source: orange/docs/extend-widgets/rst/OWLearningCurveB.py @ 11049:f4dd8dbc57bb

Revision 11049:f4dd8dbc57bb, 8.3 KB checked in by Miha Stajdohar <miha.stajdohar@…>, 16 months ago (diff)

From HTML to Sphinx.

Line 
1"""
2<name>Learning Curve (B)</name>
3<description>Takes a data set and a set of learners and shows a learning curve in a table</description>
4<icon>icons/LearningCurveB.png</icon>
5<priority>1010</priority>
6"""
7
8from OWWidget import *
9import OWGUI, orngTest, orngStat
10
11class OWLearningCurveB(OWWidget):
12    settingsList = ["folds", "steps", "scoringF", "commitOnChange"]
13   
14    def __init__(self, parent=None, signalManager=None):
15        OWWidget.__init__(self, parent, signalManager, 'LearningCurveA')
16
17        self.inputs = [("Train Data", ExampleTable, self.trainset, Default),
18                       ("Test Data", ExampleTable, self.testset),
19                       ("Learner", orange.Learner, self.learner, Multiple)]
20       
21        self.folds = 5     # cross validation folds
22        self.steps = 10    # points in the learning curve
23        self.scoringF = 0  # scoring function
24        self.commitOnChange = 1 # compute curve on any change of parameters
25        self.loadSettings()
26        self.setCurvePoints() # sets self.curvePoints, self.steps equidistantpoints from 1/self.steps to 1
27        self.scoring = [("Classification Accuracy", orngStat.CA), ("AUC", orngStat.AUC), ("BrierScore", orngStat.BrierScore), ("Information Score", orngStat.IS), ("Sensitivity", orngStat.sens), ("Specificity", orngStat.spec)]
28        self.learners = [] # list of current learners from input channel, tuples (id, learner)
29        self.data = None   # data on which to construct the learning curve
30        self.testdata = None # optional test data
31        self.curves = []   # list of evaluation results (one per learning curve point)
32        self.scores = []   # list of current scores, learnerID:[learner scores]
33
34        # GUI
35        box = OWGUI.widgetBox(self.controlArea, "Info")
36        self.infoa = OWGUI.widgetLabel(box, 'No data on input.')
37        self.infob = OWGUI.widgetLabel(box, 'No learners.')
38
39        OWGUI.separator(self.controlArea)
40        box = OWGUI.widgetBox(self.controlArea, "Evaluation Scores")
41        scoringNames = [x[0] for x in self.scoring]
42        OWGUI.comboBox(box, self, "scoringF", items=scoringNames, callback=self.computeScores)
43
44        OWGUI.separator(self.controlArea)
45        box = OWGUI.widgetBox(self.controlArea, "Options")
46        OWGUI.spin(box, self, 'folds', 2, 100, step=1, label='Cross validation folds:  ',
47                   callback=lambda: self.computeCurve(self.commitOnChange))
48        OWGUI.spin(box, self, 'steps', 2, 100, step=1, label='Learning curve points:  ',
49                   callback=[self.setCurvePoints, lambda: self.computeCurve(self.commitOnChange)])
50
51        OWGUI.checkBox(box, self, 'commitOnChange', 'Apply setting on any change')
52        self.commitBtn = OWGUI.button(box, self, "Apply Setting", callback=self.computeCurve, disabled=1)
53
54        # table widget
55        self.table = OWGUI.table(self.mainArea, selectionMode=QTableWidget.NoSelection)
56               
57        self.resize(500,200)
58
59    ##############################################################################   
60    # slots: handle input signals       
61
62    def trainset(self, data):
63        if data:
64            self.infoa.setText('%d instances in input data set' % len(data))
65            self.data = data
66            if (len(self.learners)):
67                self.computeCurve()
68        else:
69            self.infoa.setText('No data on input.')
70            self.curves = []
71            self.scores = []
72        self.commitBtn.setEnabled(self.data<>None)
73
74    def testset(self, testdata):
75        if not testdata and not self.testdata:
76            return # avoid any unnecessary computation
77        self.testdata = testdata
78        if self.data and len(self.learners):
79            self.computeCurve()
80
81    def learner(self, learner, id=None):
82        ids = [x[0] for x in self.learners]
83        if not learner: # remove a learner and corresponding results
84            if not ids.count(id):
85                return # no such learner, removed before
86            indx = ids.index(id)
87            for i in range(self.steps):
88                self.curves[i].remove(indx)
89            del self.scores[indx]
90            del self.learners[indx]
91            self.setTable()
92        else:
93            if ids.count(id): # update (already seen a learner from this source)
94                indx = ids.index(id)
95                self.learners[indx] = (id, learner)
96                if self.data:
97                    curve = self.getLearningCurve([learner])
98                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
99                    self.scores[indx] = score
100                    for i in range(self.steps):
101                        self.curves[i].add(curve[i], 0, replace=indx)
102            else: # add new learner
103                self.learners.append((id, learner))
104                if self.data:
105                    curve = self.getLearningCurve([learner])
106                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
107                    self.scores.append(score)
108                    if len(self.curves):
109                        for i in range(self.steps):
110                            self.curves[i].add(curve[i], 0)
111                    else:
112                        self.curves = curve
113        if len(self.learners):
114            self.infob.setText("%d learners on input." % len(self.learners))
115        else:
116            self.infob.setText("No learners.")
117        self.commitBtn.setEnabled(len(self.learners))
118##        if len(self.scores):
119        if self.data:
120            self.setTable()
121
122    ##############################################################################   
123    # learning curve table, callbacks
124
125    # recomputes the learning curve
126    def computeCurve(self, condition=1):
127        if condition:
128            learners = [x[1] for x in self.learners]
129            self.curves = self.getLearningCurve(learners)
130            self.computeScores()
131
132    def computeScores(self):           
133        self.scores = [[] for i in range(len(self.learners))]
134        for x in self.curves:
135            for (i,s) in enumerate(self.scoring[self.scoringF][1](x)):
136                self.scores[i].append(s)
137        self.setTable()
138
139    def getLearningCurve(self, learners):   
140        pb = OWGUI.ProgressBar(self, iterations=self.steps*self.folds)
141        if not self.testdata:
142            curve = orngTest.learningCurveN(learners, self.data, folds=self.folds, proportions=self.curvePoints, callback=pb.advance)
143        else:
144            curve = orngTest.learningCurveWithTestData(learners,
145              self.data, self.testdata, times=self.folds, proportions=self.curvePoints, callback=pb.advance)           
146        pb.finish()
147        return curve
148
149    def setCurvePoints(self):
150        self.curvePoints = [(x+1.)/self.steps for x in range(self.steps)]
151
152    def setTable(self):
153        self.table.setColumnCount(0)
154        self.table.setColumnCount(len(self.learners))
155        self.table.setRowCount(self.steps)
156
157        # set the headers
158        self.table.setHorizontalHeaderLabels([l.name for i,l in self.learners])
159        self.table.setVerticalHeaderLabels(["%4.2f" % p for p in self.curvePoints])
160
161        # set the table contents
162        for l in range(len(self.learners)):
163            for p in range(self.steps):
164                OWGUI.tableItem(self.table, p, l, "%7.5f" % self.scores[l][p])
165
166        for i in range(len(self.learners)):
167            self.table.setColumnWidth(i, 80)
168
169##############################################################################
170# Test the widget, run from prompt
171
172if __name__=="__main__":
173    appl = QApplication(sys.argv)
174    ow = OWLearningCurveB()
175    ow.show()
176   
177    l1 = orange.BayesLearner()
178    l1.name = 'Naive Bayes'
179    ow.learner(l1, 1)
180
181    data = orange.ExampleTable('../datasets/iris.tab')
182    indices = orange.MakeRandomIndices2(data, p0 = 0.7)
183    train = data.select(indices, 0)
184    test = data.select(indices, 1)
185
186    ow.trainset(train)
187    ow.testset(test)
188
189    l2 = orange.BayesLearner()
190    l2.name = 'Naive Bayes (m=10)'
191    l2.estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10)
192    l2.conditionalEstimatorConstructor = orange.ConditionalProbabilityEstimatorConstructor_ByRows(estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10))
193    ow.learner(l2, 2)
194
195    import orngTree
196    l4 = orngTree.TreeLearner(minSubset=2)
197    l4.name = "Decision Tree"
198    ow.learner(l4, 4)
199
200#    ow.learner(None, 1)
201#    ow.learner(None, 2)
202#    ow.learner(None, 4)
203
204    appl.exec_()
Note: See TracBrowser for help on using the repository browser.