From 449de5ee3d580c5efae7090f65a7de207ea43f6a Mon Sep 17 00:00:00 2001 From: Michael Demoret Date: Mon, 8 Jan 2024 12:13:48 -0700 Subject: [PATCH 1/9] Added get/put memory access wrappers on Endpoint --- cpp/CMakeLists.txt | 6 +- cpp/include/ucxx/constructors.h | 11 ++ cpp/include/ucxx/endpoint.h | 16 ++ cpp/include/ucxx/request_mem.h | 180 ++++++++++++++++++++ cpp/src/endpoint.cpp | 41 +++++ cpp/src/request_mem.cpp | 293 ++++++++++++++++++++++++++++++++ 6 files changed, 545 insertions(+), 2 deletions(-) create mode 100644 cpp/include/ucxx/request_mem.h create mode 100644 cpp/src/request_mem.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4c4fc5666..ef3e14db8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -77,6 +77,7 @@ rapids_cmake_support_conda_env(conda_env MODIFY_PREFIX_PATH) # * compiler options ------------------------------------------------------------------------------ rapids_find_package( ucx REQUIRED + GLOBAL_TARGETS ucx::ucp ucx::ucs ucx::uct BUILD_EXPORT_SET ucxx-exports INSTALL_EXPORT_SET ucxx-exports ) @@ -119,6 +120,7 @@ add_library( src/request.cpp src/request_am.cpp src/request_helper.cpp + src/request_mem.cpp src/request_stream.cpp src/request_tag.cpp src/request_tag_multi.cpp @@ -149,9 +151,9 @@ target_compile_options( # Specify include paths for the current target and dependents target_include_directories( ucxx - PUBLIC "$" + PUBLIC "$" "$" - PRIVATE "$" + PRIVATE "$" INTERFACE "$" ) diff --git a/cpp/include/ucxx/constructors.h b/cpp/include/ucxx/constructors.h index 0d0876dc9..e40ea7966 100644 --- a/cpp/include/ucxx/constructors.h +++ b/cpp/include/ucxx/constructors.h @@ -20,6 +20,7 @@ class Listener; class Notifier; class Request; class RequestAm; +class RequestMem; class RequestStream; class RequestTag; class RequestTagMulti; @@ -74,6 +75,16 @@ std::shared_ptr createRequestStream(std::shared_ptr end size_t length, const bool enablePythonFuture); +std::shared_ptr createRequestMem(std::shared_ptr endpoint, + bool send, + void* buffer, + size_t length, + uint64_t remote_addr, + ucp_rkey_h rkey, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + std::shared_ptr createRequestTag(std::shared_ptr endpointOrWorker, bool send, void* buffer, diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 59c920637..7843afe5a 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -265,6 +265,22 @@ class Endpoint : public Component { */ void setCloseCallback(std::function closeCallback, void* closeCallbackArg); + std::shared_ptr memGet(void* buffer, + size_t length, + uint64_t remote_addr, + ucp_rkey_h rkey, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + std::shared_ptr memPut(void* buffer, + size_t length, + uint64_t remote_addr, + ucp_rkey_h rkey, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + /** * @brief Enqueue an active message send operation. * diff --git a/cpp/include/ucxx/request_mem.h b/cpp/include/ucxx/request_mem.h new file mode 100644 index 000000000..85b11f5c0 --- /dev/null +++ b/cpp/include/ucxx/request_mem.h @@ -0,0 +1,180 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once +#include +#include + +#include + +#include +#include +#include + +namespace ucxx { + +class Buffer; + +namespace internal { +class RecvMemMessage; +} // namespace internal + +class RequestMem : public Request { + private: + friend class internal::RecvMemMessage; + + uint64_t _remote_addr{}; + ucp_rkey_h _rkey{}; + + /** + * @brief Private constructor of `ucxx::RequestAm` send. + * + * This is the internal implementation of `ucxx::RequestAm` send constructor, made private + * not to be called directly. This constructor is made private to ensure all UCXX objects + * are shared pointers and the correct lifetime management of each one. + * + * Instead the user should use one of the following: + * + * - `ucxx::Endpoint::amSend()` + * - `ucxx::createRequestAmSend()` + * + * @throws ucxx::Error if `endpoint` is not a valid `std::shared_ptr`. + * + * @param[in] endpoint the parent endpoint. + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the active message to be sent. + * @param[in] memoryType the memory type of the buffer. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + */ + RequestMem(std::shared_ptr endpoint, + bool send, + void* buffer, + size_t length, + uint64_t remote_addr, + ucp_rkey_h rkey, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + public: + /** + * @brief Constructor for `std::shared_ptr` send. + * + * The constructor for a `std::shared_ptr` object, creating a send active + * message request, returning a pointer to a request object that can be later awaited and + * checked for errors. This is a non-blocking operation, and the status of the transfer + * must be verified from the resulting request object before the data can be + * released. + * + * @throws ucxx::Error if `endpoint` is not a valid + * `std::shared_ptr`. + * + * @param[in] endpoint the parent endpoint. + * @param[in] buffer a raw pointer to the data to be transferred. + * @param[in] length the size in bytes of the tag message to be transferred. + * @param[in] memoryType the memory type of the buffer. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns The `shared_ptr` object + */ + friend std::shared_ptr createRequestMem(std::shared_ptr endpoint, + bool send, + void* buffer, + size_t length, + uint64_t remote_addr, + ucp_rkey_h rkey, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + + virtual void populateDelayedSubmission(); + + /** + * @brief Callback executed by UCX when a tag send request is completed. + * + * Callback executed by UCX when a tag send request is completed, that will dispatch + * `ucxx::Request::callback()`. + * + * WARNING: This is not intended to be called by the user, but it currently needs to be + * a public method so that UCX may access it. In future changes this will be moved to + * an internal object and remove this method from the public API. + * + * @param[in] request the UCX request pointer. + * @param[in] status the completion status of the request. + * @param[in] arg the pointer to the `ucxx::Request` object that created the + * transfer, effectively `this` pointer as seen by `request()`. + */ + static void memPutCallback(void* request, ucs_status_t status, void* arg); + + /** + * @brief Callback executed by UCX when a tag receive request is completed. + * + * Callback executed by UCX when a tag receive request is completed, that will dispatch + * `ucxx::RequestTag::callback()`. + * + * WARNING: This is not intended to be called by the user, but it currently needs to be + * a public method so that UCX may access it. In future changes this will be moved to + * an internal object and remove this method from the public API. + * + * @param[in] request the UCX request pointer. + * @param[in] status the completion status of the request. + * @param[in] info information of the completed transfer provided by UCX, includes + * length of message received used to verify for truncation. + * @param[in] arg the pointer to the `ucxx::Request` object that created the + * transfer, effectively `this` pointer as seen by `request()`. + */ + static void memGetCallback(void* request, ucs_status_t status, void* arg); + + /** + * @brief Create and submit an active message send request. + * + * This is the method that should be called to actually submit an active message send + * request. It is meant to be called from `populateDelayedSubmission()`, which is decided + * at the discretion of `std::shared_ptr`. See `populateDelayedSubmission()` + * for more details. + */ + void request(); + + // /** + // * @brief Receive callback registered by `ucxx::Worker`. + // * + // * This is the receive callback registered by the `ucxx::Worker` to handle incoming active + // * messages. For each incoming active message, a proper buffer will be allocated based on + // * the header sent by the remote endpoint using the default allocator or one registered by + // * the user via `ucxx::Worker::registerAmAllocator()`. Following that, the message is + // * immediately received onto the buffer and a `UCS_OK` or the proper error status is set + // * onto the `RequestAm` that is returned to the user, or will be later handled by another + // * callback when the message is ready. If the callback is executed when a user has already + // * requested received of the active message, the buffer and status will be set on the + // * earliest request, otherwise a new request is created and saved in a pool that will be + // * already populated and ready for consumption or waiting for the internal callback when + // * requested. + // * + // * This is always called by `ucp_worker_progress()`, and thus will happen in the same + // * thread that is called from, when using the worker progress thread, this is called from + // * the progress thread. + // * + // * param[in,out] arg pointer to the `AmData` object held by the `ucxx::Worker` who + // * registered this callback. + // * param[in] header pointer to the header containing the sender buffer's memory type. + // * param[in] header_length length in bytes of the receive header. + // * param[in] data pointer to the buffer containing the remote endpoint's send data. + // * param[in] length the length in bytes of the message to be received. + // * param[in] param UCP parameters of the active message being received. + // */ + // static ucs_status_t recvCallback(void* arg, + // const void* header, + // size_t header_length, + // void* data, + // size_t length, + // const ucp_am_recv_param_t* param); +}; + +} // namespace ucxx diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index e795f71cf..55ed4b5c9 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -302,6 +303,46 @@ size_t Endpoint::cancelInflightRequests(uint64_t period, uint64_t maxAttempts) return canceled; } +std::shared_ptr Endpoint::memGet(void* buffer, + size_t length, + uint64_t remote_addr, + ucp_rkey_h rkey, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest(createRequestMem(endpoint, + false, + buffer, + length, + remote_addr, + rkey, + enablePythonFuture, + callbackFunction, + callbackData)); +} + +std::shared_ptr Endpoint::memPut(void* buffer, + size_t length, + uint64_t remote_addr, + ucp_rkey_h rkey, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest(createRequestMem(endpoint, + true, + buffer, + length, + remote_addr, + rkey, + enablePythonFuture, + callbackFunction, + callbackData)); +} + std::shared_ptr Endpoint::amSend(void* buffer, size_t length, ucs_memory_type_t memoryType, diff --git a/cpp/src/request_mem.cpp b/cpp/src/request_mem.cpp new file mode 100644 index 000000000..c70bd0ac9 --- /dev/null +++ b/cpp/src/request_mem.cpp @@ -0,0 +1,293 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include +#include +#include + +#include + +#include +#include +#include + +namespace ucxx { + +std::shared_ptr createRequestMem(std::shared_ptr endpoint, + bool send, + void* buffer, + size_t length, + uint64_t remote_addr, + ucp_rkey_h rkey, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr) +{ + auto req = std::shared_ptr(new RequestMem(endpoint, + send, + buffer, + length, + remote_addr, + rkey, + enablePythonFuture, + callbackFunction, + callbackData)); + + // A delayed notification request is not populated immediately, instead it is + // delayed to allow the worker progress thread to set its status, and more + // importantly the Python future later on, so that we don't need the GIL here. + req->_worker->registerDelayedSubmission( + req, std::bind(std::mem_fn(&Request::populateDelayedSubmission), req.get())); + + return req; +} + +RequestMem::RequestMem(std::shared_ptr endpoint, + bool send, + void* buffer, + size_t length, + uint64_t remote_addr, + ucp_rkey_h rkey, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) + : Request(endpoint, + std::make_shared(send, buffer, length, 0), + std::string(send ? "memSend" : "memRecv"), + enablePythonFuture) +{ + if (_endpoint == nullptr) + throw ucxx::Error("An endpoint is required to perform remote memory put/get messages"); + + _callback = callbackFunction; + _callbackData = callbackData; +} + +// static void _memSendCallback(void* request, ucs_status_t status, void* user_data) +// { +// Request* req = reinterpret_cast(user_data); +// ucxx_trace_req_f(req->getOwnerString().c_str(), request, "memSend", "_memSendCallback"); +// req->callback(request, status); +// } + +// static void _recvCompletedCallback(void* request, +// ucs_status_t status, +// size_t length, +// void* user_data) +// { +// internal::RecvMemMessage* recvMemMessage = static_cast(user_data); +// ucxx_trace_req_f( +// recvMemMessage->_request->getOwnerString().c_str(), request, "memRecv", "memRecvCallback"); +// recvMemMessage->callback(request, status); +// } + +// ucs_status_t RequestMem::recvCallback(void* arg, +// const void* header, +// size_t header_length, +// void* data, +// size_t length, +// const ucp_am_recv_param_t* param) +// { +// internal::AmData* memData = static_cast(arg); +// auto worker = memData->_worker.lock(); +// auto& ownerString = memData->_ownerString; +// auto& recvPool = memData->_recvPool; +// auto& recvWait = memData->_recvWait; + +// if ((param->recv_attr & UCP_AM_RECV_ATTR_FIELD_REPLY_EP) == 0) +// ucxx_error("UCP_AM_RECV_ATTR_FIELD_REPLY_EP not set"); + +// ucp_ep_h ep = param->reply_ep; + +// bool is_rndv = param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV; + +// std::shared_ptr buf{nullptr}; +// auto allocatorType = *static_cast(header); + +// std::shared_ptr req{nullptr}; + +// { +// std::lock_guard lock(memData->_mutex); + +// auto reqs = recvWait.find(ep); +// if (reqs != recvWait.end() && !reqs->second.empty()) { +// req = reqs->second.front(); +// reqs->second.pop(); +// ucxx_trace_req("memRecv recvWait: %p", req.get()); +// } else { +// req = std::shared_ptr( +// new RequestMem(worker, worker->isFutureEnabled(), nullptr, nullptr)); +// auto [queue, _] = recvPool.try_emplace(ep, std::queue>()); +// queue->second.push(req); +// ucxx_trace_req("memRecv recvPool: %p", req.get()); +// } +// } + +// if (is_rndv) { +// if (memData->_allocators.find(allocatorType) == memData->_allocators.end()) { +// // TODO: Is a hard failure better? +// // ucxx_debug("Unsupported memory type %d", allocatorType); +// // internal::RecvMemMessage recvMemMessage(memData, ep, req, nullptr); +// // recvMemMessage.callback(nullptr, UCS_ERR_UNSUPPORTED); +// // return UCS_ERR_UNSUPPORTED; + +// ucxx_trace_req("No allocator registered for memory type %d, falling back to host memory.", +// allocatorType); +// allocatorType = UCS_MEMORY_TYPE_HOST; +// } + +// std::shared_ptr buf = memData->_allocators.at(allocatorType)(length); + +// auto recvMemMessage = std::make_shared(memData, ep, req, buf); + +// ucp_request_param_t request_param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | +// UCP_OP_ATTR_FIELD_USER_DATA | +// UCP_OP_ATTR_FLAG_NO_IMM_CMPL, +// .cb = {.recv_am = _recvCompletedCallback}, +// .user_data = recvMemMessage.get()}; + +// ucs_status_ptr_t status = +// ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &request_param); + +// if (req->_enablePythonFuture) +// ucxx_trace_req_f(ownerString.c_str(), +// status, +// "memRecv rndv", +// "ep %p, buffer %p, size %lu, future %p, future handle %p, recvCallback", +// ep, +// buf->data(), +// length, +// req->_future.get(), +// req->_future->getHandle()); +// else +// ucxx_trace_req_f(ownerString.c_str(), +// status, +// "memRecv rndv", +// "ep %p, buffer %p, size %lu, recvCallback", +// ep, +// buf->data(), +// length); + +// if (req->isCompleted()) { +// // The request completed/errored immediately +// ucs_status_t s = UCS_PTR_STATUS(status); +// recvMemMessage->callback(nullptr, s); + +// return s; +// } else { +// // The request will be handled by the callback +// recvMemMessage->setUcpRequest(status); +// memData->_registerInflightRequest(req); + +// { +// std::lock_guard lock(memData->_mutex); +// memData->_recvMemMessageMap.emplace(req.get(), recvMemMessage); +// } + +// return UCS_INPROGRESS; +// } +// } else { +// std::shared_ptr buf = memData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length); +// if (length > 0) memcpy(buf->data(), data, length); + +// if (req->_enablePythonFuture) +// ucxx_trace_req_f(ownerString.c_str(), +// nullptr, +// "memRecv eager", +// "ep: %p, buffer %p, size %lu, future %p, future handle %p, recvCallback", +// ep, +// buf->data(), +// length, +// req->_future.get(), +// req->_future->getHandle()); +// else +// ucxx_trace_req_f(ownerString.c_str(), +// nullptr, +// "memRecv eager", +// "ep: %p, buffer %p, size %lu, recvCallback", +// ep, +// buf->data(), +// length); + +// internal::RecvMemMessage recvMemMessage(memData, ep, req, buf); +// recvMemMessage.callback(nullptr, UCS_OK); +// return UCS_OK; +// } +// } + +void RequestMem::memPutCallback(void* request, ucs_status_t status, void* arg) +{ + Request* req = reinterpret_cast(arg); + ucxx_trace_req_f(req->getOwnerString().c_str(), request, "memSend", "memPutCallback"); + return req->callback(request, status); +} + +void RequestMem::memGetCallback(void* request, ucs_status_t status, void* arg) +{ + Request* req = reinterpret_cast(arg); + ucxx_trace_req_f(req->getOwnerString().c_str(), request, "memRecv", "memGetCallback"); + return req->callback(request, status); +} + +void RequestMem::request() +{ + ucp_request_param_t param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_FLAGS | + UCP_OP_ATTR_FIELD_USER_DATA, + .flags = UCP_AM_SEND_FLAG_REPLY, + .datatype = ucp_dt_make_contig(1), + .user_data = this}; + + void* request = nullptr; + + if (_delayedSubmission->_send) { + param.cb.send = memPutCallback; + request = ucp_put_nbx(_endpoint->getHandle(), + _delayedSubmission->_buffer, + _delayedSubmission->_length, + _remote_addr, + _rkey, + ¶m); + + std::lock_guard lock(_mutex); + _request = request; + } else { + param.cb.send = memGetCallback; + request = ucp_get_nbx(_endpoint->getHandle(), + _delayedSubmission->_buffer, + _delayedSubmission->_length, + _remote_addr, + _rkey, + ¶m); + } + + std::lock_guard lock(_mutex); + _request = request; +} + +void RequestMem::populateDelayedSubmission() +{ + request(); + + if (_enablePythonFuture) + ucxx_trace_req_f(_ownerString.c_str(), + _request, + _operationName.c_str(), + "buffer %p, size %lu, future %p, future handle %p, populateDelayedSubmission", + _delayedSubmission->_buffer, + _delayedSubmission->_length, + _future.get(), + _future->getHandle()); + else + ucxx_trace_req_f(_ownerString.c_str(), + _request, + _operationName.c_str(), + "buffer %p, size %lu, populateDelayedSubmission", + _delayedSubmission->_buffer, + _delayedSubmission->_length); + + process(); +} + +} // namespace ucxx From 1e01fb4686df2e2d7730008b7b69729134f3845d Mon Sep 17 00:00:00 2001 From: Michael Demoret Date: Mon, 8 Jan 2024 14:28:33 -0700 Subject: [PATCH 2/9] Changes from testing --- cpp/src/request_mem.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/src/request_mem.cpp b/cpp/src/request_mem.cpp index c70bd0ac9..446e7ef09 100644 --- a/cpp/src/request_mem.cpp +++ b/cpp/src/request_mem.cpp @@ -55,7 +55,9 @@ RequestMem::RequestMem(std::shared_ptr endpoint, : Request(endpoint, std::make_shared(send, buffer, length, 0), std::string(send ? "memSend" : "memRecv"), - enablePythonFuture) + enablePythonFuture), + _remote_addr(remote_addr), + _rkey(rkey) { if (_endpoint == nullptr) throw ucxx::Error("An endpoint is required to perform remote memory put/get messages"); From f21d093b52995788e19f7aa4b5f63c00d2015adc Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 7 Feb 2024 09:22:12 -0800 Subject: [PATCH 3/9] Update to new `RequestData` interface and update docs --- cpp/include/ucxx/constructors.h | 15 +-- cpp/include/ucxx/endpoint.h | 76 ++++++++++--- cpp/include/ucxx/request_data.h | 67 ++++++++++- cpp/include/ucxx/request_mem.h | 123 +++++++------------- cpp/src/endpoint.cpp | 65 +++++------ cpp/src/request_data.cpp | 17 +++ cpp/src/request_mem.cpp | 193 ++++++++++++++++++++------------ cpp/src/request_tag.cpp | 2 - 8 files changed, 344 insertions(+), 214 deletions(-) diff --git a/cpp/include/ucxx/constructors.h b/cpp/include/ucxx/constructors.h index 8c8b86c3e..2c3a272f3 100644 --- a/cpp/include/ucxx/constructors.h +++ b/cpp/include/ucxx/constructors.h @@ -76,15 +76,12 @@ std::shared_ptr createRequestTag( RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); -std::shared_ptr createRequestMem(std::shared_ptr endpoint, - bool send, - void* buffer, - size_t length, - uint64_t remote_addr, - ucp_rkey_h rkey, - const bool enablePythonFuture, - RequestCallbackUserFunction callbackFunction, - RequestCallbackUserData callbackData); +std::shared_ptr createRequestMem( + std::shared_ptr endpoint, + const std::variant requestData, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); std::shared_ptr createRequestTagMulti( std::shared_ptr endpoint, diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 53464a83c..87b9b4995 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -288,22 +288,6 @@ class Endpoint : public Component { */ void setCloseCallback(std::function closeCallback, void* closeCallbackArg); - std::shared_ptr memGet(void* buffer, - size_t length, - uint64_t remote_addr, - ucp_rkey_h rkey, - const bool enablePythonFuture = false, - RequestCallbackUserFunction callbackFunction = nullptr, - RequestCallbackUserData callbackData = nullptr); - - std::shared_ptr memPut(void* buffer, - size_t length, - uint64_t remote_addr, - ucp_rkey_h rkey, - const bool enablePythonFuture = false, - RequestCallbackUserFunction callbackFunction = nullptr, - RequestCallbackUserData callbackData = nullptr); - /** * @brief Enqueue an active message send operation. * @@ -358,6 +342,66 @@ class Endpoint : public Component { RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); + /** + * @brief Enqueue a memory put operation. + * + * Enqueue a memory operation, returning a `std::shared` that can be later + * awaited and checked for errors. This is a non-blocking operation, and the status of the + * transfer must be verified from the resulting request object before both local and + * remote data can be released and the remote data can be consumed. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + * @param[in] remoteAddr the destination remote memory address to write to. + * @param[in] rkey the remote memory key associated with the remote memory + * address. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr memPut(void* buffer, + size_t length, + uint64_t remote_addr, + ucp_rkey_h rkey, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + /** + * @brief Enqueue a memory get operation. + * + * Enqueue a memory operation, returning a `std::shared` that can be later + * awaited and checked for errors. This is a non-blocking operation, and the status of the + * transfer must be verified from the resulting request object before both local and + * remote data can be released and the local data can be consumed. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + * @param[in] remoteAddr the source remote memory address to read from. + * @param[in] rkey the remote memory key associated with the remote memory + * address. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr memGet(void* buffer, + size_t length, + uint64_t remoteAddr, + ucp_rkey_h rkey, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + /** * @brief Enqueue a stream send operation. * diff --git a/cpp/include/ucxx/request_data.h b/cpp/include/ucxx/request_data.h index 8143dad78..200a8b6a2 100644 --- a/cpp/include/ucxx/request_data.h +++ b/cpp/include/ucxx/request_data.h @@ -66,6 +66,67 @@ class AmReceive { AmReceive(); }; +/** + * @brief Data for a memory send. + * + * Type identifying a memory send operation and containing data specific to this request type. + */ +class MemSend { + public: + const void* _buffer{nullptr}; ///< The raw pointer where data to be sent is stored. + const size_t _length{0}; ///< The length of the message. + const uint64_t _remoteAddr{0}; ///< Remote memory address to write to. + const ucp_rkey_h _rkey{}; ///< UCX remote key associated with the remote memory address. + + /** + * @brief Constructor for memory-specific data. + * + * Construct an object containing memory-specific data. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + * @param[in] remoteAddr the destination remote memory address to write to. + * @param[in] rkey the remote memory key associated with the remote memory address. + */ + explicit MemSend(const decltype(_buffer) buffer, + const decltype(_length) length, + const decltype(_remoteAddr) remoteAddr, + const decltype(_rkey) rkey); + + MemSend() = delete; +}; + +/** + * @brief Data for a memory receive. + * + * Type identifying a memory receive operation and containing data specific to this request + * type. + */ +class MemReceive { + public: + void* _buffer{nullptr}; ///< The raw pointer where received data should be stored. + const size_t _length{0}; ///< The length of the message. + const uint64_t _remoteAddr{0}; ///< Remote memory address to read from. + const ucp_rkey_h _rkey{}; ///< UCX remote key associated with the remote memory address. + + /** + * @brief Constructor for memory-specific data. + * + * Construct an object containing memory-specific data. + * + * @param[out] buffer a raw pointer to the received data. + * @param[in] length the size in bytes of the tag message to be received. + * @param[in] remoteAddr the source remote memory address to read from. + * @param[in] rkey the remote memory key associated with the remote memory address. + */ + explicit MemReceive(decltype(_buffer) buffer, + const decltype(_length) length, + const decltype(_remoteAddr) remoteAddr, + const decltype(_rkey) rkey); + + MemReceive() = delete; +}; + /** * @brief Data for a Stream send. * @@ -127,7 +188,7 @@ class TagSend { const ::ucxx::Tag _tag{0}; ///< Tag to match /** - * @brief Constructor for tag/multi-buffer tag-specific data. + * @brief Constructor for tag-specific data. * * Construct an object containing tag-specific data. * @@ -156,7 +217,7 @@ class TagReceive { const ::ucxx::TagMask _tagMask{0}; ///< Tag mask to use /** - * @brief Constructor send tag-specific data. + * @brief Constructor for tag-specific data. * * Construct an object containing send tag-specific data. * @@ -231,6 +292,8 @@ class TagMultiReceive { using RequestData = std::variant +#include #include #include @@ -28,78 +29,76 @@ class RequestMem : public Request { ucp_rkey_h _rkey{}; /** - * @brief Private constructor of `ucxx::RequestAm` send. + * @brief Private constructor of `ucxx::RequestMem`. * - * This is the internal implementation of `ucxx::RequestAm` send constructor, made private - * not to be called directly. This constructor is made private to ensure all UCXX objects + * This is the internal implementation of `ucxx::RequestMem` constructor, made private not + * to be called directly. This constructor is made private to ensure all UCXX objects * are shared pointers and the correct lifetime management of each one. * * Instead the user should use one of the following: * - * - `ucxx::Endpoint::amSend()` - * - `ucxx::createRequestAmSend()` + * - `ucxx::Endpoint::memGet()` + * - `ucxx::Endpoint::memPut()` + * - `ucxx::createRequestMem()` * - * @throws ucxx::Error if `endpoint` is not a valid `std::shared_ptr`. + * @throws ucxx::Error if send is `true` and `endpointOrWorker` is not a + * `std::shared_ptr`. * - * @param[in] endpoint the parent endpoint. - * @param[in] buffer a raw pointer to the data to be sent. - * @param[in] length the size in bytes of the active message to be sent. - * @param[in] memoryType the memory type of the buffer. + * @param[in] endpoint the `std::shared_ptr` parent component. + * @param[in] requestData container of the specified message type, including all + * type-specific data. + * @param[in] operationName a human-readable operation name to help identifying + * requests by their types when UCXX logging is enabled. * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. * @param[in] callbackFunction user-defined callback function to call upon completion. * @param[in] callbackData user-defined data to pass to the `callbackFunction`. */ RequestMem(std::shared_ptr endpoint, - bool send, - void* buffer, - size_t length, - uint64_t remote_addr, - ucp_rkey_h rkey, + const std::variant requestData, + const std::string operationName, const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); public: /** - * @brief Constructor for `std::shared_ptr` send. + * @brief Constructor for `std::shared_ptr`. * - * The constructor for a `std::shared_ptr` object, creating a send active - * message request, returning a pointer to a request object that can be later awaited and - * checked for errors. This is a non-blocking operation, and the status of the transfer - * must be verified from the resulting request object before the data can be - * released. + * The constructor for a `std::shared_ptr` object, creating a get or put + * request, returning a pointer to a request object that can be later awaited and checked + * for errors. This is a non-blocking operation, and the status of the transfer must be + * verified from the resulting request object before both the local and remote data can + * be released and the local data (on get operations) or remote data (on put operations) + * can be consumed. * - * @throws ucxx::Error if `endpoint` is not a valid - * `std::shared_ptr`. + * @throws ucxx::Error if `endpoint` is not a valid `std::shared_ptr`. * - * @param[in] endpoint the parent endpoint. - * @param[in] buffer a raw pointer to the data to be transferred. - * @param[in] length the size in bytes of the tag message to be transferred. - * @param[in] memoryType the memory type of the buffer. + * @param[in] endpointOrWorker the parent component, which may either be a + * `std::shared_ptr` or + * `std::shared_ptr`. + * @param[in] requestData container of the specified message type, including all + * type-specific data. * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. * @param[in] callbackFunction user-defined callback function to call upon completion. * @param[in] callbackData user-defined data to pass to the `callbackFunction`. * - * @returns The `shared_ptr` object + * @returns The `shared_ptr` object */ - friend std::shared_ptr createRequestMem(std::shared_ptr endpoint, - bool send, - void* buffer, - size_t length, - uint64_t remote_addr, - ucp_rkey_h rkey, - const bool enablePythonFuture, - RequestCallbackUserFunction callbackFunction, - RequestCallbackUserData callbackData); + friend std::shared_ptr createRequestMem( + std::shared_ptr endpoint, + const std::variant requestData, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); virtual void populateDelayedSubmission(); /** - * @brief Callback executed by UCX when a tag send request is completed. + * @brief Callback executed by UCX when a memory put request is completed. * - * Callback executed by UCX when a tag send request is completed, that will dispatch + * Callback executed by UCX when a memory put request is completed, that will dispatch * `ucxx::Request::callback()`. * * WARNING: This is not intended to be called by the user, but it currently needs to be @@ -114,10 +113,10 @@ class RequestMem : public Request { static void memPutCallback(void* request, ucs_status_t status, void* arg); /** - * @brief Callback executed by UCX when a tag receive request is completed. + * @brief Callback executed by UCX when a memory get request is completed. * - * Callback executed by UCX when a tag receive request is completed, that will dispatch - * `ucxx::RequestTag::callback()`. + * Callback executed by UCX when a memory get request is completed, that will dispatch + * `ucxx::Request::callback()`. * * WARNING: This is not intended to be called by the user, but it currently needs to be * a public method so that UCX may access it. In future changes this will be moved to @@ -125,56 +124,20 @@ class RequestMem : public Request { * * @param[in] request the UCX request pointer. * @param[in] status the completion status of the request. - * @param[in] info information of the completed transfer provided by UCX, includes - * length of message received used to verify for truncation. * @param[in] arg the pointer to the `ucxx::Request` object that created the * transfer, effectively `this` pointer as seen by `request()`. */ static void memGetCallback(void* request, ucs_status_t status, void* arg); /** - * @brief Create and submit an active message send request. + * @brief Create and submit a memory get or put request. * - * This is the method that should be called to actually submit an active message send + * This is the method that should be called to actually submit memory request get or put * request. It is meant to be called from `populateDelayedSubmission()`, which is decided * at the discretion of `std::shared_ptr`. See `populateDelayedSubmission()` * for more details. */ void request(); - - // /** - // * @brief Receive callback registered by `ucxx::Worker`. - // * - // * This is the receive callback registered by the `ucxx::Worker` to handle incoming active - // * messages. For each incoming active message, a proper buffer will be allocated based on - // * the header sent by the remote endpoint using the default allocator or one registered by - // * the user via `ucxx::Worker::registerAmAllocator()`. Following that, the message is - // * immediately received onto the buffer and a `UCS_OK` or the proper error status is set - // * onto the `RequestAm` that is returned to the user, or will be later handled by another - // * callback when the message is ready. If the callback is executed when a user has already - // * requested received of the active message, the buffer and status will be set on the - // * earliest request, otherwise a new request is created and saved in a pool that will be - // * already populated and ready for consumption or waiting for the internal callback when - // * requested. - // * - // * This is always called by `ucp_worker_progress()`, and thus will happen in the same - // * thread that is called from, when using the worker progress thread, this is called from - // * the progress thread. - // * - // * param[in,out] arg pointer to the `AmData` object held by the `ucxx::Worker` who - // * registered this callback. - // * param[in] header pointer to the header containing the sender buffer's memory type. - // * param[in] header_length length in bytes of the receive header. - // * param[in] data pointer to the buffer containing the remote endpoint's send data. - // * param[in] length the length in bytes of the message to be received. - // * param[in] param UCP parameters of the active message being received. - // */ - // static ucs_status_t recvCallback(void* arg, - // const void* header, - // size_t header_length, - // void* data, - // size_t length, - // const ucp_am_recv_param_t* param); }; } // namespace ucxx diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 82742e4e3..9582c9e27 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -330,68 +330,61 @@ size_t Endpoint::cancelInflightRequests(uint64_t period, uint64_t maxAttempts) return canceled; } -std::shared_ptr Endpoint::memGet(void* buffer, +std::shared_ptr Endpoint::amSend(void* buffer, size_t length, - uint64_t remote_addr, - ucp_rkey_h rkey, + ucs_memory_type_t memoryType, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest(createRequestMem(endpoint, - false, - buffer, - length, - remote_addr, - rkey, - enablePythonFuture, - callbackFunction, - callbackData)); + return registerInflightRequest(createRequestAm(endpoint, + data::AmSend(buffer, length, memoryType), + enablePythonFuture, + callbackFunction, + callbackData)); } -std::shared_ptr Endpoint::memPut(void* buffer, - size_t length, - uint64_t remote_addr, - ucp_rkey_h rkey, - const bool enablePythonFuture, +std::shared_ptr Endpoint::amRecv(const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest(createRequestMem(endpoint, - true, - buffer, - length, - remote_addr, - rkey, - enablePythonFuture, - callbackFunction, - callbackData)); + return registerInflightRequest(createRequestAm( + endpoint, data::AmReceive(), enablePythonFuture, callbackFunction, callbackData)); } -std::shared_ptr Endpoint::amSend(void* buffer, +std::shared_ptr Endpoint::memGet(void* buffer, size_t length, - ucs_memory_type_t memoryType, + uint64_t remoteAddr, + ucp_rkey_h rkey, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest(createRequestAm(endpoint, - data::AmSend(buffer, length, memoryType), - enablePythonFuture, - callbackFunction, - callbackData)); + return registerInflightRequest( + createRequestMem(endpoint, + data::MemReceive(buffer, length, remoteAddr, rkey), + enablePythonFuture, + callbackFunction, + callbackData)); } -std::shared_ptr Endpoint::amRecv(const bool enablePythonFuture, +std::shared_ptr Endpoint::memPut(void* buffer, + size_t length, + uint64_t remoteAddr, + ucp_rkey_h rkey, + const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest(createRequestAm( - endpoint, data::AmReceive(), enablePythonFuture, callbackFunction, callbackData)); + return registerInflightRequest(createRequestMem(endpoint, + data::MemSend(buffer, length, remoteAddr, rkey), + enablePythonFuture, + callbackFunction, + callbackData)); } std::shared_ptr Endpoint::streamSend(void* buffer, diff --git a/cpp/src/request_data.cpp b/cpp/src/request_data.cpp index 3bc1efd46..ad08658ce 100644 --- a/cpp/src/request_data.cpp +++ b/cpp/src/request_data.cpp @@ -6,6 +6,7 @@ #include +#include #include #include @@ -20,6 +21,22 @@ AmSend::AmSend(const void* buffer, const size_t length, const ucs_memory_type me AmReceive::AmReceive() {} +MemSend::MemSend(const void* buffer, + const size_t length, + const uint64_t remoteAddr, + const ucp_rkey_h rkey) + : _buffer(buffer), _length(length), _remoteAddr(remoteAddr), _rkey(rkey) +{ +} + +MemReceive::MemReceive(void* buffer, + const size_t length, + const uint64_t remoteAddr, + const ucp_rkey_h rkey) + : _buffer(buffer), _length(length), _remoteAddr(remoteAddr), _rkey(rkey) +{ +} + StreamSend::StreamSend(const void* buffer, const size_t length) : _buffer(buffer), _length(length) { /** diff --git a/cpp/src/request_mem.cpp b/cpp/src/request_mem.cpp index 446e7ef09..e2592a600 100644 --- a/cpp/src/request_mem.cpp +++ b/cpp/src/request_mem.cpp @@ -14,25 +14,26 @@ namespace ucxx { -std::shared_ptr createRequestMem(std::shared_ptr endpoint, - bool send, - void* buffer, - size_t length, - uint64_t remote_addr, - ucp_rkey_h rkey, - const bool enablePythonFuture = false, - RequestCallbackUserFunction callbackFunction = nullptr, - RequestCallbackUserData callbackData = nullptr) +std::shared_ptr createRequestMem( + std::shared_ptr endpoint, + const std::variant requestData, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr) { - auto req = std::shared_ptr(new RequestMem(endpoint, - send, - buffer, - length, - remote_addr, - rkey, - enablePythonFuture, - callbackFunction, - callbackData)); + std::shared_ptr req = std::visit( + data::dispatch{ + [&endpoint, &enablePythonFuture, &callbackFunction, &callbackData](data::MemSend memSend) { + return std::shared_ptr(new RequestMem( + endpoint, memSend, "memSend", enablePythonFuture, callbackFunction, callbackData)); + }, + [&endpoint, &enablePythonFuture, &callbackFunction, &callbackData]( + data::MemReceive memReceive) { + return std::shared_ptr(new RequestMem( + endpoint, memReceive, "memRecv", enablePythonFuture, callbackFunction, callbackData)); + }, + }, + requestData); // A delayed notification request is not populated immediately, instead it is // delayed to allow the worker progress thread to set its status, and more @@ -44,23 +45,25 @@ std::shared_ptr createRequestMem(std::shared_ptr endpoint, } RequestMem::RequestMem(std::shared_ptr endpoint, - bool send, - void* buffer, - size_t length, - uint64_t remote_addr, - ucp_rkey_h rkey, + const std::variant requestData, + const std::string operationName, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) - : Request(endpoint, - std::make_shared(send, buffer, length, 0), - std::string(send ? "memSend" : "memRecv"), - enablePythonFuture), - _remote_addr(remote_addr), - _rkey(rkey) + : Request(endpoint, data::getRequestData(requestData), operationName, enablePythonFuture) { - if (_endpoint == nullptr) - throw ucxx::Error("An endpoint is required to perform remote memory put/get messages"); + std::visit(data::dispatch{ + [this](data::MemSend memSend) { + if (_endpoint == nullptr) + throw ucxx::Error("A valid endpoint is required to send memory messages."); + }, + [this](data::MemReceive memReceive) { + if (_endpoint == nullptr) + throw ucxx::Error("A valid endpoint is required to receive memory messages."); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + requestData); _callback = callbackFunction; _callbackData = callbackData; @@ -221,14 +224,14 @@ RequestMem::RequestMem(std::shared_ptr endpoint, void RequestMem::memPutCallback(void* request, ucs_status_t status, void* arg) { Request* req = reinterpret_cast(arg); - ucxx_trace_req_f(req->getOwnerString().c_str(), request, "memSend", "memPutCallback"); + ucxx_trace_req_f(req->getOwnerString().c_str(), nullptr, request, "memSend", "memPutCallback"); return req->callback(request, status); } void RequestMem::memGetCallback(void* request, ucs_status_t status, void* arg) { Request* req = reinterpret_cast(arg); - ucxx_trace_req_f(req->getOwnerString().c_str(), request, "memRecv", "memGetCallback"); + ucxx_trace_req_f(req->getOwnerString().c_str(), nullptr, request, "memRecv", "memGetCallback"); return req->callback(request, status); } @@ -243,51 +246,103 @@ void RequestMem::request() void* request = nullptr; - if (_delayedSubmission->_send) { - param.cb.send = memPutCallback; - request = ucp_put_nbx(_endpoint->getHandle(), - _delayedSubmission->_buffer, - _delayedSubmission->_length, - _remote_addr, - _rkey, - ¶m); - - std::lock_guard lock(_mutex); - _request = request; - } else { - param.cb.send = memGetCallback; - request = ucp_get_nbx(_endpoint->getHandle(), - _delayedSubmission->_buffer, - _delayedSubmission->_length, - _remote_addr, - _rkey, - ¶m); - } + std::visit(data::dispatch{ + [this, &request, ¶m](data::MemSend memSend) { + param.cb.send = memPutCallback; + request = ucp_put_nbx(_endpoint->getHandle(), + memSend._buffer, + memSend._length, + memSend._remoteAddr, + memSend._rkey, + ¶m); + }, + [this, &request, ¶m](data::MemReceive memReceive) { + param.cb.send = memGetCallback; + request = ucp_get_nbx(_endpoint->getHandle(), + memReceive._buffer, + memReceive._length, + memReceive._remoteAddr, + memReceive._rkey, + ¶m); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); std::lock_guard lock(_mutex); _request = request; } +static void logPopulateDelayedSubmission() {} + void RequestMem::populateDelayedSubmission() { + bool terminate = + std::visit(data::dispatch{ + [this](data::MemSend memSend) { + if (_endpoint->getHandle() == nullptr) { + ucxx_warn("Endpoint was closed before message could be sent"); + Request::callback(this, UCS_ERR_CANCELED); + return true; + } + return false; + }, + [this](data::MemReceive memReceive) { + if (_worker->getHandle() == nullptr) { + ucxx_warn("Endpoint was closed before message could be received"); + Request::callback(this, UCS_ERR_CANCELED); + return true; + } + return false; + }, + [](auto) -> decltype(terminate) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); + if (terminate) return; + request(); - if (_enablePythonFuture) - ucxx_trace_req_f(_ownerString.c_str(), - _request, - _operationName.c_str(), - "buffer %p, size %lu, future %p, future handle %p, populateDelayedSubmission", - _delayedSubmission->_buffer, - _delayedSubmission->_length, - _future.get(), - _future->getHandle()); - else - ucxx_trace_req_f(_ownerString.c_str(), - _request, - _operationName.c_str(), - "buffer %p, size %lu, populateDelayedSubmission", - _delayedSubmission->_buffer, - _delayedSubmission->_length); + auto log = + [this]( + const void* buffer, const size_t length, const uint64_t remoteAddr, const ucp_rkey_h rkey) { + if (_enablePythonFuture) + ucxx_trace_req_f( + _ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "populateDelayedSubmission, buffer: %p, size: %lu, remoteAddr: 0x%lx, rkey: %p, " + "future: %p, future handle: %p", + buffer, + length, + remoteAddr, + rkey, + _future.get(), + _future->getHandle()); + else + ucxx_trace_req_f( + _ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "populateDelayedSubmission, buffer: %p, size: %lu, remoteAddr: 0x%lx, rkey: %p", + buffer, + length, + remoteAddr, + rkey); + }; + + std::visit( + data::dispatch{ + [this, &log](data::MemSend memSend) { + log(memSend._buffer, memSend._length, memSend._remoteAddr, memSend._rkey); + }, + [this, &log](data::MemReceive memReceive) { + log(memReceive._buffer, memReceive._length, memReceive._remoteAddr, memReceive._rkey); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); process(); } diff --git a/cpp/src/request_tag.cpp b/cpp/src/request_tag.cpp index 9bdc1f487..f8dfc8e2a 100644 --- a/cpp/src/request_tag.cpp +++ b/cpp/src/request_tag.cpp @@ -137,8 +137,6 @@ void RequestTag::request() _request = request; } -static void logPopulateDelayedSubmission() {} - void RequestTag::populateDelayedSubmission() { bool terminate = From d9d006d7e89b788066825789ecd9166b94479291 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 6 Feb 2024 13:42:10 -0800 Subject: [PATCH 4/9] Fix default return for `getProgressFunction()` test function --- cpp/tests/utils.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/cpp/tests/utils.cpp b/cpp/tests/utils.cpp index ec063fec0..ba84d33bb 100644 --- a/cpp/tests/utils.cpp +++ b/cpp/tests/utils.cpp @@ -16,14 +16,12 @@ void createCudaContextCallback(void* callbackArg) std::function getProgressFunction(std::shared_ptr worker, ProgressMode progressMode) { - if (progressMode == ProgressMode::Polling) - return std::bind(std::mem_fn(&ucxx::Worker::progress), worker); - else if (progressMode == ProgressMode::Blocking) - return std::bind(std::mem_fn(&ucxx::Worker::progressWorkerEvent), worker, -1); - else if (progressMode == ProgressMode::Wait) - return std::bind(std::mem_fn(&ucxx::Worker::waitProgress), worker); - else - return std::function(); + switch (progressMode) { + case ProgressMode::Polling: return [worker]() { worker->progress(); }; + case ProgressMode::Blocking: return [worker]() { worker->progressWorkerEvent(-1); }; + case ProgressMode::Wait: return [worker]() { worker->waitProgress(); }; + default: return []() {}; + } } bool loopWithTimeout(std::chrono::milliseconds timeout, std::function f) From fb1f5dfcd0747d8b678ae2f2688ba5bf18441d39 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 12 Feb 2024 02:51:09 -0800 Subject: [PATCH 5/9] Add `ucxx::RequestFlush` --- cpp/CMakeLists.txt | 1 + cpp/include/ucxx/constructors.h | 7 ++ cpp/include/ucxx/endpoint.h | 25 +++++ cpp/include/ucxx/request_data.h | 16 ++++ cpp/include/ucxx/request_flush.h | 126 +++++++++++++++++++++++++ cpp/include/ucxx/worker.h | 25 +++++ cpp/src/endpoint.cpp | 10 ++ cpp/src/request_data.cpp | 2 + cpp/src/request_flush.cpp | 157 +++++++++++++++++++++++++++++++ cpp/src/worker.cpp | 10 ++ 10 files changed, 379 insertions(+) create mode 100644 cpp/include/ucxx/request_flush.h create mode 100644 cpp/src/request_flush.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index fe27fbbdf..f4b881618 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -133,6 +133,7 @@ add_library( src/request.cpp src/request_am.cpp src/request_data.cpp + src/request_flush.cpp src/request_helper.cpp src/request_mem.cpp src/request_stream.cpp diff --git a/cpp/include/ucxx/constructors.h b/cpp/include/ucxx/constructors.h index 2c3a272f3..c393e4a1e 100644 --- a/cpp/include/ucxx/constructors.h +++ b/cpp/include/ucxx/constructors.h @@ -21,6 +21,7 @@ class Listener; class Notifier; class Request; class RequestAm; +class RequestFlush; class RequestMem; class RequestStream; class RequestTag; @@ -64,6 +65,12 @@ std::shared_ptr createRequestAm( RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); +std::shared_ptr createRequestFlush(std::shared_ptr endpointOrWorker, + const std::variant requestData, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + std::shared_ptr createRequestStream( std::shared_ptr endpoint, const std::variant requestData, diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 87b9b4995..165454790 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -571,6 +571,31 @@ class Endpoint : public Component { const TagMask tagMask, const bool enablePythonFuture); + /** + * @brief Enqueue a flush operation. + * + * Enqueue request to flush outstanding AMO (Atomic Memory Operation) and RMA (Remote + * Memory Access) operations on the endpoint, returning a pointer to a request object that + * can be later awaited and checked for errors. This is a non-blocking operation, and its + * status must be verified from the resulting request object to confirm the flush + * operation has completed successfully. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr flush(const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + /** * @brief Get `ucxx::Worker` component from a worker or listener object. * diff --git a/cpp/include/ucxx/request_data.h b/cpp/include/ucxx/request_data.h index 200a8b6a2..75b6eda46 100644 --- a/cpp/include/ucxx/request_data.h +++ b/cpp/include/ucxx/request_data.h @@ -66,6 +66,21 @@ class AmReceive { AmReceive(); }; +/** + * @brief Data for a flush operation. + * + * Type identifying a flush operation and containing data specific to this request type. + */ +class Flush { + public: + /** + * @brief Constructor for flush-specific data. + * + * Construct an object containing flush-specific data. + */ + Flush(); +}; + /** * @brief Data for a memory send. * @@ -292,6 +307,7 @@ class TagMultiReceive { using RequestData = std::variant +#include +#include + +#include + +#include +#include +#include + +namespace ucxx { + +/** + * @brief Flush a UCP endpoint or worker. + * + * Flush outstanding AMO (Atomic Memory Operation) and RMA (Remote Memory Access) operations + * on a UCP endpoint or worker. + */ +class RequestFlush : public Request { + private: + /** + * @brief Private constructor of `ucxx::RequestFlush`. + * + * This is the internal implementation of `ucxx::RequestFlush` constructor, made private + * not to be called directly. This constructor is made private to ensure all UCXX objects + * are shared pointers and the correct lifetime management of each one. + * + * Instead the user should use one of the following: + * + * - `ucxx::Endpoint::flush()` + * - `ucxx::Worker::flush()` + * - `ucxx::createRequestFlush()` + * + * @throws ucxx::Error if `endpointOrWorker` is not a valid + * `std::shared_ptr` or + * `std::shared_ptr`. + * + * @param[in] endpointOrWorker the parent component, which may either be a + * `std::shared_ptr` or + * `std::shared_ptr`. + * @param[in] requestData container of the specified message type, including all + * type-specific data. + * @param[in] operationName a human-readable operation name to help identifying + * requests by their types when UCXX logging is enabled. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + */ + RequestFlush(std::shared_ptr endpointOrWorker, + const std::variant requestData, + const std::string operationName, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + public: + /** + * @brief Constructor for `std::shared_ptr`. + * + * The constructor for a `std::shared_ptr` object, creating a request + * to flush outstanding AMO (Atomic Memory Operation) and RMA (Remote Memory Access) + * operations on a UCP endpoint or worker, returning a pointer to a request object that + * can be later awaited and checked for errors. This is a non-blocking operation, and its + * status must be verified from the resulting request object to confirm the flush + * operation has completed successfully. + * + * @throws ucxx::Error `endpointOrWorker` is not a valid + * `std::shared_ptr` or + * `std::shared_ptr`. + * + * @param[in] endpointOrWorker the parent component, which may either be a + * `std::shared_ptr` or + * `std::shared_ptr`. + * @param[in] requestData container of the specified message type, including all + * type-specific data. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns The `shared_ptr` object + */ + friend std::shared_ptr createRequestFlush( + std::shared_ptr endpointOrWorker, + const std::variant requestData, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + + virtual void populateDelayedSubmission(); + + /** + * @brief Create and submit a flush request. + * + * This is the method that should be called to actually submit a flush request. It is + * meant to be called from `populateDelayedSubmission()`, which is decided at the + * discretion of `std::shared_ptr`. See `populateDelayedSubmission()` for + * more details. + */ + void request(); + + /** + * @brief Callback executed by UCX when a flush request is completed. + * + * Callback executed by UCX when a flush request is completed, that will dispatch + * `ucxx::Request::callback()`. + * + * WARNING: This is not intended to be called by the user, but it currently needs to be + * a public method so that UCX may access it. In future changes this will be moved to + * an internal object and remove this method from the public API. + * + * @param[in] request the UCX request pointer. + * @param[in] status the completion status of the request. + * @param[in] arg the pointer to the `ucxx::Request` object that created the + * transfer, effectively `this` pointer as seen by `request()`. + */ + static void flushCallback(void* request, ucs_status_t status, void* arg); +}; + +} // namespace ucxx diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index f6ad17cce..e1c23f9ae 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -819,6 +819,31 @@ class Worker : public Component { * @returns `true` if any uncaught messages were received, `false` otherwise. */ bool amProbe(const ucp_ep_h endpointHandle) const; + + /** + * @brief Enqueue a flush operation. + + * Enqueue request to flush outstanding AMO (Atomic Memory Operation) and RMA (Remote + * Memory Access) operations on the worker, returning a pointer to a request object that + * can be later awaited and checked for errors. This is a non-blocking operation, and its + * status must be verified from the resulting request object to confirm the flush + * operation has completed successfully. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr flush(const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); }; } // namespace ucxx diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 9582c9e27..9f70da448 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -456,6 +457,15 @@ std::shared_ptr Endpoint::tagMultiRecv(const Tag tag, createRequestTagMulti(endpoint, data::TagMultiReceive(tag, tagMask), enablePythonFuture)); } +std::shared_ptr Endpoint::flush(const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest(createRequestFlush( + endpoint, data::Flush(), enablePythonFuture, callbackFunction, callbackData)); +} + std::shared_ptr Endpoint::getWorker() { return ::ucxx::getWorker(_parent); } void Endpoint::errorCallback(void* arg, ucp_ep_h ep, ucs_status_t status) diff --git a/cpp/src/request_data.cpp b/cpp/src/request_data.cpp index ad08658ce..7a52627b4 100644 --- a/cpp/src/request_data.cpp +++ b/cpp/src/request_data.cpp @@ -21,6 +21,8 @@ AmSend::AmSend(const void* buffer, const size_t length, const ucs_memory_type me AmReceive::AmReceive() {} +Flush::Flush() {} + MemSend::MemSend(const void* buffer, const size_t length, const uint64_t remoteAddr, diff --git a/cpp/src/request_flush.cpp b/cpp/src/request_flush.cpp new file mode 100644 index 000000000..bcba03702 --- /dev/null +++ b/cpp/src/request_flush.cpp @@ -0,0 +1,157 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include +#include +#include + +#include + +#include +#include +#include + +namespace ucxx { + +std::shared_ptr createRequestFlush( + std::shared_ptr endpointOrWorker, + const std::variant requestData, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr) +{ + std::shared_ptr req = std::visit( + data::dispatch{ + [&endpointOrWorker, &enablePythonFuture, &callbackFunction, &callbackData]( + data::Flush flush) { + return std::shared_ptr(new RequestFlush( + endpointOrWorker, flush, "flush", enablePythonFuture, callbackFunction, callbackData)); + }, + }, + requestData); + + // A delayed notification request is not populated immediately, instead it is + // delayed to allow the worker progress thread to set its status, and more + // importantly the Python future later on, so that we don't need the GIL here. + req->_worker->registerDelayedSubmission( + req, std::bind(std::mem_fn(&Request::populateDelayedSubmission), req.get())); + + return req; +} + +RequestFlush::RequestFlush(std::shared_ptr endpointOrWorker, + const std::variant requestData, + const std::string operationName, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) + : Request(endpointOrWorker, data::getRequestData(requestData), operationName, enablePythonFuture) +{ + std::visit( + data::dispatch{ + [this](data::Flush) { + if (_endpoint == nullptr && _worker == nullptr) + throw ucxx::Error("A valid endpoint or worker is required for a flush operation."); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + requestData); + + _callback = callbackFunction; + _callbackData = callbackData; +} + +void RequestFlush::flushCallback(void* request, ucs_status_t status, void* arg) +{ + Request* req = reinterpret_cast(arg); + ucxx_trace_req_f(req->getOwnerString().c_str(), nullptr, request, "flush", "flushCallback"); + return req->callback(request, status); +} + +void RequestFlush::request() +{ + ucp_request_param_t param = { + .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, .user_data = this}; + + void* request = nullptr; + + std::visit( + data::dispatch{ + [this, &request, ¶m](data::Flush) { + param.cb.send = flushCallback; + if (_endpoint != nullptr) + request = ucp_ep_flush_nbx(_endpoint->getHandle(), ¶m); + else if (_worker != nullptr) + request = ucp_worker_flush_nbx(_worker->getHandle(), ¶m); + else + throw ucxx::Error("A valid endpoint or worker is required for a flush operation."); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); + + std::lock_guard lock(_mutex); + _request = request; +} + +static void logPopulateDelayedSubmission() {} + +void RequestFlush::populateDelayedSubmission() +{ + bool terminate = + std::visit(data::dispatch{ + [this](data::Flush flush) { + if (_endpoint != nullptr && _endpoint->getHandle() == nullptr) { + ucxx_warn("Endpoint was closed before it could be flushed"); + Request::callback(this, UCS_ERR_CANCELED); + return true; + } else if (_worker != nullptr && _worker->getHandle() == nullptr) { + ucxx_warn("Worker was closed before it could be flushed"); + Request::callback(this, UCS_ERR_CANCELED); + return true; + } + return false; + }, + [](auto) -> decltype(terminate) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); + if (terminate) return; + + request(); + + auto log = [this]() { + std::string flushComponent = "unknown"; + if (_endpoint != nullptr) + flushComponent = "endpoint"; + else if (_worker != nullptr) + flushComponent = "worker"; + + if (_enablePythonFuture) + ucxx_trace_req_f(_ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "populateDelayedSubmission, flush (%s), future: %p, future handle: %p", + flushComponent.c_str(), + _future.get(), + _future->getHandle()); + else + ucxx_trace_req_f(_ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "populateDelayedSubmission, flush (%s)", + flushComponent.c_str()); + }; + + std::visit(data::dispatch{ + [this, &log](data::Flush flush) { log(); }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); + + process(); +} + +} // namespace ucxx diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index 300804e9d..aea6f3312 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -591,4 +592,13 @@ bool Worker::amProbe(const ucp_ep_h endpointHandle) const return _amData->_recvPool.find(endpointHandle) != _amData->_recvPool.end(); } +std::shared_ptr Worker::flush(const bool enableFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto worker = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest( + createRequestFlush(worker, data::Flush(), enableFuture, callbackFunction, callbackData)); +} + } // namespace ucxx From 7954db9e3188ffc3a65d37f1b217161556ae6f6f Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 12 Feb 2024 07:16:25 -0800 Subject: [PATCH 6/9] Add direct `ucxx::Endpoint` interfaces to `ucxx::RemoteKey` --- cpp/include/ucxx/endpoint.h | 64 +++++++++++++++++++++++++++++++++++++ cpp/src/endpoint.cpp | 37 +++++++++++++++++++++ 2 files changed, 101 insertions(+) diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 165454790..81cced641 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -372,6 +372,38 @@ class Endpoint : public Component { RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); + /** + * @brief Enqueue a memory put operation. + * + * Enqueue a memory operation, returning a `std::shared` that can be later + * awaited and checked for errors. This is a non-blocking operation, and the status of the + * transfer must be verified from the resulting request object before both local and + * remote data can be released and the remote data can be consumed. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + * @param[in] remoteKey the remote memory key associated with the remote memory + * address. + * @param[in] remoteAddrOffset the destination remote memory address offset where to + * start writing to, `0` means start writing from beginning + * of the base address. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr memPut(void* buffer, + size_t length, + std::shared_ptr remoteKey, + uint64_t remoteAddrOffset = 0, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + /** * @brief Enqueue a memory get operation. * @@ -402,6 +434,38 @@ class Endpoint : public Component { RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); + /** + * @brief Enqueue a memory get operation. + * + * Enqueue a memory operation, returning a `std::shared` that can be later + * awaited and checked for errors. This is a non-blocking operation, and the status of the + * transfer must be verified from the resulting request object before both local and + * remote data can be released and the local data can be consumed. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + * @param[in] remoteKey the remote memory key associated with the remote memory + * address. + * @param[in] remoteAddrOffset the destination remote memory address offset where to + * start reading from, `0` means start writing from + * beginning of the base address. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr memGet(void* buffer, + size_t length, + std::shared_ptr remoteKey, + uint64_t remoteAddrOffset = 0, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + /** * @brief Enqueue a stream send operation. * diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 9f70da448..e2015ea54 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -372,6 +373,24 @@ std::shared_ptr Endpoint::memGet(void* buffer, callbackData)); } +std::shared_ptr Endpoint::memGet(void* buffer, + size_t length, + std::shared_ptr remoteKey, + uint64_t remoteAddressOffset, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest(createRequestMem( + endpoint, + data::MemReceive( + buffer, length, remoteKey->getBaseAddress() + remoteAddressOffset, remoteKey->getHandle()), + enablePythonFuture, + callbackFunction, + callbackData)); +} + std::shared_ptr Endpoint::memPut(void* buffer, size_t length, uint64_t remoteAddr, @@ -388,6 +407,24 @@ std::shared_ptr Endpoint::memPut(void* buffer, callbackData)); } +std::shared_ptr Endpoint::memPut(void* buffer, + size_t length, + std::shared_ptr remoteKey, + uint64_t remoteAddressOffset, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest(createRequestMem( + endpoint, + data::MemSend( + buffer, length, remoteKey->getBaseAddress() + remoteAddressOffset, remoteKey->getHandle()), + enablePythonFuture, + callbackFunction, + callbackData)); +} + std::shared_ptr Endpoint::streamSend(void* buffer, size_t length, const bool enablePythonFuture) From e025e9cffc3523ef3507f9d4ced332864c4b76f6 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 9 Feb 2024 12:43:15 -0800 Subject: [PATCH 7/9] Add `ucxx::RequestMem` tests --- cpp/tests/request.cpp | 181 +++++++++++++++++++++++++++++++++++++++++- cpp/tests/worker.cpp | 48 +++++++++++ 2 files changed, 227 insertions(+), 2 deletions(-) diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 5055653d2..a818bf131 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include #include @@ -14,6 +16,9 @@ #include #include "include/utils.h" +#include "ucxx/buffer.h" +#include "ucxx/constructors.h" +#include "ucxx/utils/ucx.h" namespace { @@ -21,6 +26,8 @@ using ::testing::Combine; using ::testing::ContainerEq; using ::testing::Values; +typedef std::vector DataContainerType; + class RequestTest : public ::testing::TestWithParam< std::tuple> { protected: @@ -39,8 +46,8 @@ class RequestTest : public ::testing::TestWithParam< size_t _rndvThresh{8192}; size_t _numBuffers{0}; - std::vector> _send; - std::vector> _recv; + std::vector _send; + std::vector _recv; std::vector> _sendBuffer; std::vector> _recvBuffer; std::vector _sendPtr{nullptr}; @@ -306,6 +313,176 @@ TEST_P(RequestTest, TagUserCallback) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } +TEST_P(RequestTest, MemoryGet) +{ + if (_bufferType == ucxx::BufferType::RMM) + GTEST_SKIP() << "CUDA memory support not implemented yet"; + allocate(); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, nullptr); + + // Fill memory handle with send data + memcpy(reinterpret_cast(memoryHandle->getBaseAddress()), _sendPtr[0], _messageSize); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back(_ep->memGet(_recvPtr[0], _messageSize, remoteKey)); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + copyResults(); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, MemoryGetPreallocated) +{ + if (_bufferType == ucxx::BufferType::RMM) + GTEST_SKIP() << "CUDA memory support not implemented yet"; + allocate(); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, _sendPtr[0]); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back(_ep->memGet(_recvPtr[0], _messageSize, remoteKey)); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + copyResults(); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, MemoryGetWithOffset) +{ + if (_bufferType == ucxx::BufferType::RMM) + GTEST_SKIP() << "CUDA memory support not implemented yet"; + if (_messageLength < 2) GTEST_SKIP() << "Message too small to perform operations with offsets"; + allocate(); + + size_t offset = 1; + size_t offsetBytes = offset * sizeof(_send[0][0]); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, nullptr); + + // Fill memory handle with send data + memcpy(reinterpret_cast(memoryHandle->getBaseAddress()), _sendPtr[0], _messageSize); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back(_ep->memGet(reinterpret_cast(_recvPtr[0]) + offsetBytes, + _messageSize - offsetBytes, + remoteKey, + offsetBytes)); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + copyResults(); + + // Assert offset data correctness + auto recvOffset = DataContainerType(_recv[0].begin() + offset, _recv[0].end()); + auto sendOffset = DataContainerType(_send[0].begin() + offset, _send[0].end()); + ASSERT_THAT(recvOffset, sendOffset); +} + +TEST_P(RequestTest, MemoryPut) +{ + if (_bufferType == ucxx::BufferType::RMM) + GTEST_SKIP() << "CUDA memory support not implemented yet"; + allocate(); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, nullptr); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back(_ep->memPut(_sendPtr[0], _messageSize, remoteKey)); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + // Copy memory handle data to receive buffer + memcpy(_recvPtr[0], reinterpret_cast(memoryHandle->getBaseAddress()), _messageSize); + + copyResults(); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, MemoryPutPreallocated) +{ + if (_bufferType == ucxx::BufferType::RMM) + GTEST_SKIP() << "CUDA memory support not implemented yet"; + allocate(); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, _recvPtr[0]); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back(_ep->memPut(_sendPtr[0], _messageSize, remoteKey)); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + copyResults(); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, MemoryPutWithOffset) +{ + if (_bufferType == ucxx::BufferType::RMM) + GTEST_SKIP() << "CUDA memory support not implemented yet"; + if (_messageLength < 2) GTEST_SKIP() << "Message too small to perform operations with offsets"; + allocate(); + + size_t offset = 1; + size_t offsetBytes = offset * sizeof(_send[0][0]); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, nullptr); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back(_ep->memPut(reinterpret_cast(_sendPtr[0]) + offsetBytes, + _messageSize - offsetBytes, + remoteKey, + offsetBytes)); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + // Copy memory handle data to receive buffer + memcpy(reinterpret_cast(_recvPtr[0]), + reinterpret_cast(memoryHandle->getBaseAddress()), + _messageSize); + + copyResults(); + + // Assert offset data correctness + auto recvOffset = DataContainerType(_recv[0].begin() + offset, _recv[0].end()); + auto sendOffset = DataContainerType(_send[0].begin() + offset, _send[0].end()); + ASSERT_THAT(recvOffset, sendOffset); +} + INSTANTIATE_TEST_SUITE_P(ProgressModes, RequestTest, Combine(Values(ucxx::BufferType::Host), diff --git a/cpp/tests/worker.cpp b/cpp/tests/worker.cpp index 4cc610561..f41071d09 100644 --- a/cpp/tests/worker.cpp +++ b/cpp/tests/worker.cpp @@ -146,6 +146,54 @@ TEST_P(WorkerProgressTest, ProgressAm) ASSERT_EQ(recvAbstract[0], send[0]); } +TEST_P(WorkerProgressTest, ProgressMemoryGet) +{ + auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + + std::vector send{123}; + std::vector recv(1); + + size_t messageSize = send.size() * sizeof(int); + + auto memoryHandle = _context->createMemoryHandle(messageSize, send.data()); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back( + ep->memGet(recv.data(), messageSize, remoteKey->getBaseAddress(), remoteKey->getHandle())); + requests.push_back(_worker->flush()); + waitRequests(_worker, requests, _progressWorker); + + ASSERT_EQ(recv[0], send[0]); +} + +TEST_P(WorkerProgressTest, ProgressMemoryPut) +{ + auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + + std::vector send{123}; + std::vector recv(1); + + size_t messageSize = send.size() * sizeof(int); + + auto memoryHandle = _context->createMemoryHandle(messageSize, recv.data()); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back( + ep->memPut(send.data(), messageSize, remoteKey->getBaseAddress(), remoteKey->getHandle())); + requests.push_back(_worker->flush()); + waitRequests(_worker, requests, _progressWorker); + + ASSERT_EQ(recv[0], send[0]); +} + TEST_P(WorkerProgressTest, ProgressStream) { auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); From c592ed5fbd8be3bd4f7ed9b0f26851c803662cae Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 12 Feb 2024 07:33:19 -0800 Subject: [PATCH 8/9] Make `MemGet`/`MemPut` naming consistent --- cpp/include/ucxx/constructors.h | 2 +- cpp/include/ucxx/request_data.h | 28 ++-- cpp/include/ucxx/request_mem.h | 10 +- cpp/src/endpoint.cpp | 17 ++- cpp/src/request_data.cpp | 13 +- cpp/src/request_mem.cpp | 218 +++++--------------------------- 6 files changed, 66 insertions(+), 222 deletions(-) diff --git a/cpp/include/ucxx/constructors.h b/cpp/include/ucxx/constructors.h index c393e4a1e..817c09760 100644 --- a/cpp/include/ucxx/constructors.h +++ b/cpp/include/ucxx/constructors.h @@ -85,7 +85,7 @@ std::shared_ptr createRequestTag( std::shared_ptr createRequestMem( std::shared_ptr endpoint, - const std::variant requestData, + const std::variant requestData, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); diff --git a/cpp/include/ucxx/request_data.h b/cpp/include/ucxx/request_data.h index 75b6eda46..db3ab7267 100644 --- a/cpp/include/ucxx/request_data.h +++ b/cpp/include/ucxx/request_data.h @@ -86,7 +86,7 @@ class Flush { * * Type identifying a memory send operation and containing data specific to this request type. */ -class MemSend { +class MemPut { public: const void* _buffer{nullptr}; ///< The raw pointer where data to be sent is stored. const size_t _length{0}; ///< The length of the message. @@ -103,12 +103,12 @@ class MemSend { * @param[in] remoteAddr the destination remote memory address to write to. * @param[in] rkey the remote memory key associated with the remote memory address. */ - explicit MemSend(const decltype(_buffer) buffer, - const decltype(_length) length, - const decltype(_remoteAddr) remoteAddr, - const decltype(_rkey) rkey); + explicit MemPut(const decltype(_buffer) buffer, + const decltype(_length) length, + const decltype(_remoteAddr) remoteAddr, + const decltype(_rkey) rkey); - MemSend() = delete; + MemPut() = delete; }; /** @@ -117,7 +117,7 @@ class MemSend { * Type identifying a memory receive operation and containing data specific to this request * type. */ -class MemReceive { +class MemGet { public: void* _buffer{nullptr}; ///< The raw pointer where received data should be stored. const size_t _length{0}; ///< The length of the message. @@ -134,12 +134,12 @@ class MemReceive { * @param[in] remoteAddr the source remote memory address to read from. * @param[in] rkey the remote memory key associated with the remote memory address. */ - explicit MemReceive(decltype(_buffer) buffer, - const decltype(_length) length, - const decltype(_remoteAddr) remoteAddr, - const decltype(_rkey) rkey); + explicit MemGet(decltype(_buffer) buffer, + const decltype(_length) length, + const decltype(_remoteAddr) remoteAddr, + const decltype(_rkey) rkey); - MemReceive() = delete; + MemGet() = delete; }; /** @@ -308,8 +308,8 @@ using RequestData = std::variant`. * * @param[in] endpoint the `std::shared_ptr` parent component. @@ -55,7 +55,7 @@ class RequestMem : public Request { * @param[in] callbackData user-defined data to pass to the `callbackFunction`. */ RequestMem(std::shared_ptr endpoint, - const std::variant requestData, + const std::variant requestData, const std::string operationName, const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, @@ -72,7 +72,9 @@ class RequestMem : public Request { * be released and the local data (on get operations) or remote data (on put operations) * can be consumed. * - * @throws ucxx::Error if `endpoint` is not a valid `std::shared_ptr`. + * @throws ucxx::Error if `endpointOrWorker` is not a valid + * `std::shared_ptr` or + * `std::shared_ptr`. * * @param[in] endpointOrWorker the parent component, which may either be a * `std::shared_ptr` or @@ -88,7 +90,7 @@ class RequestMem : public Request { */ friend std::shared_ptr createRequestMem( std::shared_ptr endpoint, - const std::variant requestData, + const std::variant requestData, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index e2015ea54..304ef6a85 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -365,12 +365,11 @@ std::shared_ptr Endpoint::memGet(void* buffer, RequestCallbackUserData callbackData) { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest( - createRequestMem(endpoint, - data::MemReceive(buffer, length, remoteAddr, rkey), - enablePythonFuture, - callbackFunction, - callbackData)); + return registerInflightRequest(createRequestMem(endpoint, + data::MemGet(buffer, length, remoteAddr, rkey), + enablePythonFuture, + callbackFunction, + callbackData)); } std::shared_ptr Endpoint::memGet(void* buffer, @@ -384,7 +383,7 @@ std::shared_ptr Endpoint::memGet(void* buffer, auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestMem( endpoint, - data::MemReceive( + data::MemGet( buffer, length, remoteKey->getBaseAddress() + remoteAddressOffset, remoteKey->getHandle()), enablePythonFuture, callbackFunction, @@ -401,7 +400,7 @@ std::shared_ptr Endpoint::memPut(void* buffer, { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestMem(endpoint, - data::MemSend(buffer, length, remoteAddr, rkey), + data::MemPut(buffer, length, remoteAddr, rkey), enablePythonFuture, callbackFunction, callbackData)); @@ -418,7 +417,7 @@ std::shared_ptr Endpoint::memPut(void* buffer, auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestMem( endpoint, - data::MemSend( + data::MemPut( buffer, length, remoteKey->getBaseAddress() + remoteAddressOffset, remoteKey->getHandle()), enablePythonFuture, callbackFunction, diff --git a/cpp/src/request_data.cpp b/cpp/src/request_data.cpp index 7a52627b4..0a51b8506 100644 --- a/cpp/src/request_data.cpp +++ b/cpp/src/request_data.cpp @@ -23,18 +23,15 @@ AmReceive::AmReceive() {} Flush::Flush() {} -MemSend::MemSend(const void* buffer, - const size_t length, - const uint64_t remoteAddr, - const ucp_rkey_h rkey) +MemPut::MemPut(const void* buffer, + const size_t length, + const uint64_t remoteAddr, + const ucp_rkey_h rkey) : _buffer(buffer), _length(length), _remoteAddr(remoteAddr), _rkey(rkey) { } -MemReceive::MemReceive(void* buffer, - const size_t length, - const uint64_t remoteAddr, - const ucp_rkey_h rkey) +MemGet::MemGet(void* buffer, const size_t length, const uint64_t remoteAddr, const ucp_rkey_h rkey) : _buffer(buffer), _length(length), _remoteAddr(remoteAddr), _rkey(rkey) { } diff --git a/cpp/src/request_mem.cpp b/cpp/src/request_mem.cpp index e2592a600..682e90842 100644 --- a/cpp/src/request_mem.cpp +++ b/cpp/src/request_mem.cpp @@ -16,21 +16,20 @@ namespace ucxx { std::shared_ptr createRequestMem( std::shared_ptr endpoint, - const std::variant requestData, + const std::variant requestData, const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr) { std::shared_ptr req = std::visit( data::dispatch{ - [&endpoint, &enablePythonFuture, &callbackFunction, &callbackData](data::MemSend memSend) { + [&endpoint, &enablePythonFuture, &callbackFunction, &callbackData](data::MemPut memPut) { return std::shared_ptr(new RequestMem( - endpoint, memSend, "memSend", enablePythonFuture, callbackFunction, callbackData)); + endpoint, memPut, "memPut", enablePythonFuture, callbackFunction, callbackData)); }, - [&endpoint, &enablePythonFuture, &callbackFunction, &callbackData]( - data::MemReceive memReceive) { + [&endpoint, &enablePythonFuture, &callbackFunction, &callbackData](data::MemGet memGet) { return std::shared_ptr(new RequestMem( - endpoint, memReceive, "memRecv", enablePythonFuture, callbackFunction, callbackData)); + endpoint, memGet, "memGet", enablePythonFuture, callbackFunction, callbackData)); }, }, requestData); @@ -45,7 +44,7 @@ std::shared_ptr createRequestMem( } RequestMem::RequestMem(std::shared_ptr endpoint, - const std::variant requestData, + const std::variant requestData, const std::string operationName, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, @@ -53,11 +52,11 @@ RequestMem::RequestMem(std::shared_ptr endpoint, : Request(endpoint, data::getRequestData(requestData), operationName, enablePythonFuture) { std::visit(data::dispatch{ - [this](data::MemSend memSend) { + [this](data::MemPut memPut) { if (_endpoint == nullptr) throw ucxx::Error("A valid endpoint is required to send memory messages."); }, - [this](data::MemReceive memReceive) { + [this](data::MemGet memGet) { if (_endpoint == nullptr) throw ucxx::Error("A valid endpoint is required to receive memory messages."); }, @@ -69,169 +68,17 @@ RequestMem::RequestMem(std::shared_ptr endpoint, _callbackData = callbackData; } -// static void _memSendCallback(void* request, ucs_status_t status, void* user_data) -// { -// Request* req = reinterpret_cast(user_data); -// ucxx_trace_req_f(req->getOwnerString().c_str(), request, "memSend", "_memSendCallback"); -// req->callback(request, status); -// } - -// static void _recvCompletedCallback(void* request, -// ucs_status_t status, -// size_t length, -// void* user_data) -// { -// internal::RecvMemMessage* recvMemMessage = static_cast(user_data); -// ucxx_trace_req_f( -// recvMemMessage->_request->getOwnerString().c_str(), request, "memRecv", "memRecvCallback"); -// recvMemMessage->callback(request, status); -// } - -// ucs_status_t RequestMem::recvCallback(void* arg, -// const void* header, -// size_t header_length, -// void* data, -// size_t length, -// const ucp_am_recv_param_t* param) -// { -// internal::AmData* memData = static_cast(arg); -// auto worker = memData->_worker.lock(); -// auto& ownerString = memData->_ownerString; -// auto& recvPool = memData->_recvPool; -// auto& recvWait = memData->_recvWait; - -// if ((param->recv_attr & UCP_AM_RECV_ATTR_FIELD_REPLY_EP) == 0) -// ucxx_error("UCP_AM_RECV_ATTR_FIELD_REPLY_EP not set"); - -// ucp_ep_h ep = param->reply_ep; - -// bool is_rndv = param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV; - -// std::shared_ptr buf{nullptr}; -// auto allocatorType = *static_cast(header); - -// std::shared_ptr req{nullptr}; - -// { -// std::lock_guard lock(memData->_mutex); - -// auto reqs = recvWait.find(ep); -// if (reqs != recvWait.end() && !reqs->second.empty()) { -// req = reqs->second.front(); -// reqs->second.pop(); -// ucxx_trace_req("memRecv recvWait: %p", req.get()); -// } else { -// req = std::shared_ptr( -// new RequestMem(worker, worker->isFutureEnabled(), nullptr, nullptr)); -// auto [queue, _] = recvPool.try_emplace(ep, std::queue>()); -// queue->second.push(req); -// ucxx_trace_req("memRecv recvPool: %p", req.get()); -// } -// } - -// if (is_rndv) { -// if (memData->_allocators.find(allocatorType) == memData->_allocators.end()) { -// // TODO: Is a hard failure better? -// // ucxx_debug("Unsupported memory type %d", allocatorType); -// // internal::RecvMemMessage recvMemMessage(memData, ep, req, nullptr); -// // recvMemMessage.callback(nullptr, UCS_ERR_UNSUPPORTED); -// // return UCS_ERR_UNSUPPORTED; - -// ucxx_trace_req("No allocator registered for memory type %d, falling back to host memory.", -// allocatorType); -// allocatorType = UCS_MEMORY_TYPE_HOST; -// } - -// std::shared_ptr buf = memData->_allocators.at(allocatorType)(length); - -// auto recvMemMessage = std::make_shared(memData, ep, req, buf); - -// ucp_request_param_t request_param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | -// UCP_OP_ATTR_FIELD_USER_DATA | -// UCP_OP_ATTR_FLAG_NO_IMM_CMPL, -// .cb = {.recv_am = _recvCompletedCallback}, -// .user_data = recvMemMessage.get()}; - -// ucs_status_ptr_t status = -// ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &request_param); - -// if (req->_enablePythonFuture) -// ucxx_trace_req_f(ownerString.c_str(), -// status, -// "memRecv rndv", -// "ep %p, buffer %p, size %lu, future %p, future handle %p, recvCallback", -// ep, -// buf->data(), -// length, -// req->_future.get(), -// req->_future->getHandle()); -// else -// ucxx_trace_req_f(ownerString.c_str(), -// status, -// "memRecv rndv", -// "ep %p, buffer %p, size %lu, recvCallback", -// ep, -// buf->data(), -// length); - -// if (req->isCompleted()) { -// // The request completed/errored immediately -// ucs_status_t s = UCS_PTR_STATUS(status); -// recvMemMessage->callback(nullptr, s); - -// return s; -// } else { -// // The request will be handled by the callback -// recvMemMessage->setUcpRequest(status); -// memData->_registerInflightRequest(req); - -// { -// std::lock_guard lock(memData->_mutex); -// memData->_recvMemMessageMap.emplace(req.get(), recvMemMessage); -// } - -// return UCS_INPROGRESS; -// } -// } else { -// std::shared_ptr buf = memData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length); -// if (length > 0) memcpy(buf->data(), data, length); - -// if (req->_enablePythonFuture) -// ucxx_trace_req_f(ownerString.c_str(), -// nullptr, -// "memRecv eager", -// "ep: %p, buffer %p, size %lu, future %p, future handle %p, recvCallback", -// ep, -// buf->data(), -// length, -// req->_future.get(), -// req->_future->getHandle()); -// else -// ucxx_trace_req_f(ownerString.c_str(), -// nullptr, -// "memRecv eager", -// "ep: %p, buffer %p, size %lu, recvCallback", -// ep, -// buf->data(), -// length); - -// internal::RecvMemMessage recvMemMessage(memData, ep, req, buf); -// recvMemMessage.callback(nullptr, UCS_OK); -// return UCS_OK; -// } -// } - void RequestMem::memPutCallback(void* request, ucs_status_t status, void* arg) { Request* req = reinterpret_cast(arg); - ucxx_trace_req_f(req->getOwnerString().c_str(), nullptr, request, "memSend", "memPutCallback"); + ucxx_trace_req_f(req->getOwnerString().c_str(), nullptr, request, "memPut", "memPutCallback"); return req->callback(request, status); } void RequestMem::memGetCallback(void* request, ucs_status_t status, void* arg) { Request* req = reinterpret_cast(arg); - ucxx_trace_req_f(req->getOwnerString().c_str(), nullptr, request, "memRecv", "memGetCallback"); + ucxx_trace_req_f(req->getOwnerString().c_str(), nullptr, request, "memGet", "memGetCallback"); return req->callback(request, status); } @@ -247,22 +94,22 @@ void RequestMem::request() void* request = nullptr; std::visit(data::dispatch{ - [this, &request, ¶m](data::MemSend memSend) { + [this, &request, ¶m](data::MemPut memPut) { param.cb.send = memPutCallback; request = ucp_put_nbx(_endpoint->getHandle(), - memSend._buffer, - memSend._length, - memSend._remoteAddr, - memSend._rkey, + memPut._buffer, + memPut._length, + memPut._remoteAddr, + memPut._rkey, ¶m); }, - [this, &request, ¶m](data::MemReceive memReceive) { + [this, &request, ¶m](data::MemGet memGet) { param.cb.send = memGetCallback; request = ucp_get_nbx(_endpoint->getHandle(), - memReceive._buffer, - memReceive._length, - memReceive._remoteAddr, - memReceive._rkey, + memGet._buffer, + memGet._length, + memGet._remoteAddr, + memGet._rkey, ¶m); }, [](auto) { throw std::runtime_error("Unreachable"); }, @@ -279,7 +126,7 @@ void RequestMem::populateDelayedSubmission() { bool terminate = std::visit(data::dispatch{ - [this](data::MemSend memSend) { + [this](data::MemPut memPut) { if (_endpoint->getHandle() == nullptr) { ucxx_warn("Endpoint was closed before message could be sent"); Request::callback(this, UCS_ERR_CANCELED); @@ -287,7 +134,7 @@ void RequestMem::populateDelayedSubmission() } return false; }, - [this](data::MemReceive memReceive) { + [this](data::MemGet memGet) { if (_worker->getHandle() == nullptr) { ucxx_warn("Endpoint was closed before message could be received"); Request::callback(this, UCS_ERR_CANCELED); @@ -332,17 +179,16 @@ void RequestMem::populateDelayedSubmission() rkey); }; - std::visit( - data::dispatch{ - [this, &log](data::MemSend memSend) { - log(memSend._buffer, memSend._length, memSend._remoteAddr, memSend._rkey); - }, - [this, &log](data::MemReceive memReceive) { - log(memReceive._buffer, memReceive._length, memReceive._remoteAddr, memReceive._rkey); - }, - [](auto) { throw std::runtime_error("Unreachable"); }, - }, - _requestData); + std::visit(data::dispatch{ + [this, &log](data::MemPut memPut) { + log(memPut._buffer, memPut._length, memPut._remoteAddr, memPut._rkey); + }, + [this, &log](data::MemGet memGet) { + log(memGet._buffer, memGet._length, memGet._remoteAddr, memGet._rkey); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); process(); } From 04ec51dad7cc5dbde4cf71540efbdd020b8ab5f4 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 19 Feb 2024 07:41:00 -0800 Subject: [PATCH 9/9] Avoid `std::variant` for single type --- cpp/include/ucxx/constructors.h | 2 +- cpp/include/ucxx/request_flush.h | 4 +- cpp/src/request_flush.cpp | 128 +++++++++++-------------------- 3 files changed, 48 insertions(+), 86 deletions(-) diff --git a/cpp/include/ucxx/constructors.h b/cpp/include/ucxx/constructors.h index 817c09760..e50e3ee62 100644 --- a/cpp/include/ucxx/constructors.h +++ b/cpp/include/ucxx/constructors.h @@ -66,7 +66,7 @@ std::shared_ptr createRequestAm( RequestCallbackUserData callbackData); std::shared_ptr createRequestFlush(std::shared_ptr endpointOrWorker, - const std::variant requestData, + const data::Flush requestData, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); diff --git a/cpp/include/ucxx/request_flush.h b/cpp/include/ucxx/request_flush.h index c637d49d4..9502dfe47 100644 --- a/cpp/include/ucxx/request_flush.h +++ b/cpp/include/ucxx/request_flush.h @@ -53,7 +53,7 @@ class RequestFlush : public Request { * @param[in] callbackData user-defined data to pass to the `callbackFunction`. */ RequestFlush(std::shared_ptr endpointOrWorker, - const std::variant requestData, + const data::Flush requestData, const std::string operationName, const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, @@ -88,7 +88,7 @@ class RequestFlush : public Request { */ friend std::shared_ptr createRequestFlush( std::shared_ptr endpointOrWorker, - const std::variant requestData, + const data::Flush requestData, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); diff --git a/cpp/src/request_flush.cpp b/cpp/src/request_flush.cpp index bcba03702..5bc0ea9e9 100644 --- a/cpp/src/request_flush.cpp +++ b/cpp/src/request_flush.cpp @@ -16,20 +16,13 @@ namespace ucxx { std::shared_ptr createRequestFlush( std::shared_ptr endpointOrWorker, - const std::variant requestData, + const data::Flush requestData, const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr) { - std::shared_ptr req = std::visit( - data::dispatch{ - [&endpointOrWorker, &enablePythonFuture, &callbackFunction, &callbackData]( - data::Flush flush) { - return std::shared_ptr(new RequestFlush( - endpointOrWorker, flush, "flush", enablePythonFuture, callbackFunction, callbackData)); - }, - }, - requestData); + auto req = std::shared_ptr(new RequestFlush( + endpointOrWorker, requestData, "flush", enablePythonFuture, callbackFunction, callbackData)); // A delayed notification request is not populated immediately, instead it is // delayed to allow the worker progress thread to set its status, and more @@ -41,22 +34,15 @@ std::shared_ptr createRequestFlush( } RequestFlush::RequestFlush(std::shared_ptr endpointOrWorker, - const std::variant requestData, + const data::Flush requestData, const std::string operationName, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) - : Request(endpointOrWorker, data::getRequestData(requestData), operationName, enablePythonFuture) + : Request(endpointOrWorker, requestData, operationName, enablePythonFuture) { - std::visit( - data::dispatch{ - [this](data::Flush) { - if (_endpoint == nullptr && _worker == nullptr) - throw ucxx::Error("A valid endpoint or worker is required for a flush operation."); - }, - [](auto) { throw std::runtime_error("Unreachable"); }, - }, - requestData); + if (_endpoint == nullptr && _worker == nullptr) + throw ucxx::Error("A valid endpoint or worker is required for a flush operation."); _callback = callbackFunction; _callbackData = callbackData; @@ -76,20 +62,13 @@ void RequestFlush::request() void* request = nullptr; - std::visit( - data::dispatch{ - [this, &request, ¶m](data::Flush) { - param.cb.send = flushCallback; - if (_endpoint != nullptr) - request = ucp_ep_flush_nbx(_endpoint->getHandle(), ¶m); - else if (_worker != nullptr) - request = ucp_worker_flush_nbx(_worker->getHandle(), ¶m); - else - throw ucxx::Error("A valid endpoint or worker is required for a flush operation."); - }, - [](auto) { throw std::runtime_error("Unreachable"); }, - }, - _requestData); + param.cb.send = flushCallback; + if (_endpoint != nullptr) + request = ucp_ep_flush_nbx(_endpoint->getHandle(), ¶m); + else if (_worker != nullptr) + request = ucp_worker_flush_nbx(_worker->getHandle(), ¶m); + else + throw ucxx::Error("A valid endpoint or worker is required for a flush operation."); std::lock_guard lock(_mutex); _request = request; @@ -99,57 +78,40 @@ static void logPopulateDelayedSubmission() {} void RequestFlush::populateDelayedSubmission() { - bool terminate = - std::visit(data::dispatch{ - [this](data::Flush flush) { - if (_endpoint != nullptr && _endpoint->getHandle() == nullptr) { - ucxx_warn("Endpoint was closed before it could be flushed"); - Request::callback(this, UCS_ERR_CANCELED); - return true; - } else if (_worker != nullptr && _worker->getHandle() == nullptr) { - ucxx_warn("Worker was closed before it could be flushed"); - Request::callback(this, UCS_ERR_CANCELED); - return true; - } - return false; - }, - [](auto) -> decltype(terminate) { throw std::runtime_error("Unreachable"); }, - }, - _requestData); - if (terminate) return; + if (_endpoint != nullptr && _endpoint->getHandle() == nullptr) { + ucxx_warn("Endpoint was closed before it could be flushed"); + Request::callback(this, UCS_ERR_CANCELED); + return; + } else if (_worker != nullptr && _worker->getHandle() == nullptr) { + ucxx_warn("Worker was closed before it could be flushed"); + Request::callback(this, UCS_ERR_CANCELED); + return; + } request(); - auto log = [this]() { - std::string flushComponent = "unknown"; - if (_endpoint != nullptr) - flushComponent = "endpoint"; - else if (_worker != nullptr) - flushComponent = "worker"; - - if (_enablePythonFuture) - ucxx_trace_req_f(_ownerString.c_str(), - this, - _request, - _operationName.c_str(), - "populateDelayedSubmission, flush (%s), future: %p, future handle: %p", - flushComponent.c_str(), - _future.get(), - _future->getHandle()); - else - ucxx_trace_req_f(_ownerString.c_str(), - this, - _request, - _operationName.c_str(), - "populateDelayedSubmission, flush (%s)", - flushComponent.c_str()); - }; - - std::visit(data::dispatch{ - [this, &log](data::Flush flush) { log(); }, - [](auto) { throw std::runtime_error("Unreachable"); }, - }, - _requestData); + std::string flushComponent = "unknown"; + if (_endpoint != nullptr) + flushComponent = "endpoint"; + else if (_worker != nullptr) + flushComponent = "worker"; + + if (_enablePythonFuture) + ucxx_trace_req_f(_ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "populateDelayedSubmission, flush (%s), future: %p, future handle: %p", + flushComponent.c_str(), + _future.get(), + _future->getHandle()); + else + ucxx_trace_req_f(_ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "populateDelayedSubmission, flush (%s)", + flushComponent.c_str()); process(); }