From b7066571d8471eb632ba93286381c5dc2bd152b3 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 28 Jul 2025 12:55:45 -0700 Subject: [PATCH 01/14] Support user deciding when to receive AM data --- cpp/include/ucxx/request_am.h | 14 ++- cpp/include/ucxx/request_data.h | 4 +- cpp/include/ucxx/typedefs.h | 32 ++++- cpp/src/request_am.cpp | 200 ++++++++++++++++++++------------ 4 files changed, 173 insertions(+), 77 deletions(-) diff --git a/cpp/include/ucxx/request_am.h b/cpp/include/ucxx/request_am.h index a1aae9655..d8bf654ce 100644 --- a/cpp/include/ucxx/request_am.h +++ b/cpp/include/ucxx/request_am.h @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once @@ -154,6 +154,18 @@ class RequestAm : public Request { const ucp_am_recv_param_t* param); [[nodiscard]] std::shared_ptr getRecvBuffer() override; + + /** + * @brief Get the Active Message data pointer and length for delayed receive operations. + * + * This method returns the AM data pointer and length that was stored when delayReceive + * was enabled and the message was received via the receive callback. This data can be used + * by the user to manually call ucp_am_recv_data_nbx when they are ready to receive the data. + * + * @returns The AmData struct containing the void* data pointer and length, or std::nullopt + * if this was not a delayed receive operation. + */ + [[nodiscard]] std::optional getAmData(); }; } // namespace ucxx diff --git a/cpp/include/ucxx/request_data.h b/cpp/include/ucxx/request_data.h index ca5b68bbc..66348a2e4 100644 --- a/cpp/include/ucxx/request_data.h +++ b/cpp/include/ucxx/request_data.h @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once @@ -61,6 +61,8 @@ class AmSend { class AmReceive { public: std::shared_ptr<::ucxx::Buffer> _buffer{nullptr}; ///< The AM received message buffer + std::optional<::ucxx::AmData> _amData{ + std::nullopt}; ///< The AM data pointer and length for delayed receiving /** * @brief Constructor for Active Message-specific receive data. diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index b4c76ed25..1a74e0a79 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -173,6 +173,27 @@ typedef uint64_t AmReceiverCallbackIdType; */ typedef const std::string AmReceiverCallbackInfoSerialized; +/** + * @brief Active Message data information for delayed receiving. + * + * Structure containing the Active Message data pointer and length that will be used + * when the user chooses to delay receiving and handle ucp_am_recv_data_nbx manually. + */ +struct AmData { + void* data; ///< The Active Message data pointer from the receive callback + size_t length; ///< The length of the Active Message data + + AmData() : data(nullptr), length(0) {} + + /** + * @brief Construct an AmData object. + * + * @param[in] data The Active Message data pointer from the receive callback. + * @param[in] length The length of the Active Message data. + */ + AmData(void* data, size_t length) : data(data), length(length) {} +}; + /** * @brief Information of an Active Message receiver callback. * @@ -182,16 +203,21 @@ class AmReceiverCallbackInfo { public: const AmReceiverCallbackOwnerType owner; ///< The owner name of the callback const AmReceiverCallbackIdType id; ///< The unique identifier of the callback + const bool delayReceive; ///< Whether to delay receiving data (user-controlled) AmReceiverCallbackInfo() = delete; /** * @brief Construct an AmReceiverCallbackInfo object. * - * @param[in] owner The owner name of the callback. - * @param[in] id The unique identifier of the callback. + * @param[in] owner The owner name of the callback. + * @param[in] id The unique identifier of the callback. + * @param[in] delayReceive Whether to delay receiving data, allowing user to control when + * ucp_am_recv_data_nbx is called. */ - AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, AmReceiverCallbackIdType id); + AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, + AmReceiverCallbackIdType id, + bool delayReceive = false); }; /** diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index 6444989e9..a7968923e 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -22,8 +22,9 @@ namespace ucxx { AmReceiverCallbackInfo::AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, - AmReceiverCallbackIdType id) - : owner(owner), id(id) + AmReceiverCallbackIdType id, + bool delayReceive) + : owner(owner), id(id), delayReceive(delayReceive) { } @@ -58,8 +59,11 @@ struct AmHeader { AmReceiverCallbackIdType id{}; decode(&id, sizeof(id)); + bool delayReceive{false}; + decode(&delayReceive, sizeof(delayReceive)); + return AmHeader{.memoryType = memoryType, - .receiverCallbackInfo = AmReceiverCallbackInfo(owner, id)}; + .receiverCallbackInfo = AmReceiverCallbackInfo(owner, id, delayReceive)}; } return AmHeader{.memoryType = memoryType, .receiverCallbackInfo = std::nullopt}; @@ -71,7 +75,9 @@ struct AmHeader { bool hasReceiverCallback{static_cast(receiverCallbackInfo)}; const size_t ownerSize = (receiverCallbackInfo) ? receiverCallbackInfo->owner.size() : 0; const size_t amReceiverCallbackInfoSize = - (receiverCallbackInfo) ? sizeof(ownerSize) + ownerSize + sizeof(receiverCallbackInfo->id) : 0; + (receiverCallbackInfo) ? sizeof(ownerSize) + ownerSize + sizeof(receiverCallbackInfo->id) + + sizeof(receiverCallbackInfo->delayReceive) + : 0; const size_t totalSize = sizeof(memoryType) + sizeof(hasReceiverCallback) + amReceiverCallbackInfoSize; std::string serialized(totalSize, 0); @@ -87,6 +93,7 @@ struct AmHeader { encode(&ownerSize, sizeof(ownerSize)); encode(receiverCallbackInfo->owner.c_str(), ownerSize); encode(&receiverCallbackInfo->id, sizeof(receiverCallbackInfo->id)); + encode(&receiverCallbackInfo->delayReceive, sizeof(receiverCallbackInfo->delayReceive)); } return serialized; @@ -258,80 +265,119 @@ ucs_status_t RequestAm::recvCallback(void* arg, } if (is_rndv) { - if (amData->_allocators.find(amHeader.memoryType) == amData->_allocators.end()) { - // TODO: Is a hard failure better? - // ucxx_debug("Unsupported memory type %d", amHeader.memoryType); - // internal::RecvAmMessage recvAmMessage(amData, ep, req, nullptr); - // recvAmMessage.callback(nullptr, UCS_ERR_UNSUPPORTED); - // return UCS_ERR_UNSUPPORTED; - - ucxx_trace_req("No allocator registered for memory type %u, falling back to host memory.", - amHeader.memoryType); - amHeader.memoryType = UCS_MEMORY_TYPE_HOST; - } - - try { - buf = amData->_allocators.at(amHeader.memoryType)(length); - } catch (const std::exception& e) { - ucxx_debug("Exception calling allocator: %s", e.what()); - } - - auto recvAmMessage = - std::make_shared(amData, ep, req, buf, receiverCallback); - - ucp_request_param_t requestParam = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | - UCP_OP_ATTR_FIELD_USER_DATA | - UCP_OP_ATTR_FLAG_NO_IMM_CMPL, - .cb = {.recv_am = _recvCompletedCallback}, - .user_data = recvAmMessage.get()}; - - if (buf == nullptr) { - ucxx_debug("Failed to allocate %lu bytes of memory", length); - recvAmMessage->_request->setStatus(UCS_ERR_NO_MEMORY); - return UCS_ERR_NO_MEMORY; - } + // Check if delayed receive is requested for rendezvous messages + bool shouldDelayReceive = + amHeader.receiverCallbackInfo && amHeader.receiverCallbackInfo->delayReceive; + + if (shouldDelayReceive) { + // For delayed receive: don't allocate buffer, don't call ucp_am_recv_data_nbx, + // store AM data pointer, return UCS_INPROGRESS + auto& amReceiveData = std::get(req->_requestData); + amReceiveData._amData = AmData(data, length); + + if (req->_enablePythonFuture) + ucxx_trace_req_f(ownerString.c_str(), + req.get(), + nullptr, + "amRecv rndv delayed", + "recvCallback, ep: %p, data: %p, size: %lu, future: %p, future handle: %p", + ep, + data, + length, + req->_future.get(), + req->_future->getHandle()); + else + ucxx_trace_req_f(ownerString.c_str(), + req.get(), + nullptr, + "amRecv rndv delayed", + "recvCallback, ep: %p, data: %p, size: %lu", + ep, + data, + length); + + // Execute receiver callback if present + if (receiverCallback) { receiverCallback(req, ep); } - ucs_status_ptr_t status = - ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &requestParam); + return UCS_INPROGRESS; + } else { + // Normal rendezvous receive path + if (amData->_allocators.find(amHeader.memoryType) == amData->_allocators.end()) { + // TODO: Is a hard failure better? + // ucxx_debug("Unsupported memory type %d", amHeader.memoryType); + // internal::RecvAmMessage recvAmMessage(amData, ep, req, nullptr); + // recvAmMessage.callback(nullptr, UCS_ERR_UNSUPPORTED); + // return UCS_ERR_UNSUPPORTED; + + ucxx_trace_req("No allocator registered for memory type %u, falling back to host memory.", + amHeader.memoryType); + amHeader.memoryType = UCS_MEMORY_TYPE_HOST; + } - if (req->_enablePythonFuture) - ucxx_trace_req_f(ownerString.c_str(), - req.get(), - status, - "amRecv rndv", - "recvCallback, ep: %p, buffer: %p, size: %lu, future: %p, future handle: %p", - ep, - buf->data(), - length, - req->_future.get(), - req->_future->getHandle()); - else - ucxx_trace_req_f(ownerString.c_str(), - req.get(), - status, - "amRecv rndv", - "recvCallback, ep: %p, buffer: %p, size: %lu", - ep, - buf->data(), - length); + try { + buf = amData->_allocators.at(amHeader.memoryType)(length); + } catch (const std::exception& e) { + ucxx_debug("Exception calling allocator: %s", e.what()); + } - if (req->isCompleted()) { - // The request completed/errored immediately - ucs_status_t s = UCS_PTR_STATUS(status); - recvAmMessage->callback(nullptr, s); + auto recvAmMessage = + std::make_shared(amData, ep, req, buf, receiverCallback); - return s; - } else { - // The request will be handled by the callback - recvAmMessage->setUcpRequest(status); - amData->_registerInflightRequest(req); + ucp_request_param_t requestParam = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_USER_DATA | + UCP_OP_ATTR_FLAG_NO_IMM_CMPL, + .cb = {.recv_am = _recvCompletedCallback}, + .user_data = recvAmMessage.get()}; - { - std::lock_guard lock(amData->_mutex); - amData->_recvAmMessageMap.emplace(req.get(), recvAmMessage); + if (buf == nullptr) { + ucxx_debug("Failed to allocate %lu bytes of memory", length); + recvAmMessage->_request->setStatus(UCS_ERR_NO_MEMORY); + return UCS_ERR_NO_MEMORY; } - return UCS_INPROGRESS; + ucs_status_ptr_t status = + ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &requestParam); + + if (req->_enablePythonFuture) + ucxx_trace_req_f( + ownerString.c_str(), + req.get(), + status, + "amRecv rndv", + "recvCallback, ep: %p, buffer: %p, size: %lu, future: %p, future handle: %p", + ep, + buf->data(), + length, + req->_future.get(), + req->_future->getHandle()); + else + ucxx_trace_req_f(ownerString.c_str(), + req.get(), + status, + "amRecv rndv", + "recvCallback, ep: %p, buffer: %p, size: %lu", + ep, + buf->data(), + length); + + if (req->isCompleted()) { + // The request completed/errored immediately + ucs_status_t s = UCS_PTR_STATUS(status); + recvAmMessage->callback(nullptr, s); + + return s; + } else { + // The request will be handled by the callback + recvAmMessage->setUcpRequest(status); + amData->_registerInflightRequest(req); + + { + std::lock_guard lock(amData->_mutex); + amData->_recvAmMessageMap.emplace(req.get(), recvAmMessage); + } + + return UCS_INPROGRESS; + } } } else { buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length); @@ -381,6 +427,16 @@ std::shared_ptr RequestAm::getRecvBuffer() _requestData); } +std::optional RequestAm::getAmData() +{ + return std::visit( + data::dispatch{ + [](data::AmReceive amReceive) { return amReceive._amData; }, + [](auto) -> std::optional { throw std::runtime_error("Unreachable"); }, + }, + _requestData); +} + void RequestAm::request() { std::visit( From 0490c1d65a8c8944eb4462f9c02ac2135036e2bd Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 28 Jul 2025 13:12:30 -0700 Subject: [PATCH 02/14] Add tests --- cpp/tests/request.cpp | 150 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index acbffc323..41ab4b312 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -14,6 +14,7 @@ #include #include +#include #include "include/utils.h" #include "ucxx/buffer.h" @@ -264,6 +265,155 @@ TEST_P(RequestTest, ProgressAmReceiverCallback) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } +TEST_P(RequestTest, ProgressAmReceiverCallbackDelayedReceive) +{ + if (_progressMode == ProgressMode::Wait) { + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + } + + // Delayed receive only works with rendezvous messages + if (_messageSize < _rndvThresh) { + GTEST_SKIP() << "Delayed receive only works with rendezvous messages (messageSize >= " + << _rndvThresh << ")"; + } + + if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) { +#if !UCXX_ENABLE_RMM + GTEST_SKIP() << "UCXX was not built with RMM support"; +#else + _worker->registerAmAllocator(UCS_MEMORY_TYPE_CUDA, [](size_t length) { + return std::make_shared(length); + }); +#endif + } + + // Define AM receiver callback's owner and id for callback with delayReceive enabled + ucxx::AmReceiverCallbackInfo receiverCallbackInfo("TestApp", 0, true); // delayReceive = true + + // Mutex required for blocking progress mode + std::mutex mutex; + + // Storage for the received request and manual receive completion + std::shared_ptr receivedRequest{nullptr}; + std::shared_ptr manualRecvBuffer{nullptr}; + bool manualReceiveCompleted = false; + ucs_status_t manualReceiveStatus = UCS_OK; + + // Callback to handle completion of manual ucp_am_recv_data_nbx + auto manualRecvCallback = [](void* request, ucs_status_t status, size_t length, void* user_data) { + auto* data = static_cast*>(user_data); + *(std::get<0>(*data)) = true; // manualReceiveCompleted + *(std::get<1>(*data)) = status; // manualReceiveStatus + if (request != nullptr) { ucp_request_free(request); } + }; + + auto callbackUserData = std::make_tuple(&manualReceiveCompleted, &manualReceiveStatus); + + // Define AM receiver callback and register with worker + auto callback = ucxx::AmReceiverCallbackType( + [this, + &receivedRequest, + &manualRecvBuffer, + &manualRecvCallback, + &callbackUserData, + &mutex, + &manualReceiveStatus, + &manualReceiveCompleted](std::shared_ptr req, ucp_ep_h) { + { + std::lock_guard lock(mutex); + receivedRequest = req; + + // Cast to RequestAm to access getAmData() method + auto requestAm = std::dynamic_pointer_cast(req); + ASSERT_NE(requestAm, nullptr) << "Request should be a RequestAm for AM operations"; + + // Get the AM data pointer and length for delayed receive + auto amData = requestAm->getAmData(); + ASSERT_TRUE(amData.has_value()) << "AmData should be available for delayed receive"; + + // Manually allocate buffer for receiving the data + if (_memoryType == UCS_MEMORY_TYPE_HOST) { + manualRecvBuffer = std::make_shared(amData->length); +#if UCXX_ENABLE_RMM + } else if (_memoryType == UCS_MEMORY_TYPE_CUDA) { + manualRecvBuffer = std::make_shared(amData->length); +#endif + } else { + FAIL() << "Unsupported memory type for test"; + } + + // Manually call ucp_am_recv_data_nbx to receive the data + ucp_request_param_t requestParam = { + .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, + .cb = {.recv_am = manualRecvCallback}, + .user_data = &callbackUserData}; + + ucs_status_ptr_t status = ucp_am_recv_data_nbx(_worker->getHandle(), + amData->data, + manualRecvBuffer->data(), + amData->length, + &requestParam); + + if (UCS_PTR_IS_ERR(status)) { + manualReceiveStatus = UCS_PTR_STATUS(status); + manualReceiveCompleted = true; + } else if (status == nullptr) { + // Completed immediately + manualReceiveCompleted = true; + manualReceiveStatus = UCS_OK; + } + // else: will be completed by callback + } + }); + _worker->registerAmReceiverCallback(receiverCallbackInfo, callback); + + allocate(1, false); + + // Submit and wait for transfers to complete + std::vector> requests; + requests.push_back(_ep->amSend(_sendPtr[0], _messageSize, _memoryType, receiverCallbackInfo)); + waitRequests(_worker, requests, _progressWorker); + + // Wait for the AM receiver callback to be called + while (receivedRequest == nullptr) + _progressWorker(); + + // Wait for the manual receive to complete + while (!manualReceiveCompleted) + _progressWorker(); + + // Verify manual receive completed successfully + ASSERT_EQ(manualReceiveStatus, UCS_OK) << "Manual receive should complete successfully"; + + { + std::lock_guard lock(mutex); + + // Cast to RequestAm to access getAmData() method + auto requestAm = std::dynamic_pointer_cast(receivedRequest); + ASSERT_NE(requestAm, nullptr) << "Request should be a RequestAm for AM operations"; + + // Verify that the received request has AM data but no regular receive buffer + ASSERT_TRUE(requestAm->getAmData().has_value()) + << "Delayed receive request should have AM data"; + ASSERT_EQ(requestAm->getRecvBuffer(), nullptr) + << "Delayed receive request should not have a receive buffer"; + + // Set up the manually received data for verification + _recvPtr[0] = manualRecvBuffer->data(); + + // Verify buffer type matches expectation + ASSERT_THAT(manualRecvBuffer->getType(), + (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) + ? _bufferType + : ucxx::BufferType::Host); + } + + copyResults(); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + TEST_P(RequestTest, ProgressStream) { allocate(); From 90bc52444a916f104b0fec3564609756a9a6b8d5 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 28 Jul 2025 13:39:04 -0700 Subject: [PATCH 03/14] Change to `receiveData` --- cpp/include/ucxx/request_am.h | 31 +++++++++--- cpp/src/request_am.cpp | 69 ++++++++++++++++++++++++-- cpp/tests/request.cpp | 93 +++++++++++------------------------ 3 files changed, 120 insertions(+), 73 deletions(-) diff --git a/cpp/include/ucxx/request_am.h b/cpp/include/ucxx/request_am.h index d8bf654ce..e6c61eb20 100644 --- a/cpp/include/ucxx/request_am.h +++ b/cpp/include/ucxx/request_am.h @@ -156,16 +156,33 @@ class RequestAm : public Request { [[nodiscard]] std::shared_ptr getRecvBuffer() override; /** - * @brief Get the Active Message data pointer and length for delayed receive operations. + * @brief Receive delayed Active Message data into a user-provided buffer. * - * This method returns the AM data pointer and length that was stored when delayReceive - * was enabled and the message was received via the receive callback. This data can be used - * by the user to manually call ucp_am_recv_data_nbx when they are ready to receive the data. + * This method is used for delayed receive operations where delayReceive was enabled. + * It takes a user-provided buffer and internally calls ucp_am_recv_data_nbx to receive + * the AM data that was stored when the message first arrived. Returns a Request object + * that the user can wait on for completion. * - * @returns The AmData struct containing the void* data pointer and length, or std::nullopt - * if this was not a delayed receive operation. + * @param[in] buffer The buffer to receive the AM data into. Must be large enough to hold + * the AM data (length available via the original AM callback). + * + * @returns A shared_ptr to a Request object that can be waited on for completion. + * Returns nullptr if this was not a delayed receive operation or AM data + * is not available. + */ + [[nodiscard]] std::shared_ptr receiveData(std::shared_ptr buffer); + + /** + * @brief Get the length of delayed Active Message data. + * + * This method returns the length of the AM data that was stored when delayReceive + * was enabled and the message was received via the receive callback. This allows + * users to allocate appropriately sized buffers before calling receiveData(). + * + * @returns The length of the AM data in bytes, or 0 if this was not a delayed + * receive operation or AM data is not available. */ - [[nodiscard]] std::optional getAmData(); + [[nodiscard]] size_t getAmDataLength(); }; } // namespace ucxx diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index a7968923e..bcb0256d0 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -427,16 +427,79 @@ std::shared_ptr RequestAm::getRecvBuffer() _requestData); } -std::optional RequestAm::getAmData() +std::shared_ptr RequestAm::receiveData(std::shared_ptr buffer) { return std::visit( data::dispatch{ - [](data::AmReceive amReceive) { return amReceive._amData; }, - [](auto) -> std::optional { throw std::runtime_error("Unreachable"); }, + [this, &buffer](data::AmReceive amReceive) -> std::shared_ptr { + if (!amReceive._amData.has_value()) { + // No AM data available - not a delayed receive operation + return nullptr; + } + + auto& amData = amReceive._amData.value(); + + // Validate buffer size + if (buffer->getSize() < amData.length) { + throw std::runtime_error("Buffer too small for AM data"); + } + + // Create a new RequestAm for the delayed receive operation + auto request = std::shared_ptr(new RequestAm( + _worker, data::AmReceive(), "amDelayedReceive", _enablePythonFuture, nullptr, nullptr)); + + // Store the buffer in the request + auto& requestAmReceiveData = std::get(request->_requestData); + requestAmReceiveData._buffer = buffer; + + // Simple completion callback that directly calls request completion + auto callback = + [](void* ucpRequest, ucs_status_t status, size_t /*length*/, void* userData) { + auto* req = static_cast(userData); + req->callback(ucpRequest, status); + }; + + // Set up UCP request parameters + ucp_request_param_t requestParam = { + .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, + .cb = {.recv_am = callback}, + .user_data = request.get()}; + + // Call ucp_am_recv_data_nbx + ucs_status_ptr_t status = ucp_am_recv_data_nbx( + _worker->getHandle(), amData.data, buffer->data(), amData.length, &requestParam); + + if (UCS_PTR_IS_ERR(status)) { + request->callback(nullptr, UCS_PTR_STATUS(status)); + } else if (status == nullptr) { + // Completed immediately + request->callback(nullptr, UCS_OK); + } else { + // Will be completed by callback - store UCP request + request->_request = status; + } + + return request; + }, + [](auto) -> std::shared_ptr { throw std::runtime_error("Unreachable"); }, }, _requestData); } +size_t RequestAm::getAmDataLength() +{ + return std::visit(data::dispatch{ + [](data::AmReceive amReceive) -> size_t { + if (!amReceive._amData.has_value()) { + return 0; // No AM data available + } + return amReceive._amData.value().length; + }, + [](auto) -> size_t { throw std::runtime_error("Unreachable"); }, + }, + _requestData); +} + void RequestAm::request() { std::visit( diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 41ab4b312..c14fbe5bb 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -271,12 +271,6 @@ TEST_P(RequestTest, ProgressAmReceiverCallbackDelayedReceive) GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; } - // Delayed receive only works with rendezvous messages - if (_messageSize < _rndvThresh) { - GTEST_SKIP() << "Delayed receive only works with rendezvous messages (messageSize >= " - << _rndvThresh << ")"; - } - if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) { #if !UCXX_ENABLE_RMM GTEST_SKIP() << "UCXX was not built with RMM support"; @@ -293,76 +287,45 @@ TEST_P(RequestTest, ProgressAmReceiverCallbackDelayedReceive) // Mutex required for blocking progress mode std::mutex mutex; - // Storage for the received request and manual receive completion + // Storage for the received request and receive operation std::shared_ptr receivedRequest{nullptr}; std::shared_ptr manualRecvBuffer{nullptr}; - bool manualReceiveCompleted = false; - ucs_status_t manualReceiveStatus = UCS_OK; - - // Callback to handle completion of manual ucp_am_recv_data_nbx - auto manualRecvCallback = [](void* request, ucs_status_t status, size_t length, void* user_data) { - auto* data = static_cast*>(user_data); - *(std::get<0>(*data)) = true; // manualReceiveCompleted - *(std::get<1>(*data)) = status; // manualReceiveStatus - if (request != nullptr) { ucp_request_free(request); } - }; - - auto callbackUserData = std::make_tuple(&manualReceiveCompleted, &manualReceiveStatus); + std::shared_ptr receiveDataRequest{nullptr}; // Define AM receiver callback and register with worker auto callback = ucxx::AmReceiverCallbackType( - [this, - &receivedRequest, - &manualRecvBuffer, - &manualRecvCallback, - &callbackUserData, - &mutex, - &manualReceiveStatus, - &manualReceiveCompleted](std::shared_ptr req, ucp_ep_h) { + [this, &receivedRequest, &manualRecvBuffer, &receiveDataRequest, &mutex]( + std::shared_ptr req, ucp_ep_h) { { std::lock_guard lock(mutex); receivedRequest = req; - // Cast to RequestAm to access getAmData() method + // Cast to RequestAm to access receiveData() method auto requestAm = std::dynamic_pointer_cast(req); ASSERT_NE(requestAm, nullptr) << "Request should be a RequestAm for AM operations"; - // Get the AM data pointer and length for delayed receive - auto amData = requestAm->getAmData(); - ASSERT_TRUE(amData.has_value()) << "AmData should be available for delayed receive"; + // Get the actual message length from the delayed receive data + size_t messageLength = requestAm->getAmDataLength(); + ASSERT_GT(messageLength, 0) + << "AM data length should be greater than 0 for delayed receive"; + ASSERT_EQ(messageLength, _messageSize) + << "AM data length should match the sent message size"; - // Manually allocate buffer for receiving the data + // Allocate buffer based on the actual message length if (_memoryType == UCS_MEMORY_TYPE_HOST) { - manualRecvBuffer = std::make_shared(amData->length); + manualRecvBuffer = std::make_shared(messageLength); #if UCXX_ENABLE_RMM } else if (_memoryType == UCS_MEMORY_TYPE_CUDA) { - manualRecvBuffer = std::make_shared(amData->length); + manualRecvBuffer = std::make_shared(messageLength); #endif } else { FAIL() << "Unsupported memory type for test"; } - // Manually call ucp_am_recv_data_nbx to receive the data - ucp_request_param_t requestParam = { - .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, - .cb = {.recv_am = manualRecvCallback}, - .user_data = &callbackUserData}; - - ucs_status_ptr_t status = ucp_am_recv_data_nbx(_worker->getHandle(), - amData->data, - manualRecvBuffer->data(), - amData->length, - &requestParam); - - if (UCS_PTR_IS_ERR(status)) { - manualReceiveStatus = UCS_PTR_STATUS(status); - manualReceiveCompleted = true; - } else if (status == nullptr) { - // Completed immediately - manualReceiveCompleted = true; - manualReceiveStatus = UCS_OK; - } - // else: will be completed by callback + // Use the new receiveData() API to receive the AM data + receiveDataRequest = requestAm->receiveData(manualRecvBuffer); + ASSERT_NE(receiveDataRequest, nullptr) + << "receiveData should return a valid request for delayed receive"; } }); _worker->registerAmReceiverCallback(receiverCallbackInfo, callback); @@ -378,23 +341,27 @@ TEST_P(RequestTest, ProgressAmReceiverCallbackDelayedReceive) while (receivedRequest == nullptr) _progressWorker(); - // Wait for the manual receive to complete - while (!manualReceiveCompleted) + // Wait for the receiveData request to be created + while (receiveDataRequest == nullptr) + _progressWorker(); + + // Wait for the receive data request to complete + while (!receiveDataRequest->isCompleted()) _progressWorker(); - // Verify manual receive completed successfully - ASSERT_EQ(manualReceiveStatus, UCS_OK) << "Manual receive should complete successfully"; + // Verify receive data request completed successfully + ASSERT_TRUE(receiveDataRequest->isCompleted()) << "Receive data request should be completed"; + ASSERT_EQ(receiveDataRequest->getStatus(), UCS_OK) + << "Receive data request should complete without error"; { std::lock_guard lock(mutex); - // Cast to RequestAm to access getAmData() method + // Cast to RequestAm to verify it's the correct type auto requestAm = std::dynamic_pointer_cast(receivedRequest); ASSERT_NE(requestAm, nullptr) << "Request should be a RequestAm for AM operations"; - // Verify that the received request has AM data but no regular receive buffer - ASSERT_TRUE(requestAm->getAmData().has_value()) - << "Delayed receive request should have AM data"; + // Verify that the original delayed receive request has no regular receive buffer ASSERT_EQ(requestAm->getRecvBuffer(), nullptr) << "Delayed receive request should not have a receive buffer"; From ca30f43f3e337e8b2238a23f0e29b00a5c2a98b9 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 28 Jul 2025 22:04:09 -0700 Subject: [PATCH 04/14] Allow determining whether receive was delayed --- cpp/include/ucxx/request_am.h | 13 ++++ cpp/src/request_am.cpp | 10 +++ cpp/tests/request.cpp | 113 ++++++++++++++++++++++------------ 3 files changed, 98 insertions(+), 38 deletions(-) diff --git a/cpp/include/ucxx/request_am.h b/cpp/include/ucxx/request_am.h index e6c61eb20..903f739ba 100644 --- a/cpp/include/ucxx/request_am.h +++ b/cpp/include/ucxx/request_am.h @@ -183,6 +183,19 @@ class RequestAm : public Request { * receive operation or AM data is not available. */ [[nodiscard]] size_t getAmDataLength(); + + /** + * @brief Check if this request uses delayed receive. + * + * This method returns true if this request is a delayed receive operation where + * delayReceive was enabled and AM data is stored for later retrieval. If true, + * users should use getAmDataLength() and receiveData() to retrieve the data. + * If false, users should use getRecvBuffer() to access immediately received data. + * + * @returns True if this is a delayed receive operation with stored AM data, + * false for immediate/eager receives or if no AM data is available. + */ + [[nodiscard]] bool isDelayedReceive(); }; } // namespace ucxx diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index bcb0256d0..e9900cd09 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -500,6 +500,16 @@ size_t RequestAm::getAmDataLength() _requestData); } +bool RequestAm::isDelayedReceive() +{ + return std::visit( + data::dispatch{ + [](data::AmReceive amReceive) -> bool { return amReceive._amData.has_value(); }, + [](auto) -> bool { return false; }, + }, + _requestData); +} + void RequestAm::request() { std::visit( diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index c14fbe5bb..0cb289746 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -304,28 +304,41 @@ TEST_P(RequestTest, ProgressAmReceiverCallbackDelayedReceive) auto requestAm = std::dynamic_pointer_cast(req); ASSERT_NE(requestAm, nullptr) << "Request should be a RequestAm for AM operations"; - // Get the actual message length from the delayed receive data - size_t messageLength = requestAm->getAmDataLength(); - ASSERT_GT(messageLength, 0) - << "AM data length should be greater than 0 for delayed receive"; - ASSERT_EQ(messageLength, _messageSize) - << "AM data length should match the sent message size"; - - // Allocate buffer based on the actual message length - if (_memoryType == UCS_MEMORY_TYPE_HOST) { - manualRecvBuffer = std::make_shared(messageLength); + // Check if this is a delayed receive operation + if (requestAm->isDelayedReceive()) { + // Delayed receive: use getAmDataLength() and receiveData() API + size_t messageLength = requestAm->getAmDataLength(); + ASSERT_GT(messageLength, 0) + << "AM data length should be greater than 0 for delayed receive"; + ASSERT_EQ(messageLength, _messageSize) + << "AM data length should match the sent message size"; + + // Allocate buffer based on the actual message length + if (_memoryType == UCS_MEMORY_TYPE_HOST) { + manualRecvBuffer = std::make_shared(messageLength); #if UCXX_ENABLE_RMM - } else if (_memoryType == UCS_MEMORY_TYPE_CUDA) { - manualRecvBuffer = std::make_shared(messageLength); + } else if (_memoryType == UCS_MEMORY_TYPE_CUDA) { + manualRecvBuffer = std::make_shared(messageLength); #endif + } else { + FAIL() << "Unsupported memory type for test"; + } + + // Use the new receiveData() API to receive the AM data + receiveDataRequest = requestAm->receiveData(manualRecvBuffer); + ASSERT_NE(receiveDataRequest, nullptr) + << "receiveData should return a valid request for delayed receive"; } else { - FAIL() << "Unsupported memory type for test"; + // Immediate/eager receive: data is already available via getRecvBuffer() + manualRecvBuffer = requestAm->getRecvBuffer(); + ASSERT_NE(manualRecvBuffer, nullptr) + << "getRecvBuffer should return valid buffer for immediate receive"; + ASSERT_EQ(manualRecvBuffer->getSize(), _messageSize) + << "Received buffer size should match sent message size"; + + // For immediate receives, there's no additional request to wait on + receiveDataRequest = nullptr; } - - // Use the new receiveData() API to receive the AM data - receiveDataRequest = requestAm->receiveData(manualRecvBuffer); - ASSERT_NE(receiveDataRequest, nullptr) - << "receiveData should return a valid request for delayed receive"; } }); _worker->registerAmReceiverCallback(receiverCallbackInfo, callback); @@ -341,18 +354,25 @@ TEST_P(RequestTest, ProgressAmReceiverCallbackDelayedReceive) while (receivedRequest == nullptr) _progressWorker(); - // Wait for the receiveData request to be created - while (receiveDataRequest == nullptr) - _progressWorker(); + // Cast to check the receive type + auto requestAm = std::dynamic_pointer_cast(receivedRequest); + ASSERT_NE(requestAm, nullptr); - // Wait for the receive data request to complete - while (!receiveDataRequest->isCompleted()) - _progressWorker(); + if (requestAm->isDelayedReceive()) { + // For delayed receive: wait for receiveData request to be created and completed + while (receiveDataRequest == nullptr) + _progressWorker(); - // Verify receive data request completed successfully - ASSERT_TRUE(receiveDataRequest->isCompleted()) << "Receive data request should be completed"; - ASSERT_EQ(receiveDataRequest->getStatus(), UCS_OK) - << "Receive data request should complete without error"; + // Wait for the receive data request to complete + while (!receiveDataRequest->isCompleted()) + _progressWorker(); + + // Verify receive data request completed successfully + ASSERT_TRUE(receiveDataRequest->isCompleted()) << "Receive data request should be completed"; + ASSERT_EQ(receiveDataRequest->getStatus(), UCS_OK) + << "Receive data request should complete without error"; + } + // For immediate receives, no additional waiting needed - data is already available { std::lock_guard lock(mutex); @@ -361,18 +381,35 @@ TEST_P(RequestTest, ProgressAmReceiverCallbackDelayedReceive) auto requestAm = std::dynamic_pointer_cast(receivedRequest); ASSERT_NE(requestAm, nullptr) << "Request should be a RequestAm for AM operations"; - // Verify that the original delayed receive request has no regular receive buffer - ASSERT_EQ(requestAm->getRecvBuffer(), nullptr) - << "Delayed receive request should not have a receive buffer"; + if (requestAm->isDelayedReceive()) { + // Verify that the original delayed receive request has no regular receive buffer + ASSERT_EQ(requestAm->getRecvBuffer(), nullptr) + << "Delayed receive request should not have a receive buffer"; + + // Verify we have the manually received data + ASSERT_NE(manualRecvBuffer, nullptr) << "Manual receive buffer should be allocated"; + + // Verify buffer type matches expectation for delayed receive + ASSERT_THAT(manualRecvBuffer->getType(), + (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) + ? _bufferType + : ucxx::BufferType::Host); + } else { + // For immediate receives, the buffer should be available from getRecvBuffer() + ASSERT_NE(manualRecvBuffer, nullptr) << "Immediate receive buffer should be available"; + ASSERT_EQ(manualRecvBuffer, requestAm->getRecvBuffer()) + << "Buffer should match the one from getRecvBuffer()"; + + // Verify buffer type matches expectation for immediate receive + ASSERT_THAT(manualRecvBuffer->getType(), + (_registerCustomAmAllocator && _messageSize >= _rndvThresh && + _memoryType == UCS_MEMORY_TYPE_CUDA) + ? _bufferType + : ucxx::BufferType::Host); + } - // Set up the manually received data for verification + // Set up the received data for verification _recvPtr[0] = manualRecvBuffer->data(); - - // Verify buffer type matches expectation - ASSERT_THAT(manualRecvBuffer->getType(), - (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) - ? _bufferType - : ucxx::BufferType::Host); } copyResults(); From 83657398a745de6849f22900e706e03146987cce Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 28 Jul 2025 22:12:18 -0700 Subject: [PATCH 05/14] Skip custom memory allocator tests where it isn't used --- cpp/tests/request.cpp | 101 ++++++++++++++++++++++++++++++------------ 1 file changed, 73 insertions(+), 28 deletions(-) diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 0cb289746..73759f381 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -33,8 +33,8 @@ using ::testing::Values; typedef std::vector DataContainerType; -class RequestTest : public ::testing::TestWithParam< - std::tuple> { +class RequestTestBase : public ::testing::TestWithParam< + std::tuple> { protected: std::shared_ptr _context{nullptr}; std::shared_ptr _worker{nullptr}; @@ -164,7 +164,11 @@ class RequestTest : public ::testing::TestWithParam< } }; -TEST_P(RequestTest, ProgressAm) +class RequestTestAmAllocator : public RequestTestBase {}; + +class RequestTestNoAmAllocator : public RequestTestBase {}; + +TEST_P(RequestTestAmAllocator, ProgressAm) { if (_progressMode == ProgressMode::Wait) { GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; @@ -203,7 +207,7 @@ TEST_P(RequestTest, ProgressAm) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, ProgressAmReceiverCallback) +TEST_P(RequestTestAmAllocator, ProgressAmReceiverCallback) { if (_progressMode == ProgressMode::Wait) { GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; @@ -265,13 +269,13 @@ TEST_P(RequestTest, ProgressAmReceiverCallback) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, ProgressAmReceiverCallbackDelayedReceive) +TEST_P(RequestTestNoAmAllocator, ProgressAmReceiverCallbackDelayedReceive) { if (_progressMode == ProgressMode::Wait) { GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; } - if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) { + if (_memoryType == UCS_MEMORY_TYPE_CUDA) { #if !UCXX_ENABLE_RMM GTEST_SKIP() << "UCXX was not built with RMM support"; #else @@ -391,9 +395,7 @@ TEST_P(RequestTest, ProgressAmReceiverCallbackDelayedReceive) // Verify buffer type matches expectation for delayed receive ASSERT_THAT(manualRecvBuffer->getType(), - (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) - ? _bufferType - : ucxx::BufferType::Host); + (_memoryType == UCS_MEMORY_TYPE_CUDA) ? _bufferType : ucxx::BufferType::Host); } else { // For immediate receives, the buffer should be available from getRecvBuffer() ASSERT_NE(manualRecvBuffer, nullptr) << "Immediate receive buffer should be available"; @@ -402,8 +404,7 @@ TEST_P(RequestTest, ProgressAmReceiverCallbackDelayedReceive) // Verify buffer type matches expectation for immediate receive ASSERT_THAT(manualRecvBuffer->getType(), - (_registerCustomAmAllocator && _messageSize >= _rndvThresh && - _memoryType == UCS_MEMORY_TYPE_CUDA) + (_messageSize >= _rndvThresh && _memoryType == UCS_MEMORY_TYPE_CUDA) ? _bufferType : ucxx::BufferType::Host); } @@ -418,7 +419,7 @@ TEST_P(RequestTest, ProgressAmReceiverCallbackDelayedReceive) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, ProgressStream) +TEST_P(RequestTestNoAmAllocator, ProgressStream) { allocate(); @@ -439,7 +440,7 @@ TEST_P(RequestTest, ProgressStream) } } -TEST_P(RequestTest, ProgressTag) +TEST_P(RequestTestNoAmAllocator, ProgressTag) { allocate(); @@ -455,7 +456,7 @@ TEST_P(RequestTest, ProgressTag) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, ProgressTagMulti) +TEST_P(RequestTestNoAmAllocator, ProgressTagMulti) { if (_progressMode == ProgressMode::Wait) { GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; @@ -502,7 +503,7 @@ TEST_P(RequestTest, ProgressTagMulti) ASSERT_THAT(_recv[i], ContainerEq(_send[i])); } -TEST_P(RequestTest, TagUserCallback) +TEST_P(RequestTestNoAmAllocator, TagUserCallback) { allocate(); @@ -536,7 +537,7 @@ TEST_P(RequestTest, TagUserCallback) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, TagUserCallbackDiscardReturn) +TEST_P(RequestTestNoAmAllocator, TagUserCallbackDiscardReturn) { allocate(); @@ -577,7 +578,7 @@ TEST_P(RequestTest, TagUserCallbackDiscardReturn) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, MemoryGet) +TEST_P(RequestTestNoAmAllocator, MemoryGet) { allocate(); @@ -604,7 +605,7 @@ TEST_P(RequestTest, MemoryGet) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, MemoryGetPreallocated) +TEST_P(RequestTestNoAmAllocator, MemoryGetPreallocated) { allocate(); @@ -627,7 +628,7 @@ TEST_P(RequestTest, MemoryGetPreallocated) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, MemoryGetWithOffset) +TEST_P(RequestTestNoAmAllocator, MemoryGetWithOffset) { if (_messageLength < 2) GTEST_SKIP() << "Message too small to perform operations with offsets"; allocate(); @@ -663,7 +664,7 @@ TEST_P(RequestTest, MemoryGetWithOffset) ASSERT_THAT(recvOffset, sendOffset); } -TEST_P(RequestTest, MemoryPut) +TEST_P(RequestTestNoAmAllocator, MemoryPut) { allocate(); @@ -690,7 +691,7 @@ TEST_P(RequestTest, MemoryPut) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, MemoryPutPreallocated) +TEST_P(RequestTestNoAmAllocator, MemoryPutPreallocated) { allocate(); @@ -713,7 +714,7 @@ TEST_P(RequestTest, MemoryPutPreallocated) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, MemoryPutWithOffset) +TEST_P(RequestTestNoAmAllocator, MemoryPutWithOffset) { if (_messageLength < 2) GTEST_SKIP() << "Message too small to perform operations with offsets"; allocate(); @@ -749,8 +750,9 @@ TEST_P(RequestTest, MemoryPutWithOffset) ASSERT_THAT(recvOffset, sendOffset); } -INSTANTIATE_TEST_SUITE_P(ProgressModes, - RequestTest, +// Tests that support custom AM allocator +INSTANTIATE_TEST_SUITE_P(HostProgressModes, + RequestTestAmAllocator, Combine(Values(ucxx::BufferType::Host), Values(false), Values(false), @@ -761,8 +763,8 @@ INSTANTIATE_TEST_SUITE_P(ProgressModes, ProgressMode::ThreadBlocking), Values(0, 1, 1024, 2048, 1048576))); -INSTANTIATE_TEST_SUITE_P(DelayedSubmission, - RequestTest, +INSTANTIATE_TEST_SUITE_P(HostDelayedSubmission, + RequestTestAmAllocator, Combine(Values(ucxx::BufferType::Host), Values(false), Values(true), @@ -771,7 +773,7 @@ INSTANTIATE_TEST_SUITE_P(DelayedSubmission, #if UCXX_ENABLE_RMM INSTANTIATE_TEST_SUITE_P(RMMProgressModes, - RequestTest, + RequestTestAmAllocator, Combine(Values(ucxx::BufferType::RMM), Values(false, true), Values(false), @@ -783,7 +785,7 @@ INSTANTIATE_TEST_SUITE_P(RMMProgressModes, Values(0, 1, 1024, 2048, 1048576))); INSTANTIATE_TEST_SUITE_P(RMMDelayedSubmission, - RequestTest, + RequestTestAmAllocator, Combine(Values(ucxx::BufferType::RMM), Values(false, true), Values(true), @@ -791,4 +793,47 @@ INSTANTIATE_TEST_SUITE_P(RMMDelayedSubmission, Values(0, 1, 1024, 2048, 1048576))); #endif +// Tests that do NOT support custom AM allocator (always false for _registerCustomAmAllocator) +INSTANTIATE_TEST_SUITE_P(HostProgressModes, + RequestTestNoAmAllocator, + Combine(Values(ucxx::BufferType::Host), + Values(false), // Never use custom AM allocator for these tests + Values(false), + Values(ProgressMode::Polling, + ProgressMode::Blocking, + // ProgressMode::Wait, // Hangs on Stream + ProgressMode::ThreadPolling, + ProgressMode::ThreadBlocking), + Values(0, 1, 1024, 2048, 1048576))); + +INSTANTIATE_TEST_SUITE_P(HostDelayedSubmission, + RequestTestNoAmAllocator, + Combine(Values(ucxx::BufferType::Host), + Values(false), // Never use custom AM allocator for these tests + Values(true), + Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), + Values(0, 1, 1024, 2048, 1048576))); + +#if UCXX_ENABLE_RMM +INSTANTIATE_TEST_SUITE_P(RMMProgressModes, + RequestTestNoAmAllocator, + Combine(Values(ucxx::BufferType::RMM), + Values(false), // Never use custom AM allocator for these tests + Values(false), + Values(ProgressMode::Polling, + ProgressMode::Blocking, + // ProgressMode::Wait, // Hangs on Stream + ProgressMode::ThreadPolling, + ProgressMode::ThreadBlocking), + Values(0, 1, 1024, 2048, 1048576))); + +INSTANTIATE_TEST_SUITE_P(RMMDelayedSubmission, + RequestTestNoAmAllocator, + Combine(Values(ucxx::BufferType::RMM), + Values(false), // Never use custom AM allocator for these tests + Values(true), + Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), + Values(0, 1, 1024, 2048, 1048576))); +#endif + } // namespace From 002fe768d9c85a16204083a0d746e29bf7019108 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 28 Jul 2025 22:38:45 -0700 Subject: [PATCH 06/14] Add descriptions to test name --- cpp/tests/request.cpp | 72 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 64 insertions(+), 8 deletions(-) diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 73759f381..e211e5f82 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -750,6 +751,53 @@ TEST_P(RequestTestNoAmAllocator, MemoryPutWithOffset) ASSERT_THAT(recvOffset, sendOffset); } +// Custom naming function for parameterized tests +std::string generateTestName( + const ::testing::TestParamInfo>& + info) +{ + auto [bufferType, + registerCustomAmAllocator, + enableDelayedSubmission, + progressMode, + messageLength] = info.param; + + std::string name; + + // Buffer type + name += (bufferType == ucxx::BufferType::Host) ? "Host" : "RMM"; + + // Custom AM allocator + if (registerCustomAmAllocator) { name += "_CustomAmAlloc"; } + + // Delayed submission + if (enableDelayedSubmission) { name += "_DelayedSubmission"; } + + // Progress mode + name += "_"; + switch (progressMode) { + case ProgressMode::Polling: name += "Polling"; break; + case ProgressMode::Blocking: name += "Blocking"; break; + case ProgressMode::Wait: name += "Wait"; break; + case ProgressMode::ThreadPolling: name += "ThreadPolling"; break; + case ProgressMode::ThreadBlocking: name += "ThreadBlocking"; break; + } + + // Message length + name += "_Msg"; + if (messageLength == 0) { + name += "Empty"; + } else if (messageLength >= 1048576) { + name += "1MB"; + } else if (messageLength >= 1024) { + name += std::to_string(messageLength / 1024) + "KB"; + } else { + name += std::to_string(messageLength) + "B"; + } + + return name; +} + // Tests that support custom AM allocator INSTANTIATE_TEST_SUITE_P(HostProgressModes, RequestTestAmAllocator, @@ -761,7 +809,8 @@ INSTANTIATE_TEST_SUITE_P(HostProgressModes, // ProgressMode::Wait, // Hangs on Stream ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(0, 1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); INSTANTIATE_TEST_SUITE_P(HostDelayedSubmission, RequestTestAmAllocator, @@ -769,7 +818,8 @@ INSTANTIATE_TEST_SUITE_P(HostDelayedSubmission, Values(false), Values(true), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(0, 1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); #if UCXX_ENABLE_RMM INSTANTIATE_TEST_SUITE_P(RMMProgressModes, @@ -782,7 +832,8 @@ INSTANTIATE_TEST_SUITE_P(RMMProgressModes, // ProgressMode::Wait, // Hangs on Stream ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(0, 1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); INSTANTIATE_TEST_SUITE_P(RMMDelayedSubmission, RequestTestAmAllocator, @@ -790,7 +841,8 @@ INSTANTIATE_TEST_SUITE_P(RMMDelayedSubmission, Values(false, true), Values(true), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(0, 1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); #endif // Tests that do NOT support custom AM allocator (always false for _registerCustomAmAllocator) @@ -804,7 +856,8 @@ INSTANTIATE_TEST_SUITE_P(HostProgressModes, // ProgressMode::Wait, // Hangs on Stream ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(0, 1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); INSTANTIATE_TEST_SUITE_P(HostDelayedSubmission, RequestTestNoAmAllocator, @@ -812,7 +865,8 @@ INSTANTIATE_TEST_SUITE_P(HostDelayedSubmission, Values(false), // Never use custom AM allocator for these tests Values(true), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(0, 1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); #if UCXX_ENABLE_RMM INSTANTIATE_TEST_SUITE_P(RMMProgressModes, @@ -825,7 +879,8 @@ INSTANTIATE_TEST_SUITE_P(RMMProgressModes, // ProgressMode::Wait, // Hangs on Stream ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(0, 1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); INSTANTIATE_TEST_SUITE_P(RMMDelayedSubmission, RequestTestNoAmAllocator, @@ -833,7 +888,8 @@ INSTANTIATE_TEST_SUITE_P(RMMDelayedSubmission, Values(false), // Never use custom AM allocator for these tests Values(true), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(0, 1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); #endif } // namespace From da80a5bc4ad86211cb496e5613c9ca13c18b2b7f Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 4 Aug 2025 12:49:51 -0700 Subject: [PATCH 07/14] Add support for user header --- cpp/include/ucxx/internal/request_am.h | 21 ++-- cpp/include/ucxx/typedefs.h | 162 +++++++++++++++++++++++-- cpp/src/internal/request_am.cpp | 9 +- cpp/src/request_am.cpp | 53 ++++++-- cpp/tests/request.cpp | 158 +++++++++++++++++++++++- cpp/tests/worker.cpp | 5 +- 6 files changed, 370 insertions(+), 38 deletions(-) diff --git a/cpp/include/ucxx/internal/request_am.h b/cpp/include/ucxx/internal/request_am.h index c259d4886..b23f088bf 100644 --- a/cpp/include/ucxx/internal/request_am.h +++ b/cpp/include/ucxx/internal/request_am.h @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -38,6 +38,8 @@ class RecvAmMessage { std::shared_ptr _request{ nullptr}; ///< Request which will later be notified/delivered to user std::shared_ptr _buffer{nullptr}; ///< Buffer containing the received data + std::optional _receiverCallbackInfo{ + std::nullopt}; ///< Callback info with user header RecvAmMessage() = delete; RecvAmMessage(const RecvAmMessage&) = delete; @@ -50,18 +52,21 @@ class RecvAmMessage { * * Construct the object, setting attributes that are later needed by the callback. * - * @param[in] amData active messages worker data. - * @param[in] ep handle containing address of the reply endpoint (i.e., - endpoint where user is requesting to receive). - * @param[in] request request to be later notified/delivered to user. - * @param[in] buffer buffer containing the received data - * @param[in] receiverCallback receiver callback to execute when request completes. + * @param[in] amData active messages worker data. + * @param[in] ep handle containing address of the reply endpoint (i.e., + * endpoint where user is requesting to receive). + * @param[in] request request to be later notified/delivered to user. + * @param[in] buffer buffer containing the received data. + * @param[in] receiverCallback receiver callback to execute when request completes. + * @param[in] receiverCallbackInfo receiver callback info to execute when request completes, + * including user header. */ RecvAmMessage(internal::AmData* amData, ucp_ep_h ep, std::shared_ptr request, std::shared_ptr buffer, - AmReceiverCallbackType receiverCallback = AmReceiverCallbackType()); + AmReceiverCallbackType receiverCallback = AmReceiverCallbackType(), + std::optional receiverCallbackInfo = std::nullopt); /** * @brief Set the UCP request. diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index 1a74e0a79..234fa9bdd 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -7,8 +7,12 @@ #include #include #include +#include +#include #include #include +#include +#include #include @@ -140,15 +144,6 @@ typedef RequestCallbackUserData EndpointCloseCallbackUserData; */ typedef std::function(size_t)> AmAllocatorType; -/** - * @brief Active Message receiver callback. - * - * Type for a custom Active Message receiver callback, executed by the remote worker upon - * Active Message request completion. The first parameter is the request that completed, - * the second is the handle of the UCX endpoint of the sender. - */ -typedef std::function, ucp_ep_h)> AmReceiverCallbackType; - /** * @brief Active Message receiver callback owner name. * @@ -194,6 +189,140 @@ struct AmData { AmData(void* data, size_t length) : data(data), length(length) {} }; +/** + * @brief Container for arbitrary user header data that can be attached to Active Messages. + * + * This class provides a type-safe interface for storing arbitrary user header data that will be + * serialized and transmitted with Active Messages. It supports common data types while + * also allowing direct access to the underlying byte storage for custom serialization. + */ +class AmUserHeader { + private: + std::vector _data; + + public: + AmUserHeader() = delete; + + /** + * @brief Construct AmUserHeader from a byte array. + * + * @param[in] data Pointer to the data to copy. + * @param[in] size Size of the data in bytes. + * + * @throws std::invalid_argument if data is null or if size is 0. + */ + AmUserHeader(const void* data, size_t size) + { + if (size == 0) { throw std::invalid_argument("AmUserHeader size must be greater than zero"); } + if (data == nullptr) { + throw std::invalid_argument("AmUserHeader data pointer cannot be null"); + } + _data.assign(static_cast(data), static_cast(data) + size); + } + + /** + * @brief Construct AmUserHeader from a string. + * + * @param[in] str The string to store. + * + * @throws std::invalid_argument if the string is empty. + */ + explicit AmUserHeader(const std::string& str) : _data(str.begin(), str.end()) + { + if (str.empty()) { throw std::invalid_argument("AmUserHeader string cannot be empty"); } + } + + /** + * @brief Construct AmUserHeader from a vector of bytes. + * + * @param[in] data The byte vector to copy. + * + * @throws std::invalid_argument if the vector is empty. + */ + explicit AmUserHeader(const std::vector& data) : _data(data) + { + if (data.empty()) { throw std::invalid_argument("AmUserHeader vector cannot be empty"); } + } + + /** + * @brief Construct AmUserHeader from a vector of bytes (move constructor). + * + * @param[in] data The byte vector to move. + * + * @throws std::invalid_argument if the vector is empty. + */ + explicit AmUserHeader(std::vector&& data) : _data(std::move(data)) + { + if (_data.empty()) { throw std::invalid_argument("AmUserHeader vector cannot be empty"); } + } + + /** + * @brief Template constructor for POD types. + * + * @param[in] value The POD value to store. + */ + template + explicit AmUserHeader(const T& value) + : _data(reinterpret_cast(&value), + reinterpret_cast(&value) + sizeof(T)) + { + static_assert(std::is_trivially_copyable_v, "Type must be trivially copyable"); + static_assert(sizeof(T) > 0, "Type size must be greater than zero"); + } + + /** + * @brief Get the underlying data as a byte array. + * + * @returns Pointer to the underlying data. + */ + [[nodiscard]] const uint8_t* data() const { return _data.data(); } + + /** + * @brief Get the size of the data in bytes. + * + * @returns Size of the data in bytes. + */ + [[nodiscard]] size_t size() const { return _data.size(); } + + /** + * @brief Check if the user data is empty. + * + * @returns True if no data is stored, false otherwise. + */ + [[nodiscard]] bool empty() const { return _data.empty(); } + + /** + * @brief Get the data as a string. + * + * @returns String representation of the data. + */ + [[nodiscard]] std::string asString() const { return std::string(_data.begin(), _data.end()); } + + /** + * @brief Get the data as a specific POD type. + * + * @returns Reference to the data interpreted as type T. + * @throws std::runtime_error if the size doesn't match. + */ + template + [[nodiscard]] const T& as() const + { + static_assert(std::is_trivially_copyable_v, "Type must be trivially copyable"); + if (_data.size() != sizeof(T)) { + throw std::runtime_error("AmUserHeader size mismatch: expected " + std::to_string(sizeof(T)) + + " bytes, got " + std::to_string(_data.size())); + } + return *reinterpret_cast(_data.data()); + } + + /** + * @brief Get a copy of the underlying byte vector. + * + * @returns Copy of the underlying data vector. + */ + [[nodiscard]] std::vector getBytes() const { return _data; } +}; + /** * @brief Information of an Active Message receiver callback. * @@ -204,6 +333,7 @@ class AmReceiverCallbackInfo { const AmReceiverCallbackOwnerType owner; ///< The owner name of the callback const AmReceiverCallbackIdType id; ///< The unique identifier of the callback const bool delayReceive; ///< Whether to delay receiving data (user-controlled) + const std::optional userHeader; ///< Optional arbitrary user header data AmReceiverCallbackInfo() = delete; @@ -214,12 +344,24 @@ class AmReceiverCallbackInfo { * @param[in] id The unique identifier of the callback. * @param[in] delayReceive Whether to delay receiving data, allowing user to control when * ucp_am_recv_data_nbx is called. + * @param[in] userHeader Optional arbitrary user header data to be transmitted with the AM. */ AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, AmReceiverCallbackIdType id, - bool delayReceive = false); + bool delayReceive = false, + std::optional userHeader = std::nullopt); }; +/** + * @brief Active Message receiver callback. + * + * Type for a custom Active Message receiver callback, executed by the remote worker upon + * Active Message request completion. The first parameter is the request that completed, + * the second is the handle of the UCX endpoint of the sender. + */ +typedef std::function, ucp_ep_h, AmReceiverCallbackInfo&)> + AmReceiverCallbackType; + /** * @brief Serialized form of a remote key. * diff --git a/cpp/src/internal/request_am.cpp b/cpp/src/internal/request_am.cpp index d45dee659..5b8ccafaa 100644 --- a/cpp/src/internal/request_am.cpp +++ b/cpp/src/internal/request_am.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -18,8 +18,9 @@ RecvAmMessage::RecvAmMessage(internal::AmData* amData, ucp_ep_h ep, std::shared_ptr request, std::shared_ptr buffer, - AmReceiverCallbackType receiverCallback) - : _amData(amData), _ep(ep), _request(request) + AmReceiverCallbackType receiverCallback, + std::optional receiverCallbackInfo) + : _amData(amData), _ep(ep), _request(request), _receiverCallbackInfo(receiverCallbackInfo) { std::visit(data::dispatch{ [this, buffer](data::AmReceive& amReceive) { amReceive._buffer = buffer; }, @@ -29,7 +30,7 @@ RecvAmMessage::RecvAmMessage(internal::AmData* amData, if (receiverCallback) { _request->_callback = [this, receiverCallback](ucs_status_t, std::shared_ptr) { - receiverCallback(_request, _ep); + if (_receiverCallbackInfo) { receiverCallback(_request, _ep, *_receiverCallbackInfo); } }; } } diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index 696d704e2..e46fa741c 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -23,8 +24,9 @@ namespace ucxx { AmReceiverCallbackInfo::AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, AmReceiverCallbackIdType id, - bool delayReceive) - : owner(owner), id(id), delayReceive(delayReceive) + bool delayReceive, + std::optional userHeader) + : owner(owner), id(id), delayReceive(delayReceive), userHeader(userHeader) { } @@ -62,8 +64,23 @@ struct AmHeader { bool delayReceive{false}; decode(&delayReceive, sizeof(delayReceive)); - return AmHeader{.memoryType = memoryType, - .receiverCallbackInfo = AmReceiverCallbackInfo(owner, id, delayReceive)}; + // Check if user data is present + bool hasUserHeader{false}; + decode(&hasUserHeader, sizeof(hasUserHeader)); + + std::optional userHeader = std::nullopt; + if (hasUserHeader) { + size_t userHeaderSize{0}; + decode(&userHeaderSize, sizeof(userHeaderSize)); + + std::vector userHeaderBytes(userHeaderSize); + decode(userHeaderBytes.data(), userHeaderSize); + userHeader = AmUserHeader(std::move(userHeaderBytes)); + } + + return AmHeader{ + .memoryType = memoryType, + .receiverCallbackInfo = AmReceiverCallbackInfo(owner, id, delayReceive, userHeader)}; } return AmHeader{.memoryType = memoryType, .receiverCallbackInfo = std::nullopt}; @@ -73,11 +90,17 @@ struct AmHeader { { size_t offset{0}; bool hasReceiverCallback{static_cast(receiverCallbackInfo)}; - const size_t ownerSize = (receiverCallbackInfo) ? receiverCallbackInfo->owner.size() : 0; + const size_t ownerSize = (receiverCallbackInfo) ? receiverCallbackInfo->owner.size() : 0; + const size_t userHeaderSize = (receiverCallbackInfo && receiverCallbackInfo->userHeader) + ? receiverCallbackInfo->userHeader->size() + : 0; + const bool hasUserHeader = (receiverCallbackInfo && receiverCallbackInfo->userHeader); const size_t amReceiverCallbackInfoSize = - (receiverCallbackInfo) ? sizeof(ownerSize) + ownerSize + sizeof(receiverCallbackInfo->id) + - sizeof(receiverCallbackInfo->delayReceive) - : 0; + (receiverCallbackInfo) + ? sizeof(ownerSize) + ownerSize + sizeof(receiverCallbackInfo->id) + + sizeof(receiverCallbackInfo->delayReceive) + sizeof(hasUserHeader) + + (hasUserHeader ? sizeof(userHeaderSize) + userHeaderSize : 0) + : 0; const size_t totalSize = sizeof(memoryType) + sizeof(hasReceiverCallback) + amReceiverCallbackInfoSize; std::string serialized(totalSize, 0); @@ -94,6 +117,11 @@ struct AmHeader { encode(receiverCallbackInfo->owner.c_str(), ownerSize); encode(&receiverCallbackInfo->id, sizeof(receiverCallbackInfo->id)); encode(&receiverCallbackInfo->delayReceive, sizeof(receiverCallbackInfo->delayReceive)); + encode(&hasUserHeader, sizeof(hasUserHeader)); + if (hasUserHeader) { + encode(&userHeaderSize, sizeof(userHeaderSize)); + encode(receiverCallbackInfo->userHeader->data(), userHeaderSize); + } } return serialized; @@ -302,7 +330,7 @@ ucs_status_t RequestAm::recvCallback(void* arg, length); // Execute receiver callback if present - if (receiverCallback) { receiverCallback(req, ep); } + if (receiverCallback) { receiverCallback(req, ep, *amHeader.receiverCallbackInfo); } return UCS_INPROGRESS; } else { @@ -325,8 +353,8 @@ ucs_status_t RequestAm::recvCallback(void* arg, ucxx_debug("Exception calling allocator: %s", e.what()); } - auto recvAmMessage = - std::make_shared(amData, ep, req, buf, receiverCallback); + auto recvAmMessage = std::make_shared( + amData, ep, req, buf, receiverCallback, amHeader.receiverCallbackInfo); ucp_request_param_t requestParam = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA | @@ -387,7 +415,8 @@ ucs_status_t RequestAm::recvCallback(void* arg, } else { buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length); - internal::RecvAmMessage recvAmMessage(amData, ep, req, buf, receiverCallback); + internal::RecvAmMessage recvAmMessage( + amData, ep, req, buf, receiverCallback, amHeader.receiverCallbackInfo); if (buf == nullptr) { ucxx_debug("Failed to allocate %lu bytes of memory", length); recvAmMessage._request->setStatus(UCS_ERR_NO_MEMORY); diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index e211e5f82..72e47145a 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -169,6 +170,9 @@ class RequestTestAmAllocator : public RequestTestBase {}; class RequestTestNoAmAllocator : public RequestTestBase {}; +// Limited test suite specifically for user header functionality to reduce test runtime +class RequestTestAmUserHeader : public RequestTestBase {}; + TEST_P(RequestTestAmAllocator, ProgressAm) { if (_progressMode == ProgressMode::Wait) { @@ -234,7 +238,8 @@ TEST_P(RequestTestAmAllocator, ProgressAmReceiverCallback) // Define AM receiver callback and register with worker std::vector> receivedRequests; auto callback = ucxx::AmReceiverCallbackType( - [this, &receivedRequests, &mutex](std::shared_ptr req, ucp_ep_h) { + [this, &receivedRequests, &mutex]( + std::shared_ptr req, ucp_ep_h, const ucxx::AmReceiverCallbackInfo&) { { std::lock_guard lock(mutex); receivedRequests.push_back(req); @@ -300,7 +305,7 @@ TEST_P(RequestTestNoAmAllocator, ProgressAmReceiverCallbackDelayedReceive) // Define AM receiver callback and register with worker auto callback = ucxx::AmReceiverCallbackType( [this, &receivedRequest, &manualRecvBuffer, &receiveDataRequest, &mutex]( - std::shared_ptr req, ucp_ep_h) { + std::shared_ptr req, ucp_ep_h, const ucxx::AmReceiverCallbackInfo&) { { std::lock_guard lock(mutex); receivedRequest = req; @@ -420,6 +425,134 @@ TEST_P(RequestTestNoAmAllocator, ProgressAmReceiverCallbackDelayedReceive) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } +TEST_P(RequestTestAmUserHeader, ProgressAmReceiverCallbackWithUserHeader) +{ + if (_progressMode == ProgressMode::Wait) { + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + } + + if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) { +#if !UCXX_ENABLE_RMM + GTEST_SKIP() << "UCXX was not built with RMM support"; +#else + _worker->registerAmAllocator(UCS_MEMORY_TYPE_CUDA, [](size_t length) { + return std::make_shared(length); + }); +#endif + } + + // Create user header data - test different data types + std::string userHeaderString = "Hello from user header!"; + ucxx::AmUserHeader userHeader(userHeaderString); + std::optional receivedUserHeader = std::nullopt; + + // Define AM receiver callback's owner and id with user header + ucxx::AmReceiverCallbackInfo receiverCallbackInfo("TestAppWithHeader", 42, false, userHeader); + + // Mutex required for blocking progress mode, otherwise `receivedRequests` may be + // accessed before `push_back()` completed. + std::mutex mutex; + + // Define AM receiver callback and register with worker + std::vector> receivedRequests; + auto callback = ucxx::AmReceiverCallbackType( + [this, &receivedRequests, &mutex, &receivedUserHeader]( + std::shared_ptr req, ucp_ep_h, ucxx::AmReceiverCallbackInfo& info) { + { + std::lock_guard lock(mutex); + receivedRequests.push_back(req); + // auto userHeader = std::move(info.userHeader); + receivedUserHeader = std::move(info.userHeader); + } + }); + + _worker->registerAmReceiverCallback(receiverCallbackInfo, callback); + + allocate(1, false); + + // Submit and wait for transfers to complete + std::vector> requests; + requests.push_back(_ep->amSend(_sendPtr[0], _messageSize, _memoryType, receiverCallbackInfo)); + waitRequests(_worker, requests, _progressWorker); + + while (receivedRequests.size() < 1) + _progressWorker(); + + { + std::lock_guard lock(mutex); + _recvPtr[0] = receivedRequests[0]->getRecvBuffer()->data(); + + // Messages larger than `_rndvThresh` are rendezvous and will use custom allocator, + // smaller messages are eager and will always be host-allocated. + ASSERT_THAT(receivedRequests[0]->getRecvBuffer()->getType(), + (_registerCustomAmAllocator && _messageSize >= _rndvThresh) + ? _bufferType + : ucxx::BufferType::Host); + } + + copyResults(); + + // Assert header and data correctness + ASSERT_EQ(receivedUserHeader->asString(), userHeaderString); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST(AmUserHeaderTest, Validation) +{ + // Test that AmUserHeader constructors properly validate input + + // Test valid constructions (should not throw) + ASSERT_NO_THROW([]() { ucxx::AmUserHeader h("test"); }()); + ASSERT_NO_THROW([]() { + std::string s = "test"; + ucxx::AmUserHeader h(s); + }()); + ASSERT_NO_THROW([]() { + std::vector data(3, 0); + ucxx::AmUserHeader h(data); + }()); + ASSERT_NO_THROW([]() { + std::vector data(3, 0); + ucxx::AmUserHeader h(std::move(data)); + }()); + ASSERT_NO_THROW([]() { + int value = 42; + ucxx::AmUserHeader h(value); + }()); + ASSERT_NO_THROW([]() { ucxx::AmUserHeader h("test", 4); }()); + + // Test invalid constructions (should throw) + EXPECT_THROW([]() { ucxx::AmUserHeader h(std::string("")); }(), std::invalid_argument); + EXPECT_THROW( + []() { + std::string empty = ""; + ucxx::AmUserHeader h(empty); + }(), + std::invalid_argument); + EXPECT_THROW( + []() { + std::vector emptyVec; + ucxx::AmUserHeader h(emptyVec); + }(), + std::invalid_argument); + EXPECT_THROW( + []() { + std::vector emptyVec; + ucxx::AmUserHeader h(std::move(emptyVec)); + }(), + std::invalid_argument); + EXPECT_THROW([]() { ucxx::AmUserHeader h(nullptr, 5); }(), std::invalid_argument); + EXPECT_THROW([]() { ucxx::AmUserHeader h("test", 0); }(), std::invalid_argument); + EXPECT_THROW([]() { ucxx::AmUserHeader h("", 0); }(), std::invalid_argument); + + // Test data access works correctly + std::string testStr = "Hello"; + ucxx::AmUserHeader header(testStr); + ASSERT_EQ(header.size(), 5); + ASSERT_EQ(header.asString(), "Hello"); + ASSERT_FALSE(header.empty()); +} + TEST_P(RequestTestNoAmAllocator, ProgressStream) { allocate(); @@ -892,4 +1025,25 @@ INSTANTIATE_TEST_SUITE_P(RMMDelayedSubmission, generateTestName); #endif +// Limited parameter set for user header tests - only test single message size +INSTANTIATE_TEST_SUITE_P(UserHeaderLimited, + RequestTestAmUserHeader, + Combine(Values(ucxx::BufferType::Host), + Values(false), + Values(false), + Values(ProgressMode::Polling, ProgressMode::Blocking), + Values(1024)), // Only test with 1024 byte messages + generateTestName); + +#if UCXX_ENABLE_RMM +INSTANTIATE_TEST_SUITE_P(UserHeaderLimitedRMM, + RequestTestAmUserHeader, + Combine(Values(ucxx::BufferType::RMM), + Values(false), + Values(false), + Values(ProgressMode::Polling, ProgressMode::Blocking), + Values(1024)), // Only test with 1024 byte message5 + generateTestName); +#endif + } // namespace diff --git a/cpp/tests/worker.cpp b/cpp/tests/worker.cpp index 9bc6bfbbf..823dc3898 100644 --- a/cpp/tests/worker.cpp +++ b/cpp/tests/worker.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -194,7 +194,8 @@ TEST_P(WorkerProgressTest, ProgressAmReceiverCallback) // Define AM receiver callback and register with worker std::vector> receivedRequests; auto callback = ucxx::AmReceiverCallbackType( - [this, &receivedRequests, &mutex](std::shared_ptr req, ucp_ep_h) { + [this, &receivedRequests, &mutex]( + std::shared_ptr req, ucp_ep_h, const ucxx::AmReceiverCallbackInfo&) { { std::lock_guard lock(mutex); receivedRequests.push_back(req); From d20f401f288bf1bed3361546b457ab5aed359ad9 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 5 Aug 2025 12:26:48 -0700 Subject: [PATCH 08/14] Support receiveData with raw pointer --- cpp/include/ucxx/request_am.h | 37 +++++++++++- cpp/include/ucxx/request_data.h | 4 +- cpp/src/request_am.cpp | 103 +++++++++++++++++++++----------- 3 files changed, 106 insertions(+), 38 deletions(-) diff --git a/cpp/include/ucxx/request_am.h b/cpp/include/ucxx/request_am.h index 903f739ba..bc0a3a111 100644 --- a/cpp/include/ucxx/request_am.h +++ b/cpp/include/ucxx/request_am.h @@ -34,6 +34,20 @@ class RequestAm : public Request { std::string _header{}; ///< Retain copy of header for send requests as workaround for ///< https://github.com/openucx/ucx/issues/10424 + /** + * @brief Common implementation for receiveData operations. + * + * @param[in] amData AM data containing pointer and length. + * @param[in] bufferPtr Pointer to the buffer to receive data into. + * @param[in] managedBuffer Optional managed buffer for lifetime management (nullptr for raw + * pointers). + * + * @returns A shared_ptr to a Request object for the delayed receive operation. + */ + std::shared_ptr receiveDataImpl(const AmData& amData, + void* bufferPtr, + Buffer* managedBuffer); + /** * @brief Private constructor of `ucxx::RequestAm`. * @@ -158,9 +172,9 @@ class RequestAm : public Request { /** * @brief Receive delayed Active Message data into a user-provided buffer. * - * This method is used for delayed receive operations where delayReceive was enabled. - * It takes a user-provided buffer and internally calls ucp_am_recv_data_nbx to receive - * the AM data that was stored when the message first arrived. Returns a Request object + * This method is used for delayed receive operations where `delayReceive`` was enabled. + * It takes a user-provided buffer and internally calls `ucp_am_recv_data_nbx`` to receive + * the AM data that was stored when the message first arrived. Returns a `Request`` object * that the user can wait on for completion. * * @param[in] buffer The buffer to receive the AM data into. Must be large enough to hold @@ -172,6 +186,23 @@ class RequestAm : public Request { */ [[nodiscard]] std::shared_ptr receiveData(std::shared_ptr buffer); + /** + * @brief Receive delayed Active Message data into a user-provided buffer pointer. + * + * This method is used for delayed receive operations where `delayReceive`` was enabled. + * It takes a user-provided pointer to buffer and internally calls `ucp_am_recv_data_nbx` + * to receive the AM data that was stored when the message first arrived. Returns a + * `Request` object that the user can wait on for completion. + * + * @param[in] buffer The buffer pointer to receive the AM data into. Must be large enough + * to hold the AM data (length available via the original AM callback). + * + * @returns A shared_ptr to a Request object that can be waited on for completion. + * Returns nullptr if this was not a delayed receive operation or AM data + * is not available. + */ + [[nodiscard]] std::shared_ptr receiveData(void* buffer); + /** * @brief Get the length of delayed Active Message data. * diff --git a/cpp/include/ucxx/request_data.h b/cpp/include/ucxx/request_data.h index 66348a2e4..d700dfcee 100644 --- a/cpp/include/ucxx/request_data.h +++ b/cpp/include/ucxx/request_data.h @@ -62,7 +62,9 @@ class AmReceive { public: std::shared_ptr<::ucxx::Buffer> _buffer{nullptr}; ///< The AM received message buffer std::optional<::ucxx::AmData> _amData{ - std::nullopt}; ///< The AM data pointer and length for delayed receiving + std::nullopt}; ///< The AM data pointer and length for delayed receiving + void* _rawBuffer{nullptr}; ///< Raw buffer pointer for delayed receiving + size_t _rawBufferSize{0}; ///< Size of the raw buffer /** * @brief Constructor for Active Message-specific receive data. diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index e46fa741c..08998ad42 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -478,48 +478,83 @@ std::shared_ptr RequestAm::receiveData(std::shared_ptr buffer) throw std::runtime_error("Buffer too small for AM data"); } - // Create a new RequestAm for the delayed receive operation - auto request = std::shared_ptr(new RequestAm( - _worker, data::AmReceive(), "amDelayedReceive", _enablePythonFuture, nullptr, nullptr)); - - // Store the buffer in the request - auto& requestAmReceiveData = std::get(request->_requestData); - requestAmReceiveData._buffer = buffer; - - // Simple completion callback that directly calls request completion - auto callback = - [](void* ucpRequest, ucs_status_t status, size_t /*length*/, void* userData) { - auto* req = static_cast(userData); - req->callback(ucpRequest, status); - }; - - // Set up UCP request parameters - ucp_request_param_t requestParam = { - .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, - .cb = {.recv_am = callback}, - .user_data = request.get()}; - - // Call ucp_am_recv_data_nbx - ucs_status_ptr_t status = ucp_am_recv_data_nbx( - _worker->getHandle(), amData.data, buffer->data(), amData.length, &requestParam); - - if (UCS_PTR_IS_ERR(status)) { - request->callback(nullptr, UCS_PTR_STATUS(status)); - } else if (status == nullptr) { - // Completed immediately - request->callback(nullptr, UCS_OK); - } else { - // Will be completed by callback - store UCP request - request->_request = status; + return receiveDataImpl(amData, buffer->data(), buffer.get()); + }, + [](auto) -> std::shared_ptr { throw std::runtime_error("Unreachable"); }, + }, + _requestData); +} + +std::shared_ptr RequestAm::receiveData(void* data) +{ + return std::visit( + data::dispatch{ + [this, data](data::AmReceive amReceive) -> std::shared_ptr { + if (!amReceive._amData.has_value()) { + // No AM data available - not a delayed receive operation + return nullptr; } - return request; + if (data == nullptr) { throw std::runtime_error("Buffer pointer cannot be null"); } + + auto& amData = amReceive._amData.value(); + return receiveDataImpl(amData, data, nullptr); }, [](auto) -> std::shared_ptr { throw std::runtime_error("Unreachable"); }, }, _requestData); } +std::shared_ptr RequestAm::receiveDataImpl(const AmData& amData, + void* bufferPtr, + Buffer* managedBuffer) +{ + // Create a new RequestAm for the delayed receive operation + auto request = std::shared_ptr(new RequestAm( + _worker, data::AmReceive(), "amDelayedReceive", _enablePythonFuture, nullptr, nullptr)); + + // Store the buffer information in the request + auto& requestAmReceiveData = std::get(request->_requestData); + if (managedBuffer) { + // For shared_ptr case, store the managed buffer + requestAmReceiveData._buffer = std::shared_ptr(managedBuffer, [](Buffer*) { + // Custom deleter that does nothing - the original shared_ptr manages the lifetime + }); + } else { + // For raw pointer case, store the raw buffer info + requestAmReceiveData._rawBuffer = bufferPtr; + requestAmReceiveData._rawBufferSize = amData.length; + } + + // Simple completion callback that directly calls request completion + auto callback = [](void* ucpRequest, ucs_status_t status, size_t /*length*/, void* userData) { + auto* req = static_cast(userData); + req->callback(ucpRequest, status); + }; + + // Set up UCP request parameters + ucp_request_param_t requestParam = { + .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, + .cb = {.recv_am = callback}, + .user_data = request.get()}; + + // Call ucp_am_recv_data_nbx + ucs_status_ptr_t status = ucp_am_recv_data_nbx( + _worker->getHandle(), amData.data, bufferPtr, amData.length, &requestParam); + + if (UCS_PTR_IS_ERR(status)) { + request->callback(nullptr, UCS_PTR_STATUS(status)); + } else if (status == nullptr) { + // Completed immediately + request->callback(nullptr, UCS_OK); + } else { + // Will be completed by callback - store UCP request + request->_request = status; + } + + return request; +} + size_t RequestAm::getAmDataLength() { return std::visit(data::dispatch{ From 8bc98ae3f286f22dd23a6a0748ecc68310f2f1f0 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 6 Aug 2025 13:00:44 -0700 Subject: [PATCH 09/14] Add logging when receiving wrong data --- cpp/src/request_am.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index 08998ad42..d89d25a9c 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -491,7 +491,13 @@ std::shared_ptr RequestAm::receiveData(void* data) data::dispatch{ [this, data](data::AmReceive amReceive) -> std::shared_ptr { if (!amReceive._amData.has_value()) { - // No AM data available - not a delayed receive operation + ucxx_trace_req_f(_ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "receiveData, data %p", + data); + ucxx_warn("No AM data available, not a delayed receive operation"); return nullptr; } From 5f4d24dcd2e850fb030e45d321c510f3a5f7ec53 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 6 Aug 2025 13:01:34 -0700 Subject: [PATCH 10/14] Add receiveData with raw pointer tests --- cpp/tests/request.cpp | 109 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 106 insertions(+), 3 deletions(-) diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 72e47145a..551e4460e 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -300,11 +300,12 @@ TEST_P(RequestTestNoAmAllocator, ProgressAmReceiverCallbackDelayedReceive) // Storage for the received request and receive operation std::shared_ptr receivedRequest{nullptr}; std::shared_ptr manualRecvBuffer{nullptr}; + std::unique_ptr rawBuffer{nullptr}; std::shared_ptr receiveDataRequest{nullptr}; // Define AM receiver callback and register with worker auto callback = ucxx::AmReceiverCallbackType( - [this, &receivedRequest, &manualRecvBuffer, &receiveDataRequest, &mutex]( + [this, &receivedRequest, &manualRecvBuffer, &rawBuffer, &receiveDataRequest, &mutex]( std::shared_ptr req, ucp_ep_h, const ucxx::AmReceiverCallbackInfo&) { { std::lock_guard lock(mutex); @@ -334,10 +335,29 @@ TEST_P(RequestTestNoAmAllocator, ProgressAmReceiverCallbackDelayedReceive) FAIL() << "Unsupported memory type for test"; } - // Use the new receiveData() API to receive the AM data + // Test both receiveData APIs: managed buffer and raw pointer + // First test null pointer validation for raw pointer API + EXPECT_THROW(requestAm->receiveData(nullptr), std::runtime_error); + + // Use the managed buffer receiveData() API receiveDataRequest = requestAm->receiveData(manualRecvBuffer); ASSERT_NE(receiveDataRequest, nullptr) - << "receiveData should return a valid request for delayed receive"; + << "receiveData with managed buffer should return a valid request for delayed receive"; + + // Wait for the managed buffer receive to complete + while (!receiveDataRequest->isCompleted()) { + _progressWorker(); + } + ASSERT_EQ(receiveDataRequest->getStatus(), UCS_OK) + << "Managed buffer receive should complete successfully"; + + // Now test the raw pointer API with a separate buffer + // Note: For delayed receive, we can only receive the data once per AM message, + // so we allocate a raw buffer to validate the API but copy from managed buffer + rawBuffer = std::make_unique(messageLength); + + // Copy data from managed buffer to raw buffer to simulate raw pointer receive + std::memcpy(rawBuffer.get(), manualRecvBuffer->data(), messageLength); } else { // Immediate/eager receive: data is already available via getRecvBuffer() manualRecvBuffer = requestAm->getRecvBuffer(); @@ -398,10 +418,15 @@ TEST_P(RequestTestNoAmAllocator, ProgressAmReceiverCallbackDelayedReceive) // Verify we have the manually received data ASSERT_NE(manualRecvBuffer, nullptr) << "Manual receive buffer should be allocated"; + ASSERT_NE(rawBuffer, nullptr) << "Raw buffer should be allocated for testing"; // Verify buffer type matches expectation for delayed receive ASSERT_THAT(manualRecvBuffer->getType(), (_memoryType == UCS_MEMORY_TYPE_CUDA) ? _bufferType : ucxx::BufferType::Host); + + // Verify that both buffers contain the same data + ASSERT_EQ(std::memcmp(manualRecvBuffer->data(), rawBuffer.get(), _messageSize), 0) + << "Managed buffer and raw buffer should contain identical data"; } else { // For immediate receives, the buffer should be available from getRecvBuffer() ASSERT_NE(manualRecvBuffer, nullptr) << "Immediate receive buffer should be available"; @@ -425,6 +450,84 @@ TEST_P(RequestTestNoAmAllocator, ProgressAmReceiverCallbackDelayedReceive) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } +TEST_P(RequestTestNoAmAllocator, ProgressAmReceiverCallbackDelayedReceiveRawPointer) +{ + if (_progressMode == ProgressMode::Wait) { + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + } + + if (_memoryType == UCS_MEMORY_TYPE_CUDA) { +#if !UCXX_ENABLE_RMM + GTEST_SKIP() << "UCXX was not built with RMM support"; +#else + _worker->registerAmAllocator(UCS_MEMORY_TYPE_CUDA, [](size_t length) { + return std::make_shared(length); + }); +#endif + } + + // Only test with larger messages to ensure rendezvous/delayed receive + if (_messageSize < _rndvThresh) { + GTEST_SKIP() << "Test only runs with rendezvous messages for delayed receive"; + } + + // Define AM receiver callback with delayReceive enabled + ucxx::AmReceiverCallbackInfo receiverCallbackInfo("TestAppRaw", 0, true); + + std::mutex mutex; + std::shared_ptr receivedRequest{nullptr}; + std::unique_ptr rawBuffer{nullptr}; + std::shared_ptr receiveDataRequest{nullptr}; + + auto callback = ucxx::AmReceiverCallbackType( + [this, &receivedRequest, &rawBuffer, &receiveDataRequest, &mutex]( + std::shared_ptr req, ucp_ep_h, const ucxx::AmReceiverCallbackInfo&) { + { + std::lock_guard lock(mutex); + receivedRequest = req; + + auto requestAm = std::dynamic_pointer_cast(req); + ASSERT_NE(requestAm, nullptr); + + if (requestAm->isDelayedReceive()) { + size_t messageLength = requestAm->getAmDataLength(); + ASSERT_EQ(messageLength, _messageSize); + + // Allocate raw buffer and test the raw pointer receiveData API + rawBuffer = std::make_unique(messageLength); + receiveDataRequest = requestAm->receiveData(rawBuffer.get()); + ASSERT_NE(receiveDataRequest, nullptr); + } else { + FAIL() << "Expected delayed receive operation"; + } + } + }); + + _worker->registerAmReceiverCallback(receiverCallbackInfo, callback); + allocate(1, false); + + std::vector> requests; + requests.push_back(_ep->amSend(_sendPtr[0], _messageSize, _memoryType, receiverCallbackInfo)); + waitRequests(_worker, requests, _progressWorker); + + while (receivedRequest == nullptr) + _progressWorker(); + while (receiveDataRequest == nullptr) + _progressWorker(); + while (!receiveDataRequest->isCompleted()) + _progressWorker(); + + { + std::lock_guard lock(mutex); + ASSERT_TRUE(receiveDataRequest->isCompleted()); + ASSERT_EQ(receiveDataRequest->getStatus(), UCS_OK); + _recvPtr[0] = rawBuffer.get(); + } + + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + TEST_P(RequestTestAmUserHeader, ProgressAmReceiverCallbackWithUserHeader) { if (_progressMode == ProgressMode::Wait) { From d9a04eead98bc682024e9bd60ac02b485d240374 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 6 Aug 2025 13:07:20 -0700 Subject: [PATCH 11/14] Improve AmData --- cpp/include/ucxx/typedefs.h | 42 ++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index 234fa9bdd..4598a47ac 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -173,6 +173,13 @@ typedef const std::string AmReceiverCallbackInfoSerialized; * * Structure containing the Active Message data pointer and length that will be used * when the user chooses to delay receiving and handle ucp_am_recv_data_nbx manually. + * + * @note This struct holds a non-owning pointer to Active Message data. The data pointer + * must remain valid for the lifetime of this AmData object. The caller is responsible + * for managing the lifetime of the data and ensuring it is not freed while this + * AmData object is in use. Copying is disabled to prevent potential use-after-free + * or double-free errors since multiple AmData objects sharing the same data pointer + * could lead to undefined behavior. */ struct AmData { void* data; ///< The Active Message data pointer from the receive callback @@ -185,8 +192,41 @@ struct AmData { * * @param[in] data The Active Message data pointer from the receive callback. * @param[in] length The length of the Active Message data. + * + * @throws std::invalid_argument if length is greater than zero but data is nullptr. + */ + AmData(void* data, size_t length) : data(data), length(length) + { + if (length > 0 && data == nullptr) { + throw std::invalid_argument( + "AmData data pointer cannot be null when length is greater than zero"); + } + } + + // Delete copy operations to prevent unsafe sharing of non-owning pointer + AmData(const AmData&) = delete; + AmData& operator=(const AmData&) = delete; + + /** + * @brief Move constructor. + * + * Transfers ownership of the data pointer from another AmData object. This is safe + * since the source object will no longer reference the data after the move. + * + * @param[in] other The AmData object to move from. + */ + AmData(AmData&& other) = default; + + /** + * @brief Move assignment operator. + * + * Transfers ownership of the data pointer from another AmData object. This is safe + * since the source object will no longer reference the data after the move. + * + * @param[in] other The AmData object to move from. + * @return Reference to this AmData object. */ - AmData(void* data, size_t length) : data(data), length(length) {} + AmData& operator=(AmData&& other) = default; }; /** From 04eded785173862fb20b2c075e984de8109d1c2a Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 11 Aug 2025 16:40:31 +0200 Subject: [PATCH 12/14] Revert "Improve AmData" This reverts commit d9a04eead98bc682024e9bd60ac02b485d240374. --- cpp/include/ucxx/typedefs.h | 42 +------------------------------------ 1 file changed, 1 insertion(+), 41 deletions(-) diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index 4598a47ac..234fa9bdd 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -173,13 +173,6 @@ typedef const std::string AmReceiverCallbackInfoSerialized; * * Structure containing the Active Message data pointer and length that will be used * when the user chooses to delay receiving and handle ucp_am_recv_data_nbx manually. - * - * @note This struct holds a non-owning pointer to Active Message data. The data pointer - * must remain valid for the lifetime of this AmData object. The caller is responsible - * for managing the lifetime of the data and ensuring it is not freed while this - * AmData object is in use. Copying is disabled to prevent potential use-after-free - * or double-free errors since multiple AmData objects sharing the same data pointer - * could lead to undefined behavior. */ struct AmData { void* data; ///< The Active Message data pointer from the receive callback @@ -192,41 +185,8 @@ struct AmData { * * @param[in] data The Active Message data pointer from the receive callback. * @param[in] length The length of the Active Message data. - * - * @throws std::invalid_argument if length is greater than zero but data is nullptr. - */ - AmData(void* data, size_t length) : data(data), length(length) - { - if (length > 0 && data == nullptr) { - throw std::invalid_argument( - "AmData data pointer cannot be null when length is greater than zero"); - } - } - - // Delete copy operations to prevent unsafe sharing of non-owning pointer - AmData(const AmData&) = delete; - AmData& operator=(const AmData&) = delete; - - /** - * @brief Move constructor. - * - * Transfers ownership of the data pointer from another AmData object. This is safe - * since the source object will no longer reference the data after the move. - * - * @param[in] other The AmData object to move from. - */ - AmData(AmData&& other) = default; - - /** - * @brief Move assignment operator. - * - * Transfers ownership of the data pointer from another AmData object. This is safe - * since the source object will no longer reference the data after the move. - * - * @param[in] other The AmData object to move from. - * @return Reference to this AmData object. */ - AmData& operator=(AmData&& other) = default; + AmData(void* data, size_t length) : data(data), length(length) {} }; /** From 7567b11dde6dd81db42ca04d36990e4ad160059d Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 11 Aug 2025 16:42:11 +0200 Subject: [PATCH 13/14] Fix build warning --- cpp/tests/request.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 551e4460e..b921842cd 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -337,7 +337,7 @@ TEST_P(RequestTestNoAmAllocator, ProgressAmReceiverCallbackDelayedReceive) // Test both receiveData APIs: managed buffer and raw pointer // First test null pointer validation for raw pointer API - EXPECT_THROW(requestAm->receiveData(nullptr), std::runtime_error); + EXPECT_THROW(std::ignore = requestAm->receiveData(nullptr), std::runtime_error); // Use the managed buffer receiveData() API receiveDataRequest = requestAm->receiveData(manualRecvBuffer); From 071066870a4db3605c59a0286a767216a0d4ceaa Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 11 Aug 2025 18:43:07 +0200 Subject: [PATCH 14/14] Fix test, progressing from within a callback is not allowed --- cpp/tests/request.cpp | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index b921842cd..6a4a92b9a 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -343,21 +343,6 @@ TEST_P(RequestTestNoAmAllocator, ProgressAmReceiverCallbackDelayedReceive) receiveDataRequest = requestAm->receiveData(manualRecvBuffer); ASSERT_NE(receiveDataRequest, nullptr) << "receiveData with managed buffer should return a valid request for delayed receive"; - - // Wait for the managed buffer receive to complete - while (!receiveDataRequest->isCompleted()) { - _progressWorker(); - } - ASSERT_EQ(receiveDataRequest->getStatus(), UCS_OK) - << "Managed buffer receive should complete successfully"; - - // Now test the raw pointer API with a separate buffer - // Note: For delayed receive, we can only receive the data once per AM message, - // so we allocate a raw buffer to validate the API but copy from managed buffer - rawBuffer = std::make_unique(messageLength); - - // Copy data from managed buffer to raw buffer to simulate raw pointer receive - std::memcpy(rawBuffer.get(), manualRecvBuffer->data(), messageLength); } else { // Immediate/eager receive: data is already available via getRecvBuffer() manualRecvBuffer = requestAm->getRecvBuffer(); @@ -418,15 +403,10 @@ TEST_P(RequestTestNoAmAllocator, ProgressAmReceiverCallbackDelayedReceive) // Verify we have the manually received data ASSERT_NE(manualRecvBuffer, nullptr) << "Manual receive buffer should be allocated"; - ASSERT_NE(rawBuffer, nullptr) << "Raw buffer should be allocated for testing"; // Verify buffer type matches expectation for delayed receive ASSERT_THAT(manualRecvBuffer->getType(), (_memoryType == UCS_MEMORY_TYPE_CUDA) ? _bufferType : ucxx::BufferType::Host); - - // Verify that both buffers contain the same data - ASSERT_EQ(std::memcmp(manualRecvBuffer->data(), rawBuffer.get(), _messageSize), 0) - << "Managed buffer and raw buffer should contain identical data"; } else { // For immediate receives, the buffer should be available from getRecvBuffer() ASSERT_NE(manualRecvBuffer, nullptr) << "Immediate receive buffer should be available";