source: orange/Orange/testing/unit/tests/test_earth.py @ 10775:c720c17c2f3f

Revision 10775:c720c17c2f3f, 4.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 14 months ago (diff)

Added base_features method to EarthClassifier.

RevLine 
[10329]1import Orange
[10655]2from Orange.testing import testing
3from Orange.testing.testing import datasets_driven, test_on_data, test_on_datasets
[8149]4from Orange.regression import earth
[10329]5
[10278]6try:
7    import unittest2 as unittest
8except:
9    import unittest
[8149]10
[10278]11@datasets_driven(datasets=testing.REGRESSION_DATASETS + \
[8763]12                 testing.CLASSIFICATION_DATASETS)
[8149]13class TestEarthLearner(testing.LearnerTestCase):
[10278]14
[8149]15    def setUp(self):
16        self.learner = earth.EarthLearner(degree=2, terms=10)
[10278]17
[8149]18    @test_on_data
19    def test_learner_on(self, dataset):
[8763]20        if len(dataset) < 30:
21            raise unittest.SkipTest("Not enough examples.")
[8149]22        testing.LearnerTestCase.test_learner_on(self, dataset)
[9570]23        str = self.classifier.to_string()
[8149]24        evimp = self.classifier.evimp()
[10775]25        # Test base_features (make sure the domain translation works)
26        basis_features = self.classifier.base_features()
27        basis_domain = Orange.data.Domain(basis_features, None)
28        basis_matrix = Orange.data.Table(basis_domain, dataset)
[10278]29
[8763]30    @test_on_data
31    def test_bagged_evimp(self, dataset):
[8149]32        from Orange.ensemble.bagging import BaggedLearner
33        bagged = BaggedLearner(earth.EarthLearner(terms=10, degree=2), t=5)(dataset)
34        evimp = earth.bagged_evimp(bagged, used_only=False)
[10278]35
36
37@datasets_driven(datasets=testing.REGRESSION_DATASETS + \
[8763]38                 testing.CLASSIFICATION_DATASETS)
39class TestScoreEarthImportance(testing.MeasureAttributeTestCase):
[8149]40    def setUp(self):
41        from Orange.regression.earth import ScoreEarthImportance
42        self.measure = ScoreEarthImportance(t=5, score_what="rss")
[10278]43
[10329]44@datasets_driven(datasets=["multitarget-synthetic"])
45class TestEarthMultitarget(unittest.TestCase):
46    @test_on_data
47    def test_multi_target_on_data(self, dataset):
48        self.learner = earth.EarthLearner(degree=2, terms=10)
49       
50        self.predictor = self.multi_target_test(self.learner, dataset)
51       
52        self.assertTrue(bool(self.predictor.multitarget))
53       
54        s = str(self.predictor)
55        self.assertEqual(s, self.predictor.to_string())
56        self.assertNotEqual(s, self.predictor.to_string(3, 6))
57       
58   
59    def multi_target_test(self, learner, data):
60        indices = Orange.data.sample.SubsetIndices2(p0=0.3)(data)
61        learn = data.select(indices, 1)
62        test = data.select(indices, 0)
63       
64        predictor = learner(learn)
65        self.assertIsInstance(predictor, Orange.classification.Classifier)
66        self.multi_target_predictor_interface(predictor, learn.domain)
67       
68        from Orange.evaluation import testing as _testing
69       
70        r = _testing.test_on_data([predictor], test)
71       
72        def all_values(vals):
73            for v in vals:
74                self.assertIsInstance(v, Orange.core.Value)
75               
76        def all_dists(dist):
77            for d in dist:
78                self.assertIsInstance(d, Orange.core.Distribution)
79               
80        for ex in test:
81            preds = predictor(ex, Orange.core.GetValue)
82            all_values(preds)
83           
84            dist = predictor(ex, Orange.core.GetProbabilities)
85            all_dists(dist)
86           
87            preds, dist = predictor(ex, Orange.core.GetBoth)
88            all_values(preds)
89            all_dists(dist)
90           
91            for d in dist:
92                if isinstance(d, Orange.core.ContDistribution):
93                    dist_sum = sum(d.values())
94                else:
95                    dist_sum = sum(d)
96                   
97                self.assertGreater(dist_sum, 0.0)
98                self.assertLess(abs(dist_sum - 1.0), 1e-3)
99           
100        return predictor
101   
102    def multi_target_predictor_interface(self, predictor, domain):
103        self.assertTrue(hasattr(predictor, "class_vars"))
104        self.assertIsInstance(predictor.class_vars, (list, Orange.core.VarList))
105        self.assertTrue(all(c1 == c2 for c1, c2 in \
106                            zip(predictor.class_vars, domain.class_vars)))
[10278]107
[10663]108def load_tests(loader, tests, ignore):
109    import doctest
[10768]110    tests.addTests(doctest.DocTestSuite(earth))
111    return tests
[10278]112
[8149]113if __name__ == "__main__":
114    unittest.main()
[10278]115
Note: See TracBrowser for help on using the repository browser.