TTK
Loading...
Searching...
No Matches
MergeTreeAutoencoderDecoding.cpp
Go to the documentation of this file.
3
5 // inherited from Debug: prefix will be printed at the beginning of every msg
6 this->setDebugMsgPrefix("MergeTreeAutoencoderDecoding");
7}
8
10 std::vector<ttk::ftm::MergeTree<float>> &originsTrees,
11 std::vector<ttk::ftm::MergeTree<float>> &originsPrimeTrees,
12 std::vector<unsigned int *> &allRevNodeCorr,
13 std::vector<unsigned int *> &allRevNodeCorrPrime,
14 std::vector<unsigned int> &allRevNodeCorrSize,
15 std::vector<unsigned int> &allRevNodeCorrPrimeSize) {
16#ifndef TTK_ENABLE_TORCH
17 TTK_FORCE_USE(originsTrees);
18 TTK_FORCE_USE(originsPrimeTrees);
19 TTK_FORCE_USE(allRevNodeCorr);
20 TTK_FORCE_USE(allRevNodeCorrPrime);
21 TTK_FORCE_USE(allRevNodeCorrSize);
22 TTK_FORCE_USE(allRevNodeCorrPrimeSize);
23 printErr("This module requires Torch.");
24#else
25 // --- Preprocessing
26 if(not isPersistenceDiagram_) {
27 for(unsigned int i = 0; i < originsPrimeTrees.size(); ++i) {
28 bool const useMinMax = true;
29 bool const cleanTree = false;
30 bool const pt = 0.0;
31 std::vector<int> nodeCorr;
32 preprocessingPipeline<float>(originsTrees[i], 0.0, 100.0, 100.0,
33 branchDecomposition_, useMinMax, cleanTree,
34 pt, nodeCorr, false);
35 preprocessingPipeline<float>(originsPrimeTrees[i], 0.0, 100.0, 100.0,
36 branchDecomposition_, useMinMax, cleanTree,
37 pt, nodeCorr, false);
38 }
39 }
40 mergeTreesToTorchTrees(originsTrees, originsCopy_, normalizedWasserstein_,
41 allRevNodeCorr, allRevNodeCorrSize);
42 mergeTreesToTorchTrees(originsPrimeTrees, originsPrimeCopy_,
43 normalizedWasserstein_, allRevNodeCorrPrime,
44 allRevNodeCorrPrimeSize);
45 layers_.resize(noLayers_);
46 for(unsigned int l = 0; l < layers_.size(); ++l) {
47 layers_[l].setOrigin(originsCopy_[l]);
48 layers_[l].setVSTensor(vSTensorCopy_[l]);
49 layers_[l].setOriginPrime(originsPrimeCopy_[l]);
50 layers_[l].setVSPrimeTensor(vSPrimeTensorCopy_[l]);
53 and l < (trackingLossDecoding_ ? noLayers_ : getLatentLayerIndex() + 1);
55 passLayerParameters(layers_[l]);
56 }
57
58 // --- Execute
59 if(allAlphas_[0].size() != originsPrimeCopy_.size()) {
60 customAlphas_.resize(allAlphas_.size());
61 for(unsigned int i = 0; i < customAlphas_.size(); ++i)
62 customAlphas_[i] = std::vector<float>(
63 allAlphas_[i][0].data_ptr<float>(),
64 allAlphas_[i][0].data_ptr<float>() + allAlphas_[i][0].numel());
65 allAlphas_.clear();
66 createCustomRecs();
67 } else {
68 recs_.resize(allAlphas_.size());
69 for(unsigned int i = 0; i < recs_.size(); ++i) {
70 recs_[i].resize(allAlphas_[i].size());
71 for(unsigned int l = 0; l < allAlphas_[i].size(); ++l) {
72 torch::Tensor act
73 = (activate_ ? activation(allAlphas_[i][l]) : allAlphas_[i][l]);
74 layers_[l].getMultiInterpolation(layers_[l].getOriginPrime(),
75 layers_[l].getVSPrimeTensor(), act,
76 recs_[i][l]);
77 }
78 }
79 }
80
81 // --- Postprocessing
82 for(unsigned int l = 0; l < originsCopy_.size(); ++l) {
83 postprocessingPipeline<float>(&(originsCopy_[l].mTree.tree));
84 postprocessingPipeline<float>(&(originsPrimeCopy_[l].mTree.tree));
85 }
86 if(!recs_.empty()) {
87 for(unsigned int j = 0; j < recs_[0].size(); ++j) {
88 for(unsigned int i = 0; i < recs_.size(); ++i) {
89 fixTreePrecisionScalars(recs_[i][j].mTree);
90 postprocessingPipeline<float>(&(recs_[i][j].mTree.tree));
91 }
92 }
93 }
94#endif
95}
#define TTK_FORCE_USE(x)
Force the compiler to use the function/method parameter.
Definition BaseClass.h:57
void setDebugMsgPrefix(const std::string &prefix)
Definition Debug.h:364
int printErr(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
Definition Debug.h:149
void execute(std::vector< ttk::ftm::MergeTree< float > > &originsTrees, std::vector< ttk::ftm::MergeTree< float > > &originsPrimeTrees, std::vector< unsigned int * > &allRevNodeCorr, std::vector< unsigned int * > &allRevNodeCorrPrime, std::vector< unsigned int > &allRevNodeCorrSize, std::vector< unsigned int > &allRevNodeCorrPrimeSize)
std::vector< std::vector< float > > customAlphas_
void preprocessingPipeline(ftm::MergeTree< dataType > &mTree, double epsilonTree, double epsilon2Tree, double epsilon3Tree, bool branchDecompositionT, bool useMinMaxPairT, bool cleanTreeT, double persistenceThreshold, std::vector< int > &nodeCorr, bool deleteInconsistentNodes=true)
void postprocessingPipeline(ftm::FTMTree_MT *tree)