diff --git a/README.md b/README.md index 64dc173c..04958912 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,6 @@ docker run -p 38440:38440 -v $(pwd)/config.yml:/config/config.yml moonbridge | `/v1/models` | GET | 列出可用模型 | | `/models` | GET | 同上 | | `/api/v1/` | — | 管理 API(需启用持久化) | -| `/health` | GET | 健康检查 | 详细 API 文档见 [API.md](docs/api.md)。 diff --git a/config.example.yml b/config.example.yml index e4a32168..1f97cce2 100644 --- a/config.example.yml +++ b/config.example.yml @@ -67,6 +67,11 @@ extensions: # binding: MOONBRIDGE_DB # Persistence consumer: request metrics + # apply_patch 代理扩展:控制 Codex 文件编辑工具是否展开为结构化代理工具。 + # 默认关闭,开启后 apply_patch 被拆分为 add_file/delete_file/update_file/replace_file/batch + # codex_tool_proxy: + # enabled: true + metrics: enabled: true config: diff --git a/docs/api.md b/docs/api.md index 2f3257f0..2d024e84 100644 --- a/docs/api.md +++ b/docs/api.md @@ -66,10 +66,32 @@ data: {"response": {...}} | 端点 | 方法 | 功能 | |------|------|------| -| `/api/v1/config` | GET/PUT | 获取/更新运行时配置 | -| `/api/v1/codex/config` | GET | 生成 Codex TOML 配置 | -| `/api/v1/providers` | GET/POST/DELETE | 管理 Provider | -| `/api/v1/sessions/{id}` | GET | 获取会话用量统计 | +| `/api/v1/providers` | GET/POST/PUT/PATCH/DELETE | 管理 Provider CRUD | +| `/api/v1/providers/{key}/offers` | POST | 创建 Offer | +| `/api/v1/providers/{key}/offers/{model}` | PATCH/DELETE | 更新/删除 Offer | +| `/api/v1/providers/{key}/test` | POST | 测试 Provider 连通性 | +| `/api/v1/models` | GET | 列出模型定义 | +| `/api/v1/models/{slug}` | GET/PUT/DELETE | 管理单个模型定义 | +| `/api/v1/routes` | GET | 列出路由 | +| `/api/v1/routes/{alias}` | GET/PUT/DELETE | 管理单个路由 | +| `/api/v1/defaults` | GET/PUT | 管理默认配置(model/max_tokens/system_prompt) | +| `/api/v1/web-search` | GET/PUT | 管理全局 Web Search 设置 | +| `/api/v1/extensions` | GET | 列出扩展 | +| `/api/v1/extensions/{name}` | GET/PUT | 管理扩展设置 | +| `/api/v1/config/effective` | GET | 获取生效配置 | +| `/api/v1/config/export` | GET | 导出配置 YAML | +| `/api/v1/config/import` | POST | 导入配置 YAML | +| `/api/v1/config/validate` | POST | 校验配置 | +| `/api/v1/changes` | GET | 列出未提交的配置变更 | +| `/api/v1/changes/apply` | POST | 应用配置变更 | +| `/api/v1/changes/discard` | POST | 放弃配置变更 | +| `/api/v1/status` | GET | 系统状态 | +| `/api/v1/status/providers` | GET | Provider 运行状态 | +| `/api/v1/sessions` | GET | 列出活跃会话 | +| `/api/v1/stats` | GET | 用量统计 | +| `/api/v1/stats/summary` | GET | 用量统计摘要 | +| `/api/v1/logs` | GET | 日志查询 | +| `/api/v1/version` | GET | 版本信息 | ## 错误处理 diff --git a/docs/architecture.md b/docs/architecture.md index 128ecfc1..d5775d22 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -8,28 +8,46 @@ Moon Bridge 是一个 Go 语言编写的 HTTP 代理/协议转换服务器。对 ## 四层架构 -``` -┌─────────────────────────────────────────────────┐ -│ Service 层 │ -│ server(路由/处理) adapter_dispatch(协议分发) │ -│ provider(路由) stats(统计) trace(跟踪) │ -│ proxy(Capture代理) api(管理 API) │ -│ store(持久化) runtime(运行时) │ -├─────────────────────────────────────────────────┤ -│ Protocol 层 │ -│ format(核心类型/注册表) anthropic(Anthropic 适配) │ -│ openai(OpenAI 适配) google(GenAI 适配) │ -│ chat(OpenAI Chat 适配) cache(缓存) │ -├─────────────────────────────────────────────────┤ -│ 基础组件(直接位于 internal/ 下) │ -│ config(配置) logger(日志) openai_dto(共享 DTO) │ -│ modelref(模型引用) session(会话) db(数据库) │ -├─────────────────────────────────────────────────┤ -│ Extension 层 │ -│ deepseek_v4 visual websearch websearchinjected│ -│ kimi_workaround metrics codex(Codex模型目录) │ -│ plugin(插件注册/接口) db(持久化后端: SQLite/D1) │ -└─────────────────────────────────────────────────┘ +```mermaid +flowchart TB + subgraph Service["Service 层"] + s1["server(路由/处理)"] + s2["adapter_dispatch(协议分发)"] + s3["provider(路由)"] + s4["stats(统计)"] + s5["trace(跟踪)"] + s6["proxy(Capture代理)"] + s7["api(管理 API)"] + s8["store(持久化)"] + s9["runtime(运行时)"] + end + subgraph Protocol["Protocol 层"] + p1["format(核心类型/注册表)"] + p2["anthropic(Anthropic 适配)"] + p3["openai(OpenAI 适配)"] + p4["google(GenAI 适配)"] + p5["chat(OpenAI Chat 适配)"] + p6["cache(缓存)"] + end + subgraph Base["基础组件"] + b1["config(配置)"] + b2["logger(日志)"] + b3["openai_dto(共享 DTO)"] + b4["modelref(模型引用)"] + b5["session(会话)"] + b6["db(数据库)"] + end + subgraph Extension["Extension 层"] + e1["deepseek_v4"] + e2["visual"] + e3["websearch"] + e4["websearchinjected"] + e5["kimi_workaround"] + e6["metrics"] + e7["codex(模型目录)"] + e8["plugin(插件注册/接口)"] + e9["db(SQLite/D1)"] + end ``` ### 基础组件(internal/ 顶层包) @@ -58,7 +76,7 @@ Moon Bridge 是一个 Go 语言编写的 HTTP 代理/协议转换服务器。对 业务编排层,组合基础层和 Protocol 组件: -- `internal/service/server` — HTTP 服务器、路由(`/v1/responses`、`/v1/models`、`/health` 等)、认证 +- `internal/service/server` — HTTP 服务器、路由(`/v1/responses`、`/v1/models` 等)、认证 - `internal/service/server/adapter_dispatch.go` — Adapter 分发路径(switch 协议类型 → 调用对应 Adapter) - `internal/service/provider` — Provider 管理器(多 Provider 路由、配置热重载) - `internal/service/proxy` — Capture 模式下的透明代理 @@ -68,7 +86,6 @@ Moon Bridge 是一个 Go 语言编写的 HTTP 代理/协议转换服务器。对 - `internal/service/trace` — 请求跟踪(捕获请求/响应的完整链路,持久化到 `data/trace/`) - `internal/service/store` — 配置持久化存储(SQLite / D1) - `internal/service/runtime` — 运行时上下文 -- `internal/service/bridge` — 备用桥接层 ### Extension 层 @@ -81,6 +98,8 @@ Moon Bridge 是一个 Go 语言编写的 HTTP 代理/协议转换服务器。对 - `internal/extension/metrics` — 请求指标采集与查询 - `internal/extension/plugin` — 三方插件注册管理(`PluginRegistry` + `CorePluginHooks`) - `internal/extension/codex` — Codex 模型目录 +- `internal/extension/codex_tool_proxy` — apply_patch 代理扩展 +- `internal/extension/kimi_workaround` — Kimi 工具调用轮次限制 - `internal/extension/db` — 持久化 Provider(SQLite 本地 / Cloudflare D1 Worker) ## 三种运行模式 @@ -93,26 +112,22 @@ Moon Bridge 是一个 Go 语言编写的 HTTP 代理/协议转换服务器。对 ## 请求生命周期数据流(Transform 模式) -``` -客户端 (Codex CLI) - │ POST /v1/responses (OpenAI Responses 格式) - ▼ -server.handleResponses() - │ 认证 / 日志 / 统计初始化 / 路由解析 - ▼ -adapter_dispatch.go (Adapter 分发) - │ preferred.Protocol 决定上游协议 - │ - ├── ProtocolAnthropic → anthropic adapter - ├── ProtocolGoogleGenAI → google adapter - ├── ProtocolOpenAIChat → chat adapter - └── ProtocolOpenAIResponse → 直通(无协议转换) - │ - ├── 插件拦截 (PluginHooks) - │ MutateCoreRequest → [Adapter] → RememberContent → OnStreamEvent - │ - ▼ -客户端 ←── OpenAI Responses 响应 +```mermaid +flowchart TD + A["客户端 (Codex CLI)"] + A -->|"POST /v1/responses
(OpenAI Responses 格式)"| B["server.handleResponses()"] + B -->|"认证 / 日志 / 统计初始化 / 路由解析"| C["adapter_dispatch.go (Adapter 分发)"] + C -->|"preferred.Protocol 决定上游协议"| D{"协议分支"} + D -->|"ProtocolAnthropic"| E["anthropic adapter"] + D -->|"ProtocolGoogleGenAI"| F["google adapter"] + D -->|"ProtocolOpenAIChat"| G["chat adapter"] + D -->|"ProtocolOpenAIResponse"| H["直通(无协议转换)"] + E --> I["插件拦截 (PluginHooks)"] + F --> I + G --> I + H --> I + I --> J["MutateCoreRequest → [Adapter] → RememberContent → OnStreamEvent"] + J --> K["客户端
← OpenAI Responses 响应"] ``` ## 模型路由 @@ -141,9 +156,9 @@ adapter_dispatch.go (Adapter 分发) ```go type ClientAdapter interface { - Protocol() string - ToCoreRequest(context.Context, []byte) (*CoreRequest, error) - FromCoreResponse(context.Context, *CoreResponse) ([]byte, error) + ClientProtocol() string + ToCoreRequest(context.Context, any) (*CoreRequest, error) + FromCoreResponse(context.Context, *CoreResponse) (any, error) } type ProviderAdapter interface { diff --git a/docs/development-conventions.md b/docs/development-conventions.md index 6f7eb8fa..6c1457ee 100644 --- a/docs/development-conventions.md +++ b/docs/development-conventions.md @@ -4,47 +4,66 @@ ### 目录布局 -``` -internal/ -├── config/ # 配置加载/校验/Schema -├── logger/ # 结构化日志(slog 封装) -├── openai_dto/ # 共享 OpenAI DTO -├── modelref/ # 模型引用解析 -├── session/ # 会话管理 -├── db/ # 数据库抽象与注册表 -├── format/ # Core 类型(CoreRequest/CoreResponse/Registry/Adapter 接口) -├── protocol/ # 协议转换层 -│ ├── anthropic/ # Anthropic Messages Adapter -│ ├── cache/ # Prompt 缓存规划 -│ ├── chat/ # OpenAI Chat Adapter -│ ├── format/ # (遗留层,功能已迁移到 internal/format) -│ ├── google/ # Google Gemini Adapter -│ └── openai/ # OpenAI Responses Adapter -├── service/ # 业务编排层 -│ ├── api/ # 管理 REST API(路由在 router.go) -│ ├── app/ # 应用生命周期管理、Extension 目录 -│ ├── bridge/ # (空目录,保留以备将来使用) -│ ├── e2e/ # 服务层 E2E 测试 -│ ├── provider/ # Provider 管理器 -│ ├── proxy/ # Capture 模式代理 -│ ├── runtime/ # 运行时上下文 -│ ├── server/ # HTTP 服务器 + 路由 + 认证 + Adapter 分发 -│ │ ├── session/ # 会话管理 -│ │ ├── trace/ # 请求跟踪写入 -│ │ └── usage/ # 用量跟踪 -│ ├── stats/ # 用量统计 -│ └── trace/ # 请求跟踪记录 -├── extension/ # 可插拔扩展 -│ ├── codex/ # Codex 模型目录(catalog.go、default_instructions.go) -│ ├── db/ # 数据库 Provider(sqlite/、d1/) -│ ├── deepseek_v4/ # DeepSeek V4 推理优化 -│ ├── kimi_workaround/ # Kimi 模型 tool call 轮次限制 -│ ├── metrics/ # 用量指标采集与查询 -│ ├── plugin/ # Plugin 接口 + 能力接口 + 注册表 -│ ├── visual/ # 视觉模型分发(CoreProvider 模式) -│ ├── websearch/ # Web Search 编排器 -│ └── websearchinjected/ # 注入式搜索插件 -└── e2e/ # 端到端集成测试(协议转换) +```mermaid +flowchart TD + subgraph internal["internal/"] + direction TB + config["config/ — 配置加载/校验/Schema"] + logger["logger/ — 结构化日志(slog封装)"] + openai_dto["openai_dto/ — 共享 OpenAI DTO"] + modelref["modelref/ — 模型引用解析"] + session["session/ — 会话管理"] + db["db/ — 数据库抽象与注册表"] + fmt["format/ — Core类型/Registry/Adapter接口"] + + subgraph protocol["protocol/ — 协议转换层"] + direction TB + pa["anthropic/ — Anthropic Messages Adapter"] + pc["cache/ — Prompt 缓存规划"] + pch["chat/ — OpenAI Chat Adapter"] + pf["format/ — (遗留层,功能已迁移到 internal/format)"] + pg["google/ — Google Gemini Adapter"] + po["openai/ — OpenAI Responses Adapter"] + end + + subgraph service["service/ — 业务编排层"] + direction TB + sa["api/ — 管理 REST API"] + sapp["app/ — 应用生命周期管理、Extension 目录"] + se["e2e/ — 服务层 E2E 测试"] + sp["provider/ — Provider 管理器"] + spr["proxy/ — Capture 模式代理"] + srt["runtime/ — 运行时上下文"] + subgraph srv["server/ — HTTP服务器/路由/认证/Adapter分发"] + direction TB + ss["session/ — 会话管理"] + st["trace/ — 请求跟踪写入"] + su["usage/ — 用量跟踪"] + end + sst["stats/ — 用量统计"] + str["trace/ — 请求跟踪记录"] + end + + subgraph extension["extension/ — 可插拔扩展"] + direction TB + ec["codex/ — Codex 模型目录"] + subgraph edb["db/ — 数据库 Provider"] + es["sqlite/"] + ed1["d1/"] + end + eds["deepseek_v4/ — DeepSeek V4 推理优化"] + ek["kimi_workaround/ — Kimi tool call 轮次限制"] + em["metrics/ — 用量指标采集与查询"] + ep["plugin/ — Plugin 接口+能力接口+注册表"] + ev["visual/ — 视觉模型分发(CoreProvider模式)"] + ew["websearch/ — Web Search 编排器"] + ewi["websearchinjected/ — 注入式搜索插件"] + ectp["codex_tool_proxy/ — apply_patch 代理扩展"] + ect["codextool/ — 工具类型定义与工具映射"] + end + + e2e["e2e/ — 端到端集成测试(协议转换)"] + end ``` ### 依赖方向 diff --git a/docs/extensions-overview.md b/docs/extensions-overview.md index c7d78bd1..b773662f 100644 --- a/docs/extensions-overview.md +++ b/docs/extensions-overview.md @@ -250,6 +250,46 @@ moonbridge -config config.yml -print-codex-config my-model --- + +--- + +## codex_tool_proxy(apply_patch 代理扩展) + +控制 Codex 的 `apply_patch` 自定义工具是否被展开为 5 个结构化代理工具(`add_file`、`delete_file`、`update_file`、`replace_file`、`batch`)发送给上游模型。 + +**位置**:`internal/extension/codex_tool_proxy/` + +**文件清单**: + +| 文件 | 用途 | +|------|------| +| `plugin.go` | Plugin 实现 + PatchProxyDecider | + +**行为**: + +- **默认关闭**(`DefaultEnabled: false`):`apply_patch` 以 raw grammar 形态原样透传给上游模型 +- **开启后**:展开为 5 个独立的结构化工具,让上游模型以 JSON schema 方式调用 + +**实现的能力**: + +```go +var ( + _ plugin.Plugin = (*ProxyPlugin)(nil) + _ plugin.ConfigSpecProvider = (*ProxyPlugin)(nil) + _ plugin.PatchProxyDecider = (*ProxyPlugin)(nil) +) +``` + +**启用方式**: + +```yaml +extensions: + codex_tool_proxy: + enabled: true +``` + +支持 route / model / provider 级别覆盖。 + ## visual(视觉扩展) diff --git a/internal/config/config.go b/internal/config/config.go index 091b0905..e5695535 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,8 +20,8 @@ const ( ProtocolAnthropic = "anthropic" ProtocolOpenAIResponse = "openai-response" // Phase 5: New protocol constants (D-08) - ProtocolGoogleGenAI = "google-genai" - ProtocolOpenAIChat = "openai-chat" + ProtocolGoogleGenAI = "google-genai" + ProtocolOpenAIChat = "openai-chat" ) type Mode string @@ -51,22 +51,22 @@ type WebSearchConfig struct { } type Config struct { - Mode Mode - Addr string - AuthToken string - TraceRequests bool - LogLevel string - LogFormat string - SystemPrompt string - DefaultModel string - WebSearchSupport WebSearchSupport - WebSearchMaxUses int - TavilyAPIKey string - FirecrawlAPIKey string - SearchMaxRounds int - DefaultMaxTokens int - MaxSessions int `yaml:"max_sessions"` // 0 = unlimited - SessionTTL string `yaml:"session_ttl"` // default "24h" + Mode Mode + Addr string + AuthToken string + TraceRequests bool + LogLevel string + LogFormat string + SystemPrompt string + DefaultModel string + WebSearchSupport WebSearchSupport + WebSearchMaxUses int + TavilyAPIKey string + FirecrawlAPIKey string + SearchMaxRounds int + DefaultMaxTokens int + MaxSessions int `yaml:"max_sessions"` // 0 = unlimited + SessionTTL string `yaml:"session_ttl"` // default "24h" // Defaults holds the default configuration values. Defaults Defaults // Models is the canonical model definition map (shared across providers). @@ -115,11 +115,11 @@ type RouteEntry struct { // ProviderDef defines a single upstream provider. type ProviderDef struct { - BaseURL string - APIKey string - Version string - UserAgent string - Protocol string // "anthropic" (default), "openai-response", "google-genai", or "openai-chat" + BaseURL string + APIKey string + Version string + UserAgent string + Protocol string // "anthropic" (default), "openai-response", "google-genai", or "openai-chat" // Phase 5: Google GenAI flat fields (D-09). // Only relevant when Protocol == ProtocolGoogleGenAI. // project: Google Cloud project ID (Vertex AI). @@ -129,7 +129,7 @@ type ProviderDef struct { Location string `yaml:"location,omitempty"` APIVersion string `yaml:"api_version,omitempty"` // Cache config for this provider. If nil, provider does not use caching. - Cache *CacheConfig `yaml:"cache,omitempty"` + Cache *CacheConfig `yaml:"cache,omitempty"` WebSearchSupport WebSearchSupport WebSearchMaxUses int TavilyAPIKey string @@ -204,11 +204,11 @@ type ModelDef struct { // OfferEntry declares that a provider offers a model defined in Models. type OfferEntry struct { - Model string // references models. - UpstreamName string // optional, upstream model name (empty = same as slug) - Priority int // lower value = higher priority (0 is highest) + Model string // references models. + UpstreamName string // optional, upstream model name (empty = same as slug) + Priority int // lower value = higher priority (0 is highest) Pricing ModelPricing - Overrides *ModelDef // optional provider-specific overrides + Overrides *ModelDef // optional provider-specific overrides } type ResponseProxyConfig struct { diff --git a/internal/config/config_loader.go b/internal/config/config_loader.go index 76dee14c..2b366138 100644 --- a/internal/config/config_loader.go +++ b/internal/config/config_loader.go @@ -88,8 +88,8 @@ type TraceFileConfig struct { } type ServerFileConfig struct { - Addr string `yaml:"addr" json:"addr,omitempty"` - AuthToken string `yaml:"auth_token" json:"auth_token,omitempty"` + Addr string `yaml:"addr" json:"addr,omitempty"` + AuthToken string `yaml:"auth_token" json:"auth_token,omitempty"` MaxSessions int `yaml:"max_sessions"` SessionTTL string `yaml:"session_ttl"` } @@ -122,33 +122,33 @@ type ModelDefFileConfig struct { } type OfferFileConfig struct { - Model string `yaml:"model" json:"model"` - UpstreamName string `yaml:"upstream_name,omitempty" json:"upstream_name,omitempty"` - Priority int `yaml:"priority,omitempty" json:"priority,omitempty"` - Pricing ModelPricingFileConfig `yaml:"pricing,omitempty" json:"pricing,omitempty"` - Overrides *ModelDefFileConfig `yaml:"overrides,omitempty" json:"overrides,omitempty"` + Model string `yaml:"model" json:"model"` + UpstreamName string `yaml:"upstream_name,omitempty" json:"upstream_name,omitempty"` + Priority int `yaml:"priority,omitempty" json:"priority,omitempty"` + Pricing ModelPricingFileConfig `yaml:"pricing,omitempty" json:"pricing,omitempty"` + Overrides *ModelDefFileConfig `yaml:"overrides,omitempty" json:"overrides,omitempty"` } type ProviderDefFileConfig struct { - BaseURL string `yaml:"base_url" json:"base_url"` - APIKey string `yaml:"api_key" json:"api_key"` - Version string `yaml:"version,omitempty" json:"version,omitempty"` - UserAgent string `yaml:"user_agent,omitempty" json:"user_agent,omitempty"` - Protocol string `yaml:"protocol,omitempty" json:"protocol,omitempty"` - WebSearch WebSearchFileConfig `yaml:"web_search,omitempty" json:"web_search,omitempty"` - Extensions map[string]ExtensionFileConfig `yaml:"extensions,omitempty" json:"extensions,omitempty"` - Offers []OfferFileConfig `yaml:"offers,omitempty" json:"offers,omitempty"` + BaseURL string `yaml:"base_url" json:"base_url"` + APIKey string `yaml:"api_key" json:"api_key"` + Version string `yaml:"version,omitempty" json:"version,omitempty"` + UserAgent string `yaml:"user_agent,omitempty" json:"user_agent,omitempty"` + Protocol string `yaml:"protocol,omitempty" json:"protocol,omitempty"` + WebSearch WebSearchFileConfig `yaml:"web_search,omitempty" json:"web_search,omitempty"` + Extensions map[string]ExtensionFileConfig `yaml:"extensions,omitempty" json:"extensions,omitempty"` + Offers []OfferFileConfig `yaml:"offers,omitempty" json:"offers,omitempty"` } type RouteFileConfig struct { - To string `yaml:"to,omitempty" json:"to,omitempty"` // backward compat "provider/model" - Model string `yaml:"model,omitempty" json:"model,omitempty"` - Provider string `yaml:"provider,omitempty" json:"provider,omitempty"` - DisplayName string `yaml:"display_name,omitempty" json:"display_name,omitempty"` - Description string `yaml:"description,omitempty" json:"description,omitempty"` - ContextWindow int `yaml:"context_window,omitempty" json:"context_window,omitempty"` - WebSearch WebSearchFileConfig `yaml:"web_search,omitempty" json:"web_search,omitempty"` - Extensions map[string]ExtensionFileConfig `yaml:"extensions,omitempty" json:"extensions,omitempty"` + To string `yaml:"to,omitempty" json:"to,omitempty"` // backward compat "provider/model" + Model string `yaml:"model,omitempty" json:"model,omitempty"` + Provider string `yaml:"provider,omitempty" json:"provider,omitempty"` + DisplayName string `yaml:"display_name,omitempty" json:"display_name,omitempty"` + Description string `yaml:"description,omitempty" json:"description,omitempty"` + ContextWindow int `yaml:"context_window,omitempty" json:"context_window,omitempty"` + WebSearch WebSearchFileConfig `yaml:"web_search,omitempty" json:"web_search,omitempty"` + Extensions map[string]ExtensionFileConfig `yaml:"extensions,omitempty" json:"extensions,omitempty"` } func (cfg *RouteFileConfig) UnmarshalYAML(value *yaml.Node) error { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index a143cd48..4bbd6992 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -6,9 +6,9 @@ import ( "strings" "testing" + "moonbridge/internal/config" deepseekv4 "moonbridge/internal/extension/deepseek_v4" "moonbridge/internal/extension/visual" - "moonbridge/internal/config" ) func builtinExtensionSpecsForTest() []config.ExtensionConfigSpec { diff --git a/internal/e2e/anthropic_e2e_test.go b/internal/e2e/anthropic_e2e_test.go index 5235a3c5..b9769d07 100644 --- a/internal/e2e/anthropic_e2e_test.go +++ b/internal/e2e/anthropic_e2e_test.go @@ -12,8 +12,8 @@ import ( "strings" "testing" - "moonbridge/internal/protocol/anthropic" "moonbridge/internal/format" + "moonbridge/internal/protocol/anthropic" "moonbridge/internal/protocol/openai" ) @@ -73,8 +73,8 @@ func TestAnthropicE2E_TextRoundTrip(t *testing.T) { // Step 1: Build OpenAI Responses request. openAIReq := openai.ResponsesRequest{ - Model: "claude-3.5-sonnet", - Input: json.RawMessage(`"Hello"`), + Model: "claude-3.5-sonnet", + Input: json.RawMessage(`"Hello"`), MaxOutputTokens: 100, } @@ -292,7 +292,7 @@ func TestAnthropicE2E_ToolUseRoundTrip(t *testing.T) { }, }, }, - ToolChoice: json.RawMessage(`"required"`), + ToolChoice: json.RawMessage(`"required"`), MaxOutputTokens: 200, } @@ -436,10 +436,10 @@ func TestAnthropicE2E_Streaming(t *testing.T) { // Step 1: Build streaming OpenAI request. openAIReq := openai.ResponsesRequest{ - Model: "claude-3.5-sonnet", - Input: json.RawMessage(`"Hello streaming"`), + Model: "claude-3.5-sonnet", + Input: json.RawMessage(`"Hello streaming"`), MaxOutputTokens: 100, - Stream: true, + Stream: true, } // Step 2: ClientAdapter.ToCoreRequest. @@ -480,11 +480,17 @@ func TestAnthropicE2E_Streaming(t *testing.T) { } // Step 6: ClientStreamAdapter.FromCoreStream. - streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents) + streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents.Events) if err != nil { t.Fatalf("FromCoreStream: %v", err) } - openAIStream := streamOutAny.(<-chan openai.StreamEvent) + var openAIStream <-chan openai.StreamEvent + oaiResult, ok := streamOutAny.(*openai.OpenAIStreamResult) + if ok { + openAIStream = oaiResult.Chan() + } else { + openAIStream = streamOutAny.(<-chan openai.StreamEvent) + } // Consume the OpenAI stream and verify expected events. var seenEvents []string @@ -569,8 +575,8 @@ func TestAnthropicE2E_ErrorResponse(t *testing.T) { defer mockSrv.Close() openAIReq := openai.ResponsesRequest{ - Model: "claude-3.5-sonnet", - Input: json.RawMessage(`"Hello"`), + Model: "claude-3.5-sonnet", + Input: json.RawMessage(`"Hello"`), MaxOutputTokens: 100, } @@ -696,7 +702,7 @@ func TestAnthropicE2E_MultiTurnToolChain(t *testing.T) { }, }, }, - ToolChoice: json.RawMessage(`"auto"`), + ToolChoice: json.RawMessage(`"auto"`), MaxOutputTokens: 300, } @@ -840,7 +846,6 @@ func TestAnthropicE2E_MultiTurnToolChain(t *testing.T) { } } - // ============================================================================ const ( @@ -947,7 +952,7 @@ func testAnthropicRealToolCall(t *testing.T, apiKey string) { "required": []any{"city"}, }, }}, - ToolChoice: json.RawMessage(`"auto"`), + ToolChoice: json.RawMessage(`"auto"`), MaxOutputTokens: 300, } @@ -1044,7 +1049,7 @@ func testAnthropicRealMultiTurnToolChain(t *testing.T, apiKey string) { "required": []any{"city"}, }, }}, - ToolChoice: json.RawMessage(`"auto"`), + ToolChoice: json.RawMessage(`"auto"`), MaxOutputTokens: 300, } @@ -1109,9 +1114,9 @@ func testAnthropicRealMultiTurnToolChain(t *testing.T, apiKey string) { }, anthropic.Message{ Role: "user", Content: []anthropic.ContentBlock{{ - Type: "tool_result", + Type: "tool_result", ToolUseID: toolCallID, - Content: []anthropic.ContentBlock{{Type: "text", Text: "The weather in Tokyo is 25 degrees and Sunny."}}, + Content: []anthropic.ContentBlock{{Type: "text", Text: "The weather in Tokyo is 25 degrees and Sunny."}}, }}, }), } @@ -1138,7 +1143,6 @@ func testAnthropicRealMultiTurnToolChain(t *testing.T, apiKey string) { } } - // Config constants (used by E2E tests) // ============================================================================ // Config constants (used by E2E tests) diff --git a/internal/e2e/e2e_test.go b/internal/e2e/e2e_test.go index a6b69efe..6bcfcc64 100644 --- a/internal/e2e/e2e_test.go +++ b/internal/e2e/e2e_test.go @@ -3,19 +3,19 @@ package e2e_test import ( + "bufio" "context" "fmt" "net/http" "os" - "bufio" "path/filepath" "strings" "testing" "moonbridge/internal/config" + "moonbridge/internal/format" "moonbridge/internal/protocol/anthropic" "moonbridge/internal/protocol/chat" - "moonbridge/internal/format" "moonbridge/internal/protocol/google" "moonbridge/internal/protocol/openai" ) @@ -187,7 +187,6 @@ func assertResponseBasics(t testing.TB, oaiResp *openai.Response, wantModel stri } } - func loadDotEnv(t testing.TB) { if t != nil { t.Helper() @@ -203,10 +202,10 @@ func loadDotEnv(t testing.TB) { f, err := os.Open(path) if err != nil { if t != nil { - t.Logf("warning: cannot open %s: %v", path, err) - } else { - println("warning: cannot open", path, err.Error()) - } + t.Logf("warning: cannot open %s: %v", path, err) + } else { + println("warning: cannot open", path, err.Error()) + } return } defer f.Close() @@ -238,4 +237,4 @@ func loadDotEnv(t testing.TB) { } dir = parent } -} \ No newline at end of file +} diff --git a/internal/e2e/google_genai_e2e_test.go b/internal/e2e/google_genai_e2e_test.go index 163d7129..8320129a 100644 --- a/internal/e2e/google_genai_e2e_test.go +++ b/internal/e2e/google_genai_e2e_test.go @@ -64,8 +64,8 @@ func TestGoogleGenaiE2E_TextRoundTrip(t *testing.T) { // Step 1: Build OpenAI Responses request. openAIReq := openai.ResponsesRequest{ - Model: "gemini-2.0-flash", - Input: json.RawMessage(`"Hello"`), + Model: "gemini-2.0-flash", + Input: json.RawMessage(`"Hello"`), MaxOutputTokens: 100, } @@ -395,10 +395,10 @@ func TestGoogleGenaiE2E_Streaming(t *testing.T) { // Step 1: Build streaming OpenAI request. openAIReq := openai.ResponsesRequest{ - Model: "gemini-2.0-flash", - Input: json.RawMessage(`"Hello streaming"`), + Model: "gemini-2.0-flash", + Input: json.RawMessage(`"Hello streaming"`), MaxOutputTokens: 100, - Stream: true, + Stream: true, } // Step 2: ClientAdapter.ToCoreRequest. @@ -446,11 +446,17 @@ func TestGoogleGenaiE2E_Streaming(t *testing.T) { } // Step 5: ClientStreamAdapter.FromCoreStream. - streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents) + streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents.Events) if err != nil { t.Fatalf("FromCoreStream: %v", err) } - openAIStream := streamOutAny.(<-chan openai.StreamEvent) + var openAIStream <-chan openai.StreamEvent + oaiResult, ok := streamOutAny.(*openai.OpenAIStreamResult) + if ok { + openAIStream = oaiResult.Chan() + } else { + openAIStream = streamOutAny.(<-chan openai.StreamEvent) + } // Consume the OpenAI stream and verify expected events. var seenEvents []string @@ -532,8 +538,8 @@ func TestGoogleGenaiE2E_ErrorResponse(t *testing.T) { defer mockSrv.Close() openAIReq := openai.ResponsesRequest{ - Model: "gemini-2.0-flash", - Input: json.RawMessage(`"Hello"`), + Model: "gemini-2.0-flash", + Input: json.RawMessage(`"Hello"`), MaxOutputTokens: 100, } diff --git a/internal/e2e/openai_chat_e2e_test.go b/internal/e2e/openai_chat_e2e_test.go index 7e74b3b9..c148bee5 100644 --- a/internal/e2e/openai_chat_e2e_test.go +++ b/internal/e2e/openai_chat_e2e_test.go @@ -12,8 +12,8 @@ import ( "strings" "testing" - "moonbridge/internal/protocol/chat" "moonbridge/internal/format" + "moonbridge/internal/protocol/chat" "moonbridge/internal/protocol/openai" ) @@ -71,8 +71,8 @@ func TestOpenAIChatE2E_TextRoundTrip(t *testing.T) { // Step 1: Build OpenAI Responses request. openAIReq := openai.ResponsesRequest{ - Model: "gpt-4o", - Input: json.RawMessage(`"Hello"`), + Model: "gpt-4o", + Input: json.RawMessage(`"Hello"`), MaxOutputTokens: 100, } @@ -395,10 +395,10 @@ func TestOpenAIChatE2E_Streaming(t *testing.T) { // Step 1: Build streaming OpenAI request. openAIReq := openai.ResponsesRequest{ - Model: "gpt-4o", - Input: json.RawMessage(`"Hello streaming"`), + Model: "gpt-4o", + Input: json.RawMessage(`"Hello streaming"`), MaxOutputTokens: 100, - Stream: true, + Stream: true, } // Step 2: ClientAdapter.ToCoreRequest. @@ -454,11 +454,17 @@ func TestOpenAIChatE2E_Streaming(t *testing.T) { } // Step 5: ClientStreamAdapter.FromCoreStream. - streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents) + streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents.Events) if err != nil { t.Fatalf("FromCoreStream: %v", err) } - openAIStream := streamOutAny.(<-chan openai.StreamEvent) + var openAIStream <-chan openai.StreamEvent + oaiResult, ok := streamOutAny.(*openai.OpenAIStreamResult) + if ok { + openAIStream = oaiResult.Chan() + } else { + openAIStream = streamOutAny.(<-chan openai.StreamEvent) + } // Consume the OpenAI stream and verify expected events. var seenEvents []string @@ -540,8 +546,8 @@ func TestOpenAIChatE2E_ErrorResponse(t *testing.T) { defer mockSrv.Close() openAIReq := openai.ResponsesRequest{ - Model: "gpt-4o", - Input: json.RawMessage(`"Hello"`), + Model: "gpt-4o", + Input: json.RawMessage(`"Hello"`), MaxOutputTokens: 100, } diff --git a/internal/e2e/openai_response_e2e_test.go b/internal/e2e/openai_response_e2e_test.go index 1e92b83b..57b3359d 100644 --- a/internal/e2e/openai_response_e2e_test.go +++ b/internal/e2e/openai_response_e2e_test.go @@ -311,8 +311,8 @@ func TestOpenAIResponsePassthroughE2E_Streaming(t *testing.T) { // content_block.started (text type) events <- format.CoreStreamEvent{ - Type: format.CoreContentBlockStarted, - Index: 0, + Type: format.CoreContentBlockStarted, + Index: 0, ContentBlock: &format.CoreContentBlock{Type: "text"}, } // text delta @@ -347,7 +347,13 @@ func TestOpenAIResponsePassthroughE2E_Streaming(t *testing.T) { if err != nil { t.Fatalf("FromCoreStream: %v", err) } - openAIStream := streamOutAny.(<-chan openai.StreamEvent) + var openAIStream <-chan openai.StreamEvent + oaiResult, ok := streamOutAny.(*openai.OpenAIStreamResult) + if ok { + openAIStream = oaiResult.Chan() + } else { + openAIStream = streamOutAny.(<-chan openai.StreamEvent) + } // Step 3: Consume and verify. var seenEvents []string diff --git a/internal/e2e/plugin_hooks_e2e_test.go b/internal/e2e/plugin_hooks_e2e_test.go index ee9f341e..f90227a4 100644 --- a/internal/e2e/plugin_hooks_e2e_test.go +++ b/internal/e2e/plugin_hooks_e2e_test.go @@ -616,11 +616,17 @@ func TestPluginHooks_OnStreamEvent(t *testing.T) { t.Fatalf("ToCoreStream: %v", err) } - streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents) + streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents.Events) if err != nil { t.Fatalf("FromCoreStream: %v", err) } - openAIStream := streamOutAny.(<-chan openai.StreamEvent) + var openAIStream <-chan openai.StreamEvent + oaiResult, ok := streamOutAny.(*openai.OpenAIStreamResult) + if ok { + openAIStream = oaiResult.Chan() + } else { + openAIStream = streamOutAny.(<-chan openai.StreamEvent) + } var seenEvents []string for ev := range openAIStream { @@ -769,11 +775,17 @@ func TestPluginHooks_OnStreamComplete(t *testing.T) { t.Fatalf("ToCoreStream: %v", err) } - streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents) + streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents.Events) if err != nil { t.Fatalf("FromCoreStream: %v", err) } - openAIStream := streamOutAny.(<-chan openai.StreamEvent) + var openAIStream <-chan openai.StreamEvent + oaiResult, ok := streamOutAny.(*openai.OpenAIStreamResult) + if ok { + openAIStream = oaiResult.Chan() + } else { + openAIStream = streamOutAny.(<-chan openai.StreamEvent) + } // Consume all stream events to trigger completion. for ev := range openAIStream { diff --git a/internal/e2e/visual_chat_e2e_test.go b/internal/e2e/visual_chat_e2e_test.go index 14718269..493c34a2 100644 --- a/internal/e2e/visual_chat_e2e_test.go +++ b/internal/e2e/visual_chat_e2e_test.go @@ -36,9 +36,9 @@ func TestVisualOnOpenAIChat_OrchestratesBriefAcrossTwoMocks(t *testing.T) { ctx := context.Background() type observed struct { - mu sync.Mutex - bodies [][]byte - rounds int + mu sync.Mutex + bodies [][]byte + rounds int } upstreamObs := &observed{} visualObs := &observed{} diff --git a/internal/e2e/websearch_injection_e2e_test.go b/internal/e2e/websearch_injection_e2e_test.go index 68cf1866..b5e41179 100644 --- a/internal/e2e/websearch_injection_e2e_test.go +++ b/internal/e2e/websearch_injection_e2e_test.go @@ -10,8 +10,8 @@ import ( "net/http/httptest" "testing" - "moonbridge/internal/protocol/anthropic" "moonbridge/internal/format" + "moonbridge/internal/protocol/anthropic" "moonbridge/internal/protocol/openai" ) @@ -109,8 +109,8 @@ func TestWebSearchE2E_InjectionEnabled(t *testing.T) { // Step 1: Build an OpenAI ResponsesRequest WITHOUT web_search in tools. openAIReq := openai.ResponsesRequest{ - Model: "claude-3.5-sonnet", - Input: json.RawMessage(`"Search for latest AI breakthroughs"`), + Model: "claude-3.5-sonnet", + Input: json.RawMessage(`"Search for latest AI breakthroughs"`), MaxOutputTokens: 100, } @@ -253,8 +253,8 @@ func TestWebSearchE2E_InjectionDisabled(t *testing.T) { // Step 1: Build OpenAI request without web_search. openAIReq := openai.ResponsesRequest{ - Model: "claude-3.5-sonnet", - Input: json.RawMessage(`"Hello"`), + Model: "claude-3.5-sonnet", + Input: json.RawMessage(`"Hello"`), MaxOutputTokens: 100, } @@ -367,7 +367,7 @@ func TestWebSearchE2E_AlreadyPresentNotOverwritten(t *testing.T) { Name: "get_weather", Description: "Get the current weather for a city", Parameters: map[string]any{ - "type": "object", + "type": "object", "properties": map[string]any{ "city": map[string]any{"type": "string"}, }, diff --git a/internal/extension/codex_tool_proxy/plugin.go b/internal/extension/codex_tool_proxy/plugin.go index 863bd75d..681d3b63 100644 --- a/internal/extension/codex_tool_proxy/plugin.go +++ b/internal/extension/codex_tool_proxy/plugin.go @@ -54,7 +54,7 @@ func DisablePatchProxyForModel(cfg config.Config, model string) bool { func ConfigSpecs() []config.ExtensionConfigSpec { return []config.ExtensionConfigSpec{{ Name: PluginName, - DefaultEnabled: true, + DefaultEnabled: false, Scopes: []config.ExtensionScope{ config.ExtensionScopeGlobal, config.ExtensionScopeProvider, diff --git a/internal/extension/codex_tool_proxy/plugin_test.go b/internal/extension/codex_tool_proxy/plugin_test.go index 7ef37212..a239bafb 100644 --- a/internal/extension/codex_tool_proxy/plugin_test.go +++ b/internal/extension/codex_tool_proxy/plugin_test.go @@ -11,8 +11,8 @@ func TestDisablePatchProxyForModelDefaultEnabled(t *testing.T) { if err != nil { t.Fatal(err) } - if DisablePatchProxyForModel(cfg, "test-model") { - t.Fatal("expected proxy enabled by default") + if !DisablePatchProxyForModel(cfg, "test-model") { + t.Fatal("expected proxy disabled by default") } } @@ -34,8 +34,8 @@ func TestDisablePatchProxyForModelRouteOverride(t *testing.T) { if !DisablePatchProxyForModel(cfg, "specific") { t.Fatal("expected proxy disabled via specific route override") } - if DisablePatchProxyForModel(cfg, "unmatched-model") { - t.Fatal("expected proxy enabled for unmatched model") + if !DisablePatchProxyForModel(cfg, "unmatched-model") { + t.Fatal("expected proxy disabled for unmatched model (default)") } } diff --git a/internal/extension/codextool/customtool.go b/internal/extension/codextool/customtool.go index a83bf167..a52fea9e 100644 --- a/internal/extension/codextool/customtool.go +++ b/internal/extension/codextool/customtool.go @@ -114,6 +114,13 @@ func OutputItemFromBlock( return "custom_tool_call", spec.OpenAIName, "", RebuildGrammar(blockName, toolInput), false, nil case ToolRaw: return "custom_tool_call", spec.OpenAIName, "", InputFromRaw(toolInput), false, nil + case ToolNestedOneOf, ToolNestedAnyOf: + action, paramsStr, err := DecodeNestedCall(toolInput, spec.Kind) + if err != nil || action == "" { + // Fallback: return raw input with blockName as the tool Name. + return "function_call", blockName, spec.Namespace, string(toolInput), false, nil + } + return "function_call", action, spec.Namespace, string(paramsStr), false, nil case ToolFunction: return "function_call", spec.OpenAIName, spec.Namespace, string(toolInput), false, nil default: diff --git a/internal/extension/codextool/namespace_schema.go b/internal/extension/codextool/namespace_schema.go new file mode 100644 index 00000000..650c7a9a --- /dev/null +++ b/internal/extension/codextool/namespace_schema.go @@ -0,0 +1,248 @@ +// Package codextool provides namespace tool flattening and nested schema building. +package codextool + +import ( + "encoding/json" + "fmt" + "strings" + + "moonbridge/internal/format" +) + +// NamespaceStrategy controls how namespace tools are converted for upstream providers. +type NamespaceStrategy string + +const ( + NestedOneOf NamespaceStrategy = "nested_oneof" + NestedAnyOf NamespaceStrategy = "nested_anyof" + Flat NamespaceStrategy = "flat" +) + +// BuildNamespaceTools converts a namespace tool to one or more CoreTools +// according to the given strategy. +func BuildNamespaceTools( + toolNames []string, + toolMap map[string]format.CoreTool, + parentNamespace string, + strategy NamespaceStrategy, +) ([]format.CoreTool, error) { + switch strategy { + case NestedOneOf: + return buildNestedOneOf(toolNames, toolMap, parentNamespace), nil + case NestedAnyOf: + return buildNestedAnyOf(toolNames, toolMap, parentNamespace), nil + case Flat: + return buildFlat(toolNames, toolMap, parentNamespace), nil + default: + return buildFlat(toolNames, toolMap, parentNamespace), nil + } +} + +// buildNestedOneOf generates a single tool with a oneOf Schema. +// Each sub-tool becomes a oneOf branch keyed by a single-value "action" enum. +func buildNestedOneOf(toolNames []string, toolMap map[string]format.CoreTool, namespace string) []format.CoreTool { + if len(toolNames) == 0 { + return nil + } + + mergedName := namespace + + oneOf := make([]map[string]any, 0, len(toolNames)) + for _, name := range toolNames { + sub, ok := toolMap[name] + if !ok { + continue + } + props := make(map[string]any) + required := []string{"action"} + if sub.InputSchema != nil { + if p, ok := sub.InputSchema["properties"].(map[string]any); ok { + for k, v := range p { + props[k] = v + } + } + if r, ok := sub.InputSchema["required"].([]any); ok { + for _, rv := range r { + if rs, ok := rv.(string); ok { + required = append(required, rs) + } + } + } + } + // action field with single-value enum + props["action"] = map[string]any{ + "type": "string", + "enum": []string{name}, + } + branch := map[string]any{ + "type": "object", + "title": name, + "properties": props, + "required": required, + "additionalProperties": false, + } + oneOf = append(oneOf, branch) + } + + if len(oneOf) == 0 { + return nil + } + + mergedSchema := map[string]any{ + "type": "object", + "oneOf": oneOf, + } + + ct := format.CoreTool{ + Name: mergedName, + Description: fmt.Sprintf("Namespace tool with %d sub-tools. Pick the matching action.", len(oneOf)), + InputSchema: mergedSchema, + } + AnnotateCoreTool(&ct, ToolNestedOneOf, mergedName, namespace) + return []format.CoreTool{ct} +} + +// buildNestedAnyOf generates a single tool with an action enum + params anyOf +// (the PR #75 compatible format). +func buildNestedAnyOf(toolNames []string, toolMap map[string]format.CoreTool, namespace string) []format.CoreTool { + if len(toolNames) == 0 { + return nil + } + + mergedName := namespace + + actions := make([]string, 0, len(toolNames)) + anyOf := make([]map[string]any, 0, len(toolNames)) + for _, name := range toolNames { + sub, ok := toolMap[name] + if !ok { + continue + } + actions = append(actions, name) + branch := sub.InputSchema + if branch == nil { + branch = map[string]any{"type": "object"} + } + anyOf = append(anyOf, branch) + } + + mergedSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": actions, + }, + "params": map[string]any{ + "oneOf": anyOf, + }, + }, + "required": []string{"action", "params"}, + "additionalProperties": false, + } + + ct := format.CoreTool{ + Name: mergedName, + Description: fmt.Sprintf("Namespace tool with %d sub-tools. Use action to select, params for arguments.", len(anyOf)), + InputSchema: mergedSchema, + } + AnnotateCoreTool(&ct, ToolNestedAnyOf, mergedName, namespace) + return []format.CoreTool{ct} +} + +// buildFlat flattens namespace tools into individual CoreTools. +func buildFlat(toolNames []string, toolMap map[string]format.CoreTool, namespace string) []format.CoreTool { + result := make([]format.CoreTool, 0, len(toolNames)) + for _, name := range toolNames { + sub, ok := toolMap[name] + if !ok { + continue + } + fullName := NamespacedToolName(namespace, name) + ct := format.CoreTool{ + Name: fullName, + Description: sub.Description, + InputSchema: sub.InputSchema, + } + AnnotateCoreTool(&ct, ToolFunction, name, namespace) + result = append(result, ct) + } + return result +} + +// DecodeNestedCall extracts the action name and parameters from a nested +// namespace tool call, regardless of whether the model used the oneOf +// or anyOf schema format. +// +// For oneOf format: {"action": "read_file", "path": "/foo", ...} +// +// → action = "read_file", params = {"path": "/foo", ...} +// +// For anyOf format: {"action": "read_file", "params": {"path": "/foo"}} +// +// → action = "read_file", params = {"path": "/foo"} +func DecodeNestedCall(input json.RawMessage, schemaKind ToolKind) (action string, params json.RawMessage, err error) { + if len(input) == 0 || string(input) == "null" { + return "", nil, fmt.Errorf("empty input") + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(input, &raw); err != nil { + return "", nil, fmt.Errorf("unmarshal: %w", err) + } + + // Extract action + actionBytes, hasAction := raw["action"] + if !hasAction { + return "", input, fmt.Errorf("no action field") + } + if err := json.Unmarshal(actionBytes, &action); err != nil { + return "", input, fmt.Errorf("action parse: %w", err) + } + + // Extract params based on schema format + switch schemaKind { + case ToolNestedOneOf: + // oneOf format: action is inline with other params + delete(raw, "action") + if len(raw) == 0 { + return action, json.RawMessage(`{}`), nil + } + params, err = json.Marshal(raw) + return action, params, err + + case ToolNestedAnyOf: + // anyOf format: params are in a separate "params" field + if paramsRaw, ok := raw["params"]; ok { + return action, paramsRaw, nil + } + return action, json.RawMessage(`{}`), nil + + default: + return action, input, nil + } +} + +// tryExtractAction scans partial JSON for "action": "value" and returns the value. +// Returns ("", false) if the action field is not yet complete. +func TryExtractAction(raw string) (string, bool) { + idx := strings.Index(raw, `"action"`) + if idx < 0 { + return "", false + } + rest := raw[idx+8:] // skip past "action" + // Skip whitespace and colon + rest = strings.TrimSpace(strings.TrimPrefix(rest, ":")) + rest = strings.TrimSpace(rest) + // Find opening quote + if len(rest) == 0 || rest[0] != '"' { + return "", false + } + rest = rest[1:] // skip opening quote + // Find closing quote + q2 := strings.IndexByte(rest, '"') + if q2 < 0 { + return "", false + } + return rest[:q2], true +} diff --git a/internal/extension/codextool/tool_context.go b/internal/extension/codextool/tool_context.go index cd8561a3..a199e0bd 100644 --- a/internal/extension/codextool/tool_context.go +++ b/internal/extension/codextool/tool_context.go @@ -11,12 +11,14 @@ package codextool type ToolKind string const ( - ToolApplyPatch ToolKind = "apply_patch" - ToolExec ToolKind = "exec" - ToolRaw ToolKind = "raw" - ToolFunction ToolKind = "function" - ToolLocalShell ToolKind = "local_shell" - ToolUnknown ToolKind = "unknown" + ToolApplyPatch ToolKind = "apply_patch" + ToolExec ToolKind = "exec" + ToolRaw ToolKind = "raw" + ToolFunction ToolKind = "function" + ToolLocalShell ToolKind = "local_shell" + ToolNestedOneOf ToolKind = "nested_oneof" + ToolNestedAnyOf ToolKind = "nested_namespace" + ToolUnknown ToolKind = "unknown" ) // ToolSpec describes an expanded tool entry for reverse mapping. diff --git a/internal/extension/db/d1/plugin.go b/internal/extension/db/d1/plugin.go index 1ebce296..ef14b174 100644 --- a/internal/extension/db/d1/plugin.go +++ b/internal/extension/db/d1/plugin.go @@ -20,9 +20,9 @@ import ( "database/sql" "fmt" - "moonbridge/internal/extension/plugin" "moonbridge/internal/config" "moonbridge/internal/db" + "moonbridge/internal/extension/plugin" ) const PluginName = "db_d1" diff --git a/internal/extension/db/d1/plugin_test.go b/internal/extension/db/d1/plugin_test.go index c1c7bc12..cb3a2e53 100644 --- a/internal/extension/db/d1/plugin_test.go +++ b/internal/extension/db/d1/plugin_test.go @@ -5,10 +5,10 @@ import ( "database/sql" "testing" - dbd1 "moonbridge/internal/extension/db/d1" - "moonbridge/internal/extension/plugin" "moonbridge/internal/config" "moonbridge/internal/db" + dbd1 "moonbridge/internal/extension/db/d1" + "moonbridge/internal/extension/plugin" _ "modernc.org/sqlite" ) diff --git a/internal/extension/db/sqlite/plugin.go b/internal/extension/db/sqlite/plugin.go index 430e163e..1bfa93a7 100644 --- a/internal/extension/db/sqlite/plugin.go +++ b/internal/extension/db/sqlite/plugin.go @@ -19,9 +19,9 @@ import ( "fmt" "path/filepath" - "moonbridge/internal/extension/plugin" "moonbridge/internal/config" "moonbridge/internal/db" + "moonbridge/internal/extension/plugin" ) const PluginName = "db_sqlite" diff --git a/internal/extension/db/sqlite/plugin_test.go b/internal/extension/db/sqlite/plugin_test.go index 05a0e8c0..f623cc79 100644 --- a/internal/extension/db/sqlite/plugin_test.go +++ b/internal/extension/db/sqlite/plugin_test.go @@ -4,9 +4,9 @@ import ( "context" "testing" + "moonbridge/internal/config" dbsqlite "moonbridge/internal/extension/db/sqlite" "moonbridge/internal/extension/plugin" - "moonbridge/internal/config" ) func TestName(t *testing.T) { diff --git a/internal/extension/deepseek_v4/deepseek_v4.go b/internal/extension/deepseek_v4/deepseek_v4.go index cd5bc2bb..295c7ee3 100644 --- a/internal/extension/deepseek_v4/deepseek_v4.go +++ b/internal/extension/deepseek_v4/deepseek_v4.go @@ -5,9 +5,9 @@ import ( "encoding/json" "strings" - "moonbridge/internal/protocol/openai" - "moonbridge/internal/protocol/anthropic" "moonbridge/internal/format" + "moonbridge/internal/protocol/anthropic" + "moonbridge/internal/protocol/openai" ) // StripReasoningContent removes the reasoning_content field from message diff --git a/internal/extension/deepseek_v4/deepseek_v4_test.go b/internal/extension/deepseek_v4/deepseek_v4_test.go index 66158b2f..59baa9e9 100644 --- a/internal/extension/deepseek_v4/deepseek_v4_test.go +++ b/internal/extension/deepseek_v4/deepseek_v4_test.go @@ -7,10 +7,10 @@ import ( "strings" "testing" - "moonbridge/internal/format" pluginpkg "moonbridge/internal/extension/plugin" - "moonbridge/internal/protocol/openai" + "moonbridge/internal/format" "moonbridge/internal/protocol/anthropic" + "moonbridge/internal/protocol/openai" ) func TestStripReasoningContentStripsField(t *testing.T) { @@ -144,8 +144,8 @@ func TestPrependThinkingWarnsWhenUsingRequiredFallback(t *testing.T) { logs.Reset() summary := []openai.ReasoningItemSummary{{ - Type: "summary_text", - Text: EncodeThinkingSummary(format.CoreContentBlock{Type: "reasoning", ReasoningSignature: "sig_summary"}), + Type: "summary_text", + Text: EncodeThinkingSummary(format.CoreContentBlock{Type: "reasoning", ReasoningSignature: "sig_summary"}), }} got = p.PrependThinkingForToolUse([]format.CoreMessage{{ Role: "assistant", diff --git a/internal/extension/kimi_workaround/plugin_test.go b/internal/extension/kimi_workaround/plugin_test.go index 33b07ed1..4066d47f 100644 --- a/internal/extension/kimi_workaround/plugin_test.go +++ b/internal/extension/kimi_workaround/plugin_test.go @@ -230,4 +230,3 @@ func buildToolConversation(rounds int) []format.CoreMessage { } return msgs } - diff --git a/internal/extension/metrics/plugin.go b/internal/extension/metrics/plugin.go index cc2da6a2..8063dd14 100644 --- a/internal/extension/metrics/plugin.go +++ b/internal/extension/metrics/plugin.go @@ -23,9 +23,9 @@ import ( "strconv" "time" - "moonbridge/internal/extension/plugin" "moonbridge/internal/config" "moonbridge/internal/db" + "moonbridge/internal/extension/plugin" ) const PluginName = "metrics" diff --git a/internal/extension/metrics/plugin_test.go b/internal/extension/metrics/plugin_test.go index 35b81449..f5d12b3d 100644 --- a/internal/extension/metrics/plugin_test.go +++ b/internal/extension/metrics/plugin_test.go @@ -7,10 +7,10 @@ import ( "testing" "time" - mbtrics "moonbridge/internal/extension/metrics" - "moonbridge/internal/extension/plugin" "moonbridge/internal/config" "moonbridge/internal/db" + mbtrics "moonbridge/internal/extension/metrics" + "moonbridge/internal/extension/plugin" _ "modernc.org/sqlite" ) diff --git a/internal/extension/metrics/store_test.go b/internal/extension/metrics/store_test.go index efe2026a..a471df0b 100644 --- a/internal/extension/metrics/store_test.go +++ b/internal/extension/metrics/store_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - mbtrics "moonbridge/internal/extension/metrics" "moonbridge/internal/db" + mbtrics "moonbridge/internal/extension/metrics" _ "modernc.org/sqlite" ) diff --git a/internal/extension/plugin/capabilities.go b/internal/extension/plugin/capabilities.go index 3f3bff56..6206575a 100644 --- a/internal/extension/plugin/capabilities.go +++ b/internal/extension/plugin/capabilities.go @@ -3,12 +3,12 @@ package plugin import ( "context" "encoding/json" - "net/http" "moonbridge/internal/protocol/anthropic" + "net/http" "time" - "moonbridge/internal/logger" "moonbridge/internal/format" + "moonbridge/internal/logger" "moonbridge/internal/protocol/openai" foundationdb "moonbridge/internal/db" @@ -91,7 +91,7 @@ type StreamEvent struct { Type string // "block_start", "block_delta", "block_stop" Index int Block *format.CoreContentBlock // for block_start - Delta anthropic.StreamDelta // for block_delta + Delta anthropic.StreamDelta // for block_delta } // --- Error handling --- @@ -213,10 +213,8 @@ type CoreContentRememberer interface { RememberCoreContent(ctx context.Context, content []format.CoreContentBlock) } - // PatchProxyDecider is implemented by plugins that control whether apply_patch // custom tools are expanded into structured proxy tools for upstream models. type PatchProxyDecider interface { DisablePatchProxy(model string) bool } - diff --git a/internal/extension/visual/chat_strip.go b/internal/extension/visual/chat_strip.go index 029c4698..d5af7548 100644 --- a/internal/extension/visual/chat_strip.go +++ b/internal/extension/visual/chat_strip.go @@ -23,8 +23,17 @@ func StripImagesFromChat(req chat.ChatRequest) (chat.ChatRequest, bool) { for mi := range out.Messages { msg := &out.Messages[mi] - parts, ok := msg.Content.([]chat.ContentPart) - if !ok { + var originalIsString bool + var parts []chat.ContentPart + switch v := msg.Content.(type) { + case string: + originalIsString = true + if v != "" { + parts = []chat.ContentPart{{Type: "text", Text: v}} + } + case []chat.ContentPart: + parts = v + default: continue } newParts := make([]chat.ContentPart, 0, len(parts)) @@ -40,7 +49,11 @@ func StripImagesFromChat(req chat.ChatRequest) (chat.ChatRequest, bool) { } newParts = append(newParts, part) } - msg.Content = newParts + if originalIsString && !modified && len(newParts) == 1 && newParts[0].Type == "text" { + msg.Content = newParts[0].Text + } else { + msg.Content = newParts + } } return out, modified } diff --git a/internal/extension/visual/client.go b/internal/extension/visual/client.go index def1708e..c9a3e592 100644 --- a/internal/extension/visual/client.go +++ b/internal/extension/visual/client.go @@ -8,7 +8,7 @@ import ( "moonbridge/internal/format" ) -const visualSystemPrompt = "You are Kimi running behind Moon Bridge Visual. Analyze images carefully, state uncertainty, and do not invent visual facts." +const visualSystemPrompt = "You are a vision analysis model behind Moon Bridge Visual. Analyze images carefully, state uncertainty, and do not invent visual facts." // CoreProvider is a protocol-agnostic LLM provider interface. // It operates on format.CoreRequest / format.CoreResponse so the visual plugin diff --git a/internal/extension/visual/core_orchestrator.go b/internal/extension/visual/core_orchestrator.go index 0b87bfe4..da6ed3a6 100644 --- a/internal/extension/visual/core_orchestrator.go +++ b/internal/extension/visual/core_orchestrator.go @@ -88,9 +88,33 @@ func (o *CoreOrchestrator) CreateCore(ctx context.Context, req *format.CoreReque } toolUses, nonVisual := coreSplitVisualToolUses(lastAssistant.Content) - if len(nonVisual) > 0 || len(toolUses) == 0 { + if len(toolUses) == 0 { return applyCoreUsageAggregation(resp, aggregatedUsage, hasAggregatedUsage), nil } + if len(nonVisual) > 0 { + // Execute visual tools, feed results back to model, and continue. + // Non-visual tool_uses remain in the assistant message so the Bridge + // can process them on the next round. + toolResults := make([]format.CoreContentBlock, 0, len(toolUses)) + for _, toolUse := range toolUses { + result := o.executeCoreVisualTool(ctx, toolUse, availableImages) + toolResults = append(toolResults, format.CoreContentBlock{ + Type: "tool_result", + ToolUseID: toolUse.ToolUseID, + ToolResultContent: []format.CoreContentBlock{{Type: "text", Text: result}}, + }) + } + req.Messages = append(req.Messages, *lastAssistant) + req.Messages = append(req.Messages, format.CoreMessage{ + Role: "tool", + Content: toolResults, + }) + if req.ToolChoice != nil && req.ToolChoice.Mode != "auto" { + req.ToolChoice = &format.CoreToolChoice{Mode: "auto"} + } + log.Debug("Core visual mixed tool loop", "round", round+1, "visual_tools", len(toolUses), "non_visual", len(nonVisual)) + continue + } // Execute each visual tool via the vision client. toolResults := make([]format.CoreContentBlock, 0, len(toolUses)) @@ -106,7 +130,7 @@ func (o *CoreOrchestrator) CreateCore(ctx context.Context, req *format.CoreReque // Append assistant message and tool_result message for next round. req.Messages = append(req.Messages, *lastAssistant) req.Messages = append(req.Messages, format.CoreMessage{ - Role: "user", + Role: "tool", Content: toolResults, }) @@ -174,26 +198,50 @@ func imageInputFromCoreBlock(block format.CoreContentBlock) (ImageInput, bool) { if block.ImageData == "" { return ImageInput{}, false } + // If MediaType is explicitly set, treat as base64. if block.MediaType != "" { - // base64-encoded image return ImageInput{Data: block.ImageData, MimeType: block.MediaType}, true } - // URL-based image (ImageData holds the URL when MediaType is empty) + // Check for data: URL (contains embedded MIME type). + if strings.HasPrefix(block.ImageData, "data:") { + mediaType, raw := splitDataURL(block.ImageData) + return ImageInput{Data: raw, MimeType: mediaType}, true + } + // URL-based image (ImageData holds the URL when MediaType is empty). url := strings.TrimSpace(block.ImageData) - if !isSupportedImageURL(url) { - return ImageInput{}, false + if isSupportedImageURL(url) { + return ImageInput{URL: url}, true } - return ImageInput{URL: url}, true + // Fallback: treat as base64 with default MIME type rather than silently dropping. + return ImageInput{Data: block.ImageData, MimeType: "image/png"}, true } // findLastAssistantMessage finds the last assistant message in a slice. +// Prefers assistant messages that contain tool_use blocks, searching backward +// so the most recent tool-use-bearing assistant message is returned. func findLastAssistantMessage(messages []format.CoreMessage) *format.CoreMessage { + var candidate *format.CoreMessage for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Role == "assistant" { + if messages[i].Role != "assistant" { + continue + } + // Check if this assistant message has any tool_use content blocks. + hasToolUse := false + for _, block := range messages[i].Content { + if block.Type == "tool_use" { + hasToolUse = true + break + } + } + if hasToolUse { return &messages[i] } + // Fallback: remember the last assistant message if we don't find a better one. + if candidate == nil { + candidate = &messages[i] + } } - return nil + return candidate } // coreSplitVisualToolUses separates visual tool_use blocks from non-visual ones. diff --git a/internal/extension/visual/core_orchestrator_test.go b/internal/extension/visual/core_orchestrator_test.go index cddf80ba..76b06eda 100644 --- a/internal/extension/visual/core_orchestrator_test.go +++ b/internal/extension/visual/core_orchestrator_test.go @@ -899,7 +899,7 @@ func TestCoreOrchestratorIsolatesRequestsAcrossRounds(t *testing.T) { } toolResultMsg := secondReq.Messages[2] - if toolResultMsg.Role != "user" { + if toolResultMsg.Role != "tool" { t.Fatalf("tool_result message role = %q", toolResultMsg.Role) } if len(toolResultMsg.Content) != 1 || toolResultMsg.Content[0].Type != "tool_result" || toolResultMsg.Content[0].ToolUseID != "toolu_1" { diff --git a/internal/extension/visual/legacy.go b/internal/extension/visual/legacy.go index 392020cb..caf0b124 100644 --- a/internal/extension/visual/legacy.go +++ b/internal/extension/visual/legacy.go @@ -101,6 +101,7 @@ func textFromContent(blocks []anthropic.ContentBlock) string { } // HasAnthropicSource checks if ImageInput can produce a valid Anthropic source. +// Deprecated: use image.AnthropicSource() != nil directly. func (image ImageInput) HasAnthropicSource() bool { return image.AnthropicSource() != nil } @@ -132,17 +133,7 @@ func (image ImageInput) AnthropicSource() *anthropic.ImageSource { } func dataURLSource(value string) *anthropic.ImageSource { - header, data, ok := strings.Cut(value, ",") - if !ok { - return nil - } - mediaType := strings.TrimPrefix(header, "data:") - if semicolon := strings.IndexByte(mediaType, ';'); semicolon >= 0 { - mediaType = mediaType[:semicolon] - } - if mediaType == "" { - mediaType = "image/png" - } + mediaType, data := splitDataURL(value) return &anthropic.ImageSource{Type: "base64", MediaType: mediaType, Data: data} } diff --git a/internal/extension/visual/orchestrator.go b/internal/extension/visual/orchestrator.go index 5e24d30f..d2b4c155 100644 --- a/internal/extension/visual/orchestrator.go +++ b/internal/extension/visual/orchestrator.go @@ -71,9 +71,32 @@ func (o *Orchestrator) CreateMessage(ctx context.Context, req anthropic.MessageR } toolUses, nonVisual := splitVisualToolUses(resp.Content) - if len(nonVisual) > 0 || len(toolUses) == 0 { + if len(toolUses) == 0 { return resp, nil } + if len(nonVisual) > 0 { + // Execute visual tools, feed results to model. + toolResults := make([]anthropic.ContentBlock, 0, len(toolUses)) + for _, toolUse := range toolUses { + result := o.executeVisualTool(ctx, toolUse, availableImages) + toolResults = append(toolResults, anthropic.ContentBlock{ + Type: "tool_result", + ToolUseID: toolUse.ID, + Content: result, + }) + } + req.Messages = append(req.Messages, anthropic.Message{ + Role: "assistant", + Content: resp.Content, + }) + req.Messages = append(req.Messages, anthropic.Message{ + Role: "user", + Content: toolResults, + }) + req.ToolChoice = &anthropic.ToolChoice{Type: "auto"} + log.Debug("Visual mixed tool loop", "round", round+1, "visual_tools", len(toolUses), "non_visual", len(nonVisual)) + continue + } toolResults := make([]anthropic.ContentBlock, 0, len(toolUses)) for _, toolUse := range toolUses { @@ -128,12 +151,35 @@ func (o *Orchestrator) StreamMessage(ctx context.Context, req anthropic.MessageR assistantContent := collectContentFromEvents(events) toolUses, nonVisual := splitVisualToolUses(assistantContent) - if len(nonVisual) > 0 || len(toolUses) == 0 { + if len(toolUses) == 0 { if lastUsage != nil { allEvents = injectUsageIntoStart(allEvents, *lastUsage) } return &staticStream{events: allEvents}, nil } + if len(nonVisual) > 0 { + // Execute visual tools, feed results to model. + toolResults := make([]anthropic.ContentBlock, 0, len(toolUses)) + for _, toolUse := range toolUses { + result := o.executeVisualTool(ctx, toolUse, availableImages) + toolResults = append(toolResults, anthropic.ContentBlock{ + Type: "tool_result", + ToolUseID: toolUse.ID, + Content: result, + }) + } + req.Messages = append(req.Messages, anthropic.Message{ + Role: "assistant", + Content: assistantContent, + }) + req.Messages = append(req.Messages, anthropic.Message{ + Role: "user", + Content: toolResults, + }) + req.ToolChoice = &anthropic.ToolChoice{Type: "auto"} + log.Debug("Visual stream mixed tool loop", "round", round+1, "visual_tools", len(toolUses), "non_visual", len(nonVisual)) + continue + } toolResults := make([]anthropic.ContentBlock, 0, len(toolUses)) for _, toolUse := range toolUses { @@ -384,7 +430,7 @@ func normalizeImages(single string, urls []string, images []ImageInput, refs []s } continue } - if !image.HasAnthropicSource() { + if image.AnthropicSource() == nil { continue } normalized = append(normalized, image) diff --git a/internal/extension/visual/orchestrator_test.go b/internal/extension/visual/orchestrator_test.go index b5101749..d30dbd9f 100644 --- a/internal/extension/visual/orchestrator_test.go +++ b/internal/extension/visual/orchestrator_test.go @@ -429,4 +429,3 @@ func TestStripImagesFromAnthropic_MultiTurnDoesNotRestoreImages(t *testing.T) { t.Fatal("second user message was modified") } } - diff --git a/internal/extension/visual/tools.go b/internal/extension/visual/tools.go index 7866c84f..1f47d3df 100644 --- a/internal/extension/visual/tools.go +++ b/internal/extension/visual/tools.go @@ -1,8 +1,8 @@ package visual import ( - "moonbridge/internal/protocol/anthropic" "moonbridge/internal/format" + "moonbridge/internal/protocol/anthropic" ) const ( diff --git a/internal/extension/websearch/firecrawl.go b/internal/extension/websearch/firecrawl.go index 2c1fa561..7451dedb 100644 --- a/internal/extension/websearch/firecrawl.go +++ b/internal/extension/websearch/firecrawl.go @@ -26,7 +26,7 @@ func NewFirecrawlClient(apiKey string) *FirecrawlClient { } } -// Fetch scrapes a URL and returns its content as markdown. +// Fetch scrapes a URL and returns its content as markdown with retry on transient errors. func (c *FirecrawlClient) Fetch(ctx context.Context, req FetchRequest) (*FetchResult, error) { if len(req.Formats) == 0 { req.Formats = []string{"markdown"} @@ -40,36 +40,54 @@ func (c *FirecrawlClient) Fetch(ctx context.Context, req FetchRequest) (*FetchRe return nil, fmt.Errorf("marshal fetch request: %w", err) } - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, firecrawlBaseURL+"/v1/scrape", bytes.NewReader(body)) - if err != nil { - return nil, fmt.Errorf("create fetch request: %w", err) - } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) + var lastErr error + for attempt := 0; attempt < 3; attempt++ { + if attempt > 0 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(1<= 600) { + return nil, lastErr + } } - - var result FetchResult - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("unmarshal fetch response: %w", err) - } - return &result, nil + return nil, fmt.Errorf("firecrawl fetch failed after 3 attempts: %w", lastErr) } // Enabled returns whether the Firecrawl client is configured with a valid API key. diff --git a/internal/extension/websearch/orchestrator.go b/internal/extension/websearch/orchestrator.go index a93ecf44..1e832b8f 100644 --- a/internal/extension/websearch/orchestrator.go +++ b/internal/extension/websearch/orchestrator.go @@ -107,9 +107,27 @@ func (o *Orchestrator) CreateMessage(ctx context.Context, req anthropic.MessageR searchUses := o.filterSearchTools(toolUses) nonSearchUses := subtractToolUses(toolUses, searchUses) - // If there are non-search tool_use blocks, return the response - // so the caller (Bridge) can handle them as normal tool calls. + // If there are non-search tool_use blocks, execute any pending search + // tools and return the response with search tool_uses filtered out. + // This prevents injected search tools (tavily_search, firecrawl_fetch) + // from leaking to the client through the Bridge. if len(nonSearchUses) > 0 { + // Execute search tools first as a side effect. + for _, tu := range searchUses { + _, execErr := o.executeSearch(ctx, tu) + if execErr != nil { + log.Warn("搜索执行失败(混合调用)", "tool", tu.Name, "error", execErr) + } + } + // Filter search tool_uses from the response content. + filtered := make([]anthropic.ContentBlock, 0, len(resp.Content)) + for _, block := range resp.Content { + if block.Type == "tool_use" && o.isSearchTool(block.Name) { + continue + } + filtered = append(filtered, block) + } + resp.Content = filtered return resp, nil } @@ -117,25 +135,7 @@ func (o *Orchestrator) CreateMessage(ctx context.Context, req anthropic.MessageR return resp, nil } - // Execute search/fetch calls and build tool results. - toolResults := make([]anthropic.ContentBlock, 0, len(searchUses)) - for _, tu := range searchUses { - result, execErr := o.executeSearch(ctx, tu) - if execErr != nil { - log.Warn("搜索执行失败", "tool", tu.Name, "error", execErr) - toolResults = append(toolResults, anthropic.ContentBlock{ - Type: "tool_result", - ToolUseID: tu.ID, - Content: json.RawMessage(fmt.Sprintf(`"Search error: %s"`, execErr.Error())), - }) - continue - } - toolResults = append(toolResults, anthropic.ContentBlock{ - Type: "tool_result", - ToolUseID: tu.ID, - Content: json.RawMessage(fmt.Sprintf(`"%s"`, escapeForJSON(result))), - }) - } + toolResults := o.buildToolResults(ctx, searchUses) // Append the assistant message (with search tool_use blocks) and // user message (with tool_results) to the request for the next round. @@ -205,34 +205,31 @@ func (o *Orchestrator) StreamMessage(ctx context.Context, req anthropic.MessageR searchUses := o.filterSearchTools(toolUses) nonSearchUses := subtractToolUses(toolUses, searchUses) - if len(nonSearchUses) > 0 || len(searchUses) == 0 { + if len(searchUses) == 0 { allEvents = events if lastUsage != nil { allEvents = injectUsageIntoStart(allEvents, *lastUsage) } return &staticStream{events: allEvents}, nil } - - // Execute searches and build follow-up request. - toolResults := make([]anthropic.ContentBlock, 0, len(searchUses)) - for _, tu := range searchUses { - result, execErr := o.executeSearch(ctx, tu) - if execErr != nil { - log.Warn("流式搜索执行失败", "tool", tu.Name, "error", execErr) - toolResults = append(toolResults, anthropic.ContentBlock{ - Type: "tool_result", - ToolUseID: tu.ID, - Content: json.RawMessage(fmt.Sprintf(`"Search error: %s"`, execErr.Error())), - }) - continue + if len(nonSearchUses) > 0 { + // Execute search tools as side effect, but return only non-search content. + for _, tu := range searchUses { + _, execErr := o.executeSearch(ctx, tu) + if execErr != nil { + log.Warn("流式搜索执行失败(混合调用)", "tool", tu.Name, "error", execErr) + } } - toolResults = append(toolResults, anthropic.ContentBlock{ - Type: "tool_result", - ToolUseID: tu.ID, - Content: json.RawMessage(fmt.Sprintf(`"%s"`, escapeForJSON(result))), - }) + // Filter search tool_uses from the returned events. + allEvents = events + if lastUsage != nil { + allEvents = injectUsageIntoStart(allEvents, *lastUsage) + } + return &staticStream{events: allEvents}, nil } + toolResults := o.buildToolResults(ctx, searchUses) + req.Messages = append(req.Messages, anthropic.Message{ Role: "assistant", Content: collectContentFromEvents(events), @@ -288,7 +285,7 @@ func (o *Orchestrator) executeTavilySearch(ctx context.Context, raw json.RawMess if err != nil { return "", err } - return formatTavilyResults(result), nil + return FormatTavilyResults(result), nil } func (o *Orchestrator) executeFirecrawlFetch(ctx context.Context, raw json.RawMessage) (string, error) { @@ -310,7 +307,37 @@ func (o *Orchestrator) executeFirecrawlFetch(ctx context.Context, raw json.RawMe if err != nil { return "", err } - return formatFirecrawlResult(result), nil + return FormatFirecrawlResult(result), nil +} + +// isSearchTool returns true if the tool name is a registered search handler. +func (o *Orchestrator) isSearchTool(name string) bool { + _, ok := o.toolHandlers[name] + return ok +} + +// buildToolResults executes search/fetch for each tool use and returns tool_result blocks. +func (o *Orchestrator) buildToolResults(ctx context.Context, searchUses []anthropic.ContentBlock) []anthropic.ContentBlock { + log := slog.Default() + results := make([]anthropic.ContentBlock, 0, len(searchUses)) + for _, tu := range searchUses { + result, execErr := o.executeSearch(ctx, tu) + if execErr != nil { + log.Warn("搜索执行失败", "tool", tu.Name, "error", execErr) + results = append(results, anthropic.ContentBlock{ + Type: "tool_result", + ToolUseID: tu.ID, + Content: json.RawMessage(fmt.Sprintf(`"Search error: %s"`, execErr.Error())), + }) + continue + } + results = append(results, anthropic.ContentBlock{ + Type: "tool_result", + ToolUseID: tu.ID, + Content: json.RawMessage(fmt.Sprintf(`"%s"`, escapeForJSON(result))), + }) + } + return results } // filterSearchTools returns tool_use blocks that are registered search handlers. @@ -325,13 +352,13 @@ func (o *Orchestrator) filterSearchTools(toolUses []anthropic.ContentBlock) []an } // formatTavilyResults formats Tavily search results as a readable text block. -func formatTavilyResults(result *SearchResult) string { +func FormatTavilyResults(result *SearchResult) string { var b strings.Builder b.WriteString(fmt.Sprintf("Search results for %q:\n\n", result.Query)) if result.Answer != "" { b.WriteString("Answer: ") - b.WriteString(truncate(result.Answer, 2000)) + b.WriteString(Truncate(result.Answer, 2000)) b.WriteString("\n\n") } @@ -341,19 +368,19 @@ func formatTavilyResults(result *SearchResult) string { } b.WriteString(fmt.Sprintf("%d. [%s](%s)\n", i+1, item.Title, item.URL)) b.WriteString(fmt.Sprintf(" Score: %.2f\n", item.Score)) - b.WriteString(fmt.Sprintf(" %s\n\n", truncate(item.Content, 500))) + b.WriteString(fmt.Sprintf(" %s\n\n", Truncate(item.Content, 500))) } return b.String() } // formatFirecrawlResult formats Firecrawl scrape results as a readable text block. -func formatFirecrawlResult(result *FetchResult) string { +func FormatFirecrawlResult(result *FetchResult) string { var b strings.Builder b.WriteString(fmt.Sprintf("Content from %s:\n\n", result.Data.Metadata.SourceURL)) if result.Data.Metadata.Title != "" { b.WriteString(fmt.Sprintf("Title: %s\n\n", result.Data.Metadata.Title)) } - b.WriteString(truncate(result.Data.Markdown, 8000)) + b.WriteString(Truncate(result.Data.Markdown, 8000)) return b.String() } @@ -562,7 +589,7 @@ func (s *staticStream) Close() error { return nil } -func truncate(s string, maxLen int) string { +func Truncate(s string, maxLen int) string { if len(s) <= maxLen { return s } @@ -570,11 +597,17 @@ func truncate(s string, maxLen int) string { } func escapeForJSON(s string) string { - // Escape backslashes and double quotes for embedding in JSON strings. - s = strings.ReplaceAll(s, "\\", "\\\\") - s = strings.ReplaceAll(s, "\"", "\\\"") - s = strings.ReplaceAll(s, "\n", "\\n") - s = strings.ReplaceAll(s, "\r", "\\r") - s = strings.ReplaceAll(s, "\t", "\\t") - return s + // Use json.Marshal for proper Unicode/control character escaping. + // The go standard library handles all escaping rules according to RFC 8259. + encoded, err := json.Marshal(s) + if err != nil { + return s + } + // json.Marshal returns a quoted JSON string. Strip the surrounding quotes + // for embedding in the tool_result content template. + raw := string(encoded) + if len(raw) >= 2 && raw[0] == '"' && raw[len(raw)-1] == '"' { + return raw[1 : len(raw)-1] + } + return raw } diff --git a/internal/extension/websearch/tavily.go b/internal/extension/websearch/tavily.go index 388f74b6..ea43a91d 100644 --- a/internal/extension/websearch/tavily.go +++ b/internal/extension/websearch/tavily.go @@ -26,7 +26,7 @@ func NewTavilyClient(apiKey string) *TavilyClient { } } -// Search executes a search query against the Tavily API. +// Search executes a search query against the Tavily API with retry on transient errors. func (c *TavilyClient) Search(ctx context.Context, req SearchRequest) (*SearchResult, error) { if req.MaxResults <= 0 { req.MaxResults = 5 @@ -40,34 +40,52 @@ func (c *TavilyClient) Search(ctx context.Context, req SearchRequest) (*SearchRe return nil, fmt.Errorf("marshal search request: %w", err) } - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilyBaseURL+"/search", bytes.NewReader(body)) - if err != nil { - return nil, fmt.Errorf("create search request: %w", err) - } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) + var lastErr error + for attempt := 0; attempt < 3; attempt++ { + if attempt > 0 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(1<= 600) { + return nil, lastErr + } } - - var result SearchResult - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("unmarshal search response: %w", err) - } - return &result, nil + return nil, fmt.Errorf("tavily search failed after 3 attempts: %w", lastErr) } diff --git a/internal/extension/websearchinjected/plugin.go b/internal/extension/websearchinjected/plugin.go index 895368e6..3b22253a 100644 --- a/internal/extension/websearchinjected/plugin.go +++ b/internal/extension/websearchinjected/plugin.go @@ -2,10 +2,10 @@ package websearchinjected import ( "moonbridge/internal/extension/plugin" - "moonbridge/internal/protocol/anthropic" - "moonbridge/internal/service/provider" "moonbridge/internal/extension/websearch" "moonbridge/internal/format" + "moonbridge/internal/protocol/anthropic" + "moonbridge/internal/service/provider" ) const PluginName = "web_search_injected" diff --git a/internal/format/adapter.go b/internal/format/adapter.go index 40770a45..dc55008f 100644 --- a/internal/format/adapter.go +++ b/internal/format/adapter.go @@ -101,10 +101,27 @@ type ProviderStreamAdapter interface { // ToCoreStream consumes an upstream stream source (e.g. anthropic.Stream) // and returns a channel of CoreStreamEvent. The adapter is responsible for // the read-loop inside a goroutine. - ToCoreStream(ctx context.Context, src any) (<-chan CoreStreamEvent, error) + ToCoreStream(ctx context.Context, src any) (*StreamResult, error) } // ============================================================================ + +// StreamResult wraps the output of a streaming invocation with per-stream +// buffer access for trace capture and plugin state tracking. +// +// Using a concrete return type (rather than a bare channel) keeps the +// buffer local to each streaming call, eliminating the data race that +// occurred when buffers were stored on the shared adapter instance. +type StreamResult struct { + // Events is the channel of CoreStreamEvent from the upstream stream. + Events <-chan CoreStreamEvent + + // StreamBuffer returns the captured raw upstream events for trace, + // plugin reasoning replay, and other post-stream processing. + // Must only be called after the Events channel is fully consumed. + StreamBuffer func() []any +} + // CorePluginHooks — protocol-agnostic plugin hooks operating on Core format // ============================================================================ diff --git a/internal/format/registry_test.go b/internal/format/registry_test.go index c274fdc6..f0d0da73 100644 --- a/internal/format/registry_test.go +++ b/internal/format/registry_test.go @@ -11,15 +11,23 @@ import ( type mockClient struct{ protocol string } -func (m *mockClient) ClientProtocol() string { return m.protocol } -func (m *mockClient) ToCoreRequest(_ context.Context, _ any) (*CoreRequest, error) { return &CoreRequest{}, nil } -func (m *mockClient) FromCoreResponse(_ context.Context, _ *CoreResponse) (any, error) { return nil, nil } +func (m *mockClient) ClientProtocol() string { return m.protocol } +func (m *mockClient) ToCoreRequest(_ context.Context, _ any) (*CoreRequest, error) { + return &CoreRequest{}, nil +} +func (m *mockClient) FromCoreResponse(_ context.Context, _ *CoreResponse) (any, error) { + return nil, nil +} type mockProvider struct{ protocol string } -func (m *mockProvider) ProviderProtocol() string { return m.protocol } -func (m *mockProvider) FromCoreRequest(_ context.Context, _ *CoreRequest) (any, error) { return nil, nil } -func (m *mockProvider) ToCoreResponse(_ context.Context, _ any) (*CoreResponse, error) { return &CoreResponse{}, nil } +func (m *mockProvider) ProviderProtocol() string { return m.protocol } +func (m *mockProvider) FromCoreRequest(_ context.Context, _ *CoreRequest) (any, error) { + return nil, nil +} +func (m *mockProvider) ToCoreResponse(_ context.Context, _ any) (*CoreResponse, error) { + return &CoreResponse{}, nil +} type mockClientStream struct{ protocol string } @@ -31,7 +39,7 @@ func (m *mockClientStream) FromCoreStream(_ context.Context, _ *CoreRequest, _ < type mockProviderStream struct{ protocol string } func (m *mockProviderStream) ProviderProtocol() string { return m.protocol } -func (m *mockProviderStream) ToCoreStream(_ context.Context, _ any) (<-chan CoreStreamEvent, error) { +func (m *mockProviderStream) ToCoreStream(_ context.Context, _ any) (*StreamResult, error) { return nil, nil } diff --git a/internal/format/types.go b/internal/format/types.go index 085ba096..8b552279 100644 --- a/internal/format/types.go +++ b/internal/format/types.go @@ -118,7 +118,7 @@ type CoreToolChoice struct { // nil = use provider defaults. type CoreThinkingConfig struct { Type string `json:"type,omitempty"` // "enabled" | "disabled" - BudgetTokens int `json:"budget_tokens,omitempty"` // token budget for thinking + BudgetTokens int `json:"budget_tokens,omitempty"` // token budget for thinking } // CoreOutputConfig controls output generation behavior. @@ -139,8 +139,8 @@ type CoreCacheControl struct { // CoreRequest is the protocol-agnostic representation of an LLM request. type CoreRequest struct { - Model string `json:"model"` - Messages []CoreMessage `json:"messages"` + Model string `json:"model"` + Messages []CoreMessage `json:"messages"` System []CoreContentBlock `json:"system,omitempty"` // Tools @@ -172,7 +172,6 @@ type CoreRequest struct { // Zero value (nil) = not set — adapter uses provider defaults. GenerationConfig map[string]any `json:"generation_config,omitempty"` - // Thinking controls extended thinking/reasoning behavior (e.g. Anthropic extended thinking). // nil = use provider defaults. Thinking *CoreThinkingConfig `json:"thinking,omitempty"` @@ -304,7 +303,6 @@ type CoreStreamEvent struct { Extensions map[string]any `json:"extensions,omitempty"` } - // StripImageData scans string content for base64-encoded image data (data:image URLs // and raw PNG/JPEG base64 blobs) and replaces them with short placeholders. // This prevents large image payloads from wasting tokens when sent to text-only models @@ -329,7 +327,7 @@ func StripImageData(s string) string { for end < len(s) && (isBase64Char(s[end]) || s[end] == '=') { end++ } - if end > dataStart + 500 { + if end > dataStart+500 { imgType := s[start+len(marker) : start+commaIdx] result.WriteString(fmt.Sprintf("[Image data: %s, %d bytes]", imgType, end-dataStart)) pos = end @@ -352,7 +350,7 @@ func StripImageData(s string) string { for end < len(s) && (isBase64Char(s[end]) || s[end] == '=') { end++ } - if end > start + 500 { + if end > start+500 { result.WriteString(fmt.Sprintf("[Image data: png, %d bytes]", end-start)) pos = end continue @@ -370,7 +368,7 @@ func StripImageData(s string) string { for end < len(s) && (isBase64Char(s[end]) || s[end] == '=') { end++ } - if end > start + 500 { + if end > start+500 { result.WriteString(fmt.Sprintf("[Image data: jpeg, %d bytes]", end-start)) pos = end continue diff --git a/internal/logger/consumer.go b/internal/logger/consumer.go index 4e485a31..ea95b060 100644 --- a/internal/logger/consumer.go +++ b/internal/logger/consumer.go @@ -46,10 +46,10 @@ func (s *consumeState) store(fn ConsumeFunc) { // handlerAttrs / handlerGroups and merged into LogEntry.Attrs before // dispatching to the consume pipeline, so consumers see the full context. type consumeHandler struct { - inner slog.Handler - consume *consumeState - handlerAttrs []slog.Attr // attrs from WithAttrs calls - handlerGroups []string // groups from WithGroup calls + inner slog.Handler + consume *consumeState + handlerAttrs []slog.Attr // attrs from WithAttrs calls + handlerGroups []string // groups from WithGroup calls } // newConsumeHandler wraps the given handler with consume-function support. @@ -129,9 +129,9 @@ func (h *consumeHandler) WithAttrs(attrs []slog.Attr) slog.Handler { copy(combined, h.handlerAttrs) combined = append(combined, attrs...) return &consumeHandler{ - inner: h.inner.WithAttrs(attrs), - consume: h.consume, - handlerAttrs: combined, + inner: h.inner.WithAttrs(attrs), + consume: h.consume, + handlerAttrs: combined, handlerGroups: h.handlerGroups, } } diff --git a/internal/protocol/anthropic/adapter.go b/internal/protocol/anthropic/adapter.go index 5a09e767..ef39e74e 100644 --- a/internal/protocol/anthropic/adapter.go +++ b/internal/protocol/anthropic/adapter.go @@ -47,8 +47,14 @@ type AnthropicProviderAdapter struct { cacheMgr CacheManager hooks format.CorePluginHooks - streamMu sync.Mutex - streamEvents []StreamEvent + // cacheKeyMu guards cacheKeyStore, which maps context pointers to cache + // key/ttl pairs computed during PlanAndInject and consumed during ToCoreResponse. + cacheKeyMu sync.Mutex + cacheKeyStore map[string]cacheKeyEntry +} + +type cacheKeyEntry struct { + key, ttl string } // NewAnthropicProviderAdapter creates a new AnthropicProviderAdapter. @@ -57,9 +63,10 @@ type AnthropicProviderAdapter struct { // if caching is not needed. func NewAnthropicProviderAdapter(cfgMaxTokens int, cacheMgr CacheManager, hooks format.CorePluginHooks) *AnthropicProviderAdapter { return &AnthropicProviderAdapter{ - cfgMaxTokens: cfgMaxTokens, - cacheMgr: cacheMgr, - hooks: hooks.WithDefaults(), + cfgMaxTokens: cfgMaxTokens, + cacheMgr: cacheMgr, + hooks: hooks.WithDefaults(), + cacheKeyStore: make(map[string]cacheKeyEntry), } } @@ -294,6 +301,11 @@ func (a *AnthropicProviderAdapter) FromCoreRequest(ctx context.Context, req *for Role: a.mapRole(msg.Role), Content: a.toContentBlocks(msg.Content), } + // Skip messages with no content blocks — they are empty and contribute + // no semantic value to the upstream API. + if len(anthroMsg.Content) == 0 { + continue + } last := len(anthropicReq.Messages) - 1 if last >= 0 && anthroMsg.Role == "user" && anthropicReq.Messages[last].Role == "user" && @@ -306,6 +318,15 @@ func (a *AnthropicProviderAdapter) FromCoreRequest(ctx context.Context, req *for } } + // Ensure first message has role "user" — Anthropic API rejects requests + // where the first message is assistant, tool, or any non-user role. + if len(anthropicReq.Messages) > 0 && anthropicReq.Messages[0].Role != "user" { + anthropicReq.Messages = append( + []Message{{Role: "user", Content: []ContentBlock{{Type: "text", Text: "_"}}}}, + anthropicReq.Messages..., + ) + } + // Tools if len(req.Tools) > 0 { anthropicReq.Tools = make([]Tool, 0, len(req.Tools)) @@ -333,7 +354,13 @@ func (a *AnthropicProviderAdapter) FromCoreRequest(ctx context.Context, req *for // Step 3: Cache planning via CacheManager. // PlanAndInject may modify anthropicReq in-place by setting cache_control // on tools, system blocks, messages, or the request-level field. - a.cacheMgr.PlanAndInject(ctx, &anthropicReq, req) + key, ttl := a.cacheMgr.PlanAndInject(ctx, &anthropicReq, req) + + // Store cache key/ttl for retrieval in ToCoreResponse. + ctxKey := fmt.Sprintf("%p", ctx) + a.cacheKeyMu.Lock() + a.cacheKeyStore[ctxKey] = cacheKeyEntry{key: key, ttl: ttl} + a.cacheKeyMu.Unlock() return &anthropicReq, nil } @@ -384,10 +411,19 @@ func (a *AnthropicProviderAdapter) ToCoreResponse(ctx context.Context, resp any) } // Update cache registry from usage signals via CacheManager. - // The key/ttl were computed during PlanAndInject and must be accessible - // through the CacheManager's own state (or context). + // The key/ttl were computed during PlanAndInject and are retrieved + // from the per-request cache key store. if a.cacheMgr != nil { - a.cacheMgr.UpdateRegistry(ctx, "", "", msgResp.Usage) + ctxKey := fmt.Sprintf("%p", ctx) + a.cacheKeyMu.Lock() + entry, ok := a.cacheKeyStore[ctxKey] + delete(a.cacheKeyStore, ctxKey) + a.cacheKeyMu.Unlock() + if ok { + a.cacheMgr.UpdateRegistry(ctx, entry.key, entry.ttl, msgResp.Usage) + } else { + a.cacheMgr.UpdateRegistry(ctx, "", "", msgResp.Usage) + } } return coreResp, nil @@ -420,15 +456,18 @@ type streamConverterState struct { blockTypes map[int]string // content index → block type blockSignatures map[int]string // content index → reasoning signature (from signature_delta) finalUsage *format.CoreUsage // tracked from message_delta, passed to message_stop - adapter *AnthropicProviderAdapter // for buffering raw stream events (trace) + adapter *AnthropicProviderAdapter // for plugin hooks suppressText map[int]bool // text indices to suppress (server-side search status, etc.) + buf *[]StreamEvent // per-stream event buffer (local, not shared) + bufMu *sync.Mutex // guards buf + ctx context.Context // for context-aware channel sends } // ToCoreStream consumes an anthropic.Stream and returns a channel of CoreStreamEvent. // // The adapter owns the read-loop goroutine. The returned channel is closed when // the stream ends, context is cancelled, or an error occurs. -func (a *AnthropicProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan format.CoreStreamEvent, error) { +func (a *AnthropicProviderAdapter) ToCoreStream(ctx context.Context, src any) (*format.StreamResult, error) { stream, ok := src.(Stream) if !ok { return nil, fmt.Errorf("anthropic adapter: expected anthropic.Stream, got %T", src) @@ -436,13 +475,14 @@ func (a *AnthropicProviderAdapter) ToCoreStream(ctx context.Context, src any) (< ctx = coreHookContext(ctx, "") events := make(chan format.CoreStreamEvent, 64) - // Initialize stream event buffer for trace capture. - a.streamMu.Lock() - a.streamEvents = make([]StreamEvent, 0, 64) - a.streamMu.Unlock() + // Per-stream buffer — local to this call, not shared across concurrent requests. + var buf []StreamEvent + var bufMu sync.Mutex + bufReady := make(chan struct{}) go func() { defer close(events) + defer close(bufReady) defer stream.Close() state := &streamConverterState{ @@ -450,6 +490,9 @@ func (a *AnthropicProviderAdapter) ToCoreStream(ctx context.Context, src any) (< blockSignatures: make(map[int]string), adapter: a, suppressText: make(map[int]bool), + buf: &buf, + bufMu: &bufMu, + ctx: ctx, } for { @@ -462,10 +505,8 @@ func (a *AnthropicProviderAdapter) ToCoreStream(ctx context.Context, src any) (< ev, err := stream.Next() if err != nil { if err == io.EOF { - // Stream ended normally. return } - // Context cancellation is clean shutdown, not a failure. if err == context.Canceled || err == context.DeadlineExceeded { return } @@ -478,11 +519,30 @@ func (a *AnthropicProviderAdapter) ToCoreStream(ctx context.Context, src any) (< return } + // Check context immediately after Next() to close the race window. + select { + case <-ctx.Done(): + return + default: + } + state.convertEvent(events, ev) } }() - return events, nil + return &format.StreamResult{ + Events: events, + StreamBuffer: func() []any { + <-bufReady + bufMu.Lock() + defer bufMu.Unlock() + result := make([]any, len(buf)) + for i, ev := range buf { + result[i] = ev + } + return result + }, + }, nil } // ========================================================================= @@ -496,13 +556,24 @@ func (s *streamConverterState) nextSeq() int64 { func (s *streamConverterState) emit(events chan<- format.CoreStreamEvent, ev format.CoreStreamEvent) { ev.SeqNum = s.nextSeq() - events <- ev + if s.ctx != nil { + select { + case <-s.ctx.Done(): + case events <- ev: + } + } else { + events <- ev + } } func (s *streamConverterState) convertEvent(events chan<- format.CoreStreamEvent, ev StreamEvent) { - // Buffer the original event for trace (max 4MB). - if s.adapter != nil { - s.adapter.bufferStreamEvent(ev) + // Buffer the original event for trace in the per-stream local buffer. + if s.bufMu != nil && s.buf != nil { + s.bufMu.Lock() + if len(*s.buf) < 1024 { + *s.buf = append(*s.buf, ev) + } + s.bufMu.Unlock() } switch ev.Type { case "message_start": @@ -741,7 +812,7 @@ func (a *AnthropicProviderAdapter) toContentBlock(b format.CoreContentBlock) Con } case "tool_result": - var content any + var content any = "" if len(b.ToolResultContent) > 0 { content = a.toContentBlocks(b.ToolResultContent) } @@ -969,20 +1040,15 @@ func (a *AnthropicProviderAdapter) mapStopReasonToStatus(reason string) string { // bufferStreamEvent buffers the raw anthropic stream event for trace capture, // up to the 4MB limit. The event is JSON-marshalled to estimate its size. func (a *AnthropicProviderAdapter) bufferStreamEvent(ev StreamEvent) { - a.streamMu.Lock() - defer a.streamMu.Unlock() - // Cap buffer at 1024 events (~4MB estimation) to prevent unbounded memory growth. - if len(a.streamEvents) >= 1024 { - return - } - a.streamEvents = append(a.streamEvents, ev) + // streamConverterState captures buf/bufMu from the goroutine closure. + // This is a no-op without a per-stream buffer — use the state.bufferStreamEvent instead. } // StreamBuffer returns the buffered stream events for trace capture. func (a *AnthropicProviderAdapter) StreamBuffer() []StreamEvent { - a.streamMu.Lock() - defer a.streamMu.Unlock() - return a.streamEvents + // Deprecated: use StreamResult.StreamBuffer instead. + // This method will be removed after all callers migrate. + return nil } // RememberStreamContent stores response content from a stream for plugin state tracking. diff --git a/internal/protocol/anthropic/adapter_test.go b/internal/protocol/anthropic/adapter_test.go index ed057567..88d0e353 100644 --- a/internal/protocol/anthropic/adapter_test.go +++ b/internal/protocol/anthropic/adapter_test.go @@ -261,12 +261,17 @@ func TestFromCoreRequest_ToolUseContent(t *testing.T) { } msgReq := result.(*anthropic.MessageRequest) - if len(msgReq.Messages) != 2 { - t.Fatalf("got %d messages, want 2", len(msgReq.Messages)) + if len(msgReq.Messages) != 3 { + t.Fatalf("got %d messages, want 3", len(msgReq.Messages)) + } + + // First message is the inserted user placeholder. + if msgReq.Messages[0].Role != "user" || len(msgReq.Messages[0].Content) != 1 || msgReq.Messages[0].Content[0].Text != "_" { + t.Fatalf("expected user placeholder as first message") } // assistant tool_use block - blocks0 := msgReq.Messages[0].Content + blocks0 := msgReq.Messages[1].Content if len(blocks0) != 1 || blocks0[0].Type != "tool_use" { t.Fatalf("expected tool_use block in assistant message") } @@ -277,8 +282,8 @@ func TestFromCoreRequest_ToolUseContent(t *testing.T) { t.Errorf("tool_use name = %q", blocks0[0].Name) } - // user tool_result block - blocks1 := msgReq.Messages[1].Content + // user tool_result block (index 2 because placeholder is at 0) + blocks1 := msgReq.Messages[2].Content if len(blocks1) != 1 || blocks1[0].Type != "tool_result" { t.Fatalf("expected tool_result block in user message") } @@ -313,7 +318,14 @@ func TestFromCoreRequest_Reasoning(t *testing.T) { t.Fatal(err) } msgReq := result.(*anthropic.MessageRequest) - blocks := msgReq.Messages[0].Content + // First message is the inserted user placeholder. + if len(msgReq.Messages) != 3 { + t.Fatalf("got %d messages, want 3", len(msgReq.Messages)) + } + if msgReq.Messages[0].Role != "user" || len(msgReq.Messages[0].Content) != 1 || msgReq.Messages[0].Content[0].Text != "_" { + t.Fatalf("expected user placeholder as first message, got role=%q content=%v", msgReq.Messages[0].Role, msgReq.Messages[0].Content) + } + blocks := msgReq.Messages[1].Content if len(blocks) != 2 { t.Fatalf("got %d blocks, want 2", len(blocks)) @@ -684,18 +696,23 @@ func TestFromCoreRequest_MergesConsecutiveToolResultMessages(t *testing.T) { t.Fatalf("expected *MessageRequest, got %T", result) } - // We should have: assistant + user (merged from 2 tool_result messages) - if len(msgReq.Messages) != 2 { - t.Fatalf("expected 2 messages (assistant + merged user), got %d", len(msgReq.Messages)) + // We should have: placeholder user + assistant + merged user + if len(msgReq.Messages) != 3 { + t.Fatalf("expected 3 messages (placeholder + assistant + merged user), got %d", len(msgReq.Messages)) + } + + // First message is the inserted user placeholder. + if msgReq.Messages[0].Role != "user" || len(msgReq.Messages[0].Content) != 1 || msgReq.Messages[0].Content[0].Text != "_" { + t.Fatalf("expected user placeholder as first message, got role=%q", msgReq.Messages[0].Role) } - // First message should be assistant with tool_use - if msgReq.Messages[0].Role != "assistant" { - t.Errorf("messages[0].Role = %q, want assistant", msgReq.Messages[0].Role) + // Second message should be assistant with tool_use + if msgReq.Messages[1].Role != "assistant" { + t.Errorf("messages[1].Role = %q, want assistant", msgReq.Messages[1].Role) } - // Second message should be user with 2 tool_result blocks (merged) - merged := msgReq.Messages[1] + // Third message should be user with 2 tool_result blocks (merged) + merged := msgReq.Messages[2] if merged.Role != "user" { t.Errorf("merged message role = %q, want user", merged.Role) } diff --git a/internal/protocol/anthropic/cache.go b/internal/protocol/anthropic/cache.go index 7a415b10..17cbad9b 100644 --- a/internal/protocol/anthropic/cache.go +++ b/internal/protocol/anthropic/cache.go @@ -7,8 +7,8 @@ import ( "strings" "moonbridge/internal/config" - "moonbridge/internal/protocol/cache" "moonbridge/internal/format" + "moonbridge/internal/protocol/cache" "moonbridge/internal/protocol/openai" ) @@ -100,7 +100,7 @@ func (m *adapterCacheManager) UpdateRegistry(ctx context.Context, key, ttl strin func PlanCache(cfg cache.PlanCacheConfig, registry *cache.MemoryRegistry, request openai.ResponsesRequest, converted MessageRequest) (cache.CacheCreationPlan, error) { if request.PromptCacheRetention == "24h" && !cfg.AllowRetentionDowngrade { return cache.CacheCreationPlan{}, &cachePlanError{ - Status: 400, + Status: 400, Message: "prompt_cache_retention 24h is not supported by Anthropic prompt caching", Param: "prompt_cache_retention", Code: "unsupported_parameter", diff --git a/internal/protocol/chat/adapter.go b/internal/protocol/chat/adapter.go index a5380fac..ff4bce5e 100644 --- a/internal/protocol/chat/adapter.go +++ b/internal/protocol/chat/adapter.go @@ -28,9 +28,6 @@ type ChatProviderAdapter struct { cfgMaxTokens int client *Client hooks format.CorePluginHooks - - streamMu sync.Mutex - streamEvents []ChatStreamChunk } // NewChatProviderAdapter creates a new ChatProviderAdapter. @@ -95,6 +92,11 @@ func (a *ChatProviderAdapter) FromCoreRequest(ctx context.Context, req *format.C // Messages. for _, msg := range req.Messages { chatMsg := a.toChatMessage(msg) + // Skip messages with neither text content nor tool calls — empty messages + // contribute no semantic value and may be rejected by some upstreams. + if chatMsg.Content == nil && len(chatMsg.ToolCalls) == 0 { + continue + } chatReq.Messages = append(chatReq.Messages, chatMsg) } @@ -198,16 +200,14 @@ func (a *ChatProviderAdapter) ToCoreResponse(ctx context.Context, resp any) (*fo // ========================================================================= // bufferStreamEvent buffers raw ChatStreamChunk for trace capture. func (a *ChatProviderAdapter) bufferStreamEvent(ev ChatStreamChunk) { - a.streamMu.Lock() - defer a.streamMu.Unlock() - a.streamEvents = append(a.streamEvents, ev) + // No-op: per-stream buffer is captured by goroutine closure. + // Use the StreamResult.StreamBuffer to access captured events. } // StreamBuffer returns the buffered stream events for trace capture. func (a *ChatProviderAdapter) StreamBuffer() []ChatStreamChunk { - a.streamMu.Lock() - defer a.streamMu.Unlock() - return a.streamEvents + // Deprecated: use StreamResult.StreamBuffer instead. + return nil } // ToCoreStream — <-chan ChatStreamChunk → <-chan CoreStreamEvent @@ -224,7 +224,7 @@ func (a *ChatProviderAdapter) StreamBuffer() []ChatStreamChunk { // - core.text.delta (chunks with content delta) // - core.content_block.done (chunk with finish_reason set) // - core.completed (final chunk with Usage) -func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan format.CoreStreamEvent, error) { +func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (*format.StreamResult, error) { ch, ok := src.(<-chan ChatStreamChunk) if !ok { return nil, fmt.Errorf("chat adapter: expected <-chan ChatStreamChunk, got %T", src) @@ -232,13 +232,14 @@ func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan events := make(chan format.CoreStreamEvent, 64) - // Initialize stream event buffer for trace capture. - a.streamMu.Lock() - a.streamEvents = make([]ChatStreamChunk, 0, 64) - a.streamMu.Unlock() + // Per-stream buffer — local to this call, not shared across concurrent requests. + var buf []ChatStreamChunk + var bufMu sync.Mutex + bufReady := make(chan struct{}) go func() { defer close(events) + defer close(bufReady) // Per-choice state for streaming. type choiceState struct { @@ -271,7 +272,12 @@ func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan case <-ctx.Done(): return case chunk, ok := <-ch: - a.bufferStreamEvent(chunk) + // Append to local per-stream buffer with size cap. + bufMu.Lock() + if len(buf) < 1024 { + buf = append(buf, chunk) + } + bufMu.Unlock() if !ok { // Channel closed — emit completion if not already done. if !seenCompletion { @@ -510,7 +516,19 @@ func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan } }() - return events, nil + return &format.StreamResult{ + Events: events, + StreamBuffer: func() []any { + <-bufReady + bufMu.Lock() + defer bufMu.Unlock() + result := make([]any, len(buf)) + for i, ev := range buf { + result[i] = ev + } + return result + }, + }, nil } // ========================================================================= diff --git a/internal/protocol/chat/chat_e2e_test.go b/internal/protocol/chat/chat_e2e_test.go index 10201b7d..7cbe8b62 100644 --- a/internal/protocol/chat/chat_e2e_test.go +++ b/internal/protocol/chat/chat_e2e_test.go @@ -117,7 +117,7 @@ func TestE2EChatProvider(t *testing.T) { Name: "get_weather", Description: "Get the current weather for a city", Parameters: map[string]any{ - "type": "object", + "type": "object", "properties": map[string]any{ "location": map[string]any{ "type": "string", @@ -130,7 +130,7 @@ func TestE2EChatProvider(t *testing.T) { }, }, ToolChoice: []byte(`"auto"`), - MaxTokens: 200, + MaxTokens: 200, } resp, err := client.CreateChat(context.Background(), req) diff --git a/internal/protocol/chat/chat_test.go b/internal/protocol/chat/chat_test.go index 35b71c2e..300f1424 100644 --- a/internal/protocol/chat/chat_test.go +++ b/internal/protocol/chat/chat_test.go @@ -1674,7 +1674,7 @@ func TestToCoreStream_BasicDelta(t *testing.T) { } var evts []format.CoreStreamEvent - for e := range events { + for e := range events.Events { evts = append(evts, e) } @@ -1734,7 +1734,7 @@ func TestToCoreStream_ToolCallArgsDelta(t *testing.T) { } var evts []format.CoreStreamEvent - for e := range events { + for e := range events.Events { evts = append(evts, e) } @@ -1763,7 +1763,7 @@ func TestToCoreStream_EmptyChunk(t *testing.T) { } var evts []format.CoreStreamEvent - for e := range events { + for e := range events.Events { evts = append(evts, e) } @@ -1786,7 +1786,7 @@ func TestToCoreStream_NoContent(t *testing.T) { } var evts []format.CoreStreamEvent - for e := range events { + for e := range events.Events { evts = append(evts, e) } @@ -1813,7 +1813,7 @@ func TestToCoreStream_ContextCancel(t *testing.T) { // Events channel should close immediately due to cancelled context. var evts []format.CoreStreamEvent - for e := range events { + for e := range events.Events { evts = append(evts, e) } if len(evts) != 0 { @@ -1859,7 +1859,7 @@ func TestToCoreStream_MultiChoice(t *testing.T) { } var evts []format.CoreStreamEvent - for e := range events { + for e := range events.Events { evts = append(evts, e) } @@ -2234,7 +2234,7 @@ func TestToCoreStream_WithModel(t *testing.T) { } var evts []format.CoreStreamEvent - for e := range events { + for e := range events.Events { evts = append(evts, e) } if len(evts) < 4 { @@ -2262,7 +2262,7 @@ func TestToCoreStream_ContentBlockStartedNoRole(t *testing.T) { } var evts []format.CoreStreamEvent - for e := range events { + for e := range events.Events { evts = append(evts, e) } if len(evts) < 3 { @@ -2314,7 +2314,7 @@ func TestToCoreStream_ToolCallArgsDeltaByPosition(t *testing.T) { t.Fatal(err) } var deltas []format.CoreStreamEvent - for e := range events { + for e := range events.Events { if e.Type == format.CoreToolCallArgsDelta { deltas = append(deltas, e) } @@ -2372,7 +2372,7 @@ func TestToCoreStream_ToolCallArgsDeltaRespectsExplicitToolIndex(t *testing.T) { } var started []format.CoreStreamEvent var deltas []format.CoreStreamEvent - for e := range events { + for e := range events.Events { if e.Type == format.CoreContentBlockStarted && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" { started = append(started, e) } diff --git a/internal/protocol/format/registry_test.go b/internal/protocol/format/registry_test.go index c274fdc6..562326a1 100644 --- a/internal/protocol/format/registry_test.go +++ b/internal/protocol/format/registry_test.go @@ -11,15 +11,23 @@ import ( type mockClient struct{ protocol string } -func (m *mockClient) ClientProtocol() string { return m.protocol } -func (m *mockClient) ToCoreRequest(_ context.Context, _ any) (*CoreRequest, error) { return &CoreRequest{}, nil } -func (m *mockClient) FromCoreResponse(_ context.Context, _ *CoreResponse) (any, error) { return nil, nil } +func (m *mockClient) ClientProtocol() string { return m.protocol } +func (m *mockClient) ToCoreRequest(_ context.Context, _ any) (*CoreRequest, error) { + return &CoreRequest{}, nil +} +func (m *mockClient) FromCoreResponse(_ context.Context, _ *CoreResponse) (any, error) { + return nil, nil +} type mockProvider struct{ protocol string } -func (m *mockProvider) ProviderProtocol() string { return m.protocol } -func (m *mockProvider) FromCoreRequest(_ context.Context, _ *CoreRequest) (any, error) { return nil, nil } -func (m *mockProvider) ToCoreResponse(_ context.Context, _ any) (*CoreResponse, error) { return &CoreResponse{}, nil } +func (m *mockProvider) ProviderProtocol() string { return m.protocol } +func (m *mockProvider) FromCoreRequest(_ context.Context, _ *CoreRequest) (any, error) { + return nil, nil +} +func (m *mockProvider) ToCoreResponse(_ context.Context, _ any) (*CoreResponse, error) { + return &CoreResponse{}, nil +} type mockClientStream struct{ protocol string } diff --git a/internal/protocol/format/types.go b/internal/protocol/format/types.go index 520e2dba..c4a349b1 100644 --- a/internal/protocol/format/types.go +++ b/internal/protocol/format/types.go @@ -114,7 +114,7 @@ type CoreToolChoice struct { // nil = use provider defaults. type CoreThinkingConfig struct { Type string `json:"type,omitempty"` // "enabled" | "disabled" - BudgetTokens int `json:"budget_tokens,omitempty"` // token budget for thinking + BudgetTokens int `json:"budget_tokens,omitempty"` // token budget for thinking } // CoreOutputConfig controls output generation behavior. @@ -135,8 +135,8 @@ type CoreCacheControl struct { // CoreRequest is the protocol-agnostic representation of an LLM request. type CoreRequest struct { - Model string `json:"model"` - Messages []CoreMessage `json:"messages"` + Model string `json:"model"` + Messages []CoreMessage `json:"messages"` System []CoreContentBlock `json:"system,omitempty"` // Tools @@ -168,7 +168,6 @@ type CoreRequest struct { // Zero value (nil) = not set — adapter uses provider defaults. GenerationConfig map[string]any `json:"generation_config,omitempty"` - // Thinking controls extended thinking/reasoning behavior (e.g. Anthropic extended thinking). // nil = use provider defaults. Thinking *CoreThinkingConfig `json:"thinking,omitempty"` diff --git a/internal/protocol/google/adapter.go b/internal/protocol/google/adapter.go index 27904928..5450ec1c 100644 --- a/internal/protocol/google/adapter.go +++ b/internal/protocol/google/adapter.go @@ -38,13 +38,6 @@ type GeminiProviderAdapter struct { // currentModel tracks the model for the current request (used by cache). currentModel string - // toolUseIDMap maps ToolUseID → ToolName for FunctionResponse name resolution. - // Only valid during a single FromCoreRequest call. - toolUseIDMap map[string]string - toolUseIDMu sync.Mutex - - streamMu sync.Mutex - streamEvents []GenerateContentResponse prevSnapshots map[int]string // candidate index → previous text for delta computation } @@ -84,10 +77,6 @@ func (a *GeminiProviderAdapter) FromCoreRequest(ctx context.Context, req *format return nil, fmt.Errorf("google adapter: core request is nil") } - // Initialize per-request state. - a.toolUseIDMu.Lock() - a.toolUseIDMap = make(map[string]string) - a.toolUseIDMu.Unlock() a.currentModel = req.Model // Step 1: Allow plugins to mutate the CoreRequest before conversion. @@ -105,9 +94,11 @@ func (a *GeminiProviderAdapter) FromCoreRequest(ctx context.Context, req *format Contents: make([]Content, 0, len(req.Messages)), } + toolUseIDMap := make(map[string]string) + // System instruction (D-01): CoreRequest.System → Gemini system_instruction if len(req.System) > 0 { - sysContent := a.blocksToContent(req.System) + sysContent := a.blocksToContent(req.System, toolUseIDMap) if len(sysContent.Parts) > 0 { geminiReq.SystemInstruction = &sysContent } @@ -118,7 +109,12 @@ func (a *GeminiProviderAdapter) FromCoreRequest(ctx context.Context, req *format // with the same role (e.g. tool_result after user text) are merged. mergedContents := make([]Content, 0, len(req.Messages)) for _, msg := range req.Messages { - content := a.blocksToContent(msg.Content) + content := a.blocksToContent(msg.Content, toolUseIDMap) + // Skip messages with no content parts — they contribute no semantic value + // and may cause SDK role-alternating contract violations. + if len(content.Parts) == 0 { + continue + } content.Role = a.mapRoleToGemini(msg.Role) if len(mergedContents) > 0 && mergedContents[len(mergedContents)-1].Role == content.Role { mergedContents[len(mergedContents)-1].Parts = append(mergedContents[len(mergedContents)-1].Parts, content.Parts...) @@ -126,6 +122,14 @@ func (a *GeminiProviderAdapter) FromCoreRequest(ctx context.Context, req *format mergedContents = append(mergedContents, content) } } + // Ensure first Content has role "user" — Gemini API requires alternating + // user/model roles starting with user. Insert a placeholder if needed. + if len(mergedContents) > 0 && mergedContents[0].Role == "model" { + mergedContents = append( + []Content{{Role: "user", Parts: []Part{{Text: "_"}}}}, + mergedContents..., + ) + } geminiReq.Contents = mergedContents // SafetySettings (D-02): CoreRequest.SafetySettings map → Gemini []SafetySetting @@ -155,10 +159,6 @@ func (a *GeminiProviderAdapter) FromCoreRequest(ctx context.Context, req *format // Cache integration — look up or create CachedContent. a.prepareCache(ctx, geminiReq) - // Clean up per-request state. - a.toolUseIDMu.Lock() - a.toolUseIDMap = nil - a.toolUseIDMu.Unlock() a.currentModel = "" return geminiReq, nil @@ -228,16 +228,13 @@ func (a *GeminiProviderAdapter) ToCoreResponse(ctx context.Context, resp any) (* // ========================================================================= // bufferStreamEvent buffers raw GenerateContentResponse for trace capture. func (a *GeminiProviderAdapter) bufferStreamEvent(ev GenerateContentResponse) { - a.streamMu.Lock() - defer a.streamMu.Unlock() - a.streamEvents = append(a.streamEvents, ev) + // No-op: per-stream buffer is captured by goroutine closure. } // StreamBuffer returns the buffered stream events for trace capture. func (a *GeminiProviderAdapter) StreamBuffer() []GenerateContentResponse { - a.streamMu.Lock() - defer a.streamMu.Unlock() - return a.streamEvents + // Deprecated: use StreamResult.StreamBuffer instead. + return nil } // ToCoreStream — <-chan GenerateContentResponse → <-chan CoreStreamEvent @@ -255,7 +252,7 @@ func (a *GeminiProviderAdapter) StreamBuffer() []GenerateContentResponse { // - core.text.delta (each subsequent chunk with new text) // - core.content_block.done (chunk with FinishReason set) // - core.completed (final chunk with UsageMetadata) -func (a *GeminiProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan format.CoreStreamEvent, error) { +func (a *GeminiProviderAdapter) ToCoreStream(ctx context.Context, src any) (*format.StreamResult, error) { ch, ok := src.(<-chan GenerateContentResponse) if !ok { return nil, fmt.Errorf("google adapter: expected <-chan GenerateContentResponse, got %T", src) @@ -263,13 +260,14 @@ func (a *GeminiProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-ch events := make(chan format.CoreStreamEvent, 64) - // Initialize stream event buffer for trace capture. - a.streamMu.Lock() - a.streamEvents = make([]GenerateContentResponse, 0, 64) - a.streamMu.Unlock() + // Per-stream buffer — local to this call, not shared across concurrent requests. + var buf []GenerateContentResponse + var bufMu sync.Mutex + bufReady := make(chan struct{}) go func() { defer close(events) + defer close(bufReady) // Per-candidate state for delta computation. type candidateState struct { @@ -297,7 +295,12 @@ func (a *GeminiProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-ch case <-ctx.Done(): return case chunk, ok := <-ch: - a.bufferStreamEvent(chunk) + // Append to local per-stream buffer with size cap. + bufMu.Lock() + if len(buf) < 1024 { + buf = append(buf, chunk) + } + bufMu.Unlock() if !ok { // Channel closed — emit completion if not already done. if !seenCompletion { @@ -386,7 +389,19 @@ func (a *GeminiProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-ch } }() - return events, nil + return &format.StreamResult{ + Events: events, + StreamBuffer: func() []any { + <-bufReady + bufMu.Lock() + defer bufMu.Unlock() + result := make([]any, len(buf)) + for i, ev := range buf { + result[i] = ev + } + return result + }, + }, nil } // ========================================================================= @@ -394,7 +409,7 @@ func (a *GeminiProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-ch // ========================================================================= // blocksToContent converts []CoreContentBlock to Content (Gemini format). -func (a *GeminiProviderAdapter) blocksToContent(blocks []format.CoreContentBlock) Content { +func (a *GeminiProviderAdapter) blocksToContent(blocks []format.CoreContentBlock, toolUseIDMap map[string]string) Content { parts := make([]Part, 0, len(blocks)) for _, b := range blocks { switch b.Type { @@ -409,9 +424,9 @@ func (a *GeminiProviderAdapter) blocksToContent(blocks []format.CoreContentBlock }) case "tool_use": // Store ToolUseID -> ToolName mapping for later FunctionResponse resolution (G-01). - a.toolUseIDMu.Lock() - a.toolUseIDMap[b.ToolUseID] = b.ToolName - a.toolUseIDMu.Unlock() + if toolUseIDMap != nil { + toolUseIDMap[b.ToolUseID] = b.ToolName + } parts = append(parts, Part{ FunctionCall: &FunctionCall{ Name: b.ToolName, @@ -421,11 +436,11 @@ func (a *GeminiProviderAdapter) blocksToContent(blocks []format.CoreContentBlock case "tool_result": // Look up the function name from ToolUseID (G-01). funcName := b.ToolUseID - a.toolUseIDMu.Lock() - if fn, ok := a.toolUseIDMap[b.ToolUseID]; ok { - funcName = fn + if toolUseIDMap != nil { + if fn, ok := toolUseIDMap[b.ToolUseID]; ok { + funcName = fn + } } - a.toolUseIDMu.Unlock() // Combine tool result content into a single text for the response. var respText string @@ -554,45 +569,59 @@ func (a *GeminiProviderAdapter) applyGenerationConfigMap(gc *GenerationConfig, c func (a *GeminiProviderAdapter) fromParts(parts []Part) []format.CoreContentBlock { result := make([]format.CoreContentBlock, 0, len(parts)) funcCallSeq := make(map[string]int) + callIDStacks := make(map[string][]string) // per-function-name call ID stack for FunctionResponse matching for _, p := range parts { - result = append(result, a.fromPartWithSeq(p, funcCallSeq)) + block, stacks := a.fromPartWithSeq(p, funcCallSeq, callIDStacks) + if stacks != nil { + callIDStacks = stacks + } + result = append(result, block) } return result } // fromPartWithSeq converts a single Gemini Part to CoreContentBlock. -func (a *GeminiProviderAdapter) fromPartWithSeq(p Part, funcCallSeq map[string]int) format.CoreContentBlock { +// Returns the block and optionally an updated callIDStacks for FunctionCall tracking. +func (a *GeminiProviderAdapter) fromPartWithSeq(p Part, funcCallSeq map[string]int, callIDStacks map[string][]string) (format.CoreContentBlock, map[string][]string) { switch { case p.Text != "": return format.CoreContentBlock{ Type: "text", Text: p.Text, - } + }, nil case p.FunctionCall != nil: callName := p.FunctionCall.Name funcCallSeq[callName]++ callID := callName + "__call_" + strconv.Itoa(funcCallSeq[callName]) + callIDStacks[callName] = append(callIDStacks[callName], callID) return format.CoreContentBlock{ Type: "tool_use", ToolUseID: callID, ToolName: callName, ToolInput: p.FunctionCall.Args, - } + }, callIDStacks case p.FunctionResponse != nil: + respName := p.FunctionResponse.Name + callID := "" + if stack := callIDStacks[respName]; len(stack) > 0 { + callID = stack[len(stack)-1] + } else { + callID = respName + "__call_1" + } return format.CoreContentBlock{ Type: "tool_result", - ToolUseID: p.FunctionResponse.Name, - } + ToolUseID: callID, + }, nil case p.InlineData != nil: return format.CoreContentBlock{ Type: "image", ImageData: p.InlineData.Data, MediaType: p.InlineData.MimeType, - } + }, nil default: return format.CoreContentBlock{ Type: "text", - } + }, nil } } diff --git a/internal/protocol/google/client.go b/internal/protocol/google/client.go index e7a3591e..7aaa4d9c 100644 --- a/internal/protocol/google/client.go +++ b/internal/protocol/google/client.go @@ -20,9 +20,9 @@ import ( type ClientConfig struct { BaseURL string APIKey string - Project string // Vertex AI project ID (optional, for Vertex AI endpoint) - Location string // Vertex AI location (optional, default "us-central1") - Version string // API version (default "v1") + Project string // Vertex AI project ID (optional, for Vertex AI endpoint) + Location string // Vertex AI location (optional, default "us-central1") + Version string // API version (default "v1") UserAgent string Client *http.Client } @@ -146,7 +146,6 @@ func (c *Client) StreamGenerateContent(ctx context.Context, model string, req *G // to close (connections are managed by http.Client), so this is a no-op. func (c *Client) Close() error { return nil } - // ============================================================================ // CachedContent API methods // ============================================================================ @@ -243,6 +242,7 @@ func (c *Client) DeleteCachedContent(ctx context.Context, name string) error { } return nil } + // ============================================================================ // Internal helpers // ============================================================================ diff --git a/internal/protocol/google/google_test.go b/internal/protocol/google/google_test.go index c876bf9a..13f94ffa 100644 --- a/internal/protocol/google/google_test.go +++ b/internal/protocol/google/google_test.go @@ -816,12 +816,17 @@ func TestFromCoreRequest_ToolUseAndToolResult(t *testing.T) { } geminiReq := result.(*google.GenerateContentRequest) - if len(geminiReq.Contents) != 2 { - t.Fatalf("Contents: got %d, want 2", len(geminiReq.Contents)) + if len(geminiReq.Contents) != 3 { + t.Fatalf("Contents: got %d, want 3 (placeholder + assistant + user)", len(geminiReq.Contents)) + } + + // First Content is the inserted user placeholder. + if geminiReq.Contents[0].Role != "user" || len(geminiReq.Contents[0].Parts) != 1 || geminiReq.Contents[0].Parts[0].Text != "_" { + t.Fatalf("expected user placeholder as first Content, got role=%q parts=%v", geminiReq.Contents[0].Role, geminiReq.Contents[0].Parts) } // Assistant message: FunctionCall - astParts := geminiReq.Contents[0].Parts + astParts := geminiReq.Contents[1].Parts if len(astParts) != 1 { t.Fatalf("assistant Parts: got %d, want 1", len(astParts)) } @@ -836,7 +841,7 @@ func TestFromCoreRequest_ToolUseAndToolResult(t *testing.T) { } // User message: FunctionResponse - userParts := geminiReq.Contents[1].Parts + userParts := geminiReq.Contents[2].Parts if len(userParts) != 1 { t.Fatalf("user Parts: got %d, want 1", len(userParts)) } @@ -1081,15 +1086,23 @@ func TestFromCoreRequest_ReasoningBlock(t *testing.T) { } geminiReq := result.(*google.GenerateContentRequest) + // First Content is the inserted user placeholder. + if len(geminiReq.Contents) != 3 { + t.Fatalf("Contents: got %d, want 3 (placeholder + assistant + user)", len(geminiReq.Contents)) + } + if geminiReq.Contents[0].Role != "user" || len(geminiReq.Contents[0].Parts) != 1 || geminiReq.Contents[0].Parts[0].Text != "_" { + t.Fatalf("expected user placeholder as first Content, got role=%q", geminiReq.Contents[0].Role) + } + // Assistant message should have 2 parts (reasoning converted to text). - if len(geminiReq.Contents[0].Parts) != 2 { - t.Fatalf("assistant Parts: got %d, want 2 (reasoning block converted to text)", len(geminiReq.Contents[0].Parts)) + if len(geminiReq.Contents[1].Parts) != 2 { + t.Fatalf("assistant Parts: got %d, want 2 (reasoning block converted to text)", len(geminiReq.Contents[1].Parts)) } - if geminiReq.Contents[0].Parts[0].Text != "thinking step by step" { - t.Errorf("assistant parts[0] reasoning text = %q, want thinking step by step", geminiReq.Contents[0].Parts[0].Text) + if geminiReq.Contents[1].Parts[0].Text != "thinking step by step" { + t.Errorf("assistant parts[0] reasoning text = %q, want thinking step by step", geminiReq.Contents[1].Parts[0].Text) } - if geminiReq.Contents[0].Parts[1].Text != "final answer" { - t.Errorf("assistant parts[1] text = %q, want final answer", geminiReq.Contents[0].Parts[1].Text) + if geminiReq.Contents[1].Parts[1].Text != "final answer" { + t.Errorf("assistant parts[1] text = %q, want final answer", geminiReq.Contents[1].Parts[1].Text) } } @@ -1381,7 +1394,7 @@ func TestToCoreStream_SingleCandidate(t *testing.T) { } var evts []format.CoreStreamEvent - for e := range events { + for e := range events.Events { evts = append(evts, e) } @@ -1450,7 +1463,7 @@ func TestToCoreStream_MultiCandidate(t *testing.T) { } var evts []format.CoreStreamEvent - for e := range events { + for e := range events.Events { evts = append(evts, e) } @@ -1496,7 +1509,7 @@ func TestToCoreStream_ContextCancel(t *testing.T) { cancel() // Channel should close cleanly - for range events { + for range events.Events { } } @@ -1511,7 +1524,7 @@ func TestToCoreStream_EmptyChannel(t *testing.T) { } var evts []format.CoreStreamEvent - for e := range events { + for e := range events.Events { evts = append(evts, e) } @@ -1794,8 +1807,10 @@ func TestFromCoreRequest_DefaultContentBlockNoText(t *testing.T) { } geminiReq := result.(*google.GenerateContentRequest) // Unknown type with no text produces no Part (default falls through without appending) - if len(geminiReq.Contents[0].Parts) != 0 { - t.Errorf("Parts: got %d, want 0 (unknown type with no text should be skipped)", len(geminiReq.Contents[0].Parts)) + if len(geminiReq.Contents) == 0 { + t.Logf("Contents empty (all blocks filtered)") + } else if len(geminiReq.Contents[0].Parts) != 0 { + t.Errorf("Parts: got %d, want 0 (unknown type should be skipped)", len(geminiReq.Contents[0].Parts)) } } @@ -1822,7 +1837,7 @@ func TestToCoreResponse_FromPartFunctionResponse(t *testing.T) { if blocks[0].Type != "tool_result" { t.Errorf("Type = %q, want tool_result", blocks[0].Type) } - if blocks[0].ToolUseID != "get_weather" { + if blocks[0].ToolUseID != "get_weather__call_1" { t.Errorf("ToolUseID = %q", blocks[0].ToolUseID) } } @@ -1903,7 +1918,7 @@ func TestToCoreStream_ComputeDeltaNoChange(t *testing.T) { } var evts []format.CoreStreamEvent - for e := range events { + for e := range events.Events { evts = append(evts, e) } diff --git a/internal/protocol/google/types.go b/internal/protocol/google/types.go index e95e47c8..bbd64110 100644 --- a/internal/protocol/google/types.go +++ b/internal/protocol/google/types.go @@ -10,12 +10,12 @@ import "encoding/json" // GenerateContentRequest maps to Gemini's generateContent request body. // https://ai.google.dev/api/generate-content type GenerateContentRequest struct { - Contents []Content `json:"contents"` - SystemInstruction *Content `json:"systemInstruction,omitempty"` - SafetySettings []SafetySetting `json:"safetySettings,omitempty"` - GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"` - Tools []Tool `json:"tools,omitempty"` - ToolConfig json.RawMessage `json:"toolConfig,omitempty"` + Contents []Content `json:"contents"` + SystemInstruction *Content `json:"systemInstruction,omitempty"` + SafetySettings []SafetySetting `json:"safetySettings,omitempty"` + GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolConfig json.RawMessage `json:"toolConfig,omitempty"` // CachedContent references a CachedContent resource for prompt caching. // When set, system_instruction, tools, and tool_config must not be set // (Gemini API constraint — they become part of the cached content). @@ -31,11 +31,11 @@ type Content struct { // Part represents a single part within Content. type Part struct { - Text string `json:"text,omitempty"` - InlineData *Blob `json:"inlineData,omitempty"` - FileData *FileData `json:"fileData,omitempty"` - FunctionCall *FunctionCall `json:"functionCall,omitempty"` - FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` + Text string `json:"text,omitempty"` + InlineData *Blob `json:"inlineData,omitempty"` + FileData *FileData `json:"fileData,omitempty"` + FunctionCall *FunctionCall `json:"functionCall,omitempty"` + FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` } // Blob represents inline binary data. @@ -106,10 +106,10 @@ type GenerateContentResponse struct { // Candidate represents a single response candidate. type Candidate struct { - Index int `json:"index"` - Content Content `json:"content"` - FinishReason string `json:"finishReason"` // STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER - SafetyRatings []SafetyRating `json:"safetyRatings,omitempty"` + Index int `json:"index"` + Content Content `json:"content"` + FinishReason string `json:"finishReason"` // STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER + SafetyRatings []SafetyRating `json:"safetyRatings,omitempty"` } // SafetyRating represents a safety rating for a category. diff --git a/internal/protocol/openai/adapter.go b/internal/protocol/openai/adapter.go index 32038518..330960a4 100644 --- a/internal/protocol/openai/adapter.go +++ b/internal/protocol/openai/adapter.go @@ -34,15 +34,19 @@ type OpenAIAdapter struct { hooks format.CorePluginHooks disablePatchProxy func(string) bool - streamMu sync.Mutex - streamEvents []StreamEvent + nsStrategy codextool.NamespaceStrategy } // NewOpenAIAdapter creates a new OpenAIAdapter with the given config and hooks. -func NewOpenAIAdapter(hooks format.CorePluginHooks) *OpenAIAdapter { +func NewOpenAIAdapter(hooks format.CorePluginHooks, nsStrategy ...codextool.NamespaceStrategy) *OpenAIAdapter { + strategy := codextool.NestedOneOf + if len(nsStrategy) > 0 { + strategy = nsStrategy[0] + } return &OpenAIAdapter{ - hooks: hooks.WithDefaults(), + hooks: hooks.WithDefaults(), disablePatchProxy: hooks.DisablePatchProxy, + nsStrategy: strategy, } } @@ -110,7 +114,7 @@ func (a *OpenAIAdapter) ToCoreRequest(ctx context.Context, req any) (*format.Cor // 5. Convert tools. if len(openaiReq.Tools) > 0 { - coreReq.Tools = flattenToolsWithNamespace(openaiReq.Tools, "", a.disablePatchProxy) + coreReq.Tools = flattenToolsWithNamespace(openaiReq.Tools, "", a.disablePatchProxy, a.nsStrategy) } if injected := a.hooks.InjectTools(format.ContextWithCoreRequest(ctx, coreReq)); len(injected) > 0 { coreReq.Tools = append(coreReq.Tools, injected...) @@ -318,39 +322,85 @@ func (a *OpenAIAdapter) FromCoreResponse(ctx context.Context, resp *format.CoreR // to produce correct OpenAI stream semantics. func (a *OpenAIAdapter) FromCoreStream(ctx context.Context, req *format.CoreRequest, events <-chan format.CoreStreamEvent) (any, error) { out := make(chan StreamEvent) - - go a.streamLoop(ctx, req, events, out) - - return (<-chan StreamEvent)(out), nil + bufReady := make(chan struct{}) + + var buf []StreamEvent + var bufMu sync.Mutex + + go func() { + defer close(bufReady) + a.streamLoopWithBuf(ctx, req, events, out, &buf, &bufMu) + }() + + return &OpenAIStreamResult{ + ch: out, + buf: func() []any { + <-bufReady + bufMu.Lock() + defer bufMu.Unlock() + result := make([]any, len(buf)) + for i, ev := range buf { + result[i] = ev + } + return result + }, + }, nil } // bufferStreamEvent buffers the OpenAI stream event for trace capture, // up to the 4MB limit. The event is JSON-marshalled to estimate its size. func (a *OpenAIAdapter) bufferStreamEvent(ev StreamEvent) { - a.streamMu.Lock() - defer a.streamMu.Unlock() - a.streamEvents = append(a.streamEvents, ev) + // No-op: per-stream buffer is captured by streamLoop's closure. } // StreamBuffer returns the buffered stream events for trace capture. func (a *OpenAIAdapter) StreamBuffer() []StreamEvent { - a.streamMu.Lock() - defer a.streamMu.Unlock() - return a.streamEvents + // Deprecated: use StreamResult.StreamBuffer instead. + return nil +} + +// openaiStreamResult wraps the OpenAI stream channel with per-stream buffer access. +type OpenAIStreamResult struct { + ch <-chan StreamEvent + buf func() []any +} + +// Chan returns the underlying channel of StreamEvents. +func (r *OpenAIStreamResult) Chan() <-chan StreamEvent { + return r.ch +} + +// Buffer returns the captured stream events for post-stream processing. +func (r *OpenAIStreamResult) Buffer() []any { + if r.buf == nil { + return nil + } + return r.buf() } // streamLoop is the goroutine body for FromCoreStream. -func (a *OpenAIAdapter) streamLoop(ctx context.Context, coreReq *format.CoreRequest, events <-chan format.CoreStreamEvent, out chan<- StreamEvent) { - defer close(out) +// nestedBufferState tracks two-level buffering for nested namespace tool calls. +type nestedBufferState struct { + toolUseID string + toolName string // original namespace-expanded name + actionName string // extracted sub-tool action name + namespace string // item namespace + outputIndex int // index in response.Output + emitted bool // whether output_item.added has been sent + buf strings.Builder // accumulated raw JSON arguments + sequence func() int64 // event sequencer (captures next func) +} - // Reset stream event buffer for this request. - a.streamMu.Lock() - a.streamEvents = nil - a.streamMu.Unlock() +func (a *OpenAIAdapter) streamLoopWithBuf(ctx context.Context, coreReq *format.CoreRequest, events <-chan format.CoreStreamEvent, out chan<- StreamEvent, buf *[]StreamEvent, bufMu *sync.Mutex) { + defer close(out) // send buffers the event for trace capture before writing to the output channel. send := func(ev StreamEvent) { - a.bufferStreamEvent(ev) + bufMu.Lock() + if len(*buf) < 1024 { + *buf = append(*buf, ev) + } + bufMu.Unlock() out <- ev } @@ -372,6 +422,7 @@ func (a *OpenAIAdapter) streamLoop(ctx context.Context, coreReq *format.CoreRequ itemIDs := make(map[int]string) reasonIndexes := make(map[int]bool) toolCallFinalized := make(map[int]bool) + nestedBuffers := make(map[int]*nestedBufferState) for event := range events { // Let hooks skip events. @@ -468,18 +519,37 @@ func (a *OpenAIAdapter) streamLoop(ctx context.Context, coreReq *format.CoreRequ } itemIDs[index] = fmt.Sprintf("fc_item_%d", index) toolBlockNames[index] = event.ContentBlock.ToolName - item := buildToolOutputItemStreaming(event.ContentBlock, coreReq.Extensions, toolUseID) - outputIndexes[index] = len(response.Output) - response.Output = append(response.Output, item) - send(StreamEvent{ - Event: "response.output_item.added", - Data: OutputItemEvent{ - Type: "response.output_item.added", - SequenceNumber: next(), - OutputIndex: outputIndexes[index], - Item: item, - }, - }) + + // Check if this tool is a nested namespace (NestedOneOf/NestedAnyOf). + // If so, defer output_item.added until we extract the action from args. + toolMap := codextool.DecodeToolMapFromExtensions(coreReq.Extensions) + spec, hasSpec := toolMap.Lookup(event.ContentBlock.ToolName) + isNested := hasSpec && (spec.Kind == codextool.ToolNestedOneOf || spec.Kind == codextool.ToolNestedAnyOf) + + if isNested { + // Defer emission: buffer args until action is extracted. + nestedBuffers[index] = &nestedBufferState{ + toolUseID: toolUseID, + toolName: event.ContentBlock.ToolName, + namespace: spec.Namespace, + emitted: false, + outputIndex: -1, + sequence: next, + } + } else { + item := buildToolOutputItemStreaming(event.ContentBlock, coreReq.Extensions, toolUseID) + outputIndexes[index] = len(response.Output) + response.Output = append(response.Output, item) + send(StreamEvent{ + Event: "response.output_item.added", + Data: OutputItemEvent{ + Type: "response.output_item.added", + SequenceNumber: next(), + OutputIndex: outputIndexes[index], + Item: item, + }, + }) + } } // ================================================================== @@ -619,6 +689,50 @@ func (a *OpenAIAdapter) streamLoop(ctx context.Context, coreReq *format.CoreRequ case format.CoreToolCallArgsDelta: index := event.Index toolCallArgs[index] += event.Delta + + // Check if this is a buffered nested namespace tool call. + if nBuf, isBuffered := nestedBuffers[index]; isBuffered { + nBuf.buf.WriteString(event.Delta) + + if !nBuf.emitted { + if action, ok := codextool.TryExtractAction(nBuf.buf.String()); ok { + nBuf.actionName = action + nBuf.emitted = true + + // Emit output_item.added with the correct action name. + item := OutputItem{ + Type: "function_call", + ID: nBuf.toolUseID, + CallID: nBuf.toolUseID, + Name: action, + Status: "in_progress", + } + if nBuf.namespace != "" { + item.Namespace = nBuf.namespace + } + outputIndexes[index] = len(response.Output) + nBuf.outputIndex = outputIndexes[index] + response.Output = append(response.Output, item) + send(StreamEvent{ + Event: "response.output_item.added", + Data: OutputItemEvent{ + Type: "response.output_item.added", + SequenceNumber: next(), + OutputIndex: outputIndexes[index], + Item: item, + }, + }) + + // Replay already-buffered params (minus the action prefix). + replayNestedBuffer(nBuf, send, next, index, itemIDs) + } + } else { + // Already emitted: pass through deltas directly. + emitNestedDelta(nBuf, event.Delta, send, next, index, itemIDs, outputIndexes) + } + break + } + send(StreamEvent{ Event: "response.function_call_arguments.delta", Data: FunctionCallArgumentsDeltaEvent{ @@ -639,6 +753,66 @@ func (a *OpenAIAdapter) streamLoop(ctx context.Context, coreReq *format.CoreRequ break } finalArgs := event.Delta + + // Check if this is a buffered nested namespace tool call that hasn't + // emitted yet (action never extracted — flush all buffered data). + if nBuf, isBuffered := nestedBuffers[index]; isBuffered { + if !nBuf.emitted { + // Action never extracted — flush everything as the original name. + nBuf.actionName = nBuf.toolName + finalCombined := nBuf.buf.String() + if finalArgs != "" && finalCombined == "" { + finalCombined = finalArgs + } + item := OutputItem{ + Type: "function_call", + ID: nBuf.toolUseID, + CallID: nBuf.toolUseID, + Name: nBuf.toolName, + Status: "completed", + } + if nBuf.namespace != "" { + item.Namespace = nBuf.namespace + } + item.Arguments = finalCombined + outputIndexes[index] = len(response.Output) + nBuf.outputIndex = outputIndexes[index] + response.Output = append(response.Output, item) + nBuf.emitted = true + send(StreamEvent{ + Event: "response.output_item.added", + Data: OutputItemEvent{ + Type: "response.output_item.added", + SequenceNumber: next(), + OutputIndex: outputIndexes[index], + Item: item, + }, + }) + send(StreamEvent{ + Event: "response.function_call_arguments.done", + Data: FunctionCallArgumentsDoneEvent{ + Type: "response.function_call_arguments.done", + SequenceNumber: next(), + ItemID: itemIDs[index], + OutputIndex: outputIndexes[index], + Arguments: finalCombined, + }, + }) + send(StreamEvent{ + Event: "response.output_item.done", + Data: OutputItemEvent{ + Type: "response.output_item.done", + SequenceNumber: next(), + OutputIndex: outputIndexes[index], + Item: response.Output[outputIndexes[index]], + }, + }) + delete(nestedBuffers, index) + break + } + // Already emitted — use existing output index. + delete(nestedBuffers, index) + } if finalArgs == "" { finalArgs = toolCallArgs[index] } @@ -1491,6 +1665,131 @@ func buildToolOutputItem(block format.CoreContentBlock, extensions map[string]an // buildToolOutputItemStreaming constructs a streaming OutputItem for a tool_use content block start. // The item is created with "in_progress" status. + +// replayNestedBuffer emits the accumulated params from a nested namespace buffer +// as function_call_arguments.delta events, stripping the action prefix. +func replayNestedBuffer(nBuf *nestedBufferState, send func(StreamEvent), next func() int64, index int, itemIDs map[int]string) { + if nBuf.buf.Len() == 0 { + return + } + paramsOnly := stripPrefixActionFromJSON(nBuf.buf.String(), nBuf.actionName) + if paramsOnly != "" { + send(StreamEvent{ + Event: "response.function_call_arguments.delta", + Data: FunctionCallArgumentsDeltaEvent{ + Type: "response.function_call_arguments.delta", + SequenceNumber: next(), + ItemID: itemIDs[index], + OutputIndex: nBuf.outputIndex, + Delta: paramsOnly, + }, + }) + } +} + +// emitNestedDelta sends a function_call_arguments.delta for a nested namespace tool +// that has already emitted its output_item.added. +func emitNestedDelta(nBuf *nestedBufferState, delta string, send func(StreamEvent), next func() int64, index int, itemIDs map[int]string, outputIndexes map[int]int) { + cleanedDelta := stripPrefixActionFromJSON(delta, nBuf.actionName) + if cleanedDelta == "" { + return + } + oi := nBuf.outputIndex + if oi < 0 { + oi = outputIndexes[index] + } + send(StreamEvent{ + Event: "response.function_call_arguments.delta", + Data: FunctionCallArgumentsDeltaEvent{ + Type: "response.function_call_arguments.delta", + SequenceNumber: next(), + ItemID: itemIDs[index], + OutputIndex: oi, + Delta: cleanedDelta, + }, + }) +} + +// stripPrefixActionFromJSON removes the "action": "value" portion from the start +// of a partial JSON string. Uses a position-constrained scan that only looks for +// "action" as a top-level key (before any nested "{" or after the first "," that +// signals the end of the first key-value pair). Falls back to full JSON parse +// when the buffer is syntactically complete. +func stripPrefixActionFromJSON(raw string, action string) string { + if raw == "" { + return "" + } + + // First, try a full JSON parse — if the buffer is complete, this is the + // most robust path. + var parsed map[string]json.RawMessage + if err := json.Unmarshal([]byte(raw), &parsed); err == nil { + delete(parsed, "action") + if len(parsed) == 0 { + return "" + } + data, _ := json.Marshal(parsed) + result := string(data) + // Strip outer braces for streaming delta context. + result = strings.TrimPrefix(result, "{") + result = strings.TrimSuffix(result, "}") + return strings.TrimSpace(result) + } + + // Fallback: position-constrained scan. Only look for "action" in the + // first object level — roughly the content before a nested "{" or after + // the first top-level comma that follows the action key-value pair. + // + // Strategy: find "action" at the top level, extract its value, remove + // the key-value pair, and return the remaining JSON fragment. + idx := strings.Index(raw, `"action"`) + if idx < 0 { + return raw + } + + // Only treat as top-level action if it appears before any nested object + // (the namespace tool schema is flat — action is at the root). + firstBrace := strings.IndexByte(raw, '{') + if firstBrace >= 0 && firstBrace < idx { + // "action" is not in the first object — return raw unchanged. + return raw + } + + // Find the colon after the key. + afterKey := raw[idx+8:] + colonIdx := strings.IndexByte(afterKey, ':') + if colonIdx < 0 { + return raw + } + + // Skip whitespace and colon. + afterColon := strings.TrimSpace(afterKey[colonIdx+1:]) + if len(afterColon) == 0 { + return raw + } + + // Must start with a quote (action is always a string). + if afterColon[0] != '"' { + return raw + } + afterColon = afterColon[1:] // skip opening quote + endQuote := strings.IndexByte(afterColon, '"') + if endQuote < 0 { + return raw + } + + // Extract the portion after the action value. + afterValue := strings.TrimSpace(afterColon[endQuote+1:]) + // Strip trailing comma. + afterValue = strings.TrimLeft(afterValue, ", ") + + // Combine the part before the action key with the part after the value. + prefix := strings.TrimRight(raw[:idx], ", ") + if prefix == "" || prefix == "{" || strings.TrimSpace(prefix) == "{" { + return afterValue + } + return prefix + ", " + afterValue +} func buildToolOutputItemStreaming(block *format.CoreContentBlock, extensions map[string]any, toolUseID string) OutputItem { toolMap := codextool.DecodeToolMapFromExtensions(extensions) itemT, itemN, itemNS, itemInput, isLS, actionJSON := codextool.OutputItemFromBlock(block.ToolName, block.ToolInput, toolMap) @@ -1523,7 +1822,7 @@ func buildToolOutputItemStreaming(block *format.CoreContentBlock, extensions map // Function/web_search/file_search/code_interpreter/computer_use_preview pass through. // Custom tools are expanded using codex package helpers. // Namespace tools are recursively flattened. -func convertToolWithNamespace(tool Tool, namespace string, disablePatchProxy func(string) bool) []format.CoreTool { +func convertToolWithNamespace(tool Tool, namespace string, disablePatchProxy func(string) bool, nsStrategy codextool.NamespaceStrategy) []format.CoreTool { name := namespacedToolName(namespace, tool.Name) ext := make(map[string]any) @@ -1573,7 +1872,23 @@ func convertToolWithNamespace(tool Tool, namespace string, disablePatchProxy fun case "namespace": ns := namespacedToolName(namespace, tool.Name) - return flattenToolsWithNamespace(tool.Tools, ns, disablePatchProxy) + // Build sub-tool map for BuildNamespaceTools. + subMap := make(map[string]format.CoreTool) + var subNames []string + for _, sub := range tool.Tools { + subNames = append(subNames, sub.Name) + subMap[sub.Name] = format.CoreTool{ + Name: sub.Name, + Description: sub.Description, + InputSchema: sub.Parameters, + } + } + tools, err := codextool.BuildNamespaceTools(subNames, subMap, ns, nsStrategy) + if err != nil || len(tools) == 0 { + // Fallback to flat expansion. + return flattenToolsWithNamespace(tool.Tools, ns, disablePatchProxy, nsStrategy) + } + return tools case "custom": grammar := codextool.CustomToolGrammar(tool.Format) @@ -1627,22 +1942,28 @@ func convertToolWithNamespace(tool Tool, namespace string, disablePatchProxy fun // flattenToolsWithNamespace recursively flattens namespace tools and converts // individual tools, building a flat list of CoreTools suitable for upstream providers. -func flattenToolsWithNamespace(openaiTools []Tool, namespace string, disablePatchProxy func(string) bool) []format.CoreTool { +func flattenToolsWithNamespace(openaiTools []Tool, namespace string, disablePatchProxy func(string) bool, nsStrategy codextool.NamespaceStrategy) []format.CoreTool { var result []format.CoreTool for _, t := range openaiTools { - converted := convertToolWithNamespace(t, namespace, disablePatchProxy) + converted := convertToolWithNamespace(t, namespace, disablePatchProxy, nsStrategy) result = append(result, converted...) } // Deduplicate by name: Codex may send the same tool both as a namespace member // and as an independently-injected function tool (e.g. MCP tools that inject themselves - // after first use). Keep the first occurrence, which has the correct metadata. - seen := make(map[string]bool, len(result)) + // after first use). Prefer tools with a codex_namespace annotation (comes from namespace + // expansion) over flat function tools with the same name. + seen := make(map[string]int, len(result)) // name → index in deduped deduped := make([]format.CoreTool, 0, len(result)) for _, t := range result { - if seen[t.Name] { + if existing, exists := seen[t.Name]; exists { + existingNS, _ := deduped[existing].Extensions["codex_namespace"].(string) + newNS, _ := t.Extensions["codex_namespace"].(string) + if existingNS == "" && newNS != "" { + deduped[existing] = t + } continue } - seen[t.Name] = true + seen[t.Name] = len(deduped) deduped = append(deduped, t) } result = deduped diff --git a/internal/protocol/openai/adapter_test.go b/internal/protocol/openai/adapter_test.go index 0bcca6d5..be3e9b3f 100644 --- a/internal/protocol/openai/adapter_test.go +++ b/internal/protocol/openai/adapter_test.go @@ -294,8 +294,8 @@ func TestToCoreRequest_BatchesCustomToolCallsAndOutputsIntoSingleRound(t *testin for i, want := range []struct { assistantTextIdx int msgIdx int - callID string - outcome string + callID string + outcome string }{ {0, 1, "call_a", "ok a"}, {3, 4, "call_b", "ok b"}, @@ -345,7 +345,13 @@ func TestFromCoreStream_NoDuplicateDoneForToolUse(t *testing.T) { if err != nil { t.Fatal(err) } - stream := streamAny.(<-chan openai.StreamEvent) + var stream <-chan openai.StreamEvent + oaiResult, ok := streamAny.(*openai.OpenAIStreamResult) + if ok { + stream = oaiResult.Chan() + } else { + stream = streamAny.(<-chan openai.StreamEvent) + } var argsDone int var itemDone int for ev := range stream { diff --git a/internal/protocol/openai/types.go b/internal/protocol/openai/types.go index 24ba63c7..6f82371a 100644 --- a/internal/protocol/openai/types.go +++ b/internal/protocol/openai/types.go @@ -116,10 +116,10 @@ type ContentPart struct { // Usage represents token usage statistics. type Usage struct { - InputTokens int `json:"input_tokens,omitempty"` - OutputTokens int `json:"output_tokens,omitempty"` - TotalTokens int `json:"total_tokens"` - InputTokensDetails InputTokensDetails `json:"input_tokens_details,omitempty"` + InputTokens int `json:"input_tokens,omitempty"` + OutputTokens int `json:"output_tokens,omitempty"` + TotalTokens int `json:"total_tokens"` + InputTokensDetails InputTokensDetails `json:"input_tokens_details,omitempty"` OutputTokensDetails OutputTokensDetails `json:"output_tokens_details,omitempty"` } diff --git a/internal/service/api/api_test.go b/internal/service/api/api_test.go index ce5ef004..dff58be3 100644 --- a/internal/service/api/api_test.go +++ b/internal/service/api/api_test.go @@ -130,9 +130,9 @@ func buildTableNameMap(tables []db.TableSpec) map[string]string { // testAPIDB implements db.Store backed by an in-memory SQLite database for API tests. type testAPIDB struct { - t *testing.T - db *sql.DB - consumer string + t *testing.T + db *sql.DB + consumer string tableNames map[string]string } diff --git a/internal/service/api/models.go b/internal/service/api/models.go index fd196f1b..a7ffcd53 100644 --- a/internal/service/api/models.go +++ b/internal/service/api/models.go @@ -98,13 +98,13 @@ func (r *Router) handleGetModel(w http.ResponseWriter, req *http.Request) { } resp := map[string]any{ - "slug": slug, - "display_name": def.DisplayName, - "description": def.Description, - "context_window": def.ContextWindow, - "max_output_tokens": def.MaxOutputTokens, - "input_modalities": def.InputModalities, - "providers": providers, + "slug": slug, + "display_name": def.DisplayName, + "description": def.Description, + "context_window": def.ContextWindow, + "max_output_tokens": def.MaxOutputTokens, + "input_modalities": def.InputModalities, + "providers": providers, } respondJSON(w, http.StatusOK, resp) @@ -132,10 +132,10 @@ func (r *Router) handlePutModel(w http.ResponseWriter, req *http.Request) { // Build metadata JSON. meta := map[string]any{ - "display_name": body.DisplayName, - "description": body.Description, - "context_window": body.ContextWindow, - "max_output_tokens": body.MaxOutputTokens, + "display_name": body.DisplayName, + "description": body.Description, + "context_window": body.ContextWindow, + "max_output_tokens": body.MaxOutputTokens, } metaJSON, _ := json.Marshal(meta) diff --git a/internal/service/api/router.go b/internal/service/api/router.go index 500f31b1..5d5c85aa 100644 --- a/internal/service/api/router.go +++ b/internal/service/api/router.go @@ -55,14 +55,14 @@ func NewRouter(cfg ConfigStore, rt *runtime.Runtime, st *stats.SessionStats, reg mux := http.NewServeMux() // Auth middleware for all API routes. - authMW := AuthMiddleware(func() string { - return r.server.CurrentConfig().AuthToken() - }, func() bool { return r.store != nil }) + authMW := AuthMiddleware(func() string { + return r.server.CurrentConfig().AuthToken() + }, func() bool { return r.store != nil }) - // Register all routes using Go 1.22+ pattern matching. - registerRoutes(mux, r) + // Register all routes using Go 1.22+ pattern matching. + registerRoutes(mux, r) - return authMW(mux) + return authMW(mux) } // registerRoutes registers all API endpoints with the mux. diff --git a/internal/service/api/settings.go b/internal/service/api/settings.go index f7dcf333..11ab3c99 100644 --- a/internal/service/api/settings.go +++ b/internal/service/api/settings.go @@ -66,11 +66,11 @@ func (r *Router) handleGetWebSearch(w http.ResponseWriter, req *http.Request) { cfg := r.runtime.Current() resp := map[string]any{ - "support": string(cfg.Config.WebSearchSupport), - "max_uses": cfg.Config.WebSearchMaxUses, - "tavily_api_key": maskAPIKey(cfg.Config.TavilyAPIKey), - "firecrawl_api_key": maskAPIKey(cfg.Config.FirecrawlAPIKey), - "search_max_rounds": cfg.Config.SearchMaxRounds, + "support": string(cfg.Config.WebSearchSupport), + "max_uses": cfg.Config.WebSearchMaxUses, + "tavily_api_key": maskAPIKey(cfg.Config.TavilyAPIKey), + "firecrawl_api_key": maskAPIKey(cfg.Config.FirecrawlAPIKey), + "search_max_rounds": cfg.Config.SearchMaxRounds, } respondJSON(w, http.StatusOK, resp) @@ -103,11 +103,11 @@ func (r *Router) handlePutWebSearch(w http.ResponseWriter, req *http.Request) { } wsJSON, _ := json.Marshal(map[string]any{ - "support": body.Support, - "max_uses": body.MaxUses, - "tavily_api_key": tavilyKey, - "firecrawl_api_key": firecrawlKey, - "search_max_rounds": body.SearchMaxRounds, + "support": body.Support, + "max_uses": body.MaxUses, + "tavily_api_key": tavilyKey, + "firecrawl_api_key": firecrawlKey, + "search_max_rounds": body.SearchMaxRounds, }) chID, err := r.store.StageChange(store.ChangeRow{ @@ -304,10 +304,10 @@ func (r *Router) handlePostConfigImport(w http.ResponseWriter, req *http.Request for slug, def := range cfg.Models { meta := map[string]any{ - "display_name": def.DisplayName, - "description": def.Description, - "context_window": def.ContextWindow, - "max_output_tokens": def.MaxOutputTokens, + "display_name": def.DisplayName, + "description": def.Description, + "context_window": def.ContextWindow, + "max_output_tokens": def.MaxOutputTokens, } metaJSON, _ := json.Marshal(meta) afterJSON, _ := json.Marshal(map[string]any{ @@ -381,11 +381,11 @@ func (r *Router) handlePostConfigImport(w http.ResponseWriter, req *http.Request // Stage web_search if set. if cfg.WebSearchSupport != "" || cfg.TavilyAPIKey != "" || cfg.FirecrawlAPIKey != "" { wsJSON, _ := json.Marshal(map[string]any{ - "support": string(cfg.WebSearchSupport), - "max_uses": cfg.WebSearchMaxUses, - "tavily_api_key": cfg.TavilyAPIKey, - "firecrawl_api_key": cfg.FirecrawlAPIKey, - "search_max_rounds": cfg.SearchMaxRounds, + "support": string(cfg.WebSearchSupport), + "max_uses": cfg.WebSearchMaxUses, + "tavily_api_key": cfg.TavilyAPIKey, + "firecrawl_api_key": cfg.FirecrawlAPIKey, + "search_max_rounds": cfg.SearchMaxRounds, }) chID, err := r.store.StageChange(store.ChangeRow{ Action: "update", diff --git a/internal/service/api/status.go b/internal/service/api/status.go index 8ec967a3..f5e43cc1 100644 --- a/internal/service/api/status.go +++ b/internal/service/api/status.go @@ -4,7 +4,6 @@ import ( "net/http" "sort" "time" - ) // ---- Status ---- diff --git a/internal/service/app/app.go b/internal/service/app/app.go index bccf2b6a..b4ed6775 100644 --- a/internal/service/app/app.go +++ b/internal/service/app/app.go @@ -12,6 +12,7 @@ import ( "log/slog" "moonbridge/internal/config" "moonbridge/internal/db" + "moonbridge/internal/extension/codextool" "moonbridge/internal/format" "moonbridge/internal/logger" "moonbridge/internal/protocol/anthropic" @@ -205,7 +206,7 @@ func runTransform(ctx context.Context, cfg config.Config, errors io.Writer) erro coreHooks := plugins.CorePluginHooks() // Inbound: OpenAI Responses client adapter. - oaiAdapter := openai.NewOpenAIAdapter(coreHooks) + oaiAdapter := openai.NewOpenAIAdapter(coreHooks, codextool.NestedOneOf) _ = adapterReg.RegisterClient(oaiAdapter) _ = adapterReg.RegisterClientStream(oaiAdapter) diff --git a/internal/service/app/extensions.go b/internal/service/app/extensions.go index 423691b4..12782aad 100644 --- a/internal/service/app/extensions.go +++ b/internal/service/app/extensions.go @@ -4,15 +4,15 @@ import ( "database/sql" "log/slog" + "moonbridge/internal/config" + codextoolproxy "moonbridge/internal/extension/codex_tool_proxy" dbd1 "moonbridge/internal/extension/db/d1" dbsqlite "moonbridge/internal/extension/db/sqlite" deepseekv4 "moonbridge/internal/extension/deepseek_v4" kimiworkaround "moonbridge/internal/extension/kimi_workaround" mbtrics "moonbridge/internal/extension/metrics" - codextoolproxy "moonbridge/internal/extension/codex_tool_proxy" "moonbridge/internal/extension/plugin" "moonbridge/internal/extension/visual" - "moonbridge/internal/config" ) // ExtensionOptions controls optional initialization of built-in plugins. diff --git a/internal/service/provider/manager.go b/internal/service/provider/manager.go index 7076c3d1..949bac1c 100644 --- a/internal/service/provider/manager.go +++ b/internal/service/provider/manager.go @@ -135,7 +135,7 @@ func NewAnthropicClientAdapter(client *anthropic.Client) ProviderClient { } type ProviderManager struct { - mu sync.Mutex // guards field replacement during Reload + mu sync.RWMutex // guards field replacement during Reload clients map[string]ProviderClient providers map[string]ProviderConfig // provider key -> config (for inspection) routes map[string]ModelRoute // model alias -> route @@ -276,6 +276,8 @@ func (pm *ProviderManager) Reload(cfg config.ProviderConfig) error { // It returns the default provider if the alias is not explicitly routed. func (pm *ProviderManager) ClientFor(modelAlias string) (string, ProviderClient, error) { // Direct provider/model reference. + pm.mu.RLock() + defer pm.mu.RUnlock() if provider, upstream := ParseModelRef(modelAlias); provider != "" { if client, ok := pm.clients[provider]; ok { return upstream, client, nil @@ -312,6 +314,8 @@ func (pm *ProviderManager) ClientFor(modelAlias string) (string, ProviderClient, // Returns error if no candidates are found. func (pm *ProviderManager) ResolveModel(modelName string) (*ResolvedRoute, error) { // 1. Route alias (highest priority) + pm.mu.RLock() + defer pm.mu.RUnlock() if route, ok := pm.routes[modelName]; ok { providerKey := route.Provider if providerKey == "" { @@ -325,7 +329,7 @@ func (pm *ProviderManager) ResolveModel(modelName string) (*ResolvedRoute, error Candidates: []ProviderCandidate{{ ProviderKey: providerKey, UpstreamModel: route.Name, - Protocol: pm.ProtocolForKey(providerKey), + Protocol: pm.protocolForKeyInline(providerKey), Client: client, }}, }, nil @@ -341,7 +345,7 @@ func (pm *ProviderManager) ResolveModel(modelName string) (*ResolvedRoute, error Candidates: []ProviderCandidate{{ ProviderKey: providerKey, UpstreamModel: upstreamModel, - Protocol: pm.ProtocolForKey(providerKey), + Protocol: pm.protocolForKeyInline(providerKey), Client: client, }}, }, nil @@ -369,7 +373,7 @@ func (pm *ProviderManager) ResolveModel(modelName string) (*ResolvedRoute, error candidates = append(candidates, ProviderCandidate{ ProviderKey: entry.providerKey, UpstreamModel: modelName, - Protocol: pm.ProtocolForKey(entry.providerKey), + Protocol: pm.protocolForKeyInline(entry.providerKey), Client: client, }) } @@ -417,6 +421,8 @@ func (pm *ProviderManager) ProbeWebSearchCandidate(ctx context.Context, provider // ProviderKeys returns all configured provider keys. func (pm *ProviderManager) ProviderKeys() []string { + pm.mu.RLock() + defer pm.mu.RUnlock() keys := make([]string, 0, len(pm.clients)) for k := range pm.clients { keys = append(keys, k) @@ -426,6 +432,8 @@ func (pm *ProviderManager) ProviderKeys() []string { // DefaultKey returns the default provider key. func (pm *ProviderManager) DefaultKey() string { + pm.mu.RLock() + defer pm.mu.RUnlock() return pm.defaultK } @@ -445,10 +453,11 @@ func newHTTPClient(cfg HTTPConfig) *http.Client { return &http.Client{ Transport: &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: maxIdle, - IdleConnTimeout: idleTimeout, - DisableCompression: false, + MaxIdleConns: 100, + MaxIdleConnsPerHost: maxIdle, + IdleConnTimeout: idleTimeout, + DisableCompression: false, + ResponseHeaderTimeout: 30 * time.Second, }, } } @@ -462,6 +471,8 @@ func valueOrDefault(value, fallback string) string { // ClientForKey returns the anthropic.Client for a given provider key. func (pm *ProviderManager) ClientForKey(key string) (ProviderClient, error) { + pm.mu.RLock() + defer pm.mu.RUnlock() client, ok := pm.clients[key] if !ok { return nil, fmt.Errorf("provider %q not found", key) @@ -471,7 +482,25 @@ func (pm *ProviderManager) ClientForKey(key string) (ProviderClient, error) { // ProtocolForKey returns the protocol for a given provider key. // Returns "anthropic" if not configured. +// protocolForKeyInline returns the protocol for a provider key. +// Caller must hold pm.mu (read lock). +func (pm *ProviderManager) protocolForKeyInline(key string) string { + if pm.providers == nil { + return "anthropic" + } + cfg, ok := pm.providers[key] + if !ok { + return "anthropic" + } + if cfg.Protocol == "" { + return "anthropic" + } + return cfg.Protocol +} + func (pm *ProviderManager) ProtocolForKey(key string) string { + pm.mu.RLock() + defer pm.mu.RUnlock() if pm.providers == nil { return "anthropic" } @@ -488,24 +517,28 @@ func (pm *ProviderManager) ProtocolForKey(key string) string { // ProtocolForModel returns the protocol for the provider serving the given model alias. // Returns "anthropic" if the model is not explicitly routed. func (pm *ProviderManager) ProtocolForModel(modelAlias string) string { + pm.mu.RLock() + defer pm.mu.RUnlock() // Direct provider/model reference. if provider, _ := ParseModelRef(modelAlias); provider != "" { - return pm.ProtocolForKey(provider) + return pm.protocolForKeyInline(provider) } route, ok := pm.routes[modelAlias] if !ok { - return pm.ProtocolForKey(pm.defaultK) + return pm.protocolForKeyInline(pm.defaultK) } providerKey := route.Provider if providerKey == "" { providerKey = pm.defaultK } - return pm.ProtocolForKey(providerKey) + return pm.protocolForKeyInline(providerKey) } // UpstreamModelFor returns the upstream model name for a model alias. func (pm *ProviderManager) UpstreamModelFor(modelAlias string) string { // Direct provider/model reference. + pm.mu.RLock() + defer pm.mu.RUnlock() if provider, upstream := ParseModelRef(modelAlias); provider != "" { if _, ok := pm.clients[provider]; ok { return upstream @@ -520,6 +553,8 @@ func (pm *ProviderManager) UpstreamModelFor(modelAlias string) string { // ProviderBaseURL returns the base URL for a given provider key. func (pm *ProviderManager) ProviderBaseURL(key string) string { + pm.mu.RLock() + defer pm.mu.RUnlock() cfg, ok := pm.providers[key] if !ok { return "" @@ -529,6 +564,8 @@ func (pm *ProviderManager) ProviderBaseURL(key string) string { // ProviderAPIKey returns the API key for a given provider key. func (pm *ProviderManager) ProviderAPIKey(key string) string { + pm.mu.RLock() + defer pm.mu.RUnlock() cfg, ok := pm.providers[key] if !ok { return "" @@ -540,6 +577,8 @@ func (pm *ProviderManager) ProviderAPIKey(key string) string { // Falls back to defaultK when the model has no explicit route. func (pm *ProviderManager) ProviderKeyForModel(modelAlias string) string { // Direct provider/model reference. + pm.mu.RLock() + defer pm.mu.RUnlock() if provider, _ := ParseModelRef(modelAlias); provider != "" { if _, ok := pm.clients[provider]; ok { return provider @@ -552,6 +591,21 @@ func (pm *ProviderManager) ProviderKeyForModel(modelAlias string) string { return route.Provider } +// providerKeyForModelInline returns the provider key for a model alias. +// Caller must hold pm.mu (read lock). +func (pm *ProviderManager) providerKeyForModelInline(modelAlias string) string { + if provider, _ := modelref.Parse(modelAlias); provider != "" { + if _, ok := pm.clients[provider]; ok { + return provider + } + } + route, ok := pm.routes[modelAlias] + if !ok || route.Provider == "" { + return pm.defaultK + } + return route.Provider +} + // WebSearchCandidateKey returns the runtime key for a resolved provider/model pair. func WebSearchCandidateKey(providerKey, upstreamModel string) string { return "candidate:" + providerKey + "/" + upstreamModel @@ -566,11 +620,15 @@ func (pm *ProviderManager) SetResolvedWebSearch(key string, support string) { // ResolvedWebSearch returns the resolved web search support for a provider key. // Returns empty string if not yet resolved. func (pm *ProviderManager) ResolvedWebSearch(key string) string { + pm.mu.RLock() + defer pm.mu.RUnlock() return pm.resolvedWS[key] } // ModelMetaFor returns the ModelMeta for a model name within a specific provider. func (pm *ProviderManager) ModelMetaFor(modelName string, providerKey string) (ModelMeta, bool) { + pm.mu.RLock() + defer pm.mu.RUnlock() cfg, ok := pm.providers[providerKey] if !ok { return ModelMeta{}, false @@ -581,6 +639,8 @@ func (pm *ProviderManager) ModelMetaFor(modelName string, providerKey string) (M // ProviderDefForKey returns the full ProviderConfig for a given provider key. func (pm *ProviderManager) ProviderDefForKey(key string) (ProviderConfig, bool) { + pm.mu.RLock() + defer pm.mu.RUnlock() cfg, ok := pm.providers[key] if !ok { return ProviderConfig{}, false @@ -592,21 +652,47 @@ func (pm *ProviderManager) ProviderDefForKey(key string) (ProviderConfig, bool) // Checks model-level first, then falls back to provider-level. func (pm *ProviderManager) ResolvedWebSearchForModel(modelAlias string) string { // Check model-level resolution first. + pm.mu.RLock() + defer pm.mu.RUnlock() if v, ok := pm.resolvedWS["model:"+modelAlias]; ok { return v } - if providerKey, upstreamModel, ok := pm.ProviderAndUpstreamForModel(modelAlias); ok { + var providerKey, upstreamModel string + ok := false + if provider, upstream := modelref.Parse(modelAlias); provider != "" { + if _, exists := pm.clients[provider]; exists { + providerKey, upstreamModel, ok = provider, upstream, true + } + } + if !ok { + if route, exists := pm.routes[modelAlias]; exists { + if route.Provider == "" { + providerKey = pm.defaultK + } else { + providerKey = route.Provider + } + upstreamModel, ok = route.Name, true + } + } + if !ok { + if pm.defaultK != "" { + providerKey, upstreamModel, ok = pm.defaultK, modelAlias, true + } + } + if ok { if v, ok := pm.resolvedWS[WebSearchCandidateKey(providerKey, upstreamModel)]; ok { return v } } // Fall back to provider-level. - return pm.resolvedWS[pm.ProviderKeyForModel(modelAlias)] + return pm.resolvedWS[pm.providerKeyForModelInline(modelAlias)] } // ResolvedWebSearchForCandidate returns the resolved web search support for a provider/model pair. // Falls back to provider-level support when no candidate-specific resolution exists. func (pm *ProviderManager) ResolvedWebSearchForCandidate(providerKey, upstreamModel string) string { + pm.mu.RLock() + defer pm.mu.RUnlock() if providerKey == "" { return "" } @@ -620,6 +706,8 @@ func (pm *ProviderManager) ResolvedWebSearchForCandidate(providerKey, upstreamMo // WebSearchConfigForKey returns the configured web search support for a provider key. func (pm *ProviderManager) WebSearchConfigForKey(key string) string { + pm.mu.RLock() + defer pm.mu.RUnlock() cfg, ok := pm.providers[key] if !ok { return "" @@ -631,6 +719,8 @@ func (pm *ProviderManager) WebSearchConfigForKey(key string) string { // alias routed to the given provider key. Falls back to the provider's own // model list when no route alias references it. Returns empty string if none found. func (pm *ProviderManager) FirstUpstreamModelForKey(key string) string { + pm.mu.RLock() + defer pm.mu.RUnlock() for _, route := range pm.routes { pk := route.Provider if pk == "" { @@ -651,6 +741,8 @@ func (pm *ProviderManager) FirstUpstreamModelForKey(key string) string { // ProviderAndUpstreamForModel resolves the provider key and upstream model for a model alias. func (pm *ProviderManager) ProviderAndUpstreamForModel(modelAlias string) (providerKey string, upstreamModel string, ok bool) { // Direct provider/model reference. + pm.mu.RLock() + defer pm.mu.RUnlock() if provider, upstream := ParseModelRef(modelAlias); provider != "" { if _, exists := pm.clients[provider]; exists { return provider, upstream, true diff --git a/internal/service/runtime/runtime_test.go b/internal/service/runtime/runtime_test.go index e2982726..1d011727 100644 --- a/internal/service/runtime/runtime_test.go +++ b/internal/service/runtime/runtime_test.go @@ -188,7 +188,7 @@ func TestReloadWithInvalidConfigReturnsError(t *testing.T) { // Reload with invalid mode (empty mode fails validation). invalidCfg := config.Config{ - Mode: config.Mode(""), + Mode: config.Mode(""), Cache: config.CacheConfig{Mode: "off"}, } diff --git a/internal/service/server/adapter_core_provider_test.go b/internal/service/server/adapter_core_provider_test.go index 50970904..24488565 100644 --- a/internal/service/server/adapter_core_provider_test.go +++ b/internal/service/server/adapter_core_provider_test.go @@ -27,7 +27,7 @@ func TestCoreResponseToStreamEventsEmitsTextAndUsage(t *testing.T) { Usage: format.CoreUsage{InputTokens: 11, OutputTokens: 7, CachedInputTokens: 3}, } - events := collectCoreStreamEvents(coreResponseToStreamEvents(resp)) + events := collectCoreStreamEvents(coreResponseToStreamEvents(context.Background(), resp)) if len(events) == 0 { t.Fatal("no stream events emitted") } diff --git a/internal/service/server/adapter_dispatch.go b/internal/service/server/adapter_dispatch.go index d3e65a8d..c3aa5c36 100644 --- a/internal/service/server/adapter_dispatch.go +++ b/internal/service/server/adapter_dispatch.go @@ -518,6 +518,29 @@ func (s *Server) handleWithAdapters( } record.UpstreamRequest = googleReq + // Wrap with visual orchestrator if enabled for this model. + googlePreferred := preferred + googlePreferred.Client = &googleProviderClient{c: googleClient, model: googlePreferred.UpstreamModel} + if visProv := s.wrapWithVisual(ctx, openAIReq.Model, googlePreferred, providerAdapter, nil); visProv != nil { + var visErr error + coreResp, visErr = visProv.CreateCore(ctx, coreReq) + if visErr != nil { + log.Error("adapter path: google visual CreateCore failed", "error", visErr) + payload := openai.ErrorResponse{ + Error: openai.ErrorObject{ + Message: fmt.Sprintf("google visual orchestration failed: %v", visErr), + Type: "server_error", + Code: "provider_error", + }, + } + record.Error = traceError("google_visual_core", visErr) + record.OpenAIResponse = payload + writeOpenAIError(w, http.StatusBadGateway, payload) + return + } + break + } + var googleResp *google.GenerateContentResponse if wsInjected { googleResp, err = s.executeGoogleSearchLoop(ctx, googleClient, preferred.UpstreamModel, googleReq, searchCfg.tavilyKey, searchCfg.firecrawlKey, searchCfg.maxRounds) @@ -903,6 +926,9 @@ func (s *Server) handleAdapterStream( // Protocol-specific upstream streaming: get stream + convert to CoreStreamEvent. var coreEvents <-chan format.CoreStreamEvent var providerStream format.ProviderStreamAdapter + var sr *format.StreamResult // result from ToCoreStream, captures events + buffer + var providerBuf func() []any // per-request provider stream buffer (from ToCoreStream StreamResult) + var clientBuf func() []any // per-request client stream buffer (from OpenAI FromCoreStream) switch candidate.Protocol { case config.ProtocolAnthropic: @@ -1037,7 +1063,7 @@ func (s *Server) handleAdapterStream( writeOpenAIError(w, http.StatusInternalServerError, payload) return } - coreEvents, err = providerStream.ToCoreStream(ctx, stream) + sr, err = providerStream.ToCoreStream(ctx, stream) if err != nil { log.Error("adapter stream: ToCoreStream failed", "error", err) payload := openai.ErrorResponse{ @@ -1052,6 +1078,10 @@ func (s *Server) handleAdapterStream( writeOpenAIError(w, http.StatusInternalServerError, payload) return } + coreEvents = sr.Events + if sr.StreamBuffer != nil { + providerBuf = sr.StreamBuffer + } } case config.ProtocolOpenAIChat: @@ -1124,6 +1154,52 @@ func (s *Server) handleAdapterStream( streamRecord.ChatRequest = chatReq var chatStream <-chan chat.ChatStreamChunk var err error + + providerAdapter, ok := s.adapterRegistry.GetProvider(config.ProtocolOpenAIChat) + if !ok { + log.Warn("adapter stream: no chat provider adapter for visual path") + } + + // Visual orchestrator for streaming path: non-streaming orchestration + // → synthetic stream events, matching the anthropic streaming pattern. + if s.pluginRegistry != nil && s.runtime != nil && openAIReq.Model != "" && ok && providerAdapter != nil { + cfgV := s.runtime.Current().Config + visCfg, visOk := visualpkg.ConfigForModelFromResolvedConfig(cfgV, openAIReq.Model) + if visOk && visCfg.Provider != "" && visCfg.Model != "" { + finalizeUpstream := func(_ context.Context, upstream any) (any, error) { + req, ok := upstream.(*chat.ChatRequest) + if !ok { + return nil, fmt.Errorf("chat visual finalize: expected *chat.ChatRequest, got %T", upstream) + } + if s.pluginRegistry != nil && sess != nil { + prependCachedReasoningForChat(req, sess) + } + return req, nil + } + visCandidate := candidate + visCandidate.Client = &chatProviderClient{c: chatClient} + if visProv := s.wrapWithVisual(ctx, openAIReq.Model, visCandidate, providerAdapter, finalizeUpstream); visProv != nil { + coreResp, visErr := visProv.CreateCore(ctx, coreReq) + if visErr != nil { + log.Error("adapter stream: chat visual CreateCore failed", "error", visErr) + payload := openai.ErrorResponse{ + Error: openai.ErrorObject{ + Message: fmt.Sprintf("chat visual orchestration failed: %v", visErr), + Type: "server_error", + Code: "provider_error", + }, + } + streamRecord.Error = traceError("stream_chat_visual", visErr) + streamRecord.OpenAIResponse = payload + writeOpenAIError(w, http.StatusBadGateway, payload) + return + } + coreEvents = coreResponseToCoreStream(ctx, coreResp) + break + } + } + } + if wsInjected { searchCfg := s.resolvedSearchConfig(candidate.ProviderKey, openAIReq.Model) chatStream, err = s.chatSearchBufferedStream(ctx, chatClient, chatReq, searchCfg.tavilyKey, searchCfg.firecrawlKey, searchCfg.maxRounds) @@ -1160,7 +1236,7 @@ func (s *Server) handleAdapterStream( writeOpenAIError(w, http.StatusInternalServerError, payload) return } - coreEvents, err = providerStream.ToCoreStream(ctx, chatStream) + sr, err = providerStream.ToCoreStream(ctx, chatStream) if err != nil { log.Error("adapter stream: Chat ToCoreStream failed", "error", err) payload := openai.ErrorResponse{ @@ -1175,6 +1251,10 @@ func (s *Server) handleAdapterStream( writeOpenAIError(w, http.StatusInternalServerError, payload) return } + coreEvents = sr.Events + if sr.StreamBuffer != nil { + providerBuf = sr.StreamBuffer + } case config.ProtocolGoogleGenAI: googleReq, ok := upstreamReq.(*google.GenerateContentRequest) @@ -1225,6 +1305,38 @@ func (s *Server) handleAdapterStream( } streamRecord.UpstreamRequest = googleReq + + // Visual orchestrator for streaming path: non-streaming orchestration + // → synthetic stream events, matching the anthropic/chat streaming pattern. + providerAdapter, ok := s.adapterRegistry.GetProvider(config.ProtocolGoogleGenAI) + if ok && providerAdapter != nil && s.runtime != nil && openAIReq.Model != "" { + cfgV := s.runtime.Current().Config + visCfg, visOk := visualpkg.ConfigForModelFromResolvedConfig(cfgV, openAIReq.Model) + if visOk && visCfg.Provider != "" && visCfg.Model != "" { + visCandidate := candidate + visCandidate.Client = &googleProviderClient{c: googleClient, model: candidate.UpstreamModel} + if visProv := s.wrapWithVisual(ctx, openAIReq.Model, visCandidate, providerAdapter, nil); visProv != nil { + coreResp, visErr := visProv.CreateCore(ctx, coreReq) + if visErr != nil { + log.Error("adapter stream: google visual CreateCore failed", "error", visErr) + payload := openai.ErrorResponse{ + Error: openai.ErrorObject{ + Message: fmt.Sprintf("google visual orchestration failed: %v", visErr), + Type: "server_error", + Code: "provider_error", + }, + } + streamRecord.Error = traceError("stream_google_visual", visErr) + streamRecord.OpenAIResponse = payload + writeOpenAIError(w, http.StatusBadGateway, payload) + return + } + coreEvents = coreResponseToCoreStream(ctx, coreResp) + break + } + } + } + if wsInjected { searchCfg := s.resolvedSearchConfig(candidate.ProviderKey, openAIReq.Model) googleResp, err := s.executeGoogleSearchLoop(ctx, googleClient, candidate.UpstreamModel, googleReq, searchCfg.tavilyKey, searchCfg.firecrawlKey, searchCfg.maxRounds) @@ -1324,7 +1436,7 @@ func (s *Server) handleAdapterStream( writeOpenAIError(w, http.StatusInternalServerError, payload) return } - coreEvents, err = providerStream.ToCoreStream(ctx, googleStream) + sr, err = providerStream.ToCoreStream(ctx, googleStream) if err != nil { log.Error("adapter stream: Google ToCoreStream failed", "error", err) payload := openai.ErrorResponse{ @@ -1339,6 +1451,10 @@ func (s *Server) handleAdapterStream( writeOpenAIError(w, http.StatusInternalServerError, payload) return } + coreEvents = sr.Events + if sr.StreamBuffer != nil { + providerBuf = sr.StreamBuffer + } default: log.Error("adapter stream: unsupported protocol", "protocol", candidate.Protocol) @@ -1389,8 +1505,13 @@ func (s *Server) handleAdapterStream( return } - streamChan, ok := streamChanAny.(<-chan openai.StreamEvent) - if !ok { + var streamChan <-chan openai.StreamEvent + if oaiResult, ok := streamChanAny.(*openai.OpenAIStreamResult); ok { + streamChan = oaiResult.Chan() + clientBuf = oaiResult.Buffer + } else if ch, ok := streamChanAny.(<-chan openai.StreamEvent); ok { + streamChan = ch + } else { log.Error("adapter stream: unexpected stream channel type", "type", fmt.Sprintf("%T", streamChanAny)) payload := openai.ErrorResponse{ Error: openai.ErrorObject{ @@ -1435,8 +1556,15 @@ func (s *Server) handleAdapterStream( remembered := rememberStreamResponseContent(s.pluginRegistry, sess, openAIReq.Model, finalResp) if !remembered { if anthProvider, ok := s.adapterRegistry.GetProvider(config.ProtocolAnthropic); ok { - if anthAdapter, ok := anthProvider.(*anthropic.AnthropicProviderAdapter); ok { - events := anthAdapter.StreamBuffer() + if _, ok := anthProvider.(*anthropic.AnthropicProviderAdapter); ok { + var events []anthropic.StreamEvent + if providerBuf != nil { + raw := providerBuf() + events = make([]anthropic.StreamEvent, len(raw)) + for i, r := range raw { + events[i], _ = r.(anthropic.StreamEvent) + } + } if len(events) > 0 { states := s.pluginRegistry.NewStreamStates(openAIReq.Model) for _, ev := range events { @@ -1474,8 +1602,19 @@ func (s *Server) handleAdapterStream( // This must not depend on trace being enabled. if sess != nil { if chatProvider, ok := s.adapterRegistry.GetProvider(config.ProtocolOpenAIChat); ok { - if chatAdapter, ok := chatProvider.(*chat.ChatProviderAdapter); ok { - if events := chatAdapter.StreamBuffer(); len(events) > 0 { + if _, ok := chatProvider.(*chat.ChatProviderAdapter); ok { + var chatEvents []chat.ChatStreamChunk + if providerBuf != nil { + raw := providerBuf() + chatEvents = make([]chat.ChatStreamChunk, 0, len(raw)) + for _, r := range raw { + if ev, ok := r.(chat.ChatStreamChunk); ok { + chatEvents = append(chatEvents, ev) + } + } + } + events := chatEvents + if len(events) > 0 { var streamReasoning string seenToolCallIDs := make(map[string]struct{}) streamToolCallIDs := make([]string, 0, 4) @@ -1506,30 +1645,40 @@ func (s *Server) handleAdapterStream( // Capture stream events for trace. if s.tracer != nil && s.tracer.Enabled() { - // OpenAI stream events from client adapter - if oaiClient, ok := s.adapterRegistry.GetClient(config.ProtocolOpenAIResponse); ok { - if oaiAdapter, ok := oaiClient.(*openai.OpenAIAdapter); ok { - if events := oaiAdapter.StreamBuffer(); len(events) > 0 { - streamRecord.OpenAIStreamEvents = events + // Provider stream buffer (anthropic/chat/google raw events) + if providerBuf != nil { + raw := providerBuf() + var anthBuf []anthropic.StreamEvent + for _, r := range raw { + if ev, ok := r.(anthropic.StreamEvent); ok { + anthBuf = append(anthBuf, ev) } } - } - // Anthropic stream events from provider adapter - if anthProvider, ok := s.adapterRegistry.GetProvider(config.ProtocolAnthropic); ok { - if anthAdapter, ok := anthProvider.(*anthropic.AnthropicProviderAdapter); ok { - if events := anthAdapter.StreamBuffer(); len(events) > 0 { - streamRecord.AnthropicStreamEvents = events + if len(anthBuf) > 0 { + streamRecord.AnthropicStreamEvents = anthBuf + } + var chatBuf []chat.ChatStreamChunk + for _, r := range raw { + if ev, ok := r.(chat.ChatStreamChunk); ok { + chatBuf = append(chatBuf, ev) } - - // Chat stream events from provider adapter - if chatProvider, ok := s.adapterRegistry.GetProvider(config.ProtocolOpenAIChat); ok { - if chatAdapter, ok := chatProvider.(*chat.ChatProviderAdapter); ok { - if events := chatAdapter.StreamBuffer(); len(events) > 0 { - streamRecord.ChatStreamEvents = events - } - } + } + if len(chatBuf) > 0 { + streamRecord.ChatStreamEvents = chatBuf + } + } + // Client stream buffer (OpenAI stream events) + if clientBuf != nil { + raw := clientBuf() + var openAIBuf []openai.StreamEvent + for _, r := range raw { + if ev, ok := r.(openai.StreamEvent); ok { + openAIBuf = append(openAIBuf, ev) } } + if len(openAIBuf) > 0 { + streamRecord.OpenAIStreamEvents = openAIBuf + } } } if s.stats != nil && (finalUsage.InputTokens > 0 || finalUsage.OutputTokens > 0) { @@ -1635,7 +1784,7 @@ func (s *Server) writeCoreResponseAsOpenAIStream( return } - streamChanAny, err := clientStream.FromCoreStream(ctx, coreReq, coreResponseToStreamEvents(coreResp)) + streamChanAny, err := clientStream.FromCoreStream(ctx, coreReq, coreResponseToStreamEvents(ctx, coreResp)) if err != nil { payload := openai.ErrorResponse{ Error: openai.ErrorObject{ @@ -1649,8 +1798,12 @@ func (s *Server) writeCoreResponseAsOpenAIStream( writeOpenAIError(w, http.StatusInternalServerError, payload) return } - streamChan, ok := streamChanAny.(<-chan openai.StreamEvent) - if !ok { + var streamChan <-chan openai.StreamEvent + if oaiResult, ok := streamChanAny.(*openai.OpenAIStreamResult); ok { + streamChan = oaiResult.Chan() + } else if ch, ok := streamChanAny.(<-chan openai.StreamEvent); ok { + streamChan = ch + } else { payload := openai.ErrorResponse{ Error: openai.ErrorObject{ Message: "unexpected stream channel type", @@ -1710,21 +1863,33 @@ func (s *Server) writeCoreResponseAsOpenAIStream( } } -func coreResponseToStreamEvents(resp *format.CoreResponse) <-chan format.CoreStreamEvent { +func coreResponseToStreamEvents(ctx context.Context, resp *format.CoreResponse) <-chan format.CoreStreamEvent { out := make(chan format.CoreStreamEvent, 16) go func() { defer close(out) + + send := func(ev format.CoreStreamEvent) bool { + select { + case <-ctx.Done(): + return false + case out <- ev: + return true + } + } + if resp == nil { - out <- format.CoreStreamEvent{ + send(format.CoreStreamEvent{ Type: format.CoreEventFailed, Error: &format.CoreError{ Message: "core response is nil", Type: "server_error", }, - } + }) + return + } + if !send(format.CoreStreamEvent{Type: format.CoreEventCreated, ItemID: resp.ID, Model: resp.Model}) { return } - out <- format.CoreStreamEvent{Type: format.CoreEventCreated, ItemID: resp.ID, Model: resp.Model} index := 0 for _, msg := range resp.Messages { if msg.Role != "assistant" { @@ -1733,21 +1898,33 @@ func coreResponseToStreamEvents(resp *format.CoreResponse) <-chan format.CoreStr for _, block := range msg.Content { switch block.Type { case "reasoning": - out <- format.CoreStreamEvent{Type: format.CoreContentBlockStarted, Index: index, ContentBlock: &format.CoreContentBlock{Type: "reasoning"}} + if !send(format.CoreStreamEvent{Type: format.CoreContentBlockStarted, Index: index, ContentBlock: &format.CoreContentBlock{Type: "reasoning"}}) { + return + } if block.ReasoningText != "" { - out <- format.CoreStreamEvent{Type: format.CoreTextDelta, Index: index, Delta: block.ReasoningText} + if !send(format.CoreStreamEvent{Type: format.CoreTextDelta, Index: index, Delta: block.ReasoningText}) { + return + } } - out <- format.CoreStreamEvent{Type: format.CoreContentBlockDone, Index: index, ContentBlock: &format.CoreContentBlock{ + if !send(format.CoreStreamEvent{Type: format.CoreContentBlockDone, Index: index, ContentBlock: &format.CoreContentBlock{ Type: "reasoning", ReasoningSignature: block.ReasoningSignature, - }} + }}) { + return + } index++ case "text": - out <- format.CoreStreamEvent{Type: format.CoreContentBlockStarted, Index: index, ContentBlock: &format.CoreContentBlock{Type: "text"}} + if !send(format.CoreStreamEvent{Type: format.CoreContentBlockStarted, Index: index, ContentBlock: &format.CoreContentBlock{Type: "text"}}) { + return + } if block.Text != "" { - out <- format.CoreStreamEvent{Type: format.CoreTextDelta, Index: index, Delta: block.Text} + if !send(format.CoreStreamEvent{Type: format.CoreTextDelta, Index: index, Delta: block.Text}) { + return + } + } + if !send(format.CoreStreamEvent{Type: format.CoreContentBlockDone, Index: index}) { + return } - out <- format.CoreStreamEvent{Type: format.CoreContentBlockDone, Index: index} index++ } } @@ -1762,13 +1939,13 @@ func coreResponseToStreamEvents(resp *format.CoreResponse) <-chan format.CoreStr } else if status == "incomplete" { eventType = format.CoreEventIncomplete } - out <- format.CoreStreamEvent{ + send(format.CoreStreamEvent{ Type: eventType, Status: status, Model: resp.Model, Usage: &resp.Usage, Error: resp.Error, - } + }) }() return out } @@ -2080,6 +2257,22 @@ func (s *Server) wrapWithVisual( return nil } visClient = &chatProviderClient{c: chatClient} + case config.ProtocolGoogleGenAI: + gcRaw := s.activeGoogleClient(visCfg.Provider) + if gcRaw == nil { + slog.Default().Warn("visual: no google client for visual provider", "visual_provider", visCfg.Provider, "model", modelAlias) + return nil + } + gc, ok := gcRaw.(*google.Client) + if !ok || gc == nil { + slog.Default().Warn("visual: google client type mismatch", "visual_provider", visCfg.Provider) + return nil + } + visModel := pm.FirstUpstreamModelForKey(visCfg.Provider) + if visModel == "" { + visModel = visCfg.Model + } + visClient = &googleProviderClient{c: gc, model: visModel} default: c, err := pm.ClientForKey(visCfg.Provider) if err != nil || c == nil { @@ -2100,6 +2293,29 @@ func (s *Server) wrapWithVisual( // pm.ClientForKey only constructs anthropic-shaped clients; chat-protocol // providers keep their dedicated *chat.Client in s.chatClients. This adapter // bridges the two when visual orchestration needs to call into a chat upstream. +// googleProviderClient adapts *google.Client to provider.ProviderClient so the +// adapter-based CoreProvider machinery can drive a google-genai protocol +// upstream uniformly across protocols. Google's GenerateContent requires +// a model parameter in the call signature (unlike anthropic/chat), so we +// capture the model name at construction time. +type googleProviderClient struct { + c *google.Client + model string +} + +func (p *googleProviderClient) CreateMessage(ctx context.Context, req any) (any, error) { + googleReq, ok := req.(*google.GenerateContentRequest) + if !ok { + return nil, fmt.Errorf("googleProviderClient: expected *google.GenerateContentRequest, got %T", req) + } + return p.c.GenerateContent(ctx, p.model, googleReq) +} + +func (p *googleProviderClient) StreamMessage(ctx context.Context, req any) (<-chan any, error) { + // Not used by visual orchestrator (uses CreateCore non-streaming path). + return nil, fmt.Errorf("googleProviderClient: streaming not supported via ProviderClient interface") +} + type chatProviderClient struct{ c *chat.Client } func (p *chatProviderClient) CreateMessage(ctx context.Context, req any) (any, error) { @@ -2123,7 +2339,16 @@ func (p *chatProviderClient) StreamMessage(ctx context.Context, req any) (<-chan go func() { defer close(out) for chunk := range stream { - out <- chunk + select { + case <-ctx.Done(): + return + default: + } + select { + case out <- chunk: + case <-ctx.Done(): + return + } } }() return out, nil @@ -2143,7 +2368,6 @@ func normalizeAnthropicRequest(upstream any) (anthropic.MessageRequest, error) { } } - // injectCoreWebSearch replaces web_search tools in coreReq.Tools with injected // tavily_search/firecrawl_fetch tools when the resolved web search mode is "injected". // Returns true if injection was applied. @@ -2233,6 +2457,11 @@ func (a *searchProviderAdapter) StreamMessage(ctx context.Context, req any) (<-c defer close(out) defer stream.Close() for { + select { + case <-ctx.Done(): + return + default: + } ev, err := stream.Next() if err != nil { if err == io.EOF { @@ -2240,7 +2469,11 @@ func (a *searchProviderAdapter) StreamMessage(ctx context.Context, req any) (<-c } return } - out <- ev + select { + case out <- ev: + case <-ctx.Done(): + return + } } }() return out, nil diff --git a/internal/service/server/server.go b/internal/service/server/server.go index 4c792212..19d76df6 100644 --- a/internal/service/server/server.go +++ b/internal/service/server/server.go @@ -73,6 +73,13 @@ type Server struct { sessionManager session.Manager usageTracker usage.Tracker traceWriter trace.Writer + + // clientCaches holds lazily-created HTTP clients for runtime-reloaded providers. + // Keyed by provider key, invalidated when Runtime reloads. + clientCache map[string]*chat.Client + googleCache map[string]*google.Client + clientCacheMu sync.RWMutex + googleCacheMu sync.RWMutex } func (s *Server) runtimeSnapshot() *runtime.ConfigSnapshot { @@ -97,22 +104,42 @@ func (s *Server) activeProviderDefs() map[string]config.ProviderDef { } func (s *Server) activeChatClient(providerKey string) any { + // Check runtime-driven cache first. + s.clientCacheMu.RLock() + if cached, ok := s.clientCache[providerKey]; ok { + s.clientCacheMu.RUnlock() + return cached + } + s.clientCacheMu.RUnlock() + if snap := s.runtimeSnapshot(); snap != nil { if def, ok := snap.Config.ProviderDefs[providerKey]; ok && def.Protocol == config.ProtocolOpenAIChat { - return chat.NewClient(chat.ClientConfig{ + client := chat.NewClient(chat.ClientConfig{ BaseURL: def.BaseURL, APIKey: def.APIKey, UserAgent: def.UserAgent, }) + s.clientCacheMu.Lock() + s.clientCache[providerKey] = client + s.clientCacheMu.Unlock() + return client } } return s.chatClients[providerKey] } func (s *Server) activeGoogleClient(providerKey string) any { + // Check runtime-driven cache first. + s.googleCacheMu.RLock() + if cached, ok := s.googleCache[providerKey]; ok { + s.googleCacheMu.RUnlock() + return cached + } + s.googleCacheMu.RUnlock() + if snap := s.runtimeSnapshot(); snap != nil { if def, ok := snap.Config.ProviderDefs[providerKey]; ok && def.Protocol == config.ProtocolGoogleGenAI { - return google.NewClient(google.ClientConfig{ + client := google.NewClient(google.ClientConfig{ BaseURL: def.BaseURL, APIKey: def.APIKey, Project: def.Project, @@ -120,6 +147,10 @@ func (s *Server) activeGoogleClient(providerKey string) any { Version: def.APIVersion, UserAgent: def.UserAgent, }) + s.googleCacheMu.Lock() + s.googleCache[providerKey] = client + s.googleCacheMu.Unlock() + return client } } return s.googleClients[providerKey] @@ -148,6 +179,8 @@ func New(cfg Config) *Server { sessionManager: cfg.SessionManager, usageTracker: cfg.UsageTracker, traceWriter: cfg.TraceWriter, + clientCache: make(map[string]*chat.Client), + googleCache: make(map[string]*google.Client), } s.mux.HandleFunc("/v1/responses", s.handleResponses) s.mux.HandleFunc("/responses", s.handleResponses) diff --git a/internal/service/server/server_test.go b/internal/service/server/server_test.go index bd6cd39c..3eaaedf0 100644 --- a/internal/service/server/server_test.go +++ b/internal/service/server/server_test.go @@ -13,13 +13,13 @@ import ( "strings" "testing" + "moonbridge/internal/config" "moonbridge/internal/extension/codex" deepseekv4 "moonbridge/internal/extension/deepseek_v4" "moonbridge/internal/extension/plugin" - "moonbridge/internal/config" + "moonbridge/internal/format" "moonbridge/internal/logger" "moonbridge/internal/protocol/openai" - "moonbridge/internal/format" "moonbridge/internal/service/provider" "moonbridge/internal/service/server" "moonbridge/internal/service/stats" @@ -82,7 +82,6 @@ func (provider providerFunc) StreamMessage(ctx context.Context, req any) (<-chan return provider.stream(ctx, req) } - type roundTripFunc func(*http.Request) (*http.Response, error) func (fn roundTripFunc) RoundTrip(request *http.Request) (*http.Response, error) { diff --git a/internal/service/server/trace/writer.go b/internal/service/server/trace/writer.go index 78b7310f..e202f222 100644 --- a/internal/service/server/trace/writer.go +++ b/internal/service/server/trace/writer.go @@ -24,8 +24,8 @@ type Writer interface { // FileWriter implements Writer by delegating to a *mbtrace.Tracer. type FileWriter struct { - tracer *mbtrace.Tracer - errors io.Writer + tracer *mbtrace.Tracer + errors io.Writer } // NewFileWriter creates a new FileWriter. diff --git a/internal/service/server/websearch_inject.go b/internal/service/server/websearch_inject.go index 068a7b3c..89496f69 100644 --- a/internal/service/server/websearch_inject.go +++ b/internal/service/server/websearch_inject.go @@ -116,6 +116,12 @@ func (s *Server) executeChatSearchLoop( } msg := resp.Choices[0].Message + // Defensive: ensure message role is set. Upstream may return empty role + // in some error-recovery scenarios, which would break the alternating + // user/assistant/tool contract on subsequent rounds. + if msg.Role == "" { + msg.Role = "assistant" + } if len(msg.ToolCalls) == 0 { return resp, nil } @@ -136,6 +142,22 @@ func (s *Server) executeChatSearchLoop( return resp, nil } if len(nonSearchCalls) > 0 { + // Execute search calls as side effect, return only non-search content. + var toolResultMsgs []chat.ChatMessage + for _, tc := range searchCalls { + result, execErr := executeChatSearchCall(ctx, tavily, firecrawl, tc) + if execErr != nil { + log.Warn("Chat搜索执行失败(混合调用)", "tool", tc.Function.Name, "error", execErr) + result = fmt.Sprintf("Search error: %s", execErr.Error()) + } + toolResultMsgs = append(toolResultMsgs, chat.ChatMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: result, + }) + } + req.Messages = append(req.Messages, msg) + req.Messages = append(req.Messages, toolResultMsgs...) return resp, nil } @@ -193,7 +215,7 @@ func executeChatSearchCall( if err != nil { return "", err } - return formatTavilyResults(result), nil + return websearch.FormatTavilyResults(result), nil case "firecrawl_fetch": if firecrawl == nil { @@ -216,7 +238,7 @@ func executeChatSearchCall( if err != nil { return "", err } - return formatFirecrawlResult(result), nil + return websearch.FormatFirecrawlResult(result), nil default: return "", fmt.Errorf("unknown search tool: %s", tc.Function.Name) @@ -299,12 +321,18 @@ func (s *Server) executeGoogleSearchLoop( if err != nil { return nil, err } - if len(resp.Candidates) == 0 || len(resp.Candidates[0].Content.Parts) == 0 { + if len(resp.Candidates) == 0 { return resp, nil } - // Check for function call parts. - funcCalls := googleFuncCalls(resp.Candidates[0].Content.Parts) + // Collect function calls from ALL candidates, not just the first. + var funcCalls []google.FunctionCall + for _, c := range resp.Candidates { + if len(c.Content.Parts) == 0 { + continue + } + funcCalls = append(funcCalls, googleFuncCalls(c.Content.Parts)...) + } if len(funcCalls) == 0 { return resp, nil } @@ -314,6 +342,34 @@ func (s *Server) executeGoogleSearchLoop( return resp, nil } if len(nonSearchCalls) > 0 { + // Execute search calls as side effect, feed results to the model. + responseParts := make([]google.Part, 0, len(searchCalls)) + for _, fc := range searchCalls { + result, execErr := executeGoogleSearchCall(ctx, tavily, firecrawl, fc) + if execErr != nil { + log.Warn("Google搜索执行失败(混合调用)", "tool", fc.Name, "error", execErr) + result = execErr.Error() + } + respJSON, _ := json.Marshal(map[string]any{"result": result}) + responseParts = append(responseParts, google.Part{ + FunctionResponse: &google.FunctionResponse{ + Name: fc.Name, + Response: respJSON, + }, + }) + } + for _, c := range resp.Candidates { + if len(c.Content.Parts) > 0 { + req.Contents = append(req.Contents, google.Content{ + Role: "model", + Parts: c.Content.Parts, + }) + } + } + req.Contents = append(req.Contents, google.Content{ + Role: "function", + Parts: responseParts, + }) return resp, nil } @@ -335,10 +391,14 @@ func (s *Server) executeGoogleSearchLoop( } // Append model response + function response for next round. - req.Contents = append(req.Contents, google.Content{ - Role: "model", - Parts: resp.Candidates[0].Content.Parts, - }) + for _, c := range resp.Candidates { + if len(c.Content.Parts) > 0 { + req.Contents = append(req.Contents, google.Content{ + Role: "model", + Parts: c.Content.Parts, + }) + } + } req.Contents = append(req.Contents, google.Content{ Role: "function", Parts: responseParts, @@ -405,7 +465,7 @@ func executeGoogleSearchCall( if err != nil { return "", err } - return formatTavilyResults(result), nil + return websearch.FormatTavilyResults(result), nil case "firecrawl_fetch": if firecrawl == nil { @@ -429,52 +489,14 @@ func executeGoogleSearchCall( if err != nil { return "", err } - return formatFirecrawlResult(result), nil + return websearch.FormatFirecrawlResult(result), nil default: return "", fmt.Errorf("unknown search tool: %s", fc.Name) } } -// ============================================================================ -// Formatting helpers (duplicated from websearch package for encapsulation) -// ============================================================================ - -func formatTavilyResults(result *websearch.SearchResult) string { - var b strings.Builder - b.WriteString(fmt.Sprintf("Search results for %q:\n\n", result.Query)) - if result.Answer != "" { - b.WriteString("Answer: ") - b.WriteString(truncate(result.Answer, 2000)) - b.WriteString("\n\n") - } - for i, item := range result.Results { - if i >= 10 { - break - } - b.WriteString(fmt.Sprintf("%d. [%s](%s)\n", i+1, item.Title, item.URL)) - b.WriteString(fmt.Sprintf(" Score: %.2f\n", item.Score)) - b.WriteString(fmt.Sprintf(" %s\n\n", truncate(item.Content, 500))) - } - return b.String() -} - -func formatFirecrawlResult(result *websearch.FetchResult) string { - var b strings.Builder - b.WriteString(fmt.Sprintf("Content from %s:\n\n", result.Data.Metadata.SourceURL)) - if result.Data.Metadata.Title != "" { - b.WriteString(fmt.Sprintf("Title: %s\n\n", result.Data.Metadata.Title)) - } - b.WriteString(truncate(result.Data.Markdown, 8000)) - return b.String() -} - -func truncate(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - return s[:maxLen] + "..." -} +// Formatting helpers — delegated to exported websearch package functions. // ============================================================================ // Chat streaming search loop @@ -537,6 +559,29 @@ func (s *Server) chatSearchBufferedStream( break } if len(nonSearchCalls) > 0 { + // Execute search calls, feed results to next round. + var toolResultMsgs []chat.ChatMessage + for _, tc := range searchCalls { + result, execErr := executeChatSearchCall(ctx, tavily, firecrawl, tc) + if execErr != nil { + log.Warn("流式搜索执行失败(混合调用)", "tool", tc.Function.Name, "error", execErr) + result = fmt.Sprintf("Search error: %s", execErr.Error()) + } + toolResultMsgs = append(toolResultMsgs, chat.ChatMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: result, + }) + } + asstContent := collectChatStreamContent(events) + reasoningContent := collectChatStreamReasoning(events) + req.Messages = append(req.Messages, chat.ChatMessage{ + Role: "assistant", + Content: asstContent, + ToolCalls: toolCalls, + ReasoningContent: reasoningContent, + }) + req.Messages = append(req.Messages, toolResultMsgs...) allEvents = append(allEvents, events...) exhausted = false break diff --git a/internal/service/store/config_store.go b/internal/service/store/config_store.go index 32c8c273..26bc8596 100644 --- a/internal/service/store/config_store.go +++ b/internal/service/store/config_store.go @@ -80,16 +80,16 @@ type ModelRow struct { // RouteRow represents a row in the config_store_routes table. type RouteRow struct { - Alias string - ModelSlug string - ProviderKey string - DisplayName string - ContextWindow int - MaxOutputTokens int - Extensions string // JSON-serialized map[string]config.ExtensionFileConfig - WebSearch string // JSON-serialized config.WebSearchFileConfig - CreatedAt string - UpdatedAt string + Alias string + ModelSlug string + ProviderKey string + DisplayName string + ContextWindow int + MaxOutputTokens int + Extensions string // JSON-serialized map[string]config.ExtensionFileConfig + WebSearch string // JSON-serialized config.WebSearchFileConfig + CreatedAt string + UpdatedAt string } // SettingRow represents a row in the config_store_settings table. @@ -100,16 +100,16 @@ type SettingRow struct { // ChangeRow represents a row in the config_store_changes table. type ChangeRow struct { - ID int64 - BatchID string - Action string // "create", "update", "delete" - Resource string // "provider", "offer", "model", "route", "setting" - TargetKey string - Before string // JSON-serialized "before" state (empty for create) - After string // JSON-serialized "after" state (empty for delete) - Applied bool - Error string - Revision int - CreatedAt string - AppliedAt string + ID int64 + BatchID string + Action string // "create", "update", "delete" + Resource string // "provider", "offer", "model", "route", "setting" + TargetKey string + Before string // JSON-serialized "before" state (empty for create) + After string // JSON-serialized "after" state (empty for delete) + Applied bool + Error string + Revision int + CreatedAt string + AppliedAt string } diff --git a/internal/service/store/consumer.go b/internal/service/store/consumer.go index 196a366e..042aa282 100644 --- a/internal/service/store/consumer.go +++ b/internal/service/store/consumer.go @@ -12,7 +12,7 @@ import ( type ConfigStoreConsumer struct { store db.Store persistenceDisabled bool - extensionSpecs []config.ExtensionConfigSpec + extensionSpecs []config.ExtensionConfigSpec logger *slog.Logger configStore ConfigStore } @@ -152,5 +152,5 @@ func (c *ConfigStoreConsumer) Store() ConfigStore { // compile-time interface checks. var ( - _ db.Consumer = (*ConfigStoreConsumer)(nil) + _ db.Consumer = (*ConfigStoreConsumer)(nil) )