From 5ccaf665f9182dcacf6f050196e299e44155062f Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 26 Mar 2025 10:32:35 -0400 Subject: [PATCH 1/2] Make ssl_channel_credentials async --- packages/jumpstarter/jumpstarter/common/grpc.py | 5 +++-- packages/jumpstarter/jumpstarter/common/streams.py | 2 +- packages/jumpstarter/jumpstarter/config/client.py | 2 +- packages/jumpstarter/jumpstarter/config/exporter.py | 4 ++-- packages/jumpstarter/jumpstarter/exporter/exporter.py | 8 ++++---- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/packages/jumpstarter/jumpstarter/common/grpc.py b/packages/jumpstarter/jumpstarter/common/grpc.py index a2924db73..060846b3a 100644 --- a/packages/jumpstarter/jumpstarter/common/grpc.py +++ b/packages/jumpstarter/jumpstarter/common/grpc.py @@ -7,11 +7,12 @@ from urllib.parse import urlparse import grpc +from anyio.to_thread import run_sync from jumpstarter.common.exceptions import ConfigurationError, ConnectionError -def ssl_channel_credentials(target: str, tls_config): +async def ssl_channel_credentials(target: str, tls_config): configure_grpc_env() if tls_config.insecure or os.getenv("JUMPSTARTER_GRPC_INSECURE") == "1": try: @@ -21,7 +22,7 @@ def ssl_channel_credentials(target: str, tls_config): raise ConfigurationError(f"Failed parsing {target}") from e try: - root_certificates = ssl.get_server_certificate((parsed.hostname, port)) + root_certificates = await run_sync(ssl.get_server_certificate, (parsed.hostname, port)) return grpc.ssl_channel_credentials(root_certificates=root_certificates.encode()) except socket.gaierror as e: raise ConnectionError(f"Failed resolving {parsed.hostname}") from e diff --git a/packages/jumpstarter/jumpstarter/common/streams.py b/packages/jumpstarter/jumpstarter/common/streams.py index ebd3f2f7f..cddd9f374 100644 --- a/packages/jumpstarter/jumpstarter/common/streams.py +++ b/packages/jumpstarter/jumpstarter/common/streams.py @@ -34,7 +34,7 @@ class StreamRequestMetadata(BaseModel): @asynccontextmanager async def connect_router_stream(endpoint, token, stream, tls_config, grpc_options): credentials = grpc.composite_channel_credentials( - ssl_channel_credentials(endpoint, tls_config), + await ssl_channel_credentials(endpoint, tls_config), grpc.access_token_call_credentials(token), ) diff --git a/packages/jumpstarter/jumpstarter/config/client.py b/packages/jumpstarter/jumpstarter/config/client.py index 7fba240fe..6c1841d28 100644 --- a/packages/jumpstarter/jumpstarter/config/client.py +++ b/packages/jumpstarter/jumpstarter/config/client.py @@ -54,7 +54,7 @@ class ClientConfigV1Alpha1(BaseModel): async def channel(self): credentials = grpc.composite_channel_credentials( - ssl_channel_credentials(self.endpoint, self.tls), + await ssl_channel_credentials(self.endpoint, self.tls), call_credentials("Client", self.metadata, self.token), ) diff --git a/packages/jumpstarter/jumpstarter/config/exporter.py b/packages/jumpstarter/jumpstarter/config/exporter.py index f4240bc67..d8e2f45f0 100644 --- a/packages/jumpstarter/jumpstarter/config/exporter.py +++ b/packages/jumpstarter/jumpstarter/config/exporter.py @@ -159,9 +159,9 @@ async def serve(self): # dynamic import to avoid circular imports from jumpstarter.exporter import Exporter - def channel_factory(): + async def channel_factory(): credentials = grpc.composite_channel_credentials( - ssl_channel_credentials(self.endpoint, self.tls), + await ssl_channel_credentials(self.endpoint, self.tls), call_credentials("Exporter", self.metadata, self.token), ) return aio_secure_channel(self.endpoint, credentials, self.grpcOptions) diff --git a/packages/jumpstarter/jumpstarter/exporter/exporter.py b/packages/jumpstarter/jumpstarter/exporter/exporter.py index e4ff3f865..a901b4c85 100644 --- a/packages/jumpstarter/jumpstarter/exporter/exporter.py +++ b/packages/jumpstarter/jumpstarter/exporter/exporter.py @@ -29,7 +29,7 @@ class Exporter(AbstractAsyncContextManager, Metadata): grpc_options: dict[str, str] = field(default_factory=dict) async def __aexit__(self, exc_type, exc_value, traceback): - controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory()) + controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory()) logger.info("Unregistering exporter with controller") await controller.Unregister( jumpstarter_pb2.UnregisterRequest( @@ -47,7 +47,7 @@ async def __handle(self, path, endpoint, token, tls_config, grpc_options): @asynccontextmanager async def session(self): - controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory()) + controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory()) with Session( uuid=self.uuid, labels=self.labels, @@ -76,7 +76,7 @@ async def listen(retries=5, backoff=3): retries_left = retries while True: try: - controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory()) + controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory()) async for request in controller.Listen(jumpstarter_pb2.ListenRequest(lease_name=lease_name)): await listen_tx.send(request) except Exception as e: @@ -113,7 +113,7 @@ async def status(retries=5, backoff=3): retries_left = retries while True: try: - controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory()) + controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory()) async for status in controller.Status(jumpstarter_pb2.StatusRequest()): await status_tx.send(status) except Exception as e: From 3907ba7395dc501458c57757164f315465d186bd Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 26 Mar 2025 10:38:17 -0400 Subject: [PATCH 2/2] Add timeout to ssl_channel_credentials --- packages/jumpstarter/jumpstarter/common/grpc.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/packages/jumpstarter/jumpstarter/common/grpc.py b/packages/jumpstarter/jumpstarter/common/grpc.py index 060846b3a..f0b59bc99 100644 --- a/packages/jumpstarter/jumpstarter/common/grpc.py +++ b/packages/jumpstarter/jumpstarter/common/grpc.py @@ -7,12 +7,13 @@ from urllib.parse import urlparse import grpc +from anyio import fail_after from anyio.to_thread import run_sync from jumpstarter.common.exceptions import ConfigurationError, ConnectionError -async def ssl_channel_credentials(target: str, tls_config): +async def ssl_channel_credentials(target: str, tls_config, timeout=5): configure_grpc_env() if tls_config.insecure or os.getenv("JUMPSTARTER_GRPC_INSECURE") == "1": try: @@ -22,12 +23,15 @@ async def ssl_channel_credentials(target: str, tls_config): raise ConfigurationError(f"Failed parsing {target}") from e try: - root_certificates = await run_sync(ssl.get_server_certificate, (parsed.hostname, port)) + with fail_after(timeout): + root_certificates = await run_sync(ssl.get_server_certificate, (parsed.hostname, port)) return grpc.ssl_channel_credentials(root_certificates=root_certificates.encode()) except socket.gaierror as e: raise ConnectionError(f"Failed resolving {parsed.hostname}") from e except ConnectionRefusedError as e: raise ConnectionError(f"Failed connecting to {parsed.hostname}:{port}") from e + except TimeoutError as e: + raise ConnectionError(f"Timeout connecting to {parsed.hostname}:{port}") from e elif tls_config.ca != "": ca_certificate = base64.b64decode(tls_config.ca)