From f66d006ea5f146547a108f1ba29933e8047c2969 Mon Sep 17 00:00:00 2001 From: Ec3o <2499302531@qq.com> Date: Sat, 6 Jun 2026 21:22:43 +0800 Subject: [PATCH] feat(core): propagate turn runner usage --- packages/core/package.json | 2 +- .../__tests__/session-runtime.test.ts | 1 + packages/core/src/adapters/session-runtime.ts | 2 ++ .../src/engine/__tests__/turn-runner.test.ts | 34 +++++++++++++++++++ packages/core/src/engine/turn-runner.ts | 32 +++++++++++++++++ 5 files changed, 70 insertions(+), 1 deletion(-) diff --git a/packages/core/package.json b/packages/core/package.json index a515ee0..ebfaa04 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -1,6 +1,6 @@ { "name": "@stello-ai/core", - "version": "0.10.0", + "version": "0.10.1", "description": "The first open-source conversation topology engine", "license": "Apache-2.0", "author": "Stello Contributors", diff --git a/packages/core/src/adapters/__tests__/session-runtime.test.ts b/packages/core/src/adapters/__tests__/session-runtime.test.ts index b2149d2..1d832c9 100644 --- a/packages/core/src/adapters/__tests__/session-runtime.test.ts +++ b/packages/core/src/adapters/__tests__/session-runtime.test.ts @@ -16,6 +16,7 @@ describe('session-runtime adapters', () => { const parsed = sessionSendResultParser.parse(raw); expect(parsed.content).toBe('done'); + expect(parsed.usage).toEqual({ promptTokens: 10, completionTokens: 5 }); expect(parsed.toolCalls).toEqual([ { id: 't1', name: 'read_file', args: { path: 'a.ts' } }, ]); diff --git a/packages/core/src/adapters/session-runtime.ts b/packages/core/src/adapters/session-runtime.ts index e873891..d6e132d 100644 --- a/packages/core/src/adapters/session-runtime.ts +++ b/packages/core/src/adapters/session-runtime.ts @@ -128,11 +128,13 @@ export const sessionSendResultParser: ToolCallParser = { name: string; args: Record; }>; + usage?: SessionCompatibleSendResult['usage']; }; return { content: parsed.content, toolCalls: parsed.toolCalls ?? [], + usage: parsed.usage, }; }, }; diff --git a/packages/core/src/engine/__tests__/turn-runner.test.ts b/packages/core/src/engine/__tests__/turn-runner.test.ts index 719e950..b3eeaf8 100644 --- a/packages/core/src/engine/__tests__/turn-runner.test.ts +++ b/packages/core/src/engine/__tests__/turn-runner.test.ts @@ -55,6 +55,40 @@ describe('TurnRunner', () => { expect(result.toolCallsExecuted).toBe(1); }); + it('聚合 tool loop 内每次 LLM 调用的 usage', async () => { + const session = { + id: 's1', + send: vi + .fn() + .mockResolvedValueOnce( + JSON.stringify({ + content: null, + toolCalls: [{ id: '1', name: 'read', args: {} }], + usage: { promptTokens: 10, completionTokens: 2 }, + }), + ) + .mockResolvedValueOnce( + JSON.stringify({ + content: 'done', + toolCalls: [], + usage: { promptTokens: 8, completionTokens: 4 }, + }), + ), + }; + const tools = { + executeTool: vi.fn().mockResolvedValue({ success: true, data: { ok: true } }), + }; + + const runner = new TurnRunner(parser); + const result = await runner.run(session, 'hello', tools); + + expect(result.usage).toEqual({ + promptTokens: 18, + completionTokens: 6, + totalTokens: 24, + }); + }); + it('多个 tool call 在同轮内并行执行,但调用顺序保持输入序', async () => { const session = { id: 's1', diff --git a/packages/core/src/engine/turn-runner.ts b/packages/core/src/engine/turn-runner.ts index e085a8d..97c36dc 100644 --- a/packages/core/src/engine/turn-runner.ts +++ b/packages/core/src/engine/turn-runner.ts @@ -80,6 +80,30 @@ export interface ParsedTurnResponse { content: string | null; /** 需要由 Engine 执行的工具调用 */ toolCalls: ToolCall[]; + /** 本次 LLM 调用的 token 用量 */ + usage?: TurnRunnerUsage; +} + +/** 单次或聚合后的 LLM token 用量 */ +export interface TurnRunnerUsage { + promptTokens?: number; + completionTokens?: number; + totalTokens?: number; +} + +function addOptionalNumbers(a: number | undefined, b: number | undefined): number | undefined { + return a === undefined && b === undefined ? undefined : (a ?? 0) + (b ?? 0); +} + +function addUsage(current: TurnRunnerUsage | undefined, next: TurnRunnerUsage | undefined): TurnRunnerUsage | undefined { + if (!next) return current; + const currentTotal = current?.totalTokens ?? ((current?.promptTokens ?? 0) + (current?.completionTokens ?? 0)); + const nextTotal = next.totalTokens ?? ((next.promptTokens ?? 0) + (next.completionTokens ?? 0)); + return { + promptTokens: addOptionalNumbers(current?.promptTokens, next.promptTokens), + completionTokens: addOptionalNumbers(current?.completionTokens, next.completionTokens), + totalTokens: currentTotal + nextTotal, + }; } /** Session 调用的运行时选项 */ @@ -160,6 +184,8 @@ export interface TurnRunnerResult { toolCallsExecuted: number; /** 原始最终响应 */ rawResponse: string; + /** 本轮内所有 LLM 调用聚合后的 token 用量 */ + usage?: TurnRunnerUsage; } /** 流式 tool loop 的执行结果 */ @@ -197,11 +223,13 @@ export class TurnRunner { let toolRoundCount = 0; let toolCallsExecuted = 0; let lastRawResponse = ''; + let usage: TurnRunnerUsage | undefined; while (true) { options.signal?.throwIfAborted(); lastRawResponse = await session.send(currentInput, { signal: options.signal }); const parsed = this.parser.parse(lastRawResponse); + usage = addUsage(usage, parsed.usage); if (parsed.toolCalls.length === 0) { return { @@ -209,6 +237,7 @@ export class TurnRunner { toolRoundCount, toolCallsExecuted, rawResponse: lastRawResponse, + usage, }; } @@ -290,6 +319,7 @@ export class TurnRunner { let lastRawResponse = await rawResult options.signal?.throwIfAborted() let parsed = this.parser.parse(lastRawResponse) + let usage = addUsage(undefined, parsed.usage) while (parsed.toolCalls.length > 0) { if (toolRoundCount >= maxToolRounds) { @@ -303,6 +333,7 @@ export class TurnRunner { options.signal?.throwIfAborted() lastRawResponse = await session.send(JSON.stringify({ toolResults }), { signal: options.signal }) parsed = this.parser.parse(lastRawResponse) + usage = addUsage(usage, parsed.usage) } return { @@ -310,6 +341,7 @@ export class TurnRunner { toolRoundCount, toolCallsExecuted, rawResponse: lastRawResponse, + usage, } } }