diff --git a/core/src/agents/llm_agent.ts b/core/src/agents/llm_agent.ts index 30aba287..bb2d5712 100644 --- a/core/src/agents/llm_agent.ts +++ b/core/src/agents/llm_agent.ts @@ -6,6 +6,7 @@ import {GenerateContentConfig, Schema} from '@google/genai'; import {context, trace} from '@opentelemetry/api'; +import {FunctionTool} from '../tools/function_tool.js'; import {z as z3} from 'zod/v3'; import {z as z4} from 'zod/v4'; @@ -743,7 +744,23 @@ export class LlmAgent extends BaseAgent { } // TODO - b/425992518: check if tool preprocessors can be simplified. // Run pre-processors for tools. - for (const toolUnion of this.tools) { + const allTools = [...this.tools]; + if (this.outputSchema && allTools.length > 0) { + const setModelResponseTool = new FunctionTool({ + name: 'set_model_response', + description: + 'Call this tool to submit your final response conforming to the output schema. Use this tool only when you have collected all the information and are ready to return the final answer.', + parameters: this.outputSchema, + execute: async (args, toolContext) => { + if (toolContext) { + toolContext.actions.skipSummarization = true; + } + return JSON.stringify(args); + }, + }); + allTools.push(setModelResponseTool); + } + for (const toolUnion of allTools) { const toolContext = new Context({invocationContext}); // process all tools from this tool union @@ -758,7 +775,8 @@ export class LlmAgent extends BaseAgent { // The allowedTools set is populated by request processors. return ( !llmRequest.allowedTools || - llmRequest.allowedTools.includes(tool.name) + llmRequest.allowedTools.includes(tool.name) || + tool.name === 'set_model_response' ); }); @@ -880,8 +898,14 @@ export class LlmAgent extends BaseAgent { if (mergedEvent.content) { const functionCalls = getFunctionCalls(mergedEvent); - if (functionCalls?.length) { - // TODO - b/425992518: rename topopulate if missing. + const setModelResponseCall = functionCalls.find( + (call) => call.name === 'set_model_response', + ); + if (setModelResponseCall) { + const args = setModelResponseCall.args; + mergedEvent.content.parts = [{text: JSON.stringify(args)}]; + mergedEvent.actions.skipSummarization = true; + } else if (functionCalls && functionCalls.length) { populateClientFunctionCallId(mergedEvent); // TODO - b/425992518: hacky, transaction log, simplify. // Long running is a property of tool in registry. diff --git a/core/src/agents/processors/basic_llm_request_processor.ts b/core/src/agents/processors/basic_llm_request_processor.ts index 56eb9b47..7ee785fd 100644 --- a/core/src/agents/processors/basic_llm_request_processor.ts +++ b/core/src/agents/processors/basic_llm_request_processor.ts @@ -30,7 +30,7 @@ export class BasicLlmRequestProcessor extends BaseLlmRequestProcessor { llmRequest.model = agent.canonicalModel.model; llmRequest.config = {...(agent.generateContentConfig ?? {})}; - if (agent.outputSchema) { + if (agent.outputSchema && (!agent.tools || agent.tools.length === 0)) { setOutputSchema(llmRequest, agent.outputSchema); } diff --git a/core/src/agents/processors/instructions_llm_request_processor.ts b/core/src/agents/processors/instructions_llm_request_processor.ts index acd381a3..aff580b4 100644 --- a/core/src/agents/processors/instructions_llm_request_processor.ts +++ b/core/src/agents/processors/instructions_llm_request_processor.ts @@ -61,6 +61,12 @@ export class InstructionsLlmRequestProcessor extends BaseLlmRequestProcessor { } appendInstructions(llmRequest, [instructionWithState]); } + + if (agent.outputSchema && agent.tools && agent.tools.length > 0) { + appendInstructions(llmRequest, [ + 'To output the final result, you must call the "set_model_response" function with the appropriate values. Do not output anything else.', + ]); + } } } diff --git a/core/test/agents/processors/basic_llm_request_processor_test.ts b/core/test/agents/processors/basic_llm_request_processor_test.ts index 611495ba..e9c9ab74 100644 --- a/core/test/agents/processors/basic_llm_request_processor_test.ts +++ b/core/test/agents/processors/basic_llm_request_processor_test.ts @@ -9,6 +9,7 @@ import { BaseLlm, BaseLlmConnection, createSession, + FunctionTool, InvocationContext, LlmAgent, LLMRegistry, @@ -168,6 +169,34 @@ describe('BasicLlmRequestProcessor', () => { expect(llmRequest.config?.responseMimeType).toBe('application/json'); }); + it('should not set outputSchema in config when agent has outputSchema and tools', async () => { + const outputSchema = { + type: 'object' as const, + properties: { + answer: {type: 'string' as const}, + }, + }; + const agent = new LlmAgent({ + name: 'test_agent', + model: 'test-basic-processor-model', + outputSchema, + tools: [ + new FunctionTool({ + name: 'some_tool', + description: 'A test tool', + execute: () => 'result', + }), + ], + }); + const invocationContext = createMockInvocationContext(agent); + const llmRequest = makeLlmRequest(); + + await runProcessor(invocationContext, llmRequest); + + expect(llmRequest.config?.responseSchema).toBeUndefined(); + expect(llmRequest.config?.responseMimeType).toBeUndefined(); + }); + it('should populate liveConnectConfig from runConfig', async () => { const agent = new LlmAgent({ name: 'test_agent', diff --git a/core/test/agents/processors/instructions_llm_request_processor_test.ts b/core/test/agents/processors/instructions_llm_request_processor_test.ts index cf68189f..2f64e501 100644 --- a/core/test/agents/processors/instructions_llm_request_processor_test.ts +++ b/core/test/agents/processors/instructions_llm_request_processor_test.ts @@ -6,12 +6,13 @@ import { BaseAgent, + createSession, + FunctionTool, InvocationContext, LlmAgent, LlmRequest, PluginManager, ReadonlyContext, - createSession, } from '@google/adk'; import {describe, expect, it} from 'vitest'; import {INSTRUCTIONS_LLM_REQUEST_PROCESSOR} from '../../../src/agents/processors/instructions_llm_request_processor.js'; @@ -159,4 +160,44 @@ describe('InstructionsLlmRequestProcessor', () => { 'Global instruction\n\nLocal instruction', ); }); + + it('should append set_model_response instruction when outputSchema and tools are present', async () => { + const outputSchema = { + type: 'object' as const, + properties: { + answer: {type: 'string' as const}, + }, + }; + const agent = new LlmAgent({ + name: 'test_agent', + model: 'gemini-2.5-flash', + instruction: 'Base instruction', + outputSchema, + tools: [ + new FunctionTool({ + name: 'some_tool', + description: 'A test tool', + execute: () => 'result', + }), + ], + }); + + const invocationContext = createMockInvocationContext(agent); + const llmRequest: LlmRequest = { + contents: [], + toolsDict: {}, + liveConnectConfig: {}, + }; + + for await (const _ of INSTRUCTIONS_LLM_REQUEST_PROCESSOR.runAsync( + invocationContext, + llmRequest, + )) { + // intentionally empty + } + + expect(llmRequest.config?.systemInstruction).toContain( + 'To output the final result, you must call the "set_model_response" function with the appropriate values. Do not output anything else.', + ); + }); });