From f9dbbad1b3ea2be51c457145508802e372252130 Mon Sep 17 00:00:00 2001 From: Miguel Angel Ajo Pelayo Date: Wed, 19 Mar 2025 17:30:31 +0100 Subject: [PATCH] Allow grpc option tweaking, i.e. keepalive settings --- docs/source/cli/clients.md | 3 +++ docs/source/introduction/exporters.md | 3 +++ .../jumpstarter/jumpstarter/client/lease.py | 6 ++++- .../jumpstarter/jumpstarter/common/grpc.py | 25 +++++++++++++------ .../jumpstarter/jumpstarter/common/streams.py | 4 +-- .../jumpstarter/jumpstarter/config/client.py | 5 +++- .../jumpstarter/config/client_config_test.py | 3 +++ .../jumpstarter/config/exporter.py | 4 ++- .../jumpstarter/exporter/exporter.py | 9 ++++--- 9 files changed, 46 insertions(+), 16 deletions(-) diff --git a/docs/source/cli/clients.md b/docs/source/cli/clients.md index 98354c03e..b4d1ff377 100644 --- a/docs/source/cli/clients.md +++ b/docs/source/cli/clients.md @@ -32,6 +32,9 @@ metadata: name: john endpoint: grpc.jumpstarter.192.168.1.10.nip.io:8082 token: <> +grpcConfig: + # please refer to the https://grpc.github.io/grpc/core/group__grpc__arg__keys.html documentation + grpc.keepalive_time_ms: 20000 tls: ca: '' insecure: False diff --git a/docs/source/introduction/exporters.md b/docs/source/introduction/exporters.md index 2937b26b6..0334939b8 100644 --- a/docs/source/introduction/exporters.md +++ b/docs/source/introduction/exporters.md @@ -30,6 +30,9 @@ metadata: name: demo endpoint: grpc.jumpstarter.example.com:443 token: xxxxx +grpcConfig: + # Please refer to the https://grpc.github.io/grpc/core/group__grpc__arg__keys.html documentation + grpc.keepalive_time_ms: 20000 export: power: type: jumpstarter_driver_yepkit.driver.Ykush diff --git a/packages/jumpstarter/jumpstarter/client/lease.py b/packages/jumpstarter/jumpstarter/client/lease.py index b65b3e458..3ae519678 100644 --- a/packages/jumpstarter/jumpstarter/client/lease.py +++ b/packages/jumpstarter/jumpstarter/client/lease.py @@ -8,6 +8,7 @@ ) from dataclasses import dataclass, field from datetime import datetime, timedelta +from typing import Any from anyio import create_task_group, fail_after, sleep from anyio.from_thread import BlockingPortal @@ -39,6 +40,7 @@ class Lease(AbstractContextManager, AbstractAsyncContextManager): release: bool = True # release on contexts exit controller: jumpstarter_pb2_grpc.ControllerServiceStub = field(init=False) tls_config: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) + grpc_options: dict[str, Any] = field(default_factory=dict) def __post_init__(self): if hasattr(super(), "__post_init__"): @@ -149,7 +151,9 @@ def __exit__(self, exc_type, exc_value, traceback): async def handle_async(self, stream): logger.debug("Connecting to Lease with name %s", self.name) response = await self.controller.Dial(jumpstarter_pb2.DialRequest(lease_name=self.name)) - async with connect_router_stream(response.router_endpoint, response.router_token, stream, self.tls_config): + async with connect_router_stream( + response.router_endpoint, response.router_token, stream, self.tls_config, self.grpc_options + ): pass @asynccontextmanager diff --git a/packages/jumpstarter/jumpstarter/common/grpc.py b/packages/jumpstarter/jumpstarter/common/grpc.py index 8536be200..c686b2826 100644 --- a/packages/jumpstarter/jumpstarter/common/grpc.py +++ b/packages/jumpstarter/jumpstarter/common/grpc.py @@ -3,6 +3,7 @@ import socket import ssl from contextlib import contextmanager +from typing import Any, Sequence, Tuple from urllib.parse import urlparse import grpc @@ -34,20 +35,28 @@ def ssl_channel_credentials(target: str, tls_config): return grpc.ssl_channel_credentials() -def aio_secure_channel(target: str, credentials: grpc.ChannelCredentials): +def aio_secure_channel(target: str, credentials: grpc.ChannelCredentials, grpc_options: dict[str, Any] | None): return grpc.aio.secure_channel( target, credentials, - options=( - ("grpc.lb_policy_name", "round_robin"), - ("grpc.keepalive_time_ms", 350000), - ("grpc.keepalive_timeout_ms", 5000), - ("grpc.http2.max_pings_without_data", 5), - ("grpc.keepalive_permit_without_calls", 1), - ), + options=_override_default_grpc_options(grpc_options), ) +def _override_default_grpc_options(grpc_options: dict[str, str | int] | None) -> Sequence[Tuple[str, Any]]: + defaults = ( + ("grpc.lb_policy_name", "round_robin"), + # we keep a low keepalive time to avoid idle timeouts on cloud load balancers + ("grpc.keepalive_time_ms", 20000), + ("grpc.keepalive_timeout_ms", 5000), + ("grpc.http2.max_pings_without_data", 5), + ("grpc.keepalive_permit_without_calls", 1), + ) + options = dict(defaults) + options.update(grpc_options or {}) + return tuple(options.items()) + + def configure_grpc_env(): # disable informative logs by default, i.e.: # WARNING: All log messages before absl::InitializeLog() is called are written to STDERR diff --git a/packages/jumpstarter/jumpstarter/common/streams.py b/packages/jumpstarter/jumpstarter/common/streams.py index d1fa265d2..ebd3f2f7f 100644 --- a/packages/jumpstarter/jumpstarter/common/streams.py +++ b/packages/jumpstarter/jumpstarter/common/streams.py @@ -32,13 +32,13 @@ class StreamRequestMetadata(BaseModel): @asynccontextmanager -async def connect_router_stream(endpoint, token, stream, tls_config): +async def connect_router_stream(endpoint, token, stream, tls_config, grpc_options): credentials = grpc.composite_channel_credentials( ssl_channel_credentials(endpoint, tls_config), grpc.access_token_call_credentials(token), ) - async with aio_secure_channel(endpoint, credentials) as channel: + async with aio_secure_channel(endpoint, credentials, grpc_options) as channel: router = router_pb2_grpc.RouterServiceStub(channel) context = router.Stream(metadata=()) async with RouterStream(context=context) as s: diff --git a/packages/jumpstarter/jumpstarter/config/client.py b/packages/jumpstarter/jumpstarter/config/client.py index 341858fbc..6e0346339 100644 --- a/packages/jumpstarter/jumpstarter/config/client.py +++ b/packages/jumpstarter/jumpstarter/config/client.py @@ -48,6 +48,7 @@ class ClientConfigV1Alpha1(BaseModel): endpoint: str tls: TLSConfigV1Alpha1 = Field(default_factory=TLSConfigV1Alpha1) token: str + grpcOptions: dict[str, str | int] | None = Field(default_factory=dict) drivers: ClientConfigV1Alpha1Drivers @@ -57,7 +58,7 @@ async def channel(self): call_credentials("Client", self.metadata, self.token), ) - return aio_secure_channel(self.endpoint, credentials) + return aio_secure_channel(self.endpoint, credentials, self.grpcOptions) @contextmanager def lease(self, metadata_filter: MetadataFilter, lease_name: str | None = None): @@ -122,6 +123,7 @@ async def request_lease_async( allow=self.drivers.allow, unsafe=self.drivers.unsafe, tls_config=self.tls, + grpc_options=self.grpcOptions, ) with translate_grpc_exceptions(): return await lease.request_async() @@ -161,6 +163,7 @@ async def lease_async( unsafe=self.drivers.unsafe, release=release_lease, tls_config=self.tls, + grpc_options=self.grpcOptions, ) as lease: yield lease diff --git a/packages/jumpstarter/jumpstarter/config/client_config_test.py b/packages/jumpstarter/jumpstarter/config/client_config_test.py index 353c3b07f..a83d0d89e 100644 --- a/packages/jumpstarter/jumpstarter/config/client_config_test.py +++ b/packages/jumpstarter/jumpstarter/config/client_config_test.py @@ -206,6 +206,7 @@ def test_client_config_save(monkeypatch: pytest.MonkeyPatch): ca: '' insecure: false token: dGhpc2lzYXRva2VuLTEyMzQxMjM0MTIzNEyMzQtc2Rxd3Jxd2VycXdlcnF3ZXJxd2VyLTEyMzQxMjM0MTIz +grpcOptions: {} drivers: allow: - jumpstarter.drivers.* @@ -241,6 +242,7 @@ def test_client_config_save_explicit_path(): ca: '' insecure: false token: dGhpc2lzYXRva2VuLTEyMzQxMjM0MTIzNEyMzQtc2Rxd3Jxd2VycXdlcnF3ZXJxd2VyLTEyMzQxMjM0MTIz +grpcOptions: {} drivers: allow: - jumpstarter.drivers.* @@ -274,6 +276,7 @@ def test_client_config_save_unsafe_drivers(): ca: '' insecure: false token: dGhpc2lzYXRva2VuLTEyMzQxMjM0MTIzNEyMzQtc2Rxd3Jxd2VycXdlcnF3ZXJxd2VyLTEyMzQxMjM0MTIz +grpcOptions: {} drivers: allow: [] unsafe: true diff --git a/packages/jumpstarter/jumpstarter/config/exporter.py b/packages/jumpstarter/jumpstarter/config/exporter.py index 18a0e5738..f4240bc67 100644 --- a/packages/jumpstarter/jumpstarter/config/exporter.py +++ b/packages/jumpstarter/jumpstarter/config/exporter.py @@ -81,6 +81,7 @@ class ExporterConfigV1Alpha1(BaseModel): endpoint: str tls: TLSConfigV1Alpha1 = Field(default_factory=TLSConfigV1Alpha1) token: str + grpcOptions: dict[str, str | int] | None = Field(default_factory=dict) export: dict[str, ExporterConfigV1Alpha1DriverInstance] = Field(default_factory=dict) @@ -163,12 +164,13 @@ def channel_factory(): ssl_channel_credentials(self.endpoint, self.tls), call_credentials("Exporter", self.metadata, self.token), ) - return aio_secure_channel(self.endpoint, credentials) + return aio_secure_channel(self.endpoint, credentials, self.grpcOptions) async with Exporter( channel_factory=channel_factory, device_factory=ExporterConfigV1Alpha1DriverInstance(children=self.export).instantiate, tls=self.tls, + grpc_options=self.grpcOptions, ) as exporter: await exporter.serve() diff --git a/packages/jumpstarter/jumpstarter/exporter/exporter.py b/packages/jumpstarter/jumpstarter/exporter/exporter.py index ff3963d32..428292f89 100644 --- a/packages/jumpstarter/jumpstarter/exporter/exporter.py +++ b/packages/jumpstarter/jumpstarter/exporter/exporter.py @@ -26,6 +26,7 @@ class Exporter(AbstractAsyncContextManager, Metadata): device_factory: Callable[[], Driver] lease_name: str = field(init=False, default="") tls: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) + grpc_options: dict[str, str] = field(default_factory=dict) async def __aexit__(self, exc_type, exc_value, traceback): controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory()) @@ -36,9 +37,9 @@ async def __aexit__(self, exc_type, exc_value, traceback): ) ) - async def __handle(self, path, endpoint, token, tls_config): + async def __handle(self, path, endpoint, token, tls_config, grpc_options): async with await connect_unix(path) as stream: - async with connect_router_stream(endpoint, token, stream, tls_config): + async with connect_router_stream(endpoint, token, stream, tls_config, grpc_options): pass @asynccontextmanager @@ -69,7 +70,9 @@ async def handle(self, lease_name, tg): async with self.session() as path: async for request in controller.Listen(jumpstarter_pb2.ListenRequest(lease_name=lease_name)): logger.info("Handling new connection request on lease %s", lease_name) - tg.start_soon(self.__handle, path, request.router_endpoint, request.router_token, self.tls) + tg.start_soon( + self.__handle, path, request.router_endpoint, request.router_token, self.tls, self.grpc_options + ) async def serve(self): controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory())