From 807f1323a62ea5826091a2a1d173a9b6bb8db610 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 19 May 2026 14:53:02 +0200 Subject: [PATCH] Add support to query UCP debug info from requests (#437) Add an opt-in mechanism for querying UCP request attributes (memory type, debug string) on every `ucxx::Request`. Enabled via a new `ucxx::experimental::WorkerBuilder::requestAttributes(true)` option. When enabled, every `ucxx::Request` submit site funnels through a small `Request::publishRequest()` helper that stores the UCP handle and queries `ucp_request_query` under the existing `_mutex`. `ucp_request_free` moves from `Request::callback` into `Request::setStatus`, making the query and the free mutually exclusive without any new atomics or callback-side locking. Wired through every request type and exposed to users via `Request::queryAttributes()`, which throws `ucxx::UnsupportedError` when the feature is disabled on the owning worker and `ucxx::NoElemError` when UCX took an inline path that produced no UCP handle to query (e.g., an eager UCX transfer). Tag, AM, and MemoryGet test are asserted strictly above the rendezvous threshold, where UCX deterministically allocates a queryable request on every transport. Stream and MemoryPut use lenient assertions (substring-check on success, throw is acceptable) because stream has no rendezvous protocol and small RMA puts are fire-and-forget, both of which have transport-dependent inline-completion behavior that no fixed size threshold can portably predict. The toggle is worker-scoped: enabling it queries attributes on **every** request created from that worker, which has potentially non-negligible per-request cost. Fine-grained per-request opt-in (so callers can attribute-query only the requests they care about) is not implemented here, it requires a builder-pattern constructor at the request level which doesn't exist yet, and is deferred to a follow-up. For now, users who need attributes accept the worker-wide cost, and users who don't, opt out by leaving the default. Authors: - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - Lawrence Mitchell (https://github.com/wence-) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/ucxx/pull/437 --- ci/run_cpp.sh | 2 +- .../ucxx/experimental/worker_builder.h | 13 + cpp/include/ucxx/internal/request_am.h | 9 - cpp/include/ucxx/request.h | 54 +++ cpp/include/ucxx/worker.h | 43 ++- cpp/src/experimental/worker_builder.cpp | 8 + cpp/src/internal/request_am.cpp | 2 - cpp/src/request.cpp | 67 +++- cpp/src/request_am.cpp | 5 +- cpp/src/request_endpoint_close.cpp | 5 +- cpp/src/request_flush.cpp | 5 +- cpp/src/request_mem.cpp | 5 +- cpp/src/request_stream.cpp | 5 +- cpp/src/request_tag.cpp | 5 +- cpp/src/worker.cpp | 19 ++ cpp/tests/request.cpp | 313 ++++++++++++++++-- cpp/tests/worker.cpp | 44 +++ 17 files changed, 547 insertions(+), 57 deletions(-) diff --git a/ci/run_cpp.sh b/ci/run_cpp.sh index 0d0a4ab14..bf725e8f6 100755 --- a/ci/run_cpp.sh +++ b/ci/run_cpp.sh @@ -45,7 +45,7 @@ else fi run_cpp_tests() { - CMD_LINE="python ${TIMEOUT_TOOL_PATH} $((10*60)) ${GTESTS_PATH}/UCXX_TEST" + CMD_LINE="python ${TIMEOUT_TOOL_PATH} $((20*60)) ${GTESTS_PATH}/UCXX_TEST" log_command "${CMD_LINE}" UCX_TCP_CM_REUSEADDR=y ${CMD_LINE} diff --git a/cpp/include/ucxx/experimental/worker_builder.h b/cpp/include/ucxx/experimental/worker_builder.h index b8f9115be..1a9596cdd 100644 --- a/cpp/include/ucxx/experimental/worker_builder.h +++ b/cpp/include/ucxx/experimental/worker_builder.h @@ -90,6 +90,19 @@ class WorkerBuilder final { */ WorkerBuilder& pythonFuture(bool enable = true); + /** + * @brief Configure request attributes querying. + * + * When enabled, each `ucxx::Request` created from the worker will have its UCP + * attributes (such as the debug string) queried immediately after submission, making + * them available via `ucxx::Request::getRequestAttributes()`. This may have + * non-negligible runtime cost and is therefore disabled by default. + * + * @param[in] enable whether request attributes querying is enabled (default: true). + * @return Reference to this builder for method chaining. + */ + WorkerBuilder& requestAttributes(bool enable = true); + /** * @brief Configure the preferred buffer type for CUDA allocations. * diff --git a/cpp/include/ucxx/internal/request_am.h b/cpp/include/ucxx/internal/request_am.h index fe3e8ac49..37d7dd972 100644 --- a/cpp/include/ucxx/internal/request_am.h +++ b/cpp/include/ucxx/internal/request_am.h @@ -67,15 +67,6 @@ class RecvAmMessage { AmReceiverCallbackType receiverCallback = AmReceiverCallbackType(), std::vector userHeader = {}); - /** - * @brief Set the UCP request. - * - * Set the underlying UCP request (`_request` attribute) of the `RequestAm`. - * - * @param[in] request the UCP request associated to the active message receive operation. - */ - void setUcpRequest(void* request); - /** * @brief Execute the `ucxx::Request::callback()`. * diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index 6ad7607bd..c718ebca4 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -8,6 +8,7 @@ #include #include #include +#include #include @@ -37,6 +38,14 @@ namespace ucxx { */ class Request : public Component { protected: + /** + * @brief Request attributes reported by `ucp_request_query`. + */ + struct Attributes { + ucs_memory_type memoryType{UCS_MEMORY_TYPE_UNKNOWN}; ///< Memory type of the request + std::string debugString{}; ///< Stored debug string + }; + ucs_status_t _status{UCS_INPROGRESS}; ///< Requests status std::string _status_msg{}; ///< Human-readable status message void* _request{nullptr}; ///< Pointer to UCP request @@ -54,6 +63,9 @@ class Request : public Component { bool _enablePythonFuture{true}; ///< Whether Python future is enabled for this request RequestCallbackUserFunction _callback{nullptr}; ///< Completion callback RequestCallbackUserData _callbackData{nullptr}; ///< Completion callback data + Attributes _requestAttr{}; ///< Request attributes queried when request is posted; the + ///< default `memoryType == UCS_MEMORY_TYPE_UNKNOWN` doubles + ///< as the "not populated yet" sentinel /** * @brief Protected constructor of an abstract `ucxx::Request`. @@ -235,6 +247,48 @@ class Request : public Component { * @return The received user header (if applicable) or an empty string. */ [[nodiscard]] virtual std::string getRecvHeader(); + + /** + * @brief Get the requests's attributes. + * + * Returns the request attributes as a struct. The owning `ucxx::Worker` must have been + * created with request attributes querying enabled (see + * `ucxx::experimental::WorkerBuilder::requestAttributes()`); otherwise the attributes + * are never populated and this method throws. Querying the underlying UCP request is + * an implementation detail performed eagerly when the request is submitted. All + * non-status fields exposed by UCP are queried, use `getStatus()` to obtain the status. + * + * @throw ucxx::UnsupportedError if the owning worker was not built with request + * attributes querying enabled. Requires `Worker` + * created with + * `ucxx::experimental::WorkerBuilder::requestAttributes(true)`. + * @throw ucxx::NoElemError if attributes are unavailable for this specific + * request: either because UCX took an inline-completion + * path that produced no UCP request to query, or because + * the request has not completed yet. Callers can + * distinguish the latter from the former by checking + * `isCompleted()`. + * + * @return An `Attributes` containing the request attributes. + */ + [[nodiscard]] Attributes queryAttributes(); + + protected: + /** + * @brief Publish the UCP request handle and capture its attributes. + * + * Single critical section that stores the UCP request pointer in `_request` and, when + * the owning worker has request attributes querying enabled, immediately queries those + * attributes. The completion path frees the UCP request inside `setStatus` under the + * same `_mutex`, so this helper guarantees the query and the free are mutually + * exclusive and that there are no use-after-free in threaded progress modes. + * + * Every submit site calls this after obtaining the request handle from the corresponding + * `ucp_*_nbx` function. + * + * @param[in] request the UCP request pointer returned by a non-blocking submit. + */ + void publishRequest(void* request); }; } // namespace ucxx diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index 661e4eb46..20e2705f6 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -79,6 +79,8 @@ class Worker : public Component { protected: bool _enableFuture{ false}; ///< Boolean identifying whether the worker was created with future capability + bool _enableRequestAttributes{ + false}; ///< Whether request attributes (e.g. UCP debug info) are queried for each request std::mutex _futuresPoolMutex{}; ///< Mutex to access the futures pool std::queue> _futuresPool{}; ///< Futures pool to prevent running out of fresh futures @@ -507,6 +509,19 @@ class Worker : public Component { */ [[nodiscard]] bool isFutureEnabled() const; + /** + * @brief Inquire if worker has been created with request attributes querying enabled. + * + * Check whether the worker has been created with request attributes querying enabled. + * When enabled, each `ucxx::Request` will have its UCP attributes (such as the debug + * string) queried immediately after submission, making them available via + * `ucxx::Request::queryAttributes()`. Querying request attributes has a + * non-negligible runtime cost and is therefore disabled by default. + * + * @returns `true` if request attributes querying is enabled, `false` otherwise. + */ + [[nodiscard]] bool isRequestAttributesEnabled() const noexcept; + /** * @brief Get the preferred buffer type for CUDA allocations. * @@ -1000,7 +1015,7 @@ class Worker : public Component { * * Using a Python future may be requested by specifying `enablePythonFuture`. If a * Python future is requested, the Python application must then await on this future to - * ensure the transfer has completed. Requires UCXX Python support. + * ensure the transfer has completed. * * @note If a `callbackFunction` is specified, the lifetime of `callbackData` and of any * other objects used in the scope of `callbackFunction` must be guaranteed by the caller @@ -1020,6 +1035,32 @@ class Worker : public Component { const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); + + /** + * @brief Worker attributes reported by `ucp_worker_query`. + */ + struct Attributes { + /// Thread safety level the worker was created with. + ucs_thread_mode_t threadMode{UCS_THREAD_MODE_MULTI}; + /// Maximum allowed header size for `ucp_am_send_nbx`. + size_t maxAmHeader{0}; + /// Worker name used by tracing and analysis tools. + std::string name{}; + /// Maximum debug-string buffer size accepted by `ucp_request_query`. + size_t maxDebugString{0}; + }; + + /** + * @brief Get the worker's attributes. + * + * Returns the worker attributes as a struct, querying UCP via `ucp_worker_query` under + * the hood. All non-address fields exposed by UCP are queried, use `getAddress()` to + * obtain the address. + * + * @returns An `Attributes` filled with all queried fields. + * @throws ucxx::Error if an error occurred while querying worker attributes. + */ + [[nodiscard]] Attributes queryAttributes() const; }; /** diff --git a/cpp/src/experimental/worker_builder.cpp b/cpp/src/experimental/worker_builder.cpp index a8c780aa2..1efd0994f 100644 --- a/cpp/src/experimental/worker_builder.cpp +++ b/cpp/src/experimental/worker_builder.cpp @@ -17,6 +17,7 @@ struct WorkerBuilder::Impl { std::shared_ptr context; bool enableDelayedSubmission{false}; bool enableFuture{false}; + bool enableRequestAttributes{false}; BufferType cudaBufferType{BufferType::Invalid}; explicit Impl(std::shared_ptr ctx) : context(std::move(ctx)) {} @@ -41,6 +42,12 @@ WorkerBuilder& WorkerBuilder::pythonFuture(bool enable) return *this; } +WorkerBuilder& WorkerBuilder::requestAttributes(bool enable) +{ + _impl->enableRequestAttributes = enable; + return *this; +} + WorkerBuilder& WorkerBuilder::cudaBufferType(BufferType bufferType) { _impl->cudaBufferType = bufferType; @@ -51,6 +58,7 @@ std::shared_ptr WorkerBuilder::build() const { auto worker = ucxx::createWorker(_impl->context, _impl->enableDelayedSubmission, _impl->enableFuture); + worker->_enableRequestAttributes = _impl->enableRequestAttributes; if (_impl->cudaBufferType != BufferType::Invalid) worker->setCudaBufferType(_impl->cudaBufferType); return worker; diff --git a/cpp/src/internal/request_am.cpp b/cpp/src/internal/request_am.cpp index e2eb65c25..6fbe7d922 100644 --- a/cpp/src/internal/request_am.cpp +++ b/cpp/src/internal/request_am.cpp @@ -40,8 +40,6 @@ RecvAmMessage::RecvAmMessage(internal::AmData* amData, } } -void RecvAmMessage::setUcpRequest(void* request) { _request->_request = request; } - void RecvAmMessage::callback(void* request, ucs_status_t status) { std::visit(data::dispatch{ diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 1551e4f23..05eb48f58 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -3,6 +3,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ #include +#include #include #include #include @@ -140,7 +141,7 @@ void Request::callback(void* request, ucs_status_t status) if (_status != UCS_INPROGRESS) ucxx_trace_req_f(_ownerString.c_str(), this, - _request, + request, _operationName.c_str(), "has status already set to %d (%s), callback setting %d (%s)", _status, @@ -148,12 +149,10 @@ void Request::callback(void* request, ucs_status_t status) status, ucs_status_string(status)); - if (UCS_PTR_IS_PTR(_request)) ucp_request_free(request); - - ucxx_trace_req_f(_ownerString.c_str(), this, _request, _operationName.c_str(), "completed"); + ucxx_trace_req_f(_ownerString.c_str(), this, request, _operationName.c_str(), "completed"); setStatus(status); ucxx_trace_req_f( - _ownerString.c_str(), this, _request, _operationName.c_str(), "isCompleted: %d", isCompleted()); + _ownerString.c_str(), this, request, _operationName.c_str(), "isCompleted: %d", isCompleted()); } void Request::process() @@ -235,11 +234,69 @@ void Request::setStatus(ucs_status_t status) _ownerString.c_str(), this, _request, _operationName.c_str(), "invoking user callback"); _callback(status, _callbackData); } + + // Free the UCP request inside the lock so it is mutually exclusive with + // `publishRequest()`/`queryRequestAttributes()` on the submit thread. + if (UCS_PTR_IS_PTR(_request)) { + ucp_request_free(_request); + _request = nullptr; + } } } const std::string& Request::getOwnerString() const { return _ownerString; } +void Request::publishRequest(void* request) +{ + if (!_worker->isRequestAttributesEnabled()) { + std::lock_guard lock(_mutex); + _request = request; + return; + } + + std::lock_guard lock(_mutex); + _request = request; + + if (_requestAttr.memoryType != UCS_MEMORY_TYPE_UNKNOWN) return; + + ucp_request_attr_t result; + + auto worker_attr = _worker->queryAttributes(); + + std::string debugString(worker_attr.maxDebugString, '\0'); + + result.field_mask = UCP_REQUEST_ATTR_FIELD_MEM_TYPE | UCP_REQUEST_ATTR_FIELD_INFO_STRING | + UCP_REQUEST_ATTR_FIELD_INFO_STRING_SIZE; + + result.debug_string = debugString.data(); + result.debug_string_size = debugString.size(); + + if (UCS_PTR_IS_PTR(_request)) { + auto queryStatus = ucp_request_query(_request, &result); + if (queryStatus == UCS_OK && result.debug_string != nullptr) { + debugString.resize(std::strlen(debugString.c_str())); + _requestAttr.debugString = std::move(debugString); + _requestAttr.memoryType = result.mem_type; + } + } +} + +Request::Attributes Request::queryAttributes() +{ + if (!_worker->isRequestAttributesEnabled()) + throw ucxx::UnsupportedError( + "Request attributes querying is disabled on the owning worker; build the worker " + "with `ucxx::experimental::WorkerBuilder::requestAttributes(true)` to enable it"); + + std::lock_guard lock(_mutex); + + if (_requestAttr.memoryType != UCS_MEMORY_TYPE_UNKNOWN) return _requestAttr; + + throw ucxx::NoElemError( + "Request attributes are not available for this request: UCX took an inline-completion " + "path with no queryable UCP request, or the request has not completed yet"); +} + std::shared_ptr Request::getRecvBuffer() { return nullptr; } std::string Request::getRecvHeader() { return {}; } diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index d77ac2bb9..4ccdae333 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -368,7 +368,7 @@ ucs_status_t RequestAm::recvCallback(void* arg, return s; } else { // The request will be handled by the callback - recvAmMessage->setUcpRequest(status); + req->publishRequest(status); amData->_registerInflightRequest(req); { @@ -470,8 +470,7 @@ void RequestAm::request() amSend._count, ¶m); - std::lock_guard lock(_mutex); - _request = request; + publishRequest(request); }, [](auto) { throw ucxx::UnsupportedError("Only send active messages can call request()"); }, }, diff --git a/cpp/src/request_endpoint_close.cpp b/cpp/src/request_endpoint_close.cpp index 3e1175824..05f83845d 100644 --- a/cpp/src/request_endpoint_close.cpp +++ b/cpp/src/request_endpoint_close.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include "ucxx/request_data.h" @@ -78,8 +78,7 @@ void RequestEndpointClose::request() else throw ucxx::Error("A valid endpoint or worker is required for a close operation."); - std::lock_guard lock(_mutex); - _request = request; + publishRequest(request); } void RequestEndpointClose::populateDelayedSubmission() diff --git a/cpp/src/request_flush.cpp b/cpp/src/request_flush.cpp index 1dcb3a936..bb20491be 100644 --- a/cpp/src/request_flush.cpp +++ b/cpp/src/request_flush.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -77,8 +77,7 @@ void RequestFlush::request() else throw ucxx::Error("A valid endpoint or worker is required for a flush operation."); - std::lock_guard lock(_mutex); - _request = request; + publishRequest(request); } void RequestFlush::populateDelayedSubmission() diff --git a/cpp/src/request_mem.cpp b/cpp/src/request_mem.cpp index bff6caeee..16f157e59 100644 --- a/cpp/src/request_mem.cpp +++ b/cpp/src/request_mem.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -127,8 +127,7 @@ void RequestMem::request() }, _requestData); - std::lock_guard lock(_mutex); - _request = request; + publishRequest(request); } void RequestMem::populateDelayedSubmission() diff --git a/cpp/src/request_stream.cpp b/cpp/src/request_stream.cpp index 4327cc407..30b42f60d 100644 --- a/cpp/src/request_stream.cpp +++ b/cpp/src/request_stream.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -91,8 +91,7 @@ void RequestStream::request() }, _requestData); - std::lock_guard lock(_mutex); - _request = request; + publishRequest(request); } void RequestStream::populateDelayedSubmission() diff --git a/cpp/src/request_tag.cpp b/cpp/src/request_tag.cpp index a4424a4fe..edc0f8627 100644 --- a/cpp/src/request_tag.cpp +++ b/cpp/src/request_tag.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -161,8 +161,7 @@ void RequestTag::request() }, _requestData); - std::lock_guard lock(_mutex); - _request = request; + publishRequest(request); } void RequestTag::populateDelayedSubmission() diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index 9e6b556aa..b9b9063d1 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -3,6 +3,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ #include +#include #include #include #include @@ -200,6 +201,22 @@ std::string Worker::getInfo() return utils::decodeTextFileDescriptor(TextFileDescriptor); } +Worker::Attributes Worker::queryAttributes() const +{ + ucp_worker_attr_t attr = { + .field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE | UCP_WORKER_ATTR_FIELD_MAX_AM_HEADER | + UCP_WORKER_ATTR_FIELD_NAME | UCP_WORKER_ATTR_FIELD_MAX_INFO_STRING}; + + utils::ucsErrorThrow(ucp_worker_query(_handle, &attr)); + + return Attributes{ + .threadMode = attr.thread_mode, + .maxAmHeader = attr.max_am_header, + .name = std::string(attr.name, ::strnlen(attr.name, sizeof(attr.name))), + .maxDebugString = attr.max_debug_string, + }; +} + bool Worker::isDelayedRequestSubmissionEnabled() const { return _delayedSubmissionCollection->isDelayedRequestSubmissionEnabled(); @@ -207,6 +224,8 @@ bool Worker::isDelayedRequestSubmissionEnabled() const bool Worker::isFutureEnabled() const { return _enableFuture; } +bool Worker::isRequestAttributesEnabled() const noexcept { return _enableRequestAttributes; } + BufferType Worker::getCudaBufferType() const { return _cudaBufferType; } void Worker::setCudaBufferType(BufferType bufferType) diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 6997e6c44..e28584df4 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -62,6 +62,40 @@ class RequestTest : public ::testing::TestWithParam< std::vector _sendPtr{nullptr}; std::vector _recvPtr{nullptr}; + void buildWorker(bool enableRequestAttributes) + { + auto builder = ucxx::experimental::createWorker(_context) + .delayedSubmission(_enableDelayedSubmission) + .requestAttributes(enableRequestAttributes); + + if (_bufferType == ucxx::BufferType::RMM || _bufferType == ucxx::BufferType::CCCL) + builder.cudaBufferType(_bufferType); + + _worker = builder.build(); + + if (_progressMode == ProgressMode::Blocking) { + _worker->initBlockingProgressMode(); + } else if (_progressMode == ProgressMode::ThreadPolling || + _progressMode == ProgressMode::ThreadBlocking) { + _worker->setProgressThreadStartCallback(::createCudaContextCallback, nullptr); + + if (_progressMode == ProgressMode::ThreadPolling) _worker->startProgressThread(true); + if (_progressMode == ProgressMode::ThreadBlocking) _worker->startProgressThread(false); + } + + _progressWorker = getProgressFunction(_worker, _progressMode); + + _ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + } + + void rebuildWorker(bool enableRequestAttributes) + { + if (_worker && _worker->isProgressThreadRunning()) _worker->stopProgressThread(); + _ep.reset(); + _worker.reset(); + buildWorker(enableRequestAttributes); + } + void SetUp() { std::tie(_bufferType, @@ -88,25 +122,7 @@ class RequestTest : public ::testing::TestWithParam< _context = ucxx::createContext({{"RNDV_THRESH", std::to_string(_rndvThresh)}}, ucxx::Context::defaultFeatureFlags); - auto builder = - ucxx::experimental::createWorker(_context).delayedSubmission(_enableDelayedSubmission); - if (_bufferType == ucxx::BufferType::RMM || _bufferType == ucxx::BufferType::CCCL) - builder.cudaBufferType(_bufferType); - _worker = builder.build(); - - if (_progressMode == ProgressMode::Blocking) { - _worker->initBlockingProgressMode(); - } else if (_progressMode == ProgressMode::ThreadPolling || - _progressMode == ProgressMode::ThreadBlocking) { - _worker->setProgressThreadStartCallback(::createCudaContextCallback, nullptr); - - if (_progressMode == ProgressMode::ThreadPolling) _worker->startProgressThread(true); - if (_progressMode == ProgressMode::ThreadBlocking) _worker->startProgressThread(false); - } - - _progressWorker = getProgressFunction(_worker, _progressMode); - - _ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + buildWorker(false); } void TearDown() @@ -222,8 +238,8 @@ TEST_P(RequestTest, ProgressAm) auto recvReq = requests[1]; _recvPtr[0] = recvReq->getRecvBuffer()->data(); - // Messages larger than `_rndvThresh` are rendezvous and will use custom allocator, - // smaller messages are eager and will always be host-allocated. + // Messages of size `_rndvThresh` or larger are rendezvous and will use the custom + // allocator, smaller messages are eager and will always be host-allocated. ASSERT_THAT(recvReq->getRecvBuffer()->getType(), (_registerCustomAmAllocator && _messageSize >= _rndvThresh) ? _bufferType : ucxx::BufferType::Host); @@ -533,6 +549,261 @@ TEST_P(RequestTest, ProgressTag) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } +TEST_P(RequestTest, ProgressTagRequestAttributes) +{ + if (_messageSize < _rndvThresh) + GTEST_SKIP() << "Eager messages do not create a ucp_request and thus no debug info"; + + rebuildWorker(true); + + allocate(); + + std::vector> requests; + requests.push_back(_ep->tagSend(_sendPtr[0], _messageSize, ucxx::Tag{0})); + requests.push_back(_ep->tagRecv(_recvPtr[0], _messageSize, ucxx::Tag{0}, ucxx::TagMaskFull)); + waitRequests(_worker, requests, _progressWorker); + + for (const auto& request : requests) { + auto debugString = request->queryAttributes().debugString; + ASSERT_THAT(debugString, ::testing::HasSubstr("length " + std::to_string(_messageSize))); + ASSERT_THAT(debugString, + ::testing::HasSubstr(_memoryType == UCS_MEMORY_TYPE_HOST ? "host" : "cuda")); + } + + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +class RequestAttributesDisabledTest : public ::testing::Test { + protected: + static constexpr size_t kMessageLength = 1024; + static constexpr size_t kMessageSize = kMessageLength * sizeof(int); + + std::shared_ptr _context; + std::shared_ptr _worker; + std::shared_ptr _ep; + std::function _progressWorker; + std::vector _sendBuf; + std::vector _recvBuf; + + void SetUp() override + { + _context = ucxx::createContext({}, ucxx::Context::defaultFeatureFlags); + _worker = ucxx::experimental::createWorker(_context).build(); + ASSERT_FALSE(_worker->isRequestAttributesEnabled()); + + _ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + _progressWorker = getProgressFunction(_worker, ProgressMode::Polling); + + _sendBuf.resize(kMessageLength); + _recvBuf.resize(kMessageLength); + std::iota(_sendBuf.begin(), _sendBuf.end(), 0); + } + + void expectAllThrow(const std::vector>& requests) const + { + for (const auto& request : requests) { + EXPECT_THROW(std::ignore = request->queryAttributes(), ucxx::UnsupportedError); + } + } +}; + +TEST_F(RequestAttributesDisabledTest, Tag) +{ + std::vector> requests; + requests.push_back(_ep->tagSend(_sendBuf.data(), kMessageSize, ucxx::Tag{0})); + requests.push_back(_ep->tagRecv(_recvBuf.data(), kMessageSize, ucxx::Tag{0}, ucxx::TagMaskFull)); + waitRequests(_worker, requests, _progressWorker); + + expectAllThrow(requests); + ASSERT_THAT(_recvBuf, ::testing::ContainerEq(_sendBuf)); +} + +TEST_F(RequestAttributesDisabledTest, Stream) +{ + std::vector> requests; + requests.push_back(_ep->streamSend(_sendBuf.data(), kMessageSize, 0)); + requests.push_back(_ep->streamRecv(_recvBuf.data(), kMessageSize, 0)); + waitRequests(_worker, requests, _progressWorker); + + expectAllThrow(requests); + ASSERT_THAT(_recvBuf, ::testing::ContainerEq(_sendBuf)); +} + +TEST_F(RequestAttributesDisabledTest, Am) +{ + std::vector> requests; + requests.push_back(_ep->amSend(_sendBuf.data(), kMessageSize, UCS_MEMORY_TYPE_HOST)); + requests.push_back(_ep->amRecv()); + waitRequests(_worker, requests, _progressWorker); + + expectAllThrow(requests); + + auto recvBuffer = requests[1]->getRecvBuffer(); + ASSERT_EQ(recvBuffer->getSize(), kMessageSize); + std::vector received(reinterpret_cast(recvBuffer->data()), + reinterpret_cast(recvBuffer->data()) + kMessageLength); + ASSERT_THAT(received, ::testing::ContainerEq(_sendBuf)); +} + +TEST_F(RequestAttributesDisabledTest, MemoryGet) +{ + auto memoryHandle = _context->createMemoryHandle(kMessageSize, nullptr, UCS_MEMORY_TYPE_HOST); + std::memcpy( + reinterpret_cast(memoryHandle->getBaseAddress()), _sendBuf.data(), kMessageSize); + + auto serializedRemoteKey = memoryHandle->createRemoteKey()->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + auto request = _ep->memGet(_recvBuf.data(), kMessageSize, remoteKey); + std::vector> requests{request, _ep->flush()}; + waitRequests(_worker, requests, _progressWorker); + + expectAllThrow({request}); + ASSERT_THAT(_recvBuf, ::testing::ContainerEq(_sendBuf)); +} + +TEST_F(RequestAttributesDisabledTest, MemoryPut) +{ + auto memoryHandle = _context->createMemoryHandle(kMessageSize, nullptr, UCS_MEMORY_TYPE_HOST); + + auto serializedRemoteKey = memoryHandle->createRemoteKey()->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + auto request = _ep->memPut(_sendBuf.data(), kMessageSize, remoteKey); + std::vector> requests{request, _ep->flush()}; + waitRequests(_worker, requests, _progressWorker); + + expectAllThrow({request}); + + std::memcpy( + _recvBuf.data(), reinterpret_cast(memoryHandle->getBaseAddress()), kMessageSize); + ASSERT_THAT(_recvBuf, ::testing::ContainerEq(_sendBuf)); +} + +TEST_P(RequestTest, ProgressStreamRequestAttributes) +{ + if (_messageSize == 0) GTEST_SKIP() << "Stream rejects zero-length transfers"; + + rebuildWorker(true); + + allocate(); + + auto sendRequest = _ep->streamSend(_sendPtr[0], _messageSize, 0); + auto recvRequest = _ep->streamRecv(_recvPtr[0], _messageSize, 0); + std::vector> requests{sendRequest, recvRequest}; + waitRequests(_worker, requests, _progressWorker); + + try { + auto sendDebug = sendRequest->queryAttributes().debugString; + EXPECT_FALSE(sendDebug.empty()); + EXPECT_THAT(sendDebug, ::testing::HasSubstr("length " + std::to_string(_messageSize))); + } catch (const ucxx::NoElemError&) { + // Send completed inline; no UCP request handle to query. + } + + try { + auto recvDebug = recvRequest->queryAttributes().debugString; + EXPECT_THAT(recvDebug, ::testing::HasSubstr("no debug info")); + } catch (const ucxx::NoElemError&) { + // Recv completed inline; no UCP request handle to query. + } + + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, ProgressAmRequestAttributes) +{ + if (_messageSize < _rndvThresh) + GTEST_SKIP() << "Eager messages complete inline without a UCP request to query"; + if (_progressMode == ProgressMode::Wait) + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + + rebuildWorker(true); + + allocate(1, false); + + std::vector> requests; + requests.push_back(_ep->amSend(_sendPtr[0], _messageSize, _memoryType)); + requests.push_back(_ep->amRecv()); + waitRequests(_worker, requests, _progressWorker); + + for (const auto& request : requests) { + auto debugString = request->queryAttributes().debugString; + ASSERT_THAT(debugString, ::testing::HasSubstr("length " + std::to_string(_messageSize))); + } + + auto recvReq = requests[1]; + _recvPtr[0] = recvReq->getRecvBuffer()->data(); + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, MemoryGetRequestAttributes) +{ + if (_messageSize == 0) GTEST_SKIP() << "Zero-length memGet completes without a UCP request"; + + rebuildWorker(true); + + allocate(); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, nullptr, _memoryType); + copyMemoryTypeAware( + reinterpret_cast(memoryHandle->getBaseAddress()), _sendPtr[0], _messageSize); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + auto request = _ep->memGet(_recvPtr[0], _messageSize, remoteKey); + std::vector> requests; + requests.push_back(request); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + auto debugString = request->queryAttributes().debugString; + ASSERT_THAT(debugString, ::testing::HasSubstr("length " + std::to_string(_messageSize))); + + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, MemoryPutRequestAttributes) +{ + if (_messageSize == 0) GTEST_SKIP() << "Zero-length memPut completes without a UCP request"; + + rebuildWorker(true); + + allocate(); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, nullptr, _memoryType); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + auto request = _ep->memPut(_sendPtr[0], _messageSize, remoteKey); + std::vector> requests; + requests.push_back(request); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + try { + auto debugString = request->queryAttributes().debugString; + EXPECT_FALSE(debugString.empty()); + EXPECT_THAT(debugString, ::testing::HasSubstr("length " + std::to_string(_messageSize))); + } catch (const ucxx::NoElemError&) { + // Request completed inline; no UCP request handle to query. + } + + copyMemoryTypeAware( + _recvPtr[0], reinterpret_cast(memoryHandle->getBaseAddress()), _messageSize); + + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + TEST_P(RequestTest, ProgressTagMulti) { if (_progressMode == ProgressMode::Wait) { diff --git a/cpp/tests/worker.cpp b/cpp/tests/worker.cpp index b7350a93e..1749c8c7b 100644 --- a/cpp/tests/worker.cpp +++ b/cpp/tests/worker.cpp @@ -108,6 +108,21 @@ class WorkerGenericCallbackSingleTest : public WorkerProgressTest {}; TEST_F(WorkerTest, HandleIsValid) { ASSERT_TRUE(_worker->getHandle() != nullptr); } +TEST_F(WorkerTest, QueryAttributes) +{ + auto attrs = _worker->queryAttributes(); + + // The worker was created with UCS_THREAD_MODE_MULTI in the constructor. + EXPECT_EQ(attrs.threadMode, UCS_THREAD_MODE_MULTI); + + // The remaining fields are determined by UCX configuration, so the strongest + // portable assertion is that they were populated with non-zero / non-empty + // values. + EXPECT_GT(attrs.maxAmHeader, 0u); + EXPECT_FALSE(attrs.name.empty()); + EXPECT_GT(attrs.maxDebugString, 0u); +} + TEST_P(WorkerCapabilityTest, CheckCapability) { ASSERT_EQ(_worker->isDelayedRequestSubmissionEnabled(), _enableDelayedSubmission); @@ -876,6 +891,35 @@ TEST(WorkerBuilderTest, BuilderBackwardCompatibility) ASSERT_TRUE(worker2->isFutureEnabled()); } +TEST(WorkerBuilderTest, RequestAttributesDefaultDisabled) +{ + auto context = ucxx::experimental::createContext(ucxx::Context::defaultFeatureFlags).build(); + auto worker = ucxx::experimental::createWorker(context).build(); + + ASSERT_TRUE(worker != nullptr); + ASSERT_FALSE(worker->isRequestAttributesEnabled()); +} + +TEST(WorkerBuilderTest, RequestAttributesEnabled) +{ + auto context = ucxx::experimental::createContext(ucxx::Context::defaultFeatureFlags).build(); + auto worker = ucxx::experimental::createWorker(context).requestAttributes(true).build(); + + ASSERT_TRUE(worker != nullptr); + ASSERT_TRUE(worker->isRequestAttributesEnabled()); + ASSERT_FALSE(worker->isDelayedRequestSubmissionEnabled()); + ASSERT_FALSE(worker->isFutureEnabled()); +} + +TEST(WorkerBuilderTest, RequestAttributesExplicitDisable) +{ + auto context = ucxx::experimental::createContext(ucxx::Context::defaultFeatureFlags).build(); + auto worker = ucxx::experimental::createWorker(context).requestAttributes(false).build(); + + ASSERT_TRUE(worker != nullptr); + ASSERT_FALSE(worker->isRequestAttributesEnabled()); +} + TEST(AmReceiverCallbackOwnerTypeTest, DefaultConstructsEmpty) { ucxx::AmReceiverCallbackOwnerType owner;