TTK
Loading...
Searching...
No Matches
MergeTreeBarycenter.h
Go to the documentation of this file.
1
14
15#pragma once
16
17#include <random>
18
19// ttk common includes
20#include <Debug.h>
21#include <Triangulation.h>
22
23#include "MergeTreeBase.h"
24#include "MergeTreeDistance.h"
25
26namespace ttk {
27
32 class MergeTreeBarycenter : virtual public Debug, public MergeTreeBase {
33
34 protected:
35 double tol_ = 0.0;
36 bool addNodes_ = true;
37 bool deterministic_ = true;
38 bool isCalled_ = false;
41 double alpha_ = 0.5;
44
45 double allDistanceTime_ = 0;
46
48
49 bool preprocess_ = true;
50 bool postprocess_ = true;
51
52 // Output
53 std::vector<double> finalDistances_;
54
55 public:
58 "MergeTreeBarycenter"); // inherited from Debug: prefix will be printed
59 // at the beginning of every msg
60#ifdef TTK_ENABLE_OPENMP4
61 omp_set_nested(1);
62#endif
63 }
64 ~MergeTreeBarycenter() override = default;
65
66 void setTol(double tolT) {
67 tol_ = tolT;
68 }
69
70 void setAddNodes(bool addNodesT) {
71 addNodes_ = addNodesT;
72 }
73
74 void setDeterministic(bool deterministicT) {
75 deterministic_ = deterministicT;
76 }
77
78 void setProgressiveBarycenter(bool progressive) {
79 progressiveBarycenter_ = progressive;
80 }
81
82 void setProgressiveSpeedDivisor(double progSpeed) {
83 progressiveSpeedDivisor_ = progSpeed;
84 }
85
86 void setIsCalled(bool ic) {
87 isCalled_ = ic;
88 }
89
91 return allDistanceTime_;
92 }
93
96 }
97
98 void setAlpha(double alpha) {
99 alpha_ = alpha;
100 }
101
102 void setBarycenterMaximumNumberOfPairs(unsigned int maxi) {
104 }
105
106 void setBarycenterSizeLimitPercent(double percent) {
108 }
109
110 void setPreprocess(bool preproc) {
111 preprocess_ = preproc;
112 }
113
114 void setPostprocess(bool postproc) {
115 postprocess_ = postproc;
116 }
117
118 std::vector<double> getFinalDistances() {
119 return finalDistances_;
120 }
121
125 // ------------------------------------------------------------------------
126 // Initialization
127 // ------------------------------------------------------------------------
128 template <class dataType>
129 void getDistanceMatrix(std::vector<ftm::FTMTree_MT *> &trees,
130 std::vector<ftm::FTMTree_MT *> &trees2,
131 std::vector<std::vector<double>> &distanceMatrix,
132 bool useDoubleInput = false,
133 bool isFirstInput = true) {
134 distanceMatrix.clear();
135 distanceMatrix.resize(trees.size(), std::vector<double>(trees.size(), 0));
136#ifdef TTK_ENABLE_OPENMP4
137#pragma omp parallel for schedule(dynamic) \
138 num_threads(this->threadNumber_) if(parallelize_)
139#endif
140 for(unsigned int i = 0; i < trees.size(); ++i)
141 for(unsigned int j = i + 1; j < trees.size(); ++j) {
142 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching;
143 dataType distance;
144 computeOneDistance<dataType>(trees[i], trees2[j], matching, distance,
145 useDoubleInput, isFirstInput);
146 distanceMatrix[i][j] = distance;
147 distanceMatrix[j][i] = distance;
148 }
149 }
150
151 template <class dataType>
152 void getDistanceMatrix(std::vector<ftm::FTMTree_MT *> &trees,
153 std::vector<std::vector<double>> &distanceMatrix,
154 bool useDoubleInput = false,
155 bool isFirstInput = true) {
156 getDistanceMatrix<dataType>(
157 trees, trees, distanceMatrix, useDoubleInput, isFirstInput);
158 }
159
160 template <class dataType>
162 std::vector<ftm::FTMTree_MT *> &trees,
163 unsigned int barycenterMaximumNumberOfPairs,
164 double sizeLimitPercent,
165 std::vector<ftm::MergeTree<dataType>> &mTreesLimited) {
166 mTreesLimited.resize(trees.size());
167 for(unsigned int i = 0; i < trees.size(); ++i) {
168 mTreesLimited[i] = ftm::copyMergeTree<dataType>(trees[i]);
169 limitSizeBarycenter(mTreesLimited[i], trees,
170 barycenterMaximumNumberOfPairs, sizeLimitPercent);
171 ftm::cleanMergeTree<dataType>(mTreesLimited[i]);
172 }
173 }
174
175 template <class dataType>
177 std::vector<ftm::FTMTree_MT *> &trees,
178 std::vector<std::vector<double>> &distanceMatrix,
179 unsigned int barycenterMaximumNumberOfPairs,
180 double sizeLimitPercent,
181 bool useDoubleInput = false,
182 bool isFirstInput = true) {
183 std::vector<ftm::MergeTree<dataType>> mTreesLimited;
184 getSizeLimitedTrees<dataType>(
185 trees, barycenterMaximumNumberOfPairs, sizeLimitPercent, mTreesLimited);
186 std::vector<ftm::FTMTree_MT *> treesLimited;
187 ftm::mergeTreeToFTMTree<dataType>(mTreesLimited, treesLimited);
188 getDistanceMatrix<dataType>(
189 trees, treesLimited, distanceMatrix, useDoubleInput, isFirstInput);
190 }
191
192 template <class dataType>
194 std::vector<ftm::FTMTree_MT *> &trees,
195 std::vector<std::vector<double>> &distanceMatrix,
196 unsigned int barycenterMaximumNumberOfPairs,
197 double sizeLimitPercent,
198 bool useDoubleInput = false,
199 bool isFirstInput = true) {
200 if(barycenterMaximumNumberOfPairs <= 0 and sizeLimitPercent <= 0.0)
201 getDistanceMatrix<dataType>(
202 trees, distanceMatrix, useDoubleInput, isFirstInput);
203 else
204 getSizeLimitedDistanceMatrix<dataType>(
205 trees, distanceMatrix, barycenterMaximumNumberOfPairs,
206 sizeLimitPercent, useDoubleInput, isFirstInput);
207 }
208
209 template <class dataType>
210 int getBestInitTreeIndex(std::vector<ftm::FTMTree_MT *> &trees,
211 std::vector<ftm::FTMTree_MT *> &trees2,
212 unsigned int barycenterMaximumNumberOfPairs,
213 double sizeLimitPercent,
214 bool distMinimizer = true) {
215 std::vector<std::vector<double>> distanceMatrix, distanceMatrix2;
216 bool const useDoubleInput = (trees2.size() != 0);
217 getParametrizedDistanceMatrix<dataType>(trees, distanceMatrix,
218 barycenterMaximumNumberOfPairs,
219 sizeLimitPercent, useDoubleInput);
220 if(trees2.size() != 0)
221 getParametrizedDistanceMatrix<dataType>(
222 trees2, distanceMatrix2, barycenterMaximumNumberOfPairs,
223 sizeLimitPercent, useDoubleInput, false);
224
225 int bestIndex = -1;
226 dataType bestValue
227 = distMinimizer ? std::numeric_limits<dataType>::max() : 0;
228 std::vector<int> sizes(trees.size());
229 for(unsigned int i = 0; i < trees.size(); ++i) {
230 dataType value = 0;
231 for(unsigned int j = 0; j < distanceMatrix[i].size(); ++j)
232 value += (not useDoubleInput ? distanceMatrix[i][j]
233 : mixDistances(distanceMatrix[i][j],
234 distanceMatrix2[i][j]));
235 if((distMinimizer and value < bestValue)
236 or (not distMinimizer and value > bestValue)) {
237 bestIndex = i;
238 bestValue = value;
239 }
240 sizes[i] = -value;
241 sizes[i] *= (distMinimizer) ? 1 : -1;
242 }
243 if(not deterministic_) {
244 std::random_device rd;
245 std::default_random_engine generator(rd());
246 std::discrete_distribution<int> distribution(
247 sizes.begin(), sizes.end());
248 bestIndex = distribution(generator);
249 }
250 return bestIndex;
251 }
252
253 template <class dataType>
254 int getBestInitTreeIndex(std::vector<ftm::FTMTree_MT *> &trees,
255 std::vector<ftm::FTMTree_MT *> &trees2,
256 double sizeLimitPercent,
257 bool distMinimizer = true) {
258 return getBestInitTreeIndex<dataType>(trees, trees2,
260 sizeLimitPercent, distMinimizer);
261 }
262
263 template <class dataType>
264 int getBestInitTreeIndex(std::vector<ftm::FTMTree_MT *> &trees,
265 bool distMinimizer = true) {
266 std::vector<ftm::FTMTree_MT *> trees2;
267 return getBestInitTreeIndex<dataType>(
268 trees, trees2, barycenterMaximumNumberOfPairs_,
269 barycenterSizeLimitPercent_, distMinimizer);
270 }
271
272 template <class dataType>
273 void initBarycenterTree(std::vector<ftm::FTMTree_MT *> &trees,
274 ftm::MergeTree<dataType> &baryTree,
275 bool distMinimizer = true) {
276 int const bestIndex
277 = getBestInitTreeIndex<dataType>(trees, distMinimizer);
278 baryTree = ftm::copyMergeTree<dataType>(trees[bestIndex], true);
279 limitSizeBarycenter(baryTree, trees);
280 }
281
282 // ------------------------------------------------------------------------
283 // Update
284 // ------------------------------------------------------------------------
285 template <class dataType>
288 ftm::idNode nodeId1,
289 ftm::FTMTree_MT *tree2,
290 ftm::idNode nodeId2,
291 std::vector<dataType> &newScalarsVector,
292 std::vector<std::tuple<ftm::idNode, ftm::idNode, int>> &nodesToProcess,
293 ftm::idNode nodeCpt,
294 int i) {
295 // Get nodes and scalars to add
296 std::queue<std::tuple<ftm::idNode, ftm::idNode>> queue;
297 queue.emplace(nodeId2, nodeId1);
298 nodesToProcess.emplace_back(nodeId2, nodeId1, i);
299 while(!queue.empty()) {
300 auto queueTuple = queue.front();
301 queue.pop();
302 ftm::idNode const node = std::get<0>(queueTuple);
303 // Get scalars
304 newScalarsVector.push_back(
305 tree2->getValue<dataType>(tree2->getNode(node)->getOrigin()));
306 newScalarsVector.push_back(tree2->getValue<dataType>(node));
307 // Process children
308 std::vector<ftm::idNode> children;
309 tree2->getChildren(node, children);
310 for(auto child : children) {
311 queue.emplace(child, nodeCpt + 1);
312 nodesToProcess.emplace_back(child, nodeCpt + 1, i);
313 }
314 nodeCpt += 2; // we will add two nodes (birth and death)
315 }
316
317 return nodeCpt;
318 }
319
320 template <class dataType>
323 int noTrees,
324 std::vector<std::tuple<ftm::idNode, ftm::idNode, int>> &nodesToProcess,
325 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode>>>
326 &nodesProcessed) {
327 ftm::FTMTree_MT *tree1 = &(mTree1.tree);
328
329 // Add nodes
330 nodesProcessed.clear();
331 nodesProcessed.resize(noTrees);
332 for(auto processTuple : nodesToProcess) {
333 ftm::idNode const parent = std::get<1>(processTuple);
334 ftm::idNode const nodeTree1 = tree1->getNumberOfNodes();
335 int const index = std::get<2>(processTuple);
336 nodesProcessed[index].emplace_back(
337 nodeTree1 + 1, std::get<0>(processTuple));
338 // Make node and its origin
339 tree1->makeNode(nodeTree1);
340 tree1->makeNode(nodeTree1 + 1);
341 tree1->setParent(nodeTree1 + 1, parent);
342 tree1->getNode(nodeTree1)->setOrigin(nodeTree1 + 1);
343 tree1->getNode(nodeTree1 + 1)->setOrigin(nodeTree1);
344 }
345 }
346
347 template <class dataType>
350 int noTrees,
351 std::vector<std::tuple<ftm::idNode, ftm::idNode, int>> &nodesToProcess,
352 std::vector<dataType> &newScalarsVector,
353 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode>>>
354 &nodesProcessed) {
355 ftm::FTMTree_MT *tree1 = &(mTree1.tree);
356
357 // Create new tree
359 = ftm::createEmptyMergeTree<dataType>(newScalarsVector.size());
360 ftm::setTreeScalars<dataType>(mTreeNew, newScalarsVector);
361 ftm::FTMTree_MT *treeNew = &(mTreeNew.tree);
362
363 // Copy the old tree structure
364 treeNew->copyMergeTreeStructure(tree1);
365
366 // Add nodes in the other trees
367 addNodes<dataType>(mTreeNew, noTrees, nodesToProcess, nodesProcessed);
368
369 // Copy new tree
370 mTree1 = mTreeNew;
371 }
372
373 template <class dataType>
375 std::vector<ftm::FTMTree_MT *> &trees,
376 ftm::MergeTree<dataType> &baryMergeTree,
377 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
378 &matchings) {
379 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
380 ftm::idNode const baryTreeRoot = baryTree->getRoot();
381
382 // Init matching matrix
383 // m[i][j] contains the node in the barycenter matched to the jth node of
384 // the ith tree
385 std::vector<std::vector<ftm::idNode>> matrixMatchings(trees.size());
386 std::vector<bool> baryMatched(baryTree->getNumberOfNodes(), false);
387 for(unsigned int i = 0; i < matchings.size(); ++i) {
388 auto matching = matchings[i];
389 matrixMatchings[i].resize(trees[i]->getNumberOfNodes(),
390 std::numeric_limits<ftm::idNode>::max());
391 for(auto match : matching) {
392 matrixMatchings[i][std::get<1>(match)] = std::get<0>(match);
393 baryMatched[std::get<0>(match)] = true;
394 }
395 }
396
397 // Iterate through trees to get the nodes to add in the barycenter
398 std::vector<std::vector<ftm::idNode>> nodesToAdd(trees.size());
399 for(unsigned int i = 0; i < trees.size(); ++i) {
400 ftm::idNode const root = trees[i]->getRoot();
401 std::queue<ftm::idNode> queue;
402 queue.emplace(root);
403 while(!queue.empty()) {
404 ftm::idNode const node = queue.front();
405 queue.pop();
406 bool processChildren = true;
407 // if node in trees[i] is not matched
408 if(matrixMatchings[i][node]
409 == std::numeric_limits<ftm::idNode>::max()) {
410 if(not keepSubtree_) {
411 processChildren = false;
412 nodesToAdd[i].push_back(node);
413 } else {
414 // not todo manage if keepSubtree=true (not important since it is
415 // not a valid merge tree)
416 printErr(
417 "barycenter with keepSubtree_=true is not implemented yet");
418 }
419 }
420 if(processChildren) {
421 std::vector<ftm::idNode> children;
422 trees[i]->getChildren(node, children);
423 for(auto child : children)
424 if(not(trees[i]->isThereOnlyOnePersistencePair()
425 and trees[i]->isLeaf(child)))
426 queue.emplace(child);
427 }
428 }
429 }
430
431 bool foundRootNotMatched = false;
432 for(unsigned int i = 0; i < trees.size(); ++i)
433 foundRootNotMatched |= baryTree->isNodeIdInconsistent(
434 matrixMatchings[i][trees[i]->getRoot()]);
435 if(foundRootNotMatched)
436 printWrn("[updateBarycenterTreeStructure] an input tree has its root "
437 "not matched.");
438
439 // Delete nodes that are not matched in the barycenter
440 for(unsigned int i = 0; i < baryTree->getNumberOfNodes(); ++i)
441 if(not baryMatched[i])
442 baryTree->deleteNode(i);
443
444 if(not keepSubtree_) {
445 // Add scalars and nodes not present in the barycenter
446 ftm::idNode nodeCpt = baryTree->getNumberOfNodes();
447 std::vector<std::tuple<ftm::idNode, ftm::idNode, int>> nodesToProcess;
448 std::vector<dataType> newScalarsVector;
449 ftm::getTreeScalars<dataType>(baryMergeTree, newScalarsVector);
450 for(unsigned int i = 0; i < nodesToAdd.size(); ++i) {
451 for(auto node : nodesToAdd[i]) {
452 ftm::idNode parent
453 = matrixMatchings[i][trees[i]->getParentSafe(node)];
454 if(matchings[i].size() == 0)
455 parent = baryTreeRoot;
456
457 if((baryTree->isNodeIdInconsistent(parent)
458 or baryTree->isNodeAlone(parent))
459 and matchings[i].size() != 0) {
460 std::stringstream ss;
461 ss << trees[i]->getParentSafe(node) << " _ " << node;
462 printMsg(ss.str());
463 printMsg(trees[i]->printTree().str());
464 printMsg(trees[i]->printPairsFromTree<dataType>(true).str());
465 printMatching(matchings[i]);
466 std::stringstream ss2;
467 ss2 << "parent " << parent;
468 printMsg(ss2.str());
469 }
470 /*if(isRoot(trees[i], node))
471 parent = baryTree->getRoot();*/
472 std::vector<dataType> addedScalars;
473 nodeCpt = getNodesAndScalarsToAdd<dataType>(
474 baryMergeTree, parent, trees[i], node, addedScalars,
475 nodesToProcess, nodeCpt, i);
476 newScalarsVector.insert(
477 newScalarsVector.end(), addedScalars.begin(), addedScalars.end());
478 }
479 }
480 if(addNodes_) {
481 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode>>>
482 nodesProcessed;
483 updateNodesAndScalars<dataType>(baryMergeTree, trees.size(),
484 nodesToProcess, newScalarsVector,
485 nodesProcessed);
486 for(unsigned int i = 0; i < matchings.size(); ++i) {
487 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>
488 nodesProcessedT;
489 for(auto tup : nodesProcessed[i])
490 nodesProcessedT.emplace_back(
491 std::get<0>(tup), std::get<1>(tup), -1);
492 matchings[i].insert(matchings[i].end(), nodesProcessedT.begin(),
493 nodesProcessedT.end());
494 }
495 }
496 } else {
497 // not todo manage if keepSubtree=true (not important since it is not a
498 // valid merge tree)
499 printErr("barycenter with keepSubtree_=true is not implemented yet");
500 }
501 }
502
503 template <class dataType>
504 std::tuple<dataType, dataType>
506 std::tuple<dataType, dataType> birthDeath;
507 // Normalized Wasserstein
509 birthDeath = getNormalizedBirthDeath<dataType>(tree1, nodeId1);
510 // Classical Wasserstein
511 else
512 birthDeath = tree1->getBirthDeath<dataType>(nodeId1);
513 return birthDeath;
514 }
515
516 template <class dataType>
517 std::tuple<dataType, dataType>
519 ftm::idNode nodeId,
520 std::vector<dataType> &newScalarsVector,
521 std::vector<ftm::FTMTree_MT *> &trees,
522 std::vector<ftm::idNode> &nodes,
523 std::vector<double> &alphas) {
524 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
525 dataType mu_max = getMinMaxLocalFromVector<dataType>(
526 baryTree, nodeId, newScalarsVector, false);
527 dataType mu_min = getMinMaxLocalFromVector<dataType>(
528 baryTree, nodeId, newScalarsVector);
529 dataType newBirth = 0, newDeath = 0;
530
531 // Compute projection
532 dataType tempBirth = 0, tempDeath = 0;
533 double alphaSum = 0;
534 for(unsigned int i = 0; i < trees.size(); ++i)
535 if(nodes[i] != std::numeric_limits<ftm::idNode>::max())
536 alphaSum += alphas[i];
537 for(unsigned int i = 0; i < trees.size(); ++i) {
538 // if node is matched in trees[i]
539 if(nodes[i] != std::numeric_limits<ftm::idNode>::max()) {
540 auto iBirthDeath
541 = getParametrizedBirthDeath<dataType>(trees[i], nodes[i]);
542 dataType tTempBirth = 0, tTempDeath = 0;
543 tTempBirth += std::get<0>(iBirthDeath);
544 tTempDeath += std::get<1>(iBirthDeath);
545 tempBirth += tTempBirth * alphas[i] / alphaSum;
546 tempDeath += tTempDeath * alphas[i] / alphaSum;
547 }
548 }
549 dataType const projec = (tempBirth + tempDeath) / 2;
550
551 // Compute newBirth and newDeath
552 for(unsigned int i = 0; i < trees.size(); ++i) {
553 dataType iBirth = projec, iDeath = projec;
554 // if node is matched in trees[i]
555 if(nodes[i] != std::numeric_limits<ftm::idNode>::max()) {
556 auto iBirthDeath
557 = getParametrizedBirthDeath<dataType>(trees[i], nodes[i]);
558 iBirth = std::get<0>(iBirthDeath);
559 iDeath = std::get<1>(iBirthDeath);
560 }
561 newBirth += alphas[i] * iBirth;
562 newDeath += alphas[i] * iDeath;
563 }
565 // Forbid compiler optimization to have same results on different
566 // computers
567 volatile dataType tempBirthT = newBirth * (mu_max - mu_min);
568 volatile dataType tempDeathT = newDeath * (mu_max - mu_min);
569 newBirth = tempBirthT + mu_min;
570 newDeath = tempDeathT + mu_min;
571 }
572
573 return std::make_tuple(newBirth, newDeath);
574 }
575
576 template <class dataType>
577 std::tuple<dataType, dataType>
579 ftm::idNode nodeId,
580 double alpha,
581 ftm::MergeTree<dataType> &baryMergeTree,
582 ftm::idNode nodeB,
583 std::vector<dataType> &newScalarsVector) {
584 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
585 dataType mu_max = getMinMaxLocalFromVector<dataType>(
586 baryTree, nodeB, newScalarsVector, false);
587 dataType mu_min
588 = getMinMaxLocalFromVector<dataType>(baryTree, nodeB, newScalarsVector);
589
590 auto birthDeath = getParametrizedBirthDeath<dataType>(tree, nodeId);
591 dataType newBirth = std::get<0>(birthDeath);
592 dataType newDeath = std::get<1>(birthDeath);
593 dataType const projec = (newBirth + newDeath) / 2;
594
595 newBirth = alpha * newBirth + (1 - alpha) * projec;
596 newDeath = alpha * newDeath + (1 - alpha) * projec;
597
599 // Forbid compiler optimization to have same results on different
600 // computers
601 volatile dataType tempBirthT = newBirth * (mu_max - mu_min);
602 volatile dataType tempDeathT = newDeath * (mu_max - mu_min);
603 newBirth = tempBirthT + mu_min;
604 newDeath = tempDeathT + mu_min;
605 }
606
607 return std::make_tuple(newBirth, newDeath);
608 }
609
610 template <class dataType>
612 std::vector<ftm::FTMTree_MT *> &trees,
613 ftm::MergeTree<dataType> &baryMergeTree,
614 std::vector<double> &alphas,
615 unsigned int indexAddedNodes,
616 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
617 &matchings) {
618 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
619 bool const isJT = baryTree->isJoinTree<dataType>();
620
621 // Init matching matrix
622 // m[i][j] contains the node in trees[j] matched to the node i in the
623 // barycenter
624 std::vector<std::vector<ftm::idNode>> baryMatching(
625 baryTree->getNumberOfNodes(),
626 std::vector<ftm::idNode>(
627 trees.size(), std::numeric_limits<ftm::idNode>::max()));
628 std::vector<int> nodesAddedTree(baryTree->getNumberOfNodes(), -1);
629 for(unsigned int i = 0; i < matchings.size(); ++i) {
630 auto matching = matchings[i];
631 for(auto match : matching) {
632 baryMatching[std::get<0>(match)][i] = std::get<1>(match);
633 if(std::get<0>(match)
634 >= indexAddedNodes) // get the tree of this added node
635 nodesAddedTree[std::get<0>(match)] = i;
636 }
637 }
638
639 // Interpolate scalars
640 std::vector<dataType> newScalarsVector(baryTree->getNumberOfNodes());
641 ftm::idNode const root = baryTree->getRoot();
642 std::queue<ftm::idNode> queue;
643 queue.emplace(root);
644 while(!queue.empty()) {
645 ftm::idNode const node = queue.front();
646 queue.pop();
647 std::tuple<dataType, dataType> newBirthDeath;
648 if(node < indexAddedNodes) {
649 newBirthDeath
650 = interpolation<dataType>(baryMergeTree, node, newScalarsVector,
651 trees, baryMatching[node], alphas);
652 } else {
653 int const i = nodesAddedTree[node];
654 ftm::idNode const nodeT = baryMatching[node][i];
655 newBirthDeath = interpolationAdded<dataType>(
656 trees[i], nodeT, alphas[i], baryMergeTree, node, newScalarsVector);
657 }
658 dataType nodeScalar
659 = (isJT ? std::get<1>(newBirthDeath) : std::get<0>(newBirthDeath));
660 dataType nodeOriginScalar
661 = (isJT ? std::get<0>(newBirthDeath) : std::get<1>(newBirthDeath));
662 newScalarsVector[node] = nodeScalar;
663 newScalarsVector[baryTree->getNode(node)->getOrigin()]
664 = nodeOriginScalar;
665 std::vector<ftm::idNode> children;
666 baryTree->getChildren(node, children);
667 for(auto child : children)
668 queue.emplace(child);
669 }
670
671 if(baryMergeTree.tree.isFullMerge()) {
672 auto mergedRootOrigin = baryTree->getMergedRootOrigin<dataType>();
673 dataType mergedRootOriginScalar = 0.0;
674 for(unsigned int i = 0; i < trees.size(); ++i)
675 mergedRootOriginScalar += trees[i]->getValue<dataType>(
676 trees[i]->getMergedRootOrigin<dataType>());
677 mergedRootOriginScalar /= trees.size();
678 newScalarsVector[mergedRootOrigin] = mergedRootOriginScalar;
679 }
680
681 setTreeScalars(baryMergeTree, newScalarsVector);
682
683 std::vector<ftm::idNode> deletedNodesT;
684 persistenceThresholding<dataType>(
685 &(baryMergeTree.tree), 0, deletedNodesT);
686 limitSizeBarycenter(baryMergeTree, trees);
687 ftm::cleanMergeTree<dataType>(baryMergeTree);
688 }
689
690 template <class dataType>
692 std::vector<ftm::FTMTree_MT *> &trees,
693 ftm::MergeTree<dataType> &baryMergeTree,
694 std::vector<double> &alphas,
695 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
696 &matchings) {
697 int const indexAddedNodes = baryMergeTree.tree.getNumberOfNodes();
698 updateBarycenterTreeStructure<dataType>(trees, baryMergeTree, matchings);
699 updateBarycenterTreeScalars<dataType>(
700 trees, baryMergeTree, alphas, indexAddedNodes, matchings);
701 }
702
703 // ------------------------------------------------------------------------
704 // Assignment
705 // ------------------------------------------------------------------------
706 template <class dataType>
708 ftm::FTMTree_MT *tree,
709 ftm::FTMTree_MT *baryTree,
710 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
711 dataType &distance,
712 bool useDoubleInput = false,
713 bool isFirstInput = true) {
714 // Timer t_distance;
715 MergeTreeDistance mergeTreeDistance;
716 mergeTreeDistance.setDebugLevel(std::min(debugLevel_, 2));
717 mergeTreeDistance.setPreprocess(false);
718 mergeTreeDistance.setPostprocess(false);
719 mergeTreeDistance.setBranchDecomposition(true);
721 mergeTreeDistance.setKeepSubtree(keepSubtree_);
722 mergeTreeDistance.setAssignmentSolver(assignmentSolverID_);
723 mergeTreeDistance.setIsCalled(true);
724 mergeTreeDistance.setThreadNumber(this->threadNumber_);
725 mergeTreeDistance.setDistanceSquaredRoot(true); // squared root
726 mergeTreeDistance.setNodePerTask(nodePerTask_);
727 if(useDoubleInput) {
728 double const weight = mixDistancesMinMaxPairWeight(isFirstInput);
729 mergeTreeDistance.setMinMaxPairWeight(weight);
730 }
731 /*if(progressiveBarycenter_){
732 mergeTreeDistance.setAuctionNoRounds(1);
733 mergeTreeDistance.setAuctionEpsilonDiviser(NoIteration-1);
734 }*/
735 distance
736 = mergeTreeDistance.computeDistance<dataType>(baryTree, tree, matching);
737 std::stringstream ss, ss2;
738 ss << "distance tree : " << distance;
740 ss2 << "distance²tree : " << distance * distance;
742
743 // auto t_distance_time = t_distance.getElapsedTime();
744 // allDistanceTime_ += t_distance_time;
745 }
746
747 template <class dataType>
749 ftm::FTMTree_MT *tree,
750 ftm::MergeTree<dataType> &baryMergeTree,
751 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
752 dataType &distance,
753 bool useDoubleInput = false,
754 bool isFirstInput = true) {
755 computeOneDistance<dataType>(tree, &(baryMergeTree.tree), matching,
756 distance, useDoubleInput, isFirstInput);
757 }
758
759 template <class dataType>
761 ftm::MergeTree<dataType> &baryMergeTree,
762 ftm::MergeTree<dataType> &baryMergeTree2,
763 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
764 dataType &distance,
765 bool useDoubleInput = false,
766 bool isFirstInput = true) {
767 computeOneDistance<dataType>(&(baryMergeTree.tree), baryMergeTree2,
768 matching, distance, useDoubleInput,
769 isFirstInput);
770 }
771
772 template <class dataType>
774 std::vector<ftm::FTMTree_MT *> &trees,
775 ftm::MergeTree<dataType> &baryMergeTree,
776 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
777 &matchings,
778 std::vector<dataType> &distances,
779 bool useDoubleInput = false,
780 bool isFirstInput = true) {
781 if(not isCalled_)
782 assignmentPara(trees, baryMergeTree, matchings, distances,
783 useDoubleInput, isFirstInput);
784 else
785 assignmentTask(trees, baryMergeTree, matchings, distances,
786 useDoubleInput, isFirstInput);
787 }
788
789 template <class dataType>
791 std::vector<ftm::FTMTree_MT *> &trees,
792 ftm::MergeTree<dataType> &baryMergeTree,
793 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
794 &matchings,
795 std::vector<dataType> &distances,
796 bool useDoubleInput = false,
797 bool isFirstInput = true) {
798#ifdef TTK_ENABLE_OPENMP4
799#pragma omp parallel num_threads(this->threadNumber_) \
800 shared(baryMergeTree) if(parallelize_)
801 {
802#pragma omp single nowait
803#endif
804 assignmentTask(trees, baryMergeTree, matchings, distances,
805 useDoubleInput, isFirstInput);
806#ifdef TTK_ENABLE_OPENMP4
807 } // pragma omp parallel
808#endif
809 }
810
811 template <class dataType>
813 std::vector<ftm::FTMTree_MT *> &trees,
814 ftm::MergeTree<dataType> &baryMergeTree,
815 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
816 &matchings,
817 std::vector<dataType> &distances,
818 bool useDoubleInput = false,
819 bool isFirstInput = true) {
820 for(unsigned int i = 0; i < trees.size(); ++i)
821#ifdef TTK_ENABLE_OPENMP4
822#pragma omp task firstprivate(i) UNTIED() \
823 shared(baryMergeTree, matchings, distances)
824#endif
825 computeOneDistance<dataType>(trees[i], baryMergeTree, matchings[i],
826 distances[i], useDoubleInput,
827 isFirstInput);
828#ifdef TTK_ENABLE_OPENMP4
829#pragma omp taskwait
830#endif
831 }
832
833 // ------------------------------------------------------------------------
834 // Progressivity
835 // ------------------------------------------------------------------------
836 template <class dataType>
837 unsigned int
838 persistenceScaling(std::vector<ftm::FTMTree_MT *> &trees,
839 std::vector<ftm::MergeTree<dataType>> &mergeTrees,
840 std::vector<ftm::FTMTree_MT *> &oriTrees,
841 int iterationNumber,
842 std::vector<std::vector<ftm::idNode>> &deletedNodes) {
843 deletedNodes.clear();
844 deletedNodes.resize(oriTrees.size());
845 unsigned int noTreesUnscaled = 0;
846
847 // Scale trees
848 for(unsigned int i = 0; i < oriTrees.size(); ++i) {
849 double persistenceThreshold = 50.0;
850 if(iterationNumber != -1) {
851 // Get number of pairs in scaled merge tree
852 int const noPairs = mergeTrees[i].tree.getRealNumberOfNodes();
853
854 // Get pairs in original merge tree
855 std::vector<std::tuple<ftm::idNode, ftm::idNode, dataType>> pairs;
856 oriTrees[i]->getPersistencePairsFromTree<dataType>(
857 pairs, branchDecomposition_);
858
859 // Compute new persistence threshold
860 double const multiplier
862 ? 1.
863 : iterationNumber / progressiveSpeedDivisor_);
864 int const decrement = multiplier * pairs.size() / 10;
865 int thresholdIndex = pairs.size() - noPairs - std::max(decrement, 2);
866 thresholdIndex = std::max(thresholdIndex, 0);
867 const double persistence = std::get<2>(pairs[thresholdIndex]);
868 persistenceThreshold
869 = persistence / std::get<2>(pairs.back()) * 100.0;
870 if(thresholdIndex == 0) {
871 persistenceThreshold = 0.;
872 ++noTreesUnscaled;
873 }
874 }
875 if(persistenceThreshold != 0.) {
877 = ftm::copyMergeTree<dataType>(oriTrees[i]);
878 persistenceThresholding<dataType>(
879 &(mt.tree), persistenceThreshold, deletedNodes[i]);
880 if(mergeTrees.size() == 0)
881 mergeTrees.resize(oriTrees.size());
882 mergeTrees[i] = mt;
883 trees[i] = &(mt.tree);
884 } else {
885 trees[i] = oriTrees[i];
886 }
887 }
888
889 printTreesStats(trees);
890
891 return noTreesUnscaled;
892 }
893
894 template <class dataType>
896 std::vector<ftm::FTMTree_MT *> &oriTrees,
897 std::vector<std::vector<ftm::idNode>> &deletedNodes,
898 std::vector<dataType> &distances) {
899 for(unsigned int i = 0; i < oriTrees.size(); ++i)
900 for(auto node : deletedNodes[i])
901 distances[i] += deleteCost<dataType>(oriTrees[i], node);
902 }
903
904 // ------------------------------------------------------------------------
905 // Main Functions
906 // ------------------------------------------------------------------------
907 template <class dataType>
909 std::vector<ftm::FTMTree_MT *> &trees,
910 ftm::MergeTree<dataType> &baryMergeTree,
911 std::vector<double> &alphas,
912 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
913 &finalMatchings,
914 bool finalAsgnDoubleInput = false,
915 bool finalAsgnFirstInput = true) {
916 Timer t_bary;
917
918 ftm::FTMTree_MT *baryTree = &(baryMergeTree.tree);
919
920 // Persistence scaling
921 std::vector<ftm::FTMTree_MT *> oriTrees;
922 std::vector<ftm::MergeTree<dataType>> scaledMergeTrees;
923 std::vector<std::vector<ftm::idNode>> deletedNodes;
925 oriTrees.insert(oriTrees.end(), trees.begin(), trees.end());
926 persistenceScaling<dataType>(
927 trees, scaledMergeTrees, oriTrees, -1, deletedNodes);
928 std::vector<ftm::idNode> deletedNodesT;
929 persistenceThresholding<dataType>(baryTree, 50, deletedNodesT);
930 }
931 bool treesUnscaled = false;
932
933 // Print bary stats
934 printBaryStats(baryTree);
935
936 // Run
937 bool converged = false;
938 dataType frechetEnergy = -1;
939 dataType minFrechet = std::numeric_limits<dataType>::max();
940 int cptBlocked = 0;
941 int NoIteration = 0;
942 while(not converged) {
943 ++NoIteration;
944
946 std::stringstream ss;
947 ss << "Iteration " << NoIteration;
948 printMsg(ss.str());
949
950 // --- Assignment
951 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
952 matchings(trees.size());
953 std::vector<dataType> distances(trees.size(), -1);
954 Timer t_assignment;
955 assignment<dataType>(trees, baryMergeTree, matchings, distances);
956 Timer t_addDeletedNodes;
958 addScaledDeletedNodesCost<dataType>(
959 oriTrees, deletedNodes, distances);
960 addDeletedNodesTime_ += t_addDeletedNodes.getElapsedTime();
961 auto t_assignment_time
962 = t_assignment.getElapsedTime() - t_addDeletedNodes.getElapsedTime();
963 printMsg("Assignment", 1, t_assignment_time, this->threadNumber_,
965
966 // --- Update
967 Timer t_update;
968 updateBarycenterTree<dataType>(trees, baryMergeTree, alphas, matchings);
969 auto t_update_time = t_update.getElapsedTime();
970 baryTree = &(baryMergeTree.tree);
971 printMsg("Update", 1, t_update_time, this->threadNumber_,
973
974 // --- Check convergence
975 dataType currentFrechetEnergy = 0;
976 for(unsigned int i = 0; i < trees.size(); ++i)
977 currentFrechetEnergy += alphas[i] * distances[i] * distances[i];
978 auto frechetDiff
979 = std::abs((double)(frechetEnergy - currentFrechetEnergy));
980 converged = (frechetDiff <= tol_);
981 converged = converged and (not progressiveBarycenter_ or treesUnscaled);
982 frechetEnergy = currentFrechetEnergy;
983 tol_ = frechetEnergy / 125.0;
984
985 std::stringstream ss4;
986 auto barycenterTime = t_bary.getElapsedTime() - addDeletedNodesTime_;
987 printMsg("Total", 1, barycenterTime, this->threadNumber_,
989 printBaryStats(baryTree);
990 ss4 << "Frechet energy : " << frechetEnergy;
991 printMsg(ss4.str());
992
993 minFrechet = std::min(minFrechet, frechetEnergy);
994 if(not converged and (not progressiveBarycenter_ or treesUnscaled)) {
995 cptBlocked = (minFrechet < frechetEnergy) ? cptBlocked + 1 : 0;
996 converged = (cptBlocked >= 10);
997 }
998
999 // --- Persistence scaling
1001 unsigned int const noTreesUnscaled = persistenceScaling<dataType>(
1002 trees, scaledMergeTrees, oriTrees, NoIteration, deletedNodes);
1003 treesUnscaled = (noTreesUnscaled == oriTrees.size());
1004 }
1005 }
1006
1007 // Final processing
1009 printMsg("Final assignment");
1010
1011 std::vector<dataType> distances(trees.size(), -1);
1012 assignment<dataType>(trees, baryMergeTree, finalMatchings, distances,
1013 finalAsgnDoubleInput, finalAsgnFirstInput);
1014 for(auto dist : distances)
1015 finalDistances_.push_back(dist);
1016 dataType currentFrechetEnergy = 0;
1017 for(unsigned int i = 0; i < trees.size(); ++i)
1018 currentFrechetEnergy += alphas[i] * distances[i] * distances[i];
1019
1020 std::stringstream ss, ss2;
1021 ss << "Frechet energy : " << currentFrechetEnergy;
1022 printMsg(ss.str());
1023 auto barycenterTime = t_bary.getElapsedTime() - addDeletedNodesTime_;
1024 printMsg("Total", 1, barycenterTime, this->threadNumber_,
1026 // std::cout << "Bary Distance Time = " << allDistanceTime_ << std::endl;
1027
1028 if(trees.size() == 2 and not isCalled_)
1029 verifyBarycenterTwoTrees<dataType>(
1030 trees, baryMergeTree, finalMatchings, distances);
1031
1032 // Persistence (un)scaling
1034 scaledMergeTrees.clear();
1035 trees.clear();
1036 trees.insert(trees.end(), oriTrees.begin(), oriTrees.end());
1037 }
1038 }
1039
1040 template <class dataType>
1042 std::vector<ftm::MergeTree<dataType>> &trees,
1043 std::vector<double> &alphas,
1044 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1045 &finalMatchings,
1046 ftm::MergeTree<dataType> &baryMergeTree,
1047 bool finalAsgnDoubleInput = false,
1048 bool finalAsgnFirstInput = true) {
1049 // --- Preprocessing
1050 if(preprocess_) {
1051 treesNodeCorr_.resize(trees.size());
1052 for(unsigned int i = 0; i < trees.size(); ++i)
1053 preprocessingPipeline<dataType>(trees[i], epsilonTree2_,
1057 printTreesStats(trees);
1058 }
1059
1060 // --- Init barycenter
1061 std::vector<ftm::FTMTree_MT *> treesT;
1062 ftm::mergeTreeToFTMTree<dataType>(trees, treesT);
1063 initBarycenterTree<dataType>(treesT, baryMergeTree);
1064
1065 // --- Execute
1066 computeBarycenter<dataType>(treesT, baryMergeTree, alphas, finalMatchings,
1067 finalAsgnDoubleInput, finalAsgnFirstInput);
1068
1069 // --- Postprocessing
1070 if(postprocess_) {
1071 std::vector<int> const allRealNodes(trees.size());
1072 for(unsigned int i = 0; i < trees.size(); ++i) {
1073 postprocessingPipeline<dataType>(treesT[i]);
1074 }
1075
1076 // fixMergedRootOriginBarycenter<dataType>(baryMergeTree);
1077 postprocessingPipeline<dataType>(&(baryMergeTree.tree));
1078 for(unsigned int i = 0; i < trees.size(); ++i) {
1079 convertBranchDecompositionMatching<dataType>(
1080 &(baryMergeTree.tree), treesT[i], finalMatchings[i]);
1081 }
1082 }
1083 }
1084
1085 template <class dataType>
1087 std::vector<ftm::MergeTree<dataType>> &trees,
1088 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1089 &finalMatchings,
1090 ftm::MergeTree<dataType> &baryMergeTree,
1091 bool finalAsgnDoubleInput = false,
1092 bool finalAsgnFirstInput = true) {
1093 std::vector<double> alphas;
1094 if(trees.size() != 2) {
1095 for(unsigned int i = 0; i < trees.size(); ++i)
1096 alphas.push_back(1.0 / trees.size());
1097 } else {
1098 alphas.push_back(alpha_);
1099 alphas.push_back(1 - alpha_);
1100 }
1101
1102 execute<dataType>(trees, alphas, finalMatchings, baryMergeTree,
1103 finalAsgnDoubleInput, finalAsgnFirstInput);
1104 }
1105
1106 // ------------------------------------------------------------------------
1107 // Preprocessing
1108 // ------------------------------------------------------------------------
1109 template <class dataType>
1111 std::vector<ftm::FTMTree_MT *> &trees,
1112 double percent,
1113 bool useBD) {
1114 auto metric = getSizeLimitMetric(trees);
1115 unsigned int const newNoNodes = metric * percent / 100.0;
1116 keepMostImportantPairs<dataType>(&(bary.tree), newNoNodes, useBD);
1117
1118 unsigned int const noNodesAfter = bary.tree.getRealNumberOfNodes();
1119 if(bary.tree.isFullMerge() and noNodesAfter > newNoNodes * 1.1 + 1
1120 and noNodesAfter > 3) {
1121 std::cout << "metric = " << metric << std::endl;
1122 std::cout << "newNoNodes = " << newNoNodes << std::endl;
1123 std::cout << "noNodesAfter = " << noNodesAfter << std::endl;
1124 }
1125 }
1126
1127 template <class dataType>
1129 std::vector<ftm::FTMTree_MT *> &trees,
1130 unsigned int barycenterMaximumNumberOfPairs,
1131 double percent,
1132 bool useBD = true) {
1133 if(barycenterMaximumNumberOfPairs > 0)
1134 keepMostImportantPairs<dataType>(
1135 &(bary.tree), barycenterMaximumNumberOfPairs, useBD);
1136 if(percent > 0)
1137 limitSizePercent(bary, trees, percent, useBD);
1138 }
1139 template <class dataType>
1141 std::vector<ftm::FTMTree_MT *> &trees,
1142 double percent,
1143 bool useBD = true) {
1145 bary, trees, barycenterMaximumNumberOfPairs_, percent, useBD);
1146 }
1147 template <class dataType>
1149 std::vector<ftm::FTMTree_MT *> &trees,
1150 bool useBD = true) {
1153 }
1154
1155 // ------------------------------------------------------------------------
1156 // Postprocessing
1157 // ------------------------------------------------------------------------
1158 template <class dataType>
1160 if(not barycenter.tree.isFullMerge())
1161 return;
1162
1163 ftm::FTMTree_MT *tree = &(barycenter.tree);
1164 auto tup = fixMergedRootOrigin<dataType>(tree);
1165 int maxIndex = std::get<0>(tup);
1166 dataType oldOriginValue = std::get<1>(tup);
1167
1168 // Verify that scalars are consistent
1169 ftm::idNode const treeRoot = tree->getRoot();
1170 std::vector<dataType> newScalarsVector;
1171 ftm::getTreeScalars<dataType>(tree, newScalarsVector);
1172 bool isJT = tree->isJoinTree<dataType>();
1173 if((isJT and tree->getValue<dataType>(maxIndex) > oldOriginValue)
1174 or (not isJT
1175 and tree->getValue<dataType>(maxIndex) < oldOriginValue)) {
1176 newScalarsVector[treeRoot] = newScalarsVector[maxIndex];
1177 newScalarsVector[maxIndex] = oldOriginValue;
1178 } else
1179 newScalarsVector[treeRoot] = oldOriginValue;
1180 setTreeScalars(barycenter, newScalarsVector);
1181 }
1182
1183 // ------------------------------------------------------------------------
1184 // Utils
1185 // ------------------------------------------------------------------------
1187 const debug::Priority &priority
1189 auto noNodesT = baryTree->getNumberOfNodes();
1190 auto noNodes = baryTree->getRealNumberOfNodes();
1191 std::stringstream ss;
1192 ss << "Barycenter number of nodes : " << noNodes << " / " << noNodesT;
1193 printMsg(ss.str(), priority);
1194 }
1195
1196 // ------------------------------------------------------------------------
1197 // Testing
1198 // ------------------------------------------------------------------------
1199 template <class dataType>
1201 std::vector<ftm::FTMTree_MT *> &trees,
1202 ftm::MergeTree<dataType> &baryMergeTree,
1203 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1204 &finalMatchings,
1205 std::vector<dataType> distances) {
1206 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching;
1207 dataType distance;
1208 computeOneDistance(trees[0], trees[1], matching, distance);
1209 if(distance != (distances[0] + distances[1])) {
1210 std::stringstream ss, ss2, ss3, ss4;
1211 ss << "distance T1 T2 : " << distance;
1212 printMsg(ss.str());
1213 ss2 << "distance T1 T' T2 : " << distances[0] + distances[1];
1214 printMsg(ss2.str());
1215 ss3 << "distance T1 T' : " << distances[0];
1216 printMsg(ss3.str());
1217 ss4 << "distance T' T2 : " << distances[1];
1218 printMsg(ss4.str());
1219 }
1220 return;
1221
1222 auto baryTree = &(baryMergeTree.tree);
1223 std::vector<std::vector<ftm::idNode>> baryMatched(
1224 baryTree->getNumberOfNodes(),
1225 std::vector<ftm::idNode>(
1226 trees.size(), std::numeric_limits<ftm::idNode>::max()));
1227 for(unsigned int i = 0; i < finalMatchings.size(); ++i)
1228 for(auto match : finalMatchings[i])
1229 baryMatched[std::get<0>(match)][i] = std::get<1>(match);
1230
1231 std::queue<ftm::idNode> queue;
1232 queue.emplace(baryTree->getRoot());
1233 while(!queue.empty()) {
1234 auto node = queue.front();
1235 queue.pop();
1236 std::vector<dataType> costs(trees.size());
1237 for(unsigned int i = 0; i < trees.size(); ++i)
1238 if(baryMatched[node][i] != std::numeric_limits<ftm::idNode>::max())
1239 costs[i] = relabelCost<dataType>(
1240 baryTree, node, trees[i], baryMatched[node][i]);
1241 else
1242 costs[i] = deleteCost<dataType>(baryTree, node);
1243 dataType cost = 0;
1244 if(baryMatched[node][0] != std::numeric_limits<ftm::idNode>::max()
1245 and baryMatched[node][1] != std::numeric_limits<ftm::idNode>::max())
1246 cost = relabelCost<dataType>(
1247 trees[0], baryMatched[node][0], trees[1], baryMatched[node][1]);
1248 else if(baryMatched[node][0] == std::numeric_limits<ftm::idNode>::max())
1249 cost = deleteCost<dataType>(trees[1], baryMatched[node][1]);
1250 else if(baryMatched[node][1] == std::numeric_limits<ftm::idNode>::max())
1251 cost = deleteCost<dataType>(trees[0], baryMatched[node][0]);
1252 else
1253 printErr("problem");
1254 costs[0] = std::sqrt(costs[0]);
1255 costs[1] = std::sqrt(costs[1]);
1256 cost = std::sqrt(cost);
1257 if(std::abs((double)(costs[0] - costs[1])) > 1e-7) {
1259 std::stringstream ss, ss2, ss3, ss4;
1260 ss << "cost T' T0 : " << costs[0];
1261 printMsg(ss.str());
1262 ss2 << "cost T' T1 : " << costs[1];
1263 printMsg(ss2.str());
1264 ss3 << "cost T0 T1 : " << cost;
1265 printMsg(ss2.str());
1266 ss4 << "cost T0 T' T1 : " << costs[0] + costs[1];
1267 printMsg(ss4.str());
1268 if(std::abs((double)((costs[0] + costs[1]) - cost)) > 1e-7) {
1269 std::stringstream ss5;
1270 ss5 << "diff : "
1271 << std::abs((double)((costs[0] + costs[1]) - cost));
1272 printMsg(ss5.str());
1273 }
1274 std::stringstream ss6;
1275 ss6 << "diff2 : " << std::abs((double)(costs[0] - costs[1]));
1276 printMsg(ss.str());
1277 // baryTree->printNode2<dataType>(node);
1278 // baryTree->printNode2<dataType>(baryTree->getParentSafe(node));
1279 for(unsigned int i = 0; i < 2; ++i)
1280 if(baryMatched[node][i]
1281 != std::numeric_limits<ftm::idNode>::max()) {
1282 printMsg(
1283 trees[i]->printNode2<dataType>(baryMatched[node][i]).str());
1284 printMsg(trees[i]
1285 ->printNode2<dataType>(
1286 trees[i]->getParentSafe(baryMatched[node][i]))
1287 .str());
1288 }
1289 }
1290 std::vector<ftm::idNode> children;
1291 baryTree->getChildren(node, children);
1292 for(auto child : children)
1293 queue.emplace(child);
1294 }
1295 }
1296
1297 }; // MergeTreeBarycenter class
1298
1299} // namespace ttk
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
Definition BaseClass.h:47
#define UNTIED()
virtual int setThreadNumber(const int threadNumber)
Definition BaseClass.h:80
Minimalist debugging class.
Definition Debug.h:88
int debugLevel_
Definition Debug.h:379
int printWrn(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
Definition Debug.h:159
void setDebugMsgPrefix(const std::string &prefix)
Definition Debug.h:364
virtual int setDebugLevel(const int &debugLevel)
Definition Debug.cpp:147
int printErr(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
Definition Debug.h:149
int getBestInitTreeIndex(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::FTMTree_MT * > &trees2, unsigned int barycenterMaximumNumberOfPairs, double sizeLimitPercent, bool distMinimizer=true)
void limitSizeBarycenter(ftm::MergeTree< dataType > &bary, std::vector< ftm::FTMTree_MT * > &trees, unsigned int barycenterMaximumNumberOfPairs, double percent, bool useBD=true)
ftm::idNode getNodesAndScalarsToAdd(ftm::MergeTree< dataType > &ttkNotUsed(mTree1), ftm::idNode nodeId1, ftm::FTMTree_MT *tree2, ftm::idNode nodeId2, std::vector< dataType > &newScalarsVector, std::vector< std::tuple< ftm::idNode, ftm::idNode, int > > &nodesToProcess, ftm::idNode nodeCpt, int i)
std::tuple< dataType, dataType > getParametrizedBirthDeath(ftm::FTMTree_MT *tree1, ftm::idNode nodeId1)
void getParametrizedDistanceMatrix(std::vector< ftm::FTMTree_MT * > &trees, std::vector< std::vector< double > > &distanceMatrix, unsigned int barycenterMaximumNumberOfPairs, double sizeLimitPercent, bool useDoubleInput=false, bool isFirstInput=true)
void verifyBarycenterTwoTrees(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, std::vector< dataType > distances)
void computeOneDistance(ftm::FTMTree_MT *tree, ftm::FTMTree_MT *baryTree, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matching, dataType &distance, bool useDoubleInput=false, bool isFirstInput=true)
void updateNodesAndScalars(ftm::MergeTree< dataType > &mTree1, int noTrees, std::vector< std::tuple< ftm::idNode, ftm::idNode, int > > &nodesToProcess, std::vector< dataType > &newScalarsVector, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode > > > &nodesProcessed)
unsigned int persistenceScaling(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::MergeTree< dataType > > &mergeTrees, std::vector< ftm::FTMTree_MT * > &oriTrees, int iterationNumber, std::vector< std::vector< ftm::idNode > > &deletedNodes)
void fixMergedRootOriginBarycenter(ftm::MergeTree< dataType > &barycenter)
void setAddNodes(bool addNodesT)
void setPreprocess(bool preproc)
void computeOneDistance(ftm::MergeTree< dataType > &baryMergeTree, ftm::MergeTree< dataType > &baryMergeTree2, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matching, dataType &distance, bool useDoubleInput=false, bool isFirstInput=true)
void getDistanceMatrix(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::FTMTree_MT * > &trees2, std::vector< std::vector< double > > &distanceMatrix, bool useDoubleInput=false, bool isFirstInput=true)
void limitSizeBarycenter(ftm::MergeTree< dataType > &bary, std::vector< ftm::FTMTree_MT * > &trees, double percent, bool useBD=true)
std::tuple< dataType, dataType > interpolation(ftm::MergeTree< dataType > &baryMergeTree, ftm::idNode nodeId, std::vector< dataType > &newScalarsVector, std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::idNode > &nodes, std::vector< double > &alphas)
int getBestInitTreeIndex(std::vector< ftm::FTMTree_MT * > &trees, bool distMinimizer=true)
void computeBarycenter(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< double > &alphas, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, bool finalAsgnDoubleInput=false, bool finalAsgnFirstInput=true)
void assignmentPara(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< dataType > &distances, bool useDoubleInput=false, bool isFirstInput=true)
void setPostprocess(bool postproc)
std::tuple< dataType, dataType > interpolationAdded(ftm::FTMTree_MT *tree, ftm::idNode nodeId, double alpha, ftm::MergeTree< dataType > &baryMergeTree, ftm::idNode nodeB, std::vector< dataType > &newScalarsVector)
unsigned int barycenterMaximumNumberOfPairs_
void setDeterministic(bool deterministicT)
~MergeTreeBarycenter() override=default
void updateBarycenterTreeScalars(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< double > &alphas, unsigned int indexAddedNodes, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings)
void addNodes(ftm::MergeTree< dataType > &mTree1, int noTrees, std::vector< std::tuple< ftm::idNode, ftm::idNode, int > > &nodesToProcess, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode > > > &nodesProcessed)
void initBarycenterTree(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryTree, bool distMinimizer=true)
void assignmentTask(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< dataType > &distances, bool useDoubleInput=false, bool isFirstInput=true)
void execute(std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, ftm::MergeTree< dataType > &baryMergeTree, bool finalAsgnDoubleInput=false, bool finalAsgnFirstInput=true)
void setProgressiveSpeedDivisor(double progSpeed)
void addScaledDeletedNodesCost(std::vector< ftm::FTMTree_MT * > &oriTrees, std::vector< std::vector< ftm::idNode > > &deletedNodes, std::vector< dataType > &distances)
void updateBarycenterTreeStructure(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings)
std::vector< double > finalDistances_
void getSizeLimitedTrees(std::vector< ftm::FTMTree_MT * > &trees, unsigned int barycenterMaximumNumberOfPairs, double sizeLimitPercent, std::vector< ftm::MergeTree< dataType > > &mTreesLimited)
void printBaryStats(ftm::FTMTree_MT *baryTree, const debug::Priority &priority=debug::Priority::INFO)
void setProgressiveBarycenter(bool progressive)
void limitSizeBarycenter(ftm::MergeTree< dataType > &bary, std::vector< ftm::FTMTree_MT * > &trees, bool useBD=true)
void assignment(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< dataType > &distances, bool useDoubleInput=false, bool isFirstInput=true)
void setBarycenterMaximumNumberOfPairs(unsigned int maxi)
void computeOneDistance(ftm::FTMTree_MT *tree, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matching, dataType &distance, bool useDoubleInput=false, bool isFirstInput=true)
void setBarycenterSizeLimitPercent(double percent)
void getSizeLimitedDistanceMatrix(std::vector< ftm::FTMTree_MT * > &trees, std::vector< std::vector< double > > &distanceMatrix, unsigned int barycenterMaximumNumberOfPairs, double sizeLimitPercent, bool useDoubleInput=false, bool isFirstInput=true)
void updateBarycenterTree(std::vector< ftm::FTMTree_MT * > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< double > &alphas, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings)
std::vector< double > getFinalDistances()
void limitSizePercent(ftm::MergeTree< dataType > &bary, std::vector< ftm::FTMTree_MT * > &trees, double percent, bool useBD)
void getDistanceMatrix(std::vector< ftm::FTMTree_MT * > &trees, std::vector< std::vector< double > > &distanceMatrix, bool useDoubleInput=false, bool isFirstInput=true)
int getBestInitTreeIndex(std::vector< ftm::FTMTree_MT * > &trees, std::vector< ftm::FTMTree_MT * > &trees2, double sizeLimitPercent, bool distMinimizer=true)
void execute(std::vector< ftm::MergeTree< dataType > > &trees, std::vector< double > &alphas, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, ftm::MergeTree< dataType > &baryMergeTree, bool finalAsgnDoubleInput=false, bool finalAsgnFirstInput=true)
void setBranchDecomposition(bool useBD)
void setNormalizedWasserstein(bool normalizedWasserstein)
void setDistanceSquaredRoot(bool distanceSquaredRoot)
void setAssignmentSolver(int assignmentSolver)
void setNodePerTask(int npt)
double mixDistances(dataType distance1, dataType distance2)
double getSizeLimitMetric(std::vector< ftm::FTMTree_MT * > &trees)
void printTreesStats(std::vector< ftm::FTMTree_MT * > &trees)
void printMatching(std::vector< MatchingType > &matchings)
std::vector< std::vector< int > > treesNodeCorr_
void setKeepSubtree(bool keepSubtree)
double mixDistancesMinMaxPairWeight(bool isFirstInput)
dataType computeDistance(ftm::FTMTree_MT *tree1, ftm::FTMTree_MT *tree2, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &outputMatching)
void setPreprocess(bool preproc)
void setPostprocess(bool postproc)
void setMinMaxPairWeight(double weight)
double getElapsedTime()
Definition Timer.h:15
const scalarType & getValue(SimplexId nodeId) const
Definition FTMTree_MT.h:339
bool isNodeIdInconsistent(idNode nodeId)
idNode getNumberOfNodes() const
Definition FTMTree_MT.h:389
void setParent(idNode nodeId, idNode newParentNodeId)
std::tuple< dataType, dataType > getBirthDeath(idNode nodeId)
void getChildren(idNode nodeId, std::vector< idNode > &res)
void deleteNode(idNode nodeId)
idNode makeNode(SimplexId vertexId, SimplexId linked=nullVertex)
bool isNodeAlone(idNode nodeId)
void copyMergeTreeStructure(FTMTree_MT *tree)
Node * getNode(idNode nodeId)
Definition FTMTree_MT.h:393
void setOrigin(SimplexId linked)
Definition FTMNode.h:72
SimplexId getOrigin() const
Definition FTMNode.h:64
unsigned int idNode
Node index in vect_nodes_.
The Topology ToolKit.
T end(std::pair< T, T > &p)
Definition ripserpy.cpp:483
ftm::FTMTree_MT tree
Definition FTMTree_MT.h:903
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/|__ _|"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)