source: orange/docs/extend-widgets/rst/OWLearningCurve_plot.py @ 11881:99bec0d8a70d

Revision 11881:99bec0d8a70d, 11.8 KB checked in by Ales Erjavec <ales.erjavec@…>, 5 weeks ago (diff)

More fixes to widget development manual code snippets.

Line 
1"""
2<name>Learning Curve (C)</name>
3<description>Takes a data set and a set of learners and plots a learning curve in a table</description>
4<icon>icons/LearningCurveC.png</icon>
5<priority>1020</priority>
6"""
7
8from OWWidget import *
9from OWColorPalette import ColorPixmap
10import OWGUI, orngTest, orngStat
11from owplot import *
12
13import warnings
14
15class OWLearningCurveC(OWWidget):
16    settingsList = ["folds", "steps", "scoringF", "commitOnChange",
17                    "graphPointSize", "graphDrawLines", "graphShowGrid"]
18
19    def __init__(self, parent=None, signalManager=None):
20        OWWidget.__init__(self, parent, signalManager, 'LearningCurveC')
21
22        self.inputs = [("Data", ExampleTable, self.dataset),
23                       ("Learner", orange.Learner, self.learner, Multiple)]
24
25        self.folds = 5     # cross validation folds
26        self.steps = 10    # points in the learning curve
27        self.scoringF = 0  # scoring function
28        self.commitOnChange = 1 # compute curve on any change of parameters
29        self.graphPointSize = 5 # size of points in the graphs
30        self.graphDrawLines = 1 # draw lines between points in the graph
31        self.graphShowGrid = 1  # show gridlines in the graph
32        self.selectedLearners = [] 
33        self.loadSettings()
34
35        warnings.filterwarnings("ignore", ".*builtin attribute.*", orange.AttributeWarning)
36
37        self.updateCurvePoints() # sets self.curvePoints, self.steps equidistantpoints from 1/self.steps to 1
38        self.scoring = [("Classification Accuracy", orngStat.CA), ("AUC", orngStat.AUC), ("BrierScore", orngStat.BrierScore), ("Information Score", orngStat.IS), ("Sensitivity", orngStat.sens), ("Specificity", orngStat.spec)]
39        self.learners = [] # list of current learners from input channel, tuples (id, learner)
40        self.data = None   # data on which to construct the learning curve
41        self.curves = []   # list of evaluation results (one per learning curve point)
42        self.scores = []   # list of current scores, learnerID:[learner scores]
43
44        # GUI
45        box = OWGUI.widgetBox(self.controlArea, "Info")
46        self.infoa = OWGUI.widgetLabel(box, 'No data on input.')
47        self.infob = OWGUI.widgetLabel(box, 'No learners.')
48
49        ## class selection (classQLB)
50        OWGUI.separator(self.controlArea)
51        self.cbox = OWGUI.widgetBox(self.controlArea, "Learners")
52        self.llb = OWGUI.listBox(self.cbox, self, "selectedLearners", selectionMode=QListWidget.MultiSelection, callback=self.learnerSelectionChanged)
53       
54        self.llb.setMinimumHeight(50)
55        self.blockSelectionChanges = 0
56
57        OWGUI.separator(self.controlArea)
58        box = OWGUI.widgetBox(self.controlArea, "Evaluation Scores")
59        scoringNames = [x[0] for x in self.scoring]
60        OWGUI.comboBox(box, self, "scoringF", items=scoringNames,
61                       callback=self.computeScores)
62
63        OWGUI.separator(self.controlArea)
64        box = OWGUI.widgetBox(self.controlArea, "Options")
65        OWGUI.spin(box, self, 'folds', 2, 100, step=1,
66                   label='Cross validation folds:  ',
67                   callback=lambda: self.computeCurve() if self.commitOnChange else None)
68        OWGUI.spin(box, self, 'steps', 2, 100, step=1,
69                   label='Learning curve points:  ',
70                   callback=[self.updateCurvePoints,
71                             lambda: self.computeCurve() if self.commitOnChange else None])
72
73        OWGUI.checkBox(box, self, 'commitOnChange', 'Apply setting on any change')
74        self.commitBtn = OWGUI.button(box, self, "Apply Setting", callback=self.computeCurve, disabled=1)
75
76        # ~start main area tabs~
77        # start of content (right) area
78        tabs = OWGUI.tabWidget(self.mainArea)
79
80        # graph tab
81        tab = OWGUI.createTabPage(tabs, "Graph")
82        self.graph = OWPlot(tab)
83        self.graph.set_axis_autoscale(xBottom)
84        self.graph.set_axis_autoscale(yLeft)
85        tab.layout().addWidget(self.graph)
86        self.setGraphGrid()
87
88        # table tab
89        tab = OWGUI.createTabPage(tabs, "Table")
90        self.table = OWGUI.table(tab, selectionMode=QTableWidget.NoSelection)
91        # ~end main area tabs~
92
93        self.resize(550,200)
94
95    ##############################################################################
96    # slots: handle input signals
97
98    def dataset(self, data):
99        if data is not None:
100            self.infoa.setText('%d instances in input data set' % len(data))
101            self.data = data
102            if len(self.learners):
103                self.computeCurve()
104            self.replotGraph()
105        else:
106            self.infoa.setText('No data on input.')
107            self.curves = []
108            self.scores = []
109            self.graph.clear()
110            self.graph.replot()
111        self.commitBtn.setEnabled(self.data<>None)
112
113    # manage learner signal
114    # we use following additional attributes for learner:
115    # - isSelected, learner is selected (display the learning curve)
116    # - curve, learning curve for the learner
117    # - score, evaluation score for the learning
118    def learner(self, learner, id=None):
119        ids = [x[0] for x in self.learners]
120        if not learner: # remove a learner and corresponding results
121            if not ids.count(id):
122                return # no such learner, removed before
123            indx = ids.index(id)
124            for i in range(self.steps):
125                self.curves[i].remove(indx)
126            del self.scores[indx]
127            self.learners[indx][1].curve.detach()
128            del self.learners[indx]
129            self.setTable()
130            self.updatellb()
131        else:
132            if ids.count(id): # update (already seen a learner from this source)
133                indx = ids.index(id)
134                prevLearner = self.learners[indx][1]
135                learner.isSelected = prevLearner.isSelected
136                self.learners[indx] = (id, learner)
137                if self.data:
138                    curve = self.getLearningCurve([learner])
139                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
140                    self.scores[indx] = score
141                    for i in range(self.steps):
142                        self.curves[i].add(curve[i], 0, replace=indx)
143                    learner.score = score
144                    prevLearner.curve.detach()
145                    self.drawLearningCurve(learner)
146                self.updatellb()
147            else: # add new learner
148                learner.isSelected = 1
149                self.learners.append((id, learner))
150                if self.data:
151                    curve = self.getLearningCurve([learner])
152                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
153                    self.scores.append(score)
154                    if len(self.curves):
155                        for i in range(self.steps):
156                            self.curves[i].add(curve[i], 0)
157                    else:
158                        self.curves = curve
159                    learner.score = score
160                self.updatellb()
161                self.drawLearningCurve(learner)
162        if len(self.learners):
163            self.infob.setText("%d learners on input." % len(self.learners))
164        else:
165            self.infob.setText("No learners.")
166        self.commitBtn.setEnabled(len(self.learners))
167        if self.data:
168            self.setTable()
169
170    ##############################################################################
171    # learning curve table, callbacks
172
173    # recomputes the learning curve
174    def computeCurve(self):
175        learners = [x[1] for x in self.learners]
176        self.curves = self.getLearningCurve(learners)
177        self.computeScores()
178
179    def computeScores(self):
180        self.scores = [[] for i in range(len(self.learners))]
181        for x in self.curves:
182            for (i,s) in enumerate(self.scoring[self.scoringF][1](x)):
183                self.scores[i].append(s)
184        for (i,l) in enumerate(self.learners):
185            l[1].score = self.scores[i]
186        self.setTable()
187        self.replotGraph()
188
189    def getLearningCurve(self, learners):
190        pb = OWGUI.ProgressBar(self, iterations=self.steps*self.folds)
191        curve = orngTest.learningCurveN(learners, self.data, folds=self.folds, proportions=self.curvePoints, callback=pb.advance)
192        pb.finish()
193        return curve
194
195    def updateCurvePoints(self):
196        self.curvePoints = [(x+1.)/self.steps for x in range(self.steps)]
197
198    def setTable(self):
199        self.table.setColumnCount(0)
200        self.table.setColumnCount(len(self.learners))
201        self.table.setRowCount(self.steps)
202
203        # set the headers
204        self.table.setHorizontalHeaderLabels([l.name for i,l in self.learners])
205        self.table.setVerticalHeaderLabels(["%4.2f" % p for p in self.curvePoints])
206
207        # set the table contents
208        for l in range(len(self.learners)):
209            for p in range(self.steps):
210                OWGUI.tableItem(self.table, p, l, "%7.5f" % self.scores[l][p])
211
212        for i in range(len(self.learners)):
213            self.table.setColumnWidth(i, 80)
214
215
216    # management of learner selection
217
218    def updatellb(self):
219        self.blockSelectionChanges = 1
220        self.llb.clear()
221        colors = ColorPaletteHSV(len(self.learners))
222        for (i,lt) in enumerate(self.learners):
223            l = lt[1]
224            item = QListWidgetItem(ColorPixmap(colors[i]), l.name)
225            self.llb.addItem(item)
226            item.setSelected(l.isSelected)
227            l.color = colors[i]
228        self.blockSelectionChanges = 0
229
230    def learnerSelectionChanged(self):
231        if self.blockSelectionChanges: return
232        for (i,lt) in enumerate(self.learners):
233            l = lt[1]
234            if l.isSelected != (i in self.selectedLearners):
235                if l.isSelected: # learner was deselected
236                    l.curve.detach()
237                else: # learner was selected
238                    self.drawLearningCurve(l)
239                self.graph.replot()
240            l.isSelected = i in self.selectedLearners
241
242    # Graph specific methods
243
244    def setGraphGrid(self):
245        self.graph.enableGridYL(self.graphShowGrid)
246        self.graph.enableGridXB(self.graphShowGrid)
247
248    def setGraphStyle(self, learner):
249        curve = learner.curve
250        if self.graphDrawLines:
251            curve.set_style(OWCurve.LinesPoints)
252        else:
253            curve.set_style(OWCurve.Points)
254        curve.set_symbol(OWPoint.Ellipse)
255        curve.set_point_size(self.graphPointSize)
256        curve.set_color(self.graph.color(OWPalette.Data))
257        curve.set_pen(QPen(learner.color, 5))
258
259    def drawLearningCurve(self, learner):
260        if not self.data:
261            return
262        curve = self.graph.add_curve(
263            learner.name,
264            xData=self.curvePoints,
265            yData=learner.score,
266            autoScale=True)
267       
268        learner.curve = curve
269        self.setGraphStyle(learner)
270        self.graph.replot()
271
272    def replotGraph(self):
273        self.graph.clear()
274        for l in self.learners:
275            self.drawLearningCurve(l[1])
276
277##############################################################################
278# Test the widget, run from prompt
279
280if __name__=="__main__":
281    appl = QApplication(sys.argv)
282    ow = OWLearningCurveC()
283    ow.show()
284
285    l1 = orange.BayesLearner()
286    l1.name = 'Naive Bayes'
287    ow.learner(l1, 1)
288
289    data = orange.ExampleTable('iris.tab')
290    ow.dataset(data)
291
292    l2 = orange.BayesLearner()
293    l2.name = 'Naive Bayes (m=10)'
294    l2.estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10)
295    l2.conditionalEstimatorConstructor = orange.ConditionalProbabilityEstimatorConstructor_ByRows(estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10))
296
297    l3 = orange.kNNLearner(name="k-NN")
298    ow.learner(l3, 3)
299
300    import orngTree
301    l4 = orngTree.TreeLearner(minSubset=2)
302    l4.name = "Decision Tree"
303    ow.learner(l4, 4)
304
305#    ow.learner(None, 1)
306#    ow.learner(None, 2)
307#    ow.learner(None, 4)
308   
309
310
311    appl.exec_()
Note: See TracBrowser for help on using the repository browser.