TTK
Loading...
Searching...
No Matches
MergeTreeAutoencoder.cpp
Go to the documentation of this file.
2#include <cmath>
3
4#ifdef TTK_ENABLE_TORCH
5using namespace torch::indexing;
6#endif
7
9 // inherited from Debug: prefix will be printed at the beginning of every msg
10 this->setDebugMsgPrefix("MergeTreeAutoencoder");
11}
12
13#ifdef TTK_ENABLE_TORCH
14void ttk::MergeTreeAutoencoder::initClusteringLossParameters() {
15 unsigned int l = getLatentLayerIndex();
16 unsigned int noCentroids
17 = std::set<unsigned int>(clusterAsgn_.begin(), clusterAsgn_.end()).size();
18 latentCentroids_.resize(noCentroids);
19 for(unsigned int c = 0; c < noCentroids; ++c) {
20 unsigned int firstIndex = std::numeric_limits<unsigned int>::max();
21 for(unsigned int i = 0; i < clusterAsgn_.size(); ++i) {
22 if(clusterAsgn_[i] == c) {
23 firstIndex = i;
24 break;
25 }
26 }
27 if(firstIndex >= allAlphas_.size()) {
28 printWrn("no data found for cluster " + std::to_string(c));
29 // TODO init random centroid
30 }
31 latentCentroids_[c] = allAlphas_[firstIndex][l].detach().clone();
32 float noData = 1;
33 for(unsigned int i = 0; i < allAlphas_.size(); ++i) {
34 if(i == firstIndex)
35 continue;
36 if(clusterAsgn_[i] == c) {
37 latentCentroids_[c] += allAlphas_[i][l];
38 ++noData;
39 }
40 }
41 latentCentroids_[c] /= torch::tensor(noData);
42 latentCentroids_[c] = latentCentroids_[c].detach();
43 latentCentroids_[c].requires_grad_(true);
44 }
45}
46
47bool ttk::MergeTreeAutoencoder::initResetOutputBasis(
48 unsigned int l,
49 unsigned int layerNoAxes,
50 double layerOriginPrimeSizePercent,
51 std::vector<mtu::TorchMergeTree<float>> &trees,
52 std::vector<mtu::TorchMergeTree<float>> &trees2,
53 std::vector<bool> &isTrain) {
54 printMsg("Reset output basis", debug::Priority::DETAIL);
55 if((noLayers_ == 2 and l == 1) or noLayers_ == 1) {
56 initOutputBasisSpecialCase(l, layerNoAxes, trees, trees2);
57 } else if(l < (unsigned int)(noLayers_ / 2)) {
58 initOutputBasis(l, layerOriginPrimeSizePercent, trees, trees2, isTrain);
59 } else {
60 printErr("recs[i].mTree.tree.getRealNumberOfNodes() == 0");
61 std::stringstream ssT;
62 ssT << "layer " << l;
63 printWrn(ssT.str());
64 return true;
65 }
66 return false;
67}
68
69void ttk::MergeTreeAutoencoder::initOutputBasisSpecialCase(
70 unsigned int l,
71 unsigned int layerNoAxes,
72 std::vector<mtu::TorchMergeTree<float>> &trees,
73 std::vector<mtu::TorchMergeTree<float>> &trees2) {
74 // - Compute Origin
75 printMsg("Compute output basis origin", debug::Priority::DETAIL);
76 layers_[l].setOriginPrime(layers_[0].getOrigin());
77 if(useDoubleInput_)
78 layers_[l].setOrigin2Prime(layers_[0].getOrigin2());
79 // - Compute vectors
80 printMsg("Compute output basis vectors", debug::Priority::DETAIL);
81 if(layerNoAxes != layers_[0].getVSTensor().sizes()[1]) {
82 // TODO is there a way to avoid copy of merge trees?
83 std::vector<ftm::MergeTree<float>> treesToUse, trees2ToUse;
84 for(unsigned int i = 0; i < trees.size(); ++i) {
85 treesToUse.emplace_back(trees[i].mTree);
86 if(useDoubleInput_)
87 trees2ToUse.emplace_back(trees2[i].mTree);
88 }
89 std::vector<torch::Tensor> allAlphasInitT(trees.size());
90 layers_[l].initInputBasisVectors(
91 trees, trees2, treesToUse, trees2ToUse, layerNoAxes, allAlphasInitT,
92 inputToBaryDistances_L0_, baryMatchings_L0_, baryMatchings2_L0_, false);
93 } else {
94 layers_[l].setVSPrimeTensor(layers_[0].getVSTensor());
95 if(useDoubleInput_)
96 layers_[l].setVS2PrimeTensor(layers_[0].getVS2Tensor());
97 }
98}
99
100float ttk::MergeTreeAutoencoder::initParameters(
101 std::vector<mtu::TorchMergeTree<float>> &trees,
102 std::vector<mtu::TorchMergeTree<float>> &trees2,
103 std::vector<bool> &isTrain,
104 bool computeError) {
105 // ----- Init variables
106 // noLayers_ = number of encoder layers + number of decoder layers + the
107 // latent layer + the output layer
108 noLayers_ = encoderNoLayers_ * 2 + 1 + 1;
109 if(encoderNoLayers_ <= -1)
110 noLayers_ = 1;
111 std::vector<double> layersOriginPrimeSizePercent(noLayers_);
112 std::vector<unsigned int> layersNoAxes(noLayers_);
113 if(noLayers_ <= 2) {
114 layersNoAxes[0] = numberOfAxes_;
115 layersOriginPrimeSizePercent[0] = latentSpaceOriginPrimeSizePercent_;
116 if(noLayers_ == 2) {
117 layersNoAxes[1] = inputNumberOfAxes_;
118 layersOriginPrimeSizePercent[1] = barycenterSizeLimitPercent_;
119 }
120 } else {
121 for(unsigned int l = 0; l < noLayers_ / 2; ++l) {
122 double alpha = (double)(l) / (noLayers_ / 2 - 1);
123 unsigned int noAxes
124 = (1 - alpha) * inputNumberOfAxes_ + alpha * numberOfAxes_;
125 layersNoAxes[l] = noAxes;
126 layersNoAxes[noLayers_ - 1 - l] = noAxes;
127 double originPrimeSizePercent
128 = (1 - alpha) * inputOriginPrimeSizePercent_
129 + alpha * latentSpaceOriginPrimeSizePercent_;
130 layersOriginPrimeSizePercent[l] = originPrimeSizePercent;
131 layersOriginPrimeSizePercent[noLayers_ - 1 - l] = originPrimeSizePercent;
132 }
133 if(scaleLayerAfterLatent_)
134 layersNoAxes[noLayers_ / 2]
135 = (layersNoAxes[noLayers_ / 2 - 1] + layersNoAxes[noLayers_ / 2 + 1])
136 / 2.0;
137 }
138
139 // ----- Resize parameters
140 layers_.resize(noLayers_);
141 for(unsigned int l = 0; l < layers_.size(); ++l) {
142 initOriginPrimeValuesByCopy_
143 = trackingLossWeight_ != 0
144 and l < (trackingLossDecoding_ ? noLayers_ : getLatentLayerIndex() + 1);
145 initOriginPrimeValuesByCopyRandomness_ = trackingLossInitRandomness_;
146 passLayerParameters(layers_[l]);
147 }
148
149 // ----- Compute parameters of each layer
150 bool fullSymmetricAE = fullSymmetricAE_;
151
152 std::vector<mtu::TorchMergeTree<float>> recs, recs2;
153 std::vector<std::vector<torch::Tensor>> allAlphasInit(
154 trees.size(), std::vector<torch::Tensor>(noLayers_));
155 for(unsigned int l = 0; l < noLayers_; ++l) {
157 std::stringstream ss;
158 ss << "Init Layer " << l;
160
161 // --- Init Input Basis
162 if(l < (unsigned int)(noLayers_ / 2) or not fullSymmetricAE
163 or (noLayers_ <= 2 and not fullSymmetricAE)) {
164 auto &treesToUse = (l == 0 ? trees : recs);
165 auto &trees2ToUse = (l == 0 ? trees2 : recs2);
166 initInputBasis(
167 l, layersNoAxes[l], treesToUse, trees2ToUse, isTrain, allAlphasInit);
168 } else {
169 // - Copy output tensors of the opposite layer (full symmetric init)
170 printMsg(
171 "Copy output tensors of the opposite layer", debug::Priority::DETAIL);
172 unsigned int middle = noLayers_ / 2;
173 unsigned int l_opp = middle - (l - middle + 1);
174 layers_[l].setOrigin(layers_[l_opp].getOriginPrime());
175 layers_[l].setVSTensor(layers_[l_opp].getVSPrimeTensor());
176 if(trees2.size() != 0) {
177 if(fullSymmetricAE) {
178 layers_[l].setOrigin2(layers_[l_opp].getOrigin2Prime());
179 layers_[l].setVS2Tensor(layers_[l_opp].getVS2PrimeTensor());
180 }
181 }
182 for(unsigned int i = 0; i < trees.size(); ++i)
183 allAlphasInit[i][l] = allAlphasInit[i][l_opp];
184 }
185
186 // --- Init Output Basis
187 if((noLayers_ == 2 and l == 1) or noLayers_ == 1) {
188 // -- Special case
189 initOutputBasisSpecialCase(l, layersNoAxes[l], trees, trees2);
190 } else if(l < (unsigned int)(noLayers_ / 2)) {
191 initOutputBasis(
192 l, layersOriginPrimeSizePercent[l], trees, trees2, isTrain);
193 } else {
194 // - Copy input tensors of the opposite layer (symmetric init)
195 printMsg(
196 "Copy input tensors of the opposite layer", debug::Priority::DETAIL);
197 unsigned int middle = noLayers_ / 2;
198 unsigned int l_opp = middle - (l - middle + 1);
199 layers_[l].setOriginPrime(layers_[l_opp].getOrigin());
200 if(trees2.size() != 0)
201 layers_[l].setOrigin2Prime(layers_[l_opp].getOrigin2());
202 if(l == (unsigned int)(noLayers_) / 2 and scaleLayerAfterLatent_) {
203 unsigned int dim2
204 = (trees2.size() != 0 ? layers_[l].getOrigin2Prime().tensor.sizes()[0]
205 : 0);
206 layers_[l].initOutputBasisVectors(
207 layers_[l].getOriginPrime().tensor.sizes()[0], dim2);
208 } else {
209 layers_[l].setVSPrimeTensor(layers_[l_opp].getVSTensor());
210 if(trees2.size() != 0)
211 layers_[l].setVS2PrimeTensor(layers_[l_opp].getVS2Tensor());
212 }
213 }
214
215 // --- Get reconstructed
216 bool fullReset = initGetReconstructed(
217 l, layersNoAxes[l], layersOriginPrimeSizePercent[l], trees, trees2,
218 isTrain, recs, recs2, allAlphasInit);
219 if(fullReset)
220 return std::numeric_limits<float>::max();
221 }
222 allAlphas_ = allAlphasInit;
223
224 // Init clustering parameters if needed
225 if(clusteringLossWeight_ != 0)
226 initClusteringLossParameters();
227
228 // Compute error
229 float error = 0.0, recLoss = 0.0;
230 if(computeError) {
231 printMsg("Compute error", debug::Priority::DETAIL);
232 std::vector<unsigned int> indexes(trees.size());
233 std::iota(indexes.begin(), indexes.end(), 0);
234 // TODO forward only if necessary
235 unsigned int k = k_;
236 std::vector<std::vector<torch::Tensor>> bestAlphas;
237 std::vector<std::vector<mtu::TorchMergeTree<float>>> layersOuts,
238 layersOuts2;
239 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
240 matchings, matchings2;
241 bool reset = forwardStep(trees, trees2, indexes, k, allAlphasInit,
242 computeError, recs, recs2, bestAlphas, layersOuts,
243 layersOuts2, matchings, matchings2, recLoss);
244 if(reset) {
245 printWrn("[initParameters] forwardStep reset");
246 return std::numeric_limits<float>::max();
247 }
248 error = recLoss * reconstructionLossWeight_;
249 if(metricLossWeight_ != 0) {
250 torch::Tensor metricLoss;
251 computeMetricLoss(layersOuts, layersOuts2, allAlphas_, distanceMatrix_,
252 indexes, metricLoss);
253 baseRecLoss_ = std::numeric_limits<double>::max();
254 metricLoss *= metricLossWeight_
255 * getCustomLossDynamicWeight(recLoss, baseRecLoss_);
256 error += metricLoss.item<float>();
257 }
258 if(clusteringLossWeight_ != 0) {
259 torch::Tensor clusteringLoss, asgn;
260 computeClusteringLoss(allAlphas_, indexes, clusteringLoss, asgn);
261 baseRecLoss_ = std::numeric_limits<double>::max();
262 clusteringLoss *= clusteringLossWeight_
263 * getCustomLossDynamicWeight(recLoss, baseRecLoss_);
264 error += clusteringLoss.item<float>();
265 }
266 if(trackingLossWeight_ != 0) {
267 torch::Tensor trackingLoss;
268 computeTrackingLoss(trackingLoss);
269 trackingLoss *= trackingLossWeight_;
270 error += trackingLoss.item<float>();
271 }
272 }
273 return error;
274}
275
276// ---------------------------------------------------------------------------
277// --- Backward
278// ---------------------------------------------------------------------------
279bool ttk::MergeTreeAutoencoder::backwardStep(
280 std::vector<mtu::TorchMergeTree<float>> &trees,
281 std::vector<mtu::TorchMergeTree<float>> &outs,
282 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
283 &matchings,
284 std::vector<mtu::TorchMergeTree<float>> &trees2,
285 std::vector<mtu::TorchMergeTree<float>> &outs2,
286 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
287 &matchings2,
288 std::vector<std::vector<torch::Tensor>> &ttkNotUsed(alphas),
289 torch::optim::Optimizer &optimizer,
290 std::vector<unsigned int> &indexes,
291 std::vector<bool> &ttkNotUsed(isTrain),
292 std::vector<torch::Tensor> &torchCustomLoss) {
293 double totalLoss = 0;
294 bool retainGraph = (metricLossWeight_ != 0 or clusteringLossWeight_ != 0
295 or trackingLossWeight_ != 0);
296 if(reconstructionLossWeight_ != 0
297 or (customLossDynamicWeight_ and retainGraph)) {
298 std::vector<torch::Tensor> outTensors(indexes.size()),
299 reorderedTensors(indexes.size());
300#ifdef TTK_ENABLE_OPENMP
301#pragma omp parallel for schedule(dynamic) \
302 num_threads(this->threadNumber_) if(parallelize_)
303#endif
304 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
305 unsigned int i = indexes[ind];
306 torch::Tensor reorderedTensor;
307 dataReorderingGivenMatching(
308 outs[i], trees[i], matchings[i], reorderedTensor);
309 auto outTensor = outs[i].tensor;
310 if(useDoubleInput_) {
311 torch::Tensor reorderedTensor2;
312 dataReorderingGivenMatching(
313 outs2[i], trees2[i], matchings2[i], reorderedTensor2);
314 outTensor = torch::cat({outTensor, outs2[i].tensor});
315 reorderedTensor = torch::cat({reorderedTensor, reorderedTensor2});
316 }
317 outTensors[ind] = outTensor;
318 reorderedTensors[ind] = reorderedTensor;
319 }
320 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
321 auto loss = torch::nn::functional::mse_loss(
322 outTensors[ind], reorderedTensors[ind]);
323 // Same as next loss with a factor of 1 / n where n is the number of nodes
324 // in the output
325 // auto loss = (outTensors[ind] - reorderedTensors[ind]).pow(2).sum();
326 totalLoss += loss.item<float>();
327 loss *= reconstructionLossWeight_;
328 loss.backward({}, retainGraph);
329 }
330 }
331 if(metricLossWeight_ != 0) {
332 bool retainGraphMetricLoss
333 = (clusteringLossWeight_ != 0 or trackingLossWeight_ != 0);
334 torchCustomLoss[0] *= metricLossWeight_
335 * getCustomLossDynamicWeight(
336 totalLoss / indexes.size(), baseRecLoss2_);
337 torchCustomLoss[0].backward({}, retainGraphMetricLoss);
338 }
339 if(clusteringLossWeight_ != 0) {
340 bool retainGraphClusteringLoss = (trackingLossWeight_ != 0);
341 torchCustomLoss[1] *= clusteringLossWeight_
342 * getCustomLossDynamicWeight(
343 totalLoss / indexes.size(), baseRecLoss2_);
344 torchCustomLoss[1].backward({}, retainGraphClusteringLoss);
345 }
346 if(trackingLossWeight_ != 0) {
347 torchCustomLoss[2] *= trackingLossWeight_;
348 torchCustomLoss[2].backward();
349 }
350
351 for(unsigned int l = 0; l < noLayers_; ++l)
352 checkZeroGrad(l);
353
354 optimizer.step();
355 optimizer.zero_grad();
356 return false;
357}
358
359// ---------------------------------------------------------------------------
360// --- Convergence
361// ---------------------------------------------------------------------------
362float ttk::MergeTreeAutoencoder::computeOneLoss(
363 mtu::TorchMergeTree<float> &tree,
364 mtu::TorchMergeTree<float> &out,
365 mtu::TorchMergeTree<float> &tree2,
366 mtu::TorchMergeTree<float> &out2,
367 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
368 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
369 std::vector<torch::Tensor> &ttkNotUsed(alphas),
370 unsigned int ttkNotUsed(treeIndex)) {
371 float loss = 0;
372 bool isCalled = true;
373 float distance;
374 computeOneDistance<float>(
375 out.mTree, tree.mTree, matching, distance, isCalled, useDoubleInput_);
376 if(useDoubleInput_) {
377 float distance2;
378 computeOneDistance<float>(out2.mTree, tree2.mTree, matching2, distance2,
379 isCalled, useDoubleInput_, false);
380 distance = mixDistances<float>(distance, distance2);
381 }
382 loss += distance * distance;
383 return loss;
384}
385
386// ---------------------------------------------------------------------------
387// --- Main Functions
388// ---------------------------------------------------------------------------
389void ttk::MergeTreeAutoencoder::customInit(
390 std::vector<mtu::TorchMergeTree<float>> &torchTrees,
391 std::vector<mtu::TorchMergeTree<float>> &torchTrees2) {
392 baseRecLoss_ = std::numeric_limits<double>::max();
393 baseRecLoss2_ = std::numeric_limits<double>::max();
394 // ----- Init Metric Loss
395 if(metricLossWeight_ != 0)
396 getDistanceMatrix(torchTrees, torchTrees2, distanceMatrix_);
397}
398
399void ttk::MergeTreeAutoencoder::addCustomParameters(
400 std::vector<torch::Tensor> &parameters) {
401 if(clusteringLossWeight_ != 0)
402 for(unsigned int i = 0; i < latentCentroids_.size(); ++i)
403 parameters.emplace_back(latentCentroids_[i]);
404}
405
406void ttk::MergeTreeAutoencoder::computeCustomLosses(
407 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
408 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
409 std::vector<std::vector<torch::Tensor>> &bestAlphas,
410 std::vector<unsigned int> &indexes,
411 std::vector<bool> &ttkNotUsed(isTrain),
412 unsigned int ttkNotUsed(iteration),
413 std::vector<std::vector<float>> &gapCustomLosses,
414 std::vector<std::vector<float>> &iterationCustomLosses,
415 std::vector<torch::Tensor> &torchCustomLoss) {
416 if(gapCustomLosses.empty())
417 gapCustomLosses.resize(3);
418 if(iterationCustomLosses.empty())
419 iterationCustomLosses.resize(3);
420 torchCustomLoss.resize(3);
421 // - Metric Loss
422 if(metricLossWeight_ != 0) {
423 computeMetricLoss(layersOuts, layersOuts2, bestAlphas, distanceMatrix_,
424 indexes, torchCustomLoss[0]);
425 float metricLossF = torchCustomLoss[0].item<float>();
426 gapCustomLosses[0].emplace_back(metricLossF);
427 iterationCustomLosses[0].emplace_back(metricLossF);
428 }
429 // - Clustering Loss
430 if(clusteringLossWeight_ != 0) {
431 torch::Tensor asgn;
432 computeClusteringLoss(bestAlphas, indexes, torchCustomLoss[1], asgn);
433 float clusteringLossF = torchCustomLoss[1].item<float>();
434 gapCustomLosses[1].emplace_back(clusteringLossF);
435 iterationCustomLosses[1].emplace_back(clusteringLossF);
436 }
437 // - Tracking Loss
438 if(trackingLossWeight_ != 0) {
439 computeTrackingLoss(torchCustomLoss[2]);
440 float trackingLossF = torchCustomLoss[2].item<float>();
441 gapCustomLosses[2].emplace_back(trackingLossF);
442 iterationCustomLosses[2].emplace_back(trackingLossF);
443 }
444}
445
446float ttk::MergeTreeAutoencoder::computeIterationTotalLoss(
447 float iterationLoss,
448 std::vector<std::vector<float>> &iterationCustomLosses,
449 std::vector<float> &iterationCustomLoss) {
450 iterationCustomLoss.emplace_back(iterationLoss);
451 float iterationTotalLoss = reconstructionLossWeight_ * iterationLoss;
452 // Metric
453 float iterationMetricLoss = 0;
454 if(metricLossWeight_ != 0) {
455 iterationMetricLoss
456 = torch::tensor(iterationCustomLosses[0]).mean().item<float>();
457 iterationTotalLoss
458 += metricLossWeight_
459 * getCustomLossDynamicWeight(iterationLoss, baseRecLoss_)
460 * iterationMetricLoss;
461 }
462 iterationCustomLoss.emplace_back(iterationMetricLoss);
463 // Clustering
464 float iterationClusteringLoss = 0;
465 if(clusteringLossWeight_ != 0) {
466 iterationClusteringLoss
467 = torch::tensor(iterationCustomLosses[1]).mean().item<float>();
468 iterationTotalLoss
469 += clusteringLossWeight_
470 * getCustomLossDynamicWeight(iterationLoss, baseRecLoss_)
471 * iterationClusteringLoss;
472 }
473 iterationCustomLoss.emplace_back(iterationClusteringLoss);
474 // Tracking
475 float iterationTrackingLoss = 0;
476 if(trackingLossWeight_ != 0) {
477 iterationTrackingLoss
478 = torch::tensor(iterationCustomLosses[2]).mean().item<float>();
479 iterationTotalLoss += trackingLossWeight_ * iterationTrackingLoss;
480 }
481 iterationCustomLoss.emplace_back(iterationTrackingLoss);
482 return iterationTotalLoss;
483}
484
485void ttk::MergeTreeAutoencoder::printCustomLosses(
486 std::vector<float> &customLoss,
487 std::stringstream &prefix,
488 const debug::Priority &priority) {
489 if(priority != debug::Priority::VERBOSE)
490 prefix.str("");
491 std::stringstream ssBestLoss;
492 if(metricLossWeight_ != 0 or clusteringLossWeight_ != 0
493 or trackingLossWeight_ != 0) {
494 ssBestLoss.str("");
495 ssBestLoss << "- Rec. " << prefix.str() << "loss = " << customLoss[0];
496 printMsg(ssBestLoss.str(), priority);
497 }
498 if(metricLossWeight_ != 0) {
499 ssBestLoss.str("");
500 ssBestLoss << "- Metric " << prefix.str() << "loss = " << customLoss[1];
501 printMsg(ssBestLoss.str(), priority);
502 }
503 if(clusteringLossWeight_ != 0) {
504 ssBestLoss.str("");
505 ssBestLoss << "- Clust. " << prefix.str() << "loss = " << customLoss[2];
506 printMsg(ssBestLoss.str(), priority);
507 }
508 if(trackingLossWeight_ != 0) {
509 ssBestLoss.str("");
510 ssBestLoss << "- Track. " << prefix.str() << "loss = " << customLoss[3];
511 printMsg(ssBestLoss.str(), priority);
512 }
513}
514
515void ttk::MergeTreeAutoencoder::printGapLoss(
516 float loss, std::vector<std::vector<float>> &gapCustomLosses) {
517 std::stringstream ss;
518 ss << "Rec. loss = " << loss;
519 printMsg(ss.str());
520 if(metricLossWeight_ != 0) {
521 float metricLoss = torch::tensor(gapCustomLosses[0]).mean().item<float>();
522 gapCustomLosses[0].clear();
523 ss.str("");
524 ss << "Metric loss = " << metricLoss;
525 printMsg(ss.str());
526 }
527 if(clusteringLossWeight_ != 0) {
528 float clusteringLoss
529 = torch::tensor(gapCustomLosses[1]).mean().item<float>();
530 gapCustomLosses[1].clear();
531 ss.str("");
532 ss << "Clust. loss = " << clusteringLoss;
533 printMsg(ss.str());
534 }
535 if(trackingLossWeight_ != 0) {
536 float trackingLoss = torch::tensor(gapCustomLosses[2]).mean().item<float>();
537 gapCustomLosses[2].clear();
538 ss.str("");
539 ss << "Track. loss = " << trackingLoss;
540 printMsg(ss.str());
541 }
542}
543
544// ---------------------------------------------------------------------------
545// --- Custom Losses
546// ---------------------------------------------------------------------------
547double ttk::MergeTreeAutoencoder::getCustomLossDynamicWeight(double recLoss,
548 double &baseLoss) {
549 baseLoss = std::min(recLoss, baseLoss);
550 if(customLossDynamicWeight_)
551 return baseLoss;
552 else
553 return 1.0;
554}
555
556void ttk::MergeTreeAutoencoder::computeMetricLoss(
557 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
558 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
559 std::vector<std::vector<torch::Tensor>> alphas,
560 std::vector<std::vector<float>> &baseDistanceMatrix,
561 std::vector<unsigned int> &indexes,
562 torch::Tensor &metricLoss) {
563 auto layerIndex = getLatentLayerIndex();
564 std::vector<std::vector<torch::Tensor>> losses(
565 layersOuts.size(), std::vector<torch::Tensor>(layersOuts.size()));
566
567 std::vector<mtu::TorchMergeTree<float> *> trees, trees2;
568 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
569 unsigned int i = indexes[ind];
570 trees.emplace_back(&(layersOuts[i][layerIndex]));
571 if(useDoubleInput_)
572 trees2.emplace_back(&(layersOuts2[i][layerIndex]));
573 }
574
575 std::vector<std::vector<torch::Tensor>> outDistMat;
576 torch::Tensor coefDistMat;
577 if(customLossSpace_) {
578 getDifferentiableDistanceMatrix(trees, trees2, outDistMat);
579 } else {
580 std::vector<std::vector<torch::Tensor>> scaledAlphas;
581 createScaledAlphas(alphas, scaledAlphas);
582 torch::Tensor latentAlphas;
583 getAlphasTensor(scaledAlphas, indexes, layerIndex, latentAlphas);
584 if(customLossActivate_)
585 latentAlphas = activation(latentAlphas);
586 coefDistMat = torch::cdist(latentAlphas, latentAlphas).pow(2);
587 }
588
589 torch::Tensor maxLoss = torch::tensor(0);
590 metricLoss = torch::tensor(0);
591 float div = 0;
592 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
593 unsigned int i = indexes[ind];
594 for(unsigned int ind2 = ind + 1; ind2 < indexes.size(); ++ind2) {
595 unsigned int j = indexes[ind2];
596 torch::Tensor loss;
597 torch::Tensor toCompare
598 = (customLossSpace_ ? outDistMat[i][j] : coefDistMat[ind][ind2]);
599 loss = torch::nn::MSELoss()(
600 torch::tensor(baseDistanceMatrix[i][j]), toCompare);
601 metricLoss = metricLoss + loss;
602 maxLoss = torch::max(loss, maxLoss);
603 ++div;
604 }
605 }
606 metricLoss = metricLoss / torch::tensor(div);
607 if(normalizeMetricLoss_)
608 metricLoss /= maxLoss;
609}
610
611void ttk::MergeTreeAutoencoder::computeClusteringLoss(
612 std::vector<std::vector<torch::Tensor>> &alphas,
613 std::vector<unsigned int> &indexes,
614 torch::Tensor &clusteringLoss,
615 torch::Tensor &asgn) {
616 // Compute distance matrix
617 unsigned int layerIndex = getLatentLayerIndex();
618 torch::Tensor latentAlphas;
619 getAlphasTensor(alphas, indexes, layerIndex, latentAlphas);
620 if(customLossActivate_)
621 latentAlphas = activation(latentAlphas);
622 torch::Tensor centroids = latentCentroids_[0].transpose(0, 1);
623 for(unsigned int i = 1; i < latentCentroids_.size(); ++i)
624 centroids = torch::cat({centroids, latentCentroids_[i].transpose(0, 1)});
625 torch::Tensor dist = torch::cdist(latentAlphas, centroids);
626
627 // Compute softmax and one hot real asgn
628 dist = dist * -clusteringLossTemp_;
629 asgn = torch::nn::Softmax(1)(dist);
630 std::vector<float> clusterAsgn;
631 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
632 clusterAsgn.emplace_back(clusterAsgn_[indexes[ind]]);
633 }
634 torch::Tensor realAsgn = torch::tensor(clusterAsgn).to(torch::kInt64);
635 realAsgn
636 = torch::nn::functional::one_hot(realAsgn, asgn.sizes()[1]).to(torch::kF32);
637
638 // Compute KL div.
639 clusteringLoss = torch::nn::KLDivLoss(
640 torch::nn::KLDivLossOptions().reduction(torch::kBatchMean))(asgn, realAsgn);
641}
642
643void ttk::MergeTreeAutoencoder::computeTrackingLoss(
644 torch::Tensor &trackingLoss) {
645 unsigned int latentLayerIndex = getLatentLayerIndex() + 1;
646 auto endLayer = (trackingLossDecoding_ ? noLayers_ : latentLayerIndex);
647 std::vector<torch::Tensor> losses(endLayer);
648#ifdef TTK_ENABLE_OPENMP
649#pragma omp parallel for schedule(dynamic) \
650 num_threads(this->threadNumber_) if(parallelize_)
651#endif
652 for(unsigned int l = 0; l < endLayer; ++l) {
653 auto &tree1
654 = (l == 0 ? layers_[0].getOrigin() : layers_[l - 1].getOriginPrime());
655 auto &tree2
656 = (l == 0 ? layers_[0].getOriginPrime() : layers_[l].getOriginPrime());
657 torch::Tensor tensorDist;
658 bool isCalled = true, doSqrt = false;
659 getDifferentiableDistance(tree1, tree2, tensorDist, isCalled, doSqrt);
660 losses[l] = tensorDist;
661 }
662 trackingLoss = torch::tensor(0, torch::kFloat32);
663 for(unsigned int i = 0; i < losses.size(); ++i)
664 trackingLoss += losses[i];
665}
666
667// ---------------------------------------------------------------------------
668// --- End Functions
669// ---------------------------------------------------------------------------
670void ttk::MergeTreeAutoencoder::createCustomRecs() {
671 if(customAlphas_.empty())
672 return;
673
674 bool initByTreesAlphas = not allAlphas_.empty();
675 std::vector<torch::Tensor> allTreesAlphas;
676 if(initByTreesAlphas) {
677 allTreesAlphas.resize(allAlphas_[0].size());
678 for(unsigned int l = 0; l < allTreesAlphas.size(); ++l) {
679 allTreesAlphas[l] = allAlphas_[0][l].reshape({-1, 1});
680 for(unsigned int i = 1; i < allAlphas_.size(); ++i)
681 allTreesAlphas[l]
682 = torch::cat({allTreesAlphas[l], allAlphas_[i][l]}, 1);
683 allTreesAlphas[l] = allTreesAlphas[l].transpose(0, 1);
684 }
685 }
686
687 unsigned int latLayer = getLatentLayerIndex();
688 customRecs_.resize(customAlphas_.size());
689#ifdef TTK_ENABLE_OPENMP
690#pragma omp parallel for schedule(dynamic) \
691 num_threads(this->threadNumber_) if(parallelize_)
692#endif
693 for(unsigned int i = 0; i < customAlphas_.size(); ++i) {
694 torch::Tensor alphas = torch::tensor(customAlphas_[i]).reshape({-1, 1});
695
696 torch::Tensor alphasWeight;
697 if(initByTreesAlphas) {
698 auto driver = "gelsd";
699 alphasWeight = std::get<0>(torch::linalg_lstsq(
700 allTreesAlphas[latLayer].transpose(0, 1),
701 alphas, c10::nullopt, driver))
702 .transpose(0, 1);
703 }
704
705 // Reconst latent
706 std::vector<mtu::TorchMergeTree<float>> outs, outs2;
707 auto noOuts = noLayers_ - latLayer;
708 outs.resize(noOuts);
709 outs2.resize(noOuts);
710 mtu::TorchMergeTree<float> out, out2;
711 layers_[latLayer].outputBasisReconstruction(alphas, outs[0], outs2[0]);
712 // Decoding
713 unsigned int k = 32;
714 for(unsigned int l = latLayer + 1; l < noLayers_; ++l) {
715 unsigned int noIter = (initByTreesAlphas ? 1 : 32);
716 std::vector<torch::Tensor> allAlphasInit(noIter);
717 torch::Tensor maxNorm;
718 for(unsigned int j = 0; j < allAlphasInit.size(); ++j) {
719 allAlphasInit[j]
720 = torch::randn({layers_[l].getVSTensor().sizes()[1], 1});
721 auto norm = torch::linalg_vector_norm(
722 allAlphasInit[j], 2, 0, false, c10::nullopt);
723 if(j == 0 or maxNorm.item<float>() < norm.item<float>())
724 maxNorm = norm;
725 }
726 for(unsigned int j = 0; j < allAlphasInit.size(); ++j)
727 allAlphasInit[j] /= maxNorm;
728 float bestDistance = std::numeric_limits<float>::max();
729 auto outIndex = l - latLayer;
730 mtu::TorchMergeTree<float> outToUse;
731 for(unsigned int j = 0; j < noIter; ++j) {
732 torch::Tensor alphasInit, dataAlphas;
733 if(initByTreesAlphas) {
734 alphasInit
735 = torch::matmul(alphasWeight, allTreesAlphas[l]).transpose(0, 1);
736 } else {
737 alphasInit = allAlphasInit[j];
738 }
739 float distance;
740 layers_[l].forward(outs[outIndex - 1], outs2[outIndex - 1], k,
741 alphasInit, outToUse, outs2[outIndex], dataAlphas,
742 distance);
743 if(distance < bestDistance) {
744 bestDistance = distance;
745 mtu::copyTorchMergeTree<float>(
746 outToUse, (l != noLayers_ - 1 ? outs[outIndex] : customRecs_[i]));
747 }
748 }
749 }
750 }
751
752 customMatchings_.resize(customRecs_.size());
753#ifdef TTK_ENABLE_OPENMP
754#pragma omp parallel for schedule(dynamic) \
755 num_threads(this->threadNumber_) if(parallelize_)
756#endif
757 for(unsigned int i = 0; i < customRecs_.size(); ++i) {
758 bool isCalled = true;
759 float distance;
760 computeOneDistance<float>(layers_[0].getOrigin().mTree,
761 customRecs_[i].mTree, customMatchings_[i],
762 distance, isCalled, useDoubleInput_);
763 }
764
765 mtu::TorchMergeTree<float> originCopy;
766 mtu::copyTorchMergeTree<float>(layers_[0].getOrigin(), originCopy);
767 postprocessingPipeline<float>(&(originCopy.mTree.tree));
768 for(unsigned int i = 0; i < customRecs_.size(); ++i) {
769 fixTreePrecisionScalars(customRecs_[i].mTree);
770 postprocessingPipeline<float>(&(customRecs_[i].mTree.tree));
771 if(not isPersistenceDiagram_) {
772 convertBranchDecompositionMatching<float>(&(originCopy.mTree.tree),
773 &(customRecs_[i].mTree.tree),
774 customMatchings_[i]);
775 }
776 }
777}
778
779// ---------------------------------------------------------------------------
780// --- Utils
781// ---------------------------------------------------------------------------
782unsigned int ttk::MergeTreeAutoencoder::getLatentLayerIndex() {
783 unsigned int idx = noLayers_ / 2 - 1;
784 if(idx > noLayers_) // unsigned negativeness
785 idx = 0;
786 return idx;
787}
788
789void ttk::MergeTreeAutoencoder::copyCustomParams(bool get) {
790 auto &srcLatentCentroids = (get ? latentCentroids_ : bestLatentCentroids_);
791 auto &dstLatentCentroids = (!get ? latentCentroids_ : bestLatentCentroids_);
792 dstLatentCentroids.resize(srcLatentCentroids.size());
793 for(unsigned int i = 0; i < dstLatentCentroids.size(); ++i)
794 mtu::copyTensor(srcLatentCentroids[i], dstLatentCentroids[i]);
795}
796
797// ---------------------------------------------------------------------------
798// --- Main Functions
799// ---------------------------------------------------------------------------
800void ttk::MergeTreeAutoencoder::executeEndFunction(
801 std::vector<ftm::MergeTree<float>> &trees,
802 std::vector<ftm::MergeTree<float>> &ttkNotUsed(trees2)) {
803 // Tracking
804 computeTrackingInformation(getLatentLayerIndex() + 1);
805 // Correlation
806 computeCorrelationMatrix(trees, getLatentLayerIndex());
807 // Custom recs
808 createCustomRecs();
809}
810#endif
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
Definition BaseClass.h:47
void setDebugMsgPrefix(const std::string &prefix)
Definition Debug.h:364
T distance(const T *p0, const T *p1, const int &dimension=3)
Definition Geometry.cpp:362
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/| (_) |"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)