source: orange/orange/OrangeWidgets/Evaluate/OWTestLearners.py @ 9505:4b798678cd3d

Revision 9505:4b798678cd3d, 24.8 KB checked in by matija <matija.polajnar@…>, 2 years ago (diff)

Merge in the (heavily modified) MLC code from GSOC 2011 (modules, documentation, evaluation code, regression test). Widgets will be merged in a little bit later, which will finally close ticket #992.

Line 
1"""
2<name>Test Learners</name>
3<description>Estimates the predictive performance of learners on a data set.</description>
4<icon>icons/TestLearners.png</icon>
5<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
6<priority>200</priority>
7"""
8#
9# OWTestLearners.py
10#
11from OWWidget import *
12import orngTest, orngStat, OWGUI
13import time
14import warnings
15from orngWrap import PreprocessedLearner
16warnings.filterwarnings("ignore", "'id' is not a builtin attribute",
17                        orange.AttributeWarning)
18
19##############################################################################
20
21class Learner:
22    def __init__(self, learner, id):
23        learner.id = id
24        self.learner = learner
25        self.name = learner.name
26        self.id = id
27        self.scores = []
28        self.results = None
29        self.time = time.time() # used to order the learners in the table
30
31class Score:
32    def __init__(self, name, label, f, show=True, cmBased=False):
33        self.name = name
34        self.label = label
35        self.f = f
36        self.show = show
37        self.cmBased = cmBased
38       
39def dispatch(score_desc, res, cm):
40    """ Dispatch the call to orngStat method.
41    """
42    return eval("orngStat." + score_desc.f)
43
44
45class OWTestLearners(OWWidget):
46    settingsList = ["nFolds", "pLearning", "pRepeat", "precision",
47                    "selectedCScores", "selectedRScores", "applyOnAnyChange",
48                    "resampling"]
49    contextHandlers = {"": DomainContextHandler("", ["targetClass"])}
50    callbackDeposit = []
51
52    cStatistics = [Score(*s) for s in [\
53        ('Classification accuracy', 'CA', 'CA(res)', True),
54        ('Sensitivity', 'Sens', 'sens(cm)', True, True),
55        ('Specificity', 'Spec', 'spec(cm)', True, True),
56        ('Area under ROC curve', 'AUC', 'AUC(res)', True),
57        ('Information score', 'IS', 'IS(res)', False),
58        ('F-measure', 'F1', 'F1(cm)', False, True),
59        ('Precision', 'Prec', 'precision(cm)', False, True),
60        ('Recall', 'Recall', 'recall(cm)', False, True),
61        ('Brier score', 'Brier', 'BrierScore(res)', True),
62        ('Matthews correlation coefficient', 'MCC', 'MCC(cm)', False, True)]]
63
64    rStatistics = [Score(*s) for s in [\
65        ("Mean squared error", "MSE", "MSE(res)", False),
66        ("Root mean squared error", "RMSE", "RMSE(res)"),
67        ("Mean absolute error", "MAE", "MAE(res)", False),
68        ("Relative squared error", "RSE", "RSE(res)", False),
69        ("Root relative squared error", "RRSE", "RRSE(res)"),
70        ("Relative absolute error", "RAE", "RAE(res)", False),
71        ("R-squared", "R2", "R2(res)")]]
72
73    resamplingMethods = ["Cross-validation", "Leave-one-out", "Random sampling",
74                         "Test on train data", "Test on test data"]
75
76    def __init__(self,parent=None, signalManager = None):
77        OWWidget.__init__(self, parent, signalManager, "TestLearners")
78
79        self.inputs = [("Data", ExampleTable, self.setData, Default), ("Separate Test Data", ExampleTable, self.setTestData), ("Learner", orange.Learner, self.setLearner, Multiple), ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
80        self.outputs = [("Evaluation Results", orngTest.ExperimentResults)]
81
82        # Settings
83        self.resampling = 0             # cross-validation
84        self.nFolds = 5                 # cross validation folds
85        self.pLearning = 70   # size of learning set when sampling [%]
86        self.pRepeat = 10
87        self.precision = 4
88        self.applyOnAnyChange = True
89        self.selectedCScores = [i for (i,s) in enumerate(self.cStatistics) if s.show]
90        self.selectedRScores = [i for (i,s) in enumerate(self.rStatistics) if s.show]
91        self.targetClass = 0
92        self.loadSettings()
93        self.resampling = 0             # cross-validation
94
95        self.stat = self.cStatistics
96
97        self.data = None                # input data set
98        self.testdata = None            # separate test data set
99        self.learners = {}              # set of learners (input)
100        self.results = None             # from orngTest
101        self.preprocessor = None
102
103        self.controlArea.layout().setSpacing(8)
104        # GUI
105        self.sBtns = OWGUI.radioButtonsInBox(self.controlArea, self, "resampling", 
106                                             box="Sampling",
107                                             btnLabels=self.resamplingMethods[:1],
108                                             callback=self.newsampling)
109        indent = OWGUI.checkButtonOffsetHint(self.sBtns.buttons[-1])
110       
111        ibox = OWGUI.widgetBox(OWGUI.indentedBox(self.sBtns, sep=indent))
112        OWGUI.spin(ibox, self, 'nFolds', 2, 100, step=1,
113                   label='Number of folds:',
114                   callback=lambda p=0: self.conditionalRecompute(p),
115                   keyboardTracking=False)
116       
117        OWGUI.separator(self.sBtns, height = 3)
118       
119        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[1])      # leave one out
120        OWGUI.separator(self.sBtns, height = 3)
121        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[2])      # random sampling
122                       
123        ibox = OWGUI.widgetBox(OWGUI.indentedBox(self.sBtns, sep=indent))
124        OWGUI.spin(ibox, self, 'pRepeat', 1, 100, step=1,
125                   label='Repeat train/test:',
126                   callback=lambda p=2: self.conditionalRecompute(p),
127                   keyboardTracking=False)
128       
129        OWGUI.widgetLabel(ibox, "Relative training set size:")
130       
131        OWGUI.hSlider(ibox, self, 'pLearning', minValue=10, maxValue=100,
132                      step=1, ticks=10, labelFormat="   %d%%",
133                      callback=lambda p=2: self.conditionalRecompute(p))
134       
135        OWGUI.separator(self.sBtns, height = 3)
136        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[3])  # test on train
137        OWGUI.separator(self.sBtns, height = 3)
138        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[4])  # test on test
139
140        self.trainDataBtn = self.sBtns.buttons[-2]
141        self.testDataBtn = self.sBtns.buttons[-1]
142        self.testDataBtn.setDisabled(True)
143       
144        OWGUI.separator(self.sBtns)
145        OWGUI.checkBox(self.sBtns, self, 'applyOnAnyChange',
146                       label="Apply on any change", callback=self.applyChange)
147        self.applyBtn = OWGUI.button(self.sBtns, self, "&Apply",
148                                     callback=lambda f=True: self.recompute(f))
149        self.applyBtn.setDisabled(True)
150
151        if self.resampling == 4:
152            self.resampling = 3
153
154        # statistics
155        self.statLayout = QStackedLayout()
156        self.cbox = OWGUI.widgetBox(self.controlArea, addToLayout=False)
157        self.cStatLabels = [s.name for s in self.cStatistics]
158        self.cstatLB = OWGUI.listBox(self.cbox, self, 'selectedCScores',
159                                     'cStatLabels', box = "Performance scores",
160                                     selectionMode = QListWidget.MultiSelection,
161                                     callback=self.newscoreselection)
162       
163        self.cbox.layout().addSpacing(8)
164        self.targetCombo = OWGUI.comboBox(self.cbox, self, "targetClass", orientation=0,
165                                        callback=[self.changedTarget],
166                                        box="Target class")
167
168        self.rStatLabels = [s.name for s in self.rStatistics]
169        self.rbox = OWGUI.widgetBox(self.controlArea, "Performance scores", addToLayout=False)
170        self.rstatLB = OWGUI.listBox(self.rbox, self, 'selectedRScores', 'rStatLabels',
171                                     selectionMode = QListWidget.MultiSelection,
172                                     callback=self.newscoreselection)
173       
174        self.statLayout.addWidget(self.cbox)
175        self.statLayout.addWidget(self.rbox)
176        self.controlArea.layout().addLayout(self.statLayout)
177       
178        self.statLayout.setCurrentWidget(self.cbox)
179
180        # score table
181        # table with results
182        self.g = OWGUI.widgetBox(self.mainArea, 'Evaluation Results')
183        self.tab = OWGUI.table(self.g, selectionMode = QTableWidget.NoSelection)
184
185        self.resize(680,470)
186
187    # scoring and painting of score table
188    def isclassification(self):
189        if not self.data or not self.data.domain.classVar:
190            return True
191        return self.data.domain.classVar.varType == orange.VarTypes.Discrete
192       
193    def paintscores(self):
194        """paints the table with evaluation scores"""
195
196        self.tab.setColumnCount(len(self.stat)+1)
197        self.tab.setHorizontalHeaderLabels(["Method"] + [s.label for s in self.stat])
198       
199        prec="%%.%df" % self.precision
200
201        learners = [(l.time, l) for l in self.learners.values()]
202        learners.sort()
203        learners = [lt[1] for lt in learners]
204
205        self.tab.setRowCount(len(self.learners))
206        for (i, l) in enumerate(learners):
207            OWGUI.tableItem(self.tab, i,0, l.name)
208           
209        for (i, l) in enumerate(learners):
210            if l.scores:
211                for j in range(len(self.stat)):
212                    if l.scores[j] is not None:
213                        OWGUI.tableItem(self.tab, i, j+1, prec % l.scores[j])
214                    else:
215                        OWGUI.tableItem(self.tab, i, j+1, "N/A")
216            else:
217                for j in range(len(self.stat)):
218                    OWGUI.tableItem(self.tab, i, j+1, "")
219       
220        # adjust the width of the score table cloumns
221        self.tab.resizeColumnsToContents()
222        self.tab.resizeRowsToContents()
223        usestat = [self.selectedRScores, self.selectedCScores][self.isclassification()]
224        for i in range(len(self.stat)):
225            if i not in usestat:
226                self.tab.hideColumn(i+1)
227
228    def sendReport(self):
229        exset = []
230        if self.resampling == 0:
231            exset = [("Folds", self.nFolds)]
232        elif self.resampling == 2:
233            exset = [("Repetitions", self.pRepeat), ("Proportion of training instances", "%i%%" % self.pLearning)]
234        else:
235            exset = []
236        self.reportSettings("Validation method",
237                            [("Method", self.resamplingMethods[self.resampling])]
238                            + exset +
239                            ([("Target class", self.data.domain.classVar.values[self.targetClass])] if self.data else []))
240       
241        self.reportData(self.data)
242
243        if self.data:       
244            self.reportSection("Results")
245            learners = [(l.time, l) for l in self.learners.values()]
246            learners.sort()
247            learners = [lt[1] for lt in learners]
248            usestat = [self.selectedRScores, self.selectedCScores][self.isclassification()]
249           
250            res = "<table><tr><th></th>"+"".join("<th><b>%s</b></th>" % hr for hr in [s.label for i, s in enumerate(self.stat) if i in usestat])+"</tr>"
251            for i, l in enumerate(learners):
252                res += "<tr><th><b>%s</b></th>" % l.name
253                if l.scores:
254                    for j in usestat:
255                        scr = l.scores[j]
256                        res += "<td>" + ("%.4f" % scr if scr is not None else "") + "</td>"
257                res += "</tr>"
258            res += "</table>"
259            self.reportRaw(res)
260           
261    def score(self, ids):
262        """compute scores for the list of learners"""
263        if (not self.data):
264            for id in ids:
265                self.learners[id].results = None
266            return
267        # test which learners can accept the given data set
268        # e.g., regressions can't deal with classification data
269        learners = []
270        n = len(self.data.domain.attributes)*2
271        indices = orange.MakeRandomIndices2(p0=min(n, len(self.data)), stratified=orange.MakeRandomIndices2.StratifiedIfPossible)
272        new = self.data.selectref(indices(self.data))
273       
274        self.warning(0)
275        learner_exceptions = []
276        for l in [self.learners[id] for id in ids]:
277            learner = l.learner
278            if self.preprocessor:
279                learner = self.preprocessor.wrapLearner(learner)
280            try:
281                predictor = learner(new)
282                if predictor(new[0]).varType == new.domain.classVar.varType:
283                    learners.append(learner)
284                else:
285                    l.scores = []
286            except Exception, ex:
287                learner_exceptions.append((l, ex))
288                l.scores = []
289
290        if learner_exceptions:
291            text = "\n".join("Learner %s ends with exception: %s" % (l.name, str(ex)) \
292                             for l, ex in learner_exceptions)
293            self.warning(0, text)
294           
295        if not learners:
296            return
297
298        # computation of results (res, and cm if classification)
299        pb = None
300        if self.resampling==0:
301            pb = OWGUI.ProgressBar(self, iterations=self.nFolds)
302            res = orngTest.crossValidation(learners, self.data, folds=self.nFolds,
303                                           strat=orange.MakeRandomIndices.StratifiedIfPossible,
304                                           callback=pb.advance, storeExamples = True)
305            pb.finish()
306        elif self.resampling==1:
307            pb = OWGUI.ProgressBar(self, iterations=len(self.data))
308            res = orngTest.leaveOneOut(learners, self.data,
309                                       callback=pb.advance, storeExamples = True)
310            pb.finish()
311        elif self.resampling==2:
312            pb = OWGUI.ProgressBar(self, iterations=self.pRepeat)
313            res = orngTest.proportionTest(learners, self.data, self.pLearning/100.,
314                                          times=self.pRepeat, callback=pb.advance, storeExamples = True)
315            pb.finish()
316        elif self.resampling==3:
317            pb = OWGUI.ProgressBar(self, iterations=len(learners))
318            res = orngTest.learnAndTestOnLearnData(learners, self.data, storeExamples = True, callback=pb.advance)
319            pb.finish()
320           
321        elif self.resampling==4:
322            if not self.testdata:
323                for l in self.learners.values():
324                    l.scores = []
325                return
326            pb = OWGUI.ProgressBar(self, iterations=len(learners))
327            res = orngTest.learnAndTestOnTestData(learners, self.data, self.testdata, storeExamples = True, callback=pb.advance)
328            pb.finish()
329           
330        if self.isclassification():
331            cm = orngStat.computeConfusionMatrices(res, classIndex = self.targetClass)
332        else:
333            cm = None
334
335        if self.preprocessor: # Unwrap learners
336            learners = [l.wrappedLearner for l in learners]
337           
338        res.learners = learners
339       
340        for l in [self.learners[id] for id in ids]:
341            if l.learner in learners:
342                l.results = res
343
344        self.error(range(len(self.stat)))
345        scores = []
346       
347       
348       
349        for i, s in enumerate(self.stat):
350            if s.cmBased:
351                try:
352#                    scores.append(eval("orngStat." + s.f))
353                    scores.append(dispatch(s, res, cm))
354                except Exception, ex:
355                    self.error(i, "An error occurred while evaluating orngStat." + s.f + "on %s due to %s" % \
356                               (" ".join([l.name for l in learners]), ex.message))
357                    scores.append([None] * len(self.learners))
358            else:
359                scores_one = []
360                for res_one in orngStat.split_by_classifiers(res):
361                    try:
362#                        scores_one.append(eval("orngStat." + s.f)[0])
363                        scores_one.extend(dispatch(s, res_one, cm))
364                    except Exception, ex:
365                        self.error(i, "An error occurred while evaluating orngStat." + s.f + "on %s due to %s" % \
366                                   (res.classifierNames[0], ex.message))
367                        scores_one.append(None)
368                scores.append(scores_one)
369               
370        for i, l in enumerate(learners):
371            self.learners[l.id].scores = [s[i] if s else None for s in scores]
372           
373        self.sendResults()
374
375    def recomputeCM(self):
376        if not self.results:
377            return
378        cm = orngStat.computeConfusionMatrices(self.results, classIndex = self.targetClass)
379        scores = [(indx, eval("orngStat." + s.f))
380                  for (indx, s) in enumerate(self.stat) if s.cmBased]
381        for (indx, score) in scores:
382            for (i, l) in enumerate([l for l in self.learners.values() if l.scores]):
383                l.scores[indx] = score[i]
384        self.paintscores()
385       
386
387    # handle input signals
388
389    def setData(self, data):
390        """handle input train data set"""
391        self.closeContext()
392        self.data = self.isDataWithClass(data, checkMissing=True) and data or None
393        self.fillClassCombo()
394        if not self.data:
395            # data was removed, remove the scores
396            for l in self.learners.values():
397                l.scores = []
398                l.results = None
399            self.send("Evaluation Results", None)
400        else:
401            # new data has arrived
402            self.data = orange.Filter_hasClassValue(self.data)
403            self.statLayout.setCurrentWidget(self.cbox if self.isclassification() else self.rbox)
404           
405            self.stat = [self.rStatistics, self.cStatistics][self.isclassification()]
406           
407            if self.learners:
408                self.score([l.id for l in self.learners.values()])
409
410        self.openContext("", data)
411        self.paintscores()
412
413    def setTestData(self, data):
414        """handle test data set"""
415        if data is None:
416            self.testdata = None
417        else:
418            self.testdata = orange.Filter_hasClassValue(data)
419        self.testDataBtn.setEnabled(self.testdata is not None)
420        if self.testdata is not None:
421            if self.resampling == 4:
422                if self.data:
423                    self.score([l.id for l in self.learners.values()])
424                else:
425                    for l in self.learners.values():
426                        l.scores = []
427                self.paintscores()
428        elif self.resampling == 4 and self.data:
429            # test data removed, switch to testing on train data
430            self.resampling = 3
431            self.recompute()
432
433    def fillClassCombo(self):
434        """upon arrival of new data appropriately set the target class combo"""
435        self.targetCombo.clear()
436        if not self.data or not self.data.domain.classVar or not self.isclassification():
437            return
438
439        domain = self.data.domain
440        self.targetCombo.addItems([str(v) for v in domain.classVar.values])
441       
442        if self.targetClass<len(domain.classVar.values):
443            self.targetCombo.setCurrentIndex(self.targetClass)
444        else:
445            self.targetCombo.setCurrentIndex(0)
446            self.targetClass=0
447
448    def setLearner(self, learner, id=None):
449        """add/remove a learner"""
450        if learner: # a new or updated learner
451            if id in self.learners: # updated learner
452                time = self.learners[id].time
453                self.learners[id] = Learner(learner, id)
454                self.learners[id].time = time
455            else: # new learner
456                self.learners[id] = Learner(learner, id)
457            if self.applyBtn.isEnabled():
458                self.recompute(True)
459            else:
460                self.score([id])
461        else: # remove a learner and corresponding results
462            if id in self.learners:
463                res = self.learners[id].results
464                if res and res.numberOfLearners > 1:
465                    indx = [l.id for l in res.learners].index(id)
466                    res.remove(indx)
467                    del res.learners[indx]
468                del self.learners[id]
469            self.sendResults()
470        self.paintscores()
471       
472    def setPreprocessor(self, pp):
473        self.preprocessor = pp
474        if self.learners:
475            self.score([l.id for l in self.learners.values()])
476            self.paintscores()
477
478    # handle output signals
479
480    def sendResults(self):
481        """commit evaluation results"""
482        # for each learner, we first find a list where a result is stored
483        # and remember the corresponding index
484
485        valid = [(l.results, [x.id for x in l.results.learners].index(l.id))
486                 for l in self.learners.values() if l.scores and l.results]
487           
488        if not (self.data and len(valid)):
489            self.send("Evaluation Results", None)
490            return
491
492        # find the result set for a largest number of learners
493        # and remove this set from the list of result sets
494        rlist = dict([(l.results,1) for l in self.learners.values() if l.scores]).keys()
495        rlen = [r.numberOfLearners for r in rlist]
496        results = rlist.pop(rlen.index(max(rlen)))
497       
498        for (i, l) in enumerate(results.learners):
499            if not l.id in self.learners:
500                results.remove(i)
501                del results.learners[i]
502        for r in rlist:
503            for (i, l) in enumerate(r.learners):
504                if (r, i) in valid:
505                    results.add(r, i)
506                    results.learners.append(r.learners[i])
507                    self.learners[r.learners[i].id].results = results
508        self.send("Evaluation Results", results)
509        self.results = results
510
511    # signal processing
512
513    def newsampling(self):
514        """handle change of evaluation method"""
515        if not self.applyOnAnyChange:
516            self.applyBtn.setDisabled(self.applyOnAnyChange)
517        else:
518            if self.learners:
519                self.recompute()
520
521    def newscoreselection(self):
522        """handle change in set of scores to be displayed"""
523        usestat = [self.selectedRScores, self.selectedCScores][self.isclassification()]
524        for i in range(len(self.stat)):
525            if i in usestat:
526                self.tab.showColumn(i+1)
527                self.tab.resizeColumnToContents(i+1)
528            else:
529                self.tab.hideColumn(i+1)
530
531    def recompute(self, forced=False):
532        """recompute the scores for all learners,
533           if not forced, will do nothing but enable the Apply button"""
534        if self.applyOnAnyChange or forced:
535            self.score([l.id for l in self.learners.values()])
536            self.paintscores()
537            self.applyBtn.setDisabled(True)
538        else:
539            self.applyBtn.setEnabled(True)
540
541    def conditionalRecompute(self, option):
542        """calls recompute only if specific sampling option enabled"""
543        if self.resampling == option:
544            self.recompute(False)
545
546    def applyChange(self):
547        if self.applyOnAnyChange:
548            self.applyBtn.setDisabled(True)
549       
550    def changedTarget(self):
551        self.recomputeCM()
552
553##############################################################################
554# Test the widget, run from DOS prompt
555
556if __name__=="__main__":
557    a=QApplication(sys.argv)
558    ow=OWTestLearners()
559    ow.show()
560    a.exec_()
561
562    data1 = orange.ExampleTable(r'../../doc/datasets/voting')
563    data2 = orange.ExampleTable(r'../../golf')
564    datar = orange.ExampleTable(r'../../auto-mpg')
565    data3 = orange.ExampleTable(r'../../sailing-big')
566    data4 = orange.ExampleTable(r'../../sailing-test')
567
568    l1 = orange.MajorityLearner(); l1.name = '1 - Majority'
569
570    l2 = orange.BayesLearner()
571    l2.estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10)
572    l2.conditionalEstimatorConstructor = \
573        orange.ConditionalProbabilityEstimatorConstructor_ByRows(
574        estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10))
575    l2.name = '2 - NBC (m=10)'
576
577    l3 = orange.BayesLearner(); l3.name = '3 - NBC (default)'
578
579    l4 = orange.MajorityLearner(); l4.name = "4 - Majority"
580
581    import orngRegression as r
582    r5 = r.LinearRegressionLearner(name="0 - lin reg")
583
584    testcase = 4
585
586    if testcase == 0: # 1(UPD), 3, 4
587        ow.setData(data2)
588        ow.setLearner(r5, 5)
589        ow.setLearner(l1, 1)
590        ow.setLearner(l2, 2)
591        ow.setLearner(l3, 3)
592        l1.name = l1.name + " UPD"
593        ow.setLearner(l1, 1)
594        ow.setLearner(None, 2)
595        ow.setLearner(l4, 4)
596#        ow.setData(data1)
597#        ow.setData(datar)
598#        ow.setData(data1)
599    if testcase == 1: # data, but all learners removed
600        ow.setLearner(l1, 1)
601        ow.setLearner(l2, 2)
602        ow.setLearner(l1, 1)
603        ow.setLearner(None, 2)
604        ow.setData(data2)
605        ow.setLearner(None, 1)
606    if testcase == 2: # sends data, then learner, then removes the learner
607        ow.setData(data2)
608        ow.setLearner(l1, 1)
609        ow.setLearner(None, 1)
610    if testcase == 3: # regression first
611        ow.setData(datar)
612        ow.setLearner(r5, 5)
613    if testcase == 4: # separate train and test data
614        ow.setData(data3)
615        ow.setTestData(data4)
616        ow.setLearner(l2, 5)
617        ow.setTestData(None)
618
619    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.