KaMPIng 0.1.1
Flexible and (near) zero-overhead C++ bindings for MPI
Loading...
Searching...
No Matches
sort.hpp
1// This file is part of KaMPIng.
2//
3// Copyright 2024 The KaMPIng Authors
4//
5// KaMPIng is free software : you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
6// License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
7// version. KaMPIng is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the
8// implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
9// for more details.
10//
11// You should have received a copy of the GNU Lesser General Public License along with KaMPIng. If not, see
12// <https://www.gnu.org/licenses/>.
13
14#pragma once
15
16#include <algorithm>
17#include <cmath>
18#include <cstddef>
19#include <iterator>
20#include <numeric>
21#include <random>
22#include <vector>
23
24#include <kamping/utils/flatten.hpp>
25
26#include "kamping/collectives/allgather.hpp"
27#include "kamping/collectives/alltoall.hpp"
28#include "kamping/collectives/scan.hpp"
29#include "kamping/mpi_ops.hpp"
31#include "kamping/plugin/plugin_helpers.hpp"
32
33namespace kamping::plugin {
34
35/// @brief Plugin that adds a canonical sample sort to the communicator.
36/// @tparam Type of the communicator that is extended by the plugin.
37/// @tparam DefaultContainerType Default container type of the original communicator.
38template <typename Comm, template <typename...> typename DefaultContainerType>
39class SampleSort : public plugin::PluginBase<Comm, DefaultContainerType, SampleSort> {
40public:
41 /// @brief Sort the vector based on a binary comparison function (std::less by default).
42 ///
43 /// The order of equal elements is not guaranteed to be preserved. The binary comparison function has to be \c true
44 /// if the first argument is less than the second.
45 /// @tparam T Type of elements to be sorted.
46 /// @tparam Allocator Allocator of the vector.
47 /// @tparam Compare Type of the binary comparison function (\c std::less<T> by default).
48 /// @param data Vector containing the data to be sorted.
49 /// @param comp Binary comparison function used to determine the order of elements.
50 template <typename T, typename Allocator, typename Compare = std::less<T>>
51 void sort(std::vector<T, Allocator>& data, Compare comp = Compare{}) {
52 auto& self = this->to_communicator();
53 size_t const oversampling_ratio = 16 * static_cast<size_t>(std::log2(self.size())) + (data.size() > 0 ? 1 : 0);
54 std::vector<T> local_samples(oversampling_ratio);
55 std::sample(
56 data.begin(),
57 data.end(),
58 local_samples.begin(),
60 std::mt19937{self.rank() + self.size()}
61 );
62
63 auto global_samples = self.allgatherv(send_buf(local_samples));
64 pick_splitters(self.size() - 1, oversampling_ratio, global_samples, comp);
65 auto buckets = build_buckets(data.begin(), data.end(), global_samples, comp);
66 data = with_flattened(buckets).call([&](auto... flattened) { return self.alltoallv(std::move(flattened)...); });
67 std::sort(data.begin(), data.end(), comp);
68 }
69
70 /// @brief Sort the elements in [begin, end) using a binary comparison function (std::less by default).
71 ///
72 /// The order of equal elements in not guaranteed to be preserved. The binary comparison function has to be \c true
73 /// if the first argument is less than the second.
74 /// @tparam RandomIt Iterator type of the container containing the elements that are sorted.
75 /// @tparam OutputIt Iterator type of the output iterator.
76 /// @tparam Compare Type of the binary comparison function (\c std::less<> by default).
77 /// @param begin Start of the range of elements to sort.
78 /// @param end Element after the last element to be sorted.
79 /// @param out Output iterator used to output the sorted elements.
80 /// @param comp Binary comparison function used to determine the order of elements.
81 template <
82 typename RandomIt,
83 typename OutputIt,
84 typename Compare = std::less<typename std::iterator_traits<RandomIt>::value_type>>
86 using ValueType = typename std::iterator_traits<RandomIt>::value_type;
87
88 auto& self = this->to_communicator();
89 size_t const local_size = asserting_cast<size_t>(std::distance(begin, end));
90 size_t const oversampling_ratio = 16 * static_cast<size_t>(std::log2(self.size())) + (local_size > 0 ? 1 : 0);
91 std::vector<ValueType> local_samples(oversampling_ratio);
92 std::sample(
93 begin,
94 end,
95 local_samples.begin(),
97 std::mt19937{asserting_cast<std::mt19937::result_type>(self.rank() + self.size())}
98 );
99
100 auto global_samples = self.allgatherv(send_buf(local_samples));
101 pick_splitters(self.size() - 1, oversampling_ratio, global_samples, comp);
102 auto buckets = build_buckets(begin, end, global_samples, comp);
103 auto data =
104 with_flattened(buckets).call([&](auto... flattened) { return self.alltoallv(std::move(flattened)...); });
105 std::sort(data.begin(), data.end(), comp);
106 std::copy(data.begin(), data.end(), out);
107 }
108
109private:
110 /// @brief Picks spliters from a global list of splitters.
111 /// @tparam T Type of elements to be sorted (and of splitters)
112 /// @tparam Compare Type of the binary comparison function used to determine order of elements.
113 /// @param num_splitters Number of splitters that should be selected.
114 /// @param oversampling_ratio Ratio at which local splitters are sampled.
115 /// @param global_samples List of all (global) samples. Functions as out parameter where the picked samples are
116 /// stored in.
117 /// @param comp Binary comparison function used to determine order of elements.
118 template <typename T, typename Compare>
119 void pick_splitters(size_t num_splitters, size_t oversampling_ratio, std::vector<T>& global_samples, Compare comp) {
120 std::sort(global_samples.begin(), global_samples.end(), comp);
121 for (size_t i = 0; i < num_splitters; i++) {
122 global_samples[i] = global_samples[oversampling_ratio * (i + 1)];
123 }
124 global_samples.resize(num_splitters);
125 }
126
127 /// @brief Build buckets for a set of elements based on a set of splitters.
128 /// @tparam RandomIt Iterator type used to iterate through the set of elements.
129 /// @tparam T Type of elements.
130 /// @tparam Compare Type of binary comparison function used to determine order of elements.
131 /// @param begin Iterator to the beginning of the elements.
132 /// @param end Iterator pointing behind the laste element.
133 /// @param splitters
134 template <typename RandomIt, typename T, typename Compare>
135 auto build_buckets(RandomIt begin, RandomIt end, std::vector<T>& splitters, Compare comp)
136 -> std::vector<std::vector<T>> {
137 static_assert(
138 std::is_same_v<T, typename std::iterator_traits<RandomIt>::value_type>,
139 "Iterator value type and splitters do not match "
140 );
141 std::vector<std::vector<T>> buckets(splitters.size() + 1);
142 for (auto it = begin; it != end; ++it) {
143 auto const bound = std::upper_bound(splitters.begin(), splitters.end(), *it, comp);
144 buckets[asserting_cast<size_t>(std::distance(splitters.begin(), bound))].push_back(*it);
145 }
146 return buckets;
147 }
148};
149
150} // namespace kamping::plugin
STL-compatible allocator for requesting memory using the builtin MPI allocator.
Definition allocator.hpp:32
Plugin that adds a canonical sample sort to the communicator.
Definition sort.hpp:39
void sort(RandomIt begin, RandomIt end, OutputIt out, Compare comp=Compare{})
Sort the elements in [begin, end) using a binary comparison function (std::less by default).
Definition sort.hpp:85
void sort(std::vector< T, Allocator > &data, Compare comp=Compare{})
Sort the vector based on a binary comparison function (std::less by default).
Definition sort.hpp:51
auto send_buf(internal::ignore_t< Data > ignore)
Generates a dummy send buf that wraps a nullptr.
Definition named_parameters.hpp:51
Definitions for builtin MPI operations.
Factory methods for buffer wrappers.
Helper class for using CRTP for mixins. Which are used to implement kamping plugins.
Definition plugin_helpers.hpp:32