source: orange/orange/Orange/clustering/hierarchical.py @ 7732:7b9bbc3e1a44

Revision 7732:7b9bbc3e1a44, 26.0 KB checked in by ales_erjavec <ales.erjavec@…>, 3 years ago (diff)

Fixed order_leaves calls in cluster(_feature) functions.
Some minor speed improvements in order_leaves (dict lookups tuple de/allocations).

Line 
1"""
2***********************
3Hierarchical clustering
4***********************
5
6.. index::
7   single: clustering, kmeans
8.. index:: aglomerative clustering
9
10Examples
11========
12
13An example.
14
15.. automethod:: Orange.clustering.hierarchical.clustering
16
17"""
18import orange
19from Orange.core import HierarchicalClustering, \
20                        HierarchicalCluster, \
21                        HierarchicalClusterList
22
23
24def clustering(data,
25               distanceConstructor=orange.ExamplesDistanceConstructor_Euclidean,
26               linkage=orange.HierarchicalClustering.Average,
27               order=False,
28               progressCallback=None):
29    """Return a hierarhical clustering of the data set."""
30    distance = distanceConstructor(data)
31    matrix = orange.SymMatrix(len(data))
32    for i in range(len(data)):
33        for j in range(i+1):
34            matrix[i, j] = distance(data[i], data[j])
35    root = orange.HierarchicalClustering(matrix, linkage=linkage, progressCallback=(lambda value, obj=None: progressCallback(value*100.0/(2 if order else 1))) if progressCallback else None)
36    if order:
37        order_leaves(root, matrix, progressCallback=(lambda value: progressCallback(50.0 + value/2)) if progressCallback else None)
38    return root
39
40def clustering_features(data, distance=None, linkage=orange.HierarchicalClustering.Average, order=False, progressCallback=None):
41    """Return hierarhical clustering of attributes in the data set."""
42    matrix = orange.SymMatrix(len(data.domain.attributes))
43    for a1 in range(len(data.domain.attributes)):
44        for a2 in range(a1):
45            matrix[a1, a2] = (1.0 - orange.PearsonCorrelation(a1, a2, data, 0).r) / 2.0
46    root = orange.HierarchicalClustering(matrix, linkage=linkage, progressCallback=(lambda value, obj=None: progressCallback(value*100.0/(2 if order else 1))) if progressCallback else None)
47    if order:
48        order_leaves(root, matrix, progressCallback=(lambda value: progressCallback(50.0 + value/2)) if progressCallback else None)
49    return root
50
51def cluster_to_list(node, prune=None):
52    """Return a list of clusters down from the node of hierarchical clustering."""
53    if prune:
54        if len(node) <= prune:
55            return [] 
56    if node.branches:
57        return [node] + cluster_to_list(node.left, prune) + cluster_to_list(node.right, prune)
58    return [node]
59
60def top_clusters(root, k):
61    """Return k topmost clusters from hierarchical clustering."""
62    candidates = set([root])
63    while len(candidates) < k:
64        repl = max([(max(c.left.height, c.right.height), c) for c in candidates if c.branches])[1]
65        candidates.discard(repl)
66        candidates.add(repl.left)
67        candidates.add(repl.right)
68    return candidates
69
70def top_cluster_membership(root, k):
71    """Return data instances' cluster membership (list of indices) to k topmost clusters."""
72    clist = top_clusters(root, k)
73    cmap = [None] * len(root)
74    for i, c in enumerate(clist):
75        for e in c:
76            cmap[e] = i
77    return cmap
78
79def order_leaves(tree, matrix, progressCallback=None):
80    """Order the leaves in the clustering tree.
81
82    (based on Ziv Bar-Joseph et al. (Fast optimal leaf ordering for hierarchical clustering')
83    Arguments:
84        tree   --binary hierarchical clustering tree of type orange.HierarchicalCluster
85        matrix --orange.SymMatrix that was used to compute the clustering
86        progressCallback --function used to report progress
87    """
88    objects = getattr(tree.mapping, "objects", None)
89    tree.mapping.setattr("objects", range(len(tree)))
90    M = {}
91    ordering = {}
92    visitedClusters = set()
93    def _optOrdering(tree):
94        if len(tree)==1:
95            for leaf in tree:
96                M[tree, leaf, leaf] = 0
97##                print "adding:", tree, leaf, leaf
98        else:
99            _optOrdering(tree.left)
100            _optOrdering(tree.right)
101##            print "ordering", [i for i in tree]
102            Vl = set(tree.left)
103            Vr = set(tree.right)
104            Vlr = set(tree.left.right or tree.left)
105            Vll = set(tree.left.left or tree.left)
106            Vrr = set(tree.right.right or tree.right)
107            Vrl = set(tree.right.left or tree.right)
108            other = lambda e, V1, V2: V2 if e in V1 else V1
109            tree_left, tree_right = tree.left, tree.right
110            for u in Vl:
111                for w in Vr:
112                    if True: #Improved search
113                        C = min([matrix[m, k] for m in other(u, Vll, Vlr) for k in other(w, Vrl, Vrr)])
114                        orderedMs = sorted(other(u, Vll, Vlr), key=lambda m: M[tree_left, u, m])
115                        orderedKs = sorted(other(w, Vrl, Vrr), key=lambda k: M[tree_right, w, k])
116                        k0 = orderedKs[0]
117                        curMin = 1e30000 
118                        curM = curK = None
119                        for m in orderedMs:
120                            if M[tree_left, u, m] + M[tree_right, w, k0] + C >= curMin:
121                                break
122                            for k in  orderedKs:
123                                if M[tree_left, u, m] + M[tree_right, w, k] + C >= curMin:
124                                    break
125                                testMin = M[tree_left, u, m] + M[tree_right, w, k] + matrix[m, k]
126                                if curMin > testMin:
127                                    curMin = testMin
128                                    curM = m
129                                    curK = k
130                        M[tree, u, w] = M[tree, w, u] = curMin
131                        ordering[tree, u, w] = (tree_left, u, curM, tree_right, w, curK)
132                        ordering[tree, w, u] = (tree_right, w, curK, tree_left, u, curM)
133                    else:
134                        def MFunc((m, k)):
135                            return M[tree_left, u, m] + M[tree_right, w, k] + matrix[m, k]
136                        m, k = min([(m, k) for m in other(u, Vll, Vlr) for k in other(w, Vrl, Vrr)], key=MFunc)
137                        M[tree, u, w] = M[tree, w, u] = MFunc((m, k))
138                        ordering[tree, u, w] = (tree_left, u, m, tree_right, w, k)
139                        ordering[tree, w, u] = (tree_right, w, k, tree_left, u, m)
140
141            if progressCallback:
142                progressCallback(100.0 * len(visitedClusters) / len(tree.mapping))
143                visitedClusters.add(tree)
144       
145    _optOrdering(tree)
146
147    def _order(tree, u, w):
148        if len(tree)==1:
149            return
150        left, u, m, right, w, k = ordering[tree, u, w]
151        if len(left)>1 and m not in left.right:
152            left.swap()
153        _order(left, u, m)
154##        if u!=left[0] or m!=left[-1]:
155##            print "error 4:", u, m, list(left)
156        if len(right)>1 and k not in right.left:
157            right.swap()
158        _order(right, k, w)
159##        if k!=right[0] or w!=right[-1]:
160##            print "error 5:", k, w, list(right)
161   
162    u, w = min([(u, w) for u in tree.left for w in tree.right], key=lambda (u, w): M[tree, u, w])
163   
164##    print "M(v) =", M[tree, u, w]
165   
166    _order(tree, u, w)
167
168    def _check(tree, u, w):
169        if len(tree)==1:
170            return
171        left, u, m, right, w, k = ordering[tree, u, w]
172        if tree[0] == u and tree[-1] == w:
173            _check(left, u, m)
174            _check(right, k, w)
175        else:
176            print "Error:", u, w, tree[0], tree[-1]
177
178    _check(tree, u ,w)
179   
180
181    if objects:
182        tree.mapping.setattr("objects", objects)
183
184try:
185    import numpy
186except ImportError:
187    numpy = None
188
189try:
190    import matplotlib
191    from matplotlib.figure import Figure
192    from matplotlib.table import Table, Cell
193    from matplotlib.text import Text
194    from matplotlib.artist import Artist
195##    import  matplotlib.pyplot as plt
196except (ImportError, IOError), ex:
197    matplotlib = None
198    Text , Artist, Table, Cell = object, object, object, object
199
200class TableCell(Cell):
201    PAD = 0.05
202    def __init__(self, *args, **kwargs):
203        Cell.__init__(self, *args, **kwargs)
204        self._text.set_clip_on(True)
205
206class TablePlot(Table):
207    max_fontsize = 12
208    def __init__(self, xy, axes=None, bbox=None):
209        Table.__init__(self, axes or plt.gca(), bbox=bbox)
210        self.xy = xy
211        self.set_transform(self._axes.transData)
212        self._fixed_widhts = None
213        import matplotlib.pyplot as plt
214        self.max_fontsize = plt.rcParams.get("font.size", 12)
215
216    def add_cell(self, row, col, *args, **kwargs):
217        xy = (0,0)
218
219        cell = TableCell(xy, *args, **kwargs)
220        cell.set_figure(self.figure)
221        cell.set_transform(self.get_transform())
222
223        cell.set_clip_on(True)
224        cell.set_clip_box(self._axes.bbox)
225        cell._text.set_clip_box(self._axes.bbox)
226        self._cells[(row, col)] = cell
227
228    def draw(self, renderer):
229        if not self.get_visible(): return
230        self._update_positions(renderer)
231
232        keys = self._cells.keys()
233        keys.sort()
234        for key in keys:
235            self._cells[key].draw(renderer)
236
237    def _update_positions(self, renderer):
238        keys = numpy.array(self._cells.keys())
239        cells = numpy.array([[self._cells.get((row, col), None) for col in range(max(keys[:, 1] + 1))] \
240                             for row in range(max(keys[:, 0] + 1))])
241       
242        widths = self._get_column_widths(renderer)
243        x = self.xy[0] + numpy.array([numpy.sum(widths[:i]) for i in range(len(widths))])
244        y = self.xy[1] - numpy.arange(cells.shape[0]) - 0.5
245       
246        for i in range(cells.shape[0]):
247            for j in range(cells.shape[1]):
248                cells[i, j].set_xy((x[j], y[i]))
249                cells[i, j].set_width(widths[j])
250                cells[i, j].set_height(1.0)
251
252        self._width = numpy.sum(widths)
253        self._height = cells.shape[0]
254
255        self.pchanged()
256
257    def _get_column_widths(self, renderer):
258        keys = numpy.array(self._cells.keys())
259        widths = numpy.zeros(len(keys)).reshape((numpy.max(keys[:,0]+1), numpy.max(keys[:,1]+1)))
260        fontSize = self._calc_fontsize(renderer)
261        for (row, col), cell in self._cells.items():
262            cell.set_fontsize(fontSize)
263            l, b, w, h = cell._text.get_window_extent(renderer).bounds
264            transform = self._axes.transData.inverted()
265            x1, _ = transform.transform_point((0, 0))
266            x2, _ = transform.transform_point((w + w*TableCell.PAD + 10, 0))
267            w = abs(x1 - x2)
268            widths[row, col] = w
269        return numpy.max(widths, 0)
270
271    def _calc_fontsize(self, renderer):
272        transform = self._axes.transData
273        _, y1 = transform.transform_point((0, 0))
274        _, y2 = transform.transform_point((0, 1))
275        return min(max(int(abs(y1 - y2)*0.85) ,4), self.max_fontsize)
276
277    def get_children(self):
278        return self._cells.values()
279
280    def get_bbox(self):
281        return matplotlib.transform.Bbox([self.xy[0], self.xy[1], self.xy[0] + 10, self.xy[1] + 180])
282
283class DendrogramPlotPylab(object):
284    def __init__(self, root, data=None, labels=None, dendrogram_width=None, heatmap_width=None, label_width=None, space_width=None, border_width=0.05, plot_attr_names=False, cmap=None, params={}):
285        if not matplotlib:
286            raise ImportError("Could not import matplotlib module. Please make sure matplotlib is installed on your system.")
287        import matplotlib.pyplot as plt
288        self.plt = plt
289        self.root = root
290        self.data = data
291        self.labels = labels if labels else [str(i) for i in range(len(root))]
292        self.dendrogram_width = dendrogram_width
293        self.heatmap_width = heatmap_width
294        self.label_width = label_width
295        self.space_width = space_width
296        self.border_width = border_width
297        self.params = params
298        self.plot_attr_names = plot_attr_names
299
300    def plotDendrogram(self):
301        self.text_items = []
302        def draw_tree(tree):
303            if tree.branches:
304                points = []
305                for branch in tree.branches:
306                    center = draw_tree(branch)
307                    self.plt.plot([center[0], tree.height], [center[1], center[1]], color="black")
308                    points.append(center)
309                self.plt.plot([tree.height, tree.height], [points[0][1], points[-1][1]], color="black")
310                return (tree.height, (points[0][1] + points[-1][1])/2.0)
311            else:
312                return (0.0, tree.first)
313        draw_tree(self.root)
314       
315    def plotHeatMap(self):
316        import numpy.ma as ma
317        import numpy
318        dx, dy = self.root.height, 0
319        fx, fy = self.root.height/len(self.data.domain.attributes), 1.0
320        data, c, w = self.data.toNumpyMA()
321        data = (data - ma.min(data))/(ma.max(data) - ma.min(data))
322        x = numpy.arange(data.shape[1] + 1)/float(numpy.max(data.shape))
323        y = numpy.arange(data.shape[0] + 1)/float(numpy.max(data.shape))*len(self.root)
324        self.heatmap_width = numpy.max(x)
325
326        X, Y = numpy.meshgrid(x, y - 0.5)
327
328        self.meshXOffset = numpy.max(X)
329
330        self.plt.jet()
331        mesh = self.plt.pcolormesh(X, Y, data[self.root.mapping], edgecolor="b", linewidth=2)
332
333        if self.plot_attr_names:
334            names = [attr.name for attr in self.data.domain.attributes]
335            self.plt.xticks(numpy.arange(data.shape[1] + 1)/float(numpy.max(data.shape)), names)
336        self.plt.gca().xaxis.tick_top()
337        for label in self.plt.gca().xaxis.get_ticklabels():
338            label.set_rotation(45)
339
340        for tick in self.plt.gca().xaxis.get_major_ticks():
341            tick.tick1On = False
342            tick.tick2On = False
343
344    def plotLabels_(self):
345        import numpy
346##        self.plt.yticks(numpy.arange(len(self.labels) - 1, 0, -1), self.labels)
347##        for tick in self.plt.gca().yaxis.get_major_ticks():
348##            tick.tick1On = False
349##            tick.label1On = False
350##            tick.label2On = True
351##        text = TableTextLayout(xy=(self.meshXOffset+1, len(self.root)), tableText=[[label] for label in self.labels])
352        text = TableTextLayout(xy=(self.meshXOffset*1.005, len(self.root) - 1), tableText=[[label] for label in self.labels])
353        text.set_figure(self.plt.gcf())
354        self.plt.gca().add_artist(text)
355        self.plt.gca()._set_artist_props(text)
356
357    def plotLabels(self):
358##        table = TablePlot(xy=(self.meshXOffset*1.005, len(self.root) -1), axes=self.plt.gca())
359        table = TablePlot(xy=(0, len(self.root) -1), axes=self.plt.gca())
360        table.set_figure(self.plt.gcf())
361        for i,label in enumerate(self.labels):
362            table.add_cell(i, 0, width=1, height=1, text=label, loc="left", edgecolor="w")
363        table.set_zorder(0)
364        self.plt.gca().add_artist(table)
365        self.plt.gca()._set_artist_props(table)
366   
367    def plot(self, filename=None, show=False):
368        self.plt.rcParams.update(self.params)
369        labelLen = max(len(label) for label in self.labels)
370        w, h = 800, 600
371        space = 0.01 if self.space_width == None else self.space_width
372        border = self.border_width
373        width = 1.0 - 2*border
374        height = 1.0 - 2*border
375        textLineHeight = min(max(h/len(self.labels), 4), self.plt.rcParams.get("font.size", 12))
376        maxTextLineWidthEstimate = textLineHeight*labelLen
377##        print maxTextLineWidthEstimate
378        textAxisWidthRatio = 2.0*maxTextLineWidthEstimate/w
379##        print textAxisWidthRatio
380        labelsAreaRatio = min(textAxisWidthRatio, 0.4) if self.label_width == None else self.label_width
381        x, y = len(self.data.domain.attributes), len(self.data)
382
383        heatmapAreaRatio = min(1.0*y/h*x/w, 0.3) if self.heatmap_width == None else self.heatmap_width
384        dendrogramAreaRatio = 1.0 - labelsAreaRatio - heatmapAreaRatio - 2*space if self.dendrogram_width == None else self.dendrogram_width
385
386        self.fig = self.plt.figure()
387        self.labels_offset = self.root.height/20.0
388        dendrogramAxes = self.plt.axes([border, border, width*dendrogramAreaRatio, height])
389        dendrogramAxes.xaxis.grid(True)
390        import matplotlib.ticker as ticker
391
392        dendrogramAxes.yaxis.set_major_locator(ticker.NullLocator())
393        dendrogramAxes.yaxis.set_minor_locator(ticker.NullLocator())
394        dendrogramAxes.invert_xaxis()
395        self.plotDendrogram()
396        heatmapAxes = self.plt.axes([border + width*dendrogramAreaRatio + space, border, width*heatmapAreaRatio, height], sharey=dendrogramAxes)
397
398        heatmapAxes.xaxis.set_major_locator(ticker.NullLocator())
399        heatmapAxes.xaxis.set_minor_locator(ticker.NullLocator())
400        heatmapAxes.yaxis.set_major_locator(ticker.NullLocator())
401        heatmapAxes.yaxis.set_minor_locator(ticker.NullLocator())
402       
403        self.plotHeatMap()
404        labelsAxes = self.plt.axes([border + width*(dendrogramAreaRatio + heatmapAreaRatio + 2*space), border, width*labelsAreaRatio, height], sharey=dendrogramAxes)
405        self.plotLabels()
406        labelsAxes.set_axis_off()
407        labelsAxes.xaxis.set_major_locator(ticker.NullLocator())
408        labelsAxes.xaxis.set_minor_locator(ticker.NullLocator())
409        labelsAxes.yaxis.set_major_locator(ticker.NullLocator())
410        labelsAxes.yaxis.set_minor_locator(ticker.NullLocator())
411        if filename:
412            import matplotlib.backends.backend_agg
413            canvas = matplotlib.backends.backend_agg.FigureCanvasAgg(self.fig)
414            canvas.print_figure(filename)
415        if show:
416            self.plt.show()
417       
418       
419from orngMisc import ColorPalette, EPSRenderer
420class DendrogramPlot(object):
421    """ A class for drawing dendrograms
422    Example:
423    >>> a = DendrogramPlot(tree)
424    """
425    def __init__(self, tree, attr_tree = None, labels=None, data=None, width=None, height=None, tree_height=None, heatmap_width=None, text_width=None, 
426                 spacing=2, cluster_colors={}, color_palette=ColorPalette([(255, 0, 0), (0, 255, 0)]), maxv=None, minv=None, gamma=None, renderer=EPSRenderer):
427        self.tree = tree
428        self.attr_tree = attr_tree
429        self.labels = [str(ex.getclass()) for ex in data] if not labels and data and data.domain.classVar else (labels or [])
430#        self.attr_labels = [str(attr.name) for attr in data.domain.attributes] if not attr_labels and data else attr_labels or []
431        self.data = data
432        self.width, self.height = float(width) if width else None, float(height) if height else None
433        self.tree_height = tree_height
434        self.heatmap_width = heatmap_width
435        self.text_width = text_width
436        self.font_size = 10.0
437        self.linespacing = 0.0
438        self.cluster_colors = cluster_colors
439        self.horizontal_margin = 10.0
440        self.vertical_margin = 10.0
441        self.spacing = float(spacing) if spacing else None
442        self.color_palette = color_palette
443        self.minv = minv
444        self.maxv = maxv
445        self.gamma = gamma
446        self.set_matrix_color_schema(color_palette, minv, maxv, gamma)
447        self.renderer = renderer
448       
449    def set_matrix_color_schema(self, color_palette, minv, maxv, gamma=None):
450        """ Set the matrix color scheme.
451        """
452        if isinstance(color_palette, ColorPalette):
453            self.color_palette = color_palette
454        else:
455            self.color_palette = ColorPalette(color_palette)
456        self.minv = minv
457        self.maxv = maxv
458        self.gamma = gamma
459       
460    def color_shema(self):
461        vals = [float(val) for ex in self.data for val in ex if not val.isSpecial() and val.variable.varType==orange.VarTypes.Continuous] or [0]
462        avg = sum(vals)/len(vals)
463       
464        maxVal = self.maxv if self.maxv else max(vals)
465        minVal = self.minv if self.minv else min(vals)
466       
467        def _colorSchema(val):
468            if val.isSpecial():
469                return self.color_palette(None)
470            elif val.variable.varType==orange.VarTypes.Continuous:
471                r, g, b = self.color_palette((float(val) - minVal) / abs(maxVal - minVal), gamma=self.gamma)
472            elif val.variable.varType==orange.VarTypes.Discrete:
473                r = g = b = int(255.0*float(val)/len(val.variable.values))
474            return (r, g, b)
475        return _colorSchema
476   
477    def layout(self):
478        height_final = False
479        width_final = False
480        tree_height = self.tree_height or 100
481        if self.height:
482            height, height_final = self.height, True
483            heatmap_height = height - (tree_height + self.spacing if self.attr_tree else 0) - 2 * self.horizontal_margin
484            font_size =  heatmap_height / len(self.labels) #self.font_size or (height - (tree_height + self.spacing if self.attr_tree else 0) - 2 * self.horizontal_margin) / len(self.labels)
485        else:
486            font_size = self.font_size
487            heatmap_height = font_size * len(self.labels)
488            height = heatmap_height + (tree_height + self.spacing if self.attr_tree else 0) + 2 * self.horizontal_margin
489             
490        text_width = self.text_width or max([len(label) for label in self.labels] + [0]) * font_size #max([self.renderer.string_size_hint(label) for label in self.labels])
491       
492        if self.width:
493            width = self.width
494            heatmap_width = width - 2 * self.vertical_margin - tree_height - (2 if self.data else 1) * self.spacing - text_width if self.data else 0
495        else:
496            heatmap_width = len(self.data.domain.attributes) * heatmap_height / len(self.data) if self.data else 0
497            width = 2 * self.vertical_margin + tree_height + (heatmap_width + self.spacing if self.data else 0) + self.spacing + text_width
498           
499        return width, height, tree_height, heatmap_width, heatmap_height, text_width, font_size
500   
501    def plot(self, filename="graph.eps"):
502        width, height, tree_height, heatmap_width, heatmap_height, text_width, font_size = self.layout()
503        heatmap_cell_height = heatmap_height / len(self.labels)
504        heatmap_cell_width = heatmap_width / len(self.data.domain.attributes)
505       
506        self.renderer = self.renderer(width, height)
507       
508        def draw_tree(cluster, root, treeheight, treewidth, color):
509            height = treeheight * cluster.height / root.height
510            if cluster.branches:
511                centers = []
512                for branch in cluster.branches:
513                    center = draw_tree(branch, root, treeheight, treewidth, self.cluster_colors.get(branch, color))
514                    centers.append(center)
515                    self.renderer.draw_line(center[0], center[1], center[0], height, stroke_color = self.cluster_colors.get(branch, color))
516                   
517                self.renderer.draw_line(centers[0][0], height, centers[-1][0], height, stroke_color = self.cluster_colors.get(cluster, color))
518                return (centers[0][0] + centers[-1][0]) / 2.0, height
519            else:
520                return float(treewidth) * cluster.first / len(root), 0.0
521        self.renderer.save_render_state()
522        self.renderer.translate(self.vertical_margin + tree_height, self.horizontal_margin + (tree_height + self.spacing if self.attr_tree else 0) + heatmap_cell_height / 2.0)
523        self.renderer.rotate(90)
524#        print self.renderer.transform()
525        draw_tree(self.tree, self.tree, tree_height, heatmap_height, self.cluster_colors.get(self.tree, (0,0,0)))
526        self.renderer.restore_render_state()
527        if self.attr_tree:
528            self.renderer.save_render_state()
529            self.renderer.translate(self.vertical_margin + tree_height + self.spacing + heatmap_cell_width / 2.0, self.horizontal_margin + tree_height)
530            self.renderer.scale(1.0, -1.0)
531#            print self.renderer.transform()
532            draw_tree(self.attr_tree, self.attr_tree, tree_height, heatmap_width, self.cluster_colors.get(self.attr_tree, (0,0,0)))
533            self.renderer.restore_render_state()
534       
535        self.renderer.save_render_state()
536        self.renderer.translate(self.vertical_margin + tree_height + self.spacing, self.horizontal_margin + (tree_height + self.spacing if self.attr_tree else 0))
537#        print self.renderer.transform()
538        if self.data:
539            colorSchema = self.color_shema()
540            for i, ii in enumerate(self.tree):
541                ex = self.data[ii]
542                for j, jj in enumerate((self.attr_tree if self.attr_tree else range(len(self.data.domain.attributes)))):
543                    r, g, b = colorSchema(ex[jj])
544                    self.renderer.draw_rect(j * heatmap_cell_width, i * heatmap_cell_height, heatmap_cell_width, heatmap_cell_height, fill_color=(r, g, b), stroke_color=(255, 255, 255))
545       
546        self.renderer.translate(heatmap_width + self.spacing, heatmap_cell_height)
547#        print self.renderer.transform()
548        self.renderer.set_font("Times-Roman", font_size)
549        for index in self.tree: #label in self.labels:
550            self.renderer.draw_text(0.0, 0.0, self.labels[index])
551            self.renderer.translate(0.0, heatmap_cell_height)
552        self.renderer.restore_render_state()
553        self.renderer.save(filename)
554       
555def dendrogram_draw(filename, *args, **kwargs):
556    import os
557    from orngMisc import PILRenderer, EPSRenderer, SVGRenderer
558    name, ext = os.path.splitext(filename)
559    kwargs["renderer"] = {".eps":EPSRenderer, ".svg":SVGRenderer, ".png":PILRenderer}.get(ext.lower(), PILRenderer)
560#    print kwargs["renderer"], ext
561    d = DendrogramPlot(*args, **kwargs)
562    d.plot(filename)
563   
564if __name__=="__main__":
565    data = orange.ExampleTable("doc//datasets//brown-selected.tab")
566#    data = orange.ExampleTable("doc//datasets//iris.tab")
567    root = hierarchicalClustering(data, order=True) #, linkage=orange.HierarchicalClustering.Single)
568    attr_root = hierarchicalClustering_attributes(data, order=True)
569#    print root
570#    d = DendrogramPlotPylab(root, data=data, labels=[str(ex.getclass()) for ex in data], dendrogram_width=0.4, heatmap_width=0.3,  params={}, cmap=None)
571#    d.plot(show=True, filename="graph.png")
572
573    dendrogram_draw("graph.eps", root, attr_tree=attr_root, data=data, labels=[str(e.getclass()) for e in data], tree_height=50, #width=500, height=500,
574                          cluster_colors={root.right:(255,0,0), root.right.right:(0,255,0)}, 
575                          color_palette=ColorPalette([(255, 0, 0), (0,0,0), (0, 255,0)], gamma=0.5, 
576                                                     overflow=(255, 255, 255), underflow=(255, 255, 255))) #, minv=-0.5, maxv=0.5)
Note: See TracBrowser for help on using the repository browser.