source: orange/Orange/testing/unit/tests/test_svm.py @ 10617:6371404841d6

Revision 10617:6371404841d6, 7.4 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Added test case for sparse svm learner.

Line 
1import Orange
2from Orange.classification.svm import SVMLearner, SVMLearnerSparse, \
3                            ScoreSVMWeights, LinearLearner, RFE, \
4                            get_linear_svm_weights, \
5                            example_weighted_sum
6                           
7from Orange.classification.svm.kernels import BagOfWords, RBFKernelWrapper
8from Orange.misc import testing
9from Orange.misc.testing import datasets_driven, test_on_datasets, test_on_data
10import orange
11
12import copy
13
14import pickle
15import numpy as np
16
17def multiclass_from1sv1(dec_values, class_var):
18    n_class = len(class_var.values)
19    votes = [0] * n_class
20    p = 0
21    for i in range(n_class - 1):
22        for j in range(i + 1, n_class):
23            val = dec_values[p]
24            if val > 0:
25                votes[i] += 1
26            else:
27                votes[j] += 1
28            p += 1
29    max_i = np.argmax(votes)
30    return class_var(int(max_i))
31   
32   
33def svm_test_binary_classifier(self, data):
34    if isinstance(data.domain.class_var, Orange.feature.Discrete):
35        # Test binary classifiers equivalence. 
36        classes = data.domain.class_var.values
37        bin_cls = []
38        # Collect all binary classifiers
39        for i in range(len(classes) - 1):
40            for j in range(i + 1, len(classes)):
41                bin_cls.append(self.classifier.get_binary_classifier(i, j))
42               
43        pickled_bin_cls = pickle.loads(pickle.dumps(bin_cls))
44       
45        indices = Orange.data.sample.SubsetIndices2(p0=0.2)
46        sample = data.select(indices(data), 0)
47       
48        learner = copy.copy(self.LEARNER)
49        learner.probability = False 
50        classifier_no_prob = learner(data)
51       
52        for inst in sample:
53            d_val = list(self.classifier.get_decision_values(inst))
54            d_val_b = [bc.get_decision_values(inst)[0] for bc in bin_cls]
55            d_val_b1 = [bc.get_decision_values(inst)[0] for bc in pickled_bin_cls]
56            for v1, v2, v3 in zip(d_val, d_val_b, d_val_b1):
57                self.assertAlmostEqual(v1, v2, places=3)
58                self.assertAlmostEqual(v1, v3, places=3)
59           
60            prediction_1 = classifier_no_prob(inst)
61            d_val = classifier_no_prob.get_decision_values(inst)
62            prediciton_2 = multiclass_from1sv1(d_val, classifier_no_prob.class_var)
63            self.assertEqual(prediction_1, prediciton_2)
64           
65
66datasets = testing.CLASSIFICATION_DATASETS + testing.REGRESSION_DATASETS
67@datasets_driven(datasets=datasets)
68class LinearSVMTestCase(testing.LearnerTestCase):
69    LEARNER = SVMLearner(name="svm-lin", kernel_type=SVMLearner.Linear)
70
71    @test_on_data
72    def test_learner_on(self, dataset):
73        testing.LearnerTestCase.test_learner_on(self, dataset)
74        svm_test_binary_classifier(self, dataset)
75       
76   
77    @test_on_datasets(datasets=testing.CLASSIFICATION_DATASETS + ["zoo"])
78    def test_linear_weights_on(self, dataset):
79        # Test get_linear_svm_weights
80        classifier = self.LEARNER(dataset)
81        weights = get_linear_svm_weights(classifier, sum=True)
82       
83        weights = get_linear_svm_weights(classifier, sum=False)
84       
85        n_class = len(classifier.class_var.values)
86       
87        def class_pairs(n_class):
88            for i in range(n_class - 1):
89                for j in range(i + 1, n_class):
90                    yield i, j
91                   
92        l_map = classifier._get_libsvm_labels_map()
93        # Would need to map the rho values
94        if l_map == sorted(l_map):
95            for inst in dataset[:20]:
96                dec_values = classifier.get_decision_values(inst)
97               
98                for dec_v, weight, rho, pair in zip(dec_values, weights,
99                                        classifier.rho, class_pairs(n_class)):
100                    t_inst = Orange.data.Instance(classifier.domain, inst)                   
101                    dec_v1 = example_weighted_sum(t_inst, weight) - rho
102                    self.assertAlmostEqual(dec_v, dec_v1, 4)
103       
104       
105@datasets_driven(datasets=datasets)
106class PolySVMTestCase(testing.LearnerTestCase):
107    LEARNER = SVMLearner(name="svm-poly", kernel_type=SVMLearner.Polynomial)
108   
109    @test_on_data
110    def test_learner_on(self, dataset):
111        testing.LearnerTestCase.test_learner_on(self, dataset)
112        svm_test_binary_classifier(self, dataset)
113       
114
115@datasets_driven(datasets=datasets)
116class RBFSVMTestCase(testing.LearnerTestCase):
117    LEARNER = SVMLearner(name="svm-RBF", kernel_type=SVMLearner.RBF)
118   
119    @test_on_data
120    def test_learner_on(self, dataset):
121        testing.LearnerTestCase.test_learner_on(self, dataset)
122        svm_test_binary_classifier(self, dataset)
123       
124       
125@datasets_driven(datasets=datasets)
126class SigmoidSVMTestCase(testing.LearnerTestCase):
127    LEARNER = SVMLearner(name="svm-sig", kernel_type=SVMLearner.Sigmoid)
128   
129    @test_on_data
130    def test_learner_on(self, dataset):
131        testing.LearnerTestCase.test_learner_on(self, dataset)
132        svm_test_binary_classifier(self, dataset)
133       
134
135
136def to_sparse(data):
137    domain = Orange.data.Domain([], data.domain.class_var)
138    domain.add_metas(dict([(Orange.core.newmetaid(), v) for v in data.domain.attributes]))
139    return Orange.data.Table(domain, data)
140
141def sparse_data_iter():
142    for name, (data, ) in testing.datasets_iter(datasets):
143        yield name, (to_sparse(data), )
144
145@testing.data_driven(data_iter=sparse_data_iter())
146class SparseSVMTestCase(testing.LearnerTestCase):
147    LEARNER = SVMLearnerSparse(name="svm-sparse")
148   
149    @test_on_data
150    def test_learner_on(self, dataset):
151        testing.LearnerTestCase.test_learner_on(self, dataset)
152        svm_test_binary_classifier(self, dataset)
153
154 
155@datasets_driven(datasets=datasets)
156class CustomWrapperSVMTestCase(testing.LearnerTestCase):
157    LEARNER = SVMLearner
158
159    @test_on_data
160    def test_learner_on(self, data):
161        """ Test custom kernel wrapper
162        """
163        if data.domain.has_continuous_attributes():
164            dist = orange.ExamplesDistanceConstructor_Euclidean(data)
165        else:
166            dist = orange.ExamplesDistanceConstructor_Hamming(data)
167        self.learner = self.LEARNER(kernel_type=SVMLearner.Custom,
168                                    kernel_func=RBFKernelWrapper(dist, gamma=0.5))
169
170        testing.LearnerTestCase.test_learner_on(self, data)
171        svm_test_binary_classifier(self, data)
172
173
174@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
175class TestLinLearner(testing.LearnerTestCase):
176    LEARNER = LinearLearner
177
178
179@datasets_driven(datasets=datasets)
180class TestScoreSVMWeights(testing.MeasureAttributeTestCase):
181    MEASURE = ScoreSVMWeights()
182   
183
184@datasets_driven(datasets=["iris"])
185class TestRFE(testing.DataTestCase):
186    @test_on_data
187    def test_rfe_on(self, data):
188        rfe = RFE()
189        num_selected = min(5, len(data.domain.attributes))
190        reduced = rfe(data, num_selected)
191        self.assertEqual(len(reduced.domain.attributes), num_selected)
192        scores = rfe.get_attr_scores(data, stop_at=num_selected)
193        self.assertEqual(len(data.domain.attributes) - num_selected, len(scores))
194        self.assertTrue(set(reduced.domain.attributes).isdisjoint(scores.keys()))
195
196    def test_pickle(self):
197        import cPickle
198        rfe = RFE()
199        copy = cPickle.loads(cPickle.dumps(rfe))
200
201if __name__ == "__main__":
202    try:
203        import unittest2 as unittest
204    except:
205        import unittest
206    unittest.main()
Note: See TracBrowser for help on using the repository browser.