source: orange/Orange/OrangeWidgets/Classify/OWClassificationTree.py @ 10552:c0028bc3f865

Revision 10552:c0028bc3f865, 8.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Replaced orange.TreeClassifier with Orange.classification.tree.TreeClassifier in channel types (needed for dynamic signals to work when connecting to Tree Graph/Viewer widgets).

Line 
1"""
2<name>Classification Tree</name>
3<description>Classification tree learner/classifier.</description>
4<icon>icons/ClassificationTree.png</icon>
5<contact>Janez Demsar (janez.demsar(@at@)fri.uni-lj.si)</contact>
6<priority>30</priority>
7"""
8from OWWidget import *
9import orngTree, OWGUI
10from orngWrap import PreprocessedLearner
11from exceptions import Exception
12
13import Orange
14
15import warnings
16warnings.filterwarnings("ignore", ".*this class is not optimized for 'candidates' list and can be very slow.*", orange.KernelWarning, ".*orngTree", 34)
17
18
19
20class OWClassificationTree(OWWidget):
21    settingsList = ["name",
22                    "estim", "relK", "relM",
23                    "bin", "subset",
24                    "preLeafInst", "preNodeInst", "preNodeMaj",
25                    "preLeafInstP", "preNodeInstP", "preNodeMajP",
26                    "postMaj", "postMPruning", "postM",
27                    "limitDepth", "maxDepth"]
28
29    measures = (("Information Gain", "infoGain"), ("Gain Ratio", "gainRatio"), ("Gini Index", "gini"), ("ReliefF", "relief"))
30    binarizationOpts = ["No binarization", "Exhaustive search for optimal split", "One value against others"]
31
32    def __init__(self, parent=None, signalManager=None, name='Classification Tree'):
33        OWWidget.__init__(self, parent, signalManager, name, wantMainArea=0, resizingEnabled=0)
34
35        self.inputs = [("Data", ExampleTable, self.setData),
36                       ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
37       
38        self.outputs = [("Learner", Orange.classification.tree.TreeLearner),
39                        ("Classification Tree", Orange.classification.tree.TreeClassifier),
40                        ("Classification Tree Graph", Orange.network.Graph)]
41
42        self.name = 'Classification Tree'
43        self.estim = 0; self.relK = 5; self.relM = 100; self.limitRef = True
44        self.bin = 0; self.subset = 0
45        self.preLeafInstP = 2; self.preNodeInstP = 5; self.preNodeMajP = 95
46        self.preLeafInst = 1; self.preNodeInst = 0; self.preNodeMaj = 0
47        self.postMaj = 1; self.postMPruning = 1; self.postM = 2.0
48        self.limitDepth = False; self.maxDepth = 100
49        self.loadSettings()
50
51        self.data = None
52        self.preprocessor = None
53        self.setLearner()
54
55        OWGUI.lineEdit(self.controlArea, self, 'name', box='Learner/Classifier Name', tooltip='Name to be used by other widgets to identify your learner/classifier.')
56        OWGUI.separator(self.controlArea)
57
58        qBox = OWGUI.widgetBox(self.controlArea, 'Attribute selection criterion')
59
60        self.qMea = OWGUI.comboBox(qBox, self, "estim", items=[m[0] for m in self.measures], callback=self.measureChanged)
61
62        b1 = OWGUI.widgetBox(qBox, orientation="horizontal")
63        OWGUI.separator(b1, 16, 0)
64        b2 = OWGUI.widgetBox(b1)
65        self.cbLimitRef, self.hbxRel1 = OWGUI.checkWithSpin(b2, self, "Limit the number of reference examples to ", 1, 1000, "limitRef", "relM")
66        OWGUI.separator(b2)
67        self.hbxRel2 = OWGUI.spin(b2, self, "relK", 1, 50, orientation="horizontal", label="Number of neighbours in ReliefF  ")
68
69        OWGUI.separator(self.controlArea)
70
71        OWGUI.radioButtonsInBox(self.controlArea, self, 'bin', self.binarizationOpts, "Binarization")
72        OWGUI.separator(self.controlArea)
73
74        self.measureChanged()
75
76        self.pBox = OWGUI.widgetBox(self.controlArea, 'Pre-Pruning')
77
78        self.preLeafInstBox, self.preLeafInstPBox = OWGUI.checkWithSpin(self.pBox, self, "Min. instances in leaves ", 1, 1000, "preLeafInst", "preLeafInstP")
79        self.preNodeInstBox, self.preNodeInstPBox = OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes with less instances than ", 1, 1000, "preNodeInst", "preNodeInstP")
80        self.preNodeMajBox, self.preNodeMajPBox = OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes with a majority class of (%)", 1, 100, "preNodeMaj", "preNodeMajP")
81        self.cbLimitDepth, self.maxDepthBox = OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes at depth", 0, 1000, "limitDepth", "maxDepth")
82        OWGUI.separator(self.controlArea)
83        self.mBox = OWGUI.widgetBox(self.controlArea, 'Post-Pruning')
84
85        OWGUI.checkBox(self.mBox, self, 'postMaj', 'Recursively merge leaves with same majority class')
86        self.postMPruningBox, self.postMPruningPBox = OWGUI.checkWithSpin(self.mBox, self, "Pruning with m-estimate, m=", 0, 1000, 'postMPruning', 'postM')
87
88        OWGUI.separator(self.controlArea)
89        self.btnApply = OWGUI.button(self.controlArea, self, "&Apply", callback=self.setLearner, disabled=0, default=True)
90
91        OWGUI.rubber(self.controlArea)
92        self.resize(200, 200)
93
94
95    def sendReport(self):
96        self.reportSettings("Learning parameters",
97                            [("Attribute selection", self.measures[self.estim][0]),
98                             self.estim == 3 and ("ReliefF settings", "%i reference examples, %i neighbours" % (self.relM, self.relK)),
99                             ("Binarization", self.binarizationOpts[self.bin]),
100                             ("Pruning", ", ".join(s for s, c in (
101                                                 ("%i instances in leaves" % self.preLeafInstP, self.preLeafInst),
102                                                 ("%i instance in node" % self.preNodeInstP, self.preNodeInst),
103                                                 ("stop on %i%% purity" % self.preNodeMajP, self.preNodeMaj),
104                                                 ("maximum depth %i" % self.maxDepth, self.limitDepth)) if c)
105                                          or "None"),
106                             ("Recursively merge leaves with same majority class", OWGUI.YesNo[self.postMaj]),
107                             ("Pruning with m-estimate", ["No", "m=%i" % self.postM][self.postMPruning])])
108        self.reportData(self.data)
109
110    def setPreprocessor(self, preprocessor):
111        self.preprocessor = preprocessor
112        self.setLearner()
113
114    def setLearner(self):
115        if hasattr(self, "btnApply"):
116            self.btnApply.setFocus()
117        if not self.limitDepth:
118            mDepth = {}
119        else:
120            mDepth = {'maxDepth': self.maxDepth}
121        self.learner = orngTree.TreeLearner(measure=self.measures[self.estim][1],
122            reliefK=self.relK, reliefM=self.limitRef and self.relM or -1,
123            binarization=self.bin,
124            minExamples=self.preNodeInst and self.preNodeInstP,
125            minSubset=self.preLeafInst and self.preLeafInstP,
126            maxMajority=self.preNodeMaj and self.preNodeMajP / 100.0 or 1.0,
127            sameMajorityPruning=self.postMaj,
128            mForPruning=self.postMPruning and self.postM,
129            storeExamples=1, **mDepth)
130
131        self.learner.name = self.name
132        if self.preprocessor:
133            self.learner = self.preprocessor.wrapLearner(self.learner)
134
135        self.send("Learner", self.learner)
136
137        self.error()
138        if self.data:
139            try:
140                self.classifier = self.learner(self.data)
141                self.classifier.name = self.name
142            except Exception, (errValue):
143                self.error(str(errValue))
144                self.classifier = None
145        else:
146            self.classifier = None
147
148        tree_graph = None
149        if self.classifier is not None:
150            tree_graph = self.classifier.to_network()
151
152        self.send("Classification Tree", self.classifier)
153        self.send("Classification Tree Graph", tree_graph)
154
155
156    def measureChanged(self):
157        relief = self.estim == 3
158        self.hbxRel1.setEnabled(relief and self.limitRef)
159        self.hbxRel2.setEnabled(relief)
160        self.cbLimitRef.setEnabled(relief)
161
162    def setData(self, data):
163        self.data = self.isDataWithClass(data, orange.VarTypes.Discrete, checkMissing=True) and data or None
164        self.setLearner()
165
166
167##############################################################################
168# Test the widget, run from DOS prompt
169# > python OWDataTable.py)
170# Make sure that a sample data set (adult_sample.tab) is in the directory
171
172if __name__ == "__main__":
173    a = QApplication(sys.argv)
174    ow = OWClassificationTree()
175
176    #d = orange.ExampleTable('adult_sample')
177    #ow.setData(d)
178
179    ow.show()
180    a.exec_()
181    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.