source: orange/Orange/testing/unit/tests/test_linear.py @ 11397:e4b810f1f493

Revision 11397:e4b810f1f493, 6.4 KB checked in by Ales Erjavec <ales.erjavec@…>, 13 months ago (diff)

Added 'multinomial_treatment' parameter to LIBLINEAR derived learners.

Line 
1import cPickle
2
3import Orange
4from Orange.testing import testing
5from Orange.testing.testing import datasets_driven
6from Orange.classification.svm import LinearSVMLearner
7from Orange.data.preprocess import DomainContinuizer
8try:
9    import unittest2 as unittest
10except:
11    import unittest
12
13import numpy as np
14
15
16def clone(obj):
17    return cPickle.loads(cPickle.dumps(obj))
18
19
20def decision_values(classifier, instance):
21    """Return the decision values (numpy.array) for classifying `instance`.
22    """
23    instance = Orange.data.Table(classifier.domain, [instance])
24    (instance,) = instance.to_numpy_MA("A")
25
26    x = instance.filled(0.0)
27    if classifier.bias > 0.0:
28        x = np.hstack([x, [[classifier.bias]]])
29
30    w = np.array(classifier.weights)
31
32    return np.dot(x, w.T).ravel()
33
34
35def classify_from_weights(classifier, instance):
36    """Classify the instance using classifier's weights.
37    """
38    dec_values = decision_values(classifier, instance)
39
40    class_var = classifier.class_var
41    if len(class_var.values) > 2:
42        # TODO: Check how liblinear handles ties
43        return class_var(int(np.argmax(dec_values)))
44    else:
45        return class_var(0 if dec_values[0] > 0 else 1)
46
47
48def classify_from_weights_test(self, classifier, data):
49    class_var = data.domain.class_var
50    if isinstance(class_var, Orange.feature.Discrete):
51        for inst in data[:]:
52            pval1 = classifier(inst)
53            pval2 = classify_from_weights(classifier, inst)
54            self.assertEqual(pval1, pval2,
55                             msg="classifier and classify_from_weights return "
56                                 "different values")
57
58
59@testing.test_on_data
60def test_learner_on(self, dataset):
61    testing.LearnerTestCase.test_learner_on(self, dataset)
62
63    n_vals = len(dataset.domain.class_var.values)
64    if n_vals > 2:
65        self.assertEquals(len(self.classifier.weights), n_vals)
66    else:
67        self.assertEquals(len(self.classifier.weights), 1)
68
69    n_features = len(self.classifier.domain.attributes)
70    if self.classifier.bias >= 0:
71        n_features += 1
72
73    self.assertTrue(all(len(w) == n_features \
74                        for w in self.classifier.weights
75                        ))
76
77    classify_from_weights_test(self, self.classifier, dataset)
78
79
80@testing.test_on_data
81def test_learner_with_bias_on(self, dataset):
82    learner = self.learner
83    learner_b = clone(learner)
84    learner_b.bias = 1
85    try:
86        self.learner = learner_b
87        test_learner_on(self, dataset)
88    finally:
89        self.learner = learner
90
91
92def split(data, value):
93    pos = [inst for inst in data if inst.get_class() == value]
94    neg = [inst for inst in data if inst.get_class() != value]
95    return Orange.data.Table(pos), Orange.data.Table(neg)
96
97
98def missing_instances_test(self):
99    """Test the learner on a dataset with no instances for
100    some class.
101
102    """
103    data = Orange.data.Table("iris")
104    class_var = data.domain.class_var
105
106    for i, value in enumerate(class_var.values):
107        _, train = split(data, value)
108        classifier = self.learner(train)
109
110        self.assertEqual(len(classifier.weights), len(class_var.values),
111                        msg="Number of weight vectors differs from the number "
112                            "of class values")
113
114        dec_values = [decision_values(classifier, instance) \
115                      for instance in data]
116
117        self.assertTrue(all(val[i] == 0.0 for val in dec_values),
118                        msg="Non zero decision value for unseen class")
119
120        classify_from_weights_test(self, classifier, data)
121
122
123def multinomial_test(self):
124    data = Orange.data.Table("lenses")
125    learner = clone(self.learner)
126    learner.multinomial_treatment = DomainContinuizer.NValues
127    classifier = learner(data)
128    self.assertEqual(len(classifier.domain), 7)
129
130    learner.multinomial_treatment = DomainContinuizer.FrequentIsBase
131    classifier = learner(data)
132    self.assertEqual(len(classifier.domain), 6)
133
134    learner.multinomial_treatment = DomainContinuizer.ReportError
135    with self.assertRaises(Orange.core.KernelException):
136        classifier = learner(data)
137
138
139@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
140class TestLinearSVMLearnerL2R_L2LOSS_DUAL(testing.LearnerTestCase):
141    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L2R_L2LOSS_DUAL)
142
143    test_learner_on = test_learner_on
144    test_learner_with_bias_on = test_learner_with_bias_on
145    test_missing_instances = missing_instances_test
146    test_multinomial = multinomial_test
147
148
149@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
150class TestLinearSVMLearnerL2R_L2LOSS(testing.LearnerTestCase):
151    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L2R_L2LOSS)
152
153    test_learner_on = test_learner_on
154    test_learner_with_bias_on = test_learner_with_bias_on
155    test_missing_instances = missing_instances_test
156    test_multinomial = multinomial_test
157
158
159@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
160class TestLinearSVMLearnerL2R_L1LOSS_DUAL(testing.LearnerTestCase):
161    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L2R_L1LOSS_DUAL)
162
163    test_learner_on = test_learner_on
164    test_learner_with_bias_on = test_learner_with_bias_on
165    test_missing_instances = missing_instances_test
166    test_multinomial = multinomial_test
167
168
169@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
170class TestLinearSVMLearnerL2R_L1LOSS(testing.LearnerTestCase):
171    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L2R_L2LOSS)
172
173    test_learner_on = test_learner_on
174    test_learner_with_bias_on = test_learner_with_bias_on
175    test_missing_instances = missing_instances_test
176    test_multinomial = multinomial_test
177
178
179@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
180class TestLinearSVMLearnerL1R_L2LOSS(testing.LearnerTestCase):
181    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.L1R_L2LOSS)
182
183    test_learner_on = test_learner_on
184    test_learner_with_bias_on = test_learner_with_bias_on
185    test_missing_instances = missing_instances_test
186    test_multinomial = multinomial_test
187
188
189@datasets_driven(datasets=testing.CLASSIFICATION_DATASETS)
190class TestLinearSVMLearnerMCSVM_CSS(testing.LearnerTestCase):
191    LEARNER = LinearSVMLearner(sover_type=LinearSVMLearner.MCSVM_CS)
192
193    test_learner_on = test_learner_on
194    test_learner_with_bias_on = test_learner_with_bias_on
195    test_missing_instances = missing_instances_test
196    test_multinomial = multinomial_test
197
198
199if __name__ == "__main__":
200    unittest.main()
Note: See TracBrowser for help on using the repository browser.