5using namespace torch::indexing;
13#ifdef TTK_ENABLE_TORCH
17const ttk::mtu::TorchMergeTree<float> &
18 ttk::MergeTreeNeuralLayer::getOrigin()
const {
22const ttk::mtu::TorchMergeTree<float> &
23 ttk::MergeTreeNeuralLayer::getOriginPrime()
const {
27const ttk::mtu::TorchMergeTree<float> &
28 ttk::MergeTreeNeuralLayer::getOrigin2()
const {
32const ttk::mtu::TorchMergeTree<float> &
33 ttk::MergeTreeNeuralLayer::getOrigin2Prime()
const {
37const torch::Tensor &ttk::MergeTreeNeuralLayer::getVSTensor()
const {
41const torch::Tensor &ttk::MergeTreeNeuralLayer::getVSPrimeTensor()
const {
42 return vSPrimeTensor_;
45const torch::Tensor &ttk::MergeTreeNeuralLayer::getVS2Tensor()
const {
49const torch::Tensor &ttk::MergeTreeNeuralLayer::getVS2PrimeTensor()
const {
50 return vS2PrimeTensor_;
53void ttk::MergeTreeNeuralLayer::setOrigin(
54 const mtu::TorchMergeTree<float> &tmt) {
55 mtu::copyTorchMergeTree(tmt, origin_);
58void ttk::MergeTreeNeuralLayer::setOriginPrime(
59 const mtu::TorchMergeTree<float> &tmt) {
60 mtu::copyTorchMergeTree(tmt, originPrime_);
63void ttk::MergeTreeNeuralLayer::setOrigin2(
64 const mtu::TorchMergeTree<float> &tmt) {
65 mtu::copyTorchMergeTree(tmt, origin2_);
68void ttk::MergeTreeNeuralLayer::setOrigin2Prime(
69 const mtu::TorchMergeTree<float> &tmt) {
70 mtu::copyTorchMergeTree(tmt, origin2Prime_);
73void ttk::MergeTreeNeuralLayer::setVSTensor(
const torch::Tensor &vS) {
74 mtu::copyTensor(vS, vSTensor_);
77void ttk::MergeTreeNeuralLayer::setVSPrimeTensor(
const torch::Tensor &vS) {
78 mtu::copyTensor(vS, vSPrimeTensor_);
81void ttk::MergeTreeNeuralLayer::setVS2Tensor(
const torch::Tensor &vS) {
82 mtu::copyTensor(vS, vS2Tensor_);
85void ttk::MergeTreeNeuralLayer::setVS2PrimeTensor(
const torch::Tensor &vS) {
86 mtu::copyTensor(vS, vS2PrimeTensor_);
92void ttk::MergeTreeNeuralLayer::initOutputBasisTreeStructure(
93 mtu::TorchMergeTree<float> &originPrime,
95 mtu::TorchMergeTree<float> &baseOrigin) {
97 torch::Tensor originTensor = originPrime.tensor;
98 if(!originTensor.device().is_cpu())
99 originTensor = originTensor.cpu();
100 std::vector<float> scalarsVector(
101 originTensor.data_ptr<
float>(),
102 originTensor.data_ptr<
float>() + originTensor.numel());
103 unsigned int noNodes = scalarsVector.size() / 2;
104 std::vector<std::vector<ftm::idNode>> childrenFinal(noNodes);
107 if(isPersistenceDiagram_) {
108 for(
unsigned int i = 2; i < scalarsVector.size(); i += 2)
109 childrenFinal[0].emplace_back(i / 2);
112 float maxPers = std::numeric_limits<float>::lowest();
113 unsigned int indMax = 0;
114 for(
unsigned int i = 0; i < scalarsVector.size(); i += 2) {
115 if(maxPers < (scalarsVector[i + 1] - scalarsVector[i])) {
116 maxPers = (scalarsVector[i + 1] - scalarsVector[i]);
121 float temp = scalarsVector[0];
122 scalarsVector[0] = scalarsVector[indMax];
123 scalarsVector[indMax] = temp;
124 temp = scalarsVector[1];
125 scalarsVector[1] = scalarsVector[indMax + 1];
126 scalarsVector[indMax + 1] = temp;
129 for(
unsigned int i = 2; i < scalarsVector.size(); i += 2) {
131 adjustNestingScalars(scalarsVector, node, refNode);
134 if(not initOriginPrimeStructByCopy_
135 or (
int) noNodes > baseOrigin.mTree.tree.getRealNumberOfNodes()) {
137 std::vector<std::vector<ftm::idNode>> parents(noNodes), children(noNodes);
138 for(
unsigned int i = 0; i < scalarsVector.size(); i += 2) {
139 for(
unsigned int j = i; j < scalarsVector.size(); j += 2) {
142 unsigned int iN = i / 2, jN = j / 2;
143 if(scalarsVector[i] <= scalarsVector[j]
144 and scalarsVector[i + 1] >= scalarsVector[j + 1]) {
146 parents[jN].emplace_back(iN);
147 children[iN].emplace_back(jN);
148 }
else if(scalarsVector[i] >= scalarsVector[j]
149 and scalarsVector[i + 1] <= scalarsVector[j + 1]) {
151 parents[iN].emplace_back(jN);
152 children[jN].emplace_back(iN);
156 createBalancedBDT(parents, children, scalarsVector, childrenFinal);
161 keepMostImportantPairs<float>(&(mTreeTemp.tree), noNodes, useBD);
162 torch::Tensor reshaped = torch::tensor(scalarsVector).reshape({-1, 2});
163 torch::Tensor order = torch::argsort(
164 (reshaped.index({Slice(), 1}) - reshaped.index({Slice(), 0})), -1,
166 std::vector<unsigned int> nodeCorr(mTreeTemp.tree.getNumberOfNodes(), 0);
167 unsigned int nodeNum = 1;
168 std::queue<ftm::idNode> queue;
169 queue.emplace(mTreeTemp.tree.getRoot());
170 while(!queue.empty()) {
173 std::vector<ftm::idNode> children;
174 mTreeTemp.tree.getChildren(node, children);
175 for(
auto &child : children) {
176 queue.emplace(child);
177 unsigned int tNode = nodeCorr[node];
178 nodeCorr[child] = order[nodeNum].item<
int>();
180 unsigned int tChild = nodeCorr[child];
181 childrenFinal[tNode].emplace_back(tChild);
182 adjustNestingScalars(scalarsVector, tChild, tNode);
192 for(
unsigned int i = 0; i < scalarsVector.size(); i += 2) {
193 float temp = scalarsVector[i];
194 scalarsVector[i] = scalarsVector[i + 1];
195 scalarsVector[i + 1] = temp;
201 originPrime.nodeCorr.clear();
202 originPrime.nodeCorr.assign(
203 scalarsVector.size(), std::numeric_limits<unsigned int>::max());
204 for(
unsigned int i = 0; i < scalarsVector.size(); i += 2) {
206 tree->makeNode(i + 1);
207 tree->getNode(i)->setOrigin(i + 1);
208 tree->getNode(i + 1)->setOrigin(i);
209 originPrime.nodeCorr[i] = (
unsigned int)(i / 2);
211 for(
unsigned int i = 0; i < scalarsVector.size(); i += 2) {
212 unsigned int node = i / 2;
213 for(
auto &child : childrenFinal[node])
214 tree->makeSuperArc(child * 2, i);
216 mtu::getParentsVector(originPrime.mTree, originPrime.parentsOri);
218 if(isTreeHasBigValues(originPrime.mTree, bigValuesThreshold_)) {
219 std::stringstream ss;
220 ss << originPrime.mTree.tree.printPairsFromTree<
float>(
true).str()
222 ss <<
"isTreeHasBigValues(originPrime.mTree)" << std::endl;
223 ss <<
"pause" << std::endl;
229void ttk::MergeTreeNeuralLayer::initOutputBasis(
230 const unsigned int dim,
231 const unsigned int dim2,
232 const torch::Tensor &baseTensor) {
233 unsigned int originSize = origin_.tensor.sizes()[0];
234 unsigned int origin2Size = 0;
236 origin2Size = origin2_.tensor.sizes()[0];
240 auto initOutputBasisOrigin = [
this, &baseTensor](
242 mtu::TorchMergeTree<float> &tmt,
243 mtu::TorchMergeTree<float> &baseTmt) {
245 torch::nn::init::xavier_normal_(w);
246 torch::Tensor baseTmtTensor = baseTmt.tensor;
247 if(normalizedWasserstein_)
249 mtu::mergeTreeToTorchTensor(baseTmt.mTree, baseTmtTensor,
false);
251 = torch::full({w.sizes()[0], 1}, 0.01,
252 torch::TensorOptions().device(baseTmtTensor.device()));
253 tmt.tensor = (torch::matmul(w, baseTmtTensor) + b);
255 mtu::meanBirthMaxPersShift(tmt.tensor, baseTmtTensor);
257 mtu::belowDiagonalPointsShift(tmt.tensor, baseTmtTensor);
259 if(initOriginPrimeValuesByCopy_) {
260 auto baseTensorDiag = baseTensor.reshape({-1, 2});
261 auto basePersDiag = (baseTensorDiag.index({Slice(), 1})
262 - baseTensorDiag.index({Slice(), 0}));
263 auto tmtTensorDiag = tmt.tensor.reshape({-1, 2});
264 auto persDiag = (tmtTensorDiag.index({Slice(1, None), 1})
265 - tmtTensorDiag.index({Slice(1, None), 0}));
266 int noK = std::min(baseTensorDiag.sizes()[0], tmtTensorDiag.sizes()[0]);
267 auto topVal = baseTensorDiag.index({std::get<1>(basePersDiag.topk(noK))});
268 auto indexes = std::get<1>(persDiag.topk(noK - 1)) + 1;
270 = torch::zeros(1, torch::TensorOptions().device(indexes.device()));
271 indexes = torch::cat({zeros, indexes}).to(torch::kLong);
272 if(initOriginPrimeValuesByCopyRandomness_ != 0) {
273 topVal = (1 - initOriginPrimeValuesByCopyRandomness_) * topVal
274 + initOriginPrimeValuesByCopyRandomness_
275 * tmtTensorDiag.index({indexes});
277 tmtTensorDiag.index_put_({indexes}, topVal);
280 initOutputBasisTreeStructure(
281 tmt, baseTmt.mTree.tree.isJoinTree<
float>(), baseTmt);
282 if(normalizedWasserstein_)
284 mtu::mergeTreeToTorchTensor(tmt.mTree, tmt.tensor,
true);
286 interpolationProjection(tmt);
288 torch::Tensor w = torch::zeros(
289 {dim, originSize}, torch::TensorOptions().device(origin_.tensor.device()));
290 initOutputBasisOrigin(w, originPrime_, origin_);
292 if(useDoubleInput_) {
293 w2 = torch::zeros({dim2, origin2Size},
294 torch::TensorOptions().device(origin2_.tensor.device()));
295 initOutputBasisOrigin(w2, origin2Prime_, origin2_);
300 initOutputBasisVectors(w, w2);
303void ttk::MergeTreeNeuralLayer::initOutputBasisVectors(torch::Tensor &w,
305 vSPrimeTensor_ = torch::matmul(w, vSTensor_);
307 vS2PrimeTensor_ = torch::matmul(w2, vS2Tensor_);
308 if(normalizedWasserstein_) {
309 mtu::normalizeVectors(originPrime_.tensor, vSPrimeTensor_);
311 mtu::normalizeVectors(origin2Prime_.tensor, vS2PrimeTensor_);
315void ttk::MergeTreeNeuralLayer::initOutputBasisVectors(
unsigned int dim,
317 unsigned int originSize = origin_.tensor.sizes()[0];
318 unsigned int origin2Size = 0;
320 origin2Size = origin2_.tensor.sizes()[0];
321 torch::Tensor w = torch::zeros({dim, originSize});
322 torch::nn::init::xavier_normal_(w);
323 torch::Tensor w2 = torch::zeros({dim2, origin2Size});
324 torch::nn::init::xavier_normal_(w2);
325 initOutputBasisVectors(w, w2);
328void ttk::MergeTreeNeuralLayer::initInputBasisOrigin(
331 double barycenterSizeLimitPercent,
332 unsigned int barycenterMaxNoPairs,
333 unsigned int barycenterMaxNoPairs2,
334 std::vector<double> &inputToBaryDistances,
335 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
337 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
339 int barycenterInitIndex = -1;
340 if(initBarycenterRandom_) {
341 std::random_device rd;
342 std::default_random_engine rng(deterministic_ ? 0 : rd());
344 = std::uniform_int_distribution<>(0, treesToUse.size() - 1)(rng);
346 int maxNoPairs = (initBarycenterRandom_ ? barycenterMaxNoPairs : 0);
347 computeOneBarycenter<float>(treesToUse, origin_.mTree, baryMatchings,
348 inputToBaryDistances, barycenterSizeLimitPercent,
349 maxNoPairs, barycenterInitIndex,
350 initBarycenterOneIter_, useDoubleInput_,
true);
351 if(not initBarycenterRandom_ and barycenterMaxNoPairs > 0)
352 keepMostImportantPairs<float>(
353 &(origin_.mTree.tree), barycenterMaxNoPairs,
true);
354 if(useDoubleInput_) {
355 std::vector<double> baryDistances2;
356 int maxNoPairs2 = (initBarycenterRandom_ ? barycenterMaxNoPairs2 : 0);
357 computeOneBarycenter<float>(trees2ToUse, origin2_.mTree, baryMatchings2,
358 baryDistances2, barycenterSizeLimitPercent,
359 maxNoPairs2, barycenterInitIndex,
360 initBarycenterOneIter_, useDoubleInput_,
false);
361 if(not initBarycenterRandom_ and barycenterMaxNoPairs2 > 0)
362 keepMostImportantPairs<float>(
363 &(origin2_.mTree.tree), barycenterMaxNoPairs2,
true);
364 for(
unsigned int i = 0; i < inputToBaryDistances.size(); ++i)
365 inputToBaryDistances[i]
366 = mixDistances(inputToBaryDistances[i], baryDistances2[i]);
369 mtu::getParentsVector(origin_.mTree, origin_.parentsOri);
370 mtu::mergeTreeToTorchTensor<float>(
371 origin_.mTree, origin_.tensor, origin_.nodeCorr, normalizedWasserstein_);
373 origin_.tensor = origin_.tensor.cuda();
374 if(useDoubleInput_) {
375 mtu::getParentsVector(origin2_.mTree, origin2_.parentsOri);
376 mtu::mergeTreeToTorchTensor<float>(origin2_.mTree, origin2_.tensor,
378 normalizedWasserstein_);
380 origin2_.tensor = origin2_.tensor.cuda();
384void ttk::MergeTreeNeuralLayer::initInputBasisVectors(
385 std::vector<mtu::TorchMergeTree<float>> &tmTrees,
386 std::vector<mtu::TorchMergeTree<float>> &tmTrees2,
389 unsigned int noVectors,
390 std::vector<torch::Tensor> &allAlphasInit,
391 std::vector<double> &inputToBaryDistances,
392 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
394 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
396 mtu::TorchMergeTree<float> &origin,
397 mtu::TorchMergeTree<float> &origin2,
398 torch::Tensor &vSTensor,
399 torch::Tensor &vS2Tensor,
400 bool useInputBasis) {
401 if(randomAxesInit_) {
402 auto initRandomAxes = [&noVectors](mtu::TorchMergeTree<float> &originT,
403 torch::Tensor &axes) {
404 torch::Tensor w = torch::zeros({noVectors, originT.tensor.sizes()[0]});
405 torch::nn::init::xavier_normal_(w);
406 axes = torch::linalg_pinv(w);
408 initRandomAxes(origin, vSTensor);
410 vSTensor = vSTensor.cuda();
411 if(useDoubleInput_) {
412 initRandomAxes(origin2, vS2Tensor);
414 vS2Tensor = vS2Tensor.cuda();
416#ifdef TTK_ENABLE_OPENMP
417#pragma omp parallel for schedule(dynamic) \
418 num_threads(this->threadNumber_) if(parallelize_)
420 for(
unsigned int i = 0; i < trees.size(); ++i)
421 allAlphasInit[i] = torch::randn({noVectors, 1});
426 auto initializedVectorsProjection
429 std::vector<std::vector<double>> &_v,
430 std::vector<std::vector<double>> &
ttkNotUsed(_v2),
431 std::vector<std::vector<std::vector<double>>> &_vS,
432 std::vector<std::vector<std::vector<double>>> &
ttkNotUsed(_v2s),
434 std::vector<std::vector<double>> &
ttkNotUsed(_trees2V),
435 std::vector<std::vector<double>> &
ttkNotUsed(_trees2V2),
436 std::vector<std::vector<std::vector<double>>> &
ttkNotUsed(_trees2Vs),
437 std::vector<std::vector<std::vector<double>>> &
ttkNotUsed(_trees2V2s),
440 std::vector<double> scaledV, scaledVSi;
444 for(
unsigned int i = 0; i < _vS.size(); ++i) {
450 if(prod <= -1.0 + tol or prod >= 1.0 - tol) {
452 for(
unsigned int j = 0; j < _v.size(); ++j)
453 for(
unsigned int k = 0; k < _v[j].size(); ++k)
462 std::vector<std::vector<double>> inputToAxesDistances;
463 std::vector<std::vector<std::vector<double>>> vS, v2s, trees2Vs, trees2V2s;
464 std::stringstream ss;
465 for(
unsigned int vecNum = 0; vecNum < noVectors; ++vecNum) {
467 ss <<
"Compute vectors " << vecNum;
469 std::vector<std::vector<double>> v1, v2, trees2V1, trees2V2;
470 int newVectorOffset = 0;
471 bool projectInitializedVectors =
true;
473 vecNum, origin.mTree, trees, origin2.mTree, trees2, v1, v2, trees2V1,
474 trees2V2, newVectorOffset, inputToBaryDistances, baryMatchings,
475 baryMatchings2, inputToAxesDistances, vS, v2s, trees2Vs, trees2V2s,
476 projectInitializedVectors, initializedVectorsProjection);
478 v2s.emplace_back(v2);
479 trees2Vs.emplace_back(trees2V1);
480 trees2V2s.emplace_back(trees2V2);
483 ss <<
"bestIndex = " << bestIndex;
488 inputToAxesDistances.resize(1, std::vector<double>(trees.size()));
489 if(bestIndex == -1 and normalizedWasserstein_) {
490 mtu::normalizeVectors(origin, vS[vS.size() - 1]);
492 mtu::normalizeVectors(origin2, trees2Vs[vS.size() - 1]);
494 mtu::axisVectorsToTorchTensor(origin.mTree, vS, vSTensor);
496 vSTensor = vSTensor.cuda();
497 if(useDoubleInput_) {
498 mtu::axisVectorsToTorchTensor(origin2.mTree, trees2Vs, vS2Tensor);
500 vS2Tensor = vS2Tensor.cuda();
502 mtu::TorchMergeTree<float> dummyTmt;
503 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>
505#ifdef TTK_ENABLE_OPENMP
506#pragma omp parallel for schedule(dynamic) \
507 num_threads(this->threadNumber_) if(parallelize_)
509 for(
unsigned int i = 0; i < trees.size(); ++i) {
510 auto &tmt2ToUse = (not useDoubleInput_ ? dummyTmt : tmTrees2[i]);
511 if(not euclideanVectorsInit_) {
513 auto newAlpha = torch::ones({1, 1});
514 if(bestIndex == -1) {
515 newAlpha = torch::zeros({1, 1});
517 allAlphasInit[i] = (allAlphasInit[i].defined()
518 ? torch::cat({allAlphasInit[i], newAlpha})
520 torch::Tensor bestAlphas;
521 bool isCalled =
true;
522 inputToAxesDistances[0][i]
523 = assignmentOneData(tmTrees[i], tmt2ToUse, k, allAlphasInit[i],
524 bestAlphas, isCalled, useInputBasis);
525 allAlphasInit[i] = bestAlphas.detach();
527 auto &baryMatching2ToUse
528 = (not useDoubleInput_ ? dummyBaryMatching2 : baryMatchings2[i]);
529 torch::Tensor alphas;
530 computeAlphas(tmTrees[i], origin, vSTensor, origin, baryMatchings[i],
531 tmt2ToUse, origin2, vS2Tensor, origin2,
532 baryMatching2ToUse, alphas);
533 mtu::TorchMergeTree<float> interpolated, interpolated2;
534 getMultiInterpolation(origin, vSTensor, alphas, interpolated);
536 getMultiInterpolation(origin2, vS2Tensor, alphas, interpolated2);
537 torch::Tensor tensorDist;
539 getDifferentiableDistanceFromMatchings(
540 interpolated, tmTrees[i], interpolated2, tmt2ToUse, baryMatchings[i],
541 baryMatching2ToUse, tensorDist, doSqrt);
542 inputToAxesDistances[0][i] = tensorDist.item<
double>();
543 allAlphasInit[i] = alphas.detach();
549void ttk::MergeTreeNeuralLayer::initInputBasisVectors(
550 std::vector<mtu::TorchMergeTree<float>> &tmTrees,
551 std::vector<mtu::TorchMergeTree<float>> &tmTrees2,
554 unsigned int noVectors,
555 std::vector<torch::Tensor> &allAlphasInit,
556 std::vector<double> &inputToBaryDistances,
557 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
559 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
561 bool useInputBasis) {
562 mtu::TorchMergeTree<float> &origin = (useInputBasis ? origin_ : originPrime_);
563 mtu::TorchMergeTree<float> &origin2
564 = (useInputBasis ? origin2_ : origin2Prime_);
565 torch::Tensor &vSTensor = (useInputBasis ? vSTensor_ : vSPrimeTensor_);
566 torch::Tensor &vS2Tensor = (useInputBasis ? vS2Tensor_ : vS2PrimeTensor_);
568 initInputBasisVectors(tmTrees, tmTrees2, trees, trees2, noVectors,
569 allAlphasInit, inputToBaryDistances, baryMatchings,
570 baryMatchings2, origin, origin2, vSTensor, vS2Tensor,
574void ttk::MergeTreeNeuralLayer::requires_grad(
const bool requireGrad) {
575 origin_.tensor.requires_grad_(requireGrad);
576 originPrime_.tensor.requires_grad_(requireGrad);
577 vSTensor_.requires_grad_(requireGrad);
578 vSPrimeTensor_.requires_grad_(requireGrad);
579 if(useDoubleInput_) {
580 origin2_.tensor.requires_grad_(requireGrad);
581 origin2Prime_.tensor.requires_grad_(requireGrad);
582 vS2Tensor_.requires_grad_(requireGrad);
583 vS2PrimeTensor_.requires_grad_(requireGrad);
587void ttk::MergeTreeNeuralLayer::cuda() {
588 origin_.tensor = origin_.tensor.cuda();
589 originPrime_.tensor = originPrime_.tensor.cuda();
590 vSTensor_ = vSTensor_.cuda();
591 vSPrimeTensor_ = vSPrimeTensor_.cuda();
592 if(useDoubleInput_) {
593 origin2_.tensor = origin2_.tensor.cuda();
594 origin2Prime_.tensor = origin2Prime_.tensor.cuda();
595 vS2Tensor_ = vS2Tensor_.cuda();
596 vS2PrimeTensor_ = vS2PrimeTensor_.cuda();
603void ttk::MergeTreeNeuralLayer::interpolationDiagonalProjection(
604 mtu::TorchMergeTree<float> &interpolation) {
605 torch::Tensor diagTensor = interpolation.tensor.reshape({-1, 2});
606 if(interpolation.tensor.requires_grad())
607 diagTensor = diagTensor.detach();
609 torch::Tensor birthTensor = diagTensor.index({Slice(), 0});
610 torch::Tensor deathTensor = diagTensor.index({Slice(), 1});
612 torch::Tensor indexer = (birthTensor > deathTensor);
614 torch::Tensor allProj = (birthTensor + deathTensor) / 2.0;
615 allProj = allProj.index({indexer});
616 allProj = allProj.reshape({-1, 1});
618 diagTensor.index_put_({indexer}, allProj);
621void ttk::MergeTreeNeuralLayer::interpolationNestingProjection(
622 mtu::TorchMergeTree<float> &interpolation) {
623 torch::Tensor diagTensor = interpolation.tensor.reshape({-1, 2});
624 if(interpolation.tensor.requires_grad())
625 diagTensor = diagTensor.detach();
627 torch::Tensor birthTensor = diagTensor.index({Slice(1, None), 0});
628 torch::Tensor deathTensor = diagTensor.index({Slice(1, None), 1});
630 torch::Tensor birthIndexer = (birthTensor < 0);
631 torch::Tensor deathIndexer = (deathTensor < 0);
632 birthTensor.index_put_(
633 {birthIndexer}, torch::zeros_like(birthTensor.index({birthIndexer})));
634 deathTensor.index_put_(
635 {deathIndexer}, torch::zeros_like(deathTensor.index({deathIndexer})));
637 birthIndexer = (birthTensor > 1);
638 deathIndexer = (deathTensor > 1);
639 birthTensor.index_put_(
640 {birthIndexer}, torch::ones_like(birthTensor.index({birthIndexer})));
641 deathTensor.index_put_(
642 {deathIndexer}, torch::ones_like(deathTensor.index({deathIndexer})));
645void ttk::MergeTreeNeuralLayer::interpolationProjection(
646 mtu::TorchMergeTree<float> &interpolation) {
647 interpolationDiagonalProjection(interpolation);
648 if(normalizedWasserstein_)
649 interpolationNestingProjection(interpolation);
652 bool noRoot = mtu::torchTensorToMergeTree<float>(
653 interpolation, normalizedWasserstein_, interpolationNew);
655 printWrn(
"[interpolationProjection] no root found");
658 persistenceThresholding<float>(&(interpolation.mTree.tree), 0.001);
660 if(isPersistenceDiagram_ and isThereMissingPairs(interpolation))
661 printWrn(
"[getMultiInterpolation] missing pairs");
664void ttk::MergeTreeNeuralLayer::getMultiInterpolation(
665 const mtu::TorchMergeTree<float> &origin,
666 const torch::Tensor &vS,
667 torch::Tensor &alphas,
668 mtu::TorchMergeTree<float> &interpolation) {
669 mtu::copyTorchMergeTree<float>(origin, interpolation);
670 interpolation.tensor = origin.tensor + torch::matmul(vS, alphas);
671 interpolationProjection(interpolation);
677void ttk::MergeTreeNeuralLayer::getAlphasOptimizationTensors(
678 mtu::TorchMergeTree<float> &tree,
679 mtu::TorchMergeTree<float> &origin,
680 torch::Tensor &vSTensor,
681 mtu::TorchMergeTree<float> &interpolated,
682 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
683 torch::Tensor &reorderedTreeTensor,
684 torch::Tensor &deltaOrigin,
685 torch::Tensor &deltaA,
686 torch::Tensor &originTensor_f,
687 torch::Tensor &vSTensor_f) {
689 std::vector<int> tensorMatching;
690 mtu::getTensorMatching(interpolated, tree, matching, tensorMatching);
692 torch::Tensor indexes = torch::tensor(tensorMatching);
693 torch::Tensor projIndexer = (indexes == -1).reshape({-1, 1});
695 dataReorderingGivenMatching(
696 origin, tree, projIndexer, indexes, reorderedTreeTensor, deltaOrigin);
699 deltaA = vSTensor.transpose(0, 1).reshape({vSTensor.sizes()[1], -1, 2});
700 deltaA = (deltaA.index({Slice(), Slice(), 0})
701 + deltaA.index({Slice(), Slice(), 1}))
703 deltaA = torch::stack({deltaA, deltaA}, 2);
704 if(!deltaA.device().is_cpu())
705 projIndexer = projIndexer.to(deltaA.device());
706 deltaA = deltaA * projIndexer;
707 deltaA = deltaA.reshape({vSTensor.sizes()[1], -1}).transpose(0, 1);
710 originTensor_f = origin.tensor;
711 vSTensor_f = vSTensor;
714void ttk::MergeTreeNeuralLayer::computeAlphas(
715 mtu::TorchMergeTree<float> &tree,
716 mtu::TorchMergeTree<float> &origin,
717 torch::Tensor &vSTensor,
718 mtu::TorchMergeTree<float> &interpolated,
719 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
720 mtu::TorchMergeTree<float> &tree2,
721 mtu::TorchMergeTree<float> &origin2,
722 torch::Tensor &vS2Tensor,
723 mtu::TorchMergeTree<float> &interpolated2,
724 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
725 torch::Tensor &alphasOut) {
726 torch::Tensor reorderedTreeTensor, deltaOrigin, deltaA, originTensor_f,
728 getAlphasOptimizationTensors(tree, origin, vSTensor, interpolated, matching,
729 reorderedTreeTensor, deltaOrigin, deltaA,
730 originTensor_f, vSTensor_f);
732 if(useDoubleInput_) {
733 torch::Tensor reorderedTree2Tensor, deltaOrigin2, deltaA2, origin2Tensor_f,
735 getAlphasOptimizationTensors(tree2, origin2, vS2Tensor, interpolated2,
736 matching2, reorderedTree2Tensor, deltaOrigin2,
737 deltaA2, origin2Tensor_f, vS2Tensor_f);
738 vSTensor_f = torch::cat({vSTensor_f, vS2Tensor_f});
739 deltaA = torch::cat({deltaA, deltaA2});
741 = torch::cat({reorderedTreeTensor, reorderedTree2Tensor});
742 originTensor_f = torch::cat({originTensor_f, origin2Tensor_f});
743 deltaOrigin = torch::cat({deltaOrigin, deltaOrigin2});
746 torch::Tensor r_axes = vSTensor_f - deltaA;
747 torch::Tensor r_data = reorderedTreeTensor - originTensor_f + deltaOrigin;
750 auto driver =
"gelsd";
751 bool is_cpu = r_axes.device().is_cpu();
752 auto device = r_axes.device();
754 r_axes = r_axes.cpu();
755 r_data = r_data.cpu();
758 = std::get<0>(torch::linalg_lstsq(r_axes, r_data, c10::nullopt, driver));
760 alphasOut = alphasOut.to(device);
762 alphasOut.reshape({-1, 1});
765float ttk::MergeTreeNeuralLayer::assignmentOneData(
766 mtu::TorchMergeTree<float> &tree,
767 mtu::TorchMergeTree<float> &tree2,
769 torch::Tensor &alphasInit,
770 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching,
771 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching2,
772 torch::Tensor &bestAlphas,
774 bool useInputBasis) {
775 mtu::TorchMergeTree<float> &origin = (useInputBasis ? origin_ : originPrime_);
776 mtu::TorchMergeTree<float> &origin2
777 = (useInputBasis ? origin2_ : origin2Prime_);
778 torch::Tensor &vSTensor = (useInputBasis ? vSTensor_ : vSPrimeTensor_);
779 torch::Tensor &vS2Tensor = (useInputBasis ? vS2Tensor_ : vS2PrimeTensor_);
781 torch::Tensor alphas, oldAlphas;
782 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching, matching2;
783 float bestDistance = std::numeric_limits<float>::max();
784 mtu::TorchMergeTree<float> interpolated, interpolated2;
787 alphasInit = torch::randn_like(alphas);
790 unsigned int noUpdate = 0;
791 unsigned int noReset = 0;
794 if(alphasInit.defined())
797 alphas = torch::zeros({vSTensor.sizes()[1], 1});
799 computeAlphas(tree, origin, vSTensor, interpolated, matching, tree2,
800 origin2, vS2Tensor, interpolated2, matching2, alphas);
801 if(oldAlphas.defined() and alphas.defined() and alphas.equal(oldAlphas)
806 mtu::copyTensor(alphas, oldAlphas);
807 getMultiInterpolation(origin, vSTensor, alphas, interpolated);
809 getMultiInterpolation(origin2, vS2Tensor, alphas, interpolated2);
810 if(interpolated.mTree.tree.getRealNumberOfNodes() == 0
812 and interpolated2.mTree.tree.getRealNumberOfNodes() == 0)) {
815 printWrn(
"[assignmentOneData] noReset >= 100");
820 computeOneDistance<float>(interpolated.mTree, tree.mTree, matching,
821 distance, isCalled, useDoubleInput_);
822 if(useDoubleInput_) {
824 computeOneDistance<float>(interpolated2.mTree, tree2.mTree, matching2,
825 distance2, isCalled, useDoubleInput_,
false);
826 distance = mixDistances<float>(distance, distance2);
828 if(distance < bestDistance and i != 0) {
830 bestMatching = matching;
831 bestMatching2 = matching2;
838 printErr(
"[assignmentOneData] noUpdate == 0");
842float ttk::MergeTreeNeuralLayer::assignmentOneData(
843 mtu::TorchMergeTree<float> &tree,
844 mtu::TorchMergeTree<float> &tree2,
846 torch::Tensor &alphasInit,
847 torch::Tensor &bestAlphas,
849 bool useInputBasis) {
850 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> bestMatching,
852 return assignmentOneData(tree, tree2, k, alphasInit, bestMatching,
853 bestMatching2, bestAlphas, isCalled, useInputBasis);
856void ttk::MergeTreeNeuralLayer::outputBasisReconstruction(
857 torch::Tensor &alphas,
858 mtu::TorchMergeTree<float> &out,
859 mtu::TorchMergeTree<float> &out2,
864 torch::Tensor act = (activate ? activation(alphas) : alphas);
865 if(dropout_ != 0.0 and train) {
866 torch::nn::Dropout model(torch::nn::DropoutOptions().p(dropout_));
869 getMultiInterpolation(originPrime_, vSPrimeTensor_, act, out);
871 getMultiInterpolation(origin2Prime_, vS2PrimeTensor_, act, out2);
874bool ttk::MergeTreeNeuralLayer::forward(mtu::TorchMergeTree<float> &tree,
875 mtu::TorchMergeTree<float> &tree2,
877 torch::Tensor &alphasInit,
878 mtu::TorchMergeTree<float> &out,
879 mtu::TorchMergeTree<float> &out2,
880 torch::Tensor &bestAlphas,
883 bool goodOutput =
false;
885 while(not goodOutput) {
886 bool isCalled =
true;
888 = assignmentOneData(tree, tree2, k, alphasInit, bestAlphas, isCalled);
889 outputBasisReconstruction(bestAlphas, out, out2,
true, train);
890 goodOutput = (out.mTree.tree.getRealNumberOfNodes() != 0
891 and (not useDoubleInput_
892 or out2.mTree.tree.getRealNumberOfNodes() != 0));
896 printWrn(
"[forwardOneLayer] noReset >= 100");
899 alphasInit = torch::randn_like(alphasInit);
905bool ttk::MergeTreeNeuralLayer::forward(mtu::TorchMergeTree<float> &tree,
906 mtu::TorchMergeTree<float> &tree2,
908 torch::Tensor &alphasInit,
909 mtu::TorchMergeTree<float> &out,
910 mtu::TorchMergeTree<float> &out2,
911 torch::Tensor &bestAlphas,
915 tree, tree2, k, alphasInit, out, out2, bestAlphas, bestDistance, train);
921void ttk::MergeTreeNeuralLayer::projectionStep() {
922 auto projectTree = [
this](mtu::TorchMergeTree<float> &tmt) {
923 interpolationProjection(tmt);
924 tmt.tensor = tmt.tensor.detach();
925 tmt.tensor.requires_grad_(
true);
927 projectTree(origin_);
928 projectTree(originPrime_);
929 if(useDoubleInput_) {
930 projectTree(origin2_);
931 projectTree(origin2Prime_);
938void ttk::MergeTreeNeuralLayer::copyParams(
939 mtu::TorchMergeTree<float> &origin,
940 mtu::TorchMergeTree<float> &originPrime,
942 torch::Tensor &vSPrime,
943 mtu::TorchMergeTree<float> &origin2,
944 mtu::TorchMergeTree<float> &origin2Prime,
946 torch::Tensor &vS2Prime,
950 mtu::TorchMergeTree<float> &srcOrigin = (get ? origin_ : origin);
951 mtu::TorchMergeTree<float> &srcOriginPrime
952 = (get ? originPrime_ : originPrime);
953 torch::Tensor &srcVS = (get ? vSTensor_ : vS);
954 torch::Tensor &srcVSPrime = (get ? vSPrimeTensor_ : vSPrime);
955 mtu::TorchMergeTree<float> &srcOrigin2 = (get ? origin2_ : origin2);
956 mtu::TorchMergeTree<float> &srcOrigin2Prime
957 = (get ? origin2Prime_ : origin2Prime);
958 torch::Tensor &srcVS2 = (get ? vS2Tensor_ : vS2);
959 torch::Tensor &srcVS2Prime = (get ? vS2PrimeTensor_ : vS2Prime);
962 mtu::TorchMergeTree<float> &dstOrigin = (!get ? origin_ : origin);
963 mtu::TorchMergeTree<float> &dstOriginPrime
964 = (!get ? originPrime_ : originPrime);
965 torch::Tensor &dstVS = (!get ? vSTensor_ : vS);
966 torch::Tensor &dstVSPrime = (!get ? vSPrimeTensor_ : vSPrime);
967 mtu::TorchMergeTree<float> &dstOrigin2 = (!get ? origin2_ : origin2);
968 mtu::TorchMergeTree<float> &dstOrigin2Prime
969 = (!get ? origin2Prime_ : origin2Prime);
970 torch::Tensor &dstVS2 = (!get ? vS2Tensor_ : vS2);
971 torch::Tensor &dstVS2Prime = (!get ? vS2PrimeTensor_ : vS2Prime);
974 mtu::copyTorchMergeTree(srcOrigin, dstOrigin);
975 mtu::copyTorchMergeTree(srcOriginPrime, dstOriginPrime);
976 mtu::copyTensor(srcVS, dstVS);
977 mtu::copyTensor(srcVSPrime, dstVSPrime);
978 if(useDoubleInput_) {
979 mtu::copyTorchMergeTree(srcOrigin2, dstOrigin2);
980 mtu::copyTorchMergeTree(srcOrigin2Prime, dstOrigin2Prime);
981 mtu::copyTensor(srcVS2, dstVS2);
982 mtu::copyTensor(srcVS2Prime, dstVS2Prime);
986void ttk::MergeTreeNeuralLayer::adjustNestingScalars(
988 float birth = scalarsVector[refNode * 2];
989 float death = scalarsVector[refNode * 2 + 1];
990 auto getSign = [](
float v) {
return (v > 0 ? 1 : -1); };
991 auto getPrecValue = [&getSign](
float v,
bool opp =
false) {
992 return v * (1 + (opp ? -1 : 1) * getSign(v) * 1e-6);
995 if(scalarsVector[node * 2 + 1] > getPrecValue(death,
true)) {
996 float diff = scalarsVector[node * 2 + 1] - getPrecValue(death,
true);
997 scalarsVector[node * 2] -= diff;
998 scalarsVector[node * 2 + 1] -= diff;
999 }
else if(scalarsVector[node * 2] < getPrecValue(birth)) {
1000 float diff = getPrecValue(birth) - scalarsVector[node * 2];
1001 scalarsVector[node * 2] += getPrecValue(diff);
1002 scalarsVector[node * 2 + 1] += getPrecValue(diff);
1005 if(scalarsVector[node * 2] < getPrecValue(birth))
1006 scalarsVector[node * 2] = getPrecValue(birth);
1007 if(scalarsVector[node * 2 + 1] > getPrecValue(death,
true))
1008 scalarsVector[node * 2 + 1] = getPrecValue(death,
true);
1011void ttk::MergeTreeNeuralLayer::createBalancedBDT(
1012 std::vector<std::vector<ftm::idNode>> &parents,
1013 std::vector<std::vector<ftm::idNode>> &children,
1014 std::vector<float> &scalarsVector,
1015 std::vector<std::vector<ftm::idNode>> &childrenFinal) {
1017 unsigned int noNodes = scalarsVector.size() / 2;
1018 childrenFinal.resize(noNodes);
1019 int mtLevel = ceil(log(noNodes * 2) / log(2)) + 1;
1020 int bdtLevel = mtLevel - 1;
1021 int noDim = bdtLevel;
1024 std::vector<int> nodeLevels(noNodes, -1);
1025 std::queue<ftm::idNode> queueLevels;
1026 std::vector<int> noChildDone(noNodes, 0);
1027 for(
unsigned int i = 0; i < children.size(); ++i) {
1028 if(children[i].size() == 0) {
1029 queueLevels.emplace(i);
1033 while(!queueLevels.empty()) {
1036 for(
auto &parent : parents[node]) {
1037 ++noChildDone[parent];
1038 nodeLevels[parent] = std::max(nodeLevels[parent], nodeLevels[node] + 1);
1039 if(noChildDone[parent] >= (
int)children[parent].size())
1040 queueLevels.emplace(parent);
1045 auto sortChildren = [
this, &parents, &scalarsVector, &noNodes](
1046 ftm::idNode nodeOrigin, std::vector<bool> &nodeDone,
1047 std::vector<std::vector<ftm::idNode>> &childrenT) {
1048 double refPers = scalarsVector[1] - scalarsVector[0];
1049 auto getRemaining = [&nodeDone](std::vector<ftm::idNode> &vec) {
1050 unsigned int remaining = 0;
1052 remaining += (not nodeDone[e]);
1055 std::vector<unsigned int> parentsRemaining(noNodes, 0),
1056 childrenRemaining(noNodes, 0);
1057 for(
auto &child : childrenT[nodeOrigin]) {
1058 parentsRemaining[child] = getRemaining(parents[child]);
1059 childrenRemaining[child] = getRemaining(childrenT[child]);
1062 threadNumber_, childrenT[nodeOrigin].
begin(), childrenT[nodeOrigin].
end(),
1064 double persI = scalarsVector[nodeI * 2 + 1] - scalarsVector[nodeI * 2];
1065 double persJ = scalarsVector[nodeJ * 2 + 1] - scalarsVector[nodeJ * 2];
1066 return parentsRemaining[nodeI] + childrenRemaining[nodeI]
1067 - persI / refPers * noNodes
1068 < parentsRemaining[nodeJ] + childrenRemaining[nodeJ]
1069 - persJ / refPers * noNodes;
1074 const auto findStructGivenDim =
1075 [&children, &noNodes, &nodeLevels](
1076 ftm::idNode _nodeOrigin,
int _dimToFound,
bool _searchMaxDim,
1077 std::vector<bool> &_nodeDone, std::vector<bool> &_dimFound,
1078 std::vector<std::vector<ftm::idNode>> &_childrenFinalOut) {
1080 auto findStructGivenDimImpl =
1081 [&children, &noNodes, &nodeLevels](
1082 ftm::idNode nodeOrigin,
int dimToFound,
bool searchMaxDim,
1083 std::vector<bool> &nodeDone, std::vector<bool> &dimFound,
1084 std::vector<std::vector<ftm::idNode>> &childrenFinalOut,
1085 auto &findStructGivenDimRef)
mutable {
1086 childrenFinalOut.resize(noNodes);
1088 int dim = (searchMaxDim ? dimToFound - 1 : 0);
1091 auto searchMaxDimReset = [&i, &dim, &nodeDone]() {
1094 unsigned int noDone = 0;
1095 for(
auto done : nodeDone)
1098 return noDone == nodeDone.size() - 1;
1100 while(i < children[nodeOrigin].size()) {
1101 auto child = children[nodeOrigin][i];
1103 if(nodeDone[child]) {
1106 if(searchMaxDim and i == children[nodeOrigin].size() - 1) {
1107 if(searchMaxDimReset())
1115 childrenFinalOut[nodeOrigin].emplace_back(child);
1116 nodeDone[child] =
true;
1118 if(dimToFound <= 1 or searchMaxDim)
1123 std::vector<std::vector<ftm::idNode>> childrenFinalDim;
1124 std::vector<bool> nodeDoneDim;
1125 std::vector<bool> dimFoundDim(dim);
1127 if(nodeLevels[child] > dim) {
1128 nodeDoneDim = nodeDone;
1129 found = findStructGivenDimRef(child, dim,
false, nodeDoneDim,
1130 dimFoundDim, childrenFinalDim,
1131 findStructGivenDimRef);
1134 dimFound[dim] =
true;
1135 childrenFinalOut[nodeOrigin].emplace_back(child);
1136 for(
unsigned int j = 0; j < childrenFinalDim.size(); ++j)
1137 for(
auto &e : childrenFinalDim[j])
1138 childrenFinalOut[j].emplace_back(e);
1139 nodeDone[child] =
true;
1140 for(
unsigned int j = 0; j < nodeDoneDim.size(); ++j)
1141 nodeDone[j] = nodeDone[j] || nodeDoneDim[j];
1143 if(dim == dimToFound - 1 and not searchMaxDim)
1147 if(searchMaxDimReset())
1153 }
else if(searchMaxDim and i == children[nodeOrigin].size() - 1) {
1156 if(searchMaxDimReset())
1165 return findStructGivenDimImpl(_nodeOrigin, _dimToFound, _searchMaxDim,
1166 _nodeDone, _dimFound, _childrenFinalOut,
1167 findStructGivenDimImpl);
1169 std::vector<bool> dimFound(noDim - 1,
false);
1170 std::vector<bool> nodeDone(noNodes,
false);
1171 for(
unsigned int i = 0; i < children.size(); ++i)
1172 sortChildren(i, nodeDone, children);
1175 findStructGivenDim(startNode, noDim,
true, nodeDone, dimFound, childrenFinal);
1178 const auto createStructGivenDim =
1179 [
this, &children, &noNodes, &findStructGivenDim, &nodeLevels](
1180 int _nodeOrigin,
int _dimToCreate, std::vector<bool> &_nodeDone,
1181 ftm::idNode &_structOrigin, std::vector<float> &_scalarsVectorOut,
1182 std::vector<std::vector<ftm::idNode>> &_childrenFinalOut) {
1184 auto createStructGivenDimImpl =
1185 [
this, &children, &noNodes, &findStructGivenDim, &nodeLevels](
1186 int nodeOrigin,
int dimToCreate, std::vector<bool> &nodeDoneImpl,
1187 ftm::idNode &structOrigin, std::vector<float> &scalarsVectorOut,
1188 std::vector<std::vector<ftm::idNode>> &childrenFinalOut,
1189 auto &createStructGivenDimRef)
mutable {
1194 int dimToFound = dimToCreate - 1;
1195 std::vector<std::vector<std::vector<ftm::idNode>>> childrenFinalT(2);
1196 std::array<ftm::idNode, 2> structOrigins;
1197 for(
unsigned int n = 0; n < 2; ++n) {
1199 for(
unsigned int i = 0; i < children[nodeOrigin].size(); ++i) {
1200 auto child = children[nodeOrigin][i];
1201 if(nodeDoneImpl[child])
1203 if(dimToFound != 0) {
1204 if(nodeLevels[child] > dimToFound) {
1205 std::vector<bool> dimFoundT(dimToFound,
false);
1206 childrenFinalT[n].clear();
1207 childrenFinalT[n].resize(noNodes);
1208 std::vector<bool> nodeDoneImplFind = nodeDoneImpl;
1209 found = findStructGivenDim(child, dimToFound,
false,
1210 nodeDoneImplFind, dimFoundT,
1216 structOrigins[n] = child;
1217 nodeDoneImpl[child] =
true;
1218 for(
unsigned int j = 0; j < childrenFinalT[n].size(); ++j) {
1219 for(
auto &e : childrenFinalT[n][j]) {
1220 childrenFinalOut[j].emplace_back(e);
1221 nodeDoneImpl[e] =
true;
1228 if(dimToFound <= 0) {
1229 structOrigins[n] = std::numeric_limits<ftm::idNode>::max();
1232 childrenFinalT[n].clear();
1233 childrenFinalT[n].resize(noNodes);
1234 createStructGivenDimRef(
1235 nodeOrigin, dimToFound, nodeDoneImpl, structOrigins[n],
1236 scalarsVectorOut, childrenFinalT[n], createStructGivenDimRef);
1237 for(
unsigned int j = 0; j < childrenFinalT[n].size(); ++j) {
1238 for(
auto &e : childrenFinalT[n][j]) {
1239 if(e == structOrigins[n])
1241 childrenFinalOut[j].emplace_back(e);
1247 if(structOrigins[0] == std::numeric_limits<ftm::idNode>::max()
1248 and structOrigins[1] == std::numeric_limits<ftm::idNode>::max()) {
1249 structOrigin = std::numeric_limits<ftm::idNode>::max();
1252 bool firstIsParent =
true;
1253 if(structOrigins[0] == std::numeric_limits<ftm::idNode>::max())
1254 firstIsParent =
false;
1255 else if(structOrigins[1] == std::numeric_limits<ftm::idNode>::max())
1256 firstIsParent =
true;
1257 else if(scalarsVectorOut[structOrigins[1] * 2 + 1]
1258 - scalarsVectorOut[structOrigins[1] * 2]
1259 > scalarsVectorOut[structOrigins[0] * 2 + 1]
1260 - scalarsVectorOut[structOrigins[0] * 2])
1261 firstIsParent =
false;
1262 structOrigin = (firstIsParent ? structOrigins[0] : structOrigins[1]);
1264 = (firstIsParent ? structOrigins[1] : structOrigins[0]);
1265 childrenFinalOut[nodeOrigin].emplace_back(structOrigin);
1266 if(modOrigin != std::numeric_limits<ftm::idNode>::max()) {
1267 childrenFinalOut[structOrigin].emplace_back(modOrigin);
1268 std::queue<std::array<ftm::idNode, 2>> queue;
1269 queue.emplace(std::array<ftm::idNode, 2>{modOrigin, structOrigin});
1270 while(!queue.empty()) {
1271 auto &nodeAndParent = queue.front();
1275 adjustNestingScalars(scalarsVectorOut, node, parent);
1277 for(
auto &child : childrenFinalOut[node])
1278 queue.emplace(std::array<ftm::idNode, 2>{child, node});
1283 return createStructGivenDimImpl(
1284 _nodeOrigin, _dimToCreate, _nodeDone, _structOrigin, _scalarsVectorOut,
1285 _childrenFinalOut, createStructGivenDimImpl);
1287 for(
unsigned int i = 0; i < children.size(); ++i)
1288 sortChildren(i, nodeDone, children);
1290 for(
unsigned int i = 0; i < dimFound.size(); ++i) {
1294 createStructGivenDim(
1295 startNode, i, nodeDone, structOrigin, scalarsVector, childrenFinal);
1305 for(
unsigned int n = 0; n < mTree.tree.getNumberOfNodes(); ++n) {
1306 if(mTree.tree.isNodeAlone(n))
1308 auto birthDeath = mTree.tree.template getBirthDeath<float>(n);
1309 if(std::abs(std::get<0>(birthDeath)) > threshold
1310 or std::abs(std::get<1>(birthDeath)) > threshold) {
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
#define TTK_PSORT(NTHREADS,...)
Parallel sort macro.
void setDebugMsgPrefix(const std::string &prefix)
int initVectors(int axeNumber, ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, ftm::MergeTree< dataType > &barycenter2, std::vector< ftm::MergeTree< dataType > > &trees2, std::vector< std::vector< double > > &v1, std::vector< std::vector< double > > &v2, std::vector< std::vector< double > > &trees2V1, std::vector< std::vector< double > > &trees2V2, int newVectorOffset, std::vector< double > &inputToOriginDistances, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &baryMatchings, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &baryMatchings2, std::vector< std::vector< double > > &inputToAxesDistances, std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s, std::vector< std::vector< std::vector< double > > > &trees2Vs, std::vector< std::vector< std::vector< double > > > &trees2V2s, bool projectInitializedVectors, F initializedVectorsProjection)
int scaleVector(const T *a, const T factor, T *out, const int &dimension=3)
T dotProduct(const T *vA0, const T *vA1, const T *vB0, const T *vB1)
int flattenMultiDimensionalVector(const std::vector< std::vector< T > > &a, std::vector< T > &out)
T magnitude(const T *v, const int &dimension=3)
T distance(const T *p0, const T *p1, const int &dimension=3)
void setTreeScalars(MergeTree< dataType > &mergeTree, std::vector< dataType > &scalarsVector)
MergeTree< dataType > copyMergeTree(const ftm::FTMTree_MT *tree, bool doSplitMultiPersPairs=false)
MergeTree< dataType > createEmptyMergeTree(int scalarSize)
unsigned int idNode
Node index in vect_nodes_.
T end(std::pair< T, T > &p)
T begin(std::pair< T, T > &p)
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/| (_) |"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)