source: orange/Orange/testing/unit/tests/test_svm.py @ 9679:3879dea56188

Revision 9679:3879dea56188, 3.2 KB checked in by Miha Stajdohar <miha.stajdohar@…>, 2 years ago (diff)

Moved and renamed testing.

Line 
1import Orange
2from Orange.classification.svm import SVMLearner, MeasureAttribute_SVMWeights, LinearLearner, RFE
3from Orange.classification.svm.kernels import BagOfWords, RBFKernelWrapper
4from Orange.misc import testing
5from Orange.misc.testing import datasets_driven, test_on_datasets, test_on_data
6import orange
7
8
9datasets = testing.CLASSIFICATION_DATASETS + testing.REGRESSION_DATASETS
10@datasets_driven(datasets=datasets)
11class LinearSVMTestCase(testing.LearnerTestCase):
12    LEARNER = SVMLearner(name="svm-lin", kernel_type=SVMLearner.Linear)
13   
14   
15@datasets_driven(datasets=datasets)
16class PolySVMTestCase(testing.LearnerTestCase):
17    LEARNER = SVMLearner(name="svm-poly", kernel_type=SVMLearner.Polynomial)
18   
19   
20@datasets_driven(datasets=datasets)
21class RBFSVMTestCase(testing.LearnerTestCase):
22    LEARNER = SVMLearner(name="svm-RBF", kernel_type=SVMLearner.RBF)
23   
24   
25@datasets_driven(datasets=datasets)
26class SigmoidSVMTestCase(testing.LearnerTestCase):
27    LEARNER = SVMLearner(name="svm-sig", kernel_type=SVMLearner.Sigmoid)
28   
29
30#def to_sparse(data):
31#    domain = Orange.data.Domain([], data.domain.class_var)
32#    domain.add_metas(dict([(Orange.core.newmetaid(), v) for v in data.domain.attributes]))
33#    return Orange.data.Table(domain, data)
34#
35#def sparse_data_iter():
36#    for name, (data, ) in testing.datasets_iter(datasets):
37#        yield name, (to_sparse(data), )
38#   
39## This needs sparse datasets.
40#@testing.data_driven(data_iter=sparse_data_iter())
41#class BagOfWordsSVMTestCase(testing.LearnerTestCase):
42#    LEARNER = SVMLearner(name="svm-bow", kernel_type=SVMLearner.Custom, kernelFunc=BagOfWords())
43   
44   
45@datasets_driven(datasets=datasets)
46class CustomWrapperSVMTestCase(testing.LearnerTestCase):
47    LEARNER = SVMLearner
48   
49    @test_on_data
50    def test_learner_on(self, data):
51        """ Test custom kernel wrapper
52        """
53        # Need the data for ExamplesDistanceConstructor_Euclidean 
54        self.learner = self.LEARNER(kernel_type=SVMLearner.Custom,
55                                    kernelFunc=RBFKernelWrapper(orange.ExamplesDistanceConstructor_Euclidean(data), gamma=0.5))
56       
57        testing.LearnerTestCase.test_learner_on(self, data)
58   
59   
60@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
61class TestLinLearner(testing.LearnerTestCase):
62    LEARNER = LinearLearner
63   
64   
65@datasets_driven(datasets=datasets)
66class TestMeasureAttr_LinWeights(testing.MeasureAttributeTestCase):
67    MEASURE = MeasureAttribute_SVMWeights()
68
69
70@datasets_driven(datasets=["iris"])
71class TestRFE(testing.DataTestCase):
72    @test_on_data
73    def test_rfe_on(self, data):
74        rfe = RFE()
75        num_selected = min(5, len(data.domain.attributes))
76        reduced = rfe(data, num_selected)
77        self.assertEqual(len(reduced.domain.attributes), num_selected)
78        scores = rfe.getAttrScores(data, stopAt=num_selected)
79        self.assertEqual(len(data.domain.attributes) - num_selected, len(scores))
80        self.assertTrue(set(reduced.domain.attributes).isdisjoint(scores.keys()))
81       
82    def test_pickle(self):
83        import cPickle
84        rfe = RFE()
85        cPickle.loads(cPickle.dumps(rfe))
86
87if __name__ == "__main__":
88    import unittest
89    unittest.main()
Note: See TracBrowser for help on using the repository browser.