worker.h
1 
5 #pragma once
6 
7 #include <functional>
8 #include <memory>
9 #include <mutex>
10 #include <queue>
11 #include <string>
12 #include <thread>
13 #include <utility>
14 #include <variant>
15 
16 #include <ucp/api/ucp.h>
17 
18 #include <ucxx/buffer.h>
19 #include <ucxx/component.h>
20 #include <ucxx/constructors.h>
21 #include <ucxx/context.h>
22 #include <ucxx/delayed_submission.h>
23 #include <ucxx/future.h>
24 #include <ucxx/inflight_requests.h>
25 #include <ucxx/notifier.h>
26 #include <ucxx/typedefs.h>
27 #include <ucxx/worker_progress_thread.h>
28 
29 namespace ucxx {
30 
31 namespace experimental {
32 class WorkerBuilder;
33 } // namespace experimental
34 
35 class Address;
36 class Buffer;
37 class Endpoint;
38 class Listener;
39 class RequestAm;
40 
41 namespace internal {
42 class AmData;
43 } // namespace internal
44 
51 class Worker : public Component {
52  private:
53  ucp_worker_h _handle{nullptr};
54  int _epollFileDescriptor{-1};
55  int _workerFileDescriptor{-1};
56  std::mutex _inflightRequestsMutex{};
57  std::unique_ptr<InflightRequests> _inflightRequests{
58  std::make_unique<InflightRequests>()};
59  std::mutex
60  _inflightRequestsToCancelMutex{};
61  std::unique_ptr<InflightRequests> _inflightRequestsToCancel{
62  std::make_unique<InflightRequests>()};
63  WorkerProgressThread _progressThread{};
64  std::thread::id _progressThreadId{};
65  std::function<void(void*)> _progressThreadStartCallback{
66  nullptr};
67  void* _progressThreadStartCallbackArg{
68  nullptr};
69  std::shared_ptr<DelayedSubmissionCollection> _delayedSubmissionCollection{
70  nullptr};
71 
72  friend std::shared_ptr<RequestAm> createRequestAm(
73  std::shared_ptr<Endpoint> endpoint,
74  const std::variant<data::AmSend, data::AmReceive> requestData,
75  const bool enablePythonFuture,
76  RequestCallbackUserFunction callbackFunction,
77  RequestCallbackUserData callbackData);
78 
79  protected:
81  false};
83  false};
84  std::mutex _futuresPoolMutex{};
85  std::queue<std::shared_ptr<Future>>
87  std::shared_ptr<Notifier> _notifier{nullptr};
88  std::shared_ptr<internal::AmData>
90  BufferType _cudaBufferType{BufferType::Invalid};
91 
92  private:
99  void drainWorkerTagRecv();
100 
115  [[nodiscard]] std::shared_ptr<RequestAm> getAmRecv(
116  ucp_ep_h ep, std::function<std::shared_ptr<RequestAm>()> createAmRecvRequestFunction);
117 
124  void stopProgressThreadNoWarn();
125 
136  [[nodiscard]] std::shared_ptr<Request> registerInflightRequest(std::shared_ptr<Request> request);
137 
145  bool progressPending();
146 
158  void setCudaBufferType(BufferType bufferType);
159 
160  protected:
178  explicit Worker(std::shared_ptr<Context> context,
179  const bool enableDelayedSubmission = false,
180  const bool enableFuture = false);
181 
182  public:
183  Worker() = delete;
184  Worker(const Worker&) = delete;
185  Worker& operator=(Worker const&) = delete;
186  Worker(Worker&& o) = delete;
187  Worker& operator=(Worker&& o) = delete;
188 
195  friend std::shared_ptr<Worker> createWorker(std::shared_ptr<Context> context,
196  const bool enableDelayedSubmission,
197  const bool enableFuture);
198 
203 
207  virtual ~Worker();
208 
224  [[nodiscard]] ucp_worker_h getHandle();
225 
234  [[nodiscard]] std::string getInfo();
235 
265 
282 
293  bool arm();
294 
322  bool progressWorkerEvent(const int epollTimeout = -1);
323 
358  void signal();
359 
381  bool waitProgress();
382 
397  bool progressOnce();
398 
414  bool progress();
415 
432  void registerDelayedSubmission(std::shared_ptr<Request> request,
434 
463  uint64_t period = 0);
464 
492  uint64_t period = 0);
493 
501  [[nodiscard]] bool isDelayedRequestSubmissionEnabled() const;
502 
510  [[nodiscard]] bool isFutureEnabled() const;
511 
523  [[nodiscard]] bool isRequestAttributesEnabled() const noexcept;
524 
534  [[nodiscard]] BufferType getCudaBufferType() const;
535 
547  virtual void populateFuturesPool();
548 
558  virtual void clearFuturesPool();
559 
571  [[nodiscard]] virtual std::shared_ptr<Future> getFuture();
572 
587  [[nodiscard]] virtual RequestNotifierWaitState waitRequestNotifier(uint64_t periodNs);
588 
601  virtual void runRequestNotifier();
602 
611 
622  void setProgressThreadStartCallback(std::function<void(void*)> callback, void* callbackArg);
623 
636  void startProgressThread(const bool pollingMode = false, const int epollTimeout = 1);
637 
647 
655  [[nodiscard]] bool isProgressThreadRunning();
656 
664  [[nodiscard]] std::thread::id getProgressThreadId();
665 
686  size_t cancelInflightRequests(uint64_t period = 0, uint64_t maxAttempts = 1);
687 
700  void scheduleRequestCancel(TrackedRequests trackedRequests);
701 
711  void removeInflightRequest(std::shared_ptr<Request> request);
712 
755  [[nodiscard]] std::shared_ptr<TagProbeInfo> tagProbe(const Tag tag,
756  const TagMask tagMask = TagMaskFull,
757  const bool remove = false) const;
758 
790  [[nodiscard]] std::shared_ptr<Request> tagRecv(
791  void* buffer,
792  size_t length,
793  Tag tag,
794  TagMask tagMask,
795  const bool enableFuture = false,
796  RequestCallbackUserFunction callbackFunction = nullptr,
797  RequestCallbackUserData callbackData = nullptr);
798 
822  [[nodiscard]] std::shared_ptr<Request> tagRecvWithHandle(
823  void* buffer,
824  std::shared_ptr<TagProbeInfo> probeInfo,
825  const bool enableFuture = false,
826  RequestCallbackUserFunction callbackFunction = nullptr,
827  RequestCallbackUserData callbackData = nullptr);
828 
840  [[nodiscard]] std::shared_ptr<Address> getAddress();
841 
866  [[nodiscard]] std::shared_ptr<Endpoint> createEndpointFromHostname(
867  std::string ipAddress, uint16_t port, bool endpointErrorHandling = true);
868 
897  [[nodiscard]] std::shared_ptr<Endpoint> createEndpointFromWorkerAddress(
898  std::shared_ptr<Address> address, bool endpointErrorHandling = true);
899 
917  [[nodiscard]] std::shared_ptr<Listener> createListener(uint16_t port,
918  ucp_listener_conn_callback_t callback,
919  void* callbackArgs);
920 
948  void registerAmAllocator(ucs_memory_type_t memoryType, AmAllocatorType allocator);
949 
985 
1005  [[nodiscard]] bool amProbe(const ucp_ep_h endpointHandle) const;
1006 
1034  [[nodiscard]] std::shared_ptr<Request> flush(
1035  const bool enablePythonFuture = false,
1036  RequestCallbackUserFunction callbackFunction = nullptr,
1037  RequestCallbackUserData callbackData = nullptr);
1038 
1042  struct Attributes {
1044  ucs_thread_mode_t threadMode{UCS_THREAD_MODE_MULTI};
1046  size_t maxAmHeader{0};
1048  std::string name{};
1050  size_t maxDebugString{0};
1051  };
1052 
1063  [[nodiscard]] Attributes queryAttributes() const;
1064 };
1065 
1087 std::shared_ptr<Worker> createWorker(std::shared_ptr<Context> context,
1088  const bool enableDelayedSubmission,
1089  const bool enableFuture);
1090 
1091 } // namespace ucxx
1092 
1093 // Include experimental features
1094 #include <ucxx/experimental/worker_builder.h>
Component encapsulating the address of a UCP worker.
Definition: address.h:24
Information of an Active Message receiver callback.
Definition: typedefs.h:224
A UCXX component class to prevent early destruction of parent object.
Definition: component.h:17
Component encapsulating a UCP endpoint.
Definition: endpoint.h:50
Represent a future that may be notified by a specialized notifier.
Definition: future.h:21
Component encapsulating a UCP listener.
Definition: listener.h:23
Base type for a UCXX transfer request.
Definition: request.h:39
Information about probed tag message.
Definition: tag_probe.h:59
A thread to progress a ucxx::Worker.
Definition: worker_progress_thread.h:48
Component encapsulating a UCP worker.
Definition: worker.h:51
bool amProbe(const ucp_ep_h endpointHandle) const
Check for uncaught active messages.
int getEpollFileDescriptor()
Get the epoll file descriptor associated with the worker.
bool progress()
Progress the worker until all communication events are completed.
void registerAmReceiverCallback(AmReceiverCallbackInfo info, AmReceiverCallbackType callback)
Register receiver callback for active messages.
bool isRequestAttributesEnabled() const noexcept
Inquire if worker has been created with request attributes querying enabled.
bool _enableRequestAttributes
Whether request attributes (e.g. UCP debug info) are queried for each request.
Definition: worker.h:82
void signal()
Signal the worker that an event happened.
virtual void populateFuturesPool()
Populate the futures pool.
bool registerGenericPre(DelayedSubmissionCallbackType callback, uint64_t period=0)
Register callback to be executed in progress thread before progressing.
void setProgressThreadStartCallback(std::function< void(void *)> callback, void *callbackArg)
Set callback to be executed at the progress thread start.
bool progressOnce()
Progress the worker only once.
void startProgressThread(const bool pollingMode=false, const int epollTimeout=1)
Start the progress thread.
bool isProgressThreadRunning()
Inquire if worker has a progress thread running.
friend std::shared_ptr< RequestAm > createRequestAm(std::shared_ptr< Endpoint > endpoint, const std::variant< data::AmSend, data::AmReceive > requestData, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData)
void registerAmAllocator(ucs_memory_type_t memoryType, AmAllocatorType allocator)
Register allocator for active messages.
std::shared_ptr< Request > tagRecvWithHandle(void *buffer, std::shared_ptr< TagProbeInfo > probeInfo, const bool enableFuture=false, RequestCallbackUserFunction callbackFunction=nullptr, RequestCallbackUserData callbackData=nullptr)
Enqueue a tag receive operation using a message handle.
void stopProgressThread()
Stop the progress thread.
Worker(std::shared_ptr< Context > context, const bool enableDelayedSubmission=false, const bool enableFuture=false)
Protected constructor of ucxx::Worker.
std::shared_ptr< Endpoint > createEndpointFromHostname(std::string ipAddress, uint16_t port, bool endpointErrorHandling=true)
Create endpoint to worker listening on specific IP and port.
std::shared_ptr< Listener > createListener(uint16_t port, ucp_listener_conn_callback_t callback, void *callbackArgs)
Listen for remote connections on given port.
std::queue< std::shared_ptr< Future > > _futuresPool
Futures pool to prevent running out of fresh futures.
Definition: worker.h:86
BufferType _cudaBufferType
Preferred buffer type for CUDA allocations.
Definition: worker.h:90
bool isDelayedRequestSubmissionEnabled() const
Inquire if worker has been created with delayed submission enabled.
virtual ~Worker()
ucxx::Worker destructor.
virtual void clearFuturesPool()
Clear the futures pool.
ucp_worker_h getHandle()
Get the underlying ucp_worker_h handle.
std::mutex _futuresPoolMutex
Mutex to access the futures pool.
Definition: worker.h:84
virtual RequestNotifierWaitState waitRequestNotifier(uint64_t periodNs)
Block until a request event.
bool arm()
Arm the UCP worker.
std::shared_ptr< Request > flush(const bool enablePythonFuture=false, RequestCallbackUserFunction callbackFunction=nullptr, RequestCallbackUserData callbackData=nullptr)
Enqueue a flush operation.
std::shared_ptr< Endpoint > createEndpointFromWorkerAddress(std::shared_ptr< Address > address, bool endpointErrorHandling=true)
Create endpoint to worker located at UCX address.
bool waitProgress()
Block until an event has happened, then progresses.
std::shared_ptr< TagProbeInfo > tagProbe(const Tag tag, const TagMask tagMask=TagMaskFull, const bool remove=false) const
Check for uncaught tag messages.
virtual void runRequestNotifier()
Notify futures of each completed communication request.
bool registerGenericPost(DelayedSubmissionCallbackType callback, uint64_t period=0)
Register callback to be executed in progress thread before progressing.
std::shared_ptr< Notifier > _notifier
Notifier object.
Definition: worker.h:87
void registerDelayedSubmission(std::shared_ptr< Request > request, DelayedSubmissionCallbackType callback)
Register delayed request submission.
std::thread::id getProgressThreadId()
Get the progress thread ID.
size_t cancelInflightRequests(uint64_t period=0, uint64_t maxAttempts=1)
Cancel inflight requests.
bool _enableFuture
Boolean identifying whether the worker was created with future capability.
Definition: worker.h:80
bool isFutureEnabled() const
Inquire if worker has been created with future support.
virtual std::shared_ptr< Future > getFuture()
Get a future from the pool.
Attributes queryAttributes() const
Get the worker's attributes.
virtual void stopRequestNotifierThread()
Signal the notifier to terminate.
std::shared_ptr< Address > getAddress()
Get the address of the UCX worker object.
friend std::shared_ptr< Worker > createWorker(std::shared_ptr< Context > context, const bool enableDelayedSubmission, const bool enableFuture)
Friend declaration for ucxx::createWorker with parameters.
void removeInflightRequest(std::shared_ptr< Request > request)
Remove reference to request from internal container.
bool progressWorkerEvent(const int epollTimeout=-1)
Progress worker event while in blocking progress mode.
BufferType getCudaBufferType() const
Get the preferred buffer type for CUDA allocations.
std::shared_ptr< internal::AmData > _amData
Worker data made available to Active Messages callback.
Definition: worker.h:89
std::shared_ptr< Request > tagRecv(void *buffer, size_t length, Tag tag, TagMask tagMask, const bool enableFuture=false, RequestCallbackUserFunction callbackFunction=nullptr, RequestCallbackUserData callbackData=nullptr)
Enqueue a tag receive operation.
void scheduleRequestCancel(TrackedRequests trackedRequests)
Schedule cancelation of inflight requests.
void initBlockingProgressMode()
Initialize blocking progress mode.
std::string getInfo()
Get information about the underlying ucp_worker_h object.
Builder class for constructing std::shared_ptr<ucxx::Worker> objects.
Definition: worker_builder.h:44
Definition: address.h:16
std::function< void(ucs_status_t, std::shared_ptr< void >)> RequestCallbackUserFunction
A user-defined function to execute as part of a ucxx::Request callback.
Definition: typedefs.h:97
std::shared_ptr< void > RequestCallbackUserData
Data for the user-defined function provided to the ucxx::Request callback.
Definition: typedefs.h:105
std::function< void()> DelayedSubmissionCallbackType
A user-defined function to execute as part of delayed submission callback.
Definition: delayed_submission.h:32
std::function< std::shared_ptr< Buffer >size_t)> AmAllocatorType
Custom Active Message allocator type.
Definition: typedefs.h:129
std::shared_ptr< Worker > createWorker(std::shared_ptr< Context > context, const bool enableDelayedSubmission, const bool enableFuture)
Constructor of shared_ptr<ucxx::Worker> with parameters.
RequestNotifierWaitState
The state with which a wait operation completed.
Definition: notifier.h:26
TagMask
Strong type for a UCP tag mask.
Definition: typedefs.h:74
Tag
Strong type for a UCP tag.
Definition: typedefs.h:66
std::function< void(std::shared_ptr< Request >, ucp_ep_h)> AmReceiverCallbackType
Active Message receiver callback.
Definition: typedefs.h:138
BufferType
The type of a buffer.
Definition: buffer.h:22
Container for transferring tracked requests between InflightRequests instances.
Definition: inflight_requests.h:23
Worker attributes reported by ucp_worker_query.
Definition: worker.h:1042