TTK
Loading...
Searching...
No Matches
TopologicalLoss.h
Go to the documentation of this file.
1
20
21#pragma once
22
23#include <PairCellsWithOracle.h>
26
27#ifdef TTK_ENABLE_TORCH
28#include <torch/torch.h>
29#endif
30
31namespace ttk {
33 public:
48
49#ifdef TTK_ENABLE_TORCH
50
51 TopologicalLoss(const torch::Tensor &input,
52 std::vector<std::vector<double>> const &points,
53 REGUL regul);
54
55 torch::Tensor computeLoss(const torch::Tensor &latent);
56
57 private:
58 const torch::Tensor input_;
59 const std::vector<std::vector<double>> &points_;
60 const REGUL regul_;
61 const torch::Reduction::Reduction reduction_{torch::Reduction::Mean};
62 const torch::DeviceType device{torch::kCPU};
63 torch::Tensor latent_;
64 int latentDimension;
65
66 /* persistence containers */
68 std::array<torch::Tensor, 4>
69 inputCriticalPairIndices; // [0] is MST, [1] is RNG-MST, [2] is MML, [3]
70 // is strict cascade if required
71 std::unique_ptr<PersistenceDiagramWarmRestartAuction<rpd::PersistencePair>>
72 auction{nullptr};
73
74 /* persistence computation methods */
75 void precomputeInputPersistence();
76 void computeLatent0Persistence(rpd::EdgeSet &latent0PD) const;
77 template <typename PersistenceType>
78 void computeLatent0And1Persistence(PersistenceType &latentPD) const;
79 void computeLatentCascades(rpd::EdgeSets4 &latentCriticalAndCascades) const;
80
81 /* tensor tools */
82 inline torch::Tensor pairsToTorch(const rpd::EdgeSet &edges) const;
83 static inline torch::Tensor diffDistances(const torch::Tensor &data,
84 const torch::Tensor &indices);
85 inline torch::Tensor diffEdgeSetMSE(const torch::Tensor &indices) const;
86 torch::Tensor diffPD(const torch::Tensor &points,
87 const rpd::Diagram &PD,
88 const std::vector<unsigned> &indices) const;
89
90 /* TopoAE-like distances */
91 template <typename EdgeSets>
92 inline torch::Tensor diffRNGMML(const EdgeSets &latentCritical) const {
93 return diffEdgeSetMSE(inputCriticalPairIndices[0])
94 + diffEdgeSetMSE(inputCriticalPairIndices[1])
95 + diffEdgeSetMSE(inputCriticalPairIndices[2])
96 + diffEdgeSetMSE(pairsToTorch(latentCritical[0]))
97 + diffEdgeSetMSE(pairsToTorch(latentCritical[1]))
98 + diffEdgeSetMSE(pairsToTorch(latentCritical[2]));
99 }
100 torch::Tensor diffTopoAELoss() const;
101 torch::Tensor diffTopoAELossDim1() const;
102 torch::Tensor diffCascadeAELoss() const;
103 torch::Tensor diffAsymmetricCascadeAELoss() const;
104
105 /* Wasserstein distances */
106 void performAuction(const rpd::Diagram &latentPD,
107 std::vector<unsigned> &directMatchingLatent,
108 std::vector<unsigned> &directMatchingInput,
109 std::vector<unsigned> &diagonalMatchingLatent) const;
110 torch::Tensor diffW1() const;
111
112#endif
113 };
114} // namespace ttk
TTK base class for representing differentiable topological losses to be used in dimension reduction.
std::array< EdgeSet, 4 > EdgeSets4
std::vector< Edge > EdgeSet
std::vector< Diagram > MultidimensionalDiagram
std::vector< PersistencePair > Diagram
TTK base package defining the standard types.