Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/app/server/src/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export const env = createEnv({
GROQ_API_KEY: z.string().optional(),
XAI_API_KEY: z.string().optional(),
OPENROUTER_API_KEY: z.string().optional(),
AI_GATEWAY_API_KEY: z.string().optional(),
TAVILY_API_KEY: z.string().optional(),
E2B_API_KEY: z.string().optional(),
GOOGLE_SERVICE_ACCOUNT_KEY_ENCODED: z.string().optional(),
Expand Down
7 changes: 7 additions & 0 deletions packages/app/server/src/providers/ProviderFactory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { OpenAIImageProvider } from './OpenAIImageProvider';
import { OpenAIResponsesProvider } from './OpenAIResponsesProvider';
import { OpenRouterProvider } from './OpenRouterProvider';
import { ProviderType } from './ProviderType';
import { VercelAIGatewayProvider } from './VercelAIGatewayProvider';
import { XAIProvider } from './XAIProvider';
import {
VertexAIProvider,
Expand Down Expand Up @@ -58,6 +59,10 @@ const createChatModelToProviderMapping = (): Record<string, ProviderType> => {
case 'Xai':
mapping[modelConfig.model_id] = ProviderType.XAI;
break;
case 'Vercel AI Gateway':
case 'Vercel':
mapping[modelConfig.model_id] = ProviderType.VERCEL_AI_GATEWAY;
break;
// Add other providers as needed
default:
// Skip models with unsupported providers
Expand Down Expand Up @@ -192,6 +197,8 @@ export const getProvider = (
return new GroqProvider(stream, model);
case ProviderType.XAI:
return new XAIProvider(stream, model);
case ProviderType.VERCEL_AI_GATEWAY:
return new VercelAIGatewayProvider(stream, model);
default:
throw new Error(`Unknown provider type: ${type}`);
}
Expand Down
1 change: 1 addition & 0 deletions packages/app/server/src/providers/ProviderType.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ export enum ProviderType {
OPENAI_VIDEOS = 'OPENAI_VIDEOS',
GROQ = 'GROQ',
XAI = 'XAI',
VERCEL_AI_GATEWAY = 'VERCEL_AI_GATEWAY',
}
259 changes: 259 additions & 0 deletions packages/app/server/src/providers/VercelAIGatewayProvider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import { LlmTransactionMetadata, Transaction } from '../types';
import { getCostPerToken } from '../services/AccountingService';
import { BaseProvider } from './BaseProvider';
import { ProviderType } from './ProviderType';
import { parseSSEGPTFormat } from './GPTProvider';
import logger from '../logger';
import { env } from '../env';

interface AIGatewayUsage {
inputTokens?: number;
outputTokens?: number;
totalTokens?: number;
}

interface OpenAICompatibleUsage {
prompt_tokens?: number;
completion_tokens?: number;
total_tokens?: number;
}

interface AIGatewayResponseBody {
id?: string;
response?: {
id?: string;
};
usage?: AIGatewayUsage | OpenAICompatibleUsage;
}

interface AIGatewayStreamPart {
type?: string;
id?: string;
response?: {
id?: string;
};
usage?: AIGatewayUsage;
}

const ECHO_VERCEL_MODEL_PREFIX = 'vercel/';

function toVercelModelId(model: string): string {
return model.startsWith(ECHO_VERCEL_MODEL_PREFIX)
? model.slice(ECHO_VERCEL_MODEL_PREFIX.length)
: model;
}

function isAIGatewayUsage(
usage: AIGatewayUsage | OpenAICompatibleUsage | undefined
): usage is AIGatewayUsage {
return (
usage !== undefined &&
('inputTokens' in usage ||
'outputTokens' in usage ||
'totalTokens' in usage)
);
}

function parseSSEJsonObjects(data: string): unknown[] {
const chunks: unknown[] = [];

for (const event of data.split('\n\n')) {
const dataLines = event
.split('\n')
.filter(line => line.startsWith('data: '))
.map(line => line.slice(6));

if (dataLines.length === 0) {
continue;
}

const payload = dataLines.join('\n').trim();
if (!payload || payload === '[DONE]') {
continue;
}

try {
chunks.push(JSON.parse(payload));
} catch (error) {
logger.error(`Error parsing Vercel AI Gateway SSE chunk: ${error}`);
}
}

return chunks;
}

export class VercelAIGatewayProvider extends BaseProvider {
private readonly VERCEL_AI_GATEWAY_OPENAI_BASE_URL =
'https://ai-gateway.vercel.sh/v1';
private readonly VERCEL_AI_GATEWAY_AI_SDK_BASE_URL =
'https://ai-gateway.vercel.sh/v1/ai';

getType(): ProviderType {
return ProviderType.VERCEL_AI_GATEWAY;
}

getBaseUrl(reqPath?: string): string {
if (
reqPath?.includes('/language-model') ||
reqPath?.includes('/embedding-model') ||
reqPath?.endsWith('/config')
) {
return this.VERCEL_AI_GATEWAY_AI_SDK_BASE_URL;
}

return this.VERCEL_AI_GATEWAY_OPENAI_BASE_URL;
}

override formatUpstreamUrl(req: { path: string; url: string }): string {
const query = req.url.includes('?')
? req.url.substring(req.url.indexOf('?'))
: '';

if (req.path.endsWith('/language-model')) {
return `${this.VERCEL_AI_GATEWAY_AI_SDK_BASE_URL}/language-model${query}`;
}
if (req.path.endsWith('/embedding-model')) {
return `${this.VERCEL_AI_GATEWAY_AI_SDK_BASE_URL}/embedding-model${query}`;
}
if (req.path.endsWith('/config')) {
return `${this.VERCEL_AI_GATEWAY_AI_SDK_BASE_URL}/config${query}`;
}

return `${this.getBaseUrl(req.path)}${req.path}${query}`;
}

getApiKey(): string | undefined {
return env.AI_GATEWAY_API_KEY;
}

override async formatAuthHeaders(
headers: Record<string, string>
): Promise<Record<string, string>> {
const formattedHeaders = await super.formatAuthHeaders(headers);

delete formattedHeaders['AI-Language-Model-Id'];
delete formattedHeaders['ai-language-model-id'];
formattedHeaders['ai-language-model-id'] = toVercelModelId(this.getModel());

return formattedHeaders;
}

override ensureStreamUsage(
reqBody: Record<string, unknown>,
reqPath: string
): Record<string, unknown> {
if (reqPath.includes('/language-model')) {
return reqBody;
}

return super.ensureStreamUsage(reqBody, reqPath);
}

override transformRequestBody(
reqBody: Record<string, unknown>
): Record<string, unknown> {
if (typeof reqBody.model === 'string') {
return {
...reqBody,
model: toVercelModelId(reqBody.model),
};
}

return reqBody;
}

override supportsStream(): boolean {
return true;
}

private getUsageFromAIResponse(usage: AIGatewayUsage | undefined): {
promptTokens: number;
completionTokens: number;
totalTokens: number;
} {
const promptTokens = usage?.inputTokens ?? 0;
const completionTokens = usage?.outputTokens ?? 0;
const totalTokens = usage?.totalTokens ?? promptTokens + completionTokens;

return { promptTokens, completionTokens, totalTokens };
}

async handleBody(data: string): Promise<Transaction> {
try {
let prompt_tokens = 0;
let completion_tokens = 0;
let total_tokens = 0;
let providerId = 'null';

if (this.getIsStream()) {
const chunks = parseSSEJsonObjects(data);

for (const chunk of chunks) {
const streamPart = chunk as AIGatewayStreamPart;

if (streamPart.type === 'finish' && streamPart.usage) {
const usage = this.getUsageFromAIResponse(streamPart.usage);
prompt_tokens += usage.promptTokens;
completion_tokens += usage.completionTokens;
total_tokens += usage.totalTokens;
}

providerId = streamPart.response?.id ?? streamPart.id ?? providerId;
}

if (total_tokens === 0) {
for (const chunk of parseSSEGPTFormat(data)) {
if (chunk.usage !== null) {
prompt_tokens += chunk.usage.prompt_tokens;
completion_tokens += chunk.usage.completion_tokens;
total_tokens += chunk.usage.total_tokens;
}
providerId = chunk.id || providerId;
}
}
} else {
const parsed = JSON.parse(data) as AIGatewayResponseBody;

if (isAIGatewayUsage(parsed.usage)) {
const usage = this.getUsageFromAIResponse(parsed.usage);
prompt_tokens += usage.promptTokens;
completion_tokens += usage.completionTokens;
total_tokens += usage.totalTokens;
providerId = parsed.response?.id ?? parsed.id ?? 'null';
} else if (parsed.usage) {
prompt_tokens += parsed.usage.prompt_tokens ?? 0;
completion_tokens += parsed.usage.completion_tokens ?? 0;
total_tokens +=
parsed.usage.total_tokens ?? prompt_tokens + completion_tokens;
providerId = parsed.id || 'null';
Comment thread
vercel[bot] marked this conversation as resolved.
} else {
providerId = parsed.response?.id ?? parsed.id ?? 'null';
}
}

const cost = getCostPerToken(
this.getModel(),
prompt_tokens,
completion_tokens
);

const metadata: LlmTransactionMetadata = {
providerId,
provider: this.getType(),
model: this.getModel(),
inputTokens: prompt_tokens,
outputTokens: completion_tokens,
totalTokens: total_tokens,
};

return {
rawTransactionCost: cost,
metadata,
status: 'success',
};
} catch (error) {
logger.error(`Error processing data: ${error}`);
throw error;
}
}
}
57 changes: 54 additions & 3 deletions packages/app/server/src/services/AccountingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import {
OpenAIImageModels,
SupportedOpenAIResponseToolPricing,
SupportedModel,
TokenPricingTier,
SupportedImageModel,
SupportedVideoModel,
XAIModels,
VercelAIGatewayModels,
} from '@merit-systems/echo-typescript-sdk';

import { Decimal } from '@prisma/client/runtime/library';
Expand All @@ -30,6 +32,7 @@ export const ALL_SUPPORTED_MODELS: SupportedModel[] = [
...OpenRouterModels,
...GroqModels,
...XAIModels,
...VercelAIGatewayModels,
];

// Handle image models separately since they have different pricing structure
Expand Down Expand Up @@ -67,6 +70,8 @@ export const getModelPrice = (model: string) => {
return {
input_cost_per_token: supportedModel.input_cost_per_token,
output_cost_per_token: supportedModel.output_cost_per_token,
input_cost_per_token_tiers: supportedModel.input_cost_per_token_tiers,
output_cost_per_token_tiers: supportedModel.output_cost_per_token_tiers,
provider: supportedModel.provider,
model: supportedModel.model_id,
};
Expand Down Expand Up @@ -135,9 +140,17 @@ export const getCostPerToken = (
throw new Error(`Invalid pricing for model: ${model}`);
}

const cost = new Decimal(modelPrice.input_cost_per_token)
.mul(inputTokens)
.plus(new Decimal(modelPrice.output_cost_per_token).mul(outputTokens));
const cost = getTieredTokenCost(
inputTokens,
modelPrice.input_cost_per_token,
modelPrice.input_cost_per_token_tiers
).plus(
getTieredTokenCost(
outputTokens,
modelPrice.output_cost_per_token,
modelPrice.output_cost_per_token_tiers
)
);

if (cost.lessThan(0)) {
throw new Error(`Invalid cost for model: ${model}`);
Expand All @@ -146,6 +159,44 @@ export const getCostPerToken = (
return cost;
};

export function getTieredTokenRate(
tokens: number,
fallbackRate: number,
tiers?: TokenPricingTier[]
): number {
if (!tiers || tiers.length === 0) {
return fallbackRate;
}

return (
tiers.find(
tier =>
tokens >= tier.min && (tier.max === undefined || tokens < tier.max)
)?.cost ?? fallbackRate
);
}

export function getMaxTokenRate(
fallbackRate: number,
tiers?: TokenPricingTier[]
): number {
if (!tiers || tiers.length === 0) {
return fallbackRate;
}

return Math.max(fallbackRate, ...tiers.map(tier => tier.cost));
}

function getTieredTokenCost(
tokens: number,
fallbackRate: number,
tiers?: TokenPricingTier[]
): Decimal {
return new Decimal(getTieredTokenRate(tokens, fallbackRate, tiers)).mul(
tokens
);
}

export const getImageModelCost = (
model: string,
textTokens: number,
Expand Down
Loading