source: orange/Orange/doc/extend-widgets/OWLearningCurveC.py @ 10997:e82a939c44ce

Revision 10997:e82a939c44ce, 11.8 KB checked in by Ales Erjavec <ales.erjavec@…>, 18 months ago (diff)

Updated channel specification flags documentation in 'extend-widgets'.

Line 
1"""
2<name>Learning Curve (C)</name>
3<description>Takes a data set and a set of learners and plots a learning curve in a table</description>
4<icon>icons/LearningCurveC.png</icon>
5<priority>1020</priority>
6"""
7
8from OWWidget import *
9from OWColorPalette import ColorPixmap
10import OWGUI, orngTest, orngStat
11from OWGraph import *
12
13import warnings
14
15class OWLearningCurveC(OWWidget):
16    settingsList = ["folds", "steps", "scoringF", "commitOnChange",
17                    "graphPointSize", "graphDrawLines", "graphShowGrid"]
18
19    def __init__(self, parent=None, signalManager=None):
20        OWWidget.__init__(self, parent, signalManager, 'LearningCurveC')
21
22        self.inputs = [("Data", ExampleTable, self.dataset),
23                       ("Learner", orange.Learner, self.learner, Multiple)]
24
25        self.folds = 5     # cross validation folds
26        self.steps = 10    # points in the learning curve
27        self.scoringF = 0  # scoring function
28        self.commitOnChange = 1 # compute curve on any change of parameters
29        self.graphPointSize = 5 # size of points in the graphs
30        self.graphDrawLines = 1 # draw lines between points in the graph
31        self.graphShowGrid = 1  # show gridlines in the graph
32        self.selectedLearners = [] 
33        self.loadSettings()
34
35        warnings.filterwarnings("ignore", ".*builtin attribute.*", orange.AttributeWarning)
36
37        self.setCurvePoints() # sets self.curvePoints, self.steps equidistantpoints from 1/self.steps to 1
38        self.scoring = [("Classification Accuracy", orngStat.CA), ("AUC", orngStat.AUC), ("BrierScore", orngStat.BrierScore), ("Information Score", orngStat.IS), ("Sensitivity", orngStat.sens), ("Specificity", orngStat.spec)]
39        self.learners = [] # list of current learners from input channel, tuples (id, learner)
40        self.data = None   # data on which to construct the learning curve
41        self.curves = []   # list of evaluation results (one per learning curve point)
42        self.scores = []   # list of current scores, learnerID:[learner scores]
43
44        # GUI
45        box = OWGUI.widgetBox(self.controlArea, "Info")
46        self.infoa = OWGUI.widgetLabel(box, 'No data on input.')
47        self.infob = OWGUI.widgetLabel(box, 'No learners.')
48
49        ## class selection (classQLB)
50        OWGUI.separator(self.controlArea)
51        self.cbox = OWGUI.widgetBox(self.controlArea, "Learners")
52        self.llb = OWGUI.listBox(self.cbox, self, "selectedLearners", selectionMode=QListWidget.MultiSelection, callback=self.learnerSelectionChanged)
53       
54        self.llb.setMinimumHeight(50)
55        self.blockSelectionChanges = 0
56
57        OWGUI.separator(self.controlArea)
58        box = OWGUI.widgetBox(self.controlArea, "Evaluation Scores")
59        scoringNames = [x[0] for x in self.scoring]
60        OWGUI.comboBox(box, self, "scoringF", items=scoringNames,
61                       callback=self.computeScores)
62
63        OWGUI.separator(self.controlArea)
64        box = OWGUI.widgetBox(self.controlArea, "Options")
65        OWGUI.spin(box, self, 'folds', 2, 100, step=1,
66                   label='Cross validation folds:  ',
67                   callback=lambda: self.computeCurve(self.commitOnChange))
68        OWGUI.spin(box, self, 'steps', 2, 100, step=1,
69                   label='Learning curve points:  ',
70                   callback=[self.setCurvePoints, lambda: self.computeCurve(self.commitOnChange)])
71
72        OWGUI.checkBox(box, self, 'commitOnChange', 'Apply setting on any change')
73        self.commitBtn = OWGUI.button(box, self, "Apply Setting", callback=self.computeCurve, disabled=1)
74
75        # start of content (right) area
76        tabs = OWGUI.tabWidget(self.mainArea)
77
78        # graph widget
79        tab = OWGUI.createTabPage(tabs, "Graph")
80        self.graph = OWGraph(tab)
81        self.graph.setAxisAutoScale(QwtPlot.xBottom)
82        self.graph.setAxisAutoScale(QwtPlot.yLeft)
83        tab.layout().addWidget(self.graph)
84        self.setGraphGrid()
85
86        # table widget
87        tab = OWGUI.createTabPage(tabs, "Table")
88        self.table = OWGUI.table(tab, selectionMode=QTableWidget.NoSelection)
89
90        self.resize(550,200)
91
92    ##############################################################################
93    # slots: handle input signals
94
95    def dataset(self, data):
96        if data:
97            self.infoa.setText('%d instances in input data set' % len(data))
98            self.data = data
99            if (len(self.learners)):
100                self.computeCurve()
101            self.replotGraph()
102        else:
103            self.infoa.setText('No data on input.')
104            self.curves = []
105            self.scores = []
106            self.graph.removeDrawingCurves()
107            self.graph.replot()
108        self.commitBtn.setEnabled(self.data<>None)
109
110    # manage learner signal
111    # we use following additional attributes for learner:
112    # - isSelected, learner is selected (display the learning curve)
113    # - curve, learning curve for the learner
114    # - score, evaluation score for the learning
115    def learner(self, learner, id=None):
116        ids = [x[0] for x in self.learners]
117        if not learner: # remove a learner and corresponding results
118            if not ids.count(id):
119                return # no such learner, removed before
120            indx = ids.index(id)
121            for i in range(self.steps):
122                self.curves[i].remove(indx)
123            del self.scores[indx]
124            self.learners[indx][1].curve.detach()
125            del self.learners[indx]
126            self.setTable()
127            self.updatellb()
128        else:
129            if ids.count(id): # update (already seen a learner from this source)
130                indx = ids.index(id)
131                prevLearner = self.learners[indx][1]
132                learner.isSelected = prevLearner.isSelected
133                self.learners[indx] = (id, learner)
134                if self.data:
135                    curve = self.getLearningCurve([learner])
136                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
137                    self.scores[indx] = score
138                    for i in range(self.steps):
139                        self.curves[i].add(curve[i], 0, replace=indx)
140                    learner.score = score
141                    prevLearner.curve.detach()
142                    self.drawLearningCurve(learner)
143                self.updatellb()
144            else: # add new learner
145                learner.isSelected = 1
146                self.learners.append((id, learner))
147                if self.data:
148                    curve = self.getLearningCurve([learner])
149                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
150                    self.scores.append(score)
151                    if len(self.curves):
152                        for i in range(self.steps):
153                            self.curves[i].add(curve[i], 0)
154                    else:
155                        self.curves = curve
156                    learner.score = score
157                self.updatellb()
158                self.drawLearningCurve(learner)
159        if len(self.learners):
160            self.infob.setText("%d learners on input." % len(self.learners))
161        else:
162            self.infob.setText("No learners.")
163        self.commitBtn.setEnabled(len(self.learners))
164        if self.data:
165            self.setTable()
166
167    ##############################################################################
168    # learning curve table, callbacks
169
170    # recomputes the learning curve
171    def computeCurve(self, condition=1):
172        if condition:
173            learners = [x[1] for x in self.learners]
174            self.curves = self.getLearningCurve(learners)
175            self.computeScores()
176
177    def computeScores(self):
178        self.scores = [[] for i in range(len(self.learners))]
179        for x in self.curves:
180            for (i,s) in enumerate(self.scoring[self.scoringF][1](x)):
181                self.scores[i].append(s)
182        for (i,l) in enumerate(self.learners):
183            l[1].score = self.scores[i]
184        self.setTable()
185        self.replotGraph()
186
187    def getLearningCurve(self, learners):
188        pb = OWGUI.ProgressBar(self, iterations=self.steps*self.folds)
189        curve = orngTest.learningCurveN(learners, self.data, folds=self.folds, proportions=self.curvePoints, callback=pb.advance)
190        pb.finish()
191        return curve
192
193    def setCurvePoints(self):
194        self.curvePoints = [(x+1.)/self.steps for x in range(self.steps)]
195
196    def setTable(self):
197        self.table.setColumnCount(0)
198        self.table.setColumnCount(len(self.learners))
199        self.table.setRowCount(self.steps)
200
201        # set the headers
202        self.table.setHorizontalHeaderLabels([l.name for i,l in self.learners])
203        self.table.setVerticalHeaderLabels(["%4.2f" % p for p in self.curvePoints])
204
205        # set the table contents
206        for l in range(len(self.learners)):
207            for p in range(self.steps):
208                OWGUI.tableItem(self.table, p, l, "%7.5f" % self.scores[l][p])
209
210        for i in range(len(self.learners)):
211            self.table.setColumnWidth(i, 80)
212
213
214    # management of learner selection
215
216    def updatellb(self):
217        self.blockSelectionChanges = 1
218        self.llb.clear()
219        colors = ColorPaletteHSV(len(self.learners))
220        for (i,lt) in enumerate(self.learners):
221            l = lt[1]
222            item = QListWidgetItem(ColorPixmap(colors[i]), l.name)
223            self.llb.addItem(item)
224            item.setSelected(l.isSelected)
225            l.color = colors[i]
226        self.blockSelectionChanges = 0
227
228    def learnerSelectionChanged(self):
229        if self.blockSelectionChanges: return
230        for (i,lt) in enumerate(self.learners):
231            l = lt[1]
232            if l.isSelected != (i in self.selectedLearners):
233                if l.isSelected: # learner was deselected
234                    l.curve.detach()
235                else: # learner was selected
236                    self.drawLearningCurve(l)
237                self.graph.replot()
238            l.isSelected = i in self.selectedLearners
239
240    # Graph specific methods
241
242    def setGraphGrid(self):
243        self.graph.enableGridYL(self.graphShowGrid)
244        self.graph.enableGridXB(self.graphShowGrid)
245
246    def setGraphStyle(self, learner):
247        curve = learner.curve
248        if self.graphDrawLines:
249            curve.setStyle(QwtPlotCurve.Lines)
250        else:
251            curve.setStyle(QwtPlotCurve.NoCurve)
252        curve.setSymbol(QwtSymbol(QwtSymbol.Ellipse, \
253          QBrush(QColor(0,0,0)), QPen(QColor(0,0,0)),
254          QSize(self.graphPointSize, self.graphPointSize)))
255        curve.setPen(QPen(learner.color, 5))
256
257    def drawLearningCurve(self, learner):
258        if not self.data: return
259        curve = self.graph.addCurve(learner.name, xData=self.curvePoints, yData=learner.score, autoScale=True)
260       
261        learner.curve = curve
262        self.setGraphStyle(learner)
263        self.graph.replot()
264
265    def replotGraph(self):
266        self.graph.removeDrawingCurves()
267        for l in self.learners:
268            self.drawLearningCurve(l[1])
269
270##############################################################################
271# Test the widget, run from prompt
272
273if __name__=="__main__":
274    appl = QApplication(sys.argv)
275    ow = OWLearningCurveC()
276    ow.show()
277
278    l1 = orange.BayesLearner()
279    l1.name = 'Naive Bayes'
280    ow.learner(l1, 1)
281
282    data = orange.ExampleTable('../datasets/iris.tab')
283    ow.dataset(data)
284
285    l2 = orange.BayesLearner()
286    l2.name = 'Naive Bayes (m=10)'
287    l2.estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10)
288    l2.conditionalEstimatorConstructor = orange.ConditionalProbabilityEstimatorConstructor_ByRows(estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10))
289
290    l3 = orange.kNNLearner(name="k-NN")
291    ow.learner(l3, 3)
292
293    import orngTree
294    l4 = orngTree.TreeLearner(minSubset=2)
295    l4.name = "Decision Tree"
296    ow.learner(l4, 4)
297
298#    ow.learner(None, 1)
299#    ow.learner(None, 2)
300#    ow.learner(None, 4)
301   
302
303
304    appl.exec_()
Note: See TracBrowser for help on using the repository browser.