source: orange/Orange/OrangeWidgets/Unsupervised/OWPCA.py @ 11454:7121b9a53a58

Revision 11454:7121b9a53a58, 17.2 KB checked in by Ales Erjavec <ales.erjavec@…>, 12 months ago (diff)

Preserve meta attributes in the 'Transformed Data' output.

RevLine 
[10798]1"""
2<name>PCA</name>
[10801]3<description>Perform Principal Component Analysis</description>
4<contact>ales.erjavec(@ at @)fri.uni-lj.si</contact>
[11217]5<icon>icons/PCA.svg</icon>
[10801]6<tags>pca,principal,component,projection</tags>
[10835]7<priority>3050</priority>
[10798]8
9"""
10import Orange
11import Orange.utils.addons
12
13from OWWidget import *
14import OWGUI
15
16import Orange
17import Orange.projection.linear as plinear
18
19import numpy as np
20import sys
21
22from plot.owplot import OWPlot
23from plot.owcurve import OWCurve
24from plot import owaxis
25
[10832]26
[10798]27class ScreePlot(OWPlot):
28    def __init__(self, parent=None, name="Scree Plot"):
29        OWPlot.__init__(self, parent, name=name)
[10799]30        self.cutoff_curve = CutoffCurve([0.0, 0.0], [0.0, 1.0],
31                x_axis_key=owaxis.xBottom, y_axis_key=owaxis.yLeft)
32        self.cutoff_curve.setVisible(False)
33        self.cutoff_curve.set_style(OWCurve.Lines)
34        self.add_custom_curve(self.cutoff_curve)
35
36    def is_cutoff_enabled(self):
37        return self.cutoff_curve and self.cutoff_curve.isVisible()
38
39    def set_cutoff_curve_enabled(self, state):
40        self.cutoff_curve.setVisible(state)
41
42    def set_cutoff_value(self, value):
43        xmin, xmax = self.x_scale()
44        x = min(max(value, xmin), xmax)
45        self.cutoff_curve.set_data([x, x], [0.0, 1.0])
46
47    def mousePressEvent(self, event):
[10812]48        if self.isLegendEvent(event, QGraphicsView.mousePressEvent):
49            return
50
[10799]51        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
52            pos = self.mapToScene(event.pos())
[10832]53            x, _ = self.map_from_graph(pos)
[10799]54            xmin, xmax = self.x_scale()
[10805]55            if x >= xmin - 0.1 and x <= xmax + 0.1:
56                x = min(max(x, xmin), xmax)
[10799]57                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
58                self.emit_cutoff_moved(x)
59        return QGraphicsView.mousePressEvent(self, event)
60
61    def mouseMoveEvent(self, event):
[10812]62        if self.isLegendEvent(event, QGraphicsView.mouseMoveEvent):
63            return
64
[10799]65        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
66            pos = self.mapToScene(event.pos())
67            x, _ = self.map_from_graph(pos)
68            xmin, xmax = self.x_scale()
[10805]69            if x >= xmin - 0.5 and x <= xmax + 0.5:
70                x = min(max(x, xmin), xmax)
[10799]71                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
72                self.emit_cutoff_moved(x)
[10848]73        elif self.is_cutoff_enabled() and \
74                self.is_pos_over_cutoff_line(event.pos()):
75            self.setCursor(Qt.SizeHorCursor)
76        else:
77            self.setCursor(Qt.ArrowCursor)
78
[10799]79        return QGraphicsView.mouseMoveEvent(self, event)
80
81    def mouseReleaseEvene(self, event):
82        return QGraphicsView.mouseReleaseEvent(self, event)
83
84    def x_scale(self):
85        ax = self.axes[owaxis.xBottom]
86        if ax.labels:
[10805]87            return 0, len(ax.labels) - 1
[10799]88        elif ax.scale:
[10805]89            return ax.scale[0], ax.scale[1]
[10799]90        else:
91            raise ValueError
92
93    def emit_cutoff_moved(self, x):
94        self.emit(SIGNAL("cutoff_moved(double)"), x)
95
96    def set_axis_labels(self, *args):
97        OWPlot.set_axis_labels(self, *args)
98        self.map_transform = self.transform_for_axes()
99
[10848]100    def is_pos_over_cutoff_line(self, pos):
101        x1 = self.inv_transform(owaxis.xBottom, pos.x() - 1.5)
102        x2 = self.inv_transform(owaxis.xBottom, pos.x() + 1.5)
103        y = self.inv_transform(owaxis.yLeft, pos.y())
104        if y < 0.0 or y > 1.0:
105            return False
106        curve_data = self.cutoff_curve.data()
107        if not curve_data:
108            return False
109        cutoff = curve_data[0][0]
110        return x1 < cutoff and cutoff < x2
[10832]111
[11454]112
[10799]113class CutoffCurve(OWCurve):
114    def __init__(self, *args, **kwargs):
115        OWCurve.__init__(self, *args, **kwargs)
116        self.setAcceptHoverEvents(True)
117        self.setCursor(Qt.SizeHorCursor)
[10798]118
[10832]119
[10798]120class OWPCA(OWWidget):
121    settingsList = ["standardize", "max_components", "variance_covered",
[10807]122                    "use_generalized_eigenvectors", "auto_commit"]
[10832]123
[10798]124    def __init__(self, parent=None, signalManager=None, title="PCA"):
[10808]125        OWWidget.__init__(self, parent, signalManager, title, wantGraph=True)
[10798]126
[10799]127        self.inputs = [("Input Data", Orange.data.Table, self.set_data)]
128        self.outputs = [("Transformed Data", Orange.data.Table, Default),
129                        ("Eigen Vectors", Orange.data.Table)]
[10798]130
131        self.standardize = True
132        self.max_components = 0
133        self.variance_covered = 100.0
134        self.use_generalized_eigenvectors = False
[10807]135        self.auto_commit = False
[10798]136
137        self.loadSettings()
138
139        self.data = None
[10807]140        self.changed_flag = False
[10798]141
142        #####
143        # GUI
144        #####
145        grid = QGridLayout()
[10806]146        box = OWGUI.widgetBox(self.controlArea, "Components Selection",
[10798]147                              orientation=grid)
148
149        label1 = QLabel("Max components", box)
150        grid.addWidget(label1, 1, 0)
151
[10832]152        sb1 = OWGUI.spin(box, self, "max_components", 0, 1000,
[10798]153                         tooltip="Maximum number of components",
[10799]154                         callback=self.on_update,
[10798]155                         addToLayout=False,
156                         keyboardTracking=False
157                         )
[10799]158        self.max_components_spin = sb1.control
[10832]159        self.max_components_spin.setSpecialValueText("All")
[10798]160        grid.addWidget(sb1.control, 1, 1)
161
162        label2 = QLabel("Variance covered", box)
163        grid.addWidget(label2, 2, 0)
164
[10799]165        sb2 = OWGUI.doubleSpin(box, self, "variance_covered", 1.0, 100.0, 1.0,
[10798]166                               tooltip="Percent of variance covered.",
[10799]167                               callback=self.on_update,
[10798]168                               decimals=1,
169                               addToLayout=False,
170                               keyboardTracking=False
171                               )
172        sb2.control.setSuffix("%")
173        grid.addWidget(sb2.control, 2, 1)
174
175        OWGUI.rubber(self.controlArea)
176
[10807]177        box = OWGUI.widgetBox(self.controlArea, "Commit")
178        cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change")
179        b = OWGUI.button(box, self, "Commit",
180                         callback=self.update_components)
181        OWGUI.setStopper(self, b, cb, "changed_flag", self.update_components)
182
[10798]183        self.scree_plot = ScreePlot(self)
184#        self.scree_plot.set_main_title("Scree Plot")
185#        self.scree_plot.set_show_main_title(True)
186        self.scree_plot.set_axis_title(owaxis.xBottom, "Principal Components")
187        self.scree_plot.set_show_axis_title(owaxis.xBottom, 1)
188        self.scree_plot.set_axis_title(owaxis.yLeft, "Proportion of Variance")
189        self.scree_plot.set_show_axis_title(owaxis.yLeft, 1)
[10799]190
191        self.variance_curve = self.scree_plot.add_curve(
192                        "Variance",
[10832]193                        Qt.red, Qt.red, 2,
[10799]194                        xData=[],
195                        yData=[],
196                        style=OWCurve.Lines,
197                        enableLegend=True,
198                        lineWidth=2,
199                        autoScale=1,
200                        x_axis_key=owaxis.xBottom,
201                        y_axis_key=owaxis.yLeft,
202                        )
203
204        self.cumulative_variance_curve = self.scree_plot.add_curve(
205                        "Cumulative Variance",
[10832]206                        Qt.darkYellow, Qt.darkYellow, 2,
[10799]207                        xData=[],
208                        yData=[],
209                        style=OWCurve.Lines,
210                        enableLegend=True,
211                        lineWidth=2,
212                        autoScale=1,
213                        x_axis_key=owaxis.xBottom,
214                        y_axis_key=owaxis.yLeft,
215                        )
216
[10798]217        self.mainArea.layout().addWidget(self.scree_plot)
[10799]218        self.connect(self.scree_plot,
219                     SIGNAL("cutoff_moved(double)"),
220                     self.on_cutoff_moved
221                     )
[10808]222
223        self.connect(self.graphButton,
224                     SIGNAL("clicked()"),
225                     self.scree_plot.save_to_file)
226
[10798]227        self.components = None
228        self.variances = None
229        self.variances_sum = None
230        self.projector_full = None
[10806]231        self.currently_selected = 0
[10798]232
[10799]233        self.resize(800, 400)
[10798]234
235    def clear(self):
236        """Clear widget state
237        """
238        self.data = None
[10799]239        self.scree_plot.set_cutoff_curve_enabled(False)
[10798]240        self.clear_cached()
[10799]241        self.variance_curve.setVisible(False)
242        self.cumulative_variance_curve.setVisible(False)
243
[10798]244    def clear_cached(self):
245        """Clear cached components
246        """
247        self.components = None
248        self.variances = None
[10799]249        self.variances_cumsum = None
[10798]250        self.projector_full = None
[10806]251        self.currently_selected = 0
[10798]252
253    def set_data(self, data=None):
254        """Set the widget input data.
255        """
256        self.clear()
257        if data is not None:
258            self.data = data
259            self.on_change()
[10806]260        else:
261            self.send("Transformed Data", None)
262            self.send("Eigen Vectors", None)
[10798]263
264    def on_change(self):
[10806]265        """Data has changed and we need to recompute the projection.
266        """
[10798]267        if self.data is None:
268            return
269        self.clear_cached()
270        self.apply()
271
272    def on_update(self):
[10806]273        """Component selection was changed by the user.
274        """
[10798]275        if self.data is None:
276            return
[10799]277        self.update_cutoff_curve()
[10806]278        if self.currently_selected != self.number_of_selected_components():
[10807]279            self.update_components_if()
[10798]280
281    def construct_pca_all_comp(self):
282        pca = plinear.PCA(standardize=self.standardize,
283                          max_components=0,
284                          variance_covered=1,
285                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
286                          )
287        return pca
288
289    def construct_pca(self):
[10799]290        max_components = self.max_components
291        variance_covered = self.variance_covered
[10798]292        pca = plinear.PCA(standardize=self.standardize,
293                          max_components=max_components,
294                          variance_covered=variance_covered / 100.0,
295                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
296                          )
297        return pca
298
299    def apply(self):
[11454]300        """
301        Apply PCA on input data, caching the full projection and
302        updating the selected components.
303
[10798]304        """
305        pca = self.construct_pca_all_comp()
306        self.projector_full = projector = pca(self.data)
[10799]307
308        self.variances = self.projector_full.variances
309        self.variances /= np.sum(self.variances)
310        self.variances_cumsum = np.cumsum(self.variances)
311
[10832]312        self.max_components_spin.setRange(0, len(self.variances))
[10833]313        self.max_components = min(self.max_components,
314                                  len(self.variances) - 1)
[10798]315        self.update_scree_plot()
[10806]316        self.update_cutoff_curve()
[10807]317        self.update_components_if()
[10798]318
[10807]319    def update_components_if(self):
320        if self.auto_commit:
321            self.update_components()
322        else:
323            self.changed_flag = True
[10832]324
[10798]325    def update_components(self):
[10806]326        """Update the output components.
327        """
[10807]328        if self.data is None:
[10832]329            return
[10807]330
[10798]331        scale = self.projector_full.scale
332        center = self.projector_full.center
333        components = self.projector_full.projection
334        input_domain = self.projector_full.input_domain
335        variances = self.projector_full.variances
336        variance_sum = self.projector_full.variance_sum
337
[10832]338        # Get selected components (based on max_components and
[10798]339        # variance_coverd)
340        pca = self.construct_pca()
341        variances, components, variance_sum = pca._select_components(variances, components)
342
343        projector = plinear.PcaProjector(input_domain=input_domain,
344                                         standardize=self.standardize,
345                                         scale=scale,
346                                         center=center,
347                                         projection=components,
348                                         variances=variances,
349                                         variance_sum=variance_sum)
350        projected_data = projector(self.data)
[11454]351
352        append_metas(projected_data, self.data)
353
[10798]354        eigenvectors = self.eigenvectors_as_table(components)
355
[10806]356        self.currently_selected = self.number_of_selected_components()
357
[10799]358        self.send("Transformed Data", projected_data)
359        self.send("Eigen Vectors", eigenvectors)
[10798]360
[10807]361        self.changed_flag = False
362
[10798]363    def eigenvectors_as_table(self, U):
364        features = [Orange.feature.Continuous("C%i" % i) \
365                    for i in range(1, U.shape[1] + 1)]
366        domain = Orange.data.Domain(features, False)
367        return Orange.data.Table(domain, [list(v) for v in U])
368
369    def update_scree_plot(self):
[10806]370        x_space = np.arange(0, len(self.variances))
[10798]371        self.scree_plot.set_axis_enabled(owaxis.xBottom, True)
372        self.scree_plot.set_axis_enabled(owaxis.yLeft, True)
[10832]373        self.scree_plot.set_axis_labels(owaxis.xBottom,
[10798]374                                        ["PC" + str(i + 1) for i in x_space])
375
[10806]376        self.variance_curve.set_data(x_space, self.variances)
377        self.cumulative_variance_curve.set_data(x_space, self.variances_cumsum)
[10799]378        self.variance_curve.setVisible(True)
379        self.cumulative_variance_curve.setVisible(True)
380
381        self.scree_plot.set_cutoff_curve_enabled(True)
[10809]382        self.scree_plot.replot()
[10799]383
384    def on_cutoff_moved(self, value):
[10806]385        """Cutoff curve was moved by the user.
386        """
[10799]387        components = int(np.floor(value)) + 1
[10806]388        # Did the number of components actually change
389        self.max_components = components
390        self.variance_covered = self.variances_cumsum[components - 1] * 100
391        if self.currently_selected != self.number_of_selected_components():
[10807]392            self.update_components_if()
[10799]393
394    def update_cutoff_curve(self):
[10806]395        """Update cutoff curve from 'Components Selection' control box.
[10799]396        """
[10832]397        if self.max_components == 0:
398            # Special "All" value
399            max_components = len(self.variances_cumsum)
400        else:
401            max_components = self.max_components
402
403        variance = self.variances_cumsum[max_components - 1] * 100.0
[10799]404        if variance < self.variance_covered:
[10832]405            cutoff = float(max_components - 1)
[10799]406        else:
407            cutoff = np.searchsorted(self.variances_cumsum,
408                                     self.variance_covered / 100.0)
409        self.scree_plot.set_cutoff_value(cutoff + 0.5)
410
[10806]411    def number_of_selected_components(self):
412        """How many components are selected.
413        """
414        if self.data is None:
415            return 0
416
417        variance_components = np.searchsorted(self.variances_cumsum,
418                                              self.variance_covered / 100.0)
[10832]419        if self.max_components == 0:
420            # Special "All" value
421            max_components = len(self.variances_cumsum)
422        else:
423            max_components = self.max_components
424        return min(variance_components + 1, max_components)
[10798]425
[10808]426    def sendReport(self):
427        self.reportSettings("PCA Settings",
428                            [("Max. components", self.max_components),
429                             ("Variance covered", "%i%%" % self.variance_covered),
430                             ])
431        if self.data is not None and self.projector_full:
432            output_domain = self.projector_full.output_domain
433            st_dev = np.sqrt(self.projector_full.variances)
434            summary = [[""] + [a.name for a in output_domain.attributes],
435                       ["Std. deviation"] + ["%.3f" % sd for sd in st_dev],
436                       ["Proportion Var"] + ["%.3f" % v for v in self.variances * 100.0],
437                       ["Cumulative Var"] + ["%.3f" % v for v in self.variances_cumsum * 100.0]
438                       ]
439
440            th = "<th>%s</th>".__mod__
441            header = "".join(map(th, summary[0]))
442            td = "<td>%s</td>".__mod__
443            summary = ["".join(map(td, row)) for row in summary[1:]]
444            tr = "<tr>%s</tr>".__mod__
445            summary = "\n".join(map(tr, [header] + summary))
446            summary = "<table>\n%s\n</table>" % summary
447
448            self.reportSection("Summary")
449            self.reportRaw(summary)
450
451            self.reportSection("Scree Plot")
452            self.reportImage(self.scree_plot.save_to_file_direct)
453
[10832]454
[11454]455def append_metas(dest, source):
456    """
457    Append all meta attributes from the `source` table to `dest` table.
458    The tables must be of the same length.
459
460    :param dest:
461        An data table into which the meta values will be copied.
462    :type dest: :class:`Orange.data.Table`
463
464    :param source:
465        A data table with the meta attributes/values to be copied into `dest`.
466    :type source: :class:`Orange.data.Table`
467
468    """
469    if len(dest) != len(source):
470        raise ValueError("'dest' and 'source' must have the same length.")
471
472    dest.domain.add_metas(source.domain.get_metas())
473    for dest_inst, source_inst in zip(dest, source):
474        for meta_id, val in source_inst.get_metas().items():
475            dest_inst[meta_id] = val
476
477
[10798]478if __name__ == "__main__":
479    app = QApplication(sys.argv)
480    w = OWPCA()
481    data = Orange.data.Table("iris")
482    w.set_data(data)
483    w.show()
484    app.exec_()
Note: See TracBrowser for help on using the repository browser.