TTK
Loading...
Searching...
No Matches
TopologicalDimensionReduction.h
Go to the documentation of this file.
1
39
40#pragma once
41
42// ttk includes
43#include <Debug.h>
45#include <TopologicalLoss.h>
46
47namespace ttk {
48
55 class TopologicalDimensionReduction : virtual public Debug {
56
57 public:
58 enum class OPTIMIZER : std::uint8_t {
60 ADAM = 0,
62 SGD = 1,
64 LBFGS = 2,
65 };
66
67 enum class MODEL : std::uint8_t {
73 DIRECT = 2,
74 };
75
77
78#ifdef TTK_ENABLE_TORCH
79
81 bool deterministic,
82 int seed,
83 int numberOfComponents,
84 int epochs,
85 double learningRate,
86 OPTIMIZER optimizer,
87 REGUL method,
88 MODEL modelType,
89 const std::string &architecture,
90 const std::string &activation,
91 int batchSize,
92 bool batchNormalization,
93 double regCoefficient,
94 bool inputIsImages,
95 bool preOptimize,
96 int preOptimizeEpochs);
97
109 int execute(std::vector<std::vector<double>> &outputEmbedding,
110 const std::vector<double> &inputMatrix,
111 size_t n);
112
113 protected:
114 const int NumberOfComponents;
115 const int Epochs;
116 const double LearningRate;
117 const OPTIMIZER Optimizer;
118 const REGUL Method;
119 const MODEL ModelType;
120 const bool InputIsImages;
121 const std::string Architecture;
122 const std::string Activation;
123 const int BatchSize;
124 const bool BatchNormalization;
125 const double RegCoefficient;
126 const bool PreOptimize;
127 const int PreOptimizeEpochs;
128
129 private:
130 torch::DeviceType device{torch::kCPU};
131 std::unique_ptr<DimensionReductionModel> model{nullptr};
132 std::unique_ptr<torch::optim::Optimizer> torchOptimizer{nullptr};
133 std::unique_ptr<TopologicalLoss> topologicalLossContainer{nullptr};
134
135 int initializeModel(int inputSize, int inputDimension);
136 void initializeOptimizer();
137
138 void preOptimize(const torch::Tensor &input,
139 const torch::Tensor &target) const;
140
141 void optimize(const torch::Tensor &input) const;
142 void optimizeSimple(const torch::Tensor &input) const;
143
144 inline void printLoss(int epoch, int maxEpoch, double loss) const;
145
146#endif
147
148 }; // TopologicalDimensionReduction class
149
150} // namespace ttk
TTK base class that embeds points into 2D, under topological constraints.
TTK base package defining the standard types.