TTK
Loading...
Searching...
No Matches
MergeTreeNeuralLayer.cpp
Go to the documentation of this file.
2#include <cmath>
3
4#ifdef TTK_ENABLE_TORCH
5using namespace torch::indexing;
6#endif
7
9 // inherited from Debug: prefix will be printed at the beginning of every msg
10 this->setDebugMsgPrefix("MergeTreeNeuralLayer");
11}
12
13#ifdef TTK_ENABLE_TORCH
14// -----------------------------------------------------------------------
15// --- Getter/Setter
16// -----------------------------------------------------------------------
17const ttk::mtu::TorchMergeTree<float> &
18 ttk::MergeTreeNeuralLayer::getOrigin() const {
19 return origin_;
20}
21
22const ttk::mtu::TorchMergeTree<float> &
23 ttk::MergeTreeNeuralLayer::getOriginPrime() const {
24 return originPrime_;
25}
26
27const ttk::mtu::TorchMergeTree<float> &
28 ttk::MergeTreeNeuralLayer::getOrigin2() const {
29 return origin2_;
30}
31
32const ttk::mtu::TorchMergeTree<float> &
33 ttk::MergeTreeNeuralLayer::getOrigin2Prime() const {
34 return origin2Prime_;
35}
36
37const torch::Tensor &ttk::MergeTreeNeuralLayer::getVSTensor() const {
38 return vSTensor_;
39}
40
41const torch::Tensor &ttk::MergeTreeNeuralLayer::getVSPrimeTensor() const {
42 return vSPrimeTensor_;
43}
44
45const torch::Tensor &ttk::MergeTreeNeuralLayer::getVS2Tensor() const {
46 return vS2Tensor_;
47}
48
49const torch::Tensor &ttk::MergeTreeNeuralLayer::getVS2PrimeTensor() const {
50 return vS2PrimeTensor_;
51}
52
53void ttk::MergeTreeNeuralLayer::setOrigin(
54 const mtu::TorchMergeTree<float> &tmt) {
55 mtu::copyTorchMergeTree(tmt, origin_);
56}
57
58void ttk::MergeTreeNeuralLayer::setOriginPrime(
59 const mtu::TorchMergeTree<float> &tmt) {
60 mtu::copyTorchMergeTree(tmt, originPrime_);
61}
62
63void ttk::MergeTreeNeuralLayer::setOrigin2(
64 const mtu::TorchMergeTree<float> &tmt) {
65 mtu::copyTorchMergeTree(tmt, origin2_);
66}
67
68void ttk::MergeTreeNeuralLayer::setOrigin2Prime(
69 const mtu::TorchMergeTree<float> &tmt) {
70 mtu::copyTorchMergeTree(tmt, origin2Prime_);
71}
72
73void ttk::MergeTreeNeuralLayer::setVSTensor(const torch::Tensor &vS) {
74 mtu::copyTensor(vS, vSTensor_);
75}
76
77void ttk::MergeTreeNeuralLayer::setVSPrimeTensor(const torch::Tensor &vS) {
78 mtu::copyTensor(vS, vSPrimeTensor_);
79}
80
81void ttk::MergeTreeNeuralLayer::setVS2Tensor(const torch::Tensor &vS) {
82 mtu::copyTensor(vS, vS2Tensor_);
83}
84
85void ttk::MergeTreeNeuralLayer::setVS2PrimeTensor(const torch::Tensor &vS) {
86 mtu::copyTensor(vS, vS2PrimeTensor_);
87}
88
89// ---------------------------------------------------------------------------
90// --- Init
91// ---------------------------------------------------------------------------
92void ttk::MergeTreeNeuralLayer::initOutputBasisTreeStructure(
93 mtu::TorchMergeTree<float> &originPrime,
94 bool isJT,
95 mtu::TorchMergeTree<float> &baseOrigin) {
96 // ----- Create scalars vector
97 torch::Tensor originTensor = originPrime.tensor;
98 if(!originTensor.device().is_cpu())
99 originTensor = originTensor.cpu();
100 std::vector<float> scalarsVector(
101 originTensor.data_ptr<float>(),
102 originTensor.data_ptr<float>() + originTensor.numel());
103 unsigned int noNodes = scalarsVector.size() / 2;
104 std::vector<std::vector<ftm::idNode>> childrenFinal(noNodes);
105
106 // ----- Init tree structure and modify scalars if necessary
107 if(isPersistenceDiagram_) {
108 for(unsigned int i = 2; i < scalarsVector.size(); i += 2)
109 childrenFinal[0].emplace_back(i / 2);
110 } else {
111 // --- Fix or swap min-max pair
112 float maxPers = std::numeric_limits<float>::lowest();
113 unsigned int indMax = 0;
114 for(unsigned int i = 0; i < scalarsVector.size(); i += 2) {
115 if(maxPers < (scalarsVector[i + 1] - scalarsVector[i])) {
116 maxPers = (scalarsVector[i + 1] - scalarsVector[i]);
117 indMax = i;
118 }
119 }
120 if(indMax != 0) {
121 float temp = scalarsVector[0];
122 scalarsVector[0] = scalarsVector[indMax];
123 scalarsVector[indMax] = temp;
124 temp = scalarsVector[1];
125 scalarsVector[1] = scalarsVector[indMax + 1];
126 scalarsVector[indMax + 1] = temp;
127 }
128 ftm::idNode refNode = 0;
129 for(unsigned int i = 2; i < scalarsVector.size(); i += 2) {
130 ftm::idNode node = i / 2;
131 adjustNestingScalars(scalarsVector, node, refNode);
132 }
133
134 if(not initOriginPrimeStructByCopy_
135 or (int) noNodes > baseOrigin.mTree.tree.getRealNumberOfNodes()) {
136 // --- Get possible children and parent relations
137 std::vector<std::vector<ftm::idNode>> parents(noNodes), children(noNodes);
138 for(unsigned int i = 0; i < scalarsVector.size(); i += 2) {
139 for(unsigned int j = i; j < scalarsVector.size(); j += 2) {
140 if(i == j)
141 continue;
142 unsigned int iN = i / 2, jN = j / 2;
143 if(scalarsVector[i] <= scalarsVector[j]
144 and scalarsVector[i + 1] >= scalarsVector[j + 1]) {
145 // - i is parent of j
146 parents[jN].emplace_back(iN);
147 children[iN].emplace_back(jN);
148 } else if(scalarsVector[i] >= scalarsVector[j]
149 and scalarsVector[i + 1] <= scalarsVector[j + 1]) {
150 // - j is parent of i
151 parents[iN].emplace_back(jN);
152 children[jN].emplace_back(iN);
153 }
154 }
155 }
156 createBalancedBDT(parents, children, scalarsVector, childrenFinal);
157 } else {
158 ftm::MergeTree<float> mTreeTemp
159 = ftm::copyMergeTree<float>(baseOrigin.mTree);
160 bool useBD = true;
161 keepMostImportantPairs<float>(&(mTreeTemp.tree), noNodes, useBD);
162 torch::Tensor reshaped = torch::tensor(scalarsVector).reshape({-1, 2});
163 torch::Tensor order = torch::argsort(
164 (reshaped.index({Slice(), 1}) - reshaped.index({Slice(), 0})), -1,
165 true);
166 std::vector<unsigned int> nodeCorr(mTreeTemp.tree.getNumberOfNodes(), 0);
167 unsigned int nodeNum = 1;
168 std::queue<ftm::idNode> queue;
169 queue.emplace(mTreeTemp.tree.getRoot());
170 while(!queue.empty()) {
171 ftm::idNode node = queue.front();
172 queue.pop();
173 std::vector<ftm::idNode> children;
174 mTreeTemp.tree.getChildren(node, children);
175 for(auto &child : children) {
176 queue.emplace(child);
177 unsigned int tNode = nodeCorr[node];
178 nodeCorr[child] = order[nodeNum].item<int>();
179 ++nodeNum;
180 unsigned int tChild = nodeCorr[child];
181 childrenFinal[tNode].emplace_back(tChild);
182 adjustNestingScalars(scalarsVector, tChild, tNode);
183 }
184 }
185 }
186 }
187
188 // ----- Create new tree
189 originPrime.mTree = ftm::createEmptyMergeTree<float>(scalarsVector.size());
190 ftm::FTMTree_MT *tree = &(originPrime.mTree.tree);
191 if(isJT) {
192 for(unsigned int i = 0; i < scalarsVector.size(); i += 2) {
193 float temp = scalarsVector[i];
194 scalarsVector[i] = scalarsVector[i + 1];
195 scalarsVector[i + 1] = temp;
196 }
197 }
198 ftm::setTreeScalars<float>(originPrime.mTree, scalarsVector);
199
200 // ----- Create tree structure
201 originPrime.nodeCorr.clear();
202 originPrime.nodeCorr.assign(
203 scalarsVector.size(), std::numeric_limits<unsigned int>::max());
204 for(unsigned int i = 0; i < scalarsVector.size(); i += 2) {
205 tree->makeNode(i);
206 tree->makeNode(i + 1);
207 tree->getNode(i)->setOrigin(i + 1);
208 tree->getNode(i + 1)->setOrigin(i);
209 originPrime.nodeCorr[i] = (unsigned int)(i / 2);
210 }
211 for(unsigned int i = 0; i < scalarsVector.size(); i += 2) {
212 unsigned int node = i / 2;
213 for(auto &child : childrenFinal[node])
214 tree->makeSuperArc(child * 2, i);
215 }
216 mtu::getParentsVector(originPrime.mTree, originPrime.parentsOri);
217
218 if(isTreeHasBigValues(originPrime.mTree, bigValuesThreshold_)) {
219 std::stringstream ss;
220 ss << originPrime.mTree.tree.printPairsFromTree<float>(true).str()
221 << std::endl;
222 ss << "isTreeHasBigValues(originPrime.mTree)" << std::endl;
223 ss << "pause" << std::endl;
224 printMsg(ss.str());
225 std::cin.get();
226 }
227}
228
229void ttk::MergeTreeNeuralLayer::initOutputBasis(
230 const unsigned int dim,
231 const unsigned int dim2,
232 const torch::Tensor &baseTensor) {
233 unsigned int originSize = origin_.tensor.sizes()[0];
234 unsigned int origin2Size = 0;
235 if(useDoubleInput_)
236 origin2Size = origin2_.tensor.sizes()[0];
237
238 // --- Compute output basis origin
239 printMsg("Compute output basis origin", debug::Priority::DETAIL);
240 auto initOutputBasisOrigin = [this, &baseTensor](
241 torch::Tensor &w,
242 mtu::TorchMergeTree<float> &tmt,
243 mtu::TorchMergeTree<float> &baseTmt) {
244 // - Create scalars
245 torch::nn::init::xavier_normal_(w);
246 torch::Tensor baseTmtTensor = baseTmt.tensor;
247 if(normalizedWasserstein_)
248 // Work on unnormalized tensor
249 mtu::mergeTreeToTorchTensor(baseTmt.mTree, baseTmtTensor, false);
250 torch::Tensor b
251 = torch::full({w.sizes()[0], 1}, 0.01,
252 torch::TensorOptions().device(baseTmtTensor.device()));
253 tmt.tensor = (torch::matmul(w, baseTmtTensor) + b);
254 // - Shift to keep mean birth and max pers
255 mtu::meanBirthMaxPersShift(tmt.tensor, baseTmtTensor);
256 // - Shift to avoid diagonal points
257 mtu::belowDiagonalPointsShift(tmt.tensor, baseTmtTensor);
258 //
259 if(initOriginPrimeValuesByCopy_) {
260 auto baseTensorDiag = baseTensor.reshape({-1, 2});
261 auto basePersDiag = (baseTensorDiag.index({Slice(), 1})
262 - baseTensorDiag.index({Slice(), 0}));
263 auto tmtTensorDiag = tmt.tensor.reshape({-1, 2});
264 auto persDiag = (tmtTensorDiag.index({Slice(1, None), 1})
265 - tmtTensorDiag.index({Slice(1, None), 0}));
266 int noK = std::min(baseTensorDiag.sizes()[0], tmtTensorDiag.sizes()[0]);
267 auto topVal = baseTensorDiag.index({std::get<1>(basePersDiag.topk(noK))});
268 auto indexes = std::get<1>(persDiag.topk(noK - 1)) + 1;
269 auto zeros
270 = torch::zeros(1, torch::TensorOptions().device(indexes.device()));
271 indexes = torch::cat({zeros, indexes}).to(torch::kLong);
272 if(initOriginPrimeValuesByCopyRandomness_ != 0) {
273 topVal = (1 - initOriginPrimeValuesByCopyRandomness_) * topVal
274 + initOriginPrimeValuesByCopyRandomness_
275 * tmtTensorDiag.index({indexes});
276 }
277 tmtTensorDiag.index_put_({indexes}, topVal);
278 }
279 // - Create tree structure
280 initOutputBasisTreeStructure(
281 tmt, baseTmt.mTree.tree.isJoinTree<float>(), baseTmt);
282 if(normalizedWasserstein_)
283 // Normalize tensor
284 mtu::mergeTreeToTorchTensor(tmt.mTree, tmt.tensor, true);
285 // - Projection
286 interpolationProjection(tmt);
287 };
288 torch::Tensor w = torch::zeros(
289 {dim, originSize}, torch::TensorOptions().device(origin_.tensor.device()));
290 initOutputBasisOrigin(w, originPrime_, origin_);
291 torch::Tensor w2;
292 if(useDoubleInput_) {
293 w2 = torch::zeros({dim2, origin2Size},
294 torch::TensorOptions().device(origin2_.tensor.device()));
295 initOutputBasisOrigin(w2, origin2Prime_, origin2_);
296 }
297
298 // --- Compute output basis vectors
299 printMsg("Compute output basis vectors", debug::Priority::DETAIL);
300 initOutputBasisVectors(w, w2);
301}
302
303void ttk::MergeTreeNeuralLayer::initOutputBasisVectors(torch::Tensor &w,
304 torch::Tensor &w2) {
305 vSPrimeTensor_ = torch::matmul(w, vSTensor_);
306 if(useDoubleInput_)
307 vS2PrimeTensor_ = torch::matmul(w2, vS2Tensor_);
308 if(normalizedWasserstein_) {
309 mtu::normalizeVectors(originPrime_.tensor, vSPrimeTensor_);
310 if(useDoubleInput_)
311 mtu::normalizeVectors(origin2Prime_.tensor, vS2PrimeTensor_);
312 }
313}
314
315void ttk::MergeTreeNeuralLayer::initOutputBasisVectors(unsigned int dim,
316 unsigned int dim2) {
317 unsigned int originSize = origin_.tensor.sizes()[0];
318 unsigned int origin2Size = 0;
319 if(useDoubleInput_)
320 origin2Size = origin2_.tensor.sizes()[0];
321 torch::Tensor w = torch::zeros({dim, originSize});
322 torch::nn::init::xavier_normal_(w);
323 torch::Tensor w2 = torch::zeros({dim2, origin2Size});
324 torch::nn::init::xavier_normal_(w2);
325 initOutputBasisVectors(w, w2);
326}
327
328void ttk::MergeTreeNeuralLayer::initInputBasisOrigin(
329 std::vector<ftm::MergeTree<float>> &treesToUse,
330 std::vector<ftm::MergeTree<float>> &trees2ToUse,
331 double barycenterSizeLimitPercent,
332 unsigned int barycenterMaxNoPairs,
333 unsigned int barycenterMaxNoPairs2,
334 std::vector<double> &inputToBaryDistances,
335 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
336 &baryMatchings,
337 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
338 &baryMatchings2) {
339 int barycenterInitIndex = -1;
340 if(initBarycenterRandom_) {
341 std::random_device rd;
342 std::default_random_engine rng(deterministic_ ? 0 : rd());
343 barycenterInitIndex
344 = std::uniform_int_distribution<>(0, treesToUse.size() - 1)(rng);
345 }
346 int maxNoPairs = (initBarycenterRandom_ ? barycenterMaxNoPairs : 0);
347 computeOneBarycenter<float>(treesToUse, origin_.mTree, baryMatchings,
348 inputToBaryDistances, barycenterSizeLimitPercent,
349 maxNoPairs, barycenterInitIndex,
350 initBarycenterOneIter_, useDoubleInput_, true);
351 if(not initBarycenterRandom_ and barycenterMaxNoPairs > 0)
352 keepMostImportantPairs<float>(
353 &(origin_.mTree.tree), barycenterMaxNoPairs, true);
354 if(useDoubleInput_) {
355 std::vector<double> baryDistances2;
356 int maxNoPairs2 = (initBarycenterRandom_ ? barycenterMaxNoPairs2 : 0);
357 computeOneBarycenter<float>(trees2ToUse, origin2_.mTree, baryMatchings2,
358 baryDistances2, barycenterSizeLimitPercent,
359 maxNoPairs2, barycenterInitIndex,
360 initBarycenterOneIter_, useDoubleInput_, false);
361 if(not initBarycenterRandom_ and barycenterMaxNoPairs2 > 0)
362 keepMostImportantPairs<float>(
363 &(origin2_.mTree.tree), barycenterMaxNoPairs2, true);
364 for(unsigned int i = 0; i < inputToBaryDistances.size(); ++i)
365 inputToBaryDistances[i]
366 = mixDistances(inputToBaryDistances[i], baryDistances2[i]);
367 }
368
369 mtu::getParentsVector(origin_.mTree, origin_.parentsOri);
370 mtu::mergeTreeToTorchTensor<float>(
371 origin_.mTree, origin_.tensor, origin_.nodeCorr, normalizedWasserstein_);
372 if(useGpu_)
373 origin_.tensor = origin_.tensor.cuda();
374 if(useDoubleInput_) {
375 mtu::getParentsVector(origin2_.mTree, origin2_.parentsOri);
376 mtu::mergeTreeToTorchTensor<float>(origin2_.mTree, origin2_.tensor,
377 origin2_.nodeCorr,
378 normalizedWasserstein_);
379 if(useGpu_)
380 origin2_.tensor = origin2_.tensor.cuda();
381 }
382}
383
384void ttk::MergeTreeNeuralLayer::initInputBasisVectors(
385 std::vector<mtu::TorchMergeTree<float>> &tmTrees,
386 std::vector<mtu::TorchMergeTree<float>> &tmTrees2,
387 std::vector<ftm::MergeTree<float>> &trees,
388 std::vector<ftm::MergeTree<float>> &trees2,
389 unsigned int noVectors,
390 std::vector<torch::Tensor> &allAlphasInit,
391 std::vector<double> &inputToBaryDistances,
392 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
393 &baryMatchings,
394 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
395 &baryMatchings2,
396 mtu::TorchMergeTree<float> &origin,
397 mtu::TorchMergeTree<float> &origin2,
398 torch::Tensor &vSTensor,
399 torch::Tensor &vS2Tensor,
400 bool useInputBasis) {
401 if(randomAxesInit_) {
402 auto initRandomAxes = [&noVectors](mtu::TorchMergeTree<float> &originT,
403 torch::Tensor &axes) {
404 torch::Tensor w = torch::zeros({noVectors, originT.tensor.sizes()[0]});
405 torch::nn::init::xavier_normal_(w);
406 axes = torch::linalg_pinv(w);
407 };
408 initRandomAxes(origin, vSTensor);
409 if(useGpu_)
410 vSTensor = vSTensor.cuda();
411 if(useDoubleInput_) {
412 initRandomAxes(origin2, vS2Tensor);
413 if(useGpu_)
414 vS2Tensor = vS2Tensor.cuda();
415 }
416#ifdef TTK_ENABLE_OPENMP
417#pragma omp parallel for schedule(dynamic) \
418 num_threads(this->threadNumber_) if(parallelize_)
419#endif
420 for(unsigned int i = 0; i < trees.size(); ++i)
421 allAlphasInit[i] = torch::randn({noVectors, 1});
422 return;
423 }
424
425 // --- Initialized vectors projection function to avoid collinearity
426 auto initializedVectorsProjection
427 = [=](int ttkNotUsed(_axeNumber),
428 ftm::MergeTree<float> &ttkNotUsed(_barycenter),
429 std::vector<std::vector<double>> &_v,
430 std::vector<std::vector<double>> &ttkNotUsed(_v2),
431 std::vector<std::vector<std::vector<double>>> &_vS,
432 std::vector<std::vector<std::vector<double>>> &ttkNotUsed(_v2s),
433 ftm::MergeTree<float> &ttkNotUsed(_barycenter2),
434 std::vector<std::vector<double>> &ttkNotUsed(_trees2V),
435 std::vector<std::vector<double>> &ttkNotUsed(_trees2V2),
436 std::vector<std::vector<std::vector<double>>> &ttkNotUsed(_trees2Vs),
437 std::vector<std::vector<std::vector<double>>> &ttkNotUsed(_trees2V2s),
438 bool ttkNotUsed(_useSecondInput),
439 unsigned int ttkNotUsed(_noProjectionStep)) {
440 std::vector<double> scaledV, scaledVSi;
443 scaledV, 1.0 / Geometry::magnitude(scaledV), scaledV);
444 for(unsigned int i = 0; i < _vS.size(); ++i) {
447 scaledVSi, 1.0 / Geometry::magnitude(scaledVSi), scaledVSi);
448 auto prod = Geometry::dotProduct(scaledV, scaledVSi);
449 double tol = 0.01;
450 if(prod <= -1.0 + tol or prod >= 1.0 - tol) {
451 // Reset vector to initialize it again
452 for(unsigned int j = 0; j < _v.size(); ++j)
453 for(unsigned int k = 0; k < _v[j].size(); ++k)
454 _v[j][k] = 0;
455 break;
456 }
457 }
458 return 0;
459 };
460
461 // --- Init vectors
462 std::vector<std::vector<double>> inputToAxesDistances;
463 std::vector<std::vector<std::vector<double>>> vS, v2s, trees2Vs, trees2V2s;
464 std::stringstream ss;
465 for(unsigned int vecNum = 0; vecNum < noVectors; ++vecNum) {
466 ss.str("");
467 ss << "Compute vectors " << vecNum;
469 std::vector<std::vector<double>> v1, v2, trees2V1, trees2V2;
470 int newVectorOffset = 0;
471 bool projectInitializedVectors = true;
473 vecNum, origin.mTree, trees, origin2.mTree, trees2, v1, v2, trees2V1,
474 trees2V2, newVectorOffset, inputToBaryDistances, baryMatchings,
475 baryMatchings2, inputToAxesDistances, vS, v2s, trees2Vs, trees2V2s,
476 projectInitializedVectors, initializedVectorsProjection);
477 vS.emplace_back(v1);
478 v2s.emplace_back(v2);
479 trees2Vs.emplace_back(trees2V1);
480 trees2V2s.emplace_back(trees2V2);
481
482 ss.str("");
483 ss << "bestIndex = " << bestIndex;
485
486 // Update inputToAxesDistances
487 printMsg("Update inputToAxesDistances", debug::Priority::VERBOSE);
488 inputToAxesDistances.resize(1, std::vector<double>(trees.size()));
489 if(bestIndex == -1 and normalizedWasserstein_) {
490 mtu::normalizeVectors(origin, vS[vS.size() - 1]);
491 if(useDoubleInput_)
492 mtu::normalizeVectors(origin2, trees2Vs[vS.size() - 1]);
493 }
494 mtu::axisVectorsToTorchTensor(origin.mTree, vS, vSTensor);
495 if(useGpu_)
496 vSTensor = vSTensor.cuda();
497 if(useDoubleInput_) {
498 mtu::axisVectorsToTorchTensor(origin2.mTree, trees2Vs, vS2Tensor);
499 if(useGpu_)
500 vS2Tensor = vS2Tensor.cuda();
501 }
502 mtu::TorchMergeTree<float> dummyTmt;
503 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>
504 dummyBaryMatching2;
505#ifdef TTK_ENABLE_OPENMP
506#pragma omp parallel for schedule(dynamic) \
507 num_threads(this->threadNumber_) if(parallelize_)
508#endif
509 for(unsigned int i = 0; i < trees.size(); ++i) {
510 auto &tmt2ToUse = (not useDoubleInput_ ? dummyTmt : tmTrees2[i]);
511 if(not euclideanVectorsInit_) {
512 unsigned int k = k_;
513 auto newAlpha = torch::ones({1, 1});
514 if(bestIndex == -1) {
515 newAlpha = torch::zeros({1, 1});
516 }
517 allAlphasInit[i] = (allAlphasInit[i].defined()
518 ? torch::cat({allAlphasInit[i], newAlpha})
519 : newAlpha);
520 torch::Tensor bestAlphas;
521 bool isCalled = true;
522 inputToAxesDistances[0][i]
523 = assignmentOneData(tmTrees[i], tmt2ToUse, k, allAlphasInit[i],
524 bestAlphas, isCalled, useInputBasis);
525 allAlphasInit[i] = bestAlphas.detach();
526 } else {
527 auto &baryMatching2ToUse
528 = (not useDoubleInput_ ? dummyBaryMatching2 : baryMatchings2[i]);
529 torch::Tensor alphas;
530 computeAlphas(tmTrees[i], origin, vSTensor, origin, baryMatchings[i],
531 tmt2ToUse, origin2, vS2Tensor, origin2,
532 baryMatching2ToUse, alphas);
533 mtu::TorchMergeTree<float> interpolated, interpolated2;
534 getMultiInterpolation(origin, vSTensor, alphas, interpolated);
535 if(useDoubleInput_)
536 getMultiInterpolation(origin2, vS2Tensor, alphas, interpolated2);
537 torch::Tensor tensorDist;
538 bool doSqrt = true;
539 getDifferentiableDistanceFromMatchings(
540 interpolated, tmTrees[i], interpolated2, tmt2ToUse, baryMatchings[i],
541 baryMatching2ToUse, tensorDist, doSqrt);
542 inputToAxesDistances[0][i] = tensorDist.item<double>();
543 allAlphasInit[i] = alphas.detach();
544 }
545 }
546 }
547}
548
549void ttk::MergeTreeNeuralLayer::initInputBasisVectors(
550 std::vector<mtu::TorchMergeTree<float>> &tmTrees,
551 std::vector<mtu::TorchMergeTree<float>> &tmTrees2,
552 std::vector<ftm::MergeTree<float>> &trees,
553 std::vector<ftm::MergeTree<float>> &trees2,
554 unsigned int noVectors,
555 std::vector<torch::Tensor> &allAlphasInit,
556 std::vector<double> &inputToBaryDistances,
557 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
558 &baryMatchings,
559 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
560 &baryMatchings2,
561 bool useInputBasis) {
562 mtu::TorchMergeTree<float> &origin = (useInputBasis ? origin_ : originPrime_);
563 mtu::TorchMergeTree<float> &origin2
564 = (useInputBasis ? origin2_ : origin2Prime_);
565 torch::Tensor &vSTensor = (useInputBasis ? vSTensor_ : vSPrimeTensor_);
566 torch::Tensor &vS2Tensor = (useInputBasis ? vS2Tensor_ : vS2PrimeTensor_);
567
568 initInputBasisVectors(tmTrees, tmTrees2, trees, trees2, noVectors,
569 allAlphasInit, inputToBaryDistances, baryMatchings,
570 baryMatchings2, origin, origin2, vSTensor, vS2Tensor,
571 useInputBasis);
572}
573
574void ttk::MergeTreeNeuralLayer::requires_grad(const bool requireGrad) {
575 origin_.tensor.requires_grad_(requireGrad);
576 originPrime_.tensor.requires_grad_(requireGrad);
577 vSTensor_.requires_grad_(requireGrad);
578 vSPrimeTensor_.requires_grad_(requireGrad);
579 if(useDoubleInput_) {
580 origin2_.tensor.requires_grad_(requireGrad);
581 origin2Prime_.tensor.requires_grad_(requireGrad);
582 vS2Tensor_.requires_grad_(requireGrad);
583 vS2PrimeTensor_.requires_grad_(requireGrad);
584 }
585}
586
587void ttk::MergeTreeNeuralLayer::cuda() {
588 origin_.tensor = origin_.tensor.cuda();
589 originPrime_.tensor = originPrime_.tensor.cuda();
590 vSTensor_ = vSTensor_.cuda();
591 vSPrimeTensor_ = vSPrimeTensor_.cuda();
592 if(useDoubleInput_) {
593 origin2_.tensor = origin2_.tensor.cuda();
594 origin2Prime_.tensor = origin2Prime_.tensor.cuda();
595 vS2Tensor_ = vS2Tensor_.cuda();
596 vS2PrimeTensor_ = vS2PrimeTensor_.cuda();
597 }
598}
599
600// ---------------------------------------------------------------------------
601// --- Interpolation
602// ---------------------------------------------------------------------------
603void ttk::MergeTreeNeuralLayer::interpolationDiagonalProjection(
604 mtu::TorchMergeTree<float> &interpolation) {
605 torch::Tensor diagTensor = interpolation.tensor.reshape({-1, 2});
606 if(interpolation.tensor.requires_grad())
607 diagTensor = diagTensor.detach();
608
609 torch::Tensor birthTensor = diagTensor.index({Slice(), 0});
610 torch::Tensor deathTensor = diagTensor.index({Slice(), 1});
611
612 torch::Tensor indexer = (birthTensor > deathTensor);
613
614 torch::Tensor allProj = (birthTensor + deathTensor) / 2.0;
615 allProj = allProj.index({indexer});
616 allProj = allProj.reshape({-1, 1});
617
618 diagTensor.index_put_({indexer}, allProj);
619}
620
621void ttk::MergeTreeNeuralLayer::interpolationNestingProjection(
622 mtu::TorchMergeTree<float> &interpolation) {
623 torch::Tensor diagTensor = interpolation.tensor.reshape({-1, 2});
624 if(interpolation.tensor.requires_grad())
625 diagTensor = diagTensor.detach();
626
627 torch::Tensor birthTensor = diagTensor.index({Slice(1, None), 0});
628 torch::Tensor deathTensor = diagTensor.index({Slice(1, None), 1});
629
630 torch::Tensor birthIndexer = (birthTensor < 0);
631 torch::Tensor deathIndexer = (deathTensor < 0);
632 birthTensor.index_put_(
633 {birthIndexer}, torch::zeros_like(birthTensor.index({birthIndexer})));
634 deathTensor.index_put_(
635 {deathIndexer}, torch::zeros_like(deathTensor.index({deathIndexer})));
636
637 birthIndexer = (birthTensor > 1);
638 deathIndexer = (deathTensor > 1);
639 birthTensor.index_put_(
640 {birthIndexer}, torch::ones_like(birthTensor.index({birthIndexer})));
641 deathTensor.index_put_(
642 {deathIndexer}, torch::ones_like(deathTensor.index({deathIndexer})));
643}
644
645void ttk::MergeTreeNeuralLayer::interpolationProjection(
646 mtu::TorchMergeTree<float> &interpolation) {
647 interpolationDiagonalProjection(interpolation);
648 if(normalizedWasserstein_)
649 interpolationNestingProjection(interpolation);
650
651 ftm::MergeTree<float> interpolationNew;
652 bool noRoot = mtu::torchTensorToMergeTree<float>(
653 interpolation, normalizedWasserstein_, interpolationNew);
654 if(noRoot)
655 printWrn("[interpolationProjection] no root found");
656 interpolation.mTree = copyMergeTree(interpolationNew);
657
658 persistenceThresholding<float>(&(interpolation.mTree.tree), 0.001);
659
660 if(isPersistenceDiagram_ and isThereMissingPairs(interpolation))
661 printWrn("[getMultiInterpolation] missing pairs");
662}
663
664void ttk::MergeTreeNeuralLayer::getMultiInterpolation(
665 const mtu::TorchMergeTree<float> &origin,
666 const torch::Tensor &vS,
667 torch::Tensor &alphas,
668 mtu::TorchMergeTree<float> &interpolation) {
669 mtu::copyTorchMergeTree<float>(origin, interpolation);
670 interpolation.tensor = origin.tensor + torch::matmul(vS, alphas);
671 interpolationProjection(interpolation);
672}
673
674// ---------------------------------------------------------------------------
675// --- Forward
676// ---------------------------------------------------------------------------
677void ttk::MergeTreeNeuralLayer::getAlphasOptimizationTensors(
678 mtu::TorchMergeTree<float> &tree,
679 mtu::TorchMergeTree<float> &origin,
680 torch::Tensor &vSTensor,
681 mtu::TorchMergeTree<float> &interpolated,
682 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
683 torch::Tensor &reorderedTreeTensor,
684 torch::Tensor &deltaOrigin,
685 torch::Tensor &deltaA,
686 torch::Tensor &originTensor_f,
687 torch::Tensor &vSTensor_f) {
688 // Create matching indexing
689 std::vector<int> tensorMatching;
690 mtu::getTensorMatching(interpolated, tree, matching, tensorMatching);
691
692 torch::Tensor indexes = torch::tensor(tensorMatching);
693 torch::Tensor projIndexer = (indexes == -1).reshape({-1, 1});
694
695 dataReorderingGivenMatching(
696 origin, tree, projIndexer, indexes, reorderedTreeTensor, deltaOrigin);
697
698 // Create axes projection given matching
699 deltaA = vSTensor.transpose(0, 1).reshape({vSTensor.sizes()[1], -1, 2});
700 deltaA = (deltaA.index({Slice(), Slice(), 0})
701 + deltaA.index({Slice(), Slice(), 1}))
702 / 2.0;
703 deltaA = torch::stack({deltaA, deltaA}, 2);
704 if(!deltaA.device().is_cpu())
705 projIndexer = projIndexer.to(deltaA.device());
706 deltaA = deltaA * projIndexer;
707 deltaA = deltaA.reshape({vSTensor.sizes()[1], -1}).transpose(0, 1);
708
709 //
710 originTensor_f = origin.tensor;
711 vSTensor_f = vSTensor;
712}
713
714void ttk::MergeTreeNeuralLayer::computeAlphas(
715 mtu::TorchMergeTree<float> &tree,
716 mtu::TorchMergeTree<float> &origin,
717 torch::Tensor &vSTensor,
718 mtu::TorchMergeTree<float> &interpolated,
719 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
720 mtu::TorchMergeTree<float> &tree2,
721 mtu::TorchMergeTree<float> &origin2,
722 torch::Tensor &vS2Tensor,
723 mtu::TorchMergeTree<float> &interpolated2,
724 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
725 torch::Tensor &alphasOut) {
726 torch::Tensor reorderedTreeTensor, deltaOrigin, deltaA, originTensor_f,
727 vSTensor_f;
728 getAlphasOptimizationTensors(tree, origin, vSTensor, interpolated, matching,
729 reorderedTreeTensor, deltaOrigin, deltaA,
730 originTensor_f, vSTensor_f);
731
732 if(useDoubleInput_) {
733 torch::Tensor reorderedTree2Tensor, deltaOrigin2, deltaA2, origin2Tensor_f,
734 vS2Tensor_f;
735 getAlphasOptimizationTensors(tree2, origin2, vS2Tensor, interpolated2,
736 matching2, reorderedTree2Tensor, deltaOrigin2,
737 deltaA2, origin2Tensor_f, vS2Tensor_f);
738 vSTensor_f = torch::cat({vSTensor_f, vS2Tensor_f});
739 deltaA = torch::cat({deltaA, deltaA2});
740 reorderedTreeTensor
741 = torch::cat({reorderedTreeTensor, reorderedTree2Tensor});
742 originTensor_f = torch::cat({originTensor_f, origin2Tensor_f});
743 deltaOrigin = torch::cat({deltaOrigin, deltaOrigin2});
744 }
745
746 torch::Tensor r_axes = vSTensor_f - deltaA;
747 torch::Tensor r_data = reorderedTreeTensor - originTensor_f + deltaOrigin;
748
749 // Pseudo inverse
750 auto driver = "gelsd";
751 bool is_cpu = r_axes.device().is_cpu();
752 auto device = r_axes.device();
753 if(!is_cpu) {
754 r_axes = r_axes.cpu();
755 r_data = r_data.cpu();
756 }
757 alphasOut
758 = std::get<0>(torch::linalg_lstsq(r_axes, r_data, c10::nullopt, driver));
759 if(!is_cpu)
760 alphasOut = alphasOut.to(device);
761
762 alphasOut.reshape({-1, 1});
763}
764
765float ttk::MergeTreeNeuralLayer::assignmentOneData(
766 mtu::TorchMergeTree<float> &tree,
767 mtu::TorchMergeTree<float> &tree2,
768 unsigned int k,
769 torch::Tensor &alphasInit,
770 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching,
771 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching2,
772 torch::Tensor &bestAlphas,
773 bool isCalled,
774 bool useInputBasis) {
775 mtu::TorchMergeTree<float> &origin = (useInputBasis ? origin_ : originPrime_);
776 mtu::TorchMergeTree<float> &origin2
777 = (useInputBasis ? origin2_ : origin2Prime_);
778 torch::Tensor &vSTensor = (useInputBasis ? vSTensor_ : vSPrimeTensor_);
779 torch::Tensor &vS2Tensor = (useInputBasis ? vS2Tensor_ : vS2PrimeTensor_);
780
781 torch::Tensor alphas, oldAlphas;
782 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching, matching2;
783 float bestDistance = std::numeric_limits<float>::max();
784 mtu::TorchMergeTree<float> interpolated, interpolated2;
785 unsigned int i = 0;
786 auto reset = [&]() {
787 alphasInit = torch::randn_like(alphas);
788 i = 0;
789 };
790 unsigned int noUpdate = 0;
791 unsigned int noReset = 0;
792 while(i < k) {
793 if(i == 0) {
794 if(alphasInit.defined())
795 alphas = alphasInit;
796 else
797 alphas = torch::zeros({vSTensor.sizes()[1], 1});
798 } else {
799 computeAlphas(tree, origin, vSTensor, interpolated, matching, tree2,
800 origin2, vS2Tensor, interpolated2, matching2, alphas);
801 if(oldAlphas.defined() and alphas.defined() and alphas.equal(oldAlphas)
802 and i != 1) {
803 break;
804 }
805 }
806 mtu::copyTensor(alphas, oldAlphas);
807 getMultiInterpolation(origin, vSTensor, alphas, interpolated);
808 if(useDoubleInput_)
809 getMultiInterpolation(origin2, vS2Tensor, alphas, interpolated2);
810 if(interpolated.mTree.tree.getRealNumberOfNodes() == 0
811 or (useDoubleInput_
812 and interpolated2.mTree.tree.getRealNumberOfNodes() == 0)) {
813 ++noReset;
814 if(noReset >= 100)
815 printWrn("[assignmentOneData] noReset >= 100");
816 reset();
817 continue;
818 }
819 float distance;
820 computeOneDistance<float>(interpolated.mTree, tree.mTree, matching,
821 distance, isCalled, useDoubleInput_);
822 if(useDoubleInput_) {
823 float distance2;
824 computeOneDistance<float>(interpolated2.mTree, tree2.mTree, matching2,
825 distance2, isCalled, useDoubleInput_, false);
826 distance = mixDistances<float>(distance, distance2);
827 }
828 if(distance < bestDistance and i != 0) {
829 bestDistance = distance;
830 bestMatching = matching;
831 bestMatching2 = matching2;
832 bestAlphas = alphas;
833 noUpdate += 1;
834 }
835 i += 1;
836 }
837 if(noUpdate == 0)
838 printErr("[assignmentOneData] noUpdate == 0");
839 return bestDistance;
840}
841
842float ttk::MergeTreeNeuralLayer::assignmentOneData(
843 mtu::TorchMergeTree<float> &tree,
844 mtu::TorchMergeTree<float> &tree2,
845 unsigned int k,
846 torch::Tensor &alphasInit,
847 torch::Tensor &bestAlphas,
848 bool isCalled,
849 bool useInputBasis) {
850 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> bestMatching,
851 bestMatching2;
852 return assignmentOneData(tree, tree2, k, alphasInit, bestMatching,
853 bestMatching2, bestAlphas, isCalled, useInputBasis);
854}
855
856void ttk::MergeTreeNeuralLayer::outputBasisReconstruction(
857 torch::Tensor &alphas,
858 mtu::TorchMergeTree<float> &out,
859 mtu::TorchMergeTree<float> &out2,
860 bool activate,
861 bool train) {
862 if(not activate_)
863 activate = false;
864 torch::Tensor act = (activate ? activation(alphas) : alphas);
865 if(dropout_ != 0.0 and train) {
866 torch::nn::Dropout model(torch::nn::DropoutOptions().p(dropout_));
867 act = model(act);
868 }
869 getMultiInterpolation(originPrime_, vSPrimeTensor_, act, out);
870 if(useDoubleInput_)
871 getMultiInterpolation(origin2Prime_, vS2PrimeTensor_, act, out2);
872}
873
874bool ttk::MergeTreeNeuralLayer::forward(mtu::TorchMergeTree<float> &tree,
875 mtu::TorchMergeTree<float> &tree2,
876 unsigned int k,
877 torch::Tensor &alphasInit,
878 mtu::TorchMergeTree<float> &out,
879 mtu::TorchMergeTree<float> &out2,
880 torch::Tensor &bestAlphas,
881 float &bestDistance,
882 bool train) {
883 bool goodOutput = false;
884 int noReset = 0;
885 while(not goodOutput) {
886 bool isCalled = true;
887 bestDistance
888 = assignmentOneData(tree, tree2, k, alphasInit, bestAlphas, isCalled);
889 outputBasisReconstruction(bestAlphas, out, out2, true, train);
890 goodOutput = (out.mTree.tree.getRealNumberOfNodes() != 0
891 and (not useDoubleInput_
892 or out2.mTree.tree.getRealNumberOfNodes() != 0));
893 if(not goodOutput) {
894 ++noReset;
895 if(noReset >= 100) {
896 printWrn("[forwardOneLayer] noReset >= 100");
897 return true;
898 }
899 alphasInit = torch::randn_like(alphasInit);
900 }
901 }
902 return false;
903}
904
905bool ttk::MergeTreeNeuralLayer::forward(mtu::TorchMergeTree<float> &tree,
906 mtu::TorchMergeTree<float> &tree2,
907 unsigned int k,
908 torch::Tensor &alphasInit,
909 mtu::TorchMergeTree<float> &out,
910 mtu::TorchMergeTree<float> &out2,
911 torch::Tensor &bestAlphas,
912 bool train) {
913 float bestDistance;
914 return forward(
915 tree, tree2, k, alphasInit, out, out2, bestAlphas, bestDistance, train);
916}
917
918// ---------------------------------------------------------------------------
919// --- Projection
920// ---------------------------------------------------------------------------
921void ttk::MergeTreeNeuralLayer::projectionStep() {
922 auto projectTree = [this](mtu::TorchMergeTree<float> &tmt) {
923 interpolationProjection(tmt);
924 tmt.tensor = tmt.tensor.detach();
925 tmt.tensor.requires_grad_(true);
926 };
927 projectTree(origin_);
928 projectTree(originPrime_);
929 if(useDoubleInput_) {
930 projectTree(origin2_);
931 projectTree(origin2Prime_);
932 }
933}
934
935// ---------------------------------------------------------------------------
936// --- Utils
937// ---------------------------------------------------------------------------
938void ttk::MergeTreeNeuralLayer::copyParams(
939 mtu::TorchMergeTree<float> &origin,
940 mtu::TorchMergeTree<float> &originPrime,
941 torch::Tensor &vS,
942 torch::Tensor &vSPrime,
943 mtu::TorchMergeTree<float> &origin2,
944 mtu::TorchMergeTree<float> &origin2Prime,
945 torch::Tensor &vS2,
946 torch::Tensor &vS2Prime,
947 bool get) {
948
949 // Source
950 mtu::TorchMergeTree<float> &srcOrigin = (get ? origin_ : origin);
951 mtu::TorchMergeTree<float> &srcOriginPrime
952 = (get ? originPrime_ : originPrime);
953 torch::Tensor &srcVS = (get ? vSTensor_ : vS);
954 torch::Tensor &srcVSPrime = (get ? vSPrimeTensor_ : vSPrime);
955 mtu::TorchMergeTree<float> &srcOrigin2 = (get ? origin2_ : origin2);
956 mtu::TorchMergeTree<float> &srcOrigin2Prime
957 = (get ? origin2Prime_ : origin2Prime);
958 torch::Tensor &srcVS2 = (get ? vS2Tensor_ : vS2);
959 torch::Tensor &srcVS2Prime = (get ? vS2PrimeTensor_ : vS2Prime);
960
961 // Destination
962 mtu::TorchMergeTree<float> &dstOrigin = (!get ? origin_ : origin);
963 mtu::TorchMergeTree<float> &dstOriginPrime
964 = (!get ? originPrime_ : originPrime);
965 torch::Tensor &dstVS = (!get ? vSTensor_ : vS);
966 torch::Tensor &dstVSPrime = (!get ? vSPrimeTensor_ : vSPrime);
967 mtu::TorchMergeTree<float> &dstOrigin2 = (!get ? origin2_ : origin2);
968 mtu::TorchMergeTree<float> &dstOrigin2Prime
969 = (!get ? origin2Prime_ : origin2Prime);
970 torch::Tensor &dstVS2 = (!get ? vS2Tensor_ : vS2);
971 torch::Tensor &dstVS2Prime = (!get ? vS2PrimeTensor_ : vS2Prime);
972
973 // Copy
974 mtu::copyTorchMergeTree(srcOrigin, dstOrigin);
975 mtu::copyTorchMergeTree(srcOriginPrime, dstOriginPrime);
976 mtu::copyTensor(srcVS, dstVS);
977 mtu::copyTensor(srcVSPrime, dstVSPrime);
978 if(useDoubleInput_) {
979 mtu::copyTorchMergeTree(srcOrigin2, dstOrigin2);
980 mtu::copyTorchMergeTree(srcOrigin2Prime, dstOrigin2Prime);
981 mtu::copyTensor(srcVS2, dstVS2);
982 mtu::copyTensor(srcVS2Prime, dstVS2Prime);
983 }
984}
985
986void ttk::MergeTreeNeuralLayer::adjustNestingScalars(
987 std::vector<float> &scalarsVector, ftm::idNode node, ftm::idNode refNode) {
988 float birth = scalarsVector[refNode * 2];
989 float death = scalarsVector[refNode * 2 + 1];
990 auto getSign = [](float v) { return (v > 0 ? 1 : -1); };
991 auto getPrecValue = [&getSign](float v, bool opp = false) {
992 return v * (1 + (opp ? -1 : 1) * getSign(v) * 1e-6);
993 };
994 // Shift scalars
995 if(scalarsVector[node * 2 + 1] > getPrecValue(death, true)) {
996 float diff = scalarsVector[node * 2 + 1] - getPrecValue(death, true);
997 scalarsVector[node * 2] -= diff;
998 scalarsVector[node * 2 + 1] -= diff;
999 } else if(scalarsVector[node * 2] < getPrecValue(birth)) {
1000 float diff = getPrecValue(birth) - scalarsVector[node * 2];
1001 scalarsVector[node * 2] += getPrecValue(diff);
1002 scalarsVector[node * 2 + 1] += getPrecValue(diff);
1003 }
1004 // Cut scalars
1005 if(scalarsVector[node * 2] < getPrecValue(birth))
1006 scalarsVector[node * 2] = getPrecValue(birth);
1007 if(scalarsVector[node * 2 + 1] > getPrecValue(death, true))
1008 scalarsVector[node * 2 + 1] = getPrecValue(death, true);
1009}
1010
1011void ttk::MergeTreeNeuralLayer::createBalancedBDT(
1012 std::vector<std::vector<ftm::idNode>> &parents,
1013 std::vector<std::vector<ftm::idNode>> &children,
1014 std::vector<float> &scalarsVector,
1015 std::vector<std::vector<ftm::idNode>> &childrenFinal) {
1016 // ----- Some variables
1017 unsigned int noNodes = scalarsVector.size() / 2;
1018 childrenFinal.resize(noNodes);
1019 int mtLevel = ceil(log(noNodes * 2) / log(2)) + 1;
1020 int bdtLevel = mtLevel - 1;
1021 int noDim = bdtLevel;
1022
1023 // ----- Get node levels
1024 std::vector<int> nodeLevels(noNodes, -1);
1025 std::queue<ftm::idNode> queueLevels;
1026 std::vector<int> noChildDone(noNodes, 0);
1027 for(unsigned int i = 0; i < children.size(); ++i) {
1028 if(children[i].size() == 0) {
1029 queueLevels.emplace(i);
1030 nodeLevels[i] = 1;
1031 }
1032 }
1033 while(!queueLevels.empty()) {
1034 ftm::idNode node = queueLevels.front();
1035 queueLevels.pop();
1036 for(auto &parent : parents[node]) {
1037 ++noChildDone[parent];
1038 nodeLevels[parent] = std::max(nodeLevels[parent], nodeLevels[node] + 1);
1039 if(noChildDone[parent] >= (int)children[parent].size())
1040 queueLevels.emplace(parent);
1041 }
1042 }
1043
1044 // ----- Sort heuristic lambda
1045 auto sortChildren = [this, &parents, &scalarsVector, &noNodes](
1046 ftm::idNode nodeOrigin, std::vector<bool> &nodeDone,
1047 std::vector<std::vector<ftm::idNode>> &childrenT) {
1048 double refPers = scalarsVector[1] - scalarsVector[0];
1049 auto getRemaining = [&nodeDone](std::vector<ftm::idNode> &vec) {
1050 unsigned int remaining = 0;
1051 for(auto &e : vec)
1052 remaining += (not nodeDone[e]);
1053 return remaining;
1054 };
1055 std::vector<unsigned int> parentsRemaining(noNodes, 0),
1056 childrenRemaining(noNodes, 0);
1057 for(auto &child : childrenT[nodeOrigin]) {
1058 parentsRemaining[child] = getRemaining(parents[child]);
1059 childrenRemaining[child] = getRemaining(childrenT[child]);
1060 }
1061 TTK_PSORT(
1062 threadNumber_, childrenT[nodeOrigin].begin(), childrenT[nodeOrigin].end(),
1063 [&](ftm::idNode nodeI, ftm::idNode nodeJ) {
1064 double persI = scalarsVector[nodeI * 2 + 1] - scalarsVector[nodeI * 2];
1065 double persJ = scalarsVector[nodeJ * 2 + 1] - scalarsVector[nodeJ * 2];
1066 return parentsRemaining[nodeI] + childrenRemaining[nodeI]
1067 - persI / refPers * noNodes
1068 < parentsRemaining[nodeJ] + childrenRemaining[nodeJ]
1069 - persJ / refPers * noNodes;
1070 });
1071 };
1072
1073 // ----- Greedy approach to find balanced BDT structures
1074 const auto findStructGivenDim =
1075 [&children, &noNodes, &nodeLevels](
1076 ftm::idNode _nodeOrigin, int _dimToFound, bool _searchMaxDim,
1077 std::vector<bool> &_nodeDone, std::vector<bool> &_dimFound,
1078 std::vector<std::vector<ftm::idNode>> &_childrenFinalOut) {
1079 // --- Recursive lambda
1080 auto findStructGivenDimImpl =
1081 [&children, &noNodes, &nodeLevels](
1082 ftm::idNode nodeOrigin, int dimToFound, bool searchMaxDim,
1083 std::vector<bool> &nodeDone, std::vector<bool> &dimFound,
1084 std::vector<std::vector<ftm::idNode>> &childrenFinalOut,
1085 auto &findStructGivenDimRef) mutable {
1086 childrenFinalOut.resize(noNodes);
1087 // - Find structures
1088 int dim = (searchMaxDim ? dimToFound - 1 : 0);
1089 unsigned int i = 0;
1090 //
1091 auto searchMaxDimReset = [&i, &dim, &nodeDone]() {
1092 --dim;
1093 i = 0;
1094 unsigned int noDone = 0;
1095 for(auto done : nodeDone)
1096 if(done)
1097 ++noDone;
1098 return noDone == nodeDone.size() - 1; // -1 for root
1099 };
1100 while(i < children[nodeOrigin].size()) {
1101 auto child = children[nodeOrigin][i];
1102 // Skip if child was already processed
1103 if(nodeDone[child]) {
1104 // If we have processed all children while searching for max
1105 // dim then restart at the beginning to find a lower dim
1106 if(searchMaxDim and i == children[nodeOrigin].size() - 1) {
1107 if(searchMaxDimReset())
1108 break;
1109 } else
1110 ++i;
1111 continue;
1112 }
1113 if(dim == 0) {
1114 // Base case
1115 childrenFinalOut[nodeOrigin].emplace_back(child);
1116 nodeDone[child] = true;
1117 dimFound[0] = true;
1118 if(dimToFound <= 1 or searchMaxDim)
1119 return true;
1120 ++dim;
1121 } else {
1122 // General case
1123 std::vector<std::vector<ftm::idNode>> childrenFinalDim;
1124 std::vector<bool> nodeDoneDim;
1125 std::vector<bool> dimFoundDim(dim);
1126 bool found = false;
1127 if(nodeLevels[child] > dim) {
1128 nodeDoneDim = nodeDone;
1129 found = findStructGivenDimRef(child, dim, false, nodeDoneDim,
1130 dimFoundDim, childrenFinalDim,
1131 findStructGivenDimRef);
1132 }
1133 if(found) {
1134 dimFound[dim] = true;
1135 childrenFinalOut[nodeOrigin].emplace_back(child);
1136 for(unsigned int j = 0; j < childrenFinalDim.size(); ++j)
1137 for(auto &e : childrenFinalDim[j])
1138 childrenFinalOut[j].emplace_back(e);
1139 nodeDone[child] = true;
1140 for(unsigned int j = 0; j < nodeDoneDim.size(); ++j)
1141 nodeDone[j] = nodeDone[j] || nodeDoneDim[j];
1142 // Return if it is the last dim to found
1143 if(dim == dimToFound - 1 and not searchMaxDim)
1144 return true;
1145 // Reset index if we search for the maximum dim
1146 if(searchMaxDim) {
1147 if(searchMaxDimReset())
1148 break;
1149 } else {
1150 ++dim;
1151 }
1152 continue;
1153 } else if(searchMaxDim and i == children[nodeOrigin].size() - 1) {
1154 // If we have processed all children while searching for max dim
1155 // then restart at the beginning to find a lower dim
1156 if(searchMaxDimReset())
1157 break;
1158 continue;
1159 }
1160 }
1161 ++i;
1162 }
1163 return false;
1164 };
1165 return findStructGivenDimImpl(_nodeOrigin, _dimToFound, _searchMaxDim,
1166 _nodeDone, _dimFound, _childrenFinalOut,
1167 findStructGivenDimImpl);
1168 };
1169 std::vector<bool> dimFound(noDim - 1, false);
1170 std::vector<bool> nodeDone(noNodes, false);
1171 for(unsigned int i = 0; i < children.size(); ++i)
1172 sortChildren(i, nodeDone, children);
1173 Timer t_find;
1174 ftm::idNode startNode = 0;
1175 findStructGivenDim(startNode, noDim, true, nodeDone, dimFound, childrenFinal);
1176
1177 // ----- Greedy approach to create non found structures
1178 const auto createStructGivenDim =
1179 [this, &children, &noNodes, &findStructGivenDim, &nodeLevels](
1180 int _nodeOrigin, int _dimToCreate, std::vector<bool> &_nodeDone,
1181 ftm::idNode &_structOrigin, std::vector<float> &_scalarsVectorOut,
1182 std::vector<std::vector<ftm::idNode>> &_childrenFinalOut) {
1183 // --- Recursive lambda
1184 auto createStructGivenDimImpl =
1185 [this, &children, &noNodes, &findStructGivenDim, &nodeLevels](
1186 int nodeOrigin, int dimToCreate, std::vector<bool> &nodeDoneImpl,
1187 ftm::idNode &structOrigin, std::vector<float> &scalarsVectorOut,
1188 std::vector<std::vector<ftm::idNode>> &childrenFinalOut,
1189 auto &createStructGivenDimRef) mutable {
1190 // Deduction of auto lambda type
1191 if(false)
1192 return;
1193 // - Find structures of lower dimension
1194 int dimToFound = dimToCreate - 1;
1195 std::vector<std::vector<std::vector<ftm::idNode>>> childrenFinalT(2);
1196 std::array<ftm::idNode, 2> structOrigins;
1197 for(unsigned int n = 0; n < 2; ++n) {
1198 bool found = false;
1199 for(unsigned int i = 0; i < children[nodeOrigin].size(); ++i) {
1200 auto child = children[nodeOrigin][i];
1201 if(nodeDoneImpl[child])
1202 continue;
1203 if(dimToFound != 0) {
1204 if(nodeLevels[child] > dimToFound) {
1205 std::vector<bool> dimFoundT(dimToFound, false);
1206 childrenFinalT[n].clear();
1207 childrenFinalT[n].resize(noNodes);
1208 std::vector<bool> nodeDoneImplFind = nodeDoneImpl;
1209 found = findStructGivenDim(child, dimToFound, false,
1210 nodeDoneImplFind, dimFoundT,
1211 childrenFinalT[n]);
1212 }
1213 } else
1214 found = true;
1215 if(found) {
1216 structOrigins[n] = child;
1217 nodeDoneImpl[child] = true;
1218 for(unsigned int j = 0; j < childrenFinalT[n].size(); ++j) {
1219 for(auto &e : childrenFinalT[n][j]) {
1220 childrenFinalOut[j].emplace_back(e);
1221 nodeDoneImpl[e] = true;
1222 }
1223 }
1224 break;
1225 }
1226 } // end for children[nodeOrigin]
1227 if(not found) {
1228 if(dimToFound <= 0) {
1229 structOrigins[n] = std::numeric_limits<ftm::idNode>::max();
1230 continue;
1231 }
1232 childrenFinalT[n].clear();
1233 childrenFinalT[n].resize(noNodes);
1234 createStructGivenDimRef(
1235 nodeOrigin, dimToFound, nodeDoneImpl, structOrigins[n],
1236 scalarsVectorOut, childrenFinalT[n], createStructGivenDimRef);
1237 for(unsigned int j = 0; j < childrenFinalT[n].size(); ++j) {
1238 for(auto &e : childrenFinalT[n][j]) {
1239 if(e == structOrigins[n])
1240 continue;
1241 childrenFinalOut[j].emplace_back(e);
1242 }
1243 }
1244 }
1245 } // end for n
1246 // - Combine both structures
1247 if(structOrigins[0] == std::numeric_limits<ftm::idNode>::max()
1248 and structOrigins[1] == std::numeric_limits<ftm::idNode>::max()) {
1249 structOrigin = std::numeric_limits<ftm::idNode>::max();
1250 return;
1251 }
1252 bool firstIsParent = true;
1253 if(structOrigins[0] == std::numeric_limits<ftm::idNode>::max())
1254 firstIsParent = false;
1255 else if(structOrigins[1] == std::numeric_limits<ftm::idNode>::max())
1256 firstIsParent = true;
1257 else if(scalarsVectorOut[structOrigins[1] * 2 + 1]
1258 - scalarsVectorOut[structOrigins[1] * 2]
1259 > scalarsVectorOut[structOrigins[0] * 2 + 1]
1260 - scalarsVectorOut[structOrigins[0] * 2])
1261 firstIsParent = false;
1262 structOrigin = (firstIsParent ? structOrigins[0] : structOrigins[1]);
1263 ftm::idNode modOrigin
1264 = (firstIsParent ? structOrigins[1] : structOrigins[0]);
1265 childrenFinalOut[nodeOrigin].emplace_back(structOrigin);
1266 if(modOrigin != std::numeric_limits<ftm::idNode>::max()) {
1267 childrenFinalOut[structOrigin].emplace_back(modOrigin);
1268 std::queue<std::array<ftm::idNode, 2>> queue;
1269 queue.emplace(std::array<ftm::idNode, 2>{modOrigin, structOrigin});
1270 while(!queue.empty()) {
1271 auto &nodeAndParent = queue.front();
1272 ftm::idNode node = nodeAndParent[0];
1273 ftm::idNode parent = nodeAndParent[1];
1274 queue.pop();
1275 adjustNestingScalars(scalarsVectorOut, node, parent);
1276 // Push children
1277 for(auto &child : childrenFinalOut[node])
1278 queue.emplace(std::array<ftm::idNode, 2>{child, node});
1279 }
1280 }
1281 return;
1282 };
1283 return createStructGivenDimImpl(
1284 _nodeOrigin, _dimToCreate, _nodeDone, _structOrigin, _scalarsVectorOut,
1285 _childrenFinalOut, createStructGivenDimImpl);
1286 };
1287 for(unsigned int i = 0; i < children.size(); ++i)
1288 sortChildren(i, nodeDone, children);
1289 Timer t_create;
1290 for(unsigned int i = 0; i < dimFound.size(); ++i) {
1291 if(dimFound[i])
1292 continue;
1293 ftm::idNode structOrigin;
1294 createStructGivenDim(
1295 startNode, i, nodeDone, structOrigin, scalarsVector, childrenFinal);
1296 }
1297}
1298
1299// ---------------------------------------------------------------------------
1300// --- Testing
1301// ---------------------------------------------------------------------------
1302bool ttk::MergeTreeNeuralLayer::isTreeHasBigValues(ftm::MergeTree<float> &mTree,
1303 float threshold) {
1304 bool found = false;
1305 for(unsigned int n = 0; n < mTree.tree.getNumberOfNodes(); ++n) {
1306 if(mTree.tree.isNodeAlone(n))
1307 continue;
1308 auto birthDeath = mTree.tree.template getBirthDeath<float>(n);
1309 if(std::abs(std::get<0>(birthDeath)) > threshold
1310 or std::abs(std::get<1>(birthDeath)) > threshold) {
1311 found = true;
1312 break;
1313 }
1314 }
1315 return found;
1316}
1317#endif
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
Definition BaseClass.h:47
#define TTK_PSORT(NTHREADS,...)
Parallel sort macro.
Definition OpenMP.h:46
void setDebugMsgPrefix(const std::string &prefix)
Definition Debug.h:364
int initVectors(int axeNumber, ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, ftm::MergeTree< dataType > &barycenter2, std::vector< ftm::MergeTree< dataType > > &trees2, std::vector< std::vector< double > > &v1, std::vector< std::vector< double > > &v2, std::vector< std::vector< double > > &trees2V1, std::vector< std::vector< double > > &trees2V2, int newVectorOffset, std::vector< double > &inputToOriginDistances, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &baryMatchings, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &baryMatchings2, std::vector< std::vector< double > > &inputToAxesDistances, std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s, std::vector< std::vector< std::vector< double > > > &trees2Vs, std::vector< std::vector< std::vector< double > > > &trees2V2s, bool projectInitializedVectors, F initializedVectorsProjection)
int scaleVector(const T *a, const T factor, T *out, const int &dimension=3)
Definition Geometry.cpp:647
T dotProduct(const T *vA0, const T *vA1, const T *vB0, const T *vB1)
Definition Geometry.cpp:388
int flattenMultiDimensionalVector(const std::vector< std::vector< T > > &a, std::vector< T > &out)
Definition Geometry.cpp:742
T magnitude(const T *v, const int &dimension=3)
Definition Geometry.cpp:509
T distance(const T *p0, const T *p1, const int &dimension=3)
Definition Geometry.cpp:362
void setTreeScalars(MergeTree< dataType > &mergeTree, std::vector< dataType > &scalarsVector)
MergeTree< dataType > copyMergeTree(const ftm::FTMTree_MT *tree, bool doSplitMultiPersPairs=false)
MergeTree< dataType > createEmptyMergeTree(int scalarSize)
unsigned int idNode
Node index in vect_nodes_.
T end(std::pair< T, T > &p)
Definition ripser.cpp:503
T begin(std::pair< T, T > &p)
Definition ripser.cpp:499
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/| (_) |"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)