mpi.hpp
1 
5 #pragma once
6 
7 #include <cstdlib>
8 #include <memory>
9 #include <vector>
10 
11 #include <mpi.h>
12 
13 #include <rmm/device_buffer.hpp>
14 
15 #include <rapidsmpf/communicator/communicator.hpp>
16 #include <rapidsmpf/error.hpp>
17 
18 namespace rapidsmpf {
19 
24 namespace mpi {
25 
32 void init(int* argc, char*** argv);
33 
40 
49 #define RAPIDSMPF_MPI(call) \
50  rapidsmpf::mpi::detail::check_mpi_error((call), __FILE__, __LINE__)
51 
52 namespace detail {
60 void check_mpi_error(int error_code, const char* file, int line);
61 } // namespace detail
62 } // namespace mpi
63 
72 class MPI final : public Communicator {
73  public:
80  class Future : public Communicator::Future {
81  friend class MPI;
82 
83  public:
90  Future(MPI_Request req, std::unique_ptr<Buffer> data_buffer)
91  : req_{req}, data_buffer_{std::move(data_buffer)} {}
92 
102  Future(MPI_Request req, std::unique_ptr<std::vector<uint8_t>> synced_host_data)
103  : req_{std::move(req)}, synced_host_data_{std::move(synced_host_data)} {}
104 
105  ~Future() noexcept override = default;
106 
107  private:
108  MPI_Request req_;
109  // TODO: these buffers are mutually exclusive and looks similar to
110  // Buffer::storage_.
111  std::unique_ptr<Buffer> data_buffer_;
112  // Dedicated storage for host data that is valid at the time of construction.
113  std::unique_ptr<std::vector<uint8_t>> synced_host_data_;
114  };
115 
122  MPI(MPI_Comm comm, config::Options options);
123 
124  ~MPI() noexcept override = default;
125 
129  [[nodiscard]] Rank rank() const override {
130  return rank_;
131  }
132 
136  [[nodiscard]] Rank nranks() const override {
137  return nranks_;
138  }
139 
143  [[nodiscard]] std::unique_ptr<Communicator::Future> send(
144  std::unique_ptr<std::vector<uint8_t>> msg, Rank rank, Tag tag
145  ) override;
146 
147  // clang-format off
151  // clang-format on
152  [[nodiscard]] std::unique_ptr<Communicator::Future> send(
153  std::unique_ptr<Buffer> msg, Rank rank, Tag tag
154  ) override;
155 
159  [[nodiscard]] std::unique_ptr<Communicator::Future> recv(
160  Rank rank, Tag tag, std::unique_ptr<Buffer> recv_buffer
161  ) override;
162 
163  // clang-format off
167  // clang-format on
168  [[nodiscard]] std::unique_ptr<Communicator::Future> recv_sync_host_data(
169  Rank rank, Tag tag, std::unique_ptr<std::vector<uint8_t>> synced_buffer
170  ) override;
171 
175  [[nodiscard]] std::pair<std::unique_ptr<std::vector<uint8_t>>, Rank> recv_any(
176  Tag tag
177  ) override;
178 
182  [[nodiscard]] std::unique_ptr<std::vector<uint8_t>> recv_from(
183  Rank src, Tag tag
184  ) override;
188  std::pair<
189  std::vector<std::unique_ptr<Communicator::Future>>,
190  std::vector<std::size_t>>
191  test_some(std::vector<std::unique_ptr<Communicator::Future>>& future_vector) override;
192 
193  // clang-format off
197  // clang-format on
198  std::vector<std::size_t> test_some(
199  std::unordered_map<std::size_t, std::unique_ptr<Communicator::Future>> const&
200  future_map
201  ) override;
202 
206  [[nodiscard]] std::unique_ptr<Buffer> wait(
207  std::unique_ptr<Communicator::Future> future
208  ) override;
209 
213  [[nodiscard]] std::unique_ptr<Buffer> release_data(
214  std::unique_ptr<Communicator::Future> future
215  ) override;
216 
220  [[nodiscard]] std::unique_ptr<std::vector<uint8_t>> release_sync_host_data(
221  std::unique_ptr<Communicator::Future> future
222  ) override;
223 
227  [[nodiscard]] Logger& logger() override {
228  return logger_;
229  }
230 
234  [[nodiscard]] std::string str() const override;
235 
236  private:
237  MPI_Comm comm_;
238  Rank rank_;
239  Rank nranks_;
240  Logger logger_;
241 };
242 
243 
244 } // namespace rapidsmpf
Buffer representing device or host memory.
Definition: buffer.hpp:53
Abstract base class for asynchronous operation within the communicator.
A logger base class for handling different levels of log messages.
Abstract base class for a communication mechanism between nodes.
Represents the future result of an MPI operation.
Definition: mpi.hpp:80
Future(MPI_Request req, std::unique_ptr< std::vector< uint8_t >> synced_host_data)
Construct a Future from synchronized host data.
Definition: mpi.hpp:102
Future(MPI_Request req, std::unique_ptr< Buffer > data_buffer)
Construct a Future from a data buffer.
Definition: mpi.hpp:90
MPI communicator class that implements the Communicator interface.
Definition: mpi.hpp:72
std::vector< std::size_t > test_some(std::unordered_map< std::size_t, std::unique_ptr< Communicator::Future >> const &future_map) override
Tests for completion of multiple futures in a map.
std::unique_ptr< Communicator::Future > recv_sync_host_data(Rank rank, Tag tag, std::unique_ptr< std::vector< uint8_t >> synced_buffer) override
Receives a message from a specific rank to an allocated (synchronized) host buffer....
std::unique_ptr< Buffer > wait(std::unique_ptr< Communicator::Future > future) override
Wait for a future to complete and return the data buffer.
std::pair< std::vector< std::unique_ptr< Communicator::Future > >, std::vector< std::size_t > > test_some(std::vector< std::unique_ptr< Communicator::Future >> &future_vector) override
Tests for completion of multiple futures.
std::unique_ptr< Communicator::Future > send(std::unique_ptr< Buffer > msg, Rank rank, Tag tag) override
Sends a message (device or host) to a specific rank.
std::pair< std::unique_ptr< std::vector< uint8_t > >, Rank > recv_any(Tag tag) override
Receives a message from any rank (blocking).
Logger & logger() override
Retrieves the logger associated with this communicator.
Definition: mpi.hpp:227
Rank nranks() const override
Retrieves the total number of ranks.
Definition: mpi.hpp:136
Rank rank() const override
Retrieves the rank of the current node.
Definition: mpi.hpp:129
std::unique_ptr< Communicator::Future > send(std::unique_ptr< std::vector< uint8_t >> msg, Rank rank, Tag tag) override
Sends a host message to a specific rank.
std::unique_ptr< std::vector< uint8_t > > recv_from(Rank src, Tag tag) override
Receives a message from a specific rank (blocking).
std::unique_ptr< std::vector< uint8_t > > release_sync_host_data(std::unique_ptr< Communicator::Future > future) override
Retrieves synchronized host data associated with a completed future. When the future is completed,...
std::string str() const override
Provides a string representation of the communicator.
std::unique_ptr< Communicator::Future > recv(Rank rank, Tag tag, std::unique_ptr< Buffer > recv_buffer) override
Receives a message from a specific rank to a buffer. Use release_data to extract the data out of the ...
std::unique_ptr< Buffer > release_data(std::unique_ptr< Communicator::Future > future) override
Retrieves data associated with a completed future.
A tag used for identifying messages in a communication operation.
void init(int *argc, char ***argv)
Helper to initialize MPI with threading support.
bool is_initialized()
Check if MPI is initialized.