source: orange/orange/OrangeWidgets/Classify/OWClassificationTreeGraph.py @ 9546:2b6cc6f397fe

Revision 9546:2b6cc6f397fe, 19.3 KB checked in by ales_erjavec <ales.erjavec@…>, 2 years ago (diff)

Renamed widget channel names in line with the new naming rules/convention.
Added backwards compatibility in orngDoc loadDocument to enable loading of schemas saved before the change.

Line 
1"""<name>Classification Tree Graph</name>
2<description>Classification tree viewer (graph view).</description>
3<icon>icons/ClassificationTreeGraph.png</icon>
4<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
5<priority>2110</priority>
6"""
7
8from OWTreeViewer2D import *
9import OWColorPalette
10
11class PieChart(QGraphicsRectItem):
12    def __init__(self, dist, r, parent, scene):
13        QGraphicsRectItem.__init__(self, parent, scene)
14        self.dist = dist
15        self.r = r
16       
17    def setR(self, r):
18        self.r = r
19       
20    def boundingRect(self):
21        return QRectF(-self.r, -self.r, 2*self.r, 2*self.r)
22       
23    def paint(self, painter, option, widget = None):
24        distSum = sum(self.dist)
25        startAngle = 0
26        colors = self.scene().colorPalette
27        for i in range(len(self.dist)):
28            angle = self.dist[i]*16 * 360./distSum
29            if angle == 0: continue
30            painter.setBrush(QBrush(colors[i]))
31            painter.setPen(QPen(colors[i]))
32            painter.drawPie(-self.r, -self.r, 2*self.r, 2*self.r, int(startAngle), int(angle))
33            startAngle += angle
34        painter.setPen(QPen(Qt.black))
35        painter.setBrush(QBrush())
36        painter.drawEllipse(-self.r, -self.r, 2*self.r, 2*self.r)
37
38class ClassificationTreeNode(GraphicsNode):
39    def __init__(self, attr, tree, parent=None, parentItem=None, scene=None):
40        GraphicsNode.__init__(self, tree, parent, parentItem, scene)
41        self.attr = attr
42        self.pie = PieChart(self.tree.distribution, 20, self, scene)
43        self.majorityClass, self.majorityCount = max(self.tree.distribution.items(), key=lambda (key, val): val)
44        fm = QFontMetrics(self.document().defaultFont())
45        self.attr_text_w = fm.width(str(self.attr if self.attr else ""))
46        self.attr_text_h = fm.lineSpacing()
47        self.line_descent = fm.descent()
48
49    def updateContents(self):
50        self.prepareGeometryChange()
51        if getattr(self, "_rect", QRectF()).isValid() and not self.truncateText:
52            self.setTextWidth(self._rect.width() - self.pie.boundingRect().width() / 2 if hasattr(self, "pie") else 0)
53        else:
54            self.setTextWidth(-1)
55            self.setTextWidth(self.document().idealWidth())
56        self.droplet.setPos(self.rect().center().x(), self.rect().height())
57        self.droplet.setVisible(bool(self.branches))
58        self.pie.setPos(self.rect().right(), self.rect().center().y())
59        fm = QFontMetrics(self.document().defaultFont())
60        self.attr_text_w = fm.width(str(self.attr if self.attr else ""))
61        self.attr_text_h = fm.lineSpacing()
62        self.line_descent = fm.descent()
63       
64    def rect(self):
65        if self.truncateText and getattr(self, "_rect", QRectF()).isValid():
66            return self._rect
67        else:
68            rect = QRectF(QPointF(0,0), self.document().size())
69            return rect.adjusted(0, 0, self.pie.boundingRect().width() / 2 if hasattr(self, "pie") else 0, 0) | getattr(self, "_rect", QRectF(0,0,1,1))
70   
71    def setRect(self, rect):
72        self.prepareGeometryChange()
73        rect = QRectF() if rect is None else rect
74        self._rect = rect
75        if rect.isValid() and not self.truncateText:
76            self.setTextWidth(self._rect.width() - self.pie.boundingRect().width() / 2 if hasattr(self, "pie") else 0)
77        else:
78            self.setTextWidth(-1)
79        self.updateContents()
80        self.update()
81       
82    def boundingRect(self):
83        if hasattr(self, "attr"):
84            attr_rect = QRectF(QPointF(0, -self.attr_text_h), QSizeF(self.attr_text_w, self.attr_text_h))
85        else:
86            attr_rect = QRectF(0, 0, 1, 1)
87        rect = self.rect().adjusted(-5, -5, 5, 5)
88        if self.truncateText:
89            return rect | attr_rect
90        else:
91            return rect | GraphicsNode.boundingRect(self) | attr_rect
92
93    def rule(self):
94        return self.parent.rule() + [(self.parent.tree.branchSelector.classVar, self.attr)] if self.parent else []
95   
96    def paint(self, painter, option, widget=None):
97        if self.isSelected():
98            option.state = option.state.__xor__(QStyle.State_Selected)
99        if self.isSelected():
100            painter.save()
101            painter.setBrush(QBrush(QColor(125, 162, 206, 192)))
102            painter.drawRoundedRect(self.boundingRect().adjusted(-2, 1, -1, -1), 10, 10)#self.borderRadius, self.borderRadius)
103            painter.restore()
104        painter.setFont(self.document().defaultFont())
105        painter.drawText(QPointF(0, -self.line_descent), str(self.attr) if self.attr else "")
106        painter.save()
107        painter.setBrush(self.backgroundBrush)
108        rect = self.rect()
109        painter.drawRoundedRect(rect.adjusted(-3, 0, 0, 0), 10, 10)#self.borderRadius, self.borderRadius)
110        painter.restore()
111        if self.truncateText:
112#            self.setTextWidth(-1)`
113            painter.setClipRect(rect)
114        else:
115            painter.setClipRect(rect | QRectF(QPointF(0, 0), self.document().size()))
116        return QGraphicsTextItem.paint(self, painter, option, widget)
117#        TextTreeNode.paint(self, painter, option, widget)
118
119import re
120def parseRules(rules):
121    def joinCont(rule1, rule2):
122        int1, int2=["(",-1e1000,1e1000,")"], ["(",-1e1000,1e1000,")"]
123        rule=[rule1, rule2]
124        interval=[int1, int2]
125        for i in [0,1]:
126            if rule[i][1].startswith("in"):
127                r=rule[i][1][2:]
128                interval[i]=[r.strip(" ")[0]]+map(lambda a: float(a), r.strip("()[] ").split(","))+[r.strip(" ")[-1]]
129            else:
130                if "<" in rule[i][1]:
131                    interval[i][3]=("=" in rule[i][1] and "]") or ")"
132                    interval[i][2]=float(rule[i][1].strip("<>= "))
133                else:
134                    interval[i][0]=("=" in rule[i][1] and "[") or "("
135                    interval[i][1]=float(rule[i][1].strip("<>= "))
136
137        inter=[None]*4
138
139        if interval[0][1]<interval[1][1] or (interval[0][1]==interval[1][1] and interval[0][0]=="["):
140            interval.reverse()
141        inter[:2]=interval[0][:2]
142
143        if interval[0][2]>interval[1][2] or (interval[0][2]==interval[1][2] and interval[0][3]=="]"):
144            interval.reverse()
145        inter[2:]=interval[0][2:]
146
147
148        if 1e1000 in inter or -1e1000 in inter:
149            rule=((-1e1000==inter[1] and "<") or ">")
150            rule+=(("[" in inter or "]" in inter) and "=") or ""
151            rule+=(-1e1000==inter[1] and str(inter[2])) or str(inter[1])
152        else:
153            rule="in "+inter[0]+str(inter[1])+","+str(inter[2])+inter[3]
154        return (rule1[0], rule)
155
156    def joinDisc(rule1, rule2):
157        r1,r2=rule1[1],rule2[1]
158        r1=re.sub("^in ","",r1)
159        r2=re.sub("^in ","",r2)
160        r1=r1.strip("[]=")
161        r2=r2.strip("[]=")
162        s1=set([s.strip(" ") for s in r1.split(",")])
163        s2=set([s.strip(" ") for s in r2.split(",")])
164        s=s1 & s2
165        if len(s)==1:
166            return (rule1[0], "= "+str(list(s)[0]))
167        else:
168            return (rule1[0], "in ["+",".join([str(st) for st in s])+"]")
169
170    rules.sort(lambda a,b: (a[0].name<b[0].name and -1) or 1 )
171    newRules=[rules[0]]
172    for r in rules[1:]:
173        if r[0].name==newRules[-1][0].name:
174            if re.search("(a-zA-Z\"')+",r[1].lstrip("in")):
175                newRules[-1]=joinDisc(r,newRules[-1])
176            else:
177                newRules[-1]=joinCont(r,newRules[-1])
178        else:
179            newRules.append(r)
180    return newRules
181
182#BodyColor_Default = QColor(Qt.gray)
183BodyColor_Default = QColor(255, 225, 10)
184BodyCasesColor_Default = QColor(Qt.blue) #QColor(0, 0, 128)
185
186class OWClassificationTreeGraph(OWTreeViewer2D):
187    settingsList = OWTreeViewer2D.settingsList+['ShowPies', "colorSettings", "selectedColorSettingsIndex"]
188    contextHandlers = {"": DomainContextHandler("", ["TargetClassIndex"], matchValues=1)}
189   
190    nodeColorOpts = ['Default', 'Instances in node', 'Majority class probability', 'Target class probability', 'Target class distribution']
191    nodeInfoButtons = ['Majority class', 'Majority class probability', 'Target class probability', 'Number of instances']
192   
193    def __init__(self, parent=None, signalManager = None, name='ClassificationTreeViewer2D'):
194        self.ShowPies=1
195        self.TargetClassIndex=0
196        self.colorSettings = None
197        self.selectedColorSettingsIndex = 0
198        self.showNodeInfoText = False
199        self.NodeColorMethod = 2
200       
201        OWTreeViewer2D.__init__(self, parent, signalManager, name)
202
203        self.inputs = [("Classification Tree", orange.TreeClassifier, self.ctree)]
204        self.outputs = [("Data", ExampleTable)]
205
206        self.scene=TreeGraphicsScene(self)
207        self.sceneView=TreeGraphicsView(self, self.scene)
208        self.sceneView.setViewportUpdateMode(QGraphicsView.FullViewportUpdate)
209        self.mainArea.layout().addWidget(self.sceneView)
210        self.toggleZoomSlider()
211#        self.scene.setSceneRect(0,0,800,800)
212
213        self.connect(self.scene, SIGNAL("selectionChanged()"), self.updateSelection)
214
215        self.navWidget= OWBaseWidget(self)
216        self.navWidget.lay=QVBoxLayout(self.navWidget)
217
218        scene=TreeGraphicsScene(self.navWidget)
219        self.treeNav = TreeNavigator(self.sceneView)
220        self.navWidget.lay.addWidget(self.treeNav)
221        self.navWidget.resize(400,400)
222        self.navWidget.setWindowTitle("Navigator")
223        self.setMouseTracking(True)
224       
225        colorbox = OWGUI.widgetBox(self.NodeTab, "Node Color", addSpace=True)
226       
227        OWGUI.comboBox(colorbox, self, 'NodeColorMethod', items=self.nodeColorOpts,
228                                callback=self.toggleNodeColor)
229        self.targetCombo=OWGUI.comboBox(colorbox,self, "TargetClassIndex", orientation=0, items=[],label="Target class",callback=self.toggleTargetClass)
230
231        OWGUI.checkBox(colorbox, self, 'ShowPies', 'Show distribution pie charts', tooltip='Show pie graph with class distribution?', callback=self.togglePies)
232        OWGUI.separator(colorbox)
233        OWGUI.button(colorbox, self, "Set Colors", callback=self.setColors, debuggingEnabled = 0)
234
235        nodeInfoBox = OWGUI.widgetBox(self.NodeTab, "Show Info")
236        nodeInfoSettings = ['maj', 'majp', 'tarp', 'inst']
237        self.NodeInfoW = []; self.dummy = 0
238        for i in range(len(self.nodeInfoButtons)):
239            setattr(self, nodeInfoSettings[i], i in self.NodeInfo)
240            w = OWGUI.checkBox(nodeInfoBox, self, nodeInfoSettings[i], \
241                               self.nodeInfoButtons[i], callback=self.setNodeInfo, getwidget=1, id=i)
242            self.NodeInfoW.append(w)
243
244#        OWGUI.button(self.controlArea, self, "Save as", callback=self.saveGraph, debuggingEnabled = 0)
245        self.NodeInfoSorted=list(self.NodeInfo)
246        self.NodeInfoSorted.sort()
247       
248        dlg = self.createColorDialog()
249        self.scene.colorPalette = dlg.getDiscretePalette("colorPalette")
250
251        OWGUI.rubber(self.NodeTab)
252       
253
254    def sendReport(self):
255        if self.tree:
256            tclass = self.tree.examples.domain.classVar.values[self.TargetClassIndex]
257            tsize = "%i nodes, %i leaves" % (orngTree.countNodes(self.tree), orngTree.countLeaves(self.tree))
258        else:
259            tclass = "N/A"
260            tsize = "N/A"
261           
262        self.reportSettings("Information",
263                            [("Node color", self.nodeColorOpts[self.NodeColorMethod]),
264                             ("Target class", tclass),
265                             ("Data in nodes", ", ".join(s for i, s in enumerate(self.nodeInfoButtons) if self.NodeInfoW[i].isChecked())),
266                             ("Line widths", ["Constant", "Proportion of all instances", "Proportion of parent's instances"][self.LineWidthMethod]),
267                             ("Tree size", tsize) ])
268        OWTreeViewer2D.sendReport(self)
269
270    def setColors(self):
271        dlg = self.createColorDialog()
272        if dlg.exec_():
273            self.colorSettings = dlg.getColorSchemas()
274            self.selectedColorSettingsIndex = dlg.selectedSchemaIndex
275            self.scene.colorPalette = dlg.getDiscretePalette("colorPalette")
276            self.scene.update()
277
278    def createColorDialog(self):
279        c = OWColorPalette.ColorPaletteDlg(self, "Color Palette")
280        c.createDiscretePalette("colorPalette", "Discrete Palette")
281        c.setColorSchemas(self.colorSettings, self.selectedColorSettingsIndex)
282        return c
283
284    def setNodeInfo(self, widget=None, id=None):
285        flags = sum(2**i for i, name in enumerate(['maj',
286                        'majp', 'tarp', 'inst']) if getattr(self, name))
287           
288        for n in self.scene.nodes():
289            n.setRect(QRectF())
290            self.updateNodeInfo(n, flags)
291        if True:
292            w = min(max([n.rect().width() for n in self.scene.nodes()] + [0]), self.MaxNodeWidth if self.LimitNodeWidth else sys.maxint)
293            for n in self.scene.nodes():
294                n.setRect(QRectF(n.rect().x(), n.rect().y(), w, n.rect().height()))
295        self.scene.fixPos(self.rootNode, 10, 10)
296       
297    def updateNodeInfo(self, node, flags=31):
298        fix = lambda str: str.replace(">", "&gt;").replace("<", "&lt;")
299        text = ""
300       
301#        text += "%s<br>" % fix(node.attr if node.attr else "")
302           
303        lines = []
304        if flags & 1:
305            start = "Majority class: " if self.showNodeInfoText else "" 
306#            lines += [start + "<font color=%s>%s</font>" % (self.scene.colorPalette[node.tree.examples.domain.classVar.values.index(node.majorityClass)].name(), fix(node.majorityClass))]
307            lines += [start + fix(node.majorityClass)]
308        if flags & 2:
309            start = "Majority class probability: " if self.showNodeInfoText else "" 
310            lines += [start + "%.1f" % (100.0 * float(node.majorityCount) / node.tree.distribution.abs)]
311        if flags & 4:
312            start = "Target class probability: "  if self.showNodeInfoText else "" 
313            lines += [start + "%.1f" % (100.0 * float(node.tree.distribution[self.TargetClassIndex]) / node.tree.distribution.abs)]
314        if flags & 8:
315            start = "Instances: " if self.showNodeInfoText else "" 
316            lines += [start + "%i" % node.tree.distribution.cases]
317        text += "<br>".join(lines)
318        if node.tree.branchSelector:
319            text += "<hr>" + "%s" % fix(node.tree.branchSelector.classVar.name)
320        else:
321            text += "<hr>" + fix(node.majorityClass)
322        node.setHtml(text)
323
324    def activateLoadedSettings(self):
325        if not self.tree:
326            return
327        OWTreeViewer2D.activateLoadedSettings(self)
328        self.setNodeInfo()
329        self.toggleNodeColor()
330       
331    def toggleNodeSize(self):
332        self.setNodeInfo()
333        self.scene.update()
334        self.sceneView.repaint()
335       
336    def toggleNodeColor(self):
337        for node in self.scene.nodes():
338            if self.NodeColorMethod == 0:   # default
339                color = BodyColor_Default
340            elif self.NodeColorMethod == 1: # instances in node
341                div = self.tree.distribution.cases
342                if div > 1e-6:
343                    light = 400 - 300*node.tree.distribution.cases/div
344                else:
345                    light = 100
346                color = BodyCasesColor_Default.light(light)
347            elif self.NodeColorMethod == 2: # majority class probability
348                light=400- 300*float(node.majorityCount) / node.tree.distribution.abs
349                color = self.scene.colorPalette[node.tree.examples.domain.classVar.values.index(node.majorityClass)].light(light)
350            elif self.NodeColorMethod == 3: # target class probability
351                div = node.tree.distribution.cases
352                if div > 1e-6:
353                    light=400-300*node.tree.distribution[self.TargetClassIndex]/div
354                else:
355                    light = 100
356                color = self.scene.colorPalette[self.TargetClassIndex].light(light)
357            elif self.NodeColorMethod == 4: # target class distribution
358                div = self.tree.distribution[self.TargetClassIndex]
359                if div > 1e-6:
360                    light=200 - 100*node.tree.distribution[self.TargetClassIndex]/div
361                else:
362                    light = 100
363                color = self.scene.colorPalette[self.TargetClassIndex].light(light)
364#            gradient = QLinearGradient(0, 0, 0, 100)
365#                gradient.setStops([(0, QColor(Qt.gray).lighter(120)), (1, QColor(Qt.lightGray).lighter())])
366#            gradient.setStops([(0, color), (1, color.lighter())])
367#            node.backgroundBrush = QBrush(gradient)
368            node.backgroundBrush = QBrush(color)
369        self.scene.update()
370
371    def toggleTargetClass(self):
372        if self.NodeColorMethod in [3,4]:
373            self.toggleNodeColor()
374        if self.tarp:
375            self.setNodeInfo()
376        self.scene.update()
377
378    def togglePies(self):
379        for n in self.scene.nodes():
380            n.pie.setVisible(self.ShowPies and n.isVisible())
381        self.scene.update()
382
383    def ctree(self, tree=None):
384        self.send("Data", None)
385        self.closeContext()
386        self.targetCombo.clear()
387        if tree:
388            for name in tree.tree.examples.domain.classVar.values:
389                self.targetCombo.addItem(name)
390            self.TargetClassIndex=0
391            self.openContext("", tree.domain)
392        else:
393            self.openContext("", None)
394        OWTreeViewer2D.ctree(self, tree)
395        self.togglePies()
396
397    def walkcreate(self, tree, parent=None, level=0, attrVal=""):
398        node=ClassificationTreeNode(attrVal, tree, parent, None, self.scene)
399        if parent:
400            parent.graph_add_edge(GraphicsEdge(None, self.scene, node1=parent, node2=node))
401        if tree.branches:
402            for i in range(len(tree.branches)):
403                if tree.branches[i]:
404                    self.walkcreate(tree.branches[i],node,level+1,tree.branchDescriptions[i])
405        return node
406   
407    def nodeToolTip(self, node):
408        rule = list(node.rule())
409        fix = lambda str: str.replace(">", "&gt;").replace("<", "&lt;")
410        if rule:
411            try:
412                rule=parseRules(list(rule))
413            except:
414                pass
415            text="<b>IF</b> "+" <b>AND</b><br>  ".join([fix(a[0].name+" = "+a[1]) for a in rule])+"\n<br><b>THEN</b> "+fix(node.majorityClass) + "<hr>"
416        else:
417            text="<b>THEN</b> "+fix(node.majorityClass) + "<hr>"
418        text += "Instances: %(ninst)i (%(prop).1f%%)<hr>" % {"ninst": node.tree.distribution.cases, "prop": float(node.tree.distribution.cases)/self.tree.distribution.cases*100}
419       
420        text += "<br>".join(["<font color=%(color)s>%(name)s: %(num)i (%(ratio).1f% %)</font>" % \
421                             {"name":fix(d[0]), "num":int(d[1]), "ratio":d[1]/sum(node.tree.distribution)*100, "color":self.scene.colorPalette[i].name()}\
422                             for i,d in enumerate(node.tree.distribution.items()) if d[1]!=0])
423        text += "<hr>Partition on: %(nodename)s" % {"nodename": node.tree.branchSelector.classVar.name} if node.tree.branches else ""
424        return text
425
426if __name__=="__main__":
427    a = QApplication(sys.argv)
428    ow = OWClassificationTreeGraph()
429##    a.setMainWidget(ow)
430
431    #data = orange.ExampleTable('../../doc/datasets/voting.tab')
432    data = orange.ExampleTable(r"../../doc/datasets/zoo.tab")
433    tree = orange.TreeLearner(data, storeExamples = 1)
434    ow.ctree(tree)
435
436    # here you can test setting some stuff
437    ow.show()
438    a.exec_()
439    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.