source: orange/orange/OrangeWidgets/Evaluate/OWCalibrationPlot.py @ 9505:4b798678cd3d

Revision 9505:4b798678cd3d, 13.9 KB checked in by matija <matija.polajnar@…>, 2 years ago (diff)

Merge in the (heavily modified) MLC code from GSOC 2011 (modules, documentation, evaluation code, regression test). Widgets will be merged in a little bit later, which will finally close ticket #992.

Line 
1"""
2<name>Calibration Plot</name>
3<description>Displays calibration plot based on evaluation of classifiers.</description>
4<contact>Tomaz Curk</contact>
5<icon>icons/CalibrationPlot.png</icon>
6<priority>1030</priority>
7"""
8from OWColorPalette import ColorPixmap
9from OWWidget import *
10from OWGraph import *
11import OWGUI
12
13import orngTest, orngStat
14import statc, math
15
16class singleClassCalibrationPlotGraph(OWGraph):
17    def __init__(self, parent = None, name = None, title = ""):
18        OWGraph.__init__(self, parent, name)
19        self.setYRlabels(None)
20        self.enableGridXB(0)
21        self.enableGridYL(0)
22        self.setAxisMaxMajor(QwtPlot.xBottom, 10)
23        self.setAxisMaxMinor(QwtPlot.xBottom, 5)
24        self.setAxisMaxMajor(QwtPlot.yLeft, 10)
25        self.setAxisMaxMinor(QwtPlot.yLeft, 5)
26        self.setAxisScale(QwtPlot.xBottom, -0.0, 1.0, 0)
27        self.setAxisScale(QwtPlot.yLeft, -0.0, 1.0, 0)
28        self.setYLaxisTitle('actual probability')
29        self.setShowYLaxisTitle(1)
30        self.setXaxisTitle('estimated probability')
31        self.setShowXaxisTitle(1)
32        self.setShowMainTitle(1)
33        self.setMainTitle(title)
34        self.dres = None
35        self.numberOfClasses = None
36        self.targetClass = None
37        self.rugHeight = 0.02
38
39        self.removeCurves()
40
41    def setData(self, classifierColor, dres, targetClass):
42        self.classifierColor = classifierColor
43        self.dres = dres
44        self.targetClass = targetClass
45
46        classifiersNum = len(self.dres.classifierNames)
47        self.removeCurves()
48        self.classifierColor = classifierColor
49        self.classifierNames = self.dres.classifierNames
50        self.numberOfClasses = len(self.dres.classValues)
51
52        for cNum in range(classifiersNum):
53            curve = self.addCurve('', pen=QPen(self.classifierColor[cNum], 3))
54            self.classifierCalibrationCKeys.append(curve)
55
56            curve = errorBarQwtPlotCurve('', connectPoints = 0, tickXw = 0.0)
57            curve.attach(self)
58            curve.setSymbol(QwtSymbol(QwtSymbol.NoSymbol, QBrush(Qt.color0), QPen(self.classifierColor[cNum], 1), QSize(0,0)))
59            curve.setStyle(QwtPlotCurve.UserCurve)
60            self.classifierYesClassRugCKeys.append(curve)
61
62            curve = errorBarQwtPlotCurve('', connectPoints = 0, tickXw = 0.0)
63            curve.attach(self)
64            curve.setSymbol(QwtSymbol(QwtSymbol.NoSymbol, QBrush(Qt.color0), QPen(self.classifierColor[cNum], 1), QSize(0,0)))
65            curve.setStyle(QwtPlotCurve.UserCurve)
66            self.classifierNoClassRugCKeys.append(curve)
67
68            self.showClassifiers.append(0)
69
70        ## compute curves for targetClass
71        if (self.dres <> None): ## check that targetClass in range
72            if self.targetClass < 0:
73                self.targetClass = 0
74            if self.targetClass >= self.numberOfClasses:
75                self.targetClass = self.numberOfClasses - 1
76            if self.targetClass < 0:
77                self.targetClass = None ## no classes, no target
78
79        if (self.dres == None) or (self.targetClass == None):
80            self.setMainTitle("")
81            for curve in self.classifierCalibrationCKeys + self.classifierYesClassRugCKeys + self.classifierNoClassRugCKeys:
82                curve.setData([], [])
83            return
84
85        self.setMainTitle(self.dres.classValues[self.targetClass])
86        calibrationCurves = orngStat.computeCalibrationCurve(self.dres, self.targetClass)
87
88        classifier = 0
89        for (curve, yesClassRugPoints, noClassRugPoints) in calibrationCurves:
90            x = [px for (px, py) in curve]
91            y = [py for (px, py) in curve]
92            curve = self.classifierCalibrationCKeys[classifier]
93            curve.setData(x, y)
94
95            x = []
96            y = []
97            for (px, py) in yesClassRugPoints:
98                n = py > 0.0 ##py
99                if n:
100                    py = 1.0
101                    x.append(px)
102                    y.append(py - self.rugHeight*n / 2.0)
103
104                    x.append(px)
105                    y.append(py)
106
107                    x.append(px)
108                    y.append(py - self.rugHeight*n)
109            curve = self.classifierYesClassRugCKeys[classifier]
110            curve.setData(x, y)
111
112            x = []
113            y = []
114            for (px, py) in noClassRugPoints:
115                n = py > 0.0 ##py
116                if n:
117                    py = 0.0
118                    x.append(px)
119                    y.append(py + self.rugHeight*n / 2.0)
120
121                    x.append(px)
122                    y.append(py + self.rugHeight*n)
123
124                    x.append(px)
125                    y.append(py)
126            curve = self.classifierNoClassRugCKeys[classifier]
127            curve.setData(x, y)
128            classifier += 1
129
130        self.updateCurveDisplay()
131
132    def removeCurves(self):
133        OWGraph.clear(self)
134        self.classifierColor = []
135        self.classifierNames = []
136        self.showClassifiers = []
137        self.showDiagonal = 0
138        self.showRugs = 1
139
140        self.classifierCalibrationCKeys = []
141        self.classifierYesClassRugCKeys = []
142        self.classifierNoClassRugCKeys = []
143
144        ## diagonal curve
145        self.diagonalCKey = self.addCurve("", pen = QPen(Qt.black, 1), style = QwtPlotCurve.Lines, symbol = QwtSymbol.NoSymbol, xData = [0.0, 1.0], yData = [0.0, 1.0])
146
147    def updateCurveDisplay(self):
148        self.diagonalCKey.setVisible(self.showDiagonal)
149
150        for cNum in range(len(self.showClassifiers)):
151            showCNum = (self.showClassifiers[cNum] <> 0)
152            self.classifierCalibrationCKeys[cNum].setVisible(showCNum)
153            b = showCNum and self.showRugs
154            self.classifierYesClassRugCKeys[cNum].setVisible(b)
155            self.classifierNoClassRugCKeys[cNum].setVisible(b)
156        self.updateLayout()
157        self.replot()
158
159    def setCalibrationCurveWidth(self, v):
160        for cNum in range(len(self.showClassifiers)):
161            self.classifierCalibrationCKeys[cNum].setPen(QPen(self.classifierColor[cNum], v))
162        self.replot()
163
164    def setShowClassifiers(self, list):
165        self.showClassifiers = list
166        self.updateCurveDisplay()
167
168    def setShowDiagonal(self, v):
169        self.showDiagonal = v
170        self.updateCurveDisplay()
171
172    def setShowRugs(self, v):
173        self.showRugs = v
174        self.updateCurveDisplay()
175
176    def sizeHint(self):
177        return QSize(170, 170)
178
179class OWCalibrationPlot(OWWidget):
180    settingsList = ["CalibrationCurveWidth", "ShowDiagonal", "ShowRugs"]
181    contextHandlers = {"": EvaluationResultsContextHandler("", "targetClass", "selectedClassifiers")}
182   
183    def __init__(self,parent=None, signalManager = None):
184        OWWidget.__init__(self, parent, signalManager, "Calibration Plot", 1)
185
186        # inputs
187        self.inputs=[("Evaluation Results", orngTest.ExperimentResults, self.results, Default)]
188
189        #set default settings
190        self.CalibrationCurveWidth = 3
191        self.ShowDiagonal = TRUE
192        self.ShowRugs = TRUE
193        #load settings
194        self.loadSettings()
195
196        # temp variables
197        self.dres = None
198        self.targetClass = None
199        self.numberOfClasses = 0
200        self.graphs = []
201        self.classifierColor = None
202        self.numberOfClassifiers = 0
203        self.classifiers = []
204        self.selectedClassifiers = []
205
206        # GUI
207        import sip
208        sip.delete(self.mainArea.layout())
209        self.graphsGridLayoutQGL = QGridLayout(self.mainArea)
210        self.mainArea.setLayout(self.graphsGridLayoutQGL)
211
212        ## save each ROC graph in separate file
213        self.graph = None
214        self.connect(self.graphButton, SIGNAL("clicked()"), self.saveToFile)
215
216        ## general tab
217        self.tabs = OWGUI.tabWidget(self.controlArea)
218        self.generalTab = OWGUI.createTabPage(self.tabs, "General")
219        self.settingsTab = OWGUI.createTabPage(self.tabs, "Settings")
220
221        self.splitQS = QSplitter()
222        self.splitQS.setOrientation(Qt.Vertical)
223
224        ## target class
225        self.classCombo = OWGUI.comboBox(self.generalTab, self, 'targetClass', box='Target class', items=[], callback=self.target)
226        OWGUI.separator(self.generalTab)
227
228        ## classifiers selection (classifiersQLB)
229        self.classifiersQVGB = OWGUI.widgetBox(self.generalTab, "Classifiers")
230        self.classifiersQLB = OWGUI.listBox(self.classifiersQVGB, self, "selectedClassifiers", selectionMode = QListWidget.MultiSelection, callback = self.classifiersSelectionChange)
231        self.unselectAllClassifiersQLB = OWGUI.button(self.classifiersQVGB, self, "(Un)select all", callback = self.SUAclassifiersQLB)
232
233        ## settings tab
234        OWGUI.hSlider(self.settingsTab, self, 'CalibrationCurveWidth', box='Calibration Curve Width', minValue=1, maxValue=9, step=1, callback=self.setCalibrationCurveWidth, ticks=1)
235        OWGUI.checkBox(self.settingsTab, self, 'ShowDiagonal', 'Show Diagonal Line', tooltip='', callback=self.setShowDiagonal)
236        OWGUI.checkBox(self.settingsTab, self, 'ShowRugs', 'Show Rugs', tooltip='', callback=self.setShowRugs)
237        self.settingsTab.layout().addStretch(100)
238
239    def sendReport(self):
240        # need to reimport - Qt provides something stupid instead
241        from __builtin__ import hex
242        self.reportSettings("Settings",
243                            [("Classifiers", ", ".join('<font color="#%s">%s</font>' % ("".join(("0"+hex(x)[2:])[-2:] for x in self.classifierColor[cNum].getRgb()[:3]), str(item.text()))
244                                                        for cNum, item in enumerate(self.classifiersQLB.item(i) for i in range(self.classifiersQLB.count()))
245                                                          if item.isSelected())),
246                             ("Target class", self.classCombo.itemText(self.targetClass)
247                                              if self.targetClass is not None else
248                                              "N/A"),
249                            ])
250        if self.targetClass is not None:
251            self.reportRaw("<br/>")
252            self.reportImage(self.graphs[self.targetClass].saveToFileDirect, QSize(400, 400))
253
254
255    def setCalibrationCurveWidth(self):
256        for g in self.graphs:
257            g.setCalibrationCurveWidth(self.CalibrationCurveWidth)
258
259    def setShowDiagonal(self):
260        for g in self.graphs:
261            g.setShowDiagonal(self.ShowDiagonal)
262
263    def setShowRugs(self):
264        for g in self.graphs:
265            g.setShowRugs(self.ShowRugs)
266
267    ##
268    def selectUnselectAll(self, qlb):
269        selected = 0
270        for i in range(qlb.count()):
271            if qlb.item(i).isSelected():
272                selected = 1
273                break
274        if selected: qlb.clearSelection()
275        else: qlb.selectAll()
276
277    def SUAclassifiersQLB(self):
278        self.selectUnselectAll(self.classifiersQLB)
279
280    def classifiersSelectionChange(self):
281        list = []
282        for i in range(self.classifiersQLB.count()):
283            if self.classifiersQLB.item(i).isSelected():
284                list.append( 1 )
285            else:
286                list.append( 0 )
287        for g in self.graphs:
288            g.setShowClassifiers(list)
289    ##
290
291    def calcAllClassGraphs(self):
292        cl = 0
293        for g in self.graphs:
294            g.setData(self.classifierColor, self.dres, cl)
295
296            ## user settings
297            g.setCalibrationCurveWidth(self.CalibrationCurveWidth)
298            g.setShowDiagonal(self.ShowDiagonal)
299            g.setShowRugs(self.ShowRugs)
300            cl += 1
301
302    def removeGraphs(self):
303        for g in self.graphs:
304            g.removeCurves()
305            g.hide()
306
307    def saveToFile(self):
308        if self.graph:
309            self.graph.saveToFile()
310
311    def target(self):
312        for g in self.graphs:
313            g.hide()
314
315        if (self.targetClass <> None) and (len(self.graphs) > 0):
316            if self.targetClass >= len(self.graphs):
317                self.targetClass = len(self.graphs) - 1
318            if self.targetClass < 0:
319                self.targetClass = 0
320            self.graph = self.graphs[self.targetClass]
321            self.graph.show()
322            self.graphsGridLayoutQGL.addWidget(self.graph, 0, 0)
323        else:
324            self.graph = None
325
326    def results(self, dres):
327        self.closeContext()
328
329        self.targetClass = None
330        self.classifiersQLB.clear()
331        self.removeGraphs()
332        self.classCombo.clear()
333
334        self.dres = dres
335
336        self.graphs = []
337        if self.dres <> None:
338            self.numberOfClasses = len(self.dres.classValues)
339            ## one graph for each class
340            for i in range(self.numberOfClasses):
341                graph = singleClassCalibrationPlotGraph(self.mainArea)
342                graph.hide()
343                self.graphs.append(graph)
344                self.classCombo.addItem(self.dres.classValues[i])
345
346            ## classifiersQLB
347            self.classifierColor = []
348            self.numberOfClassifiers = self.dres.numberOfLearners
349            if self.numberOfClassifiers > 1:
350                allCforHSV = self.numberOfClassifiers - 1
351            else:
352                allCforHSV = self.numberOfClassifiers
353            for i in range(self.numberOfClassifiers):
354                newColor = QColor()
355                newColor.setHsv(i*255/allCforHSV, 255, 255)
356                self.classifierColor.append( newColor )
357
358            self.calcAllClassGraphs()
359
360            ## update graphics
361            ## classifiersQLB
362            for i in range(self.numberOfClassifiers):
363                newColor = self.classifierColor[i]
364                self.classifiersQLB.addItem(QListWidgetItem(ColorPixmap(newColor), self.dres.classifierNames[i]))
365            self.classifiersQLB.selectAll()
366        else:
367            self.numberOfClasses = 0
368            self.classifierColor = None
369            self.targetClass = None ## no results, no target
370           
371        if not self.targetClass:
372            self.targetClass = 0
373           
374        self.openContext("", self.dres)
375        self.target()
376
377if __name__ == "__main__":
378    a = QApplication(sys.argv)
379    owdm = OWCalibrationPlot()
380    owdm.show()
381    a.exec_()
Note: See TracBrowser for help on using the repository browser.