diff --git a/config.example.yml b/config.example.yml index e4a32168..b45c5a25 100644 --- a/config.example.yml +++ b/config.example.yml @@ -215,6 +215,24 @@ providers: output_price: 15 cache_write_price: 3.75 cache_read_price: 0.30 + # Example: per-(provider, upstream-model) extra_body for vendor-specific + # switches that are not part of the OpenAI Chat Completions schema. Only + # honored on the openai-chat protocol dispatch path; the configured map is + # merged into the top-level JSON of each outbound request to that specific + # offer. Standard fields (model/messages/stream/...) take precedence and + # cannot be clobbered. Sibling offers on the same provider are unaffected. + # + # qwen-dashscope: + # base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1" + # api_key: "replace-with-dashscope-api-key" + # protocol: openai-chat + # offers: + # - model: qwen3-plus + # overrides: + # extra_body: + # enable_search: true + # - model: deepseek-v4-pro + # # No overrides.extra_body → no vendor switches injected. routes: # Optional aliases. Provider models are listed directly in the Codex catalog diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 0391d1e6..d2f64f66 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -105,6 +105,10 @@ providers: output_price: 8 cache_write_price: 1 cache_read_price: 0.25 + overrides: + # 供应商私有顶层 JSON 字段(仅 openai-chat 协议)。作用域为当前 offer。 + extra_body: + enable_search: true ``` ### Protocol 类型 diff --git a/internal/config/config.go b/internal/config/config.go index 091b0905..b6cc56b8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -168,6 +168,12 @@ type ModelMeta struct { // WebSearch holds model-level web search config (overrides provider-level). WebSearch WebSearchConfig Extensions map[string]ExtensionSettings + // ExtraBody carries vendor-specific top-level JSON switches that the + // openai-chat protocol dispatcher merges into the outbound Chat Completions + // request body (e.g. {"enable_search": true} for Qwen/DashScope). Scoped + // per (provider, upstream-model) tuple — sourced from + // providers..offers[].overrides.extra_body in YAML. + ExtraBody map[string]any } // ModelPricing holds per-provider model pricing. @@ -200,6 +206,12 @@ type ModelDef struct { SupportsImageDetailOriginal bool WebSearch WebSearchConfig Extensions map[string]ExtensionSettings + // ExtraBody carries vendor-specific top-level JSON switches. On ModelDef + // it exists only as the intermediate type used by the offer-override merge + // pipeline; it is not exposed at the top-level `models:` YAML segment. + // Final propagation target is ModelMeta.ExtraBody, populated per + // (provider, upstream-model) tuple via providers..offers[].overrides.extra_body. + ExtraBody map[string]any } // OfferEntry declares that a provider offers a model defined in Models. diff --git a/internal/config/config_loader.go b/internal/config/config_loader.go index 76dee14c..20eb75f3 100644 --- a/internal/config/config_loader.go +++ b/internal/config/config_loader.go @@ -119,6 +119,7 @@ type ModelDefFileConfig struct { SupportsImageDetailOriginal *bool `yaml:"supports_image_detail_original,omitempty" json:"supports_image_detail_original,omitempty"` WebSearch WebSearchFileConfig `yaml:"web_search,omitempty" json:"web_search,omitempty"` Extensions map[string]ExtensionFileConfig `yaml:"extensions,omitempty" json:"extensions,omitempty"` + ExtraBody map[string]any `yaml:"extra_body,omitempty" json:"extra_body,omitempty"` } type OfferFileConfig struct { @@ -611,6 +612,9 @@ func mergeModelDefOverrides(base ModelDef, override ModelDefFileConfig) ModelDef } } } + if len(override.ExtraBody) > 0 { + out.ExtraBody = cloneAnyMap(override.ExtraBody) + } return out } @@ -651,6 +655,9 @@ func applyModelOverrides(meta *ModelMeta, override ModelDef) { if override.SupportsImageDetailOriginal { meta.SupportsImageDetailOriginal = true } + if len(override.ExtraBody) > 0 { + meta.ExtraBody = cloneAnyMap(override.ExtraBody) + } } // buildRoutes parses route specs and merges model metadata. diff --git a/internal/config/config_test.go b/internal/config/config_test.go index a143cd48..7e576a55 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1074,3 +1074,79 @@ routes: t.Fatalf("DisplayName = %q, want \"My Custom Display Name\"", route.DisplayName) } } + +// TestLoadFromYAMLParsesOfferExtraBody verifies that vendor-specific top-level +// JSON switches written under providers..offers[].overrides.extra_body are +// propagated to the corresponding ModelMeta in cfg.ProviderDefs[provider].Models[upstream]. +// +// This is the (provider, upstream-model) tuple granularity: two different offers +// on the same provider, and the same model name offered by two different providers, +// each carry independent ExtraBody maps. +func TestLoadFromYAMLParsesOfferExtraBody(t *testing.T) { + cfg, err := config.LoadFromYAML([]byte(` +mode: Transform +models: + qwen3-plus: + context_window: 1000000 + deepseek-v4-pro: + context_window: 1000000 +providers: + aliyun: + base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 + api_key: aliyun-key + protocol: openai-chat + offers: + - model: qwen3-plus + overrides: + extra_body: + enable_search: true + extra_flag: 1 + - model: deepseek-v4-pro + bailian: + base_url: https://bailian.example.test/v1 + api_key: bailian-key + protocol: openai-chat + offers: + - model: qwen3-plus + overrides: + extra_body: + enable_search: false +routes: + qwen: + model: qwen3-plus + provider: aliyun +`)) + if err != nil { + t.Fatalf("LoadFromYAML() error = %v", err) + } + + aliyunQwen, ok := cfg.ProviderDefs["aliyun"].Models["qwen3-plus"] + if !ok { + t.Fatalf("ProviderDefs[aliyun].Models[qwen3-plus] missing") + } + if aliyunQwen.ExtraBody["enable_search"] != true { + t.Errorf("aliyun/qwen3-plus ExtraBody[enable_search] = %v, want true", aliyunQwen.ExtraBody["enable_search"]) + } + if aliyunQwen.ExtraBody["extra_flag"] != 1 { + t.Errorf("aliyun/qwen3-plus ExtraBody[extra_flag] = %v, want 1", aliyunQwen.ExtraBody["extra_flag"]) + } + + aliyunDeepseek, ok := cfg.ProviderDefs["aliyun"].Models["deepseek-v4-pro"] + if !ok { + t.Fatalf("ProviderDefs[aliyun].Models[deepseek-v4-pro] missing") + } + if len(aliyunDeepseek.ExtraBody) != 0 { + t.Errorf("aliyun/deepseek-v4-pro ExtraBody = %+v, want empty when no overrides.extra_body is set", aliyunDeepseek.ExtraBody) + } + + bailianQwen, ok := cfg.ProviderDefs["bailian"].Models["qwen3-plus"] + if !ok { + t.Fatalf("ProviderDefs[bailian].Models[qwen3-plus] missing") + } + if bailianQwen.ExtraBody["enable_search"] != false { + t.Errorf("bailian/qwen3-plus ExtraBody[enable_search] = %v, want false (independent from aliyun's same-name offer)", bailianQwen.ExtraBody["enable_search"]) + } + if _, hasExtraFlag := bailianQwen.ExtraBody["extra_flag"]; hasExtraFlag { + t.Errorf("bailian/qwen3-plus should not carry aliyun's extra_flag key — per-(provider, model) isolation broken") + } +} diff --git a/internal/config/convert.go b/internal/config/convert.go index faccab17..312e7333 100644 --- a/internal/config/convert.go +++ b/internal/config/convert.go @@ -134,6 +134,10 @@ func toModelDefFileConfig(def ModelDef) ModelDefFileConfig { } } + if len(def.ExtraBody) > 0 { + m.ExtraBody = cloneAnyMap(def.ExtraBody) + } + return m } diff --git a/internal/protocol/chat/adapter.go b/internal/protocol/chat/adapter.go index a5380fac..bba062d6 100644 --- a/internal/protocol/chat/adapter.go +++ b/internal/protocol/chat/adapter.go @@ -138,9 +138,29 @@ func (a *ChatProviderAdapter) FromCoreRequest(ctx context.Context, req *format.C } } + // extra_body: vendor-specific top-level JSON switches (e.g. enable_search) + // sourced from CoreRequest.Extensions["openai_chat"]["extra_body"]. + // The dispatcher writes this value once per request based on the resolved + // (provider, upstream-model) tuple. ChatRequest.MarshalJSON flattens it to + // the top-level JSON object of the outbound body. + if extra := extractChatExtraBody(req.Extensions); len(extra) > 0 { + chatReq.ExtraParams = extra + } + return chatReq, nil } +// extractChatExtraBody returns the openai-chat extra_body map carried on a +// CoreRequest, or nil when the key is absent or has the wrong shape. +func extractChatExtraBody(ext map[string]any) map[string]any { + bag, ok := ext["openai_chat"].(map[string]any) + if !ok { + return nil + } + extra, _ := bag["extra_body"].(map[string]any) + return extra +} + // ========================================================================= // ToCoreResponse — *ChatResponse → *CoreResponse // ========================================================================= diff --git a/internal/protocol/chat/chat_test.go b/internal/protocol/chat/chat_test.go index 35b71c2e..83958f3d 100644 --- a/internal/protocol/chat/chat_test.go +++ b/internal/protocol/chat/chat_test.go @@ -7,11 +7,15 @@ package chat_test import ( "context" "encoding/json" + "fmt" + "io" "net/http" "net/http/httptest" "strings" + "sync" "testing" + visualpkg "moonbridge/internal/extension/visual" "moonbridge/internal/format" "moonbridge/internal/protocol/chat" ) @@ -2602,3 +2606,455 @@ func TestFromCoreRequest_ToolCallArgumentsAreJSONString(t *testing.T) { t.Errorf("unexpected city value: %v", obj["city"]) } } + +// ============================================================================ +// ChatRequest.ExtraParams: vendor-specific top-level passthrough +// ============================================================================ +// +// ExtraParams carries vendor-specific switches (such as enable_search for +// Qwen/DashScope) that are not part of the OpenAI Chat Completions schema. +// MarshalJSON must flatten these into the top-level JSON object so the upstream +// sees them as siblings of model/messages. Standard fields take precedence — +// ExtraParams cannot clobber model/messages/stream by accident. + +// decodeTopLevel marshals a ChatRequest and returns its top-level JSON object. +func decodeTopLevel(t *testing.T, req chat.ChatRequest) map[string]json.RawMessage { + t.Helper() + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + var top map[string]json.RawMessage + if err := json.Unmarshal(data, &top); err != nil { + t.Fatalf("Unmarshal top-level: %v (raw=%s)", err, data) + } + return top +} + +func TestTypes_ChatRequest_ExtraParams_FlattensToTopLevel(t *testing.T) { + req := chat.ChatRequest{ + Model: "qwen3-max", + Messages: []chat.ChatMessage{{Role: "user", Content: "hi"}}, + ExtraParams: map[string]any{ + "enable_search": true, + }, + } + top := decodeTopLevel(t, req) + + raw, ok := top["enable_search"] + if !ok { + t.Fatalf("enable_search missing from top-level JSON; keys=%v", keysOf(top)) + } + if string(raw) != "true" { + t.Errorf("enable_search = %s, want true", raw) + } + if _, ok := top["model"]; !ok { + t.Error("model field disappeared after MarshalJSON") + } + if _, ok := top["messages"]; !ok { + t.Error("messages field disappeared after MarshalJSON") + } +} + +func TestTypes_ChatRequest_ExtraParams_NilHasNoEffect(t *testing.T) { + req := chat.ChatRequest{ + Model: "gpt-4o", + Messages: []chat.ChatMessage{{Role: "user", Content: "hi"}}, + ExtraParams: nil, + } + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if strings.Contains(string(data), "extra_params") { + t.Errorf("Output should not expose extra_params key: %s", data) + } + // Round-trip should still produce a valid ChatRequest with model intact. + var out chat.ChatRequest + if err := json.Unmarshal(data, &out); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if out.Model != "gpt-4o" { + t.Errorf("Model = %q, want gpt-4o", out.Model) + } +} + +func TestTypes_ChatRequest_ExtraParams_DoesNotOverrideExistingFields(t *testing.T) { + req := chat.ChatRequest{ + Model: "qwen3-max", + Messages: []chat.ChatMessage{{Role: "user", Content: "hi"}}, + ExtraParams: map[string]any{ + "model": "evil-override", + "enable_search": true, + }, + } + top := decodeTopLevel(t, req) + + var modelOut string + if err := json.Unmarshal(top["model"], &modelOut); err != nil { + t.Fatalf("decode model: %v", err) + } + if modelOut != "qwen3-max" { + t.Errorf("model = %q, want qwen3-max (extra_params must not clobber real field)", modelOut) + } + if string(top["enable_search"]) != "true" { + t.Errorf("enable_search = %s, want true", top["enable_search"]) + } +} + +func TestTypes_ChatRequest_ExtraParams_MultipleKeys(t *testing.T) { + req := chat.ChatRequest{ + Model: "qwen3-max", + Messages: []chat.ChatMessage{{Role: "user", Content: "hi"}}, + ExtraParams: map[string]any{ + "enable_search": true, + "some_int": 42, + }, + } + top := decodeTopLevel(t, req) + + if string(top["enable_search"]) != "true" { + t.Errorf("enable_search = %s, want true", top["enable_search"]) + } + if string(top["some_int"]) != "42" { + t.Errorf("some_int = %s, want 42", top["some_int"]) + } +} + +func keysOf(m map[string]json.RawMessage) []string { + out := make([]string, 0, len(m)) + for k := range m { + out = append(out, k) + } + return out +} + +// TestAdapter_ChatExtraBody_FromExtensions verifies that +// CoreRequest.Extensions["openai_chat"]["extra_body"] is propagated to +// ChatRequest.ExtraParams during FromCoreRequest. This is the single read site +// that every outbound chat completion call (direct, streaming, and per-round +// inside the visual orchestrator) shares, so the test guards the central +// behaviour rather than each call-site individually. +func TestAdapter_ChatExtraBody_FromExtensions(t *testing.T) { + adapter := newTestAdapter() + core := &format.CoreRequest{ + Model: "deepseek-v4-pro", + Messages: []format.CoreMessage{{Role: "user", Content: []format.CoreContentBlock{{Type: "text", Text: "hi"}}}}, + Extensions: map[string]any{ + "openai_chat": map[string]any{ + "extra_body": map[string]any{ + "enable_search": true, + "extra_flag": 1, + }, + }, + }, + } + upstream, err := adapter.FromCoreRequest(context.Background(), core) + if err != nil { + t.Fatalf("FromCoreRequest: %v", err) + } + chatReq, ok := upstream.(*chat.ChatRequest) + if !ok { + t.Fatalf("upstream type = %T, want *chat.ChatRequest", upstream) + } + if chatReq.ExtraParams["enable_search"] != true { + t.Errorf("ExtraParams[enable_search] = %v, want true", chatReq.ExtraParams["enable_search"]) + } + if chatReq.ExtraParams["extra_flag"] != 1 { + t.Errorf("ExtraParams[extra_flag] = %v, want 1", chatReq.ExtraParams["extra_flag"]) + } + // Confirm the keys are flattened to the top level of the marshaled JSON. + data, err := json.Marshal(chatReq) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if !strings.Contains(string(data), `"enable_search":true`) { + t.Errorf("outbound body missing enable_search at top level: %s", data) + } +} + +func TestAdapter_ChatExtraBody_AbsentExtensionLeavesExtraParamsNil(t *testing.T) { + adapter := newTestAdapter() + core := &format.CoreRequest{ + Model: "deepseek-v4-pro", + Messages: []format.CoreMessage{{Role: "user", Content: []format.CoreContentBlock{{Type: "text", Text: "hi"}}}}, + } + upstream, err := adapter.FromCoreRequest(context.Background(), core) + if err != nil { + t.Fatalf("FromCoreRequest: %v", err) + } + chatReq := upstream.(*chat.ChatRequest) + if chatReq.ExtraParams != nil { + t.Errorf("ExtraParams = %+v, want nil when Extensions carries no openai_chat bag", chatReq.ExtraParams) + } + data, _ := json.Marshal(chatReq) + if strings.Contains(string(data), "enable_search") || strings.Contains(string(data), "extra_body") { + t.Errorf("outbound body should not mention vendor keys when none are configured: %s", data) + } +} + +// ============================================================================ +// extra_body 在三条出口路径上的端到端透传 +// ============================================================================ +// +// 用 httptest.NewServer 拦截 chat.Client 真正发出去的请求体,断言 +// CoreRequest.Extensions["openai_chat"]["extra_body"] 里配的供应商私有字段 +// 出现在出口 JSON 的顶层,并且不会覆盖标准字段。覆盖三种出口路径: +// +// 1. 非流式直连:chat.Client.CreateChat +// 2. 流式:chat.Client.StreamChat +// 3. visual orchestrator 编排:每轮通过 ChatProviderAdapter.FromCoreRequest +// 重新构建 ChatRequest 后转发 +// +// 任何一条路径漏掉读取 Extensions 都会被对应的测试捕获。 + +// extraBodyCoreRequest builds a CoreRequest that carries the vendor switches +// the dispatcher would have written based on the resolved (provider, upstream-model). +func extraBodyCoreRequest() *format.CoreRequest { + return &format.CoreRequest{ + Model: "deepseek-v4-pro", + Messages: []format.CoreMessage{{ + Role: "user", + Content: []format.CoreContentBlock{{Type: "text", Text: "hi"}}, + }}, + Extensions: map[string]any{ + "openai_chat": map[string]any{ + "extra_body": map[string]any{ + "enable_search": true, + "extra_flag": 1, + }, + }, + }, + } +} + +// assertExtraBodyOnTopLevel decodes a captured chat completion request body +// and asserts the vendor switches landed at the top level of the JSON object, +// while the standard fields remain intact. +func assertExtraBodyOnTopLevel(t *testing.T, captured []byte, label string) { + t.Helper() + var top map[string]json.RawMessage + if err := json.Unmarshal(captured, &top); err != nil { + t.Fatalf("%s: outbound body is not valid JSON: %v (raw=%s)", label, err, captured) + } + if string(top["enable_search"]) != "true" { + t.Errorf("%s: top-level enable_search = %s, want true (raw=%s)", label, top["enable_search"], captured) + } + if string(top["extra_flag"]) != "1" { + t.Errorf("%s: top-level extra_flag = %s, want 1 (raw=%s)", label, top["extra_flag"], captured) + } + var model string + if err := json.Unmarshal(top["model"], &model); err != nil { + t.Fatalf("%s: cannot decode top-level model: %v", label, err) + } + if model != "deepseek-v4-pro" { + t.Errorf("%s: top-level model = %q, want deepseek-v4-pro (raw=%s)", label, model, captured) + } + if _, ok := top["messages"]; !ok { + t.Errorf("%s: top-level messages field disappeared (raw=%s)", label, captured) + } +} + +func TestOpenAIChat_ExtraBody_PropagatesOnNonStreamDirect(t *testing.T) { + ctx := context.Background() + + var captured []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"id":"x","object":"chat.completion","model":"m","choices":[{"index":0,"finish_reason":"stop","message":{"role":"assistant","content":"ok"}}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`) + })) + defer srv.Close() + + adapter := newTestAdapter() + upstreamAny, err := adapter.FromCoreRequest(ctx, extraBodyCoreRequest()) + if err != nil { + t.Fatalf("FromCoreRequest: %v", err) + } + chatReq := upstreamAny.(*chat.ChatRequest) + + client := chat.NewClient(chat.ClientConfig{BaseURL: srv.URL, APIKey: "k", Client: srv.Client()}) + if _, err := client.CreateChat(ctx, chatReq); err != nil { + t.Fatalf("CreateChat: %v", err) + } + + assertExtraBodyOnTopLevel(t, captured, "non-stream direct") +} + +func TestOpenAIChat_ExtraBody_PropagatesOnStreaming(t *testing.T) { + ctx := context.Background() + + var captured []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "data: {\"id\":\"x\",\"object\":\"chat.completion.chunk\",\"model\":\"m\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"ok\"},\"finish_reason\":null}]}\n\n") + fmt.Fprint(w, "data: {\"id\":\"x\",\"object\":\"chat.completion.chunk\",\"model\":\"m\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":1,\"total_tokens\":2}}\n\n") + fmt.Fprint(w, "data: [DONE]\n\n") + })) + defer srv.Close() + + adapter := newTestAdapter() + upstreamAny, err := adapter.FromCoreRequest(ctx, extraBodyCoreRequest()) + if err != nil { + t.Fatalf("FromCoreRequest: %v", err) + } + chatReq := upstreamAny.(*chat.ChatRequest) + chatReq.Stream = true + chatReq.StreamOptions = &chat.StreamOptions{IncludeUsage: true} + + client := chat.NewClient(chat.ClientConfig{BaseURL: srv.URL, APIKey: "k", Client: srv.Client()}) + stream, err := client.StreamChat(ctx, chatReq) + if err != nil { + t.Fatalf("StreamChat: %v", err) + } + for range stream { + } + + assertExtraBodyOnTopLevel(t, captured, "streaming") + + var top map[string]json.RawMessage + _ = json.Unmarshal(captured, &top) + if string(top["stream"]) != "true" { + t.Errorf("streaming: top-level stream = %s, want true (raw=%s)", top["stream"], captured) + } +} + +// TestOpenAIChat_ExtraBody_PropagatesThroughVisualOrchestrator guards the +// path that triggered the original bug. With visual orchestration enabled the +// outbound chat completion is built per-round by ChatProviderAdapter.FromCoreRequest +// against a cloned CoreRequest, so the extra_body bag must survive both the +// orchestrator clone and the adapter conversion. Every upstream call the +// orchestrator makes must carry the configured vendor switches at the top level. +func TestOpenAIChat_ExtraBody_PropagatesThroughVisualOrchestrator(t *testing.T) { + ctx := context.Background() + + type observed struct { + mu sync.Mutex + bodies [][]byte + rounds int + } + upstreamObs := &observed{} + visualObs := &observed{} + + upstreamSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + upstreamObs.mu.Lock() + upstreamObs.rounds++ + round := upstreamObs.rounds + upstreamObs.bodies = append(upstreamObs.bodies, body) + upstreamObs.mu.Unlock() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if round == 1 { + fmt.Fprint(w, `{ + "id":"chatcmpl_upstream_1","object":"chat.completion","model":"deepseek-test", + "choices":[{"index":0,"finish_reason":"tool_calls","message":{ + "role":"assistant","content":null, + "tool_calls":[{ + "id":"call_visual_1","type":"function", + "function":{"name":"visual_brief","arguments":"{\"image_refs\":[\"Image #1\"],\"context\":\"describe\"}"} + }] + }}], + "usage":{"prompt_tokens":50,"completion_tokens":10,"total_tokens":60} + }`) + return + } + fmt.Fprint(w, `{ + "id":"chatcmpl_upstream_2","object":"chat.completion","model":"deepseek-test", + "choices":[{"index":0,"finish_reason":"stop","message":{ + "role":"assistant","content":"a chat screenshot" + }}], + "usage":{"prompt_tokens":80,"completion_tokens":12,"total_tokens":92} + }`) + })) + defer upstreamSrv.Close() + + visualSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + visualObs.mu.Lock() + visualObs.rounds++ + visualObs.bodies = append(visualObs.bodies, body) + visualObs.mu.Unlock() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{ + "id":"chatcmpl_visual_1","object":"chat.completion","model":"qwen-vl-test", + "choices":[{"index":0,"finish_reason":"stop","message":{ + "role":"assistant","content":"a chat screenshot with two people" + }}], + "usage":{"prompt_tokens":200,"completion_tokens":10,"total_tokens":210} + }`) + })) + defer visualSrv.Close() + + hooks := format.CorePluginHooks{}.WithDefaults() + upstreamAdapter := chat.NewChatProviderAdapter(2048, nil, hooks) + visualAdapter := chat.NewChatProviderAdapter(2048, nil, hooks) + + upstreamChatClient := chat.NewClient(chat.ClientConfig{BaseURL: upstreamSrv.URL, APIKey: "k", Client: upstreamSrv.Client()}) + visualChatClient := chat.NewClient(chat.ClientConfig{BaseURL: visualSrv.URL, APIKey: "k", Client: visualSrv.Client()}) + + chatCoreProvider := func(adapter *chat.ChatProviderAdapter, client *chat.Client) visualpkg.CoreProvider { + return visualpkg.CoreProviderFunc(func(ctx context.Context, req *format.CoreRequest) (*format.CoreResponse, error) { + upstreamAny, err := adapter.FromCoreRequest(ctx, req) + if err != nil { + return nil, err + } + chatReq := upstreamAny.(*chat.ChatRequest) + chatResp, err := client.CreateChat(ctx, chatReq) + if err != nil { + return nil, err + } + return adapter.ToCoreResponse(ctx, chatResp) + }) + } + + bridge := visualpkg.NewCoreBridge( + chatCoreProvider(upstreamAdapter, upstreamChatClient), + chatCoreProvider(visualAdapter, visualChatClient), + "qwen-vl-test", 4, 2048, + ) + + coreReq := &format.CoreRequest{ + Model: "deepseek-test", + Messages: []format.CoreMessage{{ + Role: "user", + Content: []format.CoreContentBlock{ + {Type: "text", Text: "describe the image"}, + {Type: "image", ImageData: "ZHVtbXlpbWFnZQ==", MediaType: "image/png"}, + }, + }}, + ToolChoice: &format.CoreToolChoice{Mode: "auto"}, + Extensions: map[string]any{ + "openai_chat": map[string]any{ + "extra_body": map[string]any{ + "enable_search": true, + "extra_flag": 1, + }, + }, + }, + } + + if _, err := bridge.CreateCore(ctx, coreReq); err != nil { + t.Fatalf("CreateCore: %v", err) + } + + if upstreamObs.rounds != 2 { + t.Fatalf("upstream rounds = %d, want 2", upstreamObs.rounds) + } + for i, body := range upstreamObs.bodies { + var top map[string]json.RawMessage + if err := json.Unmarshal(body, &top); err != nil { + t.Fatalf("upstream round %d body is not valid JSON: %v (raw=%s)", i+1, err, body) + } + if string(top["enable_search"]) != "true" { + t.Errorf("upstream round %d: enable_search = %s, want true (raw=%s)", i+1, top["enable_search"], body) + } + if string(top["extra_flag"]) != "1" { + t.Errorf("upstream round %d: extra_flag = %s, want 1 (raw=%s)", i+1, top["extra_flag"], body) + } + } +} diff --git a/internal/protocol/chat/types.go b/internal/protocol/chat/types.go index b18673c3..8640fb84 100644 --- a/internal/protocol/chat/types.go +++ b/internal/protocol/chat/types.go @@ -23,6 +23,45 @@ type ChatRequest struct { StreamOptions *StreamOptions `json:"stream_options,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` User string `json:"user,omitempty"` + + // ExtraParams holds vendor-specific top-level JSON fields merged at marshal + // time (e.g. {"enable_search": true} for Qwen/DashScope). It is not a JSON + // field itself; MarshalJSON flattens it into the top-level object. Keys + // that collide with already-serialized fields are dropped, so this cannot + // clobber model/messages/stream/etc. + ExtraParams map[string]any `json:"-"` +} + +// MarshalJSON flattens ExtraParams into the top-level JSON object. Standard +// fields take precedence; ExtraParams entries are added only when their key +// is not already present in the serialized base output. +func (r ChatRequest) MarshalJSON() ([]byte, error) { + type alias ChatRequest + base, err := json.Marshal(alias(r)) + if err != nil { + return nil, err + } + if len(r.ExtraParams) == 0 { + return base, nil + } + var merged map[string]json.RawMessage + if err := json.Unmarshal(base, &merged); err != nil { + return nil, err + } + if merged == nil { + merged = make(map[string]json.RawMessage, len(r.ExtraParams)) + } + for k, v := range r.ExtraParams { + if _, exists := merged[k]; exists { + continue + } + raw, err := json.Marshal(v) + if err != nil { + return nil, err + } + merged[k] = raw + } + return json.Marshal(merged) } // ChatMessage represents a single message in the conversation. diff --git a/internal/service/server/adapter_dispatch.go b/internal/service/server/adapter_dispatch.go index d3e65a8d..96982be2 100644 --- a/internal/service/server/adapter_dispatch.go +++ b/internal/service/server/adapter_dispatch.go @@ -177,6 +177,12 @@ func (s *Server) handleWithAdapters( wsInjected := s.injectCoreWebSearch(ctx, coreReq, preferred, openAIReq, wsMode) searchCfg := s.resolvedSearchConfig(preferred.ProviderKey, openAIReq.Model) + // Inject (provider, upstream-model)-scoped extra_body into CoreRequest.Extensions + // once per request. The chat adapter consumes this in FromCoreRequest, so the + // same value is honored uniformly across direct, streaming, and visual-orchestrator + // paths — all of which call FromCoreRequest to build the outbound *chat.ChatRequest. + injectChatExtraBodyIntoCoreRequest(coreReq, s.modelExtraBody(preferred.ProviderKey, preferred.UpstreamModel)) + upstreamAny, err := providerAdapter.FromCoreRequest(ctx, coreReq) if err != nil { log.Error("adapter path: FromCoreRequest failed", "error", err) @@ -371,10 +377,10 @@ func (s *Server) handleWithAdapters( record.ChatRequest = chatReq - // finalizeChatUpstream applies per-round mutations (cached reasoning - // replay) on every orchestrator round. prependCachedReasoningForChat - // is idempotent so duplicate application against the initial chatReq - // above is safe. + // finalizeChatUpstream applies per-round mutations on every orchestrator + // round: cached-reasoning replay. extra_body propagation is handled + // uniformly by the chat adapter via CoreRequest.Extensions, so it does + // not appear here. finalizeChatUpstream := func(_ context.Context, upstream any) (any, error) { req, ok := upstream.(*chat.ChatRequest) if !ok { @@ -2129,6 +2135,32 @@ func (p *chatProviderClient) StreamMessage(ctx context.Context, req any) (<-chan return out, nil } +// injectChatExtraBodyIntoCoreRequest stores a (provider, upstream-model)-scoped +// extra_body map inside CoreRequest.Extensions under the +// Extensions["openai_chat"]["extra_body"] key. The chat adapter's FromCoreRequest +// reads this key and translates it into ChatRequest.ExtraParams, which the +// custom MarshalJSON flattens to the top-level JSON of each outbound chat +// completion request. +// +// Centralizing the write here ensures every outbound chat completion call — +// direct, streaming, and per-round inside the visual orchestrator — picks up +// the same vendor switches, because all three paths go through FromCoreRequest. +// No-ops when extra is empty. +func injectChatExtraBodyIntoCoreRequest(coreReq *format.CoreRequest, extra map[string]any) { + if coreReq == nil || len(extra) == 0 { + return + } + if coreReq.Extensions == nil { + coreReq.Extensions = make(map[string]any) + } + bag, _ := coreReq.Extensions["openai_chat"].(map[string]any) + if bag == nil { + bag = make(map[string]any) + } + bag["extra_body"] = extra + coreReq.Extensions["openai_chat"] = bag +} + func normalizeAnthropicRequest(upstream any) (anthropic.MessageRequest, error) { switch v := upstream.(type) { case anthropic.MessageRequest: diff --git a/internal/service/server/server.go b/internal/service/server/server.go index 4c792212..0eb77437 100644 --- a/internal/service/server/server.go +++ b/internal/service/server/server.go @@ -96,6 +96,18 @@ func (s *Server) activeProviderDefs() map[string]config.ProviderDef { return nil } +// modelExtraBody returns the vendor-specific top-level JSON switches configured +// for a (provider, upstream-model) tuple via providers..offers[].overrides.extra_body. +// Returns nil when no overrides are configured for this offer. +func (s *Server) modelExtraBody(providerKey, upstreamModel string) map[string]any { + if snap := s.runtimeSnapshot(); snap != nil { + if meta, ok := snap.Config.ModelMetaFor(upstreamModel, providerKey); ok { + return meta.ExtraBody + } + } + return nil +} + func (s *Server) activeChatClient(providerKey string) any { if snap := s.runtimeSnapshot(); snap != nil { if def, ok := snap.Config.ProviderDefs[providerKey]; ok && def.Protocol == config.ProtocolOpenAIChat {