source: orange/docs/extend-widgets/rst/OWLearningCurveC.py @ 11593:6edc44eb9655

Revision 11593:6edc44eb9655, 12.4 KB checked in by Ales Erjavec <ales.erjavec@…>, 10 months ago (diff)

Updated Widget development tutorial.

RevLine 
[11049]1"""
2<name>Learning Curve (C)</name>
3<description>Takes a data set and a set of learners and plots a learning curve in a table</description>
4<icon>icons/LearningCurveC.png</icon>
5<priority>1020</priority>
6"""
7
[11593]8import Orange
9
[11049]10from OWWidget import *
11from OWColorPalette import ColorPixmap
12from OWGraph import *
13
[11593]14import OWGUI
15
[11049]16import warnings
17
18class OWLearningCurveC(OWWidget):
19    settingsList = ["folds", "steps", "scoringF", "commitOnChange",
20                    "graphPointSize", "graphDrawLines", "graphShowGrid"]
21
22    def __init__(self, parent=None, signalManager=None):
23        OWWidget.__init__(self, parent, signalManager, 'LearningCurveC')
24
[11593]25        self.inputs = [("Data", Orange.data.Table, self.dataset),
26                       ("Learner", Orange.classification.Learner,
27                        self.learner, Multiple)]
[11049]28
29        self.folds = 5     # cross validation folds
30        self.steps = 10    # points in the learning curve
31        self.scoringF = 0  # scoring function
32        self.commitOnChange = 1 # compute curve on any change of parameters
33        self.graphPointSize = 5 # size of points in the graphs
34        self.graphDrawLines = 1 # draw lines between points in the graph
35        self.graphShowGrid = 1  # show gridlines in the graph
[11593]36        self.selectedLearners = []
37
[11049]38        self.loadSettings()
39
[11593]40        warnings.filterwarnings("ignore", ".*builtin attribute.*", Orange.core.AttributeWarning)
[11049]41
[11593]42        self.setCurvePoints() # sets self.curvePoints, self.steps equidistant points from 1/self.steps to 1
43        self.scoring = [("Classification Accuracy",
44                         Orange.evaluation.scoring.CA),
45                        ("AUC", Orange.evaluation.scoring.AUC),
46                        ("BrierScore", Orange.evaluation.scoring.Brier_score),
47                        ("Information Score", Orange.evaluation.scoring.IS),
48                        ("Sensitivity", Orange.evaluation.scoring.sens),
49                        ("Specificity", Orange.evaluation.scoring.spec)]
[11049]50        self.learners = [] # list of current learners from input channel, tuples (id, learner)
51        self.data = None   # data on which to construct the learning curve
52        self.curves = []   # list of evaluation results (one per learning curve point)
53        self.scores = []   # list of current scores, learnerID:[learner scores]
54
55        # GUI
56        box = OWGUI.widgetBox(self.controlArea, "Info")
57        self.infoa = OWGUI.widgetLabel(box, 'No data on input.')
58        self.infob = OWGUI.widgetLabel(box, 'No learners.')
59
60        ## class selection (classQLB)
61        OWGUI.separator(self.controlArea)
[11593]62
63        # ~SPHINX start color cb~
[11049]64        self.cbox = OWGUI.widgetBox(self.controlArea, "Learners")
[11593]65        self.llb = OWGUI.listBox(self.cbox, self, "selectedLearners",
66                                 selectionMode=QListWidget.MultiSelection,
67                                 callback=self.learnerSelectionChanged)
[11049]68       
69        self.llb.setMinimumHeight(50)
70        self.blockSelectionChanges = 0
[11593]71        # ~SPHINX end color cb~
[11049]72
73        OWGUI.separator(self.controlArea)
[11593]74
[11049]75        box = OWGUI.widgetBox(self.controlArea, "Evaluation Scores")
76        scoringNames = [x[0] for x in self.scoring]
77        OWGUI.comboBox(box, self, "scoringF", items=scoringNames,
78                       callback=self.computeScores)
79
80        OWGUI.separator(self.controlArea)
[11593]81
[11049]82        box = OWGUI.widgetBox(self.controlArea, "Options")
83        OWGUI.spin(box, self, 'folds', 2, 100, step=1,
84                   label='Cross validation folds:  ',
85                   callback=lambda: self.computeCurve(self.commitOnChange))
86        OWGUI.spin(box, self, 'steps', 2, 100, step=1,
87                   label='Learning curve points:  ',
[11593]88                   callback=[self.setCurvePoints,
89                             lambda: self.computeCurve(self.commitOnChange)])
[11049]90
91        OWGUI.checkBox(box, self, 'commitOnChange', 'Apply setting on any change')
[11593]92        self.commitBtn = OWGUI.button(box, self, "Apply Setting",
93                                      callback=self.computeCurve, disabled=1)
[11049]94
[11593]95        # ~SPHINX start main area tabs~
[11049]96        # start of content (right) area
97        tabs = OWGUI.tabWidget(self.mainArea)
98
[11593]99        # graph tab
[11049]100        tab = OWGUI.createTabPage(tabs, "Graph")
101        self.graph = OWGraph(tab)
102        self.graph.setAxisAutoScale(QwtPlot.xBottom)
103        self.graph.setAxisAutoScale(QwtPlot.yLeft)
104        tab.layout().addWidget(self.graph)
105        self.setGraphGrid()
106
[11593]107        # table tab
[11049]108        tab = OWGUI.createTabPage(tabs, "Table")
109        self.table = OWGUI.table(tab, selectionMode=QTableWidget.NoSelection)
[11593]110        # ~SPHINX end main area tabs~
[11049]111
112        self.resize(550,200)
113
114    ##############################################################################
115    # slots: handle input signals
116
117    def dataset(self, data):
118        if data:
119            self.infoa.setText('%d instances in input data set' % len(data))
120            self.data = data
121            if (len(self.learners)):
122                self.computeCurve()
123            self.replotGraph()
124        else:
125            self.infoa.setText('No data on input.')
126            self.curves = []
127            self.scores = []
128            self.graph.removeDrawingCurves()
129            self.graph.replot()
130        self.commitBtn.setEnabled(self.data<>None)
131
132    # manage learner signal
133    # we use following additional attributes for learner:
134    # - isSelected, learner is selected (display the learning curve)
135    # - curve, learning curve for the learner
136    # - score, evaluation score for the learning
137    def learner(self, learner, id=None):
138        ids = [x[0] for x in self.learners]
139        if not learner: # remove a learner and corresponding results
140            if not ids.count(id):
141                return # no such learner, removed before
142            indx = ids.index(id)
143            for i in range(self.steps):
144                self.curves[i].remove(indx)
145            del self.scores[indx]
146            self.learners[indx][1].curve.detach()
147            del self.learners[indx]
148            self.setTable()
149            self.updatellb()
150        else:
151            if ids.count(id): # update (already seen a learner from this source)
152                indx = ids.index(id)
153                prevLearner = self.learners[indx][1]
154                learner.isSelected = prevLearner.isSelected
155                self.learners[indx] = (id, learner)
156                if self.data:
157                    curve = self.getLearningCurve([learner])
158                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
159                    self.scores[indx] = score
160                    for i in range(self.steps):
161                        self.curves[i].add(curve[i], 0, replace=indx)
162                    learner.score = score
163                    prevLearner.curve.detach()
164                    self.drawLearningCurve(learner)
165                self.updatellb()
166            else: # add new learner
167                learner.isSelected = 1
168                self.learners.append((id, learner))
169                if self.data:
170                    curve = self.getLearningCurve([learner])
171                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
172                    self.scores.append(score)
173                    if len(self.curves):
174                        for i in range(self.steps):
175                            self.curves[i].add(curve[i], 0)
176                    else:
177                        self.curves = curve
178                    learner.score = score
179                self.updatellb()
180                self.drawLearningCurve(learner)
181        if len(self.learners):
182            self.infob.setText("%d learners on input." % len(self.learners))
183        else:
184            self.infob.setText("No learners.")
185        self.commitBtn.setEnabled(len(self.learners))
186        if self.data:
187            self.setTable()
188
189    ##############################################################################
190    # learning curve table, callbacks
191
192    # recomputes the learning curve
193    def computeCurve(self, condition=1):
194        if condition:
195            learners = [x[1] for x in self.learners]
196            self.curves = self.getLearningCurve(learners)
197            self.computeScores()
198
199    def computeScores(self):
200        self.scores = [[] for i in range(len(self.learners))]
201        for x in self.curves:
202            for (i,s) in enumerate(self.scoring[self.scoringF][1](x)):
203                self.scores[i].append(s)
204        for (i,l) in enumerate(self.learners):
205            l[1].score = self.scores[i]
206        self.setTable()
207        self.replotGraph()
208
209    def getLearningCurve(self, learners):
210        pb = OWGUI.ProgressBar(self, iterations=self.steps*self.folds)
[11593]211        curve = Orange.evaluation.testing.learning_curve_n(
212            learners, self.data, folds=self.folds,
213            proportions=self.curvePoints,
214            callback=pb.advance)
215
[11049]216        pb.finish()
217        return curve
218
219    def setCurvePoints(self):
220        self.curvePoints = [(x+1.)/self.steps for x in range(self.steps)]
221
222    def setTable(self):
223        self.table.setColumnCount(0)
224        self.table.setColumnCount(len(self.learners))
225        self.table.setRowCount(self.steps)
226
227        # set the headers
228        self.table.setHorizontalHeaderLabels([l.name for i,l in self.learners])
229        self.table.setVerticalHeaderLabels(["%4.2f" % p for p in self.curvePoints])
230
231        # set the table contents
232        for l in range(len(self.learners)):
233            for p in range(self.steps):
234                OWGUI.tableItem(self.table, p, l, "%7.5f" % self.scores[l][p])
235
236        for i in range(len(self.learners)):
237            self.table.setColumnWidth(i, 80)
238
239
240    # management of learner selection
241
242    def updatellb(self):
243        self.blockSelectionChanges = 1
244        self.llb.clear()
245        colors = ColorPaletteHSV(len(self.learners))
246        for (i,lt) in enumerate(self.learners):
247            l = lt[1]
248            item = QListWidgetItem(ColorPixmap(colors[i]), l.name)
249            self.llb.addItem(item)
250            item.setSelected(l.isSelected)
251            l.color = colors[i]
252        self.blockSelectionChanges = 0
253
254    def learnerSelectionChanged(self):
255        if self.blockSelectionChanges: return
256        for (i,lt) in enumerate(self.learners):
257            l = lt[1]
258            if l.isSelected != (i in self.selectedLearners):
259                if l.isSelected: # learner was deselected
260                    l.curve.detach()
261                else: # learner was selected
262                    self.drawLearningCurve(l)
263                self.graph.replot()
264            l.isSelected = i in self.selectedLearners
265
266    # Graph specific methods
267
268    def setGraphGrid(self):
269        self.graph.enableGridYL(self.graphShowGrid)
270        self.graph.enableGridXB(self.graphShowGrid)
271
272    def setGraphStyle(self, learner):
273        curve = learner.curve
274        if self.graphDrawLines:
275            curve.setStyle(QwtPlotCurve.Lines)
276        else:
277            curve.setStyle(QwtPlotCurve.NoCurve)
[11593]278
279        curve.setSymbol(
280            QwtSymbol(QwtSymbol.Ellipse,
281                      QBrush(QColor(0,0,0)), QPen(QColor(0,0,0)),
282                      QSize(self.graphPointSize, self.graphPointSize)))
283
[11049]284        curve.setPen(QPen(learner.color, 5))
285
286    def drawLearningCurve(self, learner):
[11593]287        if not self.data:
288            return
289        curve = self.graph.addCurve(
290            learner.name,
291            xData=self.curvePoints,
292            yData=learner.score,
293            autoScale=True)
[11049]294       
295        learner.curve = curve
296        self.setGraphStyle(learner)
297        self.graph.replot()
298
299    def replotGraph(self):
300        self.graph.removeDrawingCurves()
301        for l in self.learners:
302            self.drawLearningCurve(l[1])
303
304
305if __name__=="__main__":
306    appl = QApplication(sys.argv)
307    ow = OWLearningCurveC()
308    ow.show()
309
[11593]310    l1 = Orange.classification.bayes.NaiveLearner()
[11049]311    l1.name = 'Naive Bayes'
312    ow.learner(l1, 1)
313
[11593]314    data = Orange.data.Table('iris.tab')
[11049]315    ow.dataset(data)
316
[11593]317    l2 = Orange.classification.bayes.NaiveLearner()
[11049]318    l2.name = 'Naive Bayes (m=10)'
[11593]319    l2.estimatorConstructor = Orange.statistics.estimate.M(m=10)
320    l2.conditionalEstimatorConstructor = Orange.statistics.estimate.ConditionalByRows(estimatorConstructor = Orange.statistics.estimate.M(m=10))
[11049]321
[11593]322    l3 = Orange.classification.knn.kNNLearner(name="k-NN")
[11049]323    ow.learner(l3, 3)
324
[11593]325    l4 = Orange.classification.tree.TreeLearner(minSubset=2)
[11049]326    l4.name = "Decision Tree"
327    ow.learner(l4, 4)
328
329#    ow.learner(None, 1)
330#    ow.learner(None, 2)
331#    ow.learner(None, 4)
332
333    appl.exec_()
Note: See TracBrowser for help on using the repository browser.