Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,6 @@ def __post_init__(self):
if self.ssh_identity and self.ssh_identity_file:
raise ConfigurationError("Cannot specify both ssh_identity and ssh_identity_file")

# If ssh_identity_file is provided, read it into ssh_identity
if self.ssh_identity_file:
try:
self.ssh_identity = Path(self.ssh_identity_file).read_text()
except Exception as e:
raise ConfigurationError(f"Failed to read ssh_identity_file '{self.ssh_identity_file}': {e}") from None

@classmethod
def client(cls) -> str:
return "jumpstarter_driver_ssh.client.SSHWrapperClient"
Expand All @@ -48,4 +41,10 @@ def get_ssh_command(self):
@export
def get_ssh_identity(self):
"""Get the SSH identity key content"""
# If ssh_identity_file is provided, read it lazily and cache in ssh_identity
if self.ssh_identity is None and self.ssh_identity_file:
try:
self.ssh_identity = Path(self.ssh_identity_file).read_text()
except Exception as e:
raise ConfigurationError(f"Failed to read ssh_identity_file '{self.ssh_identity_file}': {e}") from None
return self.ssh_identity
311 changes: 311 additions & 0 deletions packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
from jumpstarter.common.exceptions import ConfigurationError
from jumpstarter.common.utils import serve

# Test SSH key content used in multiple tests
TEST_SSH_KEY = (
"-----BEGIN OPENSSH PRIVATE KEY-----\n"
"test-key-content\n"
"-----END OPENSSH PRIVATE KEY-----"
)


def test_ssh_wrapper_defaults():
"""Test SSH wrapper with default configuration"""
Expand Down Expand Up @@ -348,3 +355,307 @@ def test_ssh_command_with_command_l_flag_does_not_interfere_with_username_inject
assert ssh_l_index < hostname_index < command_l_index

assert result == 0


def test_ssh_identity_string_configuration():
"""Test SSH wrapper with ssh_identity string configuration"""
instance = SSHWrapper(
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
default_username="testuser",
ssh_identity=TEST_SSH_KEY
)

# Test that the instance was created correctly
assert instance.ssh_identity == TEST_SSH_KEY
assert instance.ssh_identity_file is None

# Test that the client class is correct
assert instance.client() == "jumpstarter_driver_ssh.client.SSHWrapperClient"


def test_ssh_identity_file_configuration():
"""Test SSH wrapper with ssh_identity_file configuration"""
import os
import tempfile

# Create a temporary file with SSH key content
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_test_key') as temp_file:
temp_file.write(TEST_SSH_KEY)
temp_file_path = temp_file.name

try:
instance = SSHWrapper(
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
default_username="testuser",
ssh_identity_file=temp_file_path
)

# Test that the instance was created correctly
# ssh_identity should be None until first use (lazy loading)
assert instance.ssh_identity is None
assert instance.ssh_identity_file == temp_file_path

# Test that get_ssh_identity() reads the file on first use
identity = instance.get_ssh_identity()
assert identity == TEST_SSH_KEY

# Test that ssh_identity is now cached
assert instance.ssh_identity == TEST_SSH_KEY

# Test that the client class is correct
assert instance.client() == "jumpstarter_driver_ssh.client.SSHWrapperClient"
finally:
# Clean up the temporary file
os.unlink(temp_file_path)


def test_ssh_identity_validation_error():
"""Test SSH wrapper raises error when both ssh_identity and ssh_identity_file are provided"""
with pytest.raises(ConfigurationError, match="Cannot specify both ssh_identity and ssh_identity_file"):
SSHWrapper(
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
default_username="testuser",
ssh_identity="test-key-content",
ssh_identity_file="/path/to/key"
)


def test_ssh_identity_file_read_error():
"""Test SSH wrapper raises error when ssh_identity_file cannot be read on first use"""
# Instance creation should succeed (lazy loading)
instance = SSHWrapper(
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
default_username="testuser",
ssh_identity_file="/nonexistent/path/to/key"
)

# Error should be raised when get_ssh_identity() is called
with pytest.raises(ConfigurationError, match="Failed to read ssh_identity_file"):
instance.get_ssh_identity()


def test_ssh_command_with_identity_string():
"""Test SSH command execution with ssh_identity string"""
instance = SSHWrapper(
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
default_username="testuser",
ssh_identity=TEST_SSH_KEY
)

with serve(instance) as client:
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0)

# Test SSH command with identity string
result = client.run(False, ["hostname"])

# Verify subprocess.run was called
assert mock_run.called
call_args = mock_run.call_args[0][0] # First positional argument

# Should include -i flag with temporary identity file
assert "-i" in call_args
identity_file_index = call_args.index("-i")
identity_file_path = call_args[identity_file_index + 1]

# The identity file should be a temporary file
assert identity_file_path.endswith("_ssh_key")
assert "/tmp" in identity_file_path or "/var/tmp" in identity_file_path

# Should include -l testuser
assert "-l" in call_args
assert "testuser" in call_args

# Should include the actual hostname (127.0.0.1) at the end
assert "127.0.0.1" in call_args
assert "hostname" in call_args

assert result == 0


def test_ssh_command_with_identity_file():
"""Test SSH command execution with ssh_identity_file"""
import os
import tempfile

