diff --git a/src/providers/proxies/claude-proxy.ts b/src/providers/proxies/claude-proxy.ts index 43f1323..699c9f2 100644 --- a/src/providers/proxies/claude-proxy.ts +++ b/src/providers/proxies/claude-proxy.ts @@ -179,6 +179,77 @@ const buildUpstreamUrl = (search: string): string => { return upstream.toString(); }; +const findSseEventBoundary = ( + buffer: string +): { index: number; length: number } | null => { + const match = /\r\n\r\n|\n\n|\n\r\n|\r\n\n/u.exec(buffer); + if (!match || match.index === undefined) { + return null; + } + + return { + index: match.index, + length: match[0].length, + }; +}; + +const parseSseEventData = (chunk: string): string | null => { + const dataLines = chunk + .replace(/\r\n/gu, "\n") + .split("\n") + .filter((line) => line.startsWith("data:")) + .map((line) => line.slice(5).trimStart()); + if (!dataLines.length) { + return null; + } + + const data = dataLines.join("\n").trim(); + if (!data || data === "[DONE]") { + return null; + } + + return data; +}; + +const rewriteSseDataLines = (chunk: string, payload: string): string => { + const boundary = findSseEventBoundary(chunk); + const eventTrailer = + boundary && boundary.index + boundary.length === chunk.length + ? chunk.slice(boundary.index) + : ""; + const chunkBody = eventTrailer ? chunk.slice(0, -eventTrailer.length) : chunk; + const rewrittenBody = chunkBody.replace( + /((?:^|\r\n|\n))(?:data:.*(?:\r\n|\n)?)+$/u, + (_, separator: string) => `${separator}data: ${payload}` + ); + + if (rewrittenBody === chunkBody) { + return chunk; + } + + return `${rewrittenBody}${eventTrailer}`; +}; + +const transformSseEventChunk = ( + chunk: string, + toolPrefix: string, + readStreamUsage: (payload: unknown) => void +): string => { + const payload = parseSseEventData(chunk); + if (!payload) { + return chunk; + } + + try { + const jsonBody = JSON.parse(payload) as unknown; + readStreamUsage(jsonBody); + const transformed = transformClaudeResponsePayload(jsonBody, toolPrefix); + return rewriteSseDataLines(chunk, JSON.stringify(transformed)); + } catch { + return chunk; + } +}; + const maybeTransformClaudeStreamResponse = ( response: Response, toolPrefix: string, @@ -196,7 +267,7 @@ const maybeTransformClaudeStreamResponse = ( const reader = response.body.getReader(); const encoder = new TextEncoder(); const decoder = new TextDecoder(); - let pendingText = ""; + let buffer = ""; const streamUsage: TokenUsage = { inputTokens: 0, @@ -278,66 +349,53 @@ const maybeTransformClaudeStreamResponse = ( } }; - const transformSseLine = (line: string): string => { - if (!line.startsWith("data:")) { - return line; - } - - const payload = line.slice(5).trimStart(); - if (!payload || payload === "[DONE]") { - return line; - } - - try { - const jsonBody = JSON.parse(payload) as unknown; - readStreamUsage(jsonBody); - const transformed = transformClaudeResponsePayload(jsonBody, toolPrefix); - return `data: ${JSON.stringify(transformed)}`; - } catch { - return line; - } - }; - - const enqueueChunk = ( - controller: ReadableStreamDefaultController, - chunk: string - ): void => { - if (!chunk) { - return; - } - const transformedChunk = chunk - .split("\n") - .map((line) => transformSseLine(line)) - .join("\n"); - - controller.enqueue(encoder.encode(transformedChunk)); - }; - const stream = new ReadableStream({ - async pull(controller): Promise { - const { done, value } = await reader.read(); - if (done) { - pendingText += decoder.decode(); - enqueueChunk(controller, pendingText); - pendingText = ""; - onTokenUsage?.(streamUsage); - controller.close(); - return; - } - - if (!value) { - return; - } - - pendingText += decoder.decode(value, { stream: true }); - const lastLineBreak = pendingText.lastIndexOf("\n"); - if (lastLineBreak === -1) { - return; - } + start(controller): void { + const pump = async (): Promise => { + try { + while (true) { + const { done, value } = await reader.read(); + if (done) { + buffer += decoder.decode(); + if (buffer) { + controller.enqueue( + encoder.encode( + transformSseEventChunk(buffer, toolPrefix, readStreamUsage) + ) + ); + buffer = ""; + } + onTokenUsage?.(streamUsage); + controller.close(); + return; + } + + if (!value) { + continue; + } + + buffer += decoder.decode(value, { stream: true }); + + let boundary = findSseEventBoundary(buffer); + while (boundary) { + const chunk = buffer.slice(0, boundary.index + boundary.length); + buffer = buffer.slice(boundary.index + boundary.length); + controller.enqueue( + encoder.encode( + transformSseEventChunk(chunk, toolPrefix, readStreamUsage) + ) + ); + boundary = findSseEventBoundary(buffer); + } + } + } catch (error) { + controller.error(error); + } + }; - const completeChunk = pendingText.slice(0, lastLineBreak + 1); - pendingText = pendingText.slice(lastLineBreak + 1); - enqueueChunk(controller, completeChunk); + pump().catch((error: unknown) => { + controller.error(error); + }); }, cancel(reason): Promise { return reader.cancel(reason); diff --git a/tests/providers/proxy-contract.test.ts b/tests/providers/proxy-contract.test.ts index e005639..e177701 100644 --- a/tests/providers/proxy-contract.test.ts +++ b/tests/providers/proxy-contract.test.ts @@ -731,6 +731,160 @@ describe("proxy contract: claude", () => { expect(transformedText).toContain('"name":"shell"'); }); + test("rewrites fragmented multiline SSE events at event boundaries", async () => { + const result = prepareClaudeUsageRequest(); + const encoder = new TextEncoder(); + + const sourceResponse = new Response( + new ReadableStream({ + start(controller): void { + controller.enqueue(encoder.encode("event: message\n")); + controller.enqueue(encoder.encode('data: {"type":"tool_use",\n')); + controller.enqueue(encoder.encode('data: "name":"mcp_shell"}\n\n')); + controller.close(); + }, + }), + { + headers: { + "content-type": "text/event-stream", + }, + } + ); + + const transformedResponse = await result.transformResponse(sourceResponse); + const transformedText = await transformedResponse.text(); + + expect(transformedText).toBe( + 'event: message\ndata: {"type":"tool_use","name":"shell"}\n\n' + ); + }); + + test("rewrites CRLF-delimited SSE events without buffering until EOF", async () => { + const result = prepareClaudeUsageRequest(); + const encoder = new TextEncoder(); + + const sourceResponse = new Response( + new ReadableStream({ + start(controller): void { + controller.enqueue( + encoder.encode( + 'event: message\r\ndata: {"type":"tool_use","name":"mcp_shell"}\r\n\r\n' + ) + ); + controller.close(); + }, + }), + { + headers: { + "content-type": "text/event-stream", + }, + } + ); + + const transformedResponse = await result.transformResponse(sourceResponse); + const transformedText = await transformedResponse.text(); + + expect(transformedText).toBe( + 'event: message\r\ndata: {"type":"tool_use","name":"shell"}\r\n\r\n' + ); + }); + + test("rewrites SSE events with mixed newline boundary separators", async () => { + const result = prepareClaudeUsageRequest(); + const encoder = new TextEncoder(); + + const sourceResponse = new Response( + new ReadableStream({ + start(controller): void { + controller.enqueue( + encoder.encode( + 'event: message\ndata: {"type":"tool_use","name":"mcp_shell"}\n\r\n' + ) + ); + controller.enqueue( + encoder.encode( + 'event: message\r\ndata: {"type":"tool_use","name":"mcp_browser"}\r\n\n' + ) + ); + controller.close(); + }, + }), + { + headers: { + "content-type": "text/event-stream", + }, + } + ); + + const transformedResponse = await result.transformResponse(sourceResponse); + const transformedText = await transformedResponse.text(); + + expect(transformedText).toBe( + 'event: message\ndata: {"type":"tool_use","name":"shell"}\n\r\n' + + 'event: message\r\ndata: {"type":"tool_use","name":"browser"}\r\n\n' + ); + }); + + test("rewrites fragmented SSE events with mixed internal newlines", async () => { + const result = prepareClaudeUsageRequest(); + const encoder = new TextEncoder(); + + const sourceResponse = new Response( + new ReadableStream({ + start(controller): void { + controller.enqueue(encoder.encode("event: message\r\n")); + controller.enqueue(encoder.encode('data: {"type":"tool_use",\n')); + controller.enqueue(encoder.encode('data: "name":"mcp_shell"}\n\r\n')); + controller.close(); + }, + }), + { + headers: { + "content-type": "text/event-stream", + }, + } + ); + + const transformedResponse = await result.transformResponse(sourceResponse); + const transformedText = await transformedResponse.text(); + + expect(transformedText).toBe( + 'event: message\r\ndata: {"type":"tool_use","name":"shell"}\n\r\n' + ); + }); + + test("rewrites multiple SSE events delivered in one chunk", async () => { + const result = prepareClaudeUsageRequest(); + const encoder = new TextEncoder(); + + const sourceResponse = new Response( + new ReadableStream({ + start(controller): void { + controller.enqueue( + encoder.encode( + 'event: message\ndata: {"type":"tool_use","name":"mcp_shell"}\n\n' + + 'event: message\ndata: {"type":"tool_use","name":"mcp_browser"}\n\n' + ) + ); + controller.close(); + }, + }), + { + headers: { + "content-type": "text/event-stream", + }, + } + ); + + const transformedResponse = await result.transformResponse(sourceResponse); + const transformedText = await transformedResponse.text(); + + expect(transformedText).toBe( + 'event: message\ndata: {"type":"tool_use","name":"shell"}\n\n' + + 'event: message\ndata: {"type":"tool_use","name":"browser"}\n\n' + ); + }); + const claudeStreamUsageCases = [ { name: "extracts usage from claude streaming message events", @@ -805,6 +959,51 @@ describe("proxy contract: claude", () => { }); } + test("extracts usage from fragmented streaming events", async () => { + const capture = createUsageCapture(); + const result = prepareClaudeUsageRequest(capture.onTokenUsage); + const encoder = new TextEncoder(); + + const sourceResponse = new Response( + new ReadableStream({ + start(controller): void { + controller.enqueue( + encoder.encode('data: {"type":"message_start","message":{"usage":{') + ); + controller.enqueue( + encoder.encode('"input_tokens":55,"cache_read_input_tokens":11,') + ); + controller.enqueue( + encoder.encode('"cache_creation_input_tokens":5}}}\n\n') + ); + controller.enqueue( + encoder.encode('data: {"type":"message_delta","usage":{') + ); + controller.enqueue( + encoder.encode('"output_tokens":13,"cache_creation_input_tokens":7') + ); + controller.enqueue(encoder.encode("}}\n\n")); + controller.close(); + }, + }), + { + headers: { + "content-type": "text/event-stream", + }, + } + ); + + const transformedResponse = await result.transformResponse(sourceResponse); + await transformedResponse.text(); + + expect(capture.read()).toEqual({ + inputTokens: 55, + outputTokens: 13, + cacheReadTokens: 11, + cacheWriteTokens: 7, + }); + }); + test("does not rewrite non-tool SSE name fields", async () => { const result = prepareClaudeUsageRequest();