source: orange/Orange/testing/unit/tests/test_svm.py @ 10579:5d12fb993765

Revision 10579:5d12fb993765, 5.9 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Test multiclass prediction from the decision values.

Line 
1import Orange
2from Orange.classification.svm import SVMLearner, MeasureAttribute_SVMWeights, LinearLearner, RFE
3from Orange.classification.svm.kernels import BagOfWords, RBFKernelWrapper
4from Orange.misc import testing
5from Orange.misc.testing import datasets_driven, test_on_datasets, test_on_data
6import orange
7
8import copy
9
10import pickle
11import numpy as np
12
13def multiclass_from1sv1(dec_values, class_var):
14    n_class = len(class_var.values)
15    votes = [0] * n_class
16    p = 0
17    for i in range(n_class - 1):
18        for j in range(i + 1, n_class):
19            val = dec_values[p]
20            if val > 0:
21                votes[i] += 1
22            else:
23                votes[j] += 1
24            p += 1
25    max_i = np.argmax(votes)
26    return class_var(int(max_i))
27   
28   
29def svm_test_binary_classifier(self, data):
30    if isinstance(data.domain.class_var, Orange.feature.Discrete):
31        # Test binary classifiers equivalence. 
32        classes = data.domain.class_var.values
33        bin_cls = []
34        # Collect all binary classifiers
35        for i in range(len(classes) - 1):
36            for j in range(i + 1, len(classes)):
37                bin_cls.append(self.classifier.get_binary_classifier(i, j))
38               
39        pickled_bin_cls = pickle.loads(pickle.dumps(bin_cls))
40       
41        indices = Orange.data.sample.SubsetIndices2(p0=0.2)
42        sample = data.select(indices(data), 0)
43       
44        learner = copy.copy(self.LEARNER)
45        learner.probability = False 
46        classifier_no_prob = learner(data)
47       
48        for inst in sample:
49            d_val = list(self.classifier.get_decision_values(inst))
50            d_val_b = [bc.get_decision_values(inst)[0] for bc in bin_cls]
51            d_val_b1 = [bc.get_decision_values(inst)[0] for bc in pickled_bin_cls]
52            for v1, v2, v3 in zip(d_val, d_val_b, d_val_b1):
53                self.assertAlmostEqual(v1, v2, places=3)
54                self.assertAlmostEqual(v1, v3, places=3)
55           
56            prediction_1 = classifier_no_prob(inst)
57            d_val = classifier_no_prob.get_decision_values(inst)
58            prediciton_2 = multiclass_from1sv1(d_val, classifier_no_prob.class_var)
59            self.assertEqual(prediction_1, prediciton_2)
60           
61
62datasets = testing.CLASSIFICATION_DATASETS + testing.REGRESSION_DATASETS
63@datasets_driven(datasets=datasets)
64class LinearSVMTestCase(testing.LearnerTestCase):
65    LEARNER = SVMLearner(name="svm-lin", kernel_type=SVMLearner.Linear)
66
67    @test_on_data
68    def test_learner_on(self, dataset):
69        testing.LearnerTestCase.test_learner_on(self, dataset)
70        svm_test_binary_classifier(self, dataset)
71       
72@datasets_driven(datasets=datasets)
73class PolySVMTestCase(testing.LearnerTestCase):
74    LEARNER = SVMLearner(name="svm-poly", kernel_type=SVMLearner.Polynomial)
75   
76    @test_on_data
77    def test_learner_on(self, dataset):
78        testing.LearnerTestCase.test_learner_on(self, dataset)
79        svm_test_binary_classifier(self, dataset)
80       
81
82@datasets_driven(datasets=datasets)
83class RBFSVMTestCase(testing.LearnerTestCase):
84    LEARNER = SVMLearner(name="svm-RBF", kernel_type=SVMLearner.RBF)
85   
86    @test_on_data
87    def test_learner_on(self, dataset):
88        testing.LearnerTestCase.test_learner_on(self, dataset)
89        svm_test_binary_classifier(self, dataset)
90       
91       
92@datasets_driven(datasets=datasets)
93class SigmoidSVMTestCase(testing.LearnerTestCase):
94    LEARNER = SVMLearner(name="svm-sig", kernel_type=SVMLearner.Sigmoid)
95   
96    @test_on_data
97    def test_learner_on(self, dataset):
98        testing.LearnerTestCase.test_learner_on(self, dataset)
99        svm_test_binary_classifier(self, dataset)
100       
101
102
103#def to_sparse(data):
104#    domain = Orange.data.Domain([], data.domain.class_var)
105#    domain.add_metas(dict([(Orange.core.newmetaid(), v) for v in data.domain.attributes]))
106#    return Orange.data.Table(domain, data)
107#
108#def sparse_data_iter():
109#    for name, (data, ) in testing.datasets_iter(datasets):
110#        yield name, (to_sparse(data), )
111#   
112## This needs sparse datasets.
113#@testing.data_driven(data_iter=sparse_data_iter())
114#class BagOfWordsSVMTestCase(testing.LearnerTestCase):
115#    LEARNER = SVMLearner(name="svm-bow", kernel_type=SVMLearner.Custom, kernelFunc=BagOfWords())
116
117
118@datasets_driven(datasets=datasets)
119class CustomWrapperSVMTestCase(testing.LearnerTestCase):
120    LEARNER = SVMLearner
121
122    @test_on_data
123    def test_learner_on(self, data):
124        """ Test custom kernel wrapper
125        """
126        if data.domain.has_continuous_attributes():
127            dist = orange.ExamplesDistanceConstructor_Euclidean(data)
128        else:
129            dist = orange.ExamplesDistanceConstructor_Hamming(data)
130        self.learner = self.LEARNER(kernel_type=SVMLearner.Custom,
131                                    kernel_func=RBFKernelWrapper(dist, gamma=0.5))
132
133        testing.LearnerTestCase.test_learner_on(self, data)
134        svm_test_binary_classifier(self, data)
135
136
137@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
138class TestLinLearner(testing.LearnerTestCase):
139    LEARNER = LinearLearner
140
141
142@datasets_driven(datasets=datasets)
143class TestMeasureAttr_LinWeights(testing.MeasureAttributeTestCase):
144    MEASURE = MeasureAttribute_SVMWeights()
145
146
147@datasets_driven(datasets=["iris"])
148class TestRFE(testing.DataTestCase):
149    @test_on_data
150    def test_rfe_on(self, data):
151        rfe = RFE()
152        num_selected = min(5, len(data.domain.attributes))
153        reduced = rfe(data, num_selected)
154        self.assertEqual(len(reduced.domain.attributes), num_selected)
155        scores = rfe.get_attr_scores(data, stop_at=num_selected)
156        self.assertEqual(len(data.domain.attributes) - num_selected, len(scores))
157        self.assertTrue(set(reduced.domain.attributes).isdisjoint(scores.keys()))
158
159    def test_pickle(self):
160        import cPickle
161        rfe = RFE()
162        copy = cPickle.loads(cPickle.dumps(rfe))
163
164if __name__ == "__main__":
165    try:
166        import unittest2 as unittest
167    except:
168        import unittest
169    unittest.main()
Note: See TracBrowser for help on using the repository browser.