diff --git a/config.example.yml b/config.example.yml index e4a32168..adac778e 100644 --- a/config.example.yml +++ b/config.example.yml @@ -9,8 +9,13 @@ log: server: addr: "127.0.0.1:38440" - # Set a Bearer token for API authentication. When non-empty, all requests - # must include an Authorization: Bearer header. If empty, auth is disabled. + # How incoming Bearer tokens are handled: + # "authentication" (default) — validate against auth_token below + # "transform" — forward the user's Bearer token as the provider's api_key + # auth_type: "authentication" + # + # When auth_type is "authentication", this Bearer token must match the + # request's Authorization header. Leave empty to disable auth entirely. # auth_token: "replace-with-your-secret-token" persistence: diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 0391d1e6..c38fc246 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -37,9 +37,19 @@ defaults: ```yaml server: addr: "127.0.0.1:38440" # 监听地址 - auth_token: "" # Bearer 认证 Token(空 = 不认证) + auth_type: "authentication" # authentication(默认)| transform + auth_token: "" # Bearer 认证 Token(空 = 不认证;auth_type=transform 时忽略) ``` +### auth_type + +| 值 | 行为 | +|-----|------| +| `authentication`(默认) | 请求必须携带 `Authorization: Bearer `,与 `auth_token` 比对。`auth_token` 为空时跳过认证。 | +| `transform` | 提取请求中的 `Bearer `,替换掉 `providers..api_key` 转发给上游供应商。无需配置 `auth_token`。 | + +`transform` 模式适用于将 Moon Bridge 作为代理直接暴露给终端用户,由用户提供自己的 API Key 的场景。此时 `providers..api_key` 可留空,用户的 token 会替代它在所有协议路径(Anthropic `x-api-key`、OpenAI Chat/Response `Bearer`、Google Gemini API key / Vertex AI `Bearer`)中生效。 + ## Models 模型定义包含上下文窗口、推理能力、扩展支持等元信息: diff --git a/docs/GETTING-STARTED.md b/docs/GETTING-STARTED.md index 2d65a32e..09f31451 100644 --- a/docs/GETTING-STARTED.md +++ b/docs/GETTING-STARTED.md @@ -44,6 +44,10 @@ cp config.example.yml config.yml mode: "Transform" server: addr: "127.0.0.1:38440" + # auth_type 默认为 "authentication"。auth_token 为空时不验证。 + # 设为 "transform" 则转发用户的 Bearer token 给上游 Provider。 + # auth_type: "authentication" + # auth_token: "your-secret-token" defaults: model: "deepseek-chat" diff --git a/internal/config/config.go b/internal/config/config.go index 091b0905..32fd4708 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "context" "errors" "fmt" "moonbridge/internal/modelref" @@ -20,8 +21,8 @@ const ( ProtocolAnthropic = "anthropic" ProtocolOpenAIResponse = "openai-response" // Phase 5: New protocol constants (D-08) - ProtocolGoogleGenAI = "google-genai" - ProtocolOpenAIChat = "openai-chat" + ProtocolGoogleGenAI = "google-genai" + ProtocolOpenAIChat = "openai-chat" ) type Mode string @@ -32,6 +33,31 @@ const ( ModeTransform Mode = "Transform" ) +// AuthType controls how the server handles incoming Bearer tokens. +type AuthType string + +const ( + AuthTypeAuthentication AuthType = "authentication" + AuthTypeTransform AuthType = "transform" +) + +// contextKey is the type used for context keys in the config package. +type contextKey string + +const transformAuthTokenKey contextKey = "transform_auth_token" + +// WithTransformAuthToken stores a transformed auth token in the context. +func WithTransformAuthToken(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, transformAuthTokenKey, token) +} + +// TransformAuthTokenFromContext retrieves the transformed auth token from the context. +// Returns the token and true if present. +func TransformAuthTokenFromContext(ctx context.Context) (string, bool) { + token, ok := ctx.Value(transformAuthTokenKey).(string) + return token, ok +} + type WebSearchSupport string const ( @@ -51,22 +77,23 @@ type WebSearchConfig struct { } type Config struct { - Mode Mode - Addr string - AuthToken string - TraceRequests bool - LogLevel string - LogFormat string - SystemPrompt string - DefaultModel string - WebSearchSupport WebSearchSupport - WebSearchMaxUses int - TavilyAPIKey string - FirecrawlAPIKey string - SearchMaxRounds int - DefaultMaxTokens int - MaxSessions int `yaml:"max_sessions"` // 0 = unlimited - SessionTTL string `yaml:"session_ttl"` // default "24h" + Mode Mode + Addr string + AuthToken string + AuthType AuthType + TraceRequests bool + LogLevel string + LogFormat string + SystemPrompt string + DefaultModel string + WebSearchSupport WebSearchSupport + WebSearchMaxUses int + TavilyAPIKey string + FirecrawlAPIKey string + SearchMaxRounds int + DefaultMaxTokens int + MaxSessions int `yaml:"max_sessions"` // 0 = unlimited + SessionTTL string `yaml:"session_ttl"` // default "24h" // Defaults holds the default configuration values. Defaults Defaults // Models is the canonical model definition map (shared across providers). @@ -115,11 +142,11 @@ type RouteEntry struct { // ProviderDef defines a single upstream provider. type ProviderDef struct { - BaseURL string - APIKey string - Version string - UserAgent string - Protocol string // "anthropic" (default), "openai-response", "google-genai", or "openai-chat" + BaseURL string + APIKey string + Version string + UserAgent string + Protocol string // "anthropic" (default), "openai-response", "google-genai", or "openai-chat" // Phase 5: Google GenAI flat fields (D-09). // Only relevant when Protocol == ProtocolGoogleGenAI. // project: Google Cloud project ID (Vertex AI). @@ -129,7 +156,7 @@ type ProviderDef struct { Location string `yaml:"location,omitempty"` APIVersion string `yaml:"api_version,omitempty"` // Cache config for this provider. If nil, provider does not use caching. - Cache *CacheConfig `yaml:"cache,omitempty"` + Cache *CacheConfig `yaml:"cache,omitempty"` WebSearchSupport WebSearchSupport WebSearchMaxUses int TavilyAPIKey string @@ -204,11 +231,11 @@ type ModelDef struct { // OfferEntry declares that a provider offers a model defined in Models. type OfferEntry struct { - Model string // references models. - UpstreamName string // optional, upstream model name (empty = same as slug) - Priority int // lower value = higher priority (0 is highest) + Model string // references models. + UpstreamName string // optional, upstream model name (empty = same as slug) + Priority int // lower value = higher priority (0 is highest) Pricing ModelPricing - Overrides *ModelDef // optional provider-specific overrides + Overrides *ModelDef // optional provider-specific overrides } type ResponseProxyConfig struct { @@ -285,7 +312,7 @@ func (cfg Config) validateTransform() error { if def.BaseURL == "" { return fmt.Errorf("providers.%s.base_url is required", key) } - if def.APIKey == "" { + if cfg.AuthType != AuthTypeTransform && def.APIKey == "" { return fmt.Errorf("providers.%s.api_key is required", key) } switch def.Protocol { diff --git a/internal/config/config_loader.go b/internal/config/config_loader.go index 76dee14c..9edc4692 100644 --- a/internal/config/config_loader.go +++ b/internal/config/config_loader.go @@ -88,8 +88,9 @@ type TraceFileConfig struct { } type ServerFileConfig struct { - Addr string `yaml:"addr" json:"addr,omitempty"` - AuthToken string `yaml:"auth_token" json:"auth_token,omitempty"` + Addr string `yaml:"addr" json:"addr,omitempty"` + AuthToken string `yaml:"auth_token" json:"auth_token,omitempty"` + AuthType string `yaml:"auth_type" json:"auth_type,omitempty"` MaxSessions int `yaml:"max_sessions"` SessionTTL string `yaml:"session_ttl"` } @@ -122,33 +123,33 @@ type ModelDefFileConfig struct { } type OfferFileConfig struct { - Model string `yaml:"model" json:"model"` - UpstreamName string `yaml:"upstream_name,omitempty" json:"upstream_name,omitempty"` - Priority int `yaml:"priority,omitempty" json:"priority,omitempty"` - Pricing ModelPricingFileConfig `yaml:"pricing,omitempty" json:"pricing,omitempty"` - Overrides *ModelDefFileConfig `yaml:"overrides,omitempty" json:"overrides,omitempty"` + Model string `yaml:"model" json:"model"` + UpstreamName string `yaml:"upstream_name,omitempty" json:"upstream_name,omitempty"` + Priority int `yaml:"priority,omitempty" json:"priority,omitempty"` + Pricing ModelPricingFileConfig `yaml:"pricing,omitempty" json:"pricing,omitempty"` + Overrides *ModelDefFileConfig `yaml:"overrides,omitempty" json:"overrides,omitempty"` } type ProviderDefFileConfig struct { - BaseURL string `yaml:"base_url" json:"base_url"` - APIKey string `yaml:"api_key" json:"api_key"` - Version string `yaml:"version,omitempty" json:"version,omitempty"` - UserAgent string `yaml:"user_agent,omitempty" json:"user_agent,omitempty"` - Protocol string `yaml:"protocol,omitempty" json:"protocol,omitempty"` - WebSearch WebSearchFileConfig `yaml:"web_search,omitempty" json:"web_search,omitempty"` - Extensions map[string]ExtensionFileConfig `yaml:"extensions,omitempty" json:"extensions,omitempty"` - Offers []OfferFileConfig `yaml:"offers,omitempty" json:"offers,omitempty"` + BaseURL string `yaml:"base_url" json:"base_url"` + APIKey string `yaml:"api_key" json:"api_key"` + Version string `yaml:"version,omitempty" json:"version,omitempty"` + UserAgent string `yaml:"user_agent,omitempty" json:"user_agent,omitempty"` + Protocol string `yaml:"protocol,omitempty" json:"protocol,omitempty"` + WebSearch WebSearchFileConfig `yaml:"web_search,omitempty" json:"web_search,omitempty"` + Extensions map[string]ExtensionFileConfig `yaml:"extensions,omitempty" json:"extensions,omitempty"` + Offers []OfferFileConfig `yaml:"offers,omitempty" json:"offers,omitempty"` } type RouteFileConfig struct { - To string `yaml:"to,omitempty" json:"to,omitempty"` // backward compat "provider/model" - Model string `yaml:"model,omitempty" json:"model,omitempty"` - Provider string `yaml:"provider,omitempty" json:"provider,omitempty"` - DisplayName string `yaml:"display_name,omitempty" json:"display_name,omitempty"` - Description string `yaml:"description,omitempty" json:"description,omitempty"` - ContextWindow int `yaml:"context_window,omitempty" json:"context_window,omitempty"` - WebSearch WebSearchFileConfig `yaml:"web_search,omitempty" json:"web_search,omitempty"` - Extensions map[string]ExtensionFileConfig `yaml:"extensions,omitempty" json:"extensions,omitempty"` + To string `yaml:"to,omitempty" json:"to,omitempty"` // backward compat "provider/model" + Model string `yaml:"model,omitempty" json:"model,omitempty"` + Provider string `yaml:"provider,omitempty" json:"provider,omitempty"` + DisplayName string `yaml:"display_name,omitempty" json:"display_name,omitempty"` + Description string `yaml:"description,omitempty" json:"description,omitempty"` + ContextWindow int `yaml:"context_window,omitempty" json:"context_window,omitempty"` + WebSearch WebSearchFileConfig `yaml:"web_search,omitempty" json:"web_search,omitempty"` + Extensions map[string]ExtensionFileConfig `yaml:"extensions,omitempty" json:"extensions,omitempty"` } func (cfg *RouteFileConfig) UnmarshalYAML(value *yaml.Node) error { @@ -348,10 +349,16 @@ func FromFileConfigWithOptions(fileConfig FileConfig, opts LoadOptions) (Config, responseProxy := FromResponseProxyFileConfig(fileConfig.Proxy.Response) anthropicProxy := FromAnthropicProxyFileConfig(fileConfig.Proxy.Anthropic) + authType, err := parseAuthType(fileConfig.Server.AuthType) + if err != nil { + return Config{}, err + } + cfg := Config{ Mode: mode, Addr: valueOrDefault(strings.TrimSpace(fileConfig.Server.Addr), DefaultAddr), AuthToken: strings.TrimSpace(fileConfig.Server.AuthToken), + AuthType: authType, MaxSessions: intOrDefault(fileConfig.Server.MaxSessions, 0), SessionTTL: valueOrDefault(strings.TrimSpace(fileConfig.Server.SessionTTL), "24h"), TraceRequests: traceEnabled, @@ -786,6 +793,21 @@ func parseMode(value string) (Mode, error) { } } +func parseAuthType(value string) (AuthType, error) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return AuthTypeAuthentication, nil + } + switch AuthType(trimmed) { + case AuthTypeTransform: + return AuthTypeTransform, nil + case AuthTypeAuthentication: + return AuthTypeAuthentication, nil + default: + return "", fmt.Errorf("invalid auth_type %q (expected \"authentication\" or \"transform\")", trimmed) + } +} + func parseWebSearchSupport(value string) (WebSearchSupport, error) { switch support := WebSearchSupport(strings.TrimSpace(value)); support { case "": diff --git a/internal/config/convert.go b/internal/config/convert.go index faccab17..3f6c6117 100644 --- a/internal/config/convert.go +++ b/internal/config/convert.go @@ -19,6 +19,7 @@ func (cfg Config) ToFileConfig() FileConfig { Server: ServerFileConfig{ Addr: cfg.Addr, AuthToken: cfg.AuthToken, + AuthType: string(cfg.AuthType), MaxSessions: cfg.MaxSessions, SessionTTL: cfg.SessionTTL, }, diff --git a/internal/config/domain_server.go b/internal/config/domain_server.go index 0097225e..0cc7e0dc 100644 --- a/internal/config/domain_server.go +++ b/internal/config/domain_server.go @@ -5,6 +5,7 @@ package config type ServerConfig struct { Addr string AuthToken string + AuthType AuthType Mode string MaxSessions int SessionTTL string @@ -15,6 +16,7 @@ func ServerFromGlobalConfig(cfg *Config) ServerConfig { return ServerConfig{ Addr: cfg.Addr, AuthToken: cfg.AuthToken, + AuthType: cfg.AuthType, Mode: string(cfg.Mode), MaxSessions: cfg.MaxSessions, SessionTTL: cfg.SessionTTL, diff --git a/internal/protocol/anthropic/client.go b/internal/protocol/anthropic/client.go index f5dcf6ee..e0e9a2c1 100644 --- a/internal/protocol/anthropic/client.go +++ b/internal/protocol/anthropic/client.go @@ -11,6 +11,8 @@ import ( "log/slog" "net/http" "strings" + + "moonbridge/internal/config" ) type ClientConfig struct { @@ -205,7 +207,7 @@ func (client *Client) newRequest(ctx context.Context, messageRequest MessageRequ return nil, err } httpRequest.Header.Set("content-type", "application/json") - httpRequest.Header.Set("x-api-key", client.apiKey) + httpRequest.Header.Set("x-api-key", client.effectiveAPIKey(ctx)) if client.version != "" { httpRequest.Header.Set("anthropic-version", client.version) } @@ -375,3 +377,12 @@ func (err *ProviderError) OpenAIType() string { func UnsupportedStreamEvent(event string) error { return fmt.Errorf("unsupported stream event %q", event) } + +// effectiveAPIKey returns the transformed auth token from context if available, +// otherwise falls back to the client's configured API key. +func (client *Client) effectiveAPIKey(ctx context.Context) string { + if token, ok := config.TransformAuthTokenFromContext(ctx); ok { + return token + } + return client.apiKey +} diff --git a/internal/protocol/chat/client.go b/internal/protocol/chat/client.go index dafa3536..ec0868a1 100644 --- a/internal/protocol/chat/client.go +++ b/internal/protocol/chat/client.go @@ -14,6 +14,8 @@ import ( "log/slog" "net/http" "strings" + + "moonbridge/internal/config" ) // ClientConfig configures the OpenAI Chat Completions HTTP client. @@ -162,7 +164,7 @@ func (c *Client) newRequest(ctx context.Context, req *ChatRequest) (*http.Reques return nil, fmt.Errorf("chat API request build: %w", err) } httpReq.Header.Set("content-type", "application/json") - httpReq.Header.Set("authorization", "Bearer "+c.apiKey) + httpReq.Header.Set("authorization", "Bearer "+c.effectiveAPIKey(ctx)) if c.userAgent != "" { httpReq.Header.Set("user-agent", c.userAgent) } @@ -226,3 +228,12 @@ func safeUsage(u *Usage) Usage { } return *u } + +// effectiveAPIKey returns the transformed auth token from context if available, +// otherwise falls back to the client's configured API key. +func (c *Client) effectiveAPIKey(ctx context.Context) string { + if token, ok := config.TransformAuthTokenFromContext(ctx); ok { + return token + } + return c.apiKey +} diff --git a/internal/protocol/google/client.go b/internal/protocol/google/client.go index e7a3591e..4e1efbbc 100644 --- a/internal/protocol/google/client.go +++ b/internal/protocol/google/client.go @@ -14,15 +14,17 @@ import ( "log/slog" "net/http" "strings" + + "moonbridge/internal/config" ) // ClientConfig configures the Gemini API HTTP client. type ClientConfig struct { BaseURL string APIKey string - Project string // Vertex AI project ID (optional, for Vertex AI endpoint) - Location string // Vertex AI location (optional, default "us-central1") - Version string // API version (default "v1") + Project string // Vertex AI project ID (optional, for Vertex AI endpoint) + Location string // Vertex AI location (optional, default "us-central1") + Version string // API version (default "v1") UserAgent string Client *http.Client } @@ -146,6 +148,14 @@ func (c *Client) StreamGenerateContent(ctx context.Context, model string, req *G // to close (connections are managed by http.Client), so this is a no-op. func (c *Client) Close() error { return nil } +// effectiveAPIKey returns the transformed auth token from context if available, +// otherwise falls back to the client's configured API key. +func (c *Client) effectiveAPIKey(ctx context.Context) string { + if token, ok := config.TransformAuthTokenFromContext(ctx); ok { + return token + } + return c.apiKey +} // ============================================================================ // CachedContent API methods @@ -160,7 +170,7 @@ func (c *Client) CreateCachedContent(ctx context.Context, cc *CachedContent) (*C return nil, fmt.Errorf("create cached content: %w", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("x-goog-api-key", c.apiKey) + req.Header.Set("x-goog-api-key", c.effectiveAPIKey(ctx)) resp, err := c.client.Do(req) if err != nil { return nil, fmt.Errorf("create cached content: %w", err) @@ -183,7 +193,7 @@ func (c *Client) GetCachedContent(ctx context.Context, name string) (*CachedCont if err != nil { return nil, fmt.Errorf("get cached content: %w", err) } - req.Header.Set("x-goog-api-key", c.apiKey) + req.Header.Set("x-goog-api-key", c.effectiveAPIKey(ctx)) resp, err := c.client.Do(req) if err != nil { return nil, fmt.Errorf("get cached content: %w", err) @@ -209,7 +219,7 @@ func (c *Client) UpdateCachedContent(ctx context.Context, name, ttl string) (*Ca return nil, fmt.Errorf("update cached content: %w", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("x-goog-api-key", c.apiKey) + req.Header.Set("x-goog-api-key", c.effectiveAPIKey(ctx)) resp, err := c.client.Do(req) if err != nil { return nil, fmt.Errorf("update cached content: %w", err) @@ -232,7 +242,7 @@ func (c *Client) DeleteCachedContent(ctx context.Context, name string) error { if err != nil { return fmt.Errorf("delete cached content: %w", err) } - req.Header.Set("x-goog-api-key", c.apiKey) + req.Header.Set("x-goog-api-key", c.effectiveAPIKey(ctx)) resp, err := c.client.Do(req) if err != nil { return fmt.Errorf("delete cached content: %w", err) @@ -243,6 +253,7 @@ func (c *Client) DeleteCachedContent(ctx context.Context, name string) error { } return nil } + // ============================================================================ // Internal helpers // ============================================================================ @@ -264,8 +275,9 @@ func (c *Client) newRequest(ctx context.Context, model, action string, req *Gene c.baseURL, c.version, c.project, c.location, model, action) } else { // Gemini API: API key in query param + effectiveKey := c.effectiveAPIKey(ctx) url = fmt.Sprintf("%s/%s/models/%s%s?key=%s", - c.baseURL, c.version, model, action, c.apiKey) + c.baseURL, c.version, model, action, effectiveKey) } httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data)) @@ -276,9 +288,12 @@ func (c *Client) newRequest(ctx context.Context, model, action string, req *Gene if c.userAgent != "" { httpReq.Header.Set("user-agent", c.userAgent) } - if c.project != "" && c.apiKey != "" { + if c.project != "" { // Vertex AI uses Bearer token (APIKey field holds the OAuth token) - httpReq.Header.Set("authorization", "Bearer "+c.apiKey) + effectiveKey := c.effectiveAPIKey(ctx) + if effectiveKey != "" { + httpReq.Header.Set("authorization", "Bearer "+effectiveKey) + } } return httpReq, nil } diff --git a/internal/service/app/app.go b/internal/service/app/app.go index bccf2b6a..ea94168a 100644 --- a/internal/service/app/app.go +++ b/internal/service/app/app.go @@ -326,6 +326,12 @@ func resolvePerProviderWebSearch(ctx context.Context, cfg config.Config, pm *pro if pm == nil { return } + // In transform auth mode, upstream probing is impossible (no api_key configured). + // Explicit configs (enabled/disabled/injected) are honored regardless. + // In transform mode, auto probing is deferred as "unknown" and lazily + // resolved on the first authenticated request carrying a user token. + isTransform := cfg.AuthType == config.AuthTypeTransform + // 1. Resolve provider-level defaults. for _, key := range pm.ProviderKeys() { protocol := pm.ProtocolForKey(key) @@ -343,12 +349,17 @@ func resolvePerProviderWebSearch(ctx context.Context, cfg config.Config, pm *pro pm.SetResolvedWebSearch(key, "injected") slog.Info("网页搜索注入模式已启用", "provider", key) default: - resolved := probeProviderWebSearch(ctx, key, pm, errors) - if resolved == "disabled" && cfg.TavilyAPIKey != "" { - resolved = "injected" - slog.Info("网页搜索自动探测失败,回退到注入模式", "provider", key) + if isTransform { + pm.SetResolvedWebSearch(key, "unknown") + slog.Info("transform 模式推迟网页搜索探测", "provider", key) + } else { + resolved := probeProviderWebSearch(ctx, key, pm, errors) + if resolved == "disabled" && cfg.TavilyAPIKey != "" { + resolved = "injected" + slog.Info("网页搜索自动探测失败,回退到注入模式", "provider", key) + } + pm.SetResolvedWebSearch(key, resolved) } - pm.SetResolvedWebSearch(key, resolved) } case config.ProtocolOpenAIResponse: switch support { @@ -433,9 +444,15 @@ func resolveModelWebSearch(ctx context.Context, alias, providerKey, upstreamMode pm.SetResolvedWebSearch(candidateKey, "injected") slog.Info("模型配置启用网页搜索注入模式", "model", alias) default: - resolved := resolveModelWebSearchWithProber(ctx, alias, providerKey, upstreamModel, modelWS, pm, cfg, errors, pm) - pm.SetResolvedWebSearch(modelKey, resolved) - pm.SetResolvedWebSearch(candidateKey, resolved) + if cfg.AuthType == config.AuthTypeTransform { + pm.SetResolvedWebSearch(modelKey, "unknown") + pm.SetResolvedWebSearch(candidateKey, "unknown") + slog.Info("transform 模式推迟模型级网页搜索探测", "model", alias) + } else { + resolved := resolveModelWebSearchWithProber(ctx, alias, providerKey, upstreamModel, modelWS, pm, cfg, errors, pm) + pm.SetResolvedWebSearch(modelKey, resolved) + pm.SetResolvedWebSearch(candidateKey, resolved) + } } } diff --git a/internal/service/server/adapter_dispatch.go b/internal/service/server/adapter_dispatch.go index d3e65a8d..9e12b180 100644 --- a/internal/service/server/adapter_dispatch.go +++ b/internal/service/server/adapter_dispatch.go @@ -170,7 +170,7 @@ func (s *Server) handleWithAdapters( // the upstream provider receives the correct model identifier. coreReq.Model = preferred.UpstreamModel - wsMode := resolvedWebSearchMode(pm, openAIReq.Model, preferred) + wsMode := s.resolveWebSearchLazy(ctx, pm, preferred.ProviderKey, preferred.UpstreamModel, openAIReq.Model) // Inject web search tools at Core level if mode is "injected". // This replaces web_search/web_search_preview with tavily_search/firecrawl_fetch tools. @@ -2143,7 +2143,6 @@ func normalizeAnthropicRequest(upstream any) (anthropic.MessageRequest, error) { } } - // injectCoreWebSearch replaces web_search tools in coreReq.Tools with injected // tavily_search/firecrawl_fetch tools when the resolved web search mode is "injected". // Returns true if injection was applied. @@ -2192,6 +2191,50 @@ func resolvedWebSearchMode(pm *provider.ProviderManager, modelAlias string, pref return "" } +// resolveWebSearchLazy probes web search support on-demand when the resolved +// mode is "unknown" (startup state for transform auth mode). On success, +// updates the provider manager cache and returns the final mode. +func (s *Server) resolveWebSearchLazy(ctx context.Context, pm *provider.ProviderManager, providerKey, upstreamModel, modelAlias string) string { + // Resolve current mode using the same priority as resolvedWebSearchMode. + mode := "" + if providerKey != "" && upstreamModel != "" { + mode = pm.ResolvedWebSearchForCandidate(providerKey, upstreamModel) + } + if mode == "" && modelAlias != "" { + mode = pm.ResolvedWebSearchForModel(modelAlias) + } + if mode != "unknown" { + return mode + } + + // Only probe if context carries a transform token (otherwise skip). + if _, ok := config.TransformAuthTokenFromContext(ctx); !ok { + return "disabled" + } + + slog.Default().Info("lazy web search probe", "provider", providerKey, "model", upstreamModel) + supported, err := pm.ProbeWebSearchCandidate(ctx, providerKey, upstreamModel) + if err != nil || !supported { + slog.Default().Warn("lazy web search probe failed or unsupported", "provider", providerKey, "model", upstreamModel, "error", err) + if providerKey != "" && upstreamModel != "" { + pm.SetResolvedWebSearch(provider.WebSearchCandidateKey(providerKey, upstreamModel), "disabled") + } + if modelAlias != "" { + pm.SetResolvedWebSearch("model:"+modelAlias, "disabled") + } + return "disabled" + } + + slog.Default().Info("lazy web search probe: enabled", "provider", providerKey, "model", upstreamModel) + if providerKey != "" && upstreamModel != "" { + pm.SetResolvedWebSearch(provider.WebSearchCandidateKey(providerKey, upstreamModel), "enabled") + } + if modelAlias != "" { + pm.SetResolvedWebSearch("model:"+modelAlias, "enabled") + } + return "enabled" +} + // searchProvider wraps the websearchinjected orchestrator's behavior. type searchProvider interface { CreateMessage(ctx context.Context, req anthropic.MessageRequest) (anthropic.MessageResponse, error) diff --git a/internal/service/server/dispatch.go b/internal/service/server/dispatch.go index c5d4707d..be929ac5 100644 --- a/internal/service/server/dispatch.go +++ b/internal/service/server/dispatch.go @@ -325,6 +325,14 @@ func (server *Server) handleOpenAIResponse(writer http.ResponseWriter, request * return } + // Resolve web search once before the loop (includes lazy probe for transform mode). + wsEnabled := false + if len(openaiCandidates) > 0 { + first := openaiCandidates[0] + wsMode := server.resolveWebSearchLazy(request.Context(), pm, first.ProviderKey, first.UpstreamModel, responsesRequest.Model) + wsEnabled = wsMode == "enabled" + } + for i, candidate := range openaiCandidates { providerKey := candidate.ProviderKey isLast := i == len(openaiCandidates)-1 @@ -337,6 +345,10 @@ func (server *Server) handleOpenAIResponse(writer http.ResponseWriter, request * baseURL := pm.ProviderBaseURL(providerKey) apiKey := pm.ProviderAPIKey(providerKey) + // In transform mode, use the user's Bearer token instead of the provider's API key. + if token, ok := config.TransformAuthTokenFromContext(request.Context()); ok { + apiKey = token + } if baseURL == "" { if isLast { log.Error("OpenAI 提供商缺少 base_url") @@ -371,7 +383,7 @@ func (server *Server) handleOpenAIResponse(writer http.ResponseWriter, request * actualModel = candidate.UpstreamModel // Inject web_search tool if enabled for this model. - if pm.ResolvedWebSearchForModel(responsesRequest.Model) == "enabled" { + if wsEnabled { upstreamRequest.Tools = InjectWebSearchTool(upstreamRequest.Tools) } diff --git a/internal/service/server/server.go b/internal/service/server/server.go index 4c792212..a7279672 100644 --- a/internal/service/server/server.go +++ b/internal/service/server/server.go @@ -162,17 +162,39 @@ func New(cfg Config) *Server { } func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - if token := s.currentConfig().AuthToken; token != "" { - if !checkAuth(request, token) { + cfg := s.currentConfig() + + switch cfg.AuthType { + case config.AuthTypeTransform: + // In transform mode, extract the user's Bearer token and forward it + // to the upstream provider instead of the configured api_key. + token := extractBearerToken(request) + if token == "" { writer.Header().Set("Content-Type", "application/json") writer.WriteHeader(http.StatusUnauthorized) json.NewEncoder(writer).Encode(openai.ErrorResponse{Error: openai.ErrorObject{ - Message: "未提供有效的认证令牌,请在 Authorization header 中使用 Bearer 方案", + Message: "在 transform 模式下需要提供 Authorization: Bearer ", Type: "authentication_error", Code: "invalid_auth", }}) return } + ctx := config.WithTransformAuthToken(request.Context(), token) + request = request.WithContext(ctx) + default: + // authentication mode (default): verify Bearer token against auth_token. + if cfg.AuthToken != "" { + if !checkAuth(request, cfg.AuthToken) { + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(writer).Encode(openai.ErrorResponse{Error: openai.ErrorObject{ + Message: "未提供有效的认证令牌,请在 Authorization header 中使用 Bearer 方案", + Type: "authentication_error", + Code: "invalid_auth", + }}) + return + } + } } s.mux.ServeHTTP(writer, request) } @@ -321,11 +343,17 @@ func slugDisplayName(slug string) string { } func checkAuth(r *http.Request, expectedToken string) bool { + return extractBearerToken(r) == expectedToken +} + +// extractBearerToken extracts the Bearer token from an Authorization header. +// Returns the token value, or an empty string if no Bearer token is present. +func extractBearerToken(r *http.Request) string { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { - return false + return "" } - return strings.TrimSpace(auth[7:]) == expectedToken + return strings.TrimSpace(auth[7:]) } func (s *Server) resolveModelOrFallback(modelName string) (*provider.ResolvedRoute, error) {