Ignore:
Timestamp:
03/19/12 16:43:39 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
rebase_source:
84eb1ad3ab7dbc2cf39ac79eaf19384b1317d358
Message:

Added get_binary_classifier method to SVMClassifierWrapper.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/testing/unit/tests/test_svm.py

    r10278 r10573  
    66import orange 
    77 
     8import pickle 
     9 
     10                 
     11def svm_test_binary_classifier(self, data): 
     12    if isinstance(data.domain.class_var, Orange.feature.Discrete): 
     13        # Test binary classifiers equivalence.   
     14        classes = data.domain.class_var.values 
     15        bin_cls = [] 
     16        # Collect all binary classifiers 
     17        for i in range(len(classes) - 1): 
     18            for j in range(i + 1, len(classes)): 
     19                bin_cls.append(self.classifier.get_binary_classifier(i, j)) 
     20                 
     21        pickled_bin_cls = pickle.loads(pickle.dumps(bin_cls)) 
     22         
     23        indices = Orange.data.sample.SubsetIndices2(p0=0.2) 
     24        sample = data.select(indices(data)) 
     25         
     26        for inst in sample: 
     27            d_val = list(self.classifier.get_decision_values(inst)) 
     28            d_val_b = [bc.get_decision_values(inst)[0] for bc in bin_cls] 
     29            d_val_b1 = [bc.get_decision_values(inst)[0] for bc in pickled_bin_cls] 
     30            for v1, v2, v3 in zip(d_val, d_val_b, d_val_b1): 
     31                self.assertAlmostEqual(v1, v2, places=3) 
     32                self.assertAlmostEqual(v1, v3, places=3) 
    833 
    934datasets = testing.CLASSIFICATION_DATASETS + testing.REGRESSION_DATASETS 
     
    1237    LEARNER = SVMLearner(name="svm-lin", kernel_type=SVMLearner.Linear) 
    1338 
    14  
     39    @test_on_data 
     40    def test_learner_on(self, dataset): 
     41        testing.LearnerTestCase.test_learner_on(self, dataset) 
     42        svm_test_binary_classifier(self, dataset) 
     43         
    1544@datasets_driven(datasets=datasets) 
    1645class PolySVMTestCase(testing.LearnerTestCase): 
    1746    LEARNER = SVMLearner(name="svm-poly", kernel_type=SVMLearner.Polynomial) 
    18  
     47     
     48    @test_on_data 
     49    def test_learner_on(self, dataset): 
     50        testing.LearnerTestCase.test_learner_on(self, dataset) 
     51        svm_test_binary_classifier(self, dataset) 
     52         
    1953 
    2054@datasets_driven(datasets=datasets) 
    2155class RBFSVMTestCase(testing.LearnerTestCase): 
    2256    LEARNER = SVMLearner(name="svm-RBF", kernel_type=SVMLearner.RBF) 
    23  
    24  
     57     
     58    @test_on_data 
     59    def test_learner_on(self, dataset): 
     60        testing.LearnerTestCase.test_learner_on(self, dataset) 
     61        svm_test_binary_classifier(self, dataset) 
     62         
     63         
    2564@datasets_driven(datasets=datasets) 
    2665class SigmoidSVMTestCase(testing.LearnerTestCase): 
    2766    LEARNER = SVMLearner(name="svm-sig", kernel_type=SVMLearner.Sigmoid) 
     67     
     68    @test_on_data 
     69    def test_learner_on(self, dataset): 
     70        testing.LearnerTestCase.test_learner_on(self, dataset) 
     71        svm_test_binary_classifier(self, dataset) 
     72         
    2873 
    2974 
Note: See TracChangeset for help on using the changeset viewer.