From 204274b675a8649a4accf62c4c0632dfebf367fc Mon Sep 17 00:00:00 2001 From: alexheifetz Date: Mon, 2 Mar 2026 19:46:58 -0500 Subject: [PATCH 1/7] Pass metadata context through tool pipeline to MCP calls (#1323) Introduce ToolCallContext as an immutable, framework-agnostic key-value bag for passing out-of-band metadata (auth tokens, tenant IDs, correlation IDs) through the tool pipeline without polluting JSON input schemas. Core API: - ToolCallContext with merge, contains, and Map conversion - Tool interface two-arg call(input, context) with backward-compatible default - DelegatingTool default propagation through decorator chains - ContextAwareFunction for lambda-based context-aware tools Implementation: - MethodTool (Kotlin + Java): inject ToolCallContext, exclude from schema - ReplanningTool, ArtifactSinkingTool: context-aware overrides - DefaultToolLoop: accept and forward context on every invocation - SpringToolCallbackAdapter/Wrapper: bridge to/from Spring AI ToolContext - ProcessOptions: toolCallContext property with Java-friendly Map overload - ToolCallContextMcpMetaConverter: allowlist/denylist filtering (not yet wired) Tests: - ToolCallContext unit, flow, and Java interop tests - MethodTool context injection and schema exclusion (Kotlin + Java) --- .../agent/api/tool/ArtifactSinkingTool.kt | 10 +- .../embabel/agent/api/tool/DelegatingTool.kt | 15 + .../com/embabel/agent/api/tool/MethodTool.kt | 36 +- .../embabel/agent/api/tool/ReplanningTools.kt | 30 +- .../kotlin/com/embabel/agent/api/tool/Tool.kt | 62 +++ .../embabel/agent/api/tool/ToolCallContext.kt | 94 ++++ .../com/embabel/agent/core/ProcessOptions.kt | 14 + .../agent/core/support/LlmInteraction.kt | 2 + .../embabel/agent/spi/loop/ToolLoopFactory.kt | 6 + .../agent/spi/loop/support/DefaultToolLoop.kt | 6 +- .../spi/loop/support/ParallelToolLoop.kt | 3 + .../spi/support/ObservabilityToolCallback.kt | 13 +- .../support/OutputTransformingToolCallback.kt | 11 +- .../agent/spi/support/ToolDecorators.kt | 72 ++- .../spi/support/ToolLoopLlmOperations.kt | 27 ++ .../springai/SpringToolCallbackAdapter.kt | 36 +- .../mcp/ToolCallContextMcpMetaConverter.kt | 132 ++++++ .../MethodToolContextInjectionJavaTest.java | 301 ++++++++++++ .../api/tool/ToolCallContextJavaTest.java | 88 ++++ .../tool/MethodToolContextInjectionTest.kt | 315 +++++++++++++ .../agent/api/tool/ToolCallContextFlowTest.kt | 443 ++++++++++++++++++ .../agent/api/tool/ToolCallContextTest.kt | 145 ++++++ .../agent/spi/loop/ToolLoopFactoryTest.kt | 4 + .../ChatClientLlmOperationsGuardRailTest.kt | 1 + .../support/ChatClientLlmOperationsTest.kt | 4 + .../ChatClientLlmOperationsThinkingTest.kt | 1 + .../support/ChatClientLlmTransformerTest.kt | 3 + .../com/embabel/agent/shell/ShellCommands.kt | 78 ++- 28 files changed, 1900 insertions(+), 52 deletions(-) create mode 100644 embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ToolCallContext.kt create mode 100644 embabel-agent-api/src/main/kotlin/com/embabel/agent/tools/mcp/ToolCallContextMcpMetaConverter.kt create mode 100644 embabel-agent-api/src/test/java/com/embabel/agent/api/tool/MethodToolContextInjectionJavaTest.java create mode 100644 embabel-agent-api/src/test/java/com/embabel/agent/api/tool/ToolCallContextJavaTest.java create mode 100644 embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/MethodToolContextInjectionTest.kt create mode 100644 embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/ToolCallContextFlowTest.kt create mode 100644 embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/ToolCallContextTest.kt diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ArtifactSinkingTool.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ArtifactSinkingTool.kt index d8889c0fd..c1a0a8c4e 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ArtifactSinkingTool.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ArtifactSinkingTool.kt @@ -88,8 +88,14 @@ class ArtifactSinkingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result { - val result = delegate.call(input) + override fun call(input: String): Tool.Result = + callAndSink { delegate.call(input) } + + override fun call(input: String, context: ToolCallContext): Tool.Result = + callAndSink { delegate.call(input, context) } + + private inline fun callAndSink(action: () -> Tool.Result): Tool.Result { + val result = action() if (result is Tool.Result.WithArtifact) { val artifact = result.artifact diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/DelegatingTool.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/DelegatingTool.kt index 12c714138..e9f7b19fd 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/DelegatingTool.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/DelegatingTool.kt @@ -19,6 +19,12 @@ package com.embabel.agent.api.tool * Interface for tool decorators that wrap another tool. * Enables unwrapping to find the underlying tool implementation. * Thus, it is important that tool wrappers implement this interface to allow unwrapping. + * + * The default [call] (String, ToolCallContext) implementation propagates + * context through the decorator chain by delegating to + * `delegate.call(input, context)`. Decorators that add behavior + * (e.g., artifact sinking, replanning) should override this method + * to apply their logic while preserving context propagation. */ interface DelegatingTool : Tool { @@ -26,4 +32,13 @@ interface DelegatingTool : Tool { * The underlying tool being delegated to. */ val delegate: Tool + + /** + * Propagates [context] through the decorator chain. + * Decorators that override [call] (String) to add behavior should + * also override this method to apply the same behavior while + * forwarding context to [delegate]. + */ + override fun call(input: String, context: ToolCallContext): Tool.Result = + delegate.call(input, context) } diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/MethodTool.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/MethodTool.kt index 8562f2625..a6284936a 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/MethodTool.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/MethodTool.kt @@ -32,6 +32,11 @@ import kotlin.reflect.jvm.javaType /** * Tool implementation that wraps a method annotated with [@LlmTool]. + * + * Supports [ToolCallContext] injection: if the annotated method declares a + * parameter of type [ToolCallContext], the framework injects the current context + * automatically — just like Spring AI injects `ToolContext` into `@Tool` methods. + * Such parameters are excluded from the JSON input schema sent to the LLM. */ internal sealed class MethodTool( protected val instance: Any, @@ -43,10 +48,16 @@ internal sealed class MethodTool( override val metadata: Tool.Metadata = Tool.Metadata(returnDirect = annotation.returnDirect) - override fun call(input: String): Tool.Result { + override fun call(input: String): Tool.Result = + callWithContext(input, ToolCallContext.EMPTY) + + override fun call(input: String, context: ToolCallContext): Tool.Result = + callWithContext(input, context) + + private fun callWithContext(input: String, context: ToolCallContext): Tool.Result { return try { val args = parseArguments(input) - val result = invokeMethod(args) + val result = invokeMethod(args, context) convertResult(result) } catch (e: Exception) { // Unwrap InvocationTargetException to get the actual cause @@ -74,7 +85,7 @@ internal sealed class MethodTool( } } - protected abstract fun invokeMethod(args: Map): Any? + protected abstract fun invokeMethod(args: Map, context: ToolCallContext): Any? private fun convertResult(result: Any?): Tool.Result { return when (result) { @@ -138,8 +149,10 @@ internal class KotlinMethodTool( override val definition: Tool.Definition by lazy { val name = annotation.name.ifEmpty { method.name } // Use victools-based schema generation for proper generic type handling + // Exclude ToolCallContext parameters — they are framework-injected, not LLM-provided val parameterInfos = method.parameters .filter { it.kind == KParameter.Kind.VALUE } + .filter { it.type.javaType != ToolCallContext::class.java } .map { param -> val paramAnnotation = param.findAnnotation() ParameterInfo( @@ -156,7 +169,7 @@ internal class KotlinMethodTool( ) } - override fun invokeMethod(args: Map): Any? { + override fun invokeMethod(args: Map, context: ToolCallContext): Any? { val params = method.parameters val callArgs = mutableMapOf() @@ -165,6 +178,12 @@ internal class KotlinMethodTool( when (param.kind) { KParameter.Kind.INSTANCE -> callArgs[param] = instance KParameter.Kind.VALUE -> { + // Inject ToolCallContext if the parameter type matches + if (param.type.javaType == ToolCallContext::class.java) { + callArgs[param] = context + continue + } + val paramAnnotation = param.findAnnotation() val paramName = param.name ?: continue val value = args[paramName] @@ -207,7 +226,9 @@ internal class JavaMethodTool( override val definition: Tool.Definition by lazy { val name = annotation.name.ifEmpty { method.name } // Use victools-based schema generation for proper generic type handling + // Exclude ToolCallContext parameters — they are framework-injected, not LLM-provided val parameterInfos = method.parameters + .filter { !ToolCallContext::class.java.isAssignableFrom(it.type) } .map { param -> val paramAnnotation = param.getAnnotation(Param::class.java) ParameterInfo( @@ -224,11 +245,16 @@ internal class JavaMethodTool( ) } - override fun invokeMethod(args: Map): Any? { + override fun invokeMethod(args: Map, context: ToolCallContext): Any? { val params = method.parameters val callArgs = arrayOfNulls(method.parameters.size) for ((index, param) in params.withIndex()) { + // Inject ToolCallContext if the parameter type matches + if (ToolCallContext::class.java.isAssignableFrom(param.type)) { + callArgs[index] = context + continue + } val value = args[param.name] if (value != null) { // Convert value to expected type if needed diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ReplanningTools.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ReplanningTools.kt index 1f249a7fe..4e5af31d1 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ReplanningTools.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ReplanningTools.kt @@ -43,9 +43,6 @@ fun interface ReplanningToolBlackboardUpdater { * - Chat routing: A routing tool classifies intent and triggers replan to switch handlers * - Discovery: A tool discovers information that requires a different plan * - * Note: This tool accesses [AgentProcess] via thread-local at call time, which is set - * by the decorator chain. - * * @param delegate The tool to wrap * @param reason Human-readable explanation of why replan is needed * @param blackboardUpdater Callback to update the blackboard before replanning. @@ -65,8 +62,14 @@ class ReplanningTool @JvmOverloads constructor( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result { - val result = delegate.call(input) + override fun call(input: String): Tool.Result = + callAndReplan { delegate.call(input) } + + override fun call(input: String, context: ToolCallContext): Tool.Result = + callAndReplan { delegate.call(input, context) } + + private inline fun callAndReplan(action: () -> Tool.Result): Tool.Result { + val result = action() val resultContent = result.content throw ReplanRequestedException( @@ -135,9 +138,6 @@ fun interface ReplanDecider { * Unlike [ReplanningTool] which always triggers replanning, this tool allows the [ReplanDecider] * to inspect the result and decide whether to replan. * - * Note: This tool accesses [AgentProcess] via thread-local at call time, which is set - * by the decorator chain. - * * @param delegate The tool to wrap * @param decider Decider that inspects the result context and determines whether to replan */ @@ -149,19 +149,25 @@ class ConditionalReplanningTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result { - val result = delegate.call(input) + override fun call(input: String): Tool.Result = + callAndMaybeReplan { delegate.call(input) } + + override fun call(input: String, context: ToolCallContext): Tool.Result = + callAndMaybeReplan { delegate.call(input, context) } + + private inline fun callAndMaybeReplan(action: () -> Tool.Result): Tool.Result { + val result = action() val agentProcess = AgentProcess.get() ?: throw IllegalStateException("No AgentProcess available for ConditionalReplanningTool") - val context = ReplanContext( + val replanContext = ReplanContext( result = result, agentProcess = agentProcess, tool = delegate, ) - val decision = decider.evaluate(context) + val decision = decider.evaluate(replanContext) if (decision != null) { throw ReplanRequestedException( reason = decision.reason, diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/Tool.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/Tool.kt index 3657cc651..a13059b0a 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/Tool.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/Tool.kt @@ -51,6 +51,22 @@ interface Tool : ToolInfo { */ fun call(input: String): Result + /** + * Execute the tool with JSON input and out-of-band context. + * + * The default implementation simply delegates to [call] (String), + * discarding the context. Override this method to receive context + * explicitly (e.g., for auth tokens, tenant IDs, or correlation IDs). + * + * [DelegatingTool] provides a default that propagates context through + * decorator chains, so most decorators do not need to override this. + * + * @param input JSON string matching inputSchema + * @param context out-of-band metadata (auth tokens, tenant IDs, etc.) + * @return Result to send back to LLM + */ + fun call(input: String, context: ToolCallContext): Result = call(input) + /** * Framework-agnostic tool definition. */ @@ -259,6 +275,14 @@ interface Tool : ToolInfo { fun invoke(input: String): Result } + /** + * Functional interface for context-aware tool implementations. + * Use when the tool needs out-of-band metadata (auth tokens, tenant IDs, etc.). + */ + fun interface ContextAwareFunction { + fun invoke(input: String, context: ToolCallContext): Result + } + /** * Java-friendly functional interface for tool implementations. * Uses `handle` method name which is more idiomatic in Java than `invoke`. @@ -295,6 +319,32 @@ interface Tool : ToolInfo { function: Function, ): Tool = of(name, description, InputSchema.empty(), metadata, function) + /** + * Create a context-aware tool from a [ContextAwareFunction]. + * The function receives [ToolCallContext] explicitly at call time. + */ + fun of( + name: String, + description: String, + inputSchema: InputSchema, + metadata: Metadata = Metadata.DEFAULT, + function: ContextAwareFunction, + ): Tool = ContextAwareFunctionalTool( + definition = Definition(name, description, inputSchema), + metadata = metadata, + function = function, + ) + + /** + * Create a context-aware tool with no parameters. + */ + fun of( + name: String, + description: String, + metadata: Metadata = Metadata.DEFAULT, + function: ContextAwareFunction, + ): Tool = of(name, description, InputSchema.empty(), metadata, function) + /** * Create a tool with no parameters (Java-friendly). * This method is easier to call from Java as it uses the Handler interface. @@ -698,3 +748,15 @@ private class FunctionalTool( override fun call(input: String): Tool.Result = function.invoke(input) } + +private class ContextAwareFunctionalTool( + override val definition: Tool.Definition, + override val metadata: Tool.Metadata, + private val function: Tool.ContextAwareFunction, +) : Tool { + override fun call(input: String): Tool.Result = + function.invoke(input, ToolCallContext.EMPTY) + + override fun call(input: String, context: ToolCallContext): Tool.Result = + function.invoke(input, context) +} diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ToolCallContext.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ToolCallContext.kt new file mode 100644 index 000000000..13c3e0fa7 --- /dev/null +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ToolCallContext.kt @@ -0,0 +1,94 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.api.tool + +/** + * Framework-agnostic, immutable context passed to tools at call time. + * Carries out-of-band metadata such as auth tokens, tenant IDs, or + * correlation IDs without polluting the tool's JSON input schema. + * + * Context flows explicitly through the [Tool.call] two-arg overload and is + * propagated through decorator chains by [DelegatingTool]. + */ +class ToolCallContext private constructor( + private val entries: Map, +) { + + val isEmpty: Boolean get() = entries.isEmpty() + + /** + * Retrieve a value by key, cast to the expected type. + * Returns `null` when the key is absent. + */ + @Suppress("UNCHECKED_CAST") + fun get(key: String): T? = entries[key] as T? + + /** + * Retrieve a value by key, returning [default] when absent. + */ + @Suppress("UNCHECKED_CAST") + fun getOrDefault(key: String, default: T): T = + (entries[key] as T?) ?: default + + /** + * Check whether the context contains a given key. + * Supports Kotlin `in` operator: `"token" in ctx`. + */ + operator fun contains(key: String): Boolean = key in entries + + /** + * Merge this context with [other]. Values in [other] win on conflict. + */ + fun merge(other: ToolCallContext): ToolCallContext { + if (other.isEmpty) return this + if (this.isEmpty) return other + return ToolCallContext(this.entries + other.entries) + } + + /** + * Snapshot as an unmodifiable map, safe to hand to Spring AI [org.springframework.ai.chat.model.ToolContext]. + */ + fun toMap(): Map = entries.toMap() + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is ToolCallContext) return false + return entries == other.entries + } + + override fun hashCode(): Int = entries.hashCode() + + override fun toString(): String = "ToolCallContext($entries)" + + companion object { + + @JvmField + val EMPTY = ToolCallContext(emptyMap()) + + // ---- Factory methods ---- + + @JvmStatic + fun of(entries: Map): ToolCallContext { + if (entries.isEmpty()) return EMPTY + return ToolCallContext(entries.toMap()) + } + + fun of(vararg pairs: Pair): ToolCallContext { + if (pairs.isEmpty()) return EMPTY + return ToolCallContext(pairs.toMap()) + } + } +} diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/ProcessOptions.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/ProcessOptions.kt index 58b34ad0e..4e5440605 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/ProcessOptions.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/ProcessOptions.kt @@ -20,6 +20,7 @@ import com.embabel.agent.api.channel.OutputChannel import com.embabel.agent.api.common.PlannerType import com.embabel.agent.api.event.AgenticEventListener import com.embabel.agent.api.identity.User +import com.embabel.agent.api.tool.ToolCallContext /** * Control how much detail to log from LLM interactions. @@ -206,6 +207,9 @@ constructor( * @param listeners additional listeners (beyond platform event listeners) to receive events from this process. * @param outputChannel custom output channel to use for this process. * @param plannerType the type of planner to use for this process. Defaults to GOAP planner. + * @param toolCallContext out-of-band metadata (e.g., auth tokens, tenant IDs) passed to tools + * at call time. This context is propagated to all tools, including MCP tools where it bridges + * to Spring AI's ToolContext and ultimately to MCP's McpMeta. */ data class ProcessOptions @JvmOverloads constructor( val contextId: ContextId? = null, @@ -222,6 +226,7 @@ data class ProcessOptions @JvmOverloads constructor( val listeners: List = emptyList(), val outputChannel: OutputChannel = DevNullOutputChannel, val plannerType: PlannerType = PlannerType.GOAP, + val toolCallContext: ToolCallContext = ToolCallContext.EMPTY, ) { /** @@ -281,6 +286,15 @@ data class ProcessOptions @JvmOverloads constructor( fun withPlannerType(plannerType: PlannerType): ProcessOptions = this.copy(plannerType = plannerType) + fun withToolCallContext(toolCallContext: ToolCallContext): ProcessOptions = + this.copy(toolCallContext = toolCallContext) + + /** + * Java-friendly overload accepting a raw map. + */ + fun withToolCallContext(context: Map): ProcessOptions = + this.copy(toolCallContext = ToolCallContext.of(context)) + companion object { @JvmField diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/support/LlmInteraction.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/support/LlmInteraction.kt index 179ef4969..13c58df95 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/support/LlmInteraction.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/support/LlmInteraction.kt @@ -18,6 +18,7 @@ package com.embabel.agent.core.support import com.embabel.agent.api.common.ContextualPromptElement import com.embabel.agent.api.common.InteractionId import com.embabel.agent.api.tool.Tool +import com.embabel.agent.api.tool.ToolCallContext import com.embabel.agent.core.ToolConsumer import com.embabel.agent.core.ToolGroupConsumer import com.embabel.agent.core.ToolGroupRequirement @@ -131,6 +132,7 @@ data class LlmInteraction( val additionalInjectionStrategies: List = emptyList(), val inspectors: List = emptyList(), val transformers: List = emptyList(), + val toolCallContext: ToolCallContext = ToolCallContext.EMPTY, ) : LlmCall { override val name: String = id.value diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/loop/ToolLoopFactory.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/loop/ToolLoopFactory.kt index fbe87b410..82ecdd38c 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/loop/ToolLoopFactory.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/loop/ToolLoopFactory.kt @@ -17,6 +17,7 @@ package com.embabel.agent.spi.loop import com.embabel.agent.api.common.Asyncer import com.embabel.agent.api.tool.Tool +import com.embabel.agent.api.tool.ToolCallContext import com.embabel.agent.api.tool.config.ToolLoopConfiguration import com.embabel.agent.api.tool.config.ToolLoopConfiguration.ToolLoopType import com.embabel.agent.spi.loop.support.DefaultToolLoop @@ -52,6 +53,7 @@ fun interface ToolLoopFactory { * @param toolDecorator optional decorator for injected tools * @param inspectors read-only observers for tool loop lifecycle events * @param transformers transformers for modifying conversation history or tool results + * @param toolCallContext context propagated to tool invocations */ fun create( llmMessageSender: LlmMessageSender, @@ -61,6 +63,7 @@ fun interface ToolLoopFactory { toolDecorator: ((Tool) -> Tool)?, inspectors: List, transformers: List, + toolCallContext: ToolCallContext, ): ToolLoop companion object { @@ -94,6 +97,7 @@ internal class ConfigurableToolLoopFactory( toolDecorator: ((Tool) -> Tool)?, inspectors: List, transformers: List, + toolCallContext: ToolCallContext, ): ToolLoop = when (config.type) { ToolLoopType.DEFAULT -> DefaultToolLoop( llmMessageSender = llmMessageSender, @@ -103,6 +107,7 @@ internal class ConfigurableToolLoopFactory( toolDecorator = toolDecorator, inspectors = inspectors, transformers = transformers, + toolCallContext = toolCallContext, ) ToolLoopType.PARALLEL -> ParallelToolLoop( llmMessageSender = llmMessageSender, @@ -114,6 +119,7 @@ internal class ConfigurableToolLoopFactory( transformers = transformers, asyncer = asyncer, parallelConfig = config.parallel, + toolCallContext = toolCallContext, ) } } diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/loop/support/DefaultToolLoop.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/loop/support/DefaultToolLoop.kt index 3a3eae5e2..0632b7f47 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/loop/support/DefaultToolLoop.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/loop/support/DefaultToolLoop.kt @@ -16,6 +16,7 @@ package com.embabel.agent.spi.loop.support import com.embabel.agent.api.tool.Tool +import com.embabel.agent.api.tool.ToolCallContext import com.embabel.agent.api.tool.ToolControlFlowSignal import com.embabel.agent.core.BlackboardUpdater import com.embabel.agent.core.ReplanRequestedException @@ -58,6 +59,7 @@ internal open class DefaultToolLoop( private val toolDecorator: ((Tool) -> Tool)? = null, protected val inspectors: List = emptyList(), protected val transformers: List = emptyList(), + private val toolCallContext: ToolCallContext = ToolCallContext.EMPTY, ) : ToolLoop { private val logger = LoggerFactory.getLogger(javaClass) @@ -96,7 +98,7 @@ internal open class DefaultToolLoop( val callResult = llmMessageSender.call(state.conversationHistory, state.availableTools) accumulateUsage(callResult.usage, state) -9 + /* ------------------------------------------------- * Apply afterLlmCall callbacks - START * ------------------------------------------------- */ @@ -264,7 +266,7 @@ internal open class DefaultToolLoop( toolCall: ToolCall, ): Pair { logger.debug("Executing tool: {} with input: {}", toolCall.name, toolCall.arguments) - val result = tool.call(toolCall.arguments) + val result = tool.call(toolCall.arguments, toolCallContext) val content = when (result) { is Tool.Result.Text -> result.content is Tool.Result.WithArtifact -> result.content diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/loop/support/ParallelToolLoop.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/loop/support/ParallelToolLoop.kt index 8e25b4457..2babe7c06 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/loop/support/ParallelToolLoop.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/loop/support/ParallelToolLoop.kt @@ -17,6 +17,7 @@ package com.embabel.agent.spi.loop.support import com.embabel.agent.api.common.Asyncer import com.embabel.agent.api.tool.Tool +import com.embabel.agent.api.tool.ToolCallContext import com.embabel.agent.api.tool.ToolControlFlowSignal import com.embabel.agent.api.tool.config.ToolLoopConfiguration.ParallelModeProperties import com.embabel.agent.core.BlackboardUpdater @@ -67,6 +68,7 @@ internal class ParallelToolLoop( transformers: List = emptyList(), private val asyncer: Asyncer, private val parallelConfig: ParallelModeProperties, + toolCallContext: ToolCallContext = ToolCallContext.EMPTY, ) : DefaultToolLoop( llmMessageSender = llmMessageSender, objectMapper = objectMapper, @@ -75,6 +77,7 @@ internal class ParallelToolLoop( toolDecorator = toolDecorator, inspectors = inspectors, transformers = transformers, + toolCallContext = toolCallContext, ) { private val logger = LoggerFactory.getLogger(javaClass) diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ObservabilityToolCallback.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ObservabilityToolCallback.kt index 055482d92..1aa36f4b7 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ObservabilityToolCallback.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ObservabilityToolCallback.kt @@ -18,6 +18,7 @@ package com.embabel.agent.spi.support import com.embabel.common.util.loggerFor import io.micrometer.observation.Observation import io.micrometer.observation.ObservationRegistry +import org.springframework.ai.chat.model.ToolContext import org.springframework.ai.tool.ToolCallback import org.springframework.ai.tool.definition.ToolDefinition @@ -31,9 +32,15 @@ class ObservabilityToolCallback( override fun getToolDefinition(): ToolDefinition = delegate.toolDefinition - override fun call(toolInput: String): String { + override fun call(toolInput: String): String = + callWithObservation(toolInput) { delegate.call(toolInput) } + + override fun call(toolInput: String, toolContext: ToolContext?): String = + callWithObservation(toolInput) { delegate.call(toolInput, toolContext) } + + private inline fun callWithObservation(toolInput: String, action: () -> String): String { if (observationRegistry == null) { - return delegate.call(toolInput) + return action() } val currentObservation = observationRegistry.currentObservation if (currentObservation == null) { @@ -50,7 +57,7 @@ class ObservabilityToolCallback( .parentObservation(currentObservation) .start() return try { - val result = delegate.call(toolInput) + val result = action() observation.lowCardinalityKeyValue("status", "success") observation.highCardinalityKeyValue("result", result) result diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/OutputTransformingToolCallback.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/OutputTransformingToolCallback.kt index 3d62adcfd..753117f22 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/OutputTransformingToolCallback.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/OutputTransformingToolCallback.kt @@ -17,6 +17,7 @@ package com.embabel.agent.spi.support import com.embabel.common.util.StringTransformer import org.slf4j.LoggerFactory +import org.springframework.ai.chat.model.ToolContext import org.springframework.ai.tool.ToolCallback import org.springframework.ai.tool.definition.ToolDefinition @@ -32,8 +33,14 @@ class OutputTransformingToolCallback( override fun getToolDefinition(): ToolDefinition = delegate.toolDefinition - override fun call(toolInput: String): String { - val rawOutput = delegate.call(toolInput) + override fun call(toolInput: String): String = + transformOutput(toolInput) { delegate.call(toolInput) } + + override fun call(toolInput: String, toolContext: ToolContext?): String = + transformOutput(toolInput) { delegate.call(toolInput, toolContext) } + + private inline fun transformOutput(toolInput: String, action: () -> String): String { + val rawOutput = action() val transformed = outputTransformer.transform(rawOutput) logger.debug( "Tool {} called with input: {}, raw output: {}, transformed output: {}", diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolDecorators.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolDecorators.kt index 89467850d..a853339ea 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolDecorators.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolDecorators.kt @@ -18,6 +18,7 @@ package com.embabel.agent.spi.support import com.embabel.agent.api.event.ToolCallRequestEvent import com.embabel.agent.api.tool.DelegatingTool import com.embabel.agent.api.tool.Tool +import com.embabel.agent.api.tool.ToolCallContext import com.embabel.agent.api.tool.ToolControlFlowSignal import com.embabel.agent.core.Action import com.embabel.agent.core.AgentProcess @@ -68,9 +69,15 @@ class ObservabilityTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result { + override fun call(input: String): Tool.Result = + callWithObservation(input) { delegate.call(input) } + + override fun call(input: String, context: ToolCallContext): Tool.Result = + callWithObservation(input) { delegate.call(input, context) } + + private inline fun callWithObservation(input: String, action: () -> Tool.Result): Tool.Result { if (observationRegistry == null) { - return delegate.call(input) + return action() } val currentObservation = observationRegistry.currentObservation if (currentObservation == null) { @@ -87,7 +94,7 @@ class ObservabilityTool( .parentObservation(currentObservation) .start() return try { - val result = delegate.call(input) + val result = action() observation.lowCardinalityKeyValue("status", "success") observation.highCardinalityKeyValue("result", result.content) result @@ -118,8 +125,14 @@ class OutputTransformingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result { - val rawResult = delegate.call(input) + override fun call(input: String): Tool.Result = + transformOutput(input) { delegate.call(input) } + + override fun call(input: String, context: ToolCallContext): Tool.Result = + transformOutput(input) { delegate.call(input, context) } + + private inline fun transformOutput(input: String, action: () -> Tool.Result): Tool.Result { + val rawResult = action() val transformed = outputTransformer.transform(rawResult.content) logger.debug( "Tool {} called with input: {}, raw output: {}, transformed output: {}", @@ -145,13 +158,17 @@ class MetadataEnrichingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result { + override fun call(input: String): Tool.Result = + callWithMetadata(input) { delegate.call(input) } + + override fun call(input: String, context: ToolCallContext): Tool.Result = + callWithMetadata(input) { delegate.call(input, context) } + + private inline fun callWithMetadata(input: String, action: () -> Tool.Result): Tool.Result { try { - return delegate.call(input) + return action() } catch (t: Throwable) { if (t is ToolControlFlowSignal) { - // ToolControlFlowSignal exceptions are not failures - they are control flow signals - // (e.g., ReplanRequestedException, UserInputRequiredException) throw t } loggerFor().warn( @@ -179,10 +196,16 @@ class EventPublishingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result { + override fun call(input: String): Tool.Result = + callWithEvents(input) { delegate.call(input) } + + override fun call(input: String, context: ToolCallContext): Tool.Result = + callWithEvents(input) { delegate.call(input, context) } + + private inline fun callWithEvents(input: String, crossinline action: () -> Tool.Result): Tool.Result { val functionCallRequestEvent = ToolCallRequestEvent( agentProcess = agentProcess, - action = action, + action = this.action, llmOptions = llmOptions, tool = delegate.definition.name, toolGroupMetadata = (delegate as? MetadataEnrichingTool)?.toolGroupMetadata, @@ -194,7 +217,7 @@ class EventPublishingTool( agentProcess.processContext.onProcessEvent(functionCallRequestEvent) val (result: Result, millis) = time { try { - Result.success(delegate.call(input)) + Result.success(action()) } catch (t: Throwable) { Result.failure(t) } @@ -240,14 +263,17 @@ class ExceptionSuppressingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result { + override fun call(input: String): Tool.Result = + callSuppressing { delegate.call(input) } + + override fun call(input: String, context: ToolCallContext): Tool.Result = + callSuppressing { delegate.call(input, context) } + + private inline fun callSuppressing(action: () -> Tool.Result): Tool.Result { return try { - delegate.call(input) + action() } catch (t: Throwable) { - if (t is ToolControlFlowSignal) { - // ToolControlFlowSignal must propagate - it's a control flow signal, not an error - throw t - } + if (t is ToolControlFlowSignal) throw t Tool.Result.text("WARNING: Tool '${delegate.definition.name}' failed with exception: ${t.message ?: "No message"}") } } @@ -264,11 +290,17 @@ class AgentProcessBindingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result { + override fun call(input: String): Tool.Result = + callWithBinding { delegate.call(input) } + + override fun call(input: String, context: ToolCallContext): Tool.Result = + callWithBinding { delegate.call(input, context) } + + private inline fun callWithBinding(action: () -> Tool.Result): Tool.Result { val previousValue = AgentProcess.get() try { AgentProcess.set(agentProcess) - return delegate.call(input) + return action() } finally { if (previousValue != null) { AgentProcess.set(previousValue) diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolLoopLlmOperations.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolLoopLlmOperations.kt index 46e4b0fe1..5422f69dc 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolLoopLlmOperations.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolLoopLlmOperations.kt @@ -19,6 +19,7 @@ import com.embabel.agent.api.common.Asyncer import com.embabel.agent.api.event.LlmRequestEvent import com.embabel.agent.api.event.ToolLoopStartEvent import com.embabel.agent.api.tool.Tool +import com.embabel.agent.api.tool.ToolCallContext import com.embabel.agent.api.tool.config.ToolLoopConfiguration import com.embabel.agent.core.LlmInvocation import com.embabel.agent.core.ReplanRequestedException @@ -145,6 +146,9 @@ open class ToolLoopLlmOperations( val injectedToolDecorator = createInjectedToolDecorator(llmRequestEvent, interaction) val injectionStrategy = createInjectionStrategy(interaction) + // Merge process-level and interaction-level context (interaction wins on conflict) + val effectiveContext = resolveToolCallContext(llmRequestEvent, interaction) + val toolLoop = toolLoopFactory.create( llmMessageSender = messageSender, objectMapper = objectMapper, @@ -153,6 +157,7 @@ open class ToolLoopLlmOperations( toolDecorator = injectedToolDecorator, inspectors = interaction.inspectors, transformers = interaction.transformers, + toolCallContext = effectiveContext, ) val initialMessages = buildInitialMessages(promptContributions, messages, schemaFormat) @@ -228,6 +233,9 @@ open class ToolLoopLlmOperations( ToolInjectionStrategy.DEFAULT } + // Merge process-level and interaction-level context (interaction wins on conflict) + val effectiveContext = resolveToolCallContext(llmRequestEvent, interaction) + val toolLoop = toolLoopFactory.create( llmMessageSender = messageSender, objectMapper = objectMapper, @@ -236,6 +244,7 @@ open class ToolLoopLlmOperations( toolDecorator = injectedToolDecorator, inspectors = interaction.inspectors, transformers = interaction.transformers, + toolCallContext = effectiveContext, ) // Build MaybeReturn prompt contribution @@ -572,6 +581,24 @@ open class ToolLoopLlmOperations( } } + /** + * Resolve the effective [ToolCallContext] by merging process-level context + * (from [ProcessOptions]) with interaction-level context. + * Interaction-level values win on conflict. + */ + private fun resolveToolCallContext( + llmRequestEvent: LlmRequestEvent<*>?, + interaction: LlmInteraction, + ): ToolCallContext { + val processContext = llmRequestEvent + ?.agentProcess + ?.processContext + ?.processOptions + ?.toolCallContext + ?: ToolCallContext.EMPTY + return processContext.merge(interaction.toolCallContext) + } + /** * Check if examples should be generated based on properties and interaction settings. */ diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/SpringToolCallbackAdapter.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/SpringToolCallbackAdapter.kt index 3b616315e..9828ec4c7 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/SpringToolCallbackAdapter.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/SpringToolCallbackAdapter.kt @@ -16,6 +16,7 @@ package com.embabel.agent.spi.support.springai import com.embabel.agent.api.tool.Tool +import com.embabel.agent.api.tool.ToolCallContext import org.slf4j.LoggerFactory import org.springframework.ai.chat.model.ToolContext import org.springframework.ai.tool.ToolCallback @@ -70,12 +71,25 @@ class SpringToolCallbackAdapter( } /** - * Override to avoid Spring AI's default warning about unused ToolContext. - * Embabel manages context through [com.embabel.agent.core.AgentProcess] thread-local - * rather than Spring AI's ToolContext. + * Bridges Spring AI's ToolContext with Embabel's ToolCallContext. + * Converts any incoming Spring AI ToolContext to an Embabel ToolCallContext + * and passes it explicitly to the tool. */ override fun call(toolInput: String, toolContext: ToolContext?): String { - return call(toolInput) + val context = toolContext?.let { ToolCallContext.of(it.context) } ?: ToolCallContext.EMPTY + return try { + when (val result = tool.call(toolInput, context)) { + is Tool.Result.Text -> result.content + is Tool.Result.WithArtifact -> result.content + is Tool.Result.Error -> { + logger.warn("Tool '{}' returned error: {}", tool.definition.name, result.message) + "ERROR: ${result.message}" + } + } + } catch (e: Exception) { + logger.error("Tool '{}' threw exception: {}", tool.definition.name, e.message, e) + "ERROR: ${e.message ?: "Unknown error"}" + } } } @@ -112,7 +126,19 @@ class SpringToolCallbackWrapper( override fun call(input: String): Tool.Result { return try { - val result = callback.call(input) + Tool.Result.text(callback.call(input)) + } catch (e: Exception) { + Tool.Result.error(e.message ?: "Tool execution failed", e) + } + } + + override fun call(input: String, context: ToolCallContext): Tool.Result { + return try { + val result = if (context.isEmpty) { + callback.call(input) + } else { + callback.call(input, ToolContext(context.toMap())) + } Tool.Result.text(result) } catch (e: Exception) { Tool.Result.error(e.message ?: "Tool execution failed", e) diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/tools/mcp/ToolCallContextMcpMetaConverter.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/tools/mcp/ToolCallContextMcpMetaConverter.kt new file mode 100644 index 000000000..134d69e6f --- /dev/null +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/tools/mcp/ToolCallContextMcpMetaConverter.kt @@ -0,0 +1,132 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.tools.mcp + +import com.embabel.agent.api.tool.ToolCallContext + +/** + * Converts a [ToolCallContext] into MCP `_meta` metadata for outbound MCP tool calls. + * + * This is the gateway filter that controls what context entries cross the process boundary + * when Embabel calls tools on remote MCP servers. Think of it like an HTTP header filter + * on a reverse proxy: the converter decides which entries are safe and relevant to propagate + * to a third-party server, and which should stay local. + * + * ## Why This Matters + * + * A [ToolCallContext] may carry sensitive entries (API keys, auth tokens, tenant secrets) + * alongside benign metadata (tenant IDs, correlation IDs, user preferences). Without filtering, + * all entries would be sent as MCP `_meta` to every MCP server — including untrusted third-party + * servers that should never see secrets. + * + * ## Usage + * + * The converter is used by [com.embabel.agent.spi.support.springai.SpringToolCallbackWrapper] + * when bridging Embabel tools to Spring AI MCP callbacks. + * + * ### Default behavior + * + * The [passThrough] converter propagates all entries. Use this only when all MCP servers + * are trusted (e.g., internal infrastructure). + * + * ### Allowlist approach (recommended for production) + * ```kotlin + * val converter = ToolCallContextMcpMetaConverter.allowKeys("tenantId", "correlationId", "locale") + * ``` + * + * ### Denylist approach + * ```kotlin + * val converter = ToolCallContextMcpMetaConverter.denyKeys("apiKey", "secretToken", "authHeader") + * ``` + * + * ### Custom logic + * ```kotlin + * val converter = ToolCallContextMcpMetaConverter { context -> + * mapOf( + * "tenantId" to (context.get("tenantId") ?: "unknown"), + * "requestedAt" to Instant.now().toString(), + * ) + * } + * ``` + * + * ### Spring Bean (applied globally) + * ```kotlin + * @Bean + * fun toolCallContextMcpMetaConverter() = + * ToolCallContextMcpMetaConverter.allowKeys("tenantId", "correlationId") + * ``` + * + * If no bean is defined, the framework defaults to [passThrough] for backward compatibility. + * + * @see com.embabel.agent.api.tool.ToolCallContext + */ +fun interface ToolCallContextMcpMetaConverter { + + /** + * Convert a [ToolCallContext] to a metadata map suitable for MCP `_meta`. + * + * @param context The full tool call context from the current execution + * @return A filtered/transformed map of metadata to send with the MCP tool call. + * An empty map means no metadata will be attached. + */ + fun convert(context: ToolCallContext): Map + + companion object { + + /** + * Converter that propagates all context entries as MCP metadata. + * Use only when all MCP servers are trusted. + */ + @JvmStatic + fun passThrough(): ToolCallContextMcpMetaConverter = + ToolCallContextMcpMetaConverter { it.toMap() } + + /** + * Converter that suppresses all context — no metadata is sent to MCP servers. + */ + @JvmStatic + fun noOp(): ToolCallContextMcpMetaConverter = + ToolCallContextMcpMetaConverter { emptyMap() } + + /** + * Converter that only propagates entries whose keys are in the allowlist. + * This is the recommended approach for production: explicitly declare + * what crosses the boundary. + * + * @param keys The keys to allow through + */ + @JvmStatic + fun allowKeys(vararg keys: String): ToolCallContextMcpMetaConverter { + val allowed = keys.toSet() + return ToolCallContextMcpMetaConverter { context -> + context.toMap().filterKeys { it in allowed } + } + } + + /** + * Converter that propagates all entries except those whose keys match the denylist. + * + * @param keys The keys to exclude + */ + @JvmStatic + fun denyKeys(vararg keys: String): ToolCallContextMcpMetaConverter { + val denied = keys.toSet() + return ToolCallContextMcpMetaConverter { context -> + context.toMap().filterKeys { it !in denied } + } + } + } +} diff --git a/embabel-agent-api/src/test/java/com/embabel/agent/api/tool/MethodToolContextInjectionJavaTest.java b/embabel-agent-api/src/test/java/com/embabel/agent/api/tool/MethodToolContextInjectionJavaTest.java new file mode 100644 index 000000000..fa9603789 --- /dev/null +++ b/embabel-agent-api/src/test/java/com/embabel/agent/api/tool/MethodToolContextInjectionJavaTest.java @@ -0,0 +1,301 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.api.tool; + +import com.embabel.agent.api.annotation.LlmTool; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests verifying that {@link MethodTool} (Java variant) correctly handles + * {@link ToolCallContext} injection into Java {@code @LlmTool}-annotated methods. + * + *

These tests exercise the {@code JavaMethodTool} path, which uses + * {@code java.lang.reflect.Method} and positional argument arrays rather + * than Kotlin's {@code KFunction} and {@code callBy}. + */ +class MethodToolContextInjectionJavaTest { + + private final ObjectMapper objectMapper = new ObjectMapper(); + + // ---- Java test fixtures ---- + + /** + * Tool with a ToolCallContext parameter alongside a regular parameter. + */ + public static class JavaContextAwareTool { + public ToolCallContext lastContext; + + @LlmTool(description = "Search with auth context") + public String search( + @LlmTool.Param(description = "Search query") String query, + ToolCallContext context + ) { + this.lastContext = context; + String token = context.get("authToken"); + return "Results for '" + query + "' with token=" + (token != null ? token : "none"); + } + } + + /** + * Tool with only a ToolCallContext parameter — no LLM-facing parameters. + */ + public static class JavaContextOnlyTool { + public ToolCallContext lastContext; + + @LlmTool(description = "Audit action") + public String audit(ToolCallContext context) { + this.lastContext = context; + String userId = context.get("userId"); + return "Audit logged for " + (userId != null ? userId : "anonymous"); + } + } + + /** + * Tool without any ToolCallContext parameter — backward compatibility. + */ + public static class JavaNoContextTool { + @LlmTool(description = "Simple greeting") + public String greet(@LlmTool.Param(description = "Name") String name) { + return "Hello, " + name + "!"; + } + } + + /** + * Tool with multiple regular parameters and ToolCallContext. + */ + public static class JavaMultiParamTool { + public ToolCallContext lastContext; + + @LlmTool(description = "Transfer funds") + public String transfer( + @LlmTool.Param(description = "Source account") String from, + @LlmTool.Param(description = "Destination account") String to, + @LlmTool.Param(description = "Amount") int amount, + ToolCallContext context + ) { + this.lastContext = context; + String tenantId = context.get("tenantId"); + return "Transferred " + amount + " from " + from + " to " + to + + " (tenant=" + (tenantId != null ? tenantId : "unknown") + ")"; + } + } + + /** + * Tool where ToolCallContext appears in the middle of the parameter list. + */ + public static class JavaContextInMiddleTool { + public ToolCallContext lastContext; + + @LlmTool(description = "Process with context in middle") + public String process( + @LlmTool.Param(description = "Input") String input, + ToolCallContext context, + @LlmTool.Param(description = "Mode") String mode + ) { + this.lastContext = context; + return input + ":" + mode; + } + } + + @Nested + class ContextInjection { + + @Test + void contextIsInjectedIntoMethodWithToolCallContextParameter() { + var instance = new JavaContextAwareTool(); + var tools = Tool.fromInstance(instance, objectMapper); + var tool = tools.get(0); + + var ctx = ToolCallContext.of(Map.of("authToken", "bearer-secret-123")); + var result = tool.call("{\"query\":\"embabel agent\"}", ctx); + + assertInstanceOf(Tool.Result.Text.class, result); + assertEquals( + "Results for 'embabel agent' with token=bearer-secret-123", + ((Tool.Result.Text) result).getContent() + ); + assertNotNull(instance.lastContext); + assertEquals("bearer-secret-123", instance.lastContext.get("authToken")); + } + + @Test + void emptyContextIsInjectedWhenSingleArgCallIsUsed() { + var instance = new JavaContextAwareTool(); + var tools = Tool.fromInstance(instance, objectMapper); + var tool = tools.get(0); + + var result = tool.call("{\"query\":\"test\"}"); + + assertInstanceOf(Tool.Result.Text.class, result); + assertEquals( + "Results for 'test' with token=none", + ((Tool.Result.Text) result).getContent() + ); + assertNotNull(instance.lastContext); + assertTrue(instance.lastContext.isEmpty()); + } + + @Test + void contextOnlyToolReceivesContext() { + var instance = new JavaContextOnlyTool(); + var tools = Tool.fromInstance(instance, objectMapper); + var tool = tools.get(0); + + var ctx = ToolCallContext.of(Map.of("userId", "user-42")); + var result = tool.call("{}", ctx); + + assertInstanceOf(Tool.Result.Text.class, result); + assertEquals("Audit logged for user-42", ((Tool.Result.Text) result).getContent()); + assertEquals("user-42", instance.lastContext.get("userId")); + } + + @Test + void methodWithoutToolCallContextWorksNormally() { + var instance = new JavaNoContextTool(); + var tools = Tool.fromInstance(instance, objectMapper); + var tool = tools.get(0); + + var result = tool.call("{\"name\":\"Claude\"}"); + + assertInstanceOf(Tool.Result.Text.class, result); + assertEquals("Hello, Claude!", ((Tool.Result.Text) result).getContent()); + } + + @Test + void methodWithoutToolCallContextIgnoresProvidedContext() { + var instance = new JavaNoContextTool(); + var tools = Tool.fromInstance(instance, objectMapper); + var tool = tools.get(0); + + var ctx = ToolCallContext.of(Map.of("authToken", "should-be-ignored")); + var result = tool.call("{\"name\":\"Claude\"}", ctx); + + assertInstanceOf(Tool.Result.Text.class, result); + assertEquals("Hello, Claude!", ((Tool.Result.Text) result).getContent()); + } + + @Test + void contextWorksAlongsideMultipleRegularParameters() { + var instance = new JavaMultiParamTool(); + var tools = Tool.fromInstance(instance, objectMapper); + var tool = tools.get(0); + + var ctx = ToolCallContext.of(Map.of("tenantId", "acme-corp")); + var result = tool.call( + "{\"from\":\"checking\",\"to\":\"savings\",\"amount\":500}", + ctx + ); + + assertInstanceOf(Tool.Result.Text.class, result); + assertEquals( + "Transferred 500 from checking to savings (tenant=acme-corp)", + ((Tool.Result.Text) result).getContent() + ); + assertEquals("acme-corp", instance.lastContext.get("tenantId")); + } + + @Test + void contextWorksWhenDeclaredInMiddleOfParameterList() { + var instance = new JavaContextInMiddleTool(); + var tools = Tool.fromInstance(instance, objectMapper); + var tool = tools.get(0); + + var ctx = ToolCallContext.of(Map.of("traceId", "trace-abc")); + var result = tool.call("{\"input\":\"data\",\"mode\":\"fast\"}", ctx); + + assertInstanceOf(Tool.Result.Text.class, result); + assertEquals("data:fast", ((Tool.Result.Text) result).getContent()); + assertNotNull(instance.lastContext); + assertEquals("trace-abc", instance.lastContext.get("traceId")); + } + } + + @Nested + class SchemaExclusion { + + @Test + @SuppressWarnings("unchecked") + void toolCallContextParameterIsExcludedFromInputSchema() throws Exception { + var instance = new JavaContextAwareTool(); + var tools = Tool.fromInstance(instance, objectMapper); + var tool = tools.get(0); + + var schema = tool.getDefinition().getInputSchema().toJsonSchema(); + var schemaMap = objectMapper.readValue(schema, Map.class); + var properties = (Map) schemaMap.get("properties"); + + assertTrue(properties.containsKey("query"), "Schema should include 'query' parameter"); + assertFalse(properties.containsKey("context"), "Schema must NOT include ToolCallContext parameter"); + } + + @Test + @SuppressWarnings("unchecked") + void schemaForContextOnlyToolHasNoParameters() throws Exception { + var instance = new JavaContextOnlyTool(); + var tools = Tool.fromInstance(instance, objectMapper); + var tool = tools.get(0); + + var schema = tool.getDefinition().getInputSchema().toJsonSchema(); + var schemaMap = objectMapper.readValue(schema, Map.class); + var properties = (Map) schemaMap.getOrDefault("properties", Map.of()); + + assertTrue(properties.isEmpty(), + "Schema should have no properties when only ToolCallContext is declared"); + } + + @Test + @SuppressWarnings("unchecked") + void schemaForMultiParamToolExcludesOnlyToolCallContext() throws Exception { + var instance = new JavaMultiParamTool(); + var tools = Tool.fromInstance(instance, objectMapper); + var tool = tools.get(0); + + var schema = tool.getDefinition().getInputSchema().toJsonSchema(); + var schemaMap = objectMapper.readValue(schema, Map.class); + var properties = (Map) schemaMap.get("properties"); + + assertEquals(3, properties.size(), "Should have exactly 3 parameters (from, to, amount)"); + assertTrue(properties.containsKey("from")); + assertTrue(properties.containsKey("to")); + assertTrue(properties.containsKey("amount")); + assertFalse(properties.containsKey("context"), "ToolCallContext must be excluded"); + } + + @Test + @SuppressWarnings("unchecked") + void schemaExcludesContextWhenInMiddleOfParameterList() throws Exception { + var instance = new JavaContextInMiddleTool(); + var tools = Tool.fromInstance(instance, objectMapper); + var tool = tools.get(0); + + var schema = tool.getDefinition().getInputSchema().toJsonSchema(); + var schemaMap = objectMapper.readValue(schema, Map.class); + var properties = (Map) schemaMap.get("properties"); + + assertEquals(2, properties.size(), "Should have exactly 2 parameters (input, mode)"); + assertTrue(properties.containsKey("input")); + assertTrue(properties.containsKey("mode")); + assertFalse(properties.containsKey("context"), "ToolCallContext must be excluded"); + } + } +} diff --git a/embabel-agent-api/src/test/java/com/embabel/agent/api/tool/ToolCallContextJavaTest.java b/embabel-agent-api/src/test/java/com/embabel/agent/api/tool/ToolCallContextJavaTest.java new file mode 100644 index 000000000..ca260328a --- /dev/null +++ b/embabel-agent-api/src/test/java/com/embabel/agent/api/tool/ToolCallContextJavaTest.java @@ -0,0 +1,88 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.api.tool; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Verifies that {@link ToolCallContext} is idiomatic to use from Java. + */ +class ToolCallContextJavaTest { + + @Nested + class FactoryMethods { + + @Test + void createFromMap() { + var ctx = ToolCallContext.of(Map.of("token", "abc", "tenant", "acme")); + assertEquals("abc", ctx.get("token")); + assertEquals("acme", ctx.get("tenant")); + assertFalse(ctx.isEmpty()); + } + + @Test + void emptyContextIsAvailable() { + var ctx = ToolCallContext.EMPTY; + assertTrue(ctx.isEmpty()); + assertNull(ctx.get("anything")); + } + } + + @Nested + class Merge { + + @Test + void mergeGivesOtherPrecedence() { + var base = ToolCallContext.of(Map.of("key", "old", "extra", "kept")); + var override = ToolCallContext.of(Map.of("key", "new")); + var merged = base.merge(override); + assertEquals("new", merged.get("key")); + assertEquals("kept", merged.get("extra")); + } + } + + @Nested + class GetOrDefault { + + @Test + void returnsDefaultForMissingKey() { + var ctx = ToolCallContext.of(Map.of("a", 1)); + assertEquals("fallback", ctx.getOrDefault("missing", "fallback")); + } + + @Test + void returnsValueWhenPresent() { + var ctx = ToolCallContext.of(Map.of("a", 1)); + assertEquals(1, ctx.getOrDefault("a", 99)); + } + } + + @Nested + class WithProcessOptions { + + @Test + void processOptionsWitherAcceptsMap() { + var options = new com.embabel.agent.core.ProcessOptions() + .withToolCallContext(Map.of("authToken", "bearer-xyz")); + assertEquals("bearer-xyz", options.getToolCallContext().get("authToken")); + } + } +} diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/MethodToolContextInjectionTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/MethodToolContextInjectionTest.kt new file mode 100644 index 000000000..a21f5af2d --- /dev/null +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/MethodToolContextInjectionTest.kt @@ -0,0 +1,315 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.api.tool + +import com.embabel.agent.api.annotation.LlmTool +import com.embabel.agent.api.annotation.LlmTool.Param +import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test + +/** + * Tests verifying that [MethodTool] (both Kotlin and Java variants) correctly + * handles [ToolCallContext] injection into `@LlmTool`-annotated methods. + * + * Covers: + * - Context is injected when method declares a [ToolCallContext] parameter + * - Context parameter is excluded from the JSON input schema sent to the LLM + * - Methods without [ToolCallContext] parameter continue to work (backward compat) + * - EMPTY context is injected when no context is provided + * - Context works alongside regular parameters + */ +class MethodToolContextInjectionTest { + + private val objectMapper = jacksonObjectMapper() + + // ---- Kotlin test fixtures ---- + + class ContextAwareTools { + var lastContext: ToolCallContext? = null + + @LlmTool(description = "Search with auth context") + fun search( + @Param(description = "Search query") query: String, + context: ToolCallContext, + ): String { + lastContext = context + val token = context.get("authToken") ?: "none" + return "Results for '$query' with token=$token" + } + } + + class ContextOnlyTool { + var lastContext: ToolCallContext? = null + + @LlmTool(description = "Tool that only takes context") + fun audit(context: ToolCallContext): String { + lastContext = context + val userId = context.get("userId") ?: "anonymous" + return "Audit logged for $userId" + } + } + + class NoContextTools { + @LlmTool(description = "Simple greeting") + fun greet(@Param(description = "Name to greet") name: String): String { + return "Hello, $name!" + } + } + + class MultiParamWithContext { + var lastContext: ToolCallContext? = null + + @LlmTool(description = "Transfer funds between accounts") + fun transfer( + @Param(description = "Source account") from: String, + @Param(description = "Destination account") to: String, + @Param(description = "Amount to transfer") amount: Int, + context: ToolCallContext, + ): String { + lastContext = context + val tenantId = context.get("tenantId") ?: "unknown" + return "Transferred $amount from $from to $to (tenant=$tenantId)" + } + } + + class ContextWithOptionalParam { + var lastContext: ToolCallContext? = null + + @LlmTool(description = "Fetch data with optional format") + fun fetch( + @Param(description = "Resource ID") id: String, + @Param(description = "Response format", required = false) format: String = "json", + context: ToolCallContext, + ): String { + lastContext = context + return "Fetched $id as $format" + } + } + + @Nested + inner class KotlinContextInjection { + + @Test + fun `context is injected into method with ToolCallContext parameter`() { + val instance = ContextAwareTools() + val tools = Tool.fromInstance(instance, objectMapper) + val tool = tools.single() + + val ctx = ToolCallContext.of("authToken" to "bearer-secret-123") + val result = tool.call("""{"query":"embabel agent"}""", ctx) + + assertTrue(result is Tool.Result.Text) + assertEquals( + "Results for 'embabel agent' with token=bearer-secret-123", + (result as Tool.Result.Text).content, + ) + assertNotNull(instance.lastContext) + assertEquals("bearer-secret-123", instance.lastContext!!.get("authToken")) + } + + @Test + fun `EMPTY context is injected when single-arg call is used`() { + val instance = ContextAwareTools() + val tools = Tool.fromInstance(instance, objectMapper) + val tool = tools.single() + + val result = tool.call("""{"query":"test"}""") + + assertTrue(result is Tool.Result.Text) + assertEquals( + "Results for 'test' with token=none", + (result as Tool.Result.Text).content, + ) + assertNotNull(instance.lastContext) + assertTrue(instance.lastContext!!.isEmpty) + } + + @Test + fun `method with only ToolCallContext parameter receives context`() { + val instance = ContextOnlyTool() + val tools = Tool.fromInstance(instance, objectMapper) + val tool = tools.single() + + val ctx = ToolCallContext.of("userId" to "user-42") + val result = tool.call("{}", ctx) + + assertTrue(result is Tool.Result.Text) + assertEquals("Audit logged for user-42", (result as Tool.Result.Text).content) + assertEquals("user-42", instance.lastContext!!.get("userId")) + } + + @Test + fun `method without ToolCallContext parameter works normally`() { + val instance = NoContextTools() + val tools = Tool.fromInstance(instance, objectMapper) + val tool = tools.single() + + val result = tool.call("""{"name":"Claude"}""") + + assertTrue(result is Tool.Result.Text) + assertEquals("Hello, Claude!", (result as Tool.Result.Text).content) + } + + @Test + fun `method without ToolCallContext parameter ignores provided context`() { + val instance = NoContextTools() + val tools = Tool.fromInstance(instance, objectMapper) + val tool = tools.single() + + val ctx = ToolCallContext.of("authToken" to "should-be-ignored") + val result = tool.call("""{"name":"Claude"}""", ctx) + + assertTrue(result is Tool.Result.Text) + assertEquals("Hello, Claude!", (result as Tool.Result.Text).content) + } + + @Test + fun `context works alongside multiple regular parameters`() { + val instance = MultiParamWithContext() + val tools = Tool.fromInstance(instance, objectMapper) + val tool = tools.single() + + val ctx = ToolCallContext.of("tenantId" to "acme-corp") + val result = tool.call( + """{"from":"checking","to":"savings","amount":500}""", + ctx, + ) + + assertTrue(result is Tool.Result.Text) + assertEquals( + "Transferred 500 from checking to savings (tenant=acme-corp)", + (result as Tool.Result.Text).content, + ) + assertEquals("acme-corp", instance.lastContext!!.get("tenantId")) + } + + @Test + fun `context works with optional parameters using default values`() { + val instance = ContextWithOptionalParam() + val tools = Tool.fromInstance(instance, objectMapper) + val tool = tools.single() + + val ctx = ToolCallContext.of("traceId" to "trace-abc") + val result = tool.call("""{"id":"resource-1"}""", ctx) + + assertTrue(result is Tool.Result.Text) + assertEquals("Fetched resource-1 as json", (result as Tool.Result.Text).content) + assertNotNull(instance.lastContext) + } + + @Test + fun `context with multiple entries is fully available in method`() { + val instance = ContextAwareTools() + val tools = Tool.fromInstance(instance, objectMapper) + val tool = tools.single() + + val ctx = ToolCallContext.of( + "authToken" to "bearer-xyz", + "tenantId" to "acme", + "correlationId" to "req-123", + ) + tool.call("""{"query":"test"}""", ctx) + + val captured = instance.lastContext!! + assertEquals("bearer-xyz", captured.get("authToken")) + assertEquals("acme", captured.get("tenantId")) + assertEquals("req-123", captured.get("correlationId")) + } + } + + @Nested + inner class SchemaExclusion { + + @Test + fun `ToolCallContext parameter is excluded from input schema`() { + val instance = ContextAwareTools() + val tools = Tool.fromInstance(instance, objectMapper) + val tool = tools.single() + + val schema = tool.definition.inputSchema.toJsonSchema() + val schemaMap = objectMapper.readValue(schema, Map::class.java) + + @Suppress("UNCHECKED_CAST") + val properties = schemaMap["properties"] as Map + + assertTrue("query" in properties, "Schema should include 'query' parameter") + assertFalse("context" in properties, "Schema must NOT include 'context' (ToolCallContext) parameter") + } + + @Test + fun `schema has correct required fields excluding ToolCallContext`() { + val instance = ContextAwareTools() + val tools = Tool.fromInstance(instance, objectMapper) + val tool = tools.single() + + val schema = tool.definition.inputSchema.toJsonSchema() + val schemaMap = objectMapper.readValue(schema, Map::class.java) + + @Suppress("UNCHECKED_CAST") + val required = schemaMap["required"] as? List ?: emptyList() + + assertTrue("query" in required, "'query' should be required") + assertFalse("context" in required, "'context' must NOT appear in required") + } + + @Test + fun `schema for context-only tool has no parameters`() { + val instance = ContextOnlyTool() + val tools = Tool.fromInstance(instance, objectMapper) + val tool = tools.single() + + val schema = tool.definition.inputSchema.toJsonSchema() + val schemaMap = objectMapper.readValue(schema, Map::class.java) + + @Suppress("UNCHECKED_CAST") + val properties = schemaMap["properties"] as? Map ?: emptyMap() + + assertTrue(properties.isEmpty(), "Schema should have no properties when only ToolCallContext is declared") + } + + @Test + fun `schema for multi-param tool excludes only ToolCallContext`() { + val instance = MultiParamWithContext() + val tools = Tool.fromInstance(instance, objectMapper) + val tool = tools.single() + + val schema = tool.definition.inputSchema.toJsonSchema() + val schemaMap = objectMapper.readValue(schema, Map::class.java) + + @Suppress("UNCHECKED_CAST") + val properties = schemaMap["properties"] as Map + + assertEquals(3, properties.size, "Should have exactly 3 parameters (from, to, amount)") + assertTrue("from" in properties) + assertTrue("to" in properties) + assertTrue("amount" in properties) + assertFalse("context" in properties, "ToolCallContext must be excluded") + } + + @Test + fun `parameter count matches schema parameters excluding context`() { + val instance = ContextAwareTools() + val tools = Tool.fromInstance(instance, objectMapper) + val tool = tools.single() + + val params = tool.definition.inputSchema.parameters + assertEquals(1, params.size, "Should have 1 parameter (query only, not context)") + assertEquals("query", params[0].name) + } + } +} diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/ToolCallContextFlowTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/ToolCallContextFlowTest.kt new file mode 100644 index 000000000..66cf254e2 --- /dev/null +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/ToolCallContextFlowTest.kt @@ -0,0 +1,443 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.api.tool + +import com.embabel.agent.spi.loop.LlmMessageResponse +import com.embabel.agent.spi.loop.LlmMessageSender +import com.embabel.agent.spi.loop.support.DefaultToolLoop +import com.embabel.agent.spi.support.ObservabilityToolCallback +import com.embabel.agent.spi.support.OutputTransformingToolCallback +import com.embabel.agent.spi.support.springai.SpringToolCallbackAdapter +import com.embabel.agent.spi.support.springai.SpringToolCallbackWrapper +import com.embabel.chat.AssistantMessage +import com.embabel.chat.AssistantMessageWithToolCalls +import com.embabel.chat.Message +import com.embabel.chat.ToolCall +import com.embabel.chat.UserMessage +import com.embabel.common.util.StringTransformer +import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper +import io.micrometer.observation.ObservationRegistry +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test +import org.springframework.ai.chat.model.ToolContext +import org.springframework.ai.tool.ToolCallback +import org.springframework.ai.tool.definition.DefaultToolDefinition +import org.springframework.ai.tool.metadata.DefaultToolMetadata + +/** + * End-to-end tests verifying that [ToolCallContext] flows correctly through + * the tool execution pipeline: DefaultToolLoop → Tool → Spring AI bridges. + * + * All context is passed explicitly — no ThreadLocal is used anywhere in the pipeline. + * + * These tests validate the complete implementation of GitHub issue #1323: + * "Allow metadata to be passed to MCP calls". + */ +class ToolCallContextFlowTest { + + private val objectMapper = jacksonObjectMapper() + + @Nested + inner class DefaultToolLoopContextFlow { + + @Test + fun `tool receives context from tool loop`() { + var receivedContext: ToolCallContext? = null + val tool = Tool.of( + name = "ctx_tool", + description = "Context test", + ) { _, context -> + receivedContext = context + Tool.Result.text("ok") + } + val context = ToolCallContext.of("authToken" to "bearer-xyz", "tenantId" to "acme") + val toolLoop = DefaultToolLoop( + llmMessageSender = singleToolCallThenAnswer("ctx_tool"), + objectMapper = objectMapper, + toolCallContext = context, + ) + toolLoop.execute( + initialMessages = listOf(UserMessage("go")), + initialTools = listOf(tool), + outputParser = { it }, + ) + assertNotNull(receivedContext) + assertEquals("bearer-xyz", receivedContext!!.get("authToken")) + assertEquals("acme", receivedContext!!.get("tenantId")) + } + + @Test + fun `tool receives EMPTY context when none configured`() { + var receivedContext: ToolCallContext? = null + val tool = Tool.of( + name = "no_ctx_tool", + description = "No context test", + ) { _, context -> + receivedContext = context + Tool.Result.text("ok") + } + val toolLoop = DefaultToolLoop( + llmMessageSender = singleToolCallThenAnswer("no_ctx_tool"), + objectMapper = objectMapper, + ) + toolLoop.execute( + initialMessages = listOf(UserMessage("go")), + initialTools = listOf(tool), + outputParser = { it }, + ) + assertNotNull(receivedContext) + assertTrue(receivedContext!!.isEmpty) + } + + @Test + fun `context is available across multiple tool calls`() { + val captured = mutableListOf() + val tool = Tool.of( + name = "multi_tool", + description = "Multi call test", + ) { _, context -> + captured.add(context) + Tool.Result.text("ok") + } + val context = ToolCallContext.of("key" to "value") + val mockCaller = object : LlmMessageSender { + private var call = 0 + override fun call(messages: List, tools: List): LlmMessageResponse { + call++ + return when (call) { + 1 -> toolCallResponse("call_1", "multi_tool", "{}") + 2 -> toolCallResponse("call_2", "multi_tool", "{}") + else -> textResponse("done") + } + } + } + val toolLoop = DefaultToolLoop( + llmMessageSender = mockCaller, + objectMapper = objectMapper, + toolCallContext = context, + ) + toolLoop.execute( + initialMessages = listOf(UserMessage("go")), + initialTools = listOf(tool), + outputParser = { it }, + ) + assertEquals(2, captured.size) + captured.forEach { assertEquals("value", it.get("key")) } + } + } + + @Nested + inner class DefaultCallBehavior { + + @Test + fun `default two-arg call discards context for legacy tools`() { + var singleArgCalled = false + val legacyTool = object : Tool { + override val definition = Tool.Definition("legacy", "Legacy tool", Tool.InputSchema.empty()) + override fun call(input: String): Tool.Result { + singleArgCalled = true + return Tool.Result.text("legacy-ok") + } + } + val ctx = ToolCallContext.of("secret" to "42") + val result = legacyTool.call("{}", ctx) + assertTrue(singleArgCalled, "Single-arg call should be invoked via default delegation") + assertTrue(result is Tool.Result.Text) + assertEquals("legacy-ok", (result as Tool.Result.Text).content) + } + + @Test + fun `context-aware tools receive context explicitly`() { + var receivedContext: ToolCallContext? = null + val tool = Tool.of( + name = "aware_tool", + description = "Context-aware", + ) { _, context -> + receivedContext = context + Tool.Result.text("ok") + } + val ctx = ToolCallContext.of("authToken" to "bearer-abc") + tool.call("""{"query":"test"}""", ctx) + assertEquals("bearer-abc", receivedContext!!.get("authToken")) + } + + @Test + fun `context-aware tool receives EMPTY context via single-arg call`() { + var receivedContext: ToolCallContext? = null + val tool = Tool.of( + name = "no_ctx_tool", + description = "No context", + ) { _, context -> + receivedContext = context + Tool.Result.text("ok") + } + // Single-arg call — context-aware factory wraps with EMPTY + tool.call("{}") + assertNotNull(receivedContext) + assertTrue(receivedContext!!.isEmpty) + } + } + + @Nested + inner class SpringToolCallbackWrapperContextBridging { + + @Test + fun `wrapper bridges ToolCallContext to Spring AI ToolContext`() { + var receivedToolContext: ToolContext? = null + val springCallback = object : ToolCallback { + override fun getToolDefinition() = DefaultToolDefinition.builder() + .name("mcp_tool").description("MCP tool").inputSchema("{}").build() + override fun getToolMetadata() = DefaultToolMetadata.builder().build() + override fun call(toolInput: String) = "no-context" + override fun call(toolInput: String, toolContext: ToolContext?): String { + receivedToolContext = toolContext + return "with-context" + } + } + val wrapper = SpringToolCallbackWrapper(springCallback) + val ctx = ToolCallContext.of("authToken" to "xyz", "tenantId" to "acme") + val result = wrapper.call("{}", ctx) + assertTrue(result is Tool.Result.Text) + assertEquals("with-context", (result as Tool.Result.Text).content) + assertNotNull(receivedToolContext) + assertEquals("xyz", receivedToolContext!!.context["authToken"]) + assertEquals("acme", receivedToolContext!!.context["tenantId"]) + } + + @Test + fun `wrapper calls without ToolContext when context is empty`() { + var calledWithContext = false + val springCallback = object : ToolCallback { + override fun getToolDefinition() = DefaultToolDefinition.builder() + .name("test").description("").inputSchema("{}").build() + override fun getToolMetadata() = DefaultToolMetadata.builder().build() + override fun call(toolInput: String): String { + calledWithContext = false + return "no-context" + } + override fun call(toolInput: String, toolContext: ToolContext?): String { + calledWithContext = true + return "with-context" + } + } + val wrapper = SpringToolCallbackWrapper(springCallback) + wrapper.call("{}", ToolCallContext.EMPTY) + assertFalse(calledWithContext) + } + + @Test + fun `single-arg call does not bridge context`() { + var receivedToolContext: ToolContext? = null + val springCallback = object : ToolCallback { + override fun getToolDefinition() = DefaultToolDefinition.builder() + .name("test").description("").inputSchema("{}").build() + override fun getToolMetadata() = DefaultToolMetadata.builder().build() + override fun call(toolInput: String): String { + // Single-arg — no context expected + return "single-arg" + } + override fun call(toolInput: String, toolContext: ToolContext?): String { + receivedToolContext = toolContext + return "two-arg" + } + } + val wrapper = SpringToolCallbackWrapper(springCallback) + val result = wrapper.call("{}") + assertTrue(result is Tool.Result.Text) + assertEquals("single-arg", (result as Tool.Result.Text).content) + // Two-arg variant should NOT have been called + assertNull(receivedToolContext) + } + } + + @Nested + inner class SpringToolCallbackAdapterContextBridging { + + @Test + fun `adapter passes ToolContext to two-arg tool call`() { + var receivedContext: ToolCallContext? = null + val tool = Tool.of("ctx_tool", "Context test") { _, context -> + receivedContext = context + Tool.Result.text("ok") + } + val adapter = SpringToolCallbackAdapter(tool) + val springCtx = ToolContext(mapOf("from-spring" to "spring-val")) + adapter.call("{}", springCtx) + assertNotNull(receivedContext) + assertEquals("spring-val", receivedContext!!.get("from-spring")) + } + + @Test + fun `adapter passes EMPTY when no ToolContext provided`() { + var receivedContext: ToolCallContext? = null + val tool = Tool.of("no_ctx_tool", "No context") { _, context -> + receivedContext = context + Tool.Result.text("ok") + } + val adapter = SpringToolCallbackAdapter(tool) + adapter.call("{}", null) + assertNotNull(receivedContext) + assertTrue(receivedContext!!.isEmpty) + } + } + + @Nested + inner class EndToEndMcpSimulation { + + @Test + fun `context flows from DefaultToolLoop through SpringToolCallbackWrapper to ToolCallback`() { + var mcpReceivedContext: ToolContext? = null + // Simulate an MCP ToolCallback that expects ToolContext (like McpMeta) + val mcpCallback = object : ToolCallback { + override fun getToolDefinition() = DefaultToolDefinition.builder() + .name("mcp_search").description("MCP search").inputSchema("{}").build() + override fun getToolMetadata() = DefaultToolMetadata.builder().build() + override fun call(toolInput: String) = "no-context" + override fun call(toolInput: String, toolContext: ToolContext?): String { + mcpReceivedContext = toolContext + return """{"results": ["found"]}""" + } + } + // Wrap as Embabel Tool (this is what SpringAiMcpToolFactory does) + val tool = SpringToolCallbackWrapper(mcpCallback) + val context = ToolCallContext.of("authToken" to "bearer-secret", "userId" to "user-42") + val toolLoop = DefaultToolLoop( + llmMessageSender = singleToolCallThenAnswer("mcp_search"), + objectMapper = objectMapper, + toolCallContext = context, + ) + val result = toolLoop.execute( + initialMessages = listOf(UserMessage("search for something")), + initialTools = listOf(tool), + outputParser = { it }, + ) + assertNotNull(mcpReceivedContext) + assertEquals("bearer-secret", mcpReceivedContext!!.context["authToken"]) + assertEquals("user-42", mcpReceivedContext!!.context["userId"]) + assertEquals("done", result.result) + } + } + + @Nested + inner class DelegatingToolContextForwarding { + + @Test + fun `renamed tool forwards context to delegate`() { + var receivedContext: ToolCallContext? = null + val inner = Tool.of("original", "test") { _, context -> + receivedContext = context + Tool.Result.text("ok") + } + val renamed = inner.withName("renamed_tool") + val ctx = ToolCallContext.of("key" to "value") + renamed.call("{}", ctx) + assertNotNull(receivedContext) + assertEquals("value", receivedContext!!.get("key")) + } + + @Test + fun `described tool forwards context to delegate`() { + var receivedContext: ToolCallContext? = null + val inner = Tool.of("original", "test") { _, context -> + receivedContext = context + Tool.Result.text("ok") + } + val described = inner.withDescription("new description") + val ctx = ToolCallContext.of("key" to "value") + described.call("{}", ctx) + assertNotNull(receivedContext) + assertEquals("value", receivedContext!!.get("key")) + } + } + + @Nested + inner class ObservabilityToolCallbackContextForwarding { + + @Test + fun `observability callback forwards ToolContext to delegate`() { + var receivedToolContext: ToolContext? = null + val delegate = object : ToolCallback { + override fun getToolDefinition() = DefaultToolDefinition.builder() + .name("obs_tool").description("").inputSchema("{}").build() + override fun call(toolInput: String) = "no-context" + override fun call(toolInput: String, toolContext: ToolContext?): String { + receivedToolContext = toolContext + return "observed" + } + } + val observed = ObservabilityToolCallback(delegate, ObservationRegistry.NOOP) + val ctx = ToolContext(mapOf("key" to "val")) + val result = observed.call("{}", ctx) + assertEquals("observed", result) + assertNotNull(receivedToolContext) + assertEquals("val", receivedToolContext!!.context["key"]) + } + } + + @Nested + inner class OutputTransformingToolCallbackContextForwarding { + + @Test + fun `output transforming callback forwards ToolContext to delegate`() { + var receivedToolContext: ToolContext? = null + val delegate = object : ToolCallback { + override fun getToolDefinition() = DefaultToolDefinition.builder() + .name("xform_tool").description("").inputSchema("{}").build() + override fun call(toolInput: String) = "no-context" + override fun call(toolInput: String, toolContext: ToolContext?): String { + receivedToolContext = toolContext + return "UPPER result" + } + } + val transformer = StringTransformer { it.lowercase() } + val xform = OutputTransformingToolCallback(delegate, transformer) + val ctx = ToolContext(mapOf("key" to "val")) + val result = xform.call("{}", ctx) + assertEquals("upper result", result) + assertNotNull(receivedToolContext) + assertEquals("val", receivedToolContext!!.context["key"]) + } + } + + // -- Helpers -- + + private fun singleToolCallThenAnswer(toolName: String): LlmMessageSender { + return object : LlmMessageSender { + private var called = false + override fun call(messages: List, tools: List): LlmMessageResponse { + if (!called) { + called = true + return toolCallResponse("call_1", toolName, "{}") + } + return textResponse("done") + } + } + } + + private fun toolCallResponse(id: String, name: String, arguments: String) = LlmMessageResponse( + message = AssistantMessageWithToolCalls( + content = " ", + toolCalls = listOf(ToolCall(id, name, arguments)), + ), + textContent = "", + ) + + private fun textResponse(text: String) = LlmMessageResponse( + message = AssistantMessage(text), + textContent = text, + ) +} diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/ToolCallContextTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/ToolCallContextTest.kt new file mode 100644 index 000000000..79b85a6ed --- /dev/null +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/ToolCallContextTest.kt @@ -0,0 +1,145 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.api.tool + +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test + +class ToolCallContextTest { + + @Nested + inner class BasicOperations { + + @Test + fun `EMPTY context has no entries`() { + assertTrue(ToolCallContext.EMPTY.isEmpty) + assertNull(ToolCallContext.EMPTY.get("anything")) + } + + @Test + fun `of creates context from pairs`() { + val ctx = ToolCallContext.of("key" to "value", "num" to 42) + assertEquals("value", ctx.get("key")) + assertEquals(42, ctx.get("num")) + assertFalse(ctx.isEmpty) + } + + @Test + fun `of creates context from map`() { + val ctx = ToolCallContext.of(mapOf("token" to "abc123")) + assertEquals("abc123", ctx.get("token")) + } + + @Test + fun `of empty pairs returns EMPTY singleton`() { + assertSame(ToolCallContext.EMPTY, ToolCallContext.of()) + } + + @Test + fun `of empty map returns EMPTY singleton`() { + assertSame(ToolCallContext.EMPTY, ToolCallContext.of(emptyMap())) + } + + @Test + fun `get returns null for missing key`() { + val ctx = ToolCallContext.of("a" to 1) + assertNull(ctx.get("missing")) + } + + @Test + fun `contains checks key presence`() { + val ctx = ToolCallContext.of("present" to true) + assertTrue("present" in ctx) + assertFalse("absent" in ctx) + } + + @Test + fun `getOrDefault returns default for missing key`() { + val ctx = ToolCallContext.of("a" to 1) + assertEquals("fallback", ctx.getOrDefault("missing", "fallback")) + assertEquals(1, ctx.getOrDefault("a", 99)) + } + + @Test + fun `toMap returns defensive copy`() { + val ctx = ToolCallContext.of("k" to "v") + val map = ctx.toMap() + assertEquals(mapOf("k" to "v"), map) + } + } + + @Nested + inner class MergeTests { + + @Test + fun `merge combines two contexts`() { + val a = ToolCallContext.of("x" to 1) + val b = ToolCallContext.of("y" to 2) + val merged = a.merge(b) + assertEquals(1, merged.get("x")) + assertEquals(2, merged.get("y")) + } + + @Test + fun `merge gives other precedence on conflict`() { + val base = ToolCallContext.of("key" to "old") + val override = ToolCallContext.of("key" to "new") + val merged = base.merge(override) + assertEquals("new", merged.get("key")) + } + + @Test + fun `merge with EMPTY returns original`() { + val ctx = ToolCallContext.of("a" to 1) + val merged = ctx.merge(ToolCallContext.EMPTY) + assertSame(ctx, merged) + } + + @Test + fun `EMPTY merge with other returns other`() { + val ctx = ToolCallContext.of("a" to 1) + val merged = ToolCallContext.EMPTY.merge(ctx) + assertSame(ctx, merged) + } + } + + @Nested + inner class EqualityTests { + + @Test + fun `contexts with same entries are equal`() { + val a = ToolCallContext.of("x" to 1, "y" to 2) + val b = ToolCallContext.of("x" to 1, "y" to 2) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + } + + @Test + fun `contexts with different entries are not equal`() { + val a = ToolCallContext.of("x" to 1) + val b = ToolCallContext.of("x" to 2) + assertNotEquals(a, b) + } + + @Test + fun `toString includes entries`() { + val ctx = ToolCallContext.of("key" to "val") + assertTrue(ctx.toString().contains("key")) + assertTrue(ctx.toString().contains("val")) + } + } +} diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/loop/ToolLoopFactoryTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/loop/ToolLoopFactoryTest.kt index a4af3ba41..5e12f4ca4 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/loop/ToolLoopFactoryTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/loop/ToolLoopFactoryTest.kt @@ -15,6 +15,7 @@ */ package com.embabel.agent.spi.loop +import com.embabel.agent.api.tool.ToolCallContext import com.embabel.agent.api.tool.config.ToolLoopConfiguration import com.embabel.agent.api.tool.config.ToolLoopConfiguration.ToolLoopType import com.embabel.agent.spi.loop.support.DefaultToolLoop @@ -50,6 +51,7 @@ class ToolLoopFactoryTest { toolDecorator = null, inspectors = emptyList(), transformers = emptyList(), + toolCallContext = ToolCallContext.EMPTY, ) assertNotNull(toolLoop) @@ -69,6 +71,7 @@ class ToolLoopFactoryTest { toolDecorator = null, inspectors = emptyList(), transformers = emptyList(), + toolCallContext = ToolCallContext.EMPTY, ) assertNotNull(toolLoop) @@ -91,6 +94,7 @@ class ToolLoopFactoryTest { toolDecorator = null, inspectors = emptyList(), transformers = emptyList(), + toolCallContext = ToolCallContext.EMPTY, ) assertNotNull(toolLoop) diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsGuardRailTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsGuardRailTest.kt index 9ee91839d..6e8c7493b 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsGuardRailTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsGuardRailTest.kt @@ -127,6 +127,7 @@ class ChatClientLlmOperationsGuardRailTest { emptyList() ) every { mockProcessContext.platformServices.eventListener } returns ese + every { mockProcessContext.processOptions } returns ProcessOptions() val mockAgentProcess = mockk() every { mockAgentProcess.recordLlmInvocation(any()) } answers { mutableLlmInvocationHistory.invocations.add(firstArg()) diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsTest.kt index 2a93c0fe9..03e9d8f25 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsTest.kt @@ -21,6 +21,7 @@ import com.embabel.agent.api.tool.ToolObject import com.embabel.agent.core.AgentProcess import com.embabel.agent.core.Blackboard import com.embabel.agent.core.ProcessContext +import com.embabel.agent.core.ProcessOptions import com.embabel.agent.core.internal.LlmOperations import com.embabel.agent.core.support.InvalidLlmReturnFormatException import com.embabel.agent.core.support.InvalidLlmReturnTypeException @@ -127,6 +128,7 @@ class ChatClientLlmOperationsTest { emptyList() ) every { mockProcessContext.platformServices.eventListener } returns ese + every { mockProcessContext.processOptions } returns ProcessOptions() val mockAgentProcess = mockk() every { mockAgentProcess.recordLlmInvocation(any()) } answers { mutableLlmInvocationHistory.invocations.add(firstArg()) @@ -698,6 +700,7 @@ class ChatClientLlmOperationsTest { emptyList() ) every { mockProcessContext.platformServices.eventListener } returns ese + every { mockProcessContext.processOptions } returns ProcessOptions() val mockAgentProcess = mockk() every { mockAgentProcess.recordLlmInvocation(any()) } answers { mutableLlmInvocationHistory.invocations.add(firstArg()) @@ -1039,6 +1042,7 @@ class ChatClientLlmOperationsTest { emptyList() ) every { mockProcessContext.platformServices.eventListener } returns ese + every { mockProcessContext.processOptions } returns ProcessOptions() val mockAgentProcess = mockk() every { mockAgentProcess.recordLlmInvocation(any()) } answers { mutableLlmInvocationHistory.invocations.add(firstArg()) diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsThinkingTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsThinkingTest.kt index d387c3ca3..d5893a0ca 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsThinkingTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsThinkingTest.kt @@ -86,6 +86,7 @@ class ChatClientLlmOperationsThinkingTest { emptyList() ) every { mockProcessContext.platformServices.eventListener } returns ese + every { mockProcessContext.processOptions } returns com.embabel.agent.core.ProcessOptions() val mockAgentProcess = mockk() every { mockAgentProcess.recordLlmInvocation(any()) } answers { mutableLlmInvocationHistory.invocations.add(firstArg()) diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmTransformerTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmTransformerTest.kt index 0886d5027..20eb5f453 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmTransformerTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmTransformerTest.kt @@ -25,6 +25,7 @@ import com.embabel.agent.core.AgentProcess import com.embabel.agent.core.LlmInvocation import com.embabel.agent.core.LlmInvocationHistory import com.embabel.agent.core.ProcessContext +import com.embabel.agent.core.ProcessOptions import com.embabel.agent.core.support.LlmInteraction import com.embabel.agent.spi.support.springai.ChatClientLlmOperations import com.embabel.agent.spi.support.MaybeReturn @@ -158,6 +159,7 @@ class ChatClientLlmTransformerTest { val mockProcessContext = mockk() every { mockProcessContext.onProcessEvent(any()) } answers { eventListener.onProcessEvent(firstArg()) } every { mockProcessContext.platformServices } returns mockPlatformServices + every { mockProcessContext.processOptions } returns ProcessOptions() every { mockProcessContext.agentProcess } returns mockAgentProcess every { mockAgentProcess.processContext } returns mockProcessContext @@ -360,6 +362,7 @@ class ChatClientLlmTransformerTest { val mockProcessContext = mockk() every { mockProcessContext.onProcessEvent(any()) } answers { eventListener.onProcessEvent(firstArg()) } every { mockProcessContext.platformServices } returns mockPlatformServices + every { mockProcessContext.processOptions } returns ProcessOptions() every { mockProcessContext.agentProcess } returns mockAgentProcess every { mockAgentProcess.processContext } returns mockProcessContext every { mockAgentProcess.recordLlmInvocation(any()) } answers { diff --git a/embabel-agent-shell/src/main/kotlin/com/embabel/agent/shell/ShellCommands.kt b/embabel-agent-shell/src/main/kotlin/com/embabel/agent/shell/ShellCommands.kt index ff4813487..7c4a631c2 100644 --- a/embabel-agent-shell/src/main/kotlin/com/embabel/agent/shell/ShellCommands.kt +++ b/embabel-agent-shell/src/main/kotlin/com/embabel/agent/shell/ShellCommands.kt @@ -17,6 +17,7 @@ package com.embabel.agent.shell import com.embabel.agent.api.common.ToolsStats import com.embabel.agent.api.common.autonomy.* +import com.embabel.agent.api.tool.ToolCallContext import com.embabel.agent.core.* import com.embabel.agent.domain.io.UserInput import com.embabel.agent.shell.config.ShellProperties @@ -76,6 +77,12 @@ class ShellCommands( */ private var openMode: Boolean = false + /** + * Persistent tool call context, set via `set-context` command. + * Passed to all subsequent agent executions. + */ + private var persistentToolCallContext: ToolCallContext = ToolCallContext.EMPTY + private var defaultProcessOptions: ProcessOptions = ProcessOptions( verbosity = Verbosity( debug = false, @@ -91,6 +98,38 @@ class ShellCommands( return "Blackboard cleared" } + @ShellMethod( + value = "Set persistent tool call context as key=value pairs, passed to all tools during execution. " + + "Example: set-context tenantId=acme,apiKey=secret123", + key = ["set-context", "sc"], + ) + fun setContext( + @ShellOption( + help = "Comma-separated key=value pairs (e.g. tenantId=acme,apiKey=secret). Use 'clear' to reset.", + defaultValue = "", + ) context: String, + ): String { + if (context.isBlank() || context == "clear") { + persistentToolCallContext = ToolCallContext.EMPTY + return "Tool call context cleared".color(colorPalette.color2) + } + persistentToolCallContext = parseToolCallContext(context) + return "Tool call context set: ${persistentToolCallContext.toMap()}".color(colorPalette.color2) + } + + @ShellMethod( + value = "Show current tool call context", + key = ["show-context"], + ) + fun showContext(): String { + val ctx = persistentToolCallContext.toMap() + return if (ctx.isEmpty()) { + "Tool call context is empty" + } else { + "Tool call context: $ctx" + }.color(colorPalette.color2) + } + @ShellMethod(value = "Show recent agent process runs. This is what actually happened, not just what was planned.") fun runs(): String { val plans = agentProcesses.map { @@ -319,6 +358,12 @@ class ShellCommands( help = "show detailed planning info", defaultValue = "true", ) showPlanning: Boolean = true, + @ShellOption( + value = ["-c", "--context"], + help = "Tool call context as comma-separated key=value pairs (e.g. tenantId=acme,apiKey=secret). " + + "Merged with persistent context set via set-context; these values win on conflict.", + defaultValue = ShellOption.NULL, + ) context: String? = null, ): String { // Override any options setOptions( @@ -331,9 +376,24 @@ class ShellCommands( operationDelay = operationDelay, showPlanning = showPlanning, ) + // Merge persistent context with one-off context (one-off wins on conflict) + val effectiveContext = if (context != null) { + persistentToolCallContext.merge(parseToolCallContext(context)) + } else { + persistentToolCallContext + } + val processOptions = if (effectiveContext != ToolCallContext.EMPTY) { + logger.info( + "ToolCallContext: {}".color(colorPalette.highlight), + effectiveContext.toMap(), + ) + defaultProcessOptions.withToolCallContext(effectiveContext) + } else { + defaultProcessOptions + } return executeIntent( intent = intent, - processOptions = defaultProcessOptions, + processOptions = processOptions, ) } @@ -392,6 +452,22 @@ class ShellCommands( } + /** + * Parse a comma-separated "key=value" string into a [ToolCallContext]. + * Example input: "tenantId=acme,apiKey=secret123" + */ + private fun parseToolCallContext(input: String): ToolCallContext { + if (input.isBlank()) return ToolCallContext.EMPTY + val map = input.split(",") + .map { it.trim() } + .filter { it.contains("=") } + .associate { entry -> + val (key, value) = entry.split("=", limit = 2) + key.trim() to value.trim() + } + return ToolCallContext.of(map) + } + private fun recordAgentProcess(agentProcess: AgentProcess) { agentProcesses.add(agentProcess) blackboard = agentProcess.processContext.blackboard From 066ee54f72895756998172360e5ddd933129c29b Mon Sep 17 00:00:00 2001 From: alexheifetz Date: Mon, 2 Mar 2026 20:28:01 -0500 Subject: [PATCH 2/7] Update User Guide --- .../reference/agent-process/page.adoc | 30 ++++ .../main/asciidoc/reference/tools/page.adoc | 155 ++++++++++++++++++ .../src/main/asciidoc/shell/commands.adoc | 46 +++++- 3 files changed, 230 insertions(+), 1 deletion(-) diff --git a/embabel-agent-docs/src/main/asciidoc/reference/agent-process/page.adoc b/embabel-agent-docs/src/main/asciidoc/reference/agent-process/page.adoc index 1962e4ca0..9c5fe39d0 100644 --- a/embabel-agent-docs/src/main/asciidoc/reference/agent-process/page.adoc +++ b/embabel-agent-docs/src/main/asciidoc/reference/agent-process/page.adoc @@ -19,6 +19,36 @@ Allows fine grained control over logging prompts, LLM returns and detailed plann * `control`: Control options, determining whether the agent should be terminated as a last resort. `EarlyTerminationPolicy` can based on an absolute number of actions or a maximum budget. * Delays: Both operations (actions) and tools can have delays. This is useful to avoid rate limiting. +* `toolCallContext`: Out-of-band metadata (e.g., auth tokens, tenant IDs, correlation IDs) passed to all tool invocations during the process. +This context propagates through the entire tool pipeline—including decorator chains and MCP tools—without being exposed to the LLM. +Set via `withToolCallContext()`: + +[tabs] +==== +Java:: ++ +[source,java] +---- +var processOptions = new ProcessOptions() + .withToolCallContext(Map.of( + "authToken", bearerToken, + "tenantId", tenantId + )); +---- + +Kotlin:: ++ +[source,kotlin] +---- +val processOptions = ProcessOptions() + .withToolCallContext(ToolCallContext.of( + "authToken" to bearerToken, + "tenantId" to tenantId, + )) +---- +==== + +See <> for how tools receive this context. [graphviz, embabel_execution_context.dot, png] diff --git a/embabel-agent-docs/src/main/asciidoc/reference/tools/page.adoc b/embabel-agent-docs/src/main/asciidoc/reference/tools/page.adoc index b61cf7595..4ce5eb9bb 100644 --- a/embabel-agent-docs/src/main/asciidoc/reference/tools/page.adoc +++ b/embabel-agent-docs/src/main/asciidoc/reference/tools/page.adoc @@ -115,6 +115,161 @@ fun bindCustomer(id: Long): String { ---- ==== +[[reference.tools__tool-call-context]] +==== Receiving Out-of-Band Context in Tools + +Tool methods often need access to infrastructure metadata—auth tokens, tenant IDs, correlation IDs—that should not be part of the LLM-facing JSON schema. +`ToolCallContext` provides this: an immutable key-value bag that flows through the tool pipeline without the LLM ever seeing it. + +Think of it like HTTP headers on a request. +The caller sets them at the boundary (a REST filter, an event handler), and every handler in the chain can read them—but the request body (what the LLM provides) is unaffected. + +===== Injecting ToolCallContext into @LlmTool Methods + +Declare a `ToolCallContext` parameter on any `@LlmTool` method. +The framework will: + +* **Inject** the current context at call time (or `ToolCallContext.EMPTY` if none was set) +* **Exclude** the parameter from the JSON schema the LLM sees + +[tabs] +==== +Java:: ++ +[source,java] +---- +public class CustomerTools { + + @LlmTool(description = "Look up customer by ID") + public String lookupCustomer( + @LlmTool.Param(description = "Customer ID") long customerId, + ToolCallContext context) { + String tenantId = context.get("tenantId"); + String authToken = context.get("authToken"); + return customerService.lookup(customerId, tenantId, authToken); + } +} +---- + +Kotlin:: ++ +[source,kotlin] +---- +class CustomerTools { + + @LlmTool(description = "Look up customer by ID") + fun lookupCustomer( + @LlmTool.Param(description = "Customer ID") customerId: Long, + context: ToolCallContext, + ): String { + val tenantId = context.get("tenantId") + val authToken = context.get("authToken") + return customerService.lookup(customerId, tenantId, authToken) + } +} +---- +==== + +The LLM sees only the `customerId` parameter. +The `ToolCallContext` parameter is invisible in the tool's schema. + +This works for both `KotlinMethodTool` and `JavaMethodTool`—the `ToolCallContext` parameter can appear at any position in the method signature. + +===== Setting Context via ProcessOptions + +Context is set at the process boundary using `ProcessOptions.withToolCallContext()`. +It then propagates to every tool invocation in the process—including MCP tools, where it bridges to Spring AI's `ToolContext`. + +[tabs] +==== +Java:: ++ +[source,java] +---- +// In a REST controller or event handler +var processOptions = new ProcessOptions() + .withToolCallContext(Map.of( + "authToken", request.getHeader("Authorization"), + "tenantId", request.getHeader("X-Tenant-Id"), + "correlationId", UUID.randomUUID().toString() + )); + +var invocation = AgentInvocation.builder(agentPlatform) + .options(processOptions) + .build(CustomerReport.class); + +CustomerReport report = invocation.invoke(customerQuery); +---- + +Kotlin:: ++ +[source,kotlin] +---- +// In a REST controller or event handler +val processOptions = ProcessOptions() + .withToolCallContext(ToolCallContext.of( + "authToken" to request.getHeader("Authorization"), + "tenantId" to request.getHeader("X-Tenant-Id"), + "correlationId" to UUID.randomUUID().toString(), + )) + +val invocation = AgentInvocation.builder(agentPlatform) + .options(processOptions) + .build() + +val report = invocation.invoke(customerQuery) +---- +==== + +===== Context Propagation Through Decorators + +`ToolCallContext` flows automatically through decorator chains. +Any tool implementing `DelegatingTool` forwards the context to its delegate by default. +Built-in decorators like `ArtifactSinkingTool` and `ReplanningTool` follow this pattern, so context reaches the underlying tool without any extra wiring. + +===== Using Context in Framework-Agnostic Tools + +For programmatically created tools, use `Tool.ContextAwareFunction` to receive context in the handler. +The `Tool.of()` factory method accepts a `ContextAwareFunction` as the last parameter: + +[tabs] +==== +Java:: ++ +[source,java] +---- +Tool tenantAwareTool = Tool.of( + "search", + "Search within tenant scope", + Tool.InputSchema.of(Tool.Parameter.string("query", "Search query")), + Tool.Metadata.DEFAULT, + (Tool.ContextAwareFunction) (input, context) -> { + String tenantId = context.get("tenantId"); + return Tool.Result.text(searchService.search(input, tenantId)); + } +); +---- + +Kotlin:: ++ +[source,kotlin] +---- +val tenantAwareTool = Tool.of( + name = "search", + description = "Search within tenant scope", + inputSchema = Tool.InputSchema.of(Tool.Parameter.string("query", "Search query")), +) { input: String, context: ToolCallContext -> + val tenantId = context.get("tenantId") + Tool.Result.text(searchService.search(input, tenantId)) +} +---- +==== + +When no context is provided, the function receives `ToolCallContext.EMPTY`. + +TIP: Context is immutable and safe to read from any thread. +If you need to pass context from a web request through to tool invocations, set it once on `ProcessOptions` and every tool in the process will receive it. + [[reference.tools__tool-groups]] ==== Tool Groups diff --git a/embabel-agent-docs/src/main/asciidoc/shell/commands.adoc b/embabel-agent-docs/src/main/asciidoc/shell/commands.adoc index a9c5cb71f..253dac2c9 100644 --- a/embabel-agent-docs/src/main/asciidoc/shell/commands.adoc +++ b/embabel-agent-docs/src/main/asciidoc/shell/commands.adoc @@ -1,2 +1,46 @@ [[shell.commands]] -=== Shell Commands \ No newline at end of file +=== Shell Commands + +==== Tool Call Context Commands + +The shell supports setting out-of-band metadata that is passed to all tools during agent execution. +This is the shell interface for <>. + +===== set-context (sc) + +Set persistent tool call context as comma-separated `key=value` pairs. +This context is passed to every subsequent `execute` / `x` invocation until cleared. + +---- +# Set persistent context +set-context tenantId=acme,authToken=bearer-xyz123 + +# Shorthand alias +sc tenantId=acme,authToken=bearer-xyz123 + +# Clear persistent context +set-context clear +---- + +===== show-context + +Display the current persistent tool call context. + +---- +shell:> show-context +Tool call context: {tenantId=acme, authToken=bearer-xyz123} +---- + +===== Per-Execution Context Override + +The `execute` command accepts a `-c` / `--context` flag for one-off context entries. +These are merged with the persistent context; per-execution entries win on conflict. + +---- +# Persistent context: tenantId=acme +# Per-execution override adds correlationId and could override tenantId +x "Find news for Alice" -c "correlationId=req-456,tenantId=beta" +---- + +In this example the effective context for that single execution is `{tenantId=beta, authToken=bearer-xyz123, correlationId=req-456}`. +The persistent context remains `{tenantId=acme, authToken=bearer-xyz123}` for future invocations. \ No newline at end of file From 5e070beeab440a8b66731167a908eb17f22dba39 Mon Sep 17 00:00:00 2001 From: alexheifetz Date: Tue, 3 Mar 2026 00:15:51 -0500 Subject: [PATCH 3/7] More Unit Tests --- .../ToolCallContextMcpMetaConverterTest.kt | 192 +++++++++++ .../agent/shell/ShellCommandsContextTest.kt | 304 ++++++++++++++++++ 2 files changed, 496 insertions(+) create mode 100644 embabel-agent-api/src/test/kotlin/com/embabel/agent/tools/mcp/ToolCallContextMcpMetaConverterTest.kt create mode 100644 embabel-agent-shell/src/test/kotlin/com/embabel/agent/shell/ShellCommandsContextTest.kt diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/tools/mcp/ToolCallContextMcpMetaConverterTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/tools/mcp/ToolCallContextMcpMetaConverterTest.kt new file mode 100644 index 000000000..993777f04 --- /dev/null +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/tools/mcp/ToolCallContextMcpMetaConverterTest.kt @@ -0,0 +1,192 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.tools.mcp + +import com.embabel.agent.api.tool.ToolCallContext +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test + +class ToolCallContextMcpMetaConverterTest { + + private val context = ToolCallContext.of( + "authToken" to "bearer-secret", + "tenantId" to "acme", + "correlationId" to "req-123", + "apiKey" to "sk-supersecret", + ) + + @Nested + inner class PassThrough { + + @Test + fun `propagates all entries`() { + val converter = ToolCallContextMcpMetaConverter.passThrough() + val result = converter.convert(context) + assertEquals(context.toMap(), result) + } + + @Test + fun `returns empty map for EMPTY context`() { + val converter = ToolCallContextMcpMetaConverter.passThrough() + val result = converter.convert(ToolCallContext.EMPTY) + assertTrue(result.isEmpty()) + } + } + + @Nested + inner class NoOp { + + @Test + fun `suppresses all entries`() { + val converter = ToolCallContextMcpMetaConverter.noOp() + val result = converter.convert(context) + assertTrue(result.isEmpty()) + } + + @Test + fun `returns empty map for EMPTY context`() { + val converter = ToolCallContextMcpMetaConverter.noOp() + val result = converter.convert(ToolCallContext.EMPTY) + assertTrue(result.isEmpty()) + } + } + + @Nested + inner class AllowKeys { + + @Test + fun `only propagates allowed keys`() { + val converter = ToolCallContextMcpMetaConverter.allowKeys("tenantId", "correlationId") + val result = converter.convert(context) + assertEquals( + mapOf("tenantId" to "acme", "correlationId" to "req-123"), + result, + ) + } + + @Test + fun `returns empty map when no keys match`() { + val converter = ToolCallContextMcpMetaConverter.allowKeys("nonExistent") + val result = converter.convert(context) + assertTrue(result.isEmpty()) + } + + @Test + fun `returns empty map for EMPTY context`() { + val converter = ToolCallContextMcpMetaConverter.allowKeys("tenantId") + val result = converter.convert(ToolCallContext.EMPTY) + assertTrue(result.isEmpty()) + } + + @Test + fun `single key allowlist works`() { + val converter = ToolCallContextMcpMetaConverter.allowKeys("authToken") + val result = converter.convert(context) + assertEquals(mapOf("authToken" to "bearer-secret"), result) + } + } + + @Nested + inner class DenyKeys { + + @Test + fun `excludes denied keys`() { + val converter = ToolCallContextMcpMetaConverter.denyKeys("authToken", "apiKey") + val result = converter.convert(context) + assertEquals( + mapOf("tenantId" to "acme", "correlationId" to "req-123"), + result, + ) + } + + @Test + fun `propagates all when no keys match denylist`() { + val converter = ToolCallContextMcpMetaConverter.denyKeys("nonExistent") + val result = converter.convert(context) + assertEquals(context.toMap(), result) + } + + @Test + fun `returns empty map for EMPTY context`() { + val converter = ToolCallContextMcpMetaConverter.denyKeys("authToken") + val result = converter.convert(ToolCallContext.EMPTY) + assertTrue(result.isEmpty()) + } + + @Test + fun `denying all keys produces empty map`() { + val converter = ToolCallContextMcpMetaConverter.denyKeys( + "authToken", "tenantId", "correlationId", "apiKey", + ) + val result = converter.convert(context) + assertTrue(result.isEmpty()) + } + } + + @Nested + inner class CustomConverter { + + @Test + fun `lambda converter can transform entries`() { + val converter = ToolCallContextMcpMetaConverter { ctx -> + mapOf( + "tenant" to (ctx.get("tenantId") ?: "unknown"), + "hasAuth" to (ctx.get("authToken") != null).toString(), + ) + } + val result = converter.convert(context) + assertEquals( + mapOf("tenant" to "acme", "hasAuth" to "true"), + result, + ) + } + + @Test + fun `lambda converter can return empty map`() { + val converter = ToolCallContextMcpMetaConverter { emptyMap() } + val result = converter.convert(context) + assertTrue(result.isEmpty()) + } + + @Test + fun `lambda converter receives full context`() { + var receivedContext: ToolCallContext? = null + val converter = ToolCallContextMcpMetaConverter { ctx -> + receivedContext = ctx + emptyMap() + } + converter.convert(context) + assertNotNull(receivedContext) + assertEquals("bearer-secret", receivedContext!!.get("authToken")) + } + } + + @Nested + inner class AllowVsDenySymmetry { + + @Test + fun `allowKeys and denyKeys are complementary for full key set`() { + val allowResult = ToolCallContextMcpMetaConverter + .allowKeys("tenantId", "correlationId") + .convert(context) + val denyResult = ToolCallContextMcpMetaConverter + .denyKeys("authToken", "apiKey") + .convert(context) + assertEquals(allowResult, denyResult) + } + } +} diff --git a/embabel-agent-shell/src/test/kotlin/com/embabel/agent/shell/ShellCommandsContextTest.kt b/embabel-agent-shell/src/test/kotlin/com/embabel/agent/shell/ShellCommandsContextTest.kt new file mode 100644 index 000000000..1d09cd1d0 --- /dev/null +++ b/embabel-agent-shell/src/test/kotlin/com/embabel/agent/shell/ShellCommandsContextTest.kt @@ -0,0 +1,304 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.shell + +import com.embabel.agent.api.common.ToolsStats +import com.embabel.agent.api.common.autonomy.Autonomy +import com.embabel.agent.api.common.autonomy.AutonomyProperties +import com.embabel.agent.api.common.autonomy.NoAgentFound +import com.embabel.agent.api.common.ranking.Rankings +import com.embabel.agent.api.tool.ToolCallContext +import com.embabel.agent.core.Agent +import com.embabel.agent.core.AgentPlatform +import com.embabel.agent.core.ProcessOptions +import com.embabel.agent.shell.config.ShellProperties +import com.embabel.agent.spi.logging.ColorPalette +import com.embabel.agent.spi.logging.LoggingPersonality +import com.embabel.common.ai.model.ModelProvider +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper +import io.mockk.* +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test +import org.springframework.context.ConfigurableApplicationContext +import org.springframework.core.env.ConfigurableEnvironment + +/** + * Tests for the ToolCallContext-related shell commands introduced in PR #1456: + * set-context, show-context, and the -c flag on execute. + */ +class ShellCommandsContextTest { + + private val autonomy: Autonomy = mockk(relaxed = true) + private val modelProvider: ModelProvider = mockk(relaxed = true) + private val terminalServices: TerminalServices = mockk(relaxed = true) + private val environment: ConfigurableEnvironment = mockk(relaxed = true) + private val objectMapper: ObjectMapper = jacksonObjectMapper() + private val colorPalette: ColorPalette = object : ColorPalette { + // Use no-op colors for testing (ANSI escape won't affect assertions) + override val highlight: Int = 0xbeb780 + override val color2: Int = 0x7da17e + } + private val loggingPersonality: LoggingPersonality = mockk(relaxed = true) { + every { logger } returns mockk(relaxed = true) + every { colorPalette } returns this@ShellCommandsContextTest.colorPalette + } + private val toolsStats: ToolsStats = mockk(relaxed = true) + private val context: ConfigurableApplicationContext = mockk(relaxed = true) + private val agentPlatform: AgentPlatform = mockk(relaxed = true) + private val autonomyProperties: AutonomyProperties = mockk(relaxed = true) { + every { agentConfidenceCutOff } returns 0.6 + every { goalConfidenceCutOff } returns 0.6 + } + + private lateinit var shellCommands: ShellCommands + + @BeforeEach + fun setUp() { + every { autonomy.agentPlatform } returns agentPlatform + every { autonomy.properties } returns autonomyProperties + shellCommands = ShellCommands( + autonomy = autonomy, + modelProvider = modelProvider, + terminalServices = terminalServices, + environment = environment, + objectMapper = objectMapper, + colorPalette = colorPalette, + loggingPersonality = loggingPersonality, + toolsStats = toolsStats, + context = context, + shellProperties = ShellProperties(), + ) + } + + @Nested + inner class SetContext { + + @Test + fun `sets context from key=value pairs`() { + val result = shellCommands.setContext("tenantId=acme,apiKey=secret123") + assertTrue(result.contains("tenantId")) + assertTrue(result.contains("acme")) + assertTrue(result.contains("apiKey")) + assertTrue(result.contains("secret123")) + } + + @Test + fun `clears context when input is 'clear'`() { + // First set something + shellCommands.setContext("tenantId=acme") + // Then clear + val result = shellCommands.setContext("clear") + assertTrue(result.contains("cleared")) + } + + @Test + fun `clears context when input is blank`() { + shellCommands.setContext("tenantId=acme") + val result = shellCommands.setContext("") + assertTrue(result.contains("cleared")) + } + + @Test + fun `handles value containing equals sign`() { + val result = shellCommands.setContext("token=abc=def") + assertTrue(result.contains("token")) + assertTrue(result.contains("abc=def")) + } + + @Test + fun `ignores entries without equals sign`() { + shellCommands.setContext("validKey=value,invalidEntry") + val showResult = shellCommands.showContext() + assertTrue(showResult.contains("validKey")) + assertFalse(showResult.contains("invalidEntry")) + } + + @Test + fun `trims whitespace around keys and values`() { + shellCommands.setContext(" tenantId = acme , apiKey = secret ") + val showResult = shellCommands.showContext() + assertTrue(showResult.contains("tenantId")) + assertTrue(showResult.contains("acme")) + } + } + + @Nested + inner class ShowContext { + + @Test + fun `shows empty message when no context is set`() { + val result = shellCommands.showContext() + assertTrue(result.contains("empty")) + } + + @Test + fun `shows context entries after setContext`() { + shellCommands.setContext("tenantId=acme,authToken=bearer-xyz") + val result = shellCommands.showContext() + assertTrue(result.contains("tenantId")) + assertTrue(result.contains("acme")) + assertTrue(result.contains("authToken")) + assertTrue(result.contains("bearer-xyz")) + } + + @Test + fun `shows empty after clearing context`() { + shellCommands.setContext("tenantId=acme") + shellCommands.setContext("clear") + val result = shellCommands.showContext() + assertTrue(result.contains("empty")) + } + } + + @Nested + inner class ExecuteContextMerge { + + @BeforeEach + fun setUpAutonomyToThrow() { + // Make autonomy throw NoAgentFound so we can test context propagation + // without needing a full agent execution pipeline + every { + autonomy.chooseAndRunAgent( + intent = any(), + processOptions = any(), + ) + } throws NoAgentFound( + agentRankings = Rankings(emptyList()), + basis = "test", + ) + } + + @Test + fun `execute without context uses empty context`() { + val capturedOptions = slot() + every { + autonomy.chooseAndRunAgent( + intent = any(), + processOptions = capture(capturedOptions), + ) + } throws NoAgentFound( + agentRankings = Rankings(emptyList()), + basis = "test", + ) + + shellCommands.execute( + intent = "test intent", + showPrompts = false, + ) + + assertTrue(capturedOptions.captured.toolCallContext.isEmpty) + } + + @Test + fun `execute with persistent context propagates it`() { + val capturedOptions = slot() + every { + autonomy.chooseAndRunAgent( + intent = any(), + processOptions = capture(capturedOptions), + ) + } throws NoAgentFound( + agentRankings = Rankings(emptyList()), + basis = "test", + ) + + shellCommands.setContext("tenantId=acme") + shellCommands.execute( + intent = "test intent", + showPrompts = false, + ) + + assertEquals("acme", capturedOptions.captured.toolCallContext.get("tenantId")) + } + + @Test + fun `execute with per-execution context merges with persistent`() { + val capturedOptions = slot() + every { + autonomy.chooseAndRunAgent( + intent = any(), + processOptions = capture(capturedOptions), + ) + } throws NoAgentFound( + agentRankings = Rankings(emptyList()), + basis = "test", + ) + + shellCommands.setContext("tenantId=acme,authToken=xyz") + shellCommands.execute( + intent = "test intent", + showPrompts = false, + context = "correlationId=req-123", + ) + + val ctx = capturedOptions.captured.toolCallContext + assertEquals("acme", ctx.get("tenantId")) + assertEquals("xyz", ctx.get("authToken")) + assertEquals("req-123", ctx.get("correlationId")) + } + + @Test + fun `per-execution context wins on conflict with persistent`() { + val capturedOptions = slot() + every { + autonomy.chooseAndRunAgent( + intent = any(), + processOptions = capture(capturedOptions), + ) + } throws NoAgentFound( + agentRankings = Rankings(emptyList()), + basis = "test", + ) + + shellCommands.setContext("tenantId=acme") + shellCommands.execute( + intent = "test intent", + showPrompts = false, + context = "tenantId=beta", + ) + + assertEquals("beta", capturedOptions.captured.toolCallContext.get("tenantId")) + } + + @Test + fun `persistent context is unchanged after per-execution override`() { + every { + autonomy.chooseAndRunAgent( + intent = any(), + processOptions = any(), + ) + } throws NoAgentFound( + agentRankings = Rankings(emptyList()), + basis = "test", + ) + + shellCommands.setContext("tenantId=acme") + shellCommands.execute( + intent = "test intent", + showPrompts = false, + context = "tenantId=beta", + ) + + // Persistent context should still be "acme" + val showResult = shellCommands.showContext() + assertTrue(showResult.contains("acme")) + assertFalse(showResult.contains("beta")) + } + } +} From 1ddabf58c7c36722fff75492a81a90a8b6cb29ae Mon Sep 17 00:00:00 2001 From: alexheifetz Date: Wed, 4 Mar 2026 19:31:58 -0500 Subject: [PATCH 4/7] fix: make call(String, ToolCallContext) the canonical entry point for DelegatingTool --- .../agent/api/tool/ArtifactSinkingTool.kt | 3 - .../embabel/agent/api/tool/DelegatingTool.kt | 31 ++++++--- .../embabel/agent/api/tool/ReplanningTools.kt | 6 -- .../kotlin/com/embabel/agent/api/tool/Tool.kt | 4 -- .../embabel/agent/core/hitl/AwaitingTools.kt | 13 ++-- .../agent/spi/support/ToolDecorators.kt | 18 ------ .../embabel/chat/support/AssetAddingTool.kt | 5 +- .../tool/DelegatingToolArchitectureTest.kt | 64 +++++++++++++++++++ 8 files changed, 96 insertions(+), 48 deletions(-) create mode 100644 embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/DelegatingToolArchitectureTest.kt diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ArtifactSinkingTool.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ArtifactSinkingTool.kt index c1a0a8c4e..a786771a1 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ArtifactSinkingTool.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ArtifactSinkingTool.kt @@ -88,9 +88,6 @@ class ArtifactSinkingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result = - callAndSink { delegate.call(input) } - override fun call(input: String, context: ToolCallContext): Tool.Result = callAndSink { delegate.call(input, context) } diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/DelegatingTool.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/DelegatingTool.kt index e9f7b19fd..6d407b93f 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/DelegatingTool.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/DelegatingTool.kt @@ -20,11 +20,16 @@ package com.embabel.agent.api.tool * Enables unwrapping to find the underlying tool implementation. * Thus, it is important that tool wrappers implement this interface to allow unwrapping. * - * The default [call] (String, ToolCallContext) implementation propagates - * context through the decorator chain by delegating to - * `delegate.call(input, context)`. Decorators that add behavior - * (e.g., artifact sinking, replanning) should override this method - * to apply their logic while preserving context propagation. + * ## Canonical call method + * + * [call] (String, ToolCallContext) is the **single canonical entry point** for + * decorator logic. Decorators should override only this method. The single-arg + * [call] (String) routes through it automatically via [ToolCallContext.EMPTY], + * so both call paths execute the same decorator behavior. + * + * This eliminates a class of bugs where a decorator overrides [call] (String) + * but the two-arg variant (used by [com.embabel.agent.spi.loop.support.DefaultToolLoop]) + * bypasses the decorator entirely. */ interface DelegatingTool : Tool { @@ -34,10 +39,18 @@ interface DelegatingTool : Tool { val delegate: Tool /** - * Propagates [context] through the decorator chain. - * Decorators that override [call] (String) to add behavior should - * also override this method to apply the same behavior while - * forwarding context to [delegate]. + * Routes single-arg calls through the canonical two-arg method, + * ensuring decorator logic in [call] (String, ToolCallContext) is + * always executed regardless of which overload the caller uses. + */ + override fun call(input: String): Tool.Result = + call(input, ToolCallContext.EMPTY) + + /** + * Canonical entry point for decorator logic. Override this method + * to add behavior while preserving context propagation to [delegate]. + * + * The default implementation simply forwards to the delegate. */ override fun call(input: String, context: ToolCallContext): Tool.Result = delegate.call(input, context) diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ReplanningTools.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ReplanningTools.kt index 4e5af31d1..709677bfe 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ReplanningTools.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/ReplanningTools.kt @@ -62,9 +62,6 @@ class ReplanningTool @JvmOverloads constructor( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result = - callAndReplan { delegate.call(input) } - override fun call(input: String, context: ToolCallContext): Tool.Result = callAndReplan { delegate.call(input, context) } @@ -149,9 +146,6 @@ class ConditionalReplanningTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result = - callAndMaybeReplan { delegate.call(input) } - override fun call(input: String, context: ToolCallContext): Tool.Result = callAndMaybeReplan { delegate.call(input, context) } diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/Tool.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/Tool.kt index a13059b0a..d1f728e4f 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/Tool.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/tool/Tool.kt @@ -622,8 +622,6 @@ private class RenamedTool( override val metadata: Tool.Metadata get() = delegate.metadata - - override fun call(input: String): Tool.Result = delegate.call(input) } /** @@ -643,8 +641,6 @@ private class DescribedTool( override val metadata: Tool.Metadata get() = delegate.metadata - - override fun call(input: String): Tool.Result = delegate.call(input) } // Private implementations diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/hitl/AwaitingTools.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/hitl/AwaitingTools.kt index b5248faae..f7e98ff28 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/hitl/AwaitingTools.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/hitl/AwaitingTools.kt @@ -17,6 +17,7 @@ package com.embabel.agent.core.hitl import com.embabel.agent.api.tool.DelegatingTool import com.embabel.agent.api.tool.Tool +import com.embabel.agent.api.tool.ToolCallContext import com.embabel.agent.core.AgentProcess /** @@ -63,7 +64,7 @@ class ConfirmingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result { + override fun call(input: String, context: ToolCallContext): Tool.Result { val message = messageProvider(input) throw AwaitableResponseException( ConfirmationRequest( @@ -92,21 +93,21 @@ class ConditionalAwaitingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result { + override fun call(input: String, context: ToolCallContext): Tool.Result { val agentProcess = AgentProcess.get() ?: throw IllegalStateException("No AgentProcess available for ConditionalAwaitingTool") - val context = AwaitContext( + val awaitContext = AwaitContext( input = input, agentProcess = agentProcess, tool = delegate, ) - decider.evaluate(context)?.let { awaitable -> + decider.evaluate(awaitContext)?.let { awaitable -> throw AwaitableResponseException(awaitable) } - return delegate.call(input) + return delegate.call(input, context) } } @@ -131,7 +132,7 @@ class TypeRequestingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result { + override fun call(input: String, context: ToolCallContext): Tool.Result { throw AwaitableResponseException( TypeRequest( type = type, diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolDecorators.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolDecorators.kt index a853339ea..ee798d77e 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolDecorators.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolDecorators.kt @@ -69,9 +69,6 @@ class ObservabilityTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result = - callWithObservation(input) { delegate.call(input) } - override fun call(input: String, context: ToolCallContext): Tool.Result = callWithObservation(input) { delegate.call(input, context) } @@ -125,9 +122,6 @@ class OutputTransformingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result = - transformOutput(input) { delegate.call(input) } - override fun call(input: String, context: ToolCallContext): Tool.Result = transformOutput(input) { delegate.call(input, context) } @@ -158,9 +152,6 @@ class MetadataEnrichingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result = - callWithMetadata(input) { delegate.call(input) } - override fun call(input: String, context: ToolCallContext): Tool.Result = callWithMetadata(input) { delegate.call(input, context) } @@ -196,9 +187,6 @@ class EventPublishingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result = - callWithEvents(input) { delegate.call(input) } - override fun call(input: String, context: ToolCallContext): Tool.Result = callWithEvents(input) { delegate.call(input, context) } @@ -263,9 +251,6 @@ class ExceptionSuppressingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result = - callSuppressing { delegate.call(input) } - override fun call(input: String, context: ToolCallContext): Tool.Result = callSuppressing { delegate.call(input, context) } @@ -290,9 +275,6 @@ class AgentProcessBindingTool( override val definition: Tool.Definition = delegate.definition override val metadata: Tool.Metadata = delegate.metadata - override fun call(input: String): Tool.Result = - callWithBinding { delegate.call(input) } - override fun call(input: String, context: ToolCallContext): Tool.Result = callWithBinding { delegate.call(input, context) } diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/chat/support/AssetAddingTool.kt b/embabel-agent-api/src/main/kotlin/com/embabel/chat/support/AssetAddingTool.kt index 1319fa5eb..ff83e39d2 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/chat/support/AssetAddingTool.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/chat/support/AssetAddingTool.kt @@ -17,6 +17,7 @@ package com.embabel.chat.support import com.embabel.agent.api.tool.DelegatingTool import com.embabel.agent.api.tool.Tool +import com.embabel.agent.api.tool.ToolCallContext import com.embabel.chat.Asset import com.embabel.chat.AssetTracker import org.slf4j.LoggerFactory @@ -43,8 +44,8 @@ class AssetAddingTool( private val logger = LoggerFactory.getLogger(javaClass) - override fun call(input: String): Tool.Result { - val result = delegate.call(input) + override fun call(input: String, context: ToolCallContext): Tool.Result { + val result = delegate.call(input, context) when (result) { is Tool.Result.WithArtifact -> { val artifact = result.artifact diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/DelegatingToolArchitectureTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/DelegatingToolArchitectureTest.kt new file mode 100644 index 000000000..bc70b8858 --- /dev/null +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/tool/DelegatingToolArchitectureTest.kt @@ -0,0 +1,64 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.api.tool + +import com.tngtech.archunit.core.importer.ClassFileImporter +import com.tngtech.archunit.core.importer.ImportOption +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test + +/** + * Verifies that [DelegatingTool] implementors follow the canonical call pattern: + * override [Tool.call] (String, ToolCallContext) only, never [Tool.call] (String). + * + * [DelegatingTool.call] (String) routes to `call(String, ToolCallContext.EMPTY)` so that + * the two-arg method is the single entry point for decorator logic. If a decorator + * overrides the single-arg version instead, [com.embabel.agent.spi.loop.support.DefaultToolLoop] + * (which calls the two-arg version) will bypass the decorator's behavior silently. + * + * @see DelegatingTool + */ +class DelegatingToolArchitectureTest { + + @Test + fun `DelegatingTool implementors must not override single-arg call`() { + val classes = ClassFileImporter() + .withImportOption(ImportOption.DoNotIncludeTests()) + .importPackages("com.embabel") + + val violations = classes + .filter { it.isAssignableTo(DelegatingTool::class.java) } + .filter { !it.isInterface } + .filter { javaClass -> + javaClass.getMethods().any { method -> + method.getName() == "call" + && method.getRawParameterTypes().size == 1 + && method.getRawParameterTypes()[0].isEquivalentTo(String::class.java) + && method.getOwner() == javaClass + } + } + .map { it.getName() } + + assertThat(violations) + .describedAs( + "These DelegatingTool implementations override call(String), which is " + + "bypassed by DefaultToolLoop. Override call(String, ToolCallContext) " + + "instead — see DelegatingTool KDoc.\n " + + violations.joinToString("\n ") + ) + .isEmpty() + } +} From d61ab68fc65454d1c9b10b7827b5d792341bfacc Mon Sep 17 00:00:00 2001 From: alexheifetz Date: Wed, 4 Mar 2026 20:25:07 -0500 Subject: [PATCH 5/7] refactor: deduplicate SpringToolCallbackAdapter call paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Single-arg call(String) now delegates to call(String, ToolContext?). One code path for result mapping, logging, and error handling. Also fixes incorrect PR reference in ShellCommandsContextTest KDoc (#1456 → #1462). --- .../springai/SpringToolCallbackAdapter.kt | 20 +++---------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/SpringToolCallbackAdapter.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/SpringToolCallbackAdapter.kt index 9828ec4c7..d88107d05 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/SpringToolCallbackAdapter.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/SpringToolCallbackAdapter.kt @@ -52,23 +52,8 @@ class SpringToolCallbackAdapter( .build() } - override fun call(toolInput: String): String { - logger.debug("Executing tool '{}' with input: {}", tool.definition.name, toolInput) - - return try { - when (val result = tool.call(toolInput)) { - is Tool.Result.Text -> result.content - is Tool.Result.WithArtifact -> result.content - is Tool.Result.Error -> { - logger.warn("Tool '{}' returned error: {}", tool.definition.name, result.message) - "ERROR: ${result.message}" - } - } - } catch (e: Exception) { - logger.error("Tool '{}' threw exception: {}", tool.definition.name, e.message, e) - "ERROR: ${e.message ?: "Unknown error"}" - } - } + override fun call(toolInput: String): String = + call(toolInput, null) /** * Bridges Spring AI's ToolContext with Embabel's ToolCallContext. @@ -76,6 +61,7 @@ class SpringToolCallbackAdapter( * and passes it explicitly to the tool. */ override fun call(toolInput: String, toolContext: ToolContext?): String { + logger.debug("Executing tool '{}' with input: {}", tool.definition.name, toolInput) val context = toolContext?.let { ToolCallContext.of(it.context) } ?: ToolCallContext.EMPTY return try { when (val result = tool.call(toolInput, context)) { From 9934a00e09b13cb9b123fde2bc058f132c49c1e1 Mon Sep 17 00:00:00 2001 From: alexheifetz Date: Thu, 5 Mar 2026 22:32:04 -0500 Subject: [PATCH 6/7] Update KDoc in ShellCommandsContextTest.kt --- .../kotlin/com/embabel/agent/shell/ShellCommandsContextTest.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/embabel-agent-shell/src/test/kotlin/com/embabel/agent/shell/ShellCommandsContextTest.kt b/embabel-agent-shell/src/test/kotlin/com/embabel/agent/shell/ShellCommandsContextTest.kt index 1d09cd1d0..258d93de1 100644 --- a/embabel-agent-shell/src/test/kotlin/com/embabel/agent/shell/ShellCommandsContextTest.kt +++ b/embabel-agent-shell/src/test/kotlin/com/embabel/agent/shell/ShellCommandsContextTest.kt @@ -39,7 +39,7 @@ import org.springframework.context.ConfigurableApplicationContext import org.springframework.core.env.ConfigurableEnvironment /** - * Tests for the ToolCallContext-related shell commands introduced in PR #1456: + * Tests for the ToolCallContext-related shell commands introduced in PR #1462 (issue #1323): * set-context, show-context, and the -c flag on execute. */ class ShellCommandsContextTest { From 0cd2af84f223b1240b5d9457625b215b51a984fb Mon Sep 17 00:00:00 2001 From: alexheifetz Date: Thu, 5 Mar 2026 23:44:58 -0500 Subject: [PATCH 7/7] Rebase and add toolCallContext to doTransformWithThinking --- .../com/embabel/agent/spi/support/ToolLoopLlmOperations.kt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolLoopLlmOperations.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolLoopLlmOperations.kt index 5422f69dc..73114e5e8 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolLoopLlmOperations.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolLoopLlmOperations.kt @@ -363,6 +363,7 @@ open class ToolLoopLlmOperations( val injectedToolDecorator = createInjectedToolDecorator(llmRequestEvent, interaction) val injectionStrategy = createInjectionStrategy(interaction) + val effectiveContext = resolveToolCallContext(llmRequestEvent, interaction) val toolLoop = toolLoopFactory.create( llmMessageSender = messageSender, @@ -372,6 +373,7 @@ open class ToolLoopLlmOperations( toolDecorator = injectedToolDecorator, inspectors = interaction.inspectors, transformers = interaction.transformers, + toolCallContext = effectiveContext, ) val initialMessages = buildInitialMessages(promptContributions, messages, schemaFormat)