source: orange/Orange/OrangeWidgets/Classify/OWCN2.py @ 11765:1546cd04481b

Revision 11765:1546cd04481b, 12.8 KB checked in by Ales Erjavec <ales.erjavec@…>, 5 months ago (diff)

Fixed widget layouts.

Line 
1"""
2<name>CN2</name>
3<description>Rule-based (CN2) learner/classifier.</description>
4<icon>icons/CN2.svg</icon>
5<contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact>
6<priority>300</priority>
7"""
8from OWWidget import *
9import OWGUI, orange, orngCN2, sys
10
11from orngWrap import PreprocessedLearner
12
13NAME = "CN2"
14
15DESCRIPTION = "Rule-based (CN2) learner/classifier"
16
17AUTHOR = "Ales Erjavec"
18
19PRIORITY = 300
20
21ICON = "icons/CN2.svg"
22
23# Sphinx documentation label reference
24HELP_REF = "CN2 Rules"
25
26INPUTS = (
27    dict(name="Data", type=ExampleTable, handler="dataset",
28         doc="Training data set",
29         id="train-data"),
30
31    dict(name="Preprocess", type=PreprocessedLearner,
32         handler="setPreprocessor",
33         doc="Data preprocessor",
34         id="preprocessor")
35)
36
37OUTPUTS = (
38    dict(name="Learner", type=orange.Learner,
39         doc="A CN2 Rules learner instance",
40         id="learner"),
41
42    dict(name="Classifier", type=orange.Classifier,
43         doc="A rule classifier induced from given training data.",
44         id="classifier"),
45
46    dict(name="Unordered CN2 Classifier", type=orngCN2.CN2UnorderedClassifier,
47         doc="Same as 'Classifier'",
48         id="unordered-cn2-classifier")
49)
50
51
52class CN2ProgressBar(orange.ProgressCallback):
53    def __init__(self, widget, start=0.0, end=0.0):
54        self.start = start
55        self.end = end
56        self.widget = widget
57        orange.ProgressCallback.__init__(self)
58    def __call__(self,value,a):
59        self.widget.progressBarSet(100*(self.start+(self.end-self.start)*value))
60
61class OWCN2(OWWidget):
62    settingsList=["name", "QualityButton", "CoveringButton", "m",
63                  "MaxRuleLength", "useMaxRuleLength",
64                  "MinCoverage", "BeamWidth", "Alpha", "Weight", "stepAlpha"]
65    callbackDeposit=[]
66    def __init__(self, parent=None, signalManager=None):
67        OWWidget.__init__(self, parent, signalManager,"CN2",
68                          wantMainArea=False, resizingEnabled=False)
69
70        self.inputs = [("Data", ExampleTable, self.dataset), ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
71        self.outputs = [("Learner", orange.Learner),("Classifier",orange.Classifier),("Unordered CN2 Classifier", orngCN2.CN2UnorderedClassifier)]
72        self.QualityButton = 0
73        self.CoveringButton = 0
74        self.Alpha = 0.05
75        self.stepAlpha = 0.2
76        self.BeamWidth = 5
77        self.MinCoverage = 0
78        self.MaxRuleLength = 0
79        self.useMaxRuleLength = False
80        self.Weight = 0.9
81        self.m = 2
82        self.name = "CN2 rules"
83        self.loadSettings()
84
85        self.data=None
86        self.preprocessor = None
87
88        ##GUI
89        labelWidth = 150
90        self.learnerName = OWGUI.lineEdit(self.controlArea, self, "name", box="Learner/classifier name", tooltip="Name to be used by other widgets to identify the learner/classifier")
91        #self.learnerName.setText(self.name)
92        OWGUI.separator(self.controlArea)
93
94        self.ruleQualityBG = OWGUI.widgetBox(self.controlArea, "Rule quality estimation")
95        self.ruleQualityBG.buttons = []
96
97        OWGUI.separator(self.controlArea)
98        self.ruleValidationGroup = OWGUI.widgetBox(self.controlArea, "Pre-prunning (LRS)")
99
100        OWGUI.separator(self.controlArea)
101        OWGUI.spin(self.controlArea, self, "BeamWidth", 1, 100, box="Beam width", tooltip="The width of the search beam\n(number of rules to be specialized)")
102
103        OWGUI.separator(self.controlArea)
104        self.coveringAlgBG = OWGUI.widgetBox(self.controlArea, "Covering algorithm")
105        self.coveringAlgBG.buttons = []
106
107        b1 = QRadioButton("Laplace", self.ruleQualityBG)
108        self.ruleQualityBG.layout().addWidget(b1)
109        g = OWGUI.widgetBox(self.ruleQualityBG, orientation="horizontal")
110        b2 = QRadioButton("m-estimate", g)
111        g.layout().addWidget(b2)
112        self.mSpin = OWGUI.doubleSpin(g,self,"m",0,100)
113        b3 = QRadioButton("EVC", self.ruleQualityBG)
114        self.ruleQualityBG.layout().addWidget(b3)
115        b4 = QRadioButton("WRACC", self.ruleQualityBG)
116        self.ruleQualityBG.layout().addWidget(b4)
117        self.ruleQualityBG.buttons = [b1, b2, b3, b4]
118
119        for i, button in enumerate([b1, b2, b3, b4]):
120            self.connect(button, SIGNAL("clicked()"), lambda v=i: self.qualityButtonPressed(v))
121
122        form = QFormLayout(
123            labelAlignment=Qt.AlignLeft, formAlignment=Qt.AlignLeft,
124            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow
125        )
126
127        self.ruleValidationGroup.layout().addLayout(form)
128
129        alpha_spin = OWGUI.doubleSpin(
130            self.ruleValidationGroup, self, "Alpha", 0, 1, 0.001,
131            tooltip="Required significance of the difference between the " +
132                    "class distribution on all examples and covered examples")
133
134        step_alpha_spin = OWGUI.doubleSpin(
135            self.ruleValidationGroup, self, "stepAlpha", 0, 1, 0.001,
136            tooltip="Required significance of each specialization of a rule.")
137
138        min_coverage_spin = OWGUI.spin(
139            self.ruleValidationGroup, self, "MinCoverage", 0, 100,
140            tooltip="Minimum number of examples a rule must cover " +
141                    "(use 0 for not setting the limit)")
142
143        min_coverage_spin.setSpecialValueText("Unlimited")
144
145        # Check box needs to be in alayout for the form layout to center it
146        # in the vertical direction.
147        max_rule_box = OWGUI.widgetBox(self.ruleValidationGroup, "")
148        max_rule_cb = OWGUI.checkBox(
149            max_rule_box, self, "useMaxRuleLength", "Maximal rule length")
150
151        max_rule_spin = OWGUI.spin(
152            self.ruleValidationGroup, self, "MaxRuleLength", 0, 100,
153            tooltip="Maximal number of conditions in the left\n" +
154                    "part of the rule (use 0 for don't care)")
155        max_rule_spin.setSpecialValueText("Unlimited")
156        max_rule_cb.disables += [max_rule_spin]
157        max_rule_cb.makeConsistent()
158
159        form.addRow("Alpha (vs. default rule)", alpha_spin)
160        form.addRow("Stopping Alpha (vs. parent rule)", step_alpha_spin)
161        form.addRow("Minimum coverage", min_coverage_spin)
162        form.addRow(max_rule_box, max_rule_spin)
163
164        B1 = QRadioButton("Exclusive covering", self.coveringAlgBG)
165        self.coveringAlgBG.layout().addWidget(B1)
166        g = OWGUI.widgetBox(self.coveringAlgBG, orientation="horizontal")
167        B2 = QRadioButton("Weighted covering", g)
168        g.layout().addWidget(B2)
169        self.coveringAlgBG.buttons = [B1, B2]
170        self.weightSpin=OWGUI.doubleSpin(g,self,"Weight",0,0.95,0.05)
171
172        for i, button in enumerate([B1, B2]):
173            self.connect(button, SIGNAL("clicked()"), lambda v=i: self.coveringAlgButtonPressed(v))
174
175        OWGUI.separator(self.controlArea)
176        self.btnApply = OWGUI.button(self.controlArea, self, "&Apply", callback=self.applySettings, default=True)
177
178        self.Alpha=float(self.Alpha)
179        self.stepAlpha=float(self.stepAlpha)
180        self.Weight=float(self.Weight)
181
182        #self.ruleQualityBG.buttons[self.QualityButton].setChecked(1)
183        self.qualityButtonPressed(self.QualityButton)
184        self.coveringAlgButtonPressed(self.CoveringButton)
185        self.resize(100,100)
186        self.setLearner()
187
188    def sendReport(self):
189        self.reportSettings("Learning parameters",
190                            [("Rule quality estimation", ["Laplace", "m-estimate with m=%.2f" % self.m, "WRACC"][self.QualityButton]),
191                             ("Pruning alpha (vs. default rule)", "%.3f" % self.Alpha),
192                             ("Stopping alpha (vs. parent rule)", "%.3f" % self.stepAlpha),
193                             ("Minimum coverage", "%.3f" % self.MinCoverage),
194                             ("Maximal rule length", self.MaxRuleLength if self.useMaxRuleLength else "unlimited"),
195                             ("Beam width", self.BeamWidth),
196                             ("Covering", ["Exclusive", "Weighted with a weight of %.2f" % self.Weight][self.CoveringButton])])
197        self.reportData(self.data)
198
199    def setLearner(self):
200        if hasattr(self, "btnApply"):
201            self.btnApply.setFocus()
202        #progress bar
203        self.progressBarInit()
204
205        #learner / specific handling in case of EVC learning (completely different type of class)
206        if self.useMaxRuleLength:
207            maxRuleLength = self.MaxRuleLength
208        else:
209            maxRuleLength = -1
210       
211        if self.QualityButton == 2:
212            self.learner=orngCN2.CN2EVCUnorderedLearner(width=self.BeamWidth, rule_sig=self.Alpha, att_sig=self.stepAlpha,
213                                                        min_coverage = self.MinCoverage, max_rule_complexity = maxRuleLength)
214            if self.preprocessor:
215                self.learner = self.preprocessor.wrapLearner(self.learner)
216            self.learner.name = self.name
217#            self.learner.progressCallback=CN2ProgressBar(self)
218            self.send("Learner",self.learner)
219        else:
220            self.learner=orngCN2.CN2UnorderedLearner()
221            self.learner.name = self.name
222#            self.learner.progressCallback=CN2ProgressBar(self)
223#            self.send("Learner",self.learner)
224
225            ruleFinder=orange.RuleBeamFinder()
226            if self.QualityButton==0:
227                ruleFinder.evaluator=orange.RuleEvaluator_Laplace()
228            elif self.QualityButton==1:
229                ruleFinder.evaluator=orngCN2.mEstimate(self.m)
230            elif self.QualityButton==3:
231                ruleFinder.evaluator=orngCN2.WRACCEvaluator()
232
233
234            ruleFinder.ruleStoppingValidator=orange.RuleValidator_LRS(alpha=self.stepAlpha,
235                        min_coverage=self.MinCoverage, max_rule_complexity=maxRuleLength)
236            ruleFinder.validator=orange.RuleValidator_LRS(alpha=self.Alpha,
237                        min_coverage=self.MinCoverage, max_rule_complexity=maxRuleLength)
238            ruleFinder.ruleFilter=orange.RuleBeamFilter_Width(width=self.BeamWidth)
239            self.learner.ruleFinder=ruleFinder
240
241            if self.CoveringButton==0:
242                self.learner.coverAndRemove=orange.RuleCovererAndRemover_Default()
243            elif self.CoveringButton==1:
244                self.learner.coverAndRemove=orngCN2.CovererAndRemover_multWeights(mult=self.Weight)
245               
246            if self.preprocessor:
247                self.learner = self.preprocessor.wrapLearner(self.learner)
248            self.learner.name = self.name
249            self.send("Learner", self.learner)
250
251        self.classifier=None
252        self.error()
253        if self.data:
254            oldDomain = orange.Domain(self.data.domain)
255            learnData = orange.ExampleTable(oldDomain, self.data)
256            self.learner.progressCallback=CN2ProgressBar(self)
257            self.classifier=self.learner(learnData)
258            self.learner.progressCallback=None
259            self.classifier.name=self.name
260            for r in self.classifier.rules:
261                r.examples = orange.ExampleTable(oldDomain, r.examples)
262            self.classifier.examples = orange.ExampleTable(oldDomain, self.classifier.examples)
263            self.classifier.setattr("data",self.classifier.examples)
264            self.error("")
265##            except orange.KernelException, (errValue):
266##                self.classifier=None
267##                self.error(errValue)
268##            except Exception:
269##                self.classifier=None
270##                if not self.data.domain.classVar:
271##                    self.error("Classless domain.")
272##                elif self.data.domain.classVar.varType == orange.VarTypes.Continuous:
273##                    self.error("CN2 can learn only from discrete class.")
274##                else:
275##                    self.error("Unknown error")
276        self.send("Classifier", self.classifier)
277        self.send("Unordered CN2 Classifier", self.classifier)
278        self.progressBarFinished()
279
280    def dataset(self, data):
281        #self.data=data
282        self.data = self.isDataWithClass(data, orange.VarTypes.Discrete, checkMissing=True) and data or None
283        self.setLearner()
284       
285    def setPreprocessor(self, pp):
286        self.preprocessor = pp
287        self.setLearner()
288
289    def qualityButtonPressed(self, id=0):
290        self.QualityButton = id
291        for i in range(len(self.ruleQualityBG.buttons)):
292            self.ruleQualityBG.buttons[i].setChecked(id == i)
293        self.mSpin.control.setEnabled(id == 1)
294        self.coveringAlgBG.setEnabled(not id == 2)
295
296    def coveringAlgButtonPressed(self,id=0):
297        self.CoveringButton = id
298        for i in range(len(self.coveringAlgBG.buttons)):
299            self.coveringAlgBG.buttons[i].setChecked(id == i)
300        self.weightSpin.control.setEnabled(id == 1)
301
302    def applySettings(self):
303        self.setLearner()
304
305if __name__=="__main__":
306    app=QApplication(sys.argv)
307    w=OWCN2()
308    #w.dataset(orange.ExampleTable("titanic.tab"))
309    w.dataset(orange.ExampleTable("titanic.tab"))
310    w.show()
311    app.exec_()
312    w.saveSettings()
Note: See TracBrowser for help on using the repository browser.