source: orange/Orange/OrangeWidgets/Prototypes/OWPCA.py @ 10812:12ad1037bcb9

Revision 10812:12ad1037bcb9, 15.2 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Check for legend events.

Line 
1"""
2<name>PCA</name>
3<description>Perform Principal Component Analysis</description>
4<contact>ales.erjavec(@ at @)fri.uni-lj.si</contact>
5<icons>icons/PCA.png</icons>
6<tags>pca,principal,component,projection</tags>
7
8"""
9import Orange
10import Orange.utils.addons
11
12from OWWidget import *
13import OWGUI
14
15import Orange
16import Orange.projection.linear as plinear
17
18import numpy as np
19import sys
20
21from plot.owplot import OWPlot
22from plot.owcurve import OWCurve
23from plot import owaxis
24
25class ScreePlot(OWPlot):
26    def __init__(self, parent=None, name="Scree Plot"):
27        OWPlot.__init__(self, parent, name=name)
28        self.cutoff_curve = CutoffCurve([0.0, 0.0], [0.0, 1.0],
29                x_axis_key=owaxis.xBottom, y_axis_key=owaxis.yLeft)
30        self.cutoff_curve.setVisible(False)
31        self.cutoff_curve.set_style(OWCurve.Lines)
32        self.add_custom_curve(self.cutoff_curve)
33
34    def is_cutoff_enabled(self):
35        return self.cutoff_curve and self.cutoff_curve.isVisible()
36
37    def set_cutoff_curve_enabled(self, state):
38        self.cutoff_curve.setVisible(state)
39
40    def set_cutoff_value(self, value):
41        xmin, xmax = self.x_scale()
42        x = min(max(value, xmin), xmax)
43        self.cutoff_curve.set_data([x, x], [0.0, 1.0])
44
45    def mousePressEvent(self, event):
46        if self.isLegendEvent(event, QGraphicsView.mousePressEvent):
47            return
48
49        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
50            pos = self.mapToScene(event.pos())
51            x, _  = self.map_from_graph(pos)
52            xmin, xmax = self.x_scale()
53            if x >= xmin - 0.1 and x <= xmax + 0.1:
54                x = min(max(x, xmin), xmax)
55                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
56                self.emit_cutoff_moved(x)
57        return QGraphicsView.mousePressEvent(self, event)
58
59    def mouseMoveEvent(self, event):
60        if self.isLegendEvent(event, QGraphicsView.mouseMoveEvent):
61            return
62
63        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
64            pos = self.mapToScene(event.pos())
65            x, _ = self.map_from_graph(pos)
66            xmin, xmax = self.x_scale()
67            if x >= xmin - 0.5 and x <= xmax + 0.5:
68                x = min(max(x, xmin), xmax)
69                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
70                self.emit_cutoff_moved(x)
71        return QGraphicsView.mouseMoveEvent(self, event)
72
73    def mouseReleaseEvene(self, event):
74        return QGraphicsView.mouseReleaseEvent(self, event)
75
76    def x_scale(self):
77        ax = self.axes[owaxis.xBottom]
78        if ax.labels:
79            return 0, len(ax.labels) - 1
80        elif ax.scale:
81            return ax.scale[0], ax.scale[1]
82        else:
83            raise ValueError
84
85    def emit_cutoff_moved(self, x):
86        self.emit(SIGNAL("cutoff_moved(double)"), x)
87
88    def set_axis_labels(self, *args):
89        OWPlot.set_axis_labels(self, *args)
90        self.map_transform = self.transform_for_axes()
91
92class CutoffCurve(OWCurve):
93    def __init__(self, *args, **kwargs):
94        OWCurve.__init__(self, *args, **kwargs)
95        self.setAcceptHoverEvents(True)
96        self.setCursor(Qt.SizeHorCursor)
97
98class OWPCA(OWWidget):
99    settingsList = ["standardize", "max_components", "variance_covered",
100                    "use_generalized_eigenvectors", "auto_commit"]
101    def __init__(self, parent=None, signalManager=None, title="PCA"):
102        OWWidget.__init__(self, parent, signalManager, title, wantGraph=True)
103
104        self.inputs = [("Input Data", Orange.data.Table, self.set_data)]
105        self.outputs = [("Transformed Data", Orange.data.Table, Default),
106                        ("Eigen Vectors", Orange.data.Table)]
107
108        self.standardize = True
109        self.max_components = 0
110        self.variance_covered = 100.0
111        self.use_generalized_eigenvectors = False
112        self.auto_commit = False
113
114        self.loadSettings()
115
116        self.data = None
117        self.changed_flag = False
118
119        #####
120        # GUI
121        #####
122        grid = QGridLayout()
123        box = OWGUI.widgetBox(self.controlArea, "Components Selection",
124                              orientation=grid)
125
126        label1 = QLabel("Max components", box)
127        grid.addWidget(label1, 1, 0)
128
129        sb1 = OWGUI.spin(box, self, "max_components", 1, 1000,
130                         tooltip="Maximum number of components",
131                         callback=self.on_update,
132                         addToLayout=False,
133                         keyboardTracking=False
134                         )
135        self.max_components_spin = sb1.control
136        grid.addWidget(sb1.control, 1, 1)
137
138        label2 = QLabel("Variance covered", box)
139        grid.addWidget(label2, 2, 0)
140
141        sb2 = OWGUI.doubleSpin(box, self, "variance_covered", 1.0, 100.0, 1.0,
142                               tooltip="Percent of variance covered.",
143                               callback=self.on_update,
144                               decimals=1,
145                               addToLayout=False,
146                               keyboardTracking=False
147                               )
148        sb2.control.setSuffix("%")
149        grid.addWidget(sb2.control, 2, 1)
150
151        OWGUI.rubber(self.controlArea)
152
153        box = OWGUI.widgetBox(self.controlArea, "Commit")
154        cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change")
155        b = OWGUI.button(box, self, "Commit",
156                         callback=self.update_components)
157        OWGUI.setStopper(self, b, cb, "changed_flag", self.update_components)
158
159        self.scree_plot = ScreePlot(self)
160#        self.scree_plot.set_main_title("Scree Plot")
161#        self.scree_plot.set_show_main_title(True)
162        self.scree_plot.set_axis_title(owaxis.xBottom, "Principal Components")
163        self.scree_plot.set_show_axis_title(owaxis.xBottom, 1)
164        self.scree_plot.set_axis_title(owaxis.yLeft, "Proportion of Variance")
165        self.scree_plot.set_show_axis_title(owaxis.yLeft, 1)
166
167        self.variance_curve = self.scree_plot.add_curve(
168                        "Variance",
169                        Qt.red, Qt.red, 2, 
170                        xData=[],
171                        yData=[],
172                        style=OWCurve.Lines,
173                        enableLegend=True,
174                        lineWidth=2,
175                        autoScale=1,
176                        x_axis_key=owaxis.xBottom,
177                        y_axis_key=owaxis.yLeft,
178                        )
179
180        self.cumulative_variance_curve = self.scree_plot.add_curve(
181                        "Cumulative Variance",
182                        Qt.darkYellow, Qt.darkYellow, 2, 
183                        xData=[],
184                        yData=[],
185                        style=OWCurve.Lines,
186                        enableLegend=True,
187                        lineWidth=2,
188                        autoScale=1,
189                        x_axis_key=owaxis.xBottom,
190                        y_axis_key=owaxis.yLeft,
191                        )
192
193        self.mainArea.layout().addWidget(self.scree_plot)
194        self.connect(self.scree_plot,
195                     SIGNAL("cutoff_moved(double)"),
196                     self.on_cutoff_moved
197                     )
198
199        self.connect(self.graphButton,
200                     SIGNAL("clicked()"),
201                     self.scree_plot.save_to_file)
202
203        self.components = None
204        self.variances = None
205        self.variances_sum = None
206        self.projector_full = None
207        self.currently_selected = 0
208
209        self.resize(800, 400)
210
211    def clear(self):
212        """Clear widget state
213        """
214        self.data = None
215        self.scree_plot.set_cutoff_curve_enabled(False)
216        self.clear_cached()
217        self.variance_curve.setVisible(False)
218        self.cumulative_variance_curve.setVisible(False)
219
220    def clear_cached(self):
221        """Clear cached components
222        """
223        self.components = None
224        self.variances = None
225        self.variances_cumsum = None
226        self.projector_full = None
227        self.currently_selected = 0
228
229    def set_data(self, data=None):
230        """Set the widget input data.
231        """
232        self.clear()
233        if data is not None:
234            self.data = data
235            self.on_change()
236        else:
237            self.send("Transformed Data", None)
238            self.send("Eigen Vectors", None)
239
240    def on_change(self):
241        """Data has changed and we need to recompute the projection.
242        """
243        if self.data is None:
244            return
245        self.clear_cached()
246        self.apply()
247
248    def on_update(self):
249        """Component selection was changed by the user.
250        """
251        if self.data is None:
252            return
253        self.update_cutoff_curve()
254        if self.currently_selected != self.number_of_selected_components():
255            self.update_components_if()
256
257    def construct_pca_all_comp(self):
258        pca = plinear.PCA(standardize=self.standardize,
259                          max_components=0,
260                          variance_covered=1,
261                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
262                          )
263        return pca
264
265    def construct_pca(self):
266        max_components = self.max_components
267        variance_covered = self.variance_covered
268        pca = plinear.PCA(standardize=self.standardize,
269                          max_components=max_components,
270                          variance_covered=variance_covered / 100.0,
271                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
272                          )
273        return pca
274
275    def apply(self):
276        """Apply PCA on input data, caching the full projection,
277        then updating the selected components.
278       
279        """
280        pca = self.construct_pca_all_comp()
281        self.projector_full = projector = pca(self.data)
282
283        self.variances = self.projector_full.variances
284        self.variances /= np.sum(self.variances)
285        self.variances_cumsum = np.cumsum(self.variances)
286
287        self.max_components_spin.setRange(1, len(self.variances))
288        self.update_scree_plot()
289        self.update_cutoff_curve()
290        self.update_components_if()
291
292    def update_components_if(self):
293        if self.auto_commit:
294            self.update_components()
295        else:
296            self.changed_flag = True
297       
298    def update_components(self):
299        """Update the output components.
300        """
301        if self.data is None:
302            return 
303
304        scale = self.projector_full.scale
305        center = self.projector_full.center
306        components = self.projector_full.projection
307        input_domain = self.projector_full.input_domain
308        variances = self.projector_full.variances
309        variance_sum = self.projector_full.variance_sum
310
311        # Get selected components (based on max_components and
312        # variance_coverd)
313        pca = self.construct_pca()
314        variances, components, variance_sum = pca._select_components(variances, components)
315
316        projector = plinear.PcaProjector(input_domain=input_domain,
317                                         standardize=self.standardize,
318                                         scale=scale,
319                                         center=center,
320                                         projection=components,
321                                         variances=variances,
322                                         variance_sum=variance_sum)
323        projected_data = projector(self.data)
324        eigenvectors = self.eigenvectors_as_table(components)
325
326        self.currently_selected = self.number_of_selected_components()
327
328        self.send("Transformed Data", projected_data)
329        self.send("Eigen Vectors", eigenvectors)
330
331        self.changed_flag = False
332
333    def eigenvectors_as_table(self, U):
334        features = [Orange.feature.Continuous("C%i" % i) \
335                    for i in range(1, U.shape[1] + 1)]
336        domain = Orange.data.Domain(features, False)
337        return Orange.data.Table(domain, [list(v) for v in U])
338
339    def update_scree_plot(self):
340        x_space = np.arange(0, len(self.variances))
341        self.scree_plot.set_axis_enabled(owaxis.xBottom, True)
342        self.scree_plot.set_axis_enabled(owaxis.yLeft, True)
343        self.scree_plot.set_axis_labels(owaxis.xBottom, 
344                                        ["PC" + str(i + 1) for i in x_space])
345
346        self.variance_curve.set_data(x_space, self.variances)
347        self.cumulative_variance_curve.set_data(x_space, self.variances_cumsum)
348        self.variance_curve.setVisible(True)
349        self.cumulative_variance_curve.setVisible(True)
350
351        self.scree_plot.set_cutoff_curve_enabled(True)
352        self.scree_plot.replot()
353
354    def on_cutoff_moved(self, value):
355        """Cutoff curve was moved by the user.
356        """
357        components = int(np.floor(value)) + 1
358        # Did the number of components actually change
359        self.max_components = components
360        self.variance_covered = self.variances_cumsum[components - 1] * 100
361        if self.currently_selected != self.number_of_selected_components():
362            self.update_components_if()
363
364    def update_cutoff_curve(self):
365        """Update cutoff curve from 'Components Selection' control box.
366        """
367        variance = self.variances_cumsum[self.max_components - 1] * 100.0
368        if variance < self.variance_covered:
369            cutoff = float(self.max_components - 1)
370        else:
371            cutoff = np.searchsorted(self.variances_cumsum,
372                                     self.variance_covered / 100.0)
373        self.scree_plot.set_cutoff_value(cutoff + 0.5)
374
375    def number_of_selected_components(self):
376        """How many components are selected.
377        """
378        if self.data is None:
379            return 0
380
381        variance_components = np.searchsorted(self.variances_cumsum,
382                                              self.variance_covered / 100.0)
383        return min(variance_components + 1, self.max_components)
384
385    def sendReport(self):
386        self.reportSettings("PCA Settings",
387                            [("Max. components", self.max_components),
388                             ("Variance covered", "%i%%" % self.variance_covered),
389                             ])
390        if self.data is not None and self.projector_full:
391            output_domain = self.projector_full.output_domain
392            st_dev = np.sqrt(self.projector_full.variances)
393            summary = [[""] + [a.name for a in output_domain.attributes],
394                       ["Std. deviation"] + ["%.3f" % sd for sd in st_dev],
395                       ["Proportion Var"] + ["%.3f" % v for v in self.variances * 100.0],
396                       ["Cumulative Var"] + ["%.3f" % v for v in self.variances_cumsum * 100.0]
397                       ]
398
399            th = "<th>%s</th>".__mod__
400            header = "".join(map(th, summary[0]))
401            td = "<td>%s</td>".__mod__
402            summary = ["".join(map(td, row)) for row in summary[1:]]
403            tr = "<tr>%s</tr>".__mod__
404            summary = "\n".join(map(tr, [header] + summary))
405            summary = "<table>\n%s\n</table>" % summary
406
407            self.reportSection("Summary")
408            self.reportRaw(summary)
409
410            self.reportSection("Scree Plot")
411            self.reportImage(self.scree_plot.save_to_file_direct)
412
413if __name__ == "__main__":
414    app = QApplication(sys.argv)
415    w = OWPCA()
416    data = Orange.data.Table("iris")
417    w.set_data(data)
418    w.show()
419    app.exec_()
Note: See TracBrowser for help on using the repository browser.