Ignore:
Timestamp:
02/21/12 15:51:07 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Added test for EarthLearner on a muti-target problem.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/testing/unit/tests/test_earth.py

    r10278 r10329  
     1import Orange 
    12from Orange.misc import testing 
    23from Orange.misc.testing import datasets_driven, test_on_data, test_on_datasets 
    34from Orange.regression import earth 
    4 import Orange 
     5 
    56try: 
    67    import unittest2 as unittest 
     
    3738        self.measure = ScoreEarthImportance(t=5, score_what="rss") 
    3839 
    39  
     40@datasets_driven(datasets=["multitarget-synthetic"]) 
     41class TestEarthMultitarget(unittest.TestCase): 
     42    @test_on_data 
     43    def test_multi_target_on_data(self, dataset): 
     44        self.learner = earth.EarthLearner(degree=2, terms=10) 
     45         
     46        self.predictor = self.multi_target_test(self.learner, dataset) 
     47         
     48        self.assertTrue(bool(self.predictor.multitarget)) 
     49         
     50        s = str(self.predictor) 
     51        self.assertEqual(s, self.predictor.to_string()) 
     52        self.assertNotEqual(s, self.predictor.to_string(3, 6)) 
     53         
     54     
     55    def multi_target_test(self, learner, data): 
     56        indices = Orange.data.sample.SubsetIndices2(p0=0.3)(data) 
     57        learn = data.select(indices, 1) 
     58        test = data.select(indices, 0) 
     59         
     60        predictor = learner(learn) 
     61        self.assertIsInstance(predictor, Orange.classification.Classifier) 
     62        self.multi_target_predictor_interface(predictor, learn.domain) 
     63         
     64        from Orange.evaluation import testing as _testing 
     65         
     66        r = _testing.test_on_data([predictor], test) 
     67         
     68        def all_values(vals): 
     69            for v in vals: 
     70                self.assertIsInstance(v, Orange.core.Value) 
     71                 
     72        def all_dists(dist): 
     73            for d in dist: 
     74                self.assertIsInstance(d, Orange.core.Distribution) 
     75                 
     76        for ex in test: 
     77            preds = predictor(ex, Orange.core.GetValue) 
     78            all_values(preds) 
     79             
     80            dist = predictor(ex, Orange.core.GetProbabilities) 
     81            all_dists(dist) 
     82             
     83            preds, dist = predictor(ex, Orange.core.GetBoth) 
     84            all_values(preds) 
     85            all_dists(dist) 
     86             
     87            for d in dist: 
     88                if isinstance(d, Orange.core.ContDistribution): 
     89                    dist_sum = sum(d.values()) 
     90                else: 
     91                    dist_sum = sum(d) 
     92                     
     93                self.assertGreater(dist_sum, 0.0) 
     94                self.assertLess(abs(dist_sum - 1.0), 1e-3) 
     95             
     96        return predictor 
     97     
     98    def multi_target_predictor_interface(self, predictor, domain): 
     99        self.assertTrue(hasattr(predictor, "class_vars")) 
     100        self.assertIsInstance(predictor.class_vars, (list, Orange.core.VarList)) 
     101        self.assertTrue(all(c1 == c2 for c1, c2 in \ 
     102                            zip(predictor.class_vars, domain.class_vars))) 
     103         
     104     
    40105#@datasets_driven(datasets=testing.REGRESSION_DATASETS,) 
    41106#class TestScoreRSS(testing.MeasureAttributeTestCase): 
Note: See TracChangeset for help on using the changeset viewer.