diff --git a/src/nene2/middleware/request_size_limit.py b/src/nene2/middleware/request_size_limit.py index 610330f..59ea086 100644 --- a/src/nene2/middleware/request_size_limit.py +++ b/src/nene2/middleware/request_size_limit.py @@ -1,7 +1,8 @@ """Request body size limit middleware. -Rejects requests whose Content-Length exceeds the configured maximum. -Protects against memory exhaustion from oversized payloads. +Rejects requests whose body exceeds the configured maximum. +Protects against memory exhaustion from oversized payloads, including +chunked-transfer requests that omit the Content-Length header. """ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint @@ -12,9 +13,16 @@ _DEFAULT_MAX_BYTES = 1_048_576 # 1 MiB +_TOO_LARGE = "Request body must not exceed {limit} bytes." + class RequestSizeLimitMiddleware(BaseHTTPMiddleware): - """Reject requests whose Content-Length exceeds max_bytes.""" + """Reject requests whose body exceeds max_bytes. + + Checks the Content-Length header first for a fast pre-flight reject, + then reads the actual body to catch chunked-transfer requests that + omit Content-Length entirely. + """ def __init__(self, app: object, *, max_bytes: int = _DEFAULT_MAX_BYTES) -> None: super().__init__(app) # type: ignore[arg-type] @@ -22,11 +30,23 @@ def __init__(self, app: object, *, max_bytes: int = _DEFAULT_MAX_BYTES) -> None: async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: content_length = request.headers.get("Content-Length") - if content_length is not None and int(content_length) > self._max_bytes: - return problem_details_response( - "payload-too-large", - "Payload Too Large", - 413, - f"Request body must not exceed {self._max_bytes} bytes.", - ) + if content_length is not None: + try: + if int(content_length) > self._max_bytes: + return self._too_large() + except ValueError: + pass + + body = await request.body() + if len(body) > self._max_bytes: + return self._too_large() + return await call_next(request) + + def _too_large(self) -> Response: + return problem_details_response( + "payload-too-large", + "Payload Too Large", + 413, + _TOO_LARGE.format(limit=self._max_bytes), + ) diff --git a/tests/nene2/middleware/test_request_size_limit.py b/tests/nene2/middleware/test_request_size_limit.py index 7125af8..619c72d 100644 --- a/tests/nene2/middleware/test_request_size_limit.py +++ b/tests/nene2/middleware/test_request_size_limit.py @@ -44,3 +44,25 @@ def test_no_content_length_passes() -> None: client = TestClient(_make_app(max_bytes=10_000)) response = client.post("/upload", json={"data": "small"}) assert response.status_code == 200 + + +def test_oversized_body_without_content_length_returns_413() -> None: + """Chunked-transfer (no Content-Length) must also be caught.""" + client = TestClient(_make_app(max_bytes=100)) + response = client.post( + "/upload", + content=b"x" * 200, + headers={"Content-Type": "application/octet-stream"}, + ) + assert response.status_code == 413 + + +def test_malformed_content_length_is_tolerated() -> None: + """Non-integer Content-Length header must not crash the middleware.""" + client = TestClient(_make_app(max_bytes=1000)) + response = client.post( + "/upload", + content=b"hello", + headers={"Content-Length": "abc", "Content-Type": "application/octet-stream"}, + ) + assert response.status_code == 200