Ignore:
Timestamp:
04/17/12 19:25:52 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Added cutoff line to scree plot.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/OrangeWidgets/Prototypes/OWPCA.py

    r10798 r10799  
    1313 
    1414import numpy as np 
    15  
    1615import sys 
    1716from Orange import orangeqt 
     
    2625    def __init__(self, parent=None, name="Scree Plot"): 
    2726        OWPlot.__init__(self, parent, name=name) 
     27        self.cutoff_curve = CutoffCurve([0.0, 0.0], [0.0, 1.0], 
     28                x_axis_key=owaxis.xBottom, y_axis_key=owaxis.yLeft) 
     29        self.cutoff_curve.setVisible(False) 
     30        self.cutoff_curve.set_style(OWCurve.Lines) 
     31        self.add_custom_curve(self.cutoff_curve) 
     32 
     33    def is_cutoff_enabled(self): 
     34        return self.cutoff_curve and self.cutoff_curve.isVisible() 
     35 
     36    def set_cutoff_curve_enabled(self, state): 
     37        self.cutoff_curve.setVisible(state) 
     38 
     39    def set_cutoff_value(self, value): 
     40        xmin, xmax = self.x_scale() 
     41        x = min(max(value, xmin), xmax) 
     42        self.cutoff_curve.set_data([x, x], [0.0, 1.0]) 
     43 
     44    def mousePressEvent(self, event): 
     45        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton: 
     46            pos = self.mapToScene(event.pos()) 
     47            x, _  = self.map_from_graph(pos) 
     48            xmin, xmax = self.x_scale() 
     49            if x >= xmin and x <= xmax: 
     50                self.cutoff_curve.set_data([x, x], [0.0, 1.0]) 
     51                self.emit_cutoff_moved(x) 
     52        return QGraphicsView.mousePressEvent(self, event) 
     53 
     54    def mouseMoveEvent(self, event): 
     55        if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton: 
     56            pos = self.mapToScene(event.pos()) 
     57            x, _ = self.map_from_graph(pos) 
     58            xmin, xmax = self.x_scale() 
     59            if x >= xmin and x <= xmax: 
     60                self.cutoff_curve.set_data([x, x], [0.0, 1.0]) 
     61                self.emit_cutoff_moved(x) 
     62        return QGraphicsView.mouseMoveEvent(self, event) 
     63 
     64    def mouseReleaseEvene(self, event): 
     65        return QGraphicsView.mouseReleaseEvent(self, event) 
     66 
     67    def x_scale(self): 
     68        ax = self.axes[owaxis.xBottom] 
     69        if ax.labels: 
     70            return 0, len(ax.labels) - 0.5 
     71        elif ax.scale: 
     72            return ax.scale[0], ax.scale[1] + 0.5 
     73        else: 
     74            raise ValueError 
     75 
     76    def emit_cutoff_moved(self, x): 
     77        self.emit(SIGNAL("cutoff_moved(double)"), x) 
     78 
     79    def set_axis_labels(self, *args): 
     80        OWPlot.set_axis_labels(self, *args) 
     81        self.map_transform = self.transform_for_axes() 
     82 
     83class CutoffCurve(OWCurve): 
     84    def __init__(self, *args, **kwargs): 
     85        OWCurve.__init__(self, *args, **kwargs) 
     86        self.setAcceptHoverEvents(True) 
     87        self.setCursor(Qt.SizeHorCursor) 
    2888 
    2989class OWPCA(OWWidget): 
     
    3393        OWWidget.__init__(self, parent, signalManager, title) 
    3494 
    35         self.inputs = [("Data", Orange.data.Table, self.set_data)] 
    36         self.outputs = [("Projected Data", Orange.data.Table, Default), 
    37                         ("PCA Projector", Orange.projection.linear.PcaProjector), 
    38                         ("Principal Vectors", Orange.data.Table) 
    39                         ] 
     95        self.inputs = [("Input Data", Orange.data.Table, self.set_data)] 
     96        self.outputs = [("Transformed Data", Orange.data.Table, Default), 
     97                        ("Eigen Vectors", Orange.data.Table)] 
    4098 
    4199        self.standardize = True 
    42         self.use_variance_covered = 0 
    43100        self.max_components = 0 
    44101        self.variance_covered = 100.0 
     
    55112        box = OWGUI.widgetBox(self.controlArea, "Settings", 
    56113                              orientation=grid) 
    57         cb = OWGUI.checkBox(box, self, "standardize", "Standardize", 
    58                             tooltip="Standardize all input features.", 
    59                             callback=self.on_change,  
    60                             addToLayout=False 
    61                             ) 
    62         grid.addWidget(cb, 0, 0) 
    63  
    64 #        OWGUI.radioButtonsInBox(box, self, "use_variance_covered", [], 
    65 #                                callback=self.on_update) 
    66 #        rb1 = OWGUI.appendRadioButton(box, self, "use_variance_covered", 
    67 #                                      "Max components", 
    68 #                                      tooltip="Select max components", 
    69 #                                      callback=self.on_update, 
    70 #                                      addToLayout=False 
    71 #                                      ) 
    72 #        grid.addWidget(rb1, 1, 0) 
     114 
    73115        label1 = QLabel("Max components", box) 
    74116        grid.addWidget(label1, 1, 0) 
    75117 
    76         sb1 = OWGUI.spin(box, self, "max_components", 0, 1000, 
     118        sb1 = OWGUI.spin(box, self, "max_components", 1, 1000, 
    77119                         tooltip="Maximum number of components", 
    78                          callback=self.on_change, 
     120                         callback=self.on_update, 
    79121                         addToLayout=False, 
    80122                         keyboardTracking=False 
    81123                         ) 
    82         sb1.control.setSpecialValueText("All") 
     124        self.max_components_spin = sb1.control 
    83125        grid.addWidget(sb1.control, 1, 1) 
    84126 
    85 #        rb2 = OWGUI.appendRadioButton(box, self, "use_variance_covered", 
    86 #                                      "Variance covered",  
    87 #                                      tooltip="Percent of variance covered.", 
    88 #                                      callback=self.on_update, 
    89 #                                      addToLayout=False 
    90 #                                      ) 
    91 #        grid.addWidget(rb2, 2, 0) 
    92127        label2 = QLabel("Variance covered", box) 
    93128        grid.addWidget(label2, 2, 0) 
    94129 
    95         sb2 = OWGUI.doubleSpin(box, self, "variance_covered", 1.0, 100.0, 5.0, 
     130        sb2 = OWGUI.doubleSpin(box, self, "variance_covered", 1.0, 100.0, 1.0, 
    96131                               tooltip="Percent of variance covered.", 
    97                                callback=self.on_change, 
     132                               callback=self.on_update, 
    98133                               decimals=1, 
    99134                               addToLayout=False, 
     
    102137        sb2.control.setSuffix("%") 
    103138        grid.addWidget(sb2.control, 2, 1) 
    104  
    105         cb = OWGUI.checkBox(box, self, "use_generalized_eigenvectors", 
    106                             "Use generalized eigenvectors", 
    107                             callback=self.on_change, 
    108                             addToLayout=False, 
    109                             ) 
    110         grid.addWidget(cb, 3, 0, 1, 2) 
    111139 
    112140        OWGUI.rubber(self.controlArea) 
     
    119147        self.scree_plot.set_axis_title(owaxis.yLeft, "Proportion of Variance") 
    120148        self.scree_plot.set_show_axis_title(owaxis.yLeft, 1) 
    121          
     149 
     150        self.variance_curve = self.scree_plot.add_curve( 
     151                        "Variance", 
     152                        Qt.red, Qt.red, 2,  
     153                        xData=[], 
     154                        yData=[], 
     155                        style=OWCurve.Lines, 
     156                        enableLegend=True, 
     157                        lineWidth=2, 
     158                        autoScale=1, 
     159                        x_axis_key=owaxis.xBottom, 
     160                        y_axis_key=owaxis.yLeft, 
     161                        ) 
     162 
     163        self.cumulative_variance_curve = self.scree_plot.add_curve( 
     164                        "Cumulative Variance", 
     165                        Qt.darkYellow, Qt.darkYellow, 2,  
     166                        xData=[], 
     167                        yData=[], 
     168                        style=OWCurve.Lines, 
     169                        enableLegend=True, 
     170                        lineWidth=2, 
     171                        autoScale=1, 
     172                        x_axis_key=owaxis.xBottom, 
     173                        y_axis_key=owaxis.yLeft, 
     174                        ) 
     175 
    122176        self.mainArea.layout().addWidget(self.scree_plot) 
    123  
     177        self.connect(self.scree_plot, 
     178                     SIGNAL("cutoff_moved(double)"), 
     179                     self.on_cutoff_moved 
     180                     ) 
    124181        self.components = None 
    125182        self.variances = None 
     
    127184        self.projector_full = None 
    128185 
    129         self.resize(800, 600) 
     186        self.resize(800, 400) 
    130187 
    131188    def clear(self): 
     
    133190        """ 
    134191        self.data = None 
     192        self.scree_plot.set_cutoff_curve_enabled(False) 
    135193        self.clear_cached() 
    136          
     194        self.variance_curve.setVisible(False) 
     195        self.cumulative_variance_curve.setVisible(False) 
     196 
    137197    def clear_cached(self): 
    138198        """Clear cached components 
     
    140200        self.components = None 
    141201        self.variances = None 
    142         self.variances_sum = None 
     202        self.variances_cumsum = None 
    143203        self.projector_full = None 
    144204 
     
    160220        if self.data is None: 
    161221            return 
     222        self.update_cutoff_curve() 
    162223        self.update_components() 
    163224 
     
    171232 
    172233    def construct_pca(self): 
    173         max_components = self.max_components #if not self.use_variance_covered else 0 
    174         variance_covered = self.variance_covered #if self.use_variance_covered else 0 
     234        max_components = self.max_components 
     235        variance_covered = self.variance_covered 
    175236        pca = plinear.PCA(standardize=self.standardize, 
    176237                          max_components=max_components, 
     
    187248        pca = self.construct_pca_all_comp() 
    188249        self.projector_full = projector = pca(self.data) 
     250 
     251        self.variances = self.projector_full.variances 
     252        self.variances /= np.sum(self.variances) 
     253        self.variances_cumsum = np.cumsum(self.variances) 
     254 
     255        self.max_components_spin.setRange(1, len(self.variances)) 
    189256        self.update_scree_plot() 
    190257        self.update_components() 
     
    213280        eigenvectors = self.eigenvectors_as_table(components) 
    214281 
    215         self.send("Projected Data", projected_data) 
    216         self.send("PCA Projector", projector) 
    217         self.send("Principal Vectors", eigenvectors) 
     282        self.send("Transformed Data", projected_data) 
     283        self.send("Eigen Vectors", eigenvectors) 
    218284 
    219285    def eigenvectors_as_table(self, U): 
     
    234300                                        ["PC" + str(i + 1) for i in x_space]) 
    235301 
    236         self.c = self.scree_plot.add_curve("Variance", 
    237                         Qt.red, Qt.red, 2,  
    238                         xData=x_space, 
    239                         yData=cv, 
    240                         style=OWCurve.Lines, 
    241                         enableLegend=True, 
    242                         lineWidth=2, 
    243                         autoScale=1, 
    244                         x_axis_key=owaxis.xBottom, 
    245                         y_axis_key=owaxis.yLeft, 
    246                         ) 
    247          
    248         self.c = self.scree_plot.add_curve("Cumulative Variance", 
    249                         Qt.darkYellow, Qt.darkYellow, 2,  
    250                         xData=x_space, 
    251                         yData=cs, 
    252                         style=OWCurve.Lines, 
    253                         enableLegend=True, 
    254                         lineWidth=2, 
    255                         autoScale=1, 
    256                         x_axis_key=owaxis.xBottom, 
    257                         y_axis_key=owaxis.yLeft, 
    258                         ) 
     302        self.variance_curve.set_data(x_space, cv) 
     303        self.cumulative_variance_curve.set_data(x_space, cs) 
     304        self.variance_curve.setVisible(True) 
     305        self.cumulative_variance_curve.setVisible(True) 
     306 
     307        self.scree_plot.set_cutoff_curve_enabled(True) 
     308 
     309    def on_cutoff_moved(self, value): 
     310        components = int(np.floor(value)) + 1 
     311        if components != self.max_components: 
     312            self.max_components = int(np.floor(value)) + 1 
     313            self.variance_covered = self.variances_cumsum[self.max_components - 1] * 100 
     314            self.update_components() 
     315 
     316    def update_cutoff_curve(self): 
     317        """Update cutoff line from gui control elements. 
     318        """ 
     319        variance = self.variances_cumsum[self.max_components - 1] * 100.0 
     320        if variance < self.variance_covered: 
     321            cutoff = float(self.max_components - 1) 
     322        else: 
     323            cutoff = np.searchsorted(self.variances_cumsum, 
     324                                     self.variance_covered / 100.0) 
     325        self.scree_plot.set_cutoff_value(cutoff + 0.5) 
     326 
    259327 
    260328if __name__ == "__main__": 
Note: See TracChangeset for help on using the changeset viewer.