source: orange/Orange/testing/unit/tests/test_svm.py @ 10584:edc3a37bc3c3

Revision 10584:edc3a37bc3c3, 7.3 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Fixed get_linear_svm_weights function with respect to the internal libsvm label ordering.

Line 
1import Orange
2from Orange.classification.svm import SVMLearner, MeasureAttribute_SVMWeights,\
3                            LinearLearner, RFE, get_linear_svm_weights, \
4                            example_weighted_sum
5                           
6from Orange.classification.svm.kernels import BagOfWords, RBFKernelWrapper
7from Orange.misc import testing
8from Orange.misc.testing import datasets_driven, test_on_datasets, test_on_data
9import orange
10
11import copy
12
13import pickle
14import numpy as np
15
16def multiclass_from1sv1(dec_values, class_var):
17    n_class = len(class_var.values)
18    votes = [0] * n_class
19    p = 0
20    for i in range(n_class - 1):
21        for j in range(i + 1, n_class):
22            val = dec_values[p]
23            if val > 0:
24                votes[i] += 1
25            else:
26                votes[j] += 1
27            p += 1
28    max_i = np.argmax(votes)
29    return class_var(int(max_i))
30   
31   
32def svm_test_binary_classifier(self, data):
33    if isinstance(data.domain.class_var, Orange.feature.Discrete):
34        # Test binary classifiers equivalence. 
35        classes = data.domain.class_var.values
36        bin_cls = []
37        # Collect all binary classifiers
38        for i in range(len(classes) - 1):
39            for j in range(i + 1, len(classes)):
40                bin_cls.append(self.classifier.get_binary_classifier(i, j))
41               
42        pickled_bin_cls = pickle.loads(pickle.dumps(bin_cls))
43       
44        indices = Orange.data.sample.SubsetIndices2(p0=0.2)
45        sample = data.select(indices(data), 0)
46       
47        learner = copy.copy(self.LEARNER)
48        learner.probability = False 
49        classifier_no_prob = learner(data)
50       
51        for inst in sample:
52            d_val = list(self.classifier.get_decision_values(inst))
53            d_val_b = [bc.get_decision_values(inst)[0] for bc in bin_cls]
54            d_val_b1 = [bc.get_decision_values(inst)[0] for bc in pickled_bin_cls]
55            for v1, v2, v3 in zip(d_val, d_val_b, d_val_b1):
56                self.assertAlmostEqual(v1, v2, places=3)
57                self.assertAlmostEqual(v1, v3, places=3)
58           
59            prediction_1 = classifier_no_prob(inst)
60            d_val = classifier_no_prob.get_decision_values(inst)
61            prediciton_2 = multiclass_from1sv1(d_val, classifier_no_prob.class_var)
62            self.assertEqual(prediction_1, prediciton_2)
63           
64
65datasets = testing.CLASSIFICATION_DATASETS + testing.REGRESSION_DATASETS
66@datasets_driven(datasets=datasets)
67class LinearSVMTestCase(testing.LearnerTestCase):
68    LEARNER = SVMLearner(name="svm-lin", kernel_type=SVMLearner.Linear)
69
70    @test_on_data
71    def test_learner_on(self, dataset):
72        testing.LearnerTestCase.test_learner_on(self, dataset)
73        svm_test_binary_classifier(self, dataset)
74       
75   
76    @test_on_datasets(datasets=testing.CLASSIFICATION_DATASETS + ["zoo"])
77    def test_linear_weights_on(self, dataset):
78        # Test get_linear_svm_weights
79        classifier = self.LEARNER(dataset)
80        weights = get_linear_svm_weights(classifier, sum=True)
81       
82        weights = get_linear_svm_weights(classifier, sum=False)
83       
84        n_class = len(classifier.class_var.values)
85       
86        def class_pairs(n_class):
87            for i in range(n_class - 1):
88                for j in range(i + 1, n_class):
89                    yield i, j
90                   
91        l_map = classifier._get_libsvm_labels_map()
92        # Would need to map the rho values
93        if l_map == sorted(l_map):
94            for inst in dataset[:20]:
95                dec_values = classifier.get_decision_values(inst)
96               
97                for dec_v, weight, rho, pair in zip(dec_values, weights,
98                                        classifier.rho, class_pairs(n_class)):
99                    t_inst = Orange.data.Instance(classifier.domain, inst)                   
100                    dec_v1 = example_weighted_sum(t_inst, weight) - rho
101                    self.assertAlmostEqual(dec_v, dec_v1, 4)
102       
103       
104@datasets_driven(datasets=datasets)
105class PolySVMTestCase(testing.LearnerTestCase):
106    LEARNER = SVMLearner(name="svm-poly", kernel_type=SVMLearner.Polynomial)
107   
108    @test_on_data
109    def test_learner_on(self, dataset):
110        testing.LearnerTestCase.test_learner_on(self, dataset)
111        svm_test_binary_classifier(self, dataset)
112       
113
114@datasets_driven(datasets=datasets)
115class RBFSVMTestCase(testing.LearnerTestCase):
116    LEARNER = SVMLearner(name="svm-RBF", kernel_type=SVMLearner.RBF)
117   
118    @test_on_data
119    def test_learner_on(self, dataset):
120        testing.LearnerTestCase.test_learner_on(self, dataset)
121        svm_test_binary_classifier(self, dataset)
122       
123       
124@datasets_driven(datasets=datasets)
125class SigmoidSVMTestCase(testing.LearnerTestCase):
126    LEARNER = SVMLearner(name="svm-sig", kernel_type=SVMLearner.Sigmoid)
127   
128    @test_on_data
129    def test_learner_on(self, dataset):
130        testing.LearnerTestCase.test_learner_on(self, dataset)
131        svm_test_binary_classifier(self, dataset)
132       
133
134
135#def to_sparse(data):
136#    domain = Orange.data.Domain([], data.domain.class_var)
137#    domain.add_metas(dict([(Orange.core.newmetaid(), v) for v in data.domain.attributes]))
138#    return Orange.data.Table(domain, data)
139#
140#def sparse_data_iter():
141#    for name, (data, ) in testing.datasets_iter(datasets):
142#        yield name, (to_sparse(data), )
143#   
144## This needs sparse datasets.
145#@testing.data_driven(data_iter=sparse_data_iter())
146#class BagOfWordsSVMTestCase(testing.LearnerTestCase):
147#    LEARNER = SVMLearner(name="svm-bow", kernel_type=SVMLearner.Custom, kernelFunc=BagOfWords())
148
149
150@datasets_driven(datasets=datasets)
151class CustomWrapperSVMTestCase(testing.LearnerTestCase):
152    LEARNER = SVMLearner
153
154    @test_on_data
155    def test_learner_on(self, data):
156        """ Test custom kernel wrapper
157        """
158        if data.domain.has_continuous_attributes():
159            dist = orange.ExamplesDistanceConstructor_Euclidean(data)
160        else:
161            dist = orange.ExamplesDistanceConstructor_Hamming(data)
162        self.learner = self.LEARNER(kernel_type=SVMLearner.Custom,
163                                    kernel_func=RBFKernelWrapper(dist, gamma=0.5))
164
165        testing.LearnerTestCase.test_learner_on(self, data)
166        svm_test_binary_classifier(self, data)
167
168
169@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
170class TestLinLearner(testing.LearnerTestCase):
171    LEARNER = LinearLearner
172
173
174@datasets_driven(datasets=datasets)
175class TestMeasureAttr_LinWeights(testing.MeasureAttributeTestCase):
176    MEASURE = MeasureAttribute_SVMWeights()
177
178
179@datasets_driven(datasets=["iris"])
180class TestRFE(testing.DataTestCase):
181    @test_on_data
182    def test_rfe_on(self, data):
183        rfe = RFE()
184        num_selected = min(5, len(data.domain.attributes))
185        reduced = rfe(data, num_selected)
186        self.assertEqual(len(reduced.domain.attributes), num_selected)
187        scores = rfe.get_attr_scores(data, stop_at=num_selected)
188        self.assertEqual(len(data.domain.attributes) - num_selected, len(scores))
189        self.assertTrue(set(reduced.domain.attributes).isdisjoint(scores.keys()))
190
191    def test_pickle(self):
192        import cPickle
193        rfe = RFE()
194        copy = cPickle.loads(cPickle.dumps(rfe))
195
196if __name__ == "__main__":
197    try:
198        import unittest2 as unittest
199    except:
200        import unittest
201    unittest.main()
Note: See TracBrowser for help on using the repository browser.