source: orange/Orange/OrangeWidgets/VisualizeQt/OWScatterPlot3D.py @ 11474:df0622184ee6

Revision 11474:df0622184ee6, 35.8 KB checked in by markotoplak, 12 months ago (diff)

Renamed Visualize Qt to VisualizeQt (so it can be loaded in new canvas).

RevLine 
[8735]1'''
2<name>Scatterplot 3D</name>
[11474]3<icon>icons/ScatterPlot.svg</icon>
[8735]4<priority>2001</priority>
[8722]5'''
[8682]6
[8927]7from math import log10, ceil, floor
8
[8682]9from OWWidget import *
10from plot.owplot3d import *
[8838]11from plot.owtheme import ScatterLightTheme, ScatterDarkTheme
[8820]12from plot import OWPoint
[8682]13
14import orange
15Discrete = orange.VarTypes.Discrete
16Continuous = orange.VarTypes.Continuous
17
[10542]18from Orange.data.preprocess.scaling import get_variable_values_sorted
[8735]19
[8682]20import OWGUI
21import orngVizRank
22from OWkNNOptimization import *
23from orngScaleScatterPlotData import *
24
25import numpy
26
[8730]27TooltipKind = enum('NONE', 'VISIBLE', 'ALL')
[8927]28Axis = enum('X', 'Y', 'Z')
29
[8975]30class Plane:
31    '''Internal convenience class.'''
32    def __init__(self, A, B, C, D):
33        self.A = A
34        self.B = B
35        self.C = C
36        self.D = D
37
38    def normal(self):
39        v1 = self.A - self.B
40        v2 = self.A - self.C
41        return QVector3D.crossProduct(v1, v2).normalized()
42
43    def visible_from(self, location):
44        normal = self.normal()
45        loc_plane = (self.A - location).normalized()
46        if QVector3D.dotProduct(normal, loc_plane) > 0:
47            return False
48        return True
49
50class Edge:
51    def __init__(self, v0, v1):
52        self.v0 = v0
53        self.v1 = v1
54
55    def __add__(self, vec):
56        return Edge(self.v0 + vec, self.v1 + vec)
[8927]57
58def nicenum(x, round):
59    if x <= 0.:
60        return x # TODO: what to do in such cases?
61    expv = floor(log10(x))
62    f = x / pow(10., expv)
63    if round:
64        if f < 1.5: nf = 1.
65        elif f < 3.: nf = 2.
66        elif f < 7.: nf = 5.
67        else: nf = 10.
68    else:
69        if f <= 1.: nf = 1.
70        elif f <= 2.: nf = 2.
71        elif f <= 5.: nf = 5.
72        else: nf = 10.
73    return nf * pow(10., expv)
74
75def loose_label(min_value, max_value, num_ticks):
76    '''Algorithm by Paul S. Heckbert (Graphics Gems).
77       Generates a list of "nice" values between min and max,
78       given the number of ticks. Also returns the number
79       of fractional digits to use.
80    '''
81    range = nicenum(max_value-min_value, False)
82    d = nicenum(range / float(num_ticks-1), True)
83    if d <= 0.: # TODO
84        return numpy.arange(min_value, max_value, (max_value-min_value)/num_ticks), 1
85    plot_min = floor(min_value / d) * d
86    plot_max = ceil(max_value / d) * d
87    num_frac = int(max(-floor(log10(d)), 0))
88    return numpy.arange(plot_min, plot_max + 0.5*d, d), num_frac
[8682]89
90class ScatterPlot(OWPlot3D, orngScaleScatterPlotData):
91    def __init__(self, parent=None):
[8820]92        self.parent = parent
[8682]93        OWPlot3D.__init__(self, parent)
94        orngScaleScatterPlotData.__init__(self)
95
[8838]96        self._theme = ScatterLightTheme()
[8722]97        self.show_grid = True
98        self.show_chassis = True
[8927]99        self.show_axes = True
[8967]100        self._build_axes()
[8927]101
[8967]102        self._x_axis_labels = None
103        self._y_axis_labels = None
104        self._z_axis_labels = None
[8927]105
[8967]106        self._x_axis_title = ''
107        self._y_axis_title = ''
108        self._z_axis_title = ''
[8927]109
[8967]110        # These are public
[8927]111        self.show_x_axis_title = self.show_y_axis_title = self.show_z_axis_title = True
[8837]112
[8820]113        self.animate_plot = False
[8722]114
115    def set_data(self, data, subset_data=None, **args):
[8727]116        if data == None:
117            return
[8794]118        args['skipIfSame'] = False
[8722]119        orngScaleScatterPlotData.set_data(self, data, subset_data, **args)
[8970]120        # Optimization: calling set_plot_data here (and not in update_data) because data won't change.
[8794]121        OWPlot3D.set_plot_data(self, self.scaled_data, self.scaled_subset_data)
[8820]122        OWPlot3D.initializeGL(self)
[8722]123
124    def update_data(self, x_attr, y_attr, z_attr,
125                    color_attr, symbol_attr, size_attr, label_attr):
[8727]126        if self.data == None:
127            return
[8722]128        self.before_draw_callback = self.before_draw
129
[8928]130        color_discrete = size_discrete = False
[8722]131
132        color_index = -1
133        if color_attr != '' and color_attr != '(Same color)':
134            color_index = self.attribute_name_index[color_attr]
135            if self.data_domain[color_attr].varType == Discrete:
136                color_discrete = True
[8837]137                self.discrete_palette.setNumberOfColors(len(self.data_domain[color_attr].values))
[8722]138
139        symbol_index = -1
140        num_symbols_used = -1
141        if symbol_attr != '' and symbol_attr != 'Same symbol)' and\
[8820]142           len(self.data_domain[symbol_attr].values) < len(Symbol) and\
143           self.data_domain[symbol_attr].varType == Discrete:
[8722]144            symbol_index = self.attribute_name_index[symbol_attr]
[8820]145            num_symbols_used = len(self.data_domain[symbol_attr].values)
[8722]146
147        size_index = -1
148        if size_attr != '' and size_attr != '(Same size)':
149            size_index = self.attribute_name_index[size_attr]
150            if self.data_domain[size_attr].varType == Discrete:
151                size_discrete = True
152
153        label_index = -1
154        if label_attr != '' and label_attr != '(No labels)':
155            label_index = self.attribute_name_index[label_attr]
156
157        x_index = self.attribute_name_index[x_attr]
158        y_index = self.attribute_name_index[y_attr]
159        z_index = self.attribute_name_index[z_attr]
160
161        x_discrete = self.data_domain[x_attr].varType == Discrete
162        y_discrete = self.data_domain[y_attr].varType == Discrete
163        z_discrete = self.data_domain[z_attr].varType == Discrete
164
165        colors = []
166        if color_discrete:
167            for i in range(len(self.data_domain[color_attr].values)):
[8837]168                c = self.discrete_palette[i]
[8786]169                colors.append(c)
[8722]170
171        data_scale = [self.attr_values[x_attr][1] - self.attr_values[x_attr][0],
172                      self.attr_values[y_attr][1] - self.attr_values[y_attr][0],
173                      self.attr_values[z_attr][1] - self.attr_values[z_attr][0]]
174        data_translation = [self.attr_values[x_attr][0],
175                            self.attr_values[y_attr][0],
176                            self.attr_values[z_attr][0]]
177        data_scale = 1. / numpy.array(data_scale)
178        if x_discrete:
179            data_scale[0] = 0.5 / float(len(self.data_domain[x_attr].values))
180            data_translation[0] = 1.
181        if y_discrete:
182            data_scale[1] = 0.5 / float(len(self.data_domain[y_attr].values))
183            data_translation[1] = 1.
184        if z_discrete:
185            data_scale[2] = 0.5 / float(len(self.data_domain[z_attr].values))
186            data_translation[2] = 1.
187
[8975]188        self.data_scale = QVector3D(*data_scale)
189        self.data_translation = QVector3D(*data_translation)
190
[8967]191        self._x_axis_labels = None
192        self._y_axis_labels = None
193        self._z_axis_labels = None
[8927]194
[8834]195        self.clear()
[8820]196
[8834]197        attr_indices = [x_index, y_index, z_index]
198        if color_index > -1:
199            attr_indices.append(color_index)
200        if size_index > -1:
201            attr_indices.append(size_index)
202        if symbol_index > -1:
203            attr_indices.append(symbol_index)
204        if label_index > -1:
205            attr_indices.append(label_index)
206
207        valid_data = self.getValidList(attr_indices)
208        self.set_valid_data(valid_data)
209
[8933]210        self.set_features(x_index, y_index, z_index,
[8722]211            color_index, symbol_index, size_index, label_index,
[8730]212            colors, num_symbols_used,
[8975]213            x_discrete, y_discrete, z_discrete)
[8722]214
[8834]215        ## Legend
[8820]216        def_color = QColor(150, 150, 150)
217        def_symbol = 0
218        def_size = 10
[8722]219
[8820]220        if color_discrete:
221            num = len(self.data_domain[color_attr].values)
222            values = get_variable_values_sorted(self.data_domain[color_attr])
223            for ind in range(num):
[8837]224                self.legend().add_item(color_attr, values[ind], OWPoint(def_symbol, self.discrete_palette[ind], def_size))
[8722]225
[8820]226        if symbol_index != -1:
227            num = len(self.data_domain[symbol_attr].values)
228            values = get_variable_values_sorted(self.data_domain[symbol_attr])
229            for ind in range(num):
230                self.legend().add_item(symbol_attr, values[ind], OWPoint(ind, def_color, def_size))
[8722]231
[8820]232        if size_discrete:
233            num = len(self.data_domain[size_attr].values)
234            values = get_variable_values_sorted(self.data_domain[size_attr])
235            for ind in range(num):
[8837]236                self.legend().add_item(size_attr, values[ind], OWPoint(def_symbol, def_color, 6 + round(ind * 5 / len(values))))
[8722]237
[8820]238        if color_index != -1 and self.data_domain[color_attr].varType == Continuous:
239            self.legend().add_color_gradient(color_attr, [("%%.%df" % self.data_domain[color_attr].numberOfDecimals % v) for v in self.attr_values[color_attr]])
[8722]240
[8837]241        self.legend().max_size = QSize(400, 400)
242        self.legend().set_floating(True)
[8820]243        self.legend().set_orientation(Qt.Vertical)
244        if self.legend().pos().x() == 0:
245            self.legend().setPos(QPointF(100, 100))
246        self.legend().update_items()
[8837]247        self.legend().setVisible(self.show_legend)
[8722]248
[8834]249        ## Axes
[8967]250        self._x_axis_title = x_attr
251        self._y_axis_title = y_attr
252        self._z_axis_title = z_attr
[8722]253
254        if x_discrete:
[8967]255            self._x_axis_labels = get_variable_values_sorted(self.data_domain[x_attr])
[8722]256        if y_discrete:
[8967]257            self._y_axis_labels = get_variable_values_sorted(self.data_domain[y_attr])
[8722]258        if z_discrete:
[8967]259            self._z_axis_labels = get_variable_values_sorted(self.data_domain[z_attr])
[8722]260
[8820]261        self.update()
262
[8722]263    def before_draw(self):
[8856]264        if self.show_grid:
[8967]265            self._draw_grid()
[8856]266        if self.show_chassis:
[8967]267            self._draw_chassis()
[8927]268        if self.show_axes:
[8967]269            self._draw_axes()
[8856]270
[8967]271    def _draw_chassis(self):
[8722]272        glMatrixMode(GL_PROJECTION)
273        glLoadIdentity()
274        glMultMatrixd(numpy.array(self.projection.data(), dtype=float))
275        glMatrixMode(GL_MODELVIEW)
276        glLoadIdentity()
[8967]277        glMultMatrixd(numpy.array(self.view.data(), dtype=float))
278        glMultMatrixd(numpy.array(self.model.data(), dtype=float))
[8722]279
[8856]280        # TODO: line stipple with shaders?
[8810]281        self.qglColor(self._theme.axis_values_color)
[8722]282        glEnable(GL_LINE_STIPPLE)
283        glLineStipple(1, 0x00FF)
284        glDisable(GL_DEPTH_TEST)
285        glLineWidth(1)
286        glEnable(GL_BLEND)
287        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
288        edges = [self.x_axis, self.y_axis, self.z_axis,
289                 self.x_axis+self.unit_z, self.x_axis+self.unit_y,
290                 self.x_axis+self.unit_z+self.unit_y,
291                 self.y_axis+self.unit_x, self.y_axis+self.unit_z,
292                 self.y_axis+self.unit_x+self.unit_z,
293                 self.z_axis+self.unit_x, self.z_axis+self.unit_y,
294                 self.z_axis+self.unit_x+self.unit_y]
295        glBegin(GL_LINES)
296        for edge in edges:
[8975]297            start, end = edge.v0, edge.v1
298            glVertex3f(start.x(), start.y(), start.z())
299            glVertex3f(end.x(), end.y(), end.z())
[8722]300        glEnd()
301        glDisable(GL_LINE_STIPPLE)
302        glEnable(GL_DEPTH_TEST)
303        glDisable(GL_BLEND)
304
[8975]305    def _map_to_original(self, point, coord_index):
306        v = vec_div(self.map_to_data(point), self.data_scale) + self.data_translation
307        if coord_index == 0:
308            return v.x()
309        elif coord_index == 1:
310            return v.y()
311        elif coord_index == 2:
312            return v.z()
313
[8967]314    def _draw_grid(self):
315        self.renderer.set_transform(self.model, self.view, self.projection)
[8975]316        cam_in_space = self.camera * self.camera_distance
[8722]317
[8967]318        def _draw_grid_plane(axis0, axis1, normal0, normal1, i, j):
[8722]319            for axis, normal, coord_index in zip([axis0, axis1], [normal0, normal1], [i, j]):
[8975]320                start, end = axis.v0, axis.v1
321                start_value = self._map_to_original(start, coord_index)
322                end_value = self._map_to_original(end, coord_index)
[8722]323                values, _ = loose_label(start_value, end_value, 7)
324                for value in values:
325                    if not (start_value <= value <= end_value):
326                        continue
327                    position = start + (end-start)*((value-start_value) / float(end_value-start_value))
[8856]328                    self.renderer.draw_line(
[8975]329                        position,
330                        position - normal*1,
[8856]331                        color=self._theme.grid_color)
[8722]332
333        glDisable(GL_DEPTH_TEST)
334        glEnable(GL_BLEND)
335        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
336
337        planes = [self.axis_plane_xy, self.axis_plane_yz,
338                  self.axis_plane_xy_back, self.axis_plane_yz_right]
339        axes = [[self.x_axis, self.y_axis],
340                [self.y_axis, self.z_axis],
341                [self.x_axis+self.unit_z, self.y_axis+self.unit_z],
342                [self.z_axis+self.unit_x, self.y_axis+self.unit_x]]
[8975]343        normals = [[QVector3D(0,-1, 0), QVector3D(-1, 0, 0)],
344                   [QVector3D(0, 0,-1), QVector3D( 0,-1, 0)],
345                   [QVector3D(0,-1, 0), QVector3D(-1, 0, 0)],
346                   [QVector3D(0,-1, 0), QVector3D( 0, 0,-1)]]
[8722]347        coords = [[0, 1],
348                  [1, 2],
349                  [0, 1],
350                  [2, 1]]
[8975]351        visible_planes = [plane.visible_from(cam_in_space) for plane in planes]
352        xz_visible = not self.axis_plane_xz.visible_from(cam_in_space)
[8722]353        if xz_visible:
[8975]354            _draw_grid_plane(self.x_axis, self.z_axis, QVector3D(0, 0, -1), QVector3D(-1, 0, 0), 0, 2)
[8722]355        for visible, (axis0, axis1), (normal0, normal1), (i, j) in\
356             zip(visible_planes, axes, normals, coords):
357            if not visible:
[8967]358                _draw_grid_plane(axis0, axis1, normal0, normal1, i, j)
[8722]359
360        glEnable(GL_DEPTH_TEST)
361        glDisable(GL_BLEND)
[8682]362
[8967]363    def _build_axes(self):
[8927]364        edge_half = 1. / 2.
[8975]365        self.x_axis = Edge(QVector3D(-edge_half, -edge_half, -edge_half), QVector3D( edge_half, -edge_half, -edge_half))
366        self.y_axis = Edge(QVector3D(-edge_half, -edge_half, -edge_half), QVector3D(-edge_half,  edge_half, -edge_half))
367        self.z_axis = Edge(QVector3D(-edge_half, -edge_half, -edge_half), QVector3D(-edge_half, -edge_half,  edge_half))
[8927]368
[8975]369        self.unit_x = unit_x = QVector3D(1, 0, 0)
370        self.unit_y = unit_y = QVector3D(0, 1, 0)
371        self.unit_z = unit_z = QVector3D(0, 0, 1)
[8927]372 
[8975]373        A = self.y_axis.v1
374        B = self.y_axis.v1 + unit_x
375        C = self.x_axis.v1
376        D = self.x_axis.v0
[8927]377
378        E = A + unit_z
379        F = B + unit_z
380        G = C + unit_z
381        H = D + unit_z
382
[8975]383        self.axis_plane_xy = Plane(A, B, C, D)
384        self.axis_plane_yz = Plane(A, D, H, E)
385        self.axis_plane_xz = Plane(D, C, G, H)
[8927]386
[8975]387        self.axis_plane_xy_back = Plane(H, G, F, E)
388        self.axis_plane_yz_right = Plane(B, F, G, C)
389        self.axis_plane_xz_top = Plane(E, F, B, A)
[8927]390
[8967]391    def _draw_axes(self):
392        self.renderer.set_transform(self.model, self.view, self.projection)
[8927]393
[8975]394        def _draw_axis(axis):
[8967]395            glLineWidth(2)
[8975]396            self.renderer.draw_line(axis.v0,
397                                    axis.v1,
[8967]398                                    color=self._theme.axis_color)
399            glLineWidth(1)
[8927]400
[8967]401        def _draw_discrete_axis_values(axis, coord_index, normal, axis_labels):
[8975]402            start, end = axis.v0, axis.v1
403            start_value = self._map_to_original(start, coord_index)
404            end_value = self._map_to_original(end, coord_index)
[8927]405            length = end_value - start_value
406            for i, label in enumerate(axis_labels):
407                value = (i + 1) * 2
408                if start_value <= value <= end_value:
409                    position = start + (end-start)*((value-start_value) / length)
410                    self.renderer.draw_line(
[8975]411                        position,
412                        position + normal*0.03,
[8927]413                        color=self._theme.axis_values_color)
414                    position += normal * 0.1
[8975]415                    self.renderText(position.x(),
416                                    position.y(),
417                                    position.z(),
[8927]418                                    label, font=self._theme.labels_font)
419
[8967]420        def _draw_values(axis, coord_index, normal, axis_labels):
[8927]421            glLineWidth(1)
422            if axis_labels != None:
[8967]423                _draw_discrete_axis_values(axis, coord_index, normal, axis_labels)
[8927]424                return
[8975]425            start, end = axis.v0, axis.v1
426            start_value = self._map_to_original(start, coord_index)
427            end_value = self._map_to_original(end, coord_index)
[8927]428            values, num_frac = loose_label(start_value, end_value, 7)
429            for value in values:
430                if not (start_value <= value <= end_value):
431                    continue
432                position = start + (end-start)*((value-start_value) / float(end_value-start_value))
433                text = ('%%.%df' % num_frac) % value
434                self.renderer.draw_line(
[8975]435                    position,
436                    position+normal*0.03,
[8927]437                    color=self._theme.axis_values_color)
438                position += normal * 0.1
[8975]439                self.renderText(position.x(),
440                                position.y(),
441                                position.z(),
[8927]442                                text, font=self._theme.axis_font)
443
[8967]444        def _draw_axis_title(axis, title, normal):
[8975]445            middle = (axis.v0 + axis.v1) / 2.
446            middle += normal * 0.1 if axis.v0.y() != axis.v1.y() else normal * 0.2
447            self.renderText(middle.x(), middle.y(), middle.z(),
[8927]448                            title,
449                            font=self._theme.axis_title_font)
450
451        glDisable(GL_DEPTH_TEST)
452        glLineWidth(1)
453        glEnable(GL_BLEND)
454        glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
455
[8975]456        cam_in_space = self.camera * self.camera_distance
[8927]457
[8967]458        # TODO: the code below is horrible and should be simplified
[8927]459        planes = [self.axis_plane_xy, self.axis_plane_yz,
460                  self.axis_plane_xy_back, self.axis_plane_yz_right]
[8975]461        normals = [[QVector3D(0,-1, 0), QVector3D(-1, 0, 0)],
462                   [QVector3D(0, 0,-1), QVector3D( 0,-1, 0)],
463                   [QVector3D(0,-1, 0), QVector3D(-1, 0, 0)],
464                   [QVector3D(0,-1, 0), QVector3D( 0, 0,-1)]]
465        visible_planes = [plane.visible_from(cam_in_space) for plane in planes]
466        xz_visible = not self.axis_plane_xz.visible_from(cam_in_space)
[8927]467
468        if visible_planes[0 if xz_visible else 2]:
[8967]469            _draw_axis(self.x_axis)
[8975]470            _draw_values(self.x_axis, 0, QVector3D(0, 0, -1), self._x_axis_labels)
[8927]471            if self.show_x_axis_title:
[8975]472                _draw_axis_title(self.x_axis, self._x_axis_title, QVector3D(0, 0, -1))
[8927]473        elif visible_planes[2 if xz_visible else 0]:
[8967]474            _draw_axis(self.x_axis + self.unit_z)
[8975]475            _draw_values(self.x_axis + self.unit_z, 0, QVector3D(0, 0, 1), self._x_axis_labels)
[8927]476            if self.show_x_axis_title:
[8967]477                _draw_axis_title(self.x_axis + self.unit_z,
[8975]478                                self._x_axis_title, QVector3D(0, 0, 1))
[8927]479
480        if visible_planes[1 if xz_visible else 3]:
[8967]481            _draw_axis(self.z_axis)
[8975]482            _draw_values(self.z_axis, 2, QVector3D(-1, 0, 0), self._z_axis_labels)
[8927]483            if self.show_z_axis_title:
[8975]484                _draw_axis_title(self.z_axis, self._z_axis_title, QVector3D(-1, 0, 0))
[8927]485        elif visible_planes[3 if xz_visible else 1]:
[8967]486            _draw_axis(self.z_axis + self.unit_x)
[8975]487            _draw_values(self.z_axis + self.unit_x, 2, QVector3D(1, 0, 0), self._z_axis_labels)
[8927]488            if self.show_z_axis_title:
[8975]489                _draw_axis_title(self.z_axis + self.unit_x, self._z_axis_title, QVector3D(1, 0, 0))
[8927]490
491        try:
492            rightmost_visible = visible_planes[::-1].index(True)
493        except ValueError:
494            return
495        if rightmost_visible == 0 and visible_planes[0] == True:
496            rightmost_visible = 3
497        y_axis_translated = [self.y_axis+self.unit_x,
498                             self.y_axis+self.unit_x+self.unit_z,
499                             self.y_axis+self.unit_z,
500                             self.y_axis]
[8975]501        normals = [QVector3D(1, 0, 0),
502                   QVector3D(0, 0, 1),
503                   QVector3D(-1,0, 0),
504                   QVector3D(0, 0,-1)]
[8927]505        axis = y_axis_translated[rightmost_visible]
506        normal = normals[rightmost_visible]
[8967]507        _draw_axis(axis)
508        _draw_values(axis, 1, normal, self._y_axis_labels)
[8927]509        if self.show_y_axis_title:
[8967]510            _draw_axis_title(axis, self._y_axis_title, normal)
[8927]511
[8682]512class OWScatterPlot3D(OWWidget):
513    settingsList = ['plot.show_legend', 'plot.symbol_size', 'plot.show_x_axis_title', 'plot.show_y_axis_title',
[8852]514                    'plot.show_z_axis_title', 'plot.show_legend', 'plot.use_2d_symbols', 'plot.symbol_scale',
[8930]515                    'plot.alpha_value', 'plot.show_grid', 'plot.pitch', 'plot.yaw',
[8682]516                    'plot.show_chassis', 'plot.show_axes',
517                    'auto_send_selection', 'auto_send_selection_update',
[8847]518                    'plot.jitter_size', 'plot.jitter_continuous', 'dark_theme']
[8722]519    contextHandlers = {'': DomainContextHandler('', ['x_attr', 'y_attr', 'z_attr'])}
[8682]520    jitter_sizes = [0.0, 0.1, 0.5, 1, 2, 3, 4, 5, 7, 10, 15, 20, 30, 40, 50]
521
[8852]522    def __init__(self, parent=None, signalManager=None, name='ScatterPlot 3D'):
[8682]523        OWWidget.__init__(self, parent, signalManager, name, True)
524
[9546]525        self.inputs = [('Data', ExampleTable, self.set_data, Default), ('Subset Examples', ExampleTable, self.set_subset_data)]
526        self.outputs = [('Selected Data', ExampleTable), ('Other Data', ExampleTable)]
[8682]527
[8722]528        self.x_attr = ''
529        self.y_attr = ''
530        self.z_attr = ''
[8682]531
532        self.x_attr_discrete = False
533        self.y_attr_discrete = False
534        self.z_attr_discrete = False
535
[8722]536        self.color_attr = ''
537        self.size_attr = ''
538        self.symbol_attr = ''
539        self.label_attr = ''
[8682]540
541        self.tabs = OWGUI.tabWidget(self.controlArea)
542        self.main_tab = OWGUI.createTabPage(self.tabs, 'Main')
543        self.settings_tab = OWGUI.createTabPage(self.tabs, 'Settings', canScroll=True)
544
[8722]545        self.x_attr_cb = OWGUI.comboBox(self.main_tab, self, 'x_attr', box='X-axis attribute',
546            tooltip='Attribute to plot on X axis.',
547            callback=self.on_axis_change,
548            sendSelectedValue=1,
549            valueType=str)
[8682]550
[8722]551        self.y_attr_cb = OWGUI.comboBox(self.main_tab, self, 'y_attr', box='Y-axis attribute',
552            tooltip='Attribute to plot on Y axis.',
553            callback=self.on_axis_change,
554            sendSelectedValue=1,
555            valueType=str)
[8682]556
[8722]557        self.z_attr_cb = OWGUI.comboBox(self.main_tab, self, 'z_attr', box='Z-axis attribute',
558            tooltip='Attribute to plot on Z axis.',
559            callback=self.on_axis_change,
560            sendSelectedValue=1,
561            valueType=str)
[8682]562
[8722]563        self.color_attr_cb = OWGUI.comboBox(self.main_tab, self, 'color_attr', box='Point color',
564            tooltip='Attribute to use for point color',
565            callback=self.on_axis_change,
566            sendSelectedValue=1,
567            valueType=str)
[8682]568
[8722]569        # Additional point properties (labels, size, symbol).
[8682]570        additional_box = OWGUI.widgetBox(self.main_tab, 'Additional Point Properties')
[8722]571        self.size_attr_cb = OWGUI.comboBox(additional_box, self, 'size_attr', label='Point size:',
572            tooltip='Attribute to use for point size',
[8682]573            callback=self.on_axis_change,
574            indent=10,
575            emptyString='(Same size)',
[8722]576            sendSelectedValue=1,
577            valueType=str)
[8682]578
[8722]579        self.symbol_attr_cb = OWGUI.comboBox(additional_box, self, 'symbol_attr', label='Point symbol:',
580            tooltip='Attribute to use for point symbol',
[8682]581            callback=self.on_axis_change,
582            indent=10,
[8722]583            emptyString='(Same symbol)',
584            sendSelectedValue=1,
585            valueType=str)
[8682]586
[8722]587        self.label_attr_cb = OWGUI.comboBox(additional_box, self, 'label_attr', label='Point label:',
588            tooltip='Attribute to use for pointLabel',
[8682]589            callback=self.on_axis_change,
590            indent=10,
[8722]591            emptyString='(No labels)',
592            sendSelectedValue=1,
593            valueType=str)
[8682]594
595        self.plot = ScatterPlot(self)
[8722]596        self.vizrank = OWVizRank(self, self.signalManager, self.plot, orngVizRank.SCATTERPLOT3D, 'ScatterPlot3D')
[8682]597        self.optimization_dlg = self.vizrank
598
599        self.optimization_buttons = OWGUI.widgetBox(self.main_tab, 'Optimization dialogs', orientation='horizontal')
[8722]600        OWGUI.button(self.optimization_buttons, self, 'VizRank', callback=self.vizrank.reshow,
[8682]601            tooltip='Opens VizRank dialog, where you can search for interesting projections with different subsets of attributes',
602            debuggingEnabled=0)
603
604        box = OWGUI.widgetBox(self.settings_tab, 'Point properties')
[8922]605        OWGUI.hSlider(box, self, 'plot.symbol_scale', label='Symbol scale',
[8805]606            minValue=1, maxValue=20,
[8722]607            tooltip='Scale symbol size',
608            callback=self.on_checkbox_update)
[8682]609
[8742]610        OWGUI.hSlider(box, self, 'plot.alpha_value', label='Transparency',
[8682]611            minValue=10, maxValue=255,
[8722]612            tooltip='Point transparency value',
[8682]613            callback=self.on_checkbox_update)
614        OWGUI.rubber(box)
615
[8722]616        box = OWGUI.widgetBox(self.settings_tab, 'Jittering Options')
617        self.jitter_size_combo = OWGUI.comboBox(box, self, 'plot.jitter_size', label='Jittering size (% of size)'+'  ',
[8682]618            orientation='horizontal',
[8794]619            callback=self.handleNewSignals,
[8682]620            items=self.jitter_sizes,
621            sendSelectedValue=1,
622            valueType=float)
[8722]623        OWGUI.checkBox(box, self, 'plot.jitter_continuous', 'Jitter continuous attributes',
[8794]624            callback=self.handleNewSignals,
[8682]625            tooltip='Does jittering apply also on continuous attributes?')
626
627        self.dark_theme = False
628
629        box = OWGUI.widgetBox(self.settings_tab, 'General settings')
630        OWGUI.checkBox(box, self, 'plot.show_x_axis_title',   'X axis title',   callback=self.on_checkbox_update)
631        OWGUI.checkBox(box, self, 'plot.show_y_axis_title',   'Y axis title',   callback=self.on_checkbox_update)
632        OWGUI.checkBox(box, self, 'plot.show_z_axis_title',   'Z axis title',   callback=self.on_checkbox_update)
633        OWGUI.checkBox(box, self, 'plot.show_legend',         'Show legend',    callback=self.on_checkbox_update)
[8722]634        OWGUI.checkBox(box, self, 'plot.use_2d_symbols',      '2D symbols',     callback=self.update_plot)
[8682]635        OWGUI.checkBox(box, self, 'dark_theme',               'Dark theme',     callback=self.on_theme_change)
636        OWGUI.checkBox(box, self, 'plot.show_grid',           'Show grid',      callback=self.on_checkbox_update)
637        OWGUI.checkBox(box, self, 'plot.show_axes',           'Show axes',      callback=self.on_checkbox_update)
638        OWGUI.checkBox(box, self, 'plot.show_chassis',        'Show chassis',   callback=self.on_checkbox_update)
639        OWGUI.checkBox(box, self, 'plot.hide_outside',        'Hide outside',   callback=self.on_checkbox_update)
640        OWGUI.rubber(box)
641
[8854]642        box = OWGUI.widgetBox(self.settings_tab, 'Mouse', orientation = "horizontal")
643        OWGUI.hSlider(box, self, 'plot.mouse_sensitivity', label='Sensitivity', minValue=1, maxValue=10,
644                      step=1,
645                      callback=self.plot.update,
646                      tooltip='Change mouse sensitivity')
647
[8922]648        gui = self.plot.gui
[9088]649        buttons = gui.default_zoom_select_buttons
650        buttons.insert(2, (gui.UserButton, 'Rotate', 'state', ROTATING, None, 'Dlg_undo'))
651        self.zoom_select_toolbar = gui.zoom_select_toolbar(self.main_tab, buttons=buttons)
[8811]652        self.connect(self.zoom_select_toolbar.buttons[gui.SendSelection], SIGNAL("clicked()"), self.send_selection)
[8922]653        self.connect(self.zoom_select_toolbar.buttons[gui.Zoom], SIGNAL("clicked()"), self.plot.unselect_all_points)
654        self.plot.set_selection_behavior(OWPlot.ReplaceSelection)
[8682]655
656        self.tooltip_kind = TooltipKind.NONE
[8722]657        box = OWGUI.widgetBox(self.settings_tab, 'Tooltips Settings')
[8682]658        OWGUI.comboBox(box, self, 'tooltip_kind', items = [
659            'Don\'t Show Tooltips', 'Show Visible Attributes', 'Show All Attributes'])
660
661        self.plot.mouseover_callback = self.mouseover_callback
662
663        self.main_tab.layout().addStretch(100)
664        self.settings_tab.layout().addStretch(100)
665
666        self.mainArea.layout().addWidget(self.plot)
[8722]667        self.connect(self.graphButton, SIGNAL('clicked()'), self.plot.save_to_file)
[8682]668
669        self.loadSettings()
670        self.plot.update_camera()
[8847]671        self.on_theme_change()
[8682]672
673        self.data = None
[8722]674        self.subset_data = None
[8682]675        self.resize(1100, 600)
676
677    def mouseover_callback(self, index):
678        if self.tooltip_kind == TooltipKind.VISIBLE:
[8834]679            self.plot.show_tooltip(self.get_example_tooltip(self.data[index], self.shown_attrs))
[8682]680        elif self.tooltip_kind == TooltipKind.ALL:
681            self.plot.show_tooltip(self.get_example_tooltip(self.data[index]))
682
683    def get_example_tooltip(self, example, indices=None, max_indices=20):
684        if indices and type(indices[0]) == str:
[8722]685            indices = [self.plot.attribute_name_index[i] for i in indices]
[8682]686        if not indices:
687            indices = range(len(self.data.domain.attributes))
688
689        if example.domain.classVar:
[8722]690            classIndex = self.plot.attribute_name_index[example.domain.classVar.name]
[8682]691            while classIndex in indices:
692                indices.remove(classIndex)
693
694        text = '<b>Attributes:</b><br>'
695        for index in indices[:max_indices]:
[8724]696            attr = self.plot.data_domain[index].name
[8682]697            if attr not in example.domain:  text += '&nbsp;'*4 + '%s = ?<br>' % (attr)
698            elif example[attr].isSpecial(): text += '&nbsp;'*4 + '%s = ?<br>' % (attr)
699            else:                           text += '&nbsp;'*4 + '%s = %s<br>' % (attr, str(example[attr]))
700
701        if len(indices) > max_indices:
702            text += '&nbsp;'*4 + ' ... <br>'
703
704        if example.domain.classVar:
705            text = text[:-4]
706            text += '<hr><b>Class:</b><br>'
707            if example.getclass().isSpecial(): text += '&nbsp;'*4 + '%s = ?<br>' % (example.domain.classVar.name)
708            else:                              text += '&nbsp;'*4 + '%s = %s<br>' % (example.domain.classVar.name, str(example.getclass()))
709
710        if len(example.domain.getmetas()) != 0:
711            text = text[:-4]
712            text += '<hr><b>Meta attributes:</b><br>'
713            for key in example.domain.getmetas():
714                try: text += '&nbsp;'*4 + '%s = %s<br>' % (example.domain[key].name, str(example[key]))
715                except: pass
716        return text[:-4]
717
718    def set_data(self, data=None):
[8722]719        self.closeContext()
720        self.vizrank.clearResults()
721        same_domain = self.data and data and\
722            data.domain.checksum() == self.data.domain.checksum()
[8682]723        self.data = data
[8722]724        if not same_domain:
725            self.init_attr_values()
726        self.openContext('', data)
727
728    def init_attr_values(self):
[8682]729        self.x_attr_cb.clear()
730        self.y_attr_cb.clear()
731        self.z_attr_cb.clear()
732        self.color_attr_cb.clear()
733        self.size_attr_cb.clear()
[8722]734        self.symbol_attr_cb.clear()
[8682]735        self.label_attr_cb.clear()
736
737        self.discrete_attrs = {}
738
[8722]739        if not self.data:
740            return
[8682]741
[8722]742        self.color_attr_cb.addItem('(Same color)')
743        self.label_attr_cb.addItem('(No labels)')
744        self.symbol_attr_cb.addItem('(Same symbol)')
745        self.size_attr_cb.addItem('(Same size)')
[8682]746
[8722]747        icons = OWGUI.getAttributeIcons() 
748        for metavar in [self.data.domain.getmeta(mykey) for mykey in self.data.domain.getmetas().keys()]:
749            self.label_attr_cb.addItem(icons[metavar.varType], metavar.name)
[8682]750
[8722]751        for attr in self.data.domain:
752            if attr.varType in [Discrete, Continuous]:
[8682]753                self.x_attr_cb.addItem(icons[attr.varType], attr.name)
754                self.y_attr_cb.addItem(icons[attr.varType], attr.name)
755                self.z_attr_cb.addItem(icons[attr.varType], attr.name)
756                self.color_attr_cb.addItem(icons[attr.varType], attr.name)
757                self.size_attr_cb.addItem(icons[attr.varType], attr.name)
[8820]758            if attr.varType == Discrete and len(attr.values) < len(Symbol):
[8722]759                self.symbol_attr_cb.addItem(icons[attr.varType], attr.name)
760            self.label_attr_cb.addItem(icons[attr.varType], attr.name)
[8682]761
[8722]762        self.x_attr = str(self.x_attr_cb.itemText(0))
763        if self.y_attr_cb.count() > 1:
764            self.y_attr = str(self.y_attr_cb.itemText(1))
765        else:
766            self.y_attr = str(self.y_attr_cb.itemText(0))
[8682]767
[8722]768        if self.z_attr_cb.count() > 2:
769            self.z_attr = str(self.z_attr_cb.itemText(2))
770        else:
771            self.z_attr = str(self.z_attr_cb.itemText(0))
772
773        if self.data.domain.classVar and self.data.domain.classVar.varType in [Discrete, Continuous]:
774            self.color_attr = self.data.domain.classVar.name
775        else:
776            self.color_attr = ''
777
778        self.symbol_attr = self.size_attr = self.label_attr = ''
[8834]779        self.shown_attrs = [self.x_attr, self.y_attr, self.z_attr, self.color_attr]
[8682]780
781    def set_subset_data(self, data=None):
[8722]782        self.subset_data = data
[8682]783
784    def handleNewSignals(self):
[8828]785        self.plot.set_data(self.data, self.subset_data)
[8722]786        self.vizrank.resetDialog()
[8682]787        self.update_plot()
[8811]788        self.send_selection()
[8682]789
790    def saveSettings(self):
791        OWWidget.saveSettings(self)
792
793    def sendReport(self):
[8722]794        self.startReport('%s [%s - %s - %s]' % (self.windowTitle(), self.x_attr, self.y_attr, self.z_attr))
[8682]795        self.reportSettings('Visualized attributes',
[8722]796                            [('X', self.x_attr),
797                             ('Y', self.y_attr),
798                             ('Z', self.z_attr),
799                             self.color_attr and ('Color', self.color_attr),
800                             self.label_attr and ('Label', self.label_attr),
801                             self.symbol_attr and ('Symbol', self.symbol_attr),
802                             self.size_attr  and ('Size', self.size_attr)])
[8682]803        self.reportSettings('Settings',
804                            [('Symbol size', self.plot.symbol_scale),
[8742]805                             ('Transparency', self.plot.alpha_value),
[8722]806                             ('Jittering', self.jitter_size),
807                             ('Jitter continuous attributes', OWGUI.YesNo[self.jitter_continuous])
[8682]808                             ])
809        self.reportSection('Plot')
810        self.reportImage(self.plot.save_to_file_direct, QSize(400, 400))
811
[8811]812    def send_selection(self):
[8682]813        if self.data == None:
814            return
[8852]815
816        selected = None#selected = self.plot.get_selected_indices() # TODO: crash
[8834]817        if selected == None or len(selected) != len(self.data):
[8815]818            return
[8852]819
[8811]820        unselected = numpy.logical_not(selected)
821        selected = self.data.selectref(list(selected))
822        unselected = self.data.selectref(list(unselected))
[9546]823        self.send('Selected Data', selected)
824        self.send('Other Data', unselected)
[8682]825
826    def on_axis_change(self):
827        if self.data is not None:
828            self.update_plot()
829
830    def on_theme_change(self):
831        if self.dark_theme:
[8838]832            self.plot.theme = ScatterDarkTheme()
[8682]833        else:
[8838]834            self.plot.theme = ScatterLightTheme()
[8682]835
836    def on_checkbox_update(self):
[8820]837        self.plot.update()
[8682]838
839    def update_plot(self):
840        if self.data is None:
841            return
842
[8722]843        self.plot.update_data(self.x_attr, self.y_attr, self.z_attr,
844                              self.color_attr, self.symbol_attr, self.size_attr,
845                              self.label_attr)
[8682]846
847    def showSelectedAttributes(self):
848        val = self.vizrank.getSelectedProjection()
849        if not val: return
850        if self.data.domain.classVar:
[8722]851            self.attr_color = self.data.domain.classVar.name
[8682]852        if not self.plot.have_data:
853            return
854        attr_list = val[3]
855        if attr_list and len(attr_list) == 3:
[8722]856            self.x_attr = attr_list[0]
857            self.y_attr = attr_list[1]
858            self.z_attr = attr_list[2]
[8682]859
860        self.update_plot()
861
[8722]862if __name__ == '__main__':
[8682]863    app = QApplication(sys.argv)
864    w = OWScatterPlot3D()
[8722]865    data = orange.ExampleTable('../../doc/datasets/iris')
[8682]866    w.set_data(data)
867    w.handleNewSignals()
868    w.show()
869    app.exec_()
Note: See TracBrowser for help on using the repository browser.