progress_thread.hpp
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
7 #include <chrono>
8 #include <condition_variable>
9 #include <cstdint>
10 #include <mutex>
11 #include <unordered_map>
12 
13 #include <rapidsmpf/pausable_thread_loop.hpp>
14 #include <rapidsmpf/statistics.hpp>
15 
16 namespace rapidsmpf {
17 
27  public:
31  enum ProgressState : bool {
32  InProgress,
33  Done,
34  };
35 
40  using FunctionIndex = std::uint64_t;
41 
46  using ProgressThreadAddress = std::uintptr_t;
47 
52  struct FunctionID {
55  };
57 
66  FunctionID() = default;
67 
74  constexpr FunctionID(ProgressThreadAddress thread_addr, FunctionIndex index)
75  : thread_address(thread_addr), function_index(index) {}
76 
82  [[nodiscard]] constexpr bool is_valid() const {
84  }
85  };
86 
92  using Function = std::function<ProgressState()>;
93 
97  class FunctionState {
98  public:
104  explicit FunctionState(Function&& function);
105 
111  void operator()();
112 
113  Function function;
114  bool is_done{false};
115  };
116 
127  std::shared_ptr<Statistics> statistics = Statistics::disabled(),
128  Duration sleep = std::chrono::microseconds{1}
129  );
130 
131  ~ProgressThread();
132 
136  void stop();
137 
149 
160  void remove_function(FunctionID function_id);
161 
167  void pause();
168 
172  void resume();
173 
179  bool is_running() const;
180 
184  std::shared_ptr<Statistics> statistics() const noexcept;
185 
186  private:
193  void event_loop();
194 
195  std::shared_ptr<Statistics> statistics_;
196  bool is_thread_initialized_{false};
197  bool active_{false};
198  mutable std::mutex mutex_;
199  std::condition_variable cv_;
200  FunctionIndex next_function_id_{0};
201  std::unordered_map<FunctionIndex, FunctionState> functions_;
202 
203  // Keep `thread_` as the last member so the progress thread cannot run
204  // before all other members have been fully initialized.
205  detail::PausableThreadLoop thread_;
206 };
207 
208 } // namespace rapidsmpf
FunctionState(Function &&function)
Construct state of a function.
bool is_done
Whether the function has completed.
void operator()()
Execute the function.
A progress thread that can execute arbitrary functions.
FunctionID add_function(Function &&function)
Insert a function to process as part of the event loop.
bool is_running() const
Check if the progress thread is currently running.
void pause()
Pause the progress thread.
void resume()
Resume the progress thread.
std::shared_ptr< Statistics > statistics() const noexcept
void remove_function(FunctionID function_id)
Remove a function and stop processing it as part of the event loop.
void stop()
Stop the thread, blocking until all functions are done.
std::uint64_t FunctionIndex
The sequential index of a function within a ProgressThread.
std::function< ProgressState()> Function
The function type supported by ProgressThread, returning the progress state of the function.
std::uintptr_t ProgressThreadAddress
The address of a ProgressThread instance.
ProgressThread(std::shared_ptr< Statistics > statistics=Statistics::disabled(), Duration sleep=std::chrono::microseconds{1})
Construct a new progress thread that can handle multiple functions.
ProgressState
The progress state of a function, can be either InProgress or Done.
Tracks statistics across rapidsmpf operations.
Definition: statistics.hpp:62
static std::shared_ptr< Statistics > disabled()
Returns a shared pointer to a disabled (no-op) Statistics instance.
RAPIDS Multi-Processor interfaces.
Definition: backend.hpp:13
std::chrono::duration< double > Duration
Alias for a duration type representing time in seconds as a double.
Definition: misc.hpp:33
The unique ID of a function registered with ProgressThread. Composed of the ProgressThread address an...
FunctionID()=default
Construct a FunctionID with an invalid address.
constexpr FunctionID(ProgressThreadAddress thread_addr, FunctionIndex index)
Construct a new FunctionID.
constexpr bool is_valid() const
Check if the FunctionID is valid.
FunctionIndex function_index
The sequential index of the function.
ProgressThreadAddress thread_address
The address of the ProgressThread instance.