source: orange-bioinformatics/Orange/bioinformatics/widgets/prototypes/OWDifferentiationScale.py @ 1625:cefeb35cbfc9

Revision 1625:cefeb35cbfc9, 20.7 KB checked in by mitar, 2 years ago (diff)

Moving files around.

Line 
1"""<name>Differentiation Scale</name>
2<description></description>
3"""
4
5import os, sys
6import numpy
7import random
8
9import obiDifscale
10import Orange
11from operator import itemgetter, add
12from collections import defaultdict
13
14from OWWidget import *
15import OWGUI
16
17class OWDifferentiationScale(OWWidget):
18    def __init__(self, parent=None, signalManager=None, title="Differentiation Scale"):
19        OWWidget.__init__(self, parent, signalManager, title, wantGraph=True)
20       
21        self.inputs = [("Gene Expression Samples", Orange.data.Table, self.set_data), ("Additional Expression Samples", Orange.data.Table, self.set_additional_data)]
22        self.outputs = [("Selected Time Points", Orange.data.Table), ("Additional Selected Time Points", Orange.data.Table)]
23       
24        self.selected_time_label = 0
25        self.auto_commit = 0
26       
27        self.loadSettings()
28       
29        self.selection_changed_flag = False
30       
31        #####
32        # GUI
33        #####
34        box = OWGUI.widgetBox(self.controlArea, "Info")
35        self.info_label = OWGUI.widgetLabel(box, "No data on input")
36        self.info_label.setWordWrap(True)
37        self.info_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
38       
39        OWGUI.rubber(self.controlArea)
40       
41        box = OWGUI.widgetBox(self.controlArea, "Selection")
42       
43        cb = OWGUI.checkBox(box, self, "auto_commit", "Commit on any change",
44                            tooltip="Send updated selections automatically",
45                            callback=self.commit_if)
46       
47        b = OWGUI.button(box, self, "Commit",
48                         callback=self.commit,
49                         tooltip="Send selections on output signals")
50       
51        OWGUI.setStopper(self, b, cb, "selection_changed_flag",
52                         callback=self.commit)
53       
54        self.connect(self.graphButton, SIGNAL("pressed()"), self.save_graph)
55       
56        self.scene = QGraphicsScene()
57        self.scene_view = DiffScaleView(self.scene, self.mainArea)
58        self.scene_view.setRenderHint(QPainter.Antialiasing)
59        self.scene_view.setMinimumWidth(300)
60        self.mainArea.layout().addWidget(self.scene_view)
61        self.connect(self.scene, SIGNAL("selectionChanged()"), self.on_selection_changed)
62        self.connect(self.scene_view, SIGNAL("view_resized(QSize)"), lambda size: self.on_view_resized())
63       
64        self.data = None
65        self.additional_data = None
66        self.projections1 = []
67        self.projections2 = []
68        self.labels1 = []
69        self.labels2 = []
70       
71        self.selected_time_samples = [], []
72       
73        self.controlArea.setMaximumWidth(300)
74        self.resize(600, 480)
75       
76    def clear(self):
77        """ Clear the widget state
78        """
79        self.projections1 = []
80        self.projections2 = []
81        self.labels1 = []
82        self.labels2 = []
83        self.clear_selection()
84        self.scene.clear()
85       
86    def clear_selection(self):
87        """ Clear the time point selection.
88        """
89        self.selected_time_samples = [], []
90       
91    def set_data(self, data = None):
92        """ Set the data for the widget.
93        """
94        self.clear()
95        self.data = data
96       
97    def set_additional_data(self, data=None):
98        """ Set an additional data set.
99        """
100        self.clear()
101        self.additional_data = data
102       
103    def handleNewSignals(self):
104        if self.data is not None:
105            self.run_projections()
106            self.projection_layout()
107            self.update_graph()
108           
109            info_text = """\
110Data with {0} genes
111and {1} samples on input.\n""".format(len(self.data),
112                 len(self.data.domain.attributes))
113            if self.additional_data is not None:
114                info_text += """\
115Additional data with {0} genes
116and  {1} samples on input.""".format(len(self.additional_data),
117                                                    len(self.additional_data.domain.attributes))
118            self.info_label.setText(info_text)
119        else:
120            self.send("Selected Time Points", None)
121            self.send("Additional Selected Time Points", None)
122            self.info_label.setText("No data on input\n")
123           
124    def run_projections(self):
125        """ Run obiDifscale.get_projections with the current inputs.
126        """
127        self.error()
128#        try:
129#            attr_set = list(set(a.attributes['time'] for a in data.domain.attributes))
130#            self.time_points = obiDifscale.conv(attr_set, ticks=False)
131#        except KeyError, ex:
132#            self.error("Could not extract time data")
133#            self.clear()
134#            return
135       
136        try:
137            (self.projections1, self.labels1,
138             self.projections2, self.labels2) = \
139                obiDifscale.get_projections(self.data, data2=self.additional_data)
140        except Exception, ex:
141            self.error("Failed to obtain the projections due to: %r" % ex)
142            self.clear()
143            return
144       
145    def projection_layout(self):
146        """ Compute the layout for the projections.
147        """
148        if self.projections1: 
149            projections = self.projections1 + self.projections2
150            projections = numpy.array(projections)
151           
152            x_min = numpy.min(projections)
153            x_max = numpy.max(projections)
154           
155            # Scale projections
156            projections = (projections - x_min) / ((x_max - x_min) or 1.0)
157            projections = list(projections)
158           
159            labels = self.labels1 + self.labels2
160           
161            samples = [(attr, self.data) for attr in self.data.domain.attributes] + \
162                      ([(attr, self.additional_data) for attr in self.additional_data.domain.attributes] \
163                       if self.additional_data is not None else [])
164           
165            # TODO: handle samples with the same projection
166            # the point_layout should return the proj to sample mapping instead
167            proj_to_sample = dict([((label, proj), sample) for label, proj, sample \
168                                   in zip(labels, projections, samples)])
169            self.proj_to_sample = proj_to_sample
170           
171            time_points = point_layout(labels, projections)
172            self.time_points = time_points
173            level_height = 20
174            all_points = numpy.array(reduce(add, [p for _, p in time_points], []))
175            self.all_points = all_points
176           
177#            all_points[:, 1] *= -level_height
178            self.time_samples = [] # samples for time label (same order as in self.time_points)
179           
180            point_i = 0
181            for label, points, in time_points:
182                samples = [] 
183                for x, y in points:
184                    samples.append(proj_to_sample.get((label, x), None))
185                self.time_samples.append((label, samples))
186           
187    def update_graph(self):
188        """ Populate the Graphics Scene with the current projections.
189        """
190        scene_size_hint = self.scene_view.viewport().size()
191        scene_size_hint = QSizeF(max(scene_size_hint.width() - 50, 100),
192                                 scene_size_hint.height())
193        self.scene.clear()
194       
195        if self.projections1:
196            level_height = 20
197            all_points = self.all_points.copy()
198            all_points[:, 0] *= scene_size_hint.width()
199            all_points[:, 1] *= -level_height
200           
201            point_i = 0
202            centers = []
203            z_value = 0
204            for label, samples in self.time_samples:
205                # Points
206                p1 = all_points[point_i]
207                points = all_points[point_i: point_i + len(samples), :]
208                for (x, y), sample in zip(points, samples):
209                    item = GraphicsTimePoint(QRectF(QPointF(x-3, y-3), QSizeF(6, 6)))
210                    item.setBrush(QBrush(Qt.black))
211                    item.sample = sample
212                    item.setToolTip(sample[0].name if sample else "")
213                    item.setZValue(z_value)
214                    self.scene.addItem(item)
215                    point_i += 1
216                p2 = all_points[point_i - 1]
217               
218                # Line over all points
219                line = QGraphicsLineItem(QLineF(*(tuple(p1) + tuple(p2))))
220                line.setPen(QPen(Qt.black, 2))
221                line.setZValue(z_value - 1)
222                self.scene.addItem(line)
223               
224                # Time label on top of the median
225                n_points = len(points)
226                if n_points % 2:
227                    center = points[n_points / 2]
228                else:
229                    center = (points[n_points / 2] + points[n_points / 2 + 1]) / 2.0
230                centers.append(center)
231                x, y = center
232                text = QGraphicsSimpleTextItem(label)
233                w = text.boundingRect().width()
234                text.setPos(x - w / 2.0, y - 17.5)
235                self.scene.addItem(text)
236           
237            self.scene.addLine(QLineF(0.0, 0.0, scene_size_hint.width(), 0.0))
238           
239            polygon = QPolygonF([QPointF(3.0, 0.0),
240                                 QPointF(-2.0, -2.0),
241                                 QPointF(0.0, 0.0),
242                                 QPointF(-2.0, 2.0),
243                                 QPointF(3.0, 0.0)])
244           
245            arrow = QGraphicsPolygonItem(polygon)
246            arrow.setBrush(QBrush(Qt.black))
247            arrow.setPos(scene_size_hint.width(), 0.0)
248            arrow.scale(2, 2)
249            self.scene.addItem(arrow)
250           
251            title = QGraphicsSimpleTextItem("Development (time)")
252            font = self.font()
253            font.setPointSize(10)
254            title.setFont(font)
255            w = title.boundingRect().width()
256            title.setPos(scene_size_hint.width() - w, -15)
257            self.scene.addItem(title)
258           
259            rects = []
260            ticks = []
261            axis_label_items = []
262            labels = [(center, label) for center, (label, _) in zip(centers, self.time_samples)]
263            labels = sorted(labels, key=lambda (c, l): c[0])
264            for center, label in labels:
265                x, y = center
266                item = QGraphicsSimpleTextItem(label)
267                w = item.boundingRect().width()
268                item.setPos(x - w / 2.0, 4.0)
269                rects.append(item.sceneBoundingRect().normalized())
270                ticks.append(QPointF(x - w / 2.0, 4.0))
271                axis_label_items.append(item)
272           
273#            rects = SA_axis_label_layout(ticks, rects, max_time=0.5,
274#                                         x_factor=scene_size_hint.width() / 50.0,
275#                                         y_factor=10,
276#                                         random=random.Random(0))
277
278            rects = greedy_scale_label_layout(ticks, rects, spacing=5)
279           
280            for (tick, label), rect, item in zip(labels, rects, axis_label_items):
281                x, y = tick
282                self.scene.addLine(x, -2, x, 2)
283                if rect.top() - item.pos().y() > 5:
284                    self.scene.addLine(x, 2, rect.center().x(), 14.0)
285                if rect.top() - item.pos().y() > 15:
286                    self.scene.addLine(rect.center().x(), 14.0, rect.center().x(), rect.top())
287#                item.setPos(rect.topLeft())
288               
289#                text = QGraphicsSimpleTextItem(label)
290#            for tick, rect, item in zip(ticks, rects, axis_label_items):
291                item.setPos(rect.topLeft())
292                self.scene.addItem(item)
293#                w = text.boundingRect().width()
294#                text.setPos(x - w / 2.0, 4)
295                # Need to compute axis label layout.
296#                self.scene.addItem(text)
297
298            self.scene.setSceneRect(self.scene.itemsBoundingRect().adjusted(-10, -10, 10, 10))
299
300    def on_view_resized(self):
301        self.update_graph()
302       
303    def on_selection_changed(self):
304        try:
305            selected = self.scene.selectedItems()
306        except RuntimeError:
307            return
308       
309        selected_attrs1 = []
310        selected_attrs2  =[]
311        for point in selected:
312            attr, data = point.sample if point.sample else (None, None)
313            if data is self.data:
314                selected_attrs1.append(attr)
315            elif data is self.additional_data:
316                selected_attrs2.append(attr)
317               
318        self.selected_time_samples = selected_attrs1, selected_attrs2
319        print self.selected_time_samples
320        self.commit_if()
321           
322    def commit_if(self):
323        if self.auto_commit:
324            self.commit()
325        else:
326            self.selection_changed_flag = True
327   
328    def commit(self):
329        if self.data is not None:
330            selected1, selected2 = self.selected_time_samples
331            attrs1 = [a for a in self.data.domain.attributes \
332                      if a in selected1]
333            domain = Orange.data.Domain(attrs1, self.data.domain.class_var)
334            domain.add_metas(self.data.domain.get_metas())
335            data = Orange.data.Table(domain, self.data)
336            self.send("Selected Time Points", data)
337           
338            if self.additional_data is not None:
339                attrs2 = [a for a in self.additional_data.domain.attributes \
340                          if a in selected2]
341                domain = Orange.data.Domain(attrs2, self.additional_data.domain.class_var)
342                domain.add_metas(self.additional_data.domain.get_metas())
343                data = Orange.data.Table(domain, self.additional_data)
344                self.send("Additional Selected Time Points", data)
345        else:
346            self.send("Selected Time Points", None)
347            self.send("Additional Selected Time Points", None)
348        self.selection_changed_flag = False
349       
350    def save_graph(self):
351        from OWDlgs import OWChooseImageSizeDlg
352        dlg = OWChooseImageSizeDlg(self.scene, parent=self)
353        dlg.exec_()
354   
355   
356class GraphicsTimePoint(QGraphicsEllipseItem):
357    def __init__(self, *args):
358        QGraphicsEllipseItem.__init__(self, *args)
359        self.setFlags(QGraphicsItem.ItemIsSelectable)
360        self.setAcceptsHoverEvents(True)
361        self._is_hovering = False
362       
363    def paint(self, painter, option, widget=0):
364        if self.isSelected():
365            brush = QBrush(Qt.red)
366            pen = QPen(Qt.red, 1)
367        else:
368            brush = QBrush(Qt.darkGray)
369            pen = QPen(Qt.black, 1)
370        if self._is_hovering:
371            brush = QBrush(brush.color().darker(200))
372        painter.save()
373        painter.setBrush(brush)
374        painter.setPen(pen)
375        painter.drawEllipse(self.rect())
376        painter.restore()
377       
378    def hoverEnterEvent(self, event):
379        self._is_hovering = True
380        self.update()
381        return QGraphicsEllipseItem.hoverEnterEvent(self, event)
382   
383    def hoverLeaveEvent(self, event):
384        self._is_hovering = False
385        self.update()
386        return QGraphicsEllipseItem.hoverLeaveEvent(self, event)
387       
388   
389class DiffScaleView(QGraphicsView):
390    def resizeEvent(self, event):
391        QGraphicsView.resizeEvent(self, event)
392        self.emit(SIGNAL("view_resized(QSize)"), event.size())
393       
394
395def point_layout(labels, points, label_size_hints=None):
396    groups = defaultdict(list)
397    for label, point in zip(labels, points):
398        groups[label].append(point)
399       
400    for label, points in list(groups.items()):
401        points = sorted(points)
402        # TODO: Use label_size_hints for min, max
403        groups[label] = (points, (points[0], points[-1]))
404   
405    sorted_groups = sorted(groups.items(), key=itemgetter(1), reverse=True)
406    levels = {}
407    curr_level = 1
408    label_levels = {}
409    while sorted_groups:
410        label, (points, (x_min, x_max)) = sorted_groups.pop(-1)
411        max_level_pos = levels.get(curr_level, x_min)
412        if x_min < max_level_pos:
413            curr_level += 1
414            sorted_groups.append((label, (points, (x_min, x_max))))
415        else:
416            label_levels[label] = curr_level
417            levels[curr_level] = x_max
418            curr_level = 1
419           
420    for label, (points, _) in list(groups.items()):
421        level = float(label_levels[label])
422        groups[label] = [(x, level) for x in points]
423       
424    return list(groups.items())
425
426   
427def greedy_scale_label_layout(ticks, rects, spacing=3):
428    """ Layout the labels at ticks on a linear scale, by raising the
429    overlapping labels.
430   
431    """
432    def adjust_interval(start, end, min_v, max_v):
433        """ Adjust (start, end) interval to fit inside the (min_v, max_v).
434        """
435        if start < min_v:
436            return (min_v, min_v + (end - start))
437        elif max_v > end:
438            return (max_v - (end - start), max_v)
439        else:
440            return (start, end)
441       
442    def center_interval(start, end, center):
443        """ Center the interval on `center`
444        """
445        span = end - start
446        return centered(center, span)
447   
448    def centered(center, span):
449        """ Return an centered interval with span.
450        """
451        return (center - span / 2.0, center + span / 2.0)
452   
453    def contains((start, end), (start1, end1)):
454        return start <= start1  and end >= end1
455   
456    def fit(work, ticks, min_x, max_x):
457        """ Fit the work set between min_x and max_x  and centered on the
458        ticks, if possible.
459        """
460        fits = False
461        work_set = map(QRectF, work)
462        tick_center = sum([r.center().x() for r in work_set]) / len(work_set)
463        if len(work_set) == 1:
464            if work_set[0].left() >= min_x and work_set[0].right() <= max_x:
465                return work_set
466            else:
467                return []
468       
469        elif len(work_set) == 2: # TODO: MErge this with the > 2
470            w_sum = sum([r.width() for r in work_set]) + spacing
471            if w_sum < max_x - min_x:
472                r1, r2 = work_set
473                interval = centered(tick_center, w_sum)
474               
475                if not contains((min_x, max_x), interval):
476                    interval = adjust_interval(*(interval + (min_x, max_x)))
477                   
478                if contains((min_x, max_x), interval):
479                    r1.moveLeft(interval[0])
480                    r2.moveLeft(interval[1] - r2.width())
481                    r1.moveTop(r1.top() + 10)
482                    r2.moveTop(r2.top() + 10)
483                    return work_set
484                else:
485                    return []
486            else:
487                return []
488       
489        elif len(work_set) > 2:
490            center = (work_set[0].center().x() + work_set[-1].center().x()) / 2.0
491            w_sum = work_set[0].width() / 2.0 + work_set[-1].width() / 2.0 + spacing
492            for i, r in enumerate(work_set[1:-1]):
493                w_sum += r.width() + spacing
494            interval = centered(center, w_sum)
495           
496            if not contains((min_x, max_x), interval):
497                interval = adjust_interval(*(interval + (min_x, max_x)))
498               
499            if contains((min_x, max_x), interval):
500                istart, iend = interval
501                rstart, rend = work_set[0], work_set[-1]
502                rstart.moveLeft(istart)
503                rstart.moveTop(rstart.top() + 10)
504                rend.moveLeft(iend - rend.width())
505                rend.moveTop(rend.top() + 10)
506                istart += rstart.width() / 2.0
507                iend -= rend.width() / 2.0
508                for r in work_set[1: -1]:
509                    r.moveLeft(istart)
510                    r.moveTop(r.top() + 20)
511                    istart += r.width() + spacing
512                return work_set
513            else:
514                return []
515           
516    queue = sorted(zip(ticks, rects),
517                   key=lambda (t, _): t.x(),
518                   reverse=True)
519    done = False
520    rects = []
521   
522    min_x = -1e30
523    max_x = 1e30
524   
525    while queue:
526        work_set = [queue.pop(-1)]
527        set_fits = False
528        max_x = queue[-1][1].left() if queue else 1e30
529        while not set_fits:
530            new_rects = fit(map(itemgetter(1), work_set),
531                            map(itemgetter(0), work_set),
532                            min_x, max_x)
533            if new_rects: # Can the work set be fit.
534                set_fits = True
535                rects.extend(new_rects)
536                min_x = work_set[-1][1].right()
537               
538            else:
539                # Extend the work set with one more label rect
540                work_set.append(queue.pop(-1))
541                max_x = queue[-1][1].left() if queue else 1e30
542    return rects
543       
544   
545if __name__ == "__main__":
546    app = QApplication(sys.argv)
547    w = OWDifferentiationScale()
548    data = Orange.data.Table(os.path.expanduser("~/Documents/GDS2666n"))
549    w.show()
550    w.set_data(data)
551    w.handleNewSignals()
552    app.exec_()
553    w.saveSettings()
Note: See TracBrowser for help on using the repository browser.