source: orange/Orange/OrangeWidgets/Classify/OWClassificationTreeGraph.py @ 11629:80bbc6bff5c1

Revision 11629:80bbc6bff5c1, 19.0 KB checked in by Ales Erjavec <ales.erjavec@…>, 9 months ago (diff)

Fixed Classification Tree Graph widget so it does not require stored instances.

Line 
1"""<name>Classification Tree Graph</name>
2<description>Classification tree viewer (graph view).</description>
3<icon>icons/ClassificationTreeGraph.svg</icon>
4<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
5<priority>2110</priority>
6"""
7
8from OWTreeViewer2D import *
9import OWColorPalette
10
11import Orange
12
13class PieChart(QGraphicsRectItem):
14    def __init__(self, dist, r, parent, scene):
15        QGraphicsRectItem.__init__(self, parent, scene)
16        self.dist = dist
17        self.r = r
18       
19    def setR(self, r):
20        self.prepareGeometryChange()
21        self.r = r
22       
23    def boundingRect(self):
24        return QRectF(-self.r, -self.r, 2*self.r, 2*self.r)
25       
26    def paint(self, painter, option, widget = None):
27        distSum = sum(self.dist)
28        startAngle = 0
29        colors = self.scene().colorPalette
30        for i in range(len(self.dist)):
31            angle = self.dist[i]*16 * 360./distSum
32            if angle == 0: continue
33            painter.setBrush(QBrush(colors[i]))
34            painter.setPen(QPen(colors[i]))
35            painter.drawPie(-self.r, -self.r, 2*self.r, 2*self.r, int(startAngle), int(angle))
36            startAngle += angle
37        painter.setPen(QPen(Qt.black))
38        painter.setBrush(QBrush())
39        painter.drawEllipse(-self.r, -self.r, 2*self.r, 2*self.r)
40
41class ClassificationTreeNode(GraphicsNode):
42    def __init__(self, attr, tree, parent=None, parentItem=None, scene=None):
43        GraphicsNode.__init__(self, tree, parent, parentItem, scene)
44        self.attr = attr
45        self.pie = PieChart(self.tree.distribution, 20, self, scene)
46        self.majorityClass, self.majorityCount = max(self.tree.distribution.items(), key=lambda (key, val): val)
47        fm = QFontMetrics(self.document().defaultFont())
48        self.attr_text_w = fm.width(str(self.attr if self.attr else ""))
49        self.attr_text_h = fm.lineSpacing()
50        self.line_descent = fm.descent()
51
52    def updateContents(self):
53        self.prepareGeometryChange()
54        if getattr(self, "_rect", QRectF()).isValid() and not self.truncateText:
55            self.setTextWidth(self._rect.width() - self.pie.boundingRect().width() / 2 if hasattr(self, "pie") else 0)
56        else:
57            self.setTextWidth(-1)
58            self.setTextWidth(self.document().idealWidth())
59        self.droplet.setPos(self.rect().center().x(), self.rect().height())
60        self.droplet.setVisible(bool(self.branches))
61        self.pie.setPos(self.rect().right(), self.rect().center().y())
62        fm = QFontMetrics(self.document().defaultFont())
63        self.attr_text_w = fm.width(str(self.attr if self.attr else ""))
64        self.attr_text_h = fm.lineSpacing()
65        self.line_descent = fm.descent()
66       
67    def rect(self):
68        if self.truncateText and getattr(self, "_rect", QRectF()).isValid():
69            return self._rect
70        else:
71            rect = QRectF(QPointF(0,0), self.document().size())
72            return rect.adjusted(0, 0, self.pie.boundingRect().width() / 2 if hasattr(self, "pie") else 0, 0) | getattr(self, "_rect", QRectF(0,0,1,1))
73   
74    def setRect(self, rect):
75        self.prepareGeometryChange()
76        rect = QRectF() if rect is None else rect
77        self._rect = rect
78        if rect.isValid() and not self.truncateText:
79            self.setTextWidth(self._rect.width() - self.pie.boundingRect().width() / 2 if hasattr(self, "pie") else 0)
80        else:
81            self.setTextWidth(-1)
82        self.updateContents()
83        self.update()
84       
85    def boundingRect(self):
86        if hasattr(self, "attr"):
87            attr_rect = QRectF(QPointF(0, -self.attr_text_h), QSizeF(self.attr_text_w, self.attr_text_h))
88        else:
89            attr_rect = QRectF(0, 0, 1, 1)
90        rect = self.rect().adjusted(-5, -5, 5, 5)
91        if self.truncateText:
92            return rect | attr_rect
93        else:
94            return rect | GraphicsNode.boundingRect(self) | attr_rect
95
96    def rule(self):
97        return self.parent.rule() + [(self.parent.tree.branchSelector.classVar, self.attr)] if self.parent else []
98   
99    def paint(self, painter, option, widget=None):
100        if self.isSelected():
101            option.state = option.state.__xor__(QStyle.State_Selected)
102        if self.isSelected():
103            painter.save()
104            painter.setBrush(QBrush(QColor(125, 162, 206, 192)))
105            painter.drawRoundedRect(self.boundingRect().adjusted(-2, 1, -1, -1), 10, 10)#self.borderRadius, self.borderRadius)
106            painter.restore()
107        painter.setFont(self.document().defaultFont())
108        painter.drawText(QPointF(0, -self.line_descent), str(self.attr) if self.attr else "")
109        painter.save()
110        painter.setBrush(self.backgroundBrush)
111        rect = self.rect()
112        painter.drawRoundedRect(rect.adjusted(-3, 0, 0, 0), 10, 10)#self.borderRadius, self.borderRadius)
113        painter.restore()
114        if self.truncateText:
115#            self.setTextWidth(-1)`
116            painter.setClipRect(rect)
117        else:
118            painter.setClipRect(rect | QRectF(QPointF(0, 0), self.document().size()))
119        return QGraphicsTextItem.paint(self, painter, option, widget)
120#        TextTreeNode.paint(self, painter, option, widget)
121
122import re
123def parseRules(rules):
124    def joinCont(rule1, rule2):
125        int1, int2=["(",-1e1000,1e1000,")"], ["(",-1e1000,1e1000,")"]
126        rule=[rule1, rule2]
127        interval=[int1, int2]
128        for i in [0,1]:
129            if rule[i][1].startswith("in"):
130                r=rule[i][1][2:]
131                interval[i]=[r.strip(" ")[0]]+map(lambda a: float(a), r.strip("()[] ").split(","))+[r.strip(" ")[-1]]
132            else:
133                if "<" in rule[i][1]:
134                    interval[i][3]=("=" in rule[i][1] and "]") or ")"
135                    interval[i][2]=float(rule[i][1].strip("<>= "))
136                else:
137                    interval[i][0]=("=" in rule[i][1] and "[") or "("
138                    interval[i][1]=float(rule[i][1].strip("<>= "))
139
140        inter=[None]*4
141
142        if interval[0][1]<interval[1][1] or (interval[0][1]==interval[1][1] and interval[0][0]=="["):
143            interval.reverse()
144        inter[:2]=interval[0][:2]
145
146        if interval[0][2]>interval[1][2] or (interval[0][2]==interval[1][2] and interval[0][3]=="]"):
147            interval.reverse()
148        inter[2:]=interval[0][2:]
149
150
151        if 1e1000 in inter or -1e1000 in inter:
152            rule=((-1e1000==inter[1] and "<") or ">")
153            rule+=(("[" in inter or "]" in inter) and "=") or ""
154            rule+=(-1e1000==inter[1] and str(inter[2])) or str(inter[1])
155        else:
156            rule="in "+inter[0]+str(inter[1])+","+str(inter[2])+inter[3]
157        return (rule1[0], rule)
158
159    def joinDisc(rule1, rule2):
160        r1,r2=rule1[1],rule2[1]
161        r1=re.sub("^in ","",r1)
162        r2=re.sub("^in ","",r2)
163        r1=r1.strip("[]=")
164        r2=r2.strip("[]=")
165        s1=set([s.strip(" ") for s in r1.split(",")])
166        s2=set([s.strip(" ") for s in r2.split(",")])
167        s=s1 & s2
168        if len(s)==1:
169            return (rule1[0], "= "+str(list(s)[0]))
170        else:
171            return (rule1[0], "in ["+",".join([str(st) for st in s])+"]")
172
173    rules.sort(lambda a,b: (a[0].name<b[0].name and -1) or 1 )
174    newRules=[rules[0]]
175    for r in rules[1:]:
176        if r[0].name==newRules[-1][0].name:
177            if re.search("(a-zA-Z\"')+",r[1].lstrip("in")):
178                newRules[-1]=joinDisc(r,newRules[-1])
179            else:
180                newRules[-1]=joinCont(r,newRules[-1])
181        else:
182            newRules.append(r)
183    return newRules
184
185#BodyColor_Default = QColor(Qt.gray)
186BodyColor_Default = QColor(255, 225, 10)
187BodyCasesColor_Default = QColor(Qt.blue) #QColor(0, 0, 128)
188
189class OWClassificationTreeGraph(OWTreeViewer2D):
190    settingsList = OWTreeViewer2D.settingsList+['ShowPies', "colorSettings", "selectedColorSettingsIndex"]
191    contextHandlers = {"": DomainContextHandler("", ["TargetClassIndex"], matchValues=1)}
192   
193    nodeColorOpts = ['Default', 'Instances in node', 'Majority class probability', 'Target class probability', 'Target class distribution']
194    nodeInfoButtons = ['Majority class', 'Majority class probability', 'Target class probability', 'Number of instances']
195   
196    def __init__(self, parent=None, signalManager = None, name='ClassificationTreeViewer2D'):
197        self.ShowPies=1
198        self.TargetClassIndex=0
199        self.colorSettings = None
200        self.selectedColorSettingsIndex = 0
201        self.showNodeInfoText = False
202        self.NodeColorMethod = 2
203       
204        OWTreeViewer2D.__init__(self, parent, signalManager, name)
205
206        self.inputs = [("Classification Tree", Orange.classification.tree.TreeClassifier, self.ctree)]
207        self.outputs = [("Data", ExampleTable)]
208
209        self.scene=TreeGraphicsScene(self)
210        self.sceneView=TreeGraphicsView(self, self.scene)
211        self.sceneView.setViewportUpdateMode(QGraphicsView.FullViewportUpdate)
212        self.mainArea.layout().addWidget(self.sceneView)
213        self.toggleZoomSlider()
214#        self.scene.setSceneRect(0,0,800,800)
215
216        self.connect(self.scene, SIGNAL("selectionChanged()"), self.updateSelection)
217
218        self.navWidget= OWBaseWidget(self)
219        self.navWidget.lay=QVBoxLayout(self.navWidget)
220
221        scene=TreeGraphicsScene(self.navWidget)
222        self.treeNav = TreeNavigator(self.sceneView)
223        self.navWidget.lay.addWidget(self.treeNav)
224        self.navWidget.resize(400,400)
225        self.navWidget.setWindowTitle("Navigator")
226        self.setMouseTracking(True)
227       
228        colorbox = OWGUI.widgetBox(self.NodeTab, "Node Color", addSpace=True)
229       
230        OWGUI.comboBox(colorbox, self, 'NodeColorMethod', items=self.nodeColorOpts,
231                                callback=self.toggleNodeColor)
232        self.targetCombo=OWGUI.comboBox(colorbox,self, "TargetClassIndex", orientation=0, items=[],label="Target class",callback=self.toggleTargetClass)
233
234        OWGUI.checkBox(colorbox, self, 'ShowPies', 'Show distribution pie charts', tooltip='Show pie graph with class distribution?', callback=self.togglePies)
235        OWGUI.separator(colorbox)
236        OWGUI.button(colorbox, self, "Set Colors", callback=self.setColors, debuggingEnabled = 0)
237
238        nodeInfoBox = OWGUI.widgetBox(self.NodeTab, "Show Info")
239        nodeInfoSettings = ['maj', 'majp', 'tarp', 'inst']
240        self.NodeInfoW = []; self.dummy = 0
241        for i in range(len(self.nodeInfoButtons)):
242            setattr(self, nodeInfoSettings[i], i in self.NodeInfo)
243            w = OWGUI.checkBox(nodeInfoBox, self, nodeInfoSettings[i], \
244                               self.nodeInfoButtons[i], callback=self.setNodeInfo, getwidget=1, id=i)
245            self.NodeInfoW.append(w)
246
247#        OWGUI.button(self.controlArea, self, "Save as", callback=self.saveGraph, debuggingEnabled = 0)
248        self.NodeInfoSorted=list(self.NodeInfo)
249        self.NodeInfoSorted.sort()
250       
251        dlg = self.createColorDialog()
252        self.scene.colorPalette = dlg.getDiscretePalette("colorPalette")
253
254        OWGUI.rubber(self.NodeTab)
255
256    def sendReport(self):
257        if self.tree:
258            tclass = str(self.targetCombo.currentText())
259            tsize = "%i nodes, %i leaves" % (orngTree.countNodes(self.tree), orngTree.countLeaves(self.tree))
260        else:
261            tclass = "N/A"
262            tsize = "N/A"
263           
264        self.reportSettings("Information",
265                            [("Node color", self.nodeColorOpts[self.NodeColorMethod]),
266                             ("Target class", tclass),
267                             ("Data in nodes", ", ".join(s for i, s in enumerate(self.nodeInfoButtons) if self.NodeInfoW[i].isChecked())),
268                             ("Line widths", ["Constant", "Proportion of all instances", "Proportion of parent's instances"][self.LineWidthMethod]),
269                             ("Tree size", tsize) ])
270        OWTreeViewer2D.sendReport(self)
271
272    def setColors(self):
273        dlg = self.createColorDialog()
274        if dlg.exec_():
275            self.colorSettings = dlg.getColorSchemas()
276            self.selectedColorSettingsIndex = dlg.selectedSchemaIndex
277            self.scene.colorPalette = dlg.getDiscretePalette("colorPalette")
278            self.scene.update()
279
280    def createColorDialog(self):
281        c = OWColorPalette.ColorPaletteDlg(self, "Color Palette")
282        c.createDiscretePalette("colorPalette", "Discrete Palette")
283        c.setColorSchemas(self.colorSettings, self.selectedColorSettingsIndex)
284        return c
285
286    def setNodeInfo(self, widget=None, id=None):
287        flags = sum(2**i for i, name in enumerate(['maj',
288                        'majp', 'tarp', 'inst']) if getattr(self, name))
289           
290        for n in self.scene.nodes():
291            n.setRect(QRectF())
292            self.updateNodeInfo(n, flags)
293        if True:
294            w = min(max([n.rect().width() for n in self.scene.nodes()] + [0]), self.MaxNodeWidth if self.LimitNodeWidth else sys.maxint)
295            for n in self.scene.nodes():
296                n.setRect(QRectF(n.rect().x(), n.rect().y(), w, n.rect().height()))
297        self.scene.fixPos(self.rootNode, 10, 10)
298       
299    def updateNodeInfo(self, node, flags=31):
300        fix = lambda str: str.replace(">", "&gt;").replace("<", "&lt;")
301        text = ""
302       
303#        text += "%s<br>" % fix(node.attr if node.attr else "")
304           
305        lines = []
306        if flags & 1:
307            start = "Majority class: " if self.showNodeInfoText else "" 
308#            lines += [start + "<font color=%s>%s</font>" % (self.scene.colorPalette[node.tree.examples.domain.classVar.values.index(node.majorityClass)].name(), fix(node.majorityClass))]
309            lines += [start + fix(node.majorityClass)]
310        if flags & 2:
311            start = "Majority class probability: " if self.showNodeInfoText else "" 
312            lines += [start + "%.1f" % (100.0 * float(node.majorityCount) / node.tree.distribution.abs)]
313        if flags & 4:
314            start = "Target class probability: "  if self.showNodeInfoText else "" 
315            lines += [start + "%.1f" % (100.0 * float(node.tree.distribution[self.TargetClassIndex]) / node.tree.distribution.abs)]
316        if flags & 8:
317            start = "Instances: " if self.showNodeInfoText else "" 
318            lines += [start + "%i" % node.tree.distribution.cases]
319        text += "<br>".join(lines)
320        if node.tree.branchSelector:
321            text += "<hr>" + "%s" % fix(node.tree.branchSelector.classVar.name)
322        else:
323            text += "<hr>" + fix(node.majorityClass)
324        node.setHtml(text)
325
326    def activateLoadedSettings(self):
327        if not self.tree:
328            return
329        OWTreeViewer2D.activateLoadedSettings(self)
330        self.setNodeInfo()
331        self.toggleNodeColor()
332       
333    def toggleNodeSize(self):
334        self.setNodeInfo()
335        self.scene.update()
336        self.sceneView.repaint()
337
338    def toggleNodeColor(self):
339        palette = self.scene.colorPalette
340        for node in self.scene.nodes():
341            dist = node.tree.distribution
342            if self.NodeColorMethod == 0:
343                # default color
344                color = BodyColor_Default
345            elif self.NodeColorMethod == 1:
346                # number of instances in node
347                all_cases = self.tree.distribution.cases
348                light = 200 - 100 * dist.cases / (all_cases or 1)
349                color = BodyCasesColor_Default.light(light)
350            elif self.NodeColorMethod == 2:
351                # majority class probability
352                modus = dist.modus()
353                p = dist[modus] / (dist.abs or 1)
354                light = 400 - 300 * p
355                color = palette[int(modus)].light(light)
356            elif self.NodeColorMethod == 3:
357                # target class probability
358                p = dist[self.TargetClassIndex] / (dist.cases or 1)
359                light = 200 - 100 * p
360                color = palette[self.TargetClassIndex].light(light)
361            elif self.NodeColorMethod == 4:
362                # target class distribution
363                all_target = self.tree.distribution[self.TargetClassIndex] or 1
364                light = 200 - 100 * dist[self.TargetClassIndex] / all_target
365                color = palette[self.TargetClassIndex].light(light)
366            node.backgroundBrush = QBrush(color)
367        self.scene.update()
368
369    def toggleTargetClass(self):
370        if self.NodeColorMethod in [3,4]:
371            self.toggleNodeColor()
372        if self.tarp:
373            self.setNodeInfo()
374        self.scene.update()
375
376    def togglePies(self):
377        for n in self.scene.nodes():
378            n.pie.setVisible(self.ShowPies and n.isVisible())
379        self.scene.update()
380
381    def ctree(self, classifier=None):
382        """
383        Set the input TreeClassifier.
384        """
385        self.send("Data", None)
386        self.closeContext()
387        self.targetCombo.clear()
388        self.classifier = classifier
389        if classifier:
390            for name in classifier.domain.classVar.values:
391                self.targetCombo.addItem(name)
392            self.TargetClassIndex = 0
393            self.openContext("", classifier.domain)
394        else:
395            self.openContext("", None)
396        OWTreeViewer2D.ctree(self, classifier)
397        self.togglePies()
398
399    def walkcreate(self, tree, parent=None, level=0, attrVal=""):
400        node=ClassificationTreeNode(attrVal, tree, parent, None, self.scene)
401        if parent:
402            parent.graph_add_edge(GraphicsEdge(None, self.scene, node1=parent, node2=node))
403        if tree.branches:
404            for i in range(len(tree.branches)):
405                if tree.branches[i]:
406                    self.walkcreate(tree.branches[i],node,level+1,tree.branchDescriptions[i])
407        return node
408   
409    def nodeToolTip(self, node):
410        rule = list(node.rule())
411        fix = lambda str: str.replace(">", "&gt;").replace("<", "&lt;")
412        if rule:
413            try:
414                rule=parseRules(list(rule))
415            except:
416                pass
417            text="<b>IF</b> "+" <b>AND</b><br>  ".join([fix(a[0].name+" = "+a[1]) for a in rule])+"\n<br><b>THEN</b> "+fix(node.majorityClass) + "<hr>"
418        else:
419            text="<b>THEN</b> "+fix(node.majorityClass) + "<hr>"
420        text += "Instances: %(ninst)i (%(prop).1f%%)<hr>" % {"ninst": node.tree.distribution.cases, "prop": float(node.tree.distribution.cases)/self.tree.distribution.cases*100}
421       
422        text += "<br>".join(["<font color=%(color)s>%(name)s: %(num)i (%(ratio).1f% %)</font>" % \
423                             {"name":fix(d[0]), "num":int(d[1]), "ratio":d[1]/sum(node.tree.distribution)*100, "color":self.scene.colorPalette[i].name()}\
424                             for i,d in enumerate(node.tree.distribution.items()) if d[1]!=0])
425        text += "<hr>Partition on: %(nodename)s" % {"nodename": node.tree.branchSelector.classVar.name} if node.tree.branches else ""
426        return text
427
428if __name__=="__main__":
429    a = QApplication(sys.argv)
430    ow = OWClassificationTreeGraph()
431##    a.setMainWidget(ow)
432
433    #data = orange.ExampleTable('../../doc/datasets/voting.tab')
434    data = orange.ExampleTable(r"../../doc/datasets/zoo.tab")
435    tree = orange.TreeLearner(data, storeExamples = 1)
436    ow.ctree(tree)
437
438    # here you can test setting some stuff
439    ow.show()
440    a.exec_()
441    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.