source: orange/Orange/OrangeWidgets/Regression/OWRegressionTreeViewer2D.py @ 9671:a7b056375472

Revision 9671:a7b056375472, 12.8 KB checked in by anze <anze.staric@…>, 2 years ago (diff)

Moved orange to Orange (part 2)

Line 
1"""
2<name> Regression Tree Graph</name>
3<description>Regression tree viewer (graph view).</description>
4<icon>icons/RegressionTreeGraph.png</icon>
5<contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact>
6<priority>2110</priority>
7"""
8from OWTreeViewer2D import *
9import re
10
11       
12class RegressionTreeNode(GraphicsNode):
13    def __init__(self, attr, tree, parent=None, *args):
14        GraphicsNode.__init__(self, tree, parent, *args)
15        self.attr = attr
16        fm = QFontMetrics(self.document().defaultFont())
17        self.attr_text_w = fm.width(str(self.attr if self.attr else ""))
18        self.attr_text_h = fm.lineSpacing()
19        self.line_descent = fm.descent()
20       
21    def rule(self):
22        return self.parent.rule() + [(self.parent.tree.branchSelector.classVar, self.attr)] if self.parent else []
23   
24    def rect(self):
25        rect = GraphicsNode.rect(self)
26        rect.setRight(max(rect.right(), getattr(self, "attr_text_w", 0)))
27        return rect
28   
29    def boundingRect(self):
30        if hasattr(self, "attr"):
31            attr_rect = QRectF(QPointF(0, -self.attr_text_h), QSizeF(self.attr_text_w, self.attr_text_h))
32        else:
33            attr_rect = QRectF(0, 0, 1, 1)
34        rect = self.rect().adjusted(-5, -5, 5, 5)
35        return rect | GraphicsNode.boundingRect(self) | attr_rect
36   
37    def paint(self, painter, option, widget=None):
38        if self.isSelected():
39            option.state = option.state.__xor__(QStyle.State_Selected)
40        if self.isSelected():
41            painter.save()
42            painter.setBrush(QBrush(QColor(125, 162, 206, 192)))
43            painter.drawRoundedRect(self.boundingRect().adjusted(-2, 1, -1, -1), 10, 10)#self.borderRadius, self.borderRadius)
44            painter.restore()
45        painter.setFont(self.document().defaultFont())
46        painter.drawText(QPointF(0, -self.line_descent), str(self.attr) if self.attr else "")
47        painter.save()
48        painter.setBrush(self.backgroundBrush)
49        rect = self.rect()
50        painter.drawRoundedRect(rect.adjusted(-3, 0, 0, 0), 10, 10)#, self.borderRadius, self.borderRadius)
51        painter.restore()
52        painter.setClipRect(rect | QRectF(QPointF(0, 0), self.document().size()))
53        return QGraphicsTextItem.paint(self, painter, option, widget)
54       
55def parseRules(rules):
56    def joinCont(rule1, rule2):
57        int1, int2=["(",-1e1000,1e1000,")"], ["(",-1e1000,1e1000,")"]
58        rule=[rule1, rule2]
59        interval=[int1, int2]
60        for i in [0,1]:
61            if rule[i][1].startswith("in"):
62                r=rule[i][1][2:]
63                interval[i]=[r.strip(" ")[0]]+map(lambda a: float(a), r.strip("()[] ").split(","))+[r.strip(" ")[-1]]
64            else:
65                if "<" in rule[i][1]:
66                    interval[i][3]=("=" in rule[i][1] and "]") or ")"
67                    interval[i][2]=float(rule[i][1].strip("<>= "))
68                else:
69                    interval[i][0]=("=" in rule[i][1] and "[") or "("
70                    interval[i][1]=float(rule[i][1].strip("<>= "))
71
72        inter=[None]*4
73
74        if interval[0][1]<interval[1][1] or (interval[0][1]==interval[1][1] and interval[0][0]=="["):
75            interval.reverse()
76        inter[:2]=interval[0][:2]
77
78        if interval[0][2]>interval[1][2] or (interval[0][2]==interval[1][2] and interval[0][3]=="]"):
79            interval.reverse()
80        inter[2:]=interval[0][2:]
81
82
83        if 1e1000 in inter or -1e1000 in inter:
84            rule=((-1e1000==inter[1] and "<") or ">")
85            rule+=(("[" in inter or "]" in inter) and "=") or ""
86            rule+=(-1e1000==inter[1] and str(inter[2])) or str(inter[1])
87        else:
88            rule="in "+inter[0]+str(inter[1])+","+str(inter[2])+inter[3]
89        return (rule1[0], rule)
90
91    def joinDisc(rule1, rule2):
92        r1,r2=rule1[1],rule2[1]
93        r1=re.sub("^in ","",r1)
94        r2=re.sub("^in ","",r2)
95        r1=r1.strip("[]=")
96        r2=r2.strip("[]=")
97        s1=set([s.strip(" ") for s in r1.split(",")])
98        s2=set([s.strip(" ") for s in r2.split(",")])
99        s=s1 & s2
100        if len(s)==1:
101            return (rule1[0], "= "+str(list(s)[0]))
102        else:
103            return (rule1[0], "in ["+",".join([str(st) for st in s])+"]")
104
105    rules.sort(lambda a,b: (a[0].name<b[0].name and -1) or 1 )
106    newRules=[rules[0]]
107    for r in rules[1:]:
108        if r[0].name==newRules[-1][0].name:
109            if re.search("(a-zA-Z\"')+",r[1].lstrip("in")):
110                newRules[-1]=joinDisc(r,newRules[-1])
111            else:
112                newRules[-1]=joinCont(r,newRules[-1])
113        else:
114            newRules.append(r)
115    return newRules
116
117BodyColor_Default = QColor(255, 225, 10)
118#BodyColor_Default = QColor(Qt.gray)
119BodyCasesColor_Default = QColor(0, 0, 128)
120
121class OWRegressionTreeViewer2D(OWTreeViewer2D):
122    nodeColorOpts = ['Default', 'Instances in node', 'Variance', 'Deviation', 'Error']
123    nodeInfoButtons = ['Predicted value', 'Variance', 'Deviation', 'Error', 'Number of instances']
124
125    def __init__(self, parent=None, signalManager = None, name='RegressionTreeViewer2D'):
126        OWTreeViewer2D.__init__(self, parent, signalManager, name)
127
128        self.inputs = [("Classification Tree", orange.TreeClassifier, self.ctree)]
129        self.outputs = [("Data", ExampleTable)]
130       
131        self.NodeColorMethod = 1
132        self.showNodeInfoText = False
133       
134        self.scene = TreeGraphicsScene(self)
135        self.sceneView = TreeGraphicsView(self, self.scene)
136        self.sceneView.setViewportUpdateMode(QGraphicsView.FullViewportUpdate)
137        self.mainArea.layout().addWidget(self.sceneView)
138        self.toggleZoomSlider()
139       
140        self.connect(self.scene, SIGNAL("selectionChanged()"), self.updateSelection)
141
142        self.navWidget = OWBaseWidget(self) 
143        self.navWidget.lay=QVBoxLayout(self.navWidget)
144       
145#        scene = TreeGraphicsScene(self.navWidget)
146        self.treeNav = TreeNavigator(self.sceneView) #,self,scene,self.navWidget)
147#        self.treeNav.setScene(scene)
148        self.navWidget.layout().addWidget(self.treeNav)
149        self.navWidget.resize(400,400)
150        self.navWidget.setWindowTitle("Navigator")
151        self.setMouseTracking(True)
152
153        OWGUI.comboBox(self.NodeTab, self, 'NodeColorMethod', items=self.nodeColorOpts, box='Node Color',
154                                callback=self.toggleNodeColor, addSpace=True)
155       
156        nodeInfoBox = OWGUI.widgetBox(self.NodeTab, "Show Info On")
157        nodeInfoSettings = ['maj', 'majp', 'tarp', 'error', 'inst']
158        self.NodeInfoW = []; self.dummy = 0
159        for i in range(len(self.nodeInfoButtons)):
160            setattr(self, nodeInfoSettings[i], i in self.NodeInfo)
161            w = OWGUI.checkBox(nodeInfoBox, self, nodeInfoSettings[i], \
162                               self.nodeInfoButtons[i], callback=self.setNodeInfo, getwidget=1, id=i)
163            self.NodeInfoW.append(w)
164
165        OWGUI.rubber(self.NodeTab)
166       
167#        OWGUI.button(self.controlArea, self, "Save As", callback=self.saveGraph, debuggingEnabled = 0)
168        self.NodeInfoSorted=list(self.NodeInfo)
169        self.NodeInfoSorted.sort()
170
171    def sendReport(self):
172        self.reportSettings("Information",
173                            [("Node color", self.nodeColorOpts[self.NodeColorMethod]),
174                             ("Data in nodes", ", ".join(s for i, s in enumerate(self.nodeInfoButtons) if self.NodeInfoW[i].isChecked())),
175                             ("Line widths", ["Constant", "Proportion of all instances", "Proportion of parent's instances"][self.LineWidthMethod]),
176                             ("Tree size", "%i nodes, %i leaves" % (orngTree.countNodes(self.tree), orngTree.countLeaves(self.tree)))])
177        OWTreeViewer2D.sendReport(self)
178
179    def setNodeInfo(self, widget=None, id=None):
180        flags = sum(2**i for i, name in enumerate(['maj', 'majp', 'tarp', 'error', 'inst']) if getattr(self, name)) 
181        for n in self.scene.nodes():
182            n.setRect(QRectF())
183            self.updateNodeInfo(n, flags)
184        if True:
185            w = min(max([n.rect().width() for n in self.scene.nodes()] + [0]), self.MaxNodeWidth if self.LimitNodeWidth else sys.maxint)
186            for n in self.scene.nodes():
187                n.setRect(QRectF(n.rect().x(), n.rect().y(), w, n.rect().height()))
188        self.scene.fixPos(self.rootNode, 10, 10)
189        self.scene.update()
190       
191    def updateNodeInfo(self, node, flags=63):
192        fix = lambda str: str.replace(">", "&gt;").replace("<", "&lt;")
193        text = ""
194#        if node.attr:
195#            text += "%s<hr width=20000>" % fix(node.attr)
196        lines = []
197        if flags & 1:
198            start = "Predicted value: " if self.showNodeInfoText else ""
199            lines += [start + fix(str(node.tree.nodeClassifier.defaultValue))]
200        if flags & 2:
201            start = "Variance: " if self.showNodeInfoText else ""
202            lines += [start + "%.1f" % node.tree.distribution.var()]
203        if flags & 4:
204            start = "Deviance: " if self.showNodeInfoText else ""
205            lines += [start + "%.1f" % node.tree.distribution.dev()]
206        if flags & 8:
207            start = "Error: " if self.showNodeInfoText else ""
208            lines += [start + "%.1f" % node.tree.distribution.error()]
209        if flags & 16:
210            start = "Number of instances: " if self.showNodeInfoText else ""
211            lines += [start + "%i" % node.tree.distribution.cases]
212        text += "<br>".join(lines)
213        if node.tree.branchSelector:
214            text += "<hr>%s" % (fix(node.tree.branchSelector.classVar.name))
215        else:
216            text += "<hr>%s" % (fix(str(node.tree.nodeClassifier.defaultValue)))
217                               
218        node.setHtml(text) 
219
220    def activateLoadedSettings(self):
221        if not self.tree:
222            return
223        OWTreeViewer2D.activateLoadedSettings(self)
224        self.setNodeInfo()
225        self.toggleNodeColor()
226       
227    def toggleNodeSize(self):
228        self.setNodeInfo()
229        self.scene.update()
230        self.sceneView.repaint()
231
232    def toggleNodeColor(self):
233        for node in self.scene.nodes():
234            numInst=self.tree.distribution.cases
235            if self.NodeColorMethod == 0:   # default
236                color = BodyColor_Default
237            elif self.NodeColorMethod == 1: # instances in node
238                light = 400 - 300*node.tree.distribution.cases/numInst
239                color = BodyCasesColor_Default.light(light)
240            elif self.NodeColorMethod == 2:
241                light = 300-min([node.tree.distribution.var(),100])
242                color = BodyCasesColor_Default.light(light)
243            elif self.NodeColorMethod == 3:
244                light = 300 - min([node.tree.distribution.dev(),100])
245                color = BodyCasesColor_Default.light(light)
246            elif self.NodeColorMethod == 4:
247                light = 400 - 300*node.tree.distribution.error()
248                color = BodyCasesColor_Default.light(light)
249#            gradient = QLinearGradient(0, 0, 0, 100)
250#            gradient.setStops([(0, color.lighter(120)), (1, color.lighter())])
251#            node.backgroundBrush = QBrush(gradient)
252            node.backgroundBrush = QBrush(color)
253
254        self.scene.update()
255#        self.treeNav.leech()
256
257    def ctree(self, tree=None):
258        self.send("Data", None)
259        OWTreeViewer2D.ctree(self, tree)
260
261    def walkcreate(self, tree, parent=None, level=0, attrVal=""):
262        node=RegressionTreeNode(attrVal, tree, parent, None, self.scene)
263        if parent:
264            parent.graph_add_edge(GraphicsEdge(None, self.scene, node1=parent, node2=node))
265        if tree.branches:
266            for i in range(len(tree.branches)):
267                if tree.branches[i]:
268                    self.walkcreate(tree.branches[i],node,level+1,tree.branchDescriptions[i])
269        return node
270   
271    def nodeToolTip(self, node):
272        rule=list(node.rule())
273        fix = lambda str: str.replace(">", "&gt;").replace("<", "&lt;")
274        if rule:
275            try:
276                rule=parseRules(list(rule))
277            except:
278                pass
279            text="<b>IF</b> "+" <b>AND</b><br>\n  ".join([fix(a[0].name+" "+a[1]) for a in rule])+"\n<br><b>THEN</b> "+fix(str(node.tree.nodeClassifier.defaultValue))
280        else:
281            text="<b>THEN</b> "+fix(str(node.tree.nodeClassifier.defaultValue))
282        text += "<hr>Instances: %i (%.1f%%)" % (node.tree.distribution.cases, node.tree.distribution.cases/self.tree.distribution.cases*100)
283        text += "<hr>Partition on %s<hr>" % node.tree.branchSelector.classVar.name if node.tree.branchSelector else "<hr>"
284        text += fix(node.tree.nodeClassifier.classVar.name + " = " + str(node.tree.nodeClassifier.defaultValue))
285        return text
286
287if __name__=="__main__":
288    a = QApplication(sys.argv)
289    ow = OWRegressionTreeViewer2D()
290
291    data = orange.ExampleTable('../../doc/datasets/housing.tab')
292    tree = orange.TreeLearner(data, storeExamples = 1)
293    ow.ctree(tree)
294
295    # here you can test setting some stuff
296    ow.show()
297    a.exec_()
298    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.