TTK
Loading...
Searching...
No Matches
PersistenceDiagramClustering.cpp
Go to the documentation of this file.
2
4 std::vector<DiagramType> &intermediateDiagrams,
5 std::vector<double> &weights,
6 DiagramType &barycenter,
7 std::vector<std::vector<MatchingType>> &matchings,
8 const ttk::Debug &dbg,
9 const bool ProgBarycenter) {
10
11 std::vector<DiagramType> final_centroids;
12 std::vector<std::vector<std::vector<MatchingType>>> all_matchings;
13
14 bool useCustomWeights = true;
15
16 if(weights.size() == 0) {
17 dbg.printMsg("Uniform weights");
18 useCustomWeights = false;
19 } else if(weights.size() != intermediateDiagrams.size()) {
20 dbg.printErr("Number of weights != Number of inputs");
21 dbg.printErr("Defaulting to uniform barycenter");
22 useCustomWeights = false;
23 } else {
24 std::stringstream msg;
25 double sum = 0.0;
26 // msg << "Weights: ";
27 for(auto w : weights) {
28 // msg << " " << w;
29 sum += w;
30 }
31 // msg << " | sum: " << sum;
32 // dbg.printMsg(msg.str());
33 if(sum != 1) {
34 dbg.printWrn("Sum of weights different from 1");
35 // dbg.printErr("Defaulting to uniform barycenter");
36 // useCustomWeights = false;
37 }
38 }
39
40 PersistenceDiagramClustering barycenterComputer{};
41 barycenterComputer.setForceUseOfAlgorithm(true);
42 barycenterComputer.setUseCustomWeights(useCustomWeights);
43 barycenterComputer.setUseInterruptible(true);
44 barycenterComputer.setUseProgressive(ProgBarycenter);
45 barycenterComputer.setCustomWeights(&weights);
46 barycenterComputer.setDebugLevel(0);
47 barycenterComputer.setTimeLimit(10);
48 barycenterComputer.setThreadNumber(1);
49 barycenterComputer.execute(
50 intermediateDiagrams, final_centroids, all_matchings);
51 barycenter = std::move(final_centroids[0]);
52 matchings = std::move(all_matchings[0]);
53}
54
56 std::vector<DiagramType> &intermediateDiagrams,
57 std::vector<DiagramType> &final_centroids,
58 std::vector<std::vector<std::vector<MatchingType>>> &all_matchings) {
59
60 const int numberOfInputs_ = intermediateDiagrams.size();
61 Timer tm;
62
63 printMsg("Clustering " + std::to_string(numberOfInputs_) + " diagrams in "
64 + std::to_string(NumberOfClusters) + " cluster(s).");
65
66 std::vector<DiagramType> data_min(numberOfInputs_);
67 std::vector<DiagramType> data_sad(numberOfInputs_);
68 std::vector<DiagramType> data_max(numberOfInputs_);
69
70 std::vector<std::vector<int>> data_min_idx(numberOfInputs_);
71 std::vector<std::vector<int>> data_sad_idx(numberOfInputs_);
72 std::vector<std::vector<int>> data_max_idx(numberOfInputs_);
73
74 std::vector<int> inv_clustering(numberOfInputs_);
75
76 bool do_min = false;
77 bool do_sad = false;
78 bool do_max = false;
79
80 // Create diagrams for min, saddle and max persistence pairs
81 for(int i = 0; i < numberOfInputs_; i++) {
82 DiagramType &CTDiagram = intermediateDiagrams[i];
83
84 for(size_t j = 0; j < CTDiagram.size(); ++j) {
85 auto &t = CTDiagram[j];
86
87 ttk::CriticalType const nt1 = t.birth.type;
88 ttk::CriticalType const nt2 = t.death.type;
89
90 double const dt = t.persistence();
91 // if (abs<double>(dt) < zeroThresh) continue;
92 if(dt > 0) {
95 if(PairTypeClustering == 2) {
96 data_max[i].push_back(t);
97 data_max_idx[i].push_back(j);
98 do_max = true;
99 } else {
100 data_min[i].push_back(t);
101 data_min_idx[i].push_back(j);
102 do_min = true;
103 }
104 } else {
107 data_max[i].push_back(t);
108 data_max_idx[i].push_back(j);
109 do_max = true;
110 }
113 data_min[i].push_back(t);
114 data_min_idx[i].push_back(j);
115 do_min = true;
116 }
120 && nt2 == ttk::CriticalType::Saddle1)) {
121 data_sad[i].push_back(t);
122 data_sad_idx[i].push_back(j);
123 do_sad = true;
124 }
125 }
126 }
127 }
128 }
129
130 std::stringstream msg;
131 switch(PairTypeClustering) {
132 case(0):
133 msg << "Only MIN-SAD Pairs";
134 do_max = false;
135 do_sad = false;
136 break;
137 case(1):
138 msg << "Only SAD-SAD Pairs";
139 do_max = false;
140 do_min = false;
141 break;
142 case(2):
143 msg << "Only SAD-MAX Pairs";
144 do_min = false;
145 do_sad = false;
146 break;
147 default:
148 msg << "All critical pairs: "
149 "global clustering";
150 break;
151 }
152 printMsg(msg.str());
153
154 std::vector<std::vector<std::vector<std::vector<MatchingType>>>>
155 all_matchings_per_type_and_cluster;
156 PDClustering KMeans{};
157 KMeans.setNumberOfInputs(numberOfInputs_);
158 KMeans.setWasserstein(WassersteinMetric);
159 KMeans.setUseProgressive(UseProgressive);
160 KMeans.setAccelerated(UseAccelerated);
161 KMeans.setUseKDTree(true);
162 KMeans.setTimeLimit(TimeLimit);
163 KMeans.setGeometricalFactor(Alpha);
164 KMeans.setLambda(Lambda);
165 KMeans.setDeterministic(Deterministic);
166 KMeans.setForceUseOfAlgorithm(ForceUseOfAlgorithm);
167 KMeans.setDebugLevel(debugLevel_);
168 KMeans.setDeltaLim(DeltaLim);
169 KMeans.setUseDeltaLim(UseAdditionalPrecision);
170 KMeans.setDistanceWritingOptions(DistanceWritingOptions);
171 KMeans.setKMeanspp(UseKmeansppInit);
172 KMeans.setK(NumberOfClusters);
173 KMeans.setDiagrams(&data_min, &data_sad, &data_max);
174 KMeans.setDos(do_min, do_sad, do_max);
175 KMeans.setNonMatchingWeight(NonMatchingWeight);
176
177 KMeans.setUseCustomWeights(UseCustomWeights);
178 KMeans.setCustomWeights(CustomWeights);
179
180 inv_clustering
181 = KMeans.execute(final_centroids, all_matchings_per_type_and_cluster);
182 std::vector<std::vector<int>> centroids_sizes = KMeans.get_centroids_sizes();
183
184 this->distances = KMeans.getDistances();
185
187 //
188 std::vector<int> cluster_size;
189 std::vector<int> idxInCluster(numberOfInputs_);
190
191 for(int j = 0; j < numberOfInputs_; ++j) {
192 size_t const c = inv_clustering[j];
193 if(c + 1 > cluster_size.size()) {
194 cluster_size.resize(c + 1);
195 cluster_size[c] = 1;
196 idxInCluster[j] = 0;
197 } else {
198 cluster_size[c]++;
199 idxInCluster[j] = cluster_size[c] - 1;
200 }
201 }
202
203 bool removeDuplicateGlobalPair = false;
204 if(NumberOfClusters > 1 and do_min and do_max) {
205 removeDuplicateGlobalPair = true;
206 }
207
208 all_matchings.resize(NumberOfClusters);
209 for(int c = 0; c < NumberOfClusters; c++) {
210 all_matchings[c].resize(numberOfInputs_);
211 if(removeDuplicateGlobalPair) {
212 centroids_sizes[c][0] -= 1;
213 }
214 }
215 for(int i = 0; i < numberOfInputs_; i++) {
216 size_t const c = inv_clustering[i];
217
218 if(do_min) {
219 for(size_t j = 0;
220 j < all_matchings_per_type_and_cluster[c][0][idxInCluster[i]].size();
221 j++) {
223 = all_matchings_per_type_and_cluster[c][0][idxInCluster[i]][j];
224 int const bidder_id = std::get<0>(t);
225 if(bidder_id < (int)data_min[i].size()) {
226 if(bidder_id < 0) { // matching with a diagonal point
227 std::get<0>(t) = -1;
228 } else {
229 std::get<0>(t) = data_min_idx[i][bidder_id];
230 }
231 // cout<<" IDS : "<<bidder_id<<" "<<std::get<0>(t)<<endl;
232 if(std::get<1>(t) < 0) {
233 std::get<1>(t) = -1;
234 }
235 all_matchings[inv_clustering[i]][i].push_back(t);
236 }
237 }
238 }
239
240 if(do_sad) {
241 for(size_t j = 0;
242 j < all_matchings_per_type_and_cluster[c][1][idxInCluster[i]].size();
243 j++) {
245 = all_matchings_per_type_and_cluster[c][1][idxInCluster[i]][j];
246 int const bidder_id = std::get<0>(t);
247 if(bidder_id < (int)data_sad[i].size()) {
248 if(bidder_id < 0) { // matching with a diagonal point
249 std::get<0>(t) = -1;
250 } else {
251 std::get<0>(t) = data_sad_idx[i][bidder_id];
252 }
253 if(std::get<1>(t) >= 0) {
254 std::get<1>(t) = std::get<1>(t) + centroids_sizes[c][0];
255 } else {
256 std::get<1>(t) = -1;
257 }
258 all_matchings[inv_clustering[i]][i].push_back(t);
259 }
260 }
261 }
262
263 if(do_max) {
264 for(size_t j = 0;
265 j < all_matchings_per_type_and_cluster[c][2][idxInCluster[i]].size();
266 j++) {
268 = all_matchings_per_type_and_cluster[c][2][idxInCluster[i]][j];
269 int const bidder_id = std::get<0>(t);
270 if(bidder_id < (int)data_max[i].size()) {
271 if(bidder_id < 0) { // matching with a diagonal point
272 std::get<0>(t) = -1;
273 } else {
274 std::get<0>(t) = data_max_idx[i][bidder_id];
275 }
276 // std::get<0>(t) = data_max_idx[i][bidder_id];
277 if(std::get<1>(t) > 0) {
278 std::get<1>(t)
279 = std::get<1>(t) + centroids_sizes[c][0] + centroids_sizes[c][1];
280 } else if(std::get<1>(t) == 0) {
281 if(!removeDuplicateGlobalPair) {
282 std::get<1>(t) = std::get<1>(t) + centroids_sizes[c][0]
283 + centroids_sizes[c][1];
284 }
285 } else {
286 std::get<1>(t) = -1;
287 }
288 all_matchings[inv_clustering[i]][i].push_back(t);
289 }
290 }
291 }
292 }
293
294 printMsg("Complete", 1, tm.getElapsedTime(), threadNumber_);
295 return inv_clustering;
296}
Minimalist debugging class.
Definition Debug.h:88
int debugLevel_
Definition Debug.h:379
int printWrn(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
Definition Debug.h:159
int printMsg(const std::string &msg, const debug::Priority &priority=debug::Priority::INFO, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cout) const
Definition Debug.h:118
int printErr(const std::string &msg, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cerr) const
Definition Debug.h:149
int setNumberOfInputs(int numberOfInputs)
TTK processing package for the computation of Wasserstein barycenters and K-Means clusterings of a se...
std::vector< int > execute(std::vector< DiagramType > &intermediateDiagrams, std::vector< DiagramType > &centroids, std::vector< std::vector< std::vector< MatchingType > > > &all_matchings)
void setForceUseOfAlgorithm(bool forceUseOfAlgorithm)
double getElapsedTime()
Definition Timer.h:15
CriticalType
default value for critical index
Definition DataTypes.h:88
void computeWeightedBarycenter(std::vector< DiagramType > &intermediateDiagrams, std::vector< double > &weights, DiagramType &barycenter, std::vector< std::vector< MatchingType > > &matchings, const ttk::Debug &dbg, const bool ProgBarycenter)
std::tuple< int, int, double > MatchingType
Matching between two Persistence Diagram pairs.
std::vector< PersistencePair > DiagramType
Persistence Diagram type as a vector of Persistence pairs.
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/| (_) |"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)