Ignore:
Timestamp:
06/07/12 15:32:18 (23 months ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Added test for earth basis term computation.
(Make sure the 'base_matrix' and 'base_features' work the same)

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/testing/unit/tests/test_earth.py

    r10775 r10908  
    11import Orange 
    22from Orange.testing import testing 
    3 from Orange.testing.testing import datasets_driven, test_on_data, test_on_datasets 
     3from Orange.testing.testing import datasets_driven, test_on_data 
    44from Orange.regression import earth 
     5import numpy 
    56 
    67try: 
     
    89except: 
    910    import unittest 
     11 
    1012 
    1113@datasets_driven(datasets=testing.REGRESSION_DATASETS + \ 
     
    2325        str = self.classifier.to_string() 
    2426        evimp = self.classifier.evimp() 
    25         # Test base_features (make sure the domain translation works)  
     27 
     28        # Test base_features (make sure the domain translation works) 
    2629        basis_features = self.classifier.base_features() 
    2730        basis_domain = Orange.data.Domain(basis_features, None) 
    28         basis_matrix = Orange.data.Table(basis_domain, dataset) 
     31        basis_table = Orange.data.Table(basis_domain, dataset) 
     32        basis_matrix = self.classifier.base_matrix(dataset) 
     33        # Filter best set 
     34        basis_matrix = basis_matrix[:, self.classifier.best_set] 
     35        # Remove intercept 
     36        basis_matrix = basis_matrix[:, 1:] 
     37        basis_matrix_a = basis_table.to_numpy_MA("A")[0] 
     38        # Fill unknowns 
     39        basis_matrix[basis_matrix_a.mask] = 0 
     40        basis_matrix_a = basis_matrix_a.filled(0) 
     41        diff = basis_matrix - basis_matrix_a 
     42        self.assertAlmostEqual(numpy.max(diff), 0, places=3) 
    2943 
    3044    @test_on_data 
    3145    def test_bagged_evimp(self, dataset): 
    3246        from Orange.ensemble.bagging import BaggedLearner 
    33         bagged = BaggedLearner(earth.EarthLearner(terms=10, degree=2), t=5)(dataset) 
    34         evimp = earth.bagged_evimp(bagged, used_only=False) 
     47        bagged_learner = BaggedLearner(earth.EarthLearner(terms=10, degree=2), 
     48                                       t=5) 
     49 
     50        bagged_classifier = bagged_learner(dataset) 
     51        evimp = earth.bagged_evimp(bagged_classifier, used_only=False) 
    3552 
    3653 
     
    4259        self.measure = ScoreEarthImportance(t=5, score_what="rss") 
    4360 
     61 
    4462@datasets_driven(datasets=["multitarget-synthetic"]) 
    4563class TestEarthMultitarget(unittest.TestCase): 
     
    4765    def test_multi_target_on_data(self, dataset): 
    4866        self.learner = earth.EarthLearner(degree=2, terms=10) 
    49          
     67 
    5068        self.predictor = self.multi_target_test(self.learner, dataset) 
    51          
     69 
    5270        self.assertTrue(bool(self.predictor.multitarget)) 
    53          
     71 
    5472        s = str(self.predictor) 
    5573        self.assertEqual(s, self.predictor.to_string()) 
    5674        self.assertNotEqual(s, self.predictor.to_string(3, 6)) 
    57          
    58      
     75 
    5976    def multi_target_test(self, learner, data): 
    6077        indices = Orange.data.sample.SubsetIndices2(p0=0.3)(data) 
    6178        learn = data.select(indices, 1) 
    6279        test = data.select(indices, 0) 
    63          
     80 
    6481        predictor = learner(learn) 
    6582        self.assertIsInstance(predictor, Orange.classification.Classifier) 
    6683        self.multi_target_predictor_interface(predictor, learn.domain) 
    67          
     84 
    6885        from Orange.evaluation import testing as _testing 
    69          
     86 
    7087        r = _testing.test_on_data([predictor], test) 
    71          
     88 
    7289        def all_values(vals): 
    7390            for v in vals: 
    7491                self.assertIsInstance(v, Orange.core.Value) 
    75                  
     92 
    7693        def all_dists(dist): 
    7794            for d in dist: 
    7895                self.assertIsInstance(d, Orange.core.Distribution) 
    79                  
     96 
    8097        for ex in test: 
    8198            preds = predictor(ex, Orange.core.GetValue) 
    8299            all_values(preds) 
    83              
     100 
    84101            dist = predictor(ex, Orange.core.GetProbabilities) 
    85102            all_dists(dist) 
    86              
     103 
    87104            preds, dist = predictor(ex, Orange.core.GetBoth) 
    88105            all_values(preds) 
    89106            all_dists(dist) 
    90              
     107 
    91108            for d in dist: 
    92109                if isinstance(d, Orange.core.ContDistribution): 
     
    94111                else: 
    95112                    dist_sum = sum(d) 
    96                      
     113 
    97114                self.assertGreater(dist_sum, 0.0) 
    98115                self.assertLess(abs(dist_sum - 1.0), 1e-3) 
    99              
     116 
    100117        return predictor 
    101      
     118 
    102119    def multi_target_predictor_interface(self, predictor, domain): 
    103120        self.assertTrue(hasattr(predictor, "class_vars")) 
     
    106123                            zip(predictor.class_vars, domain.class_vars))) 
    107124 
     125 
    108126def load_tests(loader, tests, ignore): 
    109127    import doctest 
     
    111129    return tests 
    112130 
     131 
    113132if __name__ == "__main__": 
    114133    unittest.main() 
Note: See TracChangeset for help on using the changeset viewer.