source: orange/Orange/OrangeWidgets/Evaluate/OWCalibrationPlot.py @ 10497:f8cc996f4e2e

Revision 10497:f8cc996f4e2e, 14.3 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Fixed calibration curve coloring (fixes #1131).

Line 
1"""
2<name>Calibration Plot</name>
3<description>Displays calibration plot based on evaluation of classifiers.</description>
4<contact>Tomaz Curk</contact>
5<icon>icons/CalibrationPlot.png</icon>
6<priority>1030</priority>
7"""
8from OWColorPalette import ColorPixmap
9from OWWidget import *
10from OWGraph import *
11import OWGUI
12
13import orngTest, orngStat
14import statc, math
15from Orange.evaluation.testing import TEST_TYPE_SINGLE
16
17class singleClassCalibrationPlotGraph(OWGraph):
18    def __init__(self, parent = None, name = None, title = ""):
19        OWGraph.__init__(self, parent, name)
20        self.setYRlabels(None)
21        self.enableGridXB(0)
22        self.enableGridYL(0)
23        self.setAxisMaxMajor(QwtPlot.xBottom, 10)
24        self.setAxisMaxMinor(QwtPlot.xBottom, 5)
25        self.setAxisMaxMajor(QwtPlot.yLeft, 10)
26        self.setAxisMaxMinor(QwtPlot.yLeft, 5)
27        self.setAxisScale(QwtPlot.xBottom, -0.0, 1.0, 0)
28        self.setAxisScale(QwtPlot.yLeft, -0.0, 1.0, 0)
29        self.setYLaxisTitle('actual probability')
30        self.setShowYLaxisTitle(1)
31        self.setXaxisTitle('estimated probability')
32        self.setShowXaxisTitle(1)
33        self.setShowMainTitle(1)
34        self.setMainTitle(title)
35        self.dres = None
36        self.numberOfClasses = None
37        self.targetClass = None
38        self.rugHeight = 0.02
39
40        self.removeCurves()
41
42    def setData(self, classifierColor, dres, targetClass):
43        self.classifierColor = classifierColor
44        self.dres = dres
45        self.targetClass = targetClass
46
47        classifiersNum = len(self.dres.classifierNames)
48        self.removeCurves()
49        self.classifierColor = classifierColor
50        self.classifierNames = self.dres.classifierNames
51        self.numberOfClasses = len(self.dres.classValues)
52
53        for cNum in range(classifiersNum):
54            curve = self.addCurve('', 
55                                  pen=QPen(self.classifierColor[cNum], 3),
56                                  brushColor=self.classifierColor[cNum]
57                                  )
58            self.classifierCalibrationCKeys.append(curve)
59
60            curve = errorBarQwtPlotCurve('', connectPoints = 0, tickXw = 0.0)
61            curve.attach(self)
62            curve.setSymbol(QwtSymbol(QwtSymbol.NoSymbol, QBrush(Qt.color0), QPen(self.classifierColor[cNum], 1), QSize(0,0)))
63            curve.setStyle(QwtPlotCurve.UserCurve)
64            self.classifierYesClassRugCKeys.append(curve)
65
66            curve = errorBarQwtPlotCurve('', connectPoints = 0, tickXw = 0.0)
67            curve.attach(self)
68            curve.setSymbol(QwtSymbol(QwtSymbol.NoSymbol, QBrush(Qt.color0), QPen(self.classifierColor[cNum], 1), QSize(0,0)))
69            curve.setStyle(QwtPlotCurve.UserCurve)
70            self.classifierNoClassRugCKeys.append(curve)
71
72            self.showClassifiers.append(0)
73
74        ## compute curves for targetClass
75        if (self.dres <> None): ## check that targetClass in range
76            if self.targetClass < 0:
77                self.targetClass = 0
78            if self.targetClass >= self.numberOfClasses:
79                self.targetClass = self.numberOfClasses - 1
80            if self.targetClass < 0:
81                self.targetClass = None ## no classes, no target
82
83        if (self.dres == None) or (self.targetClass == None):
84            self.setMainTitle("")
85            for curve in self.classifierCalibrationCKeys + self.classifierYesClassRugCKeys + self.classifierNoClassRugCKeys:
86                curve.setData([], [])
87            return
88
89        self.setMainTitle(self.dres.classValues[self.targetClass])
90        calibrationCurves = orngStat.computeCalibrationCurve(self.dres, self.targetClass)
91
92        classifier = 0
93        for (curve, yesClassRugPoints, noClassRugPoints) in calibrationCurves:
94            x = [px for (px, py) in curve]
95            y = [py for (px, py) in curve]
96            curve = self.classifierCalibrationCKeys[classifier]
97            curve.setData(x, y)
98
99            x = []
100            y = []
101            for (px, py) in yesClassRugPoints:
102                n = py > 0.0 ##py
103                if n:
104                    py = 1.0
105                    x.append(px)
106                    y.append(py - self.rugHeight*n / 2.0)
107
108                    x.append(px)
109                    y.append(py)
110
111                    x.append(px)
112                    y.append(py - self.rugHeight*n)
113            curve = self.classifierYesClassRugCKeys[classifier]
114            curve.setData(x, y)
115
116            x = []
117            y = []
118            for (px, py) in noClassRugPoints:
119                n = py > 0.0 ##py
120                if n:
121                    py = 0.0
122                    x.append(px)
123                    y.append(py + self.rugHeight*n / 2.0)
124
125                    x.append(px)
126                    y.append(py + self.rugHeight*n)
127
128                    x.append(px)
129                    y.append(py)
130            curve = self.classifierNoClassRugCKeys[classifier]
131            curve.setData(x, y)
132            classifier += 1
133
134        self.updateCurveDisplay()
135
136    def removeCurves(self):
137        OWGraph.clear(self)
138        self.classifierColor = []
139        self.classifierNames = []
140        self.showClassifiers = []
141        self.showDiagonal = 0
142        self.showRugs = 1
143
144        self.classifierCalibrationCKeys = []
145        self.classifierYesClassRugCKeys = []
146        self.classifierNoClassRugCKeys = []
147
148        ## diagonal curve
149        self.diagonalCKey = self.addCurve("", pen = QPen(Qt.black, 1), style = QwtPlotCurve.Lines, symbol = QwtSymbol.NoSymbol, xData = [0.0, 1.0], yData = [0.0, 1.0])
150
151    def updateCurveDisplay(self):
152        self.diagonalCKey.setVisible(self.showDiagonal)
153
154        for cNum in range(len(self.showClassifiers)):
155            showCNum = (self.showClassifiers[cNum] <> 0)
156            self.classifierCalibrationCKeys[cNum].setVisible(showCNum)
157            b = showCNum and self.showRugs
158            self.classifierYesClassRugCKeys[cNum].setVisible(b)
159            self.classifierNoClassRugCKeys[cNum].setVisible(b)
160        self.updateLayout()
161        self.replot()
162
163    def setCalibrationCurveWidth(self, v):
164        for cNum in range(len(self.showClassifiers)):
165            self.classifierCalibrationCKeys[cNum].setPen(QPen(self.classifierColor[cNum], v))
166        self.replot()
167
168    def setShowClassifiers(self, list):
169        self.showClassifiers = list
170        self.updateCurveDisplay()
171
172    def setShowDiagonal(self, v):
173        self.showDiagonal = v
174        self.updateCurveDisplay()
175
176    def setShowRugs(self, v):
177        self.showRugs = v
178        self.updateCurveDisplay()
179
180    def sizeHint(self):
181        return QSize(170, 170)
182
183class OWCalibrationPlot(OWWidget):
184    settingsList = ["CalibrationCurveWidth", "ShowDiagonal", "ShowRugs"]
185    contextHandlers = {"": EvaluationResultsContextHandler("", "targetClass", "selectedClassifiers")}
186   
187    def __init__(self,parent=None, signalManager = None):
188        OWWidget.__init__(self, parent, signalManager, "Calibration Plot", 1)
189
190        # inputs
191        self.inputs=[("Evaluation Results", orngTest.ExperimentResults, self.results, Default)]
192
193        #set default settings
194        self.CalibrationCurveWidth = 3
195        self.ShowDiagonal = TRUE
196        self.ShowRugs = TRUE
197        #load settings
198        self.loadSettings()
199
200        # temp variables
201        self.dres = None
202        self.targetClass = None
203        self.numberOfClasses = 0
204        self.graphs = []
205        self.classifierColor = None
206        self.numberOfClassifiers = 0
207        self.classifiers = []
208        self.selectedClassifiers = []
209
210        # GUI
211        import sip
212        sip.delete(self.mainArea.layout())
213        self.graphsGridLayoutQGL = QGridLayout(self.mainArea)
214        self.mainArea.setLayout(self.graphsGridLayoutQGL)
215
216        ## save each ROC graph in separate file
217        self.graph = None
218        self.connect(self.graphButton, SIGNAL("clicked()"), self.saveToFile)
219
220        ## general tab
221        self.tabs = OWGUI.tabWidget(self.controlArea)
222        self.generalTab = OWGUI.createTabPage(self.tabs, "General")
223        self.settingsTab = OWGUI.createTabPage(self.tabs, "Settings")
224
225        self.splitQS = QSplitter()
226        self.splitQS.setOrientation(Qt.Vertical)
227
228        ## target class
229        self.classCombo = OWGUI.comboBox(self.generalTab, self, 'targetClass', box='Target class', items=[], callback=self.target)
230        OWGUI.separator(self.generalTab)
231
232        ## classifiers selection (classifiersQLB)
233        self.classifiersQVGB = OWGUI.widgetBox(self.generalTab, "Classifiers")
234        self.classifiersQLB = OWGUI.listBox(self.classifiersQVGB, self, "selectedClassifiers", selectionMode = QListWidget.MultiSelection, callback = self.classifiersSelectionChange)
235        self.unselectAllClassifiersQLB = OWGUI.button(self.classifiersQVGB, self, "(Un)select all", callback = self.SUAclassifiersQLB)
236
237        ## settings tab
238        OWGUI.hSlider(self.settingsTab, self, 'CalibrationCurveWidth', box='Calibration Curve Width', minValue=1, maxValue=9, step=1, callback=self.setCalibrationCurveWidth, ticks=1)
239        OWGUI.checkBox(self.settingsTab, self, 'ShowDiagonal', 'Show Diagonal Line', tooltip='', callback=self.setShowDiagonal)
240        OWGUI.checkBox(self.settingsTab, self, 'ShowRugs', 'Show Rugs', tooltip='', callback=self.setShowRugs)
241        self.settingsTab.layout().addStretch(100)
242
243    def sendReport(self):
244        # need to reimport - Qt provides something stupid instead
245        from __builtin__ import hex
246        self.reportSettings("Settings",
247                            [("Classifiers", ", ".join('<font color="#%s">%s</font>' % ("".join(("0"+hex(x)[2:])[-2:] for x in self.classifierColor[cNum].getRgb()[:3]), str(item.text()))
248                                                        for cNum, item in enumerate(self.classifiersQLB.item(i) for i in range(self.classifiersQLB.count()))
249                                                          if item.isSelected())),
250                             ("Target class", self.classCombo.itemText(self.targetClass)
251                                              if self.targetClass is not None else
252                                              "N/A"),
253                            ])
254        if self.targetClass is not None:
255            self.reportRaw("<br/>")
256            self.reportImage(self.graphs[self.targetClass].saveToFileDirect, QSize(400, 400))
257
258
259    def setCalibrationCurveWidth(self):
260        for g in self.graphs:
261            g.setCalibrationCurveWidth(self.CalibrationCurveWidth)
262
263    def setShowDiagonal(self):
264        for g in self.graphs:
265            g.setShowDiagonal(self.ShowDiagonal)
266
267    def setShowRugs(self):
268        for g in self.graphs:
269            g.setShowRugs(self.ShowRugs)
270
271    ##
272    def selectUnselectAll(self, qlb):
273        selected = 0
274        for i in range(qlb.count()):
275            if qlb.item(i).isSelected():
276                selected = 1
277                break
278        if selected: qlb.clearSelection()
279        else: qlb.selectAll()
280
281    def SUAclassifiersQLB(self):
282        self.selectUnselectAll(self.classifiersQLB)
283
284    def classifiersSelectionChange(self):
285        list = []
286        for i in range(self.classifiersQLB.count()):
287            if self.classifiersQLB.item(i).isSelected():
288                list.append( 1 )
289            else:
290                list.append( 0 )
291        for g in self.graphs:
292            g.setShowClassifiers(list)
293    ##
294
295    def calcAllClassGraphs(self):
296        cl = 0
297        for g in self.graphs:
298            g.setData(self.classifierColor, self.dres, cl)
299
300            ## user settings
301            g.setCalibrationCurveWidth(self.CalibrationCurveWidth)
302            g.setShowDiagonal(self.ShowDiagonal)
303            g.setShowRugs(self.ShowRugs)
304            cl += 1
305
306    def removeGraphs(self):
307        for g in self.graphs:
308            g.removeCurves()
309            g.hide()
310
311    def saveToFile(self):
312        if self.graph:
313            self.graph.saveToFile()
314
315    def target(self):
316        for g in self.graphs:
317            g.hide()
318
319        if (self.targetClass <> None) and (len(self.graphs) > 0):
320            if self.targetClass >= len(self.graphs):
321                self.targetClass = len(self.graphs) - 1
322            if self.targetClass < 0:
323                self.targetClass = 0
324            self.graph = self.graphs[self.targetClass]
325            self.graph.show()
326            self.graphsGridLayoutQGL.addWidget(self.graph, 0, 0)
327        else:
328            self.graph = None
329
330    def results(self, dres):
331        self.closeContext()
332
333        self.targetClass = None
334        self.classifiersQLB.clear()
335        self.removeGraphs()
336        self.classCombo.clear()
337
338        self.dres = dres
339       
340        if dres and dres.test_type != TEST_TYPE_SINGLE:
341            self.warning(0, "Calibration plot is supported only for single-target prediction problems.")
342            return
343        self.warning(0, None)
344
345        self.graphs = []
346        if self.dres <> None:
347            self.numberOfClasses = len(self.dres.classValues)
348            ## one graph for each class
349            for i in range(self.numberOfClasses):
350                graph = singleClassCalibrationPlotGraph(self.mainArea)
351                graph.hide()
352                self.graphs.append(graph)
353                self.classCombo.addItem(self.dres.classValues[i])
354
355            ## classifiersQLB
356            self.classifierColor = []
357            self.numberOfClassifiers = self.dres.numberOfLearners
358            if self.numberOfClassifiers > 1:
359                allCforHSV = self.numberOfClassifiers - 1
360            else:
361                allCforHSV = self.numberOfClassifiers
362            for i in range(self.numberOfClassifiers):
363                newColor = QColor()
364                newColor.setHsv(i*255/allCforHSV, 255, 255)
365                self.classifierColor.append( newColor )
366
367            self.calcAllClassGraphs()
368
369            ## update graphics
370            ## classifiersQLB
371            for i in range(self.numberOfClassifiers):
372                newColor = self.classifierColor[i]
373                self.classifiersQLB.addItem(QListWidgetItem(ColorPixmap(newColor), self.dres.classifierNames[i]))
374            self.classifiersQLB.selectAll()
375        else:
376            self.numberOfClasses = 0
377            self.classifierColor = None
378            self.targetClass = None ## no results, no target
379           
380        if not self.targetClass:
381            self.targetClass = 0
382           
383        self.openContext("", self.dres)
384        self.target()
385
386if __name__ == "__main__":
387    a = QApplication(sys.argv)
388    owdm = OWCalibrationPlot()
389    owdm.show()
390    a.exec_()
Note: See TracBrowser for help on using the repository browser.