From 7f4bc70c2a1a3d0fbfc1792e8754f9d57e748bbb Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Thu, 20 Feb 2025 11:13:30 -0500 Subject: [PATCH 1/4] Make client adapters proper function style context managers --- docs/source/api-reference/adapters/network.md | 15 ++-- .../adapters/fabric.py | 47 ++++++------ .../adapters/novnc.py | 30 ++++---- .../adapters/pexpect.py | 25 +++---- .../adapters/portforward.py | 72 ++++++++++--------- .../jumpstarter_driver_opendal/adapter.py | 49 +++++++------ .../jumpstarter/client/adapters.py | 23 +++--- .../jumpstarter/common/tempfile.py | 2 +- 8 files changed, 130 insertions(+), 133 deletions(-) diff --git a/docs/source/api-reference/adapters/network.md b/docs/source/api-reference/adapters/network.md index 0bd8b149e..485e3f243 100644 --- a/docs/source/api-reference/adapters/network.md +++ b/docs/source/api-reference/adapters/network.md @@ -3,28 +3,23 @@ Network adapters are for transforming network connections exposed by drivers ```{eval-rst} -.. autoclass:: jumpstarter_driver_network.adapters.TcpPortforwardAdapter - :members: +.. autofunction:: jumpstarter_driver_network.adapters.TcpPortforwardAdapter ``` ```{eval-rst} -.. autoclass:: jumpstarter_driver_network.adapters.UnixPortforwardAdapter - :members: +.. autofunction:: jumpstarter_driver_network.adapters.UnixPortforwardAdapter ``` ```{eval-rst} -.. autoclass:: jumpstarter_driver_network.adapters.NovncAdapter - :members: +.. autofunction:: jumpstarter_driver_network.adapters.NovncAdapter ``` ```{eval-rst} -.. autoclass:: jumpstarter_driver_network.adapters.PexpectAdapter - :members: +.. autofunction:: jumpstarter_driver_network.adapters.PexpectAdapter ``` ```{eval-rst} -.. autoclass:: jumpstarter_driver_network.adapters.FabricAdapter - :members: +.. autofunction:: jumpstarter_driver_network.adapters.FabricAdapter ``` ## Examples diff --git a/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/fabric.py b/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/fabric.py index afff5e88f..31d777744 100644 --- a/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/fabric.py +++ b/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/fabric.py @@ -1,30 +1,37 @@ -from dataclasses import dataclass +from contextlib import asynccontextmanager +from functools import partial from typing import Any from fabric.config import Config from fabric.connection import Connection -from .portforward import TcpPortforwardAdapter +from .portforward import handler +from jumpstarter.client import DriverClient +from jumpstarter.client.adapters import blocking +from jumpstarter.common import TemporaryTcpListener -@dataclass(kw_only=True) -class FabricAdapter(TcpPortforwardAdapter): - user: str | None = None - config: Config | None = None - forward_agent: bool | None = None - connect_timeout: int | None = None - connect_kwargs: dict[str, Any] | None = None - inline_ssh_env: bool | None = None - - async def __aenter__(self): - addr = await super().__aenter__() - return Connection( +@blocking +@asynccontextmanager +async def FabricAdapter( + *, + client: DriverClient, + method: str = "connect", + user: str | None = None, + config: Config | None = None, + forward_agent: bool | None = None, + connect_timeout: int | None = None, + connect_kwargs: dict[str, Any] | None = None, + inline_ssh_env: bool | None = None, +): + async with TemporaryTcpListener(partial(handler, client, method)) as addr: + yield Connection( addr[0], - user=self.user, + user=user, port=addr[1], - config=self.config, - forward_agent=self.forward_agent, - connect_timeout=self.connect_timeout, - connect_kwargs=self.connect_kwargs, - inline_ssh_env=self.inline_ssh_env, + config=config, + forward_agent=forward_agent, + connect_timeout=connect_timeout, + connect_kwargs=connect_kwargs, + inline_ssh_env=inline_ssh_env, ) diff --git a/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/novnc.py b/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/novnc.py index 8756f3a0e..c163a8783 100644 --- a/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/novnc.py +++ b/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/novnc.py @@ -1,16 +1,25 @@ -from dataclasses import dataclass +from contextlib import asynccontextmanager from urllib.parse import urlencode, urlunparse from ..streams import WebsocketServerStream -from .portforward import TcpPortforwardAdapter +from jumpstarter.client import DriverClient +from jumpstarter.client.adapters import blocking +from jumpstarter.common import TemporaryTcpListener from jumpstarter.streams import forward_stream -@dataclass(kw_only=True) -class NovncAdapter(TcpPortforwardAdapter): - async def __aenter__(self): - addr = await super().__aenter__() - return urlunparse( +@blocking +@asynccontextmanager +async def NovncAdapter(*, client: DriverClient, method: str = "connect"): + async def handler(conn): + async with conn: + async with client.stream_async(method) as stream: + async with WebsocketServerStream(stream=stream) as stream: + async with forward_stream(conn, stream): + pass + + async with TemporaryTcpListener(handler) as addr: + yield urlunparse( ( "https", "novnc.com", @@ -20,10 +29,3 @@ async def __aenter__(self): "", ) ) - - async def handler(self, conn): - async with conn: - async with self.client.stream_async(self.method) as stream: - async with WebsocketServerStream(stream=stream) as stream: - async with forward_stream(conn, stream): - pass diff --git a/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/pexpect.py b/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/pexpect.py index 653f88567..1edd1be9f 100644 --- a/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/pexpect.py +++ b/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/pexpect.py @@ -1,22 +1,19 @@ import socket -from dataclasses import dataclass +from contextlib import contextmanager from pexpect.fdpexpect import fdspawn from .portforward import TcpPortforwardAdapter +from jumpstarter.client import DriverClient -@dataclass(kw_only=True) -class PexpectAdapter(TcpPortforwardAdapter): - async def __aenter__(self): - addr = await super().__aenter__() +@contextmanager +def PexpectAdapter(*, client: DriverClient, method: str = "connect"): + with TcpPortforwardAdapter(client=client, method=method) as addr: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(addr) - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.connect(addr) - - return fdspawn(self.socket) - - async def __aexit__(self, exc_type, exc_value, traceback): - self.socket.close() - - await super().__aexit__(exc_type, exc_value, traceback) + try: + yield fdspawn(sock) + finally: + sock.close() diff --git a/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/portforward.py b/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/portforward.py index 45024a1b9..557bb14eb 100644 --- a/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/portforward.py +++ b/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/portforward.py @@ -1,40 +1,42 @@ -from dataclasses import dataclass +from contextlib import asynccontextmanager +from functools import partial -from jumpstarter.client.adapters import ClientAdapter +from jumpstarter.client import DriverClient +from jumpstarter.client.adapters import blocking from jumpstarter.common import TemporaryTcpListener, TemporaryUnixListener from jumpstarter.streams import forward_stream -@dataclass(kw_only=True) -class PortforwardAdapter(ClientAdapter): - method: str = "connect" - - async def __aexit__(self, exc_type, exc_value, traceback): - return await self.listener.__aexit__(exc_type, exc_value, traceback) - - async def handler(self, conn): - async with conn: - async with self.client.stream_async(self.method) as stream: - async with forward_stream(conn, stream): - pass - - -@dataclass(kw_only=True) -class TcpPortforwardAdapter(PortforwardAdapter): - local_host: str = "127.0.0.1" - local_port: int = 0 - - async def __aenter__(self): - self.listener = TemporaryTcpListener( - self.handler, local_host=self.local_host, local_port=self.local_port, reuse_port=True - ) - - return await self.listener.__aenter__() - - -@dataclass(kw_only=True) -class UnixPortforwardAdapter(PortforwardAdapter): - async def __aenter__(self): - self.listener = TemporaryUnixListener(self.handler) - - return await self.listener.__aenter__() +async def handler(client, method, conn): + async with conn: + async with client.stream_async(method) as stream: + async with forward_stream(conn, stream): + pass + + +@blocking +@asynccontextmanager +async def TcpPortforwardAdapter( + *, + client: DriverClient, + method: str = "connect", + local_host: str = "127.0.0.1", + local_port: int = 0, +): + async with TemporaryTcpListener( + partial(handler, client, method), + local_host=local_host, + local_port=local_port, + ) as addr: + yield addr + + +@blocking +@asynccontextmanager +async def UnixPortforwardAdapter( + *, + client: DriverClient, + method: str = "connect", +): + async with TemporaryUnixListener(partial(handler, client, method)) as addr: + yield addr diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/adapter.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/adapter.py index 07d09b527..14c68ae6a 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/adapter.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/adapter.py @@ -1,4 +1,4 @@ -from contextlib import suppress +from contextlib import asynccontextmanager, suppress from dataclasses import dataclass from typing import Literal @@ -7,7 +7,8 @@ from opendal import AsyncFile, Operator from opendal.exceptions import Error -from jumpstarter.client.adapters import ClientAdapter +from jumpstarter.client import DriverClient +from jumpstarter.client.adapters import blocking from jumpstarter.common.resources import PresignedRequestResource @@ -44,26 +45,24 @@ async def aclose(self): await self.file.close() -@dataclass(kw_only=True) -class OpendalAdapter(ClientAdapter): - operator: Operator # opendal.Operator for the storage backend - path: str # file path in storage backend relative to the storage root - mode: Literal["rb", "wb"] = "rb" # binary read or binary write mode - - async def __aenter__(self): - # if the access mode is binary read, and the storage backend supports presigned read requests - if self.mode == "rb" and self.operator.capability().presign_read: - # create presigned url for the specified file with a 60 second expiration - presigned = await self.operator.to_async_operator().presign_read(self.path, expire_second=60) - return PresignedRequestResource( - headers=presigned.headers, url=presigned.url, method=presigned.method - ).model_dump(mode="json") - # otherwise stream the file content from the client to the exporter - else: - file = await self.operator.to_async_operator().open(self.path, self.mode) - self.resource = self.client.resource_async(AsyncFileStream(file=file)) - return await self.resource.__aenter__() - - async def __aexit__(self, exc_type, exc_value, traceback): - if hasattr(self, "resource"): - await self.resource.__aexit__(exc_type, exc_value, traceback) +@blocking +@asynccontextmanager +async def OpendalAdapter( + *, + client: DriverClient, + operator: Operator, # opendal.Operator for the storage backend + path: str, # file path in storage backend relative to the storage root + mode: Literal["rb", "wb"] = "rb", # binary read or binary write mode +): + # if the access mode is binary read, and the storage backend supports presigned read requests + if mode == "rb" and operator.capability().presign_read: + # create presigned url for the specified file with a 60 second expiration + presigned = await operator.to_async_operator().presign_read(path, expire_second=60) + yield PresignedRequestResource( + headers=presigned.headers, url=presigned.url, method=presigned.method + ).model_dump(mode="json") + # otherwise stream the file content from the client to the exporter + else: + file = await operator.to_async_operator().open(path, mode) + async with client.resource_async(AsyncFileStream(file=file)) as res: + yield res diff --git a/packages/jumpstarter/jumpstarter/client/adapters.py b/packages/jumpstarter/jumpstarter/client/adapters.py index 94ed4393f..419c6bbc0 100644 --- a/packages/jumpstarter/jumpstarter/client/adapters.py +++ b/packages/jumpstarter/jumpstarter/client/adapters.py @@ -1,17 +1,12 @@ -from contextlib import AbstractAsyncContextManager, AbstractContextManager -from dataclasses import dataclass +from contextlib import contextmanager +from functools import wraps -from jumpstarter.client import DriverClient +def blocking(f): + @wraps(f) + @contextmanager + def wrapper(*args, **kwargs): + with kwargs["client"].portal.wrap_async_context_manager(f(*args, **kwargs)) as res: + yield res -@dataclass(kw_only=True) -class ClientAdapter(AbstractContextManager, AbstractAsyncContextManager): - client: DriverClient - - def __enter__(self): - self.manager = self.client.portal.wrap_async_context_manager(self) - - return self.manager.__enter__() - - def __exit__(self, exc_type, exc_value, traceback): - self.manager.__exit__(exc_type, exc_value, traceback) + return wrapper diff --git a/packages/jumpstarter/jumpstarter/common/tempfile.py b/packages/jumpstarter/jumpstarter/common/tempfile.py index 87d757dbe..8ecdea16e 100644 --- a/packages/jumpstarter/jumpstarter/common/tempfile.py +++ b/packages/jumpstarter/jumpstarter/common/tempfile.py @@ -27,7 +27,7 @@ async def TemporaryUnixListener(handler): @asynccontextmanager async def TemporaryTcpListener( - handler, local_host=None, local_port=0, family=AddressFamily.AF_UNSPEC, backlog=65536, reuse_port=False + handler, local_host="127.0.0.1", local_port=0, family=AddressFamily.AF_UNSPEC, backlog=65536, reuse_port=True ): async with await create_tcp_listener( local_host=local_host, From 7da1709ba456e99aa2380cbb8d5eb5d7e2a907fc Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Thu, 20 Feb 2025 11:57:15 -0500 Subject: [PATCH 2/4] Add missing try block to context managers --- .../jumpstarter/jumpstarter/common/utils.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/packages/jumpstarter/jumpstarter/common/utils.py b/packages/jumpstarter/jumpstarter/common/utils.py index 59f78492d..ef496ac8d 100644 --- a/packages/jumpstarter/jumpstarter/common/utils.py +++ b/packages/jumpstarter/jumpstarter/common/utils.py @@ -18,18 +18,22 @@ async def serve_async(root_device: Driver, portal: BlockingPortal): async with session.serve_unix_async() as path: # SAFETY: the root_device instance is constructed locally thus considered trusted async with client_from_path(path, portal, allow=[], unsafe=True) as client: - yield client - if hasattr(client, "close"): - client.close() + try: + yield client + finally: + if hasattr(client, "close"): + client.close() @contextmanager def serve(root_device: Driver): with start_blocking_portal() as portal: with portal.wrap_async_context_manager(serve_async(root_device, portal)) as client: - yield client - if hasattr(client, "close"): - client.close() + try: + yield client + finally: + if hasattr(client, "close"): + client.close() @asynccontextmanager @@ -48,9 +52,11 @@ async def env_async(portal): allow, unsafe = _allow_from_env() async with client_from_path(host, portal, allow=allow, unsafe=unsafe) as client: - yield client - if hasattr(client, "close"): - client.close() + try: + yield client + finally: + if hasattr(client, "close"): + client.close() @contextmanager From da801a7e637f5878e53295d3c29e29be3ad145ba Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Thu, 20 Feb 2025 12:56:21 -0500 Subject: [PATCH 3/4] Manage context managers with ExitStack --- .../jumpstarter_driver_pyserial/client.py | 7 +---- .../jumpstarter/jumpstarter/client/base.py | 10 +++---- .../jumpstarter/jumpstarter/client/client.py | 8 ++++-- .../jumpstarter/jumpstarter/client/lease.py | 17 +++++++---- .../jumpstarter/jumpstarter/common/utils.py | 28 ++++++++++--------- .../jumpstarter/jumpstarter/config/client.py | 13 +++++++-- .../jumpstarter/config/exporter_test.py | 14 +++++++--- .../jumpstarter/jumpstarter/listener_test.py | 26 +++++++++-------- 8 files changed, 72 insertions(+), 51 deletions(-) diff --git a/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client.py b/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client.py index fca7550ed..f1bc0782b 100644 --- a/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client.py +++ b/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/client.py @@ -22,12 +22,7 @@ def open(self) -> fdspawn: Returns: fdspawn: The pexpect session object. """ - self._context_manager = self.pexpect() - return self._context_manager.__enter__() - - def close(self): - if hasattr(self, "_context_manager"): - self._context_manager.__exit__(None, None, None) + return self.stack.enter_context(self.pexpect()) @contextmanager def pexpect(self): diff --git a/packages/jumpstarter/jumpstarter/client/base.py b/packages/jumpstarter/jumpstarter/client/base.py index 55b3f0c45..f57ce0917 100644 --- a/packages/jumpstarter/jumpstarter/client/base.py +++ b/packages/jumpstarter/jumpstarter/client/base.py @@ -4,7 +4,7 @@ from __future__ import annotations -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager from dataclasses import field from anyio.from_thread import BlockingPortal @@ -31,6 +31,7 @@ class DriverClient(AsyncDriverClient): children: dict[str, DriverClient] = field(default_factory=dict) portal: BlockingPortal + stack: ExitStack def call(self, method, *args): """ @@ -86,16 +87,13 @@ def open_stream(self) -> BlockingStream: :return: blocking stream session object. :rtype: BlockingStream """ - self._context_manager = self.stream() - return self._context_manager.__enter__() + return self.stack.enter_context(self.stream()) def close(self): """ Close the open stream session without a context manager. """ - if hasattr(self, "_context_manager"): - self._context_manager.__exit__(None, None, None) - del self._context_manager + self.stack.close() def __del__(self): self.close() diff --git a/packages/jumpstarter/jumpstarter/client/client.py b/packages/jumpstarter/jumpstarter/client/client.py index 87aabf6bc..86dd12848 100644 --- a/packages/jumpstarter/jumpstarter/client/client.py +++ b/packages/jumpstarter/jumpstarter/client/client.py @@ -1,5 +1,5 @@ from collections import OrderedDict, defaultdict -from contextlib import asynccontextmanager +from contextlib import ExitStack, asynccontextmanager from graphlib import TopologicalSorter from uuid import UUID @@ -13,16 +13,17 @@ @asynccontextmanager -async def client_from_path(path: str, portal: BlockingPortal, allow: list[str], unsafe: bool): +async def client_from_path(path: str, portal: BlockingPortal, stack: ExitStack, allow: list[str], unsafe: bool): async with grpc.aio.secure_channel( f"unix://{path}", grpc.local_channel_credentials(grpc.LocalConnectionType.UDS) ) as channel: - yield await client_from_channel(channel, portal, allow, unsafe) + yield await client_from_channel(channel, portal, stack, allow, unsafe) async def client_from_channel( channel: grpc.aio.Channel, portal: BlockingPortal, + stack: ExitStack, allow: list[str], unsafe: bool, ) -> DriverClient: @@ -49,6 +50,7 @@ async def client_from_channel( labels=report.labels, channel=channel, portal=portal, + stack=stack.enter_context(ExitStack()), children={reports[k].labels["jumpstarter.dev/name"]: clients[k] for k in topo[uuid]}, ) diff --git a/packages/jumpstarter/jumpstarter/client/lease.py b/packages/jumpstarter/jumpstarter/client/lease.py index a836cba68..5f9ff8502 100644 --- a/packages/jumpstarter/jumpstarter/client/lease.py +++ b/packages/jumpstarter/jumpstarter/client/lease.py @@ -1,5 +1,11 @@ import logging -from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager, contextmanager +from contextlib import ( + AbstractAsyncContextManager, + AbstractContextManager, + ExitStack, + asynccontextmanager, + contextmanager, +) from dataclasses import dataclass, field from anyio import fail_after, sleep @@ -147,15 +153,16 @@ async def serve_unix_async(self): yield path @asynccontextmanager - async def connect_async(self): + async def connect_async(self, stack): async with self.serve_unix_async() as path: - async with client_from_path(path, self.portal, allow=self.allow, unsafe=self.unsafe) as client: + async with client_from_path(path, self.portal, stack, allow=self.allow, unsafe=self.unsafe) as client: yield client @contextmanager def connect(self): - with self.portal.wrap_async_context_manager(self.connect_async()) as client: - yield client + with ExitStack() as stack: + with self.portal.wrap_async_context_manager(self.connect_async(stack)) as client: + yield client @contextmanager def serve_unix(self): diff --git a/packages/jumpstarter/jumpstarter/common/utils.py b/packages/jumpstarter/jumpstarter/common/utils.py index ef496ac8d..686c7c2ab 100644 --- a/packages/jumpstarter/jumpstarter/common/utils.py +++ b/packages/jumpstarter/jumpstarter/common/utils.py @@ -1,6 +1,6 @@ import os import sys -from contextlib import asynccontextmanager, contextmanager +from contextlib import ExitStack, asynccontextmanager, contextmanager from subprocess import Popen from anyio.from_thread import BlockingPortal, start_blocking_portal @@ -13,11 +13,11 @@ @asynccontextmanager -async def serve_async(root_device: Driver, portal: BlockingPortal): +async def serve_async(root_device: Driver, portal: BlockingPortal, stack: ExitStack): with Session(root_device=root_device) as session: async with session.serve_unix_async() as path: # SAFETY: the root_device instance is constructed locally thus considered trusted - async with client_from_path(path, portal, allow=[], unsafe=True) as client: + async with client_from_path(path, portal, stack, allow=[], unsafe=True) as client: try: yield client finally: @@ -28,16 +28,17 @@ async def serve_async(root_device: Driver, portal: BlockingPortal): @contextmanager def serve(root_device: Driver): with start_blocking_portal() as portal: - with portal.wrap_async_context_manager(serve_async(root_device, portal)) as client: - try: - yield client - finally: - if hasattr(client, "close"): - client.close() + with ExitStack() as stack: + with portal.wrap_async_context_manager(serve_async(root_device, portal, stack)) as client: + try: + yield client + finally: + if hasattr(client, "close"): + client.close() @asynccontextmanager -async def env_async(portal): +async def env_async(portal, stack): """Provide a client for an existing JUMPSTARTER_HOST environment variable. Async version of env() @@ -51,7 +52,7 @@ async def env_async(portal): allow, unsafe = _allow_from_env() - async with client_from_path(host, portal, allow=allow, unsafe=unsafe) as client: + async with client_from_path(host, portal, stack, allow=allow, unsafe=unsafe) as client: try: yield client finally: @@ -67,8 +68,9 @@ def env(): to either a local exporter or a remote one. """ with start_blocking_portal() as portal: - with portal.wrap_async_context_manager(env_async(portal)) as client: - yield client + with ExitStack() as stack: + with portal.wrap_async_context_manager(env_async(portal, stack)) as client: + yield client ANSI_GRAY = "\\[\\e[90m\\]" diff --git a/packages/jumpstarter/jumpstarter/config/client.py b/packages/jumpstarter/jumpstarter/config/client.py index 43c73201e..2a793e398 100644 --- a/packages/jumpstarter/jumpstarter/config/client.py +++ b/packages/jumpstarter/jumpstarter/config/client.py @@ -77,7 +77,11 @@ def release_lease(self, name): with start_blocking_portal() as portal: portal.call(self.release_lease_async, name) - async def request_lease_async(self, metadata_filter: MetadataFilter, portal: BlockingPortal): + async def request_lease_async( + self, + metadata_filter: MetadataFilter, + portal: BlockingPortal, + ): # dynamically import to avoid circular imports from jumpstarter.client import Lease @@ -104,7 +108,12 @@ async def release_lease_async(self, name): await controller.ReleaseLease(jumpstarter_pb2.ReleaseLeaseRequest(name=name)) @asynccontextmanager - async def lease_async(self, metadata_filter: MetadataFilter, lease_name: str | None, portal: BlockingPortal): + async def lease_async( + self, + metadata_filter: MetadataFilter, + lease_name: str | None, + portal: BlockingPortal, + ): from jumpstarter.client import Lease # if no lease_name provided, check if it is set in the environment diff --git a/packages/jumpstarter/jumpstarter/config/exporter_test.py b/packages/jumpstarter/jumpstarter/config/exporter_test.py index d648c039f..e4eb98149 100644 --- a/packages/jumpstarter/jumpstarter/config/exporter_test.py +++ b/packages/jumpstarter/jumpstarter/config/exporter_test.py @@ -1,3 +1,4 @@ +from contextlib import ExitStack from pathlib import Path import pytest @@ -51,10 +52,15 @@ async def test_exporter_serve(mock_controller): tg.start_soon(exporter.serve) with start_blocking_portal() as portal: - async with client.lease_async(metadata_filter=MetadataFilter(), lease_name=None, portal=portal) as lease: - async with lease.connect_async() as client: - await client.power.call_async("on") - assert hasattr(client.nested, "tcp") + async with client.lease_async( + metadata_filter=MetadataFilter(), + lease_name=None, + portal=portal, + ) as lease: + with ExitStack() as stack: + async with lease.connect_async(stack) as client: + await client.power.call_async("on") + assert hasattr(client.nested, "tcp") tg.cancel_scope.cancel() diff --git a/packages/jumpstarter/jumpstarter/listener_test.py b/packages/jumpstarter/jumpstarter/listener_test.py index 1daaee0dc..3bd58f989 100644 --- a/packages/jumpstarter/jumpstarter/listener_test.py +++ b/packages/jumpstarter/jumpstarter/listener_test.py @@ -1,6 +1,7 @@ # These tests are flaky # https://github.com/grpc/grpc/issues/25364 +from contextlib import ExitStack from uuid import uuid4 import grpc @@ -47,10 +48,10 @@ async def handle_async(stream): unsafe=True, ) - monkeypatch.setattr(lease, "handle_async", handle_async) - - async with lease.connect_async() as client: - await client.call_async("on") + with ExitStack() as stack: + monkeypatch.setattr(lease, "handle_async", handle_async) + async with lease.connect_async(stack) as client: + await client.call_async("on") tg.cancel_scope.cancel() @@ -97,13 +98,14 @@ async def test_controller(mock_controller): allow=[], unsafe=True, ) as lease: - async with lease.connect_async() as client: - await client.call_async("on") - # test concurrent connections - async with lease.connect_async() as client2: - await client2.call_async("on") - - async with lease.connect_async() as client: - await client.call_async("on") + with ExitStack() as stack: + async with lease.connect_async(stack) as client: + await client.call_async("on") + # test concurrent connections + async with lease.connect_async(stack) as client2: + await client2.call_async("on") + + async with lease.connect_async(stack) as client: + await client.call_async("on") tg.cancel_scope.cancel() From 83d06fbeb8d5cceb12a1f2fb62116f8472c6dbc2 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Mon, 24 Feb 2025 11:36:05 -0500 Subject: [PATCH 4/4] fixup! Make client adapters proper function style context managers --- .../adapters/dbus.py | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/dbus.py b/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/dbus.py index 92eaf1e9d..e764c5bd2 100644 --- a/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/dbus.py +++ b/packages/jumpstarter-driver-network/jumpstarter_driver_network/adapters/dbus.py @@ -1,28 +1,31 @@ -from dataclasses import dataclass +from contextlib import contextmanager from os import environ, getenv from .portforward import TcpPortforwardAdapter +from jumpstarter.client import DriverClient -@dataclass(kw_only=True) -class DbusAdapter(TcpPortforwardAdapter): - async def __aenter__(self): - addr = await super().__aenter__() - match self.client.kind: - case "system": - self.varname = "DBUS_SYSTEM_BUS_ADDRESS" - pass - case "session": - self.varname = "DBUS_SESSION_BUS_ADDRESS" - pass - case _: - raise ValueError(f"invalid bus type: {self.client.kind}") - self.oldenv = getenv(self.varname) - environ[self.varname] = f"tcp:host={addr[0]},port={addr[1]}" +@contextmanager +def DbusAdapter(*, client: DriverClient): + match client.kind: + case "system": + varname = "DBUS_SYSTEM_BUS_ADDRESS" + pass + case "session": + varname = "DBUS_SESSION_BUS_ADDRESS" + pass + case _: + raise ValueError(f"invalid bus type: {client.kind}") - async def __aexit__(self, exc_type, exc_value, traceback): - await super().__aexit__(exc_type, exc_value, traceback) - if self.oldenv is None: - del environ[self.varname] - else: - environ[self.varname] = self.oldenv + oldenv = getenv(varname) + + with TcpPortforwardAdapter(client=client) as addr: + environ[varname] = f"tcp:host={addr[0]},port={addr[1]}" + + try: + yield + finally: + if oldenv is None: + del environ[varname] + else: + environ[varname] = oldenv