source: orange/Orange/testing/unit/tests/test_pick_class.py @ 10895:e184536415f5

Revision 10895:e184536415f5, 5.2 KB checked in by Miran@…, 23 months ago (diff)

Added unitests for pick_class.

Line 
1from Orange import data
2try:
3    import unittest2 as unittest
4except:
5    import unittest
6
7class TestPickClass(unittest.TestCase):
8    def __init__(self, testCaseName):
9        unittest.TestCase.__init__(self, testCaseName)
10        self.orig = data.Table('multitarget-synthetic')
11
12    def test_pick_first(self):
13        d = data.Table(self.orig)
14
15        #picks first with no class_var
16        d.pick_class(d.domain.class_vars[0])
17        self.assertEquals(d.domain.class_var, self.orig.domain.class_vars[0])
18        self.assertEquals(d.domain.class_vars[0], self.orig.domain.class_vars[1])
19        self.assertEquals(d.domain.class_vars[1], self.orig.domain.class_vars[2])
20        self.assertEquals(d.domain.class_vars[2], self.orig.domain.class_vars[3])
21        for i in range(len(d)):
22            self.assertEquals(d[i].get_class(), self.orig[i].get_classes()[0])
23            self.assertEquals(d[i].get_classes()[0], self.orig[i].get_classes()[1])
24            self.assertEquals(d[i].get_classes()[1], self.orig[i].get_classes()[2])
25            self.assertEquals(d[i].get_classes()[2], self.orig[i].get_classes()[3])
26       
27        #picks first with existing class_var
28        d.pick_class(d.domain.class_vars[0])
29        self.assertEquals(d.domain.class_var, self.orig.domain.class_vars[1])
30        self.assertEquals(d.domain.class_vars[0], self.orig.domain.class_vars[0])
31        self.assertEquals(d.domain.class_vars[1], self.orig.domain.class_vars[2])
32        self.assertEquals(d.domain.class_vars[2], self.orig.domain.class_vars[3])
33        for i in range(len(d)):
34            self.assertEquals(d[i].get_class(), self.orig[i].get_classes()[1])
35            self.assertEquals(d[i].get_classes()[0], self.orig[i].get_classes()[0])
36            self.assertEquals(d[i].get_classes()[1], self.orig[i].get_classes()[2])
37            self.assertEquals(d[i].get_classes()[2], self.orig[i].get_classes()[3])
38
39        #picks None with existing class_var
40        d.pick_class(None)
41        self.assertEquals(d.domain.class_vars[0], self.orig.domain.class_vars[1])
42        self.assertEquals(d.domain.class_vars[1], self.orig.domain.class_vars[0])
43        self.assertEquals(d.domain.class_vars[2], self.orig.domain.class_vars[2])
44        self.assertEquals(d.domain.class_vars[3], self.orig.domain.class_vars[3])
45        for i in range(len(d)):
46            self.assertEquals(d[i].get_classes()[0], self.orig[i].get_classes()[1])
47            self.assertEquals(d[i].get_classes()[1], self.orig[i].get_classes()[0])
48            self.assertEquals(d[i].get_classes()[2], self.orig[i].get_classes()[2])
49            self.assertEquals(d[i].get_classes()[3], self.orig[i].get_classes()[3])
50
51    def test_pick_nonfirst(self):
52        d = data.Table(self.orig)
53
54        #picks not first with no class_var
55        d.pick_class(d.domain.class_vars[2])
56        self.assertEquals(d.domain.class_var, self.orig.domain.class_vars[2])
57        self.assertEquals(d.domain.class_vars[0], self.orig.domain.class_vars[1])
58        self.assertEquals(d.domain.class_vars[1], self.orig.domain.class_vars[0])
59        self.assertEquals(d.domain.class_vars[2], self.orig.domain.class_vars[3])
60        for i in range(len(d)):
61            self.assertEquals(d[i].get_class(), self.orig[i].get_classes()[2])
62            self.assertEquals(d[i].get_classes()[0], self.orig[i].get_classes()[1])
63            self.assertEquals(d[i].get_classes()[1], self.orig[i].get_classes()[0])
64            self.assertEquals(d[i].get_classes()[2], self.orig[i].get_classes()[3])
65
66        #picks not first with existing class_var
67        d.pick_class(d.domain.class_vars[2])
68        self.assertEquals(d.domain.class_var, self.orig.domain.class_vars[3])
69        self.assertEquals(d.domain.class_vars[0], self.orig.domain.class_vars[1])
70        self.assertEquals(d.domain.class_vars[1], self.orig.domain.class_vars[0])
71        self.assertEquals(d.domain.class_vars[2], self.orig.domain.class_vars[2])
72        for i in range(len(d)):
73            self.assertEquals(d[i].get_class(), self.orig[i].get_classes()[3])
74            self.assertEquals(d[i].get_classes()[0], self.orig[i].get_classes()[1])
75            self.assertEquals(d[i].get_classes()[1], self.orig[i].get_classes()[0])
76            self.assertEquals(d[i].get_classes()[2], self.orig[i].get_classes()[2])
77       
78        #picks None with existing class_var
79        d.pick_class(None)
80        self.assertEquals(d.domain.class_vars[0], self.orig.domain.class_vars[3])
81        self.assertEquals(d.domain.class_vars[1], self.orig.domain.class_vars[1])
82        self.assertEquals(d.domain.class_vars[2], self.orig.domain.class_vars[0])
83        self.assertEquals(d.domain.class_vars[3], self.orig.domain.class_vars[2])
84        for i in range(len(d)):
85            self.assertEquals(d[i].get_classes()[0], self.orig[i].get_classes()[3])
86            self.assertEquals(d[i].get_classes()[1], self.orig[i].get_classes()[1])
87            self.assertEquals(d[i].get_classes()[2], self.orig[i].get_classes()[0])
88            self.assertEquals(d[i].get_classes()[3], self.orig[i].get_classes()[2])
89
90    #uncomment when bug is fixed
91    #def test_pick_none(self):
92    #    d = data.Table(self.orig)
93    #    d.pick_class(None)
94    #    self.assertEquals(d.domain.class_vars,self.orig.domain.class_vars)
Note: See TracBrowser for help on using the repository browser.