Changeset 8391:8e2a256ab5df in orange


Ignore:
Timestamp:
07/15/11 22:59:42 (3 years ago)
Author:
matejd <matejd@…>
Branch:
default
Convert:
420b2caa03bcf862126c88e23c5225508a3e3075
Message:

owplot3d: arbitrary scaling, refactored legend, bugfixes

Location:
orange/OrangeWidgets
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • orange/OrangeWidgets/Prototypes/OWScatterPlot3D.py

    r8389 r8391  
    178178        X, Y, Z, mask = self.get_axis_data(x_ind, y_ind, z_ind) 
    179179 
     180        color_legend_items = [] 
    180181        if self.color_attr > 0: 
    181182            color_attr = self.axis_candidate_attrs[self.color_attr - 1] 
     
    185186                colors = [palette[int(value)] for value in C.ravel()] 
    186187                colors = [[c.red()/255., c.green()/255., c.blue()/255., self.alpha_value/255.] for c in colors] 
     188                palette_colors = [palette[i] for i in range(len(color_attr.values))] 
     189                color_legend_items = [[Symbol.TRIANGLE, [c.red()/255., c.green()/255., c.blue()/255., 1], 1, title] 
     190                    for c, title in zip(palette_colors, color_attr.values)] 
    187191            else: 
    188192                palette = OWColorPalette.ColorPaletteBW() 
     
    206210 
    207211        shapes = None 
    208         legend_items = [] 
    209212        if self.shape_attr > 0: 
    210             i,shape_attr = self.discrete_attrs[self.shape_attr] 
    211             legend_items = shape_attr.values 
     213            i, shape_attr = self.discrete_attrs[self.shape_attr] 
    212214            if shape_attr.varType == orange.VarTypes.Discrete: 
    213215                # Map discrete attribute to [0...num shapes-1] 
     
    227229 
    228230        self.plot.clear() 
     231 
     232        num_symbols = len(Symbol) 
     233        if self.shape_attr > 0: 
     234            _, shape_attr = self.discrete_attrs[self.shape_attr] 
     235            titles = list(shape_attr.values) 
     236            for i, title in enumerate(titles): 
     237                if i == num_symbols-1: 
     238                    title = ', '.join(titles[i:]) 
     239                self.plot.legend.add_item(i, (0,0,0,1), 1, '{0}={1}'.format(shape_attr.name, title)) 
     240                if i == num_symbols-1: 
     241                    break 
     242 
     243        if color_legend_items: 
     244            for item in color_legend_items: 
     245                self.plot.legend.add_item(*item) 
     246 
    229247        self.plot.scatter(X, Y, Z, colors, sizes, shapes, labels) 
    230248        self.plot.set_x_axis_title(self.axis_candidate_attrs[self.x_attr].name) 
    231249        self.plot.set_x_axis_title(self.axis_candidate_attrs[self.y_attr].name) 
    232250        self.plot.set_x_axis_title(self.axis_candidate_attrs[self.z_attr].name) 
    233         for i, value in enumerate(legend_items): 
    234             self.plot.add_legend_item(i, value) 
    235251 
    236252    def get_axis_data(self, x_ind, y_ind, z_ind): 
  • orange/OrangeWidgets/owplot3d.py

    r8390 r8391  
    3737        Removes everything from the graph. 
    3838""" 
     39 
     40__all__ = ['OWPlot3D', 'Symbol'] 
    3941 
    4042from PyQt4.QtCore import * 
     
    101103    return value 
    102104 
     105def enum(*sequential): 
     106    enums = dict(zip(sequential, range(len(sequential)))) 
     107    enums['is_valid'] = lambda self, enum_value: enum_value < len(sequential) 
     108    enums['to_str'] = lambda self, enum_value: sequential[enum_value] 
     109    enums['__len__'] = lambda self: len(sequential) 
     110    return type('Enum', (), enums)() 
     111 
     112# States the plot can be in: 
     113# * idle: mostly doing nothing, rotations are not considered special 
     114# * dragging legend: user has pressed left mouse button and is now dragging legend 
     115#   to a more suitable location 
     116# * scaling: user has pressed right mouse button, dragging it up and down 
     117#   scales data in y-coordinate, dragging it right and left scales data 
     118#   in current horizontal coordinate (x or z, depends on rotation) 
     119PlotState = enum('IDLE', 'DRAGGING_LEGEND', 'SCALING', 'SELECTING') 
     120 
     121# TODO: more symbols 
     122Symbol = enum('TRIANGLE', 'RECTANGLE', 'PENTAGON', 'CIRCLE') 
     123 
     124class Legend(object): 
     125    def __init__(self, plot): 
     126        self.border_color = [0.5, 0.5, 0.5, 1] 
     127        self.border_thickness = 2 
     128        self.position = [0, 0] 
     129        self.size = [0, 0] 
     130        self.items = [] 
     131        self.plot = plot 
     132        self.symbol_scale = 6 
     133        self.font = QFont() 
     134        self.metrics = QFontMetrics(self.font) 
     135 
     136    def add_item(self, symbol, color, size, title): 
     137        '''Adds an item to the legend. 
     138           Symbol can be integer value or enum Symbol. 
     139           Color should be RGBA. Size should be between 0 and 1. 
     140        ''' 
     141        if not Symbol.is_valid(symbol): 
     142            return 
     143        self.items.append([symbol, color, size, title]) 
     144        self.size[0] = max(self.metrics.width(item[3]) for item in self.items) + 40 
     145        self.size[1] = len(self.items) * self.metrics.height() + 4 
     146 
     147    def clear(self): 
     148        self.items = [] 
     149 
     150    def draw(self): 
     151        if not self.items: 
     152            return 
     153 
     154        x, y = self.position 
     155        w, h = self.size 
     156        t = self.border_thickness 
     157 
     158        # Draw legend outline first. 
     159        glDisable(GL_DEPTH_TEST) 
     160        glColor4f(*self.border_color) 
     161        glBegin(GL_QUADS) 
     162        glVertex2f(x,   y) 
     163        glVertex2f(x+w, y) 
     164        glVertex2f(x+w, y+h) 
     165        glVertex2f(x,   y+h) 
     166        glEnd() 
     167 
     168        glColor4f(1, 1, 1, 1) 
     169        glBegin(GL_QUADS) 
     170        glVertex2f(x+t,   y+t) 
     171        glVertex2f(x+w-t, y+t) 
     172        glVertex2f(x+w-t, y+h-t) 
     173        glVertex2f(x+t,   y+h-t) 
     174        glEnd() 
     175 
     176        def draw_ngon(n, x, y, size): 
     177            glBegin(GL_TRIANGLES) 
     178            angle_inc = 2.*pi / n 
     179            angle = angle_inc / 2. 
     180            for i in range(n): 
     181                glVertex2f(x,y) 
     182                glVertex2f(x-cos(angle)*size, y-sin(angle)*size) 
     183                angle += angle_inc 
     184                glVertex2f(x-cos(angle)*size, y-sin(angle)*size) 
     185            glEnd() 
     186 
     187        item_pos_y = y + t + 13 
     188        symbol_to_n = {Symbol.TRIANGLE: 3, 
     189                       Symbol.RECTANGLE: 4, 
     190                       Symbol.PENTAGON: 5, 
     191                       Symbol.CIRCLE: 8} 
     192 
     193        for symbol, color, size, text in self.items: 
     194            glColor4f(*color) 
     195            draw_ngon(symbol_to_n[symbol], x+t+10, item_pos_y-4, size*self.symbol_scale) 
     196            self.plot.renderText(x+t+30, item_pos_y, text) 
     197            item_pos_y += self.metrics.height() 
     198 
     199    def point_inside(self, x, y): 
     200        return self.position[0] <= x <= self.position[0]+self.size[0] and\ 
     201               self.position[1] <= y <= self.position[1]+self.size[1] 
     202 
     203    def move(self, dx, dy): 
     204        self.position[0] += dx 
     205        self.position[1] += dy 
     206 
    103207 
    104208class OWPlot3D(QtOpenGL.QGLWidget): 
     209 
    105210    def __init__(self, parent=None): 
    106211        QtOpenGL.QGLWidget.__init__(self, QtOpenGL.QGLFormat(QtOpenGL.QGL.SampleBuffers), parent) 
     
    109214        self.minx = self.miny = self.minz = 0 
    110215        self.maxx = self.maxy = self.maxz = 0 
    111         self.b_box = [numpy.array([0,   0,   0]), numpy.array([0, 0, 0])] 
    112         self.camera = numpy.array([0.6, 0.8, 0]) # Location on a unit sphere around the center. This is where camera is looking from. 
     216        self.camera = numpy.array([0.6, 0.8, 0]) 
    113217        self.center = numpy.array([0,   0,   0]) 
    114  
    115         # TODO: move to center shortcut (maybe a GUI element?) 
     218        self.view_cube_edge = 10 
     219        self.camera_distance = 30 
    116220 
    117221        self.yaw = self.pitch = 0. 
    118222        self.rotation_factor = 100. 
    119         self.zoom_factor = 100. 
    120         self.zoom = 10. 
     223        self.zoom_factor = 500. 
    121224        self.move_factor = 100. 
    122225        self.mouse_pos = [100, 100] # TODO: get real mouse position, calculate camera, fix the initial jump 
     
    139242        self.ortho = False 
    140243        self.show_legend = True 
    141         self.legend_border_color = [0.5, 0.5, 0.5, 1] 
    142         self.legend_border_thickness = 2 
    143         self.legend_position = [10, 10] 
    144         self.legend_size = [200, 50] 
    145         self.dragging_legend = False 
    146         self.legend_items = [] 
     244        self.legend = Legend(self) 
    147245 
    148246        self.face_symbols = True 
     
    151249        self.transparency = 255 
    152250        self.grid = True 
     251        self.scale = numpy.array([1., 1., 1.]) 
     252        self.add_scale = [0, 0, 0] 
     253        self.scale_x_axis = True 
     254        self.scale_factor = 30. 
     255 
     256        # Beside n-gons, symbols should also include cubes, spheres and other stuff. TODO 
     257        self.available_symbols = [3, 4, 5, 8] 
     258        self.state = PlotState.IDLE 
     259 
     260        self.build_axes() 
    153261 
    154262    def __del__(self): 
    155263        # TODO: delete shaders and vertex buffer 
    156         glDeleteProgram(self.color_shader) 
     264        glDeleteProgram(self.symbol_shader) 
    157265 
    158266    def initializeGL(self): 
    159         self.update_axes() 
    160267        glClearColor(1.0, 1.0, 1.0, 1.0) 
    161268        glClearDepth(1.0) 
     
    164271        glEnable(GL_LINE_SMOOTH) 
    165272 
    166         self.color_shader = glCreateProgram() 
     273        self.symbol_shader = glCreateProgram() 
    167274        vertex_shader = glCreateShader(GL_VERTEX_SHADER) 
    168275        fragment_shader = glCreateShader(GL_FRAGMENT_SHADER) 
     
    177284            uniform float symbol_scale; 
    178285            uniform float transparency; 
     286 
     287            uniform vec3 scale; 
     288            uniform vec3 translation; 
    179289 
    180290            varying vec4 var_color; 
     
    205315                offset_rotated = invs * offset_rotated; 
    206316 
     317              position += translation; 
     318              position *= scale; 
    207319              vec4 off_pos = vec4(position+offset_rotated, 1); 
    208320 
    209321              gl_Position = gl_ProjectionMatrix * gl_ModelViewMatrix * off_pos; 
    210               var_color = vec4(color.rgb, transparency); 
     322              position = abs(position); 
     323              float manhattan_distance = max(max(position.x, position.y), position.z)+5.; 
     324              var_color = vec4(color.rgb, pow(min(1, 10. / manhattan_distance), 5)); 
    211325            } 
    212326            ''' 
     
    244358                return 
    245359            else: 
    246                 glAttachShader(self.color_shader, shader) 
    247  
    248         glBindAttribLocation(self.color_shader, 0, 'position') 
    249         glBindAttribLocation(self.color_shader, 1, 'offset') 
    250         glBindAttribLocation(self.color_shader, 2, 'color') 
    251         glLinkProgram(self.color_shader) 
    252         self.color_shader_face_symbols = glGetUniformLocation(self.color_shader, 'face_symbols') 
    253         self.color_shader_symbol_scale = glGetUniformLocation(self.color_shader, 'symbol_scale') 
    254         self.color_shader_transparency = glGetUniformLocation(self.color_shader, 'transparency') 
     360                glAttachShader(self.symbol_shader, shader) 
     361 
     362        glBindAttribLocation(self.symbol_shader, 0, 'position') 
     363        glBindAttribLocation(self.symbol_shader, 1, 'offset') 
     364        glBindAttribLocation(self.symbol_shader, 2, 'color') 
     365        glLinkProgram(self.symbol_shader) 
     366        self.symbol_shader_face_symbols = glGetUniformLocation(self.symbol_shader, 'face_symbols') 
     367        self.symbol_shader_symbol_scale = glGetUniformLocation(self.symbol_shader, 'symbol_scale') 
     368        self.symbol_shader_transparency = glGetUniformLocation(self.symbol_shader, 'transparency') 
     369        self.symbol_shader_scale        = glGetUniformLocation(self.symbol_shader, 'scale') 
     370        self.symbol_shader_translation  = glGetUniformLocation(self.symbol_shader, 'translation') 
    255371        linked = c_int() 
    256         glGetProgramiv(self.color_shader, GL_LINK_STATUS, byref(linked)) 
     372        glGetProgramiv(self.symbol_shader, GL_LINK_STATUS, byref(linked)) 
    257373        if not linked.value: 
    258374            print('Failed to link shader!') 
     
    268384        glLoadIdentity() 
    269385        width, height = self.width(), self.height() 
    270         divide = self.zoom*10. 
    271         if self.ortho: 
    272             # TODO: fix ortho 
    273             glOrtho(-width/divide, width/divide, -height/divide, height/divide, -1, 2000) 
    274         else: 
    275             aspect = float(width) / height if height != 0 else 1 
    276             gluPerspective(30.0, aspect, 0.1, 2000) 
     386        #if self.ortho: 
     387        #    # TODO: fix ortho 
     388        #    glOrtho(-width/divide, width/divide, -height/divide, height/divide, -1, 2000) 
     389        #else: 
     390        aspect = float(width) / height if height != 0 else 1 
     391        gluPerspective(30.0, aspect, 0.1, 2000) 
    277392        glMatrixMode(GL_MODELVIEW) 
    278393        glLoadIdentity() 
    279         zoom = 100 if self.ortho else self.zoom 
    280394        gluLookAt( 
    281             self.camera[0]*zoom + self.center[0], 
    282             self.camera[1]*zoom + self.center[1], 
    283             self.camera[2]*zoom + self.center[2], 
    284             self.center[0], 
    285             self.center[1], 
    286             self.center[2], 
     395            self.camera[0]*self.camera_distance, 
     396            self.camera[1]*self.camera_distance, 
     397            self.camera[2]*self.camera_distance, 
     398            0, 0, 0, 
    287399            0, 1, 0) 
    288400        self.paint_axes() 
     
    296408            if cmd == 'scatter': 
    297409                vao, vao_outline, array, labels = params 
    298                 glUseProgram(self.color_shader) 
    299                 glUniform1i(self.color_shader_face_symbols, self.face_symbols) 
    300                 glUniform1f(self.color_shader_symbol_scale, self.symbol_scale) 
    301                 glUniform1f(self.color_shader_transparency, self.transparency) 
     410                glUseProgram(self.symbol_shader) 
     411                glUniform1i(self.symbol_shader_face_symbols, self.face_symbols) 
     412                glUniform1f(self.symbol_shader_symbol_scale, self.symbol_scale) 
     413                glUniform1f(self.symbol_shader_transparency, self.transparency) 
     414                scale = numpy.maximum([0,0,0], self.scale + self.add_scale) 
     415                glUniform3f(self.symbol_shader_scale,        *scale) 
     416                glUniform3f(self.symbol_shader_translation,  *(-self.center)) 
     417 
    302418                if self.filled_symbols: 
    303419                    glBindVertexArray(vao.value) 
     
    314430                        self.renderText(x,y,z, '{0:.1}'.format(label), font=self.labels_font) 
    315431 
    316         self.draw_center() 
    317  
    318432        glDisable(GL_BLEND) 
    319433        if self.show_legend: 
    320             self.draw_legend() 
    321  
    322     def draw_legend(self): 
    323         glMatrixMode(GL_PROJECTION) 
    324         glLoadIdentity() 
    325         glOrtho(0, self.width(), self.height(), 0, -1, 1) 
    326         glMatrixMode(GL_MODELVIEW) 
    327         glLoadIdentity() 
    328  
    329         x, y = self.legend_position 
    330         w, h = self.legend_size 
    331         t = self.legend_border_thickness 
    332  
    333         glDisable(GL_DEPTH_TEST) 
    334         glColor4f(*self.legend_border_color) 
    335         glBegin(GL_QUADS) 
    336         glVertex2f(x,   y) 
    337         glVertex2f(x+w, y) 
    338         glVertex2f(x+w, y+h) 
    339         glVertex2f(x,   y+h) 
    340         glEnd() 
    341  
    342         glColor4f(1, 1, 1, 1) 
    343         glBegin(GL_QUADS) 
    344         glVertex2f(x+t,   y+t) 
    345         glVertex2f(x+w-t, y+t) 
    346         glVertex2f(x+w-t, y+h-t) 
    347         glVertex2f(x+t,   y+h-t) 
    348         glEnd() 
    349  
    350         # TODO: clean this up 
    351         glColor4f(0.1, 0.1, 0.1, 1) 
    352         item_pos_y = y + 2*t + 10 
    353         for shape, text in self.legend_items: 
    354             if shape == 0: 
    355                 glBegin(GL_TRIANGLES) 
    356                 glVertex2f(x+10, item_pos_y) 
    357                 glVertex2f(x+20, item_pos_y) 
    358                 glVertex2f(x+15, item_pos_y-10) 
    359                 glEnd() 
    360  
    361             self.renderText(x+20+4*t, item_pos_y, text) 
    362             item_pos_y += 15 
    363  
    364     def add_legend_item(self, shape, text): 
    365         self.legend_items.append((shape, text)) 
     434            glMatrixMode(GL_PROJECTION) 
     435            glLoadIdentity() 
     436            glOrtho(0, self.width(), self.height(), 0, -1, 1) 
     437            glMatrixMode(GL_MODELVIEW) 
     438            glLoadIdentity() 
     439            self.legend.draw() 
     440 
     441        self.draw_helpers() 
     442 
     443    def draw_helpers(self): 
     444        def draw_triangle(x0, y0, x1, y1, x2, y2): 
     445            glBegin(GL_TRIANGLES) 
     446            glVertex2f(x0, y0) 
     447            glVertex2f(x1, y1) 
     448            glVertex2f(x2, y2) 
     449            glEnd() 
     450 
     451        def draw_line(x0, y0, x1, y1): 
     452            glBegin(GL_LINES) 
     453            glVertex2f(x0, y0) 
     454            glVertex2f(x1, y1) 
     455            glEnd() 
     456 
     457        if self.state == PlotState.SCALING: 
     458            if not self.show_legend: 
     459                glMatrixMode(GL_PROJECTION) 
     460                glLoadIdentity() 
     461                glOrtho(0, self.width(), self.height(), 0, -1, 1) 
     462                glMatrixMode(GL_MODELVIEW) 
     463                glLoadIdentity() 
     464            x, y = self.mouse_pos.x(), self.mouse_pos.y() 
     465            glColor4f(0,0,0,1) 
     466            draw_triangle(x-5, y-30, x+5, y-30, x, y-40) 
     467            draw_line(x, y, x, y-30) 
     468            draw_triangle(x-5, y-10, x+5, y-10, x, y) 
     469            self.renderText(x, y-50, 'Scale y axis', font=self.labels_font) 
     470 
     471            draw_triangle(x+10, y, x+20, y-5, x+20, y+5) 
     472            draw_line(x+10, y, x+40, y) 
     473            draw_triangle(x+50, y, x+40, y-5, x+40, y+5) 
     474            self.renderText(x+60, y+3, 
     475                            'Scale {0} axis'.format(['x', 'z'][self.scale_x_axis]), 
     476                            font=self.labels_font) 
     477        elif self.state == PlotState.SELECTING: 
     478            s = self.new_selection 
     479            glColor4f(0, 0, 0, 1) 
     480            draw_line(s[0], s[1], s[0], s[3]) 
     481            draw_line(s[0], s[3], s[2], s[3]) 
     482            draw_line(s[2], s[3], s[2], s[1]) 
     483            draw_line(s[2], s[1], s[0], s[1]) 
    366484 
    367485    def set_x_axis_title(self, title): 
     
    389507        self.updateGL() 
    390508 
    391     def draw_center(self): 
    392         glColor3f(0,0,0) 
    393         glLineWidth(2) 
    394         glBegin(GL_LINES) 
    395         size = 2. 
    396         glVertex3f(self.center[0] - size*self.normal_size, 
    397                    self.center[1] + size*self.normal_size, 
    398                    self.center[2]) 
    399         glVertex3f(self.center[0] + size*self.normal_size, 
    400                    self.center[1] - size*self.normal_size, 
    401                    self.center[2]) 
    402         glVertex3f(self.center[0] - size*self.normal_size, 
    403                    self.center[1] - size*self.normal_size, 
    404                    self.center[2]) 
    405         glVertex3f(self.center[0] + size*self.normal_size, 
    406                    self.center[1] + size*self.normal_size, 
    407                    self.center[2]) 
    408         glEnd() 
    409         glLineWidth(1) 
    410  
    411509    def paint_axes(self): 
    412         zoom = 100 if self.ortho else self.zoom 
    413510        cam_in_space = numpy.array([ 
    414           self.center[0] + self.camera[0]*zoom, 
    415           self.center[1] + self.camera[1]*zoom, 
    416           self.center[2] + self.camera[2]*zoom 
     511          self.camera[0]*self.camera_distance, 
     512          self.camera[1]*self.camera_distance, 
     513          self.camera[2]*self.camera_distance 
    417514        ]) 
    418515 
     
    461558        glEnable(GL_BLEND) 
    462559        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) 
    463         bb_center = (self.b_box[1] + self.b_box[0]) / 2. 
    464560 
    465561        # Draw axis labels. 
     
    549645        draw_values(y_axis_translated[rightmost_visible], 1, normal) 
    550646 
    551     def update_axes(self): 
    552         x_axis = [[self.minx, self.miny, self.minz], 
    553                   [self.maxx, self.miny, self.minz]] 
    554         y_axis = [[self.minx, self.miny, self.minz], 
    555                   [self.minx, self.maxy, self.minz]] 
    556         z_axis = [[self.minx, self.miny, self.minz], 
    557                   [self.minx, self.miny, self.maxz]] 
     647        # Remember which axis to scale when dragging mouse horizontally. 
     648        self.scale_x_axis = False if rightmost_visible % 2 == 0 else True 
     649 
     650    def build_axes(self): 
     651        edge_half = self.view_cube_edge / 2. 
     652        x_axis = [[-edge_half,-edge_half,-edge_half], [edge_half,-edge_half,-edge_half]] 
     653        y_axis = [[-edge_half,-edge_half,-edge_half], [-edge_half,edge_half,-edge_half]] 
     654        z_axis = [[-edge_half,-edge_half,-edge_half], [-edge_half,-edge_half,edge_half]] 
     655 
    558656        self.x_axis = x_axis = numpy.array(x_axis) 
    559657        self.y_axis = y_axis = numpy.array(y_axis) 
    560658        self.z_axis = z_axis = numpy.array(z_axis) 
    561659 
    562         self.unit_x = unit_x = numpy.array([self.maxx - self.minx, 0, 0]) 
    563         self.unit_y = unit_y = numpy.array([0, self.maxy - self.miny, 0]) 
    564         self.unit_z = unit_z = numpy.array([0, 0, self.maxz - self.minz]) 
     660        self.unit_x = unit_x = numpy.array([self.view_cube_edge,0,0]) 
     661        self.unit_y = unit_y = numpy.array([0,self.view_cube_edge,0]) 
     662        self.unit_z = unit_z = numpy.array([0,0,self.view_cube_edge]) 
    565663  
    566664        A = y_axis[1] 
     
    582680        self.axis_plane_xz_top = [E, F, B, A] 
    583681 
    584     def scatter(self, X, Y, Z, colors='b', sizes=5, shapes=None, labels=None, **kwargs): 
     682    def scatter(self, X, Y, Z, colors='b', sizes=5, symbols=None, labels=None, **kwargs): 
    585683        array = [[x, y, z] for x,y,z in zip(X, Y, Z)] 
    586684        if isinstance(colors, str): 
     
    594692            sizes = [sizes for _ in array] 
    595693 
    596         # Normalize sizes to 0..1 
    597         max_size = float(numpy.max(sizes)) 
    598         sizes = [size / max_size for size in sizes] 
    599  
    600         if shapes == None: 
    601             shapes = [0 for _ in array] 
     694        # Scale sizes to 0..1 
     695        self.max_size = float(numpy.max(sizes)) 
     696        sizes = [size / self.max_size for size in sizes] 
     697 
     698        if symbols == None: 
     699            symbols = [0 for _ in array] 
    602700 
    603701        max, min = numpy.max(array, axis=0), numpy.min(array, axis=0) 
    604         self.b_box = [max, min] 
    605         self.minx, self.miny, self.minz = min 
    606         self.maxx, self.maxy, self.maxz = max 
     702        self.min_x, self.min_y, self.min_z = min 
     703        self.max_x, self.max_y, self.max_z = max 
     704        self.range_x, self.range_y, self.range_z = max-min 
     705        self.middle_x, self.middle_y, self.middle_z = (min+max) / 2. 
    607706        self.center = (min + max) / 2  
    608         self.normal_size = numpy.max(self.center - self.b_box[1]) / 100. 
    609  
    610         # TODO: more shapes? For now, wrap other to these ones. 
    611         different_shapes = [3, 4, 5, 8] 
     707        self.normal_size = 0.2 
     708 
     709        self.scale_x = self.view_cube_edge / self.range_x 
     710        self.scale_y = self.view_cube_edge / self.range_y 
     711        self.scale_z = self.view_cube_edge / self.range_z 
    612712 
    613713        # Generate vertices for shapes and also indices for outlines. 
     
    615715        outline_indices = [] 
    616716        index = 0 
    617         for (x,y,z), (r,g,b,a), size, shape in zip(array, colors, sizes, shapes): 
     717        for (x,y,z), (r,g,b,a), size, symbol in zip(array, colors, sizes, symbols): 
    618718            sO2 = size * self.normal_size / 2. 
    619             n = different_shapes[shape % len(different_shapes)] 
     719            n = self.available_symbols[symbol % len(self.available_symbols)] 
    620720            angle_inc = 2.*pi / n 
    621721            angle = angle_inc / 2. 
     
    690790        self.vaos.append(vao_outline) 
    691791        self.commands.append(("scatter", [vao, vao_outline, array, labels])) 
    692         self.update_axes() 
    693792        self.updateGL() 
    694793 
    695794    def mousePressEvent(self, event): 
    696       self.mouse_pos = event.pos() 
     795        pos = self.mouse_pos = event.pos() 
     796        buttons = event.buttons() 
     797        if buttons & Qt.LeftButton: 
     798            if self.legend.point_inside(pos.x(), pos.y()): 
     799                self.state = PlotState.DRAGGING_LEGEND 
     800            else: 
     801                self.state = PlotState.SELECTING 
     802                self.new_selection = [pos.x(), pos.y(), 0, 0] 
     803        elif buttons & Qt.RightButton: 
     804            self.state = PlotState.SCALING 
     805            self.scaling_init_pos = self.mouse_pos 
     806            self.add_scale = [0, 0, 0] 
     807            self.updateGL() 
    697808 
    698809    def mouseMoveEvent(self, event): 
     
    702813 
    703814        if event.buttons() & Qt.LeftButton: 
    704             if self.dragging_legend: 
    705                 self.legend_position[0] += dx 
    706                 self.legend_position[1] += dy 
    707             elif self.legend_position[0] <= pos.x() <= self.legend_position[0]+self.legend_size[0] and\ 
    708                  self.legend_position[1] <= pos.y() <= self.legend_position[1]+self.legend_size[1]: 
    709                 self.dragging_legend = True 
     815            if self.state == PlotState.DRAGGING_LEGEND: 
     816                self.legend.move(dx, dy) 
     817            elif self.state == PlotState.SELECTING: 
     818                self.new_selection[2:] = [pos.x(), pos.y()] 
    710819        elif event.buttons() & Qt.MiddleButton: 
    711820            if QApplication.keyboardModifiers() & Qt.ShiftModifier: 
     
    715824                self.center += off_x 
    716825            else: 
    717                 self.yaw += dx /  self.rotation_factor 
     826                self.yaw += dx / self.rotation_factor 
    718827                self.pitch += dy / self.rotation_factor 
    719828                self.pitch = clamp(self.pitch, -3., -0.1) 
     
    723832                    sin(self.pitch)*sin(self.yaw)] 
    724833 
     834        if self.state == PlotState.SCALING: 
     835            dx = pos.x() - self.scaling_init_pos.x() 
     836            dy = pos.y() - self.scaling_init_pos.y() 
     837            self.add_scale = [dx / self.scale_factor, dy / self.scale_factor, 0]\ 
     838                if self.scale_x_axis else [0, dy / self.scale_factor, dx / self.scale_factor] 
     839 
    725840        self.mouse_pos = pos 
    726841        self.updateGL() 
    727842 
    728843    def mouseReleaseEvent(self, event): 
    729         self.dragging_legend = False 
     844        if self.state == PlotState.SCALING: 
     845            self.scale = numpy.maximum([0,0,0], self.scale + self.add_scale) 
     846            self.add_scale = [0,0,0] 
     847 
     848        self.state = PlotState.IDLE 
     849        self.updateGL() 
    730850 
    731851    def wheelEvent(self, event): 
    732852        if event.orientation() == Qt.Vertical: 
    733             self.zoom -= event.delta() / self.zoom_factor 
    734             if self.zoom < 2: 
    735                 self.zoom = 2 
     853            delta = 1 + event.delta() / self.zoom_factor 
     854            self.scale *= delta 
    736855            self.updateGL() 
    737856 
    738857    def clear(self): 
    739858        self.commands = [] 
    740         self.legend_items = [] 
     859        self.legend.clear() 
    741860 
    742861 
Note: See TracChangeset for help on using the changeset viewer.