TTK
Loading...
Searching...
No Matches
MergeTreePrincipalGeodesics.h
Go to the documentation of this file.
1
20
21#pragma once
22
23// ttk common includes
24#include <Debug.h>
26
27namespace ttk {
28
35 class MergeTreePrincipalGeodesics : virtual public Debug,
37
38 protected:
40 // TODO keepState works only when enabled before first computation
41 bool keepState_ = false;
42 unsigned int noProjectionStep_ = 2;
43
44 // Advanced parameters
46
47 // Old/Testing
49
50 // Filled by the algorithm
51 std::vector<double> inputToBaryDistances_;
52 std::vector<std::vector<double>> inputToGeodesicsDistances_;
56
58 double cumulVariance_ = 0.0, cumulTVariance_ = 0.0;
59
60 public:
62 // inherited from Debug: prefix will be printed at the beginning of every
63 // msg
64 this->setDebugMsgPrefix("MergeTreePrincipalGeodesics");
65#ifdef TTK_ENABLE_OPENMP
66 omp_set_nested(1);
67#endif
68 }
69
70 unsigned int getGeodesicNumber() {
71 return vS_.size();
72 }
73
74 //----------------------------------------------------------------------------
75 // OpenMP Reduction
76 //----------------------------------------------------------------------------
77 struct Compare {
78 double bestDistance = std::numeric_limits<double>::max();
79 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> bestMatching,
81 int bestIndex = 0;
82 };
83
84 //----------------------------------------------------------------------------
85 // Init
86 //----------------------------------------------------------------------------
87 template <class dataType>
88 void initVectors(int geodesicNumber,
89 ftm::MergeTree<dataType> &barycenter,
90 std::vector<ftm::MergeTree<dataType>> &trees,
91 ftm::MergeTree<dataType> &barycenter2,
92 std::vector<ftm::MergeTree<dataType>> &trees2,
93 std::vector<std::vector<double>> &v1,
94 std::vector<std::vector<double>> &v2,
95 std::vector<std::vector<double>> &trees2V1,
96 std::vector<std::vector<double>> &trees2V2) {
97 auto initializedVectorsProjection
98 = [=](int _geodesicNumber, ftm::MergeTree<dataType> &_barycenter,
99 std::vector<std::vector<double>> &_v,
100 std::vector<std::vector<double>> &_v2,
101 std::vector<std::vector<std::vector<double>>> &_vS,
102 std::vector<std::vector<std::vector<double>>> &_v2s,
103 ftm::MergeTree<dataType> &_barycenter2,
104 std::vector<std::vector<double>> &_trees2V,
105 std::vector<std::vector<double>> &_trees2V2,
106 std::vector<std::vector<std::vector<double>>> &_trees2Vs,
107 std::vector<std::vector<std::vector<double>>> &_trees2V2s,
108 bool _useSecondInput, unsigned int _noProjectionStep) {
109 return this->projectionStep(_geodesicNumber, _barycenter, _v, _v2,
110 _vS, _v2s, _barycenter2, _trees2V,
111 _trees2V2, _trees2Vs, _trees2V2s,
112 _useSecondInput, _noProjectionStep);
113 };
114
115 MergeTreeAxesAlgorithmBase::initVectors<dataType>(
116 geodesicNumber, barycenter, trees, barycenter2, trees2, v1, v2,
117 trees2V1, trees2V2, newVectorOffset_, inputToBaryDistances_,
120 initializedVectorsProjection);
121 }
122
123 //----------------------------------------------------------------------------
124 // Costs
125 //----------------------------------------------------------------------------
126 double orthogonalCost(std::vector<std::vector<std::vector<double>>> &vS,
127 std::vector<std::vector<std::vector<double>>> &v2s,
128 std::vector<std::vector<double>> &v,
129 std::vector<std::vector<double>> &v2) {
130 return verifyOrthogonality(vS, v2s, v, v2, false);
131 }
132
133 double regularizerCost(std::vector<std::vector<double>> &v,
134 std::vector<std::vector<double>> &v2) {
135 auto cost = ttk::Geometry::dotProductFlatten(v, v2)
138 return cost * cost;
139 }
140
141 double projectionCost(std::vector<std::vector<double>> &v,
142 std::vector<std::vector<double>> &v2,
143 std::vector<std::vector<std::vector<double>>> &vS,
144 std::vector<std::vector<std::vector<double>>> &v2s,
145 double optMapCost) {
146 return regularizerCost(v, v2) + orthogonalCost(vS, v2s, v, v2)
147 + optMapCost;
148 }
149
150 //----------------------------------------------------------------------------
151 // Projection
152 //----------------------------------------------------------------------------
153 template <class dataType>
155 ftm::MergeTree<dataType> &extremity,
156 std::vector<std::vector<double>> &v,
157 bool isV1,
158 bool useDoubleInput = false,
159 bool isFirstInput = true) {
160 ftm::FTMTree_MT *barycenterTree = &(barycenter.tree);
161 ftm::FTMTree_MT *extremityTree = &(extremity.tree);
162 double t = (isV1 ? -1.0 : 1.0);
163
164 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching;
165 dataType distance;
166 std::vector<ftm::idNode> matchingVector;
167
168 if(extremityTree->getRealNumberOfNodes() != 0) {
169 computeOneDistance(barycenter, extremity, matching, distance, true,
170 useDoubleInput, isFirstInput);
171 getMatchingVector(barycenter, extremity, matching, matchingVector);
172 } else
173 matchingVector.resize(barycenterTree->getNumberOfNodes(),
174 std::numeric_limits<ftm::idNode>::max());
175
176 std::vector<std::vector<double>> oriV = v;
177 for(unsigned int i = 0; i < barycenter.tree.getNumberOfNodes(); ++i) {
178 if(barycenter.tree.isNodeAlone(i))
179 continue;
180 auto matched = matchingVector[i];
181 auto birthDeathBary
182 = getParametrizedBirthDeath<dataType>(barycenterTree, i);
183 dataType birthBary = std::get<0>(birthDeathBary);
184 dataType deathBary = std::get<1>(birthDeathBary);
185 std::vector<double> newV{0.0, 0.0};
186 if(matched != std::numeric_limits<ftm::idNode>::max()) {
187 auto birthDeathMatched
188 = getParametrizedBirthDeath<dataType>(extremityTree, matched);
189 newV[0] = std::get<0>(birthDeathMatched);
190 newV[1] = std::get<1>(birthDeathMatched);
191 } else {
192 dataType projec = (birthBary + deathBary) / 2.0;
193 newV[0] = projec;
194 newV[1] = projec;
195 }
196 newV[0] = (newV[0] - birthBary) * t;
197 newV[1] = (newV[1] - deathBary) * t;
198 v[i] = newV;
199 }
200
201 // Compute distance between old and new extremity
202 double cost = ttk::Geometry::distanceFlatten(v, oriV);
203 return cost;
204 }
205
206 template <class dataType>
207 double
209 std::vector<std::vector<double>> &v,
210 std::vector<std::vector<double>> &v2,
211 ftm::MergeTree<dataType> &barycenter2,
212 std::vector<std::vector<double>> &trees2V,
213 std::vector<std::vector<double>> &trees2V2,
214 bool useSecondInput = false) {
215 std::vector<ftm::MergeTree<dataType>> extremities(
216 (useSecondInput ? 4 : 2));
217 getInterpolation<dataType>(barycenter, v, v2, 0.0, extremities[0]);
218 getInterpolation<dataType>(barycenter, v, v2, 1.0, extremities[1]);
219 if(useSecondInput) {
220 getInterpolation<dataType>(
221 barycenter2, trees2V, trees2V2, 0.0, extremities[2]);
222 getInterpolation<dataType>(
223 barycenter2, trees2V, trees2V2, 1.0, extremities[3]);
224 }
225 double cost = 0.0;
226#ifdef TTK_ENABLE_OPENMP
227#pragma omp parallel for schedule(dynamic) \
228 num_threads(this->threadNumber_) if(parallelize_)
229#endif
230 for(unsigned int i = 0; i < extremities.size(); ++i) {
231 bool isFirstInput = (i < 2);
232 ftm::MergeTree<dataType> &baryToUse
233 = (i < 2 ? barycenter : barycenter2);
234 std::vector<std::vector<double>> &vToUse
235 = (i == 0 ? v : (i == 1 ? v2 : (i == 2 ? trees2V : trees2V2)));
236 cost
237 += barycentricProjection(baryToUse, extremities[i], vToUse,
238 (i % 2 == 0), useSecondInput, isFirstInput);
239 }
240 return cost;
241 }
242
243 // Collinearity constraint
244 void
245 trueGeneralizedGeodesicProjection(std::vector<std::vector<double>> &v1,
246 std::vector<std::vector<double>> &v2) {
247 std::vector<double> v1_flatten, v2_flatten;
250 double v1_norm = ttk::Geometry::magnitude(v1_flatten);
251 double v2_norm = ttk::Geometry::magnitude(v2_flatten);
252 double beta = v2_norm / (v1_norm + v2_norm);
253 std::vector<double> v;
254 ttk::Geometry::addVectors(v1_flatten, v2_flatten, v);
255 ttk::Geometry::scaleVector(v, (1 - beta), v1_flatten);
256 ttk::Geometry::scaleVector(v, beta, v2_flatten);
259 }
260
261 void
262 orthogonalProjection(std::vector<std::vector<double>> &v1,
263 std::vector<std::vector<double>> &v2,
264 std::vector<std::vector<std::vector<double>>> &vS,
265 std::vector<std::vector<std::vector<double>>> &v2s) {
266 // Multi flatten and sum vS and v2s
267 std::vector<std::vector<double>> sumVs;
269
270 // Flatten v1 and v2
271 std::vector<double> v1_flatten, v2_flatten, v1_proj, v2_proj;
274
275 // Call Gram Schmidt
276 callGramSchmidt(sumVs, v1_flatten, v1_proj);
277 callGramSchmidt(sumVs, v2_flatten, v2_proj);
278
279 // Unflatten the resulting vectors
282 }
283
284 // TODO avoid copying vectors
285 template <class dataType>
286 double
287 projectionStep(int geodesicNumber,
288 ftm::MergeTree<dataType> &barycenter,
289 std::vector<std::vector<double>> &v,
290 std::vector<std::vector<double>> &v2,
291 std::vector<std::vector<std::vector<double>>> &vS,
292 std::vector<std::vector<std::vector<double>>> &v2s,
293 ftm::MergeTree<dataType> &barycenter2,
294 std::vector<std::vector<double>> &trees2V,
295 std::vector<std::vector<double>> &trees2V2,
296 std::vector<std::vector<std::vector<double>>> &trees2Vs,
297 std::vector<std::vector<std::vector<double>>> &trees2V2s,
298 bool useSecondInput,
299 unsigned int noProjectionStep) {
300 std::vector<std::vector<std::vector<double>>> vSConcat, v2sConcat;
301 if(useSecondInput) {
302 Timer t_vectorCopy;
303 vSConcat = vS;
304 v2sConcat = v2s;
305 for(unsigned int j = 0; j < vS.size(); ++j) {
306 vSConcat[j].insert(
307 vSConcat[j].end(), trees2Vs[j].begin(), trees2Vs[j].end());
308 v2sConcat[j].insert(
309 v2sConcat[j].end(), trees2V2s[j].begin(), trees2V2s[j].end());
310 }
311 t_vectorCopy_time_ += t_vectorCopy.getElapsedTime();
312 }
313
314 double optMapCost = 0.0;
315 for(unsigned i = 0; i < noProjectionStep; ++i) {
316 std::vector<std::vector<double>> vOld = v, v2Old = v2;
317
318 // --- Optimal mapping set projecton
321 Timer t_optMap;
322 optMapCost = optimalMappingSetProjection(
323 barycenter, v, v2, barycenter2, trees2V, trees2V2, useSecondInput);
324 printMsg("OMS Proj.", 1, t_optMap.getElapsedTime(), threadNumber_,
326
327 // ---
328 std::vector<std::vector<double>> vConcat, v2Concat;
329 if(useSecondInput) {
330 Timer t_vectorCopy;
331 vConcat = v;
332 vConcat.insert(vConcat.end(), trees2V.begin(), trees2V.end());
333 v2Concat = v2;
334 v2Concat.insert(v2Concat.end(), trees2V2.begin(), trees2V2.end());
335 t_vectorCopy_time_ += t_vectorCopy.getElapsedTime();
336 }
337
338 // --- True generalized geodesic projection
341 Timer t_trueGeod;
342 if(useSecondInput)
343 trueGeneralizedGeodesicProjection(vConcat, v2Concat);
344 else
346 printMsg("TGG Proj.", 1, t_trueGeod.getElapsedTime(), threadNumber_,
348
349 // --- Orthogonal projection
350 if(geodesicNumber != 0) {
351 printMsg("Orth. Proj.", 0, 0, threadNumber_, debug::LineMode::REPLACE,
353 Timer t_ortho;
354 if(useSecondInput) {
355 orthogonalProjection(vConcat, v2Concat, vSConcat, v2sConcat);
356 } else
357 orthogonalProjection(v, v2, vS, v2s);
358 printMsg("Orth. Proj.", 1, t_ortho.getElapsedTime(), threadNumber_,
360 }
361
362 // ---
363 if(useSecondInput) {
364 Timer t_vectorCopy;
365 for(unsigned int j = 0; j < v.size(); ++j) {
366 v[j] = vConcat[j];
367 v2[j] = v2Concat[j];
368 }
369 for(unsigned int j = 0; j < trees2V.size(); ++j) {
370 trees2V[j] = vConcat[v.size() + j];
371 trees2V2[j] = v2Concat[v.size() + j];
372 }
373 t_vectorCopy_time_ += t_vectorCopy.getElapsedTime();
374 }
375 }
376 return optMapCost;
377 }
378
379 //----------------------------------------------------------------------------
380 // Assignment
381 //----------------------------------------------------------------------------
382 template <class dataType>
384 ftm::MergeTree<dataType> &barycenter,
385 std::vector<ftm::MergeTree<dataType>> &trees,
386 std::vector<std::vector<double>> &v,
387 std::vector<std::vector<double>> &v2,
388 ftm::MergeTree<dataType> &barycenter2,
389 std::vector<ftm::MergeTree<dataType>> &trees2,
390 std::vector<std::vector<double>> &trees2V,
391 std::vector<std::vector<double>> &trees2V2,
392 std::vector<std::vector<double>> &allTreesTs,
393 std::vector<std::vector<std::vector<double>>> &vS,
394 std::vector<std::vector<std::vector<double>>> &v2s,
395 std::vector<std::vector<std::vector<double>>> &trees2Vs,
396 std::vector<std::vector<std::vector<double>>> &trees2V2s,
397 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
398 &matchings,
399 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
400 &matchings2,
401 std::vector<double> &ts,
402 std::vector<double> &distances) {
403 std::vector<std::vector<Compare>> best(
404 trees.size(), std::vector<Compare>(k_));
405
406#ifdef TTK_ENABLE_OPENMP
407#pragma omp parallel num_threads(this->threadNumber_) if(parallelize_) \
408 shared(best)
409 {
410#pragma omp single nowait
411 {
412#endif
413 for(unsigned int k = 0; k < k_; ++k) {
414 for(unsigned int i = 0; i < trees.size(); ++i) {
415#ifdef TTK_ENABLE_OPENMP
416#pragma omp task shared(best) firstprivate(i, k)
417 {
418#endif
419 double kT = (k % 2 == 0 ? k / 2 : k_ - 1 - (int)(k / 2));
420 double t = 1.0 / (k_ - 1) * kT;
421
422 dataType distance, distance2;
423 ftm::MergeTree<dataType> interpolated;
424 auto tsToUse = (allTreesTs.size() == 0 ? std::vector<double>()
425 : allTreesTs[i]);
427 barycenter, vS, v2s, v, v2, tsToUse, t, interpolated);
428 if(interpolated.tree.getRealNumberOfNodes() != 0) {
429 computeOneDistance<dataType>(interpolated, trees[i],
430 best[i][kT].bestMatching,
431 distance, true, useDoubleInput_);
432 if(trees2.size() != 0) {
433 ftm::MergeTree<dataType> interpolated2;
434 getMultiInterpolation(barycenter2, trees2Vs, trees2V2s,
435 trees2V, trees2V2, tsToUse, t,
436 interpolated2);
437 computeOneDistance<dataType>(
438 interpolated2, trees2[i], best[i][kT].bestMatching2,
439 distance2, true, useDoubleInput_, false);
440 distance = mixDistances(distance, distance2);
441 }
442 best[i][kT].bestDistance = distance;
443 best[i][kT].bestIndex = kT;
444 }
445#ifdef TTK_ENABLE_OPENMP
446 } // pragma omp task
447#endif
448 }
449 }
450#ifdef TTK_ENABLE_OPENMP
451#pragma omp taskwait
452#endif
453
454 // Reduction
455 for(unsigned int i = 0; i < trees.size(); ++i) {
456#ifdef TTK_ENABLE_OPENMP
457#pragma omp task firstprivate(i)
458 {
459#endif
460 double bestDistance = std::numeric_limits<double>::max();
461 int bestIndex = 0;
462 for(unsigned int k = 0; k < k_; ++k) {
463 if(best[i][k].bestDistance < bestDistance) {
464 bestIndex = k;
465 bestDistance = best[i][k].bestDistance;
466 }
467 }
468 matchings[i] = best[i][bestIndex].bestMatching;
469 if(trees2.size() != 0)
470 matchings2[i] = best[i][bestIndex].bestMatching2;
471 ts[i] = best[i][bestIndex].bestIndex * 1.0 / (k_ - 1);
472 distances[i] = best[i][bestIndex].bestDistance;
473#ifdef TTK_ENABLE_OPENMP
474 } // pragma omp task
475#endif
476 }
477#ifdef TTK_ENABLE_OPENMP
478 } // pragma omp single nowait
479 } // pragma omp parallel
480#endif
481 }
482
483 template <class dataType>
485 ftm::MergeTree<dataType> &barycenter,
486 std::vector<ftm::MergeTree<dataType>> &trees,
487 std::vector<std::vector<double>> &v,
488 std::vector<std::vector<double>> &v2,
489 ftm::MergeTree<dataType> &barycenter2,
490 std::vector<ftm::MergeTree<dataType>> &trees2,
491 std::vector<std::vector<double>> &trees2V,
492 std::vector<std::vector<double>> &trees2V2,
493 std::vector<std::vector<double>> &allTreesTs,
494 std::vector<std::vector<std::vector<double>>> &vS,
495 std::vector<std::vector<std::vector<double>>> &v2s,
496 std::vector<std::vector<std::vector<double>>> &trees2Vs,
497 std::vector<std::vector<std::vector<double>>> &trees2V2s,
498 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
499 &matchings,
500 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
501 &matchings2,
502 std::vector<double> &ts,
503 std::vector<double> &distances) {
504
505 matchings.resize(trees.size());
506 matchings2.resize(trees2.size());
507 ts.resize(trees.size());
508 distances.resize(trees.size());
509
510 // Assignment
511 assignmentImpl<dataType>(barycenter, trees, v, v2, barycenter2, trees2,
512 trees2V, trees2V2, allTreesTs, vS, v2s, trees2Vs,
513 trees2V2s, matchings, matchings2, ts, distances);
514 }
515
516 //----------------------------------------------------------------------------
517 // Update
518 //----------------------------------------------------------------------------
519 template <class dataType>
521 int geodesicNumber,
522 ftm::MergeTree<dataType> &barycenter,
523 std::vector<ftm::MergeTree<dataType>> &trees,
524 std::vector<ftm::MergeTree<dataType>> &allInterpolated,
525 std::vector<std::vector<double>> &v,
526 std::vector<std::vector<double>> &v2,
527 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
528 &matchings,
529 std::vector<std::vector<double>> &tss,
530 std::vector<std::vector<double>> &vR,
531 std::vector<std::vector<double>> &vR2,
532 std::vector<bool> &isUniform) {
533
534 // Init
535 ftm::FTMTree_MT *barycenterTree = &(barycenter.tree);
536 std::vector<ftm::FTMTree_MT *> ftmTrees, allInterpolatedTrees;
537 ttk::ftm::mergeTreeToFTMTree<dataType>(trees, ftmTrees);
538 if(geodesicNumber != 0) {
539 ttk::ftm::mergeTreeToFTMTree<dataType>(
540 allInterpolated, allInterpolatedTrees);
541 }
542
543 // Get matching matrix
544 std::vector<std::vector<ftm::idNode>> matchingMatrix;
545 getMatchingMatrix(barycenter, trees, matchings, matchingMatrix);
546
547 // Update
548 for(unsigned int i = 0; i < barycenter.tree.getNumberOfNodes(); ++i) {
549 if(barycenter.tree.isNodeAlone(i))
550 continue;
551
552 // Verify that ts is not uniform
553 if(isUniform[i]) {
554 v[i] = vR[i];
555 v2[i] = vR2[i];
556 continue;
557 }
558
559 // Compute projection
560 auto birthDeathBary
561 = getParametrizedBirthDeath<dataType>(barycenterTree, i);
562 dataType birthBary = std::get<0>(birthDeathBary);
563 dataType deathBary = std::get<1>(birthDeathBary);
564 dataType projec = (birthBary + deathBary) / 2.0;
565 std::vector<dataType> allBirthBary(trees.size(), birthBary);
566 std::vector<dataType> allDeathBary(trees.size(), deathBary);
567 std::vector<dataType> allProjec(trees.size(), projec);
568
569 if(geodesicNumber != 0)
570 for(unsigned int j = 0; j < trees.size(); ++j) {
571 auto birthDeathInterpol
572 = getParametrizedBirthDeath<dataType>(allInterpolatedTrees[j], i);
573 allBirthBary[j] = std::get<0>(birthDeathInterpol);
574 allDeathBary[j] = std::get<1>(birthDeathInterpol);
575 allProjec[j] = (allBirthBary[j] + allDeathBary[j]) / 2.0;
576 }
577
578 // Compute all matched values
579 std::vector<std::vector<dataType>> allMatched(trees.size());
580 for(unsigned int j = 0; j < trees.size(); ++j) {
581 dataType birth = allProjec[j];
582 dataType death = allProjec[j];
583 if(matchingMatrix[i][j] != std::numeric_limits<ftm::idNode>::max()) {
584 auto birthDeath = getParametrizedBirthDeath<dataType>(
585 ftmTrees[j], matchingMatrix[i][j]);
586 birth = std::get<0>(birthDeath);
587 death = std::get<1>(birthDeath);
588 }
589 allMatched[j].resize(2);
590 allMatched[j][0] = birth;
591 allMatched[j][1] = death;
592 }
593
594 // Compute general terms
595 double ti_squared = 0.0, one_min_ti_squared = 0.0, ti_one_min_ti = 0.0;
596 for(auto t : tss[i]) {
597 ti_squared += t * t;
598 one_min_ti_squared += (1 - t) * (1 - t);
599 ti_one_min_ti += t * (1 - t);
600 }
601
602 // Compute multiplier
603 double multBirthV1 = 0.0, multDeathV1 = 0.0, multBirthV2 = 0.0,
604 multDeathV2 = 0.0;
605 for(unsigned int j = 0; j < trees.size(); ++j) {
606 multBirthV1
607 += tss[i][j] * (allMatched[j][0] - allBirthBary[j]) / ti_squared;
608 multDeathV1
609 += tss[i][j] * (allMatched[j][1] - allDeathBary[j]) / ti_squared;
610 multBirthV2 += (1 - tss[i][j]) * (-allMatched[j][0] + allBirthBary[j])
611 / one_min_ti_squared;
612 multDeathV2 += (1 - tss[i][j]) * (-allMatched[j][1] + allDeathBary[j])
613 / one_min_ti_squared;
614 }
615
616 // Compute new birth death
617 double newBirthV1 = 0.0, newDeathV1 = 0.0, newBirthV2 = 0.0,
618 newDeathV2 = 0.0;
619 for(unsigned int j = 0; j < trees.size(); ++j) {
620 newBirthV1 += (1 - tss[i][j])
621 * (-allMatched[j][0] + allBirthBary[j]
622 + tss[i][j] * multBirthV1);
623 newDeathV1 += (1 - tss[i][j])
624 * (-allMatched[j][1] + allDeathBary[j]
625 + tss[i][j] * multDeathV1);
626 newBirthV2 += tss[i][j]
627 * (allMatched[j][0] - allBirthBary[j]
628 + (1 - tss[i][j]) * multBirthV2);
629 newDeathV2 += tss[i][j]
630 * (allMatched[j][1] - allDeathBary[j]
631 + (1 - tss[i][j]) * multDeathV2);
632 }
633 double divisorV1
634 = one_min_ti_squared - ti_one_min_ti * ti_one_min_ti / ti_squared;
635 double divisorV2
636 = ti_squared - ti_one_min_ti * ti_one_min_ti / one_min_ti_squared;
637 newBirthV1 /= divisorV1;
638 newDeathV1 /= divisorV1;
639 newBirthV2 /= divisorV2;
640 newDeathV2 /= divisorV2;
641
642 // Update vectors
643 v[i][0] = newBirthV1;
644 v[i][1] = newDeathV1;
645 v2[i][0] = newBirthV2;
646 v2[i][1] = newDeathV2;
647 }
648 }
649
650 template <class dataType>
651 void
652 manageIndividualTs(int geodesicNumber,
653 ftm::MergeTree<dataType> &barycenter,
654 std::vector<ftm::MergeTree<dataType>> &trees,
655 std::vector<std::vector<double>> &v,
656 std::vector<std::vector<double>> &v2,
657 std::vector<std::vector<std::vector<double>>> &vS,
658 std::vector<std::vector<std::vector<double>>> &v2s,
659 std::vector<double> &ts,
660 std::vector<std::vector<double>> &allTreesTs,
661 std::vector<ftm::MergeTree<dataType>> &allInterpolated,
662 std::vector<bool> &isUniform,
663 std::vector<std::vector<double>> &tss,
664 unsigned int &noUniform,
665 bool &foundAllUniform) {
666 // Get multi interpolation
667 allInterpolated.resize(trees.size());
668 if(geodesicNumber != 0) {
669 for(unsigned int i = 0; i < trees.size(); ++i)
671 barycenter, vS, v2s, allTreesTs[i], allInterpolated[i]);
672 }
673
674 // Manage individuals t
675 noUniform = 0;
676 foundAllUniform = true;
677 isUniform.resize(barycenter.tree.getNumberOfNodes(), false);
678 tss.resize(barycenter.tree.getNumberOfNodes());
679 for(unsigned int i = 0; i < barycenter.tree.getNumberOfNodes(); ++i) {
680 if(barycenter.tree.isNodeAlone(i))
681 continue;
682 tss[i] = ts;
683 for(unsigned int j = 0; j < tss[i].size(); ++j) {
684 auto &treeToUse
685 = (geodesicNumber != 0 ? allInterpolated[j] : barycenter);
686 tss[i][j] = getTNew<dataType>(treeToUse, v, v2, i, ts[j]);
687 }
688 isUniform[i] = ttk::Geometry::isVectorUniform(tss[i]);
689 noUniform += isUniform[i];
690 foundAllUniform &= isUniform[i];
691 }
692 }
693
694 template <class dataType>
696 int geodesicNumber,
697 ftm::MergeTree<dataType> &barycenter,
698 std::vector<ftm::MergeTree<dataType>> &trees,
699 std::vector<std::vector<double>> &v,
700 std::vector<std::vector<double>> &v2,
701 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
702 &matchings,
703 std::vector<std::vector<std::vector<double>>> &vS,
704 std::vector<std::vector<std::vector<double>>> &v2s,
705 ftm::MergeTree<dataType> &barycenter2,
706 std::vector<ftm::MergeTree<dataType>> &trees2,
707 std::vector<std::vector<double>> &trees2V,
708 std::vector<std::vector<double>> &trees2V2,
709 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
710 &matchings2,
711 std::vector<std::vector<std::vector<double>>> &trees2Vs,
712 std::vector<std::vector<std::vector<double>>> &trees2V2s,
713 std::vector<double> &ts,
714 std::vector<std::vector<double>> &allTreesTs) {
715
716 std::vector<ftm::MergeTree<dataType>> allInterpolated, allInterpolated2;
717 std::vector<bool> isUniform, isUniform2;
718 std::vector<std::vector<double>> tss, tss2;
719 unsigned int noUniform;
720 bool foundAllUniform;
721 manageIndividualTs(geodesicNumber, barycenter, trees, v, v2, vS, v2s, ts,
722 allTreesTs, allInterpolated, isUniform, tss, noUniform,
723 foundAllUniform);
724 if(trees2.size() != 0) {
725 unsigned int noUniform2;
726 bool foundAllUniform2;
727 manageIndividualTs(geodesicNumber, barycenter2, trees2, trees2V,
728 trees2V2, trees2Vs, trees2V2s, ts, allTreesTs,
729 allInterpolated2, isUniform2, tss2, noUniform2,
730 foundAllUniform2);
731 noUniform += noUniform2;
732 foundAllUniform &= foundAllUniform2;
733 }
734
735 if(foundAllUniform) {
736 printMsg("All projection coefficients are the same.");
737 printMsg("New vectors will be initialized.");
738 newVectorOffset_ += 1;
739 initVectors(geodesicNumber, barycenter, trees, barycenter2, trees2, v,
740 v2, trees2V, trees2V2);
741 return true;
742 }
743 std::vector<std::vector<double>> vR, vR2, trees2VR, trees2VR2;
744 if(noUniform != 0) {
745 printMsg("Found " + std::to_string(noUniform)
746 + " uniform coefficients.");
747 initVectors(geodesicNumber, barycenter, trees, barycenter2, trees2, vR,
748 vR2, trees2VR, trees2VR2);
749 }
750
751 updateClosedForm(geodesicNumber, barycenter, trees, allInterpolated, v,
752 v2, matchings, tss, vR, vR2, isUniform);
753 if(trees2.size() != 0) {
754 updateClosedForm(geodesicNumber, barycenter2, trees2, allInterpolated2,
755 trees2V, trees2V2, matchings2, tss2, trees2VR,
756 trees2VR2, isUniform2);
757 copyMinMaxPairVector(v, v2, trees2V, trees2V2);
758 }
759 return false;
760 }
761
762 template <class dataType>
764 int geodesicNumber,
765 ftm::MergeTree<dataType> &barycenter,
766 std::vector<ftm::MergeTree<dataType>> &trees,
767 std::vector<std::vector<double>> &v,
768 std::vector<std::vector<double>> &v2,
769 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
770 &matchings,
771 std::vector<std::vector<std::vector<double>>> &vS,
772 std::vector<std::vector<std::vector<double>>> &v2s,
773 ftm::MergeTree<dataType> &barycenter2,
774 std::vector<ftm::MergeTree<dataType>> &trees2,
775 std::vector<std::vector<double>> &trees2V,
776 std::vector<std::vector<double>> &trees2V2,
777 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
778 &matchings2,
779 std::vector<std::vector<std::vector<double>>> &trees2Vs,
780 std::vector<std::vector<std::vector<double>>> &trees2V2s,
781 std::vector<double> &ts,
782 std::vector<std::vector<double>> &allTreesTs) {
783 return updateClosedFormStep<dataType>(
784 geodesicNumber, barycenter, trees, v, v2, matchings, vS, v2s,
785 barycenter2, trees2, trees2V, trees2V2, matchings2, trees2Vs, trees2V2s,
786 ts, allTreesTs);
787 }
788
789 //----------------------------------------------------------------------------
790 // Main functions
791 //----------------------------------------------------------------------------
792 template <class dataType>
793 bool convergenceStep(std::vector<double> &distances,
794 std::vector<std::vector<double>> &v,
795 std::vector<std::vector<double>> &v2,
796 dataType &oldFrechetEnergy,
797 dataType &minFrechetEnergy,
798 int &cptBlocked,
799 bool &converged,
800 double optMapCost) {
801 bool isBestEnergy = false;
802
803 // Reconstruction cost (main energy)
804 double frechetEnergy = 0;
805 for(unsigned int i = 0; i < distances.size(); ++i)
806 frechetEnergy += distances[i] * distances[i] / distances.size();
807 std::stringstream ssEnergy;
808 ssEnergy << "Energy = " << frechetEnergy;
809 printMsg(ssEnergy.str());
810
811 // Prop. cost
812 std::stringstream ssReg;
813 auto reg = regularizerCost(v, v2);
814 ssReg << "Prop. cost = " << reg;
815 printMsg(ssReg.str());
816
817 // Ortho. cost
818 std::stringstream ssOrthoCost;
819 auto orthoCost = orthogonalCost(vS_, v2s_, v, v2);
820 ssOrthoCost << "Ortho. cost = " << orthoCost;
821 printMsg(ssOrthoCost.str());
822
823 // Map. cost
824 std::stringstream ssOptMapCost;
825 ssOptMapCost << "Map. cost = " << optMapCost;
826 printMsg(ssOptMapCost.str());
827
828 // Detect convergence
829 double tol = 0.01;
830 tol = oldFrechetEnergy / 125.0;
831 converged = std::abs(frechetEnergy - oldFrechetEnergy) < tol;
832 oldFrechetEnergy = frechetEnergy;
833
834 if(frechetEnergy + ENERGY_COMPARISON_TOLERANCE < minFrechetEnergy) {
835 minFrechetEnergy = frechetEnergy;
836 cptBlocked = 0;
837 isBestEnergy = true;
838 }
839 if(not converged) {
840 cptBlocked += (minFrechetEnergy < frechetEnergy) ? 1 : 0;
841 converged = (cptBlocked >= 10);
842 }
843
844 return isBestEnergy;
845 }
846
847 template <class dataType>
848 void
849 computePrincipalGeodesic(unsigned int geodesicNumber,
850 ftm::MergeTree<dataType> &barycenter,
851 std::vector<ftm::MergeTree<dataType>> &trees,
852 ftm::MergeTree<dataType> &barycenter2,
853 std::vector<ftm::MergeTree<dataType>> &trees2) {
854 // ----- Init Parameters
856 Timer t_init;
857 std::vector<std::vector<double>> v, v2, trees2V, trees2V2;
858 initVectors<dataType>(geodesicNumber, barycenter, trees, barycenter2,
859 trees2, v, v2, trees2V, trees2V2);
861 printMsg("Init", 1, t_init.getElapsedTime(), threadNumber_);
862
863 std::vector<std::vector<double>> bestV, bestV2, bestTrees2V, bestTrees2V2;
864 std::vector<double> bestTs, bestDistances;
865 int bestIteration = 0;
866
867 // ----- Init Loop
868 dataType oldFrechetEnergy, minFrechetEnergy;
869 int cptBlocked, iteration = 0;
870 auto initLoop = [&]() {
871 oldFrechetEnergy = -1;
872 minFrechetEnergy = std::numeric_limits<dataType>::max();
873 cptBlocked = 0;
874 iteration = 0;
875 };
876 initLoop();
877
878 // ----- Algorithm
879 double optMapCost = 0.0;
880 bool converged = false;
881 while(not converged) {
882 std::stringstream ss;
883 ss << "Iteration " << iteration;
885 printMsg(ss.str());
886
887 // --- Assignment
888 printMsg("Assignment", 0, 0, threadNumber_, debug::LineMode::REPLACE);
889 Timer t_assignment;
890 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
891 matchings, matchings2;
892 std::vector<double> ts, distances;
893 assignmentStep(barycenter, trees, v, v2, barycenter2, trees2, trees2V,
895 matchings, matchings2, ts, distances);
896 inputToGeodesicsDistances_[geodesicNumber] = distances;
897 allTs_[geodesicNumber] = ts;
898 printMsg("Assignment", 1, t_assignment.getElapsedTime(), threadNumber_);
899
900 // --- Convergence
901 bool isBest = convergenceStep(distances, v, v2, oldFrechetEnergy,
902 minFrechetEnergy, cptBlocked, converged,
903 optMapCost);
904 if(isBest) {
905 Timer t_copy;
906 bestV = v;
907 bestV2 = v2;
908 bestTrees2V = trees2V;
909 bestTrees2V2 = trees2V2;
910 bestTs = ts;
911 bestDistances = distances;
912 bestIteration = iteration;
914 }
915 if(converged)
916 break;
917
918 // --- Update
920 Timer t_update;
921 bool reset
922 = updateStep(geodesicNumber, barycenter, trees, v, v2, matchings, vS_,
923 v2s_, barycenter2, trees2, trees2V, trees2V2, matchings2,
925 printMsg("Update", 1, t_update.getElapsedTime(), threadNumber_);
926 if(reset) {
927 initLoop();
928 continue;
929 }
930
931 // --- Projection
932 printMsg("Projection", 0, 0, threadNumber_, debug::LineMode::REPLACE);
933 Timer t_projection;
934 optMapCost
935 = projectionStep(geodesicNumber, barycenter, v, v2, vS_, v2s_,
936 barycenter2, trees2V, trees2V2, trees2Vs_,
937 trees2V2s_, (trees2.size() != 0), noProjectionStep_);
938 auto projectionTime
939 = t_projection.getElapsedTime() - t_vectorCopy_time_;
941 t_vectorCopy_time_ = 0.0;
942 printMsg("Projection", 1, projectionTime, threadNumber_);
943
944 ++iteration;
945 }
947 printMsg("Best energy is " + std::to_string(minFrechetEnergy)
948 + " (iteration " + std::to_string(bestIteration) + " / "
949 + std::to_string(iteration) + ")");
951
952 Timer t_copy;
953 v = bestV;
954 v2 = bestV2;
955 trees2V = bestTrees2V;
956 trees2V2 = bestTrees2V2;
957 inputToGeodesicsDistances_[geodesicNumber] = bestDistances;
958 allTs_[geodesicNumber] = bestTs;
959
960 vS_.push_back(v);
961 v2s_.push_back(v2);
962 trees2Vs_.push_back(trees2V);
963 trees2V2s_.push_back(trees2V2);
966 }
967
968 template <class dataType>
969 void
971 std::vector<ftm::MergeTree<dataType>> &trees2) {
972 // --- Compute barycenter
973 ftm::MergeTree<dataType> barycenter, barycenter2;
974 if(not keepState_ or not barycenterWasComputed_) {
975 Timer t_barycenter;
976 printMsg("Barycenter", 0, t_barycenter.getElapsedTime(), threadNumber_,
978 computeOneBarycenter<dataType>(trees, barycenter, baryMatchings_,
980 mergeTreeTemplateToDouble(barycenter, barycenterBDT_);
981 if(trees2.size() != 0) {
982 std::vector<double> baryDistances2;
983 computeOneBarycenter<dataType>(trees2, barycenter2, baryMatchings2_,
984 baryDistances2, useDoubleInput_,
985 false);
986 mergeTreeTemplateToDouble(barycenter2, barycenterInput2BDT_);
987 for(unsigned int i = 0; i < inputToBaryDistances_.size(); ++i)
989 = mixDistances(inputToBaryDistances_[i], baryDistances2[i]);
990
991 verifyMinMaxPair(barycenter, barycenter2);
992 }
993 printMsg("Barycenter", 1, t_barycenter.getElapsedTime(), threadNumber_);
995 } else {
996 printMsg("KeepState is enabled and barycenter was already computed");
997 ttk::ftm::mergeTreeDoubleToTemplate<dataType>(
998 barycenterBDT_, barycenter);
999 if(trees2.size() != 0) {
1000 ttk::ftm::mergeTreeDoubleToTemplate<dataType>(
1001 barycenterInput2BDT_, barycenter2);
1002 }
1003 }
1004 printMsg(barycenter.tree.printTreeStats().str());
1005 mergeTreeTemplateToDouble(barycenter, barycenter_);
1006 if(trees2.size() != 0) {
1007 printMsg(barycenter2.tree.printTreeStats().str());
1008 mergeTreeTemplateToDouble(barycenter2, barycenterInput2_);
1009 }
1010
1011 // --- Compute global variance
1012 double globalVariance
1014
1015 // --- Manage maximum number of geodesics
1016 unsigned int maxNoGeodesics = barycenter.tree.getRealNumberOfNodes() * 2;
1017 if(trees2.size() != 0)
1018 maxNoGeodesics += barycenter2.tree.getRealNumberOfNodes() * 2;
1019 if(maxNoGeodesics < numberOfGeodesics_) {
1020 std::stringstream ss;
1021 ss << numberOfGeodesics_ << " principal geodesics are asked but only "
1022 << maxNoGeodesics << " can be computed.";
1023 printMsg(ss.str());
1024 printMsg("(the maximum is twice the number of persistence pairs in the "
1025 "barycenter)");
1026 numberOfGeodesics_ = maxNoGeodesics;
1027 }
1028
1029 // --- Init
1030 unsigned int oldNoGeod = allTs_.size();
1031 if(not keepState_) {
1032 allTs_.resize(numberOfGeodesics_, std::vector<double>(trees.size()));
1034 numberOfGeodesics_, std::vector<double>(trees.size()));
1035 vS_.clear();
1036 v2s_.clear();
1037 trees2Vs_.clear();
1038 trees2V2s_.clear();
1039 allTreesTs_.clear();
1040 srand(deterministic_ ? 7 : time(nullptr));
1041 t_vectorCopy_time_ = 0.0;
1043 cumulVariance_ = 0.0;
1044 cumulTVariance_ = 0.0;
1045 } else {
1046 allTs_.resize(numberOfGeodesics_, std::vector<double>(trees.size()));
1049 numberOfGeodesics_, std::vector<double>(trees.size()));
1050 if(oldNoGeod != 0)
1051 printMsg(
1052 "KeepState is enabled, restart the computation at geodesic number "
1053 + std::to_string(oldNoGeod));
1054 }
1055
1056 // --- Compute each geodesic
1057 for(unsigned int geodNum = oldNoGeod; geodNum < numberOfGeodesics_;
1058 ++geodNum) {
1060 std::stringstream ss;
1061 ss << "Compute geodesic " << geodNum;
1062 printMsg(ss.str());
1063
1064 // - Compute geodesic
1065 computePrincipalGeodesic<dataType>(
1066 geodNum, barycenter, trees, barycenter2, trees2);
1067
1068 // - Compute explained variance
1070 barycenter, trees, barycenter2, trees2, geodNum, globalVariance);
1071 }
1072 }
1073
1074 template <class dataType>
1075 void execute(std::vector<ftm::MergeTree<dataType>> &trees,
1076 std::vector<ftm::MergeTree<dataType>> &trees2) {
1077 // --- Preprocessing
1078 Timer t_preprocess;
1079 preprocessingTrees<dataType>(trees, treesNodeCorr_);
1080 if(trees2.size() != 0)
1081 preprocessingTrees<dataType>(trees2, trees2NodeCorr_);
1082 printMsg(
1083 "Preprocessing", 1, t_preprocess.getElapsedTime(), threadNumber_);
1084 useDoubleInput_ = (trees2.size() != 0);
1085
1086 // --- Compute principal geodesics
1087 Timer t_total;
1088 computePrincipalGeodesics<dataType>(trees, trees2);
1089 auto totalTime = t_total.getElapsedTime() - t_allVectorCopy_time_;
1091 printMsg("Total time", 1, totalTime, threadNumber_);
1092 ftm::MergeTree<dataType> barycenter;
1093 ttk::ftm::mergeTreeDoubleToTemplate<dataType>(barycenter_, barycenter);
1094
1095 // - Compute merge tree geodesic extremities
1096 computeGeodesicExtremities<dataType>();
1097
1098 // - Compute branches correlation matrix
1099 computeBranchesCorrelationMatrix<dataType>(barycenter, trees);
1100
1101 // --- Reconstruction
1103 auto reconstructionError = computeReconstructionError(
1104 barycenter, trees, vS_, v2s_, allTreesTs_);
1105 std::stringstream ss;
1106 ss << "Reconstruction Error = " << reconstructionError;
1107 printMsg(ss.str());
1108 }
1109
1110 // --- Postprocessing
1111 if(normalizedWasserstein_) { // keep BDT if input is a PD
1112 postprocessingPipeline<double>(&(barycenter_.tree));
1113 if(trees2.size() != 0)
1114 postprocessingPipeline<double>(&(barycenterInput2_.tree));
1115 }
1116
1117 ttk::ftm::mergeTreeDoubleToTemplate<dataType>(barycenter_, barycenter);
1118 for(unsigned int i = 0; i < trees.size(); ++i) {
1119 postprocessingPipeline<dataType>(&(trees[i].tree));
1120 convertBranchDecompositionMatching<dataType>(
1121 &(barycenter.tree), &(trees[i].tree), baryMatchings_[i]);
1122 }
1123 for(unsigned int i = 0; i < trees2.size(); ++i)
1124 postprocessingPipeline<dataType>(&(trees2[i].tree));
1125 }
1126
1127 // ----------------------------------------
1128 // End functions
1129 // ----------------------------------------
1130 void copyMinMaxPairVector(std::vector<std::vector<double>> &v,
1131 std::vector<std::vector<double>> &v2,
1132 std::vector<std::vector<double>> &trees2V,
1133 std::vector<std::vector<double>> &trees2V2) {
1134 auto root = barycenter_.tree.getRoot();
1135 auto root2 = barycenterInput2_.tree.getRoot();
1136 trees2V[root2] = v[root];
1137 trees2V2[root2] = v2[root];
1138 }
1139
1140 template <class dataType>
1142 allScaledTs_.resize(allTs_.size(), std::vector<double>(allTs_[0].size()));
1143 ftm::MergeTree<dataType> barycenter, barycenter2;
1144 ttk::ftm::mergeTreeDoubleToTemplate<dataType>(barycenter_, barycenter);
1145 if(trees2NodeCorr_.size() != 0)
1146 ttk::ftm::mergeTreeDoubleToTemplate<dataType>(
1147 barycenterInput2_, barycenter2);
1148#ifdef TTK_ENABLE_OPENMP
1149#pragma omp parallel for schedule(dynamic) \
1150 num_threads(this->threadNumber_) if(parallelize_)
1151#endif
1152 for(unsigned int i = 0; i < numberOfGeodesics_; ++i) {
1153 ftm::MergeTree<dataType> extremityV1, extremityV2;
1154 getInterpolation<dataType>(
1155 barycenter, vS_[i], v2s_[i], 0.0, extremityV1);
1156 getInterpolation<dataType>(
1157 barycenter, vS_[i], v2s_[i], 1.0, extremityV2);
1158 // Get distance
1159 dataType distance;
1161 extremityV1, extremityV2, distance, true, useDoubleInput_);
1162 if(trees2NodeCorr_.size() != 0) {
1163 ftm::MergeTree<dataType> extremity2V1, extremity2V2;
1164 getInterpolation<dataType>(
1165 barycenter2, trees2Vs_[i], trees2V2s_[i], 0.0, extremity2V1);
1166 getInterpolation<dataType>(
1167 barycenter2, trees2Vs_[i], trees2V2s_[i], 1.0, extremity2V2);
1168 // Get distance
1169 dataType distance2;
1170 computeOneDistance(extremity2V1, extremity2V2, distance2, true,
1171 useDoubleInput_, false);
1172 distance = mixDistances(distance, distance2);
1173 }
1174 for(unsigned int j = 0; j < allTs_[i].size(); ++j)
1175 allScaledTs_[i][j] = allTs_[i][j] * distance;
1176 }
1177 }
1178
1179 template <class dataType>
1181 ftm::MergeTree<dataType> &barycenter,
1182 std::vector<ftm::MergeTree<dataType>> &trees) {
1186 }
1187
1188 // ----------------------------------------
1189 // Message
1190 // ----------------------------------------
1191 template <class dataType>
1193 std::vector<ftm::MergeTree<dataType>> &trees,
1194 ftm::MergeTree<dataType> &barycenter2,
1195 std::vector<ftm::MergeTree<dataType>> &trees2,
1196 int geodesicNumber,
1197 double globalVariance) {
1198 bool printOriginalVariances = false;
1199 bool printSurfaceVariance = false;
1200 bool printTVariances = true;
1201
1202 if(printOriginalVariances) {
1203 // Variance
1204 double variance = computeExplainedVariance<dataType>(
1205 barycenter, trees, vS_[geodesicNumber], v2s_[geodesicNumber],
1206 allTs_[geodesicNumber]);
1207 double variancePercent = variance / globalVariance * 100.0;
1208 std::stringstream ssVariance, ssCumul;
1209 ssVariance << "Variance explained : "
1210 << round(variancePercent * 100.0) / 100.0 << " %";
1211 printMsg(ssVariance.str());
1212
1213 // Cumul Variance
1214 cumulVariance_ += variance;
1215 double cumulVariancePercent = cumulVariance_ / globalVariance * 100.0;
1216 ssCumul << "Cumulative explained variance : "
1217 << round(cumulVariancePercent * 100.0) / 100.0 << " %";
1218 printMsg(ssCumul.str());
1219 }
1220
1221 if(printSurfaceVariance) {
1222 // Surface Variance
1223 double surfaceVariance = computeSurfaceExplainedVariance<dataType>(
1224 barycenter, trees, vS_, v2s_, allTs_);
1225 double surfaceVariancePercent
1226 = surfaceVariance / globalVariance * 100.0;
1227 std::stringstream ssSurface;
1228 ssSurface << "Surface Variance explained : "
1229 << round(surfaceVariancePercent * 100.0) / 100.0 << " %";
1230 printMsg(ssSurface.str());
1231 }
1232
1233 if(printTVariances) {
1234 // T-Variance
1235 double tVariance;
1236 if(trees2.size() != 0) {
1237 tVariance = computeExplainedVarianceT(
1238 barycenter, vS_[geodesicNumber], v2s_[geodesicNumber], barycenter2,
1239 trees2Vs_[geodesicNumber], trees2V2s_[geodesicNumber],
1240 allTs_[geodesicNumber]);
1241 } else
1242 tVariance = computeExplainedVarianceT(barycenter, vS_[geodesicNumber],
1243 v2s_[geodesicNumber],
1244 allTs_[geodesicNumber]);
1245 double tVariancePercent = tVariance / globalVariance * 100.0;
1246 std::stringstream ssTVariance, ssCumulT;
1247 ssTVariance << "Explained T-Variance : "
1248 << round(tVariancePercent * 100.0) / 100.0 << " %";
1249 printMsg(ssTVariance.str());
1250
1251 // Cumul T-Variance
1252 cumulTVariance_ += tVariance;
1253 double cumulTVariancePercent = cumulTVariance_ / globalVariance * 100.0;
1254 ssCumulT << "Cumulative explained T-Variance : "
1255 << round(cumulTVariancePercent * 100.0) / 100.0 << " %";
1256 printMsg(ssCumulT.str());
1257 }
1258 }
1259
1260 //----------------------------------------------------------------------------
1261 // Utils
1262 //----------------------------------------------------------------------------
1263 double
1264 verifyOrthogonality(std::vector<std::vector<std::vector<double>>> &vS,
1265 std::vector<std::vector<std::vector<double>>> &v2s,
1266 bool doPrint = true);
1267 double
1268 verifyOrthogonality(std::vector<std::vector<std::vector<double>>> &vS,
1269 std::vector<std::vector<std::vector<double>>> &v2s,
1270 std::vector<std::vector<double>> &v,
1271 std::vector<std::vector<double>> &v2,
1272 bool doPrint = true);
1273
1274 template <class dataType>
1275 dataType computeVarianceFromDistances(std::vector<dataType> &distances);
1276
1277 template <class dataType>
1278 double
1280 std::vector<ftm::MergeTree<dataType>> &trees,
1281 std::vector<std::vector<double>> &v,
1282 std::vector<std::vector<double>> &v2,
1283 std::vector<double> &ts,
1284 bool computeGlobalVariance = false);
1285
1286 template <class dataType>
1288 std::vector<ftm::MergeTree<dataType>> &trees);
1289
1290 template <class dataType>
1292 ftm::MergeTree<dataType> &barycenter,
1293 std::vector<ftm::MergeTree<dataType>> &trees,
1294 std::vector<std::vector<std::vector<double>>> &vS,
1295 std::vector<std::vector<std::vector<double>>> &v2s,
1296 std::vector<std::vector<double>> &ts);
1297
1298 template <class dataType>
1300 std::vector<std::vector<double>> &v,
1301 std::vector<std::vector<double>> &v2,
1302 std::vector<double> &ts,
1303 std::vector<double> &distances,
1304 bool useDoubleInput = false,
1305 bool isFirstInput = true);
1306
1307 template <class dataType>
1309 std::vector<std::vector<double>> &v,
1310 std::vector<std::vector<double>> &v2,
1311 std::vector<double> &ts);
1312
1313 template <class dataType>
1315 std::vector<std::vector<double>> &v,
1316 std::vector<std::vector<double>> &v2,
1317 ftm::MergeTree<dataType> &barycenter2,
1318 std::vector<std::vector<double>> &trees2V,
1319 std::vector<std::vector<double>> &trees2V2,
1320 std::vector<double> &ts);
1321
1322 // ----------------------------------------
1323 // Testing
1324 // ----------------------------------------
1325 template <class dataType>
1327 ftm::MergeTree<dataType> &mTree) {
1328 if(not normalizedWasserstein_) // isPersistenceDiagram
1329 return;
1330 ftm::FTMTree_MT *tree = &(mTree.tree);
1331 std::stringstream ss;
1332 auto birthDeathRoot = tree->getMergedRootBirthDeath<dataType>();
1333 auto birthRoot = std::get<0>(birthDeathRoot);
1334 auto deathRoot = std::get<1>(birthDeathRoot);
1335 bool found = false;
1336 for(unsigned int i = 0; i < tree->getNumberOfNodes(); ++i) {
1337 if(tree->isNodeAlone(i))
1338 continue;
1339 auto birthDeath = tree->getBirthDeath<dataType>(i);
1340 auto birth = std::get<0>(birthDeath);
1341 auto death = std::get<1>(birthDeath);
1342 if(birth < birthRoot or birth > deathRoot or death < birthRoot
1343 or death > deathRoot) {
1344 ss << tree->printNode2<dataType>(i).str() << std::endl;
1345 found = true;
1346 }
1347 }
1348 if(found) {
1349 ftm::FTMTree_MT *tree1 = &(mTree1.tree);
1350 printMsg("Tree1 root:");
1351 printMsg(tree1->printNode2<dataType>(tree1->getRoot()).str());
1352 printMsg("Tree root:");
1353 printMsg(tree->printNode2<dataType>(tree->getRoot()).str());
1354 printMsg("Tree merged root:");
1355 printMsg(tree->printMergedRoot<dataType>().str());
1356 printMsg("Bad pairs:");
1357 printMsg(ss.str());
1358 printErr("[computePrincipalGeodesics] tree root is not min max.");
1359 }
1360 }
1361 }; // MergeTreePrincipalGeodesics class
1362} // namespace ttk
1363
#define ENERGY_COMPARISON_TOLERANCE
int threadNumber_
Definition: BaseClass.h:95
Minimalist debugging class.
Definition: Debug.h:88
void setDebugMsgPrefix(const std::string &prefix)
Definition: Debug.h:364
int printMsg(const std::string &msg, const debug::Priority &priority=debug::Priority::INFO, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cout) const
Definition: Debug.h:118
int printErr(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
Definition: Debug.h:149
std::vector< std::vector< int > > trees2NodeCorr_
void computeOneDistance(ftm::MergeTree< dataType > &tree1, ftm::MergeTree< dataType > &tree2, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matching, dataType &distance, bool isCalled=false, bool useDoubleInput=false, bool isFirstInput=true)
void getMatchingMatrix(ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< std::vector< ftm::idNode > > &matchingMatrix)
void computeBranchesCorrelationMatrix(ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &baryMatchings, std::vector< std::vector< double > > &allTs, std::vector< std::vector< double > > &branchesCorrelationMatrix, std::vector< std::vector< double > > &persCorrelationMatrix)
void getMatchingVector(ftm::MergeTree< dataType > &barycenter, ftm::MergeTree< dataType > &tree, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matchings, std::vector< ftm::idNode > &matchingVector)
double mixDistances(dataType distance1, dataType distance2)
std::vector< std::vector< int > > treesNodeCorr_
Definition: MergeTreeBase.h:60
std::vector< std::vector< double > > allTreesTs_
std::vector< std::vector< std::vector< double > > > v2s_
std::vector< std::vector< std::vector< double > > > trees2V2s_
void getMultiInterpolation(ftm::MergeTree< dataType > &barycenter, std::vector< std::vector< double * > > &vS, std::vector< std::vector< double * > > &v2s, size_t vSize, std::vector< double > &ts, ftm::MergeTree< dataType > &interpolated, bool transposeVector=true)
std::vector< std::vector< std::vector< double > > > vS_
std::vector< std::vector< double > > allTs_
std::vector< std::vector< double > > branchesCorrelationMatrix_
std::vector< std::vector< double > > persCorrelationMatrix_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > baryMatchings_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > baryMatchings2_
std::vector< std::vector< std::vector< double > > > trees2Vs_
std::vector< std::vector< double > > allScaledTs_
dataType computeReconstructionError(ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &inputTrees, std::vector< std::vector< double * > > &vS, std::vector< std::vector< double * > > &v2s, size_t vSize, std::vector< std::vector< double > > &allTreesTs, std::vector< double > &reconstructionErrors, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, bool transposeVector=true)
void callGramSchmidt(std::vector< std::vector< double > > &vS, std::vector< double > &v, std::vector< double > &newV)
double verifyOrthogonality(std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s, bool doPrint=true)
void computeProjectionDistances(ftm::MergeTree< dataType > &barycenter, std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, std::vector< double > &ts, std::vector< double > &distances, bool useDoubleInput=false, bool isFirstInput=true)
void orthogonalProjection(std::vector< std::vector< double > > &v1, std::vector< std::vector< double > > &v2, std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s)
void copyMinMaxPairVector(std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, std::vector< std::vector< double > > &trees2V, std::vector< std::vector< double > > &trees2V2)
void trueGeneralizedGeodesicProjection(std::vector< std::vector< double > > &v1, std::vector< std::vector< double > > &v2)
double computeExplainedVariance(ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, std::vector< double > &ts, bool computeGlobalVariance=false)
double projectionStep(int geodesicNumber, ftm::MergeTree< dataType > &barycenter, std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s, ftm::MergeTree< dataType > &barycenter2, std::vector< std::vector< double > > &trees2V, std::vector< std::vector< double > > &trees2V2, std::vector< std::vector< std::vector< double > > > &trees2Vs, std::vector< std::vector< std::vector< double > > > &trees2V2s, bool useSecondInput, unsigned int noProjectionStep)
void computeBranchesCorrelationMatrix(ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees)
void computePrincipalGeodesics(std::vector< ftm::MergeTree< dataType > > &trees, std::vector< ftm::MergeTree< dataType > > &trees2)
bool updateStep(int geodesicNumber, ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s, ftm::MergeTree< dataType > &barycenter2, std::vector< ftm::MergeTree< dataType > > &trees2, std::vector< std::vector< double > > &trees2V, std::vector< std::vector< double > > &trees2V2, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings2, std::vector< std::vector< std::vector< double > > > &trees2Vs, std::vector< std::vector< std::vector< double > > > &trees2V2s, std::vector< double > &ts, std::vector< std::vector< double > > &allTreesTs)
void assignmentStep(ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, ftm::MergeTree< dataType > &barycenter2, std::vector< ftm::MergeTree< dataType > > &trees2, std::vector< std::vector< double > > &trees2V, std::vector< std::vector< double > > &trees2V2, std::vector< std::vector< double > > &allTreesTs, std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s, std::vector< std::vector< std::vector< double > > > &trees2Vs, std::vector< std::vector< std::vector< double > > > &trees2V2s, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings2, std::vector< double > &ts, std::vector< double > &distances)
void execute(std::vector< ftm::MergeTree< dataType > > &trees, std::vector< ftm::MergeTree< dataType > > &trees2)
dataType computeVarianceFromDistances(std::vector< dataType > &distances)
bool updateClosedFormStep(int geodesicNumber, ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s, ftm::MergeTree< dataType > &barycenter2, std::vector< ftm::MergeTree< dataType > > &trees2, std::vector< std::vector< double > > &trees2V, std::vector< std::vector< double > > &trees2V2, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings2, std::vector< std::vector< std::vector< double > > > &trees2Vs, std::vector< std::vector< std::vector< double > > > &trees2V2s, std::vector< double > &ts, std::vector< std::vector< double > > &allTreesTs)
double regularizerCost(std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2)
double computeSurfaceExplainedVariance(ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s, std::vector< std::vector< double > > &ts)
void assignmentImpl(ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, ftm::MergeTree< dataType > &barycenter2, std::vector< ftm::MergeTree< dataType > > &trees2, std::vector< std::vector< double > > &trees2V, std::vector< std::vector< double > > &trees2V2, std::vector< std::vector< double > > &allTreesTs, std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s, std::vector< std::vector< std::vector< double > > > &trees2Vs, std::vector< std::vector< std::vector< double > > > &trees2V2s, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings2, std::vector< double > &ts, std::vector< double > &distances)
void printIterationVariances(ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, ftm::MergeTree< dataType > &barycenter2, std::vector< ftm::MergeTree< dataType > > &trees2, int geodesicNumber, double globalVariance)
double orthogonalCost(std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s, std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2)
std::vector< std::vector< double > > inputToGeodesicsDistances_
void verifyMinMaxPair(ftm::MergeTree< dataType > &mTree1, ftm::MergeTree< dataType > &mTree)
void initVectors(int geodesicNumber, ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, ftm::MergeTree< dataType > &barycenter2, std::vector< ftm::MergeTree< dataType > > &trees2, std::vector< std::vector< double > > &v1, std::vector< std::vector< double > > &v2, std::vector< std::vector< double > > &trees2V1, std::vector< std::vector< double > > &trees2V2)
void manageIndividualTs(int geodesicNumber, ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s, std::vector< double > &ts, std::vector< std::vector< double > > &allTreesTs, std::vector< ftm::MergeTree< dataType > > &allInterpolated, std::vector< bool > &isUniform, std::vector< std::vector< double > > &tss, unsigned int &noUniform, bool &foundAllUniform)
double computeExplainedVarianceT(ftm::MergeTree< dataType > &barycenter, std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, std::vector< double > &ts)
double barycentricProjection(ftm::MergeTree< dataType > &barycenter, ftm::MergeTree< dataType > &extremity, std::vector< std::vector< double > > &v, bool isV1, bool useDoubleInput=false, bool isFirstInput=true)
double projectionCost(std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s, double optMapCost)
void computePrincipalGeodesic(unsigned int geodesicNumber, ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, ftm::MergeTree< dataType > &barycenter2, std::vector< ftm::MergeTree< dataType > > &trees2)
double computeGlobalVariance(ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees)
double optimalMappingSetProjection(ftm::MergeTree< dataType > &barycenter, std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, ftm::MergeTree< dataType > &barycenter2, std::vector< std::vector< double > > &trees2V, std::vector< std::vector< double > > &trees2V2, bool useSecondInput=false)
bool convergenceStep(std::vector< double > &distances, std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, dataType &oldFrechetEnergy, dataType &minFrechetEnergy, int &cptBlocked, bool &converged, double optMapCost)
void updateClosedForm(int geodesicNumber, ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, std::vector< ftm::MergeTree< dataType > > &allInterpolated, std::vector< std::vector< double > > &v, std::vector< std::vector< double > > &v2, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< std::vector< double > > &tss, std::vector< std::vector< double > > &vR, std::vector< std::vector< double > > &vR2, std::vector< bool > &isUniform)
double getElapsedTime()
Definition: Timer.h:15
idNode getNumberOfNodes() const
Definition: FTMTree_MT.h:389
std::tuple< dataType, dataType > getBirthDeath(idNode nodeId)
std::stringstream printNode2(idNode nodeId, bool doPrint=true)
std::tuple< dataType, dataType > getMergedRootBirthDeath()
bool isNodeAlone(idNode nodeId)
std::stringstream printMergedRoot(bool doPrint=true)
std::stringstream printTreeStats(bool doPrint=true)
int scaleVector(const T *a, const T factor, T *out, const int &dimension=3)
Definition: Geometry.cpp:575
int addVectors(const T *a, const T *b, T *out, const int &dimension=3)
Definition: Geometry.cpp:538
T dotProductFlatten(const std::vector< std::vector< T > > &vA, const std::vector< std::vector< T > > &vB)
Definition: Geometry.cpp:394
int flattenMultiDimensionalVector(const std::vector< std::vector< T > > &a, std::vector< T > &out)
Definition: Geometry.cpp:670
bool isVectorUniform(const std::vector< T > &a)
Definition: Geometry.cpp:647
T magnitudeFlatten(const std::vector< std::vector< T > > &v)
Definition: Geometry.cpp:501
void transposeMatrix(const std::vector< std::vector< T > > &a, std::vector< std::vector< T > > &out)
Definition: Geometry.cpp:746
int unflattenMultiDimensionalVector(const std::vector< T > &a, std::vector< std::vector< T > > &out, const int &no_columns=2)
Definition: Geometry.cpp:690
T magnitude(const T *v, const int &dimension=3)
Definition: Geometry.cpp:491
T distanceFlatten(const std::vector< std::vector< T > > &p0, const std::vector< std::vector< T > > &p1)
Definition: Geometry.cpp:361
int multiAddVectorsFlatten(const std::vector< std::vector< std::vector< T > > > &a, const std::vector< std::vector< std::vector< T > > > &b, std::vector< std::vector< T > > &out)
Definition: Geometry.cpp:563
The Topology ToolKit.
std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > bestMatching2
std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > bestMatching
ftm::FTMTree_MT tree
Definition: FTMTree_MT.h:901