libucxx  0.37.00
All Classes Namespaces Functions Variables Typedefs Enumerations Friends
worker.h
1 
5 #pragma once
6 
7 #include <functional>
8 #include <memory>
9 #include <mutex>
10 #include <queue>
11 #include <string>
12 #include <thread>
13 
14 #include <ucp/api/ucp.h>
15 
16 #include <ucxx/component.h>
17 #include <ucxx/constructors.h>
18 #include <ucxx/context.h>
19 #include <ucxx/delayed_submission.h>
20 #include <ucxx/future.h>
21 #include <ucxx/inflight_requests.h>
22 #include <ucxx/notifier.h>
23 #include <ucxx/typedefs.h>
24 #include <ucxx/worker_progress_thread.h>
25 
26 namespace ucxx {
27 
28 class Address;
29 class Buffer;
30 class Endpoint;
31 class Listener;
32 class RequestAm;
33 
34 namespace internal {
35 class AmData;
36 } // namespace internal
37 
44 class Worker : public Component {
45  private:
46  ucp_worker_h _handle{nullptr};
47  int _epollFileDescriptor{-1};
48  int _workerFileDescriptor{-1};
49  std::mutex _inflightRequestsMutex{};
50  std::unique_ptr<InflightRequests> _inflightRequests{
51  std::make_unique<InflightRequests>()};
52  std::mutex
53  _inflightRequestsToCancelMutex{};
54  std::unique_ptr<InflightRequests> _inflightRequestsToCancel{
55  std::make_unique<InflightRequests>()};
56  std::shared_ptr<WorkerProgressThread> _progressThread{nullptr};
57  std::thread::id _progressThreadId{};
58  std::function<void(void*)> _progressThreadStartCallback{
59  nullptr};
60  void* _progressThreadStartCallbackArg{
61  nullptr};
62  std::shared_ptr<DelayedSubmissionCollection> _delayedSubmissionCollection{
63  nullptr};
64 
65  friend std::shared_ptr<RequestAm> createRequestAm(
66  std::shared_ptr<Endpoint> endpoint,
67  const std::variant<data::AmSend, data::AmReceive> requestData,
68  const bool enablePythonFuture,
69  RequestCallbackUserFunction callbackFunction,
70  RequestCallbackUserData callbackData);
71 
72  protected:
74  false};
75  std::mutex _futuresPoolMutex{};
76  std::queue<std::shared_ptr<Future>>
78  std::shared_ptr<Notifier> _notifier{nullptr};
79  std::shared_ptr<internal::AmData>
81 
82  private:
89  void drainWorkerTagRecv();
90 
105  std::shared_ptr<RequestAm> getAmRecv(
106  ucp_ep_h ep, std::function<std::shared_ptr<RequestAm>()> createAmRecvRequestFunction);
107 
114  void stopProgressThreadNoWarn();
115 
126  std::shared_ptr<Request> registerInflightRequest(std::shared_ptr<Request> request);
127 
135  bool progressPending();
136 
137  protected:
155  explicit Worker(std::shared_ptr<Context> context,
156  const bool enableDelayedSubmission = false,
157  const bool enableFuture = false);
158 
159  public:
160  Worker() = delete;
161  Worker(const Worker&) = delete;
162  Worker& operator=(Worker const&) = delete;
163  Worker(Worker&& o) = delete;
164  Worker& operator=(Worker&& o) = delete;
165 
190  friend std::shared_ptr<Worker> createWorker(std::shared_ptr<Context> context,
191  const bool enableDelayedSubmission,
192  const bool enableFuture);
193 
197  virtual ~Worker();
198 
214  ucp_worker_h getHandle();
215 
224  std::string getInfo();
225 
255 
266  bool arm();
267 
295  bool progressWorkerEvent(const int epollTimeout = -1);
296 
331  void signal();
332 
354  bool waitProgress();
355 
370  bool progressOnce();
371 
387  bool progress();
388 
405  void registerDelayedSubmission(std::shared_ptr<Request> request,
407 
424 
440 
449 
457  bool isFutureEnabled() const;
458 
470  virtual void populateFuturesPool();
471 
483  virtual std::shared_ptr<Future> getFuture();
484 
499  virtual RequestNotifierWaitState waitRequestNotifier(uint64_t periodNs);
500 
513  virtual void runRequestNotifier();
514 
523 
534  void setProgressThreadStartCallback(std::function<void(void*)> callback, void* callbackArg);
535 
548  void startProgressThread(const bool pollingMode = false, const int epollTimeout = 1);
549 
559 
568 
576  std::thread::id getProgressThreadId();
577 
598  size_t cancelInflightRequests(uint64_t period = 0, uint64_t maxAttempts = 1);
599 
613 
626  void removeInflightRequest(const Request* const request);
627 
647  bool tagProbe(const Tag tag);
648 
673  std::shared_ptr<Request> tagRecv(void* buffer,
674  size_t length,
675  Tag tag,
676  TagMask tagMask,
677  const bool enableFuture = false,
678  RequestCallbackUserFunction callbackFunction = nullptr,
679  RequestCallbackUserData callbackData = nullptr);
680 
692  std::shared_ptr<Address> getAddress();
693 
718  std::shared_ptr<Endpoint> createEndpointFromHostname(std::string ipAddress,
719  uint16_t port,
720  bool endpointErrorHandling = true);
721 
750  std::shared_ptr<Endpoint> createEndpointFromWorkerAddress(std::shared_ptr<Address> address,
751  bool endpointErrorHandling = true);
752 
770  std::shared_ptr<Listener> createListener(uint16_t port,
771  ucp_listener_conn_callback_t callback,
772  void* callbackArgs);
773 
801  void registerAmAllocator(ucs_memory_type_t memoryType, AmAllocatorType allocator);
802 
841 
861  bool amProbe(const ucp_ep_h endpointHandle) const;
862 
884  std::shared_ptr<Request> flush(const bool enablePythonFuture = false,
885  RequestCallbackUserFunction callbackFunction = nullptr,
886  RequestCallbackUserData callbackData = nullptr);
887 };
888 
889 } // namespace ucxx
Information of an Active Message receiver callback.
Definition: typedefs.h:154
A UCXX component class to prevent early destruction of parent object.
Definition: component.h:17
Base type for a UCXX transfer request.
Definition: request.h:38
Component encapsulating a UCP worker.
Definition: worker.h:44
bool amProbe(const ucp_ep_h endpointHandle) const
Check for uncaught active messages.
bool progress()
Progress the worker until all communication events are completed.
void registerAmReceiverCallback(AmReceiverCallbackInfo info, AmReceiverCallbackType callback)
Register receiver callback for active messages.
void signal()
Signal the worker that an event happened.
virtual void populateFuturesPool()
Populate the future pool.
void setProgressThreadStartCallback(std::function< void(void *)> callback, void *callbackArg)
Set callback to be executed at the progress thread start.
bool progressOnce()
Progress the worker only once.
void startProgressThread(const bool pollingMode=false, const int epollTimeout=1)
Start the progress thread.
bool isProgressThreadRunning()
Inquire if worker has a progress thread running.
friend std::shared_ptr< RequestAm > createRequestAm(std::shared_ptr< Endpoint > endpoint, const std::variant< data::AmSend, data::AmReceive > requestData, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData)
bool tagProbe(const Tag tag)
Check for uncaught tag messages.
void registerAmAllocator(ucs_memory_type_t memoryType, AmAllocatorType allocator)
Register allocator for active messages.
void stopProgressThread()
Stop the progress thread.
Worker(std::shared_ptr< Context > context, const bool enableDelayedSubmission=false, const bool enableFuture=false)
Protected constructor of ucxx::Worker.
std::shared_ptr< Endpoint > createEndpointFromHostname(std::string ipAddress, uint16_t port, bool endpointErrorHandling=true)
Create endpoint to worker listening on specific IP and port.
std::shared_ptr< Listener > createListener(uint16_t port, ucp_listener_conn_callback_t callback, void *callbackArgs)
Listen for remote connections on given port.
std::queue< std::shared_ptr< Future > > _futuresPool
Futures pool to prevent running out of fresh futures.
Definition: worker.h:77
bool isDelayedRequestSubmissionEnabled() const
Inquire if worker has been created with delayed submission enabled.
virtual ~Worker()
ucxx::Worker destructor.
ucp_worker_h getHandle()
Get the underlying ucp_worker_h handle.
std::mutex _futuresPoolMutex
Mutex to access the futures pool.
Definition: worker.h:75
virtual RequestNotifierWaitState waitRequestNotifier(uint64_t periodNs)
Block until a request event.
void scheduleRequestCancel(TrackedRequestsPtr trackedRequests)
Schedule cancelation of inflight requests.
bool arm()
Arm the UCP worker.
std::shared_ptr< Request > flush(const bool enablePythonFuture=false, RequestCallbackUserFunction callbackFunction=nullptr, RequestCallbackUserData callbackData=nullptr)
Enqueue a flush operation.
std::shared_ptr< Endpoint > createEndpointFromWorkerAddress(std::shared_ptr< Address > address, bool endpointErrorHandling=true)
Create endpoint to worker located at UCX address.
bool waitProgress()
Block until an event has happened, then progresses.
virtual void runRequestNotifier()
Notify futures of each completed communication request.
void removeInflightRequest(const Request *const request)
Remove reference to request from internal container.
std::shared_ptr< Notifier > _notifier
Notifier object.
Definition: worker.h:78
void registerDelayedSubmission(std::shared_ptr< Request > request, DelayedSubmissionCallbackType callback)
Register delayed request submission.
std::thread::id getProgressThreadId()
Get the progress thread ID.
size_t cancelInflightRequests(uint64_t period=0, uint64_t maxAttempts=1)
Cancel inflight requests.
void registerGenericPost(DelayedSubmissionCallbackType callback)
Register callback to be executed in progress thread after progressing.
bool _enableFuture
Boolean identifying whether the worker was created with future capability.
Definition: worker.h:73
bool isFutureEnabled() const
Inquire if worker has been created with future support.
virtual std::shared_ptr< Future > getFuture()
Get a future from the pool.
virtual void stopRequestNotifierThread()
Signal the notifier to terminate.
std::shared_ptr< Address > getAddress()
Get the address of the UCX worker object.
friend std::shared_ptr< Worker > createWorker(std::shared_ptr< Context > context, const bool enableDelayedSubmission, const bool enableFuture)
Constructor of shared_ptr<ucxx::Worker>.
bool progressWorkerEvent(const int epollTimeout=-1)
Progress worker event while in blocking progress mode.
std::shared_ptr< internal::AmData > _amData
Worker data made available to Active Messages callback.
Definition: worker.h:80
void registerGenericPre(DelayedSubmissionCallbackType callback)
Register callback to be executed in progress thread before progressing.
std::shared_ptr< Request > tagRecv(void *buffer, size_t length, Tag tag, TagMask tagMask, const bool enableFuture=false, RequestCallbackUserFunction callbackFunction=nullptr, RequestCallbackUserData callbackData=nullptr)
Enqueue a tag receive operation.
void initBlockingProgressMode()
Initialize blocking progress mode.
std::string getInfo()
Get information about the underlying ucp_worker_h object.
Definition: address.h:15
std::function< void(ucs_status_t, std::shared_ptr< void >)> RequestCallbackUserFunction
A user-defined function to execute as part of a ucxx::Request callback.
Definition: typedefs.h:89
std::shared_ptr< void > RequestCallbackUserData
Data for the user-defined function provided to the ucxx::Request callback.
Definition: typedefs.h:97
std::function< void()> DelayedSubmissionCallbackType
A user-defined function to execute as part of delayed submission callback.
Definition: delayed_submission.h:31
std::function< std::shared_ptr< Buffer >size_t)> AmAllocatorType
Custom Active Message allocator type.
Definition: typedefs.h:121
std::unique_ptr< TrackedRequests > TrackedRequestsPtr
Pre-defined type for a pointer to a container of tracked requests.
Definition: inflight_requests.h:54
std::function< void(std::shared_ptr< Request >)> AmReceiverCallbackType
Active Message receiver callback.
Definition: typedefs.h:129
RequestNotifierWaitState
The state with which a wait operation completed.
Definition: notifier.h:26
TagMask
Strong type for a UCP tag mask.
Definition: typedefs.h:66
Tag
Strong type for a UCP tag.
Definition: typedefs.h:58