20#ifdef TTK_ENABLE_TORCH
22#include <torch/torch.h>
31 virtual torch::Tensor forward(torch::Tensor
const &x) = 0;
32 virtual torch::Tensor encode(torch::Tensor
const &x) = 0;
33 virtual torch::Tensor decode(torch::Tensor
const &x) = 0;
42 AutoEncoder(
int inputDim,
44 const std::string &layersDescription,
45 const std::string &activation =
"ReLU",
48 inline torch::Tensor forward(torch::Tensor
const &x)
override {
49 return decoder->forward(encoder->forward(x));
52 inline torch::Tensor encode(torch::Tensor
const &x)
override {
53 return encoder->forward(x);
56 inline torch::Tensor decode(torch::Tensor
const &x)
override {
57 return decoder->forward(x);
60 static bool isStringValid(
const std::string &s);
63 torch::nn::Sequential encoder;
64 torch::nn::Sequential decoder;
73 AutoDecoder(
int inputDim,
76 const std::string &layersDescription,
77 const std::string &activation =
"ReLU",
80 inline torch::Tensor forward(torch::Tensor
const & )
override {
81 return decoder->forward(latent);
84 inline torch::Tensor encode(torch::Tensor
const & )
override {
88 inline torch::Tensor decode(torch::Tensor
const & )
override {
89 return decoder->forward(latent);
94 torch::nn::Sequential decoder;
103 DirectOptimization(
int inputSize,
int latentDim);
105 inline torch::Tensor forward(torch::Tensor
const &x)
override {
109 inline torch::Tensor encode(torch::Tensor
const &x)
override {
114 inline torch::Tensor decode(torch::Tensor
const & )
override {
119 torch::Tensor latent;
130 ConvolutionalAutoEncoder(
int imageSide,
132 const std::string &layersDescription,
135 inline torch::Tensor forward(torch::Tensor
const &x)
override {
136 return decoder->forward(encoder->forward(x));
139 inline torch::Tensor encode(torch::Tensor
const &x)
override {
140 return encoder->forward(x);
143 inline torch::Tensor decode(torch::Tensor
const &x)
override {
144 return decoder->forward(x);
147 static bool isStringValid(
const std::string &s);
150 torch::nn::Sequential encoder;
151 torch::nn::Sequential decoder;
TTK base class for containing reduction dimension models.
TTK base package defining the standard types.