source: orange/orange/OrangeWidgets/Evaluate/OWPredictions.py @ 9546:2b6cc6f397fe

Revision 9546:2b6cc6f397fe, 22.2 KB checked in by ales_erjavec <ales.erjavec@…>, 2 years ago (diff)

Renamed widget channel names in line with the new naming rules/convention.
Added backwards compatibility in orngDoc loadDocument to enable loading of schemas saved before the change.

Line 
1"""
2<name>Predictions</name>
3<description>Displays predictions of models for a particular data set.</description>
4<icon>icons/Predictions.png</icon>
5<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
6<priority>300</priority>
7"""
8
9from OWWidget import *
10import OWGUI
11import statc
12import orange
13
14from OWDataTable import ExampleTableModel, getCached
15
16def safe_call(func):
17    from functools import wraps
18    @wraps(func)
19    def wrapper(*args, **kwargs):
20        try:
21            return func(*args, **kwargs)
22        except Exception, ex:
23            print >> sys.stderr, func.__name__, "call error", ex
24            return QVariant()
25    return wrapper
26
27class PyTableModel(QAbstractTableModel):
28    """ A general list-of-lists table holding arbitrary Python objects.
29    To view it in a item view you must subclass an QItemDelegate
30   
31    Arguments:
32        - `table`: A 2D table of any python objects
33        - `headers`: A list of strings for view headers
34        - `parent`: models parent (default None)
35       
36    Examples::
37        view = QTableView()
38        view.setModel(PyTableView([[1, 2, 3], [1, 2, 3]], ["One, "Two", "Three"], parent=view)
39    """
40    def __init__(self, table=None, headers=None, parent=None):
41        QAbstractTableModel.__init__(self, parent)
42        self._table = [[]] if table is None else table
43        self._header = [None] * len(self._table) if headers is None else headers
44       
45    @safe_call
46    def data(self, index, role=Qt.DisplayRole):
47        row, column = index.row(), index.column()
48        if role == Qt.DisplayRole or role == Qt.EditRole:
49            val = self._table[row][column]
50            return QVariant(val)
51        else:
52            return QVariant() #QAbstractTableModel.data(self, index, role)
53       
54    def rowCount(self, parent=QModelIndex()):
55        if parent.isValid():
56            return 0
57        else:
58            return len(self._table)
59       
60    def columnCount(self, parent=QModelIndex()):
61        if parent.isValid():
62            return 0
63        else:
64            return max([len(row) for row in self._table]) if self._table else 0
65       
66    def headerData(self, section, orientation, role=Qt.DisplayRole):
67        if orientation == Qt.Vertical and  role == Qt.DisplayRole:
68            return QVariant(QString(str(section + 1)))
69        elif orientation == Qt.Horizontal and role == Qt.DisplayRole:
70            header = self._header[section] if section < len(self._header) else str(section)
71            return QVariant(QString(header)) if header is not None else QVariant()
72        else:
73            return QVariant() #QAbstractTableModel.headerData(self, section, orientation, role)
74       
75    def sort(self, column, order=Qt.AscendingOrder):
76        self._table.sort(key=lambda row: row[column], reverse=order == Qt.DescendingOrder)
77        self.reset()
78       
79       
80class PredictionTableModel(PyTableModel):
81    """ Item model for classifier predictions
82    """
83    def __init__(self, prediction_results, *args, **kwargs):
84        """ prediciton_results [(classifier, list-of-predictions) ...]
85        """
86        PyTableModel.__init__(self, *args, **kwargs)
87        self.prediction_results = prediction_results
88        self._header = [p[0].name for p in prediction_results]
89       
90        self._table = [[] for i in range(max([len(p) for c, p in prediction_results] or [0]))]
91        for c, pred in prediction_results:
92            for i, p in enumerate(pred):
93                self._table[i].append(p)
94
95   
96class PredictionItemDelegete(QStyledItemDelegate):
97    """ Item delegate for prediction results i.e. (class, probabilities)
98    tuples as returned by classifiers with orange.Both return value
99   
100    Optional arguments:
101        - `showProbs`: a list of `True` or `False` values for each class value.
102        These are the probabilities that will be displayed (default all False
103        i.e. no probabilities shown)
104        - `deciamals`: number of decimals to show
105    """
106    def __init__(self, parent=None, *kwargs):
107        QStyledItemDelegate.__init__(self, parent)
108        self.__dict__.update(kwargs)
109   
110    def displayText(self, value, locale):
111        pred = value.toPyObject()
112        if type(pred) >= tuple:
113            cls, prob = pred
114        elif type(pred) >= orange.Value:
115            cls, prob = pred, None
116        elif type(pred) >= orange.Distribution:
117            cls, prob = pred.modus(), pred
118        else:
119            return QString("")
120        text = ""
121        if prob and any(getattr(self, "showProbs", [])):
122            fmt = "%%.%if" % getattr(self, "decimals", 2)
123            text = " : ".join(fmt % f for f, show in zip(prob, self.showProbs) if show)
124            if getattr(self, "showClass", True):
125                text += " -> "
126        if getattr(self, "showClass", True):
127            text += str(cls)
128        return QString(text)
129
130class OWPredictions(OWWidget):
131    settingsList = ["showProb", "showClass", "ShowAttributeMethod", "sendOnChange", "precision"]
132
133    def __init__(self, parent=None, signalManager = None):
134        OWWidget.__init__(self, parent, signalManager, "Predictions")
135
136        self.callbackDeposit = []
137        self.inputs = [("Data", ExampleTable, self.setData), ("Predictors", orange.Classifier, self.setPredictor, Multiple)]
138        self.outputs = [("Predictions", ExampleTable)]
139        self.predictors = {}
140
141        # saveble settings
142        self.showProb = 1;
143        self.showClass = 1
144        self.ShowAttributeMethod = 0
145        self.sendOnChange = 1
146        self.classes = []
147        self.selectedClasses = []
148        self.loadSettings()
149        self.datalabel = "N/A"
150        self.predictorlabel = "N/A"
151        self.tasklabel = "N/A"
152        self.precision = 2
153        self.doPrediction = True
154        self.outvar = None # current output variable (set by the first predictor/data set send in)
155
156        self.data = None
157        self.changedFlag = False
158       
159        self.loadSettings()
160
161        # GUI - Options
162
163        # Options - classification
164        ibox = OWGUI.widgetBox(self.controlArea, "Info")
165        OWGUI.label(ibox, self, "Data: %(datalabel)s")
166        OWGUI.label(ibox, self, "Predictors: %(predictorlabel)s")
167        OWGUI.label(ibox, self, "Task: %(tasklabel)s")
168        OWGUI.separator(self.controlArea)
169       
170        self.copt = OWGUI.widgetBox(self.controlArea, "Options (classification)")
171        self.copt.setDisabled(1)
172        cb = OWGUI.checkBox(self.copt, self, 'showProb', "Show predicted probabilities", callback=self.setPredictionDelegate)#self.updateTableOutcomes)
173
174#        self.lbClasses = OWGUI.listBox(self.copt, self, selectionMode = QListWidget.MultiSelection, callback = self.updateTableOutcomes)
175        ibox = OWGUI.indentedBox(self.copt, sep=OWGUI.checkButtonOffsetHint(cb))
176        self.lbcls = OWGUI.listBox(ibox, self, "selectedClasses", "classes",
177                                   callback=[self.setPredictionDelegate, self.checksendpredictions],
178#                                   callback=[self.updateTableOutcomes, self.checksendpredictions],
179                                   selectionMode=QListWidget.MultiSelection)
180        self.lbcls.setFixedHeight(50)
181
182        OWGUI.spin(ibox, self, "precision", 1, 6, label="No. of decimals: ",
183                   orientation=0, callback=self.setPredictionDelegate) #self.updateTableOutcomes)
184       
185        cb.disables.append(ibox)
186        ibox.setEnabled(bool(self.showProb))
187
188        OWGUI.checkBox(self.copt, self, 'showClass', "Show predicted class",
189                       callback=[self.setPredictionDelegate, self.checksendpredictions])
190#                       callback=[self.updateTableOutcomes, self.checksendpredictions])
191
192        OWGUI.separator(self.controlArea)
193
194        self.att = OWGUI.widgetBox(self.controlArea, "Data attributes")
195        OWGUI.radioButtonsInBox(self.att, self, 'ShowAttributeMethod', ['Show all', 'Hide all'], callback=lambda :self.setDataModel(self.data)) #self.updateAttributes)
196        self.att.setDisabled(1)
197        OWGUI.rubber(self.controlArea)
198
199        OWGUI.separator(self.controlArea)
200        self.outbox = OWGUI.widgetBox(self.controlArea, "Output")
201       
202        b = self.commitBtn = OWGUI.button(self.outbox, self, "Send Predictions", callback=self.sendpredictions, default=True)
203        cb = OWGUI.checkBox(self.outbox, self, 'sendOnChange', 'Send automatically')
204        OWGUI.setStopper(self, b, cb, "changedFlag", callback=self.sendpredictions)
205        OWGUI.checkBox(self.outbox, self, "doPrediction", "Replace/add predicted class",
206                       tooltip="Apply the first predictor to input examples and replace/add the predicted value as the new class variable.",
207                       callback=self.checksendpredictions)
208
209        self.outbox.setDisabled(1)
210
211        ## GUI table
212
213        self.splitter = splitter = QSplitter(Qt.Horizontal, self.mainArea)
214        self.dataView = QTableView()
215        self.predictionsView = QTableView()
216       
217        self.dataView.verticalHeader().setDefaultSectionSize(22)
218        self.dataView.setHorizontalScrollMode(QTableWidget.ScrollPerPixel)
219        self.dataView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
220        self.dataView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
221       
222        self.predictionsView.verticalHeader().setDefaultSectionSize(22)
223        self.predictionsView.setHorizontalScrollMode(QTableWidget.ScrollPerPixel)
224        self.predictionsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
225        self.predictionsView.verticalHeader().hide()
226       
227       
228        def syncVertical(value):
229            """ sync vertical scroll positions of the two views
230            """
231            v1 = self.predictionsView.verticalScrollBar().value()
232            if v1 != value:
233                self.predictionsView.verticalScrollBar().setValue(value)
234            v2 = self.dataView.verticalScrollBar().value()
235            if v2 != value:
236                self.dataView.verticalScrollBar().setValue(v1)
237               
238        self.connect(self.dataView.verticalScrollBar(), SIGNAL("valueChanged(int)"), syncVertical)
239        self.connect(self.predictionsView.verticalScrollBar(), SIGNAL("valueChanged(int)"), syncVertical)
240       
241        splitter.addWidget(self.dataView)
242        splitter.addWidget(self.predictionsView)
243        splitter.setHandleWidth(3)
244        splitter.setChildrenCollapsible(False)
245        self.mainArea.layout().addWidget(splitter)
246       
247        self.spliter_restore_state = -1, 0
248        self.dataModel = None
249        self.predictionsModel = None
250       
251        self.resize(800, 600)
252       
253        self.handledAllSignalsFlag = False
254       
255       
256    def updateSpliter(self):
257        if not (self.dataModel and self.dataModel.columnCount() and \
258                self.predictionsModel and self.predictionsModel.columnCount()):
259            return
260        def width(view):
261            h_header = view.horizontalHeader()
262            v_header = view.verticalHeader()
263            return h_header.length() + v_header.width()
264       
265        def widthForColumns(view, start=-sys.maxint, end=sys.maxint):
266            h_header = view.horizontalHeader()
267            v_header = view.verticalHeader()
268            return sum([h_header.sectionSize(ind) for ind in range(h_header.count())[start : end]] or [0]) + v_header.width()
269       
270        if self.ShowAttributeMethod == 1:
271            w1, w2 = self.splitter.sizes()
272            w = width(self.dataView) + 4
273            self.splitter.setSizes([w, w1 + w2 - w])
274            self.dataView.setMaximumWidth(w)
275           
276            state, w = self.spliter_restore_state
277            if state == 0: # save the dataView width on change from 'show all' to 'hide all'
278                self.spliter_restore_state = 1, w1
279        else:
280            w1, w2 = self.splitter.sizes()
281            state, w = self.spliter_restore_state
282            if state == 1: # restore dataView on change from 'hide all' to 'show all'
283                w = min(w, (w1 + w2)*2 / 3)
284            else:
285                w1, w2 = self.splitter.sizes()
286                w = widthForColumns(self.dataView, -2) + 4
287                w = min(w,  (w1 + w2) / 2)
288                w = max(w,  min(w1 + w2 - widthForColumns(self.predictionsView) - 20, w1 + w2 - w))
289            self.splitter.setSizes([w, w1 + w2 - w])
290            self.dataView.setMaximumWidth(16777215) # This is QWidget's max width
291           
292            self.spliter_restore_state = 0, w
293           
294    def setDataModel(self, data):
295        if data is not None and self.outvar is not None:
296            if self.ShowAttributeMethod == 1: # Show only the class column
297                data = orange.ExampleTable(orange.Domain([self.outvar]), data)
298            elif not data.domain.classVar: # add outVar as class (all unknown values) to data
299                domain = orange.Domain(data.domain.attributes + [self.outvar])
300                domain.addmetas(data.domain.getmetas())
301                data = orange.ExampleTable(domain, data)
302               
303            dist = getCached(data, orange.DomainBasicAttrStat, (data,))
304            self.dataModel = ExampleTableModel(data, dist, None)
305            self.dataView.setModel(self.dataModel)
306            self.dataView.setItemDelegate(OWGUI.TableBarItem(self, data, color = Qt.lightGray))
307            count = self.dataModel.columnCount()
308            if count:
309                self.dataView.scrollTo(self.dataModel.index(0, count - 1))
310            self.dataView.show()
311        else:
312            self.clear()
313        self.updateSpliter()
314           
315    def setPredictionModel(self, classifiers, data):
316        predictions = [(c, [c(ex, orange.GetBoth) for ex in self.data]) for c in classifiers]
317        self.predictionsModel = PredictionTableModel(predictions)
318        self.predictionsView.setModel(self.predictionsModel)
319        self.setPredictionDelegate()
320        self.predictionsView.show()
321       
322       
323    def setPredictionDelegate(self):
324        delegate = PredictionItemDelegete(self)
325        delegate.showProbs = [self.showProb and i in self.selectedClasses for i in range(len(self.classes))]
326        delegate.decimals = self.precision
327        delegate.showClass = self.showClass
328        self.predictionsView.setItemDelegate(delegate)
329        self.predictionsView.resizeColumnsToContents()
330        self.updateSpliter()
331
332#    def sort(self, col):
333#        "sorts the table by column col"
334#        self.sortby = - self.sortby
335#        self.table.sortItems(col, self.sortby>=0)
336#
337#        # the table may be sorted, figure out data indices
338#        for i in range(len(self.data)):
339#            self.rindx[int(str(self.table.item(i,0).text()))-1] = i
340#        for (i, indx) in enumerate(self.rindx):
341#            self.vheader.setLabel(i, self.table.item(i,0).text())
342
343    def checkenable(self):
344        # following should be more complicated and depends on what data are we showing
345        cond = (self.outvar != None) and (self.data != None)
346        self.outbox.setEnabled(cond)
347        self.att.setEnabled(cond)
348        self.copt.setEnabled(cond)
349        e = (self.data and (self.data.domain.classVar <> None) + len(self.predictors)) >= 2
350        # need at least two classes to compare predictions
351
352    def clear(self):
353        self.send("Predictions", None)
354        self.checkenable()
355        if len(self.predictors) == 0:
356            self.outvar = None
357            self.classes = []
358            self.selectedClasses = []
359            self.predictorlabel = "N/A"
360        self.dataModel = PyTableModel([[]])
361        self.dataView.setModel(self.dataModel)
362        self.predictionsModel = PyTableModel([[]])
363        self.predictionsView.setModel(self.predictionsModel)
364       
365        self.dataView.hide()
366        self.predictionsView.hide()
367       
368
369    ##############################################################################
370    # Input signals
371
372    def handleNewSignals(self):
373        self.handledAllSignalsFlag = True
374        if self.data:
375            self.setDataModel(self.data)
376            self.setPredictionModel(self.predictors.values(), self.data)
377        self.checksendpredictions()
378
379    def setData(self, data):
380        self.handledAllSignalsFlag = False
381        if not data:
382            self.data = data
383            self.datalabel = "N/A"
384            self.clear()
385        else:
386            vartypes = {1:"discrete", 2:"continuous"}
387            self.data = data
388            self.rindx = range(len(self.data))
389            self.datalabel = "%d instances" % len(data)
390           
391        self.checkenable()
392        self.changedFlag = True
393
394    def setPredictor(self, predictor, id):
395        """handles incoming classifier (prediction, as could be a regressor as well)"""
396
397        def getoutvar(predictors):
398            """return outcome variable, if consistent among predictors, else None"""
399            if not len(predictors):
400                return None
401            ov = predictors[0].classVar
402            for predictor in predictors[1:]:
403                if ov != predictor.classVar:
404                    self.warning(0, "Mismatch in class variable (e.g., predictors %s and %s)" % (predictors[0].name, predictor.name))
405                    return None
406            return ov
407
408        self.handledAllSignalsFlag = False
409       
410        # remove the classifier with id, if empty
411        if not predictor:
412            if self.predictors.has_key(id):
413                del self.predictors[id]
414                if len(self.predictors) == 0:
415                    self.clear()
416                else:
417                    self.predictorlabel = "%d" % len(self.predictors)
418            return
419
420        # set the classifier
421        self.predictors[id] = predictor
422        self.predictorlabel = "%d" % len(self.predictors)
423
424        # set the outcome variable
425        ov = getoutvar(self.predictors.values())
426        if len(self.predictors) and not ov:
427            self.tasklabel = "N/A (type mismatch)"
428            self.classes = []
429            self.selectedClasses = []
430            self.clear()
431            self.outvar = None
432            return
433        self.warning(0) # clear all warnings
434
435        if ov != self.outvar:
436            self.outvar = ov
437            # regression or classification?
438            if self.outvar.varType == orange.VarTypes.Continuous:
439                self.copt.hide()
440                self.tasklabel = "Regression"
441            else:
442                self.copt.show()
443                self.classes = [str(v) for v in self.outvar.values]
444                self.selectedClasses = []
445                self.tasklabel = "Classification"
446               
447        self.checkenable()
448        self.changedFlag = True
449
450    ##############################################################################
451    # Ouput signals
452
453    def checksendpredictions(self):
454        if self.sendOnChange:
455            self.sendpredictions()
456        else:
457            self.changedFlag = True
458
459    def sendpredictions(self):
460        if not self.data or not self.outvar:
461            self.send("Predictions", None)
462            return
463
464        # predictions, data set with class predictions
465        classification = self.outvar.varType == orange.VarTypes.Discrete
466
467        metas = []
468        if classification:
469            if len(self.selectedClasses):
470                for c in self.predictors.values():
471                    m = [orange.FloatVariable(name=str("%s(%s)" % (c.name, str(self.outvar.values[i]))),
472                                              getValueFrom = lambda ex, rw, cindx=i, c=c: orange.Value(c(ex, c.GetProbabilities)[cindx])) \
473                         for i in self.selectedClasses]
474                    metas.extend(m)
475            if self.showClass:
476                mc = [orange.EnumVariable(name=str(c.name), values = self.outvar.values,
477                                         getValueFrom = lambda ex, rw, c=c: orange.Value(c(ex)))
478                      for c in self.predictors.values()]
479                metas.extend(mc)
480        else:
481            # regression
482            mc = [orange.FloatVariable(name="%s" % c.name, 
483                                       getValueFrom = lambda ex, rw, c=c: orange.Value(c(ex)))
484                  for c in self.predictors.values()]
485            metas.extend(mc)
486               
487        classVar = self.outvar
488        domain = orange.Domain(self.data.domain.attributes + [classVar])
489        domain.addmetas(self.data.domain.getmetas())
490        for m in metas:
491            domain.addmeta(orange.newmetaid(), m)
492        predictions = orange.ExampleTable(domain, self.data)
493        if self.doPrediction:
494            c = self.predictors.values()[0]
495            for ex in predictions:
496                ex[classVar] = c(ex)
497               
498        predictions.name = self.data.name
499        self.send("Predictions", predictions)
500       
501        self.changedFlag = False
502
503##############################################################################
504# Test the widget, run from DOS prompt
505
506if __name__=="__main__":
507    a = QApplication(sys.argv)
508    ow = OWPredictions()
509    ow.show()
510
511    import orngTree
512
513    dataset = orange.ExampleTable('../../doc/datasets/iris.tab')
514#    dataset = orange.ExampleTable('../../doc/datasets/auto-mpg.tab')
515    ind = orange.MakeRandomIndices2(p0=0.5)(dataset)
516    data = dataset.select(ind, 0)
517    test = dataset.select(ind, 1)
518    testnoclass = orange.ExampleTable(orange.Domain(test.domain.attributes, False), test)       
519    tree = orngTree.TreeLearner(data)
520    tree.name = "tree"
521    maj = orange.MajorityLearner(data)
522    maj.name = "maj"
523    knn = orange.kNNLearner(data, k = 10)
524    knn.name = "knn"
525   
526#    ow.setData(test)
527#   
528#    ow.setPredictor(maj, 1)
529   
530   
531
532    if 1: # data set only
533        ow.setData(test)
534    if 0: # two predictors, test data with class
535        ow.setPredictor(maj, 1)
536        ow.setPredictor(tree, 2)
537        ow.setData(test)
538    if 0: # two predictors, test data with no class
539        ow.setPredictor(maj, 1)
540        ow.setPredictor(tree, 2)
541        ow.setData(testnoclass)
542    if 1: # three predictors
543        ow.setPredictor(tree, 1)
544        ow.setPredictor(maj, 2)
545        ow.setData(data)
546        ow.setPredictor(knn, 3)
547    if 0: # just classifier, no data
548        ow.setData(data)
549        ow.setPredictor(maj, 1)
550        ow.setPredictor(knn, 2)
551    if 0: # change data set
552        ow.setPredictor(maj, 1)
553        ow.setPredictor(tree, 2)
554        ow.setData(testnoclass)
555        data = orange.ExampleTable('../../doc/datasets/titanic.tab')
556        tree = orngTree.TreeLearner(data)
557        tree.name = "tree"
558        ow.setPredictor(tree, 2)
559        ow.setData(data)
560       
561    ow.handleNewSignals()
562
563    a.exec_()
564    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.