7#include <vtkDataArray.h>
9#include <vtkFloatArray.h>
10#include <vtkInformation.h>
11#include <vtkObjectFactory.h>
12#include <vtkPointData.h>
13#include <vtkSmartPointer.h>
15#include <vtkUnsignedIntArray.h>
37 this->SetNumberOfInputPorts(3);
38 this->SetNumberOfOutputPorts(4);
49 vtkInformation *info) {
51 info->Set(vtkAlgorithm::INPUT_REQUIRED_DATA_TYPE(),
"vtkMultiBlockDataSet");
52 }
else if(port == 1) {
53 info->Set(vtkAlgorithm::INPUT_REQUIRED_DATA_TYPE(),
"vtkMultiBlockDataSet");
54 info->Set(vtkAlgorithm::INPUT_IS_OPTIONAL(), 1);
55 }
else if(port == 2) {
56 info->Set(vtkAlgorithm::INPUT_REQUIRED_DATA_TYPE(),
"vtkTable");
57 info->Set(vtkAlgorithm::INPUT_IS_OPTIONAL(), 1);
79 vtkInformation *info) {
80 if(port == 0 or port == 1 or port == 2 or port == 3) {
81 info->Set(vtkDataObject::DATA_TYPE_NAME(),
"vtkMultiBlockDataSet");
102 vtkInformationVector **inputVector,
103 vtkInformationVector *outputVector) {
104#ifndef TTK_ENABLE_TORCH
107 printErr(
"This filter requires Torch.");
113 auto blocks = vtkMultiBlockDataSet::GetData(inputVector[0], 0);
114 auto blocks2 = vtkMultiBlockDataSet::GetData(inputVector[1], 0);
115 auto table = vtkTable::GetData(inputVector[2], 0);
120 std::vector<vtkSmartPointer<vtkMultiBlockDataSet>> inputTrees, inputTrees2;
126 vtkAbstractArray *clusterAsgn;
128 clusterAsgn = this->GetInputArrayToProcess(0, inputVector);
132 clusterAsgn_[i] = clusterAsgn->GetVariantValue(i).ToInt();
137 "You must provide a table column in info input to use clustering loss");
141 std::stringstream ss;
149 if((treesNodes.size() != 0 and inputTrees[0]->GetBlock(0) != treesNodes[0])
150 or (treesNodes2.size() != inputTrees2.size()))
151 resetDataVisualization();
161 printMsg(
"Computation with normalized Wasserstein.");
163 printMsg(
"Computation without normalized Wasserstein.");
165 return run(outputVector, inputTrees, inputTrees2);
169#ifdef TTK_ENABLE_TORCH
171 vtkInformationVector *outputVector,
174 runCompute(outputVector, inputTrees, inputTrees2);
175 runOutput(outputVector, inputTrees, inputTrees2);
180 vtkInformationVector *
ttkNotUsed(outputVector),
186 std::vector<ttk::ftm::MergeTree<float>> intermediateMTrees,
191 inputTrees, intermediateMTrees, treesNodes, treesArcs, treesSegmentation,
192 useSecondPairsType, DiagramPairTypes);
198 auto &inputTrees2ToUse
201 treesNodes2, treesArcs2, treesSegmentation2,
202 !useSecondPairsType, DiagramPairTypes);
206 const int numInputs = intermediateMTrees.size();
207 const int numInputs2 = intermediateMTrees2.size();
208 setDataVisualization(numInputs, numInputs2);
213 execute(intermediateMTrees, intermediateMTrees2);
216 intermediateMTrees, intermediateDTrees);
223 vtkInformationVector *outputVector,
231 auto output_origins = vtkMultiBlockDataSet::GetData(outputVector, 0);
232 auto output_vectors = vtkMultiBlockDataSet::GetData(outputVector, 1);
233 auto output_coef = vtkMultiBlockDataSet::GetData(outputVector, 2);
234 auto output_data = vtkMultiBlockDataSet::GetData(outputVector, 3);
239 std::vector<std::vector<ttk::ftm::idNode>> originsMatchingVectorT,
240 invOriginsMatchingVectorT;
241 std::vector<std::vector<std::vector<ttk::ftm::idNode>>>
242 invDataMatchingVectorT;
243 std::vector<std::vector<ttk::ftm::idNode>> invReconstMatchingVectorT;
244 ttk::wnn::makeMatchingVectors(
246 invOriginsMatchingVectorT,
dataMatchings_, recs_, invDataMatchingVectorT,
249 std::vector<std::vector<ttk::ftm::idNode>> originsMatchingVector;
250 std::vector<std::vector<double>> originsPersPercent, originsPersDiff;
251 std::vector<double> originPersPercent, originPersDiff;
252 std::vector<int> originPersistenceOrder;
253 ttk::wnn::computeTrackingInformation(
254 originsCopy_, originsPrimeCopy_, originsMatchingVectorT,
256 originsPersPercent, originsPersDiff, originPersPercent, originPersDiff,
257 originPersistenceOrder);
262 ttk::wnn::makeDataOutput(
264 invDataMatchingVectorT, invReconstMatchingVectorT, originsMatchingVectorT,
265 originsMatchingVector, originsPersPercent, originsPersDiff,
273 ttk::wnn::makeOriginsOutput(
274 output_origins, originsCopy_, originsPrimeCopy_, originPersPercent,
275 originPersDiff, originPersistenceOrder, originsMatchingVector,
282 ttk::wnn::makeCoefficientsOutput(output_coef, allAlphas_, allScaledAlphas_,
283 allActAlphas_, allActScaledAlphas_,
287 std::vector<std::string> paramNames;
289 for(
auto paramName : paramNames) {
290 vtkNew<vtkDoubleArray> array{};
291 array->SetName(paramName.c_str());
293 output_coef->GetFieldData()->AddArray(array);
295 vtkNew<vtkDoubleArray> arrayActivate{};
296 arrayActivate->SetName(
"activate");
297 arrayActivate->InsertNextTuple1(
activate_);
298 output_coef->GetFieldData()->AddArray(arrayActivate);
299 vtkNew<vtkDoubleArray> arrayActivateFunction{};
300 arrayActivateFunction->SetName(
"activationFunction");
302 output_coef->GetFieldData()->AddArray(arrayActivateFunction);
304 vtkNew<vtkIntArray> diagramPairTypesArray{};
305 diagramPairTypesArray->SetName(
"DiagramPairTypes");
306 diagramPairTypesArray->InsertNextTuple1(DiagramPairTypes);
307 output_coef->GetFieldData()->AddArray(diagramPairTypesArray);
312 std::vector<std::vector<std::vector<ttk::ftm::idNode>>> dataMatchingVectorT(
314 for(
unsigned int l = 0; l < dataMatchingVectorT.size(); ++l) {
316 for(
unsigned int i = 0; i < dataMatchingVectorT[l].size(); ++i) {
317 auto &origin = (l == 0 ? originsCopy_[0] : originsPrimeCopy_[l - 1]);
320 dataMatchingVectorT[l][i]);
323 output_vectors->SetNumberOfBlocks(2);
324 vtkSmartPointer<vtkMultiBlockDataSet> vectors
325 = vtkSmartPointer<vtkMultiBlockDataSet>::New();
327 vtkSmartPointer<vtkMultiBlockDataSet> vectorsPrime
328 = vtkSmartPointer<vtkMultiBlockDataSet>::New();
329 vectorsPrime->SetNumberOfBlocks(
noLayers_);
330 for(
unsigned int l = 0; l <
noLayers_; ++l) {
331 vtkSmartPointer<vtkTable> vectorsTable = vtkSmartPointer<vtkTable>::New();
332 vtkSmartPointer<vtkTable> vectorsPrimeTable
333 = vtkSmartPointer<vtkTable>::New();
334 for(
unsigned int v = 0; v < layers_[l].getVSTensor().sizes()[1]; ++v) {
336 vtkNew<vtkFloatArray> vectorArray{};
338 layers_[l].getVSTensor().sizes()[1], v, 0, 0,
false);
339 vectorArray->SetName(name.c_str());
340 vectorArray->SetNumberOfTuples(layers_[l].getVSTensor().sizes()[0]);
341 for(
unsigned int i = 0; i < layers_[l].getVSTensor().sizes()[0]; ++i)
342 vectorArray->SetTuple1(i, layers_[l].getVSTensor()[i][v].item<
float>());
343 vectorsTable->AddColumn(vectorArray);
345 vtkNew<vtkFloatArray> vectorPrimeArray{};
347 layers_[l].getVSTensor().sizes()[1], v, 0, 0,
false);
348 vectorPrimeArray->SetName(name2.c_str());
349 vectorPrimeArray->SetNumberOfTuples(
350 layers_[l].getVSPrimeTensor().sizes()[0]);
351 for(
unsigned int i = 0; i < layers_[l].getVSPrimeTensor().sizes()[0]; ++i)
352 vectorPrimeArray->SetTuple1(
353 i, layers_[l].getVSPrimeTensor()[i][v].item<
float>());
354 vectorsPrimeTable->AddColumn(vectorPrimeArray);
357 vtkNew<vtkUnsignedIntArray> revNodeCorrArray{};
358 revNodeCorrArray->SetName(
"revNodeCorr");
359 revNodeCorrArray->SetNumberOfTuples(layers_[l].getVSTensor().sizes()[0]);
360 std::vector<unsigned int> revNodeCorr;
361 getReverseTorchNodeCorr(originsCopy_[l], revNodeCorr);
362 for(
unsigned int i = 0; i < layers_[l].getVSTensor().sizes()[0]; ++i)
363 revNodeCorrArray->SetTuple1(i, revNodeCorr[i]);
364 vectorsTable->AddColumn(revNodeCorrArray);
366 vtkNew<vtkUnsignedIntArray> revNodeCorrPrimeArray{};
367 revNodeCorrPrimeArray->SetNumberOfTuples(
368 layers_[l].getVSPrimeTensor().sizes()[0]);
369 revNodeCorrPrimeArray->SetName(
"revNodeCorr");
370 std::vector<unsigned int> revNodeCorrPrime;
371 getReverseTorchNodeCorr(originsPrimeCopy_[l], revNodeCorrPrime);
372 for(
unsigned int i = 0; i < layers_[l].getVSPrimeTensor().sizes()[0]; ++i)
373 revNodeCorrPrimeArray->SetTuple1(i, revNodeCorrPrime[i]);
374 vectorsPrimeTable->AddColumn(revNodeCorrPrimeArray);
376 auto addOriginMatchingArray
377 = [&](vtkSmartPointer<vtkTable> &vectorsTableT,
378 std::vector<ttk::ftm::idNode> &originMatchingVector) {
379 vtkNew<vtkIntArray> matchingArray{};
380 matchingArray->SetNumberOfTuples(originMatchingVector.size());
381 matchingArray->SetName(
"nextOriginMatching");
382 for(
unsigned int i = 0; i < originMatchingVector.size(); ++i)
383 matchingArray->SetTuple1(i, (
int)originMatchingVector[i]);
384 vectorsTableT->AddColumn(matchingArray);
387 addOriginMatchingArray(vectorsTable, originsMatchingVectorT[l]);
388 if(l < originsMatchingVectorT.size() - 1)
389 addOriginMatchingArray(vectorsPrimeTable, originsMatchingVectorT[l + 1]);
391 auto addDataMatchingArray
392 = [&](vtkSmartPointer<vtkTable> &vectorsTableT,
393 std::vector<std::vector<ttk::ftm::idNode>> &dataMatchingVector) {
394 for(
unsigned int i = 0; i < dataMatchingVector.size(); ++i) {
395 vtkNew<vtkIntArray> matchingArray{};
396 matchingArray->SetNumberOfTuples(dataMatchingVector[i].size());
397 std::stringstream ss;
400 matchingArray->SetName(ss.str().c_str());
401 for(
unsigned int j = 0; j < dataMatchingVector[i].size(); ++j)
402 matchingArray->SetTuple1(j, (
int)dataMatchingVector[i][j]);
403 vectorsTableT->AddColumn(matchingArray);
407 addDataMatchingArray(vectorsTable, dataMatchingVectorT[l]);
408 if(l < dataMatchingVectorT.size() - 1)
409 addDataMatchingArray(vectorsPrimeTable, dataMatchingVectorT[l + 1]);
412 for(
unsigned int i = 0; i < invReconstMatchingVectorT.size(); ++i) {
413 vtkNew<vtkIntArray> matchingArray{};
414 matchingArray->SetNumberOfTuples(invReconstMatchingVectorT[i].size());
415 std::stringstream ss;
416 ss <<
"reconstMatching"
418 matchingArray->SetName(ss.str().c_str());
419 for(
unsigned int j = 0; j < invReconstMatchingVectorT[i].size(); ++j)
420 matchingArray->SetTuple1(j, (
int)invReconstMatchingVectorT[i][j]);
421 vectorsPrimeTable->AddColumn(matchingArray);
426 vectors->SetBlock(l, vectorsTable);
427 std::stringstream ss;
428 ss <<
"Vectors" << l;
429 vectors->GetMetaData(l)->Set(vtkCompositeDataSet::NAME(), ss.str());
430 vectorsPrime->SetBlock(l, vectorsPrimeTable);
432 ss <<
"VectorsPrime" << l;
433 vectorsPrime->GetMetaData(l)->Set(vtkCompositeDataSet::NAME(), ss.str());
435 output_vectors->SetBlock(0, vectors);
436 output_vectors->SetBlock(1, vectorsPrime);
437 unsigned int num = 0;
438 output_vectors->GetMetaData(num)->Set(vtkCompositeDataSet::NAME(),
"Vectors");
440 output_vectors->GetMetaData(num)->Set(
441 vtkCompositeDataSet::NAME(),
"VectorsPrime");
#define TTK_FORCE_USE(x)
Force the compiler to use the function/method parameter.
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
TTK VTK-filter that wraps the ttk::MergeTreeAutoencoder module.
int RequestData(vtkInformation *request, vtkInformationVector **inputVector, vtkInformationVector *outputVector) override
ttkMergeTreeAutoencoder()
int FillOutputPortInformation(int port, vtkInformation *info) override
int run(vtkInformationVector *outputVector, std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees, std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees2)
int FillInputPortInformation(int port, vtkInformation *info) override
int runCompute(vtkInformationVector *outputVector, std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees, std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees2)
int runOutput(vtkInformationVector *outputVector, std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees, std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees2)
int printErr(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
double clusteringLossWeight_
double mixtureCoefficient_
void getParamNames(std::vector< std::string > ¶mNames)
bool normalizedWasserstein_
std::vector< std::vector< int > > treesNodeCorr_
bool branchDecomposition_
bool isPersistenceDiagram_
double getParamValueFromName(std::string ¶mName)
unsigned int activationFunction_
void execute(std::vector< ftm::MergeTree< float > > &trees, std::vector< ftm::MergeTree< float > > &trees2)
std::vector< unsigned int > clusterAsgn_
std::vector< std::vector< double > > persCorrelationMatrix_
std::vector< std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > > dataMatchings_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > originsMatchings_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > reconstMatchings_
std::string getTableVectorName(int noAxes, int axeNum, int vId, int vComp, bool isSecondInput)
std::string getTableTreeName(int noTrees, int treeNum)
void getMatchingVector(const ftm::MergeTree< dataType > &barycenter, const ftm::MergeTree< dataType > &tree, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matchings, std::vector< ftm::idNode > &matchingVector)
bool constructTrees(std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees, std::vector< MergeTree< dataType > > &intermediateTrees, std::vector< vtkUnstructuredGrid * > &treesNodes, std::vector< vtkUnstructuredGrid * > &treesArcs, std::vector< vtkDataSet * > &treesSegmentation, const std::vector< bool > &useSecondPairsTypeVec, int diagramPairTypes=0)
void loadBlocks(std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees, vtkMultiBlockDataSet *blocks)
void mergeTreesTemplateToDouble(std::vector< MergeTree< dataType > > &mts, std::vector< MergeTree< double > > &newMts)
vtkStandardNewMacro(ttkMergeTreeAutoencoder)
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/| (_) |"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)