TTK
Loading...
Searching...
No Matches
TopologicalLoss.cpp
Go to the documentation of this file.
1#include "TopologicalLoss.h"
2
3#ifdef TTK_ENABLE_TORCH
4
5using namespace torch::indexing;
6
7ttk::TopologicalLoss::TopologicalLoss(
8 const torch::Tensor &input,
9 const std::vector<std::vector<double>> &points,
10 REGUL regul)
11 : input_(input), points_(points), regul_(regul),
12 device(input.device().type()) {
13 precomputeInputPersistence();
14}
15
16torch::Tensor ttk::TopologicalLoss::computeLoss(const torch::Tensor &latent) {
17 latent_ = latent;
18 latentDimension = latent.size(1);
19
20 if(regul_ == REGUL::TOPOAE)
21 return diffTopoAELoss();
22 else if(regul_ == REGUL::TOPOAE_DIM1)
23 return diffTopoAELossDim1();
24 else if(regul_ == REGUL::CASCADE)
25 return diffCascadeAELoss();
26 else if(regul_ == REGUL::ASYMMETRIC_CASCADE)
27 return diffAsymmetricCascadeAELoss();
28 else if(regul_ == REGUL::W_DIM1)
29 return diffW1() + diffTopoAELoss();
30
31 return torch::zeros({1}, device);
32}
33
34void ttk::TopologicalLoss::precomputeInputPersistence() {
35 if(regul_ == REGUL::TOPOAE || regul_ == REGUL::W_DIM1) {
36 rpd::EdgeSets3 inputCritical;
37 ripser::ripser(points_, inputCritical, rpd::inf, 0, false);
38 inputCriticalPairIndices = {pairsToTorch(inputCritical[0])};
39 }
40 if(regul_ == REGUL::TOPOAE_DIM1) {
41 rpd::EdgeSets3 inputCritical;
42 ripser::ripser(points_, inputCritical, rpd::inf, 1, false);
43 for(int i = 0; i <= 2; ++i)
44 inputCriticalPairIndices[i] = pairsToTorch(inputCritical[i]);
45 } else if(regul_ == REGUL::W_DIM1) {
46 ripser::ripser(points_, inputPD, rpd::inf, 1, false);
47 auction = std::make_unique<
49 } else if(regul_ == REGUL::CASCADE || regul_ == REGUL::ASYMMETRIC_CASCADE) {
50 // first compute the PD with Ripser
52 // use it to quickly compute the cascade
53 rpd::PairCellsWithOracle pc(points_, inputPD, false, false);
54 pc.run();
55
56 rpd::EdgeSets4 inputCriticalAndCascade; // [0] is MST, [1] is RNG-MST, [2]
57 // is MML, [3] is strict cascade
58 pc.getCascades(inputCriticalAndCascade);
59 for(int i = 0; i <= 3; ++i)
60 inputCriticalPairIndices[i] = pairsToTorch(inputCriticalAndCascade[i]);
61 }
62}
63
64void ttk::TopologicalLoss::computeLatent0Persistence(
65 rpd::EdgeSet &latent0PD) const {
66#ifdef TTK_ENABLE_CGAL
67 if(latentDimension == 2)
68 rpd::FastRipsPersistenceDiagram2(
69 latent_.cpu().data_ptr<float>(), latent_.size(0))
70 .compute0Persistence(latent0PD);
71 else {
72 rpd::EdgeSets3 latentPD;
73 ripser::ripser(latent_.cpu().data_ptr<float>(), latent_.size(0),
74 latent_.size(1), latentPD, rpd::inf, 0, false);
75 latent0PD = latentPD[0];
76 }
77#else
78 rpd::EdgeSets3 latentPD;
79 ripser::ripser(latent_.cpu().data_ptr<float>(), latent_.size(0),
80 latent_.size(1), latentPD, rpd::inf, 0, false);
81 latent0PD = latentPD[0];
82#endif
83}
84
85template <typename PersistenceType>
86void ttk::TopologicalLoss::computeLatent0And1Persistence(
87 PersistenceType &latentPD) const {
88#ifdef TTK_ENABLE_CGAL
89 if(latentDimension == 2)
90 rpd::FastRipsPersistenceDiagram2(
91 latent_.cpu().data_ptr<float>(), latent_.size(0))
92 .computeRips0And1Persistence(latentPD, false, false);
93 else
94 ripser::ripser(latent_.cpu().data_ptr<float>(), latent_.size(0),
95 latent_.size(1), latentPD, rpd::inf, 1, false);
96#else
97 ripser::ripser(latent_.cpu().data_ptr<float>(), latent_.size(0),
98 latent_.size(1), latentPD, rpd::inf, 1, false);
99#endif
100}
101template void ttk::TopologicalLoss::computeLatent0And1Persistence(
102 rpd::MultidimensionalDiagram &latentPD) const;
103template void ttk::TopologicalLoss::computeLatent0And1Persistence(
104 rpd::EdgeSets3 &latentPD) const;
105
106void ttk::TopologicalLoss::computeLatentCascades(
107 rpd::EdgeSets4 &latentCriticalAndCascades) const {
108#ifdef TTK_ENABLE_CGAL
109 if(latentDimension == 2)
110 rpd::FastRipsPersistenceDiagram2(
111 latent_.cpu().data_ptr<float>(), latent_.size(0))
112 .computeRips0And1Persistence(latentCriticalAndCascades, false, false);
113 else {
115 ripser::ripser(latent_.cpu().data_ptr<float>(), latent_.size(0),
116 latent_.size(1), latentPD, rpd::inf, 1, false);
117 rpd::PairCellsWithOracle pc(latent_.cpu().data_ptr<float>(),
118 latent_.size(0), latent_.size(1), latentPD,
119 false);
120 pc.run();
121 pc.getCascades(latentCriticalAndCascades);
122 }
123#else
125 ripser::ripser(latent_.cpu().data_ptr<float>(), latent_.size(0),
126 latent_.size(1), latentPD, rpd::inf, 1, false);
127 rpd::PairCellsWithOracle pc(latent_.cpu().data_ptr<float>(), latent_.size(0),
128 latent_.size(1), latentPD, false);
129 pc.run();
130 pc.getCascades(latentCriticalAndCascades);
131#endif
132}
133
134/*** TopoAE-like distances ***/
135
136torch::Tensor ttk::TopologicalLoss::diffTopoAELoss() const {
137 rpd::EdgeSet latent0Critical;
138 computeLatent0Persistence(latent0Critical);
139
140 return diffEdgeSetMSE(inputCriticalPairIndices[0])
141 + diffEdgeSetMSE(pairsToTorch(latent0Critical));
142}
143
144torch::Tensor ttk::TopologicalLoss::diffTopoAELossDim1() const {
145 rpd::EdgeSets3 latentCritical;
146 computeLatent0And1Persistence(latentCritical);
147 return diffRNGMML(latentCritical);
148}
149
150torch::Tensor ttk::TopologicalLoss::diffCascadeAELoss() const {
151 rpd::EdgeSets4 latentCriticalAndCascade;
152 computeLatentCascades(latentCriticalAndCascade);
153
154 return diffRNGMML(latentCriticalAndCascade)
155 + diffEdgeSetMSE(inputCriticalPairIndices[3])
156 + diffEdgeSetMSE(pairsToTorch(latentCriticalAndCascade[rpd::CASC1]));
157}
158
159torch::Tensor ttk::TopologicalLoss::diffAsymmetricCascadeAELoss() const {
160 rpd::EdgeSets3 latentCritical;
161 computeLatent0And1Persistence(latentCritical);
162
163 return diffRNGMML(latentCritical)
164 + diffEdgeSetMSE(inputCriticalPairIndices[3]);
165}
166
167/*** Wasserstein distances ***/
168
169void ttk::TopologicalLoss::performAuction(
170 const rpd::Diagram &latentPD,
171 std::vector<unsigned> &directMatchingLatent,
172 std::vector<unsigned> &directMatchingInput,
173 std::vector<unsigned> &diagonalMatchingLatent) const {
174 auction->setNewBidder(latentPD);
175 std::vector<MatchingType> matchings;
176 auction->runAuction(matchings);
177
178 for(const MatchingType &m : matchings) {
179 if(std::get<0>(m) < 0)
180 break;
181 else {
182 if(std::get<1>(m) >= 0) {
183 directMatchingLatent.push_back(std::get<0>(m));
184 directMatchingInput.push_back(std::get<1>(m));
185 } else
186 diagonalMatchingLatent.push_back(std::get<0>(m));
187 }
188 }
189}
190
191torch::Tensor ttk::TopologicalLoss::diffW1() const {
192 std::vector<rpd::Diagram> latentPD(0);
193 computeLatent0And1Persistence(latentPD);
194
195 if(latentPD[1].empty())
196 return torch::zeros(1, device);
197
198 std::vector<unsigned> directMatchingLatent(0);
199 std::vector<unsigned> directMatchingInput(0);
200 std::vector<unsigned> diagonalMatchingLatent(0);
201 performAuction(latentPD[1], directMatchingLatent, directMatchingInput,
202 diagonalMatchingLatent);
203
204 torch::Tensor directMatchedInputPD
205 = torch::zeros({int(directMatchingInput.size()), 2});
206 float *inputData = directMatchedInputPD.data_ptr<float>();
207 for(unsigned i = 0; i < directMatchingInput.size(); ++i) {
208 inputData[2 * i]
209 = float(inputPD[1][directMatchingInput[i]].first.second); // birth
210 inputData[2 * i + 1]
211 = float(inputPD[1][directMatchingInput[i]].second.second); // death
212 }
213 directMatchedInputPD = directMatchedInputPD.to(device);
214
215 const torch::Tensor directMatchedLatentPD
216 = diffPD(latent_, latentPD[1], directMatchingLatent);
217 const torch::Tensor diagonalMatchedLatentPD
218 = diffPD(latent_, latentPD[1], diagonalMatchingLatent);
219 const torch::Tensor diagProj
220 = sqrt(2) / 2
221 * (diagonalMatchedLatentPD.index({Slice(), 1})
222 - diagonalMatchedLatentPD.index({Slice(), 0}));
223
224 const torch::Tensor costs = torch::cat(
225 {(directMatchedInputPD - directMatchedLatentPD).norm(2, 1), diagProj});
226 return costs.norm(2);
227}
228
229/*** tensor tools ***/
230
231torch::Tensor
232 ttk::TopologicalLoss::pairsToTorch(const rpd::EdgeSet &edges) const {
233 const torch::Tensor pairIndices
234 = torch::zeros({(int)edges.size(), 2}, torch::kInt);
235 int *data = pairIndices.data_ptr<int>();
236 for(unsigned i = 0; i < edges.size(); ++i) {
237 data[2 * i] = edges[i].first;
238 data[2 * i + 1] = edges[i].second;
239 }
240 return pairIndices.to(device);
241}
242
243torch::Tensor
244 ttk::TopologicalLoss::diffDistances(const torch::Tensor &data,
245 const torch::Tensor &indices) {
246 const torch::Tensor points = data.index({indices});
247 return (points.index({Slice(), 0}) - points.index({Slice(), 1})).norm(2, 1);
248}
249
250torch::Tensor
251 ttk::TopologicalLoss::diffEdgeSetMSE(const torch::Tensor &indices) const {
252 if(indices.size(0) > 0) // compute only if the edge set is non-empty
253 return torch::mse_loss(diffDistances(latent_, indices),
254 diffDistances(input_, indices), reduction_);
255 else
256 return torch::zeros({1}, device);
257}
258
259torch::Tensor
260 ttk::TopologicalLoss::diffPD(const torch::Tensor &points,
261 const rpd::Diagram &PD,
262 const std::vector<unsigned int> &indices) const {
263 torch::Tensor edgesIndices
264 = torch::zeros({int(indices.size()), 2, 2}, torch::kInt);
265 int *data = edgesIndices.data_ptr<int>();
266 for(unsigned i = 0; i < indices.size(); ++i) {
267 const unsigned index = indices[i];
268 data[4 * i] = PD[index].first.first[0];
269 data[4 * i + 1] = PD[index].first.first[1];
270 data[4 * i + 2] = PD[index].second.first[0];
271 data[4 * i + 3] = PD[index].second.first[1];
272 }
273 edgesIndices = edgesIndices.to(device);
274 const torch::Tensor PDPoints = points.index({edgesIndices});
275 return (PDPoints.index({Slice(), Slice(), 0})
276 - PDPoints.index({Slice(), Slice(), 1}))
277 .norm(2, 2);
278}
279
280#endif
TTK base class that computes the Wasserstein distance between two persistence diagrams.
TTK base class that partially executes on a Rips complex the PairCells persistence algorithm where ne...
static void callOracle(const PointCloud &points, MultidimensionalDiagram &oracle, double threshold=inf, bool distanceMatrix=false)
void ripser(std::vector< std::vector< value_t > > points, PersistenceType &ph, value_t threshold, index_t dim_max, bool distanceMatrix, bool criticalEdgesOnly=true, bool infinitePairs=true, coefficient_t modulus=2)
Definition ripser.cpp:1136
std::array< EdgeSet, 4 > EdgeSets4
constexpr value_t inf
std::vector< Edge > EdgeSet
std::vector< Diagram > MultidimensionalDiagram
std::array< EdgeSet, 3 > EdgeSets3
std::vector< PersistencePair > Diagram
std::tuple< int, int, double > MatchingType
Matching between two Persistence Diagram pairs.