TTK
Loading...
Searching...
No Matches
BranchMappingDistance.h
Go to the documentation of this file.
1
13
14#pragma once
15
16#include <set>
17#include <vector>
18
19#include <algorithm>
20#include <cfloat>
21#include <chrono>
22#include <cmath>
23#include <iostream>
24#include <limits>
25#include <set>
26#include <stack>
27#include <tuple>
28#include <vector>
29
30// ttk common includes
31#include <AssignmentAuction.h>
33#include <AssignmentMunkres.h>
34#include <Debug.h>
35#include <FTMTree_MT.h>
36
37namespace ttk {
38
39 class BranchMappingDistance : virtual public Debug {
40
41 private:
42 int baseMetric_ = 0;
43 int assignmentSolverID_ = 0;
44 bool squared_ = false;
45
46 template <class dataType>
47 inline dataType editCost_Wasserstein1(int n1,
48 int p1,
49 int n2,
50 int p2,
51 ftm::FTMTree_MT *tree1,
52 ftm::FTMTree_MT *tree2) {
53 dataType d;
54 if(n1 < 0) {
55 dataType b1 = tree2->getValue<dataType>(n2);
56 dataType d1 = tree2->getValue<dataType>(p2);
57 dataType b2 = (b1 + d1) * 0.5;
58 dataType d2 = (b1 + d1) * 0.5;
59 dataType db = b1 > b2 ? b1 - b2 : b2 - b1;
60 dataType dd = d1 > d2 ? d1 - d2 : d2 - d1;
61 d = db + dd;
62 } else if(n2 < 0) {
63 dataType b1 = tree1->getValue<dataType>(n1);
64 dataType d1 = tree1->getValue<dataType>(p1);
65 dataType b2 = (b1 + d1) * 0.5;
66 dataType d2 = (b1 + d1) * 0.5;
67 dataType db = b1 > b2 ? b1 - b2 : b2 - b1;
68 dataType dd = d1 > d2 ? d1 - d2 : d2 - d1;
69 d = db + dd;
70 } else {
71 dataType b1 = tree1->getValue<dataType>(n1);
72 dataType d1 = tree1->getValue<dataType>(p1);
73 dataType b2 = tree2->getValue<dataType>(n2);
74 dataType d2 = tree2->getValue<dataType>(p2);
75 dataType db = b1 > b2 ? b1 - b2 : b2 - b1;
76 dataType dd = d1 > d2 ? d1 - d2 : d2 - d1;
77 d = db + dd;
78 }
79 return squared_ ? d * d : d;
80 }
81
82 template <class dataType>
83 inline dataType editCost_Wasserstein2(int n1,
84 int p1,
85 int n2,
86 int p2,
87 ftm::FTMTree_MT *tree1,
88 ftm::FTMTree_MT *tree2) {
89 dataType d;
90 if(n1 < 0) {
91 dataType b1 = tree2->getValue<dataType>(n2);
92 dataType d1 = tree2->getValue<dataType>(p2);
93 dataType b2 = (b1 + d1) * 0.5;
94 dataType d2 = (b1 + d1) * 0.5;
95 dataType db = b1 > b2 ? b1 - b2 : b2 - b1;
96 dataType dd = d1 > d2 ? d1 - d2 : d2 - d1;
97 d = std::sqrt(db * db + dd * dd);
98 } else if(n2 < 0) {
99 dataType b1 = tree1->getValue<dataType>(n1);
100 dataType d1 = tree1->getValue<dataType>(p1);
101 dataType b2 = (b1 + d1) * 0.5;
102 dataType d2 = (b1 + d1) * 0.5;
103 dataType db = b1 > b2 ? b1 - b2 : b2 - b1;
104 dataType dd = d1 > d2 ? d1 - d2 : d2 - d1;
105 d = std::sqrt(db * db + dd * dd);
106 } else {
107 dataType b1 = tree1->getValue<dataType>(n1);
108 dataType d1 = tree1->getValue<dataType>(p1);
109 dataType b2 = tree2->getValue<dataType>(n2);
110 dataType d2 = tree2->getValue<dataType>(p2);
111 dataType db = b1 > b2 ? b1 - b2 : b2 - b1;
112 dataType dd = d1 > d2 ? d1 - d2 : d2 - d1;
113 d = std::sqrt(db * db + dd * dd);
114 }
115 return squared_ ? d * d : d;
116 }
117
118 template <class dataType>
119 inline dataType editCost_Persistence(int n1,
120 int p1,
121 int n2,
122 int p2,
123 ftm::FTMTree_MT *tree1,
124 ftm::FTMTree_MT *tree2) {
125 dataType d;
126 if(n1 < 0) {
127 dataType b1 = tree2->getValue<dataType>(n2);
128 dataType d1 = tree2->getValue<dataType>(p2);
129 d = d1 > b1 ? d1 - b1 : b1 - d1;
130 } else if(n2 < 0) {
131 dataType b1 = tree1->getValue<dataType>(n1);
132 dataType d1 = tree1->getValue<dataType>(p1);
133 d = d1 > b1 ? d1 - b1 : b1 - d1;
134 } else {
135 dataType b1 = tree1->getValue<dataType>(n1);
136 dataType d1 = tree1->getValue<dataType>(p1);
137 dataType b2 = tree2->getValue<dataType>(n2);
138 dataType d2 = tree2->getValue<dataType>(p2);
139 dataType dist1 = d1 > b1 ? d1 - b1 : b1 - d1;
140 dataType dist2 = d2 > b2 ? d2 - b2 : b2 - d2;
141 d = dist1 > dist2 ? dist1 - dist2 : dist2 - dist1;
142 }
143 return squared_ ? d * d : d;
144 }
145
146 template <class dataType>
147 inline dataType editCost_Shifting(int n1,
148 int p1,
149 int n2,
150 int p2,
151 ftm::FTMTree_MT *tree1,
152 ftm::FTMTree_MT *tree2) {
153 dataType d;
154 if(n1 < 0) {
155 dataType b1 = tree2->getValue<dataType>(n2);
156 dataType d1 = tree2->getValue<dataType>(p2);
157 d = d1 > b1 ? d1 - b1 : b1 - d1;
158 } else if(n2 < 0) {
159 dataType b1 = tree1->getValue<dataType>(n1);
160 dataType d1 = tree1->getValue<dataType>(p1);
161 d = d1 > b1 ? d1 - b1 : b1 - d1;
162 } else {
163 dataType b1 = tree1->getValue<dataType>(n1);
164 dataType d1 = tree1->getValue<dataType>(p1);
165 dataType b2 = tree2->getValue<dataType>(n2);
166 dataType d2 = tree2->getValue<dataType>(p2);
167 dataType pers1 = d1 > b1 ? d1 - b1 : b1 - d1;
168 dataType pers2 = d2 > b2 ? d2 - b2 : b2 - d2;
169 dataType db = b1 > b2 ? b1 - b2 : b2 - b1;
170 dataType dp = pers1 > pers2 ? pers1 - pers2 : pers2 - pers1;
171 d = db + dp;
172 }
173 return squared_ ? d * d : d;
174 }
175
176 public:
178 this->setDebugMsgPrefix(
179 "MergeTreeDistance"); // inherited from Debug: prefix will be printed at
180 // the beginning of every msg
181 }
182 ~BranchMappingDistance() override = default;
183
184 void setBaseMetric(int m) {
185 baseMetric_ = m;
186 }
187
188 void setAssignmentSolver(int assignmentSolver) {
189 assignmentSolverID_ = assignmentSolver;
190 }
191
192 void setSquared(bool s) {
193 squared_ = s;
194 }
195
196 template <class dataType>
198 ftm::FTMTree_MT *tree2) {
199
200 // initialize memoization tables
201
202 std::vector<std::vector<int>> predecessors1(tree1->getNumberOfNodes());
203 std::vector<std::vector<int>> predecessors2(tree2->getNumberOfNodes());
204 int const rootID1 = tree1->getRoot();
205 int const rootID2 = tree2->getRoot();
206 std::vector<int> preorder1(tree1->getNumberOfNodes());
207 std::vector<int> preorder2(tree2->getNumberOfNodes());
208
209 int depth1 = 0;
210 int depth2 = 0;
211 std::stack<int> stack;
212 stack.push(rootID1);
213 int count = tree1->getNumberOfNodes() - 1;
214 while(!stack.empty()) {
215 int const nIdx = stack.top();
216 stack.pop();
217 preorder1[count] = nIdx;
218 count--;
219 depth1 = std::max((int)predecessors1[nIdx].size(), depth1);
220 std::vector<ftm::idNode> children;
221 tree1->getChildren(nIdx, children);
222 for(int const cIdx : children) {
223 stack.push(cIdx);
224 predecessors1[cIdx].reserve(predecessors1[nIdx].size() + 1);
225 predecessors1[cIdx].insert(predecessors1[cIdx].end(),
226 predecessors1[nIdx].begin(),
227 predecessors1[nIdx].end());
228 predecessors1[cIdx].push_back(nIdx);
229 }
230 }
231 stack.push(rootID2);
232 count = tree2->getNumberOfNodes() - 1;
233 while(!stack.empty()) {
234 int const nIdx = stack.top();
235 stack.pop();
236 preorder2[count] = nIdx;
237 count--;
238 depth2 = std::max((int)predecessors2[nIdx].size(), depth2);
239 std::vector<ftm::idNode> children;
240 tree2->getChildren(nIdx, children);
241 for(int const cIdx : children) {
242 stack.push(cIdx);
243 predecessors2[cIdx].reserve(predecessors2[nIdx].size() + 1);
244 predecessors2[cIdx].insert(predecessors2[cIdx].end(),
245 predecessors2[nIdx].begin(),
246 predecessors2[nIdx].end());
247 predecessors2[cIdx].push_back(nIdx);
248 }
249 }
250
251 size_t nn1 = tree1->getNumberOfNodes();
252 size_t nn2 = tree2->getNumberOfNodes();
253 size_t const dim1 = 1;
254 size_t const dim2 = (nn1 + 1) * dim1;
255 size_t const dim3 = (depth1 + 1) * dim2;
256 size_t const dim4 = (nn2 + 1) * dim3;
257
258 std::vector<dataType> memT((nn1 + 1) * (depth1 + 1) * (nn2 + 1)
259 * (depth2 + 1));
260
261 memT[nn1 + 0 * dim2 + nn2 * dim3 + 0 * dim4] = 0;
262 for(size_t i = 0; i < nn1; i++) {
263 int curr1 = preorder1[i];
264 std::vector<ftm::idNode> children1;
265 tree1->getChildren(curr1, children1);
266 for(size_t l = 1; l <= predecessors1[preorder1[i]].size(); l++) {
267 int parent1 = predecessors1[preorder1[i]]
268 [predecessors1[preorder1[i]].size() - l];
269
270 //-----------------------------------------------------------------------
271 // If first subtree has only one branch, return deletion cost of this
272 // branch
273 if(tree1->getNumberOfChildren(curr1) == 0) {
274 memT[curr1 + l * dim2 + nn2 * dim3 + 0 * dim4]
275 = this->baseMetric_ == 0 ? editCost_Wasserstein1<dataType>(
276 curr1, parent1, -1, -1, tree1, tree2)
277 : this->baseMetric_ == 1 ? editCost_Wasserstein2<dataType>(
278 curr1, parent1, -1, -1, tree1, tree2)
279 : this->baseMetric_ == 2
280 ? editCost_Persistence<dataType>(
281 curr1, parent1, -1, -1, tree1, tree2)
282 : editCost_Shifting<dataType>(
283 curr1, parent1, -1, -1, tree1, tree2);
284 }
285 //-----------------------------------------------------------------------
286 // If first subtree has more than one branch, try all decompositions
287 else {
288 dataType c = std::numeric_limits<dataType>::max();
289 for(auto child1_mb : children1) {
290 dataType c_
291 = memT[child1_mb + (l + 1) * dim2 + nn2 * dim3 + 0 * dim4];
292 for(auto child1 : children1) {
293 if(child1 == child1_mb) {
294 continue;
295 }
296 c_ += memT[child1 + 1 * dim2 + nn2 * dim3 + 0 * dim4];
297 }
298 c = std::min(c, c_);
299 }
300 memT[curr1 + l * dim2 + nn2 * dim3 + 0 * dim4] = c;
301 }
302 }
303 }
304 for(size_t j = 0; j < nn2; j++) {
305 int curr2 = preorder2[j];
306 std::vector<ftm::idNode> children2;
307 tree2->getChildren(curr2, children2);
308 for(size_t l = 1; l <= predecessors2[preorder2[j]].size(); l++) {
309 int parent2 = predecessors2[preorder2[j]]
310 [predecessors2[preorder2[j]].size() - l];
311
312 //-----------------------------------------------------------------------
313 // If first subtree has only one branch, return deletion cost of this
314 // branch
315 if(tree2->getNumberOfChildren(curr2) == 0) {
316 memT[nn1 + 0 * dim2 + curr2 * dim3 + l * dim4]
317 = this->baseMetric_ == 0 ? editCost_Wasserstein1<dataType>(
318 -1, -1, curr2, parent2, tree1, tree2)
319 : this->baseMetric_ == 1 ? editCost_Wasserstein2<dataType>(
320 -1, -1, curr2, parent2, tree1, tree2)
321 : this->baseMetric_ == 2
322 ? editCost_Persistence<dataType>(
323 -1, -1, curr2, parent2, tree1, tree2)
324 : editCost_Shifting<dataType>(
325 -1, -1, curr2, parent2, tree1, tree2);
326 }
327 //-----------------------------------------------------------------------
328 // If first subtree has more than one branch, try all decompositions
329 else {
330 dataType c = std::numeric_limits<dataType>::max();
331 for(auto child2_mb : children2) {
332 dataType c_
333 = memT[nn1 + 0 * dim2 + child2_mb * dim3 + (l + 1) * dim4];
334 for(auto child2 : children2) {
335 if(child2 == child2_mb) {
336 continue;
337 }
338 c_ += memT[nn1 + 0 * dim2 + child2 * dim3 + 1 * dim4];
339 }
340 c = std::min(c, c_);
341 }
342 memT[nn1 + 0 * dim2 + curr2 * dim3 + l * dim4] = c;
343 }
344 }
345 }
346
347 for(size_t i = 0; i < nn1; i++) {
348 int curr1 = preorder1[i];
349 std::vector<ftm::idNode> children1;
350 tree1->getChildren(curr1, children1);
351 for(size_t j = 0; j < nn2; j++) {
352 int curr2 = preorder2[j];
353 std::vector<ftm::idNode> children2;
354 tree2->getChildren(curr2, children2);
355 for(size_t l1 = 1; l1 <= predecessors1[preorder1[i]].size(); l1++) {
356 int parent1
357 = predecessors1[preorder1[i]]
358 [predecessors1[preorder1[i]].size() - l1];
359 for(size_t l2 = 1; l2 <= predecessors2[preorder2[j]].size(); l2++) {
360 int parent2
361 = predecessors2[preorder2[j]]
362 [predecessors2[preorder2[j]].size() - l2];
363
364 //===============================================================================
365 // If both trees not empty, find optimal edit operation
366
367 //---------------------------------------------------------------------------
368 // If both trees only have one branch, return edit cost between
369 // the two branches
370 if(tree1->getNumberOfChildren(curr1) == 0
371 and tree2->getNumberOfChildren(curr2) == 0) {
372 memT[curr1 + l1 * dim2 + curr2 * dim3 + l2 * dim4]
373 = this->baseMetric_ == 0 ? editCost_Wasserstein1<dataType>(
374 curr1, parent1, curr2, parent2, tree1, tree2)
375 : this->baseMetric_ == 1 ? editCost_Wasserstein2<dataType>(
376 curr1, parent1, curr2, parent2, tree1, tree2)
377 : this->baseMetric_ == 2
378 ? editCost_Persistence<dataType>(
379 curr1, parent1, curr2, parent2, tree1, tree2)
380 : editCost_Shifting<dataType>(
381 curr1, parent1, curr2, parent2, tree1, tree2);
382 }
383 //---------------------------------------------------------------------------
384 // If first tree only has one branch, try all decompositions of
385 // second tree
386 else if(children1.size() == 0) {
387 dataType d = std::numeric_limits<dataType>::max();
388 for(auto child2_mb : children2) {
389 dataType d_ = memT[curr1 + l1 * dim2 + child2_mb * dim3
390 + (l2 + 1) * dim4];
391 for(auto child2 : children2) {
392 if(child2 == child2_mb) {
393 continue;
394 }
395 d_ += memT[nn1 + 0 * dim2 + child2 * dim3 + 1 * dim4];
396 }
397 d = std::min(d, d_);
398 }
399 memT[curr1 + l1 * dim2 + curr2 * dim3 + l2 * dim4] = d;
400 }
401 //---------------------------------------------------------------------------
402 // If second tree only has one branch, try all decompositions of
403 // first tree
404 else if(children2.size() == 0) {
405 dataType d = std::numeric_limits<dataType>::max();
406 for(auto child1_mb : children1) {
407 dataType d_ = memT[child1_mb + (l1 + 1) * dim2 + curr2 * dim3
408 + l2 * dim4];
409 for(auto child1 : children1) {
410 if(child1 == child1_mb) {
411 continue;
412 }
413 d_ += memT[child1 + 1 * dim2 + nn2 * dim3 + 0 * dim4];
414 }
415 d = std::min(d, d_);
416 }
417 memT[curr1 + l1 * dim2 + curr2 * dim3 + l2 * dim4] = d;
418 }
419 //---------------------------------------------------------------------------
420 // If both trees have more than one branch, try all decompositions
421 // of both trees
422 else {
423 dataType d = std::numeric_limits<dataType>::max();
424 //-----------------------------------------------------------------------
425 // Try all possible main branches of first tree (child1_mb) and
426 // all possible main branches of second tree (child2_mb) Then
427 // try all possible matchings of subtrees
428 if(children1.size() == 2 && children2.size() == 2) {
429 int const child11 = children1[0];
430 int const child12 = children1[1];
431 int const child21 = children2[0];
432 int const child22 = children2[1];
433 d = std::min<dataType>(
434 d,
435 memT[child11 + (l1 + 1) * dim2 + child21 * dim3
436 + (l2 + 1) * dim4]
437 + memT[child12 + 1 * dim2 + child22 * dim3 + 1 * dim4]);
438 d = std::min<dataType>(
439 d,
440 memT[child12 + (l1 + 1) * dim2 + child22 * dim3
441 + (l2 + 1) * dim4]
442 + memT[child11 + 1 * dim2 + child21 * dim3 + 1 * dim4]);
443 d = std::min<dataType>(
444 d,
445 memT[child11 + (l1 + 1) * dim2 + child22 * dim3
446 + (l2 + 1) * dim4]
447 + memT[child12 + 1 * dim2 + child21 * dim3 + 1 * dim4]);
448 d = std::min<dataType>(
449 d,
450 memT[child12 + (l1 + 1) * dim2 + child21 * dim3
451 + (l2 + 1) * dim4]
452 + memT[child11 + 1 * dim2 + child22 * dim3 + 1 * dim4]);
453 } else {
454 for(auto child1_mb : children1) {
455 auto topo1_ = children1;
456 topo1_.erase(
457 std::remove(topo1_.begin(), topo1_.end(), child1_mb),
458 topo1_.end());
459 for(auto child2_mb : children2) {
460 auto topo2_ = children2;
461 topo2_.erase(
462 std::remove(topo2_.begin(), topo2_.end(), child2_mb),
463 topo2_.end());
464
465 auto f = [&](unsigned r, unsigned c) {
466 int const c1 = r < topo1_.size() ? topo1_[r] : -1;
467 int const c2 = c < topo2_.size() ? topo2_[c] : -1;
468 return memT[c1 + 1 * dim2 + c2 * dim3 + 1 * dim4];
469 };
470 int size = std::max(topo1_.size(), topo2_.size()) + 1;
471 auto costMatrix = std::vector<std::vector<dataType>>(
472 size, std::vector<dataType>(size, 0));
473 std::vector<MatchingType> matching;
474 for(int r = 0; r < size; r++) {
475 for(int c = 0; c < size; c++) {
476 costMatrix[r][c] = f(r, c);
477 }
478 }
479
480 AssignmentSolver<dataType> *assignmentSolver;
481 AssignmentExhaustive<dataType> solverExhaustive;
482 AssignmentMunkres<dataType> solverMunkres;
483 AssignmentAuction<dataType> solverAuction;
484 switch(assignmentSolverID_) {
485 case 1:
486 solverExhaustive = AssignmentExhaustive<dataType>();
487 assignmentSolver = &solverExhaustive;
488 break;
489 case 2:
490 solverMunkres = AssignmentMunkres<dataType>();
491 assignmentSolver = &solverMunkres;
492 break;
493 case 0:
494 default:
495 solverAuction = AssignmentAuction<dataType>();
496 assignmentSolver = &solverAuction;
497 }
498 assignmentSolver->setInput(costMatrix);
499 assignmentSolver->setBalanced(true);
500 assignmentSolver->run(matching);
501 dataType d_ = memT[child1_mb + (l1 + 1) * dim2
502 + child2_mb * dim3 + (l2 + 1) * dim4];
503 for(auto m : matching)
504 d_ += std::get<2>(m);
505 d = std::min(d, d_);
506 }
507 }
508 }
509 //-----------------------------------------------------------------------
510 // Try to continue main branch on one child of first tree and
511 // delete all other subtrees Then match continued branch to
512 // current branch in second tree
513 for(auto child1_mb : children1) {
514 dataType d_ = memT[child1_mb + (l1 + 1) * dim2 + curr2 * dim3
515 + l2 * dim4];
516 for(auto child1 : children1) {
517 if(child1 == child1_mb) {
518 continue;
519 }
520 d_ += memT[child1 + 1 * dim2 + nn2 * dim3 + 0 * dim4];
521 }
522 d = std::min(d, d_);
523 }
524 //-----------------------------------------------------------------------
525 // Try to continue main branch on one child of second tree and
526 // delete all other subtrees Then match continued branch to
527 // current branch in first tree
528 for(auto child2_mb : children2) {
529 dataType d_ = memT[curr1 + l1 * dim2 + child2_mb * dim3
530 + (l2 + 1) * dim4];
531 for(auto child2 : children2) {
532 if(child2 == child2_mb) {
533 continue;
534 }
535 d_ += memT[nn1 + 0 * dim2 + child2 * dim3 + 1 * dim4];
536 }
537 d = std::min(d, d_);
538 }
539 memT[curr1 + l1 * dim2 + curr2 * dim3 + l2 * dim4] = d;
540 }
541 }
542 }
543 }
544 }
545
546 std::vector<ftm::idNode> children1;
547 tree1->getChildren(rootID1, children1);
548 std::vector<ftm::idNode> children2;
549 tree2->getChildren(rootID2, children2);
550
551 dataType res
552 = memT[children1[0] + 1 * dim2 + children2[0] * dim3 + 1 * dim4];
553
554 return squared_ ? std::sqrt(res) : res;
555 }
556 };
557
558} // namespace ttk
virtual int run(std::vector< MatchingType > &matchings)=0
virtual int setInput(std::vector< std::vector< dataType > > &C_)
virtual void setBalanced(bool balanced)
dataType editDistance_branch(ftm::FTMTree_MT *tree1, ftm::FTMTree_MT *tree2)
~BranchMappingDistance() override=default
void setAssignmentSolver(int assignmentSolver)
Minimalist debugging class.
Definition Debug.h:88
void setDebugMsgPrefix(const std::string &prefix)
Definition Debug.h:364
const scalarType & getValue(SimplexId nodeId) const
Definition FTMTree_MT.h:339
idNode getNumberOfNodes() const
Definition FTMTree_MT.h:389
int getNumberOfChildren(idNode nodeId)
void getChildren(idNode nodeId, std::vector< idNode > &res)
The Topology ToolKit.
T end(std::pair< T, T > &p)
Definition ripserpy.cpp:483