source: orange/Orange/OrangeWidgets/Classify/OWRandomForest.py @ 11096:cf7d2ae9d22b

Revision 11096:cf7d2ae9d22b, 5.9 KB checked in by Ales Erjavec <ales.erjavec@…>, 19 months ago (diff)

Added new svg icons for the widgets/categories.

Line 
1"""
2<name>Random Forest</name>
3<description>Random forest learner/classifier.</description>
4<icon>icons/RandomForest.svg</icon>
5<contact>Marko Toplak (marko.toplak(@at@)gmail.com)</contact>
6<priority>320</priority>
7"""
8
9from OWWidget import *
10import orngTree, OWGUI
11import orngEnsemble
12from exceptions import Exception
13from orngWrap import PreprocessedLearner
14
15class OWRandomForest(OWWidget):
16    settingsList = ["name", "trees", "attributes", "attributesP", "preNodeInst", "preNodeInstP", "limitDepth", "limitDepthP", "rseed"]
17
18    def __init__(self, parent=None, signalManager = None, name='Random Forest'):
19        OWWidget.__init__(self, parent, signalManager, name, wantMainArea=False, resizingEnabled=False)
20
21        self.inputs = [("Data", ExampleTable, self.setData),
22                       ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
23       
24        self.outputs = [("Learner", orange.Learner),
25                        ("Random Forest Classifier", orange.Classifier)]
26
27        self.name = 'Random Forest'
28        self.trees = 10
29        self.attributes = 0
30        self.attributesP = 5
31        self.preNodeInst = 1
32        self.preNodeInstP = 5
33        self.limitDepth = 0
34        self.limitDepthP = 3
35        self.rseed = 0
36
37        self.maxTrees = 10000
38
39        self.loadSettings()
40
41        self.data = None
42        self.preprocessor = None
43
44        OWGUI.lineEdit(self.controlArea, self, 'name', box='Learner/Classifier Name', tooltip='Name to be used by other widgets to identify your learner/classifier.')
45
46        OWGUI.separator(self.controlArea)
47
48        self.bBox = OWGUI.widgetBox(self.controlArea, 'Basic Properties')
49
50        self.treesBox = OWGUI.spin(self.bBox, self, "trees", 1, self.maxTrees, orientation="horizontal", label="Number of trees in forest")
51        self.attributesBox, self.attributesPBox = OWGUI.checkWithSpin(self.bBox, self, "Consider exactly", 1, 10000, "attributes", "attributesP", " "+"random attributes at each split.")
52        self.rseedBox = OWGUI.spin(self.bBox, self, "rseed", 0, 100000, orientation="horizontal", label="Seed for random generator ")
53
54        OWGUI.separator(self.controlArea)
55
56        self.pBox = OWGUI.widgetBox(self.controlArea, 'Growth Control')
57
58        self.limitDepthBox, self.limitDepthPBox = OWGUI.checkWithSpin(self.pBox, self, "Maximal depth of individual trees", 1, 1000, "limitDepth", "limitDepthP", "")
59        self.preNodeInstBox, self.preNodeInstPBox = OWGUI.checkWithSpin(self.pBox, self, "Stop splitting nodes with ", 1, 1000, "preNodeInst", "preNodeInstP", " or fewer instances")
60
61        OWGUI.separator(self.controlArea)
62
63        OWGUI.separator(self.controlArea)
64
65        self.btnApply = OWGUI.button(self.controlArea, self, "&Apply Changes", callback = self.doBoth, disabled=0, default=True)
66
67        self.resize(100,200)
68
69        self.setLearner()
70
71    def sendReport(self):
72        self.reportSettings("Learning parameters",
73                            [("Number of trees", self.trees),
74                             ("Considered number of attributes at each split", self.attributeP if self.attributes else "not set"),
75                             ("Seed for random generator", self.rseed),
76                             ("Maximal depth of individual trees", self.limitDepthP if self.limitDepth else "not set"),
77                             ("Minimal number of instances in a leaf", self.preNodeInstP if self.preNodeInst else "not limited")
78                           ])
79        self.reportData(self.data)
80
81    def constructLearner(self):
82        rand = random.Random(self.rseed)
83
84        attrs = None
85        if self.attributes:
86            attrs = self.attributesP
87
88        from Orange.classification.tree import SimpleTreeLearner
89       
90        smallLearner = SimpleTreeLearner()
91
92        if self.preNodeInst:
93            smallLearner.min_instances = self.preNodeInstP
94        else:
95            smallLearner.min_instances = 0
96
97        if self.limitDepth:
98            smallLearner.max_depth = self.limitDepthP
99       
100        learner = orngEnsemble.RandomForestLearner(base_learner=smallLearner, 
101                            trees=self.trees, rand=rand, attributes=attrs)
102
103        if self.preprocessor:
104            learner = self.preprocessor.wrapLearner(learner)
105        learner.name = self.name
106        return learner
107
108    def setLearner(self):
109
110        if hasattr(self, "btnApply"):
111            self.btnApply.setFocus()
112
113        #assemble learner
114
115        self.learner = self.constructLearner()
116        self.send("Learner", self.learner)
117
118        self.error()
119
120    def setData(self, data):
121        self.data = self.isDataWithClass(data, orange.VarTypes.Discrete, checkMissing=True) and data or None
122       
123        #self.setLearner()
124
125        if self.data:
126            learner = self.constructLearner()
127            pb = OWGUI.ProgressBar(self, iterations=self.trees)
128            self.progressBarInit()
129            learner.callback = lambda v: self.progressBarSet(100.0 * v)
130            try:
131                self.classifier = learner(self.data)
132                self.classifier.name = self.name
133            except Exception, (errValue):
134                self.error(str(errValue))
135                self.classifier = None
136            self.progressBarFinished()
137        else:
138            self.classifier = None
139
140        self.send("Random Forest Classifier", self.classifier)
141       
142    def setPreprocessor(self, pp):
143        self.preprocessor = pp
144        self.doBoth()
145
146    def doBoth(self):
147        self.setLearner()
148        self.setData(self.data)
149
150
151
152##############################################################################
153# Test the widget, run from DOS prompt
154# > python OWDataTable.py)
155# Make sure that a sample data set (adult_sample.tab) is in the directory
156
157if __name__=="__main__":
158    a=QApplication(sys.argv)
159    ow=OWRandomForest()
160    a.setMainWidget(ow)
161
162    d = orange.ExampleTable('adult_sample')
163    ow.setData(d)
164
165    ow.show()
166    a.exec_loop()
167    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.