diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py index ffc8fd7f9..fbf9536fc 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py @@ -163,14 +163,13 @@ async def _handle_read_request(self, data: bytes, addr: Tuple[str, int]): if not self._validate_mode(mode, addr): return - resolved_path = self._resolve_and_validate_path(filename, addr) + resolved_path = await self._resolve_and_validate_path(filename, addr) if not resolved_path: return negotiated_options, blksize, timeout = self._negotiate_options(options) self.logger.info(f"Negotiated options: {negotiated_options}") await self._start_transfer(resolved_path, addr, blksize, timeout, negotiated_options) - except Exception as e: self.logger.error(f"Error handling RRQ from {addr}: {e}") self._send_error(addr, TftpErrorCode.NOT_DEFINED, str(e)) @@ -230,17 +229,17 @@ def _validate_mode(self, mode: str, addr: Tuple[str, int]) -> bool: return False return True - def _resolve_and_validate_path(self, filename: str, addr: Tuple[str, int]) -> Optional[str]: + async def _resolve_and_validate_path(self, filename: str, addr: Tuple[str, int]) -> Optional[str]: try: - stat = self.server.operator.stat(filename) + stat = await self.server.operator.stat(filename) except FileNotFoundError: self.logger.error(f"File not found: {filename}") self._send_error(addr, TftpErrorCode.FILE_NOT_FOUND, "File not found") return None if not stat.mode.is_file(): - self.logger.error(f"File not found: {filename}") - self._send_error(addr, TftpErrorCode.FILE_NOT_FOUND, "File not found") + self.logger.error(f"Not a file: {filename}") + self._send_error(addr, TftpErrorCode.FILE_NOT_FOUND, "Not a file") return None return filename @@ -425,14 +424,22 @@ async def _initialize_transfer(self) -> bool: return True async def _perform_transfer(self): - async with await self.server.operator.to_async_operator().open(self.filepath, "rb") as f: + async with await self.server.operator.open(self.filepath, "rb") as f: while True: if self.server.shutdown_event.is_set(): self.logger.info(f"Server shutdown detected, stopping transfer to {self.client_addr}") break - data = await f.read(size=self.block_size) - if not await self._handle_data_block(data): + # read a full block or until EOF + data = bytearray() + while len(data) < self.block_size: + chunk = await f.read(size=self.block_size - len(data)) + if not chunk: # EOF reached + break + data.extend(chunk) + + # send the data (converted to bytes) + if not await self._handle_data_block(bytes(data)): break async def _handle_data_block(self, data: bytes) -> bool: diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py index 6ff95daac..72bba485d 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py @@ -3,7 +3,7 @@ from pathlib import Path import pytest -from opendal import Operator +from opendal import AsyncOperator from jumpstarter_driver_tftp.server import Opcode, TftpServer @@ -14,7 +14,7 @@ async def tftp_server(): test_file_path = Path(temp_dir) / "test.txt" test_file_path.write_text("Hello, TFTP!") - server = TftpServer(host="127.0.0.1", port=0, operator=Operator("fs", root=str(temp_dir))) + server = TftpServer(host="127.0.0.1", port=0, operator=AsyncOperator("fs", root=str(temp_dir))) server_task = asyncio.create_task(server.start()) for _ in range(10):