worker.h
1 
5 #pragma once
6 
7 #include <functional>
8 #include <memory>
9 #include <mutex>
10 #include <queue>
11 #include <string>
12 #include <thread>
13 #include <utility>
14 #include <variant>
15 
16 #include <ucp/api/ucp.h>
17 
18 #include <ucxx/component.h>
19 #include <ucxx/constructors.h>
20 #include <ucxx/context.h>
21 #include <ucxx/delayed_submission.h>
22 #include <ucxx/future.h>
23 #include <ucxx/inflight_requests.h>
24 #include <ucxx/notifier.h>
25 #include <ucxx/typedefs.h>
26 #include <ucxx/worker_progress_thread.h>
27 
28 namespace ucxx {
29 
30 namespace experimental {
31 class WorkerBuilder;
32 } // namespace experimental
33 
34 class Address;
35 class Buffer;
36 class Endpoint;
37 class Listener;
38 class RequestAm;
39 
40 namespace internal {
41 class AmData;
42 } // namespace internal
43 
50 class Worker : public Component {
51  private:
52  ucp_worker_h _handle{nullptr};
53  int _epollFileDescriptor{-1};
54  int _workerFileDescriptor{-1};
55  std::mutex _inflightRequestsMutex{};
56  std::unique_ptr<InflightRequests> _inflightRequests{
57  std::make_unique<InflightRequests>()};
58  std::mutex
59  _inflightRequestsToCancelMutex{};
60  std::unique_ptr<InflightRequests> _inflightRequestsToCancel{
61  std::make_unique<InflightRequests>()};
62  WorkerProgressThread _progressThread{};
63  std::thread::id _progressThreadId{};
64  std::function<void(void*)> _progressThreadStartCallback{
65  nullptr};
66  void* _progressThreadStartCallbackArg{
67  nullptr};
68  std::shared_ptr<DelayedSubmissionCollection> _delayedSubmissionCollection{
69  nullptr};
70 
71  friend std::shared_ptr<RequestAm> createRequestAm(
72  std::shared_ptr<Endpoint> endpoint,
73  const std::variant<data::AmSend, data::AmReceive> requestData,
74  const bool enablePythonFuture,
75  RequestCallbackUserFunction callbackFunction,
76  RequestCallbackUserData callbackData);
77 
78  protected:
80  false};
81  std::mutex _futuresPoolMutex{};
82  std::queue<std::shared_ptr<Future>>
84  std::shared_ptr<Notifier> _notifier{nullptr};
85  std::shared_ptr<internal::AmData>
87 
88  private:
95  void drainWorkerTagRecv();
96 
111  [[nodiscard]] std::shared_ptr<RequestAm> getAmRecv(
112  ucp_ep_h ep, std::function<std::shared_ptr<RequestAm>()> createAmRecvRequestFunction);
113 
120  void stopProgressThreadNoWarn();
121 
132  [[nodiscard]] std::shared_ptr<Request> registerInflightRequest(std::shared_ptr<Request> request);
133 
141  bool progressPending();
142 
143  protected:
161  explicit Worker(std::shared_ptr<Context> context,
162  const bool enableDelayedSubmission = false,
163  const bool enableFuture = false);
164 
165  public:
166  Worker() = delete;
167  Worker(const Worker&) = delete;
168  Worker& operator=(Worker const&) = delete;
169  Worker(Worker&& o) = delete;
170  Worker& operator=(Worker&& o) = delete;
171 
178  friend std::shared_ptr<Worker> createWorker(std::shared_ptr<Context> context,
179  const bool enableDelayedSubmission,
180  const bool enableFuture);
181 
186 
190  virtual ~Worker();
191 
207  [[nodiscard]] ucp_worker_h getHandle();
208 
217  [[nodiscard]] std::string getInfo();
218 
248 
265 
276  bool arm();
277 
305  bool progressWorkerEvent(const int epollTimeout = -1);
306 
341  void signal();
342 
364  bool waitProgress();
365 
380  bool progressOnce();
381 
397  bool progress();
398 
415  void registerDelayedSubmission(std::shared_ptr<Request> request,
417 
446  uint64_t period = 0);
447 
475  uint64_t period = 0);
476 
484  [[nodiscard]] bool isDelayedRequestSubmissionEnabled() const;
485 
493  [[nodiscard]] bool isFutureEnabled() const;
494 
506  virtual void populateFuturesPool();
507 
517  virtual void clearFuturesPool();
518 
530  [[nodiscard]] virtual std::shared_ptr<Future> getFuture();
531 
546  [[nodiscard]] virtual RequestNotifierWaitState waitRequestNotifier(uint64_t periodNs);
547 
560  virtual void runRequestNotifier();
561 
570 
581  void setProgressThreadStartCallback(std::function<void(void*)> callback, void* callbackArg);
582 
595  void startProgressThread(const bool pollingMode = false, const int epollTimeout = 1);
596 
606 
614  [[nodiscard]] bool isProgressThreadRunning();
615 
623  [[nodiscard]] std::thread::id getProgressThreadId();
624 
645  size_t cancelInflightRequests(uint64_t period = 0, uint64_t maxAttempts = 1);
646 
660 
670  void removeInflightRequest(std::shared_ptr<Request> request);
671 
714  [[nodiscard]] std::shared_ptr<TagProbeInfo> tagProbe(const Tag tag,
715  const TagMask tagMask = TagMaskFull,
716  const bool remove = false) const;
717 
749  [[nodiscard]] std::shared_ptr<Request> tagRecv(
750  void* buffer,
751  size_t length,
752  Tag tag,
753  TagMask tagMask,
754  const bool enableFuture = false,
755  RequestCallbackUserFunction callbackFunction = nullptr,
756  RequestCallbackUserData callbackData = nullptr);
757 
781  [[nodiscard]] std::shared_ptr<Request> tagRecvWithHandle(
782  void* buffer,
783  std::shared_ptr<TagProbeInfo> probeInfo,
784  const bool enableFuture = false,
785  RequestCallbackUserFunction callbackFunction = nullptr,
786  RequestCallbackUserData callbackData = nullptr);
787 
799  [[nodiscard]] std::shared_ptr<Address> getAddress();
800 
825  [[nodiscard]] std::shared_ptr<Endpoint> createEndpointFromHostname(
826  std::string ipAddress, uint16_t port, bool endpointErrorHandling = true);
827 
856  [[nodiscard]] std::shared_ptr<Endpoint> createEndpointFromWorkerAddress(
857  std::shared_ptr<Address> address, bool endpointErrorHandling = true);
858 
876  [[nodiscard]] std::shared_ptr<Listener> createListener(uint16_t port,
877  ucp_listener_conn_callback_t callback,
878  void* callbackArgs);
879 
907  void registerAmAllocator(ucs_memory_type_t memoryType, AmAllocatorType allocator);
908 
944 
964  [[nodiscard]] bool amProbe(const ucp_ep_h endpointHandle) const;
965 
993  [[nodiscard]] std::shared_ptr<Request> flush(
994  const bool enablePythonFuture = false,
995  RequestCallbackUserFunction callbackFunction = nullptr,
996  RequestCallbackUserData callbackData = nullptr);
997 };
998 
1020 std::shared_ptr<Worker> createWorker(std::shared_ptr<Context> context,
1021  const bool enableDelayedSubmission,
1022  const bool enableFuture);
1023 
1024 } // namespace ucxx
1025 
1026 // Include experimental features
1027 #include <ucxx/experimental/worker_builder.h>
Information of an Active Message receiver callback.
Definition: typedefs.h:160
A UCXX component class to prevent early destruction of parent object.
Definition: component.h:17
A thread to progress a ucxx::Worker.
Definition: worker_progress_thread.h:48
Component encapsulating a UCP worker.
Definition: worker.h:50
bool amProbe(const ucp_ep_h endpointHandle) const
Check for uncaught active messages.
int getEpollFileDescriptor()
Get the epoll file descriptor associated with the worker.
bool progress()
Progress the worker until all communication events are completed.
void registerAmReceiverCallback(AmReceiverCallbackInfo info, AmReceiverCallbackType callback)
Register receiver callback for active messages.
void signal()
Signal the worker that an event happened.
virtual void populateFuturesPool()
Populate the futures pool.
bool registerGenericPre(DelayedSubmissionCallbackType callback, uint64_t period=0)
Register callback to be executed in progress thread before progressing.
void setProgressThreadStartCallback(std::function< void(void *)> callback, void *callbackArg)
Set callback to be executed at the progress thread start.
bool progressOnce()
Progress the worker only once.
void startProgressThread(const bool pollingMode=false, const int epollTimeout=1)
Start the progress thread.
bool isProgressThreadRunning()
Inquire if worker has a progress thread running.
friend std::shared_ptr< RequestAm > createRequestAm(std::shared_ptr< Endpoint > endpoint, const std::variant< data::AmSend, data::AmReceive > requestData, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData)
void registerAmAllocator(ucs_memory_type_t memoryType, AmAllocatorType allocator)
Register allocator for active messages.
std::shared_ptr< Request > tagRecvWithHandle(void *buffer, std::shared_ptr< TagProbeInfo > probeInfo, const bool enableFuture=false, RequestCallbackUserFunction callbackFunction=nullptr, RequestCallbackUserData callbackData=nullptr)
Enqueue a tag receive operation using a message handle.
void stopProgressThread()
Stop the progress thread.
Worker(std::shared_ptr< Context > context, const bool enableDelayedSubmission=false, const bool enableFuture=false)
Protected constructor of ucxx::Worker.
std::shared_ptr< Endpoint > createEndpointFromHostname(std::string ipAddress, uint16_t port, bool endpointErrorHandling=true)
Create endpoint to worker listening on specific IP and port.
std::shared_ptr< Listener > createListener(uint16_t port, ucp_listener_conn_callback_t callback, void *callbackArgs)
Listen for remote connections on given port.
std::queue< std::shared_ptr< Future > > _futuresPool
Futures pool to prevent running out of fresh futures.
Definition: worker.h:83
bool isDelayedRequestSubmissionEnabled() const
Inquire if worker has been created with delayed submission enabled.
virtual ~Worker()
ucxx::Worker destructor.
virtual void clearFuturesPool()
Clear the futures pool.
ucp_worker_h getHandle()
Get the underlying ucp_worker_h handle.
std::mutex _futuresPoolMutex
Mutex to access the futures pool.
Definition: worker.h:81
virtual RequestNotifierWaitState waitRequestNotifier(uint64_t periodNs)
Block until a request event.
void scheduleRequestCancel(TrackedRequestsPtr trackedRequests)
Schedule cancelation of inflight requests.
bool arm()
Arm the UCP worker.
std::shared_ptr< Request > flush(const bool enablePythonFuture=false, RequestCallbackUserFunction callbackFunction=nullptr, RequestCallbackUserData callbackData=nullptr)
Enqueue a flush operation.
std::shared_ptr< Endpoint > createEndpointFromWorkerAddress(std::shared_ptr< Address > address, bool endpointErrorHandling=true)
Create endpoint to worker located at UCX address.
bool waitProgress()
Block until an event has happened, then progresses.
std::shared_ptr< TagProbeInfo > tagProbe(const Tag tag, const TagMask tagMask=TagMaskFull, const bool remove=false) const
Check for uncaught tag messages.
virtual void runRequestNotifier()
Notify futures of each completed communication request.
bool registerGenericPost(DelayedSubmissionCallbackType callback, uint64_t period=0)
Register callback to be executed in progress thread before progressing.
std::shared_ptr< Notifier > _notifier
Notifier object.
Definition: worker.h:84
void registerDelayedSubmission(std::shared_ptr< Request > request, DelayedSubmissionCallbackType callback)
Register delayed request submission.
std::thread::id getProgressThreadId()
Get the progress thread ID.
size_t cancelInflightRequests(uint64_t period=0, uint64_t maxAttempts=1)
Cancel inflight requests.
bool _enableFuture
Boolean identifying whether the worker was created with future capability.
Definition: worker.h:79
bool isFutureEnabled() const
Inquire if worker has been created with future support.
virtual std::shared_ptr< Future > getFuture()
Get a future from the pool.
virtual void stopRequestNotifierThread()
Signal the notifier to terminate.
std::shared_ptr< Address > getAddress()
Get the address of the UCX worker object.
friend std::shared_ptr< Worker > createWorker(std::shared_ptr< Context > context, const bool enableDelayedSubmission, const bool enableFuture)
Friend declaration for ucxx::createWorker with parameters.
void removeInflightRequest(std::shared_ptr< Request > request)
Remove reference to request from internal container.
bool progressWorkerEvent(const int epollTimeout=-1)
Progress worker event while in blocking progress mode.
std::shared_ptr< internal::AmData > _amData
Worker data made available to Active Messages callback.
Definition: worker.h:86
std::shared_ptr< Request > tagRecv(void *buffer, size_t length, Tag tag, TagMask tagMask, const bool enableFuture=false, RequestCallbackUserFunction callbackFunction=nullptr, RequestCallbackUserData callbackData=nullptr)
Enqueue a tag receive operation.
void initBlockingProgressMode()
Initialize blocking progress mode.
std::string getInfo()
Get information about the underlying ucp_worker_h object.
Builder class for constructing std::shared_ptr<ucxx::Worker> objects.
Definition: worker_builder.h:42
Definition: address.h:15
std::function< void(ucs_status_t, std::shared_ptr< void >)> RequestCallbackUserFunction
A user-defined function to execute as part of a ucxx::Request callback.
Definition: typedefs.h:96
std::shared_ptr< void > RequestCallbackUserData
Data for the user-defined function provided to the ucxx::Request callback.
Definition: typedefs.h:104
std::function< void()> DelayedSubmissionCallbackType
A user-defined function to execute as part of delayed submission callback.
Definition: delayed_submission.h:32
std::function< std::shared_ptr< Buffer >size_t)> AmAllocatorType
Custom Active Message allocator type.
Definition: typedefs.h:128
std::unique_ptr< TrackedRequests > TrackedRequestsPtr
Pre-defined type for a pointer to a container of tracked requests.
Definition: inflight_requests.h:46
std::shared_ptr< Worker > createWorker(std::shared_ptr< Context > context, const bool enableDelayedSubmission, const bool enableFuture)
Constructor of shared_ptr<ucxx::Worker> with parameters.
RequestNotifierWaitState
The state with which a wait operation completed.
Definition: notifier.h:26
TagMask
Strong type for a UCP tag mask.
Definition: typedefs.h:73
Tag
Strong type for a UCP tag.
Definition: typedefs.h:65
std::function< void(std::shared_ptr< Request >, ucp_ep_h)> AmReceiverCallbackType
Active Message receiver callback.
Definition: typedefs.h:137