TTK
Loading...
Searching...
No Matches
DimensionReductionModel.cpp
Go to the documentation of this file.
2#include <regex>
3
4#ifdef TTK_ENABLE_TORCH
5
6ttk::AutoEncoder::AutoEncoder(int inputDim,
7 int latentDim,
8 const std::string &layersDescription,
9 const std::string &activation,
10 bool useBN) {
11
12 std::istringstream iss(layersDescription);
13 const std::vector<std::string> hiddenDimsParsed(
14 std::istream_iterator<std::string>{iss},
15 std::istream_iterator<std::string>());
16 std::vector<unsigned> dims(1, inputDim);
17 for(const std::string &s : hiddenDimsParsed)
18 dims.push_back(std::stoi(s));
19 dims.push_back(latentDim);
20
21 const int n = dims.size() - 1;
22 encoder = torch::nn::Sequential(torch::nn::Linear(dims[0], dims[1]));
23 decoder = torch::nn::Sequential(torch::nn::Linear(dims[n], dims[n - 1]));
24 for(unsigned i = 1; i < dims.size() - 1; ++i) {
25 if(activation == "ReLU") {
26 encoder->push_back(torch::nn::ReLU());
27 decoder->push_back(torch::nn::ReLU());
28 } else if(activation == "Tanh") {
29 encoder->push_back(torch::nn::Tanh());
30 decoder->push_back(torch::nn::Tanh());
31 }
32 if(useBN) {
33 encoder->push_back(torch::nn::BatchNorm1d(dims[i]));
34 decoder->push_back(torch::nn::BatchNorm1d(dims[n - i]));
35 }
36 encoder->push_back(torch::nn::Linear(dims[i], dims[i + 1]));
37 decoder->push_back(torch::nn::Linear(dims[n - i], dims[n - (i + 1)]));
38 }
39 register_module("encoder", encoder);
40 register_module("decoder", decoder);
41}
42
43bool ttk::AutoEncoder::isStringValid(const std::string &s) {
44 return std::regex_match(s, std::regex("([0-9]+( )*)*"));
45}
46
47ttk::AutoDecoder::AutoDecoder(int inputDim,
48 int inputSize,
49 int latentDim,
50 const std::string &layersDescription,
51 const std::string &activation,
52 bool useBN) {
53
54 std::istringstream iss(layersDescription);
55 const std::vector<std::string> hiddenDimsParsed(
56 std::istream_iterator<std::string>{iss},
57 std::istream_iterator<std::string>());
58 std::vector<unsigned> dims(1, inputDim);
59 for(const std::string &s : hiddenDimsParsed)
60 dims.push_back(std::stoi(s));
61 dims.push_back(latentDim);
62
63 latent = torch::rand({inputSize, latentDim});
64
65 const int n = dims.size() - 1;
66 decoder = torch::nn::Sequential(torch::nn::Linear(dims[n], dims[n - 1]));
67 for(unsigned i = 1; i < dims.size() - 1; ++i) {
68 if(activation == "ReLU")
69 decoder->push_back(torch::nn::ReLU());
70 else if(activation == "Tanh")
71 decoder->push_back(torch::nn::Tanh());
72 if(useBN)
73 decoder->push_back(torch::nn::BatchNorm1d(dims[n - i]));
74 decoder->push_back(torch::nn::Linear(dims[n - i], dims[n - (i + 1)]));
75 }
76 register_parameter("latent", latent, true);
77 register_module("decoder", decoder);
78}
79
80ttk::DirectOptimization::DirectOptimization(int inputSize, int latentDim) {
81 latent = torch::rand({inputSize, latentDim});
82 register_parameter("latent", latent, true);
83}
84
85ttk::ConvolutionalAutoEncoder::ConvolutionalAutoEncoder(
86 int imageSide,
87 int latentDim,
88 const std::string &layersDescription,
89 bool useBN) {
90 std::istringstream iss(layersDescription);
91 const std::vector<std::string> hiddenDimsParsed(
92 std::istream_iterator<std::string>{iss},
93 std::istream_iterator<std::string>());
94 std::vector<int> denseLayersSizes = {-1};
95 std::vector<int> convolutionalLayersChannels = {1};
96 std::vector<int> convolutionalLayersStrides;
97 for(const std::string &s : hiddenDimsParsed) {
98 if(std::regex_match(s, std::regex("c[0-9]+/[0-9]+"))) {
99 convolutionalLayersChannels.push_back(
100 std::stoi(s.substr(1, s.find('/'))));
101 convolutionalLayersStrides.push_back(
102 std::stoi(s.substr(s.find('/') + 1, s.length())));
103 } else
104 denseLayersSizes.push_back(std::stoi(s));
105 }
106 denseLayersSizes.push_back(latentDim);
107 int convolutionalOutputImageSide = imageSide;
108
109 /*** convolutional encoder ***/
111 encoder->push_back(torch::nn::Unflatten(
112 torch::nn::UnflattenOptions(1, {1, imageSide, imageSide})));
114 for(unsigned c = 0; c < convolutionalLayersChannels.size() - 1; ++c) {
115 encoder->push_back(torch::nn::Conv2d(
116 torch::nn::Conv2dOptions(
117 convolutionalLayersChannels[c], convolutionalLayersChannels[c + 1], 3)
118 .padding(1)
119 .stride(convolutionalLayersStrides[c])));
120 encoder->push_back(torch::nn::ReLU());
121 convolutionalOutputImageSide /= convolutionalLayersStrides[c];
122 }
124 encoder->push_back(torch::nn::Flatten());
125
126 /*** dense encoder / decoder ***/
127 denseLayersSizes[0]
128 = convolutionalLayersChannels[convolutionalLayersChannels.size() - 1]
129 * convolutionalOutputImageSide * convolutionalOutputImageSide;
130 const int n = denseLayersSizes.size() - 1;
131 encoder->push_back(
132 torch::nn::Linear(denseLayersSizes[0], denseLayersSizes[1]));
133 decoder->push_back(
134 torch::nn::Linear(denseLayersSizes[n], denseLayersSizes[n - 1]));
135 for(unsigned i = 1; i < denseLayersSizes.size() - 1; ++i) {
136 encoder->push_back(torch::nn::ReLU());
137 decoder->push_back(torch::nn::ReLU());
138 if(useBN) {
139 encoder->push_back(torch::nn::BatchNorm1d(denseLayersSizes[i]));
140 decoder->push_back(torch::nn::BatchNorm1d(denseLayersSizes[n - i]));
141 }
142 encoder->push_back(
143 torch::nn::Linear(denseLayersSizes[i], denseLayersSizes[i + 1]));
144 decoder->push_back(torch::nn::Linear(
145 denseLayersSizes[n - i], denseLayersSizes[n - (i + 1)]));
146 }
147
148 /*** convolutional decoder ***/
150 decoder->push_back(torch::nn::Unflatten(torch::nn::UnflattenOptions(
151 1, {convolutionalLayersChannels[convolutionalLayersChannels.size() - 1],
152 convolutionalOutputImageSide, convolutionalOutputImageSide})));
154 for(unsigned c = convolutionalLayersChannels.size() - 1; c > 0; --c) {
155 decoder->push_back(torch::nn::ReLU());
156 decoder->push_back(torch::nn::ConvTranspose2d(
157 torch::nn::ConvTranspose2dOptions(
158 convolutionalLayersChannels[c], convolutionalLayersChannels[c - 1], 3)
159 .padding(1)
160 .stride(convolutionalLayersStrides[c - 1])
161 .output_padding(1)));
162 }
164 decoder->push_back(torch::nn::Flatten());
165
166 register_module("encoder", encoder);
167 register_module("decoder", decoder);
168}
169
170bool ttk::ConvolutionalAutoEncoder::isStringValid(const std::string &s) {
171 return std::regex_match(s, std::regex("(c[0-9]+/[0-9]+( )*)*([0-9]+( )*)*"));
172}
173
174#endif