TTK
Loading...
Searching...
No Matches
MergeTreeAutoencoder.h
Go to the documentation of this file.
1
21
22#pragma once
23
24// ttk common includes
25#include <Debug.h>
26#include <Geometry.h>
28#include <MergeTreeTorchUtils.h>
29
30#ifdef TTK_ENABLE_TORCH
31#include <torch/torch.h>
32#endif
33
34namespace ttk {
35
40 class MergeTreeAutoencoder : virtual public Debug,
42
43 protected:
45 bool hasComputedOnce_ = false;
46
47 // Model hyper-parameters;
50 unsigned int inputNumberOfAxes_ = 16;
53 unsigned int minIteration_ = 0;
54 unsigned int maxIteration_ = 0;
55 unsigned int iterationGap_ = 100;
56 double batchSize_ = 1;
57 int optimizer_ = 0;
58 double gradientStepSize_ = 0.1;
59 double beta1_ = 0.9;
60 double beta2_ = 0.999;
67 bool customLossSpace_ = false;
68 bool customLossActivate_ = false;
70 unsigned int noInit_ = 4;
75 bool activate_ = true;
76 unsigned int activationFunction_ = 1;
77 bool activateOutputInit_ = false;
78
79 bool createOutput_ = true;
80
81 // Old hyper-parameters
82 bool fullSymmetricAE_ = false;
83
84#ifdef TTK_ENABLE_TORCH
85 // Model optimized parameters
86 std::vector<torch::Tensor> vSTensor_, vSPrimeTensor_, vS2Tensor_,
87 vS2PrimeTensor_, latentCentroids_;
88 std::vector<mtu::TorchMergeTree<float>> origins_, originsPrime_, origins2_,
89 origins2Prime_;
90
91 std::vector<mtu::TorchMergeTree<float>> originsCopy_, originsPrimeCopy_;
92
93 // Filled by the algorithm
94 std::vector<std::vector<torch::Tensor>> allAlphas_, allScaledAlphas_,
95 allActAlphas_, allActScaledAlphas_;
96 std::vector<std::vector<mtu::TorchMergeTree<float>>> recs_, recs2_;
97 std::vector<mtu::TorchMergeTree<float>> customRecs_;
98#endif
99
100 // Filled by the algorithm
101 unsigned noLayers_;
104 std::vector<unsigned int> clusterAsgn_;
105 std::vector<std::vector<float>> distanceMatrix_, customAlphas_;
106 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
108 std::vector<double> inputToBaryDistances_L0_;
109
110 // Tracking matchings
111 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
113 std::vector<
114 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>>
116 std::vector<std::vector<double>> branchesCorrelationMatrix_,
118
119 // Testing
121 std::vector<unsigned int> originsNoZeroGrad_, originsPrimeNoZeroGrad_,
124 bool outputInit_ = true;
125#ifdef TTK_ENABLE_TORCH
126 std::vector<mtu::TorchMergeTree<float>> initOrigins_, initOriginsPrime_,
127 initRecs_;
128#endif
130
131 public:
133
134#ifdef TTK_ENABLE_TORCH
135 // -----------------------------------------------------------------------
136 // --- Init
137 // -----------------------------------------------------------------------
138 void initOutputBasisTreeStructure(mtu::TorchMergeTree<float> &originPrime,
139 bool isJT,
140 mtu::TorchMergeTree<float> &baseOrigin);
141
142 void initOutputBasis(unsigned int l, unsigned int dim, unsigned int dim2);
143
144 void initOutputBasisVectors(unsigned int l,
145 torch::Tensor &w,
146 torch::Tensor &w2);
147
148 void initOutputBasisVectors(unsigned int l,
149 unsigned int dim,
150 unsigned int dim2);
151
152 void initInputBasisOrigin(
153 std::vector<ftm::MergeTree<float>> &treesToUse,
154 std::vector<ftm::MergeTree<float>> &trees2ToUse,
155 double barycenterSizeLimitPercent,
156 unsigned int barycenterMaxNoPairs,
157 unsigned int barycenterMaxNoPairs2,
158 mtu::TorchMergeTree<float> &origin,
159 mtu::TorchMergeTree<float> &origin2,
160 std::vector<double> &inputToBaryDistances,
161 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
162 &baryMatchings,
163 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
164 &baryMatchings2);
165
166 void initInputBasisVectors(
167 std::vector<mtu::TorchMergeTree<float>> &tmTreesToUse,
168 std::vector<mtu::TorchMergeTree<float>> &tmTrees2ToUse,
169 std::vector<ftm::MergeTree<float>> &treesToUse,
170 std::vector<ftm::MergeTree<float>> &trees2ToUse,
171 mtu::TorchMergeTree<float> &origin,
172 mtu::TorchMergeTree<float> &origin2,
173 unsigned int noVectors,
174 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
175 unsigned int l,
176 std::vector<double> &inputToBaryDistances,
177 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
178 &baryMatchings,
179 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
180 &baryMatchings2,
181 torch::Tensor &vSTensor,
182 torch::Tensor &vS2Tensor);
183
184 void initClusteringLossParameters();
185
186 float initParameters(std::vector<mtu::TorchMergeTree<float>> &trees,
187 std::vector<mtu::TorchMergeTree<float>> &trees2,
188 bool computeReconstructionError = false);
189
190 void initStep(std::vector<mtu::TorchMergeTree<float>> &trees,
191 std::vector<mtu::TorchMergeTree<float>> &trees2);
192
193 // -----------------------------------------------------------------------
194 // --- Interpolation
195 // -----------------------------------------------------------------------
196 void interpolationDiagonalProjection(
197 mtu::TorchMergeTree<float> &interpolationTensor);
198
199 void
200 interpolationNestingProjection(mtu::TorchMergeTree<float> &interpolation);
201
202 void interpolationProjection(mtu::TorchMergeTree<float> &interpolation);
203
204 void getMultiInterpolation(mtu::TorchMergeTree<float> &origin,
205 torch::Tensor &vS,
206 torch::Tensor &alphas,
207 mtu::TorchMergeTree<float> &interpolation);
208
209 // -----------------------------------------------------------------------
210 // --- Forward
211 // -----------------------------------------------------------------------
212 void getAlphasOptimizationTensors(
213 mtu::TorchMergeTree<float> &tree,
214 mtu::TorchMergeTree<float> &origin,
215 torch::Tensor &vSTensor,
216 mtu::TorchMergeTree<float> &interpolated,
217 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
218 torch::Tensor &reorderedTreeTensor,
219 torch::Tensor &deltaOrigin,
220 torch::Tensor &deltaA,
221 torch::Tensor &originTensor_f,
222 torch::Tensor &vSTensor_f);
223
224 void computeAlphas(
225 mtu::TorchMergeTree<float> &tree,
226 mtu::TorchMergeTree<float> &origin,
227 torch::Tensor &vSTensor,
228 mtu::TorchMergeTree<float> &interpolated,
229 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
230 mtu::TorchMergeTree<float> &tree2,
231 mtu::TorchMergeTree<float> &origin2,
232 torch::Tensor &vS2Tensor,
233 mtu::TorchMergeTree<float> &interpolated2,
234 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
235 torch::Tensor &alphasOut);
236
237 float assignmentOneData(
238 mtu::TorchMergeTree<float> &tree,
239 mtu::TorchMergeTree<float> &origin,
240 torch::Tensor &vSTensor,
241 mtu::TorchMergeTree<float> &tree2,
242 mtu::TorchMergeTree<float> &origin2,
243 torch::Tensor &vS2Tensor,
244 unsigned int k,
245 torch::Tensor &alphasInit,
246 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching,
247 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching2,
248 torch::Tensor &bestAlphas,
249 bool isCalled = false);
250
251 float assignmentOneData(mtu::TorchMergeTree<float> &tree,
252 mtu::TorchMergeTree<float> &origin,
253 torch::Tensor &vSTensor,
254 mtu::TorchMergeTree<float> &tree2,
255 mtu::TorchMergeTree<float> &origin2,
256 torch::Tensor &vS2Tensor,
257 unsigned int k,
258 torch::Tensor &alphasInit,
259 torch::Tensor &bestAlphas,
260 bool isCalled = false);
261
262 torch::Tensor activation(torch::Tensor &in);
263
264 void outputBasisReconstruction(mtu::TorchMergeTree<float> &originPrime,
265 torch::Tensor &vSPrimeTensor,
266 mtu::TorchMergeTree<float> &origin2Prime,
267 torch::Tensor &vS2PrimeTensor,
268 torch::Tensor &alphas,
269 mtu::TorchMergeTree<float> &out,
270 mtu::TorchMergeTree<float> &out2,
271 bool activate = true);
272
273 bool forwardOneLayer(mtu::TorchMergeTree<float> &tree,
274 mtu::TorchMergeTree<float> &origin,
275 torch::Tensor &vSTensor,
276 mtu::TorchMergeTree<float> &originPrime,
277 torch::Tensor &vSPrimeTensor,
278 mtu::TorchMergeTree<float> &tree2,
279 mtu::TorchMergeTree<float> &origin2,
280 torch::Tensor &vS2Tensor,
281 mtu::TorchMergeTree<float> &origin2Prime,
282 torch::Tensor &vS2PrimeTensor,
283 unsigned int k,
284 torch::Tensor &alphasInit,
285 mtu::TorchMergeTree<float> &out,
286 mtu::TorchMergeTree<float> &out2,
287 torch::Tensor &bestAlphas,
288 float &bestDistance);
289
290 bool forwardOneLayer(mtu::TorchMergeTree<float> &tree,
291 mtu::TorchMergeTree<float> &origin,
292 torch::Tensor &vSTensor,
293 mtu::TorchMergeTree<float> &originPrime,
294 torch::Tensor &vSPrimeTensor,
295 mtu::TorchMergeTree<float> &tree2,
296 mtu::TorchMergeTree<float> &origin2,
297 torch::Tensor &vS2Tensor,
298 mtu::TorchMergeTree<float> &origin2Prime,
299 torch::Tensor &vS2PrimeTensor,
300 unsigned int k,
301 torch::Tensor &alphasInit,
302 mtu::TorchMergeTree<float> &out,
303 mtu::TorchMergeTree<float> &out2,
304 torch::Tensor &bestAlphas);
305
306 bool forwardOneData(mtu::TorchMergeTree<float> &tree,
307 mtu::TorchMergeTree<float> &tree2,
308 unsigned int treeIndex,
309 unsigned int k,
310 std::vector<torch::Tensor> &alphasInit,
311 mtu::TorchMergeTree<float> &out,
312 mtu::TorchMergeTree<float> &out2,
313 std::vector<torch::Tensor> &dataAlphas,
314 std::vector<mtu::TorchMergeTree<float>> &outs,
315 std::vector<mtu::TorchMergeTree<float>> &outs2);
316
317 bool forwardStep(
318 std::vector<mtu::TorchMergeTree<float>> &trees,
319 std::vector<mtu::TorchMergeTree<float>> &trees2,
320 std::vector<unsigned int> &indexes,
321 unsigned int k,
322 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
323 bool computeReconstructionError,
324 std::vector<mtu::TorchMergeTree<float>> &outs,
325 std::vector<mtu::TorchMergeTree<float>> &outs2,
326 std::vector<std::vector<torch::Tensor>> &bestAlphas,
327 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
328 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
329 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
330 &matchings,
331 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
332 &matchings2,
333 float &loss);
334
335 bool forwardStep(std::vector<mtu::TorchMergeTree<float>> &trees,
336 std::vector<mtu::TorchMergeTree<float>> &trees2,
337 std::vector<unsigned int> &indexes,
338 unsigned int k,
339 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
340 std::vector<mtu::TorchMergeTree<float>> &outs,
341 std::vector<mtu::TorchMergeTree<float>> &outs2,
342 std::vector<std::vector<torch::Tensor>> &bestAlphas);
343
344 // -----------------------------------------------------------------------
345 // --- Backward
346 // -----------------------------------------------------------------------
347 bool backwardStep(
348 std::vector<mtu::TorchMergeTree<float>> &trees,
349 std::vector<mtu::TorchMergeTree<float>> &outs,
350 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
351 &matchings,
352 std::vector<mtu::TorchMergeTree<float>> &trees2,
353 std::vector<mtu::TorchMergeTree<float>> &outs2,
354 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
355 &matchings2,
356 torch::optim::Optimizer &optimizer,
357 std::vector<unsigned int> &indexes,
358 torch::Tensor &metricLoss,
359 torch::Tensor &clusteringLoss,
360 torch::Tensor &trackingLoss);
361
362 // -----------------------------------------------------------------------
363 // --- Projection
364 // -----------------------------------------------------------------------
365 void projectionStep();
366
367 // -----------------------------------------------------------------------
368 // --- Convergence
369 // -----------------------------------------------------------------------
370 float computeOneLoss(
371 mtu::TorchMergeTree<float> &tree,
372 mtu::TorchMergeTree<float> &out,
373 mtu::TorchMergeTree<float> &tree2,
374 mtu::TorchMergeTree<float> &out2,
375 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
376 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2);
377
378 float computeLoss(
379 std::vector<mtu::TorchMergeTree<float>> &trees,
380 std::vector<mtu::TorchMergeTree<float>> &outs,
381 std::vector<mtu::TorchMergeTree<float>> &trees2,
382 std::vector<mtu::TorchMergeTree<float>> &outs2,
383 std::vector<unsigned int> &indexes,
384 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
385 &matchings,
386 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
387 &matchings2);
388
389 bool isBestLoss(float loss, float &minLoss, unsigned int &cptBlocked);
390
391 bool convergenceStep(float loss,
392 float &oldLoss,
393 float &minLoss,
394 unsigned int &cptBlocked);
395
396 // -----------------------------------------------------------------------
397 // --- Main Functions
398 // -----------------------------------------------------------------------
399 void fit(std::vector<ftm::MergeTree<float>> &trees,
400 std::vector<ftm::MergeTree<float>> &trees2);
401
402 // -----------------------------------------------------------------------
403 // --- Custom Losses
404 // -----------------------------------------------------------------------
405 double getCustomLossDynamicWeight(double recLoss, double &baseLoss);
406
407 void getDistanceMatrix(std::vector<mtu::TorchMergeTree<float>> &tmts,
408 std::vector<std::vector<float>> &distanceMatrix,
409 bool useDoubleInput = false,
410 bool isFirstInput = true);
411
412 void getDistanceMatrix(std::vector<mtu::TorchMergeTree<float>> &tmts,
413 std::vector<mtu::TorchMergeTree<float>> &tmts2,
414 std::vector<std::vector<float>> &distanceMatrix);
415
416 void getDifferentiableDistanceFromMatchings(
417 mtu::TorchMergeTree<float> &tree1,
418 mtu::TorchMergeTree<float> &tree2,
419 mtu::TorchMergeTree<float> &tree1_2,
420 mtu::TorchMergeTree<float> &tree2_2,
421 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings,
422 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings2,
423 torch::Tensor &tensorDist,
424 bool doSqrt);
425
426 void getDifferentiableDistance(mtu::TorchMergeTree<float> &tree1,
427 mtu::TorchMergeTree<float> &tree2,
428 mtu::TorchMergeTree<float> &tree1_2,
429 mtu::TorchMergeTree<float> &tree2_2,
430 torch::Tensor &tensorDist,
431 bool isCalled,
432 bool doSqrt);
433
434 void getDifferentiableDistance(mtu::TorchMergeTree<float> &tree1,
435 mtu::TorchMergeTree<float> &tree2,
436 torch::Tensor &tensorDist,
437 bool isCalled,
438 bool doSqrt);
439
440 void getDifferentiableDistanceMatrix(
441 std::vector<mtu::TorchMergeTree<float> *> &trees,
442 std::vector<mtu::TorchMergeTree<float> *> &trees2,
443 std::vector<std::vector<torch::Tensor>> &outDistMat);
444
445 void getAlphasTensor(std::vector<std::vector<torch::Tensor>> &alphas,
446 std::vector<unsigned int> &indexes,
447 unsigned int layerIndex,
448 torch::Tensor &alphasOut);
449
450 void computeMetricLoss(
451 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
452 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
453 std::vector<std::vector<torch::Tensor>> alphas,
454 std::vector<std::vector<float>> &baseDistanceMatrix,
455 std::vector<unsigned int> &indexes,
456 torch::Tensor &metricLoss);
457
458 void computeClusteringLoss(std::vector<std::vector<torch::Tensor>> &alphas,
459 std::vector<unsigned int> &indexes,
460 torch::Tensor &clusteringLoss,
461 torch::Tensor &asgn);
462
463 void computeTrackingLoss(torch::Tensor &trackingLoss);
464
465 // ---------------------------------------------------------------------------
466 // --- End Functions
467 // ---------------------------------------------------------------------------
468 void
469 createCustomRecs(std::vector<mtu::TorchMergeTree<float>> &origins,
470 std::vector<mtu::TorchMergeTree<float>> &originsPrime);
471
472 void computeTrackingInformation();
473
474 void
475 createScaledAlphas(std::vector<std::vector<torch::Tensor>> &alphas,
476 std::vector<torch::Tensor> &vSTensor,
477 std::vector<std::vector<torch::Tensor>> &scaledAlphas);
478
479 void createScaledAlphas();
480
481 void createActivatedAlphas();
482
483 // -----------------------------------------------------------------------
484 // --- Utils
485 // -----------------------------------------------------------------------
486 void copyParams(std::vector<mtu::TorchMergeTree<float>> &srcOrigins,
487 std::vector<mtu::TorchMergeTree<float>> &srcOriginsPrime,
488 std::vector<torch::Tensor> &srcVS,
489 std::vector<torch::Tensor> &srcVSPrime,
490 std::vector<mtu::TorchMergeTree<float>> &srcOrigins2,
491 std::vector<mtu::TorchMergeTree<float>> &srcOrigins2Prime,
492 std::vector<torch::Tensor> &srcVS2,
493 std::vector<torch::Tensor> &srcVS2Prime,
494 std::vector<std::vector<torch::Tensor>> &srcAlphas,
495 std::vector<mtu::TorchMergeTree<float>> &dstOrigins,
496 std::vector<mtu::TorchMergeTree<float>> &dstOriginsPrime,
497 std::vector<torch::Tensor> &dstVS,
498 std::vector<torch::Tensor> &dstVSPrime,
499 std::vector<mtu::TorchMergeTree<float>> &dstOrigins2,
500 std::vector<mtu::TorchMergeTree<float>> &dstOrigins2Prime,
501 std::vector<torch::Tensor> &dstVS2,
502 std::vector<torch::Tensor> &dstVS2Prime,
503 std::vector<std::vector<torch::Tensor>> &dstAlphas);
504
505 void copyParams(std::vector<std::vector<mtu::TorchMergeTree<float>>> &src,
506 std::vector<std::vector<mtu::TorchMergeTree<float>>> &dst);
507
508 unsigned int getLatentLayerIndex();
509
510 // -----------------------------------------------------------------------
511 // --- Testing
512 // -----------------------------------------------------------------------
513 bool isTreeHasBigValues(ftm::MergeTree<float> &mTree,
514 float threshold = 10000);
515#endif
516
517 // ---------------------------------------------------------------------------
518 // --- Main Functions
519 // ---------------------------------------------------------------------------
520 void execute(std::vector<ftm::MergeTree<float>> &trees,
521 std::vector<ftm::MergeTree<float>> &trees2);
522 }; // MergeTreeAutoencoder class
523
524} // namespace ttk
Minimalist debugging class.
Definition Debug.h:88
std::vector< std::vector< double > > branchesCorrelationMatrix_
std::vector< unsigned int > origins2NoZeroGrad_
void execute(std::vector< ftm::MergeTree< float > > &trees, std::vector< ftm::MergeTree< float > > &trees2)
std::vector< unsigned int > clusterAsgn_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > originsMatchings_
std::vector< unsigned int > originsPrimeNoZeroGrad_
std::vector< double > inputToBaryDistances_L0_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > customMatchings_
std::vector< std::vector< double > > persCorrelationMatrix_
std::vector< std::vector< float > > distanceMatrix_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > baryMatchings2_L0_
std::vector< unsigned int > originsNoZeroGrad_
std::vector< unsigned int > vS2NoZeroGrad_
std::vector< unsigned int > vS2PrimeNoZeroGrad_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > baryMatchings_L0_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > reconstMatchings_
std::vector< std::vector< float > > customAlphas_
std::vector< std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > > dataMatchings_
std::vector< unsigned int > origins2PrimeNoZeroGrad_
std::vector< unsigned int > vSPrimeNoZeroGrad_
std::vector< unsigned int > vSNoZeroGrad_
The Topology ToolKit.