Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions python/ucxx/ucxx/_lib_async/tests/test_multiple_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from ucxx._lib_async.utils_test import wait_listener_client_handlers


DEFAULT_CONNECT_TIMEOUT = 10.0
MANY_ENDPOINTS_CONNECT_TIMEOUT = 30.0
MANY_ENDPOINTS_THRESHOLD = 50


def get_somaxconn():
with open("/proc/sys/net/core/somaxconn", "r") as f:
return int(f.readline())
Expand All @@ -31,8 +36,10 @@ async def server_node(ep):
# assert isinstance(ep.ucx_info(), str)


async def client_node(port):
ep = await ucxx.create_endpoint(ucxx.get_address(), port, connect_timeout=10.0)
async def client_node(port, connect_timeout):
ep = await ucxx.create_endpoint(
ucxx.get_address(), port, connect_timeout=connect_timeout
)
await hello(ep)
await ep.close()
# assert isinstance(ep.ucx_info(), str)
Expand All @@ -51,18 +58,33 @@ async def client_node(port):
)
async def test_many_servers_many_clients(num_servers, num_clients):
somaxconn = get_somaxconn()
num_endpoints = num_clients * num_servers
connect_timeout = (
MANY_ENDPOINTS_CONNECT_TIMEOUT
if num_endpoints >= MANY_ENDPOINTS_THRESHOLD
else DEFAULT_CONNECT_TIMEOUT
)

listeners = []

for _ in range(num_servers):
listeners.append(ucxx.create_listener(server_node, connect_timeout=10.0))
listeners.append(
ucxx.create_listener(server_node, connect_timeout=connect_timeout)
)

# We ensure no more than `somaxconn` connections are submitted
# at once. Doing otherwise can block and hang indefinitely.
for i in range(0, num_clients * num_servers, somaxconn):
for batch_start in range(0, num_endpoints, somaxconn):
clients = []
for __ in range(i, min(i + somaxconn, num_clients * num_servers)):
clients.append(client_node(listeners[__ % num_servers].port))
for endpoint_index in range(
batch_start, min(batch_start + somaxconn, num_endpoints)
):
clients.append(
client_node(
listeners[endpoint_index % num_servers].port,
connect_timeout=connect_timeout,
)
)
await asyncio.gather(*clients)
await asyncio.gather(
*(wait_listener_client_handlers(listener) for listener in listeners)
Expand Down
Loading