diff --git a/README.md b/README.md index 6fc91d1..358ea32 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,11 @@ Configure your LLM provider in `.dev.vars`: # You can always skip these but you'll have to set them manually for each Agency OR Blueprint. LLM_API_KEY=sk-your-key LLM_API_BASE=https://api.openai.com/v1 +LLM_RETRY_MAX=2 +LLM_RETRY_BACKOFF_MS=500 +LLM_RETRY_MAX_BACKOFF_MS=8000 +LLM_RETRY_JITTER_RATIO=0.2 +LLM_RETRY_STATUS_CODES=429,500,502,503,504,520 ``` Start the development server: diff --git a/docs/guides/deployment.md b/docs/guides/deployment.md index 5e3d692..2f60a16 100644 --- a/docs/guides/deployment.md +++ b/docs/guides/deployment.md @@ -40,6 +40,11 @@ npx wrangler secret put SECRET |----------|----------|-------------| | `LLM_API_KEY` | Yes | API key for your LLM provider | | `LLM_API_BASE` | Yes | Base URL for the LLM API | +| `LLM_RETRY_MAX` | No | Max retry attempts for LLM requests (default 2) | +| `LLM_RETRY_BACKOFF_MS` | No | Base backoff delay in ms (default 500) | +| `LLM_RETRY_MAX_BACKOFF_MS` | No | Max backoff delay in ms (default 8000) | +| `LLM_RETRY_JITTER_RATIO` | No | Jitter ratio applied to backoff (default 0.2) | +| `LLM_RETRY_STATUS_CODES` | No | Comma-separated retryable HTTP status codes | | `SECRET` | No | Secret for API authentication | ### Local Development @@ -49,6 +54,11 @@ For local development, create `.dev.vars`: ```bash LLM_API_KEY=sk-your-key LLM_API_BASE=https://api.openai.com/v1 +LLM_RETRY_MAX=2 +LLM_RETRY_BACKOFF_MS=500 +LLM_RETRY_MAX_BACKOFF_MS=8000 +LLM_RETRY_JITTER_RATIO=0.2 +LLM_RETRY_STATUS_CODES=429,500,502,503,504,520 SECRET=dev-secret ``` diff --git a/lib/README.md b/lib/README.md index 5319de5..727622d 100644 --- a/lib/README.md +++ b/lib/README.md @@ -30,6 +30,11 @@ Configure your LLM provider in `.dev.vars`: ``` LLM_API_KEY=sk-your-key LLM_API_BASE=https://api.openai.com/v1 +LLM_RETRY_MAX=2 +LLM_RETRY_BACKOFF_MS=500 +LLM_RETRY_MAX_BACKOFF_MS=8000 +LLM_RETRY_JITTER_RATIO=0.2 +LLM_RETRY_STATUS_CODES=429,500,502,503,504,520 ``` Start developing: diff --git a/lib/runtime/config.ts b/lib/runtime/config.ts index ca8ad1b..947d3e0 100644 --- a/lib/runtime/config.ts +++ b/lib/runtime/config.ts @@ -40,6 +40,52 @@ export const DEFAULT_LLM_API_BASE = "https://api.openai.com/v1"; */ export const VAR_DEFAULT_MODEL = "DEFAULT_MODEL"; +/** + * Maximum number of retries for LLM requests. + * + * @var LLM_RETRY_MAX + * @default 2 + * @example 0 disables retries + */ +export const VAR_LLM_RETRY_MAX = "LLM_RETRY_MAX"; +export const DEFAULT_LLM_RETRY_MAX = 2; + +/** + * Base backoff delay in milliseconds for LLM retries. + * + * @var LLM_RETRY_BACKOFF_MS + * @default 500 + */ +export const VAR_LLM_RETRY_BACKOFF_MS = "LLM_RETRY_BACKOFF_MS"; +export const DEFAULT_LLM_RETRY_BACKOFF_MS = 500; + +/** + * Maximum backoff delay in milliseconds for LLM retries. + * + * @var LLM_RETRY_MAX_BACKOFF_MS + * @default 8000 + */ +export const VAR_LLM_RETRY_MAX_BACKOFF_MS = "LLM_RETRY_MAX_BACKOFF_MS"; +export const DEFAULT_LLM_RETRY_MAX_BACKOFF_MS = 8000; + +/** + * Jitter ratio applied to LLM retry backoff delays. + * + * @var LLM_RETRY_JITTER_RATIO + * @default 0.2 + */ +export const VAR_LLM_RETRY_JITTER_RATIO = "LLM_RETRY_JITTER_RATIO"; +export const DEFAULT_LLM_RETRY_JITTER_RATIO = 0.2; + +/** + * Comma-separated list of HTTP status codes eligible for retry. + * + * @var LLM_RETRY_STATUS_CODES + * @default "429,500,502,503,504" + */ +export const VAR_LLM_RETRY_STATUS_CODES = "LLM_RETRY_STATUS_CODES"; +export const DEFAULT_LLM_RETRY_STATUS_CODES = [429, 500, 502, 503, 504, 520]; + // ============================================================ // Agent Loop Configuration // ============================================================ diff --git a/lib/runtime/hub.ts b/lib/runtime/hub.ts index b280243..3801aa4 100644 --- a/lib/runtime/hub.ts +++ b/lib/runtime/hub.ts @@ -3,6 +3,14 @@ import { HubAgent } from "./agent"; import { Agency, type McpToolCallRequest, type McpToolCallResponse } from "./agency"; import { AgentEventType } from "./events"; import { makeChatCompletions, type Provider } from "./providers"; +import { + DEFAULT_LLM_RETRY_BACKOFF_MS, + DEFAULT_LLM_RETRY_JITTER_RATIO, + DEFAULT_LLM_RETRY_MAX, + DEFAULT_LLM_RETRY_MAX_BACKOFF_MS, + DEFAULT_LLM_RETRY_STATUS_CODES, + DEFAULT_LLM_API_BASE, +} from "./config"; import type { Tool, AgentPlugin, @@ -13,6 +21,37 @@ import type { } from "./types"; import { createHandler, type HandlerOptions } from "./worker"; +function readNumberVar(value: unknown, fallback: number): number { + if (typeof value === "number" && Number.isFinite(value)) return value; + if (typeof value === "string") { + const trimmed = value.trim(); + if (trimmed.length > 0 && Number.isFinite(Number(trimmed))) { + return Number(trimmed); + } + } + return fallback; +} + +function readStatusCodesVar( + value: unknown, + fallback: number[] +): number[] { + if (Array.isArray(value)) { + const codes = value + .map((v) => readNumberVar(v, NaN)) + .filter((v) => Number.isFinite(v)); + if (codes.length > 0) return codes; + } + if (typeof value === "string") { + const codes = value + .split(",") + .map((v) => readNumberVar(v, NaN)) + .filter((v) => Number.isFinite(v)); + if (codes.length > 0) return codes; + } + return fallback; +} + // MCP tool info from Agency (enriched with serverName for capability matching) export interface McpToolInfo { serverId: string; @@ -496,11 +535,39 @@ export class AgentHub { const apiKey = (this.vars.LLM_API_KEY as string) ?? this.env.LLM_API_KEY; const apiBase = - (this.vars.LLM_API_BASE as string) ?? this.env.LLM_API_BASE; + (this.vars.LLM_API_BASE as string) ?? + this.env.LLM_API_BASE ?? + DEFAULT_LLM_API_BASE; if (!apiKey) throw new Error("Neither LLM_API_KEY nor custom provider set"); - baseProvider = makeChatCompletions(apiKey, apiBase); + const retry = { + maxRetries: readNumberVar( + this.vars.LLM_RETRY_MAX ?? this.env.LLM_RETRY_MAX, + DEFAULT_LLM_RETRY_MAX + ), + backoffMs: readNumberVar( + this.vars.LLM_RETRY_BACKOFF_MS ?? this.env.LLM_RETRY_BACKOFF_MS, + DEFAULT_LLM_RETRY_BACKOFF_MS + ), + maxBackoffMs: readNumberVar( + this.vars.LLM_RETRY_MAX_BACKOFF_MS ?? + this.env.LLM_RETRY_MAX_BACKOFF_MS, + DEFAULT_LLM_RETRY_MAX_BACKOFF_MS + ), + jitterRatio: readNumberVar( + this.vars.LLM_RETRY_JITTER_RATIO ?? + this.env.LLM_RETRY_JITTER_RATIO, + DEFAULT_LLM_RETRY_JITTER_RATIO + ), + retryableStatusCodes: readStatusCodesVar( + this.vars.LLM_RETRY_STATUS_CODES ?? + this.env.LLM_RETRY_STATUS_CODES, + DEFAULT_LLM_RETRY_STATUS_CODES + ), + }; + + baseProvider = makeChatCompletions(apiKey, apiBase, { retry }); } return { diff --git a/lib/runtime/providers/chat-completions.ts b/lib/runtime/providers/chat-completions.ts index c413238..d72dcdd 100644 --- a/lib/runtime/providers/chat-completions.ts +++ b/lib/runtime/providers/chat-completions.ts @@ -20,6 +20,18 @@ type OAChatMsg = }>; }; +export type ChatCompletionsRetryOptions = { + maxRetries: number; + backoffMs: number; + maxBackoffMs: number; + jitterRatio: number; + retryableStatusCodes: number[]; +}; + +export type ChatCompletionsOptions = { + retry?: ChatCompletionsRetryOptions; +}; + function toOA(req: ModelRequest) { const msgs: OAChatMsg[] = []; if (req.systemPrompt) @@ -102,45 +114,142 @@ function fromOA(choice: { message: OAChatMsg }): ChatMessage { return { role: "assistant", reasoning: msg?.reasoning, content: msg?.content ?? "" }; } +function sleep(ms: number, signal?: AbortSignal): Promise { + if (ms <= 0) return Promise.resolve(); + return new Promise((resolve, reject) => { + const timer = setTimeout(resolve, ms); + const abortError = new Error("Request aborted"); + abortError.name = "AbortError"; + + if (signal?.aborted) { + clearTimeout(timer); + return reject(abortError); + } + + if (signal) { + signal.addEventListener( + "abort", + () => { + clearTimeout(timer); + reject(abortError); + }, + { once: true } + ); + } + }); +} + +function parseRetryAfterMs(value: string | null): number | null { + if (!value) return null; + const seconds = Number.parseFloat(value); + if (Number.isFinite(seconds)) return Math.max(0, seconds * 1000); + const dateMs = Date.parse(value); + if (!Number.isNaN(dateMs)) { + const diff = dateMs - Date.now(); + return diff > 0 ? diff : 0; + } + return null; +} + +function computeDelayMs( + attempt: number, + retry: ChatCompletionsRetryOptions, + retryAfterMs: number | null +): number { + let delay = + retryAfterMs ?? + Math.min(retry.maxBackoffMs, retry.backoffMs * 2 ** attempt); + if (retry.jitterRatio > 0) { + const jitter = delay * retry.jitterRatio; + delay += (Math.random() * 2 - 1) * jitter; + } + return Math.max(0, Math.round(delay)); +} + +function isAbortError(error: unknown): boolean { + return error instanceof Error && error.name === "AbortError"; +} + +class NonRetryableError extends Error { + readonly retryable = false; +} + /** * Creates a provider for OpenAI-compatible chat completions APIs. * Works with OpenAI, OpenRouter, Azure OpenAI, and other compatible endpoints. */ export function makeChatCompletions( apiKey: string, - baseUrl = "https://api.openai.com/v1" + baseUrl = "https://api.openai.com/v1", + options: ChatCompletionsOptions = {} ): Provider { const headers = { "content-type": "application/json", authorization: `Bearer ${apiKey}` }; + const retry = options.retry && options.retry.maxRetries > 0 ? options.retry : null; return { async invoke(req, { signal }) { const body = toOA(req); - const res = await fetch(`${baseUrl}/chat/completions`, { - method: "POST", - headers, - body: JSON.stringify({ ...body, stream: false }), - signal - }); - if (!res.ok) { - const errTxt = await res.text().catch(() => ""); - throw new Error(`Chat completions error ${res.status}: ${errTxt}`); - } + const payload = JSON.stringify({ ...body, stream: false }); + for (let attempt = 0; ; attempt++) { + try { + const res = await fetch(`${baseUrl}/chat/completions`, { + method: "POST", + headers, + body: payload, + signal + }); + + if (!res.ok) { + const retryAfterMs = parseRetryAfterMs( + res.headers.get("Retry-After") + ); + if ( + retry && + retry.retryableStatusCodes.includes(res.status) && + attempt < retry.maxRetries + ) { + await sleep(computeDelayMs(attempt, retry, retryAfterMs), signal); + continue; + } - const json = (await res.json()) as { - choices: Array<{ message: OAChatMsg }>; - usage: { prompt_tokens: number; completion_tokens: number }; - }; - const message = fromOA(json.choices?.[0]); - const usage = json.usage - ? { - promptTokens: json.usage.prompt_tokens, - completionTokens: json.usage.completion_tokens + const errTxt = await res.text().catch(() => ""); + throw new NonRetryableError( + `Chat completions error ${res.status}: ${errTxt}` + ); } - : undefined; - return { message, usage }; + + const json = (await res.json()) as { + choices: Array<{ message: OAChatMsg }>; + usage: { prompt_tokens: number; completion_tokens: number }; + }; + const message = fromOA(json.choices?.[0]); + const usage = json.usage + ? { + promptTokens: json.usage.prompt_tokens, + completionTokens: json.usage.completion_tokens + } + : undefined; + return { message, usage }; + } catch (error) { + if (signal?.aborted || isAbortError(error)) { + throw error; + } + + if ( + retry && + attempt < retry.maxRetries && + !(error instanceof NonRetryableError) + ) { + await sleep(computeDelayMs(attempt, retry, null), signal); + continue; + } + + throw error; + } + } }, async stream(_req, _onDelta) { diff --git a/lib/runtime/types.ts b/lib/runtime/types.ts index 944517b..7d34437 100644 --- a/lib/runtime/types.ts +++ b/lib/runtime/types.ts @@ -161,6 +161,11 @@ export interface AgentEnv { AGENCY: DurableObjectNamespace; LLM_API_KEY?: string; LLM_API_BASE?: string; + LLM_RETRY_MAX?: string | number; + LLM_RETRY_BACKOFF_MS?: string | number; + LLM_RETRY_MAX_BACKOFF_MS?: string | number; + LLM_RETRY_JITTER_RATIO?: string | number; + LLM_RETRY_STATUS_CODES?: string; FS?: R2Bucket; SANDBOX?: DurableObjectNamespace; } diff --git a/lib/tests/llm-retry.test.ts b/lib/tests/llm-retry.test.ts new file mode 100644 index 0000000..498f9b3 --- /dev/null +++ b/lib/tests/llm-retry.test.ts @@ -0,0 +1,67 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { makeChatCompletions } from "../runtime/providers/chat-completions"; + +describe("LLM provider retries", () => { + const req = { + model: "test-model", + messages: [{ role: "user" as const, content: "Hello" }], + }; + + const okResponse = new Response( + JSON.stringify({ + choices: [{ message: { role: "assistant", content: "ok" } }], + usage: { prompt_tokens: 1, completion_tokens: 1 }, + }), + { status: 200, headers: { "content-type": "application/json" } } + ); + + const retryOptions = { + maxRetries: 1, + backoffMs: 0, + maxBackoffMs: 0, + jitterRatio: 0, + retryableStatusCodes: [520], + }; + + let fetchSpy: ReturnType; + + beforeEach(() => { + fetchSpy = vi.fn(); + vi.stubGlobal("fetch", fetchSpy); + }); + + afterEach(() => { + vi.unstubAllGlobals(); + }); + + it("retries when the response status is retryable", async () => { + fetchSpy + .mockResolvedValueOnce( + new Response("temporary", { status: 520, headers: { "Retry-After": "0" } }) + ) + .mockResolvedValueOnce(okResponse); + + const provider = makeChatCompletions("test-key", "https://example.test/v1", { + retry: retryOptions, + }); + + const result = await provider.invoke(req, {}); + expect(fetchSpy).toHaveBeenCalledTimes(2); + if ("content" in result.message) { + expect(result.message.content).toBe("ok"); + } + }); + + it("does not retry when the status is not retryable", async () => { + fetchSpy.mockResolvedValueOnce(new Response("bad", { status: 400 })); + + const provider = makeChatCompletions("test-key", "https://example.test/v1", { + retry: retryOptions, + }); + + await expect(provider.invoke(req, {})).rejects.toThrow( + "Chat completions error 400" + ); + expect(fetchSpy).toHaveBeenCalledTimes(1); + }); +});