34 enum class REGUL : std::uint8_t {
49#ifdef TTK_ENABLE_TORCH
52 std::vector<std::vector<double>>
const &points,
55 torch::Tensor computeLoss(
const torch::Tensor &latent);
58 const torch::Tensor input_;
59 const std::vector<std::vector<double>> &points_;
61 const torch::Reduction::Reduction reduction_{torch::Reduction::Mean};
62 const torch::DeviceType device{torch::kCPU};
63 torch::Tensor latent_;
68 std::array<torch::Tensor, 4>
69 inputCriticalPairIndices;
71 std::unique_ptr<PersistenceDiagramWarmRestartAuction<rpd::PersistencePair>>
75 void precomputeInputPersistence();
76 void computeLatent0Persistence(
rpd::EdgeSet &latent0PD)
const;
77 template <
typename PersistenceType>
78 void computeLatent0And1Persistence(PersistenceType &latentPD)
const;
79 void computeLatentCascades(
rpd::EdgeSets4 &latentCriticalAndCascades)
const;
82 inline torch::Tensor pairsToTorch(
const rpd::EdgeSet &edges)
const;
83 static inline torch::Tensor diffDistances(
const torch::Tensor &data,
84 const torch::Tensor &indices);
85 inline torch::Tensor diffEdgeSetMSE(
const torch::Tensor &indices)
const;
86 torch::Tensor diffPD(
const torch::Tensor &points,
88 const std::vector<unsigned> &indices)
const;
91 template <
typename EdgeSets>
92 inline torch::Tensor diffRNGMML(
const EdgeSets &latentCritical)
const {
93 return diffEdgeSetMSE(inputCriticalPairIndices[0])
94 + diffEdgeSetMSE(inputCriticalPairIndices[1])
95 + diffEdgeSetMSE(inputCriticalPairIndices[2])
96 + diffEdgeSetMSE(pairsToTorch(latentCritical[0]))
97 + diffEdgeSetMSE(pairsToTorch(latentCritical[1]))
98 + diffEdgeSetMSE(pairsToTorch(latentCritical[2]));
100 torch::Tensor diffTopoAELoss()
const;
101 torch::Tensor diffTopoAELossDim1()
const;
102 torch::Tensor diffCascadeAELoss()
const;
103 torch::Tensor diffAsymmetricCascadeAELoss()
const;
107 std::vector<unsigned> &directMatchingLatent,
108 std::vector<unsigned> &directMatchingInput,
109 std::vector<unsigned> &diagonalMatchingLatent)
const;
110 torch::Tensor diffW1()
const;