Ignore:
Timestamp:
03/20/12 15:50:14 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
rebase_source:
b27e78cc27e112b769d1edb5e3d27ec79e783ce9
Message:

Fixed get_linear_svm_weights function with respect to the internal libsvm label ordering.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/testing/unit/tests/test_svm.py

    r10579 r10584  
    11import Orange 
    2 from Orange.classification.svm import SVMLearner, MeasureAttribute_SVMWeights, LinearLearner, RFE 
     2from Orange.classification.svm import SVMLearner, MeasureAttribute_SVMWeights,\ 
     3                            LinearLearner, RFE, get_linear_svm_weights, \ 
     4                            example_weighted_sum 
     5                             
    36from Orange.classification.svm.kernels import BagOfWords, RBFKernelWrapper 
    47from Orange.misc import testing 
     
    7073        svm_test_binary_classifier(self, dataset) 
    7174         
     75     
     76    @test_on_datasets(datasets=testing.CLASSIFICATION_DATASETS + ["zoo"]) 
     77    def test_linear_weights_on(self, dataset): 
     78        # Test get_linear_svm_weights 
     79        classifier = self.LEARNER(dataset) 
     80        weights = get_linear_svm_weights(classifier, sum=True) 
     81         
     82        weights = get_linear_svm_weights(classifier, sum=False) 
     83         
     84        n_class = len(classifier.class_var.values) 
     85         
     86        def class_pairs(n_class): 
     87            for i in range(n_class - 1): 
     88                for j in range(i + 1, n_class): 
     89                    yield i, j 
     90                     
     91        l_map = classifier._get_libsvm_labels_map() 
     92        # Would need to map the rho values 
     93        if l_map == sorted(l_map): 
     94            for inst in dataset[:20]: 
     95                dec_values = classifier.get_decision_values(inst) 
     96                 
     97                for dec_v, weight, rho, pair in zip(dec_values, weights, 
     98                                        classifier.rho, class_pairs(n_class)): 
     99                    t_inst = Orange.data.Instance(classifier.domain, inst)                     
     100                    dec_v1 = example_weighted_sum(t_inst, weight) - rho 
     101                    self.assertAlmostEqual(dec_v, dec_v1, 4) 
     102         
     103         
    72104@datasets_driven(datasets=datasets) 
    73105class PolySVMTestCase(testing.LearnerTestCase): 
Note: See TracChangeset for help on using the changeset viewer.