TTK
Loading...
Searching...
No Matches
MergeTreeNeuralNetwork.cpp
Go to the documentation of this file.
2#include <cmath>
3
4#ifdef TTK_ENABLE_TORCH
5using namespace torch::indexing;
6#endif
7
9 // inherited from Debug: prefix will be printed at the beginning of every msg
10 this->setDebugMsgPrefix("MergeTreeNeuralNetwork");
11}
12
13#ifdef TTK_ENABLE_TORCH
14// -----------------------------------------------------------------------
15// --- Init
16// -----------------------------------------------------------------------
17void ttk::MergeTreeNeuralNetwork::initInputBasis(
18 unsigned int l,
19 unsigned int layerNoAxes,
20 std::vector<mtu::TorchMergeTree<float>> &tmTrees,
21 std::vector<mtu::TorchMergeTree<float>> &tmTrees2,
22 std::vector<bool> &ttkNotUsed(isTrain),
23 std::vector<std::vector<torch::Tensor>> &allAlphasInit) {
24 // TODO is there a way to avoid copy of merge trees?
25 std::vector<ftm::MergeTree<float>> trees, trees2;
26 for(unsigned int i = 0; i < tmTrees.size(); ++i) {
27 trees.emplace_back(tmTrees[i].mTree);
28 if(useDoubleInput_)
29 trees2.emplace_back(tmTrees2[i].mTree);
30 }
31
32 // - Compute origin
33 printMsg("Compute origin...", debug::Priority::DETAIL);
34 Timer t_origin;
35 std::vector<double> inputToBaryDistances;
36 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
37 baryMatchings, baryMatchings2;
38 if(l != 0 or not layers_[0].getOrigin().tensor.defined()) {
39 double sizeLimit = (l == 0 ? barycenterSizeLimitPercent_ : 0);
40 unsigned int maxNoPairs
41 = (l == 0 ? 0 : layers_[l - 1].getOriginPrime().tensor.sizes()[0] / 2);
42 unsigned int maxNoPairs2
43 = (l == 0 or not useDoubleInput_
44 ? 0
45 : layers_[l - 1].getOrigin2Prime().tensor.sizes()[0] / 2);
46 layers_[l].initInputBasisOrigin(trees, trees2, sizeLimit, maxNoPairs,
47 maxNoPairs2, inputToBaryDistances,
48 baryMatchings, baryMatchings2);
49 if(l == 0) {
50 baryMatchings_L0_ = baryMatchings;
51 baryMatchings2_L0_ = baryMatchings2;
52 inputToBaryDistances_L0_ = inputToBaryDistances;
53 }
54 } else {
55 baryMatchings = baryMatchings_L0_;
56 baryMatchings2 = baryMatchings2_L0_;
57 inputToBaryDistances = inputToBaryDistances_L0_;
58 }
59 printMsg("Compute origin time", 1, t_origin.getElapsedTime(), threadNumber_,
61
62 // - Compute vectors
63 printMsg("Compute vectors...", debug::Priority::DETAIL);
64 Timer t_vectors;
65 std::vector<torch::Tensor> allAlphasInitT(tmTrees.size());
66 layers_[l].initInputBasisVectors(
67 tmTrees, tmTrees2, trees, trees2, layerNoAxes, allAlphasInitT,
68 inputToBaryDistances, baryMatchings, baryMatchings2);
69 for(unsigned int i = 0; i < allAlphasInitT.size(); ++i)
70 allAlphasInit[i][l] = allAlphasInitT[i];
71 printMsg("Compute vectors time", 1, t_vectors.getElapsedTime(), threadNumber_,
73}
74
75void ttk::MergeTreeNeuralNetwork::initOutputBasis(
76 unsigned int l,
77 double layerOriginPrimeSizePercent,
78 std::vector<mtu::TorchMergeTree<float>> &tmTrees,
79 std::vector<mtu::TorchMergeTree<float>> &tmTrees2,
80 std::vector<bool> &ttkNotUsed(isTrain)) {
81 std::vector<ftm::FTMTree_MT *> ftmTrees(tmTrees.size()),
82 ftmTrees2(tmTrees2.size());
83 for(unsigned int i = 0; i < tmTrees.size(); ++i)
84 ftmTrees[i] = &(tmTrees[i].mTree.tree);
85 for(unsigned int i = 0; i < tmTrees2.size(); ++i)
86 ftmTrees2[i] = &(tmTrees2[i].mTree.tree);
87 auto sizeMetric = getSizeLimitMetric(ftmTrees);
88 auto sizeMetric2 = getSizeLimitMetric(ftmTrees2);
89 auto getDim = [](double _sizeMetric, double _percent) {
90 unsigned int dim = std::max((int)(_sizeMetric * _percent / 100.0), 2) * 2;
91 return dim;
92 };
93
94 unsigned int dim = getDim(sizeMetric, layerOriginPrimeSizePercent);
95 dim = std::min(dim, (unsigned int)layers_[l].getOrigin().tensor.sizes()[0]);
96 unsigned int dim2 = getDim(sizeMetric2, layerOriginPrimeSizePercent);
97 if(useDoubleInput_)
98 dim2
99 = std::min(dim2, (unsigned int)layers_[l].getOrigin2().tensor.sizes()[0]);
100 auto baseTensor = (l == 0 ? layers_[0].getOrigin().tensor
101 : layers_[l - 1].getOriginPrime().tensor);
102 layers_[l].initOutputBasis(dim, dim2, baseTensor);
103}
104
105bool ttk::MergeTreeNeuralNetwork::initGetReconstructed(
106 unsigned int l,
107 unsigned int layerNoAxes,
108 double layerOriginPrimeSizePercent,
109 std::vector<mtu::TorchMergeTree<float>> &trees,
110 std::vector<mtu::TorchMergeTree<float>> &trees2,
111 std::vector<bool> &isTrain,
112 std::vector<mtu::TorchMergeTree<float>> &recs,
113 std::vector<mtu::TorchMergeTree<float>> &recs2,
114 std::vector<std::vector<torch::Tensor>> &allAlphasInit) {
115 printMsg("Get reconstructed", debug::Priority::DETAIL);
116 recs.resize(trees.size());
117 recs2.resize(trees.size());
118 unsigned int i = 0;
119 unsigned int noReset = 0;
120 while(i < trees.size()) {
121 layers_[l].outputBasisReconstruction(
122 allAlphasInit[i][l], recs[i], recs2[i], activateOutputInit_);
123 if(recs[i].mTree.tree.getRealNumberOfNodes() == 0) {
124 bool fullReset = initResetOutputBasis(
125 l, layerNoAxes, layerOriginPrimeSizePercent, trees, trees2, isTrain);
126 if(fullReset)
127 return true;
128 i = 0;
129 ++noReset;
130 if(noReset >= 100) {
131 printWrn("[initParameters] noReset >= 100");
132 return true;
133 }
134 }
135 ++i;
136 }
137 return false;
138}
139
140void ttk::MergeTreeNeuralNetwork::initStep(
141 std::vector<mtu::TorchMergeTree<float>> &trees,
142 std::vector<mtu::TorchMergeTree<float>> &trees2,
143 std::vector<bool> &isTrain) {
144 layers_.clear();
145
146 float bestError = std::numeric_limits<float>::max();
147 std::vector<torch::Tensor> bestVSTensor, bestVSPrimeTensor, bestVS2Tensor,
148 bestVS2PrimeTensor, bestLatentCentroids;
149 std::vector<mtu::TorchMergeTree<float>> bestOrigins, bestOriginsPrime,
150 bestOrigins2, bestOrigins2Prime;
151 std::vector<std::vector<torch::Tensor>> bestAlphasInit;
152 for(unsigned int n = 0; n < noInit_; ++n) {
153 // Init parameters
154 float error = initParameters(trees, trees2, isTrain, (noInit_ != 1));
155 // Save best parameters
156 if(noInit_ != 1) {
157 std::stringstream ss;
158 ss << "Init error = " << error;
159 printMsg(ss.str());
160 if(error < bestError) {
161 bestError = error;
162 copyParams(bestOrigins, bestOriginsPrime, bestVSTensor,
163 bestVSPrimeTensor, bestOrigins2, bestOrigins2Prime,
164 bestVS2Tensor, bestVS2PrimeTensor, allAlphas_,
165 bestAlphasInit, true);
166 copyCustomParams(true);
167 }
168 }
169 }
170 // TODO this copy can be avoided if initParameters takes dummy tensors to fill
171 // as parameters and then copy to the member tensors when a better init is
172 // found.
173 if(noInit_ != 1) {
174 // Put back best parameters
175 std::stringstream ss;
176 ss << "Best init error = " << bestError;
177 printMsg(ss.str());
178 copyParams(bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor,
179 bestOrigins2, bestOrigins2Prime, bestVS2Tensor,
180 bestVS2PrimeTensor, bestAlphasInit, allAlphas_, false);
181 copyCustomParams(false);
182 }
183
184 for(unsigned int l = 0; l < noLayers_; ++l) {
185 layers_[l].requires_grad(true);
186
187 // Print
189 std::stringstream ss;
190 ss << "Layer " << l;
191 printMsg(ss.str());
192 if(isTreeHasBigValues(layers_[l].getOrigin().mTree, bigValuesThreshold_)) {
193 ss.str("");
194 ss << "origins_[" << l << "] has big values!" << std::endl;
195 printMsg(ss.str());
196 printPairs(layers_[l].getOrigin().mTree);
197 }
198 if(isTreeHasBigValues(
199 layers_[l].getOriginPrime().mTree, bigValuesThreshold_)) {
200 ss.str("");
201 ss << "originsPrime_[" << l << "] has big values!" << std::endl;
202 printMsg(ss.str());
203 printPairs(layers_[l].getOriginPrime().mTree);
204 }
205 ss.str("");
206 ss << "vS size = " << layers_[l].getVSTensor().sizes();
207 printMsg(ss.str());
208 ss.str("");
209 ss << "vS' size = " << layers_[l].getVSPrimeTensor().sizes();
210 printMsg(ss.str());
211 if(trees2.size() != 0) {
212 ss.str("");
213 ss << "vS2 size = " << layers_[l].getVS2Tensor().sizes();
214 printMsg(ss.str());
215 ss.str("");
216 ss << "vS2' size = " << layers_[l].getVS2PrimeTensor().sizes();
217 printMsg(ss.str());
218 }
219 }
220}
221
222void ttk::MergeTreeNeuralNetwork::passLayerParameters(
223 MergeTreeNeuralLayer &layer) {
224 layer.setDropout(dropout_);
225 layer.setEuclideanVectorsInit(euclideanVectorsInit_);
226 layer.setRandomAxesInit(randomAxesInit_);
227 layer.setInitBarycenterRandom(initBarycenterRandom_);
228 layer.setInitBarycenterOneIter(initBarycenterOneIter_);
229 layer.setInitOriginPrimeStructByCopy(initOriginPrimeStructByCopy_);
230 layer.setInitOriginPrimeValuesByCopy(initOriginPrimeValuesByCopy_);
231 layer.setInitOriginPrimeValuesByCopyRandomness(
232 initOriginPrimeValuesByCopyRandomness_);
233 layer.setActivate(activate_);
234 layer.setActivationFunction(activationFunction_);
235 layer.setUseGpu(useGpu_);
236 layer.setBigValuesThreshold(bigValuesThreshold_);
237
238 layer.setDeterministic(deterministic_);
239 layer.setNumberOfProjectionSteps(k_);
240 layer.setBarycenterSizeLimitPercent(barycenterSizeLimitPercent_);
241 layer.setProbabilisticVectorsInit(probabilisticVectorsInit_);
242
243 layer.setNormalizedWasserstein(normalizedWasserstein_);
244 layer.setAssignmentSolver(assignmentSolverID_);
245 layer.setNodePerTask(nodePerTask_);
246 layer.setUseDoubleInput(useDoubleInput_);
247 layer.setJoinSplitMixtureCoefficient(mixtureCoefficient_);
248 layer.setIsPersistenceDiagram(isPersistenceDiagram_);
249
250 layer.setDebugLevel(debugLevel_);
251 layer.setThreadNumber(threadNumber_);
252}
253
254// ---------------------------------------------------------------------------
255// --- Forward
256// ---------------------------------------------------------------------------
257bool ttk::MergeTreeNeuralNetwork::forwardOneData(
258 mtu::TorchMergeTree<float> &tree,
259 mtu::TorchMergeTree<float> &tree2,
260 unsigned int treeIndex,
261 unsigned int k,
262 std::vector<torch::Tensor> &alphasInit,
263 mtu::TorchMergeTree<float> &out,
264 mtu::TorchMergeTree<float> &out2,
265 std::vector<torch::Tensor> &dataAlphas,
266 std::vector<mtu::TorchMergeTree<float>> &outs,
267 std::vector<mtu::TorchMergeTree<float>> &outs2,
268 bool train) {
269 outs.resize(noLayers_ - 1);
270 outs2.resize(noLayers_ - 1);
271 dataAlphas.resize(noLayers_);
272 for(unsigned int l = 0; l < noLayers_; ++l) {
273 auto &treeToUse = (l == 0 ? tree : outs[l - 1]);
274 auto &tree2ToUse = (l == 0 ? tree2 : outs2[l - 1]);
275 auto &outToUse = (l != noLayers_ - 1 ? outs[l] : out);
276 auto &out2ToUse = (l != noLayers_ - 1 ? outs2[l] : out2);
277 bool reset = layers_[l].forward(treeToUse, tree2ToUse, k, alphasInit[l],
278 outToUse, out2ToUse, dataAlphas[l], train);
279 if(reset)
280 return true;
281 // Update recs
282 auto updateRecs
283 = [this, &treeIndex, &l](
284 std::vector<std::vector<mtu::TorchMergeTree<float>>> &recs,
285 mtu::TorchMergeTree<float> &outT) {
286 if(recs[treeIndex].size() > noLayers_)
287 mtu::copyTorchMergeTree<float>(outT, recs[treeIndex][l + 1]);
288 else {
289 mtu::TorchMergeTree<float> tmt;
290 mtu::copyTorchMergeTree<float>(outT, tmt);
291 recs[treeIndex].emplace_back(tmt);
292 }
293 };
294 updateRecs(recs_, outToUse);
295 if(useDoubleInput_)
296 updateRecs(recs2_, out2ToUse);
297 }
298 return false;
299}
300
301bool ttk::MergeTreeNeuralNetwork::forwardStep(
302 std::vector<mtu::TorchMergeTree<float>> &trees,
303 std::vector<mtu::TorchMergeTree<float>> &trees2,
304 std::vector<unsigned int> &indexes,
305 std::vector<bool> &isTrain,
306 unsigned int k,
307 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
308 bool computeError,
309 std::vector<mtu::TorchMergeTree<float>> &outs,
310 std::vector<mtu::TorchMergeTree<float>> &outs2,
311 std::vector<std::vector<torch::Tensor>> &bestAlphas,
312 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
313 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
314 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
315 &matchings,
316 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
317 &matchings2,
318 float &loss,
319 float &testLoss) {
320 loss = 0;
321 testLoss = 0;
322 outs.resize(trees.size());
323 outs2.resize(trees.size());
324 bestAlphas.resize(trees.size());
325 layersOuts.resize(trees.size());
326 layersOuts2.resize(trees.size());
327 matchings.resize(trees.size());
328 if(useDoubleInput_)
329 matchings2.resize(trees2.size());
330 mtu::TorchMergeTree<float> dummyTMT;
331 bool reset = false;
332 unsigned int noTrainLoss = 0, noTestLoss = 0;
333#ifdef TTK_ENABLE_OPENMP
334#pragma omp parallel for schedule(dynamic) \
335 num_threads(this->threadNumber_) if(parallelize_) reduction(|| : reset) \
336 reduction(+ : loss)
337#endif
338 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
339 unsigned int i = indexes[ind];
340 auto &tree2ToUse = (trees2.size() == 0 ? dummyTMT : trees2[i]);
341 bool dReset = forwardOneData(trees[i], tree2ToUse, i, k, allAlphasInit[i],
342 outs[i], outs2[i], bestAlphas[i],
343 layersOuts[i], layersOuts2[i], isTrain[i]);
344 if(computeError) {
345 float iLoss
346 = computeOneLoss(trees[i], outs[i], trees2[i], outs2[i], matchings[i],
347 matchings2[i], bestAlphas[i], i);
348 if(isTrain[i]) {
349 loss += iLoss;
350 ++noTrainLoss;
351 } else {
352 testLoss += iLoss;
353 ++noTestLoss;
354 }
355 }
356 if(dReset)
357 reset = reset || dReset;
358 }
359 if(noTrainLoss != 0)
360 loss /= noTrainLoss;
361 if(noTestLoss != 0)
362 testLoss /= noTestLoss;
363 return reset;
364}
365
366bool ttk::MergeTreeNeuralNetwork::forwardStep(
367 std::vector<mtu::TorchMergeTree<float>> &trees,
368 std::vector<mtu::TorchMergeTree<float>> &trees2,
369 std::vector<unsigned int> &indexes,
370 unsigned int k,
371 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
372 bool computeError,
373 std::vector<mtu::TorchMergeTree<float>> &outs,
374 std::vector<mtu::TorchMergeTree<float>> &outs2,
375 std::vector<std::vector<torch::Tensor>> &bestAlphas,
376 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
377 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
378 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
379 &matchings,
380 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
381 &matchings2,
382 float &loss) {
383 std::vector<bool> isTrain(trees.size(), false);
384 float tempLoss;
385 return forwardStep(trees, trees2, indexes, isTrain, k, allAlphasInit,
386 computeError, outs, outs2, bestAlphas, layersOuts,
387 layersOuts2, matchings, matchings2, tempLoss, loss);
388}
389
390// ---------------------------------------------------------------------------
391// --- Projection
392// ---------------------------------------------------------------------------
393void ttk::MergeTreeNeuralNetwork::projectionStep() {
394 for(unsigned int l = 0; l < noLayers_; ++l)
395 layers_[l].projectionStep();
396}
397
398// -----------------------------------------------------------------------
399// --- Convergence
400// -----------------------------------------------------------------------
401bool ttk::MergeTreeNeuralNetwork::isBestLoss(float loss,
402 float &minLoss,
403 unsigned int &cptBlocked) {
404 bool isBestEnergy = false;
405 if(loss + ENERGY_COMPARISON_TOLERANCE < minLoss) {
406 minLoss = loss;
407 cptBlocked = 0;
408 isBestEnergy = true;
409 }
410 return isBestEnergy;
411}
412
413bool ttk::MergeTreeNeuralNetwork::convergenceStep(float loss,
414 float &oldLoss,
415 float &minLoss,
416 unsigned int &cptBlocked) {
417 double tol = oldLoss / 125.0;
418 bool converged = std::abs(loss - oldLoss) < std::abs(tol);
419 oldLoss = loss;
420 if(not converged) {
421 cptBlocked += (minLoss < loss) ? 1 : 0;
422 converged = (cptBlocked >= 10 * 10);
423 if(converged)
425 }
426 return converged;
427}
428
429// -----------------------------------------------------------------------
430// --- Main Functions
431// -----------------------------------------------------------------------
432void ttk::MergeTreeNeuralNetwork::fit(
433 std::vector<ftm::MergeTree<float>> &trees,
434 std::vector<ftm::MergeTree<float>> &trees2) {
435 torch::set_num_threads(1);
436 if(useGpu_) {
437 if(torch::cuda::device_count() > 0 and torch::cuda::is_available())
438 printMsg("Computation with GPU support.");
439 else {
440 printMsg("Disabling GPU support because no device were found.");
441 useGpu_ = false;
442 // TODO cache useGpu parameter to be in accordance with ParaView GUI
443 }
444 } else {
445 printMsg("Computation without GPU support.");
446 }
447 // ----- Determinism
448 if(deterministic_) {
449 int m_seed = 0;
450 bool m_torch_deterministic = true;
451 srand(m_seed);
452 torch::manual_seed(m_seed);
453 at::globalContext().setDeterministicCuDNN(m_torch_deterministic ? true
454 : false);
455 if(not useGpu_)
456 at::globalContext().setDeterministicAlgorithms(
457 m_torch_deterministic ? true : false, true);
458 }
459
460 // ----- Testing
461 for(unsigned int i = 0; i < trees.size(); ++i) {
462 for(unsigned int n = 0; n < trees[i].tree.getNumberOfNodes(); ++n) {
463 if(trees[i].tree.isNodeAlone(n))
464 continue;
465 auto birthDeath = trees[i].tree.template getBirthDeath<float>(n);
466 bigValuesThreshold_
467 = std::max(std::abs(std::get<0>(birthDeath)), bigValuesThreshold_);
468 bigValuesThreshold_
469 = std::max(std::abs(std::get<1>(birthDeath)), bigValuesThreshold_);
470 }
471 }
472 bigValuesThreshold_ *= 100;
473
474 // ----- Convert MergeTree to TorchMergeTree
475 std::vector<mtu::TorchMergeTree<float>> torchTrees, torchTrees2;
476 mergeTreesToTorchTrees(trees, torchTrees, normalizedWasserstein_);
477 mergeTreesToTorchTrees(trees2, torchTrees2, normalizedWasserstein_);
478 if(useGpu_) {
479 for(unsigned i = 0; i < torchTrees.size(); ++i)
480 torchTrees[i].tensor = torchTrees[i].tensor.cuda();
481 for(unsigned i = 0; i < torchTrees2.size(); ++i)
482 torchTrees2[i].tensor = torchTrees2[i].tensor.cuda();
483 }
484
485 auto initRecs = [](std::vector<std::vector<mtu::TorchMergeTree<float>>> &recs,
486 std::vector<mtu::TorchMergeTree<float>> &torchTreesT) {
487 recs.clear();
488 recs.resize(torchTreesT.size());
489 for(unsigned int i = 0; i < torchTreesT.size(); ++i) {
490 mtu::TorchMergeTree<float> tmt;
491 mtu::copyTorchMergeTree<float>(torchTreesT[i], tmt);
492 recs[i].emplace_back(tmt);
493 }
494 };
495 initRecs(recs_, torchTrees);
496 if(useDoubleInput_)
497 initRecs(recs2_, torchTrees2);
498
499 // --- Train/Test Split
500 unsigned int trainSize = std::min(
501 std::max((int)(trees.size() * trainTestSplit_), 1), (int)trees.size());
502 std::vector<unsigned int> trainIndexes(trees.size()), testIndexes;
503 std::iota(trainIndexes.begin(), trainIndexes.end(), 0);
504 std::random_device rd;
505 std::default_random_engine rng(deterministic_ ? 0 : rd());
506 bool trainTestSplitted = trainSize != trees.size();
507 if(trainTestSplitted) {
508 if(shuffleBeforeSplit_)
509 std::shuffle(trainIndexes.begin(), trainIndexes.end(), rng);
510 testIndexes.insert(
511 testIndexes.end(), trainIndexes.begin() + trainSize, trainIndexes.end());
512 trainIndexes.resize(trainSize);
513 }
514 std::vector<bool> isTrain(trees.size(), true);
515 for(auto &ind : testIndexes)
516 isTrain[ind] = false;
517
518 // ----- Custom Init
519 customInit(torchTrees, torchTrees2);
520
521 // ----- Init Model Parameters
522 Timer t_init;
523 initStep(torchTrees, torchTrees2, isTrain);
524 printMsg("Init", 1, t_init.getElapsedTime(), threadNumber_);
525
526 // --- Init optimizer
527 std::vector<torch::Tensor> parameters;
528 for(unsigned int l = 0; l < noLayers_; ++l) {
529 parameters.emplace_back(layers_[l].getOrigin().tensor);
530 parameters.emplace_back(layers_[l].getOriginPrime().tensor);
531 parameters.emplace_back(layers_[l].getVSTensor());
532 parameters.emplace_back(layers_[l].getVSPrimeTensor());
533 if(trees2.size() != 0) {
534 parameters.emplace_back(layers_[l].getOrigin2().tensor);
535 parameters.emplace_back(layers_[l].getOrigin2Prime().tensor);
536 parameters.emplace_back(layers_[l].getVS2Tensor());
537 parameters.emplace_back(layers_[l].getVS2PrimeTensor());
538 }
539 }
540 addCustomParameters(parameters);
541
542 torch::optim::Optimizer *optimizer;
543 // - Init Adam
544 auto adamOptions = torch::optim::AdamOptions(gradientStepSize_);
545 adamOptions.betas(std::make_tuple(beta1_, beta2_));
546 auto adamOptimizer = torch::optim::Adam(parameters, adamOptions);
547 // - Init SGD optimizer
548 auto sgdOptions = torch::optim::SGDOptions(gradientStepSize_);
549 auto sgdOptimizer = torch::optim::SGD(parameters, sgdOptions);
550 // -Init RMSprop optimizer
551 auto rmspropOptions = torch::optim::RMSpropOptions(gradientStepSize_);
552 auto rmspropOptimizer = torch::optim::RMSprop(parameters, rmspropOptions);
553 // - Set optimizer pointer
554 switch(optimizer_) {
555 case 1:
556 optimizer = &sgdOptimizer;
557 break;
558 case 2:
559 optimizer = &rmspropOptimizer;
560 break;
561 case 0:
562 default:
563 optimizer = &adamOptimizer;
564 }
565
566 // --- Print train/test split
567 if(trainTestSplitted) {
568 std::stringstream ss;
569 ss << "trainSize = " << trainIndexes.size() << " / " << trees.size();
570 printMsg(ss.str());
571 ss.str("");
572 ss << "testSize = " << testIndexes.size() << " / " << trees.size();
573 printMsg(ss.str());
574 }
575
576 // --- Init batches indexes
577 unsigned int batchSize
578 = std::min(std::max((int)(trainIndexes.size() * batchSize_), 1),
579 (int)trainIndexes.size());
580 std::stringstream ssBatch;
581 ssBatch << "batchSize = " << batchSize;
582 printMsg(ssBatch.str());
583 unsigned int noBatch = trainIndexes.size() / batchSize
584 + ((trainIndexes.size() % batchSize) != 0 ? 1 : 0);
585 std::vector<std::vector<unsigned int>> allIndexes(noBatch);
586 if(noBatch == 1) {
587 // Yes, trees.size() below is correct and it is not trainIndexes.size(), the
588 // goal is to forward everyone (even test data) if noBatch == 1 to benefit
589 // from full parallelism, but only train data will be used for backward.
590 allIndexes[0].resize(trees.size());
591 std::iota(allIndexes[0].begin(), allIndexes[0].end(), 0);
592 }
593
594 // ----- Testing
595 originsNoZeroGrad_.resize(noLayers_);
596 originsPrimeNoZeroGrad_.resize(noLayers_);
597 vSNoZeroGrad_.resize(noLayers_);
598 vSPrimeNoZeroGrad_.resize(noLayers_);
599 for(unsigned int l = 0; l < noLayers_; ++l) {
600 originsNoZeroGrad_[l] = 0;
601 originsPrimeNoZeroGrad_[l] = 0;
602 vSNoZeroGrad_[l] = 0;
603 vSPrimeNoZeroGrad_[l] = 0;
604 }
605 if(useDoubleInput_) {
606 origins2NoZeroGrad_.resize(noLayers_);
607 origins2PrimeNoZeroGrad_.resize(noLayers_);
608 vS2NoZeroGrad_.resize(noLayers_);
609 vS2PrimeNoZeroGrad_.resize(noLayers_);
610 for(unsigned int l = 0; l < noLayers_; ++l) {
611 origins2NoZeroGrad_[l] = 0;
612 origins2PrimeNoZeroGrad_[l] = 0;
613 vS2NoZeroGrad_[l] = 0;
614 vS2PrimeNoZeroGrad_[l] = 0;
615 }
616 }
617
618 // ----- Init Variables
619 unsigned int k = k_;
620 float oldLoss, minLoss, minTestLoss;
621 std::vector<float> minCustomLoss;
622 unsigned int cptBlocked, iteration = 0;
623 auto initLoop = [&]() {
624 oldLoss = -1;
625 minLoss = std::numeric_limits<float>::max();
626 minTestLoss = std::numeric_limits<float>::max();
627 cptBlocked = 0;
628 iteration = 0;
629 };
630 initLoop();
631 int convWinSize = 5;
632 int noConverged = 0, noConvergedToGet = 10;
633 std::vector<float> gapLosses, gapTestLosses;
634 std::vector<std::vector<float>> gapCustomLosses;
635 float windowLoss = 0;
636
637 double assignmentTime = 0.0, updateTime = 0.0, projectionTime = 0.0,
638 lossTime = 0.0;
639
640 int bestIteration = 0;
641 std::vector<torch::Tensor> bestVSTensor, bestVSPrimeTensor, bestVS2Tensor,
642 bestVS2PrimeTensor;
643 std::vector<mtu::TorchMergeTree<float>> bestOrigins, bestOriginsPrime,
644 bestOrigins2, bestOrigins2Prime;
645 std::vector<std::vector<torch::Tensor>> bestAlphasInit;
646 std::vector<std::vector<mtu::TorchMergeTree<float>>> bestRecs, bestRecs2;
647 double bestTime = 0;
648
649 auto printLoss = [this, trainTestSplitted](
650 float loss, float testLoss, std::vector<float> &customLoss,
651 int iterationT, int iterationTT, double time,
652 const debug::Priority &priority = debug::Priority::INFO) {
653 std::stringstream prefix;
654 prefix << (priority == debug::Priority::VERBOSE ? "Iter " : "Best ");
655 std::stringstream ssBestLoss;
656 ssBestLoss << prefix.str() << "loss is " << loss << " (iteration "
657 << iterationT << " / " << iterationTT << ") at time " << time;
658 printMsg(ssBestLoss.str(), priority);
659 if(trainTestSplitted) {
660 ssBestLoss.str("");
661 ssBestLoss << prefix.str() << "test loss is " << testLoss;
662 printMsg(ssBestLoss.str(), priority);
663 }
664 printCustomLosses(customLoss, prefix, priority);
665 };
666
667 auto copyAlphas = [this](std::vector<std::vector<torch::Tensor>> &alphas,
668 std::vector<unsigned int> &indexes) {
669 for(unsigned int ind = 0; ind < indexes.size(); ++ind) {
670 unsigned int i = indexes[ind];
671 for(unsigned int j = 0; j < alphas[i].size(); ++j)
672 mtu::copyTensor(alphas[i][j], allAlphas_[i][j]);
673 }
674 };
675
676 // ----- Algorithm
677 Timer t_alg;
678 bool converged = false;
679 while(not converged) {
680 if(iteration % iterationGap_ == 0) {
681 std::stringstream ss;
682 ss << "Iteration " << iteration;
684 printMsg(ss.str());
685 }
686
687 bool forwardReset = false;
688 std::vector<float> iterationLosses, iterationTestLosses;
689 std::vector<std::vector<float>> iterationCustomLosses;
690 if(noBatch != 1) {
691 std::vector<unsigned int> indexes = trainIndexes;
692 std::shuffle(std::begin(indexes), std::end(indexes), rng);
693 for(unsigned int i = 0; i < allIndexes.size(); ++i) {
694 unsigned int noProcessed = batchSize * i;
695 unsigned int remaining = trainIndexes.size() - noProcessed;
696 unsigned int size = std::min(batchSize, remaining);
697 allIndexes[i].resize(size);
698 for(unsigned int j = 0; j < size; ++j)
699 allIndexes[i][j] = indexes[noProcessed + j];
700 }
701 }
702 for(unsigned batchNum = 0; batchNum < allIndexes.size(); ++batchNum) {
703 auto &indexes = allIndexes[batchNum];
704
705 // --- Forward
706 Timer t_assignment;
707 std::vector<mtu::TorchMergeTree<float>> outs, outs2;
708 std::vector<std::vector<torch::Tensor>> bestAlphas;
709 std::vector<std::vector<mtu::TorchMergeTree<float>>> layersOuts,
710 layersOuts2;
711 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
712 matchings, matchings2;
713 float loss, testLoss;
714 bool computeError = true;
715 forwardReset
716 = forwardStep(torchTrees, torchTrees2, indexes, isTrain, k, allAlphas_,
717 computeError, outs, outs2, bestAlphas, layersOuts,
718 layersOuts2, matchings, matchings2, loss, testLoss);
719 if(forwardReset)
720 break;
721 copyAlphas(bestAlphas, indexes);
722 assignmentTime += t_assignment.getElapsedTime();
723
724 // --- Loss
725 Timer t_loss;
726 gapLosses.emplace_back(loss);
727 iterationLosses.emplace_back(loss);
728 if(noBatch == 1 and trainTestSplitted) {
729 gapTestLosses.emplace_back(testLoss);
730 iterationTestLosses.emplace_back(testLoss);
731 }
732 std::vector<torch::Tensor> torchCustomLoss;
733 computeCustomLosses(layersOuts, layersOuts2, bestAlphas, indexes, isTrain,
734 iteration, gapCustomLosses, iterationCustomLosses,
735 torchCustomLoss);
736 lossTime += t_loss.getElapsedTime();
737
738 // --- Backward
739 Timer t_update;
740 backwardStep(torchTrees, outs, matchings, torchTrees2, outs2, matchings2,
741 bestAlphas, *optimizer, indexes, isTrain, torchCustomLoss);
742 updateTime += t_update.getElapsedTime();
743
744 // --- Projection
745 Timer t_projection;
746 projectionStep();
747 projectionTime += t_projection.getElapsedTime();
748 } // end batch
749
750 if(noBatch != 1 and trainTestSplitted) {
751 std::vector<mtu::TorchMergeTree<float>> outs, outs2;
752 std::vector<std::vector<torch::Tensor>> bestAlphas;
753 std::vector<std::vector<mtu::TorchMergeTree<float>>> layersOuts,
754 layersOuts2;
755 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
756 matchings, matchings2;
757 float loss, testLoss;
758 bool computeError = true;
759 forwardStep(torchTrees, torchTrees2, testIndexes, isTrain, k, allAlphas_,
760 computeError, outs, outs2, bestAlphas, layersOuts,
761 layersOuts2, matchings, matchings2, loss, testLoss);
762 copyAlphas(bestAlphas, testIndexes);
763 gapTestLosses.emplace_back(testLoss);
764 iterationTestLosses.emplace_back(testLoss);
765 std::vector<torch::Tensor> torchCustomLoss;
766 computeCustomLosses(layersOuts, layersOuts2, bestAlphas, testIndexes,
767 isTrain, iteration, gapCustomLosses,
768 iterationCustomLosses, torchCustomLoss);
769 }
770
771 if(forwardReset) {
772 // TODO better manage reset by init new parameters and start again for
773 // example (should not happen anymore)
774 printWrn("Forward reset!");
775 break;
776 }
777
778 // --- Get iteration loss
779 // TODO an approximation is made here if batch size != 1 because the
780 // iteration loss that will be printed will not be exact, we need to do a
781 // forward step and compute loss with the whole dataset
782 float iterationLoss = torch::tensor(iterationLosses).mean().item<float>();
783 float iterationTestLoss
784 = torch::tensor(iterationTestLosses).mean().item<float>();
785 std::vector<float> iterationCustomLoss;
786 float iterationTotalLoss = computeIterationTotalLoss(
787 iterationLoss, iterationCustomLosses, iterationCustomLoss);
788 printLoss(iterationTotalLoss, iterationTestLoss, iterationCustomLoss,
789 iteration, iteration,
790 t_alg.getElapsedTime() - t_allVectorCopy_time_,
792
793 // --- Update best parameters
794 bool isBest = false;
795 if(not trainTestSplitted)
796 isBest = isBestLoss(iterationTotalLoss, minLoss, cptBlocked);
797 else {
798 // TODO generalize these lines when accuracy is not the metric computed or
799 // evaluated
800 if(minCustomLoss.empty())
801 isBest = true;
802 else {
803 float minusAcc = -iterationCustomLoss[1];
804 float minMinusAcc = -minCustomLoss[1];
805 isBest = isBestLoss(minusAcc, minMinusAcc, cptBlocked);
806 }
807 }
808 if(isBest) {
809 Timer t_copy;
810 bestIteration = iteration;
811 copyParams(bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor,
812 bestOrigins2, bestOrigins2Prime, bestVS2Tensor,
813 bestVS2PrimeTensor, allAlphas_, bestAlphasInit, true);
814 copyCustomParams(true);
815 copyParams(recs_, bestRecs);
816 copyParams(recs2_, bestRecs2);
817 t_allVectorCopy_time_ += t_copy.getElapsedTime();
818 bestTime = t_alg.getElapsedTime() - t_allVectorCopy_time_;
819 minCustomLoss = iterationCustomLoss;
820 if(trainTestSplitted) {
821 minLoss = iterationTotalLoss;
822 minTestLoss = iterationTestLoss;
823 }
824 printLoss(minLoss, minTestLoss, minCustomLoss, bestIteration, iteration,
825 bestTime, debug::Priority::DETAIL);
826 }
827
828 // --- Convergence
829 windowLoss += iterationTotalLoss;
830 if((iteration + 1) % convWinSize == 0) {
831 windowLoss /= convWinSize;
832 converged = convergenceStep(windowLoss, oldLoss, minLoss, cptBlocked);
833 windowLoss = 0;
834 if(converged) {
835 ++noConverged;
836 } else
837 noConverged = 0;
838 converged = noConverged >= noConvergedToGet;
839 if(converged and iteration < minIteration_)
840 printMsg("convergence is detected but iteration < minIteration_",
842 if(iteration < minIteration_)
843 converged = false;
844 if(converged)
845 break;
846 }
847
848 // --- Print
849 if(iteration % iterationGap_ == 0) {
850 printMsg("Assignment", 1, assignmentTime, threadNumber_);
851 printMsg("Loss", 1, lossTime, threadNumber_);
852 printMsg("Update", 1, updateTime, threadNumber_);
853 printMsg("Projection", 1, projectionTime, threadNumber_);
854 assignmentTime = 0.0;
855 lossTime = 0.0;
856 updateTime = 0.0;
857 projectionTime = 0.0;
858 float loss = torch::tensor(gapLosses).mean().item<float>();
859 gapLosses.clear();
860 float testLoss = torch::tensor(gapTestLosses).mean().item<float>();
861 gapTestLosses.clear();
862 if(trainTestSplitted) {
863 std::stringstream ss;
864 ss << "Test Loss = " << testLoss;
865 printMsg(ss.str());
866 }
867 printGapLoss(loss, gapCustomLosses);
868
869 // Verify grad and big values (testing)
870 for(unsigned int l = 0; l < noLayers_; ++l) {
871 std::stringstream ss;
872 if(originsNoZeroGrad_[l] != 0)
873 ss << originsNoZeroGrad_[l] << " originsNoZeroGrad_[" << l << "]"
874 << std::endl;
875 if(originsPrimeNoZeroGrad_[l] != 0)
876 ss << originsPrimeNoZeroGrad_[l] << " originsPrimeNoZeroGrad_[" << l
877 << "]" << std::endl;
878 if(vSNoZeroGrad_[l] != 0)
879 ss << vSNoZeroGrad_[l] << " vSNoZeroGrad_[" << l << "]" << std::endl;
880 if(vSPrimeNoZeroGrad_[l] != 0)
881 ss << vSPrimeNoZeroGrad_[l] << " vSPrimeNoZeroGrad_[" << l << "]"
882 << std::endl;
883 originsNoZeroGrad_[l] = 0;
884 originsPrimeNoZeroGrad_[l] = 0;
885 vSNoZeroGrad_[l] = 0;
886 vSPrimeNoZeroGrad_[l] = 0;
887 if(useDoubleInput_) {
888 if(origins2NoZeroGrad_[l] != 0)
889 ss << origins2NoZeroGrad_[l] << " origins2NoZeroGrad_[" << l << "]"
890 << std::endl;
891 if(origins2PrimeNoZeroGrad_[l] != 0)
892 ss << origins2PrimeNoZeroGrad_[l] << " origins2PrimeNoZeroGrad_["
893 << l << "]" << std::endl;
894 if(vS2NoZeroGrad_[l] != 0)
895 ss << vS2NoZeroGrad_[l] << " vS2NoZeroGrad_[" << l << "]"
896 << std::endl;
897 if(vS2PrimeNoZeroGrad_[l] != 0)
898 ss << vS2PrimeNoZeroGrad_[l] << " vS2PrimeNoZeroGrad_[" << l << "]"
899 << std::endl;
900 origins2NoZeroGrad_[l] = 0;
901 origins2PrimeNoZeroGrad_[l] = 0;
902 vS2NoZeroGrad_[l] = 0;
903 vS2PrimeNoZeroGrad_[l] = 0;
904 }
905 if(isTreeHasBigValues(
906 layers_[l].getOrigin().mTree, bigValuesThreshold_))
907 ss << "origins_[" << l << "] has big values!" << std::endl;
908 if(isTreeHasBigValues(
909 layers_[l].getOriginPrime().mTree, bigValuesThreshold_))
910 ss << "originsPrime_[" << l << "] has big values!" << std::endl;
911 if(ss.rdbuf()->in_avail() != 0)
913 }
914 }
915
916 ++iteration;
917 if(maxIteration_ != 0 and iteration >= maxIteration_) {
918 printMsg("iteration >= maxIteration_", debug::Priority::DETAIL);
919 break;
920 }
921 }
923 printLoss(
924 minLoss, minTestLoss, minCustomLoss, bestIteration, iteration, bestTime);
926 bestLoss_ = (trainTestSplitted ? minTestLoss : minLoss);
927
928 Timer t_copy;
929 copyParams(bestOrigins, bestOriginsPrime, bestVSTensor, bestVSPrimeTensor,
930 bestOrigins2, bestOrigins2Prime, bestVS2Tensor, bestVS2PrimeTensor,
931 bestAlphasInit, allAlphas_, false);
932 copyCustomParams(false);
933 copyParams(bestRecs, recs_);
934 copyParams(bestRecs2, recs2_);
935 t_allVectorCopy_time_ += t_copy.getElapsedTime();
936 printMsg("Copy time", 1, t_allVectorCopy_time_, threadNumber_);
937}
938
939// ---------------------------------------------------------------------------
940// --- End Functions
941// ---------------------------------------------------------------------------
942void ttk::MergeTreeNeuralNetwork::computeTrackingInformation(
943 unsigned int endLayer) {
944 originsMatchings_.resize(endLayer);
945#ifdef TTK_ENABLE_OPENMP
946#pragma omp parallel for schedule(dynamic) \
947 num_threads(this->threadNumber_) if(parallelize_)
948#endif
949 for(unsigned int l = 0; l < endLayer; ++l) {
950 auto &tree1
951 = (l == 0 ? layers_[0].getOrigin() : layers_[l - 1].getOriginPrime());
952 auto &tree2
953 = (l == 0 ? layers_[0].getOriginPrime() : layers_[l].getOriginPrime());
954 bool isCalled = true;
955 float distance;
956 computeOneDistance<float>(tree1.mTree, tree2.mTree, originsMatchings_[l],
957 distance, isCalled, useDoubleInput_);
958 }
959
960 // Data matchings
961 ++endLayer;
962 dataMatchings_.resize(endLayer);
963 for(unsigned int l = 0; l < endLayer; ++l) {
964 dataMatchings_[l].resize(recs_.size());
965#ifdef TTK_ENABLE_OPENMP
966#pragma omp parallel for schedule(dynamic) \
967 num_threads(this->threadNumber_) if(parallelize_)
968#endif
969 for(unsigned int i = 0; i < recs_.size(); ++i) {
970 bool isCalled = true;
971 float distance;
972 auto &origin
973 = (l == 0 ? layers_[0].getOrigin() : layers_[l - 1].getOriginPrime());
974 computeOneDistance<float>(origin.mTree, recs_[i][l].mTree,
975 dataMatchings_[l][i], distance, isCalled,
976 useDoubleInput_);
977 }
978 }
979
980 // Reconst matchings
981 reconstMatchings_.resize(recs_.size());
982#ifdef TTK_ENABLE_OPENMP
983#pragma omp parallel for schedule(dynamic) \
984 num_threads(this->threadNumber_) if(parallelize_)
985#endif
986 for(unsigned int i = 0; i < recs_.size(); ++i) {
987 bool isCalled = true;
988 float distance;
989 auto l = recs_[i].size() - 1;
990 computeOneDistance<float>(recs_[i][0].mTree, recs_[i][l].mTree,
991 reconstMatchings_[i], distance, isCalled,
992 useDoubleInput_);
993 }
994}
995
996void ttk::MergeTreeNeuralNetwork::computeCorrelationMatrix(
997 std::vector<ftm::MergeTree<float>> &trees, unsigned int layer) {
998 std::vector<std::vector<double>> allTs;
999 auto noGeod = allAlphas_[0][layer].sizes()[0];
1000 allTs.resize(noGeod);
1001 for(unsigned int i = 0; i < noGeod; ++i) {
1002 allTs[i].resize(allAlphas_.size());
1003 for(unsigned int j = 0; j < allAlphas_.size(); ++j)
1004 allTs[i][j] = allAlphas_[j][layer][i].item<double>();
1005 }
1006 computeBranchesCorrelationMatrix(
1007 layers_[0].getOrigin().mTree, trees, dataMatchings_[0], allTs,
1008 branchesCorrelationMatrix_, persCorrelationMatrix_);
1009}
1010
1011void ttk::MergeTreeNeuralNetwork::createScaledAlphas(
1012 std::vector<std::vector<torch::Tensor>> &alphas,
1013 std::vector<std::vector<torch::Tensor>> &scaledAlphas) {
1014 scaledAlphas.clear();
1015 scaledAlphas.resize(
1016 alphas.size(), std::vector<torch::Tensor>(alphas[0].size()));
1017 for(unsigned int l = 0; l < alphas[0].size(); ++l) {
1018 torch::Tensor scale = layers_[l].getVSTensor().pow(2).sum(0).sqrt();
1019#ifdef TTK_ENABLE_OPENMP
1020#pragma omp parallel for schedule(dynamic) \
1021 num_threads(this->threadNumber_) if(parallelize_)
1022#endif
1023 for(unsigned int i = 0; i < alphas.size(); ++i) {
1024 scaledAlphas[i][l] = alphas[i][l] * scale.reshape({-1, 1});
1025 }
1026 }
1027}
1028
1029void ttk::MergeTreeNeuralNetwork::createScaledAlphas() {
1030 createScaledAlphas(allAlphas_, allScaledAlphas_);
1031}
1032
1033void ttk::MergeTreeNeuralNetwork::createActivatedAlphas() {
1034 allActAlphas_ = allAlphas_;
1035#ifdef TTK_ENABLE_OPENMP
1036#pragma omp parallel for schedule(dynamic) \
1037 num_threads(this->threadNumber_) if(parallelize_)
1038#endif
1039 for(unsigned int i = 0; i < allActAlphas_.size(); ++i)
1040 for(unsigned int j = 0; j < allActAlphas_[i].size(); ++j)
1041 allActAlphas_[i][j] = activation(allActAlphas_[i][j]);
1042 createScaledAlphas(allActAlphas_, allActScaledAlphas_);
1043}
1044
1045// ---------------------------------------------------------------------------
1046// --- Utils
1047// ---------------------------------------------------------------------------
1048void ttk::MergeTreeNeuralNetwork::copyParams(
1049 std::vector<mtu::TorchMergeTree<float>> &origins,
1050 std::vector<mtu::TorchMergeTree<float>> &originsPrime,
1051 std::vector<torch::Tensor> &vS,
1052 std::vector<torch::Tensor> &vSPrime,
1053 std::vector<mtu::TorchMergeTree<float>> &origins2,
1054 std::vector<mtu::TorchMergeTree<float>> &origins2Prime,
1055 std::vector<torch::Tensor> &vS2,
1056 std::vector<torch::Tensor> &vS2Prime,
1057 std::vector<std::vector<torch::Tensor>> &srcAlphas,
1058 std::vector<std::vector<torch::Tensor>> &dstAlphas,
1059 bool get) {
1060 dstAlphas.resize(srcAlphas.size(), std::vector<torch::Tensor>(noLayers_));
1061 if(get) {
1062 origins.resize(noLayers_);
1063 originsPrime.resize(noLayers_);
1064 vS.resize(noLayers_);
1065 vSPrime.resize(noLayers_);
1066 if(useDoubleInput_) {
1067 origins2.resize(noLayers_);
1068 origins2Prime.resize(noLayers_);
1069 vS2.resize(noLayers_);
1070 vS2Prime.resize(noLayers_);
1071 }
1072 }
1073 for(unsigned int l = 0; l < noLayers_; ++l) {
1074 layers_[l].copyParams(origins[l], originsPrime[l], vS[l], vSPrime[l],
1075 origins2[l], origins2Prime[l], vS2[l], vS2Prime[l],
1076 get);
1077#ifdef TTK_ENABLE_OPENMP
1078#pragma omp parallel for schedule(dynamic) \
1079 num_threads(this->threadNumber_) if(parallelize_)
1080#endif
1081 for(unsigned int i = 0; i < srcAlphas.size(); ++i)
1082 mtu::copyTensor(srcAlphas[i][l], dstAlphas[i][l]);
1083 }
1084}
1085
1086void ttk::MergeTreeNeuralNetwork::copyParams(
1087 std::vector<std::vector<mtu::TorchMergeTree<float>>> &src,
1088 std::vector<std::vector<mtu::TorchMergeTree<float>>> &dst) {
1089 dst.resize(src.size());
1090#ifdef TTK_ENABLE_OPENMP
1091#pragma omp parallel for schedule(dynamic) \
1092 num_threads(this->threadNumber_) if(parallelize_)
1093#endif
1094 for(unsigned int i = 0; i < src.size(); ++i) {
1095 dst[i].resize(src[i].size());
1096 for(unsigned int j = 0; j < src[i].size(); ++j)
1097 mtu::copyTorchMergeTree(src[i][j], dst[i][j]);
1098 }
1099}
1100
1101void ttk::MergeTreeNeuralNetwork::getAlphasTensor(
1102 std::vector<std::vector<torch::Tensor>> &alphas,
1103 std::vector<unsigned int> &indexes,
1104 std::vector<bool> &toGet,
1105 unsigned int layerIndex,
1106 torch::Tensor &alphasOut) {
1107 unsigned int beg = 0;
1108 while(not toGet[indexes[beg]])
1109 ++beg;
1110 alphasOut = alphas[indexes[beg]][layerIndex].transpose(0, 1);
1111 for(unsigned int ind = beg + 1; ind < indexes.size(); ++ind) {
1112 if(not toGet[indexes[ind]])
1113 continue;
1114 alphasOut = torch::cat(
1115 {alphasOut, alphas[indexes[ind]][layerIndex].transpose(0, 1)});
1116 }
1117}
1118
1119void ttk::MergeTreeNeuralNetwork::getAlphasTensor(
1120 std::vector<std::vector<torch::Tensor>> &alphas,
1121 std::vector<unsigned int> &indexes,
1122 unsigned int layerIndex,
1123 torch::Tensor &alphasOut) {
1124 std::vector<bool> toGet(indexes.size(), true);
1125 getAlphasTensor(alphas, indexes, toGet, layerIndex, alphasOut);
1126}
1127
1128void ttk::MergeTreeNeuralNetwork::getAlphasTensor(
1129 std::vector<std::vector<torch::Tensor>> &alphas,
1130 unsigned int layerIndex,
1131 torch::Tensor &alphasOut) {
1132 std::vector<unsigned int> indexes(alphas.size());
1133 std::iota(indexes.begin(), indexes.end(), 0);
1134 getAlphasTensor(alphas, indexes, layerIndex, alphasOut);
1135}
1136
1137// ---------------------------------------------------------------------------
1138// --- Testing
1139// ---------------------------------------------------------------------------
1140void ttk::MergeTreeNeuralNetwork::checkZeroGrad(unsigned int l,
1141 bool checkOutputBasis) {
1142 if(not layers_[l].getOrigin().tensor.grad().defined()
1143 or not layers_[l].getOrigin().tensor.grad().count_nonzero().is_nonzero())
1144 ++originsNoZeroGrad_[l];
1145 if(not layers_[l].getVSTensor().grad().defined()
1146 or not layers_[l].getVSTensor().grad().count_nonzero().is_nonzero())
1147 ++vSNoZeroGrad_[l];
1148 if(checkOutputBasis) {
1149 if(not layers_[l].getOriginPrime().tensor.grad().defined()
1150 or not layers_[l]
1151 .getOriginPrime()
1152 .tensor.grad()
1153 .count_nonzero()
1154 .is_nonzero())
1155 ++originsPrimeNoZeroGrad_[l];
1156 if(not layers_[l].getVSPrimeTensor().grad().defined()
1157 or not layers_[l].getVSPrimeTensor().grad().count_nonzero().is_nonzero())
1158 ++vSPrimeNoZeroGrad_[l];
1159 }
1160 if(useDoubleInput_) {
1161 if(not layers_[l].getOrigin2().tensor.grad().defined()
1162 or not layers_[l]
1163 .getOrigin2()
1164 .tensor.grad()
1165 .count_nonzero()
1166 .is_nonzero())
1167 ++origins2NoZeroGrad_[l];
1168 if(not layers_[l].getVS2Tensor().grad().defined()
1169 or not layers_[l].getVS2Tensor().grad().count_nonzero().is_nonzero())
1170 ++vS2NoZeroGrad_[l];
1171 if(checkOutputBasis) {
1172 if(not layers_[l].getOrigin2Prime().tensor.grad().defined()
1173 or not layers_[l]
1174 .getOrigin2Prime()
1175 .tensor.grad()
1176 .count_nonzero()
1177 .is_nonzero())
1178 ++origins2PrimeNoZeroGrad_[l];
1179 if(not layers_[l].getVS2PrimeTensor().grad().defined()
1180 or not layers_[l]
1181 .getVS2PrimeTensor()
1182 .grad()
1183 .count_nonzero()
1184 .is_nonzero())
1185 ++vS2PrimeNoZeroGrad_[l];
1186 }
1187 }
1188}
1189
1190bool ttk::MergeTreeNeuralNetwork::isTreeHasBigValues(
1191 const ftm::MergeTree<float> &mTree, float threshold) {
1192 bool found = false;
1193 for(unsigned int n = 0; n < mTree.tree.getNumberOfNodes(); ++n) {
1194 if(mTree.tree.isNodeAlone(n))
1195 continue;
1196 auto birthDeath = mTree.tree.template getBirthDeath<float>(n);
1197 if(std::abs(std::get<0>(birthDeath)) > threshold
1198 or std::abs(std::get<1>(birthDeath)) > threshold) {
1199 found = true;
1200 break;
1201 }
1202 }
1203 return found;
1204}
1205#endif
1206
1207// ---------------------------------------------------------------------------
1208// --- Main Functions
1209// ---------------------------------------------------------------------------
1211 std::vector<ftm::MergeTree<float>> &trees,
1212 std::vector<ftm::MergeTree<float>> &trees2) {
1213#ifndef TTK_ENABLE_TORCH
1214 TTK_FORCE_USE(trees);
1215 TTK_FORCE_USE(trees2);
1216 printErr("This module requires Torch.");
1217#else
1218#ifdef TTK_ENABLE_OPENMP
1219 int ompNested = omp_get_nested();
1220 omp_set_nested(1);
1221#endif
1222 // makeExponentialExample(trees, trees2);
1223
1224 // --- Preprocessing
1225 Timer t_preprocess;
1227 if(trees2.size() != 0)
1229 printMsg("Preprocessing", 1, t_preprocess.getElapsedTime(), threadNumber_);
1230 useDoubleInput_ = (trees2.size() != 0);
1231
1232 // --- Fit neural network
1233 Timer t_total;
1234 fit(trees, trees2);
1235 auto totalTime = t_total.getElapsedTime() - t_allVectorCopy_time_;
1237 printMsg("Total time", 1, totalTime, threadNumber_);
1238
1239 // --- End functions
1240 Timer t_end;
1241 createScaledAlphas();
1242 createActivatedAlphas();
1243 executeEndFunction(trees, trees2);
1244 printMsg("End functions", 1, t_end.getElapsedTime(), threadNumber_);
1245
1246 // --- Postprocessing
1247 if(createOutput_) {
1248#ifdef TTK_ENABLE_OPENMP
1249#pragma omp parallel for schedule(dynamic) \
1250 num_threads(this->threadNumber_) if(parallelize_)
1251#endif
1252 for(unsigned int i = 0; i < trees.size(); ++i)
1253 postprocessingPipeline<float>(&(trees[i].tree));
1254#ifdef TTK_ENABLE_OPENMP
1255#pragma omp parallel for schedule(dynamic) \
1256 num_threads(this->threadNumber_) if(parallelize_)
1257#endif
1258 for(unsigned int i = 0; i < trees2.size(); ++i)
1259 postprocessingPipeline<float>(&(trees2[i].tree));
1260
1261 originsCopy_.resize(layers_.size());
1262 originsPrimeCopy_.resize(layers_.size());
1263#ifdef TTK_ENABLE_OPENMP
1264#pragma omp parallel for schedule(dynamic) \
1265 num_threads(this->threadNumber_) if(parallelize_)
1266#endif
1267 for(unsigned int l = 0; l < layers_.size(); ++l) {
1268 mtu::copyTorchMergeTree<float>(layers_[l].getOrigin(), originsCopy_[l]);
1269 mtu::copyTorchMergeTree<float>(
1270 layers_[l].getOriginPrime(), originsPrimeCopy_[l]);
1271 }
1272#ifdef TTK_ENABLE_OPENMP
1273#pragma omp parallel for schedule(dynamic) \
1274 num_threads(this->threadNumber_) if(parallelize_)
1275#endif
1276 for(unsigned int l = 0; l < originsCopy_.size(); ++l) {
1277 fillMergeTreeStructure(originsCopy_[l]);
1278 postprocessingPipeline<float>(&(originsCopy_[l].mTree.tree));
1279 fillMergeTreeStructure(originsPrimeCopy_[l]);
1280 postprocessingPipeline<float>(&(originsPrimeCopy_[l].mTree.tree));
1281 }
1282#ifdef TTK_ENABLE_OPENMP
1283#pragma omp parallel for schedule(dynamic) \
1284 num_threads(this->threadNumber_) if(parallelize_)
1285#endif
1286 for(unsigned int i = 0; i < recs_.size(); ++i) {
1287 for(unsigned int j = 0; j < recs_[i].size(); ++j) {
1288 fixTreePrecisionScalars(recs_[i][j].mTree);
1289 postprocessingPipeline<float>(&(recs_[i][j].mTree.tree));
1290 }
1291 }
1292 }
1293
1294 if(not isPersistenceDiagram_) {
1295 for(unsigned int l = 0; l < originsMatchings_.size(); ++l) {
1296 auto &tree1 = (l == 0 ? originsCopy_[0] : originsPrimeCopy_[l - 1]);
1297 auto &tree2 = (l == 0 ? originsPrimeCopy_[0] : originsPrimeCopy_[l]);
1299 &(tree1.mTree.tree), &(tree2.mTree.tree), originsMatchings_[l]);
1300 }
1301#ifdef TTK_ENABLE_OPENMP
1302#pragma omp parallel for schedule(dynamic) \
1303 num_threads(this->threadNumber_) if(parallelize_)
1304#endif
1305 for(unsigned int i = 0; i < recs_.size(); ++i) {
1306 for(unsigned int l = 0; l < dataMatchings_.size(); ++l) {
1307 auto &origin = (l == 0 ? originsCopy_[0] : originsPrimeCopy_[l - 1]);
1308 convertBranchDecompositionMatching<float>(&(origin.mTree.tree),
1309 &(recs_[i][l].mTree.tree),
1310 dataMatchings_[l][i]);
1311 }
1312 }
1313 for(unsigned int i = 0; i < reconstMatchings_.size(); ++i) {
1314 auto l = recs_[i].size() - 1;
1315 convertBranchDecompositionMatching<float>(&(recs_[i][0].mTree.tree),
1316 &(recs_[i][l].mTree.tree),
1318 }
1319 }
1320#ifdef TTK_ENABLE_OPENMP
1321 omp_set_nested(ompNested);
1322#endif
1323#endif
1324}
#define TTK_FORCE_USE(x)
Force the compiler to use the function/method parameter.
Definition BaseClass.h:57
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
Definition BaseClass.h:47
#define ENERGY_COMPARISON_TOLERANCE
void setDebugMsgPrefix(const std::string &prefix)
Definition Debug.h:364
int printErr(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
Definition Debug.h:149
std::vector< std::vector< int > > trees2NodeCorr_
void preprocessingTrees(std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< int > > &nodeCorr, bool useMinMaxPairT=true)
void convertBranchDecompositionMatching(ftm::FTMTree_MT *tree1, ftm::FTMTree_MT *tree2, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &outputMatching)
void postprocessingPipeline(ftm::FTMTree_MT *tree)
std::vector< std::vector< int > > treesNodeCorr_
void execute(std::vector< ftm::MergeTree< float > > &trees, std::vector< ftm::MergeTree< float > > &trees2)
std::vector< std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > > dataMatchings_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > originsMatchings_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > reconstMatchings_
double getElapsedTime()
Definition Timer.h:15
T distance(const T *p0, const T *p1, const int &dimension=3)
Definition Geometry.cpp:362
T end(std::pair< T, T > &p)
Definition ripser.cpp:503
T begin(std::pair< T, T > &p)
Definition ripser.cpp:499
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/| (_) |"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)