Ignore:
Timestamp:
02/07/12 20:09:31 (2 years ago)
Author:
anze <anze.staric@…>
Branch:
default
Message:

Refactored CA

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/evaluation/scoring.py

    r9999 r10000  
    66from Orange import statc, corn 
    77from Orange.misc import deprecated_keywords 
     8from Orange.evaluation import testing 
    89 
    910#### Private stuff 
     
    126127MAE = ME 
    127128 
     129 
     130class ConfusionMatrix: 
     131    """ 
     132    Classification result summary 
     133 
     134    .. attribute:: TP 
     135 
     136        True Positive predictions 
     137 
     138    .. attribute:: TN 
     139 
     140        True Negative predictions 
     141 
     142    .. attribute:: FP 
     143 
     144        False Positive predictions 
     145 
     146    .. attribute:: FN 
     147 
     148        False Negative predictions 
     149    """ 
     150    def __init__(self): 
     151        self.TP = self.FN = self.FP = self.TN = 0.0 
     152 
     153    @deprecated_keywords({"predictedPositive": "predicted_positive", 
     154                          "isPositive": "is_positive"}) 
     155    def addTFPosNeg(self, predicted_positive, is_positive, weight = 1.0): 
     156        """ 
     157        Update confusion matrix with result of a single classification 
     158 
     159        :param predicted_positive: positive class value was predicted 
     160        :param is_positive: correct class value is positive 
     161        :param weight: weight of the selected instance 
     162         """ 
     163        if predicted_positive: 
     164            if is_positive: 
     165                self.TP += weight 
     166            else: 
     167                self.FP += weight 
     168        else: 
     169            if is_positive: 
     170                self.FN += weight 
     171            else: 
     172                self.TN += weight 
     173 
     174 
    128175######################################################################### 
    129176# PERFORMANCE MEASURES: 
     
    305352# Scores for evaluation of classifiers 
    306353 
    307 @deprecated_keywords({"reportSE": "report_se"}) 
    308 def CA(test_results, report_se = False, **argkw): 
    309     """Return percentage of matches between predicted and actual class. 
    310  
    311     :param test_results: :obj:`~Orange.evaluation.testing.ExperimentResults` 
    312                          or :obj:`ConfusionMatrix`. 
    313     :param report_se: include standard error in result. 
    314     :rtype: list of scores, one for each learner. 
    315  
    316     Standard errors are estimated from deviation of CAs across folds (if 
    317     test_results were produced by cross_validation) or approximated under 
    318     the assumption of normal distribution otherwise. 
    319     """ 
    320     if isinstance(test_results, list) and len(test_results) > 0 \ 
    321                              and isinstance(test_results[0], ConfusionMatrix): 
    322         results = [] 
    323         for cm in test_results: 
    324             div = cm.TP+cm.FN+cm.FP+cm.TN 
    325             check_non_zero(div) 
    326             results.append((cm.TP+cm.TN)/div) 
    327         return results 
    328     elif test_results.number_of_iterations==1: 
     354class CAClass(object): 
     355    CONFUSION_MATRIX = 0 
     356    CONFUSION_MATRIX_LIST = 1 
     357    CLASSIFICATION = 2 
     358    CROSS_VALIDATION = 3 
     359 
     360    @deprecated_keywords({"reportSE": "report_se"}) 
     361    def __call__(self, test_results, report_se = False, unweighted=False): 
     362        """Return percentage of matches between predicted and actual class. 
     363 
     364        :param test_results: :obj:`~Orange.evaluation.testing.ExperimentResults` 
     365                             or :obj:`ConfusionMatrix`. 
     366        :param report_se: include standard error in result. 
     367        :rtype: list of scores, one for each learner. 
     368 
     369        Standard errors are estimated from deviation of CAs across folds (if 
     370        test_results were produced by cross_validation) or approximated under 
     371        the assumption of normal distribution otherwise. 
     372        """ 
     373        input_type = self.get_input_type(test_results) 
     374        if input_type == self.CONFUSION_MATRIX: 
     375            return self.from_confusion_matrix(test_results, report_se) 
     376        elif input_type == self.CONFUSION_MATRIX_LIST: 
     377            return self.from_confusion_matrix_list(test_results, report_se) 
     378        elif input_type == self.CLASSIFICATION: 
     379            return self.from_classification_results( 
     380                                        test_results, report_se, unweighted) 
     381        elif input_type == self.CROSS_VALIDATION: 
     382            return self.from_crossvalidation_results( 
     383                                        test_results, report_se, unweighted) 
     384 
     385    def from_confusion_matrix(self, cm, report_se): 
     386        all_predictions = cm.TP+cm.FN+cm.FP+cm.TN 
     387        check_non_zero(all_predictions) 
     388        ca = (cm.TP+cm.TN)/all_predictions 
     389 
     390        if report_se: 
     391            return ca, ca*(1-ca)/math.sqrt(all_predictions) 
     392        else: 
     393            return ca 
     394 
     395    def from_confusion_matrix_list(self, confusion_matrices, report_se): 
     396        return map(self.from_confusion_matrix, confusion_matrices) # TODO: report_se 
     397 
     398    def from_classification_results(self, test_results, report_se, unweighted): 
    329399        CAs = [0.0]*test_results.number_of_learners 
    330         if argkw.get("unweighted", 0) or not test_results.weights: 
    331             totweight = gettotsize(test_results) 
    332             for tex in test_results.results: 
    333                 CAs = map(lambda res, cls: res+(cls==tex.actual_class), CAs, tex.classes) 
    334         else: 
    335             totweight = 0. 
    336             for tex in test_results.results: 
    337                 CAs = map(lambda res, cls: res+(cls==tex.actual_class and tex.weight), CAs, tex.classes) 
    338                 totweight += tex.weight 
     400        totweight = 0. 
     401        for tex in test_results.results: 
     402            w = 1. if unweighted else tex.weight 
     403            CAs = map(lambda res, cls: res+(cls==tex.actual_class and w), CAs, tex.classes) 
     404            totweight += w 
    339405        check_non_zero(totweight) 
    340406        ca = [x/totweight for x in CAs] 
    341              
     407 
    342408        if report_se: 
    343409            return [(x, x*(1-x)/math.sqrt(totweight)) for x in ca] 
    344410        else: 
    345411            return ca 
    346          
    347     else: 
     412 
     413    def from_crossvalidation_results(self, test_results, report_se, unweighted): 
    348414        CAsByFold = [[0.0]*test_results.number_of_iterations for i in range(test_results.number_of_learners)] 
    349415        foldN = [0.0]*test_results.number_of_iterations 
    350416 
    351         if argkw.get("unweighted", 0) or not test_results.weights: 
    352             for tex in test_results.results: 
    353                 for lrn in range(test_results.number_of_learners): 
    354                     CAsByFold[lrn][tex.iteration_number] += (tex.classes[lrn]==tex.actual_class) 
    355                 foldN[tex.iteration_number] += 1 
    356         else: 
    357             for tex in test_results.results: 
    358                 for lrn in range(test_results.number_of_learners): 
    359                     CAsByFold[lrn][tex.iteration_number] += (tex.classes[lrn]==tex.actual_class) and tex.weight 
    360                 foldN[tex.iteration_number] += tex.weight 
     417        for tex in test_results.results: 
     418            w = 1. if unweighted else tex.weight 
     419            for lrn in range(test_results.number_of_learners): 
     420                CAsByFold[lrn][tex.iteration_number] += (tex.classes[lrn]==tex.actual_class) and w 
     421            foldN[tex.iteration_number] += w 
    361422 
    362423        return statistics_by_folds(CAsByFold, foldN, report_se, False) 
    363424 
    364  
    365 # Obsolete, but kept for compatibility 
    366 def CA_se(res, **argkw): 
    367     return CA(res, True, **argkw) 
     425    def get_input_type(self, test_results): 
     426        if isinstance(test_results, ConfusionMatrix): 
     427            return self.CONFUSION_MATRIX 
     428        elif isinstance(test_results, testing.ExperimentResults): 
     429            if test_results.number_of_iterations == 1: 
     430                return self.CLASSIFICATION 
     431            else: 
     432                return self.CROSS_VALIDATION 
     433        elif isinstance(test_results, list): 
     434            return self.CONFUSION_MATRIX_LIST 
     435 
     436 
     437 
     438CA = CAClass 
    368439 
    369440@deprecated_keywords({"reportSE": "report_se"}) 
     
    554625    else: 
    555626        return apply(Friedman, (res, statistics), argkw) 
    556      
    557 class ConfusionMatrix: 
    558     """ 
    559     Classification result summary 
    560  
    561     .. attribute:: TP 
    562  
    563         True Positive predictions 
    564  
    565     .. attribute:: TN 
    566  
    567         True Negative predictions 
    568  
    569     .. attribute:: FP 
    570  
    571         False Positive predictions 
    572  
    573     .. attribute:: FN 
    574  
    575         False Negative predictions 
    576     """ 
    577     def __init__(self): 
    578         self.TP = self.FN = self.FP = self.TN = 0.0 
    579  
    580     @deprecated_keywords({"predictedPositive": "predicted_positive", 
    581                           "isPositive": "is_positive"}) 
    582     def addTFPosNeg(self, predicted_positive, is_positive, weight = 1.0): 
    583         """ 
    584         Update confusion matrix with result of a single classification 
    585  
    586         :param predicted_positive: positive class value was predicted 
    587         :param is_positive: correct class value is positive 
    588         :param weight: weight of the selected instance 
    589          """ 
    590         if predicted_positive: 
    591             if is_positive: 
    592                 self.TP += weight 
    593             else: 
    594                 self.FP += weight 
    595         else: 
    596             if is_positive: 
    597                 self.FN += weight 
    598             else: 
    599                 self.TN += weight 
     627 
    600628 
    601629@deprecated_keywords({"res": "test_results", 
Note: See TracChangeset for help on using the changeset viewer.