source: orange/Orange/OrangeWidgets/Classify/OWSVM.py @ 10730:61fcb884c0ed

Revision 10730:61fcb884c0ed, 11.5 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

More precise GUI controls.

Line 
1# coding=utf-8
2"""
3<name>SVM</name>
4<description>Support Vector Machines learner/classifier.</description>
5<icon>icons/BasicSVM.png</icon>
6<contact>Ales Erjavec (ales.erjavec(@at@)fri.uni-lj.si)</contact>
7<priority>100</priority>
8"""
9import orange, orngSVM, OWGUI, sys
10from OWWidget import *
11from exceptions import SystemExit
12from orngWrap import PreprocessedLearner
13
14class OWSVM(OWWidget):
15    settingsList=["C","nu","p", "eps", "probability","gamma","degree", 
16                  "coef0", "kernel_type", "name", "useNu", "normalization"]
17    def __init__(self, parent=None, signalManager=None, name="SVM"):
18        OWWidget.__init__(self, parent, signalManager, name, wantMainArea = 0, resizingEnabled = 0)
19       
20        self.inputs = [("Data", ExampleTable, self.setData),
21                       ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
22        self.outputs = [("Learner", orange.Learner, Default),
23                        ("Classifier", orange.Classifier, Default),
24                        ("Support Vectors", ExampleTable)]
25
26        self.kernel_type = 2
27        self.gamma = 0.0
28        self.coef0 = 0.0
29        self.degree = 3
30        self.C = 1.0
31        self.p = 0.1
32        self.eps = 1e-3
33        self.nu = 0.5
34        self.shrinking = 1
35        self.probability=1
36        self.useNu=0
37        self.nomogram=0
38        self.normalization=1
39        self.data = None
40        self.selFlag=False
41        self.preprocessor = None
42        self.name="SVM"
43
44        OWGUI.lineEdit(self.controlArea, self, 'name', 
45                       box='Learner/Classifier Name', 
46                       tooltip='Name to be used by other widgets to identify your learner/classifier.')
47        OWGUI.separator(self.controlArea)
48
49        b = OWGUI.radioButtonsInBox(self.controlArea, self, "useNu", [], 
50                                    box="SVM Type", 
51                                    orientation = QGridLayout(), 
52                                    addSpace=True)
53       
54        b.layout().addWidget(OWGUI.appendRadioButton(b, self, "useNu", "C-SVM",
55                                                     addToLayout=False),
56                             0, 0, Qt.AlignLeft)
57       
58        b.layout().addWidget(QLabel("Cost (C)", b), 0, 1, Qt.AlignRight)
59        b.layout().addWidget(OWGUI.doubleSpin(b, self, "C", 0.1, 512.0, 0.1,
60                                decimals=2,
61                                addToLayout=False,
62                                callback=lambda *x: self.setType(0),
63                                alignment=Qt.AlignRight,
64                                tooltip= "Cost for a mis-classified training instance."),
65                             0, 2)
66       
67        b.layout().addWidget(OWGUI.appendRadioButton(b, self, "useNu", u"ν-SVM",
68                                                     addToLayout=False),
69                             1, 0, Qt.AlignLeft)
70
71        b.layout().addWidget(QLabel(u"Complexity bound (\u03bd)", b), 1, 1, Qt.AlignRight)
72        b.layout().addWidget(OWGUI.doubleSpin(b, self, "nu", 0.05, 1.0, 0.05,
73                                tooltip="Lower bound on the ratio of support vectors",
74                                addToLayout=False,
75                                callback=lambda *x: self.setType(1),
76                                alignment=Qt.AlignRight),
77                             1, 2)
78       
79        self.kernelBox=b = OWGUI.widgetBox(self.controlArea, "Kernel")
80        self.kernelradio = OWGUI.radioButtonsInBox(b, self, "kernel_type",
81                                        btnLabels=[u"Linear,   x∙y",
82                                                   u"Polynomial,   (g x∙y + c)^d",
83                                                   u"RBF,   exp(-g|x-y|²)",
84                                                   u"Sigmoid,   tanh(g x∙y + c)"],
85                                        callback=self.changeKernel)
86
87        OWGUI.separator(b)
88        self.gcd = OWGUI.widgetBox(b, orientation="horizontal")
89        self.leg = OWGUI.doubleSpin(self.gcd, self, "gamma",0.0,10.0,0.0001,
90                                decimals=5,
91                                label="  g: ",
92                                orientation="horizontal",
93                                callback=self.changeKernel,
94                                alignment=Qt.AlignRight)
95       
96        self.led = OWGUI.doubleSpin(self.gcd, self, "coef0", 0.0,10.0,0.0001,
97                                label="  c: ",
98                                orientation="horizontal",
99                                callback=self.changeKernel,
100                                alignment=Qt.AlignRight)
101       
102        self.lec = OWGUI.doubleSpin(self.gcd, self, "degree", 0.0,10.0,0.5,
103                                label="  d: ",
104                                orientation="horizontal",
105                                callback=self.changeKernel,
106                                alignment=Qt.AlignRight)
107
108        OWGUI.separator(self.controlArea)
109       
110        self.optionsBox=b=OWGUI.widgetBox(self.controlArea, "Options",
111                                          addSpace=True)
112       
113        OWGUI.doubleSpin(b,self, "eps", 0.0005, 1.0, 0.0005,
114                         label=u"Numerical tolerance",
115                         labelWidth = 180,
116                         orientation="horizontal",
117                         tooltip="Numerical tolerance of termination criterion.",
118                         alignment=Qt.AlignRight)
119
120        self.probBox = OWGUI.checkBox(b,self, "probability",
121                                      label="Estimate class probabilities",
122                                      tooltip="Create classifiers that support class probability estimation."
123                                      )
124       
125        OWGUI.checkBox(b, self, "normalization",
126                       label="Normalize data", 
127                       tooltip="Use data normalization")
128
129        self.paramButton=OWGUI.button(self.controlArea, self, "Automatic parameter search",
130                                      callback=self.parameterSearch,
131                                      tooltip="Automatically searches for parameters that optimize classifier accuracy", 
132                                      debuggingEnabled=0)
133       
134        self.paramButton.setDisabled(True)
135
136        OWGUI.button(self.controlArea, self,"&Apply", 
137                     callback=self.applySettings, 
138                     default=True)
139       
140        OWGUI.rubber(self.controlArea)
141       
142        self.loadSettings()
143        self.changeKernel()
144        self.searching=False
145        self.applySettings()
146
147    def sendReport(self):
148        if self.kernel_type == 0:
149            kernel = "Linear, x.y"
150        elif self.kernel_type == 1:
151            kernel = "Polynomial, (%.4f*x.y+%.4f)<sup>%.4f</sup>" % (self.gamma, self.coef0, self.degree)
152        elif self.kernel_type == 2:
153            kernel = "RBF, e<sup>-%.4f*(x-y).(x-y)</sup>" % self.gamma
154        else:
155            kernel = "Sigmoid, tanh(%.4f*x.y+%.4f)" % (self.gamma, self.coef0)
156        self.reportSettings("Learning parameters",
157                            [("Kernel", kernel),
158                             ("Cost (C)", self.C),
159                             ("Numeric precision", self.eps),
160                             self.useNu and ("Complexity bound (nu)", self.nu),
161                             ("Estimate class probabilities", OWGUI.YesNo[self.probability]),
162                             ("Normalize data", OWGUI.YesNo[self.normalization])])
163        self.reportData(self.data)
164       
165    def setType(self, type):
166        self.useNu = type
167       
168    def changeKernel(self):
169        if self.kernel_type==0:
170            for a,b in zip([self.leg, self.led, self.lec], [1,1,1]):
171                a.setDisabled(b)
172        elif self.kernel_type==1:
173            for a,b in zip([self.leg, self.led, self.lec], [0,0,0]):
174                a.setDisabled(b)
175        elif self.kernel_type==2:
176            for a,b in zip([self.leg, self.led, self.lec], [0,1,1]):
177                a.setDisabled(b)
178        elif self.kernel_type==3:
179            for a,b in zip([self.leg, self.led, self.lec], [0,0,1]):
180                a.setDisabled(b)
181
182    def setData(self, data=None):
183        self.data = self.isDataWithClass(data, checkMissing=True) and data or None
184        self.paramButton.setDisabled(not self.data)
185       
186    def setPreprocessor(self, pp):
187        self.preprocessor = pp
188       
189    def handleNewSignals(self):
190        self.applySettings()
191
192    def applySettings(self):
193        self.learner=orngSVM.SVMLearner()
194        for attr in ("name", "kernel_type", "degree", "shrinking", "probability", "normalization"):
195            setattr(self.learner, attr, getattr(self, attr))
196
197        for attr in ("gamma", "coef0", "C", "p", "eps", "nu"):
198            setattr(self.learner, attr, float(getattr(self, attr)))
199
200        self.learner.svm_type=orngSVM.SVMLearner.C_SVC
201
202        if self.useNu:
203            self.learner.svm_type=orngSVM.SVMLearner.Nu_SVC
204
205        if self.preprocessor:
206            self.learner = self.preprocessor.wrapLearner(self.learner)
207        self.classifier=None
208        self.supportVectors=None
209       
210        if self.data:
211            if self.data.domain.classVar.varType==orange.VarTypes.Continuous:
212                self.learner.svm_type+=3
213            self.classifier=self.learner(self.data)
214            self.supportVectors=self.classifier.supportVectors
215            self.classifier.name=self.name
216           
217        self.send("Learner", self.learner)
218        self.send("Classifier", self.classifier)
219        self.send("Support Vectors", self.supportVectors)
220
221    def parameterSearch(self):
222        if not self.data:
223            return
224        if self.searching:
225            self.searching=False
226        else:
227            self.kernelBox.setDisabled(1)
228            self.optionsBox.setDisabled(1)
229            self.progressBarInit()
230            self.paramButton.setText("Stop")
231            self.searching=True
232            self.search_()
233
234    def progres(self, f, best=None):
235        qApp.processEvents()
236        self.progressBarSet(f)
237        if not self.searching:
238            raise UnhandledException()
239
240    def finishSearch(self):
241        self.progressBarFinished()
242        self.kernelBox.setDisabled(0)
243        self.optionsBox.setDisabled(0)
244        self.paramButton.setText("Automatic parameter search")
245        self.searching=False
246
247    def search_(self):
248        learner=orngSVM.SVMLearner()
249        for attr in ("name", "kernel_type", "degree", "shrinking", "probability", "normalization"):
250            setattr(learner, attr, getattr(self, attr))
251
252        for attr in ("gamma", "coef0", "C", "p", "eps", "nu"):
253            setattr(learner, attr, float(getattr(self, attr)))
254
255        learner.svm_type=0
256
257        if self.useNu:
258            learner.svm_type=1
259        params=[]
260        if self.useNu:
261            params.append("nu")
262        else:
263            params.append("C")
264        if self.kernel_type in [1,2]:
265            params.append("gamma")
266        if self.kernel_type==1:
267            params.append("degree")
268        try:
269            learner.tuneParameters(self.data, params, 4, verbose=0,
270                                   progressCallback=self.progres)
271        except UnhandledException:
272            pass
273        for param in params:
274            setattr(self, param, getattr(learner, param))
275           
276        self.finishSearch()
277
278from exceptions import Exception
279class UnhandledException(Exception):
280    pass
281
282import sys
283if __name__=="__main__":
284    app=QApplication(sys.argv)
285    w=OWSVM()
286    w.show()
287    #d=orange.ExampleTable("../../doc/datasets/iris.tab")
288    #w.setData(d)
289    app.exec_()
290    w.saveSettings()
Note: See TracBrowser for help on using the repository browser.