diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4f2c19a5a..4e1196aab 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -135,7 +135,9 @@ add_library( src/request.cpp src/request_am.cpp src/request_data.cpp + src/request_flush.cpp src/request_helper.cpp + src/request_mem.cpp src/request_stream.cpp src/request_tag.cpp src/request_tag_multi.cpp diff --git a/cpp/include/ucxx/constructors.h b/cpp/include/ucxx/constructors.h index cc6012885..f04326b51 100644 --- a/cpp/include/ucxx/constructors.h +++ b/cpp/include/ucxx/constructors.h @@ -23,6 +23,8 @@ class Notifier; class RemoteKey; class Request; class RequestAm; +class RequestFlush; +class RequestMem; class RequestStream; class RequestTag; class RequestTagMulti; @@ -75,6 +77,12 @@ std::shared_ptr createRequestAm( RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); +std::shared_ptr createRequestFlush(std::shared_ptr endpointOrWorker, + const data::Flush requestData, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + std::shared_ptr createRequestStream( std::shared_ptr endpoint, const std::variant requestData, @@ -87,6 +95,13 @@ std::shared_ptr createRequestTag( RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); +std::shared_ptr createRequestMem( + std::shared_ptr endpoint, + const std::variant requestData, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + std::shared_ptr createRequestTagMulti( std::shared_ptr endpoint, const std::variant requestData, diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 489e2aa53..81cced641 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -342,6 +342,130 @@ class Endpoint : public Component { RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); + /** + * @brief Enqueue a memory put operation. + * + * Enqueue a memory operation, returning a `std::shared` that can be later + * awaited and checked for errors. This is a non-blocking operation, and the status of the + * transfer must be verified from the resulting request object before both local and + * remote data can be released and the remote data can be consumed. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + * @param[in] remoteAddr the destination remote memory address to write to. + * @param[in] rkey the remote memory key associated with the remote memory + * address. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr memPut(void* buffer, + size_t length, + uint64_t remote_addr, + ucp_rkey_h rkey, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + /** + * @brief Enqueue a memory put operation. + * + * Enqueue a memory operation, returning a `std::shared` that can be later + * awaited and checked for errors. This is a non-blocking operation, and the status of the + * transfer must be verified from the resulting request object before both local and + * remote data can be released and the remote data can be consumed. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + * @param[in] remoteKey the remote memory key associated with the remote memory + * address. + * @param[in] remoteAddrOffset the destination remote memory address offset where to + * start writing to, `0` means start writing from beginning + * of the base address. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr memPut(void* buffer, + size_t length, + std::shared_ptr remoteKey, + uint64_t remoteAddrOffset = 0, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + /** + * @brief Enqueue a memory get operation. + * + * Enqueue a memory operation, returning a `std::shared` that can be later + * awaited and checked for errors. This is a non-blocking operation, and the status of the + * transfer must be verified from the resulting request object before both local and + * remote data can be released and the local data can be consumed. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + * @param[in] remoteAddr the source remote memory address to read from. + * @param[in] rkey the remote memory key associated with the remote memory + * address. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr memGet(void* buffer, + size_t length, + uint64_t remoteAddr, + ucp_rkey_h rkey, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + /** + * @brief Enqueue a memory get operation. + * + * Enqueue a memory operation, returning a `std::shared` that can be later + * awaited and checked for errors. This is a non-blocking operation, and the status of the + * transfer must be verified from the resulting request object before both local and + * remote data can be released and the local data can be consumed. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + * @param[in] remoteKey the remote memory key associated with the remote memory + * address. + * @param[in] remoteAddrOffset the destination remote memory address offset where to + * start reading from, `0` means start writing from + * beginning of the base address. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr memGet(void* buffer, + size_t length, + std::shared_ptr remoteKey, + uint64_t remoteAddrOffset = 0, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + /** * @brief Enqueue a stream send operation. * @@ -511,6 +635,31 @@ class Endpoint : public Component { const TagMask tagMask, const bool enablePythonFuture); + /** + * @brief Enqueue a flush operation. + * + * Enqueue request to flush outstanding AMO (Atomic Memory Operation) and RMA (Remote + * Memory Access) operations on the endpoint, returning a pointer to a request object that + * can be later awaited and checked for errors. This is a non-blocking operation, and its + * status must be verified from the resulting request object to confirm the flush + * operation has completed successfully. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr flush(const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + /** * @brief Get `ucxx::Worker` component from a worker or listener object. * diff --git a/cpp/include/ucxx/request_data.h b/cpp/include/ucxx/request_data.h index 8143dad78..db3ab7267 100644 --- a/cpp/include/ucxx/request_data.h +++ b/cpp/include/ucxx/request_data.h @@ -66,6 +66,82 @@ class AmReceive { AmReceive(); }; +/** + * @brief Data for a flush operation. + * + * Type identifying a flush operation and containing data specific to this request type. + */ +class Flush { + public: + /** + * @brief Constructor for flush-specific data. + * + * Construct an object containing flush-specific data. + */ + Flush(); +}; + +/** + * @brief Data for a memory send. + * + * Type identifying a memory send operation and containing data specific to this request type. + */ +class MemPut { + public: + const void* _buffer{nullptr}; ///< The raw pointer where data to be sent is stored. + const size_t _length{0}; ///< The length of the message. + const uint64_t _remoteAddr{0}; ///< Remote memory address to write to. + const ucp_rkey_h _rkey{}; ///< UCX remote key associated with the remote memory address. + + /** + * @brief Constructor for memory-specific data. + * + * Construct an object containing memory-specific data. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + * @param[in] remoteAddr the destination remote memory address to write to. + * @param[in] rkey the remote memory key associated with the remote memory address. + */ + explicit MemPut(const decltype(_buffer) buffer, + const decltype(_length) length, + const decltype(_remoteAddr) remoteAddr, + const decltype(_rkey) rkey); + + MemPut() = delete; +}; + +/** + * @brief Data for a memory receive. + * + * Type identifying a memory receive operation and containing data specific to this request + * type. + */ +class MemGet { + public: + void* _buffer{nullptr}; ///< The raw pointer where received data should be stored. + const size_t _length{0}; ///< The length of the message. + const uint64_t _remoteAddr{0}; ///< Remote memory address to read from. + const ucp_rkey_h _rkey{}; ///< UCX remote key associated with the remote memory address. + + /** + * @brief Constructor for memory-specific data. + * + * Construct an object containing memory-specific data. + * + * @param[out] buffer a raw pointer to the received data. + * @param[in] length the size in bytes of the tag message to be received. + * @param[in] remoteAddr the source remote memory address to read from. + * @param[in] rkey the remote memory key associated with the remote memory address. + */ + explicit MemGet(decltype(_buffer) buffer, + const decltype(_length) length, + const decltype(_remoteAddr) remoteAddr, + const decltype(_rkey) rkey); + + MemGet() = delete; +}; + /** * @brief Data for a Stream send. * @@ -127,7 +203,7 @@ class TagSend { const ::ucxx::Tag _tag{0}; ///< Tag to match /** - * @brief Constructor for tag/multi-buffer tag-specific data. + * @brief Constructor for tag-specific data. * * Construct an object containing tag-specific data. * @@ -156,7 +232,7 @@ class TagReceive { const ::ucxx::TagMask _tagMask{0}; ///< Tag mask to use /** - * @brief Constructor send tag-specific data. + * @brief Constructor for tag-specific data. * * Construct an object containing send tag-specific data. * @@ -231,6 +307,9 @@ class TagMultiReceive { using RequestData = std::variant +#include +#include + +#include + +#include +#include +#include + +namespace ucxx { + +/** + * @brief Flush a UCP endpoint or worker. + * + * Flush outstanding AMO (Atomic Memory Operation) and RMA (Remote Memory Access) operations + * on a UCP endpoint or worker. + */ +class RequestFlush : public Request { + private: + /** + * @brief Private constructor of `ucxx::RequestFlush`. + * + * This is the internal implementation of `ucxx::RequestFlush` constructor, made private + * not to be called directly. This constructor is made private to ensure all UCXX objects + * are shared pointers and the correct lifetime management of each one. + * + * Instead the user should use one of the following: + * + * - `ucxx::Endpoint::flush()` + * - `ucxx::Worker::flush()` + * - `ucxx::createRequestFlush()` + * + * @throws ucxx::Error if `endpointOrWorker` is not a valid + * `std::shared_ptr` or + * `std::shared_ptr`. + * + * @param[in] endpointOrWorker the parent component, which may either be a + * `std::shared_ptr` or + * `std::shared_ptr`. + * @param[in] requestData container of the specified message type, including all + * type-specific data. + * @param[in] operationName a human-readable operation name to help identifying + * requests by their types when UCXX logging is enabled. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + */ + RequestFlush(std::shared_ptr endpointOrWorker, + const data::Flush requestData, + const std::string operationName, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + public: + /** + * @brief Constructor for `std::shared_ptr`. + * + * The constructor for a `std::shared_ptr` object, creating a request + * to flush outstanding AMO (Atomic Memory Operation) and RMA (Remote Memory Access) + * operations on a UCP endpoint or worker, returning a pointer to a request object that + * can be later awaited and checked for errors. This is a non-blocking operation, and its + * status must be verified from the resulting request object to confirm the flush + * operation has completed successfully. + * + * @throws ucxx::Error `endpointOrWorker` is not a valid + * `std::shared_ptr` or + * `std::shared_ptr`. + * + * @param[in] endpointOrWorker the parent component, which may either be a + * `std::shared_ptr` or + * `std::shared_ptr`. + * @param[in] requestData container of the specified message type, including all + * type-specific data. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns The `shared_ptr` object + */ + friend std::shared_ptr createRequestFlush( + std::shared_ptr endpointOrWorker, + const data::Flush requestData, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + + virtual void populateDelayedSubmission(); + + /** + * @brief Create and submit a flush request. + * + * This is the method that should be called to actually submit a flush request. It is + * meant to be called from `populateDelayedSubmission()`, which is decided at the + * discretion of `std::shared_ptr`. See `populateDelayedSubmission()` for + * more details. + */ + void request(); + + /** + * @brief Callback executed by UCX when a flush request is completed. + * + * Callback executed by UCX when a flush request is completed, that will dispatch + * `ucxx::Request::callback()`. + * + * WARNING: This is not intended to be called by the user, but it currently needs to be + * a public method so that UCX may access it. In future changes this will be moved to + * an internal object and remove this method from the public API. + * + * @param[in] request the UCX request pointer. + * @param[in] status the completion status of the request. + * @param[in] arg the pointer to the `ucxx::Request` object that created the + * transfer, effectively `this` pointer as seen by `request()`. + */ + static void flushCallback(void* request, ucs_status_t status, void* arg); +}; + +} // namespace ucxx diff --git a/cpp/include/ucxx/request_mem.h b/cpp/include/ucxx/request_mem.h new file mode 100644 index 000000000..31cb0d1c3 --- /dev/null +++ b/cpp/include/ucxx/request_mem.h @@ -0,0 +1,145 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once +#include +#include +#include + +#include + +#include +#include +#include + +namespace ucxx { + +class Buffer; + +namespace internal { +class RecvMemMessage; +} // namespace internal + +class RequestMem : public Request { + private: + friend class internal::RecvMemMessage; + + uint64_t _remote_addr{}; + ucp_rkey_h _rkey{}; + + /** + * @brief Private constructor of `ucxx::RequestMem`. + * + * This is the internal implementation of `ucxx::RequestMem` constructor, made private not + * to be called directly. This constructor is made private to ensure all UCXX objects + * are shared pointers and the correct lifetime management of each one. + * + * Instead the user should use one of the following: + * + * - `ucxx::Endpoint::memGet()` + * - `ucxx::Endpoint::memPut()` + * - `ucxx::createRequestMem()` + * + * @throws ucxx::Error if `endpoint` is not a valid + * `std::shared_ptr`. + * + * @param[in] endpoint the `std::shared_ptr` parent component. + * @param[in] requestData container of the specified message type, including all + * type-specific data. + * @param[in] operationName a human-readable operation name to help identifying + * requests by their types when UCXX logging is enabled. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + */ + RequestMem(std::shared_ptr endpoint, + const std::variant requestData, + const std::string operationName, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + public: + /** + * @brief Constructor for `std::shared_ptr`. + * + * The constructor for a `std::shared_ptr` object, creating a get or put + * request, returning a pointer to a request object that can be later awaited and checked + * for errors. This is a non-blocking operation, and the status of the transfer must be + * verified from the resulting request object before both the local and remote data can + * be released and the local data (on get operations) or remote data (on put operations) + * can be consumed. + * + * @throws ucxx::Error if `endpointOrWorker` is not a valid + * `std::shared_ptr` or + * `std::shared_ptr`. + * + * @param[in] endpointOrWorker the parent component, which may either be a + * `std::shared_ptr` or + * `std::shared_ptr`. + * @param[in] requestData container of the specified message type, including all + * type-specific data. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns The `shared_ptr` object + */ + friend std::shared_ptr createRequestMem( + std::shared_ptr endpoint, + const std::variant requestData, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + + virtual void populateDelayedSubmission(); + + /** + * @brief Callback executed by UCX when a memory put request is completed. + * + * Callback executed by UCX when a memory put request is completed, that will dispatch + * `ucxx::Request::callback()`. + * + * WARNING: This is not intended to be called by the user, but it currently needs to be + * a public method so that UCX may access it. In future changes this will be moved to + * an internal object and remove this method from the public API. + * + * @param[in] request the UCX request pointer. + * @param[in] status the completion status of the request. + * @param[in] arg the pointer to the `ucxx::Request` object that created the + * transfer, effectively `this` pointer as seen by `request()`. + */ + static void memPutCallback(void* request, ucs_status_t status, void* arg); + + /** + * @brief Callback executed by UCX when a memory get request is completed. + * + * Callback executed by UCX when a memory get request is completed, that will dispatch + * `ucxx::Request::callback()`. + * + * WARNING: This is not intended to be called by the user, but it currently needs to be + * a public method so that UCX may access it. In future changes this will be moved to + * an internal object and remove this method from the public API. + * + * @param[in] request the UCX request pointer. + * @param[in] status the completion status of the request. + * @param[in] arg the pointer to the `ucxx::Request` object that created the + * transfer, effectively `this` pointer as seen by `request()`. + */ + static void memGetCallback(void* request, ucs_status_t status, void* arg); + + /** + * @brief Create and submit a memory get or put request. + * + * This is the method that should be called to actually submit memory request get or put + * request. It is meant to be called from `populateDelayedSubmission()`, which is decided + * at the discretion of `std::shared_ptr`. See `populateDelayedSubmission()` + * for more details. + */ + void request(); +}; + +} // namespace ucxx diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index f6ad17cce..e1c23f9ae 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -819,6 +819,31 @@ class Worker : public Component { * @returns `true` if any uncaught messages were received, `false` otherwise. */ bool amProbe(const ucp_ep_h endpointHandle) const; + + /** + * @brief Enqueue a flush operation. + + * Enqueue request to flush outstanding AMO (Atomic Memory Operation) and RMA (Remote + * Memory Access) operations on the worker, returning a pointer to a request object that + * can be later awaited and checked for errors. This is a non-blocking operation, and its + * status must be verified from the resulting request object to confirm the flush + * operation has completed successfully. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr flush(const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); }; } // namespace ucxx diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index cddc31bee..304ef6a85 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -15,8 +15,11 @@ #include #include #include +#include #include #include +#include +#include #include #include #include @@ -353,6 +356,74 @@ std::shared_ptr Endpoint::amRecv(const bool enablePythonFuture, endpoint, data::AmReceive(), enablePythonFuture, callbackFunction, callbackData)); } +std::shared_ptr Endpoint::memGet(void* buffer, + size_t length, + uint64_t remoteAddr, + ucp_rkey_h rkey, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest(createRequestMem(endpoint, + data::MemGet(buffer, length, remoteAddr, rkey), + enablePythonFuture, + callbackFunction, + callbackData)); +} + +std::shared_ptr Endpoint::memGet(void* buffer, + size_t length, + std::shared_ptr remoteKey, + uint64_t remoteAddressOffset, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest(createRequestMem( + endpoint, + data::MemGet( + buffer, length, remoteKey->getBaseAddress() + remoteAddressOffset, remoteKey->getHandle()), + enablePythonFuture, + callbackFunction, + callbackData)); +} + +std::shared_ptr Endpoint::memPut(void* buffer, + size_t length, + uint64_t remoteAddr, + ucp_rkey_h rkey, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest(createRequestMem(endpoint, + data::MemPut(buffer, length, remoteAddr, rkey), + enablePythonFuture, + callbackFunction, + callbackData)); +} + +std::shared_ptr Endpoint::memPut(void* buffer, + size_t length, + std::shared_ptr remoteKey, + uint64_t remoteAddressOffset, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest(createRequestMem( + endpoint, + data::MemPut( + buffer, length, remoteKey->getBaseAddress() + remoteAddressOffset, remoteKey->getHandle()), + enablePythonFuture, + callbackFunction, + callbackData)); +} + std::shared_ptr Endpoint::streamSend(void* buffer, size_t length, const bool enablePythonFuture) @@ -422,6 +493,15 @@ std::shared_ptr Endpoint::tagMultiRecv(const Tag tag, createRequestTagMulti(endpoint, data::TagMultiReceive(tag, tagMask), enablePythonFuture)); } +std::shared_ptr Endpoint::flush(const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest(createRequestFlush( + endpoint, data::Flush(), enablePythonFuture, callbackFunction, callbackData)); +} + std::shared_ptr Endpoint::getWorker() { return ::ucxx::getWorker(_parent); } void Endpoint::errorCallback(void* arg, ucp_ep_h ep, ucs_status_t status) diff --git a/cpp/src/request_data.cpp b/cpp/src/request_data.cpp index 3bc1efd46..0a51b8506 100644 --- a/cpp/src/request_data.cpp +++ b/cpp/src/request_data.cpp @@ -6,6 +6,7 @@ #include +#include #include #include @@ -20,6 +21,21 @@ AmSend::AmSend(const void* buffer, const size_t length, const ucs_memory_type me AmReceive::AmReceive() {} +Flush::Flush() {} + +MemPut::MemPut(const void* buffer, + const size_t length, + const uint64_t remoteAddr, + const ucp_rkey_h rkey) + : _buffer(buffer), _length(length), _remoteAddr(remoteAddr), _rkey(rkey) +{ +} + +MemGet::MemGet(void* buffer, const size_t length, const uint64_t remoteAddr, const ucp_rkey_h rkey) + : _buffer(buffer), _length(length), _remoteAddr(remoteAddr), _rkey(rkey) +{ +} + StreamSend::StreamSend(const void* buffer, const size_t length) : _buffer(buffer), _length(length) { /** diff --git a/cpp/src/request_flush.cpp b/cpp/src/request_flush.cpp new file mode 100644 index 000000000..5bc0ea9e9 --- /dev/null +++ b/cpp/src/request_flush.cpp @@ -0,0 +1,119 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include +#include +#include + +#include + +#include +#include +#include + +namespace ucxx { + +std::shared_ptr createRequestFlush( + std::shared_ptr endpointOrWorker, + const data::Flush requestData, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr) +{ + auto req = std::shared_ptr(new RequestFlush( + endpointOrWorker, requestData, "flush", enablePythonFuture, callbackFunction, callbackData)); + + // A delayed notification request is not populated immediately, instead it is + // delayed to allow the worker progress thread to set its status, and more + // importantly the Python future later on, so that we don't need the GIL here. + req->_worker->registerDelayedSubmission( + req, std::bind(std::mem_fn(&Request::populateDelayedSubmission), req.get())); + + return req; +} + +RequestFlush::RequestFlush(std::shared_ptr endpointOrWorker, + const data::Flush requestData, + const std::string operationName, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) + : Request(endpointOrWorker, requestData, operationName, enablePythonFuture) +{ + if (_endpoint == nullptr && _worker == nullptr) + throw ucxx::Error("A valid endpoint or worker is required for a flush operation."); + + _callback = callbackFunction; + _callbackData = callbackData; +} + +void RequestFlush::flushCallback(void* request, ucs_status_t status, void* arg) +{ + Request* req = reinterpret_cast(arg); + ucxx_trace_req_f(req->getOwnerString().c_str(), nullptr, request, "flush", "flushCallback"); + return req->callback(request, status); +} + +void RequestFlush::request() +{ + ucp_request_param_t param = { + .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, .user_data = this}; + + void* request = nullptr; + + param.cb.send = flushCallback; + if (_endpoint != nullptr) + request = ucp_ep_flush_nbx(_endpoint->getHandle(), ¶m); + else if (_worker != nullptr) + request = ucp_worker_flush_nbx(_worker->getHandle(), ¶m); + else + throw ucxx::Error("A valid endpoint or worker is required for a flush operation."); + + std::lock_guard lock(_mutex); + _request = request; +} + +static void logPopulateDelayedSubmission() {} + +void RequestFlush::populateDelayedSubmission() +{ + if (_endpoint != nullptr && _endpoint->getHandle() == nullptr) { + ucxx_warn("Endpoint was closed before it could be flushed"); + Request::callback(this, UCS_ERR_CANCELED); + return; + } else if (_worker != nullptr && _worker->getHandle() == nullptr) { + ucxx_warn("Worker was closed before it could be flushed"); + Request::callback(this, UCS_ERR_CANCELED); + return; + } + + request(); + + std::string flushComponent = "unknown"; + if (_endpoint != nullptr) + flushComponent = "endpoint"; + else if (_worker != nullptr) + flushComponent = "worker"; + + if (_enablePythonFuture) + ucxx_trace_req_f(_ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "populateDelayedSubmission, flush (%s), future: %p, future handle: %p", + flushComponent.c_str(), + _future.get(), + _future->getHandle()); + else + ucxx_trace_req_f(_ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "populateDelayedSubmission, flush (%s)", + flushComponent.c_str()); + + process(); +} + +} // namespace ucxx diff --git a/cpp/src/request_mem.cpp b/cpp/src/request_mem.cpp new file mode 100644 index 000000000..682e90842 --- /dev/null +++ b/cpp/src/request_mem.cpp @@ -0,0 +1,196 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include +#include +#include + +#include + +#include +#include +#include + +namespace ucxx { + +std::shared_ptr createRequestMem( + std::shared_ptr endpoint, + const std::variant requestData, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr) +{ + std::shared_ptr req = std::visit( + data::dispatch{ + [&endpoint, &enablePythonFuture, &callbackFunction, &callbackData](data::MemPut memPut) { + return std::shared_ptr(new RequestMem( + endpoint, memPut, "memPut", enablePythonFuture, callbackFunction, callbackData)); + }, + [&endpoint, &enablePythonFuture, &callbackFunction, &callbackData](data::MemGet memGet) { + return std::shared_ptr(new RequestMem( + endpoint, memGet, "memGet", enablePythonFuture, callbackFunction, callbackData)); + }, + }, + requestData); + + // A delayed notification request is not populated immediately, instead it is + // delayed to allow the worker progress thread to set its status, and more + // importantly the Python future later on, so that we don't need the GIL here. + req->_worker->registerDelayedSubmission( + req, std::bind(std::mem_fn(&Request::populateDelayedSubmission), req.get())); + + return req; +} + +RequestMem::RequestMem(std::shared_ptr endpoint, + const std::variant requestData, + const std::string operationName, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) + : Request(endpoint, data::getRequestData(requestData), operationName, enablePythonFuture) +{ + std::visit(data::dispatch{ + [this](data::MemPut memPut) { + if (_endpoint == nullptr) + throw ucxx::Error("A valid endpoint is required to send memory messages."); + }, + [this](data::MemGet memGet) { + if (_endpoint == nullptr) + throw ucxx::Error("A valid endpoint is required to receive memory messages."); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + requestData); + + _callback = callbackFunction; + _callbackData = callbackData; +} + +void RequestMem::memPutCallback(void* request, ucs_status_t status, void* arg) +{ + Request* req = reinterpret_cast(arg); + ucxx_trace_req_f(req->getOwnerString().c_str(), nullptr, request, "memPut", "memPutCallback"); + return req->callback(request, status); +} + +void RequestMem::memGetCallback(void* request, ucs_status_t status, void* arg) +{ + Request* req = reinterpret_cast(arg); + ucxx_trace_req_f(req->getOwnerString().c_str(), nullptr, request, "memGet", "memGetCallback"); + return req->callback(request, status); +} + +void RequestMem::request() +{ + ucp_request_param_t param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_FLAGS | + UCP_OP_ATTR_FIELD_USER_DATA, + .flags = UCP_AM_SEND_FLAG_REPLY, + .datatype = ucp_dt_make_contig(1), + .user_data = this}; + + void* request = nullptr; + + std::visit(data::dispatch{ + [this, &request, ¶m](data::MemPut memPut) { + param.cb.send = memPutCallback; + request = ucp_put_nbx(_endpoint->getHandle(), + memPut._buffer, + memPut._length, + memPut._remoteAddr, + memPut._rkey, + ¶m); + }, + [this, &request, ¶m](data::MemGet memGet) { + param.cb.send = memGetCallback; + request = ucp_get_nbx(_endpoint->getHandle(), + memGet._buffer, + memGet._length, + memGet._remoteAddr, + memGet._rkey, + ¶m); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); + + std::lock_guard lock(_mutex); + _request = request; +} + +static void logPopulateDelayedSubmission() {} + +void RequestMem::populateDelayedSubmission() +{ + bool terminate = + std::visit(data::dispatch{ + [this](data::MemPut memPut) { + if (_endpoint->getHandle() == nullptr) { + ucxx_warn("Endpoint was closed before message could be sent"); + Request::callback(this, UCS_ERR_CANCELED); + return true; + } + return false; + }, + [this](data::MemGet memGet) { + if (_worker->getHandle() == nullptr) { + ucxx_warn("Endpoint was closed before message could be received"); + Request::callback(this, UCS_ERR_CANCELED); + return true; + } + return false; + }, + [](auto) -> decltype(terminate) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); + if (terminate) return; + + request(); + + auto log = + [this]( + const void* buffer, const size_t length, const uint64_t remoteAddr, const ucp_rkey_h rkey) { + if (_enablePythonFuture) + ucxx_trace_req_f( + _ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "populateDelayedSubmission, buffer: %p, size: %lu, remoteAddr: 0x%lx, rkey: %p, " + "future: %p, future handle: %p", + buffer, + length, + remoteAddr, + rkey, + _future.get(), + _future->getHandle()); + else + ucxx_trace_req_f( + _ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "populateDelayedSubmission, buffer: %p, size: %lu, remoteAddr: 0x%lx, rkey: %p", + buffer, + length, + remoteAddr, + rkey); + }; + + std::visit(data::dispatch{ + [this, &log](data::MemPut memPut) { + log(memPut._buffer, memPut._length, memPut._remoteAddr, memPut._rkey); + }, + [this, &log](data::MemGet memGet) { + log(memGet._buffer, memGet._length, memGet._remoteAddr, memGet._rkey); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); + + process(); +} + +} // namespace ucxx diff --git a/cpp/src/request_tag.cpp b/cpp/src/request_tag.cpp index 9bdc1f487..f8dfc8e2a 100644 --- a/cpp/src/request_tag.cpp +++ b/cpp/src/request_tag.cpp @@ -137,8 +137,6 @@ void RequestTag::request() _request = request; } -static void logPopulateDelayedSubmission() {} - void RequestTag::populateDelayedSubmission() { bool terminate = diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index 300804e9d..aea6f3312 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -591,4 +592,13 @@ bool Worker::amProbe(const ucp_ep_h endpointHandle) const return _amData->_recvPool.find(endpointHandle) != _amData->_recvPool.end(); } +std::shared_ptr Worker::flush(const bool enableFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto worker = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest( + createRequestFlush(worker, data::Flush(), enableFuture, callbackFunction, callbackData)); +} + } // namespace ucxx diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 5055653d2..a818bf131 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include #include @@ -14,6 +16,9 @@ #include #include "include/utils.h" +#include "ucxx/buffer.h" +#include "ucxx/constructors.h" +#include "ucxx/utils/ucx.h" namespace { @@ -21,6 +26,8 @@ using ::testing::Combine; using ::testing::ContainerEq; using ::testing::Values; +typedef std::vector DataContainerType; + class RequestTest : public ::testing::TestWithParam< std::tuple> { protected: @@ -39,8 +46,8 @@ class RequestTest : public ::testing::TestWithParam< size_t _rndvThresh{8192}; size_t _numBuffers{0}; - std::vector> _send; - std::vector> _recv; + std::vector _send; + std::vector _recv; std::vector> _sendBuffer; std::vector> _recvBuffer; std::vector _sendPtr{nullptr}; @@ -306,6 +313,176 @@ TEST_P(RequestTest, TagUserCallback) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } +TEST_P(RequestTest, MemoryGet) +{ + if (_bufferType == ucxx::BufferType::RMM) + GTEST_SKIP() << "CUDA memory support not implemented yet"; + allocate(); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, nullptr); + + // Fill memory handle with send data + memcpy(reinterpret_cast(memoryHandle->getBaseAddress()), _sendPtr[0], _messageSize); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back(_ep->memGet(_recvPtr[0], _messageSize, remoteKey)); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + copyResults(); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, MemoryGetPreallocated) +{ + if (_bufferType == ucxx::BufferType::RMM) + GTEST_SKIP() << "CUDA memory support not implemented yet"; + allocate(); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, _sendPtr[0]); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back(_ep->memGet(_recvPtr[0], _messageSize, remoteKey)); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + copyResults(); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, MemoryGetWithOffset) +{ + if (_bufferType == ucxx::BufferType::RMM) + GTEST_SKIP() << "CUDA memory support not implemented yet"; + if (_messageLength < 2) GTEST_SKIP() << "Message too small to perform operations with offsets"; + allocate(); + + size_t offset = 1; + size_t offsetBytes = offset * sizeof(_send[0][0]); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, nullptr); + + // Fill memory handle with send data + memcpy(reinterpret_cast(memoryHandle->getBaseAddress()), _sendPtr[0], _messageSize); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back(_ep->memGet(reinterpret_cast(_recvPtr[0]) + offsetBytes, + _messageSize - offsetBytes, + remoteKey, + offsetBytes)); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + copyResults(); + + // Assert offset data correctness + auto recvOffset = DataContainerType(_recv[0].begin() + offset, _recv[0].end()); + auto sendOffset = DataContainerType(_send[0].begin() + offset, _send[0].end()); + ASSERT_THAT(recvOffset, sendOffset); +} + +TEST_P(RequestTest, MemoryPut) +{ + if (_bufferType == ucxx::BufferType::RMM) + GTEST_SKIP() << "CUDA memory support not implemented yet"; + allocate(); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, nullptr); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back(_ep->memPut(_sendPtr[0], _messageSize, remoteKey)); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + // Copy memory handle data to receive buffer + memcpy(_recvPtr[0], reinterpret_cast(memoryHandle->getBaseAddress()), _messageSize); + + copyResults(); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, MemoryPutPreallocated) +{ + if (_bufferType == ucxx::BufferType::RMM) + GTEST_SKIP() << "CUDA memory support not implemented yet"; + allocate(); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, _recvPtr[0]); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back(_ep->memPut(_sendPtr[0], _messageSize, remoteKey)); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + copyResults(); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, MemoryPutWithOffset) +{ + if (_bufferType == ucxx::BufferType::RMM) + GTEST_SKIP() << "CUDA memory support not implemented yet"; + if (_messageLength < 2) GTEST_SKIP() << "Message too small to perform operations with offsets"; + allocate(); + + size_t offset = 1; + size_t offsetBytes = offset * sizeof(_send[0][0]); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, nullptr); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back(_ep->memPut(reinterpret_cast(_sendPtr[0]) + offsetBytes, + _messageSize - offsetBytes, + remoteKey, + offsetBytes)); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + // Copy memory handle data to receive buffer + memcpy(reinterpret_cast(_recvPtr[0]), + reinterpret_cast(memoryHandle->getBaseAddress()), + _messageSize); + + copyResults(); + + // Assert offset data correctness + auto recvOffset = DataContainerType(_recv[0].begin() + offset, _recv[0].end()); + auto sendOffset = DataContainerType(_send[0].begin() + offset, _send[0].end()); + ASSERT_THAT(recvOffset, sendOffset); +} + INSTANTIATE_TEST_SUITE_P(ProgressModes, RequestTest, Combine(Values(ucxx::BufferType::Host), diff --git a/cpp/tests/utils.cpp b/cpp/tests/utils.cpp index ec063fec0..ba84d33bb 100644 --- a/cpp/tests/utils.cpp +++ b/cpp/tests/utils.cpp @@ -16,14 +16,12 @@ void createCudaContextCallback(void* callbackArg) std::function getProgressFunction(std::shared_ptr worker, ProgressMode progressMode) { - if (progressMode == ProgressMode::Polling) - return std::bind(std::mem_fn(&ucxx::Worker::progress), worker); - else if (progressMode == ProgressMode::Blocking) - return std::bind(std::mem_fn(&ucxx::Worker::progressWorkerEvent), worker, -1); - else if (progressMode == ProgressMode::Wait) - return std::bind(std::mem_fn(&ucxx::Worker::waitProgress), worker); - else - return std::function(); + switch (progressMode) { + case ProgressMode::Polling: return [worker]() { worker->progress(); }; + case ProgressMode::Blocking: return [worker]() { worker->progressWorkerEvent(-1); }; + case ProgressMode::Wait: return [worker]() { worker->waitProgress(); }; + default: return []() {}; + } } bool loopWithTimeout(std::chrono::milliseconds timeout, std::function f) diff --git a/cpp/tests/worker.cpp b/cpp/tests/worker.cpp index 4cc610561..f41071d09 100644 --- a/cpp/tests/worker.cpp +++ b/cpp/tests/worker.cpp @@ -146,6 +146,54 @@ TEST_P(WorkerProgressTest, ProgressAm) ASSERT_EQ(recvAbstract[0], send[0]); } +TEST_P(WorkerProgressTest, ProgressMemoryGet) +{ + auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + + std::vector send{123}; + std::vector recv(1); + + size_t messageSize = send.size() * sizeof(int); + + auto memoryHandle = _context->createMemoryHandle(messageSize, send.data()); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back( + ep->memGet(recv.data(), messageSize, remoteKey->getBaseAddress(), remoteKey->getHandle())); + requests.push_back(_worker->flush()); + waitRequests(_worker, requests, _progressWorker); + + ASSERT_EQ(recv[0], send[0]); +} + +TEST_P(WorkerProgressTest, ProgressMemoryPut) +{ + auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + + std::vector send{123}; + std::vector recv(1); + + size_t messageSize = send.size() * sizeof(int); + + auto memoryHandle = _context->createMemoryHandle(messageSize, recv.data()); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(ep, serializedRemoteKey); + + std::vector> requests; + requests.push_back( + ep->memPut(send.data(), messageSize, remoteKey->getBaseAddress(), remoteKey->getHandle())); + requests.push_back(_worker->flush()); + waitRequests(_worker, requests, _progressWorker); + + ASSERT_EQ(recv[0], send[0]); +} + TEST_P(WorkerProgressTest, ProgressStream) { auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress());