diff --git a/src/providers/embedding/openai.ts b/src/providers/embedding/openai.ts index c868d4a..99205c6 100644 --- a/src/providers/embedding/openai.ts +++ b/src/providers/embedding/openai.ts @@ -1,16 +1,33 @@ import type { EmbeddingProvider } from "../../types.js"; import { getEnvVar } from "../../config.js"; -const API_URL = "https://api.openai.com/v1/embeddings"; - +const DEFAULT_BASE_URL = "https://api.openai.com"; +const DEFAULT_MODEL = "text-embedding-3-small"; + +/** + * OpenAI-compatible embedding provider. + * + * Required env vars: + * OPENAI_API_KEY — API key + * + * Optional: + * OPENAI_BASE_URL — base URL without path (default: https://api.openai.com) + * OPENAI_EMBEDDING_MODEL — model name (default: text-embedding-3-small) + */ export class OpenAIEmbeddingProvider implements EmbeddingProvider { readonly name = "openai"; readonly dimensions = 1536; private apiKey: string; + private baseUrl: string; + private model: string; constructor(apiKey?: string) { this.apiKey = apiKey || getEnvVar("OPENAI_API_KEY") || ""; if (!this.apiKey) throw new Error("OPENAI_API_KEY is required"); + this.baseUrl = + getEnvVar("OPENAI_BASE_URL") || DEFAULT_BASE_URL; + this.model = + getEnvVar("OPENAI_EMBEDDING_MODEL") || DEFAULT_MODEL; } async embed(text: string): Promise { @@ -19,14 +36,15 @@ export class OpenAIEmbeddingProvider implements EmbeddingProvider { } async embedBatch(texts: string[]): Promise { - const response = await fetch(API_URL, { + const url = `${this.baseUrl}/v1/embeddings`; + const response = await fetch(url, { method: "POST", headers: { Authorization: `Bearer ${this.apiKey}`, "Content-Type": "application/json", }, body: JSON.stringify({ - model: "text-embedding-3-small", + model: this.model, input: texts, }), }); diff --git a/test/embedding-provider.test.ts b/test/embedding-provider.test.ts index 05d01bd..2cf3271 100644 --- a/test/embedding-provider.test.ts +++ b/test/embedding-provider.test.ts @@ -47,3 +47,60 @@ describe("createEmbeddingProvider", () => { expect(provider).toBeInstanceOf(OpenAIEmbeddingProvider); }); }); + +describe("OpenAIEmbeddingProvider", () => { + const originalEnv = { ...process.env }; + + beforeEach(() => { + process.env = { ...originalEnv }; + delete process.env["OPENAI_BASE_URL"]; + delete process.env["OPENAI_EMBEDDING_MODEL"]; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + it("uses default base URL and model when env vars are not set", () => { + const provider = new OpenAIEmbeddingProvider("test-key"); + expect(provider.name).toBe("openai"); + expect(provider.dimensions).toBe(1536); + }); + + it("throws when no API key is provided", () => { + delete process.env["OPENAI_API_KEY"]; + expect(() => new OpenAIEmbeddingProvider()).toThrow("OPENAI_API_KEY is required"); + }); + + it("respects OPENAI_BASE_URL env var", async () => { + process.env["OPENAI_BASE_URL"] = "https://my-proxy.example.com"; + const provider = new OpenAIEmbeddingProvider("test-key"); + + const fetchSpy = vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3] }] }), { status: 200 }), + ); + + await provider.embed("hello"); + expect(fetchSpy).toHaveBeenCalledWith( + "https://my-proxy.example.com/v1/embeddings", + expect.any(Object), + ); + + fetchSpy.mockRestore(); + }); + + it("respects OPENAI_EMBEDDING_MODEL env var", async () => { + process.env["OPENAI_EMBEDDING_MODEL"] = "text-embedding-3-large"; + const provider = new OpenAIEmbeddingProvider("test-key"); + + const fetchSpy = vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3] }] }), { status: 200 }), + ); + + await provider.embed("hello"); + const body = JSON.parse((fetchSpy.mock.calls[0][1] as RequestInit).body as string); + expect(body.model).toBe("text-embedding-3-large"); + + fetchSpy.mockRestore(); + }); +});