source: orange/source/orange/hclust.cpp @ 9201:296ebc9e84da

Revision 9201:296ebc9e84da, 30.1 KB checked in by ales_erjavec <ales.erjavec@…>, 2 years ago (diff)

Fixed m_element_hash function.

Line 
1/*
2    This file is part of Orange.
3   
4    Copyright 1996-2010 Faculty of Computer and Information Science, University of Ljubljana
5    Contact: janez.demsar@fri.uni-lj.si
6
7    Orange is free software: you can redistribute it and/or modify
8    it under the terms of the GNU General Public License as published by
9    the Free Software Foundation, either version 3 of the License, or
10    (at your option) any later version.
11
12    Orange is distributed in the hope that it will be useful,
13    but WITHOUT ANY WARRANTY; without even the implied warranty of
14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15    GNU General Public License for more details.
16
17    You should have received a copy of the GNU General Public License
18    along with Orange.  If not, see <http://www.gnu.org/licenses/>.
19*/
20#include "progress.hpp"
21
22#include "hclust.ppp"
23
24DEFINE_TOrangeVector_classDescription(PHierarchicalCluster, "THierarchicalClusterList", true, ORANGE_API)
25
26class TClusterW {
27public:
28    TClusterW *next; // for cluster, this is next cluster, for elements next element
29    TClusterW *left, *right; // subclusters, if left==NULL, this is an element
30    int size;
31    int elementIndex;
32    float height;
33
34    float *distances; // distances to clusters before this one (lower left matrix)
35    float minDistance; // minimal distance
36    int rawIndexMinDistance; // index of minimal distance
37    int nDistances;
38
39    TClusterW(const int &elIndex, float *adistances, const int &anDistances)
40    : next(NULL),
41      left(NULL),
42      right(NULL),
43      size(1),
44      elementIndex(elIndex),
45      height(0.0),
46      distances(adistances),
47      minDistance(numeric_limits<float>::max()),
48      rawIndexMinDistance(-1),
49      nDistances(anDistances)
50    {
51      if (distances)
52        computeMinimalDistance();
53    }
54
55    void elevate(TClusterW **aright, const float &aheight)
56    {
57      left = mlnew TClusterW(*this);
58      right = *aright;
59
60      left->distances = NULL;
61      size = left->size + right->size;
62      elementIndex = -1;
63      height = aheight;
64
65      if (next == right)
66        next = right->next;
67      else
68        *aright = right->next;
69    }
70
71
72    void computeMinimalDistance()
73    {
74      float *dp = distances, *minp = dp++;
75      for(int i = nDistances; --i; dp++)
76        if ((*dp >= 0) && (*dp < *minp))
77          minp = dp;
78      minDistance = *minp;
79      rawIndexMinDistance = minp - distances;
80    }
81
82    TClusterW **clusterAt(int rawIndex, TClusterW **cluster)
83    {
84      for(float *dp = distances; rawIndex; dp++)
85        if (*dp>=0) {
86          cluster = &(*cluster)->next;
87          rawIndex--;
88        }
89      return cluster;
90    }
91
92    void clearDistances()
93    {
94      raiseErrorWho("TClusterW", "rewrite clean to adjust indexMinDistance");
95
96      float *dst = distances;
97      int i = nDistances;
98      for(; nDistances && (*dst >= 0); nDistances--, dst++);
99      if (!nDistances)
100        return;
101     
102      float *src = dst;
103      for(src++, nDistances--; nDistances && (*src >= 0); src++)
104        if (*src >= 0)
105          *dst++ = *src;
106
107      nDistances = dst - distances;
108    }
109};
110
111
112THierarchicalCluster::THierarchicalCluster()
113: branches(),
114  height(0.0),
115  mapping(),
116  first(0),
117  last(0)
118{}
119
120
121THierarchicalCluster::THierarchicalCluster(PIntList els, const int &elementIndex)
122: branches(),
123  height(0.0),
124  mapping(els),
125  first(elementIndex),
126  last(elementIndex+1)
127{}
128
129
130THierarchicalCluster::THierarchicalCluster(PIntList els, PHierarchicalCluster left, PHierarchicalCluster right, const float &h, const int &f, const int &l)
131: branches(new THierarchicalClusterList(2)),
132  height(h),
133  mapping(els),
134  first(f),
135  last(l)
136{ 
137  branches->at(0) = left;
138  branches->at(1) = right;
139}
140
141
142void THierarchicalCluster::swap()
143{
144  if (!branches || (branches->size()<2))
145    return;
146  if (branches->size() > 2)
147    raiseError("cannot swap multiple branches (use method 'permutation' instead)");
148
149  const TIntList::iterator beg0 = mapping->begin() + branches->at(0)->first;
150  const TIntList::iterator beg1 = mapping->begin() + branches->at(1)->first;
151  const TIntList::iterator end1 = mapping->begin() + branches->at(1)->last;
152
153  if ((branches->at(0)->first > branches->at(1)->first) || (branches->at(1)->first > branches->at(1)->last))
154    raiseError("internal inconsistency in clustering structure: invalid ordering of left's and right's 'first' and 'last'");
155
156  TIntList::iterator li0, li1;
157
158  int *temp = new int [beg1 - beg0], *t;
159  for(li0 = beg0, t = temp; li0 != beg1; *t++ = *li0++);
160  for(li0 = beg0, li1 = beg1; li1 != end1; *li0++ = *li1++);
161  for(t = temp; li0 != end1; *li0++ = *t++);
162  delete temp;
163
164  branches->at(0)->recursiveMove(end1 - beg1);
165  branches->at(1)->recursiveMove(beg0 - beg1);
166
167  PHierarchicalCluster tbr = branches->at(0);
168  branches->at(0) = branches->at(1);
169  branches->at(1) = tbr;
170}
171
172
173void THierarchicalCluster::permute(const TIntList &neworder)
174{
175  if ((!branches && neworder.size()) || (branches->size() != neworder.size()))
176    raiseError("the number of clusters does not match the lenght of the permutation vector");
177
178  int *temp = new int [last - first], *t = temp;
179  TIntList::const_iterator pi = neworder.begin();
180  THierarchicalClusterList::iterator bi(branches->begin()), be(branches->end());
181  THierarchicalClusterList newBranches;
182
183  for(; bi != be; bi++, pi++) {
184    PHierarchicalCluster branch = branches->at(*pi);
185    newBranches.push_back(branch);
186    TIntList::const_iterator bei(mapping->begin() + branch->first), bee(mapping->begin() + branch->last);
187    const int offset = (t - temp) - (branch->first - first);
188    for(; bei != bee; *t++ = *bei++);
189    if (offset)
190      branch->recursiveMove(offset);
191  }
192
193  TIntList::iterator bei(mapping->begin() + first), bee(mapping->begin() + last);
194  for(t = temp; bei!=bee; *bei++ = *t++);
195
196  bi = branches->begin();
197  THierarchicalClusterList::const_iterator nbi(newBranches.begin());
198  for(; bi != be; *bi++ = *nbi++);
199}
200
201
202void THierarchicalCluster::recursiveMove(const int &offset)
203{
204  first += offset;
205  last += offset;
206  if (branches)
207    PITERATE(THierarchicalClusterList, bi, branches)
208      (*bi)->recursiveMove(offset);
209}
210
211
212THierarchicalClustering::THierarchicalClustering()
213: linkage(Single),
214  overwriteMatrix(false)
215{}
216
217
218TClusterW **THierarchicalClustering::init(const int &dim, float *distanceMatrix)
219{
220  for(float *ddi = distanceMatrix, *dde = ddi + ((dim+1)*(dim+2))/2; ddi!=dde; ddi++)
221    if (*ddi < 0) {
222      int x, y;
223      TSymMatrix::index2coordinates(ddi-distanceMatrix, x, y);
224      raiseError("distance matrix contains negative element at (%i, %i)", x, y);
225    }
226
227  TClusterW **clusters = mlnew TClusterW *[dim];
228  TClusterW **clusteri = clusters;
229
230  *clusters = mlnew TClusterW(0, NULL, 0);
231  distanceMatrix++;
232 
233  for(int elementIndex = 1, e = dim; elementIndex < e; distanceMatrix += ++elementIndex) {
234    TClusterW *newcluster = mlnew TClusterW(elementIndex, distanceMatrix, elementIndex);
235    (*clusteri++)->next = newcluster;
236    *clusteri = newcluster; 
237  }
238
239  return clusters;
240}
241
242
243
244TClusterW *THierarchicalClustering::merge_SingleLinkage(TClusterW **clusters, float *milestones)
245{
246  float *milestone = milestones;
247
248  int step = 0;
249  while((*clusters)->next) {
250    if (milestone && (step++ ==*milestone))
251      progressCallback->call(*((++milestone)++));
252
253    TClusterW *cluster;
254    TClusterW **pcluster2;
255
256    float minDistance = numeric_limits<float>::max();
257    for(TClusterW **tcluster = &((*clusters)->next); *tcluster; tcluster = &(*tcluster)->next)
258      if ((*tcluster)->minDistance < minDistance) {
259        minDistance = (*tcluster)->minDistance;
260        pcluster2 = tcluster;
261      }
262
263    TClusterW *const cluster2 = *pcluster2;
264
265    const int rawIndex1 = cluster2->rawIndexMinDistance;
266    const int rawIndex2 = cluster2->nDistances;
267
268    TClusterW *const cluster1 = clusters[rawIndex1];
269
270    float *disti1 = cluster1->distances;
271    float *disti2 = cluster2->distances;
272
273    if (rawIndex1) { // not root - has no distances...
274      const float *minIndex1 = cluster1->distances + cluster1->rawIndexMinDistance;
275      for(int ndi = cluster1->nDistances; ndi--; disti1++, disti2++)
276        if (*disti2 < *disti1) // if one is -1, they both are
277          if ((*disti1 = *disti2) < *minIndex1)
278            minIndex1 = disti1;
279      cluster1->minDistance = *minIndex1;
280      cluster1->rawIndexMinDistance = minIndex1 - cluster1->distances;
281    }
282
283    while(*disti2 < 0)
284      disti2++;        // should have at least one more >=0  - the one corresponding to distance to cluster1
285
286    for(cluster = cluster1->next; cluster != cluster2; cluster = cluster->next) {
287      while(*++disti2 < 0); // should have more - the one corresponding to cluster
288      if (*disti2 < cluster->distances[rawIndex1])
289        if ((cluster->distances[rawIndex1] = *disti2) < cluster->minDistance) {
290          cluster->minDistance = *disti2;
291          cluster->rawIndexMinDistance = rawIndex1;
292        }
293    }
294
295    for(cluster = cluster->next; cluster; cluster = cluster->next) {
296      if (cluster->distances[rawIndex2] < cluster->distances[rawIndex1])
297        cluster->distances[rawIndex1] = cluster->distances[rawIndex2];
298
299      // don't nest this in the above if -- both distances can be equal, yet index HAS TO move
300      if (rawIndex2 == cluster->rawIndexMinDistance)
301        cluster->rawIndexMinDistance = rawIndex1; // the smallest element got moved
302
303      cluster->distances[rawIndex2] = -1;
304    }
305
306    cluster1->elevate(pcluster2, minDistance);
307  }
308
309  return *clusters;
310}
311
312// Also computes Ward's linkage
313TClusterW *THierarchicalClustering::merge_AverageLinkage(TClusterW **clusters, float *milestones)
314{
315  float *milestone = milestones;
316  bool ward = linkage == Ward;
317
318  int step = 0;
319  while((*clusters)->next) {
320    if (milestone && (step++ ==*milestone))
321      progressCallback->call(*((++milestone)++));
322
323    TClusterW *cluster;
324    TClusterW **pcluster2;
325
326    float minDistance = numeric_limits<float>::max();
327    for(TClusterW **tcluster = &((*clusters)->next); *tcluster; tcluster = &(*tcluster)->next)
328      if ((*tcluster)->minDistance < minDistance) {
329        minDistance = (*tcluster)->minDistance;
330        pcluster2 = tcluster;
331      }
332
333    TClusterW *const cluster2 = *pcluster2;
334
335    const int rawIndex1 = cluster2->rawIndexMinDistance;
336    const int rawIndex2 = cluster2->nDistances;
337
338    TClusterW *const cluster1 = clusters[rawIndex1];
339
340    float *disti1 = cluster1->distances;
341    float *disti2 = cluster2->distances;
342
343    const int size1 = cluster1->size;
344    const int size2 = cluster2->size;
345    const int sumsize = size1 + size2;
346    cluster = (*clusters)->next;
347
348    if (rawIndex1) { // not root - has no distances...
349      const float sizeK = cluster->size;
350      *disti1 = ward ? (*disti1 * (size1+sizeK) + *disti2 * (size2+sizeK) - minDistance * sizeK) / (sumsize+sizeK)
351                     : (*disti1 * size1 + *disti2 * size2) / sumsize;
352      const float *minIndex1 = disti1;
353      int ndi = cluster1->nDistances-1;
354      for(disti1++, disti2++, cluster = cluster->next; ndi--; disti1++, disti2++)
355        if (*disti1 >= 0) {
356          const float sizeK = cluster->size;
357          cluster = cluster->next;
358          *disti1 = ward ? (*disti1 * (size1+sizeK) + *disti2 * (size2+sizeK) - minDistance * sizeK) / (sumsize+sizeK)
359                         : (*disti1 * size1         + *disti2 * size2) / sumsize;
360          if (*disti1 < *minIndex1)
361            minIndex1 = disti1;
362        }
363      cluster1->minDistance = *minIndex1;
364      cluster1->rawIndexMinDistance = minIndex1 - cluster1->distances;
365    }
366
367    while(*disti2 < 0)
368      disti2++;        // should have at least one more >=0  - the one corresponding to distance to cluster1
369
370    for(cluster = cluster1->next; cluster != cluster2; cluster = cluster->next) {
371      while(*++disti2 < 0); // should have more - the one corresponding to cluster
372      float &distc = cluster->distances[rawIndex1];
373      const float sizeK = cluster->size;
374      distc = ward ? (distc * (size1+sizeK) + *disti2 * (size2+sizeK) - minDistance * sizeK) / (sumsize+sizeK)
375                   : (distc * size1         + *disti2 * size2) / sumsize;
376      if (distc < cluster->minDistance) {
377        cluster->minDistance = distc;
378        cluster->rawIndexMinDistance = rawIndex1;
379      }
380      else if ((distc > cluster->minDistance) && (cluster->rawIndexMinDistance == rawIndex1))
381        cluster->computeMinimalDistance();
382    }
383
384    for(cluster = cluster->next; cluster; cluster = cluster->next) {
385      float &distc = cluster->distances[rawIndex1];
386      const float sizeK = cluster->size;
387      distc = ward ? (distc * (size1+sizeK) + cluster->distances[rawIndex2] * (size2+sizeK) - minDistance * sizeK) / (sumsize+sizeK)
388                   : (distc * size1         + cluster->distances[rawIndex2] * size2) / sumsize;
389      cluster->distances[rawIndex2] = -1;
390      if (distc < cluster->minDistance) {
391        cluster->minDistance = distc;
392        cluster->rawIndexMinDistance = rawIndex1;
393      }
394      else if (   (distc > cluster->minDistance) && (cluster->rawIndexMinDistance == rawIndex1)
395               || (cluster->rawIndexMinDistance == rawIndex2))
396        cluster->computeMinimalDistance();
397    }
398
399    cluster1->elevate(pcluster2, minDistance);
400  }
401
402  return *clusters;
403}
404
405
406
407TClusterW *THierarchicalClustering::merge_CompleteLinkage(TClusterW **clusters, float *milestones)
408{
409  float *milestone = milestones;
410
411  int step = 0;
412  while((*clusters)->next) {
413    if (milestone && (step++ ==*milestone))
414      progressCallback->call(*((++milestone)++));
415
416    TClusterW *cluster;
417    TClusterW **pcluster2;
418
419    float minDistance = numeric_limits<float>::max();
420    for(TClusterW **tcluster = &((*clusters)->next); *tcluster; tcluster = &(*tcluster)->next)
421      if ((*tcluster)->minDistance < minDistance) {
422        minDistance = (*tcluster)->minDistance;
423        pcluster2 = tcluster;
424      }
425
426    TClusterW *const cluster2 = *pcluster2;
427
428    const int rawIndex1 = cluster2->rawIndexMinDistance;
429    const int rawIndex2 = cluster2->nDistances;
430
431    TClusterW *const cluster1 = clusters[rawIndex1];
432
433    float *disti1 = cluster1->distances;
434    float *disti2 = cluster2->distances;
435
436    if (rawIndex1) { // not root - has no distances...
437      if (*disti2 > *disti1)
438        *disti1 = *disti2;
439      const float *minIndex1 = disti1;
440      int ndi = cluster1->nDistances-1;
441      for(disti1++, disti2++; ndi--; disti1++, disti2++) {
442        if (*disti1 >= 0) {
443          if (*disti2 > *disti1) // if one is -1, they both are
444            *disti1 = *disti2;
445          if (*disti1 < *minIndex1)
446            minIndex1 = disti1;
447        }
448      }
449      cluster1->minDistance = *minIndex1;
450      cluster1->rawIndexMinDistance = minIndex1 - cluster1->distances;
451    }
452
453    while(*disti2 < 0)
454      disti2++;        // should have at least one more >=0  - the one corresponding to distance to cluster1
455
456    for(cluster = cluster1->next; cluster != cluster2; cluster = cluster->next) {
457      while(*++disti2 < 0); // should have more - the one corresponding to cluster
458      float &distc = cluster->distances[rawIndex1];
459      if (*disti2 > distc) {
460        distc = *disti2;
461        if (cluster->rawIndexMinDistance == rawIndex1)
462          if (distc <= cluster->minDistance)
463            cluster->minDistance = distc;
464          else
465            cluster->computeMinimalDistance();
466      }
467    }
468
469    for(cluster = cluster->next; cluster; cluster = cluster->next) {
470      float &distc = cluster->distances[rawIndex1];
471      if (cluster->distances[rawIndex2] > distc)
472        distc = cluster->distances[rawIndex2];
473
474      cluster->distances[rawIndex2] = -1;
475      if ((cluster->rawIndexMinDistance == rawIndex1) || (cluster->rawIndexMinDistance == rawIndex2))
476        cluster->computeMinimalDistance();
477    }
478
479    cluster1->elevate(pcluster2, minDistance);
480  }
481
482  return *clusters;
483}
484
485
486
487TClusterW *THierarchicalClustering::merge(TClusterW **clusters, float *milestones)
488{
489  switch(linkage) {
490    case Complete: return merge_CompleteLinkage(clusters, milestones);
491    case Single: return merge_SingleLinkage(clusters, milestones);
492    case Average: 
493    case Ward:
494    default: return merge_AverageLinkage(clusters, milestones);
495  }
496}
497
498PHierarchicalCluster THierarchicalClustering::restructure(TClusterW *root)
499{
500  PIntList elementIndices = new TIntList(root->size);
501  TIntList::iterator currentElement(elementIndices->begin());
502  int currentIndex = 0;
503
504  return restructure(root, elementIndices, currentElement, currentIndex);
505}
506
507
508PHierarchicalCluster THierarchicalClustering::restructure(TClusterW *node, PIntList elementIndices, TIntList::iterator &currentElement, int &currentIndex)
509{
510  PHierarchicalCluster cluster;
511
512  if (!node->left) {
513    *currentElement++ = node->elementIndex;
514    cluster = mlnew THierarchicalCluster(elementIndices, currentIndex++);
515  }
516  else {
517    PHierarchicalCluster left = restructure(node->left, elementIndices, currentElement, currentIndex);
518    PHierarchicalCluster right = restructure(node->right, elementIndices, currentElement, currentIndex);
519    cluster = mlnew THierarchicalCluster(elementIndices, left, right, node->height, left->first, right->last);
520  }
521
522  // no need to care about 'distances' - they've been removed during clustering (in 'elevate') :)
523  mldelete node;
524  return cluster;
525}
526
527
528PHierarchicalCluster THierarchicalClustering::operator()(PSymMatrix distanceMatrix)
529{
530  float *distanceMatrixElements = NULL;
531  TClusterW **clusters, *root;
532  float *callbackMilestones = NULL;
533 
534  try {
535    const int dim = distanceMatrix->dim;
536    const int size = ((dim+1)*(dim+2))/2;
537    float *distanceMatrixElements = overwriteMatrix ? distanceMatrix->elements : (float *)memcpy(new float[size], distanceMatrix->elements, size*sizeof(float));
538
539    clusters = init(dim, distanceMatrixElements);
540    callbackMilestones = (progressCallback && (distanceMatrix->dim >= 1000)) ? progressCallback->milestones(distanceMatrix->dim) : NULL;
541    root = merge(clusters, callbackMilestones);
542  }
543  catch (...) {
544    mldelete clusters;
545    mldelete callbackMilestones;
546    mldelete distanceMatrixElements;
547    throw;
548  }
549
550  mldelete clusters;
551  mldelete callbackMilestones;
552  mldelete distanceMatrixElements;
553
554  return restructure(root);
555}
556
557
558/*
559 *  Optimal leaf ordering.
560 */
561
562struct m_element {
563    THierarchicalCluster * cluster;
564    int left;
565    int right;
566
567    m_element(THierarchicalCluster * cluster, int left, int right);
568    m_element(const m_element & other);
569
570    inline bool operator< (const m_element & other) const;
571    inline bool operator== (const m_element & other) const;
572}; // cluster joined at left and right index
573
574struct ordering_element {
575    THierarchicalCluster * left;
576    unsigned int u; // the left most (outer) index of left luster
577    unsigned m; // the rightmost (inner) index of left cluster
578    THierarchicalCluster * right;
579    unsigned int w; // the right most (outer) index of the right cluster
580    unsigned int k; // the left most (inner) index of the right cluster
581
582    ordering_element();
583    ordering_element(THierarchicalCluster * left, unsigned int u, unsigned m,
584            THierarchicalCluster * right, unsigned int w, unsigned int k);
585    ordering_element(const ordering_element & other);
586};
587
588m_element::m_element(THierarchicalCluster * _cluster, int _left, int _right):
589    cluster(_cluster), left(_left), right(_right)
590{}
591
592m_element::m_element(const m_element & other):
593    cluster(other.cluster), left(other.left), right(other.right)
594{}
595
596bool m_element::operator< (const m_element & other) const
597{
598    if (cluster < other.cluster)
599        return true;
600    else
601        if (cluster == other.cluster)
602            if (left < other.left)
603                return true;
604            else
605                if (left == other.left)
606                    return right < other.right;
607                else
608                    return false;
609        else
610            return false;
611}
612
613bool m_element::operator== (const m_element & other) const
614{
615    return cluster == other.cluster && left == other.left && right == other.right;
616}
617
618struct m_element_hash
619{
620    inline size_t operator()(const m_element & m) const
621    {
622        size_t seed = 0;
623        hash_combine(seed, (size_t) m.cluster);
624        hash_combine(seed, (size_t) m.left);
625        hash_combine(seed, (size_t) m.right);
626        return seed;
627    }
628
629    // more or less taken from boost::hash_combine
630    inline void hash_combine(size_t &seed, size_t val) const
631    {
632        seed ^= val + 0x9e3779b9 + (seed << 6) + (seed >> 2);
633    }
634};
635
636ordering_element::ordering_element():
637    left(NULL), u(-1), m(-1),
638    right(NULL), w(-1), k(-1)
639{}
640
641ordering_element::ordering_element(THierarchicalCluster * _left,
642        unsigned int _u,
643        unsigned _m,
644        THierarchicalCluster * _right,
645        unsigned int _w,
646        unsigned int _k
647    ): left(_left), u(_u), m(_m),
648       right(_right), w(_w), k(_k)
649{}
650
651ordering_element::ordering_element(const ordering_element & other):
652       left(other.left), u(other.u), m(other.m),
653       right(other.right), w(other.w), k(other.k)
654{}
655
656#define USE_TR1 1
657
658#if USE_TR1
659    #if _MSC_VER
660        #define HAVE_TR1_DIR 0
661    #else
662        #define HAVE_TR1_DIR 1
663    #endif
664    // Diffrent includes required
665    #if HAVE_TR1_DIR
666        #include <tr1/unordered_map>
667    #else
668        #include <unordered_map>
669    #endif
670    typedef std::tr1::unordered_map<m_element, double, m_element_hash> join_scores;
671    typedef std::tr1::unordered_map<m_element, ordering_element, m_element_hash> cluster_ordering;
672#else
673    typedef std::map<m_element, double> join_scores;
674    typedef std::map<m_element, ordering_element> cluster_ordering;
675#endif
676
677// Return the minimum distance between elements in matrix
678//
679float min_distance(
680        TIntList::iterator indices1_begin,
681        TIntList::iterator indices1_end,
682        TIntList::iterator indices2_begin,
683        TIntList::iterator indices2_end,
684        TSymMatrix & matrix)
685{
686//  TIntList::iterator iter1 = indices1.begin(), iter2 = indices2.begin();
687    float minimum = std::numeric_limits<float>::infinity();
688    TIntList::iterator indices2;
689    for (; indices1_begin != indices1_end; indices1_begin++)
690        for (indices2 = indices2_begin; indices2 != indices2_end; indices2++){
691            minimum = std::min(matrix.getitem(*indices1_begin, *indices2), minimum);
692        }
693    return minimum;
694}
695
696struct CompareByScores
697{
698    join_scores & scores;
699    const THierarchicalCluster & cluster;
700    const int & fixed;
701
702    CompareByScores(join_scores & _scores, const THierarchicalCluster & _cluster, const int & _fixed):
703        scores(_scores), cluster(_cluster), fixed(_fixed)
704    {}
705    bool operator() (int lhs, int rhs)
706    {
707        m_element left((THierarchicalCluster*)&cluster, fixed, lhs);
708        m_element right((THierarchicalCluster*)&cluster, fixed, rhs);
709        return scores[left] < scores[right];
710    }
711};
712
713
714//#include <iostream>
715//#include <cassert>
716
717// This needs to be called with all left, right pairs to
718// update all scores for cluster.
719void partial_opt_ordering(
720        THierarchicalCluster & cluster,
721        THierarchicalCluster & left,
722        THierarchicalCluster & right,
723        THierarchicalCluster & left_left,
724        THierarchicalCluster & left_right,
725        THierarchicalCluster & right_left,
726        THierarchicalCluster & right_right,
727        TSymMatrix &matrix,
728        join_scores & M,
729        cluster_ordering & ordering)
730{
731    int u = 0, w = 0;
732    TIntList & mapping = cluster.mapping.getReference();
733    for (TIntList::iterator u_iter = mapping.begin() + left_left.first;
734            u_iter != mapping.begin() + left_left.last;
735            u_iter++)
736        for (TIntList::iterator w_iter = mapping.begin() + right_right.first;
737                w_iter != mapping.begin() + right_right.last;
738                w_iter++)
739        {
740            u = *u_iter;
741            w = *w_iter;
742            float curr_min = std::numeric_limits<float>::infinity();
743            int curr_k = 0, curr_m = 0;
744            float C = min_distance(mapping.begin() + left_right.first,
745                    mapping.begin() + left_right.last,
746                    mapping.begin() + right_left.first,
747                    mapping.begin() + right_left.last,
748                    matrix);
749
750            vector<int> m_ordered(mapping.begin() + left_right.first,
751                    mapping.begin() + left_right.last);
752            vector<int> k_ordered(mapping.begin() + right_left.first,
753                    mapping.begin() + right_left.last);
754
755            // TODO: precompute the scores for m and k in an array and use a simpler
756            // comparison function
757            std::sort(m_ordered.begin(), m_ordered.end(), CompareByScores(M, left, u));
758            std::sort(k_ordered.begin(), k_ordered.end(), CompareByScores(M, right, w));
759
760
761            int k0 = k_ordered.front();
762            m_element m_right_k0(&right, w, k0);
763            int m = 0, k = 0;
764            for (vector<int>::iterator iter_m=m_ordered.begin(); iter_m != m_ordered.end(); iter_m++)
765            {
766                m = *iter_m;
767
768                m_element m_left(&left, u, m);
769
770                if (M[m_left] + M[m_right_k0] + C >= curr_min){
771                    break;
772                }
773                for (vector<int>::iterator iter_k = k_ordered.begin(); iter_k != k_ordered.end(); iter_k++)
774                {
775                    k = *iter_k;
776                    m_element m_right(&right, w, k);
777                    if (M[m_left] + M[m_right] + C >= curr_min)
778                    {
779                        break;
780                    }
781                    float test_val = M[m_left] + M[m_right] + matrix.getitem(m, k);
782                    if (curr_min > test_val)
783                    {
784                        curr_min = test_val;
785                        curr_k = k;
786                        curr_m = m;
787                    }
788                }
789
790            }
791
792            M[m_element(&cluster, u, w)] = curr_min;
793            M[m_element(&cluster, w, u)] = curr_min;
794
795//          assert(M[m_element(&cluster, u, w)] == M[m_element(&cluster, u, w)]);
796//          assert(M[m_element(&cluster, u, w)] == curr_min);
797
798//          assert(ordering.find(m_element(&cluster, u, w)) == ordering.end());
799//          assert(ordering.find(m_element(&cluster, w, u)) == ordering.end());
800
801            ordering[m_element(&cluster, u, w)] = ordering_element(&left, u, curr_m, &right, w, curr_k);
802            ordering[m_element(&cluster, w, u)] = ordering_element(&right, w, curr_k, &left, u, curr_m);
803        }
804}
805
806void order_clusters(
807        THierarchicalCluster & cluster,
808        TSymMatrix &matrix,
809        join_scores & M,
810        cluster_ordering & ordering,
811        TProgressCallback * callback)
812{
813    if (cluster.size() == 1)
814    {
815        M[m_element(&cluster, cluster.mapping->at(cluster.first), cluster.mapping->at(cluster.first))] = 0.0;
816        return;
817    }
818    else if (cluster.branches->size() == 2)
819    {
820        PHierarchicalCluster left = cluster.branches->at(0);
821        PHierarchicalCluster right = cluster.branches->at(1);
822
823        order_clusters(left.getReference(), matrix, M, ordering, callback);
824        order_clusters(right.getReference(), matrix, M, ordering, callback);
825
826        PHierarchicalCluster  left_left = (!left->branches) ? left : left->branches->at(0);
827        PHierarchicalCluster  left_right = (!left->branches) ? left : left->branches->at(1);
828        PHierarchicalCluster  right_left = (!right->branches) ? right : right->branches->at(0);
829        PHierarchicalCluster  right_right = (!right->branches) ? right : right->branches->at(1);
830
831        // 1.)
832        partial_opt_ordering(cluster,
833                left.getReference(), right.getReference(),
834                left_left.getReference(), left_right.getReference(),
835                right_left.getReference(), right_right.getReference(),
836                matrix, M, ordering);
837
838        if (right->branches)
839            // 2.) Switch right branches.
840            // (if there are no right branches the ordering has already been evaluated in 1.)
841            partial_opt_ordering(cluster,
842                    left.getReference(), right.getReference(),
843                    left_left.getReference(), left_right.getReference(),
844                    right_right.getReference(), right_left.getReference(),
845                    matrix, M, ordering);
846
847        if (left->branches)
848            // 3.) Switch left branches.
849            // (if there are no left branches the ordering has already been evaluated in 1. and 2.)
850            partial_opt_ordering(cluster,
851                    left.getReference(), right.getReference(),
852                    left_right.getReference(), left_left.getReference(),
853                    right_left.getReference(), right_right.getReference(),
854                    matrix, M, ordering);
855
856        if (left->branches && right->branches)
857            // 4.) Switch both branches.
858            partial_opt_ordering(cluster,
859                    left.getReference(), right.getReference(),
860                    left_right.getReference(), left_left.getReference(),
861                    right_right.getReference(), right_left.getReference(),
862                    matrix, M, ordering);
863    }
864    if (callback)
865        // TODO: count the number of already processed nodes.
866        callback->operator()(0.0, PHierarchicalCluster(&cluster));
867}
868
869/* Check if TIntList contains element.
870 */
871bool contains(TIntList::iterator iter_begin, TIntList::iterator iter_end, int element)
872{
873    return std::find(iter_begin, iter_end, element) != iter_end;
874}
875
876void optimal_swap(THierarchicalCluster * cluster, int u, int w, cluster_ordering & ordering)
877{
878    if (cluster->branches)
879    {
880        assert(ordering.find(m_element(cluster, u, w)) != ordering.end());
881        ordering_element ord = ordering[m_element(cluster, u, w)];
882
883        PHierarchicalCluster left_right = (ord.left->branches)? ord.left->branches->at(1) : PHierarchicalCluster(NULL);
884        PHierarchicalCluster right_left = (ord.right->branches)? ord.right->branches->at(0) : PHierarchicalCluster(NULL);
885
886        TIntList & mapping = cluster->mapping.getReference();
887        if (left_right && !contains(mapping.begin() + left_right->first,
888                mapping.begin() + left_right->last, ord.m))
889        {
890            assert(!contains(mapping.begin() + left_right->first,
891                             mapping.begin() + left_right->last, ord.m));
892            ord.left->swap();
893            left_right = ord.left->branches->at(1);
894            assert(contains(mapping.begin() + left_right->first,
895                            mapping.begin() + left_right->last, ord.m));
896        }
897        optimal_swap(ord.left, ord.u, ord.m, ordering);
898
899        assert(mapping.at(ord.left->first) == ord.u);
900        assert(mapping.at(ord.left->last - 1) == ord.m);
901
902        if (right_left && !contains(mapping.begin() + right_left->first,
903                mapping.begin() + right_left->last, ord.k))
904        {
905            assert(!contains(mapping.begin() + right_left->first,
906                             mapping.begin() + right_left->last, ord.k));
907            ord.right->swap();
908            right_left = ord.right->branches->at(0);
909            assert(contains(mapping.begin() + right_left->first,
910                            mapping.begin() + right_left->last, ord.k));
911        }
912        optimal_swap(ord.right, ord.k, ord.w, ordering);
913
914        assert(mapping.at(ord.right->first) == ord.k);
915        assert(mapping.at(ord.right->last - 1) == ord.w);
916
917        assert(mapping.at(cluster->first) == ord.u);
918        assert(mapping.at(cluster->last - 1) == ord.w);
919    }
920}
921
922PHierarchicalCluster THierarchicalClusterOrdering::operator() (
923        PHierarchicalCluster root,
924        PSymMatrix matrix)
925{
926    join_scores M; // scores
927    cluster_ordering ordering;
928    order_clusters(root.getReference(), matrix.getReference(), M, ordering,
929            progress_callback.getUnwrappedPtr());
930
931    int u = 0, w = 0;
932    int min_u = 0, min_w = 0;
933    float min_score = std::numeric_limits<float>::infinity();
934    TIntList & mapping = root->mapping.getReference();
935
936    for (TIntList::iterator u_iter = mapping.begin() + root->branches->at(0)->first;
937            u_iter != mapping.begin() + root->branches->at(0)->last;
938            u_iter++)
939        for (TIntList::iterator w_iter = mapping.begin() + root->branches->at(1)->first;
940                    w_iter != mapping.begin() + root->branches->at(1)->last;
941                    w_iter++)
942        {
943            u = *u_iter; w = *w_iter;
944            m_element el(root.getUnwrappedPtr(), u, w);
945            if (M[el] < min_score)
946            {
947                min_score = M[el];
948                min_u = u;
949                min_w = w;
950            }
951        }
952//  std::cout << "Min score "<< min_score << endl;
953
954    optimal_swap(root.getUnwrappedPtr(), min_u, min_w, ordering);
955    return root;
956}
Note: See TracBrowser for help on using the repository browser.