ucxx.hpp
1 
5 #pragma once
6 
7 #include <cstdlib>
8 #include <memory>
9 #include <utility>
10 
11 #include <ucxx/api.h>
12 
13 #include <rmm/device_buffer.hpp>
14 
15 #include <rapidsmpf/communicator/communicator.hpp>
16 #include <rapidsmpf/config.hpp>
17 #include <rapidsmpf/error.hpp>
18 #include <rapidsmpf/progress_thread.hpp>
19 
20 namespace rapidsmpf {
21 
22 namespace ucxx {
23 
24 
25 using HostPortPair =
26  std::pair<std::string, std::uint16_t>;
28 using RemoteAddress = std::variant<
29  HostPortPair,
30  std::shared_ptr<::ucxx::Address>>;
32 
36 enum class ProgressMode : std::uint8_t {
37  Blocking = 0,
38  Polling,
39  ThreadBlocking,
40  ThreadPolling
41 };
42 
50  public:
51  RemoteAddress address;
52  Rank rank{};
53 };
54 
55 class SharedResources;
56 
65  public:
73  InitializedRank(std::shared_ptr<SharedResources> shared_resources);
74 
75  std::shared_ptr<SharedResources> shared_resources_{
76  nullptr
77  };
78 };
79 
97 std::unique_ptr<rapidsmpf::ucxx::InitializedRank> init(
98  std::shared_ptr<::ucxx::Worker> worker,
99  Rank nranks,
100  std::optional<RemoteAddress> remote_address,
101  config::Options options
102 );
103 
112 class UCXX final : public Communicator {
113  public:
120  class Future : public Communicator::Future {
121  friend class UCXX;
122 
123  public:
130  Future(std::shared_ptr<::ucxx::Request> req, std::unique_ptr<Buffer> data_buffer)
131  : req_{std::move(req)}, data_buffer_{std::move(data_buffer)} {}
132 
143  std::shared_ptr<::ucxx::Request> req,
144  std::unique_ptr<std::vector<std::uint8_t>> synced_host_data
145  )
146  : req_{std::move(req)}, synced_host_data_{std::move(synced_host_data)} {}
147 
148  ~Future() noexcept override = default;
149 
150  private:
151  std::shared_ptr<::ucxx::Request> req_;
152  std::unique_ptr<Buffer> data_buffer_;
153  // Dedicated storage for host data that is valid at the time of construction.
154  std::unique_ptr<std::vector<std::uint8_t>> synced_host_data_;
155  };
156 
168  std::unique_ptr<InitializedRank> ucxx_initialized_rank,
169  config::Options options,
170  std::shared_ptr<ProgressThread> progress_thread
171  );
172 
173  ~UCXX() noexcept override;
174 
178  [[nodiscard]] Rank rank() const override;
179 
183  [[nodiscard]] Rank nranks() const override;
184 
188  [[nodiscard]] std::unique_ptr<Communicator::Future> send(
189  std::unique_ptr<std::vector<std::uint8_t>> msg, Rank rank, Tag tag
190  ) override;
191 
192  // clang-format off
196  // clang-format on
197  [[nodiscard]] std::unique_ptr<Communicator::Future> send(
198  std::unique_ptr<Buffer> msg, Rank rank, Tag tag
199  ) override;
200 
204  [[nodiscard]] std::unique_ptr<Communicator::Future> recv(
205  Rank rank, Tag tag, std::unique_ptr<Buffer> recv_buffer
206  ) override;
207 
208  // clang-format off
212  // clang-format on
213  [[nodiscard]] std::unique_ptr<Communicator::Future> recv_sync_host_data(
214  Rank rank, Tag tag, std::unique_ptr<std::vector<std::uint8_t>> synced_buffer
215  ) override;
216 
223  [[nodiscard]] std::pair<std::unique_ptr<std::vector<std::uint8_t>>, Rank> recv_any(
224  Tag tag
225  ) override;
226 
233  [[nodiscard]] std::unique_ptr<std::vector<std::uint8_t>> recv_from(
234  Rank src, Tag tag
235  ) override;
236 
242  std::pair<
243  std::vector<std::unique_ptr<Communicator::Future>>,
244  std::vector<std::size_t>>
245  test_some(std::vector<std::unique_ptr<Communicator::Future>>& future_vector) override;
246 
247  // clang-format off
253  // clang-format on
254  std::vector<std::size_t> test_some(
255  std::unordered_map<std::size_t, std::unique_ptr<Communicator::Future>> const&
256  future_map
257  ) override;
258 
260  bool test(std::unique_ptr<Communicator::Future>& future) override;
262  std::vector<std::unique_ptr<Buffer>> wait_all(
263  std::vector<std::unique_ptr<Communicator::Future>>&& futures
264  ) override;
265 
271  [[nodiscard]] std::unique_ptr<Buffer> wait(
272  std::unique_ptr<Communicator::Future> future
273  ) override;
274 
278  [[nodiscard]] std::unique_ptr<Buffer> release_data(
279  std::unique_ptr<Communicator::Future> future
280  ) override;
281 
285  [[nodiscard]] std::unique_ptr<std::vector<std::uint8_t>> release_sync_host_data(
286  std::unique_ptr<Communicator::Future> future
287  ) override;
288 
292  [[nodiscard]] std::shared_ptr<Communicator::Logger> const& logger() override {
293  return logger_;
294  }
295 
299  [[nodiscard]] std::shared_ptr<ProgressThread> const&
300  progress_thread() const override {
301  return progress_thread_;
302  }
303 
307  [[nodiscard]] std::string str() const override;
308 
316  void barrier();
317 
324 
336  std::shared_ptr<UCXX> split();
337 
338  private:
339  std::shared_ptr<SharedResources> shared_resources_;
340  config::Options options_;
341  std::shared_ptr<Logger> logger_;
342  std::shared_ptr<ProgressThread> progress_thread_;
343 
344  std::shared_ptr<::ucxx::Endpoint> get_endpoint(Rank rank);
345  void progress_worker();
346 };
347 
348 } // namespace ucxx
349 
350 } // namespace rapidsmpf
Buffer representing device or host memory.
Definition: buffer.hpp:47
Abstract base class for asynchronous operation within the communicator.
A logger base class for handling different levels of log messages.
Abstract base class for a communication mechanism between nodes.
A progress thread that can execute arbitrary functions.
A tag used for identifying messages in a communication operation.
Manages configuration options for RapidsMPF operations.
Definition: config.hpp:140
A UCXX initialized rank.
Definition: ucxx.hpp:64
InitializedRank(std::shared_ptr< SharedResources > shared_resources)
Construct an initialized UCXX rank.
std::shared_ptr< SharedResources > shared_resources_
Opaque object created by init().
Definition: ucxx.hpp:75
Storage for a listener address.
Definition: ucxx.hpp:49
RemoteAddress address
Hostname/port pair or UCXX address.
Definition: ucxx.hpp:51
Rank rank
The rank of the listener.
Definition: ucxx.hpp:52
Represents the future result of an UCXX operation.
Definition: ucxx.hpp:120
Future(std::shared_ptr<::ucxx::Request > req, std::unique_ptr< std::vector< std::uint8_t >> synced_host_data)
Construct a Future from synchronized host data.
Definition: ucxx.hpp:142
Future(std::shared_ptr<::ucxx::Request > req, std::unique_ptr< Buffer > data_buffer)
Construct a Future from a data buffer.
Definition: ucxx.hpp:130
UCXX communicator class that implements the Communicator interface.
Definition: ucxx.hpp:112
std::shared_ptr< Communicator::Logger > const & logger() override
Retrieves the logger associated with this communicator.
Definition: ucxx.hpp:292
bool test(std::unique_ptr< Communicator::Future > &future) override
Test for completion of a single future.
ListenerAddress listener_address()
Get address to which listener is bound.
std::shared_ptr< ProgressThread > const & progress_thread() const override
Retrieves the progress thread associated with this communicator.
Definition: ucxx.hpp:300
std::unique_ptr< Communicator::Future > recv(Rank rank, Tag tag, std::unique_ptr< Buffer > recv_buffer) override
Receives a message from a specific rank to a buffer. Use release_data to extract the data out of the ...
std::string str() const override
Provides a string representation of the communicator.
Rank rank() const override
Retrieves the rank of the current node.
std::pair< std::unique_ptr< std::vector< std::uint8_t > >, Rank > recv_any(Tag tag) override
Receives a message from any rank (blocking).
std::unique_ptr< std::vector< std::uint8_t > > recv_from(Rank src, Tag tag) override
Receives a message from a specific rank (blocking).
void barrier()
Barrier to synchronize all ranks.
std::unique_ptr< Communicator::Future > recv_sync_host_data(Rank rank, Tag tag, std::unique_ptr< std::vector< std::uint8_t >> synced_buffer) override
Receives a message from a specific rank to an allocated (synchronized) host buffer....
std::unique_ptr< Buffer > release_data(std::unique_ptr< Communicator::Future > future) override
Retrieves data associated with a completed future.
std::unique_ptr< Communicator::Future > send(std::unique_ptr< std::vector< std::uint8_t >> msg, Rank rank, Tag tag) override
Sends a host message to a specific rank.
Rank nranks() const override
Retrieves the total number of ranks.
std::unique_ptr< std::vector< std::uint8_t > > release_sync_host_data(std::unique_ptr< Communicator::Future > future) override
Retrieves synchronized host data associated with a completed future. When the future is completed,...
std::shared_ptr< UCXX > split()
Creates a new communicator with a single rank.
std::vector< std::unique_ptr< Buffer > > wait_all(std::vector< std::unique_ptr< Communicator::Future >> &&futures) override
Wait for completion of all futures and return their data buffers.
std::pair< std::vector< std::unique_ptr< Communicator::Future > >, std::vector< std::size_t > > test_some(std::vector< std::unique_ptr< Communicator::Future >> &future_vector) override
Tests for completion of multiple futures.
std::unique_ptr< Buffer > wait(std::unique_ptr< Communicator::Future > future) override
Wait for a future to complete and return the data buffer.
RAPIDS Multi-Processor interfaces.
Definition: backend.hpp:13
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).