diff --git a/src/nene2/auth/api_key.py b/src/nene2/auth/api_key.py index 2788f61..c69f563 100644 --- a/src/nene2/auth/api_key.py +++ b/src/nene2/auth/api_key.py @@ -10,10 +10,18 @@ from nene2.http.problem_details import problem_details_response +from .exceptions import TokenVerificationException from .interfaces import TokenVerifierProtocol _API_KEY_HEADER = "X-Api-Key" +_UNAUTHORIZED = problem_details_response( + "unauthorized", + "Unauthorized", + 401, + "A valid X-Api-Key header is required.", +) + class ApiKeyAuthMiddleware(BaseHTTPMiddleware): """Require a valid X-Api-Key header on every request.""" @@ -24,11 +32,10 @@ def __init__(self, app: object, *, verifier: TokenVerifierProtocol) -> None: async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: api_key = request.headers.get(_API_KEY_HEADER, "") - if not api_key or not self._verifier.verify(api_key): - return problem_details_response( - "unauthorized", - "Unauthorized", - 401, - "A valid X-Api-Key header is required.", - ) + try: + verified = bool(api_key) and self._verifier.verify(api_key) + except TokenVerificationException: + verified = False + if not verified: + return _UNAUTHORIZED return await call_next(request) diff --git a/tests/nene2/auth/test_api_key.py b/tests/nene2/auth/test_api_key.py index e15e247..d682b78 100644 --- a/tests/nene2/auth/test_api_key.py +++ b/tests/nene2/auth/test_api_key.py @@ -5,6 +5,7 @@ from fastapi.testclient import TestClient from nene2.auth import ApiKeyAuthMiddleware, LocalTokenVerifier +from nene2.auth.exceptions import TokenVerificationException def _make_app(keys: list[str]) -> FastAPI: @@ -43,3 +44,25 @@ def test_multiple_allowed_keys() -> None: assert client.get("/secret", headers={"X-Api-Key": "key-a"}).status_code == 200 assert client.get("/secret", headers={"X-Api-Key": "key-b"}).status_code == 200 assert client.get("/secret", headers={"X-Api-Key": "key-c"}).status_code == 401 + + +def test_verifier_raises_token_verification_exception_returns_401() -> None: + """TokenVerificationException from verifier must return 401, not 500.""" + + class ExplodingVerifier: + def verify(self, token: str) -> bool: + raise TokenVerificationException("simulated failure") + + app = FastAPI() + app.add_middleware( + ApiKeyAuthMiddleware, + verifier=ExplodingVerifier(), # type: ignore[arg-type] + ) + + @app.get("/secret") + async def secret() -> JSONResponse: + return JSONResponse({"ok": True}) + + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/secret", headers={"X-Api-Key": "any-key"}) + assert response.status_code == 401