source: orange/testing/unittests/tests/test_data_utils.py @ 9109:d91271fec0e0

Revision 9109:d91271fec0e0, 3.5 KB checked in by ales_erjavec <ales.erjavec@…>, 3 years ago (diff)

Added unittests for Orange.data.utils

Line 
1from Orange.misc import testing
2from Orange.data import utils
3import Orange
4
5import unittest
6import random
7
8
9@testing.datasets_driven
10class TestTake(testing.DataTestCase):
11    @testing.test_on_data
12    def test_take_domain(self, data):
13        size = len(data.domain)
14        indices = range(size)
15        to_mask = lambda inds: [i for i in indices if i in inds]
16       
17        indices1 = [random.randrange(size)]
18        indices2 = indices1 + [random.randrange(size)]
19        indices3 = indices1 + [random.randrange(size)]
20       
21        mask1 = to_mask(indices1)
22        mask2 = to_mask(indices2)
23        mask3 = to_mask(indices3)
24       
25        data1 = utils.take(data, indices1, axis=1)
26        data1m = utils.take(data, mask1, axis=1)
27        self.assertEquals(len(data1.domain), len(indices1))
28        self.assertEquals(len(data1m.domain), len(set(indices1)))
29#        self.assertEquals(list(data1), list(data1m))
30       
31        data2 = utils.take(data, indices2, axis=1)
32        data2m = utils.take(data, mask2, axis=1)
33        self.assertEquals(len(data2.domain), len(indices2))
34        self.assertEquals(len(data2m.domain), len(set(indices2)))
35#        self.assertEquals(list(data2), list(data2m))
36       
37        data3 = utils.take(data, indices3, axis=1)
38        data3m = utils.take(data, mask3, axis=1)
39        self.assertEquals(len(data3.domain), len(indices3))
40        self.assertEquals(len(data3m.domain), len(set(indices3)))
41#        self.assertEquals(list(data3), list(data3m))
42
43    @testing.test_on_data
44    def test_take_instances(self, data):
45        size = len(data)
46        indices = range(len(data))
47        to_mask = lambda inds: [i for i in indices if i in inds]
48       
49        indices1 = [random.randrange(size)]
50        indices2 = indices1 + [random.randrange(size)]
51        indices3 = indices1 + [random.randrange(size)]
52       
53        mask1 = to_mask(indices1)
54        mask2 = to_mask(indices2)
55        mask3 = to_mask(indices3)
56       
57        data1 = utils.take(data, indices1, axis=0)
58        data1m = utils.take(data, mask1, axis=0)
59        self.assertEquals(len(data1), len(indices1))
60        self.assertEquals(len(data1m), len(set(indices1)))
61       
62        data2 = utils.take(data, indices2, axis=0)
63        data2m = utils.take(data, mask2, axis=0)
64        self.assertEquals(len(data2), len(indices2))
65        self.assertEquals(len(data2m), len(set(indices2)))
66       
67        data3 = utils.take(data, indices3, axis=0)
68        data3m = utils.take(data, mask3, axis=0)
69        self.assertEquals(len(data3), len(indices3))
70        self.assertEquals(len(data3m), len(set(indices3)))
71
72def split(table):
73    size = len(table.domain)
74    indices = range(size)
75    to_mask = lambda inds: [i for i in indices if i in inds]
76   
77    indices = [random.randrange(size) for i in range(2)]
78    part1 = utils.take(table, indices, axis=1)
79    complement = [i for i in range(len(table.domain)) if i not in indices]
80    part2 = utils.take(table, complement, axis=1)
81   
82    return part1, part2
83
84@testing.datasets_driven       
85class TestJoins(unittest.TestCase):
86    @testing.test_on_data
87    def test_left_join(self, table):
88        utils.add_row_id(table)
89        part1, part2 = split(table)
90        utils.left_join(part1, part2, utils._row_meta_id, utils._row_meta_id)
91       
92    @testing.test_on_data
93    def test_right_join(self, table):
94        utils.add_row_id(table)
95        part1, part2 = split(table)
96        utils.right_join(part1, part2, utils._row_meta_id, utils._row_meta_id)
97   
98if __name__ == "__main__":
99    unittest.main()
Note: See TracBrowser for help on using the repository browser.