5using namespace torch::indexing;
13#ifdef TTK_ENABLE_TORCH
17void ttk::MergeTreeNeuralBase::setDropout(
const double dropout) {
21void ttk::MergeTreeNeuralBase::setEuclideanVectorsInit(
22 const bool euclideanVectorsInit) {
23 euclideanVectorsInit_ = euclideanVectorsInit;
26void ttk::MergeTreeNeuralBase::setRandomAxesInit(
const bool randomAxesInit) {
27 randomAxesInit_ = randomAxesInit;
30void ttk::MergeTreeNeuralBase::setInitBarycenterRandom(
31 const bool initBarycenterRandom) {
32 initBarycenterRandom_ = initBarycenterRandom;
35void ttk::MergeTreeNeuralBase::setInitBarycenterOneIter(
36 const bool initBarycenterOneIter) {
37 initBarycenterOneIter_ = initBarycenterOneIter;
40void ttk::MergeTreeNeuralBase::setInitOriginPrimeStructByCopy(
41 const bool initOriginPrimeStructByCopy) {
42 initOriginPrimeStructByCopy_ = initOriginPrimeStructByCopy;
45void ttk::MergeTreeNeuralBase::setInitOriginPrimeValuesByCopy(
46 const bool initOriginPrimeValuesByCopy) {
47 initOriginPrimeValuesByCopy_ = initOriginPrimeValuesByCopy;
50void ttk::MergeTreeNeuralBase::setInitOriginPrimeValuesByCopyRandomness(
51 const double initOriginPrimeValuesByCopyRandomness) {
52 initOriginPrimeValuesByCopyRandomness_
53 = initOriginPrimeValuesByCopyRandomness;
56void ttk::MergeTreeNeuralBase::setActivate(
const bool activate) {
60void ttk::MergeTreeNeuralBase::setActivationFunction(
61 const unsigned int activationFunction) {
62 activationFunction_ = activationFunction;
65void ttk::MergeTreeNeuralBase::setUseGpu(
const bool useGpu) {
69void ttk::MergeTreeNeuralBase::setBigValuesThreshold(
70 const float bigValuesThreshold) {
71 bigValuesThreshold_ = bigValuesThreshold;
77torch::Tensor ttk::MergeTreeNeuralBase::activation(torch::Tensor &in) {
79 switch(activationFunction_) {
81 act = torch::nn::LeakyReLU()(in);
85 act = torch::nn::ReLU()(in);
90void ttk::MergeTreeNeuralBase::fixTreePrecisionScalars(
95 ftm::idNode deathNodeParent, std::vector<float> &scalars,
96 bool invalidBirth,
bool invalidDeath) {
97 std::queue<ftm::idNode> queue;
99 while(!queue.empty()) {
102 auto birthDeathNode = mTree.tree.getBirthDeathNode<
float>(node);
103 auto birthNode = std::get<0>(birthDeathNode);
104 auto deathNode = std::get<1>(birthDeathNode);
106 scalars[birthNode] = scalars[birthNodeParent] + 2 * eps;
108 scalars[deathNode] = scalars[deathNodeParent] - 2 * eps;
109 std::vector<ftm::idNode> children;
110 mTree.tree.getChildren(nodeT, children);
111 for(
auto &child : children)
112 queue.emplace(child);
115 std::vector<float> scalars;
117 std::queue<ftm::idNode> queue;
118 auto root = mTree.tree.getRoot();
120 while(!queue.empty()) {
123 auto birthDeathNode = mTree.tree.getBirthDeathNode<
float>(node);
124 auto birthNode = std::get<0>(birthDeathNode);
125 auto deathNode = std::get<1>(birthDeathNode);
126 auto birthDeathNodeParent
127 = mTree.tree.getBirthDeathNode<
float>(mTree.tree.getParentSafe(node));
128 auto birthNodeParent = std::get<0>(birthDeathNodeParent);
129 auto deathNodeParent = std::get<1>(birthDeathNodeParent);
130 bool invalidBirth = (scalars[birthNode] <= scalars[birthNodeParent] + eps);
131 bool invalidDeath = (scalars[deathNode] >= scalars[deathNodeParent] - eps);
132 if(!mTree.tree.isRoot(node) and (invalidBirth or invalidDeath))
133 shiftSubtree(node, birthNodeParent, deathNodeParent, scalars,
134 invalidBirth, invalidDeath);
135 std::vector<ftm::idNode> children;
136 mTree.tree.getChildren(node, children);
137 for(
auto &child : children)
138 queue.emplace(child);
145 std::stringstream ss;
146 if(mTree.tree.getRealNumberOfNodes() != 0)
147 ss = mTree.tree.template printPairsFromTree<float>(useBD);
149 std::vector<bool> nodeDone(mTree.tree.getNumberOfNodes(),
false);
150 for(
unsigned int i = 0; i < mTree.tree.getNumberOfNodes(); ++i) {
153 std::tuple<ftm::idNode, ftm::idNode, float> pair
154 = std::make_tuple(i, mTree.tree.getNode(i)->getOrigin(),
155 mTree.tree.getNodePersistence<
float>(i));
156 ss << std::get<0>(pair) <<
" ("
157 << mTree.tree.getValue<
float>(std::get<0>(pair)) <<
") _ ";
158 ss << std::get<1>(pair) <<
" ("
159 << mTree.tree.getValue<
float>(std::get<1>(pair)) <<
") _ ";
160 ss << std::get<2>(pair) << std::endl;
162 nodeDone[mTree.tree.getNode(i)->getOrigin()] =
true;
166 std::cout << ss.str();
172void ttk::MergeTreeNeuralBase::getDistanceMatrix(
173 const std::vector<mtu::TorchMergeTree<float>> &tmts,
174 std::vector<std::vector<float>> &distanceMatrix,
177 distanceMatrix.clear();
178 distanceMatrix.resize(tmts.size(), std::vector<float>(tmts.size(), 0));
179#ifdef TTK_ENABLE_OPENMP
180#pragma omp parallel num_threads(this->threadNumber_) if(parallelize_) \
181 shared(distanceMatrix, tmts)
183#pragma omp single nowait
186 for(
unsigned int i = 0; i < tmts.size(); ++i) {
187 for(
unsigned int j = i + 1; j < tmts.size(); ++j) {
188#ifdef TTK_ENABLE_OPENMP
189#pragma omp task UNTIED() shared(distanceMatrix, tmts) firstprivate(i, j)
192 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching;
194 bool isCalled =
true;
195 computeOneDistance(tmts[i].mTree, tmts[j].mTree, matching, distance,
196 isCalled, useDoubleInput, isFirstInput);
200#ifdef TTK_ENABLE_OPENMP
205#ifdef TTK_ENABLE_OPENMP
212void ttk::MergeTreeNeuralBase::getDistanceMatrix(
213 const std::vector<mtu::TorchMergeTree<float>> &tmts,
214 const std::vector<mtu::TorchMergeTree<float>> &tmts2,
215 std::vector<std::vector<float>> &distanceMatrix) {
216 getDistanceMatrix(tmts, distanceMatrix, useDoubleInput_);
217 if(useDoubleInput_) {
218 std::vector<std::vector<float>> distanceMatrix2;
219 getDistanceMatrix(tmts2, distanceMatrix2, useDoubleInput_,
false);
220 mixDistancesMatrix<float>(distanceMatrix, distanceMatrix2);
224void ttk::MergeTreeNeuralBase::getDifferentiableDistanceFromMatchings(
225 const mtu::TorchMergeTree<float> &tree1,
226 const mtu::TorchMergeTree<float> &tree2,
227 const mtu::TorchMergeTree<float> &tree1_2,
228 const mtu::TorchMergeTree<float> &tree2_2,
229 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings,
230 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings2,
231 torch::Tensor &tensorDist,
233 torch::Tensor reorderedITensor, reorderedJTensor;
234 dataReorderingGivenMatching(
235 tree1, tree2, matchings, reorderedITensor, reorderedJTensor);
236 if(useDoubleInput_) {
237 torch::Tensor reorderedI2Tensor, reorderedJ2Tensor;
238 dataReorderingGivenMatching(
239 tree1_2, tree2_2, matchings2, reorderedI2Tensor, reorderedJ2Tensor);
240 reorderedITensor = torch::cat({reorderedITensor, reorderedI2Tensor});
241 reorderedJTensor = torch::cat({reorderedJTensor, reorderedJ2Tensor});
243 tensorDist = (reorderedITensor - reorderedJTensor).
pow(2).sum();
245 tensorDist = tensorDist.sqrt();
248void ttk::MergeTreeNeuralBase::getDifferentiableDistance(
249 const mtu::TorchMergeTree<float> &tree1,
250 const mtu::TorchMergeTree<float> &tree2,
251 const mtu::TorchMergeTree<float> &tree1_2,
252 const mtu::TorchMergeTree<float> &tree2_2,
253 torch::Tensor &tensorDist,
256 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matchings,
259 computeOneDistance<float>(
260 tree1.mTree, tree2.mTree, matchings, distance, isCalled, useDoubleInput_);
261 if(useDoubleInput_) {
263 computeOneDistance<float>(tree1_2.mTree, tree2_2.mTree, matchings2,
264 distance2, isCalled, useDoubleInput_,
false);
266 getDifferentiableDistanceFromMatchings(
267 tree1, tree2, tree1_2, tree2_2, matchings, matchings2, tensorDist, doSqrt);
270void ttk::MergeTreeNeuralBase::getDifferentiableDistance(
271 const mtu::TorchMergeTree<float> &tree1,
272 const mtu::TorchMergeTree<float> &tree2,
273 torch::Tensor &tensorDist,
276 mtu::TorchMergeTree<float> tree1_2, tree2_2;
277 getDifferentiableDistance(
278 tree1, tree2, tree1_2, tree2_2, tensorDist, isCalled, doSqrt);
281void ttk::MergeTreeNeuralBase::getDifferentiableDistanceMatrix(
282 const std::vector<mtu::TorchMergeTree<float> *> &trees,
283 const std::vector<mtu::TorchMergeTree<float> *> &trees2,
284 std::vector<std::vector<torch::Tensor>> &outDistMat) {
285 outDistMat.resize(trees.size(), std::vector<torch::Tensor>(trees.size()));
286#ifdef TTK_ENABLE_OPENMP
287#pragma omp parallel num_threads(this->threadNumber_) if(parallelize_) \
288 shared(trees, trees2, outDistMat)
290#pragma omp single nowait
293 for(
unsigned int i = 0; i < trees.size(); ++i) {
294 outDistMat[i][i] = torch::tensor(0);
295 for(
unsigned int j = i + 1; j < trees.size(); ++j) {
296#ifdef TTK_ENABLE_OPENMP
297#pragma omp task UNTIED() shared(trees, trees2, outDistMat) firstprivate(i, j)
300 bool isCalled =
true;
302 torch::Tensor tensorDist;
303 getDifferentiableDistance(*(trees[i]), *(trees[j]), *(trees2[i]),
304 *(trees2[j]), tensorDist, isCalled,
306 outDistMat[i][j] = tensorDist;
307 outDistMat[j][i] = tensorDist;
308#ifdef TTK_ENABLE_OPENMP
313#ifdef TTK_ENABLE_OPENMP
void setDebugMsgPrefix(const std::string &prefix)
void printPairs(std::vector< std::tuple< SimplexId, SimplexId, dataType > > &treePairs)
T1 pow(const T1 val, const T2 n)
T distance(const T *p0, const T *p1, const int &dimension=3)
void setTreeScalars(MergeTree< dataType > &mergeTree, std::vector< dataType > &scalarsVector)
void getTreeScalars(const ftm::FTMTree_MT *tree, std::vector< dataType > &scalarsVector)
unsigned int idNode
Node index in vect_nodes_.