TTK
Loading...
Searching...
No Matches
MergeTreeNeuralLayer.h
Go to the documentation of this file.
1
28
29#pragma once
30
31// ttk common includes
32#include <Debug.h>
33#include <Geometry.h>
34#include <MergeTreeNeuralBase.h>
35#include <MergeTreeTorchUtils.h>
36
37#ifdef TTK_ENABLE_TORCH
38#include <torch/torch.h>
39#endif
40
41namespace ttk {
42
47 class MergeTreeNeuralLayer : virtual public Debug,
48 public MergeTreeNeuralBase {
49
50#ifdef TTK_ENABLE_TORCH
51 // Layer parameters
52 torch::Tensor vSTensor_, vSPrimeTensor_, vS2Tensor_, vS2PrimeTensor_;
53 mtu::TorchMergeTree<float> origin_, originPrime_, origin2_, origin2Prime_;
54#endif
55
56 public:
58
59#ifdef TTK_ENABLE_TORCH
60 // -----------------------------------------------------------------------
61 // --- Getter/Setter
62 // -----------------------------------------------------------------------
63 const mtu::TorchMergeTree<float> &getOrigin() const;
64
65 const mtu::TorchMergeTree<float> &getOriginPrime() const;
66
67 const mtu::TorchMergeTree<float> &getOrigin2() const;
68
69 const mtu::TorchMergeTree<float> &getOrigin2Prime() const;
70
71 const torch::Tensor &getVSTensor() const;
72
73 const torch::Tensor &getVSPrimeTensor() const;
74
75 const torch::Tensor &getVS2Tensor() const;
76
77 const torch::Tensor &getVS2PrimeTensor() const;
78
79 void setOrigin(const mtu::TorchMergeTree<float> &tmt);
80
81 void setOriginPrime(const mtu::TorchMergeTree<float> &tmt);
82
83 void setOrigin2(const mtu::TorchMergeTree<float> &tmt);
84
85 void setOrigin2Prime(const mtu::TorchMergeTree<float> &tmt);
86
87 void setVSTensor(const torch::Tensor &vS);
88
89 void setVSPrimeTensor(const torch::Tensor &vS);
90
91 void setVS2Tensor(const torch::Tensor &vS);
92
93 void setVS2PrimeTensor(const torch::Tensor &vS);
94
95 // -----------------------------------------------------------------------
96 // --- Init
97 // -----------------------------------------------------------------------
108 void initOutputBasisTreeStructure(mtu::TorchMergeTree<float> &originPrime,
109 bool isJT,
110 mtu::TorchMergeTree<float> &baseOrigin);
111
124 void initOutputBasis(const unsigned int dim,
125 const unsigned int dim2,
126 const torch::Tensor &baseTensor);
127
137 void initOutputBasisVectors(torch::Tensor &w, torch::Tensor &w2);
138
147 void initOutputBasisVectors(unsigned int dim, unsigned int dim2);
148
171 void initInputBasisOrigin(
172 std::vector<ftm::MergeTree<float>> &treesToUse,
173 std::vector<ftm::MergeTree<float>> &trees2ToUse,
174 double barycenterSizeLimitPercent,
175 unsigned int barycenterMaxNoPairs,
176 unsigned int barycenterMaxNoPairs2,
177 std::vector<double> &inputToBaryDistances,
178 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
179 &baryMatchings,
180 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
181 &baryMatchings2);
182
213 void initInputBasisVectors(
214 std::vector<mtu::TorchMergeTree<float>> &tmTrees,
215 std::vector<mtu::TorchMergeTree<float>> &tmTrees2,
216 std::vector<ftm::MergeTree<float>> &trees,
217 std::vector<ftm::MergeTree<float>> &trees2,
218 unsigned int noVectors,
219 std::vector<torch::Tensor> &allAlphasInit,
220 std::vector<double> &inputToBaryDistances,
221 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
222 &baryMatchings,
223 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
224 &baryMatchings2,
225 mtu::TorchMergeTree<float> &origin,
226 mtu::TorchMergeTree<float> &origin2,
227 torch::Tensor &vSTensor,
228 torch::Tensor &vS2Tensor,
229 bool useInputBasis = true);
230
256 void initInputBasisVectors(
257 std::vector<mtu::TorchMergeTree<float>> &tmTrees,
258 std::vector<mtu::TorchMergeTree<float>> &tmTrees2,
259 std::vector<ftm::MergeTree<float>> &trees,
260 std::vector<ftm::MergeTree<float>> &trees2,
261 unsigned int noVectors,
262 std::vector<torch::Tensor> &allAlphasInit,
263 std::vector<double> &inputToBaryDistances,
264 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
265 &baryMatchings,
266 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
267 &baryMatchings2,
268 bool useInputBasis = true);
269
270 void requires_grad(const bool requireGrad);
271
272 void cuda();
273
274 // -----------------------------------------------------------------------
275 // --- Interpolation
276 // -----------------------------------------------------------------------
284 void interpolationDiagonalProjection(
285 mtu::TorchMergeTree<float> &interpolationTensor);
286
294 void
295 interpolationNestingProjection(mtu::TorchMergeTree<float> &interpolation);
296
304 void interpolationProjection(mtu::TorchMergeTree<float> &interpolation);
305
316 void getMultiInterpolation(const mtu::TorchMergeTree<float> &origin,
317 const torch::Tensor &vS,
318 torch::Tensor &alphas,
319 mtu::TorchMergeTree<float> &interpolation);
320
321 // -----------------------------------------------------------------------
322 // --- Forward
323 // -----------------------------------------------------------------------
351 void getAlphasOptimizationTensors(
352 mtu::TorchMergeTree<float> &tree,
353 mtu::TorchMergeTree<float> &origin,
354 torch::Tensor &vSTensor,
355 mtu::TorchMergeTree<float> &interpolated,
356 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
357 torch::Tensor &reorderedTreeTensor,
358 torch::Tensor &deltaOrigin,
359 torch::Tensor &deltaA,
360 torch::Tensor &originTensor_f,
361 torch::Tensor &vSTensor_f);
362
387 void computeAlphas(
388 mtu::TorchMergeTree<float> &tree,
389 mtu::TorchMergeTree<float> &origin,
390 torch::Tensor &vSTensor,
391 mtu::TorchMergeTree<float> &interpolated,
392 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
393 mtu::TorchMergeTree<float> &tree2,
394 mtu::TorchMergeTree<float> &origin2,
395 torch::Tensor &vS2Tensor,
396 mtu::TorchMergeTree<float> &interpolated2,
397 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
398 torch::Tensor &alphasOut);
399
426 float assignmentOneData(
427 mtu::TorchMergeTree<float> &tree,
428 mtu::TorchMergeTree<float> &tree2,
429 unsigned int k,
430 torch::Tensor &alphasInit,
431 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching,
432 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching2,
433 torch::Tensor &bestAlphas,
434 bool isCalled = false,
435 bool useInputBasis = true);
436
459 float assignmentOneData(mtu::TorchMergeTree<float> &tree,
460 mtu::TorchMergeTree<float> &tree2,
461 unsigned int k,
462 torch::Tensor &alphasInit,
463 torch::Tensor &bestAlphas,
464 bool isCalled = false,
465 bool useInputBasis = true);
466
479 void outputBasisReconstruction(torch::Tensor &alphas,
480 mtu::TorchMergeTree<float> &out,
481 mtu::TorchMergeTree<float> &out2,
482 bool activate = true,
483 bool train = false);
484
507 bool forward(mtu::TorchMergeTree<float> &tree,
508 mtu::TorchMergeTree<float> &tree2,
509 unsigned int k,
510 torch::Tensor &alphasInit,
511 mtu::TorchMergeTree<float> &out,
512 mtu::TorchMergeTree<float> &out2,
513 torch::Tensor &bestAlphas,
514 float &bestDistance,
515 bool train = false);
516
537 bool forward(mtu::TorchMergeTree<float> &tree,
538 mtu::TorchMergeTree<float> &tree2,
539 unsigned int k,
540 torch::Tensor &alphasInit,
541 mtu::TorchMergeTree<float> &out,
542 mtu::TorchMergeTree<float> &out2,
543 torch::Tensor &bestAlphas,
544 bool train = false);
545
546 // -----------------------------------------------------------------------
547 // --- Projection
548 // -----------------------------------------------------------------------
553 void projectionStep();
554
555 // -----------------------------------------------------------------------
556 // --- Utils
557 // -----------------------------------------------------------------------
558 void copyParams(mtu::TorchMergeTree<float> &origin,
559 mtu::TorchMergeTree<float> &originPrime,
560 torch::Tensor &vS,
561 torch::Tensor &vSPrime,
562 mtu::TorchMergeTree<float> &origin2,
563 mtu::TorchMergeTree<float> &origin2Prime,
564 torch::Tensor &vS2,
565 torch::Tensor &vS2Prime,
566 bool get);
567
576 void adjustNestingScalars(std::vector<float> &scalarsVector,
577 ftm::idNode node,
578 ftm::idNode refNode);
579
590 void
591 createBalancedBDT(std::vector<std::vector<ftm::idNode>> &parents,
592 std::vector<std::vector<ftm::idNode>> &children,
593 std::vector<float> &scalarsVector,
594 std::vector<std::vector<ftm::idNode>> &childrenFinal);
595
596 // -----------------------------------------------------------------------
597 // --- Testing
598 // -----------------------------------------------------------------------
599 bool isTreeHasBigValues(ftm::MergeTree<float> &mTree,
600 float threshold = 10000);
601#endif
602 }; // MergeTreeNeuralLayer class
603
604} // namespace ttk
unsigned int idNode
Node index in vect_nodes_.
TTK base package defining the standard types.