source: orange/Orange/doc/extend-widgets/OWLearningCurveB.py @ 9671:a7b056375472

Revision 9671:a7b056375472, 8.3 KB checked in by anze <anze.staric@…>, 2 years ago (diff)

Moved orange to Orange (part 2)

Line 
1"""
2<name>Learning Curve (B)</name>
3<description>Takes a data set and a set of learners and shows a learning curve in a table</description>
4<icon>icons/LearningCurveB.png</icon>
5<priority>1010</priority>
6"""
7
8from OWWidget import *
9import OWGUI, orngTest, orngStat
10
11class OWLearningCurveB(OWWidget):
12    settingsList = ["folds", "steps", "scoringF", "commitOnChange"]
13   
14    def __init__(self, parent=None, signalManager=None):
15        OWWidget.__init__(self, parent, signalManager, 'LearningCurveA')
16
17        self.inputs = [("Train Data", ExampleTable, self.trainset), ("Test Data", ExampleTable, self.testset, 1, 1), ("Learner", orange.Learner, self.learner, 0)]
18       
19        self.folds = 5     # cross validation folds
20        self.steps = 10    # points in the learning curve
21        self.scoringF = 0  # scoring function
22        self.commitOnChange = 1 # compute curve on any change of parameters
23        self.loadSettings()
24        self.setCurvePoints() # sets self.curvePoints, self.steps equidistantpoints from 1/self.steps to 1
25        self.scoring = [("Classification Accuracy", orngStat.CA), ("AUC", orngStat.AUC), ("BrierScore", orngStat.BrierScore), ("Information Score", orngStat.IS), ("Sensitivity", orngStat.sens), ("Specificity", orngStat.spec)]
26        self.learners = [] # list of current learners from input channel, tuples (id, learner)
27        self.data = None   # data on which to construct the learning curve
28        self.testdata = None # optional test data
29        self.curves = []   # list of evaluation results (one per learning curve point)
30        self.scores = []   # list of current scores, learnerID:[learner scores]
31
32        # GUI
33        box = OWGUI.widgetBox(self.controlArea, "Info")
34        self.infoa = OWGUI.widgetLabel(box, 'No data on input.')
35        self.infob = OWGUI.widgetLabel(box, 'No learners.')
36
37        OWGUI.separator(self.controlArea)
38        box = OWGUI.widgetBox(self.controlArea, "Evaluation Scores")
39        scoringNames = [x[0] for x in self.scoring]
40        OWGUI.comboBox(box, self, "scoringF", items=scoringNames, callback=self.computeScores)
41
42        OWGUI.separator(self.controlArea)
43        box = OWGUI.widgetBox(self.controlArea, "Options")
44        OWGUI.spin(box, self, 'folds', 2, 100, step=1, label='Cross validation folds:  ',
45                   callback=lambda: self.computeCurve(self.commitOnChange))
46        OWGUI.spin(box, self, 'steps', 2, 100, step=1, label='Learning curve points:  ',
47                   callback=[self.setCurvePoints, lambda: self.computeCurve(self.commitOnChange)])
48
49        OWGUI.checkBox(box, self, 'commitOnChange', 'Apply setting on any change')
50        self.commitBtn = OWGUI.button(box, self, "Apply Setting", callback=self.computeCurve, disabled=1)
51
52        # table widget
53        self.table = OWGUI.table(self.mainArea, selectionMode=QTableWidget.NoSelection)
54               
55        self.resize(500,200)
56
57    ##############################################################################   
58    # slots: handle input signals       
59
60    def trainset(self, data):
61        if data:
62            self.infoa.setText('%d instances in input data set' % len(data))
63            self.data = data
64            if (len(self.learners)):
65                self.computeCurve()
66        else:
67            self.infoa.setText('No data on input.')
68            self.curves = []
69            self.scores = []
70        self.commitBtn.setEnabled(self.data<>None)
71
72    def testset(self, testdata):
73        if not testdata and not self.testdata:
74            return # avoid any unnecessary computation
75        self.testdata = testdata
76        if self.data and len(self.learners):
77            self.computeCurve()
78
79    def learner(self, learner, id=None):
80        ids = [x[0] for x in self.learners]
81        if not learner: # remove a learner and corresponding results
82            if not ids.count(id):
83                return # no such learner, removed before
84            indx = ids.index(id)
85            for i in range(self.steps):
86                self.curves[i].remove(indx)
87            del self.scores[indx]
88            del self.learners[indx]
89            self.setTable()
90        else:
91            if ids.count(id): # update (already seen a learner from this source)
92                indx = ids.index(id)
93                self.learners[indx] = (id, learner)
94                if self.data:
95                    curve = self.getLearningCurve([learner])
96                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
97                    self.scores[indx] = score
98                    for i in range(self.steps):
99                        self.curves[i].add(curve[i], 0, replace=indx)
100            else: # add new learner
101                self.learners.append((id, learner))
102                if self.data:
103                    curve = self.getLearningCurve([learner])
104                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
105                    self.scores.append(score)
106                    if len(self.curves):
107                        for i in range(self.steps):
108                            self.curves[i].add(curve[i], 0)
109                    else:
110                        self.curves = curve
111        if len(self.learners):
112            self.infob.setText("%d learners on input." % len(self.learners))
113        else:
114            self.infob.setText("No learners.")
115        self.commitBtn.setEnabled(len(self.learners))
116##        if len(self.scores):
117        if self.data:
118            self.setTable()
119
120    ##############################################################################   
121    # learning curve table, callbacks
122
123    # recomputes the learning curve
124    def computeCurve(self, condition=1):
125        if condition:
126            learners = [x[1] for x in self.learners]
127            self.curves = self.getLearningCurve(learners)
128            self.computeScores()
129
130    def computeScores(self):           
131        self.scores = [[] for i in range(len(self.learners))]
132        for x in self.curves:
133            for (i,s) in enumerate(self.scoring[self.scoringF][1](x)):
134                self.scores[i].append(s)
135        self.setTable()
136
137    def getLearningCurve(self, learners):   
138        pb = OWGUI.ProgressBar(self, iterations=self.steps*self.folds)
139        if not self.testdata:
140            curve = orngTest.learningCurveN(learners, self.data, folds=self.folds, proportions=self.curvePoints, callback=pb.advance)
141        else:
142            curve = orngTest.learningCurveWithTestData(learners,
143              self.data, self.testdata, times=self.folds, proportions=self.curvePoints, callback=pb.advance)           
144        pb.finish()
145        return curve
146
147    def setCurvePoints(self):
148        self.curvePoints = [(x+1.)/self.steps for x in range(self.steps)]
149
150    def setTable(self):
151        self.table.setColumnCount(0)
152        self.table.setColumnCount(len(self.learners))
153        self.table.setRowCount(self.steps)
154
155        # set the headers
156        self.table.setHorizontalHeaderLabels([l.name for i,l in self.learners])
157        self.table.setVerticalHeaderLabels(["%4.2f" % p for p in self.curvePoints])
158
159        # set the table contents
160        for l in range(len(self.learners)):
161            for p in range(self.steps):
162                OWGUI.tableItem(self.table, p, l, "%7.5f" % self.scores[l][p])
163
164        for i in range(len(self.learners)):
165            self.table.setColumnWidth(i, 80)
166
167##############################################################################
168# Test the widget, run from prompt
169
170if __name__=="__main__":
171    appl = QApplication(sys.argv)
172    ow = OWLearningCurveB()
173    ow.show()
174   
175    l1 = orange.BayesLearner()
176    l1.name = 'Naive Bayes'
177    ow.learner(l1, 1)
178
179    data = orange.ExampleTable('../datasets/iris.tab')
180    indices = orange.MakeRandomIndices2(data, p0 = 0.7)
181    train = data.select(indices, 0)
182    test = data.select(indices, 1)
183
184    ow.trainset(train)
185    ow.testset(test)
186
187    l2 = orange.BayesLearner()
188    l2.name = 'Naive Bayes (m=10)'
189    l2.estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10)
190    l2.conditionalEstimatorConstructor = orange.ConditionalProbabilityEstimatorConstructor_ByRows(estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10))
191    ow.learner(l2, 2)
192
193    import orngTree
194    l4 = orngTree.TreeLearner(minSubset=2)
195    l4.name = "Decision Tree"
196    ow.learner(l4, 4)
197
198#    ow.learner(None, 1)
199#    ow.learner(None, 2)
200#    ow.learner(None, 4)
201
202    appl.exec_()
Note: See TracBrowser for help on using the repository browser.