| [9157] | 1 | """ |
|---|
| 2 | <name>Reliability</name> |
|---|
| 3 | <contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact> |
|---|
| 4 | <priority>310</priority> |
|---|
| [9161] | 5 | <icon>icons/Reliability.png</icon> |
|---|
| [9157] | 6 | """ |
|---|
| 7 | |
|---|
| 8 | import Orange |
|---|
| 9 | from Orange.evaluation import reliability |
|---|
| 10 | from Orange.evaluation import testing |
|---|
| 11 | #from Orange.misc import progress_bar_milestones |
|---|
| 12 | from functools import partial |
|---|
| 13 | |
|---|
| 14 | from OWWidget import * |
|---|
| 15 | import OWGUI |
|---|
| 16 | |
|---|
| 17 | class OWReliability(OWWidget): |
|---|
| 18 | settingsList = ["variance_checked", "bias_checked", "bagged_variance", |
|---|
| 19 | "local_cv", "local_model_pred_error", "bagging_variance_cn", |
|---|
| 20 | "mahalanobis_distance", "var_e", "bias_e", "bagged_m", "local_cv_k", |
|---|
| 21 | "local_pe_k", "bagged_cn_m", "bagged_cn_k", "mahalanobis_k", |
|---|
| 22 | "include_error", "include_class", "include_input_features", |
|---|
| 23 | "auto_commit"] |
|---|
| 24 | |
|---|
| 25 | def __init__(self, parent=None, signalManager=None, title="Reliability"): |
|---|
| 26 | OWWidget.__init__(self, parent, signalManager, title, wantMainArea=False) |
|---|
| 27 | |
|---|
| 28 | self.inputs = [("Learner", Orange.core.Learner, self.set_learner), |
|---|
| 29 | ("Train Data", Orange.data.Table, self.set_train_data), |
|---|
| 30 | ("Test Data", Orange.data.Table, self.set_test_data)] |
|---|
| 31 | |
|---|
| 32 | self.outputs = [("Reliability Scores", Orange.data.Table)] |
|---|
| 33 | |
|---|
| 34 | self.variance_checked = False |
|---|
| 35 | self.bias_checked = False |
|---|
| 36 | self.bagged_variance = False |
|---|
| 37 | self.local_cv = False |
|---|
| 38 | self.local_model_pred_error = False |
|---|
| 39 | self.bagging_variance_cn = False |
|---|
| 40 | self.mahalanobis_distance = True |
|---|
| 41 | |
|---|
| 42 | self.var_e = "0.01, 0.1, 0.5, 1.0, 2.0" |
|---|
| 43 | self.bias_e = "0.01, 0.1, 0.5, 1.0, 2.0" |
|---|
| [9161] | 44 | self.bagged_m = 10 |
|---|
| [9157] | 45 | self.local_cv_k = 2 |
|---|
| 46 | self.local_pe_k = 5 |
|---|
| 47 | self.bagged_cn_m = 5 |
|---|
| 48 | self.bagged_cn_k = 1 |
|---|
| [9161] | 49 | self.mahalanobis_k = 3 |
|---|
| [9157] | 50 | |
|---|
| 51 | self.include_error = True |
|---|
| 52 | self.include_class = True |
|---|
| [9161] | 53 | self.include_input_features = False |
|---|
| [9157] | 54 | self.auto_commit = False |
|---|
| [9165] | 55 | |
|---|
| 56 | # (selected attr name, getter function, count of returned estimators, index of estimator) |
|---|
| 57 | self.estimators = \ |
|---|
| 58 | [("variance_checked", self.get_SAVar, 3, 0), |
|---|
| 59 | ("bias_checked", self.get_SABias, 3, 1), |
|---|
| 60 | ("bagged_variance", self.get_BAGV, 1, 0), |
|---|
| 61 | ("local_cv", self.get_LCV, 1, 0), |
|---|
| [9542] | 62 | ("local_model_pred_error", self.get_CNK, 2, 1), |
|---|
| [9165] | 63 | ("bagging_variance_cn", self.get_BVCK, 4, 0), |
|---|
| 64 | ("mahalanobis_distance", self.get_Mahalanobis, 1, 0)] |
|---|
| [9157] | 65 | |
|---|
| 66 | ##### |
|---|
| 67 | # GUI |
|---|
| 68 | ##### |
|---|
| 69 | self.loadSettings() |
|---|
| 70 | |
|---|
| 71 | box = OWGUI.widgetBox(self.controlArea, "Info", addSpace=True) |
|---|
| 72 | self.info_box = OWGUI.widgetLabel(box, "\n\n") |
|---|
| 73 | |
|---|
| 74 | rbox = OWGUI.widgetBox(self.controlArea, "Methods", addSpace=True) |
|---|
| 75 | def method_box(parent, name, value): |
|---|
| 76 | box = OWGUI.widgetBox(rbox, name, flat=False) |
|---|
| 77 | box.setCheckable(True) |
|---|
| 78 | box.setChecked(bool(getattr(self, value))) |
|---|
| 79 | self.connect(box, SIGNAL("toggled(bool)"), |
|---|
| 80 | lambda on: (setattr(self, value, on), |
|---|
| 81 | self.method_selection_changed(value))) |
|---|
| 82 | return box |
|---|
| 83 | |
|---|
| 84 | e_validator = QRegExpValidator(QRegExp(r"\s*(-?[0-9]+(\.[0-9]*)\s*,\s*)+"), self) |
|---|
| 85 | variance_box = method_box(rbox, "Sensitivity analysis (variance)", |
|---|
| 86 | "variance_checked") |
|---|
| 87 | OWGUI.lineEdit(variance_box, self, "var_e", "Sensitivities:", |
|---|
| 88 | tooltip="List of possible e values (comma separated) for SAvar reliability estimates.", |
|---|
| 89 | callback=partial(self.method_param_changed, 0), |
|---|
| 90 | validator=e_validator) |
|---|
| 91 | |
|---|
| 92 | bias_box = method_box(rbox, "Sensitivity analysis (bias)", |
|---|
| 93 | "bias_checked") |
|---|
| 94 | OWGUI.lineEdit(bias_box, self, "bias_e", "Sensitivities:", |
|---|
| 95 | tooltip="List of possible e values (comma separated) for SAbias reliability estimates.", |
|---|
| 96 | callback=partial(self.method_param_changed, 1), |
|---|
| 97 | validator=e_validator) |
|---|
| 98 | |
|---|
| 99 | bagged_box = method_box(rbox, "Variance of bagged models", |
|---|
| 100 | "bagged_variance") |
|---|
| 101 | |
|---|
| 102 | OWGUI.spin(bagged_box, self, "bagged_m", 2, 100, step=1, |
|---|
| 103 | label="Models:", |
|---|
| 104 | tooltip="Number of bagged models to be used with BAGV estimate.", |
|---|
| 105 | callback=partial(self.method_param_changed, 2), |
|---|
| 106 | keyboardTracking=False) |
|---|
| 107 | |
|---|
| 108 | local_cv_box = method_box(rbox, "Local cross validation", |
|---|
| 109 | "local_cv") |
|---|
| 110 | |
|---|
| 111 | OWGUI.spin(local_cv_box, self, "local_cv_k", 2, 20, step=1, |
|---|
| 112 | label="Nearest neighbors:", |
|---|
| 113 | tooltip="Number of nearest neighbors used in LCV estimate.", |
|---|
| 114 | callback=partial(self.method_param_changed, 3), |
|---|
| 115 | keyboardTracking=False) |
|---|
| 116 | |
|---|
| 117 | local_pe = method_box(rbox, "Local modeling of prediction error", |
|---|
| 118 | "local_model_pred_error") |
|---|
| 119 | |
|---|
| 120 | OWGUI.spin(local_pe, self, "local_pe_k", 1, 20, step=1, |
|---|
| 121 | label="Nearest neighbors:", |
|---|
| 122 | tooltip="Number of nearest neighbors used in CNK estimate.", |
|---|
| 123 | callback=partial(self.method_param_changed, 4), |
|---|
| 124 | keyboardTracking=False) |
|---|
| 125 | |
|---|
| 126 | bagging_cnn = method_box(rbox, "Bagging variance c-neighbors", |
|---|
| 127 | "bagging_variance_cn") |
|---|
| 128 | |
|---|
| 129 | OWGUI.spin(bagging_cnn, self, "bagged_cn_m", 2, 100, step=1, |
|---|
| 130 | label="Models:", |
|---|
| 131 | tooltip="Number of bagged models to be used with BVCK estimate.", |
|---|
| 132 | callback=partial(self.method_param_changed, 5), |
|---|
| 133 | keyboardTracking=False) |
|---|
| 134 | |
|---|
| 135 | OWGUI.spin(bagging_cnn, self, "bagged_cn_k", 1, 20, step=1, |
|---|
| 136 | label="Nearest neighbors:", |
|---|
| 137 | tooltip="Number of nearest neighbors used in BVCK estimate.", |
|---|
| 138 | callback=partial(self.method_param_changed, 5), |
|---|
| 139 | keyboardTracking=False) |
|---|
| 140 | |
|---|
| 141 | mahalanobis_box = method_box(rbox, "Mahalanobis distance", |
|---|
| 142 | "mahalanobis_distance") |
|---|
| 143 | OWGUI.spin(mahalanobis_box, self, "mahalanobis_k", 1, 20, step=1, |
|---|
| 144 | label="Nearest neighbors:", |
|---|
| 145 | tooltip="Number of nearest neighbors used in BVCK estimate.", |
|---|
| 146 | callback=partial(self.method_param_changed, 6), |
|---|
| 147 | keyboardTracking=False) |
|---|
| 148 | |
|---|
| 149 | box = OWGUI.widgetBox(self.controlArea, "Output") |
|---|
| 150 | |
|---|
| 151 | OWGUI.checkBox(box, self, "include_error", "Include prediction error", |
|---|
| 152 | tooltip="Include prediction error in the output", |
|---|
| 153 | callback=self.commit_if) |
|---|
| 154 | |
|---|
| 155 | OWGUI.checkBox(box, self, "include_class", "Include original class and prediction", |
|---|
| 156 | tooltip="Include original class and prediction in the output.", |
|---|
| 157 | callback=self.commit_if) |
|---|
| 158 | |
|---|
| 159 | OWGUI.checkBox(box, self, "include_input_features", "Include input features", |
|---|
| 160 | tooltip="Include features from the input data set.", |
|---|
| 161 | callback=self.commit_if) |
|---|
| 162 | |
|---|
| 163 | cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change", |
|---|
| 164 | callback=self.commit_if) |
|---|
| 165 | |
|---|
| 166 | self.commit_button = b = OWGUI.button(box, self, "Commit", |
|---|
| 167 | callback=self.commit, |
|---|
| 168 | autoDefault=True) |
|---|
| 169 | |
|---|
| 170 | OWGUI.setStopper(self, b, cb, "output_changed", callback=self.commit) |
|---|
| 171 | |
|---|
| 172 | self.commit_button.setEnabled(any([getattr(self, selected) \ |
|---|
| [9165] | 173 | for selected, _, _, _ in self.estimators])) |
|---|
| [9157] | 174 | |
|---|
| 175 | self.learner = None |
|---|
| 176 | self.train_data = None |
|---|
| 177 | self.test_data = None |
|---|
| 178 | self.output_changed = False |
|---|
| [9206] | 179 | |
|---|
| [9157] | 180 | self.invalidate_results() |
|---|
| 181 | |
|---|
| 182 | def set_train_data(self, data=None): |
|---|
| [9165] | 183 | self.error() |
|---|
| 184 | if data is not None: |
|---|
| 185 | if not self.isDataWithClass(data, Orange.core.VarTypes.Continuous): |
|---|
| 186 | data = None |
|---|
| 187 | |
|---|
| [9157] | 188 | self.train_data = data |
|---|
| [9165] | 189 | self.invalidate_results() |
|---|
| [9157] | 190 | |
|---|
| 191 | def set_test_data(self, data=None): |
|---|
| 192 | self.test_data = data |
|---|
| 193 | self.invalidate_results() |
|---|
| 194 | |
|---|
| 195 | def set_learner(self, learner=None): |
|---|
| 196 | self.learner = learner |
|---|
| 197 | self.invalidate_results() |
|---|
| 198 | |
|---|
| 199 | def handleNewSignals(self): |
|---|
| [9165] | 200 | name = "No learner on input" |
|---|
| 201 | train = "No train data on input" |
|---|
| 202 | test = "No test data on input" |
|---|
| 203 | |
|---|
| [9157] | 204 | if self.learner: |
|---|
| [9277] | 205 | name = "Learner: " + (getattr(self.learner, "name") or type(self.learner).__name__) |
|---|
| [9157] | 206 | |
|---|
| 207 | if self.train_data is not None: |
|---|
| 208 | train = "Train Data: %i features, %i instances" % \ |
|---|
| 209 | (len(self.train_data.domain), len(self.train_data)) |
|---|
| 210 | |
|---|
| 211 | if self.test_data is not None: |
|---|
| [9277] | 212 | test = "Test Data: %i features, %i instances" % \ |
|---|
| [9157] | 213 | (len(self.test_data.domain), len(self.test_data)) |
|---|
| [9165] | 214 | elif self.train_data: |
|---|
| [9157] | 215 | test = "Test data: using training data" |
|---|
| 216 | |
|---|
| 217 | self.info_box.setText("\n".join([name, train, test])) |
|---|
| 218 | |
|---|
| 219 | if self.learner and self._test_data() is not None: |
|---|
| 220 | self.commit_if() |
|---|
| 221 | |
|---|
| 222 | def invalidate_results(self, which=None): |
|---|
| 223 | if which is None: |
|---|
| [9165] | 224 | self.results = [None for f in self.estimators] |
|---|
| [9157] | 225 | # print "Invalidating all" |
|---|
| 226 | else: |
|---|
| 227 | for i in which: |
|---|
| 228 | self.results[i] = None |
|---|
| 229 | # print "Invalidating", which |
|---|
| [9165] | 230 | |
|---|
| [9157] | 231 | def run(self): |
|---|
| 232 | plan = [] |
|---|
| [9165] | 233 | estimate_index = 0 |
|---|
| 234 | for i, (selected, method, count, offset) in enumerate(self.estimators): |
|---|
| [9157] | 235 | if self.results[i] is None and getattr(self, selected): |
|---|
| [9165] | 236 | plan.append((i, method, estimate_index + offset)) |
|---|
| 237 | estimate_index += count |
|---|
| 238 | |
|---|
| 239 | estimators = [method() for _, method, _ in plan] |
|---|
| 240 | |
|---|
| 241 | if not estimators: |
|---|
| 242 | return |
|---|
| 243 | |
|---|
| 244 | pb = OWGUI.ProgressBar(self, len(self._test_data())) |
|---|
| 245 | estimates = self.run_estimation(estimators, pb.advance) |
|---|
| [9157] | 246 | pb.finish() |
|---|
| 247 | |
|---|
| [9165] | 248 | self.predictions = [v for v, _ in estimates] |
|---|
| 249 | estimates = [prob.reliability_estimate for _, prob in estimates] |
|---|
| 250 | |
|---|
| 251 | for i, (index, method, estimate_index) in enumerate(plan): |
|---|
| 252 | self.results[index] = [e[estimate_index] for e in estimates] |
|---|
| 253 | |
|---|
| [9157] | 254 | def _test_data(self): |
|---|
| 255 | if self.test_data is not None: |
|---|
| 256 | return self.test_data |
|---|
| 257 | else: |
|---|
| 258 | return self.train_data |
|---|
| 259 | |
|---|
| 260 | def get_estimates(self, estimator, advance=None): |
|---|
| 261 | test = self._test_data() |
|---|
| 262 | res = [] |
|---|
| 263 | for i, inst in enumerate(test): |
|---|
| 264 | value, prob = estimator(inst, result_type=Orange.core.GetBoth) |
|---|
| 265 | res.append((value, prob)) |
|---|
| 266 | if advance: |
|---|
| 267 | advance() |
|---|
| 268 | return res |
|---|
| 269 | |
|---|
| [9165] | 270 | def run_estimation(self, estimators, advance=None): |
|---|
| 271 | rel = reliability.Learner(self.learner, estimators=estimators) |
|---|
| [9157] | 272 | estimator = rel(self.train_data) |
|---|
| 273 | return self.get_estimates(estimator, advance) |
|---|
| 274 | |
|---|
| [9165] | 275 | def get_SAVar(self): |
|---|
| 276 | return reliability.SensitivityAnalysis(e=eval(self.var_e)) |
|---|
| [9157] | 277 | |
|---|
| [9165] | 278 | def get_SABias(self): |
|---|
| 279 | return reliability.SensitivityAnalysis(e=eval(self.bias_e)) |
|---|
| [9157] | 280 | |
|---|
| [9165] | 281 | def get_BAGV(self): |
|---|
| 282 | return reliability.BaggingVariance(m=self.bagged_m) |
|---|
| [9157] | 283 | |
|---|
| [9165] | 284 | def get_LCV(self): |
|---|
| 285 | return reliability.LocalCrossValidation(k=self.local_cv_k) |
|---|
| [9157] | 286 | |
|---|
| [9165] | 287 | def get_CNK(self): |
|---|
| 288 | return reliability.CNeighbours(k=self.local_pe_k) |
|---|
| 289 | |
|---|
| 290 | def get_BVCK(self): |
|---|
| [9157] | 291 | bagv = reliability.BaggingVariance(m=self.bagged_cn_m) |
|---|
| 292 | cnk = reliability.CNeighbours(k=self.bagged_cn_k) |
|---|
| [9165] | 293 | return reliability.BaggingVarianceCNeighbours(bagv, cnk) |
|---|
| [9157] | 294 | |
|---|
| [9165] | 295 | def get_Mahalanobis(self): |
|---|
| 296 | return reliability.Mahalanobis(k=self.mahalanobis_k) |
|---|
| [9157] | 297 | |
|---|
| 298 | def method_selection_changed(self, method=None): |
|---|
| 299 | self.commit_button.setEnabled(any([getattr(self, selected) \ |
|---|
| [9165] | 300 | for selected, _, _, _ in self.estimators])) |
|---|
| [9157] | 301 | self.commit_if() |
|---|
| 302 | |
|---|
| 303 | def method_param_changed(self, method=None): |
|---|
| 304 | if method is not None: |
|---|
| 305 | self.invalidate_results([method]) |
|---|
| 306 | self.commit_if() |
|---|
| 307 | |
|---|
| 308 | def commit_if(self): |
|---|
| 309 | if self.auto_commit: |
|---|
| 310 | self.commit() |
|---|
| 311 | else: |
|---|
| 312 | self.output_changed = True |
|---|
| 313 | |
|---|
| 314 | def commit(self): |
|---|
| 315 | from Orange.data import variable |
|---|
| [9165] | 316 | name_mapper = {"Mahalanobis absolute": "Mahalanobis"} |
|---|
| [9157] | 317 | all_predictions = [] |
|---|
| 318 | all_estimates = [] |
|---|
| 319 | score_vars = [] |
|---|
| 320 | features = [] |
|---|
| 321 | table = None |
|---|
| [9277] | 322 | if self.learner and self.train_data is not None \ |
|---|
| 323 | and self._test_data() is not None: |
|---|
| [9206] | 324 | self.run() |
|---|
| 325 | |
|---|
| [9157] | 326 | scores = [] |
|---|
| 327 | if self.include_class and not self.include_input_features: |
|---|
| 328 | original_class = self._test_data().domain.class_var |
|---|
| 329 | features.append(original_class) |
|---|
| 330 | |
|---|
| 331 | if self.include_class: |
|---|
| 332 | prediction_var = variable.Continuous("Prediction") |
|---|
| 333 | features.append(prediction_var) |
|---|
| 334 | |
|---|
| 335 | if self.include_error: |
|---|
| 336 | error_var = variable.Continuous("Error") |
|---|
| [9165] | 337 | abs_error_var = variable.Continuous("Abs. Error") |
|---|
| [9157] | 338 | features.append(error_var) |
|---|
| 339 | features.append(abs_error_var) |
|---|
| 340 | |
|---|
| [9165] | 341 | for estimates, (selected, method, _, _) in zip(self.results, self.estimators): |
|---|
| 342 | if estimates is not None and getattr(self, selected): |
|---|
| [9157] | 343 | name = estimates[0].method_name |
|---|
| 344 | name = name_mapper.get(name, name) |
|---|
| 345 | var = variable.Continuous(name) |
|---|
| 346 | features.append(var) |
|---|
| 347 | score_vars.append(var) |
|---|
| 348 | all_estimates.append(estimates) |
|---|
| 349 | |
|---|
| 350 | if self.include_input_features: |
|---|
| 351 | dom = self._test_data().domain |
|---|
| [9297] | 352 | attributes = list(dom.attributes) + features |
|---|
| 353 | domain = Orange.data.Domain(attributes, dom.class_var) |
|---|
| [9157] | 354 | domain.add_metas(dom.get_metas()) |
|---|
| [9297] | 355 | |
|---|
| [9157] | 356 | data = Orange.data.Table(domain, self._test_data()) |
|---|
| 357 | else: |
|---|
| [9297] | 358 | domain = Orange.data.Domain(features, None) |
|---|
| 359 | data = Orange.data.Table(domain, [[None] * len(features) for _ in self._test_data()]) |
|---|
| [9157] | 360 | |
|---|
| 361 | if self.include_class: |
|---|
| [9165] | 362 | for d, inst, pred in zip(data, self._test_data(), self.predictions): |
|---|
| [9157] | 363 | if not self.include_input_features: |
|---|
| 364 | d[features[0]] = float(inst.get_class()) |
|---|
| 365 | d[prediction_var] = float(pred) |
|---|
| 366 | |
|---|
| 367 | if self.include_error: |
|---|
| [9165] | 368 | for d, inst, pred in zip(data, self._test_data(), self.predictions): |
|---|
| [9157] | 369 | error = float(pred) - float(inst.get_class()) |
|---|
| 370 | d[error_var] = error |
|---|
| 371 | d[abs_error_var] = abs(error) |
|---|
| 372 | |
|---|
| 373 | for estimations, var in zip(all_estimates, score_vars): |
|---|
| 374 | for d, e in zip(data, estimations): |
|---|
| 375 | d[var] = e.estimate |
|---|
| 376 | |
|---|
| 377 | table = data |
|---|
| 378 | |
|---|
| 379 | self.send("Reliability Scores", table) |
|---|
| [9165] | 380 | self.output_changed = False |
|---|
| [9157] | 381 | |
|---|
| 382 | |
|---|
| 383 | if __name__ == "__main__": |
|---|
| 384 | import sys |
|---|
| 385 | app = QApplication(sys.argv) |
|---|
| 386 | w = OWReliability() |
|---|
| 387 | data = Orange.data.Table("housing") |
|---|
| 388 | indices = Orange.core.MakeRandomIndices2(p0=20)(data) |
|---|
| 389 | data = data.select(indices, 0) |
|---|
| 390 | |
|---|
| 391 | learner = Orange.regression.tree.TreeLearner() |
|---|
| 392 | w.set_learner(learner) |
|---|
| 393 | w.set_train_data(data) |
|---|
| 394 | w.handleNewSignals() |
|---|
| 395 | w.show() |
|---|
| 396 | app.exec_() |
|---|
| 397 | |
|---|
| 398 | |
|---|