Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/nene2/middleware/request_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Uses contextvars so downstream code (e.g. structlog) can read the ID.
"""

import re
import uuid
from contextvars import ContextVar

Expand All @@ -13,14 +14,31 @@

_REQUEST_ID_HEADER = "X-Request-Id"

# UUID v4 canonical form — 8-4-4-4-12 hex, version=4, variant=8/9/a/b
_UUID_V4_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$",
re.IGNORECASE,
)

request_id_var: ContextVar[str] = ContextVar("request_id", default="")


def _validated_request_id(value: str | None) -> str:
"""Return value if it is a valid UUID v4, otherwise generate a fresh one."""
if value and _UUID_V4_RE.match(value):
return value.lower()
return str(uuid.uuid4())


class RequestIdMiddleware(BaseHTTPMiddleware):
"""Generate or forward X-Request-Id and expose it via contextvars."""
"""Generate or forward X-Request-Id and expose it via contextvars.

Client-supplied X-Request-Id is accepted only when it matches UUID v4
format, preventing log injection via arbitrary header values.
"""

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
request_id = request.headers.get(_REQUEST_ID_HEADER) or str(uuid.uuid4())
request_id = _validated_request_id(request.headers.get(_REQUEST_ID_HEADER))
request_id_var.set(request_id)
response = await call_next(request)
response.headers[_REQUEST_ID_HEADER] = request_id
Expand Down
31 changes: 28 additions & 3 deletions tests/nene2/middleware/test_request_id.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""Tests for RequestIdMiddleware."""

import re

from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient

from nene2.middleware import RequestIdMiddleware, request_id_var

_UUID_V4_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
)


def _make_app() -> FastAPI:
app = FastAPI()
Expand All @@ -24,13 +30,32 @@ def test_response_has_x_request_id() -> None:
assert response.status_code == 200
assert "X-Request-Id" in response.headers
rid = response.headers["X-Request-Id"]
assert len(rid) == 36 # UUID v4 format
assert len(rid) == 36


def test_forwards_valid_uuid_v4_request_id() -> None:
valid_id = "550e8400-e29b-41d4-a716-446655440000"
client = TestClient(_make_app())
response = client.get("/ping", headers={"X-Request-Id": valid_id})
assert response.headers["X-Request-Id"] == valid_id


def test_forwards_provided_request_id() -> None:
def test_invalid_request_id_is_replaced_with_new_uuid() -> None:
"""Non-UUID values must not be forwarded to prevent log injection."""
client = TestClient(_make_app())
response = client.get("/ping", headers={"X-Request-Id": "my-trace-id-123"})
assert response.headers["X-Request-Id"] == "my-trace-id-123"
rid = response.headers["X-Request-Id"]
assert rid != "my-trace-id-123"
assert _UUID_V4_RE.match(rid), f"Expected UUID v4, got {rid!r}"


def test_newline_in_request_id_is_rejected() -> None:
"""Newlines in X-Request-Id must be rejected to prevent log injection."""
client = TestClient(_make_app())
response = client.get("/ping", headers={"X-Request-Id": "abc\nERROR injected"})
rid = response.headers["X-Request-Id"]
assert "\n" not in rid
assert _UUID_V4_RE.match(rid)


def test_request_id_available_in_contextvars() -> None:
Expand Down
Loading