source: orange/Orange/OrangeWidgets/Classify/OWClassificationTreeViewer.py @ 10552:c0028bc3f865

Revision 10552:c0028bc3f865, 11.8 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Replaced orange.TreeClassifier with Orange.classification.tree.TreeClassifier in channel types (needed for dynamic signals to work when connecting to Tree Graph/Viewer widgets).

Line 
1"""
2<name>Classification Tree Viewer</name>
3<description>Classification tree viewer (hierarchical list view).</description>
4<icon>icons/ClassificationTreeViewer.png</icon>
5<contact>Janez Demsar (janez.demsar(@at@)fri.uni-lj.si)</contact>
6<priority>2100</priority>
7"""
8from OWWidget import *
9from orngTree import TreeLearner
10import OWGUI
11
12import orngTree
13import Orange
14
15class ColumnCallback:
16    def __init__(self, widget, attribute, f = None):
17        self.widget = widget
18        self.attribute = attribute
19        self.f = f
20        widget.callbackDeposit.append(self)
21
22    def __call__(self, value):
23        setattr(self.widget, self.attribute, self.f and self.f(value) or value)
24        self.widget.setTreeView(1)
25
26def checkColumn(widget, master, text, value):
27    wa = QCheckBox(text, widget)
28    widget.layout().addWidget(wa)
29    wa.setChecked(getattr(master, value))
30    master.connect(wa, SIGNAL("toggled(bool)"), ColumnCallback(master, value))
31    return wa
32
33class OWClassificationTreeViewer(OWWidget):
34    settingsList = ["maj", "pmaj", "ptarget", "inst", "dist", "adist", "expslider", "sliderValue"]
35    contextHandlers = {"": DomainContextHandler("", ["targetClass"], matchValues=1)}
36
37    def __init__(self, parent=None, signalManager = None, name='Classification Tree Viewer'):
38        OWWidget.__init__(self, parent, signalManager, name)
39
40        self.dataLabels = (('Majority class', 'Class'), 
41                  ('Probability of majority class', 'P(Class)'), 
42                  ('Probability of target class', 'P(Target)'), 
43                  ('Number of instances', '# Inst'), 
44                  ('Relative distribution', 'Distribution (rel)'), 
45                  ('Absolute distribution', 'Distribution (abs)'))
46
47#        self.callbackDeposit = []
48
49        self.inputs = [("Classification Tree", Orange.classification.tree.TreeClassifier, self.setClassificationTree)]
50        self.outputs = [("Data", ExampleTable)]
51
52        # Settings
53        for s in self.settingsList[:6]:
54            setattr(self, s, 1)
55        self.expslider = 5
56        self.targetClass = 0
57        self.loadSettings()
58
59        self.tree = None
60        self.sliderValue = 5
61        self.precision = 3
62        self.precFrmt = "%%2.%if" % self.precision
63
64        # GUI
65        # parameters
66
67        self.dBox = OWGUI.widgetBox(self.controlArea, 'Displayed information')
68        for i in range(len(self.dataLabels)):
69            checkColumn(self.dBox, self, self.dataLabels[i][0], self.settingsList[i])
70
71        OWGUI.separator(self.controlArea)
72               
73        self.slider = OWGUI.hSlider(self.controlArea, self, "sliderValue", box = 'Expand/shrink to level', minValue = 1, maxValue = 9, step = 1, callback = self.sliderChanged)
74
75        OWGUI.separator(self.controlArea)
76        self.targetCombo=OWGUI.comboBox(self.controlArea, self, "targetClass", items=[], box="Target class", callback=self.setTarget, addSpace=True)
77
78        self.infBox = OWGUI.widgetBox(self.controlArea, 'Tree size')
79        self.infoa = OWGUI.widgetLabel(self.infBox, 'No tree.')
80        self.infob = OWGUI.widgetLabel(self.infBox, ' ')
81
82        OWGUI.rubber(self.controlArea)
83
84        # list view
85        self.splitter = QSplitter(Qt.Vertical, self.mainArea)
86        self.mainArea.layout().addWidget(self.splitter)
87
88        self.v = QTreeWidget(self.splitter)
89        self.splitter.addWidget(self.v)
90        self.v.setAllColumnsShowFocus(1)
91        self.v.setHeaderLabels(['Classification Tree'] + [label[1] for label in self.dataLabels])
92        self.v.setColumnWidth(0, 250)
93        self.connect(self.v, SIGNAL("itemSelectionChanged()"), self.viewSelectionChanged)
94
95        # rule
96        self.rule = QTextEdit(self.splitter)
97        self.splitter.addWidget(self.rule)
98        self.rule.setReadOnly(1)
99        self.splitter.setStretchFactor(0, 2)
100        self.splitter.setStretchFactor(1, 1)
101
102        self.resize(800,400)
103
104        self.resize(830, 400)
105
106    def sendReport(self):
107        if self.tree:
108            self.reportSettings("Information",
109                                [("Target class",self.tree.domain.classVar.values[self.targetClass]),
110                                 ("Tree size", "%i nodes, %i leaves" % (self.treeNodes, self.treeLeaves))])
111        else:
112            self.reportSettings("Information",
113                                [("Target class", "N/A"),
114                                 ("Tree size", "N/A")])
115        self.reportSection("Tree")
116        import OWReport
117        self.reportRaw(OWReport.reportTree(self.v))
118       
119    def getTreeItemSibling(self, item):
120            parent = item.parent()
121            if not parent:
122                parent = self.v.invisibleRootItem()
123            ind = parent.indexOfChild(item)
124            return parent.child(ind+1)
125
126    # main part:
127
128    def setTreeView(self, updateonly = 0):
129        f = self.precFrmt
130
131        def addNode(node, parent, desc, anew):
132            return li
133
134        def walkupdate(listviewitem):
135            node = self.nodeClassDict[listviewitem]
136            if not node: return
137            ncl = node.nodeClassifier
138            dist = node.distribution
139            a = dist.abs
140            if a < 1e-20:
141                a = 1
142            try:
143                p_majclass = f % float(dist[int(ncl.defaultVal)]/a)
144            except:
145                p_majclass = "N/A"
146            try:
147                p_tarclass = f % float(dist[self.targetClass]/a)
148            except:
149                p_tarclass = "N/A"
150           
151            colf = (str(ncl.defaultValue), 
152                    p_majclass, 
153                    p_tarclass, 
154                    "%d" % dist.cases, 
155                    len(dist) and reduce(lambda x, y: x+':'+y, [self.precFrmt % (x/a) for x in dist]) or "N/A", 
156                    len(dist) and reduce(lambda x, y: x+':'+y, ["%d" % int(x) for x in dist]) or "N/A"
157                   )
158
159            col = 1
160            for j in range(6):
161                if getattr(self, self.settingsList[j]):
162                    listviewitem.setText(col, colf[j])
163                    col += 1
164
165            for i in range(listviewitem.childCount()):
166                walkupdate(listviewitem.child(i))
167
168        def walkcreate(node, parent):
169            if not node: return
170            if node.branchSelector:
171                for i in range(len(node.branches)):
172                    if node.branches[i]:
173                        bd = node.branchDescriptions[i]
174                        if not bd[0] in ["<", ">"]:
175                            bd = node.branchSelector.classVar.name + " = " + bd
176                        else:
177                            bd = node.branchSelector.classVar.name + " " + bd
178                        li = QTreeWidgetItem(parent, [bd])
179                        li.setExpanded(1)
180                        self.nodeClassDict[li] = node.branches[i]
181                        walkcreate(node.branches[i], li)
182
183        headerItemStrings = []
184        for i in range(len(self.dataLabels)):
185            if getattr(self, self.settingsList[i]):
186                headerItemStrings.append(self.dataLabels[i][1])
187        self.v.setHeaderLabels(["Classification Tree"] + headerItemStrings)
188        self.v.setColumnCount(len(headerItemStrings)+1)
189        self.v.setRootIsDecorated(1)
190        self.v.header().setResizeMode(0, QHeaderView.Interactive)
191        for i in range(len(headerItemStrings)):
192            self.v.header().setResizeMode(1+i, QHeaderView.ResizeToContents)
193
194        if not updateonly:
195            self.v.clear()
196            self.nodeClassDict = {}
197            li = QTreeWidgetItem(self.v, ["<root>"])
198            li.setExpanded(1)
199            if self.tree:
200                self.nodeClassDict[li] = self.tree.tree
201                walkcreate(self.tree.tree, li)
202            self.rule.setText("")
203        if self.tree:
204            walkupdate(self.v.invisibleRootItem().child(0))
205        self.v.show()
206
207    # slots: handle input signals
208
209    def setClassificationTree(self, tree):
210        self.closeContext()
211        if tree and (not tree.classVar or tree.classVar.varType != orange.VarTypes.Discrete):
212            self.error("This viewer only shows trees with discrete classes.\nThere is another viewer for regression trees")
213            self.tree = None
214        else:
215            self.error()
216            self.tree = tree
217
218        self.setTreeView()
219        self.sliderChanged()
220
221        self.targetCombo.clear()
222        if tree:
223            self.treeNodes, self.treeLeaves = orngTree.countNodes(tree), orngTree.countLeaves(tree) 
224            self.infoa.setText('Number of nodes: %i' % self.treeNodes)
225            self.infob.setText('Number of leaves: %i' % self.treeLeaves)
226            self.targetCombo.addItems([name for name in tree.tree.examples.domain.classVar.values])
227            self.targetClass = 0
228            self.openContext("", tree.domain)
229        else:
230            self.treeNodes = self.treeLeaves = 0
231            self.infoa.setText('No tree on input.')
232            self.infob.setText('')
233            self.openContext("", None)
234
235    def setTarget(self):
236        def updatetarget(listviewitem):
237            dist = self.nodeClassDict[listviewitem].distribution
238            listviewitem.setText(targetindex, f % (dist[self.targetClass]/max(1, dist.abs)))
239
240            for i in range(listviewitem.childCount()):
241                updatetarget(listviewitem.child(i))
242
243        if self.ptarget:
244            targetindex = 1
245            for st in range(5):
246                if self.settingsList[st] == "ptarget":
247                    break
248                if getattr(self, self.settingsList[st]):
249                    targetindex += 1
250
251            f = self.precFrmt
252            if self.v.invisibleRootItem():
253                updatetarget(self.v.invisibleRootItem().child(0))
254
255    def expandTree(self, lev):
256        def expandTree0(listviewitem, lev):
257            if not listviewitem:
258                return
259            if not lev:
260                listviewitem.setExpanded(0)
261            else:
262                listviewitem.setExpanded(1)
263                for i in range(listviewitem.childCount()):
264                    child = listviewitem.child(i)
265                    expandTree0(child, lev-1)
266
267        expandTree0(self.v.invisibleRootItem().child(0), lev)
268
269    # signal processing
270   
271    def viewSelectionChanged(self):
272        """handles click on the tree"""
273        selected = self.v.selectedItems()
274        item = selected.pop() if selected else None 
275        self.handleSelectionChanged(item)
276        if self.tree and item:
277            data = self.nodeClassDict[item].examples
278            self.send("Data", data)
279
280            tx = ""
281            f = 1
282            nodeclsfr = self.nodeClassDict[item].nodeClassifier
283            while item and item.parent():
284                if f:
285                    tx = str(item.text(0))
286                    f = 0
287                else:
288                    tx = str(item.text(0)) + " AND\n    "+tx
289
290                item = item.parent()
291
292            classLabel = str(nodeclsfr.defaultValue)
293            className = str(nodeclsfr.classVar.name)
294            if tx:
295                self.rule.setText("IF %(tx)s\nTHEN %(className)s = %(classLabel)s" % vars())
296            else:
297                self.rule.setText("%(className)s = %(classLabel)s" % vars())
298        else:
299            self.send("Data", None)
300            self.rule.setText("")
301
302    def handleSelectionChanged(self, item):
303        pass
304
305    def sliderChanged(self):
306        self.expandTree(self.sliderValue)
307
308##############################################################################
309# Test the widget, run from DOS prompt
310# > python OWDataTable.py)
311# Make sure that a sample data set (adult_sample.tab) is in the directory
312
313if __name__=="__main__":
314    a=QApplication(sys.argv)
315    ow=OWClassificationTreeViewer()
316    #a.setMainWidget(ow)
317
318    data = orange.ExampleTable(r'../../doc/datasets/adult_sample')
319
320    tree = orange.TreeLearner(data, storeExamples = 1)
321    ow.setClassificationTree(tree)
322    ow.show()
323    a.exec_()
324    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.