source: orange/Orange/OrangeWidgets/Prototypes/OWRandomForestOld.py @ 10587:5c9fc139ddfa

Revision 10587:5c9fc139ddfa, 6.9 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Changed the way callback function is called in RandomForestLearner (the way it is documented). Fixed the widgets correspondingly.

Line 
1"""
2<name>Random Forest (Old)</name>
3<description>Random forest learner/classifier.</description>
4<icon>icons/RandomForest.png</icon>
5<contact>Marko Toplak (marko.toplak(@at@)gmail.com)</contact>
6<priority>320</priority>
7"""
8
9from OWWidget import *
10import orngTree, OWGUI
11import orngEnsemble
12from exceptions import Exception
13from orngWrap import PreprocessedLearner
14
15class OWRandomForestOld(OWWidget):
16    settingsList = ["name", "trees", "attributes", "attributesP", "preNodeInst", "preNodeInstP", "limitDepth", "limitDepthP", "rseed", "outtree" ]
17
18    def __init__(self, parent=None, signalManager = None, name='Random Forest'):
19        OWWidget.__init__(self, parent, signalManager, name, wantMainArea=False, resizingEnabled=False)
20
21        self.inputs = [("Data", ExampleTable, self.setData), ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
22        self.outputs = [("Learner", orange.Learner),("Random Forest Classifier", orange.Classifier),("Selected Tree", orange.TreeClassifier) ]
23
24        self.name = 'Random Forest'
25        self.trees = 10
26        self.attributes = 0
27        self.attributesP = 5
28        self.preNodeInst = 1
29        self.preNodeInstP = 5
30        self.limitDepth = 0
31        self.limitDepthP = 3
32        self.rseed = 0
33        self.outtree = 0
34
35        self.maxTrees = 10000
36
37        self.loadSettings()
38
39        self.data = None
40        self.preprocessor = None
41
42        OWGUI.lineEdit(self.controlArea, self, 'name', box='Learner/Classifier Name', tooltip='Name to be used by other widgets to identify your learner/classifier.')
43
44        OWGUI.separator(self.controlArea)
45
46        self.bBox = OWGUI.widgetBox(self.controlArea, 'Basic Properties')
47
48        self.treesBox = OWGUI.spin(self.bBox, self, "trees", 1, self.maxTrees, orientation="horizontal", label="Number of trees in forest")
49        self.attributesBox, self.attributesPBox = OWGUI.checkWithSpin(self.bBox, self, "Consider exactly", 1, 10000, "attributes", "attributesP", " "+"random attributes at each split.")
50        self.rseedBox = OWGUI.spin(self.bBox, self, "rseed", 0, 100000, orientation="horizontal", label="Seed for random generator ")
51
52        OWGUI.separator(self.controlArea)
53
54        self.pBox = OWGUI.widgetBox(self.controlArea, 'Growth Control')
55
56        self.limitDepthBox, self.limitDepthPBox = OWGUI.checkWithSpin(self.pBox, self, "Maximal depth of individual trees", 1, 1000, "limitDepth", "limitDepthP", "")
57        self.preNodeInstBox, self.preNodeInstPBox = OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes with ", 1, 1000, "preNodeInst", "preNodeInstP", " or fewer instances")
58
59        OWGUI.separator(self.controlArea)
60
61        #self.sBox = QVGroupBox(self.controlArea)
62        #self.sBox.setTitle('Single Tree Output')
63
64        self.streesBox = OWGUI.spin(self.controlArea, self, "outtree", -1, self.maxTrees, orientation="horizontal", label="Index of tree on the output", callback=[self.period, self.extree])
65        #self.streesBox.setDisabled(True)
66        self.streeEnabled(False)
67
68        OWGUI.separator(self.controlArea)
69
70        self.btnApply = OWGUI.button(self.controlArea, self, "&Apply Changes", callback = self.doBoth, disabled=0, default=True)
71
72        self.resize(100,200)
73
74        self.setLearner()
75
76    def sendReport(self):
77        self.reportSettings("Learning parameters",
78                            [("Number of trees", self.trees),
79                             ("Considered number of attributes at each split", self.attributeP if self.attributes else "not set"),
80                             ("Seed for random generator", self.rseed),
81                             ("Maximal depth of individual trees", self.limitDepthP if self.limitDepth else "not set"),
82                             ("Minimal number of instances in a leaf", self.preNodeInstP if self.preNodeInst else "not limited")
83                           ])
84        self.reportData(self.data)
85       
86    def period(self):
87        if self.outtree == -1: self.outtree = self.claTrees-1
88        elif self.outtree >= self.claTrees: self.outtree = 0
89
90    def extree(self):
91        self.send("Selected Tree", self.classifier.classifiers[self.outtree])
92
93    def streeEnabled(self, status):
94        if status:
95            self.claTrees = self.trees
96            self.streesBox.setDisabled(False)
97            self.period()
98            self.extree()
99        else:
100            #a = 1
101            self.streesBox.setDisabled(True)
102
103    def constructLearner(self):
104        rand = random.Random(self.rseed)
105
106        attrs = None
107        if self.attributes:
108            attrs = self.attributesP
109
110        smallLearner = orngTree.TreeLearner()
111
112        if self.preNodeInst:
113            smallLearner.stop.minExamples = self.preNodeInstP
114        else:
115            smallLearner.stop.minExamples = 0
116
117        smallLearner.storeExamples = 1
118        smallLearner.storeNodeClassifier = 1
119        smallLearner.storeContingencies = 1
120        smallLearner.storeDistributions = 1
121
122        if self.limitDepth:
123            smallLearner.maxDepth = self.limitDepthP
124       
125        learner = orngEnsemble.RandomForestLearner(base_learner=smallLearner, 
126                            trees=self.trees, rand=rand, attributes=attrs)
127
128        if self.preprocessor:
129            learner = self.preprocessor.wrapLearner(learner)
130        learner.name = self.name
131        return learner
132
133    def setLearner(self):
134
135        if hasattr(self, "btnApply"):
136            self.btnApply.setFocus()
137
138        #assemble learner
139
140        self.learner = self.constructLearner()
141        self.send("Learner", self.learner)
142
143        self.error()
144
145    def setData(self, data):
146        self.data = self.isDataWithClass(data, orange.VarTypes.Discrete, checkMissing=True) and data or None
147       
148        #self.setLearner()
149
150        if self.data:
151            learner = self.constructLearner()
152            self.progressBarInit()
153            learner.callback = lambda v: self.progressBarSet(100.0 * v)
154            try:
155                self.classifier = learner(self.data)
156                self.classifier.name = self.name
157                self.streeEnabled(True)
158            except Exception, (errValue):
159                self.error(str(errValue))
160                self.classifier = None
161                self.streeEnabled(False)
162            self.progressBarFinished()
163        else:
164            self.classifier = None
165            self.streeEnabled(False)
166
167        self.send("Random Forest Classifier", self.classifier)
168       
169    def setPreprocessor(self, pp):
170        self.preprocessor = pp
171        self.doBoth()
172
173    def doBoth(self):
174        self.setLearner()
175        self.setData(self.data)
176
177
178
179##############################################################################
180# Test the widget, run from DOS prompt
181# > python OWDataTable.py)
182# Make sure that a sample data set (adult_sample.tab) is in the directory
183
184if __name__=="__main__":
185    a=QApplication(sys.argv)
186    ow=OWRandomForestOld()
187    a.setMainWidget(ow)
188
189    d = orange.ExampleTable('adult_sample')
190    ow.setData(d)
191
192    ow.show()
193    a.exec_loop()
194    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.