source: orange/Orange/OrangeWidgets/Classify/OWClassificationTreeGraph.py @ 11038:51e701fc9845

Revision 11038:51e701fc9845, 19.3 KB checked in by Ales Erjavec <ales.erjavec@…>, 17 months ago (diff)

Change the node text color, if the background is to dark.

Fixes #1240

Line 
1"""<name>Classification Tree Graph</name>
2<description>Classification tree viewer (graph view).</description>
3<icon>icons/ClassificationTreeGraph.png</icon>
4<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
5<priority>2110</priority>
6"""
7
8from OWTreeViewer2D import *
9import OWColorPalette
10
11import Orange
12
13class PieChart(QGraphicsRectItem):
14    def __init__(self, dist, r, parent, scene):
15        QGraphicsRectItem.__init__(self, parent, scene)
16        self.dist = dist
17        self.r = r
18       
19    def setR(self, r):
20        self.prepareGeometryChange()
21        self.r = r
22       
23    def boundingRect(self):
24        return QRectF(-self.r, -self.r, 2*self.r, 2*self.r)
25       
26    def paint(self, painter, option, widget = None):
27        distSum = sum(self.dist)
28        startAngle = 0
29        colors = self.scene().colorPalette
30        for i in range(len(self.dist)):
31            angle = self.dist[i]*16 * 360./distSum
32            if angle == 0: continue
33            painter.setBrush(QBrush(colors[i]))
34            painter.setPen(QPen(colors[i]))
35            painter.drawPie(-self.r, -self.r, 2*self.r, 2*self.r, int(startAngle), int(angle))
36            startAngle += angle
37        painter.setPen(QPen(Qt.black))
38        painter.setBrush(QBrush())
39        painter.drawEllipse(-self.r, -self.r, 2*self.r, 2*self.r)
40
41class ClassificationTreeNode(GraphicsNode):
42    def __init__(self, attr, tree, parent=None, parentItem=None, scene=None):
43        GraphicsNode.__init__(self, tree, parent, parentItem, scene)
44        self.attr = attr
45        self.pie = PieChart(self.tree.distribution, 20, self, scene)
46        self.majorityClass, self.majorityCount = max(self.tree.distribution.items(), key=lambda (key, val): val)
47        fm = QFontMetrics(self.document().defaultFont())
48        self.attr_text_w = fm.width(str(self.attr if self.attr else ""))
49        self.attr_text_h = fm.lineSpacing()
50        self.line_descent = fm.descent()
51
52    def updateContents(self):
53        self.prepareGeometryChange()
54        if getattr(self, "_rect", QRectF()).isValid() and not self.truncateText:
55            self.setTextWidth(self._rect.width() - self.pie.boundingRect().width() / 2 if hasattr(self, "pie") else 0)
56        else:
57            self.setTextWidth(-1)
58            self.setTextWidth(self.document().idealWidth())
59        self.droplet.setPos(self.rect().center().x(), self.rect().height())
60        self.droplet.setVisible(bool(self.branches))
61        self.pie.setPos(self.rect().right(), self.rect().center().y())
62        fm = QFontMetrics(self.document().defaultFont())
63        self.attr_text_w = fm.width(str(self.attr if self.attr else ""))
64        self.attr_text_h = fm.lineSpacing()
65        self.line_descent = fm.descent()
66       
67    def rect(self):
68        if self.truncateText and getattr(self, "_rect", QRectF()).isValid():
69            return self._rect
70        else:
71            rect = QRectF(QPointF(0,0), self.document().size())
72            return rect.adjusted(0, 0, self.pie.boundingRect().width() / 2 if hasattr(self, "pie") else 0, 0) | getattr(self, "_rect", QRectF(0,0,1,1))
73   
74    def setRect(self, rect):
75        self.prepareGeometryChange()
76        rect = QRectF() if rect is None else rect
77        self._rect = rect
78        if rect.isValid() and not self.truncateText:
79            self.setTextWidth(self._rect.width() - self.pie.boundingRect().width() / 2 if hasattr(self, "pie") else 0)
80        else:
81            self.setTextWidth(-1)
82        self.updateContents()
83        self.update()
84       
85    def boundingRect(self):
86        if hasattr(self, "attr"):
87            attr_rect = QRectF(QPointF(0, -self.attr_text_h), QSizeF(self.attr_text_w, self.attr_text_h))
88        else:
89            attr_rect = QRectF(0, 0, 1, 1)
90        rect = self.rect().adjusted(-5, -5, 5, 5)
91        if self.truncateText:
92            return rect | attr_rect
93        else:
94            return rect | GraphicsNode.boundingRect(self) | attr_rect
95
96    def rule(self):
97        return self.parent.rule() + [(self.parent.tree.branchSelector.classVar, self.attr)] if self.parent else []
98   
99    def paint(self, painter, option, widget=None):
100        if self.isSelected():
101            option.state = option.state.__xor__(QStyle.State_Selected)
102        if self.isSelected():
103            painter.save()
104            painter.setBrush(QBrush(QColor(125, 162, 206, 192)))
105            painter.drawRoundedRect(self.boundingRect().adjusted(-2, 1, -1, -1), 10, 10)#self.borderRadius, self.borderRadius)
106            painter.restore()
107        painter.setFont(self.document().defaultFont())
108        painter.drawText(QPointF(0, -self.line_descent), str(self.attr) if self.attr else "")
109        painter.save()
110        painter.setBrush(self.backgroundBrush)
111        rect = self.rect()
112        painter.drawRoundedRect(rect.adjusted(-3, 0, 0, 0), 10, 10)#self.borderRadius, self.borderRadius)
113        painter.restore()
114        if self.truncateText:
115#            self.setTextWidth(-1)`
116            painter.setClipRect(rect)
117        else:
118            painter.setClipRect(rect | QRectF(QPointF(0, 0), self.document().size()))
119        return QGraphicsTextItem.paint(self, painter, option, widget)
120#        TextTreeNode.paint(self, painter, option, widget)
121
122import re
123def parseRules(rules):
124    def joinCont(rule1, rule2):
125        int1, int2=["(",-1e1000,1e1000,")"], ["(",-1e1000,1e1000,")"]
126        rule=[rule1, rule2]
127        interval=[int1, int2]
128        for i in [0,1]:
129            if rule[i][1].startswith("in"):
130                r=rule[i][1][2:]
131                interval[i]=[r.strip(" ")[0]]+map(lambda a: float(a), r.strip("()[] ").split(","))+[r.strip(" ")[-1]]
132            else:
133                if "<" in rule[i][1]:
134                    interval[i][3]=("=" in rule[i][1] and "]") or ")"
135                    interval[i][2]=float(rule[i][1].strip("<>= "))
136                else:
137                    interval[i][0]=("=" in rule[i][1] and "[") or "("
138                    interval[i][1]=float(rule[i][1].strip("<>= "))
139
140        inter=[None]*4
141
142        if interval[0][1]<interval[1][1] or (interval[0][1]==interval[1][1] and interval[0][0]=="["):
143            interval.reverse()
144        inter[:2]=interval[0][:2]
145
146        if interval[0][2]>interval[1][2] or (interval[0][2]==interval[1][2] and interval[0][3]=="]"):
147            interval.reverse()
148        inter[2:]=interval[0][2:]
149
150
151        if 1e1000 in inter or -1e1000 in inter:
152            rule=((-1e1000==inter[1] and "<") or ">")
153            rule+=(("[" in inter or "]" in inter) and "=") or ""
154            rule+=(-1e1000==inter[1] and str(inter[2])) or str(inter[1])
155        else:
156            rule="in "+inter[0]+str(inter[1])+","+str(inter[2])+inter[3]
157        return (rule1[0], rule)
158
159    def joinDisc(rule1, rule2):
160        r1,r2=rule1[1],rule2[1]
161        r1=re.sub("^in ","",r1)
162        r2=re.sub("^in ","",r2)
163        r1=r1.strip("[]=")
164        r2=r2.strip("[]=")
165        s1=set([s.strip(" ") for s in r1.split(",")])
166        s2=set([s.strip(" ") for s in r2.split(",")])
167        s=s1 & s2
168        if len(s)==1:
169            return (rule1[0], "= "+str(list(s)[0]))
170        else:
171            return (rule1[0], "in ["+",".join([str(st) for st in s])+"]")
172
173    rules.sort(lambda a,b: (a[0].name<b[0].name and -1) or 1 )
174    newRules=[rules[0]]
175    for r in rules[1:]:
176        if r[0].name==newRules[-1][0].name:
177            if re.search("(a-zA-Z\"')+",r[1].lstrip("in")):
178                newRules[-1]=joinDisc(r,newRules[-1])
179            else:
180                newRules[-1]=joinCont(r,newRules[-1])
181        else:
182            newRules.append(r)
183    return newRules
184
185#BodyColor_Default = QColor(Qt.gray)
186BodyColor_Default = QColor(255, 225, 10)
187BodyCasesColor_Default = QColor(Qt.blue) #QColor(0, 0, 128)
188
189class OWClassificationTreeGraph(OWTreeViewer2D):
190    settingsList = OWTreeViewer2D.settingsList+['ShowPies', "colorSettings", "selectedColorSettingsIndex"]
191    contextHandlers = {"": DomainContextHandler("", ["TargetClassIndex"], matchValues=1)}
192   
193    nodeColorOpts = ['Default', 'Instances in node', 'Majority class probability', 'Target class probability', 'Target class distribution']
194    nodeInfoButtons = ['Majority class', 'Majority class probability', 'Target class probability', 'Number of instances']
195   
196    def __init__(self, parent=None, signalManager = None, name='ClassificationTreeViewer2D'):
197        self.ShowPies=1
198        self.TargetClassIndex=0
199        self.colorSettings = None
200        self.selectedColorSettingsIndex = 0
201        self.showNodeInfoText = False
202        self.NodeColorMethod = 2
203       
204        OWTreeViewer2D.__init__(self, parent, signalManager, name)
205
206        self.inputs = [("Classification Tree", Orange.classification.tree.TreeClassifier, self.ctree)]
207        self.outputs = [("Data", ExampleTable)]
208
209        self.scene=TreeGraphicsScene(self)
210        self.sceneView=TreeGraphicsView(self, self.scene)
211        self.sceneView.setViewportUpdateMode(QGraphicsView.FullViewportUpdate)
212        self.mainArea.layout().addWidget(self.sceneView)
213        self.toggleZoomSlider()
214#        self.scene.setSceneRect(0,0,800,800)
215
216        self.connect(self.scene, SIGNAL("selectionChanged()"), self.updateSelection)
217
218        self.navWidget= OWBaseWidget(self)
219        self.navWidget.lay=QVBoxLayout(self.navWidget)
220
221        scene=TreeGraphicsScene(self.navWidget)
222        self.treeNav = TreeNavigator(self.sceneView)
223        self.navWidget.lay.addWidget(self.treeNav)
224        self.navWidget.resize(400,400)
225        self.navWidget.setWindowTitle("Navigator")
226        self.setMouseTracking(True)
227       
228        colorbox = OWGUI.widgetBox(self.NodeTab, "Node Color", addSpace=True)
229       
230        OWGUI.comboBox(colorbox, self, 'NodeColorMethod', items=self.nodeColorOpts,
231                                callback=self.toggleNodeColor)
232        self.targetCombo=OWGUI.comboBox(colorbox,self, "TargetClassIndex", orientation=0, items=[],label="Target class",callback=self.toggleTargetClass)
233
234        OWGUI.checkBox(colorbox, self, 'ShowPies', 'Show distribution pie charts', tooltip='Show pie graph with class distribution?', callback=self.togglePies)
235        OWGUI.separator(colorbox)
236        OWGUI.button(colorbox, self, "Set Colors", callback=self.setColors, debuggingEnabled = 0)
237
238        nodeInfoBox = OWGUI.widgetBox(self.NodeTab, "Show Info")
239        nodeInfoSettings = ['maj', 'majp', 'tarp', 'inst']
240        self.NodeInfoW = []; self.dummy = 0
241        for i in range(len(self.nodeInfoButtons)):
242            setattr(self, nodeInfoSettings[i], i in self.NodeInfo)
243            w = OWGUI.checkBox(nodeInfoBox, self, nodeInfoSettings[i], \
244                               self.nodeInfoButtons[i], callback=self.setNodeInfo, getwidget=1, id=i)
245            self.NodeInfoW.append(w)
246
247#        OWGUI.button(self.controlArea, self, "Save as", callback=self.saveGraph, debuggingEnabled = 0)
248        self.NodeInfoSorted=list(self.NodeInfo)
249        self.NodeInfoSorted.sort()
250       
251        dlg = self.createColorDialog()
252        self.scene.colorPalette = dlg.getDiscretePalette("colorPalette")
253
254        OWGUI.rubber(self.NodeTab)
255       
256
257    def sendReport(self):
258        if self.tree:
259            tclass = self.tree.examples.domain.classVar.values[self.TargetClassIndex]
260            tsize = "%i nodes, %i leaves" % (orngTree.countNodes(self.tree), orngTree.countLeaves(self.tree))
261        else:
262            tclass = "N/A"
263            tsize = "N/A"
264           
265        self.reportSettings("Information",
266                            [("Node color", self.nodeColorOpts[self.NodeColorMethod]),
267                             ("Target class", tclass),
268                             ("Data in nodes", ", ".join(s for i, s in enumerate(self.nodeInfoButtons) if self.NodeInfoW[i].isChecked())),
269                             ("Line widths", ["Constant", "Proportion of all instances", "Proportion of parent's instances"][self.LineWidthMethod]),
270                             ("Tree size", tsize) ])
271        OWTreeViewer2D.sendReport(self)
272
273    def setColors(self):
274        dlg = self.createColorDialog()
275        if dlg.exec_():
276            self.colorSettings = dlg.getColorSchemas()
277            self.selectedColorSettingsIndex = dlg.selectedSchemaIndex
278            self.scene.colorPalette = dlg.getDiscretePalette("colorPalette")
279            self.scene.update()
280
281    def createColorDialog(self):
282        c = OWColorPalette.ColorPaletteDlg(self, "Color Palette")
283        c.createDiscretePalette("colorPalette", "Discrete Palette")
284        c.setColorSchemas(self.colorSettings, self.selectedColorSettingsIndex)
285        return c
286
287    def setNodeInfo(self, widget=None, id=None):
288        flags = sum(2**i for i, name in enumerate(['maj',
289                        'majp', 'tarp', 'inst']) if getattr(self, name))
290           
291        for n in self.scene.nodes():
292            n.setRect(QRectF())
293            self.updateNodeInfo(n, flags)
294        if True:
295            w = min(max([n.rect().width() for n in self.scene.nodes()] + [0]), self.MaxNodeWidth if self.LimitNodeWidth else sys.maxint)
296            for n in self.scene.nodes():
297                n.setRect(QRectF(n.rect().x(), n.rect().y(), w, n.rect().height()))
298        self.scene.fixPos(self.rootNode, 10, 10)
299       
300    def updateNodeInfo(self, node, flags=31):
301        fix = lambda str: str.replace(">", "&gt;").replace("<", "&lt;")
302        text = ""
303       
304#        text += "%s<br>" % fix(node.attr if node.attr else "")
305           
306        lines = []
307        if flags & 1:
308            start = "Majority class: " if self.showNodeInfoText else "" 
309#            lines += [start + "<font color=%s>%s</font>" % (self.scene.colorPalette[node.tree.examples.domain.classVar.values.index(node.majorityClass)].name(), fix(node.majorityClass))]
310            lines += [start + fix(node.majorityClass)]
311        if flags & 2:
312            start = "Majority class probability: " if self.showNodeInfoText else "" 
313            lines += [start + "%.1f" % (100.0 * float(node.majorityCount) / node.tree.distribution.abs)]
314        if flags & 4:
315            start = "Target class probability: "  if self.showNodeInfoText else "" 
316            lines += [start + "%.1f" % (100.0 * float(node.tree.distribution[self.TargetClassIndex]) / node.tree.distribution.abs)]
317        if flags & 8:
318            start = "Instances: " if self.showNodeInfoText else "" 
319            lines += [start + "%i" % node.tree.distribution.cases]
320        text += "<br>".join(lines)
321        if node.tree.branchSelector:
322            text += "<hr>" + "%s" % fix(node.tree.branchSelector.classVar.name)
323        else:
324            text += "<hr>" + fix(node.majorityClass)
325        node.setHtml(text)
326
327    def activateLoadedSettings(self):
328        if not self.tree:
329            return
330        OWTreeViewer2D.activateLoadedSettings(self)
331        self.setNodeInfo()
332        self.toggleNodeColor()
333       
334    def toggleNodeSize(self):
335        self.setNodeInfo()
336        self.scene.update()
337        self.sceneView.repaint()
338       
339    def toggleNodeColor(self):
340        for node in self.scene.nodes():
341            if self.NodeColorMethod == 0:   # default
342                color = BodyColor_Default
343            elif self.NodeColorMethod == 1: # instances in node
344                div = self.tree.distribution.cases
345                if div > 1e-6:
346                    light = 400 - 300*node.tree.distribution.cases/div
347                else:
348                    light = 100
349                color = BodyCasesColor_Default.light(light)
350            elif self.NodeColorMethod == 2: # majority class probability
351                light=400- 300*float(node.majorityCount) / node.tree.distribution.abs
352                color = self.scene.colorPalette[node.tree.examples.domain.classVar.values.index(node.majorityClass)].light(light)
353            elif self.NodeColorMethod == 3: # target class probability
354                div = node.tree.distribution.cases
355                if div > 1e-6:
356                    light=400-300*node.tree.distribution[self.TargetClassIndex]/div
357                else:
358                    light = 100
359                color = self.scene.colorPalette[self.TargetClassIndex].light(light)
360            elif self.NodeColorMethod == 4: # target class distribution
361                div = self.tree.distribution[self.TargetClassIndex]
362                if div > 1e-6:
363                    light=200 - 100*node.tree.distribution[self.TargetClassIndex]/div
364                else:
365                    light = 100
366                color = self.scene.colorPalette[self.TargetClassIndex].light(light)
367#            gradient = QLinearGradient(0, 0, 0, 100)
368#                gradient.setStops([(0, QColor(Qt.gray).lighter(120)), (1, QColor(Qt.lightGray).lighter())])
369#            gradient.setStops([(0, color), (1, color.lighter())])
370#            node.backgroundBrush = QBrush(gradient)
371            node.backgroundBrush = QBrush(color)
372        self.scene.update()
373
374    def toggleTargetClass(self):
375        if self.NodeColorMethod in [3,4]:
376            self.toggleNodeColor()
377        if self.tarp:
378            self.setNodeInfo()
379        self.scene.update()
380
381    def togglePies(self):
382        for n in self.scene.nodes():
383            n.pie.setVisible(self.ShowPies and n.isVisible())
384        self.scene.update()
385
386    def ctree(self, tree=None):
387        self.send("Data", None)
388        self.closeContext()
389        self.targetCombo.clear()
390        if tree:
391            for name in tree.tree.examples.domain.classVar.values:
392                self.targetCombo.addItem(name)
393            self.TargetClassIndex=0
394            self.openContext("", tree.domain)
395        else:
396            self.openContext("", None)
397        OWTreeViewer2D.ctree(self, tree)
398        self.togglePies()
399
400    def walkcreate(self, tree, parent=None, level=0, attrVal=""):
401        node=ClassificationTreeNode(attrVal, tree, parent, None, self.scene)
402        if parent:
403            parent.graph_add_edge(GraphicsEdge(None, self.scene, node1=parent, node2=node))
404        if tree.branches:
405            for i in range(len(tree.branches)):
406                if tree.branches[i]:
407                    self.walkcreate(tree.branches[i],node,level+1,tree.branchDescriptions[i])
408        return node
409   
410    def nodeToolTip(self, node):
411        rule = list(node.rule())
412        fix = lambda str: str.replace(">", "&gt;").replace("<", "&lt;")
413        if rule:
414            try:
415                rule=parseRules(list(rule))
416            except:
417                pass
418            text="<b>IF</b> "+" <b>AND</b><br>  ".join([fix(a[0].name+" = "+a[1]) for a in rule])+"\n<br><b>THEN</b> "+fix(node.majorityClass) + "<hr>"
419        else:
420            text="<b>THEN</b> "+fix(node.majorityClass) + "<hr>"
421        text += "Instances: %(ninst)i (%(prop).1f%%)<hr>" % {"ninst": node.tree.distribution.cases, "prop": float(node.tree.distribution.cases)/self.tree.distribution.cases*100}
422       
423        text += "<br>".join(["<font color=%(color)s>%(name)s: %(num)i (%(ratio).1f% %)</font>" % \
424                             {"name":fix(d[0]), "num":int(d[1]), "ratio":d[1]/sum(node.tree.distribution)*100, "color":self.scene.colorPalette[i].name()}\
425                             for i,d in enumerate(node.tree.distribution.items()) if d[1]!=0])
426        text += "<hr>Partition on: %(nodename)s" % {"nodename": node.tree.branchSelector.classVar.name} if node.tree.branches else ""
427        return text
428
429if __name__=="__main__":
430    a = QApplication(sys.argv)
431    ow = OWClassificationTreeGraph()
432##    a.setMainWidget(ow)
433
434    #data = orange.ExampleTable('../../doc/datasets/voting.tab')
435    data = orange.ExampleTable(r"../../doc/datasets/zoo.tab")
436    tree = orange.TreeLearner(data, storeExamples = 1)
437    ow.ctree(tree)
438
439    # here you can test setting some stuff
440    ow.show()
441    a.exec_()
442    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.