source: orange/Orange/OrangeWidgets/Evaluate/OWPredictions.py @ 11660:56999fdd0f33

Revision 11660:56999fdd0f33, 22.9 KB checked in by Ales Erjavec <ales.erjavec@…>, 8 months ago (diff)

Added Context Settings for selected class values.

RevLine 
[8042]1"""
2<name>Predictions</name>
3<description>Displays predictions of models for a particular data set.</description>
[11217]4<icon>icons/Predictions.svg</icon>
[8042]5<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
6<priority>300</priority>
7"""
8
9from OWWidget import *
10import OWGUI
11import statc
12import orange
13
14from OWDataTable import ExampleTableModel, getCached
15
16def safe_call(func):
17    from functools import wraps
18    @wraps(func)
19    def wrapper(*args, **kwargs):
20        try:
21            return func(*args, **kwargs)
22        except Exception, ex:
23            print >> sys.stderr, func.__name__, "call error", ex
24            return QVariant()
25    return wrapper
26
27class PyTableModel(QAbstractTableModel):
28    """ A general list-of-lists table holding arbitrary Python objects.
29    To view it in a item view you must subclass an QItemDelegate
30   
31    Arguments:
32        - `table`: A 2D table of any python objects
33        - `headers`: A list of strings for view headers
34        - `parent`: models parent (default None)
35       
36    Examples::
37        view = QTableView()
38        view.setModel(PyTableView([[1, 2, 3], [1, 2, 3]], ["One, "Two", "Three"], parent=view)
39    """
40    def __init__(self, table=None, headers=None, parent=None):
41        QAbstractTableModel.__init__(self, parent)
42        self._table = [[]] if table is None else table
43        self._header = [None] * len(self._table) if headers is None else headers
44       
45    @safe_call
46    def data(self, index, role=Qt.DisplayRole):
47        row, column = index.row(), index.column()
48        if role == Qt.DisplayRole or role == Qt.EditRole:
49            val = self._table[row][column]
50            return QVariant(val)
51        else:
52            return QVariant() #QAbstractTableModel.data(self, index, role)
53       
54    def rowCount(self, parent=QModelIndex()):
55        if parent.isValid():
56            return 0
57        else:
58            return len(self._table)
59       
60    def columnCount(self, parent=QModelIndex()):
61        if parent.isValid():
62            return 0
63        else:
64            return max([len(row) for row in self._table]) if self._table else 0
65       
66    def headerData(self, section, orientation, role=Qt.DisplayRole):
67        if orientation == Qt.Vertical and  role == Qt.DisplayRole:
68            return QVariant(QString(str(section + 1)))
69        elif orientation == Qt.Horizontal and role == Qt.DisplayRole:
70            header = self._header[section] if section < len(self._header) else str(section)
71            return QVariant(QString(header)) if header is not None else QVariant()
72        else:
73            return QVariant() #QAbstractTableModel.headerData(self, section, orientation, role)
74       
75    def sort(self, column, order=Qt.AscendingOrder):
76        self._table.sort(key=lambda row: row[column], reverse=order == Qt.DescendingOrder)
77        self.reset()
78       
79       
80class PredictionTableModel(PyTableModel):
81    """ Item model for classifier predictions
82    """
83    def __init__(self, prediction_results, *args, **kwargs):
84        """ prediciton_results [(classifier, list-of-predictions) ...]
85        """
86        PyTableModel.__init__(self, *args, **kwargs)
87        self.prediction_results = prediction_results
88        self._header = [p[0].name for p in prediction_results]
89       
90        self._table = [[] for i in range(max([len(p) for c, p in prediction_results] or [0]))]
91        for c, pred in prediction_results:
92            for i, p in enumerate(pred):
93                self._table[i].append(p)
94
95   
96class PredictionItemDelegete(QStyledItemDelegate):
97    """ Item delegate for prediction results i.e. (class, probabilities)
98    tuples as returned by classifiers with orange.Both return value
99   
100    Optional arguments:
101        - `showProbs`: a list of `True` or `False` values for each class value.
102        These are the probabilities that will be displayed (default all False
103        i.e. no probabilities shown)
104        - `deciamals`: number of decimals to show
105    """
106    def __init__(self, parent=None, *kwargs):
107        QStyledItemDelegate.__init__(self, parent)
108        self.__dict__.update(kwargs)
109   
110    def displayText(self, value, locale):
111        pred = value.toPyObject()
[10841]112        if isinstance(pred, (tuple, list)) and len(pred) == 2:
[8042]113            cls, prob = pred
[10841]114        elif isinstance(pred, orange.Value):
[8042]115            cls, prob = pred, None
[10841]116        elif isinstance(pred, orange.Distribution):
[8042]117            cls, prob = pred.modus(), pred
118        else:
119            return QString("")
120        text = ""
121        if prob and any(getattr(self, "showProbs", [])):
122            fmt = "%%.%if" % getattr(self, "decimals", 2)
123            text = " : ".join(fmt % f for f, show in zip(prob, self.showProbs) if show)
124            if getattr(self, "showClass", True):
125                text += " -> "
126        if getattr(self, "showClass", True):
127            text += str(cls)
128        return QString(text)
129
130class OWPredictions(OWWidget):
[11660]131    contextHandlers = {
132        "": ClassValuesContextHandler("", ["selectedClasses"])
133    }
134    settingsList = ["showProb", "showClass", "ShowAttributeMethod",
135                    "sendOnChange", "precision"]
[8042]136
137    def __init__(self, parent=None, signalManager = None):
138        OWWidget.__init__(self, parent, signalManager, "Predictions")
139
140        self.callbackDeposit = []
[9546]141        self.inputs = [("Data", ExampleTable, self.setData), ("Predictors", orange.Classifier, self.setPredictor, Multiple)]
[8042]142        self.outputs = [("Predictions", ExampleTable)]
143        self.predictors = {}
144
145        # saveble settings
146        self.showProb = 1;
147        self.showClass = 1
148        self.ShowAttributeMethod = 0
149        self.sendOnChange = 1
150        self.classes = []
151        self.selectedClasses = []
152        self.loadSettings()
153        self.datalabel = "N/A"
154        self.predictorlabel = "N/A"
155        self.tasklabel = "N/A"
156        self.precision = 2
[9633]157        self.doPrediction = False
[8042]158        self.outvar = None # current output variable (set by the first predictor/data set send in)
159
160        self.data = None
161        self.changedFlag = False
162       
163        self.loadSettings()
164
165        # GUI - Options
166
167        # Options - classification
168        ibox = OWGUI.widgetBox(self.controlArea, "Info")
169        OWGUI.label(ibox, self, "Data: %(datalabel)s")
170        OWGUI.label(ibox, self, "Predictors: %(predictorlabel)s")
171        OWGUI.label(ibox, self, "Task: %(tasklabel)s")
172        OWGUI.separator(self.controlArea)
173       
174        self.copt = OWGUI.widgetBox(self.controlArea, "Options (classification)")
175        self.copt.setDisabled(1)
176        cb = OWGUI.checkBox(self.copt, self, 'showProb', "Show predicted probabilities", callback=self.setPredictionDelegate)#self.updateTableOutcomes)
177
178#        self.lbClasses = OWGUI.listBox(self.copt, self, selectionMode = QListWidget.MultiSelection, callback = self.updateTableOutcomes)
179        ibox = OWGUI.indentedBox(self.copt, sep=OWGUI.checkButtonOffsetHint(cb))
180        self.lbcls = OWGUI.listBox(ibox, self, "selectedClasses", "classes",
181                                   callback=[self.setPredictionDelegate, self.checksendpredictions],
182#                                   callback=[self.updateTableOutcomes, self.checksendpredictions],
183                                   selectionMode=QListWidget.MultiSelection)
184        self.lbcls.setFixedHeight(50)
185
186        OWGUI.spin(ibox, self, "precision", 1, 6, label="No. of decimals: ",
187                   orientation=0, callback=self.setPredictionDelegate) #self.updateTableOutcomes)
188       
189        cb.disables.append(ibox)
190        ibox.setEnabled(bool(self.showProb))
191
192        OWGUI.checkBox(self.copt, self, 'showClass', "Show predicted class",
193                       callback=[self.setPredictionDelegate, self.checksendpredictions])
194#                       callback=[self.updateTableOutcomes, self.checksendpredictions])
195
196        OWGUI.separator(self.controlArea)
197
198        self.att = OWGUI.widgetBox(self.controlArea, "Data attributes")
199        OWGUI.radioButtonsInBox(self.att, self, 'ShowAttributeMethod', ['Show all', 'Hide all'], callback=lambda :self.setDataModel(self.data)) #self.updateAttributes)
200        self.att.setDisabled(1)
201        OWGUI.rubber(self.controlArea)
202
203        OWGUI.separator(self.controlArea)
204        self.outbox = OWGUI.widgetBox(self.controlArea, "Output")
205       
206        b = self.commitBtn = OWGUI.button(self.outbox, self, "Send Predictions", callback=self.sendpredictions, default=True)
207        cb = OWGUI.checkBox(self.outbox, self, 'sendOnChange', 'Send automatically')
208        OWGUI.setStopper(self, b, cb, "changedFlag", callback=self.sendpredictions)
[9250]209        OWGUI.checkBox(self.outbox, self, "doPrediction", "Replace/add predicted class",
210                       tooltip="Apply the first predictor to input examples and replace/add the predicted value as the new class variable.",
211                       callback=self.checksendpredictions)
[8042]212
213        self.outbox.setDisabled(1)
214
215        ## GUI table
216
217        self.splitter = splitter = QSplitter(Qt.Horizontal, self.mainArea)
218        self.dataView = QTableView()
219        self.predictionsView = QTableView()
220       
221        self.dataView.verticalHeader().setDefaultSectionSize(22)
222        self.dataView.setHorizontalScrollMode(QTableWidget.ScrollPerPixel)
223        self.dataView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
224        self.dataView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
225       
226        self.predictionsView.verticalHeader().setDefaultSectionSize(22)
227        self.predictionsView.setHorizontalScrollMode(QTableWidget.ScrollPerPixel)
228        self.predictionsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
229        self.predictionsView.verticalHeader().hide()
230       
231       
[9601]232#        def syncVertical(value):
233#            """ sync vertical scroll positions of the two views
234#            """
235#            v1 = self.predictionsView.verticalScrollBar().value()
236#            if v1 != value:
237#                self.predictionsView.verticalScrollBar().setValue(value)
238#            v2 = self.dataView.verticalScrollBar().value()
239#            if v2 != value:
240#                self.dataView.verticalScrollBar().setValue(v1)
[8042]241               
[9601]242        self.connect(self.dataView.verticalScrollBar(), SIGNAL("valueChanged(int)"), self.syncVertical)
243        self.connect(self.predictionsView.verticalScrollBar(), SIGNAL("valueChanged(int)"), self.syncVertical)
[8042]244       
245        splitter.addWidget(self.dataView)
246        splitter.addWidget(self.predictionsView)
247        splitter.setHandleWidth(3)
248        splitter.setChildrenCollapsible(False)
249        self.mainArea.layout().addWidget(splitter)
250       
251        self.spliter_restore_state = -1, 0
252        self.dataModel = None
253        self.predictionsModel = None
254       
255        self.resize(800, 600)
256       
257        self.handledAllSignalsFlag = False
258       
[9601]259    def syncVertical(self, value):
260        """ sync vertical scroll positions of the two views
261        """
262        v1 = self.predictionsView.verticalScrollBar().value()
263        if v1 != value:
264            self.predictionsView.verticalScrollBar().setValue(value)
265        v2 = self.dataView.verticalScrollBar().value()
266        if v2 != value:
267            self.dataView.verticalScrollBar().setValue(v1)
[8042]268       
269    def updateSpliter(self):
270        if not (self.dataModel and self.dataModel.columnCount() and \
271                self.predictionsModel and self.predictionsModel.columnCount()):
272            return
273        def width(view):
274            h_header = view.horizontalHeader()
275            v_header = view.verticalHeader()
276            return h_header.length() + v_header.width()
277       
278        def widthForColumns(view, start=-sys.maxint, end=sys.maxint):
279            h_header = view.horizontalHeader()
280            v_header = view.verticalHeader()
281            return sum([h_header.sectionSize(ind) for ind in range(h_header.count())[start : end]] or [0]) + v_header.width()
282       
283        if self.ShowAttributeMethod == 1:
284            w1, w2 = self.splitter.sizes()
285            w = width(self.dataView) + 4
286            self.splitter.setSizes([w, w1 + w2 - w])
287            self.dataView.setMaximumWidth(w)
288           
289            state, w = self.spliter_restore_state
290            if state == 0: # save the dataView width on change from 'show all' to 'hide all'
291                self.spliter_restore_state = 1, w1
292        else:
293            w1, w2 = self.splitter.sizes()
294            state, w = self.spliter_restore_state
295            if state == 1: # restore dataView on change from 'hide all' to 'show all'
296                w = min(w, (w1 + w2)*2 / 3)
297            else:
298                w1, w2 = self.splitter.sizes()
299                w = widthForColumns(self.dataView, -2) + 4
300                w = min(w,  (w1 + w2) / 2)
301                w = max(w,  min(w1 + w2 - widthForColumns(self.predictionsView) - 20, w1 + w2 - w))
302            self.splitter.setSizes([w, w1 + w2 - w])
303            self.dataView.setMaximumWidth(16777215) # This is QWidget's max width
304           
305            self.spliter_restore_state = 0, w
306           
307    def setDataModel(self, data):
308        if data is not None and self.outvar is not None:
309            if self.ShowAttributeMethod == 1: # Show only the class column
310                data = orange.ExampleTable(orange.Domain([self.outvar]), data)
311            elif not data.domain.classVar: # add outVar as class (all unknown values) to data
312                domain = orange.Domain(data.domain.attributes + [self.outvar])
313                domain.addmetas(data.domain.getmetas())
314                data = orange.ExampleTable(domain, data)
315               
316            dist = getCached(data, orange.DomainBasicAttrStat, (data,))
317            self.dataModel = ExampleTableModel(data, dist, None)
318            self.dataView.setModel(self.dataModel)
319            self.dataView.setItemDelegate(OWGUI.TableBarItem(self, data, color = Qt.lightGray))
320            count = self.dataModel.columnCount()
321            if count:
322                self.dataView.scrollTo(self.dataModel.index(0, count - 1))
323            self.dataView.show()
324        else:
325            self.clear()
326        self.updateSpliter()
327           
328    def setPredictionModel(self, classifiers, data):
329        predictions = [(c, [c(ex, orange.GetBoth) for ex in self.data]) for c in classifiers]
330        self.predictionsModel = PredictionTableModel(predictions)
331        self.predictionsView.setModel(self.predictionsModel)
332        self.setPredictionDelegate()
333        self.predictionsView.show()
334       
335       
336    def setPredictionDelegate(self):
337        delegate = PredictionItemDelegete(self)
338        delegate.showProbs = [self.showProb and i in self.selectedClasses for i in range(len(self.classes))]
339        delegate.decimals = self.precision
340        delegate.showClass = self.showClass
341        self.predictionsView.setItemDelegate(delegate)
342        self.predictionsView.resizeColumnsToContents()
343        self.updateSpliter()
344
345#    def sort(self, col):
346#        "sorts the table by column col"
347#        self.sortby = - self.sortby
348#        self.table.sortItems(col, self.sortby>=0)
349#
350#        # the table may be sorted, figure out data indices
351#        for i in range(len(self.data)):
352#            self.rindx[int(str(self.table.item(i,0).text()))-1] = i
353#        for (i, indx) in enumerate(self.rindx):
354#            self.vheader.setLabel(i, self.table.item(i,0).text())
355
356    def checkenable(self):
357        # following should be more complicated and depends on what data are we showing
358        cond = (self.outvar != None) and (self.data != None)
359        self.outbox.setEnabled(cond)
360        self.att.setEnabled(cond)
361        self.copt.setEnabled(cond)
362        e = (self.data and (self.data.domain.classVar <> None) + len(self.predictors)) >= 2
363        # need at least two classes to compare predictions
364
365    def clear(self):
366        self.send("Predictions", None)
367        self.checkenable()
368        if len(self.predictors) == 0:
369            self.outvar = None
370            self.classes = []
371            self.selectedClasses = []
372            self.predictorlabel = "N/A"
373        self.dataModel = PyTableModel([[]])
374        self.dataView.setModel(self.dataModel)
375        self.predictionsModel = PyTableModel([[]])
376        self.predictionsView.setModel(self.predictionsModel)
377       
378        self.dataView.hide()
379        self.predictionsView.hide()
[9239]380       
[8042]381
382    ##############################################################################
383    # Input signals
384
385    def handleNewSignals(self):
386        self.handledAllSignalsFlag = True
387        if self.data:
388            self.setDataModel(self.data)
[11660]389            self.openContext("", list(self.classes))
[8042]390            self.setPredictionModel(self.predictors.values(), self.data)
[11660]391
[9239]392        self.checksendpredictions()
[8042]393
394    def setData(self, data):
[11660]395        """
396        Set input data table.
397        """
398        self.closeContext("")
399        if data is None:
[8042]400            self.data = data
401            self.datalabel = "N/A"
402            self.clear()
403        else:
404            vartypes = {1:"discrete", 2:"continuous"}
405            self.data = data
406            self.rindx = range(len(self.data))
407            self.datalabel = "%d instances" % len(data)
[9239]408           
[8042]409        self.checkenable()
[9239]410        self.changedFlag = True
[8042]411
412    def setPredictor(self, predictor, id):
413        """handles incoming classifier (prediction, as could be a regressor as well)"""
414
415        def getoutvar(predictors):
416            """return outcome variable, if consistent among predictors, else None"""
417            if not len(predictors):
418                return None
419            ov = predictors[0].classVar
420            for predictor in predictors[1:]:
421                if ov != predictor.classVar:
422                    self.warning(0, "Mismatch in class variable (e.g., predictors %s and %s)" % (predictors[0].name, predictor.name))
423                    return None
424            return ov
425
426        self.handledAllSignalsFlag = False
427       
428        # remove the classifier with id, if empty
429        if not predictor:
430            if self.predictors.has_key(id):
431                del self.predictors[id]
432                if len(self.predictors) == 0:
433                    self.clear()
434                else:
435                    self.predictorlabel = "%d" % len(self.predictors)
436            return
437
438        # set the classifier
439        self.predictors[id] = predictor
440        self.predictorlabel = "%d" % len(self.predictors)
441
442        # set the outcome variable
443        ov = getoutvar(self.predictors.values())
444        if len(self.predictors) and not ov:
445            self.tasklabel = "N/A (type mismatch)"
446            self.classes = []
447            self.selectedClasses = []
448            self.clear()
449            self.outvar = None
450            return
451        self.warning(0) # clear all warnings
452
453        if ov != self.outvar:
454            self.outvar = ov
455            # regression or classification?
456            if self.outvar.varType == orange.VarTypes.Continuous:
457                self.copt.hide()
458                self.tasklabel = "Regression"
459            else:
460                self.copt.show()
461                self.classes = [str(v) for v in self.outvar.values]
462                self.selectedClasses = []
463                self.tasklabel = "Classification"
464               
465        self.checkenable()
[9239]466        self.changedFlag = True
[8042]467
468    ##############################################################################
469    # Ouput signals
470
471    def checksendpredictions(self):
[9239]472        if self.sendOnChange:
[8042]473            self.sendpredictions()
474        else:
475            self.changedFlag = True
476
477    def sendpredictions(self):
478        if not self.data or not self.outvar:
479            self.send("Predictions", None)
480            return
481
482        # predictions, data set with class predictions
483        classification = self.outvar.varType == orange.VarTypes.Discrete
484
485        metas = []
486        if classification:
487            if len(self.selectedClasses):
488                for c in self.predictors.values():
489                    m = [orange.FloatVariable(name=str("%s(%s)" % (c.name, str(self.outvar.values[i]))),
490                                              getValueFrom = lambda ex, rw, cindx=i, c=c: orange.Value(c(ex, c.GetProbabilities)[cindx])) \
491                         for i in self.selectedClasses]
492                    metas.extend(m)
493            if self.showClass:
494                mc = [orange.EnumVariable(name=str(c.name), values = self.outvar.values,
495                                         getValueFrom = lambda ex, rw, c=c: orange.Value(c(ex)))
496                      for c in self.predictors.values()]
497                metas.extend(mc)
498        else:
499            # regression
[10795]500            mc = [orange.FloatVariable(name="%s" % str(c.name),
501                    getValueFrom=lambda ex, rw, c=c: orange.Value(c(ex)))
[8042]502                  for c in self.predictors.values()]
503            metas.extend(mc)
[9250]504               
505        classVar = self.outvar
506        domain = orange.Domain(self.data.domain.attributes + [classVar])
[8042]507        domain.addmetas(self.data.domain.getmetas())
508        for m in metas:
509            domain.addmeta(orange.newmetaid(), m)
510        predictions = orange.ExampleTable(domain, self.data)
[9250]511        if self.doPrediction:
512            c = self.predictors.values()[0]
513            for ex in predictions:
514                ex[classVar] = c(ex)
515               
[8042]516        predictions.name = self.data.name
517        self.send("Predictions", predictions)
518       
519        self.changedFlag = False
520
521##############################################################################
522# Test the widget, run from DOS prompt
523
524if __name__=="__main__":
525    a = QApplication(sys.argv)
526    ow = OWPredictions()
527    ow.show()
528
529    import orngTree
530
531    dataset = orange.ExampleTable('../../doc/datasets/iris.tab')
532#    dataset = orange.ExampleTable('../../doc/datasets/auto-mpg.tab')
533    ind = orange.MakeRandomIndices2(p0=0.5)(dataset)
534    data = dataset.select(ind, 0)
535    test = dataset.select(ind, 1)
536    testnoclass = orange.ExampleTable(orange.Domain(test.domain.attributes, False), test)       
537    tree = orngTree.TreeLearner(data)
538    tree.name = "tree"
539    maj = orange.MajorityLearner(data)
540    maj.name = "maj"
541    knn = orange.kNNLearner(data, k = 10)
542    knn.name = "knn"
543   
544#    ow.setData(test)
545#   
546#    ow.setPredictor(maj, 1)
547   
548   
549
550    if 1: # data set only
551        ow.setData(test)
552    if 0: # two predictors, test data with class
553        ow.setPredictor(maj, 1)
554        ow.setPredictor(tree, 2)
555        ow.setData(test)
556    if 0: # two predictors, test data with no class
557        ow.setPredictor(maj, 1)
558        ow.setPredictor(tree, 2)
559        ow.setData(testnoclass)
560    if 1: # three predictors
561        ow.setPredictor(tree, 1)
562        ow.setPredictor(maj, 2)
563        ow.setData(data)
564        ow.setPredictor(knn, 3)
565    if 0: # just classifier, no data
566        ow.setData(data)
567        ow.setPredictor(maj, 1)
568        ow.setPredictor(knn, 2)
569    if 0: # change data set
570        ow.setPredictor(maj, 1)
571        ow.setPredictor(tree, 2)
572        ow.setData(testnoclass)
573        data = orange.ExampleTable('../../doc/datasets/titanic.tab')
574        tree = orngTree.TreeLearner(data)
575        tree.name = "tree"
576        ow.setPredictor(tree, 2)
577        ow.setData(data)
578       
579    ow.handleNewSignals()
580
581    a.exec_()
582    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.