source: orange/docs/extend-widgets/rst/OWLearningCurveB.py @ 11593:6edc44eb9655

Revision 11593:6edc44eb9655, 8.8 KB checked in by Ales Erjavec <ales.erjavec@…>, 10 months ago (diff)

Updated Widget development tutorial.

Line 
1"""
2<name>Learning Curve (B)</name>
3<description>Takes a data set and a set of learners and shows a learning curve in a table</description>
4<icon>icons/LearningCurve.svg</icon>
5<priority>1010</priority>
6"""
7
8import Orange
9
10from OWWidget import *
11import OWGUI
12
13
14class OWLearningCurveB(OWWidget):
15    settingsList = ["folds", "steps", "scoringF", "commitOnChange"]
16   
17    def __init__(self, parent=None, signalManager=None):
18        OWWidget.__init__(self, parent, signalManager, 'LearningCurveA')
19
20        self.inputs = [("Train Data", Orange.data.Table, self.trainset, Default),
21                       ("Test Data", Orange.data.Table, self.testset),
22                       ("Learner", Orange.classification.Learner,
23                        self.learner, Multiple)]
24       
25        self.folds = 5     # cross validation folds
26        self.steps = 10    # points in the learning curve
27        self.scoringF = 0  # scoring function
28        self.commitOnChange = 1 # compute curve on any change of parameters
29        self.loadSettings()
30        self.setCurvePoints() # sets self.curvePoints, self.steps equidistan points from 1/self.steps to 1
31        self.scoring = [("Classification Accuracy", Orange.evaluation.scoring.CA),
32                        ("AUC", Orange.evaluation.scoring.AUC),
33                        ("BrierScore", Orange.evaluation.scoring.Brier_score),
34                        ("Information Score", Orange.evaluation.scoring.IS),
35                        ("Sensitivity", Orange.evaluation.scoring.Sensitivity),
36                        ("Specificity", Orange.evaluation.scoring.Specificity)]
37        self.learners = [] # list of current learners from input channel, tuples (id, learner)
38        self.data = None   # data on which to construct the learning curve
39        self.testdata = None # optional test data
40        self.curves = []   # list of evaluation results (one per learning curve point)
41        self.scores = []   # list of current scores, learnerID:[learner scores]
42
43        # GUI
44        box = OWGUI.widgetBox(self.controlArea, "Info")
45        self.infoa = OWGUI.widgetLabel(box, 'No data on input.')
46        self.infob = OWGUI.widgetLabel(box, 'No learners.')
47
48        OWGUI.separator(self.controlArea)
49
50        box = OWGUI.widgetBox(self.controlArea, "Evaluation Scores")
51        scoringNames = [x[0] for x in self.scoring]
52        OWGUI.comboBox(box, self, "scoringF", items=scoringNames,
53                       callback=self.computeScores)
54
55        OWGUI.separator(self.controlArea)
56
57        box = OWGUI.widgetBox(self.controlArea, "Options")
58        OWGUI.spin(box, self, 'folds', 2, 100, step=1,
59                   label='Cross validation folds:  ',
60                   callback=lambda: self.computeCurve(self.commitOnChange))
61
62        OWGUI.spin(box, self, 'steps', 2, 100, step=1,
63                   label='Learning curve points:  ',
64                   callback=[self.setCurvePoints,
65                             lambda: self.computeCurve(self.commitOnChange)])
66
67        OWGUI.checkBox(box, self, 'commitOnChange', 'Apply setting on any change')
68        self.commitBtn = OWGUI.button(box, self, "Apply Setting",
69                                      callback=self.computeCurve, disabled=1)
70
71        # table widget
72        self.table = OWGUI.table(self.mainArea,
73                                 selectionMode=QTableWidget.NoSelection)
74               
75        self.resize(500,200)
76
77    ##############################################################################
78    # slots: handle input signals
79
80    def trainset(self, data):
81        if data:
82            self.infoa.setText('%d instances in input data set' % len(data))
83            self.data = data
84            if (len(self.learners)):
85                self.computeCurve()
86        else:
87            self.infoa.setText('No data on input.')
88            self.curves = []
89            self.scores = []
90        self.commitBtn.setEnabled(self.data is not None)
91
92    def testset(self, testdata):
93        if not testdata and not self.testdata:
94            return # avoid any unnecessary computation
95        self.testdata = testdata
96        if self.data and len(self.learners):
97            self.computeCurve()
98
99    def learner(self, learner, id=None):
100        ids = [x[0] for x in self.learners]
101        if not learner: # remove a learner and corresponding results
102            if not ids.count(id):
103                return # no such learner, removed before
104            indx = ids.index(id)
105            for i in range(self.steps):
106                self.curves[i].remove(indx)
107            del self.scores[indx]
108            del self.learners[indx]
109            self.setTable()
110        else:
111            if ids.count(id): # update (already seen a learner from this source)
112                indx = ids.index(id)
113                self.learners[indx] = (id, learner)
114                if self.data:
115                    curve = self.getLearningCurve([learner])
116                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
117                    self.scores[indx] = score
118                    for i in range(self.steps):
119                        self.curves[i].add(curve[i], 0, replace=indx)
120            else: # add new learner
121                self.learners.append((id, learner))
122                if self.data:
123                    curve = self.getLearningCurve([learner])
124                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
125                    self.scores.append(score)
126                    if len(self.curves):
127                        for i in range(self.steps):
128                            self.curves[i].add(curve[i], 0)
129                    else:
130                        self.curves = curve
131        if len(self.learners):
132            self.infob.setText("%d learners on input." % len(self.learners))
133        else:
134            self.infob.setText("No learners.")
135        self.commitBtn.setEnabled(len(self.learners))
136##        if len(self.scores):
137        if self.data:
138            self.setTable()
139
140    ##############################################################################   
141    # learning curve table, callbacks
142
143    # recomputes the learning curve
144    def computeCurve(self, condition=1):
145        if condition:
146            learners = [x[1] for x in self.learners]
147            self.curves = self.getLearningCurve(learners)
148            self.computeScores()
149
150    def computeScores(self):           
151        self.scores = [[] for i in range(len(self.learners))]
152        for x in self.curves:
153            for (i,s) in enumerate(self.scoring[self.scoringF][1](x)):
154                self.scores[i].append(s)
155        self.setTable()
156
157    def getLearningCurve(self, learners):   
158        pb = OWGUI.ProgressBar(self, iterations=self.steps*self.folds)
159        if not self.testdata:
160            curve = Orange.evaluation.testing.learning_curve_n(
161                learners, self.data, folds=self.folds,
162                proportions=self.curvePoints,
163                callback=pb.advance)
164        else:
165            curve = Orange.evaluation.testing.learning_curve_with_test_data(
166                learners, self.data, self.testdata, times=self.folds,
167                proportions=self.curvePoints,
168#                callback=pb.advance
169                )
170        pb.finish()
171        return curve
172
173    def setCurvePoints(self):
174        self.curvePoints = [(x + 1.)/self.steps for x in range(self.steps)]
175
176    def setTable(self):
177        self.table.setColumnCount(0)
178        self.table.setColumnCount(len(self.learners))
179        self.table.setRowCount(self.steps)
180
181        # set the headers
182        self.table.setHorizontalHeaderLabels([l.name for i,l in self.learners])
183        self.table.setVerticalHeaderLabels(["%4.2f" % p for p in self.curvePoints])
184
185        # set the table contents
186        for l in range(len(self.learners)):
187            for p in range(self.steps):
188                OWGUI.tableItem(self.table, p, l, "%7.5f" % self.scores[l][p])
189
190        for i in range(len(self.learners)):
191            self.table.setColumnWidth(i, 80)
192
193
194if __name__=="__main__":
195    appl = QApplication(sys.argv)
196    ow = OWLearningCurveB()
197    ow.show()
198   
199    l1 = Orange.classification.bayes.NaiveLearner()
200    l1.name = 'Naive Bayes'
201    ow.learner(l1, 1)
202
203    data = Orange.data.Table('iris.tab')
204    indices = Orange.data.sample.SubsetIndices2(data, p0 = 0.7)
205    train = data.select(indices, 0)
206    test = data.select(indices, 1)
207
208    ow.trainset(train)
209    ow.testset(test)
210
211    l2 = Orange.classification.bayes.NaiveLearner()
212    l2.name = 'Naive Bayes (m=10)'
213    l2.estimatorConstructor = Orange.statistics.estimate.M(m=10)
214    l2.conditionalEstimatorConstructor = \
215        Orange.statistics.estimate.ConditionalByRows(
216            estimatorConstructor = Orange.statistics.estimate.M(m=10))
217    ow.learner(l2, 2)
218
219    l4 = Orange.classification.tree.TreeLearner(minSubset=2)
220    l4.name = "Decision Tree"
221    ow.learner(l4, 4)
222
223#    ow.learner(None, 1)
224#    ow.learner(None, 2)
225#    ow.learner(None, 4)
226
227    appl.exec_()
Note: See TracBrowser for help on using the repository browser.