TTK
Loading...
Searching...
No Matches
ttkMergeTreeAutoencoderDecoding.cpp
Go to the documentation of this file.
5#include <ttkMergeTreeUtils.h>
6
7#include <vtkInformation.h>
8
9#include <vtkDataArray.h>
10#include <vtkDataSet.h>
11#include <vtkMultiBlockDataSet.h>
12#include <vtkObjectFactory.h>
13#include <vtkPointData.h>
14#include <vtkSmartPointer.h>
15#include <vtkTable.h>
16
17#include <ttkMacros.h>
18#include <ttkUtils.h>
19
20// A VTK macro that enables the instantiation of this class via ::New()
21// You do not have to modify this
23
37 this->SetNumberOfInputPorts(3);
38 this->SetNumberOfOutputPorts(1);
39}
40
49 int port, vtkInformation *info) {
50 if(port == 0 or port == 1 or port == 2) {
51 info->Set(vtkAlgorithm::INPUT_REQUIRED_DATA_TYPE(), "vtkMultiBlockDataSet");
52 return 1;
53 }
54 return 0;
55}
56
73 int port, vtkInformation *info) {
74 if(port == 0) {
76 return 1;
77 }
78 return 0;
79}
80
95 vtkInformation *ttkNotUsed(request),
96 vtkInformationVector **inputVector,
97 vtkInformationVector *outputVector) {
98#ifndef TTK_ENABLE_TORCH
99 TTK_FORCE_USE(inputVector);
100 TTK_FORCE_USE(outputVector);
101 printErr("This filter requires Torch.");
102 return 0;
103#else
104 // --------------------------------------------------------------------------
105 // --- Read Input
106 // --------------------------------------------------------------------------
107 auto blockBary = vtkMultiBlockDataSet::GetData(inputVector[0], 0);
108 auto vectors = vtkMultiBlockDataSet::GetData(inputVector[1], 0);
109 auto coefficients = vtkMultiBlockDataSet::GetData(inputVector[2], 0);
110
111 // Parameters
112 printMsg("Load parameters from field data.");
113 std::vector<std::string> paramNames;
114 getParamNames(paramNames);
115 paramNames.emplace_back("activate");
116 paramNames.emplace_back("activationFunction");
117 auto fd = coefficients->GetFieldData();
118 if(not fd->GetArray("activate"))
119 fd = coefficients->GetBlock(0)->GetFieldData();
120 for(auto paramName : paramNames) {
121 auto array = fd->GetArray(paramName.c_str());
122 if(array) {
123 double const value = array->GetTuple1(0);
124 if(paramName == "activate")
125 activate_ = value;
126 else if(paramName == "activationFunction")
127 activationFunction_ = value;
128 else
129 setParamValueFromName(paramName, value);
130 } else
131 printMsg(" - " + paramName + " was not found in the field data.");
132 auto stringValue = std::to_string(
133 (paramName == "activate" ? activate_
134 : (paramName == "activationFunction"
136 : getParamValueFromName(paramName))));
137 printMsg(" - " + paramName + " = " + stringValue);
138 }
140 printMsg("Computation with normalized Wasserstein.");
141 else
142 printMsg("Computation without normalized Wasserstein.");
143
144 // -----------------
145 // Origins
146 // -----------------
147 std::vector<vtkSmartPointer<vtkMultiBlockDataSet>> origins, originsPrime;
149 origins, vtkMultiBlockDataSet::SafeDownCast((blockBary->GetBlock(0))));
151 originsPrime, vtkMultiBlockDataSet::SafeDownCast(blockBary->GetBlock(1)));
152
153 std::vector<ttk::ftm::MergeTree<float>> originsTrees, originsPrimeTrees;
154 std::vector<vtkUnstructuredGrid *> originsTreeNodes, originsTreeArcs,
155 originsPrimeTreeNodes, originsPrimeTreeArcs;
156 std::vector<vtkDataSet *> originsTreeSegmentations,
157 originsPrimeTreeSegmentations;
158
159 bool useSadMaxPairs = (mixtureCoefficient_ == 0);
160 isPersistenceDiagram_ = ttk::ftm::constructTrees<float>(
161 origins, originsTrees, originsTreeNodes, originsTreeArcs,
162 originsTreeSegmentations, useSadMaxPairs);
163 ttk::ftm::constructTrees<float>(originsPrime, originsPrimeTrees,
164 originsTreeNodes, originsTreeArcs,
165 originsTreeSegmentations, useSadMaxPairs);
166 // If merge trees are provided in input and normalization is not asked
169
170 // -----------------
171 // Number of axes per layer
172 // -----------------
173 auto vS = vtkMultiBlockDataSet::SafeDownCast(vectors->GetBlock(0));
174 noLayers_ = vS->GetNumberOfBlocks();
175 std::vector<unsigned int> allNoAxes(noLayers_);
176 for(unsigned int l = 0; l < noLayers_; ++l) {
177 auto layerVectorsTable = vtkTable::SafeDownCast(vS->GetBlock(l));
178 unsigned int maxBoundNoAxes = 1;
179 while(not layerVectorsTable->GetColumnByName(
180 ttk::axa::getTableVectorName(maxBoundNoAxes, 0, 0, 0, false).c_str()))
181 maxBoundNoAxes *= 10;
182 unsigned int noAxes = 1;
183 while(layerVectorsTable->GetColumnByName(
184 ttk::axa::getTableVectorName(maxBoundNoAxes, noAxes, 0, 0, false)
185 .c_str()))
186 noAxes += 1;
187 allNoAxes[l] = noAxes;
188 }
189
190 // -----------------
191 // Coefficients
192 // -----------------
193 auto numberOfInputs
194 = vtkTable::SafeDownCast(coefficients->GetBlock(0))->GetNumberOfRows();
195 auto noCoefs = coefficients->GetNumberOfBlocks();
196 bool customRec = (noCoefs != noLayers_);
197 allAlphas_.resize(numberOfInputs, std::vector<torch::Tensor>(noCoefs));
198#ifdef TTK_ENABLE_OPENMP
199#pragma omp parallel for schedule(dynamic) num_threads(this->threadNumber_)
200#endif
201 for(unsigned int l = 0; l < noCoefs; ++l) {
202 auto layerCoefficientsTable
203 = vtkTable::SafeDownCast(coefficients->GetBlock(l));
204 auto noAxes = (customRec ? allNoAxes[getLatentLayerIndex()] : allNoAxes[l]);
205 std::vector<std::vector<float>> alphas(
206 numberOfInputs, std::vector<float>(noAxes));
207 for(unsigned int g = 0; g < noAxes; ++g) {
208 auto array = layerCoefficientsTable->GetColumnByName(
209 ttk::axa::getTableCoefficientName(noAxes, g).c_str());
210 for(unsigned int i = 0; i < numberOfInputs; ++i)
211 alphas[i][g] = array->GetVariantValue(i).ToFloat();
212 }
213 for(unsigned int i = 0; i < numberOfInputs; ++i)
214 allAlphas_[i][l] = torch::tensor(alphas[i]).reshape({-1, 1});
215 }
216
217 // -----------------
218 // Vectors
219 // -----------------
220 vSTensor_.resize(noLayers_);
221 vSPrimeTensor_.resize(noLayers_);
222 auto vSPrime = vtkMultiBlockDataSet::SafeDownCast(vectors->GetBlock(1));
223 std::vector<unsigned int *> allRevNodeCorr(noLayers_),
224 allRevNodeCorrPrime(noLayers_);
225 std::vector<unsigned int> allRevNodeCorrSize(noLayers_),
226 allRevNodeCorrPrimeSize(noLayers_);
227#ifdef TTK_ENABLE_OPENMP
228#pragma omp parallel for schedule(dynamic) num_threads(this->threadNumber_)
229#endif
230 for(unsigned int l = 0; l < vSTensor_.size(); ++l) {
231 auto layerVectorsTable = vtkTable::SafeDownCast(vS->GetBlock(l));
232 auto layerVectorsPrimeTable = vtkTable::SafeDownCast(vSPrime->GetBlock(l));
233 auto noRows = layerVectorsTable->GetNumberOfRows();
234 auto noRows2 = layerVectorsPrimeTable->GetNumberOfRows();
235 std::vector<float> vSTensor(noRows * allNoAxes[l]),
236 vSPrimeTensor(noRows2 * allNoAxes[l]);
237 for(unsigned int v = 0; v < allNoAxes[l]; ++v) {
238 std::string name
239 = ttk::axa::getTableVectorName(allNoAxes[l], v, 0, 0, false);
240 for(unsigned int i = 0; i < noRows; ++i)
241 vSTensor[i * allNoAxes[l] + v]
242 = layerVectorsTable->GetColumnByName(name.c_str())
243 ->GetVariantValue(i)
244 .ToFloat();
245 for(unsigned int i = 0; i < noRows2; ++i)
246 vSPrimeTensor[i * allNoAxes[l] + v]
247 = layerVectorsPrimeTable->GetColumnByName(name.c_str())
248 ->GetVariantValue(i)
249 .ToFloat();
250 }
251 vSTensor_[l] = torch::tensor(vSTensor).reshape({noRows, allNoAxes[l]});
252 vSPrimeTensor_[l]
253 = torch::tensor(vSPrimeTensor).reshape({noRows2, allNoAxes[l]});
254 allRevNodeCorr[l]
255 = ttkUtils::GetPointer<unsigned int>(vtkDataArray::SafeDownCast(
256 layerVectorsTable->GetColumnByName("revNodeCorr")));
257 allRevNodeCorrSize[l] = noRows;
258 allRevNodeCorrPrime[l]
259 = ttkUtils::GetPointer<unsigned int>(vtkDataArray::SafeDownCast(
260 layerVectorsPrimeTable->GetColumnByName("revNodeCorr")));
261 allRevNodeCorrPrimeSize[l] = noRows2;
262 }
263
264 // -----------------
265 // Call base
266 // -----------------
267 execute(originsTrees, originsPrimeTrees, allRevNodeCorr, allRevNodeCorrPrime,
268 allRevNodeCorrSize, allRevNodeCorrPrimeSize);
269
270 // --------------------------------------------------------------------------
271 // --- Create Output
272 // --------------------------------------------------------------------------
273 auto output_data = vtkMultiBlockDataSet::GetData(outputVector, 0);
274
275 // ------------------------------------------
276 // --- Read Matchings
277 // ------------------------------------------
278 auto originsMatchingSize = getLatentLayerIndex() + 1;
279 std::vector<std::vector<ttk::ftm::idNode>> originsMatchingVectorT(
280 originsMatchingSize),
281 invOriginsMatchingVectorT = originsMatchingVectorT;
282 for(unsigned int l = 0; l < originsMatchingVectorT.size(); ++l) {
283 auto array = vtkTable::SafeDownCast(
284 (l == 0 ? vS : vSPrime)->GetBlock((l == 0 ? l : l - 1)))
285 ->GetColumnByName("nextOriginMatching");
286 originsMatchingVectorT[l].clear();
287 originsMatchingVectorT[l].resize(array->GetNumberOfTuples());
288 for(unsigned int i = 0; i < originsMatchingVectorT[l].size(); ++i)
289 originsMatchingVectorT[l][i] = array->GetVariantValue(i).ToUnsignedInt();
290 reverseMatchingVector<float>(originsPrime_[l].mTree,
291 originsMatchingVectorT[l],
292 invOriginsMatchingVectorT[l]);
293 }
294 auto dataMatchingSize = getLatentLayerIndex() + 2;
295 std::vector<std::vector<std::vector<ttk::ftm::idNode>>> dataMatchingVectorT(
296 dataMatchingSize),
297 invDataMatchingVectorT = dataMatchingVectorT;
298 std::vector<std::vector<ttk::ftm::idNode>> invReconstMatchingVectorT(
299 numberOfInputs);
300 if(not customRec) {
301 for(unsigned int l = 0; l < dataMatchingVectorT.size(); ++l) {
302 dataMatchingVectorT[l].resize(numberOfInputs);
303 invDataMatchingVectorT[l].resize(numberOfInputs);
304 for(unsigned int i = 0; i < dataMatchingVectorT[l].size(); ++i) {
305 std::stringstream ss;
306 ss << "matching"
307 << ttk::axa::getTableTreeName(dataMatchingVectorT[l].size(), i);
308 auto array = vtkTable::SafeDownCast(
309 (l == 0 ? vS : vSPrime)->GetBlock((l == 0 ? l : l - 1)))
310 ->GetColumnByName(ss.str().c_str());
311 dataMatchingVectorT[l][i].clear();
312 dataMatchingVectorT[l][i].resize(array->GetNumberOfTuples());
313 for(unsigned int j = 0; j < dataMatchingVectorT[l][i].size(); ++j)
314 dataMatchingVectorT[l][i][j]
315 = array->GetVariantValue(j).ToUnsignedInt();
316 auto noNodes
317 = (l == 0 ? vtkTable::SafeDownCast(coefficients->GetBlock(0))
318 ->GetColumnByName("treeNoNodes")
319 ->GetVariantValue(i)
320 .ToUnsignedInt()
321 : recs_[i][l - 1].mTree.tree.getNumberOfNodes());
323 noNodes, dataMatchingVectorT[l][i], invDataMatchingVectorT[l][i]);
324 }
325 }
326 for(unsigned int i = 0; i < invReconstMatchingVectorT.size(); ++i) {
327 std::stringstream ss;
328 ss << "reconstMatching"
329 << ttk::axa::getTableTreeName(invReconstMatchingVectorT.size(), i);
330 auto array = vtkTable::SafeDownCast(vSPrime->GetBlock(noLayers_ - 1))
331 ->GetColumnByName(ss.str().c_str());
332 invReconstMatchingVectorT[i].resize(array->GetNumberOfTuples());
333 for(unsigned int j = 0; j < invReconstMatchingVectorT[i].size(); ++j)
334 invReconstMatchingVectorT[i][j]
335 = array->GetVariantValue(j).ToUnsignedInt();
336 }
337 }
338
339 // ------------------------------------------
340 // --- Tracking information
341 // ------------------------------------------
342 std::vector<std::vector<ttk::ftm::idNode>> originsMatchingVector;
343 std::vector<std::vector<double>> originsPersPercent, originsPersDiff;
344 std::vector<double> originPersPercent, originPersDiff;
345 std::vector<int> originPersistenceOrder;
346 ttk::wae::computeTrackingInformation(
347 origins_, originsPrime_, originsMatchingVectorT, invOriginsMatchingVectorT,
348 isPersistenceDiagram_, originsMatchingVector, originsPersPercent,
349 originsPersDiff, originPersPercent, originPersDiff, originPersistenceOrder);
350
351 // ------------------------------------------
352 // --- Data
353 // ------------------------------------------
354 if(!recs_.empty()) {
355 output_data->SetNumberOfBlocks(1);
358 data->SetNumberOfBlocks(recs_[0].size());
361 dataSeg->SetNumberOfBlocks(recs_.size());
362 for(unsigned int l = 0; l < recs_[0].size(); ++l) {
365 out_layer_i->SetNumberOfBlocks(recs_.size());
366 std::vector<ttk::ftm::MergeTree<float> *> trees(recs_.size());
367 for(unsigned int i = 0; i < recs_.size(); ++i)
368 trees[i] = &(recs_[i][l].mTree);
369
370 // Custom arrays
371 std::vector<std::vector<std::tuple<std::string, std::vector<int>>>>
372 customIntArrays(recs_.size());
373 std::vector<std::vector<std::tuple<std::string, std::vector<double>>>>
374 customDoubleArrays(recs_.size());
375 unsigned int lShift = 1;
376 ttk::wae::computeCustomArrays(
377 recs_, persCorrelationMatrix_, invDataMatchingVectorT,
378 invReconstMatchingVectorT, originsMatchingVector,
379 originsMatchingVectorT, originsPersPercent, originsPersDiff,
380 originPersistenceOrder, l, lShift, customIntArrays, customDoubleArrays);
381
382 // Create output
383 ttk::wae::makeManyOutput(trees, out_layer_i, customIntArrays,
384 customDoubleArrays, mixtureCoefficient_,
386 this->debugLevel_);
387 data->SetBlock(l, out_layer_i);
388 std::stringstream ss;
389 ss << "Layer" << l;
390 data->GetMetaData(l)->Set(vtkCompositeDataSet::NAME(), ss.str());
391 }
392 output_data->SetBlock(0, data);
393 unsigned int num = 0;
394 std::stringstream ss;
395 ss << "layers" << (isPersistenceDiagram_ ? "Diagrams" : "Trees");
396 output_data->GetMetaData(num)->Set(vtkCompositeDataSet::NAME(), ss.str());
397 }
398
399 if(!customRecs_.empty()) {
400 std::vector<std::vector<std::tuple<std::string, std::vector<int>>>>
401 customRecsIntArrays(customRecs_.size());
402 std::vector<std::vector<std::tuple<std::string, std::vector<double>>>>
403 customRecsDoubleArrays(customRecs_.size());
404 std::vector<ttk::ftm::MergeTree<float> *> trees(customRecs_.size());
405 for(unsigned int i = 0; i < customRecs_.size(); ++i)
406 trees[i] = &(customRecs_[i].mTree);
407 std::vector<std::vector<int>> customOriginPersOrder(customRecs_.size());
408 for(unsigned int i = 0; i < customRecs_.size(); ++i) {
409 trees[i] = &(customRecs_[i].mTree);
410 std::vector<ttk::ftm::idNode> matchingVector;
411 getInverseMatchingVector(origins_[0].mTree, customRecs_[i].mTree,
412 customMatchings_[i], matchingVector);
413 customOriginPersOrder[i].resize(
414 customRecs_[i].mTree.tree.getNumberOfNodes());
415 for(unsigned int j = 0; j < matchingVector.size(); ++j) {
416 if(matchingVector[j] < originPersistenceOrder.size())
417 customOriginPersOrder[i][j]
418 = originPersistenceOrder[matchingVector[j]];
419 else
420 customOriginPersOrder[i][j] = -1;
421 }
422 std::string name4{"OriginPersOrder"};
423 customRecsIntArrays[i].emplace_back(
424 std::make_tuple(name4, customOriginPersOrder[i]));
425 }
428 ttk::wae::makeManyOutput(trees, dataCustom, customRecsIntArrays,
429 customRecsDoubleArrays, mixtureCoefficient_,
431 this->debugLevel_);
432 output_data->SetBlock(0, dataCustom);
433 }
434
435 // return success
436 return 1;
437#endif
438}
#define TTK_FORCE_USE(x)
Force the compiler to use the function/method parameter.
Definition BaseClass.h:57
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
Definition BaseClass.h:47
static vtkInformationIntegerKey * SAME_DATA_TYPE_AS_INPUT_PORT()
TTK VTK-filter that wraps the ttk::MergeTreeAutoencoderDecoding module.
int FillInputPortInformation(int port, vtkInformation *info) override
int FillOutputPortInformation(int port, vtkInformation *info) override
int RequestData(vtkInformation *request, vtkInformationVector **inputVector, vtkInformationVector *outputVector) override
int debugLevel_
Definition Debug.h:379
int printErr(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
Definition Debug.h:149
void execute(std::vector< ttk::ftm::MergeTree< float > > &originsTrees, std::vector< ttk::ftm::MergeTree< float > > &originsPrimeTrees, std::vector< unsigned int * > &allRevNodeCorr, std::vector< unsigned int * > &allRevNodeCorrPrime, std::vector< unsigned int > &allRevNodeCorrSize, std::vector< unsigned int > &allRevNodeCorrPrimeSize)
std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > customMatchings_
std::vector< std::vector< double > > persCorrelationMatrix_
void reverseMatchingVector(unsigned int noNodes, std::vector< ftm::idNode > &matchingVector, std::vector< ftm::idNode > &invMatchingVector)
void getInverseMatchingVector(ftm::MergeTree< dataType > &barycenter, ftm::MergeTree< dataType > &tree, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &matchings, std::vector< ftm::idNode > &matchingVector)
void setParamValueFromName(std::string &paramName, double value)
void getParamNames(std::vector< std::string > &paramNames)
double getParamValueFromName(std::string &paramName)
std::string getTableVectorName(int noAxes, int axeNum, int vId, int vComp, bool isSecondInput)
std::string getTableTreeName(int noTrees, int treeNum)
std::string getTableCoefficientName(int noAxes, int axeNum)
void loadBlocks(std::vector< vtkSmartPointer< vtkMultiBlockDataSet > > &inputTrees, vtkMultiBlockDataSet *blocks)
vtkStandardNewMacro(ttkMergeTreeAutoencoderDecoding)
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/|__ _|"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)