Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/example/app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
"""Application factory — wires dependencies and registers routes."""

import os

from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool

from nene2.auth import ApiKeyAuthMiddleware, BearerTokenMiddleware, LocalTokenVerifier
from nene2.auth import (
ApiKeyAuthMiddleware,
BearerTokenMiddleware,
LocalBearerJwtVerifier,
LocalTokenVerifier,
)
from nene2.config import AppSettings
from nene2.database import (
DatabaseHealthCheck,
Expand Down Expand Up @@ -67,6 +74,7 @@
_FRAMEWORK_DESCRIPTION = "JSON APIs first, minimal server HTML, frontend ready, AI-readable."
# Matches nene2-js tools/compose-ft-evac.yaml default for local evac smoke.
_DEFAULT_MACHINE_API_KEYS = ["ft-evac-local-machine-api-key-32ch!!"]
_DEFAULT_JWT_SECRET = "ft-evac-local-jwt-secret-min-32-chars!!" # noqa: S105

type _Repos = tuple[
NoteRepositoryInterface,
Expand Down Expand Up @@ -159,6 +167,13 @@ def create_app(settings: AppSettings | None = None) -> FastAPI:
include_paths=["/machine/health"],
header_name="X-NENE2-API-Key",
)
jwt_secret = os.getenv("NENE2_LOCAL_JWT_SECRET", _DEFAULT_JWT_SECRET)
if len(jwt_secret) >= 32:
app.add_middleware(
BearerTokenMiddleware,
verifier=LocalBearerJwtVerifier(jwt_secret),
include_paths=["/examples/protected"],
)
# CORS must be outermost — register last so preflight OPTIONS is handled
# before throttle, auth, or any other middleware runs.
if cfg.cors_enabled:
Expand Down Expand Up @@ -229,6 +244,16 @@ async def framework_smoke() -> JSONResponse:
async def example_ping() -> JSONResponse:
return JSONResponse({"message": "pong", "status": "ok"})

@app.get("/examples/protected", tags=["Examples"], summary="Protected example endpoint")
async def examples_protected(request: Request) -> JSONResponse:
claims = getattr(request.state, "nene2_auth_claims", {})
return JSONResponse(
{
"message": "Welcome, authenticated user.",
"claims": claims,
}
)

@app.get("/machine/health", tags=["system"], summary="Protected machine health endpoint")
async def machine_health(request: Request) -> JSONResponse:
credential_type = getattr(request.state, "nene2_auth_credential_type", "api_key")
Expand Down
2 changes: 2 additions & 0 deletions src/nene2/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .deps import make_require_auth
from .exceptions import TokenVerificationException
from .interfaces import TokenIssuerProtocol, TokenVerifierProtocol
from .local_bearer_jwt import LocalBearerJwtVerifier
from .local_issuer import LocalTokenIssuer, LocalTokenIssuerVerifier
from .local_verifier import LocalTokenVerifier

Expand All @@ -21,6 +22,7 @@
"BearerTokenMiddleware",
"CompositeAuthMiddleware",
"CompositeAuthRule",
"LocalBearerJwtVerifier",
"LocalTokenIssuer",
"LocalTokenIssuerVerifier",
"LocalTokenVerifier",
Expand Down
14 changes: 14 additions & 0 deletions src/nene2/auth/bearer_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,18 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
)
response.headers["WWW-Authenticate"] = _WWW_AUTH
return response
decode_claims = getattr(self._verifier, "decode_claims", None)
if callable(decode_claims):
try:
request.state.nene2_auth_claims = decode_claims(token)
except TokenVerificationException:
response = problem_details_response(
"unauthorized",
"Unauthorized",
401,
"The provided token is invalid or expired.",
)
response.headers["WWW-Authenticate"] = _WWW_AUTH
return response
request.state.nene2_auth_credential_type = "bearer"
return await call_next(request)
60 changes: 60 additions & 0 deletions src/nene2/auth/local_bearer_jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""HS256 JWT verifier compatible with NENE2 LocalBearerTokenVerifier (dev/FT only)."""

import base64
import hashlib
import hmac
import json
import time

from .exceptions import TokenVerificationException


def _b64url_decode(segment: str) -> bytes:
padding = (4 - len(segment) % 4) % 4
return base64.urlsafe_b64decode(segment + "=" * padding)


class LocalBearerJwtVerifier:
"""Verify NENE2-style HS256 JWT bearer tokens (three dot-separated segments)."""

def __init__(self, secret: str) -> None:
if len(secret) < 32:
msg = "JWT secret must be at least 32 characters for local bearer verification."
raise ValueError(msg)
self._secret = secret.encode("utf-8")

def decode_claims(self, token: str) -> dict[str, object]:
"""Return claims when valid; raises TokenVerificationException otherwise."""
parts = token.split(".")
if len(parts) != 3:
raise TokenVerificationException("Token format is invalid: expected three segments.")
header_b64, payload_b64, sig_b64 = parts
header = json.loads(_b64url_decode(header_b64))
if header.get("alg") != "HS256":
raise TokenVerificationException("Token algorithm must be HS256.")
signing_input = f"{header_b64}.{payload_b64}".encode()
expected_sig = (
base64.urlsafe_b64encode(
hmac.new(self._secret, signing_input, hashlib.sha256).digest(),
)
.rstrip(b"=")
.decode()
)
if not hmac.compare_digest(expected_sig, sig_b64):
raise TokenVerificationException("Token signature is invalid.")
claims: dict[str, object] = json.loads(_b64url_decode(payload_b64))
now = int(time.time())
nbf = claims.get("nbf")
if isinstance(nbf, int) and nbf > now:
raise TokenVerificationException("Token is not yet valid.")
exp = claims.get("exp")
if isinstance(exp, int) and exp < now:
raise TokenVerificationException("Token has expired.")
return claims

def verify(self, token: str) -> bool:
try:
self.decode_claims(token)
except TokenVerificationException:
return False
return True
53 changes: 53 additions & 0 deletions tests/example/test_examples_protected.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""GET /examples/protected — Bearer JWT parity with NENE2."""

