diff --git a/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py b/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py index 95616ceab..dd989cc88 100644 --- a/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py +++ b/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py @@ -1,6 +1,7 @@ import types from functools import wraps from types import TracebackType +from typing import NoReturn import click @@ -13,12 +14,21 @@ def format_message(self) -> str: def async_handle_exceptions(func): - """Decorator to handle exceptions in async functions.""" + """Decorator to handle exceptions in async functions, including those wrapped in BaseExceptionGroup.""" @wraps(func) async def wrapped(*args, **kwargs): try: return await func(*args, **kwargs) + except BaseExceptionGroup as eg: + # Handle exceptions wrapped in ExceptionGroup (e.g., from task groups) + for exc in leaf_exceptions(eg, fix_tracebacks=False): + if isinstance(exc, JumpstarterException): + raise ClickExceptionRed(str(exc)) from None + elif isinstance(exc, click.ClickException): + raise exc from None + # If no handled exceptions, re-raise the original group + raise eg except JumpstarterException as e: raise ClickExceptionRed(str(e)) from None except click.ClickException: @@ -46,26 +56,48 @@ def wrapped(*args, **kwargs): return wrapped +def _handle_connection_error_with_reauth(exc, login_func): + """Handle ConnectionError with reauthentication logic.""" + if "expired" in str(exc).lower(): + click.echo(click.style("Token is expired, triggering re-authentication", fg="red")) + config = exc.get_config() + login_func(config) + raise ClickExceptionRed("Please try again now") from None + else: + raise ClickExceptionRed(str(exc)) from None + + +def _handle_single_exception_with_reauth(exc, login_func): + """Handle a single exception (may raise).""" + if isinstance(exc, ConnectionError): + _handle_connection_error_with_reauth(exc, login_func) + elif isinstance(exc, JumpstarterException): + raise ClickExceptionRed(str(exc)) from None + elif isinstance(exc, click.ClickException): + raise exc from None + # Not handled: fall through + + +def _handle_exception_group_with_reauth(eg, login_func) -> NoReturn: + """Handle exceptions wrapped in BaseExceptionGroup.""" + for exc in leaf_exceptions(eg, fix_tracebacks=False): + _handle_single_exception_with_reauth(exc, login_func) + # If no handled exceptions, re-raise the original group + raise eg + + def handle_exceptions_with_reauthentication(login_func): - """Decorator to handle exceptions in blocking functions.""" + """Decorator to handle exceptions in blocking functions, including those wrapped in BaseExceptionGroup.""" def decorator(func): @wraps(func) def wrapped(*args, **kwargs): try: return func(*args, **kwargs) - except ConnectionError as e: - if "expired" in str(e).lower(): - click.echo(click.style("Token is expired, triggering re-authentication", fg="red")) - config = e.get_config() - login_func(config) - raise ClickExceptionRed("Please try again now") from None - else: - raise ClickExceptionRed(str(e)) from None - except JumpstarterException as e: - raise ClickExceptionRed(str(e)) from None - except click.ClickException: - raise # if it was already a click exception from the cli commands, just re-raise it + except BaseExceptionGroup as eg: + _handle_exception_group_with_reauth(eg, login_func) + except (ConnectionError, JumpstarterException, click.ClickException) as e: + _handle_single_exception_with_reauth(e, login_func) except Exception: raise @@ -74,7 +106,7 @@ def wrapped(*args, **kwargs): return decorator -# https://peps.python.org/pep-0785/#reference-implementation +# https://peps.python.org/pep-0654/ def leaf_exceptions(self: BaseExceptionGroup, *, fix_tracebacks: bool = True) -> list[BaseException]: """ Return a flat list of all 'leaf' exceptions. diff --git a/packages/jumpstarter-cli/jumpstarter_cli/shell.py b/packages/jumpstarter-cli/jumpstarter_cli/shell.py index 1a696d393..4e460dd71 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/shell.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/shell.py @@ -1,9 +1,12 @@ import sys from datetime import timedelta +import anyio import click +from anyio import create_task_group, get_cancelled_exc_class from jumpstarter_cli_common.config import opt_config from jumpstarter_cli_common.exceptions import handle_exceptions_with_reauthentication +from jumpstarter_cli_common.signal import signal_handler from .common import opt_duration_partial, opt_selector from .login import relogin_client @@ -12,6 +15,52 @@ from jumpstarter.config.exporter import ExporterConfigV1Alpha1 +def _run_shell_with_lease(lease, exporter_logs, config, command): + """Run shell with lease context managers.""" + def launch_remote_shell(path: str) -> int: + return launch_shell( + path, lease.exporter_name, config.drivers.allow, config.drivers.unsafe, + config.shell.use_profiles, command=command + ) + + with lease.serve_unix() as path: + with lease.monitor(): + if exporter_logs: + with lease.connect() as client: + with client.log_stream(): + return launch_remote_shell(path) + else: + return launch_remote_shell(path) + + +async def _shell_with_signal_handling(config, selector, lease_name, duration, exporter_logs, command): + """Handle lease acquisition and shell execution with signal handling.""" + exit_code = 0 + cancelled_exc_class = get_cancelled_exc_class() + + async with create_task_group() as tg: + tg.start_soon(signal_handler, tg.cancel_scope) + try: + try: + async with anyio.from_thread.BlockingPortal() as portal: + async with config.lease_async(selector, lease_name, duration, portal) as lease: + exit_code = await anyio.to_thread.run_sync( + _run_shell_with_lease, lease, exporter_logs, config, command + ) + except BaseExceptionGroup as eg: + for exc in eg.exceptions: + if isinstance(exc, TimeoutError): + raise exc from None + raise + except cancelled_exc_class: + exit_code = 2 + finally: + if not tg.cancel_scope.cancel_called: + tg.cancel_scope.cancel() + + return exit_code + + @click.command("shell") @opt_config() @click.argument("command", nargs=-1) @@ -38,27 +87,9 @@ def shell(config, command: tuple[str, ...], lease_name, selector, duration, expo match config: case ClientConfigV1Alpha1(): - exit_code = 0 - def _launch_remote_shell(path: str) -> int: - return launch_shell( - path, - "remote", - config.drivers.allow, - config.drivers.unsafe, - config.shell.use_profiles, - command=command, - ) - - with config.lease(selector=selector, lease_name=lease_name, duration=duration) as lease: - with lease.serve_unix() as path: - with lease.monitor(): - if exporter_logs: - with lease.connect() as client: - with client.log_stream(): - exit_code = _launch_remote_shell(path) - else: - exit_code = _launch_remote_shell(path) - # we exit here to make sure that all the with clauses unwind + exit_code = anyio.run( + _shell_with_signal_handling, config, selector, lease_name, duration, exporter_logs, command + ) sys.exit(exit_code) case ExporterConfigV1Alpha1(): diff --git a/packages/jumpstarter/jumpstarter/client/lease.py b/packages/jumpstarter/jumpstarter/client/lease.py index 953d30867..08af4b20b 100644 --- a/packages/jumpstarter/jumpstarter/client/lease.py +++ b/packages/jumpstarter/jumpstarter/client/lease.py @@ -9,7 +9,14 @@ from datetime import datetime, timedelta from typing import Any, Self -from anyio import AsyncContextManagerMixin, ContextManagerMixin, create_task_group, fail_after, sleep +from anyio import ( + AsyncContextManagerMixin, + CancelScope, + ContextManagerMixin, + create_task_group, + fail_after, + sleep, +) from anyio.from_thread import BlockingPortal from grpc.aio import Channel from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc @@ -40,6 +47,8 @@ class Lease(ContextManagerMixin, AsyncContextManagerMixin): controller: jumpstarter_pb2_grpc.ControllerServiceStub = field(init=False) tls_config: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) grpc_options: dict[str, Any] = field(default_factory=dict) + acquisition_timeout: int = field(default=7200) # Timeout in seconds for lease acquisition, polled in 5s intervals + exporter_name: str = field(default="remote", init=False) # Populated during acquisition def __post_init__(self): if hasattr(super(), "__post_init__"): @@ -57,7 +66,7 @@ async def _create(self): duration=self.duration, ) ).name - logger.info("Created lease request for selector %s for duration %s", self.selector, self.duration) + logger.info("Acquiring lease %s for selector %s for duration %s", self.name, self.selector, self.duration) async def get(self): with translate_grpc_exceptions(): @@ -99,6 +108,7 @@ async def request_async(self): await self._create() else: await self._create() + return await self._acquire() async def _acquire(self): @@ -106,47 +116,62 @@ async def _acquire(self): Makes sure the lease is ready, and returns the lease object. """ - with fail_after(300): # TODO: configurable timeout - while True: - logger.debug("Polling Lease %s", self.name) - result = await self.get() - # lease ready - if condition_true(result.conditions, "Ready"): - logger.debug("Lease %s acquired", self.name) - return self - # lease unsatisfiable - if condition_true(result.conditions, "Unsatisfiable"): - message = condition_message(result.conditions, "Unsatisfiable") - logger.debug( - "Lease %s cannot be satisfied: %s", - self.name, - condition_message(result.conditions, "Unsatisfiable"), - ) - raise LeaseError(f"the lease cannot be satisfied: {message}") - - # lease not pending - if condition_false(result.conditions, "Pending"): - raise LeaseError( - f"Lease {self.name} is not in pending, but it isn't in Ready or Unsatisfiable state either" - ) - - # lease released - if condition_present_and_equal(result.conditions, "Ready", "False", "Released"): - raise LeaseError(f"lease {self.name} released") - - await sleep(1) + try: + with fail_after(self.acquisition_timeout): + while True: + logger.debug("Polling Lease %s", self.name) + result = await self.get() + # lease ready + if condition_true(result.conditions, "Ready"): + logger.debug("Lease %s acquired", self.name) + self.exporter_name = result.exporter + return self + # lease unsatisfiable + if condition_true(result.conditions, "Unsatisfiable"): + message = condition_message(result.conditions, "Unsatisfiable") + logger.debug("Lease %s cannot be satisfied: %s", self.name, message) + raise LeaseError(f"the lease cannot be satisfied: {message}") + + # lease invalid + if condition_true(result.conditions, "Invalid"): + message = condition_message(result.conditions, "Invalid") + logger.debug("Lease %s is invalid: %s", self.name, message) + raise LeaseError(f"the lease is invalid: {message}") + + # lease not pending + if condition_false(result.conditions, "Pending"): + raise LeaseError( + f"Lease {self.name} is not in pending, but it isn't in Ready or Unsatisfiable state either" + ) + + # lease released + if condition_present_and_equal(result.conditions, "Ready", "False", "Released"): + raise LeaseError(f"lease {self.name} released") + + await sleep(5) + except TimeoutError: + logger.debug(f"Lease {self.name} acquisition timed out after {self.acquisition_timeout} seconds") + raise LeaseError( + f"lease {self.name} acquisition timed out after {self.acquisition_timeout} seconds" + ) from None @asynccontextmanager async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: - value = await self.request_async() try: + value = await self.request_async() yield value finally: - if self.release: + if self.release and self.name: logger.info("Releasing Lease %s", self.name) - await self.svc.DeleteLease( - name=self.name, - ) + # Shield cleanup from cancellation to ensure it completes + with CancelScope(shield=True): + try: + with fail_after(30): + await self.svc.DeleteLease( + name=self.name, + ) + except TimeoutError: + logger.warning("Timeout while deleting lease %s during cleanup", self.name) @contextmanager def __contextmanager__(self) -> Generator[Self]: diff --git a/packages/jumpstarter/jumpstarter/common/utils.py b/packages/jumpstarter/jumpstarter/common/utils.py index b960bfdb0..8fb3cc67f 100644 --- a/packages/jumpstarter/jumpstarter/common/utils.py +++ b/packages/jumpstarter/jumpstarter/common/utils.py @@ -59,7 +59,7 @@ def launch_shell( Args: host: The jumpstarter host path - context: The context of the shell ("local" or "remote") + context: The context of the shell (e.g. "local" or exporter name) allow: List of allowed drivers unsafe: Whether to allow drivers outside of the allow list """ diff --git a/packages/jumpstarter/jumpstarter/config/client.py b/packages/jumpstarter/jumpstarter/config/client.py index 9872c3c7e..fc4c3d1c2 100644 --- a/packages/jumpstarter/jumpstarter/config/client.py +++ b/packages/jumpstarter/jumpstarter/config/client.py @@ -86,6 +86,16 @@ def decode_unsafe(self) -> Self: return self +class ClientConfigV1Alpha1Lease(BaseSettings): + """Configuration for lease operations.""" + + acquisition_timeout: int = Field( + default=7200, + description="Timeout in seconds for lease acquisition", + ge=5, # Must be at least 5 seconds (polling interval) + ) + + class ClientConfigV1Alpha1(BaseSettings): CLIENT_CONFIGS_PATH: ClassVar[Path] = CONFIG_PATH / "clients" @@ -108,6 +118,8 @@ class ClientConfigV1Alpha1(BaseSettings): shell: ShellConfigV1Alpha1 = Field(default_factory=ShellConfigV1Alpha1) + leases: ClientConfigV1Alpha1Lease = Field(default_factory=ClientConfigV1Alpha1Lease) + async def channel(self): if self.endpoint is None or self.token is None: raise ConfigurationError("endpoint or token not set in client config") @@ -258,6 +270,7 @@ async def lease_async( release=release_lease, tls_config=self.tls, grpc_options=self.grpcOptions, + acquisition_timeout=self.leases.acquisition_timeout, ) as lease: yield lease diff --git a/packages/jumpstarter/jumpstarter/config/client_config_test.py b/packages/jumpstarter/jumpstarter/config/client_config_test.py index 3dbcdabf1..a82ed01f6 100644 --- a/packages/jumpstarter/jumpstarter/config/client_config_test.py +++ b/packages/jumpstarter/jumpstarter/config/client_config_test.py @@ -214,6 +214,8 @@ def test_client_config_save(monkeypatch: pytest.MonkeyPatch): unsafe: false shell: use_profiles: false +leases: + acquisition_timeout: 7200 """ config = ClientConfigV1Alpha1( alias="testclient", @@ -253,6 +255,8 @@ def test_client_config_save_explicit_path(): unsafe: false shell: use_profiles: false +leases: + acquisition_timeout: 7200 """ config = ClientConfigV1Alpha1( alias="testclient", @@ -288,6 +292,8 @@ def test_client_config_save_unsafe_drivers(): unsafe: true shell: use_profiles: false +leases: + acquisition_timeout: 7200 """ config = ClientConfigV1Alpha1( alias="testclient",