12 std::vector<unsigned int *> &allRevNodeCorr,
13 std::vector<unsigned int *> &allRevNodeCorrPrime,
14 std::vector<unsigned int> &allRevNodeCorrSize,
15 std::vector<unsigned int> &allRevNodeCorrPrimeSize) {
16#ifndef TTK_ENABLE_TORCH
23 printErr(
"This module requires Torch.");
26 if(not isPersistenceDiagram_) {
27 for(
unsigned int i = 0; i < originsPrimeTrees.size(); ++i) {
28 bool const useMinMax =
true;
29 bool const cleanTree =
false;
31 std::vector<int> nodeCorr;
32 preprocessingPipeline<float>(originsTrees[i], 0.0, 100.0, 100.0,
33 branchDecomposition_, useMinMax, cleanTree,
35 preprocessingPipeline<float>(originsPrimeTrees[i], 0.0, 100.0, 100.0,
36 branchDecomposition_, useMinMax, cleanTree,
40 mergeTreesToTorchTrees(originsTrees, origins_, normalizedWasserstein_,
41 allRevNodeCorr, allRevNodeCorrSize);
42 mergeTreesToTorchTrees(originsPrimeTrees, originsPrime_,
43 normalizedWasserstein_, allRevNodeCorrPrime,
44 allRevNodeCorrPrimeSize);
47 if(allAlphas_[0].size() != originsPrime_.size()) {
48 customAlphas_.resize(allAlphas_.size());
49 for(
unsigned int i = 0; i < customAlphas_.size(); ++i)
50 customAlphas_[i] = std::vector<float>(
51 allAlphas_[i][0].data_ptr<float>(),
52 allAlphas_[i][0].data_ptr<float>() + allAlphas_[i][0].numel());
54 createCustomRecs(origins_, originsPrime_);
56 recs_.resize(allAlphas_.size());
57 for(
unsigned int i = 0; i < recs_.size(); ++i) {
58 recs_[i].resize(allAlphas_[i].size());
59 for(
unsigned int l = 0; l < allAlphas_[i].size(); ++l) {
61 = (activate_ ? activation(allAlphas_[i][l]) : allAlphas_[i][l]);
62 getMultiInterpolation(
63 originsPrime_[l], vSPrimeTensor_[l], act, recs_[i][l]);
69 for(
unsigned int l = 0; l < origins_.size(); ++l) {
70 postprocessingPipeline<float>(&(origins_[l].mTree.tree));
71 postprocessingPipeline<float>(&(originsPrime_[l].mTree.tree));
74 for(
unsigned int j = 0; j < recs_[0].size(); ++j) {
75 for(
unsigned int i = 0; i < recs_.size(); ++i) {
77 postprocessingPipeline<float>(&(recs_[i][j].mTree.tree));
void execute(std::vector< ttk::ftm::MergeTree< float > > &originsTrees, std::vector< ttk::ftm::MergeTree< float > > &originsPrimeTrees, std::vector< unsigned int * > &allRevNodeCorr, std::vector< unsigned int * > &allRevNodeCorrPrime, std::vector< unsigned int > &allRevNodeCorrSize, std::vector< unsigned int > &allRevNodeCorrPrimeSize)