Changeset 9151:c5c7bba3fad1 in orange


Ignore:
Timestamp:
10/31/11 01:59:42 (2 years ago)
Author:
ales_erjavec <ales.erjavec@…>
Branch:
default
Convert:
ba92054082b8a96ed6ff2edbe28d779217d7357e
Message:

More work on Reliability widget.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/OrangeWidgets/Prototypes/OWReliability.py

    r9150 r9151  
    2222                       ("Test Data", Orange.data.Table, self.set_test_data)] 
    2323         
    24         self.ouputs = [("Reliability Scores", Orange.data.Table)] 
     24        self.outputs = [("Reliability Scores", Orange.data.Table)] 
    2525         
    2626        self.variance_checked = False 
     
    3535        self.bias_e =  "0.01, 0.1, 0.5, 1.0, 2.0" 
    3636        self.bagged_m = 50 
    37         self.local_cv_k = 1 
     37        self.local_cv_k = 2 
    3838        self.local_pe_k = 5 
    3939        self.bagged_cn_m = 5 
     
    7373            return box 
    7474             
     75        e_validator = QRegExpValidator(QRegExp(r"\s*(-?[0-9]+(\.[0-9]*)\s*,\s*)+"), self) 
    7576        variance_box = method_box(rbox, "Sensitivity analysis (variance)", 
    7677                                  "variance_checked") 
    7778        OWGUI.lineEdit(variance_box, self, "var_e", "Sensitivities:",  
    7879                       tooltip="List of possible e values (comma separated) for SAvar reliability estimates.",  
    79                        callback=partial(self.method_param_changed, 0)) 
     80                       callback=partial(self.method_param_changed, 0), 
     81                       validator=e_validator) 
    8082         
    8183        bias_box = method_box(rbox, "Sensitivity analysis (bias)", 
     
    8385        OWGUI.lineEdit(bias_box, self, "bias_e", "Sensitivities:",  
    8486                       tooltip="List of possible e values (comma separated) for SAbias reliability estimates.",  
    85                        callback=partial(self.method_param_changed, 1)) 
     87                       callback=partial(self.method_param_changed, 1), 
     88                       validator=e_validator) 
    8689         
    8790        bagged_box = method_box(rbox, "Variance of bagged models", 
     
    97100                                  "local_cv") 
    98101         
    99         OWGUI.spin(local_cv_box, self, "local_cv_k", 0, 20, step=1, 
     102        OWGUI.spin(local_cv_box, self, "local_cv_k", 2, 20, step=1, 
    100103                   label="Nearest neighbors:", 
    101104                   tooltip="Number of nearest neighbors used in LCV estimate.", 
     
    106109                              "local_model_pred_error") 
    107110         
    108         OWGUI.spin(local_pe, self, "local_pe_k", 0, 20, step=1, 
     111        OWGUI.spin(local_pe, self, "local_pe_k", 1, 20, step=1, 
    109112                   label="Nearest neighbors:", 
    110113                   tooltip="Number of nearest neighbors used in CNK estimate.", 
     
    121124                   keyboardTracking=False) 
    122125         
    123         OWGUI.spin(bagging_cnn, self, "bagged_cn_k", 0, 20, step=1, 
     126        OWGUI.spin(bagging_cnn, self, "bagged_cn_k", 1, 20, step=1, 
    124127                   label="Nearest neighbors:", 
    125128                   tooltip="Number of nearest neighbors used in BVCK estimate.", 
     
    129132        mahalanobis_box = method_box(rbox, "Mahalanobis distance", 
    130133                                     "mahalanobis_distance") 
    131         OWGUI.spin(mahalanobis_box, self, "mahalanobis_k", 0, 20, step=1, 
     134        OWGUI.spin(mahalanobis_box, self, "mahalanobis_k", 1, 20, step=1, 
    132135                   label="Nearest neighbors:", 
    133136                   tooltip="Number of nearest neighbors used in BVCK estimate.", 
     
    138141         
    139142        OWGUI.checkBox(box, self, "include_error", "Include prediction error", 
    140                        tooltip="Include predicion error in the output", 
     143                       tooltip="Include prediction error in the output", 
    141144                       callback=self.commit_if) 
    142145         
    143         OWGUI.checkBox(box, self, "include_class", "Include orignial class", 
    144                        tooltip="Include orignal class.", 
     146        OWGUI.checkBox(box, self, "include_class", "Include original class and prediction", 
     147                       tooltip="Include original class and prediction in the output.", 
    145148                       callback=self.commit_if) 
    146149         
    147150        OWGUI.checkBox(box, self, "include_input_features", "Include input features", 
    148                        tooltip="Include faetures from the input data set.", 
     151                       tooltip="Include features from the input data set.", 
    149152                       callback=self.commit_if) 
     153         
     154        cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change", 
     155                            callback=self.commit_if) 
     156        b = OWGUI.button(box, self, "Commit", 
     157                         callback=self.commit, 
     158                         autoDefault=True) 
     159         
     160        OWGUI.setStopper(self, b, cb, "output_changed", callback=self.commit) 
    150161         
    151162        self.learner = None 
     
    172183        if self.learner: 
    173184            name = getattr(self.learner, "name") or type(self.learner).__name__ 
     185             
    174186        if self.train_data is not None: 
    175             test = "Train Data: %i features, %i instances" % \ 
     187            train = "Train Data: %i features, %i instances" % \ 
    176188                (len(self.train_data.domain), len(self.train_data)) 
    177189             
     
    182194            test = "Test data: using training data" 
    183195         
    184         self.info_box.setText("\n".join([name, test, train])) 
     196        self.info_box.setText("\n".join([name, train, test])) 
    185197         
    186198        if self.learner and self._test_data() is not None: 
     
    191203            self._cached_SA_estimates = None 
    192204            self.results = [None for f in self.methods] 
    193             print "Invalidating all" 
     205#            print "Invalidating all" 
    194206        else: 
    195207            for i in which: 
     
    197209            if 0 in which or 1 in which: 
    198210                self._cached_SA_estimates = None 
    199             print "Invalidating", which 
     211#            print "Invalidating", which 
    200212     
    201213    def run(self): 
    202214        for i, (selected, method) in enumerate(self.methods): 
    203215            if self.results[i] is None and getattr(self, selected): 
    204                 print 'Computing', i, selected, method 
     216#                print 'Computing', i, selected, method 
    205217                self.results[i] = method() 
    206                 print self.results[i] 
    207         self.commit() 
     218#                print self.results[i] 
     219        self.commit_if() 
    208220             
    209221    def _test_data(self): 
     
    231243            rel = reliability.Learner(self.learner, estimators=[est]) 
    232244            estimator = rel(self.train_data) 
    233             # TODO: SAVar and SABias report both estimates in one pass, 
    234245            self._cached_SA_estimates = self.get_estimates(estimator) 
    235246        return self._cached_SA_estimates  
    236247     
    237248    def run_SAVar(self): 
    238 #        est = reliability.SensitivityAnalysis() 
    239 #        rel = reliability.Learner(self.learner, estimators=[est]) 
    240 #        estimator = rel(self.train_data) 
    241 #        # TODO: SAVar and SABias report both estimates in one pass, 
    242 #        return self.get_estimates(estimator) 
    243         return self.get_SA_estimates() 
     249        est = reliability.SensitivityAnalysis(e=eval(self.var_e)) 
     250        return self.run_estimation(est) 
     251#        return self.get_SA_estimates() 
    244252         
    245253    def run_SABias(self): 
    246 #        est = reliability.SensitivityAnalysis() 
    247 #        rel = reliability.Learner(self.learner, estimators=[est]) 
    248 #        estimator = rel(self.train_data)  
    249 #        # TODO: SAVar and SABias report both estimates in one pass,    
    250 #        return self.get_estimates(estimator) 
    251         return self.get_SA_estimates() 
     254        est = reliability.SensitivityAnalysis(e=eval(self.bias_e)) 
     255        return self.run_estimation(est) 
     256#        return self.get_SA_estimates() 
    252257     
    253258    def run_BAGV(self): 
     
    274279     
    275280    def method_selection_changed(self, method=None): 
    276         if method is not None: 
    277             i = [i for i, (name, _) in enumerate(self.methods)][0] 
    278             self.invalidate_results([i]) 
    279281        self.run() 
    280282     
     
    292294    def commit(self): 
    293295        from Orange.data import variable 
    294         import numpy 
    295296         
    296297        all_predictions = [] 
    297298        all_estimates = [] 
    298299        score_vars = [] 
     300        features = [] 
    299301        table = None 
    300302        if self._test_data() is not None: 
    301303            scores = [] 
    302304             
     305            if self.include_class and not self.include_input_features: 
     306                original_class = self._test_data().domain.class_var 
     307                features.append(original_class) 
     308                 
     309            if self.include_class: 
     310                prediction_var = variable.Continuous("Prediction") 
     311                features.append(prediction_var) 
     312                 
     313            if self.include_error: 
     314                error_var = variable.Continuous("Error") 
     315                abs_error_var = variable.Continuous("Abs Error") 
     316                features.append(error_var) 
     317                features.append(abs_error_var) 
     318                 
    303319            for res, (selected, method) in zip(self.results, self.methods): 
    304320                if res is not None and getattr(self, selected): 
     
    313329                    name = estimates[0].method_name 
    314330                    var = variable.Continuous(name) 
     331                    features.append(var) 
    315332                    score_vars.append(var) 
    316333                    all_predictions.append(values) 
    317334                    all_estimates.append(estimates) 
    318             data = [[] for _ in self._test_data()] 
    319             for preds, estimations in zip(all_predictions, all_estimates): 
    320                 for d, p, e in zip(data, preds, estimations): 
    321                     d.append(e.estimate) 
    322              
    323             domain = Orange.data.Domain(score_vars, False) 
    324             print data 
    325             table = Orange.data.Table(domain, data) 
    326             print table[:] 
     335                     
     336            if self.include_input_features: 
     337                dom = self._test_data().domain 
     338                domain = Orange.data.Domain(dom.attributes, dom.class_var) 
     339                domain.add_metas(dom.get_metas()) 
     340                data = Orange.data.Table(domain, self._test_data()) 
     341            else: 
     342                domain = Orange.data.Domain([]) 
     343                data = Orange.data.Table(domain, [[] for _ in self._test_data()]) 
     344                 
     345            for f in features: 
     346                data.domain.add_meta(Orange.core.newmetaid(), f) 
     347             
     348            if self.include_class: 
     349                for d, inst, pred in zip(data, self._test_data(), all_predictions[0]): 
     350                    if not self.include_input_features: 
     351                        d[features[0]] = float(inst.get_class()) 
     352                    d[prediction_var] = float(pred) 
     353             
     354            if self.include_error: 
     355                for d, inst, pred in zip(data, self._test_data(), all_predictions[0]): 
     356                    error = float(pred) - float(inst.get_class()) 
     357                    d[error_var] = error 
     358                    d[abs_error_var] = abs(error) 
     359                     
     360            for estimations, var in zip(all_estimates, score_vars): 
     361                for d, e in zip(data, estimations): 
     362                    d[var] = e.estimate 
     363             
     364#            domain = Orange.data.Domain(features, False) 
     365#            print data 
     366#            table = Orange.data.Table(domain, data) 
     367#            print data[:] 
     368            table = data 
    327369             
    328370        self.send("Reliability Scores", table) 
     
    336378    data = Orange.data.Table("housing") 
    337379    indices = Orange.core.MakeRandomIndices2(p0=20)(data) 
    338     print indices 
    339380    data = data.select(indices, 0) 
    340381     
Note: See TracChangeset for help on using the changeset viewer.