Changeset 10280:44d2bfa6236e in orange


Ignore:
Timestamp:
02/16/12 14:44:46 (2 years ago)
Author:
anzeh <anze.staric@…>
Branch:
default
rebase_source:
9ac845321789a6ea7d6bfb5605ff5204e9574b6c
Message:

Refactored CA

Location:
Orange
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • Orange/evaluation/scoring.py

    r10258 r10280  
    356356# Scores for evaluation of classifiers 
    357357 
    358 class CAClass(object): 
     358class CA(list): 
    359359    """Computation of CA from different types of test_results""" 
    360360    CONFUSION_MATRIX = 0 
     
    365365    @deprecated_keywords({"reportSE": "report_se", 
    366366                          "unweighted": "ignore_weights"}) 
    367     def __call__(self, test_results, report_se = False, ignore_weights=False): 
     367    def __init__(self, test_results, report_se = False, ignore_weights=False): 
    368368        """Return percentage of matches between predicted and actual class. 
    369369 
     
    378378        the assumption of normal distribution otherwise. 
    379379        """ 
     380        super(CA, self).__init__() 
     381        self.report_se = report_se 
     382        self.ignore_weights = ignore_weights 
     383 
    380384        input_type = self.get_input_type(test_results) 
    381385        if input_type == self.CONFUSION_MATRIX: 
    382             return self.from_confusion_matrix(test_results, report_se) 
     386            self[:] =  [self.from_confusion_matrix(test_results)] 
    383387        elif input_type == self.CONFUSION_MATRIX_LIST: 
    384             return self.from_confusion_matrix_list(test_results, report_se) 
     388            self[:] = self.from_confusion_matrix_list(test_results) 
    385389        elif input_type == self.CLASSIFICATION: 
    386             return self.from_classification_results( 
    387                                         test_results, report_se, ignore_weights) 
     390            self[:] = self.from_classification_results(test_results) 
    388391        elif input_type == self.CROSS_VALIDATION: 
    389             return self.from_crossvalidation_results( 
    390                                         test_results, report_se, ignore_weights) 
    391  
    392     def from_confusion_matrix(self, cm, report_se): 
    393         all_predictions = cm.TP+cm.FN+cm.FP+cm.TN 
     392            self[:] =  self.from_crossvalidation_results(test_results) 
     393 
     394    def from_confusion_matrix(self, cm): 
     395        all_predictions = 0. 
     396        correct_predictions = 0. 
     397        if isinstance(cm, ConfusionMatrix): 
     398            all_predictions += cm.TP+cm.FN+cm.FP+cm.TN 
     399            correct_predictions += cm.TP+cm.TN 
     400        else: 
     401            for r, row in enumerate(cm): 
     402                for c, column in enumerate(row): 
     403                    if r == c: 
     404                        correct_predictions += column 
     405                    all_predictions += column 
     406 
    394407        check_non_zero(all_predictions) 
    395         ca = (cm.TP+cm.TN)/all_predictions 
    396  
    397         if report_se: 
     408        ca = correct_predictions/all_predictions 
     409 
     410        if self.report_se: 
    398411            return ca, ca*(1-ca)/math.sqrt(all_predictions) 
    399412        else: 
    400413            return ca 
    401414 
    402     def from_confusion_matrix_list(self, confusion_matrices, report_se): 
    403         return [self.from_confusion_matrix(cm, report_se=report_se) 
    404                 for cm in confusion_matrices] 
    405  
    406     def from_classification_results(self, test_results, report_se, ignore_results): 
     415    def from_confusion_matrix_list(self, confusion_matrices): 
     416        return [self.from_confusion_matrix(cm) for cm in confusion_matrices] 
     417 
     418    def from_classification_results(self, test_results): 
    407419        CAs = [0.0]*test_results.number_of_learners 
    408420        totweight = 0. 
    409421        for tex in test_results.results: 
    410             w = 1. if ignore_results else tex.weight 
     422            w = 1. if self.ignore_weights else tex.weight 
    411423            CAs = map(lambda res, cls: res+(cls==tex.actual_class and w), CAs, tex.classes) 
    412424            totweight += w 
     
    414426        ca = [x/totweight for x in CAs] 
    415427 
    416         if report_se: 
     428        if self.report_se: 
    417429            return [(x, x*(1-x)/math.sqrt(totweight)) for x in ca] 
    418430        else: 
    419431            return ca 
    420432 
    421     def from_crossvalidation_results(self, test_results, report_se, ignore_weights): 
     433    def from_crossvalidation_results(self, test_results): 
    422434        CAsByFold = [[0.0]*test_results.number_of_iterations for _ in range(test_results.number_of_learners)] 
    423435        foldN = [0.0]*test_results.number_of_iterations 
    424436 
    425437        for tex in test_results.results: 
    426             w = 1. if ignore_weights else tex.weight 
     438            w = 1. if self.ignore_weights else tex.weight 
    427439            for lrn in range(test_results.number_of_learners): 
    428440                CAsByFold[lrn][tex.iteration_number] += (tex.classes[lrn]==tex.actual_class) and w 
    429441            foldN[tex.iteration_number] += w 
    430442 
    431         return statistics_by_folds(CAsByFold, foldN, report_se, False) 
     443        return statistics_by_folds(CAsByFold, foldN, self.report_se, False) 
    432444 
    433445    def get_input_type(self, test_results): 
     
    442454            return self.CONFUSION_MATRIX_LIST 
    443455 
    444  
    445 CA = CAClass() 
    446456 
    447457@deprecated_keywords({"reportSE": "report_se", 
  • Orange/testing/unit/tests/test_evaluation_scoring.py

    r10279 r10280  
    117117        cm = scoring.confusion_matrices(cv, class_index=1) 
    118118        ca = scoring.CA(cm[0]) 
    119         self.assertTrue(isinstance(ca, float)) 
     119        self.assertEqual(len(ca), 1) 
    120120 
    121121    def test_ca_from_confusion_matrix_for_classification_on_iris(self): 
     
    139139        cm = scoring.confusion_matrices(cv, class_index=1) 
    140140        ca = scoring.CA(cm[0], report_se=True) 
    141         self.assertTrue(isinstance(ca, tuple)) 
     141        self.assertEqual(len(ca), 1) 
    142142 
    143143    def test_ca_on_iris(self): 
Note: See TracChangeset for help on using the changeset viewer.