source: orange/Orange/OrangeWidgets/Prototypes/OWPCA.py @ 10799:05f46689f643

Revision 10799:05f46689f643, 11.7 KB checked in by Ales Erjavec <ales.erjavec@…>, 2 years ago (diff)

Added cutoff line to scree plot.

Line 
1"""
2<name>PCA</name>
3
4"""
5import Orange
6import Orange.utils.addons
7
8from OWWidget import *
9import OWGUI
10
11import Orange
12import Orange.projection.linear as plinear
13
14import numpy as np
15import sys
16from Orange import orangeqt
17
18sys.modules["orangeqt"] = orangeqt
19
20from plot.owplot import OWPlot
21from plot.owcurve import OWCurve
22from plot import owaxis
23
24class ScreePlot(OWPlot):
25    def __init__(self, parent=None, name="Scree Plot"):
26        OWPlot.__init__(self, parent, name=name)
27        self.cutoff_curve = CutoffCurve([0.0, 0.0], [0.0, 1.0],
28                x_axis_key=owaxis.xBottom, y_axis_key=owaxis.yLeft)
29        self.cutoff_curve.setVisible(False)
30        self.cutoff_curve.set_style(OWCurve.Lines)
31        self.add_custom_curve(self.cutoff_curve)
32
33    def is_cutoff_enabled(self):
34        return self.cutoff_curve and self.cutoff_curve.isVisible()
35
36    def set_cutoff_curve_enabled(self, state):
37        self.cutoff_curve.setVisible(state)
38
39    def set_cutoff_value(self, value):
40        xmin, xmax = self.x_scale()
41        x = min(max(value, xmin), xmax)
42        self.cutoff_curve.set_data([x, x], [0.0, 1.0])
43
44    def mousePressEvent(self, event):
45        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
46            pos = self.mapToScene(event.pos())
47            x, _  = self.map_from_graph(pos)
48            xmin, xmax = self.x_scale()
49            if x >= xmin and x <= xmax:
50                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
51                self.emit_cutoff_moved(x)
52        return QGraphicsView.mousePressEvent(self, event)
53
54    def mouseMoveEvent(self, event):
55        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton:
56            pos = self.mapToScene(event.pos())
57            x, _ = self.map_from_graph(pos)
58            xmin, xmax = self.x_scale()
59            if x >= xmin and x <= xmax:
60                self.cutoff_curve.set_data([x, x], [0.0, 1.0])
61                self.emit_cutoff_moved(x)
62        return QGraphicsView.mouseMoveEvent(self, event)
63
64    def mouseReleaseEvene(self, event):
65        return QGraphicsView.mouseReleaseEvent(self, event)
66
67    def x_scale(self):
68        ax = self.axes[owaxis.xBottom]
69        if ax.labels:
70            return 0, len(ax.labels) - 0.5
71        elif ax.scale:
72            return ax.scale[0], ax.scale[1] + 0.5
73        else:
74            raise ValueError
75
76    def emit_cutoff_moved(self, x):
77        self.emit(SIGNAL("cutoff_moved(double)"), x)
78
79    def set_axis_labels(self, *args):
80        OWPlot.set_axis_labels(self, *args)
81        self.map_transform = self.transform_for_axes()
82
83class CutoffCurve(OWCurve):
84    def __init__(self, *args, **kwargs):
85        OWCurve.__init__(self, *args, **kwargs)
86        self.setAcceptHoverEvents(True)
87        self.setCursor(Qt.SizeHorCursor)
88
89class OWPCA(OWWidget):
90    settingsList = ["standardize", "max_components", "variance_covered",
91                    "use_generalized_eigenvectors"]
92    def __init__(self, parent=None, signalManager=None, title="PCA"):
93        OWWidget.__init__(self, parent, signalManager, title)
94
95        self.inputs = [("Input Data", Orange.data.Table, self.set_data)]
96        self.outputs = [("Transformed Data", Orange.data.Table, Default),
97                        ("Eigen Vectors", Orange.data.Table)]
98
99        self.standardize = True
100        self.max_components = 0
101        self.variance_covered = 100.0
102        self.use_generalized_eigenvectors = False
103
104        self.loadSettings()
105
106        self.data = None
107
108        #####
109        # GUI
110        #####
111        grid = QGridLayout()
112        box = OWGUI.widgetBox(self.controlArea, "Settings",
113                              orientation=grid)
114
115        label1 = QLabel("Max components", box)
116        grid.addWidget(label1, 1, 0)
117
118        sb1 = OWGUI.spin(box, self, "max_components", 1, 1000,
119                         tooltip="Maximum number of components",
120                         callback=self.on_update,
121                         addToLayout=False,
122                         keyboardTracking=False
123                         )
124        self.max_components_spin = sb1.control
125        grid.addWidget(sb1.control, 1, 1)
126
127        label2 = QLabel("Variance covered", box)
128        grid.addWidget(label2, 2, 0)
129
130        sb2 = OWGUI.doubleSpin(box, self, "variance_covered", 1.0, 100.0, 1.0,
131                               tooltip="Percent of variance covered.",
132                               callback=self.on_update,
133                               decimals=1,
134                               addToLayout=False,
135                               keyboardTracking=False
136                               )
137        sb2.control.setSuffix("%")
138        grid.addWidget(sb2.control, 2, 1)
139
140        OWGUI.rubber(self.controlArea)
141
142        self.scree_plot = ScreePlot(self)
143#        self.scree_plot.set_main_title("Scree Plot")
144#        self.scree_plot.set_show_main_title(True)
145        self.scree_plot.set_axis_title(owaxis.xBottom, "Principal Components")
146        self.scree_plot.set_show_axis_title(owaxis.xBottom, 1)
147        self.scree_plot.set_axis_title(owaxis.yLeft, "Proportion of Variance")
148        self.scree_plot.set_show_axis_title(owaxis.yLeft, 1)
149
150        self.variance_curve = self.scree_plot.add_curve(
151                        "Variance",
152                        Qt.red, Qt.red, 2, 
153                        xData=[],
154                        yData=[],
155                        style=OWCurve.Lines,
156                        enableLegend=True,
157                        lineWidth=2,
158                        autoScale=1,
159                        x_axis_key=owaxis.xBottom,
160                        y_axis_key=owaxis.yLeft,
161                        )
162
163        self.cumulative_variance_curve = self.scree_plot.add_curve(
164                        "Cumulative Variance",
165                        Qt.darkYellow, Qt.darkYellow, 2, 
166                        xData=[],
167                        yData=[],
168                        style=OWCurve.Lines,
169                        enableLegend=True,
170                        lineWidth=2,
171                        autoScale=1,
172                        x_axis_key=owaxis.xBottom,
173                        y_axis_key=owaxis.yLeft,
174                        )
175
176        self.mainArea.layout().addWidget(self.scree_plot)
177        self.connect(self.scree_plot,
178                     SIGNAL("cutoff_moved(double)"),
179                     self.on_cutoff_moved
180                     )
181        self.components = None
182        self.variances = None
183        self.variances_sum = None
184        self.projector_full = None
185
186        self.resize(800, 400)
187
188    def clear(self):
189        """Clear widget state
190        """
191        self.data = None
192        self.scree_plot.set_cutoff_curve_enabled(False)
193        self.clear_cached()
194        self.variance_curve.setVisible(False)
195        self.cumulative_variance_curve.setVisible(False)
196
197    def clear_cached(self):
198        """Clear cached components
199        """
200        self.components = None
201        self.variances = None
202        self.variances_cumsum = None
203        self.projector_full = None
204
205    def set_data(self, data=None):
206        """Set the widget input data.
207        """
208        self.clear()
209        if data is not None:
210            self.data = data
211            self.on_change()
212
213    def on_change(self):
214        if self.data is None:
215            return
216        self.clear_cached()
217        self.apply()
218
219    def on_update(self):
220        if self.data is None:
221            return
222        self.update_cutoff_curve()
223        self.update_components()
224
225    def construct_pca_all_comp(self):
226        pca = plinear.PCA(standardize=self.standardize,
227                          max_components=0,
228                          variance_covered=1,
229                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
230                          )
231        return pca
232
233    def construct_pca(self):
234        max_components = self.max_components
235        variance_covered = self.variance_covered
236        pca = plinear.PCA(standardize=self.standardize,
237                          max_components=max_components,
238                          variance_covered=variance_covered / 100.0,
239                          use_generalized_eigenvectors=self.use_generalized_eigenvectors
240                          )
241        return pca
242
243    def apply(self):
244        """Apply PCA in input data, caching the full projection,
245        then updating the selected components.
246       
247        """
248        pca = self.construct_pca_all_comp()
249        self.projector_full = projector = pca(self.data)
250
251        self.variances = self.projector_full.variances
252        self.variances /= np.sum(self.variances)
253        self.variances_cumsum = np.cumsum(self.variances)
254
255        self.max_components_spin.setRange(1, len(self.variances))
256        self.update_scree_plot()
257        self.update_components()
258
259    def update_components(self):
260        scale = self.projector_full.scale
261        center = self.projector_full.center
262        components = self.projector_full.projection
263        input_domain = self.projector_full.input_domain
264        variances = self.projector_full.variances
265        variance_sum = self.projector_full.variance_sum
266
267        # Get selected components (based on max_components and
268        # variance_coverd)
269        pca = self.construct_pca()
270        variances, components, variance_sum = pca._select_components(variances, components)
271
272        projector = plinear.PcaProjector(input_domain=input_domain,
273                                         standardize=self.standardize,
274                                         scale=scale,
275                                         center=center,
276                                         projection=components,
277                                         variances=variances,
278                                         variance_sum=variance_sum)
279        projected_data = projector(self.data)
280        eigenvectors = self.eigenvectors_as_table(components)
281
282        self.send("Transformed Data", projected_data)
283        self.send("Eigen Vectors", eigenvectors)
284
285    def eigenvectors_as_table(self, U):
286        features = [Orange.feature.Continuous("C%i" % i) \
287                    for i in range(1, U.shape[1] + 1)]
288        domain = Orange.data.Domain(features, False)
289        return Orange.data.Table(domain, [list(v) for v in U])
290
291    def update_scree_plot(self):
292        variances = self.projector_full.variances
293        s = np.sum(variances)
294        cv = variances / s
295        cs = np.cumsum(cv)
296        x_space = np.arange(0, len(variances))
297        self.scree_plot.set_axis_enabled(owaxis.xBottom, True)
298        self.scree_plot.set_axis_enabled(owaxis.yLeft, True)
299        self.scree_plot.set_axis_labels(owaxis.xBottom, 
300                                        ["PC" + str(i + 1) for i in x_space])
301
302        self.variance_curve.set_data(x_space, cv)
303        self.cumulative_variance_curve.set_data(x_space, cs)
304        self.variance_curve.setVisible(True)
305        self.cumulative_variance_curve.setVisible(True)
306
307        self.scree_plot.set_cutoff_curve_enabled(True)
308
309    def on_cutoff_moved(self, value):
310        components = int(np.floor(value)) + 1
311        if components != self.max_components:
312            self.max_components = int(np.floor(value)) + 1
313            self.variance_covered = self.variances_cumsum[self.max_components - 1] * 100
314            self.update_components()
315
316    def update_cutoff_curve(self):
317        """Update cutoff line from gui control elements.
318        """
319        variance = self.variances_cumsum[self.max_components - 1] * 100.0
320        if variance < self.variance_covered:
321            cutoff = float(self.max_components - 1)
322        else:
323            cutoff = np.searchsorted(self.variances_cumsum,
324                                     self.variance_covered / 100.0)
325        self.scree_plot.set_cutoff_value(cutoff + 0.5)
326
327
328if __name__ == "__main__":
329    app = QApplication(sys.argv)
330    w = OWPCA()
331    data = Orange.data.Table("iris")
332    w.set_data(data)
333    w.show()
334    app.exec_()
Note: See TracBrowser for help on using the repository browser.