TTK
Loading...
Searching...
No Matches
DimensionReductionModel.h
Go to the documentation of this file.
1
17
18#pragma once
19
20#ifdef TTK_ENABLE_TORCH
21
22#include <torch/torch.h>
23
24namespace ttk {
25
29 class DimensionReductionModel : public torch::nn::Module {
30 public:
31 virtual torch::Tensor forward(torch::Tensor const &x) = 0;
32 virtual torch::Tensor encode(torch::Tensor const &x) = 0;
33 virtual torch::Tensor decode(torch::Tensor const &x) = 0;
34 }; // DimensionReductionModel class
35
40 class AutoEncoder : public DimensionReductionModel {
41 public:
42 AutoEncoder(int inputDim,
43 int latentDim,
44 const std::string &layersDescription,
45 const std::string &activation = "ReLU",
46 bool useBN = true);
47
48 inline torch::Tensor forward(torch::Tensor const &x) override {
49 return decoder->forward(encoder->forward(x));
50 }
51
52 inline torch::Tensor encode(torch::Tensor const &x) override {
53 return encoder->forward(x);
54 }
55
56 inline torch::Tensor decode(torch::Tensor const &x) override {
57 return decoder->forward(x);
58 }
59
60 static bool isStringValid(const std::string &s);
61
62 private:
63 torch::nn::Sequential encoder;
64 torch::nn::Sequential decoder;
65 }; // AutoEncoder class
66
71 class AutoDecoder : public DimensionReductionModel {
72 public:
73 AutoDecoder(int inputDim,
74 int inputSize,
75 int latentDim,
76 const std::string &layersDescription,
77 const std::string &activation = "ReLU",
78 bool useBN = true);
79
80 inline torch::Tensor forward(torch::Tensor const & /*x*/) override {
81 return decoder->forward(latent);
82 }
83
84 inline torch::Tensor encode(torch::Tensor const & /*x*/) override {
85 return latent;
86 }
87
88 inline torch::Tensor decode(torch::Tensor const & /*x*/) override {
89 return decoder->forward(latent);
90 }
91
92 private:
93 torch::Tensor latent;
94 torch::nn::Sequential decoder;
95 }; // AutoDecoder class
96
101 class DirectOptimization : public DimensionReductionModel {
102 public:
103 DirectOptimization(int inputSize, int latentDim);
104
105 inline torch::Tensor forward(torch::Tensor const &x) override {
106 return x;
107 }
108
109 inline torch::Tensor encode(torch::Tensor const &x) override {
110 input = x;
111 return latent;
112 }
113
114 inline torch::Tensor decode(torch::Tensor const & /*x*/) override {
115 return input;
116 }
117
118 private:
119 torch::Tensor latent;
120 torch::Tensor input;
121 }; // DirectOptimization class
122
128 class ConvolutionalAutoEncoder : public DimensionReductionModel {
129 public:
130 ConvolutionalAutoEncoder(int imageSide,
131 int latentDim,
132 const std::string &layersDescription,
133 bool useBN);
134
135 inline torch::Tensor forward(torch::Tensor const &x) override {
136 return decoder->forward(encoder->forward(x));
137 }
138
139 inline torch::Tensor encode(torch::Tensor const &x) override {
140 return encoder->forward(x);
141 }
142
143 inline torch::Tensor decode(torch::Tensor const &x) override {
144 return decoder->forward(x);
145 }
146
147 static bool isStringValid(const std::string &s);
148
149 private:
150 torch::nn::Sequential encoder;
151 torch::nn::Sequential decoder;
152 }; // ConvolutionalAutoEncoder class
153
154} // namespace ttk
155
156#endif
TTK base class for containing reduction dimension models.
TTK base package defining the standard types.