TTK
Loading...
Searching...
No Matches
MergeTreeAutoencoder.h
Go to the documentation of this file.
1
21
22#pragma once
23
24// ttk common includes
25#include <Debug.h>
26#include <Geometry.h>
29#include <MergeTreeTorchUtils.h>
30
31#ifdef TTK_ENABLE_TORCH
32#include <torch/torch.h>
33#endif
34
35namespace ttk {
36
41 class MergeTreeAutoencoder : virtual public Debug,
43
44 protected:
45 // Model hyper-parameters;
48 unsigned int inputNumberOfAxes_ = 16;
57 bool customLossSpace_ = false;
58 bool customLossActivate_ = false;
62
63 // Old hyper-parameters
64 bool fullSymmetricAE_ = false;
65
66#ifdef TTK_ENABLE_TORCH
67 // Model optimized parameters
68 std::vector<torch::Tensor> bestLatentCentroids_, latentCentroids_;
69
70 std::vector<torch::Tensor> vSTensorCopy_, vSPrimeTensorCopy_;
71
72 std::vector<mtu::TorchMergeTree<float>> customRecs_;
73#endif
74
75 // Filled by the algorithm
77 std::vector<std::vector<float>> distanceMatrix_, customAlphas_;
78
79 public:
81
82#ifdef TTK_ENABLE_TORCH
83 // -----------------------------------------------------------------------
84 // --- Init
85 // -----------------------------------------------------------------------
86 void initClusteringLossParameters();
87
88 bool initResetOutputBasis(unsigned int l,
89 unsigned int layerNoAxes,
90 double layerOriginPrimeSizePercent,
91 std::vector<mtu::TorchMergeTree<float>> &trees,
92 std::vector<mtu::TorchMergeTree<float>> &trees2,
93 std::vector<bool> &isTrain) override;
94
95 void initOutputBasisSpecialCase(
96 unsigned int l,
97 unsigned int layerNoAxes,
98 std::vector<mtu::TorchMergeTree<float>> &trees,
99 std::vector<mtu::TorchMergeTree<float>> &trees2);
100
101 float initParameters(std::vector<mtu::TorchMergeTree<float>> &trees,
102 std::vector<mtu::TorchMergeTree<float>> &trees2,
103 std::vector<bool> &isTrain,
104 bool computeError = false) override;
105
106 // -----------------------------------------------------------------------
107 // --- Backward
108 // -----------------------------------------------------------------------
109 bool backwardStep(
110 std::vector<mtu::TorchMergeTree<float>> &trees,
111 std::vector<mtu::TorchMergeTree<float>> &outs,
112 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
113 &matchings,
114 std::vector<mtu::TorchMergeTree<float>> &trees2,
115 std::vector<mtu::TorchMergeTree<float>> &outs2,
116 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
117 &matchings2,
118 std::vector<std::vector<torch::Tensor>> &alphas,
119 torch::optim::Optimizer &optimizer,
120 std::vector<unsigned int> &indexes,
121 std::vector<bool> &isTrain,
122 std::vector<torch::Tensor> &torchCustomLoss) override;
123
124 // -----------------------------------------------------------------------
125 // --- Convergence
126 // -----------------------------------------------------------------------
127 float computeOneLoss(
128 mtu::TorchMergeTree<float> &tree,
129 mtu::TorchMergeTree<float> &out,
130 mtu::TorchMergeTree<float> &tree2,
131 mtu::TorchMergeTree<float> &out2,
132 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
133 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
134 std::vector<torch::Tensor> &alphas,
135 unsigned int treeIndex) override;
136
137 // -----------------------------------------------------------------------
138 // --- Main Functions
139 // -----------------------------------------------------------------------
140 void
141 customInit(std::vector<mtu::TorchMergeTree<float>> &torchTrees,
142 std::vector<mtu::TorchMergeTree<float>> &torchTrees2) override;
143
144 void addCustomParameters(std::vector<torch::Tensor> &parameters) override;
145
146 void computeCustomLosses(
147 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
148 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
149 std::vector<std::vector<torch::Tensor>> &bestAlphas,
150 std::vector<unsigned int> &indexes,
151 std::vector<bool> &isTrain,
152 unsigned int iteration,
153 std::vector<std::vector<float>> &gapCustomLosses,
154 std::vector<std::vector<float>> &iterationCustomLosses,
155 std::vector<torch::Tensor> &torchCustomLoss) override;
156
157 float computeIterationTotalLoss(
158 float iterationLoss,
159 std::vector<std::vector<float>> &iterationCustomLosses,
160 std::vector<float> &iterationCustomLoss) override;
161
162 void printCustomLosses(std::vector<float> &customLoss,
163 std::stringstream &prefix,
164 const debug::Priority &priority
165 = debug::Priority::INFO) override;
166
167 void
168 printGapLoss(float loss,
169 std::vector<std::vector<float>> &gapCustomLosses) override;
170
171 // -----------------------------------------------------------------------
172 // --- Custom Losses
173 // -----------------------------------------------------------------------
174 double getCustomLossDynamicWeight(double recLoss, double &baseLoss);
175
176 void computeMetricLoss(
177 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
178 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
179 std::vector<std::vector<torch::Tensor>> alphas,
180 std::vector<std::vector<float>> &baseDistanceMatrix,
181 std::vector<unsigned int> &indexes,
182 torch::Tensor &metricLoss);
183
184 void computeClusteringLoss(std::vector<std::vector<torch::Tensor>> &alphas,
185 std::vector<unsigned int> &indexes,
186 torch::Tensor &clusteringLoss,
187 torch::Tensor &asgn);
188
189 void computeTrackingLoss(torch::Tensor &trackingLoss);
190
191 // ---------------------------------------------------------------------------
192 // --- End Functions
193 // ---------------------------------------------------------------------------
194 void createCustomRecs();
195
196 // -----------------------------------------------------------------------
197 // --- Utils
198 // -----------------------------------------------------------------------
199 unsigned int getLatentLayerIndex();
200
201 void copyCustomParams(bool get) override;
202
203 // ---------------------------------------------------------------------------
204 // --- Main Functions
205 // ---------------------------------------------------------------------------
206 void
207 executeEndFunction(std::vector<ftm::MergeTree<float>> &trees,
208 std::vector<ftm::MergeTree<float>> &trees2) override;
209#endif
210 }; // MergeTreeAutoencoder class
211
212} // namespace ttk
std::vector< std::vector< float > > distanceMatrix_
std::vector< std::vector< float > > customAlphas_
TTK base package defining the standard types.