source: orange/Orange/OrangeWidgets/Classify/OWCN2.py @ 11766:bf168e71b620

Revision 11766:bf168e71b620, 12.7 KB checked in by Ales Erjavec <ales.erjavec@…>, 5 months ago (diff)

Removed the special value for the "Maximal rule length".

Line 
1"""
2<name>CN2</name>
3<description>Rule-based (CN2) learner/classifier.</description>
4<icon>icons/CN2.svg</icon>
5<contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact>
6<priority>300</priority>
7"""
8from OWWidget import *
9import OWGUI, orange, orngCN2, sys
10
11from orngWrap import PreprocessedLearner
12
13NAME = "CN2"
14
15DESCRIPTION = "Rule-based (CN2) learner/classifier"
16
17AUTHOR = "Ales Erjavec"
18
19PRIORITY = 300
20
21ICON = "icons/CN2.svg"
22
23# Sphinx documentation label reference
24HELP_REF = "CN2 Rules"
25
26INPUTS = (
27    dict(name="Data", type=ExampleTable, handler="dataset",
28         doc="Training data set",
29         id="train-data"),
30
31    dict(name="Preprocess", type=PreprocessedLearner,
32         handler="setPreprocessor",
33         doc="Data preprocessor",
34         id="preprocessor")
35)
36
37OUTPUTS = (
38    dict(name="Learner", type=orange.Learner,
39         doc="A CN2 Rules learner instance",
40         id="learner"),
41
42    dict(name="Classifier", type=orange.Classifier,
43         doc="A rule classifier induced from given training data.",
44         id="classifier"),
45
46    dict(name="Unordered CN2 Classifier", type=orngCN2.CN2UnorderedClassifier,
47         doc="Same as 'Classifier'",
48         id="unordered-cn2-classifier")
49)
50
51
52class CN2ProgressBar(orange.ProgressCallback):
53    def __init__(self, widget, start=0.0, end=0.0):
54        self.start = start
55        self.end = end
56        self.widget = widget
57        orange.ProgressCallback.__init__(self)
58    def __call__(self,value,a):
59        self.widget.progressBarSet(100*(self.start+(self.end-self.start)*value))
60
61class OWCN2(OWWidget):
62    settingsList=["name", "QualityButton", "CoveringButton", "m",
63                  "MaxRuleLength", "useMaxRuleLength",
64                  "MinCoverage", "BeamWidth", "Alpha", "Weight", "stepAlpha"]
65    callbackDeposit=[]
66    def __init__(self, parent=None, signalManager=None):
67        OWWidget.__init__(self, parent, signalManager,"CN2",
68                          wantMainArea=False, resizingEnabled=False)
69
70        self.inputs = [("Data", ExampleTable, self.dataset), ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
71        self.outputs = [("Learner", orange.Learner),("Classifier",orange.Classifier),("Unordered CN2 Classifier", orngCN2.CN2UnorderedClassifier)]
72        self.QualityButton = 0
73        self.CoveringButton = 0
74        self.Alpha = 0.05
75        self.stepAlpha = 0.2
76        self.BeamWidth = 5
77        self.MinCoverage = 0
78        self.MaxRuleLength = 0
79        self.useMaxRuleLength = False
80        self.Weight = 0.9
81        self.m = 2
82        self.name = "CN2 rules"
83        self.loadSettings()
84
85        self.data=None
86        self.preprocessor = None
87
88        ##GUI
89        labelWidth = 150
90        self.learnerName = OWGUI.lineEdit(self.controlArea, self, "name", box="Learner/classifier name", tooltip="Name to be used by other widgets to identify the learner/classifier")
91        #self.learnerName.setText(self.name)
92        OWGUI.separator(self.controlArea)
93
94        self.ruleQualityBG = OWGUI.widgetBox(self.controlArea, "Rule quality estimation")
95        self.ruleQualityBG.buttons = []
96
97        OWGUI.separator(self.controlArea)
98        self.ruleValidationGroup = OWGUI.widgetBox(self.controlArea, "Pre-prunning (LRS)")
99
100        OWGUI.separator(self.controlArea)
101        OWGUI.spin(self.controlArea, self, "BeamWidth", 1, 100, box="Beam width", tooltip="The width of the search beam\n(number of rules to be specialized)")
102
103        OWGUI.separator(self.controlArea)
104        self.coveringAlgBG = OWGUI.widgetBox(self.controlArea, "Covering algorithm")
105        self.coveringAlgBG.buttons = []
106
107        b1 = QRadioButton("Laplace", self.ruleQualityBG)
108        self.ruleQualityBG.layout().addWidget(b1)
109        g = OWGUI.widgetBox(self.ruleQualityBG, orientation="horizontal")
110        b2 = QRadioButton("m-estimate", g)
111        g.layout().addWidget(b2)
112        self.mSpin = OWGUI.doubleSpin(g,self,"m",0,100)
113        b3 = QRadioButton("EVC", self.ruleQualityBG)
114        self.ruleQualityBG.layout().addWidget(b3)
115        b4 = QRadioButton("WRACC", self.ruleQualityBG)
116        self.ruleQualityBG.layout().addWidget(b4)
117        self.ruleQualityBG.buttons = [b1, b2, b3, b4]
118
119        for i, button in enumerate([b1, b2, b3, b4]):
120            self.connect(button, SIGNAL("clicked()"), lambda v=i: self.qualityButtonPressed(v))
121
122        form = QFormLayout(
123            labelAlignment=Qt.AlignLeft, formAlignment=Qt.AlignLeft,
124            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow
125        )
126
127        self.ruleValidationGroup.layout().addLayout(form)
128
129        alpha_spin = OWGUI.doubleSpin(
130            self.ruleValidationGroup, self, "Alpha", 0, 1, 0.001,
131            tooltip="Required significance of the difference between the " +
132                    "class distribution on all examples and covered examples")
133
134        step_alpha_spin = OWGUI.doubleSpin(
135            self.ruleValidationGroup, self, "stepAlpha", 0, 1, 0.001,
136            tooltip="Required significance of each specialization of a rule.")
137
138        min_coverage_spin = OWGUI.spin(
139            self.ruleValidationGroup, self, "MinCoverage", 0, 100,
140            tooltip="Minimum number of examples a rule must cover " +
141                    "(use 0 for not setting the limit)")
142
143        min_coverage_spin.setSpecialValueText("Unlimited")
144
145        # Check box needs to be in alayout for the form layout to center it
146        # in the vertical direction.
147        max_rule_box = OWGUI.widgetBox(self.ruleValidationGroup, "")
148        max_rule_cb = OWGUI.checkBox(
149            max_rule_box, self, "useMaxRuleLength", "Maximal rule length")
150
151        max_rule_spin = OWGUI.spin(
152            self.ruleValidationGroup, self, "MaxRuleLength", 1, 100,
153            tooltip="Maximal number of conditions in the left " +
154                    "part of the rule")
155        max_rule_cb.disables += [max_rule_spin]
156        max_rule_cb.makeConsistent()
157
158        form.addRow("Alpha (vs. default rule)", alpha_spin)
159        form.addRow("Stopping Alpha (vs. parent rule)", step_alpha_spin)
160        form.addRow("Minimum coverage", min_coverage_spin)
161        form.addRow(max_rule_box, max_rule_spin)
162
163        B1 = QRadioButton("Exclusive covering", self.coveringAlgBG)
164        self.coveringAlgBG.layout().addWidget(B1)
165        g = OWGUI.widgetBox(self.coveringAlgBG, orientation="horizontal")
166        B2 = QRadioButton("Weighted covering", g)
167        g.layout().addWidget(B2)
168        self.coveringAlgBG.buttons = [B1, B2]
169        self.weightSpin=OWGUI.doubleSpin(g,self,"Weight",0,0.95,0.05)
170
171        for i, button in enumerate([B1, B2]):
172            self.connect(button, SIGNAL("clicked()"), lambda v=i: self.coveringAlgButtonPressed(v))
173
174        OWGUI.separator(self.controlArea)
175        self.btnApply = OWGUI.button(self.controlArea, self, "&Apply", callback=self.applySettings, default=True)
176
177        self.Alpha=float(self.Alpha)
178        self.stepAlpha=float(self.stepAlpha)
179        self.Weight=float(self.Weight)
180
181        #self.ruleQualityBG.buttons[self.QualityButton].setChecked(1)
182        self.qualityButtonPressed(self.QualityButton)
183        self.coveringAlgButtonPressed(self.CoveringButton)
184        self.resize(100,100)
185        self.setLearner()
186
187    def sendReport(self):
188        self.reportSettings("Learning parameters",
189                            [("Rule quality estimation", ["Laplace", "m-estimate with m=%.2f" % self.m, "WRACC"][self.QualityButton]),
190                             ("Pruning alpha (vs. default rule)", "%.3f" % self.Alpha),
191                             ("Stopping alpha (vs. parent rule)", "%.3f" % self.stepAlpha),
192                             ("Minimum coverage", "%.3f" % self.MinCoverage),
193                             ("Maximal rule length", self.MaxRuleLength if self.useMaxRuleLength else "unlimited"),
194                             ("Beam width", self.BeamWidth),
195                             ("Covering", ["Exclusive", "Weighted with a weight of %.2f" % self.Weight][self.CoveringButton])])
196        self.reportData(self.data)
197
198    def setLearner(self):
199        if hasattr(self, "btnApply"):
200            self.btnApply.setFocus()
201        #progress bar
202        self.progressBarInit()
203
204        #learner / specific handling in case of EVC learning (completely different type of class)
205        if self.useMaxRuleLength:
206            maxRuleLength = self.MaxRuleLength
207        else:
208            maxRuleLength = -1
209       
210        if self.QualityButton == 2:
211            self.learner=orngCN2.CN2EVCUnorderedLearner(width=self.BeamWidth, rule_sig=self.Alpha, att_sig=self.stepAlpha,
212                                                        min_coverage = self.MinCoverage, max_rule_complexity = maxRuleLength)
213            if self.preprocessor:
214                self.learner = self.preprocessor.wrapLearner(self.learner)
215            self.learner.name = self.name
216#            self.learner.progressCallback=CN2ProgressBar(self)
217            self.send("Learner",self.learner)
218        else:
219            self.learner=orngCN2.CN2UnorderedLearner()
220            self.learner.name = self.name
221#            self.learner.progressCallback=CN2ProgressBar(self)
222#            self.send("Learner",self.learner)
223
224            ruleFinder=orange.RuleBeamFinder()
225            if self.QualityButton==0:
226                ruleFinder.evaluator=orange.RuleEvaluator_Laplace()
227            elif self.QualityButton==1:
228                ruleFinder.evaluator=orngCN2.mEstimate(self.m)
229            elif self.QualityButton==3:
230                ruleFinder.evaluator=orngCN2.WRACCEvaluator()
231
232
233            ruleFinder.ruleStoppingValidator=orange.RuleValidator_LRS(alpha=self.stepAlpha,
234                        min_coverage=self.MinCoverage, max_rule_complexity=maxRuleLength)
235            ruleFinder.validator=orange.RuleValidator_LRS(alpha=self.Alpha,
236                        min_coverage=self.MinCoverage, max_rule_complexity=maxRuleLength)
237            ruleFinder.ruleFilter=orange.RuleBeamFilter_Width(width=self.BeamWidth)
238            self.learner.ruleFinder=ruleFinder
239
240            if self.CoveringButton==0:
241                self.learner.coverAndRemove=orange.RuleCovererAndRemover_Default()
242            elif self.CoveringButton==1:
243                self.learner.coverAndRemove=orngCN2.CovererAndRemover_multWeights(mult=self.Weight)
244               
245            if self.preprocessor:
246                self.learner = self.preprocessor.wrapLearner(self.learner)
247            self.learner.name = self.name
248            self.send("Learner", self.learner)
249
250        self.classifier=None
251        self.error()
252        if self.data:
253            oldDomain = orange.Domain(self.data.domain)
254            learnData = orange.ExampleTable(oldDomain, self.data)
255            self.learner.progressCallback=CN2ProgressBar(self)
256            self.classifier=self.learner(learnData)
257            self.learner.progressCallback=None
258            self.classifier.name=self.name
259            for r in self.classifier.rules:
260                r.examples = orange.ExampleTable(oldDomain, r.examples)
261            self.classifier.examples = orange.ExampleTable(oldDomain, self.classifier.examples)
262            self.classifier.setattr("data",self.classifier.examples)
263            self.error("")
264##            except orange.KernelException, (errValue):
265##                self.classifier=None
266##                self.error(errValue)
267##            except Exception:
268##                self.classifier=None
269##                if not self.data.domain.classVar:
270##                    self.error("Classless domain.")
271##                elif self.data.domain.classVar.varType == orange.VarTypes.Continuous:
272##                    self.error("CN2 can learn only from discrete class.")
273##                else:
274##                    self.error("Unknown error")
275        self.send("Classifier", self.classifier)
276        self.send("Unordered CN2 Classifier", self.classifier)
277        self.progressBarFinished()
278
279    def dataset(self, data):
280        #self.data=data
281        self.data = self.isDataWithClass(data, orange.VarTypes.Discrete, checkMissing=True) and data or None
282        self.setLearner()
283       
284    def setPreprocessor(self, pp):
285        self.preprocessor = pp
286        self.setLearner()
287
288    def qualityButtonPressed(self, id=0):
289        self.QualityButton = id
290        for i in range(len(self.ruleQualityBG.buttons)):
291            self.ruleQualityBG.buttons[i].setChecked(id == i)
292        self.mSpin.control.setEnabled(id == 1)
293        self.coveringAlgBG.setEnabled(not id == 2)
294
295    def coveringAlgButtonPressed(self,id=0):
296        self.CoveringButton = id
297        for i in range(len(self.coveringAlgBG.buttons)):
298            self.coveringAlgBG.buttons[i].setChecked(id == i)
299        self.weightSpin.control.setEnabled(id == 1)
300
301    def applySettings(self):
302        self.setLearner()
303
304if __name__=="__main__":
305    app=QApplication(sys.argv)
306    w=OWCN2()
307    #w.dataset(orange.ExampleTable("titanic.tab"))
308    w.dataset(orange.ExampleTable("titanic.tab"))
309    w.show()
310    app.exec_()
311    w.saveSettings()
Note: See TracBrowser for help on using the repository browser.