source: orange/orange/OrangeWidgets/Evaluate/OWTestLearners.py @ 9510:0d054618385a

Revision 9510:0d054618385a, 24.9 KB checked in by ales_erjavec <ales.erjavec@…>, 2 years ago (diff)

Added Default flag to 'Learners' input channel.

Line 
1"""
2<name>Test Learners</name>
3<description>Estimates the predictive performance of learners on a data set.</description>
4<icon>icons/TestLearners.png</icon>
5<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
6<priority>200</priority>
7"""
8#
9# OWTestLearners.py
10#
11from OWWidget import *
12import orngTest, orngStat, OWGUI
13import time
14import warnings
15from orngWrap import PreprocessedLearner
16warnings.filterwarnings("ignore", "'id' is not a builtin attribute",
17                        orange.AttributeWarning)
18
19##############################################################################
20
21class Learner:
22    def __init__(self, learner, id):
23        learner.id = id
24        self.learner = learner
25        self.name = learner.name
26        self.id = id
27        self.scores = []
28        self.results = None
29        self.time = time.time() # used to order the learners in the table
30
31class Score:
32    def __init__(self, name, label, f, show=True, cmBased=False):
33        self.name = name
34        self.label = label
35        self.f = f
36        self.show = show
37        self.cmBased = cmBased
38       
39def dispatch(score_desc, res, cm):
40    """ Dispatch the call to orngStat method.
41    """
42    return eval("orngStat." + score_desc.f)
43
44
45class OWTestLearners(OWWidget):
46    settingsList = ["nFolds", "pLearning", "pRepeat", "precision",
47                    "selectedCScores", "selectedRScores", "applyOnAnyChange",
48                    "resampling"]
49    contextHandlers = {"": DomainContextHandler("", ["targetClass"])}
50    callbackDeposit = []
51
52    cStatistics = [Score(*s) for s in [\
53        ('Classification accuracy', 'CA', 'CA(res)', True),
54        ('Sensitivity', 'Sens', 'sens(cm)', True, True),
55        ('Specificity', 'Spec', 'spec(cm)', True, True),
56        ('Area under ROC curve', 'AUC', 'AUC(res)', True),
57        ('Information score', 'IS', 'IS(res)', False),
58        ('F-measure', 'F1', 'F1(cm)', False, True),
59        ('Precision', 'Prec', 'precision(cm)', False, True),
60        ('Recall', 'Recall', 'recall(cm)', False, True),
61        ('Brier score', 'Brier', 'BrierScore(res)', True),
62        ('Matthews correlation coefficient', 'MCC', 'MCC(cm)', False, True)]]
63
64    rStatistics = [Score(*s) for s in [\
65        ("Mean squared error", "MSE", "MSE(res)", False),
66        ("Root mean squared error", "RMSE", "RMSE(res)"),
67        ("Mean absolute error", "MAE", "MAE(res)", False),
68        ("Relative squared error", "RSE", "RSE(res)", False),
69        ("Root relative squared error", "RRSE", "RRSE(res)"),
70        ("Relative absolute error", "RAE", "RAE(res)", False),
71        ("R-squared", "R2", "R2(res)")]]
72
73    resamplingMethods = ["Cross-validation", "Leave-one-out", "Random sampling",
74                         "Test on train data", "Test on test data"]
75
76    def __init__(self,parent=None, signalManager = None):
77        OWWidget.__init__(self, parent, signalManager, "TestLearners")
78
79        self.inputs = [("Data", ExampleTable, self.setData, Default), 
80                       ("Separate Test Data", ExampleTable, self.setTestData),
81                       ("Learner", orange.Learner, self.setLearner, Multiple + Default),
82                       ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
83       
84        self.outputs = [("Evaluation Results", orngTest.ExperimentResults)]
85
86        # Settings
87        self.resampling = 0             # cross-validation
88        self.nFolds = 5                 # cross validation folds
89        self.pLearning = 70   # size of learning set when sampling [%]
90        self.pRepeat = 10
91        self.precision = 4
92        self.applyOnAnyChange = True
93        self.selectedCScores = [i for (i,s) in enumerate(self.cStatistics) if s.show]
94        self.selectedRScores = [i for (i,s) in enumerate(self.rStatistics) if s.show]
95        self.targetClass = 0
96        self.loadSettings()
97        self.resampling = 0             # cross-validation
98
99        self.stat = self.cStatistics
100
101        self.data = None                # input data set
102        self.testdata = None            # separate test data set
103        self.learners = {}              # set of learners (input)
104        self.results = None             # from orngTest
105        self.preprocessor = None
106
107        self.controlArea.layout().setSpacing(8)
108        # GUI
109        self.sBtns = OWGUI.radioButtonsInBox(self.controlArea, self, "resampling", 
110                                             box="Sampling",
111                                             btnLabels=self.resamplingMethods[:1],
112                                             callback=self.newsampling)
113        indent = OWGUI.checkButtonOffsetHint(self.sBtns.buttons[-1])
114       
115        ibox = OWGUI.widgetBox(OWGUI.indentedBox(self.sBtns, sep=indent))
116        OWGUI.spin(ibox, self, 'nFolds', 2, 100, step=1,
117                   label='Number of folds:',
118                   callback=lambda p=0: self.conditionalRecompute(p),
119                   keyboardTracking=False)
120       
121        OWGUI.separator(self.sBtns, height = 3)
122       
123        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[1])      # leave one out
124        OWGUI.separator(self.sBtns, height = 3)
125        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[2])      # random sampling
126                       
127        ibox = OWGUI.widgetBox(OWGUI.indentedBox(self.sBtns, sep=indent))
128        OWGUI.spin(ibox, self, 'pRepeat', 1, 100, step=1,
129                   label='Repeat train/test:',
130                   callback=lambda p=2: self.conditionalRecompute(p),
131                   keyboardTracking=False)
132       
133        OWGUI.widgetLabel(ibox, "Relative training set size:")
134       
135        OWGUI.hSlider(ibox, self, 'pLearning', minValue=10, maxValue=100,
136                      step=1, ticks=10, labelFormat="   %d%%",
137                      callback=lambda p=2: self.conditionalRecompute(p))
138       
139        OWGUI.separator(self.sBtns, height = 3)
140        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[3])  # test on train
141        OWGUI.separator(self.sBtns, height = 3)
142        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[4])  # test on test
143
144        self.trainDataBtn = self.sBtns.buttons[-2]
145        self.testDataBtn = self.sBtns.buttons[-1]
146        self.testDataBtn.setDisabled(True)
147       
148        OWGUI.separator(self.sBtns)
149        OWGUI.checkBox(self.sBtns, self, 'applyOnAnyChange',
150                       label="Apply on any change", callback=self.applyChange)
151        self.applyBtn = OWGUI.button(self.sBtns, self, "&Apply",
152                                     callback=lambda f=True: self.recompute(f))
153        self.applyBtn.setDisabled(True)
154
155        if self.resampling == 4:
156            self.resampling = 3
157
158        # statistics
159        self.statLayout = QStackedLayout()
160        self.cbox = OWGUI.widgetBox(self.controlArea, addToLayout=False)
161        self.cStatLabels = [s.name for s in self.cStatistics]
162        self.cstatLB = OWGUI.listBox(self.cbox, self, 'selectedCScores',
163                                     'cStatLabels', box = "Performance scores",
164                                     selectionMode = QListWidget.MultiSelection,
165                                     callback=self.newscoreselection)
166       
167        self.cbox.layout().addSpacing(8)
168        self.targetCombo = OWGUI.comboBox(self.cbox, self, "targetClass", orientation=0,
169                                        callback=[self.changedTarget],
170                                        box="Target class")
171
172        self.rStatLabels = [s.name for s in self.rStatistics]
173        self.rbox = OWGUI.widgetBox(self.controlArea, "Performance scores", addToLayout=False)
174        self.rstatLB = OWGUI.listBox(self.rbox, self, 'selectedRScores', 'rStatLabels',
175                                     selectionMode = QListWidget.MultiSelection,
176                                     callback=self.newscoreselection)
177       
178        self.statLayout.addWidget(self.cbox)
179        self.statLayout.addWidget(self.rbox)
180        self.controlArea.layout().addLayout(self.statLayout)
181       
182        self.statLayout.setCurrentWidget(self.cbox)
183
184        # score table
185        # table with results
186        self.g = OWGUI.widgetBox(self.mainArea, 'Evaluation Results')
187        self.tab = OWGUI.table(self.g, selectionMode = QTableWidget.NoSelection)
188
189        self.resize(680,470)
190
191    # scoring and painting of score table
192    def isclassification(self):
193        if not self.data or not self.data.domain.classVar:
194            return True
195        return self.data.domain.classVar.varType == orange.VarTypes.Discrete
196       
197    def paintscores(self):
198        """paints the table with evaluation scores"""
199
200        self.tab.setColumnCount(len(self.stat)+1)
201        self.tab.setHorizontalHeaderLabels(["Method"] + [s.label for s in self.stat])
202       
203        prec="%%.%df" % self.precision
204
205        learners = [(l.time, l) for l in self.learners.values()]
206        learners.sort()
207        learners = [lt[1] for lt in learners]
208
209        self.tab.setRowCount(len(self.learners))
210        for (i, l) in enumerate(learners):
211            OWGUI.tableItem(self.tab, i,0, l.name)
212           
213        for (i, l) in enumerate(learners):
214            if l.scores:
215                for j in range(len(self.stat)):
216                    if l.scores[j] is not None:
217                        OWGUI.tableItem(self.tab, i, j+1, prec % l.scores[j])
218                    else:
219                        OWGUI.tableItem(self.tab, i, j+1, "N/A")
220            else:
221                for j in range(len(self.stat)):
222                    OWGUI.tableItem(self.tab, i, j+1, "")
223       
224        # adjust the width of the score table cloumns
225        self.tab.resizeColumnsToContents()
226        self.tab.resizeRowsToContents()
227        usestat = [self.selectedRScores, self.selectedCScores][self.isclassification()]
228        for i in range(len(self.stat)):
229            if i not in usestat:
230                self.tab.hideColumn(i+1)
231
232    def sendReport(self):
233        exset = []
234        if self.resampling == 0:
235            exset = [("Folds", self.nFolds)]
236        elif self.resampling == 2:
237            exset = [("Repetitions", self.pRepeat), ("Proportion of training instances", "%i%%" % self.pLearning)]
238        else:
239            exset = []
240        self.reportSettings("Validation method",
241                            [("Method", self.resamplingMethods[self.resampling])]
242                            + exset +
243                            ([("Target class", self.data.domain.classVar.values[self.targetClass])] if self.data else []))
244       
245        self.reportData(self.data)
246
247        if self.data:       
248            self.reportSection("Results")
249            learners = [(l.time, l) for l in self.learners.values()]
250            learners.sort()
251            learners = [lt[1] for lt in learners]
252            usestat = [self.selectedRScores, self.selectedCScores][self.isclassification()]
253           
254            res = "<table><tr><th></th>"+"".join("<th><b>%s</b></th>" % hr for hr in [s.label for i, s in enumerate(self.stat) if i in usestat])+"</tr>"
255            for i, l in enumerate(learners):
256                res += "<tr><th><b>%s</b></th>" % l.name
257                if l.scores:
258                    for j in usestat:
259                        scr = l.scores[j]
260                        res += "<td>" + ("%.4f" % scr if scr is not None else "") + "</td>"
261                res += "</tr>"
262            res += "</table>"
263            self.reportRaw(res)
264           
265    def score(self, ids):
266        """compute scores for the list of learners"""
267        if (not self.data):
268            for id in ids:
269                self.learners[id].results = None
270            return
271        # test which learners can accept the given data set
272        # e.g., regressions can't deal with classification data
273        learners = []
274        n = len(self.data.domain.attributes)*2
275        indices = orange.MakeRandomIndices2(p0=min(n, len(self.data)), stratified=orange.MakeRandomIndices2.StratifiedIfPossible)
276        new = self.data.selectref(indices(self.data))
277       
278        self.warning(0)
279        learner_exceptions = []
280        for l in [self.learners[id] for id in ids]:
281            learner = l.learner
282            if self.preprocessor:
283                learner = self.preprocessor.wrapLearner(learner)
284            try:
285                predictor = learner(new)
286                if predictor(new[0]).varType == new.domain.classVar.varType:
287                    learners.append(learner)
288                else:
289                    l.scores = []
290            except Exception, ex:
291                learner_exceptions.append((l, ex))
292                l.scores = []
293
294        if learner_exceptions:
295            text = "\n".join("Learner %s ends with exception: %s" % (l.name, str(ex)) \
296                             for l, ex in learner_exceptions)
297            self.warning(0, text)
298           
299        if not learners:
300            return
301
302        # computation of results (res, and cm if classification)
303        pb = None
304        if self.resampling==0:
305            pb = OWGUI.ProgressBar(self, iterations=self.nFolds)
306            res = orngTest.crossValidation(learners, self.data, folds=self.nFolds,
307                                           strat=orange.MakeRandomIndices.StratifiedIfPossible,
308                                           callback=pb.advance, storeExamples = True)
309            pb.finish()
310        elif self.resampling==1:
311            pb = OWGUI.ProgressBar(self, iterations=len(self.data))
312            res = orngTest.leaveOneOut(learners, self.data,
313                                       callback=pb.advance, storeExamples = True)
314            pb.finish()
315        elif self.resampling==2:
316            pb = OWGUI.ProgressBar(self, iterations=self.pRepeat)
317            res = orngTest.proportionTest(learners, self.data, self.pLearning/100.,
318                                          times=self.pRepeat, callback=pb.advance, storeExamples = True)
319            pb.finish()
320        elif self.resampling==3:
321            pb = OWGUI.ProgressBar(self, iterations=len(learners))
322            res = orngTest.learnAndTestOnLearnData(learners, self.data, storeExamples = True, callback=pb.advance)
323            pb.finish()
324           
325        elif self.resampling==4:
326            if not self.testdata:
327                for l in self.learners.values():
328                    l.scores = []
329                return
330            pb = OWGUI.ProgressBar(self, iterations=len(learners))
331            res = orngTest.learnAndTestOnTestData(learners, self.data, self.testdata, storeExamples = True, callback=pb.advance)
332            pb.finish()
333           
334        if self.isclassification():
335            cm = orngStat.computeConfusionMatrices(res, classIndex = self.targetClass)
336        else:
337            cm = None
338
339        if self.preprocessor: # Unwrap learners
340            learners = [l.wrappedLearner for l in learners]
341           
342        res.learners = learners
343       
344        for l in [self.learners[id] for id in ids]:
345            if l.learner in learners:
346                l.results = res
347
348        self.error(range(len(self.stat)))
349        scores = []
350       
351       
352       
353        for i, s in enumerate(self.stat):
354            if s.cmBased:
355                try:
356#                    scores.append(eval("orngStat." + s.f))
357                    scores.append(dispatch(s, res, cm))
358                except Exception, ex:
359                    self.error(i, "An error occurred while evaluating orngStat." + s.f + "on %s due to %s" % \
360                               (" ".join([l.name for l in learners]), ex.message))
361                    scores.append([None] * len(self.learners))
362            else:
363                scores_one = []
364                for res_one in orngStat.split_by_classifiers(res):
365                    try:
366#                        scores_one.append(eval("orngStat." + s.f)[0])
367                        scores_one.extend(dispatch(s, res_one, cm))
368                    except Exception, ex:
369                        self.error(i, "An error occurred while evaluating orngStat." + s.f + "on %s due to %s" % \
370                                   (res.classifierNames[0], ex.message))
371                        scores_one.append(None)
372                scores.append(scores_one)
373               
374        for i, l in enumerate(learners):
375            self.learners[l.id].scores = [s[i] if s else None for s in scores]
376           
377        self.sendResults()
378
379    def recomputeCM(self):
380        if not self.results:
381            return
382        cm = orngStat.computeConfusionMatrices(self.results, classIndex = self.targetClass)
383        scores = [(indx, eval("orngStat." + s.f))
384                  for (indx, s) in enumerate(self.stat) if s.cmBased]
385        for (indx, score) in scores:
386            for (i, l) in enumerate([l for l in self.learners.values() if l.scores]):
387                l.scores[indx] = score[i]
388        self.paintscores()
389       
390
391    # handle input signals
392
393    def setData(self, data):
394        """handle input train data set"""
395        self.closeContext()
396        self.data = self.isDataWithClass(data, checkMissing=True) and data or None
397        self.fillClassCombo()
398        if not self.data:
399            # data was removed, remove the scores
400            for l in self.learners.values():
401                l.scores = []
402                l.results = None
403            self.send("Evaluation Results", None)
404        else:
405            # new data has arrived
406            self.data = orange.Filter_hasClassValue(self.data)
407            self.statLayout.setCurrentWidget(self.cbox if self.isclassification() else self.rbox)
408           
409            self.stat = [self.rStatistics, self.cStatistics][self.isclassification()]
410           
411            if self.learners:
412                self.score([l.id for l in self.learners.values()])
413
414        self.openContext("", data)
415        self.paintscores()
416
417    def setTestData(self, data):
418        """handle test data set"""
419        if data is None:
420            self.testdata = None
421        else:
422            self.testdata = orange.Filter_hasClassValue(data)
423        self.testDataBtn.setEnabled(self.testdata is not None)
424        if self.testdata is not None:
425            if self.resampling == 4:
426                if self.data:
427                    self.score([l.id for l in self.learners.values()])
428                else:
429                    for l in self.learners.values():
430                        l.scores = []
431                self.paintscores()
432        elif self.resampling == 4 and self.data:
433            # test data removed, switch to testing on train data
434            self.resampling = 3
435            self.recompute()
436
437    def fillClassCombo(self):
438        """upon arrival of new data appropriately set the target class combo"""
439        self.targetCombo.clear()
440        if not self.data or not self.data.domain.classVar or not self.isclassification():
441            return
442
443        domain = self.data.domain
444        self.targetCombo.addItems([str(v) for v in domain.classVar.values])
445       
446        if self.targetClass<len(domain.classVar.values):
447            self.targetCombo.setCurrentIndex(self.targetClass)
448        else:
449            self.targetCombo.setCurrentIndex(0)
450            self.targetClass=0
451
452    def setLearner(self, learner, id=None):
453        """add/remove a learner"""
454        if learner: # a new or updated learner
455            if id in self.learners: # updated learner
456                time = self.learners[id].time
457                self.learners[id] = Learner(learner, id)
458                self.learners[id].time = time
459            else: # new learner
460                self.learners[id] = Learner(learner, id)
461            if self.applyBtn.isEnabled():
462                self.recompute(True)
463            else:
464                self.score([id])
465        else: # remove a learner and corresponding results
466            if id in self.learners:
467                res = self.learners[id].results
468                if res and res.numberOfLearners > 1:
469                    indx = [l.id for l in res.learners].index(id)
470                    res.remove(indx)
471                    del res.learners[indx]
472                del self.learners[id]
473            self.sendResults()
474        self.paintscores()
475       
476    def setPreprocessor(self, pp):
477        self.preprocessor = pp
478        if self.learners:
479            self.score([l.id for l in self.learners.values()])
480            self.paintscores()
481
482    # handle output signals
483
484    def sendResults(self):
485        """commit evaluation results"""
486        # for each learner, we first find a list where a result is stored
487        # and remember the corresponding index
488
489        valid = [(l.results, [x.id for x in l.results.learners].index(l.id))
490                 for l in self.learners.values() if l.scores and l.results]
491           
492        if not (self.data and len(valid)):
493            self.send("Evaluation Results", None)
494            return
495
496        # find the result set for a largest number of learners
497        # and remove this set from the list of result sets
498        rlist = dict([(l.results,1) for l in self.learners.values() if l.scores]).keys()
499        rlen = [r.numberOfLearners for r in rlist]
500        results = rlist.pop(rlen.index(max(rlen)))
501       
502        for (i, l) in enumerate(results.learners):
503            if not l.id in self.learners:
504                results.remove(i)
505                del results.learners[i]
506        for r in rlist:
507            for (i, l) in enumerate(r.learners):
508                if (r, i) in valid:
509                    results.add(r, i)
510                    results.learners.append(r.learners[i])
511                    self.learners[r.learners[i].id].results = results
512        self.send("Evaluation Results", results)
513        self.results = results
514
515    # signal processing
516
517    def newsampling(self):
518        """handle change of evaluation method"""
519        if not self.applyOnAnyChange:
520            self.applyBtn.setDisabled(self.applyOnAnyChange)
521        else:
522            if self.learners:
523                self.recompute()
524
525    def newscoreselection(self):
526        """handle change in set of scores to be displayed"""
527        usestat = [self.selectedRScores, self.selectedCScores][self.isclassification()]
528        for i in range(len(self.stat)):
529            if i in usestat:
530                self.tab.showColumn(i+1)
531                self.tab.resizeColumnToContents(i+1)
532            else:
533                self.tab.hideColumn(i+1)
534
535    def recompute(self, forced=False):
536        """recompute the scores for all learners,
537           if not forced, will do nothing but enable the Apply button"""
538        if self.applyOnAnyChange or forced:
539            self.score([l.id for l in self.learners.values()])
540            self.paintscores()
541            self.applyBtn.setDisabled(True)
542        else:
543            self.applyBtn.setEnabled(True)
544
545    def conditionalRecompute(self, option):
546        """calls recompute only if specific sampling option enabled"""
547        if self.resampling == option:
548            self.recompute(False)
549
550    def applyChange(self):
551        if self.applyOnAnyChange:
552            self.applyBtn.setDisabled(True)
553       
554    def changedTarget(self):
555        self.recomputeCM()
556
557##############################################################################
558# Test the widget, run from DOS prompt
559
560if __name__=="__main__":
561    a=QApplication(sys.argv)
562    ow=OWTestLearners()
563    ow.show()
564    a.exec_()
565
566    data1 = orange.ExampleTable(r'../../doc/datasets/voting')
567    data2 = orange.ExampleTable(r'../../golf')
568    datar = orange.ExampleTable(r'../../auto-mpg')
569    data3 = orange.ExampleTable(r'../../sailing-big')
570    data4 = orange.ExampleTable(r'../../sailing-test')
571
572    l1 = orange.MajorityLearner(); l1.name = '1 - Majority'
573
574    l2 = orange.BayesLearner()
575    l2.estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10)
576    l2.conditionalEstimatorConstructor = \
577        orange.ConditionalProbabilityEstimatorConstructor_ByRows(
578        estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10))
579    l2.name = '2 - NBC (m=10)'
580
581    l3 = orange.BayesLearner(); l3.name = '3 - NBC (default)'
582
583    l4 = orange.MajorityLearner(); l4.name = "4 - Majority"
584
585    import orngRegression as r
586    r5 = r.LinearRegressionLearner(name="0 - lin reg")
587
588    testcase = 4
589
590    if testcase == 0: # 1(UPD), 3, 4
591        ow.setData(data2)
592        ow.setLearner(r5, 5)
593        ow.setLearner(l1, 1)
594        ow.setLearner(l2, 2)
595        ow.setLearner(l3, 3)
596        l1.name = l1.name + " UPD"
597        ow.setLearner(l1, 1)
598        ow.setLearner(None, 2)
599        ow.setLearner(l4, 4)
600#        ow.setData(data1)
601#        ow.setData(datar)
602#        ow.setData(data1)
603    if testcase == 1: # data, but all learners removed
604        ow.setLearner(l1, 1)
605        ow.setLearner(l2, 2)
606        ow.setLearner(l1, 1)
607        ow.setLearner(None, 2)
608        ow.setData(data2)
609        ow.setLearner(None, 1)
610    if testcase == 2: # sends data, then learner, then removes the learner
611        ow.setData(data2)
612        ow.setLearner(l1, 1)
613        ow.setLearner(None, 1)
614    if testcase == 3: # regression first
615        ow.setData(datar)
616        ow.setLearner(r5, 5)
617    if testcase == 4: # separate train and test data
618        ow.setData(data3)
619        ow.setTestData(data4)
620        ow.setLearner(l2, 5)
621        ow.setTestData(None)
622
623    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.