source: orange/source/orange/hclust.cpp @ 11703:9b8d8ab7820c

Revision 11703:9b8d8ab7820c, 29.4 KB checked in by janezd <janez.demsar@…>, 7 months ago (diff)

Removed the GPL copyright notice from all files except orangeqt.

Line 
1#define USE_TR1 1
2
3#if USE_TR1
4    #if _MSC_VER
5        #define HAVE_TR1_DIR 0
6    #else
7        #define HAVE_TR1_DIR 1
8    #endif
9    // Diffrent includes required
10    #if HAVE_TR1_DIR
11        #include <tr1/unordered_map>
12    #else
13        #include <unordered_map>
14    #endif
15#endif
16
17
18#include "progress.hpp"
19
20#include "hclust.hpp"
21
22DEFINE_TOrangeVector_classDescription(PHierarchicalCluster, "THierarchicalClusterList", true, ORANGE_API)
23
24#include "hclust.ppp"
25
26
27class TClusterW {
28public:
29    TClusterW *next; // for cluster, this is next cluster, for elements next element
30    TClusterW *left, *right; // subclusters, if left==NULL, this is an element
31    int size;
32    int elementIndex;
33    float height;
34
35    float *distances; // distances to clusters before this one (lower left matrix)
36    float minDistance; // minimal distance
37    int rawIndexMinDistance; // index of minimal distance
38    int nDistances;
39
40    TClusterW(const int &elIndex, float *adistances, const int &anDistances)
41    : next(NULL),
42      left(NULL),
43      right(NULL),
44      size(1),
45      elementIndex(elIndex),
46      height(0.0),
47      distances(adistances),
48      minDistance(numeric_limits<float>::max()),
49      rawIndexMinDistance(-1),
50      nDistances(anDistances)
51    {
52      if (distances)
53        computeMinimalDistance();
54    }
55
56    void elevate(TClusterW **aright, const float &aheight)
57    {
58      left = mlnew TClusterW(*this);
59      right = *aright;
60
61      left->distances = NULL;
62      size = left->size + right->size;
63      elementIndex = -1;
64      height = aheight;
65
66      if (next == right)
67        next = right->next;
68      else
69        *aright = right->next;
70    }
71
72
73    void computeMinimalDistance()
74    {
75      float *dp = distances, *minp = dp++;
76      for(int i = nDistances; --i; dp++)
77        if ((*dp >= 0) && (*dp < *minp))
78          minp = dp;
79      minDistance = *minp;
80      rawIndexMinDistance = minp - distances;
81    }
82
83    TClusterW **clusterAt(int rawIndex, TClusterW **cluster)
84    {
85      for(float *dp = distances; rawIndex; dp++)
86        if (*dp>=0) {
87          cluster = &(*cluster)->next;
88          rawIndex--;
89        }
90      return cluster;
91    }
92
93    void clearDistances()
94    {
95      raiseErrorWho("TClusterW", "rewrite clean to adjust indexMinDistance");
96
97      float *dst = distances;
98      int i = nDistances;
99      for(; nDistances && (*dst >= 0); nDistances--, dst++);
100      if (!nDistances)
101        return;
102     
103      float *src = dst;
104      for(src++, nDistances--; nDistances && (*src >= 0); src++)
105        if (*src >= 0)
106          *dst++ = *src;
107
108      nDistances = dst - distances;
109    }
110};
111
112
113THierarchicalCluster::THierarchicalCluster()
114: branches(),
115  height(0.0),
116  mapping(),
117  first(0),
118  last(0)
119{}
120
121
122THierarchicalCluster::THierarchicalCluster(PIntList els, const int &elementIndex)
123: branches(),
124  height(0.0),
125  mapping(els),
126  first(elementIndex),
127  last(elementIndex+1)
128{}
129
130
131THierarchicalCluster::THierarchicalCluster(PIntList els, PHierarchicalCluster left, PHierarchicalCluster right, const float &h, const int &f, const int &l)
132: branches(new THierarchicalClusterList(2)),
133  height(h),
134  mapping(els),
135  first(f),
136  last(l)
137{ 
138  branches->at(0) = left;
139  branches->at(1) = right;
140}
141
142
143void THierarchicalCluster::swap()
144{
145  if (!branches || (branches->size()<2))
146    return;
147  if (branches->size() > 2)
148    raiseError("cannot swap multiple branches (use method 'permutation' instead)");
149
150  const TIntList::iterator beg0 = mapping->begin() + branches->at(0)->first;
151  const TIntList::iterator beg1 = mapping->begin() + branches->at(1)->first;
152  const TIntList::iterator end1 = mapping->begin() + branches->at(1)->last;
153
154  if ((branches->at(0)->first > branches->at(1)->first) || (branches->at(1)->first > branches->at(1)->last))
155    raiseError("internal inconsistency in clustering structure: invalid ordering of left's and right's 'first' and 'last'");
156
157  TIntList::iterator li0, li1;
158
159  int *temp = new int [beg1 - beg0], *t;
160  for(li0 = beg0, t = temp; li0 != beg1; *t++ = *li0++);
161  for(li0 = beg0, li1 = beg1; li1 != end1; *li0++ = *li1++);
162  for(t = temp; li0 != end1; *li0++ = *t++);
163  delete temp;
164
165  branches->at(0)->recursiveMove(end1 - beg1);
166  branches->at(1)->recursiveMove(beg0 - beg1);
167
168  PHierarchicalCluster tbr = branches->at(0);
169  branches->at(0) = branches->at(1);
170  branches->at(1) = tbr;
171}
172
173
174void THierarchicalCluster::permute(const TIntList &neworder)
175{
176  if ((!branches && neworder.size()) || (branches->size() != neworder.size()))
177    raiseError("the number of clusters does not match the lenght of the permutation vector");
178
179  int *temp = new int [last - first], *t = temp;
180  TIntList::const_iterator pi = neworder.begin();
181  THierarchicalClusterList::iterator bi(branches->begin()), be(branches->end());
182  THierarchicalClusterList newBranches;
183
184  for(; bi != be; bi++, pi++) {
185    PHierarchicalCluster branch = branches->at(*pi);
186    newBranches.push_back(branch);
187    TIntList::const_iterator bei(mapping->begin() + branch->first), bee(mapping->begin() + branch->last);
188    const int offset = (t - temp) - (branch->first - first);
189    for(; bei != bee; *t++ = *bei++);
190    if (offset)
191      branch->recursiveMove(offset);
192  }
193
194  TIntList::iterator bei(mapping->begin() + first), bee(mapping->begin() + last);
195  for(t = temp; bei!=bee; *bei++ = *t++);
196
197  bi = branches->begin();
198  THierarchicalClusterList::const_iterator nbi(newBranches.begin());
199  for(; bi != be; *bi++ = *nbi++);
200}
201
202
203void THierarchicalCluster::recursiveMove(const int &offset)
204{
205  first += offset;
206  last += offset;
207  if (branches)
208    PITERATE(THierarchicalClusterList, bi, branches)
209      (*bi)->recursiveMove(offset);
210}
211
212
213THierarchicalClustering::THierarchicalClustering()
214: linkage(Single),
215  overwriteMatrix(false)
216{}
217
218
219TClusterW **THierarchicalClustering::init(const int &dim, float *distanceMatrix)
220{
221  for(float *ddi = distanceMatrix, *dde = ddi + ((dim+1)*(dim+2))/2; ddi!=dde; ddi++)
222    if (*ddi < 0) {
223      int x, y;
224      TSymMatrix::index2coordinates(ddi-distanceMatrix, x, y);
225      raiseError("distance matrix contains negative element at (%i, %i)", x, y);
226    }
227
228  TClusterW **clusters = mlnew TClusterW *[dim];
229  TClusterW **clusteri = clusters;
230
231  *clusters = mlnew TClusterW(0, NULL, 0);
232  distanceMatrix++;
233 
234  for(int elementIndex = 1, e = dim; elementIndex < e; distanceMatrix += ++elementIndex) {
235    TClusterW *newcluster = mlnew TClusterW(elementIndex, distanceMatrix, elementIndex);
236    (*clusteri++)->next = newcluster;
237    *clusteri = newcluster; 
238  }
239
240  return clusters;
241}
242
243
244
245TClusterW *THierarchicalClustering::merge_SingleLinkage(TClusterW **clusters, float *milestones)
246{
247  float *milestone = milestones;
248
249  int step = 0;
250  while((*clusters)->next) {
251    if (milestone && (step++ ==*milestone))
252      progressCallback->call(*((++milestone)++));
253
254    TClusterW *cluster;
255    TClusterW **pcluster2;
256
257    float minDistance = numeric_limits<float>::max();
258    for(TClusterW **tcluster = &((*clusters)->next); *tcluster; tcluster = &(*tcluster)->next)
259      if ((*tcluster)->minDistance < minDistance) {
260        minDistance = (*tcluster)->minDistance;
261        pcluster2 = tcluster;
262      }
263
264    TClusterW *const cluster2 = *pcluster2;
265
266    const int rawIndex1 = cluster2->rawIndexMinDistance;
267    const int rawIndex2 = cluster2->nDistances;
268
269    TClusterW *const cluster1 = clusters[rawIndex1];
270
271    float *disti1 = cluster1->distances;
272    float *disti2 = cluster2->distances;
273
274    if (rawIndex1) { // not root - has no distances...
275      const float *minIndex1 = cluster1->distances + cluster1->rawIndexMinDistance;
276      for(int ndi = cluster1->nDistances; ndi--; disti1++, disti2++)
277        if (*disti2 < *disti1) // if one is -1, they both are
278          if ((*disti1 = *disti2) < *minIndex1)
279            minIndex1 = disti1;
280      cluster1->minDistance = *minIndex1;
281      cluster1->rawIndexMinDistance = minIndex1 - cluster1->distances;
282    }
283
284    while(*disti2 < 0)
285      disti2++;        // should have at least one more >=0  - the one corresponding to distance to cluster1
286
287    for(cluster = cluster1->next; cluster != cluster2; cluster = cluster->next) {
288      while(*++disti2 < 0); // should have more - the one corresponding to cluster
289      if (*disti2 < cluster->distances[rawIndex1])
290        if ((cluster->distances[rawIndex1] = *disti2) < cluster->minDistance) {
291          cluster->minDistance = *disti2;
292          cluster->rawIndexMinDistance = rawIndex1;
293        }
294    }
295
296    for(cluster = cluster->next; cluster; cluster = cluster->next) {
297      if (cluster->distances[rawIndex2] < cluster->distances[rawIndex1])
298        cluster->distances[rawIndex1] = cluster->distances[rawIndex2];
299
300      // don't nest this in the above if -- both distances can be equal, yet index HAS TO move
301      if (rawIndex2 == cluster->rawIndexMinDistance)
302        cluster->rawIndexMinDistance = rawIndex1; // the smallest element got moved
303
304      cluster->distances[rawIndex2] = -1;
305    }
306
307    cluster1->elevate(pcluster2, minDistance);
308  }
309
310  return *clusters;
311}
312
313// Also computes Ward's linkage
314TClusterW *THierarchicalClustering::merge_AverageLinkage(TClusterW **clusters, float *milestones)
315{
316  float *milestone = milestones;
317  bool ward = linkage == Ward;
318
319  int step = 0;
320  while((*clusters)->next) {
321    if (milestone && (step++ ==*milestone))
322      progressCallback->call(*((++milestone)++));
323
324    TClusterW *cluster;
325    TClusterW **pcluster2;
326
327    float minDistance = numeric_limits<float>::max();
328    for(TClusterW **tcluster = &((*clusters)->next); *tcluster; tcluster = &(*tcluster)->next)
329      if ((*tcluster)->minDistance < minDistance) {
330        minDistance = (*tcluster)->minDistance;
331        pcluster2 = tcluster;
332      }
333
334    TClusterW *const cluster2 = *pcluster2;
335
336    const int rawIndex1 = cluster2->rawIndexMinDistance;
337    const int rawIndex2 = cluster2->nDistances;
338
339    TClusterW *const cluster1 = clusters[rawIndex1];
340
341    float *disti1 = cluster1->distances;
342    float *disti2 = cluster2->distances;
343
344    const int size1 = cluster1->size;
345    const int size2 = cluster2->size;
346    const int sumsize = size1 + size2;
347    cluster = (*clusters)->next;
348
349    if (rawIndex1) { // not root - has no distances...
350      const float sizeK = cluster->size;
351      *disti1 = ward ? (*disti1 * (size1+sizeK) + *disti2 * (size2+sizeK) - minDistance * sizeK) / (sumsize+sizeK)
352                     : (*disti1 * size1 + *disti2 * size2) / sumsize;
353      const float *minIndex1 = disti1;
354      int ndi = cluster1->nDistances-1;
355      for(disti1++, disti2++, cluster = cluster->next; ndi--; disti1++, disti2++)
356        if (*disti1 >= 0) {
357          const float sizeK = cluster->size;
358          cluster = cluster->next;
359          *disti1 = ward ? (*disti1 * (size1+sizeK) + *disti2 * (size2+sizeK) - minDistance * sizeK) / (sumsize+sizeK)
360                         : (*disti1 * size1         + *disti2 * size2) / sumsize;
361          if (*disti1 < *minIndex1)
362            minIndex1 = disti1;
363        }
364      cluster1->minDistance = *minIndex1;
365      cluster1->rawIndexMinDistance = minIndex1 - cluster1->distances;
366    }
367
368    while(*disti2 < 0)
369      disti2++;        // should have at least one more >=0  - the one corresponding to distance to cluster1
370
371    for(cluster = cluster1->next; cluster != cluster2; cluster = cluster->next) {
372      while(*++disti2 < 0); // should have more - the one corresponding to cluster
373      float &distc = cluster->distances[rawIndex1];
374      const float sizeK = cluster->size;
375      distc = ward ? (distc * (size1+sizeK) + *disti2 * (size2+sizeK) - minDistance * sizeK) / (sumsize+sizeK)
376                   : (distc * size1         + *disti2 * size2) / sumsize;
377      if (distc < cluster->minDistance) {
378        cluster->minDistance = distc;
379        cluster->rawIndexMinDistance = rawIndex1;
380      }
381      else if ((distc > cluster->minDistance) && (cluster->rawIndexMinDistance == rawIndex1))
382        cluster->computeMinimalDistance();
383    }
384
385    for(cluster = cluster->next; cluster; cluster = cluster->next) {
386      float &distc = cluster->distances[rawIndex1];
387      const float sizeK = cluster->size;
388      distc = ward ? (distc * (size1+sizeK) + cluster->distances[rawIndex2] * (size2+sizeK) - minDistance * sizeK) / (sumsize+sizeK)
389                   : (distc * size1         + cluster->distances[rawIndex2] * size2) / sumsize;
390      cluster->distances[rawIndex2] = -1;
391      if (distc < cluster->minDistance) {
392        cluster->minDistance = distc;
393        cluster->rawIndexMinDistance = rawIndex1;
394      }
395      else if (   (distc > cluster->minDistance) && (cluster->rawIndexMinDistance == rawIndex1)
396               || (cluster->rawIndexMinDistance == rawIndex2))
397        cluster->computeMinimalDistance();
398    }
399
400    cluster1->elevate(pcluster2, minDistance);
401  }
402
403  return *clusters;
404}
405
406
407
408TClusterW *THierarchicalClustering::merge_CompleteLinkage(TClusterW **clusters, float *milestones)
409{
410  float *milestone = milestones;
411
412  int step = 0;
413  while((*clusters)->next) {
414    if (milestone && (step++ ==*milestone))
415      progressCallback->call(*((++milestone)++));
416
417    TClusterW *cluster;
418    TClusterW **pcluster2;
419
420    float minDistance = numeric_limits<float>::max();
421    for(TClusterW **tcluster = &((*clusters)->next); *tcluster; tcluster = &(*tcluster)->next)
422      if ((*tcluster)->minDistance < minDistance) {
423        minDistance = (*tcluster)->minDistance;
424        pcluster2 = tcluster;
425      }
426
427    TClusterW *const cluster2 = *pcluster2;
428
429    const int rawIndex1 = cluster2->rawIndexMinDistance;
430    const int rawIndex2 = cluster2->nDistances;
431
432    TClusterW *const cluster1 = clusters[rawIndex1];
433
434    float *disti1 = cluster1->distances;
435    float *disti2 = cluster2->distances;
436
437    if (rawIndex1) { // not root - has no distances...
438      if (*disti2 > *disti1)
439        *disti1 = *disti2;
440      const float *minIndex1 = disti1;
441      int ndi = cluster1->nDistances-1;
442      for(disti1++, disti2++; ndi--; disti1++, disti2++) {
443        if (*disti1 >= 0) {
444          if (*disti2 > *disti1) // if one is -1, they both are
445            *disti1 = *disti2;
446          if (*disti1 < *minIndex1)
447            minIndex1 = disti1;
448        }
449      }
450      cluster1->minDistance = *minIndex1;
451      cluster1->rawIndexMinDistance = minIndex1 - cluster1->distances;
452    }
453
454    while(*disti2 < 0)
455      disti2++;        // should have at least one more >=0  - the one corresponding to distance to cluster1
456
457    for(cluster = cluster1->next; cluster != cluster2; cluster = cluster->next) {
458      while(*++disti2 < 0); // should have more - the one corresponding to cluster
459      float &distc = cluster->distances[rawIndex1];
460      if (*disti2 > distc) {
461        distc = *disti2;
462        if (cluster->rawIndexMinDistance == rawIndex1)
463          if (distc <= cluster->minDistance)
464            cluster->minDistance = distc;
465          else
466            cluster->computeMinimalDistance();
467      }
468    }
469
470    for(cluster = cluster->next; cluster; cluster = cluster->next) {
471      float &distc = cluster->distances[rawIndex1];
472      if (cluster->distances[rawIndex2] > distc)
473        distc = cluster->distances[rawIndex2];
474
475      cluster->distances[rawIndex2] = -1;
476      if ((cluster->rawIndexMinDistance == rawIndex1) || (cluster->rawIndexMinDistance == rawIndex2))
477        cluster->computeMinimalDistance();
478    }
479
480    cluster1->elevate(pcluster2, minDistance);
481  }
482
483  return *clusters;
484}
485
486
487
488TClusterW *THierarchicalClustering::merge(TClusterW **clusters, float *milestones)
489{
490  switch(linkage) {
491    case Complete: return merge_CompleteLinkage(clusters, milestones);
492    case Single: return merge_SingleLinkage(clusters, milestones);
493    case Average: 
494    case Ward:
495    default: return merge_AverageLinkage(clusters, milestones);
496  }
497}
498
499PHierarchicalCluster THierarchicalClustering::restructure(TClusterW *root)
500{
501  PIntList elementIndices = new TIntList(root->size);
502  TIntList::iterator currentElement(elementIndices->begin());
503  int currentIndex = 0;
504
505  return restructure(root, elementIndices, currentElement, currentIndex);
506}
507
508
509PHierarchicalCluster THierarchicalClustering::restructure(TClusterW *node, PIntList elementIndices, TIntList::iterator &currentElement, int &currentIndex)
510{
511  PHierarchicalCluster cluster;
512
513  if (!node->left) {
514    *currentElement++ = node->elementIndex;
515    cluster = mlnew THierarchicalCluster(elementIndices, currentIndex++);
516  }
517  else {
518    PHierarchicalCluster left = restructure(node->left, elementIndices, currentElement, currentIndex);
519    PHierarchicalCluster right = restructure(node->right, elementIndices, currentElement, currentIndex);
520    cluster = mlnew THierarchicalCluster(elementIndices, left, right, node->height, left->first, right->last);
521  }
522
523  // no need to care about 'distances' - they've been removed during clustering (in 'elevate') :)
524  mldelete node;
525  return cluster;
526}
527
528
529PHierarchicalCluster THierarchicalClustering::operator()(PSymMatrix distanceMatrix)
530{
531  float *distanceMatrixElements = NULL;
532  TClusterW **clusters, *root;
533  float *callbackMilestones = NULL;
534 
535  try {
536    const int dim = distanceMatrix->dim;
537    const int size = ((dim+1)*(dim+2))/2;
538    float *distanceMatrixElements = overwriteMatrix ? distanceMatrix->elements : (float *)memcpy(new float[size], distanceMatrix->elements, size*sizeof(float));
539
540    clusters = init(dim, distanceMatrixElements);
541    callbackMilestones = (progressCallback && (distanceMatrix->dim >= 1000)) ? progressCallback->milestones(distanceMatrix->dim) : NULL;
542    root = merge(clusters, callbackMilestones);
543  }
544  catch (...) {
545    mldelete clusters;
546    mldelete callbackMilestones;
547    mldelete distanceMatrixElements;
548    throw;
549  }
550
551  mldelete clusters;
552  mldelete callbackMilestones;
553  mldelete distanceMatrixElements;
554
555  return restructure(root);
556}
557
558
559/*
560 *  Optimal leaf ordering.
561 */
562
563struct m_element {
564    THierarchicalCluster * cluster;
565    int left;
566    int right;
567
568    m_element(THierarchicalCluster * cluster, int left, int right);
569    m_element(const m_element & other);
570
571    inline bool operator< (const m_element & other) const;
572    inline bool operator== (const m_element & other) const;
573}; // cluster joined at left and right index
574
575struct ordering_element {
576    THierarchicalCluster * left;
577    unsigned int u; // the left most (outer) index of left luster
578    unsigned m; // the rightmost (inner) index of left cluster
579    THierarchicalCluster * right;
580    unsigned int w; // the right most (outer) index of the right cluster
581    unsigned int k; // the left most (inner) index of the right cluster
582
583    ordering_element();
584    ordering_element(THierarchicalCluster * left, unsigned int u, unsigned m,
585            THierarchicalCluster * right, unsigned int w, unsigned int k);
586    ordering_element(const ordering_element & other);
587};
588
589m_element::m_element(THierarchicalCluster * _cluster, int _left, int _right):
590    cluster(_cluster), left(_left), right(_right)
591{}
592
593m_element::m_element(const m_element & other):
594    cluster(other.cluster), left(other.left), right(other.right)
595{}
596
597bool m_element::operator< (const m_element & other) const
598{
599    if (cluster < other.cluster)
600        return true;
601    else
602        if (cluster == other.cluster)
603            if (left < other.left)
604                return true;
605            else
606                if (left == other.left)
607                    return right < other.right;
608                else
609                    return false;
610        else
611            return false;
612}
613
614bool m_element::operator== (const m_element & other) const
615{
616    return cluster == other.cluster && left == other.left && right == other.right;
617}
618
619struct m_element_hash
620{
621    inline size_t operator()(const m_element & m) const
622    {
623        size_t seed = 0;
624        hash_combine(seed, (size_t) m.cluster);
625        hash_combine(seed, (size_t) m.left);
626        hash_combine(seed, (size_t) m.right);
627        return seed;
628    }
629
630    // more or less taken from boost::hash_combine
631    inline void hash_combine(size_t &seed, size_t val) const
632    {
633        seed ^= val + 0x9e3779b9 + (seed << 6) + (seed >> 2);
634    }
635};
636
637ordering_element::ordering_element():
638    left(NULL), u(-1), m(-1),
639    right(NULL), w(-1), k(-1)
640{}
641
642ordering_element::ordering_element(THierarchicalCluster * _left,
643        unsigned int _u,
644        unsigned _m,
645        THierarchicalCluster * _right,
646        unsigned int _w,
647        unsigned int _k
648    ): left(_left), u(_u), m(_m),
649       right(_right), w(_w), k(_k)
650{}
651
652ordering_element::ordering_element(const ordering_element & other):
653       left(other.left), u(other.u), m(other.m),
654       right(other.right), w(other.w), k(other.k)
655{}
656
657
658#if USE_TR1
659    typedef std::tr1::unordered_map<m_element, double, m_element_hash> join_scores;
660    typedef std::tr1::unordered_map<m_element, ordering_element, m_element_hash> cluster_ordering;
661#else
662    typedef std::map<m_element, double> join_scores;
663    typedef std::map<m_element, ordering_element> cluster_ordering;
664#endif
665
666// Return the minimum distance between elements in matrix
667//
668float min_distance(
669        TIntList::iterator indices1_begin,
670        TIntList::iterator indices1_end,
671        TIntList::iterator indices2_begin,
672        TIntList::iterator indices2_end,
673        TSymMatrix & matrix)
674{
675//  TIntList::iterator iter1 = indices1.begin(), iter2 = indices2.begin();
676    float minimum = std::numeric_limits<float>::infinity();
677    TIntList::iterator indices2;
678    for (; indices1_begin != indices1_end; indices1_begin++)
679        for (indices2 = indices2_begin; indices2 != indices2_end; indices2++){
680            minimum = std::min(matrix.getitem(*indices1_begin, *indices2), minimum);
681        }
682    return minimum;
683}
684
685struct CompareByScores
686{
687    join_scores & scores;
688    const THierarchicalCluster & cluster;
689    const int & fixed;
690
691    CompareByScores(join_scores & _scores, const THierarchicalCluster & _cluster, const int & _fixed):
692        scores(_scores), cluster(_cluster), fixed(_fixed)
693    {}
694    bool operator() (int lhs, int rhs)
695    {
696        m_element left((THierarchicalCluster*)&cluster, fixed, lhs);
697        m_element right((THierarchicalCluster*)&cluster, fixed, rhs);
698        return scores[left] < scores[right];
699    }
700};
701
702
703//#include <iostream>
704//#include <cassert>
705
706// This needs to be called with all left, right pairs to
707// update all scores for cluster.
708void partial_opt_ordering(
709        THierarchicalCluster & cluster,
710        THierarchicalCluster & left,
711        THierarchicalCluster & right,
712        THierarchicalCluster & left_left,
713        THierarchicalCluster & left_right,
714        THierarchicalCluster & right_left,
715        THierarchicalCluster & right_right,
716        TSymMatrix &matrix,
717        join_scores & M,
718        cluster_ordering & ordering)
719{
720    int u = 0, w = 0;
721    TIntList & mapping = cluster.mapping.getReference();
722    for (TIntList::iterator u_iter = mapping.begin() + left_left.first;
723            u_iter != mapping.begin() + left_left.last;
724            u_iter++)
725        for (TIntList::iterator w_iter = mapping.begin() + right_right.first;
726                w_iter != mapping.begin() + right_right.last;
727                w_iter++)
728        {
729            u = *u_iter;
730            w = *w_iter;
731            float curr_min = std::numeric_limits<float>::infinity();
732            int curr_k = 0, curr_m = 0;
733            float C = min_distance(mapping.begin() + left_right.first,
734                    mapping.begin() + left_right.last,
735                    mapping.begin() + right_left.first,
736                    mapping.begin() + right_left.last,
737                    matrix);
738
739            vector<int> m_ordered(mapping.begin() + left_right.first,
740                    mapping.begin() + left_right.last);
741            vector<int> k_ordered(mapping.begin() + right_left.first,
742                    mapping.begin() + right_left.last);
743
744            // TODO: precompute the scores for m and k in an array and use a simpler
745            // comparison function
746            std::sort(m_ordered.begin(), m_ordered.end(), CompareByScores(M, left, u));
747            std::sort(k_ordered.begin(), k_ordered.end(), CompareByScores(M, right, w));
748
749
750            int k0 = k_ordered.front();
751            m_element m_right_k0(&right, w, k0);
752            int m = 0, k = 0;
753            for (vector<int>::iterator iter_m=m_ordered.begin(); iter_m != m_ordered.end(); iter_m++)
754            {
755                m = *iter_m;
756
757                m_element m_left(&left, u, m);
758
759                if (M[m_left] + M[m_right_k0] + C >= curr_min){
760                    break;
761                }
762                for (vector<int>::iterator iter_k = k_ordered.begin(); iter_k != k_ordered.end(); iter_k++)
763                {
764                    k = *iter_k;
765                    m_element m_right(&right, w, k);
766                    if (M[m_left] + M[m_right] + C >= curr_min)
767                    {
768                        break;
769                    }
770                    float test_val = M[m_left] + M[m_right] + matrix.getitem(m, k);
771                    if (curr_min > test_val)
772                    {
773                        curr_min = test_val;
774                        curr_k = k;
775                        curr_m = m;
776                    }
777                }
778
779            }
780
781            M[m_element(&cluster, u, w)] = curr_min;
782            M[m_element(&cluster, w, u)] = curr_min;
783
784//          assert(M[m_element(&cluster, u, w)] == M[m_element(&cluster, u, w)]);
785//          assert(M[m_element(&cluster, u, w)] == curr_min);
786
787//          assert(ordering.find(m_element(&cluster, u, w)) == ordering.end());
788//          assert(ordering.find(m_element(&cluster, w, u)) == ordering.end());
789
790            ordering[m_element(&cluster, u, w)] = ordering_element(&left, u, curr_m, &right, w, curr_k);
791            ordering[m_element(&cluster, w, u)] = ordering_element(&right, w, curr_k, &left, u, curr_m);
792        }
793}
794
795void order_clusters(
796        THierarchicalCluster & cluster,
797        TSymMatrix &matrix,
798        join_scores & M,
799        cluster_ordering & ordering,
800        TProgressCallback * callback)
801{
802    if (cluster.size() == 1)
803    {
804        M[m_element(&cluster, cluster.mapping->at(cluster.first), cluster.mapping->at(cluster.first))] = 0.0;
805        return;
806    }
807    else if (cluster.branches->size() == 2)
808    {
809        PHierarchicalCluster left = cluster.branches->at(0);
810        PHierarchicalCluster right = cluster.branches->at(1);
811
812        order_clusters(left.getReference(), matrix, M, ordering, callback);
813        order_clusters(right.getReference(), matrix, M, ordering, callback);
814
815        PHierarchicalCluster  left_left = (!left->branches) ? left : left->branches->at(0);
816        PHierarchicalCluster  left_right = (!left->branches) ? left : left->branches->at(1);
817        PHierarchicalCluster  right_left = (!right->branches) ? right : right->branches->at(0);
818        PHierarchicalCluster  right_right = (!right->branches) ? right : right->branches->at(1);
819
820        // 1.)
821        partial_opt_ordering(cluster,
822                left.getReference(), right.getReference(),
823                left_left.getReference(), left_right.getReference(),
824                right_left.getReference(), right_right.getReference(),
825                matrix, M, ordering);
826
827        if (right->branches)
828            // 2.) Switch right branches.
829            // (if there are no right branches the ordering has already been evaluated in 1.)
830            partial_opt_ordering(cluster,
831                    left.getReference(), right.getReference(),
832                    left_left.getReference(), left_right.getReference(),
833                    right_right.getReference(), right_left.getReference(),
834                    matrix, M, ordering);
835
836        if (left->branches)
837            // 3.) Switch left branches.
838            // (if there are no left branches the ordering has already been evaluated in 1. and 2.)
839            partial_opt_ordering(cluster,
840                    left.getReference(), right.getReference(),
841                    left_right.getReference(), left_left.getReference(),
842                    right_left.getReference(), right_right.getReference(),
843                    matrix, M, ordering);
844
845        if (left->branches && right->branches)
846            // 4.) Switch both branches.
847            partial_opt_ordering(cluster,
848                    left.getReference(), right.getReference(),
849                    left_right.getReference(), left_left.getReference(),
850                    right_right.getReference(), right_left.getReference(),
851                    matrix, M, ordering);
852    }
853    if (callback)
854        // TODO: count the number of already processed nodes.
855        callback->operator()(0.0, PHierarchicalCluster(&cluster));
856}
857
858/* Check if TIntList contains element.
859 */
860bool contains(TIntList::iterator iter_begin, TIntList::iterator iter_end, int element)
861{
862    return std::find(iter_begin, iter_end, element) != iter_end;
863}
864
865void optimal_swap(THierarchicalCluster * cluster, int u, int w, cluster_ordering & ordering)
866{
867    if (cluster->branches)
868    {
869        assert(ordering.find(m_element(cluster, u, w)) != ordering.end());
870        ordering_element ord = ordering[m_element(cluster, u, w)];
871
872        PHierarchicalCluster left_right = (ord.left->branches)? ord.left->branches->at(1) : PHierarchicalCluster(NULL);
873        PHierarchicalCluster right_left = (ord.right->branches)? ord.right->branches->at(0) : PHierarchicalCluster(NULL);
874
875        TIntList & mapping = cluster->mapping.getReference();
876        if (left_right && !contains(mapping.begin() + left_right->first,
877                mapping.begin() + left_right->last, ord.m))
878        {
879            assert(!contains(mapping.begin() + left_right->first,
880                             mapping.begin() + left_right->last, ord.m));
881            ord.left->swap();
882            left_right = ord.left->branches->at(1);
883            assert(contains(mapping.begin() + left_right->first,
884                            mapping.begin() + left_right->last, ord.m));
885        }
886        optimal_swap(ord.left, ord.u, ord.m, ordering);
887
888        assert(mapping.at(ord.left->first) == ord.u);
889        assert(mapping.at(ord.left->last - 1) == ord.m);
890
891        if (right_left && !contains(mapping.begin() + right_left->first,
892                mapping.begin() + right_left->last, ord.k))
893        {
894            assert(!contains(mapping.begin() + right_left->first,
895                             mapping.begin() + right_left->last, ord.k));
896            ord.right->swap();
897            right_left = ord.right->branches->at(0);
898            assert(contains(mapping.begin() + right_left->first,
899                            mapping.begin() + right_left->last, ord.k));
900        }
901        optimal_swap(ord.right, ord.k, ord.w, ordering);
902
903        assert(mapping.at(ord.right->first) == ord.k);
904        assert(mapping.at(ord.right->last - 1) == ord.w);
905
906        assert(mapping.at(cluster->first) == ord.u);
907        assert(mapping.at(cluster->last - 1) == ord.w);
908    }
909}
910
911PHierarchicalCluster THierarchicalClusterOrdering::operator() (
912        PHierarchicalCluster root,
913        PSymMatrix matrix)
914{
915    join_scores M; // scores
916    cluster_ordering ordering;
917    order_clusters(root.getReference(), matrix.getReference(), M, ordering,
918            progress_callback.getUnwrappedPtr());
919
920    int u = 0, w = 0;
921    int min_u = 0, min_w = 0;
922    float min_score = std::numeric_limits<float>::infinity();
923    TIntList & mapping = root->mapping.getReference();
924
925    for (TIntList::iterator u_iter = mapping.begin() + root->branches->at(0)->first;
926            u_iter != mapping.begin() + root->branches->at(0)->last;
927            u_iter++)
928        for (TIntList::iterator w_iter = mapping.begin() + root->branches->at(1)->first;
929                    w_iter != mapping.begin() + root->branches->at(1)->last;
930                    w_iter++)
931        {
932            u = *u_iter; w = *w_iter;
933            m_element el(root.getUnwrappedPtr(), u, w);
934            if (M[el] < min_score)
935            {
936                min_score = M[el];
937                min_u = u;
938                min_w = w;
939            }
940        }
941//  std::cout << "Min score "<< min_score << endl;
942
943    optimal_swap(root.getUnwrappedPtr(), min_u, min_w, ordering);
944    return root;
945}
Note: See TracBrowser for help on using the repository browser.