Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions core/src/agents/functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {isEmpty} from 'lodash-es';
import {InvocationContext} from '../agents/invocation_context.js';
import {createEvent, Event, getFunctionCalls} from '../events/event.js';
import {mergeEventActions} from '../events/event_actions.js';
import {AgentTool, isAgentTool} from '../tools/agent_tool.js';
import {BaseTool} from '../tools/base_tool.js';
import {ToolConfirmation} from '../tools/tool_confirmation.js';
import {randomUUID} from '../utils/env_aware_utils.js';
Expand Down Expand Up @@ -313,6 +314,126 @@ export async function handleFunctionCallsAsync({
});
}

/**
* Handles function calls and yields events from {@link AgentTool} sub-agents
* as they are produced, then yields the final merged function-response event
* last.
*
* Non-{@link AgentTool} calls are delegated to {@link handleFunctionCallList}
* (preserving plugin/before/after-callback semantics). {@link AgentTool} calls
* are streamed via {@link AgentTool.runAsyncWithEvents} so the parent runner
* can surface sub-agent activity in real time. The caller is expected to
* identify the terminal function-response event via
* {@link getFunctionResponses}.
*/
export async function* handleFunctionCallsStreamingAsync({
invocationContext,
functionCallEvent,
toolsDict,
beforeToolCallbacks,
afterToolCallbacks,
toolConfirmationDict,
}: {
invocationContext: InvocationContext;
functionCallEvent: Event;
toolsDict: Record<string, BaseTool>;
beforeToolCallbacks: SingleBeforeToolCallback[];
afterToolCallbacks: SingleAfterToolCallback[];
toolConfirmationDict?: Record<string, ToolConfirmation>;
}): AsyncGenerator<Event> {
const functionCalls = getFunctionCalls(functionCallEvent);
if (!functionCalls.length) {
return;
}

const agentToolCalls: Array<{call: FunctionCall; tool: AgentTool}> = [];
const regularCalls: FunctionCall[] = [];
for (const functionCall of functionCalls) {
const tool = functionCall.name ? toolsDict[functionCall.name] : undefined;
if (tool && isAgentTool(tool)) {
agentToolCalls.push({call: functionCall, tool});
} else {
regularCalls.push(functionCall);
}
}

// Stream sub-agent events for each AgentTool call and build response events.
const agentToolResponseEvents: Event[] = [];
for (const {call, tool} of agentToolCalls) {
const toolContext = new Context({
invocationContext,
functionCallId: call.id || undefined,
toolConfirmation:
toolConfirmationDict && call.id
? toolConfirmationDict[call.id]
: undefined,
});

let lastContent: Content | undefined;
for await (const event of tool.runAsyncWithEvents({
args: call.args ?? {},
toolContext,
})) {
if (invocationContext.abortSignal?.aborted) {
return;
}
if (event.content) {
lastContent = event.content;
}
yield event;
}

const toolResult = tool.buildToolResultFromContent(lastContent);
const responseValue =
typeof toolResult !== 'object' || toolResult == null
? {result: toolResult}
: (toolResult as Record<string, unknown>);

agentToolResponseEvents.push(
createEvent({
invocationId: invocationContext.invocationId,
author: invocationContext.agent.name,
content: createUserContent({
functionResponse: {
id: call.id,
name: tool.name,
response: responseValue,
},
}),
actions: toolContext.actions,
branch: invocationContext.branch,
}),
);
}

// Run any remaining (non-AgentTool) calls through the standard pipeline.
let regularResponseEvent: Event | null = null;
if (regularCalls.length) {
regularResponseEvent = await handleFunctionCallList({
invocationContext,
functionCalls: regularCalls,
toolsDict,
beforeToolCallbacks,
afterToolCallbacks,
toolConfirmationDict,
});
}

const allResponseEvents: Event[] = [];
if (regularResponseEvent) {
allResponseEvents.push(regularResponseEvent);
}
allResponseEvents.push(...agentToolResponseEvents);

if (!allResponseEvents.length) {
return;
}

yield allResponseEvents.length === 1
? allResponseEvents[0]
: mergeParallelFunctionResponseEvents(allResponseEvents);
}

