communicator.hpp
1 
5 #pragma once
6 
7 #include <cstdlib>
8 #include <memory>
9 #include <mutex>
10 #include <sstream>
11 #include <stdexcept>
12 #include <thread>
13 #include <unordered_map>
14 #include <vector>
15 
16 #include <rapidsmpf/config.hpp>
17 #include <rapidsmpf/error.hpp>
18 #include <rapidsmpf/memory/buffer.hpp>
19 #include <rapidsmpf/progress_thread.hpp>
20 
25 namespace rapidsmpf {
26 
34 using Rank = std::int32_t;
35 
44 using OpID = std::int32_t;
45 
53 using StageID = std::int32_t;
54 
83 class Tag {
84  public:
89  using StorageT = std::int32_t;
90 
92  static constexpr int stage_id_bits{3};
93 
95  static constexpr StorageT stage_id_mask{(1 << stage_id_bits) - 1};
96 
98  static constexpr int op_id_bits{20};
99 
101  static constexpr StorageT op_id_mask{
102  ((1 << (op_id_bits + stage_id_bits)) - 1) ^ stage_id_mask
103  };
104 
114  constexpr Tag(OpID const op, StageID const stage)
115  : tag_{
116  (static_cast<StorageT>(op) << stage_id_bits) | static_cast<StorageT>(stage)
117  } {
118  RAPIDSMPF_EXPECTS(
119  stage >= 0 && stage < (1 << stage_id_bits),
120  "Invalid stage value",
121  std::overflow_error
122  );
123  RAPIDSMPF_EXPECTS(
124  op >= 0 && op < (1 << op_id_bits), "Invalid OpID value", std::overflow_error
125  );
126  }
127 
132  [[nodiscard]] static constexpr std::size_t bit_length() noexcept {
133  return op_id_bits + stage_id_bits;
134  }
135 
140  [[nodiscard]] static constexpr StorageT max_value() noexcept {
141  return (1 << bit_length()) - 1;
142  }
143 
148  constexpr operator StorageT() const noexcept {
149  return tag_;
150  }
151 
156  [[nodiscard]] constexpr OpID op() const noexcept {
157  return (tag_ & op_id_mask) >> stage_id_bits;
158  }
159 
164  [[nodiscard]] constexpr StageID stage() const noexcept {
165  return tag_ & stage_id_mask;
166  }
167 
168  private:
169  StorageT const tag_;
170 };
171 
189  public:
196  class Future {
197  public:
198  Future() = default;
199  virtual ~Future() noexcept = default;
200  Future(Future&&) = default;
206  Future& operator=(Future&&) = default;
207  Future(Future const&) = delete;
208  Future& operator=(Future const&) = delete;
209  };
210 
220  class Logger {
221  public:
227  enum class LOG_LEVEL : std::uint32_t {
228  NONE = 0,
229  PRINT,
230  WARN,
231  INFO,
232  DEBUG,
233  TRACE
234  };
235 
239  static constexpr std::array<char const*, 6> LOG_LEVEL_NAMES{
240  "NONE", "PRINT", "WARN", "INFO", "DEBUG", "TRACE"
241  };
242 
249  static constexpr char const* level_name(LOG_LEVEL level) {
250  auto index = static_cast<std::size_t>(level);
251  return index < LOG_LEVEL_NAMES.size() ? LOG_LEVEL_NAMES[index] : "UNKNOWN";
252  }
253 
270  virtual ~Logger() noexcept = default;
271 
278  return level_;
279  }
280 
290  template <typename... Args>
291  void log(LOG_LEVEL level, Args const&... args) {
292  if (static_cast<std::uint32_t>(level_) < static_cast<std::uint32_t>(level)) {
293  return;
294  }
295  std::ostringstream ss;
296  (ss << ... << args);
297  do_log(level, std::move(ss));
298  }
299 
306  template <typename... Args>
307  void print(Args const&... args) {
308  log(LOG_LEVEL::PRINT, std::forward<Args const&>(args)...);
309  }
310 
317  template <typename... Args>
318  void warn(Args const&... args) {
319  log(LOG_LEVEL::WARN, std::forward<Args const&>(args)...);
320  }
321 
328  template <typename... Args>
329  void info(Args const&... args) {
330  log(LOG_LEVEL::INFO, std::forward<Args const&>(args)...);
331  }
332 
339  template <typename... Args>
340  void debug(Args const&... args) {
341  log(LOG_LEVEL::DEBUG, std::forward<Args const&>(args)...);
342  }
343 
350  template <typename... Args>
351  void trace(Args const&... args) {
352  log(LOG_LEVEL::TRACE, std::forward<Args const&>(args)...);
353  }
354 
355  protected:
361  virtual std::uint32_t get_thread_id() {
362  auto const tid = std::this_thread::get_id();
363 
364  // To avoid large IDs, we map the thread ID to an unique counter.
365  auto const [name, inserted] =
366  thread_id_names.insert({tid, thread_id_names_counter});
367  if (inserted) {
368  ++thread_id_names_counter;
369  }
370  return name->second;
371  }
372 
384  virtual void do_log(LOG_LEVEL level, std::ostringstream&& ss) {
385  std::ostringstream full_log_msg;
386  full_log_msg << "[" << level_name(level) << ":" << rank_ << ":"
387  << get_thread_id() << ":" << Clock::now() << "] " << ss.str();
388  std::lock_guard<std::mutex> lock(mutex_);
389  std::cout << full_log_msg.str() << std::endl;
390  }
391 
392  private:
393  std::mutex mutex_;
394  Rank rank_;
395  LOG_LEVEL const level_;
396 
399  std::uint32_t thread_id_names_counter{0};
400 
402  std::unordered_map<std::thread::id, std::uint32_t> thread_id_names;
403  };
404 
405  protected:
406  Communicator() = default;
407 
408  public:
409  virtual ~Communicator() noexcept = default;
410 
415  [[nodiscard]] virtual Rank rank() const = 0;
416 
421  [[nodiscard]] virtual Rank nranks() const = 0;
422 
439  [[nodiscard]] virtual std::unique_ptr<Future> send(
440  std::unique_ptr<std::vector<std::uint8_t>> msg, Rank rank, Tag tag
441  ) = 0;
442 
461  [[nodiscard]] virtual std::unique_ptr<Future> send(
462  std::unique_ptr<Buffer> msg, Rank rank, Tag tag
463  ) = 0;
464 
484  [[nodiscard]] virtual std::unique_ptr<Future> recv(
485  Rank rank, Tag tag, std::unique_ptr<Buffer> recv_buffer
486  ) = 0;
487 
500  [[nodiscard]] virtual std::unique_ptr<Future> recv_sync_host_data(
501  Rank rank, Tag tag, std::unique_ptr<std::vector<std::uint8_t>> synced_buffer
502  ) = 0;
503 
513  [[nodiscard]] virtual std::pair<std::unique_ptr<std::vector<std::uint8_t>>, Rank>
514  recv_any(Tag tag) = 0;
515 
526  [[nodiscard]] virtual std::unique_ptr<std::vector<std::uint8_t>> recv_from(
527  Rank src, Tag tag
528  ) = 0;
529 
537  [[nodiscard]] virtual std::
538  pair<std::vector<std::unique_ptr<Future>>, std::vector<std::size_t>>
539  test_some(std::vector<std::unique_ptr<Future>>& future_vector) = 0;
540 
547  std::vector<std::size_t> virtual test_some(
548  std::unordered_map<std::size_t, std::unique_ptr<Communicator::Future>> const&
549  future_map
550  ) = 0;
551 
559  [[nodiscard]] virtual bool test(std::unique_ptr<Communicator::Future>& future) = 0;
560 
567  [[nodiscard]] virtual std::vector<std::unique_ptr<Buffer>> wait_all(
568  std::vector<std::unique_ptr<Communicator::Future>>&& futures
569  ) = 0;
570 
578  [[nodiscard]] virtual std::unique_ptr<Buffer> wait(
579  std::unique_ptr<Future> future
580  ) = 0;
581 
590  [[nodiscard]] std::unique_ptr<Buffer> virtual release_data(
591  std::unique_ptr<Communicator::Future> future
592  ) = 0;
593 
604  [[nodiscard]] std::
605  unique_ptr<std::vector<std::uint8_t>> virtual release_sync_host_data(
606  std::unique_ptr<Communicator::Future> future
607  ) = 0;
608 
613  [[nodiscard]] virtual std::shared_ptr<Communicator::Logger> const& logger() = 0;
614 
619  [[nodiscard]] virtual std::shared_ptr<ProgressThread> const&
620  progress_thread() const = 0;
621 
626  [[nodiscard]] virtual std::string str() const = 0;
627 };
628 
630 #ifdef RAPIDSMPF_HAVE_UCXX
631 constexpr bool COMM_HAVE_UCXX = true;
632 #else
633 constexpr bool COMM_HAVE_UCXX = false;
634 #endif
635 
637 #ifdef RAPIDSMPF_HAVE_MPI
638 constexpr bool COMM_HAVE_MPI = true;
639 #else
640 constexpr bool COMM_HAVE_MPI = false;
641 #endif
642 
653 inline std::ostream& operator<<(std::ostream& os, Communicator const& obj) {
654  os << obj.str();
655  return os;
656 }
657 
658 } // namespace rapidsmpf
Buffer representing device or host memory.
Definition: buffer.hpp:47
Abstract base class for asynchronous operation within the communicator.
Future(Future const &)=delete
Not copyable.
Future & operator=(Future &&)=default
Move assignment.
Future(Future &&)=default
Movable.
Future & operator=(Future const &)=delete
Not copy-assignable.
A logger base class for handling different levels of log messages.
void print(Args const &... args)
Logs a print message.
Logger(Rank rank, config::Options options)
Construct a new logger.
virtual void do_log(LOG_LEVEL level, std::ostringstream &&ss)
Handles the logging of a messages.
virtual std::uint32_t get_thread_id()
Returns a unique thread ID for the current thread.
void debug(Args const &... args)
Logs a debug message.
void info(Args const &... args)
Logs an informational message.
static constexpr std::array< char const *, 6 > LOG_LEVEL_NAMES
Log level names corresponding to the LOG_LEVEL enum.
LOG_LEVEL
Log verbosity levels.
void log(LOG_LEVEL level, Args const &... args)
Logs a message using the specified verbosity level.
void warn(Args const &... args)
Logs a warning message.
static constexpr char const * level_name(LOG_LEVEL level)
Get the string name of a log level.
void trace(Args const &... args)
Logs a trace message.
LOG_LEVEL verbosity_level() const
Get the verbosity level of the logger.
Abstract base class for a communication mechanism between nodes.
virtual std::unique_ptr< std::vector< std::uint8_t > > recv_from(Rank src, Tag tag)=0
Receives a message from a specific rank (blocking).
virtual Rank nranks() const =0
Retrieves the total number of ranks.
virtual std::string str() const =0
Provides a string representation of the communicator.
virtual std::unique_ptr< Buffer > release_data(std::unique_ptr< Communicator::Future > future)=0
Retrieves data associated with a completed future.
virtual Rank rank() const =0
Retrieves the rank of the current node.
virtual std::shared_ptr< ProgressThread > const & progress_thread() const =0
Retrieves the progress thread associated with this communicator.
virtual bool test(std::unique_ptr< Communicator::Future > &future)=0
Test for completion of a single future.
virtual std::unique_ptr< Future > recv_sync_host_data(Rank rank, Tag tag, std::unique_ptr< std::vector< std::uint8_t >> synced_buffer)=0
Receives a message from a specific rank to an allocated (synchronized) host buffer....
virtual std::vector< std::unique_ptr< Buffer > > wait_all(std::vector< std::unique_ptr< Communicator::Future >> &&futures)=0
Wait for completion of all futures and return their data buffers.
virtual std::pair< std::vector< std::unique_ptr< Future > >, std::vector< std::size_t > > test_some(std::vector< std::unique_ptr< Future >> &future_vector)=0
Tests for completion of multiple futures.
virtual std::unique_ptr< std::vector< std::uint8_t > > release_sync_host_data(std::unique_ptr< Communicator::Future > future)=0
Retrieves synchronized host data associated with a completed future. When the future is completed,...
virtual std::unique_ptr< Future > recv(Rank rank, Tag tag, std::unique_ptr< Buffer > recv_buffer)=0
Receives a message from a specific rank to a buffer. Use release_data to extract the data out of the ...
virtual std::unique_ptr< Buffer > wait(std::unique_ptr< Future > future)=0
Wait for a future to complete and return the data buffer.
virtual std::shared_ptr< Communicator::Logger > const & logger()=0
Retrieves the logger associated with this communicator.
virtual std::unique_ptr< Future > send(std::unique_ptr< std::vector< std::uint8_t >> msg, Rank rank, Tag tag)=0
Sends a host message to a specific rank.
virtual std::pair< std::unique_ptr< std::vector< std::uint8_t > >, Rank > recv_any(Tag tag)=0
Receives a message from any rank (blocking).
A progress thread that can execute arbitrary functions.
A tag used for identifying messages in a communication operation.
static constexpr StorageT stage_id_mask
Mask for the stage ID.
constexpr Tag(OpID const op, StageID const stage)
Constructs a tag.
constexpr OpID op() const noexcept
Extracts the operation ID from the tag.
constexpr StageID stage() const noexcept
Extracts the stage ID from the tag.
static constexpr int stage_id_bits
Number of bits for the stage ID.
static constexpr StorageT max_value() noexcept
Returns the max value of the tag.
static constexpr int op_id_bits
Number of bits for the operation ID.
static constexpr std::size_t bit_length() noexcept
Returns the max number of bits used for the tag.
std::int32_t StorageT
The physical data type to store the tag.
static constexpr StorageT op_id_mask
Mask for the operation ID.
Manages configuration options for RapidsMPF operations.
Definition: config.hpp:140
RAPIDS Multi-Processor interfaces.
Definition: backend.hpp:13
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).
std::int32_t StageID
Identifier for a stage of a communication operation.
constexpr bool COMM_HAVE_MPI
Whether RapidsMPF was built with the MPI Communicator.
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...
std::ostream & operator<<(std::ostream &os, Communicator const &obj)
Overloads the stream insertion operator for the Communicator class.
constexpr bool COMM_HAVE_UCXX
Whether RapidsMPF was built with the UCXX Communicator.