From c6fab27e283b5a37a46de5b4071ab38f6f30e553 Mon Sep 17 00:00:00 2001 From: Miguel Angel Ajo Pelayo Date: Sun, 5 Oct 2025 01:14:01 +0200 Subject: [PATCH 1/2] Test the jumpstarter-driver-ssh identity key injection (#689) (cherry picked from commit 0f1555fb4784511236df4a0942095ed987bd2dd5) --- .../jumpstarter_driver_ssh/driver_test.py | 299 ++++++++++++++++++ 1 file changed, 299 insertions(+) diff --git a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py index 4501828c9..0533c5a65 100644 --- a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py +++ b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py @@ -10,6 +10,13 @@ from jumpstarter.common.exceptions import ConfigurationError from jumpstarter.common.utils import serve +# Test SSH key content used in multiple tests +TEST_SSH_KEY = ( + "-----BEGIN OPENSSH PRIVATE KEY-----\n" + "test-key-content\n" + "-----END OPENSSH PRIVATE KEY-----" +) + def test_ssh_wrapper_defaults(): """Test SSH wrapper with default configuration""" @@ -348,3 +355,295 @@ def test_ssh_command_with_command_l_flag_does_not_interfere_with_username_inject assert ssh_l_index < hostname_index < command_l_index assert result == 0 + + +def test_ssh_identity_string_configuration(): + """Test SSH wrapper with ssh_identity string configuration""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY + ) + + # Test that the instance was created correctly + assert instance.ssh_identity == TEST_SSH_KEY + assert instance.ssh_identity_file is None + + # Test that the client class is correct + assert instance.client() == "jumpstarter_driver_ssh.client.SSHWrapperClient" + + +def test_ssh_identity_file_configuration(): + """Test SSH wrapper with ssh_identity_file configuration""" + import os + import tempfile + + # Create a temporary file with SSH key content + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_test_key') as temp_file: + temp_file.write(TEST_SSH_KEY) + temp_file_path = temp_file.name + + try: + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity_file=temp_file_path + ) + + # Test that the instance was created correctly + assert instance.ssh_identity == TEST_SSH_KEY + assert instance.ssh_identity_file == temp_file_path + + # Test that the client class is correct + assert instance.client() == "jumpstarter_driver_ssh.client.SSHWrapperClient" + finally: + # Clean up the temporary file + os.unlink(temp_file_path) + + +def test_ssh_identity_validation_error(): + """Test SSH wrapper raises error when both ssh_identity and ssh_identity_file are provided""" + with pytest.raises(ConfigurationError, match="Cannot specify both ssh_identity and ssh_identity_file"): + SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity="test-key-content", + ssh_identity_file="/path/to/key" + ) + + +def test_ssh_identity_file_read_error(): + """Test SSH wrapper raises error when ssh_identity_file cannot be read""" + with pytest.raises(ConfigurationError, match="Failed to read ssh_identity_file"): + SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity_file="/nonexistent/path/to/key" + ) + + +def test_ssh_command_with_identity_string(): + """Test SSH command execution with ssh_identity string""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + + # Test SSH command with identity string + result = client.run(False, ["hostname"]) + + # Verify subprocess.run was called + assert mock_run.called + call_args = mock_run.call_args[0][0] # First positional argument + + # Should include -i flag with temporary identity file + assert "-i" in call_args + identity_file_index = call_args.index("-i") + identity_file_path = call_args[identity_file_index + 1] + + # The identity file should be a temporary file + assert identity_file_path.endswith("_ssh_key") + assert "/tmp" in identity_file_path or "/var/tmp" in identity_file_path + + # Should include -l testuser + assert "-l" in call_args + assert "testuser" in call_args + + # Should include the actual hostname (127.0.0.1) at the end + assert "127.0.0.1" in call_args + assert "hostname" in call_args + + assert result == 0 + + +def test_ssh_command_with_identity_file(): + """Test SSH command execution with ssh_identity_file""" + import os + import tempfile + + # Create a temporary file with SSH key content + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_test_key') as temp_file: + temp_file.write(TEST_SSH_KEY) + temp_file_path = temp_file.name + + try: + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity_file=temp_file_path + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + + # Test SSH command with identity file + result = client.run(False, ["hostname"]) + + # Verify subprocess.run was called + assert mock_run.called + call_args = mock_run.call_args[0][0] # First positional argument + + # Should include -i flag with temporary identity file + assert "-i" in call_args + identity_file_index = call_args.index("-i") + identity_file_path = call_args[identity_file_index + 1] + + # The identity file should be a temporary file (not the original file) + assert identity_file_path.endswith("_ssh_key") + assert "/tmp" in identity_file_path or "/var/tmp" in identity_file_path + assert identity_file_path != temp_file_path + + # Should include -l testuser + assert "-l" in call_args + assert "testuser" in call_args + + # Should include the actual hostname (127.0.0.1) at the end + assert "127.0.0.1" in call_args + assert "hostname" in call_args + + assert result == 0 + finally: + # Clean up the temporary file + os.unlink(temp_file_path) + + +def test_ssh_command_without_identity(): + """Test SSH command execution without identity (should not include -i flag)""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser" + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + + # Test SSH command without identity + result = client.run(False, ["hostname"]) + + # Verify subprocess.run was called + assert mock_run.called + call_args = mock_run.call_args[0][0] # First positional argument + + # Should NOT include -i flag + assert "-i" not in call_args + + # Should include -l testuser + assert "-l" in call_args + assert "testuser" in call_args + + # Should include the actual hostname (127.0.0.1) at the end + assert "127.0.0.1" in call_args + assert "hostname" in call_args + + assert result == 0 + + +def test_ssh_identity_temp_file_creation_and_cleanup(): + """Test that temporary identity file is created and cleaned up properly""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + + with patch('tempfile.NamedTemporaryFile') as mock_temp_file: + with patch('os.chmod') as mock_chmod: + with patch('os.unlink') as mock_unlink: + # Mock the temporary file + mock_temp_file_instance = MagicMock() + mock_temp_file_instance.name = "/tmp/test_ssh_key_12345" + mock_temp_file_instance.write = MagicMock() + mock_temp_file_instance.close = MagicMock() + mock_temp_file.return_value = mock_temp_file_instance + + # Test SSH command with identity + result = client.run(False, ["hostname"]) + + # Verify temporary file was created + mock_temp_file.assert_called_once_with(mode='w', delete=False, suffix='_ssh_key') + mock_temp_file_instance.write.assert_called_once_with(TEST_SSH_KEY) + mock_temp_file_instance.close.assert_called_once() + + # Verify proper permissions were set + mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600) + + # Verify temporary file was cleaned up + mock_unlink.assert_called_once_with("/tmp/test_ssh_key_12345") + + assert result == 0 + + +def test_ssh_identity_temp_file_creation_error(): + """Test error handling when temporary identity file creation fails""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + + with patch('tempfile.NamedTemporaryFile') as mock_temp_file: + mock_temp_file.side_effect = OSError("Permission denied") + + # Test SSH command with identity should raise an error + # The exception will be wrapped in an ExceptionGroup due to the context manager + with pytest.raises(ExceptionGroup) as exc_info: + client.run(False, ["hostname"]) + + # Check that the original OSError is in the exception group + assert any(isinstance(e, OSError) and "Permission denied" in str(e) for e in exc_info.value.exceptions) + + +def test_ssh_identity_temp_file_cleanup_error(): + """Test error handling when temporary identity file cleanup fails""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity=TEST_SSH_KEY + ) + + with serve(instance) as client: + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + + with patch('tempfile.NamedTemporaryFile') as mock_temp_file: + with patch('os.chmod') as mock_chmod: + with patch('os.unlink') as mock_unlink: + # Mock the temporary file + mock_temp_file_instance = MagicMock() + mock_temp_file_instance.name = "/tmp/test_ssh_key_12345" + mock_temp_file_instance.write = MagicMock() + mock_temp_file_instance.close = MagicMock() + mock_temp_file.return_value = mock_temp_file_instance + + # Mock cleanup failure + mock_unlink.side_effect = OSError("Permission denied") + + # Test SSH command with identity - should still succeed but log warning + with patch.object(client, 'logger') as mock_logger: + result = client.run(False, ["hostname"]) + + # Verify chmod was called + mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600) + + # Verify warning was logged + mock_logger.warning.assert_called_once() + warning_call = mock_logger.warning.call_args[0][0] + assert "Failed to clean up temporary identity file" in warning_call + assert "/tmp/test_ssh_key_12345" in warning_call + + assert result == 0 From 228a4e8d0c1d32d9e93100c015a976dabde524c6 Mon Sep 17 00:00:00 2001 From: Michal Skrivanek Date: Mon, 20 Oct 2025 10:33:05 +0200 Subject: [PATCH 2/2] load ssh identity lazily on first use, to prevent a failure on startup if the file doesn't exist yet (cherry picked from commit af481163563459b69089b230a27a59de93847df0) --- .../jumpstarter_driver_ssh/driver.py | 13 +++++----- .../jumpstarter_driver_ssh/driver_test.py | 26 ++++++++++++++----- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver.py b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver.py index ec5597ca9..2a0fda411 100644 --- a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver.py +++ b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver.py @@ -24,13 +24,6 @@ def __post_init__(self): if self.ssh_identity and self.ssh_identity_file: raise ConfigurationError("Cannot specify both ssh_identity and ssh_identity_file") - # If ssh_identity_file is provided, read it into ssh_identity - if self.ssh_identity_file: - try: - self.ssh_identity = Path(self.ssh_identity_file).read_text() - except Exception as e: - raise ConfigurationError(f"Failed to read ssh_identity_file '{self.ssh_identity_file}': {e}") from None - @classmethod def client(cls) -> str: return "jumpstarter_driver_ssh.client.SSHWrapperClient" @@ -48,4 +41,10 @@ def get_ssh_command(self): @export def get_ssh_identity(self): """Get the SSH identity key content""" + # If ssh_identity_file is provided, read it lazily and cache in ssh_identity + if self.ssh_identity is None and self.ssh_identity_file: + try: + self.ssh_identity = Path(self.ssh_identity_file).read_text() + except Exception as e: + raise ConfigurationError(f"Failed to read ssh_identity_file '{self.ssh_identity_file}': {e}") from None return self.ssh_identity diff --git a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py index 0533c5a65..02ea1ba1a 100644 --- a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py +++ b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py @@ -391,9 +391,17 @@ def test_ssh_identity_file_configuration(): ) # Test that the instance was created correctly - assert instance.ssh_identity == TEST_SSH_KEY + # ssh_identity should be None until first use (lazy loading) + assert instance.ssh_identity is None assert instance.ssh_identity_file == temp_file_path + # Test that get_ssh_identity() reads the file on first use + identity = instance.get_ssh_identity() + assert identity == TEST_SSH_KEY + + # Test that ssh_identity is now cached + assert instance.ssh_identity == TEST_SSH_KEY + # Test that the client class is correct assert instance.client() == "jumpstarter_driver_ssh.client.SSHWrapperClient" finally: @@ -413,13 +421,17 @@ def test_ssh_identity_validation_error(): def test_ssh_identity_file_read_error(): - """Test SSH wrapper raises error when ssh_identity_file cannot be read""" + """Test SSH wrapper raises error when ssh_identity_file cannot be read on first use""" + # Instance creation should succeed (lazy loading) + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ssh_identity_file="/nonexistent/path/to/key" + ) + + # Error should be raised when get_ssh_identity() is called with pytest.raises(ConfigurationError, match="Failed to read ssh_identity_file"): - SSHWrapper( - children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, - default_username="testuser", - ssh_identity_file="/nonexistent/path/to/key" - ) + instance.get_ssh_identity() def test_ssh_command_with_identity_string():