source: orange/Orange/OrangeWidgets/Classify/OWITree.py @ 10743:626f2fdde58f

Revision 10743:626f2fdde58f, 9.2 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Changed the last orange.TreeClassifier signal types to Orange.classification.tree.TreeClassifier

Line 
1"""
2<name>Interactive Tree Builder</name>
3<description>Interactive Tree Builder</description>
4<icon>icons/ITree.png</icon>
5<contact>Janez Demsar (janez.demsar(@at@)fri.uni-lj.si)</contact>
6<priority>50</priority>
7"""
8from OWWidget import *
9from OWClassificationTreeViewer import *
10import OWGUI, sys, orngTree
11from orngDataCaching import *
12
13from OWItemModels import VariableListModel
14
15import Orange
16
17class FixedTreeLearner(orange.Learner):
18    def __init__(self, classifier, name):
19        self.classifier = classifier
20        self.name = name
21
22    def __call__(self, *d):
23        return self.classifier
24
25class OWITree(OWClassificationTreeViewer):
26    settingsList = OWClassificationTreeViewer.settingsList
27    contextHandlers = OWClassificationTreeViewer.contextHandlers
28
29    def __init__(self,parent = None, signalManager = None):
30        OWClassificationTreeViewer.__init__(self, parent, signalManager, 'I&nteractive Tree Builder')
31        self.inputs = [("Data", ExampleTable, self.setData),
32                       ("Tree Learner", orange.Learner, self.setLearner)]
33       
34        self.outputs = [("Data", ExampleTable),
35                        ("Classifier", Orange.classification.tree.TreeClassifier),
36                        ("Tree Learner", orange.Learner)]
37
38        self.attridx = 0
39        self.cutoffPoint = 0.0
40        self.targetClass = 0
41        self.loadSettings()
42
43        self.data = None
44        self.treeLearner = None
45        self.tree = None
46        self.learner = None
47       
48        new_controlArea = OWGUI.widgetBox(self.leftWidgetPart, orientation="vertical", margin=4, addToLayout=False)
49        self.leftWidgetPart.layout().insertWidget(0, new_controlArea)
50        self.leftWidgetPart.layout().removeWidget(self.controlArea)
51
52        tabWidget = OWGUI.tabWidget(new_controlArea)
53        buildTab = OWGUI.createTabPage(tabWidget, "Build")
54#        new_controlArea.layout().addWidget(self.controlArea)
55
56        self.old_controlArea = self.controlArea
57        displayTab = OWGUI.createTabPage(tabWidget, "Display", self.controlArea)
58        self.controlArea = new_controlArea
59
60        self.old_controlArea.layout().removeWidget(self.infBox)
61        buildTab.layout().insertWidget(0, self.infBox)
62       
63        OWGUI.separator(buildTab)
64        box = OWGUI.widgetBox(buildTab, "Split selection")
65#        OWGUI.widgetLabel(box, "Split By:")
66        self.attrsCombo = OWGUI.comboBox(box, self, 'attridx', orientation="horizontal", callback=self.cbAttributeSelected)
67        self.cutoffEdit = OWGUI.lineEdit(box, self, 'cutoffPoint', label = 'Cut off point: ', orientation='horizontal', validator=QDoubleValidator(self))
68        OWGUI.button(box, self, "Split", callback=self.btnSplitClicked)
69
70        OWGUI.separator(buildTab)
71        box = OWGUI.widgetBox(buildTab, "Prune or grow tree")
72        self.btnPrune = OWGUI.button(box, self, "Cut", callback = self.btnPruneClicked, disabled = 1)
73        self.btnBuild = OWGUI.button(box, self, "Build", callback = self.btnBuildClicked)
74
75        OWGUI.rubber(buildTab)
76
77        self.activateLoadedSettings()
78        #self.space.updateGeometry()
79
80    def sendReport(self):
81        self.reportData(self.data)
82        self.treeNodes, self.treeLeaves = orngTree.countNodes(self.tree), orngTree.countLeaves(self.tree)
83        super(OWITree, self).sendReport()
84
85    def cbAttributeSelected(self):
86        val = ""
87        if self.data:
88            attr = self.data.domain[self.attridx]
89            if attr.varType == orange.VarTypes.Continuous:
90                val = str(orange.Value(attr, self.basstat[attr].avg))
91        self.cutoffEdit.setDisabled(not val)
92        self.cutoffEdit.setText(val)
93
94    def activateLoadedSettings(self):
95        self.cbAttributeSelected()
96
97    def updateTree(self):
98        self.setTreeView()
99        self.learner = FixedTreeLearner(self.tree, self.captionTitle)
100        self.infoa.setText("Number of nodes: %i" % orngTree.countNodes(self.tree))
101        self.infob.setText("Number of leaves: %i" % orngTree.countLeaves(self.tree))
102#        self.send("Data", self.tree)
103        self.send("Classifier", self.tree)
104        self.send("Tree Learner", self.learner)
105
106    def newTreeNode(self, data):
107        node = orange.TreeNode()
108        node.examples = data
109        node.contingency = orange.DomainContingency(data)
110        node.distribution = node.contingency.classes
111        nodeLearner = self.treeLearner and getattr(self.treeLearner, "nodeLearner", None) or orange.MajorityLearner()
112        node.nodeClassifier = nodeLearner(data)
113        return node
114
115    def cutNode(self, node):
116        if not node:
117            return
118        node.branchDescriptions = node.branchSelector = node.branchSizes = node.branches = None
119
120    def findCurrentNode(self, exhaustively=0):
121        sitems = self.v.selectedItems()
122        sitem = sitems != [] and sitems[0] or None
123        if not sitem and (1 or exhaustively):
124            sitem = self.v.currentItem() or (self.v.invisibleRootItem().childCount() == 1 and self.v.invisibleRootItem().child(0))
125            if not sitem or sitem.childCount():
126                return
127        return sitem and self.nodeClassDict.get(sitem, None)
128
129    def btnSplitClicked(self):
130        node = self.findCurrentNode(1)
131        if not node:
132            return
133
134        attr = self.data.domain[self.attridx]
135        if attr.varType == orange.VarTypes.Continuous:
136            cutstr = str(self.cutoffEdit.text())
137            if not cutstr:
138                return
139            cutoff = float(cutstr)
140
141            node.branchSelector = orange.ClassifierFromVarFD(position=self.attridx, domain=self.data.domain, classVar=attr)
142            node.branchSelector.transformer = orange.ThresholdDiscretizer(threshold = cutoff)
143            node.branchDescriptions = ["<%5.3f" % cutoff, ">=%5.3f" % cutoff]
144
145            cutvar = orange.EnumVariable(node.examples.domain[self.attridx].name, values = node.branchDescriptions)
146            cutvar.getValueFrom = node.branchSelector
147            node.branchSizes = orange.Distribution(cutvar, node.examples)
148            node.branchSelector.classVar = cutvar
149
150        else:
151            node.branchSelector = orange.ClassifierFromVarFD(position=self.attridx, domain=self.data.domain, classVar=attr)
152            node.branchDescriptions=node.branchSelector.classVar.values
153            node.branchSizes = orange.Distribution(attr, node.examples)
154
155        splitter = self.treeLearner and getattr(self.treeLearner, "splitter", None) or orange.TreeExampleSplitter_IgnoreUnknowns()
156        node.branches = [subset and self.newTreeNode(subset) or None   for subset in splitter(node, node.examples)[0]]
157        self.updateTree()
158
159    def btnPruneClicked(self):
160        node = self.findCurrentNode()
161        if node:
162            self.cutNode(node)
163            self.updateTree()
164
165    def btnBuildClicked(self):
166        node = self.findCurrentNode()
167        if not node or not len(node.examples):
168            return
169
170        try:
171            newtree = (self.treeLearner or orngTree.TreeLearner(storeExamples = 1))(node.examples)
172
173        except:
174            return
175       
176        if not hasattr(newtree, "tree"):
177            QMessageBox.critical( None, "Invalid Learner", "The learner on the input built a classifier which is not a tree.", QMessageBox.Ok)
178
179        for k, v in newtree.tree.__dict__.items():
180            node.setattr(k, v)
181        self.updateTree()
182
183    def setData(self, data):
184        self.closeContext()
185        if self.data and data and data.domain.checksum() == self.data.domain.checksum():
186            return
187
188        self.attrsCombo.clear()
189        self.data = self.isDataWithClass(data, orange.VarTypes.Discrete) and data or None
190       
191        self.targetCombo.clear()
192        if self.data:
193            model = VariableListModel(list(data.domain.attributes))
194            self.attrsCombo.setModel(model)
195            self.basstat = getCached(data, orange.DomainBasicAttrStat, (data,))
196           
197#            self.attrsCombo.adjustSize()
198            self.attridx = 0
199            self.cbAttributeSelected()
200            self.tree = orange.TreeClassifier(domain = data.domain)
201            self.tree.descender = orange.TreeDescender_UnknownMergeAsBranchSizes()
202            self.tree.tree = self.newTreeNode(self.data)
203            # set target class combo
204            self.targetCombo.addItems([name for name in self.tree.tree.examples.domain.classVar.values])
205            self.targetClass = 0
206            self.openContext("", self.tree.domain)
207        else:
208            self.tree = None
209            self.infoa.setText("No tree.")
210            self.infob.setText("")
211            self.send("Classifier", self.tree)
212            self.send("Tree Learner", self.learner)
213            self.openContext("", None)
214
215        self.send("Data", None)
216        self.updateTree()
217        self.v.invisibleRootItem().child(0).setSelected(1)
218
219    def setLearner(self, learner):
220        self.treeLearner = learner
221
222    def handleSelectionChanged(self, item):
223        """called when new node in the tree is selected"""
224        if self.nodeClassDict.has_key(item):
225            self.btnPrune.setEnabled(self.nodeClassDict[item].branchSelector <> None)
226       
227
228
229if __name__ == "__main__":
230    a=QApplication(sys.argv)
231    owi=OWITree()
232#    a.setMainWidget(owi)
233    d = orange.ExampleTable(r'../../doc/datasets/iris')
234    owi.setData(d)
235    owi.show()
236    a.exec_()
Note: See TracBrowser for help on using the repository browser.