diff --git a/core/src/agents/functions.ts b/core/src/agents/functions.ts index 8697cdd4..b51bb0df 100644 --- a/core/src/agents/functions.ts +++ b/core/src/agents/functions.ts @@ -10,6 +10,7 @@ import {isEmpty} from 'lodash-es'; import {InvocationContext} from '../agents/invocation_context.js'; import {createEvent, Event, getFunctionCalls} from '../events/event.js'; import {mergeEventActions} from '../events/event_actions.js'; +import {AgentTool, isAgentTool} from '../tools/agent_tool.js'; import {BaseTool} from '../tools/base_tool.js'; import {ToolConfirmation} from '../tools/tool_confirmation.js'; import {randomUUID} from '../utils/env_aware_utils.js'; @@ -313,6 +314,126 @@ export async function handleFunctionCallsAsync({ }); } +/** + * Handles function calls and yields events from {@link AgentTool} sub-agents + * as they are produced, then yields the final merged function-response event + * last. + * + * Non-{@link AgentTool} calls are delegated to {@link handleFunctionCallList} + * (preserving plugin/before/after-callback semantics). {@link AgentTool} calls + * are streamed via {@link AgentTool.runAsyncWithEvents} so the parent runner + * can surface sub-agent activity in real time. The caller is expected to + * identify the terminal function-response event via + * {@link getFunctionResponses}. + */ +export async function* handleFunctionCallsStreamingAsync({ + invocationContext, + functionCallEvent, + toolsDict, + beforeToolCallbacks, + afterToolCallbacks, + toolConfirmationDict, +}: { + invocationContext: InvocationContext; + functionCallEvent: Event; + toolsDict: Record; + beforeToolCallbacks: SingleBeforeToolCallback[]; + afterToolCallbacks: SingleAfterToolCallback[]; + toolConfirmationDict?: Record; +}): AsyncGenerator { + const functionCalls = getFunctionCalls(functionCallEvent); + if (!functionCalls.length) { + return; + } + + const agentToolCalls: Array<{call: FunctionCall; tool: AgentTool}> = []; + const regularCalls: FunctionCall[] = []; + for (const functionCall of functionCalls) { + const tool = functionCall.name ? toolsDict[functionCall.name] : undefined; + if (tool && isAgentTool(tool)) { + agentToolCalls.push({call: functionCall, tool}); + } else { + regularCalls.push(functionCall); + } + } + + // Stream sub-agent events for each AgentTool call and build response events. + const agentToolResponseEvents: Event[] = []; + for (const {call, tool} of agentToolCalls) { + const toolContext = new Context({ + invocationContext, + functionCallId: call.id || undefined, + toolConfirmation: + toolConfirmationDict && call.id + ? toolConfirmationDict[call.id] + : undefined, + }); + + let lastContent: Content | undefined; + for await (const event of tool.runAsyncWithEvents({ + args: call.args ?? {}, + toolContext, + })) { + if (invocationContext.abortSignal?.aborted) { + return; + } + if (event.content) { + lastContent = event.content; + } + yield event; + } + + const toolResult = tool.buildToolResultFromContent(lastContent); + const responseValue = + typeof toolResult !== 'object' || toolResult == null + ? {result: toolResult} + : (toolResult as Record); + + agentToolResponseEvents.push( + createEvent({ + invocationId: invocationContext.invocationId, + author: invocationContext.agent.name, + content: createUserContent({ + functionResponse: { + id: call.id, + name: tool.name, + response: responseValue, + }, + }), + actions: toolContext.actions, + branch: invocationContext.branch, + }), + ); + } + + // Run any remaining (non-AgentTool) calls through the standard pipeline. + let regularResponseEvent: Event | null = null; + if (regularCalls.length) { + regularResponseEvent = await handleFunctionCallList({ + invocationContext, + functionCalls: regularCalls, + toolsDict, + beforeToolCallbacks, + afterToolCallbacks, + toolConfirmationDict, + }); + } + + const allResponseEvents: Event[] = []; + if (regularResponseEvent) { + allResponseEvents.push(regularResponseEvent); + } + allResponseEvents.push(...agentToolResponseEvents); + + if (!allResponseEvents.length) { + return; + } + + yield allResponseEvents.length === 1 + ? allResponseEvents[0] + : mergeParallelFunctionResponseEvents(allResponseEvents); +} + /** * The underlying implementation of handleFunctionCalls, but takes a list of * function calls instead of an event. diff --git a/core/src/agents/llm_agent.ts b/core/src/agents/llm_agent.ts index 30aba287..515355c3 100644 --- a/core/src/agents/llm_agent.ts +++ b/core/src/agents/llm_agent.ts @@ -49,7 +49,7 @@ import { generateAuthEvent, generateRequestConfirmationEvent, getLongRunningFunctionCalls, - handleFunctionCallsAsync, + handleFunctionCallsStreamingAsync, populateClientFunctionCallId, } from './functions.js'; @@ -907,13 +907,30 @@ export class LlmAgent extends BaseAgent { // Call functions // TODO - b/425992518: bloated funciton input, fix. // Tool callback passed to get rid of cyclic dependency. - const functionResponseEvent = await handleFunctionCallsAsync({ + // Use the streaming variant so events from AgentTool sub-agents are + // surfaced live (Issue parity with adk-python PR #3991). The terminal + // event of the generator is always the merged function-response event; + // hold the latest event in a one-step buffer so we yield every + // intermediate sub-agent event (including the sub-agent's own tool + // responses) and capture only the final one as the response. + let functionResponseEvent: Event | null = null; + let pendingEvent: Event | null = null; + for await (const event of handleFunctionCallsStreamingAsync({ invocationContext: invocationContext, functionCallEvent: mergedEvent, toolsDict: llmRequest.toolsDict, beforeToolCallbacks: this.canonicalBeforeToolCallbacks, afterToolCallbacks: this.canonicalAfterToolCallbacks, - }); + })) { + if (invocationContext.abortSignal?.aborted) { + return; + } + if (pendingEvent) { + yield pendingEvent; + } + pendingEvent = event; + } + functionResponseEvent = pendingEvent; if (!functionResponseEvent || invocationContext.abortSignal?.aborted) { return; diff --git a/core/src/tools/agent_tool.ts b/core/src/tools/agent_tool.ts index 16edb581..0fe178be 100644 --- a/core/src/tools/agent_tool.ts +++ b/core/src/tools/agent_tool.ts @@ -116,14 +116,20 @@ export class AgentTool extends BaseTool { return declaration; } - override async runAsync({ + /** + * Sets up the Runner and Session for sub-agent execution. + * + * Shared by {@link runAsync} and {@link runAsyncWithEvents}. + */ + private async setupRunnerAndSession({ args, toolContext, - }: RunAsyncToolRequest): Promise { - if (this.skipSummarization) { - toolContext.actions.skipSummarization = true; - } - + }: RunAsyncToolRequest): Promise<{ + runner: Runner; + content: Content; + sessionUserId: string; + sessionId: string; + }> { const hasInputSchema = isLlmAgent(this.agent) && this.agent.inputSchema; const content: Content = { role: 'user', @@ -158,15 +164,55 @@ export class AgentTool extends BaseTool { state: toolContext.state.toRecord(), }); + return { + runner, + content, + sessionUserId: session.userId, + sessionId: session.id, + }; + } + + /** + * Builds the tool result from the last content event of the sub-agent. + * + * Excludes thought parts and applies the output schema (if any). + */ + buildToolResultFromContent(lastContent: Content | undefined): unknown { + if (!lastContent?.parts?.length) { + return ''; + } + const hasOutputSchema = isLlmAgent(this.agent) && this.agent.outputSchema; + const mergedText = lastContent.parts + .filter((part) => !part.thought) + .map((part) => part.text) + .filter((text) => text) + .join('\n'); + // TODO - b/425992518: In case of output schema, the output should be + // validated. Consider similar logic to one we have in Python ADK. + return hasOutputSchema ? JSON.parse(mergedText) : mergedText; + } + + override async runAsync({ + args, + toolContext, + }: RunAsyncToolRequest): Promise { + if (this.skipSummarization) { + toolContext.actions.skipSummarization = true; + } + + const {runner, content, sessionUserId, sessionId} = + await this.setupRunnerAndSession({args, toolContext}); + if (toolContext.abortSignal?.aborted) { return ''; } let lastEvent: Event | undefined; for await (const event of runner.runAsync({ - userId: session.userId, - sessionId: session.id, + userId: sessionUserId, + sessionId, newMessage: content, + runConfig: toolContext.invocationContext.runConfig, abortSignal: toolContext.abortSignal, })) { if (toolContext.abortSignal?.aborted) { @@ -180,20 +226,48 @@ export class AgentTool extends BaseTool { lastEvent = event; } - if (!lastEvent?.content?.parts?.length) { - return ''; + return this.buildToolResultFromContent(lastEvent?.content); + } + + /** + * Runs the wrapped agent and yields the sub-agent's events as they are + * produced, providing real-time visibility into sub-agent progress. + * + * Counterpart to {@link runAsync}; the caller is responsible for tracking + * the last content event and building the final tool result via + * {@link buildToolResultFromContent}. + */ + async *runAsyncWithEvents({ + args, + toolContext, + }: RunAsyncToolRequest): AsyncGenerator { + if (this.skipSummarization) { + toolContext.actions.skipSummarization = true; } - const hasOutputSchema = isLlmAgent(this.agent) && this.agent.outputSchema; - // Exclude thoughts from the merged text. - const mergedText = lastEvent.content.parts - .filter((part) => !part.thought) - .map((part) => part.text) - .filter((text) => text) - .join('\n'); + const {runner, content, sessionUserId, sessionId} = + await this.setupRunnerAndSession({args, toolContext}); - // TODO - b/425992518: In case of output schema, the output should be - // validated. Consider similar logic to one we have in Python ADK. - return hasOutputSchema ? JSON.parse(mergedText) : mergedText; + if (toolContext.abortSignal?.aborted) { + return; + } + + for await (const event of runner.runAsync({ + userId: sessionUserId, + sessionId, + newMessage: content, + runConfig: toolContext.invocationContext.runConfig, + abortSignal: toolContext.abortSignal, + })) { + if (toolContext.abortSignal?.aborted) { + return; + } + + if (event.actions.stateDelta) { + toolContext.state.update(event.actions.stateDelta); + } + + yield event; + } } }