source: orange/Orange/OrangeWidgets/Evaluate/OWTestLearners.py @ 11679:3c8959a16655

Revision 11679:3c8959a16655, 28.0 KB checked in by Ales Erjavec <ales.erjavec@…>, 8 months ago (diff)

Preserve learner order in 'TestLearners' widget output.

(fixes #1324)

RevLine 
[9458]1"""
2<name>Test Learners</name>
3<description>Estimates the predictive performance of learners on a data set.</description>
[11217]4<icon>icons/TestLearners1.svg</icon>
[9458]5<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
6<priority>200</priority>
7"""
[11679]8
[9458]9import time
10import warnings
[11679]11import itertools
12
13from OWWidget import *
14
15import orngTest, orngStat, OWGUI
16
[9458]17from orngWrap import PreprocessedLearner
18
[9599]19import Orange
20
[9458]21##############################################################################
22
23class Learner:
[11679]24    counter = itertools.count()
25
[9458]26    def __init__(self, learner, id):
27        self.learner = learner
28        self.name = learner.name
29        self.id = id
30        self.scores = []
31        self.results = None
[11679]32        # Used to order the learners in the table (named time for
33        # back compatibility reasons)
34        self.time = next(self.counter)
35
[9458]36
37class Score:
38    def __init__(self, name, label, f, show=True, cmBased=False):
39        self.name = name
40        self.label = label
41        self.f = f
42        self.show = show
43        self.cmBased = cmBased
[11679]44
45
[9483]46def dispatch(score_desc, res, cm):
47    """ Dispatch the call to orngStat method.
48    """
49    return eval("orngStat." + score_desc.f)
50
[9458]51
52class OWTestLearners(OWWidget):
53    settingsList = ["nFolds", "pLearning", "pRepeat", "precision",
54                    "selectedCScores", "selectedRScores", "applyOnAnyChange",
55                    "resampling"]
56    contextHandlers = {"": DomainContextHandler("", ["targetClass"])}
57    callbackDeposit = []
58
[9599]59    # Classification
[9458]60    cStatistics = [Score(*s) for s in [\
61        ('Classification accuracy', 'CA', 'CA(res)', True),
62        ('Sensitivity', 'Sens', 'sens(cm)', True, True),
63        ('Specificity', 'Spec', 'spec(cm)', True, True),
64        ('Area under ROC curve', 'AUC', 'AUC(res)', True),
65        ('Information score', 'IS', 'IS(res)', False),
66        ('F-measure', 'F1', 'F1(cm)', False, True),
67        ('Precision', 'Prec', 'precision(cm)', False, True),
68        ('Recall', 'Recall', 'recall(cm)', False, True),
69        ('Brier score', 'Brier', 'BrierScore(res)', True),
[8042]70        ('Matthews correlation coefficient', 'MCC', 'MCC(cm)', False, True)]]
71
[9599]72    # Regression
[9458]73    rStatistics = [Score(*s) for s in [\
74        ("Mean squared error", "MSE", "MSE(res)", False),
75        ("Root mean squared error", "RMSE", "RMSE(res)"),
76        ("Mean absolute error", "MAE", "MAE(res)", False),
77        ("Relative squared error", "RSE", "RSE(res)", False),
78        ("Root relative squared error", "RRSE", "RRSE(res)"),
79        ("Relative absolute error", "RAE", "RAE(res)", False),
80        ("R-squared", "R2", "R2(res)")]]
[9599]81   
82    # Multi-Label
83    mStatistics = [Score(*s) for s in [\
84        ('Hamming Loss', 'HammingLoss', 'mlc_hamming_loss(res)', False),
85        ('Accuracy', 'Accuracy', 'mlc_accuracy(res)', False),
86        ('Precision', 'Precision', 'mlc_precision(res)', False),
87        ('Recall', 'Recall', 'mlc_recall(res)', False),                               
88        ]]
89   
[9458]90    resamplingMethods = ["Cross-validation", "Leave-one-out", "Random sampling",
91                         "Test on train data", "Test on test data"]
92
93    def __init__(self,parent=None, signalManager = None):
[9609]94        OWWidget.__init__(self, parent, signalManager, "Test Learners")
[9458]95
[9599]96        self.inputs = [("Data", ExampleTable, self.setData, Default),
[9510]97                       ("Separate Test Data", ExampleTable, self.setTestData),
98                       ("Learner", orange.Learner, self.setLearner, Multiple + Default),
99                       ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
[9599]100
[9458]101        self.outputs = [("Evaluation Results", orngTest.ExperimentResults)]
102
103        # Settings
104        self.resampling = 0             # cross-validation
105        self.nFolds = 5                 # cross validation folds
106        self.pLearning = 70   # size of learning set when sampling [%]
107        self.pRepeat = 10
108        self.precision = 4
109        self.applyOnAnyChange = True
110        self.selectedCScores = [i for (i,s) in enumerate(self.cStatistics) if s.show]
111        self.selectedRScores = [i for (i,s) in enumerate(self.rStatistics) if s.show]
[9599]112        self.selectedMScores = [i for (i,s) in enumerate(self.mStatistics) if s.show]
[9458]113        self.targetClass = 0
114        self.loadSettings()
115
116        self.stat = self.cStatistics
117
118        self.data = None                # input data set
119        self.testdata = None            # separate test data set
120        self.learners = {}              # set of learners (input)
121        self.results = None             # from orngTest
122        self.preprocessor = None
123
124        self.controlArea.layout().setSpacing(8)
125        # GUI
126        self.sBtns = OWGUI.radioButtonsInBox(self.controlArea, self, "resampling", 
127                                             box="Sampling",
128                                             btnLabels=self.resamplingMethods[:1],
129                                             callback=self.newsampling)
130        indent = OWGUI.checkButtonOffsetHint(self.sBtns.buttons[-1])
131       
132        ibox = OWGUI.widgetBox(OWGUI.indentedBox(self.sBtns, sep=indent))
133        OWGUI.spin(ibox, self, 'nFolds', 2, 100, step=1,
134                   label='Number of folds:',
135                   callback=lambda p=0: self.conditionalRecompute(p),
136                   keyboardTracking=False)
137       
138        OWGUI.separator(self.sBtns, height = 3)
139       
140        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[1])      # leave one out
141        OWGUI.separator(self.sBtns, height = 3)
142        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[2])      # random sampling
143                       
144        ibox = OWGUI.widgetBox(OWGUI.indentedBox(self.sBtns, sep=indent))
145        OWGUI.spin(ibox, self, 'pRepeat', 1, 100, step=1,
146                   label='Repeat train/test:',
147                   callback=lambda p=2: self.conditionalRecompute(p),
148                   keyboardTracking=False)
149       
150        OWGUI.widgetLabel(ibox, "Relative training set size:")
151       
152        OWGUI.hSlider(ibox, self, 'pLearning', minValue=10, maxValue=100,
153                      step=1, ticks=10, labelFormat="   %d%%",
154                      callback=lambda p=2: self.conditionalRecompute(p))
155       
156        OWGUI.separator(self.sBtns, height = 3)
157        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[3])  # test on train
158        OWGUI.separator(self.sBtns, height = 3)
159        OWGUI.appendRadioButton(self.sBtns, self, "resampling", self.resamplingMethods[4])  # test on test
160
161        self.trainDataBtn = self.sBtns.buttons[-2]
162        self.testDataBtn = self.sBtns.buttons[-1]
163        self.testDataBtn.setDisabled(True)
164       
165        OWGUI.separator(self.sBtns)
166        OWGUI.checkBox(self.sBtns, self, 'applyOnAnyChange',
167                       label="Apply on any change", callback=self.applyChange)
168        self.applyBtn = OWGUI.button(self.sBtns, self, "&Apply",
169                                     callback=lambda f=True: self.recompute(f))
170        self.applyBtn.setDisabled(True)
171
172        if self.resampling == 4:
173            self.resampling = 3
174
175        # statistics
176        self.statLayout = QStackedLayout()
177        self.cbox = OWGUI.widgetBox(self.controlArea, addToLayout=False)
178        self.cStatLabels = [s.name for s in self.cStatistics]
179        self.cstatLB = OWGUI.listBox(self.cbox, self, 'selectedCScores',
180                                     'cStatLabels', box = "Performance scores",
181                                     selectionMode = QListWidget.MultiSelection,
182                                     callback=self.newscoreselection)
183       
184        self.cbox.layout().addSpacing(8)
185        self.targetCombo = OWGUI.comboBox(self.cbox, self, "targetClass", orientation=0,
186                                        callback=[self.changedTarget],
187                                        box="Target class")
188
189        self.rStatLabels = [s.name for s in self.rStatistics]
190        self.rbox = OWGUI.widgetBox(self.controlArea, "Performance scores", addToLayout=False)
191        self.rstatLB = OWGUI.listBox(self.rbox, self, 'selectedRScores', 'rStatLabels',
192                                     selectionMode = QListWidget.MultiSelection,
193                                     callback=self.newscoreselection)
[9599]194
195        self.mStatLabels = [s.name for s in self.mStatistics]
196        self.mbox = OWGUI.widgetBox(self.controlArea, "Performance scores", addToLayout=False)
197        self.mstatLB = OWGUI.listBox(self.mbox, self, 'selectedMScores', 'mStatLabels',
198                                     selectionMode = QListWidget.MultiSelection,
199                                     callback=self.newscoreselection)
[9458]200       
201        self.statLayout.addWidget(self.cbox)
202        self.statLayout.addWidget(self.rbox)
[9599]203        self.statLayout.addWidget(self.mbox)
[9458]204        self.controlArea.layout().addLayout(self.statLayout)
205       
206        self.statLayout.setCurrentWidget(self.cbox)
207
208        # score table
209        # table with results
210        self.g = OWGUI.widgetBox(self.mainArea, 'Evaluation Results')
211        self.tab = OWGUI.table(self.g, selectionMode = QTableWidget.NoSelection)
212
213        self.resize(680,470)
214
215    # scoring and painting of score table
216    def isclassification(self):
217        if not self.data or not self.data.domain.classVar:
218            return True
219        return self.data.domain.classVar.varType == orange.VarTypes.Discrete
[9599]220
221    def ismultilabel(self, data = 42):
222        if data==42:
223            data = self.data
224        if not data:
225            return False
226        return Orange.multilabel.is_multilabel(data)
227
228    def get_usestat(self):
229        return ([self.selectedRScores, self.selectedCScores, self.selectedMScores]
230                [2 if self.ismultilabel() else self.isclassification()])
231
[9458]232    def paintscores(self):
233        """paints the table with evaluation scores"""
234
235        self.tab.setColumnCount(len(self.stat)+1)
236        self.tab.setHorizontalHeaderLabels(["Method"] + [s.label for s in self.stat])
237       
238        prec="%%.%df" % self.precision
239
240        learners = [(l.time, l) for l in self.learners.values()]
241        learners.sort()
242        learners = [lt[1] for lt in learners]
243
244        self.tab.setRowCount(len(self.learners))
245        for (i, l) in enumerate(learners):
246            OWGUI.tableItem(self.tab, i,0, l.name)
247           
248        for (i, l) in enumerate(learners):
249            if l.scores:
250                for j in range(len(self.stat)):
251                    if l.scores[j] is not None:
252                        OWGUI.tableItem(self.tab, i, j+1, prec % l.scores[j])
253                    else:
254                        OWGUI.tableItem(self.tab, i, j+1, "N/A")
255            else:
256                for j in range(len(self.stat)):
257                    OWGUI.tableItem(self.tab, i, j+1, "")
258       
259        # adjust the width of the score table cloumns
260        self.tab.resizeColumnsToContents()
261        self.tab.resizeRowsToContents()
[9599]262        usestat = self.get_usestat()
[9458]263        for i in range(len(self.stat)):
264            if i not in usestat:
[10983]265                self.tab.hideColumn(i + 1)
266            else:
267                self.tab.showColumn(i + 1)
[9458]268
269    def sendReport(self):
[11000]270        method = [("Method", self.resamplingMethods[self.resampling])]
271
[9458]272        exset = []
[11000]273
[9458]274        if self.resampling == 0:
275            exset = [("Folds", self.nFolds)]
276        elif self.resampling == 2:
[11000]277            exset = [("Repetitions", self.pRepeat),
278                     ("Proportion of training instances", "%i%%" \
279                      % self.pLearning)]
[9458]280        else:
281            exset = []
[11000]282
283        if self.data and \
284                isinstance(self.data.domain.classVar, orange.EnumVariable):
285            target = [("Target class",
286                       self.data.domain.classVar.values[self.targetClass])]
287        else:
288            target = []
289
[9599]290        if not self.ismultilabel():
[11000]291            self.reportSettings("Validation method", method + exset + target)
[9599]292        else:
[11000]293            self.reportSettings("Validation method", method + exset)
294
[9458]295        self.reportData(self.data)
296
[11000]297        if self.data:
[9458]298            self.reportSection("Results")
299            learners = [(l.time, l) for l in self.learners.values()]
300            learners.sort()
301            learners = [lt[1] for lt in learners]
[9599]302            usestat = self.get_usestat()
[11000]303            # table header
304            res = "<table><tr><th></th>" + \
305                  "".join("<th><b>%s</b></th>" % s.label \
306                          for i, s in enumerate(self.stat) \
307                          if i in usestat) + \
308                  "</tr>"
309
310            for i, learner in enumerate(learners):
311                res += "<tr><th><b>%s</b></th>" % learner.name
312                if learner.scores:
[9458]313                    for j in usestat:
[11000]314                        score = learner.scores[j]
315                        score = "%.4f" % score if score is not None else ""
316                        res += "<td>" + score + "</td>"
[9458]317                res += "</tr>"
318            res += "</table>"
319            self.reportRaw(res)
[11000]320
[9458]321    def score(self, ids):
322        """compute scores for the list of learners"""
323        if (not self.data):
324            for id in ids:
325                self.learners[id].results = None
[9528]326                self.learners[id].scores = []
[9458]327            return
328        # test which learners can accept the given data set
329        # e.g., regressions can't deal with classification data
330        learners = []
[9528]331        used_ids = []
[9458]332        n = len(self.data.domain.attributes)*2
333        indices = orange.MakeRandomIndices2(p0=min(n, len(self.data)), stratified=orange.MakeRandomIndices2.StratifiedIfPossible)
[10999]334        new = self.data.selectref(indices(self.data), 0)
335
[9599]336        multilabel = self.ismultilabel()
337       
[9458]338        self.warning(0)
339        learner_exceptions = []
340        for l in [self.learners[id] for id in ids]:
341            learner = l.learner
342            if self.preprocessor:
343                learner = self.preprocessor.wrapLearner(learner)
344            try:
345                predictor = learner(new)
[11562]346            except Exception, ex:
347                learner_exceptions.append((l, ex))
348                l.scores = []
349                l.results = None
350            else:
[9599]351                if (multilabel and isinstance(learner, Orange.multilabel.MultiLabelLearner)) or predictor(new[0]).varType == new.domain.classVar.varType:
[9458]352                    learners.append(learner)
[9528]353                    used_ids.append(l.id)
[8042]354                else:
355                    l.scores = []
[9528]356                    l.results = None
[9599]357
[9458]358        if learner_exceptions:
359            text = "\n".join("Learner %s ends with exception: %s" % (l.name, str(ex)) \
360                             for l, ex in learner_exceptions)
361            self.warning(0, text)
362           
363        if not learners:
364            return
365
366        # computation of results (res, and cm if classification)
367        pb = None
368        if self.resampling==0:
369            pb = OWGUI.ProgressBar(self, iterations=self.nFolds)
[9482]370            res = orngTest.crossValidation(learners, self.data, folds=self.nFolds,
371                                           strat=orange.MakeRandomIndices.StratifiedIfPossible,
372                                           callback=pb.advance, storeExamples = True)
[9458]373            pb.finish()
374        elif self.resampling==1:
375            pb = OWGUI.ProgressBar(self, iterations=len(self.data))
376            res = orngTest.leaveOneOut(learners, self.data,
377                                       callback=pb.advance, storeExamples = True)
378            pb.finish()
379        elif self.resampling==2:
380            pb = OWGUI.ProgressBar(self, iterations=self.pRepeat)
381            res = orngTest.proportionTest(learners, self.data, self.pLearning/100.,
382                                          times=self.pRepeat, callback=pb.advance, storeExamples = True)
383            pb.finish()
384        elif self.resampling==3:
385            pb = OWGUI.ProgressBar(self, iterations=len(learners))
386            res = orngTest.learnAndTestOnLearnData(learners, self.data, storeExamples = True, callback=pb.advance)
387            pb.finish()
388           
389        elif self.resampling==4:
390            if not self.testdata:
391                for l in self.learners.values():
392                    l.scores = []
393                return
394            pb = OWGUI.ProgressBar(self, iterations=len(learners))
395            res = orngTest.learnAndTestOnTestData(learners, self.data, self.testdata, storeExamples = True, callback=pb.advance)
396            pb.finish()
[9483]397           
[9599]398        if not self.ismultilabel() and self.isclassification():
[9458]399            cm = orngStat.computeConfusionMatrices(res, classIndex = self.targetClass)
[9483]400        else:
401            cm = None
[9458]402
403        if self.preprocessor: # Unwrap learners
404            learners = [l.wrappedLearner for l in learners]
405           
406        res.learners = learners
407       
408        for l in [self.learners[id] for id in ids]:
409            if l.learner in learners:
410                l.results = res
[9528]411            else:
412                l.results = None
[9458]413
414        self.error(range(len(self.stat)))
415        scores = []
[9483]416       
[9458]417        for i, s in enumerate(self.stat):
[9483]418            if s.cmBased:
419                try:
420                    scores.append(dispatch(s, res, cm))
421                except Exception, ex:
422                    self.error(i, "An error occurred while evaluating orngStat." + s.f + "on %s due to %s" % \
423                               (" ".join([l.name for l in learners]), ex.message))
424                    scores.append([None] * len(self.learners))
425            else:
426                scores_one = []
427                for res_one in orngStat.split_by_classifiers(res):
428                    try:
429                        scores_one.extend(dispatch(s, res_one, cm))
430                    except Exception, ex:
431                        self.error(i, "An error occurred while evaluating orngStat." + s.f + "on %s due to %s" % \
432                                   (res.classifierNames[0], ex.message))
433                        scores_one.append(None)
434                scores.append(scores_one)
[9599]435
[9528]436        for i, (id, l) in enumerate(zip(used_ids, learners)):
437            self.learners[id].scores = [s[i] if s else None for s in scores]
[9458]438           
439        self.sendResults()
440
441    def recomputeCM(self):
442        if not self.results:
443            return
444        cm = orngStat.computeConfusionMatrices(self.results, classIndex = self.targetClass)
445        scores = [(indx, eval("orngStat." + s.f))
446                  for (indx, s) in enumerate(self.stat) if s.cmBased]
447        for (indx, score) in scores:
448            for (i, l) in enumerate([l for l in self.learners.values() if l.scores]):
[9528]449                learner_indx = self.results.learners.index(l.learner)
450                l.scores[indx] = score[learner_indx]
[9599]451
[9458]452        self.paintscores()
453       
[9528]454    def clearScores(self, ids=None):
[11561]455        """
456        Clear/invalidate evaluation results/scores for learners for an
457        `ids` sequence (keys into self.learners). If `ids` is None then
458        invalidate data for all learners.
459
460        """
[9528]461        if ids is None:
462            ids = self.learners.keys()
[9599]463
[9528]464        for id in ids:
465            self.learners[id].scores = []
466            self.learners[id].results = None
[9599]467
[9458]468    # handle input signals
469    def setData(self, data):
[11561]470        """
471        Set the input train data set.
472        """
[9458]473        self.closeContext()
[11561]474
475        self.clearScores()
476
477        multilabel = self.ismultilabel(data)
[9599]478        if not multilabel:
[11561]479            if self.isDataWithClass(data, checkMissing=True):
480                self.data = orange.Filter_hasClassValue(data)
481            else:
482                self.data = None
[9599]483        else:
484            self.data = data
[11561]485
[9458]486        self.fillClassCombo()
[9599]487
[11561]488        if self.data is not None:
489            # Ensure the correct statistics selection is visible.
490            if self.ismultilabel():
491                statwidget = self.mbox
492                stat = self.mStatistics
493            elif self.isclassification():
494                statwidget = self.cbox
495                stat = self.cStatistics
496            else:
497                statwidget = self.rbox
498                stat = self.rStatistics
[9599]499
[11561]500            self.statLayout.setCurrentWidget(statwidget)
[9599]501
[11561]502            self.stat = stat
[9458]503
504    def setTestData(self, data):
[11561]505        """
506        Set the 'Separate Test Data' input.
507        """
[9458]508        if data is None:
509            self.testdata = None
510        else:
511            self.testdata = orange.Filter_hasClassValue(data)
512        self.testDataBtn.setEnabled(self.testdata is not None)
513        if self.testdata is not None:
514            if self.resampling == 4:
[11561]515                self.clearScores()
516
[9458]517        elif self.resampling == 4 and self.data:
518            # test data removed, switch to testing on train data
519            self.resampling = 3
[11561]520            self.clearScores()
[9458]521
522    def fillClassCombo(self):
523        """upon arrival of new data appropriately set the target class combo"""
524        self.targetCombo.clear()
525        if not self.data or not self.data.domain.classVar or not self.isclassification():
526            return
527
528        domain = self.data.domain
529        self.targetCombo.addItems([str(v) for v in domain.classVar.values])
530       
531        if self.targetClass<len(domain.classVar.values):
[8042]532            self.targetCombo.setCurrentIndex(self.targetClass)
[9458]533        else:
534            self.targetCombo.setCurrentIndex(0)
535            self.targetClass=0
536
537    def setLearner(self, learner, id=None):
[11561]538        """
539        Set the input learner with `id`.
540        """
541        if learner is not None:
542            # a new or updated learner
543            if id in self.learners:
[9458]544                time = self.learners[id].time
545                self.learners[id] = Learner(learner, id)
546                self.learners[id].time = time
[11561]547                self.clearScores([id])
548            else:
[9458]549                self.learners[id] = Learner(learner, id)
[11561]550        else:
551            # remove a learner and corresponding results
[9458]552            if id in self.learners:
553                res = self.learners[id].results
554                if res and res.numberOfLearners > 1:
[11561]555                    # Remove the learner from the shared results instance
[9528]556                    old_learner = self.learners[id].learner
557                    indx = res.learners.index(old_learner)
[9458]558                    res.remove(indx)
559                    del res.learners[indx]
560                del self.learners[id]
[11561]561
[9458]562    def setPreprocessor(self, pp):
563        self.preprocessor = pp
[11561]564        self.clearScores()
565
566    def handleNewSignals(self):
567        """
568        Update the evaluations/scores after new inputs were received.
569        """
570        def needsupdate(learner_id):
571            return not (self.learners[learner_id].scores or
572                        self.learners[learner_id].results)
573
574        if self.applyOnAnyChange:
575            self.score(filter(needsupdate, self.learners))
576            self.applyBtn.setEnabled(False)
577        else:
578            self.applyBtn.setEnabled(True)
579
580        self.paintscores()
581
582        if self.data is not None:
583            self.openContext("", self.data)
584        else:
585            self.send("Evaluation Results", None)
[9458]586
587    # handle output signals
588
589    def sendResults(self):
590        """commit evaluation results"""
[11679]591        learners = sorted(self.learners.values(),
592                          key=lambda learner: learner.time)
[9458]593
[11679]594        learners = [learner for learner in learners
595                    if learner.results and learner.scores]
596
597        if not (self.data and len(learners)):
[9458]598            self.send("Evaluation Results", None)
599            return
600
[11679]601        # Split combined results by learner/classifier
602        results_by_learner = {}
603        for learner in learners:
604            results = learner.results
605            res_split = orngStat.split_by_classifiers(results)
606            index = results.learners.index(learner.learner)
607            res_single = res_split[index]
608            res_single.learners = [learner.learner]
609            results_by_learner[learner] = res_single
[9599]610
[11679]611        def add_results(rhs, lhs):
612            assert(lhs.number_of_learners == 1)
613            rhs.add(lhs, 0)
614            rhs.learners.extend(lhs.learners)
615            return rhs
616
617        results = [results_by_learner[learner] for learner in learners]
618
619        results = reduce(add_results, results)
620
621        self.results = results
[9458]622        self.send("Evaluation Results", results)
[11679]623        return
[9458]624
625    # signal processing
626
627    def newsampling(self):
628        """handle change of evaluation method"""
629        if not self.applyOnAnyChange:
630            self.applyBtn.setDisabled(self.applyOnAnyChange)
631        else:
[11561]632            self.clearScores()
[9458]633            if self.learners:
634                self.recompute()
635
636    def newscoreselection(self):
637        """handle change in set of scores to be displayed"""
[9599]638        usestat = self.get_usestat()
[9458]639        for i in range(len(self.stat)):
640            if i in usestat:
641                self.tab.showColumn(i+1)
642                self.tab.resizeColumnToContents(i+1)
643            else:
644                self.tab.hideColumn(i+1)
[8042]645
[9458]646    def recompute(self, forced=False):
647        """recompute the scores for all learners,
648           if not forced, will do nothing but enable the Apply button"""
649        if self.applyOnAnyChange or forced:
650            self.score([l.id for l in self.learners.values()])
651            self.paintscores()
652            self.applyBtn.setDisabled(True)
653        else:
654            self.applyBtn.setEnabled(True)
655
656    def conditionalRecompute(self, option):
657        """calls recompute only if specific sampling option enabled"""
658        if self.resampling == option:
659            self.recompute(False)
660
661    def applyChange(self):
[11561]662        """
663        A change of 'Apply on any change' check box.
664        """
665        def needsupdate(learner_id):
666            return not (self.learners[learner_id].scores or
667                        self.learners[learner_id].results)
668
[9458]669        if self.applyOnAnyChange:
670            self.applyBtn.setDisabled(True)
[11561]671            pending = filter(needsupdate, self.learners)
672            if pending:
673                self.score(pending)
674                self.paintscores()
675
[9458]676    def changedTarget(self):
677        self.recomputeCM()
678
679##############################################################################
680# Test the widget, run from DOS prompt
681
682if __name__=="__main__":
683    a=QApplication(sys.argv)
684    ow=OWTestLearners()
685    ow.show()
686
[11679]687    data1 = orange.ExampleTable('voting')
688#     data2 = orange.ExampleTable('golf')
689#     datar = orange.ExampleTable('auto-mpg')
690#     data3 = orange.ExampleTable('sailing-big')
691#     data4 = orange.ExampleTable('sailing-test')
[9599]692    data5 = orange.ExampleTable('emotions')
[9458]693
694    l1 = orange.MajorityLearner(); l1.name = '1 - Majority'
695
696    l2 = orange.BayesLearner()
697    l2.estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10)
698    l2.conditionalEstimatorConstructor = \
699        orange.ConditionalProbabilityEstimatorConstructor_ByRows(
700        estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10))
701    l2.name = '2 - NBC (m=10)'
702
703    l3 = orange.BayesLearner(); l3.name = '3 - NBC (default)'
704
705    l4 = orange.MajorityLearner(); l4.name = "4 - Majority"
[8042]706
[9458]707    import orngRegression as r
708    r5 = r.LinearRegressionLearner(name="0 - lin reg")
709
[9599]710    l5 = Orange.multilabel.BinaryRelevanceLearner()
711
[11679]712    testcase = 0
[9458]713
714    if testcase == 0: # 1(UPD), 3, 4
[11679]715        ow.setData(data1)
[9458]716        ow.setLearner(r5, 5)
717        ow.setLearner(l1, 1)
718        ow.setLearner(l2, 2)
719        ow.setLearner(l3, 3)
[11679]720        ow.handleNewSignals()
721
[9458]722        l1.name = l1.name + " UPD"
723        ow.setLearner(l1, 1)
724        ow.setLearner(None, 2)
725        ow.setLearner(l4, 4)
[11679]726        ow.handleNewSignals()
727
[9458]728    if testcase == 1: # data, but all learners removed
729        ow.setLearner(l1, 1)
730        ow.setLearner(l2, 2)
731        ow.setLearner(l1, 1)
732        ow.setLearner(None, 2)
733        ow.setData(data2)
734        ow.setLearner(None, 1)
735    if testcase == 2: # sends data, then learner, then removes the learner
736        ow.setData(data2)
737        ow.setLearner(l1, 1)
738        ow.setLearner(None, 1)
739    if testcase == 3: # regression first
740        ow.setData(datar)
741        ow.setLearner(r5, 5)
742    if testcase == 4: # separate train and test data
743        ow.setData(data3)
744        ow.setTestData(data4)
745        ow.setLearner(l2, 5)
746        ow.setTestData(None)
[9599]747    if testcase == 5: # MLC
748        ow.setData(data5)
749        ow.setLearner(l5, 6)
[8042]750
[11679]751    a.exec_()
[9458]752    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.