diff --git a/README.md b/README.md index 77bcf8db..30e29ac5 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,14 @@ A toolkit for building LLM-powered applications and agent loops. uv add ai ``` -AI Gateway usage works with the base package. Direct providers that use an -OpenAI-compatible or Anthropic-compatible adapter load the corresponding -official SDK lazily and require optional extras: +AI Gateway API-key usage works with the base package. Direct providers that +use an OpenAI-compatible or Anthropic-compatible adapter load the corresponding +official SDK lazily. Vercel OIDC for AI Gateway also uses an optional extra: ```bash uv add "ai[openai]" # OpenAI-compatible providers uv add "ai[anthropic]" # Anthropic-compatible providers +uv add "ai[vercel]" # Vercel OIDC for AI Gateway ``` ```python diff --git a/examples/fastapi-vite/backend/main.py b/examples/fastapi-vite/backend/main.py index b0c5bb65..eed26ebc 100644 --- a/examples/fastapi-vite/backend/main.py +++ b/examples/fastapi-vite/backend/main.py @@ -2,8 +2,9 @@ from __future__ import annotations +import importlib import sys -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol, cast import agent as agent_ import fastapi @@ -17,11 +18,67 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator + import starlette.types + + +class _VercelHeaders(Protocol): + def set_headers(self, headers: dict[str, str] | None) -> None: ... + + +class VercelOIDCHeadersMiddleware: + def __init__(self, app: starlette.types.ASGIApp) -> None: + self.app = app + + async def __call__( + self, + scope: starlette.types.Scope, + receive: starlette.types.Receive, + send: starlette.types.Send, + ) -> None: + headers = _vercel_headers() + if scope.get("type") != "http" or headers is None: + await self.app(scope, receive, send) + return + + headers.set_headers(_scope_headers(scope)) + try: + await self.app(scope, receive, send) + finally: + headers.set_headers(None) + + +def _vercel_headers() -> _VercelHeaders | None: + try: + return cast( + "_VercelHeaders", + importlib.import_module("vercel.headers"), + ) + except ModuleNotFoundError as exc: + if exc.name not in {"vercel", "vercel.headers"}: + raise + return None + + +def _scope_headers(scope: starlette.types.Scope) -> dict[str, str]: + return { + _header_text(key): _header_text(value) + for key, value in scope.get("headers", []) + } + + +def _header_text(value: object) -> str: + if isinstance(value, bytes | bytearray): + return bytes(value).decode("latin1") + return str(value) + + app = fastapi.FastAPI( title="py-ai-fastapi-chat", description="Chat demo using Python Vercel AI SDK", ) +app.add_middleware(VercelOIDCHeadersMiddleware) + app.add_middleware( fastapi.middleware.cors.CORSMiddleware, allow_origins=["*"], diff --git a/examples/fastapi-vite/e2e-test/run.sh b/examples/fastapi-vite/e2e-test/run.sh index 8374ea15..9b355af4 100755 --- a/examples/fastapi-vite/e2e-test/run.sh +++ b/examples/fastapi-vite/e2e-test/run.sh @@ -38,7 +38,7 @@ trap cleanup EXIT echo "Starting backend on :$BACKEND_PORT..." ( cd "$ROOT/backend" - uv run --frozen --with-editable ~/src/py-ai/ fastapi dev main.py --port "$BACKEND_PORT" + uv run --frozen --with-editable "$ROOT/../.." fastapi dev main.py --port "$BACKEND_PORT" ) > "$LOGS/backend.log" 2>&1 & BACKEND_PID=$! diff --git a/examples/fastapi-vite/frontend/pnpm-lock.yaml b/examples/fastapi-vite/frontend/pnpm-lock.yaml index 2fa7c898..fa8cbc94 100644 --- a/examples/fastapi-vite/frontend/pnpm-lock.yaml +++ b/examples/fastapi-vite/frontend/pnpm-lock.yaml @@ -112,112 +112,6 @@ importers: specifier: ^7.2.4 version: 7.3.1(@types/node@24.10.11)(jiti@2.6.1)(lightningcss@1.30.2) - frontend: - dependencies: - '@ai-sdk/react': - specifier: ^3.0.74 - version: 3.0.74(react@19.2.4)(zod@4.3.6) - '@streamdown/cjk': - specifier: ^1.0.1 - version: 1.0.1(@types/mdast@4.0.4)(micromark-util-types@2.0.2)(micromark@4.0.2)(react@19.2.4)(unified@11.0.5) - '@streamdown/code': - specifier: ^1.0.1 - version: 1.0.1(react@19.2.4) - '@streamdown/math': - specifier: ^1.0.1 - version: 1.0.1(react@19.2.4) - '@streamdown/mermaid': - specifier: ^1.0.1 - version: 1.0.1(react@19.2.4) - '@tanstack/react-query': - specifier: ^5.90.20 - version: 5.90.20(react@19.2.4) - '@tanstack/react-router': - specifier: ^1.158.1 - version: 1.158.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - ai: - specifier: ^6.0.72 - version: 6.0.72(zod@4.3.6) - class-variance-authority: - specifier: ^0.7.1 - version: 0.7.1 - clsx: - specifier: ^2.1.1 - version: 2.1.1 - cmdk: - specifier: ^1.1.1 - version: 1.1.1(@types/react-dom@19.2.3(@types/react@19.2.13))(@types/react@19.2.13)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - lucide-react: - specifier: ^0.563.0 - version: 0.563.0(react@19.2.4) - nanoid: - specifier: ^5.1.6 - version: 5.1.6 - radix-ui: - specifier: ^1.4.3 - version: 1.4.3(@types/react-dom@19.2.3(@types/react@19.2.13))(@types/react@19.2.13)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - react: - specifier: ^19.2.0 - version: 19.2.4 - react-dom: - specifier: ^19.2.0 - version: 19.2.4(react@19.2.4) - shiki: - specifier: ^3.22.0 - version: 3.22.0 - streamdown: - specifier: ^2.1.0 - version: 2.1.0(react@19.2.4) - tailwind-merge: - specifier: ^3.4.0 - version: 3.4.0 - use-stick-to-bottom: - specifier: ^1.1.2 - version: 1.1.2(react@19.2.4) - devDependencies: - '@eslint/js': - specifier: ^9.39.1 - version: 9.39.2 - '@tailwindcss/vite': - specifier: ^4.1.18 - version: 4.1.18(vite@7.3.1(@types/node@24.10.11)(jiti@2.6.1)(lightningcss@1.30.2)) - '@types/node': - specifier: ^24.10.1 - version: 24.10.11 - '@types/react': - specifier: ^19.2.5 - version: 19.2.13 - '@types/react-dom': - specifier: ^19.2.3 - version: 19.2.3(@types/react@19.2.13) - '@vitejs/plugin-react': - specifier: ^5.1.1 - version: 5.1.3(vite@7.3.1(@types/node@24.10.11)(jiti@2.6.1)(lightningcss@1.30.2)) - eslint: - specifier: ^9.39.1 - version: 9.39.2(jiti@2.6.1) - eslint-plugin-react-hooks: - specifier: ^7.0.1 - version: 7.0.1(eslint@9.39.2(jiti@2.6.1)) - eslint-plugin-react-refresh: - specifier: ^0.4.24 - version: 0.4.26(eslint@9.39.2(jiti@2.6.1)) - globals: - specifier: ^16.5.0 - version: 16.5.0 - tailwindcss: - specifier: ^4.1.18 - version: 4.1.18 - typescript: - specifier: ~5.9.3 - version: 5.9.3 - typescript-eslint: - specifier: ^8.46.4 - version: 8.54.0(eslint@9.39.2(jiti@2.6.1))(typescript@5.9.3) - vite: - specifier: ^7.2.4 - version: 7.3.1(@types/node@24.10.11)(jiti@2.6.1)(lightningcss@1.30.2) - packages: '@ai-sdk/gateway@3.0.35': @@ -1349,66 +1243,79 @@ packages: resolution: {integrity: sha512-F8sWbhZ7tyuEfsmOxwc2giKDQzN3+kuBLPwwZGyVkLlKGdV1nvnNwYD0fKQ8+XS6hp9nY7B+ZeK01EBUE7aHaw==} cpu: [arm] os: [linux] + libc: [glibc] '@rollup/rollup-linux-arm-musleabihf@4.57.1': resolution: {integrity: sha512-rGfNUfn0GIeXtBP1wL5MnzSj98+PZe/AXaGBCRmT0ts80lU5CATYGxXukeTX39XBKsxzFpEeK+Mrp9faXOlmrw==} cpu: [arm] os: [linux] + libc: [musl] '@rollup/rollup-linux-arm64-gnu@4.57.1': resolution: {integrity: sha512-MMtej3YHWeg/0klK2Qodf3yrNzz6CGjo2UntLvk2RSPlhzgLvYEB3frRvbEF2wRKh1Z2fDIg9KRPe1fawv7C+g==} cpu: [arm64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-arm64-musl@4.57.1': resolution: {integrity: sha512-1a/qhaaOXhqXGpMFMET9VqwZakkljWHLmZOX48R0I/YLbhdxr1m4gtG1Hq7++VhVUmf+L3sTAf9op4JlhQ5u1Q==} cpu: [arm64] os: [linux] + libc: [musl] '@rollup/rollup-linux-loong64-gnu@4.57.1': resolution: {integrity: sha512-QWO6RQTZ/cqYtJMtxhkRkidoNGXc7ERPbZN7dVW5SdURuLeVU7lwKMpo18XdcmpWYd0qsP1bwKPf7DNSUinhvA==} cpu: [loong64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-loong64-musl@4.57.1': resolution: {integrity: sha512-xpObYIf+8gprgWaPP32xiN5RVTi/s5FCR+XMXSKmhfoJjrpRAjCuuqQXyxUa/eJTdAE6eJ+KDKaoEqjZQxh3Gw==} cpu: [loong64] os: [linux] + libc: [musl] '@rollup/rollup-linux-ppc64-gnu@4.57.1': resolution: {integrity: sha512-4BrCgrpZo4hvzMDKRqEaW1zeecScDCR+2nZ86ATLhAoJ5FQ+lbHVD3ttKe74/c7tNT9c6F2viwB3ufwp01Oh2w==} cpu: [ppc64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-ppc64-musl@4.57.1': resolution: {integrity: sha512-NOlUuzesGauESAyEYFSe3QTUguL+lvrN1HtwEEsU2rOwdUDeTMJdO5dUYl/2hKf9jWydJrO9OL/XSSf65R5+Xw==} cpu: [ppc64] os: [linux] + libc: [musl] '@rollup/rollup-linux-riscv64-gnu@4.57.1': resolution: {integrity: sha512-ptA88htVp0AwUUqhVghwDIKlvJMD/fmL/wrQj99PRHFRAG6Z5nbWoWG4o81Nt9FT+IuqUQi+L31ZKAFeJ5Is+A==} cpu: [riscv64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-riscv64-musl@4.57.1': resolution: {integrity: sha512-S51t7aMMTNdmAMPpBg7OOsTdn4tySRQvklmL3RpDRyknk87+Sp3xaumlatU+ppQ+5raY7sSTcC2beGgvhENfuw==} cpu: [riscv64] os: [linux] + libc: [musl] '@rollup/rollup-linux-s390x-gnu@4.57.1': resolution: {integrity: sha512-Bl00OFnVFkL82FHbEqy3k5CUCKH6OEJL54KCyx2oqsmZnFTR8IoNqBF+mjQVcRCT5sB6yOvK8A37LNm/kPJiZg==} cpu: [s390x] os: [linux] + libc: [glibc] '@rollup/rollup-linux-x64-gnu@4.57.1': resolution: {integrity: sha512-ABca4ceT4N+Tv/GtotnWAeXZUZuM/9AQyCyKYyKnpk4yoA7QIAuBt6Hkgpw8kActYlew2mvckXkvx0FfoInnLg==} cpu: [x64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-x64-musl@4.57.1': resolution: {integrity: sha512-HFps0JeGtuOR2convgRRkHCekD7j+gdAuXM+/i6kGzQtFhlCtQkpwtNzkNj6QhCDp7DRJ7+qC/1Vg2jt5iSOFw==} cpu: [x64] os: [linux] + libc: [musl] '@rollup/rollup-openbsd-x64@4.57.1': resolution: {integrity: sha512-H+hXEv9gdVQuDTgnqD+SQffoWoc0Of59AStSzTEj/feWTBAnSfSD3+Dql1ZruJQxmykT/JVY0dE8Ka7z0DH1hw==} @@ -1522,24 +1429,28 @@ packages: engines: {node: '>= 10'} cpu: [arm64] os: [linux] + libc: [glibc] '@tailwindcss/oxide-linux-arm64-musl@4.1.18': resolution: {integrity: sha512-1px92582HkPQlaaCkdRcio71p8bc8i/ap5807tPRDK/uw953cauQBT8c5tVGkOwrHMfc2Yh6UuxaH4vtTjGvHg==} engines: {node: '>= 10'} cpu: [arm64] os: [linux] + libc: [musl] '@tailwindcss/oxide-linux-x64-gnu@4.1.18': resolution: {integrity: sha512-v3gyT0ivkfBLoZGF9LyHmts0Isc8jHZyVcbzio6Wpzifg/+5ZJpDiRiUhDLkcr7f/r38SWNe7ucxmGW3j3Kb/g==} engines: {node: '>= 10'} cpu: [x64] os: [linux] + libc: [glibc] '@tailwindcss/oxide-linux-x64-musl@4.1.18': resolution: {integrity: sha512-bhJ2y2OQNlcRwwgOAGMY0xTFStt4/wyU6pvI6LSuZpRgKQwxTec0/3Scu91O8ir7qCR3AuepQKLU/kX99FouqQ==} engines: {node: '>= 10'} cpu: [x64] os: [linux] + libc: [musl] '@tailwindcss/oxide-wasm32-wasi@4.1.18': resolution: {integrity: sha512-LffYTvPjODiP6PT16oNeUQJzNVyJl1cjIebq/rWWBF+3eDst5JGEFSc5cWxyRCJ0Mxl+KyIkqRxk1XPEs9x8TA==} @@ -2535,24 +2446,28 @@ packages: engines: {node: '>= 12.0.0'} cpu: [arm64] os: [linux] + libc: [glibc] lightningcss-linux-arm64-musl@1.30.2: resolution: {integrity: sha512-5Vh9dGeblpTxWHpOx8iauV02popZDsCYMPIgiuw97OJ5uaDsL86cnqSFs5LZkG3ghHoX5isLgWzMs+eD1YzrnA==} engines: {node: '>= 12.0.0'} cpu: [arm64] os: [linux] + libc: [musl] lightningcss-linux-x64-gnu@1.30.2: resolution: {integrity: sha512-Cfd46gdmj1vQ+lR6VRTTadNHu6ALuw2pKR9lYq4FnhvgBc4zWY1EtZcAc6EffShbb1MFrIPfLDXD6Xprbnni4w==} engines: {node: '>= 12.0.0'} cpu: [x64] os: [linux] + libc: [glibc] lightningcss-linux-x64-musl@1.30.2: resolution: {integrity: sha512-XJaLUUFXb6/QG2lGIW6aIk6jKdtjtcffUT0NKvIqhSBY3hh9Ch+1LCeH80dR9q9LBjG3ewbDjnumefsLsP6aiA==} engines: {node: '>= 12.0.0'} cpu: [x64] os: [linux] + libc: [musl] lightningcss-win32-arm64-msvc@1.30.2: resolution: {integrity: sha512-FZn+vaj7zLv//D/192WFFVA0RgHawIcHqLX9xuWiQt7P0PtdFEVaxgF9rjM/IRYHQXNnk61/H/gb2Ei+kUQ4xQ==} diff --git a/pyproject.toml b/pyproject.toml index 32e4a38c..4d93e286 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,9 @@ dependencies = [ anthropic = ["anthropic>=0.83.0"] mcp = ["mcp>=1.18.0"] openai = ["openai>=2.14.0"] +vercel = [ + "vercel>=0.5.9", +] [build-system] requires = ["hatchling", "uv-dynamic-versioning>=0.7.0"] diff --git a/src/ai/providers/ai_gateway/client/__init__.py b/src/ai/providers/ai_gateway/client/__init__.py index 9642109b..8a650f05 100644 --- a/src/ai/providers/ai_gateway/client/__init__.py +++ b/src/ai/providers/ai_gateway/client/__init__.py @@ -1,6 +1,6 @@ """Async client for the AI Gateway provider protocol.""" from . import errors -from ._client import GatewayClient, ModelType +from ._client import AuthMethod, GatewayClient, ModelType -__all__ = ["GatewayClient", "ModelType", "errors"] +__all__ = ["AuthMethod", "GatewayClient", "ModelType", "errors"] diff --git a/src/ai/providers/ai_gateway/client/_client.py b/src/ai/providers/ai_gateway/client/_client.py index 206f0b2e..48028602 100644 --- a/src/ai/providers/ai_gateway/client/_client.py +++ b/src/ai/providers/ai_gateway/client/_client.py @@ -19,6 +19,7 @@ _PROTOCOL_VERSION = "0.0.1" ModelType = Literal["language", "image", "video"] +AuthMethod = Literal["api-key", "oidc"] class GatewayClient: @@ -34,17 +35,34 @@ def __init__( *, base_url: str, api_key: str | None = None, + auth_token: str | None = None, + auth_method: AuthMethod | None = None, headers: Mapping[str, str] | None = None, client: httpx.AsyncClient | None = None, ) -> None: self.base_url = base_url - self.api_key = api_key + self.auth_token = auth_token if auth_token is not None else api_key + self.auth_method: AuthMethod | None = ( + auth_method if auth_method is not None else "api-key" + ) self.headers = dict(headers or {}) self._http = client or httpx.AsyncClient( timeout=httpx.Timeout(timeout=300.0, connect=10.0), ) self._owns_http = client is None + @property + def api_key(self) -> str | None: + """API key auth token, if this client is using API key auth.""" + if self.auth_method == "api-key": + return self.auth_token + return None + + @api_key.setter + def api_key(self, value: str | None) -> None: + self.auth_token = value + self.auth_method = "api-key" if value else None + async def aclose(self) -> None: if self._owns_http and not self._http.is_closed: await self._http.aclose() @@ -59,9 +77,9 @@ def origin_url(self, path: str) -> str: def protocol_headers(self) -> dict[str, str]: headers = dict(self.headers) headers["ai-gateway-protocol-version"] = _PROTOCOL_VERSION - if self.api_key: - headers["Authorization"] = f"Bearer {self.api_key}" - headers["ai-gateway-auth-method"] = "api-key" + if self.auth_token and self.auth_method: + headers["Authorization"] = f"Bearer {self.auth_token}" + headers["ai-gateway-auth-method"] = self.auth_method return headers def model_headers( @@ -165,7 +183,10 @@ async def probe_model(self, model_id: str) -> None: auth_resp = await self.get("v1/credits", origin=True) if auth_resp.status_code in {401, 403}: raise errors.GatewayAuthenticationError.create_contextual( - api_key_provided=bool(self.api_key), + api_key_provided=self.auth_method == "api-key" + and bool(self.auth_token), + oidc_token_provided=self.auth_method == "oidc" + and bool(self.auth_token), status_code=auth_resp.status_code, ) if auth_resp.status_code != 200: @@ -239,7 +260,10 @@ async def raise_for_error(self, response: httpx.Response) -> None: raise errors.create_gateway_error( response_body=response.text, status_code=response.status_code, - api_key_provided=bool(self.api_key), + api_key_provided=self.auth_method == "api-key" + and bool(self.auth_token), + oidc_token_provided=self.auth_method == "oidc" + and bool(self.auth_token), ) async def iter_sse( diff --git a/src/ai/providers/ai_gateway/client/errors.py b/src/ai/providers/ai_gateway/client/errors.py index 0cfa774f..9b8db66b 100644 --- a/src/ai/providers/ai_gateway/client/errors.py +++ b/src/ai/providers/ai_gateway/client/errors.py @@ -92,6 +92,7 @@ def create_contextual( cls, *, api_key_provided: bool, + oidc_token_provided: bool = False, status_code: int = 401, generation_id: str | None = None, ) -> Self: @@ -103,6 +104,12 @@ def create_contextual( "Provide via 'api_key' option or " "'AI_GATEWAY_API_KEY' environment variable." ) + elif oidc_token_provided: + msg = ( + "AI Gateway authentication failed: Invalid OIDC token.\n\n" + "Check that Vercel OIDC is enabled for this project and " + "that the token has not expired." + ) else: msg = ( "AI Gateway authentication failed: " @@ -266,6 +273,7 @@ def create_gateway_error( response_body: Any, status_code: int, api_key_provided: bool = False, + oidc_token_provided: bool = False, ) -> GatewayError: """Create a typed error from a gateway JSON error response. @@ -310,6 +318,7 @@ def create_gateway_error( case "authentication_error": err = GatewayAuthenticationError.create_contextual( api_key_provided=api_key_provided, + oidc_token_provided=oidc_token_provided, status_code=status_code, generation_id=generation_id, ) diff --git a/src/ai/providers/ai_gateway/provider.py b/src/ai/providers/ai_gateway/provider.py index cf92b4f6..390fee0a 100644 --- a/src/ai/providers/ai_gateway/provider.py +++ b/src/ai/providers/ai_gateway/provider.py @@ -5,10 +5,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, Protocol, cast from ... import errors as ai_errors -from .. import base +from .. import _optional, base from . import client as gateway_client from . import errors from . import protocol as protocol_module @@ -30,6 +30,32 @@ _BASE_URL = "https://ai-gateway.vercel.sh/v3/ai" _API_KEY_ENV = "AI_GATEWAY_API_KEY" +_VERCEL_ENV = "VERCEL" +_OIDC_TOKEN_ENV = "VERCEL_OIDC_TOKEN" + + +class _VercelOidc(Protocol): + def get_vercel_oidc_token(self) -> str: ... + + +def _get_vercel_oidc_token() -> str: + try: + oidc = cast( + "_VercelOidc", + _optional.import_optional_sdk( + "vercel.oidc", + provider="AI Gateway OIDC", + extra="vercel", + ), + ) + except ai_errors.InstallationError as exc: + raise ai_errors.InstallationError( + "AI Gateway OIDC authentication requires the optional `vercel` " + 'package. Install it with `pip install "ai[vercel]"` or ' + '`uv add "ai[vercel]"`, or set `AI_GATEWAY_API_KEY` to use ' + "API key authentication." + ) from exc + return oidc.get_vercel_oidc_token() class GatewayProvider(base.Provider[gateway_client.GatewayClient]): @@ -59,7 +85,6 @@ def __init__( self._set_client( gateway_client.GatewayClient( base_url=self.base_url, - api_key=self.api_key, headers=self.headers, client=client, ) @@ -68,14 +93,36 @@ def __init__( @property def client(self) -> gateway_client.GatewayClient: client = super().client + auth_token, auth_method = self._gateway_auth() client.base_url = self.base_url - client.api_key = self.api_key + client.auth_token = auth_token + client.auth_method = auth_method client.headers = self.headers return client + def _gateway_auth( + self, + ) -> tuple[str | None, gateway_client.AuthMethod | None]: + api_key = self.api_key + if api_key: + return api_key, "api-key" + if self._config_value(_VERCEL_ENV) == "1" or self._config_value( + _OIDC_TOKEN_ENV + ): + return _get_vercel_oidc_token(), "oidc" + return None, None + + def is_configured(self) -> bool: + """Return ``True`` when Gateway auth can be attempted.""" + return ( + bool(self.api_key) + or self._config_value(_VERCEL_ENV) == "1" + or bool(self._config_value(_OIDC_TOKEN_ENV)) + ) + async def aclose(self) -> None: """Close the provider-owned Gateway client, if any.""" - await self.client.aclose() + await super().client.aclose() def stream( self, diff --git a/tests/providers/ai_gateway/test_errors.py b/tests/providers/ai_gateway/test_errors.py index 6e55ce1d..0ecef6e1 100644 --- a/tests/providers/ai_gateway/test_errors.py +++ b/tests/providers/ai_gateway/test_errors.py @@ -88,6 +88,21 @@ def test_authentication_error_from_json_string(self) -> None: # contextual message includes the key URL assert "vercel.com/d?to=" in str(err) + def test_authentication_error_from_oidc(self) -> None: + body = { + "error": { + "message": "Invalid OIDC token", + "type": "authentication_error", + } + } + err = client_errors.create_gateway_error( + response_body=body, + status_code=401, + oidc_token_provided=True, + ) + assert isinstance(err, client_errors.GatewayAuthenticationError) + assert "OIDC token" in str(err) + def test_invalid_request_error(self) -> None: body = { "error": { diff --git a/tests/providers/ai_gateway/test_provider.py b/tests/providers/ai_gateway/test_provider.py index 0d6d7054..17b25d01 100644 --- a/tests/providers/ai_gateway/test_provider.py +++ b/tests/providers/ai_gateway/test_provider.py @@ -1,5 +1,9 @@ from __future__ import annotations +import importlib +from collections.abc import Callable +from types import ModuleType + import httpx import pytest @@ -7,6 +11,33 @@ from ai.providers.ai_gateway.client import errors +def _set_oidc_token( + monkeypatch: pytest.MonkeyPatch, + get_token: Callable[[], str], +) -> None: + real_import_module = importlib.import_module + oidc = ModuleType("vercel.oidc") + oidc.__dict__["get_vercel_oidc_token"] = get_token + + def _import_module(name: str, package: str | None = None) -> ModuleType: + if name == "vercel.oidc": + return oidc + return real_import_module(name, package) + + monkeypatch.setattr(importlib, "import_module", _import_module) + + +def _fail_oidc_import(monkeypatch: pytest.MonkeyPatch) -> None: + real_import_module = importlib.import_module + + def _import_module(name: str, package: str | None = None) -> ModuleType: + if name == "vercel.oidc": + pytest.fail("OIDC should not be imported when an API key is set") + return real_import_module(name, package) + + monkeypatch.setattr(importlib, "import_module", _import_module) + + async def test_list_models_gets_config_with_gateway_headers_and_sorts_ids() -> ( None ): @@ -71,3 +102,163 @@ def _handler(request: httpx.Request) -> httpx.Response: assert isinstance( exc_info.value.__cause__, errors.GatewayAuthenticationError ) + + +async def test_list_models_uses_oidc_on_vercel_when_no_api_key( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("AI_GATEWAY_API_KEY", raising=False) + monkeypatch.setenv("VERCEL", "1") + _set_oidc_token(monkeypatch, lambda: "oidc-test-token") + captured_headers: dict[str, str] = {} + + def _handler(request: httpx.Request) -> httpx.Response: + captured_headers.update(dict(request.headers)) + return httpx.Response( + 200, + json={"models": [{"id": "anthropic/claude-a"}]}, + ) + + provider = ai.get_provider( + "vercel", + base_url="https://gateway.test/v3/ai", + client=httpx.AsyncClient(transport=httpx.MockTransport(_handler)), + ) + + try: + ids = await provider.list_models() + finally: + await provider.aclose() + + assert ids == ["anthropic/claude-a"] + assert captured_headers["authorization"] == "Bearer oidc-test-token" + assert captured_headers["ai-gateway-auth-method"] == "oidc" + + +async def test_list_models_uses_oidc_token_env_without_vercel_flag( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("AI_GATEWAY_API_KEY", raising=False) + monkeypatch.delenv("VERCEL", raising=False) + monkeypatch.setenv("VERCEL_OIDC_TOKEN", "pulled-oidc-token") + _set_oidc_token(monkeypatch, lambda: "pulled-oidc-token") + captured_headers: dict[str, str] = {} + + def _handler(request: httpx.Request) -> httpx.Response: + captured_headers.update(dict(request.headers)) + return httpx.Response(200, json={"models": []}) + + provider = ai.get_provider( + "vercel", + base_url="https://gateway.test/v3/ai", + client=httpx.AsyncClient(transport=httpx.MockTransport(_handler)), + ) + + try: + await provider.list_models() + finally: + await provider.aclose() + + assert captured_headers["authorization"] == "Bearer pulled-oidc-token" + assert captured_headers["ai-gateway-auth-method"] == "oidc" + + +async def test_oidc_expected_without_vercel_extra_raises_installation_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("AI_GATEWAY_API_KEY", raising=False) + monkeypatch.setenv("VERCEL", "1") + real_import_module = importlib.import_module + + def _import_module(name: str, package: str | None = None) -> ModuleType: + if name == "vercel.oidc": + raise ModuleNotFoundError(name="vercel") + return real_import_module(name, package) + + def _handler(request: httpx.Request) -> httpx.Response: + pytest.fail("Gateway should not be called without the OIDC helper") + + monkeypatch.setattr(importlib, "import_module", _import_module) + provider = ai.get_provider( + "vercel", + base_url="https://gateway.test/v3/ai", + client=httpx.AsyncClient(transport=httpx.MockTransport(_handler)), + ) + + try: + with pytest.raises(ai.InstallationError) as exc_info: + await provider.list_models() + finally: + await provider.aclose() + + assert "AI Gateway OIDC authentication requires" in str(exc_info.value) + assert "ai[vercel]" in str(exc_info.value) + assert "AI_GATEWAY_API_KEY" in str(exc_info.value) + + +async def test_api_key_env_takes_precedence_over_oidc( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("AI_GATEWAY_API_KEY", "env-test-key") + monkeypatch.setenv("VERCEL", "1") + _fail_oidc_import(monkeypatch) + captured_headers: dict[str, str] = {} + + def _handler(request: httpx.Request) -> httpx.Response: + captured_headers.update(dict(request.headers)) + return httpx.Response(200, json={"models": []}) + + provider = ai.get_provider( + "vercel", + base_url="https://gateway.test/v3/ai", + client=httpx.AsyncClient(transport=httpx.MockTransport(_handler)), + ) + + try: + await provider.list_models() + finally: + await provider.aclose() + + assert captured_headers["authorization"] == "Bearer env-test-key" + assert captured_headers["ai-gateway-auth-method"] == "api-key" + + +async def test_explicit_api_key_takes_precedence_over_oidc( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("AI_GATEWAY_API_KEY", "env-test-key") + monkeypatch.setenv("VERCEL", "1") + _fail_oidc_import(monkeypatch) + captured_headers: dict[str, str] = {} + + def _handler(request: httpx.Request) -> httpx.Response: + captured_headers.update(dict(request.headers)) + return httpx.Response(200, json={"models": []}) + + provider = ai.get_provider( + "vercel", + base_url="https://gateway.test/v3/ai", + api_key="explicit-test-key", + client=httpx.AsyncClient(transport=httpx.MockTransport(_handler)), + ) + + try: + await provider.list_models() + finally: + await provider.aclose() + + assert captured_headers["authorization"] == "Bearer explicit-test-key" + assert captured_headers["ai-gateway-auth-method"] == "api-key" + + +async def test_is_configured_on_vercel_without_api_key( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("AI_GATEWAY_API_KEY", raising=False) + monkeypatch.setenv("VERCEL", "1") + + provider = ai.get_provider("vercel") + try: + assert provider.is_configured() is True + finally: + await provider.aclose() diff --git a/uv.lock b/uv.lock index e168bba7..963b3a08 100644 --- a/uv.lock +++ b/uv.lock @@ -8,7 +8,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-05-18T18:36:33.433440494Z" +exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. exclude-newer-span = "P2D" [[package]] @@ -31,6 +31,9 @@ mcp = [ openai = [ { name = "openai" }, ] +vercel = [ + { name = "vercel" }, +] [package.dev-dependencies] dev = [ @@ -62,8 +65,9 @@ requires-dist = [ { name = "openai", marker = "extra == 'openai'", specifier = ">=2.14.0" }, { name = "pydantic", specifier = ">=2.12.5" }, { name = "typing-extensions", specifier = ">=4.15.0" }, + { name = "vercel", marker = "extra == 'vercel'", specifier = ">=0.5.9" }, ] -provides-extras = ["anthropic", "mcp", "openai"] +provides-extras = ["anthropic", "mcp", "openai", "vercel"] [package.metadata.requires-dev] dev = [ @@ -194,6 +198,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" }, ] +[[package]] +name = "cbor2" +version = "5.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bd/cb/09939728be094d155b5d4ac262e39877875f5f7e36eea66beb359f647bd0/cbor2-5.9.0.tar.gz", hash = "sha256:85c7a46279ac8f226e1059275221e6b3d0e370d2bb6bd0500f9780781615bcea", size = 111231, upload-time = "2026-03-22T15:56:50.638Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/39/72d8a5a4b06565561ec28f4fcb41aff7bb77f51705c01f00b8254a2aca4f/cbor2-5.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1f223dffb1bcdd2764665f04c1152943d9daa4bc124a576cd8dee1cad4264313", size = 71223, upload-time = "2026-03-22T15:56:13.68Z" }, + { url = "https://files.pythonhosted.org/packages/09/fd/7ddf3d3153b54c69c3be77172b8d9aa3a9d74f62a7fbde614d53eaeed9a4/cbor2-5.9.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ae6c706ac1d85a0b3cb3395308fd0c4d55e3202b4760773675957e93cdff45fc", size = 287865, upload-time = "2026-03-22T15:56:14.813Z" }, + { url = "https://files.pythonhosted.org/packages/db/9d/7ede2cc42f9bb4260492e7d29d2aab781eacbbcfb09d983de1e695077199/cbor2-5.9.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4cd43d8fc374b31643b2830910f28177a606a7bc84975a62675dd3f2e320fc7b", size = 288246, upload-time = "2026-03-22T15:56:16.113Z" }, + { url = "https://files.pythonhosted.org/packages/ce/9d/588ebc7c5bc5843f609b05fe07be8575c7dec987735b0bbc908ac9c1264a/cbor2-5.9.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4aa07b392cc3d76fb31c08a46a226b58c320d1c172ff3073e864409ced7bc50f", size = 280214, upload-time = "2026-03-22T15:56:17.519Z" }, + { url = "https://files.pythonhosted.org/packages/f7/a1/6fc8f4b15c6a27e7fbb7966c30c2b4b18c274a3221fa2f5e6235502d34bc/cbor2-5.9.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:971d425b3a23b75953d8853d5f9911bdeefa09d759ee3b5e6b07b5ff3cbd9073", size = 282162, upload-time = "2026-03-22T15:56:18.975Z" }, + { url = "https://files.pythonhosted.org/packages/cf/20/9a22cfe08be16ddfeef2542cf4eeed1b29f3f57ddbba0b42f7e0bb8331fd/cbor2-5.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:34a6cb15e6ab6a8eae94ad2041731cd3ef786af43a8df99f847969af5b902ee7", size = 70049, upload-time = "2026-03-22T15:56:20.502Z" }, + { url = "https://files.pythonhosted.org/packages/c6/9e/695f92d09006614034e25a9f5b10620f3b219f79c1bec3c37b7c6f27a7a9/cbor2-5.9.0-cp312-cp312-win_arm64.whl", hash = "sha256:7d1ddc4541e7367ac58c2470cc0df847f7137167fe4f5729e2d3cc0b993d7da4", size = 65382, upload-time = "2026-03-22T15:56:21.526Z" }, + { url = "https://files.pythonhosted.org/packages/81/c5/4901e21a8afe9448fd947b11e8f383903207cd6dd0800e5f5a386838de5b/cbor2-5.9.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:fbb06f34aa645b4deca66643bba3d400d20c15312d1fe88d429be60c1ab50f27", size = 71284, upload-time = "2026-03-22T15:56:22.836Z" }, + { url = "https://files.pythonhosted.org/packages/1b/10/df643a381aebc3f05486de4813662bc58accb640fc3275cb276a75e89694/cbor2-5.9.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ac684fe195c39821fca70d18afbf748f728aefbfbf88456018d299e559b8cae0", size = 287682, upload-time = "2026-03-22T15:56:24.024Z" }, + { url = "https://files.pythonhosted.org/packages/c6/0c/8aa6b766059ae4a0ca1ec3ff96fe3823a69a7be880dba2e249f7fbe2700b/cbor2-5.9.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2a54fbb32cb828c214f7f333a707e4aec61182e7efdc06ea5d9596d3ecee624a", size = 288009, upload-time = "2026-03-22T15:56:25.305Z" }, + { url = "https://files.pythonhosted.org/packages/74/07/6236bc25c183a9cf7e8062e5dddf9eae9b0b14ebf14a58a69fe5a1e872c6/cbor2-5.9.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4753a6d1bc71054d9179557bc65740860f185095ccb401d46637fff028a5b3ec", size = 280437, upload-time = "2026-03-22T15:56:26.479Z" }, + { url = "https://files.pythonhosted.org/packages/4e/0a/84328d23c3c68874ac6497edb9b1900579a1028efa54734df3f1762bbc15/cbor2-5.9.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:380e534482b843e43442b87d8777a7bf9bed20cb7526f89b780c3400f617304b", size = 282247, upload-time = "2026-03-22T15:56:28.644Z" }, + { url = "https://files.pythonhosted.org/packages/9b/f6/89b4627e09d028c8e5fcaf7cb55f225c33ce6e037ec1844e65d02bcfa945/cbor2-5.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:dcf0f695873e5c94bd072d6af8698e72b8fb7f7a18f37e0bced1041b7111a6cf", size = 70089, upload-time = "2026-03-22T15:56:29.801Z" }, + { url = "https://files.pythonhosted.org/packages/e2/7c/efadcd5f0102db692490e4e206988a2f98d39a09912090db497a2b800885/cbor2-5.9.0-cp313-cp313-win_arm64.whl", hash = "sha256:f7c9751a9611601ab326d8f5837f01379195bbf06175fb4effeb552140e7c9e8", size = 65466, upload-time = "2026-03-22T15:56:30.823Z" }, + { url = "https://files.pythonhosted.org/packages/08/7d/9ccc36d10ef96e6038e48046ebe1ce35a1e7814da0e1e204d09e6ef09b8d/cbor2-5.9.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:23606d31ba1368bd1b6602e3020ee88fe9523ca80e8630faf6b2fc904fd84560", size = 71500, upload-time = "2026-03-22T15:56:31.876Z" }, + { url = "https://files.pythonhosted.org/packages/70/e1/a6cca2cc72e13f00030c6a649f57ae703eb2c620806ab70c40db8eab33fa/cbor2-5.9.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0322296b9d52f55880e300ba8ba09ecf644303b99b51138bbb1c0fb644fa7c3e", size = 286953, upload-time = "2026-03-22T15:56:33.292Z" }, + { url = "https://files.pythonhosted.org/packages/08/3c/24cd5ef488a957d90e016f200a3aad820e4c2f85edd61c9fe4523007a1ee/cbor2-5.9.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:422817286c1d0ce947fb2f7eca9212b39bddd7231e8b452e2d2cc52f15332dba", size = 285454, upload-time = "2026-03-22T15:56:34.703Z" }, + { url = "https://files.pythonhosted.org/packages/a4/35/dca96818494c0ba47cdd73e8d809b27fa91f8fa0ce32a068a09237687454/cbor2-5.9.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:9a4907e0c3035bb8836116854ed8e56d8aef23909d601fa59706320897ec2551", size = 279441, upload-time = "2026-03-22T15:56:35.888Z" }, + { url = "https://files.pythonhosted.org/packages/a4/44/d3362378b16e53cf7e535a3f5aed8476e2109068154e24e31981ef5bde9e/cbor2-5.9.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:fb7afe77f8d269e42d7c4b515c6fd14f1ccc0625379fb6829b269f493d16eddd", size = 279673, upload-time = "2026-03-22T15:56:37.08Z" }, + { url = "https://files.pythonhosted.org/packages/43/d1/3533a697e5842fff7c2f64912eb251f8dcab3a8b5d88e228d6eebc3b5021/cbor2-5.9.0-cp314-cp314-win_amd64.whl", hash = "sha256:86baf870d4c0bfc6f79de3801f3860a84ab76d9c8b0abb7f081f2c14c38d79d3", size = 71940, upload-time = "2026-03-22T15:56:38.366Z" }, + { url = "https://files.pythonhosted.org/packages/ff/e2/c6ba75f3fb25dfa15ab6999cc8709c821987e9ed8e375d7f58539261bcb9/cbor2-5.9.0-cp314-cp314-win_arm64.whl", hash = "sha256:7221483fad0c63afa4244624d552abf89d7dfdbc5f5edfc56fc1ff2b4b818975", size = 67639, upload-time = "2026-03-22T15:56:39.39Z" }, + { url = "https://files.pythonhosted.org/packages/42/ff/b83492b096fbef26e9cb62c1a4bf2d3cef579ea7b33138c6c37c4ae66f67/cbor2-5.9.0-py3-none-any.whl", hash = "sha256:27695cbd70c90b8de5c4a284642c2836449b14e2c2e07e3ffe0744cb7669a01b", size = 24627, upload-time = "2026-03-22T15:56:48.847Z" }, +] + [[package]] name = "certifi" version = "2026.1.4" @@ -1297,6 +1331,40 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3d/d8/2083a1daa7439a66f3a48589a57d576aa117726762618f6bb09fe3798796/uvicorn-0.40.0-py3-none-any.whl", hash = "sha256:c6c8f55bc8bf13eb6fa9ff87ad62308bbbc33d0b67f84293151efe87e0d5f2ee", size = 68502, upload-time = "2025-12-21T14:16:21.041Z" }, ] +[[package]] +name = "vercel" +version = "0.5.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "cbor2" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "vercel-workers" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/53/e9/86d6780301e36ecfe45ec401212593f24c555b284b5b080ac9a6253803ec/vercel-0.5.9.tar.gz", hash = "sha256:88d2c27ecf7e02de67b711d4dd9df84e70c6ae1823c09f1fd67146e3d54be462", size = 119164, upload-time = "2026-05-19T19:41:01.261Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/0e/dc0cb599036ff3055428459a427ed895e1a4cd3db387655672c5d2dba04b/vercel-0.5.9-py3-none-any.whl", hash = "sha256:f87a354fc110f04f3d6a2270c800cd8d42a763ea116b2625fc049d54a856a31b", size = 140907, upload-time = "2026-05-19T19:40:59.84Z" }, +] + +[[package]] +name = "vercel-workers" +version = "0.0.24" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "vercel" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/07/3befa3d3f0fb6e1e603ac0919414d732720bba4661c10b8069b49b276634/vercel_workers-0.0.24.tar.gz", hash = "sha256:d21deb13b02ccaf57c3a0f9ac898df5473dc3e617723366b4bdffd53f161a6aa", size = 63095, upload-time = "2026-05-21T15:18:03.359Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/90/1da6cd309d77bcf70861606293c80858443796029a07da5a2566f1eeb487/vercel_workers-0.0.24-py3-none-any.whl", hash = "sha256:4218f29e0255351d778456054395c41cc948fdfda5567172704ffac19cefce9f", size = 61873, upload-time = "2026-05-21T15:18:04.266Z" }, +] + [[package]] name = "websockets" version = "16.0"