TTK
Loading...
Searching...
No Matches
MergeTreeTorchUtils.h
Go to the documentation of this file.
1
7
8#pragma once
9
10#include <Debug.h>
11#include <FTMTreeUtils.h>
12#include <FTMTree_MT.h>
13#include <Geometry.h>
14#include <MergeTreeUtils.h>
15
16#ifdef TTK_ENABLE_TORCH
17#include <torch/torch.h>
18#endif
19
20namespace ttk {
21
22 namespace mtu {
23
24#ifdef TTK_ENABLE_TORCH
31 void copyTensor(const torch::Tensor &a, torch::Tensor &b);
32
33 template <typename dataType>
34 struct TorchMergeTree {
36 torch::Tensor tensor;
37 std::vector<unsigned int> nodeCorr;
38 std::vector<ftm::idNode> parentsOri;
39 };
40
47 void getDeltaProjTensor(torch::Tensor &diagTensor,
48 torch::Tensor &deltaProjTensor);
49
68 void dataReorderingGivenMatching(const mtu::TorchMergeTree<float> &tree,
69 const mtu::TorchMergeTree<float> &tree2,
70 torch::Tensor &tree1ProjIndexer,
71 torch::Tensor &tree2ReorderingIndexes,
72 torch::Tensor &tree2ReorderedTensor,
73 torch::Tensor &tree2DeltaProjTensor,
74 torch::Tensor &tree1ReorderedTensor,
75 torch::Tensor &tree2ProjIndexer,
76 bool doubleReordering = true);
77
92 void dataReorderingGivenMatching(const mtu::TorchMergeTree<float> &tree,
93 const mtu::TorchMergeTree<float> &tree2,
94 torch::Tensor &tree1ProjIndexer,
95 torch::Tensor &tree2ReorderingIndexes,
96 torch::Tensor &tree2ReorderedTensor,
97 torch::Tensor &tree2DeltaProjTensor);
98
109 void dataReorderingGivenMatching(
110 const mtu::TorchMergeTree<float> &tree,
111 const mtu::TorchMergeTree<float> &tree2,
112 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
113 torch::Tensor &tree1ReorderedTensor,
114 torch::Tensor &tree2ReorderedTensor,
115 bool doubleReordering = true);
116
125 void dataReorderingGivenMatching(
126 const mtu::TorchMergeTree<float> &tree,
127 const mtu::TorchMergeTree<float> &tree2,
128 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
129 torch::Tensor &tree2ReorderedTensor);
130
137 void meanBirthShift(torch::Tensor &diagTensor,
138 torch::Tensor &diagBaseTensor);
139
147 void meanBirthMaxPersShift(torch::Tensor &tensor,
148 torch::Tensor &baseTensor);
149
157 void belowDiagonalPointsShift(torch::Tensor &tensor,
158 torch::Tensor &backupTensor);
159
166 void normalizeVectors(torch::Tensor &originTensor,
167 torch::Tensor &vectorsTensor);
168
175 void normalizeVectors(mtu::TorchMergeTree<float> &origin,
176 std::vector<std::vector<double>> &vectors);
177
183 unsigned int getLatentLayerIndex();
184
190 bool isThereMissingPairs(mtu::TorchMergeTree<float> &interpolation);
191
198 template <class dataType>
199 void copyTorchMergeTree(const TorchMergeTree<dataType> &tmTree,
200 TorchMergeTree<dataType> &out) {
201 out.mTree = ftm::copyMergeTree<dataType>(tmTree.mTree);
202 copyTensor(tmTree.tensor, out.tensor);
203 out.nodeCorr = tmTree.nodeCorr;
204 out.parentsOri = tmTree.parentsOri;
205 }
206
219 template <class dataType>
220 void mergeTreeToTorchTensor(ftm::MergeTree<dataType> &mTree,
221 torch::Tensor &tensor,
222 std::vector<unsigned int> &nodeCorr,
223 bool normalize,
224 unsigned int *revNodeCorr = nullptr,
225 unsigned int revNodeCorrSize = 0) {
226 nodeCorr.clear();
227 nodeCorr.assign(mTree.tree.getNumberOfNodes(),
228 std::numeric_limits<unsigned int>::max());
229 std::vector<dataType> scalars;
230 std::queue<ftm::idNode> queue;
231 if(revNodeCorrSize == 0)
232 queue.emplace(mTree.tree.getRoot());
233 else {
234 for(unsigned int i = 0; i < revNodeCorrSize; i += 2)
235 queue.emplace(revNodeCorr[i]);
236 }
237 unsigned int cpt = 0;
238 while(!queue.empty()) {
239 ftm::idNode node = queue.front();
240 queue.pop();
241
243 &(mTree.tree), node, normalize);
244 auto birth = std::get<0>(birthDeath);
245 auto death = std::get<1>(birthDeath);
246 scalars.emplace_back(birth);
247 scalars.emplace_back(death);
248
249 nodeCorr[node] = cpt;
250 ++cpt;
251
252 if(revNodeCorrSize == 0) {
253 std::vector<ftm::idNode> children;
254 mTree.tree.getChildren(node, children);
255 for(auto child : children)
256 queue.emplace(child);
257 }
258 }
259 tensor = torch::tensor(scalars).reshape({-1, 1});
260 }
261
269 template <class dataType>
270 void mergeTreeToTorchTensor(ftm::MergeTree<dataType> &mTree,
271 torch::Tensor &tensor,
272 bool normalize) {
273 std::vector<unsigned int> nodeCorr;
274 mergeTreeToTorchTensor<dataType>(mTree, tensor, nodeCorr, normalize);
275 }
276
283 template <class dataType>
284 void getParentsVector(ftm::MergeTree<dataType> &mTree,
285 std::vector<ftm::idNode> &parents) {
286 parents.resize(mTree.tree.getNumberOfNodes());
287 std::fill(parents.begin(), parents.end(),
288 std::numeric_limits<ftm::idNode>::max());
289 std::queue<ftm::idNode> queue;
290 queue.emplace(mTree.tree.getRoot());
291 while(!queue.empty()) {
292 ftm::idNode node = queue.front();
293 queue.pop();
294 if(!mTree.tree.isRoot(node))
295 parents[node] = mTree.tree.getParentSafe(node);
296 std::vector<ftm::idNode> children;
297 mTree.tree.getChildren(node, children);
298 for(auto child : children)
299 queue.emplace(child);
300 }
301 }
302
310 template <class dataType>
311 void mergeTreeToTorchTree(ftm::MergeTree<dataType> &mTree,
312 TorchMergeTree<dataType> &out,
313 bool normalize) {
314 out.mTree = copyMergeTree(mTree);
315 getParentsVector(out.mTree, out.parentsOri);
316 mergeTreeToTorchTensor<dataType>(
317 out.mTree, out.tensor, out.nodeCorr, normalize);
318 }
319
327 template <class dataType>
328 void mergeTreesToTorchTrees(std::vector<ftm::MergeTree<dataType>> &mTrees,
329 std::vector<TorchMergeTree<dataType>> &out,
330 bool normalize) {
331 out.resize(mTrees.size());
332 for(unsigned int i = 0; i < mTrees.size(); ++i)
333 mergeTreeToTorchTree<dataType>(mTrees[i], out[i], normalize);
334 }
335
346 template <class dataType>
347 void mergeTreeToTorchTree(ftm::MergeTree<dataType> &mTree,
348 TorchMergeTree<dataType> &out,
349 bool normalize,
350 unsigned int *revNodeCorr,
351 unsigned int revNodeCorrSize) {
352 out.mTree = copyMergeTree(mTree);
353 getParentsVector(out.mTree, out.parentsOri);
354 mergeTreeToTorchTensor<dataType>(out.mTree, out.tensor, out.nodeCorr,
355 normalize, revNodeCorr, revNodeCorrSize);
356 }
357
368 template <class dataType>
369 void mergeTreesToTorchTrees(std::vector<ftm::MergeTree<dataType>> &mTrees,
370 std::vector<TorchMergeTree<dataType>> &out,
371 bool normalize,
372 std::vector<unsigned int *> &allRevNodeCorr,
373 std::vector<unsigned int> &allRevNodeCorrSize) {
374 out.resize(mTrees.size());
375 for(unsigned int i = 0; i < mTrees.size(); ++i)
376 mergeTreeToTorchTree<dataType>(mTrees[i], out[i], normalize,
377 allRevNodeCorr[i],
378 allRevNodeCorrSize[i]);
379 }
380
388 template <class dataType>
389 bool torchTensorToMergeTree(TorchMergeTree<dataType> &tmt,
390 bool normalized,
391 ftm::MergeTree<dataType> &mTreeOut) {
392 std::vector<unsigned int> &nodeCorr = tmt.nodeCorr;
393 std::vector<ftm::idNode> &parentsOri = tmt.parentsOri;
394
395 mTreeOut = ttk::ftm::copyMergeTree<dataType>(tmt.mTree);
396 if(parentsOri.empty())
397 std::cout << "[torchTensorToMergeTree] parentsOri.empty()" << std::endl;
398 for(unsigned int i = 0; i < parentsOri.size(); ++i)
399 if(parentsOri[i] < mTreeOut.tree.getNumberOfNodes()
400 and mTreeOut.tree.isRoot(i))
401 mTreeOut.tree.setParent(i, parentsOri[i]);
402 ftm::idNode root = mTreeOut.tree.getRoot();
403 if(root >= mTreeOut.tree.getNumberOfNodes())
404 return true;
405
406 bool isJT = tmt.mTree.tree.template isJoinTree<dataType>();
407 torch::Tensor tensor = tmt.tensor;
408 if(!tensor.device().is_cpu())
409 tensor = tensor.cpu();
410 std::vector<dataType> tensorVec(
411 tensor.data_ptr<float>(), tensor.data_ptr<float>() + tensor.numel());
412 std::vector<dataType> scalarsVector;
413 ttk::ftm::getTreeScalars<dataType>(mTreeOut, scalarsVector);
414 std::queue<ftm::idNode> queue;
415 queue.emplace(root);
416 while(!queue.empty()) {
417 ftm::idNode node = queue.front();
418 queue.pop();
419
420 auto birthValue = tensorVec[nodeCorr[node] * 2];
421 auto deathValue = tensorVec[nodeCorr[node] * 2 + 1];
422 if(normalized and !mTreeOut.tree.isRoot(node)) {
423 ftm::idNode nodeParent = mTreeOut.tree.getParentSafe(node);
424 ftm::idNode nodeParentOrigin
425 = mTreeOut.tree.getNode(nodeParent)->getOrigin();
426 ftm::idNode birthParentNode = (isJT ? nodeParentOrigin : nodeParent);
427 ftm::idNode deathParentNode = (isJT ? nodeParent : nodeParentOrigin);
428 auto birthParent = scalarsVector[birthParentNode];
429 auto deathParent = scalarsVector[deathParentNode];
430 birthValue = birthValue * (deathParent - birthParent) + birthParent;
431 deathValue = deathValue * (deathParent - birthParent) + birthParent;
432 }
433 ftm::idNode nodeOrigin = mTreeOut.tree.getNode(node)->getOrigin();
434 scalarsVector[node] = (isJT ? deathValue : birthValue);
435 scalarsVector[nodeOrigin] = (isJT ? birthValue : deathValue);
436
437 std::vector<ftm::idNode> children;
438 mTreeOut.tree.getChildren(node, children);
439 for(auto &child : children)
440 queue.emplace(child);
441 }
442 ftm::setTreeScalars<dataType>(mTreeOut, scalarsVector);
443 return false;
444 }
445
452 template <class dataType>
453 void fillMergeTreeStructure(TorchMergeTree<dataType> &tmt) {
454 std::vector<ftm::idNode> &parentsOri = tmt.parentsOri;
455 for(unsigned int i = 0; i < parentsOri.size(); ++i)
456 if(parentsOri[i] < tmt.mTree.tree.getNumberOfNodes()
457 and tmt.mTree.tree.isRoot(i))
458 tmt.mTree.tree.setParent(i, parentsOri[i]);
459 }
460
468 template <class dataType>
469 void getReverseTorchNodeCorr(TorchMergeTree<dataType> &tmt,
470 std::vector<unsigned int> &revNodeCorr) {
471 revNodeCorr.clear();
472 revNodeCorr.assign(
473 tmt.tensor.sizes()[0], std::numeric_limits<unsigned int>::max());
474 for(unsigned int i = 0; i < tmt.nodeCorr.size(); ++i) {
475 if(tmt.nodeCorr[i] != std::numeric_limits<unsigned int>::max()) {
476 revNodeCorr[tmt.nodeCorr[i] * 2] = i;
477 revNodeCorr[tmt.nodeCorr[i] * 2 + 1] = i;
478 }
479 }
480 }
481
489 template <class dataType>
490 void axisVectorToTorchTensor(ftm::MergeTree<dataType> &mTree,
491 std::vector<std::vector<double>> &v,
492 torch::Tensor &tensor) {
493 std::vector<double> v_flatten;
494 std::queue<ftm::idNode> queue;
495 queue.emplace(mTree.tree.getRoot());
496 while(!queue.empty()) {
497 ftm::idNode node = queue.front();
498 queue.pop();
499
500 v_flatten.emplace_back(v[node][0]);
501 v_flatten.emplace_back(v[node][1]);
502
503 std::vector<ftm::idNode> children;
504 mTree.tree.getChildren(node, children);
505 for(auto child : children)
506 queue.emplace(child);
507 }
508 tensor = torch::tensor(v_flatten);
509 tensor = tensor.reshape({-1, 1});
510 }
511
519 template <class dataType>
520 void axisVectorsToTorchTensor(
522 std::vector<std::vector<std::vector<double>>> &vS,
523 torch::Tensor &tensor) {
524 std::vector<torch::Tensor> allTensors;
525 for(auto &v : vS) {
526 torch::Tensor t;
527 axisVectorToTorchTensor(mTree, v, t);
528 allTensors.emplace_back(t);
529 }
530 tensor = torch::cat(allTensors, 1);
531 }
532
541 template <class dataType>
542 void getTensorMatching(
543 const TorchMergeTree<dataType> &a,
544 const TorchMergeTree<dataType> &b,
545 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
546 std::vector<int> &tensorMatching) {
547 tensorMatching.clear();
548 tensorMatching.resize((int)(a.tensor.sizes()[0] / 2), -1);
549 for(auto &match : matching) {
550 auto match1 = std::get<0>(match);
551 auto match2 = std::get<1>(match);
552 if(a.nodeCorr[match1] < tensorMatching.size())
553 tensorMatching[a.nodeCorr[match1]] = b.nodeCorr[match2];
554 }
555 }
556
565 template <class dataType>
566 void getInverseTensorMatching(
567 const TorchMergeTree<dataType> &a,
568 const TorchMergeTree<dataType> &b,
569 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
570 std::vector<int> &tensorMatching) {
571 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> invMatching(
572 matching.size());
573 for(unsigned int i = 0; i < matching.size(); ++i)
574 invMatching[i]
575 = std::make_tuple(std::get<1>(matching[i]), std::get<0>(matching[i]),
576 std::get<2>(matching[i]));
577 getTensorMatching(b, a, invMatching, tensorMatching);
578 }
579#endif
580
581 } // namespace mtu
582
583} // namespace ttk
Node * getNode(idNode nodeId) const
Definition FTMTree_MT.h:393
void getChildren(idNode nodeId, std::vector< idNode > &res) const
idNode getNumberOfNodes() const
Definition FTMTree_MT.h:389
void setParent(idNode nodeId, idNode newParentNodeId)
idNode getRoot() const
idNode getParentSafe(idNode nodeId) const
bool isRoot(idNode nodeId) const
SimplexId getOrigin() const
Definition FTMNode.h:64
void setTreeScalars(MergeTree< dataType > &mergeTree, std::vector< dataType > &scalarsVector)
void getTreeScalars(const ftm::FTMTree_MT *tree, std::vector< dataType > &scalarsVector)
MergeTree< dataType > copyMergeTree(const ftm::FTMTree_MT *tree, bool doSplitMultiPersPairs=false)
unsigned int idNode
Node index in vect_nodes_.
TTK base package defining the standard types.
std::tuple< dataType, dataType > getParametrizedBirthDeath(ftm::FTMTree_MT *tree, ftm::idNode node, bool normalize)
coefficient_t normalize(const coefficient_t n, const coefficient_t modulus)
Definition ripser.cpp:171
ftm::FTMTree_MT tree
Definition FTMTree_MT.h:906