source: orange/Orange/OrangeWidgets/Prototypes/OWPCA.py @ 10801:1d4afcbe2add

Revision 10801:1d4afcbe2add, 11.8 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Fixed widget headers.

Line 
1"""
2<name>PCA</name>
3<description>Perform Principal Component Analysis</description>
4<contact>ales.erjavec(@ at @)fri.uni-lj.si</contact>
5<icons>icons/PCA.png</icons>
6<tags>pca,principal,component,projection</tags>
7
8"""
9import Orange
10import Orange.utils.addons
11
12from OWWidget import *
13import OWGUI
14
15import Orange
16import Orange.projection.linear as plinear
17
18import numpy as np
19import sys
20
21from plot.owplot import OWPlot
22from plot.owcurve import OWCurve
23from plot import owaxis
24
25class ScreePlot(OWPlot):
26    def __init__(self, parent=None, name="Scree Plot"):
27        OWPlot.__init__(self, parent, name=name)
28        self.cutoff_curve = CutoffCurve([0.0, 0.0], [0.0, 1.0],
29                x_axis_key=owaxis.xBottom, y_axis_key=owaxis.yLeft)
30        self.cutoff_curve.setVisible(False)
31        self.cutoff_curve.set_style(OWCurve.Lines)
32        self.add_custom_curve(self.cutoff_curve)
33
34    def is_cutoff_enabled(self):
35        return self.cutoff_curve and self.cutoff_curve.isVisible()
36
37    def set_cutoff_curve_enabled(self, state):
38        self.cutoff_curve.setVisible(state)
39
40    def set_cutoff_value(self, value):
41        xmin, xmax = self.x_scale()
42        x = min(max(value, xmin), xmax)
43        self.cutoff_curve.set_data([x, x], [0.0, 1.0])
44
45    def mousePressEvent(self, event):
46        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
47            pos = self.mapToScene(event.pos())
48            x, _  = self.map_from_graph(pos)
49            xmin, xmax = self.x_scale()
50            if x >= xmin and x <= xmax:
51                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
52                self.emit_cutoff_moved(x)
53        return QGraphicsView.mousePressEvent(self, event)
54
55    def mouseMoveEvent(self, event):
56        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
57            pos = self.mapToScene(event.pos())
58            x, _ = self.map_from_graph(pos)
59            xmin, xmax = self.x_scale()
60            if x >= xmin and x <= xmax:
61                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
62                self.emit_cutoff_moved(x)
63        return QGraphicsView.mouseMoveEvent(self, event)
64
65    def mouseReleaseEvene(self, event):
66        return QGraphicsView.mouseReleaseEvent(self, event)
67
68    def x_scale(self):
69        ax = self.axes[owaxis.xBottom]
70        if ax.labels:
71            return 0, len(ax.labels) - 0.5
72        elif ax.scale:
73            return ax.scale[0], ax.scale[1] + 0.5
74        else:
75            raise ValueError
76
77    def emit_cutoff_moved(self, x):
78        self.emit(SIGNAL("cutoff_moved(double)"), x)
79
80    def set_axis_labels(self, *args):
81        OWPlot.set_axis_labels(self, *args)
82        self.map_transform = self.transform_for_axes()
83
84class CutoffCurve(OWCurve):
85    def __init__(self, *args, **kwargs):
86        OWCurve.__init__(self, *args, **kwargs)
87        self.setAcceptHoverEvents(True)
88        self.setCursor(Qt.SizeHorCursor)
89
90class OWPCA(OWWidget):
91    settingsList = ["standardize", "max_components", "variance_covered",
92                    "use_generalized_eigenvectors"]
93    def __init__(self, parent=None, signalManager=None, title="PCA"):
94        OWWidget.__init__(self, parent, signalManager, title)
95
96        self.inputs = [("Input Data", Orange.data.Table, self.set_data)]
97        self.outputs = [("Transformed Data", Orange.data.Table, Default),
98                        ("Eigen Vectors", Orange.data.Table)]
99
100        self.standardize = True
101        self.max_components = 0
102        self.variance_covered = 100.0
103        self.use_generalized_eigenvectors = False
104
105        self.loadSettings()
106
107        self.data = None
108
109        #####
110        # GUI
111        #####
112        grid = QGridLayout()
113        box = OWGUI.widgetBox(self.controlArea, "Settings",
114                              orientation=grid)
115
116        label1 = QLabel("Max components", box)
117        grid.addWidget(label1, 1, 0)
118
119        sb1 = OWGUI.spin(box, self, "max_components", 1, 1000,
120                         tooltip="Maximum number of components",
121                         callback=self.on_update,
122                         addToLayout=False,
123                         keyboardTracking=False
124                         )
125        self.max_components_spin = sb1.control
126        grid.addWidget(sb1.control, 1, 1)
127
128        label2 = QLabel("Variance covered", box)
129        grid.addWidget(label2, 2, 0)
130
131        sb2 = OWGUI.doubleSpin(box, self, "variance_covered", 1.0, 100.0, 1.0,
132                               tooltip="Percent of variance covered.",
133                               callback=self.on_update,
134                               decimals=1,
135                               addToLayout=False,
136                               keyboardTracking=False
137                               )
138        sb2.control.setSuffix("%")
139        grid.addWidget(sb2.control, 2, 1)
140
141        OWGUI.rubber(self.controlArea)
142
143        self.scree_plot = ScreePlot(self)
144#        self.scree_plot.set_main_title("Scree Plot")
145#        self.scree_plot.set_show_main_title(True)
146        self.scree_plot.set_axis_title(owaxis.xBottom, "Principal Components")
147        self.scree_plot.set_show_axis_title(owaxis.xBottom, 1)
148        self.scree_plot.set_axis_title(owaxis.yLeft, "Proportion of Variance")
149        self.scree_plot.set_show_axis_title(owaxis.yLeft, 1)
150
151        self.variance_curve = self.scree_plot.add_curve(
152                        "Variance",
153                        Qt.red, Qt.red, 2, 
154                        xData=[],
155                        yData=[],
156                        style=OWCurve.Lines,
157                        enableLegend=True,
158                        lineWidth=2,
159                        autoScale=1,
160                        x_axis_key=owaxis.xBottom,
161                        y_axis_key=owaxis.yLeft,
162                        )
163
164        self.cumulative_variance_curve = self.scree_plot.add_curve(
165                        "Cumulative Variance",
166                        Qt.darkYellow, Qt.darkYellow, 2, 
167                        xData=[],
168                        yData=[],
169                        style=OWCurve.Lines,
170                        enableLegend=True,
171                        lineWidth=2,
172                        autoScale=1,
173                        x_axis_key=owaxis.xBottom,
174                        y_axis_key=owaxis.yLeft,
175                        )
176
177        self.mainArea.layout().addWidget(self.scree_plot)
178        self.connect(self.scree_plot,
179                     SIGNAL("cutoff_moved(double)"),
180                     self.on_cutoff_moved
181                     )
182        self.components = None
183        self.variances = None
184        self.variances_sum = None
185        self.projector_full = None
186
187        self.resize(800, 400)
188
189    def clear(self):
190        """Clear widget state
191        """
192        self.data = None
193        self.scree_plot.set_cutoff_curve_enabled(False)
194        self.clear_cached()
195        self.variance_curve.setVisible(False)
196        self.cumulative_variance_curve.setVisible(False)
197
198    def clear_cached(self):
199        """Clear cached components
200        """
201        self.components = None
202        self.variances = None
203        self.variances_cumsum = None
204        self.projector_full = None
205
206    def set_data(self, data=None):
207        """Set the widget input data.
208        """
209        self.clear()
210        if data is not None:
211            self.data = data
212            self.on_change()
213
214    def on_change(self):
215        if self.data is None:
216            return
217        self.clear_cached()
218        self.apply()
219
220    def on_update(self):
221        if self.data is None:
222            return
223        self.update_cutoff_curve()
224        self.update_components()
225
226    def construct_pca_all_comp(self):
227        pca = plinear.PCA(standardize=self.standardize,
228                          max_components=0,
229                          variance_covered=1,
230                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
231                          )
232        return pca
233
234    def construct_pca(self):
235        max_components = self.max_components
236        variance_covered = self.variance_covered
237        pca = plinear.PCA(standardize=self.standardize,
238                          max_components=max_components,
239                          variance_covered=variance_covered / 100.0,
240                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
241                          )
242        return pca
243
244    def apply(self):
245        """Apply PCA in input data, caching the full projection,
246        then updating the selected components.
247       
248        """
249        pca = self.construct_pca_all_comp()
250        self.projector_full = projector = pca(self.data)
251
252        self.variances = self.projector_full.variances
253        self.variances /= np.sum(self.variances)
254        self.variances_cumsum = np.cumsum(self.variances)
255
256        self.max_components_spin.setRange(1, len(self.variances))
257        self.update_scree_plot()
258        self.update_components()
259
260    def update_components(self):
261        scale = self.projector_full.scale
262        center = self.projector_full.center
263        components = self.projector_full.projection
264        input_domain = self.projector_full.input_domain
265        variances = self.projector_full.variances
266        variance_sum = self.projector_full.variance_sum
267
268        # Get selected components (based on max_components and
269        # variance_coverd)
270        pca = self.construct_pca()
271        variances, components, variance_sum = pca._select_components(variances, components)
272
273        projector = plinear.PcaProjector(input_domain=input_domain,
274                                         standardize=self.standardize,
275                                         scale=scale,
276                                         center=center,
277                                         projection=components,
278                                         variances=variances,
279                                         variance_sum=variance_sum)
280        projected_data = projector(self.data)
281        eigenvectors = self.eigenvectors_as_table(components)
282
283        self.send("Transformed Data", projected_data)
284        self.send("Eigen Vectors", eigenvectors)
285
286    def eigenvectors_as_table(self, U):
287        features = [Orange.feature.Continuous("C%i" % i) \
288                    for i in range(1, U.shape[1] + 1)]
289        domain = Orange.data.Domain(features, False)
290        return Orange.data.Table(domain, [list(v) for v in U])
291
292    def update_scree_plot(self):
293        variances = self.projector_full.variances
294        s = np.sum(variances)
295        cv = variances / s
296        cs = np.cumsum(cv)
297        x_space = np.arange(0, len(variances))
298        self.scree_plot.set_axis_enabled(owaxis.xBottom, True)
299        self.scree_plot.set_axis_enabled(owaxis.yLeft, True)
300        self.scree_plot.set_axis_labels(owaxis.xBottom, 
301                                        ["PC" + str(i + 1) for i in x_space])
302
303        self.variance_curve.set_data(x_space, cv)
304        self.cumulative_variance_curve.set_data(x_space, cs)
305        self.variance_curve.setVisible(True)
306        self.cumulative_variance_curve.setVisible(True)
307
308        self.scree_plot.set_cutoff_curve_enabled(True)
309
310    def on_cutoff_moved(self, value):
311        components = int(np.floor(value)) + 1
312        if components != self.max_components:
313            self.max_components = int(np.floor(value)) + 1
314            self.variance_covered = self.variances_cumsum[self.max_components - 1] * 100
315            self.update_components()
316
317    def update_cutoff_curve(self):
318        """Update cutoff line from gui control elements.
319        """
320        variance = self.variances_cumsum[self.max_components - 1] * 100.0
321        if variance < self.variance_covered:
322            cutoff = float(self.max_components - 1)
323        else:
324            cutoff = np.searchsorted(self.variances_cumsum,
325                                     self.variance_covered / 100.0)
326        self.scree_plot.set_cutoff_value(cutoff + 0.5)
327
328
329if __name__ == "__main__":
330    app = QApplication(sys.argv)
331    w = OWPCA()
332    data = Orange.data.Table("iris")
333    w.set_data(data)
334    w.show()
335    app.exec_()
Note: See TracBrowser for help on using the repository browser.