source: orange/orange/OrangeWidgets/Prototypes/OWPerformanceCurves.py @ 6913:845d764fe147

Revision 6913:845d764fe147, 7.3 KB checked in by janezd <janez.demsar@…>, 4 years ago (diff)
  • added widgets developed within the multilingual43 branch
Line 
1"""<name>Performance Curves</name>
2<description>Model performance at different thresholds</description>
3<icon>icons/PerformanceCurves.png</icon>
4<priority>30</priority>
5<contact>Janez Demsar (janez.demsar@fri.uni-lj.si)</contact>"""
6
7from OWWidget import *
8from OWGUI import *
9from PyQt4.QtGui import *
10from PyQt4.QtCore import *
11from OWDlgs import OWChooseImageSizeDlg
12import sip
13import orngTest
14from OWGraph import *
15
16class PerformanceGraph(OWGraph):
17    def __init__(self, master, *arg):
18        OWGraph.__init__(self, *arg)
19        self.master = master
20        self.mousePressed = False
21       
22    def mousePressEvent(self, e):
23        self.mousePressed = True
24        canvasPos = self.canvas().mapFrom(self, e.pos())
25        self.master.thresholdChanged(self.invTransform(QwtPlot.xBottom, canvasPos.x()))
26
27    def mouseReleaseEvent(self, e):
28        self.mousePressed = False
29       
30    def mouseMoveEvent(self, e):
31        if self.mousePressed:
32            self.mousePressEvent(e)
33
34# Remove if this widget ever goes multilingual!
35_ = lambda x:x
36
37class OWPerformanceCurves(OWWidget):
38    settingsList = ["selectedScores", "threshold"]
39
40    def __init__(self, parent=None, signalManager=None, name="Performance Curves"):
41        OWWidget.__init__(self, parent, signalManager, name)
42        self.inputs=[("Evaluation Results", orngTest.ExperimentResults, self.setTestResults, Default)]
43        self.outputs=[]
44
45        self.selectedScores = []
46        self.classifiers = []
47        self.selectedClassifier = []
48        self.targetClass = -1
49        self.threshold = 0.5
50        self.thresholdCurve = None
51        self.statistics = ""
52
53        self.resize(980, 420)
54        self.loadSettings()
55
56        self.scores = [_('Classification accuracy'), _('Sensitivity (Recall)'), _('Specificity'),
57                      _('Positive predictive value (Precision)'), _('Negative predictive value'),
58                      _('F-measure')]
59        self.colors = [Qt.black, Qt.green, Qt.darkRed,
60                       Qt.blue, Qt.red,
61                       QColor(255, 128, 0)]
62        self.res = None
63        self.allScores = None
64
65        OWGUI.listBox(self.controlArea, self, 'selectedClassifier', 'classifiers', box = "Models", callback=self.classifierChanged, selectionMode = QListWidget.SingleSelection)
66        self.comTarget = OWGUI.comboBox(self.controlArea, self, 'targetClass', box="Target Class", callback=self.classifierChanged, valueType=0)
67        OWGUI.listBox(self.controlArea, self, 'selectedScores', 'scores', box = _("Performance scores"), callback=self.selectionChanged, selectionMode = QListWidget.MultiSelection)
68
69        sip.delete(self.mainArea.layout())
70        self.layout = QHBoxLayout(self.mainArea)
71       
72        self.dottedGrayPen = QPen(QBrush(Qt.gray), 1, Qt.DotLine)
73        self.graph = graph = PerformanceGraph(self, self.mainArea)
74        graph.state = NOTHING
75        graph.setAxisScale(QwtPlot.xBottom, 0.0, 1.0, 0.0)
76        graph.setAxisScale(QwtPlot.yLeft, 0.0, 1.0, 0.0)
77        graph.useAntialiasing = True
78        graph.insertLegend(QwtLegend(), QwtPlot.BottomLegend)
79        graph.gridCurve.enableY(True)
80        graph.gridCurve.setMajPen(self.dottedGrayPen)
81        graph.gridCurve.attach(graph)
82        self.mainArea.layout().addWidget(graph)
83
84        b1 = OWGUI.widgetBox(self.mainArea, "Statistics")
85        OWGUI.label(b1, self, "%(statistics)s").setTextFormat(Qt.RichText)
86        OWGUI.rubber(b1)
87       
88        self.controlArea.setFixedWidth(220)
89   
90    def setTestResults(self, res):
91        self.res = res
92        if res and res.classifierNames:
93            self.classifiers = res.classifierNames
94            self.selectedClassifier = [0]
95            self.comTarget.clear()
96            self.comTarget.addItems(self.res.classValues)
97            self.targetClass=min(1, len(self.res.classValues))
98            self.classifierChanged()
99        else:
100            self.graph.clear()
101            self.thresholdCurve = None
102            self.allScores = None
103
104    def classifierChanged(self):
105        self.allScores = []
106        self.probs = []
107        classNo = self.selectedClassifier[0]
108        probsClasses = sorted((tex.probabilities[classNo][self.targetClass], self.targetClass==tex.actualClass) for tex in self.res.results)
109        self.all = all = len(probsClasses)
110        TP = self.P = P = float(sum(x[1] for x in probsClasses))
111        FP = self.N = N = all-P
112        TN = FN = 0.
113        prevprob = probsClasses[0][0]
114        for Nc, (prob, kls) in enumerate(probsClasses):
115            if kls:
116                TP -= 1
117                FN += 1
118            else:
119                FP -= 1
120                TN += 1
121            if prevprob != prob:
122                self.allScores.append(((TP+TN)/all, TP/(P or 1), TN/(N or 1), TP/(all-Nc), TN/Nc, 2*TP/(P+all-Nc), TP, TN, FP, FN, Nc))
123                self.probs.append(prevprob)
124            prevprob = prob
125        self.allScores.append(((TP+TN)/all, TP/(P or 1), TN/(N or 1), TP/(all-Nc), TN/Nc, 2*TP/(P+all-Nc), TP, TN, FP, FN, Nc))
126        self.probs.append(prevprob)
127        self.allScores = zip(*self.allScores)
128        self.selectionChanged()
129       
130    def selectionChanged(self):
131        self.graph.clear()
132        self.thresholdCurve = None
133        if not self.allScores:
134            return           
135        for c in self.selectedScores:
136            self.graph.addCurve(self.scores[c], self.colors[c], self.colors[c], 1, xData=self.probs, yData=self.allScores[c], style = QwtPlotCurve.Lines, symbol = QwtSymbol.NoSymbol, lineWidth=3, enableLegend=1)
137        self.thresholdChanged()
138        # self.graph.replot is called in thresholdChanged
139       
140    def thresholdChanged(self, threshold=None):
141        if threshold is not None:
142            self.threshold = threshold
143        if self.thresholdCurve:
144            self.thresholdCurve.detach()
145        self.thresholdCurve = self.graph.addCurve("threshold", Qt.black, Qt.black, 1, xData=[self.threshold]*2, yData=[0,1], style=QwtPlotCurve.Lines, symbol = QwtSymbol.NoSymbol, lineWidth=1)
146        self.graph.replot()
147        if not self.allScores:
148            self.statistics = ""
149            return
150        ind = 0
151        while self.probs[ind] < self.threshold and ind+1 < len(self.probs):
152            ind += 1
153        alls = self.allScores
154        stat = "<b>Sample size: %i instances</b><br/>  Positive: %i<br/>  Negative: %i<br/><br/>" % (self.all, self.P, self.N)
155        stat += "<b>Current threshold: %.2f</b><br/><br/>" % self.threshold
156        stat += "<b>Positive predictions: %i</b><br/>  True positive: %i<br/>  False positive: %i<br/><br/>" % (self.all-alls[-1][ind], alls[-5][ind], alls[-3][ind])
157        stat += "<b>Negative predictions: %i</b><br/>  True negative: %i<br/>  False negative: %i<br/><br/>" % (alls[-1][ind], alls[-4][ind], alls[-3][ind])
158        if self.selectedScores:
159            stat += "<b>Performance</b><br/>"
160        stat += "<br/>".join("%s: %.2f" % (self.scores[i], alls[i][ind]) for i in self.selectedScores)
161        self.statistics = stat
162       
163    def sendReport(self):
164        if self.res:
165            self.reportSettings(_("Performance Curves"), 
166                                [(_("Model"), self.res.classifierNames[self.selectedClassifier[0]]),
167                                 (_("Target class"), self.res.classValues[self.targetClass])])
168            self.reportImage(self.graph.saveToFileDirect, QSize(790, 390))
169            self.reportSection("Performance")
170            self.reportRaw(self.statistics)
Note: See TracBrowser for help on using the repository browser.