source: orange/orange/OrangeWidgets/Evaluate/OWReliability.py @ 9157:15762b4ff142

Revision 9157:15762b4ff142, 15.9 KB checked in by ales_erjavec <ales.erjavec@…>, 20 months ago (diff)

Moved Reliability widget from Prototypes to Evaluate.
Added progress bar.

Line 
1"""
2<name>Reliability</name>
3<contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact>
4<priority>310</priority>
5"""
6
7import Orange
8from Orange.evaluation import reliability
9from Orange.evaluation import testing
10#from Orange.misc import progress_bar_milestones
11from functools import partial
12 
13from OWWidget import *
14import OWGUI
15
16class OWReliability(OWWidget):
17    settingsList = ["variance_checked", "bias_checked", "bagged_variance",
18        "local_cv", "local_model_pred_error", "bagging_variance_cn", 
19        "mahalanobis_distance", "var_e", "bias_e", "bagged_m", "local_cv_k",
20        "local_pe_k", "bagged_cn_m", "bagged_cn_k", "mahalanobis_k",
21        "include_error", "include_class", "include_input_features",
22        "auto_commit"]
23   
24    def __init__(self, parent=None, signalManager=None, title="Reliability"):
25        OWWidget.__init__(self, parent, signalManager, title, wantMainArea=False)
26       
27        self.inputs = [("Learner", Orange.core.Learner, self.set_learner),
28                       ("Train Data", Orange.data.Table, self.set_train_data),
29                       ("Test Data", Orange.data.Table, self.set_test_data)]
30       
31        self.outputs = [("Reliability Scores", Orange.data.Table)]
32       
33        self.variance_checked = False
34        self.bias_checked = False
35        self.bagged_variance = False
36        self.local_cv = False
37        self.local_model_pred_error = False
38        self.bagging_variance_cn = False
39        self.mahalanobis_distance = True
40       
41        self.var_e = "0.01, 0.1, 0.5, 1.0, 2.0"
42        self.bias_e =  "0.01, 0.1, 0.5, 1.0, 2.0"
43        self.bagged_m = 50
44        self.local_cv_k = 2
45        self.local_pe_k = 5
46        self.bagged_cn_m = 5
47        self.bagged_cn_k = 1
48        self.mahalanobis_k = 5
49       
50        self.include_error = True
51        self.include_class = True
52        self.include_input_features = True
53        self.auto_commit = False
54             
55        self.methods = [("variance_checked", self.run_SAVar),
56                        ("bias_checked", self.run_SABias),
57                        ("bagged_variance", self.run_BAGV),
58                        ("local_cv", self.run_LCV),
59                        ("local_model_pred_error", self.run_CNK),
60                        ("bagging_variance_cn", self.run_BVCK),
61                        ("mahalanobis_distance", self.run_Mahalanobis)]
62       
63        #####
64        # GUI
65        #####
66        self.loadSettings()
67       
68        box = OWGUI.widgetBox(self.controlArea, "Info", addSpace=True)
69        self.info_box = OWGUI.widgetLabel(box, "\n\n")
70       
71        rbox = OWGUI.widgetBox(self.controlArea, "Methods", addSpace=True)
72        def method_box(parent, name, value):
73            box = OWGUI.widgetBox(rbox, name, flat=False)
74            box.setCheckable(True)
75            box.setChecked(bool(getattr(self, value)))
76            self.connect(box, SIGNAL("toggled(bool)"),
77                         lambda on: (setattr(self, value, on),
78                                     self.method_selection_changed(value)))
79            return box
80           
81        e_validator = QRegExpValidator(QRegExp(r"\s*(-?[0-9]+(\.[0-9]*)\s*,\s*)+"), self)
82        variance_box = method_box(rbox, "Sensitivity analysis (variance)",
83                                  "variance_checked")
84        OWGUI.lineEdit(variance_box, self, "var_e", "Sensitivities:", 
85                       tooltip="List of possible e values (comma separated) for SAvar reliability estimates.", 
86                       callback=partial(self.method_param_changed, 0),
87                       validator=e_validator)
88       
89        bias_box = method_box(rbox, "Sensitivity analysis (bias)",
90                                    "bias_checked")
91        OWGUI.lineEdit(bias_box, self, "bias_e", "Sensitivities:", 
92                       tooltip="List of possible e values (comma separated) for SAbias reliability estimates.", 
93                       callback=partial(self.method_param_changed, 1),
94                       validator=e_validator)
95       
96        bagged_box = method_box(rbox, "Variance of bagged models",
97                                "bagged_variance")
98       
99        OWGUI.spin(bagged_box, self, "bagged_m", 2, 100, step=1,
100                   label="Models:",
101                   tooltip="Number of bagged models to be used with BAGV estimate.",
102                   callback=partial(self.method_param_changed, 2),
103                   keyboardTracking=False)
104       
105        local_cv_box = method_box(rbox, "Local cross validation",
106                                  "local_cv")
107       
108        OWGUI.spin(local_cv_box, self, "local_cv_k", 2, 20, step=1,
109                   label="Nearest neighbors:",
110                   tooltip="Number of nearest neighbors used in LCV estimate.",
111                   callback=partial(self.method_param_changed, 3),
112                   keyboardTracking=False)
113       
114        local_pe = method_box(rbox, "Local modeling of prediction error",
115                              "local_model_pred_error")
116       
117        OWGUI.spin(local_pe, self, "local_pe_k", 1, 20, step=1,
118                   label="Nearest neighbors:",
119                   tooltip="Number of nearest neighbors used in CNK estimate.",
120                   callback=partial(self.method_param_changed, 4),
121                   keyboardTracking=False)
122       
123        bagging_cnn = method_box(rbox, "Bagging variance c-neighbors",
124                                 "bagging_variance_cn")
125       
126        OWGUI.spin(bagging_cnn, self, "bagged_cn_m", 2, 100, step=1,
127                   label="Models:",
128                   tooltip="Number of bagged models to be used with BVCK estimate.",
129                   callback=partial(self.method_param_changed, 5),
130                   keyboardTracking=False)
131       
132        OWGUI.spin(bagging_cnn, self, "bagged_cn_k", 1, 20, step=1,
133                   label="Nearest neighbors:",
134                   tooltip="Number of nearest neighbors used in BVCK estimate.",
135                   callback=partial(self.method_param_changed, 5),
136                   keyboardTracking=False)
137       
138        mahalanobis_box = method_box(rbox, "Mahalanobis distance",
139                                     "mahalanobis_distance")
140        OWGUI.spin(mahalanobis_box, self, "mahalanobis_k", 1, 20, step=1,
141                   label="Nearest neighbors:",
142                   tooltip="Number of nearest neighbors used in BVCK estimate.",
143                   callback=partial(self.method_param_changed, 6),
144                   keyboardTracking=False)
145       
146        box = OWGUI.widgetBox(self.controlArea, "Output")
147       
148        OWGUI.checkBox(box, self, "include_error", "Include prediction error",
149                       tooltip="Include prediction error in the output",
150                       callback=self.commit_if)
151       
152        OWGUI.checkBox(box, self, "include_class", "Include original class and prediction",
153                       tooltip="Include original class and prediction in the output.",
154                       callback=self.commit_if)
155       
156        OWGUI.checkBox(box, self, "include_input_features", "Include input features",
157                       tooltip="Include features from the input data set.",
158                       callback=self.commit_if)
159       
160        cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change",
161                            callback=self.commit_if)
162       
163        self.commit_button = b = OWGUI.button(box, self, "Commit",
164                                              callback=self.commit,
165                                              autoDefault=True)
166       
167        OWGUI.setStopper(self, b, cb, "output_changed", callback=self.commit)
168       
169        self.commit_button.setEnabled(any([getattr(self, selected) \
170                                for selected, _ in  self.methods]))
171       
172        self.learner = None
173        self.train_data = None
174        self.test_data = None
175        self.output_changed = False
176         
177        self.invalidate_results()
178       
179    def set_train_data(self, data=None):
180        self.train_data = data
181        self.invalidate_results()
182       
183    def set_test_data(self, data=None):
184        self.test_data = data
185        self.invalidate_results()
186       
187    def set_learner(self, learner=None):
188        self.learner = learner
189        self.invalidate_results()
190       
191    def handleNewSignals(self):
192        name = test = train = ""
193        if self.learner:
194            name = getattr(self.learner, "name") or type(self.learner).__name__
195           
196        if self.train_data is not None:
197            train = "Train Data: %i features, %i instances" % \
198                (len(self.train_data.domain), len(self.train_data))
199           
200        if self.test_data is not None:
201            test = "Train Data: %i features, %i instances" % \
202                (len(self.test_data.domain), len(self.test_data))
203        else:
204            test = "Test data: using training data"
205       
206        self.info_box.setText("\n".join([name, train, test]))
207       
208        if self.learner and self._test_data() is not None:
209            self.commit_if()
210       
211    def invalidate_results(self, which=None):
212        if which is None:
213            self.results = [None for f in self.methods]
214#            print "Invalidating all"
215        else:
216            for i in which:
217                self.results[i] = None
218#            print "Invalidating", which
219   
220    def run(self):
221        plan = []
222        for i, (selected, method) in enumerate(self.methods):
223            if self.results[i] is None and getattr(self, selected):
224#                print 'Computing', i, selected, method
225                plan.append((i, method))
226#                self.results[i] = method()
227#                print self.results[i]
228        count = len(plan)
229        pb = OWGUI.ProgressBar(self, count * len(self._test_data()))
230        for i, (index, method) in enumerate(plan):
231            self.results[index] = method(pb.advance)   
232        pb.finish()
233       
234    def _test_data(self):
235        if self.test_data is not None:
236            return self.test_data
237        else:
238            return self.train_data
239   
240    def get_estimates(self, estimator, advance=None):
241        test = self._test_data()
242        res = []
243        for i, inst in enumerate(test):
244            value, prob = estimator(inst, result_type=Orange.core.GetBoth)
245            res.append((value, prob))
246            if advance:
247                advance()
248        return res
249               
250    def run_estimation(self, method, advance=None):
251        rel = reliability.Learner(self.learner, estimators=[method])
252        estimator = rel(self.train_data)
253        return self.get_estimates(estimator, advance) 
254   
255    def run_SAVar(self, advance=None):
256        est = reliability.SensitivityAnalysis(e=eval(self.var_e))
257        return self.run_estimation(est, advance)
258       
259    def run_SABias(self, advance=None):
260        est = reliability.SensitivityAnalysis(e=eval(self.bias_e))
261        return self.run_estimation(est, advance)
262   
263    def run_BAGV(self, advance=None):
264        est = reliability.BaggingVariance(m=self.bagged_m)
265        return self.run_estimation(est, advance)
266   
267    def run_LCV(self, advance=None):
268        est = reliability.LocalCrossValidation(k=self.local_cv_k)
269        return self.run_estimation(est, advance)
270   
271    def run_CNK(self, advance=None):
272        est = reliability.CNeighbours(k=self.local_pe_k)
273        return self.run_estimation(est, advance)
274   
275    def run_BVCK(self, advance=None):
276        bagv = reliability.BaggingVariance(m=self.bagged_cn_m)
277        cnk = reliability.CNeighbours(k=self.bagged_cn_k)
278        est = reliability.BaggingVarianceCNeighbours(bagv, cnk)
279        return self.run_estimation(est, advance)
280   
281    def run_Mahalanobis(self, advance=None):
282        est = reliability.Mahalanobis(k=self.mahalanobis_k)
283        return self.run_estimation(est, advance)
284   
285    def method_selection_changed(self, method=None):
286        self.commit_button.setEnabled(any([getattr(self, selected) \
287                                for selected, _ in  self.methods]))
288        self.commit_if()
289   
290    def method_param_changed(self, method=None):
291        if method is not None:
292            self.invalidate_results([method])
293        self.commit_if()
294       
295    def commit_if(self):
296        if self.auto_commit:
297            self.commit()
298        else:
299            self.output_changed = True
300           
301    def commit(self):
302        from Orange.data import variable
303       
304        self.run()
305        name_mapper = {"Mahalanobis absolute": "Mahalanobis"}
306        all_predictions = []
307        all_estimates = []
308        score_vars = []
309        features = []
310        table = None
311        if self._test_data() is not None:
312            scores = []
313           
314            if self.include_class and not self.include_input_features:
315                original_class = self._test_data().domain.class_var
316                features.append(original_class)
317               
318            if self.include_class:
319                prediction_var = variable.Continuous("Prediction")
320                features.append(prediction_var)
321               
322            if self.include_error:
323                error_var = variable.Continuous("Error")
324                abs_error_var = variable.Continuous("Abs Error")
325                features.append(error_var)
326                features.append(abs_error_var)
327               
328            for res, (selected, method) in zip(self.results, self.methods):
329                if res is not None and getattr(self, selected):
330                    if selected == "bias_checked":
331                        ei = 1
332                    else:
333                        ei = 0
334                    values, estimates = [], []
335                    for value, probs in res:
336                        values.append(value)
337                        estimates.append(probs.reliability_estimate[ei])
338                    name = estimates[0].method_name
339                    name = name_mapper.get(name, name)
340                    var = variable.Continuous(name)
341                    features.append(var)
342                    score_vars.append(var)
343                    all_predictions.append(values)
344                    all_estimates.append(estimates)
345                   
346            if self.include_input_features:
347                dom = self._test_data().domain
348                domain = Orange.data.Domain(dom.attributes, dom.class_var)
349                domain.add_metas(dom.get_metas())
350                data = Orange.data.Table(domain, self._test_data())
351            else:
352                domain = Orange.data.Domain([])
353                data = Orange.data.Table(domain, [[] for _ in self._test_data()])
354               
355            for f in features:
356                data.domain.add_meta(Orange.core.newmetaid(), f)
357           
358            if self.include_class:
359                for d, inst, pred in zip(data, self._test_data(), all_predictions[0]):
360                    if not self.include_input_features:
361                        d[features[0]] = float(inst.get_class())
362                    d[prediction_var] = float(pred)
363           
364            if self.include_error:
365                for d, inst, pred in zip(data, self._test_data(), all_predictions[0]):
366                    error = float(pred) - float(inst.get_class())
367                    d[error_var] = error
368                    d[abs_error_var] = abs(error)
369                   
370            for estimations, var in zip(all_estimates, score_vars):
371                for d, e in zip(data, estimations):
372                    d[var] = e.estimate
373           
374            table = data
375           
376        self.send("Reliability Scores", table)
377        self.output_changed = True
378       
379       
380if __name__ == "__main__":
381    import sys
382    app = QApplication(sys.argv)
383    w = OWReliability()
384    data = Orange.data.Table("housing")
385    indices = Orange.core.MakeRandomIndices2(p0=20)(data)
386    data = data.select(indices, 0)
387   
388    learner = Orange.regression.tree.TreeLearner()
389    w.set_learner(learner)
390    w.set_train_data(data)
391    w.handleNewSignals()
392    w.show()
393    app.exec_()
394   
395       
Note: See TracBrowser for help on using the repository browser.