Changeset 8251:69c5e9d08181 in orange


Ignore:
Timestamp:
08/22/11 11:35:12 (3 years ago)
Author:
ales_erjavec <ales.erjavec@…>
Branch:
default
Convert:
68281bf6ce3e852e4af7082464a0c1c96e639a0c
Message:

Added support for continuous class measures (ReliefF and MSE).
Rewrote large parts of the widget to do so.
Fixes #888.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/OrangeWidgets/Data/OWRank.py

    r8042 r8251  
    77""" 
    88from OWWidget import * 
    9  
    109import OWGUI 
    11  
     10import orange 
     11 
     12def _toPyObject(variant): 
     13    val = variant.toPyObject() 
     14    if isinstance(val, type(NotImplemented)): 
     15        # PyQt 4.4 converts python int, floats ... to C types and 
     16        # cannot convert them back again and returns an exception instance. 
     17        qtype = variant.type() 
     18        if qtype == QVariant.Double: 
     19            val, ok = variant.toDouble() 
     20        elif qtype == QVariant.Int: 
     21            val, ok = variant.toInt() 
     22        elif qtype == QVariant.LongLong: 
     23            val, ok = variant.toLongLong() 
     24        elif qtype == QVariant.String: 
     25            val = variant.toString() 
     26    return val 
     27 
     28def is_class_discrete(data): 
     29    return isinstance(data.domain.classVar, orange.EnumVariable) 
     30 
     31def is_class_continuous(data): 
     32    return isinstance(data.domain.classVar, orange.FloatVariable) 
     33 
     34def table(shape, fill=None): 
     35    """ Return a 2D table with shape filed with ``fill`` 
     36    """ 
     37    return [[fill for j in range(shape[1])] for i in range(shape[0])] 
     38  
    1239class OWRank(OWWidget): 
    1340    settingsList =  ["nDecimals", "reliefK", "reliefN", "nIntervals", "sortBy", "nSelected", "selectMethod", "autoApply", "showDistributions", "distColorRgb"] 
    14     measures          = ["ReliefF", "Information Gain", "Gain Ratio", "Gini Gain", "Log Odds Ratio"] 
    15     measuresShort     = ["ReliefF", "Inf. gain", "Gain ratio", "Gini", "log OR"] 
    16     measuresAttrs     = ["computeReliefF", "computeInfoGain", "computeGainRatio", "computeGini", "computeLogOdds"] 
    17     estimators        = [orange.MeasureAttribute_relief, orange.MeasureAttribute_info, orange.MeasureAttribute_gainRatio, orange.MeasureAttribute_gini, orange.MeasureAttribute_logOddsRatio] 
    18     handlesContinuous = [True, False, False, False, False] 
     41    discMeasures          = ["ReliefF", "Information Gain", "Gain Ratio", "Gini Gain", "Log Odds Ratio"] 
     42    discMeasuresShort     = ["ReliefF", "Inf. gain", "Gain ratio", "Gini", "log OR"] 
     43    discMeasuresAttrs     = ["computeReliefF", "computeInfoGain", "computeGainRatio", "computeGini", "computeLogOdds"] 
     44    discEstimators        = [orange.MeasureAttribute_relief, orange.MeasureAttribute_info, orange.MeasureAttribute_gainRatio, orange.MeasureAttribute_gini, orange.MeasureAttribute_logOddsRatio] 
     45    discHandlesContinuous = [True, False, False, False, False] 
     46     
     47    contMeasures      = ["ReliefF", "MSE"] 
     48    contMeasuresShort = ["ReliefF", "MSE"] 
     49    contMeasuresAttrs = ["computeReliefFCont", "computeMSECont"] 
     50    contEstimators    = [orange.MeasureAttribute_relief, orange.MeasureAttribute_MSE] 
     51    contHandlesContinuous   = [True, False] 
    1952 
    2053    def __init__(self,parent=None, signalManager = None): 
     
    2255 
    2356        self.inputs = [("Examples", ExampleTable, self.setData)] 
    24         self.outputs = [("Reduced Example Table", ExampleTable, Default + Single), ("ExampleTable Attributes", ExampleTable, NonDefault)] 
    25  
    26         self.settingsList += self.measuresAttrs 
    27         self.logORIdx = self.measuresShort.index("log OR") 
     57        self.outputs = [("Reduced Example Table", ExampleTable, Default + Single)] 
     58 
     59        self.settingsList += self.discMeasuresAttrs + self.contMeasuresAttrs 
     60        self.logORIdx = self.discMeasuresShort.index("log OR") 
    2861 
    2962        self.nDecimals = 3 
     
    3164        self.reliefN = 20 
    3265        self.nIntervals = 4 
    33         self.sortBy = 0 
     66        self.sortBy = 2 
    3467        self.selectMethod = 2 
    3568        self.nSelected = 5 
     
    4275        self.data = None 
    4376 
    44         for meas in self.measuresAttrs: 
     77        for meas in self.discMeasuresAttrs + self.contMeasuresAttrs: 
    4578            setattr(self, meas, True) 
    4679 
     
    4982        labelWidth = 80 
    5083 
    51         box = OWGUI.widgetBox(self.controlArea, "Scoring", addSpace=True) 
    52         for meas, valueName in zip(self.measures, self.measuresAttrs): 
     84        self.stackedLayout = QStackedLayout() 
     85        self.stackedLayout.setContentsMargins(0, 0, 0, 0) 
     86        self.stackedWidget = OWGUI.widgetBox(self.controlArea, margin=0, 
     87                                             orientation=self.stackedLayout, 
     88                                             addSpace=True) 
     89        # Discrete class scoring 
     90        box = OWGUI.widgetBox(self.stackedWidget, "Scoring", 
     91                              addSpace=False, 
     92                              addToLayout=False) 
     93        self.stackedLayout.addWidget(box) 
     94         
     95        for meas, valueName in zip(self.discMeasures, self.discMeasuresAttrs): 
    5396            if valueName == "computeReliefF": 
    5497                hbox = OWGUI.widgetBox(box, orientation = "horizontal") 
     
    62105            else: 
    63106                OWGUI.checkBox(box, self, valueName, meas, callback=self.measuresChanged) 
    64         OWGUI.separator(box) 
    65  
    66         OWGUI.comboBox(box, self, "sortBy", label = "Sort by"+"  ", items = ["No Sorting", "Attribute Name", "Number of Values"] + self.measures, orientation=0, valueType = int, callback=self.sortingChanged) 
    67  
    68  
    69         box = OWGUI.widgetBox(self.controlArea, "Discretization", addSpace=True) 
    70         OWGUI.spin(box, self, "nIntervals", 2, 20, label="Intervals: ", orientation=0, callback=self.discretizationChanged, callbackOnReturn = True) 
     107                 
     108        OWGUI.comboBox(box, self, "sortBy", label = "Sort by"+"  ", 
     109                       items = ["No Sorting", "Attribute Name", "Number of Values"] + self.discMeasures, 
     110                       orientation=0, valueType = int, callback=self.sortingChanged) 
     111                 
     112        # Continuous class scoring 
     113        box = OWGUI.widgetBox(self.stackedWidget, "Scoring", 
     114                              addSpace=False, 
     115                              addToLayout=False) 
     116        self.stackedLayout.addWidget(box) 
     117        for meas, valueName in zip(self.contMeasures, self.contMeasuresAttrs): 
     118            if valueName == "computeReliefFCont": 
     119                hbox = OWGUI.widgetBox(box, orientation = "horizontal") 
     120                OWGUI.checkBox(hbox, self, valueName, meas, callback=self.measuresChanged) 
     121                hbox.layout().addSpacing(5) 
     122                smallWidget = OWGUI.SmallWidgetLabel(hbox, pixmap = 1, box = "ReliefF Parameters", tooltip = "Show ReliefF parameters") 
     123                OWGUI.spin(smallWidget.widget, self, "reliefK", 1, 20, label="Neighbours", labelWidth=labelWidth, orientation=0, callback=self.reliefChanged, callbackOnReturn = True) 
     124                OWGUI.spin(smallWidget.widget, self, "reliefN", 20, 100, label="Examples", labelWidth=labelWidth, orientation=0, callback=self.reliefChanged, callbackOnReturn = True) 
     125                OWGUI.button(smallWidget.widget, self, "Load defaults", callback = self.loadReliefDefaults) 
     126                OWGUI.rubber(hbox) 
     127            else: 
     128                OWGUI.checkBox(box, self, valueName, meas, callback=self.measuresChanged) 
     129                 
     130        OWGUI.comboBox(box, self, "sortBy", label = "Sort by"+"  ", 
     131                       items = ["No Sorting", "Attribute Name", "Number of Values"] + self.contMeasures, 
     132                       orientation=0, valueType = int, callback=self.sortingChanged) 
     133 
     134        box = OWGUI.widgetBox(self.controlArea, "Discretization", 
     135                              addSpace=True) 
     136        OWGUI.spin(box, self, "nIntervals", 2, 20, 
     137                   label="Intervals: ", 
     138                   orientation=0, 
     139                   tooltip="Disctetization for measures which cannot score continuous attributes", 
     140                   callback=self.discretizationChanged, 
     141                   callbackOnReturn=True) 
    71142 
    72143        box = OWGUI.widgetBox(self.controlArea, "Precision", addSpace=True) 
    73         OWGUI.spin(box, self, "nDecimals", 1, 6, label="No. of decimals: ", orientation=0, callback=self.decimalsChanged) 
    74  
    75         box = OWGUI.widgetBox(self.controlArea, "Score bars", orientation="horizontal", addSpace=True) 
    76         self.cbShowDistributions = OWGUI.checkBox(box, self, "showDistributions", 'Enable', callback = self.cbShowDistributions) 
     144        OWGUI.spin(box, self, "nDecimals", 1, 6, label="No. of decimals: ", 
     145                   orientation=0, callback=self.decimalsChanged) 
     146 
     147        box = OWGUI.widgetBox(self.controlArea, "Score bars", 
     148                              orientation="horizontal", addSpace=True) 
     149        self.cbShowDistributions = OWGUI.checkBox(box, self, "showDistributions", 
     150                                    'Enable', callback = self.cbShowDistributions) 
    77151#        colBox = OWGUI.indentedBox(box, orientation = "horizontal") 
    78152        OWGUI.rubber(box) 
     
    111185        OWGUI.rubber(self.controlArea) 
    112186         
    113         self.table = QTableWidget() 
    114         self.mainArea.layout().addWidget(self.table) 
    115  
    116         self.table.setSelectionBehavior(QAbstractItemView.SelectRows) 
    117         self.table.setSelectionMode(QAbstractItemView.MultiSelection) 
    118         self.table.verticalHeader().setResizeMode(QHeaderView.ResizeToContents) 
    119         self.table.setItemDelegate(RankItemDelegate(self, self.table)) 
    120  
    121         self.topheader = self.table.horizontalHeader() 
    122         self.topheader.setSortIndicatorShown(1) 
    123         self.topheader.setHighlightSections(0) 
    124  
    125         self.setMeasures() 
     187        # Discrete and continuous table views are stacked 
     188        self.ranksViewStack = QStackedLayout() 
     189        self.mainArea.layout().addLayout(self.ranksViewStack) 
     190         
     191        self.discRanksView = QTableView() 
     192        self.ranksViewStack.addWidget(self.discRanksView) 
     193        self.discRanksView.setSelectionBehavior(QTableView.SelectRows) 
     194        self.discRanksView.setSelectionMode(QTableView.MultiSelection) 
     195        self.discRanksView.setSortingEnabled(True) 
     196#        self.discRanksView.horizontalHeader().restoreState(self.discRanksHeaderState) 
     197         
     198        self.discRanksModel = QStandardItemModel(self) 
     199        self.discRanksModel.setHorizontalHeaderLabels(["Attribute", "#"] + self.discMeasuresShort) 
     200        self.discRanksProxyModel = MySortProxyModel(self) 
     201        self.discRanksProxyModel.setSourceModel(self.discRanksModel) 
     202        self.discRanksView.setModel(self.discRanksProxyModel) 
     203#        self.discRanksView.verticalHeader().setResizeMode(QHeaderView.ResizeToContents) 
     204        self.discRanksView.setColumnWidth(1, 20) 
     205        self.discRanksView.sortByColumn(2, Qt.DescendingOrder) 
     206        self.connect(self.discRanksView.selectionModel(), 
     207                     SIGNAL("selectionChanged(QItemSelection, QItemSelection)"), 
     208                     self.onSelectionChanged) 
     209        self.connect(self.discRanksView, 
     210                     SIGNAL("pressed(const QModelIndex &)"), 
     211                     self.onSelectItem) 
     212        self.connect(self.discRanksView.horizontalHeader(), 
     213                     SIGNAL("sectionClicked(int)"), 
     214                     self.headerClick) 
     215         
     216        self.contRanksView = QTableView() 
     217        self.ranksViewStack.addWidget(self.contRanksView) 
     218        self.contRanksView.setSelectionBehavior(QTableView.SelectRows) 
     219        self.contRanksView.setSelectionMode(QTableView.MultiSelection) 
     220        self.contRanksView.setSortingEnabled(True) 
     221#        self.contRanksView.setItemDelegate(RankItemDelegate()) 
     222#        self.contRanksView.horizontalHeader().restoreState(self.contRanksHeaderState) 
     223         
     224        self.contRanksModel = QStandardItemModel(self) 
     225        self.contRanksModel.setHorizontalHeaderLabels(["Attribute", "#"] + self.contMeasuresShort) 
     226        self.contRanksProxyModel = MySortProxyModel(self) 
     227        self.contRanksProxyModel.setSourceModel(self.contRanksModel) 
     228        self.contRanksView.setModel(self.contRanksProxyModel) 
     229#        self.contRanksView.verticalHeader().setResizeMode(QHeaderView.ResizeToContents) 
     230        self.discRanksView.setColumnWidth(1, 20) 
     231        self.contRanksView.sortByColumn(2, Qt.DescendingOrder) 
     232        self.connect(self.contRanksView.selectionModel(), 
     233                     SIGNAL("selectionChanged(QItemSelection, QItemSelection)"), 
     234                     self.onSelectionChanged) 
     235        self.connect(self.contRanksView, 
     236                     SIGNAL("pressed(const QModelIndex &)"), 
     237                     self.onSelectItem) 
     238        self.connect(self.contRanksView.horizontalHeader(), 
     239                     SIGNAL("sectionClicked(int)"), 
     240                     self.headerClick) 
     241         
     242        # Switch the current view to Discrete 
     243        self.switchRanksMode(0) 
    126244        self.resetInternals() 
    127  
    128         self.connect(self.table.horizontalHeader(), SIGNAL("sectionClicked(int)"), self.headerClick) 
    129         self.connect(self.table, SIGNAL("clicked (const QModelIndex&)"), self.selectItem) 
    130         self.connect(self.table, SIGNAL("itemSelectionChanged()"), self.onSelectionChanged) 
     245        self.updateDelegates() 
     246 
     247#        self.connect(self.table.horizontalHeader(), SIGNAL("sectionClicked(int)"), self.headerClick) 
    131248         
    132249        self.resize(690,500) 
    133250        self.updateColor() 
    134251 
     252    def switchRanksMode(self, index): 
     253        """ Switch between discrete/continuous mode 
     254        """ 
     255        self.ranksViewStack.setCurrentIndex(index) 
     256        self.stackedLayout.setCurrentIndex(index) 
     257         
     258        if index == 0: 
     259            self.ranksView = self.discRanksView 
     260            self.ranksModel = self.discRanksModel 
     261            self.ranksProxyModel = self.discRanksProxyModel 
     262            self.measures = self.discMeasures 
     263            self.handlesContinuous = self.discHandlesContinuous 
     264            self.estimators = self.discEstimators 
     265            self.measuresAttrs = self.discMeasuresAttrs 
     266        else: 
     267            self.ranksView = self.contRanksView 
     268            self.ranksModel = self.contRanksModel 
     269            self.ranksProxyModel = self.contRanksProxyModel 
     270            self.measures = self.contMeasures 
     271            self.handlesContinuous = self.contHandlesContinuous 
     272            self.estimators = self.contEstimators 
     273            self.measuresAttrs = self.contMeasuresAttrs 
     274             
     275    def setData(self, data): 
     276        self.error(0) 
     277        self.resetInternals() 
     278        self.data = self.isDataWithClass(data) and data or None 
     279        if self.data: 
     280            attrs = self.data.domain.attributes 
     281            self.usefulAttributes = filter(lambda x:x.varType in [orange.VarTypes.Discrete, orange.VarTypes.Continuous], 
     282                                           attrs) 
     283            if is_class_continuous(self.data): 
     284                self.switchRanksMode(1) 
     285            elif is_class_discrete(self.data): 
     286                self.switchRanksMode(0) 
     287            else: # String or other. 
     288                self.error(0, "Cannot handle class variable type") 
     289             
     290#            self.ranksView.setSortingEnabled(False) 
     291            self.ranksModel.setRowCount(len(attrs)) 
     292            for i, a in enumerate(attrs): 
     293                if isinstance(a, orange.EnumVariable): 
     294                    v = len(a.values) 
     295                else: 
     296                    v = "C" 
     297                item = PyStandardItem() 
     298                item.setData(QVariant(v), Qt.DisplayRole) 
     299                self.ranksModel.setItem(i, 1, item) 
     300                item = PyStandardItem(a.name) 
     301                item.setData(QVariant(i), OWGUI.SortOrderRole) 
     302                self.ranksModel.setItem(i, 0, item) 
     303                 
     304            self.ranksView.resizeColumnToContents(1) 
     305             
     306            self.measure_scores = table((len(self.measures), 
     307                                         len(attrs)), None) 
     308            self.updateScores() 
     309            if is_class_discrete(self.data): 
     310                self.setLogORTitle() 
     311            self.ranksView.setSortingEnabled(self.sortBy > 0) 
     312            self.ranksView 
     313             
     314        self.applyIf() 
     315 
     316    def updateScores(self, measuresMask=None): 
     317        """ Update the current computed measures. If measuresMask is given 
     318        it must be an list of bool values indicating what measures should be  
     319        computed. 
     320         
     321        """  
     322        if not self.data: 
     323            return 
     324         
     325        estimators = self.estimators 
     326        measures = self.measures 
     327        handlesContinous = self.handlesContinuous 
     328        self.warning(range(max(len(self.discEstimators), len(self.contEstimators)))) 
     329         
     330        if measuresMask is None: 
     331            measuresMask = [True] * len(measures) 
     332        for measure_index, (est, meas, handles, mask) in enumerate(zip( 
     333                estimators, measures, handlesContinous, measuresMask)): 
     334            if not mask: 
     335                continue 
     336            if meas == "ReliefF": 
     337                est = est() 
     338                est.k = self.reliefK 
     339                est.m = self.reliefN 
     340            else: 
     341                est = est() 
     342            if not handles: 
     343                data = self.getDiscretizedData() 
     344                attr_map = data.attrDict 
     345                data = self.data 
     346            else: 
     347                attr_map, data = {}, self.data 
     348            attr_scores = [] 
     349            for i, attr in enumerate(data.domain.attributes): 
     350                attr = attr_map.get(attr, attr) 
     351                s = None 
     352                if attr is not None: 
     353                    try: 
     354                        s = est(attr, data) 
     355                    except Exception, ex: 
     356                        self.warning(measure_index, "Error evaluating %r: %r" % (meas, str(ex))) 
     357                        # TODO: store exception message (for widget info or item tooltip) 
     358                    if meas == "Log Odds Ratio" and s is not None: 
     359                        if s == -999999: 
     360                            attr = u"-\u221E" 
     361                        elif s == 999999: 
     362                            attr = u"\u221E" 
     363                        else: 
     364                            attr = attr.values[1] 
     365                        s = ("%%.%df" % self.nDecimals + " (%s)") % (s, attr) 
     366                attr_scores.append(s) 
     367            self.measure_scores[measure_index] = attr_scores 
     368         
     369        self.updateRankModel() 
     370        self.ranksProxyModel.invalidate() 
     371         
     372        if self.selectMethod in [0, 2]: 
     373            self.autoSelection() 
     374     
     375    def updateRankModel(self, measuresMask=None): 
     376        """ Update the rankModel. 
     377        """ 
     378        values = [] 
     379        for i, scores in enumerate(self.measure_scores): 
     380            values_one = [] 
     381            for j, s in enumerate(scores): 
     382                if isinstance(s, float): 
     383                    values_one.append(s) 
     384                else: 
     385                    values_one.append(None) 
     386                 
     387#                if s is None: 
     388#                    s = "NA" 
     389                item = self.ranksModel.item(j, i + 2) 
     390                if not item: 
     391                    item = PyStandardItem() 
     392                    self.ranksModel.setItem(j ,i + 2, item) 
     393                item.setData(QVariant(s), Qt.DisplayRole) 
     394            values.append(values_one) 
     395         
     396        for i, vals in enumerate(values): 
     397            valid_vals = [v for v in vals if v is not None] 
     398            if valid_vals: 
     399                vmin, vmax = min(valid_vals), max(valid_vals) 
     400                for j, v in enumerate(vals): 
     401                    if v is not None: 
     402                        # Set the bar ratio role for i-th measure. 
     403                        ratio = (v - vmin) / ((vmax - vmin) or 1) 
     404                        if self.showDistributions: 
     405                            self.ranksModel.item(j, i + 2).setData(QVariant(ratio), OWGUI.BarRatioRole) 
     406                        else: 
     407                            self.ranksModel.item(j, i + 2).setData(QVariant(), OWGUI.BarRatioRole) 
     408                         
     409        self.ranksView.resizeColumnsToContents() 
     410        self.ranksView.setColumnWidth(1, 20) 
     411        self.ranksView.resizeRowsToContents() 
     412             
    135413    def cbShowDistributions(self): 
    136         self.table.reset() 
     414        self.updateRankModel() 
     415        # Need to update the selection 
     416        self.autoSelection() 
    137417 
    138418    def changeColor(self): 
     
    152432        painter.end() 
    153433        self.colButton.setIcon(QIcon(pixmap)) 
    154         self.table.viewport().update() 
    155  
     434        self.updateDelegates() 
    156435 
    157436    def resetInternals(self): 
     
    164443        self.dataChanged = False 
    165444        self.lastSentAttrs = None 
    166  
    167         self.table.setRowCount(0) 
    168  
    169     def onSelectionChanged(self): 
    170         if not getattr(self, "_reselecting", False): 
    171             selected = sorted(set(item.row() for item in self.table.selectedItems())) 
    172             self.clearButton.setEnabled(bool(selected)) 
    173             selected = [self.attributeOrder[row] for row in selected] 
    174             if set(selected) != set(self.selected): 
    175                 self.selected = selected 
    176                 self.selectMethod = 1 
    177             self.applyIf() 
     445        self.ranksModel.setRowCount(0) 
     446 
     447    def onSelectionChanged(self, *args): 
     448        """ Called when the ranks vire selection changes. 
     449        """ 
     450        selected = self.selectedAttrs() 
     451        self.clearButton.setEnabled(bool(selected)) 
     452#        # Change the selectionMethod to manual if necessary. 
     453#        if self.selectMethod == 0 and len(selected) != self.ranksModel.rowCount(): 
     454#            self.selectMethod = 1 
     455#        elif self.selectMethod == 2: 
     456#            inds = self.ranksView.selectionModel().selectedRows(0) 
     457#            rows = [ind.row() for ind in inds] 
     458#            if set(rows) != set(range(self.nSelected)): 
     459#                self.selectMethod = 1 
     460                 
     461        self.applyIf() 
     462         
     463    def onSelectItem(self, index): 
     464        """ Called when the user selects/unselects an item in the table view. 
     465        """ 
     466        self.selectMethod = 1 # Manual 
     467        self.clearButton.setEnabled(bool(self.selectedAttrs())) 
     468        self.applyIf() 
    178469 
    179470    def clearSelection(self): 
    180         self.selected = []  
    181         self.reselect() 
     471        self.ranksView.selectionModel().clear() 
    182472 
    183473    def selectMethodChanged(self): 
    184         if self.selectMethod == 0: 
    185             self.selected = self.attributeOrder[:] 
    186             self.reselect() 
    187         elif self.selectMethod == 2: 
    188             self.selected = self.attributeOrder[:self.nSelected] 
    189             self.reselect() 
    190         self.applyIf() 
     474        if self.selectMethod in [0, 2]: 
     475            self.autoSelection() 
    191476 
    192477    def nSelectedChanged(self): 
     
    194479        self.selectMethodChanged() 
    195480 
    196     def sendSelected(self): 
    197         attrs = self.data and [attr for i, attr in enumerate(self.attributeOrder) if self.table.isRowSelected(i)] 
    198         if not attrs: 
    199             self.send("ExampleTable Attributes", None) 
    200             return 
    201  
    202         nDomain = orange.Domain(attrs, self.data.domain.classVar) 
    203         for meta in [a.name for a in self.data.domain.getmetas().values()]: 
    204             nDomain.addmeta(orange.newmetaid(), self.data.domain[meta]) 
    205  
    206         self.send("ExampleTable Attributes", orange.ExampleTable(nDomain, self.data)) 
    207  
    208  
    209     def setData(self,data): 
    210         self.resetInternals() 
    211  
    212         self.data = self.isDataWithClass(data, orange.VarTypes.Discrete) and data or None 
    213         if self.data: 
    214             self.usefulAttributes = filter(lambda x:x.varType in [orange.VarTypes.Discrete, orange.VarTypes.Continuous], self.data.domain.attributes) 
    215             self.table.setRowCount(len(self.data.domain.attributes)) 
    216             self.reprint() 
    217  
    218         self.setLogORTitle() 
    219         self.resendAttributes() 
    220         self.applyIf() 
    221  
    222  
     481    def getDiscretizedData(self): 
     482        if not self.discretizedData: 
     483            discretizer = orange.EquiNDiscretization(numberOfIntervals=self.nIntervals) 
     484            contAttrs = filter(lambda attr: attr.varType == orange.VarTypes.Continuous, self.data.domain.attributes) 
     485            at = [] 
     486            attrDict = {} 
     487            for attri in contAttrs: 
     488                try: 
     489                    nattr = discretizer(attri, self.data) 
     490                    at.append(nattr) 
     491                    attrDict[attri] = nattr 
     492                except: 
     493                    pass 
     494            self.discretizedData = self.data.select(orange.Domain(at, self.data.domain.classVar)) 
     495            self.discretizedData.setattr("attrDict", attrDict) 
     496        return self.discretizedData 
     497         
    223498    def discretizationChanged(self): 
    224499        self.discretizedData = None 
    225  
    226         removed = False 
    227         for meas, cont in zip(self.measuresAttrs, self.handlesContinuous): 
    228             if not cont and self.measured.has_key(meas): 
    229                 del self.measured[meas] 
    230                 removed = True 
    231  
    232         if self.data and self.data.domain.hasContinuousAttributes(False): 
    233             sortedByThis = self.sortBy>=3 and not self.handlesContinuous[self.sortBy-3] 
    234             if removed or sortedByThis: 
    235                 self.reprint() 
    236                 self.resendAttributes() 
    237                 if sortedByThis and self.selectMethod == 2: 
    238                     self.applyIf() 
    239  
     500        self.updateScores([not b for b in self.handlesContinuous]) 
     501        self.autoSelection() 
    240502 
    241503    def reliefChanged(self): 
    242         removed = False 
    243         if self.measured.has_key("computeReliefF"): 
    244             del self.measured["computeReliefF"] 
    245             removed = True 
    246  
    247         if self.data: 
    248             sortedByReliefF = self.sortBy-3 == self.measuresAttrs.index("computeReliefF") 
    249             if removed or sortedByReliefF: 
    250                 self.reprint() 
    251                 self.resendAttributes() 
    252                 if sortedByReliefF and self.selectMethod == 2: 
    253                     self.applyIf() 
     504        self.updateScores([m == "ReliefF" for m in self.measures]) 
     505        self.autoSelection() 
    254506 
    255507    def loadReliefDefaults(self): 
     
    257509        self.reliefN = 20 
    258510        self.reliefChanged() 
    259  
    260  
    261     def selectItem(self, index): 
    262         pass 
    263 #        row = index.row() 
    264 #        attr = self.attributeOrder[row] 
    265 #        if attr in self.selected: 
    266 #            self.selected.remove(attr) 
    267 #        else: 
    268 #            self.selected.append(attr) 
    269 #        self.selectMethod = 1 
    270 #        self.applyIf() 
     511         
     512    def autoSelection(self): 
     513        selModel = self.ranksView.selectionModel() 
     514        rowCount = self.ranksModel.rowCount() 
     515        columnCount = self.ranksModel.columnCount() 
     516        model = self.ranksProxyModel 
     517        if self.selectMethod == 0: 
     518             
     519            selection = QItemSelection(model.index(0, 0), 
     520                                       model.index(rowCount - 1, 
     521                                       columnCount -1)) 
     522            selModel.select(selection, QItemSelectionModel.ClearAndSelect) 
     523        if self.selectMethod == 2: 
     524            nSelected = min(self.nSelected, rowCount) 
     525            selection = QItemSelection(model.index(0, 0), 
     526                                       model.index(nSelected - 1, 
     527                                       columnCount - 1)) 
     528            selModel.select(selection, QItemSelectionModel.ClearAndSelect) 
    271529 
    272530    def headerClick(self, index): 
    273         if index < 0: return 
    274  
    275         if index < 2: 
    276             self.sortBy = 1 + index 
    277         else: 
    278             self.sortBy = 3 + self.measuresShort.index(str(self.table.horizontalHeader().model().headerData(index, Qt.Horizontal).toString())) 
    279         self.sortingChanged() 
     531        self.sortBy = index + 1 
     532        if not self.ranksView.isSortingEnabled(): 
     533            # The sorting is disabled ("No sorting|" selected by user) 
     534            self.sortingChanged() 
     535             
     536        if index > 1 and self.selectMethod == 2: 
     537            # Reselect the top ranked attributes 
     538            self.autoSelection() 
     539        self.sortBy = index + 1 
     540        return 
    280541 
    281542    def sortingChanged(self): 
    282         self.reprint() 
    283         self.resendAttributes() 
    284         if self.selectMethod == 2: 
    285             self.applyIf() 
    286  
     543        """ Sorting was changed by user (through the Sort By combo box.) 
     544        """ 
     545        self.updateSorting() 
     546        self.autoSelection() 
     547         
     548    def updateSorting(self): 
     549        """ Update the sorting of the model/view. 
     550        """ 
     551        self.ranksProxyModel.invalidate() 
     552        if self.sortBy == 0: 
     553            self.ranksProxyModel.setSortRole(OWGUI.SortOrderRole) 
     554            self.ranksProxyModel.sort(0, Qt.DescendingOrder) 
     555            self.ranksView.setSortingEnabled(False) 
     556             
     557        else: 
     558            self.ranksProxyModel.setSortRole(Qt.DisplayRole) 
     559            self.ranksView.sortByColumn(self.sortBy - 1, Qt.DescendingOrder) 
     560            self.ranksView.setSortingEnabled(True) 
    287561 
    288562    def setLogORTitle(self): 
    289         selectedMeasures = list(self.selectedMeasures) 
    290         if self.logORIdx in selectedMeasures: 
    291             loi = selectedMeasures.index(self.logORIdx) 
    292             if  self.data and self.data.domain.classVar \ 
    293                 and self.data.domain.classVar.varType == orange.VarTypes.Discrete \ 
    294                 and len(self.data.domain.classVar.values) == 2: 
    295                     title = "log OR (for '%s')" % (self.data.domain.classVar.values[1][:10]) 
    296             else: 
    297                 title = "log OR" 
    298                  
    299             self.table.setHorizontalHeaderItem(2+loi, QTableWidgetItem(title)) 
    300             self.table.resizeColumnToContents(2+loi) 
    301          
    302  
    303     def setMeasures(self): 
    304         self.selectedMeasures = [i for i, ma in enumerate(self.measuresAttrs) if getattr(self, ma)] 
    305         self.table.setColumnCount(2 + len(self.selectedMeasures)) 
    306         for col, meas_idx in enumerate(self.selectedMeasures): 
    307             #self.topheader.setLabel(col+2, self.measuresShort[meas_idx]) 
    308             self.table.setColumnWidth(col+2, 80) 
    309         self.table.setHorizontalHeaderLabels(["Attribute", "#"] + [self.measuresShort[idx] for idx in self.selectedMeasures]) 
    310         self.setLogORTitle() 
    311  
     563        var =self.data.domain.classVar  
     564        if len(var.values) == 2: 
     565            title = "log OR (for %r)" % var.values[1][:10] 
     566        else: 
     567            title = "log OR" 
     568        item = PyStandardItem(title) 
     569        self.ranksModel.setHorizontalHeaderItem(self.ranksModel.columnCount() - 1, item) 
     570        return  
    312571 
    313572    def measuresChanged(self): 
    314         self.setMeasures() 
    315         if self.data: 
    316             self.reprint(True) 
    317             self.resendAttributes() 
    318  
     573        """ Measure selection has changed. Update column visibility. 
     574        """ 
     575        for i, valName in enumerate(self.measuresAttrs): 
     576            shown = getattr(self, valName, True) 
     577            self.ranksView.setColumnHidden(i + 2, not shown) 
    319578 
    320579    def sortByColumn(self, col): 
     
    325584        self.sortingChanged() 
    326585 
    327  
    328586    def decimalsChanged(self): 
    329         self.reprint(True) 
    330  
    331  
    332     def getMeasure(self, meas_idx): 
    333         measAttr = self.measuresAttrs[meas_idx] 
    334         mdict = self.measured.get(measAttr, False) 
    335         if mdict: 
    336             return mdict 
    337  
    338         estimator = self.estimators[meas_idx]() 
    339         if measAttr == "computeReliefF": 
    340             estimator.k, estimator.m = self.reliefK, self.reliefN 
    341  
    342         handlesContinuous = self.handlesContinuous[meas_idx] 
    343         mdict = {} 
    344         for attr in self.data.domain.attributes: 
    345             if handlesContinuous or attr.varType == orange.VarTypes.Discrete: 
    346                 act_attr, act_data = attr, self.data 
    347             else: 
    348                 if not self.discretizedData: 
    349                     discretizer = orange.EquiNDiscretization(numberOfIntervals=self.nIntervals) 
    350                     contAttrs = filter(lambda attr: attr.varType == orange.VarTypes.Continuous, self.data.domain.attributes) 
    351                     at = [] 
    352                     attrDict = {} 
    353                     for attri in contAttrs: 
    354                         try: 
    355                             nattr = discretizer(attri, self.data) 
    356                             at.append(nattr) 
    357                             attrDict[attri] = nattr 
    358                         except: 
    359                             pass 
    360                     self.discretizedData = self.data.select(orange.Domain(at, self.data.domain.classVar)) 
    361                     self.discretizedData.setattr("attrDict", attrDict) 
    362  
    363                 act_attr, act_data = self.discretizedData.attrDict.get(attr, None), self.discretizedData 
    364  
    365             try: 
    366                 if act_attr: 
    367                     mdict[attr] = act_attr and estimator(act_attr, act_data) 
    368                     if measAttr == "computeLogOdds": 
    369                         if mdict[attr] == -999999: 
    370                             act_attr = u"-\u221E" 
    371                         elif mdict[attr] == 999999: 
    372                             act_attr = u"\u221E" 
    373                         mdict[attr] = ("%%.%df" % self.nDecimals + " (%s)") % (mdict[attr], act_attr.values[1]) 
    374                 else: 
    375                     mdict[attr] = None 
    376             except: 
    377                 mdict[attr] = None 
    378  
    379         self.measured[measAttr] = mdict 
    380         return mdict 
    381  
    382  
    383     def reprint(self, noSort = False): 
    384         if not self.data: 
    385             return 
    386  
    387         prec = " %%.%df" % self.nDecimals 
    388  
    389         if not noSort: 
    390             self.resort() 
    391  
    392         for row, attr in enumerate(self.attributeOrder): 
    393             OWGUI.tableItem(self.table, row, 0, attr.name) 
    394             OWGUI.tableItem(self.table, row, 1, attr.varType==orange.VarTypes.Continuous and "C" or str(len(attr.values))) 
    395  
    396         self.minmax = {} 
    397  
    398         for col, meas_idx in enumerate(self.selectedMeasures): 
    399             mdict = self.getMeasure(meas_idx) 
    400             values = filter(lambda val: val != None, mdict.values()) 
    401             if values != []: 
    402                 self.minmax[col] = (min(values), max(values)) 
    403             else: 
    404                 self.minmax[col] = (0,1) 
    405             for row, attr in enumerate(self.attributeOrder): 
    406                 if mdict[attr] is None: 
    407                     mattr = "NA" 
    408                 elif isinstance(mdict[attr], (int, float)): 
    409                     mattr = prec % mdict[attr] 
    410                 else: 
    411                     mattr = mdict[attr] 
    412                 OWGUI.tableItem(self.table, row, col+2, mattr) 
    413  
    414         self.reselect() 
    415  
    416         if self.sortBy < 3: 
    417             self.topheader.setSortIndicator(self.sortBy-1, Qt.DescendingOrder) 
    418         elif self.sortBy-3 in self.selectedMeasures: 
    419             self.topheader.setSortIndicator(2 + self.selectedMeasures.index(self.sortBy-3), Qt.DescendingOrder) 
    420         else: 
    421             self.topheader.setSortIndicator(-1, Qt.DescendingOrder) 
    422  
    423         #self.table.resizeColumnsToContents() 
    424         self.table.resizeRowsToContents() 
    425         self.table.setColumnWidth(0, 100) 
    426         self.table.setColumnWidth(1, 20) 
    427         for col in range(len(self.selectedMeasures)): 
    428             self.table.setColumnWidth(col+2, 80) 
    429  
    430  
     587        self.updateDelegates() 
     588        self.ranksView.resizeColumnsToContents() 
     589         
     590    def updateDelegates(self): 
     591        self.contRanksView.setItemDelegate(RankItemDelegate(self, 
     592                            decimals=self.nDecimals, 
     593                            color=self.distColor)) 
     594        self.discRanksView.setItemDelegate(RankItemDelegate(self, 
     595                            decimals=self.nDecimals, 
     596                            color=self.distColor)) 
     597         
    431598    def sendReport(self): 
    432599        self.reportData(self.data) 
    433600        self.reportRaw(OWReport.reportTable(self.table)) 
    434601 
    435  
    436     def resendAttributes(self): 
    437         if not self.data: 
    438             self.send("ExampleTable Attributes", None) 
    439             return 
    440  
    441         attrDomain = orange.Domain(  [orange.StringVariable("attributes"), orange.EnumVariable("D/C", values = "DC"), orange.FloatVariable("#")] 
    442                                    + [orange.FloatVariable(self.measuresShort[meas_idx]) for meas_idx in self.selectedMeasures], 
    443                                      None) 
    444         attrData = orange.ExampleTable(attrDomain) 
    445         measDicts = [self.measured[self.measuresAttrs[meas_idx]] for meas_idx in self.selectedMeasures] 
    446         for attr in self.attributeOrder: 
    447             cont = attr.varType == orange.VarTypes.Continuous 
    448             attrData.append([attr.name, cont, cont and "?" or len(attr.values)] + [meas[attr] or "?" for meas in measDicts]) 
    449  
    450         self.send("ExampleTable Attributes", attrData) 
    451  
    452  
    453     def reselect(self): 
    454         self._reselecting = True 
    455         try: 
    456             self.table.clearSelection() 
    457      
    458             if not self.data: 
    459                 return 
    460      
    461             for attr in self.selected: 
    462                 self.table.selectRow(self.attributeOrder.index(attr)) 
    463      
    464             if self.selectMethod == 1 or self.selectMethod == 2 and self.selected == self.attributeOrder[:self.nSelected]: 
    465                 pass 
    466             elif self.selected == self.attributeOrder: 
    467                 self.selectMethod = 0 
    468             else: 
    469                 self.selectMethod = 1 
    470         finally: 
    471             self._reselecting = False 
    472             self.onSelectionChanged() 
    473  
    474     def resort(self): 
    475         self.attributeOrder = self.usefulAttributes 
    476  
    477         if self.sortBy: 
    478             if self.sortBy == 1: 
    479                 st = [(attr, attr.name) for attr in self.attributeOrder] 
    480                 st.sort(lambda x,y: cmp(x[1], y[1])) 
    481             elif self.sortBy == 2: 
    482                 st = [(attr, attr.varType == orange.VarTypes.Continuous and 1e30 or len(attr.values)) for attr in self.attributeOrder] 
    483                 st.sort(lambda x,y: cmp(x[1], y[1])) 
    484                 self.topheader.setSortIndicator(1, Qt.DescendingOrder) 
    485             else: 
    486                 st = [(m, a == None and -1e20 or a) for m, a in self.getMeasure(self.sortBy-3).items()] 
    487                 st.sort(lambda x,y: -cmp(x[1], y[1]) or cmp(x[0], y[0])) 
    488  
    489             self.attributeOrder = [attr for attr, meas in st] 
    490  
    491         if self.selectMethod == 2: 
    492             self.selected = self.attributeOrder[:self.nSelected] 
    493  
    494  
    495602    def applyIf(self): 
    496603        if self.autoApply: 
     
    500607 
    501608    def apply(self): 
    502         if not self.data or not self.selected: 
     609        selected = self.selectedAttrs() 
     610        if not self.data or not selected: 
    503611            self.send("Reduced Example Table", None) 
    504             self.lastSentAttrs = [] 
    505         else: 
    506             if self.lastSentAttrs != self.selected: 
    507                 nDomain = orange.Domain(self.selected, self.data.domain.classVar) 
    508                 for meta in [a.name for a in self.data.domain.getmetas().values()]: 
    509                     nDomain.addmeta(orange.newmetaid(), self.data.domain[meta]) 
    510  
    511                 self.send("Reduced Example Table", orange.ExampleTable(nDomain, self.data)) 
    512                 self.lastSentAttrs = self.selected[:] 
    513  
     612        else: 
     613            domain = orange.Domain(selected, self.data.domain.classVar) 
     614            domain.addmetas(self.data.domain.getmetas()) 
     615            data = orange.ExampleTable(domain, self.data) 
     616            self.send("Reduced Example Table", data) 
    514617        self.dataChanged = False 
    515  
    516  
    517  
    518 class RankItemDelegate(QItemDelegate): 
    519     def __init__(self, widget = None, table = None): 
    520         QItemDelegate.__init__(self, widget) 
    521         self.table = table 
    522         self.widget = widget 
    523  
     618         
     619    def selectedAttrs(self): 
     620        if self.data: 
     621            inds = self.ranksView.selectionModel().selectedRows(0) 
     622            source = self.ranksProxyModel.mapToSource 
     623            inds = map(source, inds) 
     624            inds = [ind.row() for ind in inds] 
     625            return [self.data.domain.attributes[i] for i in inds] 
     626        else: 
     627            return [] 
     628 
     629class RankItemDelegate(QStyledItemDelegate): 
     630    """ Item delegate that can also draw a distribution bar 
     631    """ 
     632    def __init__(self, parent=None, decimals=3, color=Qt.red): 
     633        QStyledItemDelegate.__init__(self, parent) 
     634        self.decimals = decimals 
     635        self.float_fmt = "%%.%if" % decimals 
     636        self.color = QColor(color) 
     637         
     638    def displayText(self, value, locale): 
     639        obj = _toPyObject(value) 
     640        if isinstance(obj, float): 
     641            return self.float_fmt % obj 
     642        elif isinstance(obj, basestring): 
     643            return obj 
     644        elif obj is None: 
     645            return "NA" 
     646        else: 
     647            return obj.__str__() 
     648         
     649    def sizeHint(self, option, index): 
     650        metrics = QFontMetrics(option.font) 
     651        height = metrics.lineSpacing() + 8 # 4 pixel margin 
     652        width = metrics.width(self.displayText(index.data(Qt.DisplayRole), QLocale())) + 8 
     653        return QSize(width, height) 
     654     
    524655    def paint(self, painter, option, index): 
    525         if not self.widget.showDistributions: 
    526             QItemDelegate.paint(self, painter, option, index) 
    527             return 
    528  
    529         col = index.column() 
    530         row = index.row() 
    531  
    532         if col < 2 or not self.widget.minmax.has_key(col-2):        # we don't paint first two columns 
    533             QItemDelegate.paint(self, painter, option, index) 
    534             return 
    535  
    536         min, max = self.widget.minmax[col-2] 
    537  
     656        text = self.displayText(index.data(Qt.DisplayRole), QLocale()) 
     657         
     658        bar_ratio = index.data(OWGUI.BarRatioRole) 
     659        ratio, have_ratio = bar_ratio.toDouble() 
     660        rect = option.rect 
     661        if have_ratio: 
     662            text_rect = rect.adjusted(4, 1, -4, -5) # Style dependent margins? 
     663        else: 
     664            text_rect = rect.adjusted(4, 4, -4, -4) 
     665             
    538666        painter.save() 
    539         self.drawBackground(painter, option, index) 
    540         value, ok = index.data(Qt.DisplayRole).toDouble() 
    541  
    542         if ok:        # in case we get "?" it is not ok 
    543             smallerWidth = option.rect.width() * (max - value) / (max-min or 1) 
    544             painter.fillRect(option.rect.adjusted(0,0,-smallerWidth,0), self.widget.distColor) 
    545  
    546         self.drawDisplay(painter, option, option.rect, index.data(Qt.DisplayRole).toString()) 
     667        painter.setFont(option.font) 
     668        qApp.style().drawPrimitive(QStyle.PE_PanelItemViewRow, option, painter) 
     669         
     670        # TODO: Check ForegroundRole. 
     671        if option.state & QStyle.State_Selected: 
     672            color = option.palette.highlightedText().color() 
     673        else: 
     674            color = option.palette.text().color() 
     675        painter.setPen(QPen(color)) 
     676         
     677        align = index.data(Qt.TextAlignmentRole) 
     678        if align.isValid(): 
     679            align = align.toInt() 
     680        else: 
     681            align = Qt.AlignLeft | Qt.AlignVCenter 
     682        painter.drawText(text_rect, align, text) 
     683        painter.setRenderHint(QPainter.Antialiasing, True) 
     684        if have_ratio: 
     685            bar_brush = index.data(OWGUI.BarBrushRole) 
     686            if bar_brush.isValid(): 
     687                bar_brush = bar_brush.toPyObject() 
     688                if not isinstance(bar_brush, (QColor, QBrush)): 
     689                    bar_brush = None 
     690            else: 
     691                bar_brush =  None 
     692            if bar_brush is None: 
     693                bar_brush = self.color 
     694            brush = QBrush(bar_brush) 
     695            painter.setBrush(brush) 
     696            painter.setPen(QPen(brush, 1)) 
     697            bar_rect = QRect(text_rect) 
     698            bar_rect.setTop(bar_rect.bottom() - 1) 
     699            bar_rect.setBottom(bar_rect.bottom() + 1) 
     700            w = text_rect.width() 
     701            bar_rect.setWidth(max(0, min(w * ratio, w))) 
     702            painter.drawRoundedRect(bar_rect, 2, 2) 
    547703        painter.restore() 
    548704 
     705         
     706class PyStandardItem(QStandardItem): 
     707    """ A StandardItem subclass for python objects. 
     708    """ 
     709    def __init__(self, *args): 
     710        QStandardItem.__init__(self, *args) 
     711        self.setFlags(Qt.ItemIsSelectable| Qt.ItemIsEnabled) 
     712         
     713    def __lt__(self, other): 
     714        my = self.data(Qt.DisplayRole).toPyObject() 
     715        other = other.data(Qt.DisplayRole).toPyObject() 
     716        if my is None: 
     717            return True 
     718        return my < other 
     719 
     720class MySortProxyModel(QSortFilterProxyModel): 
     721    def headerData(self, section, orientation, role): 
     722        """ Don't map headers. 
     723        """ 
     724        source = self.sourceModel() 
     725        return source.headerData(section, orientation, role) 
     726     
     727    def lessThan(self, left, right): 
     728        role = self.sortRole() 
     729        left = left.data(role).toPyObject() 
     730        right = right.data(role).toPyObject() 
     731#        print left, right 
     732        return left < right 
    549733 
    550734if __name__=="__main__": 
    551735    a=QApplication(sys.argv) 
    552736    ow=OWRank() 
    553     #ow.setData(orange.ExampleTable("../../doc/datasets/wine.tab")) 
    554     ow.setData(orange.ExampleTable(r"E:\Development\Orange Datasets\UCI\zoo.tab")) 
     737    ow.setData(orange.ExampleTable("wine.tab")) 
     738    ow.setData(orange.ExampleTable("zoo.tab")) 
     739#    ow.setData(orange.ExampleTable("servo.tab")) 
     740#    ow.setData(orange.ExampleTable("auto-mpg.tab")) 
    555741    ow.show() 
    556742    a.exec_() 
Note: See TracChangeset for help on using the changeset viewer.