import base64
import hashlib
import hmac
import json
import time

from fastapi.testclient import TestClient

from example.app import create_app
from nene2.config import AppSettings

_SECRET = "ft-evac-local-jwt-secret-min-32-chars!!" # noqa: S105


def _bearer() -> str:
now = int(time.time())
header_b64 = (
base64.urlsafe_b64encode(
json.dumps({"typ": "JWT", "alg": "HS256"}).encode(),
)
.rstrip(b"=")
.decode()
)
claims = {"sub": "user-42", "scope": "read:system", "iat": now, "exp": now + 3600}
payload_b64 = base64.urlsafe_b64encode(json.dumps(claims).encode()).rstrip(b"=").decode()
sig_b64 = (
base64.urlsafe_b64encode(
hmac.new(
_SECRET.encode(),
f"{header_b64}.{payload_b64}".encode(),
hashlib.sha256,
).digest(),
)
.rstrip(b"=")
.decode()
)
return f"{header_b64}.{payload_b64}.{sig_b64}"


def test_protected_requires_bearer() -> None:
client = TestClient(create_app(AppSettings(throttle_enabled=False)))
assert client.get("/examples/protected").status_code == 401


def test_protected_with_jwt() -> None:
client = TestClient(create_app(AppSettings(throttle_enabled=False)))
r = client.get("/examples/protected", headers={"Authorization": f"Bearer {_bearer()}"})
assert r.status_code == 200
body = r.json()
assert "Welcome" in body["message"]
assert body["claims"]["sub"] == "user-42"
60 changes: 60 additions & 0 deletions tests/nene2/auth/test_local_bearer_jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""LocalBearerJwtVerifier — NENE2 HS256 JWT parity."""

import base64
import hashlib
import hmac
import json
import time

import pytest

from nene2.auth import LocalBearerJwtVerifier
from nene2.auth.exceptions import TokenVerificationException

_SECRET = "ft-evac-local-jwt-secret-min-32-chars!!" # noqa: S105


def _issue(secret: str, claims: dict[str, object]) -> str:
header_b64 = (
base64.urlsafe_b64encode(
json.dumps({"typ": "JWT", "alg": "HS256"}).encode(),
)
.rstrip(b"=")
.decode()
)
payload_b64 = base64.urlsafe_b64encode(json.dumps(claims).encode()).rstrip(b"=").decode()
sig_b64 = (
base64.urlsafe_b64encode(
hmac.new(
secret.encode(),
f"{header_b64}.{payload_b64}".encode(),
hashlib.sha256,
).digest(),
)
.rstrip(b"=")
.decode()
)
return f"{header_b64}.{payload_b64}.{sig_b64}"


def test_verify_and_decode_ok() -> None:
now = int(time.time())
claims = {"sub": "user-42", "scope": "read:system", "iat": now, "exp": now + 3600}
token = _issue(_SECRET, claims)
verifier = LocalBearerJwtVerifier(_SECRET)
assert verifier.verify(token) is True
claims = verifier.decode_claims(token)
assert claims["sub"] == "user-42"


def test_rejects_expired() -> None:
now = int(time.time())
token = _issue(_SECRET, {"sub": "x", "exp": now - 10})
verifier = LocalBearerJwtVerifier(_SECRET)
with pytest.raises(TokenVerificationException):
verifier.decode_claims(token)


def test_secret_too_short() -> None:
with pytest.raises(ValueError):
LocalBearerJwtVerifier("short")
Loading