source: orange/Orange/OrangeWidgets/Regression/OWRegressionTreeViewer2D.py @ 11096:cf7d2ae9d22b

Revision 11096:cf7d2ae9d22b, 12.9 KB checked in by Ales Erjavec <ales.erjavec@…>, 19 months ago (diff)

Added new svg icons for the widgets/categories.

Line 
1"""
2<name> Regression Tree Graph</name>
3<description>Regression tree viewer (graph view).</description>
4<icon>icons/RegressionTreeGraph.svg</icon>
5<contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact>
6<priority>2110</priority>
7"""
8from OWTreeViewer2D import *
9import re
10
11import Orange
12
13class RegressionTreeNode(GraphicsNode):
14    def __init__(self, attr, tree, parent=None, *args):
15        GraphicsNode.__init__(self, tree, parent, *args)
16        self.attr = attr
17        fm = QFontMetrics(self.document().defaultFont())
18        self.attr_text_w = fm.width(str(self.attr if self.attr else ""))
19        self.attr_text_h = fm.lineSpacing()
20        self.line_descent = fm.descent()
21       
22    def rule(self):
23        return self.parent.rule() + [(self.parent.tree.branchSelector.classVar, self.attr)] if self.parent else []
24   
25    def rect(self):
26        rect = GraphicsNode.rect(self)
27        rect.setRight(max(rect.right(), getattr(self, "attr_text_w", 0)))
28        return rect
29   
30    def boundingRect(self):
31        if hasattr(self, "attr"):
32            attr_rect = QRectF(QPointF(0, -self.attr_text_h), QSizeF(self.attr_text_w, self.attr_text_h))
33        else:
34            attr_rect = QRectF(0, 0, 1, 1)
35        rect = self.rect().adjusted(-5, -5, 5, 5)
36        return rect | GraphicsNode.boundingRect(self) | attr_rect
37   
38    def paint(self, painter, option, widget=None):
39        if self.isSelected():
40            option.state = option.state.__xor__(QStyle.State_Selected)
41        if self.isSelected():
42            painter.save()
43            painter.setBrush(QBrush(QColor(125, 162, 206, 192)))
44            painter.drawRoundedRect(self.boundingRect().adjusted(-2, 1, -1, -1), 10, 10)#self.borderRadius, self.borderRadius)
45            painter.restore()
46        painter.setFont(self.document().defaultFont())
47        painter.drawText(QPointF(0, -self.line_descent), str(self.attr) if self.attr else "")
48        painter.save()
49        painter.setBrush(self.backgroundBrush)
50        rect = self.rect()
51        painter.drawRoundedRect(rect.adjusted(-3, 0, 0, 0), 10, 10)#, self.borderRadius, self.borderRadius)
52        painter.restore()
53        painter.setClipRect(rect | QRectF(QPointF(0, 0), self.document().size()))
54        return QGraphicsTextItem.paint(self, painter, option, widget)
55       
56def parseRules(rules):
57    def joinCont(rule1, rule2):
58        int1, int2=["(",-1e1000,1e1000,")"], ["(",-1e1000,1e1000,")"]
59        rule=[rule1, rule2]
60        interval=[int1, int2]
61        for i in [0,1]:
62            if rule[i][1].startswith("in"):
63                r=rule[i][1][2:]
64                interval[i]=[r.strip(" ")[0]]+map(lambda a: float(a), r.strip("()[] ").split(","))+[r.strip(" ")[-1]]
65            else:
66                if "<" in rule[i][1]:
67                    interval[i][3]=("=" in rule[i][1] and "]") or ")"
68                    interval[i][2]=float(rule[i][1].strip("<>= "))
69                else:
70                    interval[i][0]=("=" in rule[i][1] and "[") or "("
71                    interval[i][1]=float(rule[i][1].strip("<>= "))
72
73        inter=[None]*4
74
75        if interval[0][1]<interval[1][1] or (interval[0][1]==interval[1][1] and interval[0][0]=="["):
76            interval.reverse()
77        inter[:2]=interval[0][:2]
78
79        if interval[0][2]>interval[1][2] or (interval[0][2]==interval[1][2] and interval[0][3]=="]"):
80            interval.reverse()
81        inter[2:]=interval[0][2:]
82
83
84        if 1e1000 in inter or -1e1000 in inter:
85            rule=((-1e1000==inter[1] and "<") or ">")
86            rule+=(("[" in inter or "]" in inter) and "=") or ""
87            rule+=(-1e1000==inter[1] and str(inter[2])) or str(inter[1])
88        else:
89            rule="in "+inter[0]+str(inter[1])+","+str(inter[2])+inter[3]
90        return (rule1[0], rule)
91
92    def joinDisc(rule1, rule2):
93        r1,r2=rule1[1],rule2[1]
94        r1=re.sub("^in ","",r1)
95        r2=re.sub("^in ","",r2)
96        r1=r1.strip("[]=")
97        r2=r2.strip("[]=")
98        s1=set([s.strip(" ") for s in r1.split(",")])
99        s2=set([s.strip(" ") for s in r2.split(",")])
100        s=s1 & s2
101        if len(s)==1:
102            return (rule1[0], "= "+str(list(s)[0]))
103        else:
104            return (rule1[0], "in ["+",".join([str(st) for st in s])+"]")
105
106    rules.sort(lambda a,b: (a[0].name<b[0].name and -1) or 1 )
107    newRules=[rules[0]]
108    for r in rules[1:]:
109        if r[0].name==newRules[-1][0].name:
110            if re.search("(a-zA-Z\"')+",r[1].lstrip("in")):
111                newRules[-1]=joinDisc(r,newRules[-1])
112            else:
113                newRules[-1]=joinCont(r,newRules[-1])
114        else:
115            newRules.append(r)
116    return newRules
117
118BodyColor_Default = QColor(255, 225, 10)
119#BodyColor_Default = QColor(Qt.gray)
120BodyCasesColor_Default = QColor(0, 0, 128)
121
122class OWRegressionTreeViewer2D(OWTreeViewer2D):
123    nodeColorOpts = ['Default', 'Instances in node', 'Variance', 'Deviation', 'Error']
124    nodeInfoButtons = ['Predicted value', 'Variance', 'Deviation', 'Error', 'Number of instances']
125
126    def __init__(self, parent=None, signalManager = None, name='RegressionTreeViewer2D'):
127        OWTreeViewer2D.__init__(self, parent, signalManager, name)
128
129        self.inputs = [("Classification Tree", Orange.regression.tree.TreeClassifier, self.ctree)]
130        self.outputs = [("Data", ExampleTable)]
131       
132        self.NodeColorMethod = 1
133        self.showNodeInfoText = False
134       
135        self.scene = TreeGraphicsScene(self)
136        self.sceneView = TreeGraphicsView(self, self.scene)
137        self.sceneView.setViewportUpdateMode(QGraphicsView.FullViewportUpdate)
138        self.mainArea.layout().addWidget(self.sceneView)
139        self.toggleZoomSlider()
140       
141        self.connect(self.scene, SIGNAL("selectionChanged()"), self.updateSelection)
142
143        self.navWidget = OWBaseWidget(self) 
144        self.navWidget.lay=QVBoxLayout(self.navWidget)
145       
146#        scene = TreeGraphicsScene(self.navWidget)
147        self.treeNav = TreeNavigator(self.sceneView) #,self,scene,self.navWidget)
148#        self.treeNav.setScene(scene)
149        self.navWidget.layout().addWidget(self.treeNav)
150        self.navWidget.resize(400,400)
151        self.navWidget.setWindowTitle("Navigator")
152        self.setMouseTracking(True)
153
154        OWGUI.comboBox(self.NodeTab, self, 'NodeColorMethod', items=self.nodeColorOpts, box='Node Color',
155                                callback=self.toggleNodeColor, addSpace=True)
156       
157        nodeInfoBox = OWGUI.widgetBox(self.NodeTab, "Show Info On")
158        nodeInfoSettings = ['maj', 'majp', 'tarp', 'error', 'inst']
159        self.NodeInfoW = []; self.dummy = 0
160        for i in range(len(self.nodeInfoButtons)):
161            setattr(self, nodeInfoSettings[i], i in self.NodeInfo)
162            w = OWGUI.checkBox(nodeInfoBox, self, nodeInfoSettings[i], \
163                               self.nodeInfoButtons[i], callback=self.setNodeInfo, getwidget=1, id=i)
164            self.NodeInfoW.append(w)
165
166        OWGUI.rubber(self.NodeTab)
167       
168#        OWGUI.button(self.controlArea, self, "Save As", callback=self.saveGraph, debuggingEnabled = 0)
169        self.NodeInfoSorted=list(self.NodeInfo)
170        self.NodeInfoSorted.sort()
171
172    def sendReport(self):
173        self.reportSettings("Information",
174                            [("Node color", self.nodeColorOpts[self.NodeColorMethod]),
175                             ("Data in nodes", ", ".join(s for i, s in enumerate(self.nodeInfoButtons) if self.NodeInfoW[i].isChecked())),
176                             ("Line widths", ["Constant", "Proportion of all instances", "Proportion of parent's instances"][self.LineWidthMethod]),
177                             ("Tree size", "%i nodes, %i leaves" % (orngTree.countNodes(self.tree), orngTree.countLeaves(self.tree)))])
178        OWTreeViewer2D.sendReport(self)
179
180    def setNodeInfo(self, widget=None, id=None):
181        flags = sum(2**i for i, name in enumerate(['maj', 'majp', 'tarp', 'error', 'inst']) if getattr(self, name)) 
182        for n in self.scene.nodes():
183            n.setRect(QRectF())
184            self.updateNodeInfo(n, flags)
185        if True:
186            w = min(max([n.rect().width() for n in self.scene.nodes()] + [0]), self.MaxNodeWidth if self.LimitNodeWidth else sys.maxint)
187            for n in self.scene.nodes():
188                n.setRect(QRectF(n.rect().x(), n.rect().y(), w, n.rect().height()))
189        self.scene.fixPos(self.rootNode, 10, 10)
190        self.scene.update()
191       
192    def updateNodeInfo(self, node, flags=63):
193        fix = lambda str: str.replace(">", "&gt;").replace("<", "&lt;")
194        text = ""
195#        if node.attr:
196#            text += "%s<hr width=20000>" % fix(node.attr)
197        lines = []
198        if flags & 1:
199            start = "Predicted value: " if self.showNodeInfoText else ""
200            lines += [start + fix(str(node.tree.nodeClassifier.defaultValue))]
201        if flags & 2:
202            start = "Variance: " if self.showNodeInfoText else ""
203            lines += [start + "%.1f" % node.tree.distribution.var()]
204        if flags & 4:
205            start = "Deviance: " if self.showNodeInfoText else ""
206            lines += [start + "%.1f" % node.tree.distribution.dev()]
207        if flags & 8:
208            start = "Error: " if self.showNodeInfoText else ""
209            lines += [start + "%.1f" % node.tree.distribution.error()]
210        if flags & 16:
211            start = "Number of instances: " if self.showNodeInfoText else ""
212            lines += [start + "%i" % node.tree.distribution.cases]
213        text += "<br>".join(lines)
214        if node.tree.branchSelector:
215            text += "<hr>%s" % (fix(node.tree.branchSelector.classVar.name))
216        else:
217            text += "<hr>%s" % (fix(str(node.tree.nodeClassifier.defaultValue)))
218                               
219        node.setHtml(text) 
220
221    def activateLoadedSettings(self):
222        if not self.tree:
223            return
224        OWTreeViewer2D.activateLoadedSettings(self)
225        self.setNodeInfo()
226        self.toggleNodeColor()
227       
228    def toggleNodeSize(self):
229        self.setNodeInfo()
230        self.scene.update()
231        self.sceneView.repaint()
232
233    def toggleNodeColor(self):
234        for node in self.scene.nodes():
235            numInst=self.tree.distribution.cases
236            if self.NodeColorMethod == 0:   # default
237                color = BodyColor_Default
238            elif self.NodeColorMethod == 1: # instances in node
239                light = 400 - 300*node.tree.distribution.cases/numInst
240                color = BodyCasesColor_Default.light(light)
241            elif self.NodeColorMethod == 2:
242                light = 300-min([node.tree.distribution.var(),100])
243                color = BodyCasesColor_Default.light(light)
244            elif self.NodeColorMethod == 3:
245                light = 300 - min([node.tree.distribution.dev(),100])
246                color = BodyCasesColor_Default.light(light)
247            elif self.NodeColorMethod == 4:
248                light = 400 - 300*node.tree.distribution.error()
249                color = BodyCasesColor_Default.light(light)
250#            gradient = QLinearGradient(0, 0, 0, 100)
251#            gradient.setStops([(0, color.lighter(120)), (1, color.lighter())])
252#            node.backgroundBrush = QBrush(gradient)
253            node.backgroundBrush = QBrush(color)
254
255        self.scene.update()
256#        self.treeNav.leech()
257
258    def ctree(self, tree=None):
259        self.send("Data", None)
260        OWTreeViewer2D.ctree(self, tree)
261
262    def walkcreate(self, tree, parent=None, level=0, attrVal=""):
263        node=RegressionTreeNode(attrVal, tree, parent, None, self.scene)
264        if parent:
265            parent.graph_add_edge(GraphicsEdge(None, self.scene, node1=parent, node2=node))
266        if tree.branches:
267            for i in range(len(tree.branches)):
268                if tree.branches[i]:
269                    self.walkcreate(tree.branches[i],node,level+1,tree.branchDescriptions[i])
270        return node
271   
272    def nodeToolTip(self, node):
273        rule=list(node.rule())
274        fix = lambda str: str.replace(">", "&gt;").replace("<", "&lt;")
275        if rule:
276            try:
277                rule=parseRules(list(rule))
278            except:
279                pass
280            text="<b>IF</b> "+" <b>AND</b><br>\n  ".join([fix(a[0].name+" "+a[1]) for a in rule])+"\n<br><b>THEN</b> "+fix(str(node.tree.nodeClassifier.defaultValue))
281        else:
282            text="<b>THEN</b> "+fix(str(node.tree.nodeClassifier.defaultValue))
283        text += "<hr>Instances: %i (%.1f%%)" % (node.tree.distribution.cases, node.tree.distribution.cases/self.tree.distribution.cases*100)
284        text += "<hr>Partition on %s<hr>" % node.tree.branchSelector.classVar.name if node.tree.branchSelector else "<hr>"
285        text += fix(node.tree.nodeClassifier.classVar.name + " = " + str(node.tree.nodeClassifier.defaultValue))
286        return text
287
288if __name__=="__main__":
289    a = QApplication(sys.argv)
290    ow = OWRegressionTreeViewer2D()
291
292    data = orange.ExampleTable('../../doc/datasets/housing.tab')
293    tree = orange.TreeLearner(data, storeExamples = 1)
294    ow.ctree(tree)
295
296    # here you can test setting some stuff
297    ow.show()
298    a.exec_()
299    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.