communicator.hpp
1 
5 #pragma once
6 
7 #include <cstdlib>
8 #include <memory>
9 #include <mutex>
10 #include <sstream>
11 #include <thread>
12 #include <unordered_map>
13 #include <vector>
14 
15 #include <rapidsmpf/buffer/buffer.hpp>
16 #include <rapidsmpf/buffer/resource.hpp>
17 #include <rapidsmpf/config.hpp>
18 #include <rapidsmpf/error.hpp>
19 
24 namespace rapidsmpf {
25 
33 using Rank = std::int32_t;
34 
42 using OpID = std::uint8_t;
43 
48 using StageID = std::uint8_t;
49 
57 class Tag {
58  public:
63  using StorageT = std::int32_t;
64 
66  static constexpr int stage_id_bits{sizeof(StageID) * 8};
67 
69  static constexpr StorageT stage_id_mask{(1 << stage_id_bits) - 1};
70 
72  static constexpr int op_id_bits{sizeof(OpID) * 8};
73 
75  static constexpr StorageT op_id_mask{
76  ((1 << (op_id_bits + stage_id_bits)) - 1) ^ stage_id_mask
77  };
78 
85  constexpr Tag(OpID const op, StageID const stage)
86  : tag_{
87  (static_cast<StorageT>(op) << stage_id_bits) | static_cast<StorageT>(stage)
88  } {}
89 
94  [[nodiscard]] static constexpr size_t bit_length() noexcept {
95  return op_id_bits + stage_id_bits;
96  }
97 
102  [[nodiscard]] static constexpr StorageT max_value() noexcept {
103  return (1 << bit_length()) - 1;
104  }
105 
110  constexpr operator StorageT() const noexcept {
111  return tag_;
112  }
113 
118  [[nodiscard]] constexpr OpID op() const noexcept {
119  return (tag_ & op_id_mask) >> stage_id_bits;
120  }
121 
126  [[nodiscard]] constexpr StageID stage() const noexcept {
127  return tag_ & stage_id_mask;
128  }
129 
130  private:
131  StorageT const tag_;
132 };
133 
142  public:
149  class Future {
150  public:
151  Future() = default;
152  virtual ~Future() noexcept = default;
153  Future(Future&&) = default;
159  Future& operator=(Future&&) = default;
160  Future(Future const&) = delete;
161  Future& operator=(Future const&) = delete;
162  };
163 
173  class Logger {
174  public:
180  enum class LOG_LEVEL : std::uint32_t {
181  NONE = 0,
182  PRINT,
183  WARN,
184  INFO,
185  DEBUG,
186  TRACE
187  };
188 
192  static constexpr std::array<char const*, 6> LOG_LEVEL_NAMES{
193  "NONE", "PRINT", "WARN", "INFO", "DEBUG", "TRACE"
194  };
195 
202  static constexpr const char* level_name(LOG_LEVEL level) {
203  auto index = static_cast<std::size_t>(level);
204  return index < LOG_LEVEL_NAMES.size() ? LOG_LEVEL_NAMES[index] : "UNKNOWN";
205  }
206 
223  virtual ~Logger() noexcept = default;
224 
231  return level_;
232  }
233 
243  template <typename... Args>
244  void log(LOG_LEVEL level, Args const&... args) {
245  if (static_cast<std::uint32_t>(level_) < static_cast<std::uint32_t>(level)) {
246  return;
247  }
248  std::ostringstream ss;
249  (ss << ... << args);
250  do_log(level, std::move(ss));
251  }
252 
259  template <typename... Args>
260  void print(Args const&... args) {
261  log(LOG_LEVEL::PRINT, std::forward<Args const&>(args)...);
262  }
263 
270  template <typename... Args>
271  void warn(Args const&... args) {
272  log(LOG_LEVEL::WARN, std::forward<Args const&>(args)...);
273  }
274 
281  template <typename... Args>
282  void info(Args const&... args) {
283  log(LOG_LEVEL::INFO, std::forward<Args const&>(args)...);
284  }
285 
292  template <typename... Args>
293  void debug(Args const&... args) {
294  log(LOG_LEVEL::DEBUG, std::forward<Args const&>(args)...);
295  }
296 
303  template <typename... Args>
304  void trace(Args const&... args) {
305  log(LOG_LEVEL::TRACE, std::forward<Args const&>(args)...);
306  }
307 
308  protected:
314  virtual std::uint32_t get_thread_id() {
315  auto const tid = std::this_thread::get_id();
316 
317  // To avoid large IDs, we map the thread ID to an unique counter.
318  auto const [name, inserted] =
319  thread_id_names.insert({tid, thread_id_names_counter});
320  if (inserted) {
321  ++thread_id_names_counter;
322  }
323  return name->second;
324  }
325 
337  virtual void do_log(LOG_LEVEL level, std::ostringstream&& ss) {
338  std::ostringstream full_log_msg;
339  full_log_msg << "[" << level_name(level) << ":" << comm_->rank() << ":"
340  << get_thread_id() << "] " << ss.str();
341  std::lock_guard<std::mutex> lock(mutex_);
342  std::cout << full_log_msg.str() << std::endl;
343  }
344 
351  return comm_;
352  }
353 
354  private:
355  std::mutex mutex_;
356  Communicator* comm_;
357  LOG_LEVEL const level_;
358 
361  std::uint32_t thread_id_names_counter{0};
362 
364  std::unordered_map<std::thread::id, std::uint32_t> thread_id_names;
365  };
366 
367  protected:
368  Communicator() = default;
369 
370  public:
371  virtual ~Communicator() noexcept = default;
372 
377  [[nodiscard]] virtual Rank rank() const = 0;
378 
383  [[nodiscard]] virtual Rank nranks() const = 0;
384 
396  [[nodiscard]] virtual std::unique_ptr<Future> send(
397  std::unique_ptr<std::vector<uint8_t>> msg, Rank rank, Tag tag
398  ) = 0;
399 
414  [[nodiscard]] virtual std::unique_ptr<Future> send(
415  std::unique_ptr<Buffer> msg, Rank rank, Tag tag
416  ) = 0;
417 
433  [[nodiscard]] virtual std::unique_ptr<Future> recv(
434  Rank rank, Tag tag, std::unique_ptr<Buffer> recv_buffer
435  ) = 0;
436 
447  [[nodiscard]] virtual std::unique_ptr<Future> recv_sync_host_data(
448  Rank rank, Tag tag, std::unique_ptr<std::vector<uint8_t>> synced_buffer
449  ) = 0;
450 
460  [[nodiscard]] virtual std::pair<std::unique_ptr<std::vector<uint8_t>>, Rank> recv_any(
461  Tag tag
462  ) = 0;
463 
474  [[nodiscard]] virtual std::unique_ptr<std::vector<uint8_t>> recv_from(
475  Rank src, Tag tag
476  ) = 0;
477 
485  [[nodiscard]] virtual std::
486  pair<std::vector<std::unique_ptr<Future>>, std::vector<std::size_t>>
487  test_some(std::vector<std::unique_ptr<Future>>& future_vector) = 0;
488 
495  std::vector<std::size_t> virtual test_some(
496  std::unordered_map<std::size_t, std::unique_ptr<Communicator::Future>> const&
497  future_map
498  ) = 0;
499 
507  [[nodiscard]] virtual std::unique_ptr<Buffer> wait(
508  std::unique_ptr<Future> future
509  ) = 0;
510 
519  [[nodiscard]] std::unique_ptr<Buffer> virtual release_data(
520  std::unique_ptr<Communicator::Future> future
521  ) = 0;
522 
533  [[nodiscard]] std::unique_ptr<std::vector<uint8_t>> virtual release_sync_host_data(
534  std::unique_ptr<Communicator::Future> future
535  ) = 0;
536 
541  [[nodiscard]] virtual Logger& logger() = 0;
542 
547  [[nodiscard]] virtual std::string str() const = 0;
548 };
549 
551 #ifdef RAPIDSMPF_HAVE_UCXX
552 constexpr bool COMM_HAVE_UCXX = true;
553 #else
554 constexpr bool COMM_HAVE_UCXX = false;
555 #endif
556 
558 #ifdef RAPIDSMPF_HAVE_MPI
559 constexpr bool COMM_HAVE_MPI = true;
560 #else
561 constexpr bool COMM_HAVE_MPI = false;
562 #endif
563 
574 inline std::ostream& operator<<(std::ostream& os, Communicator const& obj) {
575  os << obj.str();
576  return os;
577 }
578 
579 } // namespace rapidsmpf
Buffer representing device or host memory.
Definition: buffer.hpp:53
Abstract base class for asynchronous operation within the communicator.
Future(Future const &)=delete
Not copyable.
Future & operator=(Future &&)=default
Move assignment.
Future(Future &&)=default
Movable.
Future & operator=(Future const &)=delete
Not copy-assignable.
A logger base class for handling different levels of log messages.
void print(Args const &... args)
Logs a print message.
virtual void do_log(LOG_LEVEL level, std::ostringstream &&ss)
Handles the logging of a messages.
virtual std::uint32_t get_thread_id()
Returns a unique thread ID for the current thread.
void debug(Args const &... args)
Logs a debug message.
static constexpr const char * level_name(LOG_LEVEL level)
Get the string name of a log level.
Communicator * get_communicator() const
Get the communicator used by the logger.
Logger(Communicator *comm, config::Options options)
Construct a new logger.
void info(Args const &... args)
Logs an informational message.
static constexpr std::array< char const *, 6 > LOG_LEVEL_NAMES
Log level names corresponding to the LOG_LEVEL enum.
LOG_LEVEL
Log verbosity levels.
void log(LOG_LEVEL level, Args const &... args)
Logs a message using the specified verbosity level.
void warn(Args const &... args)
Logs a warning message.
void trace(Args const &... args)
Logs a trace message.
LOG_LEVEL verbosity_level() const
Get the verbosity level of the logger.
Abstract base class for a communication mechanism between nodes.
virtual Rank nranks() const =0
Retrieves the total number of ranks.
virtual std::string str() const =0
Provides a string representation of the communicator.
virtual std::unique_ptr< std::vector< uint8_t > > release_sync_host_data(std::unique_ptr< Communicator::Future > future)=0
Retrieves synchronized host data associated with a completed future. When the future is completed,...
virtual std::unique_ptr< Buffer > release_data(std::unique_ptr< Communicator::Future > future)=0
Retrieves data associated with a completed future.
virtual std::unique_ptr< std::vector< uint8_t > > recv_from(Rank src, Tag tag)=0
Receives a message from a specific rank (blocking).
virtual Rank rank() const =0
Retrieves the rank of the current node.
virtual std::unique_ptr< Future > send(std::unique_ptr< std::vector< uint8_t >> msg, Rank rank, Tag tag)=0
Sends a host message to a specific rank.
virtual std::pair< std::unique_ptr< std::vector< uint8_t > >, Rank > recv_any(Tag tag)=0
Receives a message from any rank (blocking).
virtual std::unique_ptr< Future > recv_sync_host_data(Rank rank, Tag tag, std::unique_ptr< std::vector< uint8_t >> synced_buffer)=0
Receives a message from a specific rank to an allocated (synchronized) host buffer....
virtual std::pair< std::vector< std::unique_ptr< Future > >, std::vector< std::size_t > > test_some(std::vector< std::unique_ptr< Future >> &future_vector)=0
Tests for completion of multiple futures.
virtual std::unique_ptr< Future > recv(Rank rank, Tag tag, std::unique_ptr< Buffer > recv_buffer)=0
Receives a message from a specific rank to a buffer. Use release_data to extract the data out of the ...
virtual Logger & logger()=0
Retrieves the logger associated with this communicator.
virtual std::unique_ptr< Buffer > wait(std::unique_ptr< Future > future)=0
Wait for a future to complete and return the data buffer.
A tag used for identifying messages in a communication operation.
static constexpr StorageT stage_id_mask
Mask for the stage ID.
constexpr Tag(OpID const op, StageID const stage)
Constructs a tag.
constexpr OpID op() const noexcept
Extracts the operation ID from the tag.
static constexpr size_t bit_length() noexcept
Returns the max number of bits used for the tag.
constexpr StageID stage() const noexcept
Extracts the stage ID from the tag.
static constexpr int stage_id_bits
Number of bits for the stage ID.
static constexpr StorageT max_value() noexcept
Returns the max value of the tag.
static constexpr int op_id_bits
Number of bits for the operation ID.
std::int32_t StorageT
The physical data type to store the tag.
static constexpr StorageT op_id_mask
Mask for the operation ID.
Manages configuration options for RapidsMPF operations.
Definition: config.hpp:124
std::ostream & operator<<(std::ostream &os, cuda_stream_view stream)
std::int32_t Rank
Type alias for communicator::Rank.
Definition: bootstrap.hpp:20