source: orange/orange/OrangeWidgets/Classify/OWITree.py @ 6538:a5f65d7f0b2c

Revision 6538:a5f65d7f0b2c, 9.0 KB checked in by Mitar <Mitar@…>, 4 years ago (diff)

Made XPM version of the icon 32x32.

Line 
1"""
2<name>Interactive Tree Builder</name>
3<description>Interactive Tree Builder</description>
4<icon>icons/ITree.png</icon>
5<contact>Janez Demsar (janez.demsar(@at@)fri.uni-lj.si)</contact>
6<priority>50</priority>
7"""
8from OWWidget import *
9from OWClassificationTreeViewer import *
10import OWGUI, sys, orngTree
11from orngDataCaching import *
12
13class FixedTreeLearner(orange.Learner):
14    def __init__(self, classifier, name):
15        self.classifier = classifier
16        self.name = name
17
18    def __call__(self, *d):
19        return self.classifier
20
21class OWITree(OWClassificationTreeViewer):
22    settingsList = OWClassificationTreeViewer.settingsList
23    contextHandlers = OWClassificationTreeViewer.contextHandlers
24
25    def __init__(self,parent = None, signalManager = None):
26        OWClassificationTreeViewer.__init__(self, parent, signalManager, 'I&nteractive Tree Builder')
27        self.inputs = [("Examples", ExampleTable, self.setData), ("Tree Learner", orange.Learner, self.setLearner)]
28        self.outputs = [("Examples", ExampleTable), ("Classifier", orange.TreeClassifier), ("Tree Learner", orange.Learner)]
29
30        self.attridx = 0
31        self.cutoffPoint = 0.0
32        self.targetClass = 0
33        self.loadSettings()
34
35        self.data = None
36        self.treeLearner = None
37        self.tree = None
38        self.learner = None
39       
40        new_controlArea = OWGUI.widgetBox(self.leftWidgetPart, orientation="vertical", margin=4, addToLayout=False)
41        self.leftWidgetPart.layout().insertWidget(0, new_controlArea)
42        self.leftWidgetPart.layout().removeWidget(self.controlArea)
43
44        tabWidget = OWGUI.tabWidget(new_controlArea)
45        buildTab = OWGUI.createTabPage(tabWidget, "Build")
46#        new_controlArea.layout().addWidget(self.controlArea)
47
48        self.old_controlArea = self.controlArea
49        displayTab = OWGUI.createTabPage(tabWidget, "Display", self.controlArea)
50        self.controlArea = new_controlArea
51
52        self.old_controlArea.layout().removeWidget(self.infBox)
53        buildTab.layout().insertWidget(0, self.infBox)
54       
55        OWGUI.separator(buildTab)
56        box = OWGUI.widgetBox(buildTab, "Split selection")
57#        OWGUI.widgetLabel(box, "Split By:")
58        self.attrsCombo = OWGUI.comboBox(box, self, 'attridx', orientation="horizontal", callback=self.cbAttributeSelected)
59        self.cutoffEdit = OWGUI.lineEdit(box, self, 'cutoffPoint', label = 'Cut off point: ', orientation='horizontal', validator=QDoubleValidator(self))
60        OWGUI.button(box, self, "Split", callback=self.btnSplitClicked)
61
62        OWGUI.separator(buildTab)
63        box = OWGUI.widgetBox(buildTab, "Prune or grow tree")
64        self.btnPrune = OWGUI.button(box, self, "Cut", callback = self.btnPruneClicked, disabled = 1)
65        self.btnBuild = OWGUI.button(box, self, "Build", callback = self.btnBuildClicked)
66
67        OWGUI.rubber(buildTab)
68
69        self.activateLoadedSettings()
70        #self.space.updateGeometry()
71
72    def sendReport(self):
73        self.reportData(self.data)
74        self.treeNodes, self.treeLeaves = orngTree.countNodes(self.tree), orngTree.countLeaves(self.tree)
75        super(OWITree, self).sendReport()
76
77    def cbAttributeSelected(self):
78        val = ""
79        if self.data:
80            attr = self.data.domain[self.attridx]
81            if attr.varType == orange.VarTypes.Continuous:
82                val = str(orange.Value(attr, self.basstat[attr].avg))
83        self.cutoffEdit.setDisabled(not val)
84        self.cutoffEdit.setText(val)
85
86    def activateLoadedSettings(self):
87        self.cbAttributeSelected()
88
89    def updateTree(self):
90        self.setTreeView()
91        self.learner = FixedTreeLearner(self.tree, self.captionTitle)
92        self.infoa.setText("Number of nodes: %i" % orngTree.countNodes(self.tree))
93        self.infob.setText("Number of leaves: %i" % orngTree.countLeaves(self.tree))
94#        self.send("Examples", self.tree)
95        self.send("Classifier", self.tree)
96        self.send("Tree Learner", self.learner)
97
98    def newTreeNode(self, data):
99        node = orange.TreeNode()
100        node.examples = data
101        node.contingency = orange.DomainContingency(data)
102        node.distribution = node.contingency.classes
103        nodeLearner = self.treeLearner and getattr(self.treeLearner, "nodeLearner", None) or orange.MajorityLearner()
104        node.nodeClassifier = nodeLearner(data)
105        return node
106
107    def cutNode(self, node):
108        if not node:
109            return
110        node.branchDescriptions = node.branchSelector = node.branchSizes = node.branches = None
111
112    def findCurrentNode(self, exhaustively=0):
113        sitems = self.v.selectedItems()
114        sitem = sitems != [] and sitems[0] or None
115        if not sitem and (1 or exhaustively):
116            sitem = self.v.currentItem() or (self.v.invisibleRootItem().childCount() == 1 and self.v.invisibleRootItem().child(0))
117            if not sitem or sitem.childCount():
118                return
119        return sitem and self.nodeClassDict.get(sitem, None)
120
121    def btnSplitClicked(self):
122        node = self.findCurrentNode(1)
123        if not node:
124            return
125
126        attr = self.data.domain[self.attridx]
127        if attr.varType == orange.VarTypes.Continuous:
128            cutstr = str(self.cutoffEdit.text())
129            if not cutstr:
130                return
131            cutoff = float(cutstr)
132
133            node.branchSelector = orange.ClassifierFromVarFD(position=self.attridx, domain=self.data.domain, classVar=attr)
134            node.branchSelector.transformer = orange.ThresholdDiscretizer(threshold = cutoff)
135            node.branchDescriptions = ["<%5.3f" % cutoff, ">=%5.3f" % cutoff]
136
137            cutvar = orange.EnumVariable(node.examples.domain[self.attridx].name, values = node.branchDescriptions)
138            cutvar.getValueFrom = node.branchSelector
139            node.branchSizes = orange.Distribution(cutvar, node.examples)
140            node.branchSelector.classVar = cutvar
141
142        else:
143            node.branchSelector = orange.ClassifierFromVarFD(position=self.attridx, domain=self.data.domain, classVar=attr)
144            node.branchDescriptions=node.branchSelector.classVar.values
145            node.branchSizes = orange.Distribution(attr, node.examples)
146
147        splitter = self.treeLearner and getattr(self.treeLearner, "splitter", None) or orange.TreeExampleSplitter_IgnoreUnknowns()
148        node.branches = [subset and self.newTreeNode(subset) or None   for subset in splitter(node, node.examples)[0]]
149        self.updateTree()
150
151    def btnPruneClicked(self):
152        node = self.findCurrentNode()
153        if node:
154            self.cutNode(node)
155            self.updateTree()
156
157    def btnBuildClicked(self):
158        node = self.findCurrentNode()
159        if not node or not len(node.examples):
160            return
161
162        try:
163            newtree = (self.treeLearner or orngTree.TreeLearner(storeExamples = 1))(node.examples)
164
165        except:
166            return
167       
168        if not hasattr(newtree, "tree"):
169            QMessageBox.critical( None, "Invalid Learner", "The learner on the input built a classifier which is not a tree.", QMessageBox.Ok)
170
171        for k, v in newtree.tree.__dict__.items():
172            node.setattr(k, v)
173        self.updateTree()
174
175    def setData(self, data):
176        self.closeContext()
177        if self.data and data and data.domain.checksum() == self.data.domain.checksum():
178            return
179
180        self.attrsCombo.clear()
181        self.data = self.isDataWithClass(data, orange.VarTypes.Discrete) and data or None
182       
183        self.targetCombo.clear()
184        if self.data:
185            self.attrsCombo.addItems([str(a) for a in data.domain.attributes])
186            self.basstat = getCached(data, orange.DomainBasicAttrStat, (data,))
187           
188#            self.attrsCombo.adjustSize()
189            self.attridx = 0
190            self.cbAttributeSelected()
191            self.tree = orange.TreeClassifier(domain = data.domain)
192            self.tree.descender = orange.TreeDescender_UnknownMergeAsBranchSizes()
193            self.tree.tree = self.newTreeNode(self.data)
194            # set target class combo
195            self.targetCombo.addItems([name for name in self.tree.tree.examples.domain.classVar.values])
196            self.targetClass = 0
197            self.openContext("", self.tree.domain)
198        else:
199            self.tree = None
200            self.infoa.setText("No tree.")
201            self.infob.setText("")
202            self.send("Classifier", self.tree)
203            self.send("Tree Learner", self.learner)
204            self.openContext("", None)
205
206        self.send("Examples", None)
207        self.updateTree()
208        self.v.invisibleRootItem().child(0).setSelected(1)
209
210    def setLearner(self, learner):
211        self.treeLearner = learner
212
213    def handleSelectionChanged(self, item):
214        """called when new node in the tree is selected"""
215        if self.nodeClassDict.has_key(item):
216            self.btnPrune.setEnabled(self.nodeClassDict[item].branchSelector <> None)
217       
218
219
220if __name__ == "__main__":
221    a=QApplication(sys.argv)
222    owi=OWITree()
223#    a.setMainWidget(owi)
224    d = orange.ExampleTable(r'../../doc/datasets/iris')
225    owi.setData(d)
226    owi.show()
227    a.exec_()
Note: See TracBrowser for help on using the repository browser.