diff --git a/app/routers/classify_router.py b/app/routers/classify_router.py index bd18895..e1caff2 100644 --- a/app/routers/classify_router.py +++ b/app/routers/classify_router.py @@ -5,7 +5,7 @@ import asyncio from pydantic import BaseModel, Field from app.models.model_handler import ModelHandler -from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_response +from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_response, validate_decodable from fastapi.concurrency import run_in_threadpool logger = logging.getLogger(__name__) @@ -75,6 +75,9 @@ async def classify_image( # Resolve image bytes from URI (URL, data URI, or local path) try: image_bytes = await resolve_image_uri(classify_request.image_uri) + # Reject undecodable/corrupt images as a 4xx (client error) so + # callers treat them as a permanent failure, not a retryable 5xx. + validate_decodable(image_bytes) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/app/routers/extract_router.py b/app/routers/extract_router.py index 3f1842a..7e50eb4 100644 --- a/app/routers/extract_router.py +++ b/app/routers/extract_router.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, model_validator from app.models.model_handler import ModelHandler from app.models.miewid import MiewidModel -from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_response +from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_response, validate_decodable from fastapi.concurrency import run_in_threadpool logger = logging.getLogger(__name__) @@ -95,6 +95,9 @@ async def extract_embeddings( # Resolve image bytes from URI (URL, data URI, or local path) try: image_bytes = await resolve_image_uri(extract_request.image_uri) + # Reject undecodable/corrupt images as a 4xx (client error) so + # callers treat them as a permanent failure, not a retryable 5xx. + validate_decodable(image_bytes) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/app/routers/pipeline_router.py b/app/routers/pipeline_router.py index b2d294b..be42d51 100644 --- a/app/routers/pipeline_router.py +++ b/app/routers/pipeline_router.py @@ -9,7 +9,7 @@ from app.models.densenet_classifier import DenseNetClassifierModel from app.models.miewid import MiewidModel from app.models.densenet_orientation import DenseNetOrientationModel -from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_response, sanitize_uri_for_logging +from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_response, sanitize_uri_for_logging, validate_decodable from fastapi.concurrency import run_in_threadpool logger = logging.getLogger(__name__) @@ -125,6 +125,9 @@ async def run_pipeline( # Resolve image bytes from URI (URL, data URI, or local path) try: image_bytes = await resolve_image_uri(pipeline_request.image_uri) + # Reject undecodable/corrupt images as a 4xx (client error) so + # callers treat them as a permanent failure, not a retryable 5xx. + validate_decodable(image_bytes) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/app/routers/predict_router.py b/app/routers/predict_router.py index 8d0d35b..d896b19 100755 --- a/app/routers/predict_router.py +++ b/app/routers/predict_router.py @@ -7,7 +7,7 @@ import json import os from app.models.model_handler import ModelHandler -from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_logging +from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_logging, validate_decodable from fastapi.concurrency import run_in_threadpool logger = logging.getLogger(__name__) @@ -88,6 +88,9 @@ async def predict( # Resolve image bytes from URI (URL, data URI, or local path) try: image_bytes = await resolve_image_uri(prediction.image_uri) + # Reject undecodable/corrupt images as a 4xx (client error) so + # callers treat them as a permanent failure, not a retryable 5xx. + validate_decodable(image_bytes) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/app/utils/image_uri.py b/app/utils/image_uri.py index b1761ad..2cca778 100644 --- a/app/utils/image_uri.py +++ b/app/utils/image_uri.py @@ -1,10 +1,24 @@ """Utilities for resolving image URIs to bytes.""" import base64 +from io import BytesIO from pathlib import Path from typing import Tuple import httpx +from PIL import Image, UnidentifiedImageError + + +class ImageDecodeError(ValueError): + """Raised when image bytes cannot be decoded into a usable image. + + Subclasses ValueError so callers that already map ValueError from image + resolution to an HTTP 400 treat an undecodable image the same way — a + client/input error, not a server (5xx) error. This matters because + consumers (e.g. Wildbook) retry 5xx responses as transient, but a corrupt + image is a permanent failure: it must be reported as a 4xx so it is marked + terminal rather than retried indefinitely. + """ def is_data_uri(uri: str) -> bool: @@ -59,3 +73,29 @@ async def resolve_image_uri(uri: str) -> bytes: raise ValueError(f"File not found: {uri}") with open(file_path, "rb") as f: return f.read() + + +def validate_decodable(image_bytes: bytes) -> None: + """Confirm image_bytes decode into a usable image, else raise ImageDecodeError. + + A header-only check (Image.verify) is insufficient: corrupt JPEGs often have + a valid header but a broken entropy-coded scan stream that only fails during + a full pixel load (e.g. "broken data stream when reading image file" / + "Unsupported marker type 0xNN"). We therefore fully load() the image. + + Catches: + - UnidentifiedImageError: not a recognizable image at all. + - OSError: broken/truncated scan stream surfaced during load(). + - Image.DecompressionBombError: pathologically large image rejected by + Pillow's bomb guard. It is not an OSError, so it must be listed + explicitly or it would escape to the routers' generic 500 handler. + Like the others it is a permanent, non-retryable bad input. + + Raises: + ImageDecodeError (a ValueError): if the bytes cannot be decoded. + """ + try: + img = Image.open(BytesIO(image_bytes)) + img.load() + except (UnidentifiedImageError, OSError, Image.DecompressionBombError) as e: + raise ImageDecodeError(f"unprocessable image: cannot decode ({e})") diff --git a/tests/test_image_decode_validation.py b/tests/test_image_decode_validation.py new file mode 100644 index 0000000..72826c6 --- /dev/null +++ b/tests/test_image_decode_validation.py @@ -0,0 +1,86 @@ +"""Unit tests for app.utils.image_uri.validate_decodable. + +Pure-PIL tests (no model handler / GPU), so they run anywhere. They verify +that a corrupt image is rejected with ImageDecodeError (a ValueError) — which +the inference routers map to HTTP 400, so consumers treat a corrupt image as a +permanent (non-retryable) client error rather than a retryable 5xx. +""" + +import io + +import pytest +from PIL import Image + +from app.utils.image_uri import ImageDecodeError, validate_decodable + + +def _valid_jpeg_bytes() -> bytes: + img = Image.new("RGB", (64, 64), (120, 120, 120)) + buf = io.BytesIO() + img.save(buf, "jpeg", quality=90) + return buf.getvalue() + + +def _corrupt_jpeg_bytes() -> bytes: + """A JPEG with a valid header but a broken entropy-coded scan stream. + + Splicing 0xFF 0x99 (an unsupported marker) into the scan data reproduces the + real-world failure class ("Unsupported marker type 0x99" / + "broken data stream when reading image file") that fails during full load(), + not at header parse. + """ + data = bytearray(_valid_jpeg_bytes()) + at = max(20, len(data) - 40) + for i in range(at, min(at + 32, len(data) - 1), 2): + data[i] = 0xFF + data[i + 1] = 0x99 + return bytes(data) + + +def test_valid_image_passes(): + # Should not raise. + validate_decodable(_valid_jpeg_bytes()) + + +def test_corrupt_image_raises_image_decode_error(): + with pytest.raises(ImageDecodeError): + validate_decodable(_corrupt_jpeg_bytes()) + + +def test_image_decode_error_is_value_error(): + # Routers catch ValueError -> HTTP 400; ImageDecodeError must subclass it + # so an undecodable image is reported as a 4xx, not an unhandled 5xx. + assert issubclass(ImageDecodeError, ValueError) + + +def test_non_image_bytes_raise_image_decode_error(): + with pytest.raises(ImageDecodeError): + validate_decodable(b"this is definitely not an image") + + +def test_empty_bytes_raise_image_decode_error(): + with pytest.raises(ImageDecodeError): + validate_decodable(b"") + + +def test_truncated_jpeg_raises_image_decode_error(): + # Keep only the first 200 bytes (header + partial scan) — load() must fail. + truncated = _valid_jpeg_bytes()[:200] + with pytest.raises(ImageDecodeError): + validate_decodable(truncated) + + +def test_decompression_bomb_raises_image_decode_error(): + # A decompression bomb raises PIL's DecompressionBombError, which is NOT an + # OSError; verify it is still mapped to ImageDecodeError. Force the guard by + # lowering Pillow's pixel limit so an ordinary image trips it. + from PIL import Image + + valid = _valid_jpeg_bytes() # 64x64 = 4096 px + saved = Image.MAX_IMAGE_PIXELS + try: + Image.MAX_IMAGE_PIXELS = 1 # 2*1 < 4096 -> DecompressionBombError on load + with pytest.raises(ImageDecodeError): + validate_decodable(valid) + finally: + Image.MAX_IMAGE_PIXELS = saved