All Classes Namespaces Functions Variables Typedefs Enumerations Friends
worker.h
1 
5 #pragma once
6 
7 #include <functional>
8 #include <memory>
9 #include <mutex>
10 #include <queue>
11 #include <string>
12 #include <thread>
13 #include <utility>
14 
15 #include <ucp/api/ucp.h>
16 
17 #include <ucxx/component.h>
18 #include <ucxx/constructors.h>
19 #include <ucxx/context.h>
20 #include <ucxx/delayed_submission.h>
21 #include <ucxx/future.h>
22 #include <ucxx/inflight_requests.h>
23 #include <ucxx/notifier.h>
24 #include <ucxx/typedefs.h>
25 #include <ucxx/worker_progress_thread.h>
26 
27 namespace ucxx {
28 
29 class Address;
30 class Buffer;
31 class Endpoint;
32 class Listener;
33 class RequestAm;
34 
35 namespace internal {
36 class AmData;
37 } // namespace internal
38 
45 class Worker : public Component {
46  private:
47  ucp_worker_h _handle{nullptr};
48  int _epollFileDescriptor{-1};
49  int _workerFileDescriptor{-1};
50  std::mutex _inflightRequestsMutex{};
51  std::unique_ptr<InflightRequests> _inflightRequests{
52  std::make_unique<InflightRequests>()};
53  std::mutex
54  _inflightRequestsToCancelMutex{};
55  std::unique_ptr<InflightRequests> _inflightRequestsToCancel{
56  std::make_unique<InflightRequests>()};
57  WorkerProgressThread _progressThread{};
58  std::thread::id _progressThreadId{};
59  std::function<void(void*)> _progressThreadStartCallback{
60  nullptr};
61  void* _progressThreadStartCallbackArg{
62  nullptr};
63  std::shared_ptr<DelayedSubmissionCollection> _delayedSubmissionCollection{
64  nullptr};
65 
66  friend std::shared_ptr<RequestAm> createRequestAm(
67  std::shared_ptr<Endpoint> endpoint,
68  const std::variant<data::AmSend, data::AmReceive> requestData,
69  const bool enablePythonFuture,
70  RequestCallbackUserFunction callbackFunction,
71  RequestCallbackUserData callbackData);
72 
73  protected:
75  false};
76  std::mutex _futuresPoolMutex{};
77  std::queue<std::shared_ptr<Future>>
79  std::shared_ptr<Notifier> _notifier{nullptr};
80  std::shared_ptr<internal::AmData>
82 
83  private:
90  void drainWorkerTagRecv();
91 
106  [[nodiscard]] std::shared_ptr<RequestAm> getAmRecv(
107  ucp_ep_h ep, std::function<std::shared_ptr<RequestAm>()> createAmRecvRequestFunction);
108 
115  void stopProgressThreadNoWarn();
116 
127  [[nodiscard]] std::shared_ptr<Request> registerInflightRequest(std::shared_ptr<Request> request);
128 
136  bool progressPending();
137 
138  protected:
156  explicit Worker(std::shared_ptr<Context> context,
157  const bool enableDelayedSubmission = false,
158  const bool enableFuture = false);
159 
160  public:
161  Worker() = delete;
162  Worker(const Worker&) = delete;
163  Worker& operator=(Worker const&) = delete;
164  Worker(Worker&& o) = delete;
165  Worker& operator=(Worker&& o) = delete;
166 
191  friend std::shared_ptr<Worker> createWorker(std::shared_ptr<Context> context,
192  const bool enableDelayedSubmission,
193  const bool enableFuture);
194 
198  virtual ~Worker();
199 
215  [[nodiscard]] ucp_worker_h getHandle();
216 
225  [[nodiscard]] std::string getInfo();
226 
256 
273 
284  bool arm();
285 
313  bool progressWorkerEvent(const int epollTimeout = -1);
314 
349  void signal();
350 
372  bool waitProgress();
373 
388  bool progressOnce();
389 
405  bool progress();
406 
423  void registerDelayedSubmission(std::shared_ptr<Request> request,
425 
454  uint64_t period = 0);
455 
483  uint64_t period = 0);
484 
492  [[nodiscard]] bool isDelayedRequestSubmissionEnabled() const;
493 
501  [[nodiscard]] bool isFutureEnabled() const;
502 
514  virtual void populateFuturesPool();
515 
525  virtual void clearFuturesPool();
526 
538  [[nodiscard]] virtual std::shared_ptr<Future> getFuture();
539 
554  [[nodiscard]] virtual RequestNotifierWaitState waitRequestNotifier(uint64_t periodNs);
555 
568  virtual void runRequestNotifier();
569 
578 
589  void setProgressThreadStartCallback(std::function<void(void*)> callback, void* callbackArg);
590 
603  void startProgressThread(const bool pollingMode = false, const int epollTimeout = 1);
604 
614 
622  [[nodiscard]] bool isProgressThreadRunning();
623 
631  [[nodiscard]] std::thread::id getProgressThreadId();
632 
653  size_t cancelInflightRequests(uint64_t period = 0, uint64_t maxAttempts = 1);
654 
668 
681  void removeInflightRequest(const Request* const request);
682 
713  [[nodiscard]] std::pair<bool, TagRecvInfo> tagProbe(const Tag tag,
714  const TagMask tagMask = TagMaskFull);
715 
740  [[nodiscard]] std::shared_ptr<Request> tagRecv(
741  void* buffer,
742  size_t length,
743  Tag tag,
744  TagMask tagMask,
745  const bool enableFuture = false,
746  RequestCallbackUserFunction callbackFunction = nullptr,
747  RequestCallbackUserData callbackData = nullptr);
748 
760  [[nodiscard]] std::shared_ptr<Address> getAddress();
761 
786  [[nodiscard]] std::shared_ptr<Endpoint> createEndpointFromHostname(
787  std::string ipAddress, uint16_t port, bool endpointErrorHandling = true);
788 
817  [[nodiscard]] std::shared_ptr<Endpoint> createEndpointFromWorkerAddress(
818  std::shared_ptr<Address> address, bool endpointErrorHandling = true);
819 
837  [[nodiscard]] std::shared_ptr<Listener> createListener(uint16_t port,
838  ucp_listener_conn_callback_t callback,
839  void* callbackArgs);
840 
868  void registerAmAllocator(ucs_memory_type_t memoryType, AmAllocatorType allocator);
869 
908 
928  [[nodiscard]] bool amProbe(const ucp_ep_h endpointHandle) const;
929 
951  [[nodiscard]] std::shared_ptr<Request> flush(
952  const bool enablePythonFuture = false,
953  RequestCallbackUserFunction callbackFunction = nullptr,
954  RequestCallbackUserData callbackData = nullptr);
955 };
956 
957 } // namespace ucxx
Information of an Active Message receiver callback.
Definition: typedefs.h:169
A UCXX component class to prevent early destruction of parent object.
Definition: component.h:17
Base type for a UCXX transfer request.
Definition: request.h:38
A thread to progress a ucxx::Worker.
Definition: worker_progress_thread.h:48
Component encapsulating a UCP worker.
Definition: worker.h:45
bool amProbe(const ucp_ep_h endpointHandle) const
Check for uncaught active messages.
int getEpollFileDescriptor()
Get the epoll file descriptor associated with the worker.
bool progress()
Progress the worker until all communication events are completed.
void registerAmReceiverCallback(AmReceiverCallbackInfo info, AmReceiverCallbackType callback)
Register receiver callback for active messages.
void signal()
Signal the worker that an event happened.
virtual void populateFuturesPool()
Populate the futures pool.
bool registerGenericPre(DelayedSubmissionCallbackType callback, uint64_t period=0)
Register callback to be executed in progress thread before progressing.
void setProgressThreadStartCallback(std::function< void(void *)> callback, void *callbackArg)
Set callback to be executed at the progress thread start.
bool progressOnce()
Progress the worker only once.
void startProgressThread(const bool pollingMode=false, const int epollTimeout=1)
Start the progress thread.
bool isProgressThreadRunning()
Inquire if worker has a progress thread running.
friend std::shared_ptr< RequestAm > createRequestAm(std::shared_ptr< Endpoint > endpoint, const std::variant< data::AmSend, data::AmReceive > requestData, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData)
void registerAmAllocator(ucs_memory_type_t memoryType, AmAllocatorType allocator)
Register allocator for active messages.
void stopProgressThread()
Stop the progress thread.
Worker(std::shared_ptr< Context > context, const bool enableDelayedSubmission=false, const bool enableFuture=false)
Protected constructor of ucxx::Worker.
std::shared_ptr< Endpoint > createEndpointFromHostname(std::string ipAddress, uint16_t port, bool endpointErrorHandling=true)
Create endpoint to worker listening on specific IP and port.
std::shared_ptr< Listener > createListener(uint16_t port, ucp_listener_conn_callback_t callback, void *callbackArgs)
Listen for remote connections on given port.
std::queue< std::shared_ptr< Future > > _futuresPool
Futures pool to prevent running out of fresh futures.
Definition: worker.h:78
bool isDelayedRequestSubmissionEnabled() const
Inquire if worker has been created with delayed submission enabled.
virtual ~Worker()
ucxx::Worker destructor.
virtual void clearFuturesPool()
Clear the futures pool.
ucp_worker_h getHandle()
Get the underlying ucp_worker_h handle.
std::mutex _futuresPoolMutex
Mutex to access the futures pool.
Definition: worker.h:76
virtual RequestNotifierWaitState waitRequestNotifier(uint64_t periodNs)
Block until a request event.
void scheduleRequestCancel(TrackedRequestsPtr trackedRequests)
Schedule cancelation of inflight requests.
bool arm()
Arm the UCP worker.
std::shared_ptr< Request > flush(const bool enablePythonFuture=false, RequestCallbackUserFunction callbackFunction=nullptr, RequestCallbackUserData callbackData=nullptr)
Enqueue a flush operation.
std::shared_ptr< Endpoint > createEndpointFromWorkerAddress(std::shared_ptr< Address > address, bool endpointErrorHandling=true)
Create endpoint to worker located at UCX address.
bool waitProgress()
Block until an event has happened, then progresses.
virtual void runRequestNotifier()
Notify futures of each completed communication request.
bool registerGenericPost(DelayedSubmissionCallbackType callback, uint64_t period=0)
Register callback to be executed in progress thread before progressing.
void removeInflightRequest(const Request *const request)
Remove reference to request from internal container.
std::shared_ptr< Notifier > _notifier
Notifier object.
Definition: worker.h:79
void registerDelayedSubmission(std::shared_ptr< Request > request, DelayedSubmissionCallbackType callback)
Register delayed request submission.
std::thread::id getProgressThreadId()
Get the progress thread ID.
size_t cancelInflightRequests(uint64_t period=0, uint64_t maxAttempts=1)
Cancel inflight requests.
bool _enableFuture
Boolean identifying whether the worker was created with future capability.
Definition: worker.h:74
bool isFutureEnabled() const
Inquire if worker has been created with future support.
std::pair< bool, TagRecvInfo > tagProbe(const Tag tag, const TagMask tagMask=TagMaskFull)
Check for uncaught tag messages.
virtual std::shared_ptr< Future > getFuture()
Get a future from the pool.
virtual void stopRequestNotifierThread()
Signal the notifier to terminate.
std::shared_ptr< Address > getAddress()
Get the address of the UCX worker object.
friend std::shared_ptr< Worker > createWorker(std::shared_ptr< Context > context, const bool enableDelayedSubmission, const bool enableFuture)
Constructor of shared_ptr<ucxx::Worker>.
bool progressWorkerEvent(const int epollTimeout=-1)
Progress worker event while in blocking progress mode.
std::shared_ptr< internal::AmData > _amData
Worker data made available to Active Messages callback.
Definition: worker.h:81
std::shared_ptr< Request > tagRecv(void *buffer, size_t length, Tag tag, TagMask tagMask, const bool enableFuture=false, RequestCallbackUserFunction callbackFunction=nullptr, RequestCallbackUserData callbackData=nullptr)
Enqueue a tag receive operation.
void initBlockingProgressMode()
Initialize blocking progress mode.
std::string getInfo()
Get information about the underlying ucp_worker_h object.
Definition: address.h:15
std::function< void(ucs_status_t, std::shared_ptr< void >)> RequestCallbackUserFunction
A user-defined function to execute as part of a ucxx::Request callback.
Definition: typedefs.h:103
std::shared_ptr< void > RequestCallbackUserData
Data for the user-defined function provided to the ucxx::Request callback.
Definition: typedefs.h:111
std::function< void()> DelayedSubmissionCallbackType
A user-defined function to execute as part of delayed submission callback.
Definition: delayed_submission.h:32
std::function< std::shared_ptr< Buffer >size_t)> AmAllocatorType
Custom Active Message allocator type.
Definition: typedefs.h:135
std::unique_ptr< TrackedRequests > TrackedRequestsPtr
Pre-defined type for a pointer to a container of tracked requests.
Definition: inflight_requests.h:44
RequestNotifierWaitState
The state with which a wait operation completed.
Definition: notifier.h:26
TagMask
Strong type for a UCP tag mask.
Definition: typedefs.h:66
Tag
Strong type for a UCP tag.
Definition: typedefs.h:58
std::function< void(std::shared_ptr< Request >, ucp_ep_h)> AmReceiverCallbackType
Active Message receiver callback.
Definition: typedefs.h:144