TTK
Loading...
Searching...
No Matches
PersistenceDiagramClustering.cpp
Go to the documentation of this file.
2
4 std::vector<DiagramType> &intermediateDiagrams,
5 std::vector<DiagramType> &final_centroids,
6 std::vector<std::vector<std::vector<MatchingType>>> &all_matchings) {
7
8 const int numberOfInputs_ = intermediateDiagrams.size();
9 Timer tm;
10
11 printMsg("Clustering " + std::to_string(numberOfInputs_) + " diagrams in "
12 + std::to_string(NumberOfClusters) + " cluster(s).");
13
14 std::vector<DiagramType> data_min(numberOfInputs_);
15 std::vector<DiagramType> data_sad(numberOfInputs_);
16 std::vector<DiagramType> data_max(numberOfInputs_);
17
18 std::vector<std::vector<int>> data_min_idx(numberOfInputs_);
19 std::vector<std::vector<int>> data_sad_idx(numberOfInputs_);
20 std::vector<std::vector<int>> data_max_idx(numberOfInputs_);
21
22 std::vector<int> inv_clustering(numberOfInputs_);
23
24 bool do_min = false;
25 bool do_sad = false;
26 bool do_max = false;
27
28 // Create diagrams for min, saddle and max persistence pairs
29 for(int i = 0; i < numberOfInputs_; i++) {
30 DiagramType &CTDiagram = intermediateDiagrams[i];
31
32 for(size_t j = 0; j < CTDiagram.size(); ++j) {
33 auto &t = CTDiagram[j];
34
35 ttk::CriticalType const nt1 = t.birth.type;
36 ttk::CriticalType const nt2 = t.death.type;
37
38 double const dt = t.persistence();
39 // if (abs<double>(dt) < zeroThresh) continue;
40 if(dt > 0) {
43 if(PairTypeClustering == 2) {
44 data_max[i].push_back(t);
45 data_max_idx[i].push_back(j);
46 do_max = true;
47 } else {
48 data_min[i].push_back(t);
49 data_min_idx[i].push_back(j);
50 do_min = true;
51 }
52 } else {
55 data_max[i].push_back(t);
56 data_max_idx[i].push_back(j);
57 do_max = true;
58 }
61 data_min[i].push_back(t);
62 data_min_idx[i].push_back(j);
63 do_min = true;
64 }
68 && nt2 == ttk::CriticalType::Saddle1)) {
69 data_sad[i].push_back(t);
70 data_sad_idx[i].push_back(j);
71 do_sad = true;
72 }
73 }
74 }
75 }
76 }
77
78 std::stringstream msg;
79 switch(PairTypeClustering) {
80 case(0):
81 msg << "Only MIN-SAD Pairs";
82 do_max = false;
83 do_sad = false;
84 break;
85 case(1):
86 msg << "Only SAD-SAD Pairs";
87 do_max = false;
88 do_min = false;
89 break;
90 case(2):
91 msg << "Only SAD-MAX Pairs";
92 do_min = false;
93 do_sad = false;
94 break;
95 default:
96 msg << "All critical pairs: "
97 "global clustering";
98 break;
99 }
100 printMsg(msg.str());
101
102 std::vector<std::vector<std::vector<std::vector<MatchingType>>>>
103 all_matchings_per_type_and_cluster;
104 PDClustering KMeans{};
105 KMeans.setNumberOfInputs(numberOfInputs_);
106 KMeans.setWasserstein(WassersteinMetric);
107 KMeans.setUseProgressive(UseProgressive);
108 KMeans.setAccelerated(UseAccelerated);
109 KMeans.setUseKDTree(true);
110 KMeans.setTimeLimit(TimeLimit);
111 KMeans.setGeometricalFactor(Alpha);
112 KMeans.setLambda(Lambda);
113 KMeans.setDeterministic(Deterministic);
114 KMeans.setForceUseOfAlgorithm(ForceUseOfAlgorithm);
115 KMeans.setDebugLevel(debugLevel_);
116 KMeans.setDeltaLim(DeltaLim);
117 KMeans.setUseDeltaLim(UseAdditionalPrecision);
118 KMeans.setDistanceWritingOptions(DistanceWritingOptions);
119 KMeans.setKMeanspp(UseKmeansppInit);
120 KMeans.setK(NumberOfClusters);
121 KMeans.setDiagrams(&data_min, &data_sad, &data_max);
122 KMeans.setDos(do_min, do_sad, do_max);
123 KMeans.setNonMatchingWeight(NonMatchingWeight);
124 inv_clustering
125 = KMeans.execute(final_centroids, all_matchings_per_type_and_cluster);
126 std::vector<std::vector<int>> centroids_sizes = KMeans.get_centroids_sizes();
127
128 this->distances = KMeans.getDistances();
129
131 //
132 std::vector<int> cluster_size;
133 std::vector<int> idxInCluster(numberOfInputs_);
134
135 for(int j = 0; j < numberOfInputs_; ++j) {
136 size_t const c = inv_clustering[j];
137 if(c + 1 > cluster_size.size()) {
138 cluster_size.resize(c + 1);
139 cluster_size[c] = 1;
140 idxInCluster[j] = 0;
141 } else {
142 cluster_size[c]++;
143 idxInCluster[j] = cluster_size[c] - 1;
144 }
145 }
146
147 bool removeDuplicateGlobalPair = false;
148 if(NumberOfClusters > 1 and do_min and do_max) {
149 removeDuplicateGlobalPair = true;
150 }
151
152 all_matchings.resize(NumberOfClusters);
153 for(int c = 0; c < NumberOfClusters; c++) {
154 all_matchings[c].resize(numberOfInputs_);
155 if(removeDuplicateGlobalPair) {
156 centroids_sizes[c][0] -= 1;
157 }
158 }
159 for(int i = 0; i < numberOfInputs_; i++) {
160 size_t const c = inv_clustering[i];
161
162 if(do_min) {
163 for(size_t j = 0;
164 j < all_matchings_per_type_and_cluster[c][0][idxInCluster[i]].size();
165 j++) {
167 = all_matchings_per_type_and_cluster[c][0][idxInCluster[i]][j];
168 int const bidder_id = std::get<0>(t);
169 if(bidder_id < (int)data_min[i].size()) {
170 if(bidder_id < 0) { // matching with a diagonal point
171 std::get<0>(t) = -1;
172 } else {
173 std::get<0>(t) = data_min_idx[i][bidder_id];
174 }
175 // cout<<" IDS : "<<bidder_id<<" "<<std::get<0>(t)<<endl;
176 if(std::get<1>(t) < 0) {
177 std::get<1>(t) = -1;
178 }
179 all_matchings[inv_clustering[i]][i].push_back(t);
180 }
181 }
182 }
183
184 if(do_sad) {
185 for(size_t j = 0;
186 j < all_matchings_per_type_and_cluster[c][1][idxInCluster[i]].size();
187 j++) {
189 = all_matchings_per_type_and_cluster[c][1][idxInCluster[i]][j];
190 int const bidder_id = std::get<0>(t);
191 if(bidder_id < (int)data_sad[i].size()) {
192 if(bidder_id < 0) { // matching with a diagonal point
193 std::get<0>(t) = -1;
194 } else {
195 std::get<0>(t) = data_sad_idx[i][bidder_id];
196 }
197 if(std::get<1>(t) >= 0) {
198 std::get<1>(t) = std::get<1>(t) + centroids_sizes[c][0];
199 } else {
200 std::get<1>(t) = -1;
201 }
202 all_matchings[inv_clustering[i]][i].push_back(t);
203 }
204 }
205 }
206
207 if(do_max) {
208 for(size_t j = 0;
209 j < all_matchings_per_type_and_cluster[c][2][idxInCluster[i]].size();
210 j++) {
212 = all_matchings_per_type_and_cluster[c][2][idxInCluster[i]][j];
213 int const bidder_id = std::get<0>(t);
214 if(bidder_id < (int)data_max[i].size()) {
215 if(bidder_id < 0) { // matching with a diagonal point
216 std::get<0>(t) = -1;
217 } else {
218 std::get<0>(t) = data_max_idx[i][bidder_id];
219 }
220 // std::get<0>(t) = data_max_idx[i][bidder_id];
221 if(std::get<1>(t) > 0) {
222 std::get<1>(t)
223 = std::get<1>(t) + centroids_sizes[c][0] + centroids_sizes[c][1];
224 } else if(std::get<1>(t) == 0) {
225 if(!removeDuplicateGlobalPair) {
226 std::get<1>(t) = std::get<1>(t) + centroids_sizes[c][0]
227 + centroids_sizes[c][1];
228 }
229 } else {
230 std::get<1>(t) = -1;
231 }
232 all_matchings[inv_clustering[i]][i].push_back(t);
233 }
234 }
235 }
236 }
237
238 printMsg("Complete", 1, tm.getElapsedTime(), threadNumber_);
239 return inv_clustering;
240}
int debugLevel_
Definition Debug.h:379
int setNumberOfInputs(int numberOfInputs)
std::vector< int > execute(std::vector< DiagramType > &intermediateDiagrams, std::vector< DiagramType > &centroids, std::vector< std::vector< std::vector< MatchingType > > > &all_matchings)
double getElapsedTime()
Definition Timer.h:15
CriticalType
default value for critical index
Definition DataTypes.h:80
std::vector< PersistencePair > DiagramType
Persistence Diagram type as a vector of Persistence pairs.
std::tuple< int, int, double > MatchingType
Matching between two Persistence Diagram pairs.
printMsg(debug::output::BOLD+" | | | | | . \\ | | (__| | / __/| |_| / __/|__ _|"+debug::output::ENDCOLOR, debug::Priority::PERFORMANCE, debug::LineMode::NEW, stream)