diff --git a/packages/jumpstarter-cli/jumpstarter_cli/shell.py b/packages/jumpstarter-cli/jumpstarter_cli/shell.py index e28a052d4..f405b9f49 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/shell.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/shell.py @@ -20,7 +20,7 @@ def _run_shell_with_lease(lease, exporter_logs, config, command): def launch_remote_shell(path: str) -> int: return launch_shell( path, lease.exporter_name, config.drivers.allow, config.drivers.unsafe, - config.shell.use_profiles, command=command + config.shell.use_profiles, command=command, lease=lease ) with lease.serve_unix() as path: diff --git a/packages/jumpstarter/jumpstarter/client/lease.py b/packages/jumpstarter/jumpstarter/client/lease.py index fc61d45f6..36a5c3513 100644 --- a/packages/jumpstarter/jumpstarter/client/lease.py +++ b/packages/jumpstarter/jumpstarter/client/lease.py @@ -1,7 +1,7 @@ import logging import os import sys -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator, Callable, Generator from contextlib import ( ExitStack, asynccontextmanager, @@ -54,6 +54,9 @@ class Lease(ContextManagerMixin, AsyncContextManagerMixin): grpc_options: dict[str, Any] = field(default_factory=dict) acquisition_timeout: int = field(default=7200) # Timeout in seconds for lease acquisition, polled in 5s intervals exporter_name: str = field(default="remote", init=False) # Populated during acquisition + lease_ending_callback: Callable[[Self, timedelta], None] | None = field( + default=None, init=False + ) # Called when lease is ending def __post_init__(self): if hasattr(super(), "__post_init__"): @@ -208,11 +211,14 @@ async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: yield value finally: if self.release and self.name: - logger.info("Releasing Lease %s", self.name) # Shield cleanup from cancellation to ensure it completes with CancelScope(shield=True): try: with fail_after(30): + # skip the message if the lease is already expired + lease = await self.get() + if not lease.effective_end_time: + logger.info("Releasing Lease %s", self.name) await self.svc.DeleteLease( name=self.name, ) @@ -280,6 +286,8 @@ async def _monitor(): if remain < timedelta(0): # lease already expired, stopping monitor logger.info("Lease {} ended at {}".format(self.name, end_time)) + if self.lease_ending_callback is not None: + self.lease_ending_callback(self, timedelta(0)) break # Log once when entering the threshold window if threshold - timedelta(seconds=check_interval) <= remain < threshold: @@ -288,6 +296,9 @@ async def _monitor(): self.name, int((remain.total_seconds() + 30) // 60), end_time ) ) + # Notify callback about approaching expiration + if self.lease_ending_callback is not None: + self.lease_ending_callback(self, remain) await sleep(min(remain.total_seconds(), check_interval)) else: await sleep(1) diff --git a/packages/jumpstarter/jumpstarter/common/utils.py b/packages/jumpstarter/jumpstarter/common/utils.py index 8fb3cc67f..dac73cad0 100644 --- a/packages/jumpstarter/jumpstarter/common/utils.py +++ b/packages/jumpstarter/jumpstarter/common/utils.py @@ -1,6 +1,9 @@ import os +import signal import sys from contextlib import ExitStack, asynccontextmanager, contextmanager +from datetime import timedelta +from functools import partial from subprocess import Popen from anyio.from_thread import BlockingPortal, start_blocking_portal @@ -46,6 +49,34 @@ def serve(root_device: Driver): PROMPT_CWD = "\\W" +def lease_ending_handler(process: Popen, lease, remaining_time) -> None: + """Lease ending handler to terminate a process when lease ends. + + Args: + process: The process to terminate + lease: The lease instance + remaining_time: Time remaining until lease expiration + """ + + if remaining_time <= timedelta(0): + try: + process.send_signal(signal.SIGHUP) + except (ProcessLookupError, OSError): + pass # Process already terminated + + +def _run_process( + cmd: list[str], + env: dict[str, str], + lease=None, +) -> int: + """Helper to run a process with an option to set a lease ending callback.""" + process = Popen(cmd, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=env) + if lease is not None: + lease.lease_ending_callback = partial(lease_ending_handler, process) + return process.wait() + + def launch_shell( host: str, context: str, @@ -54,6 +85,7 @@ def launch_shell( use_profiles: bool, *, command: tuple[str, ...] | None = None, + lease=None, ) -> int: """Launch a shell with a custom prompt indicating the exporter type. @@ -62,6 +94,12 @@ def launch_shell( context: The context of the shell (e.g. "local" or exporter name) allow: List of allowed drivers unsafe: Whether to allow drivers outside of the allow list + use_profiles: Whether to load shell profile files + command: Optional command to run instead of launching an interactive shell + lease: Optional Lease object to set up lease ending callback + + Returns: + The exit code of the shell or command process """ shell = os.environ.get("SHELL", "bash") @@ -73,19 +111,16 @@ def launch_shell( } if command: - process = Popen(command, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=common_env) - return process.wait() + return _run_process(list(command), common_env, lease) if shell_name.endswith("bash"): env = common_env | { "PS1": f"{ANSI_GRAY}{PROMPT_CWD} {ANSI_YELLOW}⚡{ANSI_WHITE}{context} {ANSI_YELLOW}➤{ANSI_RESET} ", } - cmd = [shell] if not use_profiles: cmd.extend(["--norc", "--noprofile"]) - process = Popen(cmd, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=env) - return process.wait() + return _run_process(cmd, env, lease) elif shell_name == "fish": fish_fn = ( @@ -102,26 +137,20 @@ def launch_shell( "end" ) cmd = [shell, "--init-command", fish_fn] - process = Popen(cmd, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=common_env) - return process.wait() + return _run_process(cmd, common_env, lease) elif shell_name == "zsh": env = common_env | { "PS1": f"%F{{8}}%1~ %F{{yellow}}⚡%F{{white}}{context} %F{{yellow}}➤%f ", } - if "HISTFILE" not in env: env["HISTFILE"] = os.path.join(os.path.expanduser("~"), ".zsh_history") cmd = [shell] if not use_profiles: cmd.append("--no-rcs") - cmd.extend(["-o", "inc_append_history", "-o", "share_history"]) - - process = Popen(cmd, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=env) - return process.wait() + return _run_process(cmd, env, lease) else: - process = Popen([shell], stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=common_env) - return process.wait() + return _run_process([shell], common_env, lease)