diff --git a/internal/protocol/chat/adapter.go b/internal/protocol/chat/adapter.go index a5380fac..8ae9a3cc 100644 --- a/internal/protocol/chat/adapter.go +++ b/internal/protocol/chat/adapter.go @@ -33,6 +33,20 @@ type ChatProviderAdapter struct { streamEvents []ChatStreamChunk } +type chatStreamChoiceState struct { + started bool + blockIndex int // monotonically increasing content block index + textStarted bool // whether a text block is active + hasReasoning bool // whether a reasoning block is active + reasonIndex int // block index for the reasoning content block + toolCallIdx int // next tool call content block index + callStarted map[int]bool // tracks which tool call indices have been started + toolCallSlot map[int]int // tool_call delta index -> content block index + reasoningContent string // accumulated reasoning content for the current reasoning block + thinkActive bool // whether a ... text tag is open + thinkBuffer string // buffered text for split tag detection +} + // NewChatProviderAdapter creates a new ChatProviderAdapter. // // client is the HTTP client for Chat API calls. May be nil if the adapter @@ -123,7 +137,7 @@ func (a *ChatProviderAdapter) FromCoreRequest(ctx context.Context, req *format.C Function: FunctionDef{ Name: t.Name, Description: t.Description, - Parameters: t.InputSchema, + Parameters: defaultToolParameters(t.InputSchema), }, }) } @@ -176,6 +190,7 @@ func (a *ChatProviderAdapter) ToCoreResponse(ctx context.Context, resp any) (*fo coreResp := &format.CoreResponse{ ID: chatResp.ID, + Model: chatResp.Model, Status: status, Messages: messages, StopReason: stopReason, @@ -241,17 +256,7 @@ func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan defer close(events) // Per-choice state for streaming. - type choiceState struct { - started bool - blockIndex int // monotonically increasing content block index - hasReasoning bool // whether a reasoning block is active - reasonIndex int // block index for the reasoning content block - toolCallIdx int // next tool call content block index (starts after text block) - callStarted map[int]bool // tracks which tool call indices have been started - toolCallSlot map[int]int // tool_call delta index -> content block index - reasoningContent string // accumulated reasoning content for the current reasoning block - } - choices := make(map[int]*choiceState) + choices := make(map[int]*chatStreamChoiceState) var seqNum int64 var finalUsage *format.CoreUsage var lastModel string @@ -301,69 +306,67 @@ func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan for _, sc := range chunk.Choices { state := choices[sc.Index] if state == nil { - state = &choiceState{blockIndex: sc.Index * 2} + state = &chatStreamChoiceState{blockIndex: sc.Index * 2} choices[sc.Index] = state } ci := sc.Index - // Emit content_block.started on first appearance with role. + // Mark the assistant turn on first appearance with role. The + // concrete block is started lazily when text/reasoning/tool + // content arrives, which avoids emitting empty text items for + // tool-only or reasoning-only turns. if !state.started && sc.Delta.Role == "assistant" { state.started = true - blockType := "text" - if sc.Delta.ReasoningContent != "" { - blockType = "reasoning" - state.hasReasoning = true - state.reasonIndex = state.blockIndex + } + + startText := func() { + if state.textStarted { + return + } + state.textStarted = true + emit(format.CoreStreamEvent{ + Type: format.CoreContentBlockStarted, + Index: state.blockIndex, + ChoiceIndex: &ci, + ContentBlock: &format.CoreContentBlock{Type: "text"}, + }) + } + closeText := func(stopReason string) { + if !state.textStarted { + return } emit(format.CoreStreamEvent{ - Type: format.CoreContentBlockStarted, + Type: format.CoreContentBlockDone, Index: state.blockIndex, + StopReason: stopReason, ChoiceIndex: &ci, - ContentBlock: &format.CoreContentBlock{ - Type: blockType, - }, }) + state.textStarted = false + state.blockIndex++ } - - // Emit text delta. - // Emit reasoning content as text delta. - // Note: reasoning_content may appear AFTER the text block has started - // (DeepSeek first sends role=assistant, then reasoning_content in subsequent chunks). - if sc.Delta.ReasoningContent != "" { - if !state.hasReasoning { - // Transition from premature text block to reasoning block. - state.hasReasoning = true - state.reasonIndex = state.blockIndex + 1 - state.blockIndex = state.reasonIndex - emit(format.CoreStreamEvent{ - Type: format.CoreContentBlockDone, - Index: state.blockIndex, - ChoiceIndex: &ci, - }) - emit(format.CoreStreamEvent{ - Type: format.CoreContentBlockStarted, - Index: state.reasonIndex, - ChoiceIndex: &ci, - ContentBlock: &format.CoreContentBlock{ - Type: "reasoning", - }, - }) + startReasoning := func() { + if state.hasReasoning { + return } - state.reasoningContent += sc.Delta.ReasoningContent + closeText("") + state.hasReasoning = true + state.reasonIndex = state.blockIndex emit(format.CoreStreamEvent{ - Type: format.CoreTextDelta, - Index: state.reasonIndex, - Delta: sc.Delta.ReasoningContent, - ChoiceIndex: &ci, + Type: format.CoreContentBlockStarted, + Index: state.reasonIndex, + ChoiceIndex: &ci, + ContentBlock: &format.CoreContentBlock{Type: "reasoning"}, }) } - - // Transition from reasoning block to text block. - if sc.Delta.Content != "" && state.hasReasoning { + closeReasoning := func(stopReason string) { + if !state.hasReasoning { + return + } emit(format.CoreStreamEvent{ Type: format.CoreContentBlockDone, Index: state.reasonIndex, + StopReason: stopReason, ChoiceIndex: &ci, ContentBlock: &format.CoreContentBlock{ Type: "reasoning", @@ -373,24 +376,48 @@ func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan state.reasoningContent = "" state.hasReasoning = false state.blockIndex = state.reasonIndex + 1 + } + emitTextDelta := func(delta string) { + if delta == "" { + return + } + startText() emit(format.CoreStreamEvent{ - Type: format.CoreContentBlockStarted, + Type: format.CoreTextDelta, Index: state.blockIndex, + Delta: delta, ChoiceIndex: &ci, - ContentBlock: &format.CoreContentBlock{ - Type: "text", - }, }) } - if sc.Delta.Content != "" { + emitReasoningDelta := func(delta string) { + if delta == "" { + return + } + startReasoning() + state.reasoningContent += delta emit(format.CoreStreamEvent{ Type: format.CoreTextDelta, - Index: state.blockIndex, - Delta: sc.Delta.Content, + Index: state.reasonIndex, + Delta: delta, ChoiceIndex: &ci, }) } + // Emit provider-native reasoning_content as reasoning deltas. + // Note: reasoning_content may appear AFTER the text block has started + // (DeepSeek first sends role=assistant, then reasoning_content in subsequent chunks). + if sc.Delta.ReasoningContent != "" { + emitReasoningDelta(sc.Delta.ReasoningContent) + } + + // Transition from reasoning block to text block. + if sc.Delta.Content != "" && state.hasReasoning && !state.thinkActive && state.thinkBuffer == "" { + closeReasoning("") + } + if sc.Delta.Content != "" { + state.consumeThinkTaggedDelta(sc.Delta.Content, emitTextDelta, emitReasoningDelta, closeReasoning) + } + // Emit tool call content blocks and args deltas. for toolPos, tc := range sc.Delta.ToolCalls { callPos := toolPos @@ -399,8 +426,12 @@ func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan } if state.callStarted == nil { state.callStarted = make(map[int]bool) - // Start tool call indices after the current text/reasoning block. - state.toolCallIdx = state.blockIndex + 1 + // Start tool call indices after any active text/reasoning block. + if state.textStarted || state.hasReasoning { + state.toolCallIdx = state.blockIndex + 1 + } else { + state.toolCallIdx = state.blockIndex + } } if state.toolCallSlot == nil { state.toolCallSlot = make(map[int]int) @@ -456,28 +487,14 @@ func (a *ChatProviderAdapter) ToCoreStream(ctx context.Context, src any) (<-chan // Emit content_block.done when finish_reason is set. if sc.FinishReason != "" { stopReason := a.mapFinishReason(sc.FinishReason) + state.flushThinkTaggedContent(emitTextDelta, emitReasoningDelta, closeReasoning) if state.hasReasoning { - emit(format.CoreStreamEvent{ - Type: format.CoreContentBlockDone, - Index: state.blockIndex, - StopReason: stopReason, - ChoiceIndex: &ci, - ContentBlock: &format.CoreContentBlock{ - Type: "reasoning", - ReasoningText: state.reasoningContent, - }, - }) - state.reasoningContent = "" - } else { - emit(format.CoreStreamEvent{ - Type: format.CoreContentBlockDone, - Index: state.blockIndex, - StopReason: stopReason, - ChoiceIndex: &ci, - }) + closeReasoning(stopReason) + } else if state.textStarted { + closeText(stopReason) } // Complete tool call blocks. - for idx := state.blockIndex + 1; idx < state.toolCallIdx; idx++ { + for idx := 0; idx < state.toolCallIdx; idx++ { if !state.callStarted[idx] { continue } @@ -805,14 +822,17 @@ func (a *ChatProviderAdapter) fromChatContent(content any) []format.CoreContentB if v == "" { return nil } - return []format.CoreContentBlock{ - {Type: "text", Text: v}, - } + return splitThinkTaggedText(v) case []any: blocks := make([]format.CoreContentBlock, 0, len(v)) for _, item := range v { if m, ok := item.(map[string]any); ok { - blocks = append(blocks, a.fromContentPartMap(m)) + block := a.fromContentPartMap(m) + if block.Type == "text" && block.Text != "" { + blocks = append(blocks, splitThinkTaggedText(block.Text)...) + } else { + blocks = append(blocks, block) + } } } return blocks @@ -878,3 +898,144 @@ func unquoteArguments(raw json.RawMessage) json.RawMessage { } return json.RawMessage(s) } + +func defaultToolParameters(schema map[string]any) map[string]any { + if len(schema) > 0 { + return schema + } + return map[string]any{"type": "object"} +} + +const ( + thinkOpenTag = "" + thinkCloseTag = "" +) + +func splitThinkTaggedText(text string) []format.CoreContentBlock { + var blocks []format.CoreContentBlock + remaining := text + for remaining != "" { + open := indexFold(remaining, thinkOpenTag) + if open < 0 { + appendTextBlock(&blocks, remaining) + break + } + appendTextBlock(&blocks, remaining[:open]) + remaining = remaining[open+len(thinkOpenTag):] + close := indexFold(remaining, thinkCloseTag) + if close < 0 { + appendReasoningBlock(&blocks, remaining) + break + } + appendReasoningBlock(&blocks, remaining[:close]) + remaining = trimOneLeadingLineBreak(remaining[close+len(thinkCloseTag):]) + } + return blocks +} + +func appendTextBlock(blocks *[]format.CoreContentBlock, text string) { + if text == "" { + return + } + *blocks = append(*blocks, format.CoreContentBlock{Type: "text", Text: text}) +} + +func appendReasoningBlock(blocks *[]format.CoreContentBlock, text string) { + if text == "" { + return + } + *blocks = append(*blocks, format.CoreContentBlock{ + Type: "reasoning", + ReasoningText: text, + }) +} + +func (state *chatStreamChoiceState) consumeThinkTaggedDelta( + delta string, + emitText func(string), + emitReasoning func(string), + closeReasoning func(string), +) { + state.thinkBuffer += delta + for state.thinkBuffer != "" { + if state.thinkActive { + close := indexFold(state.thinkBuffer, thinkCloseTag) + if close >= 0 { + emitReasoning(state.thinkBuffer[:close]) + state.thinkBuffer = trimOneLeadingLineBreak(state.thinkBuffer[close+len(thinkCloseTag):]) + state.thinkActive = false + closeReasoning("") + continue + } + emit, keep := splitSafeForTag(state.thinkBuffer, thinkCloseTag) + emitReasoning(emit) + state.thinkBuffer = keep + return + } + + open := indexFold(state.thinkBuffer, thinkOpenTag) + if open >= 0 { + emitText(state.thinkBuffer[:open]) + state.thinkBuffer = state.thinkBuffer[open+len(thinkOpenTag):] + state.thinkActive = true + continue + } + emit, keep := splitSafeForTag(state.thinkBuffer, thinkOpenTag) + emitText(emit) + state.thinkBuffer = keep + return + } +} + +func (state *chatStreamChoiceState) flushThinkTaggedContent( + emitText func(string), + emitReasoning func(string), + closeReasoning func(string), +) { + if state.thinkBuffer == "" { + return + } + if state.thinkActive { + emitReasoning(state.thinkBuffer) + state.thinkBuffer = "" + state.thinkActive = false + closeReasoning("") + return + } + emitText(state.thinkBuffer) + state.thinkBuffer = "" +} + +func splitSafeForTag(text string, tag string) (emit string, keep string) { + keepLen := longestTagPrefixSuffix(text, tag) + if keepLen == 0 { + return text, "" + } + return text[:len(text)-keepLen], text[len(text)-keepLen:] +} + +func longestTagPrefixSuffix(text string, tag string) int { + maxLen := min(len(text), len(tag)-1) + lowerText := strings.ToLower(text) + lowerTag := strings.ToLower(tag) + for n := maxLen; n > 0; n-- { + if strings.HasSuffix(lowerText, lowerTag[:n]) { + return n + } + } + return 0 +} + +func indexFold(s string, substr string) int { + return strings.Index(strings.ToLower(s), strings.ToLower(substr)) +} + +func trimOneLeadingLineBreak(s string) string { + if strings.HasPrefix(s, "\r\n") { + return s[2:] + } + if strings.HasPrefix(s, "\n") || strings.HasPrefix(s, "\r") { + return s[1:] + } + return s +} diff --git a/internal/protocol/chat/chat_test.go b/internal/protocol/chat/chat_test.go index 35b71c2e..defa2909 100644 --- a/internal/protocol/chat/chat_test.go +++ b/internal/protocol/chat/chat_test.go @@ -1127,6 +1127,25 @@ func TestFromCoreRequest_Tools(t *testing.T) { } } +func TestFromCoreRequest_ToolsDefaultEmptyParameters(t *testing.T) { + adapter := newTestAdapter() + result, err := adapter.FromCoreRequest(context.Background(), &format.CoreRequest{ + Model: "gpt-4o", + Messages: []format.CoreMessage{ + {Role: "user", Content: []format.CoreContentBlock{{Type: "text", Text: "call tool"}}}, + }, + Tools: []format.CoreTool{{Name: "ping"}}, + }) + if err != nil { + t.Fatal(err) + } + chatReq := result.(*chat.ChatRequest) + params := chatReq.Tools[0].Function.Parameters + if got := params["type"]; got != "object" { + t.Fatalf("default parameters type = %v, want object; params=%+v", got, params) + } +} + func TestFromCoreRequest_ToolChoice(t *testing.T) { adapter := newTestAdapter() tests := []struct { @@ -1505,6 +1524,43 @@ func TestToCoreResponse_ToolCalls(t *testing.T) { } } +func TestToCoreResponse_SplitsThinkTaggedText(t *testing.T) { + adapter := newTestAdapter() + chatResp := &chat.ChatResponse{ + ID: "chatcmpl-think", + Model: "MiniMax-M3", + Choices: []chat.Choice{{ + Index: 0, + Message: chat.ChatMessage{ + Role: "assistant", + Content: "\nplan privately\n\npong", + }, + FinishReason: "stop", + }}, + } + + result, err := adapter.ToCoreResponse(context.Background(), chatResp) + if err != nil { + t.Fatal(err) + } + if result.Model != "MiniMax-M3" { + t.Fatalf("Model = %q, want MiniMax-M3", result.Model) + } + if len(result.Messages) != 1 { + t.Fatalf("Messages: got %d, want 1", len(result.Messages)) + } + content := result.Messages[0].Content + if len(content) != 2 { + t.Fatalf("content len = %d, want 2: %+v", len(content), content) + } + if content[0].Type != "reasoning" || content[0].ReasoningText != "\nplan privately\n" { + t.Fatalf("reasoning block = %+v", content[0]) + } + if content[1].Type != "text" || content[1].Text != "pong" { + t.Fatalf("text block = %+v", content[1]) + } +} + func TestToCoreResponse_FinishReasonVariants(t *testing.T) { adapter := newTestAdapter() tests := []struct { @@ -1749,6 +1805,83 @@ func TestToCoreStream_ToolCallArgsDelta(t *testing.T) { } } +func TestToCoreStream_SplitsThinkTaggedTextAcrossChunks(t *testing.T) { + adapter := newTestAdapter() + src := make(chan chat.ChatStreamChunk, 6) + src <- chat.ChatStreamChunk{ + Model: "MiniMax-M3", + Choices: []chat.StreamChoice{{ + Index: 0, + Delta: chat.Delta{Role: "assistant", Content: "\nplan"}, + }}, + } + src <- chat.ChatStreamChunk{ + Choices: []chat.StreamChoice{{ + Index: 0, + Delta: chat.Delta{Content: "\n\npong"}, + }}, + } + src <- chat.ChatStreamChunk{ + Choices: []chat.StreamChoice{{Index: 0, FinishReason: "stop"}}, + } + close(src) + + events, err := adapter.ToCoreStream(context.Background(), (<-chan chat.ChatStreamChunk)(src)) + if err != nil { + t.Fatal(err) + } + + var reasoningStarted bool + var reasoningDone *format.CoreStreamEvent + var text string + var completed *format.CoreStreamEvent + for e := range events { + if e.Type == format.CoreContentBlockStarted && e.ContentBlock != nil && e.ContentBlock.Type == "reasoning" { + reasoningStarted = true + } + if e.Type == format.CoreTextDelta { + if strings.Contains(e.Delta, " 0 { + expanded = append(expanded, expandResponseNamespaceTools(tool.Tools)...) + continue + } + if len(tool.Tools) > 0 { + tool.Tools = expandResponseNamespaceTools(tool.Tools) + } + expanded = append(expanded, tool) + } + return expanded +} + +func copyUpstreamResponseBody(target io.Writer, source io.Reader, responseWriter http.ResponseWriter) (int64, error) { + buffer := make([]byte, 32*1024) + flusher, _ := responseWriter.(http.Flusher) + var written int64 + for { + n, readErr := source.Read(buffer) + if n > 0 { + chunk := buffer[:n] + writeN, writeErr := target.Write(chunk) + written += int64(writeN) + if writeErr != nil { + return written, writeErr + } + if writeN != n { + return written, io.ErrShortWrite + } + if flusher != nil { + flusher.Flush() + } + } + if readErr == io.EOF { + return written, nil + } + if readErr != nil { + return written, readErr + } + } +} + +const openAIStreamMonitorMaxParseBytes = 64 * 1024 + +type openAIStreamMonitor struct { + Bytes int64 + Terminal string + ErrorType string + ErrorCode string + ErrorMessage string + Truncated bool + + parsedBytes int + lineBuf []byte + eventName string + dataLines []string +} + +func (m *openAIStreamMonitor) Write(p []byte) (int, error) { + m.Bytes += int64(len(p)) + for _, b := range p { + if m.Truncated { + continue + } + if m.parsedBytes >= openAIStreamMonitorMaxParseBytes { + m.truncateParsing() + continue + } + m.parsedBytes++ + if b == '\n' { + m.processLine(string(m.lineBuf)) + m.lineBuf = m.lineBuf[:0] + continue + } + m.lineBuf = append(m.lineBuf, b) + } + return len(p), nil +} + +func (m *openAIStreamMonitor) Finish() { + if m.Truncated { + return + } + if len(m.lineBuf) > 0 { + m.processLine(string(m.lineBuf)) + m.lineBuf = nil + } + m.finishEvent() +} + +func (m *openAIStreamMonitor) truncateParsing() { + m.Truncated = true + m.lineBuf = nil + m.eventName = "" + m.dataLines = nil +} + +func (m *openAIStreamMonitor) processLine(line string) { + line = strings.TrimSuffix(line, "\r") + if strings.TrimSpace(line) == "" { + m.finishEvent() + return + } + if strings.HasPrefix(line, "event:") { + m.eventName = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + return + } + if strings.HasPrefix(line, "data:") { + m.dataLines = append(m.dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:"))) + } +} + +func (m *openAIStreamMonitor) finishEvent() { + if m.eventName == "" && len(m.dataLines) == 0 { + return + } + data := strings.Join(m.dataLines, "\n") + if data == "[DONE]" { + m.Terminal = "[DONE]" + } else if m.eventName == "response.completed" || m.eventName == "response.failed" || m.eventName == "error" { + m.Terminal = m.eventName + m.captureError(data) + } else if data != "" { + m.captureTerminalFromData(data) + } + m.eventName = "" + m.dataLines = nil +} + +func (m *openAIStreamMonitor) captureTerminalFromData(data string) { + var payload struct { + Type string `json:"type"` + Error openai.ErrorObject `json:"error"` + Response struct { + Status string `json:"status"` + Error *openai.ErrorObject `json:"error"` + } `json:"response"` + } + if err := json.Unmarshal([]byte(data), &payload); err != nil { + return + } + switch payload.Type { + case "response.completed", "response.failed", "error": + m.Terminal = payload.Type + m.captureErrorFromPayload(payload.Error, payload.Response.Error) + } +} + +func (m *openAIStreamMonitor) captureError(data string) { + if data == "" || data == "[DONE]" { + return + } + var payload struct { + Error openai.ErrorObject `json:"error"` + Response struct { + Error *openai.ErrorObject `json:"error"` + } `json:"response"` + } + if err := json.Unmarshal([]byte(data), &payload); err != nil { + return + } + m.captureErrorFromPayload(payload.Error, payload.Response.Error) +} + +func (m *openAIStreamMonitor) captureErrorFromPayload(eventErr openai.ErrorObject, responseErr *openai.ErrorObject) { + if eventErr.Message != "" || eventErr.Type != "" || eventErr.Code != "" { + m.ErrorType = eventErr.Type + m.ErrorCode = eventErr.Code + m.ErrorMessage = eventErr.Message + } + if responseErr != nil && (m.ErrorMessage == "" && m.ErrorType == "" && m.ErrorCode == "") { + m.ErrorType = responseErr.Type + m.ErrorCode = responseErr.Code + m.ErrorMessage = responseErr.Message + } +} diff --git a/internal/service/server/dispatch_internal_test.go b/internal/service/server/dispatch_internal_test.go new file mode 100644 index 00000000..b2eacdfd --- /dev/null +++ b/internal/service/server/dispatch_internal_test.go @@ -0,0 +1,33 @@ +package server + +import ( + "strings" + "testing" +) + +func TestOpenAIStreamMonitorTruncatesParsingWithoutBlockingWrites(t *testing.T) { + monitor := &openAIStreamMonitor{} + payload := strings.Repeat("x", openAIStreamMonitorMaxParseBytes+1024) + + n, err := monitor.Write([]byte(payload)) + if err != nil { + t.Fatalf("Write() error = %v", err) + } + if n != len(payload) { + t.Fatalf("Write() bytes = %d, want %d", n, len(payload)) + } + if monitor.Bytes != int64(len(payload)) { + t.Fatalf("Bytes = %d, want %d", monitor.Bytes, len(payload)) + } + if !monitor.Truncated { + t.Fatal("Truncated = false, want true") + } + if len(monitor.lineBuf) != 0 || len(monitor.dataLines) != 0 || monitor.eventName != "" { + t.Fatalf("monitor buffers were not cleared: line=%d data=%d event=%q", len(monitor.lineBuf), len(monitor.dataLines), monitor.eventName) + } + + monitor.Finish() + if monitor.Terminal != "" || monitor.ErrorMessage != "" { + t.Fatalf("Finish() parsed truncated data: terminal=%q error=%q", monitor.Terminal, monitor.ErrorMessage) + } +} diff --git a/internal/service/server/server_test.go b/internal/service/server/server_test.go index bd6cd39c..784f5c20 100644 --- a/internal/service/server/server_test.go +++ b/internal/service/server/server_test.go @@ -10,16 +10,17 @@ import ( "net/http/httptest" "os" "path/filepath" + "reflect" "strings" "testing" + "moonbridge/internal/config" "moonbridge/internal/extension/codex" deepseekv4 "moonbridge/internal/extension/deepseek_v4" "moonbridge/internal/extension/plugin" - "moonbridge/internal/config" + "moonbridge/internal/format" "moonbridge/internal/logger" "moonbridge/internal/protocol/openai" - "moonbridge/internal/format" "moonbridge/internal/service/provider" "moonbridge/internal/service/server" "moonbridge/internal/service/stats" @@ -82,7 +83,6 @@ func (provider providerFunc) StreamMessage(ctx context.Context, req any) (<-chan return provider.stream(ctx, req) } - type roundTripFunc func(*http.Request) (*http.Response, error) func (fn roundTripFunc) RoundTrip(request *http.Request) (*http.Response, error) { @@ -95,6 +95,22 @@ type captureCompletionPlugin struct { called bool } +type flushRecorder struct { + *httptest.ResponseRecorder + flushes int +} + +func newFlushRecorder() *flushRecorder { + return &flushRecorder{ResponseRecorder: httptest.NewRecorder()} +} + +func (r *flushRecorder) Flush() { + r.flushes++ + if r.Code == 0 { + r.WriteHeader(http.StatusOK) + } +} + func (p *captureCompletionPlugin) Name() string { return "capture_completion" } func (p *captureCompletionPlugin) EnabledForModel(string) bool { return true @@ -756,6 +772,125 @@ func TestResponsesHandlerPassesOpenAIStreamUsageToMetrics(t *testing.T) { } } +func TestResponsesHandlerFlushesOpenAIStreamPassthrough(t *testing.T) { + httpClient := &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: response.created`, + `data: {"type":"response.created","response":{"id":"resp_1","status":"in_progress"}}`, + ``, + `event: response.completed`, + `data: {"type":"response.completed","response":{"id":"resp_1","status":"completed"}}`, + ``, + }, "\n"))), + }, nil + })} + + providerMgr, err := provider.NewProviderManager(map[string]provider.ProviderConfig{ + "openai": { + BaseURL: "https://openai.example.test", + APIKey: "openai-key", + Protocol: config.ProtocolOpenAIResponse, + }, + }, map[string]provider.ModelRoute{ + "gpt-direct": {Provider: "openai", Name: "gpt-upstream"}, + }) + if err != nil { + t.Fatalf("NewProviderManager() error = %v", err) + } + handler := server.New(server.Config{ + ProviderMgr: providerMgr, + OpenAIHTTPClient: httpClient, + }) + + recorder := newFlushRecorder() + request := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewBufferString(`{"model":"gpt-direct","input":"hello","stream":true}`)) + handler.ServeHTTP(recorder, request) + + if recorder.Code != http.StatusOK { + t.Fatalf("status = %d, body = %s", recorder.Code, recorder.Body.String()) + } + if recorder.flushes == 0 { + t.Fatalf("expected stream passthrough to flush writes") + } + if !strings.Contains(recorder.Body.String(), "event: response.completed") { + t.Fatalf("stream response missing response.completed: %s", recorder.Body.String()) + } +} + +func TestResponsesHandlerExpandsNamespaceToolsForOpenAIStreamPassthrough(t *testing.T) { + var upstreamRequest openai.ResponsesRequest + httpClient := &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) { + if err := json.NewDecoder(request.Body).Decode(&upstreamRequest); err != nil { + t.Fatalf("decode upstream request: %v", err) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: response.completed`, + `data: {"type":"response.completed","response":{"id":"resp_1","status":"completed"}}`, + ``, + }, "\n"))), + }, nil + })} + + providerMgr, err := provider.NewProviderManager(map[string]provider.ProviderConfig{ + "openai": { + BaseURL: "https://openai.example.test", + APIKey: "openai-key", + Protocol: config.ProtocolOpenAIResponse, + }, + }, map[string]provider.ModelRoute{ + "gpt-direct": {Provider: "openai", Name: "gpt-upstream"}, + }) + if err != nil { + t.Fatalf("NewProviderManager() error = %v", err) + } + handler := server.New(server.Config{ + ProviderMgr: providerMgr, + OpenAIHTTPClient: httpClient, + }) + + body := `{ + "model":"gpt-direct", + "input":"hello", + "stream":true, + "tools":[ + {"type":"function","name":"top_level","parameters":{"type":"object"}}, + {"type":"namespace","name":"multi_agent_v1","tools":[ + {"type":"function","name":"spawn_agent","parameters":{"type":"object"}}, + {"type":"function","name":"wait_agent","parameters":{"type":"object"}} + ]} + ] + }` + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body)) + handler.ServeHTTP(recorder, request) + + if recorder.Code != http.StatusOK { + t.Fatalf("status = %d, body = %s", recorder.Code, recorder.Body.String()) + } + if upstreamRequest.Model != "gpt-upstream" { + t.Fatalf("upstream model = %q", upstreamRequest.Model) + } + if len(upstreamRequest.Tools) != 3 { + t.Fatalf("tools len = %d, tools = %+v", len(upstreamRequest.Tools), upstreamRequest.Tools) + } + for _, tool := range upstreamRequest.Tools { + if tool.Type == "namespace" { + t.Fatalf("namespace tool was not expanded: %+v", tool) + } + } + gotNames := []string{upstreamRequest.Tools[0].Name, upstreamRequest.Tools[1].Name, upstreamRequest.Tools[2].Name} + wantNames := []string{"top_level", "spawn_agent", "wait_agent"} + if !reflect.DeepEqual(gotNames, wantNames) { + t.Fatalf("tool names = %+v, want %+v", gotNames, wantNames) + } +} + func TestOpenAIResponsePassthroughWritesTraceOnSuccess(t *testing.T) { traceRoot := t.TempDir() httpClient := &http.Client{Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) {