# source:orange/Orange/OrangeWidgets/Evaluate/OWROC.py@10752:a0a6ef1ab0b8

Revision 10752:a0a6ef1ab0b8, 43.8 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Making sure testing results are for a classification problem in Callibration Plot, Confusion Matrix, Lift Curve and ROC.

Line
1"""
2<name>ROC Analysis</name>
3<description>Displays Receiver Operating Characteristics curve based on evaluation of classifiers.</description>
4<contact>Tomaz Curk</contact>
5<icon>icons/ROCAnalysis.png</icon>
6<priority>1010</priority>
7"""
8from OWColorPalette import ColorPixmap
9from OWWidget import *
10from OWGraph import *
11import OWGUI
12import orngStat, orngTest
13import statc, math
14import warnings
15from Orange.evaluation.testing import TEST_TYPE_SINGLE
16
17def TCconvexHull(curves):
18    ## merge curves into one
19    mergedCurve = []
20    for c in curves:
21        mergedCurve.extend(c)
22    mergedCurve.sort() ## increasing by fp, tp
23
24    if len(mergedCurve) == 0: return []
25
26    hull = []
27    (prevX, maxY, fscore) = (mergedCurve[0] + (0.0,))[:3]
28    prevPfscore = [fscore]
29    px = prevX
30    for p in mergedCurve[1:]:
31        (px, py, fscore) = (p + (0.0,))[:3]
32        if (px == prevX):
33            if py > maxY:
34                prevPfscore = [fscore]
35                maxY = py
36            elif py == maxY:
37                prevPfscore.append(fscore)
38        elif (px > prevX):
39            hull = orngStat.ROCaddPoint((prevX, maxY, prevPfscore), hull, keepConcavities=0)
40            prevX = px
41            maxY = py
42            prevPfscore = [fscore]
43    hull = orngStat.ROCaddPoint((prevX, maxY, prevPfscore), hull, keepConcavities=0)
44
45    return hull
46
47class singleClassROCgraph(OWGraph):
48    def __init__(self, parent = None, name = None, title = ""):
49        OWGraph.__init__(self, parent, name)
50
51        self.setYRlabels(None)
52        self.enableGridXB(0)
53        self.enableGridYL(0)
54        self.setAxisMaxMajor(QwtPlot.xBottom, 10)
55        self.setAxisMaxMinor(QwtPlot.xBottom, 5)
56        self.setAxisMaxMajor(QwtPlot.yLeft, 10)
57        self.setAxisMaxMinor(QwtPlot.yLeft, 5)
58        self.setAxisScale(QwtPlot.xBottom, -0.0, 1.0, 0)
59        self.setAxisScale(QwtPlot.yLeft, -0.0, 1.0, 0)
60        self.setShowXaxisTitle(1)
61        self.setXaxisTitle("FP Rate (1-Specificity)")
62        self.setShowYLaxisTitle(1)
63        self.setYLaxisTitle("TP Rate (Sensitivity)")
64        self.setMainTitle(title)
65        self.setShowMainTitle(1)
66        self.targetClass = 0
67        self.averagingMethod = None
68        self.splitByIterations = None
69        self.VTAsamples = 10 ## vertical threshold averaging, number of samples
70        self.FPcost = 500.0
71        self.FNcost = 500.0
72        self.pvalue = 400.0 ##0.400
73
74        self.performanceLineSymbol = QwtSymbol(QwtSymbol.Ellipse, QBrush(Qt.color0), QPen(Qt.black), QSize(7,7))
75        self.defaultLineSymbol = QwtSymbol(QwtSymbol.Ellipse, QBrush(Qt.black), QPen(Qt.black), QSize(8,8))
76        self.convexHullPen = QPen(Qt.yellow, 3)
77
78        self.removeMarkers()
79        self.performanceMarkerKeys = []
80
81        self.removeCurves()
82
83    def computeCurve(self, res, classIndex=-1, keepConcavities=1):
84        return orngStat.TCcomputeROC(res, classIndex, keepConcavities)
85
86    def setNumberOfClassifiersIterationsAndClassifierColors(self, classifierNames, iterationsNum, classifierColor):
87        classifiersNum = len(classifierNames)
88        self.removeCurves()
89        self.classifierColor = classifierColor
90        self.classifierNames = classifierNames
91
92        for cNum in range(classifiersNum):
93            self.classifierIterationCKeys.append([])
94            self.classifierIterationConvexCKeys.append([])
95            self.classifierIterationROCdata.append([])
96            for iNum in range(iterationsNum):
97                curve = self.addCurve('', pen = QPen(self.classifierColor[cNum], 3), style=QwtPlotCurve.Lines)
98                self.classifierIterationCKeys[cNum].append(curve)
99                curve = self.addCurve('', pen = QPen(self.classifierColor[cNum], 1), style=QwtPlotCurve.Lines)
100                self.classifierIterationConvexCKeys[cNum].append(curve)
101                self.classifierIterationROCdata[cNum].append(None)
102
103            self.showClassifiers.append(0)
104            self.showIterations.append(0)
105
106            ## 'merge' average curve keys
107            curve = self.addCurve('', pen = QPen(self.classifierColor[cNum], 2), style=QwtPlotCurve.Lines)
108            self.mergedCKeys.append(curve)
109
110            curve = self.addCurve('', pen = QPen(Qt.black, 5), style=QwtPlotCurve.Lines)
111            curve.setSymbol(self.defaultLineSymbol)
112
113            self.mergedCThresholdKeys.append(curve)
114            self.mergedCThresholdMarkers.append([])
115
116            curve = self.addCurve('', pen = QPen(self.classifierColor[cNum], 1), style=QwtPlotCurve.Lines)
117##            self.mergedConvexCKeys.append(QPen(Qt.black, 5))
118            self.mergedConvexCKeys.append(curve)
119
120            newSymbol = QwtSymbol(QwtSymbol.NoSymbol, QBrush(Qt.color0), QPen(self.classifierColor[cNum], 2), QSize(0,0))
121            ## 'vertical' average curve keys
122            curve = errorBarQwtPlotCurve('', connectPoints = 1, tickXw = 1.0/self.VTAsamples/5.0)
123            curve.attach(self)
124            curve.setSymbol(newSymbol)
125            curve.setStyle(QwtPlotCurve.UserCurve)
126            self.verticalCKeys.append(curve)
127
128            ## 'threshold' average curve keys
129            curve = errorBarQwtPlotCurve('', connectPoints = 1, tickXw = 1.0/self.VTAsamples/5.0, tickYw = 1.0/self.VTAsamples/5.0, showVerticalErrorBar = 1, showHorizontalErrorBar = 1)
130            curve.attach(self)
131            curve.setSymbol(newSymbol)
132            curve.setStyle(QwtPlotCurve.UserCurve)
133            self.thresholdCKeys.append(curve)
134
135        ## iso-performance line on top of all curves
136        self.performanceLineCKey = self.addCurve('', pen = QPen(Qt.black, 2), style=QwtPlotCurve.Lines)
137        self.performanceLineCKey.setSymbol(self.performanceLineSymbol)
138
139    def removeCurves(self):
140        self.clear()
141        self.classifierColor = []
142        self.classifierNames = []
143        self.classifierIterationROCdata = []
144        self.showClassifiers = []
145        self.showIterations = []
146        self.showConvexCurves = 0
147        self.showConvexHull = 0
148        self.showPerformanceLine = 0
149        self.showDefaultThresholdPoint = 0
150        self.showDiagonal = 0
151
152        ## 'merge' average curve keys
153        self.mergedCKeys = []
154        self.mergedCThresholdKeys = []
155        self.mergedCThresholdMarkers = []
156        self.mergedConvexCKeys = []
157        ## 'vertical' average curve keys
158        self.verticalCKeys = []
159        ## 'threshold' average curve keys
160        self.thresholdCKeys = []
161        ## 'None' average curve keys
162        self.classifierIterationCKeys = []
163        self.classifierIterationConvexCKeys = []
164
165        ## convex hull calculation
166        self.mergedConvexHullData = []
167        self.verticalConvexHullData = []
168        self.thresholdConvexHullData = []
169        self.classifierConvexHullData = []
170        self.hullCurveDataForPerfLine = [] ## for performance analysis
171
172        ## diagonal curve
173        self.diagonalCKey = self.addCurve('', pen = QPen(Qt.black, 1), symbol = QwtSymbol.NoSymbol, xData = [0.0, 1.0], yData = [0.0, 1.0], style=QwtPlotCurve.Lines)
174
175        ## convex hull curve keys
176        self.mergedConvexHullCKey = self.addCurve('', pen = self.convexHullPen, style=QwtPlotCurve.Lines)
177        self.verticalConvexHullCKey = self.addCurve('', pen = self.convexHullPen, style=QwtPlotCurve.Lines)
178        self.thresholdConvexHullCKey = self.addCurve('', pen = self.convexHullPen, style=QwtPlotCurve.Lines)
179        self.classifierConvexHullCKey = self.addCurve('', pen = self.convexHullPen, style=QwtPlotCurve.Lines)
180
181        ## iso-performance line
182        self.performanceLineCKey = None
183
184    def setIterationCurves(self, iteration, curves):
185        classifier = 0
186        for c in curves:
187            x = [px for (px, py, pf) in c]
188            y = [py for (px, py, pf) in c]
189            curve = self.classifierIterationCKeys[classifier][iteration]
190            curve.setData(x, y)
191            self.classifierIterationROCdata[classifier][iteration] = c
192            classifier += 1
193
194    def setIterationConvexCurves(self, iteration, curves):
195        classifier = 0
196        for c in curves:
197            x = [px for (px, py, pf) in c]
198            y = [py for (px, py, pf) in c]
199            curve = self.classifierIterationConvexCKeys[classifier][iteration]
200            curve.setData(x, y)
201            classifier += 1
202
203    def setTestSetData(self, splitByIterations, targetClass):
204        self.splitByIterations = splitByIterations
205        ## generate the "base" unmodified ROC curves
206        self.targetClass = targetClass
207        iteration = 0
208        for isplit in splitByIterations:
209            # unmodified ROC curve
210            curves = self.computeCurve(isplit, self.targetClass, 1)
211            self.setIterationCurves(iteration, curves)
212
213            # convex ROC curve
214            curves = self.computeCurve(isplit, self.targetClass, 0)
215            self.setIterationConvexCurves(iteration, curves)
216            iteration += 1
217
218    def updateCurveDisplay(self):
219        self.diagonalCKey.setVisible(self.showDiagonal)
220
221        showSomething = 0
222        for cNum in range(len(zip(self.showClassifiers, self.mergedCKeys))):
223            showCNum = (self.showClassifiers[cNum] <> 0)
224
225            ## 'merge' averaging
226            b = (self.averagingMethod == 'merge') and showCNum
227            showSomething = showSomething or b
228            if self.mergedCKeys[cNum]:
229                self.mergedCKeys[cNum].setVisible(b)
230
231            b2 = b and self.showDefaultThresholdPoint
232            if self.mergedCThresholdKeys[cNum] <> None:
233                self.mergedCThresholdKeys[cNum].setVisible(b2)
234
235            for marker in self.mergedCThresholdMarkers[cNum]:
236                marker.setVisible(b2)
237
238            b = b and self.showConvexCurves
239            if self.mergedConvexCKeys[cNum] <> None:
240                self.mergedConvexCKeys[cNum].setVisible(b)
241
242            ## 'vertical' averaging
243            b = (self.averagingMethod == 'vertical') and showCNum
244            showSomething = showSomething or b
245            self.verticalCKeys[cNum].setVisible(b)
246
247            ## 'threshold' averaging
248            b = (self.averagingMethod == 'threshold') and showCNum
249            showSomething = showSomething or b
250            self.thresholdCKeys[cNum].setVisible(b)
251
252            ## 'None' averaging
253            for iNum in range(len(zip(self.showIterations, self.classifierIterationCKeys[cNum], self.classifierIterationConvexCKeys[cNum]))):
254                b = (self.averagingMethod == None) and showCNum and (self.showIterations[iNum] <> 0)
255                showSomething = showSomething or b
256                self.classifierIterationCKeys[cNum][iNum].setVisible(b)
257                b = b and self.showConvexCurves
258                self.classifierIterationConvexCKeys[cNum][iNum].setVisible(b)
259
260        chb = (showSomething) and (self.averagingMethod == None) and self.showConvexHull
261        if self.classifierConvexHullCKey:
262            self.classifierConvexHullCKey.setVisible(chb)
263
264        chb = (showSomething) and (self.averagingMethod == 'merge') and self.showConvexHull
265        if self.mergedConvexHullCKey:
266            self.mergedConvexHullCKey.setVisible(chb)
267
268        chb = (showSomething) and (self.averagingMethod == 'vertical') and self.showConvexHull
269        if self.verticalConvexHullCKey:
270            self.verticalConvexHullCKey.setVisible(chb)
271
272        chb = (showSomething) and (self.averagingMethod == 'threshold') and self.showConvexHull
273        if self.thresholdConvexHullCKey:
274            self.thresholdConvexHullCKey.setVisible(chb)
275
276        ## performance line
277        b = (self.averagingMethod == 'merge') and self.showPerformanceLine
278        for marker in self.performanceMarkerKeys:
279            marker.setVisible(b)
280
281        if self.performanceLineCKey:
282##            curve.setVisible(b)
283            self.performanceLineCKey.setVisible(b)
284
285##        self.updateLayout()
286##        self.update()
287        self.replot()
288
289    def setShowConvexCurves(self, b):
290        self.showConvexCurves = b
291        self.updateCurveDisplay()
292
293    def setShowConvexHull(self, b):
294        self.showConvexHull = b
295        self.updateCurveDisplay()
296
297    def setShowPerformanceLine(self, b):
298        self.showPerformanceLine = b
299        self.updateCurveDisplay()
300
301    def setShowDefaultThresholdPoint(self, b):
302        self.showDefaultThresholdPoint = b
303        self.updateCurveDisplay()
304
305    def setShowClassifiers(self, list):
306        self.showClassifiers = list
307        self.calcConvexHulls()
308        self.calcUpdatePerformanceLine() ## new data for performance line
309        self.updateCurveDisplay()
310
311    def setShowIterations(self, list):
312        self.showIterations = list
313        self.calcAverageCurves()
314        self.calcConvexHulls()
315        self.calcUpdatePerformanceLine() ## new data for performance line
316        self.updateCurveDisplay()
317
318    ## calculate the average curve for the selected test sets (with all the averaging methods)
319    def calcAverageCurves(self):
320        ##
321        ## self.averagingMethod == 'merge':
323        for i, (isplit, show) in enumerate(zip(self.splitByIterations, self.showIterations)):
324            if show:
325                for te in isplit.results:
326                    mergedIterations.results.append( te )
327
328        self.mergedConvexHullData = []
329        if len(mergedIterations.results) > 0:
330            curves = self.computeCurve(mergedIterations, self.targetClass, 1)
331            convexCurves = self.computeCurve(mergedIterations, self.targetClass, 0)
332            classifier = 0
333            for c in curves:
334                x = [px for (px, py, pf) in c]
335                y = [py for (px, py, pf) in c]
336                self.mergedCKeys[classifier].setData(x, y)
337
338                # points of defualt threshold classifiers
339                defPoint = [(abs(pf-0.5), pf, px, py) for (px, py, pf) in c]
340                defPoints = []
341                if len(defPoint) > 0:
342                    defPoint.sort()
343                    defPoints = [(px, py, pf) for (d, pf, px, py) in defPoint if d == defPoint[0][0]]
344                else:
345                    defPoints = []
346                defX = [px for (px, py, pf) in defPoints]
347                defY = [py for (px, py, pf) in defPoints]
348                self.mergedCThresholdKeys[classifier].setData(defX, defY)
349
350                for marker in self.mergedCThresholdMarkers[classifier]:
351                    marker.detach()
352                self.mergedCThresholdMarkers[classifier] = []
353                for (dx, dy, pf) in defPoints:
354                    dx = max(min(0.95, dx + 0.01), 0.01)
355                    dy = min(max(0.01, dy - 0.02), 0.95)
356                    marker = self.addMarker('%3.2g' % (pf), dx, dy, alignment = Qt.AlignRight)
357                    self.mergedCThresholdMarkers[classifier].append(marker)
358                classifier += 1
359            classifier = 0
360            for c in convexCurves:
361                self.mergedConvexHullData.append(c) ## put all points of all curves into one big array
362                x = [px for (px, py, pf) in c]
363                y = [py for (px, py, pf) in c]
364                curve = self.mergedConvexCKeys[classifier]
365                curve.setData(x, y)
366                classifier += 1
367        else:
368            for c in range(len(self.mergedCKeys)):
369                self.mergedCKeys[c].setData([], [])
370                self.mergedCThresholdKeys[c].setData([], [])
371                for marker in self.mergedCThresholdMarkers[c]:
372                    marker.detach()
373                self.mergedCThresholdMarkers[c] = []
374                self.mergedConvexCKeys[c].setData([], [])
375
376        ## prepare a common input structure for vertical and threshold averaging
377        ROCS = []
378        somethingToShow = 0
379        for c in self.classifierIterationROCdata:
380            ROCS.append([])
381            i = 0
382            for s in self.showIterations:
383                if s:
384                    somethingToShow = 1
385                    ROCS[-1].append(c[i])
386                i += 1
387
388        ## remove curve
389        self.verticalConvexHullData = []
390        self.thresholdConvexHullData = []
391
392        if somethingToShow == 0:
393            for curve in self.verticalCKeys:
394                curve.setData([], [])
395
396            for curve in self.thresholdCKeys:
397                curve.setData([], [])
398            return
399
400        ##
401        ## self.averagingMethod == 'vertical':
402        ## calculated from the self.classifierIterationROCdata data
403        try:
404            (averageCurves, verticalErrorBarValues) = orngStat.TCverticalAverageROC(ROCS, self.VTAsamples)
405        except (ValueError, SystemError), er:
406            print >> sys.stderr, "Failed to compute vertical average ROC curve. " + er.message
407            averageCurves, verticalErrorBarValues = [], []
408            pass
409
410        classifier = 0
411        for c in averageCurves:
412            self.verticalConvexHullData.append(c)
413            xs = []
414            mps = []
415            for pcn in range(len(c)):
416                (px, py) = c[pcn]
417                ## for the error bar plot
418                xs.append(px)
419                mps.append(py + 0.0)
420
421                xs.append(px)
422                mps.append(py + verticalErrorBarValues[classifier][pcn])
423
424                xs.append(px)
425                mps.append(py - verticalErrorBarValues[classifier][pcn])
426
427            ckey =  self.verticalCKeys[classifier]
428            ckey.setData(xs, mps)
429            classifier += 1
430
431        ##
432        ## self.averagingMethod == 'threshold':
433        ## calculated from the self.classifierIterationROCdata data
434        (averageCurves, verticalErrorBarValues, horizontalErrorBarValues) = orngStat.TCthresholdlAverageROC(ROCS, self.VTAsamples)
435        classifier = 0
436        for c in averageCurves:
437            self.thresholdConvexHullData.append(c)
438            xs = []
439            mps = []
440            for pcn in range(len(c)):
441                (px, py) = c[pcn]
442                ## for the error bar plot
443                xs.append(px + 0.0)
444                mps.append(py + 0.0)
445
446                xs.append(px - horizontalErrorBarValues[classifier][pcn])
447                mps.append(py + verticalErrorBarValues[classifier][pcn])
448
449                xs.append(px + horizontalErrorBarValues[classifier][pcn])
450                mps.append(py - verticalErrorBarValues[classifier][pcn])
451
452            ckey = self.thresholdCKeys[classifier]
453            ckey.setData(xs, mps)
454            classifier += 1
455
456        ## self.averagingMethod == 'None'
458
459    def calcConvexHulls(self):
460        ## self.classifierConvexHullCKey = -1
461        hullData = []
462        for cNum in range(len(self.showClassifiers)):
463            for iNum in range(len(self.showIterations)):
464                if (self.showClassifiers[cNum] <> 0) and (self.showIterations[iNum] <> 0):
465                    hullData.append(self.classifierIterationROCdata[cNum][iNum])
466
467        convexHullCurve = TCconvexHull(hullData)
468        x = [px for (px, py, pf) in convexHullCurve]
469        y = [py for (px, py, pf) in convexHullCurve]
470        self.classifierConvexHullCKey.setData(x, y)
471
472        ## self.mergedConvexHullCKey = -1
473        hullData = []
474        for cNum in range(len(self.mergedConvexHullData)):
475            if (self.showClassifiers[cNum] <> 0):
476                ncurve = []
477                for (px, py, pfscore) in self.mergedConvexHullData[cNum]:
478                    ncurve.append( (px, py, (cNum, pfscore)) )
479                hullData.append(ncurve)
480
481        self.hullCurveDataForPerfLine = TCconvexHull(hullData) # keep data about curve for performance line drawing
482        x = [px for (px, py, pf) in self.hullCurveDataForPerfLine]
483        y = [py for (px, py, pf) in self.hullCurveDataForPerfLine]
484        self.mergedConvexHullCKey.setData(x, y)
485
486        ## self.verticalConvexHullCKey = -1
487        hullData = []
488        for cNum in range(len(self.verticalConvexHullData)):
489            if (self.showClassifiers[cNum] <> 0):
490                hullData.append(self.verticalConvexHullData[cNum])
491
492        convexHullCurve = TCconvexHull(hullData)
493        x = [px for (px, py, pf) in convexHullCurve]
494        y = [py for (px, py, pf) in convexHullCurve]
495        self.verticalConvexHullCKey.setData(x, y)
496
497        ## self.thresholdConvexHullCKey = -1
498        hullData = []
499        for cNum in range(len(self.thresholdConvexHullData)):
500            if (self.showClassifiers[cNum] <> 0):
501                hullData.append(self.thresholdConvexHullData[cNum])
502
503        convexHullCurve = TCconvexHull(hullData)
504        x = [px for (px, py, pf) in convexHullCurve]
505        y = [py for (px, py, pf) in convexHullCurve]
506        self.thresholdConvexHullCKey.setData(x, y)
507
508    def setAveragingMethod(self, m):
509        self.averagingMethod = m
510        self.updateCurveDisplay()
511
512    ## performance line
513    def calcUpdatePerformanceLine(self):
514        closestpoints = orngStat.TCbestThresholdsOnROCcurve(self.FPcost, self.FNcost, self.pvalue, self.hullCurveDataForPerfLine)
515        m = (self.FPcost*(1.0 - self.pvalue)) / (self.FNcost*self.pvalue)
516
517        ## now draw the closest line to the curve
518        b = (self.averagingMethod == 'merge') and self.showPerformanceLine
519        lpx = []
520        lpy = []
521        first = 1
522        ## remove old markers
523        for marker in self.performanceMarkerKeys:
524            try:
525                marker.detach()
526            except RuntimeError: ## RuntimeError: underlying C/C++ object has been deleted
527                pass
528        self.performanceMarkerKeys = []
529        for (x, y, fscorelist) in closestpoints:
530            if first:
531                first = 0
532                lpx.append(x - 2.0)
533                lpy.append(y - 2.0*m)
534            lpx.append(x)
535            lpy.append(y)
536            px = x
537            py = y
538            for (cNum, threshold) in fscorelist:
539                s = "%1.3f %s" % (threshold, self.classifierNames[cNum])
540                px = max(min(0.95, px + 0.01), 0.01)
541                py = min(max(0.01, py - 0.02), 0.95)
542                marker = self.addMarker(s, px, py, alignment = Qt.AlignRight)
543                marker.setVisible(b)
544                self.performanceMarkerKeys.append(marker)
545        if len(closestpoints) > 0:
546            lpx.append(x + 2.0)
547            lpy.append(y + 2.0*m)
548
549        if self.performanceLineCKey:
550            self.performanceLineCKey.setData(lpx, lpy)
551            self.performanceLineCKey.setVisible(b)
552        self.replot()
553#        self.update()
554
555    def costChanged(self, FPcost, FNcost):
556        self.FPcost = float(FPcost)
557        self.FNcost = float(FNcost)
558        self.calcUpdatePerformanceLine()
559
560    def pChanged(self, pvalue):
561        self.pvalue = float(pvalue)
562        self.calcUpdatePerformanceLine()
563
564    def setPointWidth(self, v):
565        self.performanceLineSymbol.setSize(v, v)
566        if self.performanceLineCKey:
567            self.performanceLineCKey.setSymbol(self.performanceLineSymbol)
568
569        def setW(curve):
570            sym = curve.symbol() #.setPen(QPen(self.classifierColor[cNum], v))
571            sym.setSize(v + 1, v + 1)
572            if QWT_VERSION_STR >= "5.2": # in Qwt 5.1.* curve.setSymbol results in a crash
573                curve.setSymbol(sym)
574
575        for item in self.itemList():
576            setW(item)
577
578#        for cNum in range(len(zip(self.showClassifiers, self.mergedCKeys))):
579#            setW(self.mergedCKeys[cNum]) #.setPen(QPen(self.classifierColor[cNum], v))
580#            setW(self.verticalCKeys[cNum]) #.setPen(QPen(self.classifierColor[cNum], v))
581#            setW(self.thresholdCKeys[cNum]) #.setPen(QPen(self.classifierColor[cNum], v))
582#            setW(self.mergedCThresholdKeys[cNum])
583#            for iNum in range(len(zip(self.showIterations, self.classifierIterationCKeys[cNum]))):
584#                setW(self.classifierIterationCKeys[cNum][iNum]) #.setPen(QPen(self.classifierColor[cNum], v))
585        self.replot()
586#        self.update()
587
588    def setCurveWidth(self, v):
589        for cNum in range(len(zip(self.showClassifiers, self.mergedCKeys))):
590            self.mergedCKeys[cNum].setPen(QPen(self.classifierColor[cNum], v))
591            self.verticalCKeys[cNum].setPen(QPen(self.classifierColor[cNum], v))
592            self.thresholdCKeys[cNum].setPen(QPen(self.classifierColor[cNum], v))
593            for iNum in range(len(zip(self.showIterations, self.classifierIterationCKeys[cNum]))):
594                self.classifierIterationCKeys[cNum][iNum].setPen(QPen(self.classifierColor[cNum], v))
595        self.replot()
596#        self.update()
597
598    def setConvexCurveWidth(self, v):
599        for cNum in range(len(zip(self.showClassifiers, self.mergedConvexCKeys))):
600            self.mergedConvexCKeys[cNum].setPen(QPen(self.classifierColor[cNum], v))
601            for iNum in range(len(zip(self.showIterations, self.classifierIterationConvexCKeys[cNum]))):
602                self.classifierIterationConvexCKeys[cNum][iNum].setPen(QPen(self.classifierColor[cNum], v))
603        self.replot()
604#        self.update()
605
606    def setShowDiagonal(self, v):
607        self.showDiagonal = v
608        self.updateCurveDisplay()
609
610    def setConvexHullCurveWidth(self, v):
611        self.convexHullPen.setWidth(v)
612        self.mergedConvexHullCKey.setPen(self.convexHullPen)
613        self.verticalConvexHullCKey.setPen(self.convexHullPen)
614        self.thresholdConvexHullCKey.setPen(self.convexHullPen)
615        self.classifierConvexHullCKey.setPen(self.convexHullPen)
616        self.replot()
617#        self.update()
618
619    def setHullColor(self, c):
620        self.convexHullPen.setColor(c)
621        self.mergedConvexHullCKey.setPen(self.convexHullPen)
622        self.verticalConvexHullCKey.setPen(self.convexHullPen)
623        self.thresholdConvexHullCKey.setPen(self.convexHullPen)
624        self.classifierConvexHullCKey.setPen(self.convexHullPen)
625        self.replot()
626#        self.update()
627
628    def sizeHint(self):
629        return QSize(100, 100)
630
631class OWROC(OWWidget):
632    settingsList = ["PointWidth", "CurveWidth", "ConvexCurveWidth", "ShowDiagonal",
633                    "ConvexHullCurveWidth", "HullColor", "AveragingMethodIndex",
634                    "ShowConvexHull", "ShowConvexCurves", "EnablePerformance", "DefaultThresholdPoint"]
635    contextHandlers = {"": EvaluationResultsContextHandler("", "targetClass", "selectedClassifiers")}
636
637    def __init__(self,parent=None, signalManager = None):
638        OWWidget.__init__(self, parent, signalManager, "ROC Analysis", 1)
639
640        # inputs
641        self.inputs=[("Evaluation Results", orngTest.ExperimentResults, self.test_results, Default)]
642
643        # default settings
644        self.PointWidth = 7
645        self.CurveWidth = 3
646        self.ConvexCurveWidth = 1
647        self.ShowDiagonal = TRUE
648        self.ConvexHullCurveWidth = 3
649        self.HullColor = str(QColor(Qt.yellow).name())
650        self.AveragingMethodIndex = 0 ##'merge'
651        self.ShowConvexHull = TRUE
652        self.ShowConvexCurves = FALSE
653        self.EnablePerformance = TRUE
654        self.DefaultThresholdPoint = TRUE
655
658
659        # temp variables
660        self.dres = None
661        self.classifierColor = None
662        self.numberOfClasses  = 0
663        self.targetClass = None
664        self.numberOfClassifiers = 0
665        self.numberOfIterations = 0
666        self.graphs = []
667        self.maxp = 1000
668        self.defaultPerfLinePValues = []
669        self.classifiers = []
670        self.selectedClassifiers = []
671
672        # performance analysis (temporary values
673        self.FPcost = 500.0
674        self.FNcost = 500.0
675        self.pvalue = 50.0 ##0.400
676
677        # list of values (remember for each class)
678        self.FPcostList = []
679        self.FNcostList = []
680        self.pvalueList = []
681
682        self.AveragingMethodNames = ['merge', 'vertical', 'threshold', None]
683        self.AveragingMethod = self.AveragingMethodNames[min(3, self.AveragingMethodIndex)]
684
685        # GUI
686        import sip
687        sip.delete(self.mainArea.layout())
688        self.graphsGridLayoutQGL = QGridLayout(self.mainArea)
689        self.mainArea.setLayout(self.graphsGridLayoutQGL)
690        # save each ROC graph in separate file
691        self.connect(self.graphButton, SIGNAL("clicked()"), self.saveToFile)
692
693        ## general tab
694        self.tabs = OWGUI.tabWidget(self.controlArea)
695        self.generalTab = OWGUI.createTabPage(self.tabs, "General")
696
697        ## target class
698        self.classCombo = OWGUI.comboBox(self.generalTab, self, 'targetClass', box='Target class', items=[], callback=self.target)
699        #self.classCombo.setMaximumSize(150, 20)
700
701        ## classifiers selection (classifiersQLB)
702        self.classifiersQVGB = OWGUI.widgetBox(self.generalTab, "Classifiers")
703        self.classifiersQLB = OWGUI.listBox(self.classifiersQVGB, self, "selectedClassifiers", selectionMode = QListWidget.MultiSelection, callback = self.classifiersSelectionChange)
704        self.unselectAllClassifiersQLB = OWGUI.button(self.classifiersQVGB, self, "(Un)select All", callback = self.SUAclassifiersQLB)
705
706        # show convex ROC curves and show ROC convex hull
707        self.convexCurvesQCB = OWGUI.checkBox(self.generalTab, self, 'ShowConvexCurves', 'Show convex ROC curves', tooltip='', callback=self.setShowConvexCurves)
708        OWGUI.checkBox(self.generalTab, self, 'ShowConvexHull', 'Show ROC convex hull', tooltip='', callback=self.setShowConvexHull)
709
710
711        # performance analysis
712        self.performanceTab = OWGUI.createTabPage(self.tabs, "Analysis")
713        self.performanceTabCosts = OWGUI.widgetBox(self.performanceTab, box = 1)
714        OWGUI.checkBox(self.performanceTabCosts, self, 'EnablePerformance', 'Show performance line', tooltip='', callback=self.setShowPerformanceAnalysis)
715        OWGUI.checkBox(self.performanceTabCosts, self, 'DefaultThresholdPoint', 'Default threshold (0.5) point', tooltip='', callback=self.setShowDefaultThresholdPoint)
716
717        ## FP and FN cost ranges
718        mincost = 1; maxcost = 1000; stepcost = 5;
719        self.maxpsum = 100; self.minp = 1; self.maxp = self.maxpsum - self.minp ## need it also in self.pvaluesUpdated
720        stepp = 1.0
721
722        OWGUI.hSlider(self.performanceTabCosts, self, 'FPcost', box='FP Cost', minValue=mincost, maxValue=maxcost, step=stepcost, callback=self.costsChanged, ticks=50)
723        OWGUI.hSlider(self.performanceTabCosts, self, 'FNcost', box='FN Cost', minValue=mincost, maxValue=maxcost, step=stepcost, callback=self.costsChanged, ticks=50)
724
725        ptc = OWGUI.widgetBox(self.performanceTabCosts, "Prior target class probability [%]")
726        OWGUI.hSlider(ptc, self, 'pvalue', minValue=self.minp, maxValue=self.maxp, step=stepp, callback=self.pvaluesUpdated, ticks=5, labelFormat="%2.1f")
727        OWGUI.button(ptc, self, 'Compute from data', self.setDefaultPValues) ## reset p values to default
728
729        ## test set selection (testSetsQLB)
730        self.testSetsQVGB = OWGUI.widgetBox(self.performanceTab, "Test sets")
731        self.testSetsQLB = OWGUI.listBox(self.testSetsQVGB, self, selectionMode = QListWidget.MultiSelection, callback = self.testSetsSelectionChange)
732        self.unselectAllTestSetsQLB = OWGUI.button(self.testSetsQVGB, self, "(Un)select All", callback = self.SUAtestSetsQLB)
733
734        # settings tab
735        self.settingsTab = OWGUI.createTabPage(self.tabs, "Settings")
736        OWGUI.radioButtonsInBox(self.settingsTab, self, 'AveragingMethodIndex', ['Merge (expected ROC perf.)', 'Vertical', 'Threshold', 'None'], box='Averaging ROC curves', callback=self.selectAveragingMethod)
737        OWGUI.hSlider(self.settingsTab, self, 'PointWidth', box='Point width', minValue=0, maxValue=9, step=1, callback=self.setPointWidth, ticks=1)
738        OWGUI.hSlider(self.settingsTab, self, 'CurveWidth', box='ROC curve width', minValue=1, maxValue=5, step=1, callback=self.setCurveWidth, ticks=1)
739        OWGUI.hSlider(self.settingsTab, self, 'ConvexCurveWidth', box='ROC convex curve width', minValue=1, maxValue=5, step=1, callback=self.setConvexCurveWidth, ticks=1)
740        OWGUI.hSlider(self.settingsTab, self, 'ConvexHullCurveWidth', box='ROC convex hull', minValue=2, maxValue=9, step=1, callback=self.setConvexHullCurveWidth, ticks=1)
741        OWGUI.checkBox(self.settingsTab, self, 'ShowDiagonal', 'Show diagonal ROC line', tooltip='', callback=self.setShowDiagonal)
743
744        self.resize(800, 600)
745
746    def sendReport(self):
747        # need to reimport - Qt provides something stupid instead
748        from __builtin__ import hex
749        self.reportSettings("Settings",
750                            [("Classifiers", ", ".join('<font color="#%s">%s</font>' % ("".join(("0"+hex(x)[2:])[-2:] for x in self.classifierColor[cNum].getRgb()[:3]), str(item.text()))
751                                                        for cNum, item in enumerate(self.classifiersQLB.item(i) for i in range(self.classifiersQLB.count()))
752                                                          if item.isSelected())),
753                             ("Target class", self.classCombo.itemText(self.targetClass)
754                                              if self.targetClass is not None else
755                                              "N/A"),
756                             ("Costs", "FP=%i, FN=%i" % (self.FPcost, self.FNcost)),
757                             ("Prior target class probability", "%i%%" % self.pvalue)
758                            ])
759        if self.targetClass is not None:
760            self.reportRaw("<br/>")
761            self.reportImage(self.graphs[self.targetClass].saveToFileDirect, QSize(400, 400))
762
763    def saveToFile(self):
764        for g in self.graphs:
765            if g.isVisible():
766                g.saveToFile()
767
768    def setPointWidth(self):
769        for g in self.graphs:
770            g.setPointWidth(self.PointWidth)
771
772    def setCurveWidth(self):
773        for g in self.graphs:
774            g.setCurveWidth(self.CurveWidth)
775
776    def setConvexCurveWidth(self):
777        for g in self.graphs:
778            g.setConvexCurveWidth(self.ConvexCurveWidth)
779
780    def setShowDiagonal(self):
781        for g in self.graphs:
782            g.setShowDiagonal(self.ShowDiagonal)
783
784    def setConvexHullCurveWidth(self):
785        for g in self.graphs:
786            g.setConvexHullCurveWidth(self.ConvexHullCurveWidth)
787
788    def setHullColor(self):
789        self.HullColor = str(c.name())
790        for g in self.graphs:
791            g.setHullColor(self.HullColor)
792
793    ##
794    def selectUnselectAll(self, qlb):
795        selected = 0
796        for i in range(qlb.count()):
797            if qlb.item(i).isSelected():
798                selected = 1
799                break
800        if selected: qlb.clearSelection()
801        else: qlb.selectAll()
802
803    def SUAclassifiersQLB(self):
804        self.selectUnselectAll(self.classifiersQLB)
805
806    def SUAtestSetsQLB(self):
807        self.selectUnselectAll(self.testSetsQLB)
808    ##
809
810    def selectAveragingMethod(self):
811        self.AveragingMethod = self.AveragingMethodNames[self.AveragingMethodIndex]
812        if self.AveragingMethod == 'merge':
813            self.performanceTabCosts.setEnabled(self.EnablePerformance)
814        elif self.AveragingMethod == 'vertical':
815            self.performanceTabCosts.setEnabled(0)
816        elif self.AveragingMethod == 'threshold':
817            self.performanceTabCosts.setEnabled(0)
818        else:
819            self.performanceTabCosts.setEnabled(0)
820
821        self.convexCurvesQCB.setEnabled(self.AveragingMethod == 'merge' or self.AveragingMethod == None)
822        self.performanceTabCosts.setEnabled(self.AveragingMethod == 'merge')
823
824        for g in self.graphs:
825            g.setAveragingMethod(self.AveragingMethod)
826
827    ## class selection (classQLB)
828    def target(self):
829        for i in range(len(self.graphs)):
830            self.graphs[i].hide()
831
832        if (self.targetClass <> None) and (len(self.graphs) > 0):
833            if self.targetClass >= len(self.graphs):
834                self.targetClass = len(self.graphs) - 1
835            if self.targetClass < 0:
836                self.targetClass = 0
838            self.graphs[self.targetClass].show()
839
840            self.FPcost = self.FPcostList[self.targetClass]
841            self.FNcost = self.FNcostList[self.targetClass]
842            self.pvalue = self.pvalueList[self.targetClass]
843    ##
844
845    ## classifiers selection (classifiersQLB)
847        list = []
848        for i in range(self.classifiersQLB.count()):
849            if self.classifiersQLB.item(i).isSelected():
850                list.append( 1 )
851            else:
852                list.append( 0 )
853        for g in self.graphs:
854            g.setShowClassifiers(list)
855
856    def setShowConvexCurves(self):
857        for g in self.graphs:
858            g.setShowConvexCurves(self.ShowConvexCurves)
859
860    def setShowConvexHull(self):
861        for g in self.graphs:
862            g.setShowConvexHull(self.ShowConvexHull)
863    ##
864
865    def setShowPerformanceAnalysis(self):
866        for g in self.graphs:
867            g.setShowPerformanceLine(self.EnablePerformance)
868
869    def setShowDefaultThresholdPoint(self):
870        for g in self.graphs:
871            g.setShowDefaultThresholdPoint(self.DefaultThresholdPoint)
872
873    ## test set selection (testSetsQLB)
874    def testSetsSelectionChange(self):
875        list = []
876        for i in range(self.testSetsQLB.count()):
877            if self.testSetsQLB.item(i).isSelected():
878                list.append( 1 )
879            else:
880                list.append( 0 )
881        for g in self.graphs:
882            g.setShowIterations(list)
883    ##
884
885    def calcAllClassGraphs(self):
886        for (cl, g) in enumerate(self.graphs):
887            g.setNumberOfClassifiersIterationsAndClassifierColors(self.dres.classifierNames, self.numberOfIterations, self.classifierColor)
888            g.setTestSetData(self.dresSplitByIterations, cl)
889            g.setShowConvexCurves(self.ShowConvexCurves)
890            g.setShowConvexHull(self.ShowConvexHull)
891            g.setAveragingMethod(self.AveragingMethod)
892            g.setShowPerformanceLine(self.EnablePerformance)
893            g.setShowDefaultThresholdPoint(self.DefaultThresholdPoint)
894
895            ## user settings
896            g.setPointWidth(self.PointWidth)
897            g.setCurveWidth(self.CurveWidth)
898            g.setConvexCurveWidth(self.ConvexCurveWidth)
899            g.setShowDiagonal(self.ShowDiagonal)
900            g.setConvexHullCurveWidth(self.ConvexHullCurveWidth)
901            g.setHullColor(QColor(self.HullColor))
902
903    def removeGraphs(self):
904        for g in self.graphs:
905            g.removeCurves()
906            g.hide()
907
908    def costsChanged(self):
909        if self.targetClass <> None and (len(self.graphs) > 0):
910            self.FPcostList[self.targetClass] = self.FPcost
911            self.FNcostList[self.targetClass] = self.FNcost
912            self.graphs[self.targetClass].costChanged(self.FPcost, self.FNcost)
913
914    def pvaluesUpdated(self):
915        if (self.targetClass == None) or (len(self.graphs) == 0): return
916
917        ## update p values
918        if self.pvalue > self.maxpsum - (len(self.pvalueList) - 1):
919            self.pvalue = self.maxpsum - (len(self.pvalueList) - 1)
920
921        self.pvalueList[self.targetClass] = self.pvalue ## set new value
922        sum = int(statc.sum(self.pvalueList))
923        ## adjust for big changes
924        distrib = []
925        for vi in range(len(self.pvalueList)):
926            if vi == self.targetClass:
927                distrib.append(0.0)
928            else:
929                distrib.append(0.0 if sum - self.pvalue == 0  else self.pvalueList[vi] / float(sum - self.pvalue))
930
931        dif = self.maxpsum - sum
932        for vi in range(len(distrib)):
933            self.pvalueList[vi] += int(float(dif) * distrib[vi])
934            if self.pvalueList[vi] < self.minp:
935                self.pvalueList[vi] = self.minp
936            if self.pvalueList[vi] > self.maxp:
937                self.pvalueList[vi] = self.maxp
938
939        ## small changes
940        dif = self.maxpsum - int(statc.sum(self.pvalueList))
941        while abs(dif) > 0 and len(self.pvalueList) > 1:
942            if dif > 0: vi = self.pvalueList.index(min(self.pvalueList[:self.targetClass] + [self.maxp + 1] + self.pvalueList[self.targetClass+1:]))
943            else: vi = self.pvalueList.index(max(self.pvalueList[:self.targetClass] + [self.minp - 1] + self.pvalueList[self.targetClass+1:]))
944
945            if dif > 0: self.pvalueList[vi] += 1
946            elif dif < 0: self.pvalueList[vi] -= 1
947
948            if self.pvalueList[vi] < self.minp: self.pvalueList[vi] = self.minp
949            if self.pvalueList[vi] > self.maxp: self.pvalueList[vi] = self.maxp
950            dif = self.maxpsum - int(statc.sum(self.pvalueList))
951
952        ## apply new pvalues
953        for (index, graph) in enumerate(self.graphs):
954            graph.pChanged(float(self.pvalueList[index]) / float(self.maxp))
955
956    def setDefaultPValues(self):
957        if self.defaultPerfLinePValues and self.targetClass != None:
958            self.pvaluesList = [v for v in self.defaultPerfLinePValues]
959            self.pvalue = self.pvaluesList[self.targetClass]
960            self.pvaluesUpdated()
961
962    def test_results(self, dres):
963        self.FPcostList = []
964        self.FNcostList = []
965        self.pvalueList = []
966
967        self.closeContext()
968
969        self.warning([0, 1])
970
971        if dres is not None and dres.class_values is None:
972            self.warning(1, "ROC cannot be used for regression results.")
973            dres = None
974
975        if not dres:
976            self.targetClass = None
977            self.classCombo.clear()
978            self.testSetsQLB.clear()
979            self.classifiersQLB.clear()
980            self.removeGraphs()
981            self.openContext("", dres)
982            return
983
984        if dres and dres.test_type != TEST_TYPE_SINGLE:
985            self.warning(0, "ROC is implemented only for single-target prediction problems.")
986            return
987
988        self.dres = dres
989
990        self.classifiersQLB.clear()
991        self.testSetsQLB.clear()
992        self.removeGraphs()
993        self.classCombo.clear()
994
995        self.defaultPerfLinePValues = []
996        if self.dres <> None:
997            ## classQLB
998            self.numberOfClasses = len(self.dres.classValues)
999            self.graphs = []
1000
1001            for i in range(self.numberOfClasses):
1002                self.FPcostList.append( 500)
1003                self.FNcostList.append( 500)
1004                graph = singleClassROCgraph(self.mainArea, "", "Predicted class: " + self.dres.classValues[i])
1005                self.graphs.append( graph )
1007
1008            ## classifiersQLB
1009            self.classifierColor = []
1010            self.numberOfClassifiers = self.dres.numberOfLearners
1011            if self.numberOfClassifiers > 1:
1012                allCforHSV = self.numberOfClassifiers - 1
1013            else:
1014                allCforHSV = self.numberOfClassifiers
1015            for i in range(self.numberOfClassifiers):
1016                newColor = QColor()
1017                newColor.setHsv(i*255/allCforHSV, 255, 255)
1018                self.classifierColor.append( newColor )
1019
1020            ## testSetsQLB
1021            self.dresSplitByIterations = orngStat.splitByIterations(self.dres)
1022            self.numberOfIterations = len(self.dresSplitByIterations)
1023
1024            self.calcAllClassGraphs()
1025
1026            ## classifiersQLB
1027            for i in range(self.numberOfClassifiers):
1028                newColor = self.classifierColor[i]
1030            self.classifiersQLB.selectAll()
1031
1032            ## testSetsQLB
1033            self.testSetsQLB.addItems([str(i) for i in range(self.numberOfIterations)])
1034            self.testSetsQLB.selectAll()
1035
1036            ## calculate default pvalues
1037            reminder = self.maxp
1038            for f in orngStat.classProbabilitiesFromRes(self.dres):
1039                v = int(round(f * self.maxp))
1040                reminder -= v
1041                if reminder < 0:
1042                    v = v+reminder
1043                self.defaultPerfLinePValues.append(v)
1044                self.pvalueList.append( v)
1045
1046            self.targetClass = 0 ## select first target
1047            self.target()
1048        else:
1049            self.classifierColor = None
1050        self.openContext("", self.dres)
1051        self.performanceTabCosts.setEnabled(self.AveragingMethod == 'merge')
1052        self.setDefaultPValues()
1053
1054if __name__ == "__main__":
1055    a = QApplication(sys.argv)
1056    owdm = OWROC()
1057    owdm.show()
1058    a.exec_()
Note: See TracBrowser for help on using the repository browser.