diff --git a/src/nene2/middleware/request_id.py b/src/nene2/middleware/request_id.py index 11b7e6e..6dafa21 100644 --- a/src/nene2/middleware/request_id.py +++ b/src/nene2/middleware/request_id.py @@ -4,6 +4,7 @@ Uses contextvars so downstream code (e.g. structlog) can read the ID. """ +import re import uuid from contextvars import ContextVar @@ -13,14 +14,31 @@ _REQUEST_ID_HEADER = "X-Request-Id" +# UUID v4 canonical form — 8-4-4-4-12 hex, version=4, variant=8/9/a/b +_UUID_V4_RE = re.compile( + r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$", + re.IGNORECASE, +) + request_id_var: ContextVar[str] = ContextVar("request_id", default="") +def _validated_request_id(value: str | None) -> str: + """Return value if it is a valid UUID v4, otherwise generate a fresh one.""" + if value and _UUID_V4_RE.match(value): + return value.lower() + return str(uuid.uuid4()) + + class RequestIdMiddleware(BaseHTTPMiddleware): - """Generate or forward X-Request-Id and expose it via contextvars.""" + """Generate or forward X-Request-Id and expose it via contextvars. + + Client-supplied X-Request-Id is accepted only when it matches UUID v4 + format, preventing log injection via arbitrary header values. + """ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: - request_id = request.headers.get(_REQUEST_ID_HEADER) or str(uuid.uuid4()) + request_id = _validated_request_id(request.headers.get(_REQUEST_ID_HEADER)) request_id_var.set(request_id) response = await call_next(request) response.headers[_REQUEST_ID_HEADER] = request_id diff --git a/tests/nene2/middleware/test_request_id.py b/tests/nene2/middleware/test_request_id.py index 6c6f0e0..889392d 100644 --- a/tests/nene2/middleware/test_request_id.py +++ b/tests/nene2/middleware/test_request_id.py @@ -1,11 +1,17 @@ """Tests for RequestIdMiddleware.""" +import re + from fastapi import FastAPI from fastapi.responses import JSONResponse from fastapi.testclient import TestClient from nene2.middleware import RequestIdMiddleware, request_id_var +_UUID_V4_RE = re.compile( + r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$" +) + def _make_app() -> FastAPI: app = FastAPI() @@ -24,13 +30,32 @@ def test_response_has_x_request_id() -> None: assert response.status_code == 200 assert "X-Request-Id" in response.headers rid = response.headers["X-Request-Id"] - assert len(rid) == 36 # UUID v4 format + assert len(rid) == 36 + + +def test_forwards_valid_uuid_v4_request_id() -> None: + valid_id = "550e8400-e29b-41d4-a716-446655440000" + client = TestClient(_make_app()) + response = client.get("/ping", headers={"X-Request-Id": valid_id}) + assert response.headers["X-Request-Id"] == valid_id -def test_forwards_provided_request_id() -> None: +def test_invalid_request_id_is_replaced_with_new_uuid() -> None: + """Non-UUID values must not be forwarded to prevent log injection.""" client = TestClient(_make_app()) response = client.get("/ping", headers={"X-Request-Id": "my-trace-id-123"}) - assert response.headers["X-Request-Id"] == "my-trace-id-123" + rid = response.headers["X-Request-Id"] + assert rid != "my-trace-id-123" + assert _UUID_V4_RE.match(rid), f"Expected UUID v4, got {rid!r}" + + +def test_newline_in_request_id_is_rejected() -> None: + """Newlines in X-Request-Id must be rejected to prevent log injection.""" + client = TestClient(_make_app()) + response = client.get("/ping", headers={"X-Request-Id": "abc\nERROR injected"}) + rid = response.headers["X-Request-Id"] + assert "\n" not in rid + assert _UUID_V4_RE.match(rid) def test_request_id_available_in_contextvars() -> None: