diff --git a/packages/jumpstarter-cli/jumpstarter_cli/run.py b/packages/jumpstarter-cli/jumpstarter_cli/run.py index 348755c73..484e16d6b 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/run.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/run.py @@ -1,31 +1,107 @@ +import logging import os -from multiprocessing.sharedctypes import Value +import signal +import sys import anyio import click +from anyio import create_task_group, open_signal_receiver from jumpstarter_cli_common.config import opt_config -from jumpstarter_cli_common.exceptions import handle_exceptions, leaf_exceptions +from jumpstarter_cli_common.exceptions import handle_exceptions +logger = logging.getLogger(__name__) -def _serve_with_exc_handling(exporter): + +def _handle_child(config): + """Handle child process with graceful shutdown.""" + async def serve_with_graceful_shutdown(): + received_signal = 0 + signal_handled = False + exporter = None + + async def signal_handler(): + nonlocal received_signal, signal_handled + + with open_signal_receiver(signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT) as signals: + async for sig in signals: + if signal_handled: + continue # Ignore duplicate signals + received_signal = sig + logger.info("CHILD: Received %d (%s)", received_signal, signal.Signals(received_signal).name) + if exporter: + # Terminate exporter. SIGHUP waits until current lease is let go. Later SIGTERM still overrides + if received_signal != signal.SIGHUP: + signal_handled = True + exporter.stop(wait_for_lease_exit=received_signal == signal.SIGHUP) + + # Start signal handler first, then create exporter + async with create_task_group() as signal_tg: + + # Start signal handler immediately + signal_tg.start_soon(signal_handler) + + # Create exporter and run it + async with config.create_exporter() as exporter: + try: + await exporter.serve() + except* Exception as excgroup: + from jumpstarter_cli_common.exceptions import leaf_exceptions + for exc in leaf_exceptions(excgroup): + if not isinstance(exc, anyio.get_cancelled_exc_class()): + click.echo( + f"Exception while serving on the exporter: {type(exc).__name__}: {exc}", + err=True, + ) + + # Cancel the signal handler after exporter completes + signal_tg.cancel_scope.cancel() + + # Return signal number if received, otherwise 0 for immediate restart + return received_signal if received_signal else 0 + + sys.exit(anyio.run(serve_with_graceful_shutdown)) + + +def _handle_parent(pid): + """Handle parent process waiting for child and signal forwarding.""" + def parent_signal_handler(signum, _): + logger.info("PARENT: Received %d (%s), forwarding to child PID %d", signum, signal.Signals(signum).name, pid) + if pid and pid > 0: + try: + os.kill(pid, signum) + except ProcessLookupError: + pass + + # Set up signal handlers after fork + for sig in (signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT): + signal.signal(sig, parent_signal_handler) + + _, status = os.waitpid(pid, 0) + if os.WIFEXITED(status): + # Interpret child exit code + child_exit_code = os.WEXITSTATUS(status) + if child_exit_code == 0: + return None # restart child (unexpected exit/exception) + else: + # Child indicates termination (signal number) + return 128 + child_exit_code # Return standard Unix exit code + else: + # Child killed by unhandled signal - terminate + child_exit_signal = os.WTERMSIG(status) if os.WIFSIGNALED(status) else 0 + click.echo(f"Child killed by unhandled signal: {child_exit_signal}", err=True) + return 128 + child_exit_signal + + +def _serve_with_exc_handling(config): while True: - result = Value("i", 0) pid = os.fork() + if pid > 0: - os.waitpid(pid, 0) - if result.value != 0: - return result.value + if (exit_code := _handle_parent(pid)) is not None: + return exit_code else: - try: - anyio.run(exporter.serve) - except* Exception as excgroup: - for exc in leaf_exceptions(excgroup): - click.echo( - f"Exception while serving on the exporter: {type(exc).__name__}: {exc}", - err=True, - ) - result.value = 1 - return + _handle_child(config) + sys.exit(1) # should never happen @click.command("run") @@ -33,5 +109,4 @@ def _serve_with_exc_handling(exporter): @handle_exceptions def run(config): """Run an exporter locally.""" - return _serve_with_exc_handling(config) diff --git a/packages/jumpstarter/jumpstarter/config/exporter.py b/packages/jumpstarter/jumpstarter/config/exporter.py index c7431aea6..8cc8b9bd6 100644 --- a/packages/jumpstarter/jumpstarter/config/exporter.py +++ b/packages/jumpstarter/jumpstarter/config/exporter.py @@ -156,26 +156,44 @@ def serve_unix(self): with portal.wrap_async_context_manager(self.serve_unix_async()) as path: yield path - async def serve(self): + @asynccontextmanager + async def create_exporter(self): + """Create and manage an exporter instance with proper lifecycle.""" # dynamic import to avoid circular imports + from anyio import CancelScope + from jumpstarter.exporter import Exporter async def channel_factory(): if self.endpoint is None or self.token is None: raise ConfigurationError("endpoint or token not set in exporter config") - credentials = grpc.composite_channel_credentials( await ssl_channel_credentials(self.endpoint, self.tls), call_credentials("Exporter", self.metadata, self.token), ) return aio_secure_channel(self.endpoint, credentials, self.grpcOptions) - async with Exporter( - channel_factory=channel_factory, - device_factory=ExporterConfigV1Alpha1DriverInstance(children=self.export).instantiate, - tls=self.tls, - grpc_options=self.grpcOptions, - ) as exporter: + exporter = None + entered = False + try: + exporter = Exporter( + channel_factory=channel_factory, + device_factory=ExporterConfigV1Alpha1DriverInstance(children=self.export).instantiate, + tls=self.tls, + grpc_options=self.grpcOptions, + ) + # Initialize the exporter (registration, etc.) + await exporter.__aenter__() + entered = True + yield exporter + finally: + # Shield all cleanup operations from abrupt cancellation for clean shutdown + if exporter and entered: + with CancelScope(shield=True): + await exporter.__aexit__(None, None, None) + + async def serve(self): + async with self.create_exporter() as exporter: await exporter.serve() diff --git a/packages/jumpstarter/jumpstarter/exporter/exporter.py b/packages/jumpstarter/jumpstarter/exporter/exporter.py index 643cd2be8..8fda8fbe2 100644 --- a/packages/jumpstarter/jumpstarter/exporter/exporter.py +++ b/packages/jumpstarter/jumpstarter/exporter/exporter.py @@ -1,10 +1,11 @@ import logging -from collections.abc import Callable +from collections.abc import Awaitable, Callable from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass, field import grpc from anyio import connect_unix, create_memory_object_stream, create_task_group, sleep +from anyio.abc import TaskGroup from google.protobuf import empty_pb2 from jumpstarter_protocol import ( jumpstarter_pb2, @@ -22,22 +23,57 @@ @dataclass(kw_only=True) class Exporter(AbstractAsyncContextManager, Metadata): - channel_factory: Callable[[], grpc.aio.Channel] + channel_factory: Callable[[], Awaitable[grpc.aio.Channel]] device_factory: Callable[[], Driver] lease_name: str = field(init=False, default="") tls: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) grpc_options: dict[str, str] = field(default_factory=dict) registered: bool = field(init=False, default=False) + _stop_requested: bool = field(init=False, default=False) + _started: bool = field(init=False, default=False) + _tg: TaskGroup | None = field(init=False, default=None) + + def stop(self, wait_for_lease_exit=False): + """Signal the exporter to stop. + + Args: + wait_for_lease_exit (bool): If True, wait for the current lease to exit before stopping. + """ + + # Stop immediately if not started yet or if immediate stop is requested + if (not self._started or not wait_for_lease_exit) and self._tg is not None: + logger.info("Stopping exporter immediately") + self._tg.cancel_scope.cancel() + elif not self._stop_requested: + self._stop_requested = True + logger.info("Exporter marked for stop upon lease exit") async def __aexit__(self, exc_type, exc_value, traceback): - if self.registered: - controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory()) - logger.info("Unregistering exporter with controller") - await controller.Unregister( - jumpstarter_pb2.UnregisterRequest( - reason="TODO", - ) - ) + import anyio + + try: + if self.registered: + logger.info("Unregistering exporter with controller") + try: + with anyio.move_on_after(10): # 10 second timeout + channel = await self.channel_factory() + try: + controller = jumpstarter_pb2_grpc.ControllerServiceStub(channel) + await controller.Unregister( + jumpstarter_pb2.UnregisterRequest( + reason="Exporter shutdown", + ) + ) + logger.info("Controller unregistration completed successfully") + finally: + with anyio.CancelScope(shield=True): + await channel.close() + except Exception as e: + logger.error("Error during controller unregistration: %s", e, exc_info=True) + + except Exception as e: + logger.error("Error during exporter cleanup: %s", e, exc_info=True) + # Don't re-raise to avoid masking the original exception async def __handle(self, path, endpoint, token, tls_config, grpc_options): try: @@ -106,10 +142,12 @@ async def listen(retries=5, backoff=3): ) async def serve(self): # noqa: C901 + """ + Serve the exporter. + """ # initial registration async with self.session(): pass - started = False status_tx, status_rx = create_memory_object_stream() async def status(retries=5, backoff=3): @@ -134,18 +172,23 @@ async def status(retries=5, backoff=3): retries_left = retries async with create_task_group() as tg: + self._tg = tg tg.start_soon(status) async for status in status_rx: if self.lease_name != "" and self.lease_name != status.lease_name: self.lease_name = status.lease_name logger.info("Lease status changed, killing existing connections") - tg.cancel_scope.cancel() + self.stop() break self.lease_name = status.lease_name - if not started and self.lease_name != "": - started = True + if not self._started and self.lease_name != "": + self._started = True tg.start_soon(self.handle, self.lease_name, tg) if status.leased: logger.info("Currently leased by %s under %s", status.client_name, status.lease_name) else: logger.info("Currently not leased") + if self._stop_requested: + self.stop() + break + self._tg = None diff --git a/packages/jumpstarter/jumpstarter/exporter/session.py b/packages/jumpstarter/jumpstarter/exporter/session.py index 9affa83e7..7334e235c 100644 --- a/packages/jumpstarter/jumpstarter/exporter/session.py +++ b/packages/jumpstarter/jumpstarter/exporter/session.py @@ -44,8 +44,18 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - self.root_device.close() - logging.getLogger().removeHandler(self._logging_handler) + try: + self.root_device.close() + except Exception as e: + # Get driver name from report for more descriptive logging + try: + report = self.root_device.report() + driver_name = report.labels.get('jumpstarter.dev/name', self.root_device.__class__.__name__) + except Exception: + driver_name = self.root_device.__class__.__name__ + logger.error("Error closing driver %s: %s", driver_name, e, exc_info=True) + finally: + logging.getLogger().removeHandler(self._logging_handler) def __init__(self, *args, root_device, **kwargs): super().__init__(*args, **kwargs)