source: orange/Orange/OrangeWidgets/Classify/OWClassificationTree.py @ 11012:19029caa4a32

Revision 11012:19029caa4a32, 8.2 KB checked in by Miha Stajdohar <miha.stajdohar@…>, 18 months ago (diff)

Fixed more network references.

Line 
1"""
2<name>Classification Tree</name>
3<description>Classification tree learner/classifier.</description>
4<icon>icons/ClassificationTree.png</icon>
5<contact>Janez Demsar (janez.demsar(@at@)fri.uni-lj.si)</contact>
6<priority>30</priority>
7"""
8from OWWidget import *
9import orngTree, OWGUI
10from orngWrap import PreprocessedLearner
11from exceptions import Exception
12
13import Orange
14
15import warnings
16warnings.filterwarnings("ignore", ".*this class is not optimized for 'candidates' list and can be very slow.*", orange.KernelWarning, ".*orngTree", 34)
17
18
19
20class OWClassificationTree(OWWidget):
21    settingsList = ["name",
22                    "estim", "relK", "relM",
23                    "bin", "subset",
24                    "preLeafInst", "preNodeInst", "preNodeMaj",
25                    "preLeafInstP", "preNodeInstP", "preNodeMajP",
26                    "postMaj", "postMPruning", "postM",
27                    "limitDepth", "maxDepth"]
28
29    measures = (("Information Gain", "infoGain"), ("Gain Ratio", "gainRatio"), ("Gini Index", "gini"), ("ReliefF", "relief"))
30    binarizationOpts = ["No binarization", "Exhaustive search for optimal split", "One value against others"]
31
32    def __init__(self, parent=None, signalManager=None, name='Classification Tree'):
33        OWWidget.__init__(self, parent, signalManager, name, wantMainArea=0, resizingEnabled=0)
34
35        self.inputs = [("Data", ExampleTable, self.setData),
36                       ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
37       
38        self.outputs = [("Learner", Orange.classification.tree.TreeLearner),
39                        ("Classification Tree", Orange.classification.tree.TreeClassifier), ]
40
41        try:
42            self.outputs.append(("Classification Tree Graph", Orange.network.Graph))
43        except:
44            pass
45
46        self.name = 'Classification Tree'
47        self.estim = 0; self.relK = 5; self.relM = 100; self.limitRef = True
48        self.bin = 0; self.subset = 0
49        self.preLeafInstP = 2; self.preNodeInstP = 5; self.preNodeMajP = 95
50        self.preLeafInst = 1; self.preNodeInst = 0; self.preNodeMaj = 0
51        self.postMaj = 1; self.postMPruning = 1; self.postM = 2.0
52        self.limitDepth = False; self.maxDepth = 100
53        self.loadSettings()
54
55        self.data = None
56        self.preprocessor = None
57        self.setLearner()
58
59        OWGUI.lineEdit(self.controlArea, self, 'name', box='Learner/Classifier Name', tooltip='Name to be used by other widgets to identify your learner/classifier.')
60        OWGUI.separator(self.controlArea)
61
62        qBox = OWGUI.widgetBox(self.controlArea, 'Attribute selection criterion')
63
64        self.qMea = OWGUI.comboBox(qBox, self, "estim", items=[m[0] for m in self.measures], callback=self.measureChanged)
65
66        b1 = OWGUI.widgetBox(qBox, orientation="horizontal")
67        OWGUI.separator(b1, 16, 0)
68        b2 = OWGUI.widgetBox(b1)
69        self.cbLimitRef, self.hbxRel1 = OWGUI.checkWithSpin(b2, self, "Limit the number of reference examples to ", 1, 1000, "limitRef", "relM")
70        OWGUI.separator(b2)
71        self.hbxRel2 = OWGUI.spin(b2, self, "relK", 1, 50, orientation="horizontal", label="Number of neighbours in ReliefF  ")
72
73        OWGUI.separator(self.controlArea)
74
75        OWGUI.radioButtonsInBox(self.controlArea, self, 'bin', self.binarizationOpts, "Binarization")
76        OWGUI.separator(self.controlArea)
77
78        self.measureChanged()
79
80        self.pBox = OWGUI.widgetBox(self.controlArea, 'Pre-Pruning')
81
82        self.preLeafInstBox, self.preLeafInstPBox = OWGUI.checkWithSpin(self.pBox, self, "Min. instances in leaves ", 1, 1000, "preLeafInst", "preLeafInstP")
83        self.preNodeInstBox, self.preNodeInstPBox = OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes with less instances than ", 1, 1000, "preNodeInst", "preNodeInstP")
84        self.preNodeMajBox, self.preNodeMajPBox = OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes with a majority class of (%)", 1, 100, "preNodeMaj", "preNodeMajP")
85        self.cbLimitDepth, self.maxDepthBox = OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes at depth", 0, 1000, "limitDepth", "maxDepth")
86        OWGUI.separator(self.controlArea)
87        self.mBox = OWGUI.widgetBox(self.controlArea, 'Post-Pruning')
88
89        OWGUI.checkBox(self.mBox, self, 'postMaj', 'Recursively merge leaves with same majority class')
90        self.postMPruningBox, self.postMPruningPBox = OWGUI.checkWithSpin(self.mBox, self, "Pruning with m-estimate, m=", 0, 1000, 'postMPruning', 'postM')
91
92        OWGUI.separator(self.controlArea)
93        self.btnApply = OWGUI.button(self.controlArea, self, "&Apply", callback=self.setLearner, disabled=0, default=True)
94
95        OWGUI.rubber(self.controlArea)
96        self.resize(200, 200)
97
98
99    def sendReport(self):
100        self.reportSettings("Learning parameters",
101                            [("Attribute selection", self.measures[self.estim][0]),
102                             self.estim == 3 and ("ReliefF settings", "%i reference examples, %i neighbours" % (self.relM, self.relK)),
103                             ("Binarization", self.binarizationOpts[self.bin]),
104                             ("Pruning", ", ".join(s for s, c in (
105                                                 ("%i instances in leaves" % self.preLeafInstP, self.preLeafInst),
106                                                 ("%i instance in node" % self.preNodeInstP, self.preNodeInst),
107                                                 ("stop on %i%% purity" % self.preNodeMajP, self.preNodeMaj),
108                                                 ("maximum depth %i" % self.maxDepth, self.limitDepth)) if c)
109                                          or "None"),
110                             ("Recursively merge leaves with same majority class", OWGUI.YesNo[self.postMaj]),
111                             ("Pruning with m-estimate", ["No", "m=%i" % self.postM][self.postMPruning])])
112        self.reportData(self.data)
113
114    def setPreprocessor(self, preprocessor):
115        self.preprocessor = preprocessor
116        self.setLearner()
117
118    def setLearner(self):
119        if hasattr(self, "btnApply"):
120            self.btnApply.setFocus()
121        if not self.limitDepth:
122            mDepth = {}
123        else:
124            mDepth = {'maxDepth': self.maxDepth}
125        self.learner = orngTree.TreeLearner(measure=self.measures[self.estim][1],
126            reliefK=self.relK, reliefM=self.limitRef and self.relM or -1,
127            binarization=self.bin,
128            minExamples=self.preNodeInst and self.preNodeInstP,
129            minSubset=self.preLeafInst and self.preLeafInstP,
130            maxMajority=self.preNodeMaj and self.preNodeMajP / 100.0 or 1.0,
131            sameMajorityPruning=self.postMaj,
132            mForPruning=self.postMPruning and self.postM,
133            storeExamples=1, **mDepth)
134
135        self.learner.name = self.name
136        if self.preprocessor:
137            self.learner = self.preprocessor.wrapLearner(self.learner)
138
139        self.send("Learner", self.learner)
140
141        self.error()
142        if self.data:
143            try:
144                self.classifier = self.learner(self.data)
145                self.classifier.name = self.name
146            except Exception, (errValue):
147                self.error(str(errValue))
148                self.classifier = None
149        else:
150            self.classifier = None
151
152        tree_graph = None
153        if self.classifier is not None:
154            tree_graph = self.classifier.to_network()
155
156        self.send("Classification Tree", self.classifier)
157
158        if "Classification Tree Graph" in self.outputs:
159            self.send("Classification Tree Graph", tree_graph)
160
161    def measureChanged(self):
162        relief = self.estim == 3
163        self.hbxRel1.setEnabled(relief and self.limitRef)
164        self.hbxRel2.setEnabled(relief)
165        self.cbLimitRef.setEnabled(relief)
166
167    def setData(self, data):
168        self.data = self.isDataWithClass(data, orange.VarTypes.Discrete, checkMissing=True) and data or None
169        self.setLearner()
170
171
172##############################################################################
173# Test the widget, run from DOS prompt
174# > python OWDataTable.py)
175# Make sure that a sample data set (adult_sample.tab) is in the directory
176
177if __name__ == "__main__":
178    a = QApplication(sys.argv)
179    ow = OWClassificationTree()
180
181    #d = orange.ExampleTable('adult_sample')
182    #ow.setData(d)
183
184    ow.show()
185    a.exec_()
186    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.