# Create a temporary file with SSH key content
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_test_key') as temp_file:
temp_file.write(TEST_SSH_KEY)
temp_file_path = temp_file.name

try:
instance = SSHWrapper(
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
default_username="testuser",
ssh_identity_file=temp_file_path
)

with serve(instance) as client:
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0)

# Test SSH command with identity file
result = client.run(False, ["hostname"])

# Verify subprocess.run was called
assert mock_run.called
call_args = mock_run.call_args[0][0] # First positional argument

# Should include -i flag with temporary identity file
assert "-i" in call_args
identity_file_index = call_args.index("-i")
identity_file_path = call_args[identity_file_index + 1]

# The identity file should be a temporary file (not the original file)
assert identity_file_path.endswith("_ssh_key")
assert "/tmp" in identity_file_path or "/var/tmp" in identity_file_path
assert identity_file_path != temp_file_path

# Should include -l testuser
assert "-l" in call_args
assert "testuser" in call_args

# Should include the actual hostname (127.0.0.1) at the end
assert "127.0.0.1" in call_args
assert "hostname" in call_args

assert result == 0
finally:
# Clean up the temporary file
os.unlink(temp_file_path)


def test_ssh_command_without_identity():
"""Test SSH command execution without identity (should not include -i flag)"""
instance = SSHWrapper(
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
default_username="testuser"
)

with serve(instance) as client:
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0)

# Test SSH command without identity
result = client.run(False, ["hostname"])

# Verify subprocess.run was called
assert mock_run.called
call_args = mock_run.call_args[0][0] # First positional argument

# Should NOT include -i flag
assert "-i" not in call_args

# Should include -l testuser
assert "-l" in call_args
assert "testuser" in call_args

# Should include the actual hostname (127.0.0.1) at the end
assert "127.0.0.1" in call_args
assert "hostname" in call_args

assert result == 0


def test_ssh_identity_temp_file_creation_and_cleanup():
"""Test that temporary identity file is created and cleaned up properly"""
instance = SSHWrapper(
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
default_username="testuser",
ssh_identity=TEST_SSH_KEY
)

with serve(instance) as client:
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0)

with patch('tempfile.NamedTemporaryFile') as mock_temp_file:
with patch('os.chmod') as mock_chmod:
with patch('os.unlink') as mock_unlink:
# Mock the temporary file
mock_temp_file_instance = MagicMock()
mock_temp_file_instance.name = "/tmp/test_ssh_key_12345"
mock_temp_file_instance.write = MagicMock()
mock_temp_file_instance.close = MagicMock()
mock_temp_file.return_value = mock_temp_file_instance

# Test SSH command with identity
result = client.run(False, ["hostname"])

# Verify temporary file was created
mock_temp_file.assert_called_once_with(mode='w', delete=False, suffix='_ssh_key')
mock_temp_file_instance.write.assert_called_once_with(TEST_SSH_KEY)
mock_temp_file_instance.close.assert_called_once()

# Verify proper permissions were set
mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600)

# Verify temporary file was cleaned up
mock_unlink.assert_called_once_with("/tmp/test_ssh_key_12345")

assert result == 0


def test_ssh_identity_temp_file_creation_error():
"""Test error handling when temporary identity file creation fails"""
instance = SSHWrapper(
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
default_username="testuser",
ssh_identity=TEST_SSH_KEY
)

with serve(instance) as client:
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0)

with patch('tempfile.NamedTemporaryFile') as mock_temp_file:
mock_temp_file.side_effect = OSError("Permission denied")

# Test SSH command with identity should raise an error
# The exception will be wrapped in an ExceptionGroup due to the context manager
with pytest.raises(ExceptionGroup) as exc_info:
client.run(False, ["hostname"])

# Check that the original OSError is in the exception group
assert any(isinstance(e, OSError) and "Permission denied" in str(e) for e in exc_info.value.exceptions)


def test_ssh_identity_temp_file_cleanup_error():
"""Test error handling when temporary identity file cleanup fails"""
instance = SSHWrapper(
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
default_username="testuser",
ssh_identity=TEST_SSH_KEY
)

with serve(instance) as client:
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0)

with patch('tempfile.NamedTemporaryFile') as mock_temp_file:
with patch('os.chmod') as mock_chmod:
with patch('os.unlink') as mock_unlink:
# Mock the temporary file
mock_temp_file_instance = MagicMock()
mock_temp_file_instance.name = "/tmp/test_ssh_key_12345"
mock_temp_file_instance.write = MagicMock()
mock_temp_file_instance.close = MagicMock()
mock_temp_file.return_value = mock_temp_file_instance

# Mock cleanup failure
mock_unlink.side_effect = OSError("Permission denied")

# Test SSH command with identity - should still succeed but log warning
with patch.object(client, 'logger') as mock_logger:
result = client.run(False, ["hostname"])

# Verify chmod was called
mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600)

# Verify warning was logged
mock_logger.warning.assert_called_once()
warning_call = mock_logger.warning.call_args[0][0]
assert "Failed to clean up temporary identity file" in warning_call
assert "/tmp/test_ssh_key_12345" in warning_call

assert result == 0
Loading