Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/example/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from nene2.middleware import ErrorHandlerMiddleware
from nene2.validation.exceptions import ValidationException

from .note.exceptions import NoteNotFoundExceptionHandler
from .note.handler import make_note_router
from .note.repository import InMemoryNoteRepository
from .note.use_case import CreateNoteUseCase, GetNoteUseCase, ListNotesUseCase
Expand All @@ -22,7 +23,11 @@ def create_app(settings: AppSettings | None = None) -> FastAPI:
openapi_url="/openapi.json",
)

app.add_middleware(ErrorHandlerMiddleware, debug=cfg.app_debug)
app.add_middleware(
ErrorHandlerMiddleware,
debug=cfg.app_debug,
domain_handlers=[NoteNotFoundExceptionHandler()],
)
app.add_exception_handler(
ValidationException,
ErrorHandlerMiddleware.handle_validation_exception,
Expand Down
19 changes: 19 additions & 0 deletions src/example/note/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Note domain exceptions and their HTTP handlers."""

from starlette.responses import Response

from nene2.http.problem_details import problem_details_response


class NoteNotFoundException(Exception):
def __init__(self, note_id: int) -> None:
self.note_id = note_id
super().__init__(f"Note {note_id} not found.")


class NoteNotFoundExceptionHandler:
def handles(self, exc: Exception) -> bool:
return isinstance(exc, NoteNotFoundException)

def handle(self, exc: Exception) -> Response:
return problem_details_response("not-found", "Not Found", 404)
4 changes: 1 addition & 3 deletions src/example/note/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi.responses import JSONResponse
from pydantic import BaseModel

from nene2.http import PaginationQueryParser, PaginationResponse, problem_details_response
from nene2.http import PaginationQueryParser, PaginationResponse
from nene2.validation.exceptions import ValidationError, ValidationException

from .use_case import (
Expand Down Expand Up @@ -44,8 +44,6 @@ async def list_notes(request: Request) -> JSONResponse:
@router.get("/{note_id}")
async def get_note(note_id: int) -> JSONResponse:
note = get_use_case.execute(note_id)
if note is None:
return problem_details_response("not-found", "Not Found", 404)
return JSONResponse({"id": note.id, "title": note.title, "body": note.body})

@router.post("", status_code=201)
Expand Down
8 changes: 6 additions & 2 deletions src/example/note/use_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass

from .entity import Note
from .exceptions import NoteNotFoundException
from .repository import NoteRepositoryInterface


Expand Down Expand Up @@ -53,5 +54,8 @@ class GetNoteUseCase:
def __init__(self, repository: NoteRepositoryInterface) -> None:
self._repository = repository

def execute(self, note_id: int) -> Note | None:
return self._repository.find_by_id(note_id)
def execute(self, note_id: int) -> Note:
note = self._repository.find_by_id(note_id)
if note is None:
raise NoteNotFoundException(note_id)
return note
3 changes: 2 additions & 1 deletion src/nene2/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""NENE2 middleware pipeline."""

from .domain_exception import DomainExceptionHandlerProtocol
from .error_handler import ErrorHandlerMiddleware

__all__ = ["ErrorHandlerMiddleware"]
__all__ = ["DomainExceptionHandlerProtocol", "ErrorHandlerMiddleware"]
18 changes: 18 additions & 0 deletions src/nene2/middleware/domain_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""DomainExceptionHandlerProtocol — delegate domain errors to typed handlers."""

from typing import Protocol, runtime_checkable

from starlette.responses import Response


@runtime_checkable
class DomainExceptionHandlerProtocol(Protocol):
"""Map a domain exception to an HTTP response."""

def handles(self, exc: Exception) -> bool:
"""Return True if this handler is responsible for *exc*."""
...

def handle(self, exc: Exception) -> Response:
"""Convert *exc* to an HTTP response. Called only when handles() is True."""
...
14 changes: 13 additions & 1 deletion src/nene2/middleware/error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from nene2.http.problem_details import problem_details_response
from nene2.validation.exceptions import ValidationException

from .domain_exception import DomainExceptionHandlerProtocol

_ASGIApp = Callable[
[
MutableMapping[str, Any],
Expand All @@ -31,9 +33,16 @@
class ErrorHandlerMiddleware(BaseHTTPMiddleware):
"""Catch-all error handler that maps exceptions to Problem Details responses."""

def __init__(self, app: _ASGIApp, *, debug: bool = False) -> None:
def __init__(
self,
app: _ASGIApp,
*,
debug: bool = False,
domain_handlers: list[DomainExceptionHandlerProtocol] | None = None,
) -> None:
super().__init__(app)
self.debug = debug
self._domain_handlers: list[DomainExceptionHandlerProtocol] = domain_handlers or []

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
try:
Expand All @@ -47,6 +56,9 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
extra={"errors": [e.to_dict() for e in exc.errors]},
)
except Exception as exc:
for handler in self._domain_handlers:
if handler.handles(exc):
return handler.handle(exc)
logger.exception("Unhandled exception")
detail = str(exc) if self.debug else "The server encountered an unexpected condition."
return problem_details_response(
Expand Down
41 changes: 38 additions & 3 deletions tests/nene2/middleware/test_error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,30 @@
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from starlette.responses import Response

from nene2.middleware import ErrorHandlerMiddleware
from nene2.http.problem_details import problem_details_response
from nene2.middleware import DomainExceptionHandlerProtocol, ErrorHandlerMiddleware
from nene2.validation.exceptions import ValidationError, ValidationException


def _make_app(*, debug: bool = False) -> FastAPI:
class _DomainError(Exception):
pass


class _DomainErrorHandler:
def handles(self, exc: Exception) -> bool:
return isinstance(exc, _DomainError)

def handle(self, exc: Exception) -> Response:
return problem_details_response("domain-error", "Domain Error", 409)


def _make_app(
*, debug: bool = False, domain_handlers: list[DomainExceptionHandlerProtocol] | None = None
) -> FastAPI:
app = FastAPI()
app.add_middleware(ErrorHandlerMiddleware, debug=debug)
app.add_middleware(ErrorHandlerMiddleware, debug=debug, domain_handlers=domain_handlers)
app.add_exception_handler(
ValidationException,
ErrorHandlerMiddleware.handle_validation_exception, # type: ignore[arg-type]
Expand All @@ -20,6 +36,10 @@ def _make_app(*, debug: bool = False) -> FastAPI:
async def boom() -> JSONResponse:
raise RuntimeError("secret internal detail")

@app.get("/domain-error")
async def domain_error() -> JSONResponse:
raise _DomainError()

@app.get("/validation-error")
async def validation_error() -> JSONResponse:
raise ValidationException([ValidationError("field", "bad value", "invalid")])
Expand Down Expand Up @@ -51,3 +71,18 @@ def test_validation_exception_returns_422() -> None:
r = client.get("/validation-error")
assert r.status_code == 422
assert r.json()["errors"][0]["field"] == "field"


def test_domain_exception_handler_returns_mapped_status() -> None:
client = TestClient(
_make_app(domain_handlers=[_DomainErrorHandler()]), raise_server_exceptions=False
)
r = client.get("/domain-error")
assert r.status_code == 409
assert r.json()["type"].endswith("domain-error")


def test_unregistered_domain_exception_falls_through_to_500() -> None:
client = TestClient(_make_app(domain_handlers=[]), raise_server_exceptions=False)
r = client.get("/domain-error")
assert r.status_code == 500
Loading