channel.hpp
1 
6 #pragma once
7 
8 #include <cstddef>
9 #include <limits>
10 #include <memory>
11 #include <stdexcept>
12 #include <utility>
13 
14 #include <coro/coro.hpp>
15 #include <coro/queue.hpp>
16 #include <coro/semaphore.hpp>
17 
18 #include <rapidsmpf/error.hpp>
19 #include <rapidsmpf/streaming/core/actor.hpp>
20 #include <rapidsmpf/streaming/core/coro_executor.hpp>
21 #include <rapidsmpf/streaming/core/message.hpp>
22 #include <rapidsmpf/streaming/core/spillable_messages.hpp>
23 
24 namespace rapidsmpf::streaming {
25 
26 class Context;
27 
31 using Semaphore = coro::semaphore<std::numeric_limits<std::ptrdiff_t>::max()>;
32 
52 class Channel {
53  friend Context;
54 
55  public:
67  [[nodiscard]] coro::task<bool> send(Message msg);
68 
79  [[nodiscard]] coro::task<Message> receive();
80 
98  [[nodiscard]] coro::task<bool> send_metadata(Message msg);
99 
108  [[nodiscard]] coro::task<Message> receive_metadata();
109 
123  [[nodiscard]] Actor drain_metadata(std::shared_ptr<CoroThreadPoolExecutor> executor);
124 
137  [[nodiscard]] Actor drain(std::shared_ptr<CoroThreadPoolExecutor> executor);
138 
147  [[nodiscard]] Actor shutdown();
148 
159  [[nodiscard]] Actor shutdown_metadata();
160 
166  [[nodiscard]] bool empty() const noexcept;
167 
173  [[nodiscard]] bool is_shutdown() const noexcept;
174 
175  private:
176  Channel(std::shared_ptr<SpillableMessages> spillable_messages)
177  : sm_{std::move(spillable_messages)} {}
178 
179  coro::ring_buffer<SpillableMessages::MessageId, 1> rb_;
180  std::shared_ptr<SpillableMessages> sm_;
181  coro::queue<Message> metadata_;
182 };
183 
192  private:
194  class Ticket {
195  public:
196  Ticket& operator=(Ticket const&) = delete;
197  Ticket(Ticket const&) = delete;
198  Ticket& operator=(Ticket&&) = default;
199  Ticket(Ticket&&) = default;
200  ~Ticket() = default;
201 
241  [[nodiscard]] coro::task<std::pair<bool, coro::task<void>>> send(Message msg) {
242  RAPIDSMPF_EXPECTS(ch_, "Ticket has already been used", std::logic_error);
243  auto sent = co_await ch_->send(std::move(msg));
244  ch_ = nullptr;
245  if (sent) {
246  co_return {sent, semaphore_->release()};
247  } else {
248  // If the channel is closed we want to wake any waiters so shutdown the
249  // semaphore.
250  co_await semaphore_->shutdown();
251  co_return {sent, []() -> coro::task<void> { co_return; }()};
252  }
253  }
254 
261  Ticket(Channel* channel, Semaphore* semaphore)
262  : ch_{channel}, semaphore_{semaphore} {};
263 
264  private:
265  Channel* ch_;
266  Semaphore* semaphore_;
267  };
268 
269  public:
305  std::shared_ptr<Channel> channel, std::ptrdiff_t max_tickets
306  )
307  : ch_{std::move(channel)}, semaphore_(max_tickets) {
308  RAPIDSMPF_EXPECTS(
309  max_tickets > 0, "ThrottlingAdaptor must have at least one ticket"
310  );
311  }
312 
323  [[nodiscard]] coro::task<Ticket> acquire() {
324  auto result = co_await semaphore_.acquire();
325  RAPIDSMPF_EXPECTS(
326  result == coro::semaphore_acquire_result::acquired,
327  "Semaphore was shutdown",
328  std::runtime_error
329  );
330  co_return {ch_.get(), &semaphore_};
331  }
332 
333  private:
334  std::shared_ptr<Channel> ch_;
335  Semaphore semaphore_;
336 };
337 
351  public:
361  explicit ShutdownAtExit(std::vector<std::shared_ptr<Channel>> channels)
362  : channels_{std::move(channels)} {
363  for (auto& ch : channels_) {
364  RAPIDSMPF_EXPECTS(ch, "channel cannot be null", std::invalid_argument);
365  }
366  }
367 
380  template <class... T>
381  explicit ShutdownAtExit(T&&... channels)
382  requires(std::convertible_to<T, std::shared_ptr<Channel>> && ...)
383  : ShutdownAtExit(
384  std::vector<std::shared_ptr<Channel>>{std::forward<T>(channels)...}
385  ) {}
386 
387  // Non-copyable, non-movable.
388  ShutdownAtExit(ShutdownAtExit const&) = delete;
389  ShutdownAtExit& operator=(ShutdownAtExit const&) = delete;
390  ShutdownAtExit(ShutdownAtExit&&) = delete;
391  ShutdownAtExit& operator=(ShutdownAtExit&&) = delete;
392 
398  ~ShutdownAtExit() noexcept {
399  for (auto& ch : channels_) {
400  coro::sync_wait(ch->shutdown());
401  }
402  }
403 
404  private:
405  std::vector<std::shared_ptr<Channel>> channels_;
406 };
407 
408 } // namespace rapidsmpf::streaming
A coroutine-based channel for sending and receiving messages asynchronously.
Definition: channel.hpp:52
bool empty() const noexcept
Check whether the channel is empty.
Actor shutdown_metadata()
Immediately shuts down the metadata channel.
coro::task< bool > send_metadata(Message msg)
Asynchronously send a metadata message into the channel.
Actor drain_metadata(std::shared_ptr< CoroThreadPoolExecutor > executor)
Drains all pending metadata messages from the channel and shuts down the metadata channel.
Actor drain(std::shared_ptr< CoroThreadPoolExecutor > executor)
Drains all pending messages from the channel and shuts it down.
bool is_shutdown() const noexcept
Check whether the channel is shut down.
coro::task< Message > receive_metadata()
Asynchronously receive a metadata message from the channel.
coro::task< Message > receive()
Asynchronously receive a message from the channel.
coro::task< bool > send(Message msg)
Asynchronously send a message into the channel.
Actor shutdown()
Immediately shuts down the channel.
Context for actors (coroutines) in rapidsmpf.
Definition: context.hpp:41
Type-erased message wrapper around a payload.
Definition: message.hpp:27
Helper RAII class to shut down channels when they go out of scope.
Definition: channel.hpp:350
~ShutdownAtExit() noexcept
Destructor that synchronously shuts down all channels.
Definition: channel.hpp:398
ShutdownAtExit(T &&... channels) requires(std
Variadic convenience constructor.
Definition: channel.hpp:381
ShutdownAtExit(std::vector< std::shared_ptr< Channel >> channels)
Construct from a vector of channel handles.
Definition: channel.hpp:361
Container for individually spillable messages.
An adaptor to throttle access to a channel.
Definition: channel.hpp:191
coro::task< Ticket > acquire()
Obtain a ticket to send a message.
Definition: channel.hpp:323
ThrottlingAdaptor(std::shared_ptr< Channel > channel, std::ptrdiff_t max_tickets)
Create an adaptor that throttles sends into a channel.
Definition: channel.hpp:304
coro::task< void > Actor
Alias for an actor in a streaming graph.
Definition: actor.hpp:18
coro::semaphore< std::numeric_limits< std::ptrdiff_t >::max()> Semaphore
An awaitable semaphore to manage acquisition and release of finite resources.
Definition: channel.hpp:31