progress_thread.hpp
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
7 #include <chrono>
8 #include <condition_variable>
9 #include <cstdint>
10 #include <mutex>
11 #include <unordered_map>
12 
13 #include <rapidsmpf/communicator/communicator.hpp>
14 #include <rapidsmpf/pausable_thread_loop.hpp>
15 #include <rapidsmpf/statistics.hpp>
16 
17 namespace rapidsmpf {
18 
28  public:
32  enum ProgressState : bool {
33  InProgress,
34  Done,
35  };
36 
41  using FunctionIndex = std::uint64_t;
42 
47  using ProgressThreadAddress = std::uintptr_t;
48 
53  struct FunctionID {
56  };
58 
67  FunctionID() = default;
68 
75  constexpr FunctionID(ProgressThreadAddress thread_addr, FunctionIndex index)
76  : thread_address(thread_addr), function_index(index) {}
77 
83  [[nodiscard]] constexpr bool is_valid() const {
85  }
86  };
87 
93  using Function = std::function<ProgressState()>;
94 
98  class FunctionState {
99  public:
105  explicit FunctionState(Function&& function);
106 
112  void operator()();
113 
114  Function function;
115  bool is_done{false};
116  };
117 
129  Communicator::Logger& logger,
130  std::shared_ptr<Statistics> statistics = Statistics::disabled(),
131  Duration sleep = std::chrono::microseconds{1}
132  );
133 
134  ~ProgressThread();
135 
139  void stop();
140 
152 
163  void remove_function(FunctionID function_id);
164 
170  void pause();
171 
175  void resume();
176 
182  bool is_running() const;
183 
184  private:
191  void event_loop();
192 
194  Communicator::Logger& logger_;
195  std::shared_ptr<Statistics> statistics_;
196  bool is_thread_initialized_{false};
197  bool active_{false};
198  mutable std::mutex mutex_;
199  std::condition_variable cv_;
200  FunctionIndex next_function_id_{0};
201  std::unordered_map<FunctionIndex, FunctionState> functions_;
202 };
203 
204 } // namespace rapidsmpf
A logger base class for handling different levels of log messages.
FunctionState(Function &&function)
Construct state of a function.
bool is_done
Whether the function has completed.
void operator()()
Execute the function.
A progress thread that can execute arbitrary functions.
FunctionID add_function(Function &&function)
Insert a function to process as part of the event loop.
bool is_running() const
Check if the progress thread is currently running.
void pause()
Pause the progress thread.
void resume()
Resume the progress thread.
ProgressThread(Communicator::Logger &logger, std::shared_ptr< Statistics > statistics=Statistics::disabled(), Duration sleep=std::chrono::microseconds{1})
Construct a new progress thread that can handle multiple functions.
void remove_function(FunctionID function_id)
Remove a function and stop processing it as part of the event loop.
void stop()
Stop the thread, blocking until all functions are done.
std::uint64_t FunctionIndex
The sequential index of a function within a ProgressThread.
std::function< ProgressState()> Function
The function type supported by ProgressThread, returning the progress state of the function.
std::uintptr_t ProgressThreadAddress
The address of a ProgressThread instance.
ProgressState
The progress state of a function, can be either InProgress or Done.
static std::shared_ptr< Statistics > disabled()
Returns a shared pointer to a disabled (no-op) Statistics instance.
A thread loop that can be paused, resumed, and stopped.
The unique ID of a function registered with ProgressThread. Composed of the ProgressThread address an...
FunctionID()=default
Construct a FunctionID with an invalid address.
constexpr FunctionID(ProgressThreadAddress thread_addr, FunctionIndex index)
Construct a new FunctionID.
constexpr bool is_valid() const
Check if the FunctionID is valid.
FunctionIndex function_index
The sequential index of the function.
ProgressThreadAddress thread_address
The address of the ProgressThread instance.