50#ifdef TTK_ENABLE_TORCH
52 torch::Tensor vSTensor_, vSPrimeTensor_, vS2Tensor_, vS2PrimeTensor_;
53 mtu::TorchMergeTree<float> origin_, originPrime_, origin2_, origin2Prime_;
59#ifdef TTK_ENABLE_TORCH
63 const mtu::TorchMergeTree<float> &getOrigin()
const;
65 const mtu::TorchMergeTree<float> &getOriginPrime()
const;
67 const mtu::TorchMergeTree<float> &getOrigin2()
const;
69 const mtu::TorchMergeTree<float> &getOrigin2Prime()
const;
71 const torch::Tensor &getVSTensor()
const;
73 const torch::Tensor &getVSPrimeTensor()
const;
75 const torch::Tensor &getVS2Tensor()
const;
77 const torch::Tensor &getVS2PrimeTensor()
const;
79 void setOrigin(
const mtu::TorchMergeTree<float> &tmt);
81 void setOriginPrime(
const mtu::TorchMergeTree<float> &tmt);
83 void setOrigin2(
const mtu::TorchMergeTree<float> &tmt);
85 void setOrigin2Prime(
const mtu::TorchMergeTree<float> &tmt);
87 void setVSTensor(
const torch::Tensor &vS);
89 void setVSPrimeTensor(
const torch::Tensor &vS);
91 void setVS2Tensor(
const torch::Tensor &vS);
93 void setVS2PrimeTensor(
const torch::Tensor &vS);
108 void initOutputBasisTreeStructure(mtu::TorchMergeTree<float> &originPrime,
110 mtu::TorchMergeTree<float> &baseOrigin);
124 void initOutputBasis(
const unsigned int dim,
125 const unsigned int dim2,
126 const torch::Tensor &baseTensor);
137 void initOutputBasisVectors(torch::Tensor &w, torch::Tensor &w2);
147 void initOutputBasisVectors(
unsigned int dim,
unsigned int dim2);
171 void initInputBasisOrigin(
174 double barycenterSizeLimitPercent,
175 unsigned int barycenterMaxNoPairs,
176 unsigned int barycenterMaxNoPairs2,
177 std::vector<double> &inputToBaryDistances,
178 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
180 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
213 void initInputBasisVectors(
214 std::vector<mtu::TorchMergeTree<float>> &tmTrees,
215 std::vector<mtu::TorchMergeTree<float>> &tmTrees2,
218 unsigned int noVectors,
219 std::vector<torch::Tensor> &allAlphasInit,
220 std::vector<double> &inputToBaryDistances,
221 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
223 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
225 mtu::TorchMergeTree<float> &origin,
226 mtu::TorchMergeTree<float> &origin2,
227 torch::Tensor &vSTensor,
228 torch::Tensor &vS2Tensor,
229 bool useInputBasis =
true);
256 void initInputBasisVectors(
257 std::vector<mtu::TorchMergeTree<float>> &tmTrees,
258 std::vector<mtu::TorchMergeTree<float>> &tmTrees2,
261 unsigned int noVectors,
262 std::vector<torch::Tensor> &allAlphasInit,
263 std::vector<double> &inputToBaryDistances,
264 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
266 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
268 bool useInputBasis =
true);
270 void requires_grad(
const bool requireGrad);
284 void interpolationDiagonalProjection(
285 mtu::TorchMergeTree<float> &interpolationTensor);
295 interpolationNestingProjection(mtu::TorchMergeTree<float> &interpolation);
304 void interpolationProjection(mtu::TorchMergeTree<float> &interpolation);
316 void getMultiInterpolation(
const mtu::TorchMergeTree<float> &origin,
317 const torch::Tensor &vS,
318 torch::Tensor &alphas,
319 mtu::TorchMergeTree<float> &interpolation);
351 void getAlphasOptimizationTensors(
352 mtu::TorchMergeTree<float> &tree,
353 mtu::TorchMergeTree<float> &origin,
354 torch::Tensor &vSTensor,
355 mtu::TorchMergeTree<float> &interpolated,
356 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
357 torch::Tensor &reorderedTreeTensor,
358 torch::Tensor &deltaOrigin,
359 torch::Tensor &deltaA,
360 torch::Tensor &originTensor_f,
361 torch::Tensor &vSTensor_f);
388 mtu::TorchMergeTree<float> &tree,
389 mtu::TorchMergeTree<float> &origin,
390 torch::Tensor &vSTensor,
391 mtu::TorchMergeTree<float> &interpolated,
392 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
393 mtu::TorchMergeTree<float> &tree2,
394 mtu::TorchMergeTree<float> &origin2,
395 torch::Tensor &vS2Tensor,
396 mtu::TorchMergeTree<float> &interpolated2,
397 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
398 torch::Tensor &alphasOut);
426 float assignmentOneData(
427 mtu::TorchMergeTree<float> &tree,
428 mtu::TorchMergeTree<float> &tree2,
430 torch::Tensor &alphasInit,
431 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching,
432 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching2,
433 torch::Tensor &bestAlphas,
434 bool isCalled =
false,
435 bool useInputBasis =
true);
459 float assignmentOneData(mtu::TorchMergeTree<float> &tree,
460 mtu::TorchMergeTree<float> &tree2,
462 torch::Tensor &alphasInit,
463 torch::Tensor &bestAlphas,
464 bool isCalled =
false,
465 bool useInputBasis =
true);
479 void outputBasisReconstruction(torch::Tensor &alphas,
480 mtu::TorchMergeTree<float> &out,
481 mtu::TorchMergeTree<float> &out2,
482 bool activate =
true,
507 bool forward(mtu::TorchMergeTree<float> &tree,
508 mtu::TorchMergeTree<float> &tree2,
510 torch::Tensor &alphasInit,
511 mtu::TorchMergeTree<float> &out,
512 mtu::TorchMergeTree<float> &out2,
513 torch::Tensor &bestAlphas,
537 bool forward(mtu::TorchMergeTree<float> &tree,
538 mtu::TorchMergeTree<float> &tree2,
540 torch::Tensor &alphasInit,
541 mtu::TorchMergeTree<float> &out,
542 mtu::TorchMergeTree<float> &out2,
543 torch::Tensor &bestAlphas,
553 void projectionStep();
558 void copyParams(mtu::TorchMergeTree<float> &origin,
559 mtu::TorchMergeTree<float> &originPrime,
561 torch::Tensor &vSPrime,
562 mtu::TorchMergeTree<float> &origin2,
563 mtu::TorchMergeTree<float> &origin2Prime,
565 torch::Tensor &vS2Prime,
576 void adjustNestingScalars(std::vector<float> &scalarsVector,
591 createBalancedBDT(std::vector<std::vector<ftm::idNode>> &parents,
592 std::vector<std::vector<ftm::idNode>> &children,
593 std::vector<float> &scalarsVector,
594 std::vector<std::vector<ftm::idNode>> &childrenFinal);
600 float threshold = 10000);