source: orange/Orange/OrangeWidgets/Prototypes/OWPCA.py @ 10833:e5bf4036efba

Revision 10833:e5bf4036efba, 15.7 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Fix the value of 'max_components' when the new dataset has less features.

Line 
1"""
2<name>PCA</name>
3<description>Perform Principal Component Analysis</description>
4<contact>ales.erjavec(@ at @)fri.uni-lj.si</contact>
5<icons>icons/PCA.png</icons>
6<tags>pca,principal,component,projection</tags>
7
8"""
9import Orange
10import Orange.utils.addons
11
12from OWWidget import *
13import OWGUI
14
15import Orange
16import Orange.projection.linear as plinear
17
18import numpy as np
19import sys
20
21from plot.owplot import OWPlot
22from plot.owcurve import OWCurve
23from plot import owaxis
24
25
26class ScreePlot(OWPlot):
27    def __init__(self, parent=None, name="Scree Plot"):
28        OWPlot.__init__(self, parent, name=name)
29        self.cutoff_curve = CutoffCurve([0.0, 0.0], [0.0, 1.0],
30                x_axis_key=owaxis.xBottom, y_axis_key=owaxis.yLeft)
31        self.cutoff_curve.setVisible(False)
32        self.cutoff_curve.set_style(OWCurve.Lines)
33        self.add_custom_curve(self.cutoff_curve)
34
35    def is_cutoff_enabled(self):
36        return self.cutoff_curve and self.cutoff_curve.isVisible()
37
38    def set_cutoff_curve_enabled(self, state):
39        self.cutoff_curve.setVisible(state)
40
41    def set_cutoff_value(self, value):
42        xmin, xmax = self.x_scale()
43        x = min(max(value, xmin), xmax)
44        self.cutoff_curve.set_data([x, x], [0.0, 1.0])
45
46    def mousePressEvent(self, event):
47        if self.isLegendEvent(event, QGraphicsView.mousePressEvent):
48            return
49
50        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
51            pos = self.mapToScene(event.pos())
52            x, _ = self.map_from_graph(pos)
53            xmin, xmax = self.x_scale()
54            if x >= xmin - 0.1 and x <= xmax + 0.1:
55                x = min(max(x, xmin), xmax)
56                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
57                self.emit_cutoff_moved(x)
58        return QGraphicsView.mousePressEvent(self, event)
59
60    def mouseMoveEvent(self, event):
61        if self.isLegendEvent(event, QGraphicsView.mouseMoveEvent):
62            return
63
64        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
65            pos = self.mapToScene(event.pos())
66            x, _ = self.map_from_graph(pos)
67            xmin, xmax = self.x_scale()
68            if x >= xmin - 0.5 and x <= xmax + 0.5:
69                x = min(max(x, xmin), xmax)
70                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
71                self.emit_cutoff_moved(x)
72        return QGraphicsView.mouseMoveEvent(self, event)
73
74    def mouseReleaseEvene(self, event):
75        return QGraphicsView.mouseReleaseEvent(self, event)
76
77    def x_scale(self):
78        ax = self.axes[owaxis.xBottom]
79        if ax.labels:
80            return 0, len(ax.labels) - 1
81        elif ax.scale:
82            return ax.scale[0], ax.scale[1]
83        else:
84            raise ValueError
85
86    def emit_cutoff_moved(self, x):
87        self.emit(SIGNAL("cutoff_moved(double)"), x)
88
89    def set_axis_labels(self, *args):
90        OWPlot.set_axis_labels(self, *args)
91        self.map_transform = self.transform_for_axes()
92
93
94class CutoffCurve(OWCurve):
95    def __init__(self, *args, **kwargs):
96        OWCurve.__init__(self, *args, **kwargs)
97        self.setAcceptHoverEvents(True)
98        self.setCursor(Qt.SizeHorCursor)
99
100
101class OWPCA(OWWidget):
102    settingsList = ["standardize", "max_components", "variance_covered",
103                    "use_generalized_eigenvectors", "auto_commit"]
104
105    def __init__(self, parent=None, signalManager=None, title="PCA"):
106        OWWidget.__init__(self, parent, signalManager, title, wantGraph=True)
107
108        self.inputs = [("Input Data", Orange.data.Table, self.set_data)]
109        self.outputs = [("Transformed Data", Orange.data.Table, Default),
110                        ("Eigen Vectors", Orange.data.Table)]
111
112        self.standardize = True
113        self.max_components = 0
114        self.variance_covered = 100.0
115        self.use_generalized_eigenvectors = False
116        self.auto_commit = False
117
118        self.loadSettings()
119
120        self.data = None
121        self.changed_flag = False
122
123        #####
124        # GUI
125        #####
126        grid = QGridLayout()
127        box = OWGUI.widgetBox(self.controlArea, "Components Selection",
128                              orientation=grid)
129
130        label1 = QLabel("Max components", box)
131        grid.addWidget(label1, 1, 0)
132
133        sb1 = OWGUI.spin(box, self, "max_components", 0, 1000,
134                         tooltip="Maximum number of components",
135                         callback=self.on_update,
136                         addToLayout=False,
137                         keyboardTracking=False
138                         )
139        self.max_components_spin = sb1.control
140        self.max_components_spin.setSpecialValueText("All")
141        grid.addWidget(sb1.control, 1, 1)
142
143        label2 = QLabel("Variance covered", box)
144        grid.addWidget(label2, 2, 0)
145
146        sb2 = OWGUI.doubleSpin(box, self, "variance_covered", 1.0, 100.0, 1.0,
147                               tooltip="Percent of variance covered.",
148                               callback=self.on_update,
149                               decimals=1,
150                               addToLayout=False,
151                               keyboardTracking=False
152                               )
153        sb2.control.setSuffix("%")
154        grid.addWidget(sb2.control, 2, 1)
155
156        OWGUI.rubber(self.controlArea)
157
158        box = OWGUI.widgetBox(self.controlArea, "Commit")
159        cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change")
160        b = OWGUI.button(box, self, "Commit",
161                         callback=self.update_components)
162        OWGUI.setStopper(self, b, cb, "changed_flag", self.update_components)
163
164        self.scree_plot = ScreePlot(self)
165#        self.scree_plot.set_main_title("Scree Plot")
166#        self.scree_plot.set_show_main_title(True)
167        self.scree_plot.set_axis_title(owaxis.xBottom, "Principal Components")
168        self.scree_plot.set_show_axis_title(owaxis.xBottom, 1)
169        self.scree_plot.set_axis_title(owaxis.yLeft, "Proportion of Variance")
170        self.scree_plot.set_show_axis_title(owaxis.yLeft, 1)
171
172        self.variance_curve = self.scree_plot.add_curve(
173                        "Variance",
174                        Qt.red, Qt.red, 2,
175                        xData=[],
176                        yData=[],
177                        style=OWCurve.Lines,
178                        enableLegend=True,
179                        lineWidth=2,
180                        autoScale=1,
181                        x_axis_key=owaxis.xBottom,
182                        y_axis_key=owaxis.yLeft,
183                        )
184
185        self.cumulative_variance_curve = self.scree_plot.add_curve(
186                        "Cumulative Variance",
187                        Qt.darkYellow, Qt.darkYellow, 2,
188                        xData=[],
189                        yData=[],
190                        style=OWCurve.Lines,
191                        enableLegend=True,
192                        lineWidth=2,
193                        autoScale=1,
194                        x_axis_key=owaxis.xBottom,
195                        y_axis_key=owaxis.yLeft,
196                        )
197
198        self.mainArea.layout().addWidget(self.scree_plot)
199        self.connect(self.scree_plot,
200                     SIGNAL("cutoff_moved(double)"),
201                     self.on_cutoff_moved
202                     )
203
204        self.connect(self.graphButton,
205                     SIGNAL("clicked()"),
206                     self.scree_plot.save_to_file)
207
208        self.components = None
209        self.variances = None
210        self.variances_sum = None
211        self.projector_full = None
212        self.currently_selected = 0
213
214        self.resize(800, 400)
215
216    def clear(self):
217        """Clear widget state
218        """
219        self.data = None
220        self.scree_plot.set_cutoff_curve_enabled(False)
221        self.clear_cached()
222        self.variance_curve.setVisible(False)
223        self.cumulative_variance_curve.setVisible(False)
224
225    def clear_cached(self):
226        """Clear cached components
227        """
228        self.components = None
229        self.variances = None
230        self.variances_cumsum = None
231        self.projector_full = None
232        self.currently_selected = 0
233
234    def set_data(self, data=None):
235        """Set the widget input data.
236        """
237        self.clear()
238        if data is not None:
239            self.data = data
240            self.on_change()
241        else:
242            self.send("Transformed Data", None)
243            self.send("Eigen Vectors", None)
244
245    def on_change(self):
246        """Data has changed and we need to recompute the projection.
247        """
248        if self.data is None:
249            return
250        self.clear_cached()
251        self.apply()
252
253    def on_update(self):
254        """Component selection was changed by the user.
255        """
256        if self.data is None:
257            return
258        self.update_cutoff_curve()
259        if self.currently_selected != self.number_of_selected_components():
260            self.update_components_if()
261
262    def construct_pca_all_comp(self):
263        pca = plinear.PCA(standardize=self.standardize,
264                          max_components=0,
265                          variance_covered=1,
266                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
267                          )
268        return pca
269
270    def construct_pca(self):
271        max_components = self.max_components
272        variance_covered = self.variance_covered
273        pca = plinear.PCA(standardize=self.standardize,
274                          max_components=max_components,
275                          variance_covered=variance_covered / 100.0,
276                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
277                          )
278        return pca
279
280    def apply(self):
281        """Apply PCA on input data, caching the full projection,
282        then updating the selected components.
283       
284        """
285        pca = self.construct_pca_all_comp()
286        self.projector_full = projector = pca(self.data)
287
288        self.variances = self.projector_full.variances
289        self.variances /= np.sum(self.variances)
290        self.variances_cumsum = np.cumsum(self.variances)
291
292        self.max_components_spin.setRange(0, len(self.variances))
293        self.max_components = min(self.max_components,
294                                  len(self.variances) - 1)
295        self.update_scree_plot()
296        self.update_cutoff_curve()
297        self.update_components_if()
298
299    def update_components_if(self):
300        if self.auto_commit:
301            self.update_components()
302        else:
303            self.changed_flag = True
304
305    def update_components(self):
306        """Update the output components.
307        """
308        if self.data is None:
309            return
310
311        scale = self.projector_full.scale
312        center = self.projector_full.center
313        components = self.projector_full.projection
314        input_domain = self.projector_full.input_domain
315        variances = self.projector_full.variances
316        variance_sum = self.projector_full.variance_sum
317
318        # Get selected components (based on max_components and
319        # variance_coverd)
320        pca = self.construct_pca()
321        variances, components, variance_sum = pca._select_components(variances, components)
322
323        projector = plinear.PcaProjector(input_domain=input_domain,
324                                         standardize=self.standardize,
325                                         scale=scale,
326                                         center=center,
327                                         projection=components,
328                                         variances=variances,
329                                         variance_sum=variance_sum)
330        projected_data = projector(self.data)
331        eigenvectors = self.eigenvectors_as_table(components)
332
333        self.currently_selected = self.number_of_selected_components()
334
335        self.send("Transformed Data", projected_data)
336        self.send("Eigen Vectors", eigenvectors)
337
338        self.changed_flag = False
339
340    def eigenvectors_as_table(self, U):
341        features = [Orange.feature.Continuous("C%i" % i) \
342                    for i in range(1, U.shape[1] + 1)]
343        domain = Orange.data.Domain(features, False)
344        return Orange.data.Table(domain, [list(v) for v in U])
345
346    def update_scree_plot(self):
347        x_space = np.arange(0, len(self.variances))
348        self.scree_plot.set_axis_enabled(owaxis.xBottom, True)
349        self.scree_plot.set_axis_enabled(owaxis.yLeft, True)
350        self.scree_plot.set_axis_labels(owaxis.xBottom,
351                                        ["PC" + str(i + 1) for i in x_space])
352
353        self.variance_curve.set_data(x_space, self.variances)
354        self.cumulative_variance_curve.set_data(x_space, self.variances_cumsum)
355        self.variance_curve.setVisible(True)
356        self.cumulative_variance_curve.setVisible(True)
357
358        self.scree_plot.set_cutoff_curve_enabled(True)
359        self.scree_plot.replot()
360
361    def on_cutoff_moved(self, value):
362        """Cutoff curve was moved by the user.
363        """
364        components = int(np.floor(value)) + 1
365        # Did the number of components actually change
366        self.max_components = components
367        self.variance_covered = self.variances_cumsum[components - 1] * 100
368        if self.currently_selected != self.number_of_selected_components():
369            self.update_components_if()
370
371    def update_cutoff_curve(self):
372        """Update cutoff curve from 'Components Selection' control box.
373        """
374        if self.max_components == 0:
375            # Special "All" value
376            max_components = len(self.variances_cumsum)
377        else:
378            max_components = self.max_components
379
380        variance = self.variances_cumsum[max_components - 1] * 100.0
381        if variance < self.variance_covered:
382            cutoff = float(max_components - 1)
383        else:
384            cutoff = np.searchsorted(self.variances_cumsum,
385                                     self.variance_covered / 100.0)
386        self.scree_plot.set_cutoff_value(cutoff + 0.5)
387
388    def number_of_selected_components(self):
389        """How many components are selected.
390        """
391        if self.data is None:
392            return 0
393
394        variance_components = np.searchsorted(self.variances_cumsum,
395                                              self.variance_covered / 100.0)
396        if self.max_components == 0:
397            # Special "All" value
398            max_components = len(self.variances_cumsum)
399        else:
400            max_components = self.max_components
401        return min(variance_components + 1, max_components)
402
403    def sendReport(self):
404        self.reportSettings("PCA Settings",
405                            [("Max. components", self.max_components),
406                             ("Variance covered", "%i%%" % self.variance_covered),
407                             ])
408        if self.data is not None and self.projector_full:
409            output_domain = self.projector_full.output_domain
410            st_dev = np.sqrt(self.projector_full.variances)
411            summary = [[""] + [a.name for a in output_domain.attributes],
412                       ["Std. deviation"] + ["%.3f" % sd for sd in st_dev],
413                       ["Proportion Var"] + ["%.3f" % v for v in self.variances * 100.0],
414                       ["Cumulative Var"] + ["%.3f" % v for v in self.variances_cumsum * 100.0]
415                       ]
416
417            th = "<th>%s</th>".__mod__
418            header = "".join(map(th, summary[0]))
419            td = "<td>%s</td>".__mod__
420            summary = ["".join(map(td, row)) for row in summary[1:]]
421            tr = "<tr>%s</tr>".__mod__
422            summary = "\n".join(map(tr, [header] + summary))
423            summary = "<table>\n%s\n</table>" % summary
424
425            self.reportSection("Summary")
426            self.reportRaw(summary)
427
428            self.reportSection("Scree Plot")
429            self.reportImage(self.scree_plot.save_to_file_direct)
430
431
432if __name__ == "__main__":
433    app = QApplication(sys.argv)
434    w = OWPCA()
435    data = Orange.data.Table("iris")
436    w.set_data(data)
437    w.show()
438    app.exec_()
Note: See TracBrowser for help on using the repository browser.