diff --git a/README.md b/README.md
index 64dc173c..04958912 100644
--- a/README.md
+++ b/README.md
@@ -93,7 +93,6 @@ docker run -p 38440:38440 -v $(pwd)/config.yml:/config/config.yml moonbridge
| `/v1/models` | GET | 列出可用模型 |
| `/models` | GET | 同上 |
| `/api/v1/` | — | 管理 API(需启用持久化) |
-| `/health` | GET | 健康检查 |
详细 API 文档见 [API.md](docs/api.md)。
diff --git a/config.example.yml b/config.example.yml
index e4a32168..1f97cce2 100644
--- a/config.example.yml
+++ b/config.example.yml
@@ -67,6 +67,11 @@ extensions:
# binding: MOONBRIDGE_DB
# Persistence consumer: request metrics
+ # apply_patch 代理扩展:控制 Codex 文件编辑工具是否展开为结构化代理工具。
+ # 默认关闭,开启后 apply_patch 被拆分为 add_file/delete_file/update_file/replace_file/batch
+ # codex_tool_proxy:
+ # enabled: true
+
metrics:
enabled: true
config:
diff --git a/docs/api.md b/docs/api.md
index 2f3257f0..2d024e84 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -66,10 +66,32 @@ data: {"response": {...}}
| 端点 | 方法 | 功能 |
|------|------|------|
-| `/api/v1/config` | GET/PUT | 获取/更新运行时配置 |
-| `/api/v1/codex/config` | GET | 生成 Codex TOML 配置 |
-| `/api/v1/providers` | GET/POST/DELETE | 管理 Provider |
-| `/api/v1/sessions/{id}` | GET | 获取会话用量统计 |
+| `/api/v1/providers` | GET/POST/PUT/PATCH/DELETE | 管理 Provider CRUD |
+| `/api/v1/providers/{key}/offers` | POST | 创建 Offer |
+| `/api/v1/providers/{key}/offers/{model}` | PATCH/DELETE | 更新/删除 Offer |
+| `/api/v1/providers/{key}/test` | POST | 测试 Provider 连通性 |
+| `/api/v1/models` | GET | 列出模型定义 |
+| `/api/v1/models/{slug}` | GET/PUT/DELETE | 管理单个模型定义 |
+| `/api/v1/routes` | GET | 列出路由 |
+| `/api/v1/routes/{alias}` | GET/PUT/DELETE | 管理单个路由 |
+| `/api/v1/defaults` | GET/PUT | 管理默认配置(model/max_tokens/system_prompt) |
+| `/api/v1/web-search` | GET/PUT | 管理全局 Web Search 设置 |
+| `/api/v1/extensions` | GET | 列出扩展 |
+| `/api/v1/extensions/{name}` | GET/PUT | 管理扩展设置 |
+| `/api/v1/config/effective` | GET | 获取生效配置 |
+| `/api/v1/config/export` | GET | 导出配置 YAML |
+| `/api/v1/config/import` | POST | 导入配置 YAML |
+| `/api/v1/config/validate` | POST | 校验配置 |
+| `/api/v1/changes` | GET | 列出未提交的配置变更 |
+| `/api/v1/changes/apply` | POST | 应用配置变更 |
+| `/api/v1/changes/discard` | POST | 放弃配置变更 |
+| `/api/v1/status` | GET | 系统状态 |
+| `/api/v1/status/providers` | GET | Provider 运行状态 |
+| `/api/v1/sessions` | GET | 列出活跃会话 |
+| `/api/v1/stats` | GET | 用量统计 |
+| `/api/v1/stats/summary` | GET | 用量统计摘要 |
+| `/api/v1/logs` | GET | 日志查询 |
+| `/api/v1/version` | GET | 版本信息 |
## 错误处理
diff --git a/docs/architecture.md b/docs/architecture.md
index 128ecfc1..d5775d22 100644
--- a/docs/architecture.md
+++ b/docs/architecture.md
@@ -8,28 +8,46 @@ Moon Bridge 是一个 Go 语言编写的 HTTP 代理/协议转换服务器。对
## 四层架构
-```
-┌─────────────────────────────────────────────────┐
-│ Service 层 │
-│ server(路由/处理) adapter_dispatch(协议分发) │
-│ provider(路由) stats(统计) trace(跟踪) │
-│ proxy(Capture代理) api(管理 API) │
-│ store(持久化) runtime(运行时) │
-├─────────────────────────────────────────────────┤
-│ Protocol 层 │
-│ format(核心类型/注册表) anthropic(Anthropic 适配) │
-│ openai(OpenAI 适配) google(GenAI 适配) │
-│ chat(OpenAI Chat 适配) cache(缓存) │
-├─────────────────────────────────────────────────┤
-│ 基础组件(直接位于 internal/ 下) │
-│ config(配置) logger(日志) openai_dto(共享 DTO) │
-│ modelref(模型引用) session(会话) db(数据库) │
-├─────────────────────────────────────────────────┤
-│ Extension 层 │
-│ deepseek_v4 visual websearch websearchinjected│
-│ kimi_workaround metrics codex(Codex模型目录) │
-│ plugin(插件注册/接口) db(持久化后端: SQLite/D1) │
-└─────────────────────────────────────────────────┘
+```mermaid
+flowchart TB
+ subgraph Service["Service 层"]
+ s1["server(路由/处理)"]
+ s2["adapter_dispatch(协议分发)"]
+ s3["provider(路由)"]
+ s4["stats(统计)"]
+ s5["trace(跟踪)"]
+ s6["proxy(Capture代理)"]
+ s7["api(管理 API)"]
+ s8["store(持久化)"]
+ s9["runtime(运行时)"]
+ end
+ subgraph Protocol["Protocol 层"]
+ p1["format(核心类型/注册表)"]
+ p2["anthropic(Anthropic 适配)"]
+ p3["openai(OpenAI 适配)"]
+ p4["google(GenAI 适配)"]
+ p5["chat(OpenAI Chat 适配)"]
+ p6["cache(缓存)"]
+ end
+ subgraph Base["基础组件"]
+ b1["config(配置)"]
+ b2["logger(日志)"]
+ b3["openai_dto(共享 DTO)"]
+ b4["modelref(模型引用)"]
+ b5["session(会话)"]
+ b6["db(数据库)"]
+ end
+ subgraph Extension["Extension 层"]
+ e1["deepseek_v4"]
+ e2["visual"]
+ e3["websearch"]
+ e4["websearchinjected"]
+ e5["kimi_workaround"]
+ e6["metrics"]
+ e7["codex(模型目录)"]
+ e8["plugin(插件注册/接口)"]
+ e9["db(SQLite/D1)"]
+ end
```
### 基础组件(internal/ 顶层包)
@@ -58,7 +76,7 @@ Moon Bridge 是一个 Go 语言编写的 HTTP 代理/协议转换服务器。对
业务编排层,组合基础层和 Protocol 组件:
-- `internal/service/server` — HTTP 服务器、路由(`/v1/responses`、`/v1/models`、`/health` 等)、认证
+- `internal/service/server` — HTTP 服务器、路由(`/v1/responses`、`/v1/models` 等)、认证
- `internal/service/server/adapter_dispatch.go` — Adapter 分发路径(switch 协议类型 → 调用对应 Adapter)
- `internal/service/provider` — Provider 管理器(多 Provider 路由、配置热重载)
- `internal/service/proxy` — Capture 模式下的透明代理
@@ -68,7 +86,6 @@ Moon Bridge 是一个 Go 语言编写的 HTTP 代理/协议转换服务器。对
- `internal/service/trace` — 请求跟踪(捕获请求/响应的完整链路,持久化到 `data/trace/`)
- `internal/service/store` — 配置持久化存储(SQLite / D1)
- `internal/service/runtime` — 运行时上下文
-- `internal/service/bridge` — 备用桥接层
### Extension 层
@@ -81,6 +98,8 @@ Moon Bridge 是一个 Go 语言编写的 HTTP 代理/协议转换服务器。对
- `internal/extension/metrics` — 请求指标采集与查询
- `internal/extension/plugin` — 三方插件注册管理(`PluginRegistry` + `CorePluginHooks`)
- `internal/extension/codex` — Codex 模型目录
+- `internal/extension/codex_tool_proxy` — apply_patch 代理扩展
+- `internal/extension/kimi_workaround` — Kimi 工具调用轮次限制
- `internal/extension/db` — 持久化 Provider(SQLite 本地 / Cloudflare D1 Worker)
## 三种运行模式
@@ -93,26 +112,22 @@ Moon Bridge 是一个 Go 语言编写的 HTTP 代理/协议转换服务器。对
## 请求生命周期数据流(Transform 模式)
-```
-客户端 (Codex CLI)
- │ POST /v1/responses (OpenAI Responses 格式)
- ▼
-server.handleResponses()
- │ 认证 / 日志 / 统计初始化 / 路由解析
- ▼
-adapter_dispatch.go (Adapter 分发)
- │ preferred.Protocol 决定上游协议
- │
- ├── ProtocolAnthropic → anthropic adapter
- ├── ProtocolGoogleGenAI → google adapter
- ├── ProtocolOpenAIChat → chat adapter
- └── ProtocolOpenAIResponse → 直通(无协议转换)
- │
- ├── 插件拦截 (PluginHooks)
- │ MutateCoreRequest → [Adapter] → RememberContent → OnStreamEvent
- │
- ▼
-客户端 ←── OpenAI Responses 响应
+```mermaid
+flowchart TD
+ A["客户端 (Codex CLI)"]
+ A -->|"POST /v1/responses
(OpenAI Responses 格式)"| B["server.handleResponses()"]
+ B -->|"认证 / 日志 / 统计初始化 / 路由解析"| C["adapter_dispatch.go (Adapter 分发)"]
+ C -->|"preferred.Protocol 决定上游协议"| D{"协议分支"}
+ D -->|"ProtocolAnthropic"| E["anthropic adapter"]
+ D -->|"ProtocolGoogleGenAI"| F["google adapter"]
+ D -->|"ProtocolOpenAIChat"| G["chat adapter"]
+ D -->|"ProtocolOpenAIResponse"| H["直通(无协议转换)"]
+ E --> I["插件拦截 (PluginHooks)"]
+ F --> I
+ G --> I
+ H --> I
+ I --> J["MutateCoreRequest → [Adapter] → RememberContent → OnStreamEvent"]
+ J --> K["客户端
← OpenAI Responses 响应"]
```
## 模型路由
@@ -141,9 +156,9 @@ adapter_dispatch.go (Adapter 分发)
```go
type ClientAdapter interface {
- Protocol() string
- ToCoreRequest(context.Context, []byte) (*CoreRequest, error)
- FromCoreResponse(context.Context, *CoreResponse) ([]byte, error)
+ ClientProtocol() string
+ ToCoreRequest(context.Context, any) (*CoreRequest, error)
+ FromCoreResponse(context.Context, *CoreResponse) (any, error)
}
type ProviderAdapter interface {
diff --git a/docs/development-conventions.md b/docs/development-conventions.md
index 6f7eb8fa..6c1457ee 100644
--- a/docs/development-conventions.md
+++ b/docs/development-conventions.md
@@ -4,47 +4,66 @@
### 目录布局
-```
-internal/
-├── config/ # 配置加载/校验/Schema
-├── logger/ # 结构化日志(slog 封装)
-├── openai_dto/ # 共享 OpenAI DTO
-├── modelref/ # 模型引用解析
-├── session/ # 会话管理
-├── db/ # 数据库抽象与注册表
-├── format/ # Core 类型(CoreRequest/CoreResponse/Registry/Adapter 接口)
-├── protocol/ # 协议转换层
-│ ├── anthropic/ # Anthropic Messages Adapter
-│ ├── cache/ # Prompt 缓存规划
-│ ├── chat/ # OpenAI Chat Adapter
-│ ├── format/ # (遗留层,功能已迁移到 internal/format)
-│ ├── google/ # Google Gemini Adapter
-│ └── openai/ # OpenAI Responses Adapter
-├── service/ # 业务编排层
-│ ├── api/ # 管理 REST API(路由在 router.go)
-│ ├── app/ # 应用生命周期管理、Extension 目录
-│ ├── bridge/ # (空目录,保留以备将来使用)
-│ ├── e2e/ # 服务层 E2E 测试
-│ ├── provider/ # Provider 管理器
-│ ├── proxy/ # Capture 模式代理
-│ ├── runtime/ # 运行时上下文
-│ ├── server/ # HTTP 服务器 + 路由 + 认证 + Adapter 分发
-│ │ ├── session/ # 会话管理
-│ │ ├── trace/ # 请求跟踪写入
-│ │ └── usage/ # 用量跟踪
-│ ├── stats/ # 用量统计
-│ └── trace/ # 请求跟踪记录
-├── extension/ # 可插拔扩展
-│ ├── codex/ # Codex 模型目录(catalog.go、default_instructions.go)
-│ ├── db/ # 数据库 Provider(sqlite/、d1/)
-│ ├── deepseek_v4/ # DeepSeek V4 推理优化
-│ ├── kimi_workaround/ # Kimi 模型 tool call 轮次限制
-│ ├── metrics/ # 用量指标采集与查询
-│ ├── plugin/ # Plugin 接口 + 能力接口 + 注册表
-│ ├── visual/ # 视觉模型分发(CoreProvider 模式)
-│ ├── websearch/ # Web Search 编排器
-│ └── websearchinjected/ # 注入式搜索插件
-└── e2e/ # 端到端集成测试(协议转换)
+```mermaid
+flowchart TD
+ subgraph internal["internal/"]
+ direction TB
+ config["config/ — 配置加载/校验/Schema"]
+ logger["logger/ — 结构化日志(slog封装)"]
+ openai_dto["openai_dto/ — 共享 OpenAI DTO"]
+ modelref["modelref/ — 模型引用解析"]
+ session["session/ — 会话管理"]
+ db["db/ — 数据库抽象与注册表"]
+ fmt["format/ — Core类型/Registry/Adapter接口"]
+
+ subgraph protocol["protocol/ — 协议转换层"]
+ direction TB
+ pa["anthropic/ — Anthropic Messages Adapter"]
+ pc["cache/ — Prompt 缓存规划"]
+ pch["chat/ — OpenAI Chat Adapter"]
+ pf["format/ — (遗留层,功能已迁移到 internal/format)"]
+ pg["google/ — Google Gemini Adapter"]
+ po["openai/ — OpenAI Responses Adapter"]
+ end
+
+ subgraph service["service/ — 业务编排层"]
+ direction TB
+ sa["api/ — 管理 REST API"]
+ sapp["app/ — 应用生命周期管理、Extension 目录"]
+ se["e2e/ — 服务层 E2E 测试"]
+ sp["provider/ — Provider 管理器"]
+ spr["proxy/ — Capture 模式代理"]
+ srt["runtime/ — 运行时上下文"]
+ subgraph srv["server/ — HTTP服务器/路由/认证/Adapter分发"]
+ direction TB
+ ss["session/ — 会话管理"]
+ st["trace/ — 请求跟踪写入"]
+ su["usage/ — 用量跟踪"]
+ end
+ sst["stats/ — 用量统计"]
+ str["trace/ — 请求跟踪记录"]
+ end
+
+ subgraph extension["extension/ — 可插拔扩展"]
+ direction TB
+ ec["codex/ — Codex 模型目录"]
+ subgraph edb["db/ — 数据库 Provider"]
+ es["sqlite/"]
+ ed1["d1/"]
+ end
+ eds["deepseek_v4/ — DeepSeek V4 推理优化"]
+ ek["kimi_workaround/ — Kimi tool call 轮次限制"]
+ em["metrics/ — 用量指标采集与查询"]
+ ep["plugin/ — Plugin 接口+能力接口+注册表"]
+ ev["visual/ — 视觉模型分发(CoreProvider模式)"]
+ ew["websearch/ — Web Search 编排器"]
+ ewi["websearchinjected/ — 注入式搜索插件"]
+ ectp["codex_tool_proxy/ — apply_patch 代理扩展"]
+ ect["codextool/ — 工具类型定义与工具映射"]
+ end
+
+ e2e["e2e/ — 端到端集成测试(协议转换)"]
+ end
```
### 依赖方向
diff --git a/docs/extensions-overview.md b/docs/extensions-overview.md
index c7d78bd1..b773662f 100644
--- a/docs/extensions-overview.md
+++ b/docs/extensions-overview.md
@@ -250,6 +250,46 @@ moonbridge -config config.yml -print-codex-config my-model
---
+
+---
+
+## codex_tool_proxy(apply_patch 代理扩展)
+
+控制 Codex 的 `apply_patch` 自定义工具是否被展开为 5 个结构化代理工具(`add_file`、`delete_file`、`update_file`、`replace_file`、`batch`)发送给上游模型。
+
+**位置**:`internal/extension/codex_tool_proxy/`
+
+**文件清单**:
+
+| 文件 | 用途 |
+|------|------|
+| `plugin.go` | Plugin 实现 + PatchProxyDecider |
+
+**行为**:
+
+- **默认关闭**(`DefaultEnabled: false`):`apply_patch` 以 raw grammar 形态原样透传给上游模型
+- **开启后**:展开为 5 个独立的结构化工具,让上游模型以 JSON schema 方式调用
+
+**实现的能力**:
+
+```go
+var (
+ _ plugin.Plugin = (*ProxyPlugin)(nil)
+ _ plugin.ConfigSpecProvider = (*ProxyPlugin)(nil)
+ _ plugin.PatchProxyDecider = (*ProxyPlugin)(nil)
+)
+```
+
+**启用方式**:
+
+```yaml
+extensions:
+ codex_tool_proxy:
+ enabled: true
+```
+
+支持 route / model / provider 级别覆盖。
+
## visual(视觉扩展)
diff --git a/internal/config/config.go b/internal/config/config.go
index 091b0905..e5695535 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -20,8 +20,8 @@ const (
ProtocolAnthropic = "anthropic"
ProtocolOpenAIResponse = "openai-response"
// Phase 5: New protocol constants (D-08)
- ProtocolGoogleGenAI = "google-genai"
- ProtocolOpenAIChat = "openai-chat"
+ ProtocolGoogleGenAI = "google-genai"
+ ProtocolOpenAIChat = "openai-chat"
)
type Mode string
@@ -51,22 +51,22 @@ type WebSearchConfig struct {
}
type Config struct {
- Mode Mode
- Addr string
- AuthToken string
- TraceRequests bool
- LogLevel string
- LogFormat string
- SystemPrompt string
- DefaultModel string
- WebSearchSupport WebSearchSupport
- WebSearchMaxUses int
- TavilyAPIKey string
- FirecrawlAPIKey string
- SearchMaxRounds int
- DefaultMaxTokens int
- MaxSessions int `yaml:"max_sessions"` // 0 = unlimited
- SessionTTL string `yaml:"session_ttl"` // default "24h"
+ Mode Mode
+ Addr string
+ AuthToken string
+ TraceRequests bool
+ LogLevel string
+ LogFormat string
+ SystemPrompt string
+ DefaultModel string
+ WebSearchSupport WebSearchSupport
+ WebSearchMaxUses int
+ TavilyAPIKey string
+ FirecrawlAPIKey string
+ SearchMaxRounds int
+ DefaultMaxTokens int
+ MaxSessions int `yaml:"max_sessions"` // 0 = unlimited
+ SessionTTL string `yaml:"session_ttl"` // default "24h"
// Defaults holds the default configuration values.
Defaults Defaults
// Models is the canonical model definition map (shared across providers).
@@ -115,11 +115,11 @@ type RouteEntry struct {
// ProviderDef defines a single upstream provider.
type ProviderDef struct {
- BaseURL string
- APIKey string
- Version string
- UserAgent string
- Protocol string // "anthropic" (default), "openai-response", "google-genai", or "openai-chat"
+ BaseURL string
+ APIKey string
+ Version string
+ UserAgent string
+ Protocol string // "anthropic" (default), "openai-response", "google-genai", or "openai-chat"
// Phase 5: Google GenAI flat fields (D-09).
// Only relevant when Protocol == ProtocolGoogleGenAI.
// project: Google Cloud project ID (Vertex AI).
@@ -129,7 +129,7 @@ type ProviderDef struct {
Location string `yaml:"location,omitempty"`
APIVersion string `yaml:"api_version,omitempty"`
// Cache config for this provider. If nil, provider does not use caching.
- Cache *CacheConfig `yaml:"cache,omitempty"`
+ Cache *CacheConfig `yaml:"cache,omitempty"`
WebSearchSupport WebSearchSupport
WebSearchMaxUses int
TavilyAPIKey string
@@ -204,11 +204,11 @@ type ModelDef struct {
// OfferEntry declares that a provider offers a model defined in Models.
type OfferEntry struct {
- Model string // references models.
- UpstreamName string // optional, upstream model name (empty = same as slug)
- Priority int // lower value = higher priority (0 is highest)
+ Model string // references models.
+ UpstreamName string // optional, upstream model name (empty = same as slug)
+ Priority int // lower value = higher priority (0 is highest)
Pricing ModelPricing
- Overrides *ModelDef // optional provider-specific overrides
+ Overrides *ModelDef // optional provider-specific overrides
}
type ResponseProxyConfig struct {
diff --git a/internal/config/config_loader.go b/internal/config/config_loader.go
index 76dee14c..2b366138 100644
--- a/internal/config/config_loader.go
+++ b/internal/config/config_loader.go
@@ -88,8 +88,8 @@ type TraceFileConfig struct {
}
type ServerFileConfig struct {
- Addr string `yaml:"addr" json:"addr,omitempty"`
- AuthToken string `yaml:"auth_token" json:"auth_token,omitempty"`
+ Addr string `yaml:"addr" json:"addr,omitempty"`
+ AuthToken string `yaml:"auth_token" json:"auth_token,omitempty"`
MaxSessions int `yaml:"max_sessions"`
SessionTTL string `yaml:"session_ttl"`
}
@@ -122,33 +122,33 @@ type ModelDefFileConfig struct {
}
type OfferFileConfig struct {
- Model string `yaml:"model" json:"model"`
- UpstreamName string `yaml:"upstream_name,omitempty" json:"upstream_name,omitempty"`
- Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
- Pricing ModelPricingFileConfig `yaml:"pricing,omitempty" json:"pricing,omitempty"`
- Overrides *ModelDefFileConfig `yaml:"overrides,omitempty" json:"overrides,omitempty"`
+ Model string `yaml:"model" json:"model"`
+ UpstreamName string `yaml:"upstream_name,omitempty" json:"upstream_name,omitempty"`
+ Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
+ Pricing ModelPricingFileConfig `yaml:"pricing,omitempty" json:"pricing,omitempty"`
+ Overrides *ModelDefFileConfig `yaml:"overrides,omitempty" json:"overrides,omitempty"`
}
type ProviderDefFileConfig struct {
- BaseURL string `yaml:"base_url" json:"base_url"`
- APIKey string `yaml:"api_key" json:"api_key"`
- Version string `yaml:"version,omitempty" json:"version,omitempty"`
- UserAgent string `yaml:"user_agent,omitempty" json:"user_agent,omitempty"`
- Protocol string `yaml:"protocol,omitempty" json:"protocol,omitempty"`
- WebSearch WebSearchFileConfig `yaml:"web_search,omitempty" json:"web_search,omitempty"`
- Extensions map[string]ExtensionFileConfig `yaml:"extensions,omitempty" json:"extensions,omitempty"`
- Offers []OfferFileConfig `yaml:"offers,omitempty" json:"offers,omitempty"`
+ BaseURL string `yaml:"base_url" json:"base_url"`
+ APIKey string `yaml:"api_key" json:"api_key"`
+ Version string `yaml:"version,omitempty" json:"version,omitempty"`
+ UserAgent string `yaml:"user_agent,omitempty" json:"user_agent,omitempty"`
+ Protocol string `yaml:"protocol,omitempty" json:"protocol,omitempty"`
+ WebSearch WebSearchFileConfig `yaml:"web_search,omitempty" json:"web_search,omitempty"`
+ Extensions map[string]ExtensionFileConfig `yaml:"extensions,omitempty" json:"extensions,omitempty"`
+ Offers []OfferFileConfig `yaml:"offers,omitempty" json:"offers,omitempty"`
}
type RouteFileConfig struct {
- To string `yaml:"to,omitempty" json:"to,omitempty"` // backward compat "provider/model"
- Model string `yaml:"model,omitempty" json:"model,omitempty"`
- Provider string `yaml:"provider,omitempty" json:"provider,omitempty"`
- DisplayName string `yaml:"display_name,omitempty" json:"display_name,omitempty"`
- Description string `yaml:"description,omitempty" json:"description,omitempty"`
- ContextWindow int `yaml:"context_window,omitempty" json:"context_window,omitempty"`
- WebSearch WebSearchFileConfig `yaml:"web_search,omitempty" json:"web_search,omitempty"`
- Extensions map[string]ExtensionFileConfig `yaml:"extensions,omitempty" json:"extensions,omitempty"`
+ To string `yaml:"to,omitempty" json:"to,omitempty"` // backward compat "provider/model"
+ Model string `yaml:"model,omitempty" json:"model,omitempty"`
+ Provider string `yaml:"provider,omitempty" json:"provider,omitempty"`
+ DisplayName string `yaml:"display_name,omitempty" json:"display_name,omitempty"`
+ Description string `yaml:"description,omitempty" json:"description,omitempty"`
+ ContextWindow int `yaml:"context_window,omitempty" json:"context_window,omitempty"`
+ WebSearch WebSearchFileConfig `yaml:"web_search,omitempty" json:"web_search,omitempty"`
+ Extensions map[string]ExtensionFileConfig `yaml:"extensions,omitempty" json:"extensions,omitempty"`
}
func (cfg *RouteFileConfig) UnmarshalYAML(value *yaml.Node) error {
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
index a143cd48..4bbd6992 100644
--- a/internal/config/config_test.go
+++ b/internal/config/config_test.go
@@ -6,9 +6,9 @@ import (
"strings"
"testing"
+ "moonbridge/internal/config"
deepseekv4 "moonbridge/internal/extension/deepseek_v4"
"moonbridge/internal/extension/visual"
- "moonbridge/internal/config"
)
func builtinExtensionSpecsForTest() []config.ExtensionConfigSpec {
diff --git a/internal/e2e/anthropic_e2e_test.go b/internal/e2e/anthropic_e2e_test.go
index 5235a3c5..b9769d07 100644
--- a/internal/e2e/anthropic_e2e_test.go
+++ b/internal/e2e/anthropic_e2e_test.go
@@ -12,8 +12,8 @@ import (
"strings"
"testing"
- "moonbridge/internal/protocol/anthropic"
"moonbridge/internal/format"
+ "moonbridge/internal/protocol/anthropic"
"moonbridge/internal/protocol/openai"
)
@@ -73,8 +73,8 @@ func TestAnthropicE2E_TextRoundTrip(t *testing.T) {
// Step 1: Build OpenAI Responses request.
openAIReq := openai.ResponsesRequest{
- Model: "claude-3.5-sonnet",
- Input: json.RawMessage(`"Hello"`),
+ Model: "claude-3.5-sonnet",
+ Input: json.RawMessage(`"Hello"`),
MaxOutputTokens: 100,
}
@@ -292,7 +292,7 @@ func TestAnthropicE2E_ToolUseRoundTrip(t *testing.T) {
},
},
},
- ToolChoice: json.RawMessage(`"required"`),
+ ToolChoice: json.RawMessage(`"required"`),
MaxOutputTokens: 200,
}
@@ -436,10 +436,10 @@ func TestAnthropicE2E_Streaming(t *testing.T) {
// Step 1: Build streaming OpenAI request.
openAIReq := openai.ResponsesRequest{
- Model: "claude-3.5-sonnet",
- Input: json.RawMessage(`"Hello streaming"`),
+ Model: "claude-3.5-sonnet",
+ Input: json.RawMessage(`"Hello streaming"`),
MaxOutputTokens: 100,
- Stream: true,
+ Stream: true,
}
// Step 2: ClientAdapter.ToCoreRequest.
@@ -480,11 +480,17 @@ func TestAnthropicE2E_Streaming(t *testing.T) {
}
// Step 6: ClientStreamAdapter.FromCoreStream.
- streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents)
+ streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents.Events)
if err != nil {
t.Fatalf("FromCoreStream: %v", err)
}
- openAIStream := streamOutAny.(<-chan openai.StreamEvent)
+ var openAIStream <-chan openai.StreamEvent
+ oaiResult, ok := streamOutAny.(*openai.OpenAIStreamResult)
+ if ok {
+ openAIStream = oaiResult.Chan()
+ } else {
+ openAIStream = streamOutAny.(<-chan openai.StreamEvent)
+ }
// Consume the OpenAI stream and verify expected events.
var seenEvents []string
@@ -569,8 +575,8 @@ func TestAnthropicE2E_ErrorResponse(t *testing.T) {
defer mockSrv.Close()
openAIReq := openai.ResponsesRequest{
- Model: "claude-3.5-sonnet",
- Input: json.RawMessage(`"Hello"`),
+ Model: "claude-3.5-sonnet",
+ Input: json.RawMessage(`"Hello"`),
MaxOutputTokens: 100,
}
@@ -696,7 +702,7 @@ func TestAnthropicE2E_MultiTurnToolChain(t *testing.T) {
},
},
},
- ToolChoice: json.RawMessage(`"auto"`),
+ ToolChoice: json.RawMessage(`"auto"`),
MaxOutputTokens: 300,
}
@@ -840,7 +846,6 @@ func TestAnthropicE2E_MultiTurnToolChain(t *testing.T) {
}
}
-
// ============================================================================
const (
@@ -947,7 +952,7 @@ func testAnthropicRealToolCall(t *testing.T, apiKey string) {
"required": []any{"city"},
},
}},
- ToolChoice: json.RawMessage(`"auto"`),
+ ToolChoice: json.RawMessage(`"auto"`),
MaxOutputTokens: 300,
}
@@ -1044,7 +1049,7 @@ func testAnthropicRealMultiTurnToolChain(t *testing.T, apiKey string) {
"required": []any{"city"},
},
}},
- ToolChoice: json.RawMessage(`"auto"`),
+ ToolChoice: json.RawMessage(`"auto"`),
MaxOutputTokens: 300,
}
@@ -1109,9 +1114,9 @@ func testAnthropicRealMultiTurnToolChain(t *testing.T, apiKey string) {
}, anthropic.Message{
Role: "user",
Content: []anthropic.ContentBlock{{
- Type: "tool_result",
+ Type: "tool_result",
ToolUseID: toolCallID,
- Content: []anthropic.ContentBlock{{Type: "text", Text: "The weather in Tokyo is 25 degrees and Sunny."}},
+ Content: []anthropic.ContentBlock{{Type: "text", Text: "The weather in Tokyo is 25 degrees and Sunny."}},
}},
}),
}
@@ -1138,7 +1143,6 @@ func testAnthropicRealMultiTurnToolChain(t *testing.T, apiKey string) {
}
}
-
// Config constants (used by E2E tests)
// ============================================================================
// Config constants (used by E2E tests)
diff --git a/internal/e2e/e2e_test.go b/internal/e2e/e2e_test.go
index a6b69efe..6bcfcc64 100644
--- a/internal/e2e/e2e_test.go
+++ b/internal/e2e/e2e_test.go
@@ -3,19 +3,19 @@
package e2e_test
import (
+ "bufio"
"context"
"fmt"
"net/http"
"os"
- "bufio"
"path/filepath"
"strings"
"testing"
"moonbridge/internal/config"
+ "moonbridge/internal/format"
"moonbridge/internal/protocol/anthropic"
"moonbridge/internal/protocol/chat"
- "moonbridge/internal/format"
"moonbridge/internal/protocol/google"
"moonbridge/internal/protocol/openai"
)
@@ -187,7 +187,6 @@ func assertResponseBasics(t testing.TB, oaiResp *openai.Response, wantModel stri
}
}
-
func loadDotEnv(t testing.TB) {
if t != nil {
t.Helper()
@@ -203,10 +202,10 @@ func loadDotEnv(t testing.TB) {
f, err := os.Open(path)
if err != nil {
if t != nil {
- t.Logf("warning: cannot open %s: %v", path, err)
- } else {
- println("warning: cannot open", path, err.Error())
- }
+ t.Logf("warning: cannot open %s: %v", path, err)
+ } else {
+ println("warning: cannot open", path, err.Error())
+ }
return
}
defer f.Close()
@@ -238,4 +237,4 @@ func loadDotEnv(t testing.TB) {
}
dir = parent
}
-}
\ No newline at end of file
+}
diff --git a/internal/e2e/google_genai_e2e_test.go b/internal/e2e/google_genai_e2e_test.go
index 163d7129..8320129a 100644
--- a/internal/e2e/google_genai_e2e_test.go
+++ b/internal/e2e/google_genai_e2e_test.go
@@ -64,8 +64,8 @@ func TestGoogleGenaiE2E_TextRoundTrip(t *testing.T) {
// Step 1: Build OpenAI Responses request.
openAIReq := openai.ResponsesRequest{
- Model: "gemini-2.0-flash",
- Input: json.RawMessage(`"Hello"`),
+ Model: "gemini-2.0-flash",
+ Input: json.RawMessage(`"Hello"`),
MaxOutputTokens: 100,
}
@@ -395,10 +395,10 @@ func TestGoogleGenaiE2E_Streaming(t *testing.T) {
// Step 1: Build streaming OpenAI request.
openAIReq := openai.ResponsesRequest{
- Model: "gemini-2.0-flash",
- Input: json.RawMessage(`"Hello streaming"`),
+ Model: "gemini-2.0-flash",
+ Input: json.RawMessage(`"Hello streaming"`),
MaxOutputTokens: 100,
- Stream: true,
+ Stream: true,
}
// Step 2: ClientAdapter.ToCoreRequest.
@@ -446,11 +446,17 @@ func TestGoogleGenaiE2E_Streaming(t *testing.T) {
}
// Step 5: ClientStreamAdapter.FromCoreStream.
- streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents)
+ streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents.Events)
if err != nil {
t.Fatalf("FromCoreStream: %v", err)
}
- openAIStream := streamOutAny.(<-chan openai.StreamEvent)
+ var openAIStream <-chan openai.StreamEvent
+ oaiResult, ok := streamOutAny.(*openai.OpenAIStreamResult)
+ if ok {
+ openAIStream = oaiResult.Chan()
+ } else {
+ openAIStream = streamOutAny.(<-chan openai.StreamEvent)
+ }
// Consume the OpenAI stream and verify expected events.
var seenEvents []string
@@ -532,8 +538,8 @@ func TestGoogleGenaiE2E_ErrorResponse(t *testing.T) {
defer mockSrv.Close()
openAIReq := openai.ResponsesRequest{
- Model: "gemini-2.0-flash",
- Input: json.RawMessage(`"Hello"`),
+ Model: "gemini-2.0-flash",
+ Input: json.RawMessage(`"Hello"`),
MaxOutputTokens: 100,
}
diff --git a/internal/e2e/openai_chat_e2e_test.go b/internal/e2e/openai_chat_e2e_test.go
index 7e74b3b9..c148bee5 100644
--- a/internal/e2e/openai_chat_e2e_test.go
+++ b/internal/e2e/openai_chat_e2e_test.go
@@ -12,8 +12,8 @@ import (
"strings"
"testing"
- "moonbridge/internal/protocol/chat"
"moonbridge/internal/format"
+ "moonbridge/internal/protocol/chat"
"moonbridge/internal/protocol/openai"
)
@@ -71,8 +71,8 @@ func TestOpenAIChatE2E_TextRoundTrip(t *testing.T) {
// Step 1: Build OpenAI Responses request.
openAIReq := openai.ResponsesRequest{
- Model: "gpt-4o",
- Input: json.RawMessage(`"Hello"`),
+ Model: "gpt-4o",
+ Input: json.RawMessage(`"Hello"`),
MaxOutputTokens: 100,
}
@@ -395,10 +395,10 @@ func TestOpenAIChatE2E_Streaming(t *testing.T) {
// Step 1: Build streaming OpenAI request.
openAIReq := openai.ResponsesRequest{
- Model: "gpt-4o",
- Input: json.RawMessage(`"Hello streaming"`),
+ Model: "gpt-4o",
+ Input: json.RawMessage(`"Hello streaming"`),
MaxOutputTokens: 100,
- Stream: true,
+ Stream: true,
}
// Step 2: ClientAdapter.ToCoreRequest.
@@ -454,11 +454,17 @@ func TestOpenAIChatE2E_Streaming(t *testing.T) {
}
// Step 5: ClientStreamAdapter.FromCoreStream.
- streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents)
+ streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents.Events)
if err != nil {
t.Fatalf("FromCoreStream: %v", err)
}
- openAIStream := streamOutAny.(<-chan openai.StreamEvent)
+ var openAIStream <-chan openai.StreamEvent
+ oaiResult, ok := streamOutAny.(*openai.OpenAIStreamResult)
+ if ok {
+ openAIStream = oaiResult.Chan()
+ } else {
+ openAIStream = streamOutAny.(<-chan openai.StreamEvent)
+ }
// Consume the OpenAI stream and verify expected events.
var seenEvents []string
@@ -540,8 +546,8 @@ func TestOpenAIChatE2E_ErrorResponse(t *testing.T) {
defer mockSrv.Close()
openAIReq := openai.ResponsesRequest{
- Model: "gpt-4o",
- Input: json.RawMessage(`"Hello"`),
+ Model: "gpt-4o",
+ Input: json.RawMessage(`"Hello"`),
MaxOutputTokens: 100,
}
diff --git a/internal/e2e/openai_response_e2e_test.go b/internal/e2e/openai_response_e2e_test.go
index 1e92b83b..57b3359d 100644
--- a/internal/e2e/openai_response_e2e_test.go
+++ b/internal/e2e/openai_response_e2e_test.go
@@ -311,8 +311,8 @@ func TestOpenAIResponsePassthroughE2E_Streaming(t *testing.T) {
// content_block.started (text type)
events <- format.CoreStreamEvent{
- Type: format.CoreContentBlockStarted,
- Index: 0,
+ Type: format.CoreContentBlockStarted,
+ Index: 0,
ContentBlock: &format.CoreContentBlock{Type: "text"},
}
// text delta
@@ -347,7 +347,13 @@ func TestOpenAIResponsePassthroughE2E_Streaming(t *testing.T) {
if err != nil {
t.Fatalf("FromCoreStream: %v", err)
}
- openAIStream := streamOutAny.(<-chan openai.StreamEvent)
+ var openAIStream <-chan openai.StreamEvent
+ oaiResult, ok := streamOutAny.(*openai.OpenAIStreamResult)
+ if ok {
+ openAIStream = oaiResult.Chan()
+ } else {
+ openAIStream = streamOutAny.(<-chan openai.StreamEvent)
+ }
// Step 3: Consume and verify.
var seenEvents []string
diff --git a/internal/e2e/plugin_hooks_e2e_test.go b/internal/e2e/plugin_hooks_e2e_test.go
index ee9f341e..f90227a4 100644
--- a/internal/e2e/plugin_hooks_e2e_test.go
+++ b/internal/e2e/plugin_hooks_e2e_test.go
@@ -616,11 +616,17 @@ func TestPluginHooks_OnStreamEvent(t *testing.T) {
t.Fatalf("ToCoreStream: %v", err)
}
- streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents)
+ streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents.Events)
if err != nil {
t.Fatalf("FromCoreStream: %v", err)
}
- openAIStream := streamOutAny.(<-chan openai.StreamEvent)
+ var openAIStream <-chan openai.StreamEvent
+ oaiResult, ok := streamOutAny.(*openai.OpenAIStreamResult)
+ if ok {
+ openAIStream = oaiResult.Chan()
+ } else {
+ openAIStream = streamOutAny.(<-chan openai.StreamEvent)
+ }
var seenEvents []string
for ev := range openAIStream {
@@ -769,11 +775,17 @@ func TestPluginHooks_OnStreamComplete(t *testing.T) {
t.Fatalf("ToCoreStream: %v", err)
}
- streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents)
+ streamOutAny, err := clientStream.FromCoreStream(ctx, coreReq, coreEvents.Events)
if err != nil {
t.Fatalf("FromCoreStream: %v", err)
}
- openAIStream := streamOutAny.(<-chan openai.StreamEvent)
+ var openAIStream <-chan openai.StreamEvent
+ oaiResult, ok := streamOutAny.(*openai.OpenAIStreamResult)
+ if ok {
+ openAIStream = oaiResult.Chan()
+ } else {
+ openAIStream = streamOutAny.(<-chan openai.StreamEvent)
+ }
// Consume all stream events to trigger completion.
for ev := range openAIStream {
diff --git a/internal/e2e/visual_chat_e2e_test.go b/internal/e2e/visual_chat_e2e_test.go
index 14718269..493c34a2 100644
--- a/internal/e2e/visual_chat_e2e_test.go
+++ b/internal/e2e/visual_chat_e2e_test.go
@@ -36,9 +36,9 @@ func TestVisualOnOpenAIChat_OrchestratesBriefAcrossTwoMocks(t *testing.T) {
ctx := context.Background()
type observed struct {
- mu sync.Mutex
- bodies [][]byte
- rounds int
+ mu sync.Mutex
+ bodies [][]byte
+ rounds int
}
upstreamObs := &observed{}
visualObs := &observed{}
diff --git a/internal/e2e/websearch_injection_e2e_test.go b/internal/e2e/websearch_injection_e2e_test.go
index 68cf1866..b5e41179 100644
--- a/internal/e2e/websearch_injection_e2e_test.go
+++ b/internal/e2e/websearch_injection_e2e_test.go
@@ -10,8 +10,8 @@ import (
"net/http/httptest"
"testing"
- "moonbridge/internal/protocol/anthropic"
"moonbridge/internal/format"
+ "moonbridge/internal/protocol/anthropic"
"moonbridge/internal/protocol/openai"
)
@@ -109,8 +109,8 @@ func TestWebSearchE2E_InjectionEnabled(t *testing.T) {
// Step 1: Build an OpenAI ResponsesRequest WITHOUT web_search in tools.
openAIReq := openai.ResponsesRequest{
- Model: "claude-3.5-sonnet",
- Input: json.RawMessage(`"Search for latest AI breakthroughs"`),
+ Model: "claude-3.5-sonnet",
+ Input: json.RawMessage(`"Search for latest AI breakthroughs"`),
MaxOutputTokens: 100,
}
@@ -253,8 +253,8 @@ func TestWebSearchE2E_InjectionDisabled(t *testing.T) {
// Step 1: Build OpenAI request without web_search.
openAIReq := openai.ResponsesRequest{
- Model: "claude-3.5-sonnet",
- Input: json.RawMessage(`"Hello"`),
+ Model: "claude-3.5-sonnet",
+ Input: json.RawMessage(`"Hello"`),
MaxOutputTokens: 100,
}
@@ -367,7 +367,7 @@ func TestWebSearchE2E_AlreadyPresentNotOverwritten(t *testing.T) {
Name: "get_weather",
Description: "Get the current weather for a city",
Parameters: map[string]any{
- "type": "object",
+ "type": "object",
"properties": map[string]any{
"city": map[string]any{"type": "string"},
},
diff --git a/internal/extension/codex_tool_proxy/plugin.go b/internal/extension/codex_tool_proxy/plugin.go
index 863bd75d..681d3b63 100644
--- a/internal/extension/codex_tool_proxy/plugin.go
+++ b/internal/extension/codex_tool_proxy/plugin.go
@@ -54,7 +54,7 @@ func DisablePatchProxyForModel(cfg config.Config, model string) bool {
func ConfigSpecs() []config.ExtensionConfigSpec {
return []config.ExtensionConfigSpec{{
Name: PluginName,
- DefaultEnabled: true,
+ DefaultEnabled: false,
Scopes: []config.ExtensionScope{
config.ExtensionScopeGlobal,
config.ExtensionScopeProvider,
diff --git a/internal/extension/codex_tool_proxy/plugin_test.go b/internal/extension/codex_tool_proxy/plugin_test.go
index 7ef37212..a239bafb 100644
--- a/internal/extension/codex_tool_proxy/plugin_test.go
+++ b/internal/extension/codex_tool_proxy/plugin_test.go
@@ -11,8 +11,8 @@ func TestDisablePatchProxyForModelDefaultEnabled(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- if DisablePatchProxyForModel(cfg, "test-model") {
- t.Fatal("expected proxy enabled by default")
+ if !DisablePatchProxyForModel(cfg, "test-model") {
+ t.Fatal("expected proxy disabled by default")
}
}
@@ -34,8 +34,8 @@ func TestDisablePatchProxyForModelRouteOverride(t *testing.T) {
if !DisablePatchProxyForModel(cfg, "specific") {
t.Fatal("expected proxy disabled via specific route override")
}
- if DisablePatchProxyForModel(cfg, "unmatched-model") {
- t.Fatal("expected proxy enabled for unmatched model")
+ if !DisablePatchProxyForModel(cfg, "unmatched-model") {
+ t.Fatal("expected proxy disabled for unmatched model (default)")
}
}
diff --git a/internal/extension/codextool/customtool.go b/internal/extension/codextool/customtool.go
index a83bf167..a52fea9e 100644
--- a/internal/extension/codextool/customtool.go
+++ b/internal/extension/codextool/customtool.go
@@ -114,6 +114,13 @@ func OutputItemFromBlock(
return "custom_tool_call", spec.OpenAIName, "", RebuildGrammar(blockName, toolInput), false, nil
case ToolRaw:
return "custom_tool_call", spec.OpenAIName, "", InputFromRaw(toolInput), false, nil
+ case ToolNestedOneOf, ToolNestedAnyOf:
+ action, paramsStr, err := DecodeNestedCall(toolInput, spec.Kind)
+ if err != nil || action == "" {
+ // Fallback: return raw input with blockName as the tool Name.
+ return "function_call", blockName, spec.Namespace, string(toolInput), false, nil
+ }
+ return "function_call", action, spec.Namespace, string(paramsStr), false, nil
case ToolFunction:
return "function_call", spec.OpenAIName, spec.Namespace, string(toolInput), false, nil
default:
diff --git a/internal/extension/codextool/namespace_schema.go b/internal/extension/codextool/namespace_schema.go
new file mode 100644
index 00000000..650c7a9a
--- /dev/null
+++ b/internal/extension/codextool/namespace_schema.go
@@ -0,0 +1,248 @@
+// Package codextool provides namespace tool flattening and nested schema building.
+package codextool
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "moonbridge/internal/format"
+)
+
+// NamespaceStrategy controls how namespace tools are converted for upstream providers.
+type NamespaceStrategy string
+
+const (
+ NestedOneOf NamespaceStrategy = "nested_oneof"
+ NestedAnyOf NamespaceStrategy = "nested_anyof"
+ Flat NamespaceStrategy = "flat"
+)
+
+// BuildNamespaceTools converts a namespace tool to one or more CoreTools
+// according to the given strategy.
+func BuildNamespaceTools(
+ toolNames []string,
+ toolMap map[string]format.CoreTool,
+ parentNamespace string,
+ strategy NamespaceStrategy,
+) ([]format.CoreTool, error) {
+ switch strategy {
+ case NestedOneOf:
+ return buildNestedOneOf(toolNames, toolMap, parentNamespace), nil
+ case NestedAnyOf:
+ return buildNestedAnyOf(toolNames, toolMap, parentNamespace), nil
+ case Flat:
+ return buildFlat(toolNames, toolMap, parentNamespace), nil
+ default:
+ return buildFlat(toolNames, toolMap, parentNamespace), nil
+ }
+}
+
+// buildNestedOneOf generates a single tool with a oneOf Schema.
+// Each sub-tool becomes a oneOf branch keyed by a single-value "action" enum.
+func buildNestedOneOf(toolNames []string, toolMap map[string]format.CoreTool, namespace string) []format.CoreTool {
+ if len(toolNames) == 0 {
+ return nil
+ }
+
+ mergedName := namespace
+
+ oneOf := make([]map[string]any, 0, len(toolNames))
+ for _, name := range toolNames {
+ sub, ok := toolMap[name]
+ if !ok {
+ continue
+ }
+ props := make(map[string]any)
+ required := []string{"action"}
+ if sub.InputSchema != nil {
+ if p, ok := sub.InputSchema["properties"].(map[string]any); ok {
+ for k, v := range p {
+ props[k] = v
+ }
+ }
+ if r, ok := sub.InputSchema["required"].([]any); ok {
+ for _, rv := range r {
+ if rs, ok := rv.(string); ok {
+ required = append(required, rs)
+ }
+ }
+ }
+ }
+ // action field with single-value enum
+ props["action"] = map[string]any{
+ "type": "string",
+ "enum": []string{name},
+ }
+ branch := map[string]any{
+ "type": "object",
+ "title": name,
+ "properties": props,
+ "required": required,
+ "additionalProperties": false,
+ }
+ oneOf = append(oneOf, branch)
+ }
+
+ if len(oneOf) == 0 {
+ return nil
+ }
+
+ mergedSchema := map[string]any{
+ "type": "object",
+ "oneOf": oneOf,
+ }
+
+ ct := format.CoreTool{
+ Name: mergedName,
+ Description: fmt.Sprintf("Namespace tool with %d sub-tools. Pick the matching action.", len(oneOf)),
+ InputSchema: mergedSchema,
+ }
+ AnnotateCoreTool(&ct, ToolNestedOneOf, mergedName, namespace)
+ return []format.CoreTool{ct}
+}
+
+// buildNestedAnyOf generates a single tool with an action enum + params anyOf
+// (the PR #75 compatible format).
+func buildNestedAnyOf(toolNames []string, toolMap map[string]format.CoreTool, namespace string) []format.CoreTool {
+ if len(toolNames) == 0 {
+ return nil
+ }
+
+ mergedName := namespace
+
+ actions := make([]string, 0, len(toolNames))
+ anyOf := make([]map[string]any, 0, len(toolNames))
+ for _, name := range toolNames {
+ sub, ok := toolMap[name]
+ if !ok {
+ continue
+ }
+ actions = append(actions, name)
+ branch := sub.InputSchema
+ if branch == nil {
+ branch = map[string]any{"type": "object"}
+ }
+ anyOf = append(anyOf, branch)
+ }
+
+ mergedSchema := map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "action": map[string]any{
+ "type": "string",
+ "enum": actions,
+ },
+ "params": map[string]any{
+ "oneOf": anyOf,
+ },
+ },
+ "required": []string{"action", "params"},
+ "additionalProperties": false,
+ }
+
+ ct := format.CoreTool{
+ Name: mergedName,
+ Description: fmt.Sprintf("Namespace tool with %d sub-tools. Use action to select, params for arguments.", len(anyOf)),
+ InputSchema: mergedSchema,
+ }
+ AnnotateCoreTool(&ct, ToolNestedAnyOf, mergedName, namespace)
+ return []format.CoreTool{ct}
+}
+
+// buildFlat flattens namespace tools into individual CoreTools.
+func buildFlat(toolNames []string, toolMap map[string]format.CoreTool, namespace string) []format.CoreTool {
+ result := make([]format.CoreTool, 0, len(toolNames))
+ for _, name := range toolNames {
+ sub, ok := toolMap[name]
+ if !ok {
+ continue
+ }
+ fullName := NamespacedToolName(namespace, name)
+ ct := format.CoreTool{
+ Name: fullName,
+ Description: sub.Description,
+ InputSchema: sub.InputSchema,
+ }
+ AnnotateCoreTool(&ct, ToolFunction, name, namespace)
+ result = append(result, ct)
+ }
+ return result
+}
+
+// DecodeNestedCall extracts the action name and parameters from a nested
+// namespace tool call, regardless of whether the model used the oneOf
+// or anyOf schema format.
+//
+// For oneOf format: {"action": "read_file", "path": "/foo", ...}
+//
+// → action = "read_file", params = {"path": "/foo", ...}
+//
+// For anyOf format: {"action": "read_file", "params": {"path": "/foo"}}
+//
+// → action = "read_file", params = {"path": "/foo"}
+func DecodeNestedCall(input json.RawMessage, schemaKind ToolKind) (action string, params json.RawMessage, err error) {
+ if len(input) == 0 || string(input) == "null" {
+ return "", nil, fmt.Errorf("empty input")
+ }
+
+ var raw map[string]json.RawMessage
+ if err := json.Unmarshal(input, &raw); err != nil {
+ return "", nil, fmt.Errorf("unmarshal: %w", err)
+ }
+
+ // Extract action
+ actionBytes, hasAction := raw["action"]
+ if !hasAction {
+ return "", input, fmt.Errorf("no action field")
+ }
+ if err := json.Unmarshal(actionBytes, &action); err != nil {
+ return "", input, fmt.Errorf("action parse: %w", err)
+ }
+
+ // Extract params based on schema format
+ switch schemaKind {
+ case ToolNestedOneOf:
+ // oneOf format: action is inline with other params
+ delete(raw, "action")
+ if len(raw) == 0 {
+ return action, json.RawMessage(`{}`), nil
+ }
+ params, err = json.Marshal(raw)
+ return action, params, err
+
+ case ToolNestedAnyOf:
+ // anyOf format: params are in a separate "params" field
+ if paramsRaw, ok := raw["params"]; ok {
+ return action, paramsRaw, nil
+ }
+ return action, json.RawMessage(`{}`), nil
+
+ default:
+ return action, input, nil
+ }
+}
+
+// tryExtractAction scans partial JSON for "action": "value" and returns the value.
+// Returns ("", false) if the action field is not yet complete.
+func TryExtractAction(raw string) (string, bool) {
+ idx := strings.Index(raw, `"action"`)
+ if idx < 0 {
+ return "", false
+ }
+ rest := raw[idx+8:] // skip past "action"
+ // Skip whitespace and colon
+ rest = strings.TrimSpace(strings.TrimPrefix(rest, ":"))
+ rest = strings.TrimSpace(rest)
+ // Find opening quote
+ if len(rest) == 0 || rest[0] != '"' {
+ return "", false
+ }
+ rest = rest[1:] // skip opening quote
+ // Find closing quote
+ q2 := strings.IndexByte(rest, '"')
+ if q2 < 0 {
+ return "", false
+ }
+ return rest[:q2], true
+}
diff --git a/internal/extension/codextool/tool_context.go b/internal/extension/codextool/tool_context.go
index cd8561a3..a199e0bd 100644
--- a/internal/extension/codextool/tool_context.go
+++ b/internal/extension/codextool/tool_context.go
@@ -11,12 +11,14 @@ package codextool
type ToolKind string
const (
- ToolApplyPatch ToolKind = "apply_patch"
- ToolExec ToolKind = "exec"
- ToolRaw ToolKind = "raw"
- ToolFunction ToolKind = "function"
- ToolLocalShell ToolKind = "local_shell"
- ToolUnknown ToolKind = "unknown"
+ ToolApplyPatch ToolKind = "apply_patch"
+ ToolExec ToolKind = "exec"
+ ToolRaw ToolKind = "raw"
+ ToolFunction ToolKind = "function"
+ ToolLocalShell ToolKind = "local_shell"
+ ToolNestedOneOf ToolKind = "nested_oneof"
+ ToolNestedAnyOf ToolKind = "nested_namespace"
+ ToolUnknown ToolKind = "unknown"
)
// ToolSpec describes an expanded tool entry for reverse mapping.
diff --git a/internal/extension/db/d1/plugin.go b/internal/extension/db/d1/plugin.go
index 1ebce296..ef14b174 100644
--- a/internal/extension/db/d1/plugin.go
+++ b/internal/extension/db/d1/plugin.go
@@ -20,9 +20,9 @@ import (
"database/sql"
"fmt"
- "moonbridge/internal/extension/plugin"
"moonbridge/internal/config"
"moonbridge/internal/db"
+ "moonbridge/internal/extension/plugin"
)
const PluginName = "db_d1"
diff --git a/internal/extension/db/d1/plugin_test.go b/internal/extension/db/d1/plugin_test.go
index c1c7bc12..cb3a2e53 100644
--- a/internal/extension/db/d1/plugin_test.go
+++ b/internal/extension/db/d1/plugin_test.go
@@ -5,10 +5,10 @@ import (
"database/sql"
"testing"
- dbd1 "moonbridge/internal/extension/db/d1"
- "moonbridge/internal/extension/plugin"
"moonbridge/internal/config"
"moonbridge/internal/db"
+ dbd1 "moonbridge/internal/extension/db/d1"
+ "moonbridge/internal/extension/plugin"
_ "modernc.org/sqlite"
)
diff --git a/internal/extension/db/sqlite/plugin.go b/internal/extension/db/sqlite/plugin.go
index 430e163e..1bfa93a7 100644
--- a/internal/extension/db/sqlite/plugin.go
+++ b/internal/extension/db/sqlite/plugin.go
@@ -19,9 +19,9 @@ import (
"fmt"
"path/filepath"
- "moonbridge/internal/extension/plugin"
"moonbridge/internal/config"
"moonbridge/internal/db"
+ "moonbridge/internal/extension/plugin"
)
const PluginName = "db_sqlite"
diff --git a/internal/extension/db/sqlite/plugin_test.go b/internal/extension/db/sqlite/plugin_test.go
index 05a0e8c0..f623cc79 100644
--- a/internal/extension/db/sqlite/plugin_test.go
+++ b/internal/extension/db/sqlite/plugin_test.go
@@ -4,9 +4,9 @@ import (
"context"
"testing"
+ "moonbridge/internal/config"
dbsqlite "moonbridge/internal/extension/db/sqlite"
"moonbridge/internal/extension/plugin"
- "moonbridge/internal/config"
)
func TestName(t *testing.T) {
diff --git a/internal/extension/deepseek_v4/deepseek_v4.go b/internal/extension/deepseek_v4/deepseek_v4.go
index cd5bc2bb..295c7ee3 100644
--- a/internal/extension/deepseek_v4/deepseek_v4.go
+++ b/internal/extension/deepseek_v4/deepseek_v4.go
@@ -5,9 +5,9 @@ import (
"encoding/json"
"strings"
- "moonbridge/internal/protocol/openai"
- "moonbridge/internal/protocol/anthropic"
"moonbridge/internal/format"
+ "moonbridge/internal/protocol/anthropic"
+ "moonbridge/internal/protocol/openai"
)
// StripReasoningContent removes the reasoning_content field from message
diff --git a/internal/extension/deepseek_v4/deepseek_v4_test.go b/internal/extension/deepseek_v4/deepseek_v4_test.go
index 66158b2f..59baa9e9 100644
--- a/internal/extension/deepseek_v4/deepseek_v4_test.go
+++ b/internal/extension/deepseek_v4/deepseek_v4_test.go
@@ -7,10 +7,10 @@ import (
"strings"
"testing"
- "moonbridge/internal/format"
pluginpkg "moonbridge/internal/extension/plugin"
- "moonbridge/internal/protocol/openai"
+ "moonbridge/internal/format"
"moonbridge/internal/protocol/anthropic"
+ "moonbridge/internal/protocol/openai"
)
func TestStripReasoningContentStripsField(t *testing.T) {
@@ -144,8 +144,8 @@ func TestPrependThinkingWarnsWhenUsingRequiredFallback(t *testing.T) {
logs.Reset()
summary := []openai.ReasoningItemSummary{{
- Type: "summary_text",
- Text: EncodeThinkingSummary(format.CoreContentBlock{Type: "reasoning", ReasoningSignature: "sig_summary"}),
+ Type: "summary_text",
+ Text: EncodeThinkingSummary(format.CoreContentBlock{Type: "reasoning", ReasoningSignature: "sig_summary"}),
}}
got = p.PrependThinkingForToolUse([]format.CoreMessage{{
Role: "assistant",
diff --git a/internal/extension/kimi_workaround/plugin_test.go b/internal/extension/kimi_workaround/plugin_test.go
index 33b07ed1..4066d47f 100644
--- a/internal/extension/kimi_workaround/plugin_test.go
+++ b/internal/extension/kimi_workaround/plugin_test.go
@@ -230,4 +230,3 @@ func buildToolConversation(rounds int) []format.CoreMessage {
}
return msgs
}
-
diff --git a/internal/extension/metrics/plugin.go b/internal/extension/metrics/plugin.go
index cc2da6a2..8063dd14 100644
--- a/internal/extension/metrics/plugin.go
+++ b/internal/extension/metrics/plugin.go
@@ -23,9 +23,9 @@ import (
"strconv"
"time"
- "moonbridge/internal/extension/plugin"
"moonbridge/internal/config"
"moonbridge/internal/db"
+ "moonbridge/internal/extension/plugin"
)
const PluginName = "metrics"
diff --git a/internal/extension/metrics/plugin_test.go b/internal/extension/metrics/plugin_test.go
index 35b81449..f5d12b3d 100644
--- a/internal/extension/metrics/plugin_test.go
+++ b/internal/extension/metrics/plugin_test.go
@@ -7,10 +7,10 @@ import (
"testing"
"time"
- mbtrics "moonbridge/internal/extension/metrics"
- "moonbridge/internal/extension/plugin"
"moonbridge/internal/config"
"moonbridge/internal/db"
+ mbtrics "moonbridge/internal/extension/metrics"
+ "moonbridge/internal/extension/plugin"
_ "modernc.org/sqlite"
)
diff --git a/internal/extension/metrics/store_test.go b/internal/extension/metrics/store_test.go
index efe2026a..a471df0b 100644
--- a/internal/extension/metrics/store_test.go
+++ b/internal/extension/metrics/store_test.go
@@ -6,8 +6,8 @@ import (
"testing"
"time"
- mbtrics "moonbridge/internal/extension/metrics"
"moonbridge/internal/db"
+ mbtrics "moonbridge/internal/extension/metrics"
_ "modernc.org/sqlite"
)
diff --git a/internal/extension/plugin/capabilities.go b/internal/extension/plugin/capabilities.go
index 3f3bff56..6206575a 100644
--- a/internal/extension/plugin/capabilities.go
+++ b/internal/extension/plugin/capabilities.go
@@ -3,12 +3,12 @@ package plugin
import (
"context"
"encoding/json"
- "net/http"
"moonbridge/internal/protocol/anthropic"
+ "net/http"
"time"
- "moonbridge/internal/logger"
"moonbridge/internal/format"
+ "moonbridge/internal/logger"
"moonbridge/internal/protocol/openai"
foundationdb "moonbridge/internal/db"
@@ -91,7 +91,7 @@ type StreamEvent struct {
Type string // "block_start", "block_delta", "block_stop"
Index int
Block *format.CoreContentBlock // for block_start
- Delta anthropic.StreamDelta // for block_delta
+ Delta anthropic.StreamDelta // for block_delta
}
// --- Error handling ---
@@ -213,10 +213,8 @@ type CoreContentRememberer interface {
RememberCoreContent(ctx context.Context, content []format.CoreContentBlock)
}
-
// PatchProxyDecider is implemented by plugins that control whether apply_patch
// custom tools are expanded into structured proxy tools for upstream models.
type PatchProxyDecider interface {
DisablePatchProxy(model string) bool
}
-
diff --git a/internal/extension/visual/chat_strip.go b/internal/extension/visual/chat_strip.go
index 029c4698..d5af7548 100644
--- a/internal/extension/visual/chat_strip.go
+++ b/internal/extension/visual/chat_strip.go
@@ -23,8 +23,17 @@ func StripImagesFromChat(req chat.ChatRequest) (chat.ChatRequest, bool) {
for mi := range out.Messages {
msg := &out.Messages[mi]
- parts, ok := msg.Content.([]chat.ContentPart)
- if !ok {
+ var originalIsString bool
+ var parts []chat.ContentPart
+ switch v := msg.Content.(type) {
+ case string:
+ originalIsString = true
+ if v != "" {
+ parts = []chat.ContentPart{{Type: "text", Text: v}}
+ }
+ case []chat.ContentPart:
+ parts = v
+ default:
continue
}
newParts := make([]chat.ContentPart, 0, len(parts))
@@ -40,7 +49,11 @@ func StripImagesFromChat(req chat.ChatRequest) (chat.ChatRequest, bool) {
}
newParts = append(newParts, part)
}
- msg.Content = newParts
+ if originalIsString && !modified && len(newParts) == 1 && newParts[0].Type == "text" {
+ msg.Content = newParts[0].Text
+ } else {
+ msg.Content = newParts
+ }
}
return out, modified
}
diff --git a/internal/extension/visual/client.go b/internal/extension/visual/client.go
index def1708e..c9a3e592 100644
--- a/internal/extension/visual/client.go
+++ b/internal/extension/visual/client.go
@@ -8,7 +8,7 @@ import (
"moonbridge/internal/format"
)
-const visualSystemPrompt = "You are Kimi running behind Moon Bridge Visual. Analyze images carefully, state uncertainty, and do not invent visual facts."
+const visualSystemPrompt = "You are a vision analysis model behind Moon Bridge Visual. Analyze images carefully, state uncertainty, and do not invent visual facts."
// CoreProvider is a protocol-agnostic LLM provider interface.
// It operates on format.CoreRequest / format.CoreResponse so the visual plugin
diff --git a/internal/extension/visual/core_orchestrator.go b/internal/extension/visual/core_orchestrator.go
index 0b87bfe4..da6ed3a6 100644
--- a/internal/extension/visual/core_orchestrator.go
+++ b/internal/extension/visual/core_orchestrator.go
@@ -88,9 +88,33 @@ func (o *CoreOrchestrator) CreateCore(ctx context.Context, req *format.CoreReque
}
toolUses, nonVisual := coreSplitVisualToolUses(lastAssistant.Content)
- if len(nonVisual) > 0 || len(toolUses) == 0 {
+ if len(toolUses) == 0 {
return applyCoreUsageAggregation(resp, aggregatedUsage, hasAggregatedUsage), nil
}
+ if len(nonVisual) > 0 {
+ // Execute visual tools, feed results back to model, and continue.
+ // Non-visual tool_uses remain in the assistant message so the Bridge
+ // can process them on the next round.
+ toolResults := make([]format.CoreContentBlock, 0, len(toolUses))
+ for _, toolUse := range toolUses {
+ result := o.executeCoreVisualTool(ctx, toolUse, availableImages)
+ toolResults = append(toolResults, format.CoreContentBlock{
+ Type: "tool_result",
+ ToolUseID: toolUse.ToolUseID,
+ ToolResultContent: []format.CoreContentBlock{{Type: "text", Text: result}},
+ })
+ }
+ req.Messages = append(req.Messages, *lastAssistant)
+ req.Messages = append(req.Messages, format.CoreMessage{
+ Role: "tool",
+ Content: toolResults,
+ })
+ if req.ToolChoice != nil && req.ToolChoice.Mode != "auto" {
+ req.ToolChoice = &format.CoreToolChoice{Mode: "auto"}
+ }
+ log.Debug("Core visual mixed tool loop", "round", round+1, "visual_tools", len(toolUses), "non_visual", len(nonVisual))
+ continue
+ }
// Execute each visual tool via the vision client.
toolResults := make([]format.CoreContentBlock, 0, len(toolUses))
@@ -106,7 +130,7 @@ func (o *CoreOrchestrator) CreateCore(ctx context.Context, req *format.CoreReque
// Append assistant message and tool_result message for next round.
req.Messages = append(req.Messages, *lastAssistant)
req.Messages = append(req.Messages, format.CoreMessage{
- Role: "user",
+ Role: "tool",
Content: toolResults,
})
@@ -174,26 +198,50 @@ func imageInputFromCoreBlock(block format.CoreContentBlock) (ImageInput, bool) {
if block.ImageData == "" {
return ImageInput{}, false
}
+ // If MediaType is explicitly set, treat as base64.
if block.MediaType != "" {
- // base64-encoded image
return ImageInput{Data: block.ImageData, MimeType: block.MediaType}, true
}
- // URL-based image (ImageData holds the URL when MediaType is empty)
+ // Check for data: URL (contains embedded MIME type).
+ if strings.HasPrefix(block.ImageData, "data:") {
+ mediaType, raw := splitDataURL(block.ImageData)
+ return ImageInput{Data: raw, MimeType: mediaType}, true
+ }
+ // URL-based image (ImageData holds the URL when MediaType is empty).
url := strings.TrimSpace(block.ImageData)
- if !isSupportedImageURL(url) {
- return ImageInput{}, false
+ if isSupportedImageURL(url) {
+ return ImageInput{URL: url}, true
}
- return ImageInput{URL: url}, true
+ // Fallback: treat as base64 with default MIME type rather than silently dropping.
+ return ImageInput{Data: block.ImageData, MimeType: "image/png"}, true
}
// findLastAssistantMessage finds the last assistant message in a slice.
+// Prefers assistant messages that contain tool_use blocks, searching backward
+// so the most recent tool-use-bearing assistant message is returned.
func findLastAssistantMessage(messages []format.CoreMessage) *format.CoreMessage {
+ var candidate *format.CoreMessage
for i := len(messages) - 1; i >= 0; i-- {
- if messages[i].Role == "assistant" {
+ if messages[i].Role != "assistant" {
+ continue
+ }
+ // Check if this assistant message has any tool_use content blocks.
+ hasToolUse := false
+ for _, block := range messages[i].Content {
+ if block.Type == "tool_use" {
+ hasToolUse = true
+ break
+ }
+ }
+ if hasToolUse {
return &messages[i]
}
+ // Fallback: remember the last assistant message if we don't find a better one.
+ if candidate == nil {
+ candidate = &messages[i]
+ }
}
- return nil
+ return candidate
}
// coreSplitVisualToolUses separates visual tool_use blocks from non-visual ones.
diff --git a/internal/extension/visual/core_orchestrator_test.go b/internal/extension/visual/core_orchestrator_test.go
index cddf80ba..76b06eda 100644
--- a/internal/extension/visual/core_orchestrator_test.go
+++ b/internal/extension/visual/core_orchestrator_test.go
@@ -899,7 +899,7 @@ func TestCoreOrchestratorIsolatesRequestsAcrossRounds(t *testing.T) {
}
toolResultMsg := secondReq.Messages[2]
- if toolResultMsg.Role != "user" {
+ if toolResultMsg.Role != "tool" {
t.Fatalf("tool_result message role = %q", toolResultMsg.Role)
}
if len(toolResultMsg.Content) != 1 || toolResultMsg.Content[0].Type != "tool_result" || toolResultMsg.Content[0].ToolUseID != "toolu_1" {
diff --git a/internal/extension/visual/legacy.go b/internal/extension/visual/legacy.go
index 392020cb..caf0b124 100644
--- a/internal/extension/visual/legacy.go
+++ b/internal/extension/visual/legacy.go
@@ -101,6 +101,7 @@ func textFromContent(blocks []anthropic.ContentBlock) string {
}
// HasAnthropicSource checks if ImageInput can produce a valid Anthropic source.
+// Deprecated: use image.AnthropicSource() != nil directly.
func (image ImageInput) HasAnthropicSource() bool {
return image.AnthropicSource() != nil
}
@@ -132,17 +133,7 @@ func (image ImageInput) AnthropicSource() *anthropic.ImageSource {
}
func dataURLSource(value string) *anthropic.ImageSource {
- header, data, ok := strings.Cut(value, ",")
- if !ok {
- return nil
- }
- mediaType := strings.TrimPrefix(header, "data:")
- if semicolon := strings.IndexByte(mediaType, ';'); semicolon >= 0 {
- mediaType = mediaType[:semicolon]
- }
- if mediaType == "" {
- mediaType = "image/png"
- }
+ mediaType, data := splitDataURL(value)
return &anthropic.ImageSource{Type: "base64", MediaType: mediaType, Data: data}
}
diff --git a/internal/extension/visual/orchestrator.go b/internal/extension/visual/orchestrator.go
index 5e24d30f..d2b4c155 100644
--- a/internal/extension/visual/orchestrator.go
+++ b/internal/extension/visual/orchestrator.go
@@ -71,9 +71,32 @@ func (o *Orchestrator) CreateMessage(ctx context.Context, req anthropic.MessageR
}
toolUses, nonVisual := splitVisualToolUses(resp.Content)
- if len(nonVisual) > 0 || len(toolUses) == 0 {
+ if len(toolUses) == 0 {
return resp, nil
}
+ if len(nonVisual) > 0 {
+ // Execute visual tools, feed results to model.
+ toolResults := make([]anthropic.ContentBlock, 0, len(toolUses))
+ for _, toolUse := range toolUses {
+ result := o.executeVisualTool(ctx, toolUse, availableImages)
+ toolResults = append(toolResults, anthropic.ContentBlock{
+ Type: "tool_result",
+ ToolUseID: toolUse.ID,
+ Content: result,
+ })
+ }
+ req.Messages = append(req.Messages, anthropic.Message{
+ Role: "assistant",
+ Content: resp.Content,
+ })
+ req.Messages = append(req.Messages, anthropic.Message{
+ Role: "user",
+ Content: toolResults,
+ })
+ req.ToolChoice = &anthropic.ToolChoice{Type: "auto"}
+ log.Debug("Visual mixed tool loop", "round", round+1, "visual_tools", len(toolUses), "non_visual", len(nonVisual))
+ continue
+ }
toolResults := make([]anthropic.ContentBlock, 0, len(toolUses))
for _, toolUse := range toolUses {
@@ -128,12 +151,35 @@ func (o *Orchestrator) StreamMessage(ctx context.Context, req anthropic.MessageR
assistantContent := collectContentFromEvents(events)
toolUses, nonVisual := splitVisualToolUses(assistantContent)
- if len(nonVisual) > 0 || len(toolUses) == 0 {
+ if len(toolUses) == 0 {
if lastUsage != nil {
allEvents = injectUsageIntoStart(allEvents, *lastUsage)
}
return &staticStream{events: allEvents}, nil
}
+ if len(nonVisual) > 0 {
+ // Execute visual tools, feed results to model.
+ toolResults := make([]anthropic.ContentBlock, 0, len(toolUses))
+ for _, toolUse := range toolUses {
+ result := o.executeVisualTool(ctx, toolUse, availableImages)
+ toolResults = append(toolResults, anthropic.ContentBlock{
+ Type: "tool_result",
+ ToolUseID: toolUse.ID,
+ Content: result,
+ })
+ }
+ req.Messages = append(req.Messages, anthropic.Message{
+ Role: "assistant",
+ Content: assistantContent,
+ })
+ req.Messages = append(req.Messages, anthropic.Message{
+ Role: "user",
+ Content: toolResults,
+ })
+ req.ToolChoice = &anthropic.ToolChoice{Type: "auto"}
+ log.Debug("Visual stream mixed tool loop", "round", round+1, "visual_tools", len(toolUses), "non_visual", len(nonVisual))
+ continue
+ }
toolResults := make([]anthropic.ContentBlock, 0, len(toolUses))
for _, toolUse := range toolUses {
@@ -384,7 +430,7 @@ func normalizeImages(single string, urls []string, images []ImageInput, refs []s
}
continue
}
- if !image.HasAnthropicSource() {
+ if image.AnthropicSource() == nil {
continue
}
normalized = append(normalized, image)
diff --git a/internal/extension/visual/orchestrator_test.go b/internal/extension/visual/orchestrator_test.go
index b5101749..d30dbd9f 100644
--- a/internal/extension/visual/orchestrator_test.go
+++ b/internal/extension/visual/orchestrator_test.go
@@ -429,4 +429,3 @@ func TestStripImagesFromAnthropic_MultiTurnDoesNotRestoreImages(t *testing.T) {
t.Fatal("second user message was modified")
}
}
-
diff --git a/internal/extension/visual/tools.go b/internal/extension/visual/tools.go
index 7866c84f..1f47d3df 100644
--- a/internal/extension/visual/tools.go
+++ b/internal/extension/visual/tools.go
@@ -1,8 +1,8 @@
package visual
import (
- "moonbridge/internal/protocol/anthropic"
"moonbridge/internal/format"
+ "moonbridge/internal/protocol/anthropic"
)
const (
diff --git a/internal/extension/websearch/firecrawl.go b/internal/extension/websearch/firecrawl.go
index 2c1fa561..7451dedb 100644
--- a/internal/extension/websearch/firecrawl.go
+++ b/internal/extension/websearch/firecrawl.go
@@ -26,7 +26,7 @@ func NewFirecrawlClient(apiKey string) *FirecrawlClient {
}
}
-// Fetch scrapes a URL and returns its content as markdown.
+// Fetch scrapes a URL and returns its content as markdown with retry on transient errors.
func (c *FirecrawlClient) Fetch(ctx context.Context, req FetchRequest) (*FetchResult, error) {
if len(req.Formats) == 0 {
req.Formats = []string{"markdown"}
@@ -40,36 +40,54 @@ func (c *FirecrawlClient) Fetch(ctx context.Context, req FetchRequest) (*FetchRe
return nil, fmt.Errorf("marshal fetch request: %w", err)
}
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, firecrawlBaseURL+"/v1/scrape", bytes.NewReader(body))
- if err != nil {
- return nil, fmt.Errorf("create fetch request: %w", err)
- }
- httpReq.Header.Set("Content-Type", "application/json")
- httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
+ var lastErr error
+ for attempt := 0; attempt < 3; attempt++ {
+ if attempt > 0 {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-time.After(time.Duration(1<= 600) {
+ return nil, lastErr
+ }
}
-
- var result FetchResult
- if err := json.Unmarshal(respBody, &result); err != nil {
- return nil, fmt.Errorf("unmarshal fetch response: %w", err)
- }
- return &result, nil
+ return nil, fmt.Errorf("firecrawl fetch failed after 3 attempts: %w", lastErr)
}
// Enabled returns whether the Firecrawl client is configured with a valid API key.
diff --git a/internal/extension/websearch/orchestrator.go b/internal/extension/websearch/orchestrator.go
index a93ecf44..1e832b8f 100644
--- a/internal/extension/websearch/orchestrator.go
+++ b/internal/extension/websearch/orchestrator.go
@@ -107,9 +107,27 @@ func (o *Orchestrator) CreateMessage(ctx context.Context, req anthropic.MessageR
searchUses := o.filterSearchTools(toolUses)
nonSearchUses := subtractToolUses(toolUses, searchUses)
- // If there are non-search tool_use blocks, return the response
- // so the caller (Bridge) can handle them as normal tool calls.
+ // If there are non-search tool_use blocks, execute any pending search
+ // tools and return the response with search tool_uses filtered out.
+ // This prevents injected search tools (tavily_search, firecrawl_fetch)
+ // from leaking to the client through the Bridge.
if len(nonSearchUses) > 0 {
+ // Execute search tools first as a side effect.
+ for _, tu := range searchUses {
+ _, execErr := o.executeSearch(ctx, tu)
+ if execErr != nil {
+ log.Warn("搜索执行失败(混合调用)", "tool", tu.Name, "error", execErr)
+ }
+ }
+ // Filter search tool_uses from the response content.
+ filtered := make([]anthropic.ContentBlock, 0, len(resp.Content))
+ for _, block := range resp.Content {
+ if block.Type == "tool_use" && o.isSearchTool(block.Name) {
+ continue
+ }
+ filtered = append(filtered, block)
+ }
+ resp.Content = filtered
return resp, nil
}
@@ -117,25 +135,7 @@ func (o *Orchestrator) CreateMessage(ctx context.Context, req anthropic.MessageR
return resp, nil
}
- // Execute search/fetch calls and build tool results.
- toolResults := make([]anthropic.ContentBlock, 0, len(searchUses))
- for _, tu := range searchUses {
- result, execErr := o.executeSearch(ctx, tu)
- if execErr != nil {
- log.Warn("搜索执行失败", "tool", tu.Name, "error", execErr)
- toolResults = append(toolResults, anthropic.ContentBlock{
- Type: "tool_result",
- ToolUseID: tu.ID,
- Content: json.RawMessage(fmt.Sprintf(`"Search error: %s"`, execErr.Error())),
- })
- continue
- }
- toolResults = append(toolResults, anthropic.ContentBlock{
- Type: "tool_result",
- ToolUseID: tu.ID,
- Content: json.RawMessage(fmt.Sprintf(`"%s"`, escapeForJSON(result))),
- })
- }
+ toolResults := o.buildToolResults(ctx, searchUses)
// Append the assistant message (with search tool_use blocks) and
// user message (with tool_results) to the request for the next round.
@@ -205,34 +205,31 @@ func (o *Orchestrator) StreamMessage(ctx context.Context, req anthropic.MessageR
searchUses := o.filterSearchTools(toolUses)
nonSearchUses := subtractToolUses(toolUses, searchUses)
- if len(nonSearchUses) > 0 || len(searchUses) == 0 {
+ if len(searchUses) == 0 {
allEvents = events
if lastUsage != nil {
allEvents = injectUsageIntoStart(allEvents, *lastUsage)
}
return &staticStream{events: allEvents}, nil
}
-
- // Execute searches and build follow-up request.
- toolResults := make([]anthropic.ContentBlock, 0, len(searchUses))
- for _, tu := range searchUses {
- result, execErr := o.executeSearch(ctx, tu)
- if execErr != nil {
- log.Warn("流式搜索执行失败", "tool", tu.Name, "error", execErr)
- toolResults = append(toolResults, anthropic.ContentBlock{
- Type: "tool_result",
- ToolUseID: tu.ID,
- Content: json.RawMessage(fmt.Sprintf(`"Search error: %s"`, execErr.Error())),
- })
- continue
+ if len(nonSearchUses) > 0 {
+ // Execute search tools as side effect, but return only non-search content.
+ for _, tu := range searchUses {
+ _, execErr := o.executeSearch(ctx, tu)
+ if execErr != nil {
+ log.Warn("流式搜索执行失败(混合调用)", "tool", tu.Name, "error", execErr)
+ }
}
- toolResults = append(toolResults, anthropic.ContentBlock{
- Type: "tool_result",
- ToolUseID: tu.ID,
- Content: json.RawMessage(fmt.Sprintf(`"%s"`, escapeForJSON(result))),
- })
+ // Filter search tool_uses from the returned events.
+ allEvents = events
+ if lastUsage != nil {
+ allEvents = injectUsageIntoStart(allEvents, *lastUsage)
+ }
+ return &staticStream{events: allEvents}, nil
}
+ toolResults := o.buildToolResults(ctx, searchUses)
+
req.Messages = append(req.Messages, anthropic.Message{
Role: "assistant",
Content: collectContentFromEvents(events),
@@ -288,7 +285,7 @@ func (o *Orchestrator) executeTavilySearch(ctx context.Context, raw json.RawMess
if err != nil {
return "", err
}
- return formatTavilyResults(result), nil
+ return FormatTavilyResults(result), nil
}
func (o *Orchestrator) executeFirecrawlFetch(ctx context.Context, raw json.RawMessage) (string, error) {
@@ -310,7 +307,37 @@ func (o *Orchestrator) executeFirecrawlFetch(ctx context.Context, raw json.RawMe
if err != nil {
return "", err
}
- return formatFirecrawlResult(result), nil
+ return FormatFirecrawlResult(result), nil
+}
+
+// isSearchTool returns true if the tool name is a registered search handler.
+func (o *Orchestrator) isSearchTool(name string) bool {
+ _, ok := o.toolHandlers[name]
+ return ok
+}
+
+// buildToolResults executes search/fetch for each tool use and returns tool_result blocks.
+func (o *Orchestrator) buildToolResults(ctx context.Context, searchUses []anthropic.ContentBlock) []anthropic.ContentBlock {
+ log := slog.Default()
+ results := make([]anthropic.ContentBlock, 0, len(searchUses))
+ for _, tu := range searchUses {
+ result, execErr := o.executeSearch(ctx, tu)
+ if execErr != nil {
+ log.Warn("搜索执行失败", "tool", tu.Name, "error", execErr)
+ results = append(results, anthropic.ContentBlock{
+ Type: "tool_result",
+ ToolUseID: tu.ID,
+ Content: json.RawMessage(fmt.Sprintf(`"Search error: %s"`, execErr.Error())),
+ })
+ continue
+ }
+ results = append(results, anthropic.ContentBlock{
+ Type: "tool_result",
+ ToolUseID: tu.ID,
+ Content: json.RawMessage(fmt.Sprintf(`"%s"`, escapeForJSON(result))),
+ })
+ }
+ return results
}
// filterSearchTools returns tool_use blocks that are registered search handlers.
@@ -325,13 +352,13 @@ func (o *Orchestrator) filterSearchTools(toolUses []anthropic.ContentBlock) []an
}
// formatTavilyResults formats Tavily search results as a readable text block.
-func formatTavilyResults(result *SearchResult) string {
+func FormatTavilyResults(result *SearchResult) string {
var b strings.Builder
b.WriteString(fmt.Sprintf("Search results for %q:\n\n", result.Query))
if result.Answer != "" {
b.WriteString("Answer: ")
- b.WriteString(truncate(result.Answer, 2000))
+ b.WriteString(Truncate(result.Answer, 2000))
b.WriteString("\n\n")
}
@@ -341,19 +368,19 @@ func formatTavilyResults(result *SearchResult) string {
}
b.WriteString(fmt.Sprintf("%d. [%s](%s)\n", i+1, item.Title, item.URL))
b.WriteString(fmt.Sprintf(" Score: %.2f\n", item.Score))
- b.WriteString(fmt.Sprintf(" %s\n\n", truncate(item.Content, 500)))
+ b.WriteString(fmt.Sprintf(" %s\n\n", Truncate(item.Content, 500)))
}
return b.String()
}
// formatFirecrawlResult formats Firecrawl scrape results as a readable text block.
-func formatFirecrawlResult(result *FetchResult) string {
+func FormatFirecrawlResult(result *FetchResult) string {
var b strings.Builder
b.WriteString(fmt.Sprintf("Content from %s:\n\n", result.Data.Metadata.SourceURL))
if result.Data.Metadata.Title != "" {
b.WriteString(fmt.Sprintf("Title: %s\n\n", result.Data.Metadata.Title))
}
- b.WriteString(truncate(result.Data.Markdown, 8000))
+ b.WriteString(Truncate(result.Data.Markdown, 8000))
return b.String()
}
@@ -562,7 +589,7 @@ func (s *staticStream) Close() error {
return nil
}
-func truncate(s string, maxLen int) string {
+func Truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
@@ -570,11 +597,17 @@ func truncate(s string, maxLen int) string {
}
func escapeForJSON(s string) string {
- // Escape backslashes and double quotes for embedding in JSON strings.
- s = strings.ReplaceAll(s, "\\", "\\\\")
- s = strings.ReplaceAll(s, "\"", "\\\"")
- s = strings.ReplaceAll(s, "\n", "\\n")
- s = strings.ReplaceAll(s, "\r", "\\r")
- s = strings.ReplaceAll(s, "\t", "\\t")
- return s
+ // Use json.Marshal for proper Unicode/control character escaping.
+ // The go standard library handles all escaping rules according to RFC 8259.
+ encoded, err := json.Marshal(s)
+ if err != nil {
+ return s
+ }
+ // json.Marshal returns a quoted JSON string. Strip the surrounding quotes
+ // for embedding in the tool_result content template.
+ raw := string(encoded)
+ if len(raw) >= 2 && raw[0] == '"' && raw[len(raw)-1] == '"' {
+ return raw[1 : len(raw)-1]
+ }
+ return raw
}
diff --git a/internal/extension/websearch/tavily.go b/internal/extension/websearch/tavily.go
index 388f74b6..ea43a91d 100644
--- a/internal/extension/websearch/tavily.go
+++ b/internal/extension/websearch/tavily.go
@@ -26,7 +26,7 @@ func NewTavilyClient(apiKey string) *TavilyClient {
}
}
-// Search executes a search query against the Tavily API.
+// Search executes a search query against the Tavily API with retry on transient errors.
func (c *TavilyClient) Search(ctx context.Context, req SearchRequest) (*SearchResult, error) {
if req.MaxResults <= 0 {
req.MaxResults = 5
@@ -40,34 +40,52 @@ func (c *TavilyClient) Search(ctx context.Context, req SearchRequest) (*SearchRe
return nil, fmt.Errorf("marshal search request: %w", err)
}
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilyBaseURL+"/search", bytes.NewReader(body))
- if err != nil {
- return nil, fmt.Errorf("create search request: %w", err)
- }
- httpReq.Header.Set("Content-Type", "application/json")
- httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
+ var lastErr error
+ for attempt := 0; attempt < 3; attempt++ {
+ if attempt > 0 {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-time.After(time.Duration(1<= 600) {
+ return nil, lastErr
+ }
}
-
- var result SearchResult
- if err := json.Unmarshal(respBody, &result); err != nil {
- return nil, fmt.Errorf("unmarshal search response: %w", err)
- }
- return &result, nil
+ return nil, fmt.Errorf("tavily search failed after 3 attempts: %w", lastErr)
}
diff --git a/internal/extension/websearchinjected/plugin.go b/internal/extension/websearchinjected/plugin.go
index 895368e6..3b22253a 100644
--- a/internal/extension/websearchinjected/plugin.go
+++ b/internal/extension/websearchinjected/plugin.go
@@ -2,10 +2,10 @@ package websearchinjected
import (
"moonbridge/internal/extension/plugin"
- "moonbridge/internal/protocol/anthropic"
- "moonbridge/internal/service/provider"
"moonbridge/internal/extension/websearch"
"moonbridge/internal/format"
+ "moonbridge/internal/protocol/anthropic"
+ "moonbridge/internal/service/provider"
)
const PluginName = "web_search_injected"
diff --git a/internal/format/adapter.go b/internal/format/adapter.go
index 40770a45..dc55008f 100644
--- a/internal/format/adapter.go
+++ b/internal/format/adapter.go
@@ -101,10 +101,27 @@ type ProviderStreamAdapter interface {
// ToCoreStream consumes an upstream stream source (e.g. anthropic.Stream)
// and returns a channel of CoreStreamEvent. The adapter is responsible for
// the read-loop inside a goroutine.
- ToCoreStream(ctx context.Context, src any) (<-chan CoreStreamEvent, error)
+ ToCoreStream(ctx context.Context, src any) (*StreamResult, error)
}
// ============================================================================
+
+// StreamResult wraps the output of a streaming invocation with per-stream
+// buffer access for trace capture and plugin state tracking.
+//
+// Using a concrete return type (rather than a bare channel) keeps the
+// buffer local to each streaming call, eliminating the data race that
+// occurred when buffers were stored on the shared adapter instance.
+type StreamResult struct {
+ // Events is the channel of CoreStreamEvent from the upstream stream.
+ Events <-chan CoreStreamEvent
+
+ // StreamBuffer returns the captured raw upstream events for trace,
+ // plugin reasoning replay, and other post-stream processing.
+ // Must only be called after the Events channel is fully consumed.
+ StreamBuffer func() []any
+}
+
// CorePluginHooks — protocol-agnostic plugin hooks operating on Core format
// ============================================================================
diff --git a/internal/format/registry_test.go b/internal/format/registry_test.go
index c274fdc6..f0d0da73 100644
--- a/internal/format/registry_test.go
+++ b/internal/format/registry_test.go
@@ -11,15 +11,23 @@ import (
type mockClient struct{ protocol string }
-func (m *mockClient) ClientProtocol() string { return m.protocol }
-func (m *mockClient) ToCoreRequest(_ context.Context, _ any) (*CoreRequest, error) { return &CoreRequest{}, nil }
-func (m *mockClient) FromCoreResponse(_ context.Context, _ *CoreResponse) (any, error) { return nil, nil }
+func (m *mockClient) ClientProtocol() string { return m.protocol }
+func (m *mockClient) ToCoreRequest(_ context.Context, _ any) (*CoreRequest, error) {
+ return &CoreRequest{}, nil
+}
+func (m *mockClient) FromCoreResponse(_ context.Context, _ *CoreResponse) (any, error) {
+ return nil, nil
+}
type mockProvider struct{ protocol string }
-func (m *mockProvider) ProviderProtocol() string { return m.protocol }
-func (m *mockProvider) FromCoreRequest(_ context.Context, _ *CoreRequest) (any, error) { return nil, nil }
-func (m *mockProvider) ToCoreResponse(_ context.Context, _ any) (*CoreResponse, error) { return &CoreResponse{}, nil }
+func (m *mockProvider) ProviderProtocol() string { return m.protocol }
+func (m *mockProvider) FromCoreRequest(_ context.Context, _ *CoreRequest) (any, error) {
+ return nil, nil
+}
+func (m *mockProvider) ToCoreResponse(_ context.Context, _ any) (*CoreResponse, error) {
+ return &CoreResponse{}, nil
+}
type mockClientStream struct{ protocol string }
@@ -31,7 +39,7 @@ func (m *mockClientStream) FromCoreStream(_ context.Context, _ *CoreRequest, _ <
type mockProviderStream struct{ protocol string }
func (m *mockProviderStream) ProviderProtocol() string { return m.protocol }
-func (m *mockProviderStream) ToCoreStream(_ context.Context, _ any) (<-chan CoreStreamEvent, error) {
+func (m *mockProviderStream) ToCoreStream(_ context.Context, _ any) (*StreamResult, error) {
return nil, nil
}
diff --git a/internal/format/types.go b/internal/format/types.go
index 085ba096..8b552279 100644
--- a/internal/format/types.go
+++ b/internal/format/types.go
@@ -118,7 +118,7 @@ type CoreToolChoice struct {
// nil = use provider defaults.
type CoreThinkingConfig struct {
Type string `json:"type,omitempty"` // "enabled" | "disabled"
- BudgetTokens int `json:"budget_tokens,omitempty"` // token budget for thinking
+ BudgetTokens int `json:"budget_tokens,omitempty"` // token budget for thinking
}
// CoreOutputConfig controls output generation behavior.
@@ -139,8 +139,8 @@ type CoreCacheControl struct {
// CoreRequest is the protocol-agnostic representation of an LLM request.
type CoreRequest struct {
- Model string `json:"model"`
- Messages []CoreMessage `json:"messages"`
+ Model string `json:"model"`
+ Messages []CoreMessage `json:"messages"`
System []CoreContentBlock `json:"system,omitempty"`
// Tools
@@ -172,7 +172,6 @@ type CoreRequest struct {
// Zero value (nil) = not set — adapter uses provider defaults.
GenerationConfig map[string]any `json:"generation_config,omitempty"`
-
// Thinking controls extended thinking/reasoning behavior (e.g. Anthropic extended thinking).
// nil = use provider defaults.
Thinking *CoreThinkingConfig `json:"thinking,omitempty"`
@@ -304,7 +303,6 @@ type CoreStreamEvent struct {
Extensions map[string]any `json:"extensions,omitempty"`
}
-
// StripImageData scans string content for base64-encoded image data (data:image URLs
// and raw PNG/JPEG base64 blobs) and replaces them with short placeholders.
// This prevents large image payloads from wasting tokens when sent to text-only models
@@ -329,7 +327,7 @@ func StripImageData(s string) string {
for end < len(s) && (isBase64Char(s[end]) || s[end] == '=') {
end++
}
- if end > dataStart + 500 {
+ if end > dataStart+500 {
imgType := s[start+len(marker) : start+commaIdx]
result.WriteString(fmt.Sprintf("[Image data: %s, %d bytes]", imgType, end-dataStart))
pos = end
@@ -352,7 +350,7 @@ func StripImageData(s string) string {
for end < len(s) && (isBase64Char(s[end]) || s[end] == '=') {
end++
}
- if end > start + 500 {
+ if end > start+500 {
result.WriteString(fmt.Sprintf("[Image data: png, %d bytes]", end-start))
pos = end
continue
@@ -370,7 +368,7 @@ func StripImageData(s string) string {
for end < len(s) && (isBase64Char(s[end]) || s[end] == '=') {
end++
}
- if end > start + 500 {
+ if end > start+500 {
result.WriteString(fmt.Sprintf("[Image data: jpeg, %d bytes]", end-start))
pos = end
continue
diff --git a/internal/logger/consumer.go b/internal/logger/consumer.go
index 4e485a31..ea95b060 100644
--- a/internal/logger/consumer.go
+++ b/internal/logger/consumer.go
@@ -46,10 +46,10 @@ func (s *consumeState) store(fn ConsumeFunc) {
// handlerAttrs / handlerGroups and merged into LogEntry.Attrs before
// dispatching to the consume pipeline, so consumers see the full context.
type consumeHandler struct {
- inner slog.Handler
- consume *consumeState
- handlerAttrs []slog.Attr // attrs from WithAttrs calls
- handlerGroups []string // groups from WithGroup calls
+ inner slog.Handler
+ consume *consumeState
+ handlerAttrs []slog.Attr // attrs from WithAttrs calls
+ handlerGroups []string // groups from WithGroup calls
}
// newConsumeHandler wraps the given handler with consume-function support.
@@ -129,9 +129,9 @@ func (h *consumeHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
copy(combined, h.handlerAttrs)
combined = append(combined, attrs...)
return &consumeHandler{
- inner: h.inner.WithAttrs(attrs),
- consume: h.consume,
- handlerAttrs: combined,
+ inner: h.inner.WithAttrs(attrs),
+ consume: h.consume,
+ handlerAttrs: combined,
handlerGroups: h.handlerGroups,
}
}
diff --git a/internal/protocol/anthropic/adapter.go b/internal/protocol/anthropic/adapter.go
index 5a09e767..ef39e74e 100644
--- a/internal/protocol/anthropic/adapter.go
+++ b/internal/protocol/anthropic/adapter.go
@@ -47,8 +47,14 @@ type AnthropicProviderAdapter struct {
cacheMgr CacheManager
hooks format.CorePluginHooks
- streamMu sync.Mutex
- streamEvents []StreamEvent
+ // cacheKeyMu guards cacheKeyStore, which maps context pointers to cache
+ // key/ttl pairs computed during PlanAndInject and consumed during ToCoreResponse.
+ cacheKeyMu sync.Mutex
+ cacheKeyStore map[string]cacheKeyEntry
+}
+
+type cacheKeyEntry struct {
+ key, ttl string
}
// NewAnthropicProviderAdapter creates a new AnthropicProviderAdapter.
@@ -57,9 +63,10 @@ type AnthropicProviderAdapter struct {
// if caching is not needed.
func NewAnthropicProviderAdapter(cfgMaxTokens int, cacheMgr CacheManager, hooks format.CorePluginHooks) *AnthropicProviderAdapter {
return &AnthropicProviderAdapter{
- cfgMaxTokens: cfgMaxTokens,
- cacheMgr: cacheMgr,
- hooks: hooks.WithDefaults(),
+ cfgMaxTokens: cfgMaxTokens,
+ cacheMgr: cacheMgr,
+ hooks: hooks.WithDefaults(),
+ cacheKeyStore: make(map[string]cacheKeyEntry),
}
}
@@ -294,6 +301,11 @@ func (a *AnthropicProviderAdapter) FromCoreRequest(ctx context.Context, req *for
Role: a.mapRole(msg.Role),
Content: a.toContentBlocks(msg.Content),
}
+ // Skip messages with no content blocks — they are empty and contribute
+ // no semantic value to the upstream API.
+ if len(anthroMsg.Content) == 0 {
+ continue
+ }
last := len(anthropicReq.Messages) - 1
if last >= 0 && anthroMsg.Role == "user" &&
anthropicReq.Messages[last].Role == "user" &&
@@ -306,6 +318,15 @@ func (a *AnthropicProviderAdapter) FromCoreRequest(ctx context.Context, req *for
}
}
+ // Ensure first message has role "user" — Anthropic API rejects requests
+ // where the first message is assistant, tool, or any non-user role.
+ if len(anthropicReq.Messages) > 0 && anthropicReq.Messages[0].Role != "user" {
+ anthropicReq.Messages = append(
+ []Message{{Role: "user", Content: []ContentBlock{{Type: "text", Text: "_"}}}},
+ anthropicReq.Messages...,
+ )
+ }
+
// Tools
if len(req.Tools) > 0 {
anthropicReq.Tools = make([]Tool, 0, len(req.Tools))
@@ -333,7 +354,13 @@ func (a *AnthropicProviderAdapter) FromCoreRequest(ctx context.Context, req *for
// Step 3: Cache planning via CacheManager.
// PlanAndInject may modify anthropicReq in-place by setting cache_control
// on tools, system blocks, messages, or the request-level field.
- a.cacheMgr.PlanAndInject(ctx, &anthropicReq, req)
+ key, ttl := a.cacheMgr.PlanAndInject(ctx, &anthropicReq, req)
+
+ // Store cache key/ttl for retrieval in ToCoreResponse.
+ ctxKey := fmt.Sprintf("%p", ctx)
+ a.cacheKeyMu.Lock()
+ a.cacheKeyStore[ctxKey] = cacheKeyEntry{key: key, ttl: ttl}
+ a.cacheKeyMu.Unlock()
return &anthropicReq, nil
}
@@ -384,10 +411,19 @@ func (a *AnthropicProviderAdapter) ToCoreResponse(ctx context.Context, resp any)
}
// Update cache registry from usage signals via CacheManager.
- // The key/ttl were computed during PlanAndInject and must be accessible
- // through the CacheManager's own state (or context).
+ // The key/ttl were computed during PlanAndInject and are retrieved
+ // from the per-request cache key store.
if a.cacheMgr != nil {
- a.cacheMgr.UpdateRegistry(ctx, "", "", msgResp.Usage)
+ ctxKey := fmt.Sprintf("%p", ctx)
+ a.cacheKeyMu.Lock()
+ entry, ok := a.cacheKeyStore[ctxKey]
+ delete(a.cacheKeyStore, ctxKey)
+ a.cacheKeyMu.Unlock()
+ if ok {
+ a.cacheMgr.UpdateRegistry(ctx, entry.key, entry.ttl, msgResp.Usage)
+ } else {
+ a.cacheMgr.UpdateRegistry(ctx, "", "", msgResp.Usage)
+ }
}
return coreResp, nil
@@ -420,15 +456,18 @@ type streamConverterState struct {
blockTypes map[int]string // content index → block type
blockSignatures map[int]string // content index → reasoning signature (from signature_delta)
finalUsage *format.CoreUsage // tracked from message_delta, passed to message_stop
- adapter *AnthropicProviderAdapter // for buffering raw stream events (trace)
+ adapter *AnthropicProviderAdapter // for plugin hooks
suppressText map[int]bool // text indices to suppress (server-side search status, etc.)
+ buf *[]StreamEvent // per-stream event buffer (local, not shared)
+ bufMu *sync.Mutex // guards buf
+ ctx context.Context // for context-aware channel sends
}
// ToCoreStream consumes an anthropic.Stream and returns a channel of CoreStreamEvent.
//
// The adapter owns the read-loop goroutine. The returned channel is closed when
// the stream ends, context is cancelled, or an error occurs.
-func (a *AnthropicProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan format.CoreStreamEvent, error) {
+func (a *AnthropicProviderAdapter) ToCoreStream(ctx context.Context, src any) (*format.StreamResult, error) {
stream, ok := src.(Stream)
if !ok {
return nil, fmt.Errorf("anthropic adapter: expected anthropic.Stream, got %T", src)
@@ -436,13 +475,14 @@ func (a *AnthropicProviderAdapter) ToCoreStream(ctx context.Context, src any) (<
ctx = coreHookContext(ctx, "")
events := make(chan format.CoreStreamEvent, 64)
- // Initialize stream event buffer for trace capture.
- a.streamMu.Lock()
- a.streamEvents = make([]StreamEvent, 0, 64)
- a.streamMu.Unlock()
+ // Per-stream buffer — local to this call, not shared across concurrent requests.
+ var buf []StreamEvent
+ var bufMu sync.Mutex
+ bufReady := make(chan struct{})
go func() {
defer close(events)
+ defer close(bufReady)
defer stream.Close()
state := &streamConverterState{
@@ -450,6 +490,9 @@ func (a *AnthropicProviderAdapter) ToCoreStream(ctx context.Context, src any) (<
blockSignatures: make(map[int]string),
adapter: a,
suppressText: make(map[int]bool),
+ buf: &buf,
+ bufMu: &bufMu,
+ ctx: ctx,
}
for {
@@ -462,10 +505,8 @@ func (a *AnthropicProviderAdapter) ToCoreStream(ctx context.Context, src any) (<
ev, err := stream.Next()
if err != nil {
if err == io.EOF {
- // Stream ended normally.
return
}
- // Context cancellation is clean shutdown, not a failure.
if err == context.Canceled || err == context.DeadlineExceeded {
return
}
@@ -478,11 +519,30 @@ func (a *AnthropicProviderAdapter) ToCoreStream(ctx context.Context, src any) (<
return
}
+ // Check context immediately after Next() to close the race window.
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+
state.convertEvent(events, ev)
}
}()
- return events, nil
+ return &format.StreamResult{
+ Events: events,
+ StreamBuffer: func() []any {
+ <-bufReady
+ bufMu.Lock()
+ defer bufMu.Unlock()
+ result := make([]any, len(buf))
+ for i, ev := range buf {
+ result[i] = ev
+ }
+ return result
+ },
+ }, nil
}
// =========================================================================
@@ -496,13 +556,24 @@ func (s *streamConverterState) nextSeq() int64 {
func (s *streamConverterState) emit(events chan<- format.CoreStreamEvent, ev format.CoreStreamEvent) {
ev.SeqNum = s.nextSeq()
- events <- ev
+ if s.ctx != nil {
+ select {
+ case <-s.ctx.Done():
+ case events <- ev:
+ }
+ } else {
+ events <- ev
+ }
}
func (s *streamConverterState) convertEvent(events chan<- format.CoreStreamEvent, ev StreamEvent) {
- // Buffer the original event for trace (max 4MB).
- if s.adapter != nil {
- s.adapter.bufferStreamEvent(ev)
+ // Buffer the original event for trace in the per-stream local buffer.
+ if s.bufMu != nil && s.buf != nil {
+ s.bufMu.Lock()
+ if len(*s.buf) < 1024 {
+ *s.buf = append(*s.buf, ev)
+ }
+ s.bufMu.Unlock()
}
switch ev.Type {
case "message_start":
@@ -741,7 +812,7 @@ func (a *AnthropicProviderAdapter) toContentBlock(b format.CoreContentBlock) Con
}
case "tool_result":
- var content any
+ var content any = ""
if len(b.ToolResultContent) > 0 {
content = a.toContentBlocks(b.ToolResultContent)
}
@@ -969,20 +1040,15 @@ func (a *AnthropicProviderAdapter) mapStopReasonToStatus(reason string) string {
// bufferStreamEvent buffers the raw anthropic stream event for trace capture,
// up to the 4MB limit. The event is JSON-marshalled to estimate its size.
func (a *AnthropicProviderAdapter) bufferStreamEvent(ev StreamEvent) {
- a.streamMu.Lock()
- defer a.streamMu.Unlock()
- // Cap buffer at 1024 events (~4MB estimation) to prevent unbounded memory growth.
- if len(a.streamEvents) >= 1024 {
- return
- }
- a.streamEvents = append(a.streamEvents, ev)
+ // streamConverterState captures buf/bufMu from the goroutine closure.
+ // This is a no-op without a per-stream buffer — use the state.bufferStreamEvent instead.
}
// StreamBuffer returns the buffered stream events for trace capture.
func (a *AnthropicProviderAdapter) StreamBuffer() []StreamEvent {
- a.streamMu.Lock()
- defer a.streamMu.Unlock()
- return a.streamEvents
+ // Deprecated: use StreamResult.StreamBuffer instead.
+ // This method will be removed after all callers migrate.
+ return nil
}
// RememberStreamContent stores response content from a stream for plugin state tracking.
diff --git a/internal/protocol/anthropic/adapter_test.go b/internal/protocol/anthropic/adapter_test.go
index ed057567..88d0e353 100644
--- a/internal/protocol/anthropic/adapter_test.go
+++ b/internal/protocol/anthropic/adapter_test.go
@@ -261,12 +261,17 @@ func TestFromCoreRequest_ToolUseContent(t *testing.T) {
}
msgReq := result.(*anthropic.MessageRequest)
- if len(msgReq.Messages) != 2 {
- t.Fatalf("got %d messages, want 2", len(msgReq.Messages))
+ if len(msgReq.Messages) != 3 {
+ t.Fatalf("got %d messages, want 3", len(msgReq.Messages))
+ }
+
+ // First message is the inserted user placeholder.
+ if msgReq.Messages[0].Role != "user" || len(msgReq.Messages[0].Content) != 1 || msgReq.Messages[0].Content[0].Text != "_" {
+ t.Fatalf("expected user placeholder as first message")
}
// assistant tool_use block
- blocks0 := msgReq.Messages[0].Content
+ blocks0 := msgReq.Messages[1].Content
if len(blocks0) != 1 || blocks0[0].Type != "tool_use" {
t.Fatalf("expected tool_use block in assistant message")
}
@@ -277,8 +282,8 @@ func TestFromCoreRequest_ToolUseContent(t *testing.T) {
t.Errorf("tool_use name = %q", blocks0[0].Name)
}
- // user tool_result block
- blocks1 := msgReq.Messages[1].Content
+ // user tool_result block (index 2 because placeholder is at 0)
+ blocks1 := msgReq.Messages[2].Content
if len(blocks1) != 1 || blocks1[0].Type != "tool_result" {
t.Fatalf("expected tool_result block in user message")
}
@@ -313,7 +318,14 @@ func TestFromCoreRequest_Reasoning(t *testing.T) {
t.Fatal(err)
}
msgReq := result.(*anthropic.MessageRequest)
- blocks := msgReq.Messages[0].Content
+ // First message is the inserted user placeholder.
+ if len(msgReq.Messages) != 3 {
+ t.Fatalf("got %d messages, want 3", len(msgReq.Messages))
+ }
+ if msgReq.Messages[0].Role != "user" || len(msgReq.Messages[0].Content) != 1 || msgReq.Messages[0].Content[0].Text != "_" {
+ t.Fatalf("expected user placeholder as first message, got role=%q content=%v", msgReq.Messages[0].Role, msgReq.Messages[0].Content)
+ }
+ blocks := msgReq.Messages[1].Content
if len(blocks) != 2 {
t.Fatalf("got %d blocks, want 2", len(blocks))
@@ -684,18 +696,23 @@ func TestFromCoreRequest_MergesConsecutiveToolResultMessages(t *testing.T) {
t.Fatalf("expected *MessageRequest, got %T", result)
}
- // We should have: assistant + user (merged from 2 tool_result messages)
- if len(msgReq.Messages) != 2 {
- t.Fatalf("expected 2 messages (assistant + merged user), got %d", len(msgReq.Messages))
+ // We should have: placeholder user + assistant + merged user
+ if len(msgReq.Messages) != 3 {
+ t.Fatalf("expected 3 messages (placeholder + assistant + merged user), got %d", len(msgReq.Messages))
+ }
+
+ // First message is the inserted user placeholder.
+ if msgReq.Messages[0].Role != "user" || len(msgReq.Messages[0].Content) != 1 || msgReq.Messages[0].Content[0].Text != "_" {
+ t.Fatalf("expected user placeholder as first message, got role=%q", msgReq.Messages[0].Role)
}
- // First message should be assistant with tool_use
- if msgReq.Messages[0].Role != "assistant" {
- t.Errorf("messages[0].Role = %q, want assistant", msgReq.Messages[0].Role)
+ // Second message should be assistant with tool_use
+ if msgReq.Messages[1].Role != "assistant" {
+ t.Errorf("messages[1].Role = %q, want assistant", msgReq.Messages[1].Role)
}
- // Second message should be user with 2 tool_result blocks (merged)
- merged := msgReq.Messages[1]
+ // Third message should be user with 2 tool_result blocks (merged)
+ merged := msgReq.Messages[2]
if merged.Role != "user" {
t.Errorf("merged message role = %q, want user", merged.Role)
}
diff --git a/internal/protocol/anthropic/cache.go b/internal/protocol/anthropic/cache.go
index 7a415b10..17cbad9b 100644
--- a/internal/protocol/anthropic/cache.go
+++ b/internal/protocol/anthropic/cache.go
@@ -7,8 +7,8 @@ import (
"strings"
"moonbridge/internal/config"
- "moonbridge/internal/protocol/cache"
"moonbridge/internal/format"
+ "moonbridge/internal/protocol/cache"
"moonbridge/internal/protocol/openai"
)
@@ -100,7 +100,7 @@ func (m *adapterCacheManager) UpdateRegistry(ctx context.Context, key, ttl strin
func PlanCache(cfg cache.PlanCacheConfig, registry *cache.MemoryRegistry, request openai.ResponsesRequest, converted MessageRequest) (cache.CacheCreationPlan, error) {
if request.PromptCacheRetention == "24h" && !cfg.AllowRetentionDowngrade {
return cache.CacheCreationPlan{}, &cachePlanError{
- Status: 400,
+ Status: 400,
Message: "prompt_cache_retention 24h is not supported by Anthropic prompt caching",
Param: "prompt_cache_retention",
Code: "unsupported_parameter",
diff --git a/internal/protocol/chat/adapter.go b/internal/protocol/chat/adapter.go
index a5380fac..ff4bce5e 100644
--- a/internal/protocol/chat/adapter.go
+++ b/internal/protocol/chat/adapter.go
@@ -28,9 +28,6 @@ type ChatProviderAdapter struct {
cfgMaxTokens int
client *Client
hooks format.CorePluginHooks
-
- streamMu sync.Mutex
- streamEvents []ChatStreamChunk
}
// NewChatProviderAdapter creates a new ChatProviderAdapter.
@@ -95,6 +92,11 @@ func (a *ChatProviderAdapter) FromCoreRequest(ctx context.Context, req *format.C
// Messages.
for _, msg := range req.Messages {
chatMsg := a.toChatMessage(msg)
+ // Skip messages with neither text content nor tool calls — empty messages
+ // contribute no semantic value and may be rejected by some upstreams.
+ if chatMsg.Content == nil && len(chatMsg.ToolCalls) == 0 {
+ continue
+ }
chatReq.Messages = append(chatReq.Messages, chatMsg)
}
@@ -198,16 +200,14 @@ func (a *ChatProviderAdapter) ToCoreResponse(ctx context.Context, resp any) (*fo
// =========================================================================
// bufferStreamEvent buffers raw ChatStreamChunk for trace capture.
func (a *ChatProviderAdapter) bufferStreamEvent(ev ChatStreamChunk) {
- a.streamMu.Lock()
- defer a.streamMu.Unlock()
- a.streamEvents = append(a.streamEvents, ev)
+ // No-op: per-stream buffer is captured by goroutine closure.
+ // Use the StreamResult.StreamBuffer to access captured events.
}
// StreamBuffer returns the buffered stream events for trace capture.
func (a *ChatProviderAdapter) StreamBuffer() []ChatStreamChunk {
- a.streamMu.Lock()
- defer a.streamMu.Unlock()
- return a.streamEvents
+ // Deprecated: use StreamResult.StreamBuffer instead.
+ return nil
}
// ToCoreStream — <-chan ChatStreamChunk → <-chan CoreStreamEvent
@@ -224,7 +224,7 @@ func (a *ChatProviderAdapter) StreamBuffer() []ChatStreamChunk {
// - core.text.delta (chunks with content delta)
// - core.content_block.done (chunk with finish_reason set)
// - core.completed (final chunk with Usage)
-func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan format.CoreStreamEvent, error) {
+func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (*format.StreamResult, error) {
ch, ok := src.(<-chan ChatStreamChunk)
if !ok {
return nil, fmt.Errorf("chat adapter: expected <-chan ChatStreamChunk, got %T", src)
@@ -232,13 +232,14 @@ func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan
events := make(chan format.CoreStreamEvent, 64)
- // Initialize stream event buffer for trace capture.
- a.streamMu.Lock()
- a.streamEvents = make([]ChatStreamChunk, 0, 64)
- a.streamMu.Unlock()
+ // Per-stream buffer — local to this call, not shared across concurrent requests.
+ var buf []ChatStreamChunk
+ var bufMu sync.Mutex
+ bufReady := make(chan struct{})
go func() {
defer close(events)
+ defer close(bufReady)
// Per-choice state for streaming.
type choiceState struct {
@@ -271,7 +272,12 @@ func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan
case <-ctx.Done():
return
case chunk, ok := <-ch:
- a.bufferStreamEvent(chunk)
+ // Append to local per-stream buffer with size cap.
+ bufMu.Lock()
+ if len(buf) < 1024 {
+ buf = append(buf, chunk)
+ }
+ bufMu.Unlock()
if !ok {
// Channel closed — emit completion if not already done.
if !seenCompletion {
@@ -510,7 +516,19 @@ func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan
}
}()
- return events, nil
+ return &format.StreamResult{
+ Events: events,
+ StreamBuffer: func() []any {
+ <-bufReady
+ bufMu.Lock()
+ defer bufMu.Unlock()
+ result := make([]any, len(buf))
+ for i, ev := range buf {
+ result[i] = ev
+ }
+ return result
+ },
+ }, nil
}
// =========================================================================
diff --git a/internal/protocol/chat/chat_e2e_test.go b/internal/protocol/chat/chat_e2e_test.go
index 10201b7d..7cbe8b62 100644
--- a/internal/protocol/chat/chat_e2e_test.go
+++ b/internal/protocol/chat/chat_e2e_test.go
@@ -117,7 +117,7 @@ func TestE2EChatProvider(t *testing.T) {
Name: "get_weather",
Description: "Get the current weather for a city",
Parameters: map[string]any{
- "type": "object",
+ "type": "object",
"properties": map[string]any{
"location": map[string]any{
"type": "string",
@@ -130,7 +130,7 @@ func TestE2EChatProvider(t *testing.T) {
},
},
ToolChoice: []byte(`"auto"`),
- MaxTokens: 200,
+ MaxTokens: 200,
}
resp, err := client.CreateChat(context.Background(), req)
diff --git a/internal/protocol/chat/chat_test.go b/internal/protocol/chat/chat_test.go
index 35b71c2e..300f1424 100644
--- a/internal/protocol/chat/chat_test.go
+++ b/internal/protocol/chat/chat_test.go
@@ -1674,7 +1674,7 @@ func TestToCoreStream_BasicDelta(t *testing.T) {
}
var evts []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
evts = append(evts, e)
}
@@ -1734,7 +1734,7 @@ func TestToCoreStream_ToolCallArgsDelta(t *testing.T) {
}
var evts []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
evts = append(evts, e)
}
@@ -1763,7 +1763,7 @@ func TestToCoreStream_EmptyChunk(t *testing.T) {
}
var evts []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
evts = append(evts, e)
}
@@ -1786,7 +1786,7 @@ func TestToCoreStream_NoContent(t *testing.T) {
}
var evts []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
evts = append(evts, e)
}
@@ -1813,7 +1813,7 @@ func TestToCoreStream_ContextCancel(t *testing.T) {
// Events channel should close immediately due to cancelled context.
var evts []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
evts = append(evts, e)
}
if len(evts) != 0 {
@@ -1859,7 +1859,7 @@ func TestToCoreStream_MultiChoice(t *testing.T) {
}
var evts []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
evts = append(evts, e)
}
@@ -2234,7 +2234,7 @@ func TestToCoreStream_WithModel(t *testing.T) {
}
var evts []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
evts = append(evts, e)
}
if len(evts) < 4 {
@@ -2262,7 +2262,7 @@ func TestToCoreStream_ContentBlockStartedNoRole(t *testing.T) {
}
var evts []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
evts = append(evts, e)
}
if len(evts) < 3 {
@@ -2314,7 +2314,7 @@ func TestToCoreStream_ToolCallArgsDeltaByPosition(t *testing.T) {
t.Fatal(err)
}
var deltas []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
if e.Type == format.CoreToolCallArgsDelta {
deltas = append(deltas, e)
}
@@ -2372,7 +2372,7 @@ func TestToCoreStream_ToolCallArgsDeltaRespectsExplicitToolIndex(t *testing.T) {
}
var started []format.CoreStreamEvent
var deltas []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
if e.Type == format.CoreContentBlockStarted && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" {
started = append(started, e)
}
diff --git a/internal/protocol/format/registry_test.go b/internal/protocol/format/registry_test.go
index c274fdc6..562326a1 100644
--- a/internal/protocol/format/registry_test.go
+++ b/internal/protocol/format/registry_test.go
@@ -11,15 +11,23 @@ import (
type mockClient struct{ protocol string }
-func (m *mockClient) ClientProtocol() string { return m.protocol }
-func (m *mockClient) ToCoreRequest(_ context.Context, _ any) (*CoreRequest, error) { return &CoreRequest{}, nil }
-func (m *mockClient) FromCoreResponse(_ context.Context, _ *CoreResponse) (any, error) { return nil, nil }
+func (m *mockClient) ClientProtocol() string { return m.protocol }
+func (m *mockClient) ToCoreRequest(_ context.Context, _ any) (*CoreRequest, error) {
+ return &CoreRequest{}, nil
+}
+func (m *mockClient) FromCoreResponse(_ context.Context, _ *CoreResponse) (any, error) {
+ return nil, nil
+}
type mockProvider struct{ protocol string }
-func (m *mockProvider) ProviderProtocol() string { return m.protocol }
-func (m *mockProvider) FromCoreRequest(_ context.Context, _ *CoreRequest) (any, error) { return nil, nil }
-func (m *mockProvider) ToCoreResponse(_ context.Context, _ any) (*CoreResponse, error) { return &CoreResponse{}, nil }
+func (m *mockProvider) ProviderProtocol() string { return m.protocol }
+func (m *mockProvider) FromCoreRequest(_ context.Context, _ *CoreRequest) (any, error) {
+ return nil, nil
+}
+func (m *mockProvider) ToCoreResponse(_ context.Context, _ any) (*CoreResponse, error) {
+ return &CoreResponse{}, nil
+}
type mockClientStream struct{ protocol string }
diff --git a/internal/protocol/format/types.go b/internal/protocol/format/types.go
index 520e2dba..c4a349b1 100644
--- a/internal/protocol/format/types.go
+++ b/internal/protocol/format/types.go
@@ -114,7 +114,7 @@ type CoreToolChoice struct {
// nil = use provider defaults.
type CoreThinkingConfig struct {
Type string `json:"type,omitempty"` // "enabled" | "disabled"
- BudgetTokens int `json:"budget_tokens,omitempty"` // token budget for thinking
+ BudgetTokens int `json:"budget_tokens,omitempty"` // token budget for thinking
}
// CoreOutputConfig controls output generation behavior.
@@ -135,8 +135,8 @@ type CoreCacheControl struct {
// CoreRequest is the protocol-agnostic representation of an LLM request.
type CoreRequest struct {
- Model string `json:"model"`
- Messages []CoreMessage `json:"messages"`
+ Model string `json:"model"`
+ Messages []CoreMessage `json:"messages"`
System []CoreContentBlock `json:"system,omitempty"`
// Tools
@@ -168,7 +168,6 @@ type CoreRequest struct {
// Zero value (nil) = not set — adapter uses provider defaults.
GenerationConfig map[string]any `json:"generation_config,omitempty"`
-
// Thinking controls extended thinking/reasoning behavior (e.g. Anthropic extended thinking).
// nil = use provider defaults.
Thinking *CoreThinkingConfig `json:"thinking,omitempty"`
diff --git a/internal/protocol/google/adapter.go b/internal/protocol/google/adapter.go
index 27904928..5450ec1c 100644
--- a/internal/protocol/google/adapter.go
+++ b/internal/protocol/google/adapter.go
@@ -38,13 +38,6 @@ type GeminiProviderAdapter struct {
// currentModel tracks the model for the current request (used by cache).
currentModel string
- // toolUseIDMap maps ToolUseID → ToolName for FunctionResponse name resolution.
- // Only valid during a single FromCoreRequest call.
- toolUseIDMap map[string]string
- toolUseIDMu sync.Mutex
-
- streamMu sync.Mutex
- streamEvents []GenerateContentResponse
prevSnapshots map[int]string // candidate index → previous text for delta computation
}
@@ -84,10 +77,6 @@ func (a *GeminiProviderAdapter) FromCoreRequest(ctx context.Context, req *format
return nil, fmt.Errorf("google adapter: core request is nil")
}
- // Initialize per-request state.
- a.toolUseIDMu.Lock()
- a.toolUseIDMap = make(map[string]string)
- a.toolUseIDMu.Unlock()
a.currentModel = req.Model
// Step 1: Allow plugins to mutate the CoreRequest before conversion.
@@ -105,9 +94,11 @@ func (a *GeminiProviderAdapter) FromCoreRequest(ctx context.Context, req *format
Contents: make([]Content, 0, len(req.Messages)),
}
+ toolUseIDMap := make(map[string]string)
+
// System instruction (D-01): CoreRequest.System → Gemini system_instruction
if len(req.System) > 0 {
- sysContent := a.blocksToContent(req.System)
+ sysContent := a.blocksToContent(req.System, toolUseIDMap)
if len(sysContent.Parts) > 0 {
geminiReq.SystemInstruction = &sysContent
}
@@ -118,7 +109,12 @@ func (a *GeminiProviderAdapter) FromCoreRequest(ctx context.Context, req *format
// with the same role (e.g. tool_result after user text) are merged.
mergedContents := make([]Content, 0, len(req.Messages))
for _, msg := range req.Messages {
- content := a.blocksToContent(msg.Content)
+ content := a.blocksToContent(msg.Content, toolUseIDMap)
+ // Skip messages with no content parts — they contribute no semantic value
+ // and may cause SDK role-alternating contract violations.
+ if len(content.Parts) == 0 {
+ continue
+ }
content.Role = a.mapRoleToGemini(msg.Role)
if len(mergedContents) > 0 && mergedContents[len(mergedContents)-1].Role == content.Role {
mergedContents[len(mergedContents)-1].Parts = append(mergedContents[len(mergedContents)-1].Parts, content.Parts...)
@@ -126,6 +122,14 @@ func (a *GeminiProviderAdapter) FromCoreRequest(ctx context.Context, req *format
mergedContents = append(mergedContents, content)
}
}
+ // Ensure first Content has role "user" — Gemini API requires alternating
+ // user/model roles starting with user. Insert a placeholder if needed.
+ if len(mergedContents) > 0 && mergedContents[0].Role == "model" {
+ mergedContents = append(
+ []Content{{Role: "user", Parts: []Part{{Text: "_"}}}},
+ mergedContents...,
+ )
+ }
geminiReq.Contents = mergedContents
// SafetySettings (D-02): CoreRequest.SafetySettings map → Gemini []SafetySetting
@@ -155,10 +159,6 @@ func (a *GeminiProviderAdapter) FromCoreRequest(ctx context.Context, req *format
// Cache integration — look up or create CachedContent.
a.prepareCache(ctx, geminiReq)
- // Clean up per-request state.
- a.toolUseIDMu.Lock()
- a.toolUseIDMap = nil
- a.toolUseIDMu.Unlock()
a.currentModel = ""
return geminiReq, nil
@@ -228,16 +228,13 @@ func (a *GeminiProviderAdapter) ToCoreResponse(ctx context.Context, resp any) (*
// =========================================================================
// bufferStreamEvent buffers raw GenerateContentResponse for trace capture.
func (a *GeminiProviderAdapter) bufferStreamEvent(ev GenerateContentResponse) {
- a.streamMu.Lock()
- defer a.streamMu.Unlock()
- a.streamEvents = append(a.streamEvents, ev)
+ // No-op: per-stream buffer is captured by goroutine closure.
}
// StreamBuffer returns the buffered stream events for trace capture.
func (a *GeminiProviderAdapter) StreamBuffer() []GenerateContentResponse {
- a.streamMu.Lock()
- defer a.streamMu.Unlock()
- return a.streamEvents
+ // Deprecated: use StreamResult.StreamBuffer instead.
+ return nil
}
// ToCoreStream — <-chan GenerateContentResponse → <-chan CoreStreamEvent
@@ -255,7 +252,7 @@ func (a *GeminiProviderAdapter) StreamBuffer() []GenerateContentResponse {
// - core.text.delta (each subsequent chunk with new text)
// - core.content_block.done (chunk with FinishReason set)
// - core.completed (final chunk with UsageMetadata)
-func (a *GeminiProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan format.CoreStreamEvent, error) {
+func (a *GeminiProviderAdapter) ToCoreStream(ctx context.Context, src any) (*format.StreamResult, error) {
ch, ok := src.(<-chan GenerateContentResponse)
if !ok {
return nil, fmt.Errorf("google adapter: expected <-chan GenerateContentResponse, got %T", src)
@@ -263,13 +260,14 @@ func (a *GeminiProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-ch
events := make(chan format.CoreStreamEvent, 64)
- // Initialize stream event buffer for trace capture.
- a.streamMu.Lock()
- a.streamEvents = make([]GenerateContentResponse, 0, 64)
- a.streamMu.Unlock()
+ // Per-stream buffer — local to this call, not shared across concurrent requests.
+ var buf []GenerateContentResponse
+ var bufMu sync.Mutex
+ bufReady := make(chan struct{})
go func() {
defer close(events)
+ defer close(bufReady)
// Per-candidate state for delta computation.
type candidateState struct {
@@ -297,7 +295,12 @@ func (a *GeminiProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-ch
case <-ctx.Done():
return
case chunk, ok := <-ch:
- a.bufferStreamEvent(chunk)
+ // Append to local per-stream buffer with size cap.
+ bufMu.Lock()
+ if len(buf) < 1024 {
+ buf = append(buf, chunk)
+ }
+ bufMu.Unlock()
if !ok {
// Channel closed — emit completion if not already done.
if !seenCompletion {
@@ -386,7 +389,19 @@ func (a *GeminiProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-ch
}
}()
- return events, nil
+ return &format.StreamResult{
+ Events: events,
+ StreamBuffer: func() []any {
+ <-bufReady
+ bufMu.Lock()
+ defer bufMu.Unlock()
+ result := make([]any, len(buf))
+ for i, ev := range buf {
+ result[i] = ev
+ }
+ return result
+ },
+ }, nil
}
// =========================================================================
@@ -394,7 +409,7 @@ func (a *GeminiProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-ch
// =========================================================================
// blocksToContent converts []CoreContentBlock to Content (Gemini format).
-func (a *GeminiProviderAdapter) blocksToContent(blocks []format.CoreContentBlock) Content {
+func (a *GeminiProviderAdapter) blocksToContent(blocks []format.CoreContentBlock, toolUseIDMap map[string]string) Content {
parts := make([]Part, 0, len(blocks))
for _, b := range blocks {
switch b.Type {
@@ -409,9 +424,9 @@ func (a *GeminiProviderAdapter) blocksToContent(blocks []format.CoreContentBlock
})
case "tool_use":
// Store ToolUseID -> ToolName mapping for later FunctionResponse resolution (G-01).
- a.toolUseIDMu.Lock()
- a.toolUseIDMap[b.ToolUseID] = b.ToolName
- a.toolUseIDMu.Unlock()
+ if toolUseIDMap != nil {
+ toolUseIDMap[b.ToolUseID] = b.ToolName
+ }
parts = append(parts, Part{
FunctionCall: &FunctionCall{
Name: b.ToolName,
@@ -421,11 +436,11 @@ func (a *GeminiProviderAdapter) blocksToContent(blocks []format.CoreContentBlock
case "tool_result":
// Look up the function name from ToolUseID (G-01).
funcName := b.ToolUseID
- a.toolUseIDMu.Lock()
- if fn, ok := a.toolUseIDMap[b.ToolUseID]; ok {
- funcName = fn
+ if toolUseIDMap != nil {
+ if fn, ok := toolUseIDMap[b.ToolUseID]; ok {
+ funcName = fn
+ }
}
- a.toolUseIDMu.Unlock()
// Combine tool result content into a single text for the response.
var respText string
@@ -554,45 +569,59 @@ func (a *GeminiProviderAdapter) applyGenerationConfigMap(gc *GenerationConfig, c
func (a *GeminiProviderAdapter) fromParts(parts []Part) []format.CoreContentBlock {
result := make([]format.CoreContentBlock, 0, len(parts))
funcCallSeq := make(map[string]int)
+ callIDStacks := make(map[string][]string) // per-function-name call ID stack for FunctionResponse matching
for _, p := range parts {
- result = append(result, a.fromPartWithSeq(p, funcCallSeq))
+ block, stacks := a.fromPartWithSeq(p, funcCallSeq, callIDStacks)
+ if stacks != nil {
+ callIDStacks = stacks
+ }
+ result = append(result, block)
}
return result
}
// fromPartWithSeq converts a single Gemini Part to CoreContentBlock.
-func (a *GeminiProviderAdapter) fromPartWithSeq(p Part, funcCallSeq map[string]int) format.CoreContentBlock {
+// Returns the block and optionally an updated callIDStacks for FunctionCall tracking.
+func (a *GeminiProviderAdapter) fromPartWithSeq(p Part, funcCallSeq map[string]int, callIDStacks map[string][]string) (format.CoreContentBlock, map[string][]string) {
switch {
case p.Text != "":
return format.CoreContentBlock{
Type: "text",
Text: p.Text,
- }
+ }, nil
case p.FunctionCall != nil:
callName := p.FunctionCall.Name
funcCallSeq[callName]++
callID := callName + "__call_" + strconv.Itoa(funcCallSeq[callName])
+ callIDStacks[callName] = append(callIDStacks[callName], callID)
return format.CoreContentBlock{
Type: "tool_use",
ToolUseID: callID,
ToolName: callName,
ToolInput: p.FunctionCall.Args,
- }
+ }, callIDStacks
case p.FunctionResponse != nil:
+ respName := p.FunctionResponse.Name
+ callID := ""
+ if stack := callIDStacks[respName]; len(stack) > 0 {
+ callID = stack[len(stack)-1]
+ } else {
+ callID = respName + "__call_1"
+ }
return format.CoreContentBlock{
Type: "tool_result",
- ToolUseID: p.FunctionResponse.Name,
- }
+ ToolUseID: callID,
+ }, nil
case p.InlineData != nil:
return format.CoreContentBlock{
Type: "image",
ImageData: p.InlineData.Data,
MediaType: p.InlineData.MimeType,
- }
+ }, nil
default:
return format.CoreContentBlock{
Type: "text",
- }
+ }, nil
}
}
diff --git a/internal/protocol/google/client.go b/internal/protocol/google/client.go
index e7a3591e..7aaa4d9c 100644
--- a/internal/protocol/google/client.go
+++ b/internal/protocol/google/client.go
@@ -20,9 +20,9 @@ import (
type ClientConfig struct {
BaseURL string
APIKey string
- Project string // Vertex AI project ID (optional, for Vertex AI endpoint)
- Location string // Vertex AI location (optional, default "us-central1")
- Version string // API version (default "v1")
+ Project string // Vertex AI project ID (optional, for Vertex AI endpoint)
+ Location string // Vertex AI location (optional, default "us-central1")
+ Version string // API version (default "v1")
UserAgent string
Client *http.Client
}
@@ -146,7 +146,6 @@ func (c *Client) StreamGenerateContent(ctx context.Context, model string, req *G
// to close (connections are managed by http.Client), so this is a no-op.
func (c *Client) Close() error { return nil }
-
// ============================================================================
// CachedContent API methods
// ============================================================================
@@ -243,6 +242,7 @@ func (c *Client) DeleteCachedContent(ctx context.Context, name string) error {
}
return nil
}
+
// ============================================================================
// Internal helpers
// ============================================================================
diff --git a/internal/protocol/google/google_test.go b/internal/protocol/google/google_test.go
index c876bf9a..13f94ffa 100644
--- a/internal/protocol/google/google_test.go
+++ b/internal/protocol/google/google_test.go
@@ -816,12 +816,17 @@ func TestFromCoreRequest_ToolUseAndToolResult(t *testing.T) {
}
geminiReq := result.(*google.GenerateContentRequest)
- if len(geminiReq.Contents) != 2 {
- t.Fatalf("Contents: got %d, want 2", len(geminiReq.Contents))
+ if len(geminiReq.Contents) != 3 {
+ t.Fatalf("Contents: got %d, want 3 (placeholder + assistant + user)", len(geminiReq.Contents))
+ }
+
+ // First Content is the inserted user placeholder.
+ if geminiReq.Contents[0].Role != "user" || len(geminiReq.Contents[0].Parts) != 1 || geminiReq.Contents[0].Parts[0].Text != "_" {
+ t.Fatalf("expected user placeholder as first Content, got role=%q parts=%v", geminiReq.Contents[0].Role, geminiReq.Contents[0].Parts)
}
// Assistant message: FunctionCall
- astParts := geminiReq.Contents[0].Parts
+ astParts := geminiReq.Contents[1].Parts
if len(astParts) != 1 {
t.Fatalf("assistant Parts: got %d, want 1", len(astParts))
}
@@ -836,7 +841,7 @@ func TestFromCoreRequest_ToolUseAndToolResult(t *testing.T) {
}
// User message: FunctionResponse
- userParts := geminiReq.Contents[1].Parts
+ userParts := geminiReq.Contents[2].Parts
if len(userParts) != 1 {
t.Fatalf("user Parts: got %d, want 1", len(userParts))
}
@@ -1081,15 +1086,23 @@ func TestFromCoreRequest_ReasoningBlock(t *testing.T) {
}
geminiReq := result.(*google.GenerateContentRequest)
+ // First Content is the inserted user placeholder.
+ if len(geminiReq.Contents) != 3 {
+ t.Fatalf("Contents: got %d, want 3 (placeholder + assistant + user)", len(geminiReq.Contents))
+ }
+ if geminiReq.Contents[0].Role != "user" || len(geminiReq.Contents[0].Parts) != 1 || geminiReq.Contents[0].Parts[0].Text != "_" {
+ t.Fatalf("expected user placeholder as first Content, got role=%q", geminiReq.Contents[0].Role)
+ }
+
// Assistant message should have 2 parts (reasoning converted to text).
- if len(geminiReq.Contents[0].Parts) != 2 {
- t.Fatalf("assistant Parts: got %d, want 2 (reasoning block converted to text)", len(geminiReq.Contents[0].Parts))
+ if len(geminiReq.Contents[1].Parts) != 2 {
+ t.Fatalf("assistant Parts: got %d, want 2 (reasoning block converted to text)", len(geminiReq.Contents[1].Parts))
}
- if geminiReq.Contents[0].Parts[0].Text != "thinking step by step" {
- t.Errorf("assistant parts[0] reasoning text = %q, want thinking step by step", geminiReq.Contents[0].Parts[0].Text)
+ if geminiReq.Contents[1].Parts[0].Text != "thinking step by step" {
+ t.Errorf("assistant parts[0] reasoning text = %q, want thinking step by step", geminiReq.Contents[1].Parts[0].Text)
}
- if geminiReq.Contents[0].Parts[1].Text != "final answer" {
- t.Errorf("assistant parts[1] text = %q, want final answer", geminiReq.Contents[0].Parts[1].Text)
+ if geminiReq.Contents[1].Parts[1].Text != "final answer" {
+ t.Errorf("assistant parts[1] text = %q, want final answer", geminiReq.Contents[1].Parts[1].Text)
}
}
@@ -1381,7 +1394,7 @@ func TestToCoreStream_SingleCandidate(t *testing.T) {
}
var evts []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
evts = append(evts, e)
}
@@ -1450,7 +1463,7 @@ func TestToCoreStream_MultiCandidate(t *testing.T) {
}
var evts []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
evts = append(evts, e)
}
@@ -1496,7 +1509,7 @@ func TestToCoreStream_ContextCancel(t *testing.T) {
cancel()
// Channel should close cleanly
- for range events {
+ for range events.Events {
}
}
@@ -1511,7 +1524,7 @@ func TestToCoreStream_EmptyChannel(t *testing.T) {
}
var evts []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
evts = append(evts, e)
}
@@ -1794,8 +1807,10 @@ func TestFromCoreRequest_DefaultContentBlockNoText(t *testing.T) {
}
geminiReq := result.(*google.GenerateContentRequest)
// Unknown type with no text produces no Part (default falls through without appending)
- if len(geminiReq.Contents[0].Parts) != 0 {
- t.Errorf("Parts: got %d, want 0 (unknown type with no text should be skipped)", len(geminiReq.Contents[0].Parts))
+ if len(geminiReq.Contents) == 0 {
+ t.Logf("Contents empty (all blocks filtered)")
+ } else if len(geminiReq.Contents[0].Parts) != 0 {
+ t.Errorf("Parts: got %d, want 0 (unknown type should be skipped)", len(geminiReq.Contents[0].Parts))
}
}
@@ -1822,7 +1837,7 @@ func TestToCoreResponse_FromPartFunctionResponse(t *testing.T) {
if blocks[0].Type != "tool_result" {
t.Errorf("Type = %q, want tool_result", blocks[0].Type)
}
- if blocks[0].ToolUseID != "get_weather" {
+ if blocks[0].ToolUseID != "get_weather__call_1" {
t.Errorf("ToolUseID = %q", blocks[0].ToolUseID)
}
}
@@ -1903,7 +1918,7 @@ func TestToCoreStream_ComputeDeltaNoChange(t *testing.T) {
}
var evts []format.CoreStreamEvent
- for e := range events {
+ for e := range events.Events {
evts = append(evts, e)
}
diff --git a/internal/protocol/google/types.go b/internal/protocol/google/types.go
index e95e47c8..bbd64110 100644
--- a/internal/protocol/google/types.go
+++ b/internal/protocol/google/types.go
@@ -10,12 +10,12 @@ import "encoding/json"
// GenerateContentRequest maps to Gemini's generateContent request body.
// https://ai.google.dev/api/generate-content
type GenerateContentRequest struct {
- Contents []Content `json:"contents"`
- SystemInstruction *Content `json:"systemInstruction,omitempty"`
- SafetySettings []SafetySetting `json:"safetySettings,omitempty"`
- GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
- Tools []Tool `json:"tools,omitempty"`
- ToolConfig json.RawMessage `json:"toolConfig,omitempty"`
+ Contents []Content `json:"contents"`
+ SystemInstruction *Content `json:"systemInstruction,omitempty"`
+ SafetySettings []SafetySetting `json:"safetySettings,omitempty"`
+ GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
+ Tools []Tool `json:"tools,omitempty"`
+ ToolConfig json.RawMessage `json:"toolConfig,omitempty"`
// CachedContent references a CachedContent resource for prompt caching.
// When set, system_instruction, tools, and tool_config must not be set
// (Gemini API constraint — they become part of the cached content).
@@ -31,11 +31,11 @@ type Content struct {
// Part represents a single part within Content.
type Part struct {
- Text string `json:"text,omitempty"`
- InlineData *Blob `json:"inlineData,omitempty"`
- FileData *FileData `json:"fileData,omitempty"`
- FunctionCall *FunctionCall `json:"functionCall,omitempty"`
- FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
+ Text string `json:"text,omitempty"`
+ InlineData *Blob `json:"inlineData,omitempty"`
+ FileData *FileData `json:"fileData,omitempty"`
+ FunctionCall *FunctionCall `json:"functionCall,omitempty"`
+ FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
}
// Blob represents inline binary data.
@@ -106,10 +106,10 @@ type GenerateContentResponse struct {
// Candidate represents a single response candidate.
type Candidate struct {
- Index int `json:"index"`
- Content Content `json:"content"`
- FinishReason string `json:"finishReason"` // STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
- SafetyRatings []SafetyRating `json:"safetyRatings,omitempty"`
+ Index int `json:"index"`
+ Content Content `json:"content"`
+ FinishReason string `json:"finishReason"` // STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
+ SafetyRatings []SafetyRating `json:"safetyRatings,omitempty"`
}
// SafetyRating represents a safety rating for a category.
diff --git a/internal/protocol/openai/adapter.go b/internal/protocol/openai/adapter.go
index 32038518..330960a4 100644
--- a/internal/protocol/openai/adapter.go
+++ b/internal/protocol/openai/adapter.go
@@ -34,15 +34,19 @@ type OpenAIAdapter struct {
hooks format.CorePluginHooks
disablePatchProxy func(string) bool
- streamMu sync.Mutex
- streamEvents []StreamEvent
+ nsStrategy codextool.NamespaceStrategy
}
// NewOpenAIAdapter creates a new OpenAIAdapter with the given config and hooks.
-func NewOpenAIAdapter(hooks format.CorePluginHooks) *OpenAIAdapter {
+func NewOpenAIAdapter(hooks format.CorePluginHooks, nsStrategy ...codextool.NamespaceStrategy) *OpenAIAdapter {
+ strategy := codextool.NestedOneOf
+ if len(nsStrategy) > 0 {
+ strategy = nsStrategy[0]
+ }
return &OpenAIAdapter{
- hooks: hooks.WithDefaults(),
+ hooks: hooks.WithDefaults(),
disablePatchProxy: hooks.DisablePatchProxy,
+ nsStrategy: strategy,
}
}
@@ -110,7 +114,7 @@ func (a *OpenAIAdapter) ToCoreRequest(ctx context.Context, req any) (*format.Cor
// 5. Convert tools.
if len(openaiReq.Tools) > 0 {
- coreReq.Tools = flattenToolsWithNamespace(openaiReq.Tools, "", a.disablePatchProxy)
+ coreReq.Tools = flattenToolsWithNamespace(openaiReq.Tools, "", a.disablePatchProxy, a.nsStrategy)
}
if injected := a.hooks.InjectTools(format.ContextWithCoreRequest(ctx, coreReq)); len(injected) > 0 {
coreReq.Tools = append(coreReq.Tools, injected...)
@@ -318,39 +322,85 @@ func (a *OpenAIAdapter) FromCoreResponse(ctx context.Context, resp *format.CoreR
// to produce correct OpenAI stream semantics.
func (a *OpenAIAdapter) FromCoreStream(ctx context.Context, req *format.CoreRequest, events <-chan format.CoreStreamEvent) (any, error) {
out := make(chan StreamEvent)
-
- go a.streamLoop(ctx, req, events, out)
-
- return (<-chan StreamEvent)(out), nil
+ bufReady := make(chan struct{})
+
+ var buf []StreamEvent
+ var bufMu sync.Mutex
+
+ go func() {
+ defer close(bufReady)
+ a.streamLoopWithBuf(ctx, req, events, out, &buf, &bufMu)
+ }()
+
+ return &OpenAIStreamResult{
+ ch: out,
+ buf: func() []any {
+ <-bufReady
+ bufMu.Lock()
+ defer bufMu.Unlock()
+ result := make([]any, len(buf))
+ for i, ev := range buf {
+ result[i] = ev
+ }
+ return result
+ },
+ }, nil
}
// bufferStreamEvent buffers the OpenAI stream event for trace capture,
// up to the 4MB limit. The event is JSON-marshalled to estimate its size.
func (a *OpenAIAdapter) bufferStreamEvent(ev StreamEvent) {
- a.streamMu.Lock()
- defer a.streamMu.Unlock()
- a.streamEvents = append(a.streamEvents, ev)
+ // No-op: per-stream buffer is captured by streamLoop's closure.
}
// StreamBuffer returns the buffered stream events for trace capture.
func (a *OpenAIAdapter) StreamBuffer() []StreamEvent {
- a.streamMu.Lock()
- defer a.streamMu.Unlock()
- return a.streamEvents
+ // Deprecated: use StreamResult.StreamBuffer instead.
+ return nil
+}
+
+// openaiStreamResult wraps the OpenAI stream channel with per-stream buffer access.
+type OpenAIStreamResult struct {
+ ch <-chan StreamEvent
+ buf func() []any
+}
+
+// Chan returns the underlying channel of StreamEvents.
+func (r *OpenAIStreamResult) Chan() <-chan StreamEvent {
+ return r.ch
+}
+
+// Buffer returns the captured stream events for post-stream processing.
+func (r *OpenAIStreamResult) Buffer() []any {
+ if r.buf == nil {
+ return nil
+ }
+ return r.buf()
}
// streamLoop is the goroutine body for FromCoreStream.
-func (a *OpenAIAdapter) streamLoop(ctx context.Context, coreReq *format.CoreRequest, events <-chan format.CoreStreamEvent, out chan<- StreamEvent) {
- defer close(out)
+// nestedBufferState tracks two-level buffering for nested namespace tool calls.
+type nestedBufferState struct {
+ toolUseID string
+ toolName string // original namespace-expanded name
+ actionName string // extracted sub-tool action name
+ namespace string // item namespace
+ outputIndex int // index in response.Output
+ emitted bool // whether output_item.added has been sent
+ buf strings.Builder // accumulated raw JSON arguments
+ sequence func() int64 // event sequencer (captures next func)
+}
- // Reset stream event buffer for this request.
- a.streamMu.Lock()
- a.streamEvents = nil
- a.streamMu.Unlock()
+func (a *OpenAIAdapter) streamLoopWithBuf(ctx context.Context, coreReq *format.CoreRequest, events <-chan format.CoreStreamEvent, out chan<- StreamEvent, buf *[]StreamEvent, bufMu *sync.Mutex) {
+ defer close(out)
// send buffers the event for trace capture before writing to the output channel.
send := func(ev StreamEvent) {
- a.bufferStreamEvent(ev)
+ bufMu.Lock()
+ if len(*buf) < 1024 {
+ *buf = append(*buf, ev)
+ }
+ bufMu.Unlock()
out <- ev
}
@@ -372,6 +422,7 @@ func (a *OpenAIAdapter) streamLoop(ctx context.Context, coreReq *format.CoreRequ
itemIDs := make(map[int]string)
reasonIndexes := make(map[int]bool)
toolCallFinalized := make(map[int]bool)
+ nestedBuffers := make(map[int]*nestedBufferState)
for event := range events {
// Let hooks skip events.
@@ -468,18 +519,37 @@ func (a *OpenAIAdapter) streamLoop(ctx context.Context, coreReq *format.CoreRequ
}
itemIDs[index] = fmt.Sprintf("fc_item_%d", index)
toolBlockNames[index] = event.ContentBlock.ToolName
- item := buildToolOutputItemStreaming(event.ContentBlock, coreReq.Extensions, toolUseID)
- outputIndexes[index] = len(response.Output)
- response.Output = append(response.Output, item)
- send(StreamEvent{
- Event: "response.output_item.added",
- Data: OutputItemEvent{
- Type: "response.output_item.added",
- SequenceNumber: next(),
- OutputIndex: outputIndexes[index],
- Item: item,
- },
- })
+
+ // Check if this tool is a nested namespace (NestedOneOf/NestedAnyOf).
+ // If so, defer output_item.added until we extract the action from args.
+ toolMap := codextool.DecodeToolMapFromExtensions(coreReq.Extensions)
+ spec, hasSpec := toolMap.Lookup(event.ContentBlock.ToolName)
+ isNested := hasSpec && (spec.Kind == codextool.ToolNestedOneOf || spec.Kind == codextool.ToolNestedAnyOf)
+
+ if isNested {
+ // Defer emission: buffer args until action is extracted.
+ nestedBuffers[index] = &nestedBufferState{
+ toolUseID: toolUseID,
+ toolName: event.ContentBlock.ToolName,
+ namespace: spec.Namespace,
+ emitted: false,
+ outputIndex: -1,
+ sequence: next,
+ }
+ } else {
+ item := buildToolOutputItemStreaming(event.ContentBlock, coreReq.Extensions, toolUseID)
+ outputIndexes[index] = len(response.Output)
+ response.Output = append(response.Output, item)
+ send(StreamEvent{
+ Event: "response.output_item.added",
+ Data: OutputItemEvent{
+ Type: "response.output_item.added",
+ SequenceNumber: next(),
+ OutputIndex: outputIndexes[index],
+ Item: item,
+ },
+ })
+ }
}
// ==================================================================
@@ -619,6 +689,50 @@ func (a *OpenAIAdapter) streamLoop(ctx context.Context, coreReq *format.CoreRequ
case format.CoreToolCallArgsDelta:
index := event.Index
toolCallArgs[index] += event.Delta
+
+ // Check if this is a buffered nested namespace tool call.
+ if nBuf, isBuffered := nestedBuffers[index]; isBuffered {
+ nBuf.buf.WriteString(event.Delta)
+
+ if !nBuf.emitted {
+ if action, ok := codextool.TryExtractAction(nBuf.buf.String()); ok {
+ nBuf.actionName = action
+ nBuf.emitted = true
+
+ // Emit output_item.added with the correct action name.
+ item := OutputItem{
+ Type: "function_call",
+ ID: nBuf.toolUseID,
+ CallID: nBuf.toolUseID,
+ Name: action,
+ Status: "in_progress",
+ }
+ if nBuf.namespace != "" {
+ item.Namespace = nBuf.namespace
+ }
+ outputIndexes[index] = len(response.Output)
+ nBuf.outputIndex = outputIndexes[index]
+ response.Output = append(response.Output, item)
+ send(StreamEvent{
+ Event: "response.output_item.added",
+ Data: OutputItemEvent{
+ Type: "response.output_item.added",
+ SequenceNumber: next(),
+ OutputIndex: outputIndexes[index],
+ Item: item,
+ },
+ })
+
+ // Replay already-buffered params (minus the action prefix).
+ replayNestedBuffer(nBuf, send, next, index, itemIDs)
+ }
+ } else {
+ // Already emitted: pass through deltas directly.
+ emitNestedDelta(nBuf, event.Delta, send, next, index, itemIDs, outputIndexes)
+ }
+ break
+ }
+
send(StreamEvent{
Event: "response.function_call_arguments.delta",
Data: FunctionCallArgumentsDeltaEvent{
@@ -639,6 +753,66 @@ func (a *OpenAIAdapter) streamLoop(ctx context.Context, coreReq *format.CoreRequ
break
}
finalArgs := event.Delta
+
+ // Check if this is a buffered nested namespace tool call that hasn't
+ // emitted yet (action never extracted — flush all buffered data).
+ if nBuf, isBuffered := nestedBuffers[index]; isBuffered {
+ if !nBuf.emitted {
+ // Action never extracted — flush everything as the original name.
+ nBuf.actionName = nBuf.toolName
+ finalCombined := nBuf.buf.String()
+ if finalArgs != "" && finalCombined == "" {
+ finalCombined = finalArgs
+ }
+ item := OutputItem{
+ Type: "function_call",
+ ID: nBuf.toolUseID,
+ CallID: nBuf.toolUseID,
+ Name: nBuf.toolName,
+ Status: "completed",
+ }
+ if nBuf.namespace != "" {
+ item.Namespace = nBuf.namespace
+ }
+ item.Arguments = finalCombined
+ outputIndexes[index] = len(response.Output)
+ nBuf.outputIndex = outputIndexes[index]
+ response.Output = append(response.Output, item)
+ nBuf.emitted = true
+ send(StreamEvent{
+ Event: "response.output_item.added",
+ Data: OutputItemEvent{
+ Type: "response.output_item.added",
+ SequenceNumber: next(),
+ OutputIndex: outputIndexes[index],
+ Item: item,
+ },
+ })
+ send(StreamEvent{
+ Event: "response.function_call_arguments.done",
+ Data: FunctionCallArgumentsDoneEvent{
+ Type: "response.function_call_arguments.done",
+ SequenceNumber: next(),
+ ItemID: itemIDs[index],
+ OutputIndex: outputIndexes[index],
+ Arguments: finalCombined,
+ },
+ })
+ send(StreamEvent{
+ Event: "response.output_item.done",
+ Data: OutputItemEvent{
+ Type: "response.output_item.done",
+ SequenceNumber: next(),
+ OutputIndex: outputIndexes[index],
+ Item: response.Output[outputIndexes[index]],
+ },
+ })
+ delete(nestedBuffers, index)
+ break
+ }
+ // Already emitted — use existing output index.
+ delete(nestedBuffers, index)
+ }
if finalArgs == "" {
finalArgs = toolCallArgs[index]
}
@@ -1491,6 +1665,131 @@ func buildToolOutputItem(block format.CoreContentBlock, extensions map[string]an
// buildToolOutputItemStreaming constructs a streaming OutputItem for a tool_use content block start.
// The item is created with "in_progress" status.
+
+// replayNestedBuffer emits the accumulated params from a nested namespace buffer
+// as function_call_arguments.delta events, stripping the action prefix.
+func replayNestedBuffer(nBuf *nestedBufferState, send func(StreamEvent), next func() int64, index int, itemIDs map[int]string) {
+ if nBuf.buf.Len() == 0 {
+ return
+ }
+ paramsOnly := stripPrefixActionFromJSON(nBuf.buf.String(), nBuf.actionName)
+ if paramsOnly != "" {
+ send(StreamEvent{
+ Event: "response.function_call_arguments.delta",
+ Data: FunctionCallArgumentsDeltaEvent{
+ Type: "response.function_call_arguments.delta",
+ SequenceNumber: next(),
+ ItemID: itemIDs[index],
+ OutputIndex: nBuf.outputIndex,
+ Delta: paramsOnly,
+ },
+ })
+ }
+}
+
+// emitNestedDelta sends a function_call_arguments.delta for a nested namespace tool
+// that has already emitted its output_item.added.
+func emitNestedDelta(nBuf *nestedBufferState, delta string, send func(StreamEvent), next func() int64, index int, itemIDs map[int]string, outputIndexes map[int]int) {
+ cleanedDelta := stripPrefixActionFromJSON(delta, nBuf.actionName)
+ if cleanedDelta == "" {
+ return
+ }
+ oi := nBuf.outputIndex
+ if oi < 0 {
+ oi = outputIndexes[index]
+ }
+ send(StreamEvent{
+ Event: "response.function_call_arguments.delta",
+ Data: FunctionCallArgumentsDeltaEvent{
+ Type: "response.function_call_arguments.delta",
+ SequenceNumber: next(),
+ ItemID: itemIDs[index],
+ OutputIndex: oi,
+ Delta: cleanedDelta,
+ },
+ })
+}
+
+// stripPrefixActionFromJSON removes the "action": "value" portion from the start
+// of a partial JSON string. Uses a position-constrained scan that only looks for
+// "action" as a top-level key (before any nested "{" or after the first "," that
+// signals the end of the first key-value pair). Falls back to full JSON parse
+// when the buffer is syntactically complete.
+func stripPrefixActionFromJSON(raw string, action string) string {
+ if raw == "" {
+ return ""
+ }
+
+ // First, try a full JSON parse — if the buffer is complete, this is the
+ // most robust path.
+ var parsed map[string]json.RawMessage
+ if err := json.Unmarshal([]byte(raw), &parsed); err == nil {
+ delete(parsed, "action")
+ if len(parsed) == 0 {
+ return ""
+ }
+ data, _ := json.Marshal(parsed)
+ result := string(data)
+ // Strip outer braces for streaming delta context.
+ result = strings.TrimPrefix(result, "{")
+ result = strings.TrimSuffix(result, "}")
+ return strings.TrimSpace(result)
+ }
+
+ // Fallback: position-constrained scan. Only look for "action" in the
+ // first object level — roughly the content before a nested "{" or after
+ // the first top-level comma that follows the action key-value pair.
+ //
+ // Strategy: find "action" at the top level, extract its value, remove
+ // the key-value pair, and return the remaining JSON fragment.
+ idx := strings.Index(raw, `"action"`)
+ if idx < 0 {
+ return raw
+ }
+
+ // Only treat as top-level action if it appears before any nested object
+ // (the namespace tool schema is flat — action is at the root).
+ firstBrace := strings.IndexByte(raw, '{')
+ if firstBrace >= 0 && firstBrace < idx {
+ // "action" is not in the first object — return raw unchanged.
+ return raw
+ }
+
+ // Find the colon after the key.
+ afterKey := raw[idx+8:]
+ colonIdx := strings.IndexByte(afterKey, ':')
+ if colonIdx < 0 {
+ return raw
+ }
+
+ // Skip whitespace and colon.
+ afterColon := strings.TrimSpace(afterKey[colonIdx+1:])
+ if len(afterColon) == 0 {
+ return raw
+ }
+
+ // Must start with a quote (action is always a string).
+ if afterColon[0] != '"' {
+ return raw
+ }
+ afterColon = afterColon[1:] // skip opening quote
+ endQuote := strings.IndexByte(afterColon, '"')
+ if endQuote < 0 {
+ return raw
+ }
+
+ // Extract the portion after the action value.
+ afterValue := strings.TrimSpace(afterColon[endQuote+1:])
+ // Strip trailing comma.
+ afterValue = strings.TrimLeft(afterValue, ", ")
+
+ // Combine the part before the action key with the part after the value.
+ prefix := strings.TrimRight(raw[:idx], ", ")
+ if prefix == "" || prefix == "{" || strings.TrimSpace(prefix) == "{" {
+ return afterValue
+ }
+ return prefix + ", " + afterValue
+}
func buildToolOutputItemStreaming(block *format.CoreContentBlock, extensions map[string]any, toolUseID string) OutputItem {
toolMap := codextool.DecodeToolMapFromExtensions(extensions)
itemT, itemN, itemNS, itemInput, isLS, actionJSON := codextool.OutputItemFromBlock(block.ToolName, block.ToolInput, toolMap)
@@ -1523,7 +1822,7 @@ func buildToolOutputItemStreaming(block *format.CoreContentBlock, extensions map
// Function/web_search/file_search/code_interpreter/computer_use_preview pass through.
// Custom tools are expanded using codex package helpers.
// Namespace tools are recursively flattened.
-func convertToolWithNamespace(tool Tool, namespace string, disablePatchProxy func(string) bool) []format.CoreTool {
+func convertToolWithNamespace(tool Tool, namespace string, disablePatchProxy func(string) bool, nsStrategy codextool.NamespaceStrategy) []format.CoreTool {
name := namespacedToolName(namespace, tool.Name)
ext := make(map[string]any)
@@ -1573,7 +1872,23 @@ func convertToolWithNamespace(tool Tool, namespace string, disablePatchProxy fun
case "namespace":
ns := namespacedToolName(namespace, tool.Name)
- return flattenToolsWithNamespace(tool.Tools, ns, disablePatchProxy)
+ // Build sub-tool map for BuildNamespaceTools.
+ subMap := make(map[string]format.CoreTool)
+ var subNames []string
+ for _, sub := range tool.Tools {
+ subNames = append(subNames, sub.Name)
+ subMap[sub.Name] = format.CoreTool{
+ Name: sub.Name,
+ Description: sub.Description,
+ InputSchema: sub.Parameters,
+ }
+ }
+ tools, err := codextool.BuildNamespaceTools(subNames, subMap, ns, nsStrategy)
+ if err != nil || len(tools) == 0 {
+ // Fallback to flat expansion.
+ return flattenToolsWithNamespace(tool.Tools, ns, disablePatchProxy, nsStrategy)
+ }
+ return tools
case "custom":
grammar := codextool.CustomToolGrammar(tool.Format)
@@ -1627,22 +1942,28 @@ func convertToolWithNamespace(tool Tool, namespace string, disablePatchProxy fun
// flattenToolsWithNamespace recursively flattens namespace tools and converts
// individual tools, building a flat list of CoreTools suitable for upstream providers.
-func flattenToolsWithNamespace(openaiTools []Tool, namespace string, disablePatchProxy func(string) bool) []format.CoreTool {
+func flattenToolsWithNamespace(openaiTools []Tool, namespace string, disablePatchProxy func(string) bool, nsStrategy codextool.NamespaceStrategy) []format.CoreTool {
var result []format.CoreTool
for _, t := range openaiTools {
- converted := convertToolWithNamespace(t, namespace, disablePatchProxy)
+ converted := convertToolWithNamespace(t, namespace, disablePatchProxy, nsStrategy)
result = append(result, converted...)
}
// Deduplicate by name: Codex may send the same tool both as a namespace member
// and as an independently-injected function tool (e.g. MCP tools that inject themselves
- // after first use). Keep the first occurrence, which has the correct metadata.
- seen := make(map[string]bool, len(result))
+ // after first use). Prefer tools with a codex_namespace annotation (comes from namespace
+ // expansion) over flat function tools with the same name.
+ seen := make(map[string]int, len(result)) // name → index in deduped
deduped := make([]format.CoreTool, 0, len(result))
for _, t := range result {
- if seen[t.Name] {
+ if existing, exists := seen[t.Name]; exists {
+ existingNS, _ := deduped[existing].Extensions["codex_namespace"].(string)
+ newNS, _ := t.Extensions["codex_namespace"].(string)
+ if existingNS == "" && newNS != "" {
+ deduped[existing] = t
+ }
continue
}
- seen[t.Name] = true
+ seen[t.Name] = len(deduped)
deduped = append(deduped, t)
}
result = deduped
diff --git a/internal/protocol/openai/adapter_test.go b/internal/protocol/openai/adapter_test.go
index 0bcca6d5..be3e9b3f 100644
--- a/internal/protocol/openai/adapter_test.go
+++ b/internal/protocol/openai/adapter_test.go
@@ -294,8 +294,8 @@ func TestToCoreRequest_BatchesCustomToolCallsAndOutputsIntoSingleRound(t *testin
for i, want := range []struct {
assistantTextIdx int
msgIdx int
- callID string
- outcome string
+ callID string
+ outcome string
}{
{0, 1, "call_a", "ok a"},
{3, 4, "call_b", "ok b"},
@@ -345,7 +345,13 @@ func TestFromCoreStream_NoDuplicateDoneForToolUse(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- stream := streamAny.(<-chan openai.StreamEvent)
+ var stream <-chan openai.StreamEvent
+ oaiResult, ok := streamAny.(*openai.OpenAIStreamResult)
+ if ok {
+ stream = oaiResult.Chan()
+ } else {
+ stream = streamAny.(<-chan openai.StreamEvent)
+ }
var argsDone int
var itemDone int
for ev := range stream {
diff --git a/internal/protocol/openai/types.go b/internal/protocol/openai/types.go
index 24ba63c7..6f82371a 100644
--- a/internal/protocol/openai/types.go
+++ b/internal/protocol/openai/types.go
@@ -116,10 +116,10 @@ type ContentPart struct {
// Usage represents token usage statistics.
type Usage struct {
- InputTokens int `json:"input_tokens,omitempty"`
- OutputTokens int `json:"output_tokens,omitempty"`
- TotalTokens int `json:"total_tokens"`
- InputTokensDetails InputTokensDetails `json:"input_tokens_details,omitempty"`
+ InputTokens int `json:"input_tokens,omitempty"`
+ OutputTokens int `json:"output_tokens,omitempty"`
+ TotalTokens int `json:"total_tokens"`
+ InputTokensDetails InputTokensDetails `json:"input_tokens_details,omitempty"`
OutputTokensDetails OutputTokensDetails `json:"output_tokens_details,omitempty"`
}
diff --git a/internal/service/api/api_test.go b/internal/service/api/api_test.go
index ce5ef004..dff58be3 100644
--- a/internal/service/api/api_test.go
+++ b/internal/service/api/api_test.go
@@ -130,9 +130,9 @@ func buildTableNameMap(tables []db.TableSpec) map[string]string {
// testAPIDB implements db.Store backed by an in-memory SQLite database for API tests.
type testAPIDB struct {
- t *testing.T
- db *sql.DB
- consumer string
+ t *testing.T
+ db *sql.DB
+ consumer string
tableNames map[string]string
}
diff --git a/internal/service/api/models.go b/internal/service/api/models.go
index fd196f1b..a7ffcd53 100644
--- a/internal/service/api/models.go
+++ b/internal/service/api/models.go
@@ -98,13 +98,13 @@ func (r *Router) handleGetModel(w http.ResponseWriter, req *http.Request) {
}
resp := map[string]any{
- "slug": slug,
- "display_name": def.DisplayName,
- "description": def.Description,
- "context_window": def.ContextWindow,
- "max_output_tokens": def.MaxOutputTokens,
- "input_modalities": def.InputModalities,
- "providers": providers,
+ "slug": slug,
+ "display_name": def.DisplayName,
+ "description": def.Description,
+ "context_window": def.ContextWindow,
+ "max_output_tokens": def.MaxOutputTokens,
+ "input_modalities": def.InputModalities,
+ "providers": providers,
}
respondJSON(w, http.StatusOK, resp)
@@ -132,10 +132,10 @@ func (r *Router) handlePutModel(w http.ResponseWriter, req *http.Request) {
// Build metadata JSON.
meta := map[string]any{
- "display_name": body.DisplayName,
- "description": body.Description,
- "context_window": body.ContextWindow,
- "max_output_tokens": body.MaxOutputTokens,
+ "display_name": body.DisplayName,
+ "description": body.Description,
+ "context_window": body.ContextWindow,
+ "max_output_tokens": body.MaxOutputTokens,
}
metaJSON, _ := json.Marshal(meta)
diff --git a/internal/service/api/router.go b/internal/service/api/router.go
index 500f31b1..5d5c85aa 100644
--- a/internal/service/api/router.go
+++ b/internal/service/api/router.go
@@ -55,14 +55,14 @@ func NewRouter(cfg ConfigStore, rt *runtime.Runtime, st *stats.SessionStats, reg
mux := http.NewServeMux()
// Auth middleware for all API routes.
- authMW := AuthMiddleware(func() string {
- return r.server.CurrentConfig().AuthToken()
- }, func() bool { return r.store != nil })
+ authMW := AuthMiddleware(func() string {
+ return r.server.CurrentConfig().AuthToken()
+ }, func() bool { return r.store != nil })
- // Register all routes using Go 1.22+ pattern matching.
- registerRoutes(mux, r)
+ // Register all routes using Go 1.22+ pattern matching.
+ registerRoutes(mux, r)
- return authMW(mux)
+ return authMW(mux)
}
// registerRoutes registers all API endpoints with the mux.
diff --git a/internal/service/api/settings.go b/internal/service/api/settings.go
index f7dcf333..11ab3c99 100644
--- a/internal/service/api/settings.go
+++ b/internal/service/api/settings.go
@@ -66,11 +66,11 @@ func (r *Router) handleGetWebSearch(w http.ResponseWriter, req *http.Request) {
cfg := r.runtime.Current()
resp := map[string]any{
- "support": string(cfg.Config.WebSearchSupport),
- "max_uses": cfg.Config.WebSearchMaxUses,
- "tavily_api_key": maskAPIKey(cfg.Config.TavilyAPIKey),
- "firecrawl_api_key": maskAPIKey(cfg.Config.FirecrawlAPIKey),
- "search_max_rounds": cfg.Config.SearchMaxRounds,
+ "support": string(cfg.Config.WebSearchSupport),
+ "max_uses": cfg.Config.WebSearchMaxUses,
+ "tavily_api_key": maskAPIKey(cfg.Config.TavilyAPIKey),
+ "firecrawl_api_key": maskAPIKey(cfg.Config.FirecrawlAPIKey),
+ "search_max_rounds": cfg.Config.SearchMaxRounds,
}
respondJSON(w, http.StatusOK, resp)
@@ -103,11 +103,11 @@ func (r *Router) handlePutWebSearch(w http.ResponseWriter, req *http.Request) {
}
wsJSON, _ := json.Marshal(map[string]any{
- "support": body.Support,
- "max_uses": body.MaxUses,
- "tavily_api_key": tavilyKey,
- "firecrawl_api_key": firecrawlKey,
- "search_max_rounds": body.SearchMaxRounds,
+ "support": body.Support,
+ "max_uses": body.MaxUses,
+ "tavily_api_key": tavilyKey,
+ "firecrawl_api_key": firecrawlKey,
+ "search_max_rounds": body.SearchMaxRounds,
})
chID, err := r.store.StageChange(store.ChangeRow{
@@ -304,10 +304,10 @@ func (r *Router) handlePostConfigImport(w http.ResponseWriter, req *http.Request
for slug, def := range cfg.Models {
meta := map[string]any{
- "display_name": def.DisplayName,
- "description": def.Description,
- "context_window": def.ContextWindow,
- "max_output_tokens": def.MaxOutputTokens,
+ "display_name": def.DisplayName,
+ "description": def.Description,
+ "context_window": def.ContextWindow,
+ "max_output_tokens": def.MaxOutputTokens,
}
metaJSON, _ := json.Marshal(meta)
afterJSON, _ := json.Marshal(map[string]any{
@@ -381,11 +381,11 @@ func (r *Router) handlePostConfigImport(w http.ResponseWriter, req *http.Request
// Stage web_search if set.
if cfg.WebSearchSupport != "" || cfg.TavilyAPIKey != "" || cfg.FirecrawlAPIKey != "" {
wsJSON, _ := json.Marshal(map[string]any{
- "support": string(cfg.WebSearchSupport),
- "max_uses": cfg.WebSearchMaxUses,
- "tavily_api_key": cfg.TavilyAPIKey,
- "firecrawl_api_key": cfg.FirecrawlAPIKey,
- "search_max_rounds": cfg.SearchMaxRounds,
+ "support": string(cfg.WebSearchSupport),
+ "max_uses": cfg.WebSearchMaxUses,
+ "tavily_api_key": cfg.TavilyAPIKey,
+ "firecrawl_api_key": cfg.FirecrawlAPIKey,
+ "search_max_rounds": cfg.SearchMaxRounds,
})
chID, err := r.store.StageChange(store.ChangeRow{
Action: "update",
diff --git a/internal/service/api/status.go b/internal/service/api/status.go
index 8ec967a3..f5e43cc1 100644
--- a/internal/service/api/status.go
+++ b/internal/service/api/status.go
@@ -4,7 +4,6 @@ import (
"net/http"
"sort"
"time"
-
)
// ---- Status ----
diff --git a/internal/service/app/app.go b/internal/service/app/app.go
index bccf2b6a..b4ed6775 100644
--- a/internal/service/app/app.go
+++ b/internal/service/app/app.go
@@ -12,6 +12,7 @@ import (
"log/slog"
"moonbridge/internal/config"
"moonbridge/internal/db"
+ "moonbridge/internal/extension/codextool"
"moonbridge/internal/format"
"moonbridge/internal/logger"
"moonbridge/internal/protocol/anthropic"
@@ -205,7 +206,7 @@ func runTransform(ctx context.Context, cfg config.Config, errors io.Writer) erro
coreHooks := plugins.CorePluginHooks()
// Inbound: OpenAI Responses client adapter.
- oaiAdapter := openai.NewOpenAIAdapter(coreHooks)
+ oaiAdapter := openai.NewOpenAIAdapter(coreHooks, codextool.NestedOneOf)
_ = adapterReg.RegisterClient(oaiAdapter)
_ = adapterReg.RegisterClientStream(oaiAdapter)
diff --git a/internal/service/app/extensions.go b/internal/service/app/extensions.go
index 423691b4..12782aad 100644
--- a/internal/service/app/extensions.go
+++ b/internal/service/app/extensions.go
@@ -4,15 +4,15 @@ import (
"database/sql"
"log/slog"
+ "moonbridge/internal/config"
+ codextoolproxy "moonbridge/internal/extension/codex_tool_proxy"
dbd1 "moonbridge/internal/extension/db/d1"
dbsqlite "moonbridge/internal/extension/db/sqlite"
deepseekv4 "moonbridge/internal/extension/deepseek_v4"
kimiworkaround "moonbridge/internal/extension/kimi_workaround"
mbtrics "moonbridge/internal/extension/metrics"
- codextoolproxy "moonbridge/internal/extension/codex_tool_proxy"
"moonbridge/internal/extension/plugin"
"moonbridge/internal/extension/visual"
- "moonbridge/internal/config"
)
// ExtensionOptions controls optional initialization of built-in plugins.
diff --git a/internal/service/provider/manager.go b/internal/service/provider/manager.go
index 7076c3d1..949bac1c 100644
--- a/internal/service/provider/manager.go
+++ b/internal/service/provider/manager.go
@@ -135,7 +135,7 @@ func NewAnthropicClientAdapter(client *anthropic.Client) ProviderClient {
}
type ProviderManager struct {
- mu sync.Mutex // guards field replacement during Reload
+ mu sync.RWMutex // guards field replacement during Reload
clients map[string]ProviderClient
providers map[string]ProviderConfig // provider key -> config (for inspection)
routes map[string]ModelRoute // model alias -> route
@@ -276,6 +276,8 @@ func (pm *ProviderManager) Reload(cfg config.ProviderConfig) error {
// It returns the default provider if the alias is not explicitly routed.
func (pm *ProviderManager) ClientFor(modelAlias string) (string, ProviderClient, error) {
// Direct provider/model reference.
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
if provider, upstream := ParseModelRef(modelAlias); provider != "" {
if client, ok := pm.clients[provider]; ok {
return upstream, client, nil
@@ -312,6 +314,8 @@ func (pm *ProviderManager) ClientFor(modelAlias string) (string, ProviderClient,
// Returns error if no candidates are found.
func (pm *ProviderManager) ResolveModel(modelName string) (*ResolvedRoute, error) {
// 1. Route alias (highest priority)
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
if route, ok := pm.routes[modelName]; ok {
providerKey := route.Provider
if providerKey == "" {
@@ -325,7 +329,7 @@ func (pm *ProviderManager) ResolveModel(modelName string) (*ResolvedRoute, error
Candidates: []ProviderCandidate{{
ProviderKey: providerKey,
UpstreamModel: route.Name,
- Protocol: pm.ProtocolForKey(providerKey),
+ Protocol: pm.protocolForKeyInline(providerKey),
Client: client,
}},
}, nil
@@ -341,7 +345,7 @@ func (pm *ProviderManager) ResolveModel(modelName string) (*ResolvedRoute, error
Candidates: []ProviderCandidate{{
ProviderKey: providerKey,
UpstreamModel: upstreamModel,
- Protocol: pm.ProtocolForKey(providerKey),
+ Protocol: pm.protocolForKeyInline(providerKey),
Client: client,
}},
}, nil
@@ -369,7 +373,7 @@ func (pm *ProviderManager) ResolveModel(modelName string) (*ResolvedRoute, error
candidates = append(candidates, ProviderCandidate{
ProviderKey: entry.providerKey,
UpstreamModel: modelName,
- Protocol: pm.ProtocolForKey(entry.providerKey),
+ Protocol: pm.protocolForKeyInline(entry.providerKey),
Client: client,
})
}
@@ -417,6 +421,8 @@ func (pm *ProviderManager) ProbeWebSearchCandidate(ctx context.Context, provider
// ProviderKeys returns all configured provider keys.
func (pm *ProviderManager) ProviderKeys() []string {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
keys := make([]string, 0, len(pm.clients))
for k := range pm.clients {
keys = append(keys, k)
@@ -426,6 +432,8 @@ func (pm *ProviderManager) ProviderKeys() []string {
// DefaultKey returns the default provider key.
func (pm *ProviderManager) DefaultKey() string {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
return pm.defaultK
}
@@ -445,10 +453,11 @@ func newHTTPClient(cfg HTTPConfig) *http.Client {
return &http.Client{
Transport: &http.Transport{
- MaxIdleConns: 100,
- MaxIdleConnsPerHost: maxIdle,
- IdleConnTimeout: idleTimeout,
- DisableCompression: false,
+ MaxIdleConns: 100,
+ MaxIdleConnsPerHost: maxIdle,
+ IdleConnTimeout: idleTimeout,
+ DisableCompression: false,
+ ResponseHeaderTimeout: 30 * time.Second,
},
}
}
@@ -462,6 +471,8 @@ func valueOrDefault(value, fallback string) string {
// ClientForKey returns the anthropic.Client for a given provider key.
func (pm *ProviderManager) ClientForKey(key string) (ProviderClient, error) {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
client, ok := pm.clients[key]
if !ok {
return nil, fmt.Errorf("provider %q not found", key)
@@ -471,7 +482,25 @@ func (pm *ProviderManager) ClientForKey(key string) (ProviderClient, error) {
// ProtocolForKey returns the protocol for a given provider key.
// Returns "anthropic" if not configured.
+// protocolForKeyInline returns the protocol for a provider key.
+// Caller must hold pm.mu (read lock).
+func (pm *ProviderManager) protocolForKeyInline(key string) string {
+ if pm.providers == nil {
+ return "anthropic"
+ }
+ cfg, ok := pm.providers[key]
+ if !ok {
+ return "anthropic"
+ }
+ if cfg.Protocol == "" {
+ return "anthropic"
+ }
+ return cfg.Protocol
+}
+
func (pm *ProviderManager) ProtocolForKey(key string) string {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
if pm.providers == nil {
return "anthropic"
}
@@ -488,24 +517,28 @@ func (pm *ProviderManager) ProtocolForKey(key string) string {
// ProtocolForModel returns the protocol for the provider serving the given model alias.
// Returns "anthropic" if the model is not explicitly routed.
func (pm *ProviderManager) ProtocolForModel(modelAlias string) string {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
// Direct provider/model reference.
if provider, _ := ParseModelRef(modelAlias); provider != "" {
- return pm.ProtocolForKey(provider)
+ return pm.protocolForKeyInline(provider)
}
route, ok := pm.routes[modelAlias]
if !ok {
- return pm.ProtocolForKey(pm.defaultK)
+ return pm.protocolForKeyInline(pm.defaultK)
}
providerKey := route.Provider
if providerKey == "" {
providerKey = pm.defaultK
}
- return pm.ProtocolForKey(providerKey)
+ return pm.protocolForKeyInline(providerKey)
}
// UpstreamModelFor returns the upstream model name for a model alias.
func (pm *ProviderManager) UpstreamModelFor(modelAlias string) string {
// Direct provider/model reference.
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
if provider, upstream := ParseModelRef(modelAlias); provider != "" {
if _, ok := pm.clients[provider]; ok {
return upstream
@@ -520,6 +553,8 @@ func (pm *ProviderManager) UpstreamModelFor(modelAlias string) string {
// ProviderBaseURL returns the base URL for a given provider key.
func (pm *ProviderManager) ProviderBaseURL(key string) string {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
cfg, ok := pm.providers[key]
if !ok {
return ""
@@ -529,6 +564,8 @@ func (pm *ProviderManager) ProviderBaseURL(key string) string {
// ProviderAPIKey returns the API key for a given provider key.
func (pm *ProviderManager) ProviderAPIKey(key string) string {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
cfg, ok := pm.providers[key]
if !ok {
return ""
@@ -540,6 +577,8 @@ func (pm *ProviderManager) ProviderAPIKey(key string) string {
// Falls back to defaultK when the model has no explicit route.
func (pm *ProviderManager) ProviderKeyForModel(modelAlias string) string {
// Direct provider/model reference.
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
if provider, _ := ParseModelRef(modelAlias); provider != "" {
if _, ok := pm.clients[provider]; ok {
return provider
@@ -552,6 +591,21 @@ func (pm *ProviderManager) ProviderKeyForModel(modelAlias string) string {
return route.Provider
}
+// providerKeyForModelInline returns the provider key for a model alias.
+// Caller must hold pm.mu (read lock).
+func (pm *ProviderManager) providerKeyForModelInline(modelAlias string) string {
+ if provider, _ := modelref.Parse(modelAlias); provider != "" {
+ if _, ok := pm.clients[provider]; ok {
+ return provider
+ }
+ }
+ route, ok := pm.routes[modelAlias]
+ if !ok || route.Provider == "" {
+ return pm.defaultK
+ }
+ return route.Provider
+}
+
// WebSearchCandidateKey returns the runtime key for a resolved provider/model pair.
func WebSearchCandidateKey(providerKey, upstreamModel string) string {
return "candidate:" + providerKey + "/" + upstreamModel
@@ -566,11 +620,15 @@ func (pm *ProviderManager) SetResolvedWebSearch(key string, support string) {
// ResolvedWebSearch returns the resolved web search support for a provider key.
// Returns empty string if not yet resolved.
func (pm *ProviderManager) ResolvedWebSearch(key string) string {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
return pm.resolvedWS[key]
}
// ModelMetaFor returns the ModelMeta for a model name within a specific provider.
func (pm *ProviderManager) ModelMetaFor(modelName string, providerKey string) (ModelMeta, bool) {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
cfg, ok := pm.providers[providerKey]
if !ok {
return ModelMeta{}, false
@@ -581,6 +639,8 @@ func (pm *ProviderManager) ModelMetaFor(modelName string, providerKey string) (M
// ProviderDefForKey returns the full ProviderConfig for a given provider key.
func (pm *ProviderManager) ProviderDefForKey(key string) (ProviderConfig, bool) {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
cfg, ok := pm.providers[key]
if !ok {
return ProviderConfig{}, false
@@ -592,21 +652,47 @@ func (pm *ProviderManager) ProviderDefForKey(key string) (ProviderConfig, bool)
// Checks model-level first, then falls back to provider-level.
func (pm *ProviderManager) ResolvedWebSearchForModel(modelAlias string) string {
// Check model-level resolution first.
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
if v, ok := pm.resolvedWS["model:"+modelAlias]; ok {
return v
}
- if providerKey, upstreamModel, ok := pm.ProviderAndUpstreamForModel(modelAlias); ok {
+ var providerKey, upstreamModel string
+ ok := false
+ if provider, upstream := modelref.Parse(modelAlias); provider != "" {
+ if _, exists := pm.clients[provider]; exists {
+ providerKey, upstreamModel, ok = provider, upstream, true
+ }
+ }
+ if !ok {
+ if route, exists := pm.routes[modelAlias]; exists {
+ if route.Provider == "" {
+ providerKey = pm.defaultK
+ } else {
+ providerKey = route.Provider
+ }
+ upstreamModel, ok = route.Name, true
+ }
+ }
+ if !ok {
+ if pm.defaultK != "" {
+ providerKey, upstreamModel, ok = pm.defaultK, modelAlias, true
+ }
+ }
+ if ok {
if v, ok := pm.resolvedWS[WebSearchCandidateKey(providerKey, upstreamModel)]; ok {
return v
}
}
// Fall back to provider-level.
- return pm.resolvedWS[pm.ProviderKeyForModel(modelAlias)]
+ return pm.resolvedWS[pm.providerKeyForModelInline(modelAlias)]
}
// ResolvedWebSearchForCandidate returns the resolved web search support for a provider/model pair.
// Falls back to provider-level support when no candidate-specific resolution exists.
func (pm *ProviderManager) ResolvedWebSearchForCandidate(providerKey, upstreamModel string) string {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
if providerKey == "" {
return ""
}
@@ -620,6 +706,8 @@ func (pm *ProviderManager) ResolvedWebSearchForCandidate(providerKey, upstreamMo
// WebSearchConfigForKey returns the configured web search support for a provider key.
func (pm *ProviderManager) WebSearchConfigForKey(key string) string {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
cfg, ok := pm.providers[key]
if !ok {
return ""
@@ -631,6 +719,8 @@ func (pm *ProviderManager) WebSearchConfigForKey(key string) string {
// alias routed to the given provider key. Falls back to the provider's own
// model list when no route alias references it. Returns empty string if none found.
func (pm *ProviderManager) FirstUpstreamModelForKey(key string) string {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
for _, route := range pm.routes {
pk := route.Provider
if pk == "" {
@@ -651,6 +741,8 @@ func (pm *ProviderManager) FirstUpstreamModelForKey(key string) string {
// ProviderAndUpstreamForModel resolves the provider key and upstream model for a model alias.
func (pm *ProviderManager) ProviderAndUpstreamForModel(modelAlias string) (providerKey string, upstreamModel string, ok bool) {
// Direct provider/model reference.
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
if provider, upstream := ParseModelRef(modelAlias); provider != "" {
if _, exists := pm.clients[provider]; exists {
return provider, upstream, true
diff --git a/internal/service/runtime/runtime_test.go b/internal/service/runtime/runtime_test.go
index e2982726..1d011727 100644
--- a/internal/service/runtime/runtime_test.go
+++ b/internal/service/runtime/runtime_test.go
@@ -188,7 +188,7 @@ func TestReloadWithInvalidConfigReturnsError(t *testing.T) {
// Reload with invalid mode (empty mode fails validation).
invalidCfg := config.Config{
- Mode: config.Mode(""),
+ Mode: config.Mode(""),
Cache: config.CacheConfig{Mode: "off"},
}
diff --git a/internal/service/server/adapter_core_provider_test.go b/internal/service/server/adapter_core_provider_test.go
index 50970904..24488565 100644
--- a/internal/service/server/adapter_core_provider_test.go
+++ b/internal/service/server/adapter_core_provider_test.go
@@ -27,7 +27,7 @@ func TestCoreResponseToStreamEventsEmitsTextAndUsage(t *testing.T) {
Usage: format.CoreUsage{InputTokens: 11, OutputTokens: 7, CachedInputTokens: 3},
}
- events := collectCoreStreamEvents(coreResponseToStreamEvents(resp))
+ events := collectCoreStreamEvents(coreResponseToStreamEvents(context.Background(), resp))
if len(events) == 0 {
t.Fatal("no stream events emitted")
}
diff --git a/internal/service/server/adapter_dispatch.go b/internal/service/server/adapter_dispatch.go
index d3e65a8d..c3aa5c36 100644
--- a/internal/service/server/adapter_dispatch.go
+++ b/internal/service/server/adapter_dispatch.go
@@ -518,6 +518,29 @@ func (s *Server) handleWithAdapters(
}
record.UpstreamRequest = googleReq
+ // Wrap with visual orchestrator if enabled for this model.
+ googlePreferred := preferred
+ googlePreferred.Client = &googleProviderClient{c: googleClient, model: googlePreferred.UpstreamModel}
+ if visProv := s.wrapWithVisual(ctx, openAIReq.Model, googlePreferred, providerAdapter, nil); visProv != nil {
+ var visErr error
+ coreResp, visErr = visProv.CreateCore(ctx, coreReq)
+ if visErr != nil {
+ log.Error("adapter path: google visual CreateCore failed", "error", visErr)
+ payload := openai.ErrorResponse{
+ Error: openai.ErrorObject{
+ Message: fmt.Sprintf("google visual orchestration failed: %v", visErr),
+ Type: "server_error",
+ Code: "provider_error",
+ },
+ }
+ record.Error = traceError("google_visual_core", visErr)
+ record.OpenAIResponse = payload
+ writeOpenAIError(w, http.StatusBadGateway, payload)
+ return
+ }
+ break
+ }
+
var googleResp *google.GenerateContentResponse
if wsInjected {
googleResp, err = s.executeGoogleSearchLoop(ctx, googleClient, preferred.UpstreamModel, googleReq, searchCfg.tavilyKey, searchCfg.firecrawlKey, searchCfg.maxRounds)
@@ -903,6 +926,9 @@ func (s *Server) handleAdapterStream(
// Protocol-specific upstream streaming: get stream + convert to CoreStreamEvent.
var coreEvents <-chan format.CoreStreamEvent
var providerStream format.ProviderStreamAdapter
+ var sr *format.StreamResult // result from ToCoreStream, captures events + buffer
+ var providerBuf func() []any // per-request provider stream buffer (from ToCoreStream StreamResult)
+ var clientBuf func() []any // per-request client stream buffer (from OpenAI FromCoreStream)
switch candidate.Protocol {
case config.ProtocolAnthropic:
@@ -1037,7 +1063,7 @@ func (s *Server) handleAdapterStream(
writeOpenAIError(w, http.StatusInternalServerError, payload)
return
}
- coreEvents, err = providerStream.ToCoreStream(ctx, stream)
+ sr, err = providerStream.ToCoreStream(ctx, stream)
if err != nil {
log.Error("adapter stream: ToCoreStream failed", "error", err)
payload := openai.ErrorResponse{
@@ -1052,6 +1078,10 @@ func (s *Server) handleAdapterStream(
writeOpenAIError(w, http.StatusInternalServerError, payload)
return
}
+ coreEvents = sr.Events
+ if sr.StreamBuffer != nil {
+ providerBuf = sr.StreamBuffer
+ }
}
case config.ProtocolOpenAIChat:
@@ -1124,6 +1154,52 @@ func (s *Server) handleAdapterStream(
streamRecord.ChatRequest = chatReq
var chatStream <-chan chat.ChatStreamChunk
var err error
+
+ providerAdapter, ok := s.adapterRegistry.GetProvider(config.ProtocolOpenAIChat)
+ if !ok {
+ log.Warn("adapter stream: no chat provider adapter for visual path")
+ }
+
+ // Visual orchestrator for streaming path: non-streaming orchestration
+ // → synthetic stream events, matching the anthropic streaming pattern.
+ if s.pluginRegistry != nil && s.runtime != nil && openAIReq.Model != "" && ok && providerAdapter != nil {
+ cfgV := s.runtime.Current().Config
+ visCfg, visOk := visualpkg.ConfigForModelFromResolvedConfig(cfgV, openAIReq.Model)
+ if visOk && visCfg.Provider != "" && visCfg.Model != "" {
+ finalizeUpstream := func(_ context.Context, upstream any) (any, error) {
+ req, ok := upstream.(*chat.ChatRequest)
+ if !ok {
+ return nil, fmt.Errorf("chat visual finalize: expected *chat.ChatRequest, got %T", upstream)
+ }
+ if s.pluginRegistry != nil && sess != nil {
+ prependCachedReasoningForChat(req, sess)
+ }
+ return req, nil
+ }
+ visCandidate := candidate
+ visCandidate.Client = &chatProviderClient{c: chatClient}
+ if visProv := s.wrapWithVisual(ctx, openAIReq.Model, visCandidate, providerAdapter, finalizeUpstream); visProv != nil {
+ coreResp, visErr := visProv.CreateCore(ctx, coreReq)
+ if visErr != nil {
+ log.Error("adapter stream: chat visual CreateCore failed", "error", visErr)
+ payload := openai.ErrorResponse{
+ Error: openai.ErrorObject{
+ Message: fmt.Sprintf("chat visual orchestration failed: %v", visErr),
+ Type: "server_error",
+ Code: "provider_error",
+ },
+ }
+ streamRecord.Error = traceError("stream_chat_visual", visErr)
+ streamRecord.OpenAIResponse = payload
+ writeOpenAIError(w, http.StatusBadGateway, payload)
+ return
+ }
+ coreEvents = coreResponseToCoreStream(ctx, coreResp)
+ break
+ }
+ }
+ }
+
if wsInjected {
searchCfg := s.resolvedSearchConfig(candidate.ProviderKey, openAIReq.Model)
chatStream, err = s.chatSearchBufferedStream(ctx, chatClient, chatReq, searchCfg.tavilyKey, searchCfg.firecrawlKey, searchCfg.maxRounds)
@@ -1160,7 +1236,7 @@ func (s *Server) handleAdapterStream(
writeOpenAIError(w, http.StatusInternalServerError, payload)
return
}
- coreEvents, err = providerStream.ToCoreStream(ctx, chatStream)
+ sr, err = providerStream.ToCoreStream(ctx, chatStream)
if err != nil {
log.Error("adapter stream: Chat ToCoreStream failed", "error", err)
payload := openai.ErrorResponse{
@@ -1175,6 +1251,10 @@ func (s *Server) handleAdapterStream(
writeOpenAIError(w, http.StatusInternalServerError, payload)
return
}
+ coreEvents = sr.Events
+ if sr.StreamBuffer != nil {
+ providerBuf = sr.StreamBuffer
+ }
case config.ProtocolGoogleGenAI:
googleReq, ok := upstreamReq.(*google.GenerateContentRequest)
@@ -1225,6 +1305,38 @@ func (s *Server) handleAdapterStream(
}
streamRecord.UpstreamRequest = googleReq
+
+ // Visual orchestrator for streaming path: non-streaming orchestration
+ // → synthetic stream events, matching the anthropic/chat streaming pattern.
+ providerAdapter, ok := s.adapterRegistry.GetProvider(config.ProtocolGoogleGenAI)
+ if ok && providerAdapter != nil && s.runtime != nil && openAIReq.Model != "" {
+ cfgV := s.runtime.Current().Config
+ visCfg, visOk := visualpkg.ConfigForModelFromResolvedConfig(cfgV, openAIReq.Model)
+ if visOk && visCfg.Provider != "" && visCfg.Model != "" {
+ visCandidate := candidate
+ visCandidate.Client = &googleProviderClient{c: googleClient, model: candidate.UpstreamModel}
+ if visProv := s.wrapWithVisual(ctx, openAIReq.Model, visCandidate, providerAdapter, nil); visProv != nil {
+ coreResp, visErr := visProv.CreateCore(ctx, coreReq)
+ if visErr != nil {
+ log.Error("adapter stream: google visual CreateCore failed", "error", visErr)
+ payload := openai.ErrorResponse{
+ Error: openai.ErrorObject{
+ Message: fmt.Sprintf("google visual orchestration failed: %v", visErr),
+ Type: "server_error",
+ Code: "provider_error",
+ },
+ }
+ streamRecord.Error = traceError("stream_google_visual", visErr)
+ streamRecord.OpenAIResponse = payload
+ writeOpenAIError(w, http.StatusBadGateway, payload)
+ return
+ }
+ coreEvents = coreResponseToCoreStream(ctx, coreResp)
+ break
+ }
+ }
+ }
+
if wsInjected {
searchCfg := s.resolvedSearchConfig(candidate.ProviderKey, openAIReq.Model)
googleResp, err := s.executeGoogleSearchLoop(ctx, googleClient, candidate.UpstreamModel, googleReq, searchCfg.tavilyKey, searchCfg.firecrawlKey, searchCfg.maxRounds)
@@ -1324,7 +1436,7 @@ func (s *Server) handleAdapterStream(
writeOpenAIError(w, http.StatusInternalServerError, payload)
return
}
- coreEvents, err = providerStream.ToCoreStream(ctx, googleStream)
+ sr, err = providerStream.ToCoreStream(ctx, googleStream)
if err != nil {
log.Error("adapter stream: Google ToCoreStream failed", "error", err)
payload := openai.ErrorResponse{
@@ -1339,6 +1451,10 @@ func (s *Server) handleAdapterStream(
writeOpenAIError(w, http.StatusInternalServerError, payload)
return
}
+ coreEvents = sr.Events
+ if sr.StreamBuffer != nil {
+ providerBuf = sr.StreamBuffer
+ }
default:
log.Error("adapter stream: unsupported protocol", "protocol", candidate.Protocol)
@@ -1389,8 +1505,13 @@ func (s *Server) handleAdapterStream(
return
}
- streamChan, ok := streamChanAny.(<-chan openai.StreamEvent)
- if !ok {
+ var streamChan <-chan openai.StreamEvent
+ if oaiResult, ok := streamChanAny.(*openai.OpenAIStreamResult); ok {
+ streamChan = oaiResult.Chan()
+ clientBuf = oaiResult.Buffer
+ } else if ch, ok := streamChanAny.(<-chan openai.StreamEvent); ok {
+ streamChan = ch
+ } else {
log.Error("adapter stream: unexpected stream channel type", "type", fmt.Sprintf("%T", streamChanAny))
payload := openai.ErrorResponse{
Error: openai.ErrorObject{
@@ -1435,8 +1556,15 @@ func (s *Server) handleAdapterStream(
remembered := rememberStreamResponseContent(s.pluginRegistry, sess, openAIReq.Model, finalResp)
if !remembered {
if anthProvider, ok := s.adapterRegistry.GetProvider(config.ProtocolAnthropic); ok {
- if anthAdapter, ok := anthProvider.(*anthropic.AnthropicProviderAdapter); ok {
- events := anthAdapter.StreamBuffer()
+ if _, ok := anthProvider.(*anthropic.AnthropicProviderAdapter); ok {
+ var events []anthropic.StreamEvent
+ if providerBuf != nil {
+ raw := providerBuf()
+ events = make([]anthropic.StreamEvent, len(raw))
+ for i, r := range raw {
+ events[i], _ = r.(anthropic.StreamEvent)
+ }
+ }
if len(events) > 0 {
states := s.pluginRegistry.NewStreamStates(openAIReq.Model)
for _, ev := range events {
@@ -1474,8 +1602,19 @@ func (s *Server) handleAdapterStream(
// This must not depend on trace being enabled.
if sess != nil {
if chatProvider, ok := s.adapterRegistry.GetProvider(config.ProtocolOpenAIChat); ok {
- if chatAdapter, ok := chatProvider.(*chat.ChatProviderAdapter); ok {
- if events := chatAdapter.StreamBuffer(); len(events) > 0 {
+ if _, ok := chatProvider.(*chat.ChatProviderAdapter); ok {
+ var chatEvents []chat.ChatStreamChunk
+ if providerBuf != nil {
+ raw := providerBuf()
+ chatEvents = make([]chat.ChatStreamChunk, 0, len(raw))
+ for _, r := range raw {
+ if ev, ok := r.(chat.ChatStreamChunk); ok {
+ chatEvents = append(chatEvents, ev)
+ }
+ }
+ }
+ events := chatEvents
+ if len(events) > 0 {
var streamReasoning string
seenToolCallIDs := make(map[string]struct{})
streamToolCallIDs := make([]string, 0, 4)
@@ -1506,30 +1645,40 @@ func (s *Server) handleAdapterStream(
// Capture stream events for trace.
if s.tracer != nil && s.tracer.Enabled() {
- // OpenAI stream events from client adapter
- if oaiClient, ok := s.adapterRegistry.GetClient(config.ProtocolOpenAIResponse); ok {
- if oaiAdapter, ok := oaiClient.(*openai.OpenAIAdapter); ok {
- if events := oaiAdapter.StreamBuffer(); len(events) > 0 {
- streamRecord.OpenAIStreamEvents = events
+ // Provider stream buffer (anthropic/chat/google raw events)
+ if providerBuf != nil {
+ raw := providerBuf()
+ var anthBuf []anthropic.StreamEvent
+ for _, r := range raw {
+ if ev, ok := r.(anthropic.StreamEvent); ok {
+ anthBuf = append(anthBuf, ev)
}
}
- }
- // Anthropic stream events from provider adapter
- if anthProvider, ok := s.adapterRegistry.GetProvider(config.ProtocolAnthropic); ok {
- if anthAdapter, ok := anthProvider.(*anthropic.AnthropicProviderAdapter); ok {
- if events := anthAdapter.StreamBuffer(); len(events) > 0 {
- streamRecord.AnthropicStreamEvents = events
+ if len(anthBuf) > 0 {
+ streamRecord.AnthropicStreamEvents = anthBuf
+ }
+ var chatBuf []chat.ChatStreamChunk
+ for _, r := range raw {
+ if ev, ok := r.(chat.ChatStreamChunk); ok {
+ chatBuf = append(chatBuf, ev)
}
-
- // Chat stream events from provider adapter
- if chatProvider, ok := s.adapterRegistry.GetProvider(config.ProtocolOpenAIChat); ok {
- if chatAdapter, ok := chatProvider.(*chat.ChatProviderAdapter); ok {
- if events := chatAdapter.StreamBuffer(); len(events) > 0 {
- streamRecord.ChatStreamEvents = events
- }
- }
+ }
+ if len(chatBuf) > 0 {
+ streamRecord.ChatStreamEvents = chatBuf
+ }
+ }
+ // Client stream buffer (OpenAI stream events)
+ if clientBuf != nil {
+ raw := clientBuf()
+ var openAIBuf []openai.StreamEvent
+ for _, r := range raw {
+ if ev, ok := r.(openai.StreamEvent); ok {
+ openAIBuf = append(openAIBuf, ev)
}
}
+ if len(openAIBuf) > 0 {
+ streamRecord.OpenAIStreamEvents = openAIBuf
+ }
}
}
if s.stats != nil && (finalUsage.InputTokens > 0 || finalUsage.OutputTokens > 0) {
@@ -1635,7 +1784,7 @@ func (s *Server) writeCoreResponseAsOpenAIStream(
return
}
- streamChanAny, err := clientStream.FromCoreStream(ctx, coreReq, coreResponseToStreamEvents(coreResp))
+ streamChanAny, err := clientStream.FromCoreStream(ctx, coreReq, coreResponseToStreamEvents(ctx, coreResp))
if err != nil {
payload := openai.ErrorResponse{
Error: openai.ErrorObject{
@@ -1649,8 +1798,12 @@ func (s *Server) writeCoreResponseAsOpenAIStream(
writeOpenAIError(w, http.StatusInternalServerError, payload)
return
}
- streamChan, ok := streamChanAny.(<-chan openai.StreamEvent)
- if !ok {
+ var streamChan <-chan openai.StreamEvent
+ if oaiResult, ok := streamChanAny.(*openai.OpenAIStreamResult); ok {
+ streamChan = oaiResult.Chan()
+ } else if ch, ok := streamChanAny.(<-chan openai.StreamEvent); ok {
+ streamChan = ch
+ } else {
payload := openai.ErrorResponse{
Error: openai.ErrorObject{
Message: "unexpected stream channel type",
@@ -1710,21 +1863,33 @@ func (s *Server) writeCoreResponseAsOpenAIStream(
}
}
-func coreResponseToStreamEvents(resp *format.CoreResponse) <-chan format.CoreStreamEvent {
+func coreResponseToStreamEvents(ctx context.Context, resp *format.CoreResponse) <-chan format.CoreStreamEvent {
out := make(chan format.CoreStreamEvent, 16)
go func() {
defer close(out)
+
+ send := func(ev format.CoreStreamEvent) bool {
+ select {
+ case <-ctx.Done():
+ return false
+ case out <- ev:
+ return true
+ }
+ }
+
if resp == nil {
- out <- format.CoreStreamEvent{
+ send(format.CoreStreamEvent{
Type: format.CoreEventFailed,
Error: &format.CoreError{
Message: "core response is nil",
Type: "server_error",
},
- }
+ })
+ return
+ }
+ if !send(format.CoreStreamEvent{Type: format.CoreEventCreated, ItemID: resp.ID, Model: resp.Model}) {
return
}
- out <- format.CoreStreamEvent{Type: format.CoreEventCreated, ItemID: resp.ID, Model: resp.Model}
index := 0
for _, msg := range resp.Messages {
if msg.Role != "assistant" {
@@ -1733,21 +1898,33 @@ func coreResponseToStreamEvents(resp *format.CoreResponse) <-chan format.CoreStr
for _, block := range msg.Content {
switch block.Type {
case "reasoning":
- out <- format.CoreStreamEvent{Type: format.CoreContentBlockStarted, Index: index, ContentBlock: &format.CoreContentBlock{Type: "reasoning"}}
+ if !send(format.CoreStreamEvent{Type: format.CoreContentBlockStarted, Index: index, ContentBlock: &format.CoreContentBlock{Type: "reasoning"}}) {
+ return
+ }
if block.ReasoningText != "" {
- out <- format.CoreStreamEvent{Type: format.CoreTextDelta, Index: index, Delta: block.ReasoningText}
+ if !send(format.CoreStreamEvent{Type: format.CoreTextDelta, Index: index, Delta: block.ReasoningText}) {
+ return
+ }
}
- out <- format.CoreStreamEvent{Type: format.CoreContentBlockDone, Index: index, ContentBlock: &format.CoreContentBlock{
+ if !send(format.CoreStreamEvent{Type: format.CoreContentBlockDone, Index: index, ContentBlock: &format.CoreContentBlock{
Type: "reasoning",
ReasoningSignature: block.ReasoningSignature,
- }}
+ }}) {
+ return
+ }
index++
case "text":
- out <- format.CoreStreamEvent{Type: format.CoreContentBlockStarted, Index: index, ContentBlock: &format.CoreContentBlock{Type: "text"}}
+ if !send(format.CoreStreamEvent{Type: format.CoreContentBlockStarted, Index: index, ContentBlock: &format.CoreContentBlock{Type: "text"}}) {
+ return
+ }
if block.Text != "" {
- out <- format.CoreStreamEvent{Type: format.CoreTextDelta, Index: index, Delta: block.Text}
+ if !send(format.CoreStreamEvent{Type: format.CoreTextDelta, Index: index, Delta: block.Text}) {
+ return
+ }
+ }
+ if !send(format.CoreStreamEvent{Type: format.CoreContentBlockDone, Index: index}) {
+ return
}
- out <- format.CoreStreamEvent{Type: format.CoreContentBlockDone, Index: index}
index++
}
}
@@ -1762,13 +1939,13 @@ func coreResponseToStreamEvents(resp *format.CoreResponse) <-chan format.CoreStr
} else if status == "incomplete" {
eventType = format.CoreEventIncomplete
}
- out <- format.CoreStreamEvent{
+ send(format.CoreStreamEvent{
Type: eventType,
Status: status,
Model: resp.Model,
Usage: &resp.Usage,
Error: resp.Error,
- }
+ })
}()
return out
}
@@ -2080,6 +2257,22 @@ func (s *Server) wrapWithVisual(
return nil
}
visClient = &chatProviderClient{c: chatClient}
+ case config.ProtocolGoogleGenAI:
+ gcRaw := s.activeGoogleClient(visCfg.Provider)
+ if gcRaw == nil {
+ slog.Default().Warn("visual: no google client for visual provider", "visual_provider", visCfg.Provider, "model", modelAlias)
+ return nil
+ }
+ gc, ok := gcRaw.(*google.Client)
+ if !ok || gc == nil {
+ slog.Default().Warn("visual: google client type mismatch", "visual_provider", visCfg.Provider)
+ return nil
+ }
+ visModel := pm.FirstUpstreamModelForKey(visCfg.Provider)
+ if visModel == "" {
+ visModel = visCfg.Model
+ }
+ visClient = &googleProviderClient{c: gc, model: visModel}
default:
c, err := pm.ClientForKey(visCfg.Provider)
if err != nil || c == nil {
@@ -2100,6 +2293,29 @@ func (s *Server) wrapWithVisual(
// pm.ClientForKey only constructs anthropic-shaped clients; chat-protocol
// providers keep their dedicated *chat.Client in s.chatClients. This adapter
// bridges the two when visual orchestration needs to call into a chat upstream.
+// googleProviderClient adapts *google.Client to provider.ProviderClient so the
+// adapter-based CoreProvider machinery can drive a google-genai protocol
+// upstream uniformly across protocols. Google's GenerateContent requires
+// a model parameter in the call signature (unlike anthropic/chat), so we
+// capture the model name at construction time.
+type googleProviderClient struct {
+ c *google.Client
+ model string
+}
+
+func (p *googleProviderClient) CreateMessage(ctx context.Context, req any) (any, error) {
+ googleReq, ok := req.(*google.GenerateContentRequest)
+ if !ok {
+ return nil, fmt.Errorf("googleProviderClient: expected *google.GenerateContentRequest, got %T", req)
+ }
+ return p.c.GenerateContent(ctx, p.model, googleReq)
+}
+
+func (p *googleProviderClient) StreamMessage(ctx context.Context, req any) (<-chan any, error) {
+ // Not used by visual orchestrator (uses CreateCore non-streaming path).
+ return nil, fmt.Errorf("googleProviderClient: streaming not supported via ProviderClient interface")
+}
+
type chatProviderClient struct{ c *chat.Client }
func (p *chatProviderClient) CreateMessage(ctx context.Context, req any) (any, error) {
@@ -2123,7 +2339,16 @@ func (p *chatProviderClient) StreamMessage(ctx context.Context, req any) (<-chan
go func() {
defer close(out)
for chunk := range stream {
- out <- chunk
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+ select {
+ case out <- chunk:
+ case <-ctx.Done():
+ return
+ }
}
}()
return out, nil
@@ -2143,7 +2368,6 @@ func normalizeAnthropicRequest(upstream any) (anthropic.MessageRequest, error) {
}
}
-
// injectCoreWebSearch replaces web_search tools in coreReq.Tools with injected
// tavily_search/firecrawl_fetch tools when the resolved web search mode is "injected".
// Returns true if injection was applied.
@@ -2233,6 +2457,11 @@ func (a *searchProviderAdapter) StreamMessage(ctx context.Context, req any) (<-c
defer close(out)
defer stream.Close()
for {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
ev, err := stream.Next()
if err != nil {
if err == io.EOF {
@@ -2240,7 +2469,11 @@ func (a *searchProviderAdapter) StreamMessage(ctx context.Context, req any) (<-c
}
return
}
- out <- ev
+ select {
+ case out <- ev:
+ case <-ctx.Done():
+ return
+ }
}
}()
return out, nil
diff --git a/internal/service/server/server.go b/internal/service/server/server.go
index 4c792212..19d76df6 100644
--- a/internal/service/server/server.go
+++ b/internal/service/server/server.go
@@ -73,6 +73,13 @@ type Server struct {
sessionManager session.Manager
usageTracker usage.Tracker
traceWriter trace.Writer
+
+ // clientCaches holds lazily-created HTTP clients for runtime-reloaded providers.
+ // Keyed by provider key, invalidated when Runtime reloads.
+ clientCache map[string]*chat.Client
+ googleCache map[string]*google.Client
+ clientCacheMu sync.RWMutex
+ googleCacheMu sync.RWMutex
}
func (s *Server) runtimeSnapshot() *runtime.ConfigSnapshot {
@@ -97,22 +104,42 @@ func (s *Server) activeProviderDefs() map[string]config.ProviderDef {
}
func (s *Server) activeChatClient(providerKey string) any {
+ // Check runtime-driven cache first.
+ s.clientCacheMu.RLock()
+ if cached, ok := s.clientCache[providerKey]; ok {
+ s.clientCacheMu.RUnlock()
+ return cached
+ }
+ s.clientCacheMu.RUnlock()
+
if snap := s.runtimeSnapshot(); snap != nil {
if def, ok := snap.Config.ProviderDefs[providerKey]; ok && def.Protocol == config.ProtocolOpenAIChat {
- return chat.NewClient(chat.ClientConfig{
+ client := chat.NewClient(chat.ClientConfig{
BaseURL: def.BaseURL,
APIKey: def.APIKey,
UserAgent: def.UserAgent,
})
+ s.clientCacheMu.Lock()
+ s.clientCache[providerKey] = client
+ s.clientCacheMu.Unlock()
+ return client
}
}
return s.chatClients[providerKey]
}
func (s *Server) activeGoogleClient(providerKey string) any {
+ // Check runtime-driven cache first.
+ s.googleCacheMu.RLock()
+ if cached, ok := s.googleCache[providerKey]; ok {
+ s.googleCacheMu.RUnlock()
+ return cached
+ }
+ s.googleCacheMu.RUnlock()
+
if snap := s.runtimeSnapshot(); snap != nil {
if def, ok := snap.Config.ProviderDefs[providerKey]; ok && def.Protocol == config.ProtocolGoogleGenAI {
- return google.NewClient(google.ClientConfig{
+ client := google.NewClient(google.ClientConfig{
BaseURL: def.BaseURL,
APIKey: def.APIKey,
Project: def.Project,
@@ -120,6 +147,10 @@ func (s *Server) activeGoogleClient(providerKey string) any {
Version: def.APIVersion,
UserAgent: def.UserAgent,
})
+ s.googleCacheMu.Lock()
+ s.googleCache[providerKey] = client
+ s.googleCacheMu.Unlock()
+ return client
}
}
return s.googleClients[providerKey]
@@ -148,6 +179,8 @@ func New(cfg Config) *Server {
sessionManager: cfg.SessionManager,
usageTracker: cfg.UsageTracker,
traceWriter: cfg.TraceWriter,
+ clientCache: make(map[string]*chat.Client),
+ googleCache: make(map[string]*google.Client),
}
s.mux.HandleFunc("/v1/responses", s.handleResponses)
s.mux.HandleFunc("/responses", s.handleResponses)
diff --git a/internal/service/server/server_test.go b/internal/service/server/server_test.go
index bd6cd39c..3eaaedf0 100644
--- a/internal/service/server/server_test.go
+++ b/internal/service/server/server_test.go
@@ -13,13 +13,13 @@ import (
"strings"
"testing"
+ "moonbridge/internal/config"
"moonbridge/internal/extension/codex"
deepseekv4 "moonbridge/internal/extension/deepseek_v4"
"moonbridge/internal/extension/plugin"
- "moonbridge/internal/config"
+ "moonbridge/internal/format"
"moonbridge/internal/logger"
"moonbridge/internal/protocol/openai"
- "moonbridge/internal/format"
"moonbridge/internal/service/provider"
"moonbridge/internal/service/server"
"moonbridge/internal/service/stats"
@@ -82,7 +82,6 @@ func (provider providerFunc) StreamMessage(ctx context.Context, req any) (<-chan
return provider.stream(ctx, req)
}
-
type roundTripFunc func(*http.Request) (*http.Response, error)
func (fn roundTripFunc) RoundTrip(request *http.Request) (*http.Response, error) {
diff --git a/internal/service/server/trace/writer.go b/internal/service/server/trace/writer.go
index 78b7310f..e202f222 100644
--- a/internal/service/server/trace/writer.go
+++ b/internal/service/server/trace/writer.go
@@ -24,8 +24,8 @@ type Writer interface {
// FileWriter implements Writer by delegating to a *mbtrace.Tracer.
type FileWriter struct {
- tracer *mbtrace.Tracer
- errors io.Writer
+ tracer *mbtrace.Tracer
+ errors io.Writer
}
// NewFileWriter creates a new FileWriter.
diff --git a/internal/service/server/websearch_inject.go b/internal/service/server/websearch_inject.go
index 068a7b3c..89496f69 100644
--- a/internal/service/server/websearch_inject.go
+++ b/internal/service/server/websearch_inject.go
@@ -116,6 +116,12 @@ func (s *Server) executeChatSearchLoop(
}
msg := resp.Choices[0].Message
+ // Defensive: ensure message role is set. Upstream may return empty role
+ // in some error-recovery scenarios, which would break the alternating
+ // user/assistant/tool contract on subsequent rounds.
+ if msg.Role == "" {
+ msg.Role = "assistant"
+ }
if len(msg.ToolCalls) == 0 {
return resp, nil
}
@@ -136,6 +142,22 @@ func (s *Server) executeChatSearchLoop(
return resp, nil
}
if len(nonSearchCalls) > 0 {
+ // Execute search calls as side effect, return only non-search content.
+ var toolResultMsgs []chat.ChatMessage
+ for _, tc := range searchCalls {
+ result, execErr := executeChatSearchCall(ctx, tavily, firecrawl, tc)
+ if execErr != nil {
+ log.Warn("Chat搜索执行失败(混合调用)", "tool", tc.Function.Name, "error", execErr)
+ result = fmt.Sprintf("Search error: %s", execErr.Error())
+ }
+ toolResultMsgs = append(toolResultMsgs, chat.ChatMessage{
+ Role: "tool",
+ ToolCallID: tc.ID,
+ Content: result,
+ })
+ }
+ req.Messages = append(req.Messages, msg)
+ req.Messages = append(req.Messages, toolResultMsgs...)
return resp, nil
}
@@ -193,7 +215,7 @@ func executeChatSearchCall(
if err != nil {
return "", err
}
- return formatTavilyResults(result), nil
+ return websearch.FormatTavilyResults(result), nil
case "firecrawl_fetch":
if firecrawl == nil {
@@ -216,7 +238,7 @@ func executeChatSearchCall(
if err != nil {
return "", err
}
- return formatFirecrawlResult(result), nil
+ return websearch.FormatFirecrawlResult(result), nil
default:
return "", fmt.Errorf("unknown search tool: %s", tc.Function.Name)
@@ -299,12 +321,18 @@ func (s *Server) executeGoogleSearchLoop(
if err != nil {
return nil, err
}
- if len(resp.Candidates) == 0 || len(resp.Candidates[0].Content.Parts) == 0 {
+ if len(resp.Candidates) == 0 {
return resp, nil
}
- // Check for function call parts.
- funcCalls := googleFuncCalls(resp.Candidates[0].Content.Parts)
+ // Collect function calls from ALL candidates, not just the first.
+ var funcCalls []google.FunctionCall
+ for _, c := range resp.Candidates {
+ if len(c.Content.Parts) == 0 {
+ continue
+ }
+ funcCalls = append(funcCalls, googleFuncCalls(c.Content.Parts)...)
+ }
if len(funcCalls) == 0 {
return resp, nil
}
@@ -314,6 +342,34 @@ func (s *Server) executeGoogleSearchLoop(
return resp, nil
}
if len(nonSearchCalls) > 0 {
+ // Execute search calls as side effect, feed results to the model.
+ responseParts := make([]google.Part, 0, len(searchCalls))
+ for _, fc := range searchCalls {
+ result, execErr := executeGoogleSearchCall(ctx, tavily, firecrawl, fc)
+ if execErr != nil {
+ log.Warn("Google搜索执行失败(混合调用)", "tool", fc.Name, "error", execErr)
+ result = execErr.Error()
+ }
+ respJSON, _ := json.Marshal(map[string]any{"result": result})
+ responseParts = append(responseParts, google.Part{
+ FunctionResponse: &google.FunctionResponse{
+ Name: fc.Name,
+ Response: respJSON,
+ },
+ })
+ }
+ for _, c := range resp.Candidates {
+ if len(c.Content.Parts) > 0 {
+ req.Contents = append(req.Contents, google.Content{
+ Role: "model",
+ Parts: c.Content.Parts,
+ })
+ }
+ }
+ req.Contents = append(req.Contents, google.Content{
+ Role: "function",
+ Parts: responseParts,
+ })
return resp, nil
}
@@ -335,10 +391,14 @@ func (s *Server) executeGoogleSearchLoop(
}
// Append model response + function response for next round.
- req.Contents = append(req.Contents, google.Content{
- Role: "model",
- Parts: resp.Candidates[0].Content.Parts,
- })
+ for _, c := range resp.Candidates {
+ if len(c.Content.Parts) > 0 {
+ req.Contents = append(req.Contents, google.Content{
+ Role: "model",
+ Parts: c.Content.Parts,
+ })
+ }
+ }
req.Contents = append(req.Contents, google.Content{
Role: "function",
Parts: responseParts,
@@ -405,7 +465,7 @@ func executeGoogleSearchCall(
if err != nil {
return "", err
}
- return formatTavilyResults(result), nil
+ return websearch.FormatTavilyResults(result), nil
case "firecrawl_fetch":
if firecrawl == nil {
@@ -429,52 +489,14 @@ func executeGoogleSearchCall(
if err != nil {
return "", err
}
- return formatFirecrawlResult(result), nil
+ return websearch.FormatFirecrawlResult(result), nil
default:
return "", fmt.Errorf("unknown search tool: %s", fc.Name)
}
}
-// ============================================================================
-// Formatting helpers (duplicated from websearch package for encapsulation)
-// ============================================================================
-
-func formatTavilyResults(result *websearch.SearchResult) string {
- var b strings.Builder
- b.WriteString(fmt.Sprintf("Search results for %q:\n\n", result.Query))
- if result.Answer != "" {
- b.WriteString("Answer: ")
- b.WriteString(truncate(result.Answer, 2000))
- b.WriteString("\n\n")
- }
- for i, item := range result.Results {
- if i >= 10 {
- break
- }
- b.WriteString(fmt.Sprintf("%d. [%s](%s)\n", i+1, item.Title, item.URL))
- b.WriteString(fmt.Sprintf(" Score: %.2f\n", item.Score))
- b.WriteString(fmt.Sprintf(" %s\n\n", truncate(item.Content, 500)))
- }
- return b.String()
-}
-
-func formatFirecrawlResult(result *websearch.FetchResult) string {
- var b strings.Builder
- b.WriteString(fmt.Sprintf("Content from %s:\n\n", result.Data.Metadata.SourceURL))
- if result.Data.Metadata.Title != "" {
- b.WriteString(fmt.Sprintf("Title: %s\n\n", result.Data.Metadata.Title))
- }
- b.WriteString(truncate(result.Data.Markdown, 8000))
- return b.String()
-}
-
-func truncate(s string, maxLen int) string {
- if len(s) <= maxLen {
- return s
- }
- return s[:maxLen] + "..."
-}
+// Formatting helpers — delegated to exported websearch package functions.
// ============================================================================
// Chat streaming search loop
@@ -537,6 +559,29 @@ func (s *Server) chatSearchBufferedStream(
break
}
if len(nonSearchCalls) > 0 {
+ // Execute search calls, feed results to next round.
+ var toolResultMsgs []chat.ChatMessage
+ for _, tc := range searchCalls {
+ result, execErr := executeChatSearchCall(ctx, tavily, firecrawl, tc)
+ if execErr != nil {
+ log.Warn("流式搜索执行失败(混合调用)", "tool", tc.Function.Name, "error", execErr)
+ result = fmt.Sprintf("Search error: %s", execErr.Error())
+ }
+ toolResultMsgs = append(toolResultMsgs, chat.ChatMessage{
+ Role: "tool",
+ ToolCallID: tc.ID,
+ Content: result,
+ })
+ }
+ asstContent := collectChatStreamContent(events)
+ reasoningContent := collectChatStreamReasoning(events)
+ req.Messages = append(req.Messages, chat.ChatMessage{
+ Role: "assistant",
+ Content: asstContent,
+ ToolCalls: toolCalls,
+ ReasoningContent: reasoningContent,
+ })
+ req.Messages = append(req.Messages, toolResultMsgs...)
allEvents = append(allEvents, events...)
exhausted = false
break
diff --git a/internal/service/store/config_store.go b/internal/service/store/config_store.go
index 32c8c273..26bc8596 100644
--- a/internal/service/store/config_store.go
+++ b/internal/service/store/config_store.go
@@ -80,16 +80,16 @@ type ModelRow struct {
// RouteRow represents a row in the config_store_routes table.
type RouteRow struct {
- Alias string
- ModelSlug string
- ProviderKey string
- DisplayName string
- ContextWindow int
- MaxOutputTokens int
- Extensions string // JSON-serialized map[string]config.ExtensionFileConfig
- WebSearch string // JSON-serialized config.WebSearchFileConfig
- CreatedAt string
- UpdatedAt string
+ Alias string
+ ModelSlug string
+ ProviderKey string
+ DisplayName string
+ ContextWindow int
+ MaxOutputTokens int
+ Extensions string // JSON-serialized map[string]config.ExtensionFileConfig
+ WebSearch string // JSON-serialized config.WebSearchFileConfig
+ CreatedAt string
+ UpdatedAt string
}
// SettingRow represents a row in the config_store_settings table.
@@ -100,16 +100,16 @@ type SettingRow struct {
// ChangeRow represents a row in the config_store_changes table.
type ChangeRow struct {
- ID int64
- BatchID string
- Action string // "create", "update", "delete"
- Resource string // "provider", "offer", "model", "route", "setting"
- TargetKey string
- Before string // JSON-serialized "before" state (empty for create)
- After string // JSON-serialized "after" state (empty for delete)
- Applied bool
- Error string
- Revision int
- CreatedAt string
- AppliedAt string
+ ID int64
+ BatchID string
+ Action string // "create", "update", "delete"
+ Resource string // "provider", "offer", "model", "route", "setting"
+ TargetKey string
+ Before string // JSON-serialized "before" state (empty for create)
+ After string // JSON-serialized "after" state (empty for delete)
+ Applied bool
+ Error string
+ Revision int
+ CreatedAt string
+ AppliedAt string
}
diff --git a/internal/service/store/consumer.go b/internal/service/store/consumer.go
index 196a366e..042aa282 100644
--- a/internal/service/store/consumer.go
+++ b/internal/service/store/consumer.go
@@ -12,7 +12,7 @@ import (
type ConfigStoreConsumer struct {
store db.Store
persistenceDisabled bool
- extensionSpecs []config.ExtensionConfigSpec
+ extensionSpecs []config.ExtensionConfigSpec
logger *slog.Logger
configStore ConfigStore
}
@@ -152,5 +152,5 @@ func (c *ConfigStoreConsumer) Store() ConfigStore {
// compile-time interface checks.
var (
- _ db.Consumer = (*ConfigStoreConsumer)(nil)
+ _ db.Consumer = (*ConfigStoreConsumer)(nil)
)