6ttk::AutoEncoder::AutoEncoder(
int inputDim,
8 const std::string &layersDescription,
9 const std::string &activation,
12 std::istringstream iss(layersDescription);
13 const std::vector<std::string> hiddenDimsParsed(
14 std::istream_iterator<std::string>{iss},
15 std::istream_iterator<std::string>());
16 std::vector<unsigned> dims(1, inputDim);
17 for(
const std::string &s : hiddenDimsParsed)
18 dims.push_back(std::stoi(s));
19 dims.push_back(latentDim);
21 const int n = dims.size() - 1;
22 encoder = torch::nn::Sequential(torch::nn::Linear(dims[0], dims[1]));
23 decoder = torch::nn::Sequential(torch::nn::Linear(dims[n], dims[n - 1]));
24 for(
unsigned i = 1; i < dims.size() - 1; ++i) {
25 if(activation ==
"ReLU") {
26 encoder->push_back(torch::nn::ReLU());
27 decoder->push_back(torch::nn::ReLU());
28 }
else if(activation ==
"Tanh") {
29 encoder->push_back(torch::nn::Tanh());
30 decoder->push_back(torch::nn::Tanh());
33 encoder->push_back(torch::nn::BatchNorm1d(dims[i]));
34 decoder->push_back(torch::nn::BatchNorm1d(dims[n - i]));
36 encoder->push_back(torch::nn::Linear(dims[i], dims[i + 1]));
37 decoder->push_back(torch::nn::Linear(dims[n - i], dims[n - (i + 1)]));
39 register_module(
"encoder", encoder);
40 register_module(
"decoder", decoder);
43bool ttk::AutoEncoder::isStringValid(
const std::string &s) {
44 return std::regex_match(s, std::regex(
"([0-9]+( )*)*"));
47ttk::AutoDecoder::AutoDecoder(
int inputDim,
50 const std::string &layersDescription,
51 const std::string &activation,
54 std::istringstream iss(layersDescription);
55 const std::vector<std::string> hiddenDimsParsed(
56 std::istream_iterator<std::string>{iss},
57 std::istream_iterator<std::string>());
58 std::vector<unsigned> dims(1, inputDim);
59 for(
const std::string &s : hiddenDimsParsed)
60 dims.push_back(std::stoi(s));
61 dims.push_back(latentDim);
63 latent = torch::rand({inputSize, latentDim});
65 const int n = dims.size() - 1;
66 decoder = torch::nn::Sequential(torch::nn::Linear(dims[n], dims[n - 1]));
67 for(
unsigned i = 1; i < dims.size() - 1; ++i) {
68 if(activation ==
"ReLU")
69 decoder->push_back(torch::nn::ReLU());
70 else if(activation ==
"Tanh")
71 decoder->push_back(torch::nn::Tanh());
73 decoder->push_back(torch::nn::BatchNorm1d(dims[n - i]));
74 decoder->push_back(torch::nn::Linear(dims[n - i], dims[n - (i + 1)]));
76 register_parameter(
"latent", latent,
true);
77 register_module(
"decoder", decoder);
80ttk::DirectOptimization::DirectOptimization(
int inputSize,
int latentDim) {
81 latent = torch::rand({inputSize, latentDim});
82 register_parameter(
"latent", latent,
true);
85ttk::ConvolutionalAutoEncoder::ConvolutionalAutoEncoder(
88 const std::string &layersDescription,
90 std::istringstream iss(layersDescription);
91 const std::vector<std::string> hiddenDimsParsed(
92 std::istream_iterator<std::string>{iss},
93 std::istream_iterator<std::string>());
94 std::vector<int> denseLayersSizes = {-1};
95 std::vector<int> convolutionalLayersChannels = {1};
96 std::vector<int> convolutionalLayersStrides;
97 for(
const std::string &s : hiddenDimsParsed) {
98 if(std::regex_match(s, std::regex(
"c[0-9]+/[0-9]+"))) {
99 convolutionalLayersChannels.push_back(
100 std::stoi(s.substr(1, s.find(
'/'))));
101 convolutionalLayersStrides.push_back(
102 std::stoi(s.substr(s.find(
'/') + 1, s.length())));
104 denseLayersSizes.push_back(std::stoi(s));
106 denseLayersSizes.push_back(latentDim);
107 int convolutionalOutputImageSide = imageSide;
111 encoder->push_back(torch::nn::Unflatten(
112 torch::nn::UnflattenOptions(1, {1, imageSide, imageSide})));
114 for(
unsigned c = 0; c < convolutionalLayersChannels.size() - 1; ++c) {
115 encoder->push_back(torch::nn::Conv2d(
116 torch::nn::Conv2dOptions(
117 convolutionalLayersChannels[c], convolutionalLayersChannels[c + 1], 3)
119 .stride(convolutionalLayersStrides[c])));
120 encoder->push_back(torch::nn::ReLU());
121 convolutionalOutputImageSide /= convolutionalLayersStrides[c];
124 encoder->push_back(torch::nn::Flatten());
128 = convolutionalLayersChannels[convolutionalLayersChannels.size() - 1]
129 * convolutionalOutputImageSide * convolutionalOutputImageSide;
130 const int n = denseLayersSizes.size() - 1;
132 torch::nn::Linear(denseLayersSizes[0], denseLayersSizes[1]));
134 torch::nn::Linear(denseLayersSizes[n], denseLayersSizes[n - 1]));
135 for(
unsigned i = 1; i < denseLayersSizes.size() - 1; ++i) {
136 encoder->push_back(torch::nn::ReLU());
137 decoder->push_back(torch::nn::ReLU());
139 encoder->push_back(torch::nn::BatchNorm1d(denseLayersSizes[i]));
140 decoder->push_back(torch::nn::BatchNorm1d(denseLayersSizes[n - i]));
143 torch::nn::Linear(denseLayersSizes[i], denseLayersSizes[i + 1]));
144 decoder->push_back(torch::nn::Linear(
145 denseLayersSizes[n - i], denseLayersSizes[n - (i + 1)]));
150 decoder->push_back(torch::nn::Unflatten(torch::nn::UnflattenOptions(
151 1, {convolutionalLayersChannels[convolutionalLayersChannels.size() - 1],
152 convolutionalOutputImageSide, convolutionalOutputImageSide})));
154 for(
unsigned c = convolutionalLayersChannels.size() - 1; c > 0; --c) {
155 decoder->push_back(torch::nn::ReLU());
156 decoder->push_back(torch::nn::ConvTranspose2d(
157 torch::nn::ConvTranspose2dOptions(
158 convolutionalLayersChannels[c], convolutionalLayersChannels[c - 1], 3)
160 .stride(convolutionalLayersStrides[c - 1])
161 .output_padding(1)));
164 decoder->push_back(torch::nn::Flatten());
166 register_module(
"encoder", encoder);
167 register_module(
"decoder", decoder);
170bool ttk::ConvolutionalAutoEncoder::isStringValid(
const std::string &s) {
171 return std::regex_match(s, std::regex(
"(c[0-9]+/[0-9]+( )*)*([0-9]+( )*)*"));