source: orange/Orange/OrangeWidgets/Data/OWRank.py @ 11684:10858549f936

Revision 11684:10858549f936, 34.3 KB checked in by Ales Erjavec <ales.erjavec@…>, 8 months ago (diff)

Refactor Rank widget. Allow adding extra score functions through entry points.

Line 
1"""
2<name>Rank</name>
3<description>Ranks and filters attributes by their relevance.</description>
4<icon>icons/Rank.svg</icon>
5<contact>Janez Demsar (janez.demsar(@at@)fri.uni-lj.si)</contact>
6<priority>1102</priority>
7"""
8
9from collections import namedtuple
10from functools import partial
11
12import pkg_resources
13
14from OWWidget import *
15
16import OWGUI
17import orange
18
19from Orange.regression.earth import ScoreEarthImportance
20from orngSVM import MeasureAttribute_SVMWeights
21from orngEnsemble import MeasureAttribute_randomForests
22
23
24def _toPyObject(variant):
25    val = variant.toPyObject()
26    if isinstance(val, type(NotImplemented)):
27        # PyQt 4.4 converts python int, floats ... to C types and
28        # cannot convert them back again and returns an exception instance.
29        qtype = variant.type()
30        if qtype == QVariant.Double:
31            val, ok = variant.toDouble()
32        elif qtype == QVariant.Int:
33            val, ok = variant.toInt()
34        elif qtype == QVariant.LongLong:
35            val, ok = variant.toLongLong()
36        elif qtype == QVariant.String:
37            val = variant.toString()
38    return val
39
40def is_class_discrete(data):
41    return isinstance(data.domain.classVar, orange.EnumVariable)
42
43def is_class_continuous(data):
44    return isinstance(data.domain.classVar, orange.FloatVariable)
45
46def table(shape, fill=None):
47    """ Return a 2D table with shape filed with ``fill``
48    """
49    return [[fill for j in range(shape[1])] for i in range(shape[0])]
50
51
52MEASURE_PARAMS = {
53    ScoreEarthImportance: [
54        {"name": "t",
55         "type": int,
56         "display_name": "Num. models.",
57         "range": (1, 20),
58         "default": 10,
59         "doc": "Number of models to train for feature scoring."},
60        {"name": "terms",
61         "type": int,
62         "display_name": "Max. num of terms",
63         "range": (3, 200),
64         "default": 10,
65         "doc": "Maximum number of terms in the forward pass"},
66        {"name": "degree",
67         "type": int,
68         "display_name": "Max. term degree",
69         "range": (1, 3),
70         "default": 2,
71         "doc": "Maximum degree of terms included in the model."}
72    ],
73    orange.MeasureAttribute_relief: [
74        {"name": "k",
75         "type": int,
76         "display_name": "Neighbours",
77         "range": (1, 20),
78         "default": 10,
79         "doc": "Number of neighbors to consider."},
80        {"name":"m",
81         "type": int,
82         "display_name": "Examples",
83         "range": (20, 100),
84         "default": 20,
85         "doc": ""}
86        ],
87    MeasureAttribute_randomForests: [
88        {"name": "trees",
89         "type": int,
90         "display_name": "Num. of trees",
91         "range": (20, 100),
92         "default": 100,
93         "doc": "Number of trees in the random forest."}
94        ]
95    }
96
97
98_score_meta = namedtuple(
99    "_score_meta",
100    ["name",
101     "shortname",
102     "score",
103     "params",
104     "supports_regression",
105     "supports_classification",
106     "handles_discrete",
107     "handles_continuous"]
108)
109
110
111class score_meta(_score_meta):
112    # Add sensible defaults to __new__
113    def __new__(cls, name, shortname, score, params=None,
114                supports_regression=True, supports_classification=True,
115                handles_continuous=True, handles_discrete=True):
116        return _score_meta.__new__(
117            cls, name, shortname, score, params,
118            supports_regression, supports_classification,
119            handles_discrete, handles_continuous
120        )
121
122
123# Default scores.
124SCORES = [
125    score_meta(
126        "ReliefF", "ReliefF", orange.MeasureAttribute_relief,
127        params=MEASURE_PARAMS[orange.MeasureAttribute_relief],
128        handles_continuous=True,
129        handles_discrete=True),
130    score_meta(
131        "Information Gain", "Inf. gain", orange.MeasureAttribute_info,
132        params=None,
133        supports_regression=False,
134        supports_classification=True,
135        handles_continuous=False,
136        handles_discrete=True),
137    score_meta(
138        "Gain Ratio", "Gain Ratio", orange.MeasureAttribute_gainRatio,
139        params=None,
140        supports_regression=False,
141        handles_continuous=False,
142        handles_discrete=True),
143    score_meta(
144        "Gini Gain", "Gini", orange.MeasureAttribute_gini,
145        params=None,
146        supports_regression=False,
147        supports_classification=True,
148        handles_continuous=False),
149    score_meta(
150        "Log Odds Ratio", "log OR", orange.MeasureAttribute_logOddsRatio,
151        params=None,
152        supports_regression=False,
153        handles_continuous=False),
154    score_meta(
155        "MSE", "MSE", orange.MeasureAttribute_MSE,
156        params=None,
157        supports_classification=False,
158        handles_continuous=False),
159    score_meta(
160        "Linear SVM Weights", "SVM weight", MeasureAttribute_SVMWeights,
161        params=None),
162    score_meta(
163        "Random Forests", "RF", MeasureAttribute_randomForests,
164        params=MEASURE_PARAMS[MeasureAttribute_randomForests]),
165    score_meta(
166        "Earth Importance", "Earth imp.", ScoreEarthImportance,
167        params=MEASURE_PARAMS[ScoreEarthImportance],
168    )
169]
170
171_DEFAULT_SELECTED = set(m.name for m in SCORES[:6])
172
173
174class MethodParameter(object):
175    def __init__(self, name="", type=None, display_name="Parameter",
176                 range=None, default=None, doc=""):
177        self.name = name
178        self.type = type
179        self.display_name = display_name
180        self.range = range
181        self.default = default
182        self.doc = doc
183
184
185def measure_parameters(measure):
186    return [MethodParameter(**args) for args in (measure.params or [])]
187
188
189def param_attr_name(measure, param):
190    """Name of the OWRank widget's member where the parameter is stored.
191    """
192    return "param_" + measure.__name__ + "_" + param.name
193
194
195def drop_exceptions(iterable, exceptions=(Exception,)):
196    iterable = iter(iterable)
197    while True:
198        try:
199            yield next(iterable)
200        except StopIteration:
201            raise
202        except BaseException as ex:
203            if not isinstance(ex, exceptions):
204                raise
205
206
207def load_ep_drop_exceptions(entry_point):
208    for ep in pkg_resources.iter_entry_points(entry_point):
209        try:
210            yield ep.load()
211        except Exception:
212            log = logging.getLogger(__name__)
213            log.debug("", exc_info=True)
214
215
216def all_measures():
217    iter_ep = load_ep_drop_exceptions("orange.widgets.feature_score")
218    scores = [m for m in iter_ep if isinstance(m, score_meta)]
219    return SCORES + scores
220
221
222class OWRank(OWWidget):
223    settingsList = [
224        "nDecimals", "nIntervals", "sortBy", "nSelected",
225        "selectMethod", "autoApply", "showDistributions",
226        "distColorRgb"
227    ]
228
229    def __init__(self, parent=None, signalManager=None):
230        OWWidget.__init__(self, parent, signalManager, "Rank")
231
232        self.inputs = [("Data", ExampleTable, self.setData)]
233        self.outputs = [("Reduced Data", ExampleTable, Default + Single)]
234
235        self.nDecimals = 3
236        self.nIntervals = 4
237        self.sortBy = 2
238        self.selectMethod = 2
239        self.nSelected = 5
240        self.autoApply = True
241        self.showDistributions = 1
242        self.distColorRgb = (220, 220, 220, 255)
243        self.distColor = QColor(*self.distColorRgb)
244
245        self.all_measures = all_measures()
246
247        self.selectedMeasures = dict(
248            [(name, True) for name in _DEFAULT_SELECTED] +
249            [(m.name, False)
250             for m in self.all_measures[len(_DEFAULT_SELECTED):]]
251        )
252
253        self.data = None
254
255        self.methodParamAttrs = []
256        for m in self.all_measures:
257            params = measure_parameters(m)
258            for p in params:
259                name_mangled = param_attr_name(m.score, p)
260                setattr(self, name_mangled, p.default)
261                self.methodParamAttrs.append(name_mangled)
262
263        self.settingsList = self.settingsList + self.methodParamAttrs
264
265        self.loadSettings()
266
267        self.discMeasures = [m for m in self.all_measures
268                             if m.supports_classification]
269        self.contMeasures = [m for m in self.all_measures
270                             if m.supports_regression]
271
272        self.stackedLayout = QStackedLayout()
273        self.stackedLayout.setContentsMargins(0, 0, 0, 0)
274        self.stackedWidget = OWGUI.widgetBox(self.controlArea, margin=0,
275                                             orientation=self.stackedLayout,
276                                             addSpace=True)
277
278        # Discrete class scoring
279        discreteBox = OWGUI.widgetBox(self.stackedWidget, "Scoring",
280                                      addSpace=False,
281                                      addToLayout=False)
282        self.stackedLayout.addWidget(discreteBox)
283
284        # Continuous class scoring
285        continuousBox = OWGUI.widgetBox(self.stackedWidget, "Scoring",
286                                        addSpace=False,
287                                        addToLayout=False)
288        self.stackedLayout.addWidget(continuousBox)
289
290        def measure_control(container, measure):
291            """Construct UI control for `measure` (measure_meta instance).
292            """
293            name = measure.name
294            params = measure_parameters(measure)
295            if params:
296                hbox = OWGUI.widgetBox(container, orientation="horizontal")
297                OWGUI.checkBox(hbox, self.selectedMeasures, name, name,
298                               callback=partial(self.measuresSelectionChanged,
299                                                measure),
300                               tooltip="Enable " + name)
301
302                smallWidget = OWGUI.SmallWidgetLabel(
303                    hbox, pixmap=1, box=name + " Parameters",
304                    tooltip="Show " + name + "Parameters")
305
306                for param in params:
307                    OWGUI.spin(smallWidget.widget, self,
308                               param_attr_name(measure.score, param),
309                               param.range[0], param.range[-1],
310                               label=param.display_name,
311                               tooltip=param.doc,
312                               callback=partial(
313                                    self.measureParamChanged, measure, param),
314                               callbackOnReturn=True)
315
316                OWGUI.button(smallWidget.widget, self, "Load defaults",
317                             callback=partial(self.loadMeasureDefaults,
318                                              measure))
319            else:
320                OWGUI.checkBox(container, self.selectedMeasures, name, name,
321                               callback=partial(self.measuresSelectionChanged,
322                                                measure),
323                               tooltip="Enable " + name)
324
325        for measure in self.all_measures:
326            if measure.supports_classification:
327                measure_control(discreteBox, measure)
328
329            if measure.supports_regression:
330                measure_control(continuousBox, measure)
331
332        OWGUI.comboBox(discreteBox, self, "sortBy", label = "Sort by"+"  ",
333                       items = ["No Sorting", "Attribute Name", "Number of Values"] + \
334                               [m.name for m in self.discMeasures],
335                       orientation=0, valueType = int,
336                       callback=self.sortingChanged)
337       
338        OWGUI.comboBox(continuousBox, self, "sortBy", label = "Sort by"+"  ",
339                       items = ["No Sorting", "Attribute Name", "Number of Values"] + \
340                               [m.name for m in self.contMeasures],
341                       orientation=0, valueType = int,
342                       callback=self.sortingChanged)
343
344        box = OWGUI.widgetBox(self.controlArea, "Discretization",
345                              addSpace=True)
346        OWGUI.spin(box, self, "nIntervals", 2, 20,
347                   label="Intervals: ",
348                   orientation=0,
349                   tooltip="Disctetization for measures which cannot score continuous attributes.",
350                   callback=self.discretizationChanged,
351                   callbackOnReturn=True)
352
353        box = OWGUI.widgetBox(self.controlArea, "Precision", addSpace=True)
354        OWGUI.spin(box, self, "nDecimals", 1, 6, label="No. of decimals: ",
355                   orientation=0, callback=self.decimalsChanged)
356
357        box = OWGUI.widgetBox(self.controlArea, "Score bars",
358                              orientation="horizontal", addSpace=True)
359        self.cbShowDistributions = OWGUI.checkBox(box, self, "showDistributions",
360                                    'Enable', callback = self.cbShowDistributions)
361#        colBox = OWGUI.indentedBox(box, orientation = "horizontal")
362        OWGUI.rubber(box)
363        box = OWGUI.widgetBox(box, orientation="horizontal")
364        wl = OWGUI.widgetLabel(box, "Color: ")
365        OWGUI.separator(box)
366        self.colButton = OWGUI.toolButton(box, self, callback=self.changeColor, width=20, height=20, debuggingEnabled = 0)
367        self.cbShowDistributions.disables.extend([wl, self.colButton])
368        self.cbShowDistributions.makeConsistent()
369#        OWGUI.rubber(box)
370
371       
372        selMethBox = OWGUI.widgetBox(self.controlArea, "Select attributes", addSpace=True)
373        self.clearButton = OWGUI.button(selMethBox, self, "Clear", callback=self.clearSelection)
374        self.clearButton.setDisabled(True)
375       
376        buttonGrid = QGridLayout()
377        selMethRadio = OWGUI.radioButtonsInBox(selMethBox, self, "selectMethod", [], callback=self.selectMethodChanged)
378        b1 = OWGUI.appendRadioButton(selMethRadio, self, "selectMethod", "All", insertInto=selMethRadio, callback=self.selectMethodChanged, addToLayout=False)
379        b2 = OWGUI.appendRadioButton(selMethRadio, self, "selectMethod", "Manual", insertInto=selMethRadio, callback=self.selectMethodChanged, addToLayout=False)
380        b3 = OWGUI.appendRadioButton(selMethRadio, self, "selectMethod", "Best ranked", insertInto=selMethRadio, callback=self.selectMethodChanged, addToLayout=False)
381#        brBox = OWGUI.widgetBox(selMethBox, orientation="horizontal", margin=0)
382#        OWGUI.appendRadioButton(selMethRadio, self, "selectMethod", "Best ranked", insertInto=brBox, callback=self.selectMethodChanged)
383        spin = OWGUI.spin(OWGUI.widgetBox(selMethRadio, addToLayout=False), self, "nSelected", 1, 100, orientation=0, callback=self.nSelectedChanged)
384        buttonGrid.addWidget(b1, 0, 0)
385        buttonGrid.addWidget(b2, 1, 0)
386        buttonGrid.addWidget(b3, 2, 0)
387        buttonGrid.addWidget(spin, 2, 1)
388        selMethRadio.layout().addLayout(buttonGrid)
389        OWGUI.separator(selMethBox)
390
391        applyButton = OWGUI.button(selMethBox, self, "Commit", callback = self.apply, default=True)
392        autoApplyCB = OWGUI.checkBox(selMethBox, self, "autoApply", "Commit automatically")
393        OWGUI.setStopper(self, applyButton, autoApplyCB, "dataChanged", self.apply)
394
395        OWGUI.rubber(self.controlArea)
396       
397        # Discrete and continuous table views are stacked
398        self.ranksViewStack = QStackedLayout()
399        self.mainArea.layout().addLayout(self.ranksViewStack)
400       
401        self.discRanksView = QTableView()
402        self.ranksViewStack.addWidget(self.discRanksView)
403        self.discRanksView.setSelectionBehavior(QTableView.SelectRows)
404        self.discRanksView.setSelectionMode(QTableView.MultiSelection)
405        self.discRanksView.setSortingEnabled(True)
406#        self.discRanksView.horizontalHeader().restoreState(self.discRanksHeaderState)
407       
408        self.discRanksModel = QStandardItemModel(self)
409        self.discRanksModel.setHorizontalHeaderLabels(
410            ["Attribute", "#"] + [m.shortname for m in self.discMeasures]
411        )
412        self.discRanksProxyModel = MySortProxyModel(self)
413        self.discRanksProxyModel.setSourceModel(self.discRanksModel)
414        self.discRanksView.setModel(self.discRanksProxyModel)
415#        self.discRanksView.verticalHeader().setResizeMode(QHeaderView.ResizeToContents)
416        self.discRanksView.setColumnWidth(1, 20)
417        self.discRanksView.sortByColumn(2, Qt.DescendingOrder)
418        self.connect(self.discRanksView.selectionModel(),
419                     SIGNAL("selectionChanged(QItemSelection, QItemSelection)"),
420                     self.onSelectionChanged)
421        self.connect(self.discRanksView,
422                     SIGNAL("pressed(const QModelIndex &)"),
423                     self.onSelectItem)
424        self.connect(self.discRanksView.horizontalHeader(),
425                     SIGNAL("sectionClicked(int)"),
426                     self.headerClick)
427       
428        self.contRanksView = QTableView()
429        self.ranksViewStack.addWidget(self.contRanksView)
430        self.contRanksView.setSelectionBehavior(QTableView.SelectRows)
431        self.contRanksView.setSelectionMode(QTableView.MultiSelection)
432        self.contRanksView.setSortingEnabled(True)
433#        self.contRanksView.setItemDelegate(OWGUI.ColoredBarItemDelegate())
434#        self.contRanksView.horizontalHeader().restoreState(self.contRanksHeaderState)
435       
436        self.contRanksModel = QStandardItemModel(self)
437        self.contRanksModel.setHorizontalHeaderLabels(
438            ["Attribute", "#"] + [m.shortname for m in self.contMeasures]
439        )
440        self.contRanksProxyModel = MySortProxyModel(self)
441        self.contRanksProxyModel.setSourceModel(self.contRanksModel)
442        self.contRanksView.setModel(self.contRanksProxyModel)
443#        self.contRanksView.verticalHeader().setResizeMode(QHeaderView.ResizeToContents)
444        self.discRanksView.setColumnWidth(1, 20)
445        self.contRanksView.sortByColumn(2, Qt.DescendingOrder)
446        self.connect(self.contRanksView.selectionModel(),
447                     SIGNAL("selectionChanged(QItemSelection, QItemSelection)"),
448                     self.onSelectionChanged)
449        self.connect(self.contRanksView,
450                     SIGNAL("pressed(const QModelIndex &)"),
451                     self.onSelectItem)
452        self.connect(self.contRanksView.horizontalHeader(),
453                     SIGNAL("sectionClicked(int)"),
454                     self.headerClick)
455       
456        # Switch the current view to Discrete
457        self.switchRanksMode(0)
458        self.resetInternals()
459        self.updateDelegates()
460        self.updateVisibleScoreColumns()
461
462#        self.connect(self.table.horizontalHeader(), SIGNAL("sectionClicked(int)"), self.headerClick)
463       
464        self.resize(690,500)
465        self.updateColor()
466       
467        self.measure_scores = table((len(self.measures), 0), None)
468
469    def switchRanksMode(self, index):
470        """ Switch between discrete/continuous mode
471        """
472        self.ranksViewStack.setCurrentIndex(index)
473        self.stackedLayout.setCurrentIndex(index)
474
475        if index == 0:
476            self.ranksView = self.discRanksView
477            self.ranksModel = self.discRanksModel
478            self.ranksProxyModel = self.discRanksProxyModel
479            self.measures = self.discMeasures
480        else:
481            self.ranksView = self.contRanksView
482            self.ranksModel = self.contRanksModel
483            self.ranksProxyModel = self.contRanksProxyModel
484            self.measures = self.contMeasures
485
486        self.updateVisibleScoreColumns()
487           
488    def setData(self, data):
489        self.error(0)
490        self.resetInternals()
491        self.data = self.isDataWithClass(data) and data or None
492        if self.data:
493            attrs = self.data.domain.attributes
494            self.usefulAttributes = filter(lambda x:x.varType in [orange.VarTypes.Discrete, orange.VarTypes.Continuous],
495                                           attrs)
496            if is_class_continuous(self.data):
497                self.switchRanksMode(1)
498            elif is_class_discrete(self.data):
499                self.switchRanksMode(0)
500            else: # String or other.
501                self.error(0, "Cannot handle class variable type")
502           
503#            self.ranksView.setSortingEnabled(False)
504            self.ranksModel.setRowCount(len(attrs))
505            for i, a in enumerate(attrs):
506                if isinstance(a, orange.EnumVariable):
507                    v = len(a.values)
508                else:
509                    v = "C"
510                item = PyStandardItem()
511                item.setData(QVariant(v), Qt.DisplayRole)
512                self.ranksModel.setItem(i, 1, item)
513                item = PyStandardItem(a.name)
514                item.setData(QVariant(i), OWGUI.SortOrderRole)
515                self.ranksModel.setItem(i, 0, item)
516               
517            self.ranksView.resizeColumnToContents(1)
518           
519            self.measure_scores = table((len(self.measures),
520                                         len(attrs)), None)
521            self.updateScores()
522            if is_class_discrete(self.data):
523                self.setLogORTitle()
524            self.ranksView.setSortingEnabled(self.sortBy > 0)
525           
526        self.applyIf()
527
528    def updateScores(self, measuresMask=None):
529        """ Update the current computed measures. If measuresMask is given
530        it must be an list of bool values indicating what measures should be
531        computed.
532       
533        """ 
534        if not self.data:
535            return
536       
537#         estimators = self.estimators
538        measures = self.measures
539#         handlesContinous = self.handlesContinuous
540        # Invalidate all warnings
541        self.warning(range(max(len(self.discMeasures),
542                               len(self.contMeasures))))
543
544        if measuresMask is None:
545            # Update all selected measures
546            measuresMask = [self.selectedMeasures.get(m.name)
547                            for m in measures]
548
549        for measure_index, (meas, mask) in enumerate(zip(measures, measuresMask)):
550            if not mask:
551                continue
552
553            params = measure_parameters(meas)
554            estimator = meas.score()
555            if params:
556                for p in params:
557                    setattr(estimator, p.name,
558                            getattr(self, param_attr_name(meas.score, p)))
559
560            if not meas.handles_continuous:
561                data = self.getDiscretizedData()
562                attr_map = data.attrDict
563                data = self.data
564            else:
565                attr_map, data = {}, self.data
566
567            attr_scores = []
568            for i, attr in enumerate(data.domain.attributes):
569                attr = attr_map.get(attr, attr)
570                s = None
571                if attr is not None:
572                    try:
573                        s = estimator(attr, data)
574                    except Exception, ex:
575                        self.warning(measure_index, "Error evaluating %r: %r" % (meas.name, str(ex)))
576                        # TODO: store exception message (for widget info or item tooltip)
577                    if meas.name == "Log Odds Ratio" and s is not None:
578                        if s == -999999:
579                            attr = u"-\u221E"
580                        elif s == 999999:
581                            attr = u"\u221E"
582                        else:
583                            attr = attr.values[1]
584                        s = ("%%.%df" % self.nDecimals + " (%s)") % (s, attr)
585                attr_scores.append(s)
586            self.measure_scores[measure_index] = attr_scores
587       
588        self.updateRankModel(measuresMask)
589        self.ranksProxyModel.invalidate()
590       
591        if self.selectMethod in [0, 2]:
592            self.autoSelection()
593   
594    def updateRankModel(self, measuresMask=None):
595        """ Update the rankModel.
596        """
597        values = []
598        for i, scores in enumerate(self.measure_scores):
599            values_one = []
600            for j, s in enumerate(scores):
601                if isinstance(s, float):
602                    values_one.append(s)
603                else:
604                    values_one.append(None)
605                item = self.ranksModel.item(j, i + 2)
606                if not item:
607                    item = PyStandardItem()
608                    self.ranksModel.setItem(j ,i + 2, item)
609                item.setData(QVariant(s), Qt.DisplayRole)
610            values.append(values_one)
611       
612        for i, vals in enumerate(values):
613            valid_vals = [v for v in vals if v is not None]
614            if valid_vals:
615                vmin, vmax = min(valid_vals), max(valid_vals)
616                for j, v in enumerate(vals):
617                    if v is not None:
618                        # Set the bar ratio role for i-th measure.
619                        ratio = float((v - vmin) / ((vmax - vmin) or 1))
620                        if self.showDistributions:
621                            self.ranksModel.item(j, i + 2).setData(QVariant(ratio), OWGUI.BarRatioRole)
622                        else:
623                            self.ranksModel.item(j, i + 2).setData(QVariant(), OWGUI.BarRatioRole)
624                       
625        self.ranksView.resizeColumnsToContents()
626        self.ranksView.setColumnWidth(1, 20)
627        self.ranksView.resizeRowsToContents()
628           
629    def cbShowDistributions(self):
630        # This should be handled by the delegates only (must always set the BarRatioRole
631        self.updateRankModel()
632        # Need to update the selection
633        self.autoSelection()
634
635    def changeColor(self):
636        color = QColorDialog.getColor(self.distColor, self)
637        if color.isValid():
638            self.distColorRgb = color.getRgb()
639            self.updateColor()
640
641    def updateColor(self):
642        self.distColor = QColor(*self.distColorRgb)
643        w = self.colButton.width()-8
644        h = self.colButton.height()-8
645        pixmap = QPixmap(w, h)
646        painter = QPainter()
647        painter.begin(pixmap)
648        painter.fillRect(0,0,w,h, QBrush(self.distColor))
649        painter.end()
650        self.colButton.setIcon(QIcon(pixmap))
651        self.updateDelegates()
652
653    def resetInternals(self):
654        self.data = None
655        self.discretizedData = None
656        self.attributeOrder = []
657        self.selected = []
658        self.measured = {}
659        self.usefulAttributes = []
660        self.dataChanged = False
661        self.lastSentAttrs = None
662        self.ranksModel.setRowCount(0)
663
664    def onSelectionChanged(self, *args):
665        """ Called when the ranks view selection changes.
666        """
667        selected = self.selectedAttrs()
668        self.clearButton.setEnabled(bool(selected))
669        self.applyIf()
670       
671    def onSelectItem(self, index):
672        """ Called when the user selects/unselects an item in the table view.
673        """
674        self.selectMethod = 1 # Manual
675        self.clearButton.setEnabled(bool(self.selectedAttrs()))
676        self.applyIf()
677
678    def clearSelection(self):
679        self.ranksView.selectionModel().clear()
680
681    def selectMethodChanged(self):
682        if self.selectMethod in [0, 2]:
683            self.autoSelection()
684
685    def nSelectedChanged(self):
686        self.selectMethod = 2
687        self.selectMethodChanged()
688
689    def getDiscretizedData(self):
690        if not self.discretizedData:
691            discretizer = orange.EquiNDiscretization(numberOfIntervals=self.nIntervals)
692            contAttrs = filter(lambda attr: attr.varType == orange.VarTypes.Continuous, self.data.domain.attributes)
693            at = []
694            attrDict = {}
695            for attri in contAttrs:
696                try:
697                    nattr = discretizer(attri, self.data)
698                    at.append(nattr)
699                    attrDict[attri] = nattr
700                except:
701                    pass
702            self.discretizedData = self.data.select(orange.Domain(at, self.data.domain.classVar))
703            self.discretizedData.setattr("attrDict", attrDict)
704        return self.discretizedData
705
706    def discretizationChanged(self):
707        self.discretizedData = None
708        self.updateScores([not m.handles_continuous for m in self.measures])
709        self.autoSelection()
710
711    def measureParamChanged(self, measure, param=None):
712        index = self.measures.index(measure)
713        mask = [i == index for i, _ in enumerate(self.measures)]
714        self.updateScores(mask)
715   
716    def loadMeasureDefaults(self, measure):
717#         index = self.measures.index(measure)
718#         measure = self.estimators[index]
719        params = measure_parameters(measure)
720        for i, p in enumerate(params):
721            setattr(self, param_attr_name(measure.score, p), p.default)
722        self.measureParamChanged(measure)
723       
724    def autoSelection(self):
725        selModel = self.ranksView.selectionModel()
726        rowCount = self.ranksModel.rowCount()
727        columnCount = self.ranksModel.columnCount()
728        model = self.ranksProxyModel
729        if self.selectMethod == 0:
730           
731            selection = QItemSelection(model.index(0, 0),
732                                       model.index(rowCount - 1,
733                                       columnCount -1))
734            selModel.select(selection, QItemSelectionModel.ClearAndSelect)
735        if self.selectMethod == 2:
736            nSelected = min(self.nSelected, rowCount)
737            selection = QItemSelection(model.index(0, 0),
738                                       model.index(nSelected - 1,
739                                       columnCount - 1))
740            selModel.select(selection, QItemSelectionModel.ClearAndSelect)
741
742    def headerClick(self, index):
743        self.sortBy = index + 1
744        if not self.ranksView.isSortingEnabled():
745            # The sorting is disabled ("No sorting|" selected by user)
746            self.sortingChanged()
747           
748        if index > 1 and self.selectMethod == 2:
749            # Reselect the top ranked attributes
750            self.autoSelection()
751        self.sortBy = index + 1
752        return
753
754    def sortingChanged(self):
755        """ Sorting was changed by user (through the Sort By combo box.)
756        """
757        self.updateSorting()
758        self.autoSelection()
759       
760    def updateSorting(self):
761        """ Update the sorting of the model/view.
762        """
763        self.ranksProxyModel.invalidate()
764        if self.sortBy == 0:
765            self.ranksProxyModel.setSortRole(OWGUI.SortOrderRole)
766            self.ranksProxyModel.sort(0, Qt.DescendingOrder)
767            self.ranksView.setSortingEnabled(False)
768           
769        else:
770            self.ranksProxyModel.setSortRole(Qt.DisplayRole)
771            self.ranksView.sortByColumn(self.sortBy - 1, Qt.DescendingOrder)
772            self.ranksView.setSortingEnabled(True)
773
774    def setLogORTitle(self):
775        var = self.data.domain.classVar
776        if len(var.values) == 2:
777            title = "log OR (for %r)" % var.values[1][:10]
778        else:
779            title = "log OR"
780#         if "Log Odds Ratio" in self.discEstimators:
781#             index = self.discMeasures.index("Log Odds Ratio")
782        index = [m.name for m in self.discMeasures].index("Log Odds Ratio")
783
784        item = PyStandardItem(title)
785        self.ranksModel.setHorizontalHeaderItem(index + 2, item)
786
787    def measuresSelectionChanged(self, measure=None):
788        """Measure selection has changed. Update column visibility.
789        """
790        if measure is None:
791            # Update all scores
792            measuresMask = None
793        else:
794            # Update scores for shown column if they are not yet computed.
795            shown = self.selectedMeasures.get(measure.name, False)
796            index = self.measures.index(measure)
797            if all(s is None for s in self.measure_scores[index]) and shown:
798                measuresMask = [m == measure for m in self.measures]
799            else:
800                measuresMask = [False] * len(self.measures)
801        self.updateScores(measuresMask)
802       
803        self.updateVisibleScoreColumns()
804           
805    def updateVisibleScoreColumns(self):
806        """ Update the visible columns of the scores view.
807        """
808        for i, measure in enumerate(self.measures):
809            shown = self.selectedMeasures.get(measure.name)
810            self.ranksView.setColumnHidden(i + 2, not shown)
811
812    def sortByColumn(self, col):
813        if col < 2:
814            self.sortBy = 1 + col
815        else:
816            self.sortBy = 3 + self.selectedMeasures[col-2]
817        self.sortingChanged()
818
819    def decimalsChanged(self):
820        self.updateDelegates()
821        self.ranksView.resizeColumnsToContents()
822       
823    def updateDelegates(self):
824        self.contRanksView.setItemDelegate(OWGUI.ColoredBarItemDelegate(self,
825                            decimals=self.nDecimals,
826                            color=self.distColor))
827        self.discRanksView.setItemDelegate(OWGUI.ColoredBarItemDelegate(self,
828                            decimals=self.nDecimals,
829                            color=self.distColor))
830       
831    def sendReport(self):
832        self.reportData(self.data)
833        self.reportRaw(OWReport.reportTable(self.ranksView))
834
835    def applyIf(self):
836        if self.autoApply:
837            self.apply()
838        else:
839            self.dataChanged = True
840
841    def apply(self):
842        selected = self.selectedAttrs()
843        if not self.data or not selected:
844            self.send("Reduced Data", None)
845        else:
846            domain = orange.Domain(selected, self.data.domain.classVar)
847            domain.addmetas(self.data.domain.getmetas())
848            data = orange.ExampleTable(domain, self.data)
849            self.send("Reduced Data", data)
850        self.dataChanged = False
851       
852    def selectedAttrs(self):
853        if self.data:
854            inds = self.ranksView.selectionModel().selectedRows(0)
855            source = self.ranksProxyModel.mapToSource
856            inds = map(source, inds)
857            inds = [ind.row() for ind in inds]
858            return [self.data.domain.attributes[i] for i in inds]
859        else:
860            return []   
861
862
863class PyStandardItem(QStandardItem):
864    """ A StandardItem subclass for python objects.
865    """
866    def __init__(self, *args):
867        QStandardItem.__init__(self, *args)
868        self.setFlags(Qt.ItemIsSelectable| Qt.ItemIsEnabled)
869       
870    def __lt__(self, other):
871        my = self.data(Qt.DisplayRole).toPyObject()
872        other = other.data(Qt.DisplayRole).toPyObject()
873        if my is None:
874            return True
875        return my < other
876
877class MySortProxyModel(QSortFilterProxyModel):
878    def headerData(self, section, orientation, role):
879        """ Don't map headers.
880        """
881        source = self.sourceModel()
882        return source.headerData(section, orientation, role)
883   
884    def lessThan(self, left, right):
885        role = self.sortRole()
886        left = left.data(role).toPyObject()
887        right = right.data(role).toPyObject()
888        return left < right
889
890if __name__=="__main__":
891    a=QApplication(sys.argv)
892    ow=OWRank()
893    ow.setData(orange.ExampleTable("wine.tab"))
894    ow.setData(orange.ExampleTable("zoo.tab"))
895    ow.setData(orange.ExampleTable("servo.tab"))
896    ow.setData(orange.ExampleTable("iris.tab"))
897#    ow.setData(orange.ExampleTable("auto-mpg.tab"))
898    ow.show()
899    a.exec_()
900    ow.saveSettings()
901
Note: See TracBrowser for help on using the repository browser.