5using namespace torch::indexing;
13#ifdef TTK_ENABLE_TORCH
14void ttk::MergeTreeAutoencoder::initClusteringLossParameters() {
15 unsigned int l = getLatentLayerIndex();
16 unsigned int noCentroids
17 = std::set<unsigned int>(clusterAsgn_.begin(), clusterAsgn_.end()).size();
18 latentCentroids_.resize(noCentroids);
19 for(
unsigned int c = 0; c < noCentroids; ++c) {
20 unsigned int firstIndex = std::numeric_limits<unsigned int>::max();
21 for(
unsigned int i = 0; i < clusterAsgn_.size(); ++i) {
22 if(clusterAsgn_[i] == c) {
27 if(firstIndex >= allAlphas_.size()) {
28 printWrn(
"no data found for cluster " + std::to_string(c));
31 latentCentroids_[c] = allAlphas_[firstIndex][l].detach().clone();
33 for(
unsigned int i = 0; i < allAlphas_.size(); ++i) {
36 if(clusterAsgn_[i] == c) {
37 latentCentroids_[c] += allAlphas_[i][l];
41 latentCentroids_[c] /= torch::tensor(noData);
42 latentCentroids_[c] = latentCentroids_[c].detach();
43 latentCentroids_[c].requires_grad_(
true);
47bool ttk::MergeTreeAutoencoder::initResetOutputBasis(
49 unsigned int layerNoAxes,
50 double layerOriginPrimeSizePercent,
51 std::vector<mtu::TorchMergeTree<float>> &trees,
52 std::vector<mtu::TorchMergeTree<float>> &trees2,
53 std::vector<bool> &isTrain) {
55 if((noLayers_ == 2 and l == 1) or noLayers_ == 1) {
56 initOutputBasisSpecialCase(l, layerNoAxes, trees, trees2);
57 }
else if(l < (
unsigned int)(noLayers_ / 2)) {
58 initOutputBasis(l, layerOriginPrimeSizePercent, trees, trees2, isTrain);
60 printErr(
"recs[i].mTree.tree.getRealNumberOfNodes() == 0");
61 std::stringstream ssT;
69void ttk::MergeTreeAutoencoder::initOutputBasisSpecialCase(
71 unsigned int layerNoAxes,
72 std::vector<mtu::TorchMergeTree<float>> &trees,
73 std::vector<mtu::TorchMergeTree<float>> &trees2) {
76 layers_[l].setOriginPrime(layers_[0].getOrigin());
78 layers_[l].setOrigin2Prime(layers_[0].getOrigin2());
81 if(layerNoAxes != layers_[0].getVSTensor().sizes()[1]) {
83 std::vector<ftm::MergeTree<float>> treesToUse, trees2ToUse;
84 for(
unsigned int i = 0; i < trees.size(); ++i) {
85 treesToUse.emplace_back(trees[i].mTree);
87 trees2ToUse.emplace_back(trees2[i].mTree);
89 std::vector<torch::Tensor> allAlphasInitT(trees.size());
90 layers_[l].initInputBasisVectors(
91 trees, trees2, treesToUse, trees2ToUse, layerNoAxes, allAlphasInitT,
92 inputToBaryDistances_L0_, baryMatchings_L0_, baryMatchings2_L0_,
false);
94 layers_[l].setVSPrimeTensor(layers_[0].getVSTensor());
96 layers_[l].setVS2PrimeTensor(layers_[0].getVS2Tensor());
100float ttk::MergeTreeAutoencoder::initParameters(
101 std::vector<mtu::TorchMergeTree<float>> &trees,
102 std::vector<mtu::TorchMergeTree<float>> &trees2,
103 std::vector<bool> &isTrain,
108 noLayers_ = encoderNoLayers_ * 2 + 1 + 1;
109 if(encoderNoLayers_ <= -1)
111 std::vector<double> layersOriginPrimeSizePercent(noLayers_);
112 std::vector<unsigned int> layersNoAxes(noLayers_);
114 layersNoAxes[0] = numberOfAxes_;
115 layersOriginPrimeSizePercent[0] = latentSpaceOriginPrimeSizePercent_;
117 layersNoAxes[1] = inputNumberOfAxes_;
118 layersOriginPrimeSizePercent[1] = barycenterSizeLimitPercent_;
121 for(
unsigned int l = 0; l < noLayers_ / 2; ++l) {
122 double alpha = (double)(l) / (noLayers_ / 2 - 1);
124 = (1 - alpha) * inputNumberOfAxes_ + alpha * numberOfAxes_;
125 layersNoAxes[l] = noAxes;
126 layersNoAxes[noLayers_ - 1 - l] = noAxes;
127 double originPrimeSizePercent
128 = (1 - alpha) * inputOriginPrimeSizePercent_
129 + alpha * latentSpaceOriginPrimeSizePercent_;
130 layersOriginPrimeSizePercent[l] = originPrimeSizePercent;
131 layersOriginPrimeSizePercent[noLayers_ - 1 - l] = originPrimeSizePercent;
133 if(scaleLayerAfterLatent_)
134 layersNoAxes[noLayers_ / 2]
135 = (layersNoAxes[noLayers_ / 2 - 1] + layersNoAxes[noLayers_ / 2 + 1])
140 layers_.resize(noLayers_);
141 for(
unsigned int l = 0; l < layers_.size(); ++l) {
142 initOriginPrimeValuesByCopy_
143 = trackingLossWeight_ != 0
144 and l < (trackingLossDecoding_ ? noLayers_ : getLatentLayerIndex() + 1);
145 initOriginPrimeValuesByCopyRandomness_ = trackingLossInitRandomness_;
146 passLayerParameters(layers_[l]);
150 bool fullSymmetricAE = fullSymmetricAE_;
152 std::vector<mtu::TorchMergeTree<float>> recs, recs2;
153 std::vector<std::vector<torch::Tensor>> allAlphasInit(
154 trees.size(), std::vector<torch::Tensor>(noLayers_));
155 for(
unsigned int l = 0; l < noLayers_; ++l) {
157 std::stringstream ss;
158 ss <<
"Init Layer " << l;
162 if(l < (
unsigned int)(noLayers_ / 2) or not fullSymmetricAE
163 or (noLayers_ <= 2 and not fullSymmetricAE)) {
164 auto &treesToUse = (l == 0 ? trees : recs);
165 auto &trees2ToUse = (l == 0 ? trees2 : recs2);
167 l, layersNoAxes[l], treesToUse, trees2ToUse, isTrain, allAlphasInit);
172 unsigned int middle = noLayers_ / 2;
173 unsigned int l_opp = middle - (l - middle + 1);
174 layers_[l].setOrigin(layers_[l_opp].getOriginPrime());
175 layers_[l].setVSTensor(layers_[l_opp].getVSPrimeTensor());
176 if(trees2.size() != 0) {
177 if(fullSymmetricAE) {
178 layers_[l].setOrigin2(layers_[l_opp].getOrigin2Prime());
179 layers_[l].setVS2Tensor(layers_[l_opp].getVS2PrimeTensor());
182 for(
unsigned int i = 0; i < trees.size(); ++i)
183 allAlphasInit[i][l] = allAlphasInit[i][l_opp];
187 if((noLayers_ == 2 and l == 1) or noLayers_ == 1) {
189 initOutputBasisSpecialCase(l, layersNoAxes[l], trees, trees2);
190 }
else if(l < (
unsigned int)(noLayers_ / 2)) {
192 l, layersOriginPrimeSizePercent[l], trees, trees2, isTrain);
197 unsigned int middle = noLayers_ / 2;
198 unsigned int l_opp = middle - (l - middle + 1);
199 layers_[l].setOriginPrime(layers_[l_opp].getOrigin());
200 if(trees2.size() != 0)
201 layers_[l].setOrigin2Prime(layers_[l_opp].getOrigin2());
202 if(l == (
unsigned int)(noLayers_) / 2 and scaleLayerAfterLatent_) {
204 = (trees2.size() != 0 ? layers_[l].getOrigin2Prime().tensor.sizes()[0]
206 layers_[l].initOutputBasisVectors(
207 layers_[l].getOriginPrime().tensor.sizes()[0], dim2);
209 layers_[l].setVSPrimeTensor(layers_[l_opp].getVSTensor());
210 if(trees2.size() != 0)
211 layers_[l].setVS2PrimeTensor(layers_[l_opp].getVS2Tensor());
216 bool fullReset = initGetReconstructed(
217 l, layersNoAxes[l], layersOriginPrimeSizePercent[l], trees, trees2,
218 isTrain, recs, recs2, allAlphasInit);
220 return std::numeric_limits<float>::max();
222 allAlphas_ = allAlphasInit;
225 if(clusteringLossWeight_ != 0)
226 initClusteringLossParameters();
229 float error = 0.0, recLoss = 0.0;
232 std::vector<unsigned int> indexes(trees.size());
233 std::iota(indexes.begin(), indexes.end(), 0);
236 std::vector<std::vector<torch::Tensor>> bestAlphas;
237 std::vector<std::vector<mtu::TorchMergeTree<float>>> layersOuts,
239 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
240 matchings, matchings2;
241 bool reset = forwardStep(trees, trees2, indexes, k, allAlphasInit,
242 computeError, recs, recs2, bestAlphas, layersOuts,
243 layersOuts2, matchings, matchings2, recLoss);
245 printWrn(
"[initParameters] forwardStep reset");
246 return std::numeric_limits<float>::max();
248 error = recLoss * reconstructionLossWeight_;
249 if(metricLossWeight_ != 0) {
250 torch::Tensor metricLoss;
251 computeMetricLoss(layersOuts, layersOuts2, allAlphas_, distanceMatrix_,
252 indexes, metricLoss);
253 baseRecLoss_ = std::numeric_limits<double>::max();
254 metricLoss *= metricLossWeight_
255 * getCustomLossDynamicWeight(recLoss, baseRecLoss_);
256 error += metricLoss.item<
float>();
258 if(clusteringLossWeight_ != 0) {
259 torch::Tensor clusteringLoss, asgn;
260 computeClusteringLoss(allAlphas_, indexes, clusteringLoss, asgn);
261 baseRecLoss_ = std::numeric_limits<double>::max();
262 clusteringLoss *= clusteringLossWeight_
263 * getCustomLossDynamicWeight(recLoss, baseRecLoss_);
264 error += clusteringLoss.item<
float>();
266 if(trackingLossWeight_ != 0) {
267 torch::Tensor trackingLoss;
268 computeTrackingLoss(trackingLoss);
269 trackingLoss *= trackingLossWeight_;
270 error += trackingLoss.item<
float>();
279bool ttk::MergeTreeAutoencoder::backwardStep(
280 std::vector<mtu::TorchMergeTree<float>> &trees,
281 std::vector<mtu::TorchMergeTree<float>> &outs,
282 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
284 std::vector<mtu::TorchMergeTree<float>> &trees2,
285 std::vector<mtu::TorchMergeTree<float>> &outs2,
286 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
288 std::vector<std::vector<torch::Tensor>> &
ttkNotUsed(alphas),
289 torch::optim::Optimizer &optimizer,
290 std::vector<unsigned int> &indexes,
292 std::vector<torch::Tensor> &torchCustomLoss) {
293 double totalLoss = 0;
294 bool retainGraph = (metricLossWeight_ != 0 or clusteringLossWeight_ != 0
295 or trackingLossWeight_ != 0);
296 if(reconstructionLossWeight_ != 0
297 or (customLossDynamicWeight_ and retainGraph)) {
298 std::vector<torch::Tensor> outTensors(indexes.size()),
299 reorderedTensors(indexes.size());
300#ifdef TTK_ENABLE_OPENMP
301#pragma omp parallel for schedule(dynamic) \
302 num_threads(this->threadNumber_) if(parallelize_)
304 for(
unsigned int ind = 0; ind < indexes.size(); ++ind) {
305 unsigned int i = indexes[ind];
306 torch::Tensor reorderedTensor;
307 dataReorderingGivenMatching(
308 outs[i], trees[i], matchings[i], reorderedTensor);
309 auto outTensor = outs[i].tensor;
310 if(useDoubleInput_) {
311 torch::Tensor reorderedTensor2;
312 dataReorderingGivenMatching(
313 outs2[i], trees2[i], matchings2[i], reorderedTensor2);
314 outTensor = torch::cat({outTensor, outs2[i].tensor});
315 reorderedTensor = torch::cat({reorderedTensor, reorderedTensor2});
317 outTensors[ind] = outTensor;
318 reorderedTensors[ind] = reorderedTensor;
320 for(
unsigned int ind = 0; ind < indexes.size(); ++ind) {
321 auto loss = torch::nn::functional::mse_loss(
322 outTensors[ind], reorderedTensors[ind]);
326 totalLoss += loss.item<
float>();
327 loss *= reconstructionLossWeight_;
328 loss.backward({}, retainGraph);
331 if(metricLossWeight_ != 0) {
332 bool retainGraphMetricLoss
333 = (clusteringLossWeight_ != 0 or trackingLossWeight_ != 0);
334 torchCustomLoss[0] *= metricLossWeight_
335 * getCustomLossDynamicWeight(
336 totalLoss / indexes.size(), baseRecLoss2_);
337 torchCustomLoss[0].backward({}, retainGraphMetricLoss);
339 if(clusteringLossWeight_ != 0) {
340 bool retainGraphClusteringLoss = (trackingLossWeight_ != 0);
341 torchCustomLoss[1] *= clusteringLossWeight_
342 * getCustomLossDynamicWeight(
343 totalLoss / indexes.size(), baseRecLoss2_);
344 torchCustomLoss[1].backward({}, retainGraphClusteringLoss);
346 if(trackingLossWeight_ != 0) {
347 torchCustomLoss[2] *= trackingLossWeight_;
348 torchCustomLoss[2].backward();
351 for(
unsigned int l = 0; l < noLayers_; ++l)
355 optimizer.zero_grad();
362float ttk::MergeTreeAutoencoder::computeOneLoss(
363 mtu::TorchMergeTree<float> &tree,
364 mtu::TorchMergeTree<float> &out,
365 mtu::TorchMergeTree<float> &tree2,
366 mtu::TorchMergeTree<float> &out2,
367 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
368 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
369 std::vector<torch::Tensor> &
ttkNotUsed(alphas),
372 bool isCalled =
true;
374 computeOneDistance<float>(
375 out.mTree, tree.mTree, matching, distance, isCalled, useDoubleInput_);
376 if(useDoubleInput_) {
378 computeOneDistance<float>(out2.mTree, tree2.mTree, matching2, distance2,
379 isCalled, useDoubleInput_,
false);
380 distance = mixDistances<float>(distance, distance2);
389void ttk::MergeTreeAutoencoder::customInit(
390 std::vector<mtu::TorchMergeTree<float>> &torchTrees,
391 std::vector<mtu::TorchMergeTree<float>> &torchTrees2) {
392 baseRecLoss_ = std::numeric_limits<double>::max();
393 baseRecLoss2_ = std::numeric_limits<double>::max();
395 if(metricLossWeight_ != 0)
396 getDistanceMatrix(torchTrees, torchTrees2, distanceMatrix_);
399void ttk::MergeTreeAutoencoder::addCustomParameters(
400 std::vector<torch::Tensor> ¶meters) {
401 if(clusteringLossWeight_ != 0)
402 for(
unsigned int i = 0; i < latentCentroids_.size(); ++i)
403 parameters.emplace_back(latentCentroids_[i]);
406void ttk::MergeTreeAutoencoder::computeCustomLosses(
407 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
408 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
409 std::vector<std::vector<torch::Tensor>> &bestAlphas,
410 std::vector<unsigned int> &indexes,
413 std::vector<std::vector<float>> &gapCustomLosses,
414 std::vector<std::vector<float>> &iterationCustomLosses,
415 std::vector<torch::Tensor> &torchCustomLoss) {
416 if(gapCustomLosses.empty())
417 gapCustomLosses.resize(3);
418 if(iterationCustomLosses.empty())
419 iterationCustomLosses.resize(3);
420 torchCustomLoss.resize(3);
422 if(metricLossWeight_ != 0) {
423 computeMetricLoss(layersOuts, layersOuts2, bestAlphas, distanceMatrix_,
424 indexes, torchCustomLoss[0]);
425 float metricLossF = torchCustomLoss[0].item<
float>();
426 gapCustomLosses[0].emplace_back(metricLossF);
427 iterationCustomLosses[0].emplace_back(metricLossF);
430 if(clusteringLossWeight_ != 0) {
432 computeClusteringLoss(bestAlphas, indexes, torchCustomLoss[1], asgn);
433 float clusteringLossF = torchCustomLoss[1].item<
float>();
434 gapCustomLosses[1].emplace_back(clusteringLossF);
435 iterationCustomLosses[1].emplace_back(clusteringLossF);
438 if(trackingLossWeight_ != 0) {
439 computeTrackingLoss(torchCustomLoss[2]);
440 float trackingLossF = torchCustomLoss[2].item<
float>();
441 gapCustomLosses[2].emplace_back(trackingLossF);
442 iterationCustomLosses[2].emplace_back(trackingLossF);
446float ttk::MergeTreeAutoencoder::computeIterationTotalLoss(
448 std::vector<std::vector<float>> &iterationCustomLosses,
449 std::vector<float> &iterationCustomLoss) {
450 iterationCustomLoss.emplace_back(iterationLoss);
451 float iterationTotalLoss = reconstructionLossWeight_ * iterationLoss;
453 float iterationMetricLoss = 0;
454 if(metricLossWeight_ != 0) {
456 = torch::tensor(iterationCustomLosses[0]).mean().item<
float>();
459 * getCustomLossDynamicWeight(iterationLoss, baseRecLoss_)
460 * iterationMetricLoss;
462 iterationCustomLoss.emplace_back(iterationMetricLoss);
464 float iterationClusteringLoss = 0;
465 if(clusteringLossWeight_ != 0) {
466 iterationClusteringLoss
467 = torch::tensor(iterationCustomLosses[1]).mean().item<
float>();
469 += clusteringLossWeight_
470 * getCustomLossDynamicWeight(iterationLoss, baseRecLoss_)
471 * iterationClusteringLoss;
473 iterationCustomLoss.emplace_back(iterationClusteringLoss);
475 float iterationTrackingLoss = 0;
476 if(trackingLossWeight_ != 0) {
477 iterationTrackingLoss
478 = torch::tensor(iterationCustomLosses[2]).mean().item<
float>();
479 iterationTotalLoss += trackingLossWeight_ * iterationTrackingLoss;
481 iterationCustomLoss.emplace_back(iterationTrackingLoss);
482 return iterationTotalLoss;
485void ttk::MergeTreeAutoencoder::printCustomLosses(
486 std::vector<float> &customLoss,
487 std::stringstream &prefix,
491 std::stringstream ssBestLoss;
492 if(metricLossWeight_ != 0 or clusteringLossWeight_ != 0
493 or trackingLossWeight_ != 0) {
495 ssBestLoss <<
"- Rec. " << prefix.str() <<
"loss = " << customLoss[0];
496 printMsg(ssBestLoss.str(), priority);
498 if(metricLossWeight_ != 0) {
500 ssBestLoss <<
"- Metric " << prefix.str() <<
"loss = " << customLoss[1];
501 printMsg(ssBestLoss.str(), priority);
503 if(clusteringLossWeight_ != 0) {
505 ssBestLoss <<
"- Clust. " << prefix.str() <<
"loss = " << customLoss[2];
506 printMsg(ssBestLoss.str(), priority);
508 if(trackingLossWeight_ != 0) {
510 ssBestLoss <<
"- Track. " << prefix.str() <<
"loss = " << customLoss[3];
511 printMsg(ssBestLoss.str(), priority);
515void ttk::MergeTreeAutoencoder::printGapLoss(
516 float loss, std::vector<std::vector<float>> &gapCustomLosses) {
517 std::stringstream ss;
518 ss <<
"Rec. loss = " << loss;
520 if(metricLossWeight_ != 0) {
521 float metricLoss = torch::tensor(gapCustomLosses[0]).mean().item<
float>();
522 gapCustomLosses[0].clear();
524 ss <<
"Metric loss = " << metricLoss;
527 if(clusteringLossWeight_ != 0) {
529 = torch::tensor(gapCustomLosses[1]).mean().item<
float>();
530 gapCustomLosses[1].clear();
532 ss <<
"Clust. loss = " << clusteringLoss;
535 if(trackingLossWeight_ != 0) {
536 float trackingLoss = torch::tensor(gapCustomLosses[2]).mean().item<
float>();
537 gapCustomLosses[2].clear();
539 ss <<
"Track. loss = " << trackingLoss;
547double ttk::MergeTreeAutoencoder::getCustomLossDynamicWeight(
double recLoss,
549 baseLoss = std::min(recLoss, baseLoss);
550 if(customLossDynamicWeight_)
556void ttk::MergeTreeAutoencoder::computeMetricLoss(
557 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
558 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
559 std::vector<std::vector<torch::Tensor>> alphas,
560 std::vector<std::vector<float>> &baseDistanceMatrix,
561 std::vector<unsigned int> &indexes,
562 torch::Tensor &metricLoss) {
563 auto layerIndex = getLatentLayerIndex();
564 std::vector<std::vector<torch::Tensor>> losses(
565 layersOuts.size(), std::vector<torch::Tensor>(layersOuts.size()));
567 std::vector<mtu::TorchMergeTree<float> *> trees, trees2;
568 for(
unsigned int ind = 0; ind < indexes.size(); ++ind) {
569 unsigned int i = indexes[ind];
570 trees.emplace_back(&(layersOuts[i][layerIndex]));
572 trees2.emplace_back(&(layersOuts2[i][layerIndex]));
575 std::vector<std::vector<torch::Tensor>> outDistMat;
576 torch::Tensor coefDistMat;
577 if(customLossSpace_) {
578 getDifferentiableDistanceMatrix(trees, trees2, outDistMat);
580 std::vector<std::vector<torch::Tensor>> scaledAlphas;
581 createScaledAlphas(alphas, scaledAlphas);
582 torch::Tensor latentAlphas;
583 getAlphasTensor(scaledAlphas, indexes, layerIndex, latentAlphas);
584 if(customLossActivate_)
585 latentAlphas = activation(latentAlphas);
586 coefDistMat = torch::cdist(latentAlphas, latentAlphas).pow(2);
589 torch::Tensor maxLoss = torch::tensor(0);
590 metricLoss = torch::tensor(0);
592 for(
unsigned int ind = 0; ind < indexes.size(); ++ind) {
593 unsigned int i = indexes[ind];
594 for(
unsigned int ind2 = ind + 1; ind2 < indexes.size(); ++ind2) {
595 unsigned int j = indexes[ind2];
597 torch::Tensor toCompare
598 = (customLossSpace_ ? outDistMat[i][j] : coefDistMat[ind][ind2]);
599 loss = torch::nn::MSELoss()(
600 torch::tensor(baseDistanceMatrix[i][j]), toCompare);
601 metricLoss = metricLoss + loss;
602 maxLoss = torch::max(loss, maxLoss);
606 metricLoss = metricLoss / torch::tensor(div);
607 if(normalizeMetricLoss_)
608 metricLoss /= maxLoss;
611void ttk::MergeTreeAutoencoder::computeClusteringLoss(
612 std::vector<std::vector<torch::Tensor>> &alphas,
613 std::vector<unsigned int> &indexes,
614 torch::Tensor &clusteringLoss,
615 torch::Tensor &asgn) {
617 unsigned int layerIndex = getLatentLayerIndex();
618 torch::Tensor latentAlphas;
619 getAlphasTensor(alphas, indexes, layerIndex, latentAlphas);
620 if(customLossActivate_)
621 latentAlphas = activation(latentAlphas);
622 torch::Tensor centroids = latentCentroids_[0].transpose(0, 1);
623 for(
unsigned int i = 1; i < latentCentroids_.size(); ++i)
624 centroids = torch::cat({centroids, latentCentroids_[i].transpose(0, 1)});
625 torch::Tensor dist = torch::cdist(latentAlphas, centroids);
628 dist = dist * -clusteringLossTemp_;
629 asgn = torch::nn::Softmax(1)(dist);
630 std::vector<float> clusterAsgn;
631 for(
unsigned int ind = 0; ind < indexes.size(); ++ind) {
632 clusterAsgn.emplace_back(clusterAsgn_[indexes[ind]]);
634 torch::Tensor realAsgn = torch::tensor(clusterAsgn).to(torch::kInt64);
636 = torch::nn::functional::one_hot(realAsgn, asgn.sizes()[1]).to(torch::kF32);
639 clusteringLoss = torch::nn::KLDivLoss(
640 torch::nn::KLDivLossOptions().reduction(torch::kBatchMean))(asgn, realAsgn);
643void ttk::MergeTreeAutoencoder::computeTrackingLoss(
644 torch::Tensor &trackingLoss) {
645 unsigned int latentLayerIndex = getLatentLayerIndex() + 1;
646 auto endLayer = (trackingLossDecoding_ ? noLayers_ : latentLayerIndex);
647 std::vector<torch::Tensor> losses(endLayer);
648#ifdef TTK_ENABLE_OPENMP
649#pragma omp parallel for schedule(dynamic) \
650 num_threads(this->threadNumber_) if(parallelize_)
652 for(
unsigned int l = 0; l < endLayer; ++l) {
654 = (l == 0 ? layers_[0].getOrigin() : layers_[l - 1].getOriginPrime());
656 = (l == 0 ? layers_[0].getOriginPrime() : layers_[l].getOriginPrime());
657 torch::Tensor tensorDist;
658 bool isCalled =
true, doSqrt =
false;
659 getDifferentiableDistance(tree1, tree2, tensorDist, isCalled, doSqrt);
660 losses[l] = tensorDist;
662 trackingLoss = torch::tensor(0, torch::kFloat32);
663 for(
unsigned int i = 0; i < losses.size(); ++i)
664 trackingLoss += losses[i];
670void ttk::MergeTreeAutoencoder::createCustomRecs() {
671 if(customAlphas_.empty())
674 bool initByTreesAlphas = not allAlphas_.empty();
675 std::vector<torch::Tensor> allTreesAlphas;
676 if(initByTreesAlphas) {
677 allTreesAlphas.resize(allAlphas_[0].size());
678 for(
unsigned int l = 0; l < allTreesAlphas.size(); ++l) {
679 allTreesAlphas[l] = allAlphas_[0][l].reshape({-1, 1});
680 for(
unsigned int i = 1; i < allAlphas_.size(); ++i)
682 = torch::cat({allTreesAlphas[l], allAlphas_[i][l]}, 1);
683 allTreesAlphas[l] = allTreesAlphas[l].transpose(0, 1);
687 unsigned int latLayer = getLatentLayerIndex();
688 customRecs_.resize(customAlphas_.size());
689#ifdef TTK_ENABLE_OPENMP
690#pragma omp parallel for schedule(dynamic) \
691 num_threads(this->threadNumber_) if(parallelize_)
693 for(
unsigned int i = 0; i < customAlphas_.size(); ++i) {
694 torch::Tensor alphas = torch::tensor(customAlphas_[i]).reshape({-1, 1});
696 torch::Tensor alphasWeight;
697 if(initByTreesAlphas) {
698 auto driver =
"gelsd";
699 alphasWeight = std::get<0>(torch::linalg_lstsq(
700 allTreesAlphas[latLayer].transpose(0, 1),
701 alphas, c10::nullopt, driver))
706 std::vector<mtu::TorchMergeTree<float>> outs, outs2;
707 auto noOuts = noLayers_ - latLayer;
709 outs2.resize(noOuts);
710 mtu::TorchMergeTree<float> out, out2;
711 layers_[latLayer].outputBasisReconstruction(alphas, outs[0], outs2[0]);
714 for(
unsigned int l = latLayer + 1; l < noLayers_; ++l) {
715 unsigned int noIter = (initByTreesAlphas ? 1 : 32);
716 std::vector<torch::Tensor> allAlphasInit(noIter);
717 torch::Tensor maxNorm;
718 for(
unsigned int j = 0; j < allAlphasInit.size(); ++j) {
720 = torch::randn({layers_[l].getVSTensor().sizes()[1], 1});
721 auto norm = torch::linalg_vector_norm(
722 allAlphasInit[j], 2, 0,
false, c10::nullopt);
723 if(j == 0 or maxNorm.item<
float>() < norm.item<
float>())
726 for(
unsigned int j = 0; j < allAlphasInit.size(); ++j)
727 allAlphasInit[j] /= maxNorm;
728 float bestDistance = std::numeric_limits<float>::max();
729 auto outIndex = l - latLayer;
730 mtu::TorchMergeTree<float> outToUse;
731 for(
unsigned int j = 0; j < noIter; ++j) {
732 torch::Tensor alphasInit, dataAlphas;
733 if(initByTreesAlphas) {
735 = torch::matmul(alphasWeight, allTreesAlphas[l]).transpose(0, 1);
737 alphasInit = allAlphasInit[j];
740 layers_[l].forward(outs[outIndex - 1], outs2[outIndex - 1], k,
741 alphasInit, outToUse, outs2[outIndex], dataAlphas,
743 if(distance < bestDistance) {
745 mtu::copyTorchMergeTree<float>(
746 outToUse, (l != noLayers_ - 1 ? outs[outIndex] : customRecs_[i]));
752 customMatchings_.resize(customRecs_.size());
753#ifdef TTK_ENABLE_OPENMP
754#pragma omp parallel for schedule(dynamic) \
755 num_threads(this->threadNumber_) if(parallelize_)
757 for(
unsigned int i = 0; i < customRecs_.size(); ++i) {
758 bool isCalled =
true;
760 computeOneDistance<float>(layers_[0].getOrigin().mTree,
761 customRecs_[i].mTree, customMatchings_[i],
762 distance, isCalled, useDoubleInput_);
765 mtu::TorchMergeTree<float> originCopy;
766 mtu::copyTorchMergeTree<float>(layers_[0].getOrigin(), originCopy);
767 postprocessingPipeline<float>(&(originCopy.mTree.tree));
768 for(
unsigned int i = 0; i < customRecs_.size(); ++i) {
769 fixTreePrecisionScalars(customRecs_[i].mTree);
770 postprocessingPipeline<float>(&(customRecs_[i].mTree.tree));
771 if(not isPersistenceDiagram_) {
772 convertBranchDecompositionMatching<float>(&(originCopy.mTree.tree),
773 &(customRecs_[i].mTree.tree),
774 customMatchings_[i]);
782unsigned int ttk::MergeTreeAutoencoder::getLatentLayerIndex() {
783 unsigned int idx = noLayers_ / 2 - 1;
789void ttk::MergeTreeAutoencoder::copyCustomParams(
bool get) {
790 auto &srcLatentCentroids = (get ? latentCentroids_ : bestLatentCentroids_);
791 auto &dstLatentCentroids = (!get ? latentCentroids_ : bestLatentCentroids_);
792 dstLatentCentroids.resize(srcLatentCentroids.size());
793 for(
unsigned int i = 0; i < dstLatentCentroids.size(); ++i)
794 mtu::copyTensor(srcLatentCentroids[i], dstLatentCentroids[i]);
800void ttk::MergeTreeAutoencoder::executeEndFunction(
804 computeTrackingInformation(getLatentLayerIndex() + 1);
806 computeCorrelationMatrix(trees, getLatentLayerIndex());
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
void setDebugMsgPrefix(const std::string &prefix)
T distance(const T *p0, const T *p1, const int &dimension=3)
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/| (_) |"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)