Changeset 3630:e323ec96cbbe in orange


Ignore:
Timestamp:
05/03/07 10:04:58 (7 years ago)
Author:
blaz <blaz.zupan@…>
Branch:
default
Convert:
42c175d378d6e6beeeacb291d18ce38c649993f7
Message:

added regression, checking of variable types, warnings

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/OrangeWidgets/Evaluate/OWPredictions.py

    r3482 r3630  
    1010from OWWidget import * 
    1111import OWGUI 
     12import statc 
    1213 
    1314############################################################################## 
     
    2526 
    2627class OWPredictions(OWWidget): 
    27     settingsList = ["ShowProb", "ShowClass", "ShowTrueClass", 
    28                     "ShowAttributeMethod", "sendDataType", "commitOnChange"] 
     28    settingsList = ["showProb", "showClass", 
     29                    "ShowAttributeMethod", "sendDataType", "sendOnChange", 
     30                    "sendPredictions", "sendSelection", "classvalues", 
     31                    "rbest", "rpercentile"] 
    2932 
    3033    def __init__(self, parent=None, signalManager = None): 
    31         OWWidget.__init__(self, parent, signalManager, "Classifications") 
     34        OWWidget.__init__(self, parent, signalManager, "Predictions") 
    3235 
    3336        self.callbackDeposit = [] 
    3437        self.inputs = [("Examples", ExampleTable, self.setData),("Classifiers", orange.Classifier, self.setClassifier, Multiple)] 
    35         self.outputs = [("Selected Examples", ExampleTable)] 
     38        self.outputs = [("Predictions", ExampleTable), ("Selected Examples", ExampleTable)] 
    3639        self.classifiers = {} 
    3740 
    3841        # saveble settings 
    39         self.ShowProb = 1; self.ShowClass = 1; self.ShowTrueClass = 0 
     42        self.showProb = 1; self.showClass = 1 
    4043        self.ShowAttributeMethod = 0 
    41         self.sendDataType = 0; self.commitOnChange = 1 
     44        self.sendDataType = 0; self.sendOnChange = 1 
     45        self.sendPredictions = 0 
     46        self.sendSelection = 1 
     47        self.rbest = 0 
    4248        self.loadSettings() 
    43  
    44         self.freezeAttChange = 0 # 1 to block table update followed by changes in attribute list box 
    45         self.data=None 
     49        self.outvar = None # current output variable (set by the first predictor send in) 
     50 
     51        self.freezeAttChange = 0 # block table update after changes in attribute list box? 
     52        self.data = None 
    4653 
    4754        # GUI - Options 
    48         self.options = QVButtonGroup("Options", self.controlArea) 
    49         self.options.setDisabled(1) 
    50         OWGUI.checkBox(self.options, self, 'ShowProb', "Show predicted probabilities", 
     55 
     56        # Options - classification 
     57        self.copt = QVButtonGroup("Options (classification)", self.controlArea) 
     58        self.copt.setDisabled(1) 
     59        OWGUI.checkBox(self.copt, self, 'showProb', "Show predicted probabilities", 
    5160                       callback=self.updateTableOutcomes) 
    5261 
    53         self.lbClasses = QListBox(self.options) 
     62        self.lbClasses = QListBox(self.copt) 
    5463        self.lbClasses.setSelectionMode(QListBox.Multi) 
    5564        self.connect(self.lbClasses, SIGNAL("selectionChanged()"), self.updateTableOutcomes) 
    5665         
    57         OWGUI.checkBox(self.options, self, 'ShowClass', "Show predicted class", 
    58                        callback=[self.updateTableOutcomes, self.checkenable]) 
    59         self.trueClassCheckBox = OWGUI.checkBox(self.options, self, 'ShowTrueClass', 
    60                                                 "Show true class", callback=self.updateTrueClass, disabled=1) 
    61  
     66        OWGUI.checkBox(self.copt, self, 'showClass', "Show predicted class", 
     67                       callback=[self.updateTableOutcomes, self.checksendpredictions]) 
     68 
     69        # Options - regression 
     70        # self.ropt = QVButtonGroup("Options (regression)", self.controlArea) 
     71        # OWGUI.checkBox(self.ropt, self, 'showClass', "Show predicted class", 
     72        #                callback=[self.updateTableOutcomes, self.checksendpredictions]) 
     73        # self.ropt.hide() 
     74         
    6275        OWGUI.separator(self.controlArea) 
    63         self.att = QVButtonGroup("Data Attributes", self.controlArea) 
     76 
     77        self.att = QVButtonGroup("Data attributes", self.controlArea) 
    6478        OWGUI.radioButtonsInBox(self.att, self, 'ShowAttributeMethod', ['Show all', 'Hide all'], 
    6579                                callback=self.updateAttributes) 
     
    6781 
    6882        OWGUI.separator(self.controlArea) 
    69         self.outBox = QVButtonGroup("Output", self.controlArea) 
    70         OWGUI.radioButtonsInBox(self.outBox, self, 'sendDataType', 
    71                                 ['None', 'Data with class conflict', 'Data with class agreement'], 
    72                                 box='Data Selection', 
    73                                 tooltips=['No data will be sent to the output channel', 
    74                                           'Send data for which the predicted (and true class, if shown) are different.', 
    75                                           'Send data for which the predicted (and true class, if shown) match.'], 
    76                                 callback=self.checksenddata) 
    77         OWGUI.checkBox(self.outBox, self, 'commitOnChange', 'Commit data on any change') 
    78         self.commitBtn = OWGUI.button(self.outBox, self, "Commit", callback=self.senddata) 
    79  
    80         self.outBox.setDisabled(1) 
     83        self.outbox = QVButtonGroup("Output", self.controlArea) 
     84         
     85        self.dsel = OWGUI.checkBox(self.outbox, self, "sendSelection", "Send data selection") 
     86        # data selection for classification 
     87        self.csel = OWGUI.radioButtonsInBox(self.outbox, self, 'sendDataType', 
     88                                ['Examples with class conflict', 'Examples with class agreement'], 
     89                                box='Data selection', 
     90                                tooltips=['Data instances with different true and predictied class.', 
     91                                          'Data instances with matching true and predictied class.'], 
     92                                callback=self.checksendselection) 
     93        # data selection for regression 
     94        self.rsel = QVButtonGroup("Data selection", self.outbox) 
     95        OWGUI.radioButtonsInBox(self.rsel, self, "rbest", ["Highest variance", "Lowest variance"], 
     96                                callback = self.checksendselection) 
     97        hb = OWGUI.widgetBox(self.rsel, orientation = "horizontal") 
     98        QLabel('Percentiles: ', hb) 
     99        OWGUI.comboBox(hb, self, "rpercentile", 
     100                       items = [0.01, 0.02, 0.05, 0.1, 0.2], 
     101                       sendSelectedValue = 1, valueType = float, callback = self.checksendselection) 
     102 
     103        self.dsel.disables = [self.csel, self.rsel] 
     104 
     105        OWGUI.checkBox(self.outbox, self, 'sendPredictions', "Send predictions", 
     106                       callback=self.updateTableOutcomes) 
     107        OWGUI.separator(self.controlArea) 
     108 
     109        self.commitBtn = OWGUI.button(self.outbox, self, "Send data", callback=self.senddata) 
     110        OWGUI.checkBox(self.outbox, self, 'sendOnChange', 'Send automatically') 
     111 
     112        self.outbox.setDisabled(1) 
    81113 
    82114        # GUI - Table 
     
    91123 
    92124        self.layout.add(self.table) 
    93 #        self.table.hide() 
    94  
    95     # updates the columns associated with the classifiers 
     125 
     126    ############################################################################## 
     127    # Contents painting 
     128  
    96129    def updateTableOutcomes(self): 
     130        """updates the columns associated with the classifiers""" 
    97131        if self.freezeAttChange: # program-based changes should not alter the table immediately 
    98132            return 
     
    100134            return 
    101135 
    102         attsel = [self.lbClasses.isSelected(i) for i in range(len(self.data.domain.attributes))] 
    103         showatt = attsel.count(1) 
     136        classification = None 
     137        if self.outvar: 
     138            classification = self.outvar.varType == orange.VarTypes.Discrete 
     139            if classification: 
     140                selclass = [self.lbClasses.isSelected(i) for i in range(len(self.data.domain.classVar.values))] 
     141                showclass = selclass.count(1) 
     142            else: 
     143                showclass = 1 
     144             
    104145        # sindx is the column where these start 
    105146        sindx = 1 + len(self.data.domain.attributes) + 1 * (self.data.domain.classVar<>None) 
    106147        col = sindx 
    107         if self.ShowClass or self.ShowProb: 
     148        if self.showClass or self.showProb: 
    108149            for (cid, c) in enumerate(self.classifiers.values()): 
    109                 if self.data.domain.classVar.varType == orange.VarTypes.Continuous: 
     150                if classification: 
     151                    for (i, d) in enumerate(self.data): 
     152                        (cl, p) = c(d, orange.GetBoth) 
     153 
     154                        self.classifications[i].append(cl) 
     155                        if self.showProb and showclass: 
     156                            s = " : ".join(["%5.3f" % p for (vi,p) in enumerate(p) if selclass[vi]]) 
     157                            if self.showClass: s += " -> " 
     158                        else: 
     159                            s = "" 
     160                        if self.showClass: 
     161                            s += str(cl) 
     162                        self.table.setText(self.rindx[i], col, s) 
     163                else: 
    110164                    # regression 
    111165                    for (i, d) in enumerate(self.data): 
     
    113167                        self.classifications[i].append(cl) 
    114168                        self.table.setText(self.rindx[i], col, str(cl)) 
    115                 else: 
    116                     # classification 
    117                     for (i, d) in enumerate(self.data): 
    118                         (cl, p) = c(d, orange.GetBoth) 
    119                         self.classifications[i].append(cl) 
    120                         s = '' 
    121                         if self.ShowProb and showatt: 
    122                             s += reduce(lambda x,y: x+' : '+y, 
    123                                         map(lambda x: "%5.3f"%x[1], filter(lambda x,s=attsel: s[x[0]], enumerate(p)))) 
    124                             if self.ShowClass: 
    125                                 s += ' -> ' 
    126                         if self.ShowClass: 
    127                             s += str(cl) 
    128                         self.table.setText(self.rindx[i], col, s) 
    129169                col += 1 
    130170        else: 
     
    136176        for i in range(sindx, col): 
    137177            self.table.adjustColumn(i) 
    138             if self.ShowClass or self.ShowProb: 
     178            if self.showClass or self.showProb: 
    139179                self.table.showColumn(i) 
    140180            else: 
     
    142182 
    143183    def updateTrueClass(self): 
    144         if self.classifiers: 
    145             col = 1+len(self.data.domain.attributes) 
    146             if self.ShowTrueClass and self.data.domain.classVar: 
    147                 self.table.showColumn(col) 
    148                 self.table.adjustColumn(col) 
    149             else: 
    150                 self.table.hideColumn(col) 
     184        col = 1+len(self.data.domain.attributes) 
     185        if self.data.domain.classVar: 
     186            self.table.showColumn(col) 
     187            self.table.adjustColumn(col) 
     188        else: 
     189            self.table.hideColumn(col) 
    151190 
    152191    def updateAttributes(self): 
     
    159198                self.table.hideColumn(i+1) 
    160199 
    161     # defines the table and paints its contents 
    162200    def setTable(self): 
     201        """defines the attribute/predictions table and paints its contents""" 
    163202        if self.data==None: 
    164203            return 
    165204 
    166205        self.table.setNumCols(0) 
    167         self.table.setNumCols(1 + len(self.data.domain.attributes) + (self.ShowTrueClass) + len(self.classifiers)) 
     206        self.table.setNumCols(1 + len(self.data.domain.attributes) + (self.data.domain.classVar <> None) + len(self.classifiers)) 
    168207        self.table.setNumRows(len(self.data)) 
    169208 
     
    187226        self.classifications = [[]] * len(self.data) 
    188227        if self.data.domain.classVar: 
    189             # column for the true class 
    190228            for i in range(len(self.data)): 
    191229                c = self.data[i].getclass() 
     
    215253            self.vheader.setLabel(i, self.table.item(i,0).text()) 
    216254 
     255    def checkenable(self): 
     256        # following should be more complicated and depends on what data are we showing 
     257        cond = len(self.classifiers) 
     258        self.outbox.setEnabled(cond) 
     259        self.att.setEnabled(cond) 
     260        self.copt.setEnabled(cond) 
     261        e = (self.data and (self.data.domain.classVar <> None) + len(self.classifiers)) >= 2 
     262        # need at least two classes to compare predictions 
     263        self.dsel.setEnabled(e) 
     264        if e and self.sendSelection: 
     265            self.csel.setEnabled(1) 
     266            self.rsel.setEnabled(1) 
     267 
    217268    ############################################################################## 
    218269    # Input signals 
    219270 
    220     def setData(self,data): 
    221         self.data = self.isDataWithClass(data) and data or None 
    222         if not self.data: 
     271    def consistenttarget(self, target): 
     272        """returns TRUE if target is consistent with current predictiors and data""" 
     273        if self.classifiers: 
     274            return target == self.classifiers.values()[0] 
     275        return True 
     276 
     277    def setData(self, data): 
     278        if not data: 
     279            self.data = data 
    223280            self.table.hide() 
    224281            self.send("Selected Examples", None) 
     282            self.send("Predictions", None) 
     283            self.att.setDisabled(1) 
     284            self.outbox.setDisabled(1) 
    225285        else: 
    226             if self.data.domain.classVar.varType == orange.VarTypes.Continuous: 
    227                 # regression 
    228                 pass 
    229             else: 
    230                 lb = self.lbClasses 
    231                 lb.clear() 
    232                 for v in self.data.domain.classVar.values: 
    233                     lb.insertItem(str(v)) 
    234                 self.freezeAttChange = 1 
    235                 for i in range(len(self.data.domain.classVar.values)): 
    236                     lb.setSelected(i, 1) 
    237                 self.freezeAttChange = 0 
    238                 lb.show() 
    239                 # classification 
    240  
    241             if not self.classifiers: 
    242                 self.ShowTrueClass = 1 
    243  
     286            vartypes = {1:"discrete", 2:"continuous"} 
     287            if len(self.classifiers) and data.domain.classVar and data.domain.classVar <> self.outvar: 
     288                self.warning(id, "Data set %s ignored, inconsistent outcome variables\n%s/%s <> %s/%s (type or variable mismatch)" % (data.name, data.domain.classVar.name, vartypes.get(data.domain.classVar.varType, "?"), self.outvar.name, vartypes.get(self.outvar.varType, "?"))) 
     289                return 
     290            self.data = data 
    244291            self.rindx = range(len(self.data)) 
    245292            self.setTable() 
    246293            self.table.show() 
    247             self.checkenable() 
     294            self.checksenddata() 
     295        self.checkenable() 
    248296 
    249297    def setClassifier(self, c, id): 
     298        """handles incoming classifier (prediction, as could be a regressor as well)""" 
    250299        if not c: 
    251300            if self.classifiers.has_key(id): 
    252301                del self.classifiers[id] 
     302                if len(self.classifiers) == 0: self.outvar = None 
     303            else: 
     304                self.warning(id, "") 
    253305        else: 
     306            if len(self.classifiers) and c.classVar <> self.outvar: 
     307                vartypes = {1:"discrete", 2:"continuous"} 
     308                self.warning(id, "Predictor %s ignored, inconsistent outcome variables\n%s/%s <> %s/%s (type or variable mismatch)" % (c.name, c.classVar.name, vartypes.get(c.classVar.varType, "?"), self.outvar.name, vartypes.get(self.outvar.varType, "?"))) 
     309                return 
     310            else: 
     311                self.outvar = c.classVar 
    254312            self.classifiers[id] = c 
     313 
     314        if len(self.classifiers) == 1 and c: 
     315            # defines the outcome variable and the type of the problem we are dealing with (regression/classification) 
     316            self.outvar == c.classVar 
     317            if self.outvar.varType == orange.VarTypes.Continuous: 
     318                # regression 
     319                self.copt.hide(); self.csel.hide(); self.rsel.show() 
     320            else: 
     321                # classification 
     322                self.rsel.hide(); self.copt.show(); self.csel.show() 
     323                lb = self.lbClasses 
     324                lb.clear() 
     325                for v in self.outvar.values: 
     326                    lb.insertItem(str(v)) 
     327                self.freezeAttChange = 1 
     328                for i in range(len(self.outvar.values)): 
     329                    lb.setSelected(i, 1) 
     330                lb.show() 
     331                self.freezeAttChange = 0 
     332                 
    255333        if self.data: 
    256334            self.setTable() 
    257335            self.table.show() 
     336            self.checksenddata() 
    258337        self.checkenable() 
    259  
    260     # based on the data and classifiers enables/disables the control boxes 
    261     def checkenable(self): 
    262         # following should be more complicated and depends on what data are we showing 
    263         cond = self.data<>None and (len(self.classifiers)>1 or len(self.classifiers)>0 and self.ShowTrueClass) 
    264         self.outBox.setEnabled(cond) 
    265         if self.commitOnChange: 
    266             if cond: 
    267                 self.senddata() 
    268             else: 
    269                 self.send("Selected Examples", None) 
    270  
    271         self.trueClassCheckBox.setEnabled(self.data<>None and self.data.domain.classVar<>None) 
    272 ##        self.options.setEnabled(len(self.classifiers)>0) 
    273         self.att.setEnabled(self.data<>None) 
    274         self.options.setEnabled(self.data<>None) 
    275  
    276338 
    277339    ############################################################################## 
     
    279341 
    280342    def checksenddata(self): 
    281         if self.commitOnChange and self.outBox.isEnabled(): 
    282             self.senddata() 
     343        # if self.sendOnChange and self.outbox.isEnabled(): 
     344        if len(self.classifiers) and self.sendOnChange: self.senddata() 
     345 
     346    def checksendselection(self): 
     347        if len(self.classifiers) and self.sendOnChange: self.selection() 
     348 
     349    def checksendpredictions(self): 
     350        if len(self.classifiers) and self.sendOnChange: self.predictions() 
     351 
     352    def senddata(self): 
     353        self.predictions() 
     354        self.selection() 
    283355 
    284356    # assumes that the data and display conditions 
    285357    # (enough classes are displayed) have been checked 
    286358 
    287     def senddata(self): 
     359    def predictions(self): 
     360        if self.freezeAttChange: return 
     361        if not self.data or not self.classifiers: 
     362            self.send("Predictions", None) 
     363 
     364        if self.sendPredictions: 
     365            # predictions, data set with class predictions 
     366            classification = self.outvar.varType == orange.VarTypes.Discrete 
     367 
     368            metas = [] 
     369            if classification: 
     370                selclass = [self.lbClasses.isSelected(i) for i in range(len(self.data.domain.classVar.values))] 
     371                showclass = selclass.count(1) 
     372                if showclass: 
     373                    for c in self.classifiers.values(): 
     374                        m = [orange.FloatVariable(name="%s(%s)" % (c.name, str(v)), 
     375                                                  getValueFrom = lambda ex, rw, cindx=i: orange.Value(c(ex, c.GetProbabilities)[cindx])) \ 
     376                             for (i, v) in enumerate(self.data.domain.classVar.values) if selclass[i]] 
     377                        metas.extend(m) 
     378                if self.showClass: 
     379                    mc = [orange.EnumVariable(name="%s" % c.name, values = self.data.domain.classVar.values, 
     380                                             getValueFrom = lambda ex, rw: orange.Value(c(ex))) 
     381                          for c in self.classifiers.values()] 
     382                    metas.extend(mc) 
     383            else: 
     384                # regression 
     385                mc = [orange.FloatVariable(name="%s" % c.name,  
     386                                           getValueFrom = lambda ex, rw: orange.Value(c(ex))) 
     387                      for c in self.classifiers.values()] 
     388                metas.extend(mc) 
     389 
     390            domain = orange.Domain(self.data.domain.attributes + [self.data.domain.classVar]) 
     391            for m in metas: 
     392                domain.addmeta(orange.newmetaid(), m) 
     393            predictions = orange.ExampleTable(domain, self.data) 
     394            predictions.name = self.data.name 
     395            self.send("Predictions", predictions) 
     396 
     397    def selection(self): 
    288398        def cmpclasses(clist): 
     399            """returns True if all elements in clist are the same""" 
     400            clist = filter(lambda x: not x.isSpecial(), clist) 
    289401            ref = clist[0] 
    290402            for c in clist[1:]: 
     
    292404            return 1 
    293405 
    294         if not self.sendDataType or not self.data or not self.classifiers: 
     406        if not self.sendSelection: 
     407            return 
     408 
     409        if not self.data or not self.classifiers: 
    295410            self.send("Selected Examples", None) 
    296             return 
    297  
    298         # list of columns to check 
    299         selclass = [[],[0]][self.ShowTrueClass>0] 
    300         for (i, classifier) in enumerate(self.classifiers): 
    301             selclass.append(i+1) 
    302          
    303         s = [cmpclasses(map(lambda x: cls[x], selclass)) for cls in self.classifications] 
    304         if self.sendDataType == 1: 
    305             s = [not x for x in s] 
    306         data_selection = self.data.select(s) 
     411 
     412        classification = self.outvar.varType == orange.VarTypes.Discrete 
     413        if classification: 
     414            s = [cmpclasses(cls) for cls in self.classifications] 
     415            if self.sendDataType == 1: 
     416                s = [not x for x in s] 
     417            data_selection = self.data.select(s) 
     418        else: 
     419            if self.data.domain.classVar: 
     420                variance = [(i, statc.var(cls)) for (i, cls) in enumerate(self.classifications) 
     421                            if not cls[0].isSpecial()] 
     422            else: 
     423                variance = [(i, statc.var(cls)) for (i, cls) in enumerate(self.classifications)] 
     424            variance.sort(lambda x, y: cmp(x[1], y[1])) 
     425            if not self.rbest: 
     426                variance.reverse() 
     427            n = int(len(self.data) * self.rpercentile) 
     428            if not n: 
     429                return 
     430            sel = [r[0] for r in variance[:n]] 
     431            data_selection = self.data.getitems(sel) 
     432        data_selection.name = self.data.name 
    307433        self.send("Selected Examples", data_selection) 
    308434 
     
    319445        data = orange.ExampleTable('sailing') 
    320446        ow.setData(data) 
    321     elif 0: 
     447    if 0: 
    322448        data = orange.ExampleTable('outcome') 
    323449        test = orange.ExampleTable('cheat', uses=data.domain) 
     
    331457        ow.setClassifier(tree, 2) 
    332458        ow.setData(test) 
    333     elif 1: # two classifiers 
     459    if 1: # two classifiers 
    334460        data = orange.ExampleTable('sailing.txt') 
    335461        bayes = orange.BayesLearner(data) 
     
    345471        ow.setClassifier(knn, 3) 
    346472        ow.setData(data) 
    347     else: # regression 
     473    if 0: # regression 
    348474        data = orange.ExampleTable('auto-mpg') 
     475        data.name = 'auto-mpg' 
    349476        knn = orange.kNNLearner(data, name="knn") 
    350477        knn.name = "knn" 
    351478        maj = orange.MajorityLearner(data) 
    352479        maj.name = "Majority" 
    353         ow.setClassifier(knn, 1) 
     480        ow.setClassifier(knn, 10) 
    354481        ow.setClassifier(maj, 2) 
    355482        ow.setData(data) 
Note: See TracChangeset for help on using the changeset viewer.