channel.hpp
1 
6 #pragma once
7 
8 #include <cstddef>
9 #include <limits>
10 #include <memory>
11 #include <stdexcept>
12 #include <utility>
13 
14 #include <rapidsmpf/error.hpp>
15 #include <rapidsmpf/streaming/core/message.hpp>
16 #include <rapidsmpf/streaming/core/node.hpp>
17 #include <rapidsmpf/streaming/core/spillable_messages.hpp>
18 
19 #include <coro/coro.hpp>
20 #include <coro/semaphore.hpp>
21 
22 namespace rapidsmpf::streaming {
23 
24 class Context;
25 
29 using Semaphore = coro::semaphore<std::numeric_limits<std::ptrdiff_t>::max()>;
30 
37 class Channel {
38  friend Context;
39 
40  public:
50  coro::task<bool> send(Message msg);
51 
60  coro::task<Message> receive();
61 
70  Node drain(std::unique_ptr<coro::thread_pool>& executor);
71 
80 
86  [[nodiscard]] bool empty() const noexcept;
87 
88  private:
89  Channel(std::shared_ptr<SpillableMessages> spillable_messages)
90  : sm_{std::move(spillable_messages)} {}
91 
92  coro::ring_buffer<SpillableMessages::MessageId, 1> rb_;
93  std::shared_ptr<SpillableMessages> sm_;
94 };
95 
104  private:
106  class Ticket {
107  public:
108  Ticket& operator=(Ticket const&) = delete;
109  Ticket(Ticket const&) = delete;
110  Ticket& operator=(Ticket&&) = default;
111  Ticket(Ticket&&) = default;
112  ~Ticket() = default;
113 
153  [[nodiscard]] coro::task<std::pair<bool, coro::task<void>>> send(Message msg) {
154  RAPIDSMPF_EXPECTS(ch_, "Ticket has already been used", std::logic_error);
155  auto sent = co_await ch_->send(std::move(msg));
156  ch_ = nullptr;
157  if (sent) {
158  co_return {sent, semaphore_->release()};
159  } else {
160  // If the channel is closed we want to wake any waiters so shutdown the
161  // semaphore.
162  co_await semaphore_->shutdown();
163  co_return {sent, []() -> coro::task<void> { co_return; }()};
164  }
165  }
166 
173  Ticket(Channel* channel, Semaphore* semaphore)
174  : ch_{channel}, semaphore_{semaphore} {};
175 
176  private:
177  Channel* ch_;
178  Semaphore* semaphore_;
179  };
180 
181  public:
217  std::shared_ptr<Channel> channel, std::ptrdiff_t max_tickets
218  )
219  : ch_{std::move(channel)}, semaphore_(max_tickets) {
220  RAPIDSMPF_EXPECTS(
221  max_tickets > 0, "ThrottlingAdaptor must have at least one ticket"
222  );
223  }
224 
235  [[nodiscard]] coro::task<Ticket> acquire() {
236  auto result = co_await semaphore_.acquire();
237  RAPIDSMPF_EXPECTS(
238  result == coro::semaphore_acquire_result::acquired,
239  "Semaphore was shutdown",
240  std::runtime_error
241  );
242  co_return {ch_.get(), &semaphore_};
243  }
244 
245  private:
246  std::shared_ptr<Channel> ch_;
247  Semaphore semaphore_;
248 };
249 
263  public:
273  explicit ShutdownAtExit(std::vector<std::shared_ptr<Channel>> channels)
274  : channels_{std::move(channels)} {
275  for (auto& ch : channels_) {
276  RAPIDSMPF_EXPECTS(ch, "channel cannot be null", std::invalid_argument);
277  }
278  }
279 
292  template <class... T>
293  explicit ShutdownAtExit(T&&... channels)
294  requires(std::convertible_to<T, std::shared_ptr<Channel>> && ...)
295  : ShutdownAtExit(
296  std::vector<std::shared_ptr<Channel>>{std::forward<T>(channels)...}
297  ) {}
298 
299  // Non-copyable, non-movable.
300  ShutdownAtExit(ShutdownAtExit const&) = delete;
301  ShutdownAtExit& operator=(ShutdownAtExit const&) = delete;
302  ShutdownAtExit(ShutdownAtExit&&) = delete;
303  ShutdownAtExit& operator=(ShutdownAtExit&&) = delete;
304 
310  ~ShutdownAtExit() noexcept {
311  for (auto& ch : channels_) {
312  coro::sync_wait(ch->shutdown());
313  }
314  }
315 
316  private:
317  std::vector<std::shared_ptr<Channel>> channels_;
318 };
319 
320 } // namespace rapidsmpf::streaming
A coroutine-based channel for sending and receiving messages asynchronously.
Definition: channel.hpp:37
bool empty() const noexcept
Check whether the channel is empty.
Node drain(std::unique_ptr< coro::thread_pool > &executor)
Drains all pending messages from the channel and shuts it down.
Node shutdown()
Immediately shuts down the channel.
coro::task< Message > receive()
Asynchronously receive a message from the channel.
coro::task< bool > send(Message msg)
Asynchronously send a message into the channel.
Context for nodes (coroutines) in rapidsmpf.
Definition: context.hpp:25
Type-erased message wrapper around a payload.
Definition: message.hpp:27
Helper RAII class to shut down channels when they go out of scope.
Definition: channel.hpp:262
~ShutdownAtExit() noexcept
Destructor that synchronously shuts down all channels.
Definition: channel.hpp:310
ShutdownAtExit(T &&... channels) requires(std
Variadic convenience constructor.
Definition: channel.hpp:293
ShutdownAtExit(std::vector< std::shared_ptr< Channel >> channels)
Construct from a vector of channel handles.
Definition: channel.hpp:273
Container for individually spillable messages.
An adaptor to throttle access to a channel.
Definition: channel.hpp:103
coro::task< Ticket > acquire()
Obtain a ticket to send a message.
Definition: channel.hpp:235
ThrottlingAdaptor(std::shared_ptr< Channel > channel, std::ptrdiff_t max_tickets)
Create an adaptor that throttles sends into a channel.
Definition: channel.hpp:216
coro::task< void > Node
Alias for a node in a streaming pipeline.
Definition: node.hpp:18
coro::semaphore< std::numeric_limits< std::ptrdiff_t >::max()> Semaphore
An awaitable semaphore to manage acquisition and release of finite resources.
Definition: channel.hpp:29