diff --git a/python/ucxx/ucxx/_lib_async/tests/test_multiple_nodes.py b/python/ucxx/ucxx/_lib_async/tests/test_multiple_nodes.py index a9d6de35..91b42f47 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_multiple_nodes.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_multiple_nodes.py @@ -10,6 +10,11 @@ from ucxx._lib_async.utils_test import wait_listener_client_handlers +DEFAULT_CONNECT_TIMEOUT = 10.0 +MANY_ENDPOINTS_CONNECT_TIMEOUT = 30.0 +MANY_ENDPOINTS_THRESHOLD = 50 + + def get_somaxconn(): with open("/proc/sys/net/core/somaxconn", "r") as f: return int(f.readline()) @@ -31,8 +36,10 @@ async def server_node(ep): # assert isinstance(ep.ucx_info(), str) -async def client_node(port): - ep = await ucxx.create_endpoint(ucxx.get_address(), port, connect_timeout=10.0) +async def client_node(port, connect_timeout): + ep = await ucxx.create_endpoint( + ucxx.get_address(), port, connect_timeout=connect_timeout + ) await hello(ep) await ep.close() # assert isinstance(ep.ucx_info(), str) @@ -51,18 +58,33 @@ async def client_node(port): ) async def test_many_servers_many_clients(num_servers, num_clients): somaxconn = get_somaxconn() + num_endpoints = num_clients * num_servers + connect_timeout = ( + MANY_ENDPOINTS_CONNECT_TIMEOUT + if num_endpoints >= MANY_ENDPOINTS_THRESHOLD + else DEFAULT_CONNECT_TIMEOUT + ) listeners = [] for _ in range(num_servers): - listeners.append(ucxx.create_listener(server_node, connect_timeout=10.0)) + listeners.append( + ucxx.create_listener(server_node, connect_timeout=connect_timeout) + ) # We ensure no more than `somaxconn` connections are submitted # at once. Doing otherwise can block and hang indefinitely. - for i in range(0, num_clients * num_servers, somaxconn): + for batch_start in range(0, num_endpoints, somaxconn): clients = [] - for __ in range(i, min(i + somaxconn, num_clients * num_servers)): - clients.append(client_node(listeners[__ % num_servers].port)) + for endpoint_index in range( + batch_start, min(batch_start + somaxconn, num_endpoints) + ): + clients.append( + client_node( + listeners[endpoint_index % num_servers].port, + connect_timeout=connect_timeout, + ) + ) await asyncio.gather(*clients) await asyncio.gather( *(wait_listener_client_handlers(listener) for listener in listeners)