source: orange/orange/OrangeWidgets/Evaluate/OWPredictions.py @ 9633:d89f4632a0ea

Revision 9633:d89f4632a0ea, 22.6 KB checked in by ales_erjavec, 2 years ago (diff)

By default don't replace class with predicitons.

Line 
1"""
2<name>Predictions</name>
3<description>Displays predictions of models for a particular data set.</description>
4<icon>icons/Predictions.png</icon>
5<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
6<priority>300</priority>
7"""
8
9from OWWidget import *
10import OWGUI
11import statc
12import orange
13
14from OWDataTable import ExampleTableModel, getCached
15
16def safe_call(func):
17    from functools import wraps
18    @wraps(func)
19    def wrapper(*args, **kwargs):
20        try:
21            return func(*args, **kwargs)
22        except Exception, ex:
23            print >> sys.stderr, func.__name__, "call error", ex
24            return QVariant()
25    return wrapper
26
27class PyTableModel(QAbstractTableModel):
28    """ A general list-of-lists table holding arbitrary Python objects.
29    To view it in a item view you must subclass an QItemDelegate
30   
31    Arguments:
32        - `table`: A 2D table of any python objects
33        - `headers`: A list of strings for view headers
34        - `parent`: models parent (default None)
35       
36    Examples::
37        view = QTableView()
38        view.setModel(PyTableView([[1, 2, 3], [1, 2, 3]], ["One, "Two", "Three"], parent=view)
39    """
40    def __init__(self, table=None, headers=None, parent=None):
41        QAbstractTableModel.__init__(self, parent)
42        self._table = [[]] if table is None else table
43        self._header = [None] * len(self._table) if headers is None else headers
44       
45    @safe_call
46    def data(self, index, role=Qt.DisplayRole):
47        row, column = index.row(), index.column()
48        if role == Qt.DisplayRole or role == Qt.EditRole:
49            val = self._table[row][column]
50            return QVariant(val)
51        else:
52            return QVariant() #QAbstractTableModel.data(self, index, role)
53       
54    def rowCount(self, parent=QModelIndex()):
55        if parent.isValid():
56            return 0
57        else:
58            return len(self._table)
59       
60    def columnCount(self, parent=QModelIndex()):
61        if parent.isValid():
62            return 0
63        else:
64            return max([len(row) for row in self._table]) if self._table else 0
65       
66    def headerData(self, section, orientation, role=Qt.DisplayRole):
67        if orientation == Qt.Vertical and  role == Qt.DisplayRole:
68            return QVariant(QString(str(section + 1)))
69        elif orientation == Qt.Horizontal and role == Qt.DisplayRole:
70            header = self._header[section] if section < len(self._header) else str(section)
71            return QVariant(QString(header)) if header is not None else QVariant()
72        else:
73            return QVariant() #QAbstractTableModel.headerData(self, section, orientation, role)
74       
75    def sort(self, column, order=Qt.AscendingOrder):
76        self._table.sort(key=lambda row: row[column], reverse=order == Qt.DescendingOrder)
77        self.reset()
78       
79       
80class PredictionTableModel(PyTableModel):
81    """ Item model for classifier predictions
82    """
83    def __init__(self, prediction_results, *args, **kwargs):
84        """ prediciton_results [(classifier, list-of-predictions) ...]
85        """
86        PyTableModel.__init__(self, *args, **kwargs)
87        self.prediction_results = prediction_results
88        self._header = [p[0].name for p in prediction_results]
89       
90        self._table = [[] for i in range(max([len(p) for c, p in prediction_results] or [0]))]
91        for c, pred in prediction_results:
92            for i, p in enumerate(pred):
93                self._table[i].append(p)
94
95   
96class PredictionItemDelegete(QStyledItemDelegate):
97    """ Item delegate for prediction results i.e. (class, probabilities)
98    tuples as returned by classifiers with orange.Both return value
99   
100    Optional arguments:
101        - `showProbs`: a list of `True` or `False` values for each class value.
102        These are the probabilities that will be displayed (default all False
103        i.e. no probabilities shown)
104        - `deciamals`: number of decimals to show
105    """
106    def __init__(self, parent=None, *kwargs):
107        QStyledItemDelegate.__init__(self, parent)
108        self.__dict__.update(kwargs)
109   
110    def displayText(self, value, locale):
111        pred = value.toPyObject()
112        if type(pred) >= tuple:
113            cls, prob = pred
114        elif type(pred) >= orange.Value:
115            cls, prob = pred, None
116        elif type(pred) >= orange.Distribution:
117            cls, prob = pred.modus(), pred
118        else:
119            return QString("")
120        text = ""
121        if prob and any(getattr(self, "showProbs", [])):
122            fmt = "%%.%if" % getattr(self, "decimals", 2)
123            text = " : ".join(fmt % f for f, show in zip(prob, self.showProbs) if show)
124            if getattr(self, "showClass", True):
125                text += " -> "
126        if getattr(self, "showClass", True):
127            text += str(cls)
128        return QString(text)
129
130class OWPredictions(OWWidget):
131    settingsList = ["showProb", "showClass", "ShowAttributeMethod", "sendOnChange", "precision"]
132
133    def __init__(self, parent=None, signalManager = None):
134        OWWidget.__init__(self, parent, signalManager, "Predictions")
135
136        self.callbackDeposit = []
137        self.inputs = [("Data", ExampleTable, self.setData), ("Predictors", orange.Classifier, self.setPredictor, Multiple)]
138        self.outputs = [("Predictions", ExampleTable)]
139        self.predictors = {}
140
141        # saveble settings
142        self.showProb = 1;
143        self.showClass = 1
144        self.ShowAttributeMethod = 0
145        self.sendOnChange = 1
146        self.classes = []
147        self.selectedClasses = []
148        self.loadSettings()
149        self.datalabel = "N/A"
150        self.predictorlabel = "N/A"
151        self.tasklabel = "N/A"
152        self.precision = 2
153        self.doPrediction = False
154        self.outvar = None # current output variable (set by the first predictor/data set send in)
155
156        self.data = None
157        self.changedFlag = False
158       
159        self.loadSettings()
160
161        # GUI - Options
162
163        # Options - classification
164        ibox = OWGUI.widgetBox(self.controlArea, "Info")
165        OWGUI.label(ibox, self, "Data: %(datalabel)s")
166        OWGUI.label(ibox, self, "Predictors: %(predictorlabel)s")
167        OWGUI.label(ibox, self, "Task: %(tasklabel)s")
168        OWGUI.separator(self.controlArea)
169       
170        self.copt = OWGUI.widgetBox(self.controlArea, "Options (classification)")
171        self.copt.setDisabled(1)
172        cb = OWGUI.checkBox(self.copt, self, 'showProb', "Show predicted probabilities", callback=self.setPredictionDelegate)#self.updateTableOutcomes)
173
174#        self.lbClasses = OWGUI.listBox(self.copt, self, selectionMode = QListWidget.MultiSelection, callback = self.updateTableOutcomes)
175        ibox = OWGUI.indentedBox(self.copt, sep=OWGUI.checkButtonOffsetHint(cb))
176        self.lbcls = OWGUI.listBox(ibox, self, "selectedClasses", "classes",
177                                   callback=[self.setPredictionDelegate, self.checksendpredictions],
178#                                   callback=[self.updateTableOutcomes, self.checksendpredictions],
179                                   selectionMode=QListWidget.MultiSelection)
180        self.lbcls.setFixedHeight(50)
181
182        OWGUI.spin(ibox, self, "precision", 1, 6, label="No. of decimals: ",
183                   orientation=0, callback=self.setPredictionDelegate) #self.updateTableOutcomes)
184       
185        cb.disables.append(ibox)
186        ibox.setEnabled(bool(self.showProb))
187
188        OWGUI.checkBox(self.copt, self, 'showClass', "Show predicted class",
189                       callback=[self.setPredictionDelegate, self.checksendpredictions])
190#                       callback=[self.updateTableOutcomes, self.checksendpredictions])
191
192        OWGUI.separator(self.controlArea)
193
194        self.att = OWGUI.widgetBox(self.controlArea, "Data attributes")
195        OWGUI.radioButtonsInBox(self.att, self, 'ShowAttributeMethod', ['Show all', 'Hide all'], callback=lambda :self.setDataModel(self.data)) #self.updateAttributes)
196        self.att.setDisabled(1)
197        OWGUI.rubber(self.controlArea)
198
199        OWGUI.separator(self.controlArea)
200        self.outbox = OWGUI.widgetBox(self.controlArea, "Output")
201       
202        b = self.commitBtn = OWGUI.button(self.outbox, self, "Send Predictions", callback=self.sendpredictions, default=True)
203        cb = OWGUI.checkBox(self.outbox, self, 'sendOnChange', 'Send automatically')
204        OWGUI.setStopper(self, b, cb, "changedFlag", callback=self.sendpredictions)
205        OWGUI.checkBox(self.outbox, self, "doPrediction", "Replace/add predicted class",
206                       tooltip="Apply the first predictor to input examples and replace/add the predicted value as the new class variable.",
207                       callback=self.checksendpredictions)
208
209        self.outbox.setDisabled(1)
210
211        ## GUI table
212
213        self.splitter = splitter = QSplitter(Qt.Horizontal, self.mainArea)
214        self.dataView = QTableView()
215        self.predictionsView = QTableView()
216       
217        self.dataView.verticalHeader().setDefaultSectionSize(22)
218        self.dataView.setHorizontalScrollMode(QTableWidget.ScrollPerPixel)
219        self.dataView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
220        self.dataView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
221       
222        self.predictionsView.verticalHeader().setDefaultSectionSize(22)
223        self.predictionsView.setHorizontalScrollMode(QTableWidget.ScrollPerPixel)
224        self.predictionsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
225        self.predictionsView.verticalHeader().hide()
226       
227       
228#        def syncVertical(value):
229#            """ sync vertical scroll positions of the two views
230#            """
231#            v1 = self.predictionsView.verticalScrollBar().value()
232#            if v1 != value:
233#                self.predictionsView.verticalScrollBar().setValue(value)
234#            v2 = self.dataView.verticalScrollBar().value()
235#            if v2 != value:
236#                self.dataView.verticalScrollBar().setValue(v1)
237               
238        self.connect(self.dataView.verticalScrollBar(), SIGNAL("valueChanged(int)"), self.syncVertical)
239        self.connect(self.predictionsView.verticalScrollBar(), SIGNAL("valueChanged(int)"), self.syncVertical)
240       
241        splitter.addWidget(self.dataView)
242        splitter.addWidget(self.predictionsView)
243        splitter.setHandleWidth(3)
244        splitter.setChildrenCollapsible(False)
245        self.mainArea.layout().addWidget(splitter)
246       
247        self.spliter_restore_state = -1, 0
248        self.dataModel = None
249        self.predictionsModel = None
250       
251        self.resize(800, 600)
252       
253        self.handledAllSignalsFlag = False
254       
255    def syncVertical(self, value):
256        """ sync vertical scroll positions of the two views
257        """
258        v1 = self.predictionsView.verticalScrollBar().value()
259        if v1 != value:
260            self.predictionsView.verticalScrollBar().setValue(value)
261        v2 = self.dataView.verticalScrollBar().value()
262        if v2 != value:
263            self.dataView.verticalScrollBar().setValue(v1)
264       
265    def updateSpliter(self):
266        if not (self.dataModel and self.dataModel.columnCount() and \
267                self.predictionsModel and self.predictionsModel.columnCount()):
268            return
269        def width(view):
270            h_header = view.horizontalHeader()
271            v_header = view.verticalHeader()
272            return h_header.length() + v_header.width()
273       
274        def widthForColumns(view, start=-sys.maxint, end=sys.maxint):
275            h_header = view.horizontalHeader()
276            v_header = view.verticalHeader()
277            return sum([h_header.sectionSize(ind) for ind in range(h_header.count())[start : end]] or [0]) + v_header.width()
278       
279        if self.ShowAttributeMethod == 1:
280            w1, w2 = self.splitter.sizes()
281            w = width(self.dataView) + 4
282            self.splitter.setSizes([w, w1 + w2 - w])
283            self.dataView.setMaximumWidth(w)
284           
285            state, w = self.spliter_restore_state
286            if state == 0: # save the dataView width on change from 'show all' to 'hide all'
287                self.spliter_restore_state = 1, w1
288        else:
289            w1, w2 = self.splitter.sizes()
290            state, w = self.spliter_restore_state
291            if state == 1: # restore dataView on change from 'hide all' to 'show all'
292                w = min(w, (w1 + w2)*2 / 3)
293            else:
294                w1, w2 = self.splitter.sizes()
295                w = widthForColumns(self.dataView, -2) + 4
296                w = min(w,  (w1 + w2) / 2)
297                w = max(w,  min(w1 + w2 - widthForColumns(self.predictionsView) - 20, w1 + w2 - w))
298            self.splitter.setSizes([w, w1 + w2 - w])
299            self.dataView.setMaximumWidth(16777215) # This is QWidget's max width
300           
301            self.spliter_restore_state = 0, w
302           
303    def setDataModel(self, data):
304        if data is not None and self.outvar is not None:
305            if self.ShowAttributeMethod == 1: # Show only the class column
306                data = orange.ExampleTable(orange.Domain([self.outvar]), data)
307            elif not data.domain.classVar: # add outVar as class (all unknown values) to data
308                domain = orange.Domain(data.domain.attributes + [self.outvar])
309                domain.addmetas(data.domain.getmetas())
310                data = orange.ExampleTable(domain, data)
311               
312            dist = getCached(data, orange.DomainBasicAttrStat, (data,))
313            self.dataModel = ExampleTableModel(data, dist, None)
314            self.dataView.setModel(self.dataModel)
315            self.dataView.setItemDelegate(OWGUI.TableBarItem(self, data, color = Qt.lightGray))
316            count = self.dataModel.columnCount()
317            if count:
318                self.dataView.scrollTo(self.dataModel.index(0, count - 1))
319            self.dataView.show()
320        else:
321            self.clear()
322        self.updateSpliter()
323           
324    def setPredictionModel(self, classifiers, data):
325        predictions = [(c, [c(ex, orange.GetBoth) for ex in self.data]) for c in classifiers]
326        self.predictionsModel = PredictionTableModel(predictions)
327        self.predictionsView.setModel(self.predictionsModel)
328        self.setPredictionDelegate()
329        self.predictionsView.show()
330       
331       
332    def setPredictionDelegate(self):
333        delegate = PredictionItemDelegete(self)
334        delegate.showProbs = [self.showProb and i in self.selectedClasses for i in range(len(self.classes))]
335        delegate.decimals = self.precision
336        delegate.showClass = self.showClass
337        self.predictionsView.setItemDelegate(delegate)
338        self.predictionsView.resizeColumnsToContents()
339        self.updateSpliter()
340
341#    def sort(self, col):
342#        "sorts the table by column col"
343#        self.sortby = - self.sortby
344#        self.table.sortItems(col, self.sortby>=0)
345#
346#        # the table may be sorted, figure out data indices
347#        for i in range(len(self.data)):
348#            self.rindx[int(str(self.table.item(i,0).text()))-1] = i
349#        for (i, indx) in enumerate(self.rindx):
350#            self.vheader.setLabel(i, self.table.item(i,0).text())
351
352    def checkenable(self):
353        # following should be more complicated and depends on what data are we showing
354        cond = (self.outvar != None) and (self.data != None)
355        self.outbox.setEnabled(cond)
356        self.att.setEnabled(cond)
357        self.copt.setEnabled(cond)
358        e = (self.data and (self.data.domain.classVar <> None) + len(self.predictors)) >= 2
359        # need at least two classes to compare predictions
360
361    def clear(self):
362        self.send("Predictions", None)
363        self.checkenable()
364        if len(self.predictors) == 0:
365            self.outvar = None
366            self.classes = []
367            self.selectedClasses = []
368            self.predictorlabel = "N/A"
369        self.dataModel = PyTableModel([[]])
370        self.dataView.setModel(self.dataModel)
371        self.predictionsModel = PyTableModel([[]])
372        self.predictionsView.setModel(self.predictionsModel)
373       
374        self.dataView.hide()
375        self.predictionsView.hide()
376       
377
378    ##############################################################################
379    # Input signals
380
381    def handleNewSignals(self):
382        self.handledAllSignalsFlag = True
383        if self.data:
384            self.setDataModel(self.data)
385            self.setPredictionModel(self.predictors.values(), self.data)
386        self.checksendpredictions()
387
388    def setData(self, data):
389        self.handledAllSignalsFlag = False
390        if not data:
391            self.data = data
392            self.datalabel = "N/A"
393            self.clear()
394        else:
395            vartypes = {1:"discrete", 2:"continuous"}
396            self.data = data
397            self.rindx = range(len(self.data))
398            self.datalabel = "%d instances" % len(data)
399           
400        self.checkenable()
401        self.changedFlag = True
402
403    def setPredictor(self, predictor, id):
404        """handles incoming classifier (prediction, as could be a regressor as well)"""
405
406        def getoutvar(predictors):
407            """return outcome variable, if consistent among predictors, else None"""
408            if not len(predictors):
409                return None
410            ov = predictors[0].classVar
411            for predictor in predictors[1:]:
412                if ov != predictor.classVar:
413                    self.warning(0, "Mismatch in class variable (e.g., predictors %s and %s)" % (predictors[0].name, predictor.name))
414                    return None
415            return ov
416
417        self.handledAllSignalsFlag = False
418       
419        # remove the classifier with id, if empty
420        if not predictor:
421            if self.predictors.has_key(id):
422                del self.predictors[id]
423                if len(self.predictors) == 0:
424                    self.clear()
425                else:
426                    self.predictorlabel = "%d" % len(self.predictors)
427            return
428
429        # set the classifier
430        self.predictors[id] = predictor
431        self.predictorlabel = "%d" % len(self.predictors)
432
433        # set the outcome variable
434        ov = getoutvar(self.predictors.values())
435        if len(self.predictors) and not ov:
436            self.tasklabel = "N/A (type mismatch)"
437            self.classes = []
438            self.selectedClasses = []
439            self.clear()
440            self.outvar = None
441            return
442        self.warning(0) # clear all warnings
443
444        if ov != self.outvar:
445            self.outvar = ov
446            # regression or classification?
447            if self.outvar.varType == orange.VarTypes.Continuous:
448                self.copt.hide()
449                self.tasklabel = "Regression"
450            else:
451                self.copt.show()
452                self.classes = [str(v) for v in self.outvar.values]
453                self.selectedClasses = []
454                self.tasklabel = "Classification"
455               
456        self.checkenable()
457        self.changedFlag = True
458
459    ##############################################################################
460    # Ouput signals
461
462    def checksendpredictions(self):
463        if self.sendOnChange:
464            self.sendpredictions()
465        else:
466            self.changedFlag = True
467
468    def sendpredictions(self):
469        if not self.data or not self.outvar:
470            self.send("Predictions", None)
471            return
472
473        # predictions, data set with class predictions
474        classification = self.outvar.varType == orange.VarTypes.Discrete
475
476        metas = []
477        if classification:
478            if len(self.selectedClasses):
479                for c in self.predictors.values():
480                    m = [orange.FloatVariable(name=str("%s(%s)" % (c.name, str(self.outvar.values[i]))),
481                                              getValueFrom = lambda ex, rw, cindx=i, c=c: orange.Value(c(ex, c.GetProbabilities)[cindx])) \
482                         for i in self.selectedClasses]
483                    metas.extend(m)
484            if self.showClass:
485                mc = [orange.EnumVariable(name=str(c.name), values = self.outvar.values,
486                                         getValueFrom = lambda ex, rw, c=c: orange.Value(c(ex)))
487                      for c in self.predictors.values()]
488                metas.extend(mc)
489        else:
490            # regression
491            mc = [orange.FloatVariable(name="%s" % c.name, 
492                                       getValueFrom = lambda ex, rw, c=c: orange.Value(c(ex)))
493                  for c in self.predictors.values()]
494            metas.extend(mc)
495               
496        classVar = self.outvar
497        domain = orange.Domain(self.data.domain.attributes + [classVar])
498        domain.addmetas(self.data.domain.getmetas())
499        for m in metas:
500            domain.addmeta(orange.newmetaid(), m)
501        predictions = orange.ExampleTable(domain, self.data)
502        if self.doPrediction:
503            c = self.predictors.values()[0]
504            for ex in predictions:
505                ex[classVar] = c(ex)
506               
507        predictions.name = self.data.name
508        self.send("Predictions", predictions)
509       
510        self.changedFlag = False
511
512##############################################################################
513# Test the widget, run from DOS prompt
514
515if __name__=="__main__":
516    a = QApplication(sys.argv)
517    ow = OWPredictions()
518    ow.show()
519
520    import orngTree
521
522    dataset = orange.ExampleTable('../../doc/datasets/iris.tab')
523#    dataset = orange.ExampleTable('../../doc/datasets/auto-mpg.tab')
524    ind = orange.MakeRandomIndices2(p0=0.5)(dataset)
525    data = dataset.select(ind, 0)
526    test = dataset.select(ind, 1)
527    testnoclass = orange.ExampleTable(orange.Domain(test.domain.attributes, False), test)       
528    tree = orngTree.TreeLearner(data)
529    tree.name = "tree"
530    maj = orange.MajorityLearner(data)
531    maj.name = "maj"
532    knn = orange.kNNLearner(data, k = 10)
533    knn.name = "knn"
534   
535#    ow.setData(test)
536#   
537#    ow.setPredictor(maj, 1)
538   
539   
540
541    if 1: # data set only
542        ow.setData(test)
543    if 0: # two predictors, test data with class
544        ow.setPredictor(maj, 1)
545        ow.setPredictor(tree, 2)
546        ow.setData(test)
547    if 0: # two predictors, test data with no class
548        ow.setPredictor(maj, 1)
549        ow.setPredictor(tree, 2)
550        ow.setData(testnoclass)
551    if 1: # three predictors
552        ow.setPredictor(tree, 1)
553        ow.setPredictor(maj, 2)
554        ow.setData(data)
555        ow.setPredictor(knn, 3)
556    if 0: # just classifier, no data
557        ow.setData(data)
558        ow.setPredictor(maj, 1)
559        ow.setPredictor(knn, 2)
560    if 0: # change data set
561        ow.setPredictor(maj, 1)
562        ow.setPredictor(tree, 2)
563        ow.setData(testnoclass)
564        data = orange.ExampleTable('../../doc/datasets/titanic.tab')
565        tree = orngTree.TreeLearner(data)
566        tree.name = "tree"
567        ow.setPredictor(tree, 2)
568        ow.setData(data)
569       
570    ow.handleNewSignals()
571
572    a.exec_()
573    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.