TTK
Loading...
Searching...
No Matches
PersistenceDiagramBarycenter.cpp
Go to the documentation of this file.
2
4 std::vector<DiagramType> &intermediateDiagrams,
5 DiagramType &barycenter,
6 std::vector<std::vector<std::vector<MatchingType>>> &all_matchings) {
7
8 Timer tm;
9
10 printMsg("Computing Barycenter of " + std::to_string(numberOfInputs_)
11 + " diagrams.");
12
13 std::vector<DiagramType> data_min(numberOfInputs_);
14 std::vector<DiagramType> data_sad(numberOfInputs_);
15 std::vector<DiagramType> data_max(numberOfInputs_);
16
17 std::vector<std::vector<int>> data_min_idx(numberOfInputs_);
18 std::vector<std::vector<int>> data_sad_idx(numberOfInputs_);
19 std::vector<std::vector<int>> data_max_idx(numberOfInputs_);
20
21 bool do_min = false;
22 bool do_sad = false;
23 bool do_max = false;
24
25 // Create diagrams for min, saddle and max persistence pairs
26 for(int i = 0; i < numberOfInputs_; i++) {
27 DiagramType &CTDiagram = intermediateDiagrams[i];
28
29 for(size_t j = 0; j < CTDiagram.size(); ++j) {
30 PersistencePair &t = CTDiagram[j];
31
34
35 double dt = t.persistence();
36 // if (abs<double>(dt) < zeroThresh) continue;
37 if(dt > 0) {
40 data_max[i].push_back(t);
41 data_max_idx[i].push_back(j);
42 do_max = true;
43 } else {
46 data_max[i].push_back(t);
47 data_max_idx[i].push_back(j);
48 do_max = true;
49 }
52 data_min[i].push_back(t);
53 data_min_idx[i].push_back(j);
54 do_min = true;
55 }
59 && nt2 == ttk::CriticalType::Saddle1)) {
60 data_sad[i].push_back(t);
61 data_sad_idx[i].push_back(j);
62 do_sad = true;
63 }
64 }
65 }
66 }
67 }
68
69 DiagramType barycenter_min;
70 DiagramType barycenter_sad;
71 DiagramType barycenter_max;
72
73 std::vector<std::vector<MatchingType>> matching_min, matching_sad,
74 matching_max;
75
76 double total_cost = 0, min_cost = 0, sad_cost = 0, max_cost = 0;
77 /*omp_set_num_threads(1);
78 #ifdef TTK_ENABLE_OPENMP
79 #pragma omp parallel sections
80 #endif
81 {
82 #ifdef TTK_ENABLE_OPENMP
83 #pragma omp section
84 #endif
85 {*/
86 if(do_min) {
87 printMsg("Computing Minima barycenter...");
88 PDBarycenter bary_min{};
90 bary_min.setWasserstein(wasserstein_);
91 bary_min.setNumberOfInputs(numberOfInputs_);
92 bary_min.setDiagramType(0);
93 bary_min.setUseProgressive(use_progressive_);
94 bary_min.setGeometricalFactor(alpha_);
95 bary_min.setDebugLevel(debugLevel_);
96 bary_min.setDeterministic(deterministic_);
97 bary_min.setLambda(lambda_);
98 bary_min.setMethod(method_);
99 bary_min.setEarlyStoppage(early_stoppage_);
100 bary_min.setEpsilonDecreases(epsilon_decreases_);
101 bary_min.setReinitPrices(reinit_prices_);
102 bary_min.setDiagrams(&data_min);
103 matching_min = bary_min.execute(barycenter_min);
104 min_cost = bary_min.getCost();
105 total_cost += min_cost;
106 }
107 /*}
108
109 #ifdef TTK_ENABLE_OPENMP
110 #pragma omp section
111 #endif
112 {*/
113 if(do_sad) {
114 printMsg("Computing Saddles barycenter...");
115 PDBarycenter bary_sad{};
117 bary_sad.setWasserstein(wasserstein_);
118 bary_sad.setNumberOfInputs(numberOfInputs_);
119 bary_sad.setDiagramType(1);
120 bary_sad.setUseProgressive(use_progressive_);
121 bary_sad.setGeometricalFactor(alpha_);
122 bary_sad.setLambda(lambda_);
123 bary_sad.setDebugLevel(debugLevel_);
124 bary_sad.setMethod(method_);
125 bary_sad.setEarlyStoppage(early_stoppage_);
126 bary_sad.setEpsilonDecreases(epsilon_decreases_);
127 bary_sad.setDeterministic(deterministic_);
128 bary_sad.setReinitPrices(reinit_prices_);
129 bary_sad.setDiagrams(&data_sad);
130 matching_sad = bary_sad.execute(barycenter_sad);
131 sad_cost = bary_sad.getCost();
132 total_cost += sad_cost;
133 }
134 /*}
135
136 #ifdef TTK_ENABLE_OPENMP
137 #pragma omp section
138 #endif
139 {*/
140 if(do_max) {
141 printMsg("Computing Maxima barycenter...");
142 PDBarycenter bary_max{};
144 bary_max.setWasserstein(wasserstein_);
145 bary_max.setNumberOfInputs(numberOfInputs_);
146 bary_max.setDiagramType(2);
147 bary_max.setUseProgressive(use_progressive_);
148 bary_max.setGeometricalFactor(alpha_);
149 bary_max.setLambda(lambda_);
150 bary_max.setMethod(method_);
151 bary_max.setDebugLevel(debugLevel_);
152 bary_max.setEarlyStoppage(early_stoppage_);
153 bary_max.setDeterministic(deterministic_);
154 bary_max.setEpsilonDecreases(epsilon_decreases_);
155 bary_max.setReinitPrices(reinit_prices_);
156 bary_max.setDiagrams(&data_max);
157 matching_max = bary_max.execute(barycenter_max);
158 max_cost = bary_max.getCost();
159 total_cost += max_cost;
160 }
161 //}
162 //}
163
164 // Reconstruct matchings
165 all_matchings.resize(1);
166 all_matchings[0].resize(numberOfInputs_);
167 for(int i = 0; i < numberOfInputs_; i++) {
168
169 if(do_min) {
170 for(size_t j = 0; j < matching_min[i].size(); j++) {
171 MatchingType t = matching_min[i][j];
172 int bidder_id = std::get<0>(t);
173 std::get<0>(t) = data_min_idx[i][bidder_id];
174 if(std::get<1>(t) < 0) {
175 std::get<1>(t) = -1;
176 }
177 all_matchings[0][i].push_back(t);
178 }
179 }
180
181 if(do_sad) {
182 for(size_t j = 0; j < matching_sad[i].size(); j++) {
183 MatchingType t = matching_sad[i][j];
184 int bidder_id = std::get<0>(t);
185 std::get<0>(t) = data_sad_idx[i][bidder_id];
186 if(std::get<1>(t) >= 0) {
187 std::get<1>(t) = std::get<1>(t) + barycenter_min.size();
188 } else {
189 std::get<1>(t) = -1;
190 }
191 all_matchings[0][i].push_back(t);
192 }
193 }
194
195 if(do_max) {
196 for(size_t j = 0; j < matching_max[i].size(); j++) {
197 MatchingType t = matching_max[i][j];
198 int bidder_id = std::get<0>(t);
199 std::get<0>(t) = data_max_idx[i][bidder_id];
200 if(std::get<1>(t) >= 0) {
201 std::get<1>(t)
202 = std::get<1>(t) + barycenter_min.size() + barycenter_sad.size();
203 } else {
204 std::get<1>(t) = -1;
205 }
206 all_matchings[0][i].push_back(t);
207 }
208 }
209 }
210 // Reconstruct barcenter
211 for(size_t j = 0; j < barycenter_min.size(); j++) {
212 const auto &dt = barycenter_min[j];
213 barycenter.push_back(dt);
214 }
215 for(size_t j = 0; j < barycenter_sad.size(); j++) {
216 const auto &dt = barycenter_sad[j];
217 barycenter.push_back(dt);
218 }
219 for(size_t j = 0; j < barycenter_max.size(); j++) {
220 const auto &dt = barycenter_max[j];
221 barycenter.push_back(dt);
222 }
223
224 // Recreate 3D critical coordinates of barycentric points
225 std::vector<int> number_of_matchings_for_point(barycenter.size());
226 std::vector<float> cords_x1(barycenter.size());
227 std::vector<float> cords_y1(barycenter.size());
228 std::vector<float> cords_z1(barycenter.size());
229 std::vector<float> cords_x2(barycenter.size());
230 std::vector<float> cords_y2(barycenter.size());
231 std::vector<float> cords_z2(barycenter.size());
232 for(unsigned i = 0; i < barycenter.size(); i++) {
233 number_of_matchings_for_point[i] = 0;
234 cords_x1[i] = 0;
235 cords_y1[i] = 0;
236 cords_z1[i] = 0;
237 cords_x2[i] = 0;
238 cords_y2[i] = 0;
239 cords_z2[i] = 0;
240 }
241
242 for(unsigned i = 0; i < all_matchings[0].size(); i++) {
243 DiagramType &CTDiagram = intermediateDiagrams[i];
244 for(unsigned j = 0; j < all_matchings[0][i].size(); j++) {
245 MatchingType t = all_matchings[0][i][j];
246 int bidder_id = std::get<0>(t);
247 int bary_id = std::get<1>(t);
248
249 const auto &bidder = CTDiagram[bidder_id];
250 number_of_matchings_for_point[bary_id] += 1;
251 cords_x1[bary_id] += bidder.birth.coords[0];
252 cords_y1[bary_id] += bidder.birth.coords[1];
253 cords_z1[bary_id] += bidder.birth.coords[2];
254 cords_x2[bary_id] += bidder.death.coords[0];
255 cords_y2[bary_id] += bidder.death.coords[1];
256 cords_z2[bary_id] += bidder.death.coords[2];
257 }
258 }
259
260 for(unsigned i = 0; i < barycenter.size(); i++) {
261 if(number_of_matchings_for_point[i] > 0) {
262 barycenter[i].birth.coords[0]
263 = cords_x1[i] / number_of_matchings_for_point[i];
264 barycenter[i].birth.coords[1]
265 = cords_y1[i] / number_of_matchings_for_point[i];
266 barycenter[i].birth.coords[2]
267 = cords_z1[i] / number_of_matchings_for_point[i];
268 barycenter[i].death.coords[0]
269 = cords_x2[i] / number_of_matchings_for_point[i];
270 barycenter[i].death.coords[1]
271 = cords_y2[i] / number_of_matchings_for_point[i];
272 barycenter[i].death.coords[2]
273 = cords_z2[i] / number_of_matchings_for_point[i];
274 }
275 }
276
277 printMsg("Min-saddle cost : " + std::to_string(min_cost));
278 printMsg("Saddle-saddle cost : " + std::to_string(sad_cost));
279 printMsg("Saddle-max cost : " + std::to_string(max_cost));
280 printMsg("Total cost : " + std::to_string(total_cost));
281 printMsg("Complete", 1, tm.getElapsedTime(), threadNumber_);
282}
int threadNumber_
Definition: BaseClass.h:95
virtual int setThreadNumber(const int threadNumber)
Definition: BaseClass.h:80
int debugLevel_
Definition: Debug.h:379
int printMsg(const std::string &msg, const debug::Priority &priority=debug::Priority::INFO, const debug::LineMode &lineMode=debug::LineMode::NEW, std::ostream &stream=std::cout) const
Definition: Debug.h:118
void execute(std::vector< DiagramType > &intermediateDiagrams, DiagramType &barycenter, std::vector< std::vector< std::vector< MatchingType > > > &all_matchings)
double getElapsedTime()
Definition: Timer.h:15
CriticalType
default value for critical index
Definition: DataTypes.h:80
std::vector< PersistencePair > DiagramType
Persistence Diagram type as a vector of Persistence pairs.
std::tuple< int, int, double > MatchingType
Matching between two Persistence Diagram pairs.
ttk::CriticalVertex birth
double persistence() const
Return the topological persistence of the pair.
ttk::CriticalVertex death