source: orange/Orange/OrangeWidgets/Classify/OWClassificationTreeViewer.py @ 10458:ebfcfb96dd9e

Revision 10458:ebfcfb96dd9e, 11.7 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Removed a redundant print.

Line 
1"""
2<name>Classification Tree Viewer</name>
3<description>Classification tree viewer (hierarchical list view).</description>
4<icon>icons/ClassificationTreeViewer.png</icon>
5<contact>Janez Demsar (janez.demsar(@at@)fri.uni-lj.si)</contact>
6<priority>2100</priority>
7"""
8from OWWidget import *
9from orngTree import TreeLearner
10import OWGUI
11
12import orngTree
13
14class ColumnCallback:
15    def __init__(self, widget, attribute, f = None):
16        self.widget = widget
17        self.attribute = attribute
18        self.f = f
19        widget.callbackDeposit.append(self)
20
21    def __call__(self, value):
22        setattr(self.widget, self.attribute, self.f and self.f(value) or value)
23        self.widget.setTreeView(1)
24
25def checkColumn(widget, master, text, value):
26    wa = QCheckBox(text, widget)
27    widget.layout().addWidget(wa)
28    wa.setChecked(getattr(master, value))
29    master.connect(wa, SIGNAL("toggled(bool)"), ColumnCallback(master, value))
30    return wa
31
32class OWClassificationTreeViewer(OWWidget):
33    settingsList = ["maj", "pmaj", "ptarget", "inst", "dist", "adist", "expslider", "sliderValue"]
34    contextHandlers = {"": DomainContextHandler("", ["targetClass"], matchValues=1)}
35
36    def __init__(self, parent=None, signalManager = None, name='Classification Tree Viewer'):
37        OWWidget.__init__(self, parent, signalManager, name)
38
39        self.dataLabels = (('Majority class', 'Class'), 
40                  ('Probability of majority class', 'P(Class)'), 
41                  ('Probability of target class', 'P(Target)'), 
42                  ('Number of instances', '# Inst'), 
43                  ('Relative distribution', 'Distribution (rel)'), 
44                  ('Absolute distribution', 'Distribution (abs)'))
45
46#        self.callbackDeposit = []
47
48        self.inputs = [("Classification Tree", orange.TreeClassifier, self.setClassificationTree)]
49        self.outputs = [("Data", ExampleTable)]
50
51        # Settings
52        for s in self.settingsList[:6]:
53            setattr(self, s, 1)
54        self.expslider = 5
55        self.targetClass = 0
56        self.loadSettings()
57
58        self.tree = None
59        self.sliderValue = 5
60        self.precision = 3
61        self.precFrmt = "%%2.%if" % self.precision
62
63        # GUI
64        # parameters
65
66        self.dBox = OWGUI.widgetBox(self.controlArea, 'Displayed information')
67        for i in range(len(self.dataLabels)):
68            checkColumn(self.dBox, self, self.dataLabels[i][0], self.settingsList[i])
69
70        OWGUI.separator(self.controlArea)
71               
72        self.slider = OWGUI.hSlider(self.controlArea, self, "sliderValue", box = 'Expand/shrink to level', minValue = 1, maxValue = 9, step = 1, callback = self.sliderChanged)
73
74        OWGUI.separator(self.controlArea)
75        self.targetCombo=OWGUI.comboBox(self.controlArea, self, "targetClass", items=[], box="Target class", callback=self.setTarget, addSpace=True)
76
77        self.infBox = OWGUI.widgetBox(self.controlArea, 'Tree size')
78        self.infoa = OWGUI.widgetLabel(self.infBox, 'No tree.')
79        self.infob = OWGUI.widgetLabel(self.infBox, ' ')
80
81        OWGUI.rubber(self.controlArea)
82
83        # list view
84        self.splitter = QSplitter(Qt.Vertical, self.mainArea)
85        self.mainArea.layout().addWidget(self.splitter)
86
87        self.v = QTreeWidget(self.splitter)
88        self.splitter.addWidget(self.v)
89        self.v.setAllColumnsShowFocus(1)
90        self.v.setHeaderLabels(['Classification Tree'] + [label[1] for label in self.dataLabels])
91        self.v.setColumnWidth(0, 250)
92        self.connect(self.v, SIGNAL("itemSelectionChanged()"), self.viewSelectionChanged)
93
94        # rule
95        self.rule = QTextEdit(self.splitter)
96        self.splitter.addWidget(self.rule)
97        self.rule.setReadOnly(1)
98        self.splitter.setStretchFactor(0, 2)
99        self.splitter.setStretchFactor(1, 1)
100
101        self.resize(800,400)
102
103        self.resize(830, 400)
104
105    def sendReport(self):
106        if self.tree:
107            self.reportSettings("Information",
108                                [("Target class",self.tree.domain.classVar.values[self.targetClass]),
109                                 ("Tree size", "%i nodes, %i leaves" % (self.treeNodes, self.treeLeaves))])
110        else:
111            self.reportSettings("Information",
112                                [("Target class", "N/A"),
113                                 ("Tree size", "N/A")])
114        self.reportSection("Tree")
115        import OWReport
116        self.reportRaw(OWReport.reportTree(self.v))
117       
118    def getTreeItemSibling(self, item):
119            parent = item.parent()
120            if not parent:
121                parent = self.v.invisibleRootItem()
122            ind = parent.indexOfChild(item)
123            return parent.child(ind+1)
124
125    # main part:
126
127    def setTreeView(self, updateonly = 0):
128        f = self.precFrmt
129
130        def addNode(node, parent, desc, anew):
131            return li
132
133        def walkupdate(listviewitem):
134            node = self.nodeClassDict[listviewitem]
135            if not node: return
136            ncl = node.nodeClassifier
137            dist = node.distribution
138            a = dist.abs
139            if a < 1e-20:
140                a = 1
141            try:
142                p_majclass = f % float(dist[int(ncl.defaultVal)]/a)
143            except:
144                p_majclass = "N/A"
145            try:
146                p_tarclass = f % float(dist[self.targetClass]/a)
147            except:
148                p_tarclass = "N/A"
149           
150            colf = (str(ncl.defaultValue), 
151                    p_majclass, 
152                    p_tarclass, 
153                    "%d" % dist.cases, 
154                    len(dist) and reduce(lambda x, y: x+':'+y, [self.precFrmt % (x/a) for x in dist]) or "N/A", 
155                    len(dist) and reduce(lambda x, y: x+':'+y, ["%d" % int(x) for x in dist]) or "N/A"
156                   )
157
158            col = 1
159            for j in range(6):
160                if getattr(self, self.settingsList[j]):
161                    listviewitem.setText(col, colf[j])
162                    col += 1
163
164            for i in range(listviewitem.childCount()):
165                walkupdate(listviewitem.child(i))
166
167        def walkcreate(node, parent):
168            if not node: return
169            if node.branchSelector:
170                for i in range(len(node.branches)):
171                    if node.branches[i]:
172                        bd = node.branchDescriptions[i]
173                        if not bd[0] in ["<", ">"]:
174                            bd = node.branchSelector.classVar.name + " = " + bd
175                        else:
176                            bd = node.branchSelector.classVar.name + " " + bd
177                        li = QTreeWidgetItem(parent, [bd])
178                        li.setExpanded(1)
179                        self.nodeClassDict[li] = node.branches[i]
180                        walkcreate(node.branches[i], li)
181
182        headerItemStrings = []
183        for i in range(len(self.dataLabels)):
184            if getattr(self, self.settingsList[i]):
185                headerItemStrings.append(self.dataLabels[i][1])
186        self.v.setHeaderLabels(["Classification Tree"] + headerItemStrings)
187        self.v.setColumnCount(len(headerItemStrings)+1)
188        self.v.setRootIsDecorated(1)
189        self.v.header().setResizeMode(0, QHeaderView.Interactive)
190        for i in range(len(headerItemStrings)):
191            self.v.header().setResizeMode(1+i, QHeaderView.ResizeToContents)
192
193        if not updateonly:
194            self.v.clear()
195            self.nodeClassDict = {}
196            li = QTreeWidgetItem(self.v, ["<root>"])
197            li.setExpanded(1)
198            if self.tree:
199                self.nodeClassDict[li] = self.tree.tree
200                walkcreate(self.tree.tree, li)
201            self.rule.setText("")
202        if self.tree:
203            walkupdate(self.v.invisibleRootItem().child(0))
204        self.v.show()
205
206    # slots: handle input signals
207
208    def setClassificationTree(self, tree):
209        self.closeContext()
210        if tree and (not tree.classVar or tree.classVar.varType != orange.VarTypes.Discrete):
211            self.error("This viewer only shows trees with discrete classes.\nThere is another viewer for regression trees")
212            self.tree = None
213        else:
214            self.error()
215            self.tree = tree
216
217        self.setTreeView()
218        self.sliderChanged()
219
220        self.targetCombo.clear()
221        if tree:
222            self.treeNodes, self.treeLeaves = orngTree.countNodes(tree), orngTree.countLeaves(tree) 
223            self.infoa.setText('Number of nodes: %i' % self.treeNodes)
224            self.infob.setText('Number of leaves: %i' % self.treeLeaves)
225            self.targetCombo.addItems([name for name in tree.tree.examples.domain.classVar.values])
226            self.targetClass = 0
227            self.openContext("", tree.domain)
228        else:
229            self.treeNodes = self.treeLeaves = 0
230            self.infoa.setText('No tree on input.')
231            self.infob.setText('')
232            self.openContext("", None)
233
234    def setTarget(self):
235        def updatetarget(listviewitem):
236            dist = self.nodeClassDict[listviewitem].distribution
237            listviewitem.setText(targetindex, f % (dist[self.targetClass]/max(1, dist.abs)))
238
239            for i in range(listviewitem.childCount()):
240                updatetarget(listviewitem.child(i))
241
242        if self.ptarget:
243            targetindex = 1
244            for st in range(5):
245                if self.settingsList[st] == "ptarget":
246                    break
247                if getattr(self, self.settingsList[st]):
248                    targetindex += 1
249
250            f = self.precFrmt
251            if self.v.invisibleRootItem():
252                updatetarget(self.v.invisibleRootItem().child(0))
253
254    def expandTree(self, lev):
255        def expandTree0(listviewitem, lev):
256            if not listviewitem:
257                return
258            if not lev:
259                listviewitem.setExpanded(0)
260            else:
261                listviewitem.setExpanded(1)
262                for i in range(listviewitem.childCount()):
263                    child = listviewitem.child(i)
264                    expandTree0(child, lev-1)
265
266        expandTree0(self.v.invisibleRootItem().child(0), lev)
267
268    # signal processing
269   
270    def viewSelectionChanged(self):
271        """handles click on the tree"""
272        selected = self.v.selectedItems()
273        item = selected.pop() if selected else None 
274        self.handleSelectionChanged(item)
275        if self.tree and item:
276            data = self.nodeClassDict[item].examples
277            self.send("Data", data)
278
279            tx = ""
280            f = 1
281            nodeclsfr = self.nodeClassDict[item].nodeClassifier
282            while item and item.parent():
283                if f:
284                    tx = str(item.text(0))
285                    f = 0
286                else:
287                    tx = str(item.text(0)) + " AND\n    "+tx
288
289                item = item.parent()
290
291            classLabel = str(nodeclsfr.defaultValue)
292            className = str(nodeclsfr.classVar.name)
293            if tx:
294                self.rule.setText("IF %(tx)s\nTHEN %(className)s = %(classLabel)s" % vars())
295            else:
296                self.rule.setText("%(className)s = %(classLabel)s" % vars())
297        else:
298            self.send("Data", None)
299            self.rule.setText("")
300
301    def handleSelectionChanged(self, item):
302        pass
303
304    def sliderChanged(self):
305        self.expandTree(self.sliderValue)
306
307##############################################################################
308# Test the widget, run from DOS prompt
309# > python OWDataTable.py)
310# Make sure that a sample data set (adult_sample.tab) is in the directory
311
312if __name__=="__main__":
313    a=QApplication(sys.argv)
314    ow=OWClassificationTreeViewer()
315    #a.setMainWidget(ow)
316
317    data = orange.ExampleTable(r'../../doc/datasets/adult_sample')
318
319    tree = orange.TreeLearner(data, storeExamples = 1)
320    ow.setClassificationTree(tree)
321    ow.show()
322    a.exec_()
323    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.