Ignore:
Timestamp:
03/26/12 14:21:46 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Added tests for get_linear_svm_weights for regression models.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/testing/unit/tests/test_svm.py

    r10617 r10641  
     1try: 
     2    import unittest2 as unittest 
     3except: 
     4    import unittest 
     5     
    16import Orange 
    27from Orange.classification.svm import SVMLearner, SVMLearnerSparse, \ 
     
    49                            get_linear_svm_weights, \ 
    510                            example_weighted_sum 
    6                              
     11from Orange.classification import svm                             
    712from Orange.classification.svm.kernels import BagOfWords, RBFKernelWrapper 
    813from Orange.misc import testing 
     
    1520import numpy as np 
    1621 
    17 def multiclass_from1sv1(dec_values, class_var): 
     22def multiclass_from1vs1(dec_values, class_var): 
    1823    n_class = len(class_var.values) 
    1924    votes = [0] * n_class 
     
    6065            prediction_1 = classifier_no_prob(inst) 
    6166            d_val = classifier_no_prob.get_decision_values(inst) 
    62             prediciton_2 = multiclass_from1sv1(d_val, classifier_no_prob.class_var) 
     67            prediciton_2 = multiclass_from1vs1(d_val, classifier_no_prob.class_var) 
    6368            self.assertEqual(prediction_1, prediciton_2) 
    6469             
     
    7580         
    7681     
    77     @test_on_datasets(datasets=testing.CLASSIFICATION_DATASETS + ["zoo"]) 
    78     def test_linear_weights_on(self, dataset): 
     82    # Don't test on "monks" the coefs are really large and 
     83    @test_on_datasets(datasets=["iris", "brown-selected", "lenses", "zoo"]) 
     84    def test_linear_classifier_weights_on(self, dataset): 
    7985        # Test get_linear_svm_weights 
    8086        classifier = self.LEARNER(dataset) 
     
    9197                     
    9298        l_map = classifier._get_libsvm_labels_map() 
    93         # Would need to map the rho values 
    94         if l_map == sorted(l_map): 
    95             for inst in dataset[:20]: 
    96                 dec_values = classifier.get_decision_values(inst) 
    97                  
    98                 for dec_v, weight, rho, pair in zip(dec_values, weights, 
    99                                         classifier.rho, class_pairs(n_class)): 
    100                     t_inst = Orange.data.Instance(classifier.domain, inst)                     
    101                     dec_v1 = example_weighted_sum(t_inst, weight) - rho 
    102                     self.assertAlmostEqual(dec_v, dec_v1, 4) 
    103          
    104          
     99     
     100        for inst in dataset[:20]: 
     101            dec_values = classifier.get_decision_values(inst) 
     102             
     103            for dec_v, weight, rho, pair in zip(dec_values, weights, 
     104                                    classifier.rho, class_pairs(n_class)): 
     105                t_inst = Orange.data.Instance(classifier.domain, inst)                     
     106                dec_v1 = example_weighted_sum(t_inst, weight) - rho 
     107                self.assertAlmostEqual(dec_v, dec_v1, 4) 
     108                     
     109    @test_on_datasets(datasets=testing.REGRESSION_DATASETS) 
     110    def test_linear_regression_weights_on(self, dataset): 
     111        predictor = self.LEARNER(dataset) 
     112        weights = get_linear_svm_weights(predictor) 
     113         
     114        for inst in dataset[:20]: 
     115            t_inst = Orange.data.Instance(predictor.domain, inst) 
     116            prediction = predictor(inst) 
     117            w_sum = example_weighted_sum(t_inst, weights) 
     118            self.assertAlmostEqual(float(prediction),  
     119                                   w_sum - predictor.rho[0], 
     120                                   places=4) 
     121         
     122 
    105123@datasets_driven(datasets=datasets) 
    106124class PolySVMTestCase(testing.LearnerTestCase): 
     
    199217        copy = cPickle.loads(cPickle.dumps(rfe)) 
    200218 
     219 
    201220if __name__ == "__main__": 
    202     try: 
    203         import unittest2 as unittest 
    204     except: 
    205         import unittest 
    206221    unittest.main() 
Note: See TracChangeset for help on using the changeset viewer.