From 5e6af025799818f236612d9d8c1dca18ef57191a Mon Sep 17 00:00:00 2001 From: Miguel Angel Ajo Pelayo Date: Mon, 3 Mar 2025 17:22:04 +0000 Subject: [PATCH] Implement the StorageMuxFlasher interface, and Flasher cli This enables simplified usage of mux storage, where the driver can be instantiated for use with a Flasher interface (which only provides flash and dump), or a more granular StorageMuxInterface with dut/host/write/read/off. --- .../jumpstarter_driver_dutlink/driver.py | 4 +- .../jumpstarter_driver_dutlink/driver_test.py | 2 +- .../jumpstarter_driver_opendal/client.py | 82 +++++++++++++++++-- .../jumpstarter_driver_opendal/driver.py | 10 +++ .../jumpstarter_driver_opendal/driver_test.py | 27 +++++- .../jumpstarter_driver_sdwire/driver.py | 4 +- 6 files changed, 118 insertions(+), 11 deletions(-) diff --git a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py index 4926aec64..3c3494b90 100644 --- a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py +++ b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py @@ -11,7 +11,7 @@ from anyio import fail_after, sleep from anyio.streams.file import FileReadStream, FileWriteStream from jumpstarter_driver_composite.driver import CompositeInterface -from jumpstarter_driver_opendal.driver import StorageMuxInterface +from jumpstarter_driver_opendal.driver import StorageMuxFlasherInterface from jumpstarter_driver_power.driver import PowerInterface, PowerReading from jumpstarter_driver_pyserial.driver import PySerial from serial.serialutil import SerialException @@ -156,7 +156,7 @@ async def read(self) -> AsyncGenerator[PowerReading, None]: @dataclass(kw_only=True) -class DutlinkStorageMux(DutlinkConfig, StorageMuxInterface, Driver): +class DutlinkStorageMux(DutlinkConfig, StorageMuxFlasherInterface, Driver): storage_device: str def control(self, action): diff --git a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver_test.py b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver_test.py index 9a1c4a959..72317cb8c 100644 --- a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver_test.py +++ b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver_test.py @@ -17,7 +17,7 @@ def power_test(power): def storage_test(storage): - storage.write_local_file("/dev/null") + storage.flash("/dev/null") def serial_test(serial): diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py index 9395a8047..595762b97 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py @@ -16,6 +16,7 @@ from .adapter import OpendalAdapter from .common import Capability, HashAlgo, Metadata, Mode, PathBuf, PresignedRequest from jumpstarter.client import DriverClient +from jumpstarter.common.exceptions import ArgumentError @dataclass(kw_only=True) @@ -524,6 +525,28 @@ def dump( with OpendalAdapter(client=self, operator=operator, path=path, mode="wb") as handle: return self.call("dump", handle, partition) + def cli(self): + @click.group + def base(): + """Generic flasher interface""" + pass + + @base.command() + @click.argument("file") + @click.option("--partition", type=str) + def flash(file, partition): + """Flash image to DUT from file""" + self.flash(file, partition=partition) + + @base.command() + @click.argument("file") + @click.option("--partition", type=str) + def dump(file, partition): + """Dump image from DUT to file""" + self.dump(file, partition=partition) + + return base + class StorageMuxClient(DriverClient): def host(self): @@ -562,11 +585,9 @@ def read_local_file(self, filepath): absolute = Path(filepath).resolve() return self.read_file(operator=Operator("fs", root="/"), path=str(absolute)) - def cli(self): - @click.group - def base(): - """Generic storage mux""" - pass + def cli(self, base=None): + if base is None: + base = click.group(lambda: None) @base.command() def host(): @@ -589,3 +610,54 @@ def write_local_file(file): self.write_local_file(file) return base + +class StorageMuxFlasherClient(FlasherClient, StorageMuxClient): + def flash( + self, + path: PathBuf, + *, + partition: str | None = None, + operator: Operator | None = None, + ): + """Flash image to DUT""" + if partition is not None: + raise ArgumentError( + f"partition is not supported for StorageMuxFlasherClient, {partition} provided") + + self.host() + + if operator is None: + path, operator = _fs_operator_for_path(path) + + with OpendalAdapter(client=self, operator=operator, path=path, mode="rb") as handle: + try: + return self.write(handle) + finally: + self.dut() + + def dump( + self, + path: PathBuf, + *, + partition: str | None = None, + operator: Operator | None = None, + ): + """Dump image from DUT""" + if partition is not None: + raise ArgumentError( + f"partition is not supported for StorageMuxFlasherClient, {partition} provided") + + self.call("host") + + if operator is None: + path, operator = _fs_operator_for_path(path) + + with OpendalAdapter(client=self, operator=operator, path=path, mode="wb") as handle: + try: + return self.call("read", handle) + finally: + self.call("dut") + + def cli(self): + top_cli = FlasherClient.cli(self) + return StorageMuxClient.cli(self, top_cli) diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py index 58c852dd9..5faf7f840 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py @@ -241,6 +241,12 @@ async def write(self, src: str): ... async def read(self, dst: str): ... +class StorageMuxFlasherInterface(StorageMuxInterface): + @classmethod + def client(cls) -> str: + return "jumpstarter_driver_opendal.client.StorageMuxFlasherClient" + + @dataclass class MockStorageMux(StorageMuxInterface, Driver): file: _TemporaryFileWrapper = field(default_factory=NamedTemporaryFile) @@ -270,3 +276,7 @@ async def read(self, dst: str): async with self.resource(dst) as res: async for chunk in stream: await res.send(chunk) + +@dataclass +class MockStorageMuxFlasher(StorageMuxFlasherInterface, MockStorageMux): + pass diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py index dbbc038f5..6290d767a 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py @@ -5,12 +5,13 @@ from random import randbytes from tempfile import TemporaryDirectory from threading import Thread +from unittest import mock import pytest from opendal import Operator from .common import PresignedRequest -from .driver import MockFlasher, MockStorageMux, Opendal +from .driver import MockFlasher, MockStorageMux, MockStorageMuxFlasher, Opendal from jumpstarter.common.utils import serve @@ -148,6 +149,30 @@ def test_driver_flasher(tmp_path, partition): assert (tmp_path / "dump.img").read_bytes() == b"hello" +def test_driver_mock_storage_mux_flasher(tmp_path): + with serve(MockStorageMuxFlasher()) as flasher: + (tmp_path / "disk.img").write_bytes(b"hello") + + # mock the StorageMuxClient dut/host methods + with mock.patch.object(flasher, "call", side_effect=flasher.call) as mock_method: + + flasher.flash(tmp_path / "disk.img") + # assert the mock had a call to "host", "write" and "dut" + assert mock_method.call_args_list == [ + mock.call("host"), + mock.call("write", mock.ANY), + mock.call("dut"), + ] + + mock_method.reset_mock() + flasher.dump(tmp_path / "dump.img") + assert mock_method.call_args_list == [ + mock.call("host"), + mock.call("read", mock.ANY), + mock.call("dut"), + ] + + assert (tmp_path / "dump.img").read_bytes() == b"hello" def test_drivers_mock_storage_mux_fs(monkeypatch: pytest.MonkeyPatch): with serve(MockStorageMux()) as client: diff --git a/packages/jumpstarter-driver-sdwire/jumpstarter_driver_sdwire/driver.py b/packages/jumpstarter-driver-sdwire/jumpstarter_driver_sdwire/driver.py index bcbf99068..8d41e36df 100644 --- a/packages/jumpstarter-driver-sdwire/jumpstarter_driver_sdwire/driver.py +++ b/packages/jumpstarter-driver-sdwire/jumpstarter_driver_sdwire/driver.py @@ -8,13 +8,13 @@ import usb.util from anyio import fail_after, sleep from anyio.streams.file import FileReadStream, FileWriteStream -from jumpstarter_driver_opendal.driver import StorageMuxInterface +from jumpstarter_driver_opendal.driver import StorageMuxFlasherInterface from jumpstarter.driver import Driver, export @dataclass(kw_only=True) -class SDWire(StorageMuxInterface, Driver): +class SDWire(StorageMuxFlasherInterface, Driver): serial: str | None = field(default=None) dev: usb.core.Device = field(init=False) itf: usb.core.Interface = field(init=False)