TTK
Loading...
Searching...
No Matches
MergeTreeTorchUtils.h
Go to the documentation of this file.
1
7
8#pragma once
9
10#include <Debug.h>
11#include <FTMTreeUtils.h>
12#include <FTMTree_MT.h>
13#include <Geometry.h>
14#include <MergeTreeUtils.h>
15
16#ifdef TTK_ENABLE_TORCH
17#include <torch/torch.h>
18#endif
19
20namespace ttk {
21
22 namespace mtu {
23
24#ifdef TTK_ENABLE_TORCH
31 void copyTensor(torch::Tensor &a, torch::Tensor &b);
32
33 template <typename dataType>
34 struct TorchMergeTree {
36 torch::Tensor tensor;
37 std::vector<unsigned int> nodeCorr;
38 std::vector<ftm::idNode> parentsOri;
39 };
40
47 void getDeltaProjTensor(torch::Tensor &diagTensor,
48 torch::Tensor &deltaProjTensor);
49
68 void dataReorderingGivenMatching(mtu::TorchMergeTree<float> &tree,
69 mtu::TorchMergeTree<float> &tree2,
70 torch::Tensor &tree1ProjIndexer,
71 torch::Tensor &tree2ReorderingIndexes,
72 torch::Tensor &tree2ReorderedTensor,
73 torch::Tensor &tree2DeltaProjTensor,
74 torch::Tensor &tree1ReorderedTensor,
75 torch::Tensor &tree2ProjIndexer,
76 bool doubleReordering = true);
77
92 void dataReorderingGivenMatching(mtu::TorchMergeTree<float> &tree,
93 mtu::TorchMergeTree<float> &tree2,
94 torch::Tensor &tree1ProjIndexer,
95 torch::Tensor &tree2ReorderingIndexes,
96 torch::Tensor &tree2ReorderedTensor,
97 torch::Tensor &tree2DeltaProjTensor);
98
109 void dataReorderingGivenMatching(
110 mtu::TorchMergeTree<float> &tree,
111 mtu::TorchMergeTree<float> &tree2,
112 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
113 torch::Tensor &tree1ReorderedTensor,
114 torch::Tensor &tree2ReorderedTensor,
115 bool doubleReordering = true);
116
125 void dataReorderingGivenMatching(
126 mtu::TorchMergeTree<float> &tree,
127 mtu::TorchMergeTree<float> &tree2,
128 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
129 torch::Tensor &tree2ReorderedTensor);
130
137 void meanBirthShift(torch::Tensor &diagTensor,
138 torch::Tensor &diagBaseTensor);
139
147 void meanBirthMaxPersShift(torch::Tensor &tensor,
148 torch::Tensor &baseTensor);
149
157 void belowDiagonalPointsShift(torch::Tensor &tensor,
158 torch::Tensor &backupTensor);
159
166 void normalizeVectors(torch::Tensor &originTensor,
167 torch::Tensor &vectorsTensor);
168
175 void normalizeVectors(mtu::TorchMergeTree<float> &origin,
176 std::vector<std::vector<double>> &vectors);
177
183 unsigned int getLatentLayerIndex();
184
190 bool isThereMissingPairs(mtu::TorchMergeTree<float> &interpolation);
191
198 template <class dataType>
199 void copyTorchMergeTree(TorchMergeTree<dataType> &tmTree,
200 TorchMergeTree<dataType> &out) {
201 out.mTree = ftm::copyMergeTree<dataType>(tmTree.mTree);
202 copyTensor(tmTree.tensor, out.tensor);
203 out.nodeCorr = tmTree.nodeCorr;
204 out.parentsOri = tmTree.parentsOri;
205 }
206
219 template <class dataType>
220 void mergeTreeToTorchTensor(ftm::MergeTree<dataType> &mTree,
221 torch::Tensor &tensor,
222 std::vector<unsigned int> &nodeCorr,
223 bool normalize,
224 unsigned int *revNodeCorr = nullptr,
225 unsigned int revNodeCorrSize = 0) {
226 nodeCorr.clear();
227 nodeCorr.assign(mTree.tree.getNumberOfNodes(),
228 std::numeric_limits<unsigned int>::max());
229 std::vector<dataType> scalars;
230 std::queue<ftm::idNode> queue;
231 if(revNodeCorrSize == 0)
232 queue.emplace(mTree.tree.getRoot());
233 else {
234 for(unsigned int i = 0; i < revNodeCorrSize; i += 2)
235 queue.emplace(revNodeCorr[i]);
236 }
237 unsigned int cpt = 0;
238 while(!queue.empty()) {
239 ftm::idNode node = queue.front();
240 queue.pop();
241
242 auto birthDeath = ttk::getParametrizedBirthDeath<dataType>(
243 &(mTree.tree), node, normalize);
244 auto birth = std::get<0>(birthDeath);
245 auto death = std::get<1>(birthDeath);
246 scalars.emplace_back(birth);
247 scalars.emplace_back(death);
248
249 nodeCorr[node] = cpt;
250 ++cpt;
251
252 if(revNodeCorrSize == 0) {
253 std::vector<ftm::idNode> children;
254 mTree.tree.getChildren(node, children);
255 for(auto child : children)
256 queue.emplace(child);
257 }
258 }
259 tensor = torch::tensor(scalars).reshape({-1, 1});
260 }
261
269 template <class dataType>
270 void mergeTreeToTorchTensor(ftm::MergeTree<dataType> &mTree,
271 torch::Tensor &tensor,
272 bool normalize) {
273 std::vector<unsigned int> nodeCorr;
274 mergeTreeToTorchTensor<dataType>(mTree, tensor, nodeCorr, normalize);
275 }
276
283 template <class dataType>
284 void getParentsVector(ftm::MergeTree<dataType> &mTree,
285 std::vector<ftm::idNode> &parents) {
286 parents.resize(mTree.tree.getNumberOfNodes());
287 std::fill(parents.begin(), parents.end(),
288 std::numeric_limits<ftm::idNode>::max());
289 std::queue<ftm::idNode> queue;
290 queue.emplace(mTree.tree.getRoot());
291 while(!queue.empty()) {
292 ftm::idNode node = queue.front();
293 queue.pop();
294 if(!mTree.tree.isRoot(node))
295 parents[node] = mTree.tree.getParentSafe(node);
296 std::vector<ftm::idNode> children;
297 mTree.tree.getChildren(node, children);
298 for(auto child : children)
299 queue.emplace(child);
300 }
301 }
302
310 template <class dataType>
311 void mergeTreeToTorchTree(ftm::MergeTree<dataType> &mTree,
312 TorchMergeTree<dataType> &out,
313 bool normalize) {
314 out.mTree = copyMergeTree(mTree);
315 getParentsVector(out.mTree, out.parentsOri);
316 mergeTreeToTorchTensor<dataType>(
317 out.mTree, out.tensor, out.nodeCorr, normalize);
318 }
319
327 template <class dataType>
328 void mergeTreesToTorchTrees(std::vector<ftm::MergeTree<dataType>> &mTrees,
329 std::vector<TorchMergeTree<dataType>> &out,
330 bool normalize) {
331 out.resize(mTrees.size());
332 for(unsigned int i = 0; i < mTrees.size(); ++i)
333 mergeTreeToTorchTree<dataType>(mTrees[i], out[i], normalize);
334 }
335
346 template <class dataType>
347 void mergeTreeToTorchTree(ftm::MergeTree<dataType> &mTree,
348 TorchMergeTree<dataType> &out,
349 bool normalize,
350 unsigned int *revNodeCorr,
351 unsigned int revNodeCorrSize) {
352 out.mTree = copyMergeTree(mTree);
353 getParentsVector(out.mTree, out.parentsOri);
354 mergeTreeToTorchTensor<dataType>(out.mTree, out.tensor, out.nodeCorr,
355 normalize, revNodeCorr, revNodeCorrSize);
356 }
357
368 template <class dataType>
369 void mergeTreesToTorchTrees(std::vector<ftm::MergeTree<dataType>> &mTrees,
370 std::vector<TorchMergeTree<dataType>> &out,
371 bool normalize,
372 std::vector<unsigned int *> &allRevNodeCorr,
373 std::vector<unsigned int> &allRevNodeCorrSize) {
374 out.resize(mTrees.size());
375 for(unsigned int i = 0; i < mTrees.size(); ++i)
376 mergeTreeToTorchTree<dataType>(mTrees[i], out[i], normalize,
377 allRevNodeCorr[i],
378 allRevNodeCorrSize[i]);
379 }
380
388 template <class dataType>
389 bool torchTensorToMergeTree(TorchMergeTree<dataType> &tmt,
390 bool normalized,
391 ftm::MergeTree<dataType> &mTreeOut) {
392 std::vector<unsigned int> &nodeCorr = tmt.nodeCorr;
393 torch::Tensor &tensor = tmt.tensor;
394 std::vector<ftm::idNode> &parentsOri = tmt.parentsOri;
395
396 mTreeOut = ttk::ftm::copyMergeTree<dataType>(tmt.mTree);
397 if(parentsOri.empty())
398 std::cout << "[torchTensorToMergeTree] parentsOri.empty()" << std::endl;
399 for(unsigned int i = 0; i < parentsOri.size(); ++i)
400 if(parentsOri[i] < mTreeOut.tree.getNumberOfNodes()
401 and mTreeOut.tree.isRoot(i))
402 mTreeOut.tree.setParent(i, parentsOri[i]);
403 ftm::idNode root = mTreeOut.tree.getRoot();
404 if(root >= mTreeOut.tree.getNumberOfNodes())
405 return true;
406
407 bool isJT = tmt.mTree.tree.template isJoinTree<dataType>();
408 std::vector<dataType> tensorVec(
409 tensor.data_ptr<float>(), tensor.data_ptr<float>() + tensor.numel());
410 std::vector<dataType> scalarsVector;
411 ttk::ftm::getTreeScalars<dataType>(mTreeOut, scalarsVector);
412 std::queue<ftm::idNode> queue;
413 queue.emplace(root);
414 while(!queue.empty()) {
415 ftm::idNode node = queue.front();
416 queue.pop();
417
418 auto birthValue = tensorVec[nodeCorr[node] * 2];
419 auto deathValue = tensorVec[nodeCorr[node] * 2 + 1];
420 if(normalized and !mTreeOut.tree.isRoot(node)) {
421 ftm::idNode nodeParent = mTreeOut.tree.getParentSafe(node);
422 ftm::idNode nodeParentOrigin
423 = mTreeOut.tree.getNode(nodeParent)->getOrigin();
424 ftm::idNode birthParentNode = (isJT ? nodeParentOrigin : nodeParent);
425 ftm::idNode deathParentNode = (isJT ? nodeParent : nodeParentOrigin);
426 auto birthParent = scalarsVector[birthParentNode];
427 auto deathParent = scalarsVector[deathParentNode];
428 birthValue = birthValue * (deathParent - birthParent) + birthParent;
429 deathValue = deathValue * (deathParent - birthParent) + birthParent;
430 }
431 ftm::idNode nodeOrigin = mTreeOut.tree.getNode(node)->getOrigin();
432 scalarsVector[node] = (isJT ? deathValue : birthValue);
433 scalarsVector[nodeOrigin] = (isJT ? birthValue : deathValue);
434
435 std::vector<ftm::idNode> children;
436 mTreeOut.tree.getChildren(node, children);
437 for(auto &child : children)
438 queue.emplace(child);
439 }
440 ftm::setTreeScalars<dataType>(mTreeOut, scalarsVector);
441 return false;
442 }
443
450 template <class dataType>
451 void fillMergeTreeStructure(TorchMergeTree<dataType> &tmt) {
452 std::vector<ftm::idNode> &parentsOri = tmt.parentsOri;
453 for(unsigned int i = 0; i < parentsOri.size(); ++i)
454 if(parentsOri[i] < tmt.mTree.tree.getNumberOfNodes()
455 and tmt.mTree.tree.isRoot(i))
456 tmt.mTree.tree.setParent(i, parentsOri[i]);
457 }
458
466 template <class dataType>
467 void getReverseTorchNodeCorr(TorchMergeTree<dataType> &tmt,
468 std::vector<unsigned int> &revNodeCorr) {
469 revNodeCorr.clear();
470 revNodeCorr.assign(
471 tmt.tensor.sizes()[0], std::numeric_limits<unsigned int>::max());
472 for(unsigned int i = 0; i < tmt.nodeCorr.size(); ++i) {
473 if(tmt.nodeCorr[i] != std::numeric_limits<unsigned int>::max()) {
474 revNodeCorr[tmt.nodeCorr[i] * 2] = i;
475 revNodeCorr[tmt.nodeCorr[i] * 2 + 1] = i;
476 }
477 }
478 }
479
487 template <class dataType>
488 void axisVectorToTorchTensor(ftm::MergeTree<dataType> &mTree,
489 std::vector<std::vector<double>> &v,
490 torch::Tensor &tensor) {
491 std::vector<double> v_flatten;
492 std::queue<ftm::idNode> queue;
493 queue.emplace(mTree.tree.getRoot());
494 while(!queue.empty()) {
495 ftm::idNode node = queue.front();
496 queue.pop();
497
498 v_flatten.emplace_back(v[node][0]);
499 v_flatten.emplace_back(v[node][1]);
500
501 std::vector<ftm::idNode> children;
502 mTree.tree.getChildren(node, children);
503 for(auto child : children)
504 queue.emplace(child);
505 }
506 tensor = torch::tensor(v_flatten);
507 tensor = tensor.reshape({-1, 1});
508 }
509
517 template <class dataType>
518 void axisVectorsToTorchTensor(
520 std::vector<std::vector<std::vector<double>>> &vS,
521 torch::Tensor &tensor) {
522 std::vector<torch::Tensor> allTensors;
523 for(auto &v : vS) {
524 torch::Tensor t;
525 axisVectorToTorchTensor(mTree, v, t);
526 allTensors.emplace_back(t);
527 }
528 tensor = torch::cat(allTensors, 1);
529 }
530
539 template <class dataType>
540 void getTensorMatching(
541 TorchMergeTree<dataType> &a,
542 TorchMergeTree<dataType> &b,
543 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
544 std::vector<int> &tensorMatching) {
545 tensorMatching.clear();
546 tensorMatching.resize((int)(a.tensor.sizes()[0] / 2), -1);
547 for(auto &match : matching) {
548 auto match1 = std::get<0>(match);
549 auto match2 = std::get<1>(match);
550 if(a.nodeCorr[match1] < tensorMatching.size())
551 tensorMatching[a.nodeCorr[match1]] = b.nodeCorr[match2];
552 }
553 }
554
563 template <class dataType>
564 void getInverseTensorMatching(
565 TorchMergeTree<dataType> &a,
566 TorchMergeTree<dataType> &b,
567 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
568 std::vector<int> &tensorMatching) {
569 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> invMatching(
570 matching.size());
571 for(unsigned int i = 0; i < matching.size(); ++i)
572 invMatching[i]
573 = std::make_tuple(std::get<1>(matching[i]), std::get<0>(matching[i]),
574 std::get<2>(matching[i]));
575 getTensorMatching(b, a, invMatching, tensorMatching);
576 }
577#endif
578
579 } // namespace mtu
580
581} // namespace ttk
idNode getNumberOfNodes() const
Definition FTMTree_MT.h:389
void setParent(idNode nodeId, idNode newParentNodeId)
bool isRoot(idNode nodeId)
idNode getParentSafe(idNode nodeId)
void getChildren(idNode nodeId, std::vector< idNode > &res)
Node * getNode(idNode nodeId)
Definition FTMTree_MT.h:393
SimplexId getOrigin() const
Definition FTMNode.h:64
unsigned int idNode
Node index in vect_nodes_.
The Topology ToolKit.
coefficient_t normalize(const coefficient_t n, const coefficient_t modulus)
Definition ripserpy.cpp:159
ftm::FTMTree_MT tree
Definition FTMTree_MT.h:903