source: orange/source/orange/hclust.cpp @ 11752:f1633f4dc201

Revision 11752:f1633f4dc201, 29.4 KB checked in by Ales Erjavec <ales.erjavec@…>, 5 months ago (diff)

Fix for MSVC compiler.

Line 
1
2#include <ciso646>
3#if defined(_LIBCPP_VERSION) || __cplusplus >= 201103L
4    // C++11 support or using libc++
5    #include <unordered_map>
6    #define _HCLUST_UNORDERED_MAP std::unordered_map
7#elif _MSC_VER
8    #include <unordered_map>
9    #define _HCLUST_UNORDERED_MAP std::tr1::unordered_map
10#else
11    // include and use from tr1 namespace
12    #include <tr1/unordered_map>
13    #define _HCLUST_UNORDERED_MAP std::tr1::unordered_map
14#endif
15
16
17#include "progress.hpp"
18
19#include "hclust.hpp"
20
21DEFINE_TOrangeVector_classDescription(PHierarchicalCluster, "THierarchicalClusterList", true, ORANGE_API)
22
23#include "hclust.ppp"
24
25
26class TClusterW {
27public:
28    TClusterW *next; // for cluster, this is next cluster, for elements next element
29    TClusterW *left, *right; // subclusters, if left==NULL, this is an element
30    int size;
31    int elementIndex;
32    float height;
33
34    float *distances; // distances to clusters before this one (lower left matrix)
35    float minDistance; // minimal distance
36    int rawIndexMinDistance; // index of minimal distance
37    int nDistances;
38
39    TClusterW(const int &elIndex, float *adistances, const int &anDistances)
40    : next(NULL),
41      left(NULL),
42      right(NULL),
43      size(1),
44      elementIndex(elIndex),
45      height(0.0),
46      distances(adistances),
47      minDistance(numeric_limits<float>::max()),
48      rawIndexMinDistance(-1),
49      nDistances(anDistances)
50    {
51      if (distances)
52        computeMinimalDistance();
53    }
54
55    void elevate(TClusterW **aright, const float &aheight)
56    {
57      left = mlnew TClusterW(*this);
58      right = *aright;
59
60      left->distances = NULL;
61      size = left->size + right->size;
62      elementIndex = -1;
63      height = aheight;
64
65      if (next == right)
66        next = right->next;
67      else
68        *aright = right->next;
69    }
70
71
72    void computeMinimalDistance()
73    {
74      float *dp = distances, *minp = dp++;
75      for(int i = nDistances; --i; dp++)
76        if ((*dp >= 0) && (*dp < *minp))
77          minp = dp;
78      minDistance = *minp;
79      rawIndexMinDistance = minp - distances;
80    }
81
82    TClusterW **clusterAt(int rawIndex, TClusterW **cluster)
83    {
84      for(float *dp = distances; rawIndex; dp++)
85        if (*dp>=0) {
86          cluster = &(*cluster)->next;
87          rawIndex--;
88        }
89      return cluster;
90    }
91
92    void clearDistances()
93    {
94      raiseErrorWho("TClusterW", "rewrite clean to adjust indexMinDistance");
95
96      float *dst = distances;
97      int i = nDistances;
98      for(; nDistances && (*dst >= 0); nDistances--, dst++);
99      if (!nDistances)
100        return;
101     
102      float *src = dst;
103      for(src++, nDistances--; nDistances && (*src >= 0); src++)
104        if (*src >= 0)
105          *dst++ = *src;
106
107      nDistances = dst - distances;
108    }
109};
110
111
112THierarchicalCluster::THierarchicalCluster()
113: branches(),
114  height(0.0),
115  mapping(),
116  first(0),
117  last(0)
118{}
119
120
121THierarchicalCluster::THierarchicalCluster(PIntList els, const int &elementIndex)
122: branches(),
123  height(0.0),
124  mapping(els),
125  first(elementIndex),
126  last(elementIndex+1)
127{}
128
129
130THierarchicalCluster::THierarchicalCluster(PIntList els, PHierarchicalCluster left, PHierarchicalCluster right, const float &h, const int &f, const int &l)
131: branches(new THierarchicalClusterList(2)),
132  height(h),
133  mapping(els),
134  first(f),
135  last(l)
136{ 
137  branches->at(0) = left;
138  branches->at(1) = right;
139}
140
141
142void THierarchicalCluster::swap()
143{
144  if (!branches || (branches->size()<2))
145    return;
146  if (branches->size() > 2)
147    raiseError("cannot swap multiple branches (use method 'permutation' instead)");
148
149  const TIntList::iterator beg0 = mapping->begin() + branches->at(0)->first;
150  const TIntList::iterator beg1 = mapping->begin() + branches->at(1)->first;
151  const TIntList::iterator end1 = mapping->begin() + branches->at(1)->last;
152
153  if ((branches->at(0)->first > branches->at(1)->first) || (branches->at(1)->first > branches->at(1)->last))
154    raiseError("internal inconsistency in clustering structure: invalid ordering of left's and right's 'first' and 'last'");
155
156  TIntList::iterator li0, li1;
157
158  int *temp = new int [beg1 - beg0], *t;
159  for(li0 = beg0, t = temp; li0 != beg1; *t++ = *li0++);
160  for(li0 = beg0, li1 = beg1; li1 != end1; *li0++ = *li1++);
161  for(t = temp; li0 != end1; *li0++ = *t++);
162  delete temp;
163
164  branches->at(0)->recursiveMove(end1 - beg1);
165  branches->at(1)->recursiveMove(beg0 - beg1);
166
167  PHierarchicalCluster tbr = branches->at(0);
168  branches->at(0) = branches->at(1);
169  branches->at(1) = tbr;
170}
171
172
173void THierarchicalCluster::permute(const TIntList &neworder)
174{
175  if ((!branches && neworder.size()) || (branches->size() != neworder.size()))
176    raiseError("the number of clusters does not match the lenght of the permutation vector");
177
178  int *temp = new int [last - first], *t = temp;
179  TIntList::const_iterator pi = neworder.begin();
180  THierarchicalClusterList::iterator bi(branches->begin()), be(branches->end());
181  THierarchicalClusterList newBranches;
182
183  for(; bi != be; bi++, pi++) {
184    PHierarchicalCluster branch = branches->at(*pi);
185    newBranches.push_back(branch);
186    TIntList::const_iterator bei(mapping->begin() + branch->first), bee(mapping->begin() + branch->last);
187    const int offset = (t - temp) - (branch->first - first);
188    for(; bei != bee; *t++ = *bei++);
189    if (offset)
190      branch->recursiveMove(offset);
191  }
192
193  TIntList::iterator bei(mapping->begin() + first), bee(mapping->begin() + last);
194  for(t = temp; bei!=bee; *bei++ = *t++);
195
196  bi = branches->begin();
197  THierarchicalClusterList::const_iterator nbi(newBranches.begin());
198  for(; bi != be; *bi++ = *nbi++);
199}
200
201
202void THierarchicalCluster::recursiveMove(const int &offset)
203{
204  first += offset;
205  last += offset;
206  if (branches)
207    PITERATE(THierarchicalClusterList, bi, branches)
208      (*bi)->recursiveMove(offset);
209}
210
211
212THierarchicalClustering::THierarchicalClustering()
213: linkage(Single),
214  overwriteMatrix(false)
215{}
216
217
218TClusterW **THierarchicalClustering::init(const int &dim, float *distanceMatrix)
219{
220  for(float *ddi = distanceMatrix, *dde = ddi + ((dim+1)*(dim+2))/2; ddi!=dde; ddi++)
221    if (*ddi < 0) {
222      int x, y;
223      TSymMatrix::index2coordinates(ddi-distanceMatrix, x, y);
224      raiseError("distance matrix contains negative element at (%i, %i)", x, y);
225    }
226
227  TClusterW **clusters = mlnew TClusterW *[dim];
228  TClusterW **clusteri = clusters;
229
230  *clusters = mlnew TClusterW(0, NULL, 0);
231  distanceMatrix++;
232 
233  for(int elementIndex = 1, e = dim; elementIndex < e; distanceMatrix += ++elementIndex) {
234    TClusterW *newcluster = mlnew TClusterW(elementIndex, distanceMatrix, elementIndex);
235    (*clusteri++)->next = newcluster;
236    *clusteri = newcluster; 
237  }
238
239  return clusters;
240}
241
242
243
244TClusterW *THierarchicalClustering::merge_SingleLinkage(TClusterW **clusters, float *milestones)
245{
246  float *milestone = milestones;
247
248  int step = 0;
249  while((*clusters)->next) {
250    if (milestone && (step++ ==*milestone))
251      progressCallback->call(*((++milestone)++));
252
253    TClusterW *cluster;
254    TClusterW **pcluster2;
255
256    float minDistance = numeric_limits<float>::max();
257    for(TClusterW **tcluster = &((*clusters)->next); *tcluster; tcluster = &(*tcluster)->next)
258      if ((*tcluster)->minDistance < minDistance) {
259        minDistance = (*tcluster)->minDistance;
260        pcluster2 = tcluster;
261      }
262
263    TClusterW *const cluster2 = *pcluster2;
264
265    const int rawIndex1 = cluster2->rawIndexMinDistance;
266    const int rawIndex2 = cluster2->nDistances;
267
268    TClusterW *const cluster1 = clusters[rawIndex1];
269
270    float *disti1 = cluster1->distances;
271    float *disti2 = cluster2->distances;
272
273    if (rawIndex1) { // not root - has no distances...
274      const float *minIndex1 = cluster1->distances + cluster1->rawIndexMinDistance;
275      for(int ndi = cluster1->nDistances; ndi--; disti1++, disti2++)
276        if (*disti2 < *disti1) // if one is -1, they both are
277          if ((*disti1 = *disti2) < *minIndex1)
278            minIndex1 = disti1;
279      cluster1->minDistance = *minIndex1;
280      cluster1->rawIndexMinDistance = minIndex1 - cluster1->distances;
281    }
282
283    while(*disti2 < 0)
284      disti2++;        // should have at least one more >=0  - the one corresponding to distance to cluster1
285
286    for(cluster = cluster1->next; cluster != cluster2; cluster = cluster->next) {
287      while(*++disti2 < 0); // should have more - the one corresponding to cluster
288      if (*disti2 < cluster->distances[rawIndex1])
289        if ((cluster->distances[rawIndex1] = *disti2) < cluster->minDistance) {
290          cluster->minDistance = *disti2;
291          cluster->rawIndexMinDistance = rawIndex1;
292        }
293    }
294
295    for(cluster = cluster->next; cluster; cluster = cluster->next) {
296      if (cluster->distances[rawIndex2] < cluster->distances[rawIndex1])
297        cluster->distances[rawIndex1] = cluster->distances[rawIndex2];
298
299      // don't nest this in the above if -- both distances can be equal, yet index HAS TO move
300      if (rawIndex2 == cluster->rawIndexMinDistance)
301        cluster->rawIndexMinDistance = rawIndex1; // the smallest element got moved
302
303      cluster->distances[rawIndex2] = -1;
304    }
305
306    cluster1->elevate(pcluster2, minDistance);
307  }
308
309  return *clusters;
310}
311
312// Also computes Ward's linkage
313TClusterW *THierarchicalClustering::merge_AverageLinkage(TClusterW **clusters, float *milestones)
314{
315  float *milestone = milestones;
316  bool ward = linkage == Ward;
317
318  int step = 0;
319  while((*clusters)->next) {
320    if (milestone && (step++ ==*milestone))
321      progressCallback->call(*((++milestone)++));
322
323    TClusterW *cluster;
324    TClusterW **pcluster2;
325
326    float minDistance = numeric_limits<float>::max();
327    for(TClusterW **tcluster = &((*clusters)->next); *tcluster; tcluster = &(*tcluster)->next)
328      if ((*tcluster)->minDistance < minDistance) {
329        minDistance = (*tcluster)->minDistance;
330        pcluster2 = tcluster;
331      }
332
333    TClusterW *const cluster2 = *pcluster2;
334
335    const int rawIndex1 = cluster2->rawIndexMinDistance;
336    const int rawIndex2 = cluster2->nDistances;
337
338    TClusterW *const cluster1 = clusters[rawIndex1];
339
340    float *disti1 = cluster1->distances;
341    float *disti2 = cluster2->distances;
342
343    const int size1 = cluster1->size;
344    const int size2 = cluster2->size;
345    const int sumsize = size1 + size2;
346    cluster = (*clusters)->next;
347
348    if (rawIndex1) { // not root - has no distances...
349      const float sizeK = cluster->size;
350      *disti1 = ward ? (*disti1 * (size1+sizeK) + *disti2 * (size2+sizeK) - minDistance * sizeK) / (sumsize+sizeK)
351                     : (*disti1 * size1 + *disti2 * size2) / sumsize;
352      const float *minIndex1 = disti1;
353      int ndi = cluster1->nDistances-1;
354      for(disti1++, disti2++, cluster = cluster->next; ndi--; disti1++, disti2++)
355        if (*disti1 >= 0) {
356          const float sizeK = cluster->size;
357          cluster = cluster->next;
358          *disti1 = ward ? (*disti1 * (size1+sizeK) + *disti2 * (size2+sizeK) - minDistance * sizeK) / (sumsize+sizeK)
359                         : (*disti1 * size1         + *disti2 * size2) / sumsize;
360          if (*disti1 < *minIndex1)
361            minIndex1 = disti1;
362        }
363      cluster1->minDistance = *minIndex1;
364      cluster1->rawIndexMinDistance = minIndex1 - cluster1->distances;
365    }
366
367    while(*disti2 < 0)
368      disti2++;        // should have at least one more >=0  - the one corresponding to distance to cluster1
369
370    for(cluster = cluster1->next; cluster != cluster2; cluster = cluster->next) {
371      while(*++disti2 < 0); // should have more - the one corresponding to cluster
372      float &distc = cluster->distances[rawIndex1];
373      const float sizeK = cluster->size;
374      distc = ward ? (distc * (size1+sizeK) + *disti2 * (size2+sizeK) - minDistance * sizeK) / (sumsize+sizeK)
375                   : (distc * size1         + *disti2 * size2) / sumsize;
376      if (distc < cluster->minDistance) {
377        cluster->minDistance = distc;
378        cluster->rawIndexMinDistance = rawIndex1;
379      }
380      else if ((distc > cluster->minDistance) && (cluster->rawIndexMinDistance == rawIndex1))
381        cluster->computeMinimalDistance();
382    }
383
384    for(cluster = cluster->next; cluster; cluster = cluster->next) {
385      float &distc = cluster->distances[rawIndex1];
386      const float sizeK = cluster->size;
387      distc = ward ? (distc * (size1+sizeK) + cluster->distances[rawIndex2] * (size2+sizeK) - minDistance * sizeK) / (sumsize+sizeK)
388                   : (distc * size1         + cluster->distances[rawIndex2] * size2) / sumsize;
389      cluster->distances[rawIndex2] = -1;
390      if (distc < cluster->minDistance) {
391        cluster->minDistance = distc;
392        cluster->rawIndexMinDistance = rawIndex1;
393      }
394      else if (   (distc > cluster->minDistance) && (cluster->rawIndexMinDistance == rawIndex1)
395               || (cluster->rawIndexMinDistance == rawIndex2))
396        cluster->computeMinimalDistance();
397    }
398
399    cluster1->elevate(pcluster2, minDistance);
400  }
401
402  return *clusters;
403}
404
405
406
407TClusterW *THierarchicalClustering::merge_CompleteLinkage(TClusterW **clusters, float *milestones)
408{
409  float *milestone = milestones;
410
411  int step = 0;
412  while((*clusters)->next) {
413    if (milestone && (step++ ==*milestone))
414      progressCallback->call(*((++milestone)++));
415
416    TClusterW *cluster;
417    TClusterW **pcluster2;
418
419    float minDistance = numeric_limits<float>::max();
420    for(TClusterW **tcluster = &((*clusters)->next); *tcluster; tcluster = &(*tcluster)->next)
421      if ((*tcluster)->minDistance < minDistance) {
422        minDistance = (*tcluster)->minDistance;
423        pcluster2 = tcluster;
424      }
425
426    TClusterW *const cluster2 = *pcluster2;
427
428    const int rawIndex1 = cluster2->rawIndexMinDistance;
429    const int rawIndex2 = cluster2->nDistances;
430
431    TClusterW *const cluster1 = clusters[rawIndex1];
432
433    float *disti1 = cluster1->distances;
434    float *disti2 = cluster2->distances;
435
436    if (rawIndex1) { // not root - has no distances...
437      if (*disti2 > *disti1)
438        *disti1 = *disti2;
439      const float *minIndex1 = disti1;
440      int ndi = cluster1->nDistances-1;
441      for(disti1++, disti2++; ndi--; disti1++, disti2++) {
442        if (*disti1 >= 0) {
443          if (*disti2 > *disti1) // if one is -1, they both are
444            *disti1 = *disti2;
445          if (*disti1 < *minIndex1)
446            minIndex1 = disti1;
447        }
448      }
449      cluster1->minDistance = *minIndex1;
450      cluster1->rawIndexMinDistance = minIndex1 - cluster1->distances;
451    }
452
453    while(*disti2 < 0)
454      disti2++;        // should have at least one more >=0  - the one corresponding to distance to cluster1
455
456    for(cluster = cluster1->next; cluster != cluster2; cluster = cluster->next) {
457      while(*++disti2 < 0); // should have more - the one corresponding to cluster
458      float &distc = cluster->distances[rawIndex1];
459      if (*disti2 > distc) {
460        distc = *disti2;
461        if (cluster->rawIndexMinDistance == rawIndex1)
462          if (distc <= cluster->minDistance)
463            cluster->minDistance = distc;
464          else
465            cluster->computeMinimalDistance();
466      }
467    }
468
469    for(cluster = cluster->next; cluster; cluster = cluster->next) {
470      float &distc = cluster->distances[rawIndex1];
471      if (cluster->distances[rawIndex2] > distc)
472        distc = cluster->distances[rawIndex2];
473
474      cluster->distances[rawIndex2] = -1;
475      if ((cluster->rawIndexMinDistance == rawIndex1) || (cluster->rawIndexMinDistance == rawIndex2))
476        cluster->computeMinimalDistance();
477    }
478
479    cluster1->elevate(pcluster2, minDistance);
480  }
481
482  return *clusters;
483}
484
485
486
487TClusterW *THierarchicalClustering::merge(TClusterW **clusters, float *milestones)
488{
489  switch(linkage) {
490    case Complete: return merge_CompleteLinkage(clusters, milestones);
491    case Single: return merge_SingleLinkage(clusters, milestones);
492    case Average: 
493    case Ward:
494    default: return merge_AverageLinkage(clusters, milestones);
495  }
496}
497
498PHierarchicalCluster THierarchicalClustering::restructure(TClusterW *root)
499{
500  PIntList elementIndices = new TIntList(root->size);
501  TIntList::iterator currentElement(elementIndices->begin());
502  int currentIndex = 0;
503
504  return restructure(root, elementIndices, currentElement, currentIndex);
505}
506
507
508PHierarchicalCluster THierarchicalClustering::restructure(TClusterW *node, PIntList elementIndices, TIntList::iterator &currentElement, int &currentIndex)
509{
510  PHierarchicalCluster cluster;
511
512  if (!node->left) {
513    *currentElement++ = node->elementIndex;
514    cluster = mlnew THierarchicalCluster(elementIndices, currentIndex++);
515  }
516  else {
517    PHierarchicalCluster left = restructure(node->left, elementIndices, currentElement, currentIndex);
518    PHierarchicalCluster right = restructure(node->right, elementIndices, currentElement, currentIndex);
519    cluster = mlnew THierarchicalCluster(elementIndices, left, right, node->height, left->first, right->last);
520  }
521
522  // no need to care about 'distances' - they've been removed during clustering (in 'elevate') :)
523  mldelete node;
524  return cluster;
525}
526
527
528PHierarchicalCluster THierarchicalClustering::operator()(PSymMatrix distanceMatrix)
529{
530  float *distanceMatrixElements = NULL;
531  TClusterW **clusters, *root;
532  float *callbackMilestones = NULL;
533 
534  try {
535    const int dim = distanceMatrix->dim;
536    const int size = ((dim+1)*(dim+2))/2;
537    float *distanceMatrixElements = overwriteMatrix ? distanceMatrix->elements : (float *)memcpy(new float[size], distanceMatrix->elements, size*sizeof(float));
538
539    clusters = init(dim, distanceMatrixElements);
540    callbackMilestones = (progressCallback && (distanceMatrix->dim >= 1000)) ? progressCallback->milestones(distanceMatrix->dim) : NULL;
541    root = merge(clusters, callbackMilestones);
542  }
543  catch (...) {
544    mldelete clusters;
545    mldelete callbackMilestones;
546    mldelete distanceMatrixElements;
547    throw;
548  }
549
550  mldelete clusters;
551  mldelete callbackMilestones;
552  mldelete distanceMatrixElements;
553
554  return restructure(root);
555}
556
557
558/*
559 *  Optimal leaf ordering.
560 */
561
562struct m_element {
563    THierarchicalCluster * cluster;
564    int left;
565    int right;
566
567    m_element(THierarchicalCluster * cluster, int left, int right);
568    m_element(const m_element & other);
569
570    inline bool operator< (const m_element & other) const;
571    inline bool operator== (const m_element & other) const;
572}; // cluster joined at left and right index
573
574struct ordering_element {
575    THierarchicalCluster * left;
576    unsigned int u; // the left most (outer) index of left luster
577    unsigned m; // the rightmost (inner) index of left cluster
578    THierarchicalCluster * right;
579    unsigned int w; // the right most (outer) index of the right cluster
580    unsigned int k; // the left most (inner) index of the right cluster
581
582    ordering_element();
583    ordering_element(THierarchicalCluster * left, unsigned int u, unsigned m,
584            THierarchicalCluster * right, unsigned int w, unsigned int k);
585    ordering_element(const ordering_element & other);
586};
587
588m_element::m_element(THierarchicalCluster * _cluster, int _left, int _right):
589    cluster(_cluster), left(_left), right(_right)
590{}
591
592m_element::m_element(const m_element & other):
593    cluster(other.cluster), left(other.left), right(other.right)
594{}
595
596bool m_element::operator< (const m_element & other) const
597{
598    if (cluster < other.cluster)
599        return true;
600    else
601        if (cluster == other.cluster)
602            if (left < other.left)
603                return true;
604            else
605                if (left == other.left)
606                    return right < other.right;
607                else
608                    return false;
609        else
610            return false;
611}
612
613bool m_element::operator== (const m_element & other) const
614{
615    return cluster == other.cluster && left == other.left && right == other.right;
616}
617
618struct m_element_hash
619{
620    inline size_t operator()(const m_element & m) const
621    {
622        size_t seed = 0;
623        hash_combine(seed, (size_t) m.cluster);
624        hash_combine(seed, (size_t) m.left);
625        hash_combine(seed, (size_t) m.right);
626        return seed;
627    }
628
629    // more or less taken from boost::hash_combine
630    inline void hash_combine(size_t &seed, size_t val) const
631    {
632        seed ^= val + 0x9e3779b9 + (seed << 6) + (seed >> 2);
633    }
634};
635
636ordering_element::ordering_element():
637    left(NULL), u(-1), m(-1),
638    right(NULL), w(-1), k(-1)
639{}
640
641ordering_element::ordering_element(THierarchicalCluster * _left,
642        unsigned int _u,
643        unsigned _m,
644        THierarchicalCluster * _right,
645        unsigned int _w,
646        unsigned int _k
647    ): left(_left), u(_u), m(_m),
648       right(_right), w(_w), k(_k)
649{}
650
651ordering_element::ordering_element(const ordering_element & other):
652       left(other.left), u(other.u), m(other.m),
653       right(other.right), w(other.w), k(other.k)
654{}
655
656
657
658typedef _HCLUST_UNORDERED_MAP<m_element, double, m_element_hash> join_scores;
659typedef _HCLUST_UNORDERED_MAP<m_element, ordering_element, m_element_hash> cluster_ordering;
660
661
662// Return the minimum distance between elements in matrix
663//
664float min_distance(
665        TIntList::iterator indices1_begin,
666        TIntList::iterator indices1_end,
667        TIntList::iterator indices2_begin,
668        TIntList::iterator indices2_end,
669        TSymMatrix & matrix)
670{
671//  TIntList::iterator iter1 = indices1.begin(), iter2 = indices2.begin();
672    float minimum = std::numeric_limits<float>::infinity();
673    TIntList::iterator indices2;
674    for (; indices1_begin != indices1_end; indices1_begin++)
675        for (indices2 = indices2_begin; indices2 != indices2_end; indices2++){
676            minimum = std::min(matrix.getitem(*indices1_begin, *indices2), minimum);
677        }
678    return minimum;
679}
680
681struct CompareByScores
682{
683    join_scores & scores;
684    const THierarchicalCluster & cluster;
685    const int & fixed;
686
687    CompareByScores(join_scores & _scores, const THierarchicalCluster & _cluster, const int & _fixed):
688        scores(_scores), cluster(_cluster), fixed(_fixed)
689    {}
690    bool operator() (int lhs, int rhs)
691    {
692        m_element left((THierarchicalCluster*)&cluster, fixed, lhs);
693        m_element right((THierarchicalCluster*)&cluster, fixed, rhs);
694        return scores[left] < scores[right];
695    }
696};
697
698
699//#include <iostream>
700//#include <cassert>
701
702// This needs to be called with all left, right pairs to
703// update all scores for cluster.
704void partial_opt_ordering(
705        THierarchicalCluster & cluster,
706        THierarchicalCluster & left,
707        THierarchicalCluster & right,
708        THierarchicalCluster & left_left,
709        THierarchicalCluster & left_right,
710        THierarchicalCluster & right_left,
711        THierarchicalCluster & right_right,
712        TSymMatrix &matrix,
713        join_scores & M,
714        cluster_ordering & ordering)
715{
716    int u = 0, w = 0;
717    TIntList & mapping = cluster.mapping.getReference();
718    for (TIntList::iterator u_iter = mapping.begin() + left_left.first;
719            u_iter != mapping.begin() + left_left.last;
720            u_iter++)
721        for (TIntList::iterator w_iter = mapping.begin() + right_right.first;
722                w_iter != mapping.begin() + right_right.last;
723                w_iter++)
724        {
725            u = *u_iter;
726            w = *w_iter;
727            float curr_min = std::numeric_limits<float>::infinity();
728            int curr_k = 0, curr_m = 0;
729            float C = min_distance(mapping.begin() + left_right.first,
730                    mapping.begin() + left_right.last,
731                    mapping.begin() + right_left.first,
732                    mapping.begin() + right_left.last,
733                    matrix);
734
735            vector<int> m_ordered(mapping.begin() + left_right.first,
736                    mapping.begin() + left_right.last);
737            vector<int> k_ordered(mapping.begin() + right_left.first,
738                    mapping.begin() + right_left.last);
739
740            // TODO: precompute the scores for m and k in an array and use a simpler
741            // comparison function
742            std::sort(m_ordered.begin(), m_ordered.end(), CompareByScores(M, left, u));
743            std::sort(k_ordered.begin(), k_ordered.end(), CompareByScores(M, right, w));
744
745
746            int k0 = k_ordered.front();
747            m_element m_right_k0(&right, w, k0);
748            int m = 0, k = 0;
749            for (vector<int>::iterator iter_m=m_ordered.begin(); iter_m != m_ordered.end(); iter_m++)
750            {
751                m = *iter_m;
752
753                m_element m_left(&left, u, m);
754
755                if (M[m_left] + M[m_right_k0] + C >= curr_min){
756                    break;
757                }
758                for (vector<int>::iterator iter_k = k_ordered.begin(); iter_k != k_ordered.end(); iter_k++)
759                {
760                    k = *iter_k;
761                    m_element m_right(&right, w, k);
762                    if (M[m_left] + M[m_right] + C >= curr_min)
763                    {
764                        break;
765                    }
766                    float test_val = M[m_left] + M[m_right] + matrix.getitem(m, k);
767                    if (curr_min > test_val)
768                    {
769                        curr_min = test_val;
770                        curr_k = k;
771                        curr_m = m;
772                    }
773                }
774
775            }
776
777            M[m_element(&cluster, u, w)] = curr_min;
778            M[m_element(&cluster, w, u)] = curr_min;
779
780//          assert(M[m_element(&cluster, u, w)] == M[m_element(&cluster, u, w)]);
781//          assert(M[m_element(&cluster, u, w)] == curr_min);
782
783//          assert(ordering.find(m_element(&cluster, u, w)) == ordering.end());
784//          assert(ordering.find(m_element(&cluster, w, u)) == ordering.end());
785
786            ordering[m_element(&cluster, u, w)] = ordering_element(&left, u, curr_m, &right, w, curr_k);
787            ordering[m_element(&cluster, w, u)] = ordering_element(&right, w, curr_k, &left, u, curr_m);
788        }
789}
790
791void order_clusters(
792        THierarchicalCluster & cluster,
793        TSymMatrix &matrix,
794        join_scores & M,
795        cluster_ordering & ordering,
796        TProgressCallback * callback)
797{
798    if (cluster.size() == 1)
799    {
800        M[m_element(&cluster, cluster.mapping->at(cluster.first), cluster.mapping->at(cluster.first))] = 0.0;
801        return;
802    }
803    else if (cluster.branches->size() == 2)
804    {
805        PHierarchicalCluster left = cluster.branches->at(0);
806        PHierarchicalCluster right = cluster.branches->at(1);
807
808        order_clusters(left.getReference(), matrix, M, ordering, callback);
809        order_clusters(right.getReference(), matrix, M, ordering, callback);
810
811        PHierarchicalCluster  left_left = (!left->branches) ? left : left->branches->at(0);
812        PHierarchicalCluster  left_right = (!left->branches) ? left : left->branches->at(1);
813        PHierarchicalCluster  right_left = (!right->branches) ? right : right->branches->at(0);
814        PHierarchicalCluster  right_right = (!right->branches) ? right : right->branches->at(1);
815
816        // 1.)
817        partial_opt_ordering(cluster,
818                left.getReference(), right.getReference(),
819                left_left.getReference(), left_right.getReference(),
820                right_left.getReference(), right_right.getReference(),
821                matrix, M, ordering);
822
823        if (right->branches)
824            // 2.) Switch right branches.
825            // (if there are no right branches the ordering has already been evaluated in 1.)
826            partial_opt_ordering(cluster,
827                    left.getReference(), right.getReference(),
828                    left_left.getReference(), left_right.getReference(),
829                    right_right.getReference(), right_left.getReference(),
830                    matrix, M, ordering);
831
832        if (left->branches)
833            // 3.) Switch left branches.
834            // (if there are no left branches the ordering has already been evaluated in 1. and 2.)
835            partial_opt_ordering(cluster,
836                    left.getReference(), right.getReference(),
837                    left_right.getReference(), left_left.getReference(),
838                    right_left.getReference(), right_right.getReference(),
839                    matrix, M, ordering);
840
841        if (left->branches && right->branches)
842            // 4.) Switch both branches.
843            partial_opt_ordering(cluster,
844                    left.getReference(), right.getReference(),
845                    left_right.getReference(), left_left.getReference(),
846                    right_right.getReference(), right_left.getReference(),
847                    matrix, M, ordering);
848    }
849    if (callback)
850        // TODO: count the number of already processed nodes.
851        callback->operator()(0.0, PHierarchicalCluster(&cluster));
852}
853
854/* Check if TIntList contains element.
855 */
856bool contains(TIntList::iterator iter_begin, TIntList::iterator iter_end, int element)
857{
858    return std::find(iter_begin, iter_end, element) != iter_end;
859}
860
861void optimal_swap(THierarchicalCluster * cluster, int u, int w, cluster_ordering & ordering)
862{
863    if (cluster->branches)
864    {
865        assert(ordering.find(m_element(cluster, u, w)) != ordering.end());
866        ordering_element ord = ordering[m_element(cluster, u, w)];
867
868        PHierarchicalCluster left_right = (ord.left->branches)? ord.left->branches->at(1) : PHierarchicalCluster(NULL);
869        PHierarchicalCluster right_left = (ord.right->branches)? ord.right->branches->at(0) : PHierarchicalCluster(NULL);
870
871        TIntList & mapping = cluster->mapping.getReference();
872        if (left_right && !contains(mapping.begin() + left_right->first,
873                mapping.begin() + left_right->last, ord.m))
874        {
875            assert(!contains(mapping.begin() + left_right->first,
876                             mapping.begin() + left_right->last, ord.m));
877            ord.left->swap();
878            left_right = ord.left->branches->at(1);
879            assert(contains(mapping.begin() + left_right->first,
880                            mapping.begin() + left_right->last, ord.m));
881        }
882        optimal_swap(ord.left, ord.u, ord.m, ordering);
883
884        assert(mapping.at(ord.left->first) == ord.u);
885        assert(mapping.at(ord.left->last - 1) == ord.m);
886
887        if (right_left && !contains(mapping.begin() + right_left->first,
888                mapping.begin() + right_left->last, ord.k))
889        {
890            assert(!contains(mapping.begin() + right_left->first,
891                             mapping.begin() + right_left->last, ord.k));
892            ord.right->swap();
893            right_left = ord.right->branches->at(0);
894            assert(contains(mapping.begin() + right_left->first,
895                            mapping.begin() + right_left->last, ord.k));
896        }
897        optimal_swap(ord.right, ord.k, ord.w, ordering);
898
899        assert(mapping.at(ord.right->first) == ord.k);
900        assert(mapping.at(ord.right->last - 1) == ord.w);
901
902        assert(mapping.at(cluster->first) == ord.u);
903        assert(mapping.at(cluster->last - 1) == ord.w);
904    }
905}
906
907PHierarchicalCluster THierarchicalClusterOrdering::operator() (
908        PHierarchicalCluster root,
909        PSymMatrix matrix)
910{
911    join_scores M; // scores
912    cluster_ordering ordering;
913    order_clusters(root.getReference(), matrix.getReference(), M, ordering,
914            progress_callback.getUnwrappedPtr());
915
916    int u = 0, w = 0;
917    int min_u = 0, min_w = 0;
918    float min_score = std::numeric_limits<float>::infinity();
919    TIntList & mapping = root->mapping.getReference();
920
921    for (TIntList::iterator u_iter = mapping.begin() + root->branches->at(0)->first;
922            u_iter != mapping.begin() + root->branches->at(0)->last;
923            u_iter++)
924        for (TIntList::iterator w_iter = mapping.begin() + root->branches->at(1)->first;
925                    w_iter != mapping.begin() + root->branches->at(1)->last;
926                    w_iter++)
927        {
928            u = *u_iter; w = *w_iter;
929            m_element el(root.getUnwrappedPtr(), u, w);
930            if (M[el] < min_score)
931            {
932                min_score = M[el];
933                min_u = u;
934                min_w = w;
935            }
936        }
937//  std::cout << "Min score "<< min_score << endl;
938
939    optimal_swap(root.getUnwrappedPtr(), min_u, min_w, ordering);
940    return root;
941}
Note: See TracBrowser for help on using the repository browser.