TTK
Loading...
Searching...
No Matches
ttkMergeTreeAutoencoder.cpp
Go to the documentation of this file.
4#include <ttkMergeTreeUtils.h>
6
7#include <vtkDataArray.h>
8#include <vtkDataSet.h>
9#include <vtkFloatArray.h>
10#include <vtkInformation.h>
11#include <vtkObjectFactory.h>
12#include <vtkPointData.h>
13#include <vtkSmartPointer.h>
14#include <vtkTable.h>
15#include <vtkUnsignedIntArray.h>
16
17#include <ttkMacros.h>
18#include <ttkUtils.h>
19
20// A VTK macro that enables the instantiation of this class via ::New()
21// You do not have to modify this
23
37 this->SetNumberOfInputPorts(3);
38 this->SetNumberOfOutputPorts(4);
39}
40
49 vtkInformation *info) {
50 if(port == 0) {
51 info->Set(vtkAlgorithm::INPUT_REQUIRED_DATA_TYPE(), "vtkMultiBlockDataSet");
52 } else if(port == 1) {
53 info->Set(vtkAlgorithm::INPUT_REQUIRED_DATA_TYPE(), "vtkMultiBlockDataSet");
54 info->Set(vtkAlgorithm::INPUT_IS_OPTIONAL(), 1);
55 } else if(port == 2) {
56 info->Set(vtkAlgorithm::INPUT_REQUIRED_DATA_TYPE(), "vtkTable");
57 info->Set(vtkAlgorithm::INPUT_IS_OPTIONAL(), 1);
58 } else
59 return 0;
60 return 1;
61}
62
79 vtkInformation *info) {
80 if(port == 0 or port == 1 or port == 2 or port == 3) {
81 info->Set(vtkDataObject::DATA_TYPE_NAME(), "vtkMultiBlockDataSet");
82 } else {
83 return 0;
84 }
85 return 1;
86}
87
102 vtkInformationVector **inputVector,
103 vtkInformationVector *outputVector) {
104#ifndef TTK_ENABLE_TORCH
105 TTK_FORCE_USE(inputVector);
106 TTK_FORCE_USE(outputVector);
107 printErr("This filter requires Torch.");
108 return 0;
109#else
110 // ------------------------------------------------------------------------------------
111 // --- Get input object from input vector
112 // ------------------------------------------------------------------------------------
113 auto blocks = vtkMultiBlockDataSet::GetData(inputVector[0], 0);
114 auto blocks2 = vtkMultiBlockDataSet::GetData(inputVector[1], 0);
115 auto table = vtkTable::GetData(inputVector[2], 0);
116
117 // ------------------------------------------------------------------------------------
118 // --- Load blocks
119 // ------------------------------------------------------------------------------------
120 std::vector<vtkSmartPointer<vtkMultiBlockDataSet>> inputTrees, inputTrees2;
121 ttk::ftm::loadBlocks(inputTrees, blocks);
122 ttk::ftm::loadBlocks(inputTrees2, blocks2);
123
124 // Load table
125 clusterAsgn_.clear();
126 vtkAbstractArray *clusterAsgn;
127 if(table) {
128 clusterAsgn = this->GetInputArrayToProcess(0, inputVector);
129 if(clusterAsgn) {
130 clusterAsgn_.resize(clusterAsgn->GetNumberOfValues());
131 for(unsigned int i = 0; i < clusterAsgn_.size(); ++i)
132 clusterAsgn_[i] = clusterAsgn->GetVariantValue(i).ToInt();
133 }
134 }
135 if((not table or not clusterAsgn) and clusteringLossWeight_ != 0) {
136 printErr(
137 "You must provide a table column in info input to use clustering loss");
138 return 0;
139 }
140 if(clusteringLossWeight_ != 0) {
141 std::stringstream ss;
142 for(auto &e : clusterAsgn_)
143 ss << e << " ";
144 printMsg(ss.str());
145 }
146
147 // ------------------------------------------------------------------------------------
148 // If we have already computed once but the input has changed
149 if((treesNodes.size() != 0 and inputTrees[0]->GetBlock(0) != treesNodes[0])
150 or (treesNodes2.size() != inputTrees2.size()))
151 resetDataVisualization();
152
153 // Parameters
155 if(not normalizedWasserstein_) {
156 oldEpsilonTree1 = epsilonTree1_;
157 epsilonTree1_ = 100;
158 } else
159 epsilonTree1_ = oldEpsilonTree1;
161 printMsg("Computation with normalized Wasserstein.");
162 else
163 printMsg("Computation without normalized Wasserstein.");
164
165 return run(outputVector, inputTrees, inputTrees2);
166#endif
167}
168
169#ifdef TTK_ENABLE_TORCH
171 vtkInformationVector *outputVector,
172 std::vector<vtkSmartPointer<vtkMultiBlockDataSet>> &inputTrees,
173 std::vector<vtkSmartPointer<vtkMultiBlockDataSet>> &inputTrees2) {
174 runCompute(outputVector, inputTrees, inputTrees2);
175 runOutput(outputVector, inputTrees, inputTrees2);
176 return 1;
177}
178
180 vtkInformationVector *ttkNotUsed(outputVector),
181 std::vector<vtkSmartPointer<vtkMultiBlockDataSet>> &inputTrees,
182 std::vector<vtkSmartPointer<vtkMultiBlockDataSet>> &inputTrees2) {
183 // ------------------------------------------------------------------------------------
184 // --- Construct trees
185 // ------------------------------------------------------------------------------------
186 std::vector<ttk::ftm::MergeTree<float>> intermediateMTrees,
187 intermediateMTrees2;
188
189 bool useSecondPairsType = (mixtureCoefficient_ == 0);
191 inputTrees, intermediateMTrees, treesNodes, treesArcs, treesSegmentation,
192 useSecondPairsType, DiagramPairTypes);
193 // If merge trees are provided in input and normalization is not asked
197 or (mixtureCoefficient_ != 0 and mixtureCoefficient_ != 1)) {
198 auto &inputTrees2ToUse
199 = (not isPersistenceDiagram_ ? inputTrees2 : inputTrees);
200 ttk::ftm::constructTrees<float>(inputTrees2ToUse, intermediateMTrees2,
201 treesNodes2, treesArcs2, treesSegmentation2,
202 !useSecondPairsType, DiagramPairTypes);
203 }
205
206 const int numInputs = intermediateMTrees.size();
207 const int numInputs2 = intermediateMTrees2.size();
208 setDataVisualization(numInputs, numInputs2);
209
210 // ------------------------------------------------------------------------------------
211 // --- Call base
212 // ------------------------------------------------------------------------------------
213 execute(intermediateMTrees, intermediateMTrees2);
214
216 intermediateMTrees, intermediateDTrees);
217
218 return 1;
219}
220
221// TODO manage double input
223 vtkInformationVector *outputVector,
224 std::vector<vtkSmartPointer<vtkMultiBlockDataSet>> &inputTrees,
225 std::vector<vtkSmartPointer<vtkMultiBlockDataSet>> &ttkNotUsed(inputTrees2)) {
226 if(not createOutput_)
227 return 1;
228 // ------------------------------------------------------------------------------------
229 // --- Create output
230 // ------------------------------------------------------------------------------------
231 auto output_origins = vtkMultiBlockDataSet::GetData(outputVector, 0);
232 auto output_vectors = vtkMultiBlockDataSet::GetData(outputVector, 1);
233 auto output_coef = vtkMultiBlockDataSet::GetData(outputVector, 2);
234 auto output_data = vtkMultiBlockDataSet::GetData(outputVector, 3);
235
236 // ------------------------------------------
237 // --- Tracking information
238 // ------------------------------------------
239 std::vector<std::vector<ttk::ftm::idNode>> originsMatchingVectorT,
240 invOriginsMatchingVectorT;
241 std::vector<std::vector<std::vector<ttk::ftm::idNode>>>
242 invDataMatchingVectorT;
243 std::vector<std::vector<ttk::ftm::idNode>> invReconstMatchingVectorT;
244 ttk::wnn::makeMatchingVectors(
245 originsMatchings_, originsCopy_, originsPrimeCopy_, originsMatchingVectorT,
246 invOriginsMatchingVectorT, dataMatchings_, recs_, invDataMatchingVectorT,
247 reconstMatchings_, invReconstMatchingVectorT);
248
249 std::vector<std::vector<ttk::ftm::idNode>> originsMatchingVector;
250 std::vector<std::vector<double>> originsPersPercent, originsPersDiff;
251 std::vector<double> originPersPercent, originPersDiff;
252 std::vector<int> originPersistenceOrder;
253 ttk::wnn::computeTrackingInformation(
254 originsCopy_, originsPrimeCopy_, originsMatchingVectorT,
255 invOriginsMatchingVectorT, isPersistenceDiagram_, originsMatchingVector,
256 originsPersPercent, originsPersDiff, originPersPercent, originPersDiff,
257 originPersistenceOrder);
258
259 // ------------------------------------------
260 // --- Data
261 // ------------------------------------------
262 ttk::wnn::makeDataOutput(
263 output_data, recs_, 1, treesSegmentation, persCorrelationMatrix_,
264 invDataMatchingVectorT, invReconstMatchingVectorT, originsMatchingVectorT,
265 originsMatchingVector, originsPersPercent, originsPersDiff,
266 originPersistenceOrder, treesNodes, treesNodeCorr_, bestLoss_,
268 this->debugLevel_);
269
270 // ------------------------------------------
271 // --- Origins
272 // ------------------------------------------
273 ttk::wnn::makeOriginsOutput(
274 output_origins, originsCopy_, originsPrimeCopy_, originPersPercent,
275 originPersDiff, originPersistenceOrder, originsMatchingVector,
276 originsPersPercent, originsPersDiff, mixtureCoefficient_,
278
279 // ------------------------------------------
280 // --- Coefficients
281 // ------------------------------------------
282 ttk::wnn::makeCoefficientsOutput(output_coef, allAlphas_, allScaledAlphas_,
283 allActAlphas_, allActScaledAlphas_,
284 clusterAsgn_, recs_, inputTrees);
285
286 // Field Data Input Parameters
287 std::vector<std::string> paramNames;
288 getParamNames(paramNames);
289 for(auto paramName : paramNames) {
290 vtkNew<vtkDoubleArray> array{};
291 array->SetName(paramName.c_str());
292 array->InsertNextTuple1(getParamValueFromName(paramName));
293 output_coef->GetFieldData()->AddArray(array);
294 }
295 vtkNew<vtkDoubleArray> arrayActivate{};
296 arrayActivate->SetName("activate");
297 arrayActivate->InsertNextTuple1(activate_);
298 output_coef->GetFieldData()->AddArray(arrayActivate);
299 vtkNew<vtkDoubleArray> arrayActivateFunction{};
300 arrayActivateFunction->SetName("activationFunction");
301 arrayActivateFunction->InsertNextTuple1(activationFunction_);
302 output_coef->GetFieldData()->AddArray(arrayActivateFunction);
303
304 vtkNew<vtkIntArray> diagramPairTypesArray{};
305 diagramPairTypesArray->SetName("DiagramPairTypes");
306 diagramPairTypesArray->InsertNextTuple1(DiagramPairTypes);
307 output_coef->GetFieldData()->AddArray(diagramPairTypesArray);
308
309 // ------------------------------------------
310 // --- Axes Vectors
311 // ------------------------------------------
312 std::vector<std::vector<std::vector<ttk::ftm::idNode>>> dataMatchingVectorT(
313 dataMatchings_.size());
314 for(unsigned int l = 0; l < dataMatchingVectorT.size(); ++l) {
315 dataMatchingVectorT[l].resize(dataMatchings_[l].size());
316 for(unsigned int i = 0; i < dataMatchingVectorT[l].size(); ++i) {
317 auto &origin = (l == 0 ? originsCopy_[0] : originsPrimeCopy_[l - 1]);
318 ttk::axa::getMatchingVector(origin.mTree, recs_[i][l].mTree,
319 dataMatchings_[l][i],
320 dataMatchingVectorT[l][i]);
321 }
322 }
323 output_vectors->SetNumberOfBlocks(2);
324 vtkSmartPointer<vtkMultiBlockDataSet> vectors
325 = vtkSmartPointer<vtkMultiBlockDataSet>::New();
326 vectors->SetNumberOfBlocks(noLayers_);
327 vtkSmartPointer<vtkMultiBlockDataSet> vectorsPrime
328 = vtkSmartPointer<vtkMultiBlockDataSet>::New();
329 vectorsPrime->SetNumberOfBlocks(noLayers_);
330 for(unsigned int l = 0; l < noLayers_; ++l) {
331 vtkSmartPointer<vtkTable> vectorsTable = vtkSmartPointer<vtkTable>::New();
332 vtkSmartPointer<vtkTable> vectorsPrimeTable
333 = vtkSmartPointer<vtkTable>::New();
334 for(unsigned int v = 0; v < layers_[l].getVSTensor().sizes()[1]; ++v) {
335 // Vs
336 vtkNew<vtkFloatArray> vectorArray{};
337 std::string name = ttk::axa::getTableVectorName(
338 layers_[l].getVSTensor().sizes()[1], v, 0, 0, false);
339 vectorArray->SetName(name.c_str());
340 vectorArray->SetNumberOfTuples(layers_[l].getVSTensor().sizes()[0]);
341 for(unsigned int i = 0; i < layers_[l].getVSTensor().sizes()[0]; ++i)
342 vectorArray->SetTuple1(i, layers_[l].getVSTensor()[i][v].item<float>());
343 vectorsTable->AddColumn(vectorArray);
344 // Vs Prime
345 vtkNew<vtkFloatArray> vectorPrimeArray{};
346 std::string name2 = ttk::axa::getTableVectorName(
347 layers_[l].getVSTensor().sizes()[1], v, 0, 0, false);
348 vectorPrimeArray->SetName(name2.c_str());
349 vectorPrimeArray->SetNumberOfTuples(
350 layers_[l].getVSPrimeTensor().sizes()[0]);
351 for(unsigned int i = 0; i < layers_[l].getVSPrimeTensor().sizes()[0]; ++i)
352 vectorPrimeArray->SetTuple1(
353 i, layers_[l].getVSPrimeTensor()[i][v].item<float>());
354 vectorsPrimeTable->AddColumn(vectorPrimeArray);
355 }
356 // Rev node corr
357 vtkNew<vtkUnsignedIntArray> revNodeCorrArray{};
358 revNodeCorrArray->SetName("revNodeCorr");
359 revNodeCorrArray->SetNumberOfTuples(layers_[l].getVSTensor().sizes()[0]);
360 std::vector<unsigned int> revNodeCorr;
361 getReverseTorchNodeCorr(originsCopy_[l], revNodeCorr);
362 for(unsigned int i = 0; i < layers_[l].getVSTensor().sizes()[0]; ++i)
363 revNodeCorrArray->SetTuple1(i, revNodeCorr[i]);
364 vectorsTable->AddColumn(revNodeCorrArray);
365 // Rev node corr prime
366 vtkNew<vtkUnsignedIntArray> revNodeCorrPrimeArray{};
367 revNodeCorrPrimeArray->SetNumberOfTuples(
368 layers_[l].getVSPrimeTensor().sizes()[0]);
369 revNodeCorrPrimeArray->SetName("revNodeCorr");
370 std::vector<unsigned int> revNodeCorrPrime;
371 getReverseTorchNodeCorr(originsPrimeCopy_[l], revNodeCorrPrime);
372 for(unsigned int i = 0; i < layers_[l].getVSPrimeTensor().sizes()[0]; ++i)
373 revNodeCorrPrimeArray->SetTuple1(i, revNodeCorrPrime[i]);
374 vectorsPrimeTable->AddColumn(revNodeCorrPrimeArray);
375 // Origins Matchings
376 auto addOriginMatchingArray
377 = [&](vtkSmartPointer<vtkTable> &vectorsTableT,
378 std::vector<ttk::ftm::idNode> &originMatchingVector) {
379 vtkNew<vtkIntArray> matchingArray{};
380 matchingArray->SetNumberOfTuples(originMatchingVector.size());
381 matchingArray->SetName("nextOriginMatching");
382 for(unsigned int i = 0; i < originMatchingVector.size(); ++i)
383 matchingArray->SetTuple1(i, (int)originMatchingVector[i]);
384 vectorsTableT->AddColumn(matchingArray);
385 };
386 if(l == 0)
387 addOriginMatchingArray(vectorsTable, originsMatchingVectorT[l]);
388 if(l < originsMatchingVectorT.size() - 1)
389 addOriginMatchingArray(vectorsPrimeTable, originsMatchingVectorT[l + 1]);
390 // Data Matchings
391 auto addDataMatchingArray
392 = [&](vtkSmartPointer<vtkTable> &vectorsTableT,
393 std::vector<std::vector<ttk::ftm::idNode>> &dataMatchingVector) {
394 for(unsigned int i = 0; i < dataMatchingVector.size(); ++i) {
395 vtkNew<vtkIntArray> matchingArray{};
396 matchingArray->SetNumberOfTuples(dataMatchingVector[i].size());
397 std::stringstream ss;
398 ss << "matching"
399 << ttk::axa::getTableTreeName(dataMatchingVector.size(), i);
400 matchingArray->SetName(ss.str().c_str());
401 for(unsigned int j = 0; j < dataMatchingVector[i].size(); ++j)
402 matchingArray->SetTuple1(j, (int)dataMatchingVector[i][j]);
403 vectorsTableT->AddColumn(matchingArray);
404 }
405 };
406 if(l == 0)
407 addDataMatchingArray(vectorsTable, dataMatchingVectorT[l]);
408 if(l < dataMatchingVectorT.size() - 1)
409 addDataMatchingArray(vectorsPrimeTable, dataMatchingVectorT[l + 1]);
410 // Reconst Matchings
411 if(l == noLayers_ - 1) {
412 for(unsigned int i = 0; i < invReconstMatchingVectorT.size(); ++i) {
413 vtkNew<vtkIntArray> matchingArray{};
414 matchingArray->SetNumberOfTuples(invReconstMatchingVectorT[i].size());
415 std::stringstream ss;
416 ss << "reconstMatching"
417 << ttk::axa::getTableTreeName(invReconstMatchingVectorT.size(), i);
418 matchingArray->SetName(ss.str().c_str());
419 for(unsigned int j = 0; j < invReconstMatchingVectorT[i].size(); ++j)
420 matchingArray->SetTuple1(j, (int)invReconstMatchingVectorT[i][j]);
421 vectorsPrimeTable->AddColumn(matchingArray);
422 }
423 }
424
425 // Add new block
426 vectors->SetBlock(l, vectorsTable);
427 std::stringstream ss;
428 ss << "Vectors" << l;
429 vectors->GetMetaData(l)->Set(vtkCompositeDataSet::NAME(), ss.str());
430 vectorsPrime->SetBlock(l, vectorsPrimeTable);
431 ss.str("");
432 ss << "VectorsPrime" << l;
433 vectorsPrime->GetMetaData(l)->Set(vtkCompositeDataSet::NAME(), ss.str());
434 }
435 output_vectors->SetBlock(0, vectors);
436 output_vectors->SetBlock(1, vectorsPrime);
437 unsigned int num = 0;
438 output_vectors->GetMetaData(num)->Set(vtkCompositeDataSet::NAME(), "Vectors");
439 num = 1;
440 output_vectors->GetMetaData(num)->Set(
441 vtkCompositeDataSet::NAME(), "VectorsPrime");
442
443 // return success
444 return 1;
445}
446#endif
#define TTK_FORCE_USE(x)
Force the compiler to use the function/method parameter.
Definition BaseClass.h:57
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
Definition BaseClass.h:47
TTK VTK-filter that wraps the ttk::MergeTreeAutoencoder module.
int RequestData(vtkInformation *request, vtkInformationVector **inputVector, vtkInformationVector *outputVector) override
int FillOutputPortInformation(int port, vtkInformation *info) override
int run(vtkInformationVector *outputVector, std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees, std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees2)
int FillInputPortInformation(int port, vtkInformation *info) override
int runCompute(vtkInformationVector *outputVector, std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees, std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees2)
int runOutput(vtkInformationVector *outputVector, std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees, std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees2)
int debugLevel_
Definition Debug.h:379
int printErr(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
Definition Debug.h:149
void getParamNames(std::vector< std::string > &paramNames)
std::vector< std::vector< int > > treesNodeCorr_
double getParamValueFromName(std::string &paramName)
void execute(std::vector< ftm::MergeTree< float > > &trees, std::vector< ftm::MergeTree< float > > &trees2)
std::vector< unsigned int > clusterAsgn_
std::vector< std::vector< double > > persCorrelationMatrix_
std::vector< std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > > dataMatchings_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > originsMatchings_
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > reconstMatchings_
std::string getTableVectorName(int noAxes, int axeNum, int vId, int vComp, bool isSecondInput)
std::string getTableTreeName(int noTrees, int treeNum)
void getMatchingVector(const ftm::MergeTree< dataType > &barycenter, const ftm::MergeTree< dataType > &tree, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matchings, std::vector< ftm::idNode > &matchingVector)
bool constructTrees(std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees, std::vector< MergeTree< dataType > > &intermediateTrees, std::vector< vtkUnstructuredGrid * > &treesNodes, std::vector< vtkUnstructuredGrid * > &treesArcs, std::vector< vtkDataSet * > &treesSegmentation, const std::vector< bool > &useSecondPairsTypeVec, int diagramPairTypes=0)
void loadBlocks(std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees, vtkMultiBlockDataSet *blocks)
void mergeTreesTemplateToDouble(std::vector< MergeTree< dataType > > &mts, std::vector< MergeTree< double > > &newMts)
vtkStandardNewMacro(ttkMergeTreeAutoencoder)
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/| (_) |"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)