source: orange/Orange/testing/unit/tests/test_svm.py @ 10278:f3b1ffae9c29

Revision 10278:f3b1ffae9c29, 3.2 KB checked in by Miha Stajdohar <miha.stajdohar@…>, 2 years ago (diff)

Unittest2 for python 2.6.

Line 
1import Orange
2from Orange.classification.svm import SVMLearner, MeasureAttribute_SVMWeights, LinearLearner, RFE
3from Orange.classification.svm.kernels import BagOfWords, RBFKernelWrapper
4from Orange.misc import testing
5from Orange.misc.testing import datasets_driven, test_on_datasets, test_on_data
6import orange
7
8
9datasets = testing.CLASSIFICATION_DATASETS + testing.REGRESSION_DATASETS
10@datasets_driven(datasets=datasets)
11class LinearSVMTestCase(testing.LearnerTestCase):
12    LEARNER = SVMLearner(name="svm-lin", kernel_type=SVMLearner.Linear)
13
14
15@datasets_driven(datasets=datasets)
16class PolySVMTestCase(testing.LearnerTestCase):
17    LEARNER = SVMLearner(name="svm-poly", kernel_type=SVMLearner.Polynomial)
18
19
20@datasets_driven(datasets=datasets)
21class RBFSVMTestCase(testing.LearnerTestCase):
22    LEARNER = SVMLearner(name="svm-RBF", kernel_type=SVMLearner.RBF)
23
24
25@datasets_driven(datasets=datasets)
26class SigmoidSVMTestCase(testing.LearnerTestCase):
27    LEARNER = SVMLearner(name="svm-sig", kernel_type=SVMLearner.Sigmoid)
28
29
30#def to_sparse(data):
31#    domain = Orange.data.Domain([], data.domain.class_var)
32#    domain.add_metas(dict([(Orange.core.newmetaid(), v) for v in data.domain.attributes]))
33#    return Orange.data.Table(domain, data)
34#
35#def sparse_data_iter():
36#    for name, (data, ) in testing.datasets_iter(datasets):
37#        yield name, (to_sparse(data), )
38#   
39## This needs sparse datasets.
40#@testing.data_driven(data_iter=sparse_data_iter())
41#class BagOfWordsSVMTestCase(testing.LearnerTestCase):
42#    LEARNER = SVMLearner(name="svm-bow", kernel_type=SVMLearner.Custom, kernelFunc=BagOfWords())
43
44
45@datasets_driven(datasets=datasets)
46class CustomWrapperSVMTestCase(testing.LearnerTestCase):
47    LEARNER = SVMLearner
48
49    @test_on_data
50    def test_learner_on(self, data):
51        """ Test custom kernel wrapper
52        """
53        # Need the data for ExamplesDistanceConstructor_Euclidean 
54        self.learner = self.LEARNER(kernel_type=SVMLearner.Custom,
55                                    kernel_func=RBFKernelWrapper(orange.ExamplesDistanceConstructor_Euclidean(data), gamma=0.5))
56
57        testing.LearnerTestCase.test_learner_on(self, data)
58
59
60@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
61class TestLinLearner(testing.LearnerTestCase):
62    LEARNER = LinearLearner
63
64
65@datasets_driven(datasets=datasets)
66class TestMeasureAttr_LinWeights(testing.MeasureAttributeTestCase):
67    MEASURE = MeasureAttribute_SVMWeights()
68
69
70@datasets_driven(datasets=["iris"])
71class TestRFE(testing.DataTestCase):
72    @test_on_data
73    def test_rfe_on(self, data):
74        rfe = RFE()
75        num_selected = min(5, len(data.domain.attributes))
76        reduced = rfe(data, num_selected)
77        self.assertEqual(len(reduced.domain.attributes), num_selected)
78        scores = rfe.get_attr_scores(data, stop_at=num_selected)
79        self.assertEqual(len(data.domain.attributes) - num_selected, len(scores))
80        self.assertTrue(set(reduced.domain.attributes).isdisjoint(scores.keys()))
81
82    def test_pickle(self):
83        import cPickle
84        rfe = RFE()
85        copy = cPickle.loads(cPickle.dumps(rfe))
86
87if __name__ == "__main__":
88    try:
89        import unittest2 as unittest
90    except:
91        import unittest
92    unittest.main()
Note: See TracBrowser for help on using the repository browser.