source: orange/Orange/OrangeWidgets/Evaluate/OWTestLearners.py @ 11680:45b798db4f91

Revision 11680:45b798db4f91, 28.3 KB checked in by Ales Erjavec <ales.erjavec@…>, 8 months ago (diff)

Explicitly track if widget output should be updated.

(fixes #1323).

RevLine 
[9458]1"""
2<name>Test Learners</name>
3<description>Estimates the predictive performance of learners on a data set.</description>
[11217]4<icon>icons/TestLearners1.svg</icon>
[9458]5<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
6<priority>200</priority>
7"""
[11679]8
[9458]9import time
10import warnings
[11679]11import itertools
12
13from OWWidget import *
14
15import orngTest, orngStat, OWGUI
16
[9458]17from orngWrap import PreprocessedLearner
18
[9599]19import Orange
20
[9458]21##############################################################################
22
23class Learner:
[11679]24    counter = itertools.count()
25
[9458]26    def __init__(self, learner, id):
27        self.learner = learner
28        self.name = learner.name
29        self.id = id
30        self.scores = []
31        self.results = None
[11679]32        # Used to order the learners in the table (named time for
33        # back compatibility reasons)
34        self.time = next(self.counter)
35
[9458]36
37class Score:
38    def __init__(self, name, label, f, show=True, cmBased=False):
39        self.name = name
40        self.label = label
41        self.f = f
42        self.show = show
43        self.cmBased = cmBased
[11679]44
45
[9483]46def dispatch(score_desc, res, cm):
47    """ Dispatch the call to orngStat method.
48    """
49    return eval("orngStat." + score_desc.f)
50
[9458]51
52class OWTestLearners(OWWidget):
53    settingsList = ["nFolds", "pLearning", "pRepeat", "precision",
54                    "selectedCScores", "selectedRScores", "applyOnAnyChange",
55                    "resampling"]
56    contextHandlers = {"": DomainContextHandler("", ["targetClass"])}
57    callbackDeposit = []
58
[9599]59    # Classification
[9458]60    cStatistics = [Score(*s) for s in [\
61        ('Classification accuracy', 'CA', 'CA(res)', True),
62        ('Sensitivity', 'Sens', 'sens(cm)', True, True),
63        ('Specificity', 'Spec', 'spec(cm)', True, True),
64        ('Area under ROC curve', 'AUC', 'AUC(res)', True),
65        ('Information score', 'IS', 'IS(res)', False),
66        ('F-measure', 'F1', 'F1(cm)', False, True),
67        ('Precision', 'Prec', 'precision(cm)', False, True),
68        ('Recall', 'Recall', 'recall(cm)', False, True),
69        ('Brier score', 'Brier', 'BrierScore(res)', True),
[8042]70        ('Matthews correlation coefficient', 'MCC', 'MCC(cm)', False, True)]]
71
[9599]72    # Regression
[9458]73    rStatistics = [Score(*s) for s in [\
74        ("Mean squared error", "MSE", "MSE(res)", False),
75        ("Root mean squared error", "RMSE", "RMSE(res)"),
76        ("Mean absolute error", "MAE", "MAE(res)", False),
77        ("Relative squared error", "RSE", "RSE(res)", False),
78        ("Root relative squared error", "RRSE", "RRSE(res)"),
79        ("Relative absolute error", "RAE", "RAE(res)", False),
80        ("R-squared", "R2", "R2(res)")]]
[9599]81   
82    # Multi-Label
83    mStatistics = [Score(*s) for s in [\
84        ('Hamming Loss', 'HammingLoss', 'mlc_hamming_loss(res)', False),
85        ('Accuracy', 'Accuracy', 'mlc_accuracy(res)', False),
86        ('Precision', 'Precision', 'mlc_precision(res)', False),
87        ('Recall', 'Recall', 'mlc_recall(res)', False),                               
88        ]]
89   
[9458]90    resamplingMethods = ["Cross-validation", "Leave-one-out", "Random sampling",
91                         "Test on train data", "Test on test data"]
92
93    def __init__(self,parent=None, signalManager = None):
[9609]94        OWWidget.__init__(self, parent, signalManager, "Test Learners")
[9458]95
[9599]96        self.inputs = [("Data", ExampleTable, self.setData, Default),
[9510]97                       ("Separate Test Data", ExampleTable, self.setTestData),
98                       ("Learner", orange.Learner, self.setLearner, Multiple + Default),
99                       ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
[9599]100
[9458]101        self.outputs = [("Evaluation Results", orngTest.ExperimentResults)]
102
103        # Settings
104        self.resampling = 0             # cross-validation
105        self.nFolds = 5                 # cross validation folds
106        self.pLearning = 70   # size of learning set when sampling [%]
107        self.pRepeat = 10
108        self.precision = 4
109        self.applyOnAnyChange = True
110        self.selectedCScores = [i for (i,s) in enumerate(self.cStatistics) if s.show]
111        self.selectedRScores = [i for (i,s) in enumerate(self.rStatistics) if s.show]
[9599]112        self.selectedMScores = [i for (i,s) in enumerate(self.mStatistics) if s.show]
[9458]113        self.targetClass = 0
114        self.loadSettings()
115
116        self.stat = self.cStatistics
117
118        self.data = None                # input data set
119        self.testdata = None            # separate test data set
120        self.learners = {}              # set of learners (input)
121        self.results = None             # from orngTest
122        self.preprocessor = None
123
[11680]124        # should the output evaluation results be recomputed
125        self._resultsInvalid = False
126
[9458]127        self.controlArea.layout().setSpacing(8)
128        # GUI
129        self.sBtns = OWGUI.radioButtonsInBox(self.controlArea, self, "resampling", 
130                                             box="Sampling",
131                                             btnLabels=self.resamplingMethods[:1],
132                                             callback=self.newsampling)
133        indent = OWGUI.checkButtonOffsetHint(self.sBtns.buttons[-1])
134       
135        ibox = OWGUI.widgetBox(OWGUI.indentedBox(self.sBtns, sep=indent))
136        OWGUI.spin(ibox, self, 'nFolds', 2, 100, step=1,
137                   label='Number of folds:',
138                   callback=lambda p=0: self.conditionalRecompute(p),
139                   keyboardTracking=False)
140       
141        OWGUI.separator(self.sBtns, height = 3)
142       
143        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[1])      # leave one out
144        OWGUI.separator(self.sBtns, height = 3)
145        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[2])      # random sampling
146                       
147        ibox = OWGUI.widgetBox(OWGUI.indentedBox(self.sBtns, sep=indent))
148        OWGUI.spin(ibox, self, 'pRepeat', 1, 100, step=1,
149                   label='Repeat train/test:',
150                   callback=lambda p=2: self.conditionalRecompute(p),
151                   keyboardTracking=False)
152       
153        OWGUI.widgetLabel(ibox, "Relative training set size:")
154       
155        OWGUI.hSlider(ibox, self, 'pLearning', minValue=10, maxValue=100,
156                      step=1, ticks=10, labelFormat="   %d%%",
157                      callback=lambda p=2: self.conditionalRecompute(p))
158       
159        OWGUI.separator(self.sBtns, height = 3)
160        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[3])  # test on train
161        OWGUI.separator(self.sBtns, height = 3)
162        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[4])  # test on test
163
164        self.trainDataBtn = self.sBtns.buttons[-2]
165        self.testDataBtn = self.sBtns.buttons[-1]
166        self.testDataBtn.setDisabled(True)
167       
168        OWGUI.separator(self.sBtns)
169        OWGUI.checkBox(self.sBtns, self, 'applyOnAnyChange',
170                       label="Apply on any change", callback=self.applyChange)
171        self.applyBtn = OWGUI.button(self.sBtns, self, "&Apply",
172                                     callback=lambda f=True: self.recompute(f))
173        self.applyBtn.setDisabled(True)
174
175        if self.resampling == 4:
176            self.resampling = 3
177
178        # statistics
179        self.statLayout = QStackedLayout()
180        self.cbox = OWGUI.widgetBox(self.controlArea, addToLayout=False)
181        self.cStatLabels = [s.name for s in self.cStatistics]
182        self.cstatLB = OWGUI.listBox(self.cbox, self, 'selectedCScores',
183                                     'cStatLabels', box = "Performance scores",
184                                     selectionMode = QListWidget.MultiSelection,
185                                     callback=self.newscoreselection)
186       
187        self.cbox.layout().addSpacing(8)
188        self.targetCombo = OWGUI.comboBox(self.cbox, self, "targetClass", orientation=0,
189                                        callback=[self.changedTarget],
190                                        box="Target class")
191
192        self.rStatLabels = [s.name for s in self.rStatistics]
193        self.rbox = OWGUI.widgetBox(self.controlArea, "Performance scores", addToLayout=False)
194        self.rstatLB = OWGUI.listBox(self.rbox, self, 'selectedRScores', 'rStatLabels',
195                                     selectionMode = QListWidget.MultiSelection,
196                                     callback=self.newscoreselection)
[9599]197
198        self.mStatLabels = [s.name for s in self.mStatistics]
199        self.mbox = OWGUI.widgetBox(self.controlArea, "Performance scores", addToLayout=False)
200        self.mstatLB = OWGUI.listBox(self.mbox, self, 'selectedMScores', 'mStatLabels',
201                                     selectionMode = QListWidget.MultiSelection,
202                                     callback=self.newscoreselection)
[9458]203       
204        self.statLayout.addWidget(self.cbox)
205        self.statLayout.addWidget(self.rbox)
[9599]206        self.statLayout.addWidget(self.mbox)
[9458]207        self.controlArea.layout().addLayout(self.statLayout)
208       
209        self.statLayout.setCurrentWidget(self.cbox)
210
211        # score table
212        # table with results
213        self.g = OWGUI.widgetBox(self.mainArea, 'Evaluation Results')
214        self.tab = OWGUI.table(self.g, selectionMode = QTableWidget.NoSelection)
215
216        self.resize(680,470)
217
218    # scoring and painting of score table
219    def isclassification(self):
220        if not self.data or not self.data.domain.classVar:
221            return True
222        return self.data.domain.classVar.varType == orange.VarTypes.Discrete
[9599]223
224    def ismultilabel(self, data = 42):
225        if data==42:
226            data = self.data
227        if not data:
228            return False
229        return Orange.multilabel.is_multilabel(data)
230
231    def get_usestat(self):
232        return ([self.selectedRScores, self.selectedCScores, self.selectedMScores]
233                [2 if self.ismultilabel() else self.isclassification()])
234
[9458]235    def paintscores(self):
236        """paints the table with evaluation scores"""
237
238        self.tab.setColumnCount(len(self.stat)+1)
239        self.tab.setHorizontalHeaderLabels(["Method"] + [s.label for s in self.stat])
240       
241        prec="%%.%df" % self.precision
242
243        learners = [(l.time, l) for l in self.learners.values()]
244        learners.sort()
245        learners = [lt[1] for lt in learners]
246
247        self.tab.setRowCount(len(self.learners))
248        for (i, l) in enumerate(learners):
249            OWGUI.tableItem(self.tab, i,0, l.name)
250           
251        for (i, l) in enumerate(learners):
252            if l.scores:
253                for j in range(len(self.stat)):
254                    if l.scores[j] is not None:
255                        OWGUI.tableItem(self.tab, i, j+1, prec % l.scores[j])
256                    else:
257                        OWGUI.tableItem(self.tab, i, j+1, "N/A")
258            else:
259                for j in range(len(self.stat)):
260                    OWGUI.tableItem(self.tab, i, j+1, "")
261       
262        # adjust the width of the score table cloumns
263        self.tab.resizeColumnsToContents()
264        self.tab.resizeRowsToContents()
[9599]265        usestat = self.get_usestat()
[9458]266        for i in range(len(self.stat)):
267            if i not in usestat:
[10983]268                self.tab.hideColumn(i + 1)
269            else:
270                self.tab.showColumn(i + 1)
[9458]271
272    def sendReport(self):
[11000]273        method = [("Method", self.resamplingMethods[self.resampling])]
274
[9458]275        exset = []
[11000]276
[9458]277        if self.resampling == 0:
278            exset = [("Folds", self.nFolds)]
279        elif self.resampling == 2:
[11000]280            exset = [("Repetitions", self.pRepeat),
281                     ("Proportion of training instances", "%i%%" \
282                      % self.pLearning)]
[9458]283        else:
284            exset = []
[11000]285
286        if self.data and \
287                isinstance(self.data.domain.classVar, orange.EnumVariable):
288            target = [("Target class",
289                       self.data.domain.classVar.values[self.targetClass])]
290        else:
291            target = []
292
[9599]293        if not self.ismultilabel():
[11000]294            self.reportSettings("Validation method", method + exset + target)
[9599]295        else:
[11000]296            self.reportSettings("Validation method", method + exset)
297
[9458]298        self.reportData(self.data)
299
[11000]300        if self.data:
[9458]301            self.reportSection("Results")
302            learners = [(l.time, l) for l in self.learners.values()]
303            learners.sort()
304            learners = [lt[1] for lt in learners]
[9599]305            usestat = self.get_usestat()
[11000]306            # table header
307            res = "<table><tr><th></th>" + \
308                  "".join("<th><b>%s</b></th>" % s.label \
309                          for i, s in enumerate(self.stat) \
310                          if i in usestat) + \
311                  "</tr>"
312
313            for i, learner in enumerate(learners):
314                res += "<tr><th><b>%s</b></th>" % learner.name
315                if learner.scores:
[9458]316                    for j in usestat:
[11000]317                        score = learner.scores[j]
318                        score = "%.4f" % score if score is not None else ""
319                        res += "<td>" + score + "</td>"
[9458]320                res += "</tr>"
321            res += "</table>"
322            self.reportRaw(res)
[11000]323
[9458]324    def score(self, ids):
325        """compute scores for the list of learners"""
[11680]326        if not self.data:
[9458]327            return
[11680]328
[9458]329        # test which learners can accept the given data set
330        # e.g., regressions can't deal with classification data
331        learners = []
[9528]332        used_ids = []
[9458]333        n = len(self.data.domain.attributes)*2
334        indices = orange.MakeRandomIndices2(p0=min(n, len(self.data)), stratified=orange.MakeRandomIndices2.StratifiedIfPossible)
[10999]335        new = self.data.selectref(indices(self.data), 0)
336
[9599]337        multilabel = self.ismultilabel()
338       
[9458]339        self.warning(0)
340        learner_exceptions = []
341        for l in [self.learners[id] for id in ids]:
342            learner = l.learner
343            if self.preprocessor:
344                learner = self.preprocessor.wrapLearner(learner)
345            try:
346                predictor = learner(new)
[11562]347            except Exception, ex:
348                learner_exceptions.append((l, ex))
349                l.scores = []
350                l.results = None
351            else:
[9599]352                if (multilabel and isinstance(learner, Orange.multilabel.MultiLabelLearner)) or predictor(new[0]).varType == new.domain.classVar.varType:
[9458]353                    learners.append(learner)
[9528]354                    used_ids.append(l.id)
[8042]355                else:
356                    l.scores = []
[9528]357                    l.results = None
[9599]358
[9458]359        if learner_exceptions:
360            text = "\n".join("Learner %s ends with exception: %s" % (l.name, str(ex)) \
361                             for l, ex in learner_exceptions)
362            self.warning(0, text)
[11680]363
[9458]364        if not learners:
365            return
366
367        # computation of results (res, and cm if classification)
368        pb = None
369        if self.resampling==0:
370            pb = OWGUI.ProgressBar(self, iterations=self.nFolds)
[9482]371            res = orngTest.crossValidation(learners, self.data, folds=self.nFolds,
372                                           strat=orange.MakeRandomIndices.StratifiedIfPossible,
373                                           callback=pb.advance, storeExamples = True)
[9458]374            pb.finish()
375        elif self.resampling==1:
376            pb = OWGUI.ProgressBar(self, iterations=len(self.data))
377            res = orngTest.leaveOneOut(learners, self.data,
378                                       callback=pb.advance, storeExamples = True)
379            pb.finish()
380        elif self.resampling==2:
381            pb = OWGUI.ProgressBar(self, iterations=self.pRepeat)
382            res = orngTest.proportionTest(learners, self.data, self.pLearning/100.,
383                                          times=self.pRepeat, callback=pb.advance, storeExamples = True)
384            pb.finish()
385        elif self.resampling==3:
386            pb = OWGUI.ProgressBar(self, iterations=len(learners))
387            res = orngTest.learnAndTestOnLearnData(learners, self.data, storeExamples = True, callback=pb.advance)
388            pb.finish()
389           
390        elif self.resampling==4:
391            if not self.testdata:
392                for l in self.learners.values():
393                    l.scores = []
394                return
395            pb = OWGUI.ProgressBar(self, iterations=len(learners))
396            res = orngTest.learnAndTestOnTestData(learners, self.data, self.testdata, storeExamples = True, callback=pb.advance)
397            pb.finish()
[9483]398           
[9599]399        if not self.ismultilabel() and self.isclassification():
[9458]400            cm = orngStat.computeConfusionMatrices(res, classIndex = self.targetClass)
[9483]401        else:
402            cm = None
[9458]403
404        if self.preprocessor: # Unwrap learners
405            learners = [l.wrappedLearner for l in learners]
406           
407        res.learners = learners
408       
409        for l in [self.learners[id] for id in ids]:
410            if l.learner in learners:
411                l.results = res
[9528]412            else:
413                l.results = None
[9458]414
415        self.error(range(len(self.stat)))
416        scores = []
[9483]417       
[9458]418        for i, s in enumerate(self.stat):
[9483]419            if s.cmBased:
420                try:
421                    scores.append(dispatch(s, res, cm))
422                except Exception, ex:
423                    self.error(i, "An error occurred while evaluating orngStat." + s.f + "on %s due to %s" % \
424                               (" ".join([l.name for l in learners]), ex.message))
425                    scores.append([None] * len(self.learners))
426            else:
427                scores_one = []
428                for res_one in orngStat.split_by_classifiers(res):
429                    try:
430                        scores_one.extend(dispatch(s, res_one, cm))
431                    except Exception, ex:
432                        self.error(i, "An error occurred while evaluating orngStat." + s.f + "on %s due to %s" % \
433                                   (res.classifierNames[0], ex.message))
434                        scores_one.append(None)
435                scores.append(scores_one)
[9599]436
[9528]437        for i, (id, l) in enumerate(zip(used_ids, learners)):
438            self.learners[id].scores = [s[i] if s else None for s in scores]
[9458]439           
440        self.sendResults()
441
442    def recomputeCM(self):
443        if not self.results:
444            return
445        cm = orngStat.computeConfusionMatrices(self.results, classIndex = self.targetClass)
446        scores = [(indx, eval("orngStat." + s.f))
447                  for (indx, s) in enumerate(self.stat) if s.cmBased]
448        for (indx, score) in scores:
449            for (i, l) in enumerate([l for l in self.learners.values() if l.scores]):
[9528]450                learner_indx = self.results.learners.index(l.learner)
451                l.scores[indx] = score[learner_indx]
[9599]452
[9458]453        self.paintscores()
454       
[9528]455    def clearScores(self, ids=None):
[11561]456        """
457        Clear/invalidate evaluation results/scores for learners for an
458        `ids` sequence (keys into self.learners). If `ids` is None then
459        invalidate data for all learners.
460
461        """
[9528]462        if ids is None:
463            ids = self.learners.keys()
[9599]464
[9528]465        for id in ids:
466            self.learners[id].scores = []
467            self.learners[id].results = None
[9599]468
[11680]469        self._resultsInvalid = bool(ids)
470
[9458]471    # handle input signals
472    def setData(self, data):
[11561]473        """
474        Set the input train data set.
475        """
[9458]476        self.closeContext()
[11561]477
478        self.clearScores()
479
480        multilabel = self.ismultilabel(data)
[9599]481        if not multilabel:
[11561]482            if self.isDataWithClass(data, checkMissing=True):
483                self.data = orange.Filter_hasClassValue(data)
484            else:
485                self.data = None
[9599]486        else:
487            self.data = data
[11561]488
[9458]489        self.fillClassCombo()
[9599]490
[11561]491        if self.data is not None:
492            # Ensure the correct statistics selection is visible.
493            if self.ismultilabel():
494                statwidget = self.mbox
495                stat = self.mStatistics
496            elif self.isclassification():
497                statwidget = self.cbox
498                stat = self.cStatistics
499            else:
500                statwidget = self.rbox
501                stat = self.rStatistics
[9599]502
[11561]503            self.statLayout.setCurrentWidget(statwidget)
[9599]504
[11561]505            self.stat = stat
[9458]506
[11680]507            self.openContext("", self.data)
508
[9458]509    def setTestData(self, data):
[11561]510        """
511        Set the 'Separate Test Data' input.
512        """
[9458]513        if data is None:
514            self.testdata = None
515        else:
516            self.testdata = orange.Filter_hasClassValue(data)
517        self.testDataBtn.setEnabled(self.testdata is not None)
518        if self.testdata is not None:
519            if self.resampling == 4:
[11561]520                self.clearScores()
521
[9458]522        elif self.resampling == 4 and self.data:
523            # test data removed, switch to testing on train data
524            self.resampling = 3
[11561]525            self.clearScores()
[9458]526
527    def fillClassCombo(self):
528        """upon arrival of new data appropriately set the target class combo"""
529        self.targetCombo.clear()
530        if not self.data or not self.data.domain.classVar or not self.isclassification():
531            return
532
533        domain = self.data.domain
534        self.targetCombo.addItems([str(v) for v in domain.classVar.values])
535       
536        if self.targetClass<len(domain.classVar.values):
[8042]537            self.targetCombo.setCurrentIndex(self.targetClass)
[9458]538        else:
539            self.targetCombo.setCurrentIndex(0)
540            self.targetClass=0
541
542    def setLearner(self, learner, id=None):
[11561]543        """
544        Set the input learner with `id`.
545        """
546        if learner is not None:
547            # a new or updated learner
548            if id in self.learners:
[9458]549                time = self.learners[id].time
550                self.learners[id] = Learner(learner, id)
551                self.learners[id].time = time
[11561]552                self.clearScores([id])
553            else:
[9458]554                self.learners[id] = Learner(learner, id)
[11680]555                self._resultsInvalid = True
[11561]556        else:
557            # remove a learner and corresponding results
[9458]558            if id in self.learners:
559                res = self.learners[id].results
560                if res and res.numberOfLearners > 1:
[11561]561                    # Remove the learner from the shared results instance
[9528]562                    old_learner = self.learners[id].learner
563                    indx = res.learners.index(old_learner)
[9458]564                    res.remove(indx)
565                    del res.learners[indx]
566                del self.learners[id]
[11561]567
[11680]568                self._resultsInvalid = True
569
[9458]570    def setPreprocessor(self, pp):
571        self.preprocessor = pp
[11561]572        self.clearScores()
573
574    def handleNewSignals(self):
575        """
576        Update the evaluations/scores after new inputs were received.
577        """
578        def needsupdate(learner_id):
579            return not (self.learners[learner_id].scores or
580                        self.learners[learner_id].results)
581
582        if self.applyOnAnyChange:
583            self.score(filter(needsupdate, self.learners))
[11680]584            if self._resultsInvalid:
585                self.sendResults()
586
[11561]587            self.applyBtn.setEnabled(False)
588        else:
589            self.applyBtn.setEnabled(True)
590
591        self.paintscores()
592
[11680]593        if self.data is None and self._resultsInvalid:
[11561]594            self.send("Evaluation Results", None)
[11680]595            self._resultsInvalid = False
596
[9458]597
598    # handle output signals
599
600    def sendResults(self):
601        """commit evaluation results"""
[11679]602        learners = sorted(self.learners.values(),
603                          key=lambda learner: learner.time)
[9458]604
[11679]605        learners = [learner for learner in learners
606                    if learner.results and learner.scores]
607
[11680]608        if self.data is None or len(learners) == 0:
[9458]609            self.send("Evaluation Results", None)
[11680]610            self._resultsInvalid = False
[9458]611            return
612
[11679]613        # Split combined results by learner/classifier
614        results_by_learner = {}
615        for learner in learners:
616            results = learner.results
617            res_split = orngStat.split_by_classifiers(results)
618            index = results.learners.index(learner.learner)
619            res_single = res_split[index]
620            res_single.learners = [learner.learner]
621            results_by_learner[learner] = res_single
[9599]622
[11679]623        def add_results(rhs, lhs):
624            assert(lhs.number_of_learners == 1)
625            rhs.add(lhs, 0)
626            rhs.learners.extend(lhs.learners)
627            return rhs
628
[11680]629        # Reconstruct the results in the same order as displayed.
[11679]630        results = [results_by_learner[learner] for learner in learners]
631        results = reduce(add_results, results)
632
633        self.results = results
[9458]634        self.send("Evaluation Results", results)
[11680]635        self._resultsInvalid = False
[11679]636        return
[9458]637
638    # signal processing
639
640    def newsampling(self):
641        """handle change of evaluation method"""
[11680]642        self.clearScores()
[9458]643        if not self.applyOnAnyChange:
644            self.applyBtn.setDisabled(self.applyOnAnyChange)
[11680]645        elif self.learners:
646            self.recompute()
[9458]647
648    def newscoreselection(self):
649        """handle change in set of scores to be displayed"""
[9599]650        usestat = self.get_usestat()
[9458]651        for i in range(len(self.stat)):
652            if i in usestat:
653                self.tab.showColumn(i+1)
654                self.tab.resizeColumnToContents(i+1)
655            else:
656                self.tab.hideColumn(i+1)
[8042]657
[9458]658    def recompute(self, forced=False):
659        """recompute the scores for all learners,
660           if not forced, will do nothing but enable the Apply button"""
661        if self.applyOnAnyChange or forced:
662            self.score([l.id for l in self.learners.values()])
663            self.paintscores()
664            self.applyBtn.setDisabled(True)
665        else:
666            self.applyBtn.setEnabled(True)
667
668    def conditionalRecompute(self, option):
669        """calls recompute only if specific sampling option enabled"""
670        if self.resampling == option:
671            self.recompute(False)
672
673    def applyChange(self):
[11561]674        """
675        A change of 'Apply on any change' check box.
676        """
677        def needsupdate(learner_id):
678            return not (self.learners[learner_id].scores or
679                        self.learners[learner_id].results)
680
[9458]681        if self.applyOnAnyChange:
682            self.applyBtn.setDisabled(True)
[11561]683            pending = filter(needsupdate, self.learners)
684            if pending:
685                self.score(pending)
686                self.paintscores()
687
[9458]688    def changedTarget(self):
689        self.recomputeCM()
690
691##############################################################################
692# Test the widget, run from DOS prompt
693
694if __name__=="__main__":
695    a=QApplication(sys.argv)
696    ow=OWTestLearners()
697    ow.show()
698
[11679]699    data1 = orange.ExampleTable('voting')
700#     data2 = orange.ExampleTable('golf')
701#     datar = orange.ExampleTable('auto-mpg')
702#     data3 = orange.ExampleTable('sailing-big')
703#     data4 = orange.ExampleTable('sailing-test')
[9599]704    data5 = orange.ExampleTable('emotions')
[9458]705
706    l1 = orange.MajorityLearner(); l1.name = '1 - Majority'
707
708    l2 = orange.BayesLearner()
709    l2.estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10)
710    l2.conditionalEstimatorConstructor = \
711        orange.ConditionalProbabilityEstimatorConstructor_ByRows(
712        estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10))
713    l2.name = '2 - NBC (m=10)'
714
715    l3 = orange.BayesLearner(); l3.name = '3 - NBC (default)'
716
717    l4 = orange.MajorityLearner(); l4.name = "4 - Majority"
[8042]718
[9458]719    import orngRegression as r
720    r5 = r.LinearRegressionLearner(name="0 - lin reg")
721
[9599]722    l5 = Orange.multilabel.BinaryRelevanceLearner()
723
[11679]724    testcase = 0
[9458]725
726    if testcase == 0: # 1(UPD), 3, 4
[11679]727        ow.setData(data1)
[9458]728        ow.setLearner(r5, 5)
729        ow.setLearner(l1, 1)
730        ow.setLearner(l2, 2)
731        ow.setLearner(l3, 3)
[11679]732        ow.handleNewSignals()
733
[9458]734        l1.name = l1.name + " UPD"
735        ow.setLearner(l1, 1)
736        ow.setLearner(None, 2)
737        ow.setLearner(l4, 4)
[11679]738        ow.handleNewSignals()
739
[9458]740    if testcase == 1: # data, but all learners removed
741        ow.setLearner(l1, 1)
742        ow.setLearner(l2, 2)
743        ow.setLearner(l1, 1)
744        ow.setLearner(None, 2)
745        ow.setData(data2)
746        ow.setLearner(None, 1)
747    if testcase == 2: # sends data, then learner, then removes the learner
748        ow.setData(data2)
749        ow.setLearner(l1, 1)
750        ow.setLearner(None, 1)
751    if testcase == 3: # regression first
752        ow.setData(datar)
753        ow.setLearner(r5, 5)
754    if testcase == 4: # separate train and test data
755        ow.setData(data3)
756        ow.setTestData(data4)
757        ow.setLearner(l2, 5)
758        ow.setTestData(None)
[9599]759    if testcase == 5: # MLC
760        ow.setData(data5)
761        ow.setLearner(l5, 6)
[8042]762
[11679]763    a.exec_()
[9458]764    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.