source: orange/orange/OrangeWidgets/Classify/OWSVM.py @ 9546:2b6cc6f397fe

Revision 9546:2b6cc6f397fe, 11.4 KB checked in by ales_erjavec <ales.erjavec@…>, 2 years ago (diff)

Renamed widget channel names in line with the new naming rules/convention.
Added backwards compatibility in orngDoc loadDocument to enable loading of schemas saved before the change.

Line 
1# coding=utf-8
2"""
3<name>SVM</name>
4<description>Support Vector Machines learner/classifier.</description>
5<icon>icons/BasicSVM.png</icon>
6<contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact>
7<priority>100</priority>
8"""
9import orange, orngSVM, OWGUI, sys
10from OWWidget import *
11from exceptions import SystemExit
12from orngWrap import PreprocessedLearner
13
14class OWSVM(OWWidget):
15    settingsList=["C","nu","p", "eps", "probability","gamma","degree", 
16                  "coef0", "kernel_type", "name", "useNu", "normalization"]
17    def __init__(self, parent=None, signalManager=None, name="SVM"):
18        OWWidget.__init__(self, parent, signalManager, name, wantMainArea = 0, resizingEnabled = 0)
19       
20        self.inputs = [("Data", ExampleTable, self.setData),
21                       ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
22        self.outputs = [("Learner", orange.Learner, Default),
23                        ("Classifier", orange.Classifier, Default),
24                        ("Support Vectors", ExampleTable)]
25
26        self.kernel_type = 2
27        self.gamma = 0.0
28        self.coef0 = 0.0
29        self.degree = 3
30        self.C = 1.0
31        self.p = 0.1
32        self.eps = 1e-3
33        self.nu = 0.5
34        self.shrinking = 1
35        self.probability=1
36        self.useNu=0
37        self.nomogram=0
38        self.normalization=1
39        self.data = None
40        self.selFlag=False
41        self.preprocessor = None
42        self.name="SVM"
43
44        OWGUI.lineEdit(self.controlArea, self, 'name', 
45                       box='Learner/Classifier Name', 
46                       tooltip='Name to be used by other widgets to identify your learner/classifier.')
47        OWGUI.separator(self.controlArea)
48
49        b = OWGUI.radioButtonsInBox(self.controlArea, self, "useNu", [], 
50                                    box="SVM Type", 
51                                    orientation = QGridLayout(), 
52                                    addSpace=True)
53       
54        b.layout().addWidget(OWGUI.appendRadioButton(b, self, "useNu", "C-SVM",
55                                                     addToLayout=False),
56                             0, 0, Qt.AlignLeft)
57       
58        b.layout().addWidget(QLabel("Cost (C)", b), 0, 1, Qt.AlignRight)
59        b.layout().addWidget(OWGUI.doubleSpin(b, self, "C", 0.5, 512.0, 0.5, 
60                                addToLayout=False, 
61                                callback=lambda *x: self.setType(0), 
62                                alignment=Qt.AlignRight,
63                                tooltip= "Cost for a mis-classified training instance."),
64                             0, 2)
65       
66        b.layout().addWidget(OWGUI.appendRadioButton(b, self, "useNu", u"ν-SVM", 
67                                                     addToLayout=False),
68                             1, 0, Qt.AlignLeft)
69       
70        b.layout().addWidget(QLabel(u"Complexity bound (\u03bd)", b), 1, 1, Qt.AlignRight)
71        b.layout().addWidget(OWGUI.doubleSpin(b, self, "nu", 0.1, 1.0, 0.1, 
72                                tooltip="Lower bound on the ratio of support vectors",
73                                addToLayout=False, 
74                                callback=lambda *x: self.setType(1),
75                                alignment=Qt.AlignRight),
76                             1, 2)
77       
78        self.kernelBox=b = OWGUI.widgetBox(self.controlArea, "Kernel")
79        self.kernelradio = OWGUI.radioButtonsInBox(b, self, "kernel_type",
80                                        btnLabels=[u"Linear,   x∙y",
81                                                   u"Polynomial,   (g x∙y + c)^d",
82                                                   u"RBF,   exp(-g|x-y|²)",
83                                                   u"Sigmoid,   tanh(g x∙y + c)"],
84                                        callback=self.changeKernel)
85
86        OWGUI.separator(b)
87        self.gcd = OWGUI.widgetBox(b, orientation="horizontal")
88        self.leg = OWGUI.doubleSpin(self.gcd, self, "gamma",0.0,10.0,0.0001,
89                                label="  g: ",
90                                orientation="horizontal",
91                                callback=self.changeKernel,
92                                alignment=Qt.AlignRight)
93       
94        self.led = OWGUI.doubleSpin(self.gcd, self, "coef0", 0.0,10.0,0.0001,
95                                label="  c: ",
96                                orientation="horizontal",
97                                callback=self.changeKernel,
98                                alignment=Qt.AlignRight)
99       
100        self.lec = OWGUI.doubleSpin(self.gcd, self, "degree", 0.0,10.0,0.5,
101                                label="  d: ",
102                                orientation="horizontal",
103                                callback=self.changeKernel,
104                                alignment=Qt.AlignRight)
105
106        OWGUI.separator(self.controlArea)
107       
108        self.optionsBox=b=OWGUI.widgetBox(self.controlArea, "Options",
109                                          addSpace=True)
110       
111        OWGUI.doubleSpin(b,self, "eps", 0.0005, 1.0, 0.0005,
112                         label=u"Numerical tolerance",
113                         labelWidth = 180,
114                         orientation="horizontal",
115                         tooltip="Numerical tolerance of termination criterion.",
116                         alignment=Qt.AlignRight)
117
118        self.probBox = OWGUI.checkBox(b,self, "probability",
119                                      label="Estimate class probabilities",
120                                      tooltip="Create classifiers that support class probability estimation."
121                                      )
122       
123        OWGUI.checkBox(b, self, "normalization",
124                       label="Normalize data", 
125                       tooltip="Use data normalization")
126
127        self.paramButton=OWGUI.button(self.controlArea, self, "Automatic parameter search",
128                                      callback=self.parameterSearch,
129                                      tooltip="Automatically searches for parameters that optimize classifier accuracy", 
130                                      debuggingEnabled=0)
131       
132        self.paramButton.setDisabled(True)
133
134        OWGUI.button(self.controlArea, self,"&Apply", 
135                     callback=self.applySettings, 
136                     default=True)
137       
138        OWGUI.rubber(self.controlArea)
139       
140        self.loadSettings()
141        self.changeKernel()
142        self.searching=False
143        self.applySettings()
144
145    def sendReport(self):
146        if self.kernel_type == 0:
147            kernel = "Linear, x.y"
148        elif self.kernel_type == 1:
149            kernel = "Polynomial, (%.4f*x.y+%.4f)<sup>%.4f</sup>" % (self.gamma, self.coef0, self.degree)
150        elif self.kernel_type == 2:
151            kernel = "RBF, e<sup>-%.4f*(x-y).(x-y)</sup>" % self.gamma
152        else:
153            kernel = "Sigmoid, tanh(%.4f*x.y+%.4f)" % (self.gamma, self.coef0)
154        self.reportSettings("Learning parameters",
155                            [("Kernel", kernel),
156                             ("Cost (C)", self.C),
157                             ("Numeric precision", self.eps),
158                             self.useNu and ("Complexity bound (nu)", self.nu),
159                             ("Estimate class probabilities", OWGUI.YesNo[self.probability]),
160                             ("Normalize data", OWGUI.YesNo[self.normalization])])
161        self.reportData(self.data)
162       
163    def setType(self, type):
164        self.useNu = type
165       
166    def changeKernel(self):
167        if self.kernel_type==0:
168            for a,b in zip([self.leg, self.led, self.lec], [1,1,1]):
169                a.setDisabled(b)
170        elif self.kernel_type==1:
171            for a,b in zip([self.leg, self.led, self.lec], [0,0,0]):
172                a.setDisabled(b)
173        elif self.kernel_type==2:
174            for a,b in zip([self.leg, self.led, self.lec], [0,1,1]):
175                a.setDisabled(b)
176        elif self.kernel_type==3:
177            for a,b in zip([self.leg, self.led, self.lec], [0,0,1]):
178                a.setDisabled(b)
179
180    def setData(self, data=None):
181        self.data = self.isDataWithClass(data, checkMissing=True) and data or None
182        self.paramButton.setDisabled(not self.data)
183       
184    def setPreprocessor(self, pp):
185        self.preprocessor = pp
186       
187    def handleNewSignals(self):
188        self.applySettings()
189
190    def applySettings(self):
191        self.learner=orngSVM.SVMLearner()
192        for attr in ("name", "kernel_type", "degree", "shrinking", "probability", "normalization"):
193            setattr(self.learner, attr, getattr(self, attr))
194
195        for attr in ("gamma", "coef0", "C", "p", "eps", "nu"):
196            setattr(self.learner, attr, float(getattr(self, attr)))
197
198        self.learner.svm_type=orngSVM.SVMLearner.C_SVC
199
200        if self.useNu:
201            self.learner.svm_type=orngSVM.SVMLearner.Nu_SVC
202
203        if self.preprocessor:
204            self.learner = self.preprocessor.wrapLearner(self.learner)
205        self.classifier=None
206        self.supportVectors=None
207       
208        if self.data:
209            if self.data.domain.classVar.varType==orange.VarTypes.Continuous:
210                self.learner.svm_type+=3
211            self.classifier=self.learner(self.data)
212            self.supportVectors=self.classifier.supportVectors
213            self.classifier.name=self.name
214           
215        self.send("Learner", self.learner)
216        self.send("Classifier", self.classifier)
217        self.send("Support Vectors", self.supportVectors)
218
219    def parameterSearch(self):
220        if not self.data:
221            return
222        if self.searching:
223            self.searching=False
224        else:
225            self.kernelBox.setDisabled(1)
226            self.optionsBox.setDisabled(1)
227            self.progressBarInit()
228            self.paramButton.setText("Stop")
229            self.searching=True
230            self.search_()
231
232    def progres(self, f, best=None):
233        qApp.processEvents()
234        self.progressBarSet(f)
235        if not self.searching:
236            raise UnhandledException()
237
238    def finishSearch(self):
239        self.progressBarFinished()
240        self.kernelBox.setDisabled(0)
241        self.optionsBox.setDisabled(0)
242        self.paramButton.setText("Automatic parameter search")
243        self.searching=False
244
245    def search_(self):
246        learner=orngSVM.SVMLearner()
247        for attr in ("name", "kernel_type", "degree", "shrinking", "probability", "normalization"):
248            setattr(learner, attr, getattr(self, attr))
249
250        for attr in ("gamma", "coef0", "C", "p", "eps", "nu"):
251            setattr(learner, attr, float(getattr(self, attr)))
252
253        learner.svm_type=0
254
255        if self.useNu:
256            learner.svm_type=1
257        params=[]
258        if self.useNu:
259            params.append("nu")
260        else:
261            params.append("C")
262        if self.kernel_type in [1,2]:
263            params.append("gamma")
264        if self.kernel_type==1:
265            params.append("degree")
266        try:
267            learner.tuneParameters(self.data, params, 4, verbose=0, progressCallback=self.progres)
268        except UnhandledException:
269            pass
270        for param in params:
271            setattr(self, param, getattr(learner, param))
272           
273        self.finishSearch()
274
275from exceptions import Exception
276class UnhandledException(Exception):
277    pass
278
279import sys
280if __name__=="__main__":
281    app=QApplication(sys.argv)
282    w=OWSVM()
283    w.show()
284    #d=orange.ExampleTable("../../doc/datasets/iris.tab")
285    #w.setData(d)
286    app.exec_()
287    w.saveSettings()
Note: See TracBrowser for help on using the repository browser.