source: orange/Orange/OrangeWidgets/Evaluate/OWReliability.py @ 10156:6726b609e76c

Revision 10156:6726b609e76c, 16.8 KB checked in by markotoplak, 2 years ago (diff)

Fixed changes from data.variable -> feature in some widgets.

Line 
1"""
2<name>Reliability</name>
3<contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact>
4<priority>310</priority>
5<icon>icons/Reliability.png</icon>
6"""
7
8import Orange
9from Orange.evaluation import reliability
10from Orange.evaluation import testing
11#from Orange.misc import progress_bar_milestones
12from functools import partial
13 
14from OWWidget import *
15import OWGUI
16
17class OWReliability(OWWidget):
18    settingsList = ["variance_checked", "bias_checked", "bagged_variance",
19        "local_cv", "local_model_pred_error", "bagging_variance_cn", 
20        "mahalanobis_distance", "var_e", "bias_e", "bagged_m", "local_cv_k",
21        "local_pe_k", "bagged_cn_m", "bagged_cn_k", "mahalanobis_k",
22        "include_error", "include_class", "include_input_features",
23        "auto_commit"]
24   
25    def __init__(self, parent=None, signalManager=None, title="Reliability"):
26        OWWidget.__init__(self, parent, signalManager, title, wantMainArea=False)
27       
28        self.inputs = [("Learner", Orange.core.Learner, self.set_learner),
29                       ("Training Data", Orange.data.Table, self.set_train_data),
30                       ("Test Data", Orange.data.Table, self.set_test_data)]
31       
32        self.outputs = [("Reliability Scores", Orange.data.Table)]
33       
34        self.variance_checked = False
35        self.bias_checked = False
36        self.bagged_variance = False
37        self.local_cv = False
38        self.local_model_pred_error = False
39        self.bagging_variance_cn = False
40        self.mahalanobis_distance = True
41       
42        self.var_e = "0.01, 0.1, 0.5, 1.0, 2.0"
43        self.bias_e =  "0.01, 0.1, 0.5, 1.0, 2.0"
44        self.bagged_m = 10
45        self.local_cv_k = 2
46        self.local_pe_k = 5
47        self.bagged_cn_m = 5
48        self.bagged_cn_k = 1
49        self.mahalanobis_k = 3
50       
51        self.include_error = True
52        self.include_class = True
53        self.include_input_features = False
54        self.auto_commit = False
55       
56        # (selected attr name, getter function, count of returned estimators, indices of estimator results to use)
57        self.estimators = \
58            [("variance_checked", self.get_SAVar, 3, [0]),
59             ("bias_checked", self.get_SABias, 3, [1, 2]),
60             ("bagged_variance", self.get_BAGV, 1, [0]),
61             ("local_cv", self.get_LCV, 1, [0]),
62             ("local_model_pred_error", self.get_CNK, 2, [0, 1]),
63             ("bagging_variance_cn", self.get_BVCK, 4, [0]),
64             ("mahalanobis_distance", self.get_Mahalanobis, 1, [0])]
65       
66        #####
67        # GUI
68        #####
69        self.loadSettings()
70       
71        box = OWGUI.widgetBox(self.controlArea, "Info", addSpace=True)
72        self.info_box = OWGUI.widgetLabel(box, "\n\n")
73       
74        rbox = OWGUI.widgetBox(self.controlArea, "Methods", addSpace=True)
75        def method_box(parent, name, value):
76            box = OWGUI.widgetBox(rbox, name, flat=False)
77            box.setCheckable(True)
78            box.setChecked(bool(getattr(self, value)))
79            self.connect(box, SIGNAL("toggled(bool)"),
80                         lambda on: (setattr(self, value, on),
81                                     self.method_selection_changed(value)))
82            return box
83           
84        e_validator = QRegExpValidator(QRegExp(r"\s*(-?[0-9]+(\.[0-9]*)\s*,\s*)+"), self)
85        variance_box = method_box(rbox, "Sensitivity analysis (variance)",
86                                  "variance_checked")
87        OWGUI.lineEdit(variance_box, self, "var_e", "Sensitivities:", 
88                       tooltip="List of possible e values (comma separated) for SAvar reliability estimates.", 
89                       callback=partial(self.method_param_changed, 0),
90                       validator=e_validator)
91       
92        bias_box = method_box(rbox, "Sensitivity analysis (bias)",
93                                    "bias_checked")
94        OWGUI.lineEdit(bias_box, self, "bias_e", "Sensitivities:", 
95                       tooltip="List of possible e values (comma separated) for SAbias reliability estimates.", 
96                       callback=partial(self.method_param_changed, 1),
97                       validator=e_validator)
98       
99        bagged_box = method_box(rbox, "Variance of bagged models",
100                                "bagged_variance")
101       
102        OWGUI.spin(bagged_box, self, "bagged_m", 2, 100, step=1,
103                   label="Models:",
104                   tooltip="Number of bagged models to be used with BAGV estimate.",
105                   callback=partial(self.method_param_changed, 2),
106                   keyboardTracking=False)
107       
108        local_cv_box = method_box(rbox, "Local cross validation",
109                                  "local_cv")
110       
111        OWGUI.spin(local_cv_box, self, "local_cv_k", 2, 20, step=1,
112                   label="Nearest neighbors:",
113                   tooltip="Number of nearest neighbors used in LCV estimate.",
114                   callback=partial(self.method_param_changed, 3),
115                   keyboardTracking=False)
116       
117        local_pe = method_box(rbox, "Local modeling of prediction error",
118                              "local_model_pred_error")
119       
120        OWGUI.spin(local_pe, self, "local_pe_k", 1, 20, step=1,
121                   label="Nearest neighbors:",
122                   tooltip="Number of nearest neighbors used in CNK estimate.",
123                   callback=partial(self.method_param_changed, 4),
124                   keyboardTracking=False)
125       
126        bagging_cnn = method_box(rbox, "Bagging variance c-neighbors",
127                                 "bagging_variance_cn")
128       
129        OWGUI.spin(bagging_cnn, self, "bagged_cn_m", 2, 100, step=1,
130                   label="Models:",
131                   tooltip="Number of bagged models to be used with BVCK estimate.",
132                   callback=partial(self.method_param_changed, 5),
133                   keyboardTracking=False)
134       
135        OWGUI.spin(bagging_cnn, self, "bagged_cn_k", 1, 20, step=1,
136                   label="Nearest neighbors:",
137                   tooltip="Number of nearest neighbors used in BVCK estimate.",
138                   callback=partial(self.method_param_changed, 5),
139                   keyboardTracking=False)
140       
141        mahalanobis_box = method_box(rbox, "Mahalanobis distance",
142                                     "mahalanobis_distance")
143        OWGUI.spin(mahalanobis_box, self, "mahalanobis_k", 1, 20, step=1,
144                   label="Nearest neighbors:",
145                   tooltip="Number of nearest neighbors used in BVCK estimate.",
146                   callback=partial(self.method_param_changed, 6),
147                   keyboardTracking=False)
148       
149        box = OWGUI.widgetBox(self.controlArea, "Output")
150       
151        OWGUI.checkBox(box, self, "include_error", "Include prediction error",
152                       tooltip="Include prediction error in the output",
153                       callback=self.commit_if)
154       
155        OWGUI.checkBox(box, self, "include_class", "Include original class and prediction",
156                       tooltip="Include original class and prediction in the output.",
157                       callback=self.commit_if)
158       
159        OWGUI.checkBox(box, self, "include_input_features", "Include input features",
160                       tooltip="Include features from the input data set.",
161                       callback=self.commit_if)
162       
163        cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change",
164                            callback=self.commit_if)
165       
166        self.commit_button = b = OWGUI.button(box, self, "Commit",
167                                              callback=self.commit,
168                                              autoDefault=True)
169       
170        OWGUI.setStopper(self, b, cb, "output_changed", callback=self.commit)
171       
172        self.commit_button.setEnabled(any([getattr(self, selected) \
173                                for selected, _, _, _ in  self.estimators]))
174       
175        self.learner = None
176        self.train_data = None
177        self.test_data = None
178        self.output_changed = False
179        self.train_data_has_no_class = False
180        self.train_data_has_discrete_class = False
181        self.invalidate_results()
182       
183    def set_train_data(self, data=None):
184        self.error()
185        self.train_data_has_no_class = False
186        self.train_data_has_discrete_class = False
187       
188        if data is not None:
189            if not self.isDataWithClass(data, Orange.core.VarTypes.Continuous):
190                if not data.domain.class_var:
191                    self.train_data_has_no_class = True
192                elif not isinstance(data.domain.class_var,
193                                    Orange.feature.Continuous):
194                    self.train_data_has_discrete_class = True
195                   
196                data = None
197       
198        self.train_data = data
199        self.invalidate_results() 
200       
201    def set_test_data(self, data=None):
202        self.test_data = data
203        self.invalidate_results()
204       
205    def set_learner(self, learner=None):
206        self.learner = learner
207        self.invalidate_results()
208       
209    def handleNewSignals(self):
210        name = "No learner on input"
211        train = "No train data on input"
212        test = "No test data on input"
213       
214        if self.learner:
215            name = "Learner: " + (getattr(self.learner, "name") or type(self.learner).__name__)
216           
217        if self.train_data is not None:
218            train = "Train Data: %i features, %i instances" % \
219                (len(self.train_data.domain), len(self.train_data))
220        elif self.train_data_has_no_class:
221            train = "Train Data has no class variable"
222        elif self.train_data_has_discrete_class:
223            train = "Train Data doesn't have a continuous class"
224           
225        if self.test_data is not None:
226            test = "Test Data: %i features, %i instances" % \
227                (len(self.test_data.domain), len(self.test_data))
228        elif self.train_data:
229            test = "Test data: using training data"
230       
231        self.info_box.setText("\n".join([name, train, test]))
232       
233        if self.learner and self._test_data() is not None:
234            self.commit_if()
235       
236    def invalidate_results(self, which=None):
237        if which is None:
238            self.results = [None for f in self.estimators]
239#            print "Invalidating all"
240        else:
241            for i in which:
242                self.results[i] = None
243#            print "Invalidating", which
244       
245    def run(self):
246        plan = []
247        estimate_index = 0
248        for i, (selected, method, count, offsets) in enumerate(self.estimators):
249            if self.results[i] is None and getattr(self, selected):
250                plan.append((i, method, [estimate_index + offset for offset in offsets]))
251                estimate_index += count
252               
253        estimators = [method() for _, method, _ in plan]
254       
255        if not estimators:
256            return
257           
258        pb = OWGUI.ProgressBar(self, len(self._test_data()))
259        estimates = self.run_estimation(estimators, pb.advance)
260        pb.finish()
261       
262        self.predictions = [v for v, _ in estimates]
263        estimates = [prob.reliability_estimate for _, prob in estimates]
264       
265        for i, (index, method, estimate_indices) in enumerate(plan):
266            self.results[index] = [[e[est_index] for e in estimates] \
267                                   for est_index in estimate_indices]
268       
269    def _test_data(self):
270        if self.test_data is not None:
271            return self.test_data
272        else:
273            return self.train_data
274   
275    def get_estimates(self, estimator, advance=None):
276        test = self._test_data()
277        res = []
278        for i, inst in enumerate(test):
279            value, prob = estimator(inst, result_type=Orange.core.GetBoth)
280            res.append((value, prob))
281            if advance:
282                advance()
283        return res
284               
285    def run_estimation(self, estimators, advance=None):
286        rel = reliability.Learner(self.learner, estimators=estimators)
287        estimator = rel(self.train_data)
288        return self.get_estimates(estimator, advance) 
289   
290    def get_SAVar(self):
291        return reliability.SensitivityAnalysis(e=eval(self.var_e))
292   
293    def get_SABias(self):
294        return reliability.SensitivityAnalysis(e=eval(self.bias_e))
295   
296    def get_BAGV(self):
297        return reliability.BaggingVariance(m=self.bagged_m)
298   
299    def get_LCV(self):
300        return reliability.LocalCrossValidation(k=self.local_cv_k)
301   
302    def get_CNK(self):
303        return reliability.CNeighbours(k=self.local_pe_k)
304   
305    def get_BVCK(self):
306        bagv = reliability.BaggingVariance(m=self.bagged_cn_m)
307        cnk = reliability.CNeighbours(k=self.bagged_cn_k)
308        return reliability.BaggingVarianceCNeighbours(bagv, cnk)
309   
310    def get_Mahalanobis(self):
311        return reliability.Mahalanobis(k=self.mahalanobis_k)
312   
313    def method_selection_changed(self, method=None):
314        self.commit_button.setEnabled(any([getattr(self, selected) \
315                                for selected, _, _, _ in  self.estimators]))
316        self.commit_if()
317   
318    def method_param_changed(self, method=None):
319        if method is not None:
320            self.invalidate_results([method])
321        self.commit_if()
322       
323    def commit_if(self):
324        if self.auto_commit:
325            self.commit()
326        else:
327            self.output_changed = True
328           
329    def commit(self):
330        from Orange import feature as variable
331        name_mapper = {"Mahalanobis absolute": "Mahalanobis"}
332        all_predictions = []
333        all_estimates = []
334        score_vars = []
335        features = []
336        table = None
337        if self.learner and self.train_data is not None \
338                and self._test_data() is not None:
339            self.run()
340           
341            scores = []
342            if self.include_class and not self.include_input_features:
343                original_class = self._test_data().domain.class_var
344                features.append(original_class)
345               
346            if self.include_class:
347                prediction_var = variable.Continuous("Prediction")
348                features.append(prediction_var)
349               
350            if self.include_error:
351                error_var = variable.Continuous("Error")
352                abs_error_var = variable.Continuous("Abs. Error")
353                features.append(error_var)
354                features.append(abs_error_var)
355               
356            for estimates, (selected, method, _, _) in zip(self.results, self.estimators):
357                if estimates is not None and getattr(self, selected):
358                    for estimate in estimates:
359                        name = estimate[0].method_name
360                        name = name_mapper.get(name, name)
361                        var = variable.Continuous(name)
362                        features.append(var)
363                        score_vars.append(var)
364                        all_estimates.append(estimate)
365                   
366            if self.include_input_features:
367                dom = self._test_data().domain
368                attributes = list(dom.attributes) + features
369                domain = Orange.data.Domain(attributes, dom.class_var)
370                domain.add_metas(dom.get_metas())
371               
372                data = Orange.data.Table(domain, self._test_data())
373            else:
374                domain = Orange.data.Domain(features, None)
375                data = Orange.data.Table(domain, [[None] * len(features) for _ in self._test_data()])
376           
377            if self.include_class:
378                for d, inst, pred in zip(data, self._test_data(), self.predictions):
379                    if not self.include_input_features:
380                        d[features[0]] = float(inst.get_class())
381                    d[prediction_var] = float(pred)
382           
383            if self.include_error:
384                for d, inst, pred in zip(data, self._test_data(), self.predictions):
385                    error = float(pred) - float(inst.get_class())
386                    d[error_var] = error
387                    d[abs_error_var] = abs(error)
388                   
389            for estimations, var in zip(all_estimates, score_vars):
390                for d, e in zip(data, estimations):
391                    d[var] = e.estimate
392           
393            table = data
394           
395        self.send("Reliability Scores", table)
396        self.output_changed = False
397       
398       
399if __name__ == "__main__":
400    import sys
401    app = QApplication(sys.argv)
402    w = OWReliability()
403    data = Orange.data.Table("housing")
404    indices = Orange.core.MakeRandomIndices2(p0=20)(data)
405    data = data.select(indices, 0)
406   
407    learner = Orange.regression.tree.TreeLearner()
408    w.set_learner(learner)
409    w.set_train_data(data)
410    w.handleNewSignals()
411    w.show()
412    app.exec_()
413   
414       
Note: See TracBrowser for help on using the repository browser.