source: orange/Orange/OrangeWidgets/Unsupervised/OWKMeans.py @ 10995:c18f18cf5624

Revision 10995:c18f18cf5624, 20.1 KB checked in by Ales Erjavec <ales.erjavec@…>, 19 months ago (diff)

Fixed inverted distribution bars in Optimization Report box.

Line 
1"""
2<name>k-Means Clustering</name>
3<description>k-means clustering.</description>
4<icon>icons/KMeans.png</icon>
5<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
6<priority>2300</priority>
7"""
8
9from OWWidget import *
10import OWGUI
11import orange
12import orngClustering
13import math
14import random
15import statc
16
17from itertools import izip
18
19import orngDebugging
20
21
22class OWKMeans(OWWidget):
23    settingsList = ["K", "optimized", "optimizationFrom", "optimizationTo",
24                    "scoring", "distanceMeasure", "classifySelected",
25                    "addIdAs", "classifyName", "initializationType",
26                    "restarts", "runAnyChange"]
27
28    distanceMeasures = [
29        ("Euclidean",
30         orange.ExamplesDistanceConstructor_Euclidean),
31        ("Pearson Correlation",
32         orngClustering.ExamplesDistanceConstructor_PearsonR),
33        ("Spearman Rank Correlation",
34         orngClustering.ExamplesDistanceConstructor_SpearmanR),
35        ("Manhattan",
36         orange.ExamplesDistanceConstructor_Manhattan),
37        ("Maximal",
38         orange.ExamplesDistanceConstructor_Maximal),
39        ("Hamming",
40         orange.ExamplesDistanceConstructor_Hamming),
41        ]
42
43    initializations = [
44        ("Random",
45         orngClustering.kmeans_init_random),
46        ("Diversity",
47         orngClustering.kmeans_init_diversity),
48        ("Agglomerative clustering",
49         orngClustering.KMeans_init_hierarchicalClustering(n=100)),
50        ]
51
52    scoringMethods = [
53        ("Silhouette (heuristic)",
54         orngClustering.score_fastsilhouette),
55        ("Silhouette",
56         orngClustering.score_silhouette),
57        ("Between cluster distance",
58         orngClustering.score_betweenClusterDistance),
59        ("Distance to centroids",
60         orngClustering.score_distance_to_centroids)
61        ]
62
63    def __init__(self, parent=None, signalManager=None):
64        OWWidget.__init__(self, parent, signalManager, 'k-Means Clustering')
65
66        self.inputs = [("Data", ExampleTable, self.setData)]
67        self.outputs = [("Data", ExampleTable), ("Centroids", ExampleTable)]
68
69        #set default settings
70        self.K = 2
71        self.optimized = True
72        self.optimizationFrom = 2
73        self.optimizationTo = 5
74        self.scoring = 0
75        self.distanceMeasure = 0
76        self.initializationType = 0
77        self.restarts = 1
78        self.classifySelected = 1
79        self.addIdAs = 0
80        self.runAnyChange = 1
81        self.classifyName = "Cluster"
82
83        self.settingsChanged = False
84
85        self.loadSettings()
86
87        self.data = None  # holds input data
88        self.km = None    # holds clustering object
89
90        # GUI definition
91        # settings
92
93        box = OWGUI.widgetBox(self.controlArea, "Clusters (k)",
94                              addSpace=True, spacing=0)
95#        left, top, right, bottom = box.getContentsMargins()
96#        box.setContentsMargins(left, 0, right, 0)
97        bg = OWGUI.radioButtonsInBox(box, self, "optimized", [],
98                                     callback=self.setOptimization)
99
100        fixedBox = OWGUI.widgetBox(box, orientation="horizontal",
101                                   margin=0, spacing=bg.layout().spacing())
102
103        button = OWGUI.appendRadioButton(bg, self, "optimized", "Fixed",
104                                         insertInto=fixedBox,
105                                         tooltip="Fixed number of clusters")
106
107        button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
108        fixedBox.layout().setAlignment(button, Qt.AlignLeft)
109        self.fixedSpinBox = OWGUI.spin(OWGUI.widgetBox(fixedBox), self, "K",
110                                       min=2, max=30,
111                                       tooltip="Fixed number of clusters",
112                                       callback=self.update,
113                                       callbackOnReturn=True)
114
115        optimizedBox = OWGUI.widgetBox(box, margin=0,
116                                       spacing=bg.layout().spacing())
117        button = OWGUI.appendRadioButton(bg, self, "optimized", "Optimized",
118                                         insertInto=optimizedBox)
119
120        box = OWGUI.indentedBox(optimizedBox,
121                                sep=OWGUI.checkButtonOffsetHint(button))
122
123        box.layout().setSpacing(0)
124        self.optimizationBox = box
125
126        OWGUI.spin(box, self, "optimizationFrom", label="From",
127                   min=2, max=99,
128                   tooltip="Minimum number of clusters to try",
129                   callback=self.updateOptimizationFrom,
130                   callbackOnReturn=True)
131
132        OWGUI.spin(box, self, "optimizationTo", label="To",
133                   min=3, max=100,
134                   tooltip="Maximum number of clusters to try",
135                   callback=self.updateOptimizationTo,
136                   callbackOnReturn=True)
137
138        OWGUI.comboBox(box, self, "scoring", label="Scoring",
139                       orientation="horizontal",
140                       items=[m[0] for m in self.scoringMethods],
141                       callback=self.update)
142
143        box = OWGUI.widgetBox(self.controlArea, "Settings", addSpace=True)
144
145        OWGUI.comboBox(box, self, "distanceMeasure", label="Distance measures",
146                       items=[name for name, _ in self.distanceMeasures],
147                       tooltip=None,
148                       indent=20,
149                       callback=self.update)
150
151        cb = OWGUI.comboBox(box, self, "initializationType",
152                            label="Initialization",
153                            items=[name for name, _ in self.initializations],
154                            tooltip=None,
155                            indent=20,
156                            callback=self.update)
157
158        OWGUI.spin(cb.box, self, "restarts", label="Restarts",
159                   orientation="horizontal",
160                   min=1,
161                   max=100 if not orngDebugging.orngDebuggingEnabled else 5,
162                   callback=self.update,
163                   callbackOnReturn=True)
164
165        box = OWGUI.widgetBox(self.controlArea, "Cluster IDs", addSpace=True)
166        cb = OWGUI.checkBox(box, self, "classifySelected",
167                            "Append cluster indices")
168
169        box = OWGUI.indentedBox(box, sep=OWGUI.checkButtonOffsetHint(cb))
170
171        form = QWidget()
172        le = OWGUI.lineEdit(form, self, "classifyName", None,
173                            orientation="horizontal",
174                            valueType=str)
175
176        cc = OWGUI.comboBox(form, self, "addIdAs", label=" ",
177                            orientation="horizontal",
178                            items=["Class attribute",
179                                   "Attribute",
180                                   "Meta attribute"])
181
182        layout = QFormLayout()
183        layout.setSpacing(8)
184        layout.setFieldGrowthPolicy(QFormLayout.AllNonFixedFieldsGrow)
185        layout.setLabelAlignment(Qt.AlignLeft | Qt.AlignJustify)
186        layout.addRow("Name  ", le)
187        layout.addRow("Place  ", cc)
188
189        form.setLayout(layout)
190        box.layout().addWidget(form)
191        left, top, right, bottom = layout.getContentsMargins()
192        layout.setContentsMargins(0, top, right, bottom)
193
194        cb.disables.append(box)
195        cb.makeConsistent()
196
197        box = OWGUI.widgetBox(self.controlArea, "Run")
198        cb = OWGUI.checkBox(box, self, "runAnyChange", "Run after any change")
199        self.runButton = b = OWGUI.button(box, self, "Run Clustering",
200                                          callback=self.run)
201
202        OWGUI.setStopper(self, b, cb, "settingsChanged", callback=self.run)
203
204        OWGUI.rubber(self.controlArea)
205
206        # display of clustering results
207        self.optimizationReportBox = OWGUI.widgetBox(self.mainArea)
208        self.tableBox = OWGUI.widgetBox(self.optimizationReportBox,
209                                        "Optimization Report")
210        self.table = OWGUI.table(self.tableBox,
211                                 selectionMode=QTableWidget.SingleSelection)
212
213        self.table.setHorizontalScrollMode(QTableWidget.ScrollPerPixel)
214        self.table.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
215        self.table.setSelectionBehavior(QAbstractItemView.SelectRows)
216        self.table.setColumnCount(3)
217        self.table.setHorizontalHeaderLabels(["k", "Best", "Score"])
218        self.table.verticalHeader().hide()
219        self.table.horizontalHeader().setStretchLastSection(True)
220
221        self.table.setItemDelegateForColumn(
222            2, OWGUI.TableBarItem(self, self.table))
223
224        self.table.setItemDelegateForColumn(
225            1, OWGUI.IndicatorItemDelegate(self))
226
227        self.table.setSizePolicy(QSizePolicy.MinimumExpanding,
228                                 QSizePolicy.MinimumExpanding)
229
230        self.connect(self.table,
231                     SIGNAL("itemSelectionChanged()"),
232                     self.tableItemSelected)
233
234        self.setSizePolicy(QSizePolicy.Preferred,
235                           QSizePolicy.Preferred)
236
237        self.mainArea.setSizePolicy(QSizePolicy.MinimumExpanding,
238                                    QSizePolicy.MinimumExpanding)
239
240        OWGUI.rubber(self.topWidgetPart)
241
242        self.updateOptimizationGui()
243
244    def adjustSize(self):
245        self.ensurePolished()
246        s = self.sizeHint()
247        self.resize(s)
248
249    def hideOptResults(self):
250        self.mainArea.hide()
251        QTimer.singleShot(100, self.adjustSize)
252
253    def showOptResults(self):
254        self.mainArea.show()
255        QTimer.singleShot(100, self.adjustSize)
256
257    def sizeHint(self):
258        s = self.leftWidgetPart.sizeHint()
259        if self.optimized and not self.mainArea.isHidden():
260            s.setWidth(s.width() + self.mainArea.sizeHint().width() + \
261                       self.childrenRect().x() * 4)
262        return s
263
264    def updateOptimizationGui(self):
265        self.fixedSpinBox.setDisabled(bool(self.optimized))
266        self.optimizationBox.setDisabled(not bool(self.optimized))
267        if self.optimized:
268            self.showOptResults()
269        else:
270            self.hideOptResults()
271
272    def updateOptimizationFrom(self):
273        self.optimizationTo = max([self.optimizationFrom + 1,
274                                   self.optimizationTo])
275        self.update()
276
277    def updateOptimizationTo(self):
278        self.optimizationFrom = min([self.optimizationFrom,
279                                     self.optimizationTo - 1])
280        self.update()
281
282    def setOptimization(self):
283        self.updateOptimizationGui()
284        self.update()
285
286    def runOptimization(self):
287        if self.optimizationTo > len(set(self.data)):
288            self.error("Not enough unique data instances (%d) for given "
289                       "number of clusters (%d)." % \
290                       (len(set(self.data)), self.optimizationTo))
291            return
292
293        random.seed(0)
294        data = self.data
295        nstart = self.restarts
296        initialization = self.initializations[self.initializationType][1]
297        distance = self.distanceMeasures[self.distanceMeasure][1]
298        scoring = self.scoringMethods[self.scoring][1]
299        try:
300            self.progressBarInit()
301            Ks = range(self.optimizationFrom, self.optimizationTo + 1)
302            outer_callback_count = len(Ks) * self.restarts
303            outer_callback_state = {"restart": 0}
304            optimizationRun = []
305            for k in Ks:
306                def outer_progress(km):
307                    outer_callback_state["restart"] += 1
308                    self.progressBarSet(
309                        100.0 * outer_callback_state["restart"] \
310                        / outer_callback_count
311                    )
312
313                def inner_progress(km):
314                    estimate = self.progressEstimate(km)
315                    self.progressBarSet(min(estimate / outer_callback_count + \
316                                            outer_callback_state["restart"] * \
317                                            100.0 / outer_callback_count,
318                                            100.0))
319
320                kmeans = orngClustering.KMeans(
321                    data,
322                    centroids=k,
323                    minscorechange=0,
324                    nstart=nstart,
325                    initialization=initialization,
326                    distance=distance,
327                    scoring=scoring,
328                    outer_callback=outer_progress,
329                    inner_callback=inner_progress
330                    )
331                optimizationRun.append((k, kmeans))
332
333                if nstart == 1:
334                    outer_progress(None)
335
336            self.optimizationRun = optimizationRun
337            minimize = getattr(scoring, "minimize", False)
338            self.optimizationRunSorted = \
339                    sorted(optimizationRun,
340                           key=lambda item: item[1].score,
341                           reverse=minimize)
342
343            self.progressBarFinished()
344
345            self.bestRun = self.optimizationRunSorted[-1]
346            self.showResults()
347            self.sendData()
348        except Exception, ex:
349            self.error(0, "An error occurred while running optimization. "
350                          "Reason: " + str(ex))
351            raise
352
353    def cluster(self):
354        if self.K > len(set(self.data)):
355            self.error("Not enough unique data instances (%d) for given "
356                       "number of clusters (%d)." % \
357                       (len(set(self.data)), self.K))
358            return
359        random.seed(0)
360
361        self.km = orngClustering.KMeans(
362            centroids=self.K,
363            minscorechange=0,
364            nstart=self.restarts,
365            initialization=self.initializations[self.initializationType][1],
366            distance=self.distanceMeasures[self.distanceMeasure][1],
367            scoring=self.scoringMethods[self.scoring][1],
368            inner_callback=self.clusterCallback,
369            )
370        self.progressBarInit()
371        self.km(self.data)
372        self.sendData()
373        self.progressBarFinished()
374
375    def clusterCallback(self, km):
376        norm = math.log(len(self.data), 10)
377        if km.iteration < norm:
378            self.progressBarSet(80.0 * km.iteration / norm)
379        else:
380            self.progressBarSet(80.0 + 0.15 * \
381                                (1.0 - math.exp(norm - km.iteration)))
382
383    def progressEstimate(self, km):
384        norm = math.log(len(km.data), 10)
385        if km.iteration < norm:
386            return min(80.0 * km.iteration / norm, 90.0)
387        else:
388            return min(80.0 + 0.15 * (1.0 - math.exp(norm - km.iteration)),
389                       90.0)
390
391    def scoreFmt(self, score, max_decimals=10):
392        if score > 0 and score < 1:
393            fmt = "%%.%if" % min(int(abs(math.log(max(score, 1e-10)))) + 2,
394                                 max_decimals)
395        else:
396            fmt = "%.1f"
397        return fmt
398
399    def showResults(self):
400        self.table.setRowCount(len(self.optimizationRun))
401        scoring = self.scoringMethods[self.scoring][1]
402        minimize = getattr(scoring, "minimize", False)
403
404        bestScore = self.bestRun[1].score
405        worstScore = self.optimizationRunSorted[0][1].score
406
407        if minimize:
408            bestScore, worstScore = worstScore, bestScore
409
410        scoreSpan = (bestScore - worstScore) or 1
411
412        for i, (k, run) in enumerate(self.optimizationRun):
413            item = OWGUI.tableItem(self.table, i, 0, k)
414            item.setData(Qt.TextAlignmentRole, QVariant(Qt.AlignCenter))
415
416            item = OWGUI.tableItem(self.table, i, 1, None)
417            item.setData(OWGUI.IndicatorItemDelegate.IndicatorRole,
418                         QVariant((k, run) == self.bestRun))
419
420            item.setData(Qt.TextAlignmentRole, QVariant(Qt.AlignCenter))
421
422            fmt = self.scoreFmt(run.score)
423            item = OWGUI.tableItem(self.table, i, 2, fmt % run.score)
424            barRatio = 0.95 * (run.score - worstScore) / scoreSpan
425
426            item.setData(OWGUI.TableBarItem.BarRole, QVariant(barRatio))
427            if (k, run) == self.bestRun:
428                self.table.selectRow(i)
429
430        for i in range(2):
431            self.table.resizeColumnToContents(i)
432
433        self.table.show()
434
435        if minimize:
436            self.tableBox.setTitle("Optimization Report (smaller is better)")
437        else:
438            self.tableBox.setTitle("Optimization Report (bigger is better)")
439
440        QTimer.singleShot(0, self.adjustSize)
441
442    def run(self):
443        self.error(0)
444        if not self.data:
445            return
446        if self.optimized:
447            self.runOptimization()
448        else:
449            self.cluster()
450
451    def update(self):
452        if self.runAnyChange:
453            self.run()
454        else:
455            self.settingsChanged = True
456
457    def tableItemSelected(self):
458        selectedItems = self.table.selectedItems()
459        rows = set([item.row() for item in selectedItems])
460        if len(rows) == 1:
461            row = rows.pop()
462            self.sendData(self.optimizationRun[row][1])
463
464    def sendData(self, km=None):
465        if km is None:
466            km = self.bestRun[1] if self.optimized else self.km
467        if not self.data or not km:
468            self.send("Data", None)
469            self.send("Centroids", None)
470            return
471
472        clustVar = orange.EnumVariable(self.classifyName,
473                                       values=["C%d" % (x + 1) \
474                                               for x in range(km.k)])
475
476        origDomain = self.data.domain
477        if self.addIdAs == 0:
478            domain = orange.Domain(origDomain.attributes, clustVar)
479            if origDomain.classVar:
480                domain.addmeta(orange.newmetaid(), origDomain.classVar)
481            aid = -1
482        elif self.addIdAs == 1:
483            domain = orange.Domain(origDomain.attributes + [clustVar],
484                                   origDomain.classVar)
485            aid = len(origDomain.attributes)
486        else:
487            domain = orange.Domain(origDomain.attributes,
488                                   origDomain.classVar)
489            aid = orange.newmetaid()
490            domain.addmeta(aid, clustVar)
491
492        domain.addmetas(origDomain.getmetas())
493
494        # construct a new data set, with a class as assigned by
495        # k-means clustering
496        new = orange.ExampleTable(domain, self.data)
497        for ex, midx in izip(new, km.clusters):
498            ex[aid] = midx
499
500        centroids = orange.ExampleTable(domain, km.centroids)
501        for i, c in enumerate(centroids):
502            c[aid] = i
503            if origDomain.classVar:
504                c[origDomain.classVar] = "?"
505
506        self.send("Data", new)
507        self.send("Centroids", centroids)
508
509    def setData(self, data):
510        """Handle data from the input signal."""
511        self.runButton.setEnabled(bool(data))
512        if not data:
513            self.data = None
514            self.table.setRowCount(0)
515        else:
516            self.data = data
517            self.run()
518
519    def sendReport(self):
520        settings = [("Distance measure",
521                     self.distanceMeasures[self.distanceMeasure][0]),
522                    ("Initialization",
523                     self.initializations[self.initializationType][0]),
524                    ("Restarts",
525                     self.restarts)]
526        if self.optimized:
527            self.reportSettings("Settings", settings)
528            self.reportSettings("Optimization",
529                                [("Minimum num. of clusters",
530                                  self.optimizationFrom),
531                                 ("Maximum num. of clusters",
532                                  self.optimizationTo),
533                                 ("Scoring method",
534                                  self.scoringMethods[self.scoring][0])])
535        else:
536            self.reportSettings("Settings",
537                                settings + [("Number of clusters (K)",
538                                             self.K)])
539
540        self.reportData(self.data)
541        if self.optimized:
542            import OWReport
543            self.reportSection("Cluster size optimization report")
544            self.reportRaw(OWReport.reportTable(self.table))
545
546###############################################################################
547# Test this widget
548
549if __name__ == "__main__":
550    import orange
551    a = QApplication(sys.argv)
552    ow = OWKMeans()
553    d = orange.ExampleTable("iris.tab")
554    ow.setData(d)
555    ow.show()
556    a.exec_()
557    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.