81#ifdef TTK_ENABLE_TORCH
85 void setDropout(
const double dropout);
87 void setEuclideanVectorsInit(
const bool euclideanVectorsInit);
89 void setRandomAxesInit(
const bool randomAxesInit);
91 void setInitBarycenterRandom(
const bool initBarycenterRandom);
93 void setInitBarycenterOneIter(
const bool initBarycenterOneIter);
95 void setInitOriginPrimeStructByCopy(
const bool initOriginPrimeStructByCopy);
97 void setInitOriginPrimeValuesByCopy(
const bool initOriginPrimeValuesByCopy);
99 void setInitOriginPrimeValuesByCopyRandomness(
100 const double initOriginPrimeValuesByCopyRandomness);
102 void setActivate(
const bool activate);
104 void setActivationFunction(
const unsigned int activationFunction);
106 void setUseGpu(
const bool useGpu);
108 void setBigValuesThreshold(
const float bigValuesThreshold);
113 torch::Tensor activation(torch::Tensor &in);
135 void getDistanceMatrix(
const std::vector<mtu::TorchMergeTree<float>> &tmts,
136 std::vector<std::vector<float>> &distanceMatrix,
137 bool useDoubleInput =
false,
138 bool isFirstInput =
true);
140 void getDistanceMatrix(
const std::vector<mtu::TorchMergeTree<float>> &tmts,
141 const std::vector<mtu::TorchMergeTree<float>> &tmts2,
142 std::vector<std::vector<float>> &distanceMatrix);
144 void getDifferentiableDistanceFromMatchings(
145 const mtu::TorchMergeTree<float> &tree1,
146 const mtu::TorchMergeTree<float> &tree2,
147 const mtu::TorchMergeTree<float> &tree1_2,
148 const mtu::TorchMergeTree<float> &tree2_2,
149 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings,
150 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings2,
151 torch::Tensor &tensorDist,
154 void getDifferentiableDistance(
const mtu::TorchMergeTree<float> &tree1,
155 const mtu::TorchMergeTree<float> &tree2,
156 const mtu::TorchMergeTree<float> &tree1_2,
157 const mtu::TorchMergeTree<float> &tree2_2,
158 torch::Tensor &tensorDist,
162 void getDifferentiableDistance(
const mtu::TorchMergeTree<float> &tree1,
163 const mtu::TorchMergeTree<float> &tree2,
164 torch::Tensor &tensorDist,
168 void getDifferentiableDistanceMatrix(
169 const std::vector<mtu::TorchMergeTree<float> *> &trees,
170 const std::vector<mtu::TorchMergeTree<float> *> &trees2,
171 std::vector<std::vector<torch::Tensor>> &outDistMat);