source: orange/orange/OrangeWidgets/Evaluate/OWConfusionMatrix.py @ 9505:4b798678cd3d

Revision 9505:4b798678cd3d, 12.0 KB checked in by matija <matija.polajnar@…>, 2 years ago (diff)

Merge in the (heavily modified) MLC code from GSOC 2011 (modules, documentation, evaluation code, regression test). Widgets will be merged in a little bit later, which will finally close ticket #992.

Line 
1"""
2<name>Confusion Matrix</name>
3<description>Shows a confusion matrix.</description>
4<contact>Janez Demsar</contact>
5<icon>icons/ConfusionMatrix.png</icon>
6<priority>1001</priority>
7"""
8from OWWidget import *
9import OWGUI
10import orngStat, orngTest
11import statc, math
12from operator import add
13           
14class TransformedLabel(QLabel):
15    def __init__(self, text, parent=None):
16        QLabel.__init__(self, text, parent)
17        self.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.MinimumExpanding)
18        self.setMaximumWidth(self.sizeHint().width() + 2)
19        self.setMargin(4)
20       
21    def sizeHint(self):
22        metrics = QFontMetrics(self.font())
23        rect = metrics.boundingRect(self.text())
24        size = QSize(rect.height() + self.margin(), rect.width() + self.margin())
25        return size
26   
27    def setGeometry(self, rect):
28        QLabel.setGeometry(self, rect)
29       
30    def paintEvent(self, event):
31        painter = QPainter(self)
32        rect = self.geometry()
33        textRect = QRect(0,0,rect.width(), rect.height())
34
35        painter.translate(textRect.bottomLeft())
36        painter.rotate(-90)
37        painter.drawText(QRect(QPoint(0, 0), QSize(rect.height(), rect.width())), Qt.AlignCenter, self.text())
38        painter.end()
39       
40       
41class OWConfusionMatrix(OWWidget):
42    settings = ["shownQuantity", "autoApply", "appendPredictions", "appendProbabilities"]
43
44    quantities = ["Number of examples", "Observed and expected examples", "Proportions of predicted", "Proportions of true"]
45    def __init__(self,parent=None, signalManager = None):
46        OWWidget.__init__(self, parent, signalManager, "Confusion Matrix", 1)
47
48        # inputs
49        self.inputs=[("Evaluation Results", orngTest.ExperimentResults, self.setTestResults, Default)]
50        self.outputs=[("Selected Examples", ExampleTable, 8)]
51
52        self.selectedLearner = []
53        self.learnerNames = []
54        self.selectionDirty = 0
55        self.autoApply = True
56        self.appendPredictions = True
57        self.appendProbabilities = False
58        self.shownQuantity = 0
59
60        self.learnerList = OWGUI.listBox(self.controlArea, self, "selectedLearner", "learnerNames", box = "Learners", callback = self.learnerChanged)
61        self.learnerList.setMinimumHeight(100)
62       
63        OWGUI.separator(self.controlArea)
64
65        OWGUI.comboBox(self.controlArea, self, "shownQuantity", items = self.quantities, box = "Show", callback=self.reprint)
66
67        OWGUI.separator(self.controlArea)
68       
69        box = OWGUI.widgetBox(self.controlArea, "Selection") #, addSpace=True)
70        OWGUI.button(box, self, "Correct", callback=self.selectCorrect)
71        OWGUI.button(box, self, "Misclassified", callback=self.selectWrong)
72        OWGUI.button(box, self, "None", callback=self.selectNone)
73       
74        OWGUI.separator(self.controlArea)
75
76        box = OWGUI.widgetBox(self.controlArea, "Output")
77        OWGUI.checkBox(box, self, "appendPredictions", "Append class predictions", callback = self.sendIf)
78        OWGUI.checkBox(box, self, "appendProbabilities", "Append predicted class probabilities", callback = self.sendIf)
79        applyButton = OWGUI.button(box, self, "Commit", callback = self.sendData, default=True)
80        autoApplyCB = OWGUI.checkBox(box, self, "autoApply", "Commit automatically")
81        OWGUI.setStopper(self, applyButton, autoApplyCB, "selectionDirty", self.sendData)
82
83        import sip
84        sip.delete(self.mainArea.layout())
85        self.layout = QGridLayout(self.mainArea)
86
87        self.layout.addWidget(OWGUI.widgetLabel(self.mainArea, "Prediction"), 0, 1, Qt.AlignCenter)
88       
89        label = TransformedLabel("Correct Class")
90        self.layout.addWidget(label, 2, 0, Qt.AlignCenter)
91#        self.layout.addWidget(OWGUI.widgetLabel(self.mainArea, "Correct Class  "), 2, 0, Qt.AlignCenter)
92        self.table = OWGUI.table(self.mainArea, rows = 0, columns = 0, selectionMode = QTableWidget.MultiSelection, addToLayout = 0)
93        self.layout.addWidget(self.table, 2, 1)
94        self.layout.setColumnStretch(1, 100)
95        self.layout.setRowStretch(2, 100)
96        self.connect(self.table, SIGNAL("itemSelectionChanged()"), self.sendIf)
97       
98        self.res = None
99        self.matrix = None
100        self.selectedLearner = None
101        self.resize(700,450)
102
103
104    def setTestResults(self, res):
105        self.res = res
106        if not res:
107            self.matrix = None
108            self.table.setRowCount(0)
109            self.table.setColumnCount(0)
110            return
111
112        self.matrix = orngStat.confusionMatrices(res, -2)
113
114        dim = len(res.classValues)
115
116        self.table.setRowCount(dim+1)
117        self.table.setColumnCount(dim+1)
118
119        self.table.setHorizontalHeaderLabels(res.classValues+[""])
120        self.table.setVerticalHeaderLabels(res.classValues+[""])
121
122        for ri in range(dim+1):
123            for ci in range(dim+1):
124                it = QTableWidgetItem()
125                it.setFlags(Qt.ItemIsEnabled | (ri<dim and ci<dim and Qt.ItemIsSelectable or Qt.NoItemFlags))
126                it.setTextAlignment(Qt.AlignRight)
127                self.table.setItem(ri, ci, it)
128
129        boldf = self.table.item(0, dim).font()
130        boldf.setBold(True)
131        for ri in range(dim+1):
132            self.table.item(ri, dim).setFont(boldf)
133            self.table.item(dim, ri).setFont(boldf)
134           
135        self.learnerNames = res.classifierNames[:]
136        if not self.selectedLearner and self.res.numberOfLearners:
137            self.selectedLearner = [0]
138        self.learnerChanged()
139        self.table.clearSelection()
140
141
142    def learnerChanged(self):
143        if not (self.res and self.res.numberOfLearners):
144            return
145       
146        if self.selectedLearner and self.selectedLearner[0] > self.res.numberOfLearners:
147            self.selectedLearner = [0]
148        if not self.selectedLearner:
149            return
150       
151        cm = self.matrix[self.selectedLearner[0]]
152       
153        self.isInteger = " %i "
154        for r in reduce(add, cm):
155            if int(r) != r:
156                self.isInteger = " %5.3f "
157                break
158
159        self.reprint()
160        self.sendIf()
161
162
163    def reprint(self):
164        if not self.matrix or not self.selectedLearner: 
165            return
166       
167        cm = self.matrix[self.selectedLearner[0]]
168
169        dim = len(cm)
170        rowSums = [sum(r) for r in cm]
171        colSums = [sum([r[i] for r in cm]) for i in range(dim)]
172        total = sum(rowSums)
173        if self.shownQuantity == 1:
174            if total > 1e-5:
175                rowPriors = [r/total for r in rowSums]
176                colPriors = [r/total for r in colSums]
177            else:
178                rowPriors = [0 for r in rowSums]
179                colPriors = [0 for r in colSums]
180
181        for ri, r in enumerate(cm):
182            for ci, c in enumerate(r):
183                item = self.table.item(ri, ci)
184                if not item:
185                    continue
186                if self.shownQuantity == 0:
187                    item.setText(self.isInteger % c)
188                elif self.shownQuantity == 1:
189                    item.setText((self.isInteger + "/ %5.3f ") % (c, total*rowPriors[ri]*colPriors[ci]))
190                elif self.shownQuantity == 2:
191                    item.setText(colSums[ci] > 1e-5 and (" %2.1f %%  " % (100 * c / colSums[ci])) or " "+"N/A"+" ")
192                elif self.shownQuantity == 3:
193                    item.setText(rowSums[ri] > 1e-5 and (" %2.1f %%  " % (100 * c / rowSums[ri])) or " "+"N/A"+" ")
194
195        for ci in range(dim):
196            self.table.item(dim, ci).setText(self.isInteger % colSums[ci])
197            self.table.item(ci, dim).setText(self.isInteger % rowSums[ci])
198        self.table.item(dim, dim).setText(self.isInteger % total)
199
200        self.table.resizeColumnsToContents()
201
202    def sendReport(self):
203        self.reportSettings("Contents",
204                            [("Learner", self.learnerNames[self.selectedLearner[0]] \
205                                            if self.learnerNames else \
206                                         "N/A"),
207                             ("Data", self.quantities[self.shownQuantity])])
208       
209        if self.res:
210            self.reportSection("Matrix")
211            classVals = self.res.classValues
212            nClassVals = len(classVals)
213            res = "<table>\n<tr><td></td>" + "".join('<td align="center"><b>&nbsp;&nbsp;%s&nbsp;&nbsp;</b></td>' % cv for cv in classVals) + "</tr>\n"
214            for i, cv in enumerate(classVals):
215                res += '<tr><th align="right"><b>%s</b></th>' % cv + \
216                       "".join('<td align="center">%s</td>' % self.table.item(i, j).text() for j in range(nClassVals)) + \
217                       '<th align="right"><b>%s</b></th>' % self.table.item(i, nClassVals).text() + \
218                       "</tr>\n"
219            res += '<tr><th></th>' + \
220                   "".join('<td align="center"><b>%s</b></td>' % self.table.item(nClassVals, j).text() for j in range(nClassVals+1)) + \
221                   "</tr>\n"
222            res += "</table>\n<p><b>Note:</b> columns represent predictions, row represent true classes</p>"
223            self.reportRaw(res)
224           
225
226    def selectCorrect(self):
227        if not self.res:
228            return
229
230        sa = self.autoApply
231        self.autoApply = False
232        self.table.clearSelection()
233        for i in range(len(self.matrix[0])):
234            self.table.setRangeSelected(QTableWidgetSelectionRange(i, i, i, i), True)
235        self.autoApply = sa
236        self.sendIf()
237
238
239    def selectWrong(self):
240        if not self.res:
241            return
242
243        sa = self.autoApply
244        self.autoApply = False
245        self.table.clearSelection()
246        dim = len(self.matrix[0])
247        self.table.setRangeSelected(QTableWidgetSelectionRange(0, 0, dim-1, dim-1), True)
248        for i in range(len(self.matrix[0])):
249            self.table.setRangeSelected(QTableWidgetSelectionRange(i, i, i, i), False)
250        self.autoApply = sa
251        self.sendIf()
252
253
254    def selectNone(self):
255        self.table.clearSelection()
256
257
258    def sendIf(self):
259        if self.autoApply:
260            self.sendData()
261        else:
262            self.selectionDirty = True
263
264
265    def sendData(self):
266        self.selectionDirty = False
267
268        selected = [(x.row(), x.column()) for x in self.table.selectedIndexes()]
269        res = self.res
270        if not res or not selected or not self.selectedLearner:
271            self.send("Selected Examples", None)
272            return
273
274        learnerI = self.selectedLearner[0]
275        selectionIndices = [i for i, rese in enumerate(res.results) if (rese.actualClass, rese.classes[learnerI]) in selected]
276        data = res.examples.getitemsref(selectionIndices)
277       
278        if self.appendPredictions or self.appendProbabilities:
279            domain = orange.Domain(data.domain.attributes, data.domain.classVar)
280            domain.addmetas(data.domain.getmetas())
281            data = orange.ExampleTable(domain, data)
282       
283            if self.appendPredictions:
284                cname = self.learnerNames[learnerI]
285                predVar = type(domain.classVar)("%s(%s)" % (domain.classVar.name, cname.encode("utf-8") if isinstance(cname, unicode) else cname))
286                if hasattr(domain.classVar, "values"):
287                    predVar.values = domain.classVar.values
288                predictionsId = orange.newmetaid()
289                domain.addmeta(predictionsId, predVar)
290                for i, ex in zip(selectionIndices, data):
291                    ex[predictionsId] = res.results[i].classes[learnerI]
292                   
293            if self.appendProbabilities:
294                probVars = [orange.FloatVariable("p(%s)" % v) for v in domain.classVar.values]
295                probIds = [orange.newmetaid() for pv in probVars]
296                domain.addmetas(dict(zip(probIds, probVars)))
297                for i, ex in zip(selectionIndices, data):
298                    for id, p in zip(probIds, res.results[i].probabilities[learnerI]):
299                        ex[id] = p
300   
301        self.send("Selected Examples", data)
302
303
304if __name__ == "__main__":
305    a = QApplication(sys.argv)
306    owdm = OWConfusionMatrix()
307    owdm.show()
308    a.exec_()
Note: See TracBrowser for help on using the repository browser.