TTK
Loading...
Searching...
No Matches
MergeTreeNeuralBase.h
Go to the documentation of this file.
1
15
16#pragma once
17
18// ttk common includes
19#include <Debug.h>
20#include <Geometry.h>
22#include <MergeTreeTorchUtils.h>
23
24#ifdef TTK_ENABLE_TORCH
25#include <torch/torch.h>
26#endif
27
28namespace ttk {
29
35 class MergeTreeNeuralBase : virtual public Debug,
37
38 protected:
39 // ======== Model hyper-parameters
40 // Dropout to use when training.
41 double dropout_ = 0.0;
42 // If the vectors should be initialized using euclidean distance (between
43 // vectors representing the topological abstractions ordered given the
44 // assignment to the barycenter), faster but less accurate than using
45 // Wasserstein distance.
47 // If the vectors should be initialized randomly.
48 bool randomAxesInit_ = false;
49 // When computing the origin of the input basis, if the barycenter algorihm
50 // should be initialized randomly (instead to the topological representation
51 // minimizing the distance to the set), faster but less accurate.
53 // When computing the origin of the input basis, if the barycenter algorithm
54 // should run for only one iteration, faster but less accurate.
56 // If the structure of the origin of the output basis should be initialized
57 // by copying the structure of the input basis.
59 // If the scalar values of the origin of the output basis should be
60 // initialized by copying the values of the input basis.
62 // Value between 0 and 1 allowing to add some randomness to the values of
63 // the origin of the output basis when initOriginPrimeValuesByCopy_ is set
64 // to true.
66 // If activation functions should be used.
67 bool activate_ = true;
68 // Choice of the activation function
69 // 0 : ReLU
70 // 1 : Leaky ReLU
71 unsigned int activationFunction_ = 1;
72
73 bool useGpu_ = false;
74
75 // ======== Testing
77
78 public:
80
81#ifdef TTK_ENABLE_TORCH
82 // -----------------------------------------------------------------------
83 // --- Setter
84 // -----------------------------------------------------------------------
85 void setDropout(const double dropout);
86
87 void setEuclideanVectorsInit(const bool euclideanVectorsInit);
88
89 void setRandomAxesInit(const bool randomAxesInit);
90
91 void setInitBarycenterRandom(const bool initBarycenterRandom);
92
93 void setInitBarycenterOneIter(const bool initBarycenterOneIter);
94
95 void setInitOriginPrimeStructByCopy(const bool initOriginPrimeStructByCopy);
96
97 void setInitOriginPrimeValuesByCopy(const bool initOriginPrimeValuesByCopy);
98
99 void setInitOriginPrimeValuesByCopyRandomness(
100 const double initOriginPrimeValuesByCopyRandomness);
101
102 void setActivate(const bool activate);
103
104 void setActivationFunction(const unsigned int activationFunction);
105
106 void setUseGpu(const bool useGpu);
107
108 void setBigValuesThreshold(const float bigValuesThreshold);
109
110 // -----------------------------------------------------------------------
111 // --- Utils
112 // -----------------------------------------------------------------------
113 torch::Tensor activation(torch::Tensor &in);
114
121 void fixTreePrecisionScalars(ftm::MergeTree<float> &mTree);
122
130 void printPairs(const ftm::MergeTree<float> &mTree, bool useBD = true);
131
132 // -----------------------------------------------------------------------
133 // --- Distance
134 // -----------------------------------------------------------------------
135 void getDistanceMatrix(const std::vector<mtu::TorchMergeTree<float>> &tmts,
136 std::vector<std::vector<float>> &distanceMatrix,
137 bool useDoubleInput = false,
138 bool isFirstInput = true);
139
140 void getDistanceMatrix(const std::vector<mtu::TorchMergeTree<float>> &tmts,
141 const std::vector<mtu::TorchMergeTree<float>> &tmts2,
142 std::vector<std::vector<float>> &distanceMatrix);
143
144 void getDifferentiableDistanceFromMatchings(
145 const mtu::TorchMergeTree<float> &tree1,
146 const mtu::TorchMergeTree<float> &tree2,
147 const mtu::TorchMergeTree<float> &tree1_2,
148 const mtu::TorchMergeTree<float> &tree2_2,
149 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings,
150 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings2,
151 torch::Tensor &tensorDist,
152 bool doSqrt);
153
154 void getDifferentiableDistance(const mtu::TorchMergeTree<float> &tree1,
155 const mtu::TorchMergeTree<float> &tree2,
156 const mtu::TorchMergeTree<float> &tree1_2,
157 const mtu::TorchMergeTree<float> &tree2_2,
158 torch::Tensor &tensorDist,
159 bool isCalled,
160 bool doSqrt);
161
162 void getDifferentiableDistance(const mtu::TorchMergeTree<float> &tree1,
163 const mtu::TorchMergeTree<float> &tree2,
164 torch::Tensor &tensorDist,
165 bool isCalled,
166 bool doSqrt);
167
168 void getDifferentiableDistanceMatrix(
169 const std::vector<mtu::TorchMergeTree<float> *> &trees,
170 const std::vector<mtu::TorchMergeTree<float> *> &trees2,
171 std::vector<std::vector<torch::Tensor>> &outDistMat);
172#endif
173 }; // MergeTreeNeuralBase class
174
175} // namespace ttk
void printPairs(std::vector< std::tuple< SimplexId, SimplexId, dataType > > &treePairs)
TTK base package defining the standard types.