| 1 | """ |
|---|
| 2 | <name>Reliability</name> |
|---|
| 3 | <contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact> |
|---|
| 4 | <priority>310</priority> |
|---|
| 5 | """ |
|---|
| 6 | |
|---|
| 7 | import Orange |
|---|
| 8 | from Orange.evaluation import reliability |
|---|
| 9 | from Orange.evaluation import testing |
|---|
| 10 | #from Orange.misc import progress_bar_milestones |
|---|
| 11 | from functools import partial |
|---|
| 12 | |
|---|
| 13 | from OWWidget import * |
|---|
| 14 | import OWGUI |
|---|
| 15 | |
|---|
| 16 | class OWReliability(OWWidget): |
|---|
| 17 | settingsList = ["variance_checked", "bias_checked", "bagged_variance", |
|---|
| 18 | "local_cv", "local_model_pred_error", "bagging_variance_cn", |
|---|
| 19 | "mahalanobis_distance", "var_e", "bias_e", "bagged_m", "local_cv_k", |
|---|
| 20 | "local_pe_k", "bagged_cn_m", "bagged_cn_k", "mahalanobis_k", |
|---|
| 21 | "include_error", "include_class", "include_input_features", |
|---|
| 22 | "auto_commit"] |
|---|
| 23 | |
|---|
| 24 | def __init__(self, parent=None, signalManager=None, title="Reliability"): |
|---|
| 25 | OWWidget.__init__(self, parent, signalManager, title, wantMainArea=False) |
|---|
| 26 | |
|---|
| 27 | self.inputs = [("Learner", Orange.core.Learner, self.set_learner), |
|---|
| 28 | ("Train Data", Orange.data.Table, self.set_train_data), |
|---|
| 29 | ("Test Data", Orange.data.Table, self.set_test_data)] |
|---|
| 30 | |
|---|
| 31 | self.outputs = [("Reliability Scores", Orange.data.Table)] |
|---|
| 32 | |
|---|
| 33 | self.variance_checked = False |
|---|
| 34 | self.bias_checked = False |
|---|
| 35 | self.bagged_variance = False |
|---|
| 36 | self.local_cv = False |
|---|
| 37 | self.local_model_pred_error = False |
|---|
| 38 | self.bagging_variance_cn = False |
|---|
| 39 | self.mahalanobis_distance = True |
|---|
| 40 | |
|---|
| 41 | self.var_e = "0.01, 0.1, 0.5, 1.0, 2.0" |
|---|
| 42 | self.bias_e = "0.01, 0.1, 0.5, 1.0, 2.0" |
|---|
| 43 | self.bagged_m = 50 |
|---|
| 44 | self.local_cv_k = 2 |
|---|
| 45 | self.local_pe_k = 5 |
|---|
| 46 | self.bagged_cn_m = 5 |
|---|
| 47 | self.bagged_cn_k = 1 |
|---|
| 48 | self.mahalanobis_k = 5 |
|---|
| 49 | |
|---|
| 50 | self.include_error = True |
|---|
| 51 | self.include_class = True |
|---|
| 52 | self.include_input_features = True |
|---|
| 53 | self.auto_commit = False |
|---|
| 54 | |
|---|
| 55 | self.methods = [("variance_checked", self.run_SAVar), |
|---|
| 56 | ("bias_checked", self.run_SABias), |
|---|
| 57 | ("bagged_variance", self.run_BAGV), |
|---|
| 58 | ("local_cv", self.run_LCV), |
|---|
| 59 | ("local_model_pred_error", self.run_CNK), |
|---|
| 60 | ("bagging_variance_cn", self.run_BVCK), |
|---|
| 61 | ("mahalanobis_distance", self.run_Mahalanobis)] |
|---|
| 62 | |
|---|
| 63 | ##### |
|---|
| 64 | # GUI |
|---|
| 65 | ##### |
|---|
| 66 | self.loadSettings() |
|---|
| 67 | |
|---|
| 68 | box = OWGUI.widgetBox(self.controlArea, "Info", addSpace=True) |
|---|
| 69 | self.info_box = OWGUI.widgetLabel(box, "\n\n") |
|---|
| 70 | |
|---|
| 71 | rbox = OWGUI.widgetBox(self.controlArea, "Methods", addSpace=True) |
|---|
| 72 | def method_box(parent, name, value): |
|---|
| 73 | box = OWGUI.widgetBox(rbox, name, flat=False) |
|---|
| 74 | box.setCheckable(True) |
|---|
| 75 | box.setChecked(bool(getattr(self, value))) |
|---|
| 76 | self.connect(box, SIGNAL("toggled(bool)"), |
|---|
| 77 | lambda on: (setattr(self, value, on), |
|---|
| 78 | self.method_selection_changed(value))) |
|---|
| 79 | return box |
|---|
| 80 | |
|---|
| 81 | e_validator = QRegExpValidator(QRegExp(r"\s*(-?[0-9]+(\.[0-9]*)\s*,\s*)+"), self) |
|---|
| 82 | variance_box = method_box(rbox, "Sensitivity analysis (variance)", |
|---|
| 83 | "variance_checked") |
|---|
| 84 | OWGUI.lineEdit(variance_box, self, "var_e", "Sensitivities:", |
|---|
| 85 | tooltip="List of possible e values (comma separated) for SAvar reliability estimates.", |
|---|
| 86 | callback=partial(self.method_param_changed, 0), |
|---|
| 87 | validator=e_validator) |
|---|
| 88 | |
|---|
| 89 | bias_box = method_box(rbox, "Sensitivity analysis (bias)", |
|---|
| 90 | "bias_checked") |
|---|
| 91 | OWGUI.lineEdit(bias_box, self, "bias_e", "Sensitivities:", |
|---|
| 92 | tooltip="List of possible e values (comma separated) for SAbias reliability estimates.", |
|---|
| 93 | callback=partial(self.method_param_changed, 1), |
|---|
| 94 | validator=e_validator) |
|---|
| 95 | |
|---|
| 96 | bagged_box = method_box(rbox, "Variance of bagged models", |
|---|
| 97 | "bagged_variance") |
|---|
| 98 | |
|---|
| 99 | OWGUI.spin(bagged_box, self, "bagged_m", 2, 100, step=1, |
|---|
| 100 | label="Models:", |
|---|
| 101 | tooltip="Number of bagged models to be used with BAGV estimate.", |
|---|
| 102 | callback=partial(self.method_param_changed, 2), |
|---|
| 103 | keyboardTracking=False) |
|---|
| 104 | |
|---|
| 105 | local_cv_box = method_box(rbox, "Local cross validation", |
|---|
| 106 | "local_cv") |
|---|
| 107 | |
|---|
| 108 | OWGUI.spin(local_cv_box, self, "local_cv_k", 2, 20, step=1, |
|---|
| 109 | label="Nearest neighbors:", |
|---|
| 110 | tooltip="Number of nearest neighbors used in LCV estimate.", |
|---|
| 111 | callback=partial(self.method_param_changed, 3), |
|---|
| 112 | keyboardTracking=False) |
|---|
| 113 | |
|---|
| 114 | local_pe = method_box(rbox, "Local modeling of prediction error", |
|---|
| 115 | "local_model_pred_error") |
|---|
| 116 | |
|---|
| 117 | OWGUI.spin(local_pe, self, "local_pe_k", 1, 20, step=1, |
|---|
| 118 | label="Nearest neighbors:", |
|---|
| 119 | tooltip="Number of nearest neighbors used in CNK estimate.", |
|---|
| 120 | callback=partial(self.method_param_changed, 4), |
|---|
| 121 | keyboardTracking=False) |
|---|
| 122 | |
|---|
| 123 | bagging_cnn = method_box(rbox, "Bagging variance c-neighbors", |
|---|
| 124 | "bagging_variance_cn") |
|---|
| 125 | |
|---|
| 126 | OWGUI.spin(bagging_cnn, self, "bagged_cn_m", 2, 100, step=1, |
|---|
| 127 | label="Models:", |
|---|
| 128 | tooltip="Number of bagged models to be used with BVCK estimate.", |
|---|
| 129 | callback=partial(self.method_param_changed, 5), |
|---|
| 130 | keyboardTracking=False) |
|---|
| 131 | |
|---|
| 132 | OWGUI.spin(bagging_cnn, self, "bagged_cn_k", 1, 20, step=1, |
|---|
| 133 | label="Nearest neighbors:", |
|---|
| 134 | tooltip="Number of nearest neighbors used in BVCK estimate.", |
|---|
| 135 | callback=partial(self.method_param_changed, 5), |
|---|
| 136 | keyboardTracking=False) |
|---|
| 137 | |
|---|
| 138 | mahalanobis_box = method_box(rbox, "Mahalanobis distance", |
|---|
| 139 | "mahalanobis_distance") |
|---|
| 140 | OWGUI.spin(mahalanobis_box, self, "mahalanobis_k", 1, 20, step=1, |
|---|
| 141 | label="Nearest neighbors:", |
|---|
| 142 | tooltip="Number of nearest neighbors used in BVCK estimate.", |
|---|
| 143 | callback=partial(self.method_param_changed, 6), |
|---|
| 144 | keyboardTracking=False) |
|---|
| 145 | |
|---|
| 146 | box = OWGUI.widgetBox(self.controlArea, "Output") |
|---|
| 147 | |
|---|
| 148 | OWGUI.checkBox(box, self, "include_error", "Include prediction error", |
|---|
| 149 | tooltip="Include prediction error in the output", |
|---|
| 150 | callback=self.commit_if) |
|---|
| 151 | |
|---|
| 152 | OWGUI.checkBox(box, self, "include_class", "Include original class and prediction", |
|---|
| 153 | tooltip="Include original class and prediction in the output.", |
|---|
| 154 | callback=self.commit_if) |
|---|
| 155 | |
|---|
| 156 | OWGUI.checkBox(box, self, "include_input_features", "Include input features", |
|---|
| 157 | tooltip="Include features from the input data set.", |
|---|
| 158 | callback=self.commit_if) |
|---|
| 159 | |
|---|
| 160 | cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change", |
|---|
| 161 | callback=self.commit_if) |
|---|
| 162 | |
|---|
| 163 | self.commit_button = b = OWGUI.button(box, self, "Commit", |
|---|
| 164 | callback=self.commit, |
|---|
| 165 | autoDefault=True) |
|---|
| 166 | |
|---|
| 167 | OWGUI.setStopper(self, b, cb, "output_changed", callback=self.commit) |
|---|
| 168 | |
|---|
| 169 | self.commit_button.setEnabled(any([getattr(self, selected) \ |
|---|
| 170 | for selected, _ in self.methods])) |
|---|
| 171 | |
|---|
| 172 | self.learner = None |
|---|
| 173 | self.train_data = None |
|---|
| 174 | self.test_data = None |
|---|
| 175 | self.output_changed = False |
|---|
| 176 | |
|---|
| 177 | self.invalidate_results() |
|---|
| 178 | |
|---|
| 179 | def set_train_data(self, data=None): |
|---|
| 180 | self.train_data = data |
|---|
| 181 | self.invalidate_results() |
|---|
| 182 | |
|---|
| 183 | def set_test_data(self, data=None): |
|---|
| 184 | self.test_data = data |
|---|
| 185 | self.invalidate_results() |
|---|
| 186 | |
|---|
| 187 | def set_learner(self, learner=None): |
|---|
| 188 | self.learner = learner |
|---|
| 189 | self.invalidate_results() |
|---|
| 190 | |
|---|
| 191 | def handleNewSignals(self): |
|---|
| 192 | name = test = train = "" |
|---|
| 193 | if self.learner: |
|---|
| 194 | name = getattr(self.learner, "name") or type(self.learner).__name__ |
|---|
| 195 | |
|---|
| 196 | if self.train_data is not None: |
|---|
| 197 | train = "Train Data: %i features, %i instances" % \ |
|---|
| 198 | (len(self.train_data.domain), len(self.train_data)) |
|---|
| 199 | |
|---|
| 200 | if self.test_data is not None: |
|---|
| 201 | test = "Train Data: %i features, %i instances" % \ |
|---|
| 202 | (len(self.test_data.domain), len(self.test_data)) |
|---|
| 203 | else: |
|---|
| 204 | test = "Test data: using training data" |
|---|
| 205 | |
|---|
| 206 | self.info_box.setText("\n".join([name, train, test])) |
|---|
| 207 | |
|---|
| 208 | if self.learner and self._test_data() is not None: |
|---|
| 209 | self.commit_if() |
|---|
| 210 | |
|---|
| 211 | def invalidate_results(self, which=None): |
|---|
| 212 | if which is None: |
|---|
| 213 | self.results = [None for f in self.methods] |
|---|
| 214 | # print "Invalidating all" |
|---|
| 215 | else: |
|---|
| 216 | for i in which: |
|---|
| 217 | self.results[i] = None |
|---|
| 218 | # print "Invalidating", which |
|---|
| 219 | |
|---|
| 220 | def run(self): |
|---|
| 221 | plan = [] |
|---|
| 222 | for i, (selected, method) in enumerate(self.methods): |
|---|
| 223 | if self.results[i] is None and getattr(self, selected): |
|---|
| 224 | # print 'Computing', i, selected, method |
|---|
| 225 | plan.append((i, method)) |
|---|
| 226 | # self.results[i] = method() |
|---|
| 227 | # print self.results[i] |
|---|
| 228 | count = len(plan) |
|---|
| 229 | pb = OWGUI.ProgressBar(self, count * len(self._test_data())) |
|---|
| 230 | for i, (index, method) in enumerate(plan): |
|---|
| 231 | self.results[index] = method(pb.advance) |
|---|
| 232 | pb.finish() |
|---|
| 233 | |
|---|
| 234 | def _test_data(self): |
|---|
| 235 | if self.test_data is not None: |
|---|
| 236 | return self.test_data |
|---|
| 237 | else: |
|---|
| 238 | return self.train_data |
|---|
| 239 | |
|---|
| 240 | def get_estimates(self, estimator, advance=None): |
|---|
| 241 | test = self._test_data() |
|---|
| 242 | res = [] |
|---|
| 243 | for i, inst in enumerate(test): |
|---|
| 244 | value, prob = estimator(inst, result_type=Orange.core.GetBoth) |
|---|
| 245 | res.append((value, prob)) |
|---|
| 246 | if advance: |
|---|
| 247 | advance() |
|---|
| 248 | return res |
|---|
| 249 | |
|---|
| 250 | def run_estimation(self, method, advance=None): |
|---|
| 251 | rel = reliability.Learner(self.learner, estimators=[method]) |
|---|
| 252 | estimator = rel(self.train_data) |
|---|
| 253 | return self.get_estimates(estimator, advance) |
|---|
| 254 | |
|---|
| 255 | def run_SAVar(self, advance=None): |
|---|
| 256 | est = reliability.SensitivityAnalysis(e=eval(self.var_e)) |
|---|
| 257 | return self.run_estimation(est, advance) |
|---|
| 258 | |
|---|
| 259 | def run_SABias(self, advance=None): |
|---|
| 260 | est = reliability.SensitivityAnalysis(e=eval(self.bias_e)) |
|---|
| 261 | return self.run_estimation(est, advance) |
|---|
| 262 | |
|---|
| 263 | def run_BAGV(self, advance=None): |
|---|
| 264 | est = reliability.BaggingVariance(m=self.bagged_m) |
|---|
| 265 | return self.run_estimation(est, advance) |
|---|
| 266 | |
|---|
| 267 | def run_LCV(self, advance=None): |
|---|
| 268 | est = reliability.LocalCrossValidation(k=self.local_cv_k) |
|---|
| 269 | return self.run_estimation(est, advance) |
|---|
| 270 | |
|---|
| 271 | def run_CNK(self, advance=None): |
|---|
| 272 | est = reliability.CNeighbours(k=self.local_pe_k) |
|---|
| 273 | return self.run_estimation(est, advance) |
|---|
| 274 | |
|---|
| 275 | def run_BVCK(self, advance=None): |
|---|
| 276 | bagv = reliability.BaggingVariance(m=self.bagged_cn_m) |
|---|
| 277 | cnk = reliability.CNeighbours(k=self.bagged_cn_k) |
|---|
| 278 | est = reliability.BaggingVarianceCNeighbours(bagv, cnk) |
|---|
| 279 | return self.run_estimation(est, advance) |
|---|
| 280 | |
|---|
| 281 | def run_Mahalanobis(self, advance=None): |
|---|
| 282 | est = reliability.Mahalanobis(k=self.mahalanobis_k) |
|---|
| 283 | return self.run_estimation(est, advance) |
|---|
| 284 | |
|---|
| 285 | def method_selection_changed(self, method=None): |
|---|
| 286 | self.commit_button.setEnabled(any([getattr(self, selected) \ |
|---|
| 287 | for selected, _ in self.methods])) |
|---|
| 288 | self.commit_if() |
|---|
| 289 | |
|---|
| 290 | def method_param_changed(self, method=None): |
|---|
| 291 | if method is not None: |
|---|
| 292 | self.invalidate_results([method]) |
|---|
| 293 | self.commit_if() |
|---|
| 294 | |
|---|
| 295 | def commit_if(self): |
|---|
| 296 | if self.auto_commit: |
|---|
| 297 | self.commit() |
|---|
| 298 | else: |
|---|
| 299 | self.output_changed = True |
|---|
| 300 | |
|---|
| 301 | def commit(self): |
|---|
| 302 | from Orange.data import variable |
|---|
| 303 | |
|---|
| 304 | self.run() |
|---|
| 305 | name_mapper = {"Mahalanobis absolute": "Mahalanobis"} |
|---|
| 306 | all_predictions = [] |
|---|
| 307 | all_estimates = [] |
|---|
| 308 | score_vars = [] |
|---|
| 309 | features = [] |
|---|
| 310 | table = None |
|---|
| 311 | if self._test_data() is not None: |
|---|
| 312 | scores = [] |
|---|
| 313 | |
|---|
| 314 | if self.include_class and not self.include_input_features: |
|---|
| 315 | original_class = self._test_data().domain.class_var |
|---|
| 316 | features.append(original_class) |
|---|
| 317 | |
|---|
| 318 | if self.include_class: |
|---|
| 319 | prediction_var = variable.Continuous("Prediction") |
|---|
| 320 | features.append(prediction_var) |
|---|
| 321 | |
|---|
| 322 | if self.include_error: |
|---|
| 323 | error_var = variable.Continuous("Error") |
|---|
| 324 | abs_error_var = variable.Continuous("Abs Error") |
|---|
| 325 | features.append(error_var) |
|---|
| 326 | features.append(abs_error_var) |
|---|
| 327 | |
|---|
| 328 | for res, (selected, method) in zip(self.results, self.methods): |
|---|
| 329 | if res is not None and getattr(self, selected): |
|---|
| 330 | if selected == "bias_checked": |
|---|
| 331 | ei = 1 |
|---|
| 332 | else: |
|---|
| 333 | ei = 0 |
|---|
| 334 | values, estimates = [], [] |
|---|
| 335 | for value, probs in res: |
|---|
| 336 | values.append(value) |
|---|
| 337 | estimates.append(probs.reliability_estimate[ei]) |
|---|
| 338 | name = estimates[0].method_name |
|---|
| 339 | name = name_mapper.get(name, name) |
|---|
| 340 | var = variable.Continuous(name) |
|---|
| 341 | features.append(var) |
|---|
| 342 | score_vars.append(var) |
|---|
| 343 | all_predictions.append(values) |
|---|
| 344 | all_estimates.append(estimates) |
|---|
| 345 | |
|---|
| 346 | if self.include_input_features: |
|---|
| 347 | dom = self._test_data().domain |
|---|
| 348 | domain = Orange.data.Domain(dom.attributes, dom.class_var) |
|---|
| 349 | domain.add_metas(dom.get_metas()) |
|---|
| 350 | data = Orange.data.Table(domain, self._test_data()) |
|---|
| 351 | else: |
|---|
| 352 | domain = Orange.data.Domain([]) |
|---|
| 353 | data = Orange.data.Table(domain, [[] for _ in self._test_data()]) |
|---|
| 354 | |
|---|
| 355 | for f in features: |
|---|
| 356 | data.domain.add_meta(Orange.core.newmetaid(), f) |
|---|
| 357 | |
|---|
| 358 | if self.include_class: |
|---|
| 359 | for d, inst, pred in zip(data, self._test_data(), all_predictions[0]): |
|---|
| 360 | if not self.include_input_features: |
|---|
| 361 | d[features[0]] = float(inst.get_class()) |
|---|
| 362 | d[prediction_var] = float(pred) |
|---|
| 363 | |
|---|
| 364 | if self.include_error: |
|---|
| 365 | for d, inst, pred in zip(data, self._test_data(), all_predictions[0]): |
|---|
| 366 | error = float(pred) - float(inst.get_class()) |
|---|
| 367 | d[error_var] = error |
|---|
| 368 | d[abs_error_var] = abs(error) |
|---|
| 369 | |
|---|
| 370 | for estimations, var in zip(all_estimates, score_vars): |
|---|
| 371 | for d, e in zip(data, estimations): |
|---|
| 372 | d[var] = e.estimate |
|---|
| 373 | |
|---|
| 374 | table = data |
|---|
| 375 | |
|---|
| 376 | self.send("Reliability Scores", table) |
|---|
| 377 | self.output_changed = True |
|---|
| 378 | |
|---|
| 379 | |
|---|
| 380 | if __name__ == "__main__": |
|---|
| 381 | import sys |
|---|
| 382 | app = QApplication(sys.argv) |
|---|
| 383 | w = OWReliability() |
|---|
| 384 | data = Orange.data.Table("housing") |
|---|
| 385 | indices = Orange.core.MakeRandomIndices2(p0=20)(data) |
|---|
| 386 | data = data.select(indices, 0) |
|---|
| 387 | |
|---|
| 388 | learner = Orange.regression.tree.TreeLearner() |
|---|
| 389 | w.set_learner(learner) |
|---|
| 390 | w.set_train_data(data) |
|---|
| 391 | w.handleNewSignals() |
|---|
| 392 | w.show() |
|---|
| 393 | app.exec_() |
|---|
| 394 | |
|---|
| 395 | |
|---|