source: orange/orange/OrangeWidgets/Prototypes/OWCalibratedClassifier.py @ 9546:2b6cc6f397fe

Revision 9546:2b6cc6f397fe, 6.6 KB checked in by ales_erjavec <ales.erjavec@…>, 2 years ago (diff)

Renamed widget channel names in line with the new naming rules/convention.
Added backwards compatibility in orngDoc loadDocument to enable loading of schemas saved before the change.

Line 
1"""
2<name>Calibrated Classifier</name>
3<description>Given a learner, it builds a classifier calibrated for optimal classification accuracy</description>
4<icon>icons/CalibratedClassifier.png</icon>
5<priority>1030</priority>
6"""
7
8from OWWidget import *
9from OWGraph import *
10import OWGUI
11
12import orngWrap
13
14class ThresholdGraph(OWGraph):
15    def __init__(self, parent = None, title = ""):
16        OWGraph.__init__(self, parent)
17        self.setYRlabels(None)
18        self.enableGridXB(0)
19        self.enableGridYL(1)
20        self.setAxisMaxMajor(QwtPlot.xBottom, 10)
21        self.setAxisMaxMinor(QwtPlot.xBottom, 3)
22        self.setAxisMaxMajor(QwtPlot.yLeft, 10)
23        self.setAxisMaxMinor(QwtPlot.yLeft, 5)
24        self.setAxisScale(QwtPlot.xBottom, -0.0, 100.0, 0)
25        self.setAxisScale(QwtPlot.yLeft, -0.0, 1.0, 0)
26        self.setYLaxisTitle('classification accuracy')
27        self.setShowYLaxisTitle(1)
28        self.setXaxisTitle('threshold')
29        self.setShowXaxisTitle(1)
30        self.setShowMainTitle(1)
31        self.setMainTitle(title)
32
33        self.curve = self.addCurve("")
34        self.thresholdCurve = self.addCurve("Threashold")
35
36    def setCurve(self, coords):
37        self.curve.setData([100*x[0] for x in coords], [x[1] for x in coords])
38        self.curve.setPen(QPen(Qt.black, 2))
39        self.replot()
40
41    def setThreshold(self, threshold):
42        self.thresholdCurve.setData([threshold, threshold], [0, 1])
43        self.thresholdCurve.setPen(QPen(Qt.blue, 1))
44        self.replot()
45
46   
47class OWCalibratedClassifier(OWWidget):
48    settingsList = ["name", "optimalThreshold", "threshold"]
49    def __init__(self, parent=None, signalManager = None, title = "Calibrated Classifier"):
50        OWWidget.__init__(self, parent, signalManager, title)
51
52        self.inputs = [("Data", ExampleTable, self.setData), ("Base Learner", orange.Learner, self.setBaseLearner)]
53        self.outputs = [("Learner", orange.Learner),("Classifier", orange.Classifier)]
54
55        # Settings
56        self.name = 'Calibrated Learner'
57        self.optimalThreshold = 0
58        self.threshold = self.accuracy = 50
59        self.loadSettings()
60
61        self.learner = None
62        self.baseLearner = None
63        self.data = None
64
65        OWGUI.lineEdit(self.controlArea, self, 'name',
66                       box='Learner/Classifier Name', 
67                       tooltip='Name to be used by other widgets to identify your learner/classifier.')
68        OWGUI.separator(self.controlArea)
69
70        self.wbThreshold = OWGUI.widgetBox(self.controlArea, "Threshold", addSpace=True)
71        self.cbOptimal = OWGUI.checkBox(self.wbThreshold, self, "optimalThreshold",
72                                        "Use optimal threshold",
73                                        callback=self.setThreshold)
74       
75        self.spThreshold = OWGUI.spin(self.wbThreshold, self, "threshold", 1, 99, step=5,
76                                      label = "Threshold",
77                                      orientation = "horizontal",
78                                      callback = self.setThreshold)
79       
80        self.lbNotice = OWGUI.widgetLabel(self.wbThreshold, "Notice: If the widget is connected to a widget that takes a Learner, not a Classifier (eg 'Test Learners'), the automatically computed threshold can differ from the above.")
81        self.lbNotice.setWordWrap(True)
82         
83        self.cbOptimal.disables = [self.lbNotice]
84        self.cbOptimal.makeConsistent()
85        self.spThreshold.setDisabled(self.optimalThreshold)
86       
87        OWGUI.rubber(self.controlArea)
88       
89        OWGUI.button(self.controlArea, self, "&Apply Setting",
90                     callback = self.btApplyCallback,
91                     disabled=0)
92
93        self.btSave = OWGUI.button(self.controlArea, self, "&Save Graph", callback = self.saveToFile, disabled=1)
94
95        self.graph = ThresholdGraph()
96        self.mainArea.layout().addWidget(self.graph)
97
98        self.resize(700, 330)
99
100    def setData(self, data):
101        self.error([0])
102        if data and len(data.domain.classVar.values) == 2:
103            self.data = data
104        else:
105            self.error(0, "ThresholdLearner handles binary classes only!")
106            self.data = None
107        self.compute_baseClassifier_curve_threshold()
108        self.construct_classifier()
109
110    def setBaseLearner(self, baseLearner):
111        self.baseLearner = baseLearner
112        self.construct_learner()
113        self.compute_baseClassifier_curve_threshold()
114        self.construct_classifier()
115
116    def btApplyCallback(self):
117        self.construct_learner()
118        self.construct_classifier()
119
120    def setThreshold(self):
121        self.spThreshold.setDisabled(self.optimalThreshold)
122        if self.optimalThreshold:
123            self.threshold = self.computedThreshold*100
124        self.graph.setThreshold(self.threshold)
125
126    def construct_learner(self):
127        if self.baseLearner:
128            if self.optimalThreshold:
129                self.learner = orngWrap.ThresholdLearner(learner=self.baseLearner, storeCurve = 1)
130            else:
131                self.learner = orngWrap.ThresholdLearner_fixed(learner=self.baseLearner, threshhold=self.threshold/100.0)
132            self.learner.name = self.name
133        else:
134            self.learner = None
135        self.send("Learner", self.learner)
136
137    def compute_baseClassifier_curve_threshold(self):
138        if not self.learner or not self.data:
139            self.baseClassifier = None
140            self.computedThreshold = 0.5
141            self.curve = []
142        else:
143            self.baseClassifier = self.baseLearner(self.data)
144            self.computedThreshold, CA, self.curve = orange.ThresholdCA(self.baseClassifier, self.data)
145            if self.optimalThreshold:
146                self.threshold = self.computedThreshold*100
147        self.graph.setCurve(self.curve)
148        self.graph.setThreshold(self.threshold)
149        self.btSave.setDisabled(not self.curve)
150
151    def construct_classifier(self):
152        if not self.baseClassifier:
153            self.classifier = None
154        else:
155            self.classifier = orngWrap.ThresholdClassifier(self.baseClassifier, self.threshold)
156            self.classifier.name = self.name
157        self.send("Classifier", self.classifier)
158
159    def saveToFile(self):
160        from OWDlgs import OWChooseImageSizeDlg
161        dlg = OWChooseImageSizeDlg(self.graph)
162        dlg.exec_()
163
164if __name__ == "__main__":
165    a = QApplication(sys.argv)
166    owdm = OWCalibratedClassifier()
167
168    data = orange.ExampleTable("../../doc/datasets/breast-cancer")
169    learner = orange.BayesLearner()
170    owdm.setData(data)
171    owdm.setBaseLearner(learner)
172   
173    owdm.show()
174    a.exec_()
175    owdm.saveSettings()
Note: See TracBrowser for help on using the repository browser.