mpi.hpp
1 
5 #pragma once
6 
7 #include <cstdlib>
8 #include <memory>
9 #include <vector>
10 
11 #include <mpi.h>
12 
13 #include <rmm/device_buffer.hpp>
14 
15 #include <rapidsmpf/communicator/communicator.hpp>
16 #include <rapidsmpf/error.hpp>
17 #include <rapidsmpf/progress_thread.hpp>
18 
19 namespace rapidsmpf {
20 
25 namespace mpi {
26 
33 void init(int* argc, char*** argv);
34 
41 
50 #define RAPIDSMPF_MPI(call) \
51  rapidsmpf::mpi::detail::check_mpi_error((call), __FILE__, __LINE__)
52 
53 namespace detail {
61 void check_mpi_error(int error_code, char const* file, int line);
62 } // namespace detail
63 } // namespace mpi
64 
73 class MPI final : public Communicator {
74  public:
81  class Future : public Communicator::Future {
82  friend class MPI;
83 
84  public:
91  Future(MPI_Request req, std::unique_ptr<Buffer> data_buffer)
92  : req_{req}, data_buffer_{std::move(data_buffer)} {}
93 
104  MPI_Request req, std::unique_ptr<std::vector<std::uint8_t>> synced_host_data
105  )
106  : req_{std::move(req)}, synced_host_data_{std::move(synced_host_data)} {}
107 
108  ~Future() noexcept override = default;
109 
110  private:
111  MPI_Request req_;
112  // TODO: these buffers are mutually exclusive and looks similar to
113  // Buffer::storage_.
114  std::unique_ptr<Buffer> data_buffer_;
115  // Dedicated storage for host data that is valid at the time of construction.
116  std::unique_ptr<std::vector<std::uint8_t>> synced_host_data_;
117  };
118 
126  MPI(MPI_Comm comm,
127  config::Options options,
128  std::shared_ptr<ProgressThread> progress_thread);
129 
130  ~MPI() noexcept override = default;
131 
135  [[nodiscard]] Rank rank() const override {
136  return rank_;
137  }
138 
142  [[nodiscard]] Rank nranks() const override {
143  return nranks_;
144  }
145 
151  [[nodiscard]] std::unique_ptr<Communicator::Future> send(
152  std::unique_ptr<std::vector<std::uint8_t>> msg, Rank rank, Tag tag
153  ) override;
154 
155  // clang-format off
161  // clang-format on
162  [[nodiscard]] std::unique_ptr<Communicator::Future> send(
163  std::unique_ptr<Buffer> msg, Rank rank, Tag tag
164  ) override;
165 
171  [[nodiscard]] std::unique_ptr<Communicator::Future> recv(
172  Rank rank, Tag tag, std::unique_ptr<Buffer> recv_buffer
173  ) override;
174 
175  // clang-format off
181  // clang-format on
182  [[nodiscard]] std::unique_ptr<Communicator::Future> recv_sync_host_data(
183  Rank rank, Tag tag, std::unique_ptr<std::vector<std::uint8_t>> synced_buffer
184  ) override;
185 
189  [[nodiscard]] std::pair<std::unique_ptr<std::vector<std::uint8_t>>, Rank> recv_any(
190  Tag tag
191  ) override;
192 
196  [[nodiscard]] std::unique_ptr<std::vector<std::uint8_t>> recv_from(
197  Rank src, Tag tag
198  ) override;
202  std::pair<
203  std::vector<std::unique_ptr<Communicator::Future>>,
204  std::vector<std::size_t>>
205  test_some(std::vector<std::unique_ptr<Communicator::Future>>& future_vector) override;
206 
207  // clang-format off
211  // clang-format on
212  std::vector<std::size_t> test_some(
213  std::unordered_map<std::size_t, std::unique_ptr<Communicator::Future>> const&
214  future_map
215  ) override;
216 
218  bool test(std::unique_ptr<Communicator::Future>& future) override;
220  std::vector<std::unique_ptr<Buffer>> wait_all(
221  std::vector<std::unique_ptr<Communicator::Future>>&& futures
222  ) override;
223 
227  [[nodiscard]] std::unique_ptr<Buffer> wait(
228  std::unique_ptr<Communicator::Future> future
229  ) override;
230 
234  [[nodiscard]] std::unique_ptr<Buffer> release_data(
235  std::unique_ptr<Communicator::Future> future
236  ) override;
237 
241  [[nodiscard]] std::unique_ptr<std::vector<std::uint8_t>> release_sync_host_data(
242  std::unique_ptr<Communicator::Future> future
243  ) override;
244 
248  [[nodiscard]] std::shared_ptr<Communicator::Logger> const& logger() override {
249  return logger_;
250  }
251 
255  [[nodiscard]] std::shared_ptr<ProgressThread> const&
256  progress_thread() const override {
257  return progress_thread_;
258  }
259 
263  [[nodiscard]] std::string str() const override;
264 
265  private:
266  MPI_Comm comm_;
267  Rank rank_;
268  Rank nranks_;
269  std::shared_ptr<Logger> logger_;
270  std::shared_ptr<ProgressThread> progress_thread_;
271 };
272 
273 
274 } // namespace rapidsmpf
Buffer representing device or host memory.
Definition: buffer.hpp:47
Abstract base class for asynchronous operation within the communicator.
Abstract base class for a communication mechanism between nodes.
Represents the future result of an MPI operation.
Definition: mpi.hpp:81
Future(MPI_Request req, std::unique_ptr< std::vector< std::uint8_t >> synced_host_data)
Construct a Future from synchronized host data.
Definition: mpi.hpp:103
Future(MPI_Request req, std::unique_ptr< Buffer > data_buffer)
Construct a Future from a data buffer.
Definition: mpi.hpp:91
MPI communicator class that implements the Communicator interface.
Definition: mpi.hpp:73
std::shared_ptr< Communicator::Logger > const & logger() override
Retrieves the logger associated with this communicator.
Definition: mpi.hpp:248
std::unique_ptr< Communicator::Future > send(std::unique_ptr< std::vector< std::uint8_t >> msg, Rank rank, Tag tag) override
Sends a host message to a specific rank.
std::vector< std::size_t > test_some(std::unordered_map< std::size_t, std::unique_ptr< Communicator::Future >> const &future_map) override
Tests for completion of multiple futures in a map.
std::unique_ptr< std::vector< std::uint8_t > > recv_from(Rank src, Tag tag) override
Receives a message from a specific rank (blocking).
std::vector< std::unique_ptr< Buffer > > wait_all(std::vector< std::unique_ptr< Communicator::Future >> &&futures) override
Wait for completion of all futures and return their data buffers.
std::unique_ptr< Buffer > wait(std::unique_ptr< Communicator::Future > future) override
Wait for a future to complete and return the data buffer.
std::pair< std::vector< std::unique_ptr< Communicator::Future > >, std::vector< std::size_t > > test_some(std::vector< std::unique_ptr< Communicator::Future >> &future_vector) override
Tests for completion of multiple futures.
std::shared_ptr< ProgressThread > const & progress_thread() const override
Retrieves the progress thread associated with this communicator.
Definition: mpi.hpp:256
std::unique_ptr< Communicator::Future > send(std::unique_ptr< Buffer > msg, Rank rank, Tag tag) override
Sends a message (device or host) to a specific rank. Use release_data to obtain the data buffer again...
bool test(std::unique_ptr< Communicator::Future > &future) override
Test for completion of a single future.
std::unique_ptr< Communicator::Future > recv_sync_host_data(Rank rank, Tag tag, std::unique_ptr< std::vector< std::uint8_t >> synced_buffer) override
Receives a message from a specific rank to an allocated (synchronized) host buffer....
Rank nranks() const override
Retrieves the total number of ranks.
Definition: mpi.hpp:142
std::unique_ptr< std::vector< std::uint8_t > > release_sync_host_data(std::unique_ptr< Communicator::Future > future) override
Retrieves synchronized host data associated with a completed future. When the future is completed,...
Rank rank() const override
Retrieves the rank of the current node.
Definition: mpi.hpp:135
std::pair< std::unique_ptr< std::vector< std::uint8_t > >, Rank > recv_any(Tag tag) override
Receives a message from any rank (blocking).
std::string str() const override
Provides a string representation of the communicator.
std::unique_ptr< Communicator::Future > recv(Rank rank, Tag tag, std::unique_ptr< Buffer > recv_buffer) override
Receives a message from a specific rank to a buffer. Use release_data to extract the data out of the ...
std::unique_ptr< Buffer > release_data(std::unique_ptr< Communicator::Future > future) override
Retrieves data associated with a completed future.
A progress thread that can execute arbitrary functions.
A tag used for identifying messages in a communication operation.
void init(int *argc, char ***argv)
Helper to initialize MPI with threading support.
bool is_initialized()
Check if MPI is initialized.
RAPIDS Multi-Processor interfaces.
Definition: backend.hpp:13
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).