TTK
Loading...
Searching...
No Matches
MergeTreeNeuralNetwork.h
Go to the documentation of this file.
1
86
87#pragma once
88
89// ttk common includes
90#include <Debug.h>
91#include <Geometry.h>
92#include <MergeTreeNeuralBase.h>
94#include <MergeTreeTorchUtils.h>
95
96#ifdef TTK_ENABLE_TORCH
97#include <torch/torch.h>
98#endif
99
100namespace ttk {
101
106 class MergeTreeNeuralNetwork : virtual public Debug,
107 public MergeTreeNeuralBase {
108
109 protected:
110 // Minimum number of iterations to run
111 unsigned int minIteration_ = 0;
112 // Maximum number of iterations to run
113 unsigned int maxIteration_ = 0;
114 // Number of iterations between each print
115 unsigned int iterationGap_ = 100;
116 // Batch size between 0 and 1
117 double batchSize_ = 1;
118 // Optimizer
119 // 0 : Adam
120 // 1 : Stochastic Gradient Descent
121 // 2 : RMS Prop
122 int optimizer_ = 0;
123 // Gradient Step/Learning rate
124 double gradientStepSize_ = 0.1;
125 // Adam parameters
126 double beta1_ = 0.9;
127 double beta2_ = 0.999;
128 // Number of initializations to do (the better will be kept)
129 unsigned int noInit_ = 4;
130 // If activation functions should be used during the initialization
132 // Limit in the size of the origin in output basis as a percentage of the
133 // input total number of nodes
135 // Proportion between the train set and the validation/test set
136 double trainTestSplit_ = 1.0;
137 // If the input data should be shuffled before splitted
139
140 bool createOutput_ = true;
141
142#ifdef TTK_ENABLE_TORCH
143 // Model optimized parameters
144 std::vector<MergeTreeNeuralLayer> layers_;
145
146 // Filled by the algorithm
147 std::vector<std::vector<torch::Tensor>> allAlphas_, allScaledAlphas_,
148 allActAlphas_, allActScaledAlphas_;
149 std::vector<std::vector<mtu::TorchMergeTree<float>>> recs_, recs2_;
150
151 std::vector<mtu::TorchMergeTree<float>> originsCopy_, originsPrimeCopy_;
152#endif
153
154 // Tracking matchings
155 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
157 std::vector<
158 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>>
160
161 // Filled by the algorithm
162 unsigned noLayers_;
164 std::vector<unsigned int> clusterAsgn_;
165 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
167 std::vector<double> inputToBaryDistances_L0_;
168 std::vector<std::vector<double>> branchesCorrelationMatrix_,
170
171 // Testing
173 std::vector<unsigned int> originsNoZeroGrad_, originsPrimeNoZeroGrad_,
176
177 public:
179
180#ifdef TTK_ENABLE_TORCH
181 // -----------------------------------------------------------------------
182 // --- Init
183 // -----------------------------------------------------------------------
197 void initInputBasis(unsigned int l,
198 unsigned int layerNoAxes,
199 std::vector<mtu::TorchMergeTree<float>> &trees,
200 std::vector<mtu::TorchMergeTree<float>> &trees2,
201 std::vector<bool> &isTrain,
202 std::vector<std::vector<torch::Tensor>> &allAlphasInit);
203
217 void initOutputBasis(unsigned int l,
218 double layerOriginPrimeSizePercent,
219 std::vector<mtu::TorchMergeTree<float>> &trees,
220 std::vector<mtu::TorchMergeTree<float>> &trees2,
221 std::vector<bool> &isTrain);
222
247 virtual bool
248 initResetOutputBasis(unsigned int l,
249 unsigned int layerNoAxes,
250 double layerOriginPrimeSizePercent,
251 std::vector<mtu::TorchMergeTree<float>> &trees,
252 std::vector<mtu::TorchMergeTree<float>> &trees2,
253 std::vector<bool> &isTrain)
254 = 0;
255
277 bool initGetReconstructed(
278 unsigned int l,
279 unsigned int layerNoAxes,
280 double layerOriginPrimeSizePercent,
281 std::vector<mtu::TorchMergeTree<float>> &trees,
282 std::vector<mtu::TorchMergeTree<float>> &trees2,
283 std::vector<bool> &isTrain,
284 std::vector<mtu::TorchMergeTree<float>> &recs,
285 std::vector<mtu::TorchMergeTree<float>> &recs2,
286 std::vector<std::vector<torch::Tensor>> &allAlphasInit);
287
304 virtual float
305 initParameters(std::vector<mtu::TorchMergeTree<float>> &trees,
306 std::vector<mtu::TorchMergeTree<float>> &trees2,
307 std::vector<bool> &isTrain,
308 bool computeError = false)
309 = 0;
310
321 void initStep(std::vector<mtu::TorchMergeTree<float>> &trees,
322 std::vector<mtu::TorchMergeTree<float>> &trees2,
323 std::vector<bool> &isTrain);
324
332 void passLayerParameters(MergeTreeNeuralLayer &layer);
333
334 // -----------------------------------------------------------------------
335 // --- Forward
336 // -----------------------------------------------------------------------
363 bool forwardOneData(mtu::TorchMergeTree<float> &tree,
364 mtu::TorchMergeTree<float> &tree2,
365 unsigned int treeIndex,
366 unsigned int k,
367 std::vector<torch::Tensor> &alphasInit,
368 mtu::TorchMergeTree<float> &out,
369 mtu::TorchMergeTree<float> &out2,
370 std::vector<torch::Tensor> &dataAlphas,
371 std::vector<mtu::TorchMergeTree<float>> &outs,
372 std::vector<mtu::TorchMergeTree<float>> &outs2,
373 bool train = false);
374
408 bool forwardStep(
409 std::vector<mtu::TorchMergeTree<float>> &trees,
410 std::vector<mtu::TorchMergeTree<float>> &trees2,
411 std::vector<unsigned int> &indexes,
412 std::vector<bool> &isTrain,
413 unsigned int k,
414 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
415 bool computeError,
416 std::vector<mtu::TorchMergeTree<float>> &outs,
417 std::vector<mtu::TorchMergeTree<float>> &outs2,
418 std::vector<std::vector<torch::Tensor>> &bestAlphas,
419 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
420 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
421 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
422 &matchings,
423 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
424 &matchings2,
425 float &loss,
426 float &testLoss);
427
458 bool forwardStep(
459 std::vector<mtu::TorchMergeTree<float>> &trees,
460 std::vector<mtu::TorchMergeTree<float>> &trees2,
461 std::vector<unsigned int> &indexes,
462 unsigned int k,
463 std::vector<std::vector<torch::Tensor>> &allAlphasInit,
464 bool computeError,
465 std::vector<mtu::TorchMergeTree<float>> &outs,
466 std::vector<mtu::TorchMergeTree<float>> &outs2,
467 std::vector<std::vector<torch::Tensor>> &bestAlphas,
468 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
469 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
470 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
471 &matchings,
472 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
473 &matchings2,
474 float &loss);
475
476 // -----------------------------------------------------------------------
477 // --- Backward
478 // -----------------------------------------------------------------------
506 virtual bool backwardStep(
507 std::vector<mtu::TorchMergeTree<float>> &trees,
508 std::vector<mtu::TorchMergeTree<float>> &outs,
509 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
510 &matchings,
511 std::vector<mtu::TorchMergeTree<float>> &trees2,
512 std::vector<mtu::TorchMergeTree<float>> &outs2,
513 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
514 &matchings2,
515 std::vector<std::vector<torch::Tensor>> &alphas,
516 torch::optim::Optimizer &optimizer,
517 std::vector<unsigned int> &indexes,
518 std::vector<bool> &isTrain,
519 std::vector<torch::Tensor> &torchCustomLoss)
520 = 0;
521
522 // -----------------------------------------------------------------------
523 // --- Projection
524 // -----------------------------------------------------------------------
529 void projectionStep();
530
531 // -----------------------------------------------------------------------
532 // --- Convergence
533 // -----------------------------------------------------------------------
555 virtual float computeOneLoss(
556 mtu::TorchMergeTree<float> &tree,
557 mtu::TorchMergeTree<float> &out,
558 mtu::TorchMergeTree<float> &tree2,
559 mtu::TorchMergeTree<float> &out2,
560 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching,
561 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> &matching2,
562 std::vector<torch::Tensor> &alphas,
563 unsigned int treeIndex)
564 = 0;
565
577 bool isBestLoss(float loss, float &minLoss, unsigned int &cptBlocked);
578
591 bool convergenceStep(float loss,
592 float &oldLoss,
593 float &minLoss,
594 unsigned int &cptBlocked);
595
596 // -----------------------------------------------------------------------
597 // --- Main Functions
598 // -----------------------------------------------------------------------
609 virtual void
610 customInit(std::vector<mtu::TorchMergeTree<float>> &torchTrees,
611 std::vector<mtu::TorchMergeTree<float>> &torchTrees2)
612 = 0;
613
624 virtual void addCustomParameters(std::vector<torch::Tensor> &parameters)
625 = 0;
626
653 virtual void computeCustomLosses(
654 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts,
655 std::vector<std::vector<mtu::TorchMergeTree<float>>> &layersOuts2,
656 std::vector<std::vector<torch::Tensor>> &bestAlphas,
657 std::vector<unsigned int> &indexes,
658 std::vector<bool> &isTrain,
659 unsigned int iteration,
660 std::vector<std::vector<float>> &gapCustomLosses,
661 std::vector<std::vector<float>> &iterationCustomLosses,
662 std::vector<torch::Tensor> &torchCustomLoss)
663 = 0;
664
678 virtual float computeIterationTotalLoss(
679 float iterationLoss,
680 std::vector<std::vector<float>> &iterationCustomLosses,
681 std::vector<float> &iterationCustomLoss)
682 = 0;
683
694 virtual void printCustomLosses(std::vector<float> &customLoss,
695 std::stringstream &prefix,
696 const debug::Priority &priority
698 = 0;
699
710 virtual void printGapLoss(float loss,
711 std::vector<std::vector<float>> &gapCustomLosses)
712 = 0;
713
721 void fit(std::vector<ftm::MergeTree<float>> &trees,
722 std::vector<ftm::MergeTree<float>> &trees2);
723
724 // ---------------------------------------------------------------------------
725 // --- End Functions
726 // ---------------------------------------------------------------------------
737 void computeTrackingInformation(unsigned int endLayer);
738
747 void computeCorrelationMatrix(std::vector<ftm::MergeTree<float>> &trees,
748 unsigned int layer);
749
758 void
759 createScaledAlphas(std::vector<std::vector<torch::Tensor>> &alphas,
760 std::vector<std::vector<torch::Tensor>> &scaledAlphas);
761
766 void createScaledAlphas();
767
773 void createActivatedAlphas();
774
775 // -----------------------------------------------------------------------
776 // --- Utils
777 // -----------------------------------------------------------------------
778 void copyParams(std::vector<mtu::TorchMergeTree<float>> &origins,
779 std::vector<mtu::TorchMergeTree<float>> &originsPrime,
780 std::vector<torch::Tensor> &vS,
781 std::vector<torch::Tensor> &vSPrime,
782 std::vector<mtu::TorchMergeTree<float>> &origins2,
783 std::vector<mtu::TorchMergeTree<float>> &origins2Prime,
784 std::vector<torch::Tensor> &vS2,
785 std::vector<torch::Tensor> &vS2Prime,
786 std::vector<std::vector<torch::Tensor>> &srcAlphas,
787 std::vector<std::vector<torch::Tensor>> &dstAlphas,
788 bool get);
789
790 void copyParams(std::vector<std::vector<mtu::TorchMergeTree<float>>> &src,
791 std::vector<std::vector<mtu::TorchMergeTree<float>>> &dst);
792
801 virtual void copyCustomParams(bool get) = 0;
802
815 void getAlphasTensor(std::vector<std::vector<torch::Tensor>> &alphas,
816 std::vector<unsigned int> &indexes,
817 std::vector<bool> &toGet,
818 unsigned int layerIndex,
819 torch::Tensor &alphasOut);
820
831 void getAlphasTensor(std::vector<std::vector<torch::Tensor>> &alphas,
832 std::vector<unsigned int> &indexes,
833 unsigned int layerIndex,
834 torch::Tensor &alphasOut);
835
845 void getAlphasTensor(std::vector<std::vector<torch::Tensor>> &alphas,
846 unsigned int layerIndex,
847 torch::Tensor &alphasOut);
848
849 // -----------------------------------------------------------------------
850 // --- Testing
851 // -----------------------------------------------------------------------
852 void checkZeroGrad(unsigned int l, bool checkOutputBasis = true);
853
854 bool isTreeHasBigValues(const ftm::MergeTree<float> &mTree,
855 float threshold = 10000);
856
857 // ---------------------------------------------------------------------------
858 // --- Main Functions
859 // ---------------------------------------------------------------------------
871 virtual void executeEndFunction(std::vector<ftm::MergeTree<float>> &trees,
872 std::vector<ftm::MergeTree<float>> &trees2)
873 = 0;
874#endif
875
876 void execute(std::vector<ftm::MergeTree<float>> &trees,
877 std::vector<ftm::MergeTree<float>> &trees2);
878 }; // MergeTreeNeuralNetwork class
879
880} // namespace ttk
std::vector< unsigned int > origins2PrimeNoZeroGrad_
std::vector< unsigned int > originsNoZeroGrad_
std::vector< unsigned int > vSPrimeNoZeroGrad_
void execute(std::vector< ftm::MergeTree< float > > &trees, std::vector< ftm::MergeTree< float > > &trees2)
std::vector< unsigned int > clusterAsgn_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > baryMatchings2_L0_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > baryMatchings_L0_
std::vector< std::vector< double > > persCorrelationMatrix_
std::vector< double > inputToBaryDistances_L0_
std::vector< std::vector< double > > branchesCorrelationMatrix_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > customMatchings_
std::vector< unsigned int > origins2NoZeroGrad_
std::vector< std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > > dataMatchings_
std::vector< unsigned int > originsPrimeNoZeroGrad_
std::vector< unsigned int > vS2NoZeroGrad_
std::vector< unsigned int > vS2PrimeNoZeroGrad_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > originsMatchings_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > reconstMatchings_
std::vector< unsigned int > vSNoZeroGrad_
TTK base package defining the standard types.