diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 7d0e02d93..97bfc6046 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -141,6 +141,7 @@ add_library( src/request.cpp src/request_am.cpp src/request_data.cpp + src/request_endpoint_close.cpp src/request_flush.cpp src/request_helper.cpp src/request_mem.cpp diff --git a/cpp/include/ucxx/constructors.h b/cpp/include/ucxx/constructors.h index f04326b51..b05132549 100644 --- a/cpp/include/ucxx/constructors.h +++ b/cpp/include/ucxx/constructors.h @@ -23,6 +23,7 @@ class Notifier; class RemoteKey; class Request; class RequestAm; +class RequestEndpointClose; class RequestFlush; class RequestMem; class RequestStream; @@ -77,6 +78,13 @@ std::shared_ptr createRequestAm( RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); +std::shared_ptr createRequestEndpointClose( + std::shared_ptr endpoint, + const data::EndpointClose requestData, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + std::shared_ptr createRequestFlush(std::shared_ptr endpointOrWorker, const data::Flush requestData, const bool enablePythonFuture, diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 81cced641..c97721cef 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -48,10 +48,14 @@ struct EpParamsDeleter { * callback to modify the `ucxx::Endpoint` with information relevant to the error occurred. */ struct ErrorCallbackData { + Endpoint* + _endpoint; ///< Pointer to the `ucxx::Endpoint` that owns this object, used only for logging. + std::unique_ptr + _mutex; ///< Mutex used to prevent race conditions with `ucxx::Endpoint::setCloseCallback()`. ucs_status_t status; ///< Endpoint status std::shared_ptr inflightRequests; ///< Endpoint inflight requests - std::function closeCallback; ///< Close callback to call - void* closeCallbackArg; ///< Argument to be passed to close callback + EndpointCloseCallbackUserFunction closeCallback; ///< Close callback to call + EndpointCloseCallbackUserData closeCallbackArg; ///< Argument to be passed to close callback std::shared_ptr worker; ///< Worker the endpoint has been created from }; @@ -216,9 +220,10 @@ class Endpoint : public Component { /** * @brief Check whether the endpoint is still alive. * - * Check whether the endpoint is still alive, generally `true` until `close()` is called - * the endpoint errors and the error handling procedure is executed. Always `true` if - * endpoint error handling is disabled. + * Check whether the endpoint is still alive, generally `true` until `closeBlocking()` is + * called, `close()` is called and the returned request completes or the endpoint errors + * and the error handling procedure is executed. Always `true` if endpoint error handling + * is disabled. * * @returns whether the endpoint is still alive if endpoint enables error handling, always * returns `true` if error handling is disabled. @@ -253,9 +258,32 @@ class Endpoint : public Component { /** * @brief Cancel inflight requests. * - * Cancel inflight requests, returning the total number of requests that were canceled. - * This is usually executed by `close()`, when pending requests will no longer be able - * to complete. + * Cancel inflight requests, returning the total number of requests that were scheduled + * for cancelation. After the requests are scheduled for cancelation, the caller must + * progress the worker and check the result of `getCancelingSize()`, all requests are only + * canceled when `getCancelingSize()` returns `0`. + * + * @returns Number of requests that were scheduled for cancelation. + */ + size_t cancelInflightRequests(); + + /** + * @brief Check the number of inflight requests being canceled. + * + * Check the number of inflight requests that were scheduled for cancelation with + * `cancelInflightRequests()` who have not yet completed cancelation. To ensure their + * cancelation is completed, the worker must be progressed until this method returns `0`. + * + * @returns Number of requests that are in process of cancelation. + */ + size_t getCancelingSize() const; + + /** + * @brief Cancel inflight requests. + * + * Cancel inflight requests and block until all requests complete cancelation, returning + * the total number of requests that were canceled. This is usually executed by + * `closeBlocking()`, when pending requests will no longer be able to complete. * * If the parent worker is running a progress thread, a maximum timeout may be specified * for which the close operation will wait. This can be particularly important for cases @@ -271,7 +299,7 @@ class Endpoint : public Component { * * @returns Number of requests that were canceled. */ - size_t cancelInflightRequests(uint64_t period = 0, uint64_t maxAttempts = 1); + size_t cancelInflightRequestsBlocking(uint64_t period = 0, uint64_t maxAttempts = 1); /** * @brief Register a user-defined callback to call when endpoint closes. @@ -286,7 +314,8 @@ class Endpoint : public Component { * receiving a single opaque pointer. * @param[in] closeCallbackArg pointer to optional user-allocated callback argument. */ - void setCloseCallback(std::function closeCallback, void* closeCallbackArg); + void setCloseCallback(EndpointCloseCallbackUserFunction closeCallback, + EndpointCloseCallbackUserData closeCallbackArg); /** * @brief Enqueue an active message send operation. @@ -684,11 +713,53 @@ class Endpoint : public Component { */ static void errorCallback(void* arg, ucp_ep_h ep, ucs_status_t status); + /** + * @brief Enqueue a non-blocking endpoint close operation. + * + * Enqueue a non-blocking endpoint close operation, which will close the endpoint without + * requiring to destroy the object. This may be useful when other + * `std::shared_ptr` objects are still alive, such as inflight transfers. + * + * This method returns a `std::shared` that can be later awaited and + * checked for errors. This is a non-blocking operation, and the status of closing the + * endpoint must be verified from the resulting request object before the + * `std::shared_ptr` can be safely destroyed and the UCP endpoint assumed + * inactive (closed). + * + * If the endpoint was created with error handling support, the error callback will be + * executed, implying the user-defined callback will also be executed. + * + * If a user-defined callback is specified via the `callbackFunction` argument then that + * callback will be executed, if not then the callback registered with `setCloseCallback()` + * will be executed, if neither was specified then no user-defined callback will be + * executed. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the close operation has completed. Requires UCXX Python support. + * + * @warning Unlike its `closeBlocking()` counterpart, this method does not cancel any + * inflight requests prior to submitting the UCP close request. Before scheduling the + * endpoint close request, the caller must first call `cancelInflightRequests()` and + * progress the worker until `getCancelingSize()` returns `0`. + * + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr close(const bool enablePythonFuture = false, + EndpointCloseCallbackUserFunction callbackFunction = nullptr, + EndpointCloseCallbackUserData callbackData = nullptr); + /** * @brief Close the endpoint while keeping the object alive. * - * Close the endpoint without requiring to destroy the object. This may be useful when - * `std::shared_ptr` objects are still alive. + * Close the endpoint without requiring to destroy the object, blocking until the + * operation completes. This may be useful when `std::shared_ptr` objects + * are still alive. * * If the endpoint was created with error handling support, the error callback will be * executed, implying the user-defined callback will also be executed if one was @@ -705,9 +776,8 @@ class Endpoint : public Component { * operation will wait for. * @param[in] maxAttempts maximum number of attempts to close endpoint, only applicable * if worker is running a progress thread and `period > 0`. - * */ - void close(uint64_t period = 0, uint64_t maxAttempts = 1); + void closeBlocking(uint64_t period = 0, uint64_t maxAttempts = 1); }; } // namespace ucxx diff --git a/cpp/include/ucxx/request_data.h b/cpp/include/ucxx/request_data.h index db3ab7267..3bae50789 100644 --- a/cpp/include/ucxx/request_data.h +++ b/cpp/include/ucxx/request_data.h @@ -66,6 +66,27 @@ class AmReceive { AmReceive(); }; +/** + * @brief Data for an endpoint close operation. + * + * Type identifying an endpoint close operation and containing data specific to this request + * type. + */ +class EndpointClose { + public: + const bool _force{false}; ///< Whether to force endpoint closing. + /** + * @brief Constructor for endpoint close-specific data. + * + * Construct an object containing endpoint close-specific data. + * + * @param[in] force force endpoint close if `true`, flush otherwise. + */ + explicit EndpointClose(const decltype(_force) force); + + EndpointClose() = delete; +}; + /** * @brief Data for a flush operation. * @@ -307,6 +328,7 @@ class TagMultiReceive { using RequestData = std::variant +#include +#include + +#include + +#include +#include +#include + +namespace ucxx { + +/** + * @brief Send or receive a message with the UCX Tag API. + * + * Close a UCP endpoint, using non-blocking UCP call `ucp_ep_close_nbx`. + */ +class RequestEndpointClose : public Request { + private: + /** + * @brief Private constructor of `ucxx::RequestEndpointClose`. + * + * This is the internal implementation of `ucxx::RequestEndpointClose` constructor, made + * private not to be called directly. This constructor is made private to ensure all UCXX + * objects are shared pointers and the correct lifetime management of each one. + * + * Instead the user should use one of the following: + * + * - `ucxx::Endpoint::close()` + * - `ucxx::createRequestEndpointClose()` + * + * @throws ucxx::Error if `endpoint` is not a valid `std::shared_ptr`. + * + * @param[in] endpoint the `std::shared_ptr` parent component. + * @param[in] requestData container of the specified message type, including all + * type-specific data. + * @param[in] operationName a human-readable operation name to help identifying + * requests by their types when UCXX logging is enabled. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + */ + RequestEndpointClose(std::shared_ptr endpoint, + const data::EndpointClose requestData, + const std::string operationName, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + public: + /** + * @brief Constructor for `std::shared_ptr`. + * + * The constructor for a `std::shared_ptr` object, creating a + * request to close a UCP endpoint, returning a pointer to the request object that can be + * later awaited and checked for errors. This is a non-blocking operation, and its status + * must be verified from the resulting request object to confirm the close operation has + * completed successfully. + * + * @throws ucxx::Error `endpoint` is not a valid `std::shared_ptr`. + * + * @param[in] endpoint the `std::shared_ptr` parent component. + * @param[in] requestData container of the specified message type, including all + * type-specific data. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns The `shared_ptr` object. + */ + friend std::shared_ptr createRequestEndpointClose( + std::shared_ptr endpoint, + const data::EndpointClose requestData, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + + virtual void populateDelayedSubmission(); + + /** + * @brief Create and submit an endpoint close request. + * + * This is the method that should be called to actually submit an endpoint close request. + * It is meant to be called from `populateDelayedSubmission()`, which is decided at the + * discretion of `std::shared_ptr`. See `populateDelayedSubmission()` for + * more details. + */ + void request(); + + /** + * @brief Callback executed by UCX when an endpoint close request is completed. + * + * Callback executed by UCX when an endpoint close request is completed, that will + * dispatch `ucxx::Request::callback()`. + * + * WARNING: This is not intended to be called by the user, but it currently needs to be + * a public method so that UCX may access it. In future changes this will be moved to + * an internal object and remove this method from the public API. + * + * @param[in] request the UCX request pointer. + * @param[in] status the completion status of the request. + * @param[in] arg the pointer to the `ucxx::Request` object that created the + * transfer, effectively `this` pointer as seen by `request()`. + */ + static void endpointCloseCallback(void* request, ucs_status_t status, void* arg); +}; + +} // namespace ucxx diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index 958103b8d..4831437c8 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -95,6 +95,22 @@ typedef std::function)> RequestCallback */ typedef std::shared_ptr RequestCallbackUserData; +/** + * @brief A user-defined function to execute after an endpoint closes. + * + * A user-defined function to execute after an endpoint closes, allowing execution of custom + * code after such event. + */ +typedef RequestCallbackUserFunction EndpointCloseCallbackUserFunction; + +/** + * @brief Data for the user-defined function provided to endpoint close callback. + * + * Data passed to the user-defined function provided to the endpoint close callback, which + * the custom user-defined function may act upon. + */ +typedef RequestCallbackUserData EndpointCloseCallbackUserData; + /** * @brief Custom Active Message allocator type. * diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 304ef6a85..b09814388 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include #include @@ -18,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -57,8 +60,12 @@ Endpoint::Endpoint(std::shared_ptr workerOrListener, setParent(workerOrListener); - _callbackData = std::make_unique( - (ErrorCallbackData){.status = UCS_OK, .inflightRequests = _inflightRequests, .worker = worker}); + _callbackData = + std::make_unique(ErrorCallbackData{._endpoint = this, + ._mutex = std::make_unique(), + .status = UCS_INPROGRESS, + .inflightRequests = _inflightRequests, + .worker = worker}); params->err_mode = (endpointErrorHandling ? UCP_ERR_HANDLING_MODE_PEER : UCP_ERR_HANDLING_MODE_NONE); @@ -142,58 +149,76 @@ std::shared_ptr createEndpointFromWorkerAddress(std::shared_ptr Endpoint::close(const bool enablePythonFuture, + EndpointCloseCallbackUserFunction callbackFunction, + EndpointCloseCallbackUserData callbackData) { - if (_handle == nullptr) return; + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + bool force = _endpointErrorHandling; + + auto combineCallbacksFunction = [this, &callbackFunction, &callbackData]( + ucs_status_t status, EndpointCloseCallbackUserData unused) { + _callbackData->status = status; + if (callbackFunction) callbackFunction(status, callbackData); + if (_callbackData->closeCallback) + _callbackData->closeCallback(status, _callbackData->closeCallbackArg); + }; + + return registerInflightRequest(createRequestEndpointClose( + endpoint, data::EndpointClose(force), enablePythonFuture, combineCallbacksFunction, nullptr)); +} + +void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) +{ + if (_callbackData->status != UCS_INPROGRESS || _handle == nullptr) return; - size_t canceled = cancelInflightRequests(3000000000 /* 3s */, 3); + size_t canceled = cancelInflightRequestsBlocking(3000000000 /* 3s */, 3); ucxx_debug("ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, canceled %lu requests", __func__, this, _handle, canceled); - // Close the endpoint - unsigned closeMode = UCP_EP_CLOSE_MODE_FORCE; - if (_endpointErrorHandling && _callbackData->status != UCS_OK) { - // We force close endpoint if endpoint error handling is enabled and - // the endpoint status is not UCS_OK - closeMode = UCP_EP_CLOSE_MODE_FORCE; - } + ucp_request_param_t param{}; + if (_endpointErrorHandling) + param = {.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, .flags = UCP_EP_CLOSE_FLAG_FORCE}; auto worker = ::ucxx::getWorker(_parent); ucs_status_ptr_t status; if (worker->isProgressThreadRunning()) { bool closeSuccess = false; + bool submitted = false; for (uint64_t i = 0; i < maxAttempts && !closeSuccess; ++i) { - utils::CallbackNotifier callbackNotifierPre{}; - worker->registerGenericPre([this, &callbackNotifierPre, &status, closeMode]() { - status = ucp_ep_close_nb(_handle, closeMode); - callbackNotifierPre.set(); - }); - if (!callbackNotifierPre.wait(period)) continue; + if (!submitted) { + utils::CallbackNotifier callbackNotifierPre{}; + worker->registerGenericPre([this, &callbackNotifierPre, &status, ¶m]() { + status = ucp_ep_close_nbx(_handle, ¶m); + callbackNotifierPre.set(); + }); + if (!callbackNotifierPre.wait(period)) continue; + submitted = true; + } - while (UCS_PTR_IS_PTR(status)) { + if (_callbackData->status == UCS_INPROGRESS) { utils::CallbackNotifier callbackNotifierPost{}; worker->registerGenericPost([this, &callbackNotifierPost, &status]() { - ucs_status_t s = ucp_request_check_status(status); - if (UCS_PTR_STATUS(s) != UCS_INPROGRESS) { - ucp_request_free(status); - _callbackData->status = UCS_PTR_STATUS(s); - if (UCS_PTR_STATUS(status) != UCS_OK) { - ucxx_error( - "ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, error while closing " - "endpoint: %s", - __func__, - this, - _handle, - ucs_status_string(UCS_PTR_STATUS(status))); + if (UCS_PTR_IS_PTR(status)) { + ucs_status_t s; + if ((s = ucp_request_check_status(status)) != UCS_INPROGRESS) { + _callbackData->status = s; } + } else if (UCS_PTR_STATUS(status) != UCS_OK) { + ucxx_error( + "ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, Error while closing endpoint: %s", + __func__, + this, + _handle, + ucs_status_string(UCS_PTR_STATUS(status))); } callbackNotifierPost.set(); @@ -213,12 +238,11 @@ void Endpoint::close(uint64_t period, uint64_t maxAttempts) _handle); } } else { - status = ucp_ep_close_nb(_handle, closeMode); + status = ucp_ep_close_nbx(_handle, ¶m); if (UCS_PTR_IS_PTR(status)) { ucs_status_t s; while ((s = ucp_request_check_status(status)) == UCS_INPROGRESS) worker->progress(); - ucp_request_free(status); _callbackData->status = s; } else if (UCS_PTR_STATUS(status) != UCS_OK) { ucxx_error( @@ -231,12 +255,14 @@ void Endpoint::close(uint64_t period, uint64_t maxAttempts) } ucxx_trace("ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, closed", __func__, this, _handle); + if (UCS_PTR_IS_PTR(status)) ucp_request_free(status); + if (_callbackData->closeCallback) { ucxx_debug("ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, calling user close callback", __func__, this, _handle); - _callbackData->closeCallback(_callbackData->closeCallbackArg); + _callbackData->closeCallback(_callbackData->status, _callbackData->closeCallbackArg); _callbackData->closeCallback = nullptr; _callbackData->closeCallbackArg = nullptr; } @@ -250,14 +276,14 @@ bool Endpoint::isAlive() const { if (!_endpointErrorHandling) return true; - return _callbackData->status == UCS_OK; + return _callbackData->status == UCS_INPROGRESS; } void Endpoint::raiseOnError() { ucs_status_t status = _callbackData->status; - if (status == UCS_OK || !_endpointErrorHandling) return; + if (status == UCS_OK || status == UCS_INPROGRESS || !_endpointErrorHandling) return; std::string statusString{ucs_status_string(status)}; std::stringstream errorMsgStream; @@ -266,8 +292,11 @@ void Endpoint::raiseOnError() utils::ucsErrorThrow(status, errorMsgStream.str()); } -void Endpoint::setCloseCallback(std::function closeCallback, void* closeCallbackArg) +void Endpoint::setCloseCallback(EndpointCloseCallbackUserFunction closeCallback, + EndpointCloseCallbackUserData closeCallbackArg) { + std::lock_guard lock(*_callbackData->_mutex); + _callbackData->closeCallback = closeCallback; _callbackData->closeCallbackArg = closeCallbackArg; } @@ -277,11 +306,11 @@ std::shared_ptr Endpoint::registerInflightRequest(std::shared_ptrisCompleted()) _inflightRequests->insert(request); /** - * If the endpoint errored while the request was being submitted, the error - * handler may have been called already and we need to register any new requests - * for cancelation, including the present one. + * If the endpoint closed or errored while the request was being submitted, the error + * handler may have been called already and we need to register any new requests for + * cancelation, including the present one. */ - if (_callbackData->status != UCS_OK) + if (_callbackData->status != UCS_INPROGRESS) _callbackData->worker->scheduleRequestCancel(_inflightRequests->release()); return request; @@ -292,7 +321,9 @@ void Endpoint::removeInflightRequest(const Request* const request) _inflightRequests->remove(request); } -size_t Endpoint::cancelInflightRequests(uint64_t period, uint64_t maxAttempts) +size_t Endpoint::cancelInflightRequests() { return _inflightRequests->cancelAll(); } + +size_t Endpoint::cancelInflightRequestsBlocking(uint64_t period, uint64_t maxAttempts) { auto worker = ::ucxx::getWorker(this->_parent); size_t canceled = 0; @@ -332,6 +363,8 @@ size_t Endpoint::cancelInflightRequests(uint64_t period, uint64_t maxAttempts) return canceled; } +size_t Endpoint::getCancelingSize() const { return _inflightRequests->getCancelingSize(); } + std::shared_ptr Endpoint::amSend(void* buffer, size_t length, ucs_memory_type_t memoryType, @@ -509,11 +542,17 @@ void Endpoint::errorCallback(void* arg, ucp_ep_h ep, ucs_status_t status) ErrorCallbackData* data = reinterpret_cast(arg); data->status = status; data->worker->scheduleRequestCancel(data->inflightRequests->release()); - if (data->closeCallback) { - ucxx_debug("ucxx::Endpoint::%s, UCP handle: %p, calling user close callback", __func__, ep); - data->closeCallback(data->closeCallbackArg); - data->closeCallback = nullptr; - data->closeCallbackArg = nullptr; + { + std::lock_guard lock(*data->_mutex); + if (data->closeCallback) { + ucxx_debug("ucxx::Endpoint::%s: %p, UCP handle: %p, calling user close callback", + __func__, + data->_endpoint, + ep); + data->closeCallback(status, data->closeCallbackArg); + data->closeCallback = nullptr; + data->closeCallbackArg = nullptr; + } } // Connection reset and timeout often represent just a normal remote diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index ff6078aed..9862c45e0 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -56,8 +56,8 @@ void InflightRequests::remove(const Request* const request) if (search != _trackedRequests->_inflight->end()) { /** * If this is the last request to hold `std::shared_ptr` erasing it - * may cause the `ucxx::Endpoint`s destructor and subsequently the `close()` method - * to be called which will in turn call `cancelAll()` and attempt to take the + * may cause the `ucxx::Endpoint`s destructor and subsequently the `closeBlocking()` + * method to be called which will in turn call `cancelAll()` and attempt to take the * mutexes. For this reason we should make a temporary copy of the request being * erased from `_trackedRequests->_inflight` to allow unlocking the mutexes and only then * destroy the object upon this method's return. diff --git a/cpp/src/request_data.cpp b/cpp/src/request_data.cpp index 0a51b8506..afc3b7926 100644 --- a/cpp/src/request_data.cpp +++ b/cpp/src/request_data.cpp @@ -21,6 +21,8 @@ AmSend::AmSend(const void* buffer, const size_t length, const ucs_memory_type me AmReceive::AmReceive() {} +EndpointClose::EndpointClose(const bool force) : _force(force) {} + Flush::Flush() {} MemPut::MemPut(const void* buffer, diff --git a/cpp/src/request_endpoint_close.cpp b/cpp/src/request_endpoint_close.cpp new file mode 100644 index 000000000..537f4f2fd --- /dev/null +++ b/cpp/src/request_endpoint_close.cpp @@ -0,0 +1,108 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include "ucxx/request_data.h" +#include +#include +#include + +#include + +#include +#include +#include + +namespace ucxx { + +std::shared_ptr createRequestEndpointClose( + std::shared_ptr endpoint, + const data::EndpointClose requestData, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr) +{ + std::shared_ptr req = + std::shared_ptr(new RequestEndpointClose( + endpoint, requestData, "endpointClose", enablePythonFuture, callbackFunction, callbackData)); + + // A delayed notification request is not populated immediately, instead it is + // delayed to allow the worker progress thread to set its status, and more + // importantly the Python future later on, so that we don't need the GIL here. + req->_worker->registerDelayedSubmission( + req, std::bind(std::mem_fn(&Request::populateDelayedSubmission), req.get())); + + return req; +} + +RequestEndpointClose::RequestEndpointClose(std::shared_ptr endpoint, + const data::EndpointClose requestData, + const std::string operationName, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) + : Request(endpoint, requestData, operationName, enablePythonFuture) +{ + if (_endpoint == nullptr && _worker == nullptr) + throw ucxx::Error("A valid endpoint or worker is required for a close operation."); + + _callback = callbackFunction; + _callbackData = callbackData; +} + +void RequestEndpointClose::endpointCloseCallback(void* request, ucs_status_t status, void* arg) +{ + Request* req = reinterpret_cast(arg); + ucxx_trace_req_f( + req->getOwnerString().c_str(), nullptr, request, "endpointClose", "endpointCloseCallback"); + return req->callback(request, status); +} + +void RequestEndpointClose::request() +{ + void* request = nullptr; + + ucp_request_param_t param = { + .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, .user_data = this}; + if (std::get(_requestData)._force) param.flags = UCP_EP_CLOSE_FLAG_FORCE; + param.cb.send = endpointCloseCallback; + if (_endpoint != nullptr) + request = ucp_ep_close_nbx(_endpoint->getHandle(), ¶m); + else + throw ucxx::Error("A valid endpoint or worker is required for a close operation."); + + std::lock_guard lock(_mutex); + _request = request; +} + +static void logPopulateDelayedSubmission() {} + +void RequestEndpointClose::populateDelayedSubmission() +{ + if (_endpoint != nullptr && _endpoint->getHandle() == nullptr) { + ucxx_warn("Endpoint is already closed"); + Request::callback(this, UCS_ERR_CANCELED); + return; + } + + request(); + + if (_enablePythonFuture) + ucxx_trace_req_f(_ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "populateDelayedSubmission, endpoint close, future: %p, future handle: %p", + _future.get(), + _future->getHandle()); + else + ucxx_trace_req_f(_ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "populateDelayedSubmission, endpoint close"); + + process(); +} + +} // namespace ucxx diff --git a/cpp/tests/listener.cpp b/cpp/tests/listener.cpp index 8cd8a7939..80d15751d 100644 --- a/cpp/tests/listener.cpp +++ b/cpp/tests/listener.cpp @@ -3,6 +3,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ #include +#include +#include #include #include @@ -19,6 +21,7 @@ struct ListenerContainer { std::shared_ptr listener{nullptr}; std::shared_ptr endpoint{nullptr}; bool transferCompleted{false}; + bool endpointErrorHandling{false}; }; typedef std::shared_ptr ListenerContainerPtr; @@ -32,24 +35,35 @@ static void listenerCallback(ucp_conn_request_h connRequest, void* arg) listenerContainer->status = ucp_conn_request_query(connRequest, &attr); if (listenerContainer->status != UCS_OK) return; - listenerContainer->endpoint = - listenerContainer->listener->createEndpointFromConnRequest(connRequest); + listenerContainer->endpoint = listenerContainer->listener->createEndpointFromConnRequest( + connRequest, listenerContainer->endpointErrorHandling); } -class ListenerTest : public ::testing::Test { +class ListenerTestBase { protected: std::shared_ptr _context{ ucxx::createContext({}, ucxx::Context::defaultFeatureFlags)}; std::shared_ptr _worker{nullptr}; - - virtual void SetUp() { _worker = _context->createWorker(); } + bool _endpointErrorHandling{true}; ListenerContainerPtr createListenerContainer() { - auto listenerContainer = std::make_shared(); - listenerContainer->worker = _worker; + auto listenerContainer = std::make_shared(); + listenerContainer->worker = _worker; + listenerContainer->endpointErrorHandling = _endpointErrorHandling; return listenerContainer; } +}; + +class ListenerTest : public ListenerTestBase, + public ::testing::Test, + public ::testing::WithParamInterface { + protected: + virtual void SetUp() + { + _endpointErrorHandling = GetParam(); + _worker = _context->createWorker(); + } virtual std::shared_ptr createListener(ListenerContainerPtr listenerContainer) { @@ -59,17 +73,25 @@ class ListenerTest : public ::testing::Test { } }; -class ListenerPortTest : public ListenerTest, public ::testing::WithParamInterface { +class ListenerPortTest : public ListenerTestBase, + public ::testing::Test, + public ::testing::WithParamInterface { protected: + uint16_t _port; + virtual void SetUp() + { + _port = GetParam(); + _worker = _context->createWorker(); + } virtual std::shared_ptr createListener(ListenerContainerPtr listenerContainer) { - auto listener = _worker->createListener(GetParam(), listenerCallback, listenerContainer.get()); + auto listener = _worker->createListener(_port, listenerCallback, listenerContainer.get()); listenerContainer->listener = listener; return listener; } }; -TEST_F(ListenerTest, HandleIsValid) +TEST_P(ListenerTest, HandleIsValid) { auto listenerContainer = createListenerContainer(); auto listener = createListener(listenerContainer); @@ -78,21 +100,7 @@ TEST_F(ListenerTest, HandleIsValid) ASSERT_TRUE(listener->getHandle() != nullptr); } -TEST_P(ListenerPortTest, Port) -{ - auto listenerContainer = createListenerContainer(); - auto listener = createListener(listenerContainer); - _worker->progress(); - - if (GetParam() == 0) - ASSERT_GE(listener->getPort(), 1024); - else - ASSERT_EQ(listener->getPort(), 12345); -} - -INSTANTIATE_TEST_SUITE_P(PortAssignment, ListenerPortTest, ::testing::Values(0, 12345)); - -TEST_F(ListenerTest, EndpointSendRecv) +TEST_P(ListenerTest, EndpointSendRecv) { auto listenerContainer = createListenerContainer(); auto listener = createListener(listenerContainer); @@ -100,7 +108,8 @@ TEST_F(ListenerTest, EndpointSendRecv) progress(); - auto ep = _worker->createEndpointFromHostname("127.0.0.1", listener->getPort()); + auto ep = + _worker->createEndpointFromHostname("127.0.0.1", listener->getPort(), _endpointErrorHandling); while (listenerContainer->endpoint == nullptr) progress(); @@ -125,13 +134,14 @@ TEST_F(ListenerTest, EndpointSendRecv) std::vector buf{0}; } -TEST_F(ListenerTest, IsAlive) +TEST_P(ListenerTest, IsAlive) { auto listenerContainer = createListenerContainer(); auto listener = createListener(listenerContainer); _worker->progress(); - auto ep = _worker->createEndpointFromHostname("127.0.0.1", listener->getPort()); + auto ep = + _worker->createEndpointFromHostname("127.0.0.1", listener->getPort(), _endpointErrorHandling); while (listenerContainer->endpoint == nullptr) _worker->progress(); @@ -149,16 +159,20 @@ TEST_F(ListenerTest, IsAlive) return !ep->isAlive(); }); - ASSERT_FALSE(ep->isAlive()); + if (_endpointErrorHandling) + ASSERT_FALSE(ep->isAlive()); + else + ASSERT_TRUE(ep->isAlive()); } -TEST_F(ListenerTest, RaiseOnError) +TEST_P(ListenerTest, RaiseOnError) { auto listenerContainer = createListenerContainer(); auto listener = createListener(listenerContainer); _worker->progress(); - auto ep = _worker->createEndpointFromHostname("127.0.0.1", listener->getPort()); + auto ep = + _worker->createEndpointFromHostname("127.0.0.1", listener->getPort(), _endpointErrorHandling); while (listenerContainer->endpoint == nullptr) _worker->progress(); @@ -174,34 +188,163 @@ TEST_F(ListenerTest, RaiseOnError) return false; }); - EXPECT_THROW(ep->raiseOnError(), ucxx::Error); + if (_endpointErrorHandling) EXPECT_THROW(ep->raiseOnError(), ucxx::Error); } -TEST_F(ListenerTest, CloseCallback) +TEST_P(ListenerTest, EndpointCloseCallback) { auto listenerContainer = createListenerContainer(); auto listener = createListener(listenerContainer); _worker->progress(); - auto ep = _worker->createEndpointFromHostname("127.0.0.1", listener->getPort()); + auto ep = + _worker->createEndpointFromHostname("127.0.0.1", listener->getPort(), _endpointErrorHandling); - bool isClosed = false; - ep->setCloseCallback([](void* isClosed) { *reinterpret_cast(isClosed) = true; }, - reinterpret_cast(&isClosed)); + struct CallbackData { + ucs_status_t status{UCS_INPROGRESS}; + bool closed{false}; + }; + + auto callbackData = std::make_shared(); + ep->setCloseCallback( + [](ucs_status_t status, ucxx::EndpointCloseCallbackUserData callbackData) { + auto cbData = std::static_pointer_cast(callbackData); + cbData->status = status; + cbData->closed = true; + }, + callbackData); while (listenerContainer->endpoint == nullptr) _worker->progress(); - ASSERT_FALSE(isClosed); + ASSERT_FALSE(callbackData->closed); + ASSERT_EQ(callbackData->status, UCS_INPROGRESS); listenerContainer->endpoint = nullptr; - loopWithTimeout(std::chrono::milliseconds(5000), [this, &isClosed]() { + loopWithTimeout(std::chrono::milliseconds(5000), [this, &callbackData]() { _worker->progress(); - return isClosed; + return callbackData->closed; }); - ASSERT_TRUE(isClosed); + ASSERT_TRUE(callbackData->closed); + EXPECT_NE(callbackData->status, UCS_INPROGRESS); } +bool checkRequestWithTimeout(std::chrono::milliseconds timeout, + std::shared_ptr worker, + std::shared_ptr closeRequest) +{ + auto startTime = std::chrono::system_clock::now(); + auto endTime = startTime + std::chrono::milliseconds(timeout); + + while (std::chrono::system_clock::now() < endTime) { + worker->progress(); + if (closeRequest->isCompleted()) return true; + } + return false; +} + +TEST_P(ListenerTest, EndpointNonBlockingClose) +{ + auto listenerContainer = createListenerContainer(); + auto listener = createListener(listenerContainer); + _worker->progress(); + + auto ep = + _worker->createEndpointFromHostname("127.0.0.1", listener->getPort(), _endpointErrorHandling); + + while (listenerContainer->endpoint == nullptr) + _worker->progress(); + + auto closeRequest = ep->close(); + + /** + * FIXME: For some reason the code below calls `_worker->progress()` from within + * `_worker->progress()`, which is invalid in UCX. The `checkRequestWithTimeout` below + * which is functionally equivalent has no such problem. The lambda seems to behave in + * unexpected way here. The issue also goes away if in `Endpoint::close()` the + * line `if (callbackFunction) callbackFunction(status, callbackData);` is commented + * out from the `combineCallbacksFunction` lambda, even when no callback is specified + * to `ep->close()` above. + */ + // auto f = [this, &closeRequest]() { + // _worker->progress(); + // return closeRequest->isCompleted(); + // }; + // loopWithTimeout(std::chrono::milliseconds(5000), f); + + checkRequestWithTimeout(std::chrono::milliseconds(5000), _worker, closeRequest); + + if (_endpointErrorHandling) + ASSERT_FALSE(ep->isAlive()); + else + ASSERT_TRUE(ep->isAlive()); + EXPECT_NE(closeRequest->getStatus(), UCS_INPROGRESS); +} + +TEST_P(ListenerTest, EndpointNonBlockingCloseWithCallbacks) +{ + auto listenerContainer = createListenerContainer(); + auto listener = createListener(listenerContainer); + _worker->progress(); + + auto closeCallback = [](ucs_status_t status, ucxx::EndpointCloseCallbackUserData data) { + auto dataStatus = std::static_pointer_cast(data); + *dataStatus = status; + }; + auto closeCallbackEndpoint = std::make_shared(UCS_INPROGRESS); + auto closeCallbackRequest = std::make_shared(UCS_INPROGRESS); + + auto ep = + _worker->createEndpointFromHostname("127.0.0.1", listener->getPort(), _endpointErrorHandling); + ep->setCloseCallback(closeCallback, closeCallbackEndpoint); + + while (listenerContainer->endpoint == nullptr) + _worker->progress(); + + auto closeRequest = ep->close(false, closeCallback, closeCallbackRequest); + + /** + * FIXME: For some reason the code below calls `_worker->progress()` from within + * `_worker->progress()`, which is invalid in UCX. The `checkRequestWithTimeout` below + * which is functionally equivalent has no such problem. The lambda seems to behave in + * unexpected way here. The issue also goes away if in `Endpoint::close()` the + * line `if (callbackFunction) callbackFunction(status, callbackData);` is commented + * out from the `combineCallbacksFunction` lambda, even when no callback is specified + * to `ep->close()` above. + */ + // auto f = [this, &closeRequest]() { + // _worker->progress(); + // return closeRequest->isCompleted(); + // }; + // loopWithTimeout(std::chrono::milliseconds(5000), f); + + checkRequestWithTimeout(std::chrono::milliseconds(5000), _worker, closeRequest); + + if (_endpointErrorHandling) + ASSERT_FALSE(ep->isAlive()); + else + ASSERT_TRUE(ep->isAlive()); + EXPECT_NE(closeRequest->getStatus(), UCS_INPROGRESS); + ASSERT_NE(*closeCallbackEndpoint, UCS_INPROGRESS); + ASSERT_NE(*closeCallbackRequest, UCS_INPROGRESS); +} + +INSTANTIATE_TEST_SUITE_P(EndpointErrorHandling, ListenerTest, ::testing::Values(true)); + +TEST_P(ListenerPortTest, Port) +{ + auto listenerContainer = createListenerContainer(); + auto listener = createListener(listenerContainer); + _worker->progress(); + + if (GetParam() == 0) + ASSERT_GE(listener->getPort(), 1024); + else + ASSERT_EQ(listener->getPort(), 12345); +} + +INSTANTIATE_TEST_SUITE_P(PortAssignment, ListenerPortTest, ::testing::Values(0, 12345)); + } // namespace diff --git a/python/ucxx/_lib/libucxx.pyx b/python/ucxx/_lib/libucxx.pyx index 7f09dbea9..e3a3d5632 100644 --- a/python/ucxx/_lib/libucxx.pyx +++ b/python/ucxx/_lib/libucxx.pyx @@ -20,6 +20,7 @@ from libcpp.memory cimport ( make_shared, make_unique, shared_ptr, + static_pointer_cast, unique_ptr, ) from libcpp.string cimport string @@ -1082,9 +1083,10 @@ cdef class UCXBufferRequests: return self.py_buffers -cdef void _endpoint_close_callback(void *args) with gil: +cdef void _endpoint_close_callback(ucs_status_t status, shared_ptr[void] args) with gil: """Callback function called when UCXEndpoint closes or errors""" - cdef dict cb_data = args + cdef shared_ptr[uintptr_t] cb_data_ptr = static_pointer_cast[uintptr_t, void](args) + cdef dict cb_data = cb_data_ptr.get()[0] try: cb_data['cb_func']( @@ -1102,6 +1104,7 @@ cdef class UCXEndpoint(): bint _cuda_support bint _enable_python_future dict _close_cb_data + shared_ptr[uintptr_t] _close_cb_data_ptr def __init__(self) -> None: raise TypeError("UCXListener cannot be instantiated directly.") @@ -1239,12 +1242,22 @@ cdef class UCXEndpoint(): return alive - def close(self, uint64_t period=0, uint64_t max_attempts=1) -> None: + def close(self) -> None: + cdef shared_ptr[Request] req + + with nogil: + req = self._endpoint.get().close( + self._enable_python_future + ) + + return UCXRequest(&req, self._enable_python_future) + + def close_blocking(self, uint64_t period=0, uint64_t max_attempts=1) -> None: cdef uint64_t c_period = period cdef uint64_t c_max_attempts = max_attempts with nogil: - self._endpoint.get().close(c_period, c_max_attempts) + self._endpoint.get().closeBlocking(c_period, c_max_attempts) def am_probe(self) -> bool: cdef ucp_ep_h handle @@ -1480,13 +1493,17 @@ cdef class UCXEndpoint(): "cb_args": cb_args, "cb_kwargs": cb_kwargs, } + self._close_cb_data_ptr = make_shared[uintptr_t]( + self._close_cb_data + ) - cdef function[void(void*)]* func_close_callback = ( - new function[void(void*)](_endpoint_close_callback) + cdef function[void(ucs_status_t, shared_ptr[void])]* func_close_callback = ( + new function[void(ucs_status_t, shared_ptr[void])](_endpoint_close_callback) ) with nogil: self._endpoint.get().setCloseCallback( - deref(func_close_callback), self._close_cb_data + deref(func_close_callback), + static_pointer_cast[void, uintptr_t](self._close_cb_data_ptr) ) del func_close_callback @@ -1500,8 +1517,8 @@ cdef class UCXEndpoint(): endpoint = self._endpoint.get() if endpoint != nullptr: endpoint.setCloseCallback( - nullptr, - nullptr, + nullptr, + nullptr, ) diff --git a/python/ucxx/_lib/ucxx_api.pxd b/python/ucxx/_lib/ucxx_api.pxd index 3bb40e3a4..b1cea3016 100644 --- a/python/ucxx/_lib/ucxx_api.pxd +++ b/python/ucxx/_lib/ucxx_api.pxd @@ -270,7 +270,10 @@ cdef extern from "" namespace "ucxx" nogil: cdef cppclass Endpoint(Component): ucp_ep_h getHandle() - void close(uint64_t period, uint64_t maxAttempts) + shared_ptr[Request] close( + bint enable_python_future + ) except +raise_py_error + void closeBlocking(uint64_t period, uint64_t maxAttempts) shared_ptr[Request] amSend( void* buffer, size_t length, @@ -309,7 +312,8 @@ cdef extern from "" namespace "ucxx" nogil: bint isAlive() void raiseOnError() except +raise_py_error void setCloseCallback( - function[void(void*)] close_callback, void* close_callback_arg + function[void(ucs_status_t, shared_ptr[void])] close_callback, + shared_ptr[void] close_callback_arg ) shared_ptr[Worker] getWorker() diff --git a/python/ucxx/_lib_async/endpoint.py b/python/ucxx/_lib_async/endpoint.py index 926b7af5b..5fee43c9e 100644 --- a/python/ucxx/_lib_async/endpoint.py +++ b/python/ucxx/_lib_async/endpoint.py @@ -109,7 +109,7 @@ def abort(self, period=10**10, max_attempts=1): if self._ep is not None: logger.debug("Endpoint.abort(): 0x%x" % self.uid) # Wait for a maximum of `period` ns - self._ep.close(period=period, max_attempts=max_attempts) + self._ep.close_blocking(period=period, max_attempts=max_attempts) self._ep.remove_close_callback() self._ep = None self._ctx = None