source: orange/orange/OrangeWidgets/Classify/OWCN2.py @ 9546:2b6cc6f397fe

Revision 9546:2b6cc6f397fe, 12.5 KB checked in by ales_erjavec <ales.erjavec@…>, 2 years ago (diff)

Renamed widget channel names in line with the new naming rules/convention.
Added backwards compatibility in orngDoc loadDocument to enable loading of schemas saved before the change.

Line 
1"""
2<name>CN2</name>
3<description>Rule-based (CN2) learner/classifier.</description>
4<icon>icons/CN2.png</icon>
5<contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact>
6<priority>300</priority>
7"""
8from OWWidget import *
9import OWGUI, orange, orngCN2, sys
10
11from orngWrap import PreprocessedLearner
12
13class CN2ProgressBar(orange.ProgressCallback):
14    def __init__(self, widget, start=0.0, end=0.0):
15        self.start = start
16        self.end = end
17        self.widget = widget
18        orange.ProgressCallback.__init__(self)
19    def __call__(self,value,a):
20        self.widget.progressBarSet(100*(self.start+(self.end-self.start)*value))
21
22class OWCN2(OWWidget):
23    settingsList=["name", "QualityButton", "CoveringButton", "m",
24                  "MaxRuleLength", "useMaxRuleLength",
25                  "MinCoverage", "BeamWidth", "Alpha", "Weight", "stepAlpha"]
26    callbackDeposit=[]
27    def __init__(self, parent=None, signalManager=None):
28        OWWidget.__init__(self,parent,signalManager,"CN2", wantMainArea = 0, resizingEnabled = 0)
29
30        self.inputs = [("Data", ExampleTable, self.dataset), ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
31        self.outputs = [("Learner", orange.Learner),("Classifier",orange.Classifier),("Unordered CN2 Classifier", orngCN2.CN2UnorderedClassifier)]
32        self.QualityButton = 0
33        self.CoveringButton = 0
34        self.Alpha = 0.05
35        self.stepAlpha = 0.2
36        self.BeamWidth = 5
37        self.MinCoverage = 0
38        self.MaxRuleLength = 0
39        self.useMaxRuleLength = False
40        self.Weight = 0.9
41        self.m = 2
42        self.name = "CN2 rules"
43        self.loadSettings()
44
45        self.data=None
46        self.preprocessor = None
47
48        ##GUI
49        labelWidth = 150
50        self.learnerName = OWGUI.lineEdit(self.controlArea, self, "name", box="Learner/classifier name", tooltip="Name to be used by other widgets to identify the learner/classifier")
51        #self.learnerName.setText(self.name)
52        OWGUI.separator(self.controlArea)
53
54        self.ruleQualityBG = OWGUI.widgetBox(self.controlArea, "Rule quality estimation")
55        self.ruleQualityBG.buttons = []
56
57        OWGUI.separator(self.controlArea)
58        self.ruleValidationGroup = OWGUI.widgetBox(self.controlArea, "Pre-prunning (LRS)")
59
60        OWGUI.separator(self.controlArea)
61        OWGUI.spin(self.controlArea, self, "BeamWidth", 1, 100, box="Beam width", tooltip="The width of the search beam\n(number of rules to be specialized)")
62
63        OWGUI.separator(self.controlArea)
64        self.coveringAlgBG = OWGUI.widgetBox(self.controlArea, "Covering algorithm")
65        self.coveringAlgBG.buttons = []
66
67        """
68        self.ruleQualityBG=OWGUI.radioButtonsInBox(self.ruleQualityGroup, self, "QualityButton",
69                            btnLabels=["Laplace","m-estimate","WRACC"],
70                            box="Rule quality", callback=self.qualityButtonPressed,
71                            tooltips=["Laplace rule evaluator", "m-estimate rule evaluator",
72                            "WRACC rule evaluator"])
73        self.mSpin=Spin=OWGUI.spin(self.ruleQualityGroup, self, "m", 0, 100, label="m",
74                orientation="horizontal", labelWidth=labelWidth-100, tooltip="m value for m estimate rule evaluator")
75        """
76
77        b1 = QRadioButton("Laplace", self.ruleQualityBG); self.ruleQualityBG.layout().addWidget(b1)
78        g = OWGUI.widgetBox(self.ruleQualityBG, orientation = "horizontal");
79        b2 = QRadioButton("m-estimate", g)
80        g.layout().addWidget(b2)
81        self.mSpin = OWGUI.doubleSpin(g,self,"m",0,100)
82        b3 = QRadioButton("EVC", self.ruleQualityBG)
83        self.ruleQualityBG.layout().addWidget(b3)
84        b4 = QRadioButton("WRACC", self.ruleQualityBG)
85        self.ruleQualityBG.layout().addWidget(b4)
86        self.ruleQualityBG.buttons = [b1, b2, b3, b4]
87
88        for i, button in enumerate([b1, b2, b3, b4]):
89            self.connect(button, SIGNAL("clicked()"), lambda v=i: self.qualityButtonPressed(v))
90
91        OWGUI.doubleSpin(self.ruleValidationGroup, self, "Alpha", 0, 1,0.001, label="Alpha (vs. default rule)",
92                orientation="horizontal", labelWidth=labelWidth,
93                tooltip="Required significance of the difference between the class distribution on all example and covered examples")
94        OWGUI.doubleSpin(self.ruleValidationGroup, self, "stepAlpha", 0, 1,0.001, label="Stopping Alpha (vs. parent rule)",
95                orientation="horizontal", labelWidth=labelWidth,
96                tooltip="Required significance of each specialization of a rule.")
97        OWGUI.spin(self.ruleValidationGroup, self, "MinCoverage", 0, 100,label="Minimum coverage",
98                orientation="horizontal", labelWidth=labelWidth, tooltip=
99                "Minimum number of examples a rule must\ncover (use 0 for not setting the limit)")
100        OWGUI.checkWithSpin(self.ruleValidationGroup, self, "Maximal rule length", 0, 100, "useMaxRuleLength", "MaxRuleLength", labelWidth=labelWidth,
101                            tooltip="Maximal number of conditions in the left\npart of the rule (use 0 for don't care)")
102
103        """
104        self.coveringAlgBG=OWGUI.radioButtonsInBox(self.coveringAlgGroup, self, "CoveringButton",
105                            btnLabels=["Exclusive covering ","Weighted Covering"],
106                            tooltips=["Each example will only be used once\n for the construction of a rule",
107                                      "Examples can take part in the construction\n of many rules(CN2-SD Algorithm)"],
108                            box="Covering algorithm", callback=self.coveringAlgButtonPressed)
109        self.weightSpin=OWGUI.doubleSpin(self.coveringAlgGroup, self, "Weight",0, 0.95,0.05,label= "Weight",
110                orientation="horizontal", labelWidth=labelWidth, tooltip=
111                "Multiplication constant by which the weight of\nthe example will be reduced")
112        """
113
114        B1 = QRadioButton("Exclusive covering", self.coveringAlgBG); self.coveringAlgBG.layout().addWidget(B1)
115        g = OWGUI.widgetBox(self.coveringAlgBG, orientation = "horizontal")
116        B2 = QRadioButton("Weighted covering", g); g.layout().addWidget(B2)
117        self.coveringAlgBG.buttons = [B1, B2]
118        self.weightSpin=OWGUI.doubleSpin(g,self,"Weight",0,0.95,0.05)
119
120        for i, button in enumerate([B1, B2]):
121            self.connect(button, SIGNAL("clicked()"), lambda v=i: self.coveringAlgButtonPressed(v))
122
123        OWGUI.separator(self.controlArea)
124        self.btnApply = OWGUI.button(self.controlArea, self, "&Apply", callback=self.applySettings, default=True)
125
126        self.Alpha=float(self.Alpha)
127        self.stepAlpha=float(self.stepAlpha)
128        self.Weight=float(self.Weight)
129
130        #self.ruleQualityBG.buttons[self.QualityButton].setChecked(1)
131        self.qualityButtonPressed(self.QualityButton)
132        self.coveringAlgButtonPressed(self.CoveringButton)
133        self.resize(100,100)
134        self.setLearner()
135
136    def sendReport(self):
137        self.reportSettings("Learning parameters",
138                            [("Rule quality estimation", ["Laplace", "m-estimate with m=%.2f" % self.m, "WRACC"][self.QualityButton]),
139                             ("Pruning alpha (vs. default rule)", "%.3f" % self.Alpha),
140                             ("Stopping alpha (vs. parent rule)", "%.3f" % self.stepAlpha),
141                             ("Minimum coverage", "%.3f" % self.MinCoverage),
142                             ("Maximal rule length", self.MaxRuleLength if self.useMaxRuleLength else "unlimited"),
143                             ("Beam width", self.BeamWidth),
144                             ("Covering", ["Exclusive", "Weighted with a weight of %.2f" % self.Weight][self.CoveringButton])])
145        self.reportData(self.data)
146
147    def setLearner(self):
148        if hasattr(self, "btnApply"):
149            self.btnApply.setFocus()
150        #progress bar
151        self.progressBarInit()
152
153        #learner / specific handling in case of EVC learning (completely different type of class)
154        if self.useMaxRuleLength:
155            maxRuleLength = self.MaxRuleLength
156        else:
157            maxRuleLength = -1
158       
159        if self.QualityButton == 2:
160            self.learner=orngCN2.CN2EVCUnorderedLearner(width=self.BeamWidth, rule_sig=self.Alpha, att_sig=self.stepAlpha,
161                                                        min_coverage = self.MinCoverage, max_rule_complexity = maxRuleLength)
162            if self.preprocessor:
163                self.learner = self.preprocessor.wrapLearner(self.learner)
164            self.learner.name = self.name
165#            self.learner.progressCallback=CN2ProgressBar(self)
166            self.send("Learner",self.learner)
167        else:
168            self.learner=orngCN2.CN2UnorderedLearner()
169            self.learner.name = self.name
170#            self.learner.progressCallback=CN2ProgressBar(self)
171#            self.send("Learner",self.learner)
172
173            ruleFinder=orange.RuleBeamFinder()
174            if self.QualityButton==0:
175                ruleFinder.evaluator=orange.RuleEvaluator_Laplace()
176            elif self.QualityButton==1:
177                ruleFinder.evaluator=orngCN2.mEstimate(self.m)
178            elif self.QualityButton==3:
179                ruleFinder.evaluator=orngCN2.WRACCEvaluator()
180
181
182            ruleFinder.ruleStoppingValidator=orange.RuleValidator_LRS(alpha=self.stepAlpha,
183                        min_coverage=self.MinCoverage, max_rule_complexity=maxRuleLength)
184            ruleFinder.validator=orange.RuleValidator_LRS(alpha=self.Alpha,
185                        min_coverage=self.MinCoverage, max_rule_complexity=maxRuleLength)
186            ruleFinder.ruleFilter=orange.RuleBeamFilter_Width(width=self.BeamWidth)
187            self.learner.ruleFinder=ruleFinder
188
189            if self.CoveringButton==0:
190                self.learner.coverAndRemove=orange.RuleCovererAndRemover_Default()
191            elif self.CoveringButton==1:
192                self.learner.coverAndRemove=orngCN2.CovererAndRemover_multWeights(mult=self.Weight)
193               
194            if self.preprocessor:
195                self.learner = self.preprocessor.wrapLearner(self.learner)
196            self.learner.name = self.name
197            self.send("Learner", self.learner)
198
199        self.classifier=None
200        self.error()
201        if self.data:
202            oldDomain = orange.Domain(self.data.domain)
203            learnData = orange.ExampleTable(oldDomain, self.data)
204            self.learner.progressCallback=CN2ProgressBar(self)
205            self.classifier=self.learner(learnData)
206            self.learner.progressCallback=None
207            self.classifier.name=self.name
208            for r in self.classifier.rules:
209                r.examples = orange.ExampleTable(oldDomain, r.examples)
210            self.classifier.examples = orange.ExampleTable(oldDomain, self.classifier.examples)
211            self.classifier.setattr("data",self.classifier.examples)
212            self.error("")
213##            except orange.KernelException, (errValue):
214##                self.classifier=None
215##                self.error(errValue)
216##            except Exception:
217##                self.classifier=None
218##                if not self.data.domain.classVar:
219##                    self.error("Classless domain.")
220##                elif self.data.domain.classVar.varType == orange.VarTypes.Continuous:
221##                    self.error("CN2 can learn only from discrete class.")
222##                else:
223##                    self.error("Unknown error")
224        self.send("Classifier", self.classifier)
225        self.send("Unordered CN2 Classifier", self.classifier)
226        self.progressBarFinished()
227
228    def dataset(self, data):
229        #self.data=data
230        self.data = self.isDataWithClass(data, orange.VarTypes.Discrete, checkMissing=True) and data or None
231        self.setLearner()
232       
233    def setPreprocessor(self, pp):
234        self.preprocessor = pp
235        self.setLearner()
236
237    def qualityButtonPressed(self, id=0):
238        self.QualityButton = id
239        for i in range(len(self.ruleQualityBG.buttons)):
240            self.ruleQualityBG.buttons[i].setChecked(id == i)
241        self.mSpin.control.setEnabled(id == 1)
242        self.coveringAlgBG.setEnabled(not id == 2)
243
244    def coveringAlgButtonPressed(self,id=0):
245        self.CoveringButton = id
246        for i in range(len(self.coveringAlgBG.buttons)):
247            self.coveringAlgBG.buttons[i].setChecked(id == i)
248        self.weightSpin.control.setEnabled(id == 1)
249
250    def applySettings(self):
251        self.setLearner()
252
253if __name__=="__main__":
254    app=QApplication(sys.argv)
255    w=OWCN2()
256    #w.dataset(orange.ExampleTable("titanic.tab"))
257    w.dataset(orange.ExampleTable("titanic.tab"))
258    w.show()
259    app.exec_()
260    w.saveSettings()
Note: See TracBrowser for help on using the repository browser.