source:orange/Orange/OrangeWidgets/Evaluate/OWConfusionMatrix.py@10752:a0a6ef1ab0b8

Revision 10752:a0a6ef1ab0b8, 13.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Making sure testing results are for a classification problem in Callibration Plot, Confusion Matrix, Lift Curve and ROC.

Line
1"""
2<name>Confusion Matrix</name>
3<description>Shows a confusion matrix.</description>
4<contact>Janez Demsar</contact>
5<icon>icons/ConfusionMatrix.png</icon>
6<priority>1001</priority>
7"""
8from OWWidget import *
9import OWGUI
10import orange
11import orngStat, orngTest
12import statc, math
13
15from collections import defaultdict
16
17from Orange.evaluation.testing import TEST_TYPE_SINGLE
18
19class TransformedLabel(QLabel):
20    def __init__(self, text, parent=None):
21        QLabel.__init__(self, text, parent)
22        self.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.MinimumExpanding)
23        self.setMaximumWidth(self.sizeHint().width() + 2)
24        self.setMargin(4)
25
26    def sizeHint(self):
27        metrics = QFontMetrics(self.font())
28        rect = metrics.boundingRect(self.text())
29        size = QSize(rect.height() + self.margin(), rect.width() + self.margin())
30        return size
31
32    def setGeometry(self, rect):
33        QLabel.setGeometry(self, rect)
34
35    def paintEvent(self, event):
36        painter = QPainter(self)
37        rect = self.geometry()
38        textRect = QRect(0,0,rect.width(), rect.height())
39
40        painter.translate(textRect.bottomLeft())
41        painter.rotate(-90)
42        painter.drawText(QRect(QPoint(0, 0), QSize(rect.height(), rect.width())), Qt.AlignCenter, self.text())
43        painter.end()
44
45
46class OWConfusionMatrix(OWWidget):
47    settings = ["shownQuantity", "autoApply", "appendPredictions", "appendProbabilities"]
48
49    quantities = ["Number of examples", "Observed and expected examples", "Proportions of predicted", "Proportions of true"]
50    def __init__(self,parent=None, signalManager = None):
51        OWWidget.__init__(self, parent, signalManager, "Confusion Matrix", 1)
52
53        # inputs
54        self.inputs=[("Evaluation Results", orngTest.ExperimentResults, self.setTestResults, Default)]
55        self.outputs=[("Selected Data", ExampleTable, 8)]
56
57        self.selectedLearner = []
58        self.learnerNames = []
59        self.selectionDirty = 0
60        self.autoApply = True
61        self.appendPredictions = True
62        self.appendProbabilities = False
63        self.shownQuantity = 0
64
65        self.learnerList = OWGUI.listBox(self.controlArea, self, "selectedLearner", "learnerNames", box = "Learners", callback = self.learnerChanged)
66        self.learnerList.setMinimumHeight(100)
67
68        OWGUI.separator(self.controlArea)
69
70        OWGUI.comboBox(self.controlArea, self, "shownQuantity", items = self.quantities, box = "Show", callback=self.reprint)
71
72        OWGUI.separator(self.controlArea)
73
74        box = OWGUI.widgetBox(self.controlArea, "Selection") #, addSpace=True)
75        OWGUI.button(box, self, "Correct", callback=self.selectCorrect)
76        OWGUI.button(box, self, "Misclassified", callback=self.selectWrong)
77        OWGUI.button(box, self, "None", callback=self.selectNone)
78
79        OWGUI.separator(self.controlArea)
80
81        self.outputBox = box = OWGUI.widgetBox(self.controlArea, "Output")
82        OWGUI.checkBox(box, self, "appendPredictions", "Append class predictions", callback = self.sendIf)
83        OWGUI.checkBox(box, self, "appendProbabilities", "Append predicted class probabilities", callback = self.sendIf)
84        applyButton = OWGUI.button(box, self, "Commit", callback = self.sendData, default=True)
85        autoApplyCB = OWGUI.checkBox(box, self, "autoApply", "Commit automatically")
86        OWGUI.setStopper(self, applyButton, autoApplyCB, "selectionDirty", self.sendData)
87
88        import sip
89        sip.delete(self.mainArea.layout())
90        self.layout = QGridLayout(self.mainArea)
91
92        self.layout.addWidget(OWGUI.widgetLabel(self.mainArea, "Prediction"), 0, 1, Qt.AlignCenter)
93
94        label = TransformedLabel("Correct Class")
96#        self.layout.addWidget(OWGUI.widgetLabel(self.mainArea, "Correct Class  "), 2, 0, Qt.AlignCenter)
97        self.table = OWGUI.table(self.mainArea, rows = 0, columns = 0, selectionMode = QTableWidget.MultiSelection, addToLayout = 0)
99        self.layout.setColumnStretch(1, 100)
100        self.layout.setRowStretch(2, 100)
101        self.connect(self.table, SIGNAL("itemSelectionChanged()"), self.sendIf)
102
103        self.res = None
104        self.matrix = None
105        self.selectedLearner = None
106        self.resize(700,450)
107
108
109    def setTestResults(self, res):
110        self.res = res
111        self.warning([0, 1])
112        self.outputBox.setEnabled(True)
113
114        if res is not None and res.class_values is None:
115            self.warning(1, "Confusion Matrix cannot be used for regression results.")
116            self.res = res = None
117
118        if not res:
119            self.matrix = None
120            self.learnerNames = []
121            self.table.setRowCount(0)
122            self.table.setColumnCount(0)
123            return
124
125        if res and res.test_type != TEST_TYPE_SINGLE:
126            self.warning(0, "Confusion matrix can be calculated only for single-target prediction problems.")
127            return
128
129        canOutput = True
130        if not hasattr(res, "examples"):
131            self.warning(1, "Results do not have testing instances (Output is disabled).")
132            canOutput = False
133        elif not isinstance(res.examples, orange.ExampleTable):
134            self.warning(1, "Output for results from 'Proportion test' is not supported.")
135            canOutput = False
136
137        self.outputBox.setEnabled(canOutput)
138
139        self.matrix = orngStat.confusionMatrices(res, -2)
140
141        dim = len(res.classValues)
142
143        self.table.setRowCount(dim+1)
144        self.table.setColumnCount(dim+1)
145
148
149        for ri in range(dim+1):
150            for ci in range(dim+1):
151                it = QTableWidgetItem()
152                it.setFlags(Qt.ItemIsEnabled | (ri<dim and ci<dim and Qt.ItemIsSelectable or Qt.NoItemFlags))
153                it.setTextAlignment(Qt.AlignRight)
154                self.table.setItem(ri, ci, it)
155
156        boldf = self.table.item(0, dim).font()
157        boldf.setBold(True)
158        for ri in range(dim+1):
159            self.table.item(ri, dim).setFont(boldf)
160            self.table.item(dim, ri).setFont(boldf)
161
162        self.learnerNames = res.classifierNames[:]
163        if not self.selectedLearner and self.res.numberOfLearners:
164            self.selectedLearner = [0]
165        self.learnerChanged()
166        self.table.clearSelection()
167
168
169    def learnerChanged(self):
170        if not (self.res and self.res.numberOfLearners):
171            return
172
173        if self.selectedLearner and self.selectedLearner[0] > self.res.numberOfLearners:
174            self.selectedLearner = [0]
175        if not self.selectedLearner:
176            return
177
178        cm = self.matrix[self.selectedLearner[0]]
179
180        self.isInteger = " %i "
181        for r in reduce(add, cm):
182            if int(r) != r:
183                self.isInteger = " %5.3f "
184                break
185
186        self.reprint()
187        self.sendIf()
188
189
190    def reprint(self):
191        if not self.matrix or not self.selectedLearner:
192            return
193
194        cm = self.matrix[self.selectedLearner[0]]
195
196        dim = len(cm)
197        rowSums = [sum(r) for r in cm]
198        colSums = [sum([r[i] for r in cm]) for i in range(dim)]
199        total = sum(rowSums)
200        if self.shownQuantity == 1:
201            if total > 1e-5:
202                rowPriors = [r/total for r in rowSums]
203                colPriors = [r/total for r in colSums]
204            else:
205                rowPriors = [0 for r in rowSums]
206                colPriors = [0 for r in colSums]
207
208        for ri, r in enumerate(cm):
209            for ci, c in enumerate(r):
210                item = self.table.item(ri, ci)
211                if not item:
212                    continue
213                if self.shownQuantity == 0:
214                    item.setText(self.isInteger % c)
215                elif self.shownQuantity == 1:
216                    item.setText((self.isInteger + "/ %5.3f ") % (c, total*rowPriors[ri]*colPriors[ci]))
217                elif self.shownQuantity == 2:
218                    item.setText(colSums[ci] > 1e-5 and (" %2.1f %%  " % (100 * c / colSums[ci])) or " "+"N/A"+" ")
219                elif self.shownQuantity == 3:
220                    item.setText(rowSums[ri] > 1e-5 and (" %2.1f %%  " % (100 * c / rowSums[ri])) or " "+"N/A"+" ")
221
222        for ci in range(dim):
223            self.table.item(dim, ci).setText(self.isInteger % colSums[ci])
224            self.table.item(ci, dim).setText(self.isInteger % rowSums[ci])
225        self.table.item(dim, dim).setText(self.isInteger % total)
226
227        self.table.resizeColumnsToContents()
228
229    def sendReport(self):
230        self.reportSettings("Contents",
231                            [("Learner", self.learnerNames[self.selectedLearner[0]] \
232                                            if self.learnerNames else \
233                                         "N/A"),
234                             ("Data", self.quantities[self.shownQuantity])])
235
236        if self.res:
237            self.reportSection("Matrix")
238            classVals = self.res.classValues
239            nClassVals = len(classVals)
240            res = "<table>\n<tr><td></td>" + "".join('<td align="center"><b>&nbsp;&nbsp;%s&nbsp;&nbsp;</b></td>' % cv for cv in classVals) + "</tr>\n"
241            for i, cv in enumerate(classVals):
242                res += '<tr><th align="right"><b>%s</b></th>' % cv + \
243                       "".join('<td align="center">%s</td>' % self.table.item(i, j).text() for j in range(nClassVals)) + \
244                       '<th align="right"><b>%s</b></th>' % self.table.item(i, nClassVals).text() + \
245                       "</tr>\n"
246            res += '<tr><th></th>' + \
247                   "".join('<td align="center"><b>%s</b></td>' % self.table.item(nClassVals, j).text() for j in range(nClassVals+1)) + \
248                   "</tr>\n"
249            res += "</table>\n<p><b>Note:</b> columns represent predictions, row represent true classes</p>"
250            self.reportRaw(res)
251
252
253    def selectCorrect(self):
254        if not self.res:
255            return
256
257        sa = self.autoApply
258        self.autoApply = False
259        self.table.clearSelection()
260        for i in range(len(self.matrix[0])):
261            self.table.setRangeSelected(QTableWidgetSelectionRange(i, i, i, i), True)
262        self.autoApply = sa
263        self.sendIf()
264
265
266    def selectWrong(self):
267        if not self.res:
268            return
269
270        sa = self.autoApply
271        self.autoApply = False
272        self.table.clearSelection()
273        dim = len(self.matrix[0])
274        self.table.setRangeSelected(QTableWidgetSelectionRange(0, 0, dim-1, dim-1), True)
275        for i in range(len(self.matrix[0])):
276            self.table.setRangeSelected(QTableWidgetSelectionRange(i, i, i, i), False)
277        self.autoApply = sa
278        self.sendIf()
279
280
281    def selectNone(self):
282        self.table.clearSelection()
283
284
285    def sendIf(self):
286        if self.autoApply:
287            self.sendData()
288        else:
289            self.selectionDirty = True
290
291
292    def sendData(self):
293        self.selectionDirty = False
294
295        selected = [(x.row(), x.column()) for x in self.table.selectedIndexes()]
296        res = self.res
297        if not res or not selected or not self.selectedLearner:
298            self.send("Selected Data", None)
299            return
300
301        learnerI = self.selectedLearner[0]
302
303        data = None
304        if hasattr(res, "examples") and isinstance(res.examples, orange.ExampleTable):
305            selectionIndices = [i for i, rese in enumerate(res.results) if (rese.actualClass, rese.classes[learnerI]) in selected]
306            data = res.examples.getitemsref(selectionIndices)
307
308        if data is not None and (self.appendPredictions or self.appendProbabilities):
309            domain = orange.Domain(data.domain.attributes, data.domain.classVar)
311            data = orange.ExampleTable(domain, data)
312
313            if self.appendPredictions:
314                cname = self.learnerNames[learnerI]
315                predVar = type(domain.classVar)("%s(%s)" % (domain.classVar.name, cname.encode("utf-8") if isinstance(cname, unicode) else cname))
316                if hasattr(domain.classVar, "values"):
317                    predVar.values = domain.classVar.values
318                predictionsId = orange.newmetaid()
320                for i, ex in zip(selectionIndices, data):
321                    ex[predictionsId] = res.results[i].classes[learnerI]
322
323            if self.appendProbabilities:
324                probVars = [orange.FloatVariable("p(%s)" % v) for v in domain.classVar.values]
325                probIds = [orange.newmetaid() for pv in probVars]