Changeset 9599:69c91d52d3e4 in orange


Ignore:
Timestamp:
01/27/12 23:39:05 (2 years ago)
Author:
Matija Polajnar <matija.polajnar@…>
Branch:
default
Message:

Multi-label classificaiton widgets. Merged in from Wencan Luo's work with some modifications.

Location:
orange
Files:
4 added
8 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/evaluation/scoring.py

    r9550 r9599  
    489489        return [res] 
    490490         
    491     ress = [Orange.evaluation.testing.ExperimentResults(1, res.classifier_names, res.class_values, res.weights, classifiers=res.classifiers, loaded=res.loaded) 
     491    ress = [Orange.evaluation.testing.ExperimentResults(1, res.classifier_names, res.class_values, res.weights, classifiers=res.classifiers, loaded=res.loaded, test_type=res.test_type, labels=res.labels) 
    492492            for i in range(res.number_of_iterations)] 
    493493    for te in res.results: 
     
    504504                    [res.classifierNames[i]], res.classValues, 
    505505                    weights=res.weights, baseClass=res.baseClass, 
    506                     classifiers=[res.classifiers[i]] if res.classifiers else []) 
     506                    classifiers=[res.classifiers[i]] if res.classifiers else [], 
     507                    test_type = res.test_type, labels = res.labels) 
    507508        r.results = [] 
    508509        for te in res.results: 
     
    15971598    import corn 
    15981599    ## merge multiple iterations into one 
    1599     mres = Orange.evaluation.testing.ExperimentResults(1, res.classifier_names, res.class_values, res.weights, classifiers=res.classifiers, loaded=res.loaded) 
     1600    mres = Orange.evaluation.testing.ExperimentResults(1, res.classifier_names, res.class_values, res.weights, classifiers=res.classifiers, loaded=res.loaded, test_type=res.test_type, labels=res.labels) 
    16001601    for te in res.results: 
    16011602        mres.results.append( te ) 
     
    16591660    import corn 
    16601661    ## merge multiple iterations into one 
    1661     mres = Orange.evaluation.testing.ExperimentResults(1, res.classifier_names, res.class_values, res.weights, classifiers=res.classifiers, loaded=res.loaded) 
     1662    mres = Orange.evaluation.testing.ExperimentResults(1, res.classifier_names, res.class_values, res.weights, classifiers=res.classifiers, loaded=res.loaded, test_type=res.test_type, labels=res.labels) 
    16621663    for te in res.results: 
    16631664        mres.results.append( te ) 
     
    26032604    label_num = len(res.labels) 
    26042605    example_num = gettotsize(res) 
    2605      
     2606 
    26062607    for e in res.results: 
    26072608        aclass = e.actual_class 
  • orange/Orange/evaluation/testing.py

    r9555 r9599  
    9393                          "numberOfIterations": "number_of_iterations", 
    9494                          "numberOfLearners": "number_of_learners"}) 
    95     def __init__(self, iterations, classifier_names, class_values=None, weights=None, base_class=-1, domain=None, test_type=TEST_TYPE_SINGLE, **argkw): 
     95    def __init__(self, iterations, classifier_names, class_values=None, weights=None, base_class=-1, domain=None, test_type=TEST_TYPE_SINGLE, labels=None, **argkw): 
    9696        self.class_values = class_values 
    9797        self.classifier_names = classifier_names 
     
    104104        self.weights = weights 
    105105        self.test_type = test_type 
     106        self.labels = labels 
    106107 
    107108        if domain is not None: 
  • orange/OrangeWidgets/Data/OWFile.py

    r9546 r9599  
    284284        warnings = "" 
    285285        metas = data.domain.getmetas() 
    286         for status, messageUsed, messageNotUsed in [ 
    287                                 (orange.Variable.MakeStatus.Incompatible, 
    288                                  "", 
    289                                  "The following attributes already existed but had a different order of values, so new attributes needed to be created"), 
    290                                 (orange.Variable.MakeStatus.NoRecognizedValues, 
    291                                  "The following attributes were reused although they share no common values with the existing attribute of the same names", 
    292                                  "The following attributes were not reused since they share no common values with the existing attribute of the same names"), 
    293                                 (orange.Variable.MakeStatus.MissingValues, 
    294                                  "The following attribute(s) were reused although some values needed to be added", 
    295                                  "The following attribute(s) were not reused since they miss some values") 
    296                                 ]: 
    297             if self.createNewOn > status: 
    298                 message = messageUsed 
    299             else: 
    300                 message = messageNotUsed 
    301             if not message: 
    302                 continue 
    303             attrs = [attr.name for attr, stat in zip(data.domain, data.attributeLoadStatus) if stat == status] \ 
    304                   + [attr.name for id, attr in metas.items() if data.metaAttributeLoadStatus.get(id, -99) == status] 
    305             if attrs: 
    306                 jattrs = ", ".join(attrs) 
    307                 if len(jattrs) > 80: 
    308                     jattrs = jattrs[:80] + "..." 
    309                 if len(jattrs) > 30:  
    310                     warnings += "<li>%s:<br/> %s</li>" % (message, jattrs) 
     286        if hasattr(data, "attribute_load_status"):  # For some file formats, this is not populated 
     287            for status, messageUsed, messageNotUsed in [ 
     288                                    (orange.Variable.MakeStatus.Incompatible, 
     289                                     "", 
     290                                     "The following attributes already existed but had a different order of values, so new attributes needed to be created"), 
     291                                    (orange.Variable.MakeStatus.NoRecognizedValues, 
     292                                     "The following attributes were reused although they share no common values with the existing attribute of the same names", 
     293                                     "The following attributes were not reused since they share no common values with the existing attribute of the same names"), 
     294                                    (orange.Variable.MakeStatus.MissingValues, 
     295                                     "The following attribute(s) were reused although some values needed to be added", 
     296                                     "The following attribute(s) were not reused since they miss some values") 
     297                                    ]: 
     298                if self.createNewOn > status: 
     299                    message = messageUsed 
    311300                else: 
    312                     warnings += "<li>%s: %s</li>" % (message, jattrs) 
     301                    message = messageNotUsed 
     302                if not message: 
     303                    continue 
     304                attrs = [attr.name for attr, stat in zip(data.domain, data.attributeLoadStatus) if stat == status] \ 
     305                      + [attr.name for id, attr in metas.items() if data.metaAttributeLoadStatus.get(id, -99) == status] 
     306                if attrs: 
     307                    jattrs = ", ".join(attrs) 
     308                    if len(jattrs) > 80: 
     309                        jattrs = jattrs[:80] + "..." 
     310                    if len(jattrs) > 30:  
     311                        warnings += "<li>%s:<br/> %s</li>" % (message, jattrs) 
     312                    else: 
     313                        warnings += "<li>%s: %s</li>" % (message, jattrs) 
    313314 
    314315        self.warnings.setText(warnings) 
  • orange/OrangeWidgets/Evaluate/OWCalibrationPlot.py

    r9505 r9599  
    1313import orngTest, orngStat 
    1414import statc, math 
     15from Orange.evaluation.testing import TEST_TYPE_SINGLE 
    1516 
    1617class singleClassCalibrationPlotGraph(OWGraph): 
     
    333334 
    334335        self.dres = dres 
     336         
     337        if dres and dres.test_type != TEST_TYPE_SINGLE: 
     338            self.warning(0, "Calibration plot is supported only for single-target prediction problems.") 
     339            return 
     340        self.warning(0, None) 
    335341 
    336342        self.graphs = [] 
  • orange/OrangeWidgets/Evaluate/OWConfusionMatrix.py

    r9546 r9599  
    1111import statc, math 
    1212from operator import add 
     13from Orange.evaluation.testing import TEST_TYPE_SINGLE 
    1314             
    1415class TransformedLabel(QLabel): 
     
    110111            return 
    111112 
     113        if res and res.test_type != TEST_TYPE_SINGLE: 
     114            self.warning(0, "Confusion matrix can be calculated only for single-target prediction problems.") 
     115            return 
     116        self.warning(0, None) 
     117         
    112118        self.matrix = orngStat.confusionMatrices(res, -2) 
    113119 
  • orange/OrangeWidgets/Evaluate/OWLiftCurve.py

    r9505 r9599  
    1414import orngStat, orngEval 
    1515import statc, math 
     16from Orange.evaluation.testing import TEST_TYPE_SINGLE 
    1617 
    1718class singleClassLiftCurveGraph(singleClassROCgraph): 
     
    347348            return 
    348349 
     350        if dres and dres.test_type != TEST_TYPE_SINGLE: 
     351            self.warning(0, "Lift curve is supported only for single-target prediction problems.") 
     352            return 
     353        self.warning(0, None) 
     354 
    349355        self.defaultPerfLinePValues = [] 
    350356        if self.dres <> None: 
  • orange/OrangeWidgets/Evaluate/OWROC.py

    r9505 r9599  
    1212import orngStat, orngTest 
    1313import statc, math 
     14import warnings 
     15from Orange.evaluation.testing import TEST_TYPE_SINGLE 
    1416 
    1517def TCconvexHull(curves): 
     
    973975            self.openContext("", dres) 
    974976            return 
     977 
     978        if dres and dres.test_type != TEST_TYPE_SINGLE: 
     979            self.warning(0, "ROC is implemented only for single-target prediction problems.") 
     980            return 
     981        self.warning(0, None) 
     982 
    975983        self.dres = dres 
    976984 
  • orange/OrangeWidgets/Evaluate/OWTestLearners.py

    r9528 r9599  
    1717                        orange.AttributeWarning) 
    1818 
     19import Orange 
     20 
    1921############################################################################## 
    2022 
     
    4951    callbackDeposit = [] 
    5052 
     53    # Classification 
    5154    cStatistics = [Score(*s) for s in [\ 
    5255        ('Classification accuracy', 'CA', 'CA(res)', True), 
     
    6164        ('Matthews correlation coefficient', 'MCC', 'MCC(cm)', False, True)]] 
    6265 
     66    # Regression 
    6367    rStatistics = [Score(*s) for s in [\ 
    6468        ("Mean squared error", "MSE", "MSE(res)", False), 
     
    6973        ("Relative absolute error", "RAE", "RAE(res)", False), 
    7074        ("R-squared", "R2", "R2(res)")]] 
    71  
     75     
     76    # Multi-Label 
     77    mStatistics = [Score(*s) for s in [\ 
     78        ('Hamming Loss', 'HammingLoss', 'mlc_hamming_loss(res)', False), 
     79        ('Accuracy', 'Accuracy', 'mlc_accuracy(res)', False), 
     80        ('Precision', 'Precision', 'mlc_precision(res)', False), 
     81        ('Recall', 'Recall', 'mlc_recall(res)', False),                                
     82        ]] 
     83     
    7284    resamplingMethods = ["Cross-validation", "Leave-one-out", "Random sampling", 
    7385                         "Test on train data", "Test on test data"] 
     
    7688        OWWidget.__init__(self, parent, signalManager, "TestLearners") 
    7789 
    78         self.inputs = [("Data", ExampleTable, self.setData, Default),  
     90        self.inputs = [("Data", ExampleTable, self.setData, Default), 
    7991                       ("Separate Test Data", ExampleTable, self.setTestData), 
    8092                       ("Learner", orange.Learner, self.setLearner, Multiple + Default), 
    8193                       ("Preprocess", PreprocessedLearner, self.setPreprocessor)] 
    82          
     94 
    8395        self.outputs = [("Evaluation Results", orngTest.ExperimentResults)] 
    8496 
     
    92104        self.selectedCScores = [i for (i,s) in enumerate(self.cStatistics) if s.show] 
    93105        self.selectedRScores = [i for (i,s) in enumerate(self.rStatistics) if s.show] 
     106        self.selectedMScores = [i for (i,s) in enumerate(self.mStatistics) if s.show] 
    94107        self.targetClass = 0 
    95108        self.loadSettings() 
     
    174187                                     selectionMode = QListWidget.MultiSelection, 
    175188                                     callback=self.newscoreselection) 
     189 
     190        self.mStatLabels = [s.name for s in self.mStatistics] 
     191        self.mbox = OWGUI.widgetBox(self.controlArea, "Performance scores", addToLayout=False) 
     192        self.mstatLB = OWGUI.listBox(self.mbox, self, 'selectedMScores', 'mStatLabels', 
     193                                     selectionMode = QListWidget.MultiSelection, 
     194                                     callback=self.newscoreselection) 
    176195         
    177196        self.statLayout.addWidget(self.cbox) 
    178197        self.statLayout.addWidget(self.rbox) 
     198        self.statLayout.addWidget(self.mbox) 
    179199        self.controlArea.layout().addLayout(self.statLayout) 
    180200         
     
    193213            return True 
    194214        return self.data.domain.classVar.varType == orange.VarTypes.Discrete 
    195          
     215 
     216    def ismultilabel(self, data = 42): 
     217        if data==42: 
     218            data = self.data 
     219        if not data: 
     220            return False 
     221        return Orange.multilabel.is_multilabel(data) 
     222 
     223    def get_usestat(self): 
     224        return ([self.selectedRScores, self.selectedCScores, self.selectedMScores] 
     225                [2 if self.ismultilabel() else self.isclassification()]) 
     226 
    196227    def paintscores(self): 
    197228        """paints the table with evaluation scores""" 
     
    224255        self.tab.resizeColumnsToContents() 
    225256        self.tab.resizeRowsToContents() 
    226         usestat = [self.selectedRScores, self.selectedCScores][self.isclassification()] 
     257        usestat = self.get_usestat() 
    227258        for i in range(len(self.stat)): 
    228259            if i not in usestat: 
     
    237268        else: 
    238269            exset = [] 
    239         self.reportSettings("Validation method", 
     270        if not self.ismultilabel(): 
     271            self.reportSettings("Validation method", 
    240272                            [("Method", self.resamplingMethods[self.resampling])] 
    241273                            + exset + 
    242274                            ([("Target class", self.data.domain.classVar.values[self.targetClass])] if self.data else [])) 
     275        else: 
     276             self.reportSettings("Validation method", 
     277                            [("Method", self.resamplingMethods[self.resampling])] 
     278                            + exset) 
    243279         
    244280        self.reportData(self.data) 
     
    249285            learners.sort() 
    250286            learners = [lt[1] for lt in learners] 
    251             usestat = [self.selectedRScores, self.selectedCScores][self.isclassification()] 
    252              
     287            usestat = self.get_usestat() 
    253288            res = "<table><tr><th></th>"+"".join("<th><b>%s</b></th>" % hr for hr in [s.label for i, s in enumerate(self.stat) if i in usestat])+"</tr>" 
    254289            for i, l in enumerate(learners): 
     
    277312        new = self.data.selectref(indices(self.data)) 
    278313         
     314        multilabel = self.ismultilabel() 
     315         
    279316        self.warning(0) 
    280317        learner_exceptions = [] 
     
    285322            try: 
    286323                predictor = learner(new) 
    287                 if predictor(new[0]).varType == new.domain.classVar.varType: 
     324                if (multilabel and isinstance(learner, Orange.multilabel.MultiLabelLearner)) or predictor(new[0]).varType == new.domain.classVar.varType: 
    288325                    learners.append(learner) 
    289326                    used_ids.append(l.id) 
     
    291328                    l.scores = [] 
    292329                    l.results = None 
    293                  
     330 
    294331            except Exception, ex: 
    295332                learner_exceptions.append((l, ex)) 
    296333                l.scores = [] 
    297334                l.results = None 
    298          
     335 
    299336        if learner_exceptions: 
    300337            text = "\n".join("Learner %s ends with exception: %s" % (l.name, str(ex)) \ 
     
    337374            pb.finish() 
    338375             
    339         if self.isclassification(): 
     376        if not self.ismultilabel() and self.isclassification(): 
    340377            cm = orngStat.computeConfusionMatrices(res, classIndex = self.targetClass) 
    341378        else: 
     
    374411                        scores_one.append(None) 
    375412                scores.append(scores_one) 
    376                  
     413 
    377414        for i, (id, l) in enumerate(zip(used_ids, learners)): 
    378415            self.learners[id].scores = [s[i] if s else None for s in scores] 
     
    390427                learner_indx = self.results.learners.index(l.learner) 
    391428                l.scores[indx] = score[learner_indx] 
    392                  
     429 
    393430        self.paintscores() 
    394431         
     
    396433        if ids is None: 
    397434            ids = self.learners.keys() 
    398              
     435 
    399436        for id in ids: 
    400437            self.learners[id].scores = [] 
    401438            self.learners[id].results = None 
    402          
     439 
    403440    # handle input signals 
    404441    def setData(self, data): 
    405442        """handle input train data set""" 
    406443        self.closeContext() 
    407         self.data = self.isDataWithClass(data, checkMissing=True) and data or None 
     444         
     445        multilabel= self.ismultilabel(data) 
     446        if not multilabel: 
     447            self.data = self.isDataWithClass(data, checkMissing=True) and data or None 
     448        else: 
     449            self.data = data 
     450         
    408451        self.fillClassCombo() 
    409452        if not self.data: 
     
    414457            # new data has arrived 
    415458            self.clearScores() 
    416              
    417             self.data = orange.Filter_hasClassValue(self.data) 
    418             self.statLayout.setCurrentWidget(self.cbox if self.isclassification() else self.rbox) 
    419              
    420             self.stat = [self.rStatistics, self.cStatistics][self.isclassification()] 
    421              
     459 
     460            if not multilabel: 
     461                self.data = orange.Filter_hasClassValue(self.data) 
     462 
     463            self.statLayout.setCurrentWidget([self.rbox, self.cbox, self.mbox][2 if self.ismultilabel() else self.isclassification()]) 
     464 
     465            self.stat = [self.rStatistics, self.cStatistics, self.mStatistics][2 if self.ismultilabel() else self.isclassification()] 
     466 
    422467            if self.learners: 
    423468                self.score([l.id for l in self.learners.values()]) 
    424  
     469             
    425470        self.openContext("", data) 
    426471        self.paintscores() 
     
    516561                results.remove(i) 
    517562                del results.learners[i] 
    518              
     563 
    519564        for r in rlist: 
    520565            for (i, l) in enumerate(r.learners): 
     
    539584    def newscoreselection(self): 
    540585        """handle change in set of scores to be displayed""" 
    541         usestat = [self.selectedRScores, self.selectedCScores][self.isclassification()] 
     586        usestat = self.get_usestat() 
    542587        for i in range(len(self.stat)): 
    543588            if i in usestat: 
     
    583628    data3 = orange.ExampleTable(r'../../sailing-big') 
    584629    data4 = orange.ExampleTable(r'../../sailing-test') 
     630    data5 = orange.ExampleTable('emotions') 
    585631 
    586632    l1 = orange.MajorityLearner(); l1.name = '1 - Majority' 
     
    599645    import orngRegression as r 
    600646    r5 = r.LinearRegressionLearner(name="0 - lin reg") 
     647 
     648    l5 = Orange.multilabel.BinaryRelevanceLearner() 
    601649 
    602650    testcase = 4 
     
    634682        ow.setLearner(l2, 5) 
    635683        ow.setTestData(None) 
     684    if testcase == 5: # MLC 
     685        ow.setData(data5) 
     686        ow.setLearner(l5, 6) 
    636687 
    637688    ow.saveSettings() 
Note: See TracChangeset for help on using the changeset viewer.