source: orange/Orange/testing/unit/tests/test_linear.py @ 10773:ccf1d708d349

Revision 10773:ccf1d708d349, 4.0 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Added tests for Linear learners with bias. Test the weights by reimplementing one vs. rest classification in python.

Line 
1import Orange
2from Orange.testing import testing
3from Orange.testing.testing import datasets_driven
4from Orange.classification.svm import LinearSVMLearner
5try:
6    import unittest2 as unittest
7except:
8    import unittest
9
10import numpy as np
11
12def multiclass_from_1_vs_rest(dec_values, class_var):
13    if len(class_var.values) > 2:
14        return class_var(int(np.argmax(dec_values)))
15    else:
16        return class_var(0 if dec_values[0] > 0 else 1)
17
18def binary_classifier_test(self, data):
19    class_var = data.domain.class_var
20    if isinstance(class_var, Orange.feature.Discrete):
21        cl_values = class_var.values
22        if self.classifier.bias >= 0:
23            bias = [self.classifier.bias]
24        else:
25            bias = []
26        for inst in data[:]:
27            dec_values = []
28            inst_v = [float(v) if not v.is_special() else 0.0 \
29                      for v in Orange.data.Instance(self.classifier.domain, inst)]
30            inst_v = inst_v[:-1] + bias
31            for w in self.classifier.weights:
32                dec_values.append(np.dot(inst_v, w))
33            pval1 = self.classifier(inst)
34            pval2 = multiclass_from_1_vs_rest(dec_values, class_var)
35            if len(cl_values) > 2:
36                self.assertEqual(pval1, pval2)
37            else:
38                #TODO: handle order switch
39                pass
40
41@testing.test_on_data
42def test_learner_on(self, dataset):
43    testing.LearnerTestCase.test_learner_on(self, dataset)
44    n_vals = len(dataset.domain.class_var.values)
45    if n_vals > 2:
46        self.assertEquals(len(self.classifier.weights), n_vals)
47    else:
48        self.assertEquals(len(self.classifier.weights), 1)
49    n_features = len(self.classifier.domain.attributes)
50    if self.classifier.bias >= 0:
51        n_features += 1
52   
53    self.assertTrue(all(len(w) == n_features \
54                        for w in self.classifier.weights
55                        ))
56   
57    binary_classifier_test(self, dataset)
58
59@testing.test_on_data
60def test_learner_with_bias_on(self, dataset):
61    import cPickle
62    learner = self.learner
63    learner_b = cPickle.loads(cPickle.dumps(learner))
64    learner_b.bias = 1
65    try:
66        self.learner = learner_b
67    finally:
68        self.learner = learner
69    test_learner_on(self, dataset)
70         
71
72@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
73class TestLinearSVMLearnerL2R_L2LOSS_DUAL(testing.LearnerTestCase):
74    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L2R_L2LOSS_DUAL)
75
76    test_learner_on = test_learner_on
77    test_learner_with_bias_on = test_learner_with_bias_on
78
79@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
80class TestLinearSVMLearnerL2R_L2LOSS(testing.LearnerTestCase):
81    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L2R_L2LOSS)
82   
83    test_learner_on = test_learner_on
84    test_learner_with_bias_on = test_learner_with_bias_on
85
86@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
87class TestLinearSVMLearnerL2R_L1LOSS_DUAL(testing.LearnerTestCase):
88    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L2R_L1LOSS_DUAL)
89   
90    test_learner_on = test_learner_on
91    test_learner_with_bias_on = test_learner_with_bias_on
92
93@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
94class TestLinearSVMLearnerL2R_L1LOSS(testing.LearnerTestCase):
95    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L2R_L2LOSS)
96       
97    test_learner_on = test_learner_on
98    test_learner_with_bias_on = test_learner_with_bias_on
99
100@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
101class TestLinearSVMLearnerL1R_L2LOSS(testing.LearnerTestCase):
102    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L1R_L2LOSS)
103   
104    test_learner_on = test_learner_on
105    test_learner_with_bias_on = test_learner_with_bias_on
106
107@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
108class TestLinearSVMLearnerL1R_L2LOSS(testing.LearnerTestCase):
109    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.MCSVM_CS)
110   
111    test_learner_on = test_learner_on
112    test_learner_with_bias_on = test_learner_with_bias_on
113
114if __name__ == "__main__":
115    unittest.main()
Note: See TracBrowser for help on using the repository browser.