Changeset 8721:0cd372efab82 in orange


Ignore:
Timestamp:
08/17/11 18:20:11 (3 years ago)
Author:
matejd <matejd@…>
Branch:
default
Convert:
0048a693064f6c4ef87ed66cc2bc97bd5a12b539
Message:

New approach, no reloads necessary anymore; also quite big refactor, methods moved to more appropriate classes

Location:
orange/OrangeWidgets
Files:
4 added
6 edited

Legend:

Unmodified
Added
Removed
  • orange/OrangeWidgets/Prototypes/OWScatterPlot3D.py

    r8568 r8721  
    1 """<name> 3D Scatterplot</name> 
    2 """ 
     1'''<name>Scatterplot 3D</name> 
     2''' 
    33 
    44from OWWidget import * 
     
    1111import OWGUI 
    1212import OWToolbars 
    13 import OWColorPalette 
    1413import orngVizRank 
    1514from OWkNNOptimization import * 
     
    1918 
    2019TooltipKind = enum('NONE', 'VISIBLE', 'ALL') # Which attributes should be displayed in tooltips? 
     20 
     21class ScatterPlotTheme(PlotTheme): 
     22    def __init__(self): 
     23        super(ScatterPlotTheme, self).__init__() 
     24        self.grid_color = [0.8, 0.8, 0.8, 1.] 
     25 
     26class LightTheme(ScatterPlotTheme): 
     27    pass 
     28 
     29class DarkTheme(ScatterPlotTheme): 
     30    def __init__(self): 
     31        super(DarkTheme, self).__init__() 
     32        self.grid_color = [0.3, 0.3, 0.3, 1.] 
     33        self.labels_color = [0.9, 0.9, 0.9, 1.] 
     34        self.helpers_color = [0.9, 0.9, 0.9, 1.] 
     35        self.axis_values_color = [0.7, 0.7, 0.7, 1.] 
     36        self.axis_color = [0.8, 0.8, 0.8, 1.] 
     37        self.background_color = [0., 0., 0., 1.] 
    2138 
    2239class ScatterPlot(OWPlot3D, orngScaleScatterPlotData): 
     
    2542        orngScaleScatterPlotData.__init__(self) 
    2643 
    27     def set_data(self, data, subsetData=None, **args): 
    28         orngScaleScatterPlotData.setData(self, data, subsetData, **args) 
     44        self.disc_palette = ColorPaletteGenerator() 
     45        self._theme = LightTheme() 
     46        self.show_grid = True 
     47        self.show_chassis = True 
     48 
     49    def set_data(self, data, subset_data=None, **args): 
     50        orngScaleScatterPlotData.set_data(self, data, subset_data, **args) 
     51        OWPlot3D.set_data(self, self.no_jittering_scaled_data, self.no_jittering_scaled_subset_data) 
     52        # TODO: wire jitter settings (actual jittering done in geometry shader) 
     53 
     54    def update_data(self, x_attr, y_attr, z_attr, 
     55                    color_attr, symbol_attr, size_attr, label_attr): 
     56        self.before_draw_callback = self.before_draw 
     57 
     58        color_discrete = symbol_discrete = size_discrete = False 
     59 
     60        color_index = -1 
     61        if color_attr != '' and color_attr != '(Same color)': 
     62            color_index = self.attribute_name_index[color_attr] 
     63            if self.data_domain[color_attr].varType == Discrete: 
     64                color_discrete = True 
     65                self.disc_palette.setNumberOfColors(len(self.data_domain[color_attr].values)) 
     66 
     67        symbol_index = -1 
     68        num_symbols_used = -1 
     69        if symbol_attr != '' and symbol_attr != 'Same symbol)' and\ 
     70           len(self.data_domain[symbol_attr].values) < len(Symbol): 
     71            symbol_index = self.attribute_name_index[symbol_attr] 
     72            if self.data_domain[symbol_attr].varType == Discrete: 
     73                symbol_discrete = True 
     74                num_symbols_used = len(self.data_domain[symbol_attr].values) 
     75 
     76        size_index = -1 
     77        if size_attr != '' and size_attr != '(Same size)': 
     78            size_index = self.attribute_name_index[size_attr] 
     79            if self.data_domain[size_attr].varType == Discrete: 
     80                size_discrete = True 
     81 
     82        label_index = -1 
     83        if label_attr != '' and label_attr != '(No labels)': 
     84            label_index = self.attribute_name_index[label_attr] 
     85 
     86        x_index = self.attribute_name_index[x_attr] 
     87        y_index = self.attribute_name_index[y_attr] 
     88        z_index = self.attribute_name_index[z_attr] 
     89 
     90        x_discrete = self.data_domain[x_attr].varType == Discrete 
     91        y_discrete = self.data_domain[y_attr].varType == Discrete 
     92        z_discrete = self.data_domain[z_attr].varType == Discrete 
     93 
     94        colors = [] 
     95        if color_discrete: 
     96            for i in range(len(self.data_domain[color_attr].values)): 
     97                c = self.disc_palette[i] 
     98                colors.append([c.red()/255., c.green()/255., c.blue()/255.]) 
     99 
     100        data_scale = [self.attr_values[x_attr][1] - self.attr_values[x_attr][0], 
     101                      self.attr_values[y_attr][1] - self.attr_values[y_attr][0], 
     102                      self.attr_values[z_attr][1] - self.attr_values[z_attr][0]] 
     103        data_translation = [self.attr_values[x_attr][0], 
     104                            self.attr_values[y_attr][0], 
     105                            self.attr_values[z_attr][0]] 
     106        data_scale = 1. / numpy.array(data_scale) 
     107        if x_discrete: 
     108            data_scale[0] = 0.5 / float(len(self.data_domain[x_attr].values)) 
     109            data_translation[0] = 1. 
     110        if y_discrete: 
     111            data_scale[1] = 0.5 / float(len(self.data_domain[y_attr].values)) 
     112            data_translation[1] = 1. 
     113        if z_discrete: 
     114            data_scale[2] = 0.5 / float(len(self.data_domain[z_attr].values)) 
     115            data_translation[2] = 1. 
     116 
     117        self.clear() 
     118        self.set_shown_attributes_indices(x_index, y_index, z_index, 
     119            color_index, symbol_index, size_index, label_index, 
     120            colors, num_symbols_used, data_scale, data_translation) 
     121 
     122        if self.show_legend: 
     123            legend_keys = {} 
     124            color_index = color_index if color_index != -1 and color_discrete else -1 
     125            size_index = size_index if size_index != -1 and size_discrete else -1 
     126            symbol_index = symbol_index if symbol_index != -1 and symbol_discrete else -1 
     127 
     128            single_legend = [color_index, size_index, symbol_index].count(-1) == 2 
     129            if single_legend: 
     130                legend_join = lambda name, val: val 
     131            else: 
     132                legend_join = lambda name, val: name + '=' + val  
     133 
     134            color_attr = self.data_domain[color_attr] if color_index != -1 else None 
     135            symbol_attr = self.data_domain[symbol_attr] if symbol_index != -1 else None 
     136            size_attr = self.data_domain[size_attr] if size_index != -1 else None 
     137 
     138            if color_index != -1: 
     139                num = len(color_attr.values) 
     140                val = [[], [], [1.]*num, [Symbol.RECT]*num] 
     141                var_values = get_variable_values_sorted(color_attr) 
     142                for i in range(num): 
     143                    val[0].append(legend_join(color_attr.name, var_values[i])) 
     144                    c = self.disc_palette[i] 
     145                    val[1].append([c.red()/255., c.green()/255., c.blue()/255., 1.]) 
     146                legend_keys[color_attr] = val 
     147 
     148            if symbol_index != -1: 
     149                num = len(symbol_attr.values) 
     150                if legend_keys.has_key(symbol_attr): 
     151                    val = legend_keys[symbol_attr] 
     152                else: 
     153                    val = [[], [(0, 0, 0, 1)]*num, [1.]*num, []] 
     154                var_values = get_variable_values_sorted(symbol_attr) 
     155                val[3] = [] 
     156                val[0] = [] 
     157                for i in range(num): 
     158                    val[3].append(i) 
     159                    val[0].append(legend_join(symbol_attr.name, var_values[i])) 
     160                legend_keys[symbol_attr] = val 
     161 
     162            if size_index != -1: 
     163                num = len(size_attr.values) 
     164                if legend_keys.has_key(size_attr): 
     165                    val = legend_keys[size_attr] 
     166                else: 
     167                    val = [[], [(0, 0, 0, 1)]*num, [], [Symbol.RECT]*num] 
     168                val[2] = [] 
     169                val[0] = [] 
     170                var_values = get_variable_values_sorted(size_attr) 
     171                for i in range(num): 
     172                    val[0].append(legend_join(size_attr.name, var_values[i])) 
     173                    val[2].append(0.1 + float(i) / len(var_values)) 
     174                legend_keys[size_attr] = val 
     175 
     176            for val in legend_keys.values(): 
     177                for i in range(len(val[1])): 
     178                    self.legend.add_item(val[3][i], val[1][i], val[2][i], val[0][i]) 
     179 
     180        self.set_axis_title(Axis.X, x_attr) 
     181        self.set_axis_title(Axis.Y, y_attr) 
     182        self.set_axis_title(Axis.Z, z_attr) 
     183 
     184        if x_discrete: 
     185            self.set_axis_labels(Axis.X, get_variable_values_sorted(self.data_domain[x_attr])) 
     186        if y_discrete: 
     187            self.set_axis_labels(Axis.Y, get_variable_values_sorted(self.data_domain[y_attr])) 
     188        if z_discrete: 
     189            self.set_axis_labels(Axis.Z, get_variable_values_sorted(self.data_domain[z_attr])) 
     190 
     191        self.updateGL() 
     192 
     193    def before_draw(self): 
     194        glMatrixMode(GL_PROJECTION) 
     195        glLoadIdentity() 
     196        glMultMatrixd(numpy.array(self.projection.data(), dtype=float)) 
     197        glMatrixMode(GL_MODELVIEW) 
     198        glLoadIdentity() 
     199        glMultMatrixd(numpy.array(self.modelview.data(), dtype=float)) 
     200 
     201        if self.show_grid: 
     202            self.draw_grid() 
     203        if self.show_chassis: 
     204            self.draw_chassis() 
     205 
     206    def draw_chassis(self): 
     207        glColor4f(*self._theme.axis_values_color) 
     208        glEnable(GL_LINE_STIPPLE) 
     209        glLineStipple(1, 0x00FF) 
     210        glDisable(GL_DEPTH_TEST) 
     211        glLineWidth(1) 
     212        glEnable(GL_BLEND) 
     213        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) 
     214        edges = [self.x_axis, self.y_axis, self.z_axis, 
     215                 self.x_axis+self.unit_z, self.x_axis+self.unit_y, 
     216                 self.x_axis+self.unit_z+self.unit_y, 
     217                 self.y_axis+self.unit_x, self.y_axis+self.unit_z, 
     218                 self.y_axis+self.unit_x+self.unit_z, 
     219                 self.z_axis+self.unit_x, self.z_axis+self.unit_y, 
     220                 self.z_axis+self.unit_x+self.unit_y] 
     221        glBegin(GL_LINES) 
     222        for edge in edges: 
     223            start, end = edge 
     224            glVertex3f(*start) 
     225            glVertex3f(*end) 
     226        glEnd() 
     227        glDisable(GL_LINE_STIPPLE) 
     228        glEnable(GL_DEPTH_TEST) 
     229        glDisable(GL_BLEND) 
     230 
     231    def draw_grid(self): 
     232        cam_in_space = numpy.array([ 
     233          self.camera[0]*self.camera_distance, 
     234          self.camera[1]*self.camera_distance, 
     235          self.camera[2]*self.camera_distance 
     236        ]) 
     237 
     238        def _draw_grid(axis0, axis1, normal0, normal1, i, j): 
     239            glColor4f(*self._theme.grid_color) 
     240            for axis, normal, coord_index in zip([axis0, axis1], [normal0, normal1], [i, j]): 
     241                start, end = axis.copy() 
     242                start_value = self.map_to_data(start.copy())[coord_index] 
     243                end_value = self.map_to_data(end.copy())[coord_index] 
     244                values, _ = loose_label(start_value, end_value, 7) 
     245                for value in values: 
     246                    if not (start_value <= value <= end_value): 
     247                        continue 
     248                    position = start + (end-start)*((value-start_value) / float(end_value-start_value)) 
     249                    glBegin(GL_LINES) 
     250                    glVertex3f(*position) 
     251                    glVertex3f(*(position-normal*1.)) 
     252                    glEnd() 
     253 
     254        glDisable(GL_DEPTH_TEST) 
     255        glLineWidth(1) 
     256        glEnable(GL_BLEND) 
     257        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) 
     258 
     259        planes = [self.axis_plane_xy, self.axis_plane_yz, 
     260                  self.axis_plane_xy_back, self.axis_plane_yz_right] 
     261        axes = [[self.x_axis, self.y_axis], 
     262                [self.y_axis, self.z_axis], 
     263                [self.x_axis+self.unit_z, self.y_axis+self.unit_z], 
     264                [self.z_axis+self.unit_x, self.y_axis+self.unit_x]] 
     265        normals = [[numpy.array([0,-1, 0]), numpy.array([-1, 0, 0])], 
     266                   [numpy.array([0, 0,-1]), numpy.array([ 0,-1, 0])], 
     267                   [numpy.array([0,-1, 0]), numpy.array([-1, 0, 0])], 
     268                   [numpy.array([0,-1, 0]), numpy.array([ 0, 0,-1])]] 
     269        coords = [[0, 1], 
     270                  [1, 2], 
     271                  [0, 1], 
     272                  [2, 1]] 
     273        visible_planes = [plane_visible(plane, cam_in_space) for plane in planes] 
     274        xz_visible = not plane_visible(self.axis_plane_xz, cam_in_space) 
     275        if xz_visible: 
     276            _draw_grid(self.x_axis, self.z_axis, numpy.array([0,0,-1]), numpy.array([-1,0,0]), 0, 2) 
     277        for visible, (axis0, axis1), (normal0, normal1), (i, j) in\ 
     278             zip(visible_planes, axes, normals, coords): 
     279            if not visible: 
     280                _draw_grid(axis0, axis1, normal0, normal1, i, j) 
     281 
     282        glEnable(GL_DEPTH_TEST) 
     283        glDisable(GL_BLEND) 
    29284 
    30285class OWScatterPlot3D(OWWidget): 
     
    34289                    'plot.show_chassis', 'plot.show_axes', 
    35290                    'auto_send_selection', 'auto_send_selection_update', 
    36                     'jitter_size', 'jitter_continuous'] 
    37     contextHandlers = {"": DomainContextHandler("", ["x_attr", "y_attr", "z_attr"])} 
     291                    'plot.jitter_size', 'plot.jitter_continuous'] 
     292    contextHandlers = {'': DomainContextHandler('', ['x_attr', 'y_attr', 'z_attr'])} 
    38293    jitter_sizes = [0.0, 0.1, 0.5, 1, 2, 3, 4, 5, 7, 10, 15, 20, 30, 40, 50] 
    39294 
    40     def __init__(self, parent=None, signalManager=None, name="Scatter Plot 3D"): 
     295    def __init__(self, parent=None, signalManager=None, name='Scatter Plot 3D'): 
    41296        OWWidget.__init__(self, parent, signalManager, name, True) 
    42297 
    43         self.inputs = [("Examples", ExampleTable, self.set_data, Default), ("Subset Examples", ExampleTable, self.set_subset_data)] 
    44         self.outputs = [("Selected Examples", ExampleTable), ("Unselected Examples", ExampleTable)] 
    45  
    46         self.x_attr = 0 
    47         self.y_attr = 0 
    48         self.z_attr = 0 
     298        self.inputs = [('Examples', ExampleTable, self.set_data, Default), ('Subset Examples', ExampleTable, self.set_subset_data)] 
     299        self.outputs = [('Selected Examples', ExampleTable), ('Unselected Examples', ExampleTable)] 
     300 
     301        self.x_attr = '' 
     302        self.y_attr = '' 
     303        self.z_attr = '' 
    49304 
    50305        self.x_attr_discrete = False 
     
    52307        self.z_attr_discrete = False 
    53308 
    54         self.color_attr = None 
    55         self.size_attr = None 
    56         self.shape_attr = None 
    57         self.label_attr = None 
    58  
    59         self.alpha_value = 255 
     309        self.color_attr = '' 
     310        self.size_attr = '' 
     311        self.symbol_attr = '' 
     312        self.label_attr = '' 
    60313 
    61314        self.tabs = OWGUI.tabWidget(self.controlArea) 
     
    63316        self.settings_tab = OWGUI.createTabPage(self.tabs, 'Settings', canScroll=True) 
    64317 
    65         self.x_attr_cb = OWGUI.comboBox(self.main_tab, self, "x_attr", box="X-axis attribute", 
    66             tooltip="Attribute to plot on X axis.", 
    67             callback=self.on_axis_change 
    68             ) 
    69  
    70         self.y_attr_cb = OWGUI.comboBox(self.main_tab, self, "y_attr", box="Y-axis attribute", 
    71             tooltip="Attribute to plot on Y axis.", 
    72             callback=self.on_axis_change 
    73             ) 
    74  
    75         self.z_attr_cb = OWGUI.comboBox(self.main_tab, self, "z_attr", box="Z-axis attribute", 
    76             tooltip="Attribute to plot on Z axis.", 
    77             callback=self.on_axis_change 
    78             ) 
    79  
    80         self.color_attr_cb = OWGUI.comboBox(self.main_tab, self, "color_attr", box="Point color", 
    81             tooltip="Attribute to use for point color", 
    82             callback=self.on_axis_change) 
    83  
    84         # Additional point properties (labels, size, shape). 
     318        self.x_attr_cb = OWGUI.comboBox(self.main_tab, self, 'x_attr', box='X-axis attribute', 
     319            tooltip='Attribute to plot on X axis.', 
     320            callback=self.on_axis_change, 
     321            sendSelectedValue=1, 
     322            valueType=str) 
     323 
     324        self.y_attr_cb = OWGUI.comboBox(self.main_tab, self, 'y_attr', box='Y-axis attribute', 
     325            tooltip='Attribute to plot on Y axis.', 
     326            callback=self.on_axis_change, 
     327            sendSelectedValue=1, 
     328            valueType=str) 
     329 
     330        self.z_attr_cb = OWGUI.comboBox(self.main_tab, self, 'z_attr', box='Z-axis attribute', 
     331            tooltip='Attribute to plot on Z axis.', 
     332            callback=self.on_axis_change, 
     333            sendSelectedValue=1, 
     334            valueType=str) 
     335 
     336        self.color_attr_cb = OWGUI.comboBox(self.main_tab, self, 'color_attr', box='Point color', 
     337            tooltip='Attribute to use for point color', 
     338            callback=self.on_axis_change, 
     339            sendSelectedValue=1, 
     340            valueType=str) 
     341 
     342        # Additional point properties (labels, size, symbol). 
    85343        additional_box = OWGUI.widgetBox(self.main_tab, 'Additional Point Properties') 
    86         self.size_attr_cb = OWGUI.comboBox(additional_box, self, "size_attr", label="Point size:", 
    87             tooltip="Attribute to use for pointSize", 
     344        self.size_attr_cb = OWGUI.comboBox(additional_box, self, 'size_attr', label='Point size:', 
     345            tooltip='Attribute to use for point size', 
    88346            callback=self.on_axis_change, 
    89347            indent=10, 
    90348            emptyString='(Same size)', 
    91             ) 
    92  
    93         self.shape_attr_cb = OWGUI.comboBox(additional_box, self, "shape_attr", label="Point shape:", 
    94             tooltip="Attribute to use for pointShape", 
     349            sendSelectedValue=1, 
     350            valueType=str) 
     351 
     352        self.symbol_attr_cb = OWGUI.comboBox(additional_box, self, 'symbol_attr', label='Point symbol:', 
     353            tooltip='Attribute to use for point symbol', 
    95354            callback=self.on_axis_change, 
    96355            indent=10, 
    97             emptyString='(Same shape)', 
    98             ) 
    99  
    100         self.label_attr_cb = OWGUI.comboBox(additional_box, self, "label_attr", label="Point label:", 
    101             tooltip="Attribute to use for pointLabel", 
     356            emptyString='(Same symbol)', 
     357            sendSelectedValue=1, 
     358            valueType=str) 
     359 
     360        self.label_attr_cb = OWGUI.comboBox(additional_box, self, 'label_attr', label='Point label:', 
     361            tooltip='Attribute to use for pointLabel', 
    102362            callback=self.on_axis_change, 
    103363            indent=10, 
    104             emptyString='(No labels)' 
    105             ) 
     364            emptyString='(No labels)', 
     365            sendSelectedValue=1, 
     366            valueType=str) 
    106367 
    107368        self.plot = ScatterPlot(self) 
    108         self.vizrank = OWVizRank(self, self.signalManager, self.plot, orngVizRank.SCATTERPLOT3D, "ScatterPlot3D") 
     369        self.vizrank = OWVizRank(self, self.signalManager, self.plot, orngVizRank.SCATTERPLOT3D, 'ScatterPlot3D') 
    109370        self.optimization_dlg = self.vizrank 
    110371 
    111372        self.optimization_buttons = OWGUI.widgetBox(self.main_tab, 'Optimization dialogs', orientation='horizontal') 
    112         OWGUI.button(self.optimization_buttons, self, "VizRank", callback=self.vizrank.reshow, 
     373        OWGUI.button(self.optimization_buttons, self, 'VizRank', callback=self.vizrank.reshow, 
    113374            tooltip='Opens VizRank dialog, where you can search for interesting projections with different subsets of attributes', 
    114375            debuggingEnabled=0) 
    115376 
    116377        box = OWGUI.widgetBox(self.settings_tab, 'Point properties') 
    117         ss = OWGUI.hSlider(box, self, "plot.symbol_scale", label="Symbol scale", 
    118             minValue=1, maxValue=20, 
    119             tooltip="Scale symbol size", 
    120             callback=self.on_checkbox_update, 
    121             ) 
    122         ss.setValue(5) 
    123  
    124         OWGUI.hSlider(box, self, "plot.transparency", label="Transparency", 
     378        ss = OWGUI.hSlider(box, self, 'plot.symbol_scale', label='Symbol scale', 
     379            minValue=0, maxValue=20, 
     380            tooltip='Scale symbol size', 
     381            callback=self.on_checkbox_update) 
     382        ss.setValue(4) 
     383 
     384        OWGUI.hSlider(box, self, 'plot.transparency', label='Transparency', 
    125385            minValue=10, maxValue=255, 
    126             tooltip="Point transparency value", 
     386            tooltip='Point transparency value', 
    127387            callback=self.on_checkbox_update) 
    128388        OWGUI.rubber(box) 
    129389 
    130         self.jitter_size = 0 
    131         self.jitter_continuous = False 
    132         box = OWGUI.widgetBox(self.settings_tab, "Jittering Options") 
    133         self.jitter_size_combo = OWGUI.comboBox(box, self, 'jitter_size', label='Jittering size (% of size)'+'  ', 
     390        box = OWGUI.widgetBox(self.settings_tab, 'Jittering Options') 
     391        self.jitter_size_combo = OWGUI.comboBox(box, self, 'plot.jitter_size', label='Jittering size (% of size)'+'  ', 
    134392            orientation='horizontal', 
    135393            callback=self.update_plot, 
     
    137395            sendSelectedValue=1, 
    138396            valueType=float) 
    139         OWGUI.checkBox(box, self, 'jitter_continuous', 'Jitter continuous attributes', 
     397        OWGUI.checkBox(box, self, 'plot.jitter_continuous', 'Jitter continuous attributes', 
    140398            callback=self.update_plot, 
    141399            tooltip='Does jittering apply also on continuous attributes?') 
     
    149407        OWGUI.checkBox(box, self, 'plot.show_legend',         'Show legend',    callback=self.on_checkbox_update) 
    150408        OWGUI.checkBox(box, self, 'plot.use_ortho',           'Use ortho',      callback=self.on_checkbox_update) 
    151         OWGUI.checkBox(box, self, 'plot.use_2d_symbols',      '2D symbols',     callback=self.on_checkbox_update) 
     409        OWGUI.checkBox(box, self, 'plot.use_2d_symbols',      '2D symbols',     callback=self.update_plot) 
    152410        OWGUI.checkBox(box, self, 'dark_theme',               'Dark theme',     callback=self.on_theme_change) 
    153411        OWGUI.checkBox(box, self, 'plot.show_grid',           'Show grid',      callback=self.on_checkbox_update) 
    154412        OWGUI.checkBox(box, self, 'plot.show_axes',           'Show axes',      callback=self.on_checkbox_update) 
    155413        OWGUI.checkBox(box, self, 'plot.show_chassis',        'Show chassis',   callback=self.on_checkbox_update) 
    156         OWGUI.checkBox(box, self, 'plot.draw_point_cloud',    'Point cloud',    callback=self.on_checkbox_update) 
    157414        OWGUI.checkBox(box, self, 'plot.hide_outside',        'Hide outside',   callback=self.on_checkbox_update) 
    158415        OWGUI.rubber(box) 
     
    179436 
    180437        self.tooltip_kind = TooltipKind.NONE 
    181         box = OWGUI.widgetBox(self.settings_tab, "Tooltips Settings") 
     438        box = OWGUI.widgetBox(self.settings_tab, 'Tooltips Settings') 
    182439        OWGUI.comboBox(box, self, 'tooltip_kind', items = [ 
    183440            'Don\'t Show Tooltips', 'Show Visible Attributes', 'Show All Attributes']) 
    184441 
    185442        self.plot.mouseover_callback = self.mouseover_callback 
    186         self.shown_attr_indices = [] 
    187443 
    188444        self.main_tab.layout().addStretch(100) 
     
    190446 
    191447        self.mainArea.layout().addWidget(self.plot) 
    192         self.connect(self.graphButton, SIGNAL("clicked()"), self.plot.save_to_file) 
     448        self.connect(self.graphButton, SIGNAL('clicked()'), self.plot.save_to_file) 
    193449 
    194450        self.loadSettings() 
     
    196452 
    197453        self.data = None 
    198         self.subsetData = None 
    199         self.data_array_jittered = None 
     454        self.subset_data = None 
    200455        self.resize(1100, 600) 
    201456 
     
    208463    def get_example_tooltip(self, example, indices=None, max_indices=20): 
    209464        if indices and type(indices[0]) == str: 
    210             indices = [self.attr_name_index[i] for i in indices] 
     465            indices = [self.plot.attribute_name_index[i] for i in indices] 
    211466        if not indices: 
    212467            indices = range(len(self.data.domain.attributes)) 
    213468 
    214469        if example.domain.classVar: 
    215             classIndex = self.attr_name_index[example.domain.classVar.name] 
     470            classIndex = self.plot.attribute_name_index[example.domain.classVar.name] 
    216471            while classIndex in indices: 
    217472                indices.remove(classIndex) 
     
    219474        text = '<b>Attributes:</b><br>' 
    220475        for index in indices[:max_indices]: 
    221             attr = self.attr_name[index] 
     476            attr = self.attribute_names[index] 
    222477            if attr not in example.domain:  text += '&nbsp;'*4 + '%s = ?<br>' % (attr) 
    223478            elif example[attr].isSpecial(): text += '&nbsp;'*4 + '%s = ?<br>' % (attr) 
     
    247502                self.plot.remove_all_selections() 
    248503                return 
    249             # TODO: refactor this properly 
    250             if self.data_array_jittered: 
    251                 X, Y, Z = self.data_array_jittered 
    252             else: 
    253                 X, Y, Z = self.data_array[:, self.x_attr],\ 
    254                           self.data_array[:, self.y_attr],\ 
    255                           self.data_array[:, self.z_attr] 
    256             X = [X[i] for i in indices] 
    257             Y = [Y[i] for i in indices] 
    258             Z = [Z[i] for i in indices] 
    259             min_x, max_x = numpy.min(X), numpy.max(X) 
    260             min_y, max_y = numpy.min(Y), numpy.max(Y) 
    261             min_z, max_z = numpy.min(Z), numpy.max(Z) 
    262             self.plot.set_new_zoom(min_x, max_x, min_y, max_y, min_z, max_z) 
     504            selected_indices = [1 if i in indices else 0 
     505                                for i in range(len(self.data))] 
     506            selected = self.plot.rawData.selectref(selected_indices) 
     507            x_min = y_min = z_min = 1e100 
     508            x_max = y_max = z_max = -1e100 
     509            x_index = self.plot.attribute_name_index[self.x_attr] 
     510            y_index = self.plot.attribute_name_index[self.y_attr] 
     511            z_index = self.plot.attribute_name_index[self.z_attr] 
     512            # TODO: there has to be a faster way 
     513            for example in selected: 
     514                x_min = min(example[x_index], x_min) 
     515                y_min = min(example[y_index], y_min) 
     516                z_min = min(example[z_index], z_min) 
     517                x_max = max(example[x_index], x_max) 
     518                y_max = max(example[y_index], y_max) 
     519                z_max = max(example[z_index], z_max) 
     520            self.plot.set_new_zoom(x_min, x_max, y_min, y_max, z_min, z_max) 
    263521        else: 
    264522            if self.auto_send_selection: 
    265                 self._send_selections() 
     523                self.send_selections() 
    266524 
    267525    def selection_updated_callback(self): 
    268526        if self.plot.selection_type != SelectionType.ZOOM and self.auto_send_selection_update: 
    269             self._send_selections() 
    270  
    271     def _send_selections(self): 
    272         # TODO: implement precise get_selection_indices 
    273         indices = self.plot.get_selection_indices() 
    274         if len(indices) < 1: 
    275             return 
    276  
    277         selected_indices = [1 if i in indices else 0 
    278                             for i in range(len(self.data))] 
    279         unselected_indices = [1-i for i in selected_indices] 
    280         selected = self.plot.rawData.selectref(selected_indices) 
    281         unselected = self.plot.rawData.selectref(unselected_indices) 
    282  
    283         if len(selected) == 0: 
    284             selected = None 
    285         if len(unselected) == 0: 
    286             unselected = None 
    287  
    288         self.send('Selected Examples', selected) 
    289         self.send('Unselected Examples', unselected) 
     527            self.send_selections() 
    290528 
    291529    def change_selection_type(self): 
     
    295533 
    296534    def set_data(self, data=None): 
    297         self.closeContext("") 
     535        self.closeContext() 
     536        self.vizrank.clearResults() 
     537        same_domain = self.data and data and\ 
     538            data.domain.checksum() == self.data.domain.checksum() 
    298539        self.data = data 
    299         self.plot.set_data(data, self.subsetData) 
     540        if not same_domain: 
     541            self.init_attr_values() 
     542        self.openContext('', data) 
     543 
     544    def init_attr_values(self): 
    300545        self.x_attr_cb.clear() 
    301546        self.y_attr_cb.clear() 
     
    303548        self.color_attr_cb.clear() 
    304549        self.size_attr_cb.clear() 
    305         self.shape_attr_cb.clear() 
     550        self.symbol_attr_cb.clear() 
    306551        self.label_attr_cb.clear() 
    307552 
    308553        self.discrete_attrs = {} 
    309554 
    310         if self.data is not None: 
    311             self.all_attrs = data.domain.variables + data.domain.getmetas().values() 
    312             self.candidate_attrs = [attr for attr in self.all_attrs if attr.varType in [Discrete, Continuous]] 
    313  
    314             self.attr_name_index = {} 
    315             for i, attr in enumerate(self.all_attrs): 
    316                 self.attr_name_index[attr.name] = i 
    317  
    318             self.attr_name = {} 
    319             for i, attr in enumerate(self.all_attrs): 
    320                 self.attr_name[i] = attr.name 
    321  
    322             self.color_attr_cb.addItem('(Same color)') 
    323             self.size_attr_cb.addItem('(Same size)') 
    324             self.shape_attr_cb.addItem('(Same shape)') 
    325             self.label_attr_cb.addItem('(No labels)') 
    326             icons = OWGUI.getAttributeIcons()  
    327             for (i, attr) in enumerate(self.candidate_attrs): 
     555        if not self.data: 
     556            return 
     557 
     558        self.color_attr_cb.addItem('(Same color)') 
     559        self.label_attr_cb.addItem('(No labels)') 
     560        self.symbol_attr_cb.addItem('(Same symbol)') 
     561        self.size_attr_cb.addItem('(Same size)') 
     562 
     563        icons = OWGUI.getAttributeIcons()  
     564        for metavar in [self.data.domain.getmeta(mykey) for mykey in self.data.domain.getmetas().keys()]: 
     565            self.label_attr_cb.addItem(icons[metavar.varType], metavar.name) 
     566 
     567        for attr in self.data.domain: 
     568            if attr.varType in [Discrete, Continuous]: 
    328569                self.x_attr_cb.addItem(icons[attr.varType], attr.name) 
    329570                self.y_attr_cb.addItem(icons[attr.varType], attr.name) 
     
    331572                self.color_attr_cb.addItem(icons[attr.varType], attr.name) 
    332573                self.size_attr_cb.addItem(icons[attr.varType], attr.name) 
    333                 self.label_attr_cb.addItem(icons[attr.varType], attr.name) 
    334                 if attr.varType == orange.VarTypes.Discrete: 
    335                     self.discrete_attrs[len(self.discrete_attrs)+1] = (i, attr) 
    336                     self.shape_attr_cb.addItem(icons[orange.VarTypes.Discrete], attr.name) 
    337  
    338             array, c, w = self.data.toNumpyMA() 
    339             if len(c): 
    340                 array = numpy.hstack((array, c.reshape(-1,1))) 
    341             self.data_array = array 
    342  
    343             self.x_attr, self.y_attr, self.z_attr = numpy.min([[0, 1, 2], 
    344                                                                [len(self.candidate_attrs) - 1]*3 
    345                                                               ], axis=0) 
    346             self.color_attr = 0 
    347             self.shown_attr_indices = [self.x_attr, self.y_attr, self.z_attr, self.color_attr] 
    348             self.openContext('', data) 
     574            if attr.varType == Discrete:  
     575                self.symbol_attr_cb.addItem(icons[attr.varType], attr.name) 
     576            self.label_attr_cb.addItem(icons[attr.varType], attr.name) 
     577 
     578        self.x_attr = str(self.x_attr_cb.itemText(0)) 
     579        if self.y_attr_cb.count() > 1: 
     580            self.y_attr = str(self.y_attr_cb.itemText(1)) 
     581        else: 
     582            self.y_attr = str(self.y_attr_cb.itemText(0)) 
     583 
     584        if self.z_attr_cb.count() > 2: 
     585            self.z_attr = str(self.z_attr_cb.itemText(2)) 
     586        else: 
     587            self.z_attr = str(self.z_attr_cb.itemText(0)) 
     588 
     589        if self.data.domain.classVar and self.data.domain.classVar.varType in [Discrete, Continuous]: 
     590            self.color_attr = self.data.domain.classVar.name 
     591        else: 
     592            self.color_attr = '' 
     593 
     594        self.symbol_attr = self.size_attr = self.label_attr = '' 
     595        self.shown_attr_indices = [self.x_attr, self.y_attr, self.z_attr, self.color_attr] 
    349596 
    350597    def set_subset_data(self, data=None): 
    351         self.subsetData = data 
     598        self.subset_data = data 
    352599 
    353600    def handleNewSignals(self): 
     601        self.vizrank.resetDialog() 
     602        self.plot.set_data(self.data, self.subset_data) 
    354603        self.update_plot() 
    355604        self.send_selections() 
     
    359608 
    360609    def sendReport(self): 
    361         self.startReport('%s [%s - %s - %s]' % (self.windowTitle(), self.attr_name[self.x_attr], 
    362                                                 self.attr_name[self.y_attr], self.attr_name[self.z_attr])) 
     610        self.startReport('%s [%s - %s - %s]' % (self.windowTitle(), self.x_attr, self.y_attr, self.z_attr)) 
    363611        self.reportSettings('Visualized attributes', 
    364                             [('X', self.attr_name[self.x_attr]), 
    365                              ('Y', self.attr_name[self.y_attr]), 
    366                              ('Z', self.attr_name[self.z_attr]), 
    367                              self.color_attr and ('Color', self.attr_name[self.color_attr]), 
    368                              self.label_attr and ('Label', self.attr_name[self.label_attr]), 
    369                              self.shape_attr and ('Shape', self.attr_name[self.shape_attr]), 
    370                              self.size_attr  and ('Size', self.attr_name[self.size_attr])]) 
     612                            [('X', self.x_attr), 
     613                             ('Y', self.y_attr), 
     614                             ('Z', self.z_attr), 
     615                             self.color_attr and ('Color', self.color_attr), 
     616                             self.label_attr and ('Label', self.label_attr), 
     617                             self.symbol_attr and ('Symbol', self.symbol_attr), 
     618                             self.size_attr  and ('Size', self.size_attr)]) 
    371619        self.reportSettings('Settings', 
    372620                            [('Symbol size', self.plot.symbol_scale), 
    373621                             ('Transparency', self.plot.transparency), 
    374                              ("Jittering", self.jitter_size), 
    375                              ("Jitter continuous attributes", OWGUI.YesNo[self.jitter_continuous]) 
     622                             ('Jittering', self.jitter_size), 
     623                             ('Jitter continuous attributes', OWGUI.YesNo[self.jitter_continuous]) 
    376624                             ]) 
    377625        self.reportSection('Plot') 
     
    406654            return 
    407655 
    408         self.x_attr_discrete = self.y_attr_discrete = self.z_attr_discrete = False 
    409  
    410         if self.candidate_attrs[self.x_attr].varType == Discrete: 
    411             self.x_attr_discrete = True 
    412         if self.candidate_attrs[self.y_attr].varType == Discrete: 
    413             self.y_attr_discrete = True 
    414         if self.candidate_attrs[self.z_attr].varType == Discrete: 
    415             self.z_attr_discrete = True 
    416  
    417         X, Y, Z, mask = self.get_axis_data(self.x_attr, self.y_attr, self.z_attr) 
    418  
    419         color_discrete = shape_discrete = size_discrete = False 
    420  
    421         if self.color_attr > 0: 
    422             color_attr = self.candidate_attrs[self.color_attr - 1] 
    423             C = self.data_array[:, self.color_attr - 1] 
    424             if color_attr.varType == Discrete: 
    425                 color_discrete = True 
    426                 palette = OWColorPalette.ColorPaletteHSV(len(color_attr.values)) 
    427                 colors = [palette[int(value)] for value in C.ravel()] 
    428                 colors = [[c.red()/255., c.green()/255., c.blue()/255., self.alpha_value/255.] for c in colors] 
    429                 palette_colors = [palette[i] for i in range(len(color_attr.values))] 
    430             else: 
    431                 palette = OWColorPalette.ColorPaletteBW() 
    432                 maxC, minC = numpy.max(C), numpy.min(C) 
    433                 C = (C - minC) / (maxC - minC) 
    434                 colors = [palette[value] for value in C.ravel()] 
    435                 colors = [[c.red()/255., c.green()/255., c.blue()/255., self.alpha_value/255.] for c in colors] 
    436         else: 
    437             colors = 'b' 
    438  
    439         if self.size_attr > 0: 
    440             size_attr = self.candidate_attrs[self.size_attr - 1] 
    441             S = self.data_array[:, self.size_attr - 1] 
    442             if size_attr.varType == Discrete: 
    443                 size_discrete = True 
    444                 sizes = [v+1. for v in S] 
    445             else: 
    446                 min, max = numpy.min(S), numpy.max(S) 
    447                 sizes = [(v - min) / (max-min) for v in S] 
    448         else: 
    449             sizes = 1. 
    450  
    451         shapes = None 
    452         if self.shape_attr > 0: 
    453             i, shape_attr = self.discrete_attrs[self.shape_attr] 
    454             if shape_attr.varType == Discrete: 
    455                 shape_discrete = True 
    456                 shapes = self.data_array[:, i] 
    457  
    458         labels = None 
    459         if self.label_attr > 0: 
    460             label_attr = self.candidate_attrs[self.label_attr - 1] 
    461             labels = self.data_array[:, self.label_attr - 1] 
    462             if label_attr.varType == Discrete: 
    463                 value_map = {key: label_attr.values[key] for key in range(len(label_attr.values))} 
    464                 labels = [value_map[value] for value in labels] 
    465  
    466         self.plot.clear() 
    467  
    468         if self.plot.show_legend: 
    469             legend_keys = {} 
    470             color_attr = color_attr if self.color_attr > 0 and color_discrete else None 
    471             size_attr = size_attr if self.size_attr > 0 and size_discrete else None 
    472             shape_attr = shape_attr if self.shape_attr > 0 and shape_discrete else None 
    473  
    474             single_legend = [color_attr, size_attr, shape_attr].count(None) == 2 
    475             if single_legend: 
    476                 legend_join = lambda name, val: val 
    477             else: 
    478                 legend_join = lambda name, val: name + '=' + val  
    479  
    480             if color_attr != None: 
    481                 num = len(color_attr.values) 
    482                 val = [[], [], [1.]*num, [Symbol.RECT]*num] 
    483                 var_values = getVariableValuesSorted(self.data.domain[self.attr_name_index[color_attr.name]]) 
    484                 for i in range(num): 
    485                     val[0].append(legend_join(color_attr.name, var_values[i])) 
    486                     c = palette_colors[i] 
    487                     val[1].append([c.red()/255., c.green()/255., c.blue()/255., 1.]) 
    488                 legend_keys[color_attr] = val 
    489  
    490             if shape_attr != None: 
    491                 num = len(shape_attr.values) 
    492                 if legend_keys.has_key(shape_attr): 
    493                     val = legend_keys[shape_attr] 
    494                 else: 
    495                     val = [[], [(0, 0, 0, 1)]*num, [1.]*num, []] 
    496                 var_values = getVariableValuesSorted(self.data.domain[self.attr_name_index[shape_attr.name]]) 
    497                 val[3] = [] 
    498                 val[0] = [] 
    499                 for i in range(num): 
    500                     val[3].append(i) 
    501                     val[0].append(legend_join(shape_attr.name, var_values[i])) 
    502                 legend_keys[shape_attr] = val 
    503  
    504             if size_attr != None: 
    505                 num = len(size_attr.values) 
    506                 if legend_keys.has_key(size_attr): 
    507                     val = legend_keys[size_attr] 
    508                 else: 
    509                     val = [[], [(0, 0, 0, 1)]*num, [], [Symbol.RECT]*num] 
    510                 val[2] = [] 
    511                 val[0] = [] 
    512                 var_values = getVariableValuesSorted(self.data.domain[self.attr_name_index[size_attr.name]]) 
    513                 for i in range(num): 
    514                     val[0].append(legend_join(size_attr.name, var_values[i])) 
    515                     val[2].append(0.1 + float(i) / len(var_values)) 
    516                 legend_keys[size_attr] = val 
    517         else: 
    518             legend_keys = {} 
    519  
    520         for val in legend_keys.values(): 
    521             for i in range(len(val[1])): 
    522                 self.plot.legend.add_item(val[3][i], val[1][i], val[2][i], val[0][i]) 
    523  
    524         self.plot.scatter(X, Y, Z, colors, sizes, shapes, labels) 
    525         self.plot.set_x_axis_title(self.candidate_attrs[self.x_attr].name) 
    526         self.plot.set_y_axis_title(self.candidate_attrs[self.y_attr].name) 
    527         self.plot.set_z_axis_title(self.candidate_attrs[self.z_attr].name) 
    528  
    529         def create_discrete_map(attr_index): 
    530             values = self.candidate_attrs[attr_index].values 
    531             return {key: value for key, value in enumerate(values)} 
    532  
    533         if self.candidate_attrs[self.x_attr].varType == Discrete: 
    534             self.plot.set_x_axis_map(create_discrete_map(self.x_attr)) 
    535         if self.candidate_attrs[self.y_attr].varType == Discrete: 
    536             self.plot.set_y_axis_map(create_discrete_map(self.y_attr)) 
    537         if self.candidate_attrs[self.z_attr].varType == Discrete: 
    538             self.plot.set_z_axis_map(create_discrete_map(self.z_attr)) 
    539  
    540     def get_axis_data(self, x_index, y_index, z_index): 
    541         array = self.data_array 
    542         X, Y, Z = array[:, x_index], array[:, y_index], array[:, z_index] 
    543  
    544         if self.jitter_size > 0: 
    545             X, Y, Z = map(numpy.copy, [X, Y, Z]) 
    546             x_range = numpy.max(X)-numpy.min(X) 
    547             y_range = numpy.max(Y)-numpy.min(Y) 
    548             z_range = numpy.max(Z)-numpy.min(Z) 
    549             if self.x_attr_discrete or self.jitter_continuous: 
    550                 X += (numpy.random.random(len(X))-0.5) * (self.jitter_size * x_range / 100.) 
    551             if self.y_attr_discrete or self.jitter_continuous: 
    552                 Y += (numpy.random.random(len(Y))-0.5) * (self.jitter_size * y_range / 100.) 
    553             if self.z_attr_discrete or self.jitter_continuous: 
    554                 Z += (numpy.random.random(len(Z))-0.5) * (self.jitter_size * z_range / 100.) 
    555             self.data_array_jittered = (X, Y, Z) 
    556         return X, Y, Z, None 
     656        self.plot.update_data(self.x_attr, self.y_attr, self.z_attr, 
     657                              self.color_attr, self.symbol_attr, self.size_attr, 
     658                              self.label_attr) 
    557659 
    558660    def showSelectedAttributes(self): 
     
    560662        if not val: return 
    561663        if self.data.domain.classVar: 
    562             self.attr_color = self.attr_name_index[self.data.domain.classVar.name] 
     664            self.attr_color = self.data.domain.classVar.name 
    563665        if not self.plot.have_data: 
    564666            return 
    565667        attr_list = val[3] 
    566668        if attr_list and len(attr_list) == 3: 
    567             self.x_attr = self.attr_name_index[attr_list[0]] 
    568             self.y_attr = self.attr_name_index[attr_list[1]] 
    569             self.z_attr = self.attr_name_index[attr_list[2]] 
    570  
    571         #if self.graph.dataHasDiscreteClass and (self.vizrank.showKNNCorrectButton.isChecked() or self.vizrank.showKNNWrongButton.isChecked()): 
    572         #    kNNExampleAccuracy, probabilities = self.vizrank.kNNClassifyData(self.graph.createProjectionAsExampleTable([self.graph.attributeNameIndex[self.attrX], self.graph.attributeNameIndex[self.attrY]])) 
    573         #    if self.vizrank.showKNNCorrectButton.isChecked(): kNNExampleAccuracy = ([1.0 - val for val in kNNExampleAccuracy], "Probability of wrong classification = %.2f%%") 
    574         #    else: kNNExampleAccuracy = (kNNExampleAccuracy, "Probability of correct classification = %.2f%%") 
    575         #else: 
    576         #    kNNExampleAccuracy = None 
    577         #self.graph.insideColors = insideColors or self.classificationResults or kNNExampleAccuracy or self.outlierValues 
    578         #self.graph.updateData(self.attrX, self.attrY, self.attrColor, self.attrShape, self.attrSize, self.attrLabel) 
     669            self.x_attr = attr_list[0] 
     670            self.y_attr = attr_list[1] 
     671            self.z_attr = attr_list[2] 
     672 
    579673        self.update_plot() 
    580674 
    581 if __name__ == "__main__": 
     675if __name__ == '__main__': 
    582676    app = QApplication(sys.argv) 
    583677    w = OWScatterPlot3D() 
    584     data = orange.ExampleTable("../../doc/datasets/iris") 
     678    data = orange.ExampleTable('../../doc/datasets/iris') 
    585679    w.set_data(data) 
    586680    w.handleNewSignals() 
  • orange/OrangeWidgets/plot/owplot3d.py

    r8568 r8721  
    99        If False, perspective projection is used instead. 
    1010 
    11     .. method:: set_x_axis_title(title) 
    12         Sets ``title`` as the current title (label) of x axis. 
    13  
    14     .. method:: set_y_axis_title(title) 
    15         Sets ``title`` as the current title (label) of y axis. 
    16  
    17     .. method:: set_z_axis_title(title) 
    18         Sets ``title`` as the current title (label) of z axis. 
    19  
    20     .. method:: set_show_x_axis_title(show) 
    21         Determines whether to show the title of x axis or not. 
    22  
    23     .. method:: set_show_y_axis_title(show) 
    24         Determines whether to show the title of y axis or not. 
    25  
    26     .. method:: set_show_z_axis_title(show) 
    27         Determines whether to show the title of z axis or not. 
    28  
    29     .. method:: scatter(X, Y, Z, c, s) 
    30         Adds scatter data to command buffer. ``X``, ``Y`` and ``Z` 
    31         should be arrays (of equal length) with example data. 
    32         ``c`` is optional, can be an array as well (setting 
    33         colors of each example) or string ('r', 'g' or 'b'). ``s`` 
    34         optionally sets sizes of individual examples. 
    35  
    3611    .. method:: clear() 
    3712        Removes everything from the graph. 
    3813""" 
    3914 
    40 # TODO: docs! 
     15import os 
     16import sys 
     17import time 
     18from math import sin, cos, pi, floor, ceil, log10 
     19import struct 
    4120 
    4221from PyQt4.QtCore import * 
     
    4827 
    4928import OpenGL 
    50 OpenGL.ERROR_CHECKING = False # Turned off for performance improvement. 
    51 OpenGL.ERROR_LOGGING = False 
    52 OpenGL.FULL_LOGGING = False 
    53 #OpenGL.ERROR_ON_COPY = True  # TODO: enable this to check for unwanted copying (wrappers) 
     29OpenGL.ERROR_CHECKING = True # Turned off for performance improvement. 
     30OpenGL.ERROR_LOGGING = True 
     31OpenGL.FULL_LOGGING = True 
     32OpenGL.ERROR_ON_COPY = True  # TODO: enable this to check for unwanted copying (wrappers) 
    5433from OpenGL.GL import * 
    55 from OpenGL.GLU import * 
    5634from OpenGL.GL.ARB.vertex_array_object import * 
    5735from OpenGL.GL.ARB.vertex_buffer_object import * 
    58 from ctypes import c_void_p 
    59  
    60 import sys 
    61 from math import sin, cos, pi, floor, ceil, log10 
    62 import time 
    63 import struct 
     36from ctypes import c_void_p, c_char, c_char_p, POINTER 
     37 
    6438import numpy 
     39from numpy import array, maximum 
    6540#numpy.seterr(all='raise') # Raises exceptions on invalid numerical operations. 
    6641 
    6742try: 
    6843    from itertools import izip as zip # Python 3 zip == izip in Python 2.x 
     44    from itertools import chain 
    6945except: 
    7046    pass 
    7147 
    72 def normalize(vec): 
    73     return vec / numpy.sqrt(numpy.sum(vec**2)) 
    74  
    75 def clamp(value, min, max): 
    76     if value < min: 
    77         return min 
    78     if value > max: 
    79         return max 
    80     return value 
    81  
    82 def normal_from_points(p1, p2, p3): 
    83     if isinstance(p1, (list, tuple)): 
    84         v1 = [p2[0]-p1[0], p2[1]-p1[1], p2[2]-p1[2]] 
    85         v2 = [p3[0]-p1[0], p3[1]-p1[1], p3[2]-p1[2]] 
    86     else: 
    87         v1 = p2 - p1 
    88         v2 = p3 - p1 
    89     return normalize(numpy.cross(v1, v2)) 
    90  
     48# TODO: modern opengl renderer 
    9149def draw_triangle(x0, y0, x1, y1, x2, y2): 
    9250    glBegin(GL_TRIANGLES) 
     
    10159    glVertex2f(x1, y1) 
    10260    glEnd() 
     61 
     62def plane_visible(plane, location): 
     63    normal = normal_from_points(*plane[:3]) 
     64    loc_plane = normalize(plane[0] - location) 
     65    if numpy.dot(normal, loc_plane) > 0: 
     66        return False 
     67    return True 
    10368 
    10469def nicenum(x, round): 
     
    151116SelectionType = enum('ZOOM', 'RECTANGLE', 'POLYGON') 
    152117 
    153 from owprimitives3d import get_symbol_data, get_2d_symbol_data, get_2d_symbol_edges 
     118Axis = enum('X', 'Y', 'Z', 'CUSTOM') 
     119 
     120from owprimitives3d import * 
    154121 
    155122class Legend(object): 
     
    352319    def __init__(self): 
    353320        self.labels_font = QFont('Helvetice', 8) 
     321        self.helper_font = self.labels_font 
     322        self.helpers_color = [0., 0., 0., 1.]        # Color used for helping arrows when scaling. 
     323        self.background_color = [1., 1., 1., 1.]     # Color in the background. 
    354324        self.axis_title_font = QFont('Helvetica', 10, QFont.Bold) 
    355325        self.axis_font = QFont('Helvetica', 9) 
    356         self.helper_font = self.labels_font 
    357         self.grid_color = [0.8, 0.8, 0.8, 1.]        # Color of the cube grid. 
    358         self.labels_color = [0., 0., 0., 1.]         # Color used for example labels. 
    359         self.helpers_color = [0., 0., 0., 1.]        # Color used for helping arrows when scaling. 
    360         self.axis_color = [0.1, 0.1, 0.1, 1.]        # Color of the axis lines. 
     326        self.labels_color = [0., 0., 0., 1.] 
     327        self.axis_color = [0.1, 0.1, 0.1, 1.] 
    361328        self.axis_values_color = [0.1, 0.1, 0.1, 1.] 
    362         self.background_color = [1., 1., 1., 1.]     # Color in the background. 
    363  
    364 class LightTheme(PlotTheme): 
    365     pass 
    366  
    367 class DarkTheme(PlotTheme): 
    368     def __init__(self): 
    369         super(DarkTheme, self).__init__() 
    370         self.grid_color = [0.3, 0.3, 0.3, 1.] 
    371         self.labels_color = [0.9, 0.9, 0.9, 1.] 
    372         self.helpers_color = [0.9, 0.9, 0.9, 1.] 
    373         self.axis_values_color = [0.7, 0.7, 0.7, 1.] 
    374         self.axis_color = [0.8, 0.8, 0.8, 1.] 
    375         self.background_color = [0., 0., 0., 1.] 
    376329 
    377330class OWPlot3D(QtOpenGL.QGLWidget): 
     
    379332        QtOpenGL.QGLWidget.__init__(self, QtOpenGL.QGLFormat(QtOpenGL.QGL.SampleBuffers), parent) 
    380333 
    381         self.commands = [] 
    382         self.minx = self.miny = self.minz = 0 
    383         self.maxx = self.maxy = self.maxz = 0 
    384         self.view_cube_edge = 10 
    385         self.camera_distance = 30 
     334        self.camera_distance = 3. 
    386335 
    387336        self.yaw = self.pitch = -pi / 4. 
    388337        self.rotation_factor = 0.3 
     338        self.panning_factor = 0.4 
    389339        self.update_camera() 
    390340 
     
    396346        self.camera_fov = 30. 
    397347        self.zoom_factor = 2000. 
    398         self.move_factor = 100. 
    399  
    400         self.x_axis_title = '' 
    401         self.y_axis_title = '' 
    402         self.z_axis_title = '' 
    403         self.show_x_axis_title = self.show_y_axis_title = self.show_z_axis_title = True 
    404  
    405         self.vertex_buffers = [] 
    406         self.index_buffers = [] 
    407         self.vaos = [] 
    408348 
    409349        self.use_ortho = False 
     
    414354        self.symbol_scale = 1. 
    415355        self.transparency = 255 
    416         self.show_grid = True 
    417         self.scale = numpy.array([1., 1., 1.]) 
    418         self.additional_scale = [0, 0, 0] 
    419         self.scale_x_axis = True 
    420         self.scale_factor = 0.05 
    421         self.data_scale = numpy.array([1., 1., 1.]) 
    422         self.data_center = numpy.array([0., 0., 0.]) 
    423         self.zoomed_size = [self.view_cube_edge, 
    424                             self.view_cube_edge, 
    425                             self.view_cube_edge] 
     356        self.zoomed_size = [1., 1., 1.] 
    426357 
    427358        self.state = PlotState.IDLE 
    428359 
    429         self.build_axes() 
    430360        self.selections = [] 
    431361        self.selection_changed_callback = None 
     
    436366        self.setMouseTracking(True) 
    437367        self.mouseover_callback = None 
    438  
    439         self.x_axis_map = None 
    440         self.y_axis_map = None 
    441         self.z_axis_map = None 
     368        self.before_draw_callback = None 
     369        self.after_draw_callback = None 
     370 
     371        self.x_axis_labels = None 
     372        self.y_axis_labels = None 
     373        self.z_axis_labels = None 
     374 
     375        self.x_axis_title = '' 
     376        self.y_axis_title = '' 
     377        self.z_axis_title = '' 
     378 
     379        self.show_x_axis_title = self.show_y_axis_title = self.show_z_axis_title = True 
     380 
     381        self.scale_factor = 0.05 
     382        self.additional_scale = array([0., 0., 0.]) 
     383        self.data_scale = array([1., 1., 1.]) 
     384        self.data_translation = array([0., 0., 0.]) 
     385        self.plot_scale = array([1., 1., 1.]) 
     386        self.plot_translation = -array([0.5, 0.5, 0.5]) 
    442387 
    443388        self.zoom_stack = [] 
    444         self.translation = numpy.array([0., 0., 0.]) 
    445  
    446         self._theme = LightTheme() 
     389 
     390        self._theme = PlotTheme() 
    447391        self.show_axes = True 
    448         self.show_chassis = True 
    449392 
    450393        self.tooltip_fbo_dirty = True 
    451394        self.selection_fbo_dirty = True 
     395 
    452396        self.use_fbos = True 
    453  
    454         self.draw_point_cloud = False 
     397        self.use_geometry_shader = True 
     398 
    455399        self.hide_outside = False 
     400 
     401        self.build_axes() 
    456402 
    457403    def __del__(self): 
    458404        # TODO: check if anything needs deleting 
     405        # TODO: yes it does! 
    459406        pass 
    460407 
     
    467414        glDisable(GL_CULL_FACE) 
    468415        glEnable(GL_MULTISAMPLE) 
    469         glEnable(GL_VERTEX_PROGRAM_POINT_SIZE) 
    470  
    471         self.symbol_shader = QtOpenGL.QGLShaderProgram() 
    472         vertex_shader_source = ''' 
    473             #extension GL_EXT_gpu_shader4 : enable 
    474  
    475             attribute vec4 position; 
    476             attribute vec3 offset; 
    477             attribute vec4 color; 
    478             attribute vec3 normal; 
    479  
    480             uniform bool use_2d_symbols; 
    481             uniform bool shrink_symbols; 
    482             uniform bool encode_color; 
    483             uniform bool hide_outside; 
    484             uniform vec4 force_color; 
    485             uniform vec2 transparency; // vec2 instead of float, fixing a bug on windows 
    486                                        // (setUniformValue with float crashes) 
    487             uniform vec2 symbol_scale; 
    488             uniform vec2 view_edge; 
    489  
    490             uniform vec3 scale; 
    491             uniform vec3 translation; 
    492  
    493             varying vec4 var_color; 
    494  
    495             void main(void) { 
    496               vec3 offset_rotated = offset; 
    497               offset_rotated.x *= symbol_scale.x; 
    498               offset_rotated.y *= symbol_scale.x; 
    499               offset_rotated.z *= symbol_scale.x; 
    500  
    501               if (use_2d_symbols) { 
    502                   // Calculate inverse of rotations (in this case, inverse 
    503                   // is actually just transpose), so that polygons face 
    504                   // camera all the time. 
    505                   mat3 invs; 
    506  
    507                   invs[0][0] = gl_ModelViewMatrix[0][0]; 
    508                   invs[0][1] = gl_ModelViewMatrix[1][0]; 
    509                   invs[0][2] = gl_ModelViewMatrix[2][0]; 
    510  
    511                   invs[1][0] = gl_ModelViewMatrix[0][1]; 
    512                   invs[1][1] = gl_ModelViewMatrix[1][1]; 
    513                   invs[1][2] = gl_ModelViewMatrix[2][1]; 
    514  
    515                   invs[2][0] = gl_ModelViewMatrix[0][2]; 
    516                   invs[2][1] = gl_ModelViewMatrix[1][2]; 
    517                   invs[2][2] = gl_ModelViewMatrix[2][2]; 
    518  
    519                   offset_rotated = invs * offset_rotated; 
    520               } 
    521  
    522               vec3 pos = position.xyz; 
    523               pos += translation; 
    524               pos *= scale; 
    525               vec4 off_pos = vec4(pos, 1.); 
    526  
    527               if (shrink_symbols) { 
    528                   // Shrink symbols into points by ignoring offsets. 
    529                   gl_PointSize = 2.; 
    530               } 
    531               else { 
    532                   off_pos = vec4(pos+offset_rotated, 1.); 
    533               } 
    534  
    535               gl_Position = gl_ProjectionMatrix * gl_ModelViewMatrix * off_pos; 
    536  
    537               if (force_color.a > 0.) { 
    538                 var_color = force_color; 
    539               } 
    540               else if (encode_color) { 
    541                 // We've packed example index into .w component of this vertex, 
    542                 // to output it to the screen, it has to be broken down into RGBA. 
    543                 uint index = uint(position.w); 
    544                 var_color = vec4(float((index & 0xFF)) / 255., 
    545                                  float((index & 0xFF00) >> 8) / 255., 
    546                                  float((index & 0xFF0000) >> 16) / 255., 
    547                                  float((index & 0xFF000000) >> 24) / 255.); 
    548               } 
    549               else { 
    550                 pos = abs(pos); 
    551                 float manhattan_distance = max(max(pos.x, pos.y), pos.z)+5.; 
    552                 float a = min(pow(min(1., view_edge.x / manhattan_distance), 5.), transparency.x); 
    553                 if (use_2d_symbols) { 
    554                     var_color = vec4(color.rgb, a); 
    555                 } 
    556                 else { 
    557                     // Calculate the amount of lighting this triangle receives (diffuse component only). 
    558                     // The calculations are physically wrong, but look better. TODO: make them look better 
    559                     vec3 light_direction = normalize(vec3(1., 1., 0.5)); 
    560                     float diffuse = max(0., 
    561                         dot(normalize((gl_ModelViewMatrix * vec4(normal, 0.)).xyz), light_direction)); 
    562                     var_color = vec4(color.rgb+diffuse*0.7, a); 
    563                 } 
    564                 if (manhattan_distance > view_edge.x && hide_outside) 
    565                     var_color.a = 0.; 
    566               } 
    567             } 
    568             ''' 
    569  
    570         fragment_shader_source = ''' 
    571             varying vec4 var_color; 
    572  
    573             void main(void) { 
    574               gl_FragColor = var_color; 
    575             } 
    576             ''' 
    577  
    578         self.symbol_shader.addShaderFromSourceCode(QtOpenGL.QGLShader.Vertex, vertex_shader_source) 
    579         self.symbol_shader.addShaderFromSourceCode(QtOpenGL.QGLShader.Fragment, fragment_shader_source) 
    580  
    581         self.symbol_shader.bindAttributeLocation('position', 0) 
    582         self.symbol_shader.bindAttributeLocation('offset',   1) 
    583         self.symbol_shader.bindAttributeLocation('color',    2) 
    584         self.symbol_shader.bindAttributeLocation('normal',   3) 
    585  
    586         if not self.symbol_shader.link(): 
     416 
     417        self.feedback_generated = False 
     418 
     419        # Build shader program which will generate triangle data to be outputed 
     420        # to the screen in subsequent frames. Geometry shader is the heart 
     421        # of the process - it will produce actual symbol geometry out of dummy points. 
     422        self.generating_program = QtOpenGL.QGLShaderProgram() 
     423        self.generating_program.addShaderFromSourceFile(QtOpenGL.QGLShader.Geometry, 
     424            os.path.join(os.path.dirname(__file__), 'generator.gs')) 
     425        self.generating_program.addShaderFromSourceFile(QtOpenGL.QGLShader.Vertex, 
     426            os.path.join(os.path.dirname(__file__), 'generator.vs')) 
     427        varyings = (c_char_p * 5)() 
     428        varyings[:] = ['out_position', 'out_offset', 'out_color', 'out_normal', 'out_index'] 
     429        glTransformFeedbackVaryings(self.generating_program.programId(), 5,  
     430            ctypes.cast(varyings, POINTER(POINTER(c_char))), GL_INTERLEAVED_ATTRIBS) 
     431 
     432        self.generating_program.bindAttributeLocation('index', 0) 
     433 
     434        if not self.generating_program.link(): 
     435            print('Failed to link generating shader! Attribute changes may be slow.') 
     436            self.use_geometry_shader = False 
     437        else: 
     438            print('Generating shader linked.') 
     439 
     440        self.symbol_program = QtOpenGL.QGLShaderProgram() 
     441        self.symbol_program.addShaderFromSourceFile(QtOpenGL.QGLShader.Vertex, 
     442            os.path.join(os.path.dirname(__file__), 'symbol.vs')) 
     443        self.symbol_program.addShaderFromSourceFile(QtOpenGL.QGLShader.Fragment, 
     444            os.path.join(os.path.dirname(__file__), 'symbol.fs')) 
     445 
     446        self.symbol_program.bindAttributeLocation('position', 0) 
     447        self.symbol_program.bindAttributeLocation('offset',   1) 
     448        self.symbol_program.bindAttributeLocation('color',    2) 
     449        self.symbol_program.bindAttributeLocation('normal',   3) 
     450        self.symbol_program.bindAttributeLocation('index',    4) 
     451 
     452        if not self.symbol_program.link(): 
    587453            print('Failed to link symbol shader!') 
    588454        else: 
    589455            print('Symbol shader linked.') 
    590         self.symbol_shader_use_2d_symbols = self.symbol_shader.uniformLocation('use_2d_symbols') 
    591         self.symbol_shader_symbol_scale   = self.symbol_shader.uniformLocation('symbol_scale') 
    592         self.symbol_shader_transparency   = self.symbol_shader.uniformLocation('transparency') 
    593         self.symbol_shader_view_edge      = self.symbol_shader.uniformLocation('view_edge') 
    594         self.symbol_shader_scale          = self.symbol_shader.uniformLocation('scale') 
    595         self.symbol_shader_translation    = self.symbol_shader.uniformLocation('translation') 
    596         self.symbol_shader_shrink_symbols = self.symbol_shader.uniformLocation('shrink_symbols') 
    597         self.symbol_shader_encode_color   = self.symbol_shader.uniformLocation('encode_color') 
    598         self.symbol_shader_hide_outside   = self.symbol_shader.uniformLocation('hide_outside') 
    599         self.symbol_shader_force_color    = self.symbol_shader.uniformLocation('force_color') 
    600  
     456 
     457        self.symbol_program_use_2d_symbols = self.symbol_program.uniformLocation('use_2d_symbols') 
     458        self.symbol_program_symbol_scale   = self.symbol_program.uniformLocation('symbol_scale') 
     459        self.symbol_program_transparency   = self.symbol_program.uniformLocation('transparency') 
     460        self.symbol_program_scale          = self.symbol_program.uniformLocation('scale') 
     461        self.symbol_program_translation    = self.symbol_program.uniformLocation('translation') 
     462        self.symbol_program_hide_outside   = self.symbol_program.uniformLocation('hide_outside') 
     463        self.symbol_program_force_color    = self.symbol_program.uniformLocation('force_color') 
     464 
     465        # TODO: if not self.use_geometry_shader 
     466 
     467        # Upload all symbol geometry into a TBO (texture buffer object), so that generating 
     468        # geometry shader will have access to it. (TBO is easier to use than a texture in this use case). 
     469        geometry_data = [] 
     470        symbols_indices = [] 
     471        symbols_sizes = [] 
     472        for symbol in range(len(Symbol)): 
     473            triangles = get_2d_symbol_data(symbol) 
     474            symbols_indices.append(len(geometry_data) / 3) 
     475            symbols_sizes.append(len(triangles)) 
     476            for tri in triangles: 
     477                geometry_data.extend(chain(*tri)) 
     478 
     479        for symbol in range(len(Symbol)): 
     480            triangles = get_symbol_data(symbol) 
     481            symbols_indices.append(len(geometry_data) / 3) 
     482            symbols_sizes.append(len(triangles)) 
     483            for tri in triangles: 
     484                geometry_data.extend(chain(*tri)) 
     485 
     486        self.symbols_indices = symbols_indices 
     487        self.symbols_sizes = symbols_sizes 
     488 
     489        tbo = glGenBuffers(1) 
     490        glBindBuffer(GL_TEXTURE_BUFFER, tbo) 
     491        glBufferData(GL_TEXTURE_BUFFER, len(geometry_data)*4, numpy.array(geometry_data, 'f'), GL_STATIC_DRAW) 
     492        glBindBuffer(GL_TEXTURE_BUFFER, 0) 
     493        self.symbol_buffer = glGenTextures(1) 
     494        glBindTexture(GL_TEXTURE_BUFFER, self.symbol_buffer) 
     495        glTexBuffer(GL_TEXTURE_BUFFER, GL_RGB32F, tbo) # 3 floating-point components 
     496        glBindTexture(GL_TEXTURE_BUFFER, 0) 
     497 
     498        # Generate dummy vertex buffer (points which will be fed to the geometry shader). 
     499        self.dummy_vao = GLuint(0) 
     500        glGenVertexArrays(1, self.dummy_vao) 
     501        glBindVertexArray(self.dummy_vao) 
     502        vertex_buffer_id = glGenBuffers(1) 
     503        glBindBuffer(GL_ARRAY_BUFFER, vertex_buffer_id) 
     504        glBufferData(GL_ARRAY_BUFFER, numpy.arange(50*1000, dtype=numpy.float32), GL_STATIC_DRAW) 
     505        glVertexAttribPointer(0, 1, GL_FLOAT, GL_FALSE, 4, c_void_p(0)) 
     506        glEnableVertexAttribArray(0) 
     507        glBindVertexArray(0) 
     508        glBindBuffer(GL_ARRAY_BUFFER, 0) 
     509 
     510        # Specify an output VBO (and VAO) 
     511        self.feedback_vao = feedback_vao = GLuint(0) 
     512        glGenVertexArrays(1, feedback_vao) 
     513        glBindVertexArray(feedback_vao) 
     514        self.feedback_bid = feedback_bid = glGenBuffers(1) 
     515        glBindBuffer(GL_ARRAY_BUFFER, feedback_bid) 
     516        vertex_size = (3+3+3+3+1)*4 
     517        glBufferData(GL_ARRAY_BUFFER, 20*1000*144*vertex_size, c_void_p(0), GL_STATIC_DRAW) 
     518        glVertexAttribPointer(0, 3, GL_FLOAT, GL_FALSE, vertex_size, c_void_p(0)) 
     519        glVertexAttribPointer(1, 3, GL_FLOAT, GL_FALSE, vertex_size, c_void_p(3*4)) 
     520        glVertexAttribPointer(2, 3, GL_FLOAT, GL_FALSE, vertex_size, c_void_p(6*4)) 
     521        glVertexAttribPointer(3, 3, GL_FLOAT, GL_FALSE, vertex_size, c_void_p(9*4)) 
     522        glVertexAttribPointer(4, 1, GL_INT,   GL_FALSE, vertex_size, c_void_p(12*4)) 
     523        glEnableVertexAttribArray(0) 
     524        glEnableVertexAttribArray(1) 
     525        glEnableVertexAttribArray(2) 
     526        glEnableVertexAttribArray(3) 
     527        glEnableVertexAttribArray(4) 
     528        glBindVertexArray(0) 
     529        glBindBuffer(GL_ARRAY_BUFFER, 0) 
     530            
    601531        # Create two FBOs (framebuffer objects): 
    602532        # - one will be used together with stencil mask to find out which 
     
    612542            print('Failed to create selection FBO! Selections may be slow.') 
    613543            self.use_fbos = False 
    614         self.tooltip_fbo = QtOpenGL.QGLFramebufferObject(1024, 1024, format) 
     544 
     545        self.tooltip_fbo = QtOpenGL.QGLFramebufferObject(256, 256, format) 
    615546        if self.tooltip_fbo.isValid(): 
    616547            print('Tooltip FBO created.') 
     
    629560            sin(self.pitch)*sin(self.yaw)] 
    630561 
     562    def get_mvp(self): 
     563        projection = QMatrix4x4() 
     564        width, height = self.width(), self.height() 
     565        if self.use_ortho: 
     566            projection.ortho(-width / self.ortho_scale, 
     567                              width / self.ortho_scale, 
     568                             -height / self.ortho_scale, 
     569                              height / self.ortho_scale, 
     570                             self.ortho_near, 
     571                             self.ortho_far) 
     572        else: 
     573            aspect = float(width) / height if height != 0 else 1 
     574            projection.perspective(self.camera_fov, aspect, self.perspective_near, self.perspective_far) 
     575 
     576        modelview = QMatrix4x4() 
     577        modelview.lookAt( 
     578            QVector3D(self.camera[0]*self.camera_distance, 
     579                      self.camera[1]*self.camera_distance, 
     580                      self.camera[2]*self.camera_distance), 
     581            QVector3D(0,-0.1, 0), 
     582            QVector3D(0, 1, 0)) 
     583 
     584        return modelview, projection 
     585 
    631586    def paintGL(self): 
     587        if not self.feedback_generated: 
     588            return 
     589 
    632590        glClearColor(*self._theme.background_color) 
    633591        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) 
    634592 
    635         if len(self.commands) == 0: 
    636             return 
    637  
    638         glMatrixMode(GL_PROJECTION) 
    639         glLoadIdentity() 
    640         width, height = self.width(), self.height() 
    641         if self.use_ortho: 
    642             glOrtho(-width / self.ortho_scale, 
    643                      width / self.ortho_scale, 
    644                     -height / self.ortho_scale, 
    645                      height / self.ortho_scale, 
    646                      self.ortho_near, 
    647                      self.ortho_far) 
    648         else: 
    649             aspect = float(width) / height if height != 0 else 1 
    650             gluPerspective(self.camera_fov, aspect, self.perspective_near, self.perspective_far) 
    651         glMatrixMode(GL_MODELVIEW) 
    652         glLoadIdentity() 
    653         gluLookAt( 
    654             self.camera[0]*self.camera_distance, 
    655             self.camera[1]*self.camera_distance, 
    656             self.camera[2]*self.camera_distance, 
    657             0,-1, 0, 
    658             0, 1, 0) 
    659  
    660         if self.show_chassis: 
    661             self.draw_chassis() 
    662         self.draw_grid_and_axes() 
    663  
    664         for (cmd, params) in self.commands: 
    665             if cmd == 'scatter': 
    666                 vao_id, (X, Y, Z), labels = params 
    667                 scale = numpy.maximum([0., 0., 0.], self.scale + self.additional_scale) 
    668  
    669                 self.symbol_shader.bind() 
    670                 self.symbol_shader.setUniformValue(self.symbol_shader_use_2d_symbols, self.use_2d_symbols) 
    671                 self.symbol_shader.setUniformValue(self.symbol_shader_shrink_symbols, self.draw_point_cloud) 
    672                 self.symbol_shader.setUniformValue(self.symbol_shader_encode_color,   False) 
    673                 self.symbol_shader.setUniformValue(self.symbol_shader_hide_outside,   self.hide_outside) 
    674                 # Specifying float uniforms with vec2 because of a weird bug in PyQt 
    675                 self.symbol_shader.setUniformValue(self.symbol_shader_view_edge,      self.view_cube_edge, self.view_cube_edge) 
    676                 self.symbol_shader.setUniformValue(self.symbol_shader_symbol_scale,   self.symbol_scale, self.symbol_scale) 
    677                 self.symbol_shader.setUniformValue(self.symbol_shader_transparency,   self.transparency / 255., self.transparency / 255.) 
    678                 self.symbol_shader.setUniformValue(self.symbol_shader_scale,          *scale) 
    679                 self.symbol_shader.setUniformValue(self.symbol_shader_translation,    *self.translation) 
    680                 self.symbol_shader.setUniformValue(self.symbol_shader_force_color,    0., 0., 0., 0.) 
    681  
    682                 glBindVertexArray(vao_id) 
    683                 if self.draw_point_cloud: 
    684                     glDisable(GL_DEPTH_TEST) 
    685                     glDisable(GL_BLEND) 
    686                     glDrawArrays(GL_POINTS, 0, vao_id.num_3d_vertices) 
    687                 else: 
    688                     glEnable(GL_DEPTH_TEST) 
    689                     glEnable(GL_BLEND) 
    690                     glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) 
    691                     if self.use_2d_symbols: 
    692                         glDrawArrays(GL_TRIANGLES, vao_id.num_3d_vertices, vao_id.num_2d_vertices) 
    693                         # Draw outlines (somewhat dark, discuss) 
    694                         #self.symbol_shader.setUniformValue(self.symbol_shader_force_color, 
    695                             #0., 0., 0., self.transparency / 255. + 0.01) 
    696                         #glDisable(GL_DEPTH_TEST) 
    697                         #glDrawArrays(GL_LINES, vao_id.num_3d_vertices+vao_id.num_2d_vertices, vao_id.num_edge_vertices) 
    698                         #self.symbol_shader.setUniformValue(self.symbol_shader_force_color, 0., 0., 0., 0.) 
    699                     else: 
    700                         glDrawArrays(GL_TRIANGLES, 0, vao_id.num_3d_vertices) 
    701                 glBindVertexArray(0) 
    702  
    703                 self.symbol_shader.release() 
    704  
    705                 if labels != None: 
    706                     glColor4f(*self._theme.labels_color) 
    707                     for x, y, z, label in zip(X, Y, Z, labels): 
    708                         x, y, z = self.transform_data_to_plot((x, y, z)) 
    709                         if isinstance(label, str): 
    710                             self.renderText(x,y,z, label, font=self._theme.labels_font) 
    711                         else: 
    712                             self.renderText(x,y,z, ('%f' % label).rstrip('0').rstrip('.'), 
    713                                             font=self._theme.labels_font) 
    714             elif cmd == 'custom': 
    715                 callback = params 
    716                 callback() 
     593        modelview, projection = self.get_mvp() 
     594        self.modelview = modelview 
     595        self.projection = projection 
     596 
     597        if self.before_draw_callback: 
     598            self.before_draw_callback() 
     599 
     600        if self.show_axes: 
     601            self.draw_axes() 
     602 
     603        self.symbol_program.bind() 
     604        self.symbol_program.setUniformValue('modelview', modelview) 
     605        self.symbol_program.setUniformValue('projection', projection) 
     606        self.symbol_program.setUniformValue(self.symbol_program_use_2d_symbols, self.use_2d_symbols) 
     607        self.symbol_program.setUniformValue(self.symbol_program_hide_outside,   self.hide_outside) 
     608        # Specifying float uniforms with vec2 because of a weird bug in PyQt 
     609        self.symbol_program.setUniformValue(self.symbol_program_symbol_scale,   self.symbol_scale, self.symbol_scale) 
     610        self.symbol_program.setUniformValue(self.symbol_program_transparency,   self.transparency / 255., self.transparency / 255.) 
     611        plot_scale = numpy.maximum([1e-5, 1e-5, 1e-5],                          self.plot_scale+self.additional_scale) 
     612        self.symbol_program.setUniformValue(self.symbol_program_scale,          *plot_scale) 
     613        self.symbol_program.setUniformValue(self.symbol_program_translation,    *self.plot_translation) 
     614        self.symbol_program.setUniformValue(self.symbol_program_force_color,    0., 0., 0., 0.) 
     615 
     616        glEnable(GL_DEPTH_TEST) 
     617        glDisable(GL_BLEND) 
     618        glBindVertexArray(self.feedback_vao) 
     619        glDrawArrays(GL_TRIANGLES, 0, self.num_primitives_generated*3) 
     620        glBindVertexArray(0) 
     621 
     622        self.symbol_program.release() 
     623 
     624        self.draw_labels() 
     625 
     626        if self.after_draw_callback: 
     627            self.after_draw_callback() 
     628 
     629        if self.tooltip_fbo_dirty: 
     630            self.tooltip_fbo.bind() 
     631            glClearColor(1, 1, 1, 1) 
     632            glClearDepth(1) 
     633            glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) 
     634            # Draw data the same as to the screen, but with 
     635            # disabled blending and enabled depth testing. 
     636            glDisable(GL_BLEND) 
     637            glEnable(GL_DEPTH_TEST) 
     638 
     639            # TODO: scissors 
     640 
     641            self.symbol_program.bind() 
     642            # Most uniforms retain their values. 
     643            #self.symbol_program.setUniformValue(self.symbol_program_encode_color, True) 
     644            #self.symbol_program.setUniformValue(self.symbol_program_shrink_symbols, False) 
     645            #glBindVertexArray(vao_id) 
     646            #glDrawArrays(GL_TRIANGLES, 0, vao_id.num_3d_vertices) 
     647            #glBindVertexArray(0) 
     648            self.symbol_program.release() 
     649            self.tooltip_fbo.release() 
     650            self.tooltip_fbo_dirty = False 
    717651 
    718652        if self.selection_fbo_dirty: 
     
    721655            glClearStencil(0) 
    722656            glClear(GL_COLOR_BUFFER_BIT | GL_STENCIL_BUFFER_BIT) 
     657 
     658            #self.symbol_program.bind() 
     659            ##self.symbol_program.setUniformValue(self.symbol_program_encode_color, True) 
     660            ##self.symbol_program.setUniformValue(self.symbol_program_shrink_symbols, True) 
     661            #glDisable(GL_DEPTH_TEST) 
     662            #glDisable(GL_BLEND) 
     663            ##glBindVertexArray(vao_id) 
     664            ##glDrawArrays(GL_POINTS, 0, vao_id.num_3d_vertices) 
     665            ##glBindVertexArray(0) 
     666            #self.symbol_program.release() 
     667 
     668            ## Also draw stencil masks to the screen. No need to 
     669            ## write color or depth information as well, so we 
     670            ## disable those. 
     671            #glMatrixMode(GL_PROJECTION) 
     672            #glLoadIdentity() 
     673            #glOrtho(0, self.width(), self.height(), 0, -1, 1) 
     674            #glMatrixMode(GL_MODELVIEW) 
     675            #glLoadIdentity() 
     676 
     677            #glColorMask(GL_FALSE, GL_FALSE, GL_FALSE, GL_FALSE) 
     678            #glDepthMask(GL_FALSE) 
     679            #glStencilMask(0x01) 
     680            #glStencilOp(GL_KEEP, GL_KEEP, GL_INVERT) 
     681            #glStencilFunc(GL_ALWAYS, 0, ~0) 
     682            #glEnable(GL_STENCIL_TEST) 
     683            #for selection in self.selections: 
     684                #selection.draw_mask() 
     685            #glDisable(GL_STENCIL_TEST) 
     686            #glColorMask(GL_TRUE, GL_TRUE, GL_TRUE, GL_TRUE) 
     687            #glDepthMask(GL_TRUE) 
    723688            self.selection_fbo.release() 
    724  
    725         if self.tooltip_fbo_dirty: 
    726             self.tooltip_fbo.bind() 
    727             glClearColor(1, 1, 1, 1) 
    728             glClearDepth(1) 
    729             glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) 
    730             self.tooltip_fbo.release() 
    731  
    732         for (cmd, params) in self.commands: 
    733             if cmd == 'scatter': 
    734                 # Don't draw auxiliary info when rotating or selecting 
    735                 # (but make sure these are drawn the very first frame 
    736                 # plot returns to state idle because the user might want to 
    737                 # get tooltips or has just selected a bunch of stuff). 
    738                 vao_id, _, _ = params 
    739  
    740                 if self.tooltip_fbo_dirty: 
    741                     # Draw data the same as to the screen, but with 
    742                     # disabled blending and enabled depth testing. 
    743                     self.tooltip_fbo.bind() 
    744                     glDisable(GL_BLEND) 
    745                     glEnable(GL_DEPTH_TEST) 
    746  
    747                     self.symbol_shader.bind() 
    748                     # Most uniforms retain their values. 
    749                     self.symbol_shader.setUniformValue(self.symbol_shader_encode_color, True) 
    750                     self.symbol_shader.setUniformValue(self.symbol_shader_shrink_symbols, False) 
    751                     glBindVertexArray(vao_id) 
    752                     glDrawArrays(GL_TRIANGLES, 0, vao_id.num_3d_vertices) 
    753                     glBindVertexArray(0) 
    754                     self.symbol_shader.release() 
    755                     self.tooltip_fbo.release() 
    756                     self.tooltip_fbo_dirty = False 
    757  
    758                 if self.selection_fbo_dirty: 
    759                     # Draw data as points instead, this means that examples farther away 
    760                     # will still have a good chance at being visible (not covered). 
    761                     self.selection_fbo.bind() 
    762                     self.symbol_shader.bind() 
    763                     self.symbol_shader.setUniformValue(self.symbol_shader_encode_color, True) 
    764                     self.symbol_shader.setUniformValue(self.symbol_shader_shrink_symbols, True) 
    765                     glDisable(GL_DEPTH_TEST) 
    766                     glDisable(GL_BLEND) 
    767                     glBindVertexArray(vao_id) 
    768                     glDrawArrays(GL_POINTS, 0, vao_id.num_3d_vertices) 
    769                     glBindVertexArray(0) 
    770                     self.symbol_shader.release() 
    771  
    772                     # Also draw stencil masks to the screen. No need to 
    773                     # write color or depth information as well, so we 
    774                     # disable those. 
    775                     glMatrixMode(GL_PROJECTION) 
    776                     glLoadIdentity() 
    777                     glOrtho(0, self.width(), self.height(), 0, -1, 1) 
    778                     glMatrixMode(GL_MODELVIEW) 
    779                     glLoadIdentity() 
    780  
    781                     glColorMask(GL_FALSE, GL_FALSE, GL_FALSE, GL_FALSE) 
    782                     glDepthMask(GL_FALSE) 
    783                     glStencilMask(0x01) 
    784                     glStencilOp(GL_KEEP, GL_KEEP, GL_INVERT) 
    785                     glStencilFunc(GL_ALWAYS, 0, ~0) 
    786                     glEnable(GL_STENCIL_TEST) 
    787                     for selection in self.selections: 
    788                         selection.draw_mask() 
    789                     glDisable(GL_STENCIL_TEST) 
    790                     glColorMask(GL_TRUE, GL_TRUE, GL_TRUE, GL_TRUE) 
    791                     glDepthMask(GL_TRUE) 
    792                     self.selection_fbo.release() 
    793                     self.selection_fbo_dirty = False 
     689            self.selection_fbo_dirty = False 
    794690 
    795691        glDisable(GL_DEPTH_TEST) 
     
    805701        self.draw_helpers() 
    806702 
     703    def draw_labels(self): 
     704        if self.label_index < 0: 
     705            return 
     706 
     707        glMatrixMode(GL_PROJECTION) 
     708        glLoadIdentity() 
     709        glMultMatrixd(array(self.projection.data(), dtype=float)) 
     710        glMatrixMode(GL_MODELVIEW) 
     711        glLoadIdentity() 
     712        glMultMatrixd(array(self.modelview.data(), dtype=float)) 
     713 
     714        glColor4f(*self._theme.labels_color) 
     715        for example in self.data.transpose(): 
     716            x = example[self.x_index] 
     717            y = example[self.y_index] 
     718            z = example[self.z_index] 
     719            label = example[self.label_index] 
     720            x, y, z = self.map_to_plot(array([x, y, z]), original=False) 
     721            #if isinstance(label, str): 
     722                #self.renderText(x,y,z, label, font=self._theme.labels_font) 
     723            #else: 
     724            self.renderText(x,y,z, ('%f' % label).rstrip('0').rstrip('.'), 
     725                            font=self._theme.labels_font) 
     726 
    807727    def draw_helpers(self): 
    808728        glMatrixMode(GL_PROJECTION) 
     
    814734        if self.state == PlotState.SCALING: 
    815735            x, y = self.mouse_pos.x(), self.mouse_pos.y() 
     736            #TODO: replace with an image 
    816737            glColor4f(*self._theme.helpers_color) 
    817738            draw_triangle(x-5, y-30, x+5, y-30, x, y-40) 
     
    824745 
    825746            self.renderText(x, y-50, 'Scale y axis', font=self._theme.labels_font) 
    826             self.renderText(x+60, y+3, 
    827                             'Scale {0} axis'.format(['z', 'x'][self.scale_x_axis]), 
    828                             font=self._theme.labels_font) 
     747            self.renderText(x+60, y+3, 'Scale x and z axes', font=self._theme.labels_font) 
    829748        elif self.state == PlotState.SELECTING and self.new_selection != None: 
    830749            self.new_selection.draw() 
     
    833752            selection.draw() 
    834753 
    835     def set_x_axis_title(self, title): 
    836         self.x_axis_title = title 
    837         self.updateGL() 
    838  
    839     def set_show_x_axis_title(self, show): 
    840         self.show_x_axis_title = show 
    841         self.updateGL() 
    842  
    843     def set_y_axis_title(self, title): 
    844         self.y_axis_title = title 
    845         self.updateGL() 
    846  
    847     def set_show_y_axis_title(self, show): 
    848         self.show_y_axis_title = show 
    849         self.updateGL() 
    850  
    851     def set_z_axis_title(self, title): 
    852         self.z_axis_title = title 
    853         self.updateGL() 
    854  
    855     def set_show_z_axis_title(self, show): 
    856         self.show_z_axis_title = show 
    857         self.updateGL() 
    858  
    859     def draw_chassis(self): 
    860         glColor4f(*self._theme.axis_values_color) 
    861         glEnable(GL_LINE_STIPPLE) 
    862         glLineStipple(1, 0x00FF) 
    863         edges = [self.x_axis, self.y_axis, self.z_axis, 
    864                  self.x_axis+self.unit_z, self.x_axis+self.unit_y, 
    865                  self.x_axis+self.unit_z+self.unit_y, 
    866                  self.y_axis+self.unit_x, self.y_axis+self.unit_z, 
    867                  self.y_axis+self.unit_x+self.unit_z, 
    868                  self.z_axis+self.unit_x, self.z_axis+self.unit_y, 
    869                  self.z_axis+self.unit_x+self.unit_y] 
    870         glBegin(GL_LINES) 
    871         for edge in edges: 
    872             start, end = edge 
    873             glVertex3f(*start) 
    874             glVertex3f(*end) 
    875         glEnd() 
    876         glDisable(GL_LINE_STIPPLE) 
    877  
    878     def draw_grid_and_axes(self): 
    879         cam_in_space = numpy.array([ 
    880           self.camera[0]*self.camera_distance, 
    881           self.camera[1]*self.camera_distance, 
    882           self.camera[2]*self.camera_distance 
    883         ]) 
    884  
    885         def plane_visible(plane): 
    886             normal = normal_from_points(*plane[:3]) 
    887             cam_plane = normalize(plane[0] - cam_in_space) 
    888             if numpy.dot(normal, cam_plane) > 0: 
    889                 return False 
    890             return True 
     754    def build_axes(self): 
     755        edge_half = 1. / 2. 
     756        x_axis = [[-edge_half, -edge_half, -edge_half], [edge_half, -edge_half, -edge_half]] 
     757        y_axis = [[-edge_half, -edge_half, -edge_half], [-edge_half, edge_half, -edge_half]] 
     758        z_axis = [[-edge_half, -edge_half, -edge_half], [-edge_half, -edge_half, edge_half]] 
     759 
     760        self.x_axis = x_axis = numpy.array(x_axis) 
     761        self.y_axis = y_axis = numpy.array(y_axis) 
     762        self.z_axis = z_axis = numpy.array(z_axis) 
     763 
     764        self.unit_x = unit_x = numpy.array([1., 0., 0.]) 
     765        self.unit_y = unit_y = numpy.array([0., 1., 0.]) 
     766        self.unit_z = unit_z = numpy.array([0., 0., 1.]) 
     767  
     768        A = y_axis[1] 
     769        B = y_axis[1] + unit_x 
     770        C = x_axis[1] 
     771        D = x_axis[0] 
     772 
     773        E = A + unit_z 
     774        F = B + unit_z 
     775        G = C + unit_z 
     776        H = D + unit_z 
     777 
     778        self.axis_plane_xy = [A, B, C, D] 
     779        self.axis_plane_yz = [A, D, H, E] 
     780        self.axis_plane_xz = [D, C, G, H] 
     781 
     782        self.axis_plane_xy_back = [H, G, F, E] 
     783        self.axis_plane_yz_right = [B, F, G, C] 
     784        self.axis_plane_xz_top = [E, F, B, A] 
     785 
     786    def draw_axes(self): 
     787        glMatrixMode(GL_PROJECTION) 
     788        glLoadIdentity() 
     789        glMultMatrixd(numpy.array(self.projection.data(), dtype=float)) 
     790        glMatrixMode(GL_MODELVIEW) 
     791        glLoadIdentity() 
     792        glMultMatrixd(numpy.array(self.modelview.data(), dtype=float)) 
    891793 
    892794        def draw_axis(line): 
     
    898800            glEnd() 
    899801 
    900         def draw_discrete_axis_values(axis, coord_index, normal, axis_map): 
    901             start, end = axis 
    902             start_value = self.transform_plot_to_data(numpy.copy(start))[coord_index] 
    903             end_value = self.transform_plot_to_data(numpy.copy(end))[coord_index] 
     802        def draw_discrete_axis_values(axis, coord_index, normal, axis_labels): 
     803            start, end = axis.copy() 
     804            start_value = self.map_to_data(start.copy())[coord_index] 
     805            end_value = self.map_to_data(end.copy())[coord_index] 
    904806            length = end_value - start_value 
    905             offset = normal*0.8 
    906             for key in axis_map.keys(): 
    907                 if start_value <= key <= end_value: 
    908                     position = start + (end-start)*((key-start_value) / length) 
     807            for i, label in enumerate(axis_labels): 
     808                value = (i + 1) * 2 
     809                if start_value <= value <= end_value: 
     810                    position = start + (end-start)*((value-start_value) / length) 
    909811                    glBegin(GL_LINES) 
    910812                    glVertex3f(*(position)) 
    911                     glVertex3f(*(position+normal*0.2)) 
     813                    glVertex3f(*(position+normal*0.03)) 
    912814                    glEnd() 
    913                     position += offset 
     815                    position += normal * 0.1 
    914816                    self.renderText(position[0], 
    915817                                    position[1], 
    916818                                    position[2], 
    917                                     axis_map[key], font=self._theme.labels_font) 
    918  
    919         def draw_values(axis, coord_index, normal, axis_map): 
     819                                    label, font=self._theme.labels_font) 
     820 
     821        def draw_values(axis, coord_index, normal, axis_labels): 
    920822            glColor4f(*self._theme.axis_values_color) 
    921823            glLineWidth(1) 
    922             if axis_map != None: 
    923                 draw_discrete_axis_values(axis, coord_index, normal, axis_map) 
     824            if axis_labels != None: 
     825                draw_discrete_axis_values(axis, coord_index, normal, axis_labels) 
    924826                return 
    925             start, end = axis 
    926             start_value = self.transform_plot_to_data(numpy.copy(start))[coord_index] 
    927             end_value = self.transform_plot_to_data(numpy.copy(end))[coord_index] 
     827            start, end = axis.copy() 
     828            start_value = self.map_to_data(start.copy())[coord_index] 
     829            end_value = self.map_to_data(end.copy())[coord_index] 
    928830            values, num_frac = loose_label(start_value, end_value, 7) 
    929             format = '%%.%df' % num_frac 
    930             offset = normal*0.8 
    931831            for value in values: 
    932832                if not (start_value <= value <= end_value): 
    933833                    continue 
    934834                position = start + (end-start)*((value-start_value) / float(end_value-start_value)) 
     835                text = ('%%.%df' % num_frac) % value 
    935836                glBegin(GL_LINES) 
    936837                glVertex3f(*(position)) 
    937                 glVertex3f(*(position+normal*0.2)) 
     838                glVertex3f(*(position+normal*0.03)) 
    938839                glEnd() 
    939                 value = self.transform_plot_to_data(numpy.copy(position))[coord_index] 
    940                 position += offset 
     840                position += normal * 0.1 
    941841                self.renderText(position[0], 
    942842                                position[1], 
    943843                                position[2], 
    944                                 format % value) 
     844                                text, font=self._theme.axis_font) 
    945845 
    946846        def draw_axis_title(axis, title, normal): 
    947847            middle = (axis[0] + axis[1]) / 2. 
    948             middle += normal * 1. if axis[0][1] != axis[1][1] else normal * 2. 
     848            middle += normal * 0.1 if axis[0][1] != axis[1][1] else normal * 0.2 
    949849            self.renderText(middle[0], middle[1], middle[2], 
    950850                            title, 
    951851                            font=self._theme.axis_title_font) 
    952  
    953         def draw_grid(axis0, axis1, normal0, normal1, i, j): 
    954             glColor4f(*self._theme.grid_color) 
    955             for axis, normal, coord_index in zip([axis0, axis1], [normal0, normal1], [i, j]): 
    956                 start, end = axis 
    957                 start_value = self.transform_plot_to_data(numpy.copy(start))[coord_index] 
    958                 end_value = self.transform_plot_to_data(numpy.copy(end))[coord_index] 
    959                 values, _ = loose_label(start_value, end_value, 7) 
    960                 for value in values: 
    961                     if not (start_value <= value <= end_value): 
    962                         continue 
    963                     position = start + (end-start)*((value-start_value) / float(end_value-start_value)) 
    964                     glBegin(GL_LINES) 
    965                     glVertex3f(*position) 
    966                     glVertex3f(*(position-normal*10.)) 
    967                     glEnd() 
    968852 
    969853        glDisable(GL_DEPTH_TEST) 
     
    972856        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) 
    973857 
     858        cam_in_space = numpy.array([ 
     859          self.camera[0]*self.camera_distance, 
     860          self.camera[1]*self.camera_distance, 
     861          self.camera[2]*self.camera_distance 
     862        ]) 
     863 
    974864        planes = [self.axis_plane_xy, self.axis_plane_yz, 
    975865                  self.axis_plane_xy_back, self.axis_plane_yz_right] 
    976         axes = [[self.x_axis, self.y_axis], 
    977                 [self.y_axis, self.z_axis], 
    978                 [self.x_axis+self.unit_z, self.y_axis+self.unit_z], 
    979                 [self.z_axis+self.unit_x, self.y_axis+self.unit_x]] 
    980866        normals = [[numpy.array([0,-1, 0]), numpy.array([-1, 0, 0])], 
    981867                   [numpy.array([0, 0,-1]), numpy.array([ 0,-1, 0])], 
    982868                   [numpy.array([0,-1, 0]), numpy.array([-1, 0, 0])], 
    983869                   [numpy.array([0,-1, 0]), numpy.array([ 0, 0,-1])]] 
    984         coords = [[0, 1], 
    985                   [1, 2], 
    986                   [0, 1], 
    987                   [2, 1]] 
    988         visible_planes = map(plane_visible, planes) 
    989         xz_visible = not plane_visible(self.axis_plane_xz) 
    990         if self.show_grid: 
    991             if xz_visible: 
    992                 draw_grid(self.x_axis, self.z_axis, numpy.array([0,0,-1]), numpy.array([-1,0,0]), 0, 2) 
    993             for visible, (axis0, axis1), (normal0, normal1), (i, j) in\ 
    994                  zip(visible_planes, axes, normals, coords): 
    995                 if not visible: 
    996                     draw_grid(axis0, axis1, normal0, normal1, i, j) 
    997  
    998         glEnable(GL_DEPTH_TEST) 
    999         glDisable(GL_BLEND) 
    1000  
    1001         if not self.show_axes: 
    1002             return 
     870        visible_planes = [plane_visible(plane, cam_in_space) for plane in planes] 
     871        xz_visible = not plane_visible(self.axis_plane_xz, cam_in_space) 
    1003872 
    1004873        if visible_planes[0 if xz_visible else 2]: 
    1005874            draw_axis(self.x_axis) 
    1006             draw_values(self.x_axis, 0, numpy.array([0, 0, -1]), self.x_axis_map) 
     875            draw_values(self.x_axis, 0, numpy.array([0, 0, -1]), self.x_axis_labels) 
    1007876            if self.show_x_axis_title: 
    1008877                draw_axis_title(self.x_axis, self.x_axis_title, numpy.array([0, 0, -1])) 
    1009878        elif visible_planes[2 if xz_visible else 0]: 
    1010879            draw_axis(self.x_axis + self.unit_z) 
    1011             draw_values(self.x_axis + self.unit_z, 0, numpy.array([0, 0, 1]), self.x_axis_map) 
     880            draw_values(self.x_axis + self.unit_z, 0, numpy.array([0, 0, 1]), self.x_axis_labels) 
    1012881            if self.show_x_axis_title: 
    1013882                draw_axis_title(self.x_axis + self.unit_z, 
     
    1016885        if visible_planes[1 if xz_visible else 3]: 
    1017886            draw_axis(self.z_axis) 
    1018             draw_values(self.z_axis, 2, numpy.array([-1, 0, 0]), self.z_axis_map) 
     887            draw_values(self.z_axis, 2, numpy.array([-1, 0, 0]), self.z_axis_labels) 
    1019888            if self.show_z_axis_title: 
    1020889                draw_axis_title(self.z_axis, self.z_axis_title, numpy.array([-1, 0, 0])) 
    1021890        elif visible_planes[3 if xz_visible else 1]: 
    1022891            draw_axis(self.z_axis + self.unit_x) 
    1023             draw_values(self.z_axis + self.unit_x, 2, numpy.array([1, 0, 0]), self.z_axis_map) 
     892            draw_values(self.z_axis + self.unit_x, 2, numpy.array([1, 0, 0]), self.z_axis_labels) 
    1024893            if self.show_z_axis_title: 
    1025894                draw_axis_title(self.z_axis + self.unit_x, self.z_axis_title, numpy.array([1, 0, 0])) 
     
    1042911        normal = normals[rightmost_visible] 
    1043912        draw_axis(axis) 
    1044         draw_values(axis, 1, normal, self.y_axis_map) 
     913        draw_values(axis, 1, normal, self.y_axis_labels) 
    1045914        if self.show_y_axis_title: 
    1046915            draw_axis_title(axis, self.y_axis_title, normal) 
    1047916 
    1048         # Remember which axis to scale when dragging mouse horizontally. 
    1049         self.scale_x_axis = False if rightmost_visible % 2 == 0 else True 
    1050  
    1051     def build_axes(self): 
    1052         edge_half = self.view_cube_edge / 2. 
    1053         x_axis = [[-edge_half,-edge_half,-edge_half], [edge_half,-edge_half,-edge_half]] 
    1054         y_axis = [[-edge_half,-edge_half,-edge_half], [-edge_half,edge_half,-edge_half]] 
    1055         z_axis = [[-edge_half,-edge_half,-edge_half], [-edge_half,-edge_half,edge_half]] 
    1056  
    1057         self.x_axis = x_axis = numpy.array(x_axis) 
    1058         self.y_axis = y_axis = numpy.array(y_axis) 
    1059         self.z_axis = z_axis = numpy.array(z_axis) 
    1060  
    1061         self.unit_x = unit_x = numpy.array([self.view_cube_edge,0,0]) 
    1062         self.unit_y = unit_y = numpy.array([0,self.view_cube_edge,0]) 
    1063         self.unit_z = unit_z = numpy.array([0,0,self.view_cube_edge]) 
    1064   
    1065         A = y_axis[1] 
    1066         B = y_axis[1] + unit_x 
    1067         C = x_axis[1] 
    1068         D = x_axis[0] 
    1069  
    1070         E = A + unit_z 
    1071         F = B + unit_z 
    1072         G = C + unit_z 
    1073         H = D + unit_z 
    1074  
    1075         self.axis_plane_xy = [A, B, C, D] 
    1076         self.axis_plane_yz = [A, D, H, E] 
    1077         self.axis_plane_xz = [D, C, G, H] 
    1078  
    1079         self.axis_plane_xy_back = [H, G, F, E] 
    1080         self.axis_plane_yz_right = [B, F, G, C] 
    1081         self.axis_plane_xz_top = [E, F, B, A] 
    1082  
    1083     def scatter(self, X, Y, Z, colors='b', sizes=5, symbols=None, labels=None, **kwargs): 
    1084         if len(X) != len(Y) != len(Z): 
    1085             raise ValueError('Axis data arrays must be of equal length') 
    1086         num_points = len(X) 
    1087  
    1088         if isinstance(colors, str): 
    1089             color_map = {'r': [1.0, 0.0, 0.0, 1.0], 
    1090                          'g': [0.0, 1.0, 0.0, 1.0], 
    1091                          'b': [0.0, 0.0, 1.0, 1.0]} 
    1092             default = [0.0, 0.0, 1.0, 1.0] 
    1093             colors = [color_map.get(colors, default) for _ in range(num_points)] 
    1094   
    1095         if isinstance(sizes, (int, float)): 
    1096             sizes = [sizes for _ in range(num_points)] 
    1097  
    1098         # Scale sizes to 0..1 
    1099         self.max_size = float(numpy.max(sizes)) 
    1100         sizes = [size / self.max_size for size in sizes] 
    1101  
    1102         if symbols == None: 
    1103             symbols = [Symbol.RECT for _ in range(num_points)] 
    1104  
    1105         # We scale and translate data into almost-unit cube centered around (0,0,0) in plot-space. 
    1106         # It's almost-unit because the length of its edge is specified with view_cube_edge. 
    1107         # This transform is done to ease later calculations and for presentation purposes. 
    1108         min = self.min_x, self.min_y, self.min_z = numpy.min(X), numpy.min(Y), numpy.min(Z) 
    1109         max = self.max_x, self.max_y, self.max_z = numpy.max(X), numpy.max(Y), numpy.max(Z) 
    1110         min = numpy.array(min) 
    1111         max = numpy.array(max) 
    1112         range_x, range_y, range_z = max-min 
    1113         self.data_center = (min + max) / 2  
    1114  
    1115         scale_x = self.view_cube_edge / range_x 
    1116         scale_y = self.view_cube_edge / range_y 
    1117         scale_z = self.view_cube_edge / range_z 
    1118  
    1119         self.data_scale = numpy.array([scale_x, scale_y, scale_z]) 
    1120  
    1121         # TODO: if self.use_2d_symbols 
    1122  
    1123         num_3d_vertices = 0 
    1124         num_2d_vertices = 0 
    1125         num_edge_vertices = 0 
    1126         vertices = [] 
    1127         ai = -1 # Array index (used in color-picking). 
    1128         for x, y, z, (r,g,b,a), size, symbol in zip(X, Y, Z, colors, sizes, symbols): 
    1129             x -= self.data_center[0] 
    1130             y -= self.data_center[1] 
    1131             z -= self.data_center[2] 
    1132             x *= scale_x 
    1133             y *= scale_y 
    1134             z *= scale_z 
    1135             triangles = get_symbol_data(symbol) 
    1136             ss = size*0.02 
    1137             ai += 1 
    1138             for v0, v1, v2, n0, n1, n2 in triangles: 
    1139                 num_3d_vertices += 3 
    1140                 vertices.extend([x,y,z, ai, ss*v0[0],ss*v0[1],ss*v0[2], r,g,b,a, n0[0],n0[1],n0[2], 
    1141                                  x,y,z, ai, ss*v1[0],ss*v1[1],ss*v1[2], r,g,b,a, n1[0],n1[1],n1[2], 
    1142                                  x,y,z, ai, ss*v2[0],ss*v2[1],ss*v2[2], r,g,b,a, n2[0],n2[1],n2[2]]) 
    1143  
    1144         for x, y, z, (r,g,b,a), size, symbol in zip(X, Y, Z, colors, sizes, symbols): 
    1145             x -= self.data_center[0] 
    1146             y -= self.data_center[1] 
    1147             z -= self.data_center[2] 
    1148             x *= scale_x 
    1149             y *= scale_y 
    1150             z *= scale_z 
    1151             triangles = get_2d_symbol_data(symbol) 
    1152             ss = size*0.02 
    1153             for v0, v1, v2, _, _, _ in triangles: 
    1154                 num_2d_vertices += 3 
    1155                 vertices.extend([x,y,z, 0, ss*v0[0],ss*v0[1],ss*v0[2], r,g,b,a, 0,0,0, 
    1156                                  x,y,z, 0, ss*v1[0],ss*v1[1],ss*v1[2], r,g,b,a, 0,0,0, 
    1157                                  x,y,z, 0, ss*v2[0],ss*v2[1],ss*v2[2], r,g,b,a, 0,0,0]) 
    1158  
    1159         for x, y, z, (r,g,b,a), size, symbol in zip(X, Y, Z, colors, sizes, symbols): 
    1160             x -= self.data_center[0] 
    1161             y -= self.data_center[1] 
    1162             z -= self.data_center[2] 
    1163             x *= scale_x 
    1164             y *= scale_y 
    1165             z *= scale_z 
    1166             edges = get_2d_symbol_edges(symbol) 
    1167             ss = size*0.02 
    1168             for v0, v1 in edges: 
    1169                 num_edge_vertices += 2 
    1170                 vertices.extend([x,y,z, 0, ss*v0[0],ss*v0[1],ss*v0[2], r,g,b,a, 0,0,0, 
    1171                                  x,y,z, 0, ss*v1[0],ss*v1[1],ss*v1[2], r,g,b,a, 0,0,0]) 
    1172  
    1173         # Build Vertex Buffer + Vertex Array Object. 
    1174         vao_id = GLuint(0) 
    1175         glGenVertexArrays(1, vao_id) 
    1176         glBindVertexArray(vao_id) 
    1177  
    1178         vertex_buffer_id = glGenBuffers(1) 
    1179         glBindBuffer(GL_ARRAY_BUFFER, vertex_buffer_id) 
    1180         glBufferData(GL_ARRAY_BUFFER, numpy.array(vertices, 'f'), GL_STATIC_DRAW) 
    1181  
    1182         vertex_size = (4+3+3+4)*4 
    1183         glVertexAttribPointer(0, 4, GL_FLOAT, GL_FALSE, vertex_size, c_void_p(0))    # position 
    1184         glVertexAttribPointer(1, 3, GL_FLOAT, GL_FALSE, vertex_size, c_void_p(4*4))  # offset 
    1185         glVertexAttribPointer(2, 4, GL_FLOAT, GL_FALSE, vertex_size, c_void_p(7*4))  # color 
    1186         glVertexAttribPointer(3, 3, GL_FLOAT, GL_FALSE, vertex_size, c_void_p(11*4)) # normal 
    1187         glEnableVertexAttribArray(0) 
    1188         glEnableVertexAttribArray(1) 
    1189         glEnableVertexAttribArray(2) 
    1190         glEnableVertexAttribArray(3) 
    1191  
     917    def set_shown_attributes_indices(self, x_index, y_index, z_index, 
     918            color_index, symbol_index, size_index, label_index, 
     919            colors, num_symbols_used, 
     920            data_scale=array([1., 1., 1.]), data_translation=array([0., 0., 0.])): 
     921        start = time.time() 
     922        self.makeCurrent() 
     923        self.data_scale = data_scale 
     924        self.data_translation = data_translation 
     925        self.x_index = x_index 
     926        self.y_index = y_index 
     927        self.z_index = z_index 
     928        self.label_index = label_index 
     929 
     930        # If color is a discrete attribute, colors should be a list of colors 
     931        # each specified with vec3 (RGB). 
     932 
     933        # Re-run generating program (geometry shader), store 
     934        # results through transform feedback into a VBO on the GPU. 
     935        self.generating_program.bind() 
     936        self.generating_program.setUniformValue('x_index', x_index) 
     937        self.generating_program.setUniformValue('y_index', y_index) 
     938        self.generating_program.setUniformValue('z_index', z_index) 
     939        self.generating_program.setUniformValue('color_index', color_index) 
     940        self.generating_program.setUniformValue('symbol_index', symbol_index) 
     941        self.generating_program.setUniformValue('size_index', size_index) 
     942        self.generating_program.setUniformValue('use_2d_symbols', self.use_2d_symbols) 
     943        self.generating_program.setUniformValue('example_size', self.example_size) 
     944        self.generating_program.setUniformValue('num_colors', len(colors)) 
     945        self.generating_program.setUniformValue('num_symbols_used', num_symbols_used) 
     946        glUniform3fv(glGetUniformLocation(self.generating_program.programId(), 'colors'), 
     947            len(colors), numpy.array(colors, 'f').ravel()) 
     948        glUniform1iv(glGetUniformLocation(self.generating_program.programId(), 'symbols_sizes'), 
     949            len(Symbol)*2, numpy.array(self.symbols_sizes, dtype='i')) 
     950        glUniform1iv(glGetUniformLocation(self.generating_program.programId(), 'symbols_indices'), 
     951            len(Symbol)*2, numpy.array(self.symbols_indices, dtype='i')) 
     952 
     953        glActiveTexture(GL_TEXTURE0) 
     954        glBindTexture(GL_TEXTURE_BUFFER, self.symbol_buffer) 
     955        self.generating_program.setUniformValue('symbol_buffer', 0) 
     956        glActiveTexture(GL_TEXTURE1) 
     957        glBindTexture(GL_TEXTURE_BUFFER, self.data_buffer) 
     958        self.generating_program.setUniformValue('data_buffer', 1) 
     959 
     960        qid = glGenQueries(1) 
     961        glBeginQuery(GL_TRANSFORM_FEEDBACK_PRIMITIVES_WRITTEN, qid) 
     962        glBindBufferBase(GL_TRANSFORM_FEEDBACK_BUFFER, 0, self.feedback_bid) 
     963        glEnable(GL_RASTERIZER_DISCARD) 
     964        glBeginTransformFeedback(GL_TRIANGLES) 
     965 
     966        glBindVertexArray(self.dummy_vao) 
     967        glDrawArrays(GL_POINTS, 0, self.num_examples) 
     968 
     969        glEndTransformFeedback() 
     970        glDisable(GL_RASTERIZER_DISCARD) 
     971 
     972        glEndQuery(GL_TRANSFORM_FEEDBACK_PRIMITIVES_WRITTEN) 
     973        self.num_primitives_generated = glGetQueryObjectuiv(qid, GL_QUERY_RESULT) 
    1192974        glBindVertexArray(0) 
    1193         glBindBuffer(GL_ARRAY_BUFFER, 0) 
    1194  
    1195         vao_id.num_3d_vertices = num_3d_vertices 
    1196         vao_id.num_2d_vertices = num_2d_vertices 
    1197         vao_id.num_edge_vertices = num_edge_vertices 
    1198         self.vertex_buffers.append(vertex_buffer_id) 
    1199         self.vaos.append(vao_id) 
    1200         self.commands.append(("scatter", [vao_id, (X,Y,Z), labels])) 
     975        self.feedback_generated = True 
     976        print('Num generated primitives: ' + str(self.num_primitives_generated)) 
     977 
     978        self.generating_program.release() 
     979        glActiveTexture(GL_TEXTURE0) 
     980        print('Generation took ' + str(time.time()-start) + ' seconds') 
    1201981        self.updateGL() 
    1202982 
    1203     def set_x_axis_map(self, map): 
    1204         self.x_axis_map = map 
    1205         self.updateGL() 
    1206  
    1207     def set_y_axis_map(self, map): 
    1208         self.y_axis_map = map 
    1209         self.updateGL() 
    1210  
    1211     def set_z_axis_map(self, map): 
    1212         self.z_axis_map = map 
    1213         self.updateGL() 
     983    def set_data(self, data, subset_data=None): 
     984        self.makeCurrent() 
     985        start = time.time() 
     986 
     987        data_array = numpy.array(data.transpose().flatten(), dtype='f') 
     988        self.example_size = len(data) 
     989        self.num_examples = len(data[0]) 
     990        self.data = data 
     991 
     992        tbo = glGenBuffers(1) 
     993        glBindBuffer(GL_TEXTURE_BUFFER, tbo) 
     994        glBufferData(GL_TEXTURE_BUFFER, len(data_array)*4, data_array, GL_STATIC_DRAW) 
     995        glBindBuffer(GL_TEXTURE_BUFFER, 0) 
     996 
     997        self.data_buffer = glGenTextures(1) 
     998        glBindTexture(GL_TEXTURE_BUFFER, self.data_buffer) 
     999        GL_R32F = 0x822E 
     1000        glTexBuffer(GL_TEXTURE_BUFFER, GL_R32F, tbo) 
     1001        glBindTexture(GL_TEXTURE_BUFFER, 0) 
     1002 
     1003        print('Uploading data to GPU took ' + str(time.time()-start) + ' seconds') 
     1004 
     1005    def set_axis_labels(self, axis_id, labels): 
     1006        '''labels should be a list of strings''' 
     1007        if Axis.is_valid(axis_id) and axis_id != Axis.CUSTOM: 
     1008            setattr(self, Axis.to_str(axis_id).lower() + '_axis_labels', labels) 
     1009 
     1010    def set_axis_title(self, axis_id, title): 
     1011        if Axis.is_valid(axis_id) and axis_id != Axis.CUSTOM: 
     1012            setattr(self, Axis.to_str(axis_id).lower() + '_axis_title', title) 
     1013 
     1014    def set_show_axis_title(self, axis_id, show): 
     1015        if Axis.is_valid(axis_id) and axis_id != Axis.CUSTOM: 
     1016            setattr(self, 'show_' + Axis.to_str(axis_id).lower() + '_axis_title', title) 
    12141017 
    12151018    def set_new_zoom(self, x_min, x_max, y_min, y_max, z_min, z_max): 
     1019        '''Specifies new zoom in data coordinates.''' 
    12161020        self.selections = [] 
    1217         self.zoom_stack.append((self.scale, self.translation)) 
    1218  
    1219         max = numpy.array([x_max, y_max, z_max]) 
    1220         min = numpy.array([x_min, y_min, z_min]) 
    1221         min, max = map(numpy.copy, [min, max]) 
    1222         min -= self.data_center 
     1021        self.zoom_stack.append((self.plot_scale, self.plot_translation)) 
     1022 
     1023        max = array([x_max, y_max, z_max]).copy() 
     1024        min = array([x_min, y_min, z_min]).copy() 
     1025        min -= self.data_translation 
    12231026        min *= self.data_scale 
    1224         max -= self.data_center 
     1027        max -= self.data_translation 
    12251028        max *= self.data_scale 
    12261029        center = (max + min) / 2. 
    1227         new_translation = -numpy.array(center) 
     1030        new_translation = -array(center) 
    12281031        # Avoid division by zero by adding a small value (this happens when zooming in 
    12291032        # on elements with the same value of an attribute). 
    1230         self.zoomed_size = numpy.array(map(lambda i: i+0.001 if i == 0 else i, max-min)) 
    1231         new_scale = self.view_cube_edge / self.zoomed_size 
     1033        self.zoomed_size = array(map(lambda i: i+1e-5 if i == 0 else i, max-min)) 
     1034        new_scale = 1. / self.zoomed_size 
    12321035        self._animate_new_scale_translation(new_scale, new_translation) 
    12331036 
    12341037    def _animate_new_scale_translation(self, new_scale, new_translation, num_steps=10): 
    1235         translation_step = (new_translation - self.translation) / float(num_steps) 
    1236         scale_step = (new_scale - self.scale) / float(num_steps) 
     1038        translation_step = (new_translation - self.plot_translation) / float(num_steps) 
     1039        scale_step = (new_scale - self.plot_scale) / float(num_steps) 
    12371040        # Animate zooming: translate first for a number of steps, 
    12381041        # then scale. Make sure it doesn't take too long. 
     
    12401043        for i in range(num_steps): 
    12411044            if time.time() - start > 1.: 
    1242                 self.translation = new_translation 
     1045                self.plot_translation = new_translation 
    12431046                break 
    1244             self.translation = self.translation + translation_step 
     1047            self.plot_translation = self.plot_translation + translation_step 
    12451048            self.updateGL() 
    12461049        for i in range(num_steps): 
    12471050            if time.time() - start > 1.: 
    1248                 self.scale = new_scale 
     1051                self.plot_scale = new_scale 
    12491052                break 
    1250             self.scale = self.scale + scale_step 
     1053            self.plot_scale = self.plot_scale + scale_step 
    12511054            self.updateGL() 
    12521055 
    1253     def pop_zoom(self): 
     1056    def zoom_out(self): 
    12541057        if len(self.zoom_stack) < 1: 
    1255             new_translation = numpy.array([0., 0., 0.]) 
    1256             new_scale = numpy.array([1., 1., 1.]) 
     1058            new_translation = -array([0.5, 0.5, 0.5]) 
     1059            new_scale = array([1., 1., 1.]) 
    12571060        else: 
    12581061            new_scale, new_translation = self.zoom_stack.pop() 
    12591062        self._animate_new_scale_translation(new_scale, new_translation) 
    1260         self.zoomed_size = self.view_cube_edge / new_scale 
     1063        self.zoomed_size = 1. / new_scale 
    12611064 
    12621065    def save_to_file(self): 
     
    12701073        return img.save(file_name) 
    12711074 
    1272     def transform_data_to_plot(self, vertex): 
    1273         vertex -= self.data_center 
    1274         vertex *= self.data_scale 
    1275         vertex += self.translation 
    1276         vertex *= numpy.maximum([0., 0., 0.], self.scale + self.additional_scale) 
    1277         return vertex 
    1278  
    1279     def transform_plot_to_data(self, vertex): 
    1280         denominator = numpy.maximum([0., 0., 0.], self.scale + self.additional_scale) 
    1281         denominator = numpy.array(map(lambda v: v+0.00001 if v == 0. else v, denominator)) 
    1282         vertex /= denominator 
    1283         vertex -= self.translation 
    1284         vertex /= self.data_scale 
    1285         vertex += self.data_center 
    1286         return vertex 
     1075    def map_to_plot(self, point, original=True): 
     1076        if original: 
     1077            point -= self.data_translation 
     1078            point *= self.data_scale 
     1079        point += self.plot_translation 
     1080        plot_scale = maximum([1e-5, 1e-5, 1e-5], self.plot_scale+self.additional_scale) 
     1081        point *= plot_scale 
     1082        return point 
     1083 
     1084    def map_to_data(self, point, original=True): 
     1085        plot_scale = maximum([1e-5, 1e-5, 1e-5], self.plot_scale+self.additional_scale) 
     1086        point /= plot_scale 
     1087        point -= self.plot_translation 
     1088        if original: 
     1089            point /= self.data_scale 
     1090            point += self.data_translation 
     1091        return point 
    12871092 
    12881093    def get_selection_indices(self): 
     
    12911096 
    12921097        width, height = self.width(), self.height() 
    1293         if self.use_fbos and width <= 1024 and height <= 1024: 
     1098        if False and self.use_fbos and width <= 1024 and height <= 1024: 
    12941099            self.selection_fbo_dirty = True 
    12951100            self.updateGL() 
     
    13151120        else: 
    13161121            # Slower method (projects points manually and checks containments). 
    1317             projection = QMatrix4x4() 
    1318             if self.use_ortho: 
    1319                 projection.ortho(-width / self.ortho_scale, width / self.ortho_scale, 
    1320                                  -height / self.ortho_scale, height / self.ortho_scale, 
    1321                                  self.ortho_near, self.ortho_far) 
    1322             else: 
    1323                 projection.perspective(self.camera_fov, float(width) / height, 
    1324                                        self.perspective_near, self.perspective_far) 
    1325  
    1326             modelview = QMatrix4x4() 
    1327             modelview.lookAt(QVector3D(self.camera[0]*self.camera_distance, 
    1328                                        self.camera[1]*self.camera_distance, 
    1329                                        self.camera[2]*self.camera_distance), 
    1330                              QVector3D(0,-1, 0), 
    1331                              QVector3D(0, 1, 0)) 
    1332  
     1122            modelview, projection = self.get_mvp() 
    13331123            proj_model = projection * modelview 
    13341124            viewport = [0, 0, width, height] 
     
    13431133 
    13441134            indices = [] 
    1345             for (cmd, params) in self.commands: 
    1346                 if cmd == 'scatter': 
    1347                     _, (X, Y, Z), _ = params 
    1348                     for i, (x, y, z) in enumerate(zip(X, Y, Z)): 
    1349                         x, y, z = self.transform_data_to_plot((x,y,z)) 
    1350                         x_win, y_win = project(x, y, z) 
    1351                         if any(sel.contains(x_win, y_win) for sel in self.selections): 
    1352                             indices.append(i) 
     1135            for i, example in enumerate(self.data.transpose()): 
     1136                x = example[self.x_index] 
     1137                y = example[self.y_index] 
     1138                z = example[self.z_index] 
     1139                x, y, z = self.map_to_plot(array([x,y,z]).copy(), original=False) 
     1140                x_win, y_win = project(x, y, z) 
     1141                if any(sel.contains(x_win, y_win) for sel in self.selections): 
     1142                    indices.append(i) 
    13531143 
    13541144            return indices 
     
    13841174                self.state = PlotState.SCALING 
    13851175                self.scaling_init_pos = self.mouse_pos 
    1386                 self.additional_scale = [0., 0., 0.] 
     1176                self.additional_scale = array([0., 0., 0.]) 
    13871177            else: 
    1388                 self.pop_zoom() 
     1178                self.zoom_out() 
    13891179            self.updateGL() 
    13901180        elif buttons & Qt.MiddleButton: 
     
    14321222                right_vec = normalize(numpy.cross(self.camera, [0, 1, 0])) 
    14331223                up_vec = normalize(numpy.cross(right_vec, self.camera)) 
    1434                 right_scale = self.width()*max(self.scale[0], self.scale[2])*0.1 
    1435                 up_scale = self.height()*self.scale[1]*0.1 
    1436                 self.translation -= right_vec*(dx / right_scale) +\ 
    1437                                     up_vec*(dy / up_scale) 
     1224                right_vec[0] *= dx / (self.width() * self.plot_scale[0] * self.panning_factor) 
     1225                right_vec[2] *= dx / (self.width() * self.plot_scale[2] * self.panning_factor) 
     1226                up_scale = self.height()*self.plot_scale[1]*self.panning_factor 
     1227                self.plot_translation -= right_vec + up_vec*(dy / up_scale) 
    14381228            else: 
    14391229                self.yaw += dx / (self.rotation_factor*self.width()) 
     
    14431233            dx = pos.x() - self.scaling_init_pos.x() 
    14441234            dy = pos.y() - self.scaling_init_pos.y() 
    1445             dx /= float(self.zoomed_size[0 if self.scale_x_axis else 2]) 
     1235            dx /= float(self.zoomed_size[0]) # TODO 
    14461236            dy /= float(self.zoomed_size[1]) 
    14471237            dx /= self.scale_factor * self.width() 
    14481238            dy /= self.scale_factor * self.height() 
    1449             self.additional_scale = [dx, dy, 0] if self.scale_x_axis else [0, dy, dx] 
     1239            self.additional_scale = [dx, dy, 0] 
    14501240        elif self.state == PlotState.PANNING: 
    14511241            self.dragged_selection.move(dx, dy) 
     
    14601250 
    14611251        if self.state == PlotState.SCALING: 
    1462             self.scale = numpy.maximum([0., 0., 0.], self.scale + self.additional_scale) 
    1463             self.additional_scale = [0., 0., 0.] 
     1252            self.plot_scale = numpy.maximum([1e-5, 1e-5, 1e-5], self.plot_scale+self.additional_scale) 
     1253            self.additional_scale = array([0., 0., 0.]) 
    14641254            self.state = PlotState.IDLE 
    14651255        elif self.state == PlotState.SELECTING: 
     
    14901280            self.selections = [] 
    14911281            delta = 1 + event.delta() / self.zoom_factor 
    1492             self.scale *= delta 
     1282            self.plot_scale *= delta 
    14931283            self.tooltip_fbo_dirty = True 
    14941284            self.updateGL() 
     
    15021292    def remove_all_selections(self): 
    15031293        self.selections = [] 
    1504         self.selection_changed_callback() if self.selection_changed_callback else None 
     1294        if self.selection_changed_callback and self.selection_type != SelectionType.ZOOM: 
     1295            self.selection_changed_callback() 
    15051296        self.updateGL() 
    15061297 
     
    15191310 
    15201311    def clear(self): 
    1521         self.commands = [] 
    15221312        self.selections = [] 
    15231313        self.legend.clear() 
    15241314        self.zoom_stack = [] 
    1525         self.zoomed_size = [self.view_cube_edge, 
    1526                             self.view_cube_edge, 
    1527                             self.view_cube_edge] 
    1528         self.translation = numpy.array([0., 0., 0.]) 
    1529         self.scale = numpy.array([1., 1., 1.]) 
    1530         self.additional_scale = numpy.array([0., 0., 0.]) 
    1531         self.x_axis_title = self.y_axis_title = self.z_axis_title = '' 
    1532         self.x_axis_map = self.y_axis_map = self.z_axis_map = None 
     1315        self.zoomed_size = [1., 1., 1.] 
     1316        self.plot_translation = -array([0.5, 0.5, 0.5]) 
     1317        self.plot_scale = array([1., 1., 1.]) 
     1318        self.additional_scale = array([0., 0., 0.]) 
     1319        self.data_scale = array([1., 1., 1.]) 
     1320        self.data_translation = array([0., 0., 0.]) 
     1321        self.x_axis_labels = None 
     1322        self.y_axis_labels = None 
     1323        self.z_axis_labels = None 
    15331324        self.tooltip_fbo_dirty = True 
    15341325        self.selection_fbo_dirty = True 
    1535         self.updateGL() 
    1536  
    15371326 
    15381327if __name__ == "__main__": 
  • orange/OrangeWidgets/plot/owprimitives3d.py

    r8567 r8721  
    11import os 
    22import re 
    3 from owplot3d import Symbol, normal_from_points 
     3from owplot3d import Symbol 
     4import numpy 
     5 
     6def normalize(vec): 
     7    return vec / numpy.sqrt(numpy.sum(vec**2)) 
     8 
     9def clamp(value, min, max): 
     10    if value < min: 
     11        return min 
     12    if value > max: 
     13        return max 
     14    return value 
     15 
     16def normal_from_points(p1, p2, p3): 
     17    if isinstance(p1, (list, tuple)): 
     18        v1 = [p2[0]-p1[0], p2[1]-p1[1], p2[2]-p1[2]] 
     19        v2 = [p3[0]-p1[0], p3[1]-p1[1], p3[2]-p1[2]] 
     20    else: 
     21        v1 = p2 - p1 
     22        v2 = p3 - p1 
     23    return normalize(numpy.cross(v1, v2)) 
    424 
    525symbol_map = { 
Note: See TracChangeset for help on using the changeset viewer.