source: orange/orange/OrangeWidgets/Evaluate/OWCalibrationPlot.py @ 9599:69c91d52d3e4

Revision 9599:69c91d52d3e4, 14.2 KB checked in by Matija Polajnar <matija.polajnar@…>, 2 years ago (diff)

Multi-label classificaiton widgets. Merged in from Wencan Luo's work with some modifications.

Line 
1"""
2<name>Calibration Plot</name>
3<description>Displays calibration plot based on evaluation of classifiers.</description>
4<contact>Tomaz Curk</contact>
5<icon>icons/CalibrationPlot.png</icon>
6<priority>1030</priority>
7"""
8from OWColorPalette import ColorPixmap
9from OWWidget import *
10from OWGraph import *
11import OWGUI
12
13import orngTest, orngStat
14import statc, math
15from Orange.evaluation.testing import TEST_TYPE_SINGLE
16
17class singleClassCalibrationPlotGraph(OWGraph):
18    def __init__(self, parent = None, name = None, title = ""):
19        OWGraph.__init__(self, parent, name)
20        self.setYRlabels(None)
21        self.enableGridXB(0)
22        self.enableGridYL(0)
23        self.setAxisMaxMajor(QwtPlot.xBottom, 10)
24        self.setAxisMaxMinor(QwtPlot.xBottom, 5)
25        self.setAxisMaxMajor(QwtPlot.yLeft, 10)
26        self.setAxisMaxMinor(QwtPlot.yLeft, 5)
27        self.setAxisScale(QwtPlot.xBottom, -0.0, 1.0, 0)
28        self.setAxisScale(QwtPlot.yLeft, -0.0, 1.0, 0)
29        self.setYLaxisTitle('actual probability')
30        self.setShowYLaxisTitle(1)
31        self.setXaxisTitle('estimated probability')
32        self.setShowXaxisTitle(1)
33        self.setShowMainTitle(1)
34        self.setMainTitle(title)
35        self.dres = None
36        self.numberOfClasses = None
37        self.targetClass = None
38        self.rugHeight = 0.02
39
40        self.removeCurves()
41
42    def setData(self, classifierColor, dres, targetClass):
43        self.classifierColor = classifierColor
44        self.dres = dres
45        self.targetClass = targetClass
46
47        classifiersNum = len(self.dres.classifierNames)
48        self.removeCurves()
49        self.classifierColor = classifierColor
50        self.classifierNames = self.dres.classifierNames
51        self.numberOfClasses = len(self.dres.classValues)
52
53        for cNum in range(classifiersNum):
54            curve = self.addCurve('', pen=QPen(self.classifierColor[cNum], 3))
55            self.classifierCalibrationCKeys.append(curve)
56
57            curve = errorBarQwtPlotCurve('', connectPoints = 0, tickXw = 0.0)
58            curve.attach(self)
59            curve.setSymbol(QwtSymbol(QwtSymbol.NoSymbol, QBrush(Qt.color0), QPen(self.classifierColor[cNum], 1), QSize(0,0)))
60            curve.setStyle(QwtPlotCurve.UserCurve)
61            self.classifierYesClassRugCKeys.append(curve)
62
63            curve = errorBarQwtPlotCurve('', connectPoints = 0, tickXw = 0.0)
64            curve.attach(self)
65            curve.setSymbol(QwtSymbol(QwtSymbol.NoSymbol, QBrush(Qt.color0), QPen(self.classifierColor[cNum], 1), QSize(0,0)))
66            curve.setStyle(QwtPlotCurve.UserCurve)
67            self.classifierNoClassRugCKeys.append(curve)
68
69            self.showClassifiers.append(0)
70
71        ## compute curves for targetClass
72        if (self.dres <> None): ## check that targetClass in range
73            if self.targetClass < 0:
74                self.targetClass = 0
75            if self.targetClass >= self.numberOfClasses:
76                self.targetClass = self.numberOfClasses - 1
77            if self.targetClass < 0:
78                self.targetClass = None ## no classes, no target
79
80        if (self.dres == None) or (self.targetClass == None):
81            self.setMainTitle("")
82            for curve in self.classifierCalibrationCKeys + self.classifierYesClassRugCKeys + self.classifierNoClassRugCKeys:
83                curve.setData([], [])
84            return
85
86        self.setMainTitle(self.dres.classValues[self.targetClass])
87        calibrationCurves = orngStat.computeCalibrationCurve(self.dres, self.targetClass)
88
89        classifier = 0
90        for (curve, yesClassRugPoints, noClassRugPoints) in calibrationCurves:
91            x = [px for (px, py) in curve]
92            y = [py for (px, py) in curve]
93            curve = self.classifierCalibrationCKeys[classifier]
94            curve.setData(x, y)
95
96            x = []
97            y = []
98            for (px, py) in yesClassRugPoints:
99                n = py > 0.0 ##py
100                if n:
101                    py = 1.0
102                    x.append(px)
103                    y.append(py - self.rugHeight*n / 2.0)
104
105                    x.append(px)
106                    y.append(py)
107
108                    x.append(px)
109                    y.append(py - self.rugHeight*n)
110            curve = self.classifierYesClassRugCKeys[classifier]
111            curve.setData(x, y)
112
113            x = []
114            y = []
115            for (px, py) in noClassRugPoints:
116                n = py > 0.0 ##py
117                if n:
118                    py = 0.0
119                    x.append(px)
120                    y.append(py + self.rugHeight*n / 2.0)
121
122                    x.append(px)
123                    y.append(py + self.rugHeight*n)
124
125                    x.append(px)
126                    y.append(py)
127            curve = self.classifierNoClassRugCKeys[classifier]
128            curve.setData(x, y)
129            classifier += 1
130
131        self.updateCurveDisplay()
132
133    def removeCurves(self):
134        OWGraph.clear(self)
135        self.classifierColor = []
136        self.classifierNames = []
137        self.showClassifiers = []
138        self.showDiagonal = 0
139        self.showRugs = 1
140
141        self.classifierCalibrationCKeys = []
142        self.classifierYesClassRugCKeys = []
143        self.classifierNoClassRugCKeys = []
144
145        ## diagonal curve
146        self.diagonalCKey = self.addCurve("", pen = QPen(Qt.black, 1), style = QwtPlotCurve.Lines, symbol = QwtSymbol.NoSymbol, xData = [0.0, 1.0], yData = [0.0, 1.0])
147
148    def updateCurveDisplay(self):
149        self.diagonalCKey.setVisible(self.showDiagonal)
150
151        for cNum in range(len(self.showClassifiers)):
152            showCNum = (self.showClassifiers[cNum] <> 0)
153            self.classifierCalibrationCKeys[cNum].setVisible(showCNum)
154            b = showCNum and self.showRugs
155            self.classifierYesClassRugCKeys[cNum].setVisible(b)
156            self.classifierNoClassRugCKeys[cNum].setVisible(b)
157        self.updateLayout()
158        self.replot()
159
160    def setCalibrationCurveWidth(self, v):
161        for cNum in range(len(self.showClassifiers)):
162            self.classifierCalibrationCKeys[cNum].setPen(QPen(self.classifierColor[cNum], v))
163        self.replot()
164
165    def setShowClassifiers(self, list):
166        self.showClassifiers = list
167        self.updateCurveDisplay()
168
169    def setShowDiagonal(self, v):
170        self.showDiagonal = v
171        self.updateCurveDisplay()
172
173    def setShowRugs(self, v):
174        self.showRugs = v
175        self.updateCurveDisplay()
176
177    def sizeHint(self):
178        return QSize(170, 170)
179
180class OWCalibrationPlot(OWWidget):
181    settingsList = ["CalibrationCurveWidth", "ShowDiagonal", "ShowRugs"]
182    contextHandlers = {"": EvaluationResultsContextHandler("", "targetClass", "selectedClassifiers")}
183   
184    def __init__(self,parent=None, signalManager = None):
185        OWWidget.__init__(self, parent, signalManager, "Calibration Plot", 1)
186
187        # inputs
188        self.inputs=[("Evaluation Results", orngTest.ExperimentResults, self.results, Default)]
189
190        #set default settings
191        self.CalibrationCurveWidth = 3
192        self.ShowDiagonal = TRUE
193        self.ShowRugs = TRUE
194        #load settings
195        self.loadSettings()
196
197        # temp variables
198        self.dres = None
199        self.targetClass = None
200        self.numberOfClasses = 0
201        self.graphs = []
202        self.classifierColor = None
203        self.numberOfClassifiers = 0
204        self.classifiers = []
205        self.selectedClassifiers = []
206
207        # GUI
208        import sip
209        sip.delete(self.mainArea.layout())
210        self.graphsGridLayoutQGL = QGridLayout(self.mainArea)
211        self.mainArea.setLayout(self.graphsGridLayoutQGL)
212
213        ## save each ROC graph in separate file
214        self.graph = None
215        self.connect(self.graphButton, SIGNAL("clicked()"), self.saveToFile)
216
217        ## general tab
218        self.tabs = OWGUI.tabWidget(self.controlArea)
219        self.generalTab = OWGUI.createTabPage(self.tabs, "General")
220        self.settingsTab = OWGUI.createTabPage(self.tabs, "Settings")
221
222        self.splitQS = QSplitter()
223        self.splitQS.setOrientation(Qt.Vertical)
224
225        ## target class
226        self.classCombo = OWGUI.comboBox(self.generalTab, self, 'targetClass', box='Target class', items=[], callback=self.target)
227        OWGUI.separator(self.generalTab)
228
229        ## classifiers selection (classifiersQLB)
230        self.classifiersQVGB = OWGUI.widgetBox(self.generalTab, "Classifiers")
231        self.classifiersQLB = OWGUI.listBox(self.classifiersQVGB, self, "selectedClassifiers", selectionMode = QListWidget.MultiSelection, callback = self.classifiersSelectionChange)
232        self.unselectAllClassifiersQLB = OWGUI.button(self.classifiersQVGB, self, "(Un)select all", callback = self.SUAclassifiersQLB)
233
234        ## settings tab
235        OWGUI.hSlider(self.settingsTab, self, 'CalibrationCurveWidth', box='Calibration Curve Width', minValue=1, maxValue=9, step=1, callback=self.setCalibrationCurveWidth, ticks=1)
236        OWGUI.checkBox(self.settingsTab, self, 'ShowDiagonal', 'Show Diagonal Line', tooltip='', callback=self.setShowDiagonal)
237        OWGUI.checkBox(self.settingsTab, self, 'ShowRugs', 'Show Rugs', tooltip='', callback=self.setShowRugs)
238        self.settingsTab.layout().addStretch(100)
239
240    def sendReport(self):
241        # need to reimport - Qt provides something stupid instead
242        from __builtin__ import hex
243        self.reportSettings("Settings",
244                            [("Classifiers", ", ".join('<font color="#%s">%s</font>' % ("".join(("0"+hex(x)[2:])[-2:] for x in self.classifierColor[cNum].getRgb()[:3]), str(item.text()))
245                                                        for cNum, item in enumerate(self.classifiersQLB.item(i) for i in range(self.classifiersQLB.count()))
246                                                          if item.isSelected())),
247                             ("Target class", self.classCombo.itemText(self.targetClass)
248                                              if self.targetClass is not None else
249                                              "N/A"),
250                            ])
251        if self.targetClass is not None:
252            self.reportRaw("<br/>")
253            self.reportImage(self.graphs[self.targetClass].saveToFileDirect, QSize(400, 400))
254
255
256    def setCalibrationCurveWidth(self):
257        for g in self.graphs:
258            g.setCalibrationCurveWidth(self.CalibrationCurveWidth)
259
260    def setShowDiagonal(self):
261        for g in self.graphs:
262            g.setShowDiagonal(self.ShowDiagonal)
263
264    def setShowRugs(self):
265        for g in self.graphs:
266            g.setShowRugs(self.ShowRugs)
267
268    ##
269    def selectUnselectAll(self, qlb):
270        selected = 0
271        for i in range(qlb.count()):
272            if qlb.item(i).isSelected():
273                selected = 1
274                break
275        if selected: qlb.clearSelection()
276        else: qlb.selectAll()
277
278    def SUAclassifiersQLB(self):
279        self.selectUnselectAll(self.classifiersQLB)
280
281    def classifiersSelectionChange(self):
282        list = []
283        for i in range(self.classifiersQLB.count()):
284            if self.classifiersQLB.item(i).isSelected():
285                list.append( 1 )
286            else:
287                list.append( 0 )
288        for g in self.graphs:
289            g.setShowClassifiers(list)
290    ##
291
292    def calcAllClassGraphs(self):
293        cl = 0
294        for g in self.graphs:
295            g.setData(self.classifierColor, self.dres, cl)
296
297            ## user settings
298            g.setCalibrationCurveWidth(self.CalibrationCurveWidth)
299            g.setShowDiagonal(self.ShowDiagonal)
300            g.setShowRugs(self.ShowRugs)
301            cl += 1
302
303    def removeGraphs(self):
304        for g in self.graphs:
305            g.removeCurves()
306            g.hide()
307
308    def saveToFile(self):
309        if self.graph:
310            self.graph.saveToFile()
311
312    def target(self):
313        for g in self.graphs:
314            g.hide()
315
316        if (self.targetClass <> None) and (len(self.graphs) > 0):
317            if self.targetClass >= len(self.graphs):
318                self.targetClass = len(self.graphs) - 1
319            if self.targetClass < 0:
320                self.targetClass = 0
321            self.graph = self.graphs[self.targetClass]
322            self.graph.show()
323            self.graphsGridLayoutQGL.addWidget(self.graph, 0, 0)
324        else:
325            self.graph = None
326
327    def results(self, dres):
328        self.closeContext()
329
330        self.targetClass = None
331        self.classifiersQLB.clear()
332        self.removeGraphs()
333        self.classCombo.clear()
334
335        self.dres = dres
336       
337        if dres and dres.test_type != TEST_TYPE_SINGLE:
338            self.warning(0, "Calibration plot is supported only for single-target prediction problems.")
339            return
340        self.warning(0, None)
341
342        self.graphs = []
343        if self.dres <> None:
344            self.numberOfClasses = len(self.dres.classValues)
345            ## one graph for each class
346            for i in range(self.numberOfClasses):
347                graph = singleClassCalibrationPlotGraph(self.mainArea)
348                graph.hide()
349                self.graphs.append(graph)
350                self.classCombo.addItem(self.dres.classValues[i])
351
352            ## classifiersQLB
353            self.classifierColor = []
354            self.numberOfClassifiers = self.dres.numberOfLearners
355            if self.numberOfClassifiers > 1:
356                allCforHSV = self.numberOfClassifiers - 1
357            else:
358                allCforHSV = self.numberOfClassifiers
359            for i in range(self.numberOfClassifiers):
360                newColor = QColor()
361                newColor.setHsv(i*255/allCforHSV, 255, 255)
362                self.classifierColor.append( newColor )
363
364            self.calcAllClassGraphs()
365
366            ## update graphics
367            ## classifiersQLB
368            for i in range(self.numberOfClassifiers):
369                newColor = self.classifierColor[i]
370                self.classifiersQLB.addItem(QListWidgetItem(ColorPixmap(newColor), self.dres.classifierNames[i]))
371            self.classifiersQLB.selectAll()
372        else:
373            self.numberOfClasses = 0
374            self.classifierColor = None
375            self.targetClass = None ## no results, no target
376           
377        if not self.targetClass:
378            self.targetClass = 0
379           
380        self.openContext("", self.dres)
381        self.target()
382
383if __name__ == "__main__":
384    a = QApplication(sys.argv)
385    owdm = OWCalibrationPlot()
386    owdm.show()
387    a.exec_()
Note: See TracBrowser for help on using the repository browser.