TTK
Loading...
Searching...
No Matches
MergeTreeTemporalReduction.h
Go to the documentation of this file.
1
19
20#pragma once
21
22// ttk common includes
23#include <Debug.h>
24
25#include <FTMTreeUtils.h>
26#include <MergeTreeBarycenter.h>
27#include <MergeTreeBase.h>
28#include <MergeTreeDistance.h>
29
30namespace ttk {
31
36 class MergeTreeTemporalReduction : virtual public Debug,
37 public MergeTreeBase {
38 protected:
39 double removalPercentage_ = 50.;
40 bool useL2Distance_ = false;
41 std::vector<std::vector<double>> fieldL2_;
43 std::vector<double> timeVariable_;
44
45 public:
47
48 void setRemovalPercentage(double rs) {
50 }
51
52 void setUseL2Distance(bool useL2) {
53 useL2Distance_ = useL2;
54 }
55
56 template <class dataType>
57 dataType computeL2Distance(std::vector<dataType> &img1,
58 std::vector<dataType> &img2,
59 bool emptyFieldDistance = false) {
60 size_t const noPoints = img1.size();
61
62 std::vector<dataType> secondField = img2;
63 if(emptyFieldDistance)
64 secondField = std::vector<dataType>(noPoints, 0);
65
66 dataType distance = 0;
67
68 for(size_t i = 0; i < noPoints; ++i)
69 distance += std::pow((img1[i] - secondField[i]), 2);
70
71 distance = std::sqrt(distance);
72
73 return distance;
74 }
75
76 template <class dataType>
77 std::vector<dataType> computeL2Barycenter(std::vector<dataType> &img1,
78 std::vector<dataType> &img2,
79 double alpha) {
80
81 size_t const noPoints = img1.size();
82
83 std::vector<dataType> barycenter(noPoints);
84 for(size_t i = 0; i < noPoints; ++i)
85 barycenter[i] = alpha * img1[i] * (1 - alpha) * img2[i];
86
87 return barycenter;
88 }
89
90 template <class dataType>
93 bool emptyTreeDistance = false) {
94 MergeTreeDistance mergeTreeDistance;
96 mergeTreeDistance.setEpsilonTree1(epsilonTree1_);
97 mergeTreeDistance.setEpsilonTree2(epsilonTree2_);
98 mergeTreeDistance.setEpsilon2Tree1(epsilon2Tree1_);
99 mergeTreeDistance.setEpsilon2Tree2(epsilon2Tree2_);
100 mergeTreeDistance.setEpsilon3Tree1(epsilon3Tree1_);
101 mergeTreeDistance.setEpsilon3Tree2(epsilon3Tree2_);
103 mergeTreeDistance.setParallelize(parallelize_);
106 mergeTreeDistance.setKeepSubtree(keepSubtree_);
107 mergeTreeDistance.setUseMinMaxPair(useMinMaxPair_);
108 mergeTreeDistance.setThreadNumber(this->threadNumber_);
109 mergeTreeDistance.setDistanceSquaredRoot(true); // squared root
110 mergeTreeDistance.setDebugLevel(2);
111 mergeTreeDistance.setPreprocess(false);
112 mergeTreeDistance.setPostprocess(false);
113 // mergeTreeDistance.setIsCalled(true);
114 mergeTreeDistance.setOnlyEmptyTreeDistance(emptyTreeDistance);
115
116 std::vector<std::tuple<ftm::idNode, ftm::idNode, double>> matching;
117 dataType distance
118 = mergeTreeDistance.execute<dataType>(mTree1, mTree2, matching);
119
120 return distance;
121 }
122
123 template <class dataType>
126 double alpha) {
127 MergeTreeBarycenter mergeTreeBarycenter;
128 mergeTreeBarycenter.setAssignmentSolver(assignmentSolverID_);
129 mergeTreeBarycenter.setEpsilonTree1(epsilonTree1_);
130 mergeTreeBarycenter.setEpsilonTree2(epsilonTree2_);
131 mergeTreeBarycenter.setEpsilon2Tree1(epsilon2Tree1_);
132 mergeTreeBarycenter.setEpsilon2Tree2(epsilon2Tree2_);
133 mergeTreeBarycenter.setEpsilon3Tree1(epsilon3Tree1_);
134 mergeTreeBarycenter.setEpsilon3Tree2(epsilon3Tree2_);
135 mergeTreeBarycenter.setBranchDecomposition(branchDecomposition_);
136 mergeTreeBarycenter.setParallelize(parallelize_);
139 mergeTreeBarycenter.setKeepSubtree(keepSubtree_);
140 mergeTreeBarycenter.setUseMinMaxPair(useMinMaxPair_);
141 mergeTreeBarycenter.setThreadNumber(this->threadNumber_);
142 mergeTreeBarycenter.setAlpha(alpha);
143 mergeTreeBarycenter.setDebugLevel(2);
144 mergeTreeBarycenter.setPreprocess(false);
145 mergeTreeBarycenter.setPostprocess(false);
146 // mergeTreeBarycenter.setIsCalled(true);
147
148 std::vector<ftm::MergeTree<dataType>> intermediateTrees;
149 intermediateTrees.push_back(mTree1);
150 intermediateTrees.push_back(mTree2);
151 std::vector<std::vector<std::tuple<ftm::idNode, ftm::idNode, double>>>
152 outputMatchingBarycenter(2);
153 ftm::MergeTree<dataType> barycenter;
154 mergeTreeBarycenter.execute<dataType>(
155 intermediateTrees, outputMatchingBarycenter, barycenter);
156 return barycenter;
157 }
158
159 double computeAlpha(int index1, int middleIndex, int index2) {
160 index1 = timeVariable_[index1];
161 middleIndex = timeVariable_[middleIndex];
162 index2 = timeVariable_[index2];
163 return 1 - ((double)middleIndex - index1) / (index2 - index1);
164 }
165
166 template <class dataType>
167 void
169 std::vector<int> &removed,
170 std::vector<ftm::MergeTree<dataType>> &barycenters,
171 std::vector<std::vector<dataType>> &barycentersL2) {
172 std::vector<bool> treeRemoved(mTrees.size(), false);
173
174 int toRemoved = mTrees.size() * removalPercentage_ / 100.;
175 toRemoved = std::min(toRemoved, (int)(mTrees.size() - 3));
176
177 std::vector<std::vector<dataType>> images(fieldL2_.size());
178 for(size_t i = 0; i < fieldL2_.size(); ++i)
179 for(size_t j = 0; j < fieldL2_[i].size(); ++j)
180 images[i].push_back(static_cast<dataType>(fieldL2_[i][j]));
181
182 for(int iter = 0; iter < toRemoved; ++iter) {
183 dataType bestCost = std::numeric_limits<dataType>::max();
184 int bestMiddleIndex = -1;
185 ftm::MergeTree<dataType> bestBarycenter;
186 std::vector<std::tuple<ftm::MergeTree<dataType>, int>>
187 bestBarycentersOnPath;
188 std::vector<dataType> bestBarycenterL2;
189 std::vector<std::tuple<std::vector<dataType>, int>>
190 bestBarycentersL2OnPath;
191
192 // Compute barycenter for each pair of trees
193 printMsg("Compute barycenter for each pair of trees",
195 unsigned int index1 = 0, index2 = 0;
196 while(index2 != mTrees.size() - 1) {
197
198 // Get index in the middle
199 int middleIndex = index1 + 1;
200 while(treeRemoved[middleIndex])
201 ++middleIndex;
202
203 // Get second index
204 index2 = middleIndex + 1;
205 while(treeRemoved[index2])
206 ++index2;
207
208 // Compute barycenter
209 printMsg("Compute barycenter", debug::Priority::VERBOSE);
210 double const alpha = computeAlpha(index1, middleIndex, index2);
211 ftm::MergeTree<dataType> barycenter;
212 std::vector<dataType> barycenterL2;
213 if(not useL2Distance_)
214 barycenter = computeBarycenter<dataType>(
215 mTrees[index1], mTrees[index2], alpha);
216 else
217 barycenterL2 = computeL2Barycenter<dataType>(
218 images[index1], images[index2], alpha);
219
220 // - Compute cost
221 // Compute distance with middleIndex
222 printMsg(
223 "Compute distance with middleIndex", debug::Priority::VERBOSE);
224 dataType cost;
225 if(not useL2Distance_)
226 cost = computeDistance<dataType>(barycenter, mTrees[middleIndex]);
227 else
228 cost
229 = computeL2Distance<dataType>(barycenterL2, images[middleIndex]);
230
231 // Compute distances of previously removed trees on the path
232 printMsg("Compute distances of previously removed trees",
234 std::vector<std::tuple<ftm::MergeTree<dataType>, int>>
235 barycentersOnPath;
236 std::vector<std::tuple<std::vector<dataType>, int>>
237 barycentersL2OnPath;
238 for(unsigned int i = 0; i < 2; ++i) {
239 int const toReach = (i == 0 ? index1 : index2);
240 int const offset = (i == 0 ? -1 : 1);
241 int tIndex = middleIndex + offset;
242 while(tIndex != toReach) {
243
244 // Compute barycenter
245 double const alphaT = computeAlpha(index1, tIndex, index2);
246 ftm::MergeTree<dataType> barycenterP;
247 std::vector<dataType> barycenterPL2;
248 if(not useL2Distance_)
249 barycenterP = computeBarycenter<dataType>(
250 mTrees[index1], mTrees[index2], alphaT);
251 else
252 barycenterPL2 = computeL2Barycenter<dataType>(
253 images[index1], images[index2], alphaT);
254
255 // Compute distance
256 dataType costP;
257 if(not useL2Distance_)
258 costP = computeDistance<dataType>(barycenterP, mTrees[tIndex]);
259 else
260 costP
261 = computeL2Distance<dataType>(barycenterPL2, images[tIndex]);
262
263 // Save results
264 if(not useL2Distance_)
265 barycentersOnPath.push_back(
266 std::make_tuple(barycenterP, tIndex));
267 else
268 barycentersL2OnPath.push_back(
269 std::make_tuple(barycenterPL2, tIndex));
270 cost += costP;
271 tIndex += offset;
272 }
273 }
274
275 if(cost < bestCost) {
276 bestCost = cost;
277 bestMiddleIndex = middleIndex;
278 if(not useL2Distance_) {
279 bestBarycenter = barycenter;
280 bestBarycentersOnPath = barycentersOnPath;
281 } else {
282 bestBarycenterL2 = barycenterL2;
283 bestBarycentersL2OnPath = barycentersL2OnPath;
284 }
285 }
286
287 // Go to the next index
288 index1 = middleIndex;
289 }
290
291 // Removed the tree with the lowest cost
292 printMsg(
293 "Removed the tree with the lowest cost", debug::Priority::VERBOSE);
294 removed.push_back(bestMiddleIndex);
295 treeRemoved[bestMiddleIndex] = true;
296 if(not useL2Distance_) {
297 barycenters[bestMiddleIndex] = bestBarycenter;
298 for(auto &tup : bestBarycentersOnPath)
299 barycenters[std::get<1>(tup)] = std::get<0>(tup);
300 } else {
301 barycentersL2[bestMiddleIndex] = bestBarycenterL2;
302 for(auto &tup : bestBarycentersL2OnPath)
303 barycentersL2[std::get<1>(tup)] = std::get<0>(tup);
304 }
305 }
306 }
307
308 template <class dataType>
309 std::vector<int> execute(std::vector<ftm::MergeTree<dataType>> &mTrees,
310 std::vector<double> &emptyTreeDistances,
311 std::vector<ftm::MergeTree<dataType>> &allMT) {
312 Timer t_tempSub;
313
314 // --- Preprocessing
315 if(not useL2Distance_) {
316 treesNodeCorr_ = std::vector<std::vector<int>>(mTrees.size());
317 for(unsigned int i = 0; i < mTrees.size(); ++i) {
318 preprocessingPipeline<dataType>(mTrees[i], epsilonTree2_,
322 }
323 printTreesStats<dataType>(mTrees);
324 }
325
326 // --- Execute
327 std::vector<ftm::MergeTree<dataType>> barycenters(mTrees.size());
328 std::vector<std::vector<dataType>> barycentersL2(mTrees.size());
329 std::vector<int> removed;
330 if(not useCustomTimeVariable_) {
331 timeVariable_.clear();
332 for(size_t i = 0; i < mTrees.size(); ++i)
333 timeVariable_.push_back(i);
334 }
335 temporalSubsampling<dataType>(
336 mTrees, removed, barycenters, barycentersL2);
337
338 // --- Concatenate all trees/L2Images
339 std::vector<std::vector<dataType>> images(fieldL2_.size());
340 for(size_t i = 0; i < fieldL2_.size(); ++i)
341 for(size_t j = 0; j < fieldL2_[i].size(); ++j)
342 images[i].push_back(static_cast<dataType>(fieldL2_[i][j]));
343
344 for(auto &mt : mTrees)
345 allMT.push_back(mt);
346 std::vector<bool> removedB(mTrees.size(), false);
347 for(auto r : removed)
348 removedB[r] = true;
349 for(unsigned int i = 0; i < barycenters.size(); ++i)
350 if(removedB[i]) {
351 if(not useL2Distance_)
352 allMT.push_back(barycenters[i]);
353 else
354 images.push_back(barycentersL2[i]);
355 }
356
357 // --- Compute empty tree distances
358 unsigned int const distMatSize
359 = (not useL2Distance_ ? allMT.size() : images.size());
360 for(unsigned int i = 0; i < distMatSize; ++i) {
361 dataType distance;
362 if(not useL2Distance_)
363 distance = computeDistance<dataType>(allMT[i], allMT[i], true);
364 else
365 distance = computeL2Distance<dataType>(images[i], images[i], true);
366 emptyTreeDistances.push_back(distance);
367 }
368
369 // --- Postprocessing
370 if(not useL2Distance_) {
371 for(unsigned int i = 0; i < allMT.size(); ++i)
372 postprocessingPipeline<dataType>(&(allMT[i].tree));
373 for(unsigned int i = 0; i < mTrees.size(); ++i)
374 postprocessingPipeline<dataType>(&(mTrees[i].tree));
375 }
376
377 // --- Print results
378 std::stringstream ss, ss2, ss3;
379 ss << "input size = " << mTrees.size();
380 printMsg(ss.str());
381 ss2 << "output size = "
382 << mTrees.size() - (distMatSize - mTrees.size());
383 printMsg(ss2.str());
384 ss3 << "removed : ";
385 for(unsigned int i = 0; i < removed.size(); ++i) {
386 auto r = removed[i];
387 ss3 << r;
388 if(i < removed.size() - 1)
389 ss3 << ", ";
390 }
391 printMsg(ss3.str());
392
393 sort(removed.begin(), removed.end());
394
395 printMsg("Encoding", 1, t_tempSub.getElapsedTime(), this->threadNumber_);
396
397 return removed;
398 }
399
400 }; // MergeTreeTemporalReduction class
401
402} // namespace ttk
virtual int setThreadNumber(const int threadNumber)
Definition BaseClass.h:80
Minimalist debugging class.
Definition Debug.h:88
virtual int setDebugLevel(const int &debugLevel)
Definition Debug.cpp:147
void setPreprocess(bool preproc)
void setPostprocess(bool postproc)
void execute(std::vector< ftm::MergeTree< dataType > > &trees, std::vector< double > &alphas, std::vector< std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > > &finalMatchings, ftm::MergeTree< dataType > &baryMergeTree, bool finalAsgnDoubleInput=false, bool finalAsgnFirstInput=true)
void setBranchDecomposition(bool useBD)
void setNormalizedWasserstein(bool normalizedWasserstein)
void setDistanceSquaredRoot(bool distanceSquaredRoot)
void setEpsilon3Tree1(double epsilon)
void setEpsilonTree1(double epsilon)
void setAssignmentSolver(int assignmentSolver)
void setEpsilon2Tree1(double epsilon)
void setEpsilonTree2(double epsilon)
void setPersistenceThreshold(double pt)
std::vector< std::vector< int > > treesNodeCorr_
void setEpsilon2Tree2(double epsilon)
void setKeepSubtree(bool keepSubtree)
void setUseMinMaxPair(bool useMinMaxPair)
void setEpsilon3Tree2(double epsilon)
void setParallelize(bool para)
void setOnlyEmptyTreeDistance(double only)
void setPreprocess(bool preproc)
void setPostprocess(bool postproc)
dataType execute(ftm::MergeTree< dataType > &mTree1, ftm::MergeTree< dataType > &mTree2, std::vector< std::tuple< ftm::idNode, ftm::idNode, double > > &outputMatching)
dataType computeDistance(ftm::MergeTree< dataType > &mTree1, ftm::MergeTree< dataType > &mTree2, bool emptyTreeDistance=false)
double computeAlpha(int index1, int middleIndex, int index2)
std::vector< int > execute(std::vector< ftm::MergeTree< dataType > > &mTrees, std::vector< double > &emptyTreeDistances, std::vector< ftm::MergeTree< dataType > > &allMT)
std::vector< std::vector< double > > fieldL2_
ftm::MergeTree< dataType > computeBarycenter(ftm::MergeTree< dataType > &mTree1, ftm::MergeTree< dataType > &mTree2, double alpha)
void temporalSubsampling(std::vector< ftm::MergeTree< dataType > > &mTrees, std::vector< int > &removed, std::vector< ftm::MergeTree< dataType > > &barycenters, std::vector< std::vector< dataType > > &barycentersL2)
std::vector< dataType > computeL2Barycenter(std::vector< dataType > &img1, std::vector< dataType > &img2, double alpha)
dataType computeL2Distance(std::vector< dataType > &img1, std::vector< dataType > &img2, bool emptyFieldDistance=false)
double getElapsedTime()
Definition Timer.h:15
The Topology ToolKit.
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/|__ _|"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)