source: orange/Orange/OrangeWidgets/Classify/OWITree.py @ 10457:5e5d93460ecb

Revision 10457:5e5d93460ecb, 9.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Using VariableListModel for displaying features in 'Split Selection' combo box (fixes #1117).

Line 
1"""
2<name>Interactive Tree Builder</name>
3<description>Interactive Tree Builder</description>
4<icon>icons/ITree.png</icon>
5<contact>Janez Demsar (janez.demsar(@at@)fri.uni-lj.si)</contact>
6<priority>50</priority>
7"""
8from OWWidget import *
9from OWClassificationTreeViewer import *
10import OWGUI, sys, orngTree
11from orngDataCaching import *
12
13from OWItemModels import VariableListModel
14
15class FixedTreeLearner(orange.Learner):
16    def __init__(self, classifier, name):
17        self.classifier = classifier
18        self.name = name
19
20    def __call__(self, *d):
21        return self.classifier
22
23class OWITree(OWClassificationTreeViewer):
24    settingsList = OWClassificationTreeViewer.settingsList
25    contextHandlers = OWClassificationTreeViewer.contextHandlers
26
27    def __init__(self,parent = None, signalManager = None):
28        OWClassificationTreeViewer.__init__(self, parent, signalManager, 'I&nteractive Tree Builder')
29        self.inputs = [("Data", ExampleTable, self.setData), ("Tree Learner", orange.Learner, self.setLearner)]
30        self.outputs = [("Data", ExampleTable), ("Classifier", orange.TreeClassifier), ("Tree Learner", orange.Learner)]
31
32        self.attridx = 0
33        self.cutoffPoint = 0.0
34        self.targetClass = 0
35        self.loadSettings()
36
37        self.data = None
38        self.treeLearner = None
39        self.tree = None
40        self.learner = None
41       
42        new_controlArea = OWGUI.widgetBox(self.leftWidgetPart, orientation="vertical", margin=4, addToLayout=False)
43        self.leftWidgetPart.layout().insertWidget(0, new_controlArea)
44        self.leftWidgetPart.layout().removeWidget(self.controlArea)
45
46        tabWidget = OWGUI.tabWidget(new_controlArea)
47        buildTab = OWGUI.createTabPage(tabWidget, "Build")
48#        new_controlArea.layout().addWidget(self.controlArea)
49
50        self.old_controlArea = self.controlArea
51        displayTab = OWGUI.createTabPage(tabWidget, "Display", self.controlArea)
52        self.controlArea = new_controlArea
53
54        self.old_controlArea.layout().removeWidget(self.infBox)
55        buildTab.layout().insertWidget(0, self.infBox)
56       
57        OWGUI.separator(buildTab)
58        box = OWGUI.widgetBox(buildTab, "Split selection")
59#        OWGUI.widgetLabel(box, "Split By:")
60        self.attrsCombo = OWGUI.comboBox(box, self, 'attridx', orientation="horizontal", callback=self.cbAttributeSelected)
61        self.cutoffEdit = OWGUI.lineEdit(box, self, 'cutoffPoint', label = 'Cut off point: ', orientation='horizontal', validator=QDoubleValidator(self))
62        OWGUI.button(box, self, "Split", callback=self.btnSplitClicked)
63
64        OWGUI.separator(buildTab)
65        box = OWGUI.widgetBox(buildTab, "Prune or grow tree")
66        self.btnPrune = OWGUI.button(box, self, "Cut", callback = self.btnPruneClicked, disabled = 1)
67        self.btnBuild = OWGUI.button(box, self, "Build", callback = self.btnBuildClicked)
68
69        OWGUI.rubber(buildTab)
70
71        self.activateLoadedSettings()
72        #self.space.updateGeometry()
73
74    def sendReport(self):
75        self.reportData(self.data)
76        self.treeNodes, self.treeLeaves = orngTree.countNodes(self.tree), orngTree.countLeaves(self.tree)
77        super(OWITree, self).sendReport()
78
79    def cbAttributeSelected(self):
80        val = ""
81        if self.data:
82            attr = self.data.domain[self.attridx]
83            if attr.varType == orange.VarTypes.Continuous:
84                val = str(orange.Value(attr, self.basstat[attr].avg))
85        self.cutoffEdit.setDisabled(not val)
86        self.cutoffEdit.setText(val)
87
88    def activateLoadedSettings(self):
89        self.cbAttributeSelected()
90
91    def updateTree(self):
92        self.setTreeView()
93        self.learner = FixedTreeLearner(self.tree, self.captionTitle)
94        self.infoa.setText("Number of nodes: %i" % orngTree.countNodes(self.tree))
95        self.infob.setText("Number of leaves: %i" % orngTree.countLeaves(self.tree))
96#        self.send("Data", self.tree)
97        self.send("Classifier", self.tree)
98        self.send("Tree Learner", self.learner)
99
100    def newTreeNode(self, data):
101        node = orange.TreeNode()
102        node.examples = data
103        node.contingency = orange.DomainContingency(data)
104        node.distribution = node.contingency.classes
105        nodeLearner = self.treeLearner and getattr(self.treeLearner, "nodeLearner", None) or orange.MajorityLearner()
106        node.nodeClassifier = nodeLearner(data)
107        return node
108
109    def cutNode(self, node):
110        if not node:
111            return
112        node.branchDescriptions = node.branchSelector = node.branchSizes = node.branches = None
113
114    def findCurrentNode(self, exhaustively=0):
115        sitems = self.v.selectedItems()
116        sitem = sitems != [] and sitems[0] or None
117        if not sitem and (1 or exhaustively):
118            sitem = self.v.currentItem() or (self.v.invisibleRootItem().childCount() == 1 and self.v.invisibleRootItem().child(0))
119            if not sitem or sitem.childCount():
120                return
121        return sitem and self.nodeClassDict.get(sitem, None)
122
123    def btnSplitClicked(self):
124        node = self.findCurrentNode(1)
125        if not node:
126            return
127
128        attr = self.data.domain[self.attridx]
129        if attr.varType == orange.VarTypes.Continuous:
130            cutstr = str(self.cutoffEdit.text())
131            if not cutstr:
132                return
133            cutoff = float(cutstr)
134
135            node.branchSelector = orange.ClassifierFromVarFD(position=self.attridx, domain=self.data.domain, classVar=attr)
136            node.branchSelector.transformer = orange.ThresholdDiscretizer(threshold = cutoff)
137            node.branchDescriptions = ["<%5.3f" % cutoff, ">=%5.3f" % cutoff]
138
139            cutvar = orange.EnumVariable(node.examples.domain[self.attridx].name, values = node.branchDescriptions)
140            cutvar.getValueFrom = node.branchSelector
141            node.branchSizes = orange.Distribution(cutvar, node.examples)
142            node.branchSelector.classVar = cutvar
143
144        else:
145            node.branchSelector = orange.ClassifierFromVarFD(position=self.attridx, domain=self.data.domain, classVar=attr)
146            node.branchDescriptions=node.branchSelector.classVar.values
147            node.branchSizes = orange.Distribution(attr, node.examples)
148
149        splitter = self.treeLearner and getattr(self.treeLearner, "splitter", None) or orange.TreeExampleSplitter_IgnoreUnknowns()
150        node.branches = [subset and self.newTreeNode(subset) or None   for subset in splitter(node, node.examples)[0]]
151        self.updateTree()
152
153    def btnPruneClicked(self):
154        node = self.findCurrentNode()
155        if node:
156            self.cutNode(node)
157            self.updateTree()
158
159    def btnBuildClicked(self):
160        node = self.findCurrentNode()
161        if not node or not len(node.examples):
162            return
163
164        try:
165            newtree = (self.treeLearner or orngTree.TreeLearner(storeExamples = 1))(node.examples)
166
167        except:
168            return
169       
170        if not hasattr(newtree, "tree"):
171            QMessageBox.critical( None, "Invalid Learner", "The learner on the input built a classifier which is not a tree.", QMessageBox.Ok)
172
173        for k, v in newtree.tree.__dict__.items():
174            node.setattr(k, v)
175        self.updateTree()
176
177    def setData(self, data):
178        self.closeContext()
179        if self.data and data and data.domain.checksum() == self.data.domain.checksum():
180            return
181
182        self.attrsCombo.clear()
183        self.data = self.isDataWithClass(data, orange.VarTypes.Discrete) and data or None
184       
185        self.targetCombo.clear()
186        if self.data:
187            model = VariableListModel(list(data.domain.attributes))
188            self.attrsCombo.setModel(model)
189            self.basstat = getCached(data, orange.DomainBasicAttrStat, (data,))
190           
191#            self.attrsCombo.adjustSize()
192            self.attridx = 0
193            self.cbAttributeSelected()
194            self.tree = orange.TreeClassifier(domain = data.domain)
195            self.tree.descender = orange.TreeDescender_UnknownMergeAsBranchSizes()
196            self.tree.tree = self.newTreeNode(self.data)
197            # set target class combo
198            self.targetCombo.addItems([name for name in self.tree.tree.examples.domain.classVar.values])
199            self.targetClass = 0
200            self.openContext("", self.tree.domain)
201        else:
202            self.tree = None
203            self.infoa.setText("No tree.")
204            self.infob.setText("")
205            self.send("Classifier", self.tree)
206            self.send("Tree Learner", self.learner)
207            self.openContext("", None)
208
209        self.send("Data", None)
210        self.updateTree()
211        self.v.invisibleRootItem().child(0).setSelected(1)
212
213    def setLearner(self, learner):
214        self.treeLearner = learner
215
216    def handleSelectionChanged(self, item):
217        """called when new node in the tree is selected"""
218        if self.nodeClassDict.has_key(item):
219            self.btnPrune.setEnabled(self.nodeClassDict[item].branchSelector <> None)
220       
221
222
223if __name__ == "__main__":
224    a=QApplication(sys.argv)
225    owi=OWITree()
226#    a.setMainWidget(owi)
227    d = orange.ExampleTable(r'../../doc/datasets/iris')
228    owi.setData(d)
229    owi.show()
230    a.exec_()
Note: See TracBrowser for help on using the repository browser.