24#ifdef TTK_ENABLE_TORCH
31 void copyTensor(torch::Tensor &a, torch::Tensor &b);
33 template <
typename dataType>
34 struct TorchMergeTree {
37 std::vector<unsigned int> nodeCorr;
38 std::vector<ftm::idNode> parentsOri;
47 void getDeltaProjTensor(torch::Tensor &diagTensor,
48 torch::Tensor &deltaProjTensor);
68 void dataReorderingGivenMatching(mtu::TorchMergeTree<float> &tree,
69 mtu::TorchMergeTree<float> &tree2,
70 torch::Tensor &tree1ProjIndexer,
71 torch::Tensor &tree2ReorderingIndexes,
72 torch::Tensor &tree2ReorderedTensor,
73 torch::Tensor &tree2DeltaProjTensor,
74 torch::Tensor &tree1ReorderedTensor,
75 torch::Tensor &tree2ProjIndexer,
76 bool doubleReordering =
true);
92 void dataReorderingGivenMatching(mtu::TorchMergeTree<float> &tree,
93 mtu::TorchMergeTree<float> &tree2,
94 torch::Tensor &tree1ProjIndexer,
95 torch::Tensor &tree2ReorderingIndexes,
96 torch::Tensor &tree2ReorderedTensor,
97 torch::Tensor &tree2DeltaProjTensor);
109 void dataReorderingGivenMatching(
110 mtu::TorchMergeTree<float> &tree,
111 mtu::TorchMergeTree<float> &tree2,
112 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
113 torch::Tensor &tree1ReorderedTensor,
114 torch::Tensor &tree2ReorderedTensor,
115 bool doubleReordering =
true);
125 void dataReorderingGivenMatching(
126 mtu::TorchMergeTree<float> &tree,
127 mtu::TorchMergeTree<float> &tree2,
128 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
129 torch::Tensor &tree2ReorderedTensor);
137 void meanBirthShift(torch::Tensor &diagTensor,
138 torch::Tensor &diagBaseTensor);
147 void meanBirthMaxPersShift(torch::Tensor &tensor,
148 torch::Tensor &baseTensor);
157 void belowDiagonalPointsShift(torch::Tensor &tensor,
158 torch::Tensor &backupTensor);
166 void normalizeVectors(torch::Tensor &originTensor,
167 torch::Tensor &vectorsTensor);
175 void normalizeVectors(mtu::TorchMergeTree<float> &origin,
176 std::vector<std::vector<double>> &vectors);
183 unsigned int getLatentLayerIndex();
190 bool isThereMissingPairs(mtu::TorchMergeTree<float> &interpolation);
198 template <
class dataType>
199 void copyTorchMergeTree(TorchMergeTree<dataType> &tmTree,
200 TorchMergeTree<dataType> &out) {
201 out.mTree = ftm::copyMergeTree<dataType>(tmTree.mTree);
202 copyTensor(tmTree.tensor, out.tensor);
203 out.nodeCorr = tmTree.nodeCorr;
204 out.parentsOri = tmTree.parentsOri;
219 template <
class dataType>
221 torch::Tensor &tensor,
222 std::vector<unsigned int> &nodeCorr,
224 unsigned int *revNodeCorr =
nullptr,
225 unsigned int revNodeCorrSize = 0) {
228 std::numeric_limits<unsigned int>::max());
229 std::vector<dataType> scalars;
230 std::queue<ftm::idNode> queue;
231 if(revNodeCorrSize == 0)
234 for(
unsigned int i = 0; i < revNodeCorrSize; i += 2)
235 queue.emplace(revNodeCorr[i]);
237 unsigned int cpt = 0;
238 while(!queue.empty()) {
242 auto birthDeath = ttk::getParametrizedBirthDeath<dataType>(
244 auto birth = std::get<0>(birthDeath);
245 auto death = std::get<1>(birthDeath);
246 scalars.emplace_back(birth);
247 scalars.emplace_back(death);
249 nodeCorr[node] = cpt;
252 if(revNodeCorrSize == 0) {
253 std::vector<ftm::idNode> children;
255 for(
auto child : children)
256 queue.emplace(child);
259 tensor = torch::tensor(scalars).reshape({-1, 1});
269 template <
class dataType>
271 torch::Tensor &tensor,
273 std::vector<unsigned int> nodeCorr;
274 mergeTreeToTorchTensor<dataType>(mTree, tensor, nodeCorr,
normalize);
283 template <
class dataType>
285 std::vector<ftm::idNode> &parents) {
287 std::fill(parents.begin(), parents.end(),
288 std::numeric_limits<ftm::idNode>::max());
289 std::queue<ftm::idNode> queue;
291 while(!queue.empty()) {
296 std::vector<ftm::idNode> children;
298 for(
auto child : children)
299 queue.emplace(child);
310 template <
class dataType>
312 TorchMergeTree<dataType> &out,
314 out.mTree = copyMergeTree(mTree);
315 getParentsVector(out.mTree, out.parentsOri);
316 mergeTreeToTorchTensor<dataType>(
317 out.mTree, out.tensor, out.nodeCorr,
normalize);
327 template <
class dataType>
329 std::vector<TorchMergeTree<dataType>> &out,
331 out.resize(mTrees.size());
332 for(
unsigned int i = 0; i < mTrees.size(); ++i)
333 mergeTreeToTorchTree<dataType>(mTrees[i], out[i],
normalize);
346 template <
class dataType>
348 TorchMergeTree<dataType> &out,
350 unsigned int *revNodeCorr,
351 unsigned int revNodeCorrSize) {
352 out.mTree = copyMergeTree(mTree);
353 getParentsVector(out.mTree, out.parentsOri);
354 mergeTreeToTorchTensor<dataType>(out.mTree, out.tensor, out.nodeCorr,
355 normalize, revNodeCorr, revNodeCorrSize);
368 template <
class dataType>
370 std::vector<TorchMergeTree<dataType>> &out,
372 std::vector<unsigned int *> &allRevNodeCorr,
373 std::vector<unsigned int> &allRevNodeCorrSize) {
374 out.resize(mTrees.size());
375 for(
unsigned int i = 0; i < mTrees.size(); ++i)
376 mergeTreeToTorchTree<dataType>(mTrees[i], out[i],
normalize,
378 allRevNodeCorrSize[i]);
388 template <
class dataType>
389 bool torchTensorToMergeTree(TorchMergeTree<dataType> &tmt,
392 std::vector<unsigned int> &nodeCorr = tmt.nodeCorr;
393 torch::Tensor &tensor = tmt.tensor;
394 std::vector<ftm::idNode> &parentsOri = tmt.parentsOri;
396 mTreeOut = ttk::ftm::copyMergeTree<dataType>(tmt.mTree);
397 if(parentsOri.empty())
398 std::cout <<
"[torchTensorToMergeTree] parentsOri.empty()" << std::endl;
399 for(
unsigned int i = 0; i < parentsOri.size(); ++i)
407 bool isJT = tmt.mTree.tree.template isJoinTree<dataType>();
408 std::vector<dataType> tensorVec(
409 tensor.data_ptr<
float>(), tensor.data_ptr<
float>() + tensor.numel());
410 std::vector<dataType> scalarsVector;
411 ttk::ftm::getTreeScalars<dataType>(mTreeOut, scalarsVector);
412 std::queue<ftm::idNode> queue;
414 while(!queue.empty()) {
418 auto birthValue = tensorVec[nodeCorr[node] * 2];
419 auto deathValue = tensorVec[nodeCorr[node] * 2 + 1];
420 if(normalized and !mTreeOut.
tree.
isRoot(node)) {
424 ftm::idNode birthParentNode = (isJT ? nodeParentOrigin : nodeParent);
425 ftm::idNode deathParentNode = (isJT ? nodeParent : nodeParentOrigin);
426 auto birthParent = scalarsVector[birthParentNode];
427 auto deathParent = scalarsVector[deathParentNode];
428 birthValue = birthValue * (deathParent - birthParent) + birthParent;
429 deathValue = deathValue * (deathParent - birthParent) + birthParent;
432 scalarsVector[node] = (isJT ? deathValue : birthValue);
433 scalarsVector[nodeOrigin] = (isJT ? birthValue : deathValue);
435 std::vector<ftm::idNode> children;
437 for(
auto &child : children)
438 queue.emplace(child);
440 ftm::setTreeScalars<dataType>(mTreeOut, scalarsVector);
450 template <
class dataType>
451 void fillMergeTreeStructure(TorchMergeTree<dataType> &tmt) {
452 std::vector<ftm::idNode> &parentsOri = tmt.parentsOri;
453 for(
unsigned int i = 0; i < parentsOri.size(); ++i)
454 if(parentsOri[i] < tmt.mTree.tree.getNumberOfNodes()
455 and tmt.mTree.tree.isRoot(i))
456 tmt.mTree.tree.setParent(i, parentsOri[i]);
466 template <
class dataType>
467 void getReverseTorchNodeCorr(TorchMergeTree<dataType> &tmt,
468 std::vector<unsigned int> &revNodeCorr) {
471 tmt.tensor.sizes()[0], std::numeric_limits<unsigned int>::max());
472 for(
unsigned int i = 0; i < tmt.nodeCorr.size(); ++i) {
473 if(tmt.nodeCorr[i] != std::numeric_limits<unsigned int>::max()) {
474 revNodeCorr[tmt.nodeCorr[i] * 2] = i;
475 revNodeCorr[tmt.nodeCorr[i] * 2 + 1] = i;
487 template <
class dataType>
489 std::vector<std::vector<double>> &v,
490 torch::Tensor &tensor) {
491 std::vector<double> v_flatten;
492 std::queue<ftm::idNode> queue;
494 while(!queue.empty()) {
498 v_flatten.emplace_back(v[node][0]);
499 v_flatten.emplace_back(v[node][1]);
501 std::vector<ftm::idNode> children;
503 for(
auto child : children)
504 queue.emplace(child);
506 tensor = torch::tensor(v_flatten);
507 tensor = tensor.reshape({-1, 1});
517 template <
class dataType>
518 void axisVectorsToTorchTensor(
520 std::vector<std::vector<std::vector<double>>> &vS,
521 torch::Tensor &tensor) {
522 std::vector<torch::Tensor> allTensors;
525 axisVectorToTorchTensor(mTree, v, t);
526 allTensors.emplace_back(t);
528 tensor = torch::cat(allTensors, 1);
539 template <
class dataType>
540 void getTensorMatching(
541 TorchMergeTree<dataType> &a,
542 TorchMergeTree<dataType> &b,
543 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
544 std::vector<int> &tensorMatching) {
545 tensorMatching.clear();
546 tensorMatching.resize((
int)(a.tensor.sizes()[0] / 2), -1);
547 for(
auto &match : matching) {
548 auto match1 = std::get<0>(match);
549 auto match2 = std::get<1>(match);
550 if(a.nodeCorr[match1] < tensorMatching.size())
551 tensorMatching[a.nodeCorr[match1]] = b.nodeCorr[match2];
563 template <
class dataType>
564 void getInverseTensorMatching(
565 TorchMergeTree<dataType> &a,
566 TorchMergeTree<dataType> &b,
567 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
568 std::vector<int> &tensorMatching) {
569 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> invMatching(
571 for(
unsigned int i = 0; i < matching.size(); ++i)
573 = std::make_tuple(std::get<1>(matching[i]), std::get<0>(matching[i]),
574 std::get<2>(matching[i]));
575 getTensorMatching(b, a, invMatching, tensorMatching);