source: orange/Orange/testing/unit/tests/test_data_utils.py @ 10651:4f6fcf57db06

Revision 10651:4f6fcf57db06, 3.4 KB checked in by markotoplak, 2 years ago (diff)

Moved caching, collections, debugging, fileutil, r, testing from misc to utils.

Line 
1from Orange.utils import testing
2from Orange.data import utils
3import Orange
4
5try:
6    import unittest2 as unittest
7except:
8    import unittest
9import random
10
11
12@testing.datasets_driven
13class TestTake(testing.DataTestCase):
14    @testing.test_on_data
15    def test_take_domain(self, data):
16        size = len(data.domain)
17        indices = range(size)
18        to_mask = lambda inds: [i for i in indices if i in inds]
19
20        indices1 = [random.randrange(size)]
21        indices2 = indices1 + [random.randrange(size)]
22        indices3 = indices1 + [random.randrange(size)]
23
24        mask1 = to_mask(indices1)
25        mask2 = to_mask(indices2)
26        mask3 = to_mask(indices3)
27
28        data1 = utils.take(data, indices1, axis=1)
29        data1m = utils.take(data, mask1, axis=1)
30        self.assertEquals(len(data1.domain), len(indices1))
31        self.assertEquals(len(data1m.domain), len(set(indices1)))
32#        self.assertEquals(list(data1), list(data1m))
33
34        data2 = utils.take(data, indices2, axis=1)
35        data2m = utils.take(data, mask2, axis=1)
36        self.assertEquals(len(data2.domain), len(indices2))
37        self.assertEquals(len(data2m.domain), len(set(indices2)))
38#        self.assertEquals(list(data2), list(data2m))
39
40        data3 = utils.take(data, indices3, axis=1)
41        data3m = utils.take(data, mask3, axis=1)
42        self.assertEquals(len(data3.domain), len(indices3))
43        self.assertEquals(len(data3m.domain), len(set(indices3)))
44#        self.assertEquals(list(data3), list(data3m))
45
46    @testing.test_on_data
47    def test_take_instances(self, data):
48        size = len(data)
49        indices = range(len(data))
50        to_mask = lambda inds: [i for i in indices if i in inds]
51
52        indices1 = [random.randrange(size)]
53        indices2 = indices1 + [random.randrange(size)]
54        indices3 = indices1 + [random.randrange(size)]
55
56        mask1 = to_mask(indices1)
57        mask2 = to_mask(indices2)
58        mask3 = to_mask(indices3)
59
60        data1 = utils.take(data, indices1, axis=0)
61        data1m = utils.take(data, mask1, axis=0)
62        self.assertEquals(len(data1), len(indices1))
63        self.assertEquals(len(data1m), len(set(indices1)))
64
65        data2 = utils.take(data, indices2, axis=0)
66        data2m = utils.take(data, mask2, axis=0)
67        self.assertEquals(len(data2), len(indices2))
68        self.assertEquals(len(data2m), len(set(indices2)))
69
70        data3 = utils.take(data, indices3, axis=0)
71        data3m = utils.take(data, mask3, axis=0)
72        self.assertEquals(len(data3), len(indices3))
73        self.assertEquals(len(data3m), len(set(indices3)))
74
75def split(table):
76    size = len(table.domain)
77    indices = range(size)
78    to_mask = lambda inds: [i for i in indices if i in inds]
79
80    indices = [random.randrange(size) for i in range(2)]
81    part1 = utils.take(table, indices, axis=1)
82    complement = [i for i in range(len(table.domain)) if i not in indices]
83    part2 = utils.take(table, complement, axis=1)
84
85    return part1, part2
86
87@testing.datasets_driven
88class TestJoins(unittest.TestCase):
89    @testing.test_on_data
90    def test_left_join(self, table):
91        utils.add_row_id(table)
92        part1, part2 = split(table)
93        utils.left_join(part1, part2, utils._row_meta_id, utils._row_meta_id)
94
95    @testing.test_on_data
96    def test_right_join(self, table):
97        utils.add_row_id(table)
98        part1, part2 = split(table)
99        utils.right_join(part1, part2, utils._row_meta_id, utils._row_meta_id)
100
101if __name__ == "__main__":
102    unittest.main()
Note: See TracBrowser for help on using the repository browser.