diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..5d13780a --- /dev/null +++ b/.env.example @@ -0,0 +1,42 @@ +# ============================================================================= +# OpenAI +# ============================================================================= +OPENAI_API_KEY=your-openai-api-key + +# Optional: override the API base URL (e.g. for org proxies, vLLM, or LiteLLM gateways) +# OPENAI_API_BASE=https://your-org-gateway.example.com/v1 + +# ============================================================================= +# Anthropic +# ============================================================================= +ANTHROPIC_API_KEY=your-anthropic-api-key + +# Optional: override the API base URL (e.g. for org proxies or self-hosted endpoints) +# ANTHROPIC_BASE_URL=https://your-org-gateway.example.com + +# ============================================================================= +# Google Gemini +# ============================================================================= +GEMINI_API_KEY=your-gemini-api-key + +# Optional: override the API base URL (e.g. for Vertex AI proxies) +# GEMINI_API_BASE=https://your-org-gateway.example.com + +# ============================================================================= +# Other supported environment variables +# ============================================================================= + +# OpenRouter (https://openrouter.ai) +# OPENROUTER_API_KEY=your-openrouter-api-key + +# Vercel AI Gateway (https://ai-gateway.vercel.sh) +# AI_GATEWAY_API_KEY=your-vercel-ai-gateway-api-key + +# Azure OpenAI +# AZURE_OPENAI_API_KEY=your-azure-openai-api-key +# AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com +# AZURE_OPENAI_API_VERSION=2024-02-01 +# AZURE_OPENAI_DEPLOYMENT=your-deployment-name + +# Prime Intellect +# PRIME_API_KEY=your-prime-intellect-api-key diff --git a/rlm/clients/anthropic.py b/rlm/clients/anthropic.py index 5c747de4..04f0e83c 100644 --- a/rlm/clients/anthropic.py +++ b/rlm/clients/anthropic.py @@ -1,11 +1,18 @@ +import os from collections import defaultdict from typing import Any import anthropic +from dotenv import load_dotenv from rlm.clients.base_lm import BaseLM from rlm.core.types import ModelUsageSummary, UsageSummary +load_dotenv() + +DEFAULT_ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") +DEFAULT_ANTHROPIC_BASE_URL = os.getenv("ANTHROPIC_BASE_URL") + class AnthropicClient(BaseLM): """ @@ -14,14 +21,32 @@ class AnthropicClient(BaseLM): def __init__( self, - api_key: str, + api_key: str | None = None, model_name: str | None = None, max_tokens: int = 32768, + base_url: str | None = None, **kwargs, ): super().__init__(model_name=model_name, **kwargs) - self.client = anthropic.Anthropic(api_key=api_key, timeout=self.timeout) - self.async_client = anthropic.AsyncAnthropic(api_key=api_key, timeout=self.timeout) + + if api_key is None: + api_key = DEFAULT_ANTHROPIC_API_KEY + + if api_key is None: + raise ValueError( + "Anthropic API key is required. Set ANTHROPIC_API_KEY env var or pass api_key." + ) + + # Fall back to ANTHROPIC_BASE_URL env var if base_url is not explicitly provided. + if base_url is None: + base_url = DEFAULT_ANTHROPIC_BASE_URL + + client_kwargs = {"api_key": api_key, "timeout": self.timeout} + if base_url is not None: + client_kwargs["base_url"] = base_url + + self.client = anthropic.Anthropic(**client_kwargs) + self.async_client = anthropic.AsyncAnthropic(**client_kwargs) self.model_name = model_name self.max_tokens = max_tokens diff --git a/rlm/clients/gemini.py b/rlm/clients/gemini.py index 8a211611..2f7cecd6 100644 --- a/rlm/clients/gemini.py +++ b/rlm/clients/gemini.py @@ -12,6 +12,7 @@ load_dotenv() DEFAULT_GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") +DEFAULT_GEMINI_API_BASE = os.getenv("GEMINI_API_BASE") class GeminiClient(BaseLM): @@ -24,6 +25,7 @@ def __init__( self, api_key: str | None = None, model_name: str | None = "gemini-2.5-flash", + base_url: str | None = None, **kwargs, ): super().__init__(model_name=model_name, **kwargs) @@ -36,8 +38,17 @@ def __init__( "Gemini API key is required. Set GEMINI_API_KEY env var or pass api_key." ) + # Fall back to GEMINI_API_BASE env var if base_url is not explicitly provided. + if base_url is None: + base_url = DEFAULT_GEMINI_API_BASE + # Configure HTTP options with timeout - http_options = types.HttpOptions(timeout=int(self.timeout * 1000)) # milliseconds + http_options = types.HttpOptions( + timeout=int(self.timeout * 1000), # milliseconds + **({ + "base_url": base_url + } if base_url is not None else {}), + ) self.client = genai.Client(api_key=api_key, http_options=http_options) self.model_name = model_name diff --git a/rlm/clients/openai.py b/rlm/clients/openai.py index ae590e49..e88dba46 100644 --- a/rlm/clients/openai.py +++ b/rlm/clients/openai.py @@ -10,8 +10,9 @@ load_dotenv() -# Load API keys from environment variables +# Load API keys and base URL from environment variables DEFAULT_OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +DEFAULT_OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") DEFAULT_OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") DEFAULT_VERCEL_API_KEY = os.getenv("AI_GATEWAY_API_KEY") DEFAULT_PRIME_API_KEY = os.getenv("PRIME_API_KEY") @@ -36,6 +37,10 @@ def __init__( ): super().__init__(model_name=model_name, **kwargs) + # Fall back to OPENAI_API_BASE env var if base_url is not explicitly provided. + if base_url is None: + base_url = DEFAULT_OPENAI_API_BASE + if api_key is None: if base_url == "https://api.openai.com/v1" or base_url is None: api_key = DEFAULT_OPENAI_API_KEY @@ -45,6 +50,10 @@ def __init__( api_key = DEFAULT_VERCEL_API_KEY elif base_url == DEFAULT_PRIME_INTELLECT_BASE_URL: api_key = DEFAULT_PRIME_API_KEY + else: + # For any custom/unknown base URL (e.g. org proxies set via OPENAI_API_BASE), + # fall back to OPENAI_API_KEY so the client is not left unauthenticated. + api_key = DEFAULT_OPENAI_API_KEY # Pass through arbitrary kwargs to the OpenAI client (e.g. default_headers, default_query, max_retries). # Exclude model_name since it is not an OpenAI client constructor argument. diff --git a/tests/clients/test_base_url_env_vars.py b/tests/clients/test_base_url_env_vars.py new file mode 100644 index 00000000..db98e6d7 --- /dev/null +++ b/tests/clients/test_base_url_env_vars.py @@ -0,0 +1,206 @@ +""" +Unit tests for automatic base URL and API key env var fallbacks +in OpenAI, Anthropic, and Gemini clients. +""" + +import os +from unittest.mock import patch + +import pytest + +# ============================================================================= +# OpenAI +# ============================================================================= + + +class TestOpenAIClientBaseUrl: + """Tests for OPENAI_API_BASE env var fallback in OpenAIClient.""" + + def test_base_url_from_env(self): + """When base_url is not passed, OPENAI_API_BASE is used.""" + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai, patch( + "rlm.clients.openai.openai.AsyncOpenAI" + ), patch.dict( + os.environ, + {"OPENAI_API_BASE": "https://my-org.example.com/v1", "OPENAI_API_KEY": "test-key"}, + ): + from rlm.clients.openai import OpenAIClient + + with patch("rlm.clients.openai.DEFAULT_OPENAI_API_BASE", "https://my-org.example.com/v1"): + _ = OpenAIClient(api_key="test-key") + + _, kwargs = mock_openai.call_args + assert kwargs.get("base_url") == "https://my-org.example.com/v1" + + def test_explicit_base_url_overrides_env(self): + """An explicitly passed base_url takes precedence over OPENAI_API_BASE.""" + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai, patch( + "rlm.clients.openai.openai.AsyncOpenAI" + ), patch("rlm.clients.openai.DEFAULT_OPENAI_API_BASE", "https://env-base.example.com/v1"): + from rlm.clients.openai import OpenAIClient + + _ = OpenAIClient(api_key="test-key", base_url="https://explicit.example.com/v1") + + _, kwargs = mock_openai.call_args + assert kwargs.get("base_url") == "https://explicit.example.com/v1" + + def test_base_url_is_none_when_env_not_set(self): + """When OPENAI_API_BASE is not set and base_url not passed, base_url is None.""" + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai, patch( + "rlm.clients.openai.openai.AsyncOpenAI" + ), patch("rlm.clients.openai.DEFAULT_OPENAI_API_BASE", None): + from rlm.clients.openai import OpenAIClient + + _ = OpenAIClient(api_key="test-key") + + _, kwargs = mock_openai.call_args + assert kwargs.get("base_url") is None + + def test_openai_api_key_used_for_custom_base_url(self): + """When OPENAI_API_BASE is a custom URL and api_key is omitted, OPENAI_API_KEY is still used.""" + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai, patch( + "rlm.clients.openai.openai.AsyncOpenAI" + ), patch( + "rlm.clients.openai.DEFAULT_OPENAI_API_BASE", "https://my-org.example.com/v1" + ), patch( + "rlm.clients.openai.DEFAULT_OPENAI_API_KEY", "my-openai-key" + ): + from rlm.clients.openai import OpenAIClient + + _ = OpenAIClient() + + _, kwargs = mock_openai.call_args + assert kwargs.get("api_key") == "my-openai-key" + assert kwargs.get("base_url") == "https://my-org.example.com/v1" + + +# ============================================================================= +# Anthropic +# ============================================================================= + + +class TestAnthropicClientEnvVars: + """Tests for ANTHROPIC_API_KEY and ANTHROPIC_BASE_URL env var fallbacks.""" + + def test_api_key_from_env(self): + """When api_key is not passed, ANTHROPIC_API_KEY env var is used.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic") as mock_anthropic, patch( + "rlm.clients.anthropic.anthropic.AsyncAnthropic" + ), patch("rlm.clients.anthropic.DEFAULT_ANTHROPIC_API_KEY", "env-anthropic-key"), patch( + "rlm.clients.anthropic.DEFAULT_ANTHROPIC_BASE_URL", None + ): + from rlm.clients.anthropic import AnthropicClient + + _ = AnthropicClient() + + _, kwargs = mock_anthropic.call_args + assert kwargs.get("api_key") == "env-anthropic-key" + + def test_raises_when_no_api_key(self): + """Raises ValueError when neither api_key arg nor ANTHROPIC_API_KEY env var is set.""" + with patch("rlm.clients.anthropic.DEFAULT_ANTHROPIC_API_KEY", None), patch( + "rlm.clients.anthropic.DEFAULT_ANTHROPIC_BASE_URL", None + ): + from rlm.clients.anthropic import AnthropicClient + + with pytest.raises(ValueError, match="Anthropic API key is required"): + AnthropicClient(api_key=None) + + def test_base_url_from_env(self): + """When base_url is not passed, ANTHROPIC_BASE_URL env var is used.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic") as mock_anthropic, patch( + "rlm.clients.anthropic.anthropic.AsyncAnthropic" + ), patch( + "rlm.clients.anthropic.DEFAULT_ANTHROPIC_API_KEY", "test-key" + ), patch( + "rlm.clients.anthropic.DEFAULT_ANTHROPIC_BASE_URL", "https://my-org.example.com" + ): + from rlm.clients.anthropic import AnthropicClient + + _ = AnthropicClient() + + _, kwargs = mock_anthropic.call_args + assert kwargs.get("base_url") == "https://my-org.example.com" + + def test_explicit_base_url_overrides_env(self): + """An explicitly passed base_url takes precedence over ANTHROPIC_BASE_URL.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic") as mock_anthropic, patch( + "rlm.clients.anthropic.anthropic.AsyncAnthropic" + ), patch( + "rlm.clients.anthropic.DEFAULT_ANTHROPIC_API_KEY", "test-key" + ), patch( + "rlm.clients.anthropic.DEFAULT_ANTHROPIC_BASE_URL", "https://env-base.example.com" + ): + from rlm.clients.anthropic import AnthropicClient + + _ = AnthropicClient(base_url="https://explicit.example.com") + + _, kwargs = mock_anthropic.call_args + assert kwargs.get("base_url") == "https://explicit.example.com" + + def test_base_url_not_passed_when_env_not_set(self): + """When ANTHROPIC_BASE_URL is not set, base_url key is absent from client kwargs.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic") as mock_anthropic, patch( + "rlm.clients.anthropic.anthropic.AsyncAnthropic" + ), patch( + "rlm.clients.anthropic.DEFAULT_ANTHROPIC_API_KEY", "test-key" + ), patch( + "rlm.clients.anthropic.DEFAULT_ANTHROPIC_BASE_URL", None + ): + from rlm.clients.anthropic import AnthropicClient + + _ = AnthropicClient() + + _, kwargs = mock_anthropic.call_args + assert "base_url" not in kwargs + + +# ============================================================================= +# Gemini +# ============================================================================= + + +class TestGeminiClientBaseUrl: + """Tests for GEMINI_API_BASE env var fallback in GeminiClient.""" + + def test_base_url_from_env(self): + """When base_url is not passed, GEMINI_API_BASE is used in HttpOptions.""" + with patch("rlm.clients.gemini.genai.Client"), patch( + "rlm.clients.gemini.types.HttpOptions" + ) as mock_http_options, patch( + "rlm.clients.gemini.DEFAULT_GEMINI_API_BASE", "https://my-org.example.com" + ): + from rlm.clients.gemini import GeminiClient + + _ = GeminiClient(api_key="test-key") + + _, kwargs = mock_http_options.call_args + assert kwargs.get("base_url") == "https://my-org.example.com" + + def test_explicit_base_url_overrides_env(self): + """An explicitly passed base_url takes precedence over GEMINI_API_BASE.""" + with patch("rlm.clients.gemini.genai.Client"), patch( + "rlm.clients.gemini.types.HttpOptions" + ) as mock_http_options, patch( + "rlm.clients.gemini.DEFAULT_GEMINI_API_BASE", "https://env-base.example.com" + ): + from rlm.clients.gemini import GeminiClient + + _ = GeminiClient(api_key="test-key", base_url="https://explicit.example.com") + + _, kwargs = mock_http_options.call_args + assert kwargs.get("base_url") == "https://explicit.example.com" + + def test_base_url_absent_when_env_not_set(self): + """When GEMINI_API_BASE is not set, base_url is not passed to HttpOptions.""" + with patch("rlm.clients.gemini.genai.Client"), patch( + "rlm.clients.gemini.types.HttpOptions" + ) as mock_http_options, patch( + "rlm.clients.gemini.DEFAULT_GEMINI_API_BASE", None + ): + from rlm.clients.gemini import GeminiClient + + _ = GeminiClient(api_key="test-key") + + _, kwargs = mock_http_options.call_args + assert "base_url" not in kwargs