source: orange/Orange/testing/unit/tests/test_earth.py @ 10775:c720c17c2f3f

Revision 10775:c720c17c2f3f, 4.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Added base_features method to EarthClassifier.

Line 
1import Orange
2from Orange.testing import testing
3from Orange.testing.testing import datasets_driven, test_on_data, test_on_datasets
4from Orange.regression import earth
5
6try:
7    import unittest2 as unittest
8except:
9    import unittest
10
11@datasets_driven(datasets=testing.REGRESSION_DATASETS + \
12                 testing.CLASSIFICATION_DATASETS)
13class TestEarthLearner(testing.LearnerTestCase):
14
15    def setUp(self):
16        self.learner = earth.EarthLearner(degree=2, terms=10)
17
18    @test_on_data
19    def test_learner_on(self, dataset):
20        if len(dataset) < 30:
21            raise unittest.SkipTest("Not enough examples.")
22        testing.LearnerTestCase.test_learner_on(self, dataset)
23        str = self.classifier.to_string()
24        evimp = self.classifier.evimp()
25        # Test base_features (make sure the domain translation works)
26        basis_features = self.classifier.base_features()
27        basis_domain = Orange.data.Domain(basis_features, None)
28        basis_matrix = Orange.data.Table(basis_domain, dataset)
29
30    @test_on_data
31    def test_bagged_evimp(self, dataset):
32        from Orange.ensemble.bagging import BaggedLearner
33        bagged = BaggedLearner(earth.EarthLearner(terms=10, degree=2), t=5)(dataset)
34        evimp = earth.bagged_evimp(bagged, used_only=False)
35
36
37@datasets_driven(datasets=testing.REGRESSION_DATASETS + \
38                 testing.CLASSIFICATION_DATASETS)
39class TestScoreEarthImportance(testing.MeasureAttributeTestCase):
40    def setUp(self):
41        from Orange.regression.earth import ScoreEarthImportance
42        self.measure = ScoreEarthImportance(t=5, score_what="rss")
43
44@datasets_driven(datasets=["multitarget-synthetic"])
45class TestEarthMultitarget(unittest.TestCase):
46    @test_on_data
47    def test_multi_target_on_data(self, dataset):
48        self.learner = earth.EarthLearner(degree=2, terms=10)
49       
50        self.predictor = self.multi_target_test(self.learner, dataset)
51       
52        self.assertTrue(bool(self.predictor.multitarget))
53       
54        s = str(self.predictor)
55        self.assertEqual(s, self.predictor.to_string())
56        self.assertNotEqual(s, self.predictor.to_string(3, 6))
57       
58   
59    def multi_target_test(self, learner, data):
60        indices = Orange.data.sample.SubsetIndices2(p0=0.3)(data)
61        learn = data.select(indices, 1)
62        test = data.select(indices, 0)
63       
64        predictor = learner(learn)
65        self.assertIsInstance(predictor, Orange.classification.Classifier)
66        self.multi_target_predictor_interface(predictor, learn.domain)
67       
68        from Orange.evaluation import testing as _testing
69       
70        r = _testing.test_on_data([predictor], test)
71       
72        def all_values(vals):
73            for v in vals:
74                self.assertIsInstance(v, Orange.core.Value)
75               
76        def all_dists(dist):
77            for d in dist:
78                self.assertIsInstance(d, Orange.core.Distribution)
79               
80        for ex in test:
81            preds = predictor(ex, Orange.core.GetValue)
82            all_values(preds)
83           
84            dist = predictor(ex, Orange.core.GetProbabilities)
85            all_dists(dist)
86           
87            preds, dist = predictor(ex, Orange.core.GetBoth)
88            all_values(preds)
89            all_dists(dist)
90           
91            for d in dist:
92                if isinstance(d, Orange.core.ContDistribution):
93                    dist_sum = sum(d.values())
94                else:
95                    dist_sum = sum(d)
96                   
97                self.assertGreater(dist_sum, 0.0)
98                self.assertLess(abs(dist_sum - 1.0), 1e-3)
99           
100        return predictor
101   
102    def multi_target_predictor_interface(self, predictor, domain):
103        self.assertTrue(hasattr(predictor, "class_vars"))
104        self.assertIsInstance(predictor.class_vars, (list, Orange.core.VarList))
105        self.assertTrue(all(c1 == c2 for c1, c2 in \
106                            zip(predictor.class_vars, domain.class_vars)))
107
108def load_tests(loader, tests, ignore):
109    import doctest
110    tests.addTests(doctest.DocTestSuite(earth))
111    return tests
112
113if __name__ == "__main__":
114    unittest.main()
115
Note: See TracBrowser for help on using the repository browser.