Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 117 additions & 59 deletions src/providers/proxies/claude-proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,77 @@ const buildUpstreamUrl = (search: string): string => {
return upstream.toString();
};

const findSseEventBoundary = (
buffer: string
): { index: number; length: number } | null => {
const match = /\r\n\r\n|\n\n|\n\r\n|\r\n\n/u.exec(buffer);
if (!match || match.index === undefined) {
return null;
}

return {
index: match.index,
length: match[0].length,
};
};

const parseSseEventData = (chunk: string): string | null => {
const dataLines = chunk
.replace(/\r\n/gu, "\n")
.split("\n")
.filter((line) => line.startsWith("data:"))
.map((line) => line.slice(5).trimStart());
if (!dataLines.length) {
return null;
}

const data = dataLines.join("\n").trim();
if (!data || data === "[DONE]") {
return null;
}

return data;
};

const rewriteSseDataLines = (chunk: string, payload: string): string => {
const boundary = findSseEventBoundary(chunk);
const eventTrailer =
boundary && boundary.index + boundary.length === chunk.length
? chunk.slice(boundary.index)
: "";
const chunkBody = eventTrailer ? chunk.slice(0, -eventTrailer.length) : chunk;
const rewrittenBody = chunkBody.replace(
/((?:^|\r\n|\n))(?:data:.*(?:\r\n|\n)?)+$/u,
(_, separator: string) => `${separator}data: ${payload}`
);

if (rewrittenBody === chunkBody) {
return chunk;
}

return `${rewrittenBody}${eventTrailer}`;
};

const transformSseEventChunk = (
chunk: string,
toolPrefix: string,
readStreamUsage: (payload: unknown) => void
): string => {
const payload = parseSseEventData(chunk);
if (!payload) {
return chunk;
}

try {
const jsonBody = JSON.parse(payload) as unknown;
readStreamUsage(jsonBody);
const transformed = transformClaudeResponsePayload(jsonBody, toolPrefix);
return rewriteSseDataLines(chunk, JSON.stringify(transformed));
} catch {
return chunk;
}
};

const maybeTransformClaudeStreamResponse = (
response: Response,
toolPrefix: string,
Expand All @@ -196,7 +267,7 @@ const maybeTransformClaudeStreamResponse = (
const reader = response.body.getReader();
const encoder = new TextEncoder();
const decoder = new TextDecoder();
let pendingText = "";
let buffer = "";

const streamUsage: TokenUsage = {
inputTokens: 0,
Expand Down Expand Up @@ -278,66 +349,53 @@ const maybeTransformClaudeStreamResponse = (
}
};

const transformSseLine = (line: string): string => {
if (!line.startsWith("data:")) {
return line;
}

const payload = line.slice(5).trimStart();
if (!payload || payload === "[DONE]") {
return line;
}

try {
const jsonBody = JSON.parse(payload) as unknown;
readStreamUsage(jsonBody);
const transformed = transformClaudeResponsePayload(jsonBody, toolPrefix);
return `data: ${JSON.stringify(transformed)}`;
} catch {
return line;
}
};

const enqueueChunk = (
controller: ReadableStreamDefaultController<Uint8Array>,
chunk: string
): void => {
if (!chunk) {
return;
}
const transformedChunk = chunk
.split("\n")
.map((line) => transformSseLine(line))
.join("\n");

controller.enqueue(encoder.encode(transformedChunk));
};

const stream = new ReadableStream<Uint8Array>({
async pull(controller): Promise<void> {
const { done, value } = await reader.read();
if (done) {
pendingText += decoder.decode();
enqueueChunk(controller, pendingText);
pendingText = "";
onTokenUsage?.(streamUsage);
controller.close();
return;
}

if (!value) {
return;
}

pendingText += decoder.decode(value, { stream: true });
const lastLineBreak = pendingText.lastIndexOf("\n");
if (lastLineBreak === -1) {
return;
}
start(controller): void {
const pump = async (): Promise<void> => {
try {
while (true) {
const { done, value } = await reader.read();
if (done) {
buffer += decoder.decode();
if (buffer) {
controller.enqueue(
encoder.encode(
transformSseEventChunk(buffer, toolPrefix, readStreamUsage)
)
);
buffer = "";
}
onTokenUsage?.(streamUsage);
controller.close();
return;
}

if (!value) {
continue;
}

buffer += decoder.decode(value, { stream: true });

let boundary = findSseEventBoundary(buffer);
while (boundary) {
const chunk = buffer.slice(0, boundary.index + boundary.length);
buffer = buffer.slice(boundary.index + boundary.length);
controller.enqueue(
encoder.encode(
transformSseEventChunk(chunk, toolPrefix, readStreamUsage)
)
);
boundary = findSseEventBoundary(buffer);
}
}
} catch (error) {
controller.error(error);
}
};

const completeChunk = pendingText.slice(0, lastLineBreak + 1);
pendingText = pendingText.slice(lastLineBreak + 1);
enqueueChunk(controller, completeChunk);
pump().catch((error: unknown) => {
controller.error(error);
});
},
cancel(reason): Promise<void> {
return reader.cancel(reason);
Expand Down
Loading