5using namespace torch::indexing;
7ttk::TopologicalLoss::TopologicalLoss(
8 const torch::Tensor &input,
9 const std::vector<std::vector<double>> &points,
11 : input_(input), points_(points), regul_(regul),
12 device(input.device().type()) {
13 precomputeInputPersistence();
16torch::Tensor ttk::TopologicalLoss::computeLoss(
const torch::Tensor &latent) {
18 latentDimension = latent.size(1);
21 return diffTopoAELoss();
23 return diffTopoAELossDim1();
25 return diffCascadeAELoss();
27 return diffAsymmetricCascadeAELoss();
29 return diffW1() + diffTopoAELoss();
31 return torch::zeros({1}, device);
34void ttk::TopologicalLoss::precomputeInputPersistence() {
35 if(regul_ == REGUL::TOPOAE || regul_ == REGUL::W_DIM1) {
38 inputCriticalPairIndices = {pairsToTorch(inputCritical[0])};
40 if(regul_ == REGUL::TOPOAE_DIM1) {
43 for(
int i = 0; i <= 2; ++i)
44 inputCriticalPairIndices[i] = pairsToTorch(inputCritical[i]);
45 }
else if(regul_ == REGUL::W_DIM1) {
47 auction = std::make_unique<
49 }
else if(regul_ == REGUL::CASCADE || regul_ == REGUL::ASYMMETRIC_CASCADE) {
58 pc.getCascades(inputCriticalAndCascade);
59 for(
int i = 0; i <= 3; ++i)
60 inputCriticalPairIndices[i] = pairsToTorch(inputCriticalAndCascade[i]);
64void ttk::TopologicalLoss::computeLatent0Persistence(
67 if(latentDimension == 2)
68 rpd::FastRipsPersistenceDiagram2(
69 latent_.cpu().data_ptr<
float>(), latent_.size(0))
70 .compute0Persistence(latent0PD);
74 latent_.size(1), latentPD,
rpd::inf, 0,
false);
75 latent0PD = latentPD[0];
80 latent_.size(1), latentPD,
rpd::inf, 0,
false);
81 latent0PD = latentPD[0];
85template <
typename PersistenceType>
86void ttk::TopologicalLoss::computeLatent0And1Persistence(
87 PersistenceType &latentPD)
const {
89 if(latentDimension == 2)
90 rpd::FastRipsPersistenceDiagram2(
91 latent_.cpu().data_ptr<
float>(), latent_.size(0))
92 .computeRips0And1Persistence(latentPD,
false,
false);
95 latent_.size(1), latentPD,
rpd::inf, 1,
false);
98 latent_.size(1), latentPD,
rpd::inf, 1,
false);
101template void ttk::TopologicalLoss::computeLatent0And1Persistence(
103template void ttk::TopologicalLoss::computeLatent0And1Persistence(
106void ttk::TopologicalLoss::computeLatentCascades(
108#ifdef TTK_ENABLE_CGAL
109 if(latentDimension == 2)
110 rpd::FastRipsPersistenceDiagram2(
111 latent_.cpu().data_ptr<
float>(), latent_.size(0))
112 .computeRips0And1Persistence(latentCriticalAndCascades,
false,
false);
116 latent_.size(1), latentPD,
rpd::inf, 1,
false);
118 latent_.size(0), latent_.size(1), latentPD,
121 pc.getCascades(latentCriticalAndCascades);
126 latent_.size(1), latentPD,
rpd::inf, 1,
false);
128 latent_.size(1), latentPD,
false);
130 pc.getCascades(latentCriticalAndCascades);
136torch::Tensor ttk::TopologicalLoss::diffTopoAELoss()
const {
138 computeLatent0Persistence(latent0Critical);
140 return diffEdgeSetMSE(inputCriticalPairIndices[0])
141 + diffEdgeSetMSE(pairsToTorch(latent0Critical));
144torch::Tensor ttk::TopologicalLoss::diffTopoAELossDim1()
const {
146 computeLatent0And1Persistence(latentCritical);
147 return diffRNGMML(latentCritical);
150torch::Tensor ttk::TopologicalLoss::diffCascadeAELoss()
const {
152 computeLatentCascades(latentCriticalAndCascade);
154 return diffRNGMML(latentCriticalAndCascade)
155 + diffEdgeSetMSE(inputCriticalPairIndices[3])
156 + diffEdgeSetMSE(pairsToTorch(latentCriticalAndCascade[
rpd::CASC1]));
159torch::Tensor ttk::TopologicalLoss::diffAsymmetricCascadeAELoss()
const {
161 computeLatent0And1Persistence(latentCritical);
163 return diffRNGMML(latentCritical)
164 + diffEdgeSetMSE(inputCriticalPairIndices[3]);
169void ttk::TopologicalLoss::performAuction(
171 std::vector<unsigned> &directMatchingLatent,
172 std::vector<unsigned> &directMatchingInput,
173 std::vector<unsigned> &diagonalMatchingLatent)
const {
174 auction->setNewBidder(latentPD);
175 std::vector<MatchingType> matchings;
176 auction->runAuction(matchings);
179 if(std::get<0>(m) < 0)
182 if(std::get<1>(m) >= 0) {
183 directMatchingLatent.push_back(std::get<0>(m));
184 directMatchingInput.push_back(std::get<1>(m));
186 diagonalMatchingLatent.push_back(std::get<0>(m));
191torch::Tensor ttk::TopologicalLoss::diffW1()
const {
192 std::vector<rpd::Diagram> latentPD(0);
193 computeLatent0And1Persistence(latentPD);
195 if(latentPD[1].empty())
196 return torch::zeros(1, device);
198 std::vector<unsigned> directMatchingLatent(0);
199 std::vector<unsigned> directMatchingInput(0);
200 std::vector<unsigned> diagonalMatchingLatent(0);
201 performAuction(latentPD[1], directMatchingLatent, directMatchingInput,
202 diagonalMatchingLatent);
204 torch::Tensor directMatchedInputPD
205 = torch::zeros({int(directMatchingInput.size()), 2});
206 float *inputData = directMatchedInputPD.data_ptr<
float>();
207 for(
unsigned i = 0; i < directMatchingInput.size(); ++i) {
209 = float(inputPD[1][directMatchingInput[i]].first.second);
211 = float(inputPD[1][directMatchingInput[i]].second.second);
213 directMatchedInputPD = directMatchedInputPD.to(device);
215 const torch::Tensor directMatchedLatentPD
216 = diffPD(latent_, latentPD[1], directMatchingLatent);
217 const torch::Tensor diagonalMatchedLatentPD
218 = diffPD(latent_, latentPD[1], diagonalMatchingLatent);
219 const torch::Tensor diagProj
221 * (diagonalMatchedLatentPD.index({Slice(), 1})
222 - diagonalMatchedLatentPD.index({Slice(), 0}));
224 const torch::Tensor costs = torch::cat(
225 {(directMatchedInputPD - directMatchedLatentPD).norm(2, 1), diagProj});
226 return costs.norm(2);
232 ttk::TopologicalLoss::pairsToTorch(
const rpd::EdgeSet &edges)
const {
233 const torch::Tensor pairIndices
234 = torch::zeros({(int)edges.size(), 2}, torch::kInt);
235 int *data = pairIndices.data_ptr<
int>();
236 for(
unsigned i = 0; i < edges.size(); ++i) {
237 data[2 * i] = edges[i].first;
238 data[2 * i + 1] = edges[i].second;
240 return pairIndices.to(device);
244 ttk::TopologicalLoss::diffDistances(
const torch::Tensor &data,
245 const torch::Tensor &indices) {
246 const torch::Tensor points = data.index({indices});
247 return (points.index({Slice(), 0}) - points.index({Slice(), 1})).norm(2, 1);
251 ttk::TopologicalLoss::diffEdgeSetMSE(
const torch::Tensor &indices)
const {
252 if(indices.size(0) > 0)
253 return torch::mse_loss(diffDistances(latent_, indices),
254 diffDistances(input_, indices), reduction_);
256 return torch::zeros({1}, device);
260 ttk::TopologicalLoss::diffPD(
const torch::Tensor &points,
262 const std::vector<unsigned int> &indices)
const {
263 torch::Tensor edgesIndices
264 = torch::zeros({int(indices.size()), 2, 2}, torch::kInt);
265 int *data = edgesIndices.data_ptr<
int>();
266 for(
unsigned i = 0; i < indices.size(); ++i) {
267 const unsigned index = indices[i];
268 data[4 * i] = PD[index].first.first[0];
269 data[4 * i + 1] = PD[index].first.first[1];
270 data[4 * i + 2] = PD[index].second.first[0];
271 data[4 * i + 3] = PD[index].second.first[1];
273 edgesIndices = edgesIndices.to(device);
274 const torch::Tensor PDPoints = points.index({edgesIndices});
275 return (PDPoints.index({Slice(), Slice(), 0})
276 - PDPoints.index({Slice(), Slice(), 1}))
TTK base class that computes the Wasserstein distance between two persistence diagrams.
TTK base class that partially executes on a Rips complex the PairCells persistence algorithm where ne...
static void callOracle(const PointCloud &points, MultidimensionalDiagram &oracle, double threshold=inf, bool distanceMatrix=false)
void ripser(std::vector< std::vector< value_t > > points, PersistenceType &ph, value_t threshold, index_t dim_max, bool distanceMatrix, bool criticalEdgesOnly=true, bool infinitePairs=true, coefficient_t modulus=2)
std::array< EdgeSet, 4 > EdgeSets4
std::vector< Edge > EdgeSet
std::vector< Diagram > MultidimensionalDiagram
std::array< EdgeSet, 3 > EdgeSets3
std::vector< PersistencePair > Diagram
std::tuple< int, int, double > MatchingType
Matching between two Persistence Diagram pairs.