source: orange/Orange/testing/unit/tests/test_kmeans.py @ 10278:f3b1ffae9c29

Revision 10278:f3b1ffae9c29, 1.8 KB checked in by Miha Stajdohar <miha.stajdohar@…>, 2 years ago (diff)

Unittest2 for python 2.6.

Line 
1try:
2    import unittest2 as unittest
3except:
4    import unittest
5from Orange.misc import testing
6from Orange.clustering import kmeans
7from Orange.clustering.kmeans import Clustering
8from Orange.distance import *
9
10@testing.datasets_driven
11class TestKMeans(unittest.TestCase):
12    @testing.test_on_data
13    def test_kmeans_on(self, table):
14        km = Clustering(table, 5, maxiters=100, nstart=3)
15        self.assertEqual(len(km.centroids), 5)
16        self.assertEqual(max(set(km.clusters)) + 1, 5)
17        self.assertEqual(len(km.clusters), len(table))
18
19        self._test_score_functions(km)
20
21    def _test_score_functions(self, km):
22        kmeans.score_distance_to_centroids(km)
23        kmeans.score_fast_silhouette(km, index=None)
24        kmeans.score_silhouette(km, index=None)
25
26    @testing.test_on_data
27    def test_init_functs(self, table):
28        distfunc = Euclidean(table)
29        for k in [1, 5, 10]:
30            self._test_init_func(table, k, distfunc)
31
32    def _test_init_func(self, table, k, distfunc):
33        centers = kmeans.init_random(table, k, distfunc)
34        self.assertEqual(len(centers), k)
35        self.assertEqual(centers[0].domain, table.domain)
36
37        centers = kmeans.init_diversity(table, k, distfunc)
38        self.assertEqual(len(centers), k)
39        self.assertEqual(centers[0].domain, table.domain)
40
41        centers = kmeans.init_hclustering(n=50)(table, k, distfunc)
42        self.assertEqual(len(centers), k)
43        self.assertEqual(centers[0].domain, table.domain)
44
45    def test_kmeans_fail(self):
46        """ Test the reaction when centroids is larger then example table length
47        """
48        data = Orange.data.Table("iris")
49        self.assertRaises(Exception, Clustering, (data, len(data) + 1))
50
51
52if __name__ == "__main__":
53    unittest.main()
54
Note: See TracBrowser for help on using the repository browser.