source: orange/Orange/OrangeWidgets/Unsupervised/OWKMeans.py @ 9671:a7b056375472

Revision 9671:a7b056375472, 19.3 KB checked in by anze <anze.staric@…>, 2 years ago (diff)

Moved orange to Orange (part 2)

Line 
1"""
2<name>k-Means Clustering</name>
3<description>k-means clustering.</description>
4<icon>icons/KMeans.png</icon>
5<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact>
6<priority>2300</priority>
7"""
8
9from OWWidget import *
10import OWGUI
11import orange
12import orngClustering
13import math
14import random
15import statc
16#from PyQt4.Qwt5 import *
17from itertools import izip
18
19import orngDebugging
20
21##############################################################################
22# main class
23
24class OWKMeans(OWWidget):
25    settingsList = ["K", "optimized", "optimizationFrom", "optimizationTo", "scoring", "distanceMeasure", "classifySelected", "addIdAs", "classifyName",
26                    "initializationType", "restarts", "runAnyChange"]
27   
28    distanceMeasures = [
29        ("Euclidean", orange.ExamplesDistanceConstructor_Euclidean),
30        ("Pearson Correlation", orngClustering.ExamplesDistanceConstructor_PearsonR),
31        ("Spearman Rank Correlation", orngClustering.ExamplesDistanceConstructor_SpearmanR),
32        ("Manhattan", orange.ExamplesDistanceConstructor_Manhattan),
33        ("Maximal", orange.ExamplesDistanceConstructor_Maximal),
34        ("Hamming", orange.ExamplesDistanceConstructor_Hamming),
35        ]
36
37    initializations = [
38        ("Random", orngClustering.kmeans_init_random),
39        ("Diversity", orngClustering.kmeans_init_diversity),
40        ("Agglomerative clustering", orngClustering.KMeans_init_hierarchicalClustering(n=100)),
41        ]
42   
43    scoringMethods = [
44        ("Silhouette (heuristic)", orngClustering.score_fastsilhouette),
45        ("Silhouette", orngClustering.score_silhouette),
46        ("Between cluster distance", orngClustering.score_betweenClusterDistance),
47        ("Distance to centroids", orngClustering.score_distance_to_centroids)
48        ] 
49
50    def __init__(self, parent=None, signalManager = None):
51        OWWidget.__init__(self, parent, signalManager, 'k-Means Clustering')
52
53        self.inputs = [("Data", ExampleTable, self.setData)]
54        self.outputs = [("Data", ExampleTable), ("Centroids", ExampleTable)]
55
56        #set default settings
57        self.K = 2
58        self.optimized = True
59        self.optimizationFrom = 2
60        self.optimizationTo = 5
61        self.scoring = 0
62        self.distanceMeasure = 0
63        self.initializationType = 0
64        self.restarts = 1
65        self.classifySelected = 1
66        self.addIdAs = 0
67        self.runAnyChange = 1
68        self.classifyName = "Cluster"
69       
70        self.settingsChanged = False
71       
72        self.loadSettings()
73
74        self.data = None # holds input data
75        self.km = None   # holds clustering object
76
77        # GUI definition
78        # settings
79       
80        box = OWGUI.widgetBox(self.controlArea, "Clusters (k)", addSpace=True, spacing=0)
81        left, top, right, bottom = box.getContentsMargins()
82#        box.setContentsMargins(left, 0, right, 0)
83        bg = OWGUI.radioButtonsInBox(box, self, "optimized", [], callback=self.setOptimization)
84        fixedBox = OWGUI.widgetBox(box, orientation="horizontal", margin=0, spacing=bg.layout().spacing())
85        button = OWGUI.appendRadioButton(bg, self, "optimized", "Fixed", 
86                                         insertInto=fixedBox, tooltip="Fixed number of clusters")
87        button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
88        fixedBox.layout().setAlignment(button, Qt.AlignLeft)
89        self.fixedSpinBox = OWGUI.spin(OWGUI.widgetBox(fixedBox), self, "K", min=2, max=30, tooltip="Fixed number of clusters",
90                                       callback=self.update, callbackOnReturn=True)
91       
92        optimizedBox = OWGUI.widgetBox(box, margin=0, spacing=bg.layout().spacing())
93        button = OWGUI.appendRadioButton(bg, self, "optimized", "Optimized", insertInto=optimizedBox)
94       
95        box = OWGUI.indentedBox(optimizedBox, sep=OWGUI.checkButtonOffsetHint(button))
96        box.layout().setSpacing(0)
97        self.optimizationBox = box
98        OWGUI.spin(box, self, "optimizationFrom", label="From", min=2, max=99,
99                   tooltip="Minimum number of clusters to try", callback=self.updateOptimizationFrom, callbackOnReturn=True)
100        b = OWGUI.spin(box, self, "optimizationTo", label="To", min=3, max=100,
101                   tooltip="Maximum number of clusters to try", callback=self.updateOptimizationTo, callbackOnReturn=True)
102#        b.control.setLineEdit(OWGUI.LineEditWFocusOut(b))
103        OWGUI.comboBox(box, self, "scoring", label="Scoring", orientation="horizontal",
104                       items=[m[0] for m in self.scoringMethods], callback=self.update)
105       
106       
107       
108        box = OWGUI.widgetBox(self.controlArea, "Settings", addSpace=True)
109#        OWGUI.spin(box, self, "K", label="Number of clusters"+"  ", min=1, max=30, step=1,
110#                   callback = self.initializeClustering)
111        OWGUI.comboBox(box, self, "distanceMeasure", label="Distance measures",
112                       items=[name for name, foo in self.distanceMeasures],
113                       tooltip=None, indent=20,
114                       callback = self.update)
115        cb = OWGUI.comboBox(box, self, "initializationType", label="Initialization",
116                       items=[name for name, foo in self.initializations],
117                       tooltip=None, indent=20,
118                       callback = self.update)
119        OWGUI.spin(cb.box, self, "restarts", label="Restarts", orientation="horizontal",
120                   min=1, max=100 if not orngDebugging.orngDebuggingEnabled else 5,
121                   callback=self.update, callbackOnReturn=True)
122
123        box = OWGUI.widgetBox(self.controlArea, "Cluster IDs", addSpace=True)
124        cb = OWGUI.checkBox(box, self, "classifySelected", "Append cluster indices")
125        box = OWGUI.indentedBox(box, sep=OWGUI.checkButtonOffsetHint(cb))
126        form = QWidget()
127        le = OWGUI.lineEdit(form, self, "classifyName", None, #"Name" + "  ",
128                            orientation="horizontal", #controlWidth=100,
129                            valueType=str,
130#                            callback=self.sendData,
131#                            callbackOnReturn=True
132                            )
133       
134        cc = OWGUI.comboBox(form, self, "addIdAs", label = " ", #"Place" + "  ",
135                            orientation="horizontal", items = ["Class attribute", "Attribute", "Meta attribute"],
136                            )
137       
138        layout = QFormLayout()
139        layout.setSpacing(8)
140        layout.setFieldGrowthPolicy(QFormLayout.AllNonFixedFieldsGrow)
141        layout.setLabelAlignment(Qt.AlignLeft | Qt.AlignJustify)
142        layout.addRow("Name  ", le)
143        layout.addRow("Place  ", cc)
144#        le.setFixedWidth(cc.sizeHint().width())
145        form.setLayout(layout)
146        box.layout().addWidget(form)
147        left, top, right, bottom = layout.getContentsMargins()
148        layout.setContentsMargins(0, top, right, bottom)
149       
150        cb.disables.append(box)
151#        cb.disables.append(cc.box)
152        cb.makeConsistent()
153#        OWGUI.separator(box)
154       
155        box = OWGUI.widgetBox(self.controlArea, "Run")
156        cb = OWGUI.checkBox(box, self, "runAnyChange", "Run after any change")
157        self.runButton = b = OWGUI.button(box, self, "Run Clustering", callback = self.run)
158        OWGUI.setStopper(self, b, cb, "settingsChanged", callback=self.run)
159
160        OWGUI.rubber(self.controlArea)
161        # display of clustering results
162       
163       
164        self.optimizationReportBox = OWGUI.widgetBox(self.mainArea)
165        tableBox = OWGUI.widgetBox(self.optimizationReportBox, "Optimization Report")
166        self.table = OWGUI.table(tableBox, selectionMode=QTableWidget.SingleSelection)
167        self.table.setHorizontalScrollMode(QTableWidget.ScrollPerPixel)
168        self.table.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
169        self.table.setSelectionBehavior(QAbstractItemView.SelectRows)
170        self.table.setColumnCount(3)
171        self.table.setHorizontalHeaderLabels(["k", "Best", "Score"])
172        self.table.verticalHeader().hide()
173        self.table.horizontalHeader().setStretchLastSection(True)
174        self.table.setItemDelegateForColumn(2, OWGUI.TableBarItem(self, self.table))
175        self.table.setItemDelegateForColumn(1, OWGUI.IndicatorItemDelegate(self))
176        self.table.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding)
177       
178        self.connect(self.table, SIGNAL("itemSelectionChanged()"), self.tableItemSelected)
179       
180        self.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred)
181        self.mainArea.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding)
182       
183       
184        OWGUI.rubber(self.topWidgetPart)
185       
186        self.updateOptimizationGui()
187
188#        self.resize(100,100)
189
190    def adjustSize(self):
191        self.ensurePolished()
192        s = self.sizeHint()
193        self.resize(s)
194       
195    def hideOptResults(self):
196        self.mainArea.hide()
197        QTimer.singleShot(100, self.adjustSize)
198       
199       
200    def showOptResults(self):
201        self.mainArea.show()
202        QTimer.singleShot(100, self.adjustSize)
203       
204    def sizeHint(self):
205        s = self.leftWidgetPart.sizeHint()
206        if self.optimized and not self.mainArea.isHidden():
207            s.setWidth(s.width() + self.mainArea.sizeHint().width() + self.childrenRect().x() * 4)
208        return s
209       
210    def updateOptimizationGui(self):
211        self.fixedSpinBox.setDisabled(bool(self.optimized))
212        self.optimizationBox.setDisabled(not bool(self.optimized))
213        if self.optimized:
214            self.showOptResults()
215        else:
216            self.hideOptResults()
217           
218    def updateOptimizationFrom(self):
219        self.optimizationTo = max([self.optimizationFrom + 1, self.optimizationTo])
220        self.update()
221       
222    def updateOptimizationTo(self):
223        self.optimizationFrom = min([self.optimizationFrom, self.optimizationTo - 1])
224        self.update()
225       
226    def setOptimization(self):
227        self.updateOptimizationGui()
228        self.update()
229           
230    def runOptimization(self):
231        if self.optimizationTo > len(set(self.data)):
232            self.error("Not enough unique data instances (%d) for given number of clusters (%d)." % \
233                       (len(set(self.data)), self.optimizationTo))
234            return
235       
236        random.seed(0)
237        try:
238            self.progressBarInit()
239            Ks = range(self.optimizationFrom, self.optimizationTo + 1)
240            outer_callback_count = len(Ks) * self.restarts
241            outer_callback_state = {"restart": 0}
242            optimizationRun = []
243            for k in Ks:
244                def outer_progress(km):
245                    outer_callback_state["restart"] += 1
246                    self.progressBarSet(100.0 * outer_callback_state["restart"] / outer_callback_count)
247                   
248                def inner_progress(km):
249                    estimate = self.progressEstimate(km)
250                    self.progressBarSet(min(estimate / outer_callback_count + \
251                                            outer_callback_state["restart"] * \
252                                            100.0 / outer_callback_count,
253                                            100.0))
254                     
255                kmeans = orngClustering.KMeans(
256                    self.data,
257                    centroids = k,
258                    minscorechange=0,
259                    nstart = self.restarts,
260                    initialization = self.initializations[self.initializationType][1],
261                    distance = self.distanceMeasures[self.distanceMeasure][1],
262                    scoring = self.scoringMethods[self.scoring][1],
263                    outer_callback = outer_progress,
264                    inner_callback = inner_progress
265                    )
266                optimizationRun.append((k, kmeans))
267               
268                if self.restarts == 1:
269                    outer_progress(None)
270                   
271            self.optimizationRun = optimizationRun
272            self.progressBarFinished()
273            self.bestRun = (min if getattr(self.scoringMethods[self.scoring][1], "minimize", False) else max)(self.optimizationRun, key=lambda (k, run): run.score)
274            self.showResults()
275            self.sendData()
276        except Exception, ex:
277            self.error(0, "An error occured while running optimization. Reason: " + str(ex))
278            raise
279       
280    def cluster(self):
281        if self.K > len(set(self.data)):
282            self.error("Not enough unique data instances (%d) for given number of clusters (%d)." % \
283                       (len(set(self.data)), self.K))
284            return
285        random.seed(0)
286       
287        self.km = orngClustering.KMeans(
288            centroids = self.K,
289            minscorechange=0,
290            nstart = self.restarts,
291            initialization = self.initializations[self.initializationType][1],
292            distance = self.distanceMeasures[self.distanceMeasure][1],
293            scoring = self.scoringMethods[self.scoring][1],
294            inner_callback = self.clusterCallback,
295            )
296        self.progressBarInit()
297        self.km(self.data)
298        self.sendData()
299        self.progressBarFinished()
300
301    def clusterCallback(self, km):
302        norm = math.log(len(self.data), 10)
303        if km.iteration < norm:
304            self.progressBarSet(80.0 * km.iteration / norm)
305        else:
306            self.progressBarSet(80.0 + 0.15 * (1.0 - math.exp(norm - km.iteration)))
307           
308    def progressEstimate(self, km):
309        norm = math.log(len(km.data), 10)
310        if km.iteration < norm:
311            return min(80.0 * km.iteration / norm, 90.0)
312        else:
313            return min(80.0 + 0.15 * (1.0 - math.exp(norm - km.iteration)), 90.0)
314       
315    def showResults(self):
316        self.table.setRowCount(len(self.optimizationRun))
317        bestScore = self.bestRun[1].score
318        worstScore = (max if getattr(self.scoringMethods[self.scoring][1], "minimize", False) else min)([km.score for k, km in self.optimizationRun])
319        for i, (k, run) in enumerate(self.optimizationRun):
320            item = OWGUI.tableItem(self.table, i, 0, k)
321            item.setData(Qt.TextAlignmentRole, QVariant(Qt.AlignCenter))
322           
323            item = OWGUI.tableItem(self.table, i, 1, None)#" " if (k, run) == self.bestRun else "")
324            item.setData(OWGUI.IndicatorItemDelegate.IndicatorRole, QVariant((k, run) == self.bestRun))
325#            item.setData(Qt.DecorationRole, QVariant(QIcon(os.path.join(os.path.dirname(OWGUI.__file__), "icons", "circle.png"))))
326            item.setData(Qt.TextAlignmentRole, QVariant(Qt.AlignCenter))
327           
328            fmt = lambda score, max_decimals=10: "%%.%if" % min(int(abs(math.log(max(score, 1e-10)))) + 2, max_decimals) if score > 0 and score < 1 else "%.1f"
329            item = OWGUI.tableItem(self.table, i, 2, fmt(run.score) % run.score)
330            item.setData(OWGUI.TableBarItem.BarRole, QVariant((bestScore - run.score) / ((bestScore - worstScore) or 1) * 0.95))
331            if (k, run) == self.bestRun:
332                self.table.selectRow(i)
333           
334        for i in range(2):
335            self.table.resizeColumnToContents(i)
336        self.table.show()
337        qApp.processEvents()
338        self.adjustSize()
339
340    def run(self):
341        self.error(0)
342        if not self.data:
343            return
344        if self.optimized:
345            self.runOptimization()
346        else:
347            self.cluster()
348
349    def update(self):
350        if self.runAnyChange:
351            self.run()
352        else:
353            self.settingsChanged = True
354           
355    def tableItemSelected(self):
356        selectedItems = self.table.selectedItems()
357        rows = set([item.row() for item in selectedItems])
358        if len(rows) == 1:
359            row = rows.pop()
360            self.sendData(self.optimizationRun[row][1])
361
362    def sendData(self, km=None):
363        if km is None:
364            km = self.bestRun[1] if self.optimized else self.km
365        if not self.data or not km:
366            self.send("Data", None)
367            self.send("Centroids", None)
368            return
369        clustVar = orange.EnumVariable(self.classifyName, values = ["C%d" % (x+1) for x in range(km.k)])
370
371        origDomain = self.data.domain
372        if self.addIdAs == 0:
373            domain=orange.Domain(origDomain.attributes,clustVar)
374            if origDomain.classVar:
375                domain.addmeta(orange.newmetaid(), origDomain.classVar)
376            aid = -1
377        elif self.addIdAs == 1:
378            domain=orange.Domain(origDomain.attributes+[clustVar], origDomain.classVar)
379            aid = len(origDomain.attributes)
380        else:
381            domain=orange.Domain(origDomain.attributes, origDomain.classVar)
382            aid=orange.newmetaid()
383            domain.addmeta(aid, clustVar)
384
385        domain.addmetas(origDomain.getmetas())
386
387        # construct a new data set, with a class as assigned by k-means clustering
388        new = orange.ExampleTable(domain, self.data)
389        for ex, midx in izip(new, km.clusters):
390            ex[aid] = midx
391       
392        centroids = orange.ExampleTable(domain, km.centroids)
393        for i, c in enumerate(centroids):
394            c[aid] = i
395            if origDomain.classVar:
396                c[origDomain.classVar] = "?"
397
398        self.send("Data", new)
399        self.send("Centroids", centroids)
400       
401    def setData(self, data):
402        """Handle data from the input signal."""
403        self.runButton.setEnabled(bool(data))
404        if not data:
405            self.data = None
406            self.table.setRowCount(0)
407        else:
408            self.data = data
409            self.run()
410
411    def sendReport(self):
412        settings = [("Distance measure", self.distanceMeasures[self.distanceMeasure][0]),
413                    ("Initialization", self.initializations[self.initializationType][0]),
414                    ("Restarts", self.restarts)]
415        if self.optimized:
416            self.reportSettings("Settings", settings)
417            self.reportSettings("Optimization", [("Minimum num. of clusters", self.optimizationFrom),
418                                                 ("Maximum num. of clusters", self.optimizationTo),
419                                                 ("Scoring method", self.scoringMethods[self.scoring][0])])
420        else:
421            self.reportSettings("Settings", settings + [("Number of clusters (K)", self.K)])
422        self.reportData(self.data)
423        if self.optimized:
424            self.reportSection("Cluster size optimization report")
425            import OWReport 
426            self.reportRaw(OWReport.reportTable(self.table))
427
428
429##############################################################################
430
431class colorItem(QTableWidgetItem):
432    def __init__(self, table, i, j, text, flags = Qt.ItemIsSelectable | Qt.ItemIsEnabled | Qt.ItemIsEditable, color=Qt.lightGray):
433        self.color = color
434        QTableWidgetItem.__init__(self, unicode(text))
435        self.setFlags(flags)
436        table.setItem(i, j, self)
437
438    def paint(self, painter, colorgroup, rect, selected):
439        g = QPalette(colorgroup)
440        g.setColor(QPalette.Base, self.color)
441        QTableWidgetItem.paint(self, painter, g, rect, selected)
442
443
444##################################################################################################
445# Test this widget
446
447if __name__=="__main__":
448    import orange
449    a = QApplication(sys.argv)
450    ow = OWKMeans()
451    d = orange.ExampleTable("../../doc/datasets/iris.tab")
452    ow.setData(d)
453    ow.show()
454    a.exec_()
455    ow.saveSettings()
Note: See TracBrowser for help on using the repository browser.