TTK
Loading...
Searching...
No Matches
MergeTreeNeuralBase.cpp
Go to the documentation of this file.
2#include <cmath>
3
4#ifdef TTK_ENABLE_TORCH
5using namespace torch::indexing;
6#endif
7
9 // inherited from Debug: prefix will be printed at the beginning of every msg
10 this->setDebugMsgPrefix("MergeTreeNeuralBase");
11}
12
13#ifdef TTK_ENABLE_TORCH
14// -----------------------------------------------------------------------
15// --- Setter
16// -----------------------------------------------------------------------
17void ttk::MergeTreeNeuralBase::setDropout(const double dropout) {
18 dropout_ = dropout;
19}
20
21void ttk::MergeTreeNeuralBase::setEuclideanVectorsInit(
22 const bool euclideanVectorsInit) {
23 euclideanVectorsInit_ = euclideanVectorsInit;
24}
25
26void ttk::MergeTreeNeuralBase::setRandomAxesInit(const bool randomAxesInit) {
27 randomAxesInit_ = randomAxesInit;
28}
29
30void ttk::MergeTreeNeuralBase::setInitBarycenterRandom(
31 const bool initBarycenterRandom) {
32 initBarycenterRandom_ = initBarycenterRandom;
33}
34
35void ttk::MergeTreeNeuralBase::setInitBarycenterOneIter(
36 const bool initBarycenterOneIter) {
37 initBarycenterOneIter_ = initBarycenterOneIter;
38}
39
40void ttk::MergeTreeNeuralBase::setInitOriginPrimeStructByCopy(
41 const bool initOriginPrimeStructByCopy) {
42 initOriginPrimeStructByCopy_ = initOriginPrimeStructByCopy;
43}
44
45void ttk::MergeTreeNeuralBase::setInitOriginPrimeValuesByCopy(
46 const bool initOriginPrimeValuesByCopy) {
47 initOriginPrimeValuesByCopy_ = initOriginPrimeValuesByCopy;
48}
49
50void ttk::MergeTreeNeuralBase::setInitOriginPrimeValuesByCopyRandomness(
51 const double initOriginPrimeValuesByCopyRandomness) {
52 initOriginPrimeValuesByCopyRandomness_
53 = initOriginPrimeValuesByCopyRandomness;
54}
55
56void ttk::MergeTreeNeuralBase::setActivate(const bool activate) {
57 activate_ = activate;
58}
59
60void ttk::MergeTreeNeuralBase::setActivationFunction(
61 const unsigned int activationFunction) {
62 activationFunction_ = activationFunction;
63}
64
65void ttk::MergeTreeNeuralBase::setUseGpu(const bool useGpu) {
66 useGpu_ = useGpu;
67}
68
69void ttk::MergeTreeNeuralBase::setBigValuesThreshold(
70 const float bigValuesThreshold) {
71 bigValuesThreshold_ = bigValuesThreshold;
72}
73
74// -----------------------------------------------------------------------
75// --- Utils
76// -----------------------------------------------------------------------
77torch::Tensor ttk::MergeTreeNeuralBase::activation(torch::Tensor &in) {
78 torch::Tensor act;
79 switch(activationFunction_) {
80 case 1:
81 act = torch::nn::LeakyReLU()(in);
82 break;
83 case 0:
84 default:
85 act = torch::nn::ReLU()(in);
86 }
87 return act;
88}
89
90void ttk::MergeTreeNeuralBase::fixTreePrecisionScalars(
91 ftm::MergeTree<float> &mTree) {
92 double eps = 1e-6;
93 auto shiftSubtree
94 = [&mTree, &eps](ftm::idNode node, ftm::idNode birthNodeParent,
95 ftm::idNode deathNodeParent, std::vector<float> &scalars,
96 bool invalidBirth, bool invalidDeath) {
97 std::queue<ftm::idNode> queue;
98 queue.emplace(node);
99 while(!queue.empty()) {
100 ftm::idNode nodeT = queue.front();
101 queue.pop();
102 auto birthDeathNode = mTree.tree.getBirthDeathNode<float>(node);
103 auto birthNode = std::get<0>(birthDeathNode);
104 auto deathNode = std::get<1>(birthDeathNode);
105 if(invalidBirth)
106 scalars[birthNode] = scalars[birthNodeParent] + 2 * eps;
107 if(invalidDeath)
108 scalars[deathNode] = scalars[deathNodeParent] - 2 * eps;
109 std::vector<ftm::idNode> children;
110 mTree.tree.getChildren(nodeT, children);
111 for(auto &child : children)
112 queue.emplace(child);
113 }
114 };
115 std::vector<float> scalars;
116 getTreeScalars(mTree, scalars);
117 std::queue<ftm::idNode> queue;
118 auto root = mTree.tree.getRoot();
119 queue.emplace(root);
120 while(!queue.empty()) {
121 ftm::idNode node = queue.front();
122 queue.pop();
123 auto birthDeathNode = mTree.tree.getBirthDeathNode<float>(node);
124 auto birthNode = std::get<0>(birthDeathNode);
125 auto deathNode = std::get<1>(birthDeathNode);
126 auto birthDeathNodeParent
127 = mTree.tree.getBirthDeathNode<float>(mTree.tree.getParentSafe(node));
128 auto birthNodeParent = std::get<0>(birthDeathNodeParent);
129 auto deathNodeParent = std::get<1>(birthDeathNodeParent);
130 bool invalidBirth = (scalars[birthNode] <= scalars[birthNodeParent] + eps);
131 bool invalidDeath = (scalars[deathNode] >= scalars[deathNodeParent] - eps);
132 if(!mTree.tree.isRoot(node) and (invalidBirth or invalidDeath))
133 shiftSubtree(node, birthNodeParent, deathNodeParent, scalars,
134 invalidBirth, invalidDeath);
135 std::vector<ftm::idNode> children;
136 mTree.tree.getChildren(node, children);
137 for(auto &child : children)
138 queue.emplace(child);
139 }
140 ftm::setTreeScalars<float>(mTree, scalars);
141}
142
144 bool useBD) {
145 std::stringstream ss;
146 if(mTree.tree.getRealNumberOfNodes() != 0)
147 ss = mTree.tree.template printPairsFromTree<float>(useBD);
148 else {
149 std::vector<bool> nodeDone(mTree.tree.getNumberOfNodes(), false);
150 for(unsigned int i = 0; i < mTree.tree.getNumberOfNodes(); ++i) {
151 if(nodeDone[i])
152 continue;
153 std::tuple<ftm::idNode, ftm::idNode, float> pair
154 = std::make_tuple(i, mTree.tree.getNode(i)->getOrigin(),
155 mTree.tree.getNodePersistence<float>(i));
156 ss << std::get<0>(pair) << " ("
157 << mTree.tree.getValue<float>(std::get<0>(pair)) << ") _ ";
158 ss << std::get<1>(pair) << " ("
159 << mTree.tree.getValue<float>(std::get<1>(pair)) << ") _ ";
160 ss << std::get<2>(pair) << std::endl;
161 nodeDone[i] = true;
162 nodeDone[mTree.tree.getNode(i)->getOrigin()] = true;
163 }
164 }
165 ss << std::endl;
166 std::cout << ss.str();
167}
168
169// -----------------------------------------------------------------------
170// --- Distance
171// -----------------------------------------------------------------------
172void ttk::MergeTreeNeuralBase::getDistanceMatrix(
173 const std::vector<mtu::TorchMergeTree<float>> &tmts,
174 std::vector<std::vector<float>> &distanceMatrix,
175 bool useDoubleInput,
176 bool isFirstInput) {
177 distanceMatrix.clear();
178 distanceMatrix.resize(tmts.size(), std::vector<float>(tmts.size(), 0));
179#ifdef TTK_ENABLE_OPENMP
180#pragma omp parallel num_threads(this->threadNumber_) if(parallelize_) \
181 shared(distanceMatrix, tmts)
182 {
183#pragma omp single nowait
184 {
185#endif
186 for(unsigned int i = 0; i < tmts.size(); ++i) {
187 for(unsigned int j = i + 1; j < tmts.size(); ++j) {
188#ifdef TTK_ENABLE_OPENMP
189#pragma omp task UNTIED() shared(distanceMatrix, tmts) firstprivate(i, j)
190 {
191#endif
192 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching;
193 float distance;
194 bool isCalled = true;
195 computeOneDistance(tmts[i].mTree, tmts[j].mTree, matching, distance,
196 isCalled, useDoubleInput, isFirstInput);
198 distanceMatrix[i][j] = distance;
199 distanceMatrix[j][i] = distance;
200#ifdef TTK_ENABLE_OPENMP
201 } // pragma omp task
202#endif
203 }
204 }
205#ifdef TTK_ENABLE_OPENMP
206#pragma omp taskwait
207 } // pragma omp single nowait
208 } // pragma omp parallel
209#endif
210}
211
212void ttk::MergeTreeNeuralBase::getDistanceMatrix(
213 const std::vector<mtu::TorchMergeTree<float>> &tmts,
214 const std::vector<mtu::TorchMergeTree<float>> &tmts2,
215 std::vector<std::vector<float>> &distanceMatrix) {
216 getDistanceMatrix(tmts, distanceMatrix, useDoubleInput_);
217 if(useDoubleInput_) {
218 std::vector<std::vector<float>> distanceMatrix2;
219 getDistanceMatrix(tmts2, distanceMatrix2, useDoubleInput_, false);
220 mixDistancesMatrix<float>(distanceMatrix, distanceMatrix2);
221 }
222}
223
224void ttk::MergeTreeNeuralBase::getDifferentiableDistanceFromMatchings(
225 const mtu::TorchMergeTree<float> &tree1,
226 const mtu::TorchMergeTree<float> &tree2,
227 const mtu::TorchMergeTree<float> &tree1_2,
228 const mtu::TorchMergeTree<float> &tree2_2,
229 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings,
230 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings2,
231 torch::Tensor &tensorDist,
232 bool doSqrt) {
233 torch::Tensor reorderedITensor, reorderedJTensor;
234 dataReorderingGivenMatching(
235 tree1, tree2, matchings, reorderedITensor, reorderedJTensor);
236 if(useDoubleInput_) {
237 torch::Tensor reorderedI2Tensor, reorderedJ2Tensor;
238 dataReorderingGivenMatching(
239 tree1_2, tree2_2, matchings2, reorderedI2Tensor, reorderedJ2Tensor);
240 reorderedITensor = torch::cat({reorderedITensor, reorderedI2Tensor});
241 reorderedJTensor = torch::cat({reorderedJTensor, reorderedJ2Tensor});
242 }
243 tensorDist = (reorderedITensor - reorderedJTensor).pow(2).sum();
244 if(doSqrt)
245 tensorDist = tensorDist.sqrt();
246}
247
248void ttk::MergeTreeNeuralBase::getDifferentiableDistance(
249 const mtu::TorchMergeTree<float> &tree1,
250 const mtu::TorchMergeTree<float> &tree2,
251 const mtu::TorchMergeTree<float> &tree1_2,
252 const mtu::TorchMergeTree<float> &tree2_2,
253 torch::Tensor &tensorDist,
254 bool isCalled,
255 bool doSqrt) {
256 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matchings,
257 matchings2;
258 float distance;
259 computeOneDistance<float>(
260 tree1.mTree, tree2.mTree, matchings, distance, isCalled, useDoubleInput_);
261 if(useDoubleInput_) {
262 float distance2;
263 computeOneDistance<float>(tree1_2.mTree, tree2_2.mTree, matchings2,
264 distance2, isCalled, useDoubleInput_, false);
265 }
266 getDifferentiableDistanceFromMatchings(
267 tree1, tree2, tree1_2, tree2_2, matchings, matchings2, tensorDist, doSqrt);
268}
269
270void ttk::MergeTreeNeuralBase::getDifferentiableDistance(
271 const mtu::TorchMergeTree<float> &tree1,
272 const mtu::TorchMergeTree<float> &tree2,
273 torch::Tensor &tensorDist,
274 bool isCalled,
275 bool doSqrt) {
276 mtu::TorchMergeTree<float> tree1_2, tree2_2;
277 getDifferentiableDistance(
278 tree1, tree2, tree1_2, tree2_2, tensorDist, isCalled, doSqrt);
279}
280
281void ttk::MergeTreeNeuralBase::getDifferentiableDistanceMatrix(
282 const std::vector<mtu::TorchMergeTree<float> *> &trees,
283 const std::vector<mtu::TorchMergeTree<float> *> &trees2,
284 std::vector<std::vector<torch::Tensor>> &outDistMat) {
285 outDistMat.resize(trees.size(), std::vector<torch::Tensor>(trees.size()));
286#ifdef TTK_ENABLE_OPENMP
287#pragma omp parallel num_threads(this->threadNumber_) if(parallelize_) \
288 shared(trees, trees2, outDistMat)
289 {
290#pragma omp single nowait
291 {
292#endif
293 for(unsigned int i = 0; i < trees.size(); ++i) {
294 outDistMat[i][i] = torch::tensor(0);
295 for(unsigned int j = i + 1; j < trees.size(); ++j) {
296#ifdef TTK_ENABLE_OPENMP
297#pragma omp task UNTIED() shared(trees, trees2, outDistMat) firstprivate(i, j)
298 {
299#endif
300 bool isCalled = true;
301 bool doSqrt = false;
302 torch::Tensor tensorDist;
303 getDifferentiableDistance(*(trees[i]), *(trees[j]), *(trees2[i]),
304 *(trees2[j]), tensorDist, isCalled,
305 doSqrt);
306 outDistMat[i][j] = tensorDist;
307 outDistMat[j][i] = tensorDist;
308#ifdef TTK_ENABLE_OPENMP
309 } // pragma omp task
310#endif
311 }
312 }
313#ifdef TTK_ENABLE_OPENMP
314#pragma omp taskwait
315 } // pragma omp single nowait
316 } // pragma omp parallel
317#endif
318}
319#endif
void setDebugMsgPrefix(const std::string &prefix)
Definition Debug.h:364
void printPairs(std::vector< std::tuple< SimplexId, SimplexId, dataType > > &treePairs)
T1 pow(const T1 val, const T2 n)
Definition Geometry.h:456
T distance(const T *p0, const T *p1, const int &dimension=3)
Definition Geometry.cpp:362
void setTreeScalars(MergeTree< dataType > &mergeTree, std::vector< dataType > &scalarsVector)
void getTreeScalars(const ftm::FTMTree_MT *tree, std::vector< dataType > &scalarsVector)
unsigned int idNode
Node index in vect_nodes_.