source: orange/Orange/testing/unit/tests/test_svm.py @ 10664:1c41c9dd6c8f

Revision 10664:1c41c9dd6c8f, 8.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Added doctests to the svm test suite.

Line 
1try:
2    import unittest2 as unittest
3except:
4    import unittest
5   
6import Orange
7from Orange.classification.svm import SVMLearner, SVMLearnerSparse, \
8                            ScoreSVMWeights, LinearLearner, RFE, \
9                            get_linear_svm_weights, \
10                            example_weighted_sum
11from Orange.classification import svm                           
12from Orange.classification.svm.kernels import BagOfWords, RBFKernelWrapper
13from Orange.misc import testing
14from Orange.misc.testing import datasets_driven, test_on_datasets, test_on_data
15import orange
16
17import copy
18
19import pickle
20import numpy as np
21
22def multiclass_from1vs1(dec_values, class_var):
23    n_class = len(class_var.values)
24    votes = [0] * n_class
25    p = 0
26    for i in range(n_class - 1):
27        for j in range(i + 1, n_class):
28            val = dec_values[p]
29            if val > 0:
30                votes[i] += 1
31            else:
32                votes[j] += 1
33            p += 1
34    max_i = np.argmax(votes)
35    return class_var(int(max_i))
36   
37   
38def svm_test_binary_classifier(self, data):
39    if isinstance(data.domain.class_var, Orange.feature.Discrete):
40        # Test binary classifiers equivalence. 
41        classes = data.domain.class_var.values
42        bin_cls = []
43        # Collect all binary classifiers
44        for i in range(len(classes) - 1):
45            for j in range(i + 1, len(classes)):
46                bin_cls.append(self.classifier.get_binary_classifier(i, j))
47               
48        pickled_bin_cls = pickle.loads(pickle.dumps(bin_cls))
49       
50        indices = Orange.data.sample.SubsetIndices2(p0=0.2)
51        sample = data.select(indices(data), 0)
52       
53        learner = copy.copy(self.learner)
54        learner.probability = False 
55        classifier_no_prob = learner(data)
56       
57        for inst in sample:
58            d_val = list(self.classifier.get_decision_values(inst))
59            d_val_b = [bc.get_decision_values(inst)[0] for bc in bin_cls]
60            d_val_b1 = [bc.get_decision_values(inst)[0] for bc in pickled_bin_cls]
61            for v1, v2, v3 in zip(d_val, d_val_b, d_val_b1):
62                self.assertAlmostEqual(v1, v2, places=3)
63                self.assertAlmostEqual(v1, v3, places=3)
64           
65            prediction_1 = classifier_no_prob(inst)
66            d_val = classifier_no_prob.get_decision_values(inst)
67            prediciton_2 = multiclass_from1vs1(d_val, classifier_no_prob.class_var)
68            self.assertEqual(prediction_1, prediciton_2)
69           
70datasets = testing.CLASSIFICATION_DATASETS + testing.REGRESSION_DATASETS
71@datasets_driven(datasets=datasets)
72class LinearSVMTestCase(testing.LearnerTestCase):
73    LEARNER = SVMLearner(name="svm-lin", kernel_type=SVMLearner.Linear)
74
75    @test_on_data
76    def test_learner_on(self, dataset):
77        testing.LearnerTestCase.test_learner_on(self, dataset)
78        svm_test_binary_classifier(self, dataset)
79       
80   
81    # Don't test on "monks" the coefs are really large and
82    @test_on_datasets(datasets=["iris", "brown-selected", "lenses", "zoo"])
83    def test_linear_classifier_weights_on(self, dataset):
84        # Test get_linear_svm_weights
85        classifier = self.LEARNER(dataset)
86        weights = get_linear_svm_weights(classifier, sum=True)
87       
88        weights = get_linear_svm_weights(classifier, sum=False)
89       
90        n_class = len(classifier.class_var.values)
91       
92        def class_pairs(n_class):
93            for i in range(n_class - 1):
94                for j in range(i + 1, n_class):
95                    yield i, j
96                   
97        l_map = classifier._get_libsvm_labels_map()
98   
99        for inst in dataset[:20]:
100            dec_values = classifier.get_decision_values(inst)
101           
102            for dec_v, weight, rho, pair in zip(dec_values, weights,
103                                    classifier.rho, class_pairs(n_class)):
104                t_inst = Orange.data.Instance(classifier.domain, inst)                   
105                dec_v1 = example_weighted_sum(t_inst, weight) - rho
106                self.assertAlmostEqual(dec_v, dec_v1, 4)
107                   
108    @test_on_datasets(datasets=testing.REGRESSION_DATASETS)
109    def test_linear_regression_weights_on(self, dataset):
110        predictor = self.LEARNER(dataset)
111        weights = get_linear_svm_weights(predictor)
112       
113        for inst in dataset[:20]:
114            t_inst = Orange.data.Instance(predictor.domain, inst)
115            prediction = predictor(inst)
116            w_sum = example_weighted_sum(t_inst, weights)
117            self.assertAlmostEqual(float(prediction), 
118                                   w_sum - predictor.rho[0],
119                                   places=4)
120       
121
122@datasets_driven(datasets=datasets)
123class PolySVMTestCase(testing.LearnerTestCase):
124    LEARNER = SVMLearner(name="svm-poly", kernel_type=SVMLearner.Polynomial)
125   
126    @test_on_data
127    def test_learner_on(self, dataset):
128        testing.LearnerTestCase.test_learner_on(self, dataset)
129        svm_test_binary_classifier(self, dataset)
130       
131
132@datasets_driven(datasets=datasets)
133class RBFSVMTestCase(testing.LearnerTestCase):
134    LEARNER = SVMLearner(name="svm-RBF", kernel_type=SVMLearner.RBF)
135   
136    @test_on_data
137    def test_learner_on(self, dataset):
138        testing.LearnerTestCase.test_learner_on(self, dataset)
139        svm_test_binary_classifier(self, dataset)
140
141@datasets_driven(datasets=datasets)
142class SigmoidSVMTestCase(testing.LearnerTestCase):
143    LEARNER = SVMLearner(name="svm-sig", kernel_type=SVMLearner.Sigmoid)
144   
145    @test_on_data
146    def test_learner_on(self, dataset):
147        testing.LearnerTestCase.test_learner_on(self, dataset)
148        svm_test_binary_classifier(self, dataset)
149
150
151def to_sparse(data):
152    domain = Orange.data.Domain([], data.domain.class_var)
153    domain.add_metas(dict([(Orange.core.newmetaid(), v) for v in data.domain.attributes]))
154    return Orange.data.Table(domain, data)
155
156def sparse_data_iter():
157    for name, (data, ) in testing.datasets_iter(datasets):
158        yield name, (to_sparse(data), )
159
160@testing.data_driven(data_iter=sparse_data_iter())
161class SparseSVMTestCase(testing.LearnerTestCase):
162    LEARNER = SVMLearnerSparse(name="svm-sparse")
163   
164    @test_on_data
165    def test_learner_on(self, dataset):
166        testing.LearnerTestCase.test_learner_on(self, dataset)
167        svm_test_binary_classifier(self, dataset)
168
169
170@datasets_driven(datasets=datasets)
171class CustomWrapperSVMTestCase(testing.LearnerTestCase):
172    LEARNER = SVMLearner
173
174    @test_on_data
175    def test_learner_on(self, data):
176        """ Test custom kernel wrapper
177        """
178        if data.domain.has_continuous_attributes():
179            dist = orange.ExamplesDistanceConstructor_Euclidean(data)
180        else:
181            dist = orange.ExamplesDistanceConstructor_Hamming(data)
182        self.learner = self.LEARNER(kernel_type=SVMLearner.Custom,
183                                    kernel_func=RBFKernelWrapper(dist, gamma=0.5))
184
185        testing.LearnerTestCase.test_learner_on(self, data)
186        svm_test_binary_classifier(self, data)
187
188@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
189class TestLinLearner(testing.LearnerTestCase):
190    LEARNER = LinearLearner
191
192
193@datasets_driven(datasets=datasets)
194class TestScoreSVMWeights(testing.MeasureAttributeTestCase):
195    MEASURE = ScoreSVMWeights()
196   
197
198@datasets_driven(datasets=["iris"])
199class TestRFE(testing.DataTestCase):
200    @test_on_data
201    def test_rfe_on(self, data):
202        rfe = RFE()
203        num_selected = min(5, len(data.domain.attributes))
204        reduced = rfe(data, num_selected)
205        self.assertEqual(len(reduced.domain.attributes), num_selected)
206        scores = rfe.get_attr_scores(data, stop_at=num_selected)
207        self.assertEqual(len(data.domain.attributes) - num_selected, len(scores))
208        self.assertTrue(set(reduced.domain.attributes).isdisjoint(scores.keys()))
209
210    def test_pickle(self):
211        import cPickle
212        rfe = RFE()
213        copy = cPickle.loads(cPickle.dumps(rfe))
214
215def load_tests(loader, tests, ignore):
216    import doctest
217    tests.addTests(doctest.DocTestSuite(svm, **svm._doctest_args()))
218    return tests
219
220if __name__ == "__main__":
221    unittest.main()
Note: See TracBrowser for help on using the repository browser.