source: orange/Orange/OrangeWidgets/Prototypes/OWPCA.py @ 10808:984a9f693ff2

Revision 10808:984a9f693ff2, 15.0 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Added 'Save Graph' and 'Report'.

Line 
1"""
2<name>PCA</name>
3<description>Perform Principal Component Analysis</description>
4<contact>ales.erjavec(@ at @)fri.uni-lj.si</contact>
5<icons>icons/PCA.png</icons>
6<tags>pca,principal,component,projection</tags>
7
8"""
9import Orange
10import Orange.utils.addons
11
12from OWWidget import *
13import OWGUI
14
15import Orange
16import Orange.projection.linear as plinear
17
18import numpy as np
19import sys
20
21from plot.owplot import OWPlot
22from plot.owcurve import OWCurve
23from plot import owaxis
24
25class ScreePlot(OWPlot):
26    def __init__(self, parent=None, name="Scree Plot"):
27        OWPlot.__init__(self, parent, name=name)
28        self.cutoff_curve = CutoffCurve([0.0, 0.0], [0.0, 1.0],
29                x_axis_key=owaxis.xBottom, y_axis_key=owaxis.yLeft)
30        self.cutoff_curve.setVisible(False)
31        self.cutoff_curve.set_style(OWCurve.Lines)
32        self.add_custom_curve(self.cutoff_curve)
33
34    def is_cutoff_enabled(self):
35        return self.cutoff_curve and self.cutoff_curve.isVisible()
36
37    def set_cutoff_curve_enabled(self, state):
38        self.cutoff_curve.setVisible(state)
39
40    def set_cutoff_value(self, value):
41        xmin, xmax = self.x_scale()
42        x = min(max(value, xmin), xmax)
43        self.cutoff_curve.set_data([x, x], [0.0, 1.0])
44
45    def mousePressEvent(self, event):
46        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
47            pos = self.mapToScene(event.pos())
48            x, _  = self.map_from_graph(pos)
49            xmin, xmax = self.x_scale()
50            if x >= xmin - 0.1 and x <= xmax + 0.1:
51                x = min(max(x, xmin), xmax)
52                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
53                self.emit_cutoff_moved(x)
54        return QGraphicsView.mousePressEvent(self, event)
55
56    def mouseMoveEvent(self, event):
57        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
58            pos = self.mapToScene(event.pos())
59            x, _ = self.map_from_graph(pos)
60            xmin, xmax = self.x_scale()
61            if x >= xmin - 0.5 and x <= xmax + 0.5:
62                x = min(max(x, xmin), xmax)
63                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
64                self.emit_cutoff_moved(x)
65        return QGraphicsView.mouseMoveEvent(self, event)
66
67    def mouseReleaseEvene(self, event):
68        return QGraphicsView.mouseReleaseEvent(self, event)
69
70    def x_scale(self):
71        ax = self.axes[owaxis.xBottom]
72        if ax.labels:
73            return 0, len(ax.labels) - 1
74        elif ax.scale:
75            return ax.scale[0], ax.scale[1]
76        else:
77            raise ValueError
78
79    def emit_cutoff_moved(self, x):
80        self.emit(SIGNAL("cutoff_moved(double)"), x)
81
82    def set_axis_labels(self, *args):
83        OWPlot.set_axis_labels(self, *args)
84        self.map_transform = self.transform_for_axes()
85
86class CutoffCurve(OWCurve):
87    def __init__(self, *args, **kwargs):
88        OWCurve.__init__(self, *args, **kwargs)
89        self.setAcceptHoverEvents(True)
90        self.setCursor(Qt.SizeHorCursor)
91
92class OWPCA(OWWidget):
93    settingsList = ["standardize", "max_components", "variance_covered",
94                    "use_generalized_eigenvectors", "auto_commit"]
95    def __init__(self, parent=None, signalManager=None, title="PCA"):
96        OWWidget.__init__(self, parent, signalManager, title, wantGraph=True)
97
98        self.inputs = [("Input Data", Orange.data.Table, self.set_data)]
99        self.outputs = [("Transformed Data", Orange.data.Table, Default),
100                        ("Eigen Vectors", Orange.data.Table)]
101
102        self.standardize = True
103        self.max_components = 0
104        self.variance_covered = 100.0
105        self.use_generalized_eigenvectors = False
106        self.auto_commit = False
107
108        self.loadSettings()
109
110        self.data = None
111        self.changed_flag = False
112
113        #####
114        # GUI
115        #####
116        grid = QGridLayout()
117        box = OWGUI.widgetBox(self.controlArea, "Components Selection",
118                              orientation=grid)
119
120        label1 = QLabel("Max components", box)
121        grid.addWidget(label1, 1, 0)
122
123        sb1 = OWGUI.spin(box, self, "max_components", 1, 1000,
124                         tooltip="Maximum number of components",
125                         callback=self.on_update,
126                         addToLayout=False,
127                         keyboardTracking=False
128                         )
129        self.max_components_spin = sb1.control
130        grid.addWidget(sb1.control, 1, 1)
131
132        label2 = QLabel("Variance covered", box)
133        grid.addWidget(label2, 2, 0)
134
135        sb2 = OWGUI.doubleSpin(box, self, "variance_covered", 1.0, 100.0, 1.0,
136                               tooltip="Percent of variance covered.",
137                               callback=self.on_update,
138                               decimals=1,
139                               addToLayout=False,
140                               keyboardTracking=False
141                               )
142        sb2.control.setSuffix("%")
143        grid.addWidget(sb2.control, 2, 1)
144
145        OWGUI.rubber(self.controlArea)
146
147        box = OWGUI.widgetBox(self.controlArea, "Commit")
148        cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change")
149        b = OWGUI.button(box, self, "Commit",
150                         callback=self.update_components)
151        OWGUI.setStopper(self, b, cb, "changed_flag", self.update_components)
152
153        self.scree_plot = ScreePlot(self)
154#        self.scree_plot.set_main_title("Scree Plot")
155#        self.scree_plot.set_show_main_title(True)
156        self.scree_plot.set_axis_title(owaxis.xBottom, "Principal Components")
157        self.scree_plot.set_show_axis_title(owaxis.xBottom, 1)
158        self.scree_plot.set_axis_title(owaxis.yLeft, "Proportion of Variance")
159        self.scree_plot.set_show_axis_title(owaxis.yLeft, 1)
160
161        self.variance_curve = self.scree_plot.add_curve(
162                        "Variance",
163                        Qt.red, Qt.red, 2, 
164                        xData=[],
165                        yData=[],
166                        style=OWCurve.Lines,
167                        enableLegend=True,
168                        lineWidth=2,
169                        autoScale=1,
170                        x_axis_key=owaxis.xBottom,
171                        y_axis_key=owaxis.yLeft,
172                        )
173
174        self.cumulative_variance_curve = self.scree_plot.add_curve(
175                        "Cumulative Variance",
176                        Qt.darkYellow, Qt.darkYellow, 2, 
177                        xData=[],
178                        yData=[],
179                        style=OWCurve.Lines,
180                        enableLegend=True,
181                        lineWidth=2,
182                        autoScale=1,
183                        x_axis_key=owaxis.xBottom,
184                        y_axis_key=owaxis.yLeft,
185                        )
186
187        self.mainArea.layout().addWidget(self.scree_plot)
188        self.connect(self.scree_plot,
189                     SIGNAL("cutoff_moved(double)"),
190                     self.on_cutoff_moved
191                     )
192
193        self.connect(self.graphButton,
194                     SIGNAL("clicked()"),
195                     self.scree_plot.save_to_file)
196
197        self.components = None
198        self.variances = None
199        self.variances_sum = None
200        self.projector_full = None
201        self.currently_selected = 0
202
203        self.resize(800, 400)
204
205    def clear(self):
206        """Clear widget state
207        """
208        self.data = None
209        self.scree_plot.set_cutoff_curve_enabled(False)
210        self.clear_cached()
211        self.variance_curve.setVisible(False)
212        self.cumulative_variance_curve.setVisible(False)
213
214    def clear_cached(self):
215        """Clear cached components
216        """
217        self.components = None
218        self.variances = None
219        self.variances_cumsum = None
220        self.projector_full = None
221        self.currently_selected = 0
222
223    def set_data(self, data=None):
224        """Set the widget input data.
225        """
226        self.clear()
227        if data is not None:
228            self.data = data
229            self.on_change()
230        else:
231            self.send("Transformed Data", None)
232            self.send("Eigen Vectors", None)
233
234    def on_change(self):
235        """Data has changed and we need to recompute the projection.
236        """
237        if self.data is None:
238            return
239        self.clear_cached()
240        self.apply()
241
242    def on_update(self):
243        """Component selection was changed by the user.
244        """
245        if self.data is None:
246            return
247        self.update_cutoff_curve()
248        if self.currently_selected != self.number_of_selected_components():
249            self.update_components_if()
250
251    def construct_pca_all_comp(self):
252        pca = plinear.PCA(standardize=self.standardize,
253                          max_components=0,
254                          variance_covered=1,
255                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
256                          )
257        return pca
258
259    def construct_pca(self):
260        max_components = self.max_components
261        variance_covered = self.variance_covered
262        pca = plinear.PCA(standardize=self.standardize,
263                          max_components=max_components,
264                          variance_covered=variance_covered / 100.0,
265                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
266                          )
267        return pca
268
269    def apply(self):
270        """Apply PCA on input data, caching the full projection,
271        then updating the selected components.
272       
273        """
274        pca = self.construct_pca_all_comp()
275        self.projector_full = projector = pca(self.data)
276
277        self.variances = self.projector_full.variances
278        self.variances /= np.sum(self.variances)
279        self.variances_cumsum = np.cumsum(self.variances)
280
281        self.max_components_spin.setRange(1, len(self.variances))
282        self.update_scree_plot()
283        self.update_cutoff_curve()
284        self.update_components_if()
285
286    def update_components_if(self):
287        if self.auto_commit:
288            self.update_components()
289        else:
290            self.changed_flag = True
291       
292    def update_components(self):
293        """Update the output components.
294        """
295        if self.data is None:
296            return 
297
298        scale = self.projector_full.scale
299        center = self.projector_full.center
300        components = self.projector_full.projection
301        input_domain = self.projector_full.input_domain
302        variances = self.projector_full.variances
303        variance_sum = self.projector_full.variance_sum
304
305        # Get selected components (based on max_components and
306        # variance_coverd)
307        pca = self.construct_pca()
308        variances, components, variance_sum = pca._select_components(variances, components)
309
310        projector = plinear.PcaProjector(input_domain=input_domain,
311                                         standardize=self.standardize,
312                                         scale=scale,
313                                         center=center,
314                                         projection=components,
315                                         variances=variances,
316                                         variance_sum=variance_sum)
317        projected_data = projector(self.data)
318        eigenvectors = self.eigenvectors_as_table(components)
319
320        self.currently_selected = self.number_of_selected_components()
321
322        self.send("Transformed Data", projected_data)
323        self.send("Eigen Vectors", eigenvectors)
324
325        self.changed_flag = False
326
327    def eigenvectors_as_table(self, U):
328        features = [Orange.feature.Continuous("C%i" % i) \
329                    for i in range(1, U.shape[1] + 1)]
330        domain = Orange.data.Domain(features, False)
331        return Orange.data.Table(domain, [list(v) for v in U])
332
333    def update_scree_plot(self):
334        x_space = np.arange(0, len(self.variances))
335        self.scree_plot.set_axis_enabled(owaxis.xBottom, True)
336        self.scree_plot.set_axis_enabled(owaxis.yLeft, True)
337        self.scree_plot.set_axis_labels(owaxis.xBottom, 
338                                        ["PC" + str(i + 1) for i in x_space])
339
340        self.variance_curve.set_data(x_space, self.variances)
341        self.cumulative_variance_curve.set_data(x_space, self.variances_cumsum)
342        self.variance_curve.setVisible(True)
343        self.cumulative_variance_curve.setVisible(True)
344
345        self.scree_plot.set_cutoff_curve_enabled(True)
346
347    def on_cutoff_moved(self, value):
348        """Cutoff curve was moved by the user.
349        """
350        components = int(np.floor(value)) + 1
351        # Did the number of components actually change
352        self.max_components = components
353        self.variance_covered = self.variances_cumsum[components - 1] * 100
354        if self.currently_selected != self.number_of_selected_components():
355            self.update_components_if()
356
357    def update_cutoff_curve(self):
358        """Update cutoff curve from 'Components Selection' control box.
359        """
360        variance = self.variances_cumsum[self.max_components - 1] * 100.0
361        if variance < self.variance_covered:
362            cutoff = float(self.max_components - 1)
363        else:
364            cutoff = np.searchsorted(self.variances_cumsum,
365                                     self.variance_covered / 100.0)
366        self.scree_plot.set_cutoff_value(cutoff + 0.5)
367
368    def number_of_selected_components(self):
369        """How many components are selected.
370        """
371        if self.data is None:
372            return 0
373
374        variance_components = np.searchsorted(self.variances_cumsum,
375                                              self.variance_covered / 100.0)
376        return min(variance_components + 1, self.max_components)
377
378    def sendReport(self):
379        self.reportSettings("PCA Settings",
380                            [("Max. components", self.max_components),
381                             ("Variance covered", "%i%%" % self.variance_covered),
382                             ])
383        if self.data is not None and self.projector_full:
384            output_domain = self.projector_full.output_domain
385            st_dev = np.sqrt(self.projector_full.variances)
386            summary = [[""] + [a.name for a in output_domain.attributes],
387                       ["Std. deviation"] + ["%.3f" % sd for sd in st_dev],
388                       ["Proportion Var"] + ["%.3f" % v for v in self.variances * 100.0],
389                       ["Cumulative Var"] + ["%.3f" % v for v in self.variances_cumsum * 100.0]
390                       ]
391
392            th = "<th>%s</th>".__mod__
393            header = "".join(map(th, summary[0]))
394            td = "<td>%s</td>".__mod__
395            summary = ["".join(map(td, row)) for row in summary[1:]]
396            tr = "<tr>%s</tr>".__mod__
397            summary = "\n".join(map(tr, [header] + summary))
398            summary = "<table>\n%s\n</table>" % summary
399
400            self.reportSection("Summary")
401            self.reportRaw(summary)
402
403            self.reportSection("Scree Plot")
404            self.reportImage(self.scree_plot.save_to_file_direct)
405
406if __name__ == "__main__":
407    app = QApplication(sys.argv)
408    w = OWPCA()
409    data = Orange.data.Table("iris")
410    w.set_data(data)
411    w.show()
412    app.exec_()
Note: See TracBrowser for help on using the repository browser.