diff --git a/packages/jumpstarter-driver-http/jumpstarter_driver_http/client.py b/packages/jumpstarter-driver-http/jumpstarter_driver_http/client.py index e541f3a90..a0c33d929 100644 --- a/packages/jumpstarter-driver-http/jumpstarter_driver_http/client.py +++ b/packages/jumpstarter-driver-http/jumpstarter_driver_http/client.py @@ -1,111 +1,12 @@ from dataclasses import dataclass -from pathlib import Path -from jumpstarter_driver_opendal.adapter import OpendalAdapter -from opendal import Operator - -from jumpstarter.client import DriverClient +from jumpstarter_driver_opendal.client import FileServerClient @dataclass(kw_only=True) -class HttpServerClient(DriverClient): +class HttpServerClient(FileServerClient): """Client for the HTTP server driver""" - def start(self): - """ - Start the HTTP server. - - Initializes and starts the HTTP server if it's not already running. - The server will listen on the configured host and port. - """ - self.call("start") - - def stop(self): - """ - Stop the HTTP server. - - Stops the running HTTP server and releases associated resources. - Raises: - ServerNotRunning: If the server is not currently running - """ - self.call("stop") - - def list_files(self) -> list[str]: - """ - List all files in the HTTP server's root directory. - - Returns: - list[str]: A list of filenames present in the HTTP server's root directory - """ - return self.call("list_files") - - def put_file(self, filename: str, src_stream): - """ - Upload a file to the HTTP server using a streamed source. - - Args: - filename (str): Name to save the file as on the server. - src_stream: Stream/source to read the file data from. - - Returns: - str: URL of the uploaded file - """ - return self.call("put_file", filename, src_stream) - - def put_local_file(self, filepath: str) -> str: - """ - Upload a file from the local filesystem to the HTTP server. - - Note: This doesn't use HTTP to upload; it streams the file content directly. - - Args: - filepath (str): Path to the local file to upload. - - Returns: - str: Name of the uploaded file - - Example: - >>> client.put_local_file("/path/to/local/file.txt") - """ - absolute = Path(filepath).resolve() - with OpendalAdapter(client=self, operator=Operator("fs", root="/"), path=str(absolute), mode="rb") as handle: - return self.call("put_file", absolute.name, handle) - - def delete_file(self, filename: str) -> str: - """ - Delete a file from the HTTP server. - - Args: - filename (str): Name of the file to delete. - - Returns: - str: Name of the deleted file - """ - return self.call("delete_file", filename) - - def get_host(self) -> str: - """ - Get the host IP address the HTTP server is listening on. - - Returns: - str: The IP address or hostname the server is bound to - """ - return self.call("get_host") - - def get_port(self) -> int: - """ - Get the port number the HTTP server is listening on. - - Returns: - int: The port number (default is 8080) - """ - return self.call("get_port") - def get_url(self) -> str: - """ - Get the base URL of the HTTP server. - - Returns: - str: The base URL of the server - """ + """Get the base URL of the HTTP server""" return self.call("get_url") diff --git a/packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py b/packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py index b08c338de..19fa969e4 100644 --- a/packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py +++ b/packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py @@ -1,3 +1,4 @@ +import hashlib import os from dataclasses import dataclass, field from pathlib import Path @@ -9,6 +10,8 @@ from jumpstarter.driver import Driver, export +# 4MiB +CHUNK_SIZE = 4 * 1024 * 1024 class HttpServerError(Exception): """Base exception for HTTP server errors""" @@ -58,20 +61,14 @@ def client(cls) -> str: return "jumpstarter_driver_http.client.HttpServerClient" @export - async def put_file(self, filename: str, src_stream) -> str: + async def put_file(self, filename: str, src_stream, client_checksum: str | None = None) -> str: """ Upload a file to the HTTP server. Args: filename (str): Name of the file to upload. src_stream: Stream of file content. - - Returns: - str: Name of the uploaded file. - - Raises: - HttpServerError: If the target path is invalid. - FileWriteError: If the file upload fails. + client_checksum (str | None, optional): Optional SHA256 checksum for verification. """ try: file_path = os.path.join(self.root_dir, filename) @@ -79,12 +76,26 @@ async def put_file(self, filename: str, src_stream) -> str: if not Path(file_path).resolve().is_relative_to(Path(self.root_dir).resolve()): raise HttpServerError("Invalid target path") + self.logger.info(f"Starting file upload: {filename}") + if client_checksum: + self.logger.info(f"Expected checksum from client: {client_checksum}") + async with await FileWriteStream.from_path(file_path) as dst: async with self.resource(src_stream) as src: async for chunk in src: await dst.send(chunk) - self.logger.info(f"File '{filename}' written to '{file_path}'") + actual_checksum = self._compute_checksum(file_path) + self.logger.info(f"Server computed checksum: {actual_checksum}") + + if client_checksum is not None: + if actual_checksum != client_checksum: + self.logger.warning(f"Checksum mismatch for {filename}") + self.logger.warning(f"Expected: {client_checksum}") + self.logger.warning(f"Actual: {actual_checksum}") + else: + self.logger.info("Checksum verification successful") + return f"{self.get_url()}/{filename}" except Exception as e: @@ -222,6 +233,36 @@ def get_port(self) -> int: """ return self.port + @export + def check_file_checksum(self, filename: str, client_checksum: str) -> bool: + """Check if a file matches the expected checksum. + + Args: + filename (str): Name of the file to check + client_checksum (str): Expected SHA256 checksum + + Returns: + bool: True if the file exists and matches the checksum, False otherwise + """ + file_path = os.path.join(self.root_dir, filename) + + if not os.path.exists(file_path): + self.logger.debug(f"File {filename} does not exist") + return False + + current_checksum = self._compute_checksum(file_path) + self.logger.debug(f"Computed checksum: {current_checksum}") + self.logger.debug(f"Client checksum: {client_checksum}") + + return current_checksum == client_checksum + + def _compute_checksum(self, path: str) -> str: + hasher = hashlib.sha256() + with open(path, "rb") as f: + while chunk := f.read(CHUNK_SIZE): + hasher.update(chunk) + return hasher.hexdigest() + def close(self): if self.runner: try: diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py index 5e7c4f443..3f32a202e 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py @@ -1,4 +1,6 @@ +import hashlib from pathlib import Path +from urllib.parse import urlparse import asyncclick as click from opendal import Operator @@ -71,3 +73,147 @@ def write_local_file(file): self.write_local_file(file) return base + +CHUNK_SIZE = 4 * 1024 * 1024 + +class FileServerClient(DriverClient): + """Base client for file server implementations (HTTP, TFTP, etc)""" + + def start(self): + """Start the file server""" + self.call("start") + + def stop(self): + """Stop the file server""" + self.call("stop") + + def list_files(self) -> list[str]: + """List files in the server root directory""" + return self.call("list_files") + + def compute_checksum(self, filepath: str | Path) -> str: + """ + Compute SHA256 checksum of a local file + + Args: + filepath: Path to the file to checksum + + Returns: + str: Hex digest of SHA256 hash + """ + hasher = hashlib.sha256() + with open(filepath, "rb") as f: + while chunk := f.read(CHUNK_SIZE): + hasher.update(chunk) + return hasher.hexdigest() + + def compute_opendal_checksum(self, operator: Operator, path: str) -> str: + """ + Compute SHA256 checksum of a file from an OpenDAL operator + + Args: + operator: OpenDAL operator to read from + path: Path within the operator's storage + + Returns: + str: Hex digest of SHA256 hash + """ + hasher = hashlib.sha256() + with operator.open(path, "rb") as f: + while chunk := f.read(CHUNK_SIZE): + hasher.update(chunk) + return hasher.hexdigest() + + def check_file_checksum(self, filename: str, expected_checksum: str) -> bool: + """ + Check if a server-side file matches an expected checksum + + Args: + filename: Name of file to check + expected_checksum: Expected SHA256 checksum + + Returns: + bool: True if checksums match, False otherwise + """ + return self.call("check_file_checksum", filename, expected_checksum) + + def put_file(self, filename: str, src_stream, checksum: str | None = None): + """ + Upload a file to the server + + Args: + filename: Name to save the file as + src_stream: Source stream to read data from + checksum: Optional SHA256 checksum for verification + """ + if checksum is not None: + try: + return self.call("put_file", filename, src_stream, checksum) + except (TypeError, ValueError): + self.logger.debug("Server does not support checksum verification, falling back to basic upload") + + return self.call("put_file", filename, src_stream) + + def put_file_from_source(self, source: str, checksum: str | None = None): + """ + Upload a file from either a local path or URL to the server. + + Args: + source (str): Local file path or URL to upload + checksum (str, optional): SHA256 checksum of the file. If provided, + will be used for verification + """ + self.logger.info(f"Starting upload from source: {source}") + + if source.startswith(('http://', 'https://')): + parsed_url = urlparse(source) + operator = Operator( + 'http', + root='/', + endpoint=f"{parsed_url.scheme}://{parsed_url.netloc}" + ) + filename = parsed_url.path.split('/')[-1] + path = parsed_url.path + if path.startswith('/'): + path = path[1:] + + if checksum is None: + self.logger.warning("No checksum provided for remote file - skipping verification") + else: + operator = Operator('fs', root='/') + path = str(Path(source).resolve()) + filename = Path(path).name + + if checksum is None: + computed_checksum = self.compute_checksum(source) + self.logger.info(f"Computed checksum for local file {filename}: {computed_checksum}") + checksum = computed_checksum + else: + computed_checksum = self.compute_checksum(source) + self.logger.info(f"Provided checksum: {checksum}") + self.logger.info(f"Computed checksum: {computed_checksum}") + if computed_checksum != checksum: + self.logger.warning("Checksum mismatch between provided and computed values") + + if checksum and self.check_file_checksum(filename, checksum): + self.logger.info(f"Skipping upload of identical file: {filename}") + return filename + + self.logger.info(f"Opening adapter for {filename}") + with OpendalAdapter(client=self, operator=operator, path=path, mode="rb") as handle: + self.logger.info(f"Putting file {filename}") + result = self.put_file(filename, handle, checksum) + self.logger.info(f"Completed upload of {filename}") + return result + + def delete_file(self, filename: str) -> str: + """Delete a file from the server""" + return self.call("delete_file", filename) + + def get_host(self) -> str: + """Get the host address the server is listening on""" + return self.call("get_host") + + def get_port(self) -> int: + """Get the port number the server is listening on""" + return self.call("get_port") diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py index 24081eea0..bfa4a71bb 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py @@ -1,112 +1,10 @@ -import hashlib from dataclasses import dataclass -from pathlib import Path -from jumpstarter_driver_opendal.adapter import OpendalAdapter -from opendal import Operator - -from . import CHUNK_SIZE -from jumpstarter.client import DriverClient +from jumpstarter_driver_opendal.client import FileServerClient @dataclass(kw_only=True) -class TftpServerClient(DriverClient): - """ - Client interface for TFTP Server driver - - This client provides methods to control a TFTP server and manage files on it. - Supports file operations like uploading from various storage backends through OpenDAL. - """ - - def start(self): - """ - Start the TFTP server - - Initializes and starts the TFTP server if it's not already running. - The server will listen on the configured host and port. - """ - self.call("start") - - def stop(self): - """ - Stop the TFTP server - - Stops the running TFTP server and releases associated resources. - - Raises: - ServerNotRunning: If the server is not currently running - """ - self.call("stop") - - def list_files(self) -> list[str]: - """ - List files in the TFTP server root directory - - Returns: - list[str]: A list of filenames present in the TFTP server's root directory - """ - return self.call("list_files") - - def put_file(self, operator: Operator, path: str): - filename = Path(path).name - client_checksum = self._compute_checksum(operator, path) - - if self.call("check_file_checksum", filename, client_checksum): - self.logger.info(f"Skipping upload of identical file: {filename}") - return filename - - with OpendalAdapter(client=self, operator=operator, path=path, mode="rb") as handle: - return self.call("put_file", filename, handle, client_checksum) - - def put_local_file(self, filepath: str): - absolute = Path(filepath).resolve() - filename = absolute.name - - operator = Operator("fs", root="/") - client_checksum = self._compute_checksum(operator, str(absolute)) - - if self.call("check_file_checksum", filename, client_checksum): - self.logger.info(f"Skipping upload of identical file: {filename}") - return filename - - self.logger.info(f"checksum: {client_checksum}") - with OpendalAdapter(client=self, operator=operator, path=str(absolute), mode="rb") as handle: - return self.call("put_file", filename, handle, client_checksum) - - def delete_file(self, filename: str): - """ - Delete a file from the TFTP server - - Args: - filename (str): Name of the file to delete - - Raises: - FileNotFound: If the specified file doesn't exist - TftpError: If deletion fails for other reasons - """ - return self.call("delete_file", filename) - - def get_host(self) -> str: - """ - Get the host address the TFTP server is listening on - - Returns: - str: The IP address or hostname the server is bound to - """ - return self.call("get_host") - - def get_port(self) -> int: - """ - Get the port number the TFTP server is listening on - - Returns: - int: The port number (default is 69) - """ - return self.call("get_port") +class TftpServerClient(FileServerClient): + """Client for the TFTP server driver""" - def _compute_checksum(self, operator: Operator, path: str) -> str: - hasher = hashlib.sha256() - with operator.open(path, "rb") as f: - while chunk := f.read(CHUNK_SIZE): - hasher.update(chunk) - return hasher.hexdigest() + pass