Changeset 11820:399588413aa1 in orange


Ignore:
Timestamp:
12/30/13 15:34:50 (4 months ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Using PyQwt5 for PCA scree plot.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/OrangeWidgets/Unsupervised/OWPCA.py

    r11454 r11820  
    88 
    99""" 
    10 import Orange 
    11 import Orange.utils.addons 
    12  
    13 from OWWidget import * 
    14 import OWGUI 
     10import sys 
     11 
     12import numpy as np 
     13 
     14from PyQt4.Qwt5 import QwtPlot, QwtPlotCurve, QwtSymbol 
     15from PyQt4.QtCore import pyqtSignal as Signal, pyqtSlot as Slot 
    1516 
    1617import Orange 
    1718import Orange.projection.linear as plinear 
    1819 
    19 import numpy as np 
    20 import sys 
    21  
    22 from plot.owplot import OWPlot 
    23 from plot.owcurve import OWCurve 
    24 from plot import owaxis 
    25  
    26  
    27 class ScreePlot(OWPlot): 
    28     def __init__(self, parent=None, name="Scree Plot"): 
    29         OWPlot.__init__(self, parent, name=name) 
    30         self.cutoff_curve = CutoffCurve([0.0, 0.0], [0.0, 1.0], 
    31                 x_axis_key=owaxis.xBottom, y_axis_key=owaxis.yLeft) 
    32         self.cutoff_curve.setVisible(False) 
    33         self.cutoff_curve.set_style(OWCurve.Lines) 
    34         self.add_custom_curve(self.cutoff_curve) 
    35  
    36     def is_cutoff_enabled(self): 
    37         return self.cutoff_curve and self.cutoff_curve.isVisible() 
    38  
    39     def set_cutoff_curve_enabled(self, state): 
    40         self.cutoff_curve.setVisible(state) 
    41  
    42     def set_cutoff_value(self, value): 
    43         xmin, xmax = self.x_scale() 
    44         x = min(max(value, xmin), xmax) 
    45         self.cutoff_curve.set_data([x, x], [0.0, 1.0]) 
    46  
     20from OWWidget import * 
     21from OWGraph import OWGraph 
     22 
     23import OWGUI 
     24 
     25 
     26def plot_curve(title=None, pen=None, brush=None, style=QwtPlotCurve.Lines, 
     27               symbol=QwtSymbol.Ellipse, legend=True, antialias=True, 
     28               auto_scale=True, xaxis=QwtPlot.xBottom, yaxis=QwtPlot.yLeft): 
     29    curve = QwtPlotCurve(title or "") 
     30    return configure_curve(curve, pen=pen, brush=brush, style=style, 
     31                           symbol=symbol, legend=legend, antialias=antialias, 
     32                           auto_scale=auto_scale, xaxis=xaxis, yaxis=yaxis) 
     33 
     34 
     35def configure_curve(curve, title=None, pen=None, brush=None, 
     36          style=QwtPlotCurve.Lines, symbol=QwtSymbol.Ellipse, 
     37          legend=True, antialias=True, auto_scale=True, 
     38          xaxis=QwtPlot.xBottom, yaxis=QwtPlot.yLeft): 
     39    if title is not None: 
     40        curve.setTitle(title) 
     41    if pen is not None: 
     42        curve.setPen(pen) 
     43 
     44    if brush is not None: 
     45        curve.setBrush(brush) 
     46 
     47    if not isinstance(symbol, QwtSymbol): 
     48        symbol_ = QwtSymbol() 
     49        symbol_.setStyle(symbol) 
     50        symbol = symbol_ 
     51 
     52    curve.setStyle(style) 
     53    curve.setSymbol(QwtSymbol(symbol)) 
     54    curve.setRenderHint(QwtPlotCurve.RenderAntialiased, antialias) 
     55    curve.setItemAttribute(QwtPlotCurve.Legend, legend) 
     56    curve.setItemAttribute(QwtPlotCurve.AutoScale, auto_scale) 
     57    curve.setAxis(xaxis, yaxis) 
     58    return curve 
     59 
     60 
     61class PlotTool(QObject): 
     62    """ 
     63    A base class for Plot tools that operate on QwtPlot's canvas 
     64    widget by installing itself as its event filter. 
     65 
     66    """ 
     67    cursor = Qt.ArrowCursor 
     68 
     69    def __init__(self, parent=None, graph=None): 
     70        QObject.__init__(self, parent) 
     71        self.__graph = None 
     72        self.__oldCursor = None 
     73        self.setGraph(graph) 
     74 
     75    def setGraph(self, graph): 
     76        """ 
     77        Install this tool to operate on ``graph``. 
     78        """ 
     79        if self.__graph is graph: 
     80            return 
     81 
     82        if self.__graph is not None: 
     83            self.uninstall(self.__graph) 
     84 
     85        self.__graph = graph 
     86 
     87        if graph is not None: 
     88            self.install(graph) 
     89 
     90    def graph(self): 
     91        return self.__graph 
     92 
     93    def install(self, graph): 
     94        canvas = graph.canvas() 
     95        canvas.setMouseTracking(True) 
     96        canvas.installEventFilter(self) 
     97        canvas.destroyed.connect(self.__on_destroyed) 
     98        self.__oldCursor = canvas.cursor() 
     99        canvas.setCursor(self.cursor) 
     100 
     101    def uninstall(self, graph): 
     102        canvas = graph.canvas() 
     103        canvas.removeEventFilter(self) 
     104        canvas.setCursor(self.__oldCursor) 
     105        canvas.destroyed.disconnect(self.__on_destroyed) 
     106        self.__oldCursor = None 
     107 
     108    def eventFilter(self, obj, event): 
     109        if obj is self.__graph.canvas(): 
     110            return self.canvasEvent(event) 
     111        return False 
     112 
     113    def canvasEvent(self, event): 
     114        """ 
     115        Main handler for a canvas events. 
     116        """ 
     117        if event.type() == QEvent.MouseButtonPress: 
     118            return self.mousePressEvent(event) 
     119        elif event.type() == QEvent.MouseButtonRelease: 
     120            return self.mouseReleaseEvent(event) 
     121        elif event.type() == QEvent.MouseButtonDblClick: 
     122            return self.mouseDoubleClickEvent(event) 
     123        elif event.type() == QEvent.MouseMove: 
     124            return self.mouseMoveEvent(event) 
     125        elif event.type() == QEvent.Leave: 
     126            return self.leaveEvent(event) 
     127        elif event.type() == QEvent.Enter: 
     128            return self.enterEvent(event) 
     129        return False 
     130 
     131    # These are actually event filters (note the return values) 
    47132    def mousePressEvent(self, event): 
    48         if self.isLegendEvent(event, QGraphicsView.mousePressEvent): 
    49             return 
    50  
    51         if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton: 
    52             pos = self.mapToScene(event.pos()) 
    53             x, _ = self.map_from_graph(pos) 
    54             xmin, xmax = self.x_scale() 
    55             if x >= xmin - 0.1 and x <= xmax + 0.1: 
    56                 x = min(max(x, xmin), xmax) 
    57                 self.cutoff_curve.set_data([x, x], [0.0, 1.0]) 
    58                 self.emit_cutoff_moved(x) 
    59         return QGraphicsView.mousePressEvent(self, event) 
     133        return False 
    60134 
    61135    def mouseMoveEvent(self, event): 
    62         if self.isLegendEvent(event, QGraphicsView.mouseMoveEvent): 
    63             return 
    64  
    65         if self.is_cutoff_enabled() and event.buttons() & Qt.LeftButton: 
    66             pos = self.mapToScene(event.pos()) 
    67             x, _ = self.map_from_graph(pos) 
    68             xmin, xmax = self.x_scale() 
    69             if x >= xmin - 0.5 and x <= xmax + 0.5: 
    70                 x = min(max(x, xmin), xmax) 
    71                 self.cutoff_curve.set_data([x, x], [0.0, 1.0]) 
    72                 self.emit_cutoff_moved(x) 
    73         elif self.is_cutoff_enabled() and \ 
    74                 self.is_pos_over_cutoff_line(event.pos()): 
    75             self.setCursor(Qt.SizeHorCursor) 
     136        return False 
     137 
     138    def mouseReleaseEvent(self, event): 
     139        return False 
     140 
     141    def mouseDoubleClickEvent(self, event): 
     142        return False 
     143 
     144    def enterEvent(self, event): 
     145        return False 
     146 
     147    def leaveEvent(self, event): 
     148        return False 
     149 
     150    def keyPressEvent(self, event): 
     151        return False 
     152 
     153    def transform(self, point, xaxis=QwtPlot.xBottom, yaxis=QwtPlot.yLeft): 
     154        """ 
     155        Transform a QPointF from plot coordinates to canvas local coordinates. 
     156        """ 
     157        x = self.__graph.transform(xaxis, point.x()) 
     158        y = self.__graph.transform(yaxis, point.y()) 
     159        return QPoint(x, y) 
     160 
     161    def invTransform(self, point, xaxis=QwtPlot.xBottom, yaxis=QwtPlot.yLeft): 
     162        """ 
     163        Transform a QPoint from canvas local coordinates to plot coordinates. 
     164        """ 
     165        x = self.__graph.invTransform(xaxis, point.x()) 
     166        y = self.__graph.invTransform(yaxis, point.y()) 
     167        return QPointF(x, y) 
     168 
     169    @Slot() 
     170    def __on_destroyed(self, obj): 
     171        obj.removeEventFilter(self) 
     172 
     173 
     174class CutoffControler(PlotTool): 
     175 
     176    class CutoffCurve(QwtPlotCurve): 
     177        pass 
     178 
     179    cutoffChanged = Signal(float) 
     180    cutoffMoved = Signal(float) 
     181    cutoffPressed = Signal() 
     182    cutoffReleased = Signal() 
     183 
     184    NoState, Drag = 0, 1 
     185 
     186    def __init__(self, parent=None, graph=None): 
     187        self.__curve = None 
     188        self.__range = (0, 1) 
     189        self.__cutoff = 0 
     190        super(CutoffControler, self).__init__(parent, graph) 
     191        self._state = self.NoState 
     192 
     193    def install(self, graph): 
     194        super(CutoffControler, self).install(graph) 
     195        assert self.__curve is None 
     196        self.__curve = CutoffControler.CutoffCurve("") 
     197        configure_curve(self.__curve, symbol=QwtSymbol.NoSymbol, legend=False) 
     198        self.__curve.setData([self.__cutoff, self.__cutoff], [0.0, 1.0]) 
     199        self.__curve.attach(graph) 
     200 
     201    def uninstall(self, graph): 
     202        super(CutoffControler, self).uninstall(graph) 
     203        self.__curve.detach() 
     204        self.__curve = None 
     205 
     206    def _toRange(self, value): 
     207        minval, maxval = self.__range 
     208        return max(min(value, maxval), minval) 
     209 
     210    def mousePressEvent(self, event): 
     211        if event.button() == Qt.LeftButton: 
     212            cut = self.invTransform(event.pos()).x() 
     213            self.setCutoff(cut) 
     214            self.cutoffPressed.emit() 
     215            self._state = self.Drag 
     216        return True 
     217 
     218    def mouseMoveEvent(self, event): 
     219        if self._state == self.Drag: 
     220            cut = self._toRange(self.invTransform(event.pos()).x()) 
     221            self.setCutoff(cut) 
     222            self.cutoffMoved.emit(cut) 
    76223        else: 
    77             self.setCursor(Qt.ArrowCursor) 
    78  
    79         return QGraphicsView.mouseMoveEvent(self, event) 
    80  
    81     def mouseReleaseEvene(self, event): 
    82         return QGraphicsView.mouseReleaseEvent(self, event) 
    83  
    84     def x_scale(self): 
    85         ax = self.axes[owaxis.xBottom] 
    86         if ax.labels: 
    87             return 0, len(ax.labels) - 1 
    88         elif ax.scale: 
    89             return ax.scale[0], ax.scale[1] 
    90         else: 
    91             raise ValueError 
    92  
    93     def emit_cutoff_moved(self, x): 
    94         self.emit(SIGNAL("cutoff_moved(double)"), x) 
    95  
    96     def set_axis_labels(self, *args): 
    97         OWPlot.set_axis_labels(self, *args) 
    98         self.map_transform = self.transform_for_axes() 
    99  
    100     def is_pos_over_cutoff_line(self, pos): 
    101         x1 = self.inv_transform(owaxis.xBottom, pos.x() - 1.5) 
    102         x2 = self.inv_transform(owaxis.xBottom, pos.x() + 1.5) 
    103         y = self.inv_transform(owaxis.yLeft, pos.y()) 
    104         if y < 0.0 or y > 1.0: 
    105             return False 
    106         curve_data = self.cutoff_curve.data() 
    107         if not curve_data: 
    108             return False 
    109         cutoff = curve_data[0][0] 
    110         return x1 < cutoff and cutoff < x2 
    111  
    112  
    113 class CutoffCurve(OWCurve): 
     224            cx = self.transform(QPointF(self.cutoff(), 0)).x() 
     225            if abs(cx - event.pos().x()) < 2: 
     226                self.graph().canvas().setCursor(Qt.SizeHorCursor) 
     227            else: 
     228                self.graph().canvas().setCursor(self.cursor) 
     229        return True 
     230 
     231    def mouseReleaseEvent(self, event): 
     232        if event.button() == Qt.LeftButton and self._state == self.Drag: 
     233            cut = self._toRange(self.invTransform(event.pos()).x()) 
     234            self.setCutoff(cut) 
     235            self.cutoffReleased.emit() 
     236            self._state = self.NoState 
     237        return True 
     238 
     239    def setCutoff(self, cutoff): 
     240        minval, maxval = self.__range 
     241        cutoff = max(min(cutoff, maxval), minval) 
     242        if self.__cutoff != cutoff: 
     243            self.__cutoff = cutoff 
     244            if self.__curve is not None: 
     245                self.__curve.setData([cutoff, cutoff], [0.0, 1.0]) 
     246            self.cutoffChanged.emit(cutoff) 
     247            if self.graph() is not None: 
     248                self.graph().replot() 
     249 
     250    def cutoff(self): 
     251        return self.__cutoff 
     252 
     253    def setRange(self, minval, maxval): 
     254        maxval = max(minval, maxval) 
     255        if self.__range != (minval, maxval): 
     256            self.__range = (minval, maxval) 
     257            self.setCutoff(max(min(self.cutoff(), maxval), minval)) 
     258 
     259 
     260class Graph(OWGraph): 
    114261    def __init__(self, *args, **kwargs): 
    115         OWCurve.__init__(self, *args, **kwargs) 
    116         self.setAcceptHoverEvents(True) 
    117         self.setCursor(Qt.SizeHorCursor) 
     262        super(Graph, self).__init__(*args, **kwargs) 
     263        self.gridCurve.attach(self) 
     264 
     265    # bypass the OWGraph event handlers 
     266    def mousePressEvent(self, event): 
     267        QwtPlot.mousePressEvent(self, event) 
     268 
     269    def mouseMoveEvent(self, event): 
     270        QwtPlot.mouseMoveEvent(self, event) 
     271 
     272    def mouseReleaseEvent(self, event): 
     273        QwtPlot.mouseReleaseEvent(self, event) 
    118274 
    119275 
     
    181337        OWGUI.setStopper(self, b, cb, "changed_flag", self.update_components) 
    182338 
    183         self.scree_plot = ScreePlot(self) 
    184 #        self.scree_plot.set_main_title("Scree Plot") 
    185 #        self.scree_plot.set_show_main_title(True) 
    186         self.scree_plot.set_axis_title(owaxis.xBottom, "Principal Components") 
    187         self.scree_plot.set_show_axis_title(owaxis.xBottom, 1) 
    188         self.scree_plot.set_axis_title(owaxis.yLeft, "Proportion of Variance") 
    189         self.scree_plot.set_show_axis_title(owaxis.yLeft, 1) 
    190  
    191         self.variance_curve = self.scree_plot.add_curve( 
    192                         "Variance", 
    193                         Qt.red, Qt.red, 2, 
    194                         xData=[], 
    195                         yData=[], 
    196                         style=OWCurve.Lines, 
    197                         enableLegend=True, 
    198                         lineWidth=2, 
    199                         autoScale=1, 
    200                         x_axis_key=owaxis.xBottom, 
    201                         y_axis_key=owaxis.yLeft, 
    202                         ) 
    203  
    204         self.cumulative_variance_curve = self.scree_plot.add_curve( 
    205                         "Cumulative Variance", 
    206                         Qt.darkYellow, Qt.darkYellow, 2, 
    207                         xData=[], 
    208                         yData=[], 
    209                         style=OWCurve.Lines, 
    210                         enableLegend=True, 
    211                         lineWidth=2, 
    212                         autoScale=1, 
    213                         x_axis_key=owaxis.xBottom, 
    214                         y_axis_key=owaxis.yLeft, 
    215                         ) 
    216  
    217         self.mainArea.layout().addWidget(self.scree_plot) 
    218         self.connect(self.scree_plot, 
    219                      SIGNAL("cutoff_moved(double)"), 
    220                      self.on_cutoff_moved 
    221                      ) 
    222  
    223         self.connect(self.graphButton, 
    224                      SIGNAL("clicked()"), 
    225                      self.scree_plot.save_to_file) 
    226  
     339        self.plot = Graph() 
     340        canvas = self.plot.canvas() 
     341        canvas.setFrameStyle(QFrame.StyledPanel) 
     342        self.mainArea.layout().addWidget(self.plot) 
     343        self.plot.setAxisTitle(QwtPlot.yLeft, "Proportion of Variance") 
     344        self.plot.setAxisTitle(QwtPlot.xBottom, "Principal Components") 
     345        self.plot.setAxisScale(QwtPlot.yLeft, 0.0, 1.0) 
     346        self.plot.enableGridXB(True) 
     347        self.plot.enableGridYL(True) 
     348        self.plot.setGridColor(Qt.lightGray) 
     349 
     350        self.variance_curve = plot_curve( 
     351            "Variance", 
     352            pen=QPen(Qt.red, 2), 
     353            symbol=QwtSymbol.NoSymbol, 
     354            xaxis=QwtPlot.xBottom, 
     355            yaxis=QwtPlot.yLeft 
     356        ) 
     357        self.cumulative_variance_curve = plot_curve( 
     358            "Cumulative Variance", 
     359            pen=QPen(Qt.darkYellow, 2), 
     360            symbol=QwtSymbol.NoSymbol, 
     361            xaxis=QwtPlot.xBottom, 
     362            yaxis=QwtPlot.yLeft 
     363        ) 
     364 
     365        self.variance_curve.attach(self.plot) 
     366        self.cumulative_variance_curve.attach(self.plot) 
     367 
     368        self.selection_tool = CutoffControler(parent=self.plot.canvas()) 
     369        self.selection_tool.cutoffMoved.connect(self.on_cutoff_moved) 
     370 
     371        self.graphButton.clicked.connect(self.saveToFile) 
    227372        self.components = None 
    228373        self.variances = None 
     
    234379 
    235380    def clear(self): 
    236         """Clear widget state 
     381        """ 
     382        Clear (reset) the widget state. 
    237383        """ 
    238384        self.data = None 
    239         self.scree_plot.set_cutoff_curve_enabled(False) 
     385        self.selection_tool.setGraph(None) 
    240386        self.clear_cached() 
    241387        self.variance_curve.setVisible(False) 
     
    304450        """ 
    305451        pca = self.construct_pca_all_comp() 
    306         self.projector_full = projector = pca(self.data) 
     452        self.projector_full = pca(self.data) 
    307453 
    308454        self.variances = self.projector_full.variances 
     
    334480        input_domain = self.projector_full.input_domain 
    335481        variances = self.projector_full.variances 
    336         variance_sum = self.projector_full.variance_sum 
    337482 
    338483        # Get selected components (based on max_components and 
     
    369514    def update_scree_plot(self): 
    370515        x_space = np.arange(0, len(self.variances)) 
    371         self.scree_plot.set_axis_enabled(owaxis.xBottom, True) 
    372         self.scree_plot.set_axis_enabled(owaxis.yLeft, True) 
    373         self.scree_plot.set_axis_labels(owaxis.xBottom, 
    374                                         ["PC" + str(i + 1) for i in x_space]) 
    375  
    376         self.variance_curve.set_data(x_space, self.variances) 
    377         self.cumulative_variance_curve.set_data(x_space, self.variances_cumsum) 
     516        self.plot.enableAxis(QwtPlot.xBottom, True) 
     517        self.plot.enableAxis(QwtPlot.yLeft, True) 
     518        if len(x_space) <= 5: 
     519            self.plot.setXlabels(["PC" + str(i + 1) for i in x_space]) 
     520        else: 
     521            # Restore continuous plot scale 
     522            # TODO: disable minor ticks 
     523            self.plot.setXlabels(None) 
     524 
     525        self.variance_curve.setData(x_space, self.variances) 
     526        self.cumulative_variance_curve.setData(x_space, self.variances_cumsum) 
    378527        self.variance_curve.setVisible(True) 
    379528        self.cumulative_variance_curve.setVisible(True) 
    380529 
    381         self.scree_plot.set_cutoff_curve_enabled(True) 
    382         self.scree_plot.replot() 
     530        self.selection_tool.setRange(0, len(self.variances) - 1) 
     531        self.selection_tool.setGraph(self.plot) 
     532        self.plot.replot() 
    383533 
    384534    def on_cutoff_moved(self, value): 
     
    403553        variance = self.variances_cumsum[max_components - 1] * 100.0 
    404554        if variance < self.variance_covered: 
    405             cutoff = float(max_components - 1) 
     555            cutoff = max_components - 1 
    406556        else: 
    407557            cutoff = np.searchsorted(self.variances_cumsum, 
    408558                                     self.variance_covered / 100.0) 
    409         self.scree_plot.set_cutoff_value(cutoff + 0.5) 
     559 
     560        self.selection_tool.setCutoff(float(cutoff + 0.5)) 
    410561 
    411562    def number_of_selected_components(self): 
     
    450601 
    451602            self.reportSection("Scree Plot") 
    452             self.reportImage(self.scree_plot.save_to_file_direct) 
     603            self.reportImage(self.plot.saveToFileDirect) 
     604 
     605    def saveToFile(self): 
     606        self.plot.saveToFile() 
    453607 
    454608 
     
    482636    w.set_data(data) 
    483637    w.show() 
     638    w.set_data(Orange.data.Table("brown-selected")) 
    484639    app.exec_() 
Note: See TracChangeset for help on using the changeset viewer.