From 663d67c1db21a2959802cfd6539f88f1708fcb18 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 26 Feb 2025 12:41:42 -0500 Subject: [PATCH 1/7] Use Opendal driver in Tftp driver --- .../jumpstarter_driver_tftp/client.py | 65 +--------- .../jumpstarter_driver_tftp/driver.py | 110 +---------------- .../jumpstarter_driver_tftp/driver_test.py | 113 ++++-------------- .../jumpstarter-driver-tftp/pyproject.toml | 1 + uv.lock | 2 + 5 files changed, 34 insertions(+), 257 deletions(-) diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py index 24081eea0..9cc53e55c 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py @@ -1,16 +1,10 @@ -import hashlib from dataclasses import dataclass -from pathlib import Path -from jumpstarter_driver_opendal.adapter import OpendalAdapter -from opendal import Operator - -from . import CHUNK_SIZE -from jumpstarter.client import DriverClient +from jumpstarter_driver_composite.client import CompositeClient @dataclass(kw_only=True) -class TftpServerClient(DriverClient): +class TftpServerClient(CompositeClient): """ Client interface for TFTP Server driver @@ -38,54 +32,6 @@ def stop(self): """ self.call("stop") - def list_files(self) -> list[str]: - """ - List files in the TFTP server root directory - - Returns: - list[str]: A list of filenames present in the TFTP server's root directory - """ - return self.call("list_files") - - def put_file(self, operator: Operator, path: str): - filename = Path(path).name - client_checksum = self._compute_checksum(operator, path) - - if self.call("check_file_checksum", filename, client_checksum): - self.logger.info(f"Skipping upload of identical file: {filename}") - return filename - - with OpendalAdapter(client=self, operator=operator, path=path, mode="rb") as handle: - return self.call("put_file", filename, handle, client_checksum) - - def put_local_file(self, filepath: str): - absolute = Path(filepath).resolve() - filename = absolute.name - - operator = Operator("fs", root="/") - client_checksum = self._compute_checksum(operator, str(absolute)) - - if self.call("check_file_checksum", filename, client_checksum): - self.logger.info(f"Skipping upload of identical file: {filename}") - return filename - - self.logger.info(f"checksum: {client_checksum}") - with OpendalAdapter(client=self, operator=operator, path=str(absolute), mode="rb") as handle: - return self.call("put_file", filename, handle, client_checksum) - - def delete_file(self, filename: str): - """ - Delete a file from the TFTP server - - Args: - filename (str): Name of the file to delete - - Raises: - FileNotFound: If the specified file doesn't exist - TftpError: If deletion fails for other reasons - """ - return self.call("delete_file", filename) - def get_host(self) -> str: """ Get the host address the TFTP server is listening on @@ -103,10 +49,3 @@ def get_port(self) -> int: int: The port number (default is 69) """ return self.call("get_port") - - def _compute_checksum(self, operator: Operator, path: str) -> str: - hasher = hashlib.sha256() - with operator.open(path, "rb") as f: - while chunk := f.read(CHUNK_SIZE): - hasher.update(chunk) - return hasher.hexdigest() diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py index d90ea6773..489c9fe42 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py @@ -1,17 +1,14 @@ import asyncio -import hashlib import os import socket import threading from dataclasses import dataclass, field -from pathlib import Path from typing import Optional -from anyio.streams.file import FileWriteStream +from jumpstarter_driver_opendal.driver import Opendal from jumpstarter_driver_tftp.server import TftpServer -from . import CHUNK_SIZE from jumpstarter.driver import Driver, export @@ -27,12 +24,6 @@ class ServerNotRunning(TftpError): pass -class FileNotFound(TftpError): - """File not found""" - - pass - - @dataclass(kw_only=True) class Tftp(Driver): """TFTP Server driver for Jumpstarter @@ -59,6 +50,9 @@ def __post_init__(self): super().__post_init__() os.makedirs(self.root_dir, exist_ok=True) + + self.children["storage"] = Opendal(scheme="fs", kwargs={"root": self.root_dir}) + if self.host == "": self.host = self.get_default_ip() @@ -156,95 +150,6 @@ def stop(self): self.logger.info("TFTP server stopped successfully") self.server_thread = None - @export - def list_files(self) -> list[str]: - """List all files available in the TFTP server root directory. - - Returns: - list[str]: A list of filenames present in the root directory - """ - return os.listdir(self.root_dir) - - @export - async def put_file(self, filename: str, src_stream, client_checksum: str): - """Upload a file to the TFTP server. - - Args: - filename (str): Name of the file to create - src_stream: Source stream to read the file data from - client_checksum (str): SHA256 checksum of the file for verification - - Returns: - str: The filename that was uploaded - - Raises: - TftpError: If the file upload fails or path validation fails - """ - file_path = os.path.join(self.root_dir, filename) - - try: - if not Path(file_path).resolve().is_relative_to(Path(self.root_dir).resolve()): - raise TftpError("Invalid target path") - - async with await FileWriteStream.from_path(file_path) as dst: - async with self.resource(src_stream) as src: - async for chunk in src: - await dst.send(chunk) - - return filename - except Exception as e: - raise TftpError(f"Failed to upload file: {str(e)}") from e - - @export - def delete_file(self, filename: str): - """Delete a file from the TFTP server. - - Args: - filename (str): Name of the file to delete - - Returns: - str: The filename that was deleted - - Raises: - FileNotFound: If the specified file does not exist - TftpError: If the deletion operation fails - """ - file_path = os.path.join(self.root_dir, filename) - - if not os.path.exists(file_path): - raise FileNotFound(f"File {filename} not found") - - try: - os.remove(file_path) - return filename - except Exception as e: - raise TftpError(f"Failed to delete {filename}") from e - - @export - def check_file_checksum(self, filename: str, client_checksum: str) -> bool: - """Check if a file matches the expected checksum. - - Args: - filename (str): Name of the file to check - client_checksum (str): Expected SHA256 checksum - - Returns: - bool: True if the file exists and matches the checksum, False otherwise - """ - file_path = os.path.join(self.root_dir, filename) - self.logger.debug(f"checking checksum for file: {filename}") - self.logger.debug(f"file path: {file_path}") - - if not os.path.exists(file_path): - self.logger.debug(f"File {filename} does not exist") - return False - - current_checksum = self._compute_checksum(file_path) - self.logger.debug(f"Computed checksum: {current_checksum}") - self.logger.debug(f"Client checksum: {client_checksum}") - - return current_checksum == client_checksum - @export def get_host(self) -> str: """Get the host address the server is bound to. @@ -267,10 +172,3 @@ def close(self): if self.server_thread is not None: self.stop() super().close() - - def _compute_checksum(self, path: str) -> str: - hasher = hashlib.sha256() - with open(path, "rb") as f: - while chunk := f.read(CHUNK_SIZE): - hasher.update(chunk) - return hasher.hexdigest() diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py index f74e8ec07..e29b648bb 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py @@ -1,114 +1,51 @@ -import hashlib -import os -import tempfile -from pathlib import Path -from uuid import uuid4 - -import anyio import pytest -from anyio import create_memory_object_stream -from jumpstarter_driver_tftp.driver import ( - FileNotFound, - Tftp, -) +from jumpstarter_driver_tftp.driver import Tftp -from jumpstarter.common.resources import ClientStreamResource +from jumpstarter.common.utils import serve @pytest.fixture -def temp_dir(): - with tempfile.TemporaryDirectory() as tmpdir: - yield tmpdir +def anyio_backend(): + return "asyncio" @pytest.fixture -def server(temp_dir): - server = Tftp(root_dir=temp_dir, host="127.0.0.1") - yield server - server.close() +def tftp(tmp_path): + with serve(Tftp(root_dir=str(tmp_path), host="127.0.0.1")) as client: + try: + yield client + finally: + client.close() @pytest.mark.anyio -async def test_tftp_file_operations(server): +async def test_tftp_file_operations(tftp, tmp_path): + local_filename = "test_src.txt" filename = "test.txt" test_data = b"Hello" - client_checksum = hashlib.sha256(test_data).hexdigest() - - send_stream, receive_stream = create_memory_object_stream(max_buffer_size=10) - - resource_uuid = uuid4() - server.resources[resource_uuid] = receive_stream - - resource_handle = ClientStreamResource(uuid=resource_uuid).model_dump(mode="json") - async def send_data(): - await send_stream.send(test_data) - await send_stream.aclose() + (tmp_path / local_filename).write_bytes(test_data) - async with anyio.create_task_group() as tg: - tg.start_soon(send_data) - await server.put_file(filename, resource_handle, client_checksum) + file = tftp.storage.open(filename, "wb") + file.write(str(tmp_path / local_filename)) + file.close() - files = server.list_files() + files = list(tftp.storage.list("/")) assert filename in files - file_path = Path(server.root_dir) / filename - assert file_path.read_bytes() == test_data + tftp.storage.delete(filename) + assert filename not in list(tftp.storage.list("/")) - server.delete_file(filename) - assert filename not in server.list_files() - with pytest.raises(FileNotFound): - server.delete_file("nonexistent.txt") - - -def test_tftp_host_config(temp_dir): +def test_tftp_host_config(tmp_path): custom_host = "192.168.1.1" - server = Tftp(root_dir=temp_dir, host=custom_host) + server = Tftp(root_dir=str(tmp_path), host=custom_host) assert server.get_host() == custom_host -def test_tftp_root_directory_creation(temp_dir): - new_dir = os.path.join(temp_dir, "new_tftp_root") - server = Tftp(root_dir=new_dir) - assert os.path.exists(new_dir) +def test_tftp_root_directory_creation(tmp_path): + new_dir = tmp_path / "new_tftp_root" + server = Tftp(root_dir=str(new_dir)) + assert new_dir.exists() server.close() - - -@pytest.mark.anyio -async def test_tftp_detect_corrupted_file(server): - filename = "corrupted.txt" - original_data = b"Original Data" - client_checksum = hashlib.sha256(original_data).hexdigest() - - await _upload_file(server, filename, original_data) - - assert server.check_file_checksum(filename, client_checksum) - - file_path = Path(server.root_dir, filename) - file_path.write_bytes(b"corrupted Data") - - assert not server.check_file_checksum(filename, client_checksum) - - -@pytest.fixture -def anyio_backend(): - return "asyncio" - - -async def _upload_file(server, filename: str, data: bytes) -> str: - send_stream, receive_stream = create_memory_object_stream() - resource_uuid = uuid4() - server.resources[resource_uuid] = receive_stream - resource_handle = ClientStreamResource(uuid=resource_uuid).model_dump(mode="json") - - async def send_data(): - await send_stream.send(data) - await send_stream.aclose() - - async with anyio.create_task_group() as tg: - tg.start_soon(send_data) - await server.put_file(filename, resource_handle, hashlib.sha256(data).hexdigest()) - - return hashlib.sha256(data).hexdigest() diff --git a/packages/jumpstarter-driver-tftp/pyproject.toml b/packages/jumpstarter-driver-tftp/pyproject.toml index 50dfe17e4..3858d2da1 100644 --- a/packages/jumpstarter-driver-tftp/pyproject.toml +++ b/packages/jumpstarter-driver-tftp/pyproject.toml @@ -10,6 +10,7 @@ requires-python = ">=3.12" dependencies = [ "anyio>=4.6.2.post1", "jumpstarter", + "jumpstarter-driver-composite", "jumpstarter-driver-opendal", "aiofiles>=24.1.0" ] diff --git a/uv.lock b/uv.lock index dbe8d7d72..6bebfc121 100644 --- a/uv.lock +++ b/uv.lock @@ -1466,6 +1466,7 @@ dependencies = [ { name = "aiofiles" }, { name = "anyio" }, { name = "jumpstarter" }, + { name = "jumpstarter-driver-composite" }, { name = "jumpstarter-driver-opendal" }, ] @@ -1483,6 +1484,7 @@ requires-dist = [ { name = "aiofiles", specifier = ">=24.1.0" }, { name = "anyio", specifier = ">=4.6.2.post1" }, { name = "jumpstarter", editable = "packages/jumpstarter" }, + { name = "jumpstarter-driver-composite", editable = "packages/jumpstarter-driver-composite" }, { name = "jumpstarter-driver-opendal", editable = "packages/jumpstarter-driver-opendal" }, ] From 406397933b9762bf808bc81dc53c5ea0c05cb3c5 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 26 Feb 2025 12:54:06 -0500 Subject: [PATCH 2/7] Use opendal internally in TftpServer --- .../jumpstarter_driver_tftp/driver.py | 2 +- .../jumpstarter_driver_tftp/server.py | 36 +++++++++---------- .../jumpstarter_driver_tftp/server_test.py | 3 +- .../jumpstarter-driver-tftp/pyproject.toml | 3 +- uv.lock | 11 ------ 5 files changed, 21 insertions(+), 34 deletions(-) diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py index 489c9fe42..d2dac0406 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py @@ -73,7 +73,7 @@ def client(cls) -> str: def _start_server(self): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) - self.server = TftpServer(host=self.host, port=self.port, root_dir=self.root_dir) + self.server = TftpServer(host=self.host, port=self.port, operator=self.operator) try: self._loop_ready.set() self._loop.run_until_complete(self._run_server()) diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py index 1374df08e..d8cf3bb33 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py @@ -1,11 +1,10 @@ import asyncio import logging -import os import pathlib from enum import IntEnum from typing import Optional, Set, Tuple -import aiofiles +from opendal import Operator class Opcode(IntEnum): @@ -34,11 +33,11 @@ class TftpServer: """ def __init__( - self, host: str, port: int, root_dir: str, block_size: int = 512, timeout: float = 5.0, retries: int = 3 + self, host: str, port: int, operator: Operator, block_size: int = 512, timeout: float = 5.0, retries: int = 3 ): self.host = host self.port = port - self.root_dir = pathlib.Path(os.path.abspath(root_dir)) + self.operator = operator self.block_size = block_size self.timeout = timeout self.retries = retries @@ -219,21 +218,20 @@ def _validate_mode(self, mode: str, addr: Tuple[str, int]) -> bool: return False return True - def _resolve_and_validate_path(self, filename: str, addr: Tuple[str, int]) -> Optional[pathlib.Path]: - requested_path = self.server.root_dir / filename - resolved_path = requested_path.resolve() - - if not resolved_path.is_file(): - self.logger.error(f"File not found: {resolved_path}") + def _resolve_and_validate_path(self, filename: str, addr: Tuple[str, int]) -> Optional[str]: + try: + stat = self.server.operator.stat(filename) + except FileNotFoundError: + self.logger.error(f"File not found: {filename}") self._send_error(addr, TftpErrorCode.FILE_NOT_FOUND, "File not found") return None - if not is_subpath(resolved_path, self.server.root_dir): - self.logger.error(f"Access violation: {resolved_path} is outside root directory") - self._send_error(addr, TftpErrorCode.ACCESS_VIOLATION, "Access denied") + if not stat.mode.is_file: + self.logger.error(f"File not found: {filename}") + self._send_error(addr, TftpErrorCode.FILE_NOT_FOUND, "File not found") return None - return resolved_path + return filename def _negotiate_block_size(self, requested_blksize: Optional[str]) -> int: if requested_blksize is None: @@ -289,7 +287,7 @@ def _negotiate_options(self, options: dict) -> Tuple[dict, int, float]: return negotiated, blksize, timeout async def _start_transfer( - self, filepath: pathlib.Path, addr: Tuple[str, int], blksize: int, timeout: float, negotiated_options: dict + self, filepath: str, addr: Tuple[str, int], blksize: int, timeout: float, negotiated_options: dict ): transfer = TftpReadTransfer( server=self.server, @@ -354,7 +352,7 @@ class TftpReadTransfer(TftpTransfer): def __init__( self, server: TftpServer, - filepath: pathlib.Path, + filepath: str, client_addr: Tuple[str, int], block_size: int, timeout: float, @@ -377,7 +375,7 @@ def __init__( self.current_packet: Optional[bytes] = None async def start(self): - self.logger.info(f"Starting read transfer of '{self.filepath.name}' to {self.client_addr}") + self.logger.info(f"Starting read transfer of '{self.filepath}' to {self.client_addr}") if not await self._initialize_transfer(): return @@ -415,13 +413,13 @@ async def _initialize_transfer(self) -> bool: return True async def _perform_transfer(self): - async with aiofiles.open(self.filepath, "rb") as f: + async with await self.server.operator.to_async_operator().open(self.filepath, "rb") as f: while True: if self.server.shutdown_event.is_set(): self.logger.info(f"Server shutdown detected, stopping transfer to {self.client_addr}") break - data = await f.read(self.block_size) + data = await f.read(size=self.block_size) if not await self._handle_data_block(data): break diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py index 679e55cc2..6ff95daac 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py @@ -3,6 +3,7 @@ from pathlib import Path import pytest +from opendal import Operator from jumpstarter_driver_tftp.server import Opcode, TftpServer @@ -13,7 +14,7 @@ async def tftp_server(): test_file_path = Path(temp_dir) / "test.txt" test_file_path.write_text("Hello, TFTP!") - server = TftpServer(host="127.0.0.1", port=0, root_dir=temp_dir) + server = TftpServer(host="127.0.0.1", port=0, operator=Operator("fs", root=str(temp_dir))) server_task = asyncio.create_task(server.start()) for _ in range(10): diff --git a/packages/jumpstarter-driver-tftp/pyproject.toml b/packages/jumpstarter-driver-tftp/pyproject.toml index 3858d2da1..64a6a42dc 100644 --- a/packages/jumpstarter-driver-tftp/pyproject.toml +++ b/packages/jumpstarter-driver-tftp/pyproject.toml @@ -11,8 +11,7 @@ dependencies = [ "anyio>=4.6.2.post1", "jumpstarter", "jumpstarter-driver-composite", - "jumpstarter-driver-opendal", - "aiofiles>=24.1.0" + "jumpstarter-driver-opendal" ] [dependency-groups] diff --git a/uv.lock b/uv.lock index 6bebfc121..62742f552 100644 --- a/uv.lock +++ b/uv.lock @@ -55,15 +55,6 @@ docs = [ { name = "sphinxcontrib-mermaid", specifier = ">=0.9.2" }, ] -[[package]] -name = "aiofiles" -version = "24.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0b/03/a88171e277e8caa88a4c77808c20ebb04ba74cc4681bf1e9416c862de237/aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c", size = 30247 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/45/30bb92d442636f570cb5651bc661f52b610e2eec3f891a5dc3a4c3667db0/aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5", size = 15896 }, -] - [[package]] name = "aiohappyeyeballs" version = "2.4.6" @@ -1463,7 +1454,6 @@ name = "jumpstarter-driver-tftp" version = "0.1.0" source = { editable = "packages/jumpstarter-driver-tftp" } dependencies = [ - { name = "aiofiles" }, { name = "anyio" }, { name = "jumpstarter" }, { name = "jumpstarter-driver-composite" }, @@ -1481,7 +1471,6 @@ dev = [ [package.metadata] requires-dist = [ - { name = "aiofiles", specifier = ">=24.1.0" }, { name = "anyio", specifier = ">=4.6.2.post1" }, { name = "jumpstarter", editable = "packages/jumpstarter" }, { name = "jumpstarter-driver-composite", editable = "packages/jumpstarter-driver-composite" }, From 282c63f9e1274ecd2e013f8072d51bc51187d81b Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 26 Feb 2025 13:06:02 -0500 Subject: [PATCH 3/7] Add hash method to opendal operator --- .../jumpstarter_driver_opendal/client.py | 8 ++++++++ .../jumpstarter_driver_opendal/driver.py | 20 ++++++++++++++++++- .../jumpstarter_driver_opendal/driver_test.py | 6 ++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py index 86a2170fc..f66e79217 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py @@ -3,6 +3,7 @@ from collections.abc import Generator from dataclasses import dataclass from pathlib import Path +from typing import Literal from uuid import UUID import asyncclick as click @@ -127,6 +128,13 @@ def stat(self, /, path: str) -> Metadata: """ return self.call("stat", path) + @validate_call(validate_return=True) + def hash(self, /, path: str, algo: Literal["md5", "sha256"] = "sha256") -> str: + """ + Get current path's hash + """ + return self.call("hash", path, algo) + @validate_call(validate_return=True) def copy(self, /, source: str, target: str): """ diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py index c547d7ce4..185f7ec5f 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py @@ -1,8 +1,9 @@ +import hashlib from abc import ABCMeta, abstractmethod from collections.abc import AsyncGenerator from dataclasses import dataclass, field from tempfile import NamedTemporaryFile, _TemporaryFileWrapper -from typing import Any +from typing import Any, Literal from uuid import UUID, uuid4 from anyio.streams.file import FileReadStream, FileWriteStream @@ -98,6 +99,23 @@ async def file_writable(self, /, fd: UUID) -> bool: async def stat(self, /, path: str) -> Metadata: return Metadata.model_validate(await self._operator.stat(path), from_attributes=True) + @export + @validate_call(validate_return=True) + async def hash(self, /, path: str, algo: Literal["md5", "sha256"] = "sha256") -> str: + match algo: + case "md5": + m = hashlib.md5() + case "sha256": + m = hashlib.sha256() + async with await self._operator.open(path, "rb") as f: + while True: + data = await f.read(size=65536) + if len(data) == 0: + break + m.update(data) + + return m.hexdigest() + @export @validate_call(validate_return=True) async def copy(self, /, source: str, target: str): diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py index ae735312c..bda001b49 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py @@ -48,6 +48,12 @@ def test_drivers_opendal(tmp_path): assert test_file.tell() == 0 assert test_file.seek(2) == 2 + assert client.hash("test_dir/test_file", "md5") == "5d41402abc4b2a76b9719d911017c592" + assert ( + client.hash("test_dir/test_file", "sha256") + == "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824" + ) + test_file.read(str(tmp_path / "dst")) assert (tmp_path / "dst").read_text() == "llo" From f55fffcc35f2765b837674a119ac4531dfc4d86c Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 26 Feb 2025 13:13:53 -0500 Subject: [PATCH 4/7] Make HashAlgo a shared type --- .../jumpstarter_driver_opendal/client.py | 5 ++--- .../jumpstarter_driver_opendal/common.py | 1 + .../jumpstarter_driver_opendal/driver.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py index f66e79217..97a2a141f 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py @@ -3,7 +3,6 @@ from collections.abc import Generator from dataclasses import dataclass from pathlib import Path -from typing import Literal from uuid import UUID import asyncclick as click @@ -11,7 +10,7 @@ from pydantic import ConfigDict, validate_call from .adapter import OpendalAdapter -from .common import Capability, Metadata, Mode, PresignedRequest +from .common import Capability, HashAlgo, Metadata, Mode, PresignedRequest from jumpstarter.client import DriverClient @@ -129,7 +128,7 @@ def stat(self, /, path: str) -> Metadata: return self.call("stat", path) @validate_call(validate_return=True) - def hash(self, /, path: str, algo: Literal["md5", "sha256"] = "sha256") -> str: + def hash(self, /, path: str, algo: HashAlgo = "sha256") -> str: """ Get current path's hash """ diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/common.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/common.py index 3a7c26014..e9e028f14 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/common.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/common.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, model_validator Mode = Literal["rb", "wb"] +HashAlgo = Literal["md5", "sha256"] class EntryMode(BaseModel): diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py index 185f7ec5f..19e667572 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py @@ -3,7 +3,7 @@ from collections.abc import AsyncGenerator from dataclasses import dataclass, field from tempfile import NamedTemporaryFile, _TemporaryFileWrapper -from typing import Any, Literal +from typing import Any from uuid import UUID, uuid4 from anyio.streams.file import FileReadStream, FileWriteStream @@ -11,7 +11,7 @@ from pydantic import validate_call from .adapter import AsyncFileStream -from .common import Capability, Metadata, Mode, PresignedRequest +from .common import Capability, HashAlgo, Metadata, Mode, PresignedRequest from jumpstarter.driver import Driver, export @@ -101,7 +101,7 @@ async def stat(self, /, path: str) -> Metadata: @export @validate_call(validate_return=True) - async def hash(self, /, path: str, algo: Literal["md5", "sha256"] = "sha256") -> str: + async def hash(self, /, path: str, algo: HashAlgo = "sha256") -> str: match algo: case "md5": m = hashlib.md5() From f4d5e356927467f44d621d03e90824086f04b359 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 26 Feb 2025 13:25:17 -0500 Subject: [PATCH 5/7] Fixup tftp doctest --- docs/source/api-reference/drivers/tftp.md | 35 ++++--------------- .../jumpstarter_driver_tftp/driver.py | 2 +- 2 files changed, 8 insertions(+), 29 deletions(-) diff --git a/docs/source/api-reference/drivers/tftp.md b/docs/source/api-reference/drivers/tftp.md index 68cc8c261..858023824 100644 --- a/docs/source/api-reference/drivers/tftp.md +++ b/docs/source/api-reference/drivers/tftp.md @@ -41,10 +41,6 @@ export: .. autoclass:: jumpstarter_driver_tftp.driver.ServerNotRunning :members: :show-inheritance: - -.. autoclass:: jumpstarter_driver_tftp.driver.FileNotFound - :members: - :show-inheritance: ``` ## Examples @@ -53,6 +49,7 @@ export: >>> import tempfile >>> import os >>> from jumpstarter_driver_tftp.driver import Tftp +>>> from jumpstarter.common.utils import serve >>> with tempfile.TemporaryDirectory() as tmp_dir: ... # Create a test file ... test_file = os.path.join(tmp_dir, "test.txt") @@ -60,30 +57,12 @@ export: ... _ = f.write("hello") ... ... # Start TFTP server -... tftp = Tftp(root_dir=tmp_dir, host="127.0.0.1", port=6969) -... tftp.start() +... with serve(Tftp(root_dir=tmp_dir, host="127.0.0.1", port=6969)) as tftp: +... tftp.start() ... -... # List files -... files = tftp.list_files() -... assert "test.txt" in files +... # List files +... files = list(tftp.storage.list("/")) +... assert "test.txt" in files ... -... tftp.stop() -``` - -```{testsetup} * -import tempfile -import os -from jumpstarter_driver_tftp.driver import Tftp -from jumpstarter.common.utils import serve - -# Create a persistent temp dir that won't be removed by the example -TEST_DIR = tempfile.mkdtemp(prefix='tftp-test-') -instance = serve(Tftp(root_dir=TEST_DIR, host="127.0.0.1")) -client = instance.__enter__() -``` - -```{testcleanup} * -instance.__exit__(None, None, None) -import shutil -shutil.rmtree(TEST_DIR, ignore_errors=True) +... tftp.stop() ``` diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py index d2dac0406..68269df70 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py @@ -73,7 +73,7 @@ def client(cls) -> str: def _start_server(self): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) - self.server = TftpServer(host=self.host, port=self.port, operator=self.operator) + self.server = TftpServer(host=self.host, port=self.port, operator=self.children["storage"]._operator) try: self._loop_ready.set() self._loop.run_until_complete(self._run_server()) From 0b96ab80d76c5dd2fc628fcb2de23cfc1e7b222e Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 26 Feb 2025 13:31:33 -0500 Subject: [PATCH 6/7] Ignore warning about shadowed field --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 7d8672a02..5793a70f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,3 +71,6 @@ skip_empty = true [tool.pytest.ini_options] addopts = "--capture=no --doctest-modules --cov --cov-report=html --cov-report=xml" +filterwarnings = [ + 'ignore:Field name "copy" in "Capability" shadows an attribute in parent "BaseModel":UserWarning', +] From 2526e390d42132ae8cf53184ac94ad0993833f33 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 26 Feb 2025 13:51:39 -0500 Subject: [PATCH 7/7] Make opendal client methods accept os.PathLike --- .../jumpstarter_driver_opendal/client.py | 34 +++++++++---------- .../jumpstarter_driver_opendal/common.py | 3 +- .../jumpstarter_driver_opendal/driver_test.py | 4 +-- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py index 97a2a141f..11e3902f4 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py @@ -10,7 +10,7 @@ from pydantic import ConfigDict, validate_call from .adapter import OpendalAdapter -from .common import Capability, HashAlgo, Metadata, Mode, PresignedRequest +from .common import Capability, HashAlgo, Metadata, Mode, PathBuf, PresignedRequest from jumpstarter.client import DriverClient @@ -30,7 +30,7 @@ def __read(self, handle): return self.client.call("file_read", self.fd, handle) @validate_call(validate_return=True, config=ConfigDict(arbitrary_types_allowed=True)) - def write(self, path: str, operator: Operator | None = None): + def write(self, path: PathBuf, operator: Operator | None = None): """ Write into remote file with content from local file """ @@ -41,7 +41,7 @@ def write(self, path: str, operator: Operator | None = None): return self.__write(handle) @validate_call(validate_return=True, config=ConfigDict(arbitrary_types_allowed=True)) - def read(self, path: str, operator: Operator | None = None): + def read(self, path: PathBuf, operator: Operator | None = None): """ Read content from remote file into local file """ @@ -114,49 +114,49 @@ def writable(self) -> bool: class OpendalClient(DriverClient): @validate_call - def open(self, /, path: str, mode: Mode) -> OpendalFile: + def open(self, /, path: PathBuf, mode: Mode) -> OpendalFile: """ Open a file-like reader for the given path """ return OpendalFile(client=self, fd=self.call("open", path, mode)) @validate_call(validate_return=True) - def stat(self, /, path: str) -> Metadata: + def stat(self, /, path: PathBuf) -> Metadata: """ Get current path's metadata """ return self.call("stat", path) @validate_call(validate_return=True) - def hash(self, /, path: str, algo: HashAlgo = "sha256") -> str: + def hash(self, /, path: PathBuf, algo: HashAlgo = "sha256") -> str: """ Get current path's hash """ return self.call("hash", path, algo) @validate_call(validate_return=True) - def copy(self, /, source: str, target: str): + def copy(self, /, source: PathBuf, target: PathBuf): """ Copy source to target """ self.call("copy", source, target) @validate_call(validate_return=True) - def rename(self, /, source: str, target: str): + def rename(self, /, source: PathBuf, target: PathBuf): """ Rename source to target """ self.call("rename", source, target) @validate_call(validate_return=True) - def remove_all(self, /, path: str): + def remove_all(self, /, path: PathBuf): """ Remove all file under path """ self.call("remove_all", path) @validate_call(validate_return=True) - def create_dir(self, /, path: str): + def create_dir(self, /, path: PathBuf): """ Create a dir at given path @@ -168,7 +168,7 @@ def create_dir(self, /, path: str): self.call("create_dir", path) @validate_call(validate_return=True) - def delete(self, /, path: str): + def delete(self, /, path: PathBuf): """ Delete given path @@ -177,42 +177,42 @@ def delete(self, /, path: str): self.call("delete", path) @validate_call(validate_return=True) - def exists(self, /, path: str) -> bool: + def exists(self, /, path: PathBuf) -> bool: """ Check if given path exists """ return self.call("exists", path) @validate_call - def list(self, /, path: str) -> Generator[str, None, None]: + def list(self, /, path: PathBuf) -> Generator[str, None, None]: """ List files and directories under given path """ yield from self.streamingcall("list", path) @validate_call - def scan(self, /, path: str) -> Generator[str, None, None]: + def scan(self, /, path: PathBuf) -> Generator[str, None, None]: """ List files and directories under given path recursively """ yield from self.streamingcall("scan", path) @validate_call(validate_return=True) - def presign_stat(self, /, path: str, expire_second: int) -> PresignedRequest: + def presign_stat(self, /, path: PathBuf, expire_second: int) -> PresignedRequest: """ Presign an operation for stat (HEAD) which expires after expire_second seconds """ return self.call("presign_stat", path, expire_second) @validate_call(validate_return=True) - def presign_read(self, /, path: str, expire_second: int) -> PresignedRequest: + def presign_read(self, /, path: PathBuf, expire_second: int) -> PresignedRequest: """ Presign an operation for read (GET) which expires after expire_second seconds """ return self.call("presign_read", path, expire_second) @validate_call(validate_return=True) - def presign_write(self, /, path: str, expire_second: int) -> PresignedRequest: + def presign_write(self, /, path: PathBuf, expire_second: int) -> PresignedRequest: """ Presign an operation for write (PUT) which expires after expire_second seconds """ diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/common.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/common.py index e9e028f14..d5f68af71 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/common.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/common.py @@ -1,5 +1,5 @@ # Reference: https://github.com/apache/opendal/blob/main/bindings/python/python/opendal/__init__.pyi - +from os import PathLike from typing import Any, Literal, Optional import opendal @@ -7,6 +7,7 @@ Mode = Literal["rb", "wb"] HashAlgo = Literal["md5", "sha256"] +PathBuf = str | PathLike class EntryMode(BaseModel): diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py index bda001b49..ea8d432c9 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver_test.py @@ -34,7 +34,7 @@ def test_drivers_opendal(tmp_path): assert test_file.writable() (tmp_path / "src").write_text("hello") - test_file.write(str(tmp_path / "src")) + test_file.write(tmp_path / "src") test_file.close() assert test_file.closed @@ -54,7 +54,7 @@ def test_drivers_opendal(tmp_path): == "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824" ) - test_file.read(str(tmp_path / "dst")) + test_file.read(tmp_path / "dst") assert (tmp_path / "dst").read_text() == "llo" assert client.stat("dst").content_length == 3