source: orange/Orange/testing/unit/tests/test_feature_selection.py @ 11647:dfa6d31c2fc2

Revision 11647:dfa6d31c2fc2, 3.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 9 months ago (diff)

Preserve the domain's meta attributes and class_vars.

RevLine 
[10655]1import Orange.testing.testing as testing
[10522]2from Orange.feature import selection, scoring
3import Orange
4
5from operator import itemgetter
6
7try:
8    import unittest2 as unittest
9except ImportError:
10    import unittest
11
12class TestSelection(unittest.TestCase):
13    def setUp(self):
14        self.score = Orange.feature.scoring.Gini()
15        self.data = Orange.data.Table("lenses")
16       
17        self.scores = scoring.score_all(self.data, self.score)
18       
19    def test_best_n(self):
20        sorted_scores = sorted(self.scores, key=itemgetter(1),
21                               reverse=True)
22        # Test the descending order of scores
23        self.assertEqual(self.scores, sorted_scores)
24       
25        # best 3 scores
26        best_3 = map(itemgetter(0), sorted_scores[:3])
27       
28        # test best_n function
[10708]29        self.assertEqual(selection.top_rated(self.scores, 3), best_3)
[10522]30       
[10708]31        self.assertTrue(len(selection.top_rated(self.scores, 3)) == 3)
[10522]32       
33        # all returned values should be strings.
34        self.assertTrue(all(isinstance(item, basestring) for item in \
[10708]35                            selection.top_rated(self.scores, 3)))
[10522]36       
[10708]37        new_data = selection.select(self.data, self.scores, 3)
[10522]38        self.assertEqual(best_3, [a.name for a in new_data.domain.attributes])
39        self.assertEqual(new_data.domain.class_var, self.data.domain.class_var)
40       
41    def test_above_threashold(self):
42        threshold = self.scores[len(self.scores) / 2][1]
[11646]43        above = [item[0] for item in self.scores if item[1] >= threshold]
[10522]44       
45        self.assertEqual(above, 
46                         selection.above_threshold(self.scores, threshold)
47                         )
48       
49        new_data = selection.select_above_threshold(self.data, 
50                                                    self.scores, threshold)
51        self.assertEqual(above, [a.name for a in new_data.domain.attributes])
52        self.assertEqual(new_data.domain.class_var, self.data.domain.class_var)
[11647]53
54    def test_select_features_subset(self):
55        data = Orange.data.Table("lenses")
56
57        d1 = selection._select_features_subset(data, [])
58        self.assertSequenceEqual(d1.domain.features, [])
59        self.assertIs(d1.domain.class_var, data.domain.class_var)
60
61        d1 = selection._select_features_subset(data, [data.domain[0]])
62        self.assertSequenceEqual(d1.domain.features, [data.domain[0]])
63        self.assertIs(d1.domain.class_var, data.domain.class_var)
64
65        domain = Orange.data.Domain(data.domain.features[:2],
66                                    data.domain.class_var,
67                                    class_vars=[data.domain.features[2]])
68        domain.add_metas({-1, data.domain.features[3]})
69        data = Orange.data.Table(domain, data)
70
71        d1 = selection._select_features_subset(data, [data.domain[0]])
72        self.assertSequenceEqual(d1.domain.features, [data.domain[0]])
73        self.assertIs(d1.domain.class_var, data.domain.class_var)
74        self.assertSequenceEqual(d1.domain.class_vars, data.domain.class_vars)
75        self.assertEqual(d1.domain.get_metas(), data.domain.get_metas())
76
[10522]77if __name__ == "__main__":
78    unittest.main()
Note: See TracBrowser for help on using the repository browser.