diff --git a/craft_application/services/request.py b/craft_application/services/request.py index ced299e2d..23431039c 100644 --- a/craft_application/services/request.py +++ b/craft_application/services/request.py @@ -42,6 +42,7 @@ def __init__( super().__init__(app, services) self._session = requests.Session() self._session.headers["User-Agent"] = f"{self._app.name}/{self._app.version}" + self._max_retries = 3 # Passthroughs for requests methods so other services can use the session. self.request = self._session.request @@ -65,12 +66,35 @@ def download_chunks(self, url: str, dest: pathlib.Path) -> Iterator[int]: filename = util.get_filename_from_url_path(url) dest = dest / filename - with self.get(url, stream=True) as download: - with dest.open("wb") as file: - yield int(download.headers.get("Content-Length", -1)) - for chunk in download.iter_content(None): - file.write(chunk) - yield len(chunk) + content_length_yielded = False + + for attempt in range(self._max_retries): + downloaded_bytes = 0 + try: + with self.get(url, stream=True) as download: + with dest.open("wb") as file: + if not content_length_yielded: + content_length = int( + download.headers.get("Content-Length", -1) + ) + yield content_length + content_length_yielded = True + + # Download and track chunks + for chunk in download.iter_content(None): + file.write(chunk) + downloaded_bytes += len(chunk) + yield len(chunk) + break + except requests.exceptions.ChunkedEncodingError: + if attempt < self._max_retries - 1: + craft_cli.emit.progress( + f"Download interrupted, retrying... (attempt {attempt + 1}/{self._max_retries})" + ) + # Yield negative sum of chunk sizes to indicate rollback + yield -downloaded_bytes + else: + raise def download_with_progress(self, url: str, dest: pathlib.Path) -> pathlib.Path: """Download a single file with a progress bar.""" diff --git a/tests/unit/services/test_request.py b/tests/unit/services/test_request.py index 5b241f30a..efa4ffb6d 100644 --- a/tests/unit/services/test_request.py +++ b/tests/unit/services/test_request.py @@ -15,11 +15,12 @@ # along with this program. If not, see . """Unit tests for the Request service.""" -from unittest.mock import call +from unittest.mock import call, patch import craft_cli.pytest_plugin import pytest import pytest_check +import requests import responses from hypothesis import HealthCheck, given, settings, strategies @@ -114,3 +115,110 @@ def test_download_files_with_progress(tmp_path, emitter, request_service, downlo for url, path in results.items(): assert path.read_bytes() == downloads[url] + + +def failing_iter_content(chunk_size=None): # pylint: disable=unused-argument + """Simulate a ChunkedEncodingError during iter_content().""" + # Yield some data first to simulate partial download + yield b"partial" + # Then raise the error + raise requests.exceptions.ChunkedEncodingError( + "Connection broken: Invalid chunk encoding" + ) + + +@responses.activate +def test_download_chunks_with_chunked_encoding_error_retry( + tmp_path, emitter, request_service +): + """Test that download_chunks retries on ChunkedEncodingError and eventually succeeds. + + This test simulates a ChunkedEncodingError occurring during download.iter_content(), + verifies that the download is retried, and ensures the final download completes + successfully with the correct data and chunk count (not counting failed attempts). + """ + data = b"This is test data for download retry" + output_file = tmp_path / "file" + + # Patch the get method to simulate ChunkedEncodingError on first attempts + original_get = request_service.get + call_count = {"count": 0} + + # Set up the mock response + responses.add( + responses.GET, + "http://example/file", + body=data, + headers={"Content-Length": str(len(data))}, + ) + + def patched_get(*args, **kwargs): + call_count["count"] += 1 + response = original_get(*args, **kwargs) + + # Make iter_content raise ChunkedEncodingError on first two attempts + if call_count["count"] <= 2: + response.iter_content = failing_iter_content + + return response + + with patch.object(request_service, "get", side_effect=patched_get): + downloader = request_service.download_chunks("http://example/file", output_file) + size = next(downloader) + dl_size = sum(downloader) + + # Verify that the download eventually succeeded + pytest_check.equal(int(size), len(data), "Downloaded size is incorrect") + pytest_check.equal(dl_size, len(data), "Downloaded size is incorrect") + pytest_check.equal(output_file.read_bytes(), data, "Download data is incorrect") + + # Verify that retry messages were emitted + progress_calls = [ + interaction + for interaction in emitter.interactions + if len(interaction.args) > 0 + and interaction.args[0] == "progress" + and len(interaction.args) > 1 + and "retrying" in interaction.args[1] + ] + pytest_check.equal(len(progress_calls), 2, "Expected 2 retry progress messages") + + # Verify that we made 3 attempts (2 failures + 1 success) + pytest_check.equal(call_count["count"], 3, "Expected 3 download attempts") + + +@responses.activate +def test_download_chunks_chunked_encoding_error_exhausted(tmp_path, request_service): + """Test that download_chunks raises ChunkedEncodingError after max retries. + + This test simulates a persistent ChunkedEncodingError that occurs on every + download attempt, and verifies that after exhausting all retry attempts, + the error is properly raised to the caller. + """ + data = b"test data" + output_file = tmp_path / "file" + + # Set up the mock response + responses.add( + responses.GET, + "http://example/file", + body=data, + headers={"Content-Length": str(len(data))}, + ) + + # Patch to always fail + original_get = request_service.get + + def patched_get(*args, **kwargs): + response = original_get(*args, **kwargs) + + response.iter_content = failing_iter_content + return response + + with patch.object(request_service, "get", side_effect=patched_get): + downloader = request_service.download_chunks("http://example/file", output_file) + next(downloader) # Get the size + + # The error should be raised after max_retries attempts + with pytest.raises(requests.exceptions.ChunkedEncodingError): + list(downloader) # Consume the iterator