TTK
Loading...
Searching...
No Matches
MergeTreeBarycenter.h
Go to the documentation of this file.
1
14
15#pragma once
16
17#include <random>
18
19// ttk common includes
20#include <Debug.h>
21#include <Triangulation.h>
22
23#include "MergeTreeBase.h"
24#include "MergeTreeDistance.h"
25
26namespace ttk {
27
32 class MergeTreeBarycenter : virtual public Debug, public MergeTreeBase {
33
34 protected:
35 double tol_ = 0.0;
36 bool addNodes_ = true;
37 bool deterministic_ = true;
40 bool isCalled_ = false;
43 double alpha_ = 0.5;
46
47 double allDistanceTime_ = 0;
48
50
51 bool preprocess_ = true;
52 bool postprocess_ = true;
53
54 // Output
55 std::vector<double> finalDistances_;
56
57 public:
60 "MergeTreeBarycenter"); // inherited from Debug: prefix will be printed
61 // at the beginning of every msg
62#ifdef TTK_ENABLE_OPENMP4
63 omp_set_max_active_levels(100);
64#endif
65 }
66 ~MergeTreeBarycenter() override = default;
67
68 void setTol(double tolT) {
69 tol_ = tolT;
70 }
71
72 void setAddNodes(bool addNodesT) {
73 addNodes_ = addNodesT;
74 }
75
76 void setDeterministic(bool deterministicT) {
77 deterministic_ = deterministicT;
78 }
79
80 void setBarycenterInitIndex(int barycenterInitIndex) {
81 barycenterInitIndex_ = barycenterInitIndex;
82 }
83
84 void setBarycenterMaxIter(int barycenterMaxIter) {
85 barycenterMaxIter_ = barycenterMaxIter;
86 }
87
88 void setProgressiveBarycenter(bool progressive) {
89 progressiveBarycenter_ = progressive;
90 }
91
92 void setProgressiveSpeedDivisor(double progSpeed) {
93 progressiveSpeedDivisor_ = progSpeed;
94 }
95
96 void setIsCalled(bool ic) {
97 isCalled_ = ic;
98 }
99
101 return allDistanceTime_;
102 }
103
106 }
107
108 void setAlpha(double alpha) {
109 alpha_ = alpha;
110 }
111
112 void setBarycenterMaximumNumberOfPairs(unsigned int maxi) {
114 }
115
116 void setBarycenterSizeLimitPercent(double percent) {
118 }
119
120 void setPreprocess(bool preproc) {
121 preprocess_ = preproc;
122 }
123
124 void setPostprocess(bool postproc) {
125 postprocess_ = postproc;
126 }
127
128 std::vector<double> getFinalDistances() {
129 return finalDistances_;
130 }
131
135 // ------------------------------------------------------------------------
136 // Initialization
137 // ------------------------------------------------------------------------
138 template <class dataType>
139 void getDistanceMatrix(std::vector<ftm::FTMTree_MT *> &trees,
140 std::vector<ftm::FTMTree_MT *> &trees2,
141 std::vector<std::vector<double>> &distanceMatrix,
142 bool useDoubleInput = false,
143 bool isFirstInput = true) {
144 distanceMatrix.clear();
145 distanceMatrix.resize(trees.size(), std::vector<double>(trees.size(), 0));
146#ifdef TTK_ENABLE_OPENMP4
147#pragma omp parallel for schedule(dynamic) \
148 num_threads(this->threadNumber_) if(parallelize_)
149#endif
150 for(unsigned int i = 0; i < trees.size(); ++i)
151 for(unsigned int j = i + 1; j < trees.size(); ++j) {
152 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching;
153 dataType distance;
154 computeOneDistance<dataType>(trees[i], trees2[j], matching, distance,
155 useDoubleInput, isFirstInput);
156 distanceMatrix[i][j] = distance;
157 distanceMatrix[j][i] = distance;
158 }
159 }
160
161 template <class dataType>
162 void getDistanceMatrix(std::vector<ftm::FTMTree_MT *> &trees,
163 std::vector<std::vector<double>> &distanceMatrix,
164 bool useDoubleInput = false,
165 bool isFirstInput = true) {
167 trees, trees, distanceMatrix, useDoubleInput, isFirstInput);
168 }
169
170 template <class dataType>
172 std::vector<ftm::FTMTree_MT *> &trees,
173 unsigned int barycenterMaximumNumberOfPairs,
174 double sizeLimitPercent,
175 std::vector<ftm::MergeTree<dataType>> &mTreesLimited) {
176 mTreesLimited.resize(trees.size());
177#ifdef TTK_ENABLE_OPENMP4
178#pragma omp parallel for schedule(dynamic) \
179 num_threads(this->threadNumber_) if(parallelize_)
180#endif
181 for(unsigned int i = 0; i < trees.size(); ++i) {
182 mTreesLimited[i] = ftm::copyMergeTree<dataType>(trees[i]);
183 limitSizeBarycenter(mTreesLimited[i], trees,
184 barycenterMaximumNumberOfPairs, sizeLimitPercent);
185 ftm::cleanMergeTree<dataType>(mTreesLimited[i]);
186 }
187 }
188
189 template <class dataType>
191 std::vector<ftm::FTMTree_MT *> &trees,
192 std::vector<std::vector<double>> &distanceMatrix,
193 unsigned int barycenterMaximumNumberOfPairs,
194 double sizeLimitPercent,
195 bool useDoubleInput = false,
196 bool isFirstInput = true) {
197 std::vector<ftm::MergeTree<dataType>> mTreesLimited;
199 trees, barycenterMaximumNumberOfPairs, sizeLimitPercent, mTreesLimited);
200 std::vector<ftm::FTMTree_MT *> treesLimited;
201 ftm::mergeTreeToFTMTree<dataType>(mTreesLimited, treesLimited);
203 trees, treesLimited, distanceMatrix, useDoubleInput, isFirstInput);
204 }
205
206 template <class dataType>
208 std::vector<ftm::FTMTree_MT *> &trees,
209 std::vector<std::vector<double>> &distanceMatrix,
210 unsigned int barycenterMaximumNumberOfPairs,
211 double sizeLimitPercent,
212 bool useDoubleInput = false,
213 bool isFirstInput = true) {
214 if(barycenterMaximumNumberOfPairs <= 0 and sizeLimitPercent <= 0.0)
216 trees, distanceMatrix, useDoubleInput, isFirstInput);
217 else
219 trees, distanceMatrix, barycenterMaximumNumberOfPairs,
220 sizeLimitPercent, useDoubleInput, isFirstInput);
221 }
222
223 template <class dataType>
224 int getBestInitTreeIndex(std::vector<ftm::FTMTree_MT *> &trees,
225 std::vector<ftm::FTMTree_MT *> &trees2,
226 unsigned int barycenterMaximumNumberOfPairs,
227 double sizeLimitPercent,
228 bool distMinimizer = true) {
229 if(barycenterInitIndex_ != -1)
231 std::vector<std::vector<double>> distanceMatrix, distanceMatrix2;
232 bool const useDoubleInput = (trees2.size() != 0);
233 getParametrizedDistanceMatrix<dataType>(trees, distanceMatrix,
234 barycenterMaximumNumberOfPairs,
235 sizeLimitPercent, useDoubleInput);
236 if(trees2.size() != 0)
238 trees2, distanceMatrix2, barycenterMaximumNumberOfPairs,
239 sizeLimitPercent, useDoubleInput, false);
240
241 int bestIndex = -1;
242 dataType bestValue
243 = distMinimizer ? std::numeric_limits<dataType>::max() : 0;
244 std::vector<int> sizes(trees.size());
245 for(unsigned int i = 0; i < trees.size(); ++i) {
246 dataType value = 0;
247 for(unsigned int j = 0; j < distanceMatrix[i].size(); ++j)
248 value += (not useDoubleInput ? distanceMatrix[i][j]
249 : mixDistances(distanceMatrix[i][j],
250 distanceMatrix2[i][j]));
251 if((distMinimizer and value < bestValue)
252 or (not distMinimizer and value > bestValue)) {
253 bestIndex = i;
254 bestValue = value;
255 }
256 sizes[i] = -value;
257 sizes[i] *= (distMinimizer) ? 1 : -1;
258 }
259 if(not deterministic_) {
260 std::random_device rd;
261 std::default_random_engine generator(rd());
262 std::discrete_distribution<int> distribution(
263 sizes.begin(), sizes.end());
264 bestIndex = distribution(generator);
265 }
266 return bestIndex;
267 }
268
269 template <class dataType>
270 int getBestInitTreeIndex(std::vector<ftm::FTMTree_MT *> &trees,
271 std::vector<ftm::FTMTree_MT *> &trees2,
272 double sizeLimitPercent,
273 bool distMinimizer = true) {
274 return getBestInitTreeIndex<dataType>(trees, trees2,
276 sizeLimitPercent, distMinimizer);
277 }
278
279 template <class dataType>
280 int getBestInitTreeIndex(std::vector<ftm::FTMTree_MT *> &trees,
281 bool distMinimizer = true) {
282 std::vector<ftm::FTMTree_MT *> trees2;
284 trees, trees2, barycenterMaximumNumberOfPairs_,
285 barycenterSizeLimitPercent_, distMinimizer);
286 }
287
288 template <class dataType>
289 void initBarycenterTree(std::vector<ftm::FTMTree_MT *> &trees,
290 ftm::MergeTree<dataType> &baryTree,
291 bool distMinimizer = true) {
292 int const bestIndex
293 = getBestInitTreeIndex<dataType>(trees, distMinimizer);
294 baryTree = ftm::copyMergeTree<dataType>(trees[bestIndex], true);
295 limitSizeBarycenter(baryTree, trees);
296 }
297
298 // ------------------------------------------------------------------------
299 // Update
300 // ------------------------------------------------------------------------
314 template <class dataType>
316 ftm::idNode nodeId1,
317 ftm::FTMTree_MT *tree,
318 ftm::idNode nodeId2,
319 std::vector<dataType> &newScalarsVector,
320 std::vector<std::tuple<ftm::idNode, ftm::idNode, int>> &nodesToProcess,
321 ftm::idNode nodeCpt,
322 int i) {
323 // Get nodes and scalars to add
324 std::queue<std::tuple<ftm::idNode, ftm::idNode>> queue;
325 queue.emplace(nodeId2, nodeId1);
326 nodesToProcess.emplace_back(nodeId2, nodeId1, i);
327 while(!queue.empty()) {
328 auto &queueTuple = queue.front();
329 queue.pop();
330 ftm::idNode const node = std::get<0>(queueTuple);
331 // Get scalars
332 newScalarsVector.push_back(
333 tree->getValue<dataType>(tree->getNode(node)->getOrigin()));
334 newScalarsVector.push_back(tree->getValue<dataType>(node));
335 // Process children
336 std::vector<ftm::idNode> children;
337 tree->getChildren(node, children);
338 for(auto child : children) {
339 queue.emplace(child, nodeCpt + 1);
340 nodesToProcess.emplace_back(child, nodeCpt + 1, i);
341 }
342 nodeCpt += 2; // we will add two nodes (birth and death)
343 }
344
345 return nodeCpt;
346 }
347
348 template <class dataType>
351 int noTrees,
352 std::vector<std::tuple<ftm::idNode, ftm::idNode, int>> &nodesToProcess,
353 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode>>>
354 &nodesProcessed) {
355 ftm::FTMTree_MT *tree1 = &(mTree1.tree);
356
357 // Add nodes
358 nodesProcessed.clear();
359 nodesProcessed.resize(noTrees);
360 for(auto &processTuple : nodesToProcess) {
361 ftm::idNode const parent = std::get<1>(processTuple);
362 ftm::idNode const nodeTree1 = tree1->getNumberOfNodes();
363 int const index = std::get<2>(processTuple);
364 nodesProcessed[index].emplace_back(
365 nodeTree1 + 1, std::get<0>(processTuple));
366 // Make node and its origin
367 tree1->makeNode(nodeTree1);
368 tree1->makeNode(nodeTree1 + 1);
369 tree1->setParent(nodeTree1 + 1, parent);
370 tree1->getNode(nodeTree1)->setOrigin(nodeTree1 + 1);
371 tree1->getNode(nodeTree1 + 1)->setOrigin(nodeTree1);
372 }
373 }
374
375 template <class dataType>
378 int noTrees,
379 std::vector<std::tuple<ftm::idNode, ftm::idNode, int>> &nodesToProcess,
380 std::vector<dataType> &newScalarsVector,
381 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode>>>
382 &nodesProcessed) {
383 ftm::FTMTree_MT *tree1 = &(mTree1.tree);
384
385 // Create new tree
387 = ftm::createEmptyMergeTree<dataType>(newScalarsVector.size());
388 ftm::setTreeScalars<dataType>(mTreeNew, newScalarsVector);
389 ftm::FTMTree_MT *treeNew = &(mTreeNew.tree);
390
391 // Copy the old tree structure
392 treeNew->copyMergeTreeStructure(tree1);
393
394 // Add nodes in the other trees
395 addNodes<dataType>(mTreeNew, noTrees, nodesToProcess, nodesProcessed);
396
397 // Copy new tree
398 mTree1 = mTreeNew;
399 }
400
401 template <class dataType>
403 std::vector<ftm::FTMTree_MT *> &trees,
404 ftm::MergeTree<dataType> &baryMergeTree,
405 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
406 &matchings) {
407 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
408 ftm::idNode const baryTreeRoot = baryTree->getRoot();
409
410 // Init matching matrix
411 // m[i][j] contains the node in the barycenter matched to the jth node of
412 // the ith tree
413 std::vector<std::vector<ftm::idNode>> matrixMatchings(trees.size());
414 std::vector<bool> baryMatched(baryTree->getNumberOfNodes(), false);
415 for(unsigned int i = 0; i < matchings.size(); ++i) {
416 auto &matching = matchings[i];
417 matrixMatchings[i].resize(trees[i]->getNumberOfNodes(),
418 std::numeric_limits<ftm::idNode>::max());
419 for(auto &match : matching) {
420 matrixMatchings[i][std::get<1>(match)] = std::get<0>(match);
421 baryMatched[std::get<0>(match)] = true;
422 }
423 }
424
425 // Iterate through trees to get the nodes to add in the barycenter
426 std::vector<std::vector<ftm::idNode>> nodesToAdd(trees.size());
427#ifdef TTK_ENABLE_OPENMP4
428#pragma omp parallel for schedule(dynamic) \
429 num_threads(this->threadNumber_) if(parallelize_)
430#endif
431 for(unsigned int i = 0; i < trees.size(); ++i) {
432 ftm::idNode const root = trees[i]->getRoot();
433 std::queue<ftm::idNode> queue;
434 queue.emplace(root);
435 while(!queue.empty()) {
436 ftm::idNode const node = queue.front();
437 queue.pop();
438 bool processChildren = true;
439 // if node in trees[i] is not matched
440 if(matrixMatchings[i][node]
441 == std::numeric_limits<ftm::idNode>::max()) {
442 if(not keepSubtree_) {
443 processChildren = false;
444 nodesToAdd[i].push_back(node);
445 } else {
446 // not todo manage if keepSubtree=true (not important since it is
447 // not a valid merge tree)
448 printErr(
449 "barycenter with keepSubtree_=true is not implemented yet");
450 }
451 }
452 if(processChildren) {
453 std::vector<ftm::idNode> children;
454 trees[i]->getChildren(node, children);
455 for(auto child : children)
456 if(not(trees[i]->isThereOnlyOnePersistencePair()
457 and trees[i]->isLeaf(child)))
458 queue.emplace(child);
459 }
460 }
461 }
462
463 bool foundRootNotMatched = false;
464 for(unsigned int i = 0; i < trees.size(); ++i)
465 foundRootNotMatched |= baryTree->isNodeIdInconsistent(
466 matrixMatchings[i][trees[i]->getRoot()]);
467 if(foundRootNotMatched)
468 printWrn("[updateBarycenterTreeStructure] an input tree has its root "
469 "not matched.");
470
471 // Delete nodes that are not matched in the barycenter
472 for(unsigned int i = 0; i < baryTree->getNumberOfNodes(); ++i)
473 if(not baryMatched[i])
474 baryTree->deleteNode(i);
475
476 if(not keepSubtree_) {
477 // Add scalars and nodes not present in the barycenter
478 ftm::idNode nodeCpt = baryTree->getNumberOfNodes();
479 std::vector<std::tuple<ftm::idNode, ftm::idNode, int>> nodesToProcess;
480 std::vector<dataType> newScalarsVector;
481 ftm::getTreeScalars<dataType>(baryMergeTree, newScalarsVector);
482 for(unsigned int i = 0; i < nodesToAdd.size(); ++i) {
483 for(auto node : nodesToAdd[i]) {
484 ftm::idNode parent
485 = matrixMatchings[i][trees[i]->getParentSafe(node)];
486 if(matchings[i].size() == 0)
487 parent = baryTreeRoot;
488
489 if((baryTree->isNodeIdInconsistent(parent)
490 or baryTree->isNodeAlone(parent))
491 and matchings[i].size() != 0) {
492 std::stringstream ss;
493 ss << trees[i]->getParentSafe(node) << " _ " << node;
494 printMsg(ss.str());
495 printMsg(trees[i]->printTree().str());
496 printMsg(trees[i]->printPairsFromTree<dataType>(true).str());
497 printMatching(matchings[i]);
498 std::stringstream ss2;
499 ss2 << "parent " << parent;
500 printMsg(ss2.str());
501 }
502 /*if(isRoot(trees[i], node))
503 parent = baryTree->getRoot();*/
504 std::vector<dataType> addedScalars;
506 parent, trees[i], node, addedScalars, nodesToProcess, nodeCpt, i);
507 newScalarsVector.insert(
508 newScalarsVector.end(), addedScalars.begin(), addedScalars.end());
509 }
510 }
511 if(addNodes_) {
512 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode>>>
513 nodesProcessed;
514 updateNodesAndScalars<dataType>(baryMergeTree, trees.size(),
515 nodesToProcess, newScalarsVector,
516 nodesProcessed);
517 for(unsigned int i = 0; i < matchings.size(); ++i) {
518 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>
519 nodesProcessedT;
520 for(auto &tup : nodesProcessed[i])
521 nodesProcessedT.emplace_back(
522 std::get<0>(tup), std::get<1>(tup), -1);
523 matchings[i].insert(matchings[i].end(), nodesProcessedT.begin(),
524 nodesProcessedT.end());
525 }
526 }
527 } else {
528 // not todo manage if keepSubtree=true (not important since it is not a
529 // valid merge tree)
530 printErr("barycenter with keepSubtree_=true is not implemented yet");
531 }
532 }
533
534 template <class dataType>
535 std::tuple<dataType, dataType>
537 std::tuple<dataType, dataType> birthDeath;
538 // Normalized Wasserstein
540 birthDeath = getNormalizedBirthDeath<dataType>(tree1, nodeId1);
541 // Classical Wasserstein
542 else
543 birthDeath = tree1->getBirthDeath<dataType>(nodeId1);
544 return birthDeath;
545 }
546
547 template <class dataType>
548 std::tuple<dataType, dataType>
550 ftm::idNode nodeId,
551 std::vector<dataType> &newScalarsVector,
552 std::vector<ftm::FTMTree_MT *> &trees,
553 std::vector<ftm::idNode> &nodes,
554 std::vector<double> &alphas) {
555 dataType newBirth = 0, newDeath = 0;
556
557 // Compute projection
558 dataType tempBirth = 0, tempDeath = 0;
559 double alphaSum = 0;
560 for(unsigned int i = 0; i < trees.size(); ++i)
561 if(nodes[i] != std::numeric_limits<ftm::idNode>::max())
562 alphaSum += alphas[i];
563 for(unsigned int i = 0; i < trees.size(); ++i) {
564 // if node is matched in trees[i]
565 if(nodes[i] != std::numeric_limits<ftm::idNode>::max()) {
566 auto iBirthDeath
567 = getParametrizedBirthDeath<dataType>(trees[i], nodes[i]);
568 dataType tTempBirth = 0, tTempDeath = 0;
569 tTempBirth += std::get<0>(iBirthDeath);
570 tTempDeath += std::get<1>(iBirthDeath);
571 tempBirth += tTempBirth * alphas[i] / alphaSum;
572 tempDeath += tTempDeath * alphas[i] / alphaSum;
573 }
574 }
575 dataType const projec = (tempBirth + tempDeath) / 2;
576
577 // Compute newBirth and newDeath
578 for(unsigned int i = 0; i < trees.size(); ++i) {
579 dataType iBirth = projec, iDeath = projec;
580 // if node is matched in trees[i]
581 if(nodes[i] != std::numeric_limits<ftm::idNode>::max()) {
582 auto iBirthDeath
583 = getParametrizedBirthDeath<dataType>(trees[i], nodes[i]);
584 iBirth = std::get<0>(iBirthDeath);
585 iDeath = std::get<1>(iBirthDeath);
586 }
587 newBirth += alphas[i] * iBirth;
588 newDeath += alphas[i] * iDeath;
589 }
591 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
592 dataType mu_max = getMinMaxLocalFromVector<dataType>(
593 baryTree, nodeId, newScalarsVector, false);
594 dataType mu_min = getMinMaxLocalFromVector<dataType>(
595 baryTree, nodeId, newScalarsVector);
596 // Forbid compiler optimization to have same results on different
597 // computers
598 volatile dataType tempBirthT = newBirth * (mu_max - mu_min);
599 volatile dataType tempDeathT = newDeath * (mu_max - mu_min);
600 newBirth = tempBirthT + mu_min;
601 newDeath = tempDeathT + mu_min;
602 }
603
604 return std::make_tuple(newBirth, newDeath);
605 }
606
607 template <class dataType>
608 std::tuple<dataType, dataType>
610 ftm::idNode nodeId,
611 double alpha,
612 ftm::MergeTree<dataType> &baryMergeTree,
613 ftm::idNode nodeB,
614 std::vector<dataType> &newScalarsVector) {
615 auto birthDeath = getParametrizedBirthDeath<dataType>(tree, nodeId);
616 dataType newBirth = std::get<0>(birthDeath);
617 dataType newDeath = std::get<1>(birthDeath);
618 dataType const projec = (newBirth + newDeath) / 2;
619
620 newBirth = alpha * newBirth + (1 - alpha) * projec;
621 newDeath = alpha * newDeath + (1 - alpha) * projec;
622
624 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
625 dataType mu_max = getMinMaxLocalFromVector<dataType>(
626 baryTree, nodeB, newScalarsVector, false);
627 dataType mu_min = getMinMaxLocalFromVector<dataType>(
628 baryTree, nodeB, newScalarsVector);
629 // Forbid compiler optimization to have same results on different
630 // computers
631 volatile dataType tempBirthT = newBirth * (mu_max - mu_min);
632 volatile dataType tempDeathT = newDeath * (mu_max - mu_min);
633 newBirth = tempBirthT + mu_min;
634 newDeath = tempDeathT + mu_min;
635 }
636
637 return std::make_tuple(newBirth, newDeath);
638 }
639
640 template <class dataType>
642 std::vector<ftm::FTMTree_MT *> &trees,
643 ftm::MergeTree<dataType> &baryMergeTree,
644 std::vector<double> &alphas,
645 unsigned int indexAddedNodes,
646 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
647 &matchings) {
648 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
649 bool const isJT = baryTree->isJoinTree<dataType>();
650
651 // Init matching matrix
652 // m[i][j] contains the node in trees[j] matched to the node i in the
653 // barycenter
654 std::vector<std::vector<ftm::idNode>> baryMatching(
655 indexAddedNodes,
656 std::vector<ftm::idNode>(
657 trees.size(), std::numeric_limits<ftm::idNode>::max()));
658 std::vector<std::tuple<int, ftm::idNode>> nodesAddedTree(
659 baryTree->getNumberOfNodes(), std::make_tuple(-1, -1));
660 for(unsigned int i = 0; i < matchings.size(); ++i) {
661 auto &matching = matchings[i];
662 for(auto &match : matching) {
663 if(std::get<0>(match) >= indexAddedNodes)
664 // get the tree of this added node
665 nodesAddedTree[std::get<0>(match)]
666 = std::make_tuple(i, std::get<1>(match));
667 else
668 baryMatching[std::get<0>(match)][i] = std::get<1>(match);
669 }
670 }
671
672 // Interpolate scalars
673 std::vector<dataType> newScalarsVector(baryTree->getNumberOfNodes());
674 ftm::idNode const root = baryTree->getRoot();
675 std::queue<ftm::idNode> queue;
676 queue.emplace(root);
677 while(!queue.empty()) {
678 ftm::idNode const node = queue.front();
679 queue.pop();
680 std::tuple<dataType, dataType> newBirthDeath;
681 if(node < indexAddedNodes) {
682 newBirthDeath
683 = interpolation<dataType>(baryMergeTree, node, newScalarsVector,
684 trees, baryMatching[node], alphas);
685 } else {
686 int const i = std::get<0>(nodesAddedTree[node]);
687 ftm::idNode const nodeT = std::get<1>(nodesAddedTree[node]);
688 newBirthDeath = interpolationAdded<dataType>(
689 trees[i], nodeT, alphas[i], baryMergeTree, node, newScalarsVector);
690 }
691 dataType nodeScalar
692 = (isJT ? std::get<1>(newBirthDeath) : std::get<0>(newBirthDeath));
693 dataType nodeOriginScalar
694 = (isJT ? std::get<0>(newBirthDeath) : std::get<1>(newBirthDeath));
695 newScalarsVector[node] = nodeScalar;
696 newScalarsVector[baryTree->getNode(node)->getOrigin()]
697 = nodeOriginScalar;
698 std::vector<ftm::idNode> children;
699 baryTree->getChildren(node, children);
700 for(auto child : children)
701 queue.emplace(child);
702 }
703
704 if(baryMergeTree.tree.isFullMerge()) {
705 auto mergedRootOrigin = baryTree->getMergedRootOrigin<dataType>();
706 dataType mergedRootOriginScalar = 0.0;
707 for(unsigned int i = 0; i < trees.size(); ++i)
708 mergedRootOriginScalar += trees[i]->getValue<dataType>(
709 trees[i]->getMergedRootOrigin<dataType>());
710 mergedRootOriginScalar /= trees.size();
711 newScalarsVector[mergedRootOrigin] = mergedRootOriginScalar;
712 }
713
714 setTreeScalars(baryMergeTree, newScalarsVector);
715 std::vector<ftm::idNode> deletedNodesT;
717 &(baryMergeTree.tree), 0, deletedNodesT);
718 limitSizeBarycenter(baryMergeTree, trees);
719 ftm::cleanMergeTree<dataType>(baryMergeTree);
720 }
721
722 template <class dataType>
724 std::vector<ftm::FTMTree_MT *> &trees,
725 ftm::MergeTree<dataType> &baryMergeTree,
726 std::vector<double> &alphas,
727 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
728 &matchings) {
729 int const indexAddedNodes = baryMergeTree.tree.getNumberOfNodes();
730 updateBarycenterTreeStructure<dataType>(trees, baryMergeTree, matchings);
732 trees, baryMergeTree, alphas, indexAddedNodes, matchings);
733 }
734
735 // ------------------------------------------------------------------------
736 // Assignment
737 // ------------------------------------------------------------------------
738 template <class dataType>
740 ftm::FTMTree_MT *tree,
741 ftm::FTMTree_MT *baryTree,
742 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
743 dataType &distance,
744 bool useDoubleInput = false,
745 bool isFirstInput = true) {
746 // Timer t_distance;
747 MergeTreeDistance mergeTreeDistance;
748 mergeTreeDistance.setDebugLevel(std::min(debugLevel_, 2));
749 mergeTreeDistance.setPreprocess(false);
750 mergeTreeDistance.setPostprocess(false);
751 mergeTreeDistance.setBranchDecomposition(true);
753 mergeTreeDistance.setKeepSubtree(keepSubtree_);
754 mergeTreeDistance.setAssignmentSolver(assignmentSolverID_);
755 mergeTreeDistance.setIsCalled(true);
756 mergeTreeDistance.setThreadNumber(this->threadNumber_);
757 mergeTreeDistance.setDistanceSquaredRoot(true); // squared root
758 mergeTreeDistance.setNodePerTask(nodePerTask_);
759 if(useDoubleInput) {
760 double const weight = mixDistancesMinMaxPairWeight(isFirstInput);
761 mergeTreeDistance.setMinMaxPairWeight(weight);
762 }
763 /*if(progressiveBarycenter_){
764 mergeTreeDistance.setAuctionNoRounds(1);
765 mergeTreeDistance.setAuctionEpsilonDiviser(NoIteration-1);
766 }*/
767 distance
768 = mergeTreeDistance.computeDistance<dataType>(baryTree, tree, matching);
769 std::stringstream ss, ss2;
770 ss << "distance tree : " << distance;
772 ss2 << "distance²tree : " << distance * distance;
774
775 // auto t_distance_time = t_distance.getElapsedTime();
776 // allDistanceTime_ += t_distance_time;
777 }
778
779 template <class dataType>
781 ftm::FTMTree_MT *tree,
782 ftm::MergeTree<dataType> &baryMergeTree,
783 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
784 dataType &distance,
785 bool useDoubleInput = false,
786 bool isFirstInput = true) {
787 computeOneDistance<dataType>(tree, &(baryMergeTree.tree), matching,
788 distance, useDoubleInput, isFirstInput);
789 }
790
791 template <class dataType>
793 ftm::MergeTree<dataType> &baryMergeTree,
794 ftm::MergeTree<dataType> &baryMergeTree2,
795 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
796 dataType &distance,
797 bool useDoubleInput = false,
798 bool isFirstInput = true) {
799 computeOneDistance<dataType>(&(baryMergeTree.tree), baryMergeTree2,
800 matching, distance, useDoubleInput,
801 isFirstInput);
802 }
803
804 template <class dataType>
806 std::vector<ftm::FTMTree_MT *> &trees,
807 ftm::MergeTree<dataType> &baryMergeTree,
808 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
809 &matchings,
810 std::vector<dataType> &distances,
811 bool useDoubleInput = false,
812 bool isFirstInput = true) {
813 if(not isCalled_)
814 assignmentPara(trees, baryMergeTree, matchings, distances,
815 useDoubleInput, isFirstInput);
816 else
817 assignmentTask(trees, baryMergeTree, matchings, distances,
818 useDoubleInput, isFirstInput);
819 }
820
821 template <class dataType>
823 std::vector<ftm::FTMTree_MT *> &trees,
824 ftm::MergeTree<dataType> &baryMergeTree,
825 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
826 &matchings,
827 std::vector<dataType> &distances,
828 bool useDoubleInput = false,
829 bool isFirstInput = true) {
830#ifdef TTK_ENABLE_OPENMP4
831#pragma omp parallel num_threads(this->threadNumber_) \
832 shared(baryMergeTree) if(parallelize_)
833 {
834#pragma omp single nowait
835#endif
836 assignmentTask(trees, baryMergeTree, matchings, distances,
837 useDoubleInput, isFirstInput);
838#ifdef TTK_ENABLE_OPENMP4
839 } // pragma omp parallel
840#endif
841 }
842
843 template <class dataType>
845 std::vector<ftm::FTMTree_MT *> &trees,
846 ftm::MergeTree<dataType> &baryMergeTree,
847 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
848 &matchings,
849 std::vector<dataType> &distances,
850 bool useDoubleInput = false,
851 bool isFirstInput = true) {
852 for(unsigned int i = 0; i < trees.size(); ++i)
853#ifdef TTK_ENABLE_OPENMP4
854#pragma omp task firstprivate(i) UNTIED() \
855 shared(baryMergeTree, matchings, distances)
856#endif
857 computeOneDistance<dataType>(trees[i], baryMergeTree, matchings[i],
858 distances[i], useDoubleInput,
859 isFirstInput);
860#ifdef TTK_ENABLE_OPENMP4
861#pragma omp taskwait
862#endif
863 }
864
865 // ------------------------------------------------------------------------
866 // Progressivity
867 // ------------------------------------------------------------------------
868 template <class dataType>
869 unsigned int
870 persistenceScaling(std::vector<ftm::FTMTree_MT *> &trees,
871 std::vector<ftm::MergeTree<dataType>> &mergeTrees,
872 std::vector<ftm::FTMTree_MT *> &oriTrees,
873 int iterationNumber,
874 std::vector<std::vector<ftm::idNode>> &deletedNodes) {
875 deletedNodes.clear();
876 deletedNodes.resize(oriTrees.size());
877 unsigned int noTreesUnscaled = 0;
878
879 // Scale trees
880 for(unsigned int i = 0; i < oriTrees.size(); ++i) {
881 double persistenceThreshold = 50.0;
882 if(iterationNumber != -1) {
883 // Get number of pairs in scaled merge tree
884 int const noPairs = mergeTrees[i].tree.getRealNumberOfNodes();
885
886 // Get pairs in original merge tree
887 std::vector<std::tuple<ftm::idNode, ftm::idNode, dataType>> pairs;
888 oriTrees[i]->getPersistencePairsFromTree<dataType>(
889 pairs, branchDecomposition_);
890
891 // Compute new persistence threshold
892 double const multiplier
894 ? 1.
895 : iterationNumber / progressiveSpeedDivisor_);
896 int const decrement = multiplier * pairs.size() / 10;
897 int thresholdIndex = pairs.size() - noPairs - std::max(decrement, 2);
898 thresholdIndex = std::max(thresholdIndex, 0);
899 const double persistence = std::get<2>(pairs[thresholdIndex]);
900 persistenceThreshold
901 = persistence / std::get<2>(pairs.back()) * 100.0;
902 if(thresholdIndex == 0) {
903 persistenceThreshold = 0.;
904 ++noTreesUnscaled;
905 }
906 }
907 if(persistenceThreshold != 0.) {
909 = ftm::copyMergeTree<dataType>(oriTrees[i]);
911 &(mt.tree), persistenceThreshold, deletedNodes[i]);
912 if(mergeTrees.size() == 0)
913 mergeTrees.resize(oriTrees.size());
914 mergeTrees[i] = mt;
915 trees[i] = &(mt.tree);
916 } else {
917 trees[i] = oriTrees[i];
918 }
919 }
920
921 printTreesStats(trees);
922
923 return noTreesUnscaled;
924 }
925
926 template <class dataType>
928 std::vector<ftm::FTMTree_MT *> &oriTrees,
929 std::vector<std::vector<ftm::idNode>> &deletedNodes,
930 std::vector<dataType> &distances) {
931 for(unsigned int i = 0; i < oriTrees.size(); ++i)
932 for(auto node : deletedNodes[i])
933 distances[i] += deleteCost<dataType>(oriTrees[i], node);
934 }
935
936 // ------------------------------------------------------------------------
937 // Main Functions
938 // ------------------------------------------------------------------------
939 template <class dataType>
941 std::vector<ftm::FTMTree_MT *> &trees,
942 ftm::MergeTree<dataType> &baryMergeTree,
943 std::vector<double> &alphas,
944 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
945 &finalMatchings,
946 bool finalAsgnDoubleInput = false,
947 bool finalAsgnFirstInput = true) {
948 Timer t_bary;
949
950 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
951
952 // Persistence scaling
953 std::vector<ftm::FTMTree_MT *> oriTrees;
954 std::vector<ftm::MergeTree<dataType>> scaledMergeTrees;
955 std::vector<std::vector<ftm::idNode>> deletedNodes;
957 oriTrees.insert(oriTrees.end(), trees.begin(), trees.end());
959 trees, scaledMergeTrees, oriTrees, -1, deletedNodes);
960 std::vector<ftm::idNode> deletedNodesT;
961 persistenceThresholding<dataType>(baryTree, 50, deletedNodesT);
962 }
963 bool treesUnscaled = false;
964
965 // Print bary stats
966 printBaryStats(baryTree);
967
968 // Run
969 bool converged = false;
970 dataType frechetEnergy = -1;
971 dataType minFrechet = std::numeric_limits<dataType>::max();
972 int cptBlocked = 0;
973 int NoIteration = 0;
974 while(not converged) {
975 ++NoIteration;
976 if(barycenterMaxIter_ != -1 and NoIteration > barycenterMaxIter_)
977 break;
978
980 std::stringstream ss;
981 ss << "Iteration " << NoIteration;
982 printMsg(ss.str());
983
984 // --- Assignment
985 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
986 matchings(trees.size());
987 std::vector<dataType> distances(trees.size(), -1);
988 Timer t_assignment;
989 assignment<dataType>(trees, baryMergeTree, matchings, distances);
990 Timer t_addDeletedNodes;
993 oriTrees, deletedNodes, distances);
994 addDeletedNodesTime_ += t_addDeletedNodes.getElapsedTime();
995 auto t_assignment_time
996 = t_assignment.getElapsedTime() - t_addDeletedNodes.getElapsedTime();
997 printMsg("Assignment", 1, t_assignment_time, this->threadNumber_,
999
1000 // --- Update
1001 Timer t_update;
1002 updateBarycenterTree<dataType>(trees, baryMergeTree, alphas, matchings);
1003 auto t_update_time = t_update.getElapsedTime();
1004 baryTree = &(baryMergeTree.tree);
1005 printMsg("Update", 1, t_update_time, this->threadNumber_,
1007
1008 // --- Check convergence
1009 dataType currentFrechetEnergy = 0;
1010 for(unsigned int i = 0; i < trees.size(); ++i)
1011 currentFrechetEnergy += alphas[i] * distances[i] * distances[i];
1012 auto frechetDiff
1013 = std::abs((double)(frechetEnergy - currentFrechetEnergy));
1014 converged = (frechetDiff <= tol_);
1015 converged = converged and (not progressiveBarycenter_ or treesUnscaled);
1016 frechetEnergy = currentFrechetEnergy;
1017 tol_ = frechetEnergy / 125.0;
1018
1019 std::stringstream ss4;
1020 auto barycenterTime = t_bary.getElapsedTime() - addDeletedNodesTime_;
1021 printMsg("Total", 1, barycenterTime, this->threadNumber_,
1023 printBaryStats(baryTree);
1024 ss4 << "Frechet energy : " << frechetEnergy;
1025 printMsg(ss4.str());
1026
1027 minFrechet = std::min(minFrechet, frechetEnergy);
1028 if(not converged and (not progressiveBarycenter_ or treesUnscaled)) {
1029 cptBlocked = (minFrechet < frechetEnergy) ? cptBlocked + 1 : 0;
1030 converged = (cptBlocked >= 10);
1031 }
1032
1033 // --- Persistence scaling
1035 unsigned int const noTreesUnscaled = persistenceScaling<dataType>(
1036 trees, scaledMergeTrees, oriTrees, NoIteration, deletedNodes);
1037 treesUnscaled = (noTreesUnscaled == oriTrees.size());
1038 }
1039 }
1040
1041 // Final processing
1043 printMsg("Final assignment");
1044
1045 std::vector<dataType> distances(trees.size(), -1);
1046 assignment<dataType>(trees, baryMergeTree, finalMatchings, distances,
1047 finalAsgnDoubleInput, finalAsgnFirstInput);
1048 for(auto dist : distances)
1049 finalDistances_.push_back(dist);
1050 dataType currentFrechetEnergy = 0;
1051 for(unsigned int i = 0; i < trees.size(); ++i)
1052 currentFrechetEnergy += alphas[i] * distances[i] * distances[i];
1053
1054 std::stringstream ss, ss2;
1055 ss << "Frechet energy : " << currentFrechetEnergy;
1056 printMsg(ss.str());
1057 auto barycenterTime = t_bary.getElapsedTime() - addDeletedNodesTime_;
1058 printMsg("Total", 1, barycenterTime, this->threadNumber_,
1060 // std::cout << "Bary Distance Time = " << allDistanceTime_ << std::endl;
1061
1062 if(trees.size() == 2 and not isCalled_)
1064 trees, baryMergeTree, finalMatchings, distances);
1065
1066 // Persistence (un)scaling
1068 scaledMergeTrees.clear();
1069 trees.clear();
1070 trees.insert(trees.end(), oriTrees.begin(), oriTrees.end());
1071 }
1072 }
1073
1074 template <class dataType>
1076 std::vector<ftm::MergeTree<dataType>> &trees,
1077 std::vector<double> &alphas,
1078 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1079 &finalMatchings,
1080 ftm::MergeTree<dataType> &baryMergeTree,
1081 bool finalAsgnDoubleInput = false,
1082 bool finalAsgnFirstInput = true) {
1083 // --- Preprocessing
1084 if(preprocess_) {
1085 treesNodeCorr_.resize(trees.size());
1086 for(unsigned int i = 0; i < trees.size(); ++i) {
1091 }
1092 printTreesStats(trees);
1093 }
1094
1095 // --- Init barycenter
1096 std::vector<ftm::FTMTree_MT *> treesT;
1097 ftm::mergeTreeToFTMTree<dataType>(trees, treesT);
1098 initBarycenterTree<dataType>(treesT, baryMergeTree);
1099
1100 // --- Execute
1101 computeBarycenter<dataType>(treesT, baryMergeTree, alphas, finalMatchings,
1102 finalAsgnDoubleInput, finalAsgnFirstInput);
1103
1104 // --- Postprocessing
1105 if(postprocess_) {
1106 std::vector<int> const allRealNodes(trees.size());
1107 for(unsigned int i = 0; i < trees.size(); ++i) {
1109 }
1110
1111 // fixMergedRootOriginBarycenter<dataType>(baryMergeTree);
1112 postprocessingPipeline<dataType>(&(baryMergeTree.tree));
1113 for(unsigned int i = 0; i < trees.size(); ++i) {
1115 &(baryMergeTree.tree), treesT[i], finalMatchings[i]);
1116 }
1117 }
1118 }
1119
1120 template <class dataType>
1122 std::vector<ftm::MergeTree<dataType>> &trees,
1123 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1124 &finalMatchings,
1125 ftm::MergeTree<dataType> &baryMergeTree,
1126 bool finalAsgnDoubleInput = false,
1127 bool finalAsgnFirstInput = true) {
1128 std::vector<double> alphas;
1129 if(trees.size() != 2) {
1130 for(unsigned int i = 0; i < trees.size(); ++i)
1131 alphas.push_back(1.0 / trees.size());
1132 } else {
1133 alphas.push_back(alpha_);
1134 alphas.push_back(1 - alpha_);
1135 }
1136
1137 execute<dataType>(trees, alphas, finalMatchings, baryMergeTree,
1138 finalAsgnDoubleInput, finalAsgnFirstInput);
1139 }
1140
1141 // ------------------------------------------------------------------------
1142 // Preprocessing
1143 // ------------------------------------------------------------------------
1144 template <class dataType>
1146 std::vector<ftm::FTMTree_MT *> &trees,
1147 unsigned int barycenterMaximumNumberOfPairs,
1148 double percent,
1149 bool useBD = true) {
1150 auto metric = getSizeLimitMetric(trees);
1151 unsigned int percentMaxPairs = metric * percent / 100.0;
1152
1153 unsigned int newNoNodes;
1154 if(barycenterMaximumNumberOfPairs > 0 and percent > 0)
1155 newNoNodes = std::min(barycenterMaximumNumberOfPairs, percentMaxPairs);
1156 else if(barycenterMaximumNumberOfPairs > 0)
1157 newNoNodes = barycenterMaximumNumberOfPairs;
1158 else if(percent > 0)
1159 newNoNodes = percentMaxPairs;
1160 else
1161 return;
1162 keepMostImportantPairs<dataType>(&(bary.tree), newNoNodes, useBD);
1163 }
1164
1165 template <class dataType>
1167 std::vector<ftm::FTMTree_MT *> &trees,
1168 double percent,
1169 bool useBD = true) {
1171 bary, trees, barycenterMaximumNumberOfPairs_, percent, useBD);
1172 }
1173
1174 template <class dataType>
1176 std::vector<ftm::FTMTree_MT *> &trees,
1177 bool useBD = true) {
1180 }
1181
1182 // ------------------------------------------------------------------------
1183 // Postprocessing
1184 // ------------------------------------------------------------------------
1185 template <class dataType>
1187 if(not barycenter.tree.isFullMerge())
1188 return;
1189
1190 ftm::FTMTree_MT *tree = &(barycenter.tree);
1191 auto &tup = fixMergedRootOrigin<dataType>(tree);
1192 int maxIndex = std::get<0>(tup);
1193 dataType oldOriginValue = std::get<1>(tup);
1194
1195 // Verify that scalars are consistent
1196 ftm::idNode const treeRoot = tree->getRoot();
1197 std::vector<dataType> newScalarsVector;
1198 ftm::getTreeScalars<dataType>(tree, newScalarsVector);
1199 bool isJT = tree->isJoinTree<dataType>();
1200 if((isJT and tree->getValue<dataType>(maxIndex) > oldOriginValue)
1201 or (not isJT
1202 and tree->getValue<dataType>(maxIndex) < oldOriginValue)) {
1203 newScalarsVector[treeRoot] = newScalarsVector[maxIndex];
1204 newScalarsVector[maxIndex] = oldOriginValue;
1205 } else
1206 newScalarsVector[treeRoot] = oldOriginValue;
1207 setTreeScalars(barycenter, newScalarsVector);
1208 }
1209
1210 // ------------------------------------------------------------------------
1211 // Utils
1212 // ------------------------------------------------------------------------
1214 const debug::Priority &priority
1216 auto noNodesT = baryTree->getNumberOfNodes();
1217 auto noNodes = baryTree->getRealNumberOfNodes();
1218 std::stringstream ss;
1219 ss << "Barycenter number of nodes : " << noNodes << " / " << noNodesT;
1220 printMsg(ss.str(), priority);
1221 }
1222
1223 // ------------------------------------------------------------------------
1224 // Testing
1225 // ------------------------------------------------------------------------
1226 template <class dataType>
1228 std::vector<ftm::FTMTree_MT *> &trees,
1229 ftm::MergeTree<dataType> &baryMergeTree,
1230 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1231 &finalMatchings,
1232 std::vector<dataType> distances) {
1233 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching;
1234 dataType distance;
1235 computeOneDistance(trees[0], trees[1], matching, distance);
1236 if(distance != (distances[0] + distances[1])) {
1237 std::stringstream ss, ss2, ss3, ss4;
1238 ss << "distance T1 T2 : " << distance;
1239 printMsg(ss.str());
1240 ss2 << "distance T1 T' T2 : " << distances[0] + distances[1];
1241 printMsg(ss2.str());
1242 ss3 << "distance T1 T' : " << distances[0];
1243 printMsg(ss3.str());
1244 ss4 << "distance T' T2 : " << distances[1];
1245 printMsg(ss4.str());
1246 }
1247 return;
1248
1249 auto baryTree = &(baryMergeTree.tree);
1250 std::vector<std::vector<ftm::idNode>> baryMatched(
1251 baryTree->getNumberOfNodes(),
1252 std::vector<ftm::idNode>(
1253 trees.size(), std::numeric_limits<ftm::idNode>::max()));
1254 for(unsigned int i = 0; i < finalMatchings.size(); ++i)
1255 for(auto &match : finalMatchings[i])
1256 baryMatched[std::get<0>(match)][i] = std::get<1>(match);
1257
1258 std::queue<ftm::idNode> queue;
1259 queue.emplace(baryTree->getRoot());
1260 while(!queue.empty()) {
1261 auto node = queue.front();
1262 queue.pop();
1263 std::vector<dataType> costs(trees.size());
1264 for(unsigned int i = 0; i < trees.size(); ++i)
1265 if(baryMatched[node][i] != std::numeric_limits<ftm::idNode>::max())
1266 costs[i] = relabelCost<dataType>(
1267 baryTree, node, trees[i], baryMatched[node][i]);
1268 else
1269 costs[i] = deleteCost<dataType>(baryTree, node);
1270 dataType cost = 0;
1271 if(baryMatched[node][0] != std::numeric_limits<ftm::idNode>::max()
1272 and baryMatched[node][1] != std::numeric_limits<ftm::idNode>::max())
1273 cost = relabelCost<dataType>(
1274 trees[0], baryMatched[node][0], trees[1], baryMatched[node][1]);
1275 else if(baryMatched[node][0] == std::numeric_limits<ftm::idNode>::max())
1276 cost = deleteCost<dataType>(trees[1], baryMatched[node][1]);
1277 else if(baryMatched[node][1] == std::numeric_limits<ftm::idNode>::max())
1278 cost = deleteCost<dataType>(trees[0], baryMatched[node][0]);
1279 else
1280 printErr("problem");
1281 costs[0] = std::sqrt(costs[0]);
1282 costs[1] = std::sqrt(costs[1]);
1283 cost = std::sqrt(cost);
1284 if(std::abs((double)(costs[0] - costs[1])) > 1e-7) {
1286 std::stringstream ss, ss2, ss3, ss4;
1287 ss << "cost T' T0 : " << costs[0];
1288 printMsg(ss.str());
1289 ss2 << "cost T' T1 : " << costs[1];
1290 printMsg(ss2.str());
1291 ss3 << "cost T0 T1 : " << cost;
1292 printMsg(ss2.str());
1293 ss4 << "cost T0 T' T1 : " << costs[0] + costs[1];
1294 printMsg(ss4.str());
1295 if(std::abs((double)((costs[0] + costs[1]) - cost)) > 1e-7) {
1296 std::stringstream ss5;
1297 ss5 << "diff : "
1298 << std::abs((double)((costs[0] + costs[1]) - cost));
1299 printMsg(ss5.str());
1300 }
1301 std::stringstream ss6;
1302 ss6 << "diff2 : " << std::abs((double)(costs[0] - costs[1]));
1303 printMsg(ss.str());
1304 // baryTree->printNode2<dataType>(node);
1305 // baryTree->printNode2<dataType>(baryTree->getParentSafe(node));
1306 for(unsigned int i = 0; i < 2; ++i)
1307 if(baryMatched[node][i]
1308 != std::numeric_limits<ftm::idNode>::max()) {
1309 printMsg(
1310 trees[i]->printNode2<dataType>(baryMatched[node][i]).str());
1311 printMsg(trees[i]
1312 ->printNode2<dataType>(
1313 trees[i]->getParentSafe(baryMatched[node][i]))
1314 .str());
1315 }
1316 }
1317 std::vector<ftm::idNode> children;
1318 baryTree->getChildren(node, children);
1319 for(auto child : children)
1320 queue.emplace(child);
1321 }
1322 }
1323
1324 }; // MergeTreeBarycenter class
1325
1326} // namespace ttk
#define UNTIED()
TTK processing package that efficiently computes the contour tree of scalar data and more (data segme...
virtual int setThreadNumber(const int threadNumber)
Definition BaseClass.h:80
int debugLevel_
Definition Debug.h:379
int printWrn(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
Definition Debug.h:159
void setDebugMsgPrefix(const std::string &prefix)
Definition Debug.h:364
virtual int setDebugLevel(const int &debugLevel)
Definition Debug.cpp:147
int printErr(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
Definition Debug.h:149
int getBestInitTreeIndex(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::FTMTree_MT * > &trees2, unsigned int barycenterMaximumNumberOfPairs, double sizeLimitPercent, bool distMinimizer=true)
void limitSizeBarycenter(ftm::MergeTree< dataType > &bary, std::vector< ftm::FTMTree_MT * > &trees, unsigned int barycenterMaximumNumberOfPairs, double percent, bool useBD=true)
std::tuple< dataType, dataType > getParametrizedBirthDeath(ftm::FTMTree_MT *tree1, ftm::idNode nodeId1)
void getParametrizedDistanceMatrix(std::vector< ftm::FTMTree_MT * > &trees, std::vector< std::vector< double > > &distanceMatrix, unsigned int barycenterMaximumNumberOfPairs, double sizeLimitPercent, bool useDoubleInput=false, bool isFirstInput=true)
void verifyBarycenterTwoTrees(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, std::vector< dataType > distances)
void computeOneDistance(ftm::FTMTree_MT *tree, ftm::FTMTree_MT *baryTree, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matching, dataType &distance, bool useDoubleInput=false, bool isFirstInput=true)
void updateNodesAndScalars(ftm::MergeTree< dataType > &mTree1, int noTrees, std::vector< std::tuple< ftm::idNode, ftm::idNode, int > > &nodesToProcess, std::vector< dataType > &newScalarsVector, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode > > > &nodesProcessed)
unsigned int persistenceScaling(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::MergeTree< dataType > > &mergeTrees, std::vector< ftm::FTMTree_MT * > &oriTrees, int iterationNumber, std::vector< std::vector< ftm::idNode > > &deletedNodes)
void fixMergedRootOriginBarycenter(ftm::MergeTree< dataType > &barycenter)
void setAddNodes(bool addNodesT)
void setBarycenterMaxIter(int barycenterMaxIter)
void setPreprocess(bool preproc)
void computeOneDistance(ftm::MergeTree< dataType > &baryMergeTree, ftm::MergeTree< dataType > &baryMergeTree2, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matching, dataType &distance, bool useDoubleInput=false, bool isFirstInput=true)
void getDistanceMatrix(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::FTMTree_MT * > &trees2, std::vector< std::vector< double > > &distanceMatrix, bool useDoubleInput=false, bool isFirstInput=true)
void limitSizeBarycenter(ftm::MergeTree< dataType > &bary, std::vector< ftm::FTMTree_MT * > &trees, double percent, bool useBD=true)
std::tuple< dataType, dataType > interpolation(ftm::MergeTree< dataType > &baryMergeTree, ftm::idNode nodeId, std::vector< dataType > &newScalarsVector, std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::idNode > &nodes, std::vector< double > &alphas)
int getBestInitTreeIndex(std::vector< ftm::FTMTree_MT * > &trees, bool distMinimizer=true)
void computeBarycenter(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< double > &alphas, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, bool finalAsgnDoubleInput=false, bool finalAsgnFirstInput=true)
void assignmentPara(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< dataType > &distances, bool useDoubleInput=false, bool isFirstInput=true)
void setPostprocess(bool postproc)
std::tuple< dataType, dataType > interpolationAdded(ftm::FTMTree_MT *tree, ftm::idNode nodeId, double alpha, ftm::MergeTree< dataType > &baryMergeTree, ftm::idNode nodeB, std::vector< dataType > &newScalarsVector)
unsigned int barycenterMaximumNumberOfPairs_
void setBarycenterInitIndex(int barycenterInitIndex)
ftm::idNode getNodesAndScalarsToAdd(ftm::idNode nodeId1, ftm::FTMTree_MT *tree, ftm::idNode nodeId2, std::vector< dataType > &newScalarsVector, std::vector< std::tuple< ftm::idNode, ftm::idNode, int > > &nodesToProcess, ftm::idNode nodeCpt, int i)
Get information about the nodes to add in the barycenter.
void setDeterministic(bool deterministicT)
~MergeTreeBarycenter() override=default
void updateBarycenterTreeScalars(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< double > &alphas, unsigned int indexAddedNodes, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings)
void addNodes(ftm::MergeTree< dataType > &mTree1, int noTrees, std::vector< std::tuple< ftm::idNode, ftm::idNode, int > > &nodesToProcess, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode > > > &nodesProcessed)
void initBarycenterTree(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryTree, bool distMinimizer=true)
void assignmentTask(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< dataType > &distances, bool useDoubleInput=false, bool isFirstInput=true)
void execute(std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, ftm::MergeTree< dataType > &baryMergeTree, bool finalAsgnDoubleInput=false, bool finalAsgnFirstInput=true)
void setProgressiveSpeedDivisor(double progSpeed)
void addScaledDeletedNodesCost(std::vector< ftm::FTMTree_MT * > &oriTrees, std::vector< std::vector< ftm::idNode > > &deletedNodes, std::vector< dataType > &distances)
void updateBarycenterTreeStructure(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings)
std::vector< double > finalDistances_
void getSizeLimitedTrees(std::vector< ftm::FTMTree_MT * > &trees, unsigned int barycenterMaximumNumberOfPairs, double sizeLimitPercent, std::vector< ftm::MergeTree< dataType > > &mTreesLimited)
void printBaryStats(ftm::FTMTree_MT *baryTree, const debug::Priority &priority=debug::Priority::INFO)
void setProgressiveBarycenter(bool progressive)
void limitSizeBarycenter(ftm::MergeTree< dataType > &bary, std::vector< ftm::FTMTree_MT * > &trees, bool useBD=true)
void assignment(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< dataType > &distances, bool useDoubleInput=false, bool isFirstInput=true)
void setBarycenterMaximumNumberOfPairs(unsigned int maxi)
void computeOneDistance(ftm::FTMTree_MT *tree, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matching, dataType &distance, bool useDoubleInput=false, bool isFirstInput=true)
void setBarycenterSizeLimitPercent(double percent)
void getSizeLimitedDistanceMatrix(std::vector< ftm::FTMTree_MT * > &trees, std::vector< std::vector< double > > &distanceMatrix, unsigned int barycenterMaximumNumberOfPairs, double sizeLimitPercent, bool useDoubleInput=false, bool isFirstInput=true)
void updateBarycenterTree(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< double > &alphas, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings)
std::vector< double > getFinalDistances()
void getDistanceMatrix(std::vector< ftm::FTMTree_MT * > &trees, std::vector< std::vector< double > > &distanceMatrix, bool useDoubleInput=false, bool isFirstInput=true)
int getBestInitTreeIndex(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::FTMTree_MT * > &trees2, double sizeLimitPercent, bool distMinimizer=true)
void execute(std::vector< ftm::MergeTree< dataType > > &trees, std::vector< double > &alphas, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, ftm::MergeTree< dataType > &baryMergeTree, bool finalAsgnDoubleInput=false, bool finalAsgnFirstInput=true)
void setBranchDecomposition(bool useBD)
void setNormalizedWasserstein(bool normalizedWasserstein)
void setDistanceSquaredRoot(bool distanceSquaredRoot)
void setAssignmentSolver(int assignmentSolver)
dataType deleteCost(const ftm::FTMTree_MT *tree, ftm::idNode nodeId)
void preprocessingPipeline(ftm::MergeTree< dataType > &mTree, double epsilonTree, double epsilon2Tree, double epsilon3Tree, bool branchDecompositionT, bool useMinMaxPairT, bool cleanTreeT, double persistenceThreshold, std::vector< int > &nodeCorr, bool deleteInconsistentNodes=true)
void keepMostImportantPairs(ftm::FTMTree_MT *tree, int n, bool useBD)
void setNodePerTask(int npt)
void convertBranchDecompositionMatching(ftm::FTMTree_MT *tree1, ftm::FTMTree_MT *tree2, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &outputMatching)
dataType relabelCost(const ftm::FTMTree_MT *tree1, ftm::idNode nodeId1, const ftm::FTMTree_MT *tree2, ftm::idNode nodeId2)
double mixDistances(dataType distance1, dataType distance2)
double getSizeLimitMetric(std::vector< ftm::FTMTree_MT * > &trees)
void printTreesStats(std::vector< ftm::FTMTree_MT * > &trees)
void postprocessingPipeline(ftm::FTMTree_MT *tree)
void printMatching(std::vector< MatchingType > &matchings)
std::vector< std::vector< int > > treesNodeCorr_
void setKeepSubtree(bool keepSubtree)
void persistenceThresholding(ftm::FTMTree_MT *tree, double persistenceThresholdT, std::vector< ftm::idNode > &deletedNodes)
double mixDistancesMinMaxPairWeight(bool isFirstInput)
std::tuple< int, dataType > fixMergedRootOrigin(ftm::FTMTree_MT *tree)
void setPreprocess(bool preproc)
void setPostprocess(bool postproc)
void setMinMaxPairWeight(double weight)
dataType computeDistance(const ftm::FTMTree_MT *tree1, const ftm::FTMTree_MT *tree2, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &outputMatching)
double getElapsedTime()
Definition Timer.h:15
Node * getNode(idNode nodeId) const
Definition FTMTree_MT.h:393
const scalarType & getValue(SimplexId nodeId) const
Definition FTMTree_MT.h:339
void getChildren(idNode nodeId, std::vector< idNode > &res) const
idNode getNumberOfNodes() const
Definition FTMTree_MT.h:389
void setParent(idNode nodeId, idNode newParentNodeId)
void copyMergeTreeStructure(const FTMTree_MT *tree)
idNode getRoot() const
int getRealNumberOfNodes() const
void deleteNode(idNode nodeId)
std::tuple< dataType, dataType > getBirthDeath(idNode nodeId) const
bool isNodeIdInconsistent(idNode nodeId) const
idNode makeNode(SimplexId vertexId, SimplexId linked=nullVertex)
bool isNodeAlone(idNode nodeId) const
bool isFullMerge() const
void setOrigin(SimplexId linked)
Definition FTMNode.h:72
SimplexId getOrigin() const
Definition FTMNode.h:64
void setTreeScalars(MergeTree< dataType > &mergeTree, std::vector< dataType > &scalarsVector)
MergeTree< dataType > cleanMergeTree(ftm::FTMTree_MT *tree, std::vector< int > &nodeCorr, bool useBD=true)
void getTreeScalars(const ftm::FTMTree_MT *tree, std::vector< dataType > &scalarsVector)
MergeTree< dataType > copyMergeTree(const ftm::FTMTree_MT *tree, bool doSplitMultiPersPairs=false)
void mergeTreeToFTMTree(std::vector< MergeTree< dataType > > &trees, std::vector< ftm::FTMTree_MT * > &treesT)
MergeTree< dataType > createEmptyMergeTree(int scalarSize)
unsigned int idNode
Node index in vect_nodes_.
TTK base package defining the standard types.
std::tuple< dataType, dataType > getNormalizedBirthDeath(const ftm::FTMTree_MT *tree, ftm::idNode nodeId, dataType newMin=0.0, dataType newMax=1.0)
dataType getMinMaxLocalFromVector(ftm::FTMTree_MT *tree, ftm::idNode nodeId, std::vector< dataType > &scalarsVector, bool getMin=true)
T end(std::pair< T, T > &p)
Definition ripser.cpp:503
ftm::FTMTree_MT tree
Definition FTMTree_MT.h:906
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/| (_) |"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)