finish_counter.hpp
1 
5 #pragma once
6 
7 #include <array>
8 #include <chrono>
9 #include <condition_variable>
10 #include <functional>
11 #include <mutex>
12 #include <optional>
13 #include <unordered_map>
14 #include <vector>
15 
16 #include <rapidsmpf/communicator/communicator.hpp>
17 #include <rapidsmpf/shuffler/chunk.hpp>
18 #include <rapidsmpf/shuffler/postbox.hpp>
19 #include <rapidsmpf/utils.hpp>
20 
27 namespace rapidsmpf::shuffler {
28 
36 namespace detail {
37 
46  public:
56  using FinishedCallback = std::function<void(PartID)>;
57 
67  Rank nranks,
68  std::vector<PartID> const& local_partitions,
69  FinishedCallback&& finished_callback = nullptr
70  );
71 
72  ~FinishCounter() = default;
73 
86  void move_goalpost(PartID pid, ChunkID nchunks);
87 
100 
106  [[nodiscard]] bool all_finished() const;
107 
128  PartID wait_any(std::optional<std::chrono::milliseconds> timeout = {});
129 
149  void wait_on(PartID pid, std::optional<std::chrono::milliseconds> timeout = {});
150 
155  [[nodiscard]] std::string str() const;
156 
157  private:
158  Rank const nranks_;
159  PartID
160  n_unfinished_partitions_;
162 
164  struct PartitionInfo {
165  Rank rank_count{0};
166  ChunkID chunk_goal{0};
168  ChunkID finished_chunk_count{
169  0
170  };
172 
173  constexpr PartitionInfo() = default;
174 
175  constexpr void move_goalpost(ChunkID nchunks, Rank nranks) {
176  RAPIDSMPF_EXPECTS(nchunks != 0, "the goalpost was moved by 0 chunks");
177  RAPIDSMPF_EXPECTS(
178  ++rank_count <= nranks, "the goalpost was moved more than one per rank"
179  );
180  chunk_goal += nchunks;
181  }
182 
183  constexpr void add_finished_chunk(Rank nranks) {
184  finished_chunk_count++;
185  // only throw if rank_count == nranks
186  RAPIDSMPF_EXPECTS(
187  (rank_count < nranks) || (finished_chunk_count <= chunk_goal),
188  "finished chunk exceeds the goal"
189  );
190  }
191 
192  // The partition is finished if the goalpost has been set by all ranks
193  // and the number of finished chunks has reached the goal.
194  [[nodiscard]] constexpr bool is_finished(Rank nranks) const {
195  return rank_count == nranks && finished_chunk_count == chunk_goal;
196  }
197 
198  [[nodiscard]] constexpr ChunkID data_chunk_goal() const {
199  // there will always be a control message from each rank indicating how many
200  // chunks it's sending. Chunk goal contains this control message for each
201  // rank. Therefore, to get the data chunk goal, we need to subtract the number
202  // of ranks that have reported their chunk count from the chunk goal.
203  return chunk_goal - static_cast<ChunkID>(rank_count);
204  }
205  };
206 
207  // The goalpost of each partition. The goal is a rank counter to track how many ranks
208  // has reported their goal, and a chunk counter that specifies the goal. It is only
209  // when all ranks has reported their goal that the goalpost is final.
210  std::unordered_map<PartID, PartitionInfo> goalposts_;
211 
212  mutable std::mutex mutex_; // TODO: use a shared_mutex lock?
213  mutable std::condition_variable wait_cv_;
214 
215  FinishedCallback finished_callback_ =
216  nullptr;
217 };
218 
219 } // namespace detail
220 
231 inline std::ostream& operator<<(std::ostream& os, detail::FinishCounter const& obj) {
232  os << obj.str();
233  return os;
234 }
235 
236 } // namespace rapidsmpf::shuffler
Helper to tally the finish status of a shuffle.
void wait_on(PartID pid, std::optional< std::chrono::milliseconds > timeout={})
Wait for a specific partition to be finished (blocking). Optionally a timeout (in ms) can be provided...
std::function< void(PartID)> FinishedCallback
Callback function type called when a partition is finished.
std::string str() const
Returns a description of this instance.
PartID wait_any(std::optional< std::chrono::milliseconds > timeout={})
Returns the partition ID of a finished partition that hasn't been waited on (blocking)....
FinishCounter(Rank nranks, std::vector< PartID > const &local_partitions, FinishedCallback &&finished_callback=nullptr)
Construct a finish counter.
void move_goalpost(PartID pid, ChunkID nchunks)
Move the goalpost for a specific rank and partition.
void add_finished_chunk(PartID pid)
Add a finished chunk to a partition counter.
bool all_finished() const
Returns whether all partitions are finished (non-blocking).
std::uint64_t ChunkID
The globally unique ID of a chunk.
Definition: chunk.hpp:29
Shuffler interfaces.
Definition: chunk.hpp:15
std::uint32_t PartID
Partition ID, which goes from 0 to the total number of partitions.
Definition: chunk.hpp:22
std::ostream & operator<<(std::ostream &os, detail::FinishCounter const &obj)
Overloads the stream insertion operator for the FinishCounter class.