communicator.hpp
1 
5 #pragma once
6 
7 #include <cstdlib>
8 #include <memory>
9 #include <mutex>
10 #include <sstream>
11 #include <thread>
12 #include <unordered_map>
13 #include <vector>
14 
15 #include <rapidsmpf/config.hpp>
16 #include <rapidsmpf/error.hpp>
17 #include <rapidsmpf/memory/buffer.hpp>
18 #include <rapidsmpf/memory/buffer_resource.hpp>
19 
20 #include "rapidsmpf/utils.hpp"
21 
26 namespace rapidsmpf {
27 
35 using Rank = std::int32_t;
36 
44 using OpID = std::uint8_t;
45 
50 using StageID = std::uint8_t;
51 
59 class Tag {
60  public:
65  using StorageT = std::int32_t;
66 
68  static constexpr int stage_id_bits{sizeof(StageID) * 8};
69 
71  static constexpr StorageT stage_id_mask{(1 << stage_id_bits) - 1};
72 
74  static constexpr int op_id_bits{sizeof(OpID) * 8};
75 
77  static constexpr StorageT op_id_mask{
78  ((1 << (op_id_bits + stage_id_bits)) - 1) ^ stage_id_mask
79  };
80 
87  constexpr Tag(OpID const op, StageID const stage)
88  : tag_{
89  (static_cast<StorageT>(op) << stage_id_bits) | static_cast<StorageT>(stage)
90  } {}
91 
96  [[nodiscard]] static constexpr size_t bit_length() noexcept {
97  return op_id_bits + stage_id_bits;
98  }
99 
104  [[nodiscard]] static constexpr StorageT max_value() noexcept {
105  return (1 << bit_length()) - 1;
106  }
107 
112  constexpr operator StorageT() const noexcept {
113  return tag_;
114  }
115 
120  [[nodiscard]] constexpr OpID op() const noexcept {
121  return (tag_ & op_id_mask) >> stage_id_bits;
122  }
123 
128  [[nodiscard]] constexpr StageID stage() const noexcept {
129  return tag_ & stage_id_mask;
130  }
131 
132  private:
133  StorageT const tag_;
134 };
135 
144  public:
151  class Future {
152  public:
153  Future() = default;
154  virtual ~Future() noexcept = default;
155  Future(Future&&) = default;
161  Future& operator=(Future&&) = default;
162  Future(Future const&) = delete;
163  Future& operator=(Future const&) = delete;
164  };
165 
175  class Logger {
176  public:
182  enum class LOG_LEVEL : std::uint32_t {
183  NONE = 0,
184  PRINT,
185  WARN,
186  INFO,
187  DEBUG,
188  TRACE
189  };
190 
194  static constexpr std::array<char const*, 6> LOG_LEVEL_NAMES{
195  "NONE", "PRINT", "WARN", "INFO", "DEBUG", "TRACE"
196  };
197 
204  static constexpr const char* level_name(LOG_LEVEL level) {
205  auto index = static_cast<std::size_t>(level);
206  return index < LOG_LEVEL_NAMES.size() ? LOG_LEVEL_NAMES[index] : "UNKNOWN";
207  }
208 
225  virtual ~Logger() noexcept = default;
226 
233  return level_;
234  }
235 
245  template <typename... Args>
246  void log(LOG_LEVEL level, Args const&... args) {
247  if (static_cast<std::uint32_t>(level_) < static_cast<std::uint32_t>(level)) {
248  return;
249  }
250  std::ostringstream ss;
251  (ss << ... << args);
252  do_log(level, std::move(ss));
253  }
254 
261  template <typename... Args>
262  void print(Args const&... args) {
263  log(LOG_LEVEL::PRINT, std::forward<Args const&>(args)...);
264  }
265 
272  template <typename... Args>
273  void warn(Args const&... args) {
274  log(LOG_LEVEL::WARN, std::forward<Args const&>(args)...);
275  }
276 
283  template <typename... Args>
284  void info(Args const&... args) {
285  log(LOG_LEVEL::INFO, std::forward<Args const&>(args)...);
286  }
287 
294  template <typename... Args>
295  void debug(Args const&... args) {
296  log(LOG_LEVEL::DEBUG, std::forward<Args const&>(args)...);
297  }
298 
305  template <typename... Args>
306  void trace(Args const&... args) {
307  log(LOG_LEVEL::TRACE, std::forward<Args const&>(args)...);
308  }
309 
310  protected:
316  virtual std::uint32_t get_thread_id() {
317  auto const tid = std::this_thread::get_id();
318 
319  // To avoid large IDs, we map the thread ID to an unique counter.
320  auto const [name, inserted] =
321  thread_id_names.insert({tid, thread_id_names_counter});
322  if (inserted) {
323  ++thread_id_names_counter;
324  }
325  return name->second;
326  }
327 
339  virtual void do_log(LOG_LEVEL level, std::ostringstream&& ss) {
340  std::ostringstream full_log_msg;
341  full_log_msg << "[" << level_name(level) << ":" << comm_->rank() << ":"
342  << get_thread_id() << ":" << Clock::now() << "] " << ss.str();
343  std::lock_guard<std::mutex> lock(mutex_);
344  std::cout << full_log_msg.str() << std::endl;
345  }
346 
353  return comm_;
354  }
355 
356  private:
357  std::mutex mutex_;
358  Communicator* comm_;
359  LOG_LEVEL const level_;
360 
363  std::uint32_t thread_id_names_counter{0};
364 
366  std::unordered_map<std::thread::id, std::uint32_t> thread_id_names;
367  };
368 
369  protected:
370  Communicator() = default;
371 
372  public:
373  virtual ~Communicator() noexcept = default;
374 
379  [[nodiscard]] virtual Rank rank() const = 0;
380 
385  [[nodiscard]] virtual Rank nranks() const = 0;
386 
398  [[nodiscard]] virtual std::unique_ptr<Future> send(
399  std::unique_ptr<std::vector<uint8_t>> msg, Rank rank, Tag tag
400  ) = 0;
401 
416  [[nodiscard]] virtual std::unique_ptr<Future> send(
417  std::unique_ptr<Buffer> msg, Rank rank, Tag tag
418  ) = 0;
419 
435  [[nodiscard]] virtual std::unique_ptr<Future> recv(
436  Rank rank, Tag tag, std::unique_ptr<Buffer> recv_buffer
437  ) = 0;
438 
449  [[nodiscard]] virtual std::unique_ptr<Future> recv_sync_host_data(
450  Rank rank, Tag tag, std::unique_ptr<std::vector<uint8_t>> synced_buffer
451  ) = 0;
452 
462  [[nodiscard]] virtual std::pair<std::unique_ptr<std::vector<uint8_t>>, Rank> recv_any(
463  Tag tag
464  ) = 0;
465 
476  [[nodiscard]] virtual std::unique_ptr<std::vector<uint8_t>> recv_from(
477  Rank src, Tag tag
478  ) = 0;
479 
487  [[nodiscard]] virtual std::
488  pair<std::vector<std::unique_ptr<Future>>, std::vector<std::size_t>>
489  test_some(std::vector<std::unique_ptr<Future>>& future_vector) = 0;
490 
497  std::vector<std::size_t> virtual test_some(
498  std::unordered_map<std::size_t, std::unique_ptr<Communicator::Future>> const&
499  future_map
500  ) = 0;
501 
509  [[nodiscard]] virtual std::unique_ptr<Buffer> wait(
510  std::unique_ptr<Future> future
511  ) = 0;
512 
521  [[nodiscard]] std::unique_ptr<Buffer> virtual release_data(
522  std::unique_ptr<Communicator::Future> future
523  ) = 0;
524 
535  [[nodiscard]] std::unique_ptr<std::vector<uint8_t>> virtual release_sync_host_data(
536  std::unique_ptr<Communicator::Future> future
537  ) = 0;
538 
543  [[nodiscard]] virtual Logger& logger() = 0;
544 
549  [[nodiscard]] virtual std::string str() const = 0;
550 };
551 
553 #ifdef RAPIDSMPF_HAVE_UCXX
554 constexpr bool COMM_HAVE_UCXX = true;
555 #else
556 constexpr bool COMM_HAVE_UCXX = false;
557 #endif
558 
560 #ifdef RAPIDSMPF_HAVE_MPI
561 constexpr bool COMM_HAVE_MPI = true;
562 #else
563 constexpr bool COMM_HAVE_MPI = false;
564 #endif
565 
576 inline std::ostream& operator<<(std::ostream& os, Communicator const& obj) {
577  os << obj.str();
578  return os;
579 }
580 
581 } // namespace rapidsmpf
Buffer representing device or host memory.
Definition: buffer.hpp:46
Abstract base class for asynchronous operation within the communicator.
Future(Future const &)=delete
Not copyable.
Future & operator=(Future &&)=default
Move assignment.
Future(Future &&)=default
Movable.
Future & operator=(Future const &)=delete
Not copy-assignable.
A logger base class for handling different levels of log messages.
void print(Args const &... args)
Logs a print message.
virtual void do_log(LOG_LEVEL level, std::ostringstream &&ss)
Handles the logging of a messages.
virtual std::uint32_t get_thread_id()
Returns a unique thread ID for the current thread.
void debug(Args const &... args)
Logs a debug message.
static constexpr const char * level_name(LOG_LEVEL level)
Get the string name of a log level.
Communicator * get_communicator() const
Get the communicator used by the logger.
Logger(Communicator *comm, config::Options options)
Construct a new logger.
void info(Args const &... args)
Logs an informational message.
static constexpr std::array< char const *, 6 > LOG_LEVEL_NAMES
Log level names corresponding to the LOG_LEVEL enum.
LOG_LEVEL
Log verbosity levels.
void log(LOG_LEVEL level, Args const &... args)
Logs a message using the specified verbosity level.
void warn(Args const &... args)
Logs a warning message.
void trace(Args const &... args)
Logs a trace message.
LOG_LEVEL verbosity_level() const
Get the verbosity level of the logger.
Abstract base class for a communication mechanism between nodes.
virtual Rank nranks() const =0
Retrieves the total number of ranks.
virtual std::string str() const =0
Provides a string representation of the communicator.
virtual std::unique_ptr< std::vector< uint8_t > > release_sync_host_data(std::unique_ptr< Communicator::Future > future)=0
Retrieves synchronized host data associated with a completed future. When the future is completed,...
virtual std::unique_ptr< Buffer > release_data(std::unique_ptr< Communicator::Future > future)=0
Retrieves data associated with a completed future.
virtual std::unique_ptr< std::vector< uint8_t > > recv_from(Rank src, Tag tag)=0
Receives a message from a specific rank (blocking).
virtual Rank rank() const =0
Retrieves the rank of the current node.
virtual std::unique_ptr< Future > send(std::unique_ptr< std::vector< uint8_t >> msg, Rank rank, Tag tag)=0
Sends a host message to a specific rank.
virtual std::pair< std::unique_ptr< std::vector< uint8_t > >, Rank > recv_any(Tag tag)=0
Receives a message from any rank (blocking).
virtual std::unique_ptr< Future > recv_sync_host_data(Rank rank, Tag tag, std::unique_ptr< std::vector< uint8_t >> synced_buffer)=0
Receives a message from a specific rank to an allocated (synchronized) host buffer....
virtual std::pair< std::vector< std::unique_ptr< Future > >, std::vector< std::size_t > > test_some(std::vector< std::unique_ptr< Future >> &future_vector)=0
Tests for completion of multiple futures.
virtual std::unique_ptr< Future > recv(Rank rank, Tag tag, std::unique_ptr< Buffer > recv_buffer)=0
Receives a message from a specific rank to a buffer. Use release_data to extract the data out of the ...
virtual Logger & logger()=0
Retrieves the logger associated with this communicator.
virtual std::unique_ptr< Buffer > wait(std::unique_ptr< Future > future)=0
Wait for a future to complete and return the data buffer.
A tag used for identifying messages in a communication operation.
static constexpr StorageT stage_id_mask
Mask for the stage ID.
constexpr Tag(OpID const op, StageID const stage)
Constructs a tag.
constexpr OpID op() const noexcept
Extracts the operation ID from the tag.
static constexpr size_t bit_length() noexcept
Returns the max number of bits used for the tag.
constexpr StageID stage() const noexcept
Extracts the stage ID from the tag.
static constexpr int stage_id_bits
Number of bits for the stage ID.
static constexpr StorageT max_value() noexcept
Returns the max value of the tag.
static constexpr int op_id_bits
Number of bits for the operation ID.
std::int32_t StorageT
The physical data type to store the tag.
static constexpr StorageT op_id_mask
Mask for the operation ID.
Manages configuration options for RapidsMPF operations.
Definition: config.hpp:124
std::ostream & operator<<(std::ostream &os, cuda_stream_view stream)
std::int32_t Rank
Type alias for communicator::Rank.
Definition: types.hpp:14