diff --git a/.github/actions/conformance/client.py b/.github/actions/conformance/client.py
index 7ca88110a..cd15c4d64 100644
--- a/.github/actions/conformance/client.py
+++ b/.github/actions/conformance/client.py
@@ -16,6 +16,9 @@
elicitation-sep1034-client-defaults - Elicitation with default accept callback
auth/client-credentials-jwt - Client credentials with private_key_jwt
auth/client-credentials-basic - Client credentials with client_secret_basic
+ auth/enterprise-token-exchange - Enterprise auth with OIDC ID token (SEP-990)
+ auth/enterprise-saml-exchange - Enterprise auth with SAML assertion (SEP-990)
+ auth/enterprise-id-jag-validation - Validate ID-JAG token structure (SEP-990)
auth/* - Authorization code flow (default for auth scenarios)
"""
@@ -293,6 +296,255 @@ async def run_auth_code_client(server_url: str) -> None:
await _run_auth_session(server_url, oauth_auth)
+@register("auth/enterprise-token-exchange")
+async def run_enterprise_token_exchange(server_url: str) -> None:
+ """Enterprise managed auth: Token exchange flow (RFC 8693)."""
+ from mcp.client.auth.extensions.enterprise_managed_auth import (
+ EnterpriseAuthOAuthClientProvider,
+ TokenExchangeParameters,
+ )
+
+ context = get_conformance_context()
+ id_token = context.get("id_token")
+ idp_token_endpoint = context.get("idp_token_endpoint")
+ mcp_server_auth_issuer = context.get("mcp_server_auth_issuer")
+ mcp_server_resource_id = context.get("mcp_server_resource_id")
+ scope = context.get("scope")
+
+ if not id_token:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'id_token'")
+ if not idp_token_endpoint:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_token_endpoint'")
+ if not mcp_server_auth_issuer:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_auth_issuer'")
+ if not mcp_server_resource_id:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_resource_id'")
+
+ # Create token exchange parameters
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer=mcp_server_auth_issuer,
+ mcp_server_resource_id=mcp_server_resource_id,
+ scope=scope,
+ )
+
+ # Create enterprise auth provider
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url=server_url,
+ client_metadata=OAuthClientMetadata(
+ client_name="conformance-enterprise-client",
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ ),
+ storage=InMemoryTokenStorage(),
+ idp_token_endpoint=idp_token_endpoint,
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Perform token exchange flow
+ async with httpx.AsyncClient() as client:
+ # Step 1: Set OAuth metadata manually (since we're not going through full OAuth flow)
+ logger.debug(f"Setting OAuth metadata for {server_url}")
+ from pydantic import AnyUrl as PydanticAnyUrl
+
+ from mcp.shared.auth import OAuthMetadata
+
+ # Extract base URL from server_url
+ base_url = server_url.replace("/mcp", "")
+ token_endpoint_url = f"{base_url}/oauth/token"
+ auth_endpoint_url = f"{base_url}/oauth/authorize"
+
+ enterprise_auth.context.oauth_metadata = OAuthMetadata(
+ issuer=mcp_server_auth_issuer,
+ authorization_endpoint=PydanticAnyUrl(auth_endpoint_url),
+ token_endpoint=PydanticAnyUrl(token_endpoint_url),
+ )
+ logger.debug(f"OAuth metadata set, token_endpoint: {token_endpoint_url}")
+
+ # Step 2: Exchange ID token for ID-JAG
+ logger.debug("Exchanging ID token for ID-JAG")
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+ logger.debug(f"Obtained ID-JAG: {id_jag[:50]}...")
+
+ # Step 3: Exchange ID-JAG for access token
+ logger.debug("Exchanging ID-JAG for access token")
+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
+ logger.debug(f"Obtained access token, expires in: {access_token.expires_in}s")
+
+ # Step 4: Verify we can make authenticated requests
+ logger.debug("Verifying access token with MCP endpoint")
+ auth_client = httpx.AsyncClient(headers={"Authorization": f"Bearer {access_token.access_token}"})
+ response = await auth_client.get(server_url.replace("/mcp", "") + "/mcp")
+ if response.status_code == 200:
+ logger.debug(f"Successfully authenticated with MCP server: {response.json()}")
+ else:
+ logger.warning(f"MCP server returned {response.status_code}")
+
+ logger.debug("Enterprise auth flow completed successfully")
+
+
+@register("auth/enterprise-saml-exchange")
+async def run_enterprise_saml_exchange(server_url: str) -> None:
+ """Enterprise managed auth: SAML assertion exchange flow."""
+ from mcp.client.auth.extensions.enterprise_managed_auth import (
+ EnterpriseAuthOAuthClientProvider,
+ TokenExchangeParameters,
+ )
+
+ context = get_conformance_context()
+ saml_assertion = context.get("saml_assertion")
+ idp_token_endpoint = context.get("idp_token_endpoint")
+ mcp_server_auth_issuer = context.get("mcp_server_auth_issuer")
+ mcp_server_resource_id = context.get("mcp_server_resource_id")
+ scope = context.get("scope")
+
+ if not saml_assertion:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'saml_assertion'")
+ if not idp_token_endpoint:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'idp_token_endpoint'")
+ if not mcp_server_auth_issuer:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_auth_issuer'")
+ if not mcp_server_resource_id:
+ raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'mcp_server_resource_id'")
+
+ # Create token exchange parameters for SAML
+ token_exchange_params = TokenExchangeParameters.from_saml_assertion(
+ saml_assertion=saml_assertion,
+ mcp_server_auth_issuer=mcp_server_auth_issuer,
+ mcp_server_resource_id=mcp_server_resource_id,
+ scope=scope,
+ )
+
+ # Create enterprise auth provider
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url=server_url,
+ client_metadata=OAuthClientMetadata(
+ client_name="conformance-enterprise-saml-client",
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ ),
+ storage=InMemoryTokenStorage(),
+ idp_token_endpoint=idp_token_endpoint,
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Perform token exchange flow
+ async with httpx.AsyncClient() as client:
+ # Step 1: Set OAuth metadata manually (since we're not going through full OAuth flow)
+ logger.debug(f"Setting OAuth metadata for {server_url}")
+ from pydantic import AnyUrl as PydanticAnyUrl
+
+ from mcp.shared.auth import OAuthMetadata
+
+ # Extract base URL from server_url
+ base_url = server_url.replace("/mcp", "")
+ token_endpoint_url = f"{base_url}/oauth/token"
+ auth_endpoint_url = f"{base_url}/oauth/authorize"
+
+ enterprise_auth.context.oauth_metadata = OAuthMetadata(
+ issuer=mcp_server_auth_issuer,
+ authorization_endpoint=PydanticAnyUrl(auth_endpoint_url),
+ token_endpoint=PydanticAnyUrl(token_endpoint_url),
+ )
+ logger.debug(f"OAuth metadata set, token_endpoint: {token_endpoint_url}")
+
+ # Step 2: Exchange SAML assertion for ID-JAG
+ logger.debug("Exchanging SAML assertion for ID-JAG")
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+ logger.debug(f"Obtained ID-JAG from SAML: {id_jag[:50]}...")
+
+ # Step 3: Exchange ID-JAG for access token
+ logger.debug("Exchanging ID-JAG for access token")
+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
+ logger.debug(f"Obtained access token, expires in: {access_token.expires_in}s")
+
+ # Step 4: Verify we can make authenticated requests
+ logger.debug("Verifying access token with MCP endpoint")
+ auth_client = httpx.AsyncClient(headers={"Authorization": f"Bearer {access_token.access_token}"})
+ response = await auth_client.get(server_url.replace("/mcp", "") + "/mcp")
+ if response.status_code == 200:
+ logger.debug(f"Successfully authenticated with MCP server: {response.json()}")
+ else:
+ logger.warning(f"MCP server returned {response.status_code}")
+
+ logger.debug("SAML enterprise auth flow completed successfully")
+
+
+@register("auth/enterprise-id-jag-validation")
+async def run_id_jag_validation(server_url: str) -> None:
+ """Validate ID-JAG token structure and claims."""
+ from mcp.client.auth.extensions.enterprise_managed_auth import (
+ EnterpriseAuthOAuthClientProvider,
+ TokenExchangeParameters,
+ decode_id_jag,
+ validate_token_exchange_params,
+ )
+
+ context = get_conformance_context()
+ id_token = context.get("id_token")
+ idp_token_endpoint = context.get("idp_token_endpoint")
+ mcp_server_auth_issuer = context.get("mcp_server_auth_issuer")
+ mcp_server_resource_id = context.get("mcp_server_resource_id")
+
+ if not all([id_token, idp_token_endpoint, mcp_server_auth_issuer, mcp_server_resource_id]):
+ raise RuntimeError("Missing required context parameters for ID-JAG validation")
+
+ # Create and validate token exchange parameters
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer=mcp_server_auth_issuer,
+ mcp_server_resource_id=mcp_server_resource_id,
+ )
+
+ logger.debug("Validating token exchange parameters")
+ validate_token_exchange_params(token_exchange_params)
+ logger.debug("Token exchange parameters validated successfully")
+
+ # Create enterprise auth provider
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url=server_url,
+ client_metadata=OAuthClientMetadata(
+ client_name="conformance-validation-client",
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ ),
+ storage=InMemoryTokenStorage(),
+ idp_token_endpoint=idp_token_endpoint,
+ token_exchange_params=token_exchange_params,
+ )
+
+ async with httpx.AsyncClient() as client:
+ # Get ID-JAG
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+ logger.debug(f"Obtained ID-JAG for validation: {id_jag[:50]}...")
+
+ # Decode and validate ID-JAG claims
+ logger.debug("Decoding ID-JAG token")
+ claims = decode_id_jag(id_jag)
+
+ # Validate required claims
+ assert claims.typ == "oauth-id-jag+jwt", f"Invalid typ: {claims.typ}"
+ assert claims.jti, "Missing jti claim"
+ assert claims.iss == mcp_server_auth_issuer or claims.iss, "Missing or invalid iss claim"
+ assert claims.sub, "Missing sub claim"
+ assert claims.aud, "Missing aud claim"
+ assert claims.resource == mcp_server_resource_id, f"Invalid resource: {claims.resource}"
+ assert claims.client_id, "Missing client_id claim"
+ assert claims.exp > claims.iat, "Invalid expiration"
+
+ logger.debug("ID-JAG validated successfully:")
+ logger.debug(f" Subject: {claims.sub}")
+ logger.debug(f" Issuer: {claims.iss}")
+ logger.debug(f" Audience: {claims.aud}")
+ logger.debug(f" Resource: {claims.resource}")
+ logger.debug(f" Client ID: {claims.client_id}")
+
+ logger.debug("ID-JAG validation completed successfully")
+
+
async def _run_auth_session(server_url: str, oauth_auth: OAuthClientProvider) -> None:
"""Common session logic for all OAuth flows."""
client = httpx.AsyncClient(auth=oauth_auth, timeout=30.0)
diff --git a/.github/actions/conformance/enterprise_auth_server.py b/.github/actions/conformance/enterprise_auth_server.py
new file mode 100644
index 000000000..5065add0f
--- /dev/null
+++ b/.github/actions/conformance/enterprise_auth_server.py
@@ -0,0 +1,332 @@
+#!/usr/bin/env python3
+"""Enterprise Auth Mock Server for Conformance Testing
+
+This server provides:
+1. Mock IdP token exchange endpoint (RFC 8693)
+2. MCP OAuth token endpoint (RFC 7523)
+3. Protected MCP tools endpoint
+4. OAuth metadata endpoint
+
+Run on port 3002 to avoid conflicts with everything-server.
+"""
+
+import asyncio
+import json
+import logging
+import secrets
+import time
+import uuid
+from datetime import datetime, timedelta, timezone
+from typing import Any
+
+import jwt
+from cryptography.hazmat.primitives.asymmetric import rsa
+from fastapi import Depends, FastAPI, Form, HTTPException
+from fastapi.responses import JSONResponse
+from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+from uvicorn import Config, Server
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+# Generate RSA key pair for JWT signing
+PRIVATE_KEY = rsa.generate_private_key(public_exponent=65537, key_size=2048)
+PUBLIC_KEY = PRIVATE_KEY.public_key()
+
+# Server configuration
+IDP_ISSUER = "https://conformance-idp.example.com"
+MCP_SERVER_ISSUER = "https://conformance-mcp.example.com"
+MCP_SERVER_RESOURCE = "https://conformance-mcp.example.com"
+
+# In-memory storage
+ACCESS_TOKENS: dict[str, dict[str, Any]] = {}
+
+app = FastAPI()
+
+# Security
+security = HTTPBearer(auto_error=False)
+
+
+def create_test_id_token(subject: str = "test-user@example.com", client_id: str = "test-client") -> str:
+ """Create a test ID token."""
+ now = datetime.now(timezone.utc)
+ claims = {
+ "iss": IDP_ISSUER,
+ "sub": subject,
+ "aud": client_id,
+ "exp": int((now + timedelta(hours=1)).timestamp()),
+ "iat": int(now.timestamp()),
+ "email": subject,
+ }
+ return jwt.encode(claims, PRIVATE_KEY, algorithm="RS256")
+
+
+def create_id_jag(
+ subject: str,
+ audience: str,
+ resource: str,
+ client_id: str,
+ scope: str | None = None,
+) -> str:
+ """Create an ID-JAG token."""
+ now = datetime.now(timezone.utc)
+ claims = {
+ "jti": str(uuid.uuid4()),
+ "iss": IDP_ISSUER,
+ "sub": subject,
+ "aud": audience,
+ "resource": resource,
+ "client_id": client_id,
+ "exp": int((now + timedelta(minutes=5)).timestamp()),
+ "iat": int(now.timestamp()),
+ }
+ if scope:
+ claims["scope"] = scope
+
+ return jwt.encode(claims, PRIVATE_KEY, algorithm="RS256", headers={"typ": "oauth-id-jag+jwt"})
+
+
+def verify_id_token(id_token: str) -> dict[str, Any]:
+ """Verify and decode an ID token."""
+ try:
+ claims = jwt.decode(id_token, PUBLIC_KEY, algorithms=["RS256"], options={"verify_aud": False})
+ return claims
+ except jwt.InvalidTokenError as e:
+ raise HTTPException(status_code=400, detail=f"Invalid ID token: {e}") from e
+
+
+def verify_id_jag(id_jag: str) -> dict[str, Any]:
+ """Verify and decode an ID-JAG token."""
+ try:
+ header = jwt.get_unverified_header(id_jag)
+ if header.get("typ") != "oauth-id-jag+jwt":
+ raise HTTPException(status_code=400, detail="Invalid ID-JAG type")
+ claims = jwt.decode(id_jag, PUBLIC_KEY, algorithms=["RS256"], options={"verify_aud": False})
+ return claims
+ except jwt.InvalidTokenError as e:
+ raise HTTPException(status_code=400, detail=f"Invalid ID-JAG: {e}") from e
+
+
+# OAuth Metadata Endpoint
+@app.get("/.well-known/oauth-authorization-server")
+async def oauth_metadata() -> JSONResponse:
+ """OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
+ metadata = {
+ "issuer": MCP_SERVER_ISSUER,
+ "token_endpoint": "http://localhost:3002/oauth/token",
+ "grant_types_supported": ["urn:ietf:params:oauth:grant-type:jwt-bearer", "refresh_token"],
+ "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post", "none"],
+ }
+ return JSONResponse(metadata)
+
+
+# Token Exchange Endpoint (IdP)
+@app.post("/token-exchange")
+async def token_exchange(
+ grant_type: str = Form(...),
+ requested_token_type: str = Form(...),
+ audience: str = Form(...),
+ resource: str = Form(...),
+ subject_token: str = Form(...),
+ subject_token_type: str = Form(...),
+ scope: str | None = Form(None),
+ client_id: str | None = Form(None),
+ client_secret: str | None = Form(None),
+) -> JSONResponse:
+ """RFC 8693 Token Exchange endpoint."""
+ logger.info(f"Token exchange request: grant_type={grant_type}, subject_token_type={subject_token_type}")
+
+ if grant_type != "urn:ietf:params:oauth:grant-type:token-exchange":
+ return JSONResponse(
+ status_code=400,
+ content={"error": "unsupported_grant_type", "error_description": "Only token-exchange grant supported"},
+ )
+
+ if requested_token_type != "urn:ietf:params:oauth:token-type:id-jag":
+ return JSONResponse(
+ status_code=400,
+ content={
+ "error": "invalid_request",
+ "error_description": f"Unsupported token type: {requested_token_type}",
+ },
+ )
+
+ # Extract subject based on token type
+ if subject_token_type == "urn:ietf:params:oauth:token-type:id_token":
+ id_token_claims = verify_id_token(subject_token)
+ subject = id_token_claims["sub"]
+ elif subject_token_type == "urn:ietf:params:oauth:token-type:saml2":
+ # For SAML, extract from mock data
+ import base64
+
+ try:
+ saml_data = json.loads(base64.b64decode(subject_token))
+ subject = saml_data.get("subject", "saml-user@example.com")
+ except Exception:
+ subject = "saml-user@example.com"
+ else:
+ return JSONResponse(
+ status_code=400,
+ content={
+ "error": "invalid_request",
+ "error_description": f"Unsupported subject token type: {subject_token_type}",
+ },
+ )
+
+ # Create ID-JAG
+ id_jag = create_id_jag(
+ subject=subject,
+ audience=audience,
+ resource=resource,
+ client_id=client_id or "test-client",
+ scope=scope,
+ )
+
+ response = {
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": id_jag,
+ "token_type": "N_A",
+ "expires_in": 300,
+ }
+ if scope:
+ response["scope"] = scope
+
+ logger.info("Token exchange successful")
+ return JSONResponse(response)
+
+
+# JWT Bearer Grant Endpoint (MCP Server)
+@app.post("/oauth/token")
+async def jwt_bearer_grant(
+ grant_type: str = Form(...),
+ assertion: str = Form(...),
+ client_id: str | None = Form(None),
+ client_secret: str | None = Form(None),
+) -> JSONResponse:
+ """RFC 7523 JWT Bearer Grant endpoint."""
+ logger.info(f"JWT bearer grant request: grant_type={grant_type}")
+
+ if grant_type != "urn:ietf:params:oauth:grant-type:jwt-bearer":
+ return JSONResponse(
+ status_code=400,
+ content={"error": "unsupported_grant_type", "error_description": "Only jwt-bearer grant supported"},
+ )
+
+ # Verify ID-JAG
+ id_jag_claims = verify_id_jag(assertion)
+
+ # Create access token
+ access_token = secrets.token_urlsafe(32)
+ expires_in = 3600
+
+ # Store token info
+ ACCESS_TOKENS[access_token] = {
+ "subject": id_jag_claims["sub"],
+ "client_id": id_jag_claims.get("client_id"),
+ "scope": id_jag_claims.get("scope"),
+ "expires_at": time.time() + expires_in,
+ }
+
+ response = {
+ "access_token": access_token,
+ "token_type": "Bearer",
+ "expires_in": expires_in,
+ }
+ if id_jag_claims.get("scope"):
+ response["scope"] = id_jag_claims["scope"]
+
+ logger.info("JWT bearer grant successful")
+ return JSONResponse(response)
+
+
+# MCP Endpoints (protected by access token)
+@app.get("/mcp")
+async def mcp_root() -> JSONResponse:
+ """MCP root endpoint - returns basic info."""
+ return JSONResponse({"server": "enterprise-auth-test-server", "version": "1.0"})
+
+
+@app.post("/mcp")
+async def mcp_jsonrpc(
+ credentials: HTTPAuthorizationCredentials | None = Depends(security),
+) -> JSONResponse:
+ """MCP JSON-RPC endpoint."""
+ if not credentials:
+ raise HTTPException(status_code=401, detail="Missing authorization")
+
+ token = credentials.credentials
+ if token not in ACCESS_TOKENS:
+ raise HTTPException(status_code=401, detail="Invalid or expired token")
+
+ token_info = ACCESS_TOKENS[token]
+ if token_info["expires_at"] < time.time():
+ del ACCESS_TOKENS[token]
+ raise HTTPException(status_code=401, detail="Token expired")
+
+ # Return a simple MCP response
+ return JSONResponse(
+ {
+ "jsonrpc": "2.0",
+ "result": {
+ "protocolVersion": "2024-11-05",
+ "capabilities": {"tools": {}},
+ "serverInfo": {"name": "enterprise-auth-test-server", "version": "1.0"},
+ },
+ }
+ )
+
+
+# Helper endpoint to get test ID token
+@app.get("/test/id-token")
+async def get_test_id_token(subject: str = "test-user@example.com", client_id: str = "test-client") -> JSONResponse:
+ """Get a test ID token for conformance testing."""
+ id_token = create_test_id_token(subject, client_id)
+ return JSONResponse({"id_token": id_token})
+
+
+@app.get("/test/context")
+async def get_test_context() -> JSONResponse:
+ """Get complete test context for conformance testing."""
+ id_token = create_test_id_token()
+
+ # Create mock SAML assertion
+ import base64
+
+ saml_data = {
+ "issuer": IDP_ISSUER,
+ "subject": "saml-user@example.com",
+ "issued_at": datetime.now(timezone.utc).isoformat(),
+ }
+ saml_assertion = base64.b64encode(json.dumps(saml_data).encode()).decode()
+
+ context = {
+ "id_token": id_token,
+ "saml_assertion": saml_assertion,
+ "idp_token_endpoint": "http://localhost:3002/token-exchange",
+ "mcp_server_auth_issuer": MCP_SERVER_ISSUER,
+ "mcp_server_resource_id": MCP_SERVER_RESOURCE,
+ "client_id": "test-client",
+ "scope": "mcp:tools mcp:resources",
+ }
+
+ return JSONResponse(context)
+
+
+async def run_server(port: int = 3002) -> None:
+ """Run the mock server."""
+ config = Config(app=app, host="0.0.0.0", port=port, log_level="info")
+ server = Server(config)
+ logger.info(f"Starting Enterprise Auth Mock Server on port {port}")
+ logger.info(f"Token Exchange endpoint: http://localhost:{port}/token-exchange")
+ logger.info(f"JWT Bearer Grant endpoint: http://localhost:{port}/oauth/token")
+ logger.info(f"MCP endpoint: http://localhost:{port}/mcp")
+ logger.info(f"OAuth metadata: http://localhost:{port}/.well-known/oauth-authorization-server")
+ logger.info(f"Test context: http://localhost:{port}/test/context")
+ await server.serve()
+
+
+if __name__ == "__main__":
+ import sys
+
+ port = int(sys.argv[1]) if len(sys.argv) > 1 else 3002
+ asyncio.run(run_server(port))
diff --git a/.github/actions/conformance/run-enterprise-auth-with-server.sh b/.github/actions/conformance/run-enterprise-auth-with-server.sh
new file mode 100755
index 000000000..c8ddfdcc6
--- /dev/null
+++ b/.github/actions/conformance/run-enterprise-auth-with-server.sh
@@ -0,0 +1,169 @@
+#!/bin/bash
+set -e
+
+# Enterprise Auth Full Conformance Test with Mock Server
+# This script:
+# 1. Starts the enterprise auth mock server (IdP + OAuth endpoints)
+# 2. Fetches test context from the server
+# 3. Runs all enterprise auth conformance scenarios
+# 4. Cleans up servers on exit
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+cd "$SCRIPT_DIR/../../.."
+
+MOCK_SERVER_PORT=3002
+MOCK_SERVER_URL="http://localhost:${MOCK_SERVER_PORT}"
+
+echo "==================================================================="
+echo " Enterprise Auth Conformance Tests with Mock Server (SEP-990)"
+echo "==================================================================="
+echo ""
+
+# Function to cleanup servers
+cleanup() {
+ echo ""
+ echo "Cleaning up servers..."
+ if [ -n "$MOCK_SERVER_PID" ]; then
+ kill $MOCK_SERVER_PID 2>/dev/null || true
+ wait $MOCK_SERVER_PID 2>/dev/null || true
+ echo "✓ Mock server stopped"
+ fi
+}
+
+trap cleanup EXIT
+
+# Start enterprise auth mock server
+echo "Starting Enterprise Auth Mock Server on port ${MOCK_SERVER_PORT}..."
+uv run --frozen python "$SCRIPT_DIR/enterprise_auth_server.py" $MOCK_SERVER_PORT > /tmp/enterprise_auth_server.log 2>&1 &
+MOCK_SERVER_PID=$!
+
+# Wait for mock server to be ready
+echo "Waiting for mock server to be ready..."
+MAX_RETRIES=30
+RETRY_COUNT=0
+while ! curl -s "${MOCK_SERVER_URL}/test/context" > /dev/null 2>&1; do
+ RETRY_COUNT=$((RETRY_COUNT + 1))
+ if [ $RETRY_COUNT -ge $MAX_RETRIES ]; then
+ echo "✗ Mock server failed to start after ${MAX_RETRIES} retries" >&2
+ echo "Server log:"
+ cat /tmp/enterprise_auth_server.log
+ exit 1
+ fi
+ sleep 0.5
+done
+
+echo "✓ Mock server ready at ${MOCK_SERVER_URL}"
+echo ""
+
+# Fetch test context from server
+echo "Fetching test context from mock server..."
+export MCP_CONFORMANCE_CONTEXT=$(curl -s "${MOCK_SERVER_URL}/test/context")
+
+if [ -z "$MCP_CONFORMANCE_CONTEXT" ]; then
+ echo "✗ Failed to fetch test context" >&2
+ exit 1
+fi
+
+echo "✓ Test context retrieved"
+echo ""
+
+# Display server info
+echo "Server Endpoints:"
+echo " - Token Exchange (IdP): ${MOCK_SERVER_URL}/token-exchange"
+echo " - OAuth Token (MCP): ${MOCK_SERVER_URL}/oauth/token"
+echo " - MCP Endpoint: ${MOCK_SERVER_URL}/mcp"
+echo " - OAuth Metadata: ${MOCK_SERVER_URL}/.well-known/oauth-authorization-server"
+echo ""
+
+# Run conformance scenarios
+echo "==================================================================="
+echo " Running Conformance Test Scenarios"
+echo "==================================================================="
+echo ""
+
+# Test 1: ID-JAG Validation
+echo "--- Test 1: ID-JAG Token Validation ---"
+export MCP_CONFORMANCE_SCENARIO="auth/enterprise-id-jag-validation"
+if uv run --frozen python "$SCRIPT_DIR/client.py" "${MOCK_SERVER_URL}/mcp" 2>&1 | tee /tmp/test1.log | grep -q "ID-JAG validation completed successfully"; then
+ echo "✓ ID-JAG validation PASSED"
+ TEST1_PASS=true
+else
+ echo "✗ ID-JAG validation FAILED"
+ echo "Log output:"
+ cat /tmp/test1.log | tail -20
+ TEST1_PASS=false
+fi
+echo ""
+
+# Test 2: OIDC ID Token Exchange Flow
+echo "--- Test 2: OIDC ID Token Exchange Flow ---"
+export MCP_CONFORMANCE_SCENARIO="auth/enterprise-token-exchange"
+if uv run --frozen python "$SCRIPT_DIR/client.py" "${MOCK_SERVER_URL}/mcp" 2>&1 | tee /tmp/test2.log | grep -q "Enterprise auth flow completed successfully"; then
+ echo "✓ OIDC token exchange PASSED"
+ TEST2_PASS=true
+else
+ echo "✗ OIDC token exchange FAILED"
+ echo "Log output:"
+ cat /tmp/test2.log | tail -20
+ TEST2_PASS=false
+fi
+echo ""
+
+# Test 3: SAML Assertion Exchange Flow
+echo "--- Test 3: SAML Assertion Exchange Flow ---"
+export MCP_CONFORMANCE_SCENARIO="auth/enterprise-saml-exchange"
+if uv run --frozen python "$SCRIPT_DIR/client.py" "${MOCK_SERVER_URL}/mcp" 2>&1 | tee /tmp/test3.log | grep -q "SAML enterprise auth flow completed successfully"; then
+ echo "✓ SAML assertion exchange PASSED"
+ TEST3_PASS=true
+else
+ echo "✗ SAML assertion exchange FAILED"
+ echo "Log output:"
+ cat /tmp/test3.log | tail -20
+ TEST3_PASS=false
+fi
+echo ""
+
+# Summary
+echo "==================================================================="
+echo " Test Results Summary"
+echo "==================================================================="
+echo ""
+
+TESTS_PASSED=0
+TESTS_FAILED=0
+
+if [ "$TEST1_PASS" = true ]; then
+ echo "✓ ID-JAG Validation"
+ TESTS_PASSED=$((TESTS_PASSED + 1))
+else
+ echo "✗ ID-JAG Validation"
+ TESTS_FAILED=$((TESTS_FAILED + 1))
+fi
+
+if [ "$TEST2_PASS" = true ]; then
+ echo "✓ OIDC Token Exchange"
+ TESTS_PASSED=$((TESTS_PASSED + 1))
+else
+ echo "✗ OIDC Token Exchange"
+ TESTS_FAILED=$((TESTS_FAILED + 1))
+fi
+
+if [ "$TEST3_PASS" = true ]; then
+ echo "✓ SAML Assertion Exchange"
+ TESTS_PASSED=$((TESTS_PASSED + 1))
+else
+ echo "✗ SAML Assertion Exchange"
+ TESTS_FAILED=$((TESTS_FAILED + 1))
+fi
+
+echo ""
+echo "Total: ${TESTS_PASSED}/3 passed, ${TESTS_FAILED}/3 failed"
+echo ""
+
+if [ $TESTS_FAILED -eq 0 ]; then
+ echo "🎉 All enterprise auth conformance tests PASSED!"
+ exit 0
+else
+ echo "❌ Some tests failed. Check logs above for details."
+ exit 1
+fi
diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml
index 248e5bf6a..3b2c8cbe0 100644
--- a/.github/workflows/conformance.yml
+++ b/.github/workflows/conformance.yml
@@ -43,3 +43,18 @@ jobs:
node-version: 24
- run: uv sync --frozen --all-extras --package mcp
- run: npx @modelcontextprotocol/conformance@0.1.10 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all
+
+ enterprise-auth-conformance:
+ runs-on: ubuntu-latest
+ continue-on-error: true
+ steps:
+ - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
+ - uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # v7.2.0
+ with:
+ enable-cache: true
+ version: 0.9.5
+ - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
+ with:
+ node-version: 24
+ - run: uv sync --frozen --all-extras --package mcp
+ - run: ./.github/actions/conformance/run-enterprise-auth-with-server.sh
diff --git a/README.md b/README.md
index 468e1d85d..e6f39968a 100644
--- a/README.md
+++ b/README.md
@@ -68,6 +68,7 @@
- [Writing MCP Clients](#writing-mcp-clients)
- [Client Display Utilities](#client-display-utilities)
- [OAuth Authentication for Clients](#oauth-authentication-for-clients)
+ - [Enterprise Managed Authorization](#enterprise-managed-authorization)
- [Parsing Tool Results](#parsing-tool-results)
- [MCP Primitives](#mcp-primitives)
- [Server Capabilities](#server-capabilities)
@@ -2421,6 +2422,288 @@ _Full example: [examples/snippets/clients/oauth_client.py](https://github.com/mo
For a complete working example, see [`examples/clients/simple-auth-client/`](examples/clients/simple-auth-client/).
+#### Enterprise Managed Authorization
+
+The SDK includes support for Enterprise Managed Authorization (SEP-990), which enables MCP clients to connect to protected servers using enterprise Single Sign-On (SSO) systems. This implementation supports:
+
+- **RFC 8693**: OAuth 2.0 Token Exchange (ID Token → ID-JAG)
+- **RFC 7523**: JSON Web Token (JWT) Profile for OAuth 2.0 Authorization Grants (ID-JAG → Access Token)
+- Integration with enterprise identity providers (Okta, Azure AD, etc.)
+
+**Key Components:**
+
+The `EnterpriseAuthOAuthClientProvider` class extends the standard OAuth provider to implement the enterprise authorization flow:
+
+**Token Exchange Flow:**
+
+1. **Obtain ID Token** from your enterprise IdP (e.g., Okta, Azure AD)
+2. **Exchange ID Token for ID-JAG** using RFC 8693 Token Exchange
+3. **Exchange ID-JAG for Access Token** using RFC 7523 JWT Bearer Grant
+4. **Use Access Token** to call protected MCP server tools
+
+**Using the Access Token with MCP Server:**
+
+1. Once you have obtained the access token, you can use it to authenticate requests to the MCP server
+2. The access token is automatically included in all subsequent requests to the MCP server, allowing you to access protected tools and resources based on your enterprise identity and permissions.
+
+**Handling Token Expiration and Refresh:**
+
+Access tokens have a limited lifetime and will expire. When tokens expire:
+
+- **Check Token Expiration**: Use the `expires_in` field to determine when the token expires
+- **Refresh Flow**: When expired, repeat the token exchange flow with a fresh ID token from your IdP
+- **Automatic Refresh**: Implement automatic token refresh before expiration (recommended for production)
+- **Error Handling**: Catch authentication errors and retry with refreshed tokens
+
+**Important Notes:**
+
+- **ID Token Expiration**: If the ID token from your IdP expires, you must re-authenticate with the IdP to obtain a new ID token before performing token exchange
+- **Token Storage**: Store tokens securely and implement the `TokenStorage` interface to persist tokens between application restarts
+- **Scope Changes**: If you need different scopes, you must obtain a new ID token from the IdP with the required scopes
+- **Security**: Never log or expose access tokens or ID tokens in production environments
+
+**Example Usage:**
+
+
+```python
+import asyncio
+from datetime import datetime, timedelta, timezone
+from typing import Any
+
+import httpx
+from pydantic import AnyUrl
+
+from mcp import ClientSession
+from mcp.client.auth import OAuthTokenError, TokenStorage
+from mcp.client.auth.extensions import (
+ EnterpriseAuthOAuthClientProvider,
+ TokenExchangeParameters,
+)
+from mcp.client.sse import sse_client
+from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
+from mcp.types import CallToolResult
+
+
+# Placeholder function for IdP authentication
+async def get_id_token_from_idp() -> str:
+ """Placeholder function to get ID token from your IdP.
+
+ In production, implement actual IdP authentication flow.
+ """
+ raise NotImplementedError("Implement your IdP authentication flow here")
+
+
+# Define token storage implementation
+class SimpleTokenStorage(TokenStorage):
+ def __init__(self) -> None:
+ self._tokens: OAuthToken | None = None
+ self._client_info: OAuthClientInformationFull | None = None
+
+ async def get_tokens(self) -> OAuthToken | None:
+ return self._tokens
+
+ async def set_tokens(self, tokens: OAuthToken) -> None:
+ self._tokens = tokens
+
+ async def get_client_info(self) -> OAuthClientInformationFull | None:
+ return self._client_info
+
+ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
+ self._client_info = client_info
+
+
+def is_token_expired(access_token: OAuthToken) -> bool:
+ """Check if the access token has expired."""
+ if access_token.expires_in:
+ # Calculate expiration time
+ issued_at = datetime.now(timezone.utc)
+ expiration_time = issued_at + timedelta(seconds=access_token.expires_in)
+ return datetime.now(timezone.utc) >= expiration_time
+ return False
+
+
+async def refresh_access_token(
+ enterprise_auth: EnterpriseAuthOAuthClientProvider,
+ client: httpx.AsyncClient,
+ id_token: str,
+) -> OAuthToken:
+ """Refresh the access token when it expires."""
+ try:
+ # Update token exchange parameters with fresh ID token
+ enterprise_auth.token_exchange_params.subject_token = id_token
+
+ # Re-exchange for new ID-JAG
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+
+ # Get new access token
+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
+ return access_token
+ except Exception as e:
+ print(f"Token refresh failed: {e}")
+ # Re-authenticate with IdP if ID token is also expired
+ id_token = await get_id_token_from_idp()
+ return await refresh_access_token(enterprise_auth, client, id_token)
+
+
+async def call_tool_with_retry(
+ session: ClientSession,
+ tool_name: str,
+ arguments: dict[str, Any],
+ enterprise_auth: EnterpriseAuthOAuthClientProvider,
+ client: httpx.AsyncClient,
+ id_token: str,
+) -> CallToolResult | None:
+ """Call a tool with automatic retry on token expiration."""
+ max_retries = 1
+
+ for attempt in range(max_retries + 1):
+ try:
+ result = await session.call_tool(tool_name, arguments)
+ return result
+ except OAuthTokenError:
+ if attempt < max_retries:
+ print("Token expired, refreshing...")
+ # Refresh token and reconnect
+ _access_token = await refresh_access_token(enterprise_auth, client, id_token)
+ # Note: In production, you'd need to reconnect the session here
+ else:
+ raise
+ return None
+
+
+async def main() -> None:
+ # Step 1: Get ID token from your IdP (example with Okta)
+ id_token = await get_id_token_from_idp() # Your IdP authentication
+
+ # Step 2: Configure token exchange parameters
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer="https://your-idp.com", # IdP issuer URL
+ mcp_server_resource_id="https://mcp-server.example.com", # MCP server resource ID
+ scope="mcp:tools mcp:resources", # Optional scopes
+ )
+
+ # Step 3: Create enterprise auth provider
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example.com",
+ client_metadata=OAuthClientMetadata(
+ client_name="Enterprise MCP Client",
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ async with httpx.AsyncClient() as client:
+ # Step 4: Exchange ID token for ID-JAG
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+ print(f"Obtained ID-JAG: {id_jag[:50]}...")
+
+ # Step 5: Exchange ID-JAG for access token
+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
+ print(f"Access token obtained, expires in: {access_token.expires_in}s")
+
+ # Step 6: Check if token is expired (for demonstration)
+ if is_token_expired(access_token):
+ print("Token is expired, refreshing...")
+ access_token = await refresh_access_token(enterprise_auth, client, id_token)
+
+ # Step 7: Use the access token to connect to MCP server
+ headers = {"Authorization": f"Bearer {access_token.access_token}"}
+
+ async with sse_client(url="https://mcp-server.example.com", headers=headers) as (read, write):
+ async with ClientSession(read, write) as session:
+ await session.initialize()
+
+ # Call tools with automatic retry on token expiration
+ result = await call_tool_with_retry(
+ session, "enterprise_tool", {"param": "value"}, enterprise_auth, client, id_token
+ )
+ if result:
+ print(f"Tool result: {result.content}")
+
+ # List available resources
+ resources = await session.list_resources()
+ for resource in resources.resources:
+ print(f"Resource: {resource.uri}")
+
+
+async def maintain_active_session(
+ enterprise_auth: EnterpriseAuthOAuthClientProvider,
+ mcp_server_url: str,
+) -> None:
+ """Maintain an active session with automatic token refresh."""
+ id_token_var = await get_id_token_from_idp()
+
+ async with httpx.AsyncClient() as client:
+ while True:
+ try:
+ # Update token exchange params with current ID token
+ enterprise_auth.token_exchange_params.subject_token = id_token_var
+
+ # Get access token
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
+
+ # Calculate refresh time (refresh before expiration)
+ refresh_in = access_token.expires_in - 60 if access_token.expires_in else 300
+
+ # Use the token for MCP operations
+ headers = {"Authorization": f"Bearer {access_token.access_token}"}
+ async with sse_client(mcp_server_url, headers=headers) as (read, write):
+ async with ClientSession(read, write) as session:
+ await session.initialize()
+
+ # Perform operations...
+ # Schedule refresh before token expires
+ await asyncio.sleep(refresh_in)
+
+ except Exception as e:
+ print(f"Session error: {e}")
+ # Re-authenticate with IdP
+ id_token_var = await get_id_token_from_idp()
+ await asyncio.sleep(5) # Wait before retry
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
+```
+
+_Full example: [examples/snippets/clients/enterprise_managed_auth_client.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/enterprise_managed_auth_client.py)_
+
+
+**Working with SAML Assertions:**
+
+If your enterprise uses SAML instead of OIDC, you can exchange SAML assertions:
+
+```python
+token_exchange_params = TokenExchangeParameters.from_saml_assertion(
+ saml_assertion=saml_assertion_string,
+ mcp_server_auth_issuer="https://your-idp.com",
+ mcp_server_resource_id="https://mcp-server.example.com",
+ scope="mcp:tools",
+)
+```
+
+**Decoding and Inspecting ID-JAG Tokens:**
+
+You can decode ID-JAG tokens to inspect their claims:
+
+```python
+from mcp.client.auth.extensions import decode_id_jag
+
+# Decode without signature verification (for inspection only)
+claims = decode_id_jag(id_jag)
+print(f"Subject: {claims.sub}")
+print(f"Issuer: {claims.iss}")
+print(f"Audience: {claims.aud}")
+print(f"Client ID: {claims.client_id}")
+print(f"Resource: {claims.resource}")
+```
+
### Parsing Tool Results
When calling tools through MCP, the `CallToolResult` object contains the tool's response in a structured format. Understanding how to parse this result is essential for properly handling tool outputs.
diff --git a/examples/snippets/clients/enterprise_managed_auth_client.py b/examples/snippets/clients/enterprise_managed_auth_client.py
new file mode 100644
index 000000000..57070043a
--- /dev/null
+++ b/examples/snippets/clients/enterprise_managed_auth_client.py
@@ -0,0 +1,204 @@
+import asyncio
+from datetime import datetime, timedelta, timezone
+from typing import Any
+
+import httpx
+from pydantic import AnyUrl
+
+from mcp import ClientSession
+from mcp.client.auth import OAuthTokenError, TokenStorage
+from mcp.client.auth.extensions import (
+ EnterpriseAuthOAuthClientProvider,
+ TokenExchangeParameters,
+)
+from mcp.client.sse import sse_client
+from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
+from mcp.types import CallToolResult
+
+
+# Placeholder function for IdP authentication
+async def get_id_token_from_idp() -> str:
+ """Placeholder function to get ID token from your IdP.
+
+ In production, implement actual IdP authentication flow.
+ """
+ raise NotImplementedError("Implement your IdP authentication flow here")
+
+
+# Define token storage implementation
+class SimpleTokenStorage(TokenStorage):
+ def __init__(self) -> None:
+ self._tokens: OAuthToken | None = None
+ self._client_info: OAuthClientInformationFull | None = None
+
+ async def get_tokens(self) -> OAuthToken | None:
+ return self._tokens
+
+ async def set_tokens(self, tokens: OAuthToken) -> None:
+ self._tokens = tokens
+
+ async def get_client_info(self) -> OAuthClientInformationFull | None:
+ return self._client_info
+
+ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
+ self._client_info = client_info
+
+
+def is_token_expired(access_token: OAuthToken) -> bool:
+ """Check if the access token has expired."""
+ if access_token.expires_in:
+ # Calculate expiration time
+ issued_at = datetime.now(timezone.utc)
+ expiration_time = issued_at + timedelta(seconds=access_token.expires_in)
+ return datetime.now(timezone.utc) >= expiration_time
+ return False
+
+
+async def refresh_access_token(
+ enterprise_auth: EnterpriseAuthOAuthClientProvider,
+ client: httpx.AsyncClient,
+ id_token: str,
+) -> OAuthToken:
+ """Refresh the access token when it expires."""
+ try:
+ # Update token exchange parameters with fresh ID token
+ enterprise_auth.token_exchange_params.subject_token = id_token
+
+ # Re-exchange for new ID-JAG
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+
+ # Get new access token
+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
+ return access_token
+ except Exception as e:
+ print(f"Token refresh failed: {e}")
+ # Re-authenticate with IdP if ID token is also expired
+ id_token = await get_id_token_from_idp()
+ return await refresh_access_token(enterprise_auth, client, id_token)
+
+
+async def call_tool_with_retry(
+ session: ClientSession,
+ tool_name: str,
+ arguments: dict[str, Any],
+ enterprise_auth: EnterpriseAuthOAuthClientProvider,
+ client: httpx.AsyncClient,
+ id_token: str,
+) -> CallToolResult | None:
+ """Call a tool with automatic retry on token expiration."""
+ max_retries = 1
+
+ for attempt in range(max_retries + 1):
+ try:
+ result = await session.call_tool(tool_name, arguments)
+ return result
+ except OAuthTokenError:
+ if attempt < max_retries:
+ print("Token expired, refreshing...")
+ # Refresh token and reconnect
+ _access_token = await refresh_access_token(enterprise_auth, client, id_token)
+ # Note: In production, you'd need to reconnect the session here
+ else:
+ raise
+ return None
+
+
+async def main() -> None:
+ # Step 1: Get ID token from your IdP (example with Okta)
+ id_token = await get_id_token_from_idp() # Your IdP authentication
+
+ # Step 2: Configure token exchange parameters
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=id_token,
+ mcp_server_auth_issuer="https://your-idp.com", # IdP issuer URL
+ mcp_server_resource_id="https://mcp-server.example.com", # MCP server resource ID
+ scope="mcp:tools mcp:resources", # Optional scopes
+ )
+
+ # Step 3: Create enterprise auth provider
+ enterprise_auth = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example.com",
+ client_metadata=OAuthClientMetadata(
+ client_name="Enterprise MCP Client",
+ redirect_uris=[AnyUrl("http://localhost:3000/callback")],
+ grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
+ response_types=["token"],
+ ),
+ storage=SimpleTokenStorage(),
+ idp_token_endpoint="https://your-idp.com/oauth2/v1/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ async with httpx.AsyncClient() as client:
+ # Step 4: Exchange ID token for ID-JAG
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+ print(f"Obtained ID-JAG: {id_jag[:50]}...")
+
+ # Step 5: Exchange ID-JAG for access token
+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
+ print(f"Access token obtained, expires in: {access_token.expires_in}s")
+
+ # Step 6: Check if token is expired (for demonstration)
+ if is_token_expired(access_token):
+ print("Token is expired, refreshing...")
+ access_token = await refresh_access_token(enterprise_auth, client, id_token)
+
+ # Step 7: Use the access token to connect to MCP server
+ headers = {"Authorization": f"Bearer {access_token.access_token}"}
+
+ async with sse_client(url="https://mcp-server.example.com", headers=headers) as (read, write):
+ async with ClientSession(read, write) as session:
+ await session.initialize()
+
+ # Call tools with automatic retry on token expiration
+ result = await call_tool_with_retry(
+ session, "enterprise_tool", {"param": "value"}, enterprise_auth, client, id_token
+ )
+ if result:
+ print(f"Tool result: {result.content}")
+
+ # List available resources
+ resources = await session.list_resources()
+ for resource in resources.resources:
+ print(f"Resource: {resource.uri}")
+
+
+async def maintain_active_session(
+ enterprise_auth: EnterpriseAuthOAuthClientProvider,
+ mcp_server_url: str,
+) -> None:
+ """Maintain an active session with automatic token refresh."""
+ id_token_var = await get_id_token_from_idp()
+
+ async with httpx.AsyncClient() as client:
+ while True:
+ try:
+ # Update token exchange params with current ID token
+ enterprise_auth.token_exchange_params.subject_token = id_token_var
+
+ # Get access token
+ id_jag = await enterprise_auth.exchange_token_for_id_jag(client)
+ access_token = await enterprise_auth.exchange_id_jag_for_access_token(client, id_jag)
+
+ # Calculate refresh time (refresh before expiration)
+ refresh_in = access_token.expires_in - 60 if access_token.expires_in else 300
+
+ # Use the token for MCP operations
+ headers = {"Authorization": f"Bearer {access_token.access_token}"}
+ async with sse_client(mcp_server_url, headers=headers) as (read, write):
+ async with ClientSession(read, write) as session:
+ await session.initialize()
+
+ # Perform operations...
+ # Schedule refresh before token expires
+ await asyncio.sleep(refresh_in)
+
+ except Exception as e:
+ print(f"Session error: {e}")
+ # Re-authenticate with IdP
+ id_token_var = await get_id_token_from_idp()
+ await asyncio.sleep(5) # Wait before retry
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/pyproject.toml b/pyproject.toml
index 87eac7213..ca19e33a2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -69,6 +69,7 @@ dev = [
"dirty-equals>=0.9.0",
"coverage[toml]>=7.13.1",
"pillow>=12.0",
+ "fastapi>=0.115.0", # For enterprise auth conformance test mock server
]
docs = [
"mkdocs>=1.6.1",
diff --git a/src/mcp/client/auth/extensions/__init__.py b/src/mcp/client/auth/extensions/__init__.py
index e69de29bb..56ba368ef 100644
--- a/src/mcp/client/auth/extensions/__init__.py
+++ b/src/mcp/client/auth/extensions/__init__.py
@@ -0,0 +1,19 @@
+"""MCP Client Auth Extensions."""
+
+from mcp.client.auth.extensions.enterprise_managed_auth import (
+ EnterpriseAuthOAuthClientProvider,
+ IDJAGClaims,
+ TokenExchangeParameters,
+ TokenExchangeResponse,
+ decode_id_jag,
+ validate_token_exchange_params,
+)
+
+__all__ = [
+ "EnterpriseAuthOAuthClientProvider",
+ "IDJAGClaims",
+ "TokenExchangeParameters",
+ "TokenExchangeResponse",
+ "decode_id_jag",
+ "validate_token_exchange_params",
+]
diff --git a/src/mcp/client/auth/extensions/enterprise_managed_auth.py b/src/mcp/client/auth/extensions/enterprise_managed_auth.py
new file mode 100644
index 000000000..5ebc19c56
--- /dev/null
+++ b/src/mcp/client/auth/extensions/enterprise_managed_auth.py
@@ -0,0 +1,416 @@
+"""Enterprise Managed Authorization extension for MCP (SEP-990).
+
+Implements RFC 8693 Token Exchange and RFC 7523 JWT Bearer Grant for
+enterprise SSO integration.
+"""
+
+import logging
+from collections.abc import Awaitable, Callable
+
+import httpx
+import jwt
+from pydantic import BaseModel, Field
+from typing_extensions import TypedDict
+
+from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
+from mcp.shared.auth import OAuthClientMetadata, OAuthToken
+
+logger = logging.getLogger(__name__)
+
+
+class TokenExchangeRequestData(TypedDict, total=False):
+ """Type definition for RFC 8693 Token Exchange request data."""
+
+ grant_type: str
+ requested_token_type: str
+ audience: str
+ resource: str
+ subject_token: str
+ subject_token_type: str
+ scope: str
+ client_id: str
+ client_secret: str
+
+
+class JWTBearerGrantRequestData(TypedDict, total=False):
+ """Type definition for RFC 7523 JWT Bearer Grant request data."""
+
+ grant_type: str
+ assertion: str
+ client_id: str
+ client_secret: str
+
+
+class TokenExchangeParameters(BaseModel):
+ """Parameters for RFC 8693 Token Exchange request."""
+
+ requested_token_type: str = Field(
+ default="urn:ietf:params:oauth:token-type:id-jag",
+ description="Type of token being requested (ID-JAG)",
+ )
+
+ audience: str = Field(
+ ...,
+ description="Issuer URL of the MCP Server's authorization server",
+ )
+
+ resource: str = Field(
+ ...,
+ description="RFC 9728 Resource Identifier of the MCP Server",
+ )
+
+ scope: str | None = Field(
+ default=None,
+ description="Space-separated list of scopes being requested",
+ )
+
+ subject_token: str = Field(
+ ...,
+ description="ID Token or SAML assertion for the end user",
+ )
+
+ subject_token_type: str = Field(
+ ...,
+ description="Type of subject token (id_token or saml2)",
+ )
+
+ @classmethod
+ def from_id_token(
+ cls,
+ id_token: str,
+ mcp_server_auth_issuer: str,
+ mcp_server_resource_id: str,
+ scope: str | None = None,
+ ) -> "TokenExchangeParameters":
+ """Create parameters for OIDC ID Token exchange."""
+ return cls(
+ subject_token=id_token,
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience=mcp_server_auth_issuer,
+ resource=mcp_server_resource_id,
+ scope=scope,
+ )
+
+ @classmethod
+ def from_saml_assertion(
+ cls,
+ saml_assertion: str,
+ mcp_server_auth_issuer: str,
+ mcp_server_resource_id: str,
+ scope: str | None = None,
+ ) -> "TokenExchangeParameters":
+ """Create parameters for SAML assertion exchange."""
+ return cls(
+ subject_token=saml_assertion,
+ subject_token_type="urn:ietf:params:oauth:token-type:saml2",
+ audience=mcp_server_auth_issuer,
+ resource=mcp_server_resource_id,
+ scope=scope,
+ )
+
+
+class TokenExchangeResponse(BaseModel):
+ """Response from RFC 8693 Token Exchange."""
+
+ issued_token_type: str = Field(
+ ...,
+ description="Type of token issued (should be id-jag)",
+ )
+
+ access_token: str = Field(
+ ...,
+ description="The ID-JAG token (named access_token per RFC 8693)",
+ )
+
+ token_type: str = Field(
+ ...,
+ description="Token type (should be N_A for ID-JAG)",
+ )
+
+ scope: str | None = Field(
+ default=None,
+ description="Granted scopes",
+ )
+
+ expires_in: int | None = Field(
+ default=None,
+ description="Lifetime in seconds",
+ )
+
+ @property
+ def id_jag(self) -> str:
+ """Get the ID-JAG token."""
+ return self.access_token
+
+
+class IDJAGClaims(BaseModel):
+ """Claims structure for Identity Assertion JWT Authorization Grant."""
+
+ model_config = {"extra": "allow"}
+
+ # JWT header
+ typ: str = Field(
+ ...,
+ description="JWT type - must be 'oauth-id-jag+jwt'",
+ )
+
+ # Required claims
+ jti: str = Field(..., description="Unique JWT ID")
+ iss: str = Field(..., description="IdP issuer URL")
+ sub: str = Field(..., description="Subject (user) identifier")
+ aud: str = Field(..., description="MCP Server's auth server issuer")
+ resource: str = Field(..., description="MCP Server resource identifier")
+ client_id: str = Field(..., description="MCP Client identifier")
+ exp: int = Field(..., description="Expiration timestamp")
+ iat: int = Field(..., description="Issued-at timestamp")
+
+ # Optional claims
+ scope: str | None = Field(None, description="Space-separated scopes")
+ email: str | None = Field(None, description="User email")
+
+
+class EnterpriseAuthOAuthClientProvider(OAuthClientProvider):
+ """OAuth client provider for Enterprise Managed Authorization (SEP-990).
+
+ Implements:
+ - RFC 8693: Token Exchange (ID Token → ID-JAG)
+ - RFC 7523: JWT Bearer Grant (ID-JAG → Access Token)
+ """
+
+ def __init__(
+ self,
+ server_url: str,
+ client_metadata: OAuthClientMetadata,
+ storage: TokenStorage,
+ idp_token_endpoint: str,
+ token_exchange_params: TokenExchangeParameters,
+ redirect_handler: Callable[[str], Awaitable[None]] | None = None,
+ callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
+ timeout: float = 300.0,
+ ) -> None:
+ """Initialize Enterprise Auth OAuth Client.
+
+ Args:
+ server_url: MCP server URL
+ client_metadata: OAuth client metadata
+ storage: Token storage implementation
+ idp_token_endpoint: Enterprise IdP token endpoint URL
+ token_exchange_params: Token exchange parameters
+ redirect_handler: Optional redirect handler
+ callback_handler: Optional callback handler
+ timeout: Request timeout in seconds
+ """
+ super().__init__(
+ server_url=server_url,
+ client_metadata=client_metadata,
+ storage=storage,
+ redirect_handler=redirect_handler,
+ callback_handler=callback_handler,
+ timeout=timeout,
+ )
+ self.idp_token_endpoint = idp_token_endpoint
+ self.token_exchange_params = token_exchange_params
+ self._id_jag: str | None = None
+
+ async def exchange_token_for_id_jag(
+ self,
+ client: httpx.AsyncClient,
+ ) -> str:
+ """Exchange ID Token for ID-JAG using RFC 8693 Token Exchange.
+
+ Args:
+ client: HTTP client for making requests
+
+ Returns:
+ The ID-JAG token string
+
+ Raises:
+ OAuthTokenError: If token exchange fails
+ """
+ logger.info("Starting token exchange for ID-JAG")
+
+ # Build token exchange request
+ token_data: TokenExchangeRequestData = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "requested_token_type": self.token_exchange_params.requested_token_type,
+ "audience": self.token_exchange_params.audience,
+ "resource": self.token_exchange_params.resource,
+ "subject_token": self.token_exchange_params.subject_token,
+ "subject_token_type": self.token_exchange_params.subject_token_type,
+ }
+
+ if self.token_exchange_params.scope:
+ token_data["scope"] = self.token_exchange_params.scope
+
+ # Add client authentication if needed
+ if self.context.client_info:
+ if self.context.client_info.client_id is not None:
+ token_data["client_id"] = self.context.client_info.client_id
+ if self.context.client_info.client_secret is not None:
+ token_data["client_secret"] = self.context.client_info.client_secret
+
+ try:
+ response = await client.post(
+ self.idp_token_endpoint,
+ data=token_data,
+ timeout=self.context.timeout,
+ )
+
+ if response.status_code != 200:
+ error_data: dict[str, str] = (
+ response.json() if response.headers.get("content-type", "").startswith("application/json") else {}
+ )
+ error: str = error_data.get("error", "unknown_error")
+ error_description: str = error_data.get("error_description", "Token exchange failed")
+ raise OAuthTokenError(f"Token exchange failed: {error} - {error_description}")
+
+ # Parse response
+ token_response = TokenExchangeResponse.model_validate_json(response.content)
+
+ # Validate response
+ if token_response.issued_token_type != "urn:ietf:params:oauth:token-type:id-jag":
+ raise OAuthTokenError(f"Unexpected token type: {token_response.issued_token_type}")
+
+ if token_response.token_type != "N_A":
+ logger.warning(f"Expected token_type 'N_A', got '{token_response.token_type}'")
+
+ logger.info("Successfully obtained ID-JAG")
+ self._id_jag = token_response.id_jag
+ return token_response.id_jag
+
+ except httpx.HTTPError as e:
+ raise OAuthTokenError(f"HTTP error during token exchange: {e}") from e
+
+ async def exchange_id_jag_for_access_token(
+ self,
+ client: httpx.AsyncClient,
+ id_jag: str,
+ ) -> OAuthToken:
+ """Exchange ID-JAG for access token using RFC 7523 JWT Bearer Grant.
+
+ Args:
+ client: HTTP client for making requests
+ id_jag: The ID-JAG token
+
+ Returns:
+ OAuth access token
+
+ Raises:
+ OAuthTokenError: If JWT bearer grant fails
+ """
+ logger.info("Exchanging ID-JAG for access token")
+
+ # Discover token endpoint from MCP server if not already done
+ if not self.context.oauth_metadata or not self.context.oauth_metadata.token_endpoint:
+ raise OAuthFlowError("MCP server token endpoint not discovered")
+
+ token_endpoint = str(self.context.oauth_metadata.token_endpoint)
+
+ # Build JWT bearer grant request
+ token_data: JWTBearerGrantRequestData = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
+ "assertion": id_jag,
+ }
+
+ # Add client authentication
+ if self.context.client_info:
+ if self.context.client_info.client_id is not None:
+ token_data["client_id"] = self.context.client_info.client_id
+ if self.context.client_info.client_secret is not None:
+ token_data["client_secret"] = self.context.client_info.client_secret
+
+ try:
+ response = await client.post(
+ token_endpoint,
+ data=token_data,
+ timeout=self.context.timeout,
+ )
+
+ if response.status_code != 200:
+ error_data: dict[str, str] = (
+ response.json() if response.headers.get("content-type", "").startswith("application/json") else {}
+ )
+ error: str = error_data.get("error", "unknown_error")
+ error_description: str = error_data.get("error_description", "JWT bearer grant failed")
+ raise OAuthTokenError(f"JWT bearer grant failed: {error} - {error_description}")
+
+ # Parse OAuth token response
+ token = OAuthToken.model_validate_json(response.content)
+
+ # Store tokens
+ self.context.current_tokens = token
+ self.context.update_token_expiry(token)
+ await self.context.storage.set_tokens(token)
+
+ logger.info("Successfully obtained access token via ID-JAG")
+ return token
+
+ except httpx.HTTPError as e:
+ raise OAuthTokenError(f"HTTP error during JWT bearer grant: {e}") from e
+
+ async def _perform_authorization(self) -> httpx.Request:
+ """Perform enterprise authorization flow.
+
+ Overrides parent method to use token exchange + JWT bearer grant
+ instead of standard authorization code flow.
+ """
+ # Check if we already have valid tokens
+ if self.context.is_token_valid():
+ # Return a dummy request - we don't need to make any request
+ return httpx.Request("GET", self.context.server_url)
+
+ # For now, raise NotImplementedError as this requires integration
+ # with the full httpx auth flow
+ raise NotImplementedError(
+ "Full enterprise auth flow integration not yet implemented. "
+ "Use exchange_token_for_id_jag and exchange_id_jag_for_access_token directly."
+ )
+
+
+def decode_id_jag(id_jag: str) -> IDJAGClaims:
+ """Decode an ID-JAG token without verification.
+
+ Args:
+ id_jag: The ID-JAG token string
+
+ Returns:
+ Decoded ID-JAG claims
+
+ Note:
+ For verification, use server-side validation instead.
+ """
+ # Decode without verification for inspection
+ claims = jwt.decode(id_jag, options={"verify_signature": False})
+ header = jwt.get_unverified_header(id_jag)
+
+ # Add typ from header to claims
+ claims["typ"] = header.get("typ", "")
+
+ return IDJAGClaims.model_validate(claims)
+
+
+def validate_token_exchange_params(
+ params: TokenExchangeParameters,
+) -> None:
+ """Validate token exchange parameters.
+
+ Args:
+ params: Token exchange parameters to validate
+
+ Raises:
+ ValueError: If parameters are invalid
+ """
+ if not params.subject_token:
+ raise ValueError("subject_token is required")
+
+ if not params.audience:
+ raise ValueError("audience is required")
+
+ if not params.resource:
+ raise ValueError("resource is required")
+
+ if params.subject_token_type not in [
+ "urn:ietf:params:oauth:token-type:id_token",
+ "urn:ietf:params:oauth:token-type:saml2",
+ ]:
+ raise ValueError(f"Invalid subject_token_type: {params.subject_token_type}")
diff --git a/tests/client/auth/test_enterprise_managed_auth_client.py b/tests/client/auth/test_enterprise_managed_auth_client.py
new file mode 100644
index 000000000..729189766
--- /dev/null
+++ b/tests/client/auth/test_enterprise_managed_auth_client.py
@@ -0,0 +1,1108 @@
+"""Tests for Enterprise Managed Authorization client-side implementation."""
+
+import time
+from typing import Any
+from unittest.mock import AsyncMock, Mock, patch
+
+import httpx
+import jwt
+import pytest
+from pydantic import AnyHttpUrl, AnyUrl
+
+from mcp.client.auth import OAuthTokenError
+from mcp.client.auth.extensions.enterprise_managed_auth import (
+ EnterpriseAuthOAuthClientProvider,
+ IDJAGClaims,
+ TokenExchangeParameters,
+ TokenExchangeResponse,
+ decode_id_jag,
+ validate_token_exchange_params,
+)
+from mcp.shared.auth import OAuthClientMetadata
+
+
+@pytest.fixture
+def sample_id_token() -> str:
+ """Generate a sample ID token for testing."""
+ payload = {
+ "iss": "https://idp.example.com",
+ "sub": "user123",
+ "aud": "mcp-client-app",
+ "exp": int(time.time()) + 3600,
+ "iat": int(time.time()),
+ "email": "user@example.com",
+ }
+ return jwt.encode(payload, "secret", algorithm="HS256")
+
+
+@pytest.fixture
+def sample_id_jag() -> str:
+ """Generate a sample ID-JAG token for testing."""
+ # Create typed claims using IDJAGClaims model
+ claims = IDJAGClaims(
+ typ="oauth-id-jag+jwt",
+ jti="unique-jwt-id-12345",
+ iss="https://idp.example.com",
+ sub="user123",
+ aud="https://auth.mcp-server.example/",
+ resource="https://mcp-server.example/",
+ client_id="mcp-client-app",
+ exp=int(time.time()) + 300,
+ iat=int(time.time()),
+ scope="read write",
+ email=None, # Optional field
+ )
+
+ # Dump to dict for JWT encoding (exclude typ as it goes in header)
+ payload = claims.model_dump(exclude={"typ"}, exclude_none=True)
+
+ return jwt.encode(payload, "secret", algorithm="HS256", headers={"typ": "oauth-id-jag+jwt"})
+
+
+@pytest.fixture
+def mock_token_storage() -> Any:
+ """Create a mock token storage."""
+ storage = Mock()
+ storage.get_tokens = AsyncMock(return_value=None)
+ storage.set_tokens = AsyncMock()
+ storage.get_client_info = AsyncMock(return_value=None)
+ storage.set_client_info = AsyncMock()
+ return storage
+
+
+def test_token_exchange_params_from_id_token():
+ """Test creating TokenExchangeParameters from ID token."""
+ params = TokenExchangeParameters.from_id_token(
+ id_token="eyJhbGc...",
+ mcp_server_auth_issuer="https://auth.server.example/",
+ mcp_server_resource_id="https://server.example/",
+ scope="read write",
+ )
+
+ assert params.subject_token == "eyJhbGc..."
+ assert params.subject_token_type == "urn:ietf:params:oauth:token-type:id_token"
+ assert params.audience == "https://auth.server.example/"
+ assert params.resource == "https://server.example/"
+ assert params.scope == "read write"
+ assert params.requested_token_type == "urn:ietf:params:oauth:token-type:id-jag"
+
+
+def test_token_exchange_params_from_saml_assertion():
+ """Test creating TokenExchangeParameters from SAML assertion."""
+ params = TokenExchangeParameters.from_saml_assertion(
+ saml_assertion="...",
+ mcp_server_auth_issuer="https://auth.server.example/",
+ mcp_server_resource_id="https://server.example/",
+ scope="read",
+ )
+
+ assert params.subject_token == "..."
+ assert params.subject_token_type == "urn:ietf:params:oauth:token-type:saml2"
+ assert params.audience == "https://auth.server.example/"
+ assert params.resource == "https://server.example/"
+ assert params.scope == "read"
+
+
+def test_validate_token_exchange_params_valid():
+ """Test validating valid token exchange parameters."""
+ params = TokenExchangeParameters.from_id_token(
+ id_token="token",
+ mcp_server_auth_issuer="https://auth.example/",
+ mcp_server_resource_id="https://server.example/",
+ )
+
+ # Should not raise
+ validate_token_exchange_params(params)
+
+
+def test_validate_token_exchange_params_invalid_token_type():
+ """Test validation fails for invalid subject token type."""
+ params = TokenExchangeParameters(
+ subject_token="token",
+ subject_token_type="invalid:type",
+ audience="https://auth.example/",
+ resource="https://server.example/",
+ )
+
+ with pytest.raises(ValueError, match="Invalid subject_token_type"):
+ validate_token_exchange_params(params)
+
+
+def test_validate_token_exchange_params_missing_subject_token():
+ """Test validation fails for missing subject token."""
+ params = TokenExchangeParameters(
+ subject_token="",
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience="https://auth.example/",
+ resource="https://server.example/",
+ )
+
+ with pytest.raises(ValueError, match="subject_token is required"):
+ validate_token_exchange_params(params)
+
+
+def test_token_exchange_response_parsing():
+ """Test parsing token exchange response."""
+ response_json = """{
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": "eyJhbGc...",
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300
+ }"""
+
+ response = TokenExchangeResponse.model_validate_json(response_json)
+
+ assert response.issued_token_type == "urn:ietf:params:oauth:token-type:id-jag"
+ assert response.id_jag == "eyJhbGc..."
+ assert response.access_token == "eyJhbGc..."
+ assert response.token_type == "N_A"
+ assert response.scope == "read write"
+ assert response.expires_in == 300
+
+
+def test_token_exchange_response_id_jag_property():
+ """Test id_jag property returns access_token."""
+ response = TokenExchangeResponse(
+ issued_token_type="urn:ietf:params:oauth:token-type:id-jag",
+ access_token="the-id-jag-token",
+ token_type="N_A",
+ )
+
+ assert response.id_jag == "the-id-jag-token"
+
+
+def test_decode_id_jag(sample_id_jag: str):
+ """Test decoding ID-JAG token."""
+ claims = decode_id_jag(sample_id_jag)
+
+ assert claims.iss == "https://idp.example.com"
+ assert claims.sub == "user123"
+ assert claims.aud == "https://auth.mcp-server.example/"
+ assert claims.resource == "https://mcp-server.example/"
+ assert claims.client_id == "mcp-client-app"
+ assert claims.scope == "read write"
+
+
+def test_id_jag_claims_with_extra_fields():
+ """Test IDJAGClaims allows extra fields."""
+ claims_data = {
+ "typ": "oauth-id-jag+jwt",
+ "jti": "jti123",
+ "iss": "https://idp.example.com",
+ "sub": "user123",
+ "aud": "https://auth.server.example/",
+ "resource": "https://server.example/",
+ "client_id": "client123",
+ "exp": int(time.time()) + 300,
+ "iat": int(time.time()),
+ "scope": "read",
+ "email": "user@example.com",
+ "custom_claim": "custom_value", # Extra field
+ }
+
+ claims = IDJAGClaims.model_validate(claims_data)
+ assert claims.email == "user@example.com"
+ # Extra field should be preserved
+ assert claims.model_extra is not None and claims.model_extra.get("custom_claim") == "custom_value"
+
+
+# ============================================================================
+# Tests for EnterpriseAuthOAuthClientProvider
+# ============================================================================
+
+
+@pytest.mark.anyio
+async def test_exchange_token_for_id_jag_success(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test successful token exchange for ID-JAG."""
+ # Create provider
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify
+ assert id_jag == sample_id_jag
+ assert provider._id_jag == sample_id_jag
+
+ # Verify request was made correctly
+ mock_client.post.assert_called_once()
+ call_args = mock_client.post.call_args
+ assert call_args[0][0] == "https://idp.example.com/oauth2/token"
+ assert call_args[1]["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:token-exchange"
+ assert call_args[1]["data"]["requested_token_type"] == "urn:ietf:params:oauth:token-type:id-jag"
+ assert call_args[1]["data"]["audience"] == "https://auth.mcp-server.example/"
+ assert call_args[1]["data"]["resource"] == "https://mcp-server.example/"
+
+
+@pytest.mark.anyio
+async def test_exchange_token_for_id_jag_error(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange failure handling."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock error response
+ mock_response = httpx.Response(
+ status_code=400,
+ json={
+ "error": "invalid_request",
+ "error_description": "Invalid subject token",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="Token exchange failed"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_for_id_jag_unexpected_token_type(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with unexpected token type."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock response with wrong token type
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "access_token": "some-token",
+ "token_type": "Bearer",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="Unexpected token type"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_for_access_token_success(sample_id_jag: str, mock_token_storage: Any):
+ """Test successful JWT bearer grant to get access token."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ from mcp.shared.auth import OAuthMetadata
+
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "token_type": "Bearer",
+ "access_token": "mcp-access-token-12345",
+ "expires_in": 3600,
+ "scope": "read write",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform JWT bearer grant
+ token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+ # Verify
+ assert token.access_token == "mcp-access-token-12345"
+ assert token.token_type == "Bearer"
+ assert token.expires_in == 3600
+
+ # Verify tokens were stored
+ mock_token_storage.set_tokens.assert_called_once()
+
+ # Verify request was made correctly
+ mock_client.post.assert_called_once()
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["grant_type"] == "urn:ietf:params:oauth:grant-type:jwt-bearer"
+ assert call_args[1]["data"]["assertion"] == sample_id_jag
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_for_access_token_no_metadata(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant fails without OAuth metadata."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # No OAuth metadata set
+ mock_client = Mock(spec=httpx.AsyncClient)
+
+ # Should raise OAuthFlowError
+ from mcp.client.auth import OAuthFlowError
+
+ with pytest.raises(OAuthFlowError, match="token endpoint not discovered"):
+ await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_not_implemented(mock_token_storage: Any):
+ """Test that _perform_authorization raises NotImplementedError."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Should raise NotImplementedError
+ with pytest.raises(NotImplementedError, match="not yet implemented"):
+ await provider._perform_authorization()
+
+
+@pytest.mark.anyio
+async def test_perform_authorization_with_valid_tokens(mock_token_storage: Any):
+ """Test that _perform_authorization returns dummy request when tokens are valid."""
+ from mcp.shared.auth import OAuthToken
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set valid tokens
+ provider.context.current_tokens = OAuthToken(
+ token_type="Bearer",
+ access_token="valid-token",
+ expires_in=3600,
+ )
+ provider.context.token_expiry_time = time.time() + 3600
+
+ # Should return a dummy request
+ request = await provider._perform_authorization()
+ assert request.method == "GET"
+ assert str(request.url) == "https://mcp-server.example/"
+
+
+@pytest.mark.anyio
+async def test_exchange_token_with_client_authentication(
+ sample_id_token: str, sample_id_jag: str, mock_token_storage: Any
+):
+ """Test token exchange with client authentication."""
+ from mcp.shared.auth import OAuthClientInformationFull
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info with secret
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret="test-client-secret",
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify the ID-JAG was returned
+ assert id_jag == sample_id_jag
+
+ # Verify client credentials were included
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["client_id"] == "test-client-id"
+ assert call_args[1]["data"]["client_secret"] == "test-client-secret"
+
+
+@pytest.mark.anyio
+async def test_exchange_token_with_client_id_only(sample_id_token: str, sample_id_jag: str, mock_token_storage: Any):
+ """Test token exchange with client_id but no client_secret (covers branch 232->235)."""
+ from mcp.shared.auth import OAuthClientInformationFull
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info WITHOUT secret (client_secret=None)
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret=None, # No secret
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify the ID-JAG was returned
+ assert id_jag == sample_id_jag
+
+ # Verify client_id was included but NOT client_secret
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["client_id"] == "test-client-id"
+ assert "client_secret" not in call_args[1]["data"]
+
+
+@pytest.mark.anyio
+async def test_exchange_token_http_error(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with HTTP error."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection failed"))
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="HTTP error during token exchange"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_non_json_error_response(sample_id_token: str, mock_token_storage: Any):
+ """Test token exchange with non-JSON error response."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock error response with non-JSON content
+ mock_response = httpx.Response(
+ status_code=500,
+ content=b"Internal Server Error",
+ headers={"content-type": "text/plain"},
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError with default error
+ with pytest.raises(OAuthTokenError, match="Token exchange failed: unknown_error"):
+ await provider.exchange_token_for_id_jag(mock_client)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_warning_for_non_na_token_type(
+ sample_id_token: str, sample_id_jag: str, mock_token_storage: Any
+):
+ """Test token exchange logs warning for non-N_A token type."""
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Mock response with different token_type
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "Bearer", # Not N_A
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should succeed but log warning
+ import logging
+
+ with patch.object(
+ logging.getLogger("mcp.client.auth.extensions.enterprise_managed_auth"), "warning"
+ ) as mock_warning:
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+ assert id_jag == sample_id_jag
+ mock_warning.assert_called_once()
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_client_authentication(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant with client authentication."""
+ from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info with secret
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret="test-client-secret",
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "token_type": "Bearer",
+ "access_token": "mcp-access-token-12345",
+ "expires_in": 3600,
+ "scope": "read write",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform JWT bearer grant
+ token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+ # Verify token was returned
+ assert token.access_token == "mcp-access-token-12345"
+
+ # Verify client credentials were included
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["client_id"] == "test-client-id"
+ assert call_args[1]["data"]["client_secret"] == "test-client-secret"
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_client_id_only(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant with client_id but no client_secret (covers branch 304->307)."""
+ from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info WITHOUT secret (client_secret=None)
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id="test-client-id",
+ client_secret=None, # No secret
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "token_type": "Bearer",
+ "access_token": "mcp-access-token-12345",
+ "expires_in": 3600,
+ "scope": "read write",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform JWT bearer grant
+ token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+ # Verify token was returned correctly
+ assert token.access_token == "mcp-access-token-12345"
+ assert token.token_type == "Bearer"
+
+ # Verify client_id was included but NOT client_secret
+ call_args = mock_client.post.call_args
+ assert call_args[1]["data"]["client_id"] == "test-client-id"
+ assert "client_secret" not in call_args[1]["data"]
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_error_response(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant with error response."""
+ from mcp.shared.auth import OAuthMetadata
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock error response
+ mock_response = httpx.Response(
+ status_code=400,
+ json={
+ "error": "invalid_grant",
+ "error_description": "Invalid assertion",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="JWT bearer grant failed"):
+ await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_non_json_error(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant with non-JSON error response."""
+ from mcp.shared.auth import OAuthMetadata
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Mock error response with non-JSON content
+ mock_response = httpx.Response(
+ status_code=503,
+ content=b"Service Unavailable",
+ headers={"content-type": "text/html"},
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Should raise OAuthTokenError with default error
+ with pytest.raises(OAuthTokenError, match="JWT bearer grant failed: unknown_error"):
+ await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_http_error(sample_id_jag: str, mock_token_storage: Any):
+ """Test JWT bearer grant with HTTP error."""
+ from mcp.shared.auth import OAuthMetadata
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(side_effect=httpx.ReadTimeout("Request timeout"))
+
+ # Should raise OAuthTokenError
+ with pytest.raises(OAuthTokenError, match="HTTP error during JWT bearer grant"):
+ await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+
+@pytest.mark.anyio
+async def test_exchange_token_with_client_info_but_no_client_id(
+ sample_id_token: str, sample_id_jag: str, mock_token_storage: Any
+):
+ """Test token exchange when client_info exists but client_id is None (covers line 231)."""
+ from mcp.shared.auth import OAuthClientInformationFull
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token=sample_id_token,
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ scope="read write",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ client_name="Test Client",
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set client info with client_id=None
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id=None, # This should skip the client_id assignment
+ client_secret="test-secret",
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "issued_token_type": "urn:ietf:params:oauth:token-type:id-jag",
+ "access_token": sample_id_jag,
+ "token_type": "N_A",
+ "scope": "read write",
+ "expires_in": 300,
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform token exchange
+ id_jag = await provider.exchange_token_for_id_jag(mock_client)
+
+ # Verify the ID-JAG was returned
+ assert id_jag == sample_id_jag
+
+ # Verify client_id was not included (None), but client_secret was included
+ call_args = mock_client.post.call_args
+ assert "client_id" not in call_args[1]["data"]
+ assert call_args[1]["data"]["client_secret"] == "test-secret"
+
+
+@pytest.mark.anyio
+async def test_exchange_id_jag_with_client_info_but_no_client_id(sample_id_jag: str, mock_token_storage: Any):
+ """Test ID-JAG exchange when client_info exists but client_id is None (covers line 302)."""
+ from pydantic import AnyHttpUrl
+
+ from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata
+
+ token_exchange_params = TokenExchangeParameters.from_id_token(
+ id_token="dummy-token",
+ mcp_server_auth_issuer="https://auth.mcp-server.example/",
+ mcp_server_resource_id="https://mcp-server.example/",
+ )
+
+ provider = EnterpriseAuthOAuthClientProvider(
+ server_url="https://mcp-server.example/",
+ client_metadata=OAuthClientMetadata(
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ ),
+ storage=mock_token_storage,
+ idp_token_endpoint="https://idp.example.com/oauth2/token",
+ token_exchange_params=token_exchange_params,
+ )
+
+ # Set up OAuth metadata
+ provider.context.oauth_metadata = OAuthMetadata(
+ issuer=AnyHttpUrl("https://auth.mcp-server.example/"),
+ authorization_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/authorize"),
+ token_endpoint=AnyHttpUrl("https://auth.mcp-server.example/oauth2/token"),
+ )
+
+ # Set client info with client_id=None
+ provider.context.client_info = OAuthClientInformationFull(
+ client_id=None, # This should skip the client_id assignment
+ client_secret="test-secret",
+ redirect_uris=[AnyUrl("http://localhost:8080/callback")],
+ )
+
+ # Mock HTTP response
+ mock_response = httpx.Response(
+ status_code=200,
+ json={
+ "token_type": "Bearer",
+ "access_token": "mcp-access-token-12345",
+ "expires_in": 3600,
+ "scope": "read write",
+ },
+ )
+
+ mock_client = Mock(spec=httpx.AsyncClient)
+ mock_client.post = AsyncMock(return_value=mock_response)
+
+ # Perform JWT bearer grant
+ token = await provider.exchange_id_jag_for_access_token(mock_client, sample_id_jag)
+
+ # Verify
+ assert token.access_token == "mcp-access-token-12345"
+ assert token.token_type == "Bearer"
+ assert token.expires_in == 3600
+
+ # Verify client_id was not included (None), but client_secret was included
+ call_args = mock_client.post.call_args
+ assert "client_id" not in call_args[1]["data"]
+ assert call_args[1]["data"]["client_secret"] == "test-secret"
+
+
+def test_validate_token_exchange_params_missing_audience():
+ """Test validation fails for missing audience."""
+ params = TokenExchangeParameters(
+ subject_token="token",
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience="",
+ resource="https://server.example/",
+ )
+
+ with pytest.raises(ValueError, match="audience is required"):
+ validate_token_exchange_params(params)
+
+
+def test_validate_token_exchange_params_missing_resource():
+ """Test validation fails for missing resource."""
+ params = TokenExchangeParameters(
+ subject_token="token",
+ subject_token_type="urn:ietf:params:oauth:token-type:id_token",
+ audience="https://auth.example/",
+ resource="",
+ )
+
+ with pytest.raises(ValueError, match="resource is required"):
+ validate_token_exchange_params(params)
diff --git a/uv.lock b/uv.lock
index 5d36da2e3..256b1cc02 100644
--- a/uv.lock
+++ b/uv.lock
@@ -29,6 +29,15 @@ members = [
"mcp-structured-output-lowlevel",
]
+[[package]]
+name = "annotated-doc"
+version = "0.0.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" },
+]
+
[[package]]
name = "annotated-types"
version = "0.7.0"
@@ -494,6 +503,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" },
]
+[[package]]
+name = "fastapi"
+version = "0.128.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "annotated-doc" },
+ { name = "pydantic", version = "2.11.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" },
+ { name = "pydantic", version = "2.12.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" },
+ { name = "starlette" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682, upload-time = "2025-12-27T15:21:13.714Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094, upload-time = "2025-12-27T15:21:12.154Z" },
+]
+
[[package]]
name = "ghp-import"
version = "2.1.0"
@@ -754,6 +779,7 @@ ws = [
dev = [
{ name = "coverage", extra = ["toml"] },
{ name = "dirty-equals" },
+ { name = "fastapi" },
{ name = "inline-snapshot" },
{ name = "mcp", extra = ["cli", "ws"] },
{ name = "pillow" },
@@ -802,6 +828,7 @@ provides-extras = ["cli", "rich", "ws"]
dev = [
{ name = "coverage", extras = ["toml"], specifier = ">=7.13.1" },
{ name = "dirty-equals", specifier = ">=0.9.0" },
+ { name = "fastapi", specifier = ">=0.115.0" },
{ name = "inline-snapshot", specifier = ">=0.23.0" },
{ name = "mcp", extras = ["cli", "ws"], editable = "." },
{ name = "pillow", specifier = ">=12.0" },