source: orange/Orange/OrangeWidgets/Prototypes/OWPCA.py @ 10806:eb51260debca

Revision 10806:eb51260debca, 13.0 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Fixed selected components update

Line 
1"""
2<name>PCA</name>
3<description>Perform Principal Component Analysis</description>
4<contact>ales.erjavec(@ at @)fri.uni-lj.si</contact>
5<icons>icons/PCA.png</icons>
6<tags>pca,principal,component,projection</tags>
7
8"""
9import Orange
10import Orange.utils.addons
11
12from OWWidget import *
13import OWGUI
14
15import Orange
16import Orange.projection.linear as plinear
17
18import numpy as np
19import sys
20
21from plot.owplot import OWPlot
22from plot.owcurve import OWCurve
23from plot import owaxis
24
25class ScreePlot(OWPlot):
26    def __init__(self, parent=None, name="Scree Plot"):
27        OWPlot.__init__(self, parent, name=name)
28        self.cutoff_curve = CutoffCurve([0.0, 0.0], [0.0, 1.0],
29                x_axis_key=owaxis.xBottom, y_axis_key=owaxis.yLeft)
30        self.cutoff_curve.setVisible(False)
31        self.cutoff_curve.set_style(OWCurve.Lines)
32        self.add_custom_curve(self.cutoff_curve)
33
34    def is_cutoff_enabled(self):
35        return self.cutoff_curve and self.cutoff_curve.isVisible()
36
37    def set_cutoff_curve_enabled(self, state):
38        self.cutoff_curve.setVisible(state)
39
40    def set_cutoff_value(self, value):
41        xmin, xmax = self.x_scale()
42        x = min(max(value, xmin), xmax)
43        self.cutoff_curve.set_data([x, x], [0.0, 1.0])
44
45    def mousePressEvent(self, event):
46        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
47            pos = self.mapToScene(event.pos())
48            x, _  = self.map_from_graph(pos)
49            xmin, xmax = self.x_scale()
50            if x >= xmin - 0.1 and x <= xmax + 0.1:
51                x = min(max(x, xmin), xmax)
52                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
53                self.emit_cutoff_moved(x)
54        return QGraphicsView.mousePressEvent(self, event)
55
56    def mouseMoveEvent(self, event):
57        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
58            pos = self.mapToScene(event.pos())
59            x, _ = self.map_from_graph(pos)
60            xmin, xmax = self.x_scale()
61            if x >= xmin - 0.5 and x <= xmax + 0.5:
62                x = min(max(x, xmin), xmax)
63                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
64                self.emit_cutoff_moved(x)
65        return QGraphicsView.mouseMoveEvent(self, event)
66
67    def mouseReleaseEvene(self, event):
68        return QGraphicsView.mouseReleaseEvent(self, event)
69
70    def x_scale(self):
71        ax = self.axes[owaxis.xBottom]
72        if ax.labels:
73            return 0, len(ax.labels) - 1
74        elif ax.scale:
75            return ax.scale[0], ax.scale[1]
76        else:
77            raise ValueError
78
79    def emit_cutoff_moved(self, x):
80        self.emit(SIGNAL("cutoff_moved(double)"), x)
81
82    def set_axis_labels(self, *args):
83        OWPlot.set_axis_labels(self, *args)
84        self.map_transform = self.transform_for_axes()
85
86class CutoffCurve(OWCurve):
87    def __init__(self, *args, **kwargs):
88        OWCurve.__init__(self, *args, **kwargs)
89        self.setAcceptHoverEvents(True)
90        self.setCursor(Qt.SizeHorCursor)
91
92class OWPCA(OWWidget):
93    settingsList = ["standardize", "max_components", "variance_covered",
94                    "use_generalized_eigenvectors"]
95    def __init__(self, parent=None, signalManager=None, title="PCA"):
96        OWWidget.__init__(self, parent, signalManager, title)
97
98        self.inputs = [("Input Data", Orange.data.Table, self.set_data)]
99        self.outputs = [("Transformed Data", Orange.data.Table, Default),
100                        ("Eigen Vectors", Orange.data.Table)]
101
102        self.standardize = True
103        self.max_components = 0
104        self.variance_covered = 100.0
105        self.use_generalized_eigenvectors = False
106
107        self.loadSettings()
108
109        self.data = None
110
111        #####
112        # GUI
113        #####
114        grid = QGridLayout()
115        box = OWGUI.widgetBox(self.controlArea, "Components Selection",
116                              orientation=grid)
117
118        label1 = QLabel("Max components", box)
119        grid.addWidget(label1, 1, 0)
120
121        sb1 = OWGUI.spin(box, self, "max_components", 1, 1000,
122                         tooltip="Maximum number of components",
123                         callback=self.on_update,
124                         addToLayout=False,
125                         keyboardTracking=False
126                         )
127        self.max_components_spin = sb1.control
128        grid.addWidget(sb1.control, 1, 1)
129
130        label2 = QLabel("Variance covered", box)
131        grid.addWidget(label2, 2, 0)
132
133        sb2 = OWGUI.doubleSpin(box, self, "variance_covered", 1.0, 100.0, 1.0,
134                               tooltip="Percent of variance covered.",
135                               callback=self.on_update,
136                               decimals=1,
137                               addToLayout=False,
138                               keyboardTracking=False
139                               )
140        sb2.control.setSuffix("%")
141        grid.addWidget(sb2.control, 2, 1)
142
143        OWGUI.rubber(self.controlArea)
144
145        self.scree_plot = ScreePlot(self)
146#        self.scree_plot.set_main_title("Scree Plot")
147#        self.scree_plot.set_show_main_title(True)
148        self.scree_plot.set_axis_title(owaxis.xBottom, "Principal Components")
149        self.scree_plot.set_show_axis_title(owaxis.xBottom, 1)
150        self.scree_plot.set_axis_title(owaxis.yLeft, "Proportion of Variance")
151        self.scree_plot.set_show_axis_title(owaxis.yLeft, 1)
152
153        self.variance_curve = self.scree_plot.add_curve(
154                        "Variance",
155                        Qt.red, Qt.red, 2, 
156                        xData=[],
157                        yData=[],
158                        style=OWCurve.Lines,
159                        enableLegend=True,
160                        lineWidth=2,
161                        autoScale=1,
162                        x_axis_key=owaxis.xBottom,
163                        y_axis_key=owaxis.yLeft,
164                        )
165
166        self.cumulative_variance_curve = self.scree_plot.add_curve(
167                        "Cumulative Variance",
168                        Qt.darkYellow, Qt.darkYellow, 2, 
169                        xData=[],
170                        yData=[],
171                        style=OWCurve.Lines,
172                        enableLegend=True,
173                        lineWidth=2,
174                        autoScale=1,
175                        x_axis_key=owaxis.xBottom,
176                        y_axis_key=owaxis.yLeft,
177                        )
178
179        self.mainArea.layout().addWidget(self.scree_plot)
180        self.connect(self.scree_plot,
181                     SIGNAL("cutoff_moved(double)"),
182                     self.on_cutoff_moved
183                     )
184        self.components = None
185        self.variances = None
186        self.variances_sum = None
187        self.projector_full = None
188        self.currently_selected = 0
189
190        self.resize(800, 400)
191
192    def clear(self):
193        """Clear widget state
194        """
195        self.data = None
196        self.scree_plot.set_cutoff_curve_enabled(False)
197        self.clear_cached()
198        self.variance_curve.setVisible(False)
199        self.cumulative_variance_curve.setVisible(False)
200
201    def clear_cached(self):
202        """Clear cached components
203        """
204        self.components = None
205        self.variances = None
206        self.variances_cumsum = None
207        self.projector_full = None
208        self.currently_selected = 0
209
210    def set_data(self, data=None):
211        """Set the widget input data.
212        """
213        self.clear()
214        if data is not None:
215            self.data = data
216            self.on_change()
217        else:
218            self.send("Transformed Data", None)
219            self.send("Eigen Vectors", None)
220
221    def on_change(self):
222        """Data has changed and we need to recompute the projection.
223        """
224        if self.data is None:
225            return
226        self.clear_cached()
227        self.apply()
228
229    def on_update(self):
230        """Component selection was changed by the user.
231        """
232        if self.data is None:
233            return
234        self.update_cutoff_curve()
235        if self.currently_selected != self.number_of_selected_components():
236            self.update_components()
237
238    def construct_pca_all_comp(self):
239        pca = plinear.PCA(standardize=self.standardize,
240                          max_components=0,
241                          variance_covered=1,
242                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
243                          )
244        return pca
245
246    def construct_pca(self):
247        max_components = self.max_components
248        variance_covered = self.variance_covered
249        pca = plinear.PCA(standardize=self.standardize,
250                          max_components=max_components,
251                          variance_covered=variance_covered / 100.0,
252                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
253                          )
254        return pca
255
256    def apply(self):
257        """Apply PCA on input data, caching the full projection,
258        then updating the selected components.
259       
260        """
261        pca = self.construct_pca_all_comp()
262        self.projector_full = projector = pca(self.data)
263
264        self.variances = self.projector_full.variances
265        self.variances /= np.sum(self.variances)
266        self.variances_cumsum = np.cumsum(self.variances)
267
268        self.max_components_spin.setRange(1, len(self.variances))
269        self.update_scree_plot()
270        self.update_cutoff_curve()
271        self.update_components()
272
273    def update_components(self):
274        """Update the output components.
275        """
276        scale = self.projector_full.scale
277        center = self.projector_full.center
278        components = self.projector_full.projection
279        input_domain = self.projector_full.input_domain
280        variances = self.projector_full.variances
281        variance_sum = self.projector_full.variance_sum
282
283        # Get selected components (based on max_components and
284        # variance_coverd)
285        pca = self.construct_pca()
286        variances, components, variance_sum = pca._select_components(variances, components)
287
288        projector = plinear.PcaProjector(input_domain=input_domain,
289                                         standardize=self.standardize,
290                                         scale=scale,
291                                         center=center,
292                                         projection=components,
293                                         variances=variances,
294                                         variance_sum=variance_sum)
295        projected_data = projector(self.data)
296        eigenvectors = self.eigenvectors_as_table(components)
297
298        self.currently_selected = self.number_of_selected_components()
299
300        self.send("Transformed Data", projected_data)
301        self.send("Eigen Vectors", eigenvectors)
302
303    def eigenvectors_as_table(self, U):
304        features = [Orange.feature.Continuous("C%i" % i) \
305                    for i in range(1, U.shape[1] + 1)]
306        domain = Orange.data.Domain(features, False)
307        return Orange.data.Table(domain, [list(v) for v in U])
308
309    def update_scree_plot(self):
310        x_space = np.arange(0, len(self.variances))
311        self.scree_plot.set_axis_enabled(owaxis.xBottom, True)
312        self.scree_plot.set_axis_enabled(owaxis.yLeft, True)
313        self.scree_plot.set_axis_labels(owaxis.xBottom, 
314                                        ["PC" + str(i + 1) for i in x_space])
315
316        self.variance_curve.set_data(x_space, self.variances)
317        self.cumulative_variance_curve.set_data(x_space, self.variances_cumsum)
318        self.variance_curve.setVisible(True)
319        self.cumulative_variance_curve.setVisible(True)
320
321        self.scree_plot.set_cutoff_curve_enabled(True)
322
323    def on_cutoff_moved(self, value):
324        """Cutoff curve was moved by the user.
325        """
326        components = int(np.floor(value)) + 1
327        # Did the number of components actually change
328        self.max_components = components
329        self.variance_covered = self.variances_cumsum[components - 1] * 100
330        if self.currently_selected != self.number_of_selected_components():
331#            self.max_components = int(np.floor(value)) + 1
332#            self.variance_covered = self.variances_cumsum[self.max_components - 1] * 100
333            self.update_components()
334
335    def update_cutoff_curve(self):
336        """Update cutoff curve from 'Components Selection' control box.
337        """
338        variance = self.variances_cumsum[self.max_components - 1] * 100.0
339        if variance < self.variance_covered:
340            cutoff = float(self.max_components - 1)
341        else:
342            cutoff = np.searchsorted(self.variances_cumsum,
343                                     self.variance_covered / 100.0)
344        self.scree_plot.set_cutoff_value(cutoff + 0.5)
345
346    def number_of_selected_components(self):
347        """How many components are selected.
348        """
349        if self.data is None:
350            return 0
351
352        variance_components = np.searchsorted(self.variances_cumsum,
353                                              self.variance_covered / 100.0)
354        return min(variance_components + 1, self.max_components)
355
356if __name__ == "__main__":
357    app = QApplication(sys.argv)
358    w = OWPCA()
359    data = Orange.data.Table("iris")
360    w.set_data(data)
361    w.show()
362    app.exec_()
Note: See TracBrowser for help on using the repository browser.