diff --git a/trollmoves/movers.py b/trollmoves/movers.py index e3219e52..bd3e0d9c 100644 --- a/trollmoves/movers.py +++ b/trollmoves/movers.py @@ -358,6 +358,7 @@ def move(self): def copy(self): """Upload the file.""" from scp import SCPClient + from paramiko import SSHException ssh_connection = self.get_connection(self.destination.hostname, self.destination.port or 22, @@ -371,7 +372,23 @@ def copy(self): raise try: - scp.put(self.origin, self.destination.path) + destination = self.destination.path + remote_tmp = self.attrs.get("remote_tmp", None) + if remote_tmp: + destination = os.path.join(destination, '.' + os.path.basename(self.origin)) + scp.put(self.origin, destination) + + if remote_tmp: + timeout = self.attrs.get("ssh_connection_timeout", None) + _remote_orig = os.path.join(self.destination.path, os.path.basename(self.origin)) + _cmd = f"mv {destination} {_remote_orig}" + (_, out_ret, err_ret) = ssh_connection.exec_command(_cmd, timeout=timeout) + out_lines = out_ret.readlines() + for line in out_lines: + LOGGER.debug("Remote rename stdout: %s ", str(line)) + err_lines = err_ret.readlines() + for line in err_lines: + LOGGER.error("Remote rename stderr: %s ", str(line)) except OSError as osex: if osex.errno == 2: LOGGER.error("No such file or directory. File not transfered: " @@ -380,6 +397,8 @@ def copy(self): else: LOGGER.error("OSError in scp.put: %s", str(osex)) raise + except SSHException as sshe: + LOGGER.exception("Failed to rename from tmp name: %s", str(sshe)) except Exception as err: LOGGER.error("Something went wrong with scp: %s", str(err)) LOGGER.error("Exception name %s", type(err).__name__) diff --git a/trollmoves/tests/test_ssh_server.py b/trollmoves/tests/test_ssh_server.py index 40654628..cdceda7d 100644 --- a/trollmoves/tests/test_ssh_server.py +++ b/trollmoves/tests/test_ssh_server.py @@ -21,6 +21,9 @@ # along with this program. If not, see . """Test the ssh server.""" +import os +import sys +import logging import shutil from unittest.mock import Mock, MagicMock, patch import unittest @@ -30,14 +33,24 @@ from paramiko import SSHException import pytest -import logging import socket -import sys import trollmoves logger = logging.getLogger() + +class MockChannel: + def __init__(self, content=None): + self.content = [] if not content else [content] + + def __str__(self) -> str: + return str(self.content) + + def readlines(self): + return self.content + + class TestSSHMovers(unittest.TestCase): """Tests for SSH Mover.""" @@ -280,6 +293,78 @@ def test_scp_move(self, mock_scp_client, mock_sshclient): mocked_scp_client.put.assert_called_once_with(self.origin, urlparse(self.destination_no_port).path) + @patch('trollmoves.movers.ScpMover.get_connection') + @patch('paramiko.SSHClient.connect') + @patch('scp.SCPClient', autospec=True) + def test_scp_copy_via_remote_tmp2(self, mock_scp_client, mock_sshconnect, mock_sshexec): + """Check scp copy using remote temporary file.""" + from trollmoves.movers import ScpMover + + mocked_scp_client = MagicMock() + mock_scp_client.return_value = mocked_scp_client + mock_sshexec.return_value.exec_command.return_value = [(None), (MockChannel()), (MockChannel())] + scp_mover = ScpMover(self.origin, self.destination_no_port, attrs={'remote_tmp': True}) + scp_mover.copy() + + tmp_bn = os.path.join(urlparse(self.destination_no_port).path, + "." + os.path.basename(self.origin)) + mocked_scp_client.put.assert_called_once_with(self.origin, tmp_bn) + final_remote = os.path.join(urlparse(self.destination_no_port).path, + os.path.basename(self.origin)) + _cmd = f"mv {tmp_bn} {final_remote}" + mock_sshexec.return_value.exec_command.assert_called_once_with(_cmd, timeout=None) + + @patch('trollmoves.movers.ScpMover.get_connection') + @patch('paramiko.SSHClient.connect') + @patch('scp.SCPClient', autospec=True) + def test_scp_copy_via_remote_tmp_return_values(self, mock_scp_client, mock_sshconnect, mock_sshexec): + """Check scp copy using remote temporary file.""" + from trollmoves.movers import ScpMover + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + logger.setLevel(logging.INFO) + + mocked_scp_client = MagicMock() + mock_scp_client.return_value = mocked_scp_client + mock_sshexec.return_value.exec_command.return_value = [(None), (MockChannel("stdout")), (MockChannel("stderr"))] + try: + with self.assertLogs(logger, level=logging.DEBUG) as lc: + scp_mover = ScpMover(self.origin, self.destination_no_port, attrs={'remote_tmp': True}) + scp_mover.copy() + self.assertIn("Remote rename stdout: stdout", "".join(lc.output)) + self.assertIn("Remote rename stderr: stderr", "".join(lc.output)) + finally: + logger.removeHandler(stream_handler) + + tmp_bn = os.path.join(urlparse(self.destination_no_port).path, + "." + os.path.basename(self.origin)) + mocked_scp_client.put.assert_called_once_with(self.origin, tmp_bn) + final_remote = os.path.join(urlparse(self.destination_no_port).path, + os.path.basename(self.origin)) + _cmd = f"mv {tmp_bn} {final_remote}" + mock_sshexec.return_value.exec_command.assert_called_once_with(_cmd, timeout=None) + + @patch('trollmoves.movers.ScpMover.get_connection') + @patch('paramiko.SSHClient.connect') + @patch('scp.SCPClient', autospec=True) + def test_scp_copy_via_remote_tmp_exception(self, mock_scp_client, mock_sshconnect, mock_sshexec): + """Check scp copy using remote temporary file.""" + from trollmoves.movers import ScpMover + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + logger.setLevel(logging.INFO) + + mocked_scp_client = MagicMock() + mock_scp_client.return_value = mocked_scp_client + mock_sshexec.return_value.exec_command.side_effect = MagicMock(side_effect=SSHException) + try: + with self.assertLogs(logger, level=logging.DEBUG) as lc: + scp_mover = ScpMover(self.origin, self.destination_no_port, attrs={'remote_tmp': True}) + scp_mover.copy() + self.assertIn("Failed to rename from tmp name:", "".join(lc.output)) + finally: + logger.removeHandler(stream_handler) + if __name__ == '__main__': unittest.main()