TTK
Loading...
Searching...
No Matches
MergeTreeAutoencoderDecoding.cpp
Go to the documentation of this file.
3
5 // inherited from Debug: prefix will be printed at the beginning of every msg
6 this->setDebugMsgPrefix("MergeTreeAutoencoderDecoding");
7}
8
10 std::vector<ttk::ftm::MergeTree<float>> &originsTrees,
11 std::vector<ttk::ftm::MergeTree<float>> &originsPrimeTrees,
12 std::vector<unsigned int *> &allRevNodeCorr,
13 std::vector<unsigned int *> &allRevNodeCorrPrime,
14 std::vector<unsigned int> &allRevNodeCorrSize,
15 std::vector<unsigned int> &allRevNodeCorrPrimeSize) {
16#ifndef TTK_ENABLE_TORCH
17 TTK_FORCE_USE(originsTrees);
18 TTK_FORCE_USE(originsPrimeTrees);
19 TTK_FORCE_USE(allRevNodeCorr);
20 TTK_FORCE_USE(allRevNodeCorrPrime);
21 TTK_FORCE_USE(allRevNodeCorrSize);
22 TTK_FORCE_USE(allRevNodeCorrPrimeSize);
23 printErr("This module requires Torch.");
24#else
25 // --- Preprocessing
26 if(not isPersistenceDiagram_) {
27 for(unsigned int i = 0; i < originsPrimeTrees.size(); ++i) {
28 bool const useMinMax = true;
29 bool const cleanTree = false;
30 bool const pt = 0.0;
31 std::vector<int> nodeCorr;
32 preprocessingPipeline<float>(originsTrees[i], 0.0, 100.0, 100.0,
33 branchDecomposition_, useMinMax, cleanTree,
34 pt, nodeCorr, false);
35 preprocessingPipeline<float>(originsPrimeTrees[i], 0.0, 100.0, 100.0,
36 branchDecomposition_, useMinMax, cleanTree,
37 pt, nodeCorr, false);
38 }
39 }
40 mergeTreesToTorchTrees(originsTrees, origins_, normalizedWasserstein_,
41 allRevNodeCorr, allRevNodeCorrSize);
42 mergeTreesToTorchTrees(originsPrimeTrees, originsPrime_,
43 normalizedWasserstein_, allRevNodeCorrPrime,
44 allRevNodeCorrPrimeSize);
45
46 // --- Execute
47 if(allAlphas_[0].size() != originsPrime_.size()) {
48 customAlphas_.resize(allAlphas_.size());
49 for(unsigned int i = 0; i < customAlphas_.size(); ++i)
50 customAlphas_[i] = std::vector<float>(
51 allAlphas_[i][0].data_ptr<float>(),
52 allAlphas_[i][0].data_ptr<float>() + allAlphas_[i][0].numel());
53 allAlphas_.clear();
54 createCustomRecs(origins_, originsPrime_);
55 } else {
56 recs_.resize(allAlphas_.size());
57 for(unsigned int i = 0; i < recs_.size(); ++i) {
58 recs_[i].resize(allAlphas_[i].size());
59 for(unsigned int l = 0; l < allAlphas_[i].size(); ++l) {
60 torch::Tensor act
61 = (activate_ ? activation(allAlphas_[i][l]) : allAlphas_[i][l]);
62 getMultiInterpolation(
63 originsPrime_[l], vSPrimeTensor_[l], act, recs_[i][l]);
64 }
65 }
66 }
67
68 // --- Postprocessing
69 for(unsigned int l = 0; l < origins_.size(); ++l) {
70 postprocessingPipeline<float>(&(origins_[l].mTree.tree));
71 postprocessingPipeline<float>(&(originsPrime_[l].mTree.tree));
72 }
73 if(!recs_.empty()) {
74 for(unsigned int j = 0; j < recs_[0].size(); ++j) {
75 for(unsigned int i = 0; i < recs_.size(); ++i) {
76 wae::fixTreePrecisionScalars(recs_[i][j].mTree);
77 postprocessingPipeline<float>(&(recs_[i][j].mTree.tree));
78 }
79 }
80 }
81#endif
82}
#define TTK_FORCE_USE(x)
Force the compiler to use the function/method parameter.
Definition BaseClass.h:57
void setDebugMsgPrefix(const std::string &prefix)
Definition Debug.h:364
void execute(std::vector< ttk::ftm::MergeTree< float > > &originsTrees, std::vector< ttk::ftm::MergeTree< float > > &originsPrimeTrees, std::vector< unsigned int * > &allRevNodeCorr, std::vector< unsigned int * > &allRevNodeCorrPrime, std::vector< unsigned int > &allRevNodeCorrSize, std::vector< unsigned int > &allRevNodeCorrPrimeSize)
void fixTreePrecisionScalars(ftm::MergeTree< float > &mTree)
Fix the scalars of a merge tree to ensure that the nesting condition is respected.