46 template <
typename T,
class _Compare>
53 PermCompare(T *w, _Compare c) : weights(w),
comp(c) {
55 bool operator()(
int a,
int b) {
56 return comp(weights[a], weights[b]);
60 template <
class _Compare,
typename _RandomAccessIter>
61 void psort_split(_RandomAccessIter first,
62 _RandomAccessIter last,
65 std::vector<std::vector<ttk::SimplexId>> &right_ends,
66 MPI_Datatype &MPI_valueType,
67 MPI_Datatype &MPI_distanceType,
71 MPI_Comm &localComm) {
74 typename std::iterator_traits<_RandomAccessIter>::value_type;
76 int n_real = localSize;
77 for(
int i = 0; i < localSize; ++i)
83 std::copy(dist, dist + localSize, right_ends[localSize].
begin());
88 std::vector<ttk::SimplexId> targets(localSize - 1);
89 std::partial_sum(dist, dist + (localSize - 1), targets.data());
92 std::vector<std::pair<_RandomAccessIter, _RandomAccessIter>> d_ranges(
94 std::vector<std::pair<ttk::SimplexId *, ttk::SimplexId *>> t_ranges(
96 d_ranges[0] = std::make_pair(first, last);
98 = std::make_pair(targets.data(), targets.data() + localSize - 1);
100 std::vector<std::vector<ttk::SimplexId>> subdist(
101 localSize - 1, std::vector<ttk::SimplexId>(localSize));
102 std::copy(dist, dist + localSize, subdist[0].
begin());
105 std::vector<std::vector<ttk::SimplexId>> outleft(
106 localSize - 1, std::vector<ttk::SimplexId>(localSize, 0));
108 for(
int n_act = 1; n_act > 0;) {
110 for(
int k = 0; k < n_act; ++k) {
111 assert(subdist[k][localRank] == d_ranges[k].second - d_ranges[k].first);
118 std::vector<dataType> medians(localSize * n_act);
119 for(
int k = 0; k < n_act; ++k) {
120 if(d_ranges[k].first != last) {
121 dataType *ptr = &(*d_ranges[k].first);
123 medians[localRank * n_act + k] = ptr[index];
125 medians[localRank * n_act + k] = *(last - 1);
128 MPI_Allgather(MPI_IN_PLACE, n_act, MPI_valueType, &medians[0], n_act,
129 MPI_valueType, localComm);
132 std::vector<dataType> queries(n_act);
134 std::vector<ttk::SimplexId> ms_perm(n_real);
135 for(
int k = 0; k < n_act; ++k) {
137 for(
int i = 0; i < n_real; ++i)
138 ms_perm[i] = i * n_act + k;
139 TTK_PSORT(nThreads, ms_perm.data(), ms_perm.data() + n_real,
140 PermCompare<dataType, _Compare>(&medians[0], comp));
147 for(
int i = 0; i < n_real; ++i) {
148 if(subdist[k][ms_perm[i] / n_act] == 0)
151 mid -= subdist[k][ms_perm[i] / n_act];
153 query_ind = ms_perm[i];
158 assert(query_ind >= 0);
159 queries[k] = medians[query_ind];
162 std::vector<ttk::SimplexId> ind_local(2 * n_act);
164 for(
int k = 0; k < n_act; ++k) {
165 std::pair<_RandomAccessIter, _RandomAccessIter> ind_local_p
167 d_ranges[k].first, d_ranges[k].second, queries[k], comp);
169 ind_local[2 * k] = ind_local_p.first - first;
170 ind_local[2 * k + 1] = ind_local_p.second - first;
173 std::vector<ttk::SimplexId> ind_all(2 * n_act * localSize);
174 MPI_Allgather(ind_local.data(), 2 * n_act, MPI_distanceType,
175 ind_all.data(), 2 * n_act, MPI_distanceType, localComm);
177 std::vector<std::pair<ttk::SimplexId, ttk::SimplexId>> ind_global(n_act);
178 for(
int k = 0; k < n_act; ++k) {
179 ind_global[k] = std::make_pair(0, 0);
180 for(
int i = 0; i < localSize; ++i) {
181 ind_global[k].first += ind_all[2 * (i * n_act + k)];
182 ind_global[k].second += ind_all[2 * (i * n_act + k) + 1];
187 std::vector<std::pair<_RandomAccessIter, _RandomAccessIter>> d_ranges_x(
189 std::vector<std::pair<ttk::SimplexId *, ttk::SimplexId *>> t_ranges_x(
191 std::vector<std::vector<ttk::SimplexId>> subdist_x(
192 localSize - 1, std::vector<ttk::SimplexId>(localSize));
193 std::vector<std::vector<ttk::SimplexId>> outleft_x(
194 localSize - 1, std::vector<ttk::SimplexId>(localSize, 0));
197 for(
int k = 0; k < n_act; ++k) {
199 t_ranges[k].first, t_ranges[k].second, ind_global[k].first);
201 t_ranges[k].first, t_ranges[k].second, ind_global[k].second);
210 for(
int i = 0; i < localSize; ++i) {
212 = std::min(ind_all[2 * (i * n_act + k)] + excess,
213 ind_all[2 * (i * n_act + k) + 1]);
214 right_ends[(s - targets.data()) + 1][i] = amount;
215 excess -= amount - ind_all[2 * (i * n_act + k)];
219 if((split_low - t_ranges[k].first) > 0) {
220 t_ranges_x[n_act_x] = std::make_pair(t_ranges[k].first, split_low);
223 = std::make_pair(d_ranges[k].first, first + ind_local[2 * k]);
224 for(
int i = 0; i < localSize; ++i) {
225 subdist_x[n_act_x][i]
226 = ind_all[2 * (i * n_act + k)] - outleft[k][i];
227 outleft_x[n_act_x][i] = outleft[k][i];
232 if((t_ranges[k].second - split_high) > 0) {
233 t_ranges_x[n_act_x] = std::make_pair(split_high, t_ranges[k].second);
236 = std::make_pair(first + ind_local[2 * k + 1], d_ranges[k].second);
237 for(
int i = 0; i < localSize; ++i) {
238 subdist_x[n_act_x][i] = outleft[k][i] + subdist[k][i]
239 - ind_all[2 * (i * n_act + k) + 1];
240 outleft_x[n_act_x][i] = ind_all[2 * (i * n_act + k) + 1];
246 t_ranges = t_ranges_x;
247 d_ranges = d_ranges_x;
260 template <
typename dataType>
261 static void alltoall(std::vector<std::vector<ttk::SimplexId>> &right_ends,
262 std::vector<dataType> &data,
263 std::vector<dataType> &trans_data,
265 MPI_Datatype &MPI_valueType,
266 MPI_Datatype &MPI_distanceType,
269 MPI_Comm &localComm) {
276 if(n_loc_ > INT_MAX) {
279 int n_loc =
static_cast<int>(n_loc_);
282 std::vector<ttk::SimplexId> send_counts(localSize);
283 std::vector<ttk::SimplexId> recv_counts(localSize);
284#pragma omp parallel for reduction(+ : overflowInt)
285 for(
int i = 0; i < localSize; ++i) {
287 = right_ends[i + 1][localRank] - right_ends[i][localRank];
289 = right_ends[localRank + 1][i] - right_ends[localRank][i];
290 if(scount > INT_MAX || rcount > INT_MAX) {
293 send_counts[i] = scount;
294 recv_counts[i] = rcount;
296 MPI_Allreduce(MPI_IN_PLACE, &overflowInt, 1, MPI_CHAR, MPI_SUM, localComm);
298 MPI_IN_PLACE, &chunkSize, 1, MPI_distanceType, MPI_MAX, localComm);
299 if(overflowInt == 0) {
300 std::vector<int> send_counts_int(localSize);
301 std::vector<int> send_disps_int(localSize);
302 std::vector<int> recv_counts_int(localSize);
303 std::vector<int> recv_disps_int(localSize);
305 for(
int i = 0; i < localSize; i++) {
306 send_counts_int[i] =
static_cast<int>(send_counts[i]);
307 recv_counts_int[i] =
static_cast<int>(recv_counts[i]);
310 recv_disps_int[0] = 0;
311 std::partial_sum(recv_counts_int.data(),
312 recv_counts_int.data() + localSize - 1,
313 recv_disps_int.data() + 1);
315 send_disps_int[0] = 0;
316 std::partial_sum(send_counts_int.data(),
317 send_counts_int.data() + localSize - 1,
318 send_disps_int.data() + 1);
320 MPI_Alltoallv(data.data(), send_counts_int.data(), send_disps_int.data(),
321 MPI_valueType, trans_data.data(), recv_counts_int.data(),
322 recv_disps_int.data(), MPI_valueType, localComm);
324 for(
int i = 0; i < localSize; ++i)
334 std::vector<ttk::SimplexId> send_disps(localSize);
335 std::vector<ttk::SimplexId> recv_disps(localSize);
337 std::partial_sum(send_counts.data(), send_counts.data() + localSize - 1,
338 send_disps.data() + 1);
340 std::partial_sum(recv_counts.data(), recv_counts.data() + localSize - 1,
341 recv_disps.data() + 1);
345 std::vector<int> partial_recv_count(localSize, 0);
346 std::vector<int> partial_send_count(localSize, 0);
347 std::vector<int> partial_recv_displs(localSize, 0);
348 std::vector<int> partial_send_displs(localSize, 0);
349 std::vector<dataType> send_buffer_64bits(INT_MAX);
350 std::vector<dataType> recv_buffer_64bits(INT_MAX);
351 ttk::SimplexId messageSize = std::max(INT_MAX / localSize - 1, 1);
354 send_buffer_64bits.resize(0);
356 for(
int i = 0; i < localSize; i++) {
358 partial_send_displs[i]
359 = partial_send_displs[i - 1] + partial_send_count[i - 1];
360 partial_recv_displs[i]
361 = partial_recv_displs[i - 1] + partial_recv_count[i - 1];
363 if(send_counts[i] - count * messageSize > 0) {
365 if(send_counts[i] - count * messageSize > messageSize) {
366 partial_send_count[i] = messageSize;
368 partial_send_count[i] = send_counts[i] - count * messageSize;
370 std::copy(data.begin() + send_disps[i] + count * messageSize,
371 data.begin() + send_disps[i] + count * messageSize
372 + partial_send_count[i],
373 send_buffer_64bits.begin() + partial_send_displs[i]);
375 partial_send_count[i] = 0;
377 if(recv_counts[i] - count * messageSize > 0) {
378 if(recv_counts[i] - count * messageSize > messageSize) {
379 partial_recv_count[i] = messageSize;
381 partial_recv_count[i] = recv_counts[i] - count * messageSize;
384 partial_recv_count[i] = 0;
387 MPI_Alltoallv(send_buffer_64bits.data(), partial_send_count.data(),
388 partial_send_displs.data(), MPI_valueType,
389 recv_buffer_64bits.data(), partial_recv_count.data(),
390 partial_recv_displs.data(), MPI_valueType, localComm);
392 for(
int i = 0; i < localSize; i++) {
393 if(partial_recv_count[i] > 0) {
394 std::copy(recv_buffer_64bits.begin() + partial_recv_displs[i],
395 recv_buffer_64bits.begin() + partial_recv_displs[i]
396 + partial_recv_count[i],
397 trans_data.begin() + recv_disps[i] + count * messageSize);
403 MPI_IN_PLACE, &moreToSend, 1, MPI_INTEGER, MPI_SUM, localComm);
406 for(
int i = 0; i < localSize; ++i)
407 boundaries[i] = recv_disps[i];
408 boundaries[localSize] = n_loc_;
413 template <
class _Compare,
typename _RandomAccessIter>
414 void psort_merge(_RandomAccessIter in,
415 _RandomAccessIter out,
418 _Compare oppositeComp,
422 std::copy(in, in + disps[localSize], out);
426 _RandomAccessIter bufs[2] = {in, out};
428 std::vector<ttk::SimplexId> locs(localSize, 0);
433 if(stride >= localSize)
440 std::merge(bufs[locs[i]] + disps[i], bufs[locs[i]] + disps[i + next],
441 bufs[locs[i + next]] + disps[i + next],
442 bufs[locs[i + next]] + disps[end_ind],
443 bufs[1 - locs[i]] + disps[i], comp);
444 locs[i] = 1 - locs[i];
453 std::merge(in, in + disps[next], bufs[locs[next]] + disps[next],
454 bufs[locs[next]] + disps[localSize], out, comp);
455 }
else if(locs[next] == 0) {
458 std::reverse_iterator<_RandomAccessIter>(in + disps[localSize]),
459 std::reverse_iterator<_RandomAccessIter>(in + disps[next]),
460 std::reverse_iterator<_RandomAccessIter>(out + disps[next]),
461 std::reverse_iterator<_RandomAccessIter>(out),
462 std::reverse_iterator<_RandomAccessIter>(out + disps[localSize]),
466 std::inplace_merge(out, out + disps[next], out + disps[localSize], comp);
486 template <
typename dataType,
typename _Compare>
487 void parallel_sort(std::vector<dataType> &data,
489 _Compare oppositeComp,
490 std::vector<ttk::SimplexId> &dist,
491 MPI_Datatype &MPI_valueType,
492 MPI_Datatype &MPI_distanceType,
494 int isEmpty = (data.size() == 0);
498 MPI_Comm_split(ttk::MPIcomm_, isEmpty,
ttk::MPIrank_, &localComm);
499 MPI_Comm_rank(localComm, &localRank);
500 MPI_Comm_size(localComm, &localSize);
504 if(
static_cast<int>(dist.size()) != localSize) {
505 dist.resize(localSize);
507 MPI_Allgather(&dataSize, 1, MPI_distanceType, dist.data(), 1,
508 MPI_distanceType, localComm);
512 TTK_PSORT(nThreads, data.begin(), data.end(), comp);
518 std::vector<std::vector<ttk::SimplexId>> right_ends(
519 localSize + 1, std::vector<ttk::SimplexId>(localSize, 0));
520 psort_split<_Compare>(data.begin(), data.end(), dist.data(), comp,
521 right_ends, MPI_valueType, MPI_distanceType, nThreads,
522 localSize, localRank, localComm);
526 std::vector<dataType> trans_data(n_loc);
528 std::vector<ttk::SimplexId> boundaries(localSize + 1);
529 alltoall(right_ends, data, trans_data, boundaries.data(), MPI_valueType,
530 MPI_distanceType, localSize, localRank, localComm);
532 psort_merge<_Compare>(trans_data.data(), data.data(), boundaries.data(),
533 comp, oppositeComp, localSize);
#define TTK_FORCE_USE(x)
Force the compiler to use the function/method parameter.
#define TTK_PSORT(NTHREADS,...)
Parallel sort macro.
bool comp(const PersistencePair a, const PersistencePair b)
int SimplexId
Identifier type for simplices of any dimension.
COMMON_EXPORTS int MPIrank_
T end(std::pair< T, T > &p)
T begin(std::pair< T, T > &p)