Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions deployments/api/tests/test_auth_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""Integration tests for auth module JIT user provisioning."""

from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from unittest.mock import MagicMock, AsyncMock

import pytest
from sqlalchemy import select

Expand Down Expand Up @@ -258,3 +262,89 @@ async def test_backfill_survives_caller_rollback(
).scalar_one()
assert row.name == "Filled"
assert row.email == "filled@example.com"


class TestGetCurrentUserIntegrityErrorHandling:
"""Unit tests for narrowed IntegrityError recovery."""

@pytest.mark.anyio
async def test_unrelated_integrity_error_propagates(self):
"""An IntegrityError without a concurrent row must propagate, not be masked."""
miss_result = MagicMock()
miss_result.scalar_one_or_none = MagicMock(return_value=None)

session = MagicMock(spec=AsyncSession)
session.execute = AsyncMock(return_value=miss_result)
session.add = MagicMock()
session.commit = AsyncMock(
side_effect=IntegrityError(
"INSERT INTO users", {}, Exception("simulated check constraint")
)
)
session.rollback = AsyncMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=None)

factory = MagicMock(return_value=session)

claims = _make_claims(sub="auth0|broken")

with pytest.raises(IntegrityError):
await get_current_user(claims, factory)

session.rollback.assert_awaited_once()


class TestGetCurrentUserMachineToMachine:
"""JIT provisioning works for Auth0 M2M tokens (sub=...@clients, no email/name)."""

@pytest.mark.anyio
async def test_creates_user_for_m2m_token(
self,
integration_session_factory,
):
"""M2M token with no email/name JIT-creates a users row keyed by sub."""
claims = TokenClaims(
sub="tujhthy5vlhVd43C27te5Vtdy6a5BMJ7@clients",
email=None,
name=None,
permissions=frozenset({"resource:read:public"}),
raw={},
)

user = await get_current_user(claims, integration_session_factory)

assert user.sub == "tujhthy5vlhVd43C27te5Vtdy6a5BMJ7@clients"
assert user.email is None
assert user.name is None
assert user.id is not None

async with integration_session_factory() as session:
row = (
await session.execute(
select(UserModel).where(UserModel.sub == claims.sub)
)
).scalar_one()
assert row.email is None
assert row.name is None

@pytest.mark.anyio
async def test_returns_existing_m2m_user(
self,
integration_session_factory,
):
"""Subsequent M2M requests find the existing row rather than re-inserting."""
async with integration_session_factory() as session:
session.add(UserModel(sub="abc123@clients", name=None, email=None))
await session.commit()

claims = TokenClaims(
sub="abc123@clients",
email=None,
name=None,
raw={},
)

user = await get_current_user(claims, integration_session_factory)

assert user.sub == "abc123@clients"
64 changes: 64 additions & 0 deletions packages/stitch-auth/tests/test_validator_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,67 @@ def test_rejects_permissions_as_dict(

with pytest.raises(TokenValidationError, match="permissions"):
validator.validate(token)


class TestJWTValidatorMachineToMachineTokens:
"""Auth0 client-credentials tokens: sub ends in @clients, no email/name."""

def test_validates_m2m_token_with_clients_sub(
self, oidc_settings, mock_jwks_client, token_factory
):
"""M2M token validates and produces claims with the @clients sub preserved."""
token = token_factory(
sub="tujhthy5vlhVd43C27te5Vtdy6a5BMJ7@clients",
email=None,
name=None,
extra_claims={
"azp": "tujhthy5vlhVd43C27te5Vtdy6a5BMJ7",
"scope": "resource:read:public",
"permissions": ["resource:read:public"],
},
)
validator = JWTValidator(oidc_settings)

claims = validator.validate(token)

assert claims.sub == "tujhthy5vlhVd43C27te5Vtdy6a5BMJ7@clients"
assert claims.email is None
assert claims.name is None
assert claims.permissions == frozenset({"resource:read:public"})

def test_validates_m2m_token_without_permissions_claim(
self, oidc_settings, mock_jwks_client, token_factory
):
"""When RBAC is off, permissions array is absent — claims.permissions is empty.

