source: orange/Orange/testing/unit/tests/test_earth.py @ 10908:c3638c032fba

Revision 10908:c3638c032fba, 4.5 KB checked in by Ales Erjavec <ales.erjavec@…>, 23 months ago (diff)

Added test for earth basis term computation.
(Make sure the 'base_matrix' and 'base_features' work the same)

Line 
1import Orange
2from Orange.testing import testing
3from Orange.testing.testing import datasets_driven, test_on_data
4from Orange.regression import earth
5import numpy
6
7try:
8    import unittest2 as unittest
9except:
10    import unittest
11
12
13@datasets_driven(datasets=testing.REGRESSION_DATASETS + \
14                 testing.CLASSIFICATION_DATASETS)
15class TestEarthLearner(testing.LearnerTestCase):
16
17    def setUp(self):
18        self.learner = earth.EarthLearner(degree=2, terms=10)
19
20    @test_on_data
21    def test_learner_on(self, dataset):
22        if len(dataset) < 30:
23            raise unittest.SkipTest("Not enough examples.")
24        testing.LearnerTestCase.test_learner_on(self, dataset)
25        str = self.classifier.to_string()
26        evimp = self.classifier.evimp()
27
28        # Test base_features (make sure the domain translation works)
29        basis_features = self.classifier.base_features()
30        basis_domain = Orange.data.Domain(basis_features, None)
31        basis_table = Orange.data.Table(basis_domain, dataset)
32        basis_matrix = self.classifier.base_matrix(dataset)
33        # Filter best set
34        basis_matrix = basis_matrix[:, self.classifier.best_set]
35        # Remove intercept
36        basis_matrix = basis_matrix[:, 1:]
37        basis_matrix_a = basis_table.to_numpy_MA("A")[0]
38        # Fill unknowns
39        basis_matrix[basis_matrix_a.mask] = 0
40        basis_matrix_a = basis_matrix_a.filled(0)
41        diff = basis_matrix - basis_matrix_a
42        self.assertAlmostEqual(numpy.max(diff), 0, places=3)
43
44    @test_on_data
45    def test_bagged_evimp(self, dataset):
46        from Orange.ensemble.bagging import BaggedLearner
47        bagged_learner = BaggedLearner(earth.EarthLearner(terms=10, degree=2),
48                                       t=5)
49
50        bagged_classifier = bagged_learner(dataset)
51        evimp = earth.bagged_evimp(bagged_classifier, used_only=False)
52
53
54@datasets_driven(datasets=testing.REGRESSION_DATASETS + \
55                 testing.CLASSIFICATION_DATASETS)
56class TestScoreEarthImportance(testing.MeasureAttributeTestCase):
57    def setUp(self):
58        from Orange.regression.earth import ScoreEarthImportance
59        self.measure = ScoreEarthImportance(t=5, score_what="rss")
60
61
62@datasets_driven(datasets=["multitarget-synthetic"])
63class TestEarthMultitarget(unittest.TestCase):
64    @test_on_data
65    def test_multi_target_on_data(self, dataset):
66        self.learner = earth.EarthLearner(degree=2, terms=10)
67
68        self.predictor = self.multi_target_test(self.learner, dataset)
69
70        self.assertTrue(bool(self.predictor.multitarget))
71
72        s = str(self.predictor)
73        self.assertEqual(s, self.predictor.to_string())
74        self.assertNotEqual(s, self.predictor.to_string(3, 6))
75
76    def multi_target_test(self, learner, data):
77        indices = Orange.data.sample.SubsetIndices2(p0=0.3)(data)
78        learn = data.select(indices, 1)
79        test = data.select(indices, 0)
80
81        predictor = learner(learn)
82        self.assertIsInstance(predictor, Orange.classification.Classifier)
83        self.multi_target_predictor_interface(predictor, learn.domain)
84
85        from Orange.evaluation import testing as _testing
86
87        r = _testing.test_on_data([predictor], test)
88
89        def all_values(vals):
90            for v in vals:
91                self.assertIsInstance(v, Orange.core.Value)
92
93        def all_dists(dist):
94            for d in dist:
95                self.assertIsInstance(d, Orange.core.Distribution)
96
97        for ex in test:
98            preds = predictor(ex, Orange.core.GetValue)
99            all_values(preds)
100
101            dist = predictor(ex, Orange.core.GetProbabilities)
102            all_dists(dist)
103
104            preds, dist = predictor(ex, Orange.core.GetBoth)
105            all_values(preds)
106            all_dists(dist)
107
108            for d in dist:
109                if isinstance(d, Orange.core.ContDistribution):
110                    dist_sum = sum(d.values())
111                else:
112                    dist_sum = sum(d)
113
114                self.assertGreater(dist_sum, 0.0)
115                self.assertLess(abs(dist_sum - 1.0), 1e-3)
116
117        return predictor
118
119    def multi_target_predictor_interface(self, predictor, domain):
120        self.assertTrue(hasattr(predictor, "class_vars"))
121        self.assertIsInstance(predictor.class_vars, (list, Orange.core.VarList))
122        self.assertTrue(all(c1 == c2 for c1, c2 in \
123                            zip(predictor.class_vars, domain.class_vars)))
124
125
126def load_tests(loader, tests, ignore):
127    import doctest
128    tests.addTests(doctest.DocTestSuite(earth))
129    return tests
130
131
132if __name__ == "__main__":
133    unittest.main()
134
Note: See TracBrowser for help on using the repository browser.