source: orange/Orange/testing/unit/tests/test_svm.py @ 10574:f7272ba4865d

Revision 10574:f7272ba4865d, 5.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Added support for custom kernels in get_binary_classifier.

Line 
1import Orange
2from Orange.classification.svm import SVMLearner, MeasureAttribute_SVMWeights, LinearLearner, RFE
3from Orange.classification.svm.kernels import BagOfWords, RBFKernelWrapper
4from Orange.misc import testing
5from Orange.misc.testing import datasets_driven, test_on_datasets, test_on_data
6import orange
7
8import pickle
9
10               
11def svm_test_binary_classifier(self, data):
12    if isinstance(data.domain.class_var, Orange.feature.Discrete):
13        # Test binary classifiers equivalence. 
14        classes = data.domain.class_var.values
15        bin_cls = []
16        # Collect all binary classifiers
17        for i in range(len(classes) - 1):
18            for j in range(i + 1, len(classes)):
19                bin_cls.append(self.classifier.get_binary_classifier(i, j))
20               
21        pickled_bin_cls = pickle.loads(pickle.dumps(bin_cls))
22       
23        indices = Orange.data.sample.SubsetIndices2(p0=0.2)
24        sample = data.select(indices(data))
25       
26        for inst in sample:
27            d_val = list(self.classifier.get_decision_values(inst))
28            d_val_b = [bc.get_decision_values(inst)[0] for bc in bin_cls]
29            d_val_b1 = [bc.get_decision_values(inst)[0] for bc in pickled_bin_cls]
30            for v1, v2, v3 in zip(d_val, d_val_b, d_val_b1):
31                self.assertAlmostEqual(v1, v2, places=3)
32                self.assertAlmostEqual(v1, v3, places=3)
33
34datasets = testing.CLASSIFICATION_DATASETS + testing.REGRESSION_DATASETS
35@datasets_driven(datasets=datasets)
36class LinearSVMTestCase(testing.LearnerTestCase):
37    LEARNER = SVMLearner(name="svm-lin", kernel_type=SVMLearner.Linear)
38
39    @test_on_data
40    def test_learner_on(self, dataset):
41        testing.LearnerTestCase.test_learner_on(self, dataset)
42        svm_test_binary_classifier(self, dataset)
43       
44@datasets_driven(datasets=datasets)
45class PolySVMTestCase(testing.LearnerTestCase):
46    LEARNER = SVMLearner(name="svm-poly", kernel_type=SVMLearner.Polynomial)
47   
48    @test_on_data
49    def test_learner_on(self, dataset):
50        testing.LearnerTestCase.test_learner_on(self, dataset)
51        svm_test_binary_classifier(self, dataset)
52       
53
54@datasets_driven(datasets=datasets)
55class RBFSVMTestCase(testing.LearnerTestCase):
56    LEARNER = SVMLearner(name="svm-RBF", kernel_type=SVMLearner.RBF)
57   
58    @test_on_data
59    def test_learner_on(self, dataset):
60        testing.LearnerTestCase.test_learner_on(self, dataset)
61        svm_test_binary_classifier(self, dataset)
62       
63       
64@datasets_driven(datasets=datasets)
65class SigmoidSVMTestCase(testing.LearnerTestCase):
66    LEARNER = SVMLearner(name="svm-sig", kernel_type=SVMLearner.Sigmoid)
67   
68    @test_on_data
69    def test_learner_on(self, dataset):
70        testing.LearnerTestCase.test_learner_on(self, dataset)
71        svm_test_binary_classifier(self, dataset)
72       
73
74
75#def to_sparse(data):
76#    domain = Orange.data.Domain([], data.domain.class_var)
77#    domain.add_metas(dict([(Orange.core.newmetaid(), v) for v in data.domain.attributes]))
78#    return Orange.data.Table(domain, data)
79#
80#def sparse_data_iter():
81#    for name, (data, ) in testing.datasets_iter(datasets):
82#        yield name, (to_sparse(data), )
83#   
84## This needs sparse datasets.
85#@testing.data_driven(data_iter=sparse_data_iter())
86#class BagOfWordsSVMTestCase(testing.LearnerTestCase):
87#    LEARNER = SVMLearner(name="svm-bow", kernel_type=SVMLearner.Custom, kernelFunc=BagOfWords())
88
89
90@datasets_driven(datasets=datasets)
91class CustomWrapperSVMTestCase(testing.LearnerTestCase):
92    LEARNER = SVMLearner
93
94    @test_on_data
95    def test_learner_on(self, data):
96        """ Test custom kernel wrapper
97        """
98        if data.domain.has_continuous_attributes():
99            dist = orange.ExamplesDistanceConstructor_Euclidean(data)
100        else:
101            dist = orange.ExamplesDistanceConstructor_Hamming(data)
102        self.learner = self.LEARNER(kernel_type=SVMLearner.Custom,
103                                    kernel_func=RBFKernelWrapper(dist, gamma=0.5))
104
105        testing.LearnerTestCase.test_learner_on(self, data)
106        svm_test_binary_classifier(self, data)
107
108
109@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
110class TestLinLearner(testing.LearnerTestCase):
111    LEARNER = LinearLearner
112
113
114@datasets_driven(datasets=datasets)
115class TestMeasureAttr_LinWeights(testing.MeasureAttributeTestCase):
116    MEASURE = MeasureAttribute_SVMWeights()
117
118
119@datasets_driven(datasets=["iris"])
120class TestRFE(testing.DataTestCase):
121    @test_on_data
122    def test_rfe_on(self, data):
123        rfe = RFE()
124        num_selected = min(5, len(data.domain.attributes))
125        reduced = rfe(data, num_selected)
126        self.assertEqual(len(reduced.domain.attributes), num_selected)
127        scores = rfe.get_attr_scores(data, stop_at=num_selected)
128        self.assertEqual(len(data.domain.attributes) - num_selected, len(scores))
129        self.assertTrue(set(reduced.domain.attributes).isdisjoint(scores.keys()))
130
131    def test_pickle(self):
132        import cPickle
133        rfe = RFE()
134        copy = cPickle.loads(cPickle.dumps(rfe))
135
136if __name__ == "__main__":
137    try:
138        import unittest2 as unittest
139    except:
140        import unittest
141    unittest.main()
Note: See TracBrowser for help on using the repository browser.