source: orange/orange/Orange/clustering/hierarchical.py @ 7645:254d35e2fd6c

Revision 7645:254d35e2fd6c, 26.0 KB checked in by ales_erjavec <ales.erjavec@…>, 3 years ago (diff)
  • added imports of HierarhicalClustering classes from Orange.core
Line 
1"""
2***********************
3Hierarchical clustering
4***********************
5
6.. index::
7   single: clustering, kmeans
8.. index:: aglomerative clustering
9
10Examples
11========
12
13An example.
14
15.. automethod:: Orange.clustering.hierarchical.clustering
16
17"""
18import orange
19from Orange.core import HierarchicalClustering, \
20                        HierarchicalCluster, \
21                        HierarchicalClusterList
22
23
24def clustering(data,
25               distanceConstructor=orange.ExamplesDistanceConstructor_Euclidean,
26               linkage=orange.HierarchicalClustering.Average,
27               order=False,
28               progressCallback=None):
29    """Return a hierarhical clustering of the data set."""
30    distance = distanceConstructor(data)
31    matrix = orange.SymMatrix(len(data))
32    for i in range(len(data)):
33        for j in range(i+1):
34            matrix[i, j] = distance(data[i], data[j])
35    root = orange.HierarchicalClustering(matrix, linkage=linkage, progressCallback=(lambda value, obj=None: progressCallback(value*100.0/(2 if order else 1))) if progressCallback else None)
36    if order:
37        orderLeaves(root, matrix, progressCallback=(lambda value: progressCallback(50.0 + value/2)) if progressCallback else None)
38    return root
39
40def clustering_features(data, distance=None, linkage=orange.HierarchicalClustering.Average, order=False, progressCallback=None):
41    """Return hierarhical clustering of attributes in the data set."""
42    matrix = orange.SymMatrix(len(data.domain.attributes))
43    for a1 in range(len(data.domain.attributes)):
44        for a2 in range(a1):
45            matrix[a1, a2] = (1.0 - orange.PearsonCorrelation(a1, a2, data, 0).r) / 2.0
46    root = orange.HierarchicalClustering(matrix, linkage=linkage, progressCallback=(lambda value, obj=None: progressCallback(value*100.0/(2 if order else 1))) if progressCallback else None)
47    if order:
48        orderLeaves(root, matrix, progressCallback=(lambda value: progressCallback(50.0 + value/2)) if progressCallback else None)
49    return root
50
51def cluster_to_list(node, prune=None):
52    """Return a list of clusters down from the node of hierarchical clustering."""
53    if prune:
54        if len(node) <= prune:
55            return [] 
56    if node.branches:
57        return [node] + cluster_to_list(node.left, prune) + cluster_to_list(node.right, prune)
58    return [node]
59
60def top_clusters(root, k):
61    """Return k topmost clusters from hierarchical clustering."""
62    candidates = set([root])
63    while len(candidates) < k:
64        repl = max([(max(c.left.height, c.right.height), c) for c in candidates if c.branches])[1]
65        candidates.discard(repl)
66        candidates.add(repl.left)
67        candidates.add(repl.right)
68    return candidates
69
70def top_cluster_membership(root, k):
71    """Return data instances' cluster membership (list of indices) to k topmost clusters."""
72    clist = top_clusters(root, k)
73    cmap = [None] * len(root)
74    for i, c in enumerate(clist):
75        for e in c:
76            cmap[e] = i
77    return cmap
78
79def order_leaves(tree, matrix, progressCallback=None):
80    """Order the leaves in the clustering tree.
81
82    (based on Ziv Bar-Joseph et al. (Fast optimal leaf ordering for hierarchical clustering')
83    Arguments:
84        tree   --binary hierarchical clustering tree of type orange.HierarchicalCluster
85        matrix --orange.SymMatrix that was used to compute the clustering
86        progressCallback --function used to report progress
87    """
88    objects = getattr(tree.mapping, "objects", None)
89    tree.mapping.setattr("objects", range(len(tree)))
90    M = {}
91    ordering = {}
92    visitedClusters = set()
93    def _optOrdering(tree):
94        if len(tree)==1:
95            for leaf in tree:
96                M[tree, leaf, leaf] = 0
97##                print "adding:", tree, leaf, leaf
98        else:
99            _optOrdering(tree.left)
100            _optOrdering(tree.right)
101##            print "ordering", [i for i in tree]
102            Vl = set(tree.left)
103            Vr = set(tree.right)
104            Vlr = set(tree.left.right or tree.left)
105            Vll = set(tree.left.left or tree.left)
106            Vrr = set(tree.right.right or tree.right)
107            Vrl = set(tree.right.left or tree.right)
108            other = lambda e, V1, V2: V2 if e in V1 else V1
109            tree_left, tree_right = tree.left, tree.right
110            for u in Vl:
111                for w in Vr:
112                    if True: #Improved search
113                        C = min([matrix[m, k] for m in other(u, Vll, Vlr) for k in other(w, Vrl, Vrr)])
114                        orderedMs = sorted(other(u, Vll, Vlr), key=lambda m: M[tree_left, u, m])
115                        orderedKs = sorted(other(w, Vrl, Vrr), key=lambda k: M[tree_right, w, k])
116                        k0 = orderedKs[0]
117                        curMin = 1e30000 
118                        curMK = ()
119                        for m in orderedMs:
120                            if M[tree_left, u, m] + M[tree_right, w, k0] + C >= curMin:
121                                break
122                            for k in  orderedKs:
123                                if M[tree_left, u, m] + M[tree_right, w, k] + C >= curMin:
124                                    break
125                                if curMin > M[tree_left, u, m] + M[tree_right, w, k] + matrix[m, k]:
126                                    curMin = M[tree_left, u, m] + M[tree_right, w, k] + matrix[m, k]
127                                    curMK = (m, k)
128                        M[tree, u, w] = M[tree, w, u] = curMin
129                        ordering[tree, u, w] = (tree_left, u, curMK[0], tree_right, w, curMK[1])
130                        ordering[tree, w, u] = (tree_right, w, curMK[1], tree_left, u, curMK[0])
131                    else:
132                        def MFunc((m, k)):
133                            return M[tree_left, u, m] + M[tree_right, w, k] + matrix[m, k]
134                        m, k = min([(m, k) for m in other(u, Vll, Vlr) for k in other(w, Vrl, Vrr)], key=MFunc)
135                        M[tree, u, w] = M[tree, w, u] = MFunc((m, k))
136                        ordering[tree, u, w] = (tree_left, u, m, tree_right, w, k)
137                        ordering[tree, w, u] = (tree_right, w, k, tree_left, u, m)
138
139            if progressCallback:
140                progressCallback(100.0 * len(visitedClusters) / len(tree.mapping))
141                visitedClusters.add(tree)
142       
143    _optOrdering(tree)
144
145    def _order(tree, u, w):
146        if len(tree)==1:
147            return
148        left, u, m, right, w, k = ordering[tree, u, w]
149        if len(left)>1 and m not in left.right:
150            left.swap()
151        _order(left, u, m)
152##        if u!=left[0] or m!=left[-1]:
153##            print "error 4:", u, m, list(left)
154        if len(right)>1 and k not in right.left:
155            right.swap()
156        _order(right, k, w)
157##        if k!=right[0] or w!=right[-1]:
158##            print "error 5:", k, w, list(right)
159   
160    u, w = min([(u, w) for u in tree.left for w in tree.right], key=lambda (u, w): M[tree, u, w])
161   
162##    print "M(v) =", M[tree, u, w]
163   
164    _order(tree, u, w)
165
166    def _check(tree, u, w):
167        if len(tree)==1:
168            return
169        left, u, m, right, w, k = ordering[tree, u, w]
170        if tree[0] == u and tree[-1] == w:
171            _check(left, u, m)
172            _check(right, k, w)
173        else:
174            print "Error:", u, w, tree[0], tree[-1]
175
176    _check(tree, u ,w)
177   
178
179    if objects:
180        tree.mapping.setattr("objects", objects)
181
182try:
183    import numpy
184except ImportError:
185    numpy = None
186
187try:
188    import matplotlib
189    from matplotlib.figure import Figure
190    from matplotlib.table import Table, Cell
191    from matplotlib.text import Text
192    from matplotlib.artist import Artist
193##    import  matplotlib.pyplot as plt
194except (ImportError, IOError), ex:
195    matplotlib = None
196    Text , Artist, Table, Cell = object, object, object, object
197
198class TableCell(Cell):
199    PAD = 0.05
200    def __init__(self, *args, **kwargs):
201        Cell.__init__(self, *args, **kwargs)
202        self._text.set_clip_on(True)
203
204class TablePlot(Table):
205    max_fontsize = 12
206    def __init__(self, xy, axes=None, bbox=None):
207        Table.__init__(self, axes or plt.gca(), bbox=bbox)
208        self.xy = xy
209        self.set_transform(self._axes.transData)
210        self._fixed_widhts = None
211        import matplotlib.pyplot as plt
212        self.max_fontsize = plt.rcParams.get("font.size", 12)
213
214    def add_cell(self, row, col, *args, **kwargs):
215        xy = (0,0)
216
217        cell = TableCell(xy, *args, **kwargs)
218        cell.set_figure(self.figure)
219        cell.set_transform(self.get_transform())
220
221        cell.set_clip_on(True)
222        cell.set_clip_box(self._axes.bbox)
223        cell._text.set_clip_box(self._axes.bbox)
224        self._cells[(row, col)] = cell
225
226    def draw(self, renderer):
227        if not self.get_visible(): return
228        self._update_positions(renderer)
229
230        keys = self._cells.keys()
231        keys.sort()
232        for key in keys:
233            self._cells[key].draw(renderer)
234
235    def _update_positions(self, renderer):
236        keys = numpy.array(self._cells.keys())
237        cells = numpy.array([[self._cells.get((row, col), None) for col in range(max(keys[:, 1] + 1))] \
238                             for row in range(max(keys[:, 0] + 1))])
239       
240        widths = self._get_column_widths(renderer)
241        x = self.xy[0] + numpy.array([numpy.sum(widths[:i]) for i in range(len(widths))])
242        y = self.xy[1] - numpy.arange(cells.shape[0]) - 0.5
243       
244        for i in range(cells.shape[0]):
245            for j in range(cells.shape[1]):
246                cells[i, j].set_xy((x[j], y[i]))
247                cells[i, j].set_width(widths[j])
248                cells[i, j].set_height(1.0)
249
250        self._width = numpy.sum(widths)
251        self._height = cells.shape[0]
252
253        self.pchanged()
254
255    def _get_column_widths(self, renderer):
256        keys = numpy.array(self._cells.keys())
257        widths = numpy.zeros(len(keys)).reshape((numpy.max(keys[:,0]+1), numpy.max(keys[:,1]+1)))
258        fontSize = self._calc_fontsize(renderer)
259        for (row, col), cell in self._cells.items():
260            cell.set_fontsize(fontSize)
261            l, b, w, h = cell._text.get_window_extent(renderer).bounds
262            transform = self._axes.transData.inverted()
263            x1, _ = transform.transform_point((0, 0))
264            x2, _ = transform.transform_point((w + w*TableCell.PAD + 10, 0))
265            w = abs(x1 - x2)
266            widths[row, col] = w
267        return numpy.max(widths, 0)
268
269    def _calc_fontsize(self, renderer):
270        transform = self._axes.transData
271        _, y1 = transform.transform_point((0, 0))
272        _, y2 = transform.transform_point((0, 1))
273        return min(max(int(abs(y1 - y2)*0.85) ,4), self.max_fontsize)
274
275    def get_children(self):
276        return self._cells.values()
277
278    def get_bbox(self):
279        return matplotlib.transform.Bbox([self.xy[0], self.xy[1], self.xy[0] + 10, self.xy[1] + 180])
280
281class DendrogramPlotPylab(object):
282    def __init__(self, root, data=None, labels=None, dendrogram_width=None, heatmap_width=None, label_width=None, space_width=None, border_width=0.05, plot_attr_names=False, cmap=None, params={}):
283        if not matplotlib:
284            raise ImportError("Could not import matplotlib module. Please make sure matplotlib is installed on your system.")
285        import matplotlib.pyplot as plt
286        self.plt = plt
287        self.root = root
288        self.data = data
289        self.labels = labels if labels else [str(i) for i in range(len(root))]
290        self.dendrogram_width = dendrogram_width
291        self.heatmap_width = heatmap_width
292        self.label_width = label_width
293        self.space_width = space_width
294        self.border_width = border_width
295        self.params = params
296        self.plot_attr_names = plot_attr_names
297
298    def plotDendrogram(self):
299        self.text_items = []
300        def draw_tree(tree):
301            if tree.branches:
302                points = []
303                for branch in tree.branches:
304                    center = draw_tree(branch)
305                    self.plt.plot([center[0], tree.height], [center[1], center[1]], color="black")
306                    points.append(center)
307                self.plt.plot([tree.height, tree.height], [points[0][1], points[-1][1]], color="black")
308                return (tree.height, (points[0][1] + points[-1][1])/2.0)
309            else:
310                return (0.0, tree.first)
311        draw_tree(self.root)
312       
313    def plotHeatMap(self):
314        import numpy.ma as ma
315        import numpy
316        dx, dy = self.root.height, 0
317        fx, fy = self.root.height/len(self.data.domain.attributes), 1.0
318        data, c, w = self.data.toNumpyMA()
319        data = (data - ma.min(data))/(ma.max(data) - ma.min(data))
320        x = numpy.arange(data.shape[1] + 1)/float(numpy.max(data.shape))
321        y = numpy.arange(data.shape[0] + 1)/float(numpy.max(data.shape))*len(self.root)
322        self.heatmap_width = numpy.max(x)
323
324        X, Y = numpy.meshgrid(x, y - 0.5)
325
326        self.meshXOffset = numpy.max(X)
327
328        self.plt.jet()
329        mesh = self.plt.pcolormesh(X, Y, data[self.root.mapping], edgecolor="b", linewidth=2)
330
331        if self.plot_attr_names:
332            names = [attr.name for attr in self.data.domain.attributes]
333            self.plt.xticks(numpy.arange(data.shape[1] + 1)/float(numpy.max(data.shape)), names)
334        self.plt.gca().xaxis.tick_top()
335        for label in self.plt.gca().xaxis.get_ticklabels():
336            label.set_rotation(45)
337
338        for tick in self.plt.gca().xaxis.get_major_ticks():
339            tick.tick1On = False
340            tick.tick2On = False
341
342    def plotLabels_(self):
343        import numpy
344##        self.plt.yticks(numpy.arange(len(self.labels) - 1, 0, -1), self.labels)
345##        for tick in self.plt.gca().yaxis.get_major_ticks():
346##            tick.tick1On = False
347##            tick.label1On = False
348##            tick.label2On = True
349##        text = TableTextLayout(xy=(self.meshXOffset+1, len(self.root)), tableText=[[label] for label in self.labels])
350        text = TableTextLayout(xy=(self.meshXOffset*1.005, len(self.root) - 1), tableText=[[label] for label in self.labels])
351        text.set_figure(self.plt.gcf())
352        self.plt.gca().add_artist(text)
353        self.plt.gca()._set_artist_props(text)
354
355    def plotLabels(self):
356##        table = TablePlot(xy=(self.meshXOffset*1.005, len(self.root) -1), axes=self.plt.gca())
357        table = TablePlot(xy=(0, len(self.root) -1), axes=self.plt.gca())
358        table.set_figure(self.plt.gcf())
359        for i,label in enumerate(self.labels):
360            table.add_cell(i, 0, width=1, height=1, text=label, loc="left", edgecolor="w")
361        table.set_zorder(0)
362        self.plt.gca().add_artist(table)
363        self.plt.gca()._set_artist_props(table)
364   
365    def plot(self, filename=None, show=False):
366        self.plt.rcParams.update(self.params)
367        labelLen = max(len(label) for label in self.labels)
368        w, h = 800, 600
369        space = 0.01 if self.space_width == None else self.space_width
370        border = self.border_width
371        width = 1.0 - 2*border
372        height = 1.0 - 2*border
373        textLineHeight = min(max(h/len(self.labels), 4), self.plt.rcParams.get("font.size", 12))
374        maxTextLineWidthEstimate = textLineHeight*labelLen
375##        print maxTextLineWidthEstimate
376        textAxisWidthRatio = 2.0*maxTextLineWidthEstimate/w
377##        print textAxisWidthRatio
378        labelsAreaRatio = min(textAxisWidthRatio, 0.4) if self.label_width == None else self.label_width
379        x, y = len(self.data.domain.attributes), len(self.data)
380
381        heatmapAreaRatio = min(1.0*y/h*x/w, 0.3) if self.heatmap_width == None else self.heatmap_width
382        dendrogramAreaRatio = 1.0 - labelsAreaRatio - heatmapAreaRatio - 2*space if self.dendrogram_width == None else self.dendrogram_width
383
384        self.fig = self.plt.figure()
385        self.labels_offset = self.root.height/20.0
386        dendrogramAxes = self.plt.axes([border, border, width*dendrogramAreaRatio, height])
387        dendrogramAxes.xaxis.grid(True)
388        import matplotlib.ticker as ticker
389
390        dendrogramAxes.yaxis.set_major_locator(ticker.NullLocator())
391        dendrogramAxes.yaxis.set_minor_locator(ticker.NullLocator())
392        dendrogramAxes.invert_xaxis()
393        self.plotDendrogram()
394        heatmapAxes = self.plt.axes([border + width*dendrogramAreaRatio + space, border, width*heatmapAreaRatio, height], sharey=dendrogramAxes)
395
396        heatmapAxes.xaxis.set_major_locator(ticker.NullLocator())
397        heatmapAxes.xaxis.set_minor_locator(ticker.NullLocator())
398        heatmapAxes.yaxis.set_major_locator(ticker.NullLocator())
399        heatmapAxes.yaxis.set_minor_locator(ticker.NullLocator())
400       
401        self.plotHeatMap()
402        labelsAxes = self.plt.axes([border + width*(dendrogramAreaRatio + heatmapAreaRatio + 2*space), border, width*labelsAreaRatio, height], sharey=dendrogramAxes)
403        self.plotLabels()
404        labelsAxes.set_axis_off()
405        labelsAxes.xaxis.set_major_locator(ticker.NullLocator())
406        labelsAxes.xaxis.set_minor_locator(ticker.NullLocator())
407        labelsAxes.yaxis.set_major_locator(ticker.NullLocator())
408        labelsAxes.yaxis.set_minor_locator(ticker.NullLocator())
409        if filename:
410            import matplotlib.backends.backend_agg
411            canvas = matplotlib.backends.backend_agg.FigureCanvasAgg(self.fig)
412            canvas.print_figure(filename)
413        if show:
414            self.plt.show()
415       
416       
417from orngMisc import ColorPalette, EPSRenderer
418class DendrogramPlot(object):
419    """ A class for drawing dendrograms
420    Example:
421    >>> a = DendrogramPlot(tree)
422    """
423    def __init__(self, tree, attr_tree = None, labels=None, data=None, width=None, height=None, tree_height=None, heatmap_width=None, text_width=None, 
424                 spacing=2, cluster_colors={}, color_palette=ColorPalette([(255, 0, 0), (0, 255, 0)]), maxv=None, minv=None, gamma=None, renderer=EPSRenderer):
425        self.tree = tree
426        self.attr_tree = attr_tree
427        self.labels = [str(ex.getclass()) for ex in data] if not labels and data and data.domain.classVar else (labels or [])
428#        self.attr_labels = [str(attr.name) for attr in data.domain.attributes] if not attr_labels and data else attr_labels or []
429        self.data = data
430        self.width, self.height = float(width) if width else None, float(height) if height else None
431        self.tree_height = tree_height
432        self.heatmap_width = heatmap_width
433        self.text_width = text_width
434        self.font_size = 10.0
435        self.linespacing = 0.0
436        self.cluster_colors = cluster_colors
437        self.horizontal_margin = 10.0
438        self.vertical_margin = 10.0
439        self.spacing = float(spacing) if spacing else None
440        self.color_palette = color_palette
441        self.minv = minv
442        self.maxv = maxv
443        self.gamma = gamma
444        self.set_matrix_color_schema(color_palette, minv, maxv, gamma)
445        self.renderer = renderer
446       
447    def set_matrix_color_schema(self, color_palette, minv, maxv, gamma=None):
448        """ Set the matrix color scheme.
449        """
450        if isinstance(color_palette, ColorPalette):
451            self.color_palette = color_palette
452        else:
453            self.color_palette = ColorPalette(color_palette)
454        self.minv = minv
455        self.maxv = maxv
456        self.gamma = gamma
457       
458    def color_shema(self):
459        vals = [float(val) for ex in self.data for val in ex if not val.isSpecial() and val.variable.varType==orange.VarTypes.Continuous] or [0]
460        avg = sum(vals)/len(vals)
461       
462        maxVal = self.maxv if self.maxv else max(vals)
463        minVal = self.minv if self.minv else min(vals)
464       
465        def _colorSchema(val):
466            if val.isSpecial():
467                return self.color_palette(None)
468            elif val.variable.varType==orange.VarTypes.Continuous:
469                r, g, b = self.color_palette((float(val) - minVal) / abs(maxVal - minVal), gamma=self.gamma)
470            elif val.variable.varType==orange.VarTypes.Discrete:
471                r = g = b = int(255.0*float(val)/len(val.variable.values))
472            return (r, g, b)
473        return _colorSchema
474   
475    def layout(self):
476        height_final = False
477        width_final = False
478        tree_height = self.tree_height or 100
479        if self.height:
480            height, height_final = self.height, True
481            heatmap_height = height - (tree_height + self.spacing if self.attr_tree else 0) - 2 * self.horizontal_margin
482            font_size =  heatmap_height / len(self.labels) #self.font_size or (height - (tree_height + self.spacing if self.attr_tree else 0) - 2 * self.horizontal_margin) / len(self.labels)
483        else:
484            font_size = self.font_size
485            heatmap_height = font_size * len(self.labels)
486            height = heatmap_height + (tree_height + self.spacing if self.attr_tree else 0) + 2 * self.horizontal_margin
487             
488        text_width = self.text_width or max([len(label) for label in self.labels] + [0]) * font_size #max([self.renderer.string_size_hint(label) for label in self.labels])
489       
490        if self.width:
491            width = self.width
492            heatmap_width = width - 2 * self.vertical_margin - tree_height - (2 if self.data else 1) * self.spacing - text_width if self.data else 0
493        else:
494            heatmap_width = len(self.data.domain.attributes) * heatmap_height / len(self.data) if self.data else 0
495            width = 2 * self.vertical_margin + tree_height + (heatmap_width + self.spacing if self.data else 0) + self.spacing + text_width
496           
497        return width, height, tree_height, heatmap_width, heatmap_height, text_width, font_size
498   
499    def plot(self, filename="graph.eps"):
500        width, height, tree_height, heatmap_width, heatmap_height, text_width, font_size = self.layout()
501        heatmap_cell_height = heatmap_height / len(self.labels)
502        heatmap_cell_width = heatmap_width / len(self.data.domain.attributes)
503       
504        self.renderer = self.renderer(width, height)
505       
506        def draw_tree(cluster, root, treeheight, treewidth, color):
507            height = treeheight * cluster.height / root.height
508            if cluster.branches:
509                centers = []
510                for branch in cluster.branches:
511                    center = draw_tree(branch, root, treeheight, treewidth, self.cluster_colors.get(branch, color))
512                    centers.append(center)
513                    self.renderer.draw_line(center[0], center[1], center[0], height, stroke_color = self.cluster_colors.get(branch, color))
514                   
515                self.renderer.draw_line(centers[0][0], height, centers[-1][0], height, stroke_color = self.cluster_colors.get(cluster, color))
516                return (centers[0][0] + centers[-1][0]) / 2.0, height
517            else:
518                return float(treewidth) * cluster.first / len(root), 0.0
519        self.renderer.save_render_state()
520        self.renderer.translate(self.vertical_margin + tree_height, self.horizontal_margin + (tree_height + self.spacing if self.attr_tree else 0) + heatmap_cell_height / 2.0)
521        self.renderer.rotate(90)
522#        print self.renderer.transform()
523        draw_tree(self.tree, self.tree, tree_height, heatmap_height, self.cluster_colors.get(self.tree, (0,0,0)))
524        self.renderer.restore_render_state()
525        if self.attr_tree:
526            self.renderer.save_render_state()
527            self.renderer.translate(self.vertical_margin + tree_height + self.spacing + heatmap_cell_width / 2.0, self.horizontal_margin + tree_height)
528            self.renderer.scale(1.0, -1.0)
529#            print self.renderer.transform()
530            draw_tree(self.attr_tree, self.attr_tree, tree_height, heatmap_width, self.cluster_colors.get(self.attr_tree, (0,0,0)))
531            self.renderer.restore_render_state()
532       
533        self.renderer.save_render_state()
534        self.renderer.translate(self.vertical_margin + tree_height + self.spacing, self.horizontal_margin + (tree_height + self.spacing if self.attr_tree else 0))
535#        print self.renderer.transform()
536        if self.data:
537            colorSchema = self.color_shema()
538            for i, ii in enumerate(self.tree):
539                ex = self.data[ii]
540                for j, jj in enumerate((self.attr_tree if self.attr_tree else range(len(self.data.domain.attributes)))):
541                    r, g, b = colorSchema(ex[jj])
542                    self.renderer.draw_rect(j * heatmap_cell_width, i * heatmap_cell_height, heatmap_cell_width, heatmap_cell_height, fill_color=(r, g, b), stroke_color=(255, 255, 255))
543       
544        self.renderer.translate(heatmap_width + self.spacing, heatmap_cell_height)
545#        print self.renderer.transform()
546        self.renderer.set_font("Times-Roman", font_size)
547        for index in self.tree: #label in self.labels:
548            self.renderer.draw_text(0.0, 0.0, self.labels[index])
549            self.renderer.translate(0.0, heatmap_cell_height)
550        self.renderer.restore_render_state()
551        self.renderer.save(filename)
552       
553def dendrogram_draw(filename, *args, **kwargs):
554    import os
555    from orngMisc import PILRenderer, EPSRenderer, SVGRenderer
556    name, ext = os.path.splitext(filename)
557    kwargs["renderer"] = {".eps":EPSRenderer, ".svg":SVGRenderer, ".png":PILRenderer}.get(ext.lower(), PILRenderer)
558#    print kwargs["renderer"], ext
559    d = DendrogramPlot(*args, **kwargs)
560    d.plot(filename)
561   
562if __name__=="__main__":
563    data = orange.ExampleTable("doc//datasets//brown-selected.tab")
564#    data = orange.ExampleTable("doc//datasets//iris.tab")
565    root = hierarchicalClustering(data, order=True) #, linkage=orange.HierarchicalClustering.Single)
566    attr_root = hierarchicalClustering_attributes(data, order=True)
567#    print root
568#    d = DendrogramPlotPylab(root, data=data, labels=[str(ex.getclass()) for ex in data], dendrogram_width=0.4, heatmap_width=0.3,  params={}, cmap=None)
569#    d.plot(show=True, filename="graph.png")
570
571    dendrogram_draw("graph.eps", root, attr_tree=attr_root, data=data, labels=[str(e.getclass()) for e in data], tree_height=50, #width=500, height=500,
572                          cluster_colors={root.right:(255,0,0), root.right.right:(0,255,0)}, 
573                          color_palette=ColorPalette([(255, 0, 0), (0,0,0), (0, 255,0)], gamma=0.5, 
574                                                     overflow=(255, 255, 255), underflow=(255, 255, 255))) #, minv=-0.5, maxv=0.5)
Note: See TracBrowser for help on using the repository browser.