TTK
Loading...
Searching...
No Matches
AssignmentMunkresImpl.h
Go to the documentation of this file.
1#pragma once
2
3#include <AssignmentMunkres.h>
4
5#include <tuple>
6
7template <typename dataType>
9 std::vector<MatchingType> &matchings) {
10 int step = 1;
11 int iter = 0;
12 int maxIter = 100000;
13 bool done = false;
14 Timer t;
15
16 std::vector<std::vector<dataType>> inputMatrix(
17 this->rowSize, std::vector<dataType>(this->colSize));
18 copyInputMatrix(inputMatrix);
19
20 while(!done) {
21 ++iter;
22 this->printMsg(
23 "Step " + std::to_string(step) + ", Iteration " + std::to_string(iter),
24 debug::Priority::DETAIL);
25
26 if(iter > 20 && (iter % (int)std::round((double)maxIter / 5.0) == 0)) {
27 double progress = std::round(100.0 * (double)iter / (double)maxIter);
28 this->printMsg("Progress", progress / 100.0, t.getElapsedTime());
29 }
30
31 if(iter > maxIter) {
32 // showCostMatrix<dataType>();
33 // showMaskMatrix();
34
35 this->printMsg("Failed to converge after " + std::to_string(maxIter)
36 + " iterations. Aborting.");
37
38 step = 7;
39 // Abort. Still found something
40 // though not optimal.
41 }
42
43 // Show intermediary matrices:
44 // showCostMatrix<dataType>();
45 // showMaskMatrix();
46
47 switch(step) {
48 case 1:
49 stepOne(step);
50 break;
51 case 2:
52 stepTwo(step);
53 break;
54 case 3:
55 stepThree(step);
56 break;
57 case 4:
58 stepFour(step);
59 break;
60 case 5:
61 stepFive(step);
62 break;
63 case 6:
64 stepSix(step);
65 break;
66 case 7:
67 stepSeven(step);
68 done = true;
69 break;
70 default:
71 break;
72 }
73 }
74
75 this->computeAffectationCost(inputMatrix);
76 this->affect(matchings, inputMatrix);
77 this->clear();
78
79 return 0;
80}
81
82// Preprocess cost matrix.
83template <typename dataType>
84int ttk::AssignmentMunkres<dataType>::stepOne(int &step) // ~ 0% perf
85{
86 double minInCol;
87 std::vector<std::vector<dataType>> *C
89
90 // Benefit from the matrix sparsity.
91 dataType maxVal = std::numeric_limits<dataType>::max();
92 for(int r = 0; r < this->rowSize - 1; ++r) {
93 rowLimitsPlus[r] = -1;
94 rowLimitsMinus[r] = -1;
95 }
96 for(int c = 0; c < this->colSize - 1; ++c) {
97 colLimitsPlus[c] = -1;
98 colLimitsMinus[c] = -1;
99 }
100
101 int droppedMinus = 0;
102 int droppedPlus = 0;
103
104 for(int r = 0; r < this->rowSize - 1; ++r) {
105 for(int c = 0; c < this->colSize - 1; ++c)
106 if((*C)[r][c] != maxVal) {
107 rowLimitsMinus[r] = c; // Included
108 break;
109 }
110 if(rowLimitsMinus[r] == -1) {
111 ++droppedMinus;
112 rowLimitsMinus[r] = 0;
113 } // Included
114
115 for(int c = this->colSize - 2; c >= 0; --c)
116 if((*C)[r][c] != maxVal) {
117 rowLimitsPlus[r] = c + 1; // Not included
118 break;
119 }
120 if(rowLimitsPlus[r] == -1) {
121 ++droppedPlus;
122 rowLimitsPlus[r] = this->colSize - 1;
123 } // Not included
124 }
125
126 if(droppedMinus > 0) {
127 this->printMsg(
128 "Unexpected non-assignable row [minus], dropping optimisation for "
129 + std::to_string(droppedMinus) + " row(s).",
130 debug::Priority::DETAIL);
131 }
132
133 if(droppedPlus > 0) {
134 this->printMsg(
135 "Unexpected non-assignable row [plus], dropping optimisation for "
136 + std::to_string(droppedPlus) + " row(s).",
137 debug::Priority::DETAIL);
138 }
139
140 droppedMinus = 0;
141 droppedPlus = 0;
142
143 for(int c = 0; c < this->colSize - 1; ++c) {
144 for(int r = 0; r < this->rowSize - 1; ++r)
145 if((*C)[r][c] != maxVal) {
146 colLimitsMinus[c] = r; // Inclusive
147 break;
148 }
149 for(int r = this->rowSize - 1; r >= 0; --r)
150 if((*C)[r][c] != maxVal) {
151 colLimitsPlus[c] = r + 1; // Exclusive.
152 break;
153 }
154 if(colLimitsPlus[c] == -1) {
155 ++droppedPlus;
156 colLimitsMinus[c] = 0;
157 }
158 if(colLimitsMinus[c] == -1) {
159 ++droppedMinus;
160 colLimitsMinus[c] = this->rowSize;
161 }
162 }
163
164 if(droppedMinus > 0) {
165 this->printMsg(
166 "Unexpected non-assignable row [minus], dropping optimisation for "
167 + std::to_string(droppedMinus) + " row(s).",
168 debug::Priority::DETAIL);
169 }
170
171 if(droppedPlus > 0) {
172 this->printMsg(
173 "Unexpected non-assignable row [plus], dropping optimisation for "
174 + std::to_string(droppedPlus) + " row(s).",
175 debug::Priority::DETAIL);
176 }
177
178 rowLimitsMinus[this->rowSize - 1] = 0;
179 rowLimitsPlus[this->rowSize - 1] = this->colSize - 1;
180
181 // Remove last column (except the last element) from all other columns.
182 // The last column will then be ignored during the solving.
183 for(int r = 0; r < this->rowSize - 1; ++r) {
184 dataType lastElement = (*C)[r][this->colSize - 1];
185 for(int c = 0; c < this->colSize - 1; ++c) {
186 (*C)[r][c] -= lastElement;
187 }
188 }
189
190 // Subtract minimum value in every column except the last.
191 for(int c = 0; c < this->colSize - 1; ++c) {
192 minInCol = (*C)[0][c];
193
194 for(int r = 0; r < this->rowSize; ++r)
195 if((*C)[r][c] < minInCol)
196 minInCol = (*C)[r][c];
197
198 for(int r = 0; r < this->rowSize; ++r)
199 (*C)[r][c] -= minInCol;
200 }
201
202 step = 2;
203 return 0;
204}
205
206// Find a zero in the matrix,
207// star it if it is the only one in its row and col.
208template <typename dataType>
209int ttk::AssignmentMunkres<dataType>::stepTwo(int &step) // ~ 0% perf
210{
211 std::vector<std::vector<dataType>> *C
212 = AssignmentSolver<dataType>::getCostMatrixPointer();
213
214 for(int r = 0; r < this->rowSize - 1; ++r) {
215 for(int c = 0; c < this->colSize - 1; ++c) {
216 if(!rowCover[r] && !colCover[c] && isZero((*C)[r][c])) {
217 M[r][c] = 1;
218 // Temporarily cover row and column to find independent zeros.
219 rowCover[r] = true;
220 colCover[c] = true;
221 }
222 }
223
224 // Don't account for last column.
225 }
226
227 for(int c = 0; c < this->colSize - 1; ++c)
228 if(isZero((*C)[this->rowSize - 1][c]) && !colCover[c]) {
229 M[this->rowSize - 1][c] = 1;
230 // Don't ban last row where elements are all independent.
231 colCover[c] = true;
232 }
233
234 // Remove coverings (temporarily used to find independent zeros).
235 for(int r = 0; r < this->rowSize; ++r)
236 rowCover[r] = false;
237
238 for(int c = 0; c < this->colSize - 1; ++c)
239 colCover[c] = false;
240
241 step = 3;
242 return 0;
243}
244
245// Check column coverings.
246// If all columns are starred (1 star only per column is possible)
247// then the algorithm is terminated.
248template <typename dataType>
249int ttk::AssignmentMunkres<dataType>::stepThree(int &step) // ~ 10% perf
250{
251 for(int r = 0; r < this->rowSize; ++r) {
252 int start = rowLimitsMinus[r];
253 int end = rowLimitsPlus[r];
254 for(int c = start; c < end; ++c)
255 if(M[r][c] == 1)
256 colCover[c] = true;
257 }
258
259 int processedCols = 0;
260
261 for(int c = 0; c < this->colSize - 1; ++c)
262 if(colCover[c])
263 ++processedCols;
264
265 if(processedCols >= this->colSize - 1)
266 step = 7; // end algorithm
267 else
268 step = 4; // follow prime scheme
269 return 0;
270}
271
272// Find a non covered zero, prime it
273// . if current row is last or has no starred zero -> step 5
274// . else, cover row and uncover the col with a star
275// Repeat until there are no uncovered zero left
276// Save smallest uncovered value then -> step 6
277template <typename dataType>
278int ttk::AssignmentMunkres<dataType>::stepFour(int &step) // ~ 45% perf
279{
280 int row = -1;
281 int col = -1;
282 bool done = false;
283
284 while(!done) {
285 findZero(row, col);
286
287 if(row == -1) {
288 done = true;
289 step = 6;
290 }
291
292 else {
293 M[row][col] = 2;
294 int colOfStarInRow = findStarInRow(row);
295 // If a star was found and it is not in the last row
296 if(colOfStarInRow > -1 && row < this->rowSize - 1) {
297 rowCover[row] = true;
298 colCover[colOfStarInRow] = false;
299 }
300
301 else {
302 done = true;
303 step = 5;
304 pathRow0 = row;
305 pathCol0 = col;
306 }
307 }
308 }
309
310 return 0;
311}
312
313template <typename dataType>
315 int start = rowLimitsMinus[row];
316 int end = rowLimitsPlus[row];
317 for(int c = start; c < end; ++c)
318 if(M[row][c] == 1)
319 return c;
320 return -1;
321}
322
323template <typename dataType>
324int ttk::AssignmentMunkres<dataType>::findZero(int &row, int &col) {
325 auto *C = AssignmentSolver<dataType>::getCostMatrixPointer();
326
327 row = -1;
328 col = -1;
329
330 while(createdZeros.size() > 0) {
331 std::pair<int, int> zero = createdZeros.back();
332 int f = zero.first;
333 int s = zero.second;
334 createdZeros.pop_back();
335 if(!rowCover[f] && !colCover[s]) {
336 row = f;
337 col = s;
338 return 0;
339 }
340 }
341
342 for(int r = 0; r < this->rowSize; ++r) {
343 int start = rowLimitsMinus[r];
344 int end = rowLimitsPlus[r];
345 if(rowCover[r])
346 continue;
347
348 for(int c = start; c < end; ++c) {
349 if(colCover[c])
350 continue;
351 if((*C)[r][c] == (dataType)0) {
352 row = r;
353 col = c;
354 return 0;
355 }
356 }
357 }
358
359 this->printMsg("Zero not found.", debug::Priority::DETAIL);
360
361 return 0;
362}
363
364// Make path of alternating primed and starred zeros
365// 1. uncovered primed found at step 4
366// 2. same column, starred (if any)
367// 3. same row, primed (always one)
368// 4. continue until a primed zero has no starred zero in its column
369// Unstar each starred zero in the series, star each primed zero
370// in the series,
371// erase all primes, uncover every line, return to step 3.
372template <typename dataType>
373int ttk::AssignmentMunkres<dataType>::stepFive(int &step) // ~ 10% perf
374{
375 {
376 int r;
377 int c;
378
379 pathCount = 1;
380 path[pathCount - 1][0] = pathRow0;
381 path[pathCount - 1][1] = pathCol0;
382
383 bool done = false;
384 while(!done) {
385 r = findStarInCol(path[pathCount - 1][1]);
386 if(r == -1)
387 done = true;
388
389 else {
390 ++pathCount;
391 path[pathCount - 1][0] = r;
392 path[pathCount - 1][1] = path[pathCount - 2][1];
393
394 c = findPrimeInRow(path[pathCount - 1][0]);
395 if(c == -1) {
396 this->printWrn("Did not find an expected prime.");
397 }
398 ++pathCount;
399 path[pathCount - 1][0] = path[pathCount - 2][0];
400 path[pathCount - 1][1] = c;
401 }
402 }
403 }
404
405 // process path
406 for(int p = 0; p < pathCount; ++p) {
407 if(M[path[p][0]][path[p][1]] == 1)
408 M[path[p][0]][path[p][1]] = 0;
409 else
410 M[path[p][0]][path[p][1]] = 1;
411 }
412
413 // clear covers
414 for(int r = 0; r < this->rowSize; ++r)
415 rowCover[r] = false;
416 for(int c = 0; c < this->colSize - 1; ++c)
417 colCover[c] = false;
418
419 // erase primes
420 for(int r = 0; r < this->rowSize; ++r) {
421 int start = rowLimitsMinus[r];
422 int end = rowLimitsPlus[r];
423 for(int c = start; c < end; ++c)
424 if(M[r][c] == 2)
425 M[r][c] = 0;
426 }
427
428 step = 3;
429 return 0;
430}
431
432template <typename dataType>
434 int start = colLimitsMinus[col];
435 int end = colLimitsPlus[col];
436 for(int r = start; r < end; ++r)
437 if(M[r][col] == 1)
438 return r;
439
440 if(M[this->rowSize - 1][col] == 1)
441 return (this->rowSize - 1);
442 return -1;
443}
444
445template <typename dataType>
447 int start = rowLimitsMinus[row];
448 int end = rowLimitsPlus[row];
449 for(int c = start; c < end; ++c)
450 if(M[row][c] == 2)
451 return c;
452 return -1;
453}
454
455// Add smallest value to every element of each covered row,
456// subtract it from every element of each uncovered col.
457// Return to step 4 without altering any stars/primes/covers.
458template <typename dataType>
459int ttk::AssignmentMunkres<dataType>::stepSix(int &step) // ~ 35% perf
460{
461 auto *C = AssignmentSolver<dataType>::getCostMatrixPointer();
462
463 dataType minVal = std::numeric_limits<dataType>::max();
464
465 // find smallest
466 for(int r = 0; r < this->rowSize; ++r) {
467 if(rowCover[r])
468 continue;
469
470 int start = rowLimitsMinus[r];
471 int end = rowLimitsPlus[r];
472
473 for(int c = start; c < end; ++c) {
474 if(colCover[c])
475 continue;
476 if((*C)[r][c] < minVal)
477 minVal = (*C)[r][c];
478 }
479 }
480
481 createdZeros.clear();
482
483 // add and subtract
484 for(int r = 0; r < this->rowSize; ++r) {
485
486 int start = rowLimitsMinus[r];
487 int end = rowLimitsPlus[r];
488
489 for(int c = start; c < end; ++c) {
490 if(rowCover[r])
491 (*C)[r][c] = (*C)[r][c] + minVal;
492 if(!colCover[c]) {
493 (*C)[r][c] = (*C)[r][c] - minVal;
494 if(isZero((*C)[r][c])) {
495 createdZeros.emplace_back(r, c);
496 }
497 }
498 }
499 }
500
501 step = 4;
502 return 0;
503}
504
505template <typename dataType>
507 this->printMsg("Step 7 over.", debug::Priority::DETAIL);
508 return 0;
509}
510
511template <typename dataType>
513 std::vector<MatchingType> &matchings,
514 const std::vector<std::vector<dataType>> &C) {
515 int nbC = this->colSize;
516 int nbR = this->rowSize;
517
518 matchings.clear();
519
520 for(int r = 0; r < nbR; ++r)
521 for(int c = 0; c < nbC; ++c)
522 if(M[r][c] == 1) {
523 matchings.emplace_back(r, c, C[r][c]);
524 // Use row cover to match to last column diagonal.
525 if(r < nbR - 1)
526 rowCover[r] = true;
527 }
528
529 // Clear row cover
530 for(int r = 0; r < nbR - 1; ++r) {
531 // Match to diagonal.
532 if(!rowCover[r]) {
533 matchings.emplace_back(r, nbC - 1, C[r][nbC - 1]);
534 }
535 // Ensure row covers are cleared.
536 else {
537 rowCover[r] = false;
538 }
539 }
540
541 return 0;
542}
543
544template <typename dataType>
546 const std::vector<std::vector<dataType>> &C) {
547 int nbC = this->colSize;
548 int nbR = this->rowSize;
549
550 dataType total = 0;
551
552 for(int r = 0; r < nbR; ++r)
553 for(int c = 0; c < nbC; ++c)
554 if(M[r][c] == 1) {
555 total += C[r][c];
556 }
557
558 this->printMsg("Total cost: " + std::to_string(total));
559
560 return 0;
561}
#define ttkNotUsed(x)
Mark function/method parameters that are not used in the function body at all.
Definition: BaseClass.h:47
int run(std::vector< MatchingType > &matchings) override
double getElapsedTime()
Definition: Timer.h:15
printMsg(debug::output::GREEN+"                           "+debug::output::ENDCOLOR+debug::output::GREEN+"▒"+debug::output::ENDCOLOR+debug::output::GREEN+"▒▒▒▒▒▒▒▒▒▒▒▒▒░"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)