Ignore:
Timestamp:
01/27/12 23:39:05 (2 years ago)
Author:
Matija Polajnar <matija.polajnar@…>
Branch:
default
Message:

Multi-label classificaiton widgets. Merged in from Wencan Luo's work with some modifications.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/OrangeWidgets/Evaluate/OWTestLearners.py

    r9528 r9599  
    1717                        orange.AttributeWarning) 
    1818 
     19import Orange 
     20 
    1921############################################################################## 
    2022 
     
    4951    callbackDeposit = [] 
    5052 
     53    # Classification 
    5154    cStatistics = [Score(*s) for s in [\ 
    5255        ('Classification accuracy', 'CA', 'CA(res)', True), 
     
    6164        ('Matthews correlation coefficient', 'MCC', 'MCC(cm)', False, True)]] 
    6265 
     66    # Regression 
    6367    rStatistics = [Score(*s) for s in [\ 
    6468        ("Mean squared error", "MSE", "MSE(res)", False), 
     
    6973        ("Relative absolute error", "RAE", "RAE(res)", False), 
    7074        ("R-squared", "R2", "R2(res)")]] 
    71  
     75     
     76    # Multi-Label 
     77    mStatistics = [Score(*s) for s in [\ 
     78        ('Hamming Loss', 'HammingLoss', 'mlc_hamming_loss(res)', False), 
     79        ('Accuracy', 'Accuracy', 'mlc_accuracy(res)', False), 
     80        ('Precision', 'Precision', 'mlc_precision(res)', False), 
     81        ('Recall', 'Recall', 'mlc_recall(res)', False),                                
     82        ]] 
     83     
    7284    resamplingMethods = ["Cross-validation", "Leave-one-out", "Random sampling", 
    7385                         "Test on train data", "Test on test data"] 
     
    7688        OWWidget.__init__(self, parent, signalManager, "TestLearners") 
    7789 
    78         self.inputs = [("Data", ExampleTable, self.setData, Default),  
     90        self.inputs = [("Data", ExampleTable, self.setData, Default), 
    7991                       ("Separate Test Data", ExampleTable, self.setTestData), 
    8092                       ("Learner", orange.Learner, self.setLearner, Multiple + Default), 
    8193                       ("Preprocess", PreprocessedLearner, self.setPreprocessor)] 
    82          
     94 
    8395        self.outputs = [("Evaluation Results", orngTest.ExperimentResults)] 
    8496 
     
    92104        self.selectedCScores = [i for (i,s) in enumerate(self.cStatistics) if s.show] 
    93105        self.selectedRScores = [i for (i,s) in enumerate(self.rStatistics) if s.show] 
     106        self.selectedMScores = [i for (i,s) in enumerate(self.mStatistics) if s.show] 
    94107        self.targetClass = 0 
    95108        self.loadSettings() 
     
    174187                                     selectionMode = QListWidget.MultiSelection, 
    175188                                     callback=self.newscoreselection) 
     189 
     190        self.mStatLabels = [s.name for s in self.mStatistics] 
     191        self.mbox = OWGUI.widgetBox(self.controlArea, "Performance scores", addToLayout=False) 
     192        self.mstatLB = OWGUI.listBox(self.mbox, self, 'selectedMScores', 'mStatLabels', 
     193                                     selectionMode = QListWidget.MultiSelection, 
     194                                     callback=self.newscoreselection) 
    176195         
    177196        self.statLayout.addWidget(self.cbox) 
    178197        self.statLayout.addWidget(self.rbox) 
     198        self.statLayout.addWidget(self.mbox) 
    179199        self.controlArea.layout().addLayout(self.statLayout) 
    180200         
     
    193213            return True 
    194214        return self.data.domain.classVar.varType == orange.VarTypes.Discrete 
    195          
     215 
     216    def ismultilabel(self, data = 42): 
     217        if data==42: 
     218            data = self.data 
     219        if not data: 
     220            return False 
     221        return Orange.multilabel.is_multilabel(data) 
     222 
     223    def get_usestat(self): 
     224        return ([self.selectedRScores, self.selectedCScores, self.selectedMScores] 
     225                [2 if self.ismultilabel() else self.isclassification()]) 
     226 
    196227    def paintscores(self): 
    197228        """paints the table with evaluation scores""" 
     
    224255        self.tab.resizeColumnsToContents() 
    225256        self.tab.resizeRowsToContents() 
    226         usestat = [self.selectedRScores, self.selectedCScores][self.isclassification()] 
     257        usestat = self.get_usestat() 
    227258        for i in range(len(self.stat)): 
    228259            if i not in usestat: 
     
    237268        else: 
    238269            exset = [] 
    239         self.reportSettings("Validation method", 
     270        if not self.ismultilabel(): 
     271            self.reportSettings("Validation method", 
    240272                            [("Method", self.resamplingMethods[self.resampling])] 
    241273                            + exset + 
    242274                            ([("Target class", self.data.domain.classVar.values[self.targetClass])] if self.data else [])) 
     275        else: 
     276             self.reportSettings("Validation method", 
     277                            [("Method", self.resamplingMethods[self.resampling])] 
     278                            + exset) 
    243279         
    244280        self.reportData(self.data) 
     
    249285            learners.sort() 
    250286            learners = [lt[1] for lt in learners] 
    251             usestat = [self.selectedRScores, self.selectedCScores][self.isclassification()] 
    252              
     287            usestat = self.get_usestat() 
    253288            res = "<table><tr><th></th>"+"".join("<th><b>%s</b></th>" % hr for hr in [s.label for i, s in enumerate(self.stat) if i in usestat])+"</tr>" 
    254289            for i, l in enumerate(learners): 
     
    277312        new = self.data.selectref(indices(self.data)) 
    278313         
     314        multilabel = self.ismultilabel() 
     315         
    279316        self.warning(0) 
    280317        learner_exceptions = [] 
     
    285322            try: 
    286323                predictor = learner(new) 
    287                 if predictor(new[0]).varType == new.domain.classVar.varType: 
     324                if (multilabel and isinstance(learner, Orange.multilabel.MultiLabelLearner)) or predictor(new[0]).varType == new.domain.classVar.varType: 
    288325                    learners.append(learner) 
    289326                    used_ids.append(l.id) 
     
    291328                    l.scores = [] 
    292329                    l.results = None 
    293                  
     330 
    294331            except Exception, ex: 
    295332                learner_exceptions.append((l, ex)) 
    296333                l.scores = [] 
    297334                l.results = None 
    298          
     335 
    299336        if learner_exceptions: 
    300337            text = "\n".join("Learner %s ends with exception: %s" % (l.name, str(ex)) \ 
     
    337374            pb.finish() 
    338375             
    339         if self.isclassification(): 
     376        if not self.ismultilabel() and self.isclassification(): 
    340377            cm = orngStat.computeConfusionMatrices(res, classIndex = self.targetClass) 
    341378        else: 
     
    374411                        scores_one.append(None) 
    375412                scores.append(scores_one) 
    376                  
     413 
    377414        for i, (id, l) in enumerate(zip(used_ids, learners)): 
    378415            self.learners[id].scores = [s[i] if s else None for s in scores] 
     
    390427                learner_indx = self.results.learners.index(l.learner) 
    391428                l.scores[indx] = score[learner_indx] 
    392                  
     429 
    393430        self.paintscores() 
    394431         
     
    396433        if ids is None: 
    397434            ids = self.learners.keys() 
    398              
     435 
    399436        for id in ids: 
    400437            self.learners[id].scores = [] 
    401438            self.learners[id].results = None 
    402          
     439 
    403440    # handle input signals 
    404441    def setData(self, data): 
    405442        """handle input train data set""" 
    406443        self.closeContext() 
    407         self.data = self.isDataWithClass(data, checkMissing=True) and data or None 
     444         
     445        multilabel= self.ismultilabel(data) 
     446        if not multilabel: 
     447            self.data = self.isDataWithClass(data, checkMissing=True) and data or None 
     448        else: 
     449            self.data = data 
     450         
    408451        self.fillClassCombo() 
    409452        if not self.data: 
     
    414457            # new data has arrived 
    415458            self.clearScores() 
    416              
    417             self.data = orange.Filter_hasClassValue(self.data) 
    418             self.statLayout.setCurrentWidget(self.cbox if self.isclassification() else self.rbox) 
    419              
    420             self.stat = [self.rStatistics, self.cStatistics][self.isclassification()] 
    421              
     459 
     460            if not multilabel: 
     461                self.data = orange.Filter_hasClassValue(self.data) 
     462 
     463            self.statLayout.setCurrentWidget([self.rbox, self.cbox, self.mbox][2 if self.ismultilabel() else self.isclassification()]) 
     464 
     465            self.stat = [self.rStatistics, self.cStatistics, self.mStatistics][2 if self.ismultilabel() else self.isclassification()] 
     466 
    422467            if self.learners: 
    423468                self.score([l.id for l in self.learners.values()]) 
    424  
     469             
    425470        self.openContext("", data) 
    426471        self.paintscores() 
     
    516561                results.remove(i) 
    517562                del results.learners[i] 
    518              
     563 
    519564        for r in rlist: 
    520565            for (i, l) in enumerate(r.learners): 
     
    539584    def newscoreselection(self): 
    540585        """handle change in set of scores to be displayed""" 
    541         usestat = [self.selectedRScores, self.selectedCScores][self.isclassification()] 
     586        usestat = self.get_usestat() 
    542587        for i in range(len(self.stat)): 
    543588            if i in usestat: 
     
    583628    data3 = orange.ExampleTable(r'../../sailing-big') 
    584629    data4 = orange.ExampleTable(r'../../sailing-test') 
     630    data5 = orange.ExampleTable('emotions') 
    585631 
    586632    l1 = orange.MajorityLearner(); l1.name = '1 - Majority' 
     
    599645    import orngRegression as r 
    600646    r5 = r.LinearRegressionLearner(name="0 - lin reg") 
     647 
     648    l5 = Orange.multilabel.BinaryRelevanceLearner() 
    601649 
    602650    testcase = 4 
     
    634682        ow.setLearner(l2, 5) 
    635683        ow.setTestData(None) 
     684    if testcase == 5: # MLC 
     685        ow.setData(data5) 
     686        ow.setLearner(l5, 6) 
    636687 
    637688    ow.saveSettings() 
Note: See TracChangeset for help on using the changeset viewer.