TTK
Loading...
Searching...
No Matches
MergeTreeBarycenter.h
Go to the documentation of this file.
1
14
15#pragma once
16
17#include <random>
18
19// ttk common includes
20#include <Debug.h>
21#include <Triangulation.h>
22
23#include "MergeTreeBase.h"
24#include "MergeTreeDistance.h"
25
26namespace ttk {
27
32 class MergeTreeBarycenter : virtual public Debug, public MergeTreeBase {
33
34 protected:
35 double tol_ = 0.0;
36 bool addNodes_ = true;
37 bool deterministic_ = true;
38 bool isCalled_ = false;
41 double alpha_ = 0.5;
44
45 double allDistanceTime_ = 0;
46
48
49 bool preprocess_ = true;
50 bool postprocess_ = true;
51
52 // Output
53 std::vector<double> finalDistances_;
54
55 public:
58 "MergeTreeBarycenter"); // inherited from Debug: prefix will be printed
59 // at the beginning of every msg
60#ifdef TTK_ENABLE_OPENMP4
61 omp_set_nested(1);
62#endif
63 }
64 ~MergeTreeBarycenter() override = default;
65
66 void setTol(double tolT) {
67 tol_ = tolT;
68 }
69
70 void setAddNodes(bool addNodesT) {
71 addNodes_ = addNodesT;
72 }
73
74 void setDeterministic(bool deterministicT) {
75 deterministic_ = deterministicT;
76 }
77
78 void setProgressiveBarycenter(bool progressive) {
79 progressiveBarycenter_ = progressive;
80 }
81
82 void setProgressiveSpeedDivisor(double progSpeed) {
83 progressiveSpeedDivisor_ = progSpeed;
84 }
85
86 void setIsCalled(bool ic) {
87 isCalled_ = ic;
88 }
89
91 return allDistanceTime_;
92 }
93
96 }
97
98 void setAlpha(double alpha) {
99 alpha_ = alpha;
100 }
101
102 void setBarycenterMaximumNumberOfPairs(unsigned int maxi) {
104 }
105
106 void setBarycenterSizeLimitPercent(double percent) {
108 }
109
110 void setPreprocess(bool preproc) {
111 preprocess_ = preproc;
112 }
113
114 void setPostprocess(bool postproc) {
115 postprocess_ = postproc;
116 }
117
118 std::vector<double> getFinalDistances() {
119 return finalDistances_;
120 }
121
125 // ------------------------------------------------------------------------
126 // Initialization
127 // ------------------------------------------------------------------------
128 template <class dataType>
129 void getDistanceMatrix(std::vector<ftm::FTMTree_MT *> &trees,
130 std::vector<ftm::FTMTree_MT *> &trees2,
131 std::vector<std::vector<double>> &distanceMatrix,
132 bool useDoubleInput = false,
133 bool isFirstInput = true) {
134 distanceMatrix.clear();
135 distanceMatrix.resize(trees.size(), std::vector<double>(trees.size(), 0));
136#ifdef TTK_ENABLE_OPENMP4
137#pragma omp parallel for schedule(dynamic) \
138 num_threads(this->threadNumber_) if(parallelize_)
139#endif
140 for(unsigned int i = 0; i < trees.size(); ++i)
141 for(unsigned int j = i + 1; j < trees.size(); ++j) {
142 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching;
143 dataType distance;
144 computeOneDistance<dataType>(trees[i], trees2[j], matching, distance,
145 useDoubleInput, isFirstInput);
146 distanceMatrix[i][j] = distance;
147 distanceMatrix[j][i] = distance;
148 }
149 }
150
151 template <class dataType>
152 void getDistanceMatrix(std::vector<ftm::FTMTree_MT *> &trees,
153 std::vector<std::vector<double>> &distanceMatrix,
154 bool useDoubleInput = false,
155 bool isFirstInput = true) {
156 getDistanceMatrix<dataType>(
157 trees, trees, distanceMatrix, useDoubleInput, isFirstInput);
158 }
159
160 template <class dataType>
162 std::vector<ftm::FTMTree_MT *> &trees,
163 unsigned int barycenterMaximumNumberOfPairs,
164 double sizeLimitPercent,
165 std::vector<ftm::MergeTree<dataType>> &mTreesLimited) {
166 mTreesLimited.resize(trees.size());
167 for(unsigned int i = 0; i < trees.size(); ++i) {
168 mTreesLimited[i] = ftm::copyMergeTree<dataType>(trees[i]);
169 limitSizeBarycenter(mTreesLimited[i], trees,
170 barycenterMaximumNumberOfPairs, sizeLimitPercent);
171 ftm::cleanMergeTree<dataType>(mTreesLimited[i]);
172 }
173 }
174
175 template <class dataType>
177 std::vector<ftm::FTMTree_MT *> &trees,
178 std::vector<std::vector<double>> &distanceMatrix,
179 unsigned int barycenterMaximumNumberOfPairs,
180 double sizeLimitPercent,
181 bool useDoubleInput = false,
182 bool isFirstInput = true) {
183 std::vector<ftm::MergeTree<dataType>> mTreesLimited;
184 getSizeLimitedTrees<dataType>(
185 trees, barycenterMaximumNumberOfPairs, sizeLimitPercent, mTreesLimited);
186 std::vector<ftm::FTMTree_MT *> treesLimited;
187 ftm::mergeTreeToFTMTree<dataType>(mTreesLimited, treesLimited);
188 getDistanceMatrix<dataType>(
189 trees, treesLimited, distanceMatrix, useDoubleInput, isFirstInput);
190 }
191
192 template <class dataType>
194 std::vector<ftm::FTMTree_MT *> &trees,
195 std::vector<std::vector<double>> &distanceMatrix,
196 unsigned int barycenterMaximumNumberOfPairs,
197 double sizeLimitPercent,
198 bool useDoubleInput = false,
199 bool isFirstInput = true) {
200 if(barycenterMaximumNumberOfPairs <= 0 and sizeLimitPercent <= 0.0)
201 getDistanceMatrix<dataType>(
202 trees, distanceMatrix, useDoubleInput, isFirstInput);
203 else
204 getSizeLimitedDistanceMatrix<dataType>(
205 trees, distanceMatrix, barycenterMaximumNumberOfPairs,
206 sizeLimitPercent, useDoubleInput, isFirstInput);
207 }
208
209 template <class dataType>
210 int getBestInitTreeIndex(std::vector<ftm::FTMTree_MT *> &trees,
211 std::vector<ftm::FTMTree_MT *> &trees2,
212 unsigned int barycenterMaximumNumberOfPairs,
213 double sizeLimitPercent,
214 bool distMinimizer = true) {
215 std::vector<std::vector<double>> distanceMatrix, distanceMatrix2;
216 bool const useDoubleInput = (trees2.size() != 0);
217 getParametrizedDistanceMatrix<dataType>(trees, distanceMatrix,
218 barycenterMaximumNumberOfPairs,
219 sizeLimitPercent, useDoubleInput);
220 if(trees2.size() != 0)
221 getParametrizedDistanceMatrix<dataType>(
222 trees2, distanceMatrix2, barycenterMaximumNumberOfPairs,
223 sizeLimitPercent, useDoubleInput, false);
224
225 int bestIndex = -1;
226 dataType bestValue
227 = distMinimizer ? std::numeric_limits<dataType>::max() : 0;
228 std::vector<int> sizes(trees.size());
229 for(unsigned int i = 0; i < trees.size(); ++i) {
230 dataType value = 0;
231 for(unsigned int j = 0; j < distanceMatrix[i].size(); ++j)
232 value += (not useDoubleInput ? distanceMatrix[i][j]
233 : mixDistances(distanceMatrix[i][j],
234 distanceMatrix2[i][j]));
235 if((distMinimizer and value < bestValue)
236 or (not distMinimizer and value > bestValue)) {
237 bestIndex = i;
238 bestValue = value;
239 }
240 sizes[i] = -value;
241 sizes[i] *= (distMinimizer) ? 1 : -1;
242 }
243 if(not deterministic_) {
244 std::random_device rd;
245 std::default_random_engine generator(rd());
246 std::discrete_distribution<int> distribution(
247 sizes.begin(), sizes.end());
248 bestIndex = distribution(generator);
249 }
250 return bestIndex;
251 }
252
253 template <class dataType>
254 int getBestInitTreeIndex(std::vector<ftm::FTMTree_MT *> &trees,
255 std::vector<ftm::FTMTree_MT *> &trees2,
256 double sizeLimitPercent,
257 bool distMinimizer = true) {
258 return getBestInitTreeIndex<dataType>(trees, trees2,
260 sizeLimitPercent, distMinimizer);
261 }
262
263 template <class dataType>
264 int getBestInitTreeIndex(std::vector<ftm::FTMTree_MT *> &trees,
265 bool distMinimizer = true) {
266 std::vector<ftm::FTMTree_MT *> trees2;
267 return getBestInitTreeIndex<dataType>(
268 trees, trees2, barycenterMaximumNumberOfPairs_,
269 barycenterSizeLimitPercent_, distMinimizer);
270 }
271
272 template <class dataType>
273 void initBarycenterTree(std::vector<ftm::FTMTree_MT *> &trees,
274 ftm::MergeTree<dataType> &baryTree,
275 bool distMinimizer = true) {
276 int const bestIndex
277 = getBestInitTreeIndex<dataType>(trees, distMinimizer);
278 baryTree = ftm::copyMergeTree<dataType>(trees[bestIndex], true);
279 limitSizeBarycenter(baryTree, trees);
280 }
281
282 // ------------------------------------------------------------------------
283 // Update
284 // ------------------------------------------------------------------------
285 template <class dataType>
288 ftm::idNode nodeId1,
289 ftm::FTMTree_MT *tree2,
290 ftm::idNode nodeId2,
291 std::vector<dataType> &newScalarsVector,
292 std::vector<std::tuple<ftm::idNode, ftm::idNode, int>> &nodesToProcess,
293 ftm::idNode nodeCpt,
294 int i) {
295 // Get nodes and scalars to add
296 std::queue<std::tuple<ftm::idNode, ftm::idNode>> queue;
297 queue.emplace(nodeId2, nodeId1);
298 nodesToProcess.emplace_back(nodeId2, nodeId1, i);
299 while(!queue.empty()) {
300 auto queueTuple = queue.front();
301 queue.pop();
302 ftm::idNode const node = std::get<0>(queueTuple);
303 // Get scalars
304 newScalarsVector.push_back(
305 tree2->getValue<dataType>(tree2->getNode(node)->getOrigin()));
306 newScalarsVector.push_back(tree2->getValue<dataType>(node));
307 // Process children
308 std::vector<ftm::idNode> children;
309 tree2->getChildren(node, children);
310 for(auto child : children) {
311 queue.emplace(child, nodeCpt + 1);
312 nodesToProcess.emplace_back(child, nodeCpt + 1, i);
313 }
314 nodeCpt += 2; // we will add two nodes (birth and death)
315 }
316
317 return nodeCpt;
318 }
319
320 template <class dataType>
323 int noTrees,
324 std::vector<std::tuple<ftm::idNode, ftm::idNode, int>> &nodesToProcess,
325 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode>>>
326 &nodesProcessed) {
327 ftm::FTMTree_MT *tree1 = &(mTree1.tree);
328
329 // Add nodes
330 nodesProcessed.clear();
331 nodesProcessed.resize(noTrees);
332 for(auto processTuple : nodesToProcess) {
333 ftm::idNode const parent = std::get<1>(processTuple);
334 ftm::idNode const nodeTree1 = tree1->getNumberOfNodes();
335 int const index = std::get<2>(processTuple);
336 nodesProcessed[index].emplace_back(
337 nodeTree1 + 1, std::get<0>(processTuple));
338 // Make node and its origin
339 tree1->makeNode(nodeTree1);
340 tree1->makeNode(nodeTree1 + 1);
341 tree1->setParent(nodeTree1 + 1, parent);
342 tree1->getNode(nodeTree1)->setOrigin(nodeTree1 + 1);
343 tree1->getNode(nodeTree1 + 1)->setOrigin(nodeTree1);
344 }
345 }
346
347 template <class dataType>
350 int noTrees,
351 std::vector<std::tuple<ftm::idNode, ftm::idNode, int>> &nodesToProcess,
352 std::vector<dataType> &newScalarsVector,
353 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode>>>
354 &nodesProcessed) {
355 ftm::FTMTree_MT *tree1 = &(mTree1.tree);
356
357 // Create new tree
359 = ftm::createEmptyMergeTree<dataType>(newScalarsVector.size());
360 ftm::setTreeScalars<dataType>(mTreeNew, newScalarsVector);
361 ftm::FTMTree_MT *treeNew = &(mTreeNew.tree);
362
363 // Copy the old tree structure
364 treeNew->copyMergeTreeStructure(tree1);
365
366 // Add nodes in the other trees
367 addNodes<dataType>(mTreeNew, noTrees, nodesToProcess, nodesProcessed);
368
369 // Copy new tree
370 mTree1 = mTreeNew;
371 }
372
373 template <class dataType>
375 std::vector<ftm::FTMTree_MT *> &trees,
376 ftm::MergeTree<dataType> &baryMergeTree,
377 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
378 &matchings) {
379 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
380 ftm::idNode const baryTreeRoot = baryTree->getRoot();
381
382 // Init matching matrix
383 // m[i][j] contains the node in the barycenter matched to the jth node of
384 // the ith tree
385 std::vector<std::vector<ftm::idNode>> matrixMatchings(trees.size());
386 std::vector<bool> baryMatched(baryTree->getNumberOfNodes(), false);
387 for(unsigned int i = 0; i < matchings.size(); ++i) {
388 auto matching = matchings[i];
389 matrixMatchings[i].resize(trees[i]->getNumberOfNodes(),
390 std::numeric_limits<ftm::idNode>::max());
391 for(auto match : matching) {
392 matrixMatchings[i][std::get<1>(match)] = std::get<0>(match);
393 baryMatched[std::get<0>(match)] = true;
394 }
395 }
396
397 // Iterate through trees to get the nodes to add in the barycenter
398 std::vector<std::vector<ftm::idNode>> nodesToAdd(trees.size());
399 for(unsigned int i = 0; i < trees.size(); ++i) {
400 ftm::idNode const root = trees[i]->getRoot();
401 std::queue<ftm::idNode> queue;
402 queue.emplace(root);
403 while(!queue.empty()) {
404 ftm::idNode const node = queue.front();
405 queue.pop();
406 bool processChildren = true;
407 // if node in trees[i] is not matched
408 if(matrixMatchings[i][node]
409 == std::numeric_limits<ftm::idNode>::max()) {
410 if(not keepSubtree_) {
411 processChildren = false;
412 nodesToAdd[i].push_back(node);
413 } else {
414 // not todo manage if keepSubtree=true (not important since it is
415 // not a valid merge tree)
416 printErr(
417 "barycenter with keepSubtree_=true is not implemented yet");
418 }
419 }
420 if(processChildren) {
421 std::vector<ftm::idNode> children;
422 trees[i]->getChildren(node, children);
423 for(auto child : children)
424 if(not(trees[i]->isThereOnlyOnePersistencePair()
425 and trees[i]->isLeaf(child)))
426 queue.emplace(child);
427 }
428 }
429 }
430
431 bool foundRootNotMatched = false;
432 for(unsigned int i = 0; i < trees.size(); ++i)
433 foundRootNotMatched |= baryTree->isNodeIdInconsistent(
434 matrixMatchings[i][trees[i]->getRoot()]);
435 if(foundRootNotMatched)
436 printWrn("[updateBarycenterTreeStructure] an input tree has its root "
437 "not matched.");
438
439 // Delete nodes that are not matched in the barycenter
440 for(unsigned int i = 0; i < baryTree->getNumberOfNodes(); ++i)
441 if(not baryMatched[i])
442 baryTree->deleteNode(i);
443
444 if(not keepSubtree_) {
445 // Add scalars and nodes not present in the barycenter
446 ftm::idNode nodeCpt = baryTree->getNumberOfNodes();
447 std::vector<std::tuple<ftm::idNode, ftm::idNode, int>> nodesToProcess;
448 std::vector<dataType> newScalarsVector;
449 ftm::getTreeScalars<dataType>(baryMergeTree, newScalarsVector);
450 for(unsigned int i = 0; i < nodesToAdd.size(); ++i) {
451 for(auto node : nodesToAdd[i]) {
452 ftm::idNode parent
453 = matrixMatchings[i][trees[i]->getParentSafe(node)];
454 if(matchings[i].size() == 0)
455 parent = baryTreeRoot;
456
457 if((baryTree->isNodeIdInconsistent(parent)
458 or baryTree->isNodeAlone(parent))
459 and matchings[i].size() != 0) {
460 std::stringstream ss;
461 ss << trees[i]->getParentSafe(node) << " _ " << node;
462 printMsg(ss.str());
463 printMsg(trees[i]->printTree().str());
464 printMsg(trees[i]->printPairsFromTree<dataType>(true).str());
465 printMatching(matchings[i]);
466 std::stringstream ss2;
467 ss2 << "parent " << parent;
468 printMsg(ss2.str());
469 }
470 /*if(isRoot(trees[i], node))
471 parent = baryTree->getRoot();*/
472 std::vector<dataType> addedScalars;
473 nodeCpt = getNodesAndScalarsToAdd<dataType>(
474 baryMergeTree, parent, trees[i], node, addedScalars,
475 nodesToProcess, nodeCpt, i);
476 newScalarsVector.insert(
477 newScalarsVector.end(), addedScalars.begin(), addedScalars.end());
478 }
479 }
480 if(addNodes_) {
481 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode>>>
482 nodesProcessed;
483 updateNodesAndScalars<dataType>(baryMergeTree, trees.size(),
484 nodesToProcess, newScalarsVector,
485 nodesProcessed);
486 for(unsigned int i = 0; i < matchings.size(); ++i) {
487 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>
488 nodesProcessedT;
489 for(auto tup : nodesProcessed[i])
490 nodesProcessedT.emplace_back(
491 std::get<0>(tup), std::get<1>(tup), -1);
492 matchings[i].insert(matchings[i].end(), nodesProcessedT.begin(),
493 nodesProcessedT.end());
494 }
495 }
496 } else {
497 // not todo manage if keepSubtree=true (not important since it is not a
498 // valid merge tree)
499 printErr("barycenter with keepSubtree_=true is not implemented yet");
500 }
501 }
502
503 template <class dataType>
504 std::tuple<dataType, dataType>
506 std::tuple<dataType, dataType> birthDeath;
507 // Normalized Wasserstein
509 birthDeath = getNormalizedBirthDeathDouble<dataType>(tree1, nodeId1);
510 // Classical Wasserstein
511 else
512 birthDeath = tree1->getBirthDeath<dataType>(nodeId1);
513 return birthDeath;
514 }
515
516 template <class dataType>
517 std::tuple<dataType, dataType>
519 ftm::idNode nodeId,
520 std::vector<dataType> &newScalarsVector,
521 std::vector<ftm::FTMTree_MT *> &trees,
522 std::vector<ftm::idNode> &nodes,
523 std::vector<double> &alphas) {
524 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
525 dataType mu_max = getMinMaxLocalFromVector<dataType>(
526 baryTree, nodeId, newScalarsVector, false);
527 dataType mu_min = getMinMaxLocalFromVector<dataType>(
528 baryTree, nodeId, newScalarsVector);
529 double newBirth = 0, newDeath = 0;
530
531 // Compute projection
532 double tempBirth = 0, tempDeath = 0;
533 double alphaSum = 0;
534 for(unsigned int i = 0; i < trees.size(); ++i)
535 if(nodes[i] != std::numeric_limits<ftm::idNode>::max())
536 alphaSum += alphas[i];
537 for(unsigned int i = 0; i < trees.size(); ++i) {
538 // if node is matched in trees[i]
539 if(nodes[i] != std::numeric_limits<ftm::idNode>::max()) {
540 auto iBirthDeath
541 = getParametrizedBirthDeath<dataType>(trees[i], nodes[i]);
542 double tTempBirth = 0, tTempDeath = 0;
543 tTempBirth += std::get<0>(iBirthDeath);
544 tTempDeath += std::get<1>(iBirthDeath);
545 tempBirth += tTempBirth * alphas[i] / alphaSum;
546 tempDeath += tTempDeath * alphas[i] / alphaSum;
547 }
548 }
549 double const projec = (tempBirth + tempDeath) / 2;
550
551 // Compute newBirth and newDeath
552 for(unsigned int i = 0; i < trees.size(); ++i) {
553 double iBirth = projec, iDeath = projec;
554 // if node is matched in trees[i]
555 if(nodes[i] != std::numeric_limits<ftm::idNode>::max()) {
556 auto iBirthDeath
557 = getParametrizedBirthDeath<dataType>(trees[i], nodes[i]);
558 iBirth = std::get<0>(iBirthDeath);
559 iDeath = std::get<1>(iBirthDeath);
560 }
561 newBirth += alphas[i] * iBirth;
562 newDeath += alphas[i] * iDeath;
563 }
565 newBirth = newBirth * (mu_max - mu_min) + mu_min;
566 newDeath = newDeath * (mu_max - mu_min) + mu_min;
567 }
568
569 return std::make_tuple(newBirth, newDeath);
570 }
571
572 template <class dataType>
573 std::tuple<dataType, dataType>
575 ftm::idNode nodeId,
576 double alpha,
577 ftm::MergeTree<dataType> &baryMergeTree,
578 ftm::idNode nodeB,
579 std::vector<dataType> &newScalarsVector) {
580 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
581 dataType mu_max = getMinMaxLocalFromVector<dataType>(
582 baryTree, nodeB, newScalarsVector, false);
583 dataType mu_min
584 = getMinMaxLocalFromVector<dataType>(baryTree, nodeB, newScalarsVector);
585
586 auto birthDeath = getParametrizedBirthDeath<dataType>(tree, nodeId);
587 double newBirth = std::get<0>(birthDeath);
588 double newDeath = std::get<1>(birthDeath);
589 double const projec = (newBirth + newDeath) / 2;
590
591 newBirth = alpha * newBirth + (1 - alpha) * projec;
592 newDeath = alpha * newDeath + (1 - alpha) * projec;
593
595 newBirth = newBirth * (mu_max - mu_min) + mu_min;
596 newDeath = newDeath * (mu_max - mu_min) + mu_min;
597 }
598
599 dataType newBirthT = newBirth;
600 dataType newDeathT = newDeath;
601 return std::make_tuple(newBirthT, newDeathT);
602 }
603
604 template <class dataType>
606 std::vector<ftm::FTMTree_MT *> &trees,
607 ftm::MergeTree<dataType> &baryMergeTree,
608 std::vector<double> &alphas,
609 unsigned int indexAddedNodes,
610 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
611 &matchings) {
612 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
613 bool const isJT = baryTree->isJoinTree<dataType>();
614
615 // Init matching matrix
616 // m[i][j] contains the node in trees[j] matched to the node i in the
617 // barycenter
618 std::vector<std::vector<ftm::idNode>> baryMatching(
619 baryTree->getNumberOfNodes(),
620 std::vector<ftm::idNode>(
621 trees.size(), std::numeric_limits<ftm::idNode>::max()));
622 std::vector<int> nodesAddedTree(baryTree->getNumberOfNodes(), -1);
623 for(unsigned int i = 0; i < matchings.size(); ++i) {
624 auto matching = matchings[i];
625 for(auto match : matching) {
626 baryMatching[std::get<0>(match)][i] = std::get<1>(match);
627 if(std::get<0>(match)
628 >= indexAddedNodes) // get the tree of this added node
629 nodesAddedTree[std::get<0>(match)] = i;
630 }
631 }
632
633 // Interpolate scalars
634 std::vector<dataType> newScalarsVector(baryTree->getNumberOfNodes());
635 ftm::idNode const root = baryTree->getRoot();
636 std::queue<ftm::idNode> queue;
637 queue.emplace(root);
638 while(!queue.empty()) {
639 ftm::idNode const node = queue.front();
640 queue.pop();
641 std::tuple<dataType, dataType> newBirthDeath;
642 if(node < indexAddedNodes) {
643 newBirthDeath
644 = interpolation<dataType>(baryMergeTree, node, newScalarsVector,
645 trees, baryMatching[node], alphas);
646 } else {
647 int const i = nodesAddedTree[node];
648 ftm::idNode const nodeT = baryMatching[node][i];
649 newBirthDeath = interpolationAdded<dataType>(
650 trees[i], nodeT, alphas[i], baryMergeTree, node, newScalarsVector);
651 }
652 dataType nodeScalar
653 = (isJT ? std::get<1>(newBirthDeath) : std::get<0>(newBirthDeath));
654 dataType nodeOriginScalar
655 = (isJT ? std::get<0>(newBirthDeath) : std::get<1>(newBirthDeath));
656 newScalarsVector[node] = nodeScalar;
657 newScalarsVector[baryTree->getNode(node)->getOrigin()]
658 = nodeOriginScalar;
659 std::vector<ftm::idNode> children;
660 baryTree->getChildren(node, children);
661 for(auto child : children)
662 queue.emplace(child);
663 }
664
665 if(baryMergeTree.tree.isFullMerge()) {
666 auto mergedRootOrigin = baryTree->getMergedRootOrigin<dataType>();
667 dataType mergedRootOriginScalar = 0.0;
668 for(unsigned int i = 0; i < trees.size(); ++i)
669 mergedRootOriginScalar += trees[i]->getValue<dataType>(
670 trees[i]->getMergedRootOrigin<dataType>());
671 mergedRootOriginScalar /= trees.size();
672 newScalarsVector[mergedRootOrigin] = mergedRootOriginScalar;
673 }
674
675 setTreeScalars(baryMergeTree, newScalarsVector);
676
677 std::vector<ftm::idNode> deletedNodesT;
678 persistenceThresholding<dataType>(
679 &(baryMergeTree.tree), 0, deletedNodesT);
680 limitSizeBarycenter(baryMergeTree, trees);
681 ftm::cleanMergeTree<dataType>(baryMergeTree);
682 }
683
684 template <class dataType>
686 std::vector<ftm::FTMTree_MT *> &trees,
687 ftm::MergeTree<dataType> &baryMergeTree,
688 std::vector<double> &alphas,
689 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
690 &matchings) {
691 int const indexAddedNodes = baryMergeTree.tree.getNumberOfNodes();
692 updateBarycenterTreeStructure<dataType>(trees, baryMergeTree, matchings);
693 updateBarycenterTreeScalars<dataType>(
694 trees, baryMergeTree, alphas, indexAddedNodes, matchings);
695 }
696
697 // ------------------------------------------------------------------------
698 // Assignment
699 // ------------------------------------------------------------------------
700 template <class dataType>
702 ftm::FTMTree_MT *tree,
703 ftm::FTMTree_MT *baryTree,
704 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
705 dataType &distance,
706 bool useDoubleInput = false,
707 bool isFirstInput = true) {
708 // Timer t_distance;
709 MergeTreeDistance mergeTreeDistance;
710 mergeTreeDistance.setDebugLevel(std::min(debugLevel_, 2));
711 mergeTreeDistance.setPreprocess(false);
712 mergeTreeDistance.setPostprocess(false);
713 mergeTreeDistance.setBranchDecomposition(true);
715 mergeTreeDistance.setKeepSubtree(keepSubtree_);
716 mergeTreeDistance.setAssignmentSolver(assignmentSolverID_);
717 mergeTreeDistance.setIsCalled(true);
718 mergeTreeDistance.setThreadNumber(this->threadNumber_);
719 mergeTreeDistance.setDistanceSquaredRoot(true); // squared root
720 mergeTreeDistance.setNodePerTask(nodePerTask_);
721 if(useDoubleInput) {
722 double const weight = mixDistancesMinMaxPairWeight(isFirstInput);
723 mergeTreeDistance.setMinMaxPairWeight(weight);
724 }
725 /*if(progressiveBarycenter_){
726 mergeTreeDistance.setAuctionNoRounds(1);
727 mergeTreeDistance.setAuctionEpsilonDiviser(NoIteration-1);
728 }*/
729 distance
730 = mergeTreeDistance.computeDistance<dataType>(baryTree, tree, matching);
731 std::stringstream ss, ss2;
732 ss << "distance tree : " << distance;
734 ss2 << "distance²tree : " << distance * distance;
736
737 // auto t_distance_time = t_distance.getElapsedTime();
738 // allDistanceTime_ += t_distance_time;
739 }
740
741 template <class dataType>
743 ftm::FTMTree_MT *tree,
744 ftm::MergeTree<dataType> &baryMergeTree,
745 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
746 dataType &distance,
747 bool useDoubleInput = false,
748 bool isFirstInput = true) {
749 computeOneDistance<dataType>(tree, &(baryMergeTree.tree), matching,
750 distance, useDoubleInput, isFirstInput);
751 }
752
753 template <class dataType>
755 ftm::MergeTree<dataType> &baryMergeTree,
756 ftm::MergeTree<dataType> &baryMergeTree2,
757 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
758 dataType &distance,
759 bool useDoubleInput = false,
760 bool isFirstInput = true) {
761 computeOneDistance<dataType>(&(baryMergeTree.tree), baryMergeTree2,
762 matching, distance, useDoubleInput,
763 isFirstInput);
764 }
765
766 template <class dataType>
768 std::vector<ftm::FTMTree_MT *> &trees,
769 ftm::MergeTree<dataType> &baryMergeTree,
770 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
771 &matchings,
772 std::vector<dataType> &distances,
773 bool useDoubleInput = false,
774 bool isFirstInput = true) {
775 if(not isCalled_)
776 assignmentPara(trees, baryMergeTree, matchings, distances,
777 useDoubleInput, isFirstInput);
778 else
779 assignmentTask(trees, baryMergeTree, matchings, distances,
780 useDoubleInput, isFirstInput);
781 }
782
783 template <class dataType>
785 std::vector<ftm::FTMTree_MT *> &trees,
786 ftm::MergeTree<dataType> &baryMergeTree,
787 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
788 &matchings,
789 std::vector<dataType> &distances,
790 bool useDoubleInput = false,
791 bool isFirstInput = true) {
792#ifdef TTK_ENABLE_OPENMP4
793#pragma omp parallel num_threads(this->threadNumber_) \
794 shared(baryMergeTree) if(parallelize_)
795 {
796#pragma omp single nowait
797#endif
798 assignmentTask(trees, baryMergeTree, matchings, distances,
799 useDoubleInput, isFirstInput);
800#ifdef TTK_ENABLE_OPENMP4
801 } // pragma omp parallel
802#endif
803 }
804
805 template <class dataType>
807 std::vector<ftm::FTMTree_MT *> &trees,
808 ftm::MergeTree<dataType> &baryMergeTree,
809 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
810 &matchings,
811 std::vector<dataType> &distances,
812 bool useDoubleInput = false,
813 bool isFirstInput = true) {
814 for(unsigned int i = 0; i < trees.size(); ++i)
815#ifdef TTK_ENABLE_OPENMP4
816#pragma omp task firstprivate(i) UNTIED() \
817 shared(baryMergeTree, matchings, distances)
818#endif
819 computeOneDistance<dataType>(trees[i], baryMergeTree, matchings[i],
820 distances[i], useDoubleInput,
821 isFirstInput);
822#ifdef TTK_ENABLE_OPENMP4
823#pragma omp taskwait
824#endif
825 }
826
827 // ------------------------------------------------------------------------
828 // Progressivity
829 // ------------------------------------------------------------------------
830 template <class dataType>
831 unsigned int
832 persistenceScaling(std::vector<ftm::FTMTree_MT *> &trees,
833 std::vector<ftm::MergeTree<dataType>> &mergeTrees,
834 std::vector<ftm::FTMTree_MT *> &oriTrees,
835 int iterationNumber,
836 std::vector<std::vector<ftm::idNode>> &deletedNodes) {
837 deletedNodes.clear();
838 deletedNodes.resize(oriTrees.size());
839 unsigned int noTreesUnscaled = 0;
840
841 // Scale trees
842 for(unsigned int i = 0; i < oriTrees.size(); ++i) {
843 double persistenceThreshold = 50.0;
844 if(iterationNumber != -1) {
845 // Get number of pairs in scaled merge tree
846 int const noPairs = mergeTrees[i].tree.getRealNumberOfNodes();
847
848 // Get pairs in original merge tree
849 std::vector<std::tuple<ftm::idNode, ftm::idNode, dataType>> pairs;
850 oriTrees[i]->getPersistencePairsFromTree<dataType>(
851 pairs, branchDecomposition_);
852
853 // Compute new persistence threshold
854 double const multiplier
856 ? 1.
857 : iterationNumber / progressiveSpeedDivisor_);
858 int const decrement = multiplier * pairs.size() / 10;
859 int thresholdIndex = pairs.size() - noPairs - std::max(decrement, 2);
860 thresholdIndex = std::max(thresholdIndex, 0);
861 const double persistence = std::get<2>(pairs[thresholdIndex]);
862 persistenceThreshold
863 = persistence / std::get<2>(pairs.back()) * 100.0;
864 if(thresholdIndex == 0) {
865 persistenceThreshold = 0.;
866 ++noTreesUnscaled;
867 }
868 }
869 if(persistenceThreshold != 0.) {
871 = ftm::copyMergeTree<dataType>(oriTrees[i]);
872 persistenceThresholding<dataType>(
873 &(mt.tree), persistenceThreshold, deletedNodes[i]);
874 if(mergeTrees.size() == 0)
875 mergeTrees.resize(oriTrees.size());
876 mergeTrees[i] = mt;
877 trees[i] = &(mt.tree);
878 } else {
879 trees[i] = oriTrees[i];
880 }
881 }
882
883 printTreesStats(trees);
884
885 return noTreesUnscaled;
886 }
887
888 template <class dataType>
890 std::vector<ftm::FTMTree_MT *> &oriTrees,
891 std::vector<std::vector<ftm::idNode>> &deletedNodes,
892 std::vector<dataType> &distances) {
893 for(unsigned int i = 0; i < oriTrees.size(); ++i)
894 for(auto node : deletedNodes[i])
895 distances[i] += deleteCost<dataType>(oriTrees[i], node);
896 }
897
898 // ------------------------------------------------------------------------
899 // Main Functions
900 // ------------------------------------------------------------------------
901 template <class dataType>
903 std::vector<ftm::FTMTree_MT *> &trees,
904 ftm::MergeTree<dataType> &baryMergeTree,
905 std::vector<double> &alphas,
906 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
907 &finalMatchings,
908 bool finalAsgnDoubleInput = false,
909 bool finalAsgnFirstInput = true) {
910 Timer t_bary;
911
912 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
913
914 // Persistence scaling
915 std::vector<ftm::FTMTree_MT *> oriTrees;
916 std::vector<ftm::MergeTree<dataType>> scaledMergeTrees;
917 std::vector<std::vector<ftm::idNode>> deletedNodes;
919 oriTrees.insert(oriTrees.end(), trees.begin(), trees.end());
920 persistenceScaling<dataType>(
921 trees, scaledMergeTrees, oriTrees, -1, deletedNodes);
922 std::vector<ftm::idNode> deletedNodesT;
923 persistenceThresholding<dataType>(baryTree, 50, deletedNodesT);
924 }
925 bool treesUnscaled = false;
926
927 // Print bary stats
928 printBaryStats(baryTree);
929
930 // Run
931 bool converged = false;
932 dataType frechetEnergy = -1;
933 dataType minFrechet = std::numeric_limits<dataType>::max();
934 int cptBlocked = 0;
935 int NoIteration = 0;
936 while(not converged) {
937 ++NoIteration;
938
940 std::stringstream ss;
941 ss << "Iteration " << NoIteration;
942 printMsg(ss.str());
943
944 // --- Assignment
945 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
946 matchings(trees.size());
947 std::vector<dataType> distances(trees.size(), -1);
948 Timer t_assignment;
949 assignment<dataType>(trees, baryMergeTree, matchings, distances);
950 Timer t_addDeletedNodes;
952 addScaledDeletedNodesCost<dataType>(
953 oriTrees, deletedNodes, distances);
954 addDeletedNodesTime_ += t_addDeletedNodes.getElapsedTime();
955 auto t_assignment_time
956 = t_assignment.getElapsedTime() - t_addDeletedNodes.getElapsedTime();
957 printMsg("Assignment", 1, t_assignment_time, this->threadNumber_,
959
960 // --- Update
961 Timer t_update;
962 updateBarycenterTree<dataType>(trees, baryMergeTree, alphas, matchings);
963 auto t_update_time = t_update.getElapsedTime();
964 baryTree = &(baryMergeTree.tree);
965 printMsg("Update", 1, t_update_time, this->threadNumber_,
967
968 // --- Check convergence
969 dataType currentFrechetEnergy = 0;
970 for(unsigned int i = 0; i < trees.size(); ++i)
971 currentFrechetEnergy += alphas[i] * distances[i] * distances[i];
972 auto frechetDiff
973 = std::abs((double)(frechetEnergy - currentFrechetEnergy));
974 converged = (frechetDiff <= tol_);
975 converged = converged and (not progressiveBarycenter_ or treesUnscaled);
976 frechetEnergy = currentFrechetEnergy;
977 tol_ = frechetEnergy / 125.0;
978
979 std::stringstream ss4;
980 auto barycenterTime = t_bary.getElapsedTime() - addDeletedNodesTime_;
981 printMsg("Total", 1, barycenterTime, this->threadNumber_,
983 printBaryStats(baryTree);
984 ss4 << "Frechet energy : " << frechetEnergy;
985 printMsg(ss4.str());
986
987 minFrechet = std::min(minFrechet, frechetEnergy);
988 if(not converged and (not progressiveBarycenter_ or treesUnscaled)) {
989 cptBlocked = (minFrechet < frechetEnergy) ? cptBlocked + 1 : 0;
990 converged = (cptBlocked >= 10);
991 }
992
993 // --- Persistence scaling
995 unsigned int const noTreesUnscaled = persistenceScaling<dataType>(
996 trees, scaledMergeTrees, oriTrees, NoIteration, deletedNodes);
997 treesUnscaled = (noTreesUnscaled == oriTrees.size());
998 }
999 }
1000
1001 // Final processing
1003 printMsg("Final assignment");
1004
1005 std::vector<dataType> distances(trees.size(), -1);
1006 assignment<dataType>(trees, baryMergeTree, finalMatchings, distances,
1007 finalAsgnDoubleInput, finalAsgnFirstInput);
1008 for(auto dist : distances)
1009 finalDistances_.push_back(dist);
1010 dataType currentFrechetEnergy = 0;
1011 for(unsigned int i = 0; i < trees.size(); ++i)
1012 currentFrechetEnergy += alphas[i] * distances[i] * distances[i];
1013
1014 std::stringstream ss, ss2;
1015 ss << "Frechet energy : " << currentFrechetEnergy;
1016 printMsg(ss.str());
1017 auto barycenterTime = t_bary.getElapsedTime() - addDeletedNodesTime_;
1018 printMsg("Total", 1, barycenterTime, this->threadNumber_,
1020 // std::cout << "Bary Distance Time = " << allDistanceTime_ << std::endl;
1021
1022 if(trees.size() == 2 and not isCalled_)
1023 verifyBarycenterTwoTrees<dataType>(
1024 trees, baryMergeTree, finalMatchings, distances);
1025
1026 // Persistence (un)scaling
1028 scaledMergeTrees.clear();
1029 trees.clear();
1030 trees.insert(trees.end(), oriTrees.begin(), oriTrees.end());
1031 }
1032 }
1033
1034 template <class dataType>
1036 std::vector<ftm::MergeTree<dataType>> &trees,
1037 std::vector<double> &alphas,
1038 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1039 &finalMatchings,
1040 ftm::MergeTree<dataType> &baryMergeTree,
1041 bool finalAsgnDoubleInput = false,
1042 bool finalAsgnFirstInput = true) {
1043 // --- Preprocessing
1044 if(preprocess_) {
1045 treesNodeCorr_.resize(trees.size());
1046 for(unsigned int i = 0; i < trees.size(); ++i)
1047 preprocessingPipeline<dataType>(trees[i], epsilonTree2_,
1051 printTreesStats(trees);
1052 }
1053
1054 // --- Init barycenter
1055 std::vector<ftm::FTMTree_MT *> treesT;
1056 ftm::mergeTreeToFTMTree<dataType>(trees, treesT);
1057 initBarycenterTree<dataType>(treesT, baryMergeTree);
1058
1059 // --- Execute
1060 computeBarycenter<dataType>(treesT, baryMergeTree, alphas, finalMatchings,
1061 finalAsgnDoubleInput, finalAsgnFirstInput);
1062
1063 // --- Postprocessing
1064 if(postprocess_) {
1065 std::vector<int> const allRealNodes(trees.size());
1066 for(unsigned int i = 0; i < trees.size(); ++i) {
1067 postprocessingPipeline<dataType>(treesT[i]);
1068 }
1069
1070 // fixMergedRootOriginBarycenter<dataType>(baryMergeTree);
1071 postprocessingPipeline<dataType>(&(baryMergeTree.tree));
1072 for(unsigned int i = 0; i < trees.size(); ++i) {
1073 convertBranchDecompositionMatching<dataType>(
1074 &(baryMergeTree.tree), treesT[i], finalMatchings[i]);
1075 }
1076 }
1077 }
1078
1079 template <class dataType>
1081 std::vector<ftm::MergeTree<dataType>> &trees,
1082 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1083 &finalMatchings,
1084 ftm::MergeTree<dataType> &baryMergeTree,
1085 bool finalAsgnDoubleInput = false,
1086 bool finalAsgnFirstInput = true) {
1087 std::vector<double> alphas;
1088 if(trees.size() != 2) {
1089 for(unsigned int i = 0; i < trees.size(); ++i)
1090 alphas.push_back(1.0 / trees.size());
1091 } else {
1092 alphas.push_back(alpha_);
1093 alphas.push_back(1 - alpha_);
1094 }
1095
1096 execute<dataType>(trees, alphas, finalMatchings, baryMergeTree,
1097 finalAsgnDoubleInput, finalAsgnFirstInput);
1098 }
1099
1100 // ------------------------------------------------------------------------
1101 // Preprocessing
1102 // ------------------------------------------------------------------------
1103 template <class dataType>
1105 std::vector<ftm::FTMTree_MT *> &trees,
1106 double percent,
1107 bool useBD) {
1108 auto metric = getSizeLimitMetric(trees);
1109 unsigned int const newNoNodes = metric * percent / 100.0;
1110 keepMostImportantPairs<dataType>(&(bary.tree), newNoNodes, useBD);
1111
1112 unsigned int const noNodesAfter = bary.tree.getRealNumberOfNodes();
1113 if(bary.tree.isFullMerge() and noNodesAfter > newNoNodes * 1.1 + 1
1114 and noNodesAfter > 3) {
1115 std::cout << "metric = " << metric << std::endl;
1116 std::cout << "newNoNodes = " << newNoNodes << std::endl;
1117 std::cout << "noNodesAfter = " << noNodesAfter << std::endl;
1118 }
1119 }
1120
1121 template <class dataType>
1123 std::vector<ftm::FTMTree_MT *> &trees,
1124 unsigned int barycenterMaximumNumberOfPairs,
1125 double percent,
1126 bool useBD = true) {
1127 if(barycenterMaximumNumberOfPairs > 0)
1128 keepMostImportantPairs<dataType>(
1129 &(bary.tree), barycenterMaximumNumberOfPairs, useBD);
1130 if(percent > 0)
1131 limitSizePercent(bary, trees, percent, useBD);
1132 }
1133 template <class dataType>
1135 std::vector<ftm::FTMTree_MT *> &trees,
1136 double percent,
1137 bool useBD = true) {
1139 bary, trees, barycenterMaximumNumberOfPairs_, percent, useBD);
1140 }
1141 template <class dataType>
1143 std::vector<ftm::FTMTree_MT *> &trees,
1144 bool useBD = true) {
1147 }
1148
1149 // ------------------------------------------------------------------------
1150 // Postprocessing
1151 // ------------------------------------------------------------------------
1152 template <class dataType>
1154 if(not barycenter.tree.isFullMerge())
1155 return;
1156
1157 ftm::FTMTree_MT *tree = &(barycenter.tree);
1158 auto tup = fixMergedRootOrigin<dataType>(tree);
1159 int maxIndex = std::get<0>(tup);
1160 dataType oldOriginValue = std::get<1>(tup);
1161
1162 // Verify that scalars are consistent
1163 ftm::idNode const treeRoot = tree->getRoot();
1164 std::vector<dataType> newScalarsVector;
1165 ftm::getTreeScalars<dataType>(tree, newScalarsVector);
1166 bool isJT = tree->isJoinTree<dataType>();
1167 if((isJT and tree->getValue<dataType>(maxIndex) > oldOriginValue)
1168 or (not isJT
1169 and tree->getValue<dataType>(maxIndex) < oldOriginValue)) {
1170 newScalarsVector[treeRoot] = newScalarsVector[maxIndex];
1171 newScalarsVector[maxIndex] = oldOriginValue;
1172 } else
1173 newScalarsVector[treeRoot] = oldOriginValue;
1174 setTreeScalars(barycenter, newScalarsVector);
1175 }
1176
1177 // ------------------------------------------------------------------------
1178 // Utils
1179 // ------------------------------------------------------------------------
1181 const debug::Priority &priority
1183 auto noNodesT = baryTree->getNumberOfNodes();
1184 auto noNodes = baryTree->getRealNumberOfNodes();
1185 std::stringstream ss;
1186 ss << "Barycenter number of nodes : " << noNodes << " / " << noNodesT;
1187 printMsg(ss.str(), priority);
1188 }
1189
1190 // ------------------------------------------------------------------------
1191 // Testing
1192 // ------------------------------------------------------------------------
1193 template <class dataType>
1195 std::vector<ftm::FTMTree_MT *> &trees,
1196 ftm::MergeTree<dataType> &baryMergeTree,
1197 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1198 &finalMatchings,
1199 std::vector<dataType> distances) {
1200 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching;
1201 dataType distance;
1202 computeOneDistance(trees[0], trees[1], matching, distance);
1203 if(distance != (distances[0] + distances[1])) {
1204 std::stringstream ss, ss2, ss3, ss4;
1205 ss << "distance T1 T2 : " << distance;
1206 printMsg(ss.str());
1207 ss2 << "distance T1 T' T2 : " << distances[0] + distances[1];
1208 printMsg(ss2.str());
1209 ss3 << "distance T1 T' : " << distances[0];
1210 printMsg(ss3.str());
1211 ss4 << "distance T' T2 : " << distances[1];
1212 printMsg(ss4.str());
1213 }
1214 return;
1215
1216 auto baryTree = &(baryMergeTree.tree);
1217 std::vector<std::vector<ftm::idNode>> baryMatched(
1218 baryTree->getNumberOfNodes(),
1219 std::vector<ftm::idNode>(
1220 trees.size(), std::numeric_limits<ftm::idNode>::max()));
1221 for(unsigned int i = 0; i < finalMatchings.size(); ++i)
1222 for(auto match : finalMatchings[i])
1223 baryMatched[std::get<0>(match)][i] = std::get<1>(match);
1224
1225 std::queue<ftm::idNode> queue;
1226 queue.emplace(baryTree->getRoot());
1227 while(!queue.empty()) {
1228 auto node = queue.front();
1229 queue.pop();
1230 std::vector<dataType> costs(trees.size());
1231 for(unsigned int i = 0; i < trees.size(); ++i)
1232 if(baryMatched[node][i] != std::numeric_limits<ftm::idNode>::max())
1233 costs[i] = relabelCost<dataType>(
1234 baryTree, node, trees[i], baryMatched[node][i]);
1235 else
1236 costs[i] = deleteCost<dataType>(baryTree, node);
1237 dataType cost = 0;
1238 if(baryMatched[node][0] != std::numeric_limits<ftm::idNode>::max()
1239 and baryMatched[node][1] != std::numeric_limits<ftm::idNode>::max())
1240 cost = relabelCost<dataType>(
1241 trees[0], baryMatched[node][0], trees[1], baryMatched[node][1]);
1242 else if(baryMatched[node][0] == std::numeric_limits<ftm::idNode>::max())
1243 cost = deleteCost<dataType>(trees[1], baryMatched[node][1]);
1244 else if(baryMatched[node][1] == std::numeric_limits<ftm::idNode>::max())
1245 cost = deleteCost<dataType>(trees[0], baryMatched[node][0]);
1246 else
1247 printErr("problem");
1248 costs[0] = std::sqrt(costs[0]);
1249 costs[1] = std::sqrt(costs[1]);
1250 cost = std::sqrt(cost);
1251 if(std::abs((double)(costs[0] - costs[1])) > 1e-7) {
1253 std::stringstream ss, ss2, ss3, ss4;
1254 ss << "cost T' T0 : " << costs[0];
1255 printMsg(ss.str());
1256 ss2 << "cost T' T1 : " << costs[1];
1257 printMsg(ss2.str());
1258 ss3 << "cost T0 T1 : " << cost;
1259 printMsg(ss2.str());
1260 ss4 << "cost T0 T' T1 : " << costs[0] + costs[1];
1261 printMsg(ss4.str());
1262 if(std::abs((double)((costs[0] + costs[1]) - cost)) > 1e-7) {
1263 std::stringstream ss5;
1264 ss5 << "diff : "
1265 << std::abs((double)((costs[0] + costs[1]) - cost));
1266 printMsg(ss5.str());
1267 }
1268 std::stringstream ss6;
1269 ss6 << "diff2 : " << std::abs((double)(costs[0] - costs[1]));
1270 printMsg(ss.str());
1271 // baryTree->printNode2<dataType>(node);
1272 // baryTree->printNode2<dataType>(baryTree->getParentSafe(node));
1273 for(unsigned int i = 0; i < 2; ++i)
1274 if(baryMatched[node][i]
1275 != std::numeric_limits<ftm::idNode>::max()) {
1276 printMsg(
1277 trees[i]->printNode2<dataType>(baryMatched[node][i]).str());
1278 printMsg(trees[i]
1279 ->printNode2<dataType>(
1280 trees[i]->getParentSafe(baryMatched[node][i]))
1281 .str());
1282 }
1283 }
1284 std::vector<ftm::idNode> children;
1285 baryTree->getChildren(node, children);
1286 for(auto child : children)
1287 queue.emplace(child);
1288 }
1289 }
1290
1291 }; // MergeTreeBarycenter class
1292
1293} // namespace ttk
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
Definition BaseClass.h:47
#define UNTIED()
virtual int setThreadNumber(const int threadNumber)
Definition BaseClass.h:80
Minimalist debugging class.
Definition Debug.h:88
int debugLevel_
Definition Debug.h:379
int printWrn(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
Definition Debug.h:159
void setDebugMsgPrefix(const std::string &prefix)
Definition Debug.h:364
virtual int setDebugLevel(const int &debugLevel)
Definition Debug.cpp:147
int printErr(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
Definition Debug.h:149
int getBestInitTreeIndex(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::FTMTree_MT * > &trees2, unsigned int barycenterMaximumNumberOfPairs, double sizeLimitPercent, bool distMinimizer=true)
void limitSizeBarycenter(ftm::MergeTree< dataType > &bary, std::vector< ftm::FTMTree_MT * > &trees, unsigned int barycenterMaximumNumberOfPairs, double percent, bool useBD=true)
ftm::idNode getNodesAndScalarsToAdd(ftm::MergeTree< dataType > &ttkNotUsed(mTree1), ftm::idNode nodeId1, ftm::FTMTree_MT *tree2, ftm::idNode nodeId2, std::vector< dataType > &newScalarsVector, std::vector< std::tuple< ftm::idNode, ftm::idNode, int > > &nodesToProcess, ftm::idNode nodeCpt, int i)
std::tuple< dataType, dataType > getParametrizedBirthDeath(ftm::FTMTree_MT *tree1, ftm::idNode nodeId1)
void getParametrizedDistanceMatrix(std::vector< ftm::FTMTree_MT * > &trees, std::vector< std::vector< double > > &distanceMatrix, unsigned int barycenterMaximumNumberOfPairs, double sizeLimitPercent, bool useDoubleInput=false, bool isFirstInput=true)
void verifyBarycenterTwoTrees(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, std::vector< dataType > distances)
void computeOneDistance(ftm::FTMTree_MT *tree, ftm::FTMTree_MT *baryTree, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matching, dataType &distance, bool useDoubleInput=false, bool isFirstInput=true)
void updateNodesAndScalars(ftm::MergeTree< dataType > &mTree1, int noTrees, std::vector< std::tuple< ftm::idNode, ftm::idNode, int > > &nodesToProcess, std::vector< dataType > &newScalarsVector, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode > > > &nodesProcessed)
unsigned int persistenceScaling(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::MergeTree< dataType > > &mergeTrees, std::vector< ftm::FTMTree_MT * > &oriTrees, int iterationNumber, std::vector< std::vector< ftm::idNode > > &deletedNodes)
void fixMergedRootOriginBarycenter(ftm::MergeTree< dataType > &barycenter)
void setAddNodes(bool addNodesT)
void setPreprocess(bool preproc)
void computeOneDistance(ftm::MergeTree< dataType > &baryMergeTree, ftm::MergeTree< dataType > &baryMergeTree2, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matching, dataType &distance, bool useDoubleInput=false, bool isFirstInput=true)
void getDistanceMatrix(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::FTMTree_MT * > &trees2, std::vector< std::vector< double > > &distanceMatrix, bool useDoubleInput=false, bool isFirstInput=true)
void limitSizeBarycenter(ftm::MergeTree< dataType > &bary, std::vector< ftm::FTMTree_MT * > &trees, double percent, bool useBD=true)
std::tuple< dataType, dataType > interpolation(ftm::MergeTree< dataType > &baryMergeTree, ftm::idNode nodeId, std::vector< dataType > &newScalarsVector, std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::idNode > &nodes, std::vector< double > &alphas)
int getBestInitTreeIndex(std::vector< ftm::FTMTree_MT * > &trees, bool distMinimizer=true)
void computeBarycenter(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< double > &alphas, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, bool finalAsgnDoubleInput=false, bool finalAsgnFirstInput=true)
void assignmentPara(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< dataType > &distances, bool useDoubleInput=false, bool isFirstInput=true)
void setPostprocess(bool postproc)
std::tuple< dataType, dataType > interpolationAdded(ftm::FTMTree_MT *tree, ftm::idNode nodeId, double alpha, ftm::MergeTree< dataType > &baryMergeTree, ftm::idNode nodeB, std::vector< dataType > &newScalarsVector)
unsigned int barycenterMaximumNumberOfPairs_
void setDeterministic(bool deterministicT)
~MergeTreeBarycenter() override=default
void updateBarycenterTreeScalars(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< double > &alphas, unsigned int indexAddedNodes, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings)
void addNodes(ftm::MergeTree< dataType > &mTree1, int noTrees, std::vector< std::tuple< ftm::idNode, ftm::idNode, int > > &nodesToProcess, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode > > > &nodesProcessed)
void initBarycenterTree(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryTree, bool distMinimizer=true)
void assignmentTask(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< dataType > &distances, bool useDoubleInput=false, bool isFirstInput=true)
void execute(std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, ftm::MergeTree< dataType > &baryMergeTree, bool finalAsgnDoubleInput=false, bool finalAsgnFirstInput=true)
void setProgressiveSpeedDivisor(double progSpeed)
void addScaledDeletedNodesCost(std::vector< ftm::FTMTree_MT * > &oriTrees, std::vector< std::vector< ftm::idNode > > &deletedNodes, std::vector< dataType > &distances)
void updateBarycenterTreeStructure(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings)
std::vector< double > finalDistances_
void getSizeLimitedTrees(std::vector< ftm::FTMTree_MT * > &trees, unsigned int barycenterMaximumNumberOfPairs, double sizeLimitPercent, std::vector< ftm::MergeTree< dataType > > &mTreesLimited)
void printBaryStats(ftm::FTMTree_MT *baryTree, const debug::Priority &priority=debug::Priority::INFO)
void setProgressiveBarycenter(bool progressive)
void limitSizeBarycenter(ftm::MergeTree< dataType > &bary, std::vector< ftm::FTMTree_MT * > &trees, bool useBD=true)
void assignment(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< dataType > &distances, bool useDoubleInput=false, bool isFirstInput=true)
void setBarycenterMaximumNumberOfPairs(unsigned int maxi)
void computeOneDistance(ftm::FTMTree_MT *tree, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matching, dataType &distance, bool useDoubleInput=false, bool isFirstInput=true)
void setBarycenterSizeLimitPercent(double percent)
void getSizeLimitedDistanceMatrix(std::vector< ftm::FTMTree_MT * > &trees, std::vector< std::vector< double > > &distanceMatrix, unsigned int barycenterMaximumNumberOfPairs, double sizeLimitPercent, bool useDoubleInput=false, bool isFirstInput=true)
void updateBarycenterTree(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< double > &alphas, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings)
std::vector< double > getFinalDistances()
void limitSizePercent(ftm::MergeTree< dataType > &bary, std::vector< ftm::FTMTree_MT * > &trees, double percent, bool useBD)
void getDistanceMatrix(std::vector< ftm::FTMTree_MT * > &trees, std::vector< std::vector< double > > &distanceMatrix, bool useDoubleInput=false, bool isFirstInput=true)
int getBestInitTreeIndex(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::FTMTree_MT * > &trees2, double sizeLimitPercent, bool distMinimizer=true)
void execute(std::vector< ftm::MergeTree< dataType > > &trees, std::vector< double > &alphas, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, ftm::MergeTree< dataType > &baryMergeTree, bool finalAsgnDoubleInput=false, bool finalAsgnFirstInput=true)
void setBranchDecomposition(bool useBD)
void setNormalizedWasserstein(bool normalizedWasserstein)
void setDistanceSquaredRoot(bool distanceSquaredRoot)
void setAssignmentSolver(int assignmentSolver)
void setNodePerTask(int npt)
double mixDistances(dataType distance1, dataType distance2)
double getSizeLimitMetric(std::vector< ftm::FTMTree_MT * > &trees)
void printTreesStats(std::vector< ftm::FTMTree_MT * > &trees)
void printMatching(std::vector< MatchingType > &matchings)
std::vector< std::vector< int > > treesNodeCorr_
void setKeepSubtree(bool keepSubtree)
double mixDistancesMinMaxPairWeight(bool isFirstInput)
dataType computeDistance(ftm::FTMTree_MT *tree1, ftm::FTMTree_MT *tree2, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &outputMatching)
void setPreprocess(bool preproc)
void setPostprocess(bool postproc)
void setMinMaxPairWeight(double weight)
double getElapsedTime()
Definition Timer.h:15
const scalarType & getValue(SimplexId nodeId) const
Definition FTMTree_MT.h:339
bool isNodeIdInconsistent(idNode nodeId)
idNode getNumberOfNodes() const
Definition FTMTree_MT.h:389
void setParent(idNode nodeId, idNode newParentNodeId)
std::tuple< dataType, dataType > getBirthDeath(idNode nodeId)
void getChildren(idNode nodeId, std::vector< idNode > &res)
void deleteNode(idNode nodeId)
idNode makeNode(SimplexId vertexId, SimplexId linked=nullVertex)
bool isNodeAlone(idNode nodeId)
void copyMergeTreeStructure(FTMTree_MT *tree)
Node * getNode(idNode nodeId)
Definition FTMTree_MT.h:393
void setOrigin(SimplexId linked)
Definition FTMNode.h:72
SimplexId getOrigin() const
Definition FTMNode.h:64
unsigned int idNode
Node index in vect_nodes_.
The Topology ToolKit.
T end(std::pair< T, T > &p)
Definition ripserpy.cpp:472
ftm::FTMTree_MT tree
Definition FTMTree_MT.h:901
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/|__ _|"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)