Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/global-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ export const SECRET_STATE_KEYS = [
"codebaseIndexVercelAiGatewayApiKey",
"codebaseIndexOpenRouterApiKey",
"sambaNovaApiKey",
"veniceApiKey",
"zaiApiKey",
"fireworksApiKey",
"vercelAiGatewayApiKey",
Expand Down
10 changes: 10 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
xaiModels,
internationalZAiModels,
minimaxModels,
veniceModels,
} from "./providers/index.js"

/**
Expand Down Expand Up @@ -126,6 +127,7 @@ export const providerNames = [
"roo",
"sambanova",
"vertex",
"venice",
"xai",
"zai",
] as const
Expand Down Expand Up @@ -372,6 +374,10 @@ const zaiSchema = apiModelIdProviderModelSchema.extend({
zaiApiLine: zaiApiLineSchema.optional(),
})

const veniceSchema = apiModelIdProviderModelSchema.extend({
veniceApiKey: z.string().optional(),
})

const fireworksSchema = apiModelIdProviderModelSchema.extend({
fireworksApiKey: z.string().optional(),
})
Expand Down Expand Up @@ -427,6 +433,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
fireworksSchema.merge(z.object({ apiProvider: z.literal("fireworks") })),
qwenCodeSchema.merge(z.object({ apiProvider: z.literal("qwen-code") })),
rooSchema.merge(z.object({ apiProvider: z.literal("roo") })),
veniceSchema.merge(z.object({ apiProvider: z.literal("venice") })),
vercelAiGatewaySchema.merge(z.object({ apiProvider: z.literal("vercel-ai-gateway") })),
defaultSchema,
])
Expand Down Expand Up @@ -461,6 +468,7 @@ export const providerSettingsSchema = z.object({
...fireworksSchema.shape,
...qwenCodeSchema.shape,
...rooSchema.shape,
...veniceSchema.shape,
...vercelAiGatewaySchema.shape,
...codebaseIndexProviderSchema.shape,
})
Expand Down Expand Up @@ -529,6 +537,7 @@ export const modelIdKeysByProvider: Record<TypicalProvider, ModelIdKey> = {
"qwen-code": "apiModelId",
requesty: "requestyModelId",
unbound: "unboundModelId",
venice: "apiModelId",
xai: "apiModelId",
baseten: "apiModelId",
litellm: "litellmModelId",
Expand Down Expand Up @@ -643,6 +652,7 @@ export const MODELS_BY_PROVIDER: Record<
label: "VS Code LM API",
models: Object.keys(vscodeLlmModels),
},
venice: { id: "venice", label: "Venice AI", models: Object.keys(veniceModels) },
xai: { id: "xai", label: "xAI (Grok)", models: Object.keys(xaiModels) },
zai: { id: "zai", label: "Z.ai", models: Object.keys(internationalZAiModels) },
baseten: { id: "baseten", label: "Baseten", models: Object.keys(basetenModels) },
Expand Down
4 changes: 4 additions & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export * from "./requesty.js"
export * from "./roo.js"
export * from "./sambanova.js"
export * from "./unbound.js"
export * from "./venice.js"
export * from "./vertex.js"
export * from "./vscode-llm.js"
export * from "./xai.js"
Expand All @@ -43,6 +44,7 @@ import { requestyDefaultModelId } from "./requesty.js"
import { rooDefaultModelId } from "./roo.js"
import { sambaNovaDefaultModelId } from "./sambanova.js"
import { unboundDefaultModelId } from "./unbound.js"
import { veniceDefaultModelId } from "./venice.js"
import { vertexDefaultModelId } from "./vertex.js"
import { vscodeLlmDefaultModelId } from "./vscode-llm.js"
import { xaiDefaultModelId } from "./xai.js"
Expand Down Expand Up @@ -113,6 +115,8 @@ export function getProviderDefaultModelId(
return poeDefaultModelId
case "unbound":
return unboundDefaultModelId
case "venice":
return veniceDefaultModelId
case "vercel-ai-gateway":
return vercelAiGatewayDefaultModelId
case "anthropic":
Expand Down
81 changes: 81 additions & 0 deletions packages/types/src/providers/venice.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import type { ModelInfo } from "../model.js"

// Venice AI
// https://docs.venice.ai/api-reference/chat-completions
export type VeniceModelId =
| "glm-4-32b"
| "trinity-v1"
| "deepseek-r1-671b"
| "deepseek-v3-0324"
| "qwen-2.5-coder-32b"
| "llama-3.3-70b"

export const veniceDefaultModelId: VeniceModelId = "glm-4-32b"

export const veniceDefaultModelInfo: ModelInfo = {
maxTokens: 8192,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0,
outputPrice: 0,
description: "GLM-4 32B model via Venice AI with private inference.",
}

export const veniceModels = {
"glm-4-32b": {
maxTokens: 8192,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0,
outputPrice: 0,
description: "GLM-4 32B model via Venice AI with private inference.",
},
"trinity-v1": {
maxTokens: 8192,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0,
outputPrice: 0,
description: "Venice Trinity V1 model with private inference.",
},
"deepseek-r1-671b": {
maxTokens: 8192,
contextWindow: 65536,
supportsImages: false,
supportsPromptCache: false,
supportsReasoningBudget: true,
inputPrice: 0,
outputPrice: 0,
description: "DeepSeek R1 671B reasoning model via Venice AI.",
},
"deepseek-v3-0324": {
maxTokens: 8192,
contextWindow: 65536,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0,
outputPrice: 0,
description: "DeepSeek V3 0324 model via Venice AI.",
},
"qwen-2.5-coder-32b": {
maxTokens: 8192,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0,
outputPrice: 0,
description: "Qwen 2.5 Coder 32B model via Venice AI.",
},
"llama-3.3-70b": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0,
outputPrice: 0,
description: "Meta Llama 3.3 70B model via Venice AI.",
},
} as const satisfies Record<string, ModelInfo>
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
VsCodeLmHandler,
RequestyHandler,
UnboundHandler,
VeniceHandler,
FakeAIHandler,
XAIHandler,
LiteLLMHandler,
Expand Down Expand Up @@ -177,6 +178,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new MiniMaxHandler(options)
case "baseten":
return new BasetenHandler(options)
case "venice":
return new VeniceHandler(options)
case "poe":
return new PoeHandler(options)
default:
Expand Down
147 changes: 147 additions & 0 deletions src/api/providers/__tests__/venice.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// npx vitest run src/api/providers/__tests__/venice.spec.ts

import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"

import { type VeniceModelId, veniceDefaultModelId, veniceModels } from "@roo-code/types"

import { VeniceHandler } from "../venice"

vitest.mock("openai", () => {
const createMock = vitest.fn()
return {
default: vitest.fn(() => ({ chat: { completions: { create: createMock } } })),
}
})

describe("VeniceHandler", () => {
let handler: VeniceHandler
let mockCreate: any

beforeEach(() => {
vitest.clearAllMocks()
mockCreate = (OpenAI as unknown as any)().chat.completions.create
handler = new VeniceHandler({ veniceApiKey: "test-venice-api-key" })
})

it("should use the correct Venice base URL", () => {
new VeniceHandler({ veniceApiKey: "test-venice-api-key" })
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.venice.ai/api/v1" }))
})

it("should use the provided API key", () => {
const veniceApiKey = "test-venice-api-key"
new VeniceHandler({ veniceApiKey })
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: veniceApiKey }))
})

it("should return default model when no model is specified", () => {
const model = handler.getModel()
expect(model.id).toBe(veniceDefaultModelId)
expect(model.info).toEqual(veniceModels[veniceDefaultModelId])
})

it("should return specified model when valid model is provided", () => {
const testModelId: VeniceModelId = "deepseek-r1-671b"
const handlerWithModel = new VeniceHandler({
apiModelId: testModelId,
veniceApiKey: "test-venice-api-key",
})
const model = handlerWithModel.getModel()
expect(model.id).toBe(testModelId)
expect(model.info).toEqual(veniceModels[testModelId])
})

it("completePrompt method should return text from Venice API", async () => {
const expectedResponse = "This is a test response from Venice"
mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] })
const result = await handler.completePrompt("test prompt")
expect(result).toBe(expectedResponse)
})

it("should handle errors in completePrompt", async () => {
const errorMessage = "Venice API error"
mockCreate.mockRejectedValueOnce(new Error(errorMessage))
await expect(handler.completePrompt("test prompt")).rejects.toThrow(`Venice completion error: ${errorMessage}`)
})

it("createMessage should yield text content from stream", async () => {
const testContent = "This is test content from Venice stream"

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vitest
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: testContent } }] },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "text", text: testContent })
})

it("createMessage should yield usage data from stream", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vitest
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [{ delta: { content: "" } }],
usage: { prompt_tokens: 10, completion_tokens: 5 },
},
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual(
expect.objectContaining({
type: "usage",
inputTokens: 10,
outputTokens: 5,
}),
)
})

it("should pass the correct parameters to OpenAI API", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vitest.fn().mockResolvedValueOnce({ done: true }),
}),
}
})

const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]

const stream = handler.createMessage(systemPrompt, messages)

// Consume the stream
const results = []
for await (const chunk of stream) {
results.push(chunk)
}

const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.model).toBe(veniceDefaultModelId)
expect(callArgs.stream).toBe(true)
})
})
1 change: 1 addition & 0 deletions src/api/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export { QwenCodeHandler } from "./qwen-code"
export { RequestyHandler } from "./requesty"
export { SambaNovaHandler } from "./sambanova"
export { UnboundHandler } from "./unbound"
export { VeniceHandler } from "./venice"
export { VertexHandler } from "./vertex"
export { VsCodeLmHandler } from "./vscode-lm"
export { XAIHandler } from "./xai"
Expand Down
18 changes: 18 additions & 0 deletions src/api/providers/venice.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { type VeniceModelId, veniceDefaultModelId, veniceModels } from "@roo-code/types"

import type { ApiHandlerOptions } from "../../shared/api"

import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"

export class VeniceHandler extends BaseOpenAiCompatibleProvider<VeniceModelId> {
constructor(options: ApiHandlerOptions) {
super({
...options,
providerName: "Venice",
baseURL: "https://api.venice.ai/api/v1",
apiKey: options.veniceApiKey,
defaultProviderModelId: veniceDefaultModelId,
providerModels: veniceModels,
})
}
}
1 change: 1 addition & 0 deletions src/shared/ProfileValidator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export class ProfileValidator {
case "deepseek":
case "xai":
case "sambanova":
case "venice":
case "fireworks":
return profile.apiModelId
case "litellm":
Expand Down
Loading
Loading