/**
* The underlying implementation of handleFunctionCalls, but takes a list of
* function calls instead of an event.
Expand Down
23 changes: 20 additions & 3 deletions core/src/agents/llm_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import {
generateAuthEvent,
generateRequestConfirmationEvent,
getLongRunningFunctionCalls,
handleFunctionCallsAsync,
handleFunctionCallsStreamingAsync,
populateClientFunctionCallId,
} from './functions.js';

Expand Down Expand Up @@ -907,13 +907,30 @@ export class LlmAgent extends BaseAgent {
// Call functions
// TODO - b/425992518: bloated funciton input, fix.
// Tool callback passed to get rid of cyclic dependency.
const functionResponseEvent = await handleFunctionCallsAsync({
// Use the streaming variant so events from AgentTool sub-agents are
// surfaced live (Issue parity with adk-python PR #3991). The terminal
// event of the generator is always the merged function-response event;
// hold the latest event in a one-step buffer so we yield every
// intermediate sub-agent event (including the sub-agent's own tool
// responses) and capture only the final one as the response.
let functionResponseEvent: Event | null = null;
let pendingEvent: Event | null = null;
for await (const event of handleFunctionCallsStreamingAsync({
invocationContext: invocationContext,
functionCallEvent: mergedEvent,
toolsDict: llmRequest.toolsDict,
beforeToolCallbacks: this.canonicalBeforeToolCallbacks,
afterToolCallbacks: this.canonicalAfterToolCallbacks,
});
})) {
if (invocationContext.abortSignal?.aborted) {
return;
}
if (pendingEvent) {
yield pendingEvent;
}
pendingEvent = event;
}
functionResponseEvent = pendingEvent;

if (!functionResponseEvent || invocationContext.abortSignal?.aborted) {
return;
Expand Down
114 changes: 94 additions & 20 deletions core/src/tools/agent_tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,20 @@ export class AgentTool extends BaseTool {
return declaration;
}

override async runAsync({
/**
* Sets up the Runner and Session for sub-agent execution.
*
* Shared by {@link runAsync} and {@link runAsyncWithEvents}.
*/
private async setupRunnerAndSession({
args,
toolContext,
}: RunAsyncToolRequest): Promise<unknown> {
if (this.skipSummarization) {
toolContext.actions.skipSummarization = true;
}

}: RunAsyncToolRequest): Promise<{
runner: Runner;
content: Content;
sessionUserId: string;
sessionId: string;
}> {
const hasInputSchema = isLlmAgent(this.agent) && this.agent.inputSchema;
const content: Content = {
role: 'user',
Expand Down Expand Up @@ -158,15 +164,55 @@ export class AgentTool extends BaseTool {
state: toolContext.state.toRecord(),
});

return {
runner,
content,
sessionUserId: session.userId,
sessionId: session.id,
};
}

/**
* Builds the tool result from the last content event of the sub-agent.
*
* Excludes thought parts and applies the output schema (if any).
*/
buildToolResultFromContent(lastContent: Content | undefined): unknown {
if (!lastContent?.parts?.length) {
return '';
}
const hasOutputSchema = isLlmAgent(this.agent) && this.agent.outputSchema;
const mergedText = lastContent.parts
.filter((part) => !part.thought)
.map((part) => part.text)
.filter((text) => text)
.join('\n');
// TODO - b/425992518: In case of output schema, the output should be
// validated. Consider similar logic to one we have in Python ADK.
return hasOutputSchema ? JSON.parse(mergedText) : mergedText;
}

override async runAsync({
args,
toolContext,
}: RunAsyncToolRequest): Promise<unknown> {
if (this.skipSummarization) {
toolContext.actions.skipSummarization = true;
}

const {runner, content, sessionUserId, sessionId} =
await this.setupRunnerAndSession({args, toolContext});

if (toolContext.abortSignal?.aborted) {
return '';
}

let lastEvent: Event | undefined;
for await (const event of runner.runAsync({
userId: session.userId,
sessionId: session.id,
userId: sessionUserId,
sessionId,
newMessage: content,
runConfig: toolContext.invocationContext.runConfig,
abortSignal: toolContext.abortSignal,
})) {
if (toolContext.abortSignal?.aborted) {
Expand All @@ -180,20 +226,48 @@ export class AgentTool extends BaseTool {
lastEvent = event;
}

if (!lastEvent?.content?.parts?.length) {
return '';
return this.buildToolResultFromContent(lastEvent?.content);
}

/**
* Runs the wrapped agent and yields the sub-agent's events as they are
* produced, providing real-time visibility into sub-agent progress.
*
* Counterpart to {@link runAsync}; the caller is responsible for tracking
* the last content event and building the final tool result via
* {@link buildToolResultFromContent}.
*/
async *runAsyncWithEvents({
args,
toolContext,
}: RunAsyncToolRequest): AsyncGenerator<Event> {
if (this.skipSummarization) {
toolContext.actions.skipSummarization = true;
}

const hasOutputSchema = isLlmAgent(this.agent) && this.agent.outputSchema;
// Exclude thoughts from the merged text.
const mergedText = lastEvent.content.parts
.filter((part) => !part.thought)
.map((part) => part.text)
.filter((text) => text)
.join('\n');
const {runner, content, sessionUserId, sessionId} =
await this.setupRunnerAndSession({args, toolContext});

// TODO - b/425992518: In case of output schema, the output should be
// validated. Consider similar logic to one we have in Python ADK.
return hasOutputSchema ? JSON.parse(mergedText) : mergedText;
if (toolContext.abortSignal?.aborted) {
return;
}

for await (const event of runner.runAsync({
userId: sessionUserId,
sessionId,
newMessage: content,
runConfig: toolContext.invocationContext.runConfig,
abortSignal: toolContext.abortSignal,
})) {
if (toolContext.abortSignal?.aborted) {
return;
}

if (event.actions.stateDelta) {
toolContext.state.update(event.actions.stateDelta);
}

yield event;
}
}
}