84#ifdef TTK_ENABLE_TORCH
86 std::vector<torch::Tensor> vSTensor_, vSPrimeTensor_, vS2Tensor_,
87 vS2PrimeTensor_, latentCentroids_;
88 std::vector<mtu::TorchMergeTree<float>> origins_, originsPrime_, origins2_,
91 std::vector<mtu::TorchMergeTree<float>> originsCopy_, originsPrimeCopy_;
94 std::vector<std::vector<torch::Tensor>> allAlphas_, allScaledAlphas_,
95 allActAlphas_, allActScaledAlphas_;
96 std::vector<std::vector<mtu::TorchMergeTree<float>>> recs_, recs2_;
97 std::vector<mtu::TorchMergeTree<float>> customRecs_;
106 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
111 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
114 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>>
125#ifdef TTK_ENABLE_TORCH
126 std::vector<mtu::TorchMergeTree<float>> initOrigins_, initOriginsPrime_,
134#ifdef TTK_ENABLE_TORCH
138 void initOutputBasisTreeStructure(mtu::TorchMergeTree<float> &originPrime,
140 mtu::TorchMergeTree<float> &baseOrigin);
142 void initOutputBasis(
unsigned int l,
unsigned int dim,
unsigned int dim2);
144 void initOutputBasisVectors(
unsigned int l,
148 void initOutputBasisVectors(
unsigned int l,
152 void initInputBasisOrigin(
155 double barycenterSizeLimitPercent,
156 unsigned int barycenterMaxNoPairs,
157 unsigned int barycenterMaxNoPairs2,
158 mtu::TorchMergeTree<float> &origin,
159 mtu::TorchMergeTree<float> &origin2,
160 std::vector<double> &inputToBaryDistances,
161 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
163 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
166 void initInputBasisVectors(
167 std::vector<mtu::TorchMergeTree<float>> &tmTreesToUse,
168 std::vector<mtu::TorchMergeTree<float>> &tmTrees2ToUse,
171 mtu::TorchMergeTree<float> &origin,
172 mtu::TorchMergeTree<float> &origin2,
173 unsigned int noVectors,
174 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
176 std::vector<double> &inputToBaryDistances,
177 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
179 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
181 torch::Tensor &vSTensor,
182 torch::Tensor &vS2Tensor);
184 void initClusteringLossParameters();
186 float initParameters(std::vector<mtu::TorchMergeTree<float>> &trees,
187 std::vector<mtu::TorchMergeTree<float>> &trees2,
188 bool computeReconstructionError =
false);
190 void initStep(std::vector<mtu::TorchMergeTree<float>> &trees,
191 std::vector<mtu::TorchMergeTree<float>> &trees2);
196 void interpolationDiagonalProjection(
197 mtu::TorchMergeTree<float> &interpolationTensor);
200 interpolationNestingProjection(mtu::TorchMergeTree<float> &interpolation);
202 void interpolationProjection(mtu::TorchMergeTree<float> &interpolation);
204 void getMultiInterpolation(mtu::TorchMergeTree<float> &origin,
206 torch::Tensor &alphas,
207 mtu::TorchMergeTree<float> &interpolation);
212 void getAlphasOptimizationTensors(
213 mtu::TorchMergeTree<float> &tree,
214 mtu::TorchMergeTree<float> &origin,
215 torch::Tensor &vSTensor,
216 mtu::TorchMergeTree<float> &interpolated,
217 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
218 torch::Tensor &reorderedTreeTensor,
219 torch::Tensor &deltaOrigin,
220 torch::Tensor &deltaA,
221 torch::Tensor &originTensor_f,
222 torch::Tensor &vSTensor_f);
225 mtu::TorchMergeTree<float> &tree,
226 mtu::TorchMergeTree<float> &origin,
227 torch::Tensor &vSTensor,
228 mtu::TorchMergeTree<float> &interpolated,
229 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
230 mtu::TorchMergeTree<float> &tree2,
231 mtu::TorchMergeTree<float> &origin2,
232 torch::Tensor &vS2Tensor,
233 mtu::TorchMergeTree<float> &interpolated2,
234 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
235 torch::Tensor &alphasOut);
237 float assignmentOneData(
238 mtu::TorchMergeTree<float> &tree,
239 mtu::TorchMergeTree<float> &origin,
240 torch::Tensor &vSTensor,
241 mtu::TorchMergeTree<float> &tree2,
242 mtu::TorchMergeTree<float> &origin2,
243 torch::Tensor &vS2Tensor,
245 torch::Tensor &alphasInit,
246 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching,
247 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &bestMatching2,
248 torch::Tensor &bestAlphas,
249 bool isCalled =
false);
251 float assignmentOneData(mtu::TorchMergeTree<float> &tree,
252 mtu::TorchMergeTree<float> &origin,
253 torch::Tensor &vSTensor,
254 mtu::TorchMergeTree<float> &tree2,
255 mtu::TorchMergeTree<float> &origin2,
256 torch::Tensor &vS2Tensor,
258 torch::Tensor &alphasInit,
259 torch::Tensor &bestAlphas,
260 bool isCalled =
false);
262 torch::Tensor activation(torch::Tensor &in);
264 void outputBasisReconstruction(mtu::TorchMergeTree<float> &originPrime,
265 torch::Tensor &vSPrimeTensor,
266 mtu::TorchMergeTree<float> &origin2Prime,
267 torch::Tensor &vS2PrimeTensor,
268 torch::Tensor &alphas,
269 mtu::TorchMergeTree<float> &out,
270 mtu::TorchMergeTree<float> &out2,
271 bool activate =
true);
273 bool forwardOneLayer(mtu::TorchMergeTree<float> &tree,
274 mtu::TorchMergeTree<float> &origin,
275 torch::Tensor &vSTensor,
276 mtu::TorchMergeTree<float> &originPrime,
277 torch::Tensor &vSPrimeTensor,
278 mtu::TorchMergeTree<float> &tree2,
279 mtu::TorchMergeTree<float> &origin2,
280 torch::Tensor &vS2Tensor,
281 mtu::TorchMergeTree<float> &origin2Prime,
282 torch::Tensor &vS2PrimeTensor,
284 torch::Tensor &alphasInit,
285 mtu::TorchMergeTree<float> &out,
286 mtu::TorchMergeTree<float> &out2,
287 torch::Tensor &bestAlphas,
288 float &bestDistance);
290 bool forwardOneLayer(mtu::TorchMergeTree<float> &tree,
291 mtu::TorchMergeTree<float> &origin,
292 torch::Tensor &vSTensor,
293 mtu::TorchMergeTree<float> &originPrime,
294 torch::Tensor &vSPrimeTensor,
295 mtu::TorchMergeTree<float> &tree2,
296 mtu::TorchMergeTree<float> &origin2,
297 torch::Tensor &vS2Tensor,
298 mtu::TorchMergeTree<float> &origin2Prime,
299 torch::Tensor &vS2PrimeTensor,
301 torch::Tensor &alphasInit,
302 mtu::TorchMergeTree<float> &out,
303 mtu::TorchMergeTree<float> &out2,
304 torch::Tensor &bestAlphas);
306 bool forwardOneData(mtu::TorchMergeTree<float> &tree,
307 mtu::TorchMergeTree<float> &tree2,
308 unsigned int treeIndex,
310 std::vector<torch::Tensor> &alphasInit,
311 mtu::TorchMergeTree<float> &out,
312 mtu::TorchMergeTree<float> &out2,
313 std::vector<torch::Tensor> &dataAlphas,
314 std::vector<mtu::TorchMergeTree<float>> &outs,
315 std::vector<mtu::TorchMergeTree<float>> &outs2);
318 std::vector<mtu::TorchMergeTree<float>> &trees,
319 std::vector<mtu::TorchMergeTree<float>> &trees2,
320 std::vector<unsigned int> &indexes,
322 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
323 bool computeReconstructionError,
324 std::vector<mtu::TorchMergeTree<float>> &outs,
325 std::vector<mtu::TorchMergeTree<float>> &outs2,
326 std::vector<std::vector<torch::Tensor>> &bestAlphas,
327 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
328 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
329 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
331 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
335 bool forwardStep(std::vector<mtu::TorchMergeTree<float>> &trees,
336 std::vector<mtu::TorchMergeTree<float>> &trees2,
337 std::vector<unsigned int> &indexes,
339 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
340 std::vector<mtu::TorchMergeTree<float>> &outs,
341 std::vector<mtu::TorchMergeTree<float>> &outs2,
342 std::vector<std::vector<torch::Tensor>> &bestAlphas);
348 std::vector<mtu::TorchMergeTree<float>> &trees,
349 std::vector<mtu::TorchMergeTree<float>> &outs,
350 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
352 std::vector<mtu::TorchMergeTree<float>> &trees2,
353 std::vector<mtu::TorchMergeTree<float>> &outs2,
354 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
356 torch::optim::Optimizer &optimizer,
357 std::vector<unsigned int> &indexes,
358 torch::Tensor &metricLoss,
359 torch::Tensor &clusteringLoss,
360 torch::Tensor &trackingLoss);
365 void projectionStep();
370 float computeOneLoss(
371 mtu::TorchMergeTree<float> &tree,
372 mtu::TorchMergeTree<float> &out,
373 mtu::TorchMergeTree<float> &tree2,
374 mtu::TorchMergeTree<float> &out2,
375 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
376 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2);
379 std::vector<mtu::TorchMergeTree<float>> &trees,
380 std::vector<mtu::TorchMergeTree<float>> &outs,
381 std::vector<mtu::TorchMergeTree<float>> &trees2,
382 std::vector<mtu::TorchMergeTree<float>> &outs2,
383 std::vector<unsigned int> &indexes,
384 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
386 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
389 bool isBestLoss(
float loss,
float &minLoss,
unsigned int &cptBlocked);
391 bool convergenceStep(
float loss,
394 unsigned int &cptBlocked);
405 double getCustomLossDynamicWeight(
double recLoss,
double &baseLoss);
407 void getDistanceMatrix(std::vector<mtu::TorchMergeTree<float>> &tmts,
408 std::vector<std::vector<float>> &distanceMatrix,
409 bool useDoubleInput =
false,
410 bool isFirstInput =
true);
412 void getDistanceMatrix(std::vector<mtu::TorchMergeTree<float>> &tmts,
413 std::vector<mtu::TorchMergeTree<float>> &tmts2,
414 std::vector<std::vector<float>> &distanceMatrix);
416 void getDifferentiableDistanceFromMatchings(
417 mtu::TorchMergeTree<float> &tree1,
418 mtu::TorchMergeTree<float> &tree2,
419 mtu::TorchMergeTree<float> &tree1_2,
420 mtu::TorchMergeTree<float> &tree2_2,
421 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings,
422 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matchings2,
423 torch::Tensor &tensorDist,
426 void getDifferentiableDistance(mtu::TorchMergeTree<float> &tree1,
427 mtu::TorchMergeTree<float> &tree2,
428 mtu::TorchMergeTree<float> &tree1_2,
429 mtu::TorchMergeTree<float> &tree2_2,
430 torch::Tensor &tensorDist,
434 void getDifferentiableDistance(mtu::TorchMergeTree<float> &tree1,
435 mtu::TorchMergeTree<float> &tree2,
436 torch::Tensor &tensorDist,
440 void getDifferentiableDistanceMatrix(
441 std::vector<mtu::TorchMergeTree<float> *> &trees,
442 std::vector<mtu::TorchMergeTree<float> *> &trees2,
443 std::vector<std::vector<torch::Tensor>> &outDistMat);
445 void getAlphasTensor(std::vector<std::vector<torch::Tensor>> &alphas,
446 std::vector<unsigned int> &indexes,
447 unsigned int layerIndex,
448 torch::Tensor &alphasOut);
450 void computeMetricLoss(
451 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
452 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
453 std::vector<std::vector<torch::Tensor>> alphas,
454 std::vector<std::vector<float>> &baseDistanceMatrix,
455 std::vector<unsigned int> &indexes,
456 torch::Tensor &metricLoss);
458 void computeClusteringLoss(std::vector<std::vector<torch::Tensor>> &alphas,
459 std::vector<unsigned int> &indexes,
460 torch::Tensor &clusteringLoss,
461 torch::Tensor &asgn);
463 void computeTrackingLoss(torch::Tensor &trackingLoss);
469 createCustomRecs(std::vector<mtu::TorchMergeTree<float>> &origins,
470 std::vector<mtu::TorchMergeTree<float>> &originsPrime);
472 void computeTrackingInformation();
475 createScaledAlphas(std::vector<std::vector<torch::Tensor>> &alphas,
476 std::vector<torch::Tensor> &vSTensor,
477 std::vector<std::vector<torch::Tensor>> &scaledAlphas);
479 void createScaledAlphas();
481 void createActivatedAlphas();
486 void copyParams(std::vector<mtu::TorchMergeTree<float>> &srcOrigins,
487 std::vector<mtu::TorchMergeTree<float>> &srcOriginsPrime,
488 std::vector<torch::Tensor> &srcVS,
489 std::vector<torch::Tensor> &srcVSPrime,
490 std::vector<mtu::TorchMergeTree<float>> &srcOrigins2,
491 std::vector<mtu::TorchMergeTree<float>> &srcOrigins2Prime,
492 std::vector<torch::Tensor> &srcVS2,
493 std::vector<torch::Tensor> &srcVS2Prime,
494 std::vector<std::vector<torch::Tensor>> &srcAlphas,
495 std::vector<mtu::TorchMergeTree<float>> &dstOrigins,
496 std::vector<mtu::TorchMergeTree<float>> &dstOriginsPrime,
497 std::vector<torch::Tensor> &dstVS,
498 std::vector<torch::Tensor> &dstVSPrime,
499 std::vector<mtu::TorchMergeTree<float>> &dstOrigins2,
500 std::vector<mtu::TorchMergeTree<float>> &dstOrigins2Prime,
501 std::vector<torch::Tensor> &dstVS2,
502 std::vector<torch::Tensor> &dstVS2Prime,
503 std::vector<std::vector<torch::Tensor>> &dstAlphas);
505 void copyParams(std::vector<std::vector<mtu::TorchMergeTree<float>>> &src,
506 std::vector<std::vector<mtu::TorchMergeTree<float>>> &dst);
508 unsigned int getLatentLayerIndex();
514 float threshold = 10000);