source: orange/Orange/OrangeWidgets/Unsupervised/OWPCA.py @ 11820:399588413aa1

Revision 11820:399588413aa1, 21.3 KB checked in by Ales Erjavec <ales.erjavec@…>, 4 months ago (diff)

Using PyQwt5 for PCA scree plot.

Line 
1"""
2<name>PCA</name>
3<description>Perform Principal Component Analysis</description>
4<contact>ales.erjavec(@ at @)fri.uni-lj.si</contact>
5<icon>icons/PCA.svg</icon>
6<tags>pca,principal,component,projection</tags>
7<priority>3050</priority>
8
9"""
10import sys
11
12import numpy as np
13
14from PyQt4.Qwt5 import QwtPlot, QwtPlotCurve, QwtSymbol
15from PyQt4.QtCore import pyqtSignal as Signal, pyqtSlot as Slot
16
17import Orange
18import Orange.projection.linear as plinear
19
20from OWWidget import *
21from OWGraph import OWGraph
22
23import OWGUI
24
25
26def plot_curve(title=None, pen=None, brush=None, style=QwtPlotCurve.Lines,
27               symbol=QwtSymbol.Ellipse, legend=True, antialias=True,
28               auto_scale=True, xaxis=QwtPlot.xBottom, yaxis=QwtPlot.yLeft):
29    curve = QwtPlotCurve(title or "")
30    return configure_curve(curve, pen=pen, brush=brush, style=style,
31                           symbol=symbol, legend=legend, antialias=antialias,
32                           auto_scale=auto_scale, xaxis=xaxis, yaxis=yaxis)
33
34
35def configure_curve(curve, title=None, pen=None, brush=None,
36          style=QwtPlotCurve.Lines, symbol=QwtSymbol.Ellipse,
37          legend=True, antialias=True, auto_scale=True,
38          xaxis=QwtPlot.xBottom, yaxis=QwtPlot.yLeft):
39    if title is not None:
40        curve.setTitle(title)
41    if pen is not None:
42        curve.setPen(pen)
43
44    if brush is not None:
45        curve.setBrush(brush)
46
47    if not isinstance(symbol, QwtSymbol):
48        symbol_ = QwtSymbol()
49        symbol_.setStyle(symbol)
50        symbol = symbol_
51
52    curve.setStyle(style)
53    curve.setSymbol(QwtSymbol(symbol))
54    curve.setRenderHint(QwtPlotCurve.RenderAntialiased, antialias)
55    curve.setItemAttribute(QwtPlotCurve.Legend, legend)
56    curve.setItemAttribute(QwtPlotCurve.AutoScale, auto_scale)
57    curve.setAxis(xaxis, yaxis)
58    return curve
59
60
61class PlotTool(QObject):
62    """
63    A base class for Plot tools that operate on QwtPlot's canvas
64    widget by installing itself as its event filter.
65
66    """
67    cursor = Qt.ArrowCursor
68
69    def __init__(self, parent=None, graph=None):
70        QObject.__init__(self, parent)
71        self.__graph = None
72        self.__oldCursor = None
73        self.setGraph(graph)
74
75    def setGraph(self, graph):
76        """
77        Install this tool to operate on ``graph``.
78        """
79        if self.__graph is graph:
80            return
81
82        if self.__graph is not None:
83            self.uninstall(self.__graph)
84
85        self.__graph = graph
86
87        if graph is not None:
88            self.install(graph)
89
90    def graph(self):
91        return self.__graph
92
93    def install(self, graph):
94        canvas = graph.canvas()
95        canvas.setMouseTracking(True)
96        canvas.installEventFilter(self)
97        canvas.destroyed.connect(self.__on_destroyed)
98        self.__oldCursor = canvas.cursor()
99        canvas.setCursor(self.cursor)
100
101    def uninstall(self, graph):
102        canvas = graph.canvas()
103        canvas.removeEventFilter(self)
104        canvas.setCursor(self.__oldCursor)
105        canvas.destroyed.disconnect(self.__on_destroyed)
106        self.__oldCursor = None
107
108    def eventFilter(self, obj, event):
109        if obj is self.__graph.canvas():
110            return self.canvasEvent(event)
111        return False
112
113    def canvasEvent(self, event):
114        """
115        Main handler for a canvas events.
116        """
117        if event.type() == QEvent.MouseButtonPress:
118            return self.mousePressEvent(event)
119        elif event.type() == QEvent.MouseButtonRelease:
120            return self.mouseReleaseEvent(event)
121        elif event.type() == QEvent.MouseButtonDblClick:
122            return self.mouseDoubleClickEvent(event)
123        elif event.type() == QEvent.MouseMove:
124            return self.mouseMoveEvent(event)
125        elif event.type() == QEvent.Leave:
126            return self.leaveEvent(event)
127        elif event.type() == QEvent.Enter:
128            return self.enterEvent(event)
129        return False
130
131    # These are actually event filters (note the return values)
132    def mousePressEvent(self, event):
133        return False
134
135    def mouseMoveEvent(self, event):
136        return False
137
138    def mouseReleaseEvent(self, event):
139        return False
140
141    def mouseDoubleClickEvent(self, event):
142        return False
143
144    def enterEvent(self, event):
145        return False
146
147    def leaveEvent(self, event):
148        return False
149
150    def keyPressEvent(self, event):
151        return False
152
153    def transform(self, point, xaxis=QwtPlot.xBottom, yaxis=QwtPlot.yLeft):
154        """
155        Transform a QPointF from plot coordinates to canvas local coordinates.
156        """
157        x = self.__graph.transform(xaxis, point.x())
158        y = self.__graph.transform(yaxis, point.y())
159        return QPoint(x, y)
160
161    def invTransform(self, point, xaxis=QwtPlot.xBottom, yaxis=QwtPlot.yLeft):
162        """
163        Transform a QPoint from canvas local coordinates to plot coordinates.
164        """
165        x = self.__graph.invTransform(xaxis, point.x())
166        y = self.__graph.invTransform(yaxis, point.y())
167        return QPointF(x, y)
168
169    @Slot()
170    def __on_destroyed(self, obj):
171        obj.removeEventFilter(self)
172
173
174class CutoffControler(PlotTool):
175
176    class CutoffCurve(QwtPlotCurve):
177        pass
178
179    cutoffChanged = Signal(float)
180    cutoffMoved = Signal(float)
181    cutoffPressed = Signal()
182    cutoffReleased = Signal()
183
184    NoState, Drag = 0, 1
185
186    def __init__(self, parent=None, graph=None):
187        self.__curve = None
188        self.__range = (0, 1)
189        self.__cutoff = 0
190        super(CutoffControler, self).__init__(parent, graph)
191        self._state = self.NoState
192
193    def install(self, graph):
194        super(CutoffControler, self).install(graph)
195        assert self.__curve is None
196        self.__curve = CutoffControler.CutoffCurve("")
197        configure_curve(self.__curve, symbol=QwtSymbol.NoSymbol, legend=False)
198        self.__curve.setData([self.__cutoff, self.__cutoff], [0.0, 1.0])
199        self.__curve.attach(graph)
200
201    def uninstall(self, graph):
202        super(CutoffControler, self).uninstall(graph)
203        self.__curve.detach()
204        self.__curve = None
205
206    def _toRange(self, value):
207        minval, maxval = self.__range
208        return max(min(value, maxval), minval)
209
210    def mousePressEvent(self, event):
211        if event.button() == Qt.LeftButton:
212            cut = self.invTransform(event.pos()).x()
213            self.setCutoff(cut)
214            self.cutoffPressed.emit()
215            self._state = self.Drag
216        return True
217
218    def mouseMoveEvent(self, event):
219        if self._state == self.Drag:
220            cut = self._toRange(self.invTransform(event.pos()).x())
221            self.setCutoff(cut)
222            self.cutoffMoved.emit(cut)
223        else:
224            cx = self.transform(QPointF(self.cutoff(), 0)).x()
225            if abs(cx - event.pos().x()) < 2:
226                self.graph().canvas().setCursor(Qt.SizeHorCursor)
227            else:
228                self.graph().canvas().setCursor(self.cursor)
229        return True
230
231    def mouseReleaseEvent(self, event):
232        if event.button() == Qt.LeftButton and self._state == self.Drag:
233            cut = self._toRange(self.invTransform(event.pos()).x())
234            self.setCutoff(cut)
235            self.cutoffReleased.emit()
236            self._state = self.NoState
237        return True
238
239    def setCutoff(self, cutoff):
240        minval, maxval = self.__range
241        cutoff = max(min(cutoff, maxval), minval)
242        if self.__cutoff != cutoff:
243            self.__cutoff = cutoff
244            if self.__curve is not None:
245                self.__curve.setData([cutoff, cutoff], [0.0, 1.0])
246            self.cutoffChanged.emit(cutoff)
247            if self.graph() is not None:
248                self.graph().replot()
249
250    def cutoff(self):
251        return self.__cutoff
252
253    def setRange(self, minval, maxval):
254        maxval = max(minval, maxval)
255        if self.__range != (minval, maxval):
256            self.__range = (minval, maxval)
257            self.setCutoff(max(min(self.cutoff(), maxval), minval))
258
259
260class Graph(OWGraph):
261    def __init__(self, *args, **kwargs):
262        super(Graph, self).__init__(*args, **kwargs)
263        self.gridCurve.attach(self)
264
265    # bypass the OWGraph event handlers
266    def mousePressEvent(self, event):
267        QwtPlot.mousePressEvent(self, event)
268
269    def mouseMoveEvent(self, event):
270        QwtPlot.mouseMoveEvent(self, event)
271
272    def mouseReleaseEvent(self, event):
273        QwtPlot.mouseReleaseEvent(self, event)
274
275
276class OWPCA(OWWidget):
277    settingsList = ["standardize", "max_components", "variance_covered",
278                    "use_generalized_eigenvectors", "auto_commit"]
279
280    def __init__(self, parent=None, signalManager=None, title="PCA"):
281        OWWidget.__init__(self, parent, signalManager, title, wantGraph=True)
282
283        self.inputs = [("Input Data", Orange.data.Table, self.set_data)]
284        self.outputs = [("Transformed Data", Orange.data.Table, Default),
285                        ("Eigen Vectors", Orange.data.Table)]
286
287        self.standardize = True
288        self.max_components = 0
289        self.variance_covered = 100.0
290        self.use_generalized_eigenvectors = False
291        self.auto_commit = False
292
293        self.loadSettings()
294
295        self.data = None
296        self.changed_flag = False
297
298        #####
299        # GUI
300        #####
301        grid = QGridLayout()
302        box = OWGUI.widgetBox(self.controlArea, "Components Selection",
303                              orientation=grid)
304
305        label1 = QLabel("Max components", box)
306        grid.addWidget(label1, 1, 0)
307
308        sb1 = OWGUI.spin(box, self, "max_components", 0, 1000,
309                         tooltip="Maximum number of components",
310                         callback=self.on_update,
311                         addToLayout=False,
312                         keyboardTracking=False
313                         )
314        self.max_components_spin = sb1.control
315        self.max_components_spin.setSpecialValueText("All")
316        grid.addWidget(sb1.control, 1, 1)
317
318        label2 = QLabel("Variance covered", box)
319        grid.addWidget(label2, 2, 0)
320
321        sb2 = OWGUI.doubleSpin(box, self, "variance_covered", 1.0, 100.0, 1.0,
322                               tooltip="Percent of variance covered.",
323                               callback=self.on_update,
324                               decimals=1,
325                               addToLayout=False,
326                               keyboardTracking=False
327                               )
328        sb2.control.setSuffix("%")
329        grid.addWidget(sb2.control, 2, 1)
330
331        OWGUI.rubber(self.controlArea)
332
333        box = OWGUI.widgetBox(self.controlArea, "Commit")
334        cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change")
335        b = OWGUI.button(box, self, "Commit",
336                         callback=self.update_components)
337        OWGUI.setStopper(self, b, cb, "changed_flag", self.update_components)
338
339        self.plot = Graph()
340        canvas = self.plot.canvas()
341        canvas.setFrameStyle(QFrame.StyledPanel)
342        self.mainArea.layout().addWidget(self.plot)
343        self.plot.setAxisTitle(QwtPlot.yLeft, "Proportion of Variance")
344        self.plot.setAxisTitle(QwtPlot.xBottom, "Principal Components")
345        self.plot.setAxisScale(QwtPlot.yLeft, 0.0, 1.0)
346        self.plot.enableGridXB(True)
347        self.plot.enableGridYL(True)
348        self.plot.setGridColor(Qt.lightGray)
349
350        self.variance_curve = plot_curve(
351            "Variance",
352            pen=QPen(Qt.red, 2),
353            symbol=QwtSymbol.NoSymbol,
354            xaxis=QwtPlot.xBottom,
355            yaxis=QwtPlot.yLeft
356        )
357        self.cumulative_variance_curve = plot_curve(
358            "Cumulative Variance",
359            pen=QPen(Qt.darkYellow, 2),
360            symbol=QwtSymbol.NoSymbol,
361            xaxis=QwtPlot.xBottom,
362            yaxis=QwtPlot.yLeft
363        )
364
365        self.variance_curve.attach(self.plot)
366        self.cumulative_variance_curve.attach(self.plot)
367
368        self.selection_tool = CutoffControler(parent=self.plot.canvas())
369        self.selection_tool.cutoffMoved.connect(self.on_cutoff_moved)
370
371        self.graphButton.clicked.connect(self.saveToFile)
372        self.components = None
373        self.variances = None
374        self.variances_sum = None
375        self.projector_full = None
376        self.currently_selected = 0
377
378        self.resize(800, 400)
379
380    def clear(self):
381        """
382        Clear (reset) the widget state.
383        """
384        self.data = None
385        self.selection_tool.setGraph(None)
386        self.clear_cached()
387        self.variance_curve.setVisible(False)
388        self.cumulative_variance_curve.setVisible(False)
389
390    def clear_cached(self):
391        """Clear cached components
392        """
393        self.components = None
394        self.variances = None
395        self.variances_cumsum = None
396        self.projector_full = None
397        self.currently_selected = 0
398
399    def set_data(self, data=None):
400        """Set the widget input data.
401        """
402        self.clear()
403        if data is not None:
404            self.data = data
405            self.on_change()
406        else:
407            self.send("Transformed Data", None)
408            self.send("Eigen Vectors", None)
409
410    def on_change(self):
411        """Data has changed and we need to recompute the projection.
412        """
413        if self.data is None:
414            return
415        self.clear_cached()
416        self.apply()
417
418    def on_update(self):
419        """Component selection was changed by the user.
420        """
421        if self.data is None:
422            return
423        self.update_cutoff_curve()
424        if self.currently_selected != self.number_of_selected_components():
425            self.update_components_if()
426
427    def construct_pca_all_comp(self):
428        pca = plinear.PCA(standardize=self.standardize,
429                          max_components=0,
430                          variance_covered=1,
431                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
432                          )
433        return pca
434
435    def construct_pca(self):
436        max_components = self.max_components
437        variance_covered = self.variance_covered
438        pca = plinear.PCA(standardize=self.standardize,
439                          max_components=max_components,
440                          variance_covered=variance_covered / 100.0,
441                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
442                          )
443        return pca
444
445    def apply(self):
446        """
447        Apply PCA on input data, caching the full projection and
448        updating the selected components.
449
450        """
451        pca = self.construct_pca_all_comp()
452        self.projector_full = pca(self.data)
453
454        self.variances = self.projector_full.variances
455        self.variances /= np.sum(self.variances)
456        self.variances_cumsum = np.cumsum(self.variances)
457
458        self.max_components_spin.setRange(0, len(self.variances))
459        self.max_components = min(self.max_components,
460                                  len(self.variances) - 1)
461        self.update_scree_plot()
462        self.update_cutoff_curve()
463        self.update_components_if()
464
465    def update_components_if(self):
466        if self.auto_commit:
467            self.update_components()
468        else:
469            self.changed_flag = True
470
471    def update_components(self):
472        """Update the output components.
473        """
474        if self.data is None:
475            return
476
477        scale = self.projector_full.scale
478        center = self.projector_full.center
479        components = self.projector_full.projection
480        input_domain = self.projector_full.input_domain
481        variances = self.projector_full.variances
482
483        # Get selected components (based on max_components and
484        # variance_coverd)
485        pca = self.construct_pca()
486        variances, components, variance_sum = pca._select_components(variances, components)
487
488        projector = plinear.PcaProjector(input_domain=input_domain,
489                                         standardize=self.standardize,
490                                         scale=scale,
491                                         center=center,
492                                         projection=components,
493                                         variances=variances,
494                                         variance_sum=variance_sum)
495        projected_data = projector(self.data)
496
497        append_metas(projected_data, self.data)
498
499        eigenvectors = self.eigenvectors_as_table(components)
500
501        self.currently_selected = self.number_of_selected_components()
502
503        self.send("Transformed Data", projected_data)
504        self.send("Eigen Vectors", eigenvectors)
505
506        self.changed_flag = False
507
508    def eigenvectors_as_table(self, U):
509        features = [Orange.feature.Continuous("C%i" % i) \
510                    for i in range(1, U.shape[1] + 1)]
511        domain = Orange.data.Domain(features, False)
512        return Orange.data.Table(domain, [list(v) for v in U])
513
514    def update_scree_plot(self):
515        x_space = np.arange(0, len(self.variances))
516        self.plot.enableAxis(QwtPlot.xBottom, True)
517        self.plot.enableAxis(QwtPlot.yLeft, True)
518        if len(x_space) <= 5:
519            self.plot.setXlabels(["PC" + str(i + 1) for i in x_space])
520        else:
521            # Restore continuous plot scale
522            # TODO: disable minor ticks
523            self.plot.setXlabels(None)
524
525        self.variance_curve.setData(x_space, self.variances)
526        self.cumulative_variance_curve.setData(x_space, self.variances_cumsum)
527        self.variance_curve.setVisible(True)
528        self.cumulative_variance_curve.setVisible(True)
529
530        self.selection_tool.setRange(0, len(self.variances) - 1)
531        self.selection_tool.setGraph(self.plot)
532        self.plot.replot()
533
534    def on_cutoff_moved(self, value):
535        """Cutoff curve was moved by the user.
536        """
537        components = int(np.floor(value)) + 1
538        # Did the number of components actually change
539        self.max_components = components
540        self.variance_covered = self.variances_cumsum[components - 1] * 100
541        if self.currently_selected != self.number_of_selected_components():
542            self.update_components_if()
543
544    def update_cutoff_curve(self):
545        """Update cutoff curve from 'Components Selection' control box.
546        """
547        if self.max_components == 0:
548            # Special "All" value
549            max_components = len(self.variances_cumsum)
550        else:
551            max_components = self.max_components
552
553        variance = self.variances_cumsum[max_components - 1] * 100.0
554        if variance < self.variance_covered:
555            cutoff = max_components - 1
556        else:
557            cutoff = np.searchsorted(self.variances_cumsum,
558                                     self.variance_covered / 100.0)
559
560        self.selection_tool.setCutoff(float(cutoff + 0.5))
561
562    def number_of_selected_components(self):
563        """How many components are selected.
564        """
565        if self.data is None:
566            return 0
567
568        variance_components = np.searchsorted(self.variances_cumsum,
569                                              self.variance_covered / 100.0)
570        if self.max_components == 0:
571            # Special "All" value
572            max_components = len(self.variances_cumsum)
573        else:
574            max_components = self.max_components
575        return min(variance_components + 1, max_components)
576
577    def sendReport(self):
578        self.reportSettings("PCA Settings",
579                            [("Max. components", self.max_components),
580                             ("Variance covered", "%i%%" % self.variance_covered),
581                             ])
582        if self.data is not None and self.projector_full:
583            output_domain = self.projector_full.output_domain
584            st_dev = np.sqrt(self.projector_full.variances)
585            summary = [[""] + [a.name for a in output_domain.attributes],
586                       ["Std. deviation"] + ["%.3f" % sd for sd in st_dev],
587                       ["Proportion Var"] + ["%.3f" % v for v in self.variances * 100.0],
588                       ["Cumulative Var"] + ["%.3f" % v for v in self.variances_cumsum * 100.0]
589                       ]
590
591            th = "<th>%s</th>".__mod__
592            header = "".join(map(th, summary[0]))
593            td = "<td>%s</td>".__mod__
594            summary = ["".join(map(td, row)) for row in summary[1:]]
595            tr = "<tr>%s</tr>".__mod__
596            summary = "\n".join(map(tr, [header] + summary))
597            summary = "<table>\n%s\n</table>" % summary
598
599            self.reportSection("Summary")
600            self.reportRaw(summary)
601
602            self.reportSection("Scree Plot")
603            self.reportImage(self.plot.saveToFileDirect)
604
605    def saveToFile(self):
606        self.plot.saveToFile()
607
608
609def append_metas(dest, source):
610    """
611    Append all meta attributes from the `source` table to `dest` table.
612    The tables must be of the same length.
613
614    :param dest:
615        An data table into which the meta values will be copied.
616    :type dest: :class:`Orange.data.Table`
617
618    :param source:
619        A data table with the meta attributes/values to be copied into `dest`.
620    :type source: :class:`Orange.data.Table`
621
622    """
623    if len(dest) != len(source):
624        raise ValueError("'dest' and 'source' must have the same length.")
625
626    dest.domain.add_metas(source.domain.get_metas())
627    for dest_inst, source_inst in zip(dest, source):
628        for meta_id, val in source_inst.get_metas().items():
629            dest_inst[meta_id] = val
630
631
632if __name__ == "__main__":
633    app = QApplication(sys.argv)
634    w = OWPCA()
635    data = Orange.data.Table("iris")
636    w.set_data(data)
637    w.show()
638    w.set_data(Orange.data.Table("brown-selected"))
639    app.exec_()
Note: See TracBrowser for help on using the repository browser.