source: orange/orange/OrangeWidgets/Evaluate/OWROC.py @ 9505:4b798678cd3d

Revision 9505:4b798678cd3d, 43.4 KB checked in by matija <matija.polajnar@…>, 2 years ago (diff)

Merge in the (heavily modified) MLC code from GSOC 2011 (modules, documentation, evaluation code, regression test). Widgets will be merged in a little bit later, which will finally close ticket #992.

Line 
1"""
2<name>ROC Analysis</name>
3<description>Displays Receiver Operating Characteristics curve based on evaluation of classifiers.</description>
4<contact>Tomaz Curk</contact>
5<icon>icons/ROCAnalysis.png</icon>
6<priority>1010</priority>
7"""
8from OWColorPalette import ColorPixmap
9from OWWidget import *
10from OWGraph import *
11import OWGUI
12import orngStat, orngTest
13import statc, math
14
15def TCconvexHull(curves):
16    ## merge curves into one
17    mergedCurve = []
18    for c in curves:
19        mergedCurve.extend(c)
20    mergedCurve.sort() ## increasing by fp, tp
21
22    if len(mergedCurve) == 0: return []
23
24    hull = []
25    (prevX, maxY, fscore) = (mergedCurve[0] + (0.0,))[:3]
26    prevPfscore = [fscore]
27    px = prevX
28    for p in mergedCurve[1:]:
29        (px, py, fscore) = (p + (0.0,))[:3]
30        if (px == prevX):
31            if py > maxY:
32                prevPfscore = [fscore]
33                maxY = py
34            elif py == maxY:
35                prevPfscore.append(fscore)
36        elif (px > prevX):
37            hull = orngStat.ROCaddPoint((prevX, maxY, prevPfscore), hull, keepConcavities=0)
38            prevX = px
39            maxY = py
40            prevPfscore = [fscore]
41    hull = orngStat.ROCaddPoint((prevX, maxY, prevPfscore), hull, keepConcavities=0)
42
43    return hull
44
45class singleClassROCgraph(OWGraph):
46    def __init__(self, parent = None, name = None, title = ""):
47        OWGraph.__init__(self, parent, name)
48
49        self.setYRlabels(None)
50        self.enableGridXB(0)
51        self.enableGridYL(0)
52        self.setAxisMaxMajor(QwtPlot.xBottom, 10)
53        self.setAxisMaxMinor(QwtPlot.xBottom, 5)
54        self.setAxisMaxMajor(QwtPlot.yLeft, 10)
55        self.setAxisMaxMinor(QwtPlot.yLeft, 5)
56        self.setAxisScale(QwtPlot.xBottom, -0.0, 1.0, 0)
57        self.setAxisScale(QwtPlot.yLeft, -0.0, 1.0, 0)
58        self.setShowXaxisTitle(1)
59        self.setXaxisTitle("FP Rate (1-Specificity)")
60        self.setShowYLaxisTitle(1)
61        self.setYLaxisTitle("TP Rate (Sensitivity)")
62        self.setMainTitle(title)
63        self.setShowMainTitle(1)
64        self.targetClass = 0
65        self.averagingMethod = None
66        self.splitByIterations = None
67        self.VTAsamples = 10 ## vertical threshold averaging, number of samples
68        self.FPcost = 500.0
69        self.FNcost = 500.0
70        self.pvalue = 400.0 ##0.400
71
72        self.performanceLineSymbol = QwtSymbol(QwtSymbol.Ellipse, QBrush(Qt.color0), QPen(Qt.black), QSize(7,7))
73        self.defaultLineSymbol = QwtSymbol(QwtSymbol.Ellipse, QBrush(Qt.black), QPen(Qt.black), QSize(8,8))
74        self.convexHullPen = QPen(Qt.yellow, 3)
75
76        self.removeMarkers()
77        self.performanceMarkerKeys = []
78
79        self.removeCurves()
80
81    def computeCurve(self, res, classIndex=-1, keepConcavities=1):
82        return orngStat.TCcomputeROC(res, classIndex, keepConcavities)
83
84    def setNumberOfClassifiersIterationsAndClassifierColors(self, classifierNames, iterationsNum, classifierColor):
85        classifiersNum = len(classifierNames)
86        self.removeCurves()
87        self.classifierColor = classifierColor
88        self.classifierNames = classifierNames
89
90        for cNum in range(classifiersNum):
91            self.classifierIterationCKeys.append([])
92            self.classifierIterationConvexCKeys.append([])
93            self.classifierIterationROCdata.append([])
94            for iNum in range(iterationsNum):
95                curve = self.addCurve('', pen = QPen(self.classifierColor[cNum], 3), style=QwtPlotCurve.Lines)
96                self.classifierIterationCKeys[cNum].append(curve)
97                curve = self.addCurve('', pen = QPen(self.classifierColor[cNum], 1), style=QwtPlotCurve.Lines)
98                self.classifierIterationConvexCKeys[cNum].append(curve)
99                self.classifierIterationROCdata[cNum].append(None)
100
101            self.showClassifiers.append(0)
102            self.showIterations.append(0)
103
104            ## 'merge' average curve keys
105            curve = self.addCurve('', pen = QPen(self.classifierColor[cNum], 2), style=QwtPlotCurve.Lines)
106            self.mergedCKeys.append(curve)
107
108            curve = self.addCurve('', pen = QPen(Qt.black, 5), style=QwtPlotCurve.Lines)
109            curve.setSymbol(self.defaultLineSymbol)
110           
111            self.mergedCThresholdKeys.append(curve)
112            self.mergedCThresholdMarkers.append([])
113
114            curve = self.addCurve('', pen = QPen(self.classifierColor[cNum], 1), style=QwtPlotCurve.Lines)
115##            self.mergedConvexCKeys.append(QPen(Qt.black, 5))
116            self.mergedConvexCKeys.append(curve)
117
118            newSymbol = QwtSymbol(QwtSymbol.NoSymbol, QBrush(Qt.color0), QPen(self.classifierColor[cNum], 2), QSize(0,0))
119            ## 'vertical' average curve keys
120            curve = errorBarQwtPlotCurve('', connectPoints = 1, tickXw = 1.0/self.VTAsamples/5.0)
121            curve.attach(self)
122            curve.setSymbol(newSymbol)
123            curve.setStyle(QwtPlotCurve.UserCurve)
124            self.verticalCKeys.append(curve)
125
126            ## 'threshold' average curve keys
127            curve = errorBarQwtPlotCurve('', connectPoints = 1, tickXw = 1.0/self.VTAsamples/5.0, tickYw = 1.0/self.VTAsamples/5.0, showVerticalErrorBar = 1, showHorizontalErrorBar = 1)
128            curve.attach(self)
129            curve.setSymbol(newSymbol)
130            curve.setStyle(QwtPlotCurve.UserCurve)
131            self.thresholdCKeys.append(curve)
132
133        ## iso-performance line on top of all curves
134        self.performanceLineCKey = self.addCurve('', pen = QPen(Qt.black, 2), style=QwtPlotCurve.Lines)
135        self.performanceLineCKey.setSymbol(self.performanceLineSymbol)
136
137    def removeCurves(self):
138        self.clear()
139        self.classifierColor = []
140        self.classifierNames = []
141        self.classifierIterationROCdata = []
142        self.showClassifiers = []
143        self.showIterations = []
144        self.showConvexCurves = 0
145        self.showConvexHull = 0
146        self.showPerformanceLine = 0
147        self.showDefaultThresholdPoint = 0
148        self.showDiagonal = 0
149
150        ## 'merge' average curve keys
151        self.mergedCKeys = []
152        self.mergedCThresholdKeys = []
153        self.mergedCThresholdMarkers = []
154        self.mergedConvexCKeys = []
155        ## 'vertical' average curve keys
156        self.verticalCKeys = []
157        ## 'threshold' average curve keys
158        self.thresholdCKeys = []
159        ## 'None' average curve keys
160        self.classifierIterationCKeys = []
161        self.classifierIterationConvexCKeys = []
162
163        ## convex hull calculation
164        self.mergedConvexHullData = []
165        self.verticalConvexHullData = []
166        self.thresholdConvexHullData = []
167        self.classifierConvexHullData = []
168        self.hullCurveDataForPerfLine = [] ## for performance analysis
169
170        ## diagonal curve
171        self.diagonalCKey = self.addCurve('', pen = QPen(Qt.black, 1), symbol = QwtSymbol.NoSymbol, xData = [0.0, 1.0], yData = [0.0, 1.0], style=QwtPlotCurve.Lines)
172
173        ## convex hull curve keys
174        self.mergedConvexHullCKey = self.addCurve('', pen = self.convexHullPen, style=QwtPlotCurve.Lines)
175        self.verticalConvexHullCKey = self.addCurve('', pen = self.convexHullPen, style=QwtPlotCurve.Lines)
176        self.thresholdConvexHullCKey = self.addCurve('', pen = self.convexHullPen, style=QwtPlotCurve.Lines)
177        self.classifierConvexHullCKey = self.addCurve('', pen = self.convexHullPen, style=QwtPlotCurve.Lines)
178
179        ## iso-performance line
180        self.performanceLineCKey = None
181
182    def setIterationCurves(self, iteration, curves):
183        classifier = 0
184        for c in curves:
185            x = [px for (px, py, pf) in c]
186            y = [py for (px, py, pf) in c]
187            curve = self.classifierIterationCKeys[classifier][iteration]
188            curve.setData(x, y)
189            self.classifierIterationROCdata[classifier][iteration] = c
190            classifier += 1
191
192    def setIterationConvexCurves(self, iteration, curves):
193        classifier = 0
194        for c in curves:
195            x = [px for (px, py, pf) in c]
196            y = [py for (px, py, pf) in c]
197            curve = self.classifierIterationConvexCKeys[classifier][iteration]
198            curve.setData(x, y)
199            classifier += 1
200
201    def setTestSetData(self, splitByIterations, targetClass):
202        self.splitByIterations = splitByIterations
203        ## generate the "base" unmodified ROC curves
204        self.targetClass = targetClass
205        iteration = 0
206        for isplit in splitByIterations:
207            # unmodified ROC curve
208            curves = self.computeCurve(isplit, self.targetClass, 1)
209            self.setIterationCurves(iteration, curves)
210
211            # convex ROC curve
212            curves = self.computeCurve(isplit, self.targetClass, 0)
213            self.setIterationConvexCurves(iteration, curves)
214            iteration += 1
215
216    def updateCurveDisplay(self):
217        self.diagonalCKey.setVisible(self.showDiagonal)
218
219        showSomething = 0
220        for cNum in range(len(zip(self.showClassifiers, self.mergedCKeys))):
221            showCNum = (self.showClassifiers[cNum] <> 0)
222
223            ## 'merge' averaging
224            b = (self.averagingMethod == 'merge') and showCNum
225            showSomething = showSomething or b
226            if self.mergedCKeys[cNum]:
227                self.mergedCKeys[cNum].setVisible(b)
228
229            b2 = b and self.showDefaultThresholdPoint
230            if self.mergedCThresholdKeys[cNum] <> None:
231                self.mergedCThresholdKeys[cNum].setVisible(b2)
232
233            for marker in self.mergedCThresholdMarkers[cNum]:
234                marker.setVisible(b2)
235
236            b = b and self.showConvexCurves
237            if self.mergedConvexCKeys[cNum] <> None:
238                self.mergedConvexCKeys[cNum].setVisible(b)
239
240            ## 'vertical' averaging
241            b = (self.averagingMethod == 'vertical') and showCNum
242            showSomething = showSomething or b
243            self.verticalCKeys[cNum].setVisible(b)
244
245            ## 'threshold' averaging
246            b = (self.averagingMethod == 'threshold') and showCNum
247            showSomething = showSomething or b
248            self.thresholdCKeys[cNum].setVisible(b)
249
250            ## 'None' averaging
251            for iNum in range(len(zip(self.showIterations, self.classifierIterationCKeys[cNum], self.classifierIterationConvexCKeys[cNum]))):
252                b = (self.averagingMethod == None) and showCNum and (self.showIterations[iNum] <> 0)
253                showSomething = showSomething or b
254                self.classifierIterationCKeys[cNum][iNum].setVisible(b)
255                b = b and self.showConvexCurves
256                self.classifierIterationConvexCKeys[cNum][iNum].setVisible(b)
257
258        chb = (showSomething) and (self.averagingMethod == None) and self.showConvexHull
259        if self.classifierConvexHullCKey:
260            self.classifierConvexHullCKey.setVisible(chb)
261
262        chb = (showSomething) and (self.averagingMethod == 'merge') and self.showConvexHull
263        if self.mergedConvexHullCKey:
264            self.mergedConvexHullCKey.setVisible(chb)
265
266        chb = (showSomething) and (self.averagingMethod == 'vertical') and self.showConvexHull
267        if self.verticalConvexHullCKey:
268            self.verticalConvexHullCKey.setVisible(chb)
269
270        chb = (showSomething) and (self.averagingMethod == 'threshold') and self.showConvexHull
271        if self.thresholdConvexHullCKey:
272            self.thresholdConvexHullCKey.setVisible(chb)
273
274        ## performance line
275        b = (self.averagingMethod == 'merge') and self.showPerformanceLine
276        for marker in self.performanceMarkerKeys:
277            marker.setVisible(b)
278
279        if self.performanceLineCKey:
280##            curve.setVisible(b)
281            self.performanceLineCKey.setVisible(b)
282
283##        self.updateLayout()
284##        self.update()
285        self.replot()
286
287    def setShowConvexCurves(self, b):
288        self.showConvexCurves = b
289        self.updateCurveDisplay()
290
291    def setShowConvexHull(self, b):
292        self.showConvexHull = b
293        self.updateCurveDisplay()
294
295    def setShowPerformanceLine(self, b):
296        self.showPerformanceLine = b
297        self.updateCurveDisplay()
298
299    def setShowDefaultThresholdPoint(self, b):
300        self.showDefaultThresholdPoint = b
301        self.updateCurveDisplay()
302
303    def setShowClassifiers(self, list):
304        self.showClassifiers = list
305        self.calcConvexHulls()
306        self.calcUpdatePerformanceLine() ## new data for performance line
307        self.updateCurveDisplay()
308
309    def setShowIterations(self, list):
310        self.showIterations = list
311        self.calcAverageCurves()
312        self.calcConvexHulls()
313        self.calcUpdatePerformanceLine() ## new data for performance line
314        self.updateCurveDisplay()
315
316    ## calculate the average curve for the selected test sets (with all the averaging methods)
317    def calcAverageCurves(self):
318        ##
319        ## self.averagingMethod == 'merge':
320        mergedIterations = orngTest.ExperimentResults(1, self.splitByIterations[0].classifierNames, self.splitByIterations[0].classValues, self.splitByIterations[0].weights, classifiers=self.splitByIterations[0].classifiers, loaded=self.splitByIterations[0].loaded)
321        for i, (isplit, show) in enumerate(zip(self.splitByIterations, self.showIterations)):
322            if show:
323                for te in isplit.results:
324                    mergedIterations.results.append( te )
325                   
326        self.mergedConvexHullData = []
327        if len(mergedIterations.results) > 0:
328            curves = self.computeCurve(mergedIterations, self.targetClass, 1)
329            convexCurves = self.computeCurve(mergedIterations, self.targetClass, 0)
330            classifier = 0
331            for c in curves:
332                x = [px for (px, py, pf) in c]
333                y = [py for (px, py, pf) in c]
334                self.mergedCKeys[classifier].setData(x, y)
335
336                # points of defualt threshold classifiers
337                defPoint = [(abs(pf-0.5), pf, px, py) for (px, py, pf) in c]
338                defPoints = []
339                if len(defPoint) > 0:
340                    defPoint.sort()
341                    defPoints = [(px, py, pf) for (d, pf, px, py) in defPoint if d == defPoint[0][0]]
342                else:
343                    defPoints = []
344                defX = [px for (px, py, pf) in defPoints]
345                defY = [py for (px, py, pf) in defPoints]
346                self.mergedCThresholdKeys[classifier].setData(defX, defY)
347
348                for marker in self.mergedCThresholdMarkers[classifier]:
349                    marker.detach()
350                self.mergedCThresholdMarkers[classifier] = []
351                for (dx, dy, pf) in defPoints:
352                    dx = max(min(0.95, dx + 0.01), 0.01)
353                    dy = min(max(0.01, dy - 0.02), 0.95)
354                    marker = self.addMarker('%3.2g' % (pf), dx, dy, alignment = Qt.AlignRight)
355                    self.mergedCThresholdMarkers[classifier].append(marker)
356                classifier += 1
357            classifier = 0
358            for c in convexCurves:
359                self.mergedConvexHullData.append(c) ## put all points of all curves into one big array
360                x = [px for (px, py, pf) in c]
361                y = [py for (px, py, pf) in c]
362                curve = self.mergedConvexCKeys[classifier]
363                curve.setData(x, y)
364                classifier += 1
365        else:
366            for c in range(len(self.mergedCKeys)):
367                self.mergedCKeys[c].setData([], [])
368                self.mergedCThresholdKeys[c].setData([], [])
369                for marker in self.mergedCThresholdMarkers[c]:
370                    marker.detach()
371                self.mergedCThresholdMarkers[c] = []
372                self.mergedConvexCKeys[c].setData([], [])
373
374        ## prepare a common input structure for vertical and threshold averaging
375        ROCS = []
376        somethingToShow = 0
377        for c in self.classifierIterationROCdata:
378            ROCS.append([])
379            i = 0
380            for s in self.showIterations:
381                if s:
382                    somethingToShow = 1
383                    ROCS[-1].append(c[i])
384                i += 1
385
386        ## remove curve
387        self.verticalConvexHullData = []
388        self.thresholdConvexHullData = []
389
390        if somethingToShow == 0:
391            for curve in self.verticalCKeys:
392                curve.setData([], [])
393
394            for curve in self.thresholdCKeys:
395                curve.setData([], [])
396            return
397
398        ##
399        ## self.averagingMethod == 'vertical':
400        ## calculated from the self.classifierIterationROCdata data
401        try:
402            (averageCurves, verticalErrorBarValues) = orngStat.TCverticalAverageROC(ROCS, self.VTAsamples)
403        except (ValueError, SystemError), er:
404            print >> sys.stderr, "Failed to compute vertical average ROC curve. " + er.message
405            averageCurves, verticalErrorBarValues = [], []
406            pass
407         
408        classifier = 0
409        for c in averageCurves:
410            self.verticalConvexHullData.append(c)
411            xs = []
412            mps = []
413            for pcn in range(len(c)):
414                (px, py) = c[pcn]
415                ## for the error bar plot
416                xs.append(px)
417                mps.append(py + 0.0)
418
419                xs.append(px)
420                mps.append(py + verticalErrorBarValues[classifier][pcn])
421
422                xs.append(px)
423                mps.append(py - verticalErrorBarValues[classifier][pcn])
424
425            ckey =  self.verticalCKeys[classifier]
426            ckey.setData(xs, mps)
427            classifier += 1
428
429        ##
430        ## self.averagingMethod == 'threshold':
431        ## calculated from the self.classifierIterationROCdata data
432        (averageCurves, verticalErrorBarValues, horizontalErrorBarValues) = orngStat.TCthresholdlAverageROC(ROCS, self.VTAsamples)
433        classifier = 0
434        for c in averageCurves:
435            self.thresholdConvexHullData.append(c)
436            xs = []
437            mps = []
438            for pcn in range(len(c)):
439                (px, py) = c[pcn]
440                ## for the error bar plot
441                xs.append(px + 0.0)
442                mps.append(py + 0.0)
443
444                xs.append(px - horizontalErrorBarValues[classifier][pcn])
445                mps.append(py + verticalErrorBarValues[classifier][pcn])
446
447                xs.append(px + horizontalErrorBarValues[classifier][pcn])
448                mps.append(py - verticalErrorBarValues[classifier][pcn])
449
450            ckey = self.thresholdCKeys[classifier]
451            ckey.setData(xs, mps)
452            classifier += 1
453
454        ## self.averagingMethod == 'None'
455        ## already calculated
456
457    def calcConvexHulls(self):
458        ## self.classifierConvexHullCKey = -1
459        hullData = []
460        for cNum in range(len(self.showClassifiers)):
461            for iNum in range(len(self.showIterations)):
462                if (self.showClassifiers[cNum] <> 0) and (self.showIterations[iNum] <> 0):
463                    hullData.append(self.classifierIterationROCdata[cNum][iNum])
464
465        convexHullCurve = TCconvexHull(hullData)
466        x = [px for (px, py, pf) in convexHullCurve]
467        y = [py for (px, py, pf) in convexHullCurve]
468        self.classifierConvexHullCKey.setData(x, y)
469
470        ## self.mergedConvexHullCKey = -1
471        hullData = []
472        for cNum in range(len(self.mergedConvexHullData)):
473            if (self.showClassifiers[cNum] <> 0):
474                ncurve = []
475                for (px, py, pfscore) in self.mergedConvexHullData[cNum]:
476                    ncurve.append( (px, py, (cNum, pfscore)) )
477                hullData.append(ncurve)
478
479        self.hullCurveDataForPerfLine = TCconvexHull(hullData) # keep data about curve for performance line drawing
480        x = [px for (px, py, pf) in self.hullCurveDataForPerfLine]
481        y = [py for (px, py, pf) in self.hullCurveDataForPerfLine]
482        self.mergedConvexHullCKey.setData(x, y)
483
484        ## self.verticalConvexHullCKey = -1
485        hullData = []
486        for cNum in range(len(self.verticalConvexHullData)):
487            if (self.showClassifiers[cNum] <> 0):
488                hullData.append(self.verticalConvexHullData[cNum])
489
490        convexHullCurve = TCconvexHull(hullData)
491        x = [px for (px, py, pf) in convexHullCurve]
492        y = [py for (px, py, pf) in convexHullCurve]
493        self.verticalConvexHullCKey.setData(x, y)
494
495        ## self.thresholdConvexHullCKey = -1
496        hullData = []
497        for cNum in range(len(self.thresholdConvexHullData)):
498            if (self.showClassifiers[cNum] <> 0):
499                hullData.append(self.thresholdConvexHullData[cNum])
500
501        convexHullCurve = TCconvexHull(hullData)
502        x = [px for (px, py, pf) in convexHullCurve]
503        y = [py for (px, py, pf) in convexHullCurve]
504        self.thresholdConvexHullCKey.setData(x, y)
505
506    def setAveragingMethod(self, m):
507        self.averagingMethod = m
508        self.updateCurveDisplay()
509
510    ## performance line
511    def calcUpdatePerformanceLine(self):
512        closestpoints = orngStat.TCbestThresholdsOnROCcurve(self.FPcost, self.FNcost, self.pvalue, self.hullCurveDataForPerfLine)
513        m = (self.FPcost*(1.0 - self.pvalue)) / (self.FNcost*self.pvalue)
514
515        ## now draw the closest line to the curve
516        b = (self.averagingMethod == 'merge') and self.showPerformanceLine
517        lpx = []
518        lpy = []
519        first = 1
520        ## remove old markers
521        for marker in self.performanceMarkerKeys:
522            try:
523                marker.detach()
524            except RuntimeError: ## RuntimeError: underlying C/C++ object has been deleted
525                pass
526        self.performanceMarkerKeys = []
527        for (x, y, fscorelist) in closestpoints:
528            if first:
529                first = 0
530                lpx.append(x - 2.0)
531                lpy.append(y - 2.0*m)
532            lpx.append(x)
533            lpy.append(y)
534            px = x
535            py = y
536            for (cNum, threshold) in fscorelist:
537                s = "%1.3f %s" % (threshold, self.classifierNames[cNum])
538                px = max(min(0.95, px + 0.01), 0.01)
539                py = min(max(0.01, py - 0.02), 0.95)
540                marker = self.addMarker(s, px, py, alignment = Qt.AlignRight)
541                marker.setVisible(b)
542                self.performanceMarkerKeys.append(marker)
543        if len(closestpoints) > 0:
544            lpx.append(x + 2.0)
545            lpy.append(y + 2.0*m)
546
547        if self.performanceLineCKey:
548            self.performanceLineCKey.setData(lpx, lpy)
549            self.performanceLineCKey.setVisible(b)
550        self.replot()
551#        self.update()
552
553    def costChanged(self, FPcost, FNcost):
554        self.FPcost = float(FPcost)
555        self.FNcost = float(FNcost)
556        self.calcUpdatePerformanceLine()
557
558    def pChanged(self, pvalue):
559        self.pvalue = float(pvalue)
560        self.calcUpdatePerformanceLine()
561
562    def setPointWidth(self, v):
563        self.performanceLineSymbol.setSize(v, v)
564        if self.performanceLineCKey:
565            self.performanceLineCKey.setSymbol(self.performanceLineSymbol)
566           
567        def setW(curve):
568            sym = curve.symbol() #.setPen(QPen(self.classifierColor[cNum], v))
569            sym.setSize(v + 1, v + 1)
570            if QWT_VERSION_STR >= "5.2": # in Qwt 5.1.* curve.setSymbol results in a crash
571                curve.setSymbol(sym)
572           
573        for item in self.itemList():
574            setW(item)
575           
576#        for cNum in range(len(zip(self.showClassifiers, self.mergedCKeys))):
577#            setW(self.mergedCKeys[cNum]) #.setPen(QPen(self.classifierColor[cNum], v))
578#            setW(self.verticalCKeys[cNum]) #.setPen(QPen(self.classifierColor[cNum], v))
579#            setW(self.thresholdCKeys[cNum]) #.setPen(QPen(self.classifierColor[cNum], v))
580#            setW(self.mergedCThresholdKeys[cNum])
581#            for iNum in range(len(zip(self.showIterations, self.classifierIterationCKeys[cNum]))):
582#                setW(self.classifierIterationCKeys[cNum][iNum]) #.setPen(QPen(self.classifierColor[cNum], v))
583        self.replot()
584#        self.update()
585
586    def setCurveWidth(self, v):
587        for cNum in range(len(zip(self.showClassifiers, self.mergedCKeys))):
588            self.mergedCKeys[cNum].setPen(QPen(self.classifierColor[cNum], v))
589            self.verticalCKeys[cNum].setPen(QPen(self.classifierColor[cNum], v))
590            self.thresholdCKeys[cNum].setPen(QPen(self.classifierColor[cNum], v))
591            for iNum in range(len(zip(self.showIterations, self.classifierIterationCKeys[cNum]))):
592                self.classifierIterationCKeys[cNum][iNum].setPen(QPen(self.classifierColor[cNum], v))
593        self.replot()
594#        self.update()
595
596    def setConvexCurveWidth(self, v):
597        for cNum in range(len(zip(self.showClassifiers, self.mergedConvexCKeys))):
598            self.mergedConvexCKeys[cNum].setPen(QPen(self.classifierColor[cNum], v))
599            for iNum in range(len(zip(self.showIterations, self.classifierIterationConvexCKeys[cNum]))):
600                self.classifierIterationConvexCKeys[cNum][iNum].setPen(QPen(self.classifierColor[cNum], v))
601        self.replot()
602#        self.update()
603
604    def setShowDiagonal(self, v):
605        self.showDiagonal = v
606        self.updateCurveDisplay()
607
608    def setConvexHullCurveWidth(self, v):
609        self.convexHullPen.setWidth(v)
610        self.mergedConvexHullCKey.setPen(self.convexHullPen)
611        self.verticalConvexHullCKey.setPen(self.convexHullPen)
612        self.thresholdConvexHullCKey.setPen(self.convexHullPen)
613        self.classifierConvexHullCKey.setPen(self.convexHullPen)
614        self.replot()
615#        self.update()
616
617    def setHullColor(self, c):
618        self.convexHullPen.setColor(c)
619        self.mergedConvexHullCKey.setPen(self.convexHullPen)
620        self.verticalConvexHullCKey.setPen(self.convexHullPen)
621        self.thresholdConvexHullCKey.setPen(self.convexHullPen)
622        self.classifierConvexHullCKey.setPen(self.convexHullPen)
623        self.replot()
624#        self.update()
625
626    def sizeHint(self):
627        return QSize(100, 100)
628
629class OWROC(OWWidget):
630    settingsList = ["PointWidth", "CurveWidth", "ConvexCurveWidth", "ShowDiagonal",
631                    "ConvexHullCurveWidth", "HullColor", "AveragingMethodIndex",
632                    "ShowConvexHull", "ShowConvexCurves", "EnablePerformance", "DefaultThresholdPoint"]
633    contextHandlers = {"": EvaluationResultsContextHandler("", "targetClass", "selectedClassifiers")}
634
635    def __init__(self,parent=None, signalManager = None):
636        OWWidget.__init__(self, parent, signalManager, "ROC Analysis", 1)
637
638        # inputs
639        self.inputs=[("Evaluation Results", orngTest.ExperimentResults, self.test_results, Default)]
640
641        # default settings
642        self.PointWidth = 7
643        self.CurveWidth = 3
644        self.ConvexCurveWidth = 1
645        self.ShowDiagonal = TRUE
646        self.ConvexHullCurveWidth = 3
647        self.HullColor = str(QColor(Qt.yellow).name())
648        self.AveragingMethodIndex = 0 ##'merge'
649        self.ShowConvexHull = TRUE
650        self.ShowConvexCurves = FALSE
651        self.EnablePerformance = TRUE
652        self.DefaultThresholdPoint = TRUE
653
654        #load settings
655        self.loadSettings()
656
657        # temp variables
658        self.dres = None
659        self.classifierColor = None
660        self.numberOfClasses  = 0
661        self.targetClass = None
662        self.numberOfClassifiers = 0
663        self.numberOfIterations = 0
664        self.graphs = []
665        self.maxp = 1000
666        self.defaultPerfLinePValues = []
667        self.classifiers = []
668        self.selectedClassifiers = []
669
670        # performance analysis (temporary values
671        self.FPcost = 500.0
672        self.FNcost = 500.0
673        self.pvalue = 50.0 ##0.400
674
675        # list of values (remember for each class)
676        self.FPcostList = []
677        self.FNcostList = []
678        self.pvalueList = []
679
680        self.AveragingMethodNames = ['merge', 'vertical', 'threshold', None]
681        self.AveragingMethod = self.AveragingMethodNames[min(3, self.AveragingMethodIndex)]
682
683        # GUI
684        import sip
685        sip.delete(self.mainArea.layout())
686        self.graphsGridLayoutQGL = QGridLayout(self.mainArea)
687        self.mainArea.setLayout(self.graphsGridLayoutQGL)
688        # save each ROC graph in separate file
689        self.connect(self.graphButton, SIGNAL("clicked()"), self.saveToFile)
690
691        ## general tab
692        self.tabs = OWGUI.tabWidget(self.controlArea)
693        self.generalTab = OWGUI.createTabPage(self.tabs, "General")
694
695        ## target class
696        self.classCombo = OWGUI.comboBox(self.generalTab, self, 'targetClass', box='Target class', items=[], callback=self.target)
697        #self.classCombo.setMaximumSize(150, 20)
698
699        ## classifiers selection (classifiersQLB)
700        self.classifiersQVGB = OWGUI.widgetBox(self.generalTab, "Classifiers")
701        self.classifiersQLB = OWGUI.listBox(self.classifiersQVGB, self, "selectedClassifiers", selectionMode = QListWidget.MultiSelection, callback = self.classifiersSelectionChange)
702        self.unselectAllClassifiersQLB = OWGUI.button(self.classifiersQVGB, self, "(Un)select All", callback = self.SUAclassifiersQLB)
703
704        # show convex ROC curves and show ROC convex hull
705        self.convexCurvesQCB = OWGUI.checkBox(self.generalTab, self, 'ShowConvexCurves', 'Show convex ROC curves', tooltip='', callback=self.setShowConvexCurves)
706        OWGUI.checkBox(self.generalTab, self, 'ShowConvexHull', 'Show ROC convex hull', tooltip='', callback=self.setShowConvexHull)
707       
708
709        # performance analysis
710        self.performanceTab = OWGUI.createTabPage(self.tabs, "Analysis")
711        self.performanceTabCosts = OWGUI.widgetBox(self.performanceTab, box = 1)
712        OWGUI.checkBox(self.performanceTabCosts, self, 'EnablePerformance', 'Show performance line', tooltip='', callback=self.setShowPerformanceAnalysis)
713        OWGUI.checkBox(self.performanceTabCosts, self, 'DefaultThresholdPoint', 'Default threshold (0.5) point', tooltip='', callback=self.setShowDefaultThresholdPoint)
714
715        ## FP and FN cost ranges
716        mincost = 1; maxcost = 1000; stepcost = 5;
717        self.maxpsum = 100; self.minp = 1; self.maxp = self.maxpsum - self.minp ## need it also in self.pvaluesUpdated
718        stepp = 1.0
719
720        OWGUI.hSlider(self.performanceTabCosts, self, 'FPcost', box='FP Cost', minValue=mincost, maxValue=maxcost, step=stepcost, callback=self.costsChanged, ticks=50)
721        OWGUI.hSlider(self.performanceTabCosts, self, 'FNcost', box='FN Cost', minValue=mincost, maxValue=maxcost, step=stepcost, callback=self.costsChanged, ticks=50)
722
723        ptc = OWGUI.widgetBox(self.performanceTabCosts, "Prior target class probability [%]")
724        OWGUI.hSlider(ptc, self, 'pvalue', minValue=self.minp, maxValue=self.maxp, step=stepp, callback=self.pvaluesUpdated, ticks=5, labelFormat="%2.1f")
725        OWGUI.button(ptc, self, 'Compute from data', self.setDefaultPValues) ## reset p values to default
726
727        ## test set selection (testSetsQLB)
728        self.testSetsQVGB = OWGUI.widgetBox(self.performanceTab, "Test sets")
729        self.testSetsQLB = OWGUI.listBox(self.testSetsQVGB, self, selectionMode = QListWidget.MultiSelection, callback = self.testSetsSelectionChange)
730        self.unselectAllTestSetsQLB = OWGUI.button(self.testSetsQVGB, self, "(Un)select All", callback = self.SUAtestSetsQLB)
731
732        # settings tab
733        self.settingsTab = OWGUI.createTabPage(self.tabs, "Settings")
734        OWGUI.radioButtonsInBox(self.settingsTab, self, 'AveragingMethodIndex', ['Merge (expected ROC perf.)', 'Vertical', 'Threshold', 'None'], box='Averaging ROC curves', callback=self.selectAveragingMethod)
735        OWGUI.hSlider(self.settingsTab, self, 'PointWidth', box='Point width', minValue=0, maxValue=9, step=1, callback=self.setPointWidth, ticks=1)
736        OWGUI.hSlider(self.settingsTab, self, 'CurveWidth', box='ROC curve width', minValue=1, maxValue=5, step=1, callback=self.setCurveWidth, ticks=1)
737        OWGUI.hSlider(self.settingsTab, self, 'ConvexCurveWidth', box='ROC convex curve width', minValue=1, maxValue=5, step=1, callback=self.setConvexCurveWidth, ticks=1)
738        OWGUI.hSlider(self.settingsTab, self, 'ConvexHullCurveWidth', box='ROC convex hull', minValue=2, maxValue=9, step=1, callback=self.setConvexHullCurveWidth, ticks=1)
739        OWGUI.checkBox(self.settingsTab, self, 'ShowDiagonal', 'Show diagonal ROC line', tooltip='', callback=self.setShowDiagonal)
740        self.settingsTab.layout().addStretch(100)
741     
742        self.resize(800, 600)
743
744    def sendReport(self):
745        # need to reimport - Qt provides something stupid instead
746        from __builtin__ import hex
747        self.reportSettings("Settings",
748                            [("Classifiers", ", ".join('<font color="#%s">%s</font>' % ("".join(("0"+hex(x)[2:])[-2:] for x in self.classifierColor[cNum].getRgb()[:3]), str(item.text()))
749                                                        for cNum, item in enumerate(self.classifiersQLB.item(i) for i in range(self.classifiersQLB.count()))
750                                                          if item.isSelected())),
751                             ("Target class", self.classCombo.itemText(self.targetClass)
752                                              if self.targetClass is not None else
753                                              "N/A"),
754                             ("Costs", "FP=%i, FN=%i" % (self.FPcost, self.FNcost)),
755                             ("Prior target class probability", "%i%%" % self.pvalue)
756                            ])
757        if self.targetClass is not None:
758            self.reportRaw("<br/>")
759            self.reportImage(self.graphs[self.targetClass].saveToFileDirect, QSize(400, 400))
760       
761    def saveToFile(self):
762        for g in self.graphs:
763            if g.isVisible():
764                g.saveToFile()
765
766    def setPointWidth(self):
767        for g in self.graphs:
768            g.setPointWidth(self.PointWidth)
769
770    def setCurveWidth(self):
771        for g in self.graphs:
772            g.setCurveWidth(self.CurveWidth)
773
774    def setConvexCurveWidth(self):
775        for g in self.graphs:
776            g.setConvexCurveWidth(self.ConvexCurveWidth)
777
778    def setShowDiagonal(self):
779        for g in self.graphs:
780            g.setShowDiagonal(self.ShowDiagonal)
781
782    def setConvexHullCurveWidth(self):
783        for g in self.graphs:
784            g.setConvexHullCurveWidth(self.ConvexHullCurveWidth)
785
786    def setHullColor(self):
787        self.HullColor = str(c.name())
788        for g in self.graphs:
789            g.setHullColor(self.HullColor)
790
791    ##
792    def selectUnselectAll(self, qlb):
793        selected = 0
794        for i in range(qlb.count()):
795            if qlb.item(i).isSelected():
796                selected = 1
797                break
798        if selected: qlb.clearSelection()
799        else: qlb.selectAll()
800
801    def SUAclassifiersQLB(self):
802        self.selectUnselectAll(self.classifiersQLB)
803
804    def SUAtestSetsQLB(self):
805        self.selectUnselectAll(self.testSetsQLB)
806    ##
807
808    def selectAveragingMethod(self):
809        self.AveragingMethod = self.AveragingMethodNames[self.AveragingMethodIndex]
810        if self.AveragingMethod == 'merge':
811            self.performanceTabCosts.setEnabled(self.EnablePerformance)
812        elif self.AveragingMethod == 'vertical':
813            self.performanceTabCosts.setEnabled(0)
814        elif self.AveragingMethod == 'threshold':
815            self.performanceTabCosts.setEnabled(0)
816        else:
817            self.performanceTabCosts.setEnabled(0)
818
819        self.convexCurvesQCB.setEnabled(self.AveragingMethod == 'merge' or self.AveragingMethod == None)
820        self.performanceTabCosts.setEnabled(self.AveragingMethod == 'merge')
821
822        for g in self.graphs:
823            g.setAveragingMethod(self.AveragingMethod)
824
825    ## class selection (classQLB)
826    def target(self):
827        for i in range(len(self.graphs)):
828            self.graphs[i].hide()
829
830        if (self.targetClass <> None) and (len(self.graphs) > 0):
831            if self.targetClass >= len(self.graphs):
832                self.targetClass = len(self.graphs) - 1
833            if self.targetClass < 0:
834                self.targetClass = 0
835            self.graphsGridLayoutQGL.addWidget(self.graphs[self.targetClass], 0, 0)
836            self.graphs[self.targetClass].show()
837
838            self.FPcost = self.FPcostList[self.targetClass]
839            self.FNcost = self.FNcostList[self.targetClass]
840            self.pvalue = self.pvalueList[self.targetClass]
841    ##
842
843    ## classifiers selection (classifiersQLB)
844    def classifiersSelectionChange(self):
845        list = []
846        for i in range(self.classifiersQLB.count()):
847            if self.classifiersQLB.item(i).isSelected():
848                list.append( 1 )
849            else:
850                list.append( 0 )
851        for g in self.graphs:
852            g.setShowClassifiers(list)
853
854    def setShowConvexCurves(self):
855        for g in self.graphs:
856            g.setShowConvexCurves(self.ShowConvexCurves)
857
858    def setShowConvexHull(self):
859        for g in self.graphs:
860            g.setShowConvexHull(self.ShowConvexHull)
861    ##
862
863    def setShowPerformanceAnalysis(self):
864        for g in self.graphs:
865            g.setShowPerformanceLine(self.EnablePerformance)
866
867    def setShowDefaultThresholdPoint(self):
868        for g in self.graphs:
869            g.setShowDefaultThresholdPoint(self.DefaultThresholdPoint)
870
871    ## test set selection (testSetsQLB)
872    def testSetsSelectionChange(self):
873        list = []
874        for i in range(self.testSetsQLB.count()):
875            if self.testSetsQLB.item(i).isSelected():
876                list.append( 1 )
877            else:
878                list.append( 0 )
879        for g in self.graphs:
880            g.setShowIterations(list)
881    ##
882
883    def calcAllClassGraphs(self):
884        for (cl, g) in enumerate(self.graphs):
885            g.setNumberOfClassifiersIterationsAndClassifierColors(self.dres.classifierNames, self.numberOfIterations, self.classifierColor)
886            g.setTestSetData(self.dresSplitByIterations, cl)
887            g.setShowConvexCurves(self.ShowConvexCurves)
888            g.setShowConvexHull(self.ShowConvexHull)
889            g.setAveragingMethod(self.AveragingMethod)
890            g.setShowPerformanceLine(self.EnablePerformance)
891            g.setShowDefaultThresholdPoint(self.DefaultThresholdPoint)
892
893            ## user settings
894            g.setPointWidth(self.PointWidth)
895            g.setCurveWidth(self.CurveWidth)
896            g.setConvexCurveWidth(self.ConvexCurveWidth)
897            g.setShowDiagonal(self.ShowDiagonal)
898            g.setConvexHullCurveWidth(self.ConvexHullCurveWidth)
899            g.setHullColor(QColor(self.HullColor))
900
901    def removeGraphs(self):
902        for g in self.graphs:
903            g.removeCurves()
904            g.hide()
905
906    def costsChanged(self):
907        if self.targetClass <> None and (len(self.graphs) > 0):
908            self.FPcostList[self.targetClass] = self.FPcost
909            self.FNcostList[self.targetClass] = self.FNcost
910            self.graphs[self.targetClass].costChanged(self.FPcost, self.FNcost)
911
912    def pvaluesUpdated(self):
913        if (self.targetClass == None) or (len(self.graphs) == 0): return
914
915        ## update p values
916        if self.pvalue > self.maxpsum - (len(self.pvalueList) - 1):
917            self.pvalue = self.maxpsum - (len(self.pvalueList) - 1)
918
919        self.pvalueList[self.targetClass] = self.pvalue ## set new value
920        sum = int(statc.sum(self.pvalueList))
921        ## adjust for big changes
922        distrib = []
923        for vi in range(len(self.pvalueList)):
924            if vi == self.targetClass:
925                distrib.append(0.0)
926            else:
927                distrib.append(0.0 if sum - self.pvalue == 0  else self.pvalueList[vi] / float(sum - self.pvalue))
928
929        dif = self.maxpsum - sum
930        for vi in range(len(distrib)):
931            self.pvalueList[vi] += int(float(dif) * distrib[vi])
932            if self.pvalueList[vi] < self.minp:
933                self.pvalueList[vi] = self.minp
934            if self.pvalueList[vi] > self.maxp:
935                self.pvalueList[vi] = self.maxp
936
937        ## small changes
938        dif = self.maxpsum - int(statc.sum(self.pvalueList))
939        while abs(dif) > 0 and len(self.pvalueList) > 1:
940            if dif > 0: vi = self.pvalueList.index(min(self.pvalueList[:self.targetClass] + [self.maxp + 1] + self.pvalueList[self.targetClass+1:]))
941            else: vi = self.pvalueList.index(max(self.pvalueList[:self.targetClass] + [self.minp - 1] + self.pvalueList[self.targetClass+1:]))
942           
943            if dif > 0: self.pvalueList[vi] += 1
944            elif dif < 0: self.pvalueList[vi] -= 1
945
946            if self.pvalueList[vi] < self.minp: self.pvalueList[vi] = self.minp
947            if self.pvalueList[vi] > self.maxp: self.pvalueList[vi] = self.maxp
948            dif = self.maxpsum - int(statc.sum(self.pvalueList))
949
950        ## apply new pvalues
951        for (index, graph) in enumerate(self.graphs):
952            graph.pChanged(float(self.pvalueList[index]) / float(self.maxp))
953
954    def setDefaultPValues(self):
955        if self.defaultPerfLinePValues and self.targetClass != None:
956            self.pvaluesList = [v for v in self.defaultPerfLinePValues]
957            self.pvalue = self.pvaluesList[self.targetClass]
958            self.pvaluesUpdated()
959
960    def test_results(self, dres):
961        self.FPcostList = []
962        self.FNcostList = []
963        self.pvalueList = []
964
965        self.closeContext()
966
967        if not dres:
968            self.targetClass = None
969            self.classCombo.clear()
970            self.testSetsQLB.clear()
971            self.classifiersQLB.clear()
972            self.removeGraphs()
973            self.openContext("", dres)
974            return
975        self.dres = dres
976
977        self.classifiersQLB.clear()
978        self.testSetsQLB.clear()
979        self.removeGraphs()
980        self.classCombo.clear()
981
982        self.defaultPerfLinePValues = []
983        if self.dres <> None:
984            ## classQLB
985            self.numberOfClasses = len(self.dres.classValues)
986            self.graphs = []
987
988            for i in range(self.numberOfClasses):
989                self.FPcostList.append( 500)
990                self.FNcostList.append( 500)
991                graph = singleClassROCgraph(self.mainArea, "", "Predicted class: " + self.dres.classValues[i])
992                self.graphs.append( graph )
993                self.classCombo.addItem(self.dres.classValues[i])
994
995            ## classifiersQLB
996            self.classifierColor = []
997            self.numberOfClassifiers = self.dres.numberOfLearners
998            if self.numberOfClassifiers > 1:
999                allCforHSV = self.numberOfClassifiers - 1
1000            else:
1001                allCforHSV = self.numberOfClassifiers
1002            for i in range(self.numberOfClassifiers):
1003                newColor = QColor()
1004                newColor.setHsv(i*255/allCforHSV, 255, 255)
1005                self.classifierColor.append( newColor )
1006
1007            ## testSetsQLB
1008            self.dresSplitByIterations = orngStat.splitByIterations(self.dres)
1009            self.numberOfIterations = len(self.dresSplitByIterations)
1010
1011            self.calcAllClassGraphs()
1012
1013            ## classifiersQLB
1014            for i in range(self.numberOfClassifiers):
1015                newColor = self.classifierColor[i]
1016                self.classifiersQLB.addItem(QListWidgetItem(ColorPixmap(newColor), self.dres.classifierNames[i]))
1017            self.classifiersQLB.selectAll()
1018
1019            ## testSetsQLB
1020            self.testSetsQLB.addItems([str(i) for i in range(self.numberOfIterations)])
1021            self.testSetsQLB.selectAll()
1022
1023            ## calculate default pvalues
1024            reminder = self.maxp
1025            for f in orngStat.classProbabilitiesFromRes(self.dres):
1026                v = int(round(f * self.maxp))
1027                reminder -= v
1028                if reminder < 0:
1029                    v = v+reminder
1030                self.defaultPerfLinePValues.append(v)
1031                self.pvalueList.append( v)
1032
1033            self.targetClass = 0 ## select first target
1034            self.target()
1035        else:
1036            self.classifierColor = None
1037        self.openContext("", self.dres)
1038        self.performanceTabCosts.setEnabled(self.AveragingMethod == 'merge')
1039        self.setDefaultPValues()
1040
1041if __name__ == "__main__":
1042    a = QApplication(sys.argv)
1043    owdm = OWROC()
1044    owdm.show()
1045    a.exec_()
Note: See TracBrowser for help on using the repository browser.