Changeset 10908:c3638c032fba in orange
 Timestamp:
 06/07/12 15:32:18 (23 months ago)
 Branch:
 default
 File:

 1 edited
Legend:
 Unmodified
 Added
 Removed

Orange/testing/unit/tests/test_earth.py
r10775 r10908 1 1 import Orange 2 2 from Orange.testing import testing 3 from Orange.testing.testing import datasets_driven, test_on_data , test_on_datasets3 from Orange.testing.testing import datasets_driven, test_on_data 4 4 from Orange.regression import earth 5 import numpy 5 6 6 7 try: … … 8 9 except: 9 10 import unittest 11 10 12 11 13 @datasets_driven(datasets=testing.REGRESSION_DATASETS + \ … … 23 25 str = self.classifier.to_string() 24 26 evimp = self.classifier.evimp() 25 # Test base_features (make sure the domain translation works) 27 28 # Test base_features (make sure the domain translation works) 26 29 basis_features = self.classifier.base_features() 27 30 basis_domain = Orange.data.Domain(basis_features, None) 28 basis_matrix = Orange.data.Table(basis_domain, dataset) 31 basis_table = Orange.data.Table(basis_domain, dataset) 32 basis_matrix = self.classifier.base_matrix(dataset) 33 # Filter best set 34 basis_matrix = basis_matrix[:, self.classifier.best_set] 35 # Remove intercept 36 basis_matrix = basis_matrix[:, 1:] 37 basis_matrix_a = basis_table.to_numpy_MA("A")[0] 38 # Fill unknowns 39 basis_matrix[basis_matrix_a.mask] = 0 40 basis_matrix_a = basis_matrix_a.filled(0) 41 diff = basis_matrix  basis_matrix_a 42 self.assertAlmostEqual(numpy.max(diff), 0, places=3) 29 43 30 44 @test_on_data 31 45 def test_bagged_evimp(self, dataset): 32 46 from Orange.ensemble.bagging import BaggedLearner 33 bagged = BaggedLearner(earth.EarthLearner(terms=10, degree=2), t=5)(dataset) 34 evimp = earth.bagged_evimp(bagged, used_only=False) 47 bagged_learner = BaggedLearner(earth.EarthLearner(terms=10, degree=2), 48 t=5) 49 50 bagged_classifier = bagged_learner(dataset) 51 evimp = earth.bagged_evimp(bagged_classifier, used_only=False) 35 52 36 53 … … 42 59 self.measure = ScoreEarthImportance(t=5, score_what="rss") 43 60 61 44 62 @datasets_driven(datasets=["multitargetsynthetic"]) 45 63 class TestEarthMultitarget(unittest.TestCase): … … 47 65 def test_multi_target_on_data(self, dataset): 48 66 self.learner = earth.EarthLearner(degree=2, terms=10) 49 67 50 68 self.predictor = self.multi_target_test(self.learner, dataset) 51 69 52 70 self.assertTrue(bool(self.predictor.multitarget)) 53 71 54 72 s = str(self.predictor) 55 73 self.assertEqual(s, self.predictor.to_string()) 56 74 self.assertNotEqual(s, self.predictor.to_string(3, 6)) 57 58 75 59 76 def multi_target_test(self, learner, data): 60 77 indices = Orange.data.sample.SubsetIndices2(p0=0.3)(data) 61 78 learn = data.select(indices, 1) 62 79 test = data.select(indices, 0) 63 80 64 81 predictor = learner(learn) 65 82 self.assertIsInstance(predictor, Orange.classification.Classifier) 66 83 self.multi_target_predictor_interface(predictor, learn.domain) 67 84 68 85 from Orange.evaluation import testing as _testing 69 86 70 87 r = _testing.test_on_data([predictor], test) 71 88 72 89 def all_values(vals): 73 90 for v in vals: 74 91 self.assertIsInstance(v, Orange.core.Value) 75 92 76 93 def all_dists(dist): 77 94 for d in dist: 78 95 self.assertIsInstance(d, Orange.core.Distribution) 79 96 80 97 for ex in test: 81 98 preds = predictor(ex, Orange.core.GetValue) 82 99 all_values(preds) 83 100 84 101 dist = predictor(ex, Orange.core.GetProbabilities) 85 102 all_dists(dist) 86 103 87 104 preds, dist = predictor(ex, Orange.core.GetBoth) 88 105 all_values(preds) 89 106 all_dists(dist) 90 107 91 108 for d in dist: 92 109 if isinstance(d, Orange.core.ContDistribution): … … 94 111 else: 95 112 dist_sum = sum(d) 96 113 97 114 self.assertGreater(dist_sum, 0.0) 98 115 self.assertLess(abs(dist_sum  1.0), 1e3) 99 116 100 117 return predictor 101 118 102 119 def multi_target_predictor_interface(self, predictor, domain): 103 120 self.assertTrue(hasattr(predictor, "class_vars")) … … 106 123 zip(predictor.class_vars, domain.class_vars))) 107 124 125 108 126 def load_tests(loader, tests, ignore): 109 127 import doctest … … 111 129 return tests 112 130 131 113 132 if __name__ == "__main__": 114 133 unittest.main()
Note: See TracChangeset
for help on using the changeset viewer.