66#ifdef TTK_ENABLE_TORCH
68 std::vector<torch::Tensor> bestLatentCentroids_, latentCentroids_;
70 std::vector<torch::Tensor> vSTensorCopy_, vSPrimeTensorCopy_;
72 std::vector<mtu::TorchMergeTree<float>> customRecs_;
82#ifdef TTK_ENABLE_TORCH
86 void initClusteringLossParameters();
88 bool initResetOutputBasis(
unsigned int l,
89 unsigned int layerNoAxes,
90 double layerOriginPrimeSizePercent,
91 std::vector<mtu::TorchMergeTree<float>> &trees,
92 std::vector<mtu::TorchMergeTree<float>> &trees2,
93 std::vector<bool> &isTrain)
override;
95 void initOutputBasisSpecialCase(
97 unsigned int layerNoAxes,
98 std::vector<mtu::TorchMergeTree<float>> &trees,
99 std::vector<mtu::TorchMergeTree<float>> &trees2);
101 float initParameters(std::vector<mtu::TorchMergeTree<float>> &trees,
102 std::vector<mtu::TorchMergeTree<float>> &trees2,
103 std::vector<bool> &isTrain,
104 bool computeError =
false)
override;
110 std::vector<mtu::TorchMergeTree<float>> &trees,
111 std::vector<mtu::TorchMergeTree<float>> &outs,
112 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
114 std::vector<mtu::TorchMergeTree<float>> &trees2,
115 std::vector<mtu::TorchMergeTree<float>> &outs2,
116 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
118 std::vector<std::vector<torch::Tensor>> &alphas,
119 torch::optim::Optimizer &optimizer,
120 std::vector<unsigned int> &indexes,
121 std::vector<bool> &isTrain,
122 std::vector<torch::Tensor> &torchCustomLoss)
override;
127 float computeOneLoss(
128 mtu::TorchMergeTree<float> &tree,
129 mtu::TorchMergeTree<float> &out,
130 mtu::TorchMergeTree<float> &tree2,
131 mtu::TorchMergeTree<float> &out2,
132 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
133 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
134 std::vector<torch::Tensor> &alphas,
135 unsigned int treeIndex)
override;
141 customInit(std::vector<mtu::TorchMergeTree<float>> &torchTrees,
142 std::vector<mtu::TorchMergeTree<float>> &torchTrees2)
override;
144 void addCustomParameters(std::vector<torch::Tensor> ¶meters)
override;
146 void computeCustomLosses(
147 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
148 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
149 std::vector<std::vector<torch::Tensor>> &bestAlphas,
150 std::vector<unsigned int> &indexes,
151 std::vector<bool> &isTrain,
152 unsigned int iteration,
153 std::vector<std::vector<float>> &gapCustomLosses,
154 std::vector<std::vector<float>> &iterationCustomLosses,
155 std::vector<torch::Tensor> &torchCustomLoss)
override;
157 float computeIterationTotalLoss(
159 std::vector<std::vector<float>> &iterationCustomLosses,
160 std::vector<float> &iterationCustomLoss)
override;
162 void printCustomLosses(std::vector<float> &customLoss,
163 std::stringstream &prefix,
168 printGapLoss(
float loss,
169 std::vector<std::vector<float>> &gapCustomLosses)
override;
174 double getCustomLossDynamicWeight(
double recLoss,
double &baseLoss);
176 void computeMetricLoss(
177 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
178 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
179 std::vector<std::vector<torch::Tensor>> alphas,
180 std::vector<std::vector<float>> &baseDistanceMatrix,
181 std::vector<unsigned int> &indexes,
182 torch::Tensor &metricLoss);
184 void computeClusteringLoss(std::vector<std::vector<torch::Tensor>> &alphas,
185 std::vector<unsigned int> &indexes,
186 torch::Tensor &clusteringLoss,
187 torch::Tensor &asgn);
189 void computeTrackingLoss(torch::Tensor &trackingLoss);
194 void createCustomRecs();
199 unsigned int getLatentLayerIndex();
201 void copyCustomParams(
bool get)
override;