The route layer is responsible for treating empty permissions as 403; the
validator must not reject the token just because the array is missing.
"""
token = token_factory(
sub="abc123@clients",
email=None,
name=None,
extra_claims={"scope": "openid"},
)
validator = JWTValidator(oidc_settings)

claims = validator.validate(token)

assert claims.sub == "abc123@clients"
assert claims.permissions == frozenset()

def test_m2m_token_raw_payload_preserves_azp(
self, oidc_settings, mock_jwks_client, token_factory
):
"""`azp` (authorized party) round-trips into raw for downstream audit logging."""
token = token_factory(
sub="abc123@clients",
email=None,
name=None,
extra_claims={"azp": "abc123", "permissions": []},
)
validator = JWTValidator(oidc_settings)

claims = validator.validate(token)

assert claims.raw["azp"] == "abc123"
9 changes: 7 additions & 2 deletions packages/stitch-client/src/stitch/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from .async_client import AsyncStitchClient
from .auth import (
Auth0M2MAuth,
STITCH_CLIENT_BEARER_TOKEN_ENV_VAR,
env_bearer_token_headers_provider,
)
from .errors import StitchAPIError
from .config import StitchClientConfig
from .errors import StitchAPIError, StitchAuthError

__all__ = [
"AsyncStitchClient",
"STITCH_CLIENT_BEARER_TOKEN_ENV_VAR",
"AsyncStitchClient",
"Auth0M2MAuth",
"StitchAPIError",
"StitchAuthError",
"StitchClientConfig",
"env_bearer_token_headers_provider",
]
32 changes: 31 additions & 1 deletion packages/stitch-client/src/stitch/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import asyncio
import logging
from collections.abc import Callable, Mapping
from typing import Any
from typing import TYPE_CHECKING, Any

import httpx
from stitch.ogsi.model import OGFieldResource

from .errors import StitchAPIError

if TYPE_CHECKING:
from .config import StitchClientConfig

logger = logging.getLogger("stitch.client")


Expand Down Expand Up @@ -44,6 +47,33 @@ def __init__(
timeout=timeout if timeout is not None else 30.0,
)

@classmethod
def from_config(
cls,
config: "StitchClientConfig",
*,
timeout: float = 30.0,
) -> "AsyncStitchClient":
from .auth import Auth0M2MAuth, fetch_auth_jwt

async def _fetch() -> str:
return await fetch_auth_jwt(config)

httpx_client = httpx.AsyncClient(
base_url=config.api_base_url,
timeout=timeout,
auth=Auth0M2MAuth(_fetch),
)
instance = cls(client=httpx_client)
instance._owns_client = True
return instance

@classmethod
def from_env(cls, *, timeout: float = 30.0) -> "AsyncStitchClient":
from .config import StitchClientConfig

return cls.from_config(StitchClientConfig.from_env(), timeout=timeout)

async def __aenter__(self) -> "AsyncStitchClient":
return self

Expand Down
93 changes: 92 additions & 1 deletion packages/stitch-client/src/stitch/client/auth.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,102 @@
from __future__ import annotations

import asyncio
from collections.abc import AsyncGenerator, Awaitable, Callable
import os
from collections.abc import Callable


from typing import Any

import httpx

from .config import StitchClientConfig
from .errors import StitchAuthError

STITCH_CLIENT_BEARER_TOKEN_ENV_VAR = "STITCH_CLIENT_BEARER_TOKEN"


async def fetch_auth_jwt(
config: StitchClientConfig,
*,
transport: httpx.AsyncBaseTransport | None = None,
) -> str:
"""POST client-credentials to Auth0 and return the access_token string."""
payload = {
"client_id": config.client_id,
"client_secret": config.client_secret,
"audience": config.audience,
"grant_type": "client_credentials",
}
async with httpx.AsyncClient(
base_url=config.auth_issuer_url,
transport=transport,
timeout=30.0,
) as auth_client:
try:
res = await auth_client.post("/oauth/token", json=payload)
except httpx.HTTPError as exc:
raise StitchAuthError(f"Auth0 token request failed: {exc}") from exc

if res.status_code != 200:
raise StitchAuthError(
f"Auth0 token request returned status {res.status_code}",
status_code=res.status_code,
response_text=res.text,
)
try:
body = res.json()
except ValueError as exc:
raise StitchAuthError(
"Auth0 token response was not valid JSON",
status_code=res.status_code,
response_text=res.text,
) from exc
token = body.get("access_token")
if not isinstance(token, str) or not token:
raise StitchAuthError(
"Auth0 token response missing 'access_token'",
status_code=res.status_code,
response_text=res.text,
)
return token


TokenFetcher = Callable[[], Awaitable[str]]


class Auth0M2MAuth(httpx.Auth):
requires_response_body = True

def __init__(self, token_fetcher: TokenFetcher) -> None:
self._token_fetcher = token_fetcher
self._token: str | None = None
self._lock = asyncio.Lock()

async def _ensure_token(self, *, force: bool = False) -> str:
async with self._lock:
if force or self._token is None:
self._token = await self._token_fetcher()
return self._token

async def async_auth_flow(
self, request: httpx.Request
) -> AsyncGenerator[httpx.Request, httpx.Response]:
token = await self._ensure_token()
request.headers["Authorization"] = f"Bearer {token}"
response = yield request

if response.status_code != 401:
return

await response.aread()
token = await self._ensure_token(force=True)
request.headers["Authorization"] = f"Bearer {token}"
yield request

def sync_auth_flow(self, request: httpx.Request) -> Any: # pragma: no cover
raise RuntimeError("Auth0M2MAuth only supports async usage")


def env_bearer_token_headers_provider() -> Callable[[], dict[str, str]]:
"""
Build a headers provider backed by STITCH_CLIENT_BEARER_TOKEN.
Expand Down
46 changes: 46 additions & 0 deletions packages/stitch-client/src/stitch/client/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

import os
from dataclasses import dataclass

from .errors import StitchAuthError

_REQUIRED_ENV_VARS = (
"STITCH_AUTH_CLIENT_ID",
"STITCH_AUTH_CLIENT_SECRET",
"STITCH_AUTH_AUDIENCE",
"STITCH_AUTH_ISSUER_URL",
"STITCH_API_BASE_URL",
)


@dataclass(frozen=True)
class StitchClientConfig:
client_id: str
client_secret: str
audience: str
auth_issuer_url: str
api_base_url: str

@classmethod
def from_env(cls) -> "StitchClientConfig":
missing: list[str] = []
values: dict[str, str] = {}
for var in _REQUIRED_ENV_VARS:
v = os.environ.get(var)
if not v:
missing.append(var)
else:
values[var] = v
if missing:
raise StitchAuthError(
"Missing required environment variable(s) for "
f"StitchClientConfig.from_env(): {', '.join(missing)}"
)
return cls(
client_id=values["STITCH_AUTH_CLIENT_ID"],
client_secret=values["STITCH_AUTH_CLIENT_SECRET"],
audience=values["STITCH_AUTH_AUDIENCE"],
auth_issuer_url=values["STITCH_AUTH_ISSUER_URL"],
api_base_url=values["STITCH_API_BASE_URL"],
)
Loading
Loading