source: orange/Orange/testing/unit/tests/test_kmeans.py @ 9679:3879dea56188

Revision 9679:3879dea56188, 1.8 KB checked in by Miha Stajdohar <miha.stajdohar@…>, 2 years ago (diff)

Moved and renamed testing.

Line 
1import unittest
2from Orange.misc import testing
3from Orange.clustering import kmeans
4from Orange.clustering.kmeans import Clustering
5from Orange.distance.instances import *
6
7@testing.datasets_driven
8class TestKMeans(unittest.TestCase):
9    @testing.test_on_data
10    def test_kmeans_on(self, table):
11        km = Clustering(table, 5, maxiters=100, nstart=3)
12        self.assertEqual(len(km.centroids), 5)
13        self.assertEqual(max(set(km.clusters)) + 1, 5)
14        self.assertEqual(len(km.clusters), len(table))
15       
16        self._test_score_functions(km)
17   
18    def _test_score_functions(self, km):
19        kmeans.score_distance_to_centroids(km)
20        kmeans.score_fast_silhouette(km, index=None)
21        kmeans.score_silhouette(km, index=None)
22       
23    @testing.test_on_data
24    def test_init_functs(self, table):
25        distfunc = EuclideanConstructor(table)
26        for k in [1, 5, 10]:
27            self._test_init_func(table, k, distfunc)
28       
29    def _test_init_func(self, table, k, distfunc):
30        centers = kmeans.init_random(table, k, distfunc)
31        self.assertEqual(len(centers), k)
32        self.assertEqual(centers[0].domain, table.domain)
33       
34        centers = kmeans.init_diversity(table, k, distfunc)
35        self.assertEqual(len(centers), k)
36        self.assertEqual(centers[0].domain, table.domain)
37       
38        centers = kmeans.init_hclustering(n=50)(table, k, distfunc)
39        self.assertEqual(len(centers), k)
40        self.assertEqual(centers[0].domain, table.domain)
41       
42   
43    @unittest.expectedFailure
44    def test_kmeans_fail(self):
45        """ Test the reaction when centroids is larger then example table length
46        """
47        data = iter(testDatasets()).next()
48        Clustering(data, len(data) + 1)
49
50
51if __name__ == "__main__":
52    unittest.main()
53           
Note: See TracBrowser for help on using the repository browser.