TTK
Loading...
Searching...
No Matches
AssignmentMunkresImpl.h
Go to the documentation of this file.
1#pragma once
2
3#include <AssignmentMunkres.h>
4
5#include <tuple>
6
7template <typename dataType>
9 std::vector<MatchingType> &matchings) {
10 int step = 1;
11 int iter = 0;
12 const int maxIter = 100000;
13 bool done = false;
14 Timer t;
15
16 std::vector<std::vector<dataType>> inputMatrix(
17 this->rowSize, std::vector<dataType>(this->colSize));
18 copyInputMatrix(inputMatrix);
19
20 while(!done) {
21 ++iter;
22 this->printMsg(
23 "Step " + std::to_string(step) + ", Iteration " + std::to_string(iter),
24 debug::Priority::DETAIL);
25
26 if(iter > 20 && (iter % (int)std::round((double)maxIter / 5.0) == 0)) {
27 double const progress
28 = std::round(100.0 * (double)iter / (double)maxIter);
29 this->printMsg("Progress", progress / 100.0, t.getElapsedTime());
30 }
31
32 if(iter > maxIter) {
33 // showCostMatrix<dataType>();
34 // showMaskMatrix();
35
36 this->printMsg("Failed to converge after " + std::to_string(maxIter)
37 + " iterations. Aborting.");
38
39 step = 7;
40 // Abort. Still found something
41 // though not optimal.
42 }
43
44 // Show intermediary matrices:
45 // showCostMatrix<dataType>();
46 // showMaskMatrix();
47
48 switch(step) {
49 case 1:
50 stepOne(step);
51 break;
52 case 2:
53 stepTwo(step);
54 break;
55 case 3:
56 stepThree(step);
57 break;
58 case 4:
59 stepFour(step);
60 break;
61 case 5:
62 stepFive(step);
63 break;
64 case 6:
65 stepSix(step);
66 break;
67 case 7:
68 stepSeven(step);
69 done = true;
70 break;
71 default:
72 break;
73 }
74 }
75
76 this->computeAffectationCost(inputMatrix);
77 this->affect(matchings, inputMatrix);
78 this->clear();
79
80 return 0;
81}
82
83// Preprocess cost matrix.
84template <typename dataType>
85int ttk::AssignmentMunkres<dataType>::stepOne(int &step) // ~ 0% perf
86{
87 double minInCol;
88 std::vector<std::vector<dataType>> *C
90
91 // Benefit from the matrix sparsity.
92 dataType maxVal = std::numeric_limits<dataType>::max();
93 for(int r = 0; r < this->rowSize - 1; ++r) {
94 rowLimitsPlus[r] = -1;
95 rowLimitsMinus[r] = -1;
96 }
97 for(int c = 0; c < this->colSize - 1; ++c) {
98 colLimitsPlus[c] = -1;
99 colLimitsMinus[c] = -1;
100 }
101
102 int droppedMinus = 0;
103 int droppedPlus = 0;
104
105 for(int r = 0; r < this->rowSize - 1; ++r) {
106 for(int c = 0; c < this->colSize - 1; ++c)
107 if((*C)[r][c] != maxVal) {
108 rowLimitsMinus[r] = c; // Included
109 break;
110 }
111 if(rowLimitsMinus[r] == -1) {
112 ++droppedMinus;
113 rowLimitsMinus[r] = 0;
114 } // Included
115
116 for(int c = this->colSize - 2; c >= 0; --c)
117 if((*C)[r][c] != maxVal) {
118 rowLimitsPlus[r] = c + 1; // Not included
119 break;
120 }
121 if(rowLimitsPlus[r] == -1) {
122 ++droppedPlus;
123 rowLimitsPlus[r] = this->colSize - 1;
124 } // Not included
125 }
126
127 if(droppedMinus > 0) {
128 this->printMsg(
129 "Unexpected non-assignable row [minus], dropping optimisation for "
130 + std::to_string(droppedMinus) + " row(s).",
131 debug::Priority::DETAIL);
132 }
133
134 if(droppedPlus > 0) {
135 this->printMsg(
136 "Unexpected non-assignable row [plus], dropping optimisation for "
137 + std::to_string(droppedPlus) + " row(s).",
138 debug::Priority::DETAIL);
139 }
140
141 droppedMinus = 0;
142 droppedPlus = 0;
143
144 for(int c = 0; c < this->colSize - 1; ++c) {
145 for(int r = 0; r < this->rowSize - 1; ++r)
146 if((*C)[r][c] != maxVal) {
147 colLimitsMinus[c] = r; // Inclusive
148 break;
149 }
150 for(int r = this->rowSize - 1; r >= 0; --r)
151 if((*C)[r][c] != maxVal) {
152 colLimitsPlus[c] = r + 1; // Exclusive.
153 break;
154 }
155 if(colLimitsPlus[c] == -1) {
156 ++droppedPlus;
157 colLimitsMinus[c] = 0;
158 }
159 if(colLimitsMinus[c] == -1) {
160 ++droppedMinus;
161 colLimitsMinus[c] = this->rowSize;
162 }
163 }
164
165 if(droppedMinus > 0) {
166 this->printMsg(
167 "Unexpected non-assignable row [minus], dropping optimisation for "
168 + std::to_string(droppedMinus) + " row(s).",
169 debug::Priority::DETAIL);
170 }
171
172 if(droppedPlus > 0) {
173 this->printMsg(
174 "Unexpected non-assignable row [plus], dropping optimisation for "
175 + std::to_string(droppedPlus) + " row(s).",
176 debug::Priority::DETAIL);
177 }
178
179 rowLimitsMinus[this->rowSize - 1] = 0;
180 rowLimitsPlus[this->rowSize - 1] = this->colSize - 1;
181
182 // Remove last column (except the last element) from all other columns.
183 // The last column will then be ignored during the solving.
184 for(int r = 0; r < this->rowSize - 1; ++r) {
185 dataType lastElement = (*C)[r][this->colSize - 1];
186 for(int c = 0; c < this->colSize - 1; ++c) {
187 (*C)[r][c] -= lastElement;
188 }
189 }
190
191 // Subtract minimum value in every column except the last.
192 for(int c = 0; c < this->colSize - 1; ++c) {
193 minInCol = (*C)[0][c];
194
195 for(int r = 0; r < this->rowSize; ++r)
196 if((*C)[r][c] < minInCol)
197 minInCol = (*C)[r][c];
198
199 for(int r = 0; r < this->rowSize; ++r)
200 (*C)[r][c] -= minInCol;
201 }
202
203 step = 2;
204 return 0;
205}
206
207// Find a zero in the matrix,
208// star it if it is the only one in its row and col.
209template <typename dataType>
210int ttk::AssignmentMunkres<dataType>::stepTwo(int &step) // ~ 0% perf
211{
212 std::vector<std::vector<dataType>> *C
213 = AssignmentSolver<dataType>::getCostMatrixPointer();
214
215 for(int r = 0; r < this->rowSize - 1; ++r) {
216 for(int c = 0; c < this->colSize - 1; ++c) {
217 if(!rowCover[r] && !colCover[c] && isZero((*C)[r][c])) {
218 M[r][c] = 1;
219 // Temporarily cover row and column to find independent zeros.
220 rowCover[r] = true;
221 colCover[c] = true;
222 }
223 }
224
225 // Don't account for last column.
226 }
227
228 for(int c = 0; c < this->colSize - 1; ++c)
229 if(isZero((*C)[this->rowSize - 1][c]) && !colCover[c]) {
230 M[this->rowSize - 1][c] = 1;
231 // Don't ban last row where elements are all independent.
232 colCover[c] = true;
233 }
234
235 // Remove coverings (temporarily used to find independent zeros).
236 for(int r = 0; r < this->rowSize; ++r)
237 rowCover[r] = false;
238
239 for(int c = 0; c < this->colSize - 1; ++c)
240 colCover[c] = false;
241
242 step = 3;
243 return 0;
244}
245
246// Check column coverings.
247// If all columns are starred (1 star only per column is possible)
248// then the algorithm is terminated.
249template <typename dataType>
250int ttk::AssignmentMunkres<dataType>::stepThree(int &step) // ~ 10% perf
251{
252 for(int r = 0; r < this->rowSize; ++r) {
253 const int start = rowLimitsMinus[r];
254 const int end = rowLimitsPlus[r];
255 for(int c = start; c < end; ++c)
256 if(M[r][c] == 1)
257 colCover[c] = true;
258 }
259
260 int processedCols = 0;
261
262 for(int c = 0; c < this->colSize - 1; ++c)
263 if(colCover[c])
264 ++processedCols;
265
266 if(processedCols >= this->colSize - 1)
267 step = 7; // end algorithm
268 else
269 step = 4; // follow prime scheme
270 return 0;
271}
272
273// Find a non covered zero, prime it
274// . if current row is last or has no starred zero -> step 5
275// . else, cover row and uncover the col with a star
276// Repeat until there are no uncovered zero left
277// Save smallest uncovered value then -> step 6
278template <typename dataType>
279int ttk::AssignmentMunkres<dataType>::stepFour(int &step) // ~ 45% perf
280{
281 int row = -1;
282 int col = -1;
283 bool done = false;
284
285 while(!done) {
286 findZero(row, col);
287
288 if(row == -1) {
289 done = true;
290 step = 6;
291 }
292
293 else {
294 M[row][col] = 2;
295 const int colOfStarInRow = findStarInRow(row);
296 // If a star was found and it is not in the last row
297 if(colOfStarInRow > -1 && row < this->rowSize - 1) {
298 rowCover[row] = true;
299 colCover[colOfStarInRow] = false;
300 }
301
302 else {
303 done = true;
304 step = 5;
305 pathRow0 = row;
306 pathCol0 = col;
307 }
308 }
309 }
310
311 return 0;
312}
313
314template <typename dataType>
316 const int start = rowLimitsMinus[row];
317 const int end = rowLimitsPlus[row];
318 for(int c = start; c < end; ++c)
319 if(M[row][c] == 1)
320 return c;
321 return -1;
322}
323
324template <typename dataType>
325int ttk::AssignmentMunkres<dataType>::findZero(int &row, int &col) {
326 auto *C = AssignmentSolver<dataType>::getCostMatrixPointer();
327
328 row = -1;
329 col = -1;
330
331 while(createdZeros.size() > 0) {
332 const std::pair<int, int> zero = createdZeros.back();
333 const int f = zero.first;
334 const int s = zero.second;
335 createdZeros.pop_back();
336 if(!rowCover[f] && !colCover[s]) {
337 row = f;
338 col = s;
339 return 0;
340 }
341 }
342
343 for(int r = 0; r < this->rowSize; ++r) {
344 const int start = rowLimitsMinus[r];
345 const int end = rowLimitsPlus[r];
346 if(rowCover[r])
347 continue;
348
349 for(int c = start; c < end; ++c) {
350 if(colCover[c])
351 continue;
352 if((*C)[r][c] == (dataType)0) {
353 row = r;
354 col = c;
355 return 0;
356 }
357 }
358 }
359
360 this->printMsg("Zero not found.", debug::Priority::DETAIL);
361
362 return 0;
363}
364
365// Make path of alternating primed and starred zeros
366// 1. uncovered primed found at step 4
367// 2. same column, starred (if any)
368// 3. same row, primed (always one)
369// 4. continue until a primed zero has no starred zero in its column
370// Unstar each starred zero in the series, star each primed zero
371// in the series,
372// erase all primes, uncover every line, return to step 3.
373template <typename dataType>
374int ttk::AssignmentMunkres<dataType>::stepFive(int &step) // ~ 10% perf
375{
376 {
377 int r;
378 int c;
379
380 pathCount = 1;
381 path[pathCount - 1][0] = pathRow0;
382 path[pathCount - 1][1] = pathCol0;
383
384 bool done = false;
385 while(!done) {
386 r = findStarInCol(path[pathCount - 1][1]);
387 if(r == -1)
388 done = true;
389
390 else {
391 ++pathCount;
392 path[pathCount - 1][0] = r;
393 path[pathCount - 1][1] = path[pathCount - 2][1];
394
395 c = findPrimeInRow(path[pathCount - 1][0]);
396 if(c == -1) {
397 this->printWrn("Did not find an expected prime.");
398 }
399 ++pathCount;
400 path[pathCount - 1][0] = path[pathCount - 2][0];
401 path[pathCount - 1][1] = c;
402 }
403 }
404 }
405
406 // process path
407 for(int p = 0; p < pathCount; ++p) {
408 if(M[path[p][0]][path[p][1]] == 1)
409 M[path[p][0]][path[p][1]] = 0;
410 else
411 M[path[p][0]][path[p][1]] = 1;
412 }
413
414 // clear covers
415 for(int r = 0; r < this->rowSize; ++r)
416 rowCover[r] = false;
417 for(int c = 0; c < this->colSize - 1; ++c)
418 colCover[c] = false;
419
420 // erase primes
421 for(int r = 0; r < this->rowSize; ++r) {
422 const int start = rowLimitsMinus[r];
423 const int end = rowLimitsPlus[r];
424 for(int c = start; c < end; ++c)
425 if(M[r][c] == 2)
426 M[r][c] = 0;
427 }
428
429 step = 3;
430 return 0;
431}
432
433template <typename dataType>
435 const int start = colLimitsMinus[col];
436 const int end = colLimitsPlus[col];
437 for(int r = start; r < end; ++r)
438 if(M[r][col] == 1)
439 return r;
440
441 if(M[this->rowSize - 1][col] == 1)
442 return (this->rowSize - 1);
443 return -1;
444}
445
446template <typename dataType>
448 const int start = rowLimitsMinus[row];
449 const int end = rowLimitsPlus[row];
450 for(int c = start; c < end; ++c)
451 if(M[row][c] == 2)
452 return c;
453 return -1;
454}
455
456// Add smallest value to every element of each covered row,
457// subtract it from every element of each uncovered col.
458// Return to step 4 without altering any stars/primes/covers.
459template <typename dataType>
460int ttk::AssignmentMunkres<dataType>::stepSix(int &step) // ~ 35% perf
461{
462 auto *C = AssignmentSolver<dataType>::getCostMatrixPointer();
463
464 dataType minVal = std::numeric_limits<dataType>::max();
465
466 // find smallest
467 for(int r = 0; r < this->rowSize; ++r) {
468 if(rowCover[r])
469 continue;
470
471 const int start = rowLimitsMinus[r];
472 const int end = rowLimitsPlus[r];
473
474 for(int c = start; c < end; ++c) {
475 if(colCover[c])
476 continue;
477 if((*C)[r][c] < minVal)
478 minVal = (*C)[r][c];
479 }
480 }
481
482 createdZeros.clear();
483
484 // add and subtract
485 for(int r = 0; r < this->rowSize; ++r) {
486
487 const int start = rowLimitsMinus[r];
488 const int end = rowLimitsPlus[r];
489
490 for(int c = start; c < end; ++c) {
491 if(rowCover[r])
492 (*C)[r][c] = (*C)[r][c] + minVal;
493 if(!colCover[c]) {
494 (*C)[r][c] = (*C)[r][c] - minVal;
495 if(isZero((*C)[r][c])) {
496 createdZeros.emplace_back(r, c);
497 }
498 }
499 }
500 }
501
502 step = 4;
503 return 0;
504}
505
506template <typename dataType>
508 this->printMsg("Step 7 over.", debug::Priority::DETAIL);
509 return 0;
510}
511
512template <typename dataType>
514 std::vector<MatchingType> &matchings,
515 const std::vector<std::vector<dataType>> &C) {
516 const int nbC = this->colSize;
517 const int nbR = this->rowSize;
518
519 matchings.clear();
520
521 for(int r = 0; r < nbR; ++r)
522 for(int c = 0; c < nbC; ++c)
523 if(M[r][c] == 1) {
524 matchings.emplace_back(r, c, C[r][c]);
525 // Use row cover to match to last column diagonal.
526 if(r < nbR - 1)
527 rowCover[r] = true;
528 }
529
530 // Clear row cover
531 for(int r = 0; r < nbR - 1; ++r) {
532 // Match to diagonal.
533 if(!rowCover[r]) {
534 matchings.emplace_back(r, nbC - 1, C[r][nbC - 1]);
535 }
536 // Ensure row covers are cleared.
537 else {
538 rowCover[r] = false;
539 }
540 }
541
542 return 0;
543}
544
545template <typename dataType>
547 const std::vector<std::vector<dataType>> &C) {
548 const int nbC = this->colSize;
549 const int nbR = this->rowSize;
550
551 dataType total = 0;
552
553 for(int r = 0; r < nbR; ++r)
554 for(int c = 0; c < nbC; ++c)
555 if(M[r][c] == 1) {
556 total += C[r][c];
557 }
558
559 this->printMsg("Total cost: " + std::to_string(total));
560
561 return 0;
562}
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
Definition BaseClass.h:47
int run(std::vector< MatchingType > &matchings) override
double getElapsedTime()
Definition Timer.h:15
T end(std::pair< T, T > &p)
Definition ripserpy.cpp:483
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/|__ _|"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)