4 std::vector<DiagramType> &intermediateDiagrams,
5 std::vector<double> &weights,
7 std::vector<std::vector<MatchingType>> &matchings,
9 const bool ProgBarycenter) {
11 std::vector<DiagramType> final_centroids;
12 std::vector<std::vector<std::vector<MatchingType>>> all_matchings;
14 bool useCustomWeights =
true;
16 if(weights.size() == 0) {
18 useCustomWeights =
false;
19 }
else if(weights.size() != intermediateDiagrams.size()) {
20 dbg.
printErr(
"Number of weights != Number of inputs");
21 dbg.
printErr(
"Defaulting to uniform barycenter");
22 useCustomWeights =
false;
24 std::stringstream msg;
27 for(
auto w : weights) {
34 dbg.
printWrn(
"Sum of weights different from 1");
42 barycenterComputer.setUseCustomWeights(useCustomWeights);
43 barycenterComputer.setUseInterruptible(
true);
44 barycenterComputer.setUseProgressive(ProgBarycenter);
45 barycenterComputer.setCustomWeights(&weights);
46 barycenterComputer.setDebugLevel(0);
47 barycenterComputer.setTimeLimit(10);
48 barycenterComputer.setThreadNumber(1);
49 barycenterComputer.execute(
50 intermediateDiagrams, final_centroids, all_matchings);
51 barycenter = std::move(final_centroids[0]);
52 matchings = std::move(all_matchings[0]);
56 std::vector<DiagramType> &intermediateDiagrams,
57 std::vector<DiagramType> &final_centroids,
58 std::vector<std::vector<std::vector<MatchingType>>> &all_matchings) {
60 const int numberOfInputs_ = intermediateDiagrams.size();
63 printMsg(
"Clustering " + std::to_string(numberOfInputs_) +
" diagrams in "
66 std::vector<DiagramType> data_min(numberOfInputs_);
67 std::vector<DiagramType> data_sad(numberOfInputs_);
68 std::vector<DiagramType> data_max(numberOfInputs_);
70 std::vector<std::vector<int>> data_min_idx(numberOfInputs_);
71 std::vector<std::vector<int>> data_sad_idx(numberOfInputs_);
72 std::vector<std::vector<int>> data_max_idx(numberOfInputs_);
74 std::vector<int> inv_clustering(numberOfInputs_);
81 for(
int i = 0; i < numberOfInputs_; i++) {
84 for(
size_t j = 0; j < CTDiagram.size(); ++j) {
85 auto &t = CTDiagram[j];
90 double const dt = t.persistence();
96 data_max[i].push_back(t);
97 data_max_idx[i].push_back(j);
100 data_min[i].push_back(t);
101 data_min_idx[i].push_back(j);
107 data_max[i].push_back(t);
108 data_max_idx[i].push_back(j);
113 data_min[i].push_back(t);
114 data_min_idx[i].push_back(j);
121 data_sad[i].push_back(t);
122 data_sad_idx[i].push_back(j);
130 std::stringstream msg;
133 msg <<
"Only MIN-SAD Pairs";
138 msg <<
"Only SAD-SAD Pairs";
143 msg <<
"Only SAD-MAX Pairs";
148 msg <<
"All critical pairs: "
154 std::vector<std::vector<std::vector<std::vector<MatchingType>>>>
155 all_matchings_per_type_and_cluster;
161 KMeans.setUseKDTree(
true);
163 KMeans.setGeometricalFactor(
Alpha);
173 KMeans.setDiagrams(&data_min, &data_sad, &data_max);
174 KMeans.setDos(do_min, do_sad, do_max);
181 = KMeans.execute(final_centroids, all_matchings_per_type_and_cluster);
182 std::vector<std::vector<int>> centroids_sizes = KMeans.get_centroids_sizes();
188 std::vector<int> cluster_size;
189 std::vector<int> idxInCluster(numberOfInputs_);
191 for(
int j = 0; j < numberOfInputs_; ++j) {
192 size_t const c = inv_clustering[j];
193 if(c + 1 > cluster_size.size()) {
194 cluster_size.resize(c + 1);
199 idxInCluster[j] = cluster_size[c] - 1;
203 bool removeDuplicateGlobalPair =
false;
205 removeDuplicateGlobalPair =
true;
210 all_matchings[c].resize(numberOfInputs_);
211 if(removeDuplicateGlobalPair) {
212 centroids_sizes[c][0] -= 1;
215 for(
int i = 0; i < numberOfInputs_; i++) {
216 size_t const c = inv_clustering[i];
220 j < all_matchings_per_type_and_cluster[c][0][idxInCluster[i]].size();
223 = all_matchings_per_type_and_cluster[c][0][idxInCluster[i]][j];
224 int const bidder_id = std::get<0>(t);
225 if(bidder_id < (
int)data_min[i].size()) {
229 std::get<0>(t) = data_min_idx[i][bidder_id];
232 if(std::get<1>(t) < 0) {
235 all_matchings[inv_clustering[i]][i].push_back(t);
242 j < all_matchings_per_type_and_cluster[c][1][idxInCluster[i]].size();
245 = all_matchings_per_type_and_cluster[c][1][idxInCluster[i]][j];
246 int const bidder_id = std::get<0>(t);
247 if(bidder_id < (
int)data_sad[i].size()) {
251 std::get<0>(t) = data_sad_idx[i][bidder_id];
253 if(std::get<1>(t) >= 0) {
254 std::get<1>(t) = std::get<1>(t) + centroids_sizes[c][0];
258 all_matchings[inv_clustering[i]][i].push_back(t);
265 j < all_matchings_per_type_and_cluster[c][2][idxInCluster[i]].size();
268 = all_matchings_per_type_and_cluster[c][2][idxInCluster[i]][j];
269 int const bidder_id = std::get<0>(t);
270 if(bidder_id < (
int)data_max[i].size()) {
274 std::get<0>(t) = data_max_idx[i][bidder_id];
277 if(std::get<1>(t) > 0) {
279 = std::get<1>(t) + centroids_sizes[c][0] + centroids_sizes[c][1];
280 }
else if(std::get<1>(t) == 0) {
281 if(!removeDuplicateGlobalPair) {
282 std::get<1>(t) = std::get<1>(t) + centroids_sizes[c][0]
283 + centroids_sizes[c][1];
288 all_matchings[inv_clustering[i]][i].push_back(t);
295 return inv_clustering;