diff --git a/.github/actions/conformance/client.py b/.github/actions/conformance/client.py index 7ca88110a..cd15c4d64 100644 --- a/.github/actions/conformance/client.py +++ b/.github/actions/conformance/client.py @@ -16,6 +16,9 @@ elicitation-sep1034-client-defaults - Elicitation with default accept callback auth/client-credentials-jwt - Client credentials with private_key_jwt auth/client-credentials-basic - Client credentials with client_secret_basic + auth/enterprise-token-exchange - Enterprise auth with OIDC ID token (SEP-990) + auth/enterprise-saml-exchange - Enterprise auth with SAML assertion (SEP-990) + auth/enterprise-id-jag-validation - Validate ID-JAG token structure (SEP-990) auth/* - Authorization code flow (default for auth scenarios) """ @@ -293,6 +296,255 @@ async def run_auth_code_client(server_url: str) -> None: await _run_auth_session(server_url, oauth_auth) +@register("auth/enterprise-token-exchange") +async def run_enterprise_token_exchange(server_url: str) -> None: + """Enterprise managed auth: Token exchange flow (RFC 8693).""" + from mcp.client.auth.extensions.enterprise_managed_auth import ( + EnterpriseAuthOAuthClientProvider, + TokenExchangeParameters, + ) + + context = get_conformance_context() + id_token = context.get("id_token") + idp_token_endpoint = context.get("idp_token_endpoint") + mcp_server_auth_issuer = context.get("mcp_server_auth_issuer") + mcp_server_resource_id = context.get("mcp_server_resource_id") + scope = context.get("scope") + + if not id_token: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'id_token'") + if not idp_token_endpoint: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_token_endpoint'") + if not mcp_server_auth_issuer: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_auth_issuer'") + if not mcp_server_resource_id: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_resource_id'") + + # Create token exchange parameters + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=id_token, + mcp_server_auth_issuer=mcp_server_auth_issuer, + mcp_server_resource_id=mcp_server_resource_id, + scope=scope, + ) + + # Create enterprise auth provider + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url=server_url, + client_metadata=OAuthClientMetadata( + client_name="conformance-enterprise-client", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + response_types=["token"], + ), + storage=InMemoryTokenStorage(), + idp_token_endpoint=idp_token_endpoint, + token_exchange_params=token_exchange_params, + ) + + # Perform token exchange flow + async with httpx.AsyncClient() as client: + # Step 1: Set OAuth metadata manually (since we're not going through full OAuth flow) + logger.debug(f"Setting OAuth metadata for {server_url}") + from pydantic import AnyUrl as PydanticAnyUrl + + from mcp.shared.auth import OAuthMetadata + + # Extract base URL from server_url + base_url = server_url.replace("/mcp", "") + token_endpoint_url = f"{base_url}/oauth/token" + auth_endpoint_url = f"{base_url}/oauth/authorize" + + enterprise_auth.context.oauth_metadata = OAuthMetadata( + issuer=mcp_server_auth_issuer, + authorization_endpoint=PydanticAnyUrl(auth_endpoint_url), + token_endpoint=PydanticAnyUrl(token_endpoint_url), + ) + logger.debug(f"OAuth metadata set, token_endpoint: {token_endpoint_url}") + + # Step 2: Exchange ID token for ID-JAG + logger.debug("Exchanging ID token for ID-JAG") + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + logger.debug(f"Obtained ID-JAG: {id_jag[:50]}...") + + # Step 3: Exchange ID-JAG for access token + logger.debug("Exchanging ID-JAG for access token") + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) + logger.debug(f"Obtained access token, expires in: {access_token.expires_in}s") + + # Step 4: Verify we can make authenticated requests + logger.debug("Verifying access token with MCP endpoint") + auth_client = httpx.AsyncClient(headers={"Authorization": f"Bearer {access_token.access_token}"}) + response = await auth_client.get(server_url.replace("/mcp", "") + "/mcp") + if response.status_code == 200: + logger.debug(f"Successfully authenticated with MCP server: {response.json()}") + else: + logger.warning(f"MCP server returned {response.status_code}") + + logger.debug("Enterprise auth flow completed successfully") + + +@register("auth/enterprise-saml-exchange") +async def run_enterprise_saml_exchange(server_url: str) -> None: + """Enterprise managed auth: SAML assertion exchange flow.""" + from mcp.client.auth.extensions.enterprise_managed_auth import ( + EnterpriseAuthOAuthClientProvider, + TokenExchangeParameters, + ) + + context = get_conformance_context() + saml_assertion = context.get("saml_assertion") + idp_token_endpoint = context.get("idp_token_endpoint") + mcp_server_auth_issuer = context.get("mcp_server_auth_issuer") + mcp_server_resource_id = context.get("mcp_server_resource_id") + scope = context.get("scope") + + if not saml_assertion: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'saml_assertion'") + if not idp_token_endpoint: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_token_endpoint'") + if not mcp_server_auth_issuer: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_auth_issuer'") + if not mcp_server_resource_id: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_resource_id'") + + # Create token exchange parameters for SAML + token_exchange_params = TokenExchangeParameters.from_saml_assertion( + saml_assertion=saml_assertion, + mcp_server_auth_issuer=mcp_server_auth_issuer, + mcp_server_resource_id=mcp_server_resource_id, + scope=scope, + ) + + # Create enterprise auth provider + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url=server_url, + client_metadata=OAuthClientMetadata( + client_name="conformance-enterprise-saml-client", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + response_types=["token"], + ), + storage=InMemoryTokenStorage(), + idp_token_endpoint=idp_token_endpoint, + token_exchange_params=token_exchange_params, + ) + + # Perform token exchange flow + async with httpx.AsyncClient() as client: + # Step 1: Set OAuth metadata manually (since we're not going through full OAuth flow) + logger.debug(f"Setting OAuth metadata for {server_url}") + from pydantic import AnyUrl as PydanticAnyUrl + + from mcp.shared.auth import OAuthMetadata + + # Extract base URL from server_url + base_url = server_url.replace("/mcp", "") + token_endpoint_url = f"{base_url}/oauth/token" + auth_endpoint_url = f"{base_url}/oauth/authorize" + + enterprise_auth.context.oauth_metadata = OAuthMetadata( + issuer=mcp_server_auth_issuer, + authorization_endpoint=PydanticAnyUrl(auth_endpoint_url), + token_endpoint=PydanticAnyUrl(token_endpoint_url), + ) + logger.debug(f"OAuth metadata set, token_endpoint: {token_endpoint_url}") + + # Step 2: Exchange SAML assertion for ID-JAG + logger.debug("Exchanging SAML assertion for ID-JAG") + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + logger.debug(f"Obtained ID-JAG from SAML: {id_jag[:50]}...") + + # Step 3: Exchange ID-JAG for access token + logger.debug("Exchanging ID-JAG for access token") + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) + logger.debug(f"Obtained access token, expires in: {access_token.expires_in}s") + + # Step 4: Verify we can make authenticated requests + logger.debug("Verifying access token with MCP endpoint") + auth_client = httpx.AsyncClient(headers={"Authorization": f"Bearer {access_token.access_token}"}) + response = await auth_client.get(server_url.replace("/mcp", "") + "/mcp") + if response.status_code == 200: + logger.debug(f"Successfully authenticated with MCP server: {response.json()}") + else: + logger.warning(f"MCP server returned {response.status_code}") + + logger.debug("SAML enterprise auth flow completed successfully") + + +@register("auth/enterprise-id-jag-validation") +async def run_id_jag_validation(server_url: str) -> None: + """Validate ID-JAG token structure and claims.""" + from mcp.client.auth.extensions.enterprise_managed_auth import ( + EnterpriseAuthOAuthClientProvider, + TokenExchangeParameters, + decode_id_jag, + validate_token_exchange_params, + ) + + context = get_conformance_context() + id_token = context.get("id_token") + idp_token_endpoint = context.get("idp_token_endpoint") + mcp_server_auth_issuer = context.get("mcp_server_auth_issuer") + mcp_server_resource_id = context.get("mcp_server_resource_id") + + if not all([id_token, idp_token_endpoint, mcp_server_auth_issuer, mcp_server_resource_id]): + raise RuntimeError("Missing required context parameters for ID-JAG validation") + + # Create and validate token exchange parameters + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=id_token, + mcp_server_auth_issuer=mcp_server_auth_issuer, + mcp_server_resource_id=mcp_server_resource_id, + ) + + logger.debug("Validating token exchange parameters") + validate_token_exchange_params(token_exchange_params) + logger.debug("Token exchange parameters validated successfully") + + # Create enterprise auth provider + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url=server_url, + client_metadata=OAuthClientMetadata( + client_name="conformance-validation-client", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + response_types=["token"], + ), + storage=InMemoryTokenStorage(), + idp_token_endpoint=idp_token_endpoint, + token_exchange_params=token_exchange_params, + ) + + async with httpx.AsyncClient() as client: + # Get ID-JAG + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + logger.debug(f"Obtained ID-JAG for validation: {id_jag[:50]}...") + + # Decode and validate ID-JAG claims + logger.debug("Decoding ID-JAG token") + claims = decode_id_jag(id_jag) + + # Validate required claims + assert claims.typ == "oauth-id-jag+jwt", f"Invalid typ: {claims.typ}" + assert claims.jti, "Missing jti claim" + assert claims.iss == mcp_server_auth_issuer or claims.iss, "Missing or invalid iss claim" + assert claims.sub, "Missing sub claim" + assert claims.aud, "Missing aud claim" + assert claims.resource == mcp_server_resource_id, f"Invalid resource: {claims.resource}" + assert claims.client_id, "Missing client_id claim" + assert claims.exp > claims.iat, "Invalid expiration" + + logger.debug("ID-JAG validated successfully:") + logger.debug(f" Subject: {claims.sub}") + logger.debug(f" Issuer: {claims.iss}") + logger.debug(f" Audience: {claims.aud}") + logger.debug(f" Resource: {claims.resource}") + logger.debug(f" Client ID: {claims.client_id}") + + logger.debug("ID-JAG validation completed successfully") + + async def _run_auth_session(server_url: str, oauth_auth: OAuthClientProvider) -> None: """Common session logic for all OAuth flows.""" client = httpx.AsyncClient(auth=oauth_auth, timeout=30.0) diff --git a/.github/actions/conformance/enterprise_auth_server.py b/.github/actions/conformance/enterprise_auth_server.py new file mode 100644 index 000000000..5065add0f --- /dev/null +++ b/.github/actions/conformance/enterprise_auth_server.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +"""Enterprise Auth Mock Server for Conformance Testing + +This server provides: +1. Mock IdP token exchange endpoint (RFC 8693) +2. MCP OAuth token endpoint (RFC 7523) +3. Protected MCP tools endpoint +4. OAuth metadata endpoint + +Run on port 3002 to avoid conflicts with everything-server. +""" + +import asyncio +import json +import logging +import secrets +import time +import uuid +from datetime import datetime, timedelta, timezone +from typing import Any + +import jwt +from cryptography.hazmat.primitives.asymmetric import rsa +from fastapi import Depends, FastAPI, Form, HTTPException +from fastapi.responses import JSONResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from uvicorn import Config, Server + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Generate RSA key pair for JWT signing +PRIVATE_KEY = rsa.generate_private_key(public_exponent=65537, key_size=2048) +PUBLIC_KEY = PRIVATE_KEY.public_key() + +# Server configuration +IDP_ISSUER = "https://conformance-idp.example.com" +MCP_SERVER_ISSUER = "https://conformance-mcp.example.com" +MCP_SERVER_RESOURCE = "https://conformance-mcp.example.com" + +# In-memory storage +ACCESS_TOKENS: dict[str, dict[str, Any]] = {} + +app = FastAPI() + +# Security +security = HTTPBearer(auto_error=False) + + +def create_test_id_token(subject: str = "test-user@example.com", client_id: str = "test-client") -> str: + """Create a test ID token.""" + now = datetime.now(timezone.utc) + claims = { + "iss": IDP_ISSUER, + "sub": subject, + "aud": client_id, + "exp": int((now + timedelta(hours=1)).timestamp()), + "iat": int(now.timestamp()), + "email": subject, + } + return jwt.encode(claims, PRIVATE_KEY, algorithm="RS256") + + +def create_id_jag( + subject: str, + audience: str, + resource: str, + client_id: str, + scope: str | None = None, +) -> str: + """Create an ID-JAG token.""" + now = datetime.now(timezone.utc) + claims = { + "jti": str(uuid.uuid4()), + "iss": IDP_ISSUER, + "sub": subject, + "aud": audience, + "resource": resource, + "client_id": client_id, + "exp": int((now + timedelta(minutes=5)).timestamp()), + "iat": int(now.timestamp()), + } + if scope: + claims["scope"] = scope + + return jwt.encode(claims, PRIVATE_KEY, algorithm="RS256", headers={"typ": "oauth-id-jag+jwt"}) + + +def verify_id_token(id_token: str) -> dict[str, Any]: + """Verify and decode an ID token.""" + try: + claims = jwt.decode(id_token, PUBLIC_KEY, algorithms=["RS256"], options={"verify_aud": False}) + return claims + except jwt.InvalidTokenError as e: + raise HTTPException(status_code=400, detail=f"Invalid ID token: {e}") from e + + +def verify_id_jag(id_jag: str) -> dict[str, Any]: + """Verify and decode an ID-JAG token.""" + try: + header = jwt.get_unverified_header(id_jag) + if header.get("typ") != "oauth-id-jag+jwt": + raise HTTPException(status_code=400, detail="Invalid ID-JAG type") + claims = jwt.decode(id_jag, PUBLIC_KEY, algorithms=["RS256"], options={"verify_aud": False}) + return claims + except jwt.InvalidTokenError as e: + raise HTTPException(status_code=400, detail=f"Invalid ID-JAG: {e}") from e + + +# OAuth Metadata Endpoint +@app.get("/.well-known/oauth-authorization-server") +async def oauth_metadata() -> JSONResponse: + """OAuth 2.0 Authorization Server Metadata (RFC 8414).""" + metadata = { + "issuer": MCP_SERVER_ISSUER, + "token_endpoint": "http://localhost:3002/oauth/token", + "grant_types_supported": ["urn:ietf:params:oauth:grant-type:jwt-bearer", "refresh_token"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post", "none"], + } + return JSONResponse(metadata) + + +# Token Exchange Endpoint (IdP) +@app.post("/token-exchange") +async def token_exchange( + grant_type: str = Form(...), + requested_token_type: str = Form(...), + audience: str = Form(...), + resource: str = Form(...), + subject_token: str = Form(...), + subject_token_type: str = Form(...), + scope: str | None = Form(None), + client_id: str | None = Form(None), + client_secret: str | None = Form(None), +) -> JSONResponse: + """RFC 8693 Token Exchange endpoint.""" + logger.info(f"Token exchange request: grant_type={grant_type}, subject_token_type={subject_token_type}") + + if grant_type != "urn:ietf:params:oauth:grant-type:token-exchange": + return JSONResponse( + status_code=400, + content={"error": "unsupported_grant_type", "error_description": "Only token-exchange grant supported"}, + ) + + if requested_token_type != "urn:ietf:params:oauth:token-type:id-jag": + return JSONResponse( + status_code=400, + content={ + "error": "invalid_request", + "error_description": f"Unsupported token type: {requested_token_type}", + }, + ) + + # Extract subject based on token type + if subject_token_type == "urn:ietf:params:oauth:token-type:id_token": + id_token_claims = verify_id_token(subject_token) + subject = id_token_claims["sub"] + elif subject_token_type == "urn:ietf:params:oauth:token-type:saml2": + # For SAML, extract from mock data + import base64 + + try: + saml_data = json.loads(base64.b64decode(subject_token)) + subject = saml_data.get("subject", "saml-user@example.com") + except Exception: + subject = "saml-user@example.com" + else: + return JSONResponse( + status_code=400, + content={ + "error": "invalid_request", + "error_description": f"Unsupported subject token type: {subject_token_type}", + }, + ) + + # Create ID-JAG + id_jag = create_id_jag( + subject=subject, + audience=audience, + resource=resource, + client_id=client_id or "test-client", + scope=scope, + ) + + response = { + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": id_jag, + "token_type": "N_A", + "expires_in": 300, + } + if scope: + response["scope"] = scope + + logger.info("Token exchange successful") + return JSONResponse(response) + + +# JWT Bearer Grant Endpoint (MCP Server) +@app.post("/oauth/token") +async def jwt_bearer_grant( + grant_type: str = Form(...), + assertion: str = Form(...), + client_id: str | None = Form(None), + client_secret: str | None = Form(None), +) -> JSONResponse: + """RFC 7523 JWT Bearer Grant endpoint.""" + logger.info(f"JWT bearer grant request: grant_type={grant_type}") + + if grant_type != "urn:ietf:params:oauth:grant-type:jwt-bearer": + return JSONResponse( + status_code=400, + content={"error": "unsupported_grant_type", "error_description": "Only jwt-bearer grant supported"}, + ) + + # Verify ID-JAG + id_jag_claims = verify_id_jag(assertion) + + # Create access token + access_token = secrets.token_urlsafe(32) + expires_in = 3600 + + # Store token info + ACCESS_TOKENS[access_token] = { + "subject": id_jag_claims["sub"], + "client_id": id_jag_claims.get("client_id"), + "scope": id_jag_claims.get("scope"), + "expires_at": time.time() + expires_in, + } + + response = { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": expires_in, + } + if id_jag_claims.get("scope"): + response["scope"] = id_jag_claims["scope"] + + logger.info("JWT bearer grant successful") + return JSONResponse(response) + + +# MCP Endpoints (protected by access token) +@app.get("/mcp") +async def mcp_root() -> JSONResponse: + """MCP root endpoint - returns basic info.""" + return JSONResponse({"server": "enterprise-auth-test-server", "version": "1.0"}) + + +@app.post("/mcp") +async def mcp_jsonrpc( + credentials: HTTPAuthorizationCredentials | None = Depends(security), +) -> JSONResponse: + """MCP JSON-RPC endpoint.""" + if not credentials: + raise HTTPException(status_code=401, detail="Missing authorization") + + token = credentials.credentials + if token not in ACCESS_TOKENS: + raise HTTPException(status_code=401, detail="Invalid or expired token") + + token_info = ACCESS_TOKENS[token] + if token_info["expires_at"] < time.time(): + del ACCESS_TOKENS[token] + raise HTTPException(status_code=401, detail="Token expired") + + # Return a simple MCP response + return JSONResponse( + { + "jsonrpc": "2.0", + "result": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "serverInfo": {"name": "enterprise-auth-test-server", "version": "1.0"}, + }, + } + ) + + +# Helper endpoint to get test ID token +@app.get("/test/id-token") +async def get_test_id_token(subject: str = "test-user@example.com", client_id: str = "test-client") -> JSONResponse: + """Get a test ID token for conformance testing.""" + id_token = create_test_id_token(subject, client_id) + return JSONResponse({"id_token": id_token}) + + +@app.get("/test/context") +async def get_test_context() -> JSONResponse: + """Get complete test context for conformance testing.""" + id_token = create_test_id_token() + + # Create mock SAML assertion + import base64 + + saml_data = { + "issuer": IDP_ISSUER, + "subject": "saml-user@example.com", + "issued_at": datetime.now(timezone.utc).isoformat(), + } + saml_assertion = base64.b64encode(json.dumps(saml_data).encode()).decode() + + context = { + "id_token": id_token, + "saml_assertion": saml_assertion, + "idp_token_endpoint": "http://localhost:3002/token-exchange", + "mcp_server_auth_issuer": MCP_SERVER_ISSUER, + "mcp_server_resource_id": MCP_SERVER_RESOURCE, + "client_id": "test-client", + "scope": "mcp:tools mcp:resources", + } + + return JSONResponse(context) + + +async def run_server(port: int = 3002) -> None: + """Run the mock server.""" + config = Config(app=app, host="0.0.0.0", port=port, log_level="info") + server = Server(config) + logger.info(f"Starting Enterprise Auth Mock Server on port {port}") + logger.info(f"Token Exchange endpoint: http://localhost:{port}/token-exchange") + logger.info(f"JWT Bearer Grant endpoint: http://localhost:{port}/oauth/token") + logger.info(f"MCP endpoint: http://localhost:{port}/mcp") + logger.info(f"OAuth metadata: http://localhost:{port}/.well-known/oauth-authorization-server") + logger.info(f"Test context: http://localhost:{port}/test/context") + await server.serve() + + +if __name__ == "__main__": + import sys + + port = int(sys.argv[1]) if len(sys.argv) > 1 else 3002 + asyncio.run(run_server(port)) diff --git a/.github/actions/conformance/run-enterprise-auth-with-server.sh b/.github/actions/conformance/run-enterprise-auth-with-server.sh new file mode 100755 index 000000000..c8ddfdcc6 --- /dev/null +++ b/.github/actions/conformance/run-enterprise-auth-with-server.sh @@ -0,0 +1,169 @@ +#!/bin/bash +set -e + +# Enterprise Auth Full Conformance Test with Mock Server +# This script: +# 1. Starts the enterprise auth mock server (IdP + OAuth endpoints) +# 2. Fetches test context from the server +# 3. Runs all enterprise auth conformance scenarios +# 4. Cleans up servers on exit + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/../../.." + +MOCK_SERVER_PORT=3002 +MOCK_SERVER_URL="http://localhost:${MOCK_SERVER_PORT}" + +echo "===================================================================" +echo " Enterprise Auth Conformance Tests with Mock Server (SEP-990)" +echo "===================================================================" +echo "" + +# Function to cleanup servers +cleanup() { + echo "" + echo "Cleaning up servers..." + if [ -n "$MOCK_SERVER_PID" ]; then + kill $MOCK_SERVER_PID 2>/dev/null || true + wait $MOCK_SERVER_PID 2>/dev/null || true + echo "✓ Mock server stopped" + fi +} + +trap cleanup EXIT + +# Start enterprise auth mock server +echo "Starting Enterprise Auth Mock Server on port ${MOCK_SERVER_PORT}..." +uv run --frozen python "$SCRIPT_DIR/enterprise_auth_server.py" $MOCK_SERVER_PORT > /tmp/enterprise_auth_server.log 2>&1 & +MOCK_SERVER_PID=$! + +# Wait for mock server to be ready +echo "Waiting for mock server to be ready..." +MAX_RETRIES=30 +RETRY_COUNT=0 +while ! curl -s "${MOCK_SERVER_URL}/test/context" > /dev/null 2>&1; do + RETRY_COUNT=$((RETRY_COUNT + 1)) + if [ $RETRY_COUNT -ge $MAX_RETRIES ]; then + echo "✗ Mock server failed to start after ${MAX_RETRIES} retries" >&2 + echo "Server log:" + cat /tmp/enterprise_auth_server.log + exit 1 + fi + sleep 0.5 +done + +echo "✓ Mock server ready at ${MOCK_SERVER_URL}" +echo "" + +# Fetch test context from server +echo "Fetching test context from mock server..." +export MCP_CONFORMANCE_CONTEXT=$(curl -s "${MOCK_SERVER_URL}/test/context") + +if [ -z "$MCP_CONFORMANCE_CONTEXT" ]; then + echo "✗ Failed to fetch test context" >&2 + exit 1 +fi + +echo "✓ Test context retrieved" +echo "" + +# Display server info +echo "Server Endpoints:" +echo " - Token Exchange (IdP): ${MOCK_SERVER_URL}/token-exchange" +echo " - OAuth Token (MCP): ${MOCK_SERVER_URL}/oauth/token" +echo " - MCP Endpoint: ${MOCK_SERVER_URL}/mcp" +echo " - OAuth Metadata: ${MOCK_SERVER_URL}/.well-known/oauth-authorization-server" +echo "" + +# Run conformance scenarios +echo "===================================================================" +echo " Running Conformance Test Scenarios" +echo "===================================================================" +echo "" + +# Test 1: ID-JAG Validation +echo "--- Test 1: ID-JAG Token Validation ---" +export MCP_CONFORMANCE_SCENARIO="auth/enterprise-id-jag-validation" +if uv run --frozen python "$SCRIPT_DIR/client.py" "${MOCK_SERVER_URL}/mcp" 2>&1 | tee /tmp/test1.log | grep -q "ID-JAG validation completed successfully"; then + echo "✓ ID-JAG validation PASSED" + TEST1_PASS=true +else + echo "✗ ID-JAG validation FAILED" + echo "Log output:" + cat /tmp/test1.log | tail -20 + TEST1_PASS=false +fi +echo "" + +# Test 2: OIDC ID Token Exchange Flow +echo "--- Test 2: OIDC ID Token Exchange Flow ---" +export MCP_CONFORMANCE_SCENARIO="auth/enterprise-token-exchange" +if uv run --frozen python "$SCRIPT_DIR/client.py" "${MOCK_SERVER_URL}/mcp" 2>&1 | tee /tmp/test2.log | grep -q "Enterprise auth flow completed successfully"; then + echo "✓ OIDC token exchange PASSED" + TEST2_PASS=true +else + echo "✗ OIDC token exchange FAILED" + echo "Log output:" + cat /tmp/test2.log | tail -20 + TEST2_PASS=false +fi +echo "" + +# Test 3: SAML Assertion Exchange Flow +echo "--- Test 3: SAML Assertion Exchange Flow ---" +export MCP_CONFORMANCE_SCENARIO="auth/enterprise-saml-exchange" +if uv run --frozen python "$SCRIPT_DIR/client.py" "${MOCK_SERVER_URL}/mcp" 2>&1 | tee /tmp/test3.log | grep -q "SAML enterprise auth flow completed successfully"; then + echo "✓ SAML assertion exchange PASSED" + TEST3_PASS=true +else + echo "✗ SAML assertion exchange FAILED" + echo "Log output:" + cat /tmp/test3.log | tail -20 + TEST3_PASS=false +fi +echo "" + +# Summary +echo "===================================================================" +echo " Test Results Summary" +echo "===================================================================" +echo "" + +TESTS_PASSED=0 +TESTS_FAILED=0 + +if [ "$TEST1_PASS" = true ]; then + echo "✓ ID-JAG Validation" + TESTS_PASSED=$((TESTS_PASSED + 1)) +else + echo "✗ ID-JAG Validation" + TESTS_FAILED=$((TESTS_FAILED + 1)) +fi + +if [ "$TEST2_PASS" = true ]; then + echo "✓ OIDC Token Exchange" + TESTS_PASSED=$((TESTS_PASSED + 1)) +else + echo "✗ OIDC Token Exchange" + TESTS_FAILED=$((TESTS_FAILED + 1)) +fi + +if [ "$TEST3_PASS" = true ]; then + echo "✓ SAML Assertion Exchange" + TESTS_PASSED=$((TESTS_PASSED + 1)) +else + echo "✗ SAML Assertion Exchange" + TESTS_FAILED=$((TESTS_FAILED + 1)) +fi + +echo "" +echo "Total: ${TESTS_PASSED}/3 passed, ${TESTS_FAILED}/3 failed" +echo "" + +if [ $TESTS_FAILED -eq 0 ]; then + echo "🎉 All enterprise auth conformance tests PASSED!" + exit 0 +else + echo "❌ Some tests failed. Check logs above for details." + exit 1 +fi diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index 248e5bf6a..3b2c8cbe0 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -43,3 +43,18 @@ jobs: node-version: 24 - run: uv sync --frozen --all-extras --package mcp - run: npx @modelcontextprotocol/conformance@0.1.10 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all + + enterprise-auth-conformance: + runs-on: ubuntu-latest + continue-on-error: true + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # v7.2.0 + with: + enable-cache: true + version: 0.9.5 + - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 + with: + node-version: 24 + - run: uv sync --frozen --all-extras --package mcp + - run: ./.github/actions/conformance/run-enterprise-auth-with-server.sh diff --git a/README.md b/README.md index 468e1d85d..e6f39968a 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ - [Writing MCP Clients](#writing-mcp-clients) - [Client Display Utilities](#client-display-utilities) - [OAuth Authentication for Clients](#oauth-authentication-for-clients) + - [Enterprise Managed Authorization](#enterprise-managed-authorization) - [Parsing Tool Results](#parsing-tool-results) - [MCP Primitives](#mcp-primitives) - [Server Capabilities](#server-capabilities) @@ -2421,6 +2422,288 @@ _Full example: [examples/snippets/clients/oauth_client.py](https://github.com/mo For a complete working example, see [`examples/clients/simple-auth-client/`](examples/clients/simple-auth-client/). +#### Enterprise Managed Authorization + +The SDK includes support for Enterprise Managed Authorization (SEP-990), which enables MCP clients to connect to protected servers using enterprise Single Sign-On (SSO) systems. This implementation supports: + +- **RFC 8693**: OAuth 2.0 Token Exchange (ID Token → ID-JAG) +- **RFC 7523**: JSON Web Token (JWT) Profile for OAuth 2.0 Authorization Grants (ID-JAG → Access Token) +- Integration with enterprise identity providers (Okta, Azure AD, etc.) + +**Key Components:** + +The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provider to implement the enterprise authorization flow: + +**Token Exchange Flow:** + +1. **Obtain ID Token** from your enterprise IdP (e.g., Okta, Azure AD) +2. **Exchange ID Token for ID-JAG** using RFC 8693 Token Exchange +3. **Exchange ID-JAG for Access Token** using RFC 7523 JWT Bearer Grant +4. **Use Access Token** to call protected MCP server tools + +**Using the Access Token with MCP Server:** + +1. Once you have obtained the access token, you can use it to authenticate requests to the MCP server +2. The access token is automatically included in all subsequent requests to the MCP server, allowing you to access protected tools and resources based on your enterprise identity and permissions. + +**Handling Token Expiration and Refresh:** + +Access tokens have a limited lifetime and will expire. When tokens expire: + +- **Check Token Expiration**: Use the `expires_in` field to determine when the token expires +- **Refresh Flow**: When expired, repeat the token exchange flow with a fresh ID token from your IdP +- **Automatic Refresh**: Implement automatic token refresh before expiration (recommended for production) +- **Error Handling**: Catch authentication errors and retry with refreshed tokens + +**Important Notes:** + +- **ID Token Expiration**: If the ID token from your IdP expires, you must re-authenticate with the IdP to obtain a new ID token before performing token exchange +- **Token Storage**: Store tokens securely and implement the `TokenStorage` interface to persist tokens between application restarts +- **Scope Changes**: If you need different scopes, you must obtain a new ID token from the IdP with the required scopes +- **Security**: Never log or expose access tokens or ID tokens in production environments + +**Example Usage:** + + +```python +import asyncio +from datetime import datetime, timedelta, timezone +from typing import Any + +import httpx +from pydantic import AnyUrl + +from mcp import ClientSession +from mcp.client.auth import OAuthTokenError, TokenStorage +from mcp.client.auth.extensions import ( + EnterpriseAuthOAuthClientProvider, + TokenExchangeParameters, +) +from mcp.client.sse import sse_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from mcp.types import CallToolResult + + +# Placeholder function for IdP authentication +async def get_id_token_from_idp() -> str: + """Placeholder function to get ID token from your IdP. + + In production, implement actual IdP authentication flow. + """ + raise NotImplementedError("Implement your IdP authentication flow here") + + +# Define token storage implementation +class SimpleTokenStorage(TokenStorage): + def __init__(self) -> None: + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info + + +def is_token_expired(access_token: OAuthToken) -> bool: + """Check if the access token has expired.""" + if access_token.expires_in: + # Calculate expiration time + issued_at = datetime.now(timezone.utc) + expiration_time = issued_at + timedelta(seconds=access_token.expires_in) + return datetime.now(timezone.utc) >= expiration_time + return False + + +async def refresh_access_token( + enterprise_auth: EnterpriseAuthOAuthClientProvider, + client: httpx.AsyncClient, + id_token: str, +) -> OAuthToken: + """Refresh the access token when it expires.""" + try: + # Update token exchange parameters with fresh ID token + enterprise_auth.token_exchange_params.subject_token = id_token + + # Re-exchange for new ID-JAG + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + + # Get new access token + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) + return access_token + except Exception as e: + print(f"Token refresh failed: {e}") + # Re-authenticate with IdP if ID token is also expired + id_token = await get_id_token_from_idp() + return await refresh_access_token(enterprise_auth, client, id_token) + + +async def call_tool_with_retry( + session: ClientSession, + tool_name: str, + arguments: dict[str, Any], + enterprise_auth: EnterpriseAuthOAuthClientProvider, + client: httpx.AsyncClient, + id_token: str, +) -> CallToolResult | None: + """Call a tool with automatic retry on token expiration.""" + max_retries = 1 + + for attempt in range(max_retries + 1): + try: + result = await session.call_tool(tool_name, arguments) + return result + except OAuthTokenError: + if attempt < max_retries: + print("Token expired, refreshing...") + # Refresh token and reconnect + _access_token = await refresh_access_token(enterprise_auth, client, id_token) + # Note: In production, you'd need to reconnect the session here + else: + raise + return None + + +async def main() -> None: + # Step 1: Get ID token from your IdP (example with Okta) + id_token = await get_id_token_from_idp() # Your IdP authentication + + # Step 2: Configure token exchange parameters + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=id_token, + mcp_server_auth_issuer="https://your-idp.com", # IdP issuer URL + mcp_server_resource_id="https://mcp-server.example.com", # MCP server resource ID + scope="mcp:tools mcp:resources", # Optional scopes + ) + + # Step 3: Create enterprise auth provider + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example.com", + client_metadata=OAuthClientMetadata( + client_name="Enterprise MCP Client", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + response_types=["token"], + ), + storage=SimpleTokenStorage(), + idp_token_endpoint="https://your-idp.com/oauth2/v1/token", + token_exchange_params=token_exchange_params, + ) + + async with httpx.AsyncClient() as client: + # Step 4: Exchange ID token for ID-JAG + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + print(f"Obtained ID-JAG: {id_jag[:50]}...") + + # Step 5: Exchange ID-JAG for access token + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) + print(f"Access token obtained, expires in: {access_token.expires_in}s") + + # Step 6: Check if token is expired (for demonstration) + if is_token_expired(access_token): + print("Token is expired, refreshing...") + access_token = await refresh_access_token(enterprise_auth, client, id_token) + + # Step 7: Use the access token to connect to MCP server + headers = {"Authorization": f"Bearer {access_token.access_token}"} + + async with sse_client(url="https://mcp-server.example.com", headers=headers) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # Call tools with automatic retry on token expiration + result = await call_tool_with_retry( + session, "enterprise_tool", {"param": "value"}, enterprise_auth, client, id_token + ) + if result: + print(f"Tool result: {result.content}") + + # List available resources + resources = await session.list_resources() + for resource in resources.resources: + print(f"Resource: {resource.uri}") + + +async def maintain_active_session( + enterprise_auth: EnterpriseAuthOAuthClientProvider, + mcp_server_url: str, +) -> None: + """Maintain an active session with automatic token refresh.""" + id_token_var = await get_id_token_from_idp() + + async with httpx.AsyncClient() as client: + while True: + try: + # Update token exchange params with current ID token + enterprise_auth.token_exchange_params.subject_token = id_token_var + + # Get access token + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) + + # Calculate refresh time (refresh before expiration) + refresh_in = access_token.expires_in - 60 if access_token.expires_in else 300 + + # Use the token for MCP operations + headers = {"Authorization": f"Bearer {access_token.access_token}"} + async with sse_client(mcp_server_url, headers=headers) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # Perform operations... + # Schedule refresh before token expires + await asyncio.sleep(refresh_in) + + except Exception as e: + print(f"Session error: {e}") + # Re-authenticate with IdP + id_token_var = await get_id_token_from_idp() + await asyncio.sleep(5) # Wait before retry + + +if __name__ == "__main__": + asyncio.run(main()) +``` + +_Full example: [examples/snippets/clients/enterprise_managed_auth_client.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/enterprise_managed_auth_client.py)_ + + +**Working with SAML Assertions:** + +If your enterprise uses SAML instead of OIDC, you can exchange SAML assertions: + +```python +token_exchange_params = TokenExchangeParameters.from_saml_assertion( + saml_assertion=saml_assertion_string, + mcp_server_auth_issuer="https://your-idp.com", + mcp_server_resource_id="https://mcp-server.example.com", + scope="mcp:tools", +) +``` + +**Decoding and Inspecting ID-JAG Tokens:** + +You can decode ID-JAG tokens to inspect their claims: + +```python +from mcp.client.auth.extensions import decode_id_jag + +# Decode without signature verification (for inspection only) +claims = decode_id_jag(id_jag) +print(f"Subject: {claims.sub}") +print(f"Issuer: {claims.iss}") +print(f"Audience: {claims.aud}") +print(f"Client ID: {claims.client_id}") +print(f"Resource: {claims.resource}") +``` + ### Parsing Tool Results When calling tools through MCP, the `CallToolResult` object contains the tool's response in a structured format. Understanding how to parse this result is essential for properly handling tool outputs. diff --git a/examples/snippets/clients/enterprise_managed_auth_client.py b/examples/snippets/clients/enterprise_managed_auth_client.py new file mode 100644 index 000000000..57070043a --- /dev/null +++ b/examples/snippets/clients/enterprise_managed_auth_client.py @@ -0,0 +1,204 @@ +import asyncio +from datetime import datetime, timedelta, timezone +from typing import Any + +import httpx +from pydantic import AnyUrl + +from mcp import ClientSession +from mcp.client.auth import OAuthTokenError, TokenStorage +from mcp.client.auth.extensions import ( + EnterpriseAuthOAuthClientProvider, + TokenExchangeParameters, +) +from mcp.client.sse import sse_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from mcp.types import CallToolResult + + +# Placeholder function for IdP authentication +async def get_id_token_from_idp() -> str: + """Placeholder function to get ID token from your IdP. + + In production, implement actual IdP authentication flow. + """ + raise NotImplementedError("Implement your IdP authentication flow here") + + +# Define token storage implementation +class SimpleTokenStorage(TokenStorage): + def __init__(self) -> None: + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info + + +def is_token_expired(access_token: OAuthToken) -> bool: + """Check if the access token has expired.""" + if access_token.expires_in: + # Calculate expiration time + issued_at = datetime.now(timezone.utc) + expiration_time = issued_at + timedelta(seconds=access_token.expires_in) + return datetime.now(timezone.utc) >= expiration_time + return False + + +async def refresh_access_token( + enterprise_auth: EnterpriseAuthOAuthClientProvider, + client: httpx.AsyncClient, + id_token: str, +) -> OAuthToken: + """Refresh the access token when it expires.""" + try: + # Update token exchange parameters with fresh ID token + enterprise_auth.token_exchange_params.subject_token = id_token + + # Re-exchange for new ID-JAG + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + + # Get new access token + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) + return access_token + except Exception as e: + print(f"Token refresh failed: {e}") + # Re-authenticate with IdP if ID token is also expired + id_token = await get_id_token_from_idp() + return await refresh_access_token(enterprise_auth, client, id_token) + + +async def call_tool_with_retry( + session: ClientSession, + tool_name: str, + arguments: dict[str, Any], + enterprise_auth: EnterpriseAuthOAuthClientProvider, + client: httpx.AsyncClient, + id_token: str, +) -> CallToolResult | None: + """Call a tool with automatic retry on token expiration.""" + max_retries = 1 + + for attempt in range(max_retries + 1): + try: + result = await session.call_tool(tool_name, arguments) + return result + except OAuthTokenError: + if attempt < max_retries: + print("Token expired, refreshing...") + # Refresh token and reconnect + _access_token = await refresh_access_token(enterprise_auth, client, id_token) + # Note: In production, you'd need to reconnect the session here + else: + raise + return None + + +async def main() -> None: + # Step 1: Get ID token from your IdP (example with Okta) + id_token = await get_id_token_from_idp() # Your IdP authentication + + # Step 2: Configure token exchange parameters + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=id_token, + mcp_server_auth_issuer="https://your-idp.com", # IdP issuer URL + mcp_server_resource_id="https://mcp-server.example.com", # MCP server resource ID + scope="mcp:tools mcp:resources", # Optional scopes + ) + + # Step 3: Create enterprise auth provider + enterprise_auth = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example.com", + client_metadata=OAuthClientMetadata( + client_name="Enterprise MCP Client", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"], + response_types=["token"], + ), + storage=SimpleTokenStorage(), + idp_token_endpoint="https://your-idp.com/oauth2/v1/token", + token_exchange_params=token_exchange_params, + ) + + async with httpx.AsyncClient() as client: + # Step 4: Exchange ID token for ID-JAG + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + print(f"Obtained ID-JAG: {id_jag[:50]}...") + + # Step 5: Exchange ID-JAG for access token + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) + print(f"Access token obtained, expires in: {access_token.expires_in}s") + + # Step 6: Check if token is expired (for demonstration) + if is_token_expired(access_token): + print("Token is expired, refreshing...") + access_token = await refresh_access_token(enterprise_auth, client, id_token) + + # Step 7: Use the access token to connect to MCP server + headers = {"Authorization": f"Bearer {access_token.access_token}"} + + async with sse_client(url="https://mcp-server.example.com", headers=headers) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # Call tools with automatic retry on token expiration + result = await call_tool_with_retry( + session, "enterprise_tool", {"param": "value"}, enterprise_auth, client, id_token + ) + if result: + print(f"Tool result: {result.content}") + + # List available resources + resources = await session.list_resources() + for resource in resources.resources: + print(f"Resource: {resource.uri}") + + +async def maintain_active_session( + enterprise_auth: EnterpriseAuthOAuthClientProvider, + mcp_server_url: str, +) -> None: + """Maintain an active session with automatic token refresh.""" + id_token_var = await get_id_token_from_idp() + + async with httpx.AsyncClient() as client: + while True: + try: + # Update token exchange params with current ID token + enterprise_auth.token_exchange_params.subject_token = id_token_var + + # Get access token + id_jag = await enterprise_auth.exchange_token_for_id_jag(client) + access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag) + + # Calculate refresh time (refresh before expiration) + refresh_in = access_token.expires_in - 60 if access_token.expires_in else 300 + + # Use the token for MCP operations + headers = {"Authorization": f"Bearer {access_token.access_token}"} + async with sse_client(mcp_server_url, headers=headers) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # Perform operations... + # Schedule refresh before token expires + await asyncio.sleep(refresh_in) + + except Exception as e: + print(f"Session error: {e}") + # Re-authenticate with IdP + id_token_var = await get_id_token_from_idp() + await asyncio.sleep(5) # Wait before retry + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 87eac7213..ca19e33a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ dev = [ "dirty-equals>=0.9.0", "coverage[toml]>=7.13.1", "pillow>=12.0", + "fastapi>=0.115.0", # For enterprise auth conformance test mock server ] docs = [ "mkdocs>=1.6.1", diff --git a/src/mcp/client/auth/extensions/__init__.py b/src/mcp/client/auth/extensions/__init__.py index e69de29bb..56ba368ef 100644 --- a/src/mcp/client/auth/extensions/__init__.py +++ b/src/mcp/client/auth/extensions/__init__.py @@ -0,0 +1,19 @@ +"""MCP Client Auth Extensions.""" + +from mcp.client.auth.extensions.enterprise_managed_auth import ( + EnterpriseAuthOAuthClientProvider, + IDJAGClaims, + TokenExchangeParameters, + TokenExchangeResponse, + decode_id_jag, + validate_token_exchange_params, +) + +__all__ = [ + "EnterpriseAuthOAuthClientProvider", + "IDJAGClaims", + "TokenExchangeParameters", + "TokenExchangeResponse", + "decode_id_jag", + "validate_token_exchange_params", +] diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py new file mode 100644 index 000000000..5ebc19c56 --- /dev/null +++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py @@ -0,0 +1,416 @@ +"""Enterprise Managed Authorization extension for MCP (SEP-990). + +Implements RFC 8693 Token Exchange and RFC 7523 JWT Bearer Grant for +enterprise SSO integration. +""" + +import logging +from collections.abc import Awaitable, Callable + +import httpx +import jwt +from pydantic import BaseModel, Field +from typing_extensions import TypedDict + +from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage +from mcp.shared.auth import OAuthClientMetadata, OAuthToken + +logger = logging.getLogger(__name__) + + +class TokenExchangeRequestData(TypedDict, total=False): + """Type definition for RFC 8693 Token Exchange request data.""" + + grant_type: str + requested_token_type: str + audience: str + resource: str + subject_token: str + subject_token_type: str + scope: str + client_id: str + client_secret: str + + +class JWTBearerGrantRequestData(TypedDict, total=False): + """Type definition for RFC 7523 JWT Bearer Grant request data.""" + + grant_type: str + assertion: str + client_id: str + client_secret: str + + +class TokenExchangeParameters(BaseModel): + """Parameters for RFC 8693 Token Exchange request.""" + + requested_token_type: str = Field( + default="urn:ietf:params:oauth:token-type:id-jag", + description="Type of token being requested (ID-JAG)", + ) + + audience: str = Field( + ..., + description="Issuer URL of the MCP Server's authorization server", + ) + + resource: str = Field( + ..., + description="RFC 9728 Resource Identifier of the MCP Server", + ) + + scope: str | None = Field( + default=None, + description="Space-separated list of scopes being requested", + ) + + subject_token: str = Field( + ..., + description="ID Token or SAML assertion for the end user", + ) + + subject_token_type: str = Field( + ..., + description="Type of subject token (id_token or saml2)", + ) + + @classmethod + def from_id_token( + cls, + id_token: str, + mcp_server_auth_issuer: str, + mcp_server_resource_id: str, + scope: str | None = None, + ) -> "TokenExchangeParameters": + """Create parameters for OIDC ID Token exchange.""" + return cls( + subject_token=id_token, + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience=mcp_server_auth_issuer, + resource=mcp_server_resource_id, + scope=scope, + ) + + @classmethod + def from_saml_assertion( + cls, + saml_assertion: str, + mcp_server_auth_issuer: str, + mcp_server_resource_id: str, + scope: str | None = None, + ) -> "TokenExchangeParameters": + """Create parameters for SAML assertion exchange.""" + return cls( + subject_token=saml_assertion, + subject_token_type="urn:ietf:params:oauth:token-type:saml2", + audience=mcp_server_auth_issuer, + resource=mcp_server_resource_id, + scope=scope, + ) + + +class TokenExchangeResponse(BaseModel): + """Response from RFC 8693 Token Exchange.""" + + issued_token_type: str = Field( + ..., + description="Type of token issued (should be id-jag)", + ) + + access_token: str = Field( + ..., + description="The ID-JAG token (named access_token per RFC 8693)", + ) + + token_type: str = Field( + ..., + description="Token type (should be N_A for ID-JAG)", + ) + + scope: str | None = Field( + default=None, + description="Granted scopes", + ) + + expires_in: int | None = Field( + default=None, + description="Lifetime in seconds", + ) + + @property + def id_jag(self) -> str: + """Get the ID-JAG token.""" + return self.access_token + + +class IDJAGClaims(BaseModel): + """Claims structure for Identity Assertion JWT Authorization Grant.""" + + model_config = {"extra": "allow"} + + # JWT header + typ: str = Field( + ..., + description="JWT type - must be 'oauth-id-jag+jwt'", + ) + + # Required claims + jti: str = Field(..., description="Unique JWT ID") + iss: str = Field(..., description="IdP issuer URL") + sub: str = Field(..., description="Subject (user) identifier") + aud: str = Field(..., description="MCP Server's auth server issuer") + resource: str = Field(..., description="MCP Server resource identifier") + client_id: str = Field(..., description="MCP Client identifier") + exp: int = Field(..., description="Expiration timestamp") + iat: int = Field(..., description="Issued-at timestamp") + + # Optional claims + scope: str | None = Field(None, description="Space-separated scopes") + email: str | None = Field(None, description="User email") + + +class EnterpriseAuthOAuthClientProvider(OAuthClientProvider): + """OAuth client provider for Enterprise Managed Authorization (SEP-990). + + Implements: + - RFC 8693: Token Exchange (ID Token → ID-JAG) + - RFC 7523: JWT Bearer Grant (ID-JAG → Access Token) + """ + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + idp_token_endpoint: str, + token_exchange_params: TokenExchangeParameters, + redirect_handler: Callable[[str], Awaitable[None]] | None = None, + callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, + timeout: float = 300.0, + ) -> None: + """Initialize Enterprise Auth OAuth Client. + + Args: + server_url: MCP server URL + client_metadata: OAuth client metadata + storage: Token storage implementation + idp_token_endpoint: Enterprise IdP token endpoint URL + token_exchange_params: Token exchange parameters + redirect_handler: Optional redirect handler + callback_handler: Optional callback handler + timeout: Request timeout in seconds + """ + super().__init__( + server_url=server_url, + client_metadata=client_metadata, + storage=storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + timeout=timeout, + ) + self.idp_token_endpoint = idp_token_endpoint + self.token_exchange_params = token_exchange_params + self._id_jag: str | None = None + + async def exchange_token_for_id_jag( + self, + client: httpx.AsyncClient, + ) -> str: + """Exchange ID Token for ID-JAG using RFC 8693 Token Exchange. + + Args: + client: HTTP client for making requests + + Returns: + The ID-JAG token string + + Raises: + OAuthTokenError: If token exchange fails + """ + logger.info("Starting token exchange for ID-JAG") + + # Build token exchange request + token_data: TokenExchangeRequestData = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "requested_token_type": self.token_exchange_params.requested_token_type, + "audience": self.token_exchange_params.audience, + "resource": self.token_exchange_params.resource, + "subject_token": self.token_exchange_params.subject_token, + "subject_token_type": self.token_exchange_params.subject_token_type, + } + + if self.token_exchange_params.scope: + token_data["scope"] = self.token_exchange_params.scope + + # Add client authentication if needed + if self.context.client_info: + if self.context.client_info.client_id is not None: + token_data["client_id"] = self.context.client_info.client_id + if self.context.client_info.client_secret is not None: + token_data["client_secret"] = self.context.client_info.client_secret + + try: + response = await client.post( + self.idp_token_endpoint, + data=token_data, + timeout=self.context.timeout, + ) + + if response.status_code != 200: + error_data: dict[str, str] = ( + response.json() if response.headers.get("content-type", "").startswith("application/json") else {} + ) + error: str = error_data.get("error", "unknown_error") + error_description: str = error_data.get("error_description", "Token exchange failed") + raise OAuthTokenError(f"Token exchange failed: {error} - {error_description}") + + # Parse response + token_response = TokenExchangeResponse.model_validate_json(response.content) + + # Validate response + if token_response.issued_token_type != "urn:ietf:params:oauth:token-type:id-jag": + raise OAuthTokenError(f"Unexpected token type: {token_response.issued_token_type}") + + if token_response.token_type != "N_A": + logger.warning(f"Expected token_type 'N_A', got '{token_response.token_type}'") + + logger.info("Successfully obtained ID-JAG") + self._id_jag = token_response.id_jag + return token_response.id_jag + + except httpx.HTTPError as e: + raise OAuthTokenError(f"HTTP error during token exchange: {e}") from e + + async def exchange_id_jag_for_access_token( + self, + client: httpx.AsyncClient, + id_jag: str, + ) -> OAuthToken: + """Exchange ID-JAG for access token using RFC 7523 JWT Bearer Grant. + + Args: + client: HTTP client for making requests + id_jag: The ID-JAG token + + Returns: + OAuth access token + + Raises: + OAuthTokenError: If JWT bearer grant fails + """ + logger.info("Exchanging ID-JAG for access token") + + # Discover token endpoint from MCP server if not already done + if not self.context.oauth_metadata or not self.context.oauth_metadata.token_endpoint: + raise OAuthFlowError("MCP server token endpoint not discovered") + + token_endpoint = str(self.context.oauth_metadata.token_endpoint) + + # Build JWT bearer grant request + token_data: JWTBearerGrantRequestData = { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": id_jag, + } + + # Add client authentication + if self.context.client_info: + if self.context.client_info.client_id is not None: + token_data["client_id"] = self.context.client_info.client_id + if self.context.client_info.client_secret is not None: + token_data["client_secret"] = self.context.client_info.client_secret + + try: + response = await client.post( + token_endpoint, + data=token_data, + timeout=self.context.timeout, + ) + + if response.status_code != 200: + error_data: dict[str, str] = ( + response.json() if response.headers.get("content-type", "").startswith("application/json") else {} + ) + error: str = error_data.get("error", "unknown_error") + error_description: str = error_data.get("error_description", "JWT bearer grant failed") + raise OAuthTokenError(f"JWT bearer grant failed: {error} - {error_description}") + + # Parse OAuth token response + token = OAuthToken.model_validate_json(response.content) + + # Store tokens + self.context.current_tokens = token + self.context.update_token_expiry(token) + await self.context.storage.set_tokens(token) + + logger.info("Successfully obtained access token via ID-JAG") + return token + + except httpx.HTTPError as e: + raise OAuthTokenError(f"HTTP error during JWT bearer grant: {e}") from e + + async def _perform_authorization(self) -> httpx.Request: + """Perform enterprise authorization flow. + + Overrides parent method to use token exchange + JWT bearer grant + instead of standard authorization code flow. + """ + # Check if we already have valid tokens + if self.context.is_token_valid(): + # Return a dummy request - we don't need to make any request + return httpx.Request("GET", self.context.server_url) + + # For now, raise NotImplementedError as this requires integration + # with the full httpx auth flow + raise NotImplementedError( + "Full enterprise auth flow integration not yet implemented. " + "Use exchange_token_for_id_jag and exchange_id_jag_for_access_token directly." + ) + + +def decode_id_jag(id_jag: str) -> IDJAGClaims: + """Decode an ID-JAG token without verification. + + Args: + id_jag: The ID-JAG token string + + Returns: + Decoded ID-JAG claims + + Note: + For verification, use server-side validation instead. + """ + # Decode without verification for inspection + claims = jwt.decode(id_jag, options={"verify_signature": False}) + header = jwt.get_unverified_header(id_jag) + + # Add typ from header to claims + claims["typ"] = header.get("typ", "") + + return IDJAGClaims.model_validate(claims) + + +def validate_token_exchange_params( + params: TokenExchangeParameters, +) -> None: + """Validate token exchange parameters. + + Args: + params: Token exchange parameters to validate + + Raises: + ValueError: If parameters are invalid + """ + if not params.subject_token: + raise ValueError("subject_token is required") + + if not params.audience: + raise ValueError("audience is required") + + if not params.resource: + raise ValueError("resource is required") + + if params.subject_token_type not in [ + "urn:ietf:params:oauth:token-type:id_token", + "urn:ietf:params:oauth:token-type:saml2", + ]: + raise ValueError(f"Invalid subject_token_type: {params.subject_token_type}") diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py new file mode 100644 index 000000000..729189766 --- /dev/null +++ b/tests/client/auth/test_enterprise_managed_auth_client.py @@ -0,0 +1,1108 @@ +"""Tests for Enterprise Managed Authorization client-side implementation.""" + +import time +from typing import Any +from unittest.mock import AsyncMock, Mock, patch + +import httpx +import jwt +import pytest +from pydantic import AnyHttpUrl, AnyUrl + +from mcp.client.auth import OAuthTokenError +from mcp.client.auth.extensions.enterprise_managed_auth import ( + EnterpriseAuthOAuthClientProvider, + IDJAGClaims, + TokenExchangeParameters, + TokenExchangeResponse, + decode_id_jag, + validate_token_exchange_params, +) +from mcp.shared.auth import OAuthClientMetadata + + +@pytest.fixture +def sample_id_token() -> str: + """Generate a sample ID token for testing.""" + payload = { + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "mcp-client-app", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + "email": "user@example.com", + } + return jwt.encode(payload, "secret", algorithm="HS256") + + +@pytest.fixture +def sample_id_jag() -> str: + """Generate a sample ID-JAG token for testing.""" + # Create typed claims using IDJAGClaims model + claims = IDJAGClaims( + typ="oauth-id-jag+jwt", + jti="unique-jwt-id-12345", + iss="https://idp.example.com", + sub="user123", + aud="https://auth.mcp-server.example/", + resource="https://mcp-server.example/", + client_id="mcp-client-app", + exp=int(time.time()) + 300, + iat=int(time.time()), + scope="read write", + email=None, # Optional field + ) + + # Dump to dict for JWT encoding (exclude typ as it goes in header) + payload = claims.model_dump(exclude={"typ"}, exclude_none=True) + + return jwt.encode(payload, "secret", algorithm="HS256", headers={"typ": "oauth-id-jag+jwt"}) + + +@pytest.fixture +def mock_token_storage() -> Any: + """Create a mock token storage.""" + storage = Mock() + storage.get_tokens = AsyncMock(return_value=None) + storage.set_tokens = AsyncMock() + storage.get_client_info = AsyncMock(return_value=None) + storage.set_client_info = AsyncMock() + return storage + + +def test_token_exchange_params_from_id_token(): + """Test creating TokenExchangeParameters from ID token.""" + params = TokenExchangeParameters.from_id_token( + id_token="eyJhbGc...", + mcp_server_auth_issuer="https://auth.server.example/", + mcp_server_resource_id="https://server.example/", + scope="read write", + ) + + assert params.subject_token == "eyJhbGc..." + assert params.subject_token_type == "urn:ietf:params:oauth:token-type:id_token" + assert params.audience == "https://auth.server.example/" + assert params.resource == "https://server.example/" + assert params.scope == "read write" + assert params.requested_token_type == "urn:ietf:params:oauth:token-type:id-jag" + + +def test_token_exchange_params_from_saml_assertion(): + """Test creating TokenExchangeParameters from SAML assertion.""" + params = TokenExchangeParameters.from_saml_assertion( + saml_assertion="...", + mcp_server_auth_issuer="https://auth.server.example/", + mcp_server_resource_id="https://server.example/", + scope="read", + ) + + assert params.subject_token == "..." + assert params.subject_token_type == "urn:ietf:params:oauth:token-type:saml2" + assert params.audience == "https://auth.server.example/" + assert params.resource == "https://server.example/" + assert params.scope == "read" + + +def test_validate_token_exchange_params_valid(): + """Test validating valid token exchange parameters.""" + params = TokenExchangeParameters.from_id_token( + id_token="token", + mcp_server_auth_issuer="https://auth.example/", + mcp_server_resource_id="https://server.example/", + ) + + # Should not raise + validate_token_exchange_params(params) + + +def test_validate_token_exchange_params_invalid_token_type(): + """Test validation fails for invalid subject token type.""" + params = TokenExchangeParameters( + subject_token="token", + subject_token_type="invalid:type", + audience="https://auth.example/", + resource="https://server.example/", + ) + + with pytest.raises(ValueError, match="Invalid subject_token_type"): + validate_token_exchange_params(params) + + +def test_validate_token_exchange_params_missing_subject_token(): + """Test validation fails for missing subject token.""" + params = TokenExchangeParameters( + subject_token="", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience="https://auth.example/", + resource="https://server.example/", + ) + + with pytest.raises(ValueError, match="subject_token is required"): + validate_token_exchange_params(params) + + +def test_token_exchange_response_parsing(): + """Test parsing token exchange response.""" + response_json = """{ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": "eyJhbGc...", + "token_type": "N_A", + "scope": "read write", + "expires_in": 300 + }""" + + response = TokenExchangeResponse.model_validate_json(response_json) + + assert response.issued_token_type == "urn:ietf:params:oauth:token-type:id-jag" + assert response.id_jag == "eyJhbGc..." + assert response.access_token == "eyJhbGc..." + assert response.token_type == "N_A" + assert response.scope == "read write" + assert response.expires_in == 300 + + +def test_token_exchange_response_id_jag_property(): + """Test id_jag property returns access_token.""" + response = TokenExchangeResponse( + issued_token_type="urn:ietf:params:oauth:token-type:id-jag", + access_token="the-id-jag-token", + token_type="N_A", + ) + + assert response.id_jag == "the-id-jag-token" + + +def test_decode_id_jag(sample_id_jag: str): + """Test decoding ID-JAG token.""" + claims = decode_id_jag(sample_id_jag) + + assert claims.iss == "https://idp.example.com" + assert claims.sub == "user123" + assert claims.aud == "https://auth.mcp-server.example/" + assert claims.resource == "https://mcp-server.example/" + assert claims.client_id == "mcp-client-app" + assert claims.scope == "read write" + + +def test_id_jag_claims_with_extra_fields(): + """Test IDJAGClaims allows extra fields.""" + claims_data = { + "typ": "oauth-id-jag+jwt", + "jti": "jti123", + "iss": "https://idp.example.com", + "sub": "user123", + "aud": "https://auth.server.example/", + "resource": "https://server.example/", + "client_id": "client123", + "exp": int(time.time()) + 300, + "iat": int(time.time()), + "scope": "read", + "email": "user@example.com", + "custom_claim": "custom_value", # Extra field + } + + claims = IDJAGClaims.model_validate(claims_data) + assert claims.email == "user@example.com" + # Extra field should be preserved + assert claims.model_extra is not None and claims.model_extra.get("custom_claim") == "custom_value" + + +# ============================================================================ +# Tests for EnterpriseAuthOAuthClientProvider +# ============================================================================ + + +@pytest.mark.anyio +async def test_exchange_token_for_id_jag_success(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): + """Test successful token exchange for ID-JAG.""" + # Create provider + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify + assert id_jag == sample_id_jag + assert provider._id_jag == sample_id_jag + + # Verify request was made correctly + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + assert call_args[0][0] == "https://idp.example.com/oauth2/token" + assert call_args[1]["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:token-exchange" + assert call_args[1]["data"]["requested_token_type"] == "urn:ietf:params:oauth:token-type:id-jag" + assert call_args[1]["data"]["audience"] == "https://auth.mcp-server.example/" + assert call_args[1]["data"]["resource"] == "https://mcp-server.example/" + + +@pytest.mark.anyio +async def test_exchange_token_for_id_jag_error(sample_id_token: str, mock_token_storage: Any): + """Test token exchange failure handling.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock error response + mock_response = httpx.Response( + status_code=400, + json={ + "error": "invalid_request", + "error_description": "Invalid subject token", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="Token exchange failed"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_for_id_jag_unexpected_token_type(sample_id_token: str, mock_token_storage: Any): + """Test token exchange with unexpected token type.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock response with wrong token type + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "access_token": "some-token", + "token_type": "Bearer", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="Unexpected token type"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_id_jag_for_access_token_success(sample_id_jag: str, mock_token_storage: Any): + """Test successful JWT bearer grant to get access token.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + from mcp.shared.auth import OAuthMetadata + + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "token_type": "Bearer", + "access_token": "mcp-access-token-12345", + "expires_in": 3600, + "scope": "read write", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform JWT bearer grant + token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + # Verify + assert token.access_token == "mcp-access-token-12345" + assert token.token_type == "Bearer" + assert token.expires_in == 3600 + + # Verify tokens were stored + mock_token_storage.set_tokens.assert_called_once() + + # Verify request was made correctly + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + assert call_args[1]["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:jwt-bearer" + assert call_args[1]["data"]["assertion"] == sample_id_jag + + +@pytest.mark.anyio +async def test_exchange_id_jag_for_access_token_no_metadata(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant fails without OAuth metadata.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # No OAuth metadata set + mock_client = Mock(spec=httpx.AsyncClient) + + # Should raise OAuthFlowError + from mcp.client.auth import OAuthFlowError + + with pytest.raises(OAuthFlowError, match="token endpoint not discovered"): + await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + +@pytest.mark.anyio +async def test_perform_authorization_not_implemented(mock_token_storage: Any): + """Test that _perform_authorization raises NotImplementedError.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Should raise NotImplementedError + with pytest.raises(NotImplementedError, match="not yet implemented"): + await provider._perform_authorization() + + +@pytest.mark.anyio +async def test_perform_authorization_with_valid_tokens(mock_token_storage: Any): + """Test that _perform_authorization returns dummy request when tokens are valid.""" + from mcp.shared.auth import OAuthToken + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set valid tokens + provider.context.current_tokens = OAuthToken( + token_type="Bearer", + access_token="valid-token", + expires_in=3600, + ) + provider.context.token_expiry_time = time.time() + 3600 + + # Should return a dummy request + request = await provider._perform_authorization() + assert request.method == "GET" + assert str(request.url) == "https://mcp-server.example/" + + +@pytest.mark.anyio +async def test_exchange_token_with_client_authentication( + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any +): + """Test token exchange with client authentication.""" + from mcp.shared.auth import OAuthClientInformationFull + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info with secret + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret="test-client-secret", + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify the ID-JAG was returned + assert id_jag == sample_id_jag + + # Verify client credentials were included + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-client-id" + assert call_args[1]["data"]["client_secret"] == "test-client-secret" + + +@pytest.mark.anyio +async def test_exchange_token_with_client_id_only(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any): + """Test token exchange with client_id but no client_secret (covers branch 232->235).""" + from mcp.shared.auth import OAuthClientInformationFull + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info WITHOUT secret (client_secret=None) + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret=None, # No secret + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify the ID-JAG was returned + assert id_jag == sample_id_jag + + # Verify client_id was included but NOT client_secret + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-client-id" + assert "client_secret" not in call_args[1]["data"] + + +@pytest.mark.anyio +async def test_exchange_token_http_error(sample_id_token: str, mock_token_storage: Any): + """Test token exchange with HTTP error.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection failed")) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="HTTP error during token exchange"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_non_json_error_response(sample_id_token: str, mock_token_storage: Any): + """Test token exchange with non-JSON error response.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock error response with non-JSON content + mock_response = httpx.Response( + status_code=500, + content=b"Internal Server Error", + headers={"content-type": "text/plain"}, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError with default error + with pytest.raises(OAuthTokenError, match="Token exchange failed: unknown_error"): + await provider.exchange_token_for_id_jag(mock_client) + + +@pytest.mark.anyio +async def test_exchange_token_warning_for_non_na_token_type( + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any +): + """Test token exchange logs warning for non-N_A token type.""" + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Mock response with different token_type + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "Bearer", # Not N_A + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should succeed but log warning + import logging + + with patch.object( + logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning" + ) as mock_warning: + id_jag = await provider.exchange_token_for_id_jag(mock_client) + assert id_jag == sample_id_jag + mock_warning.assert_called_once() + + +@pytest.mark.anyio +async def test_exchange_id_jag_with_client_authentication(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant with client authentication.""" + from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info with secret + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret="test-client-secret", + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "token_type": "Bearer", + "access_token": "mcp-access-token-12345", + "expires_in": 3600, + "scope": "read write", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform JWT bearer grant + token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + # Verify token was returned + assert token.access_token == "mcp-access-token-12345" + + # Verify client credentials were included + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-client-id" + assert call_args[1]["data"]["client_secret"] == "test-client-secret" + + +@pytest.mark.anyio +async def test_exchange_id_jag_with_client_id_only(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant with client_id but no client_secret (covers branch 304->307).""" + from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info WITHOUT secret (client_secret=None) + provider.context.client_info = OAuthClientInformationFull( + client_id="test-client-id", + client_secret=None, # No secret + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "token_type": "Bearer", + "access_token": "mcp-access-token-12345", + "expires_in": 3600, + "scope": "read write", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform JWT bearer grant + token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + # Verify token was returned correctly + assert token.access_token == "mcp-access-token-12345" + assert token.token_type == "Bearer" + + # Verify client_id was included but NOT client_secret + call_args = mock_client.post.call_args + assert call_args[1]["data"]["client_id"] == "test-client-id" + assert "client_secret" not in call_args[1]["data"] + + +@pytest.mark.anyio +async def test_exchange_id_jag_error_response(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant with error response.""" + from mcp.shared.auth import OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock error response + mock_response = httpx.Response( + status_code=400, + json={ + "error": "invalid_grant", + "error_description": "Invalid assertion", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="JWT bearer grant failed"): + await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + +@pytest.mark.anyio +async def test_exchange_id_jag_non_json_error(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant with non-JSON error response.""" + from mcp.shared.auth import OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Mock error response with non-JSON content + mock_response = httpx.Response( + status_code=503, + content=b"Service Unavailable", + headers={"content-type": "text/html"}, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Should raise OAuthTokenError with default error + with pytest.raises(OAuthTokenError, match="JWT bearer grant failed: unknown_error"): + await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + +@pytest.mark.anyio +async def test_exchange_id_jag_http_error(sample_id_jag: str, mock_token_storage: Any): + """Test JWT bearer grant with HTTP error.""" + from mcp.shared.auth import OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(side_effect=httpx.ReadTimeout("Request timeout")) + + # Should raise OAuthTokenError + with pytest.raises(OAuthTokenError, match="HTTP error during JWT bearer grant"): + await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + +@pytest.mark.anyio +async def test_exchange_token_with_client_info_but_no_client_id( + sample_id_token: str, sample_id_jag: str, mock_token_storage: Any +): + """Test token exchange when client_info exists but client_id is None (covers line 231).""" + from mcp.shared.auth import OAuthClientInformationFull + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token=sample_id_token, + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + scope="read write", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + client_name="Test Client", + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set client info with client_id=None + provider.context.client_info = OAuthClientInformationFull( + client_id=None, # This should skip the client_id assignment + client_secret="test-secret", + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag", + "access_token": sample_id_jag, + "token_type": "N_A", + "scope": "read write", + "expires_in": 300, + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform token exchange + id_jag = await provider.exchange_token_for_id_jag(mock_client) + + # Verify the ID-JAG was returned + assert id_jag == sample_id_jag + + # Verify client_id was not included (None), but client_secret was included + call_args = mock_client.post.call_args + assert "client_id" not in call_args[1]["data"] + assert call_args[1]["data"]["client_secret"] == "test-secret" + + +@pytest.mark.anyio +async def test_exchange_id_jag_with_client_info_but_no_client_id(sample_id_jag: str, mock_token_storage: Any): + """Test ID-JAG exchange when client_info exists but client_id is None (covers line 302).""" + from pydantic import AnyHttpUrl + + from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata + + token_exchange_params = TokenExchangeParameters.from_id_token( + id_token="dummy-token", + mcp_server_auth_issuer="https://auth.mcp-server.example/", + mcp_server_resource_id="https://mcp-server.example/", + ) + + provider = EnterpriseAuthOAuthClientProvider( + server_url="https://mcp-server.example/", + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ), + storage=mock_token_storage, + idp_token_endpoint="https://idp.example.com/oauth2/token", + token_exchange_params=token_exchange_params, + ) + + # Set up OAuth metadata + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.mcp-server.example/"), + authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"), + token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"), + ) + + # Set client info with client_id=None + provider.context.client_info = OAuthClientInformationFull( + client_id=None, # This should skip the client_id assignment + client_secret="test-secret", + redirect_uris=[AnyUrl("http://localhost:8080/callback")], + ) + + # Mock HTTP response + mock_response = httpx.Response( + status_code=200, + json={ + "token_type": "Bearer", + "access_token": "mcp-access-token-12345", + "expires_in": 3600, + "scope": "read write", + }, + ) + + mock_client = Mock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=mock_response) + + # Perform JWT bearer grant + token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag) + + # Verify + assert token.access_token == "mcp-access-token-12345" + assert token.token_type == "Bearer" + assert token.expires_in == 3600 + + # Verify client_id was not included (None), but client_secret was included + call_args = mock_client.post.call_args + assert "client_id" not in call_args[1]["data"] + assert call_args[1]["data"]["client_secret"] == "test-secret" + + +def test_validate_token_exchange_params_missing_audience(): + """Test validation fails for missing audience.""" + params = TokenExchangeParameters( + subject_token="token", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience="", + resource="https://server.example/", + ) + + with pytest.raises(ValueError, match="audience is required"): + validate_token_exchange_params(params) + + +def test_validate_token_exchange_params_missing_resource(): + """Test validation fails for missing resource.""" + params = TokenExchangeParameters( + subject_token="token", + subject_token_type="urn:ietf:params:oauth:token-type:id_token", + audience="https://auth.example/", + resource="", + ) + + with pytest.raises(ValueError, match="resource is required"): + validate_token_exchange_params(params) diff --git a/uv.lock b/uv.lock index 5d36da2e3..256b1cc02 100644 --- a/uv.lock +++ b/uv.lock @@ -29,6 +29,15 @@ members = [ "mcp-structured-output-lowlevel", ] +[[package]] +name = "annotated-doc" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -494,6 +503,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, ] +[[package]] +name = "fastapi" +version = "0.128.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-doc" }, + { name = "pydantic", version = "2.11.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" }, + { name = "pydantic", version = "2.12.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682, upload-time = "2025-12-27T15:21:13.714Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094, upload-time = "2025-12-27T15:21:12.154Z" }, +] + [[package]] name = "ghp-import" version = "2.1.0" @@ -754,6 +779,7 @@ ws = [ dev = [ { name = "coverage", extra = ["toml"] }, { name = "dirty-equals" }, + { name = "fastapi" }, { name = "inline-snapshot" }, { name = "mcp", extra = ["cli", "ws"] }, { name = "pillow" }, @@ -802,6 +828,7 @@ provides-extras = ["cli", "rich", "ws"] dev = [ { name = "coverage", extras = ["toml"], specifier = ">=7.13.1" }, { name = "dirty-equals", specifier = ">=0.9.0" }, + { name = "fastapi", specifier = ">=0.115.0" }, { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "mcp", extras = ["cli", "ws"], editable = "." }, { name = "pillow", specifier = ">=12.0" },