142#ifdef TTK_ENABLE_TORCH
144 std::vector<MergeTreeNeuralLayer> layers_;
147 std::vector<std::vector<torch::Tensor>> allAlphas_, allScaledAlphas_,
148 allActAlphas_, allActScaledAlphas_;
149 std::vector<std::vector<mtu::TorchMergeTree<float>>> recs_, recs2_;
151 std::vector<mtu::TorchMergeTree<float>> originsCopy_, originsPrimeCopy_;
155 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
158 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>>
165 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
180#ifdef TTK_ENABLE_TORCH
197 void initInputBasis(
unsigned int l,
198 unsigned int layerNoAxes,
199 std::vector<mtu::TorchMergeTree<float>> &trees,
200 std::vector<mtu::TorchMergeTree<float>> &trees2,
201 std::vector<bool> &isTrain,
202 std::vector<std::vector<torch::Tensor>> &allAlphasInit);
217 void initOutputBasis(
unsigned int l,
218 double layerOriginPrimeSizePercent,
219 std::vector<mtu::TorchMergeTree<float>> &trees,
220 std::vector<mtu::TorchMergeTree<float>> &trees2,
221 std::vector<bool> &isTrain);
248 initResetOutputBasis(
unsigned int l,
249 unsigned int layerNoAxes,
250 double layerOriginPrimeSizePercent,
251 std::vector<mtu::TorchMergeTree<float>> &trees,
252 std::vector<mtu::TorchMergeTree<float>> &trees2,
253 std::vector<bool> &isTrain)
277 bool initGetReconstructed(
279 unsigned int layerNoAxes,
280 double layerOriginPrimeSizePercent,
281 std::vector<mtu::TorchMergeTree<float>> &trees,
282 std::vector<mtu::TorchMergeTree<float>> &trees2,
283 std::vector<bool> &isTrain,
284 std::vector<mtu::TorchMergeTree<float>> &recs,
285 std::vector<mtu::TorchMergeTree<float>> &recs2,
286 std::vector<std::vector<torch::Tensor>> &allAlphasInit);
305 initParameters(std::vector<mtu::TorchMergeTree<float>> &trees,
306 std::vector<mtu::TorchMergeTree<float>> &trees2,
307 std::vector<bool> &isTrain,
308 bool computeError =
false)
321 void initStep(std::vector<mtu::TorchMergeTree<float>> &trees,
322 std::vector<mtu::TorchMergeTree<float>> &trees2,
323 std::vector<bool> &isTrain);
363 bool forwardOneData(mtu::TorchMergeTree<float> &tree,
364 mtu::TorchMergeTree<float> &tree2,
365 unsigned int treeIndex,
367 std::vector<torch::Tensor> &alphasInit,
368 mtu::TorchMergeTree<float> &out,
369 mtu::TorchMergeTree<float> &out2,
370 std::vector<torch::Tensor> &dataAlphas,
371 std::vector<mtu::TorchMergeTree<float>> &outs,
372 std::vector<mtu::TorchMergeTree<float>> &outs2,
409 std::vector<mtu::TorchMergeTree<float>> &trees,
410 std::vector<mtu::TorchMergeTree<float>> &trees2,
411 std::vector<unsigned int> &indexes,
412 std::vector<bool> &isTrain,
414 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
416 std::vector<mtu::TorchMergeTree<float>> &outs,
417 std::vector<mtu::TorchMergeTree<float>> &outs2,
418 std::vector<std::vector<torch::Tensor>> &bestAlphas,
419 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
420 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
421 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
423 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
459 std::vector<mtu::TorchMergeTree<float>> &trees,
460 std::vector<mtu::TorchMergeTree<float>> &trees2,
461 std::vector<unsigned int> &indexes,
463 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
465 std::vector<mtu::TorchMergeTree<float>> &outs,
466 std::vector<mtu::TorchMergeTree<float>> &outs2,
467 std::vector<std::vector<torch::Tensor>> &bestAlphas,
468 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
469 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
470 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
472 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
506 virtual bool backwardStep(
507 std::vector<mtu::TorchMergeTree<float>> &trees,
508 std::vector<mtu::TorchMergeTree<float>> &outs,
509 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
511 std::vector<mtu::TorchMergeTree<float>> &trees2,
512 std::vector<mtu::TorchMergeTree<float>> &outs2,
513 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
515 std::vector<std::vector<torch::Tensor>> &alphas,
516 torch::optim::Optimizer &optimizer,
517 std::vector<unsigned int> &indexes,
518 std::vector<bool> &isTrain,
519 std::vector<torch::Tensor> &torchCustomLoss)
529 void projectionStep();
555 virtual float computeOneLoss(
556 mtu::TorchMergeTree<float> &tree,
557 mtu::TorchMergeTree<float> &out,
558 mtu::TorchMergeTree<float> &tree2,
559 mtu::TorchMergeTree<float> &out2,
560 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
561 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
562 std::vector<torch::Tensor> &alphas,
563 unsigned int treeIndex)
577 bool isBestLoss(
float loss,
float &minLoss,
unsigned int &cptBlocked);
591 bool convergenceStep(
float loss,
594 unsigned int &cptBlocked);
610 customInit(std::vector<mtu::TorchMergeTree<float>> &torchTrees,
611 std::vector<mtu::TorchMergeTree<float>> &torchTrees2)
624 virtual void addCustomParameters(std::vector<torch::Tensor> ¶meters)
653 virtual void computeCustomLosses(
654 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
655 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
656 std::vector<std::vector<torch::Tensor>> &bestAlphas,
657 std::vector<unsigned int> &indexes,
658 std::vector<bool> &isTrain,
659 unsigned int iteration,
660 std::vector<std::vector<float>> &gapCustomLosses,
661 std::vector<std::vector<float>> &iterationCustomLosses,
662 std::vector<torch::Tensor> &torchCustomLoss)
678 virtual float computeIterationTotalLoss(
680 std::vector<std::vector<float>> &iterationCustomLosses,
681 std::vector<float> &iterationCustomLoss)
694 virtual void printCustomLosses(std::vector<float> &customLoss,
695 std::stringstream &prefix,
710 virtual void printGapLoss(
float loss,
711 std::vector<std::vector<float>> &gapCustomLosses)
737 void computeTrackingInformation(
unsigned int endLayer);
759 createScaledAlphas(std::vector<std::vector<torch::Tensor>> &alphas,
760 std::vector<std::vector<torch::Tensor>> &scaledAlphas);
766 void createScaledAlphas();
773 void createActivatedAlphas();
778 void copyParams(std::vector<mtu::TorchMergeTree<float>> &origins,
779 std::vector<mtu::TorchMergeTree<float>> &originsPrime,
780 std::vector<torch::Tensor> &vS,
781 std::vector<torch::Tensor> &vSPrime,
782 std::vector<mtu::TorchMergeTree<float>> &origins2,
783 std::vector<mtu::TorchMergeTree<float>> &origins2Prime,
784 std::vector<torch::Tensor> &vS2,
785 std::vector<torch::Tensor> &vS2Prime,
786 std::vector<std::vector<torch::Tensor>> &srcAlphas,
787 std::vector<std::vector<torch::Tensor>> &dstAlphas,
790 void copyParams(std::vector<std::vector<mtu::TorchMergeTree<float>>> &src,
791 std::vector<std::vector<mtu::TorchMergeTree<float>>> &dst);
801 virtual void copyCustomParams(
bool get) = 0;
815 void getAlphasTensor(std::vector<std::vector<torch::Tensor>> &alphas,
816 std::vector<unsigned int> &indexes,
817 std::vector<bool> &toGet,
818 unsigned int layerIndex,
819 torch::Tensor &alphasOut);
831 void getAlphasTensor(std::vector<std::vector<torch::Tensor>> &alphas,
832 std::vector<unsigned int> &indexes,
833 unsigned int layerIndex,
834 torch::Tensor &alphasOut);
845 void getAlphasTensor(std::vector<std::vector<torch::Tensor>> &alphas,
846 unsigned int layerIndex,
847 torch::Tensor &alphasOut);
852 void checkZeroGrad(
unsigned int l,
bool checkOutputBasis =
true);
855 float threshold = 10000);