From a397f94f3c6a03c84625192104fd130d69ea01db Mon Sep 17 00:00:00 2001 From: Ec3o <2499302531@qq.com> Date: Sat, 6 Jun 2026 23:26:32 +0800 Subject: [PATCH] feat(session): capture streaming token usage --- packages/core/package.json | 2 +- packages/session/package.json | 2 +- .../src/__tests__/openai-compatible.test.ts | 33 +++++++++++++++++++ packages/session/src/__tests__/turn.test.ts | 27 +++++++++++++++ packages/session/src/adapters/anthropic.ts | 21 ++++++++++++ .../session/src/adapters/openai-compatible.ts | 21 ++++++++---- packages/session/src/create-session.ts | 13 +++++++- packages/session/src/types/llm.ts | 7 ++++ 8 files changed, 116 insertions(+), 10 deletions(-) diff --git a/packages/core/package.json b/packages/core/package.json index ebfaa04..8e4d5d3 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -1,6 +1,6 @@ { "name": "@stello-ai/core", - "version": "0.10.1", + "version": "0.10.2", "description": "The first open-source conversation topology engine", "license": "Apache-2.0", "author": "Stello Contributors", diff --git a/packages/session/package.json b/packages/session/package.json index 036b561..e7222df 100644 --- a/packages/session/package.json +++ b/packages/session/package.json @@ -1,6 +1,6 @@ { "name": "@stello-ai/session", - "version": "0.8.0", + "version": "0.8.1", "description": "Session layer for Stello — conversation topology engine", "license": "Apache-2.0", "author": "Stello Contributors", diff --git a/packages/session/src/__tests__/openai-compatible.test.ts b/packages/session/src/__tests__/openai-compatible.test.ts index 5cc2c2e..a059616 100644 --- a/packages/session/src/__tests__/openai-compatible.test.ts +++ b/packages/session/src/__tests__/openai-compatible.test.ts @@ -347,6 +347,39 @@ describe('createOpenAICompatibleAdapter', () => { expect(chunks.map((chunk) => chunk.delta).join('')).toBe('上海中心大厦') }) + it('stream() 请求 provider usage 并透传 usage-only chunk', async () => { + createCompletion.mockResolvedValueOnce((async function* () { + yield { choices: [{ delta: { content: 'ok' } }] } + yield { choices: [], usage: { prompt_tokens: 10, completion_tokens: 2 } } + })()) + + const adapter = createOpenAICompatibleAdapter({ + apiKey: 'test-key', + baseURL: 'https://api.example.com/v1', + model: 'test-model', + maxContextTokens: 128_000, + }) + + if (!adapter.stream) throw new Error('adapter.stream is required') + + const chunks = [] + for await (const chunk of adapter.stream([{ role: 'user', content: 'hello' }])) { + chunks.push(chunk) + } + + expect(createCompletion).toHaveBeenCalledWith( + expect.objectContaining({ + stream: true, + stream_options: { include_usage: true }, + }), + undefined, + ) + expect(chunks).toEqual([ + { delta: 'ok' }, + { delta: '', usage: { promptTokens: 10, completionTokens: 2 } }, + ]) + }) + it('StepFun 3.7 多模态能力不绑定固定 baseURL', async () => { const adapter = createOpenAICompatibleAdapter({ apiKey: 'test-key', diff --git a/packages/session/src/__tests__/turn.test.ts b/packages/session/src/__tests__/turn.test.ts index e8c9ebc..eb9c768 100644 --- a/packages/session/src/__tests__/turn.test.ts +++ b/packages/session/src/__tests__/turn.test.ts @@ -270,6 +270,33 @@ describe('send() 契约', () => { expect(messages).toHaveLength(2) expect(messages[1]!.content).toBe('你好,世界') }) + + it('stream() 汇总 adapter chunk usage 并透传到最终结果', async () => { + const { session } = await makeSession({ + llm: { + maxContextTokens: 1_000_000, + async complete() { + return { content: 'unused' } + }, + async *stream() { + yield { delta: '你', usage: { promptTokens: 11, completionTokens: 0 } } + yield { delta: '好' } + yield { delta: '', usage: { promptTokens: 11, completionTokens: 2 } } + }, + }, + }) + + const stream = session.stream('hello') + const chunks: string[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + const result = await stream.result + + expect(chunks).toEqual(['你', '好']) + expect(result.content).toBe('你好') + expect(result.usage).toEqual({ promptTokens: 11, completionTokens: 2 }) + }) }) describe('Session.setTools (per-session tool list mutation)', () => { diff --git a/packages/session/src/adapters/anthropic.ts b/packages/session/src/adapters/anthropic.ts index 147dc0f..60e4a45 100644 --- a/packages/session/src/adapters/anthropic.ts +++ b/packages/session/src/adapters/anthropic.ts @@ -11,6 +11,7 @@ import type { LLMAdapter, LLMResult, LLMChunk, + LLMUsage, Message, ToolCall, LLMCompleteOptions, @@ -27,6 +28,11 @@ type AnthropicProviderBlock = { content?: unknown } & Record +type AnthropicStreamUsage = { + input_tokens?: number | null + output_tokens?: number | null +} + /** Anthropic 原生协议的配置选项 */ export interface AnthropicAdapterOptions { apiKey: string @@ -115,6 +121,14 @@ function toAnthropicMessages(messages: Message[]): MessageParam[] { return result } +function mergeAnthropicUsage(current: LLMUsage | undefined, usage: AnthropicStreamUsage | undefined): LLMUsage | undefined { + if (!usage) return current + return { + promptTokens: usage.input_tokens ?? current?.promptTokens ?? 0, + completionTokens: usage.output_tokens ?? current?.completionTokens ?? 0, + } +} + /** 将 Stello tools schema 转换为 Anthropic Tool 格式 */ function toAnthropicTools( tools: NonNullable, @@ -252,6 +266,7 @@ export function createAnthropicAdapter(options: AnthropicAdapterOptions): LLMAda completeOptions?.signal ? { signal: completeOptions.signal } : undefined, ) + let usage: LLMUsage | undefined for await (const event of stream) { if (event.type === 'content_block_start') { // tool_use 块的 id 和 name 只在 start 事件里下发, @@ -285,6 +300,12 @@ export function createAnthropicAdapter(options: AnthropicAdapterOptions): LLMAda }], } } + } else if (event.type === 'message_start') { + usage = mergeAnthropicUsage(usage, event.message.usage) + if (usage) yield { delta: '', usage } + } else if (event.type === 'message_delta') { + usage = mergeAnthropicUsage(usage, event.usage) + if (usage) yield { delta: '', usage } } } }, diff --git a/packages/session/src/adapters/openai-compatible.ts b/packages/session/src/adapters/openai-compatible.ts index 8b509ed..581ef8c 100644 --- a/packages/session/src/adapters/openai-compatible.ts +++ b/packages/session/src/adapters/openai-compatible.ts @@ -92,6 +92,15 @@ function isProviderToolCall(call: RawOpenAIToolCall): boolean { return typeof call.type === 'string' && call.type !== 'function' } +function toProviderUsage(usage: ChatCompletion['usage'] | ChatCompletionChunk['usage'] | undefined) { + return usage + ? { + promptTokens: usage.prompt_tokens, + completionTokens: usage.completion_tokens, + } + : undefined +} + function toProviderToolEvent(call: RawOpenAIToolCall): ProviderToolEvent { const event: ProviderToolEvent = { ...(call.id ? { id: call.id } : {}), @@ -268,12 +277,7 @@ export function createOpenAICompatibleAdapter(options: OpenAICompatibleOptions): }] }), ...(providerToolEvents.length > 0 ? { providerToolEvents } : {}), - usage: response.usage - ? { - promptTokens: response.usage.prompt_tokens, - completionTokens: response.usage.completion_tokens, - } - : undefined, + usage: toProviderUsage(response.usage), } }, async *stream(messages: Message[], completeOptions?: LLMCompleteOptions) { @@ -282,6 +286,7 @@ export function createOpenAICompatibleAdapter(options: OpenAICompatibleOptions): ...(await buildParams(messages, completeOptions)), ...(options.extraBody ?? {}), stream: true, + stream_options: { include_usage: true }, } as Parameters[0], completeOptions?.signal ? { signal: completeOptions.signal } : undefined, ) as Stream @@ -297,18 +302,20 @@ export function createOpenAICompatibleAdapter(options: OpenAICompatibleOptions): const providerToolEvents = rawToolCalls .filter(isProviderToolCall) .map(toProviderToolEvent) + const usage = toProviderUsage(chunk.usage) const toolCallDeltas = rawToolCalls.filter((call) => !isProviderToolCall(call)).map((call) => ({ index: call.index ?? 0, id: call.id, name: call.function?.name, input: call.function?.arguments, })) - if (delta || reasoningDelta || toolCallDeltas.length > 0 || providerToolEvents.length > 0) { + if (delta || reasoningDelta || toolCallDeltas.length > 0 || providerToolEvents.length > 0 || usage) { yield { delta, ...(reasoningDelta ? { reasoningDelta } : {}), ...(toolCallDeltas.length > 0 ? { toolCallDeltas } : {}), ...(providerToolEvents.length > 0 ? { providerToolEvents } : {}), + ...(usage ? { usage } : {}), } } } diff --git a/packages/session/src/create-session.ts b/packages/session/src/create-session.ts index 269325b..c8b4578 100644 --- a/packages/session/src/create-session.ts +++ b/packages/session/src/create-session.ts @@ -2,7 +2,7 @@ import { randomUUID } from 'node:crypto' import type { Session, MessageQueryOptions, SessionInput, SessionSendOptions } from './types/session-api.js' import { SessionArchivedError } from './types/session-api.js' import type { SessionMeta, SessionMetaUpdate, ForkOptions } from './types/session.js' -import type { Message } from './types/llm.js' +import type { LLMUsage, Message } from './types/llm.js' import type { CreateSessionOptions, LoadSessionOptions, SendResult, StreamResult } from './types/functions.js' import { assembleSessionContext, buildSessionIdentityMessages, createBuiltinCompressFn, flushCompressionCache, hydrateCompressionCache, removeIncompleteToolCallGroups, type CompressionCache } from './context-utils.js' @@ -76,6 +76,14 @@ function stripMultimodalParts(records: Message[]): Message[] { }) } +function mergeUsage(current: LLMUsage | undefined, next: LLMUsage | undefined): LLMUsage | undefined { + if (!next) return current + return { + promptTokens: next.promptTokens ?? current?.promptTokens ?? 0, + completionTokens: next.completionTokens ?? current?.completionTokens ?? 0, + } +} + /** 为 toolResults continuation 组装固定上下文与历史。 */ async function assembleSessionReplayContext( sessionId: string, @@ -369,12 +377,14 @@ function buildSession( if (options.llm.stream) { let accumulated = '' let accumulatedReasoning = '' + let usage: LLMUsage | undefined const toolCallsByIndex = new Map() // adapter 在 abort 时抛 AbortError,这里直接向上传播给 result promise; // 下方 L3 写入分支不会执行(policy: drop entirely),与非流式 send() 对称。 for await (const chunk of options.llm.stream(promptMessages, { tools, signal: sendOptions?.signal })) { accumulated += chunk.delta if (chunk.reasoningDelta) accumulatedReasoning += chunk.reasoningDelta + usage = mergeUsage(usage, chunk.usage) push(chunk.delta) for (const delta of chunk.toolCallDeltas ?? []) { const current = toolCallsByIndex.get(delta.index) ?? { input: '' } @@ -393,6 +403,7 @@ function buildSession( content: accumulated, ...(accumulatedReasoning ? { reasoningContent: accumulatedReasoning } : {}), toolCalls, + ...(usage ? { usage } : {}), } } else { result = await options.llm.complete(promptMessages, { tools, signal: sendOptions?.signal }) diff --git a/packages/session/src/types/llm.ts b/packages/session/src/types/llm.ts index 72cbb08..5006bef 100644 --- a/packages/session/src/types/llm.ts +++ b/packages/session/src/types/llm.ts @@ -139,6 +139,11 @@ export interface LLMResult { } } +export interface LLMUsage { + promptTokens: number + completionTokens: number +} + /** 流式输出的单个 chunk */ export interface LLMChunk { /** 文本增量片段 */ @@ -154,6 +159,8 @@ export interface LLMChunk { }> /** Provider 内置 tool 事件 / 结果,不进入客户端 tool loop。 */ providerToolEvents?: ProviderToolEvent[] + /** token 用量统计;通常由 provider 在流结束前后的 usage-only chunk 返回。 */ + usage?: LLMUsage } /**