source: orange/Orange/testing/unit/tests/test_linear.py @ 11017:512e4e254ad5

Revision 11017:512e4e254ad5, 5.5 KB checked in by Ales Erjavec <ales.erjavec@…>, 18 months ago (diff)

Renamed test_missing_instances to missing_instances_test.

Nose picks up the global function and trys to run it (unsuccessfully).

Line 
1import cPickle
2
3import Orange
4from Orange.testing import testing
5from Orange.testing.testing import datasets_driven
6from Orange.classification.svm import LinearSVMLearner
7try:
8    import unittest2 as unittest
9except:
10    import unittest
11
12import numpy as np
13
14
15def decision_values(classifier, instance):
16    """Return the decision values (numpy.array) for classifying `instance`.
17    """
18    instance = Orange.data.Table(classifier.domain, [instance])
19    (instance,) = instance.to_numpy_MA("A")
20
21    x = instance.filled(0.0)
22    if classifier.bias > 0.0:
23        x = np.hstack([x, [[classifier.bias]]])
24
25    w = np.array(classifier.weights)
26
27    return np.dot(x, w.T).ravel()
28
29
30def classify_from_weights(classifier, instance):
31    """Classify the instance using classifier's weights.
32    """
33    dec_values = decision_values(classifier, instance)
34
35    class_var = classifier.class_var
36    if len(class_var.values) > 2:
37        # TODO: Check how liblinear handles ties
38        return class_var(int(np.argmax(dec_values)))
39    else:
40        return class_var(0 if dec_values[0] > 0 else 1)
41
42
43def classify_from_weights_test(self, classifier, data):
44    class_var = data.domain.class_var
45    if isinstance(class_var, Orange.feature.Discrete):
46        for inst in data[:]:
47            pval1 = classifier(inst)
48            pval2 = classify_from_weights(classifier, inst)
49            self.assertEqual(pval1, pval2,
50                             msg="classifier and classify_from_weights return "
51                                 "different values")
52
53
54@testing.test_on_data
55def test_learner_on(self, dataset):
56    testing.LearnerTestCase.test_learner_on(self, dataset)
57
58    n_vals = len(dataset.domain.class_var.values)
59    if n_vals > 2:
60        self.assertEquals(len(self.classifier.weights), n_vals)
61    else:
62        self.assertEquals(len(self.classifier.weights), 1)
63
64    n_features = len(self.classifier.domain.attributes)
65    if self.classifier.bias >= 0:
66        n_features += 1
67
68    self.assertTrue(all(len(w) == n_features \
69                        for w in self.classifier.weights
70                        ))
71
72    classify_from_weights_test(self, self.classifier, dataset)
73
74
75@testing.test_on_data
76def test_learner_with_bias_on(self, dataset):
77    learner = self.learner
78    learner_b = cPickle.loads(cPickle.dumps(learner))
79    learner_b.bias = 1
80    try:
81        self.learner = learner_b
82        test_learner_on(self, dataset)
83    finally:
84        self.learner = learner
85
86
87def split(data, value):
88    pos = [inst for inst in data if inst.get_class() == value]
89    neg = [inst for inst in data if inst.get_class() != value]
90    return Orange.data.Table(pos), Orange.data.Table(neg)
91
92
93def missing_instances_test(self):
94    """Test the learner on a dataset with no instances for
95    some class.
96
97    """
98    data = Orange.data.Table("iris")
99    class_var = data.domain.class_var
100
101    for i, value in enumerate(class_var.values):
102        _, train = split(data, value)
103        classifier = self.learner(train)
104
105        self.assertEqual(len(classifier.weights), len(class_var.values),
106                        msg="Number of weight vectors differs from the number "
107                            "of class values")
108
109        dec_values = [decision_values(classifier, instance) \
110                      for instance in data]
111
112        self.assertTrue(all(val[i] == 0.0 for val in dec_values),
113                        msg="Non zero decision value for unseen class")
114
115        classify_from_weights_test(self, classifier, data)
116
117
118@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
119class TestLinearSVMLearnerL2R_L2LOSS_DUAL(testing.LearnerTestCase):
120    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L2R_L2LOSS_DUAL)
121
122    test_learner_on = test_learner_on
123    test_learner_with_bias_on = test_learner_with_bias_on
124    test_missing_instances = missing_instances_test
125
126
127@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
128class TestLinearSVMLearnerL2R_L2LOSS(testing.LearnerTestCase):
129    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L2R_L2LOSS)
130
131    test_learner_on = test_learner_on
132    test_learner_with_bias_on = test_learner_with_bias_on
133    test_missing_instances = missing_instances_test
134
135
136@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
137class TestLinearSVMLearnerL2R_L1LOSS_DUAL(testing.LearnerTestCase):
138    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L2R_L1LOSS_DUAL)
139
140    test_learner_on = test_learner_on
141    test_learner_with_bias_on = test_learner_with_bias_on
142    test_missing_instances = missing_instances_test
143
144
145@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
146class TestLinearSVMLearnerL2R_L1LOSS(testing.LearnerTestCase):
147    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L2R_L2LOSS)
148
149    test_learner_on = test_learner_on
150    test_learner_with_bias_on = test_learner_with_bias_on
151    test_missing_instances = missing_instances_test
152
153
154@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
155class TestLinearSVMLearnerL1R_L2LOSS(testing.LearnerTestCase):
156    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L1R_L2LOSS)
157
158    test_learner_on = test_learner_on
159    test_learner_with_bias_on = test_learner_with_bias_on
160    test_missing_instances = missing_instances_test
161
162
163@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
164class TestLinearSVMLearnerMCSVM_CSS(testing.LearnerTestCase):
165    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.MCSVM_CS)
166
167    test_learner_on = test_learner_on
168    test_learner_with_bias_on = test_learner_with_bias_on
169    test_missing_instances = missing_instances_test
170
171
172if __name__ == "__main__":
173    unittest.main()
Note: See TracBrowser for help on using the repository browser.