source: orange/Orange/OrangeWidgets/Unsupervised/OWPCA.py @ 11454:7121b9a53a58

Revision 11454:7121b9a53a58, 17.2 KB checked in by Ales Erjavec <ales.erjavec@…>, 12 months ago (diff)

Preserve meta attributes in the 'Transformed Data' output.

Line 
1"""
2<name>PCA</name>
3<description>Perform Principal Component Analysis</description>
4<contact>ales.erjavec(@ at @)fri.uni-lj.si</contact>
5<icon>icons/PCA.svg</icon>
6<tags>pca,principal,component,projection</tags>
7<priority>3050</priority>
8
9"""
10import Orange
11import Orange.utils.addons
12
13from OWWidget import *
14import OWGUI
15
16import Orange
17import Orange.projection.linear as plinear
18
19import numpy as np
20import sys
21
22from plot.owplot import OWPlot
23from plot.owcurve import OWCurve
24from plot import owaxis
25
26
27class ScreePlot(OWPlot):
28    def __init__(self, parent=None, name="Scree Plot"):
29        OWPlot.__init__(self, parent, name=name)
30        self.cutoff_curve = CutoffCurve([0.0, 0.0], [0.0, 1.0],
31                x_axis_key=owaxis.xBottom, y_axis_key=owaxis.yLeft)
32        self.cutoff_curve.setVisible(False)
33        self.cutoff_curve.set_style(OWCurve.Lines)
34        self.add_custom_curve(self.cutoff_curve)
35
36    def is_cutoff_enabled(self):
37        return self.cutoff_curve and self.cutoff_curve.isVisible()
38
39    def set_cutoff_curve_enabled(self, state):
40        self.cutoff_curve.setVisible(state)
41
42    def set_cutoff_value(self, value):
43        xmin, xmax = self.x_scale()
44        x = min(max(value, xmin), xmax)
45        self.cutoff_curve.set_data([x, x], [0.0, 1.0])
46
47    def mousePressEvent(self, event):
48        if self.isLegendEvent(event, QGraphicsView.mousePressEvent):
49            return
50
51        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
52            pos = self.mapToScene(event.pos())
53            x, _ = self.map_from_graph(pos)
54            xmin, xmax = self.x_scale()
55            if x >= xmin - 0.1 and x <= xmax + 0.1:
56                x = min(max(x, xmin), xmax)
57                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
58                self.emit_cutoff_moved(x)
59        return QGraphicsView.mousePressEvent(self, event)
60
61    def mouseMoveEvent(self, event):
62        if self.isLegendEvent(event, QGraphicsView.mouseMoveEvent):
63            return
64
65        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
66            pos = self.mapToScene(event.pos())
67            x, _ = self.map_from_graph(pos)
68            xmin, xmax = self.x_scale()
69            if x >= xmin - 0.5 and x <= xmax + 0.5:
70                x = min(max(x, xmin), xmax)
71                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
72                self.emit_cutoff_moved(x)
73        elif self.is_cutoff_enabled() and \
74                self.is_pos_over_cutoff_line(event.pos()):
75            self.setCursor(Qt.SizeHorCursor)
76        else:
77            self.setCursor(Qt.ArrowCursor)
78
79        return QGraphicsView.mouseMoveEvent(self, event)
80
81    def mouseReleaseEvene(self, event):
82        return QGraphicsView.mouseReleaseEvent(self, event)
83
84    def x_scale(self):
85        ax = self.axes[owaxis.xBottom]
86        if ax.labels:
87            return 0, len(ax.labels) - 1
88        elif ax.scale:
89            return ax.scale[0], ax.scale[1]
90        else:
91            raise ValueError
92
93    def emit_cutoff_moved(self, x):
94        self.emit(SIGNAL("cutoff_moved(double)"), x)
95
96    def set_axis_labels(self, *args):
97        OWPlot.set_axis_labels(self, *args)
98        self.map_transform = self.transform_for_axes()
99
100    def is_pos_over_cutoff_line(self, pos):
101        x1 = self.inv_transform(owaxis.xBottom, pos.x() - 1.5)
102        x2 = self.inv_transform(owaxis.xBottom, pos.x() + 1.5)
103        y = self.inv_transform(owaxis.yLeft, pos.y())
104        if y < 0.0 or y > 1.0:
105            return False
106        curve_data = self.cutoff_curve.data()
107        if not curve_data:
108            return False
109        cutoff = curve_data[0][0]
110        return x1 < cutoff and cutoff < x2
111
112
113class CutoffCurve(OWCurve):
114    def __init__(self, *args, **kwargs):
115        OWCurve.__init__(self, *args, **kwargs)
116        self.setAcceptHoverEvents(True)
117        self.setCursor(Qt.SizeHorCursor)
118
119
120class OWPCA(OWWidget):
121    settingsList = ["standardize", "max_components", "variance_covered",
122                    "use_generalized_eigenvectors", "auto_commit"]
123
124    def __init__(self, parent=None, signalManager=None, title="PCA"):
125        OWWidget.__init__(self, parent, signalManager, title, wantGraph=True)
126
127        self.inputs = [("Input Data", Orange.data.Table, self.set_data)]
128        self.outputs = [("Transformed Data", Orange.data.Table, Default),
129                        ("Eigen Vectors", Orange.data.Table)]
130
131        self.standardize = True
132        self.max_components = 0
133        self.variance_covered = 100.0
134        self.use_generalized_eigenvectors = False
135        self.auto_commit = False
136
137        self.loadSettings()
138
139        self.data = None
140        self.changed_flag = False
141
142        #####
143        # GUI
144        #####
145        grid = QGridLayout()
146        box = OWGUI.widgetBox(self.controlArea, "Components Selection",
147                              orientation=grid)
148
149        label1 = QLabel("Max components", box)
150        grid.addWidget(label1, 1, 0)
151
152        sb1 = OWGUI.spin(box, self, "max_components", 0, 1000,
153                         tooltip="Maximum number of components",
154                         callback=self.on_update,
155                         addToLayout=False,
156                         keyboardTracking=False
157                         )
158        self.max_components_spin = sb1.control
159        self.max_components_spin.setSpecialValueText("All")
160        grid.addWidget(sb1.control, 1, 1)
161
162        label2 = QLabel("Variance covered", box)
163        grid.addWidget(label2, 2, 0)
164
165        sb2 = OWGUI.doubleSpin(box, self, "variance_covered", 1.0, 100.0, 1.0,
166                               tooltip="Percent of variance covered.",
167                               callback=self.on_update,
168                               decimals=1,
169                               addToLayout=False,
170                               keyboardTracking=False
171                               )
172        sb2.control.setSuffix("%")
173        grid.addWidget(sb2.control, 2, 1)
174
175        OWGUI.rubber(self.controlArea)
176
177        box = OWGUI.widgetBox(self.controlArea, "Commit")
178        cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change")
179        b = OWGUI.button(box, self, "Commit",
180                         callback=self.update_components)
181        OWGUI.setStopper(self, b, cb, "changed_flag", self.update_components)
182
183        self.scree_plot = ScreePlot(self)
184#        self.scree_plot.set_main_title("Scree Plot")
185#        self.scree_plot.set_show_main_title(True)
186        self.scree_plot.set_axis_title(owaxis.xBottom, "Principal Components")
187        self.scree_plot.set_show_axis_title(owaxis.xBottom, 1)
188        self.scree_plot.set_axis_title(owaxis.yLeft, "Proportion of Variance")
189        self.scree_plot.set_show_axis_title(owaxis.yLeft, 1)
190
191        self.variance_curve = self.scree_plot.add_curve(
192                        "Variance",
193                        Qt.red, Qt.red, 2,
194                        xData=[],
195                        yData=[],
196                        style=OWCurve.Lines,
197                        enableLegend=True,
198                        lineWidth=2,
199                        autoScale=1,
200                        x_axis_key=owaxis.xBottom,
201                        y_axis_key=owaxis.yLeft,
202                        )
203
204        self.cumulative_variance_curve = self.scree_plot.add_curve(
205                        "Cumulative Variance",
206                        Qt.darkYellow, Qt.darkYellow, 2,
207                        xData=[],
208                        yData=[],
209                        style=OWCurve.Lines,
210                        enableLegend=True,
211                        lineWidth=2,
212                        autoScale=1,
213                        x_axis_key=owaxis.xBottom,
214                        y_axis_key=owaxis.yLeft,
215                        )
216
217        self.mainArea.layout().addWidget(self.scree_plot)
218        self.connect(self.scree_plot,
219                     SIGNAL("cutoff_moved(double)"),
220                     self.on_cutoff_moved
221                     )
222
223        self.connect(self.graphButton,
224                     SIGNAL("clicked()"),
225                     self.scree_plot.save_to_file)
226
227        self.components = None
228        self.variances = None
229        self.variances_sum = None
230        self.projector_full = None
231        self.currently_selected = 0
232
233        self.resize(800, 400)
234
235    def clear(self):
236        """Clear widget state
237        """
238        self.data = None
239        self.scree_plot.set_cutoff_curve_enabled(False)
240        self.clear_cached()
241        self.variance_curve.setVisible(False)
242        self.cumulative_variance_curve.setVisible(False)
243
244    def clear_cached(self):
245        """Clear cached components
246        """
247        self.components = None
248        self.variances = None
249        self.variances_cumsum = None
250        self.projector_full = None
251        self.currently_selected = 0
252
253    def set_data(self, data=None):
254        """Set the widget input data.
255        """
256        self.clear()
257        if data is not None:
258            self.data = data
259            self.on_change()
260        else:
261            self.send("Transformed Data", None)
262            self.send("Eigen Vectors", None)
263
264    def on_change(self):
265        """Data has changed and we need to recompute the projection.
266        """
267        if self.data is None:
268            return
269        self.clear_cached()
270        self.apply()
271
272    def on_update(self):
273        """Component selection was changed by the user.
274        """
275        if self.data is None:
276            return
277        self.update_cutoff_curve()
278        if self.currently_selected != self.number_of_selected_components():
279            self.update_components_if()
280
281    def construct_pca_all_comp(self):
282        pca = plinear.PCA(standardize=self.standardize,
283                          max_components=0,
284                          variance_covered=1,
285                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
286                          )
287        return pca
288
289    def construct_pca(self):
290        max_components = self.max_components
291        variance_covered = self.variance_covered
292        pca = plinear.PCA(standardize=self.standardize,
293                          max_components=max_components,
294                          variance_covered=variance_covered / 100.0,
295                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
296                          )
297        return pca
298
299    def apply(self):
300        """
301        Apply PCA on input data, caching the full projection and
302        updating the selected components.
303
304        """
305        pca = self.construct_pca_all_comp()
306        self.projector_full = projector = pca(self.data)
307
308        self.variances = self.projector_full.variances
309        self.variances /= np.sum(self.variances)
310        self.variances_cumsum = np.cumsum(self.variances)
311
312        self.max_components_spin.setRange(0, len(self.variances))
313        self.max_components = min(self.max_components,
314                                  len(self.variances) - 1)
315        self.update_scree_plot()
316        self.update_cutoff_curve()
317        self.update_components_if()
318
319    def update_components_if(self):
320        if self.auto_commit:
321            self.update_components()
322        else:
323            self.changed_flag = True
324
325    def update_components(self):
326        """Update the output components.
327        """
328        if self.data is None:
329            return
330
331        scale = self.projector_full.scale
332        center = self.projector_full.center
333        components = self.projector_full.projection
334        input_domain = self.projector_full.input_domain
335        variances = self.projector_full.variances
336        variance_sum = self.projector_full.variance_sum
337
338        # Get selected components (based on max_components and
339        # variance_coverd)
340        pca = self.construct_pca()
341        variances, components, variance_sum = pca._select_components(variances, components)
342
343        projector = plinear.PcaProjector(input_domain=input_domain,
344                                         standardize=self.standardize,
345                                         scale=scale,
346                                         center=center,
347                                         projection=components,
348                                         variances=variances,
349                                         variance_sum=variance_sum)
350        projected_data = projector(self.data)
351
352        append_metas(projected_data, self.data)
353
354        eigenvectors = self.eigenvectors_as_table(components)
355
356        self.currently_selected = self.number_of_selected_components()
357
358        self.send("Transformed Data", projected_data)
359        self.send("Eigen Vectors", eigenvectors)
360
361        self.changed_flag = False
362
363    def eigenvectors_as_table(self, U):
364        features = [Orange.feature.Continuous("C%i" % i) \
365                    for i in range(1, U.shape[1] + 1)]
366        domain = Orange.data.Domain(features, False)
367        return Orange.data.Table(domain, [list(v) for v in U])
368
369    def update_scree_plot(self):
370        x_space = np.arange(0, len(self.variances))
371        self.scree_plot.set_axis_enabled(owaxis.xBottom, True)
372        self.scree_plot.set_axis_enabled(owaxis.yLeft, True)
373        self.scree_plot.set_axis_labels(owaxis.xBottom,
374                                        ["PC" + str(i + 1) for i in x_space])
375
376        self.variance_curve.set_data(x_space, self.variances)
377        self.cumulative_variance_curve.set_data(x_space, self.variances_cumsum)
378        self.variance_curve.setVisible(True)
379        self.cumulative_variance_curve.setVisible(True)
380
381        self.scree_plot.set_cutoff_curve_enabled(True)
382        self.scree_plot.replot()
383
384    def on_cutoff_moved(self, value):
385        """Cutoff curve was moved by the user.
386        """
387        components = int(np.floor(value)) + 1
388        # Did the number of components actually change
389        self.max_components = components
390        self.variance_covered = self.variances_cumsum[components - 1] * 100
391        if self.currently_selected != self.number_of_selected_components():
392            self.update_components_if()
393
394    def update_cutoff_curve(self):
395        """Update cutoff curve from 'Components Selection' control box.
396        """
397        if self.max_components == 0:
398            # Special "All" value
399            max_components = len(self.variances_cumsum)
400        else:
401            max_components = self.max_components
402
403        variance = self.variances_cumsum[max_components - 1] * 100.0
404        if variance < self.variance_covered:
405            cutoff = float(max_components - 1)
406        else:
407            cutoff = np.searchsorted(self.variances_cumsum,
408                                     self.variance_covered / 100.0)
409        self.scree_plot.set_cutoff_value(cutoff + 0.5)
410
411    def number_of_selected_components(self):
412        """How many components are selected.
413        """
414        if self.data is None:
415            return 0
416
417        variance_components = np.searchsorted(self.variances_cumsum,
418                                              self.variance_covered / 100.0)
419        if self.max_components == 0:
420            # Special "All" value
421            max_components = len(self.variances_cumsum)
422        else:
423            max_components = self.max_components
424        return min(variance_components + 1, max_components)
425
426    def sendReport(self):
427        self.reportSettings("PCA Settings",
428                            [("Max. components", self.max_components),
429                             ("Variance covered", "%i%%" % self.variance_covered),
430                             ])
431        if self.data is not None and self.projector_full:
432            output_domain = self.projector_full.output_domain
433            st_dev = np.sqrt(self.projector_full.variances)
434            summary = [[""] + [a.name for a in output_domain.attributes],
435                       ["Std. deviation"] + ["%.3f" % sd for sd in st_dev],
436                       ["Proportion Var"] + ["%.3f" % v for v in self.variances * 100.0],
437                       ["Cumulative Var"] + ["%.3f" % v for v in self.variances_cumsum * 100.0]
438                       ]
439
440            th = "<th>%s</th>".__mod__
441            header = "".join(map(th, summary[0]))
442            td = "<td>%s</td>".__mod__
443            summary = ["".join(map(td, row)) for row in summary[1:]]
444            tr = "<tr>%s</tr>".__mod__
445            summary = "\n".join(map(tr, [header] + summary))
446            summary = "<table>\n%s\n</table>" % summary
447
448            self.reportSection("Summary")
449            self.reportRaw(summary)
450
451            self.reportSection("Scree Plot")
452            self.reportImage(self.scree_plot.save_to_file_direct)
453
454
455def append_metas(dest, source):
456    """
457    Append all meta attributes from the `source` table to `dest` table.
458    The tables must be of the same length.
459
460    :param dest:
461        An data table into which the meta values will be copied.
462    :type dest: :class:`Orange.data.Table`
463
464    :param source:
465        A data table with the meta attributes/values to be copied into `dest`.
466    :type source: :class:`Orange.data.Table`
467
468    """
469    if len(dest) != len(source):
470        raise ValueError("'dest' and 'source' must have the same length.")
471
472    dest.domain.add_metas(source.domain.get_metas())
473    for dest_inst, source_inst in zip(dest, source):
474        for meta_id, val in source_inst.get_metas().items():
475            dest_inst[meta_id] = val
476
477
478if __name__ == "__main__":
479    app = QApplication(sys.argv)
480    w = OWPCA()
481    data = Orange.data.Table("iris")
482    w.set_data(data)
483    w.show()
484    app.exec_()
Note: See TracBrowser for help on using the repository browser.