source: orange/orange/OrangeWidgets/Classify/OWClassificationTree.py @ 9608:0ff19219442a

Revision 9608:0ff19219442a, 8.0 KB checked in by Miha Stajdohar <miha.stajdohar@…>, 2 years ago (diff)

Added tree graph output signal.

Line 
1"""
2<name>Classification Tree</name>
3<description>Classification tree learner/classifier.</description>
4<icon>icons/ClassificationTree.png</icon>
5<contact>Janez Demsar (janez.demsar(@at@)fri.uni-lj.si)</contact>
6<priority>30</priority>
7"""
8from OWWidget import *
9import orngTree, OWGUI
10from orngWrap import PreprocessedLearner
11from exceptions import Exception
12
13import Orange
14
15import warnings
16warnings.filterwarnings("ignore", ".*this class is not optimized for 'candidates' list and can be very slow.*", orange.KernelWarning, ".*orngTree", 34)
17
18
19
20class OWClassificationTree(OWWidget):
21    settingsList = ["name",
22                    "estim", "relK", "relM",
23                    "bin", "subset",
24                    "preLeafInst", "preNodeInst", "preNodeMaj",
25                    "preLeafInstP", "preNodeInstP", "preNodeMajP",
26                    "postMaj", "postMPruning", "postM",
27                    "limitDepth", "maxDepth"]
28
29    measures = (("Information Gain", "infoGain"), ("Gain Ratio", "gainRatio"), ("Gini Index", "gini"), ("ReliefF", "relief"))
30    binarizationOpts = ["No binarization", "Exhaustive search for optimal split", "One value against others"]
31
32    def __init__(self, parent=None, signalManager=None, name='Classification Tree'):
33        OWWidget.__init__(self, parent, signalManager, name, wantMainArea=0, resizingEnabled=0)
34
35        self.inputs = [("Data", ExampleTable, self.setData), ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
36        self.outputs = [("Learner", orange.TreeLearner), ("Classification Tree", orange.TreeClassifier), ("Classification Tree Graph", Orange.network.Graph)]
37
38        self.name = 'Classification Tree'
39        self.estim = 0; self.relK = 5; self.relM = 100; self.limitRef = True
40        self.bin = 0; self.subset = 0
41        self.preLeafInstP = 2; self.preNodeInstP = 5; self.preNodeMajP = 95
42        self.preLeafInst = 1; self.preNodeInst = 0; self.preNodeMaj = 0
43        self.postMaj = 1; self.postMPruning = 1; self.postM = 2.0
44        self.limitDepth = False; self.maxDepth = 100
45        self.loadSettings()
46
47        self.data = None
48        self.preprocessor = None
49        self.setLearner()
50
51        OWGUI.lineEdit(self.controlArea, self, 'name', box='Learner/Classifier Name', tooltip='Name to be used by other widgets to identify your learner/classifier.')
52        OWGUI.separator(self.controlArea)
53
54        qBox = OWGUI.widgetBox(self.controlArea, 'Attribute selection criterion')
55
56        self.qMea = OWGUI.comboBox(qBox, self, "estim", items=[m[0] for m in self.measures], callback=self.measureChanged)
57
58        b1 = OWGUI.widgetBox(qBox, orientation="horizontal")
59        OWGUI.separator(b1, 16, 0)
60        b2 = OWGUI.widgetBox(b1)
61        self.cbLimitRef, self.hbxRel1 = OWGUI.checkWithSpin(b2, self, "Limit the number of reference examples to ", 1, 1000, "limitRef", "relM")
62        OWGUI.separator(b2)
63        self.hbxRel2 = OWGUI.spin(b2, self, "relK", 1, 50, orientation="horizontal", label="Number of neighbours in ReliefF  ")
64
65        OWGUI.separator(self.controlArea)
66
67        OWGUI.radioButtonsInBox(self.controlArea, self, 'bin', self.binarizationOpts, "Binarization")
68        OWGUI.separator(self.controlArea)
69
70        self.measureChanged()
71
72        self.pBox = OWGUI.widgetBox(self.controlArea, 'Pre-Pruning')
73
74        self.preLeafInstBox, self.preLeafInstPBox = OWGUI.checkWithSpin(self.pBox, self, "Min. instances in leaves ", 1, 1000, "preLeafInst", "preLeafInstP")
75        self.preNodeInstBox, self.preNodeInstPBox = OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes with less instances than ", 1, 1000, "preNodeInst", "preNodeInstP")
76        self.preNodeMajBox, self.preNodeMajPBox = OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes with a majority class of (%)", 1, 100, "preNodeMaj", "preNodeMajP")
77        self.cbLimitDepth, self.maxDepthBox = OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes at depth", 0, 1000, "limitDepth", "maxDepth")
78        OWGUI.separator(self.controlArea)
79        self.mBox = OWGUI.widgetBox(self.controlArea, 'Post-Pruning')
80
81        OWGUI.checkBox(self.mBox, self, 'postMaj', 'Recursively merge leaves with same majority class')
82        self.postMPruningBox, self.postMPruningPBox = OWGUI.checkWithSpin(self.mBox, self, "Pruning with m-estimate, m=", 0, 1000, 'postMPruning', 'postM')
83
84        OWGUI.separator(self.controlArea)
85        self.btnApply = OWGUI.button(self.controlArea, self, "&Apply", callback=self.setLearner, disabled=0, default=True)
86
87        OWGUI.rubber(self.controlArea)
88        self.resize(200, 200)
89
90
91    def sendReport(self):
92        self.reportSettings("Learning parameters",
93                            [("Attribute selection", self.measures[self.estim][0]),
94                             self.estim == 3 and ("ReliefF settings", "%i reference examples, %i neighbours" % (self.relM, self.relK)),
95                             ("Binarization", self.binarizationOpts[self.bin]),
96                             ("Pruning", ", ".join(s for s, c in (
97                                                 ("%i instances in leaves" % self.preLeafInstP, self.preLeafInst),
98                                                 ("%i instance in node" % self.preNodeInstP, self.preNodeInst),
99                                                 ("stop on %i%% purity" % self.preNodeMajP, self.preNodeMaj),
100                                                 ("maximum depth %i" % self.maxDepth, self.limitDepth)) if c)
101                                          or "None"),
102                             ("Recursively merge leaves with same majority class", OWGUI.YesNo[self.postMaj]),
103                             ("Pruning with m-estimate", ["No", "m=%i" % self.postM][self.postMPruning])])
104        self.reportData(self.data)
105
106    def setPreprocessor(self, preprocessor):
107        self.preprocessor = preprocessor
108        self.setLearner()
109
110    def setLearner(self):
111        if hasattr(self, "btnApply"):
112            self.btnApply.setFocus()
113        if not self.limitDepth:
114            mDepth = {}
115        else:
116            mDepth = {'maxDepth': self.maxDepth}
117        self.learner = orngTree.TreeLearner(measure=self.measures[self.estim][1],
118            reliefK=self.relK, reliefM=self.limitRef and self.relM or -1,
119            binarization=self.bin,
120            minExamples=self.preNodeInst and self.preNodeInstP,
121            minSubset=self.preLeafInst and self.preLeafInstP,
122            maxMajority=self.preNodeMaj and self.preNodeMajP / 100.0 or 1.0,
123            sameMajorityPruning=self.postMaj,
124            mForPruning=self.postMPruning and self.postM,
125            storeExamples=1, **mDepth)
126
127        self.learner.name = self.name
128        if self.preprocessor:
129            self.learner = self.preprocessor.wrapLearner(self.learner)
130
131        self.send("Learner", self.learner)
132
133        self.error()
134        if self.data:
135            try:
136                self.classifier = self.learner(self.data)
137                self.classifier.name = self.name
138            except Exception, (errValue):
139                self.error(str(errValue))
140                self.classifier = None
141        else:
142            self.classifier = None
143
144        tree_graph = None
145        if self.classifier is not None:
146            tree_graph = self.classifier.to_network()
147
148        self.send("Classification Tree", self.classifier)
149        self.send("Classification Tree Graph", tree_graph)
150
151
152    def measureChanged(self):
153        relief = self.estim == 3
154        self.hbxRel1.setEnabled(relief and self.limitRef)
155        self.hbxRel2.setEnabled(relief)
156        self.cbLimitRef.setEnabled(relief)
157
158    def setData(self, data):
159        self.data = self.isDataWithClass(data, orange.VarTypes.Discrete, checkMissing=True) and data or None
160        self.setLearner()
161
162
163##############################################################################
164# Test the widget, run from DOS prompt
165# > python OWDataTable.py)
166# Make sure that a sample data set (adult_sample.tab) is in the directory
167
168if __name__ == "__main__":
169    a = QApplication(sys.argv)
170    ow = OWClassificationTree()
171
172    #d = orange.ExampleTable('adult_sample')
173    #ow.setData(d)
174
175    ow.show()
176    a.exec_()
177    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.