source: orange/Orange/OrangeWidgets/Classify/OWClassificationTreeGraph.py @ 11096:cf7d2ae9d22b

Revision 11096:cf7d2ae9d22b, 19.3 KB checked in by Ales Erjavec <ales.erjavec@…>, 19 months ago (diff)

Added new svg icons for the widgets/categories.

Line 
1"""<name>Classification Tree Graph</name>
2<description>Classification tree viewer (graph view).</description>
3<icon>icons/ClassificationTreeGraph.svg</icon>
4<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
5<priority>2110</priority>
6"""
7
8from OWTreeViewer2D import *
9import OWColorPalette
10
11import Orange
12
13class PieChart(QGraphicsRectItem):
14    def __init__(self, dist, r, parent, scene):
15        QGraphicsRectItem.__init__(self, parent, scene)
16        self.dist = dist
17        self.r = r
18       
19    def setR(self, r):
20        self.r = r
21       
22    def boundingRect(self):
23        return QRectF(-self.r, -self.r, 2*self.r, 2*self.r)
24       
25    def paint(self, painter, option, widget = None):
26        distSum = sum(self.dist)
27        startAngle = 0
28        colors = self.scene().colorPalette
29        for i in range(len(self.dist)):
30            angle = self.dist[i]*16 * 360./distSum
31            if angle == 0: continue
32            painter.setBrush(QBrush(colors[i]))
33            painter.setPen(QPen(colors[i]))
34            painter.drawPie(-self.r, -self.r, 2*self.r, 2*self.r, int(startAngle), int(angle))
35            startAngle += angle
36        painter.setPen(QPen(Qt.black))
37        painter.setBrush(QBrush())
38        painter.drawEllipse(-self.r, -self.r, 2*self.r, 2*self.r)
39
40class ClassificationTreeNode(GraphicsNode):
41    def __init__(self, attr, tree, parent=None, parentItem=None, scene=None):
42        GraphicsNode.__init__(self, tree, parent, parentItem, scene)
43        self.attr = attr
44        self.pie = PieChart(self.tree.distribution, 20, self, scene)
45        self.majorityClass, self.majorityCount = max(self.tree.distribution.items(), key=lambda (key, val): val)
46        fm = QFontMetrics(self.document().defaultFont())
47        self.attr_text_w = fm.width(str(self.attr if self.attr else ""))
48        self.attr_text_h = fm.lineSpacing()
49        self.line_descent = fm.descent()
50
51    def updateContents(self):
52        self.prepareGeometryChange()
53        if getattr(self, "_rect", QRectF()).isValid() and not self.truncateText:
54            self.setTextWidth(self._rect.width() - self.pie.boundingRect().width() / 2 if hasattr(self, "pie") else 0)
55        else:
56            self.setTextWidth(-1)
57            self.setTextWidth(self.document().idealWidth())
58        self.droplet.setPos(self.rect().center().x(), self.rect().height())
59        self.droplet.setVisible(bool(self.branches))
60        self.pie.setPos(self.rect().right(), self.rect().center().y())
61        fm = QFontMetrics(self.document().defaultFont())
62        self.attr_text_w = fm.width(str(self.attr if self.attr else ""))
63        self.attr_text_h = fm.lineSpacing()
64        self.line_descent = fm.descent()
65       
66    def rect(self):
67        if self.truncateText and getattr(self, "_rect", QRectF()).isValid():
68            return self._rect
69        else:
70            rect = QRectF(QPointF(0,0), self.document().size())
71            return rect.adjusted(0, 0, self.pie.boundingRect().width() / 2 if hasattr(self, "pie") else 0, 0) | getattr(self, "_rect", QRectF(0,0,1,1))
72   
73    def setRect(self, rect):
74        self.prepareGeometryChange()
75        rect = QRectF() if rect is None else rect
76        self._rect = rect
77        if rect.isValid() and not self.truncateText:
78            self.setTextWidth(self._rect.width() - self.pie.boundingRect().width() / 2 if hasattr(self, "pie") else 0)
79        else:
80            self.setTextWidth(-1)
81        self.updateContents()
82        self.update()
83       
84    def boundingRect(self):
85        if hasattr(self, "attr"):
86            attr_rect = QRectF(QPointF(0, -self.attr_text_h), QSizeF(self.attr_text_w, self.attr_text_h))
87        else:
88            attr_rect = QRectF(0, 0, 1, 1)
89        rect = self.rect().adjusted(-5, -5, 5, 5)
90        if self.truncateText:
91            return rect | attr_rect
92        else:
93            return rect | GraphicsNode.boundingRect(self) | attr_rect
94
95    def rule(self):
96        return self.parent.rule() + [(self.parent.tree.branchSelector.classVar, self.attr)] if self.parent else []
97   
98    def paint(self, painter, option, widget=None):
99        if self.isSelected():
100            option.state = option.state.__xor__(QStyle.State_Selected)
101        if self.isSelected():
102            painter.save()
103            painter.setBrush(QBrush(QColor(125, 162, 206, 192)))
104            painter.drawRoundedRect(self.boundingRect().adjusted(-2, 1, -1, -1), 10, 10)#self.borderRadius, self.borderRadius)
105            painter.restore()
106        painter.setFont(self.document().defaultFont())
107        painter.drawText(QPointF(0, -self.line_descent), str(self.attr) if self.attr else "")
108        painter.save()
109        painter.setBrush(self.backgroundBrush)
110        rect = self.rect()
111        painter.drawRoundedRect(rect.adjusted(-3, 0, 0, 0), 10, 10)#self.borderRadius, self.borderRadius)
112        painter.restore()
113        if self.truncateText:
114#            self.setTextWidth(-1)`
115            painter.setClipRect(rect)
116        else:
117            painter.setClipRect(rect | QRectF(QPointF(0, 0), self.document().size()))
118        return QGraphicsTextItem.paint(self, painter, option, widget)
119#        TextTreeNode.paint(self, painter, option, widget)
120
121import re
122def parseRules(rules):
123    def joinCont(rule1, rule2):
124        int1, int2=["(",-1e1000,1e1000,")"], ["(",-1e1000,1e1000,")"]
125        rule=[rule1, rule2]
126        interval=[int1, int2]
127        for i in [0,1]:
128            if rule[i][1].startswith("in"):
129                r=rule[i][1][2:]
130                interval[i]=[r.strip(" ")[0]]+map(lambda a: float(a), r.strip("()[] ").split(","))+[r.strip(" ")[-1]]
131            else:
132                if "<" in rule[i][1]:
133                    interval[i][3]=("=" in rule[i][1] and "]") or ")"
134                    interval[i][2]=float(rule[i][1].strip("<>= "))
135                else:
136                    interval[i][0]=("=" in rule[i][1] and "[") or "("
137                    interval[i][1]=float(rule[i][1].strip("<>= "))
138
139        inter=[None]*4
140
141        if interval[0][1]<interval[1][1] or (interval[0][1]==interval[1][1] and interval[0][0]=="["):
142            interval.reverse()
143        inter[:2]=interval[0][:2]
144
145        if interval[0][2]>interval[1][2] or (interval[0][2]==interval[1][2] and interval[0][3]=="]"):
146            interval.reverse()
147        inter[2:]=interval[0][2:]
148
149
150        if 1e1000 in inter or -1e1000 in inter:
151            rule=((-1e1000==inter[1] and "<") or ">")
152            rule+=(("[" in inter or "]" in inter) and "=") or ""
153            rule+=(-1e1000==inter[1] and str(inter[2])) or str(inter[1])
154        else:
155            rule="in "+inter[0]+str(inter[1])+","+str(inter[2])+inter[3]
156        return (rule1[0], rule)
157
158    def joinDisc(rule1, rule2):
159        r1,r2=rule1[1],rule2[1]
160        r1=re.sub("^in ","",r1)
161        r2=re.sub("^in ","",r2)
162        r1=r1.strip("[]=")
163        r2=r2.strip("[]=")
164        s1=set([s.strip(" ") for s in r1.split(",")])
165        s2=set([s.strip(" ") for s in r2.split(",")])
166        s=s1 & s2
167        if len(s)==1:
168            return (rule1[0], "= "+str(list(s)[0]))
169        else:
170            return (rule1[0], "in ["+",".join([str(st) for st in s])+"]")
171
172    rules.sort(lambda a,b: (a[0].name<b[0].name and -1) or 1 )
173    newRules=[rules[0]]
174    for r in rules[1:]:
175        if r[0].name==newRules[-1][0].name:
176            if re.search("(a-zA-Z\"')+",r[1].lstrip("in")):
177                newRules[-1]=joinDisc(r,newRules[-1])
178            else:
179                newRules[-1]=joinCont(r,newRules[-1])
180        else:
181            newRules.append(r)
182    return newRules
183
184#BodyColor_Default = QColor(Qt.gray)
185BodyColor_Default = QColor(255, 225, 10)
186BodyCasesColor_Default = QColor(Qt.blue) #QColor(0, 0, 128)
187
188class OWClassificationTreeGraph(OWTreeViewer2D):
189    settingsList = OWTreeViewer2D.settingsList+['ShowPies', "colorSettings", "selectedColorSettingsIndex"]
190    contextHandlers = {"": DomainContextHandler("", ["TargetClassIndex"], matchValues=1)}
191   
192    nodeColorOpts = ['Default', 'Instances in node', 'Majority class probability', 'Target class probability', 'Target class distribution']
193    nodeInfoButtons = ['Majority class', 'Majority class probability', 'Target class probability', 'Number of instances']
194   
195    def __init__(self, parent=None, signalManager = None, name='ClassificationTreeViewer2D'):
196        self.ShowPies=1
197        self.TargetClassIndex=0
198        self.colorSettings = None
199        self.selectedColorSettingsIndex = 0
200        self.showNodeInfoText = False
201        self.NodeColorMethod = 2
202       
203        OWTreeViewer2D.__init__(self, parent, signalManager, name)
204
205        self.inputs = [("Classification Tree", Orange.classification.tree.TreeClassifier, self.ctree)]
206        self.outputs = [("Data", ExampleTable)]
207
208        self.scene=TreeGraphicsScene(self)
209        self.sceneView=TreeGraphicsView(self, self.scene)
210        self.sceneView.setViewportUpdateMode(QGraphicsView.FullViewportUpdate)
211        self.mainArea.layout().addWidget(self.sceneView)
212        self.toggleZoomSlider()
213#        self.scene.setSceneRect(0,0,800,800)
214
215        self.connect(self.scene, SIGNAL("selectionChanged()"), self.updateSelection)
216
217        self.navWidget= OWBaseWidget(self)
218        self.navWidget.lay=QVBoxLayout(self.navWidget)
219
220        scene=TreeGraphicsScene(self.navWidget)
221        self.treeNav = TreeNavigator(self.sceneView)
222        self.navWidget.lay.addWidget(self.treeNav)
223        self.navWidget.resize(400,400)
224        self.navWidget.setWindowTitle("Navigator")
225        self.setMouseTracking(True)
226       
227        colorbox = OWGUI.widgetBox(self.NodeTab, "Node Color", addSpace=True)
228       
229        OWGUI.comboBox(colorbox, self, 'NodeColorMethod', items=self.nodeColorOpts,
230                                callback=self.toggleNodeColor)
231        self.targetCombo=OWGUI.comboBox(colorbox,self, "TargetClassIndex", orientation=0, items=[],label="Target class",callback=self.toggleTargetClass)
232
233        OWGUI.checkBox(colorbox, self, 'ShowPies', 'Show distribution pie charts', tooltip='Show pie graph with class distribution?', callback=self.togglePies)
234        OWGUI.separator(colorbox)
235        OWGUI.button(colorbox, self, "Set Colors", callback=self.setColors, debuggingEnabled = 0)
236
237        nodeInfoBox = OWGUI.widgetBox(self.NodeTab, "Show Info")
238        nodeInfoSettings = ['maj', 'majp', 'tarp', 'inst']
239        self.NodeInfoW = []; self.dummy = 0
240        for i in range(len(self.nodeInfoButtons)):
241            setattr(self, nodeInfoSettings[i], i in self.NodeInfo)
242            w = OWGUI.checkBox(nodeInfoBox, self, nodeInfoSettings[i], \
243                               self.nodeInfoButtons[i], callback=self.setNodeInfo, getwidget=1, id=i)
244            self.NodeInfoW.append(w)
245
246#        OWGUI.button(self.controlArea, self, "Save as", callback=self.saveGraph, debuggingEnabled = 0)
247        self.NodeInfoSorted=list(self.NodeInfo)
248        self.NodeInfoSorted.sort()
249       
250        dlg = self.createColorDialog()
251        self.scene.colorPalette = dlg.getDiscretePalette("colorPalette")
252
253        OWGUI.rubber(self.NodeTab)
254       
255
256    def sendReport(self):
257        if self.tree:
258            tclass = self.tree.examples.domain.classVar.values[self.TargetClassIndex]
259            tsize = "%i nodes, %i leaves" % (orngTree.countNodes(self.tree), orngTree.countLeaves(self.tree))
260        else:
261            tclass = "N/A"
262            tsize = "N/A"
263           
264        self.reportSettings("Information",
265                            [("Node color", self.nodeColorOpts[self.NodeColorMethod]),
266                             ("Target class", tclass),
267                             ("Data in nodes", ", ".join(s for i, s in enumerate(self.nodeInfoButtons) if self.NodeInfoW[i].isChecked())),
268                             ("Line widths", ["Constant", "Proportion of all instances", "Proportion of parent's instances"][self.LineWidthMethod]),
269                             ("Tree size", tsize) ])
270        OWTreeViewer2D.sendReport(self)
271
272    def setColors(self):
273        dlg = self.createColorDialog()
274        if dlg.exec_():
275            self.colorSettings = dlg.getColorSchemas()
276            self.selectedColorSettingsIndex = dlg.selectedSchemaIndex
277            self.scene.colorPalette = dlg.getDiscretePalette("colorPalette")
278            self.scene.update()
279
280    def createColorDialog(self):
281        c = OWColorPalette.ColorPaletteDlg(self, "Color Palette")
282        c.createDiscretePalette("colorPalette", "Discrete Palette")
283        c.setColorSchemas(self.colorSettings, self.selectedColorSettingsIndex)
284        return c
285
286    def setNodeInfo(self, widget=None, id=None):
287        flags = sum(2**i for i, name in enumerate(['maj',
288                        'majp', 'tarp', 'inst']) if getattr(self, name))
289           
290        for n in self.scene.nodes():
291            n.setRect(QRectF())
292            self.updateNodeInfo(n, flags)
293        if True:
294            w = min(max([n.rect().width() for n in self.scene.nodes()] + [0]), self.MaxNodeWidth if self.LimitNodeWidth else sys.maxint)
295            for n in self.scene.nodes():
296                n.setRect(QRectF(n.rect().x(), n.rect().y(), w, n.rect().height()))
297        self.scene.fixPos(self.rootNode, 10, 10)
298       
299    def updateNodeInfo(self, node, flags=31):
300        fix = lambda str: str.replace(">", "&gt;").replace("<", "&lt;")
301        text = ""
302       
303#        text += "%s<br>" % fix(node.attr if node.attr else "")
304           
305        lines = []
306        if flags & 1:
307            start = "Majority class: " if self.showNodeInfoText else "" 
308#            lines += [start + "<font color=%s>%s</font>" % (self.scene.colorPalette[node.tree.examples.domain.classVar.values.index(node.majorityClass)].name(), fix(node.majorityClass))]
309            lines += [start + fix(node.majorityClass)]
310        if flags & 2:
311            start = "Majority class probability: " if self.showNodeInfoText else "" 
312            lines += [start + "%.1f" % (100.0 * float(node.majorityCount) / node.tree.distribution.abs)]
313        if flags & 4:
314            start = "Target class probability: "  if self.showNodeInfoText else "" 
315            lines += [start + "%.1f" % (100.0 * float(node.tree.distribution[self.TargetClassIndex]) / node.tree.distribution.abs)]
316        if flags & 8:
317            start = "Instances: " if self.showNodeInfoText else "" 
318            lines += [start + "%i" % node.tree.distribution.cases]
319        text += "<br>".join(lines)
320        if node.tree.branchSelector:
321            text += "<hr>" + "%s" % fix(node.tree.branchSelector.classVar.name)
322        else:
323            text += "<hr>" + fix(node.majorityClass)
324        node.setHtml(text)
325
326    def activateLoadedSettings(self):
327        if not self.tree:
328            return
329        OWTreeViewer2D.activateLoadedSettings(self)
330        self.setNodeInfo()
331        self.toggleNodeColor()
332       
333    def toggleNodeSize(self):
334        self.setNodeInfo()
335        self.scene.update()
336        self.sceneView.repaint()
337       
338    def toggleNodeColor(self):
339        for node in self.scene.nodes():
340            if self.NodeColorMethod == 0:   # default
341                color = BodyColor_Default
342            elif self.NodeColorMethod == 1: # instances in node
343                div = self.tree.distribution.cases
344                if div > 1e-6:
345                    light = 400 - 300*node.tree.distribution.cases/div
346                else:
347                    light = 100
348                color = BodyCasesColor_Default.light(light)
349            elif self.NodeColorMethod == 2: # majority class probability
350                light=400- 300*float(node.majorityCount) / node.tree.distribution.abs
351                color = self.scene.colorPalette[node.tree.examples.domain.classVar.values.index(node.majorityClass)].light(light)
352            elif self.NodeColorMethod == 3: # target class probability
353                div = node.tree.distribution.cases
354                if div > 1e-6:
355                    light=400-300*node.tree.distribution[self.TargetClassIndex]/div
356                else:
357                    light = 100
358                color = self.scene.colorPalette[self.TargetClassIndex].light(light)
359            elif self.NodeColorMethod == 4: # target class distribution
360                div = self.tree.distribution[self.TargetClassIndex]
361                if div > 1e-6:
362                    light=200 - 100*node.tree.distribution[self.TargetClassIndex]/div
363                else:
364                    light = 100
365                color = self.scene.colorPalette[self.TargetClassIndex].light(light)
366#            gradient = QLinearGradient(0, 0, 0, 100)
367#                gradient.setStops([(0, QColor(Qt.gray).lighter(120)), (1, QColor(Qt.lightGray).lighter())])
368#            gradient.setStops([(0, color), (1, color.lighter())])
369#            node.backgroundBrush = QBrush(gradient)
370            node.backgroundBrush = QBrush(color)
371        self.scene.update()
372
373    def toggleTargetClass(self):
374        if self.NodeColorMethod in [3,4]:
375            self.toggleNodeColor()
376        if self.tarp:
377            self.setNodeInfo()
378        self.scene.update()
379
380    def togglePies(self):
381        for n in self.scene.nodes():
382            n.pie.setVisible(self.ShowPies and n.isVisible())
383        self.scene.update()
384
385    def ctree(self, tree=None):
386        self.send("Data", None)
387        self.closeContext()
388        self.targetCombo.clear()
389        if tree:
390            for name in tree.tree.examples.domain.classVar.values:
391                self.targetCombo.addItem(name)
392            self.TargetClassIndex=0
393            self.openContext("", tree.domain)
394        else:
395            self.openContext("", None)
396        OWTreeViewer2D.ctree(self, tree)
397        self.togglePies()
398
399    def walkcreate(self, tree, parent=None, level=0, attrVal=""):
400        node=ClassificationTreeNode(attrVal, tree, parent, None, self.scene)
401        if parent:
402            parent.graph_add_edge(GraphicsEdge(None, self.scene, node1=parent, node2=node))
403        if tree.branches:
404            for i in range(len(tree.branches)):
405                if tree.branches[i]:
406                    self.walkcreate(tree.branches[i],node,level+1,tree.branchDescriptions[i])
407        return node
408   
409    def nodeToolTip(self, node):
410        rule = list(node.rule())
411        fix = lambda str: str.replace(">", "&gt;").replace("<", "&lt;")
412        if rule:
413            try:
414                rule=parseRules(list(rule))
415            except:
416                pass
417            text="<b>IF</b> "+" <b>AND</b><br>  ".join([fix(a[0].name+" = "+a[1]) for a in rule])+"\n<br><b>THEN</b> "+fix(node.majorityClass) + "<hr>"
418        else:
419            text="<b>THEN</b> "+fix(node.majorityClass) + "<hr>"
420        text += "Instances: %(ninst)i (%(prop).1f%%)<hr>" % {"ninst": node.tree.distribution.cases, "prop": float(node.tree.distribution.cases)/self.tree.distribution.cases*100}
421       
422        text += "<br>".join(["<font color=%(color)s>%(name)s: %(num)i (%(ratio).1f% %)</font>" % \
423                             {"name":fix(d[0]), "num":int(d[1]), "ratio":d[1]/sum(node.tree.distribution)*100, "color":self.scene.colorPalette[i].name()}\
424                             for i,d in enumerate(node.tree.distribution.items()) if d[1]!=0])
425        text += "<hr>Partition on: %(nodename)s" % {"nodename": node.tree.branchSelector.classVar.name} if node.tree.branches else ""
426        return text
427
428if __name__=="__main__":
429    a = QApplication(sys.argv)
430    ow = OWClassificationTreeGraph()
431##    a.setMainWidget(ow)
432
433    #data = orange.ExampleTable('../../doc/datasets/voting.tab')
434    data = orange.ExampleTable(r"../../doc/datasets/zoo.tab")
435    tree = orange.TreeLearner(data, storeExamples = 1)
436    ow.ctree(tree)
437
438    # here you can test setting some stuff
439    ow.show()
440    a.exec_()
441    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.