Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
97a9c8c
Add `ucxx::RequestEndpointClose` to support non-blocking close
pentschev Feb 14, 2024
89337af
Endpoint status fixes
pentschev Feb 15, 2024
8c4fda3
Test disabled endpoint error handling
pentschev Feb 15, 2024
dba8c2d
Update Python's `set_close_callback`
pentschev Feb 15, 2024
2c133ea
Fix Endpoint's inflight request registration
pentschev Feb 16, 2024
f3de55e
Rename Endpoint close method to `closeBlocking()`
pentschev Feb 16, 2024
63458f4
Add non-blocking inflight request cancelation to `ucxx::Endpoint`
pentschev Feb 16, 2024
030b868
Delete `EndpointClose` default constructor
pentschev Feb 16, 2024
a9482d5
Merge remote-tracking branch 'upstream/branch-0.37' into request-endp…
pentschev Feb 16, 2024
978af0a
Fix `ucxx::Endpoint::closeBlocking()`
pentschev Feb 16, 2024
6dcf644
Fix `_endpoint_close_callback`
pentschev Feb 16, 2024
19c5e02
Avoid `std::variant` for single type
pentschev Feb 19, 2024
0fcab9d
Merge remote-tracking branch 'upstream/branch-0.37' into request-endp…
pentschev Feb 20, 2024
1c7a3e6
Merge remote-tracking branch 'upstream/branch-0.37' into request-endp…
pentschev Feb 20, 2024
e5a2581
Merge remote-tracking branch 'upstream/branch-0.37' into request-endp…
pentschev Feb 22, 2024
0fc4faf
Merge remote-tracking branch 'upstream/branch-0.37' into request-endp…
pentschev Feb 26, 2024
d4cefc2
Merge remote-tracking branch 'upstream/branch-0.37' into request-endp…
pentschev Mar 11, 2024
66ef5b2
Merge remote-tracking branch 'upstream/branch-0.37' into request-endp…
pentschev Mar 13, 2024
3784f3d
Fix copyright headers and linting
pentschev Mar 14, 2024
0a1aaf2
Fix `ucxx::Endpoint::close()` doctring
pentschev Mar 14, 2024
72bd88f
Remove lingering debug output
pentschev Mar 14, 2024
f6079eb
Merge remote-tracking branch 'origin/request-endpoint-close' into req…
pentschev Mar 14, 2024
e56d569
Merge remote-tracking branch 'upstream/branch-0.37' into request-endp…
pentschev Mar 14, 2024
088254b
Fix linting
pentschev Mar 14, 2024
78ce20e
Merge remote-tracking branch 'upstream/branch-0.37' into request-endp…
pentschev Mar 14, 2024
66b9c74
Update casting in `remove_close_callback()`
pentschev Mar 14, 2024
eb8e708
Prevent race condition with endpoint error callback and its setter
pentschev Mar 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ add_library(
src/request.cpp
src/request_am.cpp
src/request_data.cpp
src/request_endpoint_close.cpp
src/request_flush.cpp
src/request_helper.cpp
src/request_mem.cpp
Expand Down
8 changes: 8 additions & 0 deletions cpp/include/ucxx/constructors.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Notifier;
class RemoteKey;
class Request;
class RequestAm;
class RequestEndpointClose;
class RequestFlush;
class RequestMem;
class RequestStream;
Expand Down Expand Up @@ -77,6 +78,13 @@ std::shared_ptr<RequestAm> createRequestAm(
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

std::shared_ptr<RequestEndpointClose> createRequestEndpointClose(
std::shared_ptr<Endpoint> endpoint,
const data::EndpointClose requestData,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

std::shared_ptr<RequestFlush> createRequestFlush(std::shared_ptr<Component> endpointOrWorker,
const data::Flush requestData,
const bool enablePythonFuture,
Expand Down
98 changes: 84 additions & 14 deletions cpp/include/ucxx/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,14 @@ struct EpParamsDeleter {
* callback to modify the `ucxx::Endpoint` with information relevant to the error occurred.
*/
struct ErrorCallbackData {
Endpoint*
_endpoint; ///< Pointer to the `ucxx::Endpoint` that owns this object, used only for logging.
std::unique_ptr<std::mutex>
_mutex; ///< Mutex used to prevent race conditions with `ucxx::Endpoint::setCloseCallback()`.
ucs_status_t status; ///< Endpoint status
std::shared_ptr<InflightRequests> inflightRequests; ///< Endpoint inflight requests
std::function<void(void*)> closeCallback; ///< Close callback to call
void* closeCallbackArg; ///< Argument to be passed to close callback
EndpointCloseCallbackUserFunction closeCallback; ///< Close callback to call
EndpointCloseCallbackUserData closeCallbackArg; ///< Argument to be passed to close callback
std::shared_ptr<Worker> worker; ///< Worker the endpoint has been created from
};

Expand Down Expand Up @@ -216,9 +220,10 @@ class Endpoint : public Component {
/**
* @brief Check whether the endpoint is still alive.
*
* Check whether the endpoint is still alive, generally `true` until `close()` is called
* the endpoint errors and the error handling procedure is executed. Always `true` if
* endpoint error handling is disabled.
* Check whether the endpoint is still alive, generally `true` until `closeBlocking()` is
* called, `close()` is called and the returned request completes or the endpoint errors
* and the error handling procedure is executed. Always `true` if endpoint error handling
* is disabled.
*
* @returns whether the endpoint is still alive if endpoint enables error handling, always
* returns `true` if error handling is disabled.
Expand Down Expand Up @@ -253,9 +258,32 @@ class Endpoint : public Component {
/**
* @brief Cancel inflight requests.
*
* Cancel inflight requests, returning the total number of requests that were canceled.
* This is usually executed by `close()`, when pending requests will no longer be able
* to complete.
* Cancel inflight requests, returning the total number of requests that were scheduled
* for cancelation. After the requests are scheduled for cancelation, the caller must
* progress the worker and check the result of `getCancelingSize()`, all requests are only
* canceled when `getCancelingSize()` returns `0`.
*
* @returns Number of requests that were scheduled for cancelation.
*/
size_t cancelInflightRequests();

/**
* @brief Check the number of inflight requests being canceled.
*
* Check the number of inflight requests that were scheduled for cancelation with
* `cancelInflightRequests()` who have not yet completed cancelation. To ensure their
* cancelation is completed, the worker must be progressed until this method returns `0`.
*
* @returns Number of requests that are in process of cancelation.
*/
size_t getCancelingSize() const;

/**
* @brief Cancel inflight requests.
*
* Cancel inflight requests and block until all requests complete cancelation, returning
* the total number of requests that were canceled. This is usually executed by
* `closeBlocking()`, when pending requests will no longer be able to complete.
*
* If the parent worker is running a progress thread, a maximum timeout may be specified
* for which the close operation will wait. This can be particularly important for cases
Expand All @@ -271,7 +299,7 @@ class Endpoint : public Component {
*
* @returns Number of requests that were canceled.
*/
size_t cancelInflightRequests(uint64_t period = 0, uint64_t maxAttempts = 1);
size_t cancelInflightRequestsBlocking(uint64_t period = 0, uint64_t maxAttempts = 1);

/**
* @brief Register a user-defined callback to call when endpoint closes.
Expand All @@ -286,7 +314,8 @@ class Endpoint : public Component {
* receiving a single opaque pointer.
* @param[in] closeCallbackArg pointer to optional user-allocated callback argument.
*/
void setCloseCallback(std::function<void(void*)> closeCallback, void* closeCallbackArg);
void setCloseCallback(EndpointCloseCallbackUserFunction closeCallback,
EndpointCloseCallbackUserData closeCallbackArg);

/**
* @brief Enqueue an active message send operation.
Expand Down Expand Up @@ -684,11 +713,53 @@ class Endpoint : public Component {
*/
static void errorCallback(void* arg, ucp_ep_h ep, ucs_status_t status);

/**
* @brief Enqueue a non-blocking endpoint close operation.
*
* Enqueue a non-blocking endpoint close operation, which will close the endpoint without
* requiring to destroy the object. This may be useful when other
* `std::shared_ptr<ucxx::Request>` objects are still alive, such as inflight transfers.
*
* This method returns a `std::shared<ucxx::Request>` that can be later awaited and
* checked for errors. This is a non-blocking operation, and the status of closing the
* endpoint must be verified from the resulting request object before the
* `std::shared_ptr<ucxx::Endpoint>` can be safely destroyed and the UCP endpoint assumed
* inactive (closed).
*
* If the endpoint was created with error handling support, the error callback will be
* executed, implying the user-defined callback will also be executed.
*
* If a user-defined callback is specified via the `callbackFunction` argument then that
* callback will be executed, if not then the callback registered with `setCloseCallback()`
* will be executed, if neither was specified then no user-defined callback will be
* executed.
*
* Using a Python future may be requested by specifying `enablePythonFuture`. If a
* Python future is requested, the Python application must then await on this future to
* ensure the close operation has completed. Requires UCXX Python support.
*
* @warning Unlike its `closeBlocking()` counterpart, this method does not cancel any
* inflight requests prior to submitting the UCP close request. Before scheduling the
* endpoint close request, the caller must first call `cancelInflightRequests()` and
* progress the worker until `getCancelingSize()` returns `0`.
*
* @param[in] enablePythonFuture whether a python future should be created and
* subsequently notified.
* @param[in] callbackFunction user-defined callback function to call upon completion.
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
*
* @returns Request to be subsequently checked for the completion and its state.
*/
std::shared_ptr<Request> close(const bool enablePythonFuture = false,
EndpointCloseCallbackUserFunction callbackFunction = nullptr,
EndpointCloseCallbackUserData callbackData = nullptr);

/**
* @brief Close the endpoint while keeping the object alive.
*
* Close the endpoint without requiring to destroy the object. This may be useful when
* `std::shared_ptr<ucxx::Request>` objects are still alive.
* Close the endpoint without requiring to destroy the object, blocking until the
* operation completes. This may be useful when `std::shared_ptr<ucxx::Request>` objects
* are still alive.
*
* If the endpoint was created with error handling support, the error callback will be
* executed, implying the user-defined callback will also be executed if one was
Expand All @@ -705,9 +776,8 @@ class Endpoint : public Component {
* operation will wait for.
* @param[in] maxAttempts maximum number of attempts to close endpoint, only applicable
* if worker is running a progress thread and `period > 0`.
*
*/
void close(uint64_t period = 0, uint64_t maxAttempts = 1);
void closeBlocking(uint64_t period = 0, uint64_t maxAttempts = 1);
};

} // namespace ucxx
22 changes: 22 additions & 0 deletions cpp/include/ucxx/request_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,27 @@ class AmReceive {
AmReceive();
};

/**
* @brief Data for an endpoint close operation.
*
* Type identifying an endpoint close operation and containing data specific to this request
* type.
*/
class EndpointClose {
public:
const bool _force{false}; ///< Whether to force endpoint closing.
/**
* @brief Constructor for endpoint close-specific data.
*
* Construct an object containing endpoint close-specific data.
*
* @param[in] force force endpoint close if `true`, flush otherwise.
*/
explicit EndpointClose(const decltype(_force) force);

EndpointClose() = delete;
};

/**
* @brief Data for a flush operation.
*
Expand Down Expand Up @@ -307,6 +328,7 @@ class TagMultiReceive {
using RequestData = std::variant<std::monostate,
AmSend,
AmReceive,
EndpointClose,
Flush,
MemPut,
MemGet,
Expand Down
115 changes: 115 additions & 0 deletions cpp/include/ucxx/request_endpoint_close.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: BSD-3-Clause
*/
#pragma once
#include <memory>
#include <string>
#include <utility>

#include <ucp/api/ucp.h>

#include <ucxx/delayed_submission.h>
#include <ucxx/request.h>
#include <ucxx/typedefs.h>

namespace ucxx {

/**
* @brief Send or receive a message with the UCX Tag API.
*
* Close a UCP endpoint, using non-blocking UCP call `ucp_ep_close_nbx`.
*/
class RequestEndpointClose : public Request {
private:
/**
* @brief Private constructor of `ucxx::RequestEndpointClose`.
*
* This is the internal implementation of `ucxx::RequestEndpointClose` constructor, made
* private not to be called directly. This constructor is made private to ensure all UCXX
* objects are shared pointers and the correct lifetime management of each one.
*
* Instead the user should use one of the following:
*
* - `ucxx::Endpoint::close()`
* - `ucxx::createRequestEndpointClose()`
*
* @throws ucxx::Error if `endpoint` is not a valid `std::shared_ptr<ucxx::Endpoint>`.
*
* @param[in] endpoint the `std::shared_ptr<Endpoint>` parent component.
* @param[in] requestData container of the specified message type, including all
* type-specific data.
* @param[in] operationName a human-readable operation name to help identifying
* requests by their types when UCXX logging is enabled.
* @param[in] enablePythonFuture whether a python future should be created and
* subsequently notified.
* @param[in] callbackFunction user-defined callback function to call upon completion.
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
*/
RequestEndpointClose(std::shared_ptr<Endpoint> endpoint,
const data::EndpointClose requestData,
const std::string operationName,
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

public:
/**
* @brief Constructor for `std::shared_ptr<ucxx::RequestEndpointClose>`.
*
* The constructor for a `std::shared_ptr<ucxx::RequestEndpointClose>` object, creating a
* request to close a UCP endpoint, returning a pointer to the request object that can be
* later awaited and checked for errors. This is a non-blocking operation, and its status
* must be verified from the resulting request object to confirm the close operation has
* completed successfully.
*
* @throws ucxx::Error `endpoint` is not a valid `std::shared_ptr<ucxx::Endpoint>`.
*
* @param[in] endpoint the `std::shared_ptr<Endpoint>` parent component.
* @param[in] requestData container of the specified message type, including all
* type-specific data.
* @param[in] enablePythonFuture whether a python future should be created and
* subsequently notified.
* @param[in] callbackFunction user-defined callback function to call upon completion.
* @param[in] callbackData user-defined data to pass to the `callbackFunction`.
*
* @returns The `shared_ptr<ucxx::RequestEndpointClose>` object.
*/
friend std::shared_ptr<RequestEndpointClose> createRequestEndpointClose(
std::shared_ptr<Endpoint> endpoint,
const data::EndpointClose requestData,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData);

virtual void populateDelayedSubmission();

/**
* @brief Create and submit an endpoint close request.
*
* This is the method that should be called to actually submit an endpoint close request.
* It is meant to be called from `populateDelayedSubmission()`, which is decided at the
* discretion of `std::shared_ptr<ucxx::Worker>`. See `populateDelayedSubmission()` for
* more details.
*/
void request();

/**
* @brief Callback executed by UCX when an endpoint close request is completed.
*
* Callback executed by UCX when an endpoint close request is completed, that will
* dispatch `ucxx::Request::callback()`.
*
* WARNING: This is not intended to be called by the user, but it currently needs to be
* a public method so that UCX may access it. In future changes this will be moved to
* an internal object and remove this method from the public API.
*
* @param[in] request the UCX request pointer.
* @param[in] status the completion status of the request.
* @param[in] arg the pointer to the `ucxx::Request` object that created the
* transfer, effectively `this` pointer as seen by `request()`.
*/
static void endpointCloseCallback(void* request, ucs_status_t status, void* arg);
};

} // namespace ucxx
16 changes: 16 additions & 0 deletions cpp/include/ucxx/typedefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,22 @@ typedef std::function<void(ucs_status_t, std::shared_ptr<void>)> RequestCallback
*/
typedef std::shared_ptr<void> RequestCallbackUserData;

/**
* @brief A user-defined function to execute after an endpoint closes.
*
* A user-defined function to execute after an endpoint closes, allowing execution of custom
* code after such event.
*/
typedef RequestCallbackUserFunction EndpointCloseCallbackUserFunction;

/**
* @brief Data for the user-defined function provided to endpoint close callback.
*
* Data passed to the user-defined function provided to the endpoint close callback, which
* the custom user-defined function may act upon.
*/
typedef RequestCallbackUserData EndpointCloseCallbackUserData;

/**
* @brief Custom Active Message allocator type.
*
Expand Down
Loading