source: orange/Orange/testing/unit/tests/test_evaluation_scoring.py @ 10426:ce57e8dbcc18

Revision 10426:ce57e8dbcc18, 8.3 KB checked in by anzeh <anze.staric@…>, 2 years ago (diff)

Refactored AUC.

Line 
1import random
2from Orange import data
3from Orange.evaluation import scoring, testing
4from Orange.statistics import distribution
5
6try:
7    import unittest2 as unittest
8except:
9    import unittest
10
11random.seed(0)
12def random_learner(data, *args):
13    def random_classifier(*args, **kwargs):
14        prob = [random.random() for _ in data.domain.class_var.values]
15        sprob = sum(prob)
16        prob = [i / sprob for i in prob]
17        distribution.Discrete(prob)
18        return data.domain.class_var[random.randint(0,
19            len(data.domain.class_var.values) - 1)], prob
20    return random_classifier
21
22class TestAuc(unittest.TestCase):
23    def setUp(self):
24        self.learner = random_learner
25
26    def test_auc_on_monks(self):
27        ds = data.Table("monks-1")
28        cv = testing.cross_validation([self.learner], ds, folds=5)
29        pt = testing.proportion_test([self.learner], ds, times=1)
30
31        auc = scoring.AUC(cv)
32        self.assertEqual(len(auc), 1)
33
34        auc = scoring.AUC(pt)
35        self.assertEqual(len(auc), 1)
36
37    def test_auc_on_iris(self):
38        ds = data.Table("iris")
39        test_results = testing.cross_validation([self.learner], ds, folds=5)
40        auc = scoring.AUC(test_results)
41
42        self.assertEqual(len(auc), 1)
43
44    def test_auc_on_iris_by_pairs(self):
45        ds = data.Table("iris")
46        test_results = testing.cross_validation([self.learner], ds, folds=5)
47        auc = scoring.AUC(test_results, multiclass=scoring.AUC.ByPairs)
48
49        self.assertEqual(len(auc), 1)
50
51    def test_auc_on_iris_by_weighted_pairs(self):
52        ds = data.Table("iris")
53        test_results = testing.cross_validation([self.learner], ds, folds=5)
54        auc = scoring.AUC(test_results, multiclass=scoring.AUC.ByWeightedPairs)
55
56        self.assertEqual(len(auc), 1)
57
58    def test_auc_on_iris_one_against_all(self):
59        ds = data.Table("iris")
60        test_results = testing.cross_validation([self.learner], ds, folds=5)
61        auc = scoring.AUC(test_results, multiclass=scoring.AUC.OneAgainstAll)
62
63        self.assertEqual(len(auc), 1)
64
65    def test_auc_on_iris_weighted_one_against_all(self):
66        ds = data.Table("iris")
67        test_results = testing.cross_validation([self.learner], ds, folds=5)
68        auc = scoring.AUC(test_results, multiclass=scoring.AUC.WeightedOneAgainstAll)
69
70        self.assertEqual(len(auc), 1)
71
72    def test_auc_on_iris_single_class(self):
73        ds = data.Table("iris")
74        test_results = testing.cross_validation([self.learner], ds, folds=5)
75        auc = scoring.AUC_for_single_class(test_results)
76        self.assertEqual(len(auc), 1)
77        auc = scoring.AUC_for_single_class(test_results, 0)
78        self.assertEqual(len(auc), 1)
79        auc = scoring.AUC_for_single_class(test_results, 1)
80        self.assertEqual(len(auc), 1)
81        auc = scoring.AUC_for_single_class(test_results, 2)
82        self.assertEqual(len(auc), 1)
83
84    def test_auc_on_iris_pair(self):
85        ds = data.Table("iris")
86        test_results = testing.cross_validation([self.learner], ds, folds=5)
87        auc = scoring.AUC_for_pair_of_classes(test_results, 0, 1)
88        self.assertEqual(len(auc), 1)
89        auc = scoring.AUC_for_pair_of_classes(test_results, 0, 2)
90        self.assertEqual(len(auc), 1)
91        auc = scoring.AUC_for_pair_of_classes(test_results, 1, 2)
92        self.assertEqual(len(auc), 1)
93
94    def test_auc_matrix_on_iris(self):
95        ds = data.Table("iris")
96        test_results = testing.cross_validation([self.learner], ds, folds=5)
97        auc = scoring.AUC_matrix(test_results)
98        self.assertEqual(len(auc), 1)
99        self.assertEqual(len(auc[0]), 3)
100
101
102class TestCA(unittest.TestCase):
103    def setUp(self):
104        self.learner = random_learner
105
106    def test_ca_on_iris(self):
107        ds = data.Table("iris")
108        cv = testing.cross_validation([self.learner], ds, folds=5)
109        ca = scoring.CA(cv)
110        self.assertEqual(len(ca), 1)
111
112    def test_ca_from_confusion_matrix_list_on_iris(self):
113        ds = data.Table("iris")
114        cv = testing.cross_validation([self.learner], ds, folds=5)
115        cm = scoring.confusion_matrices(cv)
116        ca = scoring.CA(cm)
117        self.assertEqual(len(ca), 1)
118
119    def test_ca_from_confusion_matrix_on_iris(self):
120        ds = data.Table("iris")
121        cv = testing.cross_validation([self.learner], ds, folds=5)
122        cm = scoring.confusion_matrices(cv, class_index=1)
123        ca = scoring.CA(cm[0])
124        self.assertEqual(len(ca), 1)
125
126    def test_ca_from_confusion_matrix_for_classification_on_iris(self):
127        ds = data.Table("iris")
128        pt = testing.proportion_test([self.learner], ds, times=1)
129        self.assertEqual(pt.number_of_iterations, 1)
130        ca = scoring.CA(pt)
131        self.assertEqual(len(ca), 1)
132
133    def test_ca_from_confusion_matrix_for_classification_on_iris_se(self):
134        ds = data.Table("iris")
135        pt = testing.proportion_test([self.learner], ds, times=1)
136        self.assertEqual(pt.number_of_iterations, 1)
137        ca = scoring.CA(pt, report_se=True)
138        self.assertEqual(len(ca), 1)
139
140    def test_ca_from_confusion_matrix_on_iris_se(self):
141        ds = data.Table("iris")
142        cv = testing.cross_validation([self.learner], ds, folds=5)
143        cm = scoring.confusion_matrices(cv, class_index=1)
144        ca = scoring.CA(cm[0], report_se=True)
145        self.assertEqual(len(ca), 1)
146
147    def test_ca_on_iris(self):
148        ds = data.Table("iris")
149        cv = testing.cross_validation([self.learner], ds, folds=5)
150        ca = scoring.CA(cv, report_se=True)
151        self.assertEqual(len(ca), 1)
152
153
154class TestConfusionMatrix(unittest.TestCase):
155    def test_construct_confusion_matrix_from_multiclass(self):
156        learner = random_learner
157        ds = data.Table("iris")
158        pt = testing.proportion_test([learner], ds, times=1)
159        cm = scoring.confusion_matrices(pt)
160
161        self.assertTrue(isinstance(cm[0], list))
162
163
164    def test_construct_confusion_matrix_from_biclass(self):
165        learner = random_learner
166        ds = data.Table("monks-1")
167        pt = testing.proportion_test([learner], ds, times=1)
168        cm = scoring.confusion_matrices(pt, class_index=1)
169
170        self.assertTrue(hasattr(cm[0], "TP"))
171
172class CMScoreTest(object):
173    def test_with_test_results_on_biclass(self):
174        learner = random_learner
175        ds = data.Table("monks-1")
176        pt = testing.proportion_test([learner], ds, times=1)
177        scores = self.score(pt)
178        self.assertIsInstance(scores, list)
179
180    def test_with_test_results_on_multiclass(self):
181        learner = random_learner
182        ds = data.Table("iris")
183        pt = testing.proportion_test([learner], ds, times=1)
184
185        scores = self.score(pt)
186        self.assertIsInstance(scores, list)
187
188    def test_with_confusion_matrix_on_biclass(self):
189        learner = random_learner
190        ds = data.Table("monks-1")
191        pt = testing.proportion_test([learner], ds, times=1)
192        cm = scoring.confusion_matrices(pt, class_index=1)
193        scores = self.score(cm)
194        self.assertIsInstance(scores, list)
195
196    def test_with_confusion_matrix_on_multiclass(self):
197        learner = random_learner
198        ds = data.Table("iris")
199        pt = testing.proportion_test([learner], ds, times=1)
200        cm = scoring.confusion_matrices(pt, class_index=1)
201        scores = self.score(cm)
202        self.assertIsInstance(scores, list)
203
204class TestSensitivity(CMScoreTest, unittest.TestCase):
205    @property
206    def score(self):
207        return scoring.Sensitivity
208
209class TestSpecificity(CMScoreTest, unittest.TestCase):
210    @property
211    def score(self):
212        return scoring.Specificity
213
214class TestPrecision(CMScoreTest, unittest.TestCase):
215    @property
216    def score(self):
217        return scoring.Precision
218
219class TestRecall(CMScoreTest, unittest.TestCase):
220    @property
221    def score(self):
222        return scoring.Recall
223
224class TestPPV(CMScoreTest, unittest.TestCase):
225    @property
226    def score(self):
227        return scoring.PPV
228
229class TestNPV(CMScoreTest, unittest.TestCase):
230    @property
231    def score(self):
232        return scoring.NPV
233
234class TestF1(CMScoreTest, unittest.TestCase):
235    @property
236    def score(self):
237        return scoring.F1
238
239class TestFalpha(CMScoreTest, unittest.TestCase):
240    @property
241    def score(self):
242        return scoring.Falpha
243
244class TestMCC(CMScoreTest, unittest.TestCase):
245    @property
246    def score(self):
247        return scoring.MCC
248
249if __name__ == '__main__':
250    unittest.main()
Note: See TracBrowser for help on using the repository browser.