source: orange/Orange/doc/extend-widgets/OWLearningCurveA.py @ 9671:a7b056375472

Revision 9671:a7b056375472, 7.6 KB checked in by anze <anze.staric@…>, 2 years ago (diff)

Moved orange to Orange (part 2)

Line 
1"""
2<name>Learning Curve (A)</name>
3<description>Takes a data set and a set of learners and shows a learning curve in a table</description>
4<icon>icons/LearningCurveA.png</icon>
5<priority>1000</priority>
6"""
7
8from OWWidget import *
9import OWGUI, orngTest, orngStat
10
11class OWLearningCurveA(OWWidget):
12    settingsList = ["folds", "steps", "scoringF", "commitOnChange"]
13   
14    def __init__(self, parent=None, signalManager=None):
15        OWWidget.__init__(self, parent, signalManager, 'LearningCurveA')
16
17        self.inputs = [("Data", ExampleTable, self.dataset), ("Learner", orange.Learner, self.learner, 0)]
18       
19        self.folds = 5     # cross validation folds
20        self.steps = 10    # points in the learning curve
21        self.scoringF = 0  # scoring function
22        self.commitOnChange = 1 # compute curve on any change of parameters
23        self.loadSettings()
24        self.setCurvePoints() # sets self.curvePoints, self.steps equidistantpoints from 1/self.steps to 1
25        self.scoring = [("Classification Accuracy", orngStat.CA), ("AUC", orngStat.AUC), ("BrierScore", orngStat.BrierScore), ("Information Score", orngStat.IS), ("Sensitivity", orngStat.sens), ("Specificity", orngStat.spec)]
26        self.learners = [] # list of current learners from input channel, tuples (id, learner)
27        self.data = None   # data on which to construct the learning curve
28        self.curves = []   # list of evaluation results (one per learning curve point)
29        self.scores = []   # list of current scores, learnerID:[learner scores]
30
31        # GUI
32        box = OWGUI.widgetBox(self.controlArea, "Info")
33        self.infoa = OWGUI.widgetLabel(box, 'No data on input.')
34        self.infob = OWGUI.widgetLabel(box, 'No learners.')
35
36        OWGUI.separator(self.controlArea)
37        box = OWGUI.widgetBox(self.controlArea, "Evaluation Scores")
38        scoringNames = [x[0] for x in self.scoring]
39        OWGUI.comboBox(box, self, "scoringF", items=scoringNames, callback=self.computeScores)
40
41        OWGUI.separator(self.controlArea)
42        box = OWGUI.widgetBox(self.controlArea, "Options")
43        OWGUI.spin(box, self, 'folds', 2, 100, step=1, label='Cross validation folds:  ',
44                   callback=lambda: self.computeCurve(self.commitOnChange))
45        OWGUI.spin(box, self, 'steps', 2, 100, step=1, label='Learning curve points:  ',
46                   callback=[self.setCurvePoints, lambda: self.computeCurve(self.commitOnChange)])
47
48        OWGUI.checkBox(box, self, 'commitOnChange', 'Apply setting on any change')
49        self.commitBtn = OWGUI.button(box, self, "Apply Setting", callback=self.computeCurve, disabled=1)
50
51        # table widget
52        self.table = OWGUI.table(self.mainArea, selectionMode=QTableWidget.NoSelection)
53               
54        self.resize(500,200)
55
56    ##############################################################################   
57    # slots: handle input signals       
58
59    def dataset(self, data):
60        if data:
61            self.infoa.setText('%d instances in input data set' % len(data))
62            self.data = data
63            if (len(self.learners)):
64                self.computeCurve()
65        else:
66            self.infoa.setText('No data on input.')
67            self.curves = []
68            self.scores = []
69        self.commitBtn.setEnabled(self.data<>None)
70
71    def learner(self, learner, id=None):
72        ids = [x[0] for x in self.learners]
73        if not learner: # remove a learner and corresponding results
74            if not ids.count(id):
75                return # no such learner, removed before
76            indx = ids.index(id)
77            for i in range(self.steps):
78                self.curves[i].remove(indx)
79            del self.scores[indx]
80            del self.learners[indx]
81            self.setTable()
82        else:
83            if ids.count(id): # update (already seen a learner from this source)
84                indx = ids.index(id)
85                self.learners[indx] = (id, learner)
86                if self.data:
87                    curve = self.getLearningCurve([learner])
88                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
89                    self.scores[indx] = score
90                    for i in range(self.steps):
91                        self.curves[i].add(curve[i], 0, replace=indx)
92            else: # add new learner
93                self.learners.append((id, learner))
94                if self.data:
95                    curve = self.getLearningCurve([learner])
96                    score = [self.scoring[self.scoringF][1](x)[0] for x in curve]
97                    self.scores.append(score)
98                    if len(self.curves):
99                        for i in range(self.steps):
100                            self.curves[i].add(curve[i], 0)
101                    else:
102                        self.curves = curve
103        if len(self.learners):
104            self.infob.setText("%d learners on input." % len(self.learners))
105        else:
106            self.infob.setText("No learners.")
107        self.commitBtn.setEnabled(len(self.learners))
108##        if len(self.scores):
109        if self.data:
110            self.setTable()
111
112    ##############################################################################   
113    # learning curve table, callbacks
114
115    # recomputes the learning curve
116    def computeCurve(self, condition=1):
117        if condition:
118            learners = [x[1] for x in self.learners]
119            self.curves = self.getLearningCurve(learners)
120            self.computeScores()
121
122    def computeScores(self):           
123        self.scores = [[] for i in range(len(self.learners))]
124        for x in self.curves:
125            for (i,s) in enumerate(self.scoring[self.scoringF][1](x)):
126                self.scores[i].append(s)
127        self.setTable()
128
129    def getLearningCurve(self, learners):   
130        pb = OWGUI.ProgressBar(self, iterations=self.steps*self.folds)
131        curve = orngTest.learningCurveN(learners, self.data, folds=self.folds, proportions=self.curvePoints, callback=pb.advance)
132        pb.finish()
133        return curve
134
135    def setCurvePoints(self):
136        self.curvePoints = [(x+1.)/self.steps for x in range(self.steps)]
137
138    def setTable(self):
139        self.table.setColumnCount(0)
140        self.table.setColumnCount(len(self.learners))
141        self.table.setRowCount(self.steps)
142
143        # set the headers
144        self.table.setHorizontalHeaderLabels([l.name for i,l in self.learners])
145        self.table.setVerticalHeaderLabels(["%4.2f" % p for p in self.curvePoints])
146
147        # set the table contents
148        for l in range(len(self.learners)):
149            for p in range(self.steps):
150                OWGUI.tableItem(self.table, p, l, "%7.5f" % self.scores[l][p])
151
152        for i in range(len(self.learners)):
153            self.table.setColumnWidth(i, 80)
154
155##############################################################################
156# Test the widget, run from prompt
157
158if __name__=="__main__":
159    appl = QApplication(sys.argv)
160    ow = OWLearningCurveA()
161    ow.show()
162   
163    l1 = orange.BayesLearner()
164    l1.name = 'Naive Bayes'
165    ow.learner(l1, 1)
166
167    data = orange.ExampleTable('../datasets/iris.tab')
168    ow.dataset(data)
169
170    l2 = orange.BayesLearner()
171    l2.name = 'Naive Bayes (m=10)'
172    l2.estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10)
173    l2.conditionalEstimatorConstructor = orange.ConditionalProbabilityEstimatorConstructor_ByRows(estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10))
174    ow.learner(l2, 2)
175
176    import orngTree
177    l4 = orngTree.TreeLearner(minSubset=2)
178    l4.name = "Decision Tree"
179    ow.learner(l4, 4)
180
181#    ow.learner(None, 1)
182#    ow.learner(None, 2)
183#    ow.learner(None, 4)
184
185    appl.exec_()
Note: See TracBrowser for help on using the repository browser.