diff --git a/packages/jumpstarter/jumpstarter/exporter/exporter.py b/packages/jumpstarter/jumpstarter/exporter/exporter.py index 428292f89..e4ff3f865 100644 --- a/packages/jumpstarter/jumpstarter/exporter/exporter.py +++ b/packages/jumpstarter/jumpstarter/exporter/exporter.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field import grpc -from anyio import connect_unix, create_task_group +from anyio import connect_unix, create_memory_object_stream, create_task_group, sleep from google.protobuf import empty_pb2 from jumpstarter_protocol import ( jumpstarter_pb2, @@ -38,9 +38,12 @@ async def __aexit__(self, exc_type, exc_value, traceback): ) async def __handle(self, path, endpoint, token, tls_config, grpc_options): - async with await connect_unix(path) as stream: - async with connect_router_stream(endpoint, token, stream, tls_config, grpc_options): - pass + try: + async with await connect_unix(path) as stream: + async with connect_router_stream(endpoint, token, stream, tls_config, grpc_options): + pass + except Exception as e: + logger.info("failed to handle connection: {}".format(e)) @asynccontextmanager async def session(self): @@ -65,23 +68,71 @@ async def session(self): yield path async def handle(self, lease_name, tg): - controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory()) logger.info("Listening for incoming connection requests on lease %s", lease_name) + + listen_tx, listen_rx = create_memory_object_stream() + + async def listen(retries=5, backoff=3): + retries_left = retries + while True: + try: + controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory()) + async for request in controller.Listen(jumpstarter_pb2.ListenRequest(lease_name=lease_name)): + await listen_tx.send(request) + except Exception as e: + if retries_left > 0: + retries_left -= 1 + logger.info( + "Listen stream interrupted, restarting in {}s, {} retries left: {}".format( + backoff, retries_left, e + ) + ) + await sleep(backoff) + else: + raise + else: + retries_left = retries + + tg.start_soon(listen) + async with self.session() as path: - async for request in controller.Listen(jumpstarter_pb2.ListenRequest(lease_name=lease_name)): + async for request in listen_rx: logger.info("Handling new connection request on lease %s", lease_name) tg.start_soon( self.__handle, path, request.router_endpoint, request.router_token, self.tls, self.grpc_options ) - async def serve(self): - controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory()) + async def serve(self): # noqa: C901 # initial registration async with self.session(): pass started = False + status_tx, status_rx = create_memory_object_stream() + + async def status(retries=5, backoff=3): + retries_left = retries + while True: + try: + controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory()) + async for status in controller.Status(jumpstarter_pb2.StatusRequest()): + await status_tx.send(status) + except Exception as e: + if retries_left > 0: + retries_left -= 1 + logger.info( + "Status stream interrupted, restarting in {}s, {} retries left: {}".format( + backoff, retries_left, e + ) + ) + await sleep(backoff) + else: + raise + else: + retries_left = retries + async with create_task_group() as tg: - async for status in controller.Status(jumpstarter_pb2.StatusRequest()): + tg.start_soon(status) + async for status in status_rx: if self.lease_name != "" and self.lease_name != status.lease_name: self.lease_name = status.lease_name logger.info("Lease status changed, killing existing connections")