diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index 0f9e13a03..3a56d878b 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -96,14 +96,22 @@ class InflightRequests { void remove(const std::shared_ptr& request); /** - * @brief Issue cancelation of all inflight requests and clear the internal container. + * @brief Issue cancelation of inflight requests and clear the internal container. * - * Issue cancelation of all inflight requests known to this object and clear the - * internal container. The total number of canceled requests is returned. + * Issue cancelation of inflight requests known to this object. The total number of + * canceled requests is returned. * + * When `workerOnly` is `true`, only worker-scoped operations + * (`Request::isWorkerOperation() == true`, i.e. receive variants) are cancelled and + * removed; endpoint-scoped operations are left in the container untouched, on the + * assumption that the caller will follow up with `ucp_ep_close_nbx(FORCE)` which + * handles their UCT-level cleanup atomically. When `workerOnly` is `false` (the + * default), all inflight requests are cancelled. + * + * @param[in] workerOnly if `true`, cancel only worker-scoped requests. * @returns The total number of canceled requests. */ - size_t cancelAll(); + size_t cancelAll(bool workerOnly = false); /** * @brief Releases the internally-tracked containers. diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index 6ad7607bd..8aefab930 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -127,6 +127,27 @@ class Request : public Component { */ virtual void cancel(); + /** + * @brief Whether this request is worker-scoped (vs endpoint-scoped). + * + * Returns `true` for operations that are tracked by the UCP worker rather than a + * specific UCP endpoint, currently the receive variants (`TagReceive`, + * `TagReceiveWithHandle`, `AmReceive`, `StreamReceive`, `TagMultiReceive`). + * + * `Endpoint::closeBlocking()` uses this to decide which requests must be cancelled + * explicitly via `ucp_request_cancel`: worker-scoped requests are not cancelled by + * `ucp_ep_close_nbx(UCP_EP_CLOSE_FLAG_FORCE)` (which only tears down endpoint-bound + * state) and would otherwise hang forever. Endpoint-scoped requests are left to + * UCX's FORCE close, which handles their UCT-level cleanup atomically, calling + * `ucp_request_cancel` on them in addition to FORCE close has been observed to + * leave UCT-level pending queue entries referencing freed staging buffers, which + * the next `ucp_worker_progress()` then crashes dispatching (see + * test_send_recv_multi.py CUDA segfault). + * + * @returns `true` if this request is worker-scoped, `false` if endpoint-scoped. + */ + [[nodiscard]] bool isWorkerOperation() const; + /** * @brief Return the status of the request. * diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 5d880c359..4dd9ce3f1 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -269,13 +269,6 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) { if (_closing.exchange(true) || _handle == nullptr) return; - size_t canceled = cancelInflightRequestsBlocking(3000000000 /* 3s */, 3); - ucxx_debug("ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, canceled %lu requests", - __func__, - this, - _handle, - canceled); - ucp_request_param_t param{}; if (_endpointErrorHandling) param = {.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, .flags = UCP_EP_CLOSE_FLAG_FORCE}; @@ -288,8 +281,45 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) bool submitted = false; for (uint64_t i = 0; i < maxAttempts && !closeSuccess; ++i) { if (!submitted) { + // Cancel WORKER-SCOPED inflight requests (tag_recv & friends) and submit + // FORCE close ATOMICALLY in a single pre-callback no + // ucp_worker_progress() between them, and no ucp_request_cancel() on + // endpoint-scoped operations. + // + // Why cancel anything at all (UCX FORCE close handles endpoint state): + // Receive operations (`tag_recv`, `am_recv`, ...) post requests on the + // UCP worker via `ucp_*_recv_nbx(worker, ...)`. ucp_ep_close_nbx(FORCE) + // tears down the UCP endpoint but leaves worker-scoped requests + // pending forever, an `await ep.close()` racing with an outstanding + // `await ep.recv()` would hang. See + // test_shutdown.py::test_{server,client}_shutdown. + // + // Why ONLY worker-scoped (not endpoint-scoped sends/RMA): + // Calling ucp_request_cancel() on an endpoint-scoped tag_send and then + // FORCE-closing the endpoint has been observed to leave UCT-level TCP + // pending queue entries pointing at freed staging buffers; the next + // ucp_worker_progress() then crashed dispatching them + // (uct_tcp_pending_queue_dispatch -> uct_cuda_copy_ep_get_short -> + // cuMemcpyAsync -> SIGSEGV), see test_send_recv_multi.py CUDA segfault. + // FORCE close handles endpoint-scoped requests' UCT cleanup atomically + // on its own, so we leave them alone. + // + // Why atomic with FORCE close (not as a separate pre-callback): + // When cancelAll and FORCE close were separate pre-callbacks (the + // old cancelInflightRequestsBlocking path), a full ucp_worker_progress() + // ran between them, see prior commit "Cancel inflight requests and + // submit force-close atomically in a single pre-callback". if (!worker->registerGenericPre( - [this, &status, ¶m]() { status = ucp_ep_close_nbx(_handle, ¶m); }, period)) + [this, &status, ¶m]() { + _inflightRequests->cancelAll(/*workerOnly=*/true); + status = ucp_ep_close_nbx(_handle, ¶m); + // Invalidate _handle synchronously immediately, to prevent + // time window where _handle` points to freed UCP memory, usually + // observed in `populateDelayedSubmission()`. + _originalHandle = _handle; + _handle = nullptr; + }, + period)) continue; submitted = true; } @@ -306,7 +336,7 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) "endpoint: %s", __func__, this, - _handle, + _originalHandle, ucs_status_string(UCS_PTR_STATUS(status))); } }, @@ -323,10 +353,16 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) "ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, all attempts to close timed out", __func__, this, - _handle); + _originalHandle != nullptr ? _originalHandle : _handle); } } else { - status = ucp_ep_close_nbx(_handle, ¶m); + // No progress thread: cancel inflight + FORCE close back-to-back, then + // drive progress here. Same atomicity reasoning as the progress-thread + // path above (no ucp_worker_progress() between cancel and FORCE close). + _inflightRequests->cancelAll(/*workerOnly=*/true); + status = ucp_ep_close_nbx(_handle, ¶m); + _originalHandle = _handle; + _handle = nullptr; if (UCS_PTR_IS_PTR(status)) { ucs_status_t s; while ((s = ucp_request_check_status(status)) == UCS_INPROGRESS) @@ -337,11 +373,12 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) "ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, Error while closing endpoint: %s", __func__, this, - _handle, + _originalHandle, ucs_status_string(UCS_PTR_STATUS(status))); } } - ucxx_trace("ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, closed", __func__, this, _handle); + ucxx_trace( + "ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, closed", __func__, this, _originalHandle); if (UCS_PTR_IS_PTR(status)) ucp_request_free(status); @@ -351,14 +388,12 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) ucxx_debug("ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, calling user close callback", __func__, this, - _handle); + _originalHandle); _closeCallback(_status, _closeCallbackArg); _closeCallback = nullptr; _closeCallbackArg = nullptr; } } - - std::swap(_handle, _originalHandle); } ucp_ep_h Endpoint::getHandle() { return _handle; } @@ -448,7 +483,7 @@ size_t Endpoint::cancelInflightRequestsBlocking(uint64_t period, uint64_t maxAtt "cancel inflight requests failed", __func__, this, - _handle); + _originalHandle != nullptr ? _originalHandle : _handle); } else { canceled = _inflightRequests->cancelAll(); } diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 4c9b396e1..76bd2d5f4 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -41,18 +41,34 @@ void InflightRequests::merge(TrackedRequests&& trackedRequests) if (r) _canceling.insert(std::move(r)); } -size_t InflightRequests::cancelAll() +size_t InflightRequests::cancelAll(bool workerOnly) { decltype(_inflight) toCancel; + decltype(_inflight) toKeep; { std::lock_guard lock(_mutex); - toCancel = std::exchange(_inflight, {}); + if (workerOnly) { + // Partition: receive (worker-scoped) requests get explicit ucp_request_cancel, + // endpoint-scoped requests stay in `_inflight` to be cleaned up by FORCE close. + for (auto& r : _inflight) { + if (r && r->isWorkerOperation()) + toCancel.insert(r); + else + toKeep.insert(r); + } + _inflight = std::move(toKeep); + } else { + toCancel = std::exchange(_inflight, {}); + } } size_t total = toCancel.size(); if (total == 0) return 0; - ucxx_debug("ucxx::InflightRequests::%s, canceling %lu requests", __func__, total); + ucxx_debug("ucxx::InflightRequests::%s, canceling %lu requests (workerOnly=%d)", + __func__, + total, + workerOnly); for (auto& r : toCancel) { if (r) r->cancel(); diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 1551e4f23..65c97416b 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -71,6 +71,24 @@ Request::~Request() ucxx_trace("ucxx::Request destroyed (%s): %p", _operationName.c_str(), this); } +bool Request::isWorkerOperation() const +{ + // Receive operations go through ucp_tag_recv_nbx / ucp_am_recv_nbx and friends, + // which post requests against the UCP worker rather than a specific endpoint. + // ucp_ep_close_nbx(FORCE) does not cancel them, so closeBlocking() must do so + // explicitly. Send / RMA operations are endpoint-bound and are cleaned up by + // FORCE close itself. + return std::visit(data::dispatch{ + [](const data::TagReceive&) { return true; }, + [](const data::TagReceiveWithHandle&) { return true; }, + [](const data::TagMultiReceive&) { return true; }, + [](const data::AmReceive&) { return true; }, + [](const data::StreamReceive&) { return true; }, + [](const auto&) { return false; }, + }, + _requestData); +} + void Request::cancel() { std::lock_guard lock(_mutex); diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index 7e1bbc5fe..2d26b5b73 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -498,9 +498,9 @@ async def write( for each_frame in send_frames: await self.ep.send(each_frame) return sum(sizes) - except ucxx.exceptions.UCXError: + except BaseException as e: self.abort() - raise CommClosedError("While writing, the connection was closed") + raise CommClosedError("While writing, the connection was closed") from e @log_errors async def read(self, deserializers=("cuda", "dask", "pickle", "error")): @@ -729,6 +729,8 @@ async def serve_forever(client_ep): self.ucxx_server = ucxx.create_listener(serve_forever, port=self._input_port) def stop(self): + if self.ucxx_server is not None: + self.ucxx_server.close() self.ucxx_server = None _deregister_dask_resource(self._resource_id) diff --git a/python/distributed-ucxx/distributed_ucxx/utils_test.py b/python/distributed-ucxx/distributed_ucxx/utils_test.py index f539c5125..c6aeb51b4 100644 --- a/python/distributed-ucxx/distributed_ucxx/utils_test.py +++ b/python/distributed-ucxx/distributed_ucxx/utils_test.py @@ -1,9 +1,10 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations import asyncio +import gc import logging import sys @@ -84,6 +85,10 @@ def ucxx_loop(request): with check_thread_leak(): yield loop + + # Collect garbage to break any remaining reference cycles before reset. + gc.collect() + if ignore_alive_references: try: ucxx.reset() diff --git a/python/ucxx/ucxx/_lib_async/application_context.py b/python/ucxx/ucxx/_lib_async/application_context.py index 89488c1ac..1f1aa5a5f 100644 --- a/python/ucxx/ucxx/_lib_async/application_context.py +++ b/python/ucxx/ucxx/_lib_async/application_context.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import logging @@ -311,7 +311,7 @@ def create_listener( cb_args=( loop, callback_func, - self, + weakref.ref(self), endpoint_error_handling, connect_timeout, listener_id, @@ -321,6 +321,7 @@ def create_listener( ), listener_id, self._listener_active_clients, + ctx=self, ) return ret diff --git a/python/ucxx/ucxx/_lib_async/listener.py b/python/ucxx/ucxx/_lib_async/listener.py index 9abed0375..114de7dfb 100644 --- a/python/ucxx/ucxx/_lib_async/listener.py +++ b/python/ucxx/ucxx/_lib_async/listener.py @@ -34,7 +34,9 @@ def __init__(self): def add_listener(self, ident: int) -> None: if ident in self._active_clients: - raise ValueError("Listener {ident} is already registered in ActiveClients.") + raise ValueError( + f"Listener {ident} is already registered in ActiveClients." + ) self._locks[ident] = threading.Lock() self._active_clients[ident] = 0 @@ -44,7 +46,7 @@ def remove_listener(self, ident: int) -> None: active_clients = self.get_active(ident) if active_clients > 0: raise RuntimeError( - "Listener {ident} is being removed from ActiveClients, but " + f"Listener {ident} is being removed from ActiveClients, but " f"{active_clients} active client(s) is(are) still accounted for." ) @@ -97,11 +99,15 @@ class Listener: Please use `create_listener()` to create an Listener. """ - def __init__(self, listener, ident, active_clients): + def __init__(self, listener, ident, active_clients, ctx=None): if not isinstance(listener, ucx_api.UCXListener): raise ValueError("listener must be an instance of UCXListener") self._listener = listener + # Hold a strong reference to ApplicationContext so that reset() correctly + # detects a live Listener. Released by close() so the context can be freed + # even if UCXListener is still in a reference cycle. + self._ctx = ctx active_clients.add_listener(ident) self._ident = ident @@ -130,12 +136,13 @@ def active_clients(self): def close(self): """Closing the listener""" + self._ctx = None self._listener = None async def _listener_handler_coroutine( conn_request, - ctx, + ctx_weakref, func, endpoint_error_handling, connect_timeout, @@ -148,6 +155,14 @@ async def _listener_handler_coroutine( # 3) Exchange endpoint info such as tags # 4) Setup control receive callback # 5) Execute the listener's callback function + + ctx = ctx_weakref() + if ctx is None: + logger.debug( + "ApplicationContext was freed before listener handler coroutine ran" + ) + return + active_clients.inc(ident) ep = None try: @@ -187,7 +202,7 @@ async def _listener_handler_coroutine( ) # Removing references here to avoid delayed clean up - del ctx + ctx = None # Finally, we call `func` if inspect.iscoroutinefunction(func): @@ -201,6 +216,7 @@ async def _listener_handler_coroutine( logger.exception("Unexpected error in listener handler coroutine") finally: active_clients.dec(ident) + del ctx del ep @@ -208,7 +224,7 @@ def _listener_handler( conn_request, event_loop, callback_func, - ctx, + ctx_weakref, endpoint_error_handling, connect_timeout, ident, @@ -217,7 +233,7 @@ def _listener_handler( asyncio.run_coroutine_threadsafe( _listener_handler_coroutine( conn_request, - ctx, + ctx_weakref, callback_func, endpoint_error_handling, connect_timeout, diff --git a/python/ucxx/ucxx/_lib_async/tests/conftest.py b/python/ucxx/ucxx/_lib_async/tests/conftest.py index b1bc3f76f..cc52e2171 100644 --- a/python/ucxx/ucxx/_lib_async/tests/conftest.py +++ b/python/ucxx/ucxx/_lib_async/tests/conftest.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import asyncio +import functools import gc import inspect import os @@ -112,6 +113,39 @@ def pytest_runtest_setup(item: pytest.Item) -> None: item.config.stash[ASYNCIO_PLUGIN_TIMEOUT_STASH_KEY] = timeout +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_call(item: pytest.Item): + """ + Inject the asyncio timeout BEFORE pytest-asyncio applies its MonkeyPatch. + + In pytest-asyncio 1.3.0 (asyncio_mode=auto) + Python 3.14, the framework + replaces ``item.obj`` with a sync wrapper inside + ``PytestAsyncioFunction.runtest()``, which is called during + ``pytest_runtest_call``. By the time ``pytest_pyfunc_call`` fires, + ``item.obj`` is already a sync function and + ``inspect.iscoroutinefunction()`` returns False. Wrapping here instead + guarantees ``asyncio.wait_for`` is present in the coroutine that + pytest-asyncio eventually passes to its ``asyncio.Runner``. + """ + if isinstance(item, pytest.Function) and inspect.iscoroutinefunction(item.obj): + timeout = _asyncio_plugin_timeout_seconds(item) + if timeout > 0.0: + original = item.obj + test_name = item.name + + @functools.wraps(original) + async def timed_coroutine(*args, **kwargs): + try: + return await asyncio.wait_for( + original(*args, **kwargs), timeout=timeout + ) + except (asyncio.CancelledError, asyncio.TimeoutError): + pytest.fail(f"{test_name} timed out after {timeout} seconds.") + + item.obj = timed_coroutine + yield + + @pytest.hookimpl(tryfirst=True, hookwrapper=True) def pytest_pyfunc_call(pyfuncitem: pytest.Function): """ diff --git a/python/ucxx/ucxx/_lib_async/tests/test_send_recv.py b/python/ucxx/ucxx/_lib_async/tests/test_send_recv.py index 0fc9a776a..d2f59cc61 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_send_recv.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_send_recv.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import functools @@ -94,6 +94,7 @@ async def test_send_recv_cupy(size, dtype): resp = cupy.empty_like(msg) await client.recv(resp) np.testing.assert_array_equal(cupy.asnumpy(resp), cupy.asnumpy(msg)) + await client.close() await wait_listener_client_handlers(listener) @@ -116,6 +117,7 @@ async def test_send_recv_numba(size, dtype): resp = cuda.device_array_like(msg) await client.recv(resp) np.testing.assert_array_equal(np.array(resp), np.array(msg)) + await client.close() await wait_listener_client_handlers(listener) diff --git a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py index 7a33664d0..f8fcae96e 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import pytest @@ -8,7 +8,11 @@ np = pytest.importorskip("numpy") -msg_sizes = [2**i for i in range(0, 25, 4)] +# Some CI nodes can be _very_ slow for large sized messages, generally only on +# 4 or 8 messages, thus substantially increase the timeouts for 16MiB messages. +msg_sizes = [2**i for i in range(0, 24, 4)] + [ + pytest.param(2**24, marks=pytest.mark.asyncio_timeout(240)) +] # multi_sizes = [0, 1, 2, 3, 4, 8] multi_sizes = [1, 2, 3, 4, 8] dtypes = ["|u1", " 0: + if deadline is not None and loop.time() >= deadline: + raise asyncio.TimeoutError( + f"Listener still has {listener.active_clients} active client(s) " + f"after {timeout}s, likely due to a deadlock in the handler coroutine." + ) # Minimal delay to yield to the event loop so call_soon_threadsafe callbacks # run. Using a very short positive sleep ensures pending callbacks are # processed and significantly reduces "coroutine never awaited" warnings. - await asyncio.sleep(1e-9) + try: + await asyncio.sleep(1e-9) + except asyncio.CancelledError: + cancelled = True + # Python 3.11+ tracks cancellation depth; calling uncancel() lets + # this loop continue iterating rather than being re-cancelled on + # the very next await. + task = asyncio.current_task() + if task is not None and hasattr(task, "uncancel"): + task.uncancel() if not ucxx.core._get_ctx().progress_mode.startswith("thread"): ucxx.progress() + if cancelled: + raise asyncio.CancelledError()