Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/how-to/new-project.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ Create `src/app.py`:

```python
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError

from nene2.config import AppSettings
from nene2.log import setup_logging
from nene2.middleware import ErrorHandlerMiddleware
from nene2.middleware.error_handler import request_validation_error_handler
from nene2.middleware.request_id import RequestIdMiddleware
from nene2.middleware.request_logging import RequestLoggingMiddleware
from nene2.middleware.request_size_limit import RequestSizeLimitMiddleware
Expand Down Expand Up @@ -104,6 +106,9 @@ def create_app(settings: AppSettings | None = None) -> FastAPI:
window=settings.throttle_window,
)

# Convert Pydantic BaseModel validation errors to RFC 9457 Problem Details
app.add_exception_handler(RequestValidationError, request_validation_error_handler) # type: ignore[arg-type]

return app


Expand Down
6 changes: 6 additions & 0 deletions src/example/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Application factory — wires dependencies and registers routes."""

from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from sqlalchemy import create_engine
Expand All @@ -23,6 +24,7 @@
SecurityHeadersMiddleware,
ThrottleMiddleware,
)
from nene2.middleware.error_handler import request_validation_error_handler
from nene2.validation.exceptions import ValidationException

from .comment.exceptions import CommentNotFoundExceptionHandler
Expand Down Expand Up @@ -150,6 +152,10 @@ def create_app(settings: AppSettings | None = None) -> FastAPI:
ValidationException,
ErrorHandlerMiddleware.handle_validation_exception,
)
app.add_exception_handler(
RequestValidationError,
request_validation_error_handler,
)

note_repo, tag_repo, comment_repo, db_executor = _build_repositories(cfg)

Expand Down
33 changes: 32 additions & 1 deletion src/nene2/middleware/error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from typing import Any

from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import Response

from nene2.http.problem_details import problem_details_response
from nene2.validation.exceptions import ValidationException
from nene2.validation.exceptions import ValidationError, ValidationException

from .domain_exception import DomainExceptionHandlerProtocol

Expand Down Expand Up @@ -79,3 +80,33 @@ async def handle_validation_exception(_request: Request, exc: Exception) -> JSON
"The request contains invalid values.",
extra={"errors": [e.to_dict() for e in exc.errors]},
)


async def request_validation_error_handler(_request: Request, exc: Exception) -> JSONResponse:
"""Convert FastAPI RequestValidationError to nene2 Problem Details (422).

Register with FastAPI to replace the default Pydantic validation error format::

from fastapi.exceptions import RequestValidationError
from nene2.middleware.error_handler import request_validation_error_handler

app.add_exception_handler(RequestValidationError, request_validation_error_handler)
"""
if not isinstance(exc, RequestValidationError):
raise TypeError(f"Expected RequestValidationError, got {type(exc)}")

errors: list[ValidationError] = []
for raw in exc.errors():
loc = raw.get("loc", ())
field = ".".join(str(part) for part in loc if part != "body") or "request"
message = str(raw.get("msg", "Invalid value."))
code = str(raw.get("type", "invalid"))
errors.append(ValidationError(field=field or "request", message=message, code=code))

return problem_details_response(
"validation-failed",
"Validation Failed",
422,
"The request contains invalid values.",
extra={"errors": [e.to_dict() for e in errors]},
)
41 changes: 41 additions & 0 deletions tests/nene2/middleware/test_error_handler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Tests for ErrorHandlerMiddleware."""

from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from pydantic import BaseModel, Field
from starlette.responses import Response

from nene2.http.problem_details import problem_details_response
from nene2.middleware import DomainExceptionHandlerProtocol, ErrorHandlerMiddleware
from nene2.middleware.error_handler import request_validation_error_handler
from nene2.validation.exceptions import ValidationError, ValidationException


Expand Down Expand Up @@ -86,3 +89,41 @@ def test_unregistered_domain_exception_falls_through_to_500() -> None:
client = TestClient(_make_app(domain_handlers=[]), raise_server_exceptions=False)
r = client.get("/domain-error")
assert r.status_code == 500


def _make_app_with_pydantic_handler() -> FastAPI:
class _Body(BaseModel):
rating: int = Field(ge=1, le=5)
price: int = Field(ge=0)

app = FastAPI()
app.add_middleware(ErrorHandlerMiddleware)
app.add_exception_handler(
RequestValidationError,
request_validation_error_handler, # type: ignore[arg-type]
)

@app.post("/items")
async def create_item(body: _Body) -> JSONResponse:
return JSONResponse({"rating": body.rating, "price": body.price})

return app


def test_pydantic_validation_error_returns_problem_details_format() -> None:
client = TestClient(_make_app_with_pydantic_handler(), raise_server_exceptions=False)
r = client.post("/items", json={"rating": 99, "price": -1})
assert r.status_code == 422
body = r.json()
assert body["type"].endswith("validation-failed")
assert body["status"] == 422
assert isinstance(body["errors"], list)
assert len(body["errors"]) == 2


def test_pydantic_validation_error_field_names_are_extracted() -> None:
client = TestClient(_make_app_with_pydantic_handler(), raise_server_exceptions=False)
r = client.post("/items", json={"rating": 99, "price": 100})
assert r.status_code == 422
fields = [e["field"] for e in r.json()["errors"]]
assert "rating" in fields
Loading