TTK
Loading...
Searching...
No Matches
MergeTreeAutoencoderUtils.cpp
Go to the documentation of this file.
2
4 double eps = 1e-6;
5 auto shiftSubtree
6 = [&mTree, &eps](ftm::idNode node, ftm::idNode birthNodeParent,
7 ftm::idNode deathNodeParent, std::vector<float> &scalars,
8 bool invalidBirth, bool invalidDeath) {
9 std::queue<ftm::idNode> queue;
10 queue.emplace(node);
11 while(!queue.empty()) {
12 ftm::idNode nodeT = queue.front();
13 queue.pop();
14 auto birthDeathNode = mTree.tree.getBirthDeathNode<float>(node);
15 auto birthNode = std::get<0>(birthDeathNode);
16 auto deathNode = std::get<1>(birthDeathNode);
17 if(invalidBirth)
18 scalars[birthNode] = scalars[birthNodeParent] + 2 * eps;
19 if(invalidDeath)
20 scalars[deathNode] = scalars[deathNodeParent] - 2 * eps;
21 std::vector<ftm::idNode> children;
22 mTree.tree.getChildren(nodeT, children);
23 for(auto &child : children)
24 queue.emplace(child);
25 }
26 };
27 std::vector<float> scalars;
28 getTreeScalars(mTree, scalars);
29 std::queue<ftm::idNode> queue;
30 auto root = mTree.tree.getRoot();
31 queue.emplace(root);
32 while(!queue.empty()) {
33 ftm::idNode node = queue.front();
34 queue.pop();
35 auto birthDeathNode = mTree.tree.getBirthDeathNode<float>(node);
36 auto birthNode = std::get<0>(birthDeathNode);
37 auto deathNode = std::get<1>(birthDeathNode);
38 auto birthDeathNodeParent
39 = mTree.tree.getBirthDeathNode<float>(mTree.tree.getParentSafe(node));
40 auto birthNodeParent = std::get<0>(birthDeathNodeParent);
41 auto deathNodeParent = std::get<1>(birthDeathNodeParent);
42 bool invalidBirth = (scalars[birthNode] <= scalars[birthNodeParent] + eps);
43 bool invalidDeath = (scalars[deathNode] >= scalars[deathNodeParent] - eps);
44 if(!mTree.tree.isRoot(node) and (invalidBirth or invalidDeath))
45 shiftSubtree(node, birthNodeParent, deathNodeParent, scalars,
46 invalidBirth, invalidDeath);
47 std::vector<ftm::idNode> children;
48 mTree.tree.getChildren(node, children);
49 for(auto &child : children)
50 queue.emplace(child);
51 }
52 ftm::setTreeScalars<float>(mTree, scalars);
53}
54
55void ttk::wae::adjustNestingScalars(std::vector<float> &scalarsVector,
56 ftm::idNode node,
57 ftm::idNode refNode) {
58 float birth = scalarsVector[refNode * 2];
59 float death = scalarsVector[refNode * 2 + 1];
60 auto getSign = [](float v) { return (v > 0 ? 1 : -1); };
61 auto getPrecValue = [&getSign](float v, bool opp = false) {
62 return v * (1 + (opp ? -1 : 1) * getSign(v) * 1e-6);
63 };
64 // Shift scalars
65 if(scalarsVector[node * 2 + 1] > getPrecValue(death, true)) {
66 float diff = scalarsVector[node * 2 + 1] - getPrecValue(death, true);
67 scalarsVector[node * 2] -= diff;
68 scalarsVector[node * 2 + 1] -= diff;
69 } else if(scalarsVector[node * 2] < getPrecValue(birth)) {
70 float diff = getPrecValue(birth) - scalarsVector[node * 2];
71 scalarsVector[node * 2] += getPrecValue(diff);
72 scalarsVector[node * 2 + 1] += getPrecValue(diff);
73 }
74 // Cut scalars
75 if(scalarsVector[node * 2] < getPrecValue(birth))
76 scalarsVector[node * 2] = getPrecValue(birth);
77 if(scalarsVector[node * 2 + 1] > getPrecValue(death, true))
78 scalarsVector[node * 2 + 1] = getPrecValue(death, true);
79}
80
82 std::vector<std::vector<ftm::idNode>> &parents,
83 std::vector<std::vector<ftm::idNode>> &children,
84 std::vector<float> &scalarsVector,
85 std::vector<std::vector<ftm::idNode>> &childrenFinal,
86 int threadNumber) {
87 // ----- Some variables
88 unsigned int noNodes = scalarsVector.size() / 2;
89 childrenFinal.resize(noNodes);
90 int mtLevel = ceil(log(noNodes * 2) / log(2)) + 1;
91 int bdtLevel = mtLevel - 1;
92 int noDim = bdtLevel;
93
94 // ----- Get node levels
95 std::vector<int> nodeLevels(noNodes, -1);
96 std::queue<ftm::idNode> queueLevels;
97 std::vector<int> noChildDone(noNodes, 0);
98 for(unsigned int i = 0; i < children.size(); ++i) {
99 if(children[i].size() == 0) {
100 queueLevels.emplace(i);
101 nodeLevels[i] = 1;
102 }
103 }
104 while(!queueLevels.empty()) {
105 ftm::idNode node = queueLevels.front();
106 queueLevels.pop();
107 for(auto &parent : parents[node]) {
108 ++noChildDone[parent];
109 nodeLevels[parent] = std::max(nodeLevels[parent], nodeLevels[node] + 1);
110 if(noChildDone[parent] >= (int)children[parent].size())
111 queueLevels.emplace(parent);
112 }
113 }
114
115 // ----- Sort heuristic lambda
116 auto sortChildren = [&parents, &scalarsVector, &noNodes, &threadNumber](
117 ftm::idNode nodeOrigin, std::vector<bool> &nodeDone,
118 std::vector<std::vector<ftm::idNode>> &childrenT) {
119 double refPers = scalarsVector[1] - scalarsVector[0];
120 auto getRemaining = [&nodeDone](std::vector<ftm::idNode> &vec) {
121 unsigned int remaining = 0;
122 for(auto &e : vec)
123 remaining += (not nodeDone[e]);
124 return remaining;
125 };
126 std::vector<unsigned int> parentsRemaining(noNodes, 0),
127 childrenRemaining(noNodes, 0);
128 for(auto &child : childrenT[nodeOrigin]) {
129 parentsRemaining[child] = getRemaining(parents[child]);
130 childrenRemaining[child] = getRemaining(childrenT[child]);
131 }
132 TTK_FORCE_USE(threadNumber);
133 TTK_PSORT(
134 threadNumber, childrenT[nodeOrigin].begin(), childrenT[nodeOrigin].end(),
135 [&](ftm::idNode nodeI, ftm::idNode nodeJ) {
136 double persI = scalarsVector[nodeI * 2 + 1] - scalarsVector[nodeI * 2];
137 double persJ = scalarsVector[nodeJ * 2 + 1] - scalarsVector[nodeJ * 2];
138 return parentsRemaining[nodeI] + childrenRemaining[nodeI]
139 - persI / refPers * noNodes
140 < parentsRemaining[nodeJ] + childrenRemaining[nodeJ]
141 - persJ / refPers * noNodes;
142 });
143 };
144
145 // ----- Greedy approach to find balanced BDT structures
146 const auto findStructGivenDim =
147 [&children, &noNodes, &nodeLevels](
148 ftm::idNode _nodeOrigin, int _dimToFound, bool _searchMaxDim,
149 std::vector<bool> &_nodeDone, std::vector<bool> &_dimFound,
150 std::vector<std::vector<ftm::idNode>> &_childrenFinalOut) {
151 // --- Recursive lambda
152 auto findStructGivenDimImpl =
153 [&children, &noNodes, &nodeLevels](
154 ftm::idNode nodeOrigin, int dimToFound, bool searchMaxDim,
155 std::vector<bool> &nodeDone, std::vector<bool> &dimFound,
156 std::vector<std::vector<ftm::idNode>> &childrenFinalOut,
157 auto &findStructGivenDimRef) mutable {
158 childrenFinalOut.resize(noNodes);
159 // - Find structures
160 int dim = (searchMaxDim ? dimToFound - 1 : 0);
161 unsigned int i = 0;
162 //
163 auto searchMaxDimReset = [&i, &dim, &nodeDone]() {
164 --dim;
165 i = 0;
166 unsigned int noDone = 0;
167 for(auto done : nodeDone)
168 if(done)
169 ++noDone;
170 return noDone == nodeDone.size() - 1; // -1 for root
171 };
172 while(i < children[nodeOrigin].size()) {
173 auto child = children[nodeOrigin][i];
174 // Skip if child was already processed
175 if(nodeDone[child]) {
176 // If we have processed all children while searching for max
177 // dim then restart at the beginning to find a lower dim
178 if(searchMaxDim and i == children[nodeOrigin].size() - 1) {
179 if(searchMaxDimReset())
180 break;
181 } else
182 ++i;
183 continue;
184 }
185 if(dim == 0) {
186 // Base case
187 childrenFinalOut[nodeOrigin].emplace_back(child);
188 nodeDone[child] = true;
189 dimFound[0] = true;
190 if(dimToFound <= 1 or searchMaxDim)
191 return true;
192 ++dim;
193 } else {
194 // General case
195 std::vector<std::vector<ftm::idNode>> childrenFinalDim;
196 std::vector<bool> nodeDoneDim;
197 std::vector<bool> dimFoundDim(dim);
198 bool found = false;
199 if(nodeLevels[child] > dim) {
200 nodeDoneDim = nodeDone;
201 found = findStructGivenDimRef(child, dim, false, nodeDoneDim,
202 dimFoundDim, childrenFinalDim,
203 findStructGivenDimRef);
204 }
205 if(found) {
206 dimFound[dim] = true;
207 childrenFinalOut[nodeOrigin].emplace_back(child);
208 for(unsigned int j = 0; j < childrenFinalDim.size(); ++j)
209 for(auto &e : childrenFinalDim[j])
210 childrenFinalOut[j].emplace_back(e);
211 nodeDone[child] = true;
212 for(unsigned int j = 0; j < nodeDoneDim.size(); ++j)
213 nodeDone[j] = nodeDone[j] || nodeDoneDim[j];
214 // Return if it is the last dim to found
215 if(dim == dimToFound - 1 and not searchMaxDim)
216 return true;
217 // Reset index if we search for the maximum dim
218 if(searchMaxDim) {
219 if(searchMaxDimReset())
220 break;
221 } else {
222 ++dim;
223 }
224 continue;
225 } else if(searchMaxDim and i == children[nodeOrigin].size() - 1) {
226 // If we have processed all children while searching for max
227 // dim then restart at the beginning to find a lower dim
228 if(searchMaxDimReset())
229 break;
230 continue;
231 }
232 }
233 ++i;
234 }
235 return false;
236 };
237 return findStructGivenDimImpl(_nodeOrigin, _dimToFound, _searchMaxDim,
238 _nodeDone, _dimFound, _childrenFinalOut,
239 findStructGivenDimImpl);
240 };
241 std::vector<bool> dimFound(noDim - 1, false);
242 std::vector<bool> nodeDone(noNodes, false);
243 for(unsigned int i = 0; i < children.size(); ++i)
244 sortChildren(i, nodeDone, children);
245 Timer t_find;
246 ftm::idNode startNode = 0;
247 findStructGivenDim(startNode, noDim, true, nodeDone, dimFound, childrenFinal);
248
249 // ----- Greedy approach to create non found structures
250 const auto createStructGivenDim =
251 [&children, &noNodes, &findStructGivenDim, &nodeLevels](
252 int _nodeOrigin, int _dimToCreate, std::vector<bool> &_nodeDone,
253 ftm::idNode &_structOrigin, std::vector<float> &_scalarsVectorOut,
254 std::vector<std::vector<ftm::idNode>> &_childrenFinalOut) {
255 // --- Recursive lambda
256 auto createStructGivenDimImpl =
257 [&children, &noNodes, &findStructGivenDim, &nodeLevels](
258 int nodeOrigin, int dimToCreate, std::vector<bool> &nodeDoneImpl,
259 ftm::idNode &structOrigin, std::vector<float> &scalarsVectorOut,
260 std::vector<std::vector<ftm::idNode>> &childrenFinalOut,
261 auto &createStructGivenDimRef) mutable {
262 // Deduction of auto lambda type
263 if(false)
264 return;
265 // - Find structures of lower dimension
266 int dimToFound = dimToCreate - 1;
267 std::vector<std::vector<std::vector<ftm::idNode>>> childrenFinalT(2);
268 std::array<ftm::idNode, 2> structOrigins;
269 for(unsigned int n = 0; n < 2; ++n) {
270 bool found = false;
271 for(unsigned int i = 0; i < children[nodeOrigin].size(); ++i) {
272 auto child = children[nodeOrigin][i];
273 if(nodeDoneImpl[child])
274 continue;
275 if(dimToFound != 0) {
276 if(nodeLevels[child] > dimToFound) {
277 std::vector<bool> dimFoundT(dimToFound, false);
278 childrenFinalT[n].clear();
279 childrenFinalT[n].resize(noNodes);
280 std::vector<bool> nodeDoneImplFind = nodeDoneImpl;
281 found = findStructGivenDim(child, dimToFound, false,
282 nodeDoneImplFind, dimFoundT,
283 childrenFinalT[n]);
284 }
285 } else
286 found = true;
287 if(found) {
288 structOrigins[n] = child;
289 nodeDoneImpl[child] = true;
290 for(unsigned int j = 0; j < childrenFinalT[n].size(); ++j) {
291 for(auto &e : childrenFinalT[n][j]) {
292 childrenFinalOut[j].emplace_back(e);
293 nodeDoneImpl[e] = true;
294 }
295 }
296 break;
297 }
298 } // end for children[nodeOrigin]
299 if(not found) {
300 if(dimToFound <= 0) {
301 structOrigins[n] = std::numeric_limits<ftm::idNode>::max();
302 continue;
303 }
304 childrenFinalT[n].clear();
305 childrenFinalT[n].resize(noNodes);
306 createStructGivenDimRef(
307 nodeOrigin, dimToFound, nodeDoneImpl, structOrigins[n],
308 scalarsVectorOut, childrenFinalT[n], createStructGivenDimRef);
309 for(unsigned int j = 0; j < childrenFinalT[n].size(); ++j) {
310 for(auto &e : childrenFinalT[n][j]) {
311 if(e == structOrigins[n])
312 continue;
313 childrenFinalOut[j].emplace_back(e);
314 }
315 }
316 }
317 } // end for n
318 // - Combine both structures
319 if(structOrigins[0] == std::numeric_limits<ftm::idNode>::max()
320 and structOrigins[1] == std::numeric_limits<ftm::idNode>::max()) {
321 structOrigin = std::numeric_limits<ftm::idNode>::max();
322 return;
323 }
324 bool firstIsParent = true;
325 if(structOrigins[0] == std::numeric_limits<ftm::idNode>::max())
326 firstIsParent = false;
327 else if(structOrigins[1] == std::numeric_limits<ftm::idNode>::max())
328 firstIsParent = true;
329 else if(scalarsVectorOut[structOrigins[1] * 2 + 1]
330 - scalarsVectorOut[structOrigins[1] * 2]
331 > scalarsVectorOut[structOrigins[0] * 2 + 1]
332 - scalarsVectorOut[structOrigins[0] * 2])
333 firstIsParent = false;
334 structOrigin = (firstIsParent ? structOrigins[0] : structOrigins[1]);
335 ftm::idNode modOrigin
336 = (firstIsParent ? structOrigins[1] : structOrigins[0]);
337 childrenFinalOut[nodeOrigin].emplace_back(structOrigin);
338 if(modOrigin != std::numeric_limits<ftm::idNode>::max()) {
339 childrenFinalOut[structOrigin].emplace_back(modOrigin);
340 std::queue<std::array<ftm::idNode, 2>> queue;
341 queue.emplace(std::array<ftm::idNode, 2>{modOrigin, structOrigin});
342 while(!queue.empty()) {
343 auto &nodeAndParent = queue.front();
344 ftm::idNode node = nodeAndParent[0];
345 ftm::idNode parent = nodeAndParent[1];
346 queue.pop();
347 adjustNestingScalars(scalarsVectorOut, node, parent);
348 // Push children
349 for(auto &child : childrenFinalOut[node])
350 queue.emplace(std::array<ftm::idNode, 2>{child, node});
351 }
352 }
353 return;
354 };
355 return createStructGivenDimImpl(
356 _nodeOrigin, _dimToCreate, _nodeDone, _structOrigin, _scalarsVectorOut,
357 _childrenFinalOut, createStructGivenDimImpl);
358 };
359 for(unsigned int i = 0; i < children.size(); ++i)
360 sortChildren(i, nodeDone, children);
361 Timer t_create;
362 for(unsigned int i = 0; i < dimFound.size(); ++i) {
363 if(dimFound[i])
364 continue;
365 ftm::idNode structOrigin;
366 createStructGivenDim(
367 startNode, i, nodeDone, structOrigin, scalarsVector, childrenFinal);
368 }
369}
370
372 std::stringstream ss;
373 if(mTree.tree.getRealNumberOfNodes() != 0)
374 ss = mTree.tree.template printPairsFromTree<float>(useBD);
375 else {
376 std::vector<bool> nodeDone(mTree.tree.getNumberOfNodes(), false);
377 for(unsigned int i = 0; i < mTree.tree.getNumberOfNodes(); ++i) {
378 if(nodeDone[i])
379 continue;
380 std::tuple<ftm::idNode, ftm::idNode, float> pair
381 = std::make_tuple(i, mTree.tree.getNode(i)->getOrigin(),
382 mTree.tree.getNodePersistence<float>(i));
383 ss << std::get<0>(pair) << " ("
384 << mTree.tree.getValue<float>(std::get<0>(pair)) << ") _ ";
385 ss << std::get<1>(pair) << " ("
386 << mTree.tree.getValue<float>(std::get<1>(pair)) << ") _ ";
387 ss << std::get<2>(pair) << std::endl;
388 nodeDone[i] = true;
389 nodeDone[mTree.tree.getNode(i)->getOrigin()] = true;
390 }
391 }
392 ss << std::endl;
393 std::cout << ss.str();
394}
#define TTK_FORCE_USE(x)
Force the compiler to use the function/method parameter.
Definition BaseClass.h:57
#define TTK_PSORT(NTHREADS,...)
Parallel sort macro.
Definition OpenMP.h:46
const scalarType & getValue(SimplexId nodeId) const
Definition FTMTree_MT.h:339
dataType getNodePersistence(idNode nodeId)
idNode getNumberOfNodes() const
Definition FTMTree_MT.h:389
bool isRoot(idNode nodeId)
std::tuple< ftm::idNode, ftm::idNode > getBirthDeathNode(idNode nodeId)
idNode getParentSafe(idNode nodeId)
void getChildren(idNode nodeId, std::vector< idNode > &res)
Node * getNode(idNode nodeId)
Definition FTMTree_MT.h:393
SimplexId getOrigin() const
Definition FTMNode.h:64
unsigned int idNode
Node index in vect_nodes_.
void printPairs(ftm::MergeTree< float > &mTree, bool useBD=true)
Util function to print pairs of a merge tree.
void adjustNestingScalars(std::vector< float > &scalarsVector, ftm::idNode node, ftm::idNode refNode)
Fix the scalars of a merge tree to ensure that the nesting condition is respected.
void createBalancedBDT(std::vector< std::vector< ftm::idNode > > &parents, std::vector< std::vector< ftm::idNode > > &children, std::vector< float > &scalarsVector, std::vector< std::vector< ftm::idNode > > &childrenFinal, int threadNumber=1)
Create a balanced BDT structure (for output basis initialization).
void fixTreePrecisionScalars(ftm::MergeTree< float > &mTree)
Fix the scalars of a merge tree to ensure that the nesting condition is respected.
T end(std::pair< T, T > &p)
Definition ripserpy.cpp:483
T begin(std::pair< T, T > &p)
Definition ripserpy.cpp:479
ftm::FTMTree_MT tree
Definition FTMTree_MT.h:903