Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion app/routers/classify_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
from pydantic import BaseModel, Field
from app.models.model_handler import ModelHandler
from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_response
from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_response, validate_decodable
from fastapi.concurrency import run_in_threadpool

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -75,6 +75,9 @@ async def classify_image(
# Resolve image bytes from URI (URL, data URI, or local path)
try:
image_bytes = await resolve_image_uri(classify_request.image_uri)
# Reject undecodable/corrupt images as a 4xx (client error) so
# callers treat them as a permanent failure, not a retryable 5xx.
validate_decodable(image_bytes)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand Down
5 changes: 4 additions & 1 deletion app/routers/extract_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel, Field, model_validator
from app.models.model_handler import ModelHandler
from app.models.miewid import MiewidModel
from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_response
from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_response, validate_decodable
from fastapi.concurrency import run_in_threadpool

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -95,6 +95,9 @@ async def extract_embeddings(
# Resolve image bytes from URI (URL, data URI, or local path)
try:
image_bytes = await resolve_image_uri(extract_request.image_uri)
# Reject undecodable/corrupt images as a 4xx (client error) so
# callers treat them as a permanent failure, not a retryable 5xx.
validate_decodable(image_bytes)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand Down
5 changes: 4 additions & 1 deletion app/routers/pipeline_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from app.models.densenet_classifier import DenseNetClassifierModel
from app.models.miewid import MiewidModel
from app.models.densenet_orientation import DenseNetOrientationModel
from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_response, sanitize_uri_for_logging
from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_response, sanitize_uri_for_logging, validate_decodable
from fastapi.concurrency import run_in_threadpool

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -125,6 +125,9 @@ async def run_pipeline(
# Resolve image bytes from URI (URL, data URI, or local path)
try:
image_bytes = await resolve_image_uri(pipeline_request.image_uri)
# Reject undecodable/corrupt images as a 4xx (client error) so
# callers treat them as a permanent failure, not a retryable 5xx.
validate_decodable(image_bytes)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand Down
5 changes: 4 additions & 1 deletion app/routers/predict_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import os
from app.models.model_handler import ModelHandler
from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_logging
from app.utils.image_uri import resolve_image_uri, sanitize_uri_for_logging, validate_decodable
from fastapi.concurrency import run_in_threadpool

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,6 +88,9 @@ async def predict(
# Resolve image bytes from URI (URL, data URI, or local path)
try:
image_bytes = await resolve_image_uri(prediction.image_uri)
# Reject undecodable/corrupt images as a 4xx (client error) so
# callers treat them as a permanent failure, not a retryable 5xx.
validate_decodable(image_bytes)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand Down
40 changes: 40 additions & 0 deletions app/utils/image_uri.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
"""Utilities for resolving image URIs to bytes."""

import base64
from io import BytesIO
from pathlib import Path
from typing import Tuple

import httpx
from PIL import Image, UnidentifiedImageError


class ImageDecodeError(ValueError):
"""Raised when image bytes cannot be decoded into a usable image.

Subclasses ValueError so callers that already map ValueError from image
resolution to an HTTP 400 treat an undecodable image the same way — a
client/input error, not a server (5xx) error. This matters because
consumers (e.g. Wildbook) retry 5xx responses as transient, but a corrupt
image is a permanent failure: it must be reported as a 4xx so it is marked
terminal rather than retried indefinitely.
"""


def is_data_uri(uri: str) -> bool:
Expand Down Expand Up @@ -59,3 +73,29 @@ async def resolve_image_uri(uri: str) -> bytes:
raise ValueError(f"File not found: {uri}")
with open(file_path, "rb") as f:
return f.read()


def validate_decodable(image_bytes: bytes) -> None:
"""Confirm image_bytes decode into a usable image, else raise ImageDecodeError.

A header-only check (Image.verify) is insufficient: corrupt JPEGs often have
a valid header but a broken entropy-coded scan stream that only fails during
a full pixel load (e.g. "broken data stream when reading image file" /
"Unsupported marker type 0xNN"). We therefore fully load() the image.

Catches:
- UnidentifiedImageError: not a recognizable image at all.
- OSError: broken/truncated scan stream surfaced during load().
- Image.DecompressionBombError: pathologically large image rejected by
Pillow's bomb guard. It is not an OSError, so it must be listed
explicitly or it would escape to the routers' generic 500 handler.
Like the others it is a permanent, non-retryable bad input.

Raises:
ImageDecodeError (a ValueError): if the bytes cannot be decoded.
"""
try:
img = Image.open(BytesIO(image_bytes))
img.load()
except (UnidentifiedImageError, OSError, Image.DecompressionBombError) as e:
raise ImageDecodeError(f"unprocessable image: cannot decode ({e})")
86 changes: 86 additions & 0 deletions tests/test_image_decode_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Unit tests for app.utils.image_uri.validate_decodable.

Pure-PIL tests (no model handler / GPU), so they run anywhere. They verify
that a corrupt image is rejected with ImageDecodeError (a ValueError) — which
the inference routers map to HTTP 400, so consumers treat a corrupt image as a
permanent (non-retryable) client error rather than a retryable 5xx.
"""

import io

import pytest
from PIL import Image

from app.utils.image_uri import ImageDecodeError, validate_decodable


def _valid_jpeg_bytes() -> bytes:
img = Image.new("RGB", (64, 64), (120, 120, 120))
buf = io.BytesIO()
img.save(buf, "jpeg", quality=90)
return buf.getvalue()


def _corrupt_jpeg_bytes() -> bytes:
"""A JPEG with a valid header but a broken entropy-coded scan stream.

Splicing 0xFF 0x99 (an unsupported marker) into the scan data reproduces the
real-world failure class ("Unsupported marker type 0x99" /
"broken data stream when reading image file") that fails during full load(),
not at header parse.
"""
data = bytearray(_valid_jpeg_bytes())
at = max(20, len(data) - 40)
for i in range(at, min(at + 32, len(data) - 1), 2):
data[i] = 0xFF
data[i + 1] = 0x99
return bytes(data)


def test_valid_image_passes():
# Should not raise.
validate_decodable(_valid_jpeg_bytes())


def test_corrupt_image_raises_image_decode_error():
with pytest.raises(ImageDecodeError):
validate_decodable(_corrupt_jpeg_bytes())


def test_image_decode_error_is_value_error():
# Routers catch ValueError -> HTTP 400; ImageDecodeError must subclass it
# so an undecodable image is reported as a 4xx, not an unhandled 5xx.
assert issubclass(ImageDecodeError, ValueError)


def test_non_image_bytes_raise_image_decode_error():
with pytest.raises(ImageDecodeError):
validate_decodable(b"this is definitely not an image")


def test_empty_bytes_raise_image_decode_error():
with pytest.raises(ImageDecodeError):
validate_decodable(b"")


def test_truncated_jpeg_raises_image_decode_error():
# Keep only the first 200 bytes (header + partial scan) — load() must fail.
truncated = _valid_jpeg_bytes()[:200]
with pytest.raises(ImageDecodeError):
validate_decodable(truncated)


def test_decompression_bomb_raises_image_decode_error():
# A decompression bomb raises PIL's DecompressionBombError, which is NOT an
# OSError; verify it is still mapped to ImageDecodeError. Force the guard by
# lowering Pillow's pixel limit so an ordinary image trips it.
from PIL import Image

valid = _valid_jpeg_bytes() # 64x64 = 4096 px
saved = Image.MAX_IMAGE_PIXELS
try:
Image.MAX_IMAGE_PIXELS = 1 # 2*1 < 4096 -> DecompressionBombError on load
with pytest.raises(ImageDecodeError):
validate_decodable(valid)
finally:
Image.MAX_IMAGE_PIXELS = saved