diff --git a/python/docs/source/reference/package-apis/drivers/index.md b/python/docs/source/reference/package-apis/drivers/index.md index fab9c2f02..2b6694ab2 100644 --- a/python/docs/source/reference/package-apis/drivers/index.md +++ b/python/docs/source/reference/package-apis/drivers/index.md @@ -111,6 +111,7 @@ General-purpose utility drivers: * **[Shell](shell.md)** (`jumpstarter-driver-shell`) - Shell command execution * **[TMT](tmt.md)** (`jumpstarter-driver-tmt`) - TMT (Test Management Tool) wrapper driver * **[SSH](ssh.md)** (`jumpstarter-driver-ssh`) - SSH wrapper driver +* **[SSH Mount](ssh-mount.md)** (`jumpstarter-driver-ssh-mount`) - SSHFS remote filesystem mounting ```{toctree} :hidden: @@ -140,8 +141,9 @@ gpiod.md ridesx.md sdwire.md shell.md -ssh.md snmp.md +ssh.md +ssh-mount.md someip.md tasmota.md tmt.md diff --git a/python/docs/source/reference/package-apis/drivers/ssh-mount.md b/python/docs/source/reference/package-apis/drivers/ssh-mount.md new file mode 120000 index 000000000..17b1fa0bd --- /dev/null +++ b/python/docs/source/reference/package-apis/drivers/ssh-mount.md @@ -0,0 +1 @@ +../../../../../packages/jumpstarter-driver-ssh-mount/README.md \ No newline at end of file diff --git a/python/packages/jumpstarter-driver-ssh-mount/.gitignore b/python/packages/jumpstarter-driver-ssh-mount/.gitignore new file mode 100644 index 000000000..cbc5d672b --- /dev/null +++ b/python/packages/jumpstarter-driver-ssh-mount/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +.coverage +coverage.xml diff --git a/python/packages/jumpstarter-driver-ssh-mount/README.md b/python/packages/jumpstarter-driver-ssh-mount/README.md new file mode 100644 index 000000000..2179cd86e --- /dev/null +++ b/python/packages/jumpstarter-driver-ssh-mount/README.md @@ -0,0 +1,119 @@ +# SSHMount Driver + +`jumpstarter-driver-ssh-mount` provides remote filesystem mounting via sshfs. It allows you to mount remote directories from a target device to your local machine using SSHFS (SSH Filesystem). + +## Installation + +```shell +pip3 install --extra-index-url https://pkg.jumpstarter.dev/simple/ jumpstarter-driver-ssh-mount +``` + +You also need `sshfs` installed on the client machine: + +- **Fedora/RHEL**: `sudo dnf install fuse-sshfs` +- **Debian/Ubuntu**: `sudo apt-get install sshfs` +- **macOS**: Install macFUSE from https://macfuse.github.io/ and then install + sshfs from source, as Homebrew has removed sshfs support. + +## Configuration + +The SSHMount driver references an existing SSH driver to inherit credentials +(username, identity key) and TCP connectivity. No duplicate configuration is needed. + +Example exporter configuration: + +```yaml +export: + ssh: + type: jumpstarter_driver_ssh.driver.SSHWrapper + config: + default_username: "root" + # ssh_identity_file: "/path/to/ssh/key" + children: + tcp: + type: jumpstarter_driver_network.driver.TcpNetwork + config: + host: "192.168.1.100" + port: 22 + mount: + type: jumpstarter_driver_ssh_mount.driver.SSHMount + children: + ssh: + ref: "ssh" +``` + +## CLI Usage + +Inside a `jmp shell` session: + +```shell +# Mount remote filesystem (spawns a subshell; type 'exit' to unmount) +j mount /local/mountpoint +j mount /local/mountpoint -r /remote/path +j mount /local/mountpoint --direct + +# Mount in foreground mode (blocks until Ctrl+C) +j mount /local/mountpoint --foreground + +# Pass extra sshfs options +j mount /local/mountpoint -o reconnect -o cache=yes + +# Unmount an orphaned mount +j mount --umount /local/mountpoint +j mount --umount /local/mountpoint --lazy +``` + +By default, `j mount` runs sshfs in foreground mode and spawns a subshell +with a modified prompt. The mount stays active while the subshell is running. +When you type `exit` (or press Ctrl+D), sshfs is terminated and all resources +(port forwards, temporary identity files) are cleaned up automatically. + +Use `--foreground` to skip the subshell and block directly on sshfs. Press +Ctrl+C to unmount. + +The `--umount` flag is available as a fallback for mounts that were orphaned +(e.g., if the process was killed without cleanup). + +## Security: `allow_other` mount option + +By default, sshfs is invoked with `-o allow_other`, which permits all local +users to access the mounted filesystem — not just the user who ran `j mount`. +This is convenient for build workflows where tools run under different UIDs, +but it has security implications on multi-user systems: + +- Any local user can read (and potentially write) files on the remote device + through the mountpoint. +- The option requires that `/etc/fuse.conf` contains `user_allow_other`; + otherwise the mount will fail. + +**Automatic fallback:** if `allow_other` is rejected by FUSE (e.g., +`user_allow_other` is not set), the driver automatically retries the mount +without it. In that case only the mounting user can access the filesystem. + +To explicitly disable `allow_other` without relying on the fallback, you can +override the option via `--extra-args`: + +```shell +j mount /mnt/device -o allow_other=0 +``` + +## API Reference + +### SSHMountClient + +- `mount(mountpoint, *, remote_path="/", direct=False, foreground=False, extra_args=None)` - Mount remote filesystem locally via sshfs +- `umount(mountpoint, *, lazy=False)` - Unmount an sshfs filesystem (fallback for orphaned mounts) + +### Required Children + +| Child name | Type | Description | +|-----------|------|-------------| +| `ssh` | `jumpstarter_driver_ssh.driver.SSHWrapper` | SSH driver providing credentials (username, identity key) and TCP connectivity. Must itself have a `tcp` child of type `TcpNetwork`. | + +### CLI + +The driver registers as `mount` in the exporter config. When used in a `jmp shell` session, the CLI is a single command with a `--umount` flag for unmounting. + +Note: `extra_args` values (passed via `-o`) are forwarded directly to sshfs. This +can be used to override defaults such as `StrictHostKeyChecking=no` -- for example, +`-o StrictHostKeyChecking=yes`. diff --git a/python/packages/jumpstarter-driver-ssh-mount/examples/exporter.yaml b/python/packages/jumpstarter-driver-ssh-mount/examples/exporter.yaml new file mode 100644 index 000000000..be0600156 --- /dev/null +++ b/python/packages/jumpstarter-driver-ssh-mount/examples/exporter.yaml @@ -0,0 +1,24 @@ +apiVersion: jumpstarter.dev/v1alpha1 +kind: ExporterConfig +metadata: + namespace: default + name: demo +endpoint: grpc.jumpstarter.192.168.0.203.nip.io:8082 +token: "" +export: + ssh: + type: jumpstarter_driver_ssh.driver.SSHWrapper + config: + default_username: "root" + # ssh_identity_file: "/path/to/key" + children: + tcp: + type: jumpstarter_driver_network.driver.TcpNetwork + config: + host: "192.168.1.100" + port: 22 + mount: + type: jumpstarter_driver_ssh_mount.driver.SSHMount + children: + ssh: + ref: "ssh" diff --git a/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/__init__.py b/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/__init__.py @@ -0,0 +1 @@ + diff --git a/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/client.py b/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/client.py new file mode 100644 index 000000000..1ecd90cfb --- /dev/null +++ b/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/client.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +import os +import shutil +import subprocess +import sys +import time +from dataclasses import dataclass +from urllib.parse import urlparse + +import click +from jumpstarter_driver_network.adapters import TcpPortforwardAdapter +from jumpstarter_driver_ssh._ssh_utils import cleanup_identity_file, create_temp_identity_file + +from jumpstarter.client import DriverClient +from jumpstarter.client.core import DriverMethodNotImplemented +from jumpstarter.client.decorators import driver_click_command + +# Timeout in seconds for subprocess calls (mount test run, umount) +SUBPROCESS_TIMEOUT = 120 + +# Polling parameters for mount readiness check +MOUNT_POLL_INTERVAL = 0.5 +MOUNT_POLL_TIMEOUT = 10.0 + + +@dataclass(kw_only=True) +class SSHMountClient(DriverClient): + + def cli(self): + @driver_click_command(self) + @click.argument("mountpoint", type=click.Path()) + @click.option("--umount", "-u", is_flag=True, help="Unmount instead of mount") + @click.option("--remote-path", "-r", default="/", help="Remote path to mount (default: /)") + @click.option("--direct", is_flag=True, help="Use direct TCP address") + @click.option("--lazy", "-l", is_flag=True, help="Lazy unmount (detach filesystem now, clean up later)") + @click.option("--foreground", is_flag=True, help="Block on sshfs in foreground without spawning a subshell") + @click.option("--extra-args", "-o", multiple=True, help="Extra arguments to pass to sshfs") + def mount(mountpoint, umount, remote_path, direct, lazy, foreground, extra_args): + """Mount or unmount remote filesystem via sshfs""" + if umount: + self.umount(mountpoint, lazy=lazy) + else: + self.mount( + mountpoint, + remote_path=remote_path, + direct=direct, + foreground=foreground, + extra_args=list(extra_args), + ) + + return mount + + @property + def ssh(self): + return self.children["ssh"] + + @property + def identity(self) -> str | None: + return self.ssh.identity + + @property + def username(self) -> str: + return self.ssh.username + + def mount(self, mountpoint, *, remote_path="/", direct=False, foreground=False, extra_args=None): + """Mount remote filesystem locally via sshfs. + + Runs sshfs in foreground mode (-f) and spawns a subshell so that + the mount stays alive while the user works. When the subshell exits, + sshfs is terminated and all resources are cleaned up automatically. + + Args: + mountpoint: Local directory to mount the remote filesystem on. + remote_path: Remote path to mount (default: /). + direct: If True, connect directly to the host's TCP address. + foreground: If True, block on sshfs without spawning a subshell. + extra_args: Extra arguments to pass to sshfs. + """ + sshfs_path = self._find_executable("sshfs") + if not sshfs_path: + raise click.ClickException( + "sshfs is not installed. Please install it:\n" + " Fedora/RHEL: sudo dnf install fuse-sshfs\n" + " Debian/Ubuntu: sudo apt-get install sshfs\n" + " macOS: Install macFUSE from https://macfuse.github.io/ and then install\n" + " sshfs from source, as Homebrew has removed sshfs support." + ) + + mountpoint = os.path.realpath(mountpoint) + os.makedirs(mountpoint, exist_ok=True) + + if direct: + try: + address = self.ssh.tcp.address() + parsed = urlparse(address) + host = parsed.hostname + port = parsed.port + if not host or not port: + raise ValueError(f"Invalid address format: {address}") + self.logger.debug("Using direct TCP connection for sshfs - host: %s, port: %s", host, port) + self._run_sshfs(host, port, mountpoint, remote_path, extra_args, + foreground=foreground) + except (DriverMethodNotImplemented, ValueError) as e: + self.logger.error( + "Direct address connection failed (%s), falling back to port forwarding", e + ) + self.mount(mountpoint, remote_path=remote_path, direct=False, + foreground=foreground, extra_args=extra_args) + else: + self.logger.debug("Using SSH port forwarding for sshfs connection") + with TcpPortforwardAdapter(client=self.ssh.tcp) as (host, port): + self.logger.debug("SSH port forward established - host: %s, port: %s", host, port) + self._run_sshfs(host, port, mountpoint, remote_path, extra_args, + foreground=foreground) + + def _run_sshfs(self, host, port, mountpoint, remote_path, extra_args, *, foreground): + identity_file = create_temp_identity_file(self.identity, self.logger) + sshfs_proc = None + + try: + sshfs_args = self._build_sshfs_args(host, port, mountpoint, remote_path, identity_file, extra_args) + sshfs_args.append("-f") + + self.logger.debug("Running sshfs command: %s", sshfs_args) + + sshfs_proc = self._start_sshfs_with_fallback(sshfs_args, mountpoint) + + default_username = self.username + user_prefix = f"{default_username}@" if default_username else "" + remote_spec = f"{user_prefix}{host}:{remote_path}" + click.echo(f"Mounted {remote_spec} on {mountpoint}") + + if foreground: + click.echo("Press Ctrl+C to unmount and exit.") + try: + sshfs_proc.wait() + except KeyboardInterrupt: + click.echo("\nUnmounting...") + else: + click.echo("Type 'exit' to unmount and return.") + self._run_subshell(mountpoint, remote_path) + finally: + if sshfs_proc is not None and sshfs_proc.poll() is None: + sshfs_proc.terminate() + try: + sshfs_proc.wait(timeout=10) + except subprocess.TimeoutExpired: + sshfs_proc.kill() + sshfs_proc.wait() + + self._force_umount(mountpoint) + if os.path.ismount(mountpoint): + self.logger.warning("Mountpoint %s may still be mounted after cleanup", mountpoint) + else: + click.echo(f"Unmounted {mountpoint}") + cleanup_identity_file(identity_file, self.logger) + + def _start_sshfs_with_fallback(self, sshfs_args, mountpoint): + """Start sshfs, retrying without allow_other if it fails on that option. + + We do a quick test run (without -f) to check if sshfs can mount + successfully, then start the real foreground process. + """ + test_args = [a for a in sshfs_args if a != "-f"] + result = subprocess.run(test_args, capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT) + + if result.returncode != 0 and "allow_other" in result.stderr: + self.logger.debug("Retrying sshfs without allow_other option") + sshfs_args = self._remove_allow_other(sshfs_args) + test_args = [a for a in sshfs_args if a != "-f"] + result = subprocess.run(test_args, capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT) + + if result.returncode != 0: + stderr = result.stderr.strip() + raise click.ClickException( + f"sshfs mount failed (exit code {result.returncode}): {stderr}" + ) + + self._force_umount(mountpoint) + + proc = subprocess.Popen( + sshfs_args, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + # Poll until mount is ready or sshfs exits unexpectedly + deadline = time.monotonic() + MOUNT_POLL_TIMEOUT + while True: + ret = proc.poll() + if ret is not None: + raise click.ClickException( + f"sshfs mount failed immediately (exit code {ret})" + ) + if os.path.ismount(mountpoint): + break + if time.monotonic() >= deadline: + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait() + raise click.ClickException( + f"sshfs started but {mountpoint} is not mounted after {MOUNT_POLL_TIMEOUT}s" + ) + time.sleep(MOUNT_POLL_INTERVAL) + + return proc + + def _remove_allow_other(self, sshfs_args): + filtered = [] + skip_next = False + for i, arg in enumerate(sshfs_args): + if skip_next: + skip_next = False + continue + if arg == "-o" and i + 1 < len(sshfs_args) and sshfs_args[i + 1] == "allow_other": + skip_next = True + continue + filtered.append(arg) + return filtered + + def _run_subshell(self, mountpoint, remote_path): + """Spawn an interactive subshell with a modified prompt.""" + shell = os.environ.get("SHELL", "/bin/sh") + env = os.environ.copy() + + # Modify the prompt to indicate the active mount + prompt_prefix = f"[sshfs:{remote_path}] " + try: + if "bash" in shell: + env["PS1"] = prompt_prefix + env.get("PS1", r"\$ ") + subprocess.run( + [shell, "--norc", "--noprofile", "-i"], + env=env, + ) + elif "zsh" in shell: + env["PS1"] = prompt_prefix + env.get("PS1", "%# ") + subprocess.run([shell, "-i"], env=env) + else: + subprocess.run([shell, "-i"], env=env) + except FileNotFoundError as err: + raise click.ClickException( + f"Shell '{shell}' not found. Set the SHELL environment variable to a valid shell." + ) from err + + def _build_sshfs_args(self, host, port, mountpoint, remote_path, identity_file, extra_args): + default_username = self.username + user_prefix = f"{default_username}@" if default_username else "" + remote_spec = f"{user_prefix}{host}:{remote_path}" + + sshfs_args = ["sshfs", remote_spec, mountpoint] + + ssh_opts = [ + "StrictHostKeyChecking=no", + "UserKnownHostsFile=/dev/null", + "LogLevel=ERROR", + ] + + if port and port != 22: + sshfs_args.extend(["-p", str(port)]) + + if identity_file: + ssh_opts.append(f"IdentityFile={identity_file}") + + ssh_opts.append("allow_other") + + for opt in ssh_opts: + sshfs_args.extend(["-o", opt]) + + if extra_args: + for arg in extra_args: + sshfs_args.extend(["-o", arg]) + + return sshfs_args + + def umount(self, mountpoint, *, lazy=False): + """Unmount an sshfs filesystem (fallback for orphaned mounts).""" + mountpoint = os.path.realpath(mountpoint) + cmd = self._build_umount_cmd(mountpoint, lazy=lazy) + + self.logger.debug("Running unmount command: %s", cmd) + result = subprocess.run(cmd, capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT) + + if result.returncode != 0: + stderr = result.stderr.strip() + raise click.ClickException(f"Unmount failed (exit code {result.returncode}): {stderr}") + + click.echo(f"Unmounted {mountpoint}") + + def _force_umount(self, mountpoint): + """Best-effort unmount, logging errors at debug level (used during cleanup).""" + cmd = self._build_umount_cmd(mountpoint, lazy=False) + try: + subprocess.run(cmd, capture_output=True, text=True, timeout=SUBPROCESS_TIMEOUT) + except Exception as e: + self.logger.debug("Force umount of %s failed: %s", mountpoint, e) + + def _build_umount_cmd(self, mountpoint, *, lazy=False): + fusermount = self._find_executable("fusermount3") or self._find_executable("fusermount") + if fusermount: + cmd = [fusermount, "-u"] + if lazy: + cmd.append("-z") + else: + cmd = ["umount"] + if lazy: + # macOS umount does not support -l; use -f (force) instead + if sys.platform == "darwin": + cmd.append("-f") + else: + cmd.append("-l") + cmd.append(mountpoint) + return cmd + + @staticmethod + def _find_executable(name): + return shutil.which(name) diff --git a/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/driver.py b/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/driver.py new file mode 100644 index 000000000..48c3fae84 --- /dev/null +++ b/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/driver.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass + +from jumpstarter.common.exceptions import ConfigurationError +from jumpstarter.driver import Driver + + +@dataclass(kw_only=True) +class SSHMount(Driver): + """SSHFS mount/umount driver for Jumpstarter + + This driver provides remote filesystem mounting via sshfs. + It requires an 'ssh' child driver (SSHWrapper) which provides + SSH credentials and a 'tcp' sub-child for network connectivity. + """ + + def __post_init__(self): + if hasattr(super(), "__post_init__"): + super().__post_init__() + + if "ssh" not in self.children: + raise ConfigurationError( + "'ssh' child is required via ref to an SSHWrapper driver instance" + ) + + @classmethod + def client(cls) -> str: + return "jumpstarter_driver_ssh_mount.client.SSHMountClient" diff --git a/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/driver_test.py b/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/driver_test.py new file mode 100644 index 000000000..9c36f9e96 --- /dev/null +++ b/python/packages/jumpstarter-driver-ssh-mount/jumpstarter_driver_ssh_mount/driver_test.py @@ -0,0 +1,732 @@ +import os +from unittest.mock import MagicMock, patch + +import pytest +from jumpstarter_driver_network.driver import TcpNetwork +from jumpstarter_driver_ssh.driver import SSHWrapper + +from jumpstarter_driver_ssh_mount.client import MOUNT_POLL_INTERVAL +from jumpstarter_driver_ssh_mount.driver import SSHMount + +from jumpstarter.common.exceptions import ConfigurationError +from jumpstarter.common.utils import serve + +TEST_SSH_KEY = ( + "-----BEGIN OPENSSH PRIVATE KEY-----\n" + "test-key-content\n" + "-----END OPENSSH PRIVATE KEY-----" +) + + +def _make_ssh_child(default_username="testuser", ssh_identity=None, ssh_identity_file=None, + host="127.0.0.1", port=22): + """Helper to create an SSHWrapper driver instance for use as a child of SSHMount.""" + kwargs = { + "default_username": default_username, + "children": {"tcp": TcpNetwork(host=host, port=port)}, + } + if ssh_identity is not None: + kwargs["ssh_identity"] = ssh_identity + if ssh_identity_file is not None: + kwargs["ssh_identity_file"] = ssh_identity_file + return SSHWrapper(**kwargs) + + +def _fake_find_executable(name): + """Return plausible paths per executable name.""" + paths = { + "sshfs": "/usr/bin/sshfs", + "fusermount3": "/usr/bin/fusermount3", + "fusermount": "/usr/bin/fusermount", + } + return paths.get(name) + + +@pytest.fixture +def mount_instance(): + return SSHMount(children={"ssh": _make_ssh_child()}) + + +@pytest.fixture +def mount_instance_with_identity(): + return SSHMount(children={"ssh": _make_ssh_child(ssh_identity=TEST_SSH_KEY)}) + + +@pytest.fixture +def mock_portforward(): + with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: + mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 2222)) + mock_adapter.return_value.__exit__ = MagicMock(return_value=None) + yield mock_adapter + + +@pytest.fixture +def mock_portforward_22(): + with patch('jumpstarter_driver_ssh_mount.client.TcpPortforwardAdapter') as mock_adapter: + mock_adapter.return_value.__enter__ = MagicMock(return_value=("127.0.0.1", 22)) + mock_adapter.return_value.__exit__ = MagicMock(return_value=None) + yield mock_adapter + + +# --------------------------------------------------------------------------- +# Driver configuration tests +# --------------------------------------------------------------------------- + +def test_ssh_mount_requires_ssh_child(): + """Test that SSHMount driver requires an ssh child""" + with pytest.raises(ConfigurationError, match="'ssh' child is required"): + SSHMount() + + +# --------------------------------------------------------------------------- +# _build_sshfs_args unit tests (argument construction validated independently) +# --------------------------------------------------------------------------- + +def test_build_sshfs_args_basic(mount_instance): + """Test basic sshfs argument construction""" + with serve(mount_instance) as client: + args = client._build_sshfs_args("192.168.1.1", 22, "/mnt/remote", "/", None, None) + assert args[0] == "sshfs" + assert "testuser@192.168.1.1:/" in args + assert "/mnt/remote" in args + assert "-p" not in args + + +def test_build_sshfs_args_custom_port(mount_instance): + """Test sshfs args include -p for non-default port""" + with serve(mount_instance) as client: + args = client._build_sshfs_args("192.168.1.1", 2222, "/mnt/remote", "/", None, None) + assert "-p" in args + assert "2222" in args + + +def test_build_sshfs_args_with_identity(mount_instance): + """Test sshfs args include IdentityFile when identity file is provided""" + with serve(mount_instance) as client: + args = client._build_sshfs_args("192.168.1.1", 22, "/mnt/remote", "/", + "/tmp/my_key", None) + identity_opts = [args[i + 1] for i in range(len(args) - 1) + if args[i] == "-o" and args[i + 1].startswith("IdentityFile=")] + assert len(identity_opts) == 1 + assert identity_opts[0] == "IdentityFile=/tmp/my_key" + + +def test_build_sshfs_args_allow_other_present(mount_instance): + """Test sshfs args include allow_other by default""" + with serve(mount_instance) as client: + args = client._build_sshfs_args("192.168.1.1", 22, "/mnt/remote", "/", None, None) + assert "allow_other" in args + + +def test_build_sshfs_args_with_extra_args(mount_instance): + """Test extra args are prefixed with -o""" + with serve(mount_instance) as client: + args = client._build_sshfs_args("192.168.1.1", 22, "/mnt/remote", "/", None, + ["reconnect", "cache=yes"]) + for extra in ["reconnect", "cache=yes"]: + idx = args.index(extra) + assert args[idx - 1] == "-o" + + +def test_build_sshfs_args_remote_path(mount_instance): + """Test sshfs args use the correct remote path""" + with serve(mount_instance) as client: + args = client._build_sshfs_args("10.0.0.1", 22, "/mnt/remote", "/home/user", None, None) + assert "testuser@10.0.0.1:/home/user" in args + + +def test_build_sshfs_args_no_username(): + """Test sshfs args without default username""" + instance = SSHMount(children={"ssh": _make_ssh_child(default_username="")}) + with serve(instance) as client: + args = client._build_sshfs_args("10.0.0.1", 22, "/mnt/remote", "/", None, None) + assert "10.0.0.1:/" in args + assert not any("@" in a for a in args if ":" in a) + + +# --------------------------------------------------------------------------- +# Mount workflow tests +# --------------------------------------------------------------------------- + +def test_mount_sshfs_not_installed(mount_instance): + """Test mount fails gracefully when sshfs is not installed""" + with serve(mount_instance) as client: + with patch.object(client, '_find_executable', return_value=None): + with pytest.raises(Exception, match="sshfs is not installed"): + client.mount("/tmp/test-mount") + + +def test_mount_sshfs_success(mount_instance, mock_portforward): + """Test successful sshfs mount via port forwarding with subshell""" + with serve(mount_instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = 0 + mock_proc.stderr = None + + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.side_effect = [None] + + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount", remote_path="/home/user") + + test_run_args = mock_run.call_args_list[0][0][0] + assert test_run_args[0] == "sshfs" + assert "testuser@127.0.0.1:/home/user" in test_run_args + assert os.path.realpath("/tmp/test-mount") in test_run_args + assert "-p" in test_run_args + assert "2222" in test_run_args + assert "-f" not in test_run_args + + +def test_mount_sshfs_with_identity(mount_instance_with_identity, mock_portforward_22): + """Test sshfs mount with SSH identity""" + with serve(mount_instance_with_identity) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = 0 + mock_proc.stderr = None + + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.side_effect = [None] + + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount") + + test_run_args = mock_run.call_args_list[0][0][0] + identity_opts = [ + test_run_args[i + 1] for i in range(len(test_run_args) - 1) + if test_run_args[i] == "-o" and test_run_args[i + 1].startswith("IdentityFile=") + ] + assert len(identity_opts) == 1 + + +def test_mount_sshfs_allow_other_fallback(mount_instance, mock_portforward_22): + """Test sshfs mount falls back when allow_other fails""" + with serve(mount_instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = 0 + mock_proc.stderr = None + + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.side_effect = [ + MagicMock(returncode=1, stdout="", stderr="allow_other: permission denied"), + MagicMock(returncode=0, stdout="", stderr=""), + MagicMock(returncode=0, stdout="", stderr=""), + ] + mock_proc.wait.side_effect = [None] + + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount") + + second_call_args = mock_run.call_args_list[1][0][0] + assert "allow_other" not in second_call_args + for i, arg in enumerate(second_call_args): + if arg == "-o": + assert i + 1 < len(second_call_args), "Orphaned -o flag found" + assert not second_call_args[i + 1].startswith("-"), \ + f"Orphaned -o flag followed by {second_call_args[i + 1]}" + + +def test_mount_sshfs_generic_failure(mount_instance, mock_portforward_22): + """Test mount failure with a non-allow_other error""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="Connection refused") + + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount") + + assert mock_run.call_count == 2 + first_call_args = mock_run.call_args_list[0][0][0] + assert first_call_args[0] == "sshfs" + + +def test_mount_sshfs_direct_success(): + """Test sshfs mount using direct TCP address""" + instance = SSHMount(children={"ssh": _make_ssh_child(host="10.0.0.1", port=2222)}) + + with serve(instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = 0 + mock_proc.stderr = None + + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.side_effect = [None] + + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount", direct=True) + + test_run_args = mock_run.call_args_list[0][0][0] + assert test_run_args[0] == "sshfs" + assert "testuser@10.0.0.1:/" in test_run_args + assert "-p" in test_run_args + assert "2222" in test_run_args + + +def test_mount_sshfs_direct_fallback_to_portforward(mount_instance, mock_portforward): + """Test that direct mount falls back to port forwarding on failure""" + with serve(mount_instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = 0 + mock_proc.stderr = None + + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.side_effect = [None] + + original_ssh = client.ssh + + class FakeTcp: + def address(self): + raise ValueError("not available") + + class FakeSsh: + def __getattr__(self, name): + if name == "tcp": + return FakeTcp() + return getattr(original_ssh, name) + + with patch.object(client, 'children', {**client.children, "ssh": FakeSsh()}): + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount", direct=True) + + test_run_args = mock_run.call_args_list[0][0][0] + assert "2222" in test_run_args + + +def test_mount_foreground_mode(mount_instance, mock_portforward_22): + """Test that foreground flag blocks on sshfs without spawning subshell""" + with serve(mount_instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = None + mock_proc.returncode = 0 + + poll_calls = [0] + def poll_side_effect(): + poll_calls[0] += 1 + if poll_calls[0] >= 3: + return None + return None + mock_proc.poll.side_effect = poll_side_effect + + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc) as mock_popen, + patch('os.makedirs'), + patch('os.path.ismount', return_value=True), + patch('jumpstarter_driver_ssh_mount.client.time.sleep'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.return_value = None + + client.mount("/tmp/test-mount", foreground=True) + + assert mock_proc.wait.call_count >= 1 + mock_portforward_22.return_value.__exit__.assert_called() + popen_args = mock_popen.call_args[0][0] + assert "-f" in popen_args + + +def test_mount_subshell_mode(mount_instance, mock_portforward_22): + """Test that default mode spawns a subshell""" + with serve(mount_instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = None + mock_proc.returncode = 0 + + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + patch('os.path.ismount', return_value=True), + patch('jumpstarter_driver_ssh_mount.client.time.sleep'), + patch.object(client, '_run_subshell') as mock_subshell, + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + + client.mount("/tmp/test-mount") + + resolved = os.path.realpath("/tmp/test-mount") + mock_subshell.assert_called_once_with(resolved, "/") + + +def test_mount_cleanup_on_failure(mount_instance_with_identity, mock_portforward_22): + """Test that identity file is cleaned up when mount fails""" + with serve(mount_instance_with_identity) as client: + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('os.makedirs'), + patch('os.unlink') as mock_unlink, + ): + mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="Connection refused") + + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount") + + assert mock_unlink.called + unlink_path = mock_unlink.call_args_list[-1][0][0] + assert unlink_path.endswith("_ssh_key") + + +# --------------------------------------------------------------------------- +# Unmount tests +# --------------------------------------------------------------------------- + +def test_umount_with_fusermount(mount_instance): + """Test unmount using fusermount""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + client.umount("/tmp/test-mount") + + call_args = mock_run.call_args[0][0] + assert call_args[0] == "/usr/bin/fusermount3" + assert "-u" in call_args + + +def test_umount_with_system_umount_fallback(mount_instance): + """Test unmount falls back to system umount when fusermount is not available""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', return_value=None), + patch('subprocess.run') as mock_run, + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + client.umount("/tmp/test-mount") + + call_args = mock_run.call_args[0][0] + assert call_args[0] == "umount" + + +def test_umount_lazy(mount_instance): + """Test lazy unmount""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + client.umount("/tmp/test-mount", lazy=True) + + call_args = mock_run.call_args[0][0] + assert "-z" in call_args + + +def test_umount_failure(mount_instance): + """Test unmount failure""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + ): + mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="not mounted") + + with pytest.raises(Exception, match="Unmount failed"): + client.umount("/tmp/test-mount") + + +def test_umount_prefers_fusermount3(mount_instance): + """Test that fusermount3 is preferred over fusermount when both are available""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + client.umount("/tmp/test-mount") + + call_args = mock_run.call_args[0][0] + assert call_args[0] == "/usr/bin/fusermount3" + + +def test_umount_lazy_macos_uses_force(mount_instance): + """Test that lazy unmount on macOS uses -f instead of -l""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', return_value=None), + patch('subprocess.run') as mock_run, + patch('jumpstarter_driver_ssh_mount.client.sys') as mock_sys, + ): + mock_sys.platform = "darwin" + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + + client.umount("/tmp/test-mount", lazy=True) + + call_args = mock_run.call_args[0][0] + assert "-f" in call_args + assert "-l" not in call_args + + +def test_umount_passes_timeout(mount_instance): + """Test that umount subprocess calls include SUBPROCESS_TIMEOUT""" + with serve(mount_instance) as client: + with ( + patch.object(client, '_find_executable', return_value=None), + patch('subprocess.run') as mock_run, + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + client.umount("/tmp/test-mount") + + assert mock_run.call_args[1].get("timeout") == 120 + + +# --------------------------------------------------------------------------- +# CLI tests +# --------------------------------------------------------------------------- + +def test_cli_has_mount_and_umount_flag(mount_instance): + """Test that the CLI exposes mount command with --umount and --foreground flags""" + with serve(mount_instance) as client: + cli = client.cli() + from click.testing import CliRunner + runner = CliRunner() + result = runner.invoke(cli, ["--help"]) + assert "mountpoint" in result.output.lower() or "MOUNTPOINT" in result.output + assert "--umount" in result.output + assert "--foreground" in result.output + + +def test_cli_dispatches_mount(mount_instance): + """Test that CLI invocation with a mountpoint dispatches to self.mount()""" + with serve(mount_instance) as client: + cli = client.cli() + from click.testing import CliRunner + runner = CliRunner() + + with patch.object(client, 'mount') as mock_mount: + result = runner.invoke(cli, ["/tmp/test-cli-mount", "-r", "/home"]) + assert result.exit_code == 0 + mock_mount.assert_called_once_with( + "/tmp/test-cli-mount", + remote_path="/home", + direct=False, + foreground=False, + extra_args=[], + ) + + +def test_cli_dispatches_umount(mount_instance): + """Test that CLI invocation with --umount dispatches to self.umount()""" + with serve(mount_instance) as client: + cli = client.cli() + from click.testing import CliRunner + runner = CliRunner() + + with patch.object(client, 'umount') as mock_umount: + result = runner.invoke(cli, ["--umount", "/tmp/test-cli-mount", "--lazy"]) + assert result.exit_code == 0 + mock_umount.assert_called_once_with("/tmp/test-cli-mount", lazy=True) + + +# --------------------------------------------------------------------------- +# Polling / mount-readiness tests +# --------------------------------------------------------------------------- + +def test_mount_polling_waits_for_mount(mount_instance, mock_portforward_22): + """Test that the polling loop waits for os.path.ismount to return True""" + with serve(mount_instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = None + mock_proc.returncode = 0 + + ismount_calls = [0] + def ismount_side_effect(path): + ismount_calls[0] += 1 + return ismount_calls[0] >= 3 + + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + patch('os.path.ismount', side_effect=ismount_side_effect), + patch('jumpstarter_driver_ssh_mount.client.time.sleep') as mock_sleep, + patch.object(client, '_run_subshell'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + + client.mount("/tmp/test-mount") + + assert mock_sleep.call_count >= 2 + mock_sleep.assert_called_with(MOUNT_POLL_INTERVAL) + + +def test_mount_polling_timeout(mount_instance, mock_portforward_22): + """Test that mount fails if mountpoint is never mounted within timeout""" + with serve(mount_instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = None + mock_proc.returncode = 0 + + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + patch('os.path.ismount', return_value=False), + patch('jumpstarter_driver_ssh_mount.client.time.sleep'), + patch('jumpstarter_driver_ssh_mount.client.MOUNT_POLL_TIMEOUT', 0), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + + with pytest.raises(Exception, match="is not mounted"): + client.mount("/tmp/test-mount", foreground=True) + + mock_proc.terminate.assert_called() + + +def test_mount_sshfs_not_mounted_after_startup(mount_instance, mock_portforward_22): + """Test that mount fails if sshfs starts but mountpoint is not actually mounted""" + with serve(mount_instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = None + mock_proc.returncode = 0 + + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + patch('os.path.ismount', return_value=False), + patch('jumpstarter_driver_ssh_mount.client.time.sleep'), + patch('jumpstarter_driver_ssh_mount.client.MOUNT_POLL_TIMEOUT', 0), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + + with pytest.raises(Exception, match="is not mounted"): + client.mount("/tmp/test-mount", foreground=True) + + mock_proc.terminate.assert_called() + + +# --------------------------------------------------------------------------- +# Foreground / KeyboardInterrupt tests +# --------------------------------------------------------------------------- + +def test_mount_foreground_keyboard_interrupt(mount_instance, mock_portforward_22): + """Test that KeyboardInterrupt during foreground mode terminates sshfs and unmounts""" + with serve(mount_instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = None + mock_proc.returncode = 0 + + mock_proc.wait.side_effect = [ + KeyboardInterrupt(), + None, + ] + + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + patch('os.path.ismount', return_value=True), + patch('jumpstarter_driver_ssh_mount.client.time.sleep'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + + client.mount("/tmp/test-mount", foreground=True) + + mock_proc.terminate.assert_called_once() + + +# --------------------------------------------------------------------------- +# Extra args and port tests +# --------------------------------------------------------------------------- + +def test_extra_args_prefixed_with_dash_o(mount_instance, mock_portforward_22): + """Test that extra_args are correctly prefixed with -o in sshfs command""" + with serve(mount_instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = 0 + mock_proc.stderr = None + + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.side_effect = [None] + + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount", extra_args=["reconnect", "cache=yes"]) + + test_run_args = mock_run.call_args_list[0][0][0] + for extra in ["reconnect", "cache=yes"]: + idx = test_run_args.index(extra) + assert test_run_args[idx - 1] == "-o" + + +def test_mount_port_22_omits_p_flag(mount_instance, mock_portforward_22): + """Test that port 22 does not add -p flag to sshfs args""" + with serve(mount_instance) as client: + mock_proc = MagicMock() + mock_proc.poll.return_value = 0 + mock_proc.stderr = None + + with ( + patch.object(client, '_find_executable', side_effect=_fake_find_executable), + patch('subprocess.run') as mock_run, + patch('subprocess.Popen', return_value=mock_proc), + patch('os.makedirs'), + ): + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + mock_proc.wait.side_effect = [None] + + with pytest.raises(Exception, match="sshfs mount failed"): + client.mount("/tmp/test-mount") + + test_run_args = mock_run.call_args_list[0][0][0] + assert "-p" not in test_run_args + + +# --------------------------------------------------------------------------- +# Subshell tests +# --------------------------------------------------------------------------- + +def test_subshell_bad_shell_raises_click_exception(mount_instance): + """Test that _run_subshell raises ClickException when shell binary is not found""" + with serve(mount_instance) as client: + with patch.dict(os.environ, {"SHELL": "/nonexistent/shell"}): + with patch('subprocess.run', side_effect=FileNotFoundError("No such file")): + with pytest.raises(Exception, match="Shell .* not found"): + client._run_subshell("/tmp/test-mount", "/") diff --git a/python/packages/jumpstarter-driver-ssh-mount/pyproject.toml b/python/packages/jumpstarter-driver-ssh-mount/pyproject.toml new file mode 100644 index 000000000..f8e3b0bd2 --- /dev/null +++ b/python/packages/jumpstarter-driver-ssh-mount/pyproject.toml @@ -0,0 +1,47 @@ +[project] +name = "jumpstarter-driver-ssh-mount" +dynamic = ["version", "urls"] +description = "SSHFS mount/umount driver for Jumpstarter that provides remote filesystem mounting via sshfs" +readme = "README.md" +license = "Apache-2.0" +authors = [ + { name = "The Jumpstarter Authors" } +] +requires-python = ">=3.11" +dependencies = [ + "click>=8.0.0", + "jumpstarter", + "jumpstarter-driver-network", + "jumpstarter-driver-ssh", +] + +[project.entry-points."jumpstarter.drivers"] +SSHMount = "jumpstarter_driver_ssh_mount.driver:SSHMount" + +[tool.hatch.version] +source = "vcs" +raw-options = { 'root' = '../../../'} + +[tool.hatch.metadata.hooks.vcs.urls] +Homepage = "https://jumpstarter.dev" +source_archive = "https://github.com/jumpstarter-dev/repo/archive/{commit_hash}.zip" + +[tool.pytest.ini_options] +addopts = "--cov --cov-report=html --cov-report=xml" +log_cli = true +log_cli_level = "INFO" +testpaths = ["jumpstarter_driver_ssh_mount"] +asyncio_mode = "auto" + +[build-system] +requires = ["hatchling", "hatch-vcs", "hatch-pin-jumpstarter"] +build-backend = "hatchling.build" + +[tool.hatch.build.hooks.pin_jumpstarter] +name = "pin_jumpstarter" + +[dependency-groups] +dev = [ + "pytest-cov>=6.0.0", + "pytest>=8.3.3", +] diff --git a/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/_ssh_utils.py b/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/_ssh_utils.py new file mode 100644 index 000000000..96fd3739f --- /dev/null +++ b/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/_ssh_utils.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import os +import tempfile + + +def create_temp_identity_file(ssh_identity: str, logger) -> str | None: + if not ssh_identity: + return None + + fd = None + temp_path = None + try: + fd, temp_path = tempfile.mkstemp(suffix="_ssh_key") + os.write(fd, ssh_identity.encode()) + os.close(fd) + fd = None + logger.debug("Created temporary identity file: %s", temp_path) + return temp_path + except Exception as e: + logger.error("Failed to create temporary identity file: %s", e) + if fd is not None: + try: + os.close(fd) + except Exception: + pass + if temp_path: + try: + os.unlink(temp_path) + except Exception: + pass + raise + + +def cleanup_identity_file(identity_file: str | None, logger) -> None: + if identity_file: + try: + os.unlink(identity_file) + logger.debug("Cleaned up temporary identity file: %s", identity_file) + except Exception as e: + logger.warning("Failed to clean up identity file %s: %s", identity_file, e) diff --git a/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py b/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py index 5574dcc1a..e47ef92a2 100644 --- a/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py +++ b/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py @@ -1,7 +1,5 @@ -import os import shlex import subprocess -import tempfile from contextlib import asynccontextmanager from dataclasses import dataclass from urllib.parse import urlparse @@ -10,6 +8,7 @@ from jumpstarter_driver_composite.client import CompositeClient from jumpstarter_driver_network.adapters import TcpPortforwardAdapter +from ._ssh_utils import cleanup_identity_file, create_temp_identity_file from jumpstarter.client.core import DriverMethodNotImplemented from jumpstarter.client.decorators import driver_click_command @@ -151,27 +150,7 @@ def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult: def _run_ssh_local(self, host, port, options, args): """Run SSH command with the given host, port, and arguments""" - # Create temporary identity file if needed - ssh_identity = self.identity - identity_file = None - temp_file = None - if ssh_identity: - try: - temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_ssh_key') - temp_file.write(ssh_identity) - temp_file.close() - # Set proper permissions (600) for SSH key - os.chmod(temp_file.name, 0o600) - identity_file = temp_file.name - self.logger.debug("Created temporary identity file: %s", identity_file) - except Exception as e: - self.logger.error("Failed to create temporary identity file: %s", e) - if temp_file: - try: - os.unlink(temp_file.name) - except Exception: - pass - raise + identity_file = create_temp_identity_file(self.identity, self.logger) try: # Build SSH command arguments @@ -186,13 +165,7 @@ def _run_ssh_local(self, host, port, options, args): # Execute the command return self._execute_ssh_command(ssh_args, options) finally: - # Clean up temporary identity file - if identity_file: - try: - os.unlink(identity_file) - self.logger.debug("Cleaned up temporary identity file: %s", identity_file) - except Exception as e: - self.logger.warning("Failed to clean up temporary identity file %s: %s", identity_file, str(e)) + cleanup_identity_file(identity_file, self.logger) def _build_ssh_command_args(self, port, identity_file, args): """Build initial SSH command arguments""" diff --git a/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py b/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py index 92a540406..64b05ac5b 100644 --- a/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py +++ b/python/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py @@ -585,6 +585,9 @@ def test_ssh_command_without_identity(): assert result.stdout == "some stdout" +_UTILS = "jumpstarter_driver_ssh._ssh_utils" + + def test_ssh_identity_temp_file_creation_and_cleanup(): """Test that temporary identity file is created and cleaned up properly""" instance = SSHWrapper( @@ -597,33 +600,23 @@ def test_ssh_identity_temp_file_creation_and_cleanup(): with patch('subprocess.run') as mock_run: mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") - with patch('tempfile.NamedTemporaryFile') as mock_temp_file: - with patch('os.chmod') as mock_chmod: - with patch('os.unlink') as mock_unlink: - # Mock the temporary file - mock_temp_file_instance = MagicMock() - mock_temp_file_instance.name = "/tmp/test_ssh_key_12345" - mock_temp_file_instance.write = MagicMock() - mock_temp_file_instance.close = MagicMock() - mock_temp_file.return_value = mock_temp_file_instance - - # Test SSH command with identity - result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) - assert isinstance(result, SSHCommandRunResult) - - # Verify temporary file was created - mock_temp_file.assert_called_once_with(mode='w', delete=False, suffix='_ssh_key') - mock_temp_file_instance.write.assert_called_once_with(TEST_SSH_KEY) - mock_temp_file_instance.close.assert_called_once() - - # Verify proper permissions were set - mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600) + mkstemp_rv = (5, "/tmp/test_ssh_key_12345") + with ( + patch(f"{_UTILS}.tempfile.mkstemp", return_value=mkstemp_rv) as mock_mkstemp, + patch(f"{_UTILS}.os.write") as mock_write, + patch(f"{_UTILS}.os.close") as mock_close, + patch(f"{_UTILS}.os.unlink") as mock_unlink, + ): + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) - # Verify temporary file was cleaned up - mock_unlink.assert_called_once_with("/tmp/test_ssh_key_12345") + mock_mkstemp.assert_called_once_with(suffix="_ssh_key") + mock_write.assert_called_once_with(5, TEST_SSH_KEY.encode()) + mock_close.assert_called_once_with(5) + mock_unlink.assert_called_once_with("/tmp/test_ssh_key_12345") - assert result.return_code == 0 - assert result.stdout == "some stdout" + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_identity_temp_file_creation_error(): @@ -638,16 +631,46 @@ def test_ssh_identity_temp_file_creation_error(): with patch('subprocess.run') as mock_run: mock_run.return_value = MagicMock(returncode=0) - with patch('tempfile.NamedTemporaryFile') as mock_temp_file: - mock_temp_file.side_effect = OSError("Permission denied") + with patch(f"{_UTILS}.tempfile.mkstemp") as mock_mkstemp: + mock_mkstemp.side_effect = OSError("Permission denied") + + with pytest.raises(ExceptionGroup) as exc_info: + client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + + assert any( + isinstance(e, OSError) and "Permission denied" in str(e) + for e in exc_info.value.exceptions + ) + + +def test_ssh_identity_temp_file_creation_error_fd_cleanup(): + """Test that fd is closed when write fails after mkstemp succeeds""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) - # Test SSH command with identity should raise an error - # The exception will be wrapped in an ExceptionGroup due to the context manager + mkstemp_rv = (5, "/tmp/test_ssh_key_12345") + with ( + patch(f"{_UTILS}.tempfile.mkstemp", return_value=mkstemp_rv), + patch(f"{_UTILS}.os.write", side_effect=OSError("Disk full")), + patch(f"{_UTILS}.os.close") as mock_close, + patch(f"{_UTILS}.os.unlink") as mock_unlink, + ): with pytest.raises(ExceptionGroup) as exc_info: client.run(SSHCommandRunOptions(direct=False), ["hostname"]) - # Check that the original OSError is in the exception group - assert any(isinstance(e, OSError) and "Permission denied" in str(e) for e in exc_info.value.exceptions) + assert any( + isinstance(e, OSError) and "Disk full" in str(e) + for e in exc_info.value.exceptions + ) + mock_close.assert_called_once_with(5) + mock_unlink.assert_called_once_with("/tmp/test_ssh_key_12345") def test_ssh_identity_temp_file_cleanup_error(): @@ -662,36 +685,24 @@ def test_ssh_identity_temp_file_cleanup_error(): with patch('subprocess.run') as mock_run: mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") - with patch('tempfile.NamedTemporaryFile') as mock_temp_file: - with patch('os.chmod') as mock_chmod: - with patch('os.unlink') as mock_unlink: - # Mock the temporary file - mock_temp_file_instance = MagicMock() - mock_temp_file_instance.name = "/tmp/test_ssh_key_12345" - mock_temp_file_instance.write = MagicMock() - mock_temp_file_instance.close = MagicMock() - mock_temp_file.return_value = mock_temp_file_instance - - # Mock cleanup failure - mock_unlink.side_effect = OSError("Permission denied") - - # Test SSH command with identity - should still succeed but log warning - with patch.object(client, 'logger') as mock_logger: - result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) - assert isinstance(result, SSHCommandRunResult) - - # Verify chmod was called - mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600) - - # Verify warning was logged - mock_logger.warning.assert_called_once_with( - "Failed to clean up temporary identity file %s: %s", - "/tmp/test_ssh_key_12345", - str(mock_unlink.side_effect) - ) - - assert result.return_code == 0 - assert result.stdout == "some stdout" + mkstemp_rv = (5, "/tmp/test_ssh_key_12345") + with ( + patch(f"{_UTILS}.tempfile.mkstemp", return_value=mkstemp_rv), + patch(f"{_UTILS}.os.write"), + patch(f"{_UTILS}.os.close"), + patch(f"{_UTILS}.os.unlink", side_effect=OSError("Permission denied")), + ): + with patch.object(client, 'logger') as mock_logger: + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"]) + assert isinstance(result, SSHCommandRunResult) + + mock_logger.warning.assert_called_once() + warning_args = mock_logger.warning.call_args[0] + assert "Failed to clean up identity file" in warning_args[0] + assert "/tmp/test_ssh_key_12345" in warning_args[1] + + assert result.return_code == 0 + assert result.stdout == "some stdout" def test_ssh_client_properties(): diff --git a/python/pyproject.toml b/python/pyproject.toml index dbecb9b3b..f60828624 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -36,6 +36,7 @@ jumpstarter-driver-tftp = { workspace = true } jumpstarter-driver-snmp = { workspace = true } jumpstarter-driver-shell = { workspace = true } jumpstarter-driver-ssh = { workspace = true } +jumpstarter-driver-ssh-mount = { workspace = true } jumpstarter-driver-uboot = { workspace = true } jumpstarter-driver-uds = { workspace = true } jumpstarter-driver-uds-can = { workspace = true } diff --git a/python/uv.lock b/python/uv.lock index 234253413..d8a8f317f 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.14'", @@ -33,6 +33,7 @@ members = [ "jumpstarter-driver-iscsi", "jumpstarter-driver-mitmproxy", "jumpstarter-driver-network", + "jumpstarter-driver-noyito-relay", "jumpstarter-driver-opendal", "jumpstarter-driver-pi-pico", "jumpstarter-driver-power", @@ -46,6 +47,7 @@ members = [ "jumpstarter-driver-someip", "jumpstarter-driver-ssh", "jumpstarter-driver-ssh-mitm", + "jumpstarter-driver-ssh-mount", "jumpstarter-driver-tasmota", "jumpstarter-driver-tftp", "jumpstarter-driver-tmt", @@ -1680,6 +1682,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/08/e7/ae38d7a6dfba0533684e0b2136817d667588ae3ec984c1a4e5df5eb88482/hatchling-1.27.0-py3-none-any.whl", hash = "sha256:d3a2f3567c4f926ea39849cdf924c7e99e6686c9c8e288ae1037c8fa2a5d937b", size = 75794, upload-time = "2024-12-15T17:08:10.364Z" }, ] +[[package]] +name = "hid" +version = "1.0.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e9/f8/0357a8aa8874a243e96d08a8568efaf7478293e1a3441ddca18039b690c1/hid-1.0.9.tar.gz", hash = "sha256:f4471f11f0e176d1b0cb1b243e55498cc90347a3aede735655304395694ac182", size = 4973, upload-time = "2026-02-05T15:35:20.595Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/c7/f0e1ad95179f44a6fc7a9140be025812cc7a62cf7390442b685a57ee1417/hid-1.0.9-py3-none-any.whl", hash = "sha256:6b9289e00bbc1e1589bec0c7f376a63fe03a4a4a1875575d0ad60e3e11a349f4", size = 4959, upload-time = "2026-02-05T15:35:19.269Z" }, +] + [[package]] name = "hpack" version = "4.1.0" @@ -2263,12 +2274,15 @@ source = { editable = "packages/jumpstarter-driver-ble" } dependencies = [ { name = "anyio" }, { name = "bleak" }, + { name = "click" }, { name = "jumpstarter" }, + { name = "jumpstarter-driver-network" }, ] [package.dev-dependencies] dev = [ { name = "pytest" }, + { name = "pytest-anyio" }, { name = "pytest-cov" }, ] @@ -2276,12 +2290,15 @@ dev = [ requires-dist = [ { name = "anyio", specifier = ">=4.10.0" }, { name = "bleak", specifier = ">=1.1.1" }, + { name = "click", specifier = ">=8.1.8" }, { name = "jumpstarter", editable = "packages/jumpstarter" }, + { name = "jumpstarter-driver-network", editable = "packages/jumpstarter-driver-network" }, ] [package.metadata.requires-dev] dev = [ { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-anyio", specifier = ">=0.0.0" }, { name = "pytest-cov", specifier = ">=6.0.0" }, ] @@ -2753,6 +2770,38 @@ dev = [ { name = "websocket-client", specifier = ">=1.8.0" }, ] +[[package]] +name = "jumpstarter-driver-noyito-relay" +source = { editable = "packages/jumpstarter-driver-noyito-relay" } +dependencies = [ + { name = "hid" }, + { name = "jumpstarter" }, + { name = "jumpstarter-driver-power" }, + { name = "pyserial" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pytest" }, + { name = "pytest-cov" }, + { name = "pytest-mock" }, +] + +[package.metadata] +requires-dist = [ + { name = "hid", specifier = ">=1.0.4" }, + { name = "jumpstarter", editable = "packages/jumpstarter" }, + { name = "jumpstarter-driver-power", editable = "packages/jumpstarter-driver-power" }, + { name = "pyserial", specifier = ">=3.5" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-cov", specifier = ">=6.0.0" }, + { name = "pytest-mock", specifier = ">=3.14.0" }, +] + [[package]] name = "jumpstarter-driver-opendal" source = { editable = "packages/jumpstarter-driver-opendal" } @@ -3094,7 +3143,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "jumpstarter", editable = "packages/jumpstarter" }, - { name = "opensomeip", specifier = ">=0.1.2" }, + { name = "opensomeip", specifier = ">=0.1.2,<0.2.0" }, ] [package.metadata.requires-dev] @@ -3167,6 +3216,36 @@ dev = [ { name = "trio", specifier = ">=0.28.0" }, ] +[[package]] +name = "jumpstarter-driver-ssh-mount" +source = { editable = "packages/jumpstarter-driver-ssh-mount" } +dependencies = [ + { name = "click" }, + { name = "jumpstarter" }, + { name = "jumpstarter-driver-network" }, + { name = "jumpstarter-driver-ssh" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pytest" }, + { name = "pytest-cov" }, +] + +[package.metadata] +requires-dist = [ + { name = "click", specifier = ">=8.0.0" }, + { name = "jumpstarter", editable = "packages/jumpstarter" }, + { name = "jumpstarter-driver-network", editable = "packages/jumpstarter-driver-network" }, + { name = "jumpstarter-driver-ssh", editable = "packages/jumpstarter-driver-ssh" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-cov", specifier = ">=6.0.0" }, +] + [[package]] name = "jumpstarter-driver-tasmota" source = { editable = "packages/jumpstarter-driver-tasmota" } @@ -5257,6 +5336,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0d/d2/dfc2f25f3905921c2743c300a48d9494d29032f1389fc142e718d6978fb2/pytest_httpserver-1.1.3-py3-none-any.whl", hash = "sha256:5f84757810233e19e2bb5287f3826a71c97a3740abe3a363af9155c0f82fdbb9", size = 21000, upload-time = "2025-04-10T08:17:13.906Z" }, ] +[[package]] +name = "pytest-mock" +version = "3.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, +] + [[package]] name = "pytest-mqtt" version = "0.5.0"