From a2c620b338c0fed0fdd5390afdcbc6e87f56be24 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 17 Apr 2026 09:24:29 -0700 Subject: [PATCH 01/18] Catch `BaseException` during distributed-ucxx `write()` --- python/distributed-ucxx/distributed_ucxx/ucxx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index 7e1bbc5fe..020efd234 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -498,9 +498,9 @@ async def write( for each_frame in send_frames: await self.ep.send(each_frame) return sum(sizes) - except ucxx.exceptions.UCXError: + except BaseException as e: self.abort() - raise CommClosedError("While writing, the connection was closed") + raise CommClosedError("While writing, the connection was closed") from e @log_errors async def read(self, deserializers=("cuda", "dask", "pickle", "error")): From 4ce668c36e8c34b6b48902388c6d6fc133014320 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 17 Apr 2026 09:25:58 -0700 Subject: [PATCH 02/18] Fix possible ApplicationContext leak and f-strings --- python/ucxx/ucxx/_lib_async/listener.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/ucxx/ucxx/_lib_async/listener.py b/python/ucxx/ucxx/_lib_async/listener.py index 9abed0375..e8442b199 100644 --- a/python/ucxx/ucxx/_lib_async/listener.py +++ b/python/ucxx/ucxx/_lib_async/listener.py @@ -34,7 +34,9 @@ def __init__(self): def add_listener(self, ident: int) -> None: if ident in self._active_clients: - raise ValueError("Listener {ident} is already registered in ActiveClients.") + raise ValueError( + f"Listener {ident} is already registered in ActiveClients." + ) self._locks[ident] = threading.Lock() self._active_clients[ident] = 0 @@ -44,7 +46,7 @@ def remove_listener(self, ident: int) -> None: active_clients = self.get_active(ident) if active_clients > 0: raise RuntimeError( - "Listener {ident} is being removed from ActiveClients, but " + f"Listener {ident} is being removed from ActiveClients, but " f"{active_clients} active client(s) is(are) still accounted for." ) @@ -201,6 +203,8 @@ async def _listener_handler_coroutine( logger.exception("Unexpected error in listener handler coroutine") finally: active_clients.dec(ident) + if ep is not None: + ep._ctx = None del ep From 4a3b2455e532196334b991f824af0a57d7fdea4b Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 17 Apr 2026 10:06:57 -0700 Subject: [PATCH 03/18] Revert wrong circular reference breaking --- python/ucxx/ucxx/_lib_async/listener.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/ucxx/ucxx/_lib_async/listener.py b/python/ucxx/ucxx/_lib_async/listener.py index e8442b199..5861ca194 100644 --- a/python/ucxx/ucxx/_lib_async/listener.py +++ b/python/ucxx/ucxx/_lib_async/listener.py @@ -203,8 +203,6 @@ async def _listener_handler_coroutine( logger.exception("Unexpected error in listener handler coroutine") finally: active_clients.dec(ident) - if ep is not None: - ep._ctx = None del ep From 5daeb14be416605f1144160d27973759044f4d35 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 23 Apr 2026 00:52:40 -0700 Subject: [PATCH 04/18] Flush event loop before ucxx.reset() --- .../distributed_ucxx/utils_test.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/python/distributed-ucxx/distributed_ucxx/utils_test.py b/python/distributed-ucxx/distributed_ucxx/utils_test.py index f539c5125..3e6c96ffb 100644 --- a/python/distributed-ucxx/distributed_ucxx/utils_test.py +++ b/python/distributed-ucxx/distributed_ucxx/utils_test.py @@ -1,9 +1,10 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations import asyncio +import gc import logging import sys @@ -84,6 +85,25 @@ def ucxx_loop(request): with check_thread_leak(): yield loop + + # Flush the event loop so that any in-flight _listener_handler_coroutine + # finally-blocks (which release Endpoint._ctx) have a chance to complete + # before we call ucxx.reset(). We also collect garbage to break the + # UCXXListener -> Listener -> UCXListener -> cb_args -> serve_forever closure + # -> UCXXListener reference cycle that CPython's reference counting alone + # won't free. + try: + asyncio_loop = loop.asyncio_loop + asyncio.run_coroutine_threadsafe(asyncio.sleep(0), asyncio_loop).result( + timeout=5 + ) + asyncio.run_coroutine_threadsafe(asyncio.sleep(0), asyncio_loop).result( + timeout=5 + ) + except Exception: + pass + gc.collect() + if ignore_alive_references: try: ucxx.reset() From 74ee825f69bab4e2e8c3a71ce7baa2cf979ba100 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 23 Apr 2026 04:51:05 -0700 Subject: [PATCH 05/18] Attempt to break UCXXListener refcycle --- .../distributed-ucxx/distributed_ucxx/ucxx.py | 21 ++++++++++++------- .../distributed_ucxx/utils_test.py | 17 +-------------- .../ucxx/_lib_async/application_context.py | 4 ++-- python/ucxx/ucxx/_lib_async/listener.py | 4 ++++ 4 files changed, 20 insertions(+), 26 deletions(-) diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index 020efd234..8e010b1ba 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -707,21 +707,26 @@ def address(self): return f"{self.prefix}{self.ip}:{self.port}" async def start(self): + listener_ref = weakref.ref(self) + async def serve_forever(client_ep): - ucx = self.comm_class( + listener = listener_ref() + if listener is None: + return + ucx = listener.comm_class( client_ep, - local_addr=self.address, - peer_addr=self.address, - deserialize=self.deserialize, + local_addr=listener.address, + peer_addr=listener.address, + deserialize=listener.deserialize, ) - ucx.allow_offload = self.allow_offload + ucx.allow_offload = listener.allow_offload try: - await self.on_connection(ucx) + await listener.on_connection(ucx) except CommClosedError: logger.debug("Connection closed before handshake completed") return - if self.comm_handler: - await self.comm_handler(ucx) + if listener.comm_handler: + await listener.comm_handler(ucx) init_once() self._resource_id = _register_dask_resource() diff --git a/python/distributed-ucxx/distributed_ucxx/utils_test.py b/python/distributed-ucxx/distributed_ucxx/utils_test.py index 3e6c96ffb..c6aeb51b4 100644 --- a/python/distributed-ucxx/distributed_ucxx/utils_test.py +++ b/python/distributed-ucxx/distributed_ucxx/utils_test.py @@ -86,22 +86,7 @@ def ucxx_loop(request): with check_thread_leak(): yield loop - # Flush the event loop so that any in-flight _listener_handler_coroutine - # finally-blocks (which release Endpoint._ctx) have a chance to complete - # before we call ucxx.reset(). We also collect garbage to break the - # UCXXListener -> Listener -> UCXListener -> cb_args -> serve_forever closure - # -> UCXXListener reference cycle that CPython's reference counting alone - # won't free. - try: - asyncio_loop = loop.asyncio_loop - asyncio.run_coroutine_threadsafe(asyncio.sleep(0), asyncio_loop).result( - timeout=5 - ) - asyncio.run_coroutine_threadsafe(asyncio.sleep(0), asyncio_loop).result( - timeout=5 - ) - except Exception: - pass + # Collect garbage to break any remaining reference cycles before reset. gc.collect() if ignore_alive_references: diff --git a/python/ucxx/ucxx/_lib_async/application_context.py b/python/ucxx/ucxx/_lib_async/application_context.py index 89488c1ac..78fabaa0e 100644 --- a/python/ucxx/ucxx/_lib_async/application_context.py +++ b/python/ucxx/ucxx/_lib_async/application_context.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import logging @@ -311,7 +311,7 @@ def create_listener( cb_args=( loop, callback_func, - self, + weakref.ref(self), endpoint_error_handling, connect_timeout, listener_id, diff --git a/python/ucxx/ucxx/_lib_async/listener.py b/python/ucxx/ucxx/_lib_async/listener.py index 5861ca194..3c24776f6 100644 --- a/python/ucxx/ucxx/_lib_async/listener.py +++ b/python/ucxx/ucxx/_lib_async/listener.py @@ -150,6 +150,10 @@ async def _listener_handler_coroutine( # 3) Exchange endpoint info such as tags # 4) Setup control receive callback # 5) Execute the listener's callback function + if isinstance(ctx, weakref.ref): + ctx = ctx() + if ctx is None: + return active_clients.inc(ident) ep = None try: From 047402e5ee1bd8ce049c66356947b6c49e9ca60d Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 23 Apr 2026 05:29:48 -0700 Subject: [PATCH 06/18] Stop ucxx_server in UCXXListener.stop() --- .../distributed-ucxx/distributed_ucxx/ucxx.py | 23 ++++++++----------- .../ucxx/_lib_async/application_context.py | 4 ++-- python/ucxx/ucxx/_lib_async/listener.py | 4 ---- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index 8e010b1ba..2d26b5b73 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -707,26 +707,21 @@ def address(self): return f"{self.prefix}{self.ip}:{self.port}" async def start(self): - listener_ref = weakref.ref(self) - async def serve_forever(client_ep): - listener = listener_ref() - if listener is None: - return - ucx = listener.comm_class( + ucx = self.comm_class( client_ep, - local_addr=listener.address, - peer_addr=listener.address, - deserialize=listener.deserialize, + local_addr=self.address, + peer_addr=self.address, + deserialize=self.deserialize, ) - ucx.allow_offload = listener.allow_offload + ucx.allow_offload = self.allow_offload try: - await listener.on_connection(ucx) + await self.on_connection(ucx) except CommClosedError: logger.debug("Connection closed before handshake completed") return - if listener.comm_handler: - await listener.comm_handler(ucx) + if self.comm_handler: + await self.comm_handler(ucx) init_once() self._resource_id = _register_dask_resource() @@ -734,6 +729,8 @@ async def serve_forever(client_ep): self.ucxx_server = ucxx.create_listener(serve_forever, port=self._input_port) def stop(self): + if self.ucxx_server is not None: + self.ucxx_server.close() self.ucxx_server = None _deregister_dask_resource(self._resource_id) diff --git a/python/ucxx/ucxx/_lib_async/application_context.py b/python/ucxx/ucxx/_lib_async/application_context.py index 78fabaa0e..89488c1ac 100644 --- a/python/ucxx/ucxx/_lib_async/application_context.py +++ b/python/ucxx/ucxx/_lib_async/application_context.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import logging @@ -311,7 +311,7 @@ def create_listener( cb_args=( loop, callback_func, - weakref.ref(self), + self, endpoint_error_handling, connect_timeout, listener_id, diff --git a/python/ucxx/ucxx/_lib_async/listener.py b/python/ucxx/ucxx/_lib_async/listener.py index 3c24776f6..5861ca194 100644 --- a/python/ucxx/ucxx/_lib_async/listener.py +++ b/python/ucxx/ucxx/_lib_async/listener.py @@ -150,10 +150,6 @@ async def _listener_handler_coroutine( # 3) Exchange endpoint info such as tags # 4) Setup control receive callback # 5) Execute the listener's callback function - if isinstance(ctx, weakref.ref): - ctx = ctx() - if ctx is None: - return active_clients.inc(ident) ep = None try: From 91b8d21ca5376b5d51f68fbb839ae61d2ef2fd4a Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 23 Apr 2026 07:27:15 -0700 Subject: [PATCH 07/18] Replace ApplicationContext self reference for UCXListener --- .../ucxx/_lib_async/application_context.py | 5 ++-- python/ucxx/ucxx/_lib_async/listener.py | 29 +++++++++++++++---- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/python/ucxx/ucxx/_lib_async/application_context.py b/python/ucxx/ucxx/_lib_async/application_context.py index 89488c1ac..1f1aa5a5f 100644 --- a/python/ucxx/ucxx/_lib_async/application_context.py +++ b/python/ucxx/ucxx/_lib_async/application_context.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import logging @@ -311,7 +311,7 @@ def create_listener( cb_args=( loop, callback_func, - self, + weakref.ref(self), endpoint_error_handling, connect_timeout, listener_id, @@ -321,6 +321,7 @@ def create_listener( ), listener_id, self._listener_active_clients, + ctx=self, ) return ret diff --git a/python/ucxx/ucxx/_lib_async/listener.py b/python/ucxx/ucxx/_lib_async/listener.py index 5861ca194..fe2887fc4 100644 --- a/python/ucxx/ucxx/_lib_async/listener.py +++ b/python/ucxx/ucxx/_lib_async/listener.py @@ -99,11 +99,15 @@ class Listener: Please use `create_listener()` to create an Listener. """ - def __init__(self, listener, ident, active_clients): + def __init__(self, listener, ident, active_clients, ctx=None): if not isinstance(listener, ucx_api.UCXListener): raise ValueError("listener must be an instance of UCXListener") self._listener = listener + # Hold a strong reference to ApplicationContext so that reset() correctly + # detects a live Listener. Released by close() so the context can be freed + # even if UCXListener is still in a reference cycle. + self._ctx = ctx active_clients.add_listener(ident) self._ident = ident @@ -132,12 +136,13 @@ def active_clients(self): def close(self): """Closing the listener""" + self._ctx = None self._listener = None async def _listener_handler_coroutine( conn_request, - ctx, + ctx_weakref, func, endpoint_error_handling, connect_timeout, @@ -150,6 +155,17 @@ async def _listener_handler_coroutine( # 3) Exchange endpoint info such as tags # 4) Setup control receive callback # 5) Execute the listener's callback function + + # Dereference the weakref immediately so the UCXListener's cb_args tuple + # does not keep ApplicationContext alive through this coroutine's frame. + ctx = ctx_weakref() + del ctx_weakref + if ctx is None: + logger.debug( + "ApplicationContext was freed before listener handler coroutine ran" + ) + return + active_clients.inc(ident) ep = None try: @@ -189,7 +205,7 @@ async def _listener_handler_coroutine( ) # Removing references here to avoid delayed clean up - del ctx + ctx = None # Finally, we call `func` if inspect.iscoroutinefunction(func): @@ -203,6 +219,9 @@ async def _listener_handler_coroutine( logger.exception("Unexpected error in listener handler coroutine") finally: active_clients.dec(ident) + # Release ApplicationContext reference even on early exits (e.g., cancellation + # before the explicit ctx = None above). + del ctx del ep @@ -210,7 +229,7 @@ def _listener_handler( conn_request, event_loop, callback_func, - ctx, + ctx_weakref, endpoint_error_handling, connect_timeout, ident, @@ -219,7 +238,7 @@ def _listener_handler( asyncio.run_coroutine_threadsafe( _listener_handler_coroutine( conn_request, - ctx, + ctx_weakref, callback_func, endpoint_error_handling, connect_timeout, From c0aa3be3fc99df927dff554c60dd700fab3e85a8 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 24 Apr 2026 05:01:47 -0700 Subject: [PATCH 08/18] Fix asyncio timeout for Python 3.14 + pytest-asyncio 1.3.0 In pytest-asyncio 1.3.0 (asyncio_mode=auto) + Python 3.14, the framework replaces item.obj with a sync wrapper before pytest_pyfunc_call fires, so inspect.iscoroutinefunction() returns False and asyncio.wait_for was never applied, and tests could hang indefinitely. Fix by adding a pytest_runtest_call hook that wraps the original async function with asyncio.wait_for before pytest-asyncio's MonkeyPatch runs. Also add an optional timeout parameter to wait_listener_client_handlers as a defensive guard against infinite waits in handler polling loops. --- python/ucxx/ucxx/_lib_async/tests/conftest.py | 34 +++++++++++++++++++ python/ucxx/ucxx/_lib_async/utils_test.py | 9 ++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/python/ucxx/ucxx/_lib_async/tests/conftest.py b/python/ucxx/ucxx/_lib_async/tests/conftest.py index b1bc3f76f..cc52e2171 100644 --- a/python/ucxx/ucxx/_lib_async/tests/conftest.py +++ b/python/ucxx/ucxx/_lib_async/tests/conftest.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import asyncio +import functools import gc import inspect import os @@ -112,6 +113,39 @@ def pytest_runtest_setup(item: pytest.Item) -> None: item.config.stash[ASYNCIO_PLUGIN_TIMEOUT_STASH_KEY] = timeout +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_call(item: pytest.Item): + """ + Inject the asyncio timeout BEFORE pytest-asyncio applies its MonkeyPatch. + + In pytest-asyncio 1.3.0 (asyncio_mode=auto) + Python 3.14, the framework + replaces ``item.obj`` with a sync wrapper inside + ``PytestAsyncioFunction.runtest()``, which is called during + ``pytest_runtest_call``. By the time ``pytest_pyfunc_call`` fires, + ``item.obj`` is already a sync function and + ``inspect.iscoroutinefunction()`` returns False. Wrapping here instead + guarantees ``asyncio.wait_for`` is present in the coroutine that + pytest-asyncio eventually passes to its ``asyncio.Runner``. + """ + if isinstance(item, pytest.Function) and inspect.iscoroutinefunction(item.obj): + timeout = _asyncio_plugin_timeout_seconds(item) + if timeout > 0.0: + original = item.obj + test_name = item.name + + @functools.wraps(original) + async def timed_coroutine(*args, **kwargs): + try: + return await asyncio.wait_for( + original(*args, **kwargs), timeout=timeout + ) + except (asyncio.CancelledError, asyncio.TimeoutError): + pytest.fail(f"{test_name} timed out after {timeout} seconds.") + + item.obj = timed_coroutine + yield + + @pytest.hookimpl(tryfirst=True, hookwrapper=True) def pytest_pyfunc_call(pyfuncitem: pytest.Function): """ diff --git a/python/ucxx/ucxx/_lib_async/utils_test.py b/python/ucxx/ucxx/_lib_async/utils_test.py index e6b6359e5..e877d39c2 100644 --- a/python/ucxx/ucxx/_lib_async/utils_test.py +++ b/python/ucxx/ucxx/_lib_async/utils_test.py @@ -196,8 +196,15 @@ async def am_recv(ep): return frames, msg -async def wait_listener_client_handlers(listener): +async def wait_listener_client_handlers(listener, timeout=None): + loop = asyncio.get_event_loop() + deadline = (loop.time() + timeout) if timeout is not None else None while listener.active_clients > 0: + if deadline is not None and loop.time() >= deadline: + raise asyncio.TimeoutError( + f"Listener still has {listener.active_clients} active client(s) " + f"after {timeout}s, likely due to a deadlock in the handler coroutine." + ) # Minimal delay to yield to the event loop so call_soon_threadsafe callbacks # run. Using a very short positive sleep ensures pending callbacks are # processed and significantly reduces "coroutine never awaited" warnings. From 9597e3a7c057f0ec710148731afef6306611342e Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 27 Apr 2026 01:23:02 -0700 Subject: [PATCH 09/18] Increase timeout to 120s for cupy/numba 16 MB in tag_multi tests These tests were failing CI with the new 60s default timeout. GPU transfers of 16 MB buffers with multi_size up to 8 can legitimately take over 60 seconds on CI hardware. --- python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py index 7a33664d0..1cef6c70f 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import pytest @@ -8,7 +8,9 @@ np = pytest.importorskip("numpy") -msg_sizes = [2**i for i in range(0, 25, 4)] +msg_sizes = [2**i for i in range(0, 24, 4)] + [ + pytest.param(2**24, marks=pytest.mark.asyncio_timeout(120)) +] # multi_sizes = [0, 1, 2, 3, 4, 8] multi_sizes = [1, 2, 3, 4, 8] dtypes = ["|u1", " Date: Mon, 27 Apr 2026 01:49:43 -0700 Subject: [PATCH 10/18] Increase timeout to 240s seconds --- python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py index 1cef6c70f..d099545b9 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py @@ -9,7 +9,7 @@ np = pytest.importorskip("numpy") msg_sizes = [2**i for i in range(0, 24, 4)] + [ - pytest.param(2**24, marks=pytest.mark.asyncio_timeout(120)) + pytest.param(2**24, marks=pytest.mark.asyncio_timeout(240)) ] # multi_sizes = [0, 1, 2, 3, 4, 8] multi_sizes = [1, 2, 3, 4, 8] From a35bc22062394cb3d20442c43660fef3a6d74af0 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 27 Apr 2026 01:51:29 -0700 Subject: [PATCH 11/18] Add timeout comment --- python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py index d099545b9..b42771632 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py @@ -8,6 +8,8 @@ np = pytest.importorskip("numpy") +# Some CI nodes can be _very_ slow for large sized messages, generally only on +# 4 or 8 messages, thus substantially increase the timeouts for 16MiB messages. msg_sizes = [2**i for i in range(0, 24, 4)] + [ pytest.param(2**24, marks=pytest.mark.asyncio_timeout(240)) ] From 7917b45b3125f4082bf635c2a09fe6af1dad9866 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 28 Apr 2026 03:04:05 -0700 Subject: [PATCH 12/18] Prevent segfault when test timeout cancels mid-CUDA-transfer handler When asyncio.wait_for fires during wait_listener_client_handlers, the CancelledError propagated immediately, letting the Listener be GC'd while a handler coroutine still held an in-flight CUDA send/recv. The UCX progress thread's cuMemcpyAsync then raced with the GC finalizer's close_blocking() call, causing a segfault in ucp_mem_type_pack. The fix catches CancelledError in wait_listener_client_handlers and defers it until active_clients reaches 0, keeping the Listener alive long enough for all handlers to complete. Calls task.uncancel() on Python 3.11+ to prevent immediate re-cancellation on the next await. --- python/ucxx/ucxx/_lib_async/utils_test.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/python/ucxx/ucxx/_lib_async/utils_test.py b/python/ucxx/ucxx/_lib_async/utils_test.py index e877d39c2..b40f214c0 100644 --- a/python/ucxx/ucxx/_lib_async/utils_test.py +++ b/python/ucxx/ucxx/_lib_async/utils_test.py @@ -199,6 +199,12 @@ async def am_recv(ep): async def wait_listener_client_handlers(listener, timeout=None): loop = asyncio.get_event_loop() deadline = (loop.time() + timeout) if timeout is not None else None + # If this coroutine is cancelled (e.g., by asyncio.wait_for test timeout) + # while handlers are still active, we defer the CancelledError until all + # handlers finish. Raising immediately would let the Listener be GC'd + # while a handler holds an in-flight CUDA transfer, which races with the + # UCX progress thread and causes a segfault. + cancelled = False while listener.active_clients > 0: if deadline is not None and loop.time() >= deadline: raise asyncio.TimeoutError( @@ -208,6 +214,17 @@ async def wait_listener_client_handlers(listener, timeout=None): # Minimal delay to yield to the event loop so call_soon_threadsafe callbacks # run. Using a very short positive sleep ensures pending callbacks are # processed and significantly reduces "coroutine never awaited" warnings. - await asyncio.sleep(1e-9) + try: + await asyncio.sleep(1e-9) + except asyncio.CancelledError: + cancelled = True + # Python 3.11+ tracks cancellation depth; calling uncancel() lets + # this loop continue iterating rather than being re-cancelled on + # the very next await. + task = asyncio.current_task() + if task is not None and hasattr(task, "uncancel"): + task.uncancel() if not ucxx.core._get_ctx().progress_mode.startswith("thread"): ucxx.progress() + if cancelled: + raise asyncio.CancelledError() From afa3f4777f8249a3cdb5e7efd71568774d3326e1 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 28 Apr 2026 04:29:36 -0700 Subject: [PATCH 13/18] Remove unnecessary del and misleading comment from listener handler del ctx_weakref was a no-op: weakrefs cannot keep the referent alive by definition, so there was nothing to release. The accompanying comment was also wrong for the same reason. --- python/ucxx/ucxx/_lib_async/listener.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/ucxx/ucxx/_lib_async/listener.py b/python/ucxx/ucxx/_lib_async/listener.py index fe2887fc4..114de7dfb 100644 --- a/python/ucxx/ucxx/_lib_async/listener.py +++ b/python/ucxx/ucxx/_lib_async/listener.py @@ -156,10 +156,7 @@ async def _listener_handler_coroutine( # 4) Setup control receive callback # 5) Execute the listener's callback function - # Dereference the weakref immediately so the UCXListener's cb_args tuple - # does not keep ApplicationContext alive through this coroutine's frame. ctx = ctx_weakref() - del ctx_weakref if ctx is None: logger.debug( "ApplicationContext was freed before listener handler coroutine ran" @@ -219,8 +216,6 @@ async def _listener_handler_coroutine( logger.exception("Unexpected error in listener handler coroutine") finally: active_clients.dec(ident) - # Release ApplicationContext reference even on early exits (e.g., cancellation - # before the explicit ctx = None above). del ctx del ep From 68bdc705f500ef97a3f5967bb73ff9fc18a50c28 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 28 Apr 2026 04:38:08 -0700 Subject: [PATCH 14/18] Close client endpoint before waiting for listener handlers in CUDA tests Without an explicit close, the client Endpoint is finalized during asyncio event loop teardown. Its _finalizer calls close_blocking(), which calls ucp_worker_progress() from the Python thread concurrently with the WorkerProgressThread causing a race on cuMemcpyAsync that segfaults. Closing the client inside the test coroutine (while the event loop is still running) performs the UCX close handshake while the progress thread is properly synchronized, so the finalizer is a no-op at teardown time. This also lets the echo server's ep.close() complete the handshake immediately rather than blocking for up to 10 s waiting for the peer. --- python/ucxx/ucxx/_lib_async/tests/test_send_recv.py | 4 +++- python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/ucxx/ucxx/_lib_async/tests/test_send_recv.py b/python/ucxx/ucxx/_lib_async/tests/test_send_recv.py index 0fc9a776a..d2f59cc61 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_send_recv.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_send_recv.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause import functools @@ -94,6 +94,7 @@ async def test_send_recv_cupy(size, dtype): resp = cupy.empty_like(msg) await client.recv(resp) np.testing.assert_array_equal(cupy.asnumpy(resp), cupy.asnumpy(msg)) + await client.close() await wait_listener_client_handlers(listener) @@ -116,6 +117,7 @@ async def test_send_recv_numba(size, dtype): resp = cuda.device_array_like(msg) await client.recv(resp) np.testing.assert_array_equal(np.array(resp), np.array(msg)) + await client.close() await wait_listener_client_handlers(listener) diff --git a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py index b42771632..f8fcae96e 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_multi.py @@ -69,6 +69,7 @@ async def test_send_recv_numpy(size, multi_size, dtype): recv_msg = await client.recv_multi() for r, s in zip(recv_msg, send_msg): np.testing.assert_array_equal(r.view(dtype), s) + await client.close() await wait_listener_client_handlers(listener) @@ -88,6 +89,7 @@ async def test_send_recv_cupy(size, multi_size, dtype): recv_msg = await client.recv_multi() for r, s in zip(recv_msg, send_msg): cupy.testing.assert_array_equal(cupy.asarray(r).view(dtype), cupy.asarray(s)) + await client.close() await wait_listener_client_handlers(listener) @@ -109,4 +111,5 @@ async def test_send_recv_numba(size, multi_size, dtype): np.testing.assert_array_equal( r.copy_to_host().view(dtype), s.copy_to_host().view(dtype) ) + await client.close() await wait_listener_client_handlers(listener) From 37ee61aa5d0a5b3355b35823522f71e189c5bad4 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 28 Apr 2026 11:57:31 -0700 Subject: [PATCH 15/18] Do not cancel inflight requests before endpoint close --- cpp/src/endpoint.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 5d880c359..b273909a5 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -269,13 +269,6 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) { if (_closing.exchange(true) || _handle == nullptr) return; - size_t canceled = cancelInflightRequestsBlocking(3000000000 /* 3s */, 3); - ucxx_debug("ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, canceled %lu requests", - __func__, - this, - _handle, - canceled); - ucp_request_param_t param{}; if (_endpointErrorHandling) param = {.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, .flags = UCP_EP_CLOSE_FLAG_FORCE}; From f4d4cb99493ab5cbc08e2bdb24ae2cf19ab7ab88 Mon Sep 17 00:00:00 2001 From: Horde Date: Tue, 28 Apr 2026 20:01:35 +0000 Subject: [PATCH 16/18] Cancel inflight requests and submit force-close atomically in a single pre-callback --- cpp/src/endpoint.cpp | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index b273909a5..14141b6d7 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -281,8 +281,33 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) bool submitted = false; for (uint64_t i = 0; i < maxAttempts && !closeSuccess; ++i) { if (!submitted) { + // Cancel inflight requests and submit FORCE close ATOMICALLY in a + // single pre-callback, with no ucp_worker_progress() between them. + // + // Why cancel here at all (UCX FORCE close already cancels endpoint + // operations): + // tag_recv requests are worker-scoped (ucp_tag_recv_nbx(worker, ...)), + // not endpoint-scoped, so ucp_ep_close_nbx(FORCE) leaves them pending. + // Without ucp_request_cancel() here, an `await ep.close()` running + // alongside an outstanding `await ep.recv()` would hang forever. + // See test_shutdown.py::test_{server,client}_shutdown. + // + // Why atomic with FORCE close (not as a separate pre-callback): + // When cancelAll and FORCE close were separate pre-callbacks (the + // old cancelInflightRequestsBlocking path), a full ucp_worker_progress() + // ran between them. That intermediate progress could leave UCT-level + // TCP pending entries half-dispatched (mid-cuMemcpyAsync staging of + // a CUDA send); the next progress after FORCE close then crashed + // dispatching them on a freed staging buffer (uct_cuda_copy_ep_get_short + // -> cuMemcpyAsync -> SIGSEGV). Running them in a single pre-callback + // matches the safe single-threaded ordering proven by the regression + // test in cpp/tests/endpoint_close_force_tcp_cuda_race.cpp. if (!worker->registerGenericPre( - [this, &status, ¶m]() { status = ucp_ep_close_nbx(_handle, ¶m); }, period)) + [this, &status, ¶m]() { + _inflightRequests->cancelAll(); + status = ucp_ep_close_nbx(_handle, ¶m); + }, + period)) continue; submitted = true; } @@ -319,6 +344,10 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) _handle); } } else { + // No progress thread: cancel inflight + FORCE close back-to-back, then + // drive progress here. Same atomicity reasoning as the progress-thread + // path above (no ucp_worker_progress() between cancel and FORCE close). + _inflightRequests->cancelAll(); status = ucp_ep_close_nbx(_handle, ¶m); if (UCS_PTR_IS_PTR(status)) { ucs_status_t s; From 53dcb83455ce5c6beccaf54ad1956e9712eea7fc Mon Sep 17 00:00:00 2001 From: Horde Date: Wed, 29 Apr 2026 12:19:29 +0000 Subject: [PATCH 17/18] Fix invalid _handle usage --- cpp/src/endpoint.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 14141b6d7..874872bcf 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -306,6 +306,11 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) [this, &status, ¶m]() { _inflightRequests->cancelAll(); status = ucp_ep_close_nbx(_handle, ¶m); + // Invalidate _handle synchronously immediately, to prevent + // time window where _handle` points to freed UCP memory, usually + // observed in `populateDelayedSubmission()`. + _originalHandle = _handle; + _handle = nullptr; }, period)) continue; @@ -324,7 +329,7 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) "endpoint: %s", __func__, this, - _handle, + _originalHandle, ucs_status_string(UCS_PTR_STATUS(status))); } }, @@ -341,14 +346,16 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) "ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, all attempts to close timed out", __func__, this, - _handle); + _originalHandle != nullptr ? _originalHandle : _handle); } } else { // No progress thread: cancel inflight + FORCE close back-to-back, then // drive progress here. Same atomicity reasoning as the progress-thread // path above (no ucp_worker_progress() between cancel and FORCE close). _inflightRequests->cancelAll(); - status = ucp_ep_close_nbx(_handle, ¶m); + status = ucp_ep_close_nbx(_handle, ¶m); + _originalHandle = _handle; + _handle = nullptr; if (UCS_PTR_IS_PTR(status)) { ucs_status_t s; while ((s = ucp_request_check_status(status)) == UCS_INPROGRESS) @@ -359,11 +366,12 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) "ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, Error while closing endpoint: %s", __func__, this, - _handle, + _originalHandle, ucs_status_string(UCS_PTR_STATUS(status))); } } - ucxx_trace("ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, closed", __func__, this, _handle); + ucxx_trace( + "ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, closed", __func__, this, _originalHandle); if (UCS_PTR_IS_PTR(status)) ucp_request_free(status); @@ -373,14 +381,12 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) ucxx_debug("ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, calling user close callback", __func__, this, - _handle); + _originalHandle); _closeCallback(_status, _closeCallbackArg); _closeCallback = nullptr; _closeCallbackArg = nullptr; } } - - std::swap(_handle, _originalHandle); } ucp_ep_h Endpoint::getHandle() { return _handle; } From 2bd2e1ddebfdf1b484dd3d64b5234dc2dc08ccf7 Mon Sep 17 00:00:00 2001 From: Horde Date: Tue, 12 May 2026 15:03:58 +0000 Subject: [PATCH 18/18] Scoped inflight request cancelation --- cpp/include/ucxx/inflight_requests.h | 16 +++++++--- cpp/include/ucxx/request.h | 21 +++++++++++++ cpp/src/endpoint.cpp | 47 ++++++++++++++++------------ cpp/src/inflight_requests.cpp | 22 +++++++++++-- cpp/src/request.cpp | 18 +++++++++++ 5 files changed, 97 insertions(+), 27 deletions(-) diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index 0f9e13a03..3a56d878b 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -96,14 +96,22 @@ class InflightRequests { void remove(const std::shared_ptr& request); /** - * @brief Issue cancelation of all inflight requests and clear the internal container. + * @brief Issue cancelation of inflight requests and clear the internal container. * - * Issue cancelation of all inflight requests known to this object and clear the - * internal container. The total number of canceled requests is returned. + * Issue cancelation of inflight requests known to this object. The total number of + * canceled requests is returned. * + * When `workerOnly` is `true`, only worker-scoped operations + * (`Request::isWorkerOperation() == true`, i.e. receive variants) are cancelled and + * removed; endpoint-scoped operations are left in the container untouched, on the + * assumption that the caller will follow up with `ucp_ep_close_nbx(FORCE)` which + * handles their UCT-level cleanup atomically. When `workerOnly` is `false` (the + * default), all inflight requests are cancelled. + * + * @param[in] workerOnly if `true`, cancel only worker-scoped requests. * @returns The total number of canceled requests. */ - size_t cancelAll(); + size_t cancelAll(bool workerOnly = false); /** * @brief Releases the internally-tracked containers. diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index 6ad7607bd..8aefab930 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -127,6 +127,27 @@ class Request : public Component { */ virtual void cancel(); + /** + * @brief Whether this request is worker-scoped (vs endpoint-scoped). + * + * Returns `true` for operations that are tracked by the UCP worker rather than a + * specific UCP endpoint, currently the receive variants (`TagReceive`, + * `TagReceiveWithHandle`, `AmReceive`, `StreamReceive`, `TagMultiReceive`). + * + * `Endpoint::closeBlocking()` uses this to decide which requests must be cancelled + * explicitly via `ucp_request_cancel`: worker-scoped requests are not cancelled by + * `ucp_ep_close_nbx(UCP_EP_CLOSE_FLAG_FORCE)` (which only tears down endpoint-bound + * state) and would otherwise hang forever. Endpoint-scoped requests are left to + * UCX's FORCE close, which handles their UCT-level cleanup atomically, calling + * `ucp_request_cancel` on them in addition to FORCE close has been observed to + * leave UCT-level pending queue entries referencing freed staging buffers, which + * the next `ucp_worker_progress()` then crashes dispatching (see + * test_send_recv_multi.py CUDA segfault). + * + * @returns `true` if this request is worker-scoped, `false` if endpoint-scoped. + */ + [[nodiscard]] bool isWorkerOperation() const; + /** * @brief Return the status of the request. * diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 874872bcf..4dd9ce3f1 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -281,30 +281,37 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) bool submitted = false; for (uint64_t i = 0; i < maxAttempts && !closeSuccess; ++i) { if (!submitted) { - // Cancel inflight requests and submit FORCE close ATOMICALLY in a - // single pre-callback, with no ucp_worker_progress() between them. + // Cancel WORKER-SCOPED inflight requests (tag_recv & friends) and submit + // FORCE close ATOMICALLY in a single pre-callback no + // ucp_worker_progress() between them, and no ucp_request_cancel() on + // endpoint-scoped operations. // - // Why cancel here at all (UCX FORCE close already cancels endpoint - // operations): - // tag_recv requests are worker-scoped (ucp_tag_recv_nbx(worker, ...)), - // not endpoint-scoped, so ucp_ep_close_nbx(FORCE) leaves them pending. - // Without ucp_request_cancel() here, an `await ep.close()` running - // alongside an outstanding `await ep.recv()` would hang forever. - // See test_shutdown.py::test_{server,client}_shutdown. + // Why cancel anything at all (UCX FORCE close handles endpoint state): + // Receive operations (`tag_recv`, `am_recv`, ...) post requests on the + // UCP worker via `ucp_*_recv_nbx(worker, ...)`. ucp_ep_close_nbx(FORCE) + // tears down the UCP endpoint but leaves worker-scoped requests + // pending forever, an `await ep.close()` racing with an outstanding + // `await ep.recv()` would hang. See + // test_shutdown.py::test_{server,client}_shutdown. + // + // Why ONLY worker-scoped (not endpoint-scoped sends/RMA): + // Calling ucp_request_cancel() on an endpoint-scoped tag_send and then + // FORCE-closing the endpoint has been observed to leave UCT-level TCP + // pending queue entries pointing at freed staging buffers; the next + // ucp_worker_progress() then crashed dispatching them + // (uct_tcp_pending_queue_dispatch -> uct_cuda_copy_ep_get_short -> + // cuMemcpyAsync -> SIGSEGV), see test_send_recv_multi.py CUDA segfault. + // FORCE close handles endpoint-scoped requests' UCT cleanup atomically + // on its own, so we leave them alone. // // Why atomic with FORCE close (not as a separate pre-callback): // When cancelAll and FORCE close were separate pre-callbacks (the // old cancelInflightRequestsBlocking path), a full ucp_worker_progress() - // ran between them. That intermediate progress could leave UCT-level - // TCP pending entries half-dispatched (mid-cuMemcpyAsync staging of - // a CUDA send); the next progress after FORCE close then crashed - // dispatching them on a freed staging buffer (uct_cuda_copy_ep_get_short - // -> cuMemcpyAsync -> SIGSEGV). Running them in a single pre-callback - // matches the safe single-threaded ordering proven by the regression - // test in cpp/tests/endpoint_close_force_tcp_cuda_race.cpp. + // ran between them, see prior commit "Cancel inflight requests and + // submit force-close atomically in a single pre-callback". if (!worker->registerGenericPre( [this, &status, ¶m]() { - _inflightRequests->cancelAll(); + _inflightRequests->cancelAll(/*workerOnly=*/true); status = ucp_ep_close_nbx(_handle, ¶m); // Invalidate _handle synchronously immediately, to prevent // time window where _handle` points to freed UCP memory, usually @@ -350,9 +357,9 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) } } else { // No progress thread: cancel inflight + FORCE close back-to-back, then - // drive progress here. Same atomicity reasoning as the progress-thread + // drive progress here. Same atomicity reasoning as the progress-thread // path above (no ucp_worker_progress() between cancel and FORCE close). - _inflightRequests->cancelAll(); + _inflightRequests->cancelAll(/*workerOnly=*/true); status = ucp_ep_close_nbx(_handle, ¶m); _originalHandle = _handle; _handle = nullptr; @@ -476,7 +483,7 @@ size_t Endpoint::cancelInflightRequestsBlocking(uint64_t period, uint64_t maxAtt "cancel inflight requests failed", __func__, this, - _handle); + _originalHandle != nullptr ? _originalHandle : _handle); } else { canceled = _inflightRequests->cancelAll(); } diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 4c9b396e1..76bd2d5f4 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -41,18 +41,34 @@ void InflightRequests::merge(TrackedRequests&& trackedRequests) if (r) _canceling.insert(std::move(r)); } -size_t InflightRequests::cancelAll() +size_t InflightRequests::cancelAll(bool workerOnly) { decltype(_inflight) toCancel; + decltype(_inflight) toKeep; { std::lock_guard lock(_mutex); - toCancel = std::exchange(_inflight, {}); + if (workerOnly) { + // Partition: receive (worker-scoped) requests get explicit ucp_request_cancel, + // endpoint-scoped requests stay in `_inflight` to be cleaned up by FORCE close. + for (auto& r : _inflight) { + if (r && r->isWorkerOperation()) + toCancel.insert(r); + else + toKeep.insert(r); + } + _inflight = std::move(toKeep); + } else { + toCancel = std::exchange(_inflight, {}); + } } size_t total = toCancel.size(); if (total == 0) return 0; - ucxx_debug("ucxx::InflightRequests::%s, canceling %lu requests", __func__, total); + ucxx_debug("ucxx::InflightRequests::%s, canceling %lu requests (workerOnly=%d)", + __func__, + total, + workerOnly); for (auto& r : toCancel) { if (r) r->cancel(); diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 1551e4f23..65c97416b 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -71,6 +71,24 @@ Request::~Request() ucxx_trace("ucxx::Request destroyed (%s): %p", _operationName.c_str(), this); } +bool Request::isWorkerOperation() const +{ + // Receive operations go through ucp_tag_recv_nbx / ucp_am_recv_nbx and friends, + // which post requests against the UCP worker rather than a specific endpoint. + // ucp_ep_close_nbx(FORCE) does not cancel them, so closeBlocking() must do so + // explicitly. Send / RMA operations are endpoint-bound and are cleaned up by + // FORCE close itself. + return std::visit(data::dispatch{ + [](const data::TagReceive&) { return true; }, + [](const data::TagReceiveWithHandle&) { return true; }, + [](const data::TagMultiReceive&) { return true; }, + [](const data::AmReceive&) { return true; }, + [](const data::StreamReceive&) { return true; }, + [](const auto&) { return false; }, + }, + _requestData); +} + void Request::cancel() { std::lock_guard lock(_mutex);