TTK
Loading...
Searching...
No Matches
ttk::MergeTreeNeuralNetwork Class Reference

#include <MergeTreeNeuralNetwork.h>

Inheritance diagram for ttk::MergeTreeNeuralNetwork:
ttk::Debug ttk::MergeTreeNeuralBase ttk::BaseClass ttk::Debug ttk::MergeTreeAxesAlgorithmBase ttk::BaseClass ttk::Debug ttk::MergeTreeBase ttk::BaseClass ttk::Debug ttk::BaseClass ttk::MergeTreeAutoencoder ttk::MergeTreeAutoencoderDecoding ttkMergeTreeAutoencoder ttkMergeTreeAutoencoderDecoding

Public Member Functions

 MergeTreeNeuralNetwork ()
 
void execute (std::vector< ftm::MergeTree< float > > &trees, std::vector< ftm::MergeTree< float > > &trees2)
 
- Public Member Functions inherited from ttk::Debug
 Debug ()
 
 ~Debug () override
 
virtual int setDebugLevel (const int &debugLevel)
 
int setWrapper (const Wrapper *wrapper) override
 
int printMsg (const std::string &msg, const debug::Priority &priority=debug::Priority::INFO, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cout) const
 
int printMsg (const std::vector< std::string > &msgs, const debug::Priority &priority=debug::Priority::INFO, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cout) const
 
int printErr (const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
 
int printWrn (const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
 
int printMsg (const std::string &msg, const double &progress, const double &time, const int &threads, const double &memory, const debug::LineMode &lineMode=debug::LineMode::NEW, const debug::Priority &priority=debug::Priority::PERFORMANCE, std::ostream &stream=std::cout) const
 
int printMsg (const std::string &msg, const double &progress, const double &time, const debug::LineMode &lineMode=debug::LineMode::NEW, const debug::Priority &priority=debug::Priority::PERFORMANCE, std::ostream &stream=std::cout) const
 
int printMsg (const std::string &msg, const double &progress, const double &time, const int &threads, const debug::LineMode &lineMode=debug::LineMode::NEW, const debug::Priority &priority=debug::Priority::PERFORMANCE, std::ostream &stream=std::cout) const
 
int printMsg (const std::string &msg, const double &progress, const debug::LineMode &lineMode=debug::LineMode::NEW, const debug::Priority &priority=debug::Priority::PERFORMANCE, std::ostream &stream=std::cout) const
 
int printMsg (const std::string &msg, const double &progress, const debug::Priority &priority, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cout) const
 
int printMsg (const std::vector< std::vector< std::string > > &rows, const debug::Priority &priority=debug::Priority::INFO, const bool hasHeader=true, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cout) const
 
int printMsg (const debug::Separator &separator, const debug::LineMode &lineMode=debug::LineMode::NEW, const debug::Priority &priority=debug::Priority::INFO, std::ostream &stream=std::cout) const
 
int printMsg (const debug::Separator &separator, const debug::Priority &priority, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cout) const
 
int printMsg (const std::string &msg, const debug::Separator &separator, const debug::LineMode &lineMode=debug::LineMode::NEW, const debug::Priority &priority=debug::Priority::INFO, std::ostream &stream=std::cout) const
 
void setDebugMsgPrefix (const std::string &prefix)
 
- Public Member Functions inherited from ttk::BaseClass
 BaseClass ()
 
virtual ~BaseClass ()=default
 
int getThreadNumber () const
 
virtual int setThreadNumber (const int threadNumber)
 
- Public Member Functions inherited from ttk::MergeTreeNeuralBase
 MergeTreeNeuralBase ()
 
- Public Member Functions inherited from ttk::MergeTreeAxesAlgorithmBase
 MergeTreeAxesAlgorithmBase ()
 
void setDeterministic (const bool deterministic)
 
void setNumberOfProjectionSteps (const unsigned int k)
 
void setBarycenterSizeLimitPercent (const double barycenterSizeLimitPercent)
 
void setProbabilisticVectorsInit (const bool probabilisticVectorsInit)
 
template<class dataType>
void computeOneDistance (const ftm::MergeTree< dataType > &tree1, const ftm::MergeTree< dataType > &tree2, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matching, dataType &distance, bool isCalled=false, bool useDoubleInput=false, bool isFirstInput=true)
 
template<class dataType>
void computeOneDistance (const ftm::MergeTree< dataType > &tree1, const ftm::MergeTree< dataType > &tree2, dataType &distance, bool isCalled=false, bool useDoubleInput=false, bool isFirstInput=true)
 
template<class dataType>
void initVectorFromMatching (ftm::MergeTree< dataType > &barycenter, ftm::MergeTree< dataType > &tree, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matching, std::vector< std::vector< double > > &v)
 
template<class dataType>
void initRandomVector (ftm::MergeTree< dataType > &barycenter, std::vector< std::vector< double > > &v, std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s)
 
template<class dataType, typename F>
int initVectors (int axeNumber, ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, ftm::MergeTree< dataType > &barycenter2, std::vector< ftm::MergeTree< dataType > > &trees2, std::vector< std::vector< double > > &v1, std::vector< std::vector< double > > &v2, std::vector< std::vector< double > > &trees2V1, std::vector< std::vector< double > > &trees2V2, int newVectorOffset, std::vector< double > &inputToOriginDistances, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &baryMatchings, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &baryMatchings2, std::vector< std::vector< double > > &inputToAxesDistances, std::vector< std::vector< std::vector< double > > > &vS, std::vector< std::vector< std::vector< double > > > &v2s, std::vector< std::vector< std::vector< double > > > &trees2Vs, std::vector< std::vector< std::vector< double > > > &trees2V2s, bool projectInitializedVectors, F initializedVectorsProjection)
 
template<class dataType>
void computeOneBarycenter (std::vector< ftm::MergeTree< dataType > > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< double > &finalDistances, double barycenterSizeLimitPercent, unsigned int barycenterMaximumNumberOfPairs, int barycenterInitIndex, bool oneIter, bool useDoubleInput=false, bool isFirstInput=true)
 
template<class dataType>
void computeOneBarycenter (std::vector< ftm::MergeTree< dataType > > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< double > &finalDistances, double barycenterSizeLimitPercent, unsigned int barycenterMaximumNumberOfPairs, bool useDoubleInput=false, bool isFirstInput=true)
 
template<class dataType>
void computeOneBarycenter (std::vector< ftm::MergeTree< dataType > > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< double > &finalDistances, double barycenterSizeLimitPercent, bool useDoubleInput=false, bool isFirstInput=true)
 
template<class dataType>
void computeOneBarycenter (std::vector< ftm::MergeTree< dataType > > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings, std::vector< double > &finalDistances, bool useDoubleInput=false, bool isFirstInput=true)
 
template<class dataType>
void computeOneBarycenter (std::vector< ftm::MergeTree< dataType > > &trees, ftm::MergeTree< dataType > &baryMergeTree, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &matchings)
 
template<class dataType>
void computeOneBarycenter (std::vector< ftm::MergeTree< dataType > > &trees, ftm::MergeTree< dataType > &baryMergeTree)
 
template<class dataType>
void preprocessingTrees (std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< int > > &nodeCorr, bool useMinMaxPairT=true)
 
template<class dataType>
void preprocessingTrees (std::vector< ftm::MergeTree< dataType > > &trees, bool useMinMaxPairT=true)
 
template<class dataType>
std::tuple< dataType, dataType > getParametrizedBirthDeath (ftm::FTMTree_MT *tree, ftm::idNode node)
 
template<class dataType>
void computeBranchesCorrelationMatrix (const ftm::MergeTree< dataType > &barycenter, std::vector< ftm::MergeTree< dataType > > &trees, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &baryMatchings, std::vector< std::vector< double > > &allTs, std::vector< std::vector< double > > &branchesCorrelationMatrix, std::vector< std::vector< double > > &persCorrelationMatrix)
 
- Public Member Functions inherited from ttk::MergeTreeBase
 MergeTreeBase ()
 
void setAssignmentSolver (int assignmentSolver)
 
void setEpsilon1UseFarthestSaddle (bool b)
 
void setEpsilonTree1 (double epsilon)
 
void setEpsilonTree2 (double epsilon)
 
void setEpsilon2Tree1 (double epsilon)
 
void setEpsilon2Tree2 (double epsilon)
 
void setEpsilon3Tree1 (double epsilon)
 
void setEpsilon3Tree2 (double epsilon)
 
void setPersistenceThreshold (double pt)
 
void setParallelize (bool para)
 
void setNodePerTask (int npt)
 
void setBranchDecomposition (bool useBD)
 
void setNormalizedWasserstein (bool normalizedWasserstein)
 
void setKeepSubtree (bool keepSubtree)
 
void setNonMatchingWeight (double weight)
 
void setBarycenterMergeTree (bool imt)
 
void setDistanceSquaredRoot (bool distanceSquaredRoot)
 
void setUseMinMaxPair (bool useMinMaxPair)
 
void setDeleteMultiPersPairs (bool deleteMultiPersPairsT)
 
void setCleanTree (bool clean)
 
void setIsPersistenceDiagram (bool isPD)
 
void setJoinSplitMixtureCoefficient (const double mixtureCoefficient)
 
void setUseDoubleInput (const bool useDoubleInput)
 
std::vector< std::vector< int > > getTreesNodeCorr ()
 
double mixDistancesMinMaxPairWeight (bool isFirstInput)
 
double mixDistancesWeight (bool isFirstInput)
 
template<class dataType>
double mixDistances (dataType distance1, dataType distance2)
 
template<class dataType>
void mixDistancesMatrix (std::vector< std::vector< dataType > > &distanceMatrix, std::vector< std::vector< dataType > > &distanceMatrix2)
 
template<class dataType>
void mergeSaddle (ftm::FTMTree_MT *tree, double epsilon, std::vector< std::vector< ftm::idNode > > &treeNodeMerged, bool mergeByPersistence=false)
 
template<class dataType>
void persistenceMerging (ftm::FTMTree_MT *tree, double epsilon2, double epsilon3=100)
 
void deletePersistenceDiagramsPairs (ftm::FTMTree_MT *tree, std::vector< ftm::idNode > &nodes)
 
template<class dataType>
void keepMostImportantPairs (ftm::FTMTree_MT *tree, int n, bool useBD)
 
template<class dataType>
void persistenceThresholding (ftm::FTMTree_MT *tree, double persistenceThresholdT, std::vector< ftm::idNode > &deletedNodes)
 
template<class dataType>
void persistenceThresholding (ftm::FTMTree_MT *tree, std::vector< ftm::idNode > &deletedNodes)
 
template<class dataType>
void persistenceThresholding (ftm::FTMTree_MT *tree, double persistenceThresholdT)
 
template<class dataType>
void persistenceThresholding (ftm::FTMTree_MT *tree)
 
template<class dataType>
void verifyOrigins (ftm::FTMTree_MT *tree)
 
template<class dataType>
void preprocessTree (ftm::FTMTree_MT *tree, bool deleteInconsistentNodes=true)
 
template<class dataType>
ftm::FTMTree_MTcomputeBranchDecomposition (ftm::FTMTree_MT *tree, std::vector< std::vector< ftm::idNode > > &treeNodeMerged)
 
template<class dataType>
void dontUseMinMaxPair (ftm::FTMTree_MT *tree)
 
void verifyPairsTree (ftm::FTMTree_MT *tree)
 
template<class dataType>
void deleteMultiPersPairs (ftm::FTMTree_MT *tree, bool useBD)
 
template<class dataType>
void preprocessingPipeline (ftm::MergeTree< dataType > &mTree, double epsilonTree, double epsilon2Tree, double epsilon3Tree, bool branchDecompositionT, bool useMinMaxPairT, bool cleanTreeT, double persistenceThreshold, std::vector< int > &nodeCorr, bool deleteInconsistentNodes=true)
 
template<class dataType>
void preprocessingPipeline (ftm::MergeTree< dataType > &mTree, double epsilonTree, double epsilon2Tree, double epsilon3Tree, bool branchDecompositionT, bool useMinMaxPairT, bool cleanTreeT, std::vector< int > &nodeCorr, bool deleteInconsistentNodes=true)
 
void reverseNodeCorr (ftm::FTMTree_MT *tree, std::vector< int > &nodeCorr)
 
template<class dataType>
void mtFlattening (ftm::MergeTree< dataType > &mt)
 
template<class dataType>
void mtsFlattening (std::vector< ftm::MergeTree< dataType > > &mts)
 
double getSizeLimitMetric (std::vector< ftm::FTMTree_MT * > &trees)
 
template<class dataType>
void copyMinMaxPair (ftm::MergeTree< dataType > &mTree1, ftm::MergeTree< dataType > &mTree2, bool setOrigins=false)
 
template<class dataType>
std::tuple< int, dataType > fixMergedRootOrigin (ftm::FTMTree_MT *tree)
 
template<class dataType>
void branchDecompositionToTree (ftm::FTMTree_MT *tree)
 
template<class dataType>
void putBackMergedNodes (ftm::FTMTree_MT *tree)
 
template<class dataType>
void postprocessingPipeline (ftm::FTMTree_MT *tree)
 
template<class dataType>
void convertBranchDecompositionMatching (ftm::FTMTree_MT *tree1, ftm::FTMTree_MT *tree2, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &outputMatching)
 
template<class dataType>
void convertBranchDecompositionMatching (ftm::FTMTree_MT *tree1, ftm::FTMTree_MT *tree2, std::vector< std::tuple< ftm::idNode, ftm::idNode > > &outputMatching)
 
template<class dataType>
void identifyRealMatching (ftm::FTMTree_MT *tree1, ftm::FTMTree_MT *tree2, std::vector< std::tuple< ftm::idNode, ftm::idNode > > &outputMatching, std::vector< std::tuple< ftm::idNode, ftm::idNode, bool > > &realMatching)
 
template<class dataType>
dataType computeDistance (dataType x1, dataType x2, dataType y1, dataType y2, double power=2)
 
template<class dataType>
dataType deleteCost (const ftm::FTMTree_MT *tree, ftm::idNode nodeId)
 
template<class dataType>
dataType insertCost (const ftm::FTMTree_MT *tree, ftm::idNode nodeId)
 
template<class dataType>
dataType relabelCostOnly (const ftm::FTMTree_MT *tree1, ftm::idNode nodeId1, const ftm::FTMTree_MT *tree2, ftm::idNode nodeId2)
 
template<class dataType>
dataType relabelCost (const ftm::FTMTree_MT *tree1, ftm::idNode nodeId1, const ftm::FTMTree_MT *tree2, ftm::idNode nodeId2)
 
void getParamNames (std::vector< std::string > &paramNames)
 
double getParamValueFromName (std::string &paramName)
 
void setParamValueFromName (std::string &paramName, double value)
 
void getTreesStats (std::vector< ftm::FTMTree_MT * > &trees, std::array< double, 3 > &stats)
 
void printTreesStats (std::vector< ftm::FTMTree_MT * > &trees)
 
template<class dataType>
void printTreesStats (std::vector< ftm::MergeTree< dataType > > &trees)
 
template<class dataType>
void printTableVector (std::vector< std::vector< dataType > > &table)
 
template<class dataType>
void printTable (dataType *table, int nRows, int nCols)
 
void printMatching (std::vector< MatchingType > &matchings)
 
void printMatching (std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matchings)
 
void printMatching (std::vector< std::tuple< ftm::idNode, ftm::idNode > > &matchings)
 
template<class dataType>
void printPairs (std::vector< std::tuple< SimplexId, SimplexId, dataType > > &treePairs)
 
template<class dataType>
void printOutputMatching (std::vector< std::tuple< ftm::idNode, ftm::idNode > > &outputMatching, ftm::FTMTree_MT *tree1, ftm::FTMTree_MT *tree2, bool computeCosts=true)
 

Protected Attributes

unsigned int minIteration_ = 0
 
unsigned int maxIteration_ = 0
 
unsigned int iterationGap_ = 100
 
double batchSize_ = 1
 
int optimizer_ = 0
 
double gradientStepSize_ = 0.1
 
double beta1_ = 0.9
 
double beta2_ = 0.999
 
unsigned int noInit_ = 4
 
bool activateOutputInit_ = false
 
double originPrimeSizePercent_ = 15
 
double trainTestSplit_ = 1.0
 
bool shuffleBeforeSplit_ = true
 
bool createOutput_ = true
 
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > originsMatchings_
 
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > reconstMatchings_
 
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > customMatchings_
 
std::vector< std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > > dataMatchings_
 
unsigned noLayers_
 
float bestLoss_
 
std::vector< unsigned int > clusterAsgn_
 
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > baryMatchings_L0_
 
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > baryMatchings2_L0_
 
std::vector< double > inputToBaryDistances_L0_
 
std::vector< std::vector< double > > branchesCorrelationMatrix_
 
std::vector< std::vector< double > > persCorrelationMatrix_
 
double t_allVectorCopy_time_ = 0.0
 
std::vector< unsigned int > originsNoZeroGrad_
 
std::vector< unsigned int > originsPrimeNoZeroGrad_
 
std::vector< unsigned int > vSNoZeroGrad_
 
std::vector< unsigned int > vSPrimeNoZeroGrad_
 
std::vector< unsigned int > origins2NoZeroGrad_
 
std::vector< unsigned int > origins2PrimeNoZeroGrad_
 
std::vector< unsigned int > vS2NoZeroGrad_
 
std::vector< unsigned int > vS2PrimeNoZeroGrad_
 
- Protected Attributes inherited from ttk::Debug
int debugLevel_
 
std::string debugMsgPrefix_
 
std::string debugMsgNamePrefix_
 
- Protected Attributes inherited from ttk::BaseClass
bool lastObject_
 
int threadNumber_
 
Wrapperwrapper_
 
- Protected Attributes inherited from ttk::MergeTreeNeuralBase
double dropout_ = 0.0
 
bool euclideanVectorsInit_ = false
 
bool randomAxesInit_ = false
 
bool initBarycenterRandom_ = false
 
bool initBarycenterOneIter_ = false
 
bool initOriginPrimeStructByCopy_ = true
 
bool initOriginPrimeValuesByCopy_ = true
 
double initOriginPrimeValuesByCopyRandomness_ = 0.0
 
bool activate_ = true
 
unsigned int activationFunction_ = 1
 
bool useGpu_ = false
 
float bigValuesThreshold_ = 0
 
- Protected Attributes inherited from ttk::MergeTreeAxesAlgorithmBase
bool deterministic_ = true
 
unsigned int numberOfAxes_ = 2
 
unsigned int k_ = 16
 
double barycenterSizeLimitPercent_ = 20.0
 
bool probabilisticVectorsInit_ = false
 
std::vector< std::vector< int > > trees2NodeCorr_
 
- Protected Attributes inherited from ttk::MergeTreeBase
int assignmentSolverID_ = 0
 
bool epsilon1UseFarthestSaddle_ = false
 
double epsilonTree1_ = 0
 
double epsilonTree2_ = 0
 
double epsilon2Tree1_ = 100
 
double epsilon2Tree2_ = 100
 
double epsilon3Tree1_ = 100
 
double epsilon3Tree2_ = 100
 
double persistenceThreshold_ = 0
 
bool barycenterMergeTree_ = false
 
bool useMinMaxPair_ = true
 
bool deleteMultiPersPairs_ = false
 
bool branchDecomposition_ = true
 
int wassersteinPower_ = 2
 
bool normalizedWasserstein_ = true
 
bool keepSubtree_ = false
 
double nonMatchingWeight_ = 1.0
 
bool distanceSquaredRoot_ = true
 
bool useFullMerge_ = false
 
bool isPersistenceDiagram_ = false
 
bool convertToDiagram_ = false
 
double mixtureCoefficient_ = 0.5
 
bool useDoubleInput_ = false
 
bool parallelize_ = true
 
int nodePerTask_ = 32
 
bool cleanTree_ = true
 
std::vector< std::vector< int > > treesNodeCorr_
 

Additional Inherited Members

- Protected Member Functions inherited from ttk::Debug
int printMsgInternal (const std::string &msg, const std::string &right, const std::string &filler, const debug::Priority &priority=debug::Priority::INFO, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cout) const
 
int printMsgInternal (const std::string &msg, const debug::Priority &priority, const debug::LineMode &lineMode, std::ostream &stream=std::cout) const
 
int welcomeMsg (std::ostream &stream)
 
- Static Protected Attributes inherited from ttk::Debug
static COMMON_EXPORTS debug::LineMode lastLineMode = ttk::debug::LineMode::NEW
 

Detailed Description

The MergeTreeNeuralNetwork class provides methods to define a neural network able to process merge trees or persistence diagrams.

Author
Mathieu Pont mathi.nosp@m.eu.p.nosp@m.ont@l.nosp@m.ip6..nosp@m.fr
Date
2023.

This module defines the MergeTreeNeuralNetwork class providing functions to define a neural network able to process merge trees or persistence diagrams.

This is an abstract class, to implement a derived class you need to define the following functions:

  • "initParameters" : initializes the network, like the different layers. A strategy to initialize a sequence of layers consist in initializing a first layer with the input topological representations, then pass them through this layer to initialize the second one and so on. A simple loop (whose number of iterations corresponds to the number of layers) can do this using the "initInputBasis" and the "initOutputBasis" function, then the "initGetReconstructed" function to pass the representations to the layer that just have been initialized.
  • "initResetOutputBasis" : please refer to the documentation of this function.
  • "customInit" : called just before the "initStep" function (that call the "initParameters" function), is is intended to do custom operations depending on the architecture and the optimization you want to define (such as computing the distance matrix for the metric loss in the autoencoder case). This function can be empty.
  • "backwardStep" : optimizes the parameters of the network. A loss using differentiable torch operations should be computed using the output of some layers of the network (usually the output of the last layer but it can also be any other layers). You can either use the torch coordinates of the representations in a layer or their torch tensors to compute the loss. Then use the torch::Tensor "backward" function to compute the gradients, then the torch::optim::Optimizer "step" function to update the model parameters, after this, the torch::optim::Optimizer "zero_grad" function should be called to reset the gradient. If you have correctly created the MergeTreeNeuralLayer objects (refer to the corresponding class documentation), basically by calling the "requires_grad" function (with true as parameter) for each layer after initializing its parameters, then everything would be automatically handled to backpropagate the gradient of the loss through the layers.
  • "addCustomParameters" : adds custom parameters to the parameter list that will be given to the optimizer, depending on the architecture and the optimization you want to define (such as the centroids for the cluster loss in the autoencoder case). This function can be empty.
  • "computeOneLoss" : computes the loss for one input topological representation, the loss computed here does not need to be differentiable because it will only be used to print it in the console and to check convergence of the method (i.e. it is not called in the "backwardStep" function).
  • "computeCustomLosses" : computes custom losses for all input topological representations depending on the architecture and the optimization you want to define (such as the clustering and the metric loss in the autoendoer case). Like "computeOneLoss", the losses do not need to be differentiable because they will only be used to print them in the console and to check convergence of the method. This function can be empty.
  • "computeIterationTotalLoss"
  • "printCustomLosses" : prints the custom loss depending on the architecture and the optimization you want to define (such as the clustering and the metric loss in the autoendoer case). This function can be empty.
  • "printGapLoss" : prints the "gap" loss, the aggregated loss over iterationGap_ iterations.
  • "copyCustomParams" : copy the custom parameters (for instance to save them when a better loss is reached during the optimization) depending on the architecture and the optimization you want to define (such as the centroids for the cluster loss in the autoencoder case). This function can be empty.
  • "executeEndFunction" : does specific operations at the end of the optimization, such as calling the "computeTrackingInformation" and the "computeCorrelationMatrix" functions.

Related publication:
"Wasserstein Auto-Encoders of Merge Trees (and Persistence Diagrams)"
Mathieu Pont, Julien Tierny.
IEEE Transactions on Visualization and Computer Graphics, 2023

Definition at line 106 of file MergeTreeNeuralNetwork.h.

Constructor & Destructor Documentation

◆ MergeTreeNeuralNetwork()

ttk::MergeTreeNeuralNetwork::MergeTreeNeuralNetwork ( )

Definition at line 8 of file MergeTreeNeuralNetwork.cpp.

Member Function Documentation

◆ execute()

void ttk::MergeTreeNeuralNetwork::execute ( std::vector< ftm::MergeTree< float > > & trees,
std::vector< ftm::MergeTree< float > > & trees2 )

Definition at line 1210 of file MergeTreeNeuralNetwork.cpp.

Member Data Documentation

◆ activateOutputInit_

bool ttk::MergeTreeNeuralNetwork::activateOutputInit_ = false
protected

Definition at line 131 of file MergeTreeNeuralNetwork.h.

◆ baryMatchings2_L0_

std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double> > > ttk::MergeTreeNeuralNetwork::baryMatchings2_L0_
protected

Definition at line 166 of file MergeTreeNeuralNetwork.h.

◆ baryMatchings_L0_

std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double> > > ttk::MergeTreeNeuralNetwork::baryMatchings_L0_
protected

Definition at line 166 of file MergeTreeNeuralNetwork.h.

◆ batchSize_

double ttk::MergeTreeNeuralNetwork::batchSize_ = 1
protected

Definition at line 117 of file MergeTreeNeuralNetwork.h.

◆ bestLoss_

float ttk::MergeTreeNeuralNetwork::bestLoss_
protected

Definition at line 163 of file MergeTreeNeuralNetwork.h.

◆ beta1_

double ttk::MergeTreeNeuralNetwork::beta1_ = 0.9
protected

Definition at line 126 of file MergeTreeNeuralNetwork.h.

◆ beta2_

double ttk::MergeTreeNeuralNetwork::beta2_ = 0.999
protected

Definition at line 127 of file MergeTreeNeuralNetwork.h.

◆ branchesCorrelationMatrix_

std::vector<std::vector<double> > ttk::MergeTreeNeuralNetwork::branchesCorrelationMatrix_
protected

Definition at line 168 of file MergeTreeNeuralNetwork.h.

◆ clusterAsgn_

std::vector<unsigned int> ttk::MergeTreeNeuralNetwork::clusterAsgn_
protected

Definition at line 164 of file MergeTreeNeuralNetwork.h.

◆ createOutput_

bool ttk::MergeTreeNeuralNetwork::createOutput_ = true
protected

Definition at line 140 of file MergeTreeNeuralNetwork.h.

◆ customMatchings_

std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double> > > ttk::MergeTreeNeuralNetwork::customMatchings_
protected

Definition at line 156 of file MergeTreeNeuralNetwork.h.

◆ dataMatchings_

std::vector< std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double> > > > ttk::MergeTreeNeuralNetwork::dataMatchings_
protected

Definition at line 159 of file MergeTreeNeuralNetwork.h.

◆ gradientStepSize_

double ttk::MergeTreeNeuralNetwork::gradientStepSize_ = 0.1
protected

Definition at line 124 of file MergeTreeNeuralNetwork.h.

◆ inputToBaryDistances_L0_

std::vector<double> ttk::MergeTreeNeuralNetwork::inputToBaryDistances_L0_
protected

Definition at line 167 of file MergeTreeNeuralNetwork.h.

◆ iterationGap_

unsigned int ttk::MergeTreeNeuralNetwork::iterationGap_ = 100
protected

Definition at line 115 of file MergeTreeNeuralNetwork.h.

◆ maxIteration_

unsigned int ttk::MergeTreeNeuralNetwork::maxIteration_ = 0
protected

Definition at line 113 of file MergeTreeNeuralNetwork.h.

◆ minIteration_

unsigned int ttk::MergeTreeNeuralNetwork::minIteration_ = 0
protected

Definition at line 111 of file MergeTreeNeuralNetwork.h.

◆ noInit_

unsigned int ttk::MergeTreeNeuralNetwork::noInit_ = 4
protected

Definition at line 129 of file MergeTreeNeuralNetwork.h.

◆ noLayers_

unsigned ttk::MergeTreeNeuralNetwork::noLayers_
protected

Definition at line 162 of file MergeTreeNeuralNetwork.h.

◆ optimizer_

int ttk::MergeTreeNeuralNetwork::optimizer_ = 0
protected

Definition at line 122 of file MergeTreeNeuralNetwork.h.

◆ originPrimeSizePercent_

double ttk::MergeTreeNeuralNetwork::originPrimeSizePercent_ = 15
protected

Definition at line 134 of file MergeTreeNeuralNetwork.h.

◆ origins2NoZeroGrad_

std::vector<unsigned int> ttk::MergeTreeNeuralNetwork::origins2NoZeroGrad_
protected

Definition at line 174 of file MergeTreeNeuralNetwork.h.

◆ origins2PrimeNoZeroGrad_

std::vector<unsigned int> ttk::MergeTreeNeuralNetwork::origins2PrimeNoZeroGrad_
protected

Definition at line 175 of file MergeTreeNeuralNetwork.h.

◆ originsMatchings_

std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double> > > ttk::MergeTreeNeuralNetwork::originsMatchings_
protected

Definition at line 156 of file MergeTreeNeuralNetwork.h.

◆ originsNoZeroGrad_

std::vector<unsigned int> ttk::MergeTreeNeuralNetwork::originsNoZeroGrad_
protected

Definition at line 173 of file MergeTreeNeuralNetwork.h.

◆ originsPrimeNoZeroGrad_

std::vector<unsigned int> ttk::MergeTreeNeuralNetwork::originsPrimeNoZeroGrad_
protected

Definition at line 173 of file MergeTreeNeuralNetwork.h.

◆ persCorrelationMatrix_

std::vector<std::vector<double> > ttk::MergeTreeNeuralNetwork::persCorrelationMatrix_
protected

Definition at line 169 of file MergeTreeNeuralNetwork.h.

◆ reconstMatchings_

std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double> > > ttk::MergeTreeNeuralNetwork::reconstMatchings_
protected

Definition at line 156 of file MergeTreeNeuralNetwork.h.

◆ shuffleBeforeSplit_

bool ttk::MergeTreeNeuralNetwork::shuffleBeforeSplit_ = true
protected

Definition at line 138 of file MergeTreeNeuralNetwork.h.

◆ t_allVectorCopy_time_

double ttk::MergeTreeNeuralNetwork::t_allVectorCopy_time_ = 0.0
protected

Definition at line 172 of file MergeTreeNeuralNetwork.h.

◆ trainTestSplit_

double ttk::MergeTreeNeuralNetwork::trainTestSplit_ = 1.0
protected

Definition at line 136 of file MergeTreeNeuralNetwork.h.

◆ vS2NoZeroGrad_

std::vector<unsigned int> ttk::MergeTreeNeuralNetwork::vS2NoZeroGrad_
protected

Definition at line 175 of file MergeTreeNeuralNetwork.h.

◆ vS2PrimeNoZeroGrad_

std::vector<unsigned int> ttk::MergeTreeNeuralNetwork::vS2PrimeNoZeroGrad_
protected

Definition at line 175 of file MergeTreeNeuralNetwork.h.

◆ vSNoZeroGrad_

std::vector<unsigned int> ttk::MergeTreeNeuralNetwork::vSNoZeroGrad_
protected

Definition at line 174 of file MergeTreeNeuralNetwork.h.

◆ vSPrimeNoZeroGrad_

std::vector<unsigned int> ttk::MergeTreeNeuralNetwork::vSPrimeNoZeroGrad_
protected

Definition at line 174 of file MergeTreeNeuralNetwork.h.


The documentation for this class was generated from the following files: