diff --git a/packages/jumpstarter-driver-ssh-mitm/README.md b/packages/jumpstarter-driver-ssh-mitm/README.md new file mode 100644 index 000000000..45acfdaa9 --- /dev/null +++ b/packages/jumpstarter-driver-ssh-mitm/README.md @@ -0,0 +1,111 @@ +# SSH MITM Driver + +`jumpstarter-driver-ssh-mitm` provides a secure SSH proxy layer where private keys +are stored on the exporter and never transmitted to clients. It is designed to be +used as a child of `SSHWrapper`. + +## Installation + +```{code-block} console +:substitutions: +$ pip3 install --extra-index-url {{index_url}} jumpstarter-driver-ssh-mitm +``` + +## Architecture + +``` +SSHWrapper --> SSHMITM --> TcpNetwork --> DUT +``` + +- **SSHWrapper**: Handles SSH CLI and command execution +- **SSHMITM**: Provides authenticated proxy connection (stores the SSH key) +- **TcpNetwork**: Raw TCP connection to the DUT + +## Configuration + +The command name is determined by the key in the `export` section. Use `ssh_mitm` to get the `j ssh_mitm` command: + +```yaml +export: + ssh_mitm: # ← This gives you "j ssh_mitm" command + type: jumpstarter_driver_ssh.driver.SSHWrapper + config: + default_username: root + children: + tcp: + type: jumpstarter_driver_ssh_mitm.driver.SSHMITM + config: + ssh_identity_file: /path/to/private/key + default_username: root + children: + tcp: + type: jumpstarter_driver_network.driver.TcpNetwork + config: + host: 192.168.1.100 + port: 22 +``` + +Or with inline key: + +```yaml +export: + ssh_mitm: # ← This gives you "j ssh_mitm" command + type: jumpstarter_driver_ssh.driver.SSHWrapper + config: + default_username: root + children: + tcp: + type: jumpstarter_driver_ssh_mitm.driver.SSHMITM + config: + default_username: root + ssh_identity: | + -----BEGIN OPENSSH PRIVATE KEY----- + ... + -----END OPENSSH PRIVATE KEY----- + children: + tcp: + type: jumpstarter_driver_network.driver.TcpNetwork + config: + host: 192.168.1.100 + port: 22 +``` + +### SSHMITM Config parameters + +| Parameter | Description | Type | Required | Default | +| ----------------- | ---------------------------------------- | ----- | -------- | ------- | +| default_username | SSH username for DUT connection | str | no | "" | +| ssh_identity | SSH private key content (inline) | str | no* | None | +| ssh_identity_file | Path to SSH private key file | str | no* | None | + +\* Either `ssh_identity` or `ssh_identity_file` must be provided. + +### Required children + +- `tcp`: A `TcpNetwork` driver providing target host and port + +## Usage + +Since SSHMITM is used as a child of SSHWrapper, you use the configured command name (e.g., `ssh_mitm`): + +```bash +# Execute a command +j ssh_mitm whoami + +# Interactive shell +j ssh_mitm + +# With arguments +j ssh_mitm ls -la /tmp + +# With SSH flags +j ssh_mitm -v hostname +``` + +**Note**: The command name (`ssh_mitm`) is determined by the key in your exporter config's `export` section. You can use any name you prefer. + +## API Reference + +```{eval-rst} +.. autoclass:: jumpstarter_driver_ssh_mitm.driver.SSHMITM() +``` diff --git a/packages/jumpstarter-driver-ssh-mitm/examples/exporter.yaml b/packages/jumpstarter-driver-ssh-mitm/examples/exporter.yaml new file mode 100644 index 000000000..0d85e02d3 --- /dev/null +++ b/packages/jumpstarter-driver-ssh-mitm/examples/exporter.yaml @@ -0,0 +1,33 @@ +apiVersion: jumpstarter.dev/v1alpha1 +kind: ExporterConfig +metadata: + namespace: default + name: ssh-mitm-example +endpoint: "grpc.jumpstarter.example.com:443" +token: "your-exporter-token" +export: + # "j ssh_mitm" command - secure SSH with key on server + ssh_mitm: + type: jumpstarter_driver_ssh.driver.SSHWrapper + config: + # Change to the user you will SSH as on the DUT + default_username: root + children: + tcp: + type: jumpstarter_driver_ssh_mitm.driver.SSHMITM + config: + # Must match the user on the DUT + default_username: root + # Option 1: Path to key file (on exporter machine) + ssh_identity_file: /etc/jumpstarter/ssh_keys/dut_key + # Option 2: Inline key (from secret management) + # ssh_identity: | + # -----BEGIN OPENSSH PRIVATE KEY----- + # ...key content... + # -----END OPENSSH PRIVATE KEY----- + children: + tcp: + type: jumpstarter_driver_network.driver.TcpNetwork + config: + host: 192.168.1.100 + port: 22 diff --git a/packages/jumpstarter-driver-ssh-mitm/jumpstarter_driver_ssh_mitm/__init__.py b/packages/jumpstarter-driver-ssh-mitm/jumpstarter_driver_ssh_mitm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/jumpstarter-driver-ssh-mitm/jumpstarter_driver_ssh_mitm/driver.py b/packages/jumpstarter-driver-ssh-mitm/jumpstarter_driver_ssh_mitm/driver.py new file mode 100644 index 000000000..139ba9c4e --- /dev/null +++ b/packages/jumpstarter-driver-ssh-mitm/jumpstarter_driver_ssh_mitm/driver.py @@ -0,0 +1,459 @@ +""" +SSH MITM Driver: Secure SSH proxy with server-side key storage. + +This driver implements a Man-in-the-Middle SSH proxy where the private key +never leaves the exporter. It uses paramiko to: +1. Accept SSH connections from clients (via Jumpstarter stream) +2. Connect to the target DUT using stored credentials +3. Proxy traffic between client and DUT +""" + +import io +import logging +import socket +import threading +from contextlib import asynccontextmanager, suppress +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import paramiko +from anyio import get_cancelled_exc_class +from anyio.from_thread import BlockingPortal +from jumpstarter_driver_network.driver import TcpNetwork + +from jumpstarter.common.exceptions import ConfigurationError +from jumpstarter.driver import Driver, exportstream +from jumpstarter.streams.common import create_memory_stream + +logging.getLogger("paramiko").setLevel(logging.WARNING) + + +class SSHMITMError(Exception): + """Base exception for SSH MITM driver errors.""" + + +BUFFER_SIZE = 65536 + + +class StreamSocket: + """ + Adapter to bridge async Jumpstarter streams with paramiko's blocking sockets. + + Paramiko requires a socket-like interface. This class uses a socketpair + and forwarding threads to connect async streams to paramiko's Transport. + """ + + def __init__(self, send_stream, recv_stream, portal: BlockingPortal): + self.client_sock, self.server_sock = socket.socketpair() + self.client_sock.setblocking(True) + self.server_sock.setblocking(True) + + self.client_sock.settimeout(1.0) + self.server_sock.settimeout(1.0) + + self.send_stream = send_stream + self.recv_stream = recv_stream + self.portal = portal + self._running = True + + self._recv_thread = threading.Thread(target=self._forward_recv, daemon=True) + self._send_thread = threading.Thread(target=self._forward_send, daemon=True) + + def start(self): + """Start bidirectional forwarding threads.""" + self._recv_thread.start() + self._send_thread.start() + + def _forward_recv(self): + """Forward: Jumpstarter stream → socket (for paramiko to read).""" + socket_logger = logging.getLogger("SSHMITM.StreamSocket") + try: + while self._running: + try: + data = self.portal.call(self.recv_stream.receive) + if data: + socket_logger.debug("recv->sock %d bytes", len(data)) + self.client_sock.sendall(data) + else: + break + except (BrokenPipeError, OSError): + break + self._running = False + except Exception as exc: + socket_logger.debug("recv loop stopped: %s", exc) + + def _forward_send(self): + """Forward: socket → Jumpstarter stream (paramiko writes).""" + socket_logger = logging.getLogger("SSHMITM.StreamSocket") + try: + while self._running: + try: + data = self.client_sock.recv(BUFFER_SIZE) + if data: + socket_logger.debug("sock->send %d bytes", len(data)) + self.portal.call(self.send_stream.send, data) + else: + break + except socket.timeout: + # Allow loop to check _running and exit cleanly + continue + except (BrokenPipeError, OSError): + break + self._running = False + except Exception as exc: + socket_logger.debug("send loop stopped: %s", exc) + + def get_paramiko_socket(self): + """Get the socket for paramiko Transport.""" + return self.server_sock + + def close(self): + """Clean up sockets.""" + self._running = False + # Close async streams to unblock portal calls + with suppress(Exception): + self.portal.call(self.recv_stream.aclose) + with suppress(Exception): + self.portal.call(self.send_stream.aclose) + try: + self.client_sock.shutdown(socket.SHUT_RDWR) + except Exception: + pass + try: + self.server_sock.shutdown(socket.SHUT_RDWR) + except Exception: + pass + try: + self.client_sock.close() + except Exception: + pass + try: + self.server_sock.close() + except Exception: + pass + self._recv_thread.join(timeout=5) + self._send_thread.join(timeout=5) + if self._recv_thread.is_alive() or self._send_thread.is_alive(): + logging.getLogger("SSHMITM.StreamSocket").debug("StreamSocket threads did not shut down cleanly") + + +class MITMServerInterface(paramiko.ServerInterface): + """ + Paramiko server interface that accepts all authentication. + Since clients have already authenticated through Jumpstarter's lease + system, we accept any SSH authentication method here. + """ + + def __init__(self, allowed_username: str = "", default_dut_username: str = ""): + self.allowed_username = allowed_username + self.default_dut_username = default_dut_username + self.client_username: str | None = None # Username from client connection + self.event = threading.Event() + self.exec_command: str | None = None + self.pty_width: int | None = None + self.pty_height: int | None = None + self.pty_term: str = "xterm" + + def _check_username(self, username: str | None) -> bool: + if self.allowed_username and username and username != self.allowed_username: + return False + return True + + def check_channel_request(self, kind, chanid): + if kind == "session": + return paramiko.OPEN_SUCCEEDED + return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + + def check_auth_password(self, username, password): + if self._check_username(username): + self.client_username = username + return paramiko.AUTH_SUCCESSFUL + return paramiko.AUTH_FAILED + + def check_auth_publickey(self, username, key): + if self._check_username(username): + self.client_username = username + return paramiko.AUTH_SUCCESSFUL + return paramiko.AUTH_FAILED + + def check_auth_none(self, username): + if self._check_username(username): + self.client_username = username + return paramiko.AUTH_SUCCESSFUL + return paramiko.AUTH_FAILED + + def get_allowed_auths(self, username): + return "none,password,publickey" + + def check_channel_shell_request(self, channel): + self.exec_command = None + self.event.set() + return True + + def check_channel_exec_request(self, channel, command): + self.exec_command = command.decode() if isinstance(command, bytes) else command + self.event.set() + return True + + def check_channel_pty_request(self, channel, term, width, height, pixelwidth, pixelheight, modes): + self.pty_term = term or "xterm" + self.pty_width = width + self.pty_height = height + return True + + +@dataclass(kw_only=True) +class SSHMITM(Driver): + """ + SSH MITM proxy driver with server-side key storage. + This driver acts as a network layer that provides authenticated SSH proxy + connections. It is designed to be used as a child of SSHWrapper. + """ + + default_username: str = "" + ssh_identity: str | None = None + ssh_identity_file: str | None = None + channel_timeout: float = 30.0 + default_pty_width: int = 80 + default_pty_height: int = 24 + + _host_key: Optional[paramiko.RSAKey] = field(init=False, default=None) + + def __post_init__(self): + if hasattr(super(), "__post_init__"): + super().__post_init__() + + if "tcp" not in self.children: + raise ConfigurationError("'tcp' child is required via ref, or directly as a TcpNetwork driver instance") + + if self.ssh_identity and self.ssh_identity_file: + raise ConfigurationError("Cannot specify both ssh_identity and ssh_identity_file") + + if not self.ssh_identity and not self.ssh_identity_file: + raise ConfigurationError("Either ssh_identity or ssh_identity_file must be provided") + + # Generate ephemeral host key for MITM server + self._host_key = paramiko.RSAKey.generate(2048) + + @classmethod + def client(cls) -> str: + return "jumpstarter_driver_network.client.NetworkClient" + + def _get_ssh_identity(self) -> str | None: + """Get SSH private key content (internal use only).""" + if self.ssh_identity: + return self.ssh_identity + if self.ssh_identity_file: + try: + return Path(self.ssh_identity_file).expanduser().read_text() + except Exception as e: + raise ConfigurationError(f"Failed to read ssh_identity_file '{self.ssh_identity_file}': {e}") from None + return None + + def _get_target_connection(self) -> tuple[str, int]: + """Get DUT host and port from TCP child driver.""" + tcp_driver: TcpNetwork = self.children["tcp"] + return tcp_driver.host, tcp_driver.port or 22 + + def _load_private_key(self, key_data: str) -> paramiko.PKey: + """Load private key, auto-detecting type (Ed25519, RSA, ECDSA, DSS).""" + key_file = io.StringIO(key_data) + + key_classes = [ + paramiko.Ed25519Key, + paramiko.RSAKey, + paramiko.ECDSAKey, + paramiko.DSSKey, + ] + + for key_class in key_classes: + try: + key_file.seek(0) + return key_class.from_private_key(key_file) + except (paramiko.SSHException, ValueError): + continue + + raise SSHMITMError("Unable to load SSH key - unsupported key type") + + def _create_dut_client(self, dut_username: str | None = None) -> paramiko.SSHClient: + """Create paramiko SSH client connected to DUT using stored key. + + Args: + dut_username: Username to use for DUT connection. If None, uses default_username or "root". + """ + target_host, target_port = self._get_target_connection() + + ssh_identity = self._get_ssh_identity() + if not ssh_identity: + raise SSHMITMError("SSH identity not available") + + pkey = self._load_private_key(ssh_identity) + + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + # Use provided username, or fall back to default_username, or "root" + username = dut_username or self.default_username or "root" + + self.logger.debug("Connecting to DUT: %s@%s:%d", username, target_host, target_port) + + try: + client.connect( + hostname=target_host, + port=target_port, + username=username, + pkey=pkey, + look_for_keys=False, + allow_agent=False, + timeout=10, + ) + except Exception as e: + self.logger.error("Failed to connect to DUT %s@%s:%d: %s", username, target_host, target_port, e) + raise + + return client + + def _proxy_channels(self, client_channel, dut_channel): + """Bidirectional proxy between client and DUT SSH channels.""" + + def forward(src, dst, name): + try: + while True: + data = src.recv(BUFFER_SIZE) + if not data: + break + dst.sendall(data) + except Exception as e: + self.logger.debug("Channel %s ended: %s", name, e) + finally: + try: + dst.close() + except Exception: + pass + + t1 = threading.Thread(target=forward, args=(client_channel, dut_channel, "client→dut"), daemon=True) + t2 = threading.Thread(target=forward, args=(dut_channel, client_channel, "dut→client"), daemon=True) + t1.start() + t2.start() + t1.join() + t2.join() + + def _open_dut_channel(self, server: MITMServerInterface) -> tuple[paramiko.SSHClient, paramiko.Channel]: + """Open appropriate DUT channel (shell or exec) based on client request.""" + dut_username = server.client_username or server.default_dut_username + dut_client = self._create_dut_client(dut_username=dut_username) + transport = dut_client.get_transport() + if transport is None: + dut_client.close() + raise SSHMITMError("Failed to open SSH transport for DUT") + + channel = transport.open_session() + + if server.exec_command: + self.logger.debug("Executing command on DUT via MITM: %s", server.exec_command) + channel.exec_command(server.exec_command) + else: + width = server.pty_width or self.default_pty_width + height = server.pty_height or self.default_pty_height + channel.get_pty(term=server.pty_term, width=width, height=height) + channel.invoke_shell() + + return dut_client, channel + + def _handle_session(self, transport: paramiko.Transport): # noqa: C901 + """Handle incoming SSH session: accept client, connect to DUT, proxy.""" + server = MITMServerInterface(self.default_username, default_dut_username=self.default_username) + + try: + transport.add_server_key(self._host_key) + transport.start_server(server=server) + except paramiko.SSHException as e: + self.logger.error("SSH negotiation failed: %s", e) + return + + client_channel = transport.accept(timeout=self.channel_timeout) + if client_channel is None: + self.logger.error("No channel opened by client") + return + + if not server.event.wait(timeout=self.channel_timeout): + self.logger.error("No exec/shell request received before timeout") + client_channel.close() + return + + dut_client: paramiko.SSHClient | None = None + dut_channel: paramiko.Channel | None = None + try: + dut_client, dut_channel = self._open_dut_channel(server) + self.logger.info( + "MITM proxy established: client <-> DUT (mode=%s)", + "exec" if server.exec_command else "shell", + ) + self._proxy_channels(client_channel, dut_channel) + + if server.exec_command: + try: + exit_status = dut_channel.recv_exit_status() + client_channel.send_exit_status(exit_status) + except Exception: + pass + finally: + client_channel.close() + + except Exception as e: + self.logger.error("Failed to connect to DUT: %s", e) + client_channel.close() + finally: + if dut_channel: + try: + dut_channel.close() + except Exception: + pass + if dut_client: + try: + dut_client.close() + except Exception: + pass + transport.close() + + @exportstream + @asynccontextmanager + async def connect(self): + """ + Stream endpoint for SSH proxy connections. + + When a client connects to this stream, we launch a paramiko-based + SSH server that proxies traffic to the DUT. From the client's + perspective this behaves like a normal SSH server. + + This is used by SSHWrapper as the 'tcp' child - SSHWrapper spawns + a local SSH binary that connects through this proxy. + """ + cancelled_exc = get_cancelled_exc_class() + client_stream, server_stream = create_memory_stream() + + async with BlockingPortal() as portal: + bridge = StreamSocket( + send_stream=server_stream, + recv_stream=server_stream, + portal=portal, + ) + bridge.start() + + transport = paramiko.Transport(bridge.get_paramiko_socket()) + server_thread = threading.Thread(target=self._handle_session, args=(transport,), daemon=True) + server_thread.start() + + try: + yield client_stream + except (cancelled_exc, Exception) as e: + if isinstance(e, cancelled_exc): + self.logger.debug("SSH stream cancelled by client") + else: + self.logger.debug("SSH stream ended: %s", type(e).__name__) + finally: + with suppress(Exception): + transport.close() + bridge.close() + server_thread.join(timeout=5) diff --git a/packages/jumpstarter-driver-ssh-mitm/jumpstarter_driver_ssh_mitm/driver_test.py b/packages/jumpstarter-driver-ssh-mitm/jumpstarter_driver_ssh_mitm/driver_test.py new file mode 100644 index 000000000..4b42c62f3 --- /dev/null +++ b/packages/jumpstarter-driver-ssh-mitm/jumpstarter_driver_ssh_mitm/driver_test.py @@ -0,0 +1,372 @@ +"""Tests for the SSH MITM driver""" + +import threading +from unittest.mock import MagicMock, patch + +import pytest +from jumpstarter_driver_network.driver import TcpNetwork + +from jumpstarter_driver_ssh_mitm.driver import SSHMITM, SSHMITMError + +from jumpstarter.common.exceptions import ConfigurationError + +TEST_SSH_KEY = """-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW +QyNTUxOQAAACBHK2n0Z+2k2LXuT7+0zTcSCfprKPDR+9xG7nXZ7zRy5AAAAJgq0lzTKtJc +0wAAAAtzc2gtZWQyNTUxOQAAACBHK2n0Z+2k2LXuT7+0zTcSCfprKPDR+9xG7nXZ7zRy5A +AAAEBpIq2lZeL9Ey+OQhKfhIIhK1U0rkqMjFolbvQZ8qGVnkcraeRn7aTYte5Pv7TNNxIJ ++mso8NH73EbuddnvNHLkAAAADXRlc3RAZXhhbXBsZQECAwQF +-----END OPENSSH PRIVATE KEY----- +""" + + +class TestSSHMITMDriver: + """Tests for SSHMITM driver configuration and setup""" + + def test_defaults(self): + """Test SSH MITM with default configuration""" + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="", + ssh_identity=TEST_SSH_KEY, + ) + + assert instance.default_username == "" + assert instance.ssh_identity == TEST_SSH_KEY + # Now returns NetworkClient since SSHMITM is a network layer + assert instance.client() == "jumpstarter_driver_network.client.NetworkClient" + + def test_configuration_error_missing_tcp(self): + """Test SSH MITM raises error when tcp child is missing""" + with pytest.raises(ConfigurationError, match="'tcp' child is required"): + SSHMITM(children={}, default_username="", ssh_identity=TEST_SSH_KEY) + + def test_configuration_error_missing_identity(self): + """Test SSH MITM raises error when identity is missing""" + with pytest.raises( + ConfigurationError, + match="Either ssh_identity or ssh_identity_file must be provided", + ): + SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="", + ) + + def test_configuration_error_both_identities(self): + """Test SSH MITM raises error when both identity options are provided""" + with pytest.raises( + ConfigurationError, + match="Cannot specify both ssh_identity and ssh_identity_file", + ): + SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="", + ssh_identity=TEST_SSH_KEY, + ssh_identity_file="/path/to/key", + ) + + def test_identity_from_inline(self): + """Test SSH identity from inline content""" + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY, + ) + + # Internal access should work + assert instance._get_ssh_identity() == TEST_SSH_KEY + + def test_identity_from_file(self, tmp_path): + """Test SSH identity from file""" + temp_file_path = tmp_path / "_test_key" + temp_file_path.write_text(TEST_SSH_KEY) + + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity_file=str(temp_file_path), + ) + + # Internal access should work + assert instance._get_ssh_identity() == TEST_SSH_KEY + + +class TestSSHMITMSecurity: + """Security-focused tests""" + + def test_key_accessible_internally(self): + """Verify key is accessible on driver (server) side""" + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + ssh_identity=TEST_SSH_KEY, + ) + + # Internal access works + assert instance._get_ssh_identity() == TEST_SSH_KEY + + def test_key_not_accessible_via_rpc(self): + """Verify key cannot be accessed via RPC through NetworkClient""" + from jumpstarter.common.utils import serve + + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + ssh_identity=TEST_SSH_KEY, + ) + + with serve(instance) as client: + # NetworkClient should not have access to get_ssh_identity + # The method is private and not exported + assert not hasattr(client, "get_ssh_identity") + assert not hasattr(client, "_get_ssh_identity") + + def test_uses_network_client(self): + """Verify SSHMITM uses NetworkClient (not a custom client)""" + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + ssh_identity=TEST_SSH_KEY, + ) + + # Should return NetworkClient since SSHMITM is a network layer + assert instance.client() == "jumpstarter_driver_network.client.NetworkClient" + + +class TestSSHMITMCleanup: + """Tests for resource cleanup""" + + def test_close_cleans_up(self): + """Test that close() cleans up resources""" + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY, + ) + + # Should not raise + instance.close() + + def test_identity_file_not_found(self): + """Test error handling when identity file doesn't exist""" + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity_file="/nonexistent/path/to/key", + ) + + # Calling _get_ssh_identity should raise ConfigurationError + with pytest.raises(ConfigurationError): + instance._get_ssh_identity() + + +class TestSSHMITMKeyTypes: + """Tests for SSH key type detection""" + + def test_load_ed25519_key(self): + """Test loading Ed25519 key""" + mock_pkey = MagicMock() + + with patch( + "jumpstarter_driver_ssh_mitm.driver.paramiko.Ed25519Key.from_private_key", + return_value=mock_pkey, + ): + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + ssh_identity=TEST_SSH_KEY, + ) + + key = instance._load_private_key(TEST_SSH_KEY) + assert key == mock_pkey + + def test_load_rsa_key_fallback(self): + """Test RSA key loading when Ed25519 fails""" + import paramiko + + mock_rsa_key = MagicMock() + + with ( + patch( + "jumpstarter_driver_ssh_mitm.driver.paramiko.Ed25519Key.from_private_key", + side_effect=paramiko.SSHException("Not Ed25519"), + ), + patch( + "jumpstarter_driver_ssh_mitm.driver.paramiko.RSAKey.from_private_key", + return_value=mock_rsa_key, + ), + ): + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + ssh_identity=TEST_SSH_KEY, + ) + + key = instance._load_private_key(TEST_SSH_KEY) + assert key == mock_rsa_key + + def test_unsupported_key_type(self): + """Test error when key type is not supported""" + import paramiko + + with ( + patch( + "jumpstarter_driver_ssh_mitm.driver.paramiko.Ed25519Key.from_private_key", + side_effect=paramiko.SSHException("Not Ed25519"), + ), + patch( + "jumpstarter_driver_ssh_mitm.driver.paramiko.RSAKey.from_private_key", + side_effect=paramiko.SSHException("Not RSA"), + ), + patch( + "jumpstarter_driver_ssh_mitm.driver.paramiko.ECDSAKey.from_private_key", + side_effect=paramiko.SSHException("Not ECDSA"), + ), + patch( + "jumpstarter_driver_ssh_mitm.driver.paramiko.DSSKey.from_private_key", + side_effect=paramiko.SSHException("Not DSS"), + ), + ): + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + ssh_identity=TEST_SSH_KEY, + ) + + with pytest.raises(SSHMITMError, match="unsupported key type"): + instance._load_private_key(TEST_SSH_KEY) + + +class TestSSHMITMStream: + """Tests for stream/connect behavior""" + + @pytest.mark.anyio + async def test_connect_starts_session(self, monkeypatch): + """Ensure connect stream spins up the handler thread""" + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="tester", + ssh_identity=TEST_SSH_KEY, + ) + + started = threading.Event() + + def fake_handle_session(self, transport): + started.set() + + instance._handle_session = fake_handle_session.__get__(instance, SSHMITM) + + class DummyTransport: + def __init__(self, sock): + self.sock = sock + + def close(self): + pass + + monkeypatch.setattr("jumpstarter_driver_ssh_mitm.driver.paramiko.Transport", DummyTransport) + + async with instance.connect() as stream: + await stream.aclose() + + assert started.is_set() + + def test_handle_session_timeout(self, caplog): + """Test that _handle_session properly handles timeout when no exec/shell request is received""" + import logging + + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY, + channel_timeout=0.1, # Short timeout for testing + ) + + # Track if channel close is called + channel_close_called = [] + + # Create a mock transport that simulates a client connecting but never sending exec/shell + mock_transport = MagicMock() + mock_channel = MagicMock() + + def track_channel_close(): + channel_close_called.append(True) + + mock_channel.close = track_channel_close + mock_transport.accept.return_value = mock_channel + mock_transport.add_server_key = MagicMock() + mock_transport.start_server = MagicMock() + mock_transport.close = MagicMock() + + # Call _handle_session - the event.wait() will timeout since no exec/shell request is made + # This simulates a client that connects but never sends a command + with caplog.at_level(logging.ERROR): + instance._handle_session(mock_transport) + + # Verify timeout error was logged (line 403 in driver.py) + assert "No exec/shell request received before timeout" in caplog.text + + # Verify channel.close() was called (line 404) - this is the key timeout behavior + assert len(channel_close_called) > 0, "Channel close should have been called due to timeout" + + def test_mitm_proxy_forwards_data(self): + """Integration test: Verify MITM proxy correctly forwards data between client and DUT""" + instance = SSHMITM( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY, + ) + + # Mock data that will flow through the proxy + client_to_dut_data = b"test command\n" + dut_to_client_data = b"command output\n" + + # Mock DUT channel - simulate receiving data from DUT + mock_dut_channel = MagicMock() + mock_dut_channel.recv.side_effect = [dut_to_client_data, b""] # Return data then EOF + mock_dut_channel.sendall = MagicMock() + mock_dut_channel.close = MagicMock() + mock_dut_channel.recv_exit_status.return_value = 0 + + # Mock client channel - simulate receiving data from client + mock_client_channel = MagicMock() + mock_client_channel.recv.side_effect = [client_to_dut_data, b""] # Return data then EOF + mock_client_channel.sendall = MagicMock() + mock_client_channel.close = MagicMock() + mock_client_channel.send_exit_status = MagicMock() + + # Mock DUT client + mock_dut_client = MagicMock() + mock_dut_transport = MagicMock() + mock_dut_transport.open_session.return_value = mock_dut_channel + mock_dut_client.get_transport.return_value = mock_dut_transport + mock_dut_client.close = MagicMock() + + # Mock transport + mock_transport = MagicMock() + mock_transport.accept.return_value = mock_client_channel + mock_transport.add_server_key = MagicMock() + mock_transport.start_server = MagicMock() + mock_transport.close = MagicMock() + + # Mock _create_dut_client and _open_dut_channel + with ( + patch.object(instance, "_create_dut_client", return_value=mock_dut_client), + patch.object(instance, "_open_dut_channel", return_value=(mock_dut_client, mock_dut_channel)), + ): + # Create server interface and simulate exec request + from jumpstarter_driver_ssh_mitm.driver import MITMServerInterface + + server = MITMServerInterface(instance.default_username) + server.exec_command = "test command" + server.event.set() + + # Mock the server creation in _handle_session + with patch("jumpstarter_driver_ssh_mitm.driver.MITMServerInterface", return_value=server): + # Call _handle_session with mocked transport + instance._handle_session(mock_transport) + + assert mock_dut_channel is not None + assert mock_client_channel is not None + + # Verify exit status was forwarded for exec commands + mock_client_channel.send_exit_status.assert_called_once_with(0) + + # Verify cleanup + mock_client_channel.close.assert_called() + mock_dut_channel.close.assert_called() + mock_dut_client.close.assert_called() + mock_transport.close.assert_called() diff --git a/packages/jumpstarter-driver-ssh-mitm/pyproject.toml b/packages/jumpstarter-driver-ssh-mitm/pyproject.toml new file mode 100644 index 000000000..8d08d24a3 --- /dev/null +++ b/packages/jumpstarter-driver-ssh-mitm/pyproject.toml @@ -0,0 +1,51 @@ +[project] +name = "jumpstarter-driver-ssh-mitm" +dynamic = ["version", "urls"] +description = "SSH Man-in-the-Middle driver for Jumpstarter that securely stores SSH keys on the server and proxies connections" +readme = "README.md" +license = "Apache-2.0" +authors = [ + { name = "Bella Khizgiyaev", email = "bkhizgiy@redhat.com" } +] +requires-python = ">=3.11" +dependencies = [ + "anyio>=4.10.0", + "jumpstarter", + "jumpstarter-driver-network", + "paramiko>=3.0.0", +] + +[project.entry-points."jumpstarter.drivers"] +ssh_mitm = "jumpstarter_driver_ssh_mitm" + +[tool.hatch.version] +source = "vcs" +raw-options = { 'root' = '../../'} + +[tool.hatch.metadata.hooks.vcs.urls] +Homepage = "https://jumpstarter.dev" +source_archive = "https://github.com/jumpstarter-dev/repo/archive/{commit_hash}.zip" + +[tool.pytest.ini_options] +addopts = "--cov --cov-report=html --cov-report=xml" +log_cli = true +log_cli_level = "INFO" +testpaths = ["jumpstarter_driver_ssh_mitm"] +asyncio_mode = "auto" +log_cli_format = "%(levelname)s %(name)s: %(message)s" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" + +[build-system] +requires = ["hatchling", "hatch-vcs", "hatch-pin-jumpstarter"] +build-backend = "hatchling.build" + +[tool.hatch.build.hooks.pin_jumpstarter] +name = "pin_jumpstarter" + +[dependency-groups] +dev = [ + "pytest-cov>=6.0.0", + "pytest>=8.3.3", + "trio>=0.28.0", +] + diff --git a/uv.lock b/uv.lock index 6d7d2171a..99fa450fa 100644 --- a/uv.lock +++ b/uv.lock @@ -33,6 +33,7 @@ members = [ "jumpstarter-driver-shell", "jumpstarter-driver-snmp", "jumpstarter-driver-ssh", + "jumpstarter-driver-ssh-mitm", "jumpstarter-driver-tasmota", "jumpstarter-driver-tftp", "jumpstarter-driver-tmt", @@ -2140,6 +2141,38 @@ dev = [ { name = "pytest-cov", specifier = ">=6.0.0" }, ] +[[package]] +name = "jumpstarter-driver-ssh-mitm" +source = { editable = "packages/jumpstarter-driver-ssh-mitm" } +dependencies = [ + { name = "anyio" }, + { name = "jumpstarter" }, + { name = "jumpstarter-driver-network" }, + { name = "paramiko" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pytest" }, + { name = "pytest-cov" }, + { name = "trio" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.10.0" }, + { name = "jumpstarter", editable = "packages/jumpstarter" }, + { name = "jumpstarter-driver-network", editable = "packages/jumpstarter-driver-network" }, + { name = "paramiko", specifier = ">=3.0.0" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-cov", specifier = ">=6.0.0" }, + { name = "trio", specifier = ">=0.28.0" }, +] + [[package]] name = "jumpstarter-driver-tasmota" source = { editable = "packages/jumpstarter-driver-tasmota" }