diff --git a/src/example/app.py b/src/example/app.py index 73f7cb3..b52a798 100644 --- a/src/example/app.py +++ b/src/example/app.py @@ -1,5 +1,7 @@ """Application factory — wires dependencies and registers routes.""" +import os + from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware @@ -7,7 +9,12 @@ from sqlalchemy import create_engine from sqlalchemy.pool import StaticPool -from nene2.auth import ApiKeyAuthMiddleware, BearerTokenMiddleware, LocalTokenVerifier +from nene2.auth import ( + ApiKeyAuthMiddleware, + BearerTokenMiddleware, + LocalBearerJwtVerifier, + LocalTokenVerifier, +) from nene2.config import AppSettings from nene2.database import ( DatabaseHealthCheck, @@ -67,6 +74,7 @@ _FRAMEWORK_DESCRIPTION = "JSON APIs first, minimal server HTML, frontend ready, AI-readable." # Matches nene2-js tools/compose-ft-evac.yaml default for local evac smoke. _DEFAULT_MACHINE_API_KEYS = ["ft-evac-local-machine-api-key-32ch!!"] +_DEFAULT_JWT_SECRET = "ft-evac-local-jwt-secret-min-32-chars!!" # noqa: S105 type _Repos = tuple[ NoteRepositoryInterface, @@ -159,6 +167,13 @@ def create_app(settings: AppSettings | None = None) -> FastAPI: include_paths=["/machine/health"], header_name="X-NENE2-API-Key", ) + jwt_secret = os.getenv("NENE2_LOCAL_JWT_SECRET", _DEFAULT_JWT_SECRET) + if len(jwt_secret) >= 32: + app.add_middleware( + BearerTokenMiddleware, + verifier=LocalBearerJwtVerifier(jwt_secret), + include_paths=["/examples/protected"], + ) # CORS must be outermost — register last so preflight OPTIONS is handled # before throttle, auth, or any other middleware runs. if cfg.cors_enabled: @@ -229,6 +244,16 @@ async def framework_smoke() -> JSONResponse: async def example_ping() -> JSONResponse: return JSONResponse({"message": "pong", "status": "ok"}) + @app.get("/examples/protected", tags=["Examples"], summary="Protected example endpoint") + async def examples_protected(request: Request) -> JSONResponse: + claims = getattr(request.state, "nene2_auth_claims", {}) + return JSONResponse( + { + "message": "Welcome, authenticated user.", + "claims": claims, + } + ) + @app.get("/machine/health", tags=["system"], summary="Protected machine health endpoint") async def machine_health(request: Request) -> JSONResponse: credential_type = getattr(request.state, "nene2_auth_credential_type", "api_key") diff --git a/src/nene2/auth/__init__.py b/src/nene2/auth/__init__.py index 21bdefe..d492af0 100644 --- a/src/nene2/auth/__init__.py +++ b/src/nene2/auth/__init__.py @@ -12,6 +12,7 @@ from .deps import make_require_auth from .exceptions import TokenVerificationException from .interfaces import TokenIssuerProtocol, TokenVerifierProtocol +from .local_bearer_jwt import LocalBearerJwtVerifier from .local_issuer import LocalTokenIssuer, LocalTokenIssuerVerifier from .local_verifier import LocalTokenVerifier @@ -21,6 +22,7 @@ "BearerTokenMiddleware", "CompositeAuthMiddleware", "CompositeAuthRule", + "LocalBearerJwtVerifier", "LocalTokenIssuer", "LocalTokenIssuerVerifier", "LocalTokenVerifier", diff --git a/src/nene2/auth/bearer_token.py b/src/nene2/auth/bearer_token.py index 2258381..92b874a 100644 --- a/src/nene2/auth/bearer_token.py +++ b/src/nene2/auth/bearer_token.py @@ -91,4 +91,18 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - ) response.headers["WWW-Authenticate"] = _WWW_AUTH return response + decode_claims = getattr(self._verifier, "decode_claims", None) + if callable(decode_claims): + try: + request.state.nene2_auth_claims = decode_claims(token) + except TokenVerificationException: + response = problem_details_response( + "unauthorized", + "Unauthorized", + 401, + "The provided token is invalid or expired.", + ) + response.headers["WWW-Authenticate"] = _WWW_AUTH + return response + request.state.nene2_auth_credential_type = "bearer" return await call_next(request) diff --git a/src/nene2/auth/local_bearer_jwt.py b/src/nene2/auth/local_bearer_jwt.py new file mode 100644 index 0000000..12b8dfe --- /dev/null +++ b/src/nene2/auth/local_bearer_jwt.py @@ -0,0 +1,60 @@ +"""HS256 JWT verifier compatible with NENE2 LocalBearerTokenVerifier (dev/FT only).""" + +import base64 +import hashlib +import hmac +import json +import time + +from .exceptions import TokenVerificationException + + +def _b64url_decode(segment: str) -> bytes: + padding = (4 - len(segment) % 4) % 4 + return base64.urlsafe_b64decode(segment + "=" * padding) + + +class LocalBearerJwtVerifier: + """Verify NENE2-style HS256 JWT bearer tokens (three dot-separated segments).""" + + def __init__(self, secret: str) -> None: + if len(secret) < 32: + msg = "JWT secret must be at least 32 characters for local bearer verification." + raise ValueError(msg) + self._secret = secret.encode("utf-8") + + def decode_claims(self, token: str) -> dict[str, object]: + """Return claims when valid; raises TokenVerificationException otherwise.""" + parts = token.split(".") + if len(parts) != 3: + raise TokenVerificationException("Token format is invalid: expected three segments.") + header_b64, payload_b64, sig_b64 = parts + header = json.loads(_b64url_decode(header_b64)) + if header.get("alg") != "HS256": + raise TokenVerificationException("Token algorithm must be HS256.") + signing_input = f"{header_b64}.{payload_b64}".encode() + expected_sig = ( + base64.urlsafe_b64encode( + hmac.new(self._secret, signing_input, hashlib.sha256).digest(), + ) + .rstrip(b"=") + .decode() + ) + if not hmac.compare_digest(expected_sig, sig_b64): + raise TokenVerificationException("Token signature is invalid.") + claims: dict[str, object] = json.loads(_b64url_decode(payload_b64)) + now = int(time.time()) + nbf = claims.get("nbf") + if isinstance(nbf, int) and nbf > now: + raise TokenVerificationException("Token is not yet valid.") + exp = claims.get("exp") + if isinstance(exp, int) and exp < now: + raise TokenVerificationException("Token has expired.") + return claims + + def verify(self, token: str) -> bool: + try: + self.decode_claims(token) + except TokenVerificationException: + return False + return True diff --git a/tests/example/test_examples_protected.py b/tests/example/test_examples_protected.py new file mode 100644 index 0000000..60c5563 --- /dev/null +++ b/tests/example/test_examples_protected.py @@ -0,0 +1,53 @@ +"""GET /examples/protected — Bearer JWT parity with NENE2.""" + +import base64 +import hashlib +import hmac +import json +import time + +from fastapi.testclient import TestClient + +from example.app import create_app +from nene2.config import AppSettings + +_SECRET = "ft-evac-local-jwt-secret-min-32-chars!!" # noqa: S105 + + +def _bearer() -> str: + now = int(time.time()) + header_b64 = ( + base64.urlsafe_b64encode( + json.dumps({"typ": "JWT", "alg": "HS256"}).encode(), + ) + .rstrip(b"=") + .decode() + ) + claims = {"sub": "user-42", "scope": "read:system", "iat": now, "exp": now + 3600} + payload_b64 = base64.urlsafe_b64encode(json.dumps(claims).encode()).rstrip(b"=").decode() + sig_b64 = ( + base64.urlsafe_b64encode( + hmac.new( + _SECRET.encode(), + f"{header_b64}.{payload_b64}".encode(), + hashlib.sha256, + ).digest(), + ) + .rstrip(b"=") + .decode() + ) + return f"{header_b64}.{payload_b64}.{sig_b64}" + + +def test_protected_requires_bearer() -> None: + client = TestClient(create_app(AppSettings(throttle_enabled=False))) + assert client.get("/examples/protected").status_code == 401 + + +def test_protected_with_jwt() -> None: + client = TestClient(create_app(AppSettings(throttle_enabled=False))) + r = client.get("/examples/protected", headers={"Authorization": f"Bearer {_bearer()}"}) + assert r.status_code == 200 + body = r.json() + assert "Welcome" in body["message"] + assert body["claims"]["sub"] == "user-42" diff --git a/tests/nene2/auth/test_local_bearer_jwt.py b/tests/nene2/auth/test_local_bearer_jwt.py new file mode 100644 index 0000000..154eb26 --- /dev/null +++ b/tests/nene2/auth/test_local_bearer_jwt.py @@ -0,0 +1,60 @@ +"""LocalBearerJwtVerifier — NENE2 HS256 JWT parity.""" + +import base64 +import hashlib +import hmac +import json +import time + +import pytest + +from nene2.auth import LocalBearerJwtVerifier +from nene2.auth.exceptions import TokenVerificationException + +_SECRET = "ft-evac-local-jwt-secret-min-32-chars!!" # noqa: S105 + + +def _issue(secret: str, claims: dict[str, object]) -> str: + header_b64 = ( + base64.urlsafe_b64encode( + json.dumps({"typ": "JWT", "alg": "HS256"}).encode(), + ) + .rstrip(b"=") + .decode() + ) + payload_b64 = base64.urlsafe_b64encode(json.dumps(claims).encode()).rstrip(b"=").decode() + sig_b64 = ( + base64.urlsafe_b64encode( + hmac.new( + secret.encode(), + f"{header_b64}.{payload_b64}".encode(), + hashlib.sha256, + ).digest(), + ) + .rstrip(b"=") + .decode() + ) + return f"{header_b64}.{payload_b64}.{sig_b64}" + + +def test_verify_and_decode_ok() -> None: + now = int(time.time()) + claims = {"sub": "user-42", "scope": "read:system", "iat": now, "exp": now + 3600} + token = _issue(_SECRET, claims) + verifier = LocalBearerJwtVerifier(_SECRET) + assert verifier.verify(token) is True + claims = verifier.decode_claims(token) + assert claims["sub"] == "user-42" + + +def test_rejects_expired() -> None: + now = int(time.time()) + token = _issue(_SECRET, {"sub": "x", "exp": now - 10}) + verifier = LocalBearerJwtVerifier(_SECRET) + with pytest.raises(TokenVerificationException): + verifier.decode_claims(token) + + +def test_secret_too_short() -> None: + with pytest.raises(ValueError): + LocalBearerJwtVerifier("short")