TTK
Loading...
Searching...
No Matches
MergeTreeClustering.h
Go to the documentation of this file.
1
27
28#define treesMatchingVector \
29 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
30#define matchingVectorType std::vector<treesMatchingVector>
31
32#pragma once
33
34#include <random>
35
36// ttk common includes
37#include <Debug.h>
38
39#include "MergeTreeBarycenter.h"
40
41namespace ttk {
42
47 // TODO rename dataType2 to dataType and remove template everywhere else in
48 // this class
49 template <class dataType2>
50 class MergeTreeClustering : virtual public Debug, public MergeTreeBarycenter {
51
52 private:
53 bool parallelizeUpdate_ = true;
54
55 unsigned int noCentroids_ = 2;
56
57 // Progressive parameters
58 int noIterationC_ = 0;
59 double addDeletedNodesTime_ = 0;
60
61 // Accelerated KMeans
62 bool acceleratedInitialized_ = false;
63 std::vector<std::vector<double>> lowerBound_;
64 std::vector<double> upperBound_;
65 std::vector<int> bestCentroid_, oldBestCentroid_;
66 std::vector<double> bestDistance_;
67 std::vector<bool> recompute_;
68 std::vector<ftm::MergeTree<dataType2>> oldCentroids_, oldCentroids2_;
69
70 // Clean correspondence
71 std::vector<std::vector<int>> trees2NodeCorr_;
72
73 public:
76 "MergeTreeClustering"); // inherited from Debug: prefix will be printed
77 // at the beginning of every msg
78 }
79 ~MergeTreeClustering() override = default;
80
81 void setNoCentroids(unsigned int noCentroidsT) {
82 noCentroids_ = noCentroidsT;
83 }
84
85 void setMixtureCoefficient(double coef) {
87 }
88
89 std::vector<std::vector<int>> getTrees2NodeCorr() {
90 return trees2NodeCorr_;
91 }
92
97 // ------------------------------------------------------------------------
98 // Initialization
99 // ------------------------------------------------------------------------
100 // KMeans++ init
101 template <class dataType>
103 std::vector<ftm::FTMTree_MT *> &trees,
104 std::vector<ftm::FTMTree_MT *> &trees2,
105 std::vector<std::vector<ftm::MergeTree<dataType>>> &allCentroids) {
106 allCentroids.resize(
107 2, std::vector<ftm::MergeTree<dataType>>(noCentroids_));
108 std::vector<dataType> distances(
109 trees.size(), std::numeric_limits<dataType>::max());
110
111 // Manage size limited trees
112 double limitPercent = barycenterSizeLimitPercent_ / noCentroids_;
113 std::vector<ftm::MergeTree<dataType>> mTreesLimited, mTrees2Limited;
114 bool doSizeLimit
115 = (limitPercent > 0.0 or barycenterMaximumNumberOfPairs_ > 0);
116 if(doSizeLimit) {
117 getSizeLimitedTrees<dataType>(
118 trees, barycenterMaximumNumberOfPairs_, limitPercent, mTreesLimited);
119 if(trees2.size() != 0)
120 getSizeLimitedTrees<dataType>(trees2, barycenterMaximumNumberOfPairs_,
121 limitPercent, mTrees2Limited);
122 }
123
124 // Init centroids
125 for(unsigned int i = 0; i < noCentroids_; ++i) {
126 int bestIndex = -1;
127 if(i == 0) {
128 bestIndex = getBestInitTreeIndex<dataType>(
129 trees, trees2, limitPercent, false);
130 } else {
131 // Create vector of probabilities
132 double sum = 0;
133 for(auto val : distances)
134 sum += val;
135 double bestValue = std::numeric_limits<double>::lowest();
136 std::vector<double> probabilities(trees.size());
137 for(unsigned int j = 0; j < distances.size(); ++j) {
138 probabilities[j]
139 = (sum != 0 ? distances[j] / sum : 1.0 / distances.size());
140 if(probabilities[j] > bestValue) {
141 bestValue = probabilities[j];
142 bestIndex = j;
143 }
144 }
145 if(not deterministic_) {
146 std::random_device rd;
147 std::default_random_engine generator(rd());
148 std::discrete_distribution<int> distribution(
149 probabilities.begin(), probabilities.end());
150 bestIndex = distribution(generator);
151 }
152 }
153 printMsg(
154 "Init index : " + std::to_string(bestIndex), debug::Priority::DETAIL);
155 // Create new centroid
156 allCentroids[0][i]
157 = ftm::copyMergeTree<dataType>(trees[bestIndex], true);
158 limitSizeBarycenter(allCentroids[0][i], trees, limitPercent);
159 ftm::cleanMergeTree<dataType>(allCentroids[0][i]);
160 if(trees2.size() != 0) {
161 allCentroids[1][i]
162 = ftm::copyMergeTree<dataType>(trees2[bestIndex], true);
163 limitSizeBarycenter(allCentroids[1][i], trees2, limitPercent);
164 ftm::cleanMergeTree<dataType>(allCentroids[1][i]);
165 }
166
167 if(i == noCentroids_ - 1)
168 continue;
169#ifdef TTK_ENABLE_OPENMP44
170#pragma omp parallel for schedule(dynamic) shared(allCentroids) \
171 num_threads(this->threadNumber_) if(parallelize_)
172#endif
173 for(unsigned int j = 0; j < trees.size(); ++j) {
174 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching,
175 matching2;
176 dataType distanceT, distanceT2;
177 ftm::FTMTree_MT *treeToUse
178 = (doSizeLimit ? &(mTreesLimited[j].tree) : trees[j]);
179 computeOneDistance<dataType>(treeToUse, allCentroids[0][i], matching,
180 distanceT, useDoubleInput_);
181 if(trees2.size() != 0) {
182 ftm::FTMTree_MT *tree2ToUse
183 = (doSizeLimit ? &(mTrees2Limited[j].tree) : trees2[j]);
184 computeOneDistance<dataType>(tree2ToUse, allCentroids[1][i],
185 matching2, distanceT2, useDoubleInput_,
186 false);
187 distanceT = mixDistances<dataType>(distanceT, distanceT2);
188 }
189 distances[j] = std::min(distances[j], distanceT);
190 }
191 }
192 }
193
194 template <class dataType>
195 void initNewCentroid(std::vector<ftm::FTMTree_MT *> &trees,
196 ftm::MergeTree<dataType> &centroid,
197 int noNewCentroid) {
198 std::vector<std::tuple<double, int>> distancesAndIndexes(
199 bestDistance_.size());
200 for(unsigned int i = 0; i < bestDistance_.size(); ++i)
201 distancesAndIndexes[i] = std::make_tuple(-bestDistance_[i], i);
202 std::sort(distancesAndIndexes.begin(), distancesAndIndexes.end());
203 int const bestIndex = std::get<1>(distancesAndIndexes[noNewCentroid]);
204 centroid = ftm::copyMergeTree<dataType>(trees[bestIndex], true);
205 limitSizeBarycenter(centroid, trees);
206 ftm::cleanMergeTree<dataType>(centroid);
207 }
208
209 template <class dataType>
211 std::vector<ftm::FTMTree_MT *> &trees,
212 std::vector<ftm::MergeTree<dataType>> &centroids,
213 std::vector<ftm::FTMTree_MT *> &ttkNotUsed(trees2)) {
214 lowerBound_.clear();
215 lowerBound_.resize(
216 trees.size(), std::vector<double>(centroids.size(), 0));
217 upperBound_.clear();
218 upperBound_.resize(trees.size(), std::numeric_limits<double>::max());
219 bestCentroid_.clear();
220 bestCentroid_.resize(trees.size(), -1);
221 oldBestCentroid_.clear();
222 oldBestCentroid_.resize(trees.size(), -1);
223 bestDistance_.clear();
224 bestDistance_.resize(trees.size(), std::numeric_limits<double>::max());
225 recompute_.clear();
226 recompute_.resize(trees.size(), true);
227 }
228
229 template <class dataType>
230 void
231 initAcceleratedKMeans(std::vector<ftm::FTMTree_MT *> &trees,
232 std::vector<ftm::MergeTree<dataType>> &centroids,
233 std::vector<ftm::FTMTree_MT *> &trees2,
234 std::vector<ftm::MergeTree<dataType>> &centroids2) {
235 acceleratedInitialized_ = true;
236 std::vector<std::tuple<int, int>> assignmentC;
237 std::vector<dataType> bestDistanceT(
238 trees.size(), std::numeric_limits<dataType>::max());
239 assignmentCentroidsNaive<dataType>(
240 trees, centroids, assignmentC, bestDistanceT, trees2, centroids2);
241 for(unsigned int i = 0; i < bestDistanceT.size(); ++i)
242 bestDistance_[i] = bestDistanceT[i];
243 for(auto asgn : assignmentC)
244 bestCentroid_[std::get<1>(asgn)] = std::get<0>(asgn);
245 for(unsigned int i = 0; i < bestDistance_.size(); ++i)
246 upperBound_[i] = bestDistance_[i];
247 }
248
249 template <class dataType>
250 void copyCentroids(std::vector<ftm::MergeTree<dataType>> &centroids,
251 std::vector<ftm::MergeTree<dataType>> &oldCentroids) {
252 oldCentroids.clear();
253 for(unsigned int i = 0; i < centroids.size(); ++i)
254 oldCentroids.push_back(ftm::copyMergeTree<dataType>(centroids[i]));
255 }
256
257 // ------------------------------------------------------------------------
258 // Assignment
259 // ------------------------------------------------------------------------
260 template <class dataType>
262 std::vector<ftm::FTMTree_MT *> &trees,
263 std::vector<ftm::MergeTree<dataType>> &centroids,
264 std::vector<std::tuple<int, int>> &assignmentC,
265 std::vector<dataType> &bestDistanceT,
266 std::vector<ftm::FTMTree_MT *> &trees2,
267 std::vector<ftm::MergeTree<dataType>> &centroids2) {
268 if(not acceleratedInitialized_) {
269 initAcceleratedKMeans<dataType>(trees, centroids, trees2, centroids2);
270 } else {
271 // Compute distance between old and new corresponding centroids
272 std::vector<dataType> distanceShift(centroids.size()),
273 distanceShift2(centroids2.size());
274#ifdef TTK_ENABLE_OPENMP4
275#pragma omp parallel for schedule(dynamic) \
276 shared(centroids, centroids2, oldCentroids_, oldCentroids2_) \
277 num_threads(this->threadNumber_) if(parallelize_)
278#endif
279 for(unsigned int i = 0; i < centroids.size(); ++i) {
280 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching,
281 matching2;
282 computeOneDistance<dataType>(centroids[i], oldCentroids_[i], matching,
283 distanceShift[i], useDoubleInput_);
284 if(trees2.size() != 0) {
285 computeOneDistance<dataType>(centroids2[i], oldCentroids2_[i],
286 matching2, distanceShift2[i],
287 useDoubleInput_, false);
288 distanceShift[i]
289 = mixDistances<dataType>(distanceShift[i], distanceShift2[i]);
290 }
291 }
292
293 // Step 5
294 for(unsigned int i = 0; i < trees.size(); ++i)
295 for(unsigned int c = 0; c < centroids.size(); ++c)
296 lowerBound_[i][c]
297 = std::max(lowerBound_[i][c] - distanceShift[c], 0.0);
298
299 // Step 6
300 for(unsigned int i = 0; i < trees.size(); ++i) {
301 upperBound_[i] = upperBound_[i] + distanceShift[bestCentroid_[i]];
302 recompute_[i] = true;
303 }
304 }
305
306 // Step 1
307 std::vector<std::vector<double>> centroidsDistance, centroidsDistance2;
308 getCentroidsDistanceMatrix<dataType>(
309 centroids, centroidsDistance, useDoubleInput_);
310 if(trees2.size() != 0) {
311 getCentroidsDistanceMatrix<dataType>(
312 centroids2, centroidsDistance2, useDoubleInput_, false);
313 mixDistancesMatrix(centroidsDistance, centroidsDistance2);
314 }
315 std::vector<double> centroidScore(
316 centroids.size(), std::numeric_limits<double>::max());
317 for(unsigned int i = 0; i < centroids.size(); ++i)
318 for(unsigned int j = i + 1; j < centroids.size(); ++j) {
319 if(0.5 * centroidsDistance[i][j] < centroidScore[i])
320 centroidScore[i] = 0.5 * centroidsDistance[i][j];
321 if(0.5 * centroidsDistance[i][j] < centroidScore[j])
322 centroidScore[j] = 0.5 * centroidsDistance[i][j];
323 }
324
325 // Step 2
326 std::vector<bool> identified(trees.size());
327 for(unsigned int i = 0; i < trees.size(); ++i)
328 identified[i] = (upperBound_[i] <= centroidScore[bestCentroid_[i]]);
329
330 // Step 3
331#ifdef TTK_ENABLE_OPENMP4
332#pragma omp parallel for schedule(dynamic) shared(centroids, centroids2) \
333 num_threads(this->threadNumber_) if(parallelize_)
334#endif
335 for(unsigned int i = 0; i < trees.size(); ++i)
336 for(unsigned int c = 0; c < centroids.size(); ++c) {
337 if(not identified[i] and (int) c != bestCentroid_[i]
338 and upperBound_[i] > lowerBound_[i][c]
339 and upperBound_[i]
340 > 0.5 * centroidsDistance[bestCentroid_[i]][c]) {
341 // Step 3a
342 if(recompute_[i]) {
343 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>
344 matching, matching2;
345 dataType distance, distance2;
346 computeOneDistance<dataType>(trees[i],
347 centroids[bestCentroid_[i]],
348 matching, distance, useDoubleInput_);
349 if(trees2.size() != 0) {
350 computeOneDistance<dataType>(
351 trees2[i], centroids2[bestCentroid_[i]], matching2, distance2,
352 useDoubleInput_, false);
353 distance = mixDistances<dataType>(distance, distance2);
354 }
355 recompute_[i] = false;
356 lowerBound_[i][bestCentroid_[i]] = distance;
357 upperBound_[i] = distance;
358 bestDistance_[i] = distance;
359 } else {
360 bestDistance_[i] = upperBound_[i];
361 }
362 // Step 3b
363 if(bestDistance_[i] > lowerBound_[i][c]
364 and bestDistance_[i]
365 > 0.5 * centroidsDistance[bestCentroid_[i]][c]) {
366 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>
367 matching, matching2;
368 dataType distance, distance2;
369 computeOneDistance<dataType>(
370 trees[i], centroids[c], matching, distance, useDoubleInput_);
371 if(trees2.size() != 0) {
372 computeOneDistance<dataType>(trees2[i], centroids2[c],
373 matching2, distance2,
374 useDoubleInput_, false);
375 distance = mixDistances<dataType>(distance, distance2);
376 }
377 lowerBound_[i][c] = distance;
378 if(distance < bestDistance_[i]) {
379 bestCentroid_[i] = c;
380 upperBound_[i] = distance;
381 bestDistance_[i] = distance;
382 }
383 }
384 }
385 }
386
387 // Copy centroids for next step
388 copyCentroids<dataType>(centroids, oldCentroids_);
389 if(trees2.size() != 0)
390 copyCentroids<dataType>(centroids2, oldCentroids2_);
391
392 // Manage output
393 for(unsigned int i = 0; i < bestDistance_.size(); ++i)
394 bestDistanceT[i] = bestDistance_[i];
395 for(unsigned int i = 0; i < bestCentroid_.size(); ++i)
396 assignmentC.emplace_back(bestCentroid_[i], i);
397 }
398
399 template <class dataType>
400 void
401 assignmentCentroids(std::vector<ftm::FTMTree_MT *> &trees,
402 std::vector<ftm::MergeTree<dataType>> &centroids,
403 std::vector<std::tuple<int, int>> &assignmentC,
404 std::vector<dataType> &bestDistanceT,
405 std::vector<ftm::FTMTree_MT *> &trees2,
406 std::vector<ftm::MergeTree<dataType>> &centroids2) {
407 oldBestCentroid_ = bestCentroid_;
408 assignmentCentroidsAccelerated<dataType>(
409 trees, centroids, assignmentC, bestDistanceT, trees2, centroids2);
410 }
411
412 template <class dataType>
414 std::vector<ftm::FTMTree_MT *> &trees,
415 std::vector<ftm::MergeTree<dataType>> &centroids,
416 matchingVectorType &matchingsC,
417 std::vector<std::tuple<int, int>> &assignmentC,
418 std::vector<dataType> &bestDistanceT,
419 std::vector<ftm::FTMTree_MT *> &trees2,
420 std::vector<ftm::MergeTree<dataType>> &centroids2,
421 matchingVectorType &matchingsC2) {
422 int noC = centroids.size();
423 std::vector<std::vector<ftm::FTMTree_MT *>> assignedTrees(noC),
424 assignedTrees2(noC);
425 std::vector<std::vector<int>> assignedTreesIndex(noC);
426
427 for(auto asgn : assignmentC) {
428 assignedTreesIndex[std::get<0>(asgn)].push_back(std::get<1>(asgn));
429 assignedTrees[std::get<0>(asgn)].push_back(trees[std::get<1>(asgn)]);
430 if(trees2.size() != 0)
431 assignedTrees2[std::get<0>(asgn)].push_back(
432 trees2[std::get<1>(asgn)]);
433 }
434
435#ifdef TTK_ENABLE_OPENMP4
436#pragma omp parallel for schedule(dynamic) shared(centroids, centroids2) \
437 num_threads(this->threadNumber_) if(parallelize_)
438#endif
439 for(unsigned int i = 0; i < centroids.size(); ++i) {
440 std::vector<dataType> distances(assignedTrees[i].size(), 0);
441 std::vector<dataType> distances2(assignedTrees[i].size(), 0);
442 treesMatchingVector matching(trees.size()), matching2(trees2.size());
443 assignment<dataType>(
444 assignedTrees[i], centroids[i], matching, distances, useDoubleInput_);
445 matchingsC[i] = matching;
446 if(trees2.size() != 0) {
447 assignment<dataType>(assignedTrees2[i], centroids2[i], matching2,
448 distances2, useDoubleInput_, false);
449 matchingsC2[i] = matching2;
450 for(unsigned int j = 0; j < assignedTreesIndex[i].size(); ++j)
451 distances[j] = mixDistances<dataType>(distances[j], distances2[j]);
452 }
453 for(unsigned int j = 0; j < assignedTreesIndex[i].size(); ++j) {
454 int const index = assignedTreesIndex[i][j];
455 bestDistanceT[index] = distances[j];
456 }
457 }
458 }
459
460 template <class dataType>
462 std::vector<ftm::FTMTree_MT *> &trees,
463 std::vector<ftm::MergeTree<dataType>> &centroids,
464 std::vector<std::tuple<int, int>> &assignmentC,
465 std::vector<dataType> &bestDistanceT,
466 std::vector<ftm::FTMTree_MT *> &trees2,
467 std::vector<ftm::MergeTree<dataType>> &centroids2) {
468 std::vector<int> bestCentroidT(trees.size(), -1);
469
470#ifdef TTK_ENABLE_OPENMP4
471#pragma omp parallel for schedule(dynamic) shared(centroids, centroids2) \
472 num_threads(this->threadNumber_) if(parallelize_)
473#endif
474 for(unsigned int i = 0; i < trees.size(); ++i) {
475 for(unsigned int j = 0; j < centroids.size(); ++j) {
476 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching;
477 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching2;
478 dataType distance, distance2;
479 computeOneDistance<dataType>(
480 trees[i], centroids[j], matching, distance, useDoubleInput_);
481 if(trees2.size() != 0) {
482 computeOneDistance<dataType>(trees2[i], centroids2[j], matching2,
483 distance2, useDoubleInput_, false);
484 distance = mixDistances<dataType>(distance, distance2);
485 }
486 if(distance < bestDistanceT[i]) {
487 bestDistanceT[i] = distance;
488 bestDistance_[i] = distance;
489 bestCentroidT[i] = j;
490 bestCentroid_[i] = j;
491 }
492 }
493 }
494
495 for(unsigned int i = 0; i < bestCentroidT.size(); ++i)
496 assignmentC.emplace_back(bestCentroidT[i], i);
497 }
498
499 template <class dataType>
501 std::vector<ftm::MergeTree<dataType>> &centroids,
502 std::vector<std::vector<double>> &distanceMatrix,
503 bool useDoubleInput = false,
504 bool isFirstInput = true) {
505 std::vector<ftm::FTMTree_MT *> trees(centroids.size());
506 for(size_t i = 0; i < centroids.size(); ++i) {
507 trees[i] = &(centroids[i].tree);
508 }
509 getDistanceMatrix<dataType>(
510 trees, distanceMatrix, useDoubleInput, isFirstInput);
511 }
512
514 std::vector<int> &nodeCorr,
515 std::vector<int> &assignedTreesIndex) {
516 for(int const i : assignedTreesIndex) {
517 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> newMatching;
518 for(auto tup : matchingT[i])
519 newMatching.emplace_back(
520 nodeCorr[std::get<0>(tup)], std::get<1>(tup), std::get<2>(tup));
521 matchingT[i] = newMatching;
522 }
523 }
524
525 // ------------------------------------------------------------------------
526 // Update
527 // ------------------------------------------------------------------------
528 bool samePreviousAssignment(int clusterId) {
529 for(unsigned int i = 0; i < bestCentroid_.size(); ++i)
530 if(bestCentroid_[i] == clusterId
531 and bestCentroid_[i] != oldBestCentroid_[i])
532 return false;
533 return true;
534 }
535
536 template <class dataType>
537 bool updateCentroids(std::vector<ftm::FTMTree_MT *> &trees,
538 std::vector<ftm::MergeTree<dataType>> &centroids,
539 std::vector<double> &alphas,
540 std::vector<std::tuple<int, int>> &assignmentC) {
541 bool oneCentroidUpdated = false;
542 int noC = centroids.size();
543 std::vector<std::vector<ftm::FTMTree_MT *>> assignedTrees(noC);
544 std::vector<std::vector<int>> assignedTreesIndex(noC);
545 std::vector<std::vector<double>> assignedAlphas(noC);
546
547 for(auto asgn : assignmentC) {
548 assignedTrees[std::get<0>(asgn)].push_back(trees[std::get<1>(asgn)]);
549 assignedTreesIndex[std::get<0>(asgn)].push_back(std::get<1>(asgn));
550 assignedAlphas[std::get<0>(asgn)].push_back(alphas[std::get<1>(asgn)]);
551 }
552
553 int cpt = 0;
554 std::vector<int> noNewCentroid(centroids.size(), -1);
555 for(unsigned int i = 0; i < centroids.size(); ++i)
556 if(assignedTrees[i].size() == 0) {
557 noNewCentroid[i] = cpt;
558 ++cpt;
559 }
560
561#ifdef TTK_ENABLE_OPENMP4
562#pragma omp parallel num_threads(this->threadNumber_) \
563 shared(centroids) if(parallelize_ and parallelizeUpdate_)
564 {
565#pragma omp single nowait
566 {
567#endif
568 for(unsigned int i = 0; i < centroids.size(); ++i) {
569#ifdef TTK_ENABLE_OPENMP4
570#pragma omp task firstprivate(i) shared(centroids)
571 {
572#endif
573 if(assignedTrees[i].size() == 0) {
574 // Init new centroid if no trees are assigned to it
575 initNewCentroid<dataType>(
576 trees, centroids[i], noNewCentroid[i]);
577 for(unsigned int t = 0; t < trees.size(); ++t)
578 lowerBound_[t][i] = 0;
579 } else if(assignedTrees[i].size() == 1) {
580 centroids[i]
581 = ftm::copyMergeTree<dataType>(assignedTrees[i][0], true);
582 limitSizeBarycenter(centroids[i], assignedTrees[i]);
583 ftm::cleanMergeTree<dataType>(centroids[i]);
584 } else if(not samePreviousAssignment(i)) {
585 // Do not update if same previous assignment
586 // And compute barycenter of the assigned trees otherwise
587 oneCentroidUpdated = true;
588 double alphasSum = 0;
589 for(unsigned int j = 0; j < assignedAlphas[i].size(); ++j)
590 alphasSum += assignedAlphas[i][j];
591 for(unsigned int j = 0; j < assignedAlphas[i].size(); ++j)
592 assignedAlphas[i][j] /= alphasSum;
593 treesMatchingVector matching(assignedTrees[i].size());
594 computeOneBarycenter<dataType>(
595 assignedTrees[i], centroids[i], assignedAlphas[i], matching);
596 std::vector<ftm::idNode> deletedNodesT;
597 persistenceThresholding<dataType>(
598 &(centroids[i].tree), 0, deletedNodesT);
599 ftm::cleanMergeTree<dataType>(centroids[i]);
600 }
601#ifdef TTK_ENABLE_OPENMP4
602 } // pragma omp task
603#endif
604 }
605#ifdef TTK_ENABLE_OPENMP4
606#pragma omp taskwait
607 } // pragma omp single nowait
608 } // pragma omp parallel
609#endif
610 return oneCentroidUpdated;
611 }
612
613 template <class dataType>
615 std::vector<ftm::FTMTree_MT *> &trees,
616 ftm::MergeTree<dataType> &baryMergeTree,
617 std::vector<double> &alphas,
618 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
619 &finalMatchings) {
620 MergeTreeBarycenter mergeTreeBary;
621 mergeTreeBary.setDebugLevel(std::min(debugLevel_, 2));
622 mergeTreeBary.setBranchDecomposition(true);
624 mergeTreeBary.setKeepSubtree(keepSubtree_);
626 mergeTreeBary.setIsCalled(true);
627 mergeTreeBary.setThreadNumber(this->threadNumber_);
628 mergeTreeBary.setDistanceSquaredRoot(true); // squared root
630 mergeTreeBary.setDeterministic(deterministic_);
631 mergeTreeBary.setTol(tol_);
635
636 mergeTreeBary.computeBarycenter<dataType>(
637 trees, baryMergeTree, alphas, finalMatchings);
638
639 addDeletedNodesTime_ += mergeTreeBary.getAddDeletedNodesTime();
640 }
641
642 // ------------------------------------------------------------------------
643 // Main Functions
644 // ------------------------------------------------------------------------
645 template <class dataType>
646 void computeCentroids(std::vector<ftm::FTMTree_MT *> &trees,
647 std::vector<ftm::MergeTree<dataType>> &centroids,
648 matchingVectorType &outputMatching,
649 std::vector<double> &alphas,
650 std::vector<int> &clusteringAssignment,
651 std::vector<ftm::FTMTree_MT *> &trees2,
652 std::vector<ftm::MergeTree<dataType>> &centroids2,
653 matchingVectorType &outputMatching2) {
654 Timer t_clust;
655
656 printCentroidsStats(centroids, centroids2);
657
658 // Run
659 int noCentroidsT = centroids.size();
660 bool converged = false;
661 dataType inertia = -1;
662 dataType minInertia = std::numeric_limits<dataType>::max();
663 int cptBlocked = 0;
664 noIterationC_ = 0;
665 std::vector<std::tuple<int, int>> assignmentC;
666 std::vector<dataType> bestDistanceT(
667 trees.size(), std::numeric_limits<dataType>::max());
668 while(not converged) {
669 ++noIterationC_;
670
672 std::stringstream ssIter;
673 ssIter << "Iteration " << noIterationC_;
674 printMsg(ssIter.str());
675
676 // --- Assignment
677 Timer t_assignment;
678 assignmentCentroids<dataType>(
679 trees, centroids, assignmentC, bestDistanceT, trees2, centroids2);
680 auto t_assignment_time = t_assignment.getElapsedTime();
681 printMsg("Assignment", 1, t_assignment_time, this->threadNumber_);
682
683 // --- Update
684 Timer t_update;
685 bool trees1Updated = true, trees2Updated = true;
686 trees1Updated
687 = updateCentroids<dataType>(trees, centroids, alphas, assignmentC);
688 if(trees2.size() != 0)
689 trees2Updated = updateCentroids<dataType>(
690 trees2, centroids2, alphas, assignmentC);
691 auto t_update_time = t_update.getElapsedTime();
692 printMsg("Update", 1, t_update_time, this->threadNumber_);
693 printCentroidsStats(centroids, centroids2);
694
695 // --- Check convergence
696 dataType currentInertia = 0;
697 for(auto distance : bestDistanceT)
698 currentInertia += distance * distance;
699 converged = std::abs((double)(inertia - currentInertia)) < 0.01;
700 inertia = currentInertia;
701 std::stringstream ss3;
702 ss3 << "Inertia : " << inertia;
703 printMsg(ss3.str());
704
705 minInertia = std::min(minInertia, inertia);
706 if(not converged) {
707 cptBlocked += (minInertia < inertia) ? 1 : 0;
708 converged = (cptBlocked >= 10);
709 }
710
711 // Converged if barycenters were not updated (same assignment than last
712 // iteration)
713 converged = converged or (not trees1Updated and not trees2Updated);
714
715 // --- Reset vectors
716 if(not converged) {
717 assignmentC.clear();
718 bestDistanceT.clear();
719 bestDistanceT.resize(
720 trees.size(), std::numeric_limits<dataType>::max());
721 }
722 }
723
724 // Final processing
726 printMsg("Final assignment");
727 matchingVectorType matchingsC(noCentroidsT);
728 matchingVectorType matchingsC2(noCentroidsT);
729 finalAssignmentCentroids<dataType>(trees, centroids, matchingsC,
730 assignmentC, bestDistanceT, trees2,
731 centroids2, matchingsC2);
732 for(auto dist : bestDistanceT)
733 finalDistances_.push_back(dist);
734 dataType currentInertia = 0;
735 for(auto distance : bestDistanceT)
736 currentInertia += distance * distance;
737 std::stringstream ss;
738 ss << "Inertia : " << currentInertia;
739 printMsg(ss.str());
740
741 // Manage output
742 std::vector<int> cptCentroid(centroids.size(), 0);
743 for(auto asgn : assignmentC) {
744 int const centroid = std::get<0>(asgn);
745 int const tree = std::get<1>(asgn);
746 // std::cout << centroid << " " << tree << std::endl;
747 clusteringAssignment[tree] = centroid;
748 outputMatching[centroid][tree]
749 = matchingsC[centroid][cptCentroid[centroid]];
750 if(trees2.size() != 0)
751 outputMatching2[centroid][tree]
752 = matchingsC2[centroid][cptCentroid[centroid]];
753 ++cptCentroid[centroid];
754 }
755
756 auto clusteringTime = t_clust.getElapsedTime() - addDeletedNodesTime_;
757 printMsg("Total", 1, clusteringTime, this->threadNumber_);
758 }
759
760 template <class dataType>
761 void execute(std::vector<ftm::MergeTree<dataType>> &trees,
762 matchingVectorType &outputMatching,
763 std::vector<double> &alphas,
764 std::vector<int> &clusteringAssignment,
765 std::vector<ftm::MergeTree<dataType>> &trees2,
766 matchingVectorType &outputMatching2,
767 std::vector<ftm::MergeTree<dataType>> &centroids,
768 std::vector<ftm::MergeTree<dataType>> &centroids2) {
769 // --- Preprocessing
770 // std::vector<ftm::FTMTree_MT*> oldTrees, oldTrees2;
771 treesNodeCorr_.resize(trees.size());
772 preprocessingClustering<dataType>(trees, treesNodeCorr_);
773 if(trees2.size() != 0) {
774 trees2NodeCorr_.resize(trees2.size());
775 preprocessingClustering<dataType>(trees2, trees2NodeCorr_, false);
776 }
777 std::vector<ftm::FTMTree_MT *> treesT;
778 ftm::mergeTreeToFTMTree<dataType>(trees, treesT);
779 std::vector<ftm::FTMTree_MT *> treesT2;
780 ftm::mergeTreeToFTMTree<dataType>(trees2, treesT2);
781 useDoubleInput_ = (trees2.size() != 0);
782
783 // --- Init centroids
784 std::vector<std::vector<ftm::MergeTree<dataType>>> allCentroids;
785 initCentroids<dataType>(treesT, treesT2, allCentroids);
786 centroids = allCentroids[0];
787 if(trees2.size() != 0)
788 centroids2 = allCentroids[1];
789 /*for(unsigned int i = 0; i < centroids.size(); ++i){
790 verifyBranchDecompositionInconsistency<dataType>(centroids[i]->tree);
791 if(trees2.size() != 0)
792 verifyBranchDecompositionInconsistency<dataType>(centroids2[i]->tree);
793 }*/
794
795 // --- Init accelerated kmeans
796 initAcceleratedKMeansVectors<dataType>(treesT, centroids, treesT2);
797
798 // --- Execute
799 computeCentroids<dataType>(treesT, centroids, outputMatching, alphas,
800 clusteringAssignment, treesT2, centroids2,
801 outputMatching2);
802
803 // --- Postprocessing
804 if(postprocess_) {
805 // fixMergedRootOriginClustering<dataType>(centroids);
806 postprocessingClustering<dataType>(
807 trees, centroids, outputMatching, clusteringAssignment);
808 /*if(trees2.size() != 0){
809 putBackMinMaxPair<dataType>(centroids, centroids2);
810 postprocessingClustering<dataType>(trees2, centroids2,
811 outputMatching2, clusteringAssignment);
812 }*/
813 }
814 }
815
816 template <class dataType>
817 void execute(std::vector<ftm::MergeTree<dataType>> &trees,
818 matchingVectorType &outputMatching,
819 std::vector<int> &clusteringAssignment,
820 std::vector<ftm::MergeTree<dataType>> &trees2,
821 matchingVectorType &outputMatching2,
822 std::vector<ftm::MergeTree<dataType>> &centroids,
823 std::vector<ftm::MergeTree<dataType>> &centroids2) {
824 if(trees2.size() != 0)
825 printMsg("Use join and split trees");
826
827 std::vector<double> alphas;
828 for(unsigned int i = 0; i < trees.size(); ++i)
829 alphas.push_back(1.0 / trees.size());
830
831 execute<dataType>(trees, outputMatching, alphas, clusteringAssignment,
832 trees2, outputMatching2, centroids, centroids2);
833 }
834
835 template <class dataType>
836 void execute(std::vector<ftm::MergeTree<dataType>> &trees,
837 matchingVectorType &outputMatching,
838 std::vector<int> &clusteringAssignment,
839 std::vector<ftm::MergeTree<dataType>> &centroids) {
840 std::vector<ftm::MergeTree<dataType>> trees2, centroids2;
841 matchingVectorType outputMatching2 = matchingVectorType();
842 execute<dataType>(trees, outputMatching, clusteringAssignment, trees2,
843 outputMatching2, centroids, centroids2);
844 }
845
846 // ------------------------------------------------------------------------
847 // Preprocessing
848 // ------------------------------------------------------------------------
849 template <class dataType>
851 std::vector<std::vector<int>> &nodeCorr,
852 bool useMinMaxPairT = true) {
853 for(unsigned int i = 0; i < trees.size(); ++i) {
854 preprocessingPipeline<dataType>(
856 branchDecomposition_, useMinMaxPairT, cleanTree_, nodeCorr[i]);
857 if(trees.size() < 40)
858 printTreeStats(trees[i]);
859 }
860 printTreesStats(trees);
861 }
862
863 // ------------------------------------------------------------------------
864 // Postprocessing
865 // ------------------------------------------------------------------------
866 template <class dataType>
868 std::vector<ftm::MergeTree<dataType>> &centroids) {
869 for(unsigned int i = 0; i < centroids.size(); ++i)
870 fixMergedRootOriginBarycenter<dataType>(centroids[i]);
871 }
872
873 template <class dataType>
874 void putBackMinMaxPair(std::vector<ftm::MergeTree<dataType>> &centroids,
875 std::vector<ftm::MergeTree<dataType>> &centroids2) {
876 for(unsigned int i = 0; i < centroids2.size(); ++i)
877 copyMinMaxPair(centroids[i], centroids2[i]);
878 }
879
880 template <class dataType>
881 void
883 std::vector<ftm::MergeTree<dataType>> &centroids,
884 matchingVectorType &outputMatching,
885 std::vector<int> &clusteringAssignment) {
886 for(unsigned int i = 0; i < trees.size(); ++i)
887 postprocessingPipeline<dataType>(&(trees[i].tree));
888 for(unsigned int i = 0; i < centroids.size(); ++i)
889 postprocessingPipeline<dataType>(&(centroids[i].tree));
890 for(unsigned int c = 0; c < centroids.size(); ++c)
891 for(unsigned int i = 0; i < trees.size(); ++i)
892 if(clusteringAssignment[i] == (int)c)
893 convertBranchDecompositionMatching<dataType>(
894 &(centroids[c].tree), &(trees[i].tree), outputMatching[c][i]);
895 }
896
897 // ------------------------------------------------------------------------
898 // Utils
899 // ------------------------------------------------------------------------
900 template <class dataType>
901 void
903 std::vector<ftm::MergeTree<dataType>> &centroids2) {
904 for(auto &centroid : centroids)
905 printBaryStats(&(centroid.tree), debug::Priority::DETAIL);
906 for(auto &centroid : centroids2)
907 printBaryStats(&(centroid.tree), debug::Priority::DETAIL);
908 }
909
910 }; // MergeTreeClustering class
911
912} // namespace ttk
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
Definition BaseClass.h:47
#define matchingVectorType
#define treesMatchingVector
virtual int setThreadNumber(const int threadNumber)
Definition BaseClass.h:80
Minimalist debugging class.
Definition Debug.h:88
int debugLevel_
Definition Debug.h:379
void setDebugMsgPrefix(const std::string &prefix)
Definition Debug.h:364
virtual int setDebugLevel(const int &debugLevel)
Definition Debug.cpp:147
void limitSizeBarycenter(ftm::MergeTree< dataType > &bary, std::vector< ftm::FTMTree_MT * > &trees, unsigned int barycenterMaximumNumberOfPairs, double percent, bool useBD=true)
void computeBarycenter(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< double > &alphas, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, bool finalAsgnDoubleInput=false, bool finalAsgnFirstInput=true)
unsigned int barycenterMaximumNumberOfPairs_
void setDeterministic(bool deterministicT)
std::vector< double > finalDistances_
void printBaryStats(ftm::FTMTree_MT *baryTree, const debug::Priority &priority=debug::Priority::INFO)
void setProgressiveBarycenter(bool progressive)
void setBarycenterMaximumNumberOfPairs(unsigned int maxi)
void setBarycenterSizeLimitPercent(double percent)
void setBranchDecomposition(bool useBD)
void setNormalizedWasserstein(bool normalizedWasserstein)
void setDistanceSquaredRoot(bool distanceSquaredRoot)
void setAssignmentSolver(int assignmentSolver)
void printTreesStats(std::vector< ftm::FTMTree_MT * > &trees)
void copyMinMaxPair(ftm::MergeTree< dataType > &mTree1, ftm::MergeTree< dataType > &mTree2, bool setOrigins=false)
std::vector< std::vector< int > > treesNodeCorr_
void setKeepSubtree(bool keepSubtree)
void mixDistancesMatrix(std::vector< std::vector< dataType > > &distanceMatrix, std::vector< std::vector< dataType > > &distanceMatrix2)
void computeCentroids(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::MergeTree< dataType > > &centroids, matchingVectorType &outputMatching, std::vector< double > &alphas, std::vector< int > &clusteringAssignment, std::vector< ftm::FTMTree_MT * > &trees2, std::vector< ftm::MergeTree< dataType > > &centroids2, matchingVectorType &outputMatching2)
void assignmentCentroidsNaive(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::MergeTree< dataType > > &centroids, std::vector< std::tuple< int, int > > &assignmentC, std::vector< dataType > &bestDistanceT, std::vector< ftm::FTMTree_MT * > &trees2, std::vector< ftm::MergeTree< dataType > > &centroids2)
void copyCentroids(std::vector< ftm::MergeTree< dataType > > &centroids, std::vector< ftm::MergeTree< dataType > > &oldCentroids)
void setMixtureCoefficient(double coef)
void postprocessingClustering(std::vector< ftm::MergeTree< dataType > > &trees, std::vector< ftm::MergeTree< dataType > > &centroids, matchingVectorType &outputMatching, std::vector< int > &clusteringAssignment)
void initAcceleratedKMeans(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::MergeTree< dataType > > &centroids, std::vector< ftm::FTMTree_MT * > &trees2, std::vector< ftm::MergeTree< dataType > > &centroids2)
void fixMergedRootOriginClustering(std::vector< ftm::MergeTree< dataType > > &centroids)
void getCentroidsDistanceMatrix(std::vector< ftm::MergeTree< dataType > > &centroids, std::vector< std::vector< double > > &distanceMatrix, bool useDoubleInput=false, bool isFirstInput=true)
void initAcceleratedKMeansVectors(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::MergeTree< dataType > > &centroids, std::vector< ftm::FTMTree_MT * > &ttkNotUsed(trees2))
void printCentroidsStats(std::vector< ftm::MergeTree< dataType > > &centroids, std::vector< ftm::MergeTree< dataType > > &centroids2)
void computeOneBarycenter(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< double > &alphas, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings)
void setNoCentroids(unsigned int noCentroidsT)
void preprocessingClustering(std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< int > > &nodeCorr, bool useMinMaxPairT=true)
void finalAssignmentCentroids(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::MergeTree< dataType > > &centroids, matchingVectorType &matchingsC, std::vector< std::tuple< int, int > > &assignmentC, std::vector< dataType > &bestDistanceT, std::vector< ftm::FTMTree_MT * > &trees2, std::vector< ftm::MergeTree< dataType > > &centroids2, matchingVectorType &matchingsC2)
void execute(std::vector< ftm::MergeTree< dataType > > &trees, matchingVectorType &outputMatching, std::vector< double > &alphas, std::vector< int > &clusteringAssignment, std::vector< ftm::MergeTree< dataType > > &trees2, matchingVectorType &outputMatching2, std::vector< ftm::MergeTree< dataType > > &centroids, std::vector< ftm::MergeTree< dataType > > &centroids2)
std::vector< std::vector< int > > getTrees2NodeCorr()
void assignmentCentroidsAccelerated(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::MergeTree< dataType > > &centroids, std::vector< std::tuple< int, int > > &assignmentC, std::vector< dataType > &bestDistanceT, std::vector< ftm::FTMTree_MT * > &trees2, std::vector< ftm::MergeTree< dataType > > &centroids2)
void execute(std::vector< ftm::MergeTree< dataType > > &trees, matchingVectorType &outputMatching, std::vector< int > &clusteringAssignment, std::vector< ftm::MergeTree< dataType > > &centroids)
void putBackMinMaxPair(std::vector< ftm::MergeTree< dataType > > &centroids, std::vector< ftm::MergeTree< dataType > > &centroids2)
bool samePreviousAssignment(int clusterId)
bool updateCentroids(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::MergeTree< dataType > > &centroids, std::vector< double > &alphas, std::vector< std::tuple< int, int > > &assignmentC)
void initCentroids(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::FTMTree_MT * > &trees2, std::vector< std::vector< ftm::MergeTree< dataType > > > &allCentroids)
void execute(std::vector< ftm::MergeTree< dataType > > &trees, matchingVectorType &outputMatching, std::vector< int > &clusteringAssignment, std::vector< ftm::MergeTree< dataType > > &trees2, matchingVectorType &outputMatching2, std::vector< ftm::MergeTree< dataType > > &centroids, std::vector< ftm::MergeTree< dataType > > &centroids2)
~MergeTreeClustering() override=default
void matchingCorrespondence(treesMatchingVector &matchingT, std::vector< int > &nodeCorr, std::vector< int > &assignedTreesIndex)
void initNewCentroid(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &centroid, int noNewCentroid)
void assignmentCentroids(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::MergeTree< dataType > > &centroids, std::vector< std::tuple< int, int > > &assignmentC, std::vector< dataType > &bestDistanceT, std::vector< ftm::FTMTree_MT * > &trees2, std::vector< ftm::MergeTree< dataType > > &centroids2)
double getElapsedTime()
Definition Timer.h:15
The Topology ToolKit.
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/|__ _|"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)