Changeset 9165:9a3607f23503 in orange


Ignore:
Timestamp:
11/02/11 18:37:16 (2 years ago)
Author:
ales_erjavec <ales.erjavec@…>
Branch:
default
Convert:
6b053ed838c7d1e8a5b8a0b83807c7d0d2d7a892
Message:

Batch multiple estimators for evaluation in a single run.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/OrangeWidgets/Evaluate/OWReliability.py

    r9161 r9165  
    5353        self.include_input_features = False 
    5454        self.auto_commit = False 
    55               
    56         self.methods = [("variance_checked", self.run_SAVar), 
    57                         ("bias_checked", self.run_SABias), 
    58                         ("bagged_variance", self.run_BAGV), 
    59                         ("local_cv", self.run_LCV), 
    60                         ("local_model_pred_error", self.run_CNK), 
    61                         ("bagging_variance_cn", self.run_BVCK), 
    62                         ("mahalanobis_distance", self.run_Mahalanobis)] 
     55         
     56        # (selected attr name, getter function, count of returned estimators, index of estimator) 
     57        self.estimators = \ 
     58            [("variance_checked", self.get_SAVar, 3, 0), 
     59             ("bias_checked", self.get_SABias, 3, 1), 
     60             ("bagged_variance", self.get_BAGV, 1, 0), 
     61             ("local_cv", self.get_LCV, 1, 0), 
     62             ("local_model_pred_error", self.get_CNK, 2, 0), 
     63             ("bagging_variance_cn", self.get_BVCK, 4, 0), 
     64             ("mahalanobis_distance", self.get_Mahalanobis, 1, 0)] 
    6365         
    6466        ##### 
     
    169171         
    170172        self.commit_button.setEnabled(any([getattr(self, selected) \ 
    171                                 for selected, _ in  self.methods])) 
     173                                for selected, _, _, _ in  self.estimators])) 
    172174         
    173175        self.learner = None 
     
    179181         
    180182    def set_train_data(self, data=None): 
     183        self.error() 
     184        if data is not None: 
     185            if not self.isDataWithClass(data, Orange.core.VarTypes.Continuous): 
     186                data = None 
     187         
    181188        self.train_data = data 
    182         self.invalidate_results() 
     189        self.invalidate_results()  
    183190         
    184191    def set_test_data(self, data=None): 
     
    191198         
    192199    def handleNewSignals(self): 
    193         name = test = train = "" 
     200        name = "No learner on input" 
     201        train = "No train data on input" 
     202        test = "No test data on input" 
     203         
    194204        if self.learner: 
    195205            name = getattr(self.learner, "name") or type(self.learner).__name__ 
     
    202212            test = "Train Data: %i features, %i instances" % \ 
    203213                (len(self.test_data.domain), len(self.test_data)) 
    204         else: 
     214        elif self.train_data: 
    205215            test = "Test data: using training data" 
    206216         
     
    212222    def invalidate_results(self, which=None): 
    213223        if which is None: 
    214             self.results = [None for f in self.methods] 
     224            self.results = [None for f in self.estimators] 
    215225#            print "Invalidating all" 
    216226        else: 
     
    218228                self.results[i] = None 
    219229#            print "Invalidating", which 
    220      
     230         
    221231    def run(self): 
    222232        plan = [] 
    223         for i, (selected, method) in enumerate(self.methods): 
     233        estimate_index = 0 
     234        for i, (selected, method, count, offset) in enumerate(self.estimators): 
    224235            if self.results[i] is None and getattr(self, selected): 
    225 #                print 'Computing', i, selected, method 
    226                 plan.append((i, method)) 
    227 #                self.results[i] = method() 
    228 #                print self.results[i] 
    229         count = len(plan) 
    230         pb = OWGUI.ProgressBar(self, count * len(self._test_data())) 
    231         for i, (index, method) in enumerate(plan): 
    232             self.results[index] = method(pb.advance)     
     236                plan.append((i, method, estimate_index + offset)) 
     237                estimate_index += count 
     238                 
     239        estimators = [method() for _, method, _ in plan] 
     240         
     241        if not estimators: 
     242            return 
     243             
     244        pb = OWGUI.ProgressBar(self, len(self._test_data())) 
     245        estimates = self.run_estimation(estimators, pb.advance) 
    233246        pb.finish() 
     247         
     248        self.predictions = [v for v, _ in estimates] 
     249        estimates = [prob.reliability_estimate for _, prob in estimates] 
     250         
     251        for i, (index, method, estimate_index) in enumerate(plan): 
     252            self.results[index] = [e[estimate_index] for e in estimates] 
    234253         
    235254    def _test_data(self): 
     
    249268        return res 
    250269                 
    251     def run_estimation(self, method, advance=None): 
    252         rel = reliability.Learner(self.learner, estimators=[method]) 
     270    def run_estimation(self, estimators, advance=None): 
     271        rel = reliability.Learner(self.learner, estimators=estimators) 
    253272        estimator = rel(self.train_data) 
    254273        return self.get_estimates(estimator, advance)  
    255274     
    256     def run_SAVar(self, advance=None): 
    257         est = reliability.SensitivityAnalysis(e=eval(self.var_e)) 
    258         return self.run_estimation(est, advance) 
    259          
    260     def run_SABias(self, advance=None): 
    261         est = reliability.SensitivityAnalysis(e=eval(self.bias_e)) 
    262         return self.run_estimation(est, advance) 
    263      
    264     def run_BAGV(self, advance=None): 
    265         est = reliability.BaggingVariance(m=self.bagged_m) 
    266         return self.run_estimation(est, advance) 
    267      
    268     def run_LCV(self, advance=None): 
    269         est = reliability.LocalCrossValidation(k=self.local_cv_k) 
    270         return self.run_estimation(est, advance) 
    271      
    272     def run_CNK(self, advance=None): 
    273         est = reliability.CNeighbours(k=self.local_pe_k) 
    274         return self.run_estimation(est, advance) 
    275      
    276     def run_BVCK(self, advance=None): 
     275    def get_SAVar(self): 
     276        return reliability.SensitivityAnalysis(e=eval(self.var_e)) 
     277     
     278    def get_SABias(self): 
     279        return reliability.SensitivityAnalysis(e=eval(self.bias_e)) 
     280     
     281    def get_BAGV(self): 
     282        return reliability.BaggingVariance(m=self.bagged_m) 
     283     
     284    def get_LCV(self): 
     285        return reliability.LocalCrossValidation(k=self.local_cv_k) 
     286     
     287    def get_CNK(self): 
     288        return reliability.CNeighbours(k=self.local_pe_k) 
     289     
     290    def get_BVCK(self): 
    277291        bagv = reliability.BaggingVariance(m=self.bagged_cn_m) 
    278292        cnk = reliability.CNeighbours(k=self.bagged_cn_k) 
    279         est = reliability.BaggingVarianceCNeighbours(bagv, cnk) 
    280         return self.run_estimation(est, advance) 
    281      
    282     def run_Mahalanobis(self, advance=None): 
    283         est = reliability.Mahalanobis(k=self.mahalanobis_k) 
    284         return self.run_estimation(est, advance) 
     293        return reliability.BaggingVarianceCNeighbours(bagv, cnk) 
     294     
     295    def get_Mahalanobis(self): 
     296        return reliability.Mahalanobis(k=self.mahalanobis_k) 
    285297     
    286298    def method_selection_changed(self, method=None): 
    287299        self.commit_button.setEnabled(any([getattr(self, selected) \ 
    288                                 for selected, _ in  self.methods])) 
     300                                for selected, _, _, _ in  self.estimators])) 
    289301        self.commit_if() 
    290302     
     
    304316         
    305317        self.run() 
    306         name_mapper = {"Mahalanobis": "Mahalanobis"} 
     318        name_mapper = {"Mahalanobis absolute": "Mahalanobis"} 
    307319        all_predictions = [] 
    308320        all_estimates = [] 
     
    323335            if self.include_error: 
    324336                error_var = variable.Continuous("Error") 
    325                 abs_error_var = variable.Continuous("Abs Error") 
     337                abs_error_var = variable.Continuous("Abs. Error") 
    326338                features.append(error_var) 
    327339                features.append(abs_error_var) 
    328340                 
    329             for res, (selected, method) in zip(self.results, self.methods): 
    330                 if res is not None and getattr(self, selected): 
    331                     if selected == "bias_checked": 
    332                         ei = 1 
    333                     else: 
    334                         ei = 0 
    335                     values, estimates = [], [] 
    336                     for value, probs in res: 
    337                         values.append(value) 
    338                         estimates.append(probs.reliability_estimate[ei]) 
     341            for estimates, (selected, method, _, _) in zip(self.results, self.estimators): 
     342                if estimates is not None and getattr(self, selected): 
    339343                    name = estimates[0].method_name 
    340344                    name = name_mapper.get(name, name) 
     
    342346                    features.append(var) 
    343347                    score_vars.append(var) 
    344                     all_predictions.append(values) 
    345348                    all_estimates.append(estimates) 
    346349                     
     
    355358                 
    356359            for f in features: 
    357                 data.domain.add_meta(Orange.core.newmetaid(), f) 
     360                data.domain.add_meta(Orange.data.new_meta_id(), f) 
    358361             
    359362            if self.include_class: 
    360                 for d, inst, pred in zip(data, self._test_data(), all_predictions[0]): 
     363                for d, inst, pred in zip(data, self._test_data(), self.predictions): 
    361364                    if not self.include_input_features: 
    362365                        d[features[0]] = float(inst.get_class()) 
     
    364367             
    365368            if self.include_error: 
    366                 for d, inst, pred in zip(data, self._test_data(), all_predictions[0]): 
     369                for d, inst, pred in zip(data, self._test_data(), self.predictions): 
    367370                    error = float(pred) - float(inst.get_class()) 
    368371                    d[error_var] = error 
     
    376379             
    377380        self.send("Reliability Scores", table) 
    378         self.output_changed = True 
     381        self.output_changed = False 
    379382         
    380383         
Note: See TracChangeset for help on using the changeset viewer.