6using namespace torch::indexing;
14#ifdef TTK_ENABLE_TORCH
18void ttk::MergeTreeAutoencoder::initOutputBasisTreeStructure(
19 mtu::TorchMergeTree<float> &originPrime,
21 mtu::TorchMergeTree<float> &baseOrigin) {
23 std::vector<float> scalarsVector(
24 originPrime.tensor.data_ptr<
float>(),
25 originPrime.tensor.data_ptr<
float>() + originPrime.tensor.numel());
26 unsigned int noNodes = scalarsVector.size() / 2;
27 std::vector<std::vector<ftm::idNode>> childrenFinal(noNodes);
30 if(isPersistenceDiagram_) {
31 for(
unsigned int i = 2; i < scalarsVector.size(); i += 2)
32 childrenFinal[0].emplace_back(i / 2);
35 float maxPers = std::numeric_limits<float>::lowest();
36 unsigned int indMax = 0;
37 for(
unsigned int i = 0; i < scalarsVector.size(); i += 2) {
38 if(maxPers < (scalarsVector[i + 1] - scalarsVector[i])) {
39 maxPers = (scalarsVector[i + 1] - scalarsVector[i]);
44 float temp = scalarsVector[0];
45 scalarsVector[0] = scalarsVector[indMax];
46 scalarsVector[indMax] = temp;
47 temp = scalarsVector[1];
48 scalarsVector[1] = scalarsVector[indMax + 1];
49 scalarsVector[indMax + 1] = temp;
52 for(
unsigned int i = 2; i < scalarsVector.size(); i += 2) {
57 if(not initOriginPrimeStructByCopy_
58 or (
int) noNodes > baseOrigin.mTree.tree.getRealNumberOfNodes()) {
60 std::vector<std::vector<ftm::idNode>> parents(noNodes), children(noNodes);
61 for(
unsigned int i = 0; i < scalarsVector.size(); i += 2) {
62 for(
unsigned int j = i; j < scalarsVector.size(); j += 2) {
65 unsigned int iN = i / 2, jN = j / 2;
66 if(scalarsVector[i] <= scalarsVector[j]
67 and scalarsVector[i + 1] >= scalarsVector[j + 1]) {
69 parents[jN].emplace_back(iN);
70 children[iN].emplace_back(jN);
71 }
else if(scalarsVector[i] >= scalarsVector[j]
72 and scalarsVector[i + 1] <= scalarsVector[j + 1]) {
74 parents[iN].emplace_back(jN);
75 children[jN].emplace_back(iN);
80 parents, children, scalarsVector, childrenFinal, this->threadNumber_);
82 ftm::MergeTree<float> mTreeTemp
83 = ftm::copyMergeTree<float>(baseOrigin.mTree);
85 keepMostImportantPairs<float>(&(mTreeTemp.tree), noNodes, useBD);
86 torch::Tensor reshaped = torch::tensor(scalarsVector).reshape({-1, 2});
87 torch::Tensor order = torch::argsort(
88 (reshaped.index({Slice(), 1}) - reshaped.index({Slice(), 0})), -1,
90 std::vector<unsigned int> nodeCorr(mTreeTemp.tree.getNumberOfNodes(), 0);
91 unsigned int nodeNum = 1;
92 std::queue<ftm::idNode> queue;
93 queue.emplace(mTreeTemp.tree.getRoot());
94 while(!queue.empty()) {
97 std::vector<ftm::idNode> children;
98 mTreeTemp.tree.getChildren(node, children);
99 for(
auto &child : children) {
100 queue.emplace(child);
101 unsigned int tNode = nodeCorr[node];
102 nodeCorr[child] = order[nodeNum].item<
int>();
104 unsigned int tChild = nodeCorr[child];
105 childrenFinal[tNode].emplace_back(tChild);
113 originPrime.mTree = ftm::createEmptyMergeTree<float>(scalarsVector.size());
114 ftm::FTMTree_MT *tree = &(originPrime.mTree.tree);
116 for(
unsigned int i = 0; i < scalarsVector.size(); i += 2) {
117 float temp = scalarsVector[i];
118 scalarsVector[i] = scalarsVector[i + 1];
119 scalarsVector[i + 1] = temp;
122 ftm::setTreeScalars<float>(originPrime.mTree, scalarsVector);
125 originPrime.nodeCorr.clear();
126 originPrime.nodeCorr.assign(
127 scalarsVector.size(), std::numeric_limits<unsigned int>::max());
128 for(
unsigned int i = 0; i < scalarsVector.size(); i += 2) {
130 tree->makeNode(i + 1);
131 tree->getNode(i)->setOrigin(i + 1);
132 tree->getNode(i + 1)->setOrigin(i);
133 originPrime.nodeCorr[i] = (
unsigned int)(i / 2);
135 for(
unsigned int i = 0; i < scalarsVector.size(); i += 2) {
136 unsigned int node = i / 2;
137 for(
auto &child : childrenFinal[node])
138 tree->makeSuperArc(child * 2, i);
140 mtu::getParentsVector(originPrime.mTree, originPrime.parentsOri);
142 if(isTreeHasBigValues(originPrime.mTree, bigValuesThreshold_)) {
143 std::stringstream ss;
144 ss << originPrime.mTree.tree.printPairsFromTree<
float>(
true).str()
146 ss <<
"isTreeHasBigValues(originPrime.mTree)" << std::endl;
147 ss <<
"pause" << std::endl;
153void ttk::MergeTreeAutoencoder::initOutputBasis(
unsigned int l,
156 unsigned int originSize = origins_[l].tensor.sizes()[0];
157 unsigned int origin2Size = 0;
159 origin2Size = origins2_[l].tensor.sizes()[0];
163 auto initOutputBasisOrigin = [
this, &l](torch::Tensor &w,
164 mtu::TorchMergeTree<float> &tmt,
165 mtu::TorchMergeTree<float> &baseTmt) {
167 torch::nn::init::xavier_normal_(w);
168 torch::Tensor baseTmtTensor = baseTmt.tensor;
169 if(normalizedWasserstein_)
171 mtu::mergeTreeToTorchTensor(baseTmt.mTree, baseTmtTensor,
false);
172 torch::Tensor b = torch::fill(torch::zeros({w.sizes()[0], 1}), 0.01);
173 tmt.tensor = (torch::matmul(w, baseTmtTensor) + b);
175 mtu::meanBirthMaxPersShift(tmt.tensor, baseTmtTensor);
177 mtu::belowDiagonalPointsShift(tmt.tensor, baseTmtTensor);
180 = (trackingLossDecoding_ ? noLayers_ : getLatentLayerIndex() + 1);
181 if(trackingLossWeight_ != 0 and l < endLayer) {
183 = (l == 0 ? origins_[0].tensor : originsPrime_[l - 1].tensor);
184 auto baseTensorDiag = baseTensor.reshape({-1, 2});
185 auto basePersDiag = (baseTensorDiag.index({Slice(), 1})
186 - baseTensorDiag.index({Slice(), 0}));
187 auto tmtTensorDiag = tmt.tensor.reshape({-1, 2});
188 auto persDiag = (tmtTensorDiag.index({Slice(1, None), 1})
189 - tmtTensorDiag.index({Slice(1, None), 0}));
190 int noK = std::min(baseTensorDiag.sizes()[0], tmtTensorDiag.sizes()[0]);
191 auto topVal = baseTensorDiag.index({std::get<1>(basePersDiag.topk(noK))});
192 auto indexes = std::get<1>(persDiag.topk(noK - 1)) + 1;
193 indexes = torch::cat({torch::zeros(1), indexes}).to(torch::kLong);
194 if(trackingLossInitRandomness_ != 0) {
195 topVal = (1 - trackingLossInitRandomness_) * topVal
196 + trackingLossInitRandomness_ * tmtTensorDiag.index({indexes});
198 tmtTensorDiag.index_put_({indexes}, topVal);
201 initOutputBasisTreeStructure(
202 tmt, baseTmt.mTree.tree.isJoinTree<
float>(), baseTmt);
203 if(normalizedWasserstein_)
205 mtu::mergeTreeToTorchTensor(tmt.mTree, tmt.tensor,
true);
207 interpolationProjection(tmt);
209 torch::Tensor w = torch::zeros({dim, originSize});
210 initOutputBasisOrigin(w, originsPrime_[l], origins_[l]);
212 if(useDoubleInput_) {
213 w2 = torch::zeros({dim2, origin2Size});
214 initOutputBasisOrigin(w2, origins2Prime_[l], origins2_[l]);
219 initOutputBasisVectors(l, w, w2);
222void ttk::MergeTreeAutoencoder::initOutputBasisVectors(
unsigned int l,
225 vSPrimeTensor_[l] = torch::matmul(w, vSTensor_[l]);
227 vS2PrimeTensor_[l] = torch::matmul(w2, vS2Tensor_[l]);
228 if(normalizedWasserstein_) {
229 mtu::normalizeVectors(originsPrime_[l].tensor, vSPrimeTensor_[l]);
231 mtu::normalizeVectors(origins2Prime_[l].tensor, vS2PrimeTensor_[l]);
235void ttk::MergeTreeAutoencoder::initOutputBasisVectors(
unsigned int l,
238 unsigned int originSize = origins_[l].tensor.sizes()[0];
239 unsigned int origin2Size = 0;
241 origin2Size = origins2_[l].tensor.sizes()[0];
242 torch::Tensor w = torch::zeros({dim, originSize});
243 torch::nn::init::xavier_normal_(w);
244 torch::Tensor w2 = torch::zeros({dim2, origin2Size});
245 torch::nn::init::xavier_normal_(w2);
246 initOutputBasisVectors(l, w, w2);
249void ttk::MergeTreeAutoencoder::initInputBasisOrigin(
250 std::vector<ftm::MergeTree<float>> &treesToUse,
251 std::vector<ftm::MergeTree<float>> &trees2ToUse,
252 double barycenterSizeLimitPercent,
253 unsigned int barycenterMaxNoPairs,
254 unsigned int barycenterMaxNoPairs2,
255 mtu::TorchMergeTree<float> &origin,
256 mtu::TorchMergeTree<float> &origin2,
257 std::vector<double> &inputToBaryDistances,
258 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
260 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
262 computeOneBarycenter<float>(treesToUse, origin.mTree, baryMatchings,
263 inputToBaryDistances, barycenterSizeLimitPercent,
265 if(barycenterMaxNoPairs > 0)
266 keepMostImportantPairs<float>(
267 &(origin.mTree.tree), barycenterMaxNoPairs,
true);
268 if(useDoubleInput_) {
269 std::vector<double> baryDistances2;
270 computeOneBarycenter<float>(trees2ToUse, origin2.mTree, baryMatchings2,
271 baryDistances2, barycenterSizeLimitPercent,
272 useDoubleInput_,
false);
273 if(barycenterMaxNoPairs2 > 0)
274 keepMostImportantPairs<float>(
275 &(origin2.mTree.tree), barycenterMaxNoPairs2,
true);
276 for(
unsigned int i = 0; i < inputToBaryDistances.size(); ++i)
277 inputToBaryDistances[i]
278 = mixDistances(inputToBaryDistances[i], baryDistances2[i]);
281 mtu::getParentsVector(origin.mTree, origin.parentsOri);
282 mtu::mergeTreeToTorchTensor<float>(
283 origin.mTree, origin.tensor, origin.nodeCorr, normalizedWasserstein_);
284 if(useDoubleInput_) {
285 mtu::getParentsVector(origin2.mTree, origin2.parentsOri);
286 mtu::mergeTreeToTorchTensor<float>(
287 origin2.mTree, origin2.tensor, origin2.nodeCorr, normalizedWasserstein_);
291void ttk::MergeTreeAutoencoder::initInputBasisVectors(
292 std::vector<mtu::TorchMergeTree<float>> &tmTreesToUse,
293 std::vector<mtu::TorchMergeTree<float>> &tmTrees2ToUse,
294 std::vector<ftm::MergeTree<float>> &treesToUse,
295 std::vector<ftm::MergeTree<float>> &trees2ToUse,
296 mtu::TorchMergeTree<float> &origin,
297 mtu::TorchMergeTree<float> &origin2,
298 unsigned int noVectors,
299 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
301 std::vector<double> &inputToBaryDistances,
302 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
304 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
306 torch::Tensor &vSTensor,
307 torch::Tensor &vS2Tensor) {
309 auto initializedVectorsProjection
311 ftm::MergeTree<float> &
ttkNotUsed(_barycenter),
312 std::vector<std::vector<double>> &_v,
313 std::vector<std::vector<double>> &
ttkNotUsed(_v2),
314 std::vector<std::vector<std::vector<double>>> &_vS,
315 std::vector<std::vector<std::vector<double>>> &
ttkNotUsed(_v2s),
316 ftm::MergeTree<float> &
ttkNotUsed(_barycenter2),
317 std::vector<std::vector<double>> &
ttkNotUsed(_trees2V),
318 std::vector<std::vector<double>> &
ttkNotUsed(_trees2V2),
319 std::vector<std::vector<std::vector<double>>> &
ttkNotUsed(_trees2Vs),
320 std::vector<std::vector<std::vector<double>>> &
ttkNotUsed(_trees2V2s),
323 std::vector<double> scaledV, scaledVSi;
327 for(
unsigned int i = 0; i < _vS.size(); ++i) {
333 if(prod <= -1.0 + tol or prod >= 1.0 - tol) {
335 for(
unsigned int j = 0; j < _v.size(); ++j)
336 for(
unsigned int k = 0; k < _v[j].size(); ++k)
345 std::vector<std::vector<double>> inputToAxesDistances;
346 std::vector<std::vector<std::vector<double>>> vS, v2s, trees2Vs, trees2V2s;
347 std::stringstream ss;
348 for(
unsigned int vecNum = 0; vecNum < noVectors; ++vecNum) {
350 ss <<
"Compute vectors " << vecNum;
352 std::vector<std::vector<double>> v1, v2, trees2V1, trees2V2;
353 int newVectorOffset = 0;
354 bool projectInitializedVectors =
true;
355 int bestIndex = MergeTreeAxesAlgorithmBase::initVectors<float>(
356 vecNum, origin.mTree, treesToUse, origin2.mTree, trees2ToUse, v1, v2,
357 trees2V1, trees2V2, newVectorOffset, inputToBaryDistances, baryMatchings,
358 baryMatchings2, inputToAxesDistances, vS, v2s, trees2Vs, trees2V2s,
359 projectInitializedVectors, initializedVectorsProjection);
361 v2s.emplace_back(v2);
362 trees2Vs.emplace_back(trees2V1);
363 trees2V2s.emplace_back(trees2V2);
366 ss <<
"bestIndex = " << bestIndex;
371 inputToAxesDistances.resize(1, std::vector<double>(treesToUse.size()));
372 if(bestIndex == -1 and normalizedWasserstein_) {
373 mtu::normalizeVectors(origin, vS[vS.size() - 1]);
375 mtu::normalizeVectors(origin2, trees2Vs[vS.size() - 1]);
377 mtu::axisVectorsToTorchTensor(origin.mTree, vS, vSTensor);
378 if(useDoubleInput_) {
379 mtu::axisVectorsToTorchTensor(origin2.mTree, trees2Vs, vS2Tensor);
381 mtu::TorchMergeTree<float> dummyTmt;
382 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>
384#ifdef TTK_ENABLE_OPENMP
385#pragma omp parallel for schedule(dynamic) \
386 num_threads(this->threadNumber_) if(parallelize_)
388 for(
unsigned int i = 0; i < treesToUse.size(); ++i) {
389 auto &tmt2ToUse = (not useDoubleInput_ ? dummyTmt : tmTrees2ToUse[i]);
390 if(not euclideanVectorsInit_) {
392 auto newAlpha = torch::ones({1, 1});
393 if(bestIndex == -1) {
394 newAlpha = torch::zeros({1, 1});
396 allAlphasInit[i][l] = (allAlphasInit[i][l].defined()
397 ? torch::cat({allAlphasInit[i][l], newAlpha})
399 torch::Tensor bestAlphas;
400 bool isCalled =
true;
401 inputToAxesDistances[0][i] = assignmentOneData(
402 tmTreesToUse[i], origin, vSTensor, tmt2ToUse, origin2, vS2Tensor, k,
403 allAlphasInit[i][l], bestAlphas, isCalled);
404 allAlphasInit[i][l] = bestAlphas.detach();
406 auto &baryMatching2ToUse
407 = (not useDoubleInput_ ? dummyBaryMatching2 : baryMatchings2[i]);
408 torch::Tensor alphas;
409 computeAlphas(tmTreesToUse[i], origin, vSTensor, origin,
410 baryMatchings[i], tmt2ToUse, origin2, vS2Tensor, origin2,
411 baryMatching2ToUse, alphas);
412 mtu::TorchMergeTree<float> interpolated, interpolated2;
413 getMultiInterpolation(origin, vSTensor, alphas, interpolated);
415 getMultiInterpolation(origin2, vS2Tensor, alphas, interpolated2);
416 torch::Tensor tensorDist;
418 getDifferentiableDistanceFromMatchings(
419 interpolated, tmTreesToUse[i], interpolated2, tmt2ToUse,
420 baryMatchings[i], baryMatching2ToUse, tensorDist, doSqrt);
421 inputToAxesDistances[0][i] = tensorDist.item<
double>();
422 allAlphasInit[i][l] = alphas.detach();
428void ttk::MergeTreeAutoencoder::initClusteringLossParameters() {
429 unsigned int l = getLatentLayerIndex();
430 unsigned int noCentroids
431 = std::set<unsigned int>(clusterAsgn_.begin(), clusterAsgn_.end()).size();
432 latentCentroids_.resize(noCentroids);
433 for(
unsigned int c = 0; c < noCentroids; ++c) {
434 unsigned int firstIndex = std::numeric_limits<unsigned int>::max();
435 for(
unsigned int i = 0; i < clusterAsgn_.size(); ++i) {
436 if(clusterAsgn_[i] == c) {
441 if(firstIndex >= allAlphas_.size()) {
442 printWrn(
"no data found for cluster " + std::to_string(c));
445 latentCentroids_[c] = allAlphas_[firstIndex][l].detach().clone();
447 for(
unsigned int i = 0; i < allAlphas_.size(); ++i) {
450 if(clusterAsgn_[i] == c) {
451 latentCentroids_[c] += allAlphas_[i][l];
455 latentCentroids_[c] /= torch::tensor(noData);
456 latentCentroids_[c] = latentCentroids_[c].detach();
457 latentCentroids_[c].requires_grad_(
true);
461float ttk::MergeTreeAutoencoder::initParameters(
462 std::vector<mtu::TorchMergeTree<float>> &trees,
463 std::vector<mtu::TorchMergeTree<float>> &trees2,
464 bool computeReconstructionError) {
468 noLayers_ = encoderNoLayers_ * 2 + 1 + 1;
469 if(encoderNoLayers_ <= -1)
471 std::vector<double> layersOriginPrimeSizePercent(noLayers_);
472 std::vector<unsigned int> layersNoAxes(noLayers_);
474 layersNoAxes[0] = numberOfAxes_;
475 layersOriginPrimeSizePercent[0] = latentSpaceOriginPrimeSizePercent_;
477 layersNoAxes[1] = inputNumberOfAxes_;
478 layersOriginPrimeSizePercent[1] = barycenterSizeLimitPercent_;
481 for(
unsigned int l = 0; l < noLayers_ / 2; ++l) {
482 double alpha = (double)(l) / (noLayers_ / 2 - 1);
484 = (1 - alpha) * inputNumberOfAxes_ + alpha * numberOfAxes_;
485 layersNoAxes[l] = noAxes;
486 layersNoAxes[noLayers_ - 1 - l] = noAxes;
487 double originPrimeSizePercent
488 = (1 - alpha) * inputOriginPrimeSizePercent_
489 + alpha * latentSpaceOriginPrimeSizePercent_;
490 layersOriginPrimeSizePercent[l] = originPrimeSizePercent;
491 layersOriginPrimeSizePercent[noLayers_ - 1 - l] = originPrimeSizePercent;
493 if(scaleLayerAfterLatent_)
494 layersNoAxes[noLayers_ / 2]
495 = (layersNoAxes[noLayers_ / 2 - 1] + layersNoAxes[noLayers_ / 2 + 1])
499 std::vector<ftm::FTMTree_MT *> ftmTrees(trees.size()),
500 ftmTrees2(trees2.size());
501 for(
unsigned int i = 0; i < trees.size(); ++i)
502 ftmTrees[i] = &(trees[i].mTree.tree);
503 for(
unsigned int i = 0; i < trees2.size(); ++i)
504 ftmTrees2[i] = &(trees2[i].mTree.tree);
505 auto sizeMetric = getSizeLimitMetric(ftmTrees);
506 auto sizeMetric2 = getSizeLimitMetric(ftmTrees2);
507 auto getDim = [](
double _sizeMetric,
double _percent) {
508 unsigned int dim = std::max((
int)(_sizeMetric * _percent / 100.0), 2) * 2;
513 origins_.resize(noLayers_);
514 originsPrime_.resize(noLayers_);
515 vSTensor_.resize(noLayers_);
516 vSPrimeTensor_.resize(noLayers_);
517 if(trees2.size() != 0) {
518 origins2_.resize(noLayers_);
519 origins2Prime_.resize(noLayers_);
520 vS2Tensor_.resize(noLayers_);
521 vS2PrimeTensor_.resize(noLayers_);
525 bool fullSymmetricAE = fullSymmetricAE_;
526 bool outputBasisActivation = activateOutputInit_;
528 std::vector<mtu::TorchMergeTree<float>> recs, recs2;
529 std::vector<std::vector<torch::Tensor>> allAlphasInit(
530 trees.size(), std::vector<torch::Tensor>(noLayers_));
531 for(
unsigned int l = 0; l < noLayers_; ++l) {
533 std::stringstream ss;
534 ss <<
"Init Layer " << l;
538 if(l < (
unsigned int)(noLayers_ / 2) or not fullSymmetricAE
539 or (noLayers_ <= 2 and not fullSymmetricAE)) {
541 std::vector<ftm::MergeTree<float>> treesToUse, trees2ToUse;
542 for(
unsigned int i = 0; i < trees.size(); ++i) {
543 treesToUse.emplace_back((l == 0 ? trees[i].mTree : recs[i].mTree));
544 if(trees2.size() != 0)
545 trees2ToUse.emplace_back((l == 0 ? trees2[i].mTree : recs2[i].mTree));
551 std::vector<double> inputToBaryDistances;
552 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
553 baryMatchings, baryMatchings2;
554 if(l != 0 or not origins_[0].tensor.defined()) {
555 double sizeLimit = (l == 0 ? barycenterSizeLimitPercent_ : 0);
556 unsigned int maxNoPairs
557 = (l == 0 ? 0 : originsPrime_[l - 1].tensor.sizes()[0] / 2);
558 unsigned int maxNoPairs2
559 = (l == 0 or not useDoubleInput_
561 : origins2Prime_[l - 1].tensor.sizes()[0] / 2);
562 initInputBasisOrigin(treesToUse, trees2ToUse, sizeLimit, maxNoPairs,
563 maxNoPairs2, origins_[l], origins2_[l],
564 inputToBaryDistances, baryMatchings,
567 baryMatchings_L0_ = baryMatchings;
568 baryMatchings2_L0_ = baryMatchings2;
569 inputToBaryDistances_L0_ = inputToBaryDistances;
572 baryMatchings = baryMatchings_L0_;
573 baryMatchings2 = baryMatchings2_L0_;
574 inputToBaryDistances = inputToBaryDistances_L0_;
576 printMsg(
"Compute origin time", 1, t_origin.getElapsedTime(),
582 auto &tmTreesToUse = (l == 0 ? trees : recs);
583 auto &tmTrees2ToUse = (l == 0 ? trees2 : recs2);
584 initInputBasisVectors(
585 tmTreesToUse, tmTrees2ToUse, treesToUse, trees2ToUse, origins_[l],
586 origins2_[l], layersNoAxes[l], allAlphasInit, l, inputToBaryDistances,
587 baryMatchings, baryMatchings2, vSTensor_[l], vS2Tensor_[l]);
588 printMsg(
"Compute vectors time", 1, t_vectors.getElapsedTime(),
594 unsigned int middle = noLayers_ / 2;
595 unsigned int l_opp = middle - (l - middle + 1);
596 mtu::copyTorchMergeTree(originsPrime_[l_opp], origins_[l]);
597 mtu::copyTensor(vSPrimeTensor_[l_opp], vSTensor_[l]);
598 if(trees2.size() != 0) {
599 if(fullSymmetricAE) {
600 mtu::copyTorchMergeTree(origins2Prime_[l_opp], origins2_[l]);
601 mtu::copyTensor(vS2PrimeTensor_[l_opp], vS2Tensor_[l]);
604 for(
unsigned int i = 0; i < trees.size(); ++i)
605 allAlphasInit[i][l] = allAlphasInit[i][l_opp];
609 auto initOutputBasisSpecialCase
610 = [
this, &l, &layersNoAxes, &trees, &trees2]() {
613 mtu::copyTorchMergeTree(origins_[0], originsPrime_[l]);
615 mtu::copyTorchMergeTree(origins2_[0], origins2Prime_[l]);
618 if(layersNoAxes[l] != layersNoAxes[0]) {
620 std::vector<ftm::MergeTree<float>> treesToUse, trees2ToUse;
621 for(
unsigned int i = 0; i < trees.size(); ++i) {
622 treesToUse.emplace_back(trees[i].mTree);
624 trees2ToUse.emplace_back(trees2[i].mTree);
626 std::vector<std::vector<torch::Tensor>> allAlphasInitT(
627 trees.size(), std::vector<torch::Tensor>(noLayers_));
628 initInputBasisVectors(
629 trees, trees2, treesToUse, trees2ToUse, originsPrime_[l],
630 origins2Prime_[l], layersNoAxes[l], allAlphasInitT, l,
631 inputToBaryDistances_L0_, baryMatchings_L0_, baryMatchings2_L0_,
632 vSPrimeTensor_[l], vS2PrimeTensor_[l]);
634 mtu::copyTensor(vSTensor_[0], vSPrimeTensor_[l]);
636 mtu::copyTensor(vS2Tensor_[0], vS2PrimeTensor_[l]);
640 if((noLayers_ == 2 and l == 1) or noLayers_ == 1) {
642 initOutputBasisSpecialCase();
643 }
else if(l < (
unsigned int)(noLayers_ / 2)) {
644 unsigned int dim = getDim(sizeMetric, layersOriginPrimeSizePercent[l]);
645 dim = std::min(dim, (
unsigned int)origins_[l].tensor.sizes()[0]);
646 unsigned int dim2 = getDim(sizeMetric2, layersOriginPrimeSizePercent[l]);
647 if(trees2.size() != 0)
648 dim2 = std::min(dim2, (
unsigned int)origins2_[l].tensor.sizes()[0]);
649 initOutputBasis(l, dim, dim2);
654 unsigned int middle = noLayers_ / 2;
655 unsigned int l_opp = middle - (l - middle + 1);
656 mtu::copyTorchMergeTree(origins_[l_opp], originsPrime_[l]);
657 if(trees2.size() != 0)
658 mtu::copyTorchMergeTree(origins2_[l_opp], origins2Prime_[l]);
659 if(l == (
unsigned int)(noLayers_) / 2 and scaleLayerAfterLatent_) {
661 = (trees2.size() != 0 ? origins2Prime_[l].tensor.sizes()[0] : 0);
662 initOutputBasisVectors(l, originsPrime_[l].tensor.sizes()[0], dim2);
664 mtu::copyTensor(vSTensor_[l_opp], vSPrimeTensor_[l]);
665 if(trees2.size() != 0)
666 mtu::copyTensor(vS2Tensor_[l_opp], vS2PrimeTensor_[l]);
672 recs.resize(trees.size());
673 recs2.resize(trees.size());
675 unsigned int noReset = 0;
676 while(i < trees.size()) {
677 outputBasisReconstruction(originsPrime_[l], vSPrimeTensor_[l],
678 origins2Prime_[l], vS2PrimeTensor_[l],
679 allAlphasInit[i][l], recs[i], recs2[i],
680 outputBasisActivation);
681 if(recs[i].mTree.tree.getRealNumberOfNodes() == 0) {
683 if((noLayers_ == 2 and l == 1) or noLayers_ == 1) {
684 initOutputBasisSpecialCase();
685 }
else if(l < (
unsigned int)(noLayers_ / 2)) {
687 getDim(sizeMetric, layersOriginPrimeSizePercent[l]),
688 getDim(sizeMetric2, layersOriginPrimeSizePercent[l]));
690 printErr(
"recs[i].mTree.tree.getRealNumberOfNodes() == 0");
691 std::stringstream ssT;
692 ssT <<
"layer " << l;
694 return std::numeric_limits<float>::max();
699 printWrn(
"[initParameters] noReset >= 100");
700 return std::numeric_limits<float>::max();
706 allAlphas_ = allAlphasInit;
709 if(clusteringLossWeight_ != 0)
710 initClusteringLossParameters();
713 float error = 0.0, recLoss = 0.0;
714 if(computeReconstructionError) {
716 std::vector<unsigned int> indexes(trees.size());
717 std::iota(indexes.begin(), indexes.end(), 0);
720 std::vector<std::vector<torch::Tensor>> bestAlphas;
721 std::vector<std::vector<mtu::TorchMergeTree<float>>> layersOuts,
723 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
724 matchings, matchings2;
726 = forwardStep(trees, trees2, indexes, k, allAlphasInit,
727 computeReconstructionError, recs, recs2, bestAlphas,
728 layersOuts, layersOuts2, matchings, matchings2, recLoss);
730 printWrn(
"[initParameters] forwardStep reset");
731 return std::numeric_limits<float>::max();
733 error = recLoss * reconstructionLossWeight_;
734 if(metricLossWeight_ != 0) {
735 torch::Tensor metricLoss;
736 computeMetricLoss(layersOuts, layersOuts2, allAlphas_, distanceMatrix_,
737 indexes, metricLoss);
738 baseRecLoss_ = std::numeric_limits<double>::max();
739 metricLoss *= metricLossWeight_
740 * getCustomLossDynamicWeight(recLoss, baseRecLoss_);
741 error += metricLoss.item<
float>();
743 if(clusteringLossWeight_ != 0) {
744 torch::Tensor clusteringLoss, asgn;
745 computeClusteringLoss(allAlphas_, indexes, clusteringLoss, asgn);
746 baseRecLoss_ = std::numeric_limits<double>::max();
747 clusteringLoss *= clusteringLossWeight_
748 * getCustomLossDynamicWeight(recLoss, baseRecLoss_);
749 error += clusteringLoss.item<
float>();
751 if(trackingLossWeight_ != 0) {
752 torch::Tensor trackingLoss;
753 computeTrackingLoss(trackingLoss);
754 trackingLoss *= trackingLossWeight_;
755 error += trackingLoss.item<
float>();
761void ttk::MergeTreeAutoencoder::initStep(
762 std::vector<mtu::TorchMergeTree<float>> &trees,
763 std::vector<mtu::TorchMergeTree<float>> &trees2) {
765 originsPrime_.clear();
767 vSPrimeTensor_.clear();
769 origins2Prime_.clear();
771 vS2PrimeTensor_.clear();
773 float bestError = std::numeric_limits<float>::max();
774 std::vector<torch::Tensor> bestVSTensor, bestVSPrimeTensor, bestVS2Tensor,
775 bestVS2PrimeTensor, bestLatentCentroids;
776 std::vector<mtu::TorchMergeTree<float>> bestOrigins, bestOriginsPrime,
777 bestOrigins2, bestOrigins2Prime;
778 std::vector<std::vector<torch::Tensor>> bestAlphasInit;
779 for(
unsigned int n = 0; n < noInit_; ++n) {
781 float error = initParameters(trees, trees2, (noInit_ != 1));
784 std::stringstream ss;
785 ss <<
"Init error = " << error;
787 if(error < bestError) {
789 copyParams(origins_, originsPrime_, vSTensor_, vSPrimeTensor_,
790 origins2_, origins2Prime_, vS2Tensor_, vS2PrimeTensor_,
791 allAlphas_, bestOrigins, bestOriginsPrime, bestVSTensor,
792 bestVSPrimeTensor, bestOrigins2, bestOrigins2Prime,
793 bestVS2Tensor, bestVS2PrimeTensor, bestAlphasInit);
794 bestLatentCentroids.resize(latentCentroids_.size());
795 for(
unsigned int i = 0; i < latentCentroids_.size(); ++i)
796 mtu::copyTensor(latentCentroids_[i], bestLatentCentroids[i]);
805 std::stringstream ss;
806 ss <<
"Best init error = " << bestError;
808 copyParams(bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor,
809 bestOrigins2, bestOrigins2Prime, bestVS2Tensor,
810 bestVS2PrimeTensor, bestAlphasInit, origins_, originsPrime_,
811 vSTensor_, vSPrimeTensor_, origins2_, origins2Prime_, vS2Tensor_,
812 vS2PrimeTensor_, allAlphas_);
813 latentCentroids_.resize(bestLatentCentroids.size());
814 for(
unsigned int i = 0; i < bestLatentCentroids.size(); ++i)
815 mtu::copyTensor(bestLatentCentroids[i], latentCentroids_[i]);
818 for(
unsigned int l = 0; l < noLayers_; ++l) {
819 origins_[l].tensor.requires_grad_(
true);
820 originsPrime_[l].tensor.requires_grad_(
true);
821 vSTensor_[l].requires_grad_(
true);
822 vSPrimeTensor_[l].requires_grad_(
true);
823 if(trees2.size() != 0) {
824 origins2_[l].tensor.requires_grad_(
true);
825 origins2Prime_[l].tensor.requires_grad_(
true);
826 vS2Tensor_[l].requires_grad_(
true);
827 vS2PrimeTensor_[l].requires_grad_(
true);
832 std::stringstream ss;
835 if(isTreeHasBigValues(origins_[l].mTree, bigValuesThreshold_)) {
837 ss <<
"origins_[" << l <<
"] has big values!" << std::endl;
841 if(isTreeHasBigValues(originsPrime_[l].mTree, bigValuesThreshold_)) {
843 ss <<
"originsPrime_[" << l <<
"] has big values!" << std::endl;
848 ss <<
"vS size = " << vSTensor_[l].sizes();
851 ss <<
"vS' size = " << vSPrimeTensor_[l].sizes();
853 if(trees2.size() != 0) {
855 ss <<
"vS2 size = " << vS2Tensor_[l].sizes();
858 ss <<
"vS2' size = " << vS2PrimeTensor_[l].sizes();
864 if(clusteringLossWeight_ != 0)
865 initClusteringLossParameters();
871void ttk::MergeTreeAutoencoder::interpolationDiagonalProjection(
872 mtu::TorchMergeTree<float> &interpolation) {
873 torch::Tensor diagTensor = interpolation.tensor.reshape({-1, 2});
874 if(interpolation.tensor.requires_grad())
875 diagTensor = diagTensor.detach();
877 torch::Tensor birthTensor = diagTensor.index({Slice(), 0});
878 torch::Tensor deathTensor = diagTensor.index({Slice(), 1});
880 torch::Tensor indexer = (birthTensor > deathTensor);
882 torch::Tensor allProj = (birthTensor + deathTensor) / 2.0;
883 allProj = allProj.index({indexer});
884 allProj = allProj.reshape({-1, 1});
886 diagTensor.index_put_({indexer}, allProj);
889void ttk::MergeTreeAutoencoder::interpolationNestingProjection(
890 mtu::TorchMergeTree<float> &interpolation) {
891 torch::Tensor diagTensor = interpolation.tensor.reshape({-1, 2});
892 if(interpolation.tensor.requires_grad())
893 diagTensor = diagTensor.detach();
895 torch::Tensor birthTensor = diagTensor.index({Slice(1, None), 0});
896 torch::Tensor deathTensor = diagTensor.index({Slice(1, None), 1});
898 torch::Tensor birthIndexer = (birthTensor < 0);
899 torch::Tensor deathIndexer = (deathTensor < 0);
900 birthTensor.index_put_(
901 {birthIndexer}, torch::zeros_like(birthTensor.index({birthIndexer})));
902 deathTensor.index_put_(
903 {deathIndexer}, torch::zeros_like(deathTensor.index({deathIndexer})));
905 birthIndexer = (birthTensor > 1);
906 deathIndexer = (deathTensor > 1);
907 birthTensor.index_put_(
908 {birthIndexer}, torch::ones_like(birthTensor.index({birthIndexer})));
909 deathTensor.index_put_(
910 {deathIndexer}, torch::ones_like(deathTensor.index({deathIndexer})));
913void ttk::MergeTreeAutoencoder::interpolationProjection(
914 mtu::TorchMergeTree<float> &interpolation) {
915 interpolationDiagonalProjection(interpolation);
916 if(normalizedWasserstein_)
917 interpolationNestingProjection(interpolation);
919 ftm::MergeTree<float> interpolationNew;
920 bool noRoot = mtu::torchTensorToMergeTree<float>(
921 interpolation, normalizedWasserstein_, interpolationNew);
923 printWrn(
"[interpolationProjection] no root found");
926 persistenceThresholding<float>(&(interpolation.mTree.tree), 0.001);
928 if(isThereMissingPairs(interpolation) and isPersistenceDiagram_)
929 printWrn(
"[getMultiInterpolation] missing pairs");
932void ttk::MergeTreeAutoencoder::getMultiInterpolation(
933 mtu::TorchMergeTree<float> &origin,
935 torch::Tensor &alphas,
936 mtu::TorchMergeTree<float> &interpolation) {
937 mtu::copyTorchMergeTree<float>(origin, interpolation);
938 interpolation.tensor = origin.tensor + torch::matmul(vS, alphas);
939 interpolationProjection(interpolation);
945void ttk::MergeTreeAutoencoder::getAlphasOptimizationTensors(
946 mtu::TorchMergeTree<float> &tree,
947 mtu::TorchMergeTree<float> &origin,
948 torch::Tensor &vSTensor,
949 mtu::TorchMergeTree<float> &interpolated,
950 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
951 torch::Tensor &reorderedTreeTensor,
952 torch::Tensor &deltaOrigin,
953 torch::Tensor &deltaA,
954 torch::Tensor &originTensor_f,
955 torch::Tensor &vSTensor_f) {
957 std::vector<int> tensorMatching;
958 mtu::getTensorMatching(interpolated, tree, matching, tensorMatching);
960 torch::Tensor indexes = torch::tensor(tensorMatching);
961 torch::Tensor projIndexer = (indexes == -1).reshape({-1, 1});
963 dataReorderingGivenMatching(
964 origin, tree, projIndexer, indexes, reorderedTreeTensor, deltaOrigin);
967 deltaA = vSTensor.transpose(0, 1).reshape({vSTensor.sizes()[1], -1, 2});
968 deltaA = (deltaA.index({Slice(), Slice(), 0})
969 + deltaA.index({Slice(), Slice(), 1}))
971 deltaA = torch::stack({deltaA, deltaA}, 2);
972 deltaA = deltaA * projIndexer;
973 deltaA = deltaA.reshape({vSTensor.sizes()[1], -1}).transpose(0, 1);
976 originTensor_f = origin.tensor;
977 vSTensor_f = vSTensor;
980void ttk::MergeTreeAutoencoder::computeAlphas(
981 mtu::TorchMergeTree<float> &tree,
982 mtu::TorchMergeTree<float> &origin,
983 torch::Tensor &vSTensor,
984 mtu::TorchMergeTree<float> &interpolated,
985 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
986 mtu::TorchMergeTree<float> &tree2,
987 mtu::TorchMergeTree<float> &origin2,
988 torch::Tensor &vS2Tensor,
989 mtu::TorchMergeTree<float> &interpolated2,
990 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
991 torch::Tensor &alphasOut) {
992 torch::Tensor reorderedTreeTensor, deltaOrigin, deltaA, originTensor_f,
994 getAlphasOptimizationTensors(tree, origin, vSTensor, interpolated, matching,
995 reorderedTreeTensor, deltaOrigin, deltaA,
996 originTensor_f, vSTensor_f);
998 if(useDoubleInput_) {
999 torch::Tensor reorderedTree2Tensor, deltaOrigin2, deltaA2, origin2Tensor_f,
1001 getAlphasOptimizationTensors(tree2, origin2, vS2Tensor, interpolated2,
1002 matching2, reorderedTree2Tensor, deltaOrigin2,
1003 deltaA2, origin2Tensor_f, vS2Tensor_f);
1004 vSTensor_f = torch::cat({vSTensor_f, vS2Tensor_f});
1005 deltaA = torch::cat({deltaA, deltaA2});
1007 = torch::cat({reorderedTreeTensor, reorderedTree2Tensor});
1008 originTensor_f = torch::cat({originTensor_f, origin2Tensor_f});
1009 deltaOrigin = torch::cat({deltaOrigin, deltaOrigin2});
1012 torch::Tensor r_axes = vSTensor_f - deltaA;
1013 torch::Tensor r_data = reorderedTreeTensor - originTensor_f + deltaOrigin;
1016 auto driver =
"gelsd";
1018 = std::get<0>(torch::linalg::lstsq(r_axes, r_data, c10::nullopt, driver));
1020 alphasOut.reshape({-1, 1});
1023float ttk::MergeTreeAutoencoder::assignmentOneData(
1024 mtu::TorchMergeTree<float> &tree,
1025 mtu::TorchMergeTree<float> &origin,
1026 torch::Tensor &vSTensor,
1027 mtu::TorchMergeTree<float> &tree2,
1028 mtu::TorchMergeTree<float> &origin2,
1029 torch::Tensor &vS2Tensor,
1031 torch::Tensor &alphasInit,
1032 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching,
1033 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching2,
1034 torch::Tensor &bestAlphas,
1036 torch::Tensor alphas, oldAlphas;
1037 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching, matching2;
1038 float bestDistance = std::numeric_limits<float>::max();
1039 mtu::TorchMergeTree<float> interpolated, interpolated2;
1041 auto reset = [&]() {
1042 alphasInit = torch::randn_like(alphas);
1045 unsigned int noUpdate = 0;
1046 unsigned int noReset = 0;
1049 if(alphasInit.defined())
1050 alphas = alphasInit;
1052 alphas = torch::zeros({vSTensor.sizes()[1], 1});
1054 computeAlphas(tree, origin, vSTensor, interpolated, matching, tree2,
1055 origin2, vS2Tensor, interpolated2, matching2, alphas);
1056 if(oldAlphas.defined() and alphas.defined() and alphas.equal(oldAlphas)
1061 mtu::copyTensor(alphas, oldAlphas);
1062 getMultiInterpolation(origin, vSTensor, alphas, interpolated);
1064 getMultiInterpolation(origin2, vS2Tensor, alphas, interpolated2);
1065 if(interpolated.mTree.tree.getRealNumberOfNodes() == 0
1067 and interpolated2.mTree.tree.getRealNumberOfNodes() == 0)) {
1070 printWrn(
"[assignmentOneData] noReset >= 100");
1075 computeOneDistance<float>(interpolated.mTree, tree.mTree, matching,
1076 distance, isCalled, useDoubleInput_);
1077 if(useDoubleInput_) {
1079 computeOneDistance<float>(interpolated2.mTree, tree2.mTree, matching2,
1080 distance2, isCalled, useDoubleInput_,
false);
1081 distance = mixDistances<float>(distance, distance2);
1083 if(distance < bestDistance and i != 0) {
1085 bestMatching = matching;
1086 bestMatching2 = matching2;
1087 bestAlphas = alphas;
1093 printErr(
"[assignmentOneData] noUpdate == 0");
1094 return bestDistance;
1097float ttk::MergeTreeAutoencoder::assignmentOneData(
1098 mtu::TorchMergeTree<float> &tree,
1099 mtu::TorchMergeTree<float> &origin,
1100 torch::Tensor &vSTensor,
1101 mtu::TorchMergeTree<float> &tree2,
1102 mtu::TorchMergeTree<float> &origin2,
1103 torch::Tensor &vS2Tensor,
1105 torch::Tensor &alphasInit,
1106 torch::Tensor &bestAlphas,
1108 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> bestMatching,
1110 return assignmentOneData(tree, origin, vSTensor, tree2, origin2, vS2Tensor, k,
1111 alphasInit, bestMatching, bestMatching2, bestAlphas,
1115torch::Tensor ttk::MergeTreeAutoencoder::activation(torch::Tensor &in) {
1117 switch(activationFunction_) {
1119 act = torch::nn::LeakyReLU()(in);
1123 act = torch::nn::ReLU()(in);
1128void ttk::MergeTreeAutoencoder::outputBasisReconstruction(
1129 mtu::TorchMergeTree<float> &originPrime,
1130 torch::Tensor &vSPrimeTensor,
1131 mtu::TorchMergeTree<float> &origin2Prime,
1132 torch::Tensor &vS2PrimeTensor,
1133 torch::Tensor &alphas,
1134 mtu::TorchMergeTree<float> &out,
1135 mtu::TorchMergeTree<float> &out2,
1139 torch::Tensor act = (activate ? activation(alphas) : alphas);
1140 getMultiInterpolation(originPrime, vSPrimeTensor, act, out);
1142 getMultiInterpolation(origin2Prime, vS2PrimeTensor, act, out2);
1145bool ttk::MergeTreeAutoencoder::forwardOneLayer(
1146 mtu::TorchMergeTree<float> &tree,
1147 mtu::TorchMergeTree<float> &origin,
1148 torch::Tensor &vSTensor,
1149 mtu::TorchMergeTree<float> &originPrime,
1150 torch::Tensor &vSPrimeTensor,
1151 mtu::TorchMergeTree<float> &tree2,
1152 mtu::TorchMergeTree<float> &origin2,
1153 torch::Tensor &vS2Tensor,
1154 mtu::TorchMergeTree<float> &origin2Prime,
1155 torch::Tensor &vS2PrimeTensor,
1157 torch::Tensor &alphasInit,
1158 mtu::TorchMergeTree<float> &out,
1159 mtu::TorchMergeTree<float> &out2,
1160 torch::Tensor &bestAlphas,
1161 float &bestDistance) {
1162 bool goodOutput =
false;
1164 while(not goodOutput) {
1165 bool isCalled =
true;
1167 = assignmentOneData(tree, origin, vSTensor, tree2, origin2, vS2Tensor, k,
1168 alphasInit, bestAlphas, isCalled);
1169 outputBasisReconstruction(originPrime, vSPrimeTensor, origin2Prime,
1170 vS2PrimeTensor, bestAlphas, out, out2);
1171 goodOutput = (out.mTree.tree.getRealNumberOfNodes() != 0
1172 and (not useDoubleInput_
1173 or out2.mTree.tree.getRealNumberOfNodes() != 0));
1174 if(not goodOutput) {
1176 if(noReset >= 100) {
1177 printWrn(
"[forwardOneLayer] noReset >= 100");
1180 alphasInit = torch::randn_like(alphasInit);
1186bool ttk::MergeTreeAutoencoder::forwardOneLayer(
1187 mtu::TorchMergeTree<float> &tree,
1188 mtu::TorchMergeTree<float> &origin,
1189 torch::Tensor &vSTensor,
1190 mtu::TorchMergeTree<float> &originPrime,
1191 torch::Tensor &vSPrimeTensor,
1192 mtu::TorchMergeTree<float> &tree2,
1193 mtu::TorchMergeTree<float> &origin2,
1194 torch::Tensor &vS2Tensor,
1195 mtu::TorchMergeTree<float> &origin2Prime,
1196 torch::Tensor &vS2PrimeTensor,
1198 torch::Tensor &alphasInit,
1199 mtu::TorchMergeTree<float> &out,
1200 mtu::TorchMergeTree<float> &out2,
1201 torch::Tensor &bestAlphas) {
1203 return forwardOneLayer(tree, origin, vSTensor, originPrime, vSPrimeTensor,
1204 tree2, origin2, vS2Tensor, origin2Prime,
1205 vS2PrimeTensor, k, alphasInit, out, out2, bestAlphas,
1209bool ttk::MergeTreeAutoencoder::forwardOneData(
1210 mtu::TorchMergeTree<float> &tree,
1211 mtu::TorchMergeTree<float> &tree2,
1212 unsigned int treeIndex,
1214 std::vector<torch::Tensor> &alphasInit,
1215 mtu::TorchMergeTree<float> &out,
1216 mtu::TorchMergeTree<float> &out2,
1217 std::vector<torch::Tensor> &dataAlphas,
1218 std::vector<mtu::TorchMergeTree<float>> &outs,
1219 std::vector<mtu::TorchMergeTree<float>> &outs2) {
1220 outs.resize(noLayers_ - 1);
1221 outs2.resize(noLayers_ - 1);
1222 dataAlphas.resize(noLayers_);
1223 for(
unsigned int l = 0; l < noLayers_; ++l) {
1224 auto &treeToUse = (l == 0 ? tree : outs[l - 1]);
1225 auto &tree2ToUse = (l == 0 ? tree2 : outs2[l - 1]);
1226 auto &outToUse = (l != noLayers_ - 1 ? outs[l] : out);
1227 auto &out2ToUse = (l != noLayers_ - 1 ? outs2[l] : out2);
1228 bool reset = forwardOneLayer(
1229 treeToUse, origins_[l], vSTensor_[l], originsPrime_[l], vSPrimeTensor_[l],
1230 tree2ToUse, origins2_[l], vS2Tensor_[l], origins2Prime_[l],
1231 vS2PrimeTensor_[l], k, alphasInit[l], outToUse, out2ToUse, dataAlphas[l]);
1236 = [
this, &treeIndex, &l](
1237 std::vector<std::vector<mtu::TorchMergeTree<float>>> &recs,
1238 mtu::TorchMergeTree<float> &outT) {
1239 if(recs[treeIndex].size() > noLayers_)
1240 mtu::copyTorchMergeTree<float>(outT, recs[treeIndex][l + 1]);
1242 mtu::TorchMergeTree<float> tmt;
1243 mtu::copyTorchMergeTree<float>(outT, tmt);
1244 recs[treeIndex].emplace_back(tmt);
1247 updateRecs(recs_, outToUse);
1249 updateRecs(recs2_, out2ToUse);
1254bool ttk::MergeTreeAutoencoder::forwardStep(
1255 std::vector<mtu::TorchMergeTree<float>> &trees,
1256 std::vector<mtu::TorchMergeTree<float>> &trees2,
1257 std::vector<unsigned int> &indexes,
1259 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
1260 bool computeReconstructionError,
1261 std::vector<mtu::TorchMergeTree<float>> &outs,
1262 std::vector<mtu::TorchMergeTree<float>> &outs2,
1263 std::vector<std::vector<torch::Tensor>> &bestAlphas,
1264 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
1265 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
1266 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1268 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1272 outs.resize(trees.size());
1273 outs2.resize(trees.size());
1274 bestAlphas.resize(trees.size());
1275 layersOuts.resize(trees.size());
1276 layersOuts2.resize(trees.size());
1277 matchings.resize(trees.size());
1279 matchings2.resize(trees2.size());
1280 mtu::TorchMergeTree<float> dummyTMT;
1282#ifdef TTK_ENABLE_OPENMP
1283#pragma omp parallel for schedule(dynamic) num_threads(this->threadNumber_) \
1284 if(parallelize_) reduction(||: reset) reduction(+:loss)
1286 for(
unsigned int ind = 0; ind < indexes.size(); ++ind) {
1287 unsigned int i = indexes[ind];
1288 auto &tree2ToUse = (trees2.size() == 0 ? dummyTMT : trees2[i]);
1290 = forwardOneData(trees[i], tree2ToUse, i, k, allAlphasInit[i], outs[i],
1291 outs2[i], bestAlphas[i], layersOuts[i], layersOuts2[i]);
1292 if(computeReconstructionError) {
1293 float iLoss = computeOneLoss(
1294 trees[i], outs[i], trees2[i], outs2[i], matchings[i], matchings2[i]);
1298 reset = reset || dReset;
1300 loss /= indexes.size();
1304bool ttk::MergeTreeAutoencoder::forwardStep(
1305 std::vector<mtu::TorchMergeTree<float>> &trees,
1306 std::vector<mtu::TorchMergeTree<float>> &trees2,
1307 std::vector<unsigned int> &indexes,
1309 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
1310 std::vector<mtu::TorchMergeTree<float>> &outs,
1311 std::vector<mtu::TorchMergeTree<float>> &outs2,
1312 std::vector<std::vector<torch::Tensor>> &bestAlphas) {
1313 std::vector<std::vector<mtu::TorchMergeTree<float>>> layersOuts, layersOuts2;
1314 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1315 matchings, matchings2;
1316 bool computeReconstructionError =
false;
1318 return forwardStep(trees, trees2, indexes, k, allAlphasInit,
1319 computeReconstructionError, outs, outs2, bestAlphas,
1320 layersOuts, layersOuts2, matchings, matchings2, loss);
1326bool ttk::MergeTreeAutoencoder::backwardStep(
1327 std::vector<mtu::TorchMergeTree<float>> &trees,
1328 std::vector<mtu::TorchMergeTree<float>> &outs,
1329 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1331 std::vector<mtu::TorchMergeTree<float>> &trees2,
1332 std::vector<mtu::TorchMergeTree<float>> &outs2,
1333 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1335 torch::optim::Optimizer &optimizer,
1336 std::vector<unsigned int> &indexes,
1337 torch::Tensor &metricLoss,
1338 torch::Tensor &clusteringLoss,
1339 torch::Tensor &trackingLoss) {
1340 double totalLoss = 0;
1341 bool retainGraph = (metricLossWeight_ != 0 or clusteringLossWeight_ != 0
1342 or trackingLossWeight_ != 0);
1343 if(reconstructionLossWeight_ != 0
1344 or (customLossDynamicWeight_ and retainGraph)) {
1345 std::vector<torch::Tensor> outTensors(indexes.size()),
1346 reorderedTensors(indexes.size());
1347#ifdef TTK_ENABLE_OPENMP
1348#pragma omp parallel for schedule(dynamic) \
1349 num_threads(this->threadNumber_) if(parallelize_)
1351 for(
unsigned int ind = 0; ind < indexes.size(); ++ind) {
1352 unsigned int i = indexes[ind];
1353 torch::Tensor reorderedTensor;
1354 dataReorderingGivenMatching(
1355 outs[i], trees[i], matchings[i], reorderedTensor);
1356 auto outTensor = outs[i].tensor;
1357 if(useDoubleInput_) {
1358 torch::Tensor reorderedTensor2;
1359 dataReorderingGivenMatching(
1360 outs2[i], trees2[i], matchings2[i], reorderedTensor2);
1361 outTensor = torch::cat({outTensor, outs2[i].tensor});
1362 reorderedTensor = torch::cat({reorderedTensor, reorderedTensor2});
1364 outTensors[ind] = outTensor;
1365 reorderedTensors[ind] = reorderedTensor;
1367 for(
unsigned int ind = 0; ind < indexes.size(); ++ind) {
1368 auto loss = torch::nn::functional::mse_loss(
1369 outTensors[ind], reorderedTensors[ind]);
1373 totalLoss += loss.item<
float>();
1374 loss *= reconstructionLossWeight_;
1375 loss.backward({}, retainGraph);
1378 if(metricLossWeight_ != 0) {
1379 bool retainGraphMetricLoss
1380 = (clusteringLossWeight_ != 0 or trackingLossWeight_ != 0);
1381 metricLoss *= metricLossWeight_
1382 * getCustomLossDynamicWeight(
1383 totalLoss / indexes.size(), baseRecLoss2_);
1384 metricLoss.backward({}, retainGraphMetricLoss);
1386 if(clusteringLossWeight_ != 0) {
1387 bool retainGraphClusteringLoss = (trackingLossWeight_ != 0);
1388 clusteringLoss *= clusteringLossWeight_
1389 * getCustomLossDynamicWeight(
1390 totalLoss / indexes.size(), baseRecLoss2_);
1391 clusteringLoss.backward({}, retainGraphClusteringLoss);
1393 if(trackingLossWeight_ != 0) {
1394 trackingLoss *= trackingLossWeight_;
1395 trackingLoss.backward();
1398 for(
unsigned int l = 0; l < noLayers_; ++l) {
1399 if(not origins_[l].tensor.grad().defined()
1400 or not origins_[l].tensor.grad().count_nonzero().is_nonzero())
1401 ++originsNoZeroGrad_[l];
1402 if(not originsPrime_[l].tensor.grad().defined()
1403 or not originsPrime_[l].tensor.grad().count_nonzero().is_nonzero())
1404 ++originsPrimeNoZeroGrad_[l];
1405 if(not vSTensor_[l].grad().defined()
1406 or not vSTensor_[l].grad().count_nonzero().is_nonzero())
1408 if(not vSPrimeTensor_[l].grad().defined()
1409 or not vSPrimeTensor_[l].grad().count_nonzero().is_nonzero())
1410 ++vSPrimeNoZeroGrad_[l];
1411 if(useDoubleInput_) {
1412 if(not origins2_[l].tensor.grad().defined()
1413 or not origins2_[l].tensor.grad().count_nonzero().is_nonzero())
1414 ++origins2NoZeroGrad_[l];
1415 if(not origins2Prime_[l].tensor.grad().defined()
1416 or not origins2Prime_[l].tensor.grad().count_nonzero().is_nonzero())
1417 ++origins2PrimeNoZeroGrad_[l];
1418 if(not vS2Tensor_[l].grad().defined()
1419 or not vS2Tensor_[l].grad().count_nonzero().is_nonzero())
1420 ++vS2NoZeroGrad_[l];
1421 if(not vS2PrimeTensor_[l].grad().defined()
1422 or not vS2PrimeTensor_[l].grad().count_nonzero().is_nonzero())
1423 ++vS2PrimeNoZeroGrad_[l];
1428 optimizer.zero_grad();
1435void ttk::MergeTreeAutoencoder::projectionStep() {
1436 auto projectTree = [
this](mtu::TorchMergeTree<float> &tmt) {
1437 interpolationProjection(tmt);
1438 tmt.tensor = tmt.tensor.detach();
1439 tmt.tensor.requires_grad_(
true);
1441 for(
unsigned int l = 0; l < noLayers_; ++l) {
1442 projectTree(origins_[l]);
1443 projectTree(originsPrime_[l]);
1444 if(useDoubleInput_) {
1445 projectTree(origins2_[l]);
1446 projectTree(origins2Prime_[l]);
1454float ttk::MergeTreeAutoencoder::computeOneLoss(
1455 mtu::TorchMergeTree<float> &tree,
1456 mtu::TorchMergeTree<float> &out,
1457 mtu::TorchMergeTree<float> &tree2,
1458 mtu::TorchMergeTree<float> &out2,
1459 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
1460 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2) {
1462 bool isCalled =
true;
1464 computeOneDistance<float>(
1465 out.mTree, tree.mTree, matching, distance, isCalled, useDoubleInput_);
1466 if(useDoubleInput_) {
1468 computeOneDistance<float>(out2.mTree, tree2.mTree, matching2, distance2,
1469 isCalled, useDoubleInput_,
false);
1470 distance = mixDistances<float>(distance, distance2);
1476float ttk::MergeTreeAutoencoder::computeLoss(
1477 std::vector<mtu::TorchMergeTree<float>> &trees,
1478 std::vector<mtu::TorchMergeTree<float>> &outs,
1479 std::vector<mtu::TorchMergeTree<float>> &trees2,
1480 std::vector<mtu::TorchMergeTree<float>> &outs2,
1481 std::vector<unsigned int> &indexes,
1482 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1484 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1487 matchings.resize(trees.size());
1489 matchings2.resize(trees2.size());
1490#ifdef TTK_ENABLE_OPENMP
1491#pragma omp parallel for schedule(dynamic) num_threads(this->threadNumber_) \
1492 if(parallelize_) reduction(+:loss)
1494 for(
unsigned int ind = 0; ind < indexes.size(); ++ind) {
1495 unsigned int i = indexes[ind];
1496 float iLoss = computeOneLoss(
1497 trees[i], outs[i], trees2[i], outs2[i], matchings[i], matchings2[i]);
1500 return loss / indexes.size();
1503bool ttk::MergeTreeAutoencoder::isBestLoss(
float loss,
1505 unsigned int &cptBlocked) {
1506 bool isBestEnergy =
false;
1510 isBestEnergy =
true;
1512 return isBestEnergy;
1515bool ttk::MergeTreeAutoencoder::convergenceStep(
float loss,
1518 unsigned int &cptBlocked) {
1519 double tol = oldLoss / 125.0;
1520 bool converged = std::abs(loss - oldLoss) < std::abs(tol);
1523 cptBlocked += (minLoss < loss) ? 1 : 0;
1524 converged = (cptBlocked >= 10 * 10);
1534void ttk::MergeTreeAutoencoder::fit(
1535 std::vector<ftm::MergeTree<float>> &trees,
1536 std::vector<ftm::MergeTree<float>> &trees2) {
1537 torch::set_num_threads(1);
1539 if(deterministic_) {
1541 bool m_torch_deterministic =
true;
1543 torch::manual_seed(m_seed);
1544 at::globalContext().setDeterministicCuDNN(m_torch_deterministic ?
true
1546 at::globalContext().setDeterministicAlgorithms(
1547 m_torch_deterministic ?
true : false, true);
1551 for(
unsigned int i = 0; i < trees.size(); ++i) {
1552 for(
unsigned int n = 0; n < trees[i].tree.getNumberOfNodes(); ++n) {
1553 if(trees[i].tree.isNodeAlone(n))
1555 auto birthDeath = trees[i].tree.template getBirthDeath<float>(n);
1557 = std::max(std::abs(std::get<0>(birthDeath)), bigValuesThreshold_);
1559 = std::max(std::abs(std::get<1>(birthDeath)), bigValuesThreshold_);
1562 bigValuesThreshold_ *= 100;
1565 std::vector<mtu::TorchMergeTree<float>> torchTrees, torchTrees2;
1566 mergeTreesToTorchTrees(trees, torchTrees, normalizedWasserstein_);
1567 mergeTreesToTorchTrees(trees2, torchTrees2, normalizedWasserstein_);
1569 auto initRecs = [](std::vector<std::vector<mtu::TorchMergeTree<float>>> &recs,
1570 std::vector<mtu::TorchMergeTree<float>> &torchTreesT) {
1572 recs.resize(torchTreesT.size());
1573 for(
unsigned int i = 0; i < torchTreesT.size(); ++i) {
1574 mtu::TorchMergeTree<float> tmt;
1575 mtu::copyTorchMergeTree<float>(torchTreesT[i], tmt);
1576 recs[i].emplace_back(tmt);
1579 initRecs(recs_, torchTrees);
1581 initRecs(recs2_, torchTrees2);
1584 if(metricLossWeight_ != 0)
1585 getDistanceMatrix(torchTrees, torchTrees2, distanceMatrix_);
1589 initStep(torchTrees, torchTrees2);
1590 printMsg(
"Init", 1, t_init.getElapsedTime(), threadNumber_);
1593 std::vector<torch::Tensor> parameters;
1594 for(
unsigned int l = 0; l < noLayers_; ++l) {
1595 parameters.emplace_back(origins_[l].tensor);
1596 parameters.emplace_back(originsPrime_[l].tensor);
1597 parameters.emplace_back(vSTensor_[l]);
1598 parameters.emplace_back(vSPrimeTensor_[l]);
1599 if(trees2.size() != 0) {
1600 parameters.emplace_back(origins2_[l].tensor);
1601 parameters.emplace_back(origins2Prime_[l].tensor);
1602 parameters.emplace_back(vS2Tensor_[l]);
1603 parameters.emplace_back(vS2PrimeTensor_[l]);
1606 if(clusteringLossWeight_ != 0)
1607 for(
unsigned int i = 0; i < latentCentroids_.size(); ++i)
1608 parameters.emplace_back(latentCentroids_[i]);
1610 torch::optim::Optimizer *optimizer;
1612 auto adamOptions = torch::optim::AdamOptions(gradientStepSize_);
1613 adamOptions.betas(std::make_tuple(beta1_, beta2_));
1614 auto adamOptimizer = torch::optim::Adam(parameters, adamOptions);
1616 auto sgdOptions = torch::optim::SGDOptions(gradientStepSize_);
1617 auto sgdOptimizer = torch::optim::SGD(parameters, sgdOptions);
1619 auto rmspropOptions = torch::optim::RMSpropOptions(gradientStepSize_);
1620 auto rmspropOptimizer = torch::optim::RMSprop(parameters, rmspropOptions);
1622 switch(optimizer_) {
1624 optimizer = &sgdOptimizer;
1627 optimizer = &rmspropOptimizer;
1631 optimizer = &adamOptimizer;
1635 unsigned int batchSize = std::min(
1636 std::max((
int)(trees.size() * batchSize_), 1), (
int)trees.size());
1637 std::stringstream ssBatch;
1638 ssBatch <<
"batchSize = " << batchSize;
1640 unsigned int noBatch
1641 = trees.size() / batchSize + ((trees.size() % batchSize) != 0 ? 1 : 0);
1642 std::vector<std::vector<unsigned int>> allIndexes(noBatch);
1644 allIndexes[0].resize(trees.size());
1645 std::iota(allIndexes[0].
begin(), allIndexes[0].
end(), 0);
1647 auto rng = std::default_random_engine{};
1650 originsNoZeroGrad_.resize(noLayers_);
1651 originsPrimeNoZeroGrad_.resize(noLayers_);
1652 vSNoZeroGrad_.resize(noLayers_);
1653 vSPrimeNoZeroGrad_.resize(noLayers_);
1654 for(
unsigned int l = 0; l < noLayers_; ++l) {
1655 originsNoZeroGrad_[l] = 0;
1656 originsPrimeNoZeroGrad_[l] = 0;
1657 vSNoZeroGrad_[l] = 0;
1658 vSPrimeNoZeroGrad_[l] = 0;
1660 if(useDoubleInput_) {
1661 origins2NoZeroGrad_.resize(noLayers_);
1662 origins2PrimeNoZeroGrad_.resize(noLayers_);
1663 vS2NoZeroGrad_.resize(noLayers_);
1664 vS2PrimeNoZeroGrad_.resize(noLayers_);
1665 for(
unsigned int l = 0; l < noLayers_; ++l) {
1666 origins2NoZeroGrad_[l] = 0;
1667 origins2PrimeNoZeroGrad_[l] = 0;
1668 vS2NoZeroGrad_[l] = 0;
1669 vS2PrimeNoZeroGrad_[l] = 0;
1674 baseRecLoss_ = std::numeric_limits<double>::max();
1675 baseRecLoss2_ = std::numeric_limits<double>::max();
1676 unsigned int k = k_;
1677 float oldLoss, minLoss, minRecLoss, minMetricLoss, minClustLoss, minTrackLoss;
1678 unsigned int cptBlocked, iteration = 0;
1679 auto initLoop = [&]() {
1681 minLoss = std::numeric_limits<float>::max();
1682 minRecLoss = minLoss;
1683 minMetricLoss = minLoss;
1684 minClustLoss = minLoss;
1685 minTrackLoss = minLoss;
1690 int convWinSize = 5;
1691 int noConverged = 0, noConvergedToGet = 10;
1692 std::vector<float> losses, metricLosses, clusteringLosses, trackingLosses;
1693 float windowLoss = 0;
1695 double assignmentTime = 0.0, updateTime = 0.0, projectionTime = 0.0,
1698 int bestIteration = 0;
1699 std::vector<torch::Tensor> bestVSTensor, bestVSPrimeTensor, bestVS2Tensor,
1701 std::vector<mtu::TorchMergeTree<float>> bestOrigins, bestOriginsPrime,
1702 bestOrigins2, bestOrigins2Prime;
1703 std::vector<std::vector<torch::Tensor>> bestAlphasInit;
1704 std::vector<std::vector<mtu::TorchMergeTree<float>>> bestRecs, bestRecs2;
1705 double bestTime = 0;
1708 = [
this](
float loss,
float recLoss,
float metricLoss,
float clustLoss,
1709 float trackLoss,
int iterationT,
int iterationTT,
double time,
1711 std::stringstream prefix;
1713 std::stringstream ssBestLoss;
1714 ssBestLoss << prefix.str() <<
"loss is " << loss <<
" (iteration "
1715 << iterationT <<
" / " << iterationTT <<
") at time "
1717 printMsg(ssBestLoss.str(), priority);
1720 if(metricLossWeight_ != 0 or clusteringLossWeight_ != 0
1721 or trackingLossWeight_ != 0) {
1723 ssBestLoss <<
"- Rec. " << prefix.str() <<
"loss = " << recLoss;
1724 printMsg(ssBestLoss.str(), priority);
1726 if(metricLossWeight_ != 0) {
1728 ssBestLoss <<
"- Metric " << prefix.str() <<
"loss = " << metricLoss;
1729 printMsg(ssBestLoss.str(), priority);
1731 if(clusteringLossWeight_ != 0) {
1733 ssBestLoss <<
"- Clust. " << prefix.str() <<
"loss = " << clustLoss;
1734 printMsg(ssBestLoss.str(), priority);
1736 if(trackingLossWeight_ != 0) {
1738 ssBestLoss <<
"- Track. " << prefix.str() <<
"loss = " << trackLoss;
1739 printMsg(ssBestLoss.str(), priority);
1745 bool converged =
false;
1746 while(not converged) {
1747 if(iteration % iterationGap_ == 0) {
1748 std::stringstream ss;
1749 ss <<
"Iteration " << iteration;
1754 bool forwardReset =
false;
1755 std::vector<float> iterationLosses, iterationMetricLosses,
1756 iterationClusteringLosses, iterationTrackingLosses;
1758 std::vector<unsigned int> indexes(trees.size());
1759 std::iota(indexes.begin(), indexes.end(), 0);
1760 std::shuffle(std::begin(indexes), std::end(indexes), rng);
1761 for(
unsigned int i = 0; i < allIndexes.size(); ++i) {
1762 unsigned int noProcessed = batchSize * i;
1763 unsigned int remaining = trees.size() - noProcessed;
1764 unsigned int size = std::min(batchSize, remaining);
1765 allIndexes[i].resize(size);
1766 for(
unsigned int j = 0; j < size; ++j)
1767 allIndexes[i][j] = indexes[noProcessed + j];
1770 for(
unsigned batchNum = 0; batchNum < allIndexes.size(); ++batchNum) {
1771 auto &indexes = allIndexes[batchNum];
1775 std::vector<mtu::TorchMergeTree<float>> outs, outs2;
1776 std::vector<std::vector<torch::Tensor>> bestAlphas;
1777 std::vector<std::vector<mtu::TorchMergeTree<float>>> layersOuts,
1779 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
1780 matchings, matchings2;
1782 bool computeReconstructionError = reconstructionLossWeight_ != 0;
1784 = forwardStep(torchTrees, torchTrees2, indexes, k, allAlphas_,
1785 computeReconstructionError, outs, outs2, bestAlphas,
1786 layersOuts, layersOuts2, matchings, matchings2, loss);
1789 for(
unsigned int ind = 0; ind < indexes.size(); ++ind) {
1790 unsigned int i = indexes[ind];
1791 for(
unsigned int j = 0; j < bestAlphas[i].size(); ++j)
1792 mtu::copyTensor(bestAlphas[i][j], allAlphas_[i][j]);
1794 assignmentTime += t_assignment.getElapsedTime();
1798 losses.emplace_back(loss);
1799 iterationLosses.emplace_back(loss);
1801 torch::Tensor metricLoss;
1802 if(metricLossWeight_ != 0) {
1803 computeMetricLoss(layersOuts, layersOuts2, bestAlphas, distanceMatrix_,
1804 indexes, metricLoss);
1805 float metricLossF = metricLoss.item<
float>();
1806 metricLosses.emplace_back(metricLossF);
1807 iterationMetricLosses.emplace_back(metricLossF);
1810 torch::Tensor clusteringLoss;
1811 if(clusteringLossWeight_ != 0) {
1813 computeClusteringLoss(bestAlphas, indexes, clusteringLoss, asgn);
1814 float clusteringLossF = clusteringLoss.item<
float>();
1815 clusteringLosses.emplace_back(clusteringLossF);
1816 iterationClusteringLosses.emplace_back(clusteringLossF);
1819 torch::Tensor trackingLoss;
1820 if(trackingLossWeight_ != 0) {
1821 computeTrackingLoss(trackingLoss);
1822 float trackingLossF = trackingLoss.item<
float>();
1823 trackingLosses.emplace_back(trackingLossF);
1824 iterationTrackingLosses.emplace_back(trackingLossF);
1826 lossTime += t_loss.getElapsedTime();
1830 backwardStep(torchTrees, outs, matchings, torchTrees2, outs2, matchings2,
1831 *optimizer, indexes, metricLoss, clusteringLoss,
1833 updateTime += t_update.getElapsedTime();
1838 projectionTime += t_projection.getElapsedTime();
1844 printWrn(
"Forward reset!");
1854 float iterationRecLoss
1855 = torch::tensor(iterationLosses).mean().item<
float>();
1856 float iterationLoss = reconstructionLossWeight_ * iterationRecLoss;
1857 float iterationMetricLoss = 0;
1858 if(metricLossWeight_ != 0) {
1860 = torch::tensor(iterationMetricLosses).mean().item<
float>();
1862 += metricLossWeight_
1863 * getCustomLossDynamicWeight(iterationRecLoss, baseRecLoss_)
1864 * iterationMetricLoss;
1866 float iterationClusteringLoss = 0;
1867 if(clusteringLossWeight_ != 0) {
1868 iterationClusteringLoss
1869 = torch::tensor(iterationClusteringLosses).mean().item<
float>();
1871 += clusteringLossWeight_
1872 * getCustomLossDynamicWeight(iterationRecLoss, baseRecLoss_)
1873 * iterationClusteringLoss;
1875 float iterationTrackingLoss = 0;
1876 if(trackingLossWeight_ != 0) {
1877 iterationTrackingLoss
1878 = torch::tensor(iterationTrackingLosses).mean().item<
float>();
1879 iterationLoss += trackingLossWeight_ * iterationTrackingLoss;
1881 printLoss(iterationLoss, iterationRecLoss, iterationMetricLoss,
1882 iterationClusteringLoss, iterationTrackingLoss, iteration,
1883 iteration, t_alg.getElapsedTime() - t_allVectorCopy_time_,
1887 bool isBest = isBestLoss(iterationLoss, minLoss, cptBlocked);
1890 bestIteration = iteration;
1891 copyParams(origins_, originsPrime_, vSTensor_, vSPrimeTensor_, origins2_,
1892 origins2Prime_, vS2Tensor_, vS2PrimeTensor_, allAlphas_,
1893 bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor,
1894 bestOrigins2, bestOrigins2Prime, bestVS2Tensor,
1895 bestVS2PrimeTensor, bestAlphasInit);
1896 copyParams(recs_, bestRecs);
1897 copyParams(recs2_, bestRecs2);
1898 t_allVectorCopy_time_ += t_copy.getElapsedTime();
1899 bestTime = t_alg.getElapsedTime() - t_allVectorCopy_time_;
1900 minRecLoss = iterationRecLoss;
1901 minMetricLoss = iterationMetricLoss;
1902 minClustLoss = iterationClusteringLoss;
1903 minTrackLoss = iterationTrackingLoss;
1904 printLoss(minLoss, minRecLoss, minMetricLoss, minClustLoss, minTrackLoss,
1909 windowLoss += iterationLoss;
1910 if((iteration + 1) % convWinSize == 0) {
1911 windowLoss /= convWinSize;
1912 converged = convergenceStep(windowLoss, oldLoss, minLoss, cptBlocked);
1918 converged = noConverged >= noConvergedToGet;
1919 if(converged and iteration < minIteration_)
1920 printMsg(
"convergence is detected but iteration < minIteration_",
1922 if(iteration < minIteration_)
1929 if(iteration % iterationGap_ == 0) {
1930 printMsg(
"Assignment", 1, assignmentTime, threadNumber_);
1931 printMsg(
"Loss", 1, lossTime, threadNumber_);
1932 printMsg(
"Update", 1, updateTime, threadNumber_);
1933 printMsg(
"Projection", 1, projectionTime, threadNumber_);
1934 assignmentTime = 0.0;
1937 projectionTime = 0.0;
1938 std::stringstream ss;
1939 float loss = torch::tensor(losses).mean().item<
float>();
1941 ss <<
"Rec. loss = " << loss;
1943 if(metricLossWeight_ != 0) {
1944 float metricLoss = torch::tensor(metricLosses).mean().item<
float>();
1945 metricLosses.clear();
1947 ss <<
"Metric loss = " << metricLoss;
1950 if(clusteringLossWeight_ != 0) {
1951 float clusteringLoss
1952 = torch::tensor(clusteringLosses).mean().item<
float>();
1953 clusteringLosses.clear();
1955 ss <<
"Clust. loss = " << clusteringLoss;
1958 if(trackingLossWeight_ != 0) {
1959 float trackingLoss = torch::tensor(trackingLosses).mean().item<
float>();
1960 trackingLosses.clear();
1962 ss <<
"Track. loss = " << trackingLoss;
1967 for(
unsigned int l = 0; l < noLayers_; ++l) {
1969 if(originsNoZeroGrad_[l] != 0)
1970 ss << originsNoZeroGrad_[l] <<
" originsNoZeroGrad_[" << l <<
"]"
1972 if(originsPrimeNoZeroGrad_[l] != 0)
1973 ss << originsPrimeNoZeroGrad_[l] <<
" originsPrimeNoZeroGrad_[" << l
1974 <<
"]" << std::endl;
1975 if(vSNoZeroGrad_[l] != 0)
1976 ss << vSNoZeroGrad_[l] <<
" vSNoZeroGrad_[" << l <<
"]" << std::endl;
1977 if(vSPrimeNoZeroGrad_[l] != 0)
1978 ss << vSPrimeNoZeroGrad_[l] <<
" vSPrimeNoZeroGrad_[" << l <<
"]"
1980 originsNoZeroGrad_[l] = 0;
1981 originsPrimeNoZeroGrad_[l] = 0;
1982 vSNoZeroGrad_[l] = 0;
1983 vSPrimeNoZeroGrad_[l] = 0;
1984 if(useDoubleInput_) {
1985 if(origins2NoZeroGrad_[l] != 0)
1986 ss << origins2NoZeroGrad_[l] <<
" origins2NoZeroGrad_[" << l <<
"]"
1988 if(origins2PrimeNoZeroGrad_[l] != 0)
1989 ss << origins2PrimeNoZeroGrad_[l] <<
" origins2PrimeNoZeroGrad_["
1990 << l <<
"]" << std::endl;
1991 if(vS2NoZeroGrad_[l] != 0)
1992 ss << vS2NoZeroGrad_[l] <<
" vS2NoZeroGrad_[" << l <<
"]"
1994 if(vS2PrimeNoZeroGrad_[l] != 0)
1995 ss << vS2PrimeNoZeroGrad_[l] <<
" vS2PrimeNoZeroGrad_[" << l <<
"]"
1997 origins2NoZeroGrad_[l] = 0;
1998 origins2PrimeNoZeroGrad_[l] = 0;
1999 vS2NoZeroGrad_[l] = 0;
2000 vS2PrimeNoZeroGrad_[l] = 0;
2002 if(isTreeHasBigValues(origins_[l].mTree, bigValuesThreshold_))
2003 ss <<
"origins_[" << l <<
"] has big values!" << std::endl;
2004 if(isTreeHasBigValues(originsPrime_[l].mTree, bigValuesThreshold_))
2005 ss <<
"originsPrime_[" << l <<
"] has big values!" << std::endl;
2006 if(ss.rdbuf()->in_avail() != 0)
2012 if(maxIteration_ != 0 and iteration >= maxIteration_) {
2018 printLoss(minLoss, minRecLoss, minMetricLoss, minClustLoss, minTrackLoss,
2019 bestIteration, iteration, bestTime);
2021 bestLoss_ = minLoss;
2024 copyParams(bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor,
2025 bestOrigins2, bestOrigins2Prime, bestVS2Tensor, bestVS2PrimeTensor,
2026 bestAlphasInit, origins_, originsPrime_, vSTensor_, vSPrimeTensor_,
2027 origins2_, origins2Prime_, vS2Tensor_, vS2PrimeTensor_,
2029 copyParams(bestRecs, recs_);
2030 copyParams(bestRecs2, recs2_);
2031 t_allVectorCopy_time_ += t_copy.getElapsedTime();
2032 printMsg(
"Copy time", 1, t_allVectorCopy_time_, threadNumber_);
2038double ttk::MergeTreeAutoencoder::getCustomLossDynamicWeight(
double recLoss,
2040 baseLoss = std::min(recLoss, baseLoss);
2041 if(customLossDynamicWeight_)
2047void ttk::MergeTreeAutoencoder::getDistanceMatrix(
2048 std::vector<mtu::TorchMergeTree<float>> &tmts,
2049 std::vector<std::vector<float>> &distanceMatrix,
2050 bool useDoubleInput,
2051 bool isFirstInput) {
2052 distanceMatrix.clear();
2053 distanceMatrix.resize(tmts.size(), std::vector<float>(tmts.size(), 0));
2054#ifdef TTK_ENABLE_OPENMP
2055#pragma omp parallel num_threads(this->threadNumber_) if(parallelize_) \
2056 shared(distanceMatrix, tmts)
2058#pragma omp single nowait
2061 for(
unsigned int i = 0; i < tmts.size(); ++i) {
2062 for(
unsigned int j = i + 1; j < tmts.size(); ++j) {
2063#ifdef TTK_ENABLE_OPENMP
2064#pragma omp task UNTIED() shared(distanceMatrix, tmts) firstprivate(i, j)
2067 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching;
2069 bool isCalled =
true;
2070 computeOneDistance(tmts[i].mTree, tmts[j].mTree, matching, distance,
2071 isCalled, useDoubleInput, isFirstInput);
2075#ifdef TTK_ENABLE_OPENMP
2080#ifdef TTK_ENABLE_OPENMP
2087void ttk::MergeTreeAutoencoder::getDistanceMatrix(
2088 std::vector<mtu::TorchMergeTree<float>> &tmts,
2089 std::vector<mtu::TorchMergeTree<float>> &tmts2,
2090 std::vector<std::vector<float>> &distanceMatrix) {
2091 getDistanceMatrix(tmts, distanceMatrix, useDoubleInput_);
2092 if(useDoubleInput_) {
2093 std::vector<std::vector<float>> distanceMatrix2;
2094 getDistanceMatrix(tmts2, distanceMatrix2, useDoubleInput_,
false);
2095 mixDistancesMatrix<float>(distanceMatrix, distanceMatrix2);
2099void ttk::MergeTreeAutoencoder::getDifferentiableDistanceFromMatchings(
2100 mtu::TorchMergeTree<float> &tree1,
2101 mtu::TorchMergeTree<float> &tree2,
2102 mtu::TorchMergeTree<float> &tree1_2,
2103 mtu::TorchMergeTree<float> &tree2_2,
2104 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings,
2105 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings2,
2106 torch::Tensor &tensorDist,
2108 torch::Tensor reorderedITensor, reorderedJTensor;
2109 dataReorderingGivenMatching(
2110 tree1, tree2, matchings, reorderedITensor, reorderedJTensor);
2111 if(useDoubleInput_) {
2112 torch::Tensor reorderedI2Tensor, reorderedJ2Tensor;
2113 dataReorderingGivenMatching(
2114 tree1_2, tree2_2, matchings2, reorderedI2Tensor, reorderedJ2Tensor);
2115 reorderedITensor = torch::cat({reorderedITensor, reorderedI2Tensor});
2116 reorderedJTensor = torch::cat({reorderedJTensor, reorderedJ2Tensor});
2118 tensorDist = (reorderedITensor - reorderedJTensor).
pow(2).sum();
2120 tensorDist = tensorDist.sqrt();
2123void ttk::MergeTreeAutoencoder::getDifferentiableDistance(
2124 mtu::TorchMergeTree<float> &tree1,
2125 mtu::TorchMergeTree<float> &tree2,
2126 mtu::TorchMergeTree<float> &tree1_2,
2127 mtu::TorchMergeTree<float> &tree2_2,
2128 torch::Tensor &tensorDist,
2131 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matchings,
2134 computeOneDistance<float>(
2135 tree1.mTree, tree2.mTree, matchings, distance, isCalled, useDoubleInput_);
2136 if(useDoubleInput_) {
2138 computeOneDistance<float>(tree1_2.mTree, tree2_2.mTree, matchings2,
2139 distance2, isCalled, useDoubleInput_,
false);
2141 getDifferentiableDistanceFromMatchings(
2142 tree1, tree2, tree1_2, tree2_2, matchings, matchings2, tensorDist, doSqrt);
2145void ttk::MergeTreeAutoencoder::getDifferentiableDistance(
2146 mtu::TorchMergeTree<float> &tree1,
2147 mtu::TorchMergeTree<float> &tree2,
2148 torch::Tensor &tensorDist,
2151 mtu::TorchMergeTree<float> tree1_2, tree2_2;
2152 getDifferentiableDistance(
2153 tree1, tree2, tree1_2, tree2_2, tensorDist, isCalled, doSqrt);
2156void ttk::MergeTreeAutoencoder::getDifferentiableDistanceMatrix(
2157 std::vector<mtu::TorchMergeTree<float> *> &trees,
2158 std::vector<mtu::TorchMergeTree<float> *> &trees2,
2159 std::vector<std::vector<torch::Tensor>> &outDistMat) {
2160 outDistMat.resize(trees.size(), std::vector<torch::Tensor>(trees.size()));
2161#ifdef TTK_ENABLE_OPENMP
2162#pragma omp parallel num_threads(this->threadNumber_) if(parallelize_) \
2163 shared(trees, trees2, outDistMat)
2165#pragma omp single nowait
2168 for(
unsigned int i = 0; i < trees.size(); ++i) {
2169 outDistMat[i][i] = torch::tensor(0);
2170 for(
unsigned int j = i + 1; j < trees.size(); ++j) {
2171#ifdef TTK_ENABLE_OPENMP
2172#pragma omp task UNTIED() shared(trees, trees2, outDistMat) firstprivate(i, j)
2175 bool isCalled =
true;
2176 bool doSqrt =
false;
2177 torch::Tensor tensorDist;
2178 getDifferentiableDistance(*(trees[i]), *(trees[j]), *(trees2[i]),
2179 *(trees2[j]), tensorDist, isCalled,
2181 outDistMat[i][j] = tensorDist;
2182 outDistMat[j][i] = tensorDist;
2183#ifdef TTK_ENABLE_OPENMP
2188#ifdef TTK_ENABLE_OPENMP
2195void ttk::MergeTreeAutoencoder::getAlphasTensor(
2196 std::vector<std::vector<torch::Tensor>> &alphas,
2197 std::vector<unsigned int> &indexes,
2198 unsigned int layerIndex,
2199 torch::Tensor &alphasOut) {
2200 alphasOut = alphas[indexes[0]][layerIndex].transpose(0, 1);
2201 for(
unsigned int ind = 1; ind < indexes.size(); ++ind)
2202 alphasOut = torch::cat(
2203 {alphasOut, alphas[indexes[ind]][layerIndex].transpose(0, 1)});
2206void ttk::MergeTreeAutoencoder::computeMetricLoss(
2207 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
2208 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
2209 std::vector<std::vector<torch::Tensor>> alphas,
2210 std::vector<std::vector<float>> &baseDistanceMatrix,
2211 std::vector<unsigned int> &indexes,
2212 torch::Tensor &metricLoss) {
2213 auto layerIndex = getLatentLayerIndex();
2214 std::vector<std::vector<torch::Tensor>> losses(
2215 layersOuts.size(), std::vector<torch::Tensor>(layersOuts.size()));
2217 std::vector<mtu::TorchMergeTree<float> *> trees, trees2;
2218 for(
unsigned int ind = 0; ind < indexes.size(); ++ind) {
2219 unsigned int i = indexes[ind];
2220 trees.emplace_back(&(layersOuts[i][layerIndex]));
2222 trees2.emplace_back(&(layersOuts2[i][layerIndex]));
2225 std::vector<std::vector<torch::Tensor>> outDistMat;
2226 torch::Tensor coefDistMat;
2227 if(customLossSpace_) {
2228 getDifferentiableDistanceMatrix(trees, trees2, outDistMat);
2230 std::vector<std::vector<torch::Tensor>> scaledAlphas;
2231 createScaledAlphas(alphas, vSTensor_, scaledAlphas);
2232 torch::Tensor latentAlphas;
2233 getAlphasTensor(scaledAlphas, indexes, layerIndex, latentAlphas);
2234 if(customLossActivate_)
2235 latentAlphas = activation(latentAlphas);
2236 coefDistMat = torch::cdist(latentAlphas, latentAlphas).pow(2);
2239 torch::Tensor maxLoss = torch::tensor(0);
2240 metricLoss = torch::tensor(0);
2242 for(
unsigned int ind = 0; ind < indexes.size(); ++ind) {
2243 unsigned int i = indexes[ind];
2244 for(
unsigned int ind2 = ind + 1; ind2 < indexes.size(); ++ind2) {
2245 unsigned int j = indexes[ind2];
2247 torch::Tensor toCompare
2248 = (customLossSpace_ ? outDistMat[i][j] : coefDistMat[ind][ind2]);
2249 loss = torch::nn::MSELoss()(
2250 torch::tensor(baseDistanceMatrix[i][j]), toCompare);
2251 metricLoss = metricLoss + loss;
2252 maxLoss = torch::max(loss, maxLoss);
2256 metricLoss = metricLoss / torch::tensor(div);
2257 if(normalizeMetricLoss_)
2258 metricLoss /= maxLoss;
2261void ttk::MergeTreeAutoencoder::computeClusteringLoss(
2262 std::vector<std::vector<torch::Tensor>> &alphas,
2263 std::vector<unsigned int> &indexes,
2264 torch::Tensor &clusteringLoss,
2265 torch::Tensor &asgn) {
2267 unsigned int layerIndex = getLatentLayerIndex();
2268 torch::Tensor latentAlphas;
2269 getAlphasTensor(alphas, indexes, layerIndex, latentAlphas);
2270 if(customLossActivate_)
2271 latentAlphas = activation(latentAlphas);
2272 torch::Tensor centroids = latentCentroids_[0].transpose(0, 1);
2273 for(
unsigned int i = 1; i < latentCentroids_.size(); ++i)
2274 centroids = torch::cat({centroids, latentCentroids_[i].transpose(0, 1)});
2275 torch::Tensor dist = torch::cdist(latentAlphas, centroids);
2278 dist = dist * -clusteringLossTemp_;
2279 asgn = torch::nn::Softmax(1)(dist);
2280 std::vector<float> clusterAsgn;
2281 for(
unsigned int ind = 0; ind < indexes.size(); ++ind) {
2282 clusterAsgn.emplace_back(clusterAsgn_[indexes[ind]]);
2284 torch::Tensor realAsgn = torch::tensor(clusterAsgn).to(torch::kInt64);
2286 = torch::nn::functional::one_hot(realAsgn, asgn.sizes()[1]).to(torch::kF32);
2289 clusteringLoss = torch::nn::KLDivLoss(
2290 torch::nn::KLDivLossOptions().reduction(torch::kBatchMean))(asgn, realAsgn);
2293void ttk::MergeTreeAutoencoder::computeTrackingLoss(
2294 torch::Tensor &trackingLoss) {
2295 unsigned int latentLayerIndex = getLatentLayerIndex() + 1;
2296 auto endLayer = (trackingLossDecoding_ ? noLayers_ : latentLayerIndex);
2297 std::vector<torch::Tensor> losses(endLayer);
2298#ifdef TTK_ENABLE_OPENMP
2299#pragma omp parallel for schedule(dynamic) \
2300 num_threads(this->threadNumber_) if(parallelize_)
2302 for(
unsigned int l = 0; l < endLayer; ++l) {
2303 auto &tree1 = (l == 0 ? origins_[0] : originsPrime_[l - 1]);
2304 auto &tree2 = (l == 0 ? originsPrime_[0] : originsPrime_[l]);
2305 torch::Tensor tensorDist;
2306 bool isCalled =
true, doSqrt =
false;
2307 getDifferentiableDistance(tree1, tree2, tensorDist, isCalled, doSqrt);
2308 losses[l] = tensorDist;
2310 trackingLoss = torch::tensor(0, torch::kFloat32);
2311 for(
unsigned int i = 0; i < losses.size(); ++i)
2312 trackingLoss += losses[i];
2318void ttk::MergeTreeAutoencoder::createCustomRecs(
2319 std::vector<mtu::TorchMergeTree<float>> &origins,
2320 std::vector<mtu::TorchMergeTree<float>> &originsPrime) {
2321 if(customAlphas_.empty())
2324 bool initByTreesAlphas = not allAlphas_.empty();
2325 std::vector<torch::Tensor> allTreesAlphas;
2326 if(initByTreesAlphas) {
2327 allTreesAlphas.resize(allAlphas_[0].size());
2328 for(
unsigned int l = 0; l < allTreesAlphas.size(); ++l) {
2329 allTreesAlphas[l] = allAlphas_[0][l].reshape({-1, 1});
2330 for(
unsigned int i = 1; i < allAlphas_.size(); ++i)
2332 = torch::cat({allTreesAlphas[l], allAlphas_[i][l]}, 1);
2333 allTreesAlphas[l] = allTreesAlphas[l].transpose(0, 1);
2337 unsigned int latLayer = getLatentLayerIndex();
2338 customRecs_.resize(customAlphas_.size());
2339#ifdef TTK_ENABLE_OPENMP
2340#pragma omp parallel for schedule(dynamic) \
2341 num_threads(this->threadNumber_) if(parallelize_)
2343 for(
unsigned int i = 0; i < customAlphas_.size(); ++i) {
2344 torch::Tensor alphas = torch::tensor(customAlphas_[i]).reshape({-1, 1});
2346 torch::Tensor alphasWeight;
2347 if(initByTreesAlphas) {
2348 auto driver =
"gelsd";
2349 alphasWeight = std::get<0>(torch::linalg::lstsq(
2350 allTreesAlphas[latLayer].transpose(0, 1),
2351 alphas, c10::nullopt, driver))
2356 std::vector<mtu::TorchMergeTree<float>> outs, outs2;
2357 auto noOuts = noLayers_ - latLayer;
2358 outs.resize(noOuts);
2359 outs2.resize(noOuts);
2360 mtu::TorchMergeTree<float> out, out2;
2361 outputBasisReconstruction(originsPrime[latLayer], vSPrimeTensor_[latLayer],
2362 origins2Prime_[latLayer],
2363 vS2PrimeTensor_[latLayer], alphas, outs[0],
2366 unsigned int k = 32;
2367 for(
unsigned int l = latLayer + 1; l < noLayers_; ++l) {
2368 unsigned int noIter = (initByTreesAlphas ? 1 : 32);
2369 std::vector<torch::Tensor> allAlphasInit(noIter);
2370 torch::Tensor maxNorm;
2371 for(
unsigned int j = 0; j < allAlphasInit.size(); ++j) {
2372 allAlphasInit[j] = torch::randn({vSTensor_[l].sizes()[1], 1});
2373 auto norm = torch::linalg::vector_norm(
2374 allAlphasInit[j], 2, 0,
false, c10::nullopt);
2375 if(j == 0 or maxNorm.item<
float>() < norm.item<
float>())
2378 for(
unsigned int j = 0; j < allAlphasInit.size(); ++j)
2379 allAlphasInit[j] /= maxNorm;
2380 float bestDistance = std::numeric_limits<float>::max();
2381 auto outIndex = l - latLayer;
2382 mtu::TorchMergeTree<float> outToUse;
2383 for(
unsigned int j = 0; j < noIter; ++j) {
2384 torch::Tensor alphasInit, dataAlphas;
2385 if(initByTreesAlphas) {
2387 = torch::matmul(alphasWeight, allTreesAlphas[l]).transpose(0, 1);
2389 alphasInit = allAlphasInit[j];
2392 forwardOneLayer(outs[outIndex - 1], origins[l], vSTensor_[l],
2393 originsPrime[l], vSPrimeTensor_[l], outs2[outIndex - 1],
2394 origins2_[l], vS2Tensor_[l], origins2Prime_[l],
2395 vS2PrimeTensor_[l], k, alphasInit, outToUse,
2396 outs2[outIndex], dataAlphas, distance);
2397 if(distance < bestDistance) {
2399 mtu::copyTorchMergeTree<float>(
2400 outToUse, (l != noLayers_ - 1 ? outs[outIndex] : customRecs_[i]));
2406 customMatchings_.resize(customRecs_.size());
2407#ifdef TTK_ENABLE_OPENMP
2408#pragma omp parallel for schedule(dynamic) \
2409 num_threads(this->threadNumber_) if(parallelize_)
2411 for(
unsigned int i = 0; i < customRecs_.size(); ++i) {
2412 bool isCalled =
true;
2414 computeOneDistance<float>(origins[0].mTree, customRecs_[i].mTree,
2415 customMatchings_[i], distance, isCalled,
2419 mtu::TorchMergeTree<float> originCopy;
2420 mtu::copyTorchMergeTree<float>(origins[0], originCopy);
2421 postprocessingPipeline<float>(&(originCopy.mTree.tree));
2422 for(
unsigned int i = 0; i < customRecs_.size(); ++i) {
2424 postprocessingPipeline<float>(&(customRecs_[i].mTree.tree));
2425 if(not isPersistenceDiagram_) {
2426 convertBranchDecompositionMatching<float>(&(originCopy.mTree.tree),
2427 &(customRecs_[i].mTree.tree),
2428 customMatchings_[i]);
2433void ttk::MergeTreeAutoencoder::computeTrackingInformation() {
2434 unsigned int latentLayerIndex = getLatentLayerIndex() + 1;
2435 originsMatchings_.resize(latentLayerIndex);
2436#ifdef TTK_ENABLE_OPENMP
2437#pragma omp parallel for schedule(dynamic) \
2438 num_threads(this->threadNumber_) if(parallelize_)
2440 for(
unsigned int l = 0; l < latentLayerIndex; ++l) {
2441 auto &tree1 = (l == 0 ? origins_[0] : originsPrime_[l - 1]);
2442 auto &tree2 = (l == 0 ? originsPrime_[0] : originsPrime_[l]);
2443 bool isCalled =
true;
2445 computeOneDistance<float>(tree1.mTree, tree2.mTree, originsMatchings_[l],
2446 distance, isCalled, useDoubleInput_);
2451 dataMatchings_.resize(latentLayerIndex);
2452 for(
unsigned int l = 0; l < latentLayerIndex; ++l) {
2453 dataMatchings_[l].resize(recs_.size());
2454#ifdef TTK_ENABLE_OPENMP
2455#pragma omp parallel for schedule(dynamic) \
2456 num_threads(this->threadNumber_) if(parallelize_)
2458 for(
unsigned int i = 0; i < recs_.size(); ++i) {
2459 bool isCalled =
true;
2461 auto &origin = (l == 0 ? origins_[0] : originsPrime_[l - 1]);
2462 computeOneDistance<float>(origin.mTree, recs_[i][l].mTree,
2463 dataMatchings_[l][i], distance, isCalled,
2469 reconstMatchings_.resize(recs_.size());
2470#ifdef TTK_ENABLE_OPENMP
2471#pragma omp parallel for schedule(dynamic) \
2472 num_threads(this->threadNumber_) if(parallelize_)
2474 for(
unsigned int i = 0; i < recs_.size(); ++i) {
2475 bool isCalled =
true;
2477 auto l = recs_[i].size() - 1;
2478 computeOneDistance<float>(recs_[i][0].mTree, recs_[i][l].mTree,
2479 reconstMatchings_[i], distance, isCalled,
2484void ttk::MergeTreeAutoencoder::createScaledAlphas(
2485 std::vector<std::vector<torch::Tensor>> &alphas,
2486 std::vector<torch::Tensor> &vSTensor,
2487 std::vector<std::vector<torch::Tensor>> &scaledAlphas) {
2488 scaledAlphas.clear();
2489 scaledAlphas.resize(
2490 alphas.size(), std::vector<torch::Tensor>(alphas[0].size()));
2491 for(
unsigned int l = 0; l < alphas[0].size(); ++l) {
2492 torch::Tensor scale = vSTensor[l].pow(2).sum(0).sqrt();
2493 for(
unsigned int i = 0; i < alphas.size(); ++i) {
2494 scaledAlphas[i][l] = alphas[i][l] * scale.reshape({-1, 1});
2499void ttk::MergeTreeAutoencoder::createScaledAlphas() {
2500 createScaledAlphas(allAlphas_, vSTensor_, allScaledAlphas_);
2503void ttk::MergeTreeAutoencoder::createActivatedAlphas() {
2504 allActAlphas_ = allAlphas_;
2505 for(
unsigned int i = 0; i < allActAlphas_.size(); ++i)
2506 for(
unsigned int j = 0; j < allActAlphas_[i].size(); ++j)
2507 allActAlphas_[i][j] = activation(allActAlphas_[i][j]);
2508 createScaledAlphas(allActAlphas_, vSTensor_, allActScaledAlphas_);
2514void ttk::MergeTreeAutoencoder::copyParams(
2515 std::vector<mtu::TorchMergeTree<float>> &srcOrigins,
2516 std::vector<mtu::TorchMergeTree<float>> &srcOriginsPrime,
2517 std::vector<torch::Tensor> &srcVS,
2518 std::vector<torch::Tensor> &srcVSPrime,
2519 std::vector<mtu::TorchMergeTree<float>> &srcOrigins2,
2520 std::vector<mtu::TorchMergeTree<float>> &srcOrigins2Prime,
2521 std::vector<torch::Tensor> &srcVS2,
2522 std::vector<torch::Tensor> &srcVS2Prime,
2523 std::vector<std::vector<torch::Tensor>> &srcAlphas,
2524 std::vector<mtu::TorchMergeTree<float>> &dstOrigins,
2525 std::vector<mtu::TorchMergeTree<float>> &dstOriginsPrime,
2526 std::vector<torch::Tensor> &dstVS,
2527 std::vector<torch::Tensor> &dstVSPrime,
2528 std::vector<mtu::TorchMergeTree<float>> &dstOrigins2,
2529 std::vector<mtu::TorchMergeTree<float>> &dstOrigins2Prime,
2530 std::vector<torch::Tensor> &dstVS2,
2531 std::vector<torch::Tensor> &dstVS2Prime,
2532 std::vector<std::vector<torch::Tensor>> &dstAlphas) {
2533 dstOrigins.resize(noLayers_);
2534 dstOriginsPrime.resize(noLayers_);
2535 dstVS.resize(noLayers_);
2536 dstVSPrime.resize(noLayers_);
2537 dstAlphas.resize(srcAlphas.size(), std::vector<torch::Tensor>(noLayers_));
2538 if(useDoubleInput_) {
2539 dstOrigins2.resize(noLayers_);
2540 dstOrigins2Prime.resize(noLayers_);
2541 dstVS2.resize(noLayers_);
2542 dstVS2Prime.resize(noLayers_);
2544 for(
unsigned int l = 0; l < noLayers_; ++l) {
2545 mtu::copyTorchMergeTree(srcOrigins[l], dstOrigins[l]);
2546 mtu::copyTorchMergeTree(srcOriginsPrime[l], dstOriginsPrime[l]);
2547 mtu::copyTensor(srcVS[l], dstVS[l]);
2548 mtu::copyTensor(srcVSPrime[l], dstVSPrime[l]);
2549 if(useDoubleInput_) {
2550 mtu::copyTorchMergeTree(srcOrigins2[l], dstOrigins2[l]);
2551 mtu::copyTorchMergeTree(srcOrigins2Prime[l], dstOrigins2Prime[l]);
2552 mtu::copyTensor(srcVS2[l], dstVS2[l]);
2553 mtu::copyTensor(srcVS2Prime[l], dstVS2Prime[l]);
2555 for(
unsigned int i = 0; i < srcAlphas.size(); ++i)
2556 mtu::copyTensor(srcAlphas[i][l], dstAlphas[i][l]);
2560void ttk::MergeTreeAutoencoder::copyParams(
2561 std::vector<std::vector<mtu::TorchMergeTree<float>>> &src,
2562 std::vector<std::vector<mtu::TorchMergeTree<float>>> &dst) {
2563 dst.resize(src.size());
2564 for(
unsigned int i = 0; i < src.size(); ++i) {
2565 dst[i].resize(src[i].size());
2566 for(
unsigned int j = 0; j < src[i].size(); ++j)
2567 mtu::copyTorchMergeTree(src[i][j], dst[i][j]);
2571unsigned int ttk::MergeTreeAutoencoder::getLatentLayerIndex() {
2572 unsigned int idx = noLayers_ / 2 - 1;
2581bool ttk::MergeTreeAutoencoder::isTreeHasBigValues(ftm::MergeTree<float> &mTree,
2584 for(
unsigned int n = 0; n < mTree.tree.getNumberOfNodes(); ++n) {
2585 if(mTree.tree.isNodeAlone(n))
2587 auto birthDeath = mTree.tree.template getBirthDeath<float>(n);
2588 if(std::abs(std::get<0>(birthDeath)) > threshold
2589 or std::abs(std::get<1>(birthDeath)) > threshold) {
2605#ifndef TTK_ENABLE_TORCH
2608 printErr(
"This module requires Torch.");
2610#ifdef TTK_ENABLE_OPENMP
2611 int ompNested = omp_get_nested();
2616 preprocessingTrees<float>(trees, treesNodeCorr_);
2617 if(trees2.size() != 0)
2618 preprocessingTrees<float>(trees2, trees2NodeCorr_);
2620 useDoubleInput_ = (trees2.size() != 0);
2625 auto totalTime = t_total.
getElapsedTime() - t_allVectorCopy_time_;
2627 printMsg(
"Total time", 1, totalTime, threadNumber_);
2628 hasComputedOnce_ =
true;
2631 createScaledAlphas();
2632 createActivatedAlphas();
2633 computeTrackingInformation();
2635 auto latLayer = getLatentLayerIndex();
2636 std::vector<std::vector<double>> allTs;
2637 auto noGeod = allAlphas_[0][latLayer].sizes()[0];
2638 allTs.resize(noGeod);
2639 for(
unsigned int i = 0; i < noGeod; ++i) {
2640 allTs[i].resize(allAlphas_.size());
2641 for(
unsigned int j = 0; j < allAlphas_.size(); ++j)
2642 allTs[i][j] = allAlphas_[j][latLayer][i].item<double>();
2644 computeBranchesCorrelationMatrix(origins_[0].mTree, trees, dataMatchings_[0],
2645 allTs, branchesCorrelationMatrix_,
2646 persCorrelationMatrix_);
2648 originsCopy_.resize(origins_.size());
2649 originsPrimeCopy_.resize(originsPrime_.size());
2650 for(
unsigned int l = 0; l < origins_.size(); ++l) {
2651 mtu::copyTorchMergeTree<float>(origins_[l], originsCopy_[l]);
2652 mtu::copyTorchMergeTree<float>(originsPrime_[l], originsPrimeCopy_[l]);
2654 createCustomRecs(originsCopy_, originsPrimeCopy_);
2658 for(
unsigned int i = 0; i < trees.size(); ++i)
2659 postprocessingPipeline<float>(&(trees[i].tree));
2660 for(
unsigned int i = 0; i < trees2.size(); ++i)
2661 postprocessingPipeline<float>(&(trees2[i].tree));
2662 for(
unsigned int l = 0; l < origins_.size(); ++l) {
2663 fillMergeTreeStructure(origins_[l]);
2664 postprocessingPipeline<float>(&(origins_[l].mTree.tree));
2665 fillMergeTreeStructure(originsPrime_[l]);
2666 postprocessingPipeline<float>(&(originsPrime_[l].mTree.tree));
2668 for(
unsigned int j = 0; j < recs_[0].size(); ++j) {
2669 for(
unsigned int i = 0; i < recs_.size(); ++i) {
2671 postprocessingPipeline<float>(&(recs_[i][j].mTree.tree));
2676 if(not isPersistenceDiagram_) {
2677 for(
unsigned int l = 0; l < originsMatchings_.size(); ++l) {
2678 auto &tree1 = (l == 0 ? origins_[0] : originsPrime_[l - 1]);
2679 auto &tree2 = (l == 0 ? originsPrime_[0] : originsPrime_[l]);
2680 convertBranchDecompositionMatching<float>(
2681 &(tree1.mTree.tree), &(tree2.mTree.tree), originsMatchings_[l]);
2683 for(
unsigned int l = 0; l < dataMatchings_.size(); ++l) {
2684 for(
unsigned int i = 0; i < recs_.size(); ++i) {
2685 auto &origin = (l == 0 ? origins_[0] : originsPrime_[l - 1]);
2686 convertBranchDecompositionMatching<float>(&(origin.mTree.tree),
2687 &(recs_[i][l].mTree.tree),
2688 dataMatchings_[l][i]);
2691 for(
unsigned int i = 0; i < reconstMatchings_.size(); ++i) {
2692 auto l = recs_[i].size() - 1;
2693 convertBranchDecompositionMatching<float>(&(recs_[i][0].mTree.tree),
2694 &(recs_[i][l].mTree.tree),
2695 reconstMatchings_[i]);
2698#ifdef TTK_ENABLE_OPENMP
2699 omp_set_nested(ompNested);
#define TTK_FORCE_USE(x)
Force the compiler to use the function/method parameter.
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
#define ENERGY_COMPARISON_TOLERANCE
void setDebugMsgPrefix(const std::string &prefix)
void execute(std::vector< ftm::MergeTree< float > > &trees, std::vector< ftm::MergeTree< float > > &trees2)
int scaleVector(const T *a, const T factor, T *out, const int &dimension=3)
T dotProduct(const T *vA0, const T *vA1, const T *vB0, const T *vB1)
int flattenMultiDimensionalVector(const std::vector< std::vector< T > > &a, std::vector< T > &out)
T1 pow(const T1 val, const T2 n)
T magnitude(const T *v, const int &dimension=3)
T distance(const T *p0, const T *p1, const int &dimension=3)
MergeTree< dataType > copyMergeTree(ftm::FTMTree_MT *tree, bool doSplitMultiPersPairs=false)
unsigned int idNode
Node index in vect_nodes_.
void printPairs(ftm::MergeTree< float > &mTree, bool useBD=true)
Util function to print pairs of a merge tree.
void adjustNestingScalars(std::vector< float > &scalarsVector, ftm::idNode node, ftm::idNode refNode)
Fix the scalars of a merge tree to ensure that the nesting condition is respected.
void createBalancedBDT(std::vector< std::vector< ftm::idNode > > &parents, std::vector< std::vector< ftm::idNode > > &children, std::vector< float > &scalarsVector, std::vector< std::vector< ftm::idNode > > &childrenFinal, int threadNumber=1)
Create a balanced BDT structure (for output basis initialization).
void fixTreePrecisionScalars(ftm::MergeTree< float > &mTree)
Fix the scalars of a merge tree to ensure that the nesting condition is respected.
T end(std::pair< T, T > &p)
T begin(std::pair< T, T > &p)
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/|__ _|"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)