diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 8416b51e6..7a3cd8604 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -376,6 +376,52 @@ class Endpoint : public Component { RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); + /** + * @brief Enqueue an active message send operation with explicit policy parameters. + * + * This overload extends `amSend()` with explicit UCX datatype/flags controls and receive + * allocation policy metadata while keeping callback behavior identical to the legacy API. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the message to be sent. + * @param[in] params active message send parameters. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + [[nodiscard]] std::shared_ptr amSend( + const void* const buffer, + const size_t length, + const AmSendParams& params, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + /** + * @brief Enqueue an active message send operation with IOV datatype. + * + * This overload submits `UCP_DATATYPE_IOV` active message sends. + * + * @param[in] iov vector of IOV segments to be sent. + * @param[in] params active message send parameters. Datatype must be + * `UCP_DATATYPE_IOV`. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + [[nodiscard]] std::shared_ptr amSend( + std::vector iov, + const AmSendParams& params, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + /** * @brief Enqueue an active message receive operation. * diff --git a/cpp/include/ucxx/internal/request_am.h b/cpp/include/ucxx/internal/request_am.h index 5900a4e79..71f81b79d 100644 --- a/cpp/include/ucxx/internal/request_am.h +++ b/cpp/include/ucxx/internal/request_am.h @@ -9,6 +9,7 @@ #include #include #include +#include #include @@ -57,12 +58,14 @@ class RecvAmMessage { * @param[in] request request to be later notified/delivered to user. * @param[in] buffer buffer containing the received data * @param[in] receiverCallback receiver callback to execute when request completes. + * @param[in] userHeader user-defined header associated with the received message. */ RecvAmMessage(internal::AmData* amData, ucp_ep_h ep, std::shared_ptr request, std::shared_ptr buffer, - AmReceiverCallbackType receiverCallback = AmReceiverCallbackType()); + AmReceiverCallbackType receiverCallback = AmReceiverCallbackType(), + std::vector userHeader = {}); /** * @brief Set the UCP request. diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index 9e237bf0e..6ad7607bd 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once @@ -224,6 +224,17 @@ class Request : public Component { * @return The received buffer (if applicable) or `nullptr`. */ [[nodiscard]] virtual std::shared_ptr getRecvBuffer(); + + /** + * @brief Get the received user header. + * + * This method is used to get the user-defined header bytes for applicable derived classes + * (e.g., `RequestAm` receive operations), in all other cases this will return an empty + * string. + * + * @return The received user header (if applicable) or an empty string. + */ + [[nodiscard]] virtual std::string getRecvHeader(); }; } // namespace ucxx diff --git a/cpp/include/ucxx/request_am.h b/cpp/include/ucxx/request_am.h index 533610a7e..b84477107 100644 --- a/cpp/include/ucxx/request_am.h +++ b/cpp/include/ucxx/request_am.h @@ -1,11 +1,13 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once +#include #include #include #include +#include #include @@ -31,8 +33,9 @@ class RequestAm : public Request { private: friend class internal::RecvAmMessage; - std::string _header{}; ///< Retain copy of header for send requests as workaround for - ///< https://github.com/openucx/ucx/issues/10424 + std::vector _header{}; ///< Retain copy of header bytes for send requests as + ///< workaround for + ///< https://github.com/openucx/ucx/issues/10424 /** * @brief Private constructor of `ucxx::RequestAm`. @@ -161,6 +164,8 @@ class RequestAm : public Request { const ucp_am_recv_param_t* param); [[nodiscard]] std::shared_ptr getRecvBuffer() override; + + [[nodiscard]] std::string getRecvHeader() override; }; } // namespace ucxx diff --git a/cpp/include/ucxx/request_data.h b/cpp/include/ucxx/request_data.h index 057936484..addf8b680 100644 --- a/cpp/include/ucxx/request_data.h +++ b/cpp/include/ucxx/request_data.h @@ -1,11 +1,12 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once #include #include +#include #include #include @@ -28,11 +29,19 @@ namespace data { */ class AmSend { public: - const void* const _buffer{nullptr}; ///< The raw pointer where data to be sent is stored. - const size_t _length{0}; ///< The length of the message. + const void* const _buffer{nullptr}; ///< The raw pointer where data to be sent is stored. + const size_t _length{0}; ///< Message length in bytes (contiguous datatype only). + const std::vector _iov{}; ///< Segments for IOV datatype. + const size_t _count{0}; ///< Count passed to `ucp_am_send_nbx`: byte count + ///< for contiguous, number of IOV segments for IOV. + const uint32_t _flags{UCP_AM_SEND_FLAG_REPLY}; ///< UCP AM send flags. + const ucp_datatype_t _datatype{ucp_dt_make_contig(1)}; ///< UCP datatype. const ucs_memory_type_t _memoryType{UCS_MEMORY_TYPE_HOST}; ///< Memory type used on the operation + const AmSendMemoryTypePolicy _memoryTypePolicy{ + AmSendMemoryTypePolicy::FallbackToHost}; ///< Receiver allocation policy. const std::optional _receiverCallbackInfo{ std::nullopt}; ///< Owner name and unique identifier of the receiver callback. + const std::vector _userHeader{}; ///< Opaque user-defined header bytes. /** * @brief Constructor for Active Message-specific send data. @@ -41,14 +50,21 @@ class AmSend { * * @param[in] buffer a raw pointer to the data to be sent. * @param[in] length the size in bytes of the message to be sent. - * @param[in] memoryType the memory type of the buffer. - * @param[in] receiverCallbackInfo the owner name and unique identifier of the receiver - callback. + * @param[in] params send parameters controlling datatype/flags/policies. */ explicit AmSend(const decltype(_buffer) buffer, const decltype(_length) length, - const decltype(_memoryType) memoryType = UCS_MEMORY_TYPE_HOST, - const decltype(_receiverCallbackInfo) receiverCallbackInfo = std::nullopt); + const AmSendParams& params = AmSendParams{}); + + /** + * @brief Constructor for Active Message-specific send data using IOV datatype. + * + * Construct an object containing Active Message-specific send data for `UCP_DATATYPE_IOV`. + * + * @param[in] iov vector of IOV segments to send. + * @param[in] params send parameters controlling datatype/flags/policies. + */ + explicit AmSend(decltype(_iov) iov, const AmSendParams& params = AmSendParams{}); AmSend() = delete; }; @@ -62,6 +78,7 @@ class AmSend { class AmReceive { public: std::shared_ptr<::ucxx::Buffer> _buffer{nullptr}; ///< The AM received message buffer + std::vector _userHeader{}; ///< User-defined header bytes from the sender. /** * @brief Constructor for Active Message-specific receive data. diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index bda188dc0..63685fb3c 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -1,16 +1,21 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once #include +#include +#include #include #include #include #include +#include #include +#include #include +#include #include @@ -154,8 +159,8 @@ typedef uint64_t AmReceiverCallbackIdType; */ class AmReceiverCallbackInfo { public: - const AmReceiverCallbackOwnerType owner; ///< The owner name of the callback - const AmReceiverCallbackIdType id; ///< The unique identifier of the callback + AmReceiverCallbackOwnerType owner; ///< The owner name of the callback + AmReceiverCallbackIdType id; ///< The unique identifier of the callback AmReceiverCallbackInfo() = delete; @@ -168,6 +173,64 @@ class AmReceiverCallbackInfo { AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, AmReceiverCallbackIdType id); }; +/** + * @brief Policy used to allocate receive buffers for Active Messages. + * + * Active Message receive allocations can be strict (error if no allocator is registered for + * sender-provided memory type) or permissive (fallback to host allocation). + */ +enum class AmSendMemoryTypePolicy { + FallbackToHost = 0, ///< If no allocator exists for memory type, fallback to host memory. + ErrorOnUnsupported, ///< If no allocator exists for memory type, fail with unsupported error. +}; + +/** + * @brief Parameters controlling Active Message send behavior. + * + * This object is used by the extended Active Message API to expose UCX send knobs without + * breaking existing callers. + */ +struct AmSendParams { + uint32_t flags{UCP_AM_SEND_FLAG_REPLY}; ///< UCP AM send flags. + ucp_datatype_t datatype{ucp_dt_make_contig(1)}; ///< Datatype used by `ucp_am_send_nbx`. + ucs_memory_type_t memoryType{UCS_MEMORY_TYPE_HOST}; ///< Sender memory type hint. + AmSendMemoryTypePolicy memoryTypePolicy{ + AmSendMemoryTypePolicy::FallbackToHost}; ///< Receiver allocation policy. + std::optional receiverCallbackInfo{ + std::nullopt}; ///< Optional receiver callback metadata. + std::vector userHeader{}; ///< Opaque user-defined header bytes. This is serialized + ///< into the AM header parameter of `ucp_am_send_nbx`, + ///< which is subject to transport-level size limits. For + ///< TCP, the default segment size is ~8 KiB + ///< (`UCX_TCP_TX_SEG_SIZE` / `UCX_TCP_RX_SEG_SIZE`). + ///< Headers that exceed the transport limit will cause a + ///< fatal UCX error. Keep user headers small + ///< (recommended < 4 KiB) or increase segment size env + ///< vars as needed. + + /** + * @brief Set opaque user header bytes from raw pointer. + * + * @param[in] data pointer to input bytes, may be `nullptr` iff `size == 0`. + * @param[in] size number of bytes in input. + */ + void setUserHeader(const void* data, size_t size) + { + if (size > 0 && data == nullptr) + throw std::invalid_argument( + "AmSendParams::setUserHeader received null data with non-zero size"); + userHeader.resize(size); + if (size > 0) memcpy(userHeader.data(), data, size); + } + + /** + * @brief Convenience overload to set user header from string-like views. + * + * @param[in] data view of opaque bytes. + */ + void setUserHeader(std::string_view data) { setUserHeader(data.data(), data.size()); } +}; + /** * @brief Serialized form of a remote key. * diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 876a2d86c..4c20bed4c 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -449,14 +449,41 @@ std::shared_ptr Endpoint::amSend( const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) +{ + auto params = AmSendParams{}; + params.memoryType = memoryType; + params.receiverCallbackInfo = receiverCallbackInfo; + + return amSend(buffer, length, params, enablePythonFuture, callbackFunction, callbackData); +} + +std::shared_ptr Endpoint::amSend(const void* const buffer, + const size_t length, + const AmSendParams& params, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest( - createRequestAm(endpoint, - data::AmSend(buffer, length, memoryType, receiverCallbackInfo), - enablePythonFuture, - callbackFunction, - callbackData)); + return registerInflightRequest(createRequestAm(endpoint, + data::AmSend(buffer, length, params), + enablePythonFuture, + callbackFunction, + callbackData)); +} + +std::shared_ptr Endpoint::amSend(std::vector iov, + const AmSendParams& params, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest(createRequestAm(endpoint, + data::AmSend(std::move(iov), params), + enablePythonFuture, + callbackFunction, + callbackData)); } std::shared_ptr Endpoint::amRecv(const bool enablePythonFuture, diff --git a/cpp/src/internal/request_am.cpp b/cpp/src/internal/request_am.cpp index 88db30363..e2eb65c25 100644 --- a/cpp/src/internal/request_am.cpp +++ b/cpp/src/internal/request_am.cpp @@ -9,6 +9,8 @@ #include #include +#include +#include namespace ucxx { @@ -18,11 +20,15 @@ RecvAmMessage::RecvAmMessage(internal::AmData* amData, ucp_ep_h ep, std::shared_ptr request, std::shared_ptr buffer, - AmReceiverCallbackType receiverCallback) + AmReceiverCallbackType receiverCallback, + std::vector userHeader) : _amData(amData), _ep(ep), _request(request) { std::visit(data::dispatch{ - [this, buffer](data::AmReceive& amReceive) { amReceive._buffer = buffer; }, + [this, buffer, &userHeader](data::AmReceive& amReceive) { + amReceive._buffer = buffer; + amReceive._userHeader = std::move(userHeader); + }, [](auto) { throw std::runtime_error("Unreachable"); }, }, _request->_requestData); diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 2283e27f2..1551e4f23 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -242,4 +242,6 @@ const std::string& Request::getOwnerString() const { return _ownerString; } std::shared_ptr Request::getRecvBuffer() { return nullptr; } +std::string Request::getRecvHeader() { return {}; } + } // namespace ucxx diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index 1c2995a65..3d4ad2dfb 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -2,10 +2,9 @@ * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ -#include +#include #include #include -#include #include #include #include @@ -28,18 +27,20 @@ AmReceiverCallbackInfo::AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType { } -typedef std::string AmHeaderSerialized; +typedef std::vector AmHeaderSerialized; struct AmHeader { ucs_memory_type_t memoryType; + AmSendMemoryTypePolicy memoryTypePolicy; std::optional receiverCallbackInfo; + std::vector userHeader; ///< Opaque user-defined header bytes. - static AmHeader deserialize(const std::string_view serialized) + static AmHeader deserialize(const std::byte* serialized, size_t serializedSize) { size_t offset{0}; auto decode = [&offset, &serialized](void* data, size_t bytes) { - memcpy(data, serialized.data() + offset, bytes); + memcpy(data, serialized + offset, bytes); offset += bytes; }; @@ -49,6 +50,7 @@ struct AmHeader { bool hasReceiverCallback{false}; decode(&hasReceiverCallback, sizeof(hasReceiverCallback)); + std::optional receiverCallbackInfo = std::nullopt; if (hasReceiverCallback) { size_t ownerSize{0}; decode(&ownerSize, sizeof(ownerSize)); @@ -59,11 +61,30 @@ struct AmHeader { AmReceiverCallbackIdType id{}; decode(&id, sizeof(id)); - return AmHeader{.memoryType = memoryType, - .receiverCallbackInfo = AmReceiverCallbackInfo(owner, id)}; + receiverCallbackInfo = AmReceiverCallbackInfo(owner, id); } - return AmHeader{.memoryType = memoryType, .receiverCallbackInfo = std::nullopt}; + AmSendMemoryTypePolicy memoryTypePolicy = AmSendMemoryTypePolicy::FallbackToHost; + if (offset + sizeof(uint8_t) <= serializedSize) { + uint8_t serializedMemoryTypePolicy{0}; + decode(&serializedMemoryTypePolicy, sizeof(serializedMemoryTypePolicy)); + memoryTypePolicy = static_cast(serializedMemoryTypePolicy); + } + + std::vector userHeader{}; + if (offset + sizeof(size_t) <= serializedSize) { + size_t userHeaderSize{0}; + decode(&userHeaderSize, sizeof(userHeaderSize)); + if (userHeaderSize > 0 && offset + userHeaderSize <= serializedSize) { + userHeader.resize(userHeaderSize); + decode(userHeader.data(), userHeaderSize); + } + } + + return AmHeader{.memoryType = memoryType, + .memoryTypePolicy = memoryTypePolicy, + .receiverCallbackInfo = receiverCallbackInfo, + .userHeader = std::move(userHeader)}; } const AmHeaderSerialized serialize() const @@ -73,9 +94,12 @@ struct AmHeader { const size_t ownerSize = (receiverCallbackInfo) ? receiverCallbackInfo->owner.size() : 0; const size_t amReceiverCallbackInfoSize = (receiverCallbackInfo) ? sizeof(ownerSize) + ownerSize + sizeof(receiverCallbackInfo->id) : 0; - const size_t totalSize = - sizeof(memoryType) + sizeof(hasReceiverCallback) + amReceiverCallbackInfoSize; - std::string serialized(totalSize, 0); + const uint8_t serializedMemoryTypePolicy = static_cast(memoryTypePolicy); + const size_t userHeaderSize = userHeader.size(); + const size_t totalSize = sizeof(memoryType) + sizeof(hasReceiverCallback) + + amReceiverCallbackInfoSize + sizeof(serializedMemoryTypePolicy) + + sizeof(userHeaderSize) + userHeaderSize; + std::vector serialized(totalSize, std::byte{0}); auto encode = [&offset, &serialized](void const* data, size_t bytes) { memcpy(serialized.data() + offset, data, bytes); @@ -89,6 +113,9 @@ struct AmHeader { encode(receiverCallbackInfo->owner.c_str(), ownerSize); encode(&receiverCallbackInfo->id, sizeof(receiverCallbackInfo->id)); } + encode(&serializedMemoryTypePolicy, sizeof(serializedMemoryTypePolicy)); + encode(&userHeaderSize, sizeof(userHeaderSize)); + if (userHeaderSize > 0) { encode(userHeader.data(), userHeaderSize); } return serialized; } @@ -224,8 +251,7 @@ ucs_status_t RequestAm::recvCallback(void* arg, bool is_rndv = param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV; std::shared_ptr buf{nullptr}; - auto amHeader = - AmHeader::deserialize(std::string_view(static_cast(header), header_length)); + auto amHeader = AmHeader::deserialize(static_cast(header), header_length); auto receiverCallback = [&amHeader, &amData]() { if (amHeader.receiverCallbackInfo) { try { @@ -278,15 +304,18 @@ ucs_status_t RequestAm::recvCallback(void* arg, if (is_rndv) { if (amData->_allocators.find(amHeader.memoryType) == amData->_allocators.end()) { - // TODO: Is a hard failure better? - // ucxx_debug("Unsupported memory type %d", amHeader.memoryType); - // internal::RecvAmMessage recvAmMessage(amData, ep, req, nullptr); - // recvAmMessage.callback(nullptr, UCS_ERR_UNSUPPORTED); - // return UCS_ERR_UNSUPPORTED; - - ucxx_trace_req("No allocator registered for memory type %u, falling back to host memory.", - amHeader.memoryType); - amHeader.memoryType = UCS_MEMORY_TYPE_HOST; + if (amHeader.memoryTypePolicy == AmSendMemoryTypePolicy::ErrorOnUnsupported) { + ucxx_debug("No allocator registered for memory type %u and strict policy is active", + amHeader.memoryType); + internal::RecvAmMessage recvAmMessage( + amData, ep, req, nullptr, receiverCallback, amHeader.userHeader); + recvAmMessage.callback(nullptr, UCS_ERR_UNSUPPORTED); + return UCS_ERR_UNSUPPORTED; + } else { + ucxx_trace_req("No allocator registered for memory type %u, falling back to host memory.", + amHeader.memoryType); + amHeader.memoryType = UCS_MEMORY_TYPE_HOST; + } } try { @@ -295,8 +324,8 @@ ucs_status_t RequestAm::recvCallback(void* arg, ucxx_debug("Exception calling allocator: %s", e.what()); } - auto recvAmMessage = - std::make_shared(amData, ep, req, buf, receiverCallback); + auto recvAmMessage = std::make_shared( + amData, ep, req, buf, receiverCallback, amHeader.userHeader); ucp_request_param_t requestParam = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA | @@ -355,7 +384,8 @@ ucs_status_t RequestAm::recvCallback(void* arg, } else { buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length); - internal::RecvAmMessage recvAmMessage(amData, ep, req, buf, receiverCallback); + internal::RecvAmMessage recvAmMessage( + amData, ep, req, buf, receiverCallback, amHeader.userHeader); if (buf == nullptr) { ucxx_debug("Failed to allocate %lu bytes of memory", length); recvAmMessage._request->setStatus(UCS_ERR_NO_MEMORY); @@ -400,28 +430,47 @@ std::shared_ptr RequestAm::getRecvBuffer() _requestData); } +std::string RequestAm::getRecvHeader() +{ + return std::visit(data::dispatch{ + [](const data::AmReceive& amReceive) { + if (amReceive._userHeader.empty()) return std::string{}; + return std::string( + reinterpret_cast(amReceive._userHeader.data()), + amReceive._userHeader.size()); + }, + [](auto) -> std::string { return {}; }, + }, + _requestData); +} + void RequestAm::request() { std::visit( data::dispatch{ - [this](data::AmSend amSend) { - ucp_request_param_t param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | - UCP_OP_ATTR_FIELD_FLAGS | - UCP_OP_ATTR_FIELD_USER_DATA, - .flags = UCP_AM_SEND_FLAG_REPLY, - .datatype = ucp_dt_make_contig(1), - .user_data = this}; - - param.cb.send = _amSendCallback; - AmHeader header = {.memoryType = amSend._memoryType, - .receiverCallbackInfo = amSend._receiverCallbackInfo}; - _header = header.serialize(); - void* request = ucp_am_send_nbx(_endpoint->getHandle(), + [this](const data::AmSend& amSend) { + ucp_request_param_t param = { + .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_FLAGS | UCP_OP_ATTR_FIELD_USER_DATA, + .flags = amSend._flags, + .datatype = amSend._datatype, + .user_data = this}; + + param.cb.send = _amSendCallback; + AmHeader header = {.memoryType = amSend._memoryType, + .memoryTypePolicy = amSend._memoryTypePolicy, + .receiverCallbackInfo = amSend._receiverCallbackInfo, + .userHeader = amSend._userHeader}; + _header = header.serialize(); + const void* sendBuffer = (amSend._datatype == UCP_DATATYPE_IOV) + ? reinterpret_cast(amSend._iov.data()) + : amSend._buffer; + void* request = ucp_am_send_nbx(_endpoint->getHandle(), 0, - _header.data(), + reinterpret_cast(_header.data()), _header.size(), - amSend._buffer, - amSend._length, + sendBuffer, + amSend._count, ¶m); std::lock_guard lock(_mutex); @@ -477,7 +526,14 @@ void RequestAm::populateDelayedSubmission() std::visit(data::dispatch{ [this, &log](data::AmSend amSend) { - log(amSend._buffer, amSend._length, amSend._memoryType); + if (amSend._datatype == UCP_DATATYPE_IOV) { + size_t totalLength{0}; + for (const auto& segment : amSend._iov) + totalLength += segment.length; + log(amSend._iov.data(), totalLength, amSend._memoryType); + } else { + log(amSend._buffer, amSend._length, amSend._memoryType); + } }, [](auto) { throw std::runtime_error("Unreachable"); }, }, diff --git a/cpp/src/request_data.cpp b/cpp/src/request_data.cpp index d4abb0bbc..2c3bfc49a 100644 --- a/cpp/src/request_data.cpp +++ b/cpp/src/request_data.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -19,15 +19,46 @@ namespace ucxx { namespace data { -AmSend::AmSend(const void* const buffer, - const size_t length, - const ucs_memory_type memoryType, - const std::optional receiverCallbackInfo) +AmSend::AmSend(const void* const buffer, const size_t length, const AmSendParams& params) : _buffer(buffer), _length(length), - _memoryType(memoryType), - _receiverCallbackInfo(receiverCallbackInfo) + _iov(), + _count(length), + _flags(params.flags), + _datatype(params.datatype), + _memoryType(params.memoryType), + _memoryTypePolicy(params.memoryTypePolicy), + _receiverCallbackInfo(params.receiverCallbackInfo), + _userHeader(params.userHeader) +{ + if (_datatype != ucp_dt_make_contig(1)) + throw std::runtime_error("Contiguous AM send requires datatype `ucp_dt_make_contig(1)`."); + + if (_buffer == nullptr && _length > 0) + throw std::runtime_error("Buffer cannot be a nullptr when length is > 0."); +} + +AmSend::AmSend(std::vector iov, const AmSendParams& params) + : _buffer(nullptr), + _length(0), + _iov(std::move(iov)), + _count(_iov.size()), + _flags(params.flags), + _datatype(params.datatype), + _memoryType(params.memoryType), + _memoryTypePolicy(params.memoryTypePolicy), + _receiverCallbackInfo(params.receiverCallbackInfo), + _userHeader(params.userHeader) { + if (_datatype != UCP_DATATYPE_IOV) + throw std::runtime_error("IOV AM send requires datatype `UCP_DATATYPE_IOV`."); + + if (_iov.empty()) throw std::runtime_error("IOV cannot be empty."); + + for (const auto& segment : _iov) { + if (segment.buffer == nullptr && segment.length > 0) + throw std::runtime_error("IOV segment buffer cannot be nullptr when segment length is > 0."); + } } AmReceive::AmReceive() {} diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index fa7fb1c18..c0b5dcd33 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -1,10 +1,11 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include #include #include +#include #include #include #include @@ -202,6 +203,98 @@ TEST_P(RequestTest, ProgressAm) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } +TEST_P(RequestTest, ProgressAmIovHost) +{ + if (_progressMode == ProgressMode::Wait) { + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + } + + if (_memoryType != UCS_MEMORY_TYPE_HOST) { + GTEST_SKIP() << "IOV test uses host buffers for deterministic validation"; + } + + const size_t messageLength = std::max(4, _messageLength); + std::vector send(messageLength); + std::iota(send.begin(), send.end(), 0); + + const size_t firstSegmentLength = messageLength / 2; + const size_t secondSegmentLength = messageLength - firstSegmentLength; + std::vector iov(2); + iov[0].buffer = send.data(); + iov[0].length = firstSegmentLength * sizeof(int); + iov[1].buffer = send.data() + firstSegmentLength; + iov[1].length = secondSegmentLength * sizeof(int); + + auto amSendParams = ucxx::AmSendParams{}; + amSendParams.datatype = UCP_DATATYPE_IOV; + amSendParams.memoryType = UCS_MEMORY_TYPE_HOST; + + std::vector> requests; + requests.push_back(_ep->amSend(iov, amSendParams)); + requests.push_back(_ep->amRecv()); + waitRequests(_worker, requests, _progressWorker); + + auto recvReq = requests[1]; + auto recvBuffer = recvReq->getRecvBuffer(); + ASSERT_EQ(recvBuffer->getType(), ucxx::BufferType::Host); + ASSERT_EQ(recvBuffer->getSize(), messageLength * sizeof(int)); + + std::vector recv(reinterpret_cast(recvBuffer->data()), + reinterpret_cast(recvBuffer->data()) + messageLength); + ASSERT_THAT(recv, ContainerEq(send)); +} + +TEST_P(RequestTest, ProgressAmIovValidation) +{ + auto amSendParams = ucxx::AmSendParams{}; + amSendParams.datatype = UCP_DATATYPE_IOV; + amSendParams.memoryType = UCS_MEMORY_TYPE_HOST; + + EXPECT_THROW(std::ignore = _ep->amSend(std::vector{}, amSendParams), + std::runtime_error); + + std::vector iovWithNullBuffer(1); + iovWithNullBuffer[0].buffer = nullptr; + iovWithNullBuffer[0].length = 16; + EXPECT_THROW(std::ignore = _ep->amSend(iovWithNullBuffer, amSendParams), std::runtime_error); + + std::vector send{1, 2, 3, 4}; + std::vector validIov(1); + validIov[0].buffer = send.data(); + validIov[0].length = send.size() * sizeof(send[0]); + + auto wrongDatatypeParams = amSendParams; + wrongDatatypeParams.datatype = ucp_dt_make_contig(1); + EXPECT_THROW(std::ignore = _ep->amSend(validIov, wrongDatatypeParams), std::runtime_error); +} + +TEST_P(RequestTest, ProgressAmMemoryTypePolicyStrict) +{ + if (_progressMode == ProgressMode::Wait) { + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + } + + const size_t bytes = std::max(_rndvThresh + 128, sizeof(int)); + std::vector send(bytes, 42); + + auto amSendParams = ucxx::AmSendParams{}; + amSendParams.memoryType = UCS_MEMORY_TYPE_CUDA; + amSendParams.memoryTypePolicy = ucxx::AmSendMemoryTypePolicy::ErrorOnUnsupported; + + std::vector> requests; + requests.push_back(_ep->amSend(send.data(), send.size(), amSendParams)); + requests.push_back(_ep->amRecv()); + + // Wait for completion without calling checkError(), since the receive request + // is expected to complete with UCS_ERR_UNSUPPORTED. + while (!requests[0]->isCompleted() || !requests[1]->isCompleted()) + _progressWorker(); + + // When the receiver rejects a rendezvous transfer, UCX propagates the error to + // both sides, so the send may also complete with UCS_ERR_UNSUPPORTED. + ASSERT_EQ(requests[1]->getStatus(), UCS_ERR_UNSUPPORTED); +} + TEST_P(RequestTest, ProgressAmReceiverCallback) { if (_progressMode == ProgressMode::Wait) { @@ -264,6 +357,107 @@ TEST_P(RequestTest, ProgressAmReceiverCallback) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } +TEST_P(RequestTest, ProgressAmUserHeader) +{ + if (_progressMode == ProgressMode::Wait) { + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + } + + if (_memoryType != UCS_MEMORY_TYPE_HOST) { + GTEST_SKIP() << "User header test uses host buffers only"; + } + + allocate(1, false); + + const std::string sentHeader = "test-header-payload-\x00\x01\x02\xff"; + + auto amSendParams = ucxx::AmSendParams{}; + amSendParams.setUserHeader(sentHeader); + + std::vector> requests; + requests.push_back(_ep->amSend(_sendPtr[0], _messageSize, amSendParams)); + requests.push_back(_ep->amRecv()); + waitRequests(_worker, requests, _progressWorker); + + auto recvReq = requests[1]; + ASSERT_EQ(recvReq->getRecvHeader(), sentHeader); + + _recvPtr[0] = recvReq->getRecvBuffer()->data(); + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, ProgressAmIovUserHeader) +{ + if (_progressMode == ProgressMode::Wait) { + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + } + + if (_memoryType != UCS_MEMORY_TYPE_HOST) { + GTEST_SKIP() << "IOV user header test uses host buffers only"; + } + + const size_t messageLength = std::max(4, _messageLength); + std::vector send(messageLength); + std::iota(send.begin(), send.end(), 0); + + const size_t firstSegmentLength = messageLength / 2; + const size_t secondSegmentLength = messageLength - firstSegmentLength; + std::vector iov(2); + iov[0].buffer = send.data(); + iov[0].length = firstSegmentLength * sizeof(int); + iov[1].buffer = send.data() + firstSegmentLength; + iov[1].length = secondSegmentLength * sizeof(int); + + const std::string sentHeader = "iov-user-header-data"; + + auto amSendParams = ucxx::AmSendParams{}; + amSendParams.datatype = UCP_DATATYPE_IOV; + amSendParams.memoryType = UCS_MEMORY_TYPE_HOST; + amSendParams.setUserHeader(sentHeader); + + std::vector> requests; + requests.push_back(_ep->amSend(iov, amSendParams)); + requests.push_back(_ep->amRecv()); + waitRequests(_worker, requests, _progressWorker); + + auto recvReq = requests[1]; + auto recvBuffer = recvReq->getRecvBuffer(); + ASSERT_EQ(recvBuffer->getType(), ucxx::BufferType::Host); + ASSERT_EQ(recvBuffer->getSize(), messageLength * sizeof(int)); + ASSERT_EQ(recvReq->getRecvHeader(), sentHeader); + + std::vector recv(reinterpret_cast(recvBuffer->data()), + reinterpret_cast(recvBuffer->data()) + messageLength); + ASSERT_THAT(recv, ContainerEq(send)); +} + +TEST_P(RequestTest, ProgressAmEmptyUserHeader) +{ + if (_progressMode == ProgressMode::Wait) { + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + } + + if (_memoryType != UCS_MEMORY_TYPE_HOST) { + GTEST_SKIP() << "User header test uses host buffers only"; + } + + allocate(1, false); + + // Send without user header (default empty) + std::vector> requests; + requests.push_back(_ep->amSend(_sendPtr[0], _messageSize, _memoryType)); + requests.push_back(_ep->amRecv()); + waitRequests(_worker, requests, _progressWorker); + + auto recvReq = requests[1]; + ASSERT_EQ(recvReq->getRecvHeader(), std::string{}); + + _recvPtr[0] = recvReq->getRecvBuffer()->data(); + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + TEST_P(RequestTest, ProgressStream) { allocate();