source: orange-reliability/_reliability/widgets/OWReliability.py @ 0:55e4bdcfe4e3

Revision 0:55e4bdcfe4e3, 16.8 KB checked in by Matija Polajnar <matija.polajnar@…>, 2 years ago (diff)

Initial version as moved from main Orange. Without documentation.

Line 
1"""
2<name>Reliability</name>
3<contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact>
4<priority>310</priority>
5<icon>icons/Reliability.png</icon>
6"""
7
8from __future__ import absolute_import
9
10import Orange
11import _reliability as reliability
12from Orange.evaluation import testing
13#from Orange.utils import progress_bar_milestones
14from functools import partial
15 
16from OWWidget import *
17import OWGUI
18
19class OWReliability(OWWidget):
20    settingsList = ["variance_checked", "bias_checked", "bagged_variance",
21        "local_cv", "local_model_pred_error", "bagging_variance_cn", 
22        "mahalanobis_distance", "var_e", "bias_e", "bagged_m", "local_cv_k",
23        "local_pe_k", "bagged_cn_m", "bagged_cn_k", "mahalanobis_k",
24        "include_error", "include_class", "include_input_features",
25        "auto_commit"]
26   
27    def __init__(self, parent=None, signalManager=None, title="Reliability"):
28        OWWidget.__init__(self, parent, signalManager, title, wantMainArea=False)
29       
30        self.inputs = [("Learner", Orange.core.Learner, self.set_learner),
31                       ("Training Data", Orange.data.Table, self.set_train_data),
32                       ("Test Data", Orange.data.Table, self.set_test_data)]
33       
34        self.outputs = [("Reliability Scores", Orange.data.Table)]
35       
36        self.variance_checked = False
37        self.bias_checked = False
38        self.bagged_variance = False
39        self.local_cv = False
40        self.local_model_pred_error = False
41        self.bagging_variance_cn = False
42        self.mahalanobis_distance = True
43       
44        self.var_e = "0.01, 0.1, 0.5, 1.0, 2.0"
45        self.bias_e =  "0.01, 0.1, 0.5, 1.0, 2.0"
46        self.bagged_m = 10
47        self.local_cv_k = 2
48        self.local_pe_k = 5
49        self.bagged_cn_m = 5
50        self.bagged_cn_k = 1
51        self.mahalanobis_k = 3
52       
53        self.include_error = True
54        self.include_class = True
55        self.include_input_features = False
56        self.auto_commit = False
57       
58        # (selected attr name, getter function, count of returned estimators, indices of estimator results to use)
59        self.estimators = \
60            [("variance_checked", self.get_SAVar, 3, [0]),
61             ("bias_checked", self.get_SABias, 3, [1, 2]),
62             ("bagged_variance", self.get_BAGV, 1, [0]),
63             ("local_cv", self.get_LCV, 1, [0]),
64             ("local_model_pred_error", self.get_CNK, 2, [0, 1]),
65             ("bagging_variance_cn", self.get_BVCK, 4, [0]),
66             ("mahalanobis_distance", self.get_Mahalanobis, 1, [0])]
67       
68        #####
69        # GUI
70        #####
71        self.loadSettings()
72       
73        box = OWGUI.widgetBox(self.controlArea, "Info", addSpace=True)
74        self.info_box = OWGUI.widgetLabel(box, "\n\n")
75       
76        rbox = OWGUI.widgetBox(self.controlArea, "Methods", addSpace=True)
77        def method_box(parent, name, value):
78            box = OWGUI.widgetBox(rbox, name, flat=False)
79            box.setCheckable(True)
80            box.setChecked(bool(getattr(self, value)))
81            self.connect(box, SIGNAL("toggled(bool)"),
82                         lambda on: (setattr(self, value, on),
83                                     self.method_selection_changed(value)))
84            return box
85           
86        e_validator = QRegExpValidator(QRegExp(r"\s*(-?[0-9]+(\.[0-9]*)\s*,\s*)+"), self)
87        variance_box = method_box(rbox, "Sensitivity analysis (variance)",
88                                  "variance_checked")
89        OWGUI.lineEdit(variance_box, self, "var_e", "Sensitivities:", 
90                       tooltip="List of possible e values (comma separated) for SAvar reliability estimates.", 
91                       callback=partial(self.method_param_changed, 0),
92                       validator=e_validator)
93       
94        bias_box = method_box(rbox, "Sensitivity analysis (bias)",
95                                    "bias_checked")
96        OWGUI.lineEdit(bias_box, self, "bias_e", "Sensitivities:", 
97                       tooltip="List of possible e values (comma separated) for SAbias reliability estimates.", 
98                       callback=partial(self.method_param_changed, 1),
99                       validator=e_validator)
100       
101        bagged_box = method_box(rbox, "Variance of bagged models",
102                                "bagged_variance")
103       
104        OWGUI.spin(bagged_box, self, "bagged_m", 2, 100, step=1,
105                   label="Models:",
106                   tooltip="Number of bagged models to be used with BAGV estimate.",
107                   callback=partial(self.method_param_changed, 2),
108                   keyboardTracking=False)
109       
110        local_cv_box = method_box(rbox, "Local cross validation",
111                                  "local_cv")
112       
113        OWGUI.spin(local_cv_box, self, "local_cv_k", 2, 20, step=1,
114                   label="Nearest neighbors:",
115                   tooltip="Number of nearest neighbors used in LCV estimate.",
116                   callback=partial(self.method_param_changed, 3),
117                   keyboardTracking=False)
118       
119        local_pe = method_box(rbox, "Local modeling of prediction error",
120                              "local_model_pred_error")
121       
122        OWGUI.spin(local_pe, self, "local_pe_k", 1, 20, step=1,
123                   label="Nearest neighbors:",
124                   tooltip="Number of nearest neighbors used in CNK estimate.",
125                   callback=partial(self.method_param_changed, 4),
126                   keyboardTracking=False)
127       
128        bagging_cnn = method_box(rbox, "Bagging variance c-neighbors",
129                                 "bagging_variance_cn")
130       
131        OWGUI.spin(bagging_cnn, self, "bagged_cn_m", 2, 100, step=1,
132                   label="Models:",
133                   tooltip="Number of bagged models to be used with BVCK estimate.",
134                   callback=partial(self.method_param_changed, 5),
135                   keyboardTracking=False)
136       
137        OWGUI.spin(bagging_cnn, self, "bagged_cn_k", 1, 20, step=1,
138                   label="Nearest neighbors:",
139                   tooltip="Number of nearest neighbors used in BVCK estimate.",
140                   callback=partial(self.method_param_changed, 5),
141                   keyboardTracking=False)
142       
143        mahalanobis_box = method_box(rbox, "Mahalanobis distance",
144                                     "mahalanobis_distance")
145        OWGUI.spin(mahalanobis_box, self, "mahalanobis_k", 1, 20, step=1,
146                   label="Nearest neighbors:",
147                   tooltip="Number of nearest neighbors used in BVCK estimate.",
148                   callback=partial(self.method_param_changed, 6),
149                   keyboardTracking=False)
150       
151        box = OWGUI.widgetBox(self.controlArea, "Output")
152       
153        OWGUI.checkBox(box, self, "include_error", "Include prediction error",
154                       tooltip="Include prediction error in the output",
155                       callback=self.commit_if)
156       
157        OWGUI.checkBox(box, self, "include_class", "Include original class and prediction",
158                       tooltip="Include original class and prediction in the output.",
159                       callback=self.commit_if)
160       
161        OWGUI.checkBox(box, self, "include_input_features", "Include input features",
162                       tooltip="Include features from the input data set.",
163                       callback=self.commit_if)
164       
165        cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change",
166                            callback=self.commit_if)
167       
168        self.commit_button = b = OWGUI.button(box, self, "Commit",
169                                              callback=self.commit,
170                                              autoDefault=True)
171       
172        OWGUI.setStopper(self, b, cb, "output_changed", callback=self.commit)
173       
174        self.commit_button.setEnabled(any([getattr(self, selected) \
175                                for selected, _, _, _ in  self.estimators]))
176       
177        self.learner = None
178        self.train_data = None
179        self.test_data = None
180        self.output_changed = False
181        self.train_data_has_no_class = False
182        self.train_data_has_discrete_class = False
183        self.invalidate_results()
184       
185    def set_train_data(self, data=None):
186        self.error()
187        self.train_data_has_no_class = False
188        self.train_data_has_discrete_class = False
189       
190        if data is not None:
191            if not self.isDataWithClass(data, Orange.core.VarTypes.Continuous):
192                if not data.domain.class_var:
193                    self.train_data_has_no_class = True
194                elif not isinstance(data.domain.class_var,
195                                    Orange.feature.Continuous):
196                    self.train_data_has_discrete_class = True
197                   
198                data = None
199       
200        self.train_data = data
201        self.invalidate_results() 
202       
203    def set_test_data(self, data=None):
204        self.test_data = data
205        self.invalidate_results()
206       
207    def set_learner(self, learner=None):
208        self.learner = learner
209        self.invalidate_results()
210       
211    def handleNewSignals(self):
212        name = "No learner on input"
213        train = "No train data on input"
214        test = "No test data on input"
215       
216        if self.learner:
217            name = "Learner: " + (getattr(self.learner, "name") or type(self.learner).__name__)
218           
219        if self.train_data is not None:
220            train = "Train Data: %i features, %i instances" % \
221                (len(self.train_data.domain), len(self.train_data))
222        elif self.train_data_has_no_class:
223            train = "Train Data has no class variable"
224        elif self.train_data_has_discrete_class:
225            train = "Train Data doesn't have a continuous class"
226           
227        if self.test_data is not None:
228            test = "Test Data: %i features, %i instances" % \
229                (len(self.test_data.domain), len(self.test_data))
230        elif self.train_data:
231            test = "Test data: using training data"
232       
233        self.info_box.setText("\n".join([name, train, test]))
234       
235        if self.learner and self._test_data() is not None:
236            self.commit_if()
237       
238    def invalidate_results(self, which=None):
239        if which is None:
240            self.results = [None for f in self.estimators]
241#            print "Invalidating all"
242        else:
243            for i in which:
244                self.results[i] = None
245#            print "Invalidating", which
246       
247    def run(self):
248        plan = []
249        estimate_index = 0
250        for i, (selected, method, count, offsets) in enumerate(self.estimators):
251            if self.results[i] is None and getattr(self, selected):
252                plan.append((i, method, [estimate_index + offset for offset in offsets]))
253                estimate_index += count
254               
255        estimators = [method() for _, method, _ in plan]
256       
257        if not estimators:
258            return
259           
260        pb = OWGUI.ProgressBar(self, len(self._test_data()))
261        estimates = self.run_estimation(estimators, pb.advance)
262        pb.finish()
263       
264        self.predictions = [v for v, _ in estimates]
265        estimates = [prob.reliability_estimate for _, prob in estimates]
266       
267        for i, (index, method, estimate_indices) in enumerate(plan):
268            self.results[index] = [[e[est_index] for e in estimates] \
269                                   for est_index in estimate_indices]
270       
271    def _test_data(self):
272        if self.test_data is not None:
273            return self.test_data
274        else:
275            return self.train_data
276   
277    def get_estimates(self, estimator, advance=None):
278        test = self._test_data()
279        res = []
280        for i, inst in enumerate(test):
281            value, prob = estimator(inst, result_type=Orange.core.GetBoth)
282            res.append((value, prob))
283            if advance:
284                advance()
285        return res
286               
287    def run_estimation(self, estimators, advance=None):
288        rel = reliability.Learner(self.learner, estimators=estimators)
289        estimator = rel(self.train_data)
290        return self.get_estimates(estimator, advance) 
291   
292    def get_SAVar(self):
293        return reliability.SensitivityAnalysis(e=eval(self.var_e))
294   
295    def get_SABias(self):
296        return reliability.SensitivityAnalysis(e=eval(self.bias_e))
297   
298    def get_BAGV(self):
299        return reliability.BaggingVariance(m=self.bagged_m)
300   
301    def get_LCV(self):
302        return reliability.LocalCrossValidation(k=self.local_cv_k)
303   
304    def get_CNK(self):
305        return reliability.CNeighbours(k=self.local_pe_k)
306   
307    def get_BVCK(self):
308        bagv = reliability.BaggingVariance(m=self.bagged_cn_m)
309        cnk = reliability.CNeighbours(k=self.bagged_cn_k)
310        return reliability.BaggingVarianceCNeighbours(bagv, cnk)
311   
312    def get_Mahalanobis(self):
313        return reliability.Mahalanobis(k=self.mahalanobis_k)
314   
315    def method_selection_changed(self, method=None):
316        self.commit_button.setEnabled(any([getattr(self, selected) \
317                                for selected, _, _, _ in  self.estimators]))
318        self.commit_if()
319   
320    def method_param_changed(self, method=None):
321        if method is not None:
322            self.invalidate_results([method])
323        self.commit_if()
324       
325    def commit_if(self):
326        if self.auto_commit:
327            self.commit()
328        else:
329            self.output_changed = True
330           
331    def commit(self):
332        from Orange import feature as variable
333        name_mapper = {"Mahalanobis absolute": "Mahalanobis"}
334        all_predictions = []
335        all_estimates = []
336        score_vars = []
337        features = []
338        table = None
339        if self.learner and self.train_data is not None \
340                and self._test_data() is not None:
341            self.run()
342           
343            scores = []
344            if self.include_class and not self.include_input_features:
345                original_class = self._test_data().domain.class_var
346                features.append(original_class)
347               
348            if self.include_class:
349                prediction_var = variable.Continuous("Prediction")
350                features.append(prediction_var)
351               
352            if self.include_error:
353                error_var = variable.Continuous("Error")
354                abs_error_var = variable.Continuous("Abs. Error")
355                features.append(error_var)
356                features.append(abs_error_var)
357               
358            for estimates, (selected, method, _, _) in zip(self.results, self.estimators):
359                if estimates is not None and getattr(self, selected):
360                    for estimate in estimates:
361                        name = estimate[0].method_name
362                        name = name_mapper.get(name, name)
363                        var = variable.Continuous(name)
364                        features.append(var)
365                        score_vars.append(var)
366                        all_estimates.append(estimate)
367                   
368            if self.include_input_features:
369                dom = self._test_data().domain
370                attributes = list(dom.attributes) + features
371                domain = Orange.data.Domain(attributes, dom.class_var)
372                domain.add_metas(dom.get_metas())
373               
374                data = Orange.data.Table(domain, self._test_data())
375            else:
376                domain = Orange.data.Domain(features, None)
377                data = Orange.data.Table(domain, [[None] * len(features) for _ in self._test_data()])
378           
379            if self.include_class:
380                for d, inst, pred in zip(data, self._test_data(), self.predictions):
381                    if not self.include_input_features:
382                        d[features[0]] = float(inst.get_class())
383                    d[prediction_var] = float(pred)
384           
385            if self.include_error:
386                for d, inst, pred in zip(data, self._test_data(), self.predictions):
387                    error = float(pred) - float(inst.get_class())
388                    d[error_var] = error
389                    d[abs_error_var] = abs(error)
390                   
391            for estimations, var in zip(all_estimates, score_vars):
392                for d, e in zip(data, estimations):
393                    d[var] = e.estimate
394           
395            table = data
396           
397        self.send("Reliability Scores", table)
398        self.output_changed = False
399       
400       
401if __name__ == "__main__":
402    import sys
403    app = QApplication(sys.argv)
404    w = OWReliability()
405    data = Orange.data.Table("housing")
406    indices = Orange.core.MakeRandomIndices2(p0=20)(data)
407    data = data.select(indices, 0)
408   
409    learner = Orange.regression.tree.TreeLearner()
410    w.set_learner(learner)
411    w.set_train_data(data)
412    w.handleNewSignals()
413    w.show()
414    app.exec_()
415   
416       
Note: See TracBrowser for help on using the repository browser.