source: orange/Orange/OrangeWidgets/Classify/OWRandomForest.py @ 9671:a7b056375472

Revision 9671:a7b056375472, 5.8 KB checked in by anze <anze.staric@…>, 2 years ago (diff)

Moved orange to Orange (part 2)

Line 
1"""
2<name>Random Forest</name>
3<description>Random forest learner/classifier.</description>
4<icon>icons/RandomForest.png</icon>
5<contact>Marko Toplak (marko.toplak(@at@)gmail.com)</contact>
6<priority>320</priority>
7"""
8
9from OWWidget import *
10import orngTree, OWGUI
11import orngEnsemble
12from exceptions import Exception
13from orngWrap import PreprocessedLearner
14
15class OWRandomForest(OWWidget):
16    settingsList = ["name", "trees", "attributes", "attributesP", "preNodeInst", "preNodeInstP", "limitDepth", "limitDepthP", "rseed"]
17
18    def __init__(self, parent=None, signalManager = None, name='Random Forest'):
19        OWWidget.__init__(self, parent, signalManager, name, wantMainArea=False, resizingEnabled=False)
20
21        self.inputs = [("Data", ExampleTable, self.setData),
22                       ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
23       
24        self.outputs = [("Learner", orange.Learner),
25                        ("Random Forest Classifier", orange.Classifier)]
26
27        self.name = 'Random Forest'
28        self.trees = 10
29        self.attributes = 0
30        self.attributesP = 5
31        self.preNodeInst = 1
32        self.preNodeInstP = 5
33        self.limitDepth = 0
34        self.limitDepthP = 3
35        self.rseed = 0
36
37        self.maxTrees = 10000
38
39        self.loadSettings()
40
41        self.data = None
42        self.preprocessor = None
43
44        OWGUI.lineEdit(self.controlArea, self, 'name', box='Learner/Classifier Name', tooltip='Name to be used by other widgets to identify your learner/classifier.')
45
46        OWGUI.separator(self.controlArea)
47
48        self.bBox = OWGUI.widgetBox(self.controlArea, 'Basic Properties')
49
50        self.treesBox = OWGUI.spin(self.bBox, self, "trees", 1, self.maxTrees, orientation="horizontal", label="Number of trees in forest")
51        self.attributesBox, self.attributesPBox = OWGUI.checkWithSpin(self.bBox, self, "Consider exactly", 1, 10000, "attributes", "attributesP", " "+"random attributes at each split.")
52        self.rseedBox = OWGUI.spin(self.bBox, self, "rseed", 0, 100000, orientation="horizontal", label="Seed for random generator ")
53
54        OWGUI.separator(self.controlArea)
55
56        self.pBox = OWGUI.widgetBox(self.controlArea, 'Growth Control')
57
58        self.limitDepthBox, self.limitDepthPBox = OWGUI.checkWithSpin(self.pBox, self, "Maximal depth of individual trees", 1, 1000, "limitDepth", "limitDepthP", "")
59        self.preNodeInstBox, self.preNodeInstPBox = OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes with ", 1, 1000, "preNodeInst", "preNodeInstP", " or fewer instances")
60
61        OWGUI.separator(self.controlArea)
62
63        OWGUI.separator(self.controlArea)
64
65        self.btnApply = OWGUI.button(self.controlArea, self, "&Apply Changes", callback = self.doBoth, disabled=0, default=True)
66
67        self.resize(100,200)
68
69        self.setLearner()
70
71    def sendReport(self):
72        self.reportSettings("Learning parameters",
73                            [("Number of trees", self.trees),
74                             ("Considered number of attributes at each split", self.attributeP if self.attributes else "not set"),
75                             ("Seed for random generator", self.rseed),
76                             ("Maximal depth of individual trees", self.limitDepthP if self.limitDepth else "not set"),
77                             ("Minimal number of instances in a leaf", self.preNodeInstP if self.preNodeInst else "not limited")
78                           ])
79        self.reportData(self.data)
80
81    def constructLearner(self):
82        rand = random.Random(self.rseed)
83
84        attrs = None
85        if self.attributes:
86            attrs = self.attributesP
87
88        from Orange.classification.tree import SimpleTreeLearner
89       
90        smallLearner = SimpleTreeLearner()
91
92        if self.preNodeInst:
93            smallLearner.min_instances = self.preNodeInstP
94        else:
95            smallLearner.min_instances = 0
96
97        if self.limitDepth:
98            smallLearner.max_depth = self.limitDepthP
99       
100        learner = orngEnsemble.RandomForestLearner(base_learner=smallLearner, 
101                            trees=self.trees, rand=rand, attributes=attrs)
102
103        if self.preprocessor:
104            learner = self.preprocessor.wrapLearner(learner)
105        learner.name = self.name
106        return learner
107
108    def setLearner(self):
109
110        if hasattr(self, "btnApply"):
111            self.btnApply.setFocus()
112
113        #assemble learner
114
115        self.learner = self.constructLearner()
116        self.send("Learner", self.learner)
117
118        self.error()
119
120    def setData(self, data):
121        self.data = self.isDataWithClass(data, orange.VarTypes.Discrete, checkMissing=True) and data or None
122       
123        #self.setLearner()
124
125        if self.data:
126            learner = self.constructLearner()
127            pb = OWGUI.ProgressBar(self, iterations=self.trees)
128            learner.callback = pb.advance
129            try:
130                self.classifier = learner(self.data)
131                self.classifier.name = self.name
132            except Exception, (errValue):
133                self.error(str(errValue))
134                self.classifier = None
135            pb.finish()
136        else:
137            self.classifier = None
138
139        self.send("Random Forest Classifier", self.classifier)
140       
141    def setPreprocessor(self, pp):
142        self.preprocessor = pp
143        self.doBoth()
144
145    def doBoth(self):
146        self.setLearner()
147        self.setData(self.data)
148
149
150
151##############################################################################
152# Test the widget, run from DOS prompt
153# > python OWDataTable.py)
154# Make sure that a sample data set (adult_sample.tab) is in the directory
155
156if __name__=="__main__":
157    a=QApplication(sys.argv)
158    ow=OWRandomForest()
159    a.setMainWidget(ow)
160
161    d = orange.ExampleTable('adult_sample')
162    ow.setData(d)
163
164    ow.show()
165    a.exec_loop()
166    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.