diff --git a/JS/edgechains/arakoodev/src/ai/src/index.ts b/JS/edgechains/arakoodev/src/ai/src/index.ts index 2c98f37d..e1448659 100644 --- a/JS/edgechains/arakoodev/src/ai/src/index.ts +++ b/JS/edgechains/arakoodev/src/ai/src/index.ts @@ -3,3 +3,21 @@ export { GeminiAI } from "./lib/gemini/gemini.js"; export { LlamaAI } from "./lib/llama/llama.js"; export { RetellAI } from "./lib/retell-ai/retell.js"; export { RetellWebClient } from "./lib/retell-ai/retellWebClient.js"; +export { + SmartRouter, + sentryCallback, + posthogCallback, +} from "./lib/router/index.js"; +export type { + Provider, + Deployment, + Message, + ChatRequest, + ChatResponse, + StreamChunk, + Usage, + RouterCallback, + SuccessContext, + FailureContext, + RouterOptions, +} from "./lib/router/index.js"; diff --git a/JS/edgechains/arakoodev/src/ai/src/lib/router/README.md b/JS/edgechains/arakoodev/src/ai/src/lib/router/README.md new file mode 100644 index 00000000..743adbee --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/lib/router/README.md @@ -0,0 +1,95 @@ +# SmartRouter + +Load-balancing router for OpenAI / Google PaLM / Cohere chat completions. +Inspired by [LiteLLM's router](https://docs.litellm.ai/docs/routing) — closes +[#286](https://github.com/arakoodev/EdgeChains/issues/286). + +## Usage + +```ts +import { SmartRouter, sentryCallback, posthogCallback } from "@arakoodev/edgechains.js/ai"; + +const router = new SmartRouter({ retries: 2, fallback_attempts: 4 }); + +router.register({ + provider: "openai", + api_key: process.env.OPENAI_API_KEY!, + model: "gpt-3.5-turbo", + rpm_limit: 3500, + tpm_limit: 90_000, +}); +router.register({ + provider: "cohere", + api_key: process.env.COHERE_API_KEY!, + model: "command", +}); +router.register({ + provider: "google_palm", + api_key: process.env.PALM_API_KEY!, + model: "text-bison-001", +}); + +const r = await router.chat({ prompt: "What is the capital of France?" }); +console.log(r.content, r.usage); + +// Streaming +for await (const chunk of router.stream({ prompt: "Stream this" })) { + if (chunk.delta) process.stdout.write(chunk.delta); +} +``` + +## Routing rules + +On every call the router picks the deployment that: + +1. matches the requested `model` (if specified) +2. is below its `rpm_limit` and `tpm_limit` for the current minute +3. is not currently in 429 cooldown +4. has the **fewest cumulative tokens used** so far + +On a 429 the deployment is cooled until the next minute window and traffic +fails over to the next eligible deployment. Network errors and 5xx are +retried in place by `axios-retry` with exponential backoff (configurable via +`retries`). + +## Token usage + +Every `ChatResponse` and final `StreamChunk` carries a `usage` object +(`{ prompt_tokens, completion_tokens, total_tokens }`). Per-deployment +counters are exposed via `router.getUsage(deploymentId)`. + +## Logging callbacks + +```ts +import * as Sentry from "@sentry/node"; +import { PostHog } from "posthog-node"; + +router.addCallback(sentryCallback(Sentry)); +router.addCallback( + posthogCallback(new PostHog(process.env.POSTHOG_KEY!), { distinctId: "prod-router" }) +); +``` + +Both callbacks accept the client at construction time so the SDK doesn't +pull `@sentry/node` or `posthog-node` as hard dependencies. + +## Jsonnet config + +The router takes plain JS objects, so any jsonnet output that compiles to +the `Deployment` shape works directly: + +```jsonnet +// router.jsonnet +{ + deployments: [ + { provider: "openai", api_key: std.extVar("OPENAI_KEY"), model: "gpt-3.5-turbo", rpm_limit: 3500 }, + { provider: "cohere", api_key: std.extVar("COHERE_KEY"), model: "command" }, + ], +} +``` + +```ts +import jsonnet from "@arakoodev/jsonnet"; +const cfg = JSON.parse(jsonnet.evaluateFile("router.jsonnet")); +for (const d of cfg.deployments) router.register(d); +``` diff --git a/JS/edgechains/arakoodev/src/ai/src/lib/router/SmartRouter.ts b/JS/edgechains/arakoodev/src/ai/src/lib/router/SmartRouter.ts new file mode 100644 index 00000000..5341a10f --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/lib/router/SmartRouter.ts @@ -0,0 +1,335 @@ +import axios, { AxiosInstance } from "axios"; +import axiosRetry from "axios-retry"; + +import { + ChatRequest, + ChatResponse, + Deployment, + FailureContext, + Provider, + RouterCallback, + RouterOptions, + StreamChunk, + SuccessContext, + Usage, +} from "./types.js"; +import { openaiChat, openaiStream } from "./providers/openai.js"; +import { palmChat, palmStream } from "./providers/palm.js"; +import { cohereChat, cohereStream } from "./providers/cohere.js"; + +interface DeploymentState { + deployment: Required> & Deployment; + /** Minute window key. When Date.now()/60000 floor changes we reset counters. */ + window: number; + requests_this_minute: number; + tokens_this_minute: number; + cumulative_tokens: number; + /** Window key that this deployment is cooled down for (set when we observe a 429). */ + cooldown_until_window: number | null; +} + +const DEFAULT_TIMEOUT_MS = 60_000; +const DEFAULT_RETRIES = 2; +const DEFAULT_FALLBACK_ATTEMPTS = 4; + +/** + * SmartRouter — load-balances chat requests across multiple OpenAI / PaLM / + * Cohere deployments. Routing picks the deployment that is below its + * configured rpm / tpm and has the fewest cumulative tokens used so far. On + * a 429 the deployment is cooled until the next minute window and traffic + * fails over to the next eligible deployment. + */ +export class SmartRouter { + private readonly states: DeploymentState[] = []; + private readonly callbacks: RouterCallback[] = []; + private readonly http: AxiosInstance; + private readonly options: Required; + + constructor(options: RouterOptions = {}) { + this.options = { + timeout_ms: options.timeout_ms ?? DEFAULT_TIMEOUT_MS, + retries: options.retries ?? DEFAULT_RETRIES, + fallback_attempts: options.fallback_attempts ?? DEFAULT_FALLBACK_ATTEMPTS, + }; + + this.http = axios.create({ timeout: this.options.timeout_ms }); + axiosRetry(this.http, { + retries: this.options.retries, + retryDelay: axiosRetry.exponentialDelay, + // Retry network errors and 5xx in place. We deliberately do NOT + // retry 429 here; those flip the deployment to cooldown and route + // away instead, which matches what LiteLLM's router does. + retryCondition: (err) => { + if (axiosRetry.isNetworkOrIdempotentRequestError(err)) return true; + const status = err.response?.status; + return typeof status === "number" && status >= 500 && status < 600; + }, + }); + } + + register(deployment: Deployment): string { + const id = deployment.id ?? `${deployment.provider}:${deployment.model}:${this.states.length}`; + const filled: DeploymentState["deployment"] = { + ...deployment, + id, + timeout_ms: deployment.timeout_ms ?? this.options.timeout_ms, + }; + this.states.push({ + deployment: filled, + window: currentWindow(), + requests_this_minute: 0, + tokens_this_minute: 0, + cumulative_tokens: 0, + cooldown_until_window: null, + }); + return id; + } + + addCallback(cb: RouterCallback): void { + this.callbacks.push(cb); + } + + /** Per-deployment cumulative usage. */ + getUsage(deploymentId: string): { cumulative_tokens: number; tokens_this_minute: number; requests_this_minute: number } | null { + const s = this.states.find((x) => x.deployment.id === deploymentId); + if (!s) return null; + this.rolloverIfStale(s); + return { + cumulative_tokens: s.cumulative_tokens, + tokens_this_minute: s.tokens_this_minute, + requests_this_minute: s.requests_this_minute, + }; + } + + /** + * Pick the next eligible deployment. Visible for testing — production + * callers should use `chat`. + */ + pickDeployment(req: ChatRequest = {}): DeploymentState | null { + const now = currentWindow(); + const eligible = this.states.filter((s) => { + this.rolloverIfStale(s); + if (req.model && s.deployment.model !== req.model) return false; + if (s.cooldown_until_window !== null && s.cooldown_until_window >= now) return false; + const rpm = s.deployment.rpm_limit; + const tpm = s.deployment.tpm_limit; + if (typeof rpm === "number" && s.requests_this_minute >= rpm) return false; + if (typeof tpm === "number" && s.tokens_this_minute >= tpm) return false; + return true; + }); + if (eligible.length === 0) return null; + // Least cumulative tokens wins; ties broken by the order deployments + // were registered (stable sort). + eligible.sort((a, b) => a.cumulative_tokens - b.cumulative_tokens); + return eligible[0]; + } + + async chat(req: ChatRequest): Promise { + if (req.stream) { + throw new Error("chat() called with stream:true; use stream() instead"); + } + + const tried = new Set(); + const maxAttempts = Math.max(1, this.options.fallback_attempts); + let lastErr: Error | null = null; + + for (let attempt = 0; attempt < maxAttempts; attempt++) { + const candidate = this.pickDeploymentExcluding(req, tried); + if (!candidate) break; + tried.add(candidate.deployment.id); + + const start = Date.now(); + try { + const adapter = chatAdapterFor(candidate.deployment.provider); + const response = await adapter(this.http, candidate.deployment, req); + this.recordSuccess(candidate, response.usage); + await this.fireSuccess({ + deployment_id: candidate.deployment.id, + provider: candidate.deployment.provider, + model: candidate.deployment.model, + request: req, + duration_ms: Date.now() - start, + response, + }); + return response; + } catch (err) { + const e = err instanceof Error ? err : new Error(String(err)); + lastErr = e; + this.handleError(candidate, e); + await this.fireFailure({ + deployment_id: candidate.deployment.id, + provider: candidate.deployment.provider, + model: candidate.deployment.model, + request: req, + duration_ms: Date.now() - start, + error: e, + }); + // Loop continues; a different deployment will be picked. + } + } + + throw lastErr ?? new Error("SmartRouter: no eligible deployment available"); + } + + async *stream(req: ChatRequest): AsyncGenerator { + const tried = new Set(); + const maxAttempts = Math.max(1, this.options.fallback_attempts); + let lastErr: Error | null = null; + + for (let attempt = 0; attempt < maxAttempts; attempt++) { + const candidate = this.pickDeploymentExcluding(req, tried); + if (!candidate) break; + tried.add(candidate.deployment.id); + + const start = Date.now(); + const adapter = streamAdapterFor(candidate.deployment.provider); + const iter = adapter(this.http, candidate.deployment, { ...req, stream: true }); + + try { + let collected = ""; + let finalUsage: Usage | undefined; + for await (const chunk of iter) { + if (chunk.delta) collected += chunk.delta; + if (chunk.done && chunk.usage) finalUsage = chunk.usage; + yield chunk; + } + const usage = finalUsage ?? { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }; + this.recordSuccess(candidate, usage); + await this.fireSuccess({ + deployment_id: candidate.deployment.id, + provider: candidate.deployment.provider, + model: candidate.deployment.model, + request: req, + duration_ms: Date.now() - start, + response: { + content: collected, + usage, + deployment_id: candidate.deployment.id, + provider: candidate.deployment.provider, + model: candidate.deployment.model, + }, + }); + return; + } catch (err) { + const e = err instanceof Error ? err : new Error(String(err)); + lastErr = e; + this.handleError(candidate, e); + await this.fireFailure({ + deployment_id: candidate.deployment.id, + provider: candidate.deployment.provider, + model: candidate.deployment.model, + request: req, + duration_ms: Date.now() - start, + error: e, + }); + } + } + + throw lastErr ?? new Error("SmartRouter: no eligible deployment available"); + } + + // ---- internals ---- + + private pickDeploymentExcluding(req: ChatRequest, tried: Set): DeploymentState | null { + // Build the same eligibility set but skip already-tried ids. + const now = currentWindow(); + const eligible = this.states.filter((s) => { + this.rolloverIfStale(s); + if (tried.has(s.deployment.id)) return false; + if (req.model && s.deployment.model !== req.model) return false; + if (s.cooldown_until_window !== null && s.cooldown_until_window >= now) return false; + const rpm = s.deployment.rpm_limit; + const tpm = s.deployment.tpm_limit; + if (typeof rpm === "number" && s.requests_this_minute >= rpm) return false; + if (typeof tpm === "number" && s.tokens_this_minute >= tpm) return false; + return true; + }); + if (eligible.length === 0) return null; + eligible.sort((a, b) => a.cumulative_tokens - b.cumulative_tokens); + return eligible[0]; + } + + private rolloverIfStale(s: DeploymentState) { + const w = currentWindow(); + if (w !== s.window) { + s.window = w; + s.requests_this_minute = 0; + s.tokens_this_minute = 0; + // Cooldown only lasts within the window where the 429 was observed. + if (s.cooldown_until_window !== null && s.cooldown_until_window < w) { + s.cooldown_until_window = null; + } + } + } + + private recordSuccess(s: DeploymentState, usage: Usage) { + this.rolloverIfStale(s); + s.requests_this_minute += 1; + s.tokens_this_minute += usage.total_tokens; + s.cumulative_tokens += usage.total_tokens; + } + + private handleError(s: DeploymentState, err: Error) { + this.rolloverIfStale(s); + s.requests_this_minute += 1; + const status = (err as any)?.response?.status; + if (status === 429) { + // Cool this deployment down for the rest of the current minute. + s.cooldown_until_window = s.window; + } + } + + private async fireSuccess(ctx: SuccessContext) { + for (const cb of this.callbacks) { + if (cb.on_success) { + try { + await cb.on_success(ctx); + } catch { + // Callbacks must not break routing. + } + } + } + } + + private async fireFailure(ctx: FailureContext) { + for (const cb of this.callbacks) { + if (cb.on_failure) { + try { + await cb.on_failure(ctx); + } catch { + // Callbacks must not break routing. + } + } + } + } +} + +function currentWindow(): number { + return Math.floor(Date.now() / 60_000); +} + +function chatAdapterFor(provider: Provider) { + switch (provider) { + case "openai": + return openaiChat; + case "google_palm": + return palmChat; + case "cohere": + return cohereChat; + } +} + +function streamAdapterFor(provider: Provider) { + switch (provider) { + case "openai": + return openaiStream; + case "google_palm": + return palmStream; + case "cohere": + return cohereStream; + } +} diff --git a/JS/edgechains/arakoodev/src/ai/src/lib/router/callbacks/posthog.ts b/JS/edgechains/arakoodev/src/ai/src/lib/router/callbacks/posthog.ts new file mode 100644 index 00000000..a9b3849a --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/lib/router/callbacks/posthog.ts @@ -0,0 +1,52 @@ +import { FailureContext, RouterCallback, SuccessContext } from "../types.js"; + +// PostHog callback. As with Sentry, we take the client as an arg so we don't +// need posthog-node as a hard dep. Shape matches `posthog-node`'s +// `PostHog.capture`. +export interface PostHogLike { + capture(event: { + distinctId: string; + event: string; + properties?: Record; + }): void; +} + +export interface PostHogCallbackOptions { + /** distinctId to attribute events to. Defaults to "edgechains-router". */ + distinctId?: string; +} + +export function posthogCallback(client: PostHogLike, opts: PostHogCallbackOptions = {}): RouterCallback { + const distinctId = opts.distinctId ?? "edgechains-router"; + return { + on_success: (ctx: SuccessContext) => { + client.capture({ + distinctId, + event: "edgechains.router.success", + properties: { + provider: ctx.provider, + model: ctx.model, + deployment_id: ctx.deployment_id, + duration_ms: ctx.duration_ms, + prompt_tokens: ctx.response.usage.prompt_tokens, + completion_tokens: ctx.response.usage.completion_tokens, + total_tokens: ctx.response.usage.total_tokens, + }, + }); + }, + on_failure: (ctx: FailureContext) => { + client.capture({ + distinctId, + event: "edgechains.router.failure", + properties: { + provider: ctx.provider, + model: ctx.model, + deployment_id: ctx.deployment_id, + duration_ms: ctx.duration_ms, + error_message: ctx.error.message, + error_status: (ctx.error as any)?.response?.status, + }, + }); + }, + }; +} diff --git a/JS/edgechains/arakoodev/src/ai/src/lib/router/callbacks/sentry.ts b/JS/edgechains/arakoodev/src/ai/src/lib/router/callbacks/sentry.ts new file mode 100644 index 00000000..0381be4d --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/lib/router/callbacks/sentry.ts @@ -0,0 +1,37 @@ +import { FailureContext, RouterCallback, SuccessContext } from "../types.js"; + +// Sentry callback. We accept the Sentry hub/client at call time so the SDK +// doesn't pull `@sentry/node` as a hard dependency. The shape is the subset +// of `@sentry/node` that we actually need: `captureException` and +// `addBreadcrumb`. +export interface SentryLike { + captureException(err: any, captureContext?: any): any; + addBreadcrumb?(breadcrumb: any): void; +} + +export function sentryCallback(sentry: SentryLike): RouterCallback { + return { + on_success: (ctx: SuccessContext) => { + sentry.addBreadcrumb?.({ + category: "edgechains.router", + level: "info", + message: `chat ok ${ctx.provider}:${ctx.model}`, + data: { + deployment_id: ctx.deployment_id, + duration_ms: ctx.duration_ms, + total_tokens: ctx.response.usage.total_tokens, + }, + }); + }, + on_failure: (ctx: FailureContext) => { + sentry.captureException(ctx.error, { + tags: { + "edgechains.provider": ctx.provider, + "edgechains.model": ctx.model, + "edgechains.deployment_id": ctx.deployment_id, + }, + extra: { duration_ms: ctx.duration_ms }, + }); + }, + }; +} diff --git a/JS/edgechains/arakoodev/src/ai/src/lib/router/index.ts b/JS/edgechains/arakoodev/src/ai/src/lib/router/index.ts new file mode 100644 index 00000000..4262af1a --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/lib/router/index.ts @@ -0,0 +1,16 @@ +export { SmartRouter } from "./SmartRouter.js"; +export { sentryCallback } from "./callbacks/sentry.js"; +export { posthogCallback } from "./callbacks/posthog.js"; +export type { + Provider, + Deployment, + Message, + ChatRequest, + ChatResponse, + StreamChunk, + Usage, + RouterCallback, + SuccessContext, + FailureContext, + RouterOptions, +} from "./types.js"; diff --git a/JS/edgechains/arakoodev/src/ai/src/lib/router/providers/cohere.ts b/JS/edgechains/arakoodev/src/ai/src/lib/router/providers/cohere.ts new file mode 100644 index 00000000..b04ac209 --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/lib/router/providers/cohere.ts @@ -0,0 +1,115 @@ +import { AxiosInstance } from "axios"; +import { ChatRequest, ChatResponse, Deployment, StreamChunk, Usage } from "../types.js"; +import { parseSSE } from "../sse.js"; + +const CHAT_URL = "https://api.cohere.ai/v1/chat"; + +function buildBody(req: ChatRequest, deployment: Deployment, stream: boolean) { + const message = + req.prompt ?? + (req.messages ?? []) + .filter((m) => m.role === "user") + .slice(-1)[0]?.content ?? + ""; + const chat_history = + (req.messages ?? []) + .slice(0, -1) + .filter((m) => m.role !== "system") + .map((m) => ({ + role: m.role === "assistant" ? "CHATBOT" : "USER", + message: m.content, + })); + return { + model: deployment.model, + message, + chat_history, + max_tokens: req.max_tokens ?? 256, + temperature: req.temperature ?? 0.7, + stream, + }; +} + +function headers(deployment: Deployment) { + return { + Authorization: `Bearer ${deployment.api_key}`, + "Content-Type": "application/json", + }; +} + +function usageFromMeta(meta: any): Usage { + // Cohere reports tokens under either `meta.tokens` or `meta.billed_units` + // depending on API version. Take whichever is populated. + const t = meta?.tokens ?? {}; + const b = meta?.billed_units ?? {}; + const prompt_tokens = t.input_tokens ?? b.input_tokens ?? 0; + const completion_tokens = t.output_tokens ?? b.output_tokens ?? 0; + return { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + }; +} + +export async function cohereChat( + http: AxiosInstance, + deployment: Deployment, + req: ChatRequest +): Promise { + const resp = await http.post(CHAT_URL, buildBody(req, deployment, false), { + headers: headers(deployment), + timeout: deployment.timeout_ms, + }); + const data = resp.data; + const content: string = data?.text ?? ""; + return { + content, + usage: usageFromMeta(data?.meta), + deployment_id: deployment.id!, + provider: deployment.provider, + model: deployment.model, + }; +} + +export async function* cohereStream( + http: AxiosInstance, + deployment: Deployment, + req: ChatRequest +): AsyncGenerator { + const resp = await http.post(CHAT_URL, buildBody(req, deployment, true), { + headers: headers(deployment), + timeout: deployment.timeout_ms, + responseType: "stream", + }); + + let usage: Usage | undefined; + + for await (const event of parseSSE(resp.data)) { + if (!event.data) continue; + let json: any; + try { + json = JSON.parse(event.data); + } catch { + continue; + } + if (json?.event_type === "text-generation" && typeof json.text === "string") { + yield { + delta: json.text, + done: false, + deployment_id: deployment.id!, + provider: deployment.provider, + model: deployment.model, + }; + } else if (json?.event_type === "stream-end") { + usage = usageFromMeta(json?.response?.meta); + } + } + + yield { + delta: "", + done: true, + usage: usage ?? { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }, + deployment_id: deployment.id!, + provider: deployment.provider, + model: deployment.model, + }; +} diff --git a/JS/edgechains/arakoodev/src/ai/src/lib/router/providers/openai.ts b/JS/edgechains/arakoodev/src/ai/src/lib/router/providers/openai.ts new file mode 100644 index 00000000..62ba7708 --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/lib/router/providers/openai.ts @@ -0,0 +1,118 @@ +import { AxiosInstance } from "axios"; +import { ChatRequest, ChatResponse, Deployment, StreamChunk, Usage } from "../types.js"; +import { parseSSE } from "../sse.js"; + +const CHAT_URL = "https://api.openai.com/v1/chat/completions"; + +function buildBody(req: ChatRequest, deployment: Deployment, stream: boolean) { + const messages = + req.messages ?? + (req.prompt ? [{ role: "user" as const, content: req.prompt }] : []); + return { + model: deployment.model, + messages, + max_tokens: req.max_tokens ?? 256, + temperature: req.temperature ?? 0.7, + stream, + }; +} + +function headers(deployment: Deployment) { + return { + Authorization: `Bearer ${deployment.api_key}`, + "Content-Type": "application/json", + }; +} + +export async function openaiChat( + http: AxiosInstance, + deployment: Deployment, + req: ChatRequest +): Promise { + const resp = await http.post(CHAT_URL, buildBody(req, deployment, false), { + headers: headers(deployment), + timeout: deployment.timeout_ms, + }); + const data = resp.data; + const content: string = data?.choices?.[0]?.message?.content ?? ""; + const u = data?.usage ?? {}; + const usage: Usage = { + prompt_tokens: u.prompt_tokens ?? 0, + completion_tokens: u.completion_tokens ?? 0, + total_tokens: u.total_tokens ?? (u.prompt_tokens ?? 0) + (u.completion_tokens ?? 0), + }; + return { + content, + usage, + deployment_id: deployment.id!, + provider: deployment.provider, + model: deployment.model, + }; +} + +export async function* openaiStream( + http: AxiosInstance, + deployment: Deployment, + req: ChatRequest +): AsyncGenerator { + const resp = await http.post(CHAT_URL, buildBody(req, deployment, true), { + headers: headers(deployment), + timeout: deployment.timeout_ms, + responseType: "stream", + }); + + let prompt_tokens = 0; + let completion_tokens = 0; + let lastChunk: StreamChunk | null = null; + + for await (const event of parseSSE(resp.data)) { + if (!event.data || event.data === "[DONE]") continue; + let json: any; + try { + json = JSON.parse(event.data); + } catch { + continue; + } + const delta: string = json?.choices?.[0]?.delta?.content ?? ""; + if (json?.usage) { + prompt_tokens = json.usage.prompt_tokens ?? prompt_tokens; + completion_tokens = json.usage.completion_tokens ?? completion_tokens; + } + if (delta) { + // Rough completion-token estimate when usage isn't streamed back + // (OpenAI only sends usage when stream_options.include_usage is set). + completion_tokens += estimateTokens(delta); + const chunk: StreamChunk = { + delta, + done: false, + deployment_id: deployment.id!, + provider: deployment.provider, + model: deployment.model, + }; + lastChunk = chunk; + yield chunk; + } + } + + yield { + delta: "", + done: true, + usage: { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + }, + deployment_id: deployment.id!, + provider: deployment.provider, + model: deployment.model, + }; + void lastChunk; +} + +function estimateTokens(s: string): number { + // Cheap heuristic: ~4 chars per token. We only use this when the provider + // doesn't stream usage back, so over/under-counting by a few tokens won't + // hurt routing decisions in practice. + if (!s) return 0; + return Math.max(1, Math.ceil(s.length / 4)); +} diff --git a/JS/edgechains/arakoodev/src/ai/src/lib/router/providers/palm.ts b/JS/edgechains/arakoodev/src/ai/src/lib/router/providers/palm.ts new file mode 100644 index 00000000..2fff0954 --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/lib/router/providers/palm.ts @@ -0,0 +1,83 @@ +import { AxiosInstance } from "axios"; +import { ChatRequest, ChatResponse, Deployment, StreamChunk, Usage } from "../types.js"; + +// Google PaLM "generative language" v1beta2 generateText endpoint. The API was +// deprecated in favor of Gemini, but #286 explicitly names it so we keep the +// shape that matches the historical PaLM REST contract. +const BASE = "https://generativelanguage.googleapis.com/v1beta2"; + +function endpoint(model: string) { + // Allow callers to pass either "models/text-bison-001" or "text-bison-001". + const m = model.startsWith("models/") ? model : `models/${model}`; + return `${BASE}/${m}:generateText`; +} + +function buildBody(req: ChatRequest) { + const text = req.prompt ?? (req.messages ?? []).map((m) => `${m.role}: ${m.content}`).join("\n"); + return { + prompt: { text }, + temperature: req.temperature ?? 0.7, + maxOutputTokens: req.max_tokens ?? 256, + }; +} + +function approxTokens(s: string): number { + if (!s) return 0; + return Math.max(1, Math.ceil(s.length / 4)); +} + +export async function palmChat( + http: AxiosInstance, + deployment: Deployment, + req: ChatRequest +): Promise { + const url = `${endpoint(deployment.model)}?key=${encodeURIComponent(deployment.api_key)}`; + const body = buildBody(req); + const resp = await http.post(url, body, { + headers: { "Content-Type": "application/json" }, + timeout: deployment.timeout_ms, + }); + const data = resp.data; + const content: string = data?.candidates?.[0]?.output ?? ""; + + // PaLM doesn't return usage on generateText; estimate from prompt + output. + const promptText = body.prompt.text; + const usage: Usage = { + prompt_tokens: approxTokens(promptText), + completion_tokens: approxTokens(content), + total_tokens: approxTokens(promptText) + approxTokens(content), + }; + + return { + content, + usage, + deployment_id: deployment.id!, + provider: deployment.provider, + model: deployment.model, + }; +} + +export async function* palmStream( + http: AxiosInstance, + deployment: Deployment, + req: ChatRequest +): AsyncGenerator { + // PaLM generateText doesn't natively stream. We yield the full response as + // one chunk so callers using { stream: true } still get an AsyncIterable. + const full = await palmChat(http, deployment, req); + yield { + delta: full.content, + done: false, + deployment_id: deployment.id!, + provider: deployment.provider, + model: deployment.model, + }; + yield { + delta: "", + done: true, + usage: full.usage, + deployment_id: deployment.id!, + provider: deployment.provider, + model: deployment.model, + }; +} diff --git a/JS/edgechains/arakoodev/src/ai/src/lib/router/sse.ts b/JS/edgechains/arakoodev/src/ai/src/lib/router/sse.ts new file mode 100644 index 00000000..b3bb5c7d --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/lib/router/sse.ts @@ -0,0 +1,81 @@ +// Minimal Server-Sent Events parser. We deliberately avoid pulling +// `eventsource-parser` here so the SDK stays dep-light; the protocol is +// small enough to handle inline. +// +// Accepts either a Node Readable stream (axios responseType: "stream") or any +// AsyncIterable. + +export interface SSEEvent { + event?: string; + data: string; + id?: string; +} + +type Source = AsyncIterable | NodeJS.ReadableStream; + +async function* toAsyncIterable(src: Source): AsyncIterable { + // Node Readable streams are AsyncIterable in modern Node, but we coerce + // explicitly so callers don't have to care. + const it: AsyncIterable = + typeof (src as any)[Symbol.asyncIterator] === "function" + ? (src as AsyncIterable) + : (async function* () { + // Fallback: treat as a Readable in flowing mode + for await (const chunk of src as any) yield chunk; + })(); + + for await (const chunk of it) { + if (typeof chunk === "string") yield chunk; + else if (chunk instanceof Uint8Array) yield Buffer.from(chunk).toString("utf8"); + else yield String(chunk); + } +} + +export async function* parseSSE(src: Source): AsyncGenerator { + let buf = ""; + let event: string | undefined; + let dataLines: string[] = []; + let id: string | undefined; + + const flush = (): SSEEvent | null => { + if (dataLines.length === 0 && !event) return null; + const ev: SSEEvent = { data: dataLines.join("\n") }; + if (event) ev.event = event; + if (id) ev.id = id; + event = undefined; + dataLines = []; + id = undefined; + return ev; + }; + + for await (const piece of toAsyncIterable(src)) { + buf += piece; + let nl: number; + while ((nl = buf.indexOf("\n")) !== -1) { + let line = buf.slice(0, nl); + buf = buf.slice(nl + 1); + // Strip a trailing \r from CRLF line endings. + if (line.endsWith("\r")) line = line.slice(0, -1); + + if (line === "") { + const ev = flush(); + if (ev) yield ev; + continue; + } + if (line.startsWith(":")) continue; // comment + const colon = line.indexOf(":"); + const field = colon === -1 ? line : line.slice(0, colon); + // SSE spec: a single leading space after the colon is stripped. + let value = colon === -1 ? "" : line.slice(colon + 1); + if (value.startsWith(" ")) value = value.slice(1); + + if (field === "data") dataLines.push(value); + else if (field === "event") event = value; + else if (field === "id") id = value; + } + } + + // Flush any trailing event without a terminating blank line. + const tail = flush(); + if (tail) yield tail; +} diff --git a/JS/edgechains/arakoodev/src/ai/src/lib/router/types.ts b/JS/edgechains/arakoodev/src/ai/src/lib/router/types.ts new file mode 100644 index 00000000..0aa438c0 --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/lib/router/types.ts @@ -0,0 +1,89 @@ +// Shared types for the Smart Router. Kept narrow on purpose — only what the +// four features in issue #286 actually need. + +export type Provider = "openai" | "google_palm" | "cohere"; + +export interface Deployment { + /** Stable id for this deployment (defaults to `${provider}:${model}` if omitted). */ + id?: string; + provider: Provider; + api_key: string; + /** Model name for the provider, e.g. "gpt-3.5-turbo", "models/text-bison-001", "command". */ + model: string; + /** Requests-per-minute soft limit. Routing skips this deployment past the limit. */ + rpm_limit?: number; + /** Tokens-per-minute soft limit. Routing skips this deployment past the limit. */ + tpm_limit?: number; + /** Per-request timeout in ms. Defaults to router-level timeout. */ + timeout_ms?: number; +} + +export interface Message { + role: "user" | "assistant" | "system"; + content: string; +} + +export interface ChatRequest { + /** Logical model the caller is asking for. If set, only deployments with matching model are eligible. */ + model?: string; + messages?: Message[]; + prompt?: string; + max_tokens?: number; + temperature?: number; + stream?: boolean; +} + +export interface Usage { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; +} + +export interface ChatResponse { + content: string; + usage: Usage; + deployment_id: string; + provider: Provider; + model: string; +} + +export interface StreamChunk { + delta: string; + /** True on the final chunk; usage may be populated. */ + done: boolean; + usage?: Usage; + deployment_id: string; + provider: Provider; + model: string; +} + +export interface CallbackContext { + deployment_id: string; + provider: Provider; + model: string; + request: ChatRequest; + /** Wall-clock duration for the call in ms. */ + duration_ms: number; +} + +export interface SuccessContext extends CallbackContext { + response: ChatResponse; +} + +export interface FailureContext extends CallbackContext { + error: Error; +} + +export interface RouterCallback { + on_success?: (ctx: SuccessContext) => void | Promise; + on_failure?: (ctx: FailureContext) => void | Promise; +} + +export interface RouterOptions { + /** Default per-request timeout in ms. */ + timeout_ms?: number; + /** How many retries to attempt against alternate deployments before throwing. */ + fallback_attempts?: number; + /** Per-deployment axios-retry retry count for transient failures (5xx/network). 429s are routed away rather than retried in place. */ + retries?: number; +} diff --git a/JS/edgechains/arakoodev/src/ai/src/tests/router/callbacks.test.ts b/JS/edgechains/arakoodev/src/ai/src/tests/router/callbacks.test.ts new file mode 100644 index 00000000..c3b7c32b --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/tests/router/callbacks.test.ts @@ -0,0 +1,69 @@ +import { describe, expect, test, vi } from "vitest"; +import { sentryCallback } from "../../lib/router/callbacks/sentry.js"; +import { posthogCallback } from "../../lib/router/callbacks/posthog.js"; +import { SuccessContext, FailureContext } from "../../lib/router/types.js"; + +const successCtx: SuccessContext = { + deployment_id: "d1", + provider: "openai", + model: "gpt-3.5-turbo", + request: { prompt: "hello" }, + duration_ms: 42, + response: { + content: "world", + usage: { prompt_tokens: 5, completion_tokens: 7, total_tokens: 12 }, + deployment_id: "d1", + provider: "openai", + model: "gpt-3.5-turbo", + }, +}; + +const failureCtx: FailureContext = { + deployment_id: "d1", + provider: "openai", + model: "gpt-3.5-turbo", + request: { prompt: "hello" }, + duration_ms: 13, + error: Object.assign(new Error("boom"), { response: { status: 500 } }), +}; + +describe("sentry callback", () => { + test("on_success emits a breadcrumb", () => { + const sentry = { captureException: vi.fn(), addBreadcrumb: vi.fn() }; + sentryCallback(sentry).on_success!(successCtx); + expect(sentry.addBreadcrumb).toHaveBeenCalledTimes(1); + const arg = sentry.addBreadcrumb.mock.calls[0][0]; + expect(arg.category).toBe("edgechains.router"); + expect(arg.data.total_tokens).toBe(12); + }); + + test("on_failure captures the exception with tags", () => { + const sentry = { captureException: vi.fn(), addBreadcrumb: vi.fn() }; + sentryCallback(sentry).on_failure!(failureCtx); + expect(sentry.captureException).toHaveBeenCalledTimes(1); + const [err, ctx] = sentry.captureException.mock.calls[0]; + expect(err.message).toBe("boom"); + expect(ctx.tags["edgechains.provider"]).toBe("openai"); + }); +}); + +describe("posthog callback", () => { + test("on_success captures a typed event", () => { + const ph = { capture: vi.fn() }; + posthogCallback(ph).on_success!(successCtx); + expect(ph.capture).toHaveBeenCalledTimes(1); + const ev = ph.capture.mock.calls[0][0]; + expect(ev.event).toBe("edgechains.router.success"); + expect(ev.properties.total_tokens).toBe(12); + expect(ev.properties.duration_ms).toBe(42); + }); + + test("on_failure captures an error event with status", () => { + const ph = { capture: vi.fn() }; + posthogCallback(ph).on_failure!(failureCtx); + const ev = ph.capture.mock.calls[0][0]; + expect(ev.event).toBe("edgechains.router.failure"); + expect(ev.properties.error_message).toBe("boom"); + expect(ev.properties.error_status).toBe(500); + }); +}); diff --git a/JS/edgechains/arakoodev/src/ai/src/tests/router/providers.test.ts b/JS/edgechains/arakoodev/src/ai/src/tests/router/providers.test.ts new file mode 100644 index 00000000..2583c1aa --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/tests/router/providers.test.ts @@ -0,0 +1,111 @@ +import { describe, expect, test, beforeEach, vi } from "vitest"; +import { SmartRouter } from "../../lib/router/SmartRouter.js"; + +vi.mock("axios", async () => { + const actual = await vi.importActual("axios"); + return { + default: { + ...actual.default, + create: () => makeFakeAxios(), + }, + }; +}); +vi.mock("axios-retry", () => ({ + default: () => undefined, + isNetworkOrIdempotentRequestError: () => false, + exponentialDelay: () => 0, +})); + +let lastFake: any = null; +function makeFakeAxios() { + const f = { + post: vi.fn(), + interceptors: { response: { use: () => 0 } }, + defaults: {}, + }; + lastFake = f; + return f; +} + +describe("provider adapters", () => { + beforeEach(() => { + lastFake = null; + vi.clearAllMocks(); + }); + + test("openai adapter normalizes content and usage", async () => { + const router = new SmartRouter(); + router.register({ + provider: "openai", + api_key: "sk-test", + model: "gpt-3.5-turbo", + id: "oai", + }); + + lastFake.post.mockResolvedValue({ + data: { + choices: [{ message: { content: "answer" } }], + usage: { prompt_tokens: 10, completion_tokens: 20, total_tokens: 30 }, + }, + }); + + const r = await router.chat({ prompt: "hello" }); + expect(r.content).toBe("answer"); + expect(r.usage).toEqual({ prompt_tokens: 10, completion_tokens: 20, total_tokens: 30 }); + + // Verify the request URL + auth header reached the adapter unchanged. + const [url, body, config] = lastFake.post.mock.calls[0]; + expect(url).toBe("https://api.openai.com/v1/chat/completions"); + expect(body.model).toBe("gpt-3.5-turbo"); + expect(body.messages).toEqual([{ role: "user", content: "hello" }]); + expect(config.headers.Authorization).toBe("Bearer sk-test"); + }); + + test("cohere adapter pulls usage from meta.billed_units", async () => { + const router = new SmartRouter(); + router.register({ + provider: "cohere", + api_key: "co-test", + model: "command", + id: "co", + }); + + lastFake.post.mockResolvedValue({ + data: { + text: "cohere reply", + meta: { billed_units: { input_tokens: 4, output_tokens: 8 } }, + }, + }); + + const r = await router.chat({ prompt: "hi" }); + expect(r.content).toBe("cohere reply"); + expect(r.usage).toEqual({ prompt_tokens: 4, completion_tokens: 8, total_tokens: 12 }); + + const [url, body, config] = lastFake.post.mock.calls[0]; + expect(url).toBe("https://api.cohere.ai/v1/chat"); + expect(body.message).toBe("hi"); + expect(config.headers.Authorization).toBe("Bearer co-test"); + }); + + test("palm adapter sends key as query param and approximates usage", async () => { + const router = new SmartRouter(); + router.register({ + provider: "google_palm", + api_key: "AIzaTEST", + model: "text-bison-001", + id: "palm", + }); + + lastFake.post.mockResolvedValue({ + data: { candidates: [{ output: "palm reply" }] }, + }); + + const r = await router.chat({ prompt: "hi" }); + expect(r.content).toBe("palm reply"); + expect(r.usage.total_tokens).toBeGreaterThan(0); + + const [url] = lastFake.post.mock.calls[0]; + expect(url).toContain("models/text-bison-001:generateText"); + expect(url).toContain("key=AIzaTEST"); + }); +}); diff --git a/JS/edgechains/arakoodev/src/ai/src/tests/router/routing.test.ts b/JS/edgechains/arakoodev/src/ai/src/tests/router/routing.test.ts new file mode 100644 index 00000000..dcde0538 --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/tests/router/routing.test.ts @@ -0,0 +1,200 @@ +import { describe, expect, test, beforeEach, vi } from "vitest"; +import { SmartRouter } from "../../lib/router/SmartRouter.js"; + +// Stub axios.create to return a fake instance whose `.post` we can drive. +// We mock at the axios layer rather than with msw because axios-retry + +// responseType:"stream" interactions are easier to reason about with direct +// fakes, and the routing logic itself is HTTP-shape-agnostic. + +vi.mock("axios", async () => { + const actual = await vi.importActual("axios"); + return { + default: { + ...actual.default, + create: () => makeFakeAxios(), + }, + }; +}); + +vi.mock("axios-retry", () => ({ + default: () => undefined, + isNetworkOrIdempotentRequestError: () => false, + exponentialDelay: () => 0, +})); + +interface FakeAxios { + post: ReturnType; + interceptors: { response: { use: () => number } }; + defaults: any; +} + +let lastFake: FakeAxios | null = null; +function makeFakeAxios(): FakeAxios { + const f: FakeAxios = { + post: vi.fn(), + interceptors: { response: { use: () => 0 } }, + defaults: {}, + }; + lastFake = f; + return f; +} + +function ok(content: string, usage = { prompt_tokens: 5, completion_tokens: 7, total_tokens: 12 }) { + return { + data: { + choices: [{ message: { content } }], + usage, + }, + }; +} + +function rateLimitError() { + const err: any = new Error("Too Many Requests"); + err.response = { status: 429, data: {} }; + return err; +} + +describe("SmartRouter routing", () => { + beforeEach(() => { + lastFake = null; + vi.clearAllMocks(); + }); + + test("picks the deployment with the fewest cumulative tokens", async () => { + const router = new SmartRouter(); + const a = router.register({ provider: "openai", api_key: "k1", model: "gpt-3.5-turbo", id: "a" }); + const b = router.register({ provider: "openai", api_key: "k2", model: "gpt-3.5-turbo", id: "b" }); + + // Three calls: first goes to whichever (tie -> a, registered first), + // second should go to b (b has 0, a has 12), third back to a. + lastFake!.post.mockResolvedValue(ok("hi")); + const r1 = await router.chat({ prompt: "p" }); + const r2 = await router.chat({ prompt: "p" }); + const r3 = await router.chat({ prompt: "p" }); + + expect(r1.deployment_id).toBe(a); + expect(r2.deployment_id).toBe(b); + expect(r3.deployment_id).toBe(a); + + const usageA = router.getUsage(a)!; + const usageB = router.getUsage(b)!; + expect(usageA.cumulative_tokens).toBe(24); + expect(usageB.cumulative_tokens).toBe(12); + }); + + test("skips deployments past their rpm limit", async () => { + const router = new SmartRouter(); + const a = router.register({ + provider: "openai", + api_key: "k1", + model: "gpt-3.5-turbo", + id: "a", + rpm_limit: 1, + }); + const b = router.register({ + provider: "openai", + api_key: "k2", + model: "gpt-3.5-turbo", + id: "b", + }); + + lastFake!.post.mockResolvedValue(ok("hi")); + const r1 = await router.chat({ prompt: "p" }); + const r2 = await router.chat({ prompt: "p" }); + + expect(r1.deployment_id).toBe(a); + // a is now at its rpm limit -> b is the only eligible one. + expect(r2.deployment_id).toBe(b); + }); + + test("skips deployments past their tpm limit", async () => { + const router = new SmartRouter(); + const a = router.register({ + provider: "openai", + api_key: "k1", + model: "gpt-3.5-turbo", + id: "a", + tpm_limit: 5, // very tight; first response (12 tokens) blows past it + }); + const b = router.register({ + provider: "openai", + api_key: "k2", + model: "gpt-3.5-turbo", + id: "b", + }); + + lastFake!.post.mockResolvedValue(ok("hi")); + const r1 = await router.chat({ prompt: "p" }); + expect(r1.deployment_id).toBe(a); + const r2 = await router.chat({ prompt: "p" }); + expect(r2.deployment_id).toBe(b); + }); + + test("on 429 the deployment is cooled down and traffic fails over", async () => { + const router = new SmartRouter(); + const a = router.register({ provider: "openai", api_key: "k1", model: "gpt-3.5-turbo", id: "a" }); + const b = router.register({ provider: "openai", api_key: "k2", model: "gpt-3.5-turbo", id: "b" }); + + // First call: a -> 429, then b -> ok. + lastFake!.post + .mockRejectedValueOnce(rateLimitError()) + .mockResolvedValueOnce(ok("hi")); + + const r1 = await router.chat({ prompt: "p" }); + expect(r1.deployment_id).toBe(b); + + // Next call must go to b again because a is cooled for the rest of + // this minute window. + lastFake!.post.mockResolvedValueOnce(ok("hi")); + const r2 = await router.chat({ prompt: "p" }); + expect(r2.deployment_id).toBe(b); + }); + + test("respects model filter when picking", async () => { + const router = new SmartRouter(); + router.register({ provider: "openai", api_key: "k1", model: "gpt-3.5-turbo", id: "a" }); + const b = router.register({ + provider: "openai", + api_key: "k2", + model: "gpt-4", + id: "b", + }); + + lastFake!.post.mockResolvedValue(ok("hi")); + const r = await router.chat({ prompt: "p", model: "gpt-4" }); + expect(r.deployment_id).toBe(b); + }); + + test("throws when every deployment has been exhausted", async () => { + const router = new SmartRouter({ fallback_attempts: 5 }); + router.register({ provider: "openai", api_key: "k1", model: "gpt-3.5-turbo", id: "a" }); + router.register({ provider: "openai", api_key: "k2", model: "gpt-3.5-turbo", id: "b" }); + + lastFake!.post.mockRejectedValue(rateLimitError()); + + await expect(router.chat({ prompt: "p" })).rejects.toThrow(/Too Many Requests/); + }); + + test("fires success and failure callbacks", async () => { + const router = new SmartRouter(); + router.register({ provider: "openai", api_key: "k1", model: "gpt-3.5-turbo", id: "a" }); + router.register({ provider: "openai", api_key: "k2", model: "gpt-3.5-turbo", id: "b" }); + + const onSuccess = vi.fn(); + const onFailure = vi.fn(); + router.addCallback({ on_success: onSuccess, on_failure: onFailure }); + + lastFake!.post + .mockRejectedValueOnce(rateLimitError()) + .mockResolvedValueOnce(ok("hello")); + + const r = await router.chat({ prompt: "p" }); + expect(r.content).toBe("hello"); + expect(onFailure).toHaveBeenCalledTimes(1); + expect(onSuccess).toHaveBeenCalledTimes(1); + + const successCtx = onSuccess.mock.calls[0][0]; + expect(successCtx.response.usage.total_tokens).toBe(12); + expect(typeof successCtx.duration_ms).toBe("number"); + }); +}); diff --git a/JS/edgechains/arakoodev/src/ai/src/tests/router/streaming.test.ts b/JS/edgechains/arakoodev/src/ai/src/tests/router/streaming.test.ts new file mode 100644 index 00000000..74e02ffd --- /dev/null +++ b/JS/edgechains/arakoodev/src/ai/src/tests/router/streaming.test.ts @@ -0,0 +1,96 @@ +import { describe, expect, test, beforeEach, vi } from "vitest"; +import { Readable } from "node:stream"; +import { SmartRouter } from "../../lib/router/SmartRouter.js"; + +vi.mock("axios", async () => { + const actual = await vi.importActual("axios"); + return { + default: { + ...actual.default, + create: () => makeFakeAxios(), + }, + }; +}); +vi.mock("axios-retry", () => ({ + default: () => undefined, + isNetworkOrIdempotentRequestError: () => false, + exponentialDelay: () => 0, +})); + +let lastFake: any = null; +function makeFakeAxios() { + const f = { + post: vi.fn(), + interceptors: { response: { use: () => 0 } }, + defaults: {}, + }; + lastFake = f; + return f; +} + +function sseStream(events: string[]): Readable { + // Build an SSE-formatted body and emit it in two chunks to exercise the + // partial-line buffering in parseSSE. + const text = events.map((e) => `data: ${e}\n\n`).join(""); + const half = Math.floor(text.length / 2); + const a = Buffer.from(text.slice(0, half), "utf8"); + const b = Buffer.from(text.slice(half), "utf8"); + return Readable.from([a, b]); +} + +describe("SmartRouter streaming", () => { + beforeEach(() => { + lastFake = null; + vi.clearAllMocks(); + }); + + test("openai stream concatenates deltas and reports usage", async () => { + const router = new SmartRouter(); + router.register({ provider: "openai", api_key: "k", model: "gpt-3.5-turbo", id: "a" }); + + const events = [ + JSON.stringify({ choices: [{ delta: { content: "Hello" } }] }), + JSON.stringify({ choices: [{ delta: { content: ", " } }] }), + JSON.stringify({ choices: [{ delta: { content: "world" } }] }), + "[DONE]", + ]; + lastFake.post.mockResolvedValue({ data: sseStream(events) }); + + const chunks: string[] = []; + let finalUsage: any = null; + for await (const c of router.stream({ prompt: "hi" })) { + if (c.delta) chunks.push(c.delta); + if (c.done) finalUsage = c.usage; + } + + expect(chunks.join("")).toBe("Hello, world"); + expect(finalUsage.total_tokens).toBeGreaterThan(0); + + const usage = router.getUsage("a")!; + expect(usage.cumulative_tokens).toBe(finalUsage.total_tokens); + }); + + test("cohere stream parses event_type frames", async () => { + const router = new SmartRouter(); + router.register({ provider: "cohere", api_key: "k", model: "command", id: "co" }); + + const events = [ + JSON.stringify({ event_type: "text-generation", text: "foo " }), + JSON.stringify({ event_type: "text-generation", text: "bar" }), + JSON.stringify({ + event_type: "stream-end", + response: { meta: { billed_units: { input_tokens: 3, output_tokens: 5 } } }, + }), + ]; + lastFake.post.mockResolvedValue({ data: sseStream(events) }); + + const collected: string[] = []; + let usage: any = null; + for await (const c of router.stream({ prompt: "hi" })) { + if (c.delta) collected.push(c.delta); + if (c.done) usage = c.usage; + } + expect(collected.join("")).toBe("foo bar"); + expect(usage).toEqual({ prompt_tokens: 3, completion_tokens: 5, total_tokens: 8 }); + }); +});