TTK
Loading...
Searching...
No Matches
MergeTreeAutoencoder.cpp
Go to the documentation of this file.
3#include <cmath>
4
5#ifdef TTK_ENABLE_TORCH
6using namespace torch::indexing;
7#endif
8
10 // inherited from Debug: prefix will be printed at the beginning of every msg
11 this->setDebugMsgPrefix("MergeTreeAutoencoder");
12}
13
14#ifdef TTK_ENABLE_TORCH
15// ---------------------------------------------------------------------------
16// --- Init
17// ---------------------------------------------------------------------------
18void ttk::MergeTreeAutoencoder::initOutputBasisTreeStructure(
19 mtu::TorchMergeTree<float> &originPrime,
20 bool isJT,
21 mtu::TorchMergeTree<float> &baseOrigin) {
22 // ----- Create scalars vector
23 std::vector<float> scalarsVector(
24 originPrime.tensor.data_ptr<float>(),
25 originPrime.tensor.data_ptr<float>() + originPrime.tensor.numel());
26 unsigned int noNodes = scalarsVector.size() / 2;
27 std::vector<std::vector<ftm::idNode>> childrenFinal(noNodes);
28
29 // ----- Init tree structure and modify scalars if necessary
30 if(isPersistenceDiagram_) {
31 for(unsigned int i = 2; i < scalarsVector.size(); i += 2)
32 childrenFinal[0].emplace_back(i / 2);
33 } else {
34 // --- Fix or swap min-max pair
35 float maxPers = std::numeric_limits<float>::lowest();
36 unsigned int indMax = 0;
37 for(unsigned int i = 0; i < scalarsVector.size(); i += 2) {
38 if(maxPers < (scalarsVector[i + 1] - scalarsVector[i])) {
39 maxPers = (scalarsVector[i + 1] - scalarsVector[i]);
40 indMax = i;
41 }
42 }
43 if(indMax != 0) {
44 float temp = scalarsVector[0];
45 scalarsVector[0] = scalarsVector[indMax];
46 scalarsVector[indMax] = temp;
47 temp = scalarsVector[1];
48 scalarsVector[1] = scalarsVector[indMax + 1];
49 scalarsVector[indMax + 1] = temp;
50 }
51 ftm::idNode refNode = 0;
52 for(unsigned int i = 2; i < scalarsVector.size(); i += 2) {
53 ftm::idNode node = i / 2;
54 wae::adjustNestingScalars(scalarsVector, node, refNode);
55 }
56
57 if(not initOriginPrimeStructByCopy_
58 or (int) noNodes > baseOrigin.mTree.tree.getRealNumberOfNodes()) {
59 // --- Get possible children and parent relations
60 std::vector<std::vector<ftm::idNode>> parents(noNodes), children(noNodes);
61 for(unsigned int i = 0; i < scalarsVector.size(); i += 2) {
62 for(unsigned int j = i; j < scalarsVector.size(); j += 2) {
63 if(i == j)
64 continue;
65 unsigned int iN = i / 2, jN = j / 2;
66 if(scalarsVector[i] <= scalarsVector[j]
67 and scalarsVector[i + 1] >= scalarsVector[j + 1]) {
68 // - i is parent of j
69 parents[jN].emplace_back(iN);
70 children[iN].emplace_back(jN);
71 } else if(scalarsVector[i] >= scalarsVector[j]
72 and scalarsVector[i + 1] <= scalarsVector[j + 1]) {
73 // - j is parent of i
74 parents[iN].emplace_back(jN);
75 children[jN].emplace_back(iN);
76 }
77 }
78 }
80 parents, children, scalarsVector, childrenFinal, this->threadNumber_);
81 } else {
82 ftm::MergeTree<float> mTreeTemp
83 = ftm::copyMergeTree<float>(baseOrigin.mTree);
84 bool useBD = true;
85 keepMostImportantPairs<float>(&(mTreeTemp.tree), noNodes, useBD);
86 torch::Tensor reshaped = torch::tensor(scalarsVector).reshape({-1, 2});
87 torch::Tensor order = torch::argsort(
88 (reshaped.index({Slice(), 1}) - reshaped.index({Slice(), 0})), -1,
89 true);
90 std::vector<unsigned int> nodeCorr(mTreeTemp.tree.getNumberOfNodes(), 0);
91 unsigned int nodeNum = 1;
92 std::queue<ftm::idNode> queue;
93 queue.emplace(mTreeTemp.tree.getRoot());
94 while(!queue.empty()) {
95 ftm::idNode node = queue.front();
96 queue.pop();
97 std::vector<ftm::idNode> children;
98 mTreeTemp.tree.getChildren(node, children);
99 for(auto &child : children) {
100 queue.emplace(child);
101 unsigned int tNode = nodeCorr[node];
102 nodeCorr[child] = order[nodeNum].item<int>();
103 ++nodeNum;
104 unsigned int tChild = nodeCorr[child];
105 childrenFinal[tNode].emplace_back(tChild);
106 wae::adjustNestingScalars(scalarsVector, tChild, tNode);
107 }
108 }
109 }
110 }
111
112 // ----- Create new tree
113 originPrime.mTree = ftm::createEmptyMergeTree<float>(scalarsVector.size());
114 ftm::FTMTree_MT *tree = &(originPrime.mTree.tree);
115 if(isJT) {
116 for(unsigned int i = 0; i < scalarsVector.size(); i += 2) {
117 float temp = scalarsVector[i];
118 scalarsVector[i] = scalarsVector[i + 1];
119 scalarsVector[i + 1] = temp;
120 }
121 }
122 ftm::setTreeScalars<float>(originPrime.mTree, scalarsVector);
123
124 // ----- Create tree structure
125 originPrime.nodeCorr.clear();
126 originPrime.nodeCorr.assign(
127 scalarsVector.size(), std::numeric_limits<unsigned int>::max());
128 for(unsigned int i = 0; i < scalarsVector.size(); i += 2) {
129 tree->makeNode(i);
130 tree->makeNode(i + 1);
131 tree->getNode(i)->setOrigin(i + 1);
132 tree->getNode(i + 1)->setOrigin(i);
133 originPrime.nodeCorr[i] = (unsigned int)(i / 2);
134 }
135 for(unsigned int i = 0; i < scalarsVector.size(); i += 2) {
136 unsigned int node = i / 2;
137 for(auto &child : childrenFinal[node])
138 tree->makeSuperArc(child * 2, i);
139 }
140 mtu::getParentsVector(originPrime.mTree, originPrime.parentsOri);
141
142 if(isTreeHasBigValues(originPrime.mTree, bigValuesThreshold_)) {
143 std::stringstream ss;
144 ss << originPrime.mTree.tree.printPairsFromTree<float>(true).str()
145 << std::endl;
146 ss << "isTreeHasBigValues(originPrime.mTree)" << std::endl;
147 ss << "pause" << std::endl;
148 printMsg(ss.str());
149 std::cin.get();
150 }
151}
152
153void ttk::MergeTreeAutoencoder::initOutputBasis(unsigned int l,
154 unsigned int dim,
155 unsigned int dim2) {
156 unsigned int originSize = origins_[l].tensor.sizes()[0];
157 unsigned int origin2Size = 0;
158 if(useDoubleInput_)
159 origin2Size = origins2_[l].tensor.sizes()[0];
160
161 // --- Compute output basis origin
162 printMsg("Compute output basis origin", debug::Priority::DETAIL);
163 auto initOutputBasisOrigin = [this, &l](torch::Tensor &w,
164 mtu::TorchMergeTree<float> &tmt,
165 mtu::TorchMergeTree<float> &baseTmt) {
166 // - Create scalars
167 torch::nn::init::xavier_normal_(w);
168 torch::Tensor baseTmtTensor = baseTmt.tensor;
169 if(normalizedWasserstein_)
170 // Work on unnormalized tensor
171 mtu::mergeTreeToTorchTensor(baseTmt.mTree, baseTmtTensor, false);
172 torch::Tensor b = torch::fill(torch::zeros({w.sizes()[0], 1}), 0.01);
173 tmt.tensor = (torch::matmul(w, baseTmtTensor) + b);
174 // - Shift to keep mean birth and max pers
175 mtu::meanBirthMaxPersShift(tmt.tensor, baseTmtTensor);
176 // - Shift to avoid diagonal points
177 mtu::belowDiagonalPointsShift(tmt.tensor, baseTmtTensor);
178 //
179 auto endLayer
180 = (trackingLossDecoding_ ? noLayers_ : getLatentLayerIndex() + 1);
181 if(trackingLossWeight_ != 0 and l < endLayer) {
182 auto baseTensor
183 = (l == 0 ? origins_[0].tensor : originsPrime_[l - 1].tensor);
184 auto baseTensorDiag = baseTensor.reshape({-1, 2});
185 auto basePersDiag = (baseTensorDiag.index({Slice(), 1})
186 - baseTensorDiag.index({Slice(), 0}));
187 auto tmtTensorDiag = tmt.tensor.reshape({-1, 2});
188 auto persDiag = (tmtTensorDiag.index({Slice(1, None), 1})
189 - tmtTensorDiag.index({Slice(1, None), 0}));
190 int noK = std::min(baseTensorDiag.sizes()[0], tmtTensorDiag.sizes()[0]);
191 auto topVal = baseTensorDiag.index({std::get<1>(basePersDiag.topk(noK))});
192 auto indexes = std::get<1>(persDiag.topk(noK - 1)) + 1;
193 indexes = torch::cat({torch::zeros(1), indexes}).to(torch::kLong);
194 if(trackingLossInitRandomness_ != 0) {
195 topVal = (1 - trackingLossInitRandomness_) * topVal
196 + trackingLossInitRandomness_ * tmtTensorDiag.index({indexes});
197 }
198 tmtTensorDiag.index_put_({indexes}, topVal);
199 }
200 // - Create tree structure
201 initOutputBasisTreeStructure(
202 tmt, baseTmt.mTree.tree.isJoinTree<float>(), baseTmt);
203 if(normalizedWasserstein_)
204 // Normalize tensor
205 mtu::mergeTreeToTorchTensor(tmt.mTree, tmt.tensor, true);
206 // - Projection
207 interpolationProjection(tmt);
208 };
209 torch::Tensor w = torch::zeros({dim, originSize});
210 initOutputBasisOrigin(w, originsPrime_[l], origins_[l]);
211 torch::Tensor w2;
212 if(useDoubleInput_) {
213 w2 = torch::zeros({dim2, origin2Size});
214 initOutputBasisOrigin(w2, origins2Prime_[l], origins2_[l]);
215 }
216
217 // --- Compute output basis vectors
218 printMsg("Compute output basis vectors", debug::Priority::DETAIL);
219 initOutputBasisVectors(l, w, w2);
220}
221
222void ttk::MergeTreeAutoencoder::initOutputBasisVectors(unsigned int l,
223 torch::Tensor &w,
224 torch::Tensor &w2) {
225 vSPrimeTensor_[l] = torch::matmul(w, vSTensor_[l]);
226 if(useDoubleInput_)
227 vS2PrimeTensor_[l] = torch::matmul(w2, vS2Tensor_[l]);
228 if(normalizedWasserstein_) {
229 mtu::normalizeVectors(originsPrime_[l].tensor, vSPrimeTensor_[l]);
230 if(useDoubleInput_)
231 mtu::normalizeVectors(origins2Prime_[l].tensor, vS2PrimeTensor_[l]);
232 }
233}
234
235void ttk::MergeTreeAutoencoder::initOutputBasisVectors(unsigned int l,
236 unsigned int dim,
237 unsigned int dim2) {
238 unsigned int originSize = origins_[l].tensor.sizes()[0];
239 unsigned int origin2Size = 0;
240 if(useDoubleInput_)
241 origin2Size = origins2_[l].tensor.sizes()[0];
242 torch::Tensor w = torch::zeros({dim, originSize});
243 torch::nn::init::xavier_normal_(w);
244 torch::Tensor w2 = torch::zeros({dim2, origin2Size});
245 torch::nn::init::xavier_normal_(w2);
246 initOutputBasisVectors(l, w, w2);
247}
248
249void ttk::MergeTreeAutoencoder::initInputBasisOrigin(
250 std::vector<ftm::MergeTree<float>> &treesToUse,
251 std::vector<ftm::MergeTree<float>> &trees2ToUse,
252 double barycenterSizeLimitPercent,
253 unsigned int barycenterMaxNoPairs,
254 unsigned int barycenterMaxNoPairs2,
255 mtu::TorchMergeTree<float> &origin,
256 mtu::TorchMergeTree<float> &origin2,
257 std::vector<double> &inputToBaryDistances,
258 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
259 &baryMatchings,
260 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
261 &baryMatchings2) {
262 computeOneBarycenter<float>(treesToUse, origin.mTree, baryMatchings,
263 inputToBaryDistances, barycenterSizeLimitPercent,
264 useDoubleInput_);
265 if(barycenterMaxNoPairs > 0)
266 keepMostImportantPairs<float>(
267 &(origin.mTree.tree), barycenterMaxNoPairs, true);
268 if(useDoubleInput_) {
269 std::vector<double> baryDistances2;
270 computeOneBarycenter<float>(trees2ToUse, origin2.mTree, baryMatchings2,
271 baryDistances2, barycenterSizeLimitPercent,
272 useDoubleInput_, false);
273 if(barycenterMaxNoPairs2 > 0)
274 keepMostImportantPairs<float>(
275 &(origin2.mTree.tree), barycenterMaxNoPairs2, true);
276 for(unsigned int i = 0; i < inputToBaryDistances.size(); ++i)
277 inputToBaryDistances[i]
278 = mixDistances(inputToBaryDistances[i], baryDistances2[i]);
279 }
280
281 mtu::getParentsVector(origin.mTree, origin.parentsOri);
282 mtu::mergeTreeToTorchTensor<float>(
283 origin.mTree, origin.tensor, origin.nodeCorr, normalizedWasserstein_);
284 if(useDoubleInput_) {
285 mtu::getParentsVector(origin2.mTree, origin2.parentsOri);
286 mtu::mergeTreeToTorchTensor<float>(
287 origin2.mTree, origin2.tensor, origin2.nodeCorr, normalizedWasserstein_);
288 }
289}
290
291void ttk::MergeTreeAutoencoder::initInputBasisVectors(
292 std::vector<mtu::TorchMergeTree<float>> &tmTreesToUse,
293 std::vector<mtu::TorchMergeTree<float>> &tmTrees2ToUse,
294 std::vector<ftm::MergeTree<float>> &treesToUse,
295 std::vector<ftm::MergeTree<float>> &trees2ToUse,
296 mtu::TorchMergeTree<float> &origin,
297 mtu::TorchMergeTree<float> &origin2,
298 unsigned int noVectors,
299 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
300 unsigned int l,
301 std::vector<double> &inputToBaryDistances,
302 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
303 &baryMatchings,
304 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
305 &baryMatchings2,
306 torch::Tensor &vSTensor,
307 torch::Tensor &vS2Tensor) {
308 // --- Initialized vectors projection function to avoid collinearity
309 auto initializedVectorsProjection
310 = [=](int ttkNotUsed(_axeNumber),
311 ftm::MergeTree<float> &ttkNotUsed(_barycenter),
312 std::vector<std::vector<double>> &_v,
313 std::vector<std::vector<double>> &ttkNotUsed(_v2),
314 std::vector<std::vector<std::vector<double>>> &_vS,
315 std::vector<std::vector<std::vector<double>>> &ttkNotUsed(_v2s),
316 ftm::MergeTree<float> &ttkNotUsed(_barycenter2),
317 std::vector<std::vector<double>> &ttkNotUsed(_trees2V),
318 std::vector<std::vector<double>> &ttkNotUsed(_trees2V2),
319 std::vector<std::vector<std::vector<double>>> &ttkNotUsed(_trees2Vs),
320 std::vector<std::vector<std::vector<double>>> &ttkNotUsed(_trees2V2s),
321 bool ttkNotUsed(_useSecondInput),
322 unsigned int ttkNotUsed(_noProjectionStep)) {
323 std::vector<double> scaledV, scaledVSi;
326 scaledV, 1.0 / Geometry::magnitude(scaledV), scaledV);
327 for(unsigned int i = 0; i < _vS.size(); ++i) {
330 scaledVSi, 1.0 / Geometry::magnitude(scaledVSi), scaledVSi);
331 auto prod = Geometry::dotProduct(scaledV, scaledVSi);
332 double tol = 0.01;
333 if(prod <= -1.0 + tol or prod >= 1.0 - tol) {
334 // Reset vector to initialize it again
335 for(unsigned int j = 0; j < _v.size(); ++j)
336 for(unsigned int k = 0; k < _v[j].size(); ++k)
337 _v[j][k] = 0;
338 break;
339 }
340 }
341 return 0;
342 };
343
344 // --- Init vectors
345 std::vector<std::vector<double>> inputToAxesDistances;
346 std::vector<std::vector<std::vector<double>>> vS, v2s, trees2Vs, trees2V2s;
347 std::stringstream ss;
348 for(unsigned int vecNum = 0; vecNum < noVectors; ++vecNum) {
349 ss.str("");
350 ss << "Compute vectors " << vecNum;
352 std::vector<std::vector<double>> v1, v2, trees2V1, trees2V2;
353 int newVectorOffset = 0;
354 bool projectInitializedVectors = true;
355 int bestIndex = MergeTreeAxesAlgorithmBase::initVectors<float>(
356 vecNum, origin.mTree, treesToUse, origin2.mTree, trees2ToUse, v1, v2,
357 trees2V1, trees2V2, newVectorOffset, inputToBaryDistances, baryMatchings,
358 baryMatchings2, inputToAxesDistances, vS, v2s, trees2Vs, trees2V2s,
359 projectInitializedVectors, initializedVectorsProjection);
360 vS.emplace_back(v1);
361 v2s.emplace_back(v2);
362 trees2Vs.emplace_back(trees2V1);
363 trees2V2s.emplace_back(trees2V2);
364
365 ss.str("");
366 ss << "bestIndex = " << bestIndex;
368
369 // Update inputToAxesDistances
370 printMsg("Update inputToAxesDistances", debug::Priority::VERBOSE);
371 inputToAxesDistances.resize(1, std::vector<double>(treesToUse.size()));
372 if(bestIndex == -1 and normalizedWasserstein_) {
373 mtu::normalizeVectors(origin, vS[vS.size() - 1]);
374 if(useDoubleInput_)
375 mtu::normalizeVectors(origin2, trees2Vs[vS.size() - 1]);
376 }
377 mtu::axisVectorsToTorchTensor(origin.mTree, vS, vSTensor);
378 if(useDoubleInput_) {
379 mtu::axisVectorsToTorchTensor(origin2.mTree, trees2Vs, vS2Tensor);
380 }
381 mtu::TorchMergeTree<float> dummyTmt;
382 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>
383 dummyBaryMatching2;
384#ifdef TTK_ENABLE_OPENMP
385#pragma omp parallel for schedule(dynamic) \
386 num_threads(this->threadNumber_) if(parallelize_)
387#endif
388 for(unsigned int i = 0; i < treesToUse.size(); ++i) {
389 auto &tmt2ToUse = (not useDoubleInput_ ? dummyTmt : tmTrees2ToUse[i]);
390 if(not euclideanVectorsInit_) {
391 unsigned int k = k_;
392 auto newAlpha = torch::ones({1, 1});
393 if(bestIndex == -1) {
394 newAlpha = torch::zeros({1, 1});
395 }
396 allAlphasInit[i][l] = (allAlphasInit[i][l].defined()
397 ? torch::cat({allAlphasInit[i][l], newAlpha})
398 : newAlpha);
399 torch::Tensor bestAlphas;
400 bool isCalled = true;
401 inputToAxesDistances[0][i] = assignmentOneData(
402 tmTreesToUse[i], origin, vSTensor, tmt2ToUse, origin2, vS2Tensor, k,
403 allAlphasInit[i][l], bestAlphas, isCalled);
404 allAlphasInit[i][l] = bestAlphas.detach();
405 } else {
406 auto &baryMatching2ToUse
407 = (not useDoubleInput_ ? dummyBaryMatching2 : baryMatchings2[i]);
408 torch::Tensor alphas;
409 computeAlphas(tmTreesToUse[i], origin, vSTensor, origin,
410 baryMatchings[i], tmt2ToUse, origin2, vS2Tensor, origin2,
411 baryMatching2ToUse, alphas);
412 mtu::TorchMergeTree<float> interpolated, interpolated2;
413 getMultiInterpolation(origin, vSTensor, alphas, interpolated);
414 if(useDoubleInput_)
415 getMultiInterpolation(origin2, vS2Tensor, alphas, interpolated2);
416 torch::Tensor tensorDist;
417 bool doSqrt = true;
418 getDifferentiableDistanceFromMatchings(
419 interpolated, tmTreesToUse[i], interpolated2, tmt2ToUse,
420 baryMatchings[i], baryMatching2ToUse, tensorDist, doSqrt);
421 inputToAxesDistances[0][i] = tensorDist.item<double>();
422 allAlphasInit[i][l] = alphas.detach();
423 }
424 }
425 }
426}
427
428void ttk::MergeTreeAutoencoder::initClusteringLossParameters() {
429 unsigned int l = getLatentLayerIndex();
430 unsigned int noCentroids
431 = std::set<unsigned int>(clusterAsgn_.begin(), clusterAsgn_.end()).size();
432 latentCentroids_.resize(noCentroids);
433 for(unsigned int c = 0; c < noCentroids; ++c) {
434 unsigned int firstIndex = std::numeric_limits<unsigned int>::max();
435 for(unsigned int i = 0; i < clusterAsgn_.size(); ++i) {
436 if(clusterAsgn_[i] == c) {
437 firstIndex = i;
438 break;
439 }
440 }
441 if(firstIndex >= allAlphas_.size()) {
442 printWrn("no data found for cluster " + std::to_string(c));
443 // TODO init random centroid
444 }
445 latentCentroids_[c] = allAlphas_[firstIndex][l].detach().clone();
446 float noData = 1;
447 for(unsigned int i = 0; i < allAlphas_.size(); ++i) {
448 if(i == firstIndex)
449 continue;
450 if(clusterAsgn_[i] == c) {
451 latentCentroids_[c] += allAlphas_[i][l];
452 ++noData;
453 }
454 }
455 latentCentroids_[c] /= torch::tensor(noData);
456 latentCentroids_[c] = latentCentroids_[c].detach();
457 latentCentroids_[c].requires_grad_(true);
458 }
459}
460
461float ttk::MergeTreeAutoencoder::initParameters(
462 std::vector<mtu::TorchMergeTree<float>> &trees,
463 std::vector<mtu::TorchMergeTree<float>> &trees2,
464 bool computeReconstructionError) {
465 // ----- Init variables
466 // noLayers_ = number of encoder layers + number of decoder layers + the
467 // latent layer + the output layer
468 noLayers_ = encoderNoLayers_ * 2 + 1 + 1;
469 if(encoderNoLayers_ <= -1)
470 noLayers_ = 1;
471 std::vector<double> layersOriginPrimeSizePercent(noLayers_);
472 std::vector<unsigned int> layersNoAxes(noLayers_);
473 if(noLayers_ <= 2) {
474 layersNoAxes[0] = numberOfAxes_;
475 layersOriginPrimeSizePercent[0] = latentSpaceOriginPrimeSizePercent_;
476 if(noLayers_ == 2) {
477 layersNoAxes[1] = inputNumberOfAxes_;
478 layersOriginPrimeSizePercent[1] = barycenterSizeLimitPercent_;
479 }
480 } else {
481 for(unsigned int l = 0; l < noLayers_ / 2; ++l) {
482 double alpha = (double)(l) / (noLayers_ / 2 - 1);
483 unsigned int noAxes
484 = (1 - alpha) * inputNumberOfAxes_ + alpha * numberOfAxes_;
485 layersNoAxes[l] = noAxes;
486 layersNoAxes[noLayers_ - 1 - l] = noAxes;
487 double originPrimeSizePercent
488 = (1 - alpha) * inputOriginPrimeSizePercent_
489 + alpha * latentSpaceOriginPrimeSizePercent_;
490 layersOriginPrimeSizePercent[l] = originPrimeSizePercent;
491 layersOriginPrimeSizePercent[noLayers_ - 1 - l] = originPrimeSizePercent;
492 }
493 if(scaleLayerAfterLatent_)
494 layersNoAxes[noLayers_ / 2]
495 = (layersNoAxes[noLayers_ / 2 - 1] + layersNoAxes[noLayers_ / 2 + 1])
496 / 2.0;
497 }
498
499 std::vector<ftm::FTMTree_MT *> ftmTrees(trees.size()),
500 ftmTrees2(trees2.size());
501 for(unsigned int i = 0; i < trees.size(); ++i)
502 ftmTrees[i] = &(trees[i].mTree.tree);
503 for(unsigned int i = 0; i < trees2.size(); ++i)
504 ftmTrees2[i] = &(trees2[i].mTree.tree);
505 auto sizeMetric = getSizeLimitMetric(ftmTrees);
506 auto sizeMetric2 = getSizeLimitMetric(ftmTrees2);
507 auto getDim = [](double _sizeMetric, double _percent) {
508 unsigned int dim = std::max((int)(_sizeMetric * _percent / 100.0), 2) * 2;
509 return dim;
510 };
511
512 // ----- Resize parameters
513 origins_.resize(noLayers_);
514 originsPrime_.resize(noLayers_);
515 vSTensor_.resize(noLayers_);
516 vSPrimeTensor_.resize(noLayers_);
517 if(trees2.size() != 0) {
518 origins2_.resize(noLayers_);
519 origins2Prime_.resize(noLayers_);
520 vS2Tensor_.resize(noLayers_);
521 vS2PrimeTensor_.resize(noLayers_);
522 }
523
524 // ----- Compute parameters of each layer
525 bool fullSymmetricAE = fullSymmetricAE_;
526 bool outputBasisActivation = activateOutputInit_;
527
528 std::vector<mtu::TorchMergeTree<float>> recs, recs2;
529 std::vector<std::vector<torch::Tensor>> allAlphasInit(
530 trees.size(), std::vector<torch::Tensor>(noLayers_));
531 for(unsigned int l = 0; l < noLayers_; ++l) {
533 std::stringstream ss;
534 ss << "Init Layer " << l;
536
537 // --- Init Input Basis
538 if(l < (unsigned int)(noLayers_ / 2) or not fullSymmetricAE
539 or (noLayers_ <= 2 and not fullSymmetricAE)) {
540 // TODO is there a way to avoid copy of merge trees?
541 std::vector<ftm::MergeTree<float>> treesToUse, trees2ToUse;
542 for(unsigned int i = 0; i < trees.size(); ++i) {
543 treesToUse.emplace_back((l == 0 ? trees[i].mTree : recs[i].mTree));
544 if(trees2.size() != 0)
545 trees2ToUse.emplace_back((l == 0 ? trees2[i].mTree : recs2[i].mTree));
546 }
547
548 // - Compute origin
549 printMsg("Compute origin...", debug::Priority::DETAIL);
550 Timer t_origin;
551 std::vector<double> inputToBaryDistances;
552 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
553 baryMatchings, baryMatchings2;
554 if(l != 0 or not origins_[0].tensor.defined()) {
555 double sizeLimit = (l == 0 ? barycenterSizeLimitPercent_ : 0);
556 unsigned int maxNoPairs
557 = (l == 0 ? 0 : originsPrime_[l - 1].tensor.sizes()[0] / 2);
558 unsigned int maxNoPairs2
559 = (l == 0 or not useDoubleInput_
560 ? 0
561 : origins2Prime_[l - 1].tensor.sizes()[0] / 2);
562 initInputBasisOrigin(treesToUse, trees2ToUse, sizeLimit, maxNoPairs,
563 maxNoPairs2, origins_[l], origins2_[l],
564 inputToBaryDistances, baryMatchings,
565 baryMatchings2);
566 if(l == 0) {
567 baryMatchings_L0_ = baryMatchings;
568 baryMatchings2_L0_ = baryMatchings2;
569 inputToBaryDistances_L0_ = inputToBaryDistances;
570 }
571 } else {
572 baryMatchings = baryMatchings_L0_;
573 baryMatchings2 = baryMatchings2_L0_;
574 inputToBaryDistances = inputToBaryDistances_L0_;
575 }
576 printMsg("Compute origin time", 1, t_origin.getElapsedTime(),
578
579 // - Compute vectors
580 printMsg("Compute vectors...", debug::Priority::DETAIL);
581 Timer t_vectors;
582 auto &tmTreesToUse = (l == 0 ? trees : recs);
583 auto &tmTrees2ToUse = (l == 0 ? trees2 : recs2);
584 initInputBasisVectors(
585 tmTreesToUse, tmTrees2ToUse, treesToUse, trees2ToUse, origins_[l],
586 origins2_[l], layersNoAxes[l], allAlphasInit, l, inputToBaryDistances,
587 baryMatchings, baryMatchings2, vSTensor_[l], vS2Tensor_[l]);
588 printMsg("Compute vectors time", 1, t_vectors.getElapsedTime(),
590 } else {
591 // - Copy output tensors of the opposite layer (full symmetric init)
592 printMsg(
593 "Copy output tensors of the opposite layer", debug::Priority::DETAIL);
594 unsigned int middle = noLayers_ / 2;
595 unsigned int l_opp = middle - (l - middle + 1);
596 mtu::copyTorchMergeTree(originsPrime_[l_opp], origins_[l]);
597 mtu::copyTensor(vSPrimeTensor_[l_opp], vSTensor_[l]);
598 if(trees2.size() != 0) {
599 if(fullSymmetricAE) {
600 mtu::copyTorchMergeTree(origins2Prime_[l_opp], origins2_[l]);
601 mtu::copyTensor(vS2PrimeTensor_[l_opp], vS2Tensor_[l]);
602 }
603 }
604 for(unsigned int i = 0; i < trees.size(); ++i)
605 allAlphasInit[i][l] = allAlphasInit[i][l_opp];
606 }
607
608 // --- Init Output Basis
609 auto initOutputBasisSpecialCase
610 = [this, &l, &layersNoAxes, &trees, &trees2]() {
611 // - Compute Origin
612 printMsg("Compute output basis origin", debug::Priority::DETAIL);
613 mtu::copyTorchMergeTree(origins_[0], originsPrime_[l]);
614 if(useDoubleInput_)
615 mtu::copyTorchMergeTree(origins2_[0], origins2Prime_[l]);
616 // - Compute vectors
617 printMsg("Compute output basis vectors", debug::Priority::DETAIL);
618 if(layersNoAxes[l] != layersNoAxes[0]) {
619 // TODO is there a way to avoid copy of merge trees?
620 std::vector<ftm::MergeTree<float>> treesToUse, trees2ToUse;
621 for(unsigned int i = 0; i < trees.size(); ++i) {
622 treesToUse.emplace_back(trees[i].mTree);
623 if(useDoubleInput_)
624 trees2ToUse.emplace_back(trees2[i].mTree);
625 }
626 std::vector<std::vector<torch::Tensor>> allAlphasInitT(
627 trees.size(), std::vector<torch::Tensor>(noLayers_));
628 initInputBasisVectors(
629 trees, trees2, treesToUse, trees2ToUse, originsPrime_[l],
630 origins2Prime_[l], layersNoAxes[l], allAlphasInitT, l,
631 inputToBaryDistances_L0_, baryMatchings_L0_, baryMatchings2_L0_,
632 vSPrimeTensor_[l], vS2PrimeTensor_[l]);
633 } else {
634 mtu::copyTensor(vSTensor_[0], vSPrimeTensor_[l]);
635 if(useDoubleInput_)
636 mtu::copyTensor(vS2Tensor_[0], vS2PrimeTensor_[l]);
637 }
638 };
639
640 if((noLayers_ == 2 and l == 1) or noLayers_ == 1) {
641 // -- Special case
642 initOutputBasisSpecialCase();
643 } else if(l < (unsigned int)(noLayers_ / 2)) {
644 unsigned int dim = getDim(sizeMetric, layersOriginPrimeSizePercent[l]);
645 dim = std::min(dim, (unsigned int)origins_[l].tensor.sizes()[0]);
646 unsigned int dim2 = getDim(sizeMetric2, layersOriginPrimeSizePercent[l]);
647 if(trees2.size() != 0)
648 dim2 = std::min(dim2, (unsigned int)origins2_[l].tensor.sizes()[0]);
649 initOutputBasis(l, dim, dim2);
650 } else {
651 // - Copy input tensors of the opposite layer (symmetric init)
652 printMsg(
653 "Copy input tensors of the opposite layer", debug::Priority::DETAIL);
654 unsigned int middle = noLayers_ / 2;
655 unsigned int l_opp = middle - (l - middle + 1);
656 mtu::copyTorchMergeTree(origins_[l_opp], originsPrime_[l]);
657 if(trees2.size() != 0)
658 mtu::copyTorchMergeTree(origins2_[l_opp], origins2Prime_[l]);
659 if(l == (unsigned int)(noLayers_) / 2 and scaleLayerAfterLatent_) {
660 unsigned int dim2
661 = (trees2.size() != 0 ? origins2Prime_[l].tensor.sizes()[0] : 0);
662 initOutputBasisVectors(l, originsPrime_[l].tensor.sizes()[0], dim2);
663 } else {
664 mtu::copyTensor(vSTensor_[l_opp], vSPrimeTensor_[l]);
665 if(trees2.size() != 0)
666 mtu::copyTensor(vS2Tensor_[l_opp], vS2PrimeTensor_[l]);
667 }
668 }
669
670 // --- Get reconstructed
671 printMsg("Get reconstructed", debug::Priority::DETAIL);
672 recs.resize(trees.size());
673 recs2.resize(trees.size());
674 unsigned int i = 0;
675 unsigned int noReset = 0;
676 while(i < trees.size()) {
677 outputBasisReconstruction(originsPrime_[l], vSPrimeTensor_[l],
678 origins2Prime_[l], vS2PrimeTensor_[l],
679 allAlphasInit[i][l], recs[i], recs2[i],
680 outputBasisActivation);
681 if(recs[i].mTree.tree.getRealNumberOfNodes() == 0) {
682 printMsg("Reset output basis", debug::Priority::DETAIL);
683 if((noLayers_ == 2 and l == 1) or noLayers_ == 1) {
684 initOutputBasisSpecialCase();
685 } else if(l < (unsigned int)(noLayers_ / 2)) {
686 initOutputBasis(l,
687 getDim(sizeMetric, layersOriginPrimeSizePercent[l]),
688 getDim(sizeMetric2, layersOriginPrimeSizePercent[l]));
689 } else {
690 printErr("recs[i].mTree.tree.getRealNumberOfNodes() == 0");
691 std::stringstream ssT;
692 ssT << "layer " << l;
693 printWrn(ssT.str());
694 return std::numeric_limits<float>::max();
695 }
696 i = 0;
697 ++noReset;
698 if(noReset >= 100) {
699 printWrn("[initParameters] noReset >= 100");
700 return std::numeric_limits<float>::max();
701 }
702 }
703 ++i;
704 }
705 }
706 allAlphas_ = allAlphasInit;
707
708 // Init clustering parameters if needed
709 if(clusteringLossWeight_ != 0)
710 initClusteringLossParameters();
711
712 // Compute error
713 float error = 0.0, recLoss = 0.0;
714 if(computeReconstructionError) {
715 printMsg("Compute error", debug::Priority::DETAIL);
716 std::vector<unsigned int> indexes(trees.size());
717 std::iota(indexes.begin(), indexes.end(), 0);
718 // TODO forward only if necessary
719 unsigned int k = k_;
720 std::vector<std::vector<torch::Tensor>> bestAlphas;
721 std::vector<std::vector<mtu::TorchMergeTree<float>>> layersOuts,
722 layersOuts2;
723 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
724 matchings, matchings2;
725 bool reset
726 = forwardStep(trees, trees2, indexes, k, allAlphasInit,
727 computeReconstructionError, recs, recs2, bestAlphas,
728 layersOuts, layersOuts2, matchings, matchings2, recLoss);
729 if(reset) {
730 printWrn("[initParameters] forwardStep reset");
731 return std::numeric_limits<float>::max();
732 }
733 error = recLoss * reconstructionLossWeight_;
734 if(metricLossWeight_ != 0) {
735 torch::Tensor metricLoss;
736 computeMetricLoss(layersOuts, layersOuts2, allAlphas_, distanceMatrix_,
737 indexes, metricLoss);
738 baseRecLoss_ = std::numeric_limits<double>::max();
739 metricLoss *= metricLossWeight_
740 * getCustomLossDynamicWeight(recLoss, baseRecLoss_);
741 error += metricLoss.item<float>();
742 }
743 if(clusteringLossWeight_ != 0) {
744 torch::Tensor clusteringLoss, asgn;
745 computeClusteringLoss(allAlphas_, indexes, clusteringLoss, asgn);
746 baseRecLoss_ = std::numeric_limits<double>::max();
747 clusteringLoss *= clusteringLossWeight_
748 * getCustomLossDynamicWeight(recLoss, baseRecLoss_);
749 error += clusteringLoss.item<float>();
750 }
751 if(trackingLossWeight_ != 0) {
752 torch::Tensor trackingLoss;
753 computeTrackingLoss(trackingLoss);
754 trackingLoss *= trackingLossWeight_;
755 error += trackingLoss.item<float>();
756 }
757 }
758 return error;
759}
760
761void ttk::MergeTreeAutoencoder::initStep(
762 std::vector<mtu::TorchMergeTree<float>> &trees,
763 std::vector<mtu::TorchMergeTree<float>> &trees2) {
764 origins_.clear();
765 originsPrime_.clear();
766 vSTensor_.clear();
767 vSPrimeTensor_.clear();
768 origins2_.clear();
769 origins2Prime_.clear();
770 vS2Tensor_.clear();
771 vS2PrimeTensor_.clear();
772
773 float bestError = std::numeric_limits<float>::max();
774 std::vector<torch::Tensor> bestVSTensor, bestVSPrimeTensor, bestVS2Tensor,
775 bestVS2PrimeTensor, bestLatentCentroids;
776 std::vector<mtu::TorchMergeTree<float>> bestOrigins, bestOriginsPrime,
777 bestOrigins2, bestOrigins2Prime;
778 std::vector<std::vector<torch::Tensor>> bestAlphasInit;
779 for(unsigned int n = 0; n < noInit_; ++n) {
780 // Init parameters
781 float error = initParameters(trees, trees2, (noInit_ != 1));
782 // Save best parameters
783 if(noInit_ != 1) {
784 std::stringstream ss;
785 ss << "Init error = " << error;
786 printMsg(ss.str());
787 if(error < bestError) {
788 bestError = error;
789 copyParams(origins_, originsPrime_, vSTensor_, vSPrimeTensor_,
790 origins2_, origins2Prime_, vS2Tensor_, vS2PrimeTensor_,
791 allAlphas_, bestOrigins, bestOriginsPrime, bestVSTensor,
792 bestVSPrimeTensor, bestOrigins2, bestOrigins2Prime,
793 bestVS2Tensor, bestVS2PrimeTensor, bestAlphasInit);
794 bestLatentCentroids.resize(latentCentroids_.size());
795 for(unsigned int i = 0; i < latentCentroids_.size(); ++i)
796 mtu::copyTensor(latentCentroids_[i], bestLatentCentroids[i]);
797 }
798 }
799 }
800 // TODO this copy can be avoided if initParameters takes dummy tensors to fill
801 // as parameters and then copy to the member tensors when a better init is
802 // found.
803 if(noInit_ != 1) {
804 // Put back best parameters
805 std::stringstream ss;
806 ss << "Best init error = " << bestError;
807 printMsg(ss.str());
808 copyParams(bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor,
809 bestOrigins2, bestOrigins2Prime, bestVS2Tensor,
810 bestVS2PrimeTensor, bestAlphasInit, origins_, originsPrime_,
811 vSTensor_, vSPrimeTensor_, origins2_, origins2Prime_, vS2Tensor_,
812 vS2PrimeTensor_, allAlphas_);
813 latentCentroids_.resize(bestLatentCentroids.size());
814 for(unsigned int i = 0; i < bestLatentCentroids.size(); ++i)
815 mtu::copyTensor(bestLatentCentroids[i], latentCentroids_[i]);
816 }
817
818 for(unsigned int l = 0; l < noLayers_; ++l) {
819 origins_[l].tensor.requires_grad_(true);
820 originsPrime_[l].tensor.requires_grad_(true);
821 vSTensor_[l].requires_grad_(true);
822 vSPrimeTensor_[l].requires_grad_(true);
823 if(trees2.size() != 0) {
824 origins2_[l].tensor.requires_grad_(true);
825 origins2Prime_[l].tensor.requires_grad_(true);
826 vS2Tensor_[l].requires_grad_(true);
827 vS2PrimeTensor_[l].requires_grad_(true);
828 }
829
830 // Print
832 std::stringstream ss;
833 ss << "Layer " << l;
834 printMsg(ss.str());
835 if(isTreeHasBigValues(origins_[l].mTree, bigValuesThreshold_)) {
836 ss.str("");
837 ss << "origins_[" << l << "] has big values!" << std::endl;
838 printMsg(ss.str());
839 wae::printPairs(origins_[l].mTree);
840 }
841 if(isTreeHasBigValues(originsPrime_[l].mTree, bigValuesThreshold_)) {
842 ss.str("");
843 ss << "originsPrime_[" << l << "] has big values!" << std::endl;
844 printMsg(ss.str());
845 wae::printPairs(originsPrime_[l].mTree);
846 }
847 ss.str("");
848 ss << "vS size = " << vSTensor_[l].sizes();
849 printMsg(ss.str());
850 ss.str("");
851 ss << "vS' size = " << vSPrimeTensor_[l].sizes();
852 printMsg(ss.str());
853 if(trees2.size() != 0) {
854 ss.str("");
855 ss << "vS2 size = " << vS2Tensor_[l].sizes();
856 printMsg(ss.str());
857 ss.str("");
858 ss << "vS2' size = " << vS2PrimeTensor_[l].sizes();
859 printMsg(ss.str());
860 }
861 }
862
863 // Init Clustering Loss Parameters
864 if(clusteringLossWeight_ != 0)
865 initClusteringLossParameters();
866}
867
868// ---------------------------------------------------------------------------
869// --- Interpolation
870// ---------------------------------------------------------------------------
871void ttk::MergeTreeAutoencoder::interpolationDiagonalProjection(
872 mtu::TorchMergeTree<float> &interpolation) {
873 torch::Tensor diagTensor = interpolation.tensor.reshape({-1, 2});
874 if(interpolation.tensor.requires_grad())
875 diagTensor = diagTensor.detach();
876
877 torch::Tensor birthTensor = diagTensor.index({Slice(), 0});
878 torch::Tensor deathTensor = diagTensor.index({Slice(), 1});
879
880 torch::Tensor indexer = (birthTensor > deathTensor);
881
882 torch::Tensor allProj = (birthTensor + deathTensor) / 2.0;
883 allProj = allProj.index({indexer});
884 allProj = allProj.reshape({-1, 1});
885
886 diagTensor.index_put_({indexer}, allProj);
887}
888
889void ttk::MergeTreeAutoencoder::interpolationNestingProjection(
890 mtu::TorchMergeTree<float> &interpolation) {
891 torch::Tensor diagTensor = interpolation.tensor.reshape({-1, 2});
892 if(interpolation.tensor.requires_grad())
893 diagTensor = diagTensor.detach();
894
895 torch::Tensor birthTensor = diagTensor.index({Slice(1, None), 0});
896 torch::Tensor deathTensor = diagTensor.index({Slice(1, None), 1});
897
898 torch::Tensor birthIndexer = (birthTensor < 0);
899 torch::Tensor deathIndexer = (deathTensor < 0);
900 birthTensor.index_put_(
901 {birthIndexer}, torch::zeros_like(birthTensor.index({birthIndexer})));
902 deathTensor.index_put_(
903 {deathIndexer}, torch::zeros_like(deathTensor.index({deathIndexer})));
904
905 birthIndexer = (birthTensor > 1);
906 deathIndexer = (deathTensor > 1);
907 birthTensor.index_put_(
908 {birthIndexer}, torch::ones_like(birthTensor.index({birthIndexer})));
909 deathTensor.index_put_(
910 {deathIndexer}, torch::ones_like(deathTensor.index({deathIndexer})));
911}
912
913void ttk::MergeTreeAutoencoder::interpolationProjection(
914 mtu::TorchMergeTree<float> &interpolation) {
915 interpolationDiagonalProjection(interpolation);
916 if(normalizedWasserstein_)
917 interpolationNestingProjection(interpolation);
918
919 ftm::MergeTree<float> interpolationNew;
920 bool noRoot = mtu::torchTensorToMergeTree<float>(
921 interpolation, normalizedWasserstein_, interpolationNew);
922 if(noRoot)
923 printWrn("[interpolationProjection] no root found");
924 interpolation.mTree = copyMergeTree(interpolationNew);
925
926 persistenceThresholding<float>(&(interpolation.mTree.tree), 0.001);
927
928 if(isThereMissingPairs(interpolation) and isPersistenceDiagram_)
929 printWrn("[getMultiInterpolation] missing pairs");
930}
931
932void ttk::MergeTreeAutoencoder::getMultiInterpolation(
933 mtu::TorchMergeTree<float> &origin,
934 torch::Tensor &vS,
935 torch::Tensor &alphas,
936 mtu::TorchMergeTree<float> &interpolation) {
937 mtu::copyTorchMergeTree<float>(origin, interpolation);
938 interpolation.tensor = origin.tensor + torch::matmul(vS, alphas);
939 interpolationProjection(interpolation);
940}
941
942// ---------------------------------------------------------------------------
943// --- Forward
944// ---------------------------------------------------------------------------
945void ttk::MergeTreeAutoencoder::getAlphasOptimizationTensors(
946 mtu::TorchMergeTree<float> &tree,
947 mtu::TorchMergeTree<float> &origin,
948 torch::Tensor &vSTensor,
949 mtu::TorchMergeTree<float> &interpolated,
950 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
951 torch::Tensor &reorderedTreeTensor,
952 torch::Tensor &deltaOrigin,
953 torch::Tensor &deltaA,
954 torch::Tensor &originTensor_f,
955 torch::Tensor &vSTensor_f) {
956 // Create matching indexing
957 std::vector<int> tensorMatching;
958 mtu::getTensorMatching(interpolated, tree, matching, tensorMatching);
959
960 torch::Tensor indexes = torch::tensor(tensorMatching);
961 torch::Tensor projIndexer = (indexes == -1).reshape({-1, 1});
962
963 dataReorderingGivenMatching(
964 origin, tree, projIndexer, indexes, reorderedTreeTensor, deltaOrigin);
965
966 // Create axes projection given matching
967 deltaA = vSTensor.transpose(0, 1).reshape({vSTensor.sizes()[1], -1, 2});
968 deltaA = (deltaA.index({Slice(), Slice(), 0})
969 + deltaA.index({Slice(), Slice(), 1}))
970 / 2.0;
971 deltaA = torch::stack({deltaA, deltaA}, 2);
972 deltaA = deltaA * projIndexer;
973 deltaA = deltaA.reshape({vSTensor.sizes()[1], -1}).transpose(0, 1);
974
975 //
976 originTensor_f = origin.tensor;
977 vSTensor_f = vSTensor;
978}
979
980void ttk::MergeTreeAutoencoder::computeAlphas(
981 mtu::TorchMergeTree<float> &tree,
982 mtu::TorchMergeTree<float> &origin,
983 torch::Tensor &vSTensor,
984 mtu::TorchMergeTree<float> &interpolated,
985 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
986 mtu::TorchMergeTree<float> &tree2,
987 mtu::TorchMergeTree<float> &origin2,
988 torch::Tensor &vS2Tensor,
989 mtu::TorchMergeTree<float> &interpolated2,
990 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
991 torch::Tensor &alphasOut) {
992 torch::Tensor reorderedTreeTensor, deltaOrigin, deltaA, originTensor_f,
993 vSTensor_f;
994 getAlphasOptimizationTensors(tree, origin, vSTensor, interpolated, matching,
995 reorderedTreeTensor, deltaOrigin, deltaA,
996 originTensor_f, vSTensor_f);
997
998 if(useDoubleInput_) {
999 torch::Tensor reorderedTree2Tensor, deltaOrigin2, deltaA2, origin2Tensor_f,
1000 vS2Tensor_f;
1001 getAlphasOptimizationTensors(tree2, origin2, vS2Tensor, interpolated2,
1002 matching2, reorderedTree2Tensor, deltaOrigin2,
1003 deltaA2, origin2Tensor_f, vS2Tensor_f);
1004 vSTensor_f = torch::cat({vSTensor_f, vS2Tensor_f});
1005 deltaA = torch::cat({deltaA, deltaA2});
1006 reorderedTreeTensor
1007 = torch::cat({reorderedTreeTensor, reorderedTree2Tensor});
1008 originTensor_f = torch::cat({originTensor_f, origin2Tensor_f});
1009 deltaOrigin = torch::cat({deltaOrigin, deltaOrigin2});
1010 }
1011
1012 torch::Tensor r_axes = vSTensor_f - deltaA;
1013 torch::Tensor r_data = reorderedTreeTensor - originTensor_f + deltaOrigin;
1014
1015 // Pseudo inverse
1016 auto driver = "gelsd";
1017 alphasOut
1018 = std::get<0>(torch::linalg::lstsq(r_axes, r_data, c10::nullopt, driver));
1019
1020 alphasOut.reshape({-1, 1});
1021}
1022
1023float ttk::MergeTreeAutoencoder::assignmentOneData(
1024 mtu::TorchMergeTree<float> &tree,
1025 mtu::TorchMergeTree<float> &origin,
1026 torch::Tensor &vSTensor,
1027 mtu::TorchMergeTree<float> &tree2,
1028 mtu::TorchMergeTree<float> &origin2,
1029 torch::Tensor &vS2Tensor,
1030 unsigned int k,
1031 torch::Tensor &alphasInit,
1032 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching,
1033 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching2,
1034 torch::Tensor &bestAlphas,
1035 bool isCalled) {
1036 torch::Tensor alphas, oldAlphas;
1037 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching, matching2;
1038 float bestDistance = std::numeric_limits<float>::max();
1039 mtu::TorchMergeTree<float> interpolated, interpolated2;
1040 unsigned int i = 0;
1041 auto reset = [&]() {
1042 alphasInit = torch::randn_like(alphas);
1043 i = 0;
1044 };
1045 unsigned int noUpdate = 0;
1046 unsigned int noReset = 0;
1047 while(i < k) {
1048 if(i == 0) {
1049 if(alphasInit.defined())
1050 alphas = alphasInit;
1051 else
1052 alphas = torch::zeros({vSTensor.sizes()[1], 1});
1053 } else {
1054 computeAlphas(tree, origin, vSTensor, interpolated, matching, tree2,
1055 origin2, vS2Tensor, interpolated2, matching2, alphas);
1056 if(oldAlphas.defined() and alphas.defined() and alphas.equal(oldAlphas)
1057 and i != 1) {
1058 break;
1059 }
1060 }
1061 mtu::copyTensor(alphas, oldAlphas);
1062 getMultiInterpolation(origin, vSTensor, alphas, interpolated);
1063 if(useDoubleInput_)
1064 getMultiInterpolation(origin2, vS2Tensor, alphas, interpolated2);
1065 if(interpolated.mTree.tree.getRealNumberOfNodes() == 0
1066 or (useDoubleInput_
1067 and interpolated2.mTree.tree.getRealNumberOfNodes() == 0)) {
1068 ++noReset;
1069 if(noReset >= 100)
1070 printWrn("[assignmentOneData] noReset >= 100");
1071 reset();
1072 continue;
1073 }
1074 float distance;
1075 computeOneDistance<float>(interpolated.mTree, tree.mTree, matching,
1076 distance, isCalled, useDoubleInput_);
1077 if(useDoubleInput_) {
1078 float distance2;
1079 computeOneDistance<float>(interpolated2.mTree, tree2.mTree, matching2,
1080 distance2, isCalled, useDoubleInput_, false);
1081 distance = mixDistances<float>(distance, distance2);
1082 }
1083 if(distance < bestDistance and i != 0) {
1084 bestDistance = distance;
1085 bestMatching = matching;
1086 bestMatching2 = matching2;
1087 bestAlphas = alphas;
1088 noUpdate += 1;
1089 }
1090 i += 1;
1091 }
1092 if(noUpdate == 0)
1093 printErr("[assignmentOneData] noUpdate == 0");
1094 return bestDistance;
1095}
1096
1097float ttk::MergeTreeAutoencoder::assignmentOneData(
1098 mtu::TorchMergeTree<float> &tree,
1099 mtu::TorchMergeTree<float> &origin,
1100 torch::Tensor &vSTensor,
1101 mtu::TorchMergeTree<float> &tree2,
1102 mtu::TorchMergeTree<float> &origin2,
1103 torch::Tensor &vS2Tensor,
1104 unsigned int k,
1105 torch::Tensor &alphasInit,
1106 torch::Tensor &bestAlphas,
1107 bool isCalled) {
1108 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> bestMatching,
1109 bestMatching2;
1110 return assignmentOneData(tree, origin, vSTensor, tree2, origin2, vS2Tensor, k,
1111 alphasInit, bestMatching, bestMatching2, bestAlphas,
1112 isCalled);
1113}
1114
1115torch::Tensor ttk::MergeTreeAutoencoder::activation(torch::Tensor &in) {
1116 torch::Tensor act;
1117 switch(activationFunction_) {
1118 case 1:
1119 act = torch::nn::LeakyReLU()(in);
1120 break;
1121 case 0:
1122 default:
1123 act = torch::nn::ReLU()(in);
1124 }
1125 return act;
1126}
1127
1128void ttk::MergeTreeAutoencoder::outputBasisReconstruction(
1129 mtu::TorchMergeTree<float> &originPrime,
1130 torch::Tensor &vSPrimeTensor,
1131 mtu::TorchMergeTree<float> &origin2Prime,
1132 torch::Tensor &vS2PrimeTensor,
1133 torch::Tensor &alphas,
1134 mtu::TorchMergeTree<float> &out,
1135 mtu::TorchMergeTree<float> &out2,
1136 bool activate) {
1137 if(not activate_)
1138 activate = false;
1139 torch::Tensor act = (activate ? activation(alphas) : alphas);
1140 getMultiInterpolation(originPrime, vSPrimeTensor, act, out);
1141 if(useDoubleInput_)
1142 getMultiInterpolation(origin2Prime, vS2PrimeTensor, act, out2);
1143}
1144
1145bool ttk::MergeTreeAutoencoder::forwardOneLayer(
1146 mtu::TorchMergeTree<float> &tree,
1147 mtu::TorchMergeTree<float> &origin,
1148 torch::Tensor &vSTensor,
1149 mtu::TorchMergeTree<float> &originPrime,
1150 torch::Tensor &vSPrimeTensor,
1151 mtu::TorchMergeTree<float> &tree2,
1152 mtu::TorchMergeTree<float> &origin2,
1153 torch::Tensor &vS2Tensor,
1154 mtu::TorchMergeTree<float> &origin2Prime,
1155 torch::Tensor &vS2PrimeTensor,
1156 unsigned int k,
1157 torch::Tensor &alphasInit,
1158 mtu::TorchMergeTree<float> &out,
1159 mtu::TorchMergeTree<float> &out2,
1160 torch::Tensor &bestAlphas,
1161 float &bestDistance) {
1162 bool goodOutput = false;
1163 int noReset = 0;
1164 while(not goodOutput) {
1165 bool isCalled = true;
1166 bestDistance
1167 = assignmentOneData(tree, origin, vSTensor, tree2, origin2, vS2Tensor, k,
1168 alphasInit, bestAlphas, isCalled);
1169 outputBasisReconstruction(originPrime, vSPrimeTensor, origin2Prime,
1170 vS2PrimeTensor, bestAlphas, out, out2);
1171 goodOutput = (out.mTree.tree.getRealNumberOfNodes() != 0
1172 and (not useDoubleInput_
1173 or out2.mTree.tree.getRealNumberOfNodes() != 0));
1174 if(not goodOutput) {
1175 ++noReset;
1176 if(noReset >= 100) {
1177 printWrn("[forwardOneLayer] noReset >= 100");
1178 return true;
1179 }
1180 alphasInit = torch::randn_like(alphasInit);
1181 }
1182 }
1183 return false;
1184}
1185
1186bool ttk::MergeTreeAutoencoder::forwardOneLayer(
1187 mtu::TorchMergeTree<float> &tree,
1188 mtu::TorchMergeTree<float> &origin,
1189 torch::Tensor &vSTensor,
1190 mtu::TorchMergeTree<float> &originPrime,
1191 torch::Tensor &vSPrimeTensor,
1192 mtu::TorchMergeTree<float> &tree2,
1193 mtu::TorchMergeTree<float> &origin2,
1194 torch::Tensor &vS2Tensor,
1195 mtu::TorchMergeTree<float> &origin2Prime,
1196 torch::Tensor &vS2PrimeTensor,
1197 unsigned int k,
1198 torch::Tensor &alphasInit,
1199 mtu::TorchMergeTree<float> &out,
1200 mtu::TorchMergeTree<float> &out2,
1201 torch::Tensor &bestAlphas) {
1202 float bestDistance;
1203 return forwardOneLayer(tree, origin, vSTensor, originPrime, vSPrimeTensor,
1204 tree2, origin2, vS2Tensor, origin2Prime,
1205 vS2PrimeTensor, k, alphasInit, out, out2, bestAlphas,
1206 bestDistance);
1207}
1208
1209bool ttk::MergeTreeAutoencoder::forwardOneData(
1210 mtu::TorchMergeTree<float> &tree,
1211 mtu::TorchMergeTree<float> &tree2,
1212 unsigned int treeIndex,
1213 unsigned int k,
1214 std::vector<torch::Tensor> &alphasInit,
1215 mtu::TorchMergeTree<float> &out,
1216 mtu::TorchMergeTree<float> &out2,
1217 std::vector<torch::Tensor> &dataAlphas,
1218 std::vector<mtu::TorchMergeTree<float>> &outs,
1219 std::vector<mtu::TorchMergeTree<float>> &outs2) {
1220 outs.resize(noLayers_ - 1);
1221 outs2.resize(noLayers_ - 1);
1222 dataAlphas.resize(noLayers_);
1223 for(unsigned int l = 0; l < noLayers_; ++l) {
1224 auto &treeToUse = (l == 0 ? tree : outs[l - 1]);
1225 auto &tree2ToUse = (l == 0 ? tree2 : outs2[l - 1]);
1226 auto &outToUse = (l != noLayers_ - 1 ? outs[l] : out);
1227 auto &out2ToUse = (l != noLayers_ - 1 ? outs2[l] : out2);
1228 bool reset = forwardOneLayer(
1229 treeToUse, origins_[l], vSTensor_[l], originsPrime_[l], vSPrimeTensor_[l],
1230 tree2ToUse, origins2_[l], vS2Tensor_[l], origins2Prime_[l],
1231 vS2PrimeTensor_[l], k, alphasInit[l], outToUse, out2ToUse, dataAlphas[l]);
1232 if(reset)
1233 return true;
1234 // Update recs
1235 auto updateRecs
1236 = [this, &treeIndex, &l](
1237 std::vector<std::vector<mtu::TorchMergeTree<float>>> &recs,
1238 mtu::TorchMergeTree<float> &outT) {
1239 if(recs[treeIndex].size() > noLayers_)
1240 mtu::copyTorchMergeTree<float>(outT, recs[treeIndex][l + 1]);
1241 else {
1242 mtu::TorchMergeTree<float> tmt;
1243 mtu::copyTorchMergeTree<float>(outT, tmt);
1244 recs[treeIndex].emplace_back(tmt);
1245 }
1246 };
1247 updateRecs(recs_, outToUse);
1248 if(useDoubleInput_)
1249 updateRecs(recs2_, out2ToUse);
1250 }
1251 return false;
1252}
1253
1254bool ttk::MergeTreeAutoencoder::forwardStep(
1255 std::vector<mtu::TorchMergeTree<float>> &trees,
1256 std::vector<mtu::TorchMergeTree<float>> &trees2,
1257 std::vector<unsigned int> &indexes,
1258 unsigned int k,
1259 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
1260 bool computeReconstructionError,
1261 std::vector<mtu::TorchMergeTree<float>> &outs,
1262 std::vector<mtu::TorchMergeTree<float>> &outs2,
1263 std::vector<std::vector<torch::Tensor>> &bestAlphas,
1264 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
1265 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
1266 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1267 &matchings,
1268 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1269 &matchings2,
1270 float &loss) {
1271 loss = 0;
1272 outs.resize(trees.size());
1273 outs2.resize(trees.size());
1274 bestAlphas.resize(trees.size());
1275 layersOuts.resize(trees.size());
1276 layersOuts2.resize(trees.size());
1277 matchings.resize(trees.size());
1278 if(useDoubleInput_)
1279 matchings2.resize(trees2.size());
1280 mtu::TorchMergeTree<float> dummyTMT;
1281 bool reset = false;
1282#ifdef TTK_ENABLE_OPENMP
1283#pragma omp parallel for schedule(dynamic) num_threads(this->threadNumber_) \
1284 if(parallelize_) reduction(||: reset) reduction(+:loss)
1285#endif
1286 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
1287 unsigned int i = indexes[ind];
1288 auto &tree2ToUse = (trees2.size() == 0 ? dummyTMT : trees2[i]);
1289 bool dReset
1290 = forwardOneData(trees[i], tree2ToUse, i, k, allAlphasInit[i], outs[i],
1291 outs2[i], bestAlphas[i], layersOuts[i], layersOuts2[i]);
1292 if(computeReconstructionError) {
1293 float iLoss = computeOneLoss(
1294 trees[i], outs[i], trees2[i], outs2[i], matchings[i], matchings2[i]);
1295 loss += iLoss;
1296 }
1297 if(dReset)
1298 reset = reset || dReset;
1299 }
1300 loss /= indexes.size();
1301 return reset;
1302}
1303
1304bool ttk::MergeTreeAutoencoder::forwardStep(
1305 std::vector<mtu::TorchMergeTree<float>> &trees,
1306 std::vector<mtu::TorchMergeTree<float>> &trees2,
1307 std::vector<unsigned int> &indexes,
1308 unsigned int k,
1309 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
1310 std::vector<mtu::TorchMergeTree<float>> &outs,
1311 std::vector<mtu::TorchMergeTree<float>> &outs2,
1312 std::vector<std::vector<torch::Tensor>> &bestAlphas) {
1313 std::vector<std::vector<mtu::TorchMergeTree<float>>> layersOuts, layersOuts2;
1314 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1315 matchings, matchings2;
1316 bool computeReconstructionError = false;
1317 float loss;
1318 return forwardStep(trees, trees2, indexes, k, allAlphasInit,
1319 computeReconstructionError, outs, outs2, bestAlphas,
1320 layersOuts, layersOuts2, matchings, matchings2, loss);
1321}
1322
1323// ---------------------------------------------------------------------------
1324// --- Backward
1325// ---------------------------------------------------------------------------
1326bool ttk::MergeTreeAutoencoder::backwardStep(
1327 std::vector<mtu::TorchMergeTree<float>> &trees,
1328 std::vector<mtu::TorchMergeTree<float>> &outs,
1329 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1330 &matchings,
1331 std::vector<mtu::TorchMergeTree<float>> &trees2,
1332 std::vector<mtu::TorchMergeTree<float>> &outs2,
1333 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1334 &matchings2,
1335 torch::optim::Optimizer &optimizer,
1336 std::vector<unsigned int> &indexes,
1337 torch::Tensor &metricLoss,
1338 torch::Tensor &clusteringLoss,
1339 torch::Tensor &trackingLoss) {
1340 double totalLoss = 0;
1341 bool retainGraph = (metricLossWeight_ != 0 or clusteringLossWeight_ != 0
1342 or trackingLossWeight_ != 0);
1343 if(reconstructionLossWeight_ != 0
1344 or (customLossDynamicWeight_ and retainGraph)) {
1345 std::vector<torch::Tensor> outTensors(indexes.size()),
1346 reorderedTensors(indexes.size());
1347#ifdef TTK_ENABLE_OPENMP
1348#pragma omp parallel for schedule(dynamic) \
1349 num_threads(this->threadNumber_) if(parallelize_)
1350#endif
1351 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
1352 unsigned int i = indexes[ind];
1353 torch::Tensor reorderedTensor;
1354 dataReorderingGivenMatching(
1355 outs[i], trees[i], matchings[i], reorderedTensor);
1356 auto outTensor = outs[i].tensor;
1357 if(useDoubleInput_) {
1358 torch::Tensor reorderedTensor2;
1359 dataReorderingGivenMatching(
1360 outs2[i], trees2[i], matchings2[i], reorderedTensor2);
1361 outTensor = torch::cat({outTensor, outs2[i].tensor});
1362 reorderedTensor = torch::cat({reorderedTensor, reorderedTensor2});
1363 }
1364 outTensors[ind] = outTensor;
1365 reorderedTensors[ind] = reorderedTensor;
1366 }
1367 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
1368 auto loss = torch::nn::functional::mse_loss(
1369 outTensors[ind], reorderedTensors[ind]);
1370 // Same as next loss with a factor of 1 / n where n is the number of nodes
1371 // in the output
1372 // auto loss = (outTensors[ind] - reorderedTensors[ind]).pow(2).sum();
1373 totalLoss += loss.item<float>();
1374 loss *= reconstructionLossWeight_;
1375 loss.backward({}, retainGraph);
1376 }
1377 }
1378 if(metricLossWeight_ != 0) {
1379 bool retainGraphMetricLoss
1380 = (clusteringLossWeight_ != 0 or trackingLossWeight_ != 0);
1381 metricLoss *= metricLossWeight_
1382 * getCustomLossDynamicWeight(
1383 totalLoss / indexes.size(), baseRecLoss2_);
1384 metricLoss.backward({}, retainGraphMetricLoss);
1385 }
1386 if(clusteringLossWeight_ != 0) {
1387 bool retainGraphClusteringLoss = (trackingLossWeight_ != 0);
1388 clusteringLoss *= clusteringLossWeight_
1389 * getCustomLossDynamicWeight(
1390 totalLoss / indexes.size(), baseRecLoss2_);
1391 clusteringLoss.backward({}, retainGraphClusteringLoss);
1392 }
1393 if(trackingLossWeight_ != 0) {
1394 trackingLoss *= trackingLossWeight_;
1395 trackingLoss.backward();
1396 }
1397
1398 for(unsigned int l = 0; l < noLayers_; ++l) {
1399 if(not origins_[l].tensor.grad().defined()
1400 or not origins_[l].tensor.grad().count_nonzero().is_nonzero())
1401 ++originsNoZeroGrad_[l];
1402 if(not originsPrime_[l].tensor.grad().defined()
1403 or not originsPrime_[l].tensor.grad().count_nonzero().is_nonzero())
1404 ++originsPrimeNoZeroGrad_[l];
1405 if(not vSTensor_[l].grad().defined()
1406 or not vSTensor_[l].grad().count_nonzero().is_nonzero())
1407 ++vSNoZeroGrad_[l];
1408 if(not vSPrimeTensor_[l].grad().defined()
1409 or not vSPrimeTensor_[l].grad().count_nonzero().is_nonzero())
1410 ++vSPrimeNoZeroGrad_[l];
1411 if(useDoubleInput_) {
1412 if(not origins2_[l].tensor.grad().defined()
1413 or not origins2_[l].tensor.grad().count_nonzero().is_nonzero())
1414 ++origins2NoZeroGrad_[l];
1415 if(not origins2Prime_[l].tensor.grad().defined()
1416 or not origins2Prime_[l].tensor.grad().count_nonzero().is_nonzero())
1417 ++origins2PrimeNoZeroGrad_[l];
1418 if(not vS2Tensor_[l].grad().defined()
1419 or not vS2Tensor_[l].grad().count_nonzero().is_nonzero())
1420 ++vS2NoZeroGrad_[l];
1421 if(not vS2PrimeTensor_[l].grad().defined()
1422 or not vS2PrimeTensor_[l].grad().count_nonzero().is_nonzero())
1423 ++vS2PrimeNoZeroGrad_[l];
1424 }
1425 }
1426
1427 optimizer.step();
1428 optimizer.zero_grad();
1429 return false;
1430}
1431
1432// ---------------------------------------------------------------------------
1433// --- Projection
1434// ---------------------------------------------------------------------------
1435void ttk::MergeTreeAutoencoder::projectionStep() {
1436 auto projectTree = [this](mtu::TorchMergeTree<float> &tmt) {
1437 interpolationProjection(tmt);
1438 tmt.tensor = tmt.tensor.detach();
1439 tmt.tensor.requires_grad_(true);
1440 };
1441 for(unsigned int l = 0; l < noLayers_; ++l) {
1442 projectTree(origins_[l]);
1443 projectTree(originsPrime_[l]);
1444 if(useDoubleInput_) {
1445 projectTree(origins2_[l]);
1446 projectTree(origins2Prime_[l]);
1447 }
1448 }
1449}
1450
1451// ---------------------------------------------------------------------------
1452// --- Convergence
1453// ---------------------------------------------------------------------------
1454float ttk::MergeTreeAutoencoder::computeOneLoss(
1455 mtu::TorchMergeTree<float> &tree,
1456 mtu::TorchMergeTree<float> &out,
1457 mtu::TorchMergeTree<float> &tree2,
1458 mtu::TorchMergeTree<float> &out2,
1459 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
1460 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2) {
1461 float loss = 0;
1462 bool isCalled = true;
1463 float distance;
1464 computeOneDistance<float>(
1465 out.mTree, tree.mTree, matching, distance, isCalled, useDoubleInput_);
1466 if(useDoubleInput_) {
1467 float distance2;
1468 computeOneDistance<float>(out2.mTree, tree2.mTree, matching2, distance2,
1469 isCalled, useDoubleInput_, false);
1470 distance = mixDistances<float>(distance, distance2);
1471 }
1472 loss += distance * distance;
1473 return loss;
1474}
1475
1476float ttk::MergeTreeAutoencoder::computeLoss(
1477 std::vector<mtu::TorchMergeTree<float>> &trees,
1478 std::vector<mtu::TorchMergeTree<float>> &outs,
1479 std::vector<mtu::TorchMergeTree<float>> &trees2,
1480 std::vector<mtu::TorchMergeTree<float>> &outs2,
1481 std::vector<unsigned int> &indexes,
1482 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1483 &matchings,
1484 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1485 &matchings2) {
1486 float loss = 0;
1487 matchings.resize(trees.size());
1488 if(useDoubleInput_)
1489 matchings2.resize(trees2.size());
1490#ifdef TTK_ENABLE_OPENMP
1491#pragma omp parallel for schedule(dynamic) num_threads(this->threadNumber_) \
1492 if(parallelize_) reduction(+:loss)
1493#endif
1494 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
1495 unsigned int i = indexes[ind];
1496 float iLoss = computeOneLoss(
1497 trees[i], outs[i], trees2[i], outs2[i], matchings[i], matchings2[i]);
1498 loss += iLoss;
1499 }
1500 return loss / indexes.size();
1501}
1502
1503bool ttk::MergeTreeAutoencoder::isBestLoss(float loss,
1504 float &minLoss,
1505 unsigned int &cptBlocked) {
1506 bool isBestEnergy = false;
1507 if(loss + ENERGY_COMPARISON_TOLERANCE < minLoss) {
1508 minLoss = loss;
1509 cptBlocked = 0;
1510 isBestEnergy = true;
1511 }
1512 return isBestEnergy;
1513}
1514
1515bool ttk::MergeTreeAutoencoder::convergenceStep(float loss,
1516 float &oldLoss,
1517 float &minLoss,
1518 unsigned int &cptBlocked) {
1519 double tol = oldLoss / 125.0;
1520 bool converged = std::abs(loss - oldLoss) < std::abs(tol);
1521 oldLoss = loss;
1522 if(not converged) {
1523 cptBlocked += (minLoss < loss) ? 1 : 0;
1524 converged = (cptBlocked >= 10 * 10);
1525 if(converged)
1526 printMsg("Blocked!", debug::Priority::DETAIL);
1527 }
1528 return converged;
1529}
1530
1531// ---------------------------------------------------------------------------
1532// --- Main Functions
1533// ---------------------------------------------------------------------------
1534void ttk::MergeTreeAutoencoder::fit(
1535 std::vector<ftm::MergeTree<float>> &trees,
1536 std::vector<ftm::MergeTree<float>> &trees2) {
1537 torch::set_num_threads(1);
1538 // ----- Determinism
1539 if(deterministic_) {
1540 int m_seed = 0;
1541 bool m_torch_deterministic = true;
1542 srand(m_seed);
1543 torch::manual_seed(m_seed);
1544 at::globalContext().setDeterministicCuDNN(m_torch_deterministic ? true
1545 : false);
1546 at::globalContext().setDeterministicAlgorithms(
1547 m_torch_deterministic ? true : false, true);
1548 }
1549
1550 // ----- Testing
1551 for(unsigned int i = 0; i < trees.size(); ++i) {
1552 for(unsigned int n = 0; n < trees[i].tree.getNumberOfNodes(); ++n) {
1553 if(trees[i].tree.isNodeAlone(n))
1554 continue;
1555 auto birthDeath = trees[i].tree.template getBirthDeath<float>(n);
1556 bigValuesThreshold_
1557 = std::max(std::abs(std::get<0>(birthDeath)), bigValuesThreshold_);
1558 bigValuesThreshold_
1559 = std::max(std::abs(std::get<1>(birthDeath)), bigValuesThreshold_);
1560 }
1561 }
1562 bigValuesThreshold_ *= 100;
1563
1564 // ----- Convert MergeTree to TorchMergeTree
1565 std::vector<mtu::TorchMergeTree<float>> torchTrees, torchTrees2;
1566 mergeTreesToTorchTrees(trees, torchTrees, normalizedWasserstein_);
1567 mergeTreesToTorchTrees(trees2, torchTrees2, normalizedWasserstein_);
1568
1569 auto initRecs = [](std::vector<std::vector<mtu::TorchMergeTree<float>>> &recs,
1570 std::vector<mtu::TorchMergeTree<float>> &torchTreesT) {
1571 recs.clear();
1572 recs.resize(torchTreesT.size());
1573 for(unsigned int i = 0; i < torchTreesT.size(); ++i) {
1574 mtu::TorchMergeTree<float> tmt;
1575 mtu::copyTorchMergeTree<float>(torchTreesT[i], tmt);
1576 recs[i].emplace_back(tmt);
1577 }
1578 };
1579 initRecs(recs_, torchTrees);
1580 if(useDoubleInput_)
1581 initRecs(recs2_, torchTrees2);
1582
1583 // ----- Init Metric Loss
1584 if(metricLossWeight_ != 0)
1585 getDistanceMatrix(torchTrees, torchTrees2, distanceMatrix_);
1586
1587 // ----- Init Model Parameters
1588 Timer t_init;
1589 initStep(torchTrees, torchTrees2);
1590 printMsg("Init", 1, t_init.getElapsedTime(), threadNumber_);
1591
1592 // --- Init optimizer
1593 std::vector<torch::Tensor> parameters;
1594 for(unsigned int l = 0; l < noLayers_; ++l) {
1595 parameters.emplace_back(origins_[l].tensor);
1596 parameters.emplace_back(originsPrime_[l].tensor);
1597 parameters.emplace_back(vSTensor_[l]);
1598 parameters.emplace_back(vSPrimeTensor_[l]);
1599 if(trees2.size() != 0) {
1600 parameters.emplace_back(origins2_[l].tensor);
1601 parameters.emplace_back(origins2Prime_[l].tensor);
1602 parameters.emplace_back(vS2Tensor_[l]);
1603 parameters.emplace_back(vS2PrimeTensor_[l]);
1604 }
1605 }
1606 if(clusteringLossWeight_ != 0)
1607 for(unsigned int i = 0; i < latentCentroids_.size(); ++i)
1608 parameters.emplace_back(latentCentroids_[i]);
1609
1610 torch::optim::Optimizer *optimizer;
1611 // - Init Adam
1612 auto adamOptions = torch::optim::AdamOptions(gradientStepSize_);
1613 adamOptions.betas(std::make_tuple(beta1_, beta2_));
1614 auto adamOptimizer = torch::optim::Adam(parameters, adamOptions);
1615 // - Init SGD optimizer
1616 auto sgdOptions = torch::optim::SGDOptions(gradientStepSize_);
1617 auto sgdOptimizer = torch::optim::SGD(parameters, sgdOptions);
1618 // -Init RMSprop optimizer
1619 auto rmspropOptions = torch::optim::RMSpropOptions(gradientStepSize_);
1620 auto rmspropOptimizer = torch::optim::RMSprop(parameters, rmspropOptions);
1621 // - Set optimizer pointer
1622 switch(optimizer_) {
1623 case 1:
1624 optimizer = &sgdOptimizer;
1625 break;
1626 case 2:
1627 optimizer = &rmspropOptimizer;
1628 break;
1629 case 0:
1630 default:
1631 optimizer = &adamOptimizer;
1632 }
1633
1634 // --- Init batches indexes
1635 unsigned int batchSize = std::min(
1636 std::max((int)(trees.size() * batchSize_), 1), (int)trees.size());
1637 std::stringstream ssBatch;
1638 ssBatch << "batchSize = " << batchSize;
1639 printMsg(ssBatch.str());
1640 unsigned int noBatch
1641 = trees.size() / batchSize + ((trees.size() % batchSize) != 0 ? 1 : 0);
1642 std::vector<std::vector<unsigned int>> allIndexes(noBatch);
1643 if(noBatch == 1) {
1644 allIndexes[0].resize(trees.size());
1645 std::iota(allIndexes[0].begin(), allIndexes[0].end(), 0);
1646 }
1647 auto rng = std::default_random_engine{};
1648
1649 // ----- Testing
1650 originsNoZeroGrad_.resize(noLayers_);
1651 originsPrimeNoZeroGrad_.resize(noLayers_);
1652 vSNoZeroGrad_.resize(noLayers_);
1653 vSPrimeNoZeroGrad_.resize(noLayers_);
1654 for(unsigned int l = 0; l < noLayers_; ++l) {
1655 originsNoZeroGrad_[l] = 0;
1656 originsPrimeNoZeroGrad_[l] = 0;
1657 vSNoZeroGrad_[l] = 0;
1658 vSPrimeNoZeroGrad_[l] = 0;
1659 }
1660 if(useDoubleInput_) {
1661 origins2NoZeroGrad_.resize(noLayers_);
1662 origins2PrimeNoZeroGrad_.resize(noLayers_);
1663 vS2NoZeroGrad_.resize(noLayers_);
1664 vS2PrimeNoZeroGrad_.resize(noLayers_);
1665 for(unsigned int l = 0; l < noLayers_; ++l) {
1666 origins2NoZeroGrad_[l] = 0;
1667 origins2PrimeNoZeroGrad_[l] = 0;
1668 vS2NoZeroGrad_[l] = 0;
1669 vS2PrimeNoZeroGrad_[l] = 0;
1670 }
1671 }
1672
1673 // ----- Init Variables
1674 baseRecLoss_ = std::numeric_limits<double>::max();
1675 baseRecLoss2_ = std::numeric_limits<double>::max();
1676 unsigned int k = k_;
1677 float oldLoss, minLoss, minRecLoss, minMetricLoss, minClustLoss, minTrackLoss;
1678 unsigned int cptBlocked, iteration = 0;
1679 auto initLoop = [&]() {
1680 oldLoss = -1;
1681 minLoss = std::numeric_limits<float>::max();
1682 minRecLoss = minLoss;
1683 minMetricLoss = minLoss;
1684 minClustLoss = minLoss;
1685 minTrackLoss = minLoss;
1686 cptBlocked = 0;
1687 iteration = 0;
1688 };
1689 initLoop();
1690 int convWinSize = 5;
1691 int noConverged = 0, noConvergedToGet = 10;
1692 std::vector<float> losses, metricLosses, clusteringLosses, trackingLosses;
1693 float windowLoss = 0;
1694
1695 double assignmentTime = 0.0, updateTime = 0.0, projectionTime = 0.0,
1696 lossTime = 0.0;
1697
1698 int bestIteration = 0;
1699 std::vector<torch::Tensor> bestVSTensor, bestVSPrimeTensor, bestVS2Tensor,
1700 bestVS2PrimeTensor;
1701 std::vector<mtu::TorchMergeTree<float>> bestOrigins, bestOriginsPrime,
1702 bestOrigins2, bestOrigins2Prime;
1703 std::vector<std::vector<torch::Tensor>> bestAlphasInit;
1704 std::vector<std::vector<mtu::TorchMergeTree<float>>> bestRecs, bestRecs2;
1705 double bestTime = 0;
1706
1707 auto printLoss
1708 = [this](float loss, float recLoss, float metricLoss, float clustLoss,
1709 float trackLoss, int iterationT, int iterationTT, double time,
1710 const debug::Priority &priority = debug::Priority::INFO) {
1711 std::stringstream prefix;
1712 prefix << (priority == debug::Priority::VERBOSE ? "Iter " : "Best ");
1713 std::stringstream ssBestLoss;
1714 ssBestLoss << prefix.str() << "loss is " << loss << " (iteration "
1715 << iterationT << " / " << iterationTT << ") at time "
1716 << time;
1717 printMsg(ssBestLoss.str(), priority);
1718 if(priority != debug::Priority::VERBOSE)
1719 prefix.str("");
1720 if(metricLossWeight_ != 0 or clusteringLossWeight_ != 0
1721 or trackingLossWeight_ != 0) {
1722 ssBestLoss.str("");
1723 ssBestLoss << "- Rec. " << prefix.str() << "loss = " << recLoss;
1724 printMsg(ssBestLoss.str(), priority);
1725 }
1726 if(metricLossWeight_ != 0) {
1727 ssBestLoss.str("");
1728 ssBestLoss << "- Metric " << prefix.str() << "loss = " << metricLoss;
1729 printMsg(ssBestLoss.str(), priority);
1730 }
1731 if(clusteringLossWeight_ != 0) {
1732 ssBestLoss.str("");
1733 ssBestLoss << "- Clust. " << prefix.str() << "loss = " << clustLoss;
1734 printMsg(ssBestLoss.str(), priority);
1735 }
1736 if(trackingLossWeight_ != 0) {
1737 ssBestLoss.str("");
1738 ssBestLoss << "- Track. " << prefix.str() << "loss = " << trackLoss;
1739 printMsg(ssBestLoss.str(), priority);
1740 }
1741 };
1742
1743 // ----- Algorithm
1744 Timer t_alg;
1745 bool converged = false;
1746 while(not converged) {
1747 if(iteration % iterationGap_ == 0) {
1748 std::stringstream ss;
1749 ss << "Iteration " << iteration;
1751 printMsg(ss.str());
1752 }
1753
1754 bool forwardReset = false;
1755 std::vector<float> iterationLosses, iterationMetricLosses,
1756 iterationClusteringLosses, iterationTrackingLosses;
1757 if(noBatch != 1) {
1758 std::vector<unsigned int> indexes(trees.size());
1759 std::iota(indexes.begin(), indexes.end(), 0);
1760 std::shuffle(std::begin(indexes), std::end(indexes), rng);
1761 for(unsigned int i = 0; i < allIndexes.size(); ++i) {
1762 unsigned int noProcessed = batchSize * i;
1763 unsigned int remaining = trees.size() - noProcessed;
1764 unsigned int size = std::min(batchSize, remaining);
1765 allIndexes[i].resize(size);
1766 for(unsigned int j = 0; j < size; ++j)
1767 allIndexes[i][j] = indexes[noProcessed + j];
1768 }
1769 }
1770 for(unsigned batchNum = 0; batchNum < allIndexes.size(); ++batchNum) {
1771 auto &indexes = allIndexes[batchNum];
1772
1773 // --- Assignment
1774 Timer t_assignment;
1775 std::vector<mtu::TorchMergeTree<float>> outs, outs2;
1776 std::vector<std::vector<torch::Tensor>> bestAlphas;
1777 std::vector<std::vector<mtu::TorchMergeTree<float>>> layersOuts,
1778 layersOuts2;
1779 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1780 matchings, matchings2;
1781 float loss;
1782 bool computeReconstructionError = reconstructionLossWeight_ != 0;
1783 forwardReset
1784 = forwardStep(torchTrees, torchTrees2, indexes, k, allAlphas_,
1785 computeReconstructionError, outs, outs2, bestAlphas,
1786 layersOuts, layersOuts2, matchings, matchings2, loss);
1787 if(forwardReset)
1788 break;
1789 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
1790 unsigned int i = indexes[ind];
1791 for(unsigned int j = 0; j < bestAlphas[i].size(); ++j)
1792 mtu::copyTensor(bestAlphas[i][j], allAlphas_[i][j]);
1793 }
1794 assignmentTime += t_assignment.getElapsedTime();
1795
1796 // --- Loss
1797 Timer t_loss;
1798 losses.emplace_back(loss);
1799 iterationLosses.emplace_back(loss);
1800 // - Metric Loss
1801 torch::Tensor metricLoss;
1802 if(metricLossWeight_ != 0) {
1803 computeMetricLoss(layersOuts, layersOuts2, bestAlphas, distanceMatrix_,
1804 indexes, metricLoss);
1805 float metricLossF = metricLoss.item<float>();
1806 metricLosses.emplace_back(metricLossF);
1807 iterationMetricLosses.emplace_back(metricLossF);
1808 }
1809 // - Clustering Loss
1810 torch::Tensor clusteringLoss;
1811 if(clusteringLossWeight_ != 0) {
1812 torch::Tensor asgn;
1813 computeClusteringLoss(bestAlphas, indexes, clusteringLoss, asgn);
1814 float clusteringLossF = clusteringLoss.item<float>();
1815 clusteringLosses.emplace_back(clusteringLossF);
1816 iterationClusteringLosses.emplace_back(clusteringLossF);
1817 }
1818 // - Tracking Loss
1819 torch::Tensor trackingLoss;
1820 if(trackingLossWeight_ != 0) {
1821 computeTrackingLoss(trackingLoss);
1822 float trackingLossF = trackingLoss.item<float>();
1823 trackingLosses.emplace_back(trackingLossF);
1824 iterationTrackingLosses.emplace_back(trackingLossF);
1825 }
1826 lossTime += t_loss.getElapsedTime();
1827
1828 // --- Update
1829 Timer t_update;
1830 backwardStep(torchTrees, outs, matchings, torchTrees2, outs2, matchings2,
1831 *optimizer, indexes, metricLoss, clusteringLoss,
1832 trackingLoss);
1833 updateTime += t_update.getElapsedTime();
1834
1835 // --- Projection
1836 Timer t_projection;
1837 projectionStep();
1838 projectionTime += t_projection.getElapsedTime();
1839 }
1840
1841 if(forwardReset) {
1842 // TODO better manage reset by init new parameters and start again for
1843 // example (should not happen anymore)
1844 printWrn("Forward reset!");
1845 break;
1846 }
1847
1848 // --- Get iteration loss
1849 // TODO an approximation is made here if batch size != 1 because the
1850 // iteration loss will not be exact, we need to do a forward step and
1851 // compute loss with the whole dataset
1852 /*if(batchSize_ != 1)
1853 printWrn("iteration loss approximation (batchSize_ != 1)");*/
1854 float iterationRecLoss
1855 = torch::tensor(iterationLosses).mean().item<float>();
1856 float iterationLoss = reconstructionLossWeight_ * iterationRecLoss;
1857 float iterationMetricLoss = 0;
1858 if(metricLossWeight_ != 0) {
1859 iterationMetricLoss
1860 = torch::tensor(iterationMetricLosses).mean().item<float>();
1861 iterationLoss
1862 += metricLossWeight_
1863 * getCustomLossDynamicWeight(iterationRecLoss, baseRecLoss_)
1864 * iterationMetricLoss;
1865 }
1866 float iterationClusteringLoss = 0;
1867 if(clusteringLossWeight_ != 0) {
1868 iterationClusteringLoss
1869 = torch::tensor(iterationClusteringLosses).mean().item<float>();
1870 iterationLoss
1871 += clusteringLossWeight_
1872 * getCustomLossDynamicWeight(iterationRecLoss, baseRecLoss_)
1873 * iterationClusteringLoss;
1874 }
1875 float iterationTrackingLoss = 0;
1876 if(trackingLossWeight_ != 0) {
1877 iterationTrackingLoss
1878 = torch::tensor(iterationTrackingLosses).mean().item<float>();
1879 iterationLoss += trackingLossWeight_ * iterationTrackingLoss;
1880 }
1881 printLoss(iterationLoss, iterationRecLoss, iterationMetricLoss,
1882 iterationClusteringLoss, iterationTrackingLoss, iteration,
1883 iteration, t_alg.getElapsedTime() - t_allVectorCopy_time_,
1885
1886 // --- Update best parameters
1887 bool isBest = isBestLoss(iterationLoss, minLoss, cptBlocked);
1888 if(isBest) {
1889 Timer t_copy;
1890 bestIteration = iteration;
1891 copyParams(origins_, originsPrime_, vSTensor_, vSPrimeTensor_, origins2_,
1892 origins2Prime_, vS2Tensor_, vS2PrimeTensor_, allAlphas_,
1893 bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor,
1894 bestOrigins2, bestOrigins2Prime, bestVS2Tensor,
1895 bestVS2PrimeTensor, bestAlphasInit);
1896 copyParams(recs_, bestRecs);
1897 copyParams(recs2_, bestRecs2);
1898 t_allVectorCopy_time_ += t_copy.getElapsedTime();
1899 bestTime = t_alg.getElapsedTime() - t_allVectorCopy_time_;
1900 minRecLoss = iterationRecLoss;
1901 minMetricLoss = iterationMetricLoss;
1902 minClustLoss = iterationClusteringLoss;
1903 minTrackLoss = iterationTrackingLoss;
1904 printLoss(minLoss, minRecLoss, minMetricLoss, minClustLoss, minTrackLoss,
1905 bestIteration, iteration, bestTime, debug::Priority::DETAIL);
1906 }
1907
1908 // --- Convergence
1909 windowLoss += iterationLoss;
1910 if((iteration + 1) % convWinSize == 0) {
1911 windowLoss /= convWinSize;
1912 converged = convergenceStep(windowLoss, oldLoss, minLoss, cptBlocked);
1913 windowLoss = 0;
1914 if(converged) {
1915 ++noConverged;
1916 } else
1917 noConverged = 0;
1918 converged = noConverged >= noConvergedToGet;
1919 if(converged and iteration < minIteration_)
1920 printMsg("convergence is detected but iteration < minIteration_",
1922 if(iteration < minIteration_)
1923 converged = false;
1924 if(converged)
1925 break;
1926 }
1927
1928 // --- Print
1929 if(iteration % iterationGap_ == 0) {
1930 printMsg("Assignment", 1, assignmentTime, threadNumber_);
1931 printMsg("Loss", 1, lossTime, threadNumber_);
1932 printMsg("Update", 1, updateTime, threadNumber_);
1933 printMsg("Projection", 1, projectionTime, threadNumber_);
1934 assignmentTime = 0.0;
1935 lossTime = 0.0;
1936 updateTime = 0.0;
1937 projectionTime = 0.0;
1938 std::stringstream ss;
1939 float loss = torch::tensor(losses).mean().item<float>();
1940 losses.clear();
1941 ss << "Rec. loss = " << loss;
1942 printMsg(ss.str());
1943 if(metricLossWeight_ != 0) {
1944 float metricLoss = torch::tensor(metricLosses).mean().item<float>();
1945 metricLosses.clear();
1946 ss.str("");
1947 ss << "Metric loss = " << metricLoss;
1948 printMsg(ss.str());
1949 }
1950 if(clusteringLossWeight_ != 0) {
1951 float clusteringLoss
1952 = torch::tensor(clusteringLosses).mean().item<float>();
1953 clusteringLosses.clear();
1954 ss.str("");
1955 ss << "Clust. loss = " << clusteringLoss;
1956 printMsg(ss.str());
1957 }
1958 if(trackingLossWeight_ != 0) {
1959 float trackingLoss = torch::tensor(trackingLosses).mean().item<float>();
1960 trackingLosses.clear();
1961 ss.str("");
1962 ss << "Track. loss = " << trackingLoss;
1963 printMsg(ss.str());
1964 }
1965
1966 // Verify grad and big values (testing)
1967 for(unsigned int l = 0; l < noLayers_; ++l) {
1968 ss.str("");
1969 if(originsNoZeroGrad_[l] != 0)
1970 ss << originsNoZeroGrad_[l] << " originsNoZeroGrad_[" << l << "]"
1971 << std::endl;
1972 if(originsPrimeNoZeroGrad_[l] != 0)
1973 ss << originsPrimeNoZeroGrad_[l] << " originsPrimeNoZeroGrad_[" << l
1974 << "]" << std::endl;
1975 if(vSNoZeroGrad_[l] != 0)
1976 ss << vSNoZeroGrad_[l] << " vSNoZeroGrad_[" << l << "]" << std::endl;
1977 if(vSPrimeNoZeroGrad_[l] != 0)
1978 ss << vSPrimeNoZeroGrad_[l] << " vSPrimeNoZeroGrad_[" << l << "]"
1979 << std::endl;
1980 originsNoZeroGrad_[l] = 0;
1981 originsPrimeNoZeroGrad_[l] = 0;
1982 vSNoZeroGrad_[l] = 0;
1983 vSPrimeNoZeroGrad_[l] = 0;
1984 if(useDoubleInput_) {
1985 if(origins2NoZeroGrad_[l] != 0)
1986 ss << origins2NoZeroGrad_[l] << " origins2NoZeroGrad_[" << l << "]"
1987 << std::endl;
1988 if(origins2PrimeNoZeroGrad_[l] != 0)
1989 ss << origins2PrimeNoZeroGrad_[l] << " origins2PrimeNoZeroGrad_["
1990 << l << "]" << std::endl;
1991 if(vS2NoZeroGrad_[l] != 0)
1992 ss << vS2NoZeroGrad_[l] << " vS2NoZeroGrad_[" << l << "]"
1993 << std::endl;
1994 if(vS2PrimeNoZeroGrad_[l] != 0)
1995 ss << vS2PrimeNoZeroGrad_[l] << " vS2PrimeNoZeroGrad_[" << l << "]"
1996 << std::endl;
1997 origins2NoZeroGrad_[l] = 0;
1998 origins2PrimeNoZeroGrad_[l] = 0;
1999 vS2NoZeroGrad_[l] = 0;
2000 vS2PrimeNoZeroGrad_[l] = 0;
2001 }
2002 if(isTreeHasBigValues(origins_[l].mTree, bigValuesThreshold_))
2003 ss << "origins_[" << l << "] has big values!" << std::endl;
2004 if(isTreeHasBigValues(originsPrime_[l].mTree, bigValuesThreshold_))
2005 ss << "originsPrime_[" << l << "] has big values!" << std::endl;
2006 if(ss.rdbuf()->in_avail() != 0)
2008 }
2009 }
2010
2011 ++iteration;
2012 if(maxIteration_ != 0 and iteration >= maxIteration_) {
2013 printMsg("iteration >= maxIteration_", debug::Priority::DETAIL);
2014 break;
2015 }
2016 }
2018 printLoss(minLoss, minRecLoss, minMetricLoss, minClustLoss, minTrackLoss,
2019 bestIteration, iteration, bestTime);
2021 bestLoss_ = minLoss;
2022
2023 Timer t_copy;
2024 copyParams(bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor,
2025 bestOrigins2, bestOrigins2Prime, bestVS2Tensor, bestVS2PrimeTensor,
2026 bestAlphasInit, origins_, originsPrime_, vSTensor_, vSPrimeTensor_,
2027 origins2_, origins2Prime_, vS2Tensor_, vS2PrimeTensor_,
2028 allAlphas_);
2029 copyParams(bestRecs, recs_);
2030 copyParams(bestRecs2, recs2_);
2031 t_allVectorCopy_time_ += t_copy.getElapsedTime();
2032 printMsg("Copy time", 1, t_allVectorCopy_time_, threadNumber_);
2033}
2034
2035// ---------------------------------------------------------------------------
2036// --- Custom Losses
2037// ---------------------------------------------------------------------------
2038double ttk::MergeTreeAutoencoder::getCustomLossDynamicWeight(double recLoss,
2039 double &baseLoss) {
2040 baseLoss = std::min(recLoss, baseLoss);
2041 if(customLossDynamicWeight_)
2042 return baseLoss;
2043 else
2044 return 1.0;
2045}
2046
2047void ttk::MergeTreeAutoencoder::getDistanceMatrix(
2048 std::vector<mtu::TorchMergeTree<float>> &tmts,
2049 std::vector<std::vector<float>> &distanceMatrix,
2050 bool useDoubleInput,
2051 bool isFirstInput) {
2052 distanceMatrix.clear();
2053 distanceMatrix.resize(tmts.size(), std::vector<float>(tmts.size(), 0));
2054#ifdef TTK_ENABLE_OPENMP
2055#pragma omp parallel num_threads(this->threadNumber_) if(parallelize_) \
2056 shared(distanceMatrix, tmts)
2057 {
2058#pragma omp single nowait
2059 {
2060#endif
2061 for(unsigned int i = 0; i < tmts.size(); ++i) {
2062 for(unsigned int j = i + 1; j < tmts.size(); ++j) {
2063#ifdef TTK_ENABLE_OPENMP
2064#pragma omp task UNTIED() shared(distanceMatrix, tmts) firstprivate(i, j)
2065 {
2066#endif
2067 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching;
2068 float distance;
2069 bool isCalled = true;
2070 computeOneDistance(tmts[i].mTree, tmts[j].mTree, matching, distance,
2071 isCalled, useDoubleInput, isFirstInput);
2073 distanceMatrix[i][j] = distance;
2074 distanceMatrix[j][i] = distance;
2075#ifdef TTK_ENABLE_OPENMP
2076 } // pragma omp task
2077#endif
2078 }
2079 }
2080#ifdef TTK_ENABLE_OPENMP
2081#pragma omp taskwait
2082 } // pragma omp single nowait
2083 } // pragma omp parallel
2084#endif
2085}
2086
2087void ttk::MergeTreeAutoencoder::getDistanceMatrix(
2088 std::vector<mtu::TorchMergeTree<float>> &tmts,
2089 std::vector<mtu::TorchMergeTree<float>> &tmts2,
2090 std::vector<std::vector<float>> &distanceMatrix) {
2091 getDistanceMatrix(tmts, distanceMatrix, useDoubleInput_);
2092 if(useDoubleInput_) {
2093 std::vector<std::vector<float>> distanceMatrix2;
2094 getDistanceMatrix(tmts2, distanceMatrix2, useDoubleInput_, false);
2095 mixDistancesMatrix<float>(distanceMatrix, distanceMatrix2);
2096 }
2097}
2098
2099void ttk::MergeTreeAutoencoder::getDifferentiableDistanceFromMatchings(
2100 mtu::TorchMergeTree<float> &tree1,
2101 mtu::TorchMergeTree<float> &tree2,
2102 mtu::TorchMergeTree<float> &tree1_2,
2103 mtu::TorchMergeTree<float> &tree2_2,
2104 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings,
2105 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings2,
2106 torch::Tensor &tensorDist,
2107 bool doSqrt) {
2108 torch::Tensor reorderedITensor, reorderedJTensor;
2109 dataReorderingGivenMatching(
2110 tree1, tree2, matchings, reorderedITensor, reorderedJTensor);
2111 if(useDoubleInput_) {
2112 torch::Tensor reorderedI2Tensor, reorderedJ2Tensor;
2113 dataReorderingGivenMatching(
2114 tree1_2, tree2_2, matchings2, reorderedI2Tensor, reorderedJ2Tensor);
2115 reorderedITensor = torch::cat({reorderedITensor, reorderedI2Tensor});
2116 reorderedJTensor = torch::cat({reorderedJTensor, reorderedJ2Tensor});
2117 }
2118 tensorDist = (reorderedITensor - reorderedJTensor).pow(2).sum();
2119 if(doSqrt)
2120 tensorDist = tensorDist.sqrt();
2121}
2122
2123void ttk::MergeTreeAutoencoder::getDifferentiableDistance(
2124 mtu::TorchMergeTree<float> &tree1,
2125 mtu::TorchMergeTree<float> &tree2,
2126 mtu::TorchMergeTree<float> &tree1_2,
2127 mtu::TorchMergeTree<float> &tree2_2,
2128 torch::Tensor &tensorDist,
2129 bool isCalled,
2130 bool doSqrt) {
2131 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matchings,
2132 matchings2;
2133 float distance;
2134 computeOneDistance<float>(
2135 tree1.mTree, tree2.mTree, matchings, distance, isCalled, useDoubleInput_);
2136 if(useDoubleInput_) {
2137 float distance2;
2138 computeOneDistance<float>(tree1_2.mTree, tree2_2.mTree, matchings2,
2139 distance2, isCalled, useDoubleInput_, false);
2140 }
2141 getDifferentiableDistanceFromMatchings(
2142 tree1, tree2, tree1_2, tree2_2, matchings, matchings2, tensorDist, doSqrt);
2143}
2144
2145void ttk::MergeTreeAutoencoder::getDifferentiableDistance(
2146 mtu::TorchMergeTree<float> &tree1,
2147 mtu::TorchMergeTree<float> &tree2,
2148 torch::Tensor &tensorDist,
2149 bool isCalled,
2150 bool doSqrt) {
2151 mtu::TorchMergeTree<float> tree1_2, tree2_2;
2152 getDifferentiableDistance(
2153 tree1, tree2, tree1_2, tree2_2, tensorDist, isCalled, doSqrt);
2154}
2155
2156void ttk::MergeTreeAutoencoder::getDifferentiableDistanceMatrix(
2157 std::vector<mtu::TorchMergeTree<float> *> &trees,
2158 std::vector<mtu::TorchMergeTree<float> *> &trees2,
2159 std::vector<std::vector<torch::Tensor>> &outDistMat) {
2160 outDistMat.resize(trees.size(), std::vector<torch::Tensor>(trees.size()));
2161#ifdef TTK_ENABLE_OPENMP
2162#pragma omp parallel num_threads(this->threadNumber_) if(parallelize_) \
2163 shared(trees, trees2, outDistMat)
2164 {
2165#pragma omp single nowait
2166 {
2167#endif
2168 for(unsigned int i = 0; i < trees.size(); ++i) {
2169 outDistMat[i][i] = torch::tensor(0);
2170 for(unsigned int j = i + 1; j < trees.size(); ++j) {
2171#ifdef TTK_ENABLE_OPENMP
2172#pragma omp task UNTIED() shared(trees, trees2, outDistMat) firstprivate(i, j)
2173 {
2174#endif
2175 bool isCalled = true;
2176 bool doSqrt = false;
2177 torch::Tensor tensorDist;
2178 getDifferentiableDistance(*(trees[i]), *(trees[j]), *(trees2[i]),
2179 *(trees2[j]), tensorDist, isCalled,
2180 doSqrt);
2181 outDistMat[i][j] = tensorDist;
2182 outDistMat[j][i] = tensorDist;
2183#ifdef TTK_ENABLE_OPENMP
2184 } // pragma omp task
2185#endif
2186 }
2187 }
2188#ifdef TTK_ENABLE_OPENMP
2189#pragma omp taskwait
2190 } // pragma omp single nowait
2191 } // pragma omp parallel
2192#endif
2193}
2194
2195void ttk::MergeTreeAutoencoder::getAlphasTensor(
2196 std::vector<std::vector<torch::Tensor>> &alphas,
2197 std::vector<unsigned int> &indexes,
2198 unsigned int layerIndex,
2199 torch::Tensor &alphasOut) {
2200 alphasOut = alphas[indexes[0]][layerIndex].transpose(0, 1);
2201 for(unsigned int ind = 1; ind < indexes.size(); ++ind)
2202 alphasOut = torch::cat(
2203 {alphasOut, alphas[indexes[ind]][layerIndex].transpose(0, 1)});
2204}
2205
2206void ttk::MergeTreeAutoencoder::computeMetricLoss(
2207 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
2208 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
2209 std::vector<std::vector<torch::Tensor>> alphas,
2210 std::vector<std::vector<float>> &baseDistanceMatrix,
2211 std::vector<unsigned int> &indexes,
2212 torch::Tensor &metricLoss) {
2213 auto layerIndex = getLatentLayerIndex();
2214 std::vector<std::vector<torch::Tensor>> losses(
2215 layersOuts.size(), std::vector<torch::Tensor>(layersOuts.size()));
2216
2217 std::vector<mtu::TorchMergeTree<float> *> trees, trees2;
2218 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
2219 unsigned int i = indexes[ind];
2220 trees.emplace_back(&(layersOuts[i][layerIndex]));
2221 if(useDoubleInput_)
2222 trees2.emplace_back(&(layersOuts2[i][layerIndex]));
2223 }
2224
2225 std::vector<std::vector<torch::Tensor>> outDistMat;
2226 torch::Tensor coefDistMat;
2227 if(customLossSpace_) {
2228 getDifferentiableDistanceMatrix(trees, trees2, outDistMat);
2229 } else {
2230 std::vector<std::vector<torch::Tensor>> scaledAlphas;
2231 createScaledAlphas(alphas, vSTensor_, scaledAlphas);
2232 torch::Tensor latentAlphas;
2233 getAlphasTensor(scaledAlphas, indexes, layerIndex, latentAlphas);
2234 if(customLossActivate_)
2235 latentAlphas = activation(latentAlphas);
2236 coefDistMat = torch::cdist(latentAlphas, latentAlphas).pow(2);
2237 }
2238
2239 torch::Tensor maxLoss = torch::tensor(0);
2240 metricLoss = torch::tensor(0);
2241 float div = 0;
2242 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
2243 unsigned int i = indexes[ind];
2244 for(unsigned int ind2 = ind + 1; ind2 < indexes.size(); ++ind2) {
2245 unsigned int j = indexes[ind2];
2246 torch::Tensor loss;
2247 torch::Tensor toCompare
2248 = (customLossSpace_ ? outDistMat[i][j] : coefDistMat[ind][ind2]);
2249 loss = torch::nn::MSELoss()(
2250 torch::tensor(baseDistanceMatrix[i][j]), toCompare);
2251 metricLoss = metricLoss + loss;
2252 maxLoss = torch::max(loss, maxLoss);
2253 ++div;
2254 }
2255 }
2256 metricLoss = metricLoss / torch::tensor(div);
2257 if(normalizeMetricLoss_)
2258 metricLoss /= maxLoss;
2259}
2260
2261void ttk::MergeTreeAutoencoder::computeClusteringLoss(
2262 std::vector<std::vector<torch::Tensor>> &alphas,
2263 std::vector<unsigned int> &indexes,
2264 torch::Tensor &clusteringLoss,
2265 torch::Tensor &asgn) {
2266 // Compute distance matrix
2267 unsigned int layerIndex = getLatentLayerIndex();
2268 torch::Tensor latentAlphas;
2269 getAlphasTensor(alphas, indexes, layerIndex, latentAlphas);
2270 if(customLossActivate_)
2271 latentAlphas = activation(latentAlphas);
2272 torch::Tensor centroids = latentCentroids_[0].transpose(0, 1);
2273 for(unsigned int i = 1; i < latentCentroids_.size(); ++i)
2274 centroids = torch::cat({centroids, latentCentroids_[i].transpose(0, 1)});
2275 torch::Tensor dist = torch::cdist(latentAlphas, centroids);
2276
2277 // Compute softmax and one hot real asgn
2278 dist = dist * -clusteringLossTemp_;
2279 asgn = torch::nn::Softmax(1)(dist);
2280 std::vector<float> clusterAsgn;
2281 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
2282 clusterAsgn.emplace_back(clusterAsgn_[indexes[ind]]);
2283 }
2284 torch::Tensor realAsgn = torch::tensor(clusterAsgn).to(torch::kInt64);
2285 realAsgn
2286 = torch::nn::functional::one_hot(realAsgn, asgn.sizes()[1]).to(torch::kF32);
2287
2288 // Compute KL div.
2289 clusteringLoss = torch::nn::KLDivLoss(
2290 torch::nn::KLDivLossOptions().reduction(torch::kBatchMean))(asgn, realAsgn);
2291}
2292
2293void ttk::MergeTreeAutoencoder::computeTrackingLoss(
2294 torch::Tensor &trackingLoss) {
2295 unsigned int latentLayerIndex = getLatentLayerIndex() + 1;
2296 auto endLayer = (trackingLossDecoding_ ? noLayers_ : latentLayerIndex);
2297 std::vector<torch::Tensor> losses(endLayer);
2298#ifdef TTK_ENABLE_OPENMP
2299#pragma omp parallel for schedule(dynamic) \
2300 num_threads(this->threadNumber_) if(parallelize_)
2301#endif
2302 for(unsigned int l = 0; l < endLayer; ++l) {
2303 auto &tree1 = (l == 0 ? origins_[0] : originsPrime_[l - 1]);
2304 auto &tree2 = (l == 0 ? originsPrime_[0] : originsPrime_[l]);
2305 torch::Tensor tensorDist;
2306 bool isCalled = true, doSqrt = false;
2307 getDifferentiableDistance(tree1, tree2, tensorDist, isCalled, doSqrt);
2308 losses[l] = tensorDist;
2309 }
2310 trackingLoss = torch::tensor(0, torch::kFloat32);
2311 for(unsigned int i = 0; i < losses.size(); ++i)
2312 trackingLoss += losses[i];
2313}
2314
2315// ---------------------------------------------------------------------------
2316// --- End Functions
2317// ---------------------------------------------------------------------------
2318void ttk::MergeTreeAutoencoder::createCustomRecs(
2319 std::vector<mtu::TorchMergeTree<float>> &origins,
2320 std::vector<mtu::TorchMergeTree<float>> &originsPrime) {
2321 if(customAlphas_.empty())
2322 return;
2323
2324 bool initByTreesAlphas = not allAlphas_.empty();
2325 std::vector<torch::Tensor> allTreesAlphas;
2326 if(initByTreesAlphas) {
2327 allTreesAlphas.resize(allAlphas_[0].size());
2328 for(unsigned int l = 0; l < allTreesAlphas.size(); ++l) {
2329 allTreesAlphas[l] = allAlphas_[0][l].reshape({-1, 1});
2330 for(unsigned int i = 1; i < allAlphas_.size(); ++i)
2331 allTreesAlphas[l]
2332 = torch::cat({allTreesAlphas[l], allAlphas_[i][l]}, 1);
2333 allTreesAlphas[l] = allTreesAlphas[l].transpose(0, 1);
2334 }
2335 }
2336
2337 unsigned int latLayer = getLatentLayerIndex();
2338 customRecs_.resize(customAlphas_.size());
2339#ifdef TTK_ENABLE_OPENMP
2340#pragma omp parallel for schedule(dynamic) \
2341 num_threads(this->threadNumber_) if(parallelize_)
2342#endif
2343 for(unsigned int i = 0; i < customAlphas_.size(); ++i) {
2344 torch::Tensor alphas = torch::tensor(customAlphas_[i]).reshape({-1, 1});
2345
2346 torch::Tensor alphasWeight;
2347 if(initByTreesAlphas) {
2348 auto driver = "gelsd";
2349 alphasWeight = std::get<0>(torch::linalg::lstsq(
2350 allTreesAlphas[latLayer].transpose(0, 1),
2351 alphas, c10::nullopt, driver))
2352 .transpose(0, 1);
2353 }
2354
2355 // Reconst latent
2356 std::vector<mtu::TorchMergeTree<float>> outs, outs2;
2357 auto noOuts = noLayers_ - latLayer;
2358 outs.resize(noOuts);
2359 outs2.resize(noOuts);
2360 mtu::TorchMergeTree<float> out, out2;
2361 outputBasisReconstruction(originsPrime[latLayer], vSPrimeTensor_[latLayer],
2362 origins2Prime_[latLayer],
2363 vS2PrimeTensor_[latLayer], alphas, outs[0],
2364 outs2[0]);
2365 // Decoding
2366 unsigned int k = 32;
2367 for(unsigned int l = latLayer + 1; l < noLayers_; ++l) {
2368 unsigned int noIter = (initByTreesAlphas ? 1 : 32);
2369 std::vector<torch::Tensor> allAlphasInit(noIter);
2370 torch::Tensor maxNorm;
2371 for(unsigned int j = 0; j < allAlphasInit.size(); ++j) {
2372 allAlphasInit[j] = torch::randn({vSTensor_[l].sizes()[1], 1});
2373 auto norm = torch::linalg::vector_norm(
2374 allAlphasInit[j], 2, 0, false, c10::nullopt);
2375 if(j == 0 or maxNorm.item<float>() < norm.item<float>())
2376 maxNorm = norm;
2377 }
2378 for(unsigned int j = 0; j < allAlphasInit.size(); ++j)
2379 allAlphasInit[j] /= maxNorm;
2380 float bestDistance = std::numeric_limits<float>::max();
2381 auto outIndex = l - latLayer;
2382 mtu::TorchMergeTree<float> outToUse;
2383 for(unsigned int j = 0; j < noIter; ++j) {
2384 torch::Tensor alphasInit, dataAlphas;
2385 if(initByTreesAlphas) {
2386 alphasInit
2387 = torch::matmul(alphasWeight, allTreesAlphas[l]).transpose(0, 1);
2388 } else {
2389 alphasInit = allAlphasInit[j];
2390 }
2391 float distance;
2392 forwardOneLayer(outs[outIndex - 1], origins[l], vSTensor_[l],
2393 originsPrime[l], vSPrimeTensor_[l], outs2[outIndex - 1],
2394 origins2_[l], vS2Tensor_[l], origins2Prime_[l],
2395 vS2PrimeTensor_[l], k, alphasInit, outToUse,
2396 outs2[outIndex], dataAlphas, distance);
2397 if(distance < bestDistance) {
2398 bestDistance = distance;
2399 mtu::copyTorchMergeTree<float>(
2400 outToUse, (l != noLayers_ - 1 ? outs[outIndex] : customRecs_[i]));
2401 }
2402 }
2403 }
2404 }
2405
2406 customMatchings_.resize(customRecs_.size());
2407#ifdef TTK_ENABLE_OPENMP
2408#pragma omp parallel for schedule(dynamic) \
2409 num_threads(this->threadNumber_) if(parallelize_)
2410#endif
2411 for(unsigned int i = 0; i < customRecs_.size(); ++i) {
2412 bool isCalled = true;
2413 float distance;
2414 computeOneDistance<float>(origins[0].mTree, customRecs_[i].mTree,
2415 customMatchings_[i], distance, isCalled,
2416 useDoubleInput_);
2417 }
2418
2419 mtu::TorchMergeTree<float> originCopy;
2420 mtu::copyTorchMergeTree<float>(origins[0], originCopy);
2421 postprocessingPipeline<float>(&(originCopy.mTree.tree));
2422 for(unsigned int i = 0; i < customRecs_.size(); ++i) {
2423 wae::fixTreePrecisionScalars(customRecs_[i].mTree);
2424 postprocessingPipeline<float>(&(customRecs_[i].mTree.tree));
2425 if(not isPersistenceDiagram_) {
2426 convertBranchDecompositionMatching<float>(&(originCopy.mTree.tree),
2427 &(customRecs_[i].mTree.tree),
2428 customMatchings_[i]);
2429 }
2430 }
2431}
2432
2433void ttk::MergeTreeAutoencoder::computeTrackingInformation() {
2434 unsigned int latentLayerIndex = getLatentLayerIndex() + 1;
2435 originsMatchings_.resize(latentLayerIndex);
2436#ifdef TTK_ENABLE_OPENMP
2437#pragma omp parallel for schedule(dynamic) \
2438 num_threads(this->threadNumber_) if(parallelize_)
2439#endif
2440 for(unsigned int l = 0; l < latentLayerIndex; ++l) {
2441 auto &tree1 = (l == 0 ? origins_[0] : originsPrime_[l - 1]);
2442 auto &tree2 = (l == 0 ? originsPrime_[0] : originsPrime_[l]);
2443 bool isCalled = true;
2444 float distance;
2445 computeOneDistance<float>(tree1.mTree, tree2.mTree, originsMatchings_[l],
2446 distance, isCalled, useDoubleInput_);
2447 }
2448
2449 // Data matchings
2450 ++latentLayerIndex;
2451 dataMatchings_.resize(latentLayerIndex);
2452 for(unsigned int l = 0; l < latentLayerIndex; ++l) {
2453 dataMatchings_[l].resize(recs_.size());
2454#ifdef TTK_ENABLE_OPENMP
2455#pragma omp parallel for schedule(dynamic) \
2456 num_threads(this->threadNumber_) if(parallelize_)
2457#endif
2458 for(unsigned int i = 0; i < recs_.size(); ++i) {
2459 bool isCalled = true;
2460 float distance;
2461 auto &origin = (l == 0 ? origins_[0] : originsPrime_[l - 1]);
2462 computeOneDistance<float>(origin.mTree, recs_[i][l].mTree,
2463 dataMatchings_[l][i], distance, isCalled,
2464 useDoubleInput_);
2465 }
2466 }
2467
2468 // Reconst matchings
2469 reconstMatchings_.resize(recs_.size());
2470#ifdef TTK_ENABLE_OPENMP
2471#pragma omp parallel for schedule(dynamic) \
2472 num_threads(this->threadNumber_) if(parallelize_)
2473#endif
2474 for(unsigned int i = 0; i < recs_.size(); ++i) {
2475 bool isCalled = true;
2476 float distance;
2477 auto l = recs_[i].size() - 1;
2478 computeOneDistance<float>(recs_[i][0].mTree, recs_[i][l].mTree,
2479 reconstMatchings_[i], distance, isCalled,
2480 useDoubleInput_);
2481 }
2482}
2483
2484void ttk::MergeTreeAutoencoder::createScaledAlphas(
2485 std::vector<std::vector<torch::Tensor>> &alphas,
2486 std::vector<torch::Tensor> &vSTensor,
2487 std::vector<std::vector<torch::Tensor>> &scaledAlphas) {
2488 scaledAlphas.clear();
2489 scaledAlphas.resize(
2490 alphas.size(), std::vector<torch::Tensor>(alphas[0].size()));
2491 for(unsigned int l = 0; l < alphas[0].size(); ++l) {
2492 torch::Tensor scale = vSTensor[l].pow(2).sum(0).sqrt();
2493 for(unsigned int i = 0; i < alphas.size(); ++i) {
2494 scaledAlphas[i][l] = alphas[i][l] * scale.reshape({-1, 1});
2495 }
2496 }
2497}
2498
2499void ttk::MergeTreeAutoencoder::createScaledAlphas() {
2500 createScaledAlphas(allAlphas_, vSTensor_, allScaledAlphas_);
2501}
2502
2503void ttk::MergeTreeAutoencoder::createActivatedAlphas() {
2504 allActAlphas_ = allAlphas_;
2505 for(unsigned int i = 0; i < allActAlphas_.size(); ++i)
2506 for(unsigned int j = 0; j < allActAlphas_[i].size(); ++j)
2507 allActAlphas_[i][j] = activation(allActAlphas_[i][j]);
2508 createScaledAlphas(allActAlphas_, vSTensor_, allActScaledAlphas_);
2509}
2510
2511// ---------------------------------------------------------------------------
2512// --- Utils
2513// ---------------------------------------------------------------------------
2514void ttk::MergeTreeAutoencoder::copyParams(
2515 std::vector<mtu::TorchMergeTree<float>> &srcOrigins,
2516 std::vector<mtu::TorchMergeTree<float>> &srcOriginsPrime,
2517 std::vector<torch::Tensor> &srcVS,
2518 std::vector<torch::Tensor> &srcVSPrime,
2519 std::vector<mtu::TorchMergeTree<float>> &srcOrigins2,
2520 std::vector<mtu::TorchMergeTree<float>> &srcOrigins2Prime,
2521 std::vector<torch::Tensor> &srcVS2,
2522 std::vector<torch::Tensor> &srcVS2Prime,
2523 std::vector<std::vector<torch::Tensor>> &srcAlphas,
2524 std::vector<mtu::TorchMergeTree<float>> &dstOrigins,
2525 std::vector<mtu::TorchMergeTree<float>> &dstOriginsPrime,
2526 std::vector<torch::Tensor> &dstVS,
2527 std::vector<torch::Tensor> &dstVSPrime,
2528 std::vector<mtu::TorchMergeTree<float>> &dstOrigins2,
2529 std::vector<mtu::TorchMergeTree<float>> &dstOrigins2Prime,
2530 std::vector<torch::Tensor> &dstVS2,
2531 std::vector<torch::Tensor> &dstVS2Prime,
2532 std::vector<std::vector<torch::Tensor>> &dstAlphas) {
2533 dstOrigins.resize(noLayers_);
2534 dstOriginsPrime.resize(noLayers_);
2535 dstVS.resize(noLayers_);
2536 dstVSPrime.resize(noLayers_);
2537 dstAlphas.resize(srcAlphas.size(), std::vector<torch::Tensor>(noLayers_));
2538 if(useDoubleInput_) {
2539 dstOrigins2.resize(noLayers_);
2540 dstOrigins2Prime.resize(noLayers_);
2541 dstVS2.resize(noLayers_);
2542 dstVS2Prime.resize(noLayers_);
2543 }
2544 for(unsigned int l = 0; l < noLayers_; ++l) {
2545 mtu::copyTorchMergeTree(srcOrigins[l], dstOrigins[l]);
2546 mtu::copyTorchMergeTree(srcOriginsPrime[l], dstOriginsPrime[l]);
2547 mtu::copyTensor(srcVS[l], dstVS[l]);
2548 mtu::copyTensor(srcVSPrime[l], dstVSPrime[l]);
2549 if(useDoubleInput_) {
2550 mtu::copyTorchMergeTree(srcOrigins2[l], dstOrigins2[l]);
2551 mtu::copyTorchMergeTree(srcOrigins2Prime[l], dstOrigins2Prime[l]);
2552 mtu::copyTensor(srcVS2[l], dstVS2[l]);
2553 mtu::copyTensor(srcVS2Prime[l], dstVS2Prime[l]);
2554 }
2555 for(unsigned int i = 0; i < srcAlphas.size(); ++i)
2556 mtu::copyTensor(srcAlphas[i][l], dstAlphas[i][l]);
2557 }
2558}
2559
2560void ttk::MergeTreeAutoencoder::copyParams(
2561 std::vector<std::vector<mtu::TorchMergeTree<float>>> &src,
2562 std::vector<std::vector<mtu::TorchMergeTree<float>>> &dst) {
2563 dst.resize(src.size());
2564 for(unsigned int i = 0; i < src.size(); ++i) {
2565 dst[i].resize(src[i].size());
2566 for(unsigned int j = 0; j < src[i].size(); ++j)
2567 mtu::copyTorchMergeTree(src[i][j], dst[i][j]);
2568 }
2569}
2570
2571unsigned int ttk::MergeTreeAutoencoder::getLatentLayerIndex() {
2572 unsigned int idx = noLayers_ / 2 - 1;
2573 if(idx > noLayers_) // unsigned negativeness
2574 idx = 0;
2575 return idx;
2576}
2577
2578// ---------------------------------------------------------------------------
2579// --- Testing
2580// ---------------------------------------------------------------------------
2581bool ttk::MergeTreeAutoencoder::isTreeHasBigValues(ftm::MergeTree<float> &mTree,
2582 float threshold) {
2583 bool found = false;
2584 for(unsigned int n = 0; n < mTree.tree.getNumberOfNodes(); ++n) {
2585 if(mTree.tree.isNodeAlone(n))
2586 continue;
2587 auto birthDeath = mTree.tree.template getBirthDeath<float>(n);
2588 if(std::abs(std::get<0>(birthDeath)) > threshold
2589 or std::abs(std::get<1>(birthDeath)) > threshold) {
2590 found = true;
2591 break;
2592 }
2593 }
2594 return found;
2595}
2596#endif
2597
2598// ---------------------------------------------------------------------------
2599// --- Main Functions
2600// ---------------------------------------------------------------------------
2601
2603 std::vector<ftm::MergeTree<float>> &trees,
2604 std::vector<ftm::MergeTree<float>> &trees2) {
2605#ifndef TTK_ENABLE_TORCH
2606 TTK_FORCE_USE(trees);
2607 TTK_FORCE_USE(trees2);
2608 printErr("This module requires Torch.");
2609#else
2610#ifdef TTK_ENABLE_OPENMP
2611 int ompNested = omp_get_nested();
2612 omp_set_nested(1);
2613#endif
2614 // --- Preprocessing
2615 Timer t_preprocess;
2616 preprocessingTrees<float>(trees, treesNodeCorr_);
2617 if(trees2.size() != 0)
2618 preprocessingTrees<float>(trees2, trees2NodeCorr_);
2619 printMsg("Preprocessing", 1, t_preprocess.getElapsedTime(), threadNumber_);
2620 useDoubleInput_ = (trees2.size() != 0);
2621
2622 // --- Fit autoencoder
2623 Timer t_total;
2624 fit(trees, trees2);
2625 auto totalTime = t_total.getElapsedTime() - t_allVectorCopy_time_;
2627 printMsg("Total time", 1, totalTime, threadNumber_);
2628 hasComputedOnce_ = true;
2629
2630 // --- End functions
2631 createScaledAlphas();
2632 createActivatedAlphas();
2633 computeTrackingInformation();
2634 // Correlation
2635 auto latLayer = getLatentLayerIndex();
2636 std::vector<std::vector<double>> allTs;
2637 auto noGeod = allAlphas_[0][latLayer].sizes()[0];
2638 allTs.resize(noGeod);
2639 for(unsigned int i = 0; i < noGeod; ++i) {
2640 allTs[i].resize(allAlphas_.size());
2641 for(unsigned int j = 0; j < allAlphas_.size(); ++j)
2642 allTs[i][j] = allAlphas_[j][latLayer][i].item<double>();
2643 }
2644 computeBranchesCorrelationMatrix(origins_[0].mTree, trees, dataMatchings_[0],
2645 allTs, branchesCorrelationMatrix_,
2646 persCorrelationMatrix_);
2647 // Custom recs
2648 originsCopy_.resize(origins_.size());
2649 originsPrimeCopy_.resize(originsPrime_.size());
2650 for(unsigned int l = 0; l < origins_.size(); ++l) {
2651 mtu::copyTorchMergeTree<float>(origins_[l], originsCopy_[l]);
2652 mtu::copyTorchMergeTree<float>(originsPrime_[l], originsPrimeCopy_[l]);
2653 }
2654 createCustomRecs(originsCopy_, originsPrimeCopy_);
2655
2656 // --- Postprocessing
2657 if(createOutput_) {
2658 for(unsigned int i = 0; i < trees.size(); ++i)
2659 postprocessingPipeline<float>(&(trees[i].tree));
2660 for(unsigned int i = 0; i < trees2.size(); ++i)
2661 postprocessingPipeline<float>(&(trees2[i].tree));
2662 for(unsigned int l = 0; l < origins_.size(); ++l) {
2663 fillMergeTreeStructure(origins_[l]);
2664 postprocessingPipeline<float>(&(origins_[l].mTree.tree));
2665 fillMergeTreeStructure(originsPrime_[l]);
2666 postprocessingPipeline<float>(&(originsPrime_[l].mTree.tree));
2667 }
2668 for(unsigned int j = 0; j < recs_[0].size(); ++j) {
2669 for(unsigned int i = 0; i < recs_.size(); ++i) {
2670 wae::fixTreePrecisionScalars(recs_[i][j].mTree);
2671 postprocessingPipeline<float>(&(recs_[i][j].mTree.tree));
2672 }
2673 }
2674 }
2675
2676 if(not isPersistenceDiagram_) {
2677 for(unsigned int l = 0; l < originsMatchings_.size(); ++l) {
2678 auto &tree1 = (l == 0 ? origins_[0] : originsPrime_[l - 1]);
2679 auto &tree2 = (l == 0 ? originsPrime_[0] : originsPrime_[l]);
2680 convertBranchDecompositionMatching<float>(
2681 &(tree1.mTree.tree), &(tree2.mTree.tree), originsMatchings_[l]);
2682 }
2683 for(unsigned int l = 0; l < dataMatchings_.size(); ++l) {
2684 for(unsigned int i = 0; i < recs_.size(); ++i) {
2685 auto &origin = (l == 0 ? origins_[0] : originsPrime_[l - 1]);
2686 convertBranchDecompositionMatching<float>(&(origin.mTree.tree),
2687 &(recs_[i][l].mTree.tree),
2688 dataMatchings_[l][i]);
2689 }
2690 }
2691 for(unsigned int i = 0; i < reconstMatchings_.size(); ++i) {
2692 auto l = recs_[i].size() - 1;
2693 convertBranchDecompositionMatching<float>(&(recs_[i][0].mTree.tree),
2694 &(recs_[i][l].mTree.tree),
2695 reconstMatchings_[i]);
2696 }
2697 }
2698#ifdef TTK_ENABLE_OPENMP
2699 omp_set_nested(ompNested);
2700#endif
2701#endif
2702}
#define TTK_FORCE_USE(x)
Force the compiler to use the function/method parameter.
Definition BaseClass.h:57
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
Definition BaseClass.h:47
#define ENERGY_COMPARISON_TOLERANCE
void setDebugMsgPrefix(const std::string &prefix)
Definition Debug.h:364
void execute(std::vector< ftm::MergeTree< float > > &trees, std::vector< ftm::MergeTree< float > > &trees2)
double getElapsedTime()
Definition Timer.h:15
int scaleVector(const T *a, const T factor, T *out, const int &dimension=3)
Definition Geometry.cpp:647
T dotProduct(const T *vA0, const T *vA1, const T *vB0, const T *vB1)
Definition Geometry.cpp:388
int flattenMultiDimensionalVector(const std::vector< std::vector< T > > &a, std::vector< T > &out)
Definition Geometry.cpp:742
T1 pow(const T1 val, const T2 n)
Definition Geometry.h:456
T magnitude(const T *v, const int &dimension=3)
Definition Geometry.cpp:509
T distance(const T *p0, const T *p1, const int &dimension=3)
Definition Geometry.cpp:362
MergeTree< dataType > copyMergeTree(ftm::FTMTree_MT *tree, bool doSplitMultiPersPairs=false)
unsigned int idNode
Node index in vect_nodes_.
void printPairs(ftm::MergeTree< float > &mTree, bool useBD=true)
Util function to print pairs of a merge tree.
void adjustNestingScalars(std::vector< float > &scalarsVector, ftm::idNode node, ftm::idNode refNode)
Fix the scalars of a merge tree to ensure that the nesting condition is respected.
void createBalancedBDT(std::vector< std::vector< ftm::idNode > > &parents, std::vector< std::vector< ftm::idNode > > &children, std::vector< float > &scalarsVector, std::vector< std::vector< ftm::idNode > > &childrenFinal, int threadNumber=1)
Create a balanced BDT structure (for output basis initialization).
void fixTreePrecisionScalars(ftm::MergeTree< float > &mTree)
Fix the scalars of a merge tree to ensure that the nesting condition is respected.
T end(std::pair< T, T > &p)
Definition ripserpy.cpp:483
T begin(std::pair< T, T > &p)
Definition ripserpy.cpp:479
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/|__ _|"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)