Changeset 8990:5895df76d717 in orange
 Timestamp:
 09/20/11 14:05:23 (3 years ago)
 Branch:
 default
 Convert:
 7e425d0972d34baf62f1b9b53c6bf298c8462a08
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

orange/Orange/classification/svm/__init__.py
r8762 r8990 39 39  40 40 41 .. automethod:: Orange.classification.svm.max Nu42 43 .. automethod:: Orange.classification.svm.get LinearSVMWeights44 45 .. automethod:: Orange.classification.svm.table ToSVMFormat41 .. automethod:: Orange.classification.svm.max_nu 42 43 .. automethod:: Orange.classification.svm.get_linear_svm_weights 44 45 .. automethod:: Orange.classification.svm.table_to_svm_format 46 46 47 47 SVMderived feature weights … … 111 111 import Orange.core 112 112 import Orange.data 113 import Orange.misc 114 113 115 import kernels 114 116 import warnings … … 152 154 return _orange__new_wrapped 153 155 154 def max Nu(examples):156 def max_nu(examples): 155 157 """Return the maximum nu parameter for Nu_SVC support vector learning 156 158 for the given data table. … … 165 167 return min([2.0 * min(n1, n2) / (n1 + n2) for n1, n2 in pairs(dist) \ 166 168 if n1 != 0 and n2 !=0] + [nu]) 169 170 maxNu = max_nu 167 171 168 172 class SVMLearner(_SVMLearner): … … 181 185 :param coef0: kernel parameter (Polynomial/Sigmoid) (default 0) 182 186 :type coef0: int 183 :param kernel Func: function that will be called if `kernel_type` is187 :param kernel_func: function that will be called if `kernel_type` is 184 188 `Custom`. It must accept two :obj:`Orange.data.Instance` arguments and 185 189 return a float (the distance between the instances). 186 :type kernel Func: callable function190 :type kernel_func: callable function 187 191 :param C: C parameter for C_SVC, Epsilon_SVR, Nu_SVR 188 192 :type C: float … … 204 208 :type weights: list 205 209 210 Example: 211 212 >>> import Orange 213 >>> table = Orange.data.Table("vehicle.tab") 214 >>> svm = Orange.classification.svm.SVMLearner() 215 >>> results = Orange.evaluation.testing.cross_validation([svm], table, folds=5) 216 >>> print Orange.evaluation.scoring.CA(results) 217 206 218 """ 207 219 __new__ = _orange__new__(_SVMLearner) … … 213 225 Epsilon_SVR = _SVMLearner.Epsilon_SVR 214 226 227 @Orange.misc.deprecated_keywords({"kernelFunc": "kernel_func"}) 215 228 def __init__(self, svm_type=Nu_SVC, kernel_type=kernels.RBF, 216 kernel Func=None, C=1.0, nu=0.5, p=0.1, gamma=0.0, degree=3,229 kernel_func=None, C=1.0, nu=0.5, p=0.1, gamma=0.0, degree=3, 217 230 coef0=0, shrinking=True, probability=True, verbose=False, 218 231 cache_size=200, eps=0.001, normalization=True, … … 220 233 self.svm_type = SVMLearner.Nu_SVC 221 234 self.kernel_type = kernel_type 222 self.kernel Func = kernelFunc235 self.kernel_func = kernel_func 223 236 self.C = C 224 237 self.nu = nu … … 239 252 self.weight = weight 240 253 241 max Nu = staticmethod(maxNu)254 max_nu = staticmethod(max_nu) 242 255 243 256 def __call__(self, examples, weight=0): … … 261 274 self.svm_type=3 262 275 #raise AttributeError, "Cannot do regression on descrete class data. Use C_SVC or NU_SVC for classification" 263 if self.kernel_type==4 and not self.kernel Func:276 if self.kernel_type==4 and not self.kernel_func: 264 277 raise AttributeError, "Custom kernel function not supplied" 265 278 ################################################## … … 269 282 nu = self.nu 270 283 if self.svm_type == SVMLearner.Nu_SVC: #is nu feasibile 271 max Nu= self.maxNu(examples)272 if self.nu > max Nu:284 max_nu= self.max_nu(examples) 285 if self.nu > max_nu: 273 286 if getattr(self, "verbose", 0): 274 287 import warnings 275 288 warnings.warn("Specified nu %.3f is infeasible. \ 276 Setting nu to %.3f" % (self.nu, max Nu))277 nu = max(max Nu  1e7, 0.0)278 279 for name in ["svm_type", "kernel_type", "kernel Func", "C", "nu", "p",289 Setting nu to %.3f" % (self.nu, max_nu)) 290 nu = max(max_nu  1e7, 0.0) 291 292 for name in ["svm_type", "kernel_type", "kernel_func", "C", "nu", "p", 280 293 "gamma", "degree", "coef0", "shrinking", "probability", 281 294 "verbose", "cache_size", "eps"]: … … 285 298 return self.learnClassifier(examples) 286 299 287 def learn Classifier(self, examples):300 def learn_classifier(self, examples): 288 301 if self.normalization: 289 302 examples = self._normalize(examples) … … 295 308 return self.learner(examples) 296 309 297 def tuneParameters(self, examples, parameters=None, folds=5, verbose=0, 298 progressCallback=None): 310 @Orange.misc.deprecated_keywords({"progressCallback": "progress_callback"}) 311 def tune_parameters(self, examples, parameters=None, folds=5, verbose=0, 312 progress_callback=None): 299 313 """Tune the parameters of the SVMLearner on given instances using 300 314 cross validation. … … 308 322 :param verbose: default False 309 323 :type verbose: bool 310 :param progressCallback: report progress 311 :type progressCallback: callback function 312 313 Example:: 324 :param progress_callback: report progress 325 :type progress_callback: callback function 326 327 Example: 328 314 329 >>> svm = SVMLearner() 315 330 >>> svm.tuneParameters(examples, parameters=["gamma"], folds=3) … … 328 343 if self.svm_type == SVMLearner.Nu_SVC and "nu" in parameters: 329 344 numOfNuValues=9 330 max Nu = max(self.maxNu(examples)  1e7, 0.0)345 max_nu = max(self.max_nu(examples)  1e7, 0.0) 331 346 searchParams.append(("nu", [i/10.0 for i in range(1, 9) if \ 332 i/10.0 < max Nu] + [maxNu]))347 i/10.0 < max_nu] + [max_nu])) 333 348 elif "C" in parameters: 334 349 searchParams.append(("C", [2**a for a in range(5,15,2)])) … … 339 354 folds=folds, 340 355 returnWhat=orngWrap.TuneMParameters.returnLearner, 341 progressCallback=progress Callback342 if progress Callback else lambda i:None)356 progressCallback=progress_callback 357 if progress_callback else lambda i:None) 343 358 tunedLearner(examples, verbose=verbose) 344 359 if normalization: … … 353 368 return examples.translate(newdomain) 354 369 370 SVMLearner = Orange.misc.deprecated_members({ 371 "learnClassifier": "learn_classifier", 372 "tuneParameters": "tune_parameters", 373 "kernelFunc" : "kernel_func", 374 }, 375 wrap_methods=["__init__", "tune_parameters"])(SVMLearner) 376 355 377 class SVMClassifierWrapper(Orange.core.SVMClassifier): 356 378 def __new__(cls, wrapped): … … 366 388 return self.wrapped(example, what) 367 389 368 def class Distribution(self, example):390 def class_distribution(self, example): 369 391 example = Orange.data.Instance(self.wrapped.domain, example) 370 392 return self.wrapped.classDistribution(example) 371 393 372 def get DecisionValues(self, example):394 def get_decision_values(self, example): 373 395 example = Orange.data.Instance(self.wrapped.domain, example) 374 396 return self.wrapped.getDecisionValues(example) 375 397 376 def get Model(self):398 def get_model(self): 377 399 return self.wrapped.getModel() 378 400 … … 381 403 for name, val in self.__dict__.items() \ 382 404 if name not in self.wrapped.__dict__]) 405 406 SVMClassifierWrapper = Orange.misc.deprecated_members({ 407 "classDistribution": "class_distribution", 408 "getDecisionValues": "get_decision_values", 409 "getModel" : "get_model", 410 })(SVMClassifierWrapper) 383 411 384 412 class SVMLearnerSparse(SVMLearner): … … 416 444 self.learner = SVMLearner(**kwds) 417 445 418 def learn Classifier(self, examples):446 def learn_classifier(self, examples): 419 447 transformer=Orange.core.DomainContinuizer() 420 448 transformer.multinomialTreatment=Orange.core.DomainContinuizer.NValues … … 432 460 numOfNuValues=9 433 461 if self.svm_type == SVMLearner.Nu_SVC: 434 max Nu = max(self.maxNu(newexamples)  1e7, 0.0)462 max_nu = max(self.max_nu(newexamples)  1e7, 0.0) 435 463 else: 436 max Nu = 1.0464 max_nu = 1.0 437 465 parameters.append(("nu", [i/10.0 for i in range(1, 9) \ 438 if i/10.0 < max Nu] + [maxNu]))466 if i/10.0 < max_nu] + [max_nu])) 439 467 else: 440 468 parameters.append(("C", [2**a for a in range(5,15,2)])) … … 449 477 verbose=self.verbose)) 450 478 479 SVMLearner = Orange.misc.deprecated_members({ 480 "learnClassifier": "learn_classifier", 481 })(SVMLearner) 482 451 483 class SVMLearnerSparseClassEasy(SVMLearnerEasy, SVMLearnerSparse): 452 484 def __init__(self, **kwds): … … 478 510 setattr(self, name, val) 479 511 480 def get LinearSVMWeights(classifier, sum=True):512 def get_linear_svm_weights(classifier, sum=True): 481 513 """Extract attribute weights from the linear svm classifier. 482 514 … … 491 523 492 524 """ 493 def update Weights(w, key, val, mul):525 def update_weights(w, key, val, mul): 494 526 if key in w: 495 527 w[key]+=mul*val … … 518 550 for attr in attributes: 519 551 if attr.varType==Orange.data.Type.Continuous: 520 update Weights(w, attr, to_float(SVs[svInd][attr]), \552 update_weights(w, attr, to_float(SVs[svInd][attr]), \ 521 553 classifier.coef[coefInd][svInd]) 522 554 coefInd=i … … 526 558 for attr in attributes: 527 559 if attr.varType==Orange.data.Type.Continuous: 528 update Weights(w, attr, to_float(SVs[svInd][attr]), \560 update_weights(w, attr, to_float(SVs[svInd][attr]), \ 529 561 classifier.coef[coefInd][svInd]) 530 562 weights.append(w) … … 542 574 return weights 543 575 544 def exampleWeightedSum(example, weights): 576 getLinearSVMWeights = get_linear_svm_weights 577 578 def example_weighted_sum(example, weights): 545 579 sum=0 546 580 for attr, w in weights.items(): … … 548 582 return sum 549 583 584 exampleWeightedSum = example_weighted_sum 585 550 586 class MeasureAttribute_SVMWeights(Orange.core.MeasureAttribute): 551 587 … … 554 590 classifier) as the returned measure. 555 591 556 Example:: 592 Example: 593 557 594 >>> measure = MeasureAttribute_SVMWeights() 558 595 >>> for attr in table.domain.attributes: … … 604 641 weights. 605 642 606 Example: :643 Example: 607 644 608 645 >>> rfe = RFE(SVMLearner(kernel_type=kernels.Linear, … … 618 655 kernels.Linear, normalization=False) 619 656 620 def getAttrScores(self, data, stopAt=0, progressCallback=None): 657 @Orange.misc.deprecated_keywords({"progressCallback": "progress_callback"}) 658 def get_attr_scores(self, data, stopAt=0, progress_callback=None): 621 659 """Return a dict mapping attributes to scores (scores are not scores 622 660 in a general meaning; they represent the step number at which they … … 629 667 630 668 while len(attrs) > stopAt: 631 weights = get LinearSVMWeights(self.learner(data), sum=False)632 if progress Callback:633 progress Callback(100. * iter / (len(attrs)  stopAt))669 weights = get_linear_svm_weights(self.learner(data), sum=False) 670 if progress_callback: 671 progress_callback(100. * iter / (len(attrs)  stopAt)) 634 672 score = dict.fromkeys(attrs, 0) 635 673 for w in weights: … … 646 684 iter += 1 647 685 return attrScores 648 649 def __call__(self, data, numSelected=20, progressCallback=None): 686 687 @Orange.misc.deprecated_keywords({"numSelected": "num_selected", "progressCallback": "progress_callback"}) 688 def __call__(self, data, num_selected=20, progress_callback=None): 650 689 """Return a new dataset with only `numSelected` best scoring attributes 651 690 652 691 :param data: Data 653 692 :type data: Orange.data.Table 654 :param num Selected: number of features to preserve655 :type num Selected: int693 :param num_selected: number of features to preserve 694 :type num_selected: int 656 695 657 696 """ 658 scores = self.get AttrScores(data, progressCallback=progressCallback)697 scores = self.get_attr_scores(data, progressCallback=progress_callback) 659 698 scores = sorted(scores.items(), key=lambda item: item[1]) 660 699 661 scores = dict(scores[num Selected:])700 scores = dict(scores[num_selected:]) 662 701 attrs = [attr for attr in data.domain.attributes if attr in scores] 663 702 domain = Orange.data.Domain(attrs, data.domain.classVar) … … 666 705 return data 667 706 668 def exampleTableToSVMFormat(examples, file): 669 warnings.warn("Deprecated. Use tableToSVMFormat", DeprecationWarning) 670 tableToSVMFormat(examples, file) 671 672 def tableToSVMFormat(examples, file): 707 RFE = Orange.misc.deprecated_members({ 708 "getAttrScores": "get_attr_scores"}, 709 wrap_methods=["get_attr_scores", "__call__"])(RFE) 710 711 def example_table_to_svm_format(examples, file): 712 warnings.warn("Deprecated. Use table_to_svm_format", DeprecationWarning) 713 table_to_svm_format(examples, file) 714 715 exampleTableToSVMFormat = example_table_to_svm_format 716 717 def table_to_svm_format(examples, file): 673 718 """Save :obj:`Orange.data.Table` to a format used by LibSVM.""" 674 719 attrs = examples.domain.attributes + examples.domain.getmetas().values() … … 689 734 file.write("\n") 690 735 736 tableToSVMFormat = table_to_svm_format
Note: See TracChangeset
for help on using the changeset viewer.