Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from anyio import fail_after, sleep
from anyio.streams.file import FileReadStream, FileWriteStream
from jumpstarter_driver_composite.driver import CompositeInterface
from jumpstarter_driver_opendal.driver import StorageMuxInterface
from jumpstarter_driver_opendal.driver import StorageMuxFlasherInterface
from jumpstarter_driver_power.driver import PowerInterface, PowerReading
from jumpstarter_driver_pyserial.driver import PySerial
from serial.serialutil import SerialException
Expand Down Expand Up @@ -156,7 +156,7 @@ async def read(self) -> AsyncGenerator[PowerReading, None]:


@dataclass(kw_only=True)
class DutlinkStorageMux(DutlinkConfig, StorageMuxInterface, Driver):
class DutlinkStorageMux(DutlinkConfig, StorageMuxFlasherInterface, Driver):
storage_device: str

def control(self, action):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def power_test(power):


def storage_test(storage):
storage.write_local_file("/dev/null")
storage.flash("/dev/null")


def serial_test(serial):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .adapter import OpendalAdapter
from .common import Capability, HashAlgo, Metadata, Mode, PathBuf, PresignedRequest
from jumpstarter.client import DriverClient
from jumpstarter.common.exceptions import ArgumentError


@dataclass(kw_only=True)
Expand Down Expand Up @@ -524,6 +525,28 @@ def dump(
with OpendalAdapter(client=self, operator=operator, path=path, mode="wb") as handle:
return self.call("dump", handle, partition)

def cli(self):
@click.group
def base():
"""Generic flasher interface"""
pass

@base.command()
@click.argument("file")
@click.option("--partition", type=str)
def flash(file, partition):
"""Flash image to DUT from file"""
self.flash(file, partition=partition)

@base.command()
@click.argument("file")
@click.option("--partition", type=str)
def dump(file, partition):
"""Dump image from DUT to file"""
self.dump(file, partition=partition)

return base


class StorageMuxClient(DriverClient):
def host(self):
Expand Down Expand Up @@ -562,11 +585,9 @@ def read_local_file(self, filepath):
absolute = Path(filepath).resolve()
return self.read_file(operator=Operator("fs", root="/"), path=str(absolute))

def cli(self):
@click.group
def base():
"""Generic storage mux"""
pass
def cli(self, base=None):
if base is None:
base = click.group(lambda: None)

@base.command()
def host():
Expand All @@ -589,3 +610,54 @@ def write_local_file(file):
self.write_local_file(file)

return base

class StorageMuxFlasherClient(FlasherClient, StorageMuxClient):
def flash(
self,
path: PathBuf,
*,
partition: str | None = None,
operator: Operator | None = None,
):
"""Flash image to DUT"""
if partition is not None:
raise ArgumentError(
f"partition is not supported for StorageMuxFlasherClient, {partition} provided")

self.host()

if operator is None:
path, operator = _fs_operator_for_path(path)

with OpendalAdapter(client=self, operator=operator, path=path, mode="rb") as handle:
try:
return self.write(handle)
finally:
self.dut()

def dump(
self,
path: PathBuf,
*,
partition: str | None = None,
operator: Operator | None = None,
):
"""Dump image from DUT"""
if partition is not None:
raise ArgumentError(
f"partition is not supported for StorageMuxFlasherClient, {partition} provided")

self.call("host")

if operator is None:
path, operator = _fs_operator_for_path(path)

with OpendalAdapter(client=self, operator=operator, path=path, mode="wb") as handle:
try:
return self.call("read", handle)
finally:
self.call("dut")

def cli(self):
top_cli = FlasherClient.cli(self)
return StorageMuxClient.cli(self, top_cli)
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ async def write(self, src: str): ...
async def read(self, dst: str): ...


class StorageMuxFlasherInterface(StorageMuxInterface):
@classmethod
def client(cls) -> str:
return "jumpstarter_driver_opendal.client.StorageMuxFlasherClient"


@dataclass
class MockStorageMux(StorageMuxInterface, Driver):
file: _TemporaryFileWrapper = field(default_factory=NamedTemporaryFile)
Expand Down Expand Up @@ -270,3 +276,7 @@ async def read(self, dst: str):
async with self.resource(dst) as res:
async for chunk in stream:
await res.send(chunk)

@dataclass
class MockStorageMuxFlasher(StorageMuxFlasherInterface, MockStorageMux):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from random import randbytes
from tempfile import TemporaryDirectory
from threading import Thread
from unittest import mock

import pytest
from opendal import Operator

from .common import PresignedRequest
from .driver import MockFlasher, MockStorageMux, Opendal
from .driver import MockFlasher, MockStorageMux, MockStorageMuxFlasher, Opendal
from jumpstarter.common.utils import serve


Expand Down Expand Up @@ -148,6 +149,30 @@ def test_driver_flasher(tmp_path, partition):

assert (tmp_path / "dump.img").read_bytes() == b"hello"

def test_driver_mock_storage_mux_flasher(tmp_path):
with serve(MockStorageMuxFlasher()) as flasher:
(tmp_path / "disk.img").write_bytes(b"hello")

# mock the StorageMuxClient dut/host methods
with mock.patch.object(flasher, "call", side_effect=flasher.call) as mock_method:

flasher.flash(tmp_path / "disk.img")
# assert the mock had a call to "host", "write" and "dut"
assert mock_method.call_args_list == [
mock.call("host"),
mock.call("write", mock.ANY),
mock.call("dut"),
]

mock_method.reset_mock()
flasher.dump(tmp_path / "dump.img")
assert mock_method.call_args_list == [
mock.call("host"),
mock.call("read", mock.ANY),
mock.call("dut"),
]

assert (tmp_path / "dump.img").read_bytes() == b"hello"

def test_drivers_mock_storage_mux_fs(monkeypatch: pytest.MonkeyPatch):
with serve(MockStorageMux()) as client:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import usb.util
from anyio import fail_after, sleep
from anyio.streams.file import FileReadStream, FileWriteStream
from jumpstarter_driver_opendal.driver import StorageMuxInterface
from jumpstarter_driver_opendal.driver import StorageMuxFlasherInterface

from jumpstarter.driver import Driver, export


@dataclass(kw_only=True)
class SDWire(StorageMuxInterface, Driver):
class SDWire(StorageMuxFlasherInterface, Driver):
serial: str | None = field(default=None)
dev: usb.core.Device = field(init=False)
itf: usb.core.Interface = field(init=False)
Expand Down