source: orange/orange/OrangeWidgets/Prototypes/OWScatterPlot3D.py @ 8478:cf58b3983294

Revision 8478:cf58b3983294, 21.5 KB checked in by matejd <matejd@…>, 23 months ago (diff)

Added light/dark color themes

Line 
1"""<name> 3D Scatterplot</name>
2"""
3
4from OWWidget import *
5from plot.owplot3d import *
6
7import orange
8Discrete = orange.VarTypes.Discrete
9Continuous = orange.VarTypes.Continuous
10
11import OWGUI
12import OWToolbars
13import OWColorPalette
14
15import numpy
16
17TooltipKind = enum('NONE', 'VISIBLE', 'ALL') # Which attributes should be displayed in tooltips?
18
19class OWScatterPlot3D(OWWidget):
20    settingsList = ['plot.show_legend', 'plot.symbol_size', 'plot.show_x_axis_title', 'plot.show_y_axis_title',
21                    'plot.show_z_axis_title', 'plot.show_legend', 'plot.use_2d_symbols',
22                    'plot.transparency', 'plot.show_grid', 'plot.pitch', 'plot.yaw', 'plot.use_ortho',
23                    'auto_send_selection', 'auto_send_selection_update',
24                    'jitter_size', 'jitter_continuous']
25    contextHandlers = {"": DomainContextHandler("", ["xAttr", "yAttr", "zAttr"])}
26    jitter_sizes = [0.0, 0.1, 0.5, 1, 2, 3, 4, 5, 7, 10, 15, 20, 30, 40, 50]
27
28    def __init__(self, parent=None, signalManager=None, name="Scatter Plot 3D"):
29        OWWidget.__init__(self, parent, signalManager, name, True)
30
31        self.inputs = [("Examples", ExampleTable, self.set_data, Default), ("Subset Examples", ExampleTable, self.set_subset_data)]
32        self.outputs = [("Selected Examples", ExampleTable), ("Unselected Examples", ExampleTable)]
33
34        self.x_attr = 0
35        self.y_attr = 0
36        self.z_attr = 0
37
38        self.x_attr_discrete = False
39        self.y_attr_discrete = False
40        self.z_attr_discrete = False
41
42        self.color_attr = None
43        self.size_attr = None
44        self.shape_attr = None
45        self.label_attr = None
46
47        self.symbol_scale = 5
48        self.alpha_value = 255
49
50
51        self.tabs = OWGUI.tabWidget(self.controlArea)
52        self.main_tab = OWGUI.createTabPage(self.tabs, 'Main')
53        self.settings_tab = OWGUI.createTabPage(self.tabs, 'Settings', canScroll=True)
54
55        self.x_attr_cb = OWGUI.comboBox(self.main_tab, self, "x_attr", box="X-axis attribute",
56            tooltip="Attribute to plot on X axis.",
57            callback=self.on_axis_change
58            )
59
60        self.y_attr_cb = OWGUI.comboBox(self.main_tab, self, "y_attr", box="Y-axis attribute",
61            tooltip="Attribute to plot on Y axis.",
62            callback=self.on_axis_change
63            )
64
65        self.z_attr_cb = OWGUI.comboBox(self.main_tab, self, "z_attr", box="Z-axis attribute",
66            tooltip="Attribute to plot on Z axis.",
67            callback=self.on_axis_change
68            )
69
70        self.color_attr_cb = OWGUI.comboBox(self.main_tab, self, "color_attr", box="Point color",
71            tooltip="Attribute to use for point color",
72            callback=self.on_axis_change)
73
74        # Additional point properties (labels, size, shape).
75        additional_box = OWGUI.widgetBox(self.main_tab, 'Additional Point Properties')
76        self.size_attr_cb = OWGUI.comboBox(additional_box, self, "size_attr", label="Point size:",
77            tooltip="Attribute to use for pointSize",
78            callback=self.on_axis_change,
79            indent = 10,
80            emptyString = '(Same size)',
81            )
82
83        self.shape_attr_cb = OWGUI.comboBox(additional_box, self, "shape_attr", label="Point shape:",
84            tooltip="Attribute to use for pointShape",
85            callback=self.on_axis_change,
86            indent = 10,
87            emptyString = '(Same shape)',
88            )
89
90        self.label_attr_cb = OWGUI.comboBox(additional_box, self, "label_attr", label="Point label:",
91            tooltip="Attribute to use for pointLabel",
92            callback=self.on_axis_change,
93            indent = 10,
94            emptyString = '(No labels)'
95            )
96
97        self.plot = OWPlot3D(self)
98
99        box = OWGUI.widgetBox(self.settings_tab, 'Point properties')
100        OWGUI.hSlider(box, self, "plot.symbol_scale", label="Symbol scale",
101            minValue=1, maxValue=5,
102            tooltip="Scale symbol size",
103            callback=self.on_checkbox_update,
104            )
105
106        OWGUI.hSlider(box, self, "plot.transparency", label="Transparency",
107            minValue=10, maxValue=255,
108            tooltip="Point transparency value",
109            callback=self.on_checkbox_update)
110        OWGUI.rubber(box)
111
112        self.jitter_size = 0
113        self.jitter_continuous = False
114        box = OWGUI.widgetBox(self.settings_tab, "Jittering Options")
115        self.jitter_size_combo = OWGUI.comboBox(box, self, 'jitter_size', label='Jittering size (% of size)'+'  ',
116            orientation='horizontal',
117            callback=self.update_plot,
118            items=self.jitter_sizes,
119            sendSelectedValue=1,
120            valueType=float)
121        OWGUI.checkBox(box, self, 'jitter_continuous', 'Jitter continuous attributes',
122            callback=self.update_plot,
123            tooltip='Does jittering apply also on continuous attributes?')
124
125        self.dark_theme = False
126
127        box = OWGUI.widgetBox(self.settings_tab, 'General settings')
128        OWGUI.checkBox(box, self, 'plot.show_x_axis_title',   'X axis title',   callback=self.on_checkbox_update)
129        OWGUI.checkBox(box, self, 'plot.show_y_axis_title',   'Y axis title',   callback=self.on_checkbox_update)
130        OWGUI.checkBox(box, self, 'plot.show_z_axis_title',   'Z axis title',   callback=self.on_checkbox_update)
131        OWGUI.checkBox(box, self, 'plot.show_legend',         'Show legend',    callback=self.on_checkbox_update)
132        OWGUI.checkBox(box, self, 'plot.use_ortho',           'Use ortho',      callback=self.on_checkbox_update)
133        OWGUI.checkBox(box, self, 'plot.use_2d_symbols',      '2D symbols',     callback=self.on_checkbox_update)
134        OWGUI.checkBox(box, self, 'plot.show_grid',           'Show grid',      callback=self.on_checkbox_update)
135        OWGUI.checkBox(box, self, 'dark_theme',               'Dark theme',     callback=self.on_theme_change)
136        OWGUI.rubber(box)
137
138        self.auto_send_selection = True
139        self.auto_send_selection_update = False
140        self.plot.selection_changed_callback = self.selection_changed_callback
141        box = OWGUI.widgetBox(self.settings_tab, 'Auto Send Selected Data When...')
142        OWGUI.checkBox(box, self, 'auto_send_selection', 'Adding/Removing selection areas',
143            callback = self.on_checkbox_update, tooltip = 'Send selected data whenever a selection area is added or removed')
144        OWGUI.checkBox(box, self, 'auto_send_selection_update', 'Moving/Resizing selection areas',
145            callback = self.on_checkbox_update, tooltip = 'Send selected data when a user moves or resizes an existing selection area')
146
147        self.zoom_select_toolbar = OWToolbars.ZoomSelectToolbar(self, self.main_tab, self.plot, self.auto_send_selection)
148        self.connect(self.zoom_select_toolbar.buttonSendSelections, SIGNAL('clicked()'), self.send_selections)
149        self.connect(self.zoom_select_toolbar.buttonSelectRect, SIGNAL('clicked()'), self.change_selection_type)
150        self.connect(self.zoom_select_toolbar.buttonSelectPoly, SIGNAL('clicked()'), self.change_selection_type)
151        self.connect(self.zoom_select_toolbar.buttonZoom, SIGNAL('clicked()'), self.change_selection_type)
152        self.connect(self.zoom_select_toolbar.buttonRemoveLastSelection, SIGNAL('clicked()'), self.plot.remove_last_selection)
153        self.connect(self.zoom_select_toolbar.buttonRemoveAllSelections, SIGNAL('clicked()'), self.plot.remove_all_selections)
154        self.toolbarSelection = None
155
156        self.tooltip_kind = TooltipKind.NONE
157        box = OWGUI.widgetBox(self.settings_tab, "Tooltips Settings")
158        OWGUI.comboBox(box, self, 'tooltip_kind', items = [
159            'Don\'t Show Tooltips', 'Show Visible Attributes', 'Show All Attributes'], callback = self.on_axis_change)
160
161        self.plot.mouseover_callback = self.mouseover_callback
162        self.shown_attr_indices = []
163
164        self.main_tab.layout().addStretch(100)
165        self.settings_tab.layout().addStretch(100)
166
167        self.mainArea.layout().addWidget(self.plot)
168        self.connect(self.graphButton, SIGNAL("clicked()"), self.plot.save_to_file)
169
170        self.loadSettings()
171
172        self.data = None
173        self.subsetData = None
174        self.resize(1000, 600)
175
176    def mouseover_callback(self, index):
177        if self.tooltip_kind == TooltipKind.VISIBLE:
178            self.plot.show_tooltip(self.get_example_tooltip(self.data[index], self.shown_attr_indices))
179        elif self.tooltip_kind == TooltipKind.ALL:
180            self.plot.show_tooltip(self.get_example_tooltip(self.data[index]))
181
182    def get_example_tooltip(self, example, indices=None, max_indices=20):
183        if indices and type(indices[0]) == str:
184            indices = [self.attr_name_index[i] for i in indices]
185        if not indices:
186            indices = range(len(self.data.domain.attributes))
187
188        if example.domain.classVar:
189            classIndex = self.attr_name_index[example.domain.classVar.name]
190            while classIndex in indices:
191                indices.remove(classIndex)
192
193        text = '<b>Attributes:</b><br>'
194        for index in indices[:max_indices]:
195            attr = self.attr_name[index]
196            if attr not in example.domain:  text += '&nbsp;'*4 + '%s = ?<br>' % (attr)
197            elif example[attr].isSpecial(): text += '&nbsp;'*4 + '%s = ?<br>' % (attr)
198            else:                           text += '&nbsp;'*4 + '%s = %s<br>' % (attr, str(example[attr]))
199
200        if len(indices) > max_indices:
201            text += '&nbsp;'*4 + ' ... <br>'
202
203        if example.domain.classVar:
204            text = text[:-4]
205            text += '<hr><b>Class:</b><br>'
206            if example.getclass().isSpecial(): text += '&nbsp;'*4 + '%s = ?<br>' % (example.domain.classVar.name)
207            else:                              text += '&nbsp;'*4 + '%s = %s<br>' % (example.domain.classVar.name, str(example.getclass()))
208
209        if len(example.domain.getmetas()) != 0:
210            text = text[:-4]
211            text += '<hr><b>Meta attributes:</b><br>'
212            for key in example.domain.getmetas():
213                try: text += '&nbsp;'*4 + '%s = %s<br>' % (example.domain[key].name, str(example[key]))
214                except: pass
215        return text[:-4]
216
217    def selection_changed_callback(self):
218        if self.plot.selection_type == SelectionType.ZOOM:
219            indices = self.plot.get_selection_indices()
220            if len(indices) < 1:
221                self.plot.selections = []
222                return
223            X, Y, Z = self.data_array[:, self.x_attr],\
224                      self.data_array[:, self.y_attr],\
225                      self.data_array[:, self.z_attr]
226            X = [X[i] for i in indices]
227            Y = [Y[i] for i in indices]
228            Z = [Z[i] for i in indices]
229            min_x, max_x = numpy.min(X), numpy.max(X)
230            min_y, max_y = numpy.min(Y), numpy.max(Y)
231            min_z, max_z = numpy.min(Z), numpy.max(Z)
232            self.plot.set_new_zoom(min_x, max_x, min_y, max_y, min_z, max_z)
233
234    def change_selection_type(self):
235        if self.toolbarSelection < 3:
236            selection_type = [SelectionType.ZOOM, SelectionType.RECTANGLE, SelectionType.POLYGON][self.toolbarSelection]
237            self.plot.set_selection_type(selection_type)
238
239    def set_data(self, data=None):
240        self.closeContext("")
241        self.data = data
242        self.x_attr_cb.clear()
243        self.y_attr_cb.clear()
244        self.z_attr_cb.clear()
245        self.color_attr_cb.clear()
246        self.size_attr_cb.clear()
247        self.shape_attr_cb.clear()
248        self.label_attr_cb.clear()
249
250        self.discrete_attrs = {}
251
252        if self.data is not None:
253            self.all_attrs = data.domain.variables + data.domain.getmetas().values()
254            self.candidate_attrs = [attr for attr in self.all_attrs if attr.varType in [Discrete, Continuous]]
255
256            self.attr_name_index = {}
257            for i, attr in enumerate(self.all_attrs):
258                self.attr_name_index[attr.name] = i
259
260            self.attr_name = {}
261            for i, attr in enumerate(self.all_attrs):
262                self.attr_name[i] = attr.name
263
264            self.color_attr_cb.addItem('(Same color)')
265            self.size_attr_cb.addItem('(Same size)')
266            self.shape_attr_cb.addItem('(Same shape)')
267            self.label_attr_cb.addItem('(No labels)')
268            icons = OWGUI.getAttributeIcons() 
269            for (i, attr) in enumerate(self.candidate_attrs):
270                self.x_attr_cb.addItem(icons[attr.varType], attr.name)
271                self.y_attr_cb.addItem(icons[attr.varType], attr.name)
272                self.z_attr_cb.addItem(icons[attr.varType], attr.name)
273                self.color_attr_cb.addItem(icons[attr.varType], attr.name)
274                self.size_attr_cb.addItem(icons[attr.varType], attr.name)
275                self.label_attr_cb.addItem(icons[attr.varType], attr.name)
276                if attr.varType == orange.VarTypes.Discrete:
277                    self.discrete_attrs[len(self.discrete_attrs)+1] = (i, attr)
278                    self.shape_attr_cb.addItem(icons[orange.VarTypes.Discrete], attr.name)
279
280            array, c, w = self.data.toNumpyMA()
281            if len(c):
282                array = numpy.hstack((array, c.reshape(-1,1)))
283            self.data_array = array
284
285            self.x_attr, self.y_attr, self.z_attr = numpy.min([[0, 1, 2],
286                                                               [len(self.candidate_attrs) - 1]*3
287                                                              ], axis=0)
288            self.color_attr = max(len(self.candidate_attrs) - 1, 0)
289            self.shown_attr_indices = [self.x_attr, self.y_attr, self.z_attr, self.color_attr]
290            self.openContext('', data)
291
292    def set_subset_data(self, data=None):
293        self.subsetData = data # TODO: what should scatterplot do with this?
294
295    def handleNewSignals(self):
296        self.update_plot()
297        self.send_selections()
298
299    def saveSettings(self):
300        OWWidget.saveSettings(self)
301
302    def sendReport(self):
303        self.startReport('%s [%s - %s - %s]' % (self.windowTitle(), self.attr_name[self.x_attr],
304                                                self.attr_name[self.y_attr], self.attr_name[self.z_attr]))
305        self.reportSettings('Visualized attributes',
306                            [('X', self.attr_name[self.x_attr]),
307                             ('Y', self.attr_name[self.y_attr]),
308                             ('Z', self.attr_name[self.z_attr]),
309                             self.color_attr and ('Color', self.attr_name[self.color_attr]),
310                             self.label_attr and ('Label', self.attr_name[self.label_attr]),
311                             self.shape_attr and ('Shape', self.attr_name[self.shape_attr]),
312                             self.size_attr  and ('Size', self.attr_name[self.size_attr])])
313        self.reportSettings('Settings',
314                            [('Symbol size', self.plot.symbol_scale),
315                             ('Transparency', self.plot.transparency),
316                             ("Jittering", self.jitter_size),
317                             ("Jitter continuous attributes", OWGUI.YesNo[self.jitter_continuous])
318                             ])
319        self.reportSection('Plot')
320        self.reportImage(self.plot.save_to_file_direct, QSize(400, 400))
321
322    def send_selections(self):
323        if self.data == None:
324            return
325        indices = self.plot.get_selection_indices()
326        selected = [1 if i in indices else 0 for i in range(len(self.data))]
327        unselected = map(lambda n: 1-n, selected)
328        selected = self.data.selectref(selected)
329        unselected = self.data.selectref(unselected)
330        self.send('Selected Examples', selected)
331        self.send('Unselected Examples', unselected)
332
333    def on_axis_change(self):
334        if self.data is not None:
335            self.update_plot()
336
337    def on_theme_change(self):
338        if self.dark_theme:
339            self.plot.theme = DarkTheme()
340        else:
341            self.plot.theme = LightTheme()
342
343    def on_checkbox_update(self):
344        self.plot.updateGL()
345
346    def update_plot(self):
347        if self.data is None:
348            return
349
350        self.x_attr_discrete = self.y_attr_discrete = self.z_attr_discrete = False
351
352        if self.candidate_attrs[self.x_attr].varType == Discrete:
353            self.x_attr_discrete = True
354        if self.candidate_attrs[self.y_attr].varType == Discrete:
355            self.y_attr_discrete = True
356        if self.candidate_attrs[self.z_attr].varType == Discrete:
357            self.z_attr_discrete = True
358
359        X, Y, Z, mask = self.get_axis_data(self.x_attr, self.y_attr, self.z_attr)
360
361        color_legend_items = []
362        if self.color_attr > 0:
363            color_attr = self.candidate_attrs[self.color_attr - 1]
364            C = self.data_array[:, self.color_attr - 1]
365            if color_attr.varType == Discrete:
366                palette = OWColorPalette.ColorPaletteHSV(len(color_attr.values))
367                colors = [palette[int(value)] for value in C.ravel()]
368                colors = [[c.red()/255., c.green()/255., c.blue()/255., self.alpha_value/255.] for c in colors]
369                palette_colors = [palette[i] for i in range(len(color_attr.values))]
370                color_legend_items = [[Symbol.TRIANGLE, [c.red()/255., c.green()/255., c.blue()/255., 1], 1, title]
371                    for c, title in zip(palette_colors, color_attr.values)]
372            else:
373                palette = OWColorPalette.ColorPaletteBW()
374                maxC, minC = numpy.max(C), numpy.min(C)
375                C = (C - minC) / (maxC - minC)
376                colors = [palette[value] for value in C.ravel()]
377                colors = [[c.red()/255., c.green()/255., c.blue()/255., self.alpha_value/255.] for c in colors]
378        else:
379            colors = 'b'
380
381        if self.size_attr > 0:
382            size_attr = self.candidate_attrs[self.size_attr - 1]
383            S = self.data_array[:, self.size_attr - 1]
384            if size_attr.varType == Discrete:
385                sizes = [(v + 1) * len(size_attr.values) / (11 - self.symbol_scale) for v in S]
386            else:
387                min, max = numpy.min(S), numpy.max(S)
388                sizes = [(v - min) * self.symbol_scale / (max-min) for v in S]
389        else:
390            sizes = 1
391
392        shapes = None
393        if self.shape_attr > 0:
394            i, shape_attr = self.discrete_attrs[self.shape_attr]
395            if shape_attr.varType == Discrete:
396                # Map discrete attribute to [0...num shapes-1]
397                shapes = self.data_array[:, i]
398                num_shapes = 0
399                unique_shapes = {}
400                for shape in shapes:
401                    if shape not in unique_shapes:
402                        unique_shapes[shape] = num_shapes
403                        num_shapes += 1
404                shapes = [unique_shapes[value] for value in shapes]
405
406        labels = None
407        if self.label_attr > 0:
408            label_attr = self.candidate_attrs[self.label_attr - 1]
409            labels = self.data_array[:, self.label_attr - 1]
410            if label_attr.varType == Discrete:
411                value_map = {key: label_attr.values[key] for key in range(len(label_attr.values))}
412                labels = [value_map[value] for value in labels]
413
414        self.plot.clear()
415
416        num_symbols = len(Symbol)
417        if self.shape_attr > 0:
418            _, shape_attr = self.discrete_attrs[self.shape_attr]
419            titles = list(shape_attr.values)
420            for i, title in enumerate(titles):
421                if i == num_symbols-1:
422                    title = ', '.join(titles[i:])
423                self.plot.legend.add_item(i, (0,0,0,1), 1, '{0}={1}'.format(shape_attr.name, title))
424                if i == num_symbols-1:
425                    break
426
427        if color_legend_items:
428            for item in color_legend_items:
429                self.plot.legend.add_item(*item)
430
431        self.plot.scatter(X, Y, Z, colors, sizes, shapes, labels)
432        self.plot.set_x_axis_title(self.candidate_attrs[self.x_attr].name)
433        self.plot.set_y_axis_title(self.candidate_attrs[self.y_attr].name)
434        self.plot.set_z_axis_title(self.candidate_attrs[self.z_attr].name)
435
436        def create_discrete_map(attr_index):
437            values = self.candidate_attrs[attr_index].values
438            return {key: value for key, value in enumerate(values)}
439
440        if self.candidate_attrs[self.x_attr].varType == Discrete:
441            self.plot.set_x_axis_map(create_discrete_map(self.x_attr))
442        if self.candidate_attrs[self.y_attr].varType == Discrete:
443            self.plot.set_y_axis_map(create_discrete_map(self.y_attr))
444        if self.candidate_attrs[self.z_attr].varType == Discrete:
445            self.plot.set_z_axis_map(create_discrete_map(self.z_attr))
446
447    def get_axis_data(self, x_index, y_index, z_index):
448        array = self.data_array
449        X, Y, Z = array[:, x_index], array[:, y_index], array[:, z_index]
450
451        if self.jitter_size > 0:
452            X, Y, Z = map(numpy.copy, [X, Y, Z])
453            x_range = numpy.max(X)-numpy.min(X)
454            y_range = numpy.max(Y)-numpy.min(Y)
455            z_range = numpy.max(Z)-numpy.min(Z)
456            if self.x_attr_discrete or self.jitter_continuous:
457                X += (numpy.random.random(len(X))-0.5) * (self.jitter_size * x_range / 100.)
458            if self.y_attr_discrete or self.jitter_continuous:
459                Y += (numpy.random.random(len(Y))-0.5) * (self.jitter_size * y_range / 100.)
460            if self.z_attr_discrete or self.jitter_continuous:
461                Z += (numpy.random.random(len(Z))-0.5) * (self.jitter_size * z_range / 100.)
462        return X, Y, Z, None
463
464if __name__ == "__main__":
465    app = QApplication(sys.argv)
466    w = OWScatterPlot3D()
467    data = orange.ExampleTable("../../doc/datasets/iris")
468    w.set_data(data)
469    w.handleNewSignals()
470    w.show()
471    app.exec_()
Note: See TracBrowser for help on using the repository browser.