From 94a7c46b1f49cc3d14a28e763115951c24cb8290 Mon Sep 17 00:00:00 2001 From: Michal Skrivanek Date: Wed, 13 Aug 2025 16:15:53 +0200 Subject: [PATCH 1/2] refactor signal handling Properly handle termination signals to exit cleanly. Signals are forwarded to child, and its exit code is propagated to parent. SIGHUP restarts the child while keeping the parent. --- .../jumpstarter-cli/jumpstarter_cli/run.py | 111 +++++++++++++++--- .../jumpstarter/config/exporter.py | 21 +++- .../jumpstarter/exporter/exporter.py | 37 ++++-- .../jumpstarter/exporter/session.py | 14 ++- 4 files changed, 149 insertions(+), 34 deletions(-) diff --git a/packages/jumpstarter-cli/jumpstarter_cli/run.py b/packages/jumpstarter-cli/jumpstarter_cli/run.py index 348755c73..44bdd7a23 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/run.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/run.py @@ -1,31 +1,109 @@ +import logging import os -from multiprocessing.sharedctypes import Value +import signal +import sys import anyio import click +from anyio import create_task_group, open_signal_receiver from jumpstarter_cli_common.config import opt_config -from jumpstarter_cli_common.exceptions import handle_exceptions, leaf_exceptions +from jumpstarter_cli_common.exceptions import handle_exceptions + +logger = logging.getLogger(__name__) + + +def _handle_child(exporter): + """Handle child process with graceful shutdown.""" + async def serve_with_graceful_shutdown(): + received_signal = 0 + signal_handled = False + + async def signal_handler(cancel_func): + nonlocal received_signal, signal_handled + + with open_signal_receiver(signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT) as signals: + async for sig in signals: + if signal_handled: + continue # Ignore duplicate signals + signal_handled = True + received_signal = sig + logger.info("CHILD: Received %d (%s)", received_signal, signal.Signals(received_signal).name) + # Cancel exporter task group(leaves signal handler running) + cancel_func() + + # Run signal handler and exporter with separate task groups + async with create_task_group() as signal_tg: + exporter_tg = None + + async def run_exporter(): + nonlocal exporter_tg + try: + async with create_task_group() as tg: + exporter_tg = tg + await exporter.serve() + except* Exception as excgroup: + from jumpstarter_cli_common.exceptions import leaf_exceptions + for exc in leaf_exceptions(excgroup): + if not isinstance(exc, anyio.get_cancelled_exc_class()): + click.echo( + f"Exception while serving on the exporter: {type(exc).__name__}: {exc}", + err=True, + ) + + async def signal_handler_wrapper(): + await signal_handler(lambda: exporter_tg.cancel_scope.cancel() if exporter_tg else None) + + signal_tg.start_soon(signal_handler_wrapper) + await run_exporter() + # Cancel the signal handler after exporter completes + signal_tg.cancel_scope.cancel() + + # Return signal number if received, otherwise 0 for immediate restart + return received_signal if received_signal else 0 + + sys.exit(anyio.run(serve_with_graceful_shutdown)) + + +def _handle_parent(pid): + """Handle parent process waiting for child and signal forwarding.""" + def parent_signal_handler(signum, _): + logger.info("PARENT: Received %d (%s), forwarding to child PID %d", signum, signal.Signals(signum).name, pid) + if pid and pid > 0: + try: + os.kill(pid, signum) + except ProcessLookupError: + pass + + # Set up signal handlers after fork + for sig in (signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT): + signal.signal(sig, parent_signal_handler) + + _, status = os.waitpid(pid, 0) + if os.WIFEXITED(status): + # Interpret child exit code + child_exit_code = os.WEXITSTATUS(status) + if child_exit_code == 0 or child_exit_code == signal.SIGHUP: + return None # restart child (exception/unexpected or SIGHUP) + else: + # Child indicates termination (signal number) + return 128 + child_exit_code # Return standard Unix exit code + else: + # Child killed by unhandled signal - terminate + child_exit_signal = os.WTERMSIG(status) if os.WIFSIGNALED(status) else 0 + click.echo(f"Child killed by unhandled signal: {child_exit_signal}", err=True) + return 128 + child_exit_signal def _serve_with_exc_handling(exporter): while True: - result = Value("i", 0) pid = os.fork() + if pid > 0: - os.waitpid(pid, 0) - if result.value != 0: - return result.value + if (exit_code := _handle_parent(pid)) is not None: + return exit_code else: - try: - anyio.run(exporter.serve) - except* Exception as excgroup: - for exc in leaf_exceptions(excgroup): - click.echo( - f"Exception while serving on the exporter: {type(exc).__name__}: {exc}", - err=True, - ) - result.value = 1 - return + _handle_child(exporter) + sys.exit(1) # should never happen @click.command("run") @@ -33,5 +111,4 @@ def _serve_with_exc_handling(exporter): @handle_exceptions def run(config): """Run an exporter locally.""" - return _serve_with_exc_handling(config) diff --git a/packages/jumpstarter/jumpstarter/config/exporter.py b/packages/jumpstarter/jumpstarter/config/exporter.py index c7431aea6..35a04c1dd 100644 --- a/packages/jumpstarter/jumpstarter/config/exporter.py +++ b/packages/jumpstarter/jumpstarter/config/exporter.py @@ -158,6 +158,8 @@ def serve_unix(self): async def serve(self): # dynamic import to avoid circular imports + from anyio import CancelScope + from jumpstarter.exporter import Exporter async def channel_factory(): @@ -170,13 +172,20 @@ async def channel_factory(): ) return aio_secure_channel(self.endpoint, credentials, self.grpcOptions) - async with Exporter( - channel_factory=channel_factory, - device_factory=ExporterConfigV1Alpha1DriverInstance(children=self.export).instantiate, - tls=self.tls, - grpc_options=self.grpcOptions, - ) as exporter: + exporter = None + try: + exporter = Exporter( + channel_factory=channel_factory, + device_factory=ExporterConfigV1Alpha1DriverInstance(children=self.export).instantiate, + tls=self.tls, + grpc_options=self.grpcOptions, + ) await exporter.serve() + finally: + # Shield all cleanup operations from abrupt cancellation for clean shutdown + if exporter: + with CancelScope(shield=True): + await exporter.__aexit__(None, None, None) class ExporterConfigListV1Alpha1(BaseModel): diff --git a/packages/jumpstarter/jumpstarter/exporter/exporter.py b/packages/jumpstarter/jumpstarter/exporter/exporter.py index 643cd2be8..2679d4d87 100644 --- a/packages/jumpstarter/jumpstarter/exporter/exporter.py +++ b/packages/jumpstarter/jumpstarter/exporter/exporter.py @@ -28,16 +28,32 @@ class Exporter(AbstractAsyncContextManager, Metadata): tls: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) grpc_options: dict[str, str] = field(default_factory=dict) registered: bool = field(init=False, default=False) - async def __aexit__(self, exc_type, exc_value, traceback): - if self.registered: - controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory()) - logger.info("Unregistering exporter with controller") - await controller.Unregister( - jumpstarter_pb2.UnregisterRequest( - reason="TODO", - ) - ) + import anyio + + try: + if self.registered: + logger.info("Unregistering exporter with controller") + try: + with anyio.move_on_after(10): # 10 second timeout + channel = await self.channel_factory() + try: + controller = jumpstarter_pb2_grpc.ControllerServiceStub(channel) + await controller.Unregister( + jumpstarter_pb2.UnregisterRequest( + reason="Exporter shutdown", + ) + ) + logger.info("Controller unregistration completed successfully") + finally: + with anyio.CancelScope(shield=True): + await channel.close() + except Exception as e: + logger.error("Error during controller unregistration: %s", e, exc_info=True) + + except Exception as e: + logger.error("Error during exporter cleanup: %s", e, exc_info=True) + # Don't re-raise to avoid masking the original exception async def __handle(self, path, endpoint, token, tls_config, grpc_options): try: @@ -106,6 +122,9 @@ async def listen(retries=5, backoff=3): ) async def serve(self): # noqa: C901 + """ + Serve the exporter. + """ # initial registration async with self.session(): pass diff --git a/packages/jumpstarter/jumpstarter/exporter/session.py b/packages/jumpstarter/jumpstarter/exporter/session.py index 9affa83e7..7334e235c 100644 --- a/packages/jumpstarter/jumpstarter/exporter/session.py +++ b/packages/jumpstarter/jumpstarter/exporter/session.py @@ -44,8 +44,18 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - self.root_device.close() - logging.getLogger().removeHandler(self._logging_handler) + try: + self.root_device.close() + except Exception as e: + # Get driver name from report for more descriptive logging + try: + report = self.root_device.report() + driver_name = report.labels.get('jumpstarter.dev/name', self.root_device.__class__.__name__) + except Exception: + driver_name = self.root_device.__class__.__name__ + logger.error("Error closing driver %s: %s", driver_name, e, exc_info=True) + finally: + logging.getLogger().removeHandler(self._logging_handler) def __init__(self, *args, root_device, **kwargs): super().__init__(*args, **kwargs) From f4754db6dc028cca1f25f4c93b2b0d15e7951f9b Mon Sep 17 00:00:00 2001 From: Michal Skrivanek Date: Wed, 20 Aug 2025 14:00:23 +0200 Subject: [PATCH 2/2] expose stop() method to the signal handler, use SIGHUP to deferred stop sending SIGHUP to child process results in delayed termination waiting on lease to be free (either the lease ends or there is another error resulting in child restart) Intended for graceful termination of exporter after the current lease holder is done --- .../jumpstarter-cli/jumpstarter_cli/run.py | 40 +++++++++---------- .../jumpstarter/config/exporter.py | 17 ++++++-- .../jumpstarter/exporter/exporter.py | 36 ++++++++++++++--- 3 files changed, 62 insertions(+), 31 deletions(-) diff --git a/packages/jumpstarter-cli/jumpstarter_cli/run.py b/packages/jumpstarter-cli/jumpstarter_cli/run.py index 44bdd7a23..484e16d6b 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/run.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/run.py @@ -12,35 +12,38 @@ logger = logging.getLogger(__name__) -def _handle_child(exporter): +def _handle_child(config): """Handle child process with graceful shutdown.""" async def serve_with_graceful_shutdown(): received_signal = 0 signal_handled = False + exporter = None - async def signal_handler(cancel_func): + async def signal_handler(): nonlocal received_signal, signal_handled with open_signal_receiver(signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT) as signals: async for sig in signals: if signal_handled: continue # Ignore duplicate signals - signal_handled = True received_signal = sig logger.info("CHILD: Received %d (%s)", received_signal, signal.Signals(received_signal).name) - # Cancel exporter task group(leaves signal handler running) - cancel_func() + if exporter: + # Terminate exporter. SIGHUP waits until current lease is let go. Later SIGTERM still overrides + if received_signal != signal.SIGHUP: + signal_handled = True + exporter.stop(wait_for_lease_exit=received_signal == signal.SIGHUP) - # Run signal handler and exporter with separate task groups + # Start signal handler first, then create exporter async with create_task_group() as signal_tg: - exporter_tg = None - async def run_exporter(): - nonlocal exporter_tg + # Start signal handler immediately + signal_tg.start_soon(signal_handler) + + # Create exporter and run it + async with config.create_exporter() as exporter: try: - async with create_task_group() as tg: - exporter_tg = tg - await exporter.serve() + await exporter.serve() except* Exception as excgroup: from jumpstarter_cli_common.exceptions import leaf_exceptions for exc in leaf_exceptions(excgroup): @@ -50,11 +53,6 @@ async def run_exporter(): err=True, ) - async def signal_handler_wrapper(): - await signal_handler(lambda: exporter_tg.cancel_scope.cancel() if exporter_tg else None) - - signal_tg.start_soon(signal_handler_wrapper) - await run_exporter() # Cancel the signal handler after exporter completes signal_tg.cancel_scope.cancel() @@ -82,8 +80,8 @@ def parent_signal_handler(signum, _): if os.WIFEXITED(status): # Interpret child exit code child_exit_code = os.WEXITSTATUS(status) - if child_exit_code == 0 or child_exit_code == signal.SIGHUP: - return None # restart child (exception/unexpected or SIGHUP) + if child_exit_code == 0: + return None # restart child (unexpected exit/exception) else: # Child indicates termination (signal number) return 128 + child_exit_code # Return standard Unix exit code @@ -94,7 +92,7 @@ def parent_signal_handler(signum, _): return 128 + child_exit_signal -def _serve_with_exc_handling(exporter): +def _serve_with_exc_handling(config): while True: pid = os.fork() @@ -102,7 +100,7 @@ def _serve_with_exc_handling(exporter): if (exit_code := _handle_parent(pid)) is not None: return exit_code else: - _handle_child(exporter) + _handle_child(config) sys.exit(1) # should never happen diff --git a/packages/jumpstarter/jumpstarter/config/exporter.py b/packages/jumpstarter/jumpstarter/config/exporter.py index 35a04c1dd..8cc8b9bd6 100644 --- a/packages/jumpstarter/jumpstarter/config/exporter.py +++ b/packages/jumpstarter/jumpstarter/config/exporter.py @@ -156,7 +156,9 @@ def serve_unix(self): with portal.wrap_async_context_manager(self.serve_unix_async()) as path: yield path - async def serve(self): + @asynccontextmanager + async def create_exporter(self): + """Create and manage an exporter instance with proper lifecycle.""" # dynamic import to avoid circular imports from anyio import CancelScope @@ -165,7 +167,6 @@ async def serve(self): async def channel_factory(): if self.endpoint is None or self.token is None: raise ConfigurationError("endpoint or token not set in exporter config") - credentials = grpc.composite_channel_credentials( await ssl_channel_credentials(self.endpoint, self.tls), call_credentials("Exporter", self.metadata, self.token), @@ -173,6 +174,7 @@ async def channel_factory(): return aio_secure_channel(self.endpoint, credentials, self.grpcOptions) exporter = None + entered = False try: exporter = Exporter( channel_factory=channel_factory, @@ -180,13 +182,20 @@ async def channel_factory(): tls=self.tls, grpc_options=self.grpcOptions, ) - await exporter.serve() + # Initialize the exporter (registration, etc.) + await exporter.__aenter__() + entered = True + yield exporter finally: # Shield all cleanup operations from abrupt cancellation for clean shutdown - if exporter: + if exporter and entered: with CancelScope(shield=True): await exporter.__aexit__(None, None, None) + async def serve(self): + async with self.create_exporter() as exporter: + await exporter.serve() + class ExporterConfigListV1Alpha1(BaseModel): api_version: Literal["jumpstarter.dev/v1alpha1"] = Field(alias="apiVersion", default="jumpstarter.dev/v1alpha1") diff --git a/packages/jumpstarter/jumpstarter/exporter/exporter.py b/packages/jumpstarter/jumpstarter/exporter/exporter.py index 2679d4d87..8fda8fbe2 100644 --- a/packages/jumpstarter/jumpstarter/exporter/exporter.py +++ b/packages/jumpstarter/jumpstarter/exporter/exporter.py @@ -1,10 +1,11 @@ import logging -from collections.abc import Callable +from collections.abc import Awaitable, Callable from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass, field import grpc from anyio import connect_unix, create_memory_object_stream, create_task_group, sleep +from anyio.abc import TaskGroup from google.protobuf import empty_pb2 from jumpstarter_protocol import ( jumpstarter_pb2, @@ -22,12 +23,31 @@ @dataclass(kw_only=True) class Exporter(AbstractAsyncContextManager, Metadata): - channel_factory: Callable[[], grpc.aio.Channel] + channel_factory: Callable[[], Awaitable[grpc.aio.Channel]] device_factory: Callable[[], Driver] lease_name: str = field(init=False, default="") tls: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) grpc_options: dict[str, str] = field(default_factory=dict) registered: bool = field(init=False, default=False) + _stop_requested: bool = field(init=False, default=False) + _started: bool = field(init=False, default=False) + _tg: TaskGroup | None = field(init=False, default=None) + + def stop(self, wait_for_lease_exit=False): + """Signal the exporter to stop. + + Args: + wait_for_lease_exit (bool): If True, wait for the current lease to exit before stopping. + """ + + # Stop immediately if not started yet or if immediate stop is requested + if (not self._started or not wait_for_lease_exit) and self._tg is not None: + logger.info("Stopping exporter immediately") + self._tg.cancel_scope.cancel() + elif not self._stop_requested: + self._stop_requested = True + logger.info("Exporter marked for stop upon lease exit") + async def __aexit__(self, exc_type, exc_value, traceback): import anyio @@ -128,7 +148,6 @@ async def serve(self): # noqa: C901 # initial registration async with self.session(): pass - started = False status_tx, status_rx = create_memory_object_stream() async def status(retries=5, backoff=3): @@ -153,18 +172,23 @@ async def status(retries=5, backoff=3): retries_left = retries async with create_task_group() as tg: + self._tg = tg tg.start_soon(status) async for status in status_rx: if self.lease_name != "" and self.lease_name != status.lease_name: self.lease_name = status.lease_name logger.info("Lease status changed, killing existing connections") - tg.cancel_scope.cancel() + self.stop() break self.lease_name = status.lease_name - if not started and self.lease_name != "": - started = True + if not self._started and self.lease_name != "": + self._started = True tg.start_soon(self.handle, self.lease_name, tg) if status.leased: logger.info("Currently leased by %s under %s", status.client_name, status.lease_name) else: logger.info("Currently not leased") + if self._stop_requested: + self.stop() + break + self._tg = None