source: orange/Orange/OrangeWidgets/Classify/OWRandomForest.py @ 11627:3c08b4b7e5ff

Revision 11627:3c08b4b7e5ff, 6.3 KB checked in by Ales Erjavec <ales.erjavec@…>, 9 months ago (diff)

Code style fixes for Random forest widgets.

Line 
1"""
2<name>Random Forest</name>
3<description>Random forest learner/classifier.</description>
4<icon>icons/RandomForest.svg</icon>
5<contact>Marko Toplak (marko.toplak(@at@)gmail.com)</contact>
6<priority>320</priority>
7"""
8
9from OWWidget import *
10import orngTree, OWGUI
11import orngEnsemble
12from exceptions import Exception
13from orngWrap import PreprocessedLearner
14
15
16class OWRandomForest(OWWidget):
17    settingsList = ["name", "trees", "attributes", "attributesP",
18                    "preNodeInst", "preNodeInstP", "limitDepth",
19                    "limitDepthP", "rseed"]
20
21    def __init__(self, parent=None, signalManager=None, name='Random Forest'):
22        OWWidget.__init__(self, parent, signalManager, name,
23                          wantMainArea=False, resizingEnabled=False)
24
25        self.inputs = [("Data", ExampleTable, self.setData),
26                       ("Preprocess", PreprocessedLearner,
27                        self.setPreprocessor)]
28
29        self.outputs = [("Learner", orange.Learner),
30                        ("Random Forest Classifier", orange.Classifier)]
31
32        self.name = 'Random Forest'
33        self.trees = 10
34        self.attributes = 0
35        self.attributesP = 5
36        self.preNodeInst = 1
37        self.preNodeInstP = 5
38        self.limitDepth = 0
39        self.limitDepthP = 3
40        self.rseed = 0
41
42        self.maxTrees = 10000
43
44        self.loadSettings()
45
46        self.data = None
47        self.preprocessor = None
48
49        OWGUI.lineEdit(self.controlArea, self, 'name',
50                       box='Learner/Classifier Name',
51                       tooltip='Name to be used by other widgets to identify '
52                               'your learner/classifier.')
53
54        OWGUI.separator(self.controlArea)
55
56        self.bBox = OWGUI.widgetBox(self.controlArea, 'Basic Properties')
57
58        self.treesBox = OWGUI.spin(self.bBox, self, "trees", 1, self.maxTrees,
59                                   orientation="horizontal",
60                                   label="Number of trees in forest")
61        self.attributesBox, self.attributesPBox = \
62            OWGUI.checkWithSpin(self.bBox, self, "Consider exactly",
63                                1, 10000, "attributes", "attributesP",
64                                " random attributes at each split.")
65
66        self.rseedBox = OWGUI.spin(self.bBox, self, "rseed", 0, 100000,
67                                   orientation="horizontal",
68                                   label="Seed for random generator ")
69
70        OWGUI.separator(self.controlArea)
71
72        self.pBox = OWGUI.widgetBox(self.controlArea, 'Growth Control')
73
74        self.limitDepthBox, self.limitDepthPBox = \
75            OWGUI.checkWithSpin(self.pBox, self,
76                                "Maximal depth of individual trees",
77                                1, 1000, "limitDepth", "limitDepthP", "")
78
79        self.preNodeInstBox, self.preNodeInstPBox = \
80            OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes with ",
81                                1, 1000, "preNodeInst", "preNodeInstP",
82                                " or fewer instances")
83
84        OWGUI.separator(self.controlArea)
85
86        OWGUI.separator(self.controlArea)
87
88        self.btnApply = OWGUI.button(self.controlArea, self,
89                                     "&Apply Changes",
90                                     callback=self.doBoth,
91                                     disabled=0,
92                                     default=True)
93
94        self.resize(100, 200)
95
96        self.setLearner()
97
98    def sendReport(self):
99        self.reportSettings("Learning parameters",
100                    [("Number of trees", self.trees),
101                     ("Considered number of attributes at each split",
102                      self.attributeP if self.attributes else "not set"),
103                     ("Seed for random generator", self.rseed),
104                     ("Maximal depth of individual trees",
105                      self.limitDepthP if self.limitDepth else "not set"),
106                     ("Minimal number of instances in a leaf",
107                      self.preNodeInstP if self.preNodeInst else "not limited")
108                   ])
109        self.reportData(self.data)
110
111    def constructLearner(self):
112        rand = random.Random(self.rseed)
113
114        attrs = None
115        if self.attributes:
116            attrs = self.attributesP
117
118        from Orange.classification.tree import SimpleTreeLearner
119
120        smallLearner = SimpleTreeLearner()
121
122        if self.preNodeInst:
123            smallLearner.min_instances = self.preNodeInstP
124        else:
125            smallLearner.min_instances = 0
126
127        if self.limitDepth:
128            smallLearner.max_depth = self.limitDepthP
129
130        learner = orngEnsemble.RandomForestLearner(base_learner=smallLearner,
131                            trees=self.trees, rand=rand, attributes=attrs)
132
133        if self.preprocessor:
134            learner = self.preprocessor.wrapLearner(learner)
135        learner.name = self.name
136        return learner
137
138    def setLearner(self):
139
140        if hasattr(self, "btnApply"):
141            self.btnApply.setFocus()
142
143        #assemble learner
144
145        self.learner = self.constructLearner()
146        self.send("Learner", self.learner)
147
148        self.error()
149
150    def setData(self, data):
151        if not self.isDataWithClass(data, orange.VarTypes.Discrete,
152                                    checkMissing=True):
153            data = None
154        self.data = data
155
156        #self.setLearner()
157
158        if self.data:
159            learner = self.constructLearner()
160            self.progressBarInit()
161            learner.callback = lambda v: self.progressBarSet(100.0 * v)
162            try:
163                self.classifier = learner(self.data)
164                self.classifier.name = self.name
165            except Exception, (errValue):
166                self.error(str(errValue))
167                self.classifier = None
168            self.progressBarFinished()
169        else:
170            self.classifier = None
171
172        self.send("Random Forest Classifier", self.classifier)
173
174    def setPreprocessor(self, pp):
175        self.preprocessor = pp
176        self.doBoth()
177
178    def doBoth(self):
179        self.setLearner()
180        self.setData(self.data)
181
182
183if __name__ == "__main__":
184    a = QApplication(sys.argv)
185    ow = OWRandomForest()
186
187    d = orange.ExampleTable('adult_sample')
188    ow.setData(d)
189
190    ow.show()
191    a.exec_()
192    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.