source: orange/Orange/doc/extend-widgets/OWLearningCurve_plot.py @ 9671:a7b056375472

Revision 9671:a7b056375472, 11.7 KB checked in by anze <anze.staric@…>, 2 years ago (diff)

Moved orange to Orange (part 2)

Line 
1"""
2<name>Learning Curve (C)</name>
3<description>Takes a data set and a set of learners and plots a learning curve in a table</description>
4<icon>icons/LearningCurveC.png</icon>
5<priority>1020</priority>
6"""
7
8from OWWidget import *
9from OWColorPalette import ColorPixmap
10import OWGUI, orngTest, orngStat
11from owplot import *
12
13import warnings
14
15class OWLearningCurveC(OWWidget):
16    settingsList = ["folds", "steps", "scoringF", "commitOnChange",
17                    "graphPointSize", "graphDrawLines", "graphShowGrid"]
18
19    def __init__(self, parent=None, signalManager=None):
20        OWWidget.__init__(self, parent, signalManager, 'LearningCurveC')
21
22        self.inputs = [("Data", ExampleTable, self.dataset), ("Learner", orange.Learner, self.learner, 0)]
23
24        self.folds = 5     # cross validation folds
25        self.steps = 10    # points in the learning curve
26        self.scoringF = 0  # scoring function
27        self.commitOnChange = 1 # compute curve on any change of parameters
28        self.graphPointSize = 5 # size of points in the graphs
29        self.graphDrawLines = 1 # draw lines between points in the graph
30        self.graphShowGrid = 1  # show gridlines in the graph
31        self.selectedLearners = [] 
32        self.loadSettings()
33
34        warnings.filterwarnings("ignore", ".*builtin attribute.*", orange.AttributeWarning)
35
36        self.setCurvePoints() # sets self.curvePoints, self.steps equidistantpoints from 1/self.steps to 1
37        self.scoring = [("Classification Accuracy", orngStat.CA), ("AUC", orngStat.AUC), ("BrierScore", orngStat.BrierScore), ("Information Score", orngStat.IS), ("Sensitivity", orngStat.sens), ("Specificity", orngStat.spec)]
38        self.learners = [] # list of current learners from input channel, tuples (id, learner)
39        self.data = None   # data on which to construct the learning curve
40        self.curves = []   # list of evaluation results (one per learning curve point)
41        self.scores = []   # list of current scores, learnerID:[learner scores]
42
43        # GUI
44        box = OWGUI.widgetBox(self.controlArea, "Info")
45        self.infoa = OWGUI.widgetLabel(box, 'No data on input.')
46        self.infob = OWGUI.widgetLabel(box, 'No learners.')
47
48        ## class selection (classQLB)
49        OWGUI.separator(self.controlArea)
50        self.cbox = OWGUI.widgetBox(self.controlArea, "Learners")
51        self.llb = OWGUI.listBox(self.cbox, self, "selectedLearners", selectionMode=QListWidget.MultiSelection, callback=self.learnerSelectionChanged)
52       
53        self.llb.setMinimumHeight(50)
54        self.blockSelectionChanges = 0
55
56        OWGUI.separator(self.controlArea)
57        box = OWGUI.widgetBox(self.controlArea, "Evaluation Scores")
58        scoringNames = [x[0] for x in self.scoring]
59        OWGUI.comboBox(box, self, "scoringF", items=scoringNames,
60                       callback=self.computeScores)
61
62        OWGUI.separator(self.controlArea)
63        box = OWGUI.widgetBox(self.controlArea, "Options")
64        OWGUI.spin(box, self, 'folds', 2, 100, step=1,
65                   label='Cross validation folds:  ',
66                   callback=lambda: self.computeCurve(self.commitOnChange))
67        OWGUI.spin(box, self, 'steps', 2, 100, step=1,
68                   label='Learning curve points:  ',
69                   callback=[self.setCurvePoints, lambda: self.computeCurve(self.commitOnChange)])
70
71        OWGUI.checkBox(box, self, 'commitOnChange', 'Apply setting on any change')
72        self.commitBtn = OWGUI.button(box, self, "Apply Setting", callback=self.computeCurve, disabled=1)
73
74        # start of content (right) area
75        tabs = OWGUI.tabWidget(self.mainArea)
76
77        # graph widget
78        tab = OWGUI.createTabPage(tabs, "Graph")
79        self.graph = OWPlot(tab)
80        self.graph.set_axis_autoscale(xBottom)
81        self.graph.set_axis_autoscale(yLeft)
82        tab.layout().addWidget(self.graph)
83        self.setGraphGrid()
84
85        # table widget
86        tab = OWGUI.createTabPage(tabs, "Table")
87        self.table = OWGUI.table(tab, selectionMode=QTableWidget.NoSelection)
88
89        self.resize(550,200)
90
91    ##############################################################################
92    # slots: handle input signals
93
94    def dataset(self, data):
95        if data:
96            self.infoa.setText('%d instances in input data set' % len(data))
97            self.data = data
98            if (len(self.learners)):
99                self.computeCurve()
100            self.replotGraph()
101        else:
102            self.infoa.setText('No data on input.')
103            self.curves = []
104            self.scores = []
105            self.graph.clear()
106            self.graph.replot()
107        self.commitBtn.setEnabled(self.data<>None)
108
109    # manage learner signal
110    # we use following additional attributes for learner:
111    # - isSelected, learner is selected (display the learning curve)
112    # - curve, learning curve for the learner
113    # - score, evaluation score for the learning
114    def learner(self, learner, id=None):
115        ids = [x[0] for x in self.learners]
116        if not learner: # remove a learner and corresponding results
117            if not ids.count(id):
118                return # no such learner, removed before
119            indx = ids.index(id)
120            for i in range(self.steps):
121                self.curves[i].remove(indx)
122            del self.scores[indx]
123            self.learners[indx][1].curve.detach()
124            del self.learners[indx]
125            self.setTable()
126            self.updatellb()
127        else:
128            if ids.count(id): # update (already seen a learner from this source)
129                indx = ids.index(id)
130                prevLearner = self.learners[indx][1]
131                learner.isSelected = prevLearner.isSelected
132                self.learners[indx] = (id, learner)
133                if self.data:
134                    curve = self.getLearningCurve([learner])
135                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
136                    self.scores[indx] = score
137                    for i in range(self.steps):
138                        self.curves[i].add(curve[i], 0, replace=indx)
139                    learner.score = score
140                    prevLearner.curve.detach()
141                    self.drawLearningCurve(learner)
142                self.updatellb()
143            else: # add new learner
144                learner.isSelected = 1
145                self.learners.append((id, learner))
146                if self.data:
147                    curve = self.getLearningCurve([learner])
148                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
149                    self.scores.append(score)
150                    if len(self.curves):
151                        for i in range(self.steps):
152                            self.curves[i].add(curve[i], 0)
153                    else:
154                        self.curves = curve
155                    learner.score = score
156                self.updatellb()
157                self.drawLearningCurve(learner)
158        if len(self.learners):
159            self.infob.setText("%d learners on input." % len(self.learners))
160        else:
161            self.infob.setText("No learners.")
162        self.commitBtn.setEnabled(len(self.learners))
163        if self.data:
164            self.setTable()
165
166    ##############################################################################
167    # learning curve table, callbacks
168
169    # recomputes the learning curve
170    def computeCurve(self, condition=1):
171        if condition:
172            learners = [x[1] for x in self.learners]
173            self.curves = self.getLearningCurve(learners)
174            self.computeScores()
175
176    def computeScores(self):
177        self.scores = [[] for i in range(len(self.learners))]
178        for x in self.curves:
179            for (i,s) in enumerate(self.scoring[self.scoringF][1](x)):
180                self.scores[i].append(s)
181        for (i,l) in enumerate(self.learners):
182            l[1].score = self.scores[i]
183        self.setTable()
184        self.replotGraph()
185
186    def getLearningCurve(self, learners):
187        pb = OWGUI.ProgressBar(self, iterations=self.steps*self.folds)
188        curve = orngTest.learningCurveN(learners, self.data, folds=self.folds, proportions=self.curvePoints, callback=pb.advance)
189        pb.finish()
190        return curve
191
192    def setCurvePoints(self):
193        self.curvePoints = [(x+1.)/self.steps for x in range(self.steps)]
194
195    def setTable(self):
196        self.table.setColumnCount(0)
197        self.table.setColumnCount(len(self.learners))
198        self.table.setRowCount(self.steps)
199
200        # set the headers
201        self.table.setHorizontalHeaderLabels([l.name for i,l in self.learners])
202        self.table.setVerticalHeaderLabels(["%4.2f" % p for p in self.curvePoints])
203
204        # set the table contents
205        for l in range(len(self.learners)):
206            for p in range(self.steps):
207                OWGUI.tableItem(self.table, p, l, "%7.5f" % self.scores[l][p])
208
209        for i in range(len(self.learners)):
210            self.table.setColumnWidth(i, 80)
211
212
213    # management of learner selection
214
215    def updatellb(self):
216        self.blockSelectionChanges = 1
217        self.llb.clear()
218        colors = ColorPaletteHSV(len(self.learners))
219        for (i,lt) in enumerate(self.learners):
220            l = lt[1]
221            item = QListWidgetItem(ColorPixmap(colors[i]), l.name)
222            self.llb.addItem(item)
223            item.setSelected(l.isSelected)
224            l.color = colors[i]
225        self.blockSelectionChanges = 0
226
227    def learnerSelectionChanged(self):
228        if self.blockSelectionChanges: return
229        for (i,lt) in enumerate(self.learners):
230            l = lt[1]
231            if l.isSelected != (i in self.selectedLearners):
232                if l.isSelected: # learner was deselected
233                    l.curve.detach()
234                else: # learner was selected
235                    self.drawLearningCurve(l)
236                self.graph.replot()
237            l.isSelected = i in self.selectedLearners
238
239    # Graph specific methods
240
241    def setGraphGrid(self):
242        self.graph.enableGridYL(self.graphShowGrid)
243        self.graph.enableGridXB(self.graphShowGrid)
244
245    def setGraphStyle(self, learner):
246        curve = learner.curve
247        if self.graphDrawLines:
248            curve.set_style(OWCurve.LinesPoints)
249        else:
250            curve.set_style(OWCurve.Points)
251        curve.set_symbol(OWPoint.Ellipse)
252        curve.set_point_size(self.graphPointSize)
253        curve.set_color(self.graph.color(OWPalette.Data))
254        curve.set_pen(QPen(learner.color, 5))
255
256    def drawLearningCurve(self, learner):
257        if not self.data: return
258        curve = self.graph.add_curve(learner.name, xData=self.curvePoints, yData=learner.score, autoScale=True)
259       
260        learner.curve = curve
261        self.setGraphStyle(learner)
262        self.graph.replot()
263
264    def replotGraph(self):
265        self.graph.clear()
266        for l in self.learners:
267            self.drawLearningCurve(l[1])
268
269##############################################################################
270# Test the widget, run from prompt
271
272if __name__=="__main__":
273    appl = QApplication(sys.argv)
274    ow = OWLearningCurveC()
275    ow.show()
276
277    l1 = orange.BayesLearner()
278    l1.name = 'Naive Bayes'
279    ow.learner(l1, 1)
280
281    data = orange.ExampleTable('../datasets/iris.tab')
282    ow.dataset(data)
283
284    l2 = orange.BayesLearner()
285    l2.name = 'Naive Bayes (m=10)'
286    l2.estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10)
287    l2.conditionalEstimatorConstructor = orange.ConditionalProbabilityEstimatorConstructor_ByRows(estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10))
288
289    l3 = orange.kNNLearner(name="k-NN")
290    ow.learner(l3, 3)
291
292    import orngTree
293    l4 = orngTree.TreeLearner(minSubset=2)
294    l4.name = "Decision Tree"
295    ow.learner(l4, 4)
296
297#    ow.learner(None, 1)
298#    ow.learner(None, 2)
299#    ow.learner(None, 4)
300   
301
302
303    appl.exec_()
Note: See TracBrowser for help on using the repository browser.