ucxx.hpp
1 
5 #pragma once
6 
7 #include <cstdlib>
8 #include <memory>
9 #include <utility>
10 
11 #include <ucxx/api.h>
12 
13 #include <rmm/device_buffer.hpp>
14 
15 #include <rapidsmpf/communicator/communicator.hpp>
16 #include <rapidsmpf/config.hpp>
17 #include <rapidsmpf/error.hpp>
18 
19 namespace rapidsmpf {
20 
21 namespace ucxx {
22 
23 
24 using HostPortPair =
25  std::pair<std::string, uint16_t>;
27 using RemoteAddress = std::variant<
28  HostPortPair,
29  std::shared_ptr<::ucxx::Address>>;
31 
35 enum class ProgressMode : std::uint8_t {
36  Blocking = 0,
37  Polling,
38  ThreadBlocking,
39  ThreadPolling
40 };
41 
49  public:
50  RemoteAddress address;
51  Rank rank{};
52 };
53 
54 class SharedResources;
55 
64  public:
72  InitializedRank(std::shared_ptr<SharedResources> shared_resources);
73 
74  std::shared_ptr<SharedResources> shared_resources_{
75  nullptr
76  };
77 };
78 
96 std::unique_ptr<rapidsmpf::ucxx::InitializedRank> init(
97  std::shared_ptr<::ucxx::Worker> worker,
98  Rank nranks,
99  std::optional<RemoteAddress> remote_address,
100  config::Options options
101 );
102 
111 class UCXX final : public Communicator {
112  public:
119  class Future : public Communicator::Future {
120  friend class UCXX;
121 
122  public:
129  Future(std::shared_ptr<::ucxx::Request> req, std::unique_ptr<Buffer> data_buffer)
130  : req_{std::move(req)}, data_buffer_{std::move(data_buffer)} {}
131 
142  std::shared_ptr<::ucxx::Request> req,
143  std::unique_ptr<std::vector<uint8_t>> synced_host_data
144  )
145  : req_{std::move(req)}, synced_host_data_{std::move(synced_host_data)} {}
146 
147  ~Future() noexcept override = default;
148 
149  private:
150  std::shared_ptr<::ucxx::Request> req_;
151  std::unique_ptr<Buffer> data_buffer_;
152  // Dedicated storage for host data that is valid at the time of construction.
153  std::unique_ptr<std::vector<uint8_t>> synced_host_data_;
154  };
155 
165  UCXX(std::unique_ptr<InitializedRank> ucxx_initialized_rank, config::Options options);
166 
167  ~UCXX() noexcept override;
168 
172  [[nodiscard]] Rank rank() const override;
173 
177  [[nodiscard]] Rank nranks() const override;
178 
182  [[nodiscard]] std::unique_ptr<Communicator::Future> send(
183  std::unique_ptr<std::vector<uint8_t>> msg, Rank rank, Tag tag
184  ) override;
185 
186  // clang-format off
190  // clang-format on
191  [[nodiscard]] std::unique_ptr<Communicator::Future> send(
192  std::unique_ptr<Buffer> msg, Rank rank, Tag tag
193  ) override;
194 
198  [[nodiscard]] std::unique_ptr<Communicator::Future> recv(
199  Rank rank, Tag tag, std::unique_ptr<Buffer> recv_buffer
200  ) override;
201 
202  // clang-format off
206  // clang-format on
207  [[nodiscard]] std::unique_ptr<Communicator::Future> recv_sync_host_data(
208  Rank rank, Tag tag, std::unique_ptr<std::vector<uint8_t>> synced_buffer
209  ) override;
210 
217  [[nodiscard]] std::pair<std::unique_ptr<std::vector<uint8_t>>, Rank> recv_any(
218  Tag tag
219  ) override;
220 
227  [[nodiscard]] std::unique_ptr<std::vector<uint8_t>> recv_from(
228  Rank src, Tag tag
229  ) override;
230 
236  std::pair<
237  std::vector<std::unique_ptr<Communicator::Future>>,
238  std::vector<std::size_t>>
239  test_some(std::vector<std::unique_ptr<Communicator::Future>>& future_vector) override;
240 
241  // clang-format off
247  // clang-format on
248  std::vector<std::size_t> test_some(
249  std::unordered_map<std::size_t, std::unique_ptr<Communicator::Future>> const&
250  future_map
251  ) override;
252 
258  [[nodiscard]] std::unique_ptr<Buffer> wait(
259  std::unique_ptr<Communicator::Future> future
260  ) override;
261 
265  [[nodiscard]] std::unique_ptr<Buffer> release_data(
266  std::unique_ptr<Communicator::Future> future
267  ) override;
268 
272  [[nodiscard]] std::unique_ptr<std::vector<uint8_t>> release_sync_host_data(
273  std::unique_ptr<Communicator::Future> future
274  ) override;
275 
279  [[nodiscard]] Logger& logger() override {
280  return logger_;
281  }
282 
286  [[nodiscard]] std::string str() const override;
287 
295  void barrier();
296 
303 
315  std::shared_ptr<UCXX> split();
316 
317  private:
318  std::shared_ptr<SharedResources> shared_resources_;
319  config::Options options_;
320  Logger logger_;
321 
322  std::shared_ptr<::ucxx::Endpoint> get_endpoint(Rank rank);
323  void progress_worker();
324 };
325 
326 } // namespace ucxx
327 
328 } // namespace rapidsmpf
Buffer representing device or host memory.
Definition: buffer.hpp:53
Abstract base class for asynchronous operation within the communicator.
A logger base class for handling different levels of log messages.
Abstract base class for a communication mechanism between nodes.
A tag used for identifying messages in a communication operation.
Manages configuration options for RapidsMPF operations.
Definition: config.hpp:124
A UCXX initialized rank.
Definition: ucxx.hpp:63
InitializedRank(std::shared_ptr< SharedResources > shared_resources)
Construct an initialized UCXX rank.
std::shared_ptr< SharedResources > shared_resources_
Opaque object created by init().
Definition: ucxx.hpp:74
Storage for a listener address.
Definition: ucxx.hpp:48
RemoteAddress address
Hostname/port pair or UCXX address.
Definition: ucxx.hpp:50
Rank rank
The rank of the listener.
Definition: ucxx.hpp:51
Represents the future result of an UCXX operation.
Definition: ucxx.hpp:119
Future(std::shared_ptr<::ucxx::Request > req, std::unique_ptr< Buffer > data_buffer)
Construct a Future from a data buffer.
Definition: ucxx.hpp:129
Future(std::shared_ptr<::ucxx::Request > req, std::unique_ptr< std::vector< uint8_t >> synced_host_data)
Construct a Future from synchronized host data.
Definition: ucxx.hpp:141
UCXX communicator class that implements the Communicator interface.
Definition: ucxx.hpp:111
ListenerAddress listener_address()
Get address to which listener is bound.
std::pair< std::unique_ptr< std::vector< uint8_t > >, Rank > recv_any(Tag tag) override
Receives a message from any rank (blocking).
std::unique_ptr< Communicator::Future > recv(Rank rank, Tag tag, std::unique_ptr< Buffer > recv_buffer) override
Receives a message from a specific rank to a buffer. Use release_data to extract the data out of the ...
std::unique_ptr< Communicator::Future > send(std::unique_ptr< std::vector< uint8_t >> msg, Rank rank, Tag tag) override
Sends a host message to a specific rank.
std::unique_ptr< std::vector< uint8_t > > recv_from(Rank src, Tag tag) override
Receives a message from a specific rank (blocking).
std::string str() const override
Provides a string representation of the communicator.
Rank rank() const override
Retrieves the rank of the current node.
std::unique_ptr< std::vector< uint8_t > > release_sync_host_data(std::unique_ptr< Communicator::Future > future) override
Retrieves synchronized host data associated with a completed future. When the future is completed,...
Logger & logger() override
Retrieves the logger associated with this communicator.
Definition: ucxx.hpp:279
void barrier()
Barrier to synchronize all ranks.
std::unique_ptr< Communicator::Future > recv_sync_host_data(Rank rank, Tag tag, std::unique_ptr< std::vector< uint8_t >> synced_buffer) override
Receives a message from a specific rank to an allocated (synchronized) host buffer....
std::unique_ptr< Buffer > release_data(std::unique_ptr< Communicator::Future > future) override
Retrieves data associated with a completed future.
Rank nranks() const override
Retrieves the total number of ranks.
std::shared_ptr< UCXX > split()
Creates a new communicator with a single rank.
std::pair< std::vector< std::unique_ptr< Communicator::Future > >, std::vector< std::size_t > > test_some(std::vector< std::unique_ptr< Communicator::Future >> &future_vector) override
Tests for completion of multiple futures.
std::unique_ptr< Buffer > wait(std::unique_ptr< Communicator::Future > future) override
Wait for a future to complete and return the data buffer.