From d930cf529f7c7a6cb9a3a4d803825437671147f9 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Mon, 10 Feb 2025 10:50:36 -0500 Subject: [PATCH 1/2] Power.on/off method returns None --- .../getting-started/setup-local-exporter.md | 4 +- .../driver_test.py | 4 +- .../jumpstarter_driver_dutlink/driver.py | 8 +-- .../jumpstarter_driver_power/client.py | 12 ++--- .../jumpstarter_driver_power/client_test.py | 8 +-- .../jumpstarter_driver_power/driver.py | 20 +++---- .../jumpstarter_driver_power/driver_test.py | 8 +-- .../jumpstarter_driver_raspberrypi/client.py | 4 +- .../jumpstarter_driver_raspberrypi/driver.py | 4 +- .../jumpstarter_driver_yepkit/driver.py | 53 ++++++------------- .../jumpstarter_testing/pytest_test.py | 2 +- .../jumpstarter/config/exporter.py | 1 + .../jumpstarter/config/exporter_test.py | 2 +- .../jumpstarter/jumpstarter/listener_test.py | 8 +-- 14 files changed, 59 insertions(+), 79 deletions(-) diff --git a/docs/source/getting-started/setup-local-exporter.md b/docs/source/getting-started/setup-local-exporter.md index 2dd41b37b..a805e4250 100644 --- a/docs/source/getting-started/setup-local-exporter.md +++ b/docs/source/getting-started/setup-local-exporter.md @@ -164,10 +164,10 @@ from jumpstarter_testing.pytest import JumpstarterTest class MyTest(JumpstarterTest): def test_power_on(self, client): - assert client.power.on() == "ok" + client.power.on() def test_power_off(self, client): - assert client.power.off() == "ok" + client.power.off() ``` ```shell diff --git a/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver_test.py b/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver_test.py index 750ad902f..43e0508dd 100644 --- a/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver_test.py +++ b/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver_test.py @@ -17,5 +17,5 @@ def test_drivers_composite(): }, ) ) as client: - assert client.power0.on() == "ok" - assert client.composite1.power1.on() == "ok" + client.power0.on() + client.composite1.power1.on() diff --git a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py index 39de02dbf..4926aec64 100644 --- a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py +++ b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py @@ -122,12 +122,12 @@ def close(self): self.off() @export - def on(self): - return self.control("on") + def on(self) -> None: + self.control("on") @export - def off(self): - return self.control("off") + def off(self) -> None: + self.control("off") @export async def read(self) -> AsyncGenerator[PowerReading, None]: diff --git a/packages/jumpstarter-driver-power/jumpstarter_driver_power/client.py b/packages/jumpstarter-driver-power/jumpstarter_driver_power/client.py index 924f59675..ff6a572ef 100644 --- a/packages/jumpstarter-driver-power/jumpstarter_driver_power/client.py +++ b/packages/jumpstarter-driver-power/jumpstarter_driver_power/client.py @@ -7,11 +7,11 @@ class PowerClient(DriverClient): - def on(self) -> str: - return self.call("on") + def on(self) -> None: + self.call("on") - def off(self) -> str: - return self.call("off") + def off(self) -> None: + self.call("off") def read(self) -> Generator[PowerReading, None, None]: for v in self.streamingcall("read"): @@ -26,11 +26,11 @@ def base(): @base.command() def on(): """Power on""" - click.echo(self.on()) + self.on() @base.command() def off(): """Power off""" - click.echo(self.off()) + self.off() return base diff --git a/packages/jumpstarter-driver-power/jumpstarter_driver_power/client_test.py b/packages/jumpstarter-driver-power/jumpstarter_driver_power/client_test.py index 0fd9ae8e6..9f204dd8a 100644 --- a/packages/jumpstarter-driver-power/jumpstarter_driver_power/client_test.py +++ b/packages/jumpstarter-driver-power/jumpstarter_driver_power/client_test.py @@ -5,8 +5,8 @@ def test_client_mock_power(): with serve(MockPower()) as client: - assert client.on() == "ok" - assert client.off() == "ok" + client.on() + client.off() assert list(client.read()) == [ PowerReading(voltage=0.0, current=0.0), @@ -16,8 +16,8 @@ def test_client_mock_power(): def test_client_sync_mock_power(): with serve(SyncMockPower()) as client: - assert client.on() == "ok" - assert client.off() == "ok" + client.on() + client.off() assert list(client.read()) == [ PowerReading(voltage=0.0, current=0.0), diff --git a/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver.py b/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver.py index b5326a578..cd23c2ba4 100644 --- a/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver.py +++ b/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver.py @@ -11,10 +11,10 @@ def client(cls) -> str: return "jumpstarter_driver_power.client.PowerClient" @abstractmethod - async def on(self) -> str: ... + async def on(self) -> None: ... @abstractmethod - async def off(self) -> str: ... + async def off(self) -> None: ... @abstractmethod async def read(self) -> AsyncGenerator[PowerReading, None]: ... @@ -22,12 +22,12 @@ async def read(self) -> AsyncGenerator[PowerReading, None]: ... class MockPower(PowerInterface, Driver): @export - async def on(self) -> str: - return "ok" + async def on(self) -> None: + pass @export - async def off(self) -> str: - return "ok" + async def off(self) -> None: + pass @export async def read(self) -> AsyncGenerator[PowerReading, None]: @@ -37,12 +37,12 @@ async def read(self) -> AsyncGenerator[PowerReading, None]: class SyncMockPower(PowerInterface, Driver): @export - def on(self) -> str: - return "ok" + def on(self) -> None: + pass @export - def off(self) -> str: - return "ok" + def off(self) -> None: + pass @export def read(self) -> Generator[PowerReading, None]: diff --git a/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver_test.py b/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver_test.py index f0824ff47..9f766a1f8 100644 --- a/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver_test.py +++ b/packages/jumpstarter-driver-power/jumpstarter_driver_power/driver_test.py @@ -9,8 +9,8 @@ async def test_driver_mock_power(): driver = MockPower() - assert await driver.on() == "ok" - assert await driver.off() == "ok" + await driver.on() + await driver.off() assert [v async for v in driver.read()] == [ PowerReading(voltage=0.0, current=0.0), @@ -21,8 +21,8 @@ async def test_driver_mock_power(): def test_driver_sync_mock_power(): driver = SyncMockPower() - assert driver.on() == "ok" - assert driver.off() == "ok" + driver.on() + driver.off() assert list(driver.read()) == [ PowerReading(voltage=0.0, current=0.0), diff --git a/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/client.py b/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/client.py index 3242c78c9..e31325d03 100644 --- a/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/client.py +++ b/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/client.py @@ -5,10 +5,10 @@ @dataclass(kw_only=True) class DigitalOutputClient(DriverClient): - def off(self): + def off(self) -> None: self.call("off") - def on(self): + def on(self) -> None: self.call("on") diff --git a/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py b/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py index f159d2345..0805bddf4 100644 --- a/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py +++ b/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py @@ -26,14 +26,14 @@ def close(self): super().close() @export - def off(self): + def off(self) -> None: if not isinstance(self.device, DigitalOutputDevice): self.device.close() self.device = DigitalOutputDevice(pin=self.pin, initial_value=None) self.device.off() @export - def on(self): + def on(self) -> None: if not isinstance(self.device, DigitalOutputDevice): self.device.close() self.device = DigitalOutputDevice(pin=self.pin, initial_value=None) diff --git a/packages/jumpstarter-driver-yepkit/jumpstarter_driver_yepkit/driver.py b/packages/jumpstarter-driver-yepkit/jumpstarter_driver_yepkit/driver.py index 8d16820b9..875dda0e5 100644 --- a/packages/jumpstarter-driver-yepkit/jumpstarter_driver_yepkit/driver.py +++ b/packages/jumpstarter-driver-yepkit/jumpstarter_driver_yepkit/driver.py @@ -11,25 +11,11 @@ VID = 0x04D8 PID = 0xF2F7 -PORT_UP_COMMANDS = { - '1': 0x11, - '2': 0x12, - '3': 0x13, - 'all': 0x1A -} - -PORT_DOWN_COMMANDS = { - '1': 0x01, - '2': 0x02, - '3': 0x03, - 'all': 0x0A -} - -PORT_STATUS_COMMANDS = { - '1': 0x21, - '2': 0x22, - '3': 0x23 -} +PORT_UP_COMMANDS = {"1": 0x11, "2": 0x12, "3": 0x13, "all": 0x1A} + +PORT_DOWN_COMMANDS = {"1": 0x01, "2": 0x02, "3": 0x03, "all": 0x0A} + +PORT_STATUS_COMMANDS = {"1": 0x21, "2": 0x22, "3": 0x23} VALID_DEFAULTS = ["on", "off", "keep"] @@ -37,9 +23,11 @@ _USB_DEVS = {} _USB_DEVS_LOCK = threading.Lock() # Lock for synchronizing access, we don't do multithread, but just in case.. + @dataclass(kw_only=True) class Ykush(PowerInterface, Driver): - """ driver for Yepkit Ykush USB Hub with Power control """ + """driver for Yepkit Ykush USB Hub with Power control""" + serial: str | None = field(default=None) default: str = "off" port: str = "all" @@ -52,12 +40,10 @@ def __post_init__(self): keys = PORT_UP_COMMANDS.keys() if self.port not in keys: - raise ValueError( - f"The ykush driver port must be any of the following values: {keys}") + raise ValueError(f"The ykush driver port must be any of the following values: {keys}") if self.default not in VALID_DEFAULTS: - raise ValueError( - f"The ykush driver default must be any of the following values: {VALID_DEFAULTS}") + raise ValueError(f"The ykush driver default must be any of the following values: {VALID_DEFAULTS}") with _USB_DEVS_LOCK: # another instance already claimed this device? @@ -75,8 +61,7 @@ def __post_init__(self): if serial == self.serial or self.serial is None: _USB_DEVS[serial] = dev if self.serial is None: - self.logger.warning( - f"No serial number provided for ykush, using the first one found: {serial}") + self.logger.warning(f"No serial number provided for ykush, using the first one found: {serial}") self.serial = serial self.dev = dev return @@ -86,7 +71,7 @@ def __post_init__(self): def _send_cmd(self, cmd, report_size=64): out_ep, in_ep = self._get_endpoints(self.dev) out_buf = [0x00] * report_size - out_buf[0] = cmd # YKUSH command + out_buf[0] = cmd # YKUSH command # Write to the OUT endpoint out_ep.write(out_buf) @@ -103,15 +88,11 @@ def _get_endpoints(self, dev): interface = cfg[(0, 0)] out_endpoint = usb.util.find_descriptor( - interface, - custom_match=lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_OUT + interface, custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_OUT ) in_endpoint = usb.util.find_descriptor( - interface, - custom_match=lambda e: \ - usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_IN + interface, custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_IN ) if not out_endpoint or not in_endpoint: @@ -127,18 +108,16 @@ def reset(self): self.off() @export - def on(self): + def on(self) -> None: self.logger.info(f"Power ON for Ykush {self.serial} on port {self.port}") cmd = PORT_UP_COMMANDS.get(self.port) _ = self._send_cmd(cmd) - return @export - def off(self): + def off(self) -> None: self.logger.info(f"Power OFF for Ykush {self.serial} on port {self.port}") cmd = PORT_DOWN_COMMANDS.get(self.port) _ = self._send_cmd(cmd) - return @export def read(self) -> AsyncGenerator[PowerReading, None]: diff --git a/packages/jumpstarter-testing/jumpstarter_testing/pytest_test.py b/packages/jumpstarter-testing/jumpstarter_testing/pytest_test.py index b026ec5d3..f1697cc2c 100644 --- a/packages/jumpstarter-testing/jumpstarter_testing/pytest_test.py +++ b/packages/jumpstarter-testing/jumpstarter_testing/pytest_test.py @@ -12,7 +12,7 @@ def test_env(pytester: Pytester, monkeypatch): class TestSample(JumpstarterTest): def test_simple(self, client): - assert client.on() == "ok" + client.on() """ ) diff --git a/packages/jumpstarter/jumpstarter/config/exporter.py b/packages/jumpstarter/jumpstarter/config/exporter.py index 60dad8fe0..3d7b1f35f 100644 --- a/packages/jumpstarter/jumpstarter/config/exporter.py +++ b/packages/jumpstarter/jumpstarter/config/exporter.py @@ -34,6 +34,7 @@ def from_path(cls, path: str) -> ExporterConfigV1Alpha1DriverInstance: with open(path) as f: return cls.model_validate(yaml.safe_load(f)) + class ExporterConfigV1Alpha1(BaseModel): BASE_PATH: ClassVar[Path] = Path("/etc/jumpstarter/exporters") diff --git a/packages/jumpstarter/jumpstarter/config/exporter_test.py b/packages/jumpstarter/jumpstarter/config/exporter_test.py index 33348cc0a..d648c039f 100644 --- a/packages/jumpstarter/jumpstarter/config/exporter_test.py +++ b/packages/jumpstarter/jumpstarter/config/exporter_test.py @@ -53,7 +53,7 @@ async def test_exporter_serve(mock_controller): with start_blocking_portal() as portal: async with client.lease_async(metadata_filter=MetadataFilter(), lease_name=None, portal=portal) as lease: async with lease.connect_async() as client: - assert await client.power.call_async("on") == "ok" + await client.power.call_async("on") assert hasattr(client.nested, "tcp") tg.cancel_scope.cancel() diff --git a/packages/jumpstarter/jumpstarter/listener_test.py b/packages/jumpstarter/jumpstarter/listener_test.py index d2025489f..81bdf8136 100644 --- a/packages/jumpstarter/jumpstarter/listener_test.py +++ b/packages/jumpstarter/jumpstarter/listener_test.py @@ -49,7 +49,7 @@ async def handle_async(stream): monkeypatch.setattr(lease, "handle_async", handle_async) async with lease.connect_async() as client: - assert await client.call_async("on") == "ok" + await client.call_async("on") tg.cancel_scope.cancel() @@ -97,12 +97,12 @@ async def test_controller(mock_controller): unsafe=True, ) as lease: async with lease.connect_async() as client: - assert await client.call_async("on") == "ok" + await client.call_async("on") # test concurrent connections async with lease.connect_async() as client2: - assert await client2.call_async("on") == "ok" + await client2.call_async("on") async with lease.connect_async() as client: - assert await client.call_async("on") == "ok" + await client.call_async("on") tg.cancel_scope.cancel() From 82860f284a88ddf5213338da24194f7ffba4152c Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Mon, 10 Feb 2025 10:52:06 -0500 Subject: [PATCH 2/2] treewide: ruff --- .../examples/tftp_test.py | 1 + .../jumpstarter_driver_tftp/__init__.py | 2 +- .../jumpstarter_driver_tftp/driver.py | 10 +- .../jumpstarter_driver_tftp/driver_test.py | 7 + .../jumpstarter_driver_tftp/server.py | 151 +++++++++--------- .../jumpstarter_driver_tftp/server_test.py | 93 +++++------ 6 files changed, 128 insertions(+), 136 deletions(-) diff --git a/packages/jumpstarter-driver-tftp/examples/tftp_test.py b/packages/jumpstarter-driver-tftp/examples/tftp_test.py index 735fcc140..c5aa221cb 100644 --- a/packages/jumpstarter-driver-tftp/examples/tftp_test.py +++ b/packages/jumpstarter-driver-tftp/examples/tftp_test.py @@ -6,6 +6,7 @@ log = logging.getLogger(__name__) + class TestResource(JumpstarterTest): filter_labels = {"board": "rpi4"} diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py index fc3188467..e70eab33a 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py @@ -1 +1 @@ -CHUNK_SIZE = 1024 * 1024 * 4 # 4MB +CHUNK_SIZE = 1024 * 1024 * 4 # 4MB diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py index 5d526e2bd..d90ea6773 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py @@ -17,16 +17,22 @@ class TftpError(Exception): """Base exception for TFTP server errors""" + pass + class ServerNotRunning(TftpError): """Server is not running""" + pass + class FileNotFound(TftpError): """File not found""" + pass + @dataclass(kw_only=True) class Tftp(Driver): """TFTP Server driver for Jumpstarter @@ -40,7 +46,7 @@ class Tftp(Driver): """ root_dir: str = "/var/lib/tftpboot" - host: str = field(default='') + host: str = field(default="") port: int = 69 server: Optional["TftpServer"] = field(init=False, default=None) server_thread: Optional[threading.Thread] = field(init=False, default=None) @@ -53,7 +59,7 @@ def __post_init__(self): super().__post_init__() os.makedirs(self.root_dir, exist_ok=True) - if self.host == '': + if self.host == "": self.host = self.get_default_ip() def get_default_ip(self): diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py index 3f0f6911b..f74e8ec07 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py @@ -21,12 +21,14 @@ def temp_dir(): with tempfile.TemporaryDirectory() as tmpdir: yield tmpdir + @pytest.fixture def server(temp_dir): server = Tftp(root_dir=temp_dir, host="127.0.0.1") yield server server.close() + @pytest.mark.anyio async def test_tftp_file_operations(server): filename = "test.txt" @@ -60,17 +62,20 @@ async def send_data(): with pytest.raises(FileNotFound): server.delete_file("nonexistent.txt") + def test_tftp_host_config(temp_dir): custom_host = "192.168.1.1" server = Tftp(root_dir=temp_dir, host=custom_host) assert server.get_host() == custom_host + def test_tftp_root_directory_creation(temp_dir): new_dir = os.path.join(temp_dir, "new_tftp_root") server = Tftp(root_dir=new_dir) assert os.path.exists(new_dir) server.close() + @pytest.mark.anyio async def test_tftp_detect_corrupted_file(server): filename = "corrupted.txt" @@ -86,10 +91,12 @@ async def test_tftp_detect_corrupted_file(server): assert not server.check_file_checksum(filename, client_checksum) + @pytest.fixture def anyio_backend(): return "asyncio" + async def _upload_file(server, filename: str, data: bytes) -> str: send_stream, receive_stream = create_memory_object_stream() resource_uuid = uuid4() diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py index 37e83a3a1..1374df08e 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py @@ -33,18 +33,19 @@ class TftpServer: TFTP Server that handles read requests (RRQ). """ - def __init__(self, host: str, port: int, root_dir: str, - block_size: int = 512, timeout: float = 5.0, retries: int = 3): + def __init__( + self, host: str, port: int, root_dir: str, block_size: int = 512, timeout: float = 5.0, retries: int = 3 + ): self.host = host self.port = port self.root_dir = pathlib.Path(os.path.abspath(root_dir)) self.block_size = block_size self.timeout = timeout self.retries = retries - self.active_transfers: Set['TftpTransfer'] = set() + self.active_transfers: Set["TftpTransfer"] = set() self.shutdown_event = asyncio.Event() self.transport: Optional[asyncio.DatagramTransport] = None - self.protocol: Optional['TftpServerProtocol'] = None + self.protocol: Optional["TftpServerProtocol"] = None self.logger = logging.getLogger(self.__class__.__name__) self.ready_event = asyncio.Event() @@ -52,7 +53,7 @@ def __init__(self, host: str, port: int, root_dir: str, def address(self) -> Optional[Tuple[str, int]]: """Get the server's bound address and port.""" if self.transport: - return self.transport.get_extra_info('socket').getsockname() + return self.transport.get_extra_info("socket").getsockname() return None async def start(self): @@ -61,8 +62,7 @@ async def start(self): self.ready_event.set() self.transport, self.protocol = await loop.create_datagram_endpoint( - lambda: TftpServerProtocol(self), - local_addr=(self.host, self.port) + lambda: TftpServerProtocol(self), local_addr=(self.host, self.port) ) try: @@ -92,11 +92,11 @@ async def shutdown(self): self.logger.info("Shutdown signal received for TFTP server") self.shutdown_event.set() - def register_transfer(self, transfer: 'TftpTransfer'): + def register_transfer(self, transfer: "TftpTransfer"): self.active_transfers.add(transfer) self.logger.debug(f"Registered transfer: {transfer}") - def unregister_transfer(self, transfer: 'TftpTransfer'): + def unregister_transfer(self, transfer: "TftpTransfer"): self.active_transfers.discard(transfer) self.logger.debug(f"Unregistered transfer: {transfer}") @@ -130,7 +130,7 @@ def datagram_received(self, data: bytes, addr: Tuple[str, int]): return try: - opcode = Opcode(int.from_bytes(data[0:2], 'big')) + opcode = Opcode(int.from_bytes(data[0:2], "big")) except ValueError: self.logger.error(f"Unknown opcode from {addr}") self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unknown opcode") @@ -166,9 +166,9 @@ async def _handle_read_request(self, data: bytes, addr: Tuple[str, int]): def _send_oack(self, addr: Tuple[str, int], options: dict): """Send Option Acknowledgment (OACK) packet.""" - oack_data = Opcode.OACK.to_bytes(2, 'big') + oack_data = Opcode.OACK.to_bytes(2, "big") for opt_name, opt_value in options.items(): - oack_data += f"{opt_name}\0{str(opt_value)}\0".encode('utf-8') + oack_data += f"{opt_name}\0{str(opt_value)}\0".encode("utf-8") if self.transport: self.transport.sendto(oack_data, addr) @@ -176,39 +176,36 @@ def _send_oack(self, addr: Tuple[str, int], options: dict): def _send_error(self, addr: Tuple[str, int], error_code: TftpErrorCode, message: str): error_packet = ( - Opcode.ERROR.to_bytes(2, 'big') + - error_code.to_bytes(2, 'big') + - message.encode('utf-8') + b'\x00' + Opcode.ERROR.to_bytes(2, "big") + error_code.to_bytes(2, "big") + message.encode("utf-8") + b"\x00" ) if self.transport: self.transport.sendto(error_packet, addr) self.logger.debug(f"Sent ERROR {error_code.name} to {addr}: {message}") def _parse_request(self, data: bytes) -> Tuple[str, str, dict]: - parts = data[2:].split(b'\x00') + parts = data[2:].split(b"\x00") if len(parts) < 2: raise ValueError("Invalid RRQ format") - filename = parts[0].decode('utf-8') + filename = parts[0].decode("utf-8") if len(filename) > 255: # RFC 1350 doesn't specify a limit raise ValueError("Filename too long") if not all(c.isprintable() and c not in '<>:"/\\|?*' for c in filename): raise ValueError("Invalid characters in filename") - if '\x00' in filename: + if "\x00" in filename: raise ValueError("Null byte in filename") - mode = parts[1].decode('utf-8').lower() + mode = parts[1].decode("utf-8").lower() options = self._parse_options(parts[2:]) return filename, mode, options - def _parse_options(self, option_parts: list) -> dict: options = {} i = 0 while i < len(option_parts) - 1: try: - opt_name = option_parts[i].decode('utf-8').lower() - opt_value = option_parts[i + 1].decode('utf-8') + opt_name = option_parts[i].decode("utf-8").lower() + opt_value = option_parts[i + 1].decode("utf-8") options[opt_name] = opt_value i += 2 except Exception: @@ -216,7 +213,7 @@ def _parse_options(self, option_parts: list) -> dict: return options def _validate_mode(self, mode: str, addr: Tuple[str, int]) -> bool: - if mode not in ('netascii', 'octet'): + if mode not in ("netascii", "octet"): self.logger.warning(f"Unsupported transfer mode '{mode}' from {addr}") self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unsupported transfer mode") return False @@ -248,14 +245,12 @@ def _negotiate_block_size(self, requested_blksize: Optional[str]) -> int: return blksize else: self.logger.warning( - f"Requested block size {blksize} out of range (512-65464), " - f"using default: {self.server.block_size}" + f"Requested block size {blksize} out of range (512-65464), using default: {self.server.block_size}" ) return self.server.block_size except ValueError: self.logger.warning( - f"Invalid block size value '{requested_blksize}', " - f"using default: {self.server.block_size}" + f"Invalid block size value '{requested_blksize}', using default: {self.server.block_size}" ) return self.server.block_size @@ -269,15 +264,11 @@ def _negotiate_timeout(self, requested_timeout: Optional[str]) -> float: return float(timeout) else: self.logger.warning( - f"Timeout value {timeout} out of range (1-255), " - f"using default: {self.server.timeout}" + f"Timeout value {timeout} out of range (1-255), using default: {self.server.timeout}" ) return self.server.timeout except ValueError: - self.logger.warning( - f"Invalid timeout value '{requested_timeout}', " - f"using default: {self.server.timeout}" - ) + self.logger.warning(f"Invalid timeout value '{requested_timeout}', using default: {self.server.timeout}") return self.server.timeout def _negotiate_options(self, options: dict) -> Tuple[dict, int, float]: @@ -285,21 +276,21 @@ def _negotiate_options(self, options: dict) -> Tuple[dict, int, float]: blksize = self.server.block_size timeout = self.server.timeout - if 'blksize' in options: - requested = options['blksize'] + if "blksize" in options: + requested = options["blksize"] blksize = self._negotiate_block_size(requested) - negotiated['blksize'] = blksize + negotiated["blksize"] = blksize - if 'timeout' in options: - requested = options['timeout'] + if "timeout" in options: + requested = options["timeout"] timeout = self._negotiate_timeout(requested) - negotiated['timeout'] = int(timeout) + negotiated["timeout"] = int(timeout) return negotiated, blksize, timeout - - async def _start_transfer(self, filepath: pathlib.Path, addr: Tuple[str, int], - blksize: int, timeout: float, negotiated_options: dict): + async def _start_transfer( + self, filepath: pathlib.Path, addr: Tuple[str, int], blksize: int, timeout: float, negotiated_options: dict + ): transfer = TftpReadTransfer( server=self.server, filepath=filepath, @@ -307,11 +298,12 @@ async def _start_transfer(self, filepath: pathlib.Path, addr: Tuple[str, int], block_size=blksize, timeout=timeout, retries=self.server.retries, - negotiated_options=negotiated_options + negotiated_options=negotiated_options, ) self.server.register_transfer(transfer) asyncio.create_task(transfer.start()) + def is_subpath(path: pathlib.Path, root: pathlib.Path) -> bool: try: path.relative_to(root) @@ -325,8 +317,15 @@ class TftpTransfer: Base class for TFTP transfers. """ - def __init__(self, server: TftpServer, filepath: pathlib.Path, client_addr: Tuple[str, int], - block_size: int, timeout: float, retries: int): + def __init__( + self, + server: TftpServer, + filepath: pathlib.Path, + client_addr: Tuple[str, int], + block_size: int, + timeout: float, + retries: int, + ): self.server = server self.filepath = filepath self.client_addr = client_addr @@ -334,7 +333,7 @@ def __init__(self, server: TftpServer, filepath: pathlib.Path, client_addr: Tupl self.timeout = timeout self.retries = retries self.transport: Optional[asyncio.DatagramTransport] = None - self.protocol: Optional['TftpTransferProtocol'] = None + self.protocol: Optional["TftpTransferProtocol"] = None self.cleanup_task: Optional[asyncio.Task] = None self.logger = logging.getLogger(self.__class__.__name__) @@ -352,15 +351,23 @@ async def cleanup(self): class TftpReadTransfer(TftpTransfer): - def __init__(self, server: TftpServer, filepath: pathlib.Path, client_addr: Tuple[str, int], - block_size: int, timeout: float, retries: int, negotiated_options: Optional[dict] = None): + def __init__( + self, + server: TftpServer, + filepath: pathlib.Path, + client_addr: Tuple[str, int], + block_size: int, + timeout: float, + retries: int, + negotiated_options: Optional[dict] = None, + ): super().__init__( server=server, filepath=filepath, client_addr=client_addr, block_size=block_size, timeout=timeout, - retries=retries + retries=retries, ) self.block_num = 0 self.ack_received = asyncio.Event() @@ -390,17 +397,14 @@ async def _initialize_transfer(self) -> bool: loop = asyncio.get_running_loop() self.transport, self.protocol = await loop.create_datagram_endpoint( - lambda: TftpTransferProtocol(self), - local_addr=('0.0.0.0', 0), - remote_addr=self.client_addr + lambda: TftpTransferProtocol(self), local_addr=("0.0.0.0", 0), remote_addr=self.client_addr ) - local_addr = self.transport.get_extra_info('sockname') + local_addr = self.transport.get_extra_info("sockname") self.logger.debug(f"Transfer bound to local {local_addr}") # Only send OACK if we have non-default options to negotiate if self.negotiated_options and ( - self.negotiated_options['blksize'] != 512 or - self.negotiated_options['timeout'] != self.server.timeout + self.negotiated_options["blksize"] != 512 or self.negotiated_options["timeout"] != self.server.timeout ): oack_packet = self._create_oack_packet() if not await self._send_with_retries(oack_packet, is_oack=True): @@ -411,7 +415,7 @@ async def _initialize_transfer(self) -> bool: return True async def _perform_transfer(self): - async with aiofiles.open(self.filepath, 'rb') as f: + async with aiofiles.open(self.filepath, "rb") as f: while True: if self.server.shutdown_event.is_set(): self.logger.info(f"Server shutdown detected, stopping transfer to {self.client_addr}") @@ -428,7 +432,7 @@ async def _handle_data_block(self, data: bytes) -> bool: """ if not data and self.block_num == 1: # Empty file case - packet = self._create_data_packet(b'') + packet = self._create_data_packet(b"") await self._send_with_retries(packet) return False elif data: @@ -450,7 +454,7 @@ async def _handle_data_block(self, data: bytes) -> bool: return True else: # EOF reached - packet = self._create_data_packet(b'') + packet = self._create_data_packet(b"") success = await self._send_with_retries(packet) if not success: self.logger.error(f"Failed to send final block {self.block_num}") @@ -459,25 +463,21 @@ async def _handle_data_block(self, data: bytes) -> bool: return False def _create_oack_packet(self) -> bytes: - packet = Opcode.OACK.to_bytes(2, 'big') + packet = Opcode.OACK.to_bytes(2, "big") for opt_name, opt_value in self.negotiated_options.items(): - packet += f"{opt_name}\0{str(opt_value)}\0".encode('utf-8') + packet += f"{opt_name}\0{str(opt_value)}\0".encode("utf-8") return packet def _create_data_packet(self, data: bytes) -> bytes: - return ( - Opcode.DATA.to_bytes(2, 'big') + - self.block_num.to_bytes(2, 'big') + - data - ) + return Opcode.DATA.to_bytes(2, "big") + self.block_num.to_bytes(2, "big") + data def _send_packet(self, packet: bytes): self.transport.sendto(packet) - if packet[0:2] == Opcode.DATA.to_bytes(2, 'big'): - block = int.from_bytes(packet[2:4], 'big') + if packet[0:2] == Opcode.DATA.to_bytes(2, "big"): + block = int.from_bytes(packet[2:4], "big") data_length = len(packet) - 4 self.logger.debug(f"Sent DATA block {block} ({data_length} bytes) to {self.client_addr}") - elif packet[0:2] == Opcode.OACK.to_bytes(2, 'big'): + elif packet[0:2] == Opcode.OACK.to_bytes(2, "big"): self.logger.debug(f"Sent OACK to {self.client_addr}") async def _send_with_retries(self, packet: bytes, is_oack: bool = False) -> bool: @@ -488,8 +488,7 @@ async def _send_with_retries(self, packet: bytes, is_oack: bool = False) -> bool try: self._send_packet(packet) self.logger.debug( - f"Sent {'OACK' if is_oack else 'DATA'} block {expected_block}, " - f"waiting for ACK (Attempt {attempt})" + f"Sent {'OACK' if is_oack else 'DATA'} block {expected_block}, waiting for ACK (Attempt {attempt})" ) self.ack_received.clear() await asyncio.wait_for(self.ack_received.wait(), timeout=self.timeout) @@ -524,6 +523,7 @@ def handle_ack(self, block_num: int): else: self.logger.warning(f"Out of sequence ACK: expected {self.block_num}, got {block_num}") + class TftpTransferProtocol(asyncio.DatagramProtocol): """ Protocol for handling ACKs during a TFTP transfer. @@ -535,7 +535,7 @@ def __init__(self, transfer: TftpReadTransfer): def connection_made(self, transport: asyncio.DatagramTransport): self.transfer.transport = transport - local_addr = transport.get_extra_info('sockname') + local_addr = transport.get_extra_info("sockname") self.logger.debug(f"Transfer protocol connection established on {local_addr} for {self.transfer.client_addr}") def datagram_received(self, data: bytes, addr: Tuple[str, int]): @@ -549,21 +549,20 @@ def datagram_received(self, data: bytes, addr: Tuple[str, int]): return try: - opcode = Opcode(int.from_bytes(data[0:2], 'big')) + opcode = Opcode(int.from_bytes(data[0:2], "big")) except ValueError: self.logger.error(f"Unknown opcode from {addr}") self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unknown opcode") return if opcode == Opcode.ACK: - block_num = int.from_bytes(data[2:4], 'big') + block_num = int.from_bytes(data[2:4], "big") self.logger.debug(f"Received ACK for block {block_num} from {addr}") self.transfer.handle_ack(block_num) else: self.logger.warning(f"Unexpected opcode {opcode} from {addr}") self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unexpected opcode") - def error_received(self, exc): self.logger.error(f"Error received: {exc}") @@ -572,9 +571,7 @@ def connection_lost(self, exc): def _send_error(self, addr: Tuple[str, int], error_code: TftpErrorCode, message: str): error_packet = ( - Opcode.ERROR.to_bytes(2, 'big') + - error_code.to_bytes(2, 'big') + - message.encode('utf-8') + b'\x00' + Opcode.ERROR.to_bytes(2, "big") + error_code.to_bytes(2, "big") + message.encode("utf-8") + b"\x00" ) if self.transfer.transport: self.transfer.transport.sendto(error_packet) diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py index 5242f0c99..679e55cc2 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py @@ -13,11 +13,7 @@ async def tftp_server(): test_file_path = Path(temp_dir) / "test.txt" test_file_path.write_text("Hello, TFTP!") - server = TftpServer( - host="127.0.0.1", - port=0, - root_dir=temp_dir - ) + server = TftpServer(host="127.0.0.1", port=0, root_dir=temp_dir) server_task = asyncio.create_task(server.start()) for _ in range(10): @@ -42,14 +38,13 @@ async def tftp_server(): except asyncio.CancelledError: pass + async def create_test_client(server_port): loop = asyncio.get_running_loop() - transport, protocol = await loop.create_datagram_endpoint( - asyncio.DatagramProtocol, - remote_addr=('127.0.0.1', 0) - ) + transport, protocol = await loop.create_datagram_endpoint(asyncio.DatagramProtocol, remote_addr=("127.0.0.1", 0)) return transport, protocol + @pytest.mark.asyncio async def test_server_startup_and_shutdown(tftp_server): """Test that server starts up and shuts down cleanly.""" @@ -64,6 +59,7 @@ async def test_server_startup_and_shutdown(tftp_server): assert True + @pytest.mark.asyncio async def test_read_request_for_existing_file(tftp_server): """Test reading an existing file from the server.""" @@ -76,9 +72,9 @@ async def test_read_request_for_existing_file(tftp_server): transport, _ = await create_test_client(server.port) rrq_packet = ( - Opcode.RRQ.to_bytes(2, 'big') + - b'test.txt\x00' + # filename - b'octet\x00' # mode + Opcode.RRQ.to_bytes(2, "big") + + b"test.txt\x00" # filename + + b"octet\x00" # mode ) transport.sendto(rrq_packet) @@ -91,6 +87,7 @@ async def test_read_request_for_existing_file(tftp_server): await server.shutdown() await server_task + @pytest.mark.asyncio async def test_read_request_for_nonexistent_file(tftp_server): """Test reading a non-existent file returns appropriate error.""" @@ -101,11 +98,7 @@ async def test_read_request_for_nonexistent_file(tftp_server): try: transport, protocol = await create_test_client(server.port) - rrq_packet = ( - Opcode.RRQ.to_bytes(2, 'big') + - b'nonexistent.txt\x00' + - b'octet\x00' - ) + rrq_packet = Opcode.RRQ.to_bytes(2, "big") + b"nonexistent.txt\x00" + b"octet\x00" transport.sendto(rrq_packet) assert server.transport is not None @@ -115,20 +108,16 @@ async def test_read_request_for_nonexistent_file(tftp_server): await server.shutdown() await server_task + @pytest.mark.asyncio async def test_write_request_rejection(tftp_server): """Test that write requests are properly rejected (server is read-only).""" server, temp_dir, server_port = tftp_server server_task = asyncio.create_task(server.start()) - try: transport, _ = await create_test_client(server.port) - wrq_packet = ( - Opcode.WRQ.to_bytes(2, 'big') + - b'test.txt\x00' + - b'octet\x00' - ) + wrq_packet = Opcode.WRQ.to_bytes(2, "big") + b"test.txt\x00" + b"octet\x00" transport.sendto(wrq_packet) @@ -139,6 +128,7 @@ async def test_write_request_rejection(tftp_server): await server.shutdown() await server_task + @pytest.mark.asyncio async def test_invalid_packet_handling(tftp_server): server, temp_dir, server_port = tftp_server @@ -147,7 +137,7 @@ async def test_invalid_packet_handling(tftp_server): try: transport, _ = await create_test_client(server.port) - transport.sendto(b'\x00\x01') + transport.sendto(b"\x00\x01") assert server.transport is not None @@ -156,6 +146,7 @@ async def test_invalid_packet_handling(tftp_server): await server.shutdown() await server_task + @pytest.mark.asyncio async def test_path_traversal_prevention(tftp_server): """Test that path traversal attempts are blocked.""" @@ -167,11 +158,7 @@ async def test_path_traversal_prevention(tftp_server): try: transport, _ = await create_test_client(server.port) - rrq_packet = ( - Opcode.RRQ.to_bytes(2, 'big') + - b'../../../etc/passwd\x00' + - b'octet\x00' - ) + rrq_packet = Opcode.RRQ.to_bytes(2, "big") + b"../../../etc/passwd\x00" + b"octet\x00" transport.sendto(rrq_packet) @@ -182,6 +169,7 @@ async def test_path_traversal_prevention(tftp_server): await server.shutdown() await server_task + @pytest.mark.asyncio async def test_options_negotiation(tftp_server): """Test that options (blksize, timeout) are properly negotiated.""" @@ -194,13 +182,13 @@ async def test_options_negotiation(tftp_server): # RRQ with options rrq_packet = ( - Opcode.RRQ.to_bytes(2, 'big') + - b'test.txt\x00' + - b'octet\x00' + - b'blksize\x00' + - b'1024\x00' + - b'timeout\x00' + - b'3\x00' + Opcode.RRQ.to_bytes(2, "big") + + b"test.txt\x00" + + b"octet\x00" + + b"blksize\x00" + + b"1024\x00" + + b"timeout\x00" + + b"3\x00" ) transport.sendto(rrq_packet) @@ -212,6 +200,7 @@ async def test_options_negotiation(tftp_server): await server.shutdown() await server_task + @pytest.mark.asyncio async def test_retry_mechanism(tftp_server): server, _, server_port = tftp_server @@ -234,29 +223,21 @@ def datagram_received(self, data, addr): try: loop = asyncio.get_running_loop() - transport, protocol = await loop.create_datagram_endpoint( - lambda: TestProtocol(), - local_addr=('127.0.0.1', 0) - ) + transport, protocol = await loop.create_datagram_endpoint(lambda: TestProtocol(), local_addr=("127.0.0.1", 0)) assert transport is not None, "Failed to create transport" - rrq_packet = ( - Opcode.RRQ.to_bytes(2, 'big') + - b'test.txt\x00' + - b'octet\x00' - ) + rrq_packet = Opcode.RRQ.to_bytes(2, "big") + b"test.txt\x00" + b"octet\x00" - transport.sendto(rrq_packet, ('127.0.0.1', server_port)) + transport.sendto(rrq_packet, ("127.0.0.1", server_port)) await asyncio.sleep(server.timeout * 2) - data_packets = [p for p in protocol.received_packets - if p[0:2] == Opcode.DATA.to_bytes(2, 'big')] + data_packets = [p for p in protocol.received_packets if p[0:2] == Opcode.DATA.to_bytes(2, "big")] assert len(data_packets) > 1, "Server should have retried sending DATA packet" - block_numbers = {int.from_bytes(p[2:4], 'big') for p in data_packets} + block_numbers = {int.from_bytes(p[2:4], "big") for p in data_packets} assert len(block_numbers) == 1, "All retried packets should be for the same block" assert 1 in block_numbers, "First block number should be 1" @@ -278,13 +259,13 @@ async def test_invalid_options_handling(tftp_server): transport, _ = await create_test_client(server.port) rrq_packet = ( - Opcode.RRQ.to_bytes(2, 'big') + - b'test.txt\x00' + - b'octet\x00' + - b'blksize\x00' + - b'invalid\x00' + - b'timeout\x00' + - b'999999\x00' + Opcode.RRQ.to_bytes(2, "big") + + b"test.txt\x00" + + b"octet\x00" + + b"blksize\x00" + + b"invalid\x00" + + b"timeout\x00" + + b"999999\x00" ) transport.sendto(rrq_packet)