source: orange/Orange/OrangeWidgets/Evaluate/OWTestLearners.py @ 11680:45b798db4f91

Revision 11680:45b798db4f91, 28.3 KB checked in by Ales Erjavec <ales.erjavec@…>, 8 months ago (diff)

Explicitly track if widget output should be updated.

(fixes #1323).

Line 
1"""
2<name>Test Learners</name>
3<description>Estimates the predictive performance of learners on a data set.</description>
4<icon>icons/TestLearners1.svg</icon>
5<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
6<priority>200</priority>
7"""
8
9import time
10import warnings
11import itertools
12
13from OWWidget import *
14
15import orngTest, orngStat, OWGUI
16
17from orngWrap import PreprocessedLearner
18
19import Orange
20
21##############################################################################
22
23class Learner:
24    counter = itertools.count()
25
26    def __init__(self, learner, id):
27        self.learner = learner
28        self.name = learner.name
29        self.id = id
30        self.scores = []
31        self.results = None
32        # Used to order the learners in the table (named time for
33        # back compatibility reasons)
34        self.time = next(self.counter)
35
36
37class Score:
38    def __init__(self, name, label, f, show=True, cmBased=False):
39        self.name = name
40        self.label = label
41        self.f = f
42        self.show = show
43        self.cmBased = cmBased
44
45
46def dispatch(score_desc, res, cm):
47    """ Dispatch the call to orngStat method.
48    """
49    return eval("orngStat." + score_desc.f)
50
51
52class OWTestLearners(OWWidget):
53    settingsList = ["nFolds", "pLearning", "pRepeat", "precision",
54                    "selectedCScores", "selectedRScores", "applyOnAnyChange",
55                    "resampling"]
56    contextHandlers = {"": DomainContextHandler("", ["targetClass"])}
57    callbackDeposit = []
58
59    # Classification
60    cStatistics = [Score(*s) for s in [\
61        ('Classification accuracy', 'CA', 'CA(res)', True),
62        ('Sensitivity', 'Sens', 'sens(cm)', True, True),
63        ('Specificity', 'Spec', 'spec(cm)', True, True),
64        ('Area under ROC curve', 'AUC', 'AUC(res)', True),
65        ('Information score', 'IS', 'IS(res)', False),
66        ('F-measure', 'F1', 'F1(cm)', False, True),
67        ('Precision', 'Prec', 'precision(cm)', False, True),
68        ('Recall', 'Recall', 'recall(cm)', False, True),
69        ('Brier score', 'Brier', 'BrierScore(res)', True),
70        ('Matthews correlation coefficient', 'MCC', 'MCC(cm)', False, True)]]
71
72    # Regression
73    rStatistics = [Score(*s) for s in [\
74        ("Mean squared error", "MSE", "MSE(res)", False),
75        ("Root mean squared error", "RMSE", "RMSE(res)"),
76        ("Mean absolute error", "MAE", "MAE(res)", False),
77        ("Relative squared error", "RSE", "RSE(res)", False),
78        ("Root relative squared error", "RRSE", "RRSE(res)"),
79        ("Relative absolute error", "RAE", "RAE(res)", False),
80        ("R-squared", "R2", "R2(res)")]]
81   
82    # Multi-Label
83    mStatistics = [Score(*s) for s in [\
84        ('Hamming Loss', 'HammingLoss', 'mlc_hamming_loss(res)', False),
85        ('Accuracy', 'Accuracy', 'mlc_accuracy(res)', False),
86        ('Precision', 'Precision', 'mlc_precision(res)', False),
87        ('Recall', 'Recall', 'mlc_recall(res)', False),                               
88        ]]
89   
90    resamplingMethods = ["Cross-validation", "Leave-one-out", "Random sampling",
91                         "Test on train data", "Test on test data"]
92
93    def __init__(self,parent=None, signalManager = None):
94        OWWidget.__init__(self, parent, signalManager, "Test Learners")
95
96        self.inputs = [("Data", ExampleTable, self.setData, Default),
97                       ("Separate Test Data", ExampleTable, self.setTestData),
98                       ("Learner", orange.Learner, self.setLearner, Multiple + Default),
99                       ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
100
101        self.outputs = [("Evaluation Results", orngTest.ExperimentResults)]
102
103        # Settings
104        self.resampling = 0             # cross-validation
105        self.nFolds = 5                 # cross validation folds
106        self.pLearning = 70   # size of learning set when sampling [%]
107        self.pRepeat = 10
108        self.precision = 4
109        self.applyOnAnyChange = True
110        self.selectedCScores = [i for (i,s) in enumerate(self.cStatistics) if s.show]
111        self.selectedRScores = [i for (i,s) in enumerate(self.rStatistics) if s.show]
112        self.selectedMScores = [i for (i,s) in enumerate(self.mStatistics) if s.show]
113        self.targetClass = 0
114        self.loadSettings()
115
116        self.stat = self.cStatistics
117
118        self.data = None                # input data set
119        self.testdata = None            # separate test data set
120        self.learners = {}              # set of learners (input)
121        self.results = None             # from orngTest
122        self.preprocessor = None
123
124        # should the output evaluation results be recomputed
125        self._resultsInvalid = False
126
127        self.controlArea.layout().setSpacing(8)
128        # GUI
129        self.sBtns = OWGUI.radioButtonsInBox(self.controlArea, self, "resampling", 
130                                             box="Sampling",
131                                             btnLabels=self.resamplingMethods[:1],
132                                             callback=self.newsampling)
133        indent = OWGUI.checkButtonOffsetHint(self.sBtns.buttons[-1])
134       
135        ibox = OWGUI.widgetBox(OWGUI.indentedBox(self.sBtns, sep=indent))
136        OWGUI.spin(ibox, self, 'nFolds', 2, 100, step=1,
137                   label='Number of folds:',
138                   callback=lambda p=0: self.conditionalRecompute(p),
139                   keyboardTracking=False)
140       
141        OWGUI.separator(self.sBtns, height = 3)
142       
143        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[1])      # leave one out
144        OWGUI.separator(self.sBtns, height = 3)
145        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[2])      # random sampling
146                       
147        ibox = OWGUI.widgetBox(OWGUI.indentedBox(self.sBtns, sep=indent))
148        OWGUI.spin(ibox, self, 'pRepeat', 1, 100, step=1,
149                   label='Repeat train/test:',
150                   callback=lambda p=2: self.conditionalRecompute(p),
151                   keyboardTracking=False)
152       
153        OWGUI.widgetLabel(ibox, "Relative training set size:")
154       
155        OWGUI.hSlider(ibox, self, 'pLearning', minValue=10, maxValue=100,
156                      step=1, ticks=10, labelFormat="   %d%%",
157                      callback=lambda p=2: self.conditionalRecompute(p))
158       
159        OWGUI.separator(self.sBtns, height = 3)
160        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[3])  # test on train
161        OWGUI.separator(self.sBtns, height = 3)
162        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[4])  # test on test
163
164        self.trainDataBtn = self.sBtns.buttons[-2]
165        self.testDataBtn = self.sBtns.buttons[-1]
166        self.testDataBtn.setDisabled(True)
167       
168        OWGUI.separator(self.sBtns)
169        OWGUI.checkBox(self.sBtns, self, 'applyOnAnyChange',
170                       label="Apply on any change", callback=self.applyChange)
171        self.applyBtn = OWGUI.button(self.sBtns, self, "&Apply",
172                                     callback=lambda f=True: self.recompute(f))
173        self.applyBtn.setDisabled(True)
174
175        if self.resampling == 4:
176            self.resampling = 3
177
178        # statistics
179        self.statLayout = QStackedLayout()
180        self.cbox = OWGUI.widgetBox(self.controlArea, addToLayout=False)
181        self.cStatLabels = [s.name for s in self.cStatistics]
182        self.cstatLB = OWGUI.listBox(self.cbox, self, 'selectedCScores',
183                                     'cStatLabels', box = "Performance scores",
184                                     selectionMode = QListWidget.MultiSelection,
185                                     callback=self.newscoreselection)
186       
187        self.cbox.layout().addSpacing(8)
188        self.targetCombo = OWGUI.comboBox(self.cbox, self, "targetClass", orientation=0,
189                                        callback=[self.changedTarget],
190                                        box="Target class")
191
192        self.rStatLabels = [s.name for s in self.rStatistics]
193        self.rbox = OWGUI.widgetBox(self.controlArea, "Performance scores", addToLayout=False)
194        self.rstatLB = OWGUI.listBox(self.rbox, self, 'selectedRScores', 'rStatLabels',
195                                     selectionMode = QListWidget.MultiSelection,
196                                     callback=self.newscoreselection)
197
198        self.mStatLabels = [s.name for s in self.mStatistics]
199        self.mbox = OWGUI.widgetBox(self.controlArea, "Performance scores", addToLayout=False)
200        self.mstatLB = OWGUI.listBox(self.mbox, self, 'selectedMScores', 'mStatLabels',
201                                     selectionMode = QListWidget.MultiSelection,
202                                     callback=self.newscoreselection)
203       
204        self.statLayout.addWidget(self.cbox)
205        self.statLayout.addWidget(self.rbox)
206        self.statLayout.addWidget(self.mbox)
207        self.controlArea.layout().addLayout(self.statLayout)
208       
209        self.statLayout.setCurrentWidget(self.cbox)
210
211        # score table
212        # table with results
213        self.g = OWGUI.widgetBox(self.mainArea, 'Evaluation Results')
214        self.tab = OWGUI.table(self.g, selectionMode = QTableWidget.NoSelection)
215
216        self.resize(680,470)
217
218    # scoring and painting of score table
219    def isclassification(self):
220        if not self.data or not self.data.domain.classVar:
221            return True
222        return self.data.domain.classVar.varType == orange.VarTypes.Discrete
223
224    def ismultilabel(self, data = 42):
225        if data==42:
226            data = self.data
227        if not data:
228            return False
229        return Orange.multilabel.is_multilabel(data)
230
231    def get_usestat(self):
232        return ([self.selectedRScores, self.selectedCScores, self.selectedMScores]
233                [2 if self.ismultilabel() else self.isclassification()])
234
235    def paintscores(self):
236        """paints the table with evaluation scores"""
237
238        self.tab.setColumnCount(len(self.stat)+1)
239        self.tab.setHorizontalHeaderLabels(["Method"] + [s.label for s in self.stat])
240       
241        prec="%%.%df" % self.precision
242
243        learners = [(l.time, l) for l in self.learners.values()]
244        learners.sort()
245        learners = [lt[1] for lt in learners]
246
247        self.tab.setRowCount(len(self.learners))
248        for (i, l) in enumerate(learners):
249            OWGUI.tableItem(self.tab, i,0, l.name)
250           
251        for (i, l) in enumerate(learners):
252            if l.scores:
253                for j in range(len(self.stat)):
254                    if l.scores[j] is not None:
255                        OWGUI.tableItem(self.tab, i, j+1, prec % l.scores[j])
256                    else:
257                        OWGUI.tableItem(self.tab, i, j+1, "N/A")
258            else:
259                for j in range(len(self.stat)):
260                    OWGUI.tableItem(self.tab, i, j+1, "")
261       
262        # adjust the width of the score table cloumns
263        self.tab.resizeColumnsToContents()
264        self.tab.resizeRowsToContents()
265        usestat = self.get_usestat()
266        for i in range(len(self.stat)):
267            if i not in usestat:
268                self.tab.hideColumn(i + 1)
269            else:
270                self.tab.showColumn(i + 1)
271
272    def sendReport(self):
273        method = [("Method", self.resamplingMethods[self.resampling])]
274
275        exset = []
276
277        if self.resampling == 0:
278            exset = [("Folds", self.nFolds)]
279        elif self.resampling == 2:
280            exset = [("Repetitions", self.pRepeat),
281                     ("Proportion of training instances", "%i%%" \
282                      % self.pLearning)]
283        else:
284            exset = []
285
286        if self.data and \
287                isinstance(self.data.domain.classVar, orange.EnumVariable):
288            target = [("Target class",
289                       self.data.domain.classVar.values[self.targetClass])]
290        else:
291            target = []
292
293        if not self.ismultilabel():
294            self.reportSettings("Validation method", method + exset + target)
295        else:
296            self.reportSettings("Validation method", method + exset)
297
298        self.reportData(self.data)
299
300        if self.data:
301            self.reportSection("Results")
302            learners = [(l.time, l) for l in self.learners.values()]
303            learners.sort()
304            learners = [lt[1] for lt in learners]
305            usestat = self.get_usestat()
306            # table header
307            res = "<table><tr><th></th>" + \
308                  "".join("<th><b>%s</b></th>" % s.label \
309                          for i, s in enumerate(self.stat) \
310                          if i in usestat) + \
311                  "</tr>"
312
313            for i, learner in enumerate(learners):
314                res += "<tr><th><b>%s</b></th>" % learner.name
315                if learner.scores:
316                    for j in usestat:
317                        score = learner.scores[j]
318                        score = "%.4f" % score if score is not None else ""
319                        res += "<td>" + score + "</td>"
320                res += "</tr>"
321            res += "</table>"
322            self.reportRaw(res)
323
324    def score(self, ids):
325        """compute scores for the list of learners"""
326        if not self.data:
327            return
328
329        # test which learners can accept the given data set
330        # e.g., regressions can't deal with classification data
331        learners = []
332        used_ids = []
333        n = len(self.data.domain.attributes)*2
334        indices = orange.MakeRandomIndices2(p0=min(n, len(self.data)), stratified=orange.MakeRandomIndices2.StratifiedIfPossible)
335        new = self.data.selectref(indices(self.data), 0)
336
337        multilabel = self.ismultilabel()
338       
339        self.warning(0)
340        learner_exceptions = []
341        for l in [self.learners[id] for id in ids]:
342            learner = l.learner
343            if self.preprocessor:
344                learner = self.preprocessor.wrapLearner(learner)
345            try:
346                predictor = learner(new)
347            except Exception, ex:
348                learner_exceptions.append((l, ex))
349                l.scores = []
350                l.results = None
351            else:
352                if (multilabel and isinstance(learner, Orange.multilabel.MultiLabelLearner)) or predictor(new[0]).varType == new.domain.classVar.varType:
353                    learners.append(learner)
354                    used_ids.append(l.id)
355                else:
356                    l.scores = []
357                    l.results = None
358
359        if learner_exceptions:
360            text = "\n".join("Learner %s ends with exception: %s" % (l.name, str(ex)) \
361                             for l, ex in learner_exceptions)
362            self.warning(0, text)
363
364        if not learners:
365            return
366
367        # computation of results (res, and cm if classification)
368        pb = None
369        if self.resampling==0:
370            pb = OWGUI.ProgressBar(self, iterations=self.nFolds)
371            res = orngTest.crossValidation(learners, self.data, folds=self.nFolds,
372                                           strat=orange.MakeRandomIndices.StratifiedIfPossible,
373                                           callback=pb.advance, storeExamples = True)
374            pb.finish()
375        elif self.resampling==1:
376            pb = OWGUI.ProgressBar(self, iterations=len(self.data))
377            res = orngTest.leaveOneOut(learners, self.data,
378                                       callback=pb.advance, storeExamples = True)
379            pb.finish()
380        elif self.resampling==2:
381            pb = OWGUI.ProgressBar(self, iterations=self.pRepeat)
382            res = orngTest.proportionTest(learners, self.data, self.pLearning/100.,
383                                          times=self.pRepeat, callback=pb.advance, storeExamples = True)
384            pb.finish()
385        elif self.resampling==3:
386            pb = OWGUI.ProgressBar(self, iterations=len(learners))
387            res = orngTest.learnAndTestOnLearnData(learners, self.data, storeExamples = True, callback=pb.advance)
388            pb.finish()
389           
390        elif self.resampling==4:
391            if not self.testdata:
392                for l in self.learners.values():
393                    l.scores = []
394                return
395            pb = OWGUI.ProgressBar(self, iterations=len(learners))
396            res = orngTest.learnAndTestOnTestData(learners, self.data, self.testdata, storeExamples = True, callback=pb.advance)
397            pb.finish()
398           
399        if not self.ismultilabel() and self.isclassification():
400            cm = orngStat.computeConfusionMatrices(res, classIndex = self.targetClass)
401        else:
402            cm = None
403
404        if self.preprocessor: # Unwrap learners
405            learners = [l.wrappedLearner for l in learners]
406           
407        res.learners = learners
408       
409        for l in [self.learners[id] for id in ids]:
410            if l.learner in learners:
411                l.results = res
412            else:
413                l.results = None
414
415        self.error(range(len(self.stat)))
416        scores = []
417       
418        for i, s in enumerate(self.stat):
419            if s.cmBased:
420                try:
421                    scores.append(dispatch(s, res, cm))
422                except Exception, ex:
423                    self.error(i, "An error occurred while evaluating orngStat." + s.f + "on %s due to %s" % \
424                               (" ".join([l.name for l in learners]), ex.message))
425                    scores.append([None] * len(self.learners))
426            else:
427                scores_one = []
428                for res_one in orngStat.split_by_classifiers(res):
429                    try:
430                        scores_one.extend(dispatch(s, res_one, cm))
431                    except Exception, ex:
432                        self.error(i, "An error occurred while evaluating orngStat." + s.f + "on %s due to %s" % \
433                                   (res.classifierNames[0], ex.message))
434                        scores_one.append(None)
435                scores.append(scores_one)
436
437        for i, (id, l) in enumerate(zip(used_ids, learners)):
438            self.learners[id].scores = [s[i] if s else None for s in scores]
439           
440        self.sendResults()
441
442    def recomputeCM(self):
443        if not self.results:
444            return
445        cm = orngStat.computeConfusionMatrices(self.results, classIndex = self.targetClass)
446        scores = [(indx, eval("orngStat." + s.f))
447                  for (indx, s) in enumerate(self.stat) if s.cmBased]
448        for (indx, score) in scores:
449            for (i, l) in enumerate([l for l in self.learners.values() if l.scores]):
450                learner_indx = self.results.learners.index(l.learner)
451                l.scores[indx] = score[learner_indx]
452
453        self.paintscores()
454       
455    def clearScores(self, ids=None):
456        """
457        Clear/invalidate evaluation results/scores for learners for an
458        `ids` sequence (keys into self.learners). If `ids` is None then
459        invalidate data for all learners.
460
461        """
462        if ids is None:
463            ids = self.learners.keys()
464
465        for id in ids:
466            self.learners[id].scores = []
467            self.learners[id].results = None
468
469        self._resultsInvalid = bool(ids)
470
471    # handle input signals
472    def setData(self, data):
473        """
474        Set the input train data set.
475        """
476        self.closeContext()
477
478        self.clearScores()
479
480        multilabel = self.ismultilabel(data)
481        if not multilabel:
482            if self.isDataWithClass(data, checkMissing=True):
483                self.data = orange.Filter_hasClassValue(data)
484            else:
485                self.data = None
486        else:
487            self.data = data
488
489        self.fillClassCombo()
490
491        if self.data is not None:
492            # Ensure the correct statistics selection is visible.
493            if self.ismultilabel():
494                statwidget = self.mbox
495                stat = self.mStatistics
496            elif self.isclassification():
497                statwidget = self.cbox
498                stat = self.cStatistics
499            else:
500                statwidget = self.rbox
501                stat = self.rStatistics
502
503            self.statLayout.setCurrentWidget(statwidget)
504
505            self.stat = stat
506
507            self.openContext("", self.data)
508
509    def setTestData(self, data):
510        """
511        Set the 'Separate Test Data' input.
512        """
513        if data is None:
514            self.testdata = None
515        else:
516            self.testdata = orange.Filter_hasClassValue(data)
517        self.testDataBtn.setEnabled(self.testdata is not None)
518        if self.testdata is not None:
519            if self.resampling == 4:
520                self.clearScores()
521
522        elif self.resampling == 4 and self.data:
523            # test data removed, switch to testing on train data
524            self.resampling = 3
525            self.clearScores()
526
527    def fillClassCombo(self):
528        """upon arrival of new data appropriately set the target class combo"""
529        self.targetCombo.clear()
530        if not self.data or not self.data.domain.classVar or not self.isclassification():
531            return
532
533        domain = self.data.domain
534        self.targetCombo.addItems([str(v) for v in domain.classVar.values])
535       
536        if self.targetClass<len(domain.classVar.values):
537            self.targetCombo.setCurrentIndex(self.targetClass)
538        else:
539            self.targetCombo.setCurrentIndex(0)
540            self.targetClass=0
541
542    def setLearner(self, learner, id=None):
543        """
544        Set the input learner with `id`.
545        """
546        if learner is not None:
547            # a new or updated learner
548            if id in self.learners:
549                time = self.learners[id].time
550                self.learners[id] = Learner(learner, id)
551                self.learners[id].time = time
552                self.clearScores([id])
553            else:
554                self.learners[id] = Learner(learner, id)
555                self._resultsInvalid = True
556        else:
557            # remove a learner and corresponding results
558            if id in self.learners:
559                res = self.learners[id].results
560                if res and res.numberOfLearners > 1:
561                    # Remove the learner from the shared results instance
562                    old_learner = self.learners[id].learner
563                    indx = res.learners.index(old_learner)
564                    res.remove(indx)
565                    del res.learners[indx]
566                del self.learners[id]
567
568                self._resultsInvalid = True
569
570    def setPreprocessor(self, pp):
571        self.preprocessor = pp
572        self.clearScores()
573
574    def handleNewSignals(self):
575        """
576        Update the evaluations/scores after new inputs were received.
577        """
578        def needsupdate(learner_id):
579            return not (self.learners[learner_id].scores or
580                        self.learners[learner_id].results)
581
582        if self.applyOnAnyChange:
583            self.score(filter(needsupdate, self.learners))
584            if self._resultsInvalid:
585                self.sendResults()
586
587            self.applyBtn.setEnabled(False)
588        else:
589            self.applyBtn.setEnabled(True)
590
591        self.paintscores()
592
593        if self.data is None and self._resultsInvalid:
594            self.send("Evaluation Results", None)
595            self._resultsInvalid = False
596
597
598    # handle output signals
599
600    def sendResults(self):
601        """commit evaluation results"""
602        learners = sorted(self.learners.values(),
603                          key=lambda learner: learner.time)
604
605        learners = [learner for learner in learners
606                    if learner.results and learner.scores]
607
608        if self.data is None or len(learners) == 0:
609            self.send("Evaluation Results", None)
610            self._resultsInvalid = False
611            return
612
613        # Split combined results by learner/classifier
614        results_by_learner = {}
615        for learner in learners:
616            results = learner.results
617            res_split = orngStat.split_by_classifiers(results)
618            index = results.learners.index(learner.learner)
619            res_single = res_split[index]
620            res_single.learners = [learner.learner]
621            results_by_learner[learner] = res_single
622
623        def add_results(rhs, lhs):
624            assert(lhs.number_of_learners == 1)
625            rhs.add(lhs, 0)
626            rhs.learners.extend(lhs.learners)
627            return rhs
628
629        # Reconstruct the results in the same order as displayed.
630        results = [results_by_learner[learner] for learner in learners]
631        results = reduce(add_results, results)
632
633        self.results = results
634        self.send("Evaluation Results", results)
635        self._resultsInvalid = False
636        return
637
638    # signal processing
639
640    def newsampling(self):
641        """handle change of evaluation method"""
642        self.clearScores()
643        if not self.applyOnAnyChange:
644            self.applyBtn.setDisabled(self.applyOnAnyChange)
645        elif self.learners:
646            self.recompute()
647
648    def newscoreselection(self):
649        """handle change in set of scores to be displayed"""
650        usestat = self.get_usestat()
651        for i in range(len(self.stat)):
652            if i in usestat:
653                self.tab.showColumn(i+1)
654                self.tab.resizeColumnToContents(i+1)
655            else:
656                self.tab.hideColumn(i+1)
657
658    def recompute(self, forced=False):
659        """recompute the scores for all learners,
660           if not forced, will do nothing but enable the Apply button"""
661        if self.applyOnAnyChange or forced:
662            self.score([l.id for l in self.learners.values()])
663            self.paintscores()
664            self.applyBtn.setDisabled(True)
665        else:
666            self.applyBtn.setEnabled(True)
667
668    def conditionalRecompute(self, option):
669        """calls recompute only if specific sampling option enabled"""
670        if self.resampling == option:
671            self.recompute(False)
672
673    def applyChange(self):
674        """
675        A change of 'Apply on any change' check box.
676        """
677        def needsupdate(learner_id):
678            return not (self.learners[learner_id].scores or
679                        self.learners[learner_id].results)
680
681        if self.applyOnAnyChange:
682            self.applyBtn.setDisabled(True)
683            pending = filter(needsupdate, self.learners)
684            if pending:
685                self.score(pending)
686                self.paintscores()
687
688    def changedTarget(self):
689        self.recomputeCM()
690
691##############################################################################
692# Test the widget, run from DOS prompt
693
694if __name__=="__main__":
695    a=QApplication(sys.argv)
696    ow=OWTestLearners()
697    ow.show()
698
699    data1 = orange.ExampleTable('voting')
700#     data2 = orange.ExampleTable('golf')
701#     datar = orange.ExampleTable('auto-mpg')
702#     data3 = orange.ExampleTable('sailing-big')
703#     data4 = orange.ExampleTable('sailing-test')
704    data5 = orange.ExampleTable('emotions')
705
706    l1 = orange.MajorityLearner(); l1.name = '1 - Majority'
707
708    l2 = orange.BayesLearner()
709    l2.estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10)
710    l2.conditionalEstimatorConstructor = \
711        orange.ConditionalProbabilityEstimatorConstructor_ByRows(
712        estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10))
713    l2.name = '2 - NBC (m=10)'
714
715    l3 = orange.BayesLearner(); l3.name = '3 - NBC (default)'
716
717    l4 = orange.MajorityLearner(); l4.name = "4 - Majority"
718
719    import orngRegression as r
720    r5 = r.LinearRegressionLearner(name="0 - lin reg")
721
722    l5 = Orange.multilabel.BinaryRelevanceLearner()
723
724    testcase = 0
725
726    if testcase == 0: # 1(UPD), 3, 4
727        ow.setData(data1)
728        ow.setLearner(r5, 5)
729        ow.setLearner(l1, 1)
730        ow.setLearner(l2, 2)
731        ow.setLearner(l3, 3)
732        ow.handleNewSignals()
733
734        l1.name = l1.name + " UPD"
735        ow.setLearner(l1, 1)
736        ow.setLearner(None, 2)
737        ow.setLearner(l4, 4)
738        ow.handleNewSignals()
739
740    if testcase == 1: # data, but all learners removed
741        ow.setLearner(l1, 1)
742        ow.setLearner(l2, 2)
743        ow.setLearner(l1, 1)
744        ow.setLearner(None, 2)
745        ow.setData(data2)
746        ow.setLearner(None, 1)
747    if testcase == 2: # sends data, then learner, then removes the learner
748        ow.setData(data2)
749        ow.setLearner(l1, 1)
750        ow.setLearner(None, 1)
751    if testcase == 3: # regression first
752        ow.setData(datar)
753        ow.setLearner(r5, 5)
754    if testcase == 4: # separate train and test data
755        ow.setData(data3)
756        ow.setTestData(data4)
757        ow.setLearner(l2, 5)
758        ow.setTestData(None)
759    if testcase == 5: # MLC
760        ow.setData(data5)
761        ow.setLearner(l5, 6)
762
763    a.exec_()
764    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.