source: orange/orange/OrangeWidgets/Evaluate/OWConfusionMatrix.py @ 9599:69c91d52d3e4

Revision 9599:69c91d52d3e4, 12.3 KB checked in by Matija Polajnar <matija.polajnar@…>, 2 years ago (diff)

Multi-label classificaiton widgets. Merged in from Wencan Luo's work with some modifications.

Line 
1"""
2<name>Confusion Matrix</name>
3<description>Shows a confusion matrix.</description>
4<contact>Janez Demsar</contact>
5<icon>icons/ConfusionMatrix.png</icon>
6<priority>1001</priority>
7"""
8from OWWidget import *
9import OWGUI
10import orngStat, orngTest
11import statc, math
12from operator import add
13from Orange.evaluation.testing import TEST_TYPE_SINGLE
14           
15class TransformedLabel(QLabel):
16    def __init__(self, text, parent=None):
17        QLabel.__init__(self, text, parent)
18        self.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.MinimumExpanding)
19        self.setMaximumWidth(self.sizeHint().width() + 2)
20        self.setMargin(4)
21       
22    def sizeHint(self):
23        metrics = QFontMetrics(self.font())
24        rect = metrics.boundingRect(self.text())
25        size = QSize(rect.height() + self.margin(), rect.width() + self.margin())
26        return size
27   
28    def setGeometry(self, rect):
29        QLabel.setGeometry(self, rect)
30       
31    def paintEvent(self, event):
32        painter = QPainter(self)
33        rect = self.geometry()
34        textRect = QRect(0,0,rect.width(), rect.height())
35
36        painter.translate(textRect.bottomLeft())
37        painter.rotate(-90)
38        painter.drawText(QRect(QPoint(0, 0), QSize(rect.height(), rect.width())), Qt.AlignCenter, self.text())
39        painter.end()
40       
41       
42class OWConfusionMatrix(OWWidget):
43    settings = ["shownQuantity", "autoApply", "appendPredictions", "appendProbabilities"]
44
45    quantities = ["Number of examples", "Observed and expected examples", "Proportions of predicted", "Proportions of true"]
46    def __init__(self,parent=None, signalManager = None):
47        OWWidget.__init__(self, parent, signalManager, "Confusion Matrix", 1)
48
49        # inputs
50        self.inputs=[("Evaluation Results", orngTest.ExperimentResults, self.setTestResults, Default)]
51        self.outputs=[("Selected Data", ExampleTable, 8)]
52
53        self.selectedLearner = []
54        self.learnerNames = []
55        self.selectionDirty = 0
56        self.autoApply = True
57        self.appendPredictions = True
58        self.appendProbabilities = False
59        self.shownQuantity = 0
60
61        self.learnerList = OWGUI.listBox(self.controlArea, self, "selectedLearner", "learnerNames", box = "Learners", callback = self.learnerChanged)
62        self.learnerList.setMinimumHeight(100)
63       
64        OWGUI.separator(self.controlArea)
65
66        OWGUI.comboBox(self.controlArea, self, "shownQuantity", items = self.quantities, box = "Show", callback=self.reprint)
67
68        OWGUI.separator(self.controlArea)
69       
70        box = OWGUI.widgetBox(self.controlArea, "Selection") #, addSpace=True)
71        OWGUI.button(box, self, "Correct", callback=self.selectCorrect)
72        OWGUI.button(box, self, "Misclassified", callback=self.selectWrong)
73        OWGUI.button(box, self, "None", callback=self.selectNone)
74       
75        OWGUI.separator(self.controlArea)
76
77        box = OWGUI.widgetBox(self.controlArea, "Output")
78        OWGUI.checkBox(box, self, "appendPredictions", "Append class predictions", callback = self.sendIf)
79        OWGUI.checkBox(box, self, "appendProbabilities", "Append predicted class probabilities", callback = self.sendIf)
80        applyButton = OWGUI.button(box, self, "Commit", callback = self.sendData, default=True)
81        autoApplyCB = OWGUI.checkBox(box, self, "autoApply", "Commit automatically")
82        OWGUI.setStopper(self, applyButton, autoApplyCB, "selectionDirty", self.sendData)
83
84        import sip
85        sip.delete(self.mainArea.layout())
86        self.layout = QGridLayout(self.mainArea)
87
88        self.layout.addWidget(OWGUI.widgetLabel(self.mainArea, "Prediction"), 0, 1, Qt.AlignCenter)
89       
90        label = TransformedLabel("Correct Class")
91        self.layout.addWidget(label, 2, 0, Qt.AlignCenter)
92#        self.layout.addWidget(OWGUI.widgetLabel(self.mainArea, "Correct Class  "), 2, 0, Qt.AlignCenter)
93        self.table = OWGUI.table(self.mainArea, rows = 0, columns = 0, selectionMode = QTableWidget.MultiSelection, addToLayout = 0)
94        self.layout.addWidget(self.table, 2, 1)
95        self.layout.setColumnStretch(1, 100)
96        self.layout.setRowStretch(2, 100)
97        self.connect(self.table, SIGNAL("itemSelectionChanged()"), self.sendIf)
98       
99        self.res = None
100        self.matrix = None
101        self.selectedLearner = None
102        self.resize(700,450)
103
104
105    def setTestResults(self, res):
106        self.res = res
107        if not res:
108            self.matrix = None
109            self.table.setRowCount(0)
110            self.table.setColumnCount(0)
111            return
112
113        if res and res.test_type != TEST_TYPE_SINGLE:
114            self.warning(0, "Confusion matrix can be calculated only for single-target prediction problems.")
115            return
116        self.warning(0, None)
117       
118        self.matrix = orngStat.confusionMatrices(res, -2)
119
120        dim = len(res.classValues)
121
122        self.table.setRowCount(dim+1)
123        self.table.setColumnCount(dim+1)
124
125        self.table.setHorizontalHeaderLabels(res.classValues+[""])
126        self.table.setVerticalHeaderLabels(res.classValues+[""])
127
128        for ri in range(dim+1):
129            for ci in range(dim+1):
130                it = QTableWidgetItem()
131                it.setFlags(Qt.ItemIsEnabled | (ri<dim and ci<dim and Qt.ItemIsSelectable or Qt.NoItemFlags))
132                it.setTextAlignment(Qt.AlignRight)
133                self.table.setItem(ri, ci, it)
134
135        boldf = self.table.item(0, dim).font()
136        boldf.setBold(True)
137        for ri in range(dim+1):
138            self.table.item(ri, dim).setFont(boldf)
139            self.table.item(dim, ri).setFont(boldf)
140           
141        self.learnerNames = res.classifierNames[:]
142        if not self.selectedLearner and self.res.numberOfLearners:
143            self.selectedLearner = [0]
144        self.learnerChanged()
145        self.table.clearSelection()
146
147
148    def learnerChanged(self):
149        if not (self.res and self.res.numberOfLearners):
150            return
151       
152        if self.selectedLearner and self.selectedLearner[0] > self.res.numberOfLearners:
153            self.selectedLearner = [0]
154        if not self.selectedLearner:
155            return
156       
157        cm = self.matrix[self.selectedLearner[0]]
158       
159        self.isInteger = " %i "
160        for r in reduce(add, cm):
161            if int(r) != r:
162                self.isInteger = " %5.3f "
163                break
164
165        self.reprint()
166        self.sendIf()
167
168
169    def reprint(self):
170        if not self.matrix or not self.selectedLearner: 
171            return
172       
173        cm = self.matrix[self.selectedLearner[0]]
174
175        dim = len(cm)
176        rowSums = [sum(r) for r in cm]
177        colSums = [sum([r[i] for r in cm]) for i in range(dim)]
178        total = sum(rowSums)
179        if self.shownQuantity == 1:
180            if total > 1e-5:
181                rowPriors = [r/total for r in rowSums]
182                colPriors = [r/total for r in colSums]
183            else:
184                rowPriors = [0 for r in rowSums]
185                colPriors = [0 for r in colSums]
186
187        for ri, r in enumerate(cm):
188            for ci, c in enumerate(r):
189                item = self.table.item(ri, ci)
190                if not item:
191                    continue
192                if self.shownQuantity == 0:
193                    item.setText(self.isInteger % c)
194                elif self.shownQuantity == 1:
195                    item.setText((self.isInteger + "/ %5.3f ") % (c, total*rowPriors[ri]*colPriors[ci]))
196                elif self.shownQuantity == 2:
197                    item.setText(colSums[ci] > 1e-5 and (" %2.1f %%  " % (100 * c / colSums[ci])) or " "+"N/A"+" ")
198                elif self.shownQuantity == 3:
199                    item.setText(rowSums[ri] > 1e-5 and (" %2.1f %%  " % (100 * c / rowSums[ri])) or " "+"N/A"+" ")
200
201        for ci in range(dim):
202            self.table.item(dim, ci).setText(self.isInteger % colSums[ci])
203            self.table.item(ci, dim).setText(self.isInteger % rowSums[ci])
204        self.table.item(dim, dim).setText(self.isInteger % total)
205
206        self.table.resizeColumnsToContents()
207
208    def sendReport(self):
209        self.reportSettings("Contents",
210                            [("Learner", self.learnerNames[self.selectedLearner[0]] \
211                                            if self.learnerNames else \
212                                         "N/A"),
213                             ("Data", self.quantities[self.shownQuantity])])
214       
215        if self.res:
216            self.reportSection("Matrix")
217            classVals = self.res.classValues
218            nClassVals = len(classVals)
219            res = "<table>\n<tr><td></td>" + "".join('<td align="center"><b>&nbsp;&nbsp;%s&nbsp;&nbsp;</b></td>' % cv for cv in classVals) + "</tr>\n"
220            for i, cv in enumerate(classVals):
221                res += '<tr><th align="right"><b>%s</b></th>' % cv + \
222                       "".join('<td align="center">%s</td>' % self.table.item(i, j).text() for j in range(nClassVals)) + \
223                       '<th align="right"><b>%s</b></th>' % self.table.item(i, nClassVals).text() + \
224                       "</tr>\n"
225            res += '<tr><th></th>' + \
226                   "".join('<td align="center"><b>%s</b></td>' % self.table.item(nClassVals, j).text() for j in range(nClassVals+1)) + \
227                   "</tr>\n"
228            res += "</table>\n<p><b>Note:</b> columns represent predictions, row represent true classes</p>"
229            self.reportRaw(res)
230           
231
232    def selectCorrect(self):
233        if not self.res:
234            return
235
236        sa = self.autoApply
237        self.autoApply = False
238        self.table.clearSelection()
239        for i in range(len(self.matrix[0])):
240            self.table.setRangeSelected(QTableWidgetSelectionRange(i, i, i, i), True)
241        self.autoApply = sa
242        self.sendIf()
243
244
245    def selectWrong(self):
246        if not self.res:
247            return
248
249        sa = self.autoApply
250        self.autoApply = False
251        self.table.clearSelection()
252        dim = len(self.matrix[0])
253        self.table.setRangeSelected(QTableWidgetSelectionRange(0, 0, dim-1, dim-1), True)
254        for i in range(len(self.matrix[0])):
255            self.table.setRangeSelected(QTableWidgetSelectionRange(i, i, i, i), False)
256        self.autoApply = sa
257        self.sendIf()
258
259
260    def selectNone(self):
261        self.table.clearSelection()
262
263
264    def sendIf(self):
265        if self.autoApply:
266            self.sendData()
267        else:
268            self.selectionDirty = True
269
270
271    def sendData(self):
272        self.selectionDirty = False
273
274        selected = [(x.row(), x.column()) for x in self.table.selectedIndexes()]
275        res = self.res
276        if not res or not selected or not self.selectedLearner:
277            self.send("Selected Data", None)
278            return
279
280        learnerI = self.selectedLearner[0]
281        selectionIndices = [i for i, rese in enumerate(res.results) if (rese.actualClass, rese.classes[learnerI]) in selected]
282        data = res.examples.getitemsref(selectionIndices)
283       
284        if self.appendPredictions or self.appendProbabilities:
285            domain = orange.Domain(data.domain.attributes, data.domain.classVar)
286            domain.addmetas(data.domain.getmetas())
287            data = orange.ExampleTable(domain, data)
288       
289            if self.appendPredictions:
290                cname = self.learnerNames[learnerI]
291                predVar = type(domain.classVar)("%s(%s)" % (domain.classVar.name, cname.encode("utf-8") if isinstance(cname, unicode) else cname))
292                if hasattr(domain.classVar, "values"):
293                    predVar.values = domain.classVar.values
294                predictionsId = orange.newmetaid()
295                domain.addmeta(predictionsId, predVar)
296                for i, ex in zip(selectionIndices, data):
297                    ex[predictionsId] = res.results[i].classes[learnerI]
298                   
299            if self.appendProbabilities:
300                probVars = [orange.FloatVariable("p(%s)" % v) for v in domain.classVar.values]
301                probIds = [orange.newmetaid() for pv in probVars]
302                domain.addmetas(dict(zip(probIds, probVars)))
303                for i, ex in zip(selectionIndices, data):
304                    for id, p in zip(probIds, res.results[i].probabilities[learnerI]):
305                        ex[id] = p
306   
307        self.send("Selected Data", data)
308
309
310if __name__ == "__main__":
311    a = QApplication(sys.argv)
312    owdm = OWConfusionMatrix()
313    owdm.show()
314    a.exec_()
Note: See TracBrowser for help on using the repository browser.