source: orange-reliability/orangecontrib/reliability/widgets/OWReliability.py @ 36:2d696344f9aa

Revision 36:2d696344f9aa, 17.3 KB checked in by Ales Erjavec <ales.erjavec@…>, 7 months ago (diff)

Added new style widget meta description.

Line 
1"""
2<name>Reliability</name>
3<contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact>
4<priority>310</priority>
5<icon>icons/Reliability.svg</icon>
6"""
7
8from __future__ import absolute_import
9
10import Orange
11import orangecontrib.reliability as reliability
12from Orange.evaluation import testing
13#from Orange.utils import progress_bar_milestones
14from functools import partial
15 
16from OWWidget import *
17import OWGUI
18
19NAME = "Reliability"
20DESCRIPTION = "Estimate the reliability of learners."
21CATEGORY = "Evaluate"
22PRIORITY = 310
23ICON = "icons/Reliability.svg"
24
25INPUTS = [("Learner", Orange.core.Learner, "set_learner"),
26          ("Training Data", Orange.data.Table, "set_train_data"),
27          ("Test Data", Orange.data.Table, "set_test_data")]
28
29OUTPUTS = [("Reliability Scores", Orange.data.Table)]
30
31REPLACES = ["_reliability.widgets.OWReliability.OWReliability"]
32
33
34class OWReliability(OWWidget):
35    settingsList = ["variance_checked", "bias_checked", "bagged_variance",
36        "local_cv", "local_model_pred_error", "bagging_variance_cn", 
37        "mahalanobis_distance", "var_e", "bias_e", "bagged_m", "local_cv_k",
38        "local_pe_k", "bagged_cn_m", "bagged_cn_k", "mahalanobis_k",
39        "include_error", "include_class", "include_input_features",
40        "auto_commit"]
41   
42    def __init__(self, parent=None, signalManager=None, title="Reliability"):
43        OWWidget.__init__(self, parent, signalManager, title, wantMainArea=False)
44       
45        self.inputs = [("Learner", Orange.core.Learner, self.set_learner),
46                       ("Training Data", Orange.data.Table, self.set_train_data),
47                       ("Test Data", Orange.data.Table, self.set_test_data)]
48       
49        self.outputs = [("Reliability Scores", Orange.data.Table)]
50       
51        self.variance_checked = False
52        self.bias_checked = False
53        self.bagged_variance = False
54        self.local_cv = False
55        self.local_model_pred_error = False
56        self.bagging_variance_cn = False
57        self.mahalanobis_distance = True
58       
59        self.var_e = "0.01, 0.1, 0.5, 1.0, 2.0"
60        self.bias_e =  "0.01, 0.1, 0.5, 1.0, 2.0"
61        self.bagged_m = 10
62        self.local_cv_k = 2
63        self.local_pe_k = 5
64        self.bagged_cn_m = 5
65        self.bagged_cn_k = 1
66        self.mahalanobis_k = 3
67       
68        self.include_error = True
69        self.include_class = True
70        self.include_input_features = False
71        self.auto_commit = False
72       
73        # (selected attr name, getter function, count of returned estimators, indices of estimator results to use)
74        self.estimators = \
75            [("variance_checked", self.get_SAVar, 3, [0]),
76             ("bias_checked", self.get_SABias, 3, [1, 2]),
77             ("bagged_variance", self.get_BAGV, 1, [0]),
78             ("local_cv", self.get_LCV, 1, [0]),
79             ("local_model_pred_error", self.get_CNK, 2, [0, 1]),
80             ("bagging_variance_cn", self.get_BVCK, 4, [0]),
81             ("mahalanobis_distance", self.get_Mahalanobis, 1, [0])]
82       
83        #####
84        # GUI
85        #####
86        self.loadSettings()
87       
88        box = OWGUI.widgetBox(self.controlArea, "Info", addSpace=True)
89        self.info_box = OWGUI.widgetLabel(box, "\n\n")
90       
91        rbox = OWGUI.widgetBox(self.controlArea, "Methods", addSpace=True)
92        def method_box(parent, name, value):
93            box = OWGUI.widgetBox(rbox, name, flat=False)
94            box.setCheckable(True)
95            box.setChecked(bool(getattr(self, value)))
96            self.connect(box, SIGNAL("toggled(bool)"),
97                         lambda on: (setattr(self, value, on),
98                                     self.method_selection_changed(value)))
99            return box
100           
101        e_validator = QRegExpValidator(QRegExp(r"\s*(-?[0-9]+(\.[0-9]*)\s*,\s*)+"), self)
102        variance_box = method_box(rbox, "Sensitivity analysis (variance)",
103                                  "variance_checked")
104        OWGUI.lineEdit(variance_box, self, "var_e", "Sensitivities:", 
105                       tooltip="List of possible e values (comma separated) for SAvar reliability estimates.", 
106                       callback=partial(self.method_param_changed, 0),
107                       validator=e_validator)
108       
109        bias_box = method_box(rbox, "Sensitivity analysis (bias)",
110                                    "bias_checked")
111        OWGUI.lineEdit(bias_box, self, "bias_e", "Sensitivities:", 
112                       tooltip="List of possible e values (comma separated) for SAbias reliability estimates.", 
113                       callback=partial(self.method_param_changed, 1),
114                       validator=e_validator)
115       
116        bagged_box = method_box(rbox, "Variance of bagged models",
117                                "bagged_variance")
118       
119        OWGUI.spin(bagged_box, self, "bagged_m", 2, 100, step=1,
120                   label="Models:",
121                   tooltip="Number of bagged models to be used with BAGV estimate.",
122                   callback=partial(self.method_param_changed, 2),
123                   keyboardTracking=False)
124       
125        local_cv_box = method_box(rbox, "Local cross validation",
126                                  "local_cv")
127       
128        OWGUI.spin(local_cv_box, self, "local_cv_k", 2, 20, step=1,
129                   label="Nearest neighbors:",
130                   tooltip="Number of nearest neighbors used in LCV estimate.",
131                   callback=partial(self.method_param_changed, 3),
132                   keyboardTracking=False)
133       
134        local_pe = method_box(rbox, "Local modeling of prediction error",
135                              "local_model_pred_error")
136       
137        OWGUI.spin(local_pe, self, "local_pe_k", 1, 20, step=1,
138                   label="Nearest neighbors:",
139                   tooltip="Number of nearest neighbors used in CNK estimate.",
140                   callback=partial(self.method_param_changed, 4),
141                   keyboardTracking=False)
142       
143        bagging_cnn = method_box(rbox, "Bagging variance c-neighbors",
144                                 "bagging_variance_cn")
145       
146        OWGUI.spin(bagging_cnn, self, "bagged_cn_m", 2, 100, step=1,
147                   label="Models:",
148                   tooltip="Number of bagged models to be used with BVCK estimate.",
149                   callback=partial(self.method_param_changed, 5),
150                   keyboardTracking=False)
151       
152        OWGUI.spin(bagging_cnn, self, "bagged_cn_k", 1, 20, step=1,
153                   label="Nearest neighbors:",
154                   tooltip="Number of nearest neighbors used in BVCK estimate.",
155                   callback=partial(self.method_param_changed, 5),
156                   keyboardTracking=False)
157       
158        mahalanobis_box = method_box(rbox, "Mahalanobis distance",
159                                     "mahalanobis_distance")
160        OWGUI.spin(mahalanobis_box, self, "mahalanobis_k", 1, 20, step=1,
161                   label="Nearest neighbors:",
162                   tooltip="Number of nearest neighbors used in BVCK estimate.",
163                   callback=partial(self.method_param_changed, 6),
164                   keyboardTracking=False)
165       
166        box = OWGUI.widgetBox(self.controlArea, "Output")
167       
168        OWGUI.checkBox(box, self, "include_error", "Include prediction error",
169                       tooltip="Include prediction error in the output",
170                       callback=self.commit_if)
171       
172        OWGUI.checkBox(box, self, "include_class", "Include original class and prediction",
173                       tooltip="Include original class and prediction in the output.",
174                       callback=self.commit_if)
175       
176        OWGUI.checkBox(box, self, "include_input_features", "Include input features",
177                       tooltip="Include features from the input data set.",
178                       callback=self.commit_if)
179       
180        cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change",
181                            callback=self.commit_if)
182       
183        self.commit_button = b = OWGUI.button(box, self, "Commit",
184                                              callback=self.commit,
185                                              autoDefault=True)
186       
187        OWGUI.setStopper(self, b, cb, "output_changed", callback=self.commit)
188       
189        self.commit_button.setEnabled(any([getattr(self, selected) \
190                                for selected, _, _, _ in  self.estimators]))
191       
192        self.learner = None
193        self.train_data = None
194        self.test_data = None
195        self.output_changed = False
196        self.train_data_has_no_class = False
197        self.train_data_has_discrete_class = False
198        self.invalidate_results()
199       
200    def set_train_data(self, data=None):
201        self.error()
202        self.train_data_has_no_class = False
203        self.train_data_has_discrete_class = False
204       
205        if data is not None:
206            if not self.isDataWithClass(data, Orange.core.VarTypes.Continuous):
207                if not data.domain.class_var:
208                    self.train_data_has_no_class = True
209                elif not isinstance(data.domain.class_var,
210                                    Orange.feature.Continuous):
211                    self.train_data_has_discrete_class = True
212                   
213                data = None
214       
215        self.train_data = data
216        self.invalidate_results() 
217       
218    def set_test_data(self, data=None):
219        self.test_data = data
220        self.invalidate_results()
221       
222    def set_learner(self, learner=None):
223        self.learner = learner
224        self.invalidate_results()
225       
226    def handleNewSignals(self):
227        name = "No learner on input"
228        train = "No train data on input"
229        test = "No test data on input"
230       
231        if self.learner:
232            name = "Learner: " + (getattr(self.learner, "name") or type(self.learner).__name__)
233           
234        if self.train_data is not None:
235            train = "Train Data: %i features, %i instances" % \
236                (len(self.train_data.domain), len(self.train_data))
237        elif self.train_data_has_no_class:
238            train = "Train Data has no class variable"
239        elif self.train_data_has_discrete_class:
240            train = "Train Data doesn't have a continuous class"
241           
242        if self.test_data is not None:
243            test = "Test Data: %i features, %i instances" % \
244                (len(self.test_data.domain), len(self.test_data))
245        elif self.train_data:
246            test = "Test data: using training data"
247       
248        self.info_box.setText("\n".join([name, train, test]))
249       
250        if self.learner and self._test_data() is not None:
251            self.commit_if()
252       
253    def invalidate_results(self, which=None):
254        if which is None:
255            self.results = [None for f in self.estimators]
256#            print "Invalidating all"
257        else:
258            for i in which:
259                self.results[i] = None
260#            print "Invalidating", which
261       
262    def run(self):
263        plan = []
264        estimate_index = 0
265        for i, (selected, method, count, offsets) in enumerate(self.estimators):
266            if self.results[i] is None and getattr(self, selected):
267                plan.append((i, method, [estimate_index + offset for offset in offsets]))
268                estimate_index += count
269               
270        estimators = [method() for _, method, _ in plan]
271       
272        if not estimators:
273            return
274           
275        pb = OWGUI.ProgressBar(self, len(self._test_data()))
276        estimates = self.run_estimation(estimators, pb.advance)
277        pb.finish()
278       
279        self.predictions = [v for v, _ in estimates]
280        estimates = [prob.reliability_estimate for _, prob in estimates]
281       
282        for i, (index, method, estimate_indices) in enumerate(plan):
283            self.results[index] = [[e[est_index] for e in estimates] \
284                                   for est_index in estimate_indices]
285       
286    def _test_data(self):
287        if self.test_data is not None:
288            return self.test_data
289        else:
290            return self.train_data
291   
292    def get_estimates(self, estimator, advance=None):
293        test = self._test_data()
294        res = []
295        for i, inst in enumerate(test):
296            value, prob = estimator(inst, result_type=Orange.core.GetBoth)
297            res.append((value, prob))
298            if advance:
299                advance()
300        return res
301               
302    def run_estimation(self, estimators, advance=None):
303        rel = reliability.Learner(self.learner, estimators=estimators)
304        estimator = rel(self.train_data)
305        return self.get_estimates(estimator, advance) 
306   
307    def get_SAVar(self):
308        return reliability.SensitivityAnalysis(e=eval(self.var_e))
309   
310    def get_SABias(self):
311        return reliability.SensitivityAnalysis(e=eval(self.bias_e))
312   
313    def get_BAGV(self):
314        return reliability.BaggingVariance(m=self.bagged_m)
315   
316    def get_LCV(self):
317        return reliability.LocalCrossValidation(k=self.local_cv_k)
318   
319    def get_CNK(self):
320        return reliability.CNeighbours(k=self.local_pe_k)
321   
322    def get_BVCK(self):
323        bagv = reliability.BaggingVariance(m=self.bagged_cn_m)
324        cnk = reliability.CNeighbours(k=self.bagged_cn_k)
325        return reliability.BaggingVarianceCNeighbours(bagv, cnk)
326   
327    def get_Mahalanobis(self):
328        return reliability.Mahalanobis(k=self.mahalanobis_k)
329   
330    def method_selection_changed(self, method=None):
331        self.commit_button.setEnabled(any([getattr(self, selected) \
332                                for selected, _, _, _ in  self.estimators]))
333        self.commit_if()
334   
335    def method_param_changed(self, method=None):
336        if method is not None:
337            self.invalidate_results([method])
338        self.commit_if()
339       
340    def commit_if(self):
341        if self.auto_commit:
342            self.commit()
343        else:
344            self.output_changed = True
345           
346    def commit(self):
347        from Orange import feature as variable
348        name_mapper = {"Mahalanobis absolute": "Mahalanobis"}
349        all_predictions = []
350        all_estimates = []
351        score_vars = []
352        features = []
353        table = None
354        if self.learner and self.train_data is not None \
355                and self._test_data() is not None:
356            self.run()
357           
358            scores = []
359            if self.include_class and not self.include_input_features:
360                original_class = self._test_data().domain.class_var
361                features.append(original_class)
362               
363            if self.include_class:
364                prediction_var = variable.Continuous("Prediction")
365                features.append(prediction_var)
366               
367            if self.include_error:
368                error_var = variable.Continuous("Error")
369                abs_error_var = variable.Continuous("Abs. Error")
370                features.append(error_var)
371                features.append(abs_error_var)
372               
373            for estimates, (selected, method, _, _) in zip(self.results, self.estimators):
374                if estimates is not None and getattr(self, selected):
375                    for estimate in estimates:
376                        name = estimate[0].method_name
377                        name = name_mapper.get(name, name)
378                        var = variable.Continuous(name)
379                        features.append(var)
380                        score_vars.append(var)
381                        all_estimates.append(estimate)
382                   
383            if self.include_input_features:
384                dom = self._test_data().domain
385                attributes = list(dom.attributes) + features
386                domain = Orange.data.Domain(attributes, dom.class_var)
387                domain.add_metas(dom.get_metas())
388               
389                data = Orange.data.Table(domain, self._test_data())
390            else:
391                domain = Orange.data.Domain(features, None)
392                data = Orange.data.Table(domain, [[None] * len(features) for _ in self._test_data()])
393           
394            if self.include_class:
395                for d, inst, pred in zip(data, self._test_data(), self.predictions):
396                    if not self.include_input_features:
397                        d[features[0]] = float(inst.get_class())
398                    d[prediction_var] = float(pred)
399           
400            if self.include_error:
401                for d, inst, pred in zip(data, self._test_data(), self.predictions):
402                    error = float(pred) - float(inst.get_class())
403                    d[error_var] = error
404                    d[abs_error_var] = abs(error)
405                   
406            for estimations, var in zip(all_estimates, score_vars):
407                for d, e in zip(data, estimations):
408                    d[var] = e.estimate
409           
410            table = data
411           
412        self.send("Reliability Scores", table)
413        self.output_changed = False
414       
415       
416if __name__ == "__main__":
417    import sys
418    app = QApplication(sys.argv)
419    w = OWReliability()
420    data = Orange.data.Table("housing")
421    indices = Orange.core.MakeRandomIndices2(p0=20)(data)
422    data = data.select(indices, 0)
423   
424    learner = Orange.regression.tree.TreeLearner()
425    w.set_learner(learner)
426    w.set_train_data(data)
427    w.handleNewSignals()
428    w.show()
429    app.exec_()
430   
431       
Note: See TracBrowser for help on using the repository browser.