source: orange/Orange/OrangeWidgets/Classify/OWRandomForest.py @ 11628:0453b7a5b43b

Revision 11628:0453b7a5b43b, 8.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Added back single tree output using the simple tree converter.

Line 
1"""
2<name>Random Forest</name>
3<description>Random forest learner/classifier.</description>
4<icon>icons/RandomForest.svg</icon>
5<contact>Marko Toplak (marko.toplak(@at@)gmail.com)</contact>
6<priority>320</priority>
7"""
8
9from OWWidget import *
10import orngTree, OWGUI
11import orngEnsemble
12from exceptions import Exception
13from orngWrap import PreprocessedLearner
14
15
16class OWRandomForest(OWWidget):
17    settingsList = ["name", "trees", "attributes", "attributesP",
18                    "preNodeInst", "preNodeInstP", "limitDepth",
19                    "limitDepthP", "rseed"]
20
21    def __init__(self, parent=None, signalManager=None, name='Random Forest'):
22        OWWidget.__init__(self, parent, signalManager, name,
23                          wantMainArea=False, resizingEnabled=False)
24
25        self.inputs = [("Data", ExampleTable, self.setData),
26                       ("Preprocess", PreprocessedLearner,
27                        self.setPreprocessor)]
28
29        self.outputs = [("Learner", orange.Learner),
30                        ("Random Forest Classifier", orange.Classifier),
31                        ("Selected Tree", Orange.classification.tree.TreeClassifier)]
32
33        self.name = 'Random Forest'
34        self.trees = 10
35        self.attributes = 0
36        self.attributesP = 5
37        self.preNodeInst = 1
38        self.preNodeInstP = 5
39        self.limitDepth = 0
40        self.limitDepthP = 3
41        self.rseed = 0
42        self.outtree = 0
43
44        self.maxTrees = 10000
45
46        self.loadSettings()
47
48        self.data = None
49        self.preprocessor = None
50
51        OWGUI.lineEdit(self.controlArea, self, 'name',
52                       box='Learner/Classifier Name',
53                       tooltip='Name to be used by other widgets to identify '
54                               'your learner/classifier.')
55
56        OWGUI.separator(self.controlArea)
57
58        self.bBox = OWGUI.widgetBox(self.controlArea, 'Basic Properties')
59
60        self.treesBox = OWGUI.spin(self.bBox, self, "trees", 1, self.maxTrees,
61                                   orientation="horizontal",
62                                   label="Number of trees in forest")
63        self.attributesBox, self.attributesPBox = \
64            OWGUI.checkWithSpin(self.bBox, self, "Consider exactly",
65                                1, 10000, "attributes", "attributesP",
66                                " random attributes at each split.")
67
68        self.rseedBox = OWGUI.spin(self.bBox, self, "rseed", 0, 100000,
69                                   orientation="horizontal",
70                                   label="Seed for random generator ")
71
72        OWGUI.separator(self.controlArea)
73
74        self.pBox = OWGUI.widgetBox(self.controlArea, 'Growth Control')
75
76        self.limitDepthBox, self.limitDepthPBox = \
77            OWGUI.checkWithSpin(self.pBox, self,
78                                "Maximal depth of individual trees",
79                                1, 1000, "limitDepth", "limitDepthP", "")
80
81        self.preNodeInstBox, self.preNodeInstPBox = \
82            OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes with ",
83                                1, 1000, "preNodeInst", "preNodeInstP",
84                                " or fewer instances")
85
86        OWGUI.separator(self.controlArea)
87
88        self.streesBox = OWGUI.spin(self.controlArea, self, "outtree", -1,
89                                    self.maxTrees,
90                                    orientation="horizontal",
91                                    label="Index of tree on the output",
92                                    callback=[self.period, self.extree])
93        self.streeEnabled(False)
94
95        OWGUI.separator(self.controlArea)
96
97        self.btnApply = OWGUI.button(self.controlArea, self,
98                                     "&Apply Changes",
99                                     callback=self.doBoth,
100                                     disabled=0,
101                                     default=True)
102
103        self.resize(100, 200)
104
105        self.setLearner()
106
107    def sendReport(self):
108        self.reportSettings("Learning parameters",
109                    [("Number of trees", self.trees),
110                     ("Considered number of attributes at each split",
111                      self.attributeP if self.attributes else "not set"),
112                     ("Seed for random generator", self.rseed),
113                     ("Maximal depth of individual trees",
114                      self.limitDepthP if self.limitDepth else "not set"),
115                     ("Minimal number of instances in a leaf",
116                      self.preNodeInstP if self.preNodeInst else "not limited")
117                   ])
118        self.reportData(self.data)
119
120    def constructLearner(self):
121        rand = random.Random(self.rseed)
122
123        attrs = None
124        if self.attributes:
125            attrs = self.attributesP
126
127        from Orange.classification.tree import SimpleTreeLearner
128
129        smallLearner = SimpleTreeLearner()
130
131        if self.preNodeInst:
132            smallLearner.min_instances = self.preNodeInstP
133        else:
134            smallLearner.min_instances = 0
135
136        if self.limitDepth:
137            smallLearner.max_depth = self.limitDepthP
138
139        learner = orngEnsemble.RandomForestLearner(base_learner=smallLearner,
140                            trees=self.trees, rand=rand, attributes=attrs)
141
142        if self.preprocessor:
143            learner = self.preprocessor.wrapLearner(learner)
144        learner.name = self.name
145        return learner
146
147    def setLearner(self):
148
149        if hasattr(self, "btnApply"):
150            self.btnApply.setFocus()
151
152        #assemble learner
153
154        self.learner = self.constructLearner()
155        self.send("Learner", self.learner)
156
157        self.error()
158
159    def setData(self, data):
160        if not self.isDataWithClass(data, orange.VarTypes.Discrete,
161                                    checkMissing=True):
162            data = None
163        self.data = data
164
165        #self.setLearner()
166        self.streeEnabled(False)
167        if self.data:
168            learner = self.constructLearner()
169            self.progressBarInit()
170            learner.callback = lambda v: self.progressBarSet(100.0 * v)
171            try:
172                self.classifier = learner(self.data)
173                self.streeEnabled(True)
174                self.classifier.name = self.name
175            except Exception, (errValue):
176                self.error(str(errValue))
177                self.classifier = None
178            self.progressBarFinished()
179        else:
180            self.classifier = None
181
182        self.send("Random Forest Classifier", self.classifier)
183
184    def setPreprocessor(self, pp):
185        self.preprocessor = pp
186        self.doBoth()
187
188    def doBoth(self):
189        self.setLearner()
190        self.setData(self.data)
191
192    def period(self):
193        if self.outtree == -1:
194            self.outtree = self.claTrees - 1
195        elif self.outtree >= self.claTrees:
196            self.outtree = 0
197
198    def extree(self):
199        stc = self.classifier.classifiers[self.outtree]
200        if self.preprocessor:
201            # TODO: get the transformed data at learning step from the
202            # wrapped learner (or at least cache it here)
203            train_data = self.data.translate(self.classifier.domain)
204        else:
205            train_data = self.data
206
207        # Replay the bootstrap sampling as done by RandomForestLearner
208        rand = random.Random(self.claSeed)
209        n = len(train_data)
210        selection = [rand.randrange(n)
211                     for _ in range((self.outtree + 1) * n)]
212        # need the last n samples
213        selection = selection[-n:]
214        train_data = train_data.get_items_ref(selection)
215
216        tree = Orange.classification.tree._simple_tree_convert(
217            stc, self.classifier.domain, train_data)
218
219        self.send("Selected Tree", tree)
220
221    def streeEnabled(self, status):
222        if status:
223            self.claTrees = self.trees
224            self.claSeed = self.rseed
225            self.streesBox.setDisabled(False)
226            self.period()
227            self.extree()
228        else:
229            self.streesBox.setDisabled(True)
230
231
232if __name__ == "__main__":
233    a = QApplication(sys.argv)
234    ow = OWRandomForest()
235
236    d = orange.ExampleTable('adult_sample')
237    ow.setData(d)
238
239    ow.show()
240    a.exec_()
241    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.