TTK
Loading...
Searching...
No Matches
PersistenceDiagramClustering.cpp
Go to the documentation of this file.
2
4 std::vector<DiagramType> &intermediateDiagrams,
5 std::vector<DiagramType> &final_centroids,
6 std::vector<std::vector<std::vector<MatchingType>>> &all_matchings) {
7
8 const int numberOfInputs_ = intermediateDiagrams.size();
9 Timer tm;
10
11 printMsg("Clustering " + std::to_string(numberOfInputs_) + " diagrams in "
12 + std::to_string(NumberOfClusters) + " cluster(s).");
13
14 std::vector<DiagramType> data_min(numberOfInputs_);
15 std::vector<DiagramType> data_sad(numberOfInputs_);
16 std::vector<DiagramType> data_max(numberOfInputs_);
17
18 std::vector<std::vector<int>> data_min_idx(numberOfInputs_);
19 std::vector<std::vector<int>> data_sad_idx(numberOfInputs_);
20 std::vector<std::vector<int>> data_max_idx(numberOfInputs_);
21
22 std::vector<int> inv_clustering(numberOfInputs_);
23
24 bool do_min = false;
25 bool do_sad = false;
26 bool do_max = false;
27
28 // Create diagrams for min, saddle and max persistence pairs
29 for(int i = 0; i < numberOfInputs_; i++) {
30 DiagramType &CTDiagram = intermediateDiagrams[i];
31
32 for(size_t j = 0; j < CTDiagram.size(); ++j) {
33 auto &t = CTDiagram[j];
34
35 ttk::CriticalType nt1 = t.birth.type;
36 ttk::CriticalType nt2 = t.death.type;
37
38 double dt = t.persistence();
39 // if (abs<double>(dt) < zeroThresh) continue;
40 if(dt > 0) {
43 if(PairTypeClustering == 2) {
44 data_max[i].push_back(t);
45 data_max_idx[i].push_back(j);
46 do_max = true;
47 } else {
48 data_min[i].push_back(t);
49 data_min_idx[i].push_back(j);
50 do_min = true;
51 }
52 } else {
55 data_max[i].push_back(t);
56 data_max_idx[i].push_back(j);
57 do_max = true;
58 }
61 data_min[i].push_back(t);
62 data_min_idx[i].push_back(j);
63 do_min = true;
64 }
68 && nt2 == ttk::CriticalType::Saddle1)) {
69 data_sad[i].push_back(t);
70 data_sad_idx[i].push_back(j);
71 do_sad = true;
72 }
73 }
74 }
75 }
76 }
77
78 std::stringstream msg;
79 switch(PairTypeClustering) {
80 case(0):
81 msg << "Only MIN-SAD Pairs";
82 do_max = false;
83 do_sad = false;
84 break;
85 case(1):
86 msg << "Only SAD-SAD Pairs";
87 do_max = false;
88 do_min = false;
89 break;
90 case(2):
91 msg << "Only SAD-MAX Pairs";
92 do_min = false;
93 do_sad = false;
94 break;
95 default:
96 msg << "All critical pairs: "
97 "global clustering";
98 break;
99 }
100 printMsg(msg.str());
101
102 std::vector<std::vector<std::vector<std::vector<MatchingType>>>>
103 all_matchings_per_type_and_cluster;
104 PDClustering KMeans{};
105 KMeans.setNumberOfInputs(numberOfInputs_);
106 KMeans.setWasserstein(WassersteinMetric);
107 KMeans.setUseProgressive(UseProgressive);
108 KMeans.setAccelerated(UseAccelerated);
109 KMeans.setUseKDTree(true);
110 KMeans.setTimeLimit(TimeLimit);
111 KMeans.setGeometricalFactor(Alpha);
112 KMeans.setLambda(Lambda);
113 KMeans.setDeterministic(Deterministic);
114 KMeans.setForceUseOfAlgorithm(ForceUseOfAlgorithm);
115 KMeans.setDebugLevel(debugLevel_);
116 KMeans.setDeltaLim(DeltaLim);
117 KMeans.setUseDeltaLim(UseAdditionalPrecision);
118 KMeans.setDistanceWritingOptions(DistanceWritingOptions);
119 KMeans.setKMeanspp(UseKmeansppInit);
120 KMeans.setK(NumberOfClusters);
121 KMeans.setDiagrams(&data_min, &data_sad, &data_max);
122 KMeans.setDos(do_min, do_sad, do_max);
123 inv_clustering
124 = KMeans.execute(final_centroids, all_matchings_per_type_and_cluster);
125 std::vector<std::vector<int>> centroids_sizes = KMeans.get_centroids_sizes();
126
127 this->distances = KMeans.getDistances();
128
130 //
131 std::vector<int> cluster_size;
132 std::vector<int> idxInCluster(numberOfInputs_);
133
134 for(int j = 0; j < numberOfInputs_; ++j) {
135 size_t c = inv_clustering[j];
136 if(c + 1 > cluster_size.size()) {
137 cluster_size.resize(c + 1);
138 cluster_size[c] = 1;
139 idxInCluster[j] = 0;
140 } else {
141 cluster_size[c]++;
142 idxInCluster[j] = cluster_size[c] - 1;
143 }
144 }
145
146 bool removeDuplicateGlobalPair = false;
147 if(NumberOfClusters > 1 and do_min and do_max) {
148 removeDuplicateGlobalPair = true;
149 }
150
151 all_matchings.resize(NumberOfClusters);
152 for(int c = 0; c < NumberOfClusters; c++) {
153 all_matchings[c].resize(numberOfInputs_);
154 if(removeDuplicateGlobalPair) {
155 centroids_sizes[c][0] -= 1;
156 }
157 }
158 for(int i = 0; i < numberOfInputs_; i++) {
159 size_t c = inv_clustering[i];
160
161 if(do_min) {
162 for(size_t j = 0;
163 j < all_matchings_per_type_and_cluster[c][0][idxInCluster[i]].size();
164 j++) {
166 = all_matchings_per_type_and_cluster[c][0][idxInCluster[i]][j];
167 int bidder_id = std::get<0>(t);
168 if(bidder_id < (int)data_min[i].size()) {
169 if(bidder_id < 0) { // matching with a diagonal point
170 std::get<0>(t) = -1;
171 } else {
172 std::get<0>(t) = data_min_idx[i][bidder_id];
173 }
174 // cout<<" IDS : "<<bidder_id<<" "<<std::get<0>(t)<<endl;
175 if(std::get<1>(t) < 0) {
176 std::get<1>(t) = -1;
177 }
178 all_matchings[inv_clustering[i]][i].push_back(t);
179 }
180 }
181 }
182
183 if(do_sad) {
184 for(size_t j = 0;
185 j < all_matchings_per_type_and_cluster[c][1][idxInCluster[i]].size();
186 j++) {
188 = all_matchings_per_type_and_cluster[c][1][idxInCluster[i]][j];
189 int bidder_id = std::get<0>(t);
190 if(bidder_id < (int)data_sad[i].size()) {
191 if(bidder_id < 0) { // matching with a diagonal point
192 std::get<0>(t) = -1;
193 } else {
194 std::get<0>(t) = data_sad_idx[i][bidder_id];
195 }
196 if(std::get<1>(t) >= 0) {
197 std::get<1>(t) = std::get<1>(t) + centroids_sizes[c][0];
198 } else {
199 std::get<1>(t) = -1;
200 }
201 all_matchings[inv_clustering[i]][i].push_back(t);
202 }
203 }
204 }
205
206 if(do_max) {
207 for(size_t j = 0;
208 j < all_matchings_per_type_and_cluster[c][2][idxInCluster[i]].size();
209 j++) {
211 = all_matchings_per_type_and_cluster[c][2][idxInCluster[i]][j];
212 int bidder_id = std::get<0>(t);
213 if(bidder_id < (int)data_max[i].size()) {
214 if(bidder_id < 0) { // matching with a diagonal point
215 std::get<0>(t) = -1;
216 } else {
217 std::get<0>(t) = data_max_idx[i][bidder_id];
218 }
219 // std::get<0>(t) = data_max_idx[i][bidder_id];
220 if(std::get<1>(t) > 0) {
221 std::get<1>(t)
222 = std::get<1>(t) + centroids_sizes[c][0] + centroids_sizes[c][1];
223 } else if(std::get<1>(t) == 0) {
224 if(!removeDuplicateGlobalPair) {
225 std::get<1>(t) = std::get<1>(t) + centroids_sizes[c][0]
226 + centroids_sizes[c][1];
227 }
228 } else {
229 std::get<1>(t) = -1;
230 }
231 all_matchings[inv_clustering[i]][i].push_back(t);
232 }
233 }
234 }
235 }
236
237 printMsg("Complete", 1, tm.getElapsedTime(), threadNumber_);
238 return inv_clustering;
239}
int threadNumber_
Definition: BaseClass.h:95
int debugLevel_
Definition: Debug.h:379
int printMsg(const std::string &msg, const debug::Priority &priority=debug::Priority::INFO, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cout) const
Definition: Debug.h:118
int setNumberOfInputs(int numberOfInputs)
Definition: PDClustering.h:144
std::vector< int > execute(std::vector< DiagramType > &intermediateDiagrams, std::vector< DiagramType > &centroids, std::vector< std::vector< std::vector< MatchingType > > > &all_matchings)
double getElapsedTime()
Definition: Timer.h:15
CriticalType
default value for critical index
Definition: DataTypes.h:80
std::vector< PersistencePair > DiagramType
Persistence Diagram type as a vector of Persistence pairs.
std::tuple< int, int, double > MatchingType
Matching between two Persistence Diagram pairs.