Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ import org.springframework.ai.tool.ToolCallback
*
* @param chatModel The Spring AI ChatModel to use for LLM calls
* @param chatOptions Options for the LLM call (temperature, etc.)
* @param toolResponseContentAdapter Adapts tool response content for provider-specific
* format requirements (e.g., JSON wrapping for Google GenAI)
*/
internal class SpringAiLlmMessageSender(
private val chatModel: ChatModel,
private val chatOptions: ChatOptions,
private val toolResponseContentAdapter: ToolResponseContentAdapter = ToolResponseContentAdapter.PASSTHROUGH,
) : LlmMessageSender {

private val logger = loggerFor<SpringAiLlmMessageSender>()
Expand All @@ -48,8 +51,11 @@ internal class SpringAiLlmMessageSender(
messages: List<Message>,
tools: List<Tool>,
): LlmMessageResponse {
// Convert Embabel messages to Spring AI messages
val springAiMessages = messages.map { it.toSpringAiMessage() }.mergeConsecutiveToolResponses()
// Convert Embabel messages to Spring AI messages, applying provider-specific
// tool response formatting (e.g., JSON wrapping for Google GenAI)
val springAiMessages = messages
.map { it.toSpringAiMessage(toolResponseContentAdapter) }
.mergeConsecutiveToolResponses()

// Convert Embabel tools to Spring AI tool callbacks using existing adapter
val toolCallbacks = tools.toSpringToolCallbacks()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ import java.time.LocalDate
* @param promptContributors List of prompt contributors for this model.
* Knowledge cutoff is automatically included if knowledgeCutoffDate is set.
* @param pricingModel Pricing model for this LLM, if known
* @param toolResponseContentAdapter Adapts tool response content for provider-specific
* format requirements. Defaults to [ToolResponseContentAdapter.PASSTHROUGH].
* Google GenAI requires JSON; OpenAI/Anthropic accept plain text.
*/
@JsonSerialize(`as` = LlmMetadata::class)
data class SpringAiLlmService @JvmOverloads constructor(
Expand All @@ -54,6 +57,7 @@ data class SpringAiLlmService @JvmOverloads constructor(
override val promptContributors: List<PromptContributor> =
buildList { knowledgeCutoffDate?.let { add(KnowledgeCutoffDate(it)) } },
override val pricingModel: PricingModel? = null,
val toolResponseContentAdapter: ToolResponseContentAdapter = ToolResponseContentAdapter.PASSTHROUGH,
) : LlmService<SpringAiLlmService>, AiModel<ChatModel> {

/**
Expand All @@ -64,7 +68,7 @@ data class SpringAiLlmService @JvmOverloads constructor(

override fun createMessageSender(options: LlmOptions): LlmMessageSender {
val chatOptions = optionsConverter.convertOptions(options)
return SpringAiLlmMessageSender(chatModel, chatOptions)
return SpringAiLlmMessageSender(chatModel, chatOptions, toolResponseContentAdapter)
}

override fun withKnowledgeCutoffDate(date: LocalDate): SpringAiLlmService =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright 2024-2026 Embabel Pty Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.embabel.agent.spi.support.springai

import com.fasterxml.jackson.databind.ObjectMapper

/**
* Adapts tool response content to meet provider-specific format requirements
* before the content is sent to the LLM.
*
* Some LLM providers impose constraints on tool response format. For example,
* Google GenAI requires tool responses (`FunctionResponse.response`) to be valid
* JSON objects — plain text causes `JsonParseException` or silent data loss.
* Other providers like OpenAI and Anthropic accept plain text as-is.
*
* This interface allows each provider's auto-configuration to supply its own
* adapter, keeping provider-specific logic out of the shared message conversion
* infrastructure.
*
* Follows the same pattern as [com.embabel.common.ai.model.OptionsConverter]
* — a per-provider strategy plugged in at configuration time.
*
* @see JsonWrappingToolResponseContentAdapter
*/
fun interface ToolResponseContentAdapter {

/**
* Adapt tool response content for the target LLM provider.
*
* @param content The raw tool response content (may be plain text, JSON, etc.)
* @return The adapted content suitable for the target provider
*/
fun adapt(content: String): String

companion object {

/**
* Default adapter that passes content through unchanged.
* Suitable for providers that accept plain text in tool responses
* (e.g., OpenAI, Anthropic).
*/
@JvmField
val PASSTHROUGH = ToolResponseContentAdapter { it }
}
}

/**
* Wraps non-JSON tool response content in a JSON object for providers
* that require valid JSON in tool responses (e.g., Google GenAI / Gemini).
*
* Behavior:
* - Content that already looks like a JSON object (`{...}`) or array (`[...]`)
* is passed through unchanged.
* - All other content is wrapped as `{"result": "<content>"}`.
*
* This acts as a safety net at the provider boundary. Tools are encouraged
* to return structured JSON directly (see `enabledToolsJson()` in
* [com.embabel.agent.api.tool.progressive.UnfoldingTool]), but this adapter
* catches any remaining plain-text responses from arbitrary [com.embabel.agent.api.tool.Tool]
* implementations.
*
* Note: The `startsWith` check is a fast-path heuristic, not full JSON
* validation. Malformed JSON-like strings (e.g., `{not json}`) will pass
* through and may fail downstream — this is acceptable since such output
* is extremely rare from well-behaved tools.
*/
class JsonWrappingToolResponseContentAdapter : ToolResponseContentAdapter {

private val objectMapper = ObjectMapper()

override fun adapt(content: String): String {
val trimmed = content.trimStart()
if (trimmed.startsWith('{') || trimmed.startsWith('[')) {
return content
}
return objectMapper.writeValueAsString(mapOf("result" to content))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,14 @@ import org.springframework.ai.chat.messages.UserMessage as SpringAiUserMessage

/**
* Convert one of our messages to a Spring AI message with multimodal support.
*
* @param toolResponseContentAdapter Adapts tool response content for provider-specific
* format requirements (e.g., JSON wrapping for Google GenAI).
* Defaults to [ToolResponseContentAdapter.PASSTHROUGH].
*/
fun Message.toSpringAiMessage(): SpringAiMessage {
fun Message.toSpringAiMessage(
toolResponseContentAdapter: ToolResponseContentAdapter = ToolResponseContentAdapter.PASSTHROUGH,
): SpringAiMessage {
val name = (this as? BaseMessage)?.name
val metadata: Map<String, Any> = if (name != null) mapOf("name" to name) else emptyMap()
return when (this) {
Expand All @@ -51,7 +57,7 @@ fun Message.toSpringAiMessage(): SpringAiMessage {
val toolResponse = ToolResponseMessage.ToolResponse(
this.toolCallId,
this.toolName,
this.content
toolResponseContentAdapter.adapt(this.content)
)
ToolResponseMessage.builder().responses(listOf(toolResponse)).metadata(metadata).build()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,4 +514,90 @@ class MessageConversionTest {
)
}
}

/**
* Tests for provider-specific tool response content adaptation.
*
* Verifies that [ToolResponseContentAdapter] is correctly applied during
* message conversion, allowing providers like Google GenAI to receive
* JSON-wrapped tool responses while others receive plain text.
*
* See: https://github.com/embabel/embabel-agent/issues/1391
*/
@Nested
inner class ToolResponseContentAdapterTests {

@Test
fun `tool result uses PASSTHROUGH adapter by default`() {
val message = ToolResultMessage(
toolCallId = "call-1",
toolName = "search",
content = "plain text result"
)

val springAiMessage = message.toSpringAiMessage() as ToolResponseMessage

assertThat(springAiMessage.responses[0].responseData())
.isEqualTo("plain text result")
}

@Test
fun `tool result applies JsonWrapping adapter for plain text`() {
val adapter = JsonWrappingToolResponseContentAdapter()
val message = ToolResultMessage(
toolCallId = "call-1",
toolName = "search",
content = "2 results: HBNB Services - Technical Blockchain Advisor"
)

val springAiMessage = message.toSpringAiMessage(adapter) as ToolResponseMessage
val responseData = springAiMessage.responses[0].responseData()

assertThat(responseData).startsWith("{")
assertThat(responseData).contains("\"result\"")
assertThat(responseData).contains("HBNB Services")
}

@Test
fun `tool result applies JsonWrapping adapter preserving JSON content`() {
val adapter = JsonWrappingToolResponseContentAdapter()
val message = ToolResultMessage(
toolCallId = "call-1",
toolName = "get_count",
content = """{"count": 5}"""
)

val springAiMessage = message.toSpringAiMessage(adapter) as ToolResponseMessage

assertThat(springAiMessage.responses[0].responseData())
.isEqualTo("""{"count": 5}""")
}

@Test
fun `adapter does not affect non-tool messages`() {
val adapter = JsonWrappingToolResponseContentAdapter()
val userMsg = UserMessage("Hello")
val assistantMsg = AssistantMessage("Hi")
val systemMsg = SystemMessage("You are helpful")

assertThat(userMsg.toSpringAiMessage(adapter).text).isEqualTo("Hello")
assertThat(assistantMsg.toSpringAiMessage(adapter).text).isEqualTo("Hi")
assertThat(systemMsg.toSpringAiMessage(adapter).text).isEqualTo("You are helpful")
}

@Test
fun `custom adapter is applied to tool response`() {
val uppercaseAdapter = ToolResponseContentAdapter { it.uppercase() }
val message = ToolResultMessage(
toolCallId = "call-1",
toolName = "test",
content = "hello world"
)

val springAiMessage = message.toSpringAiMessage(uppercaseAdapter) as ToolResponseMessage

assertThat(springAiMessage.responses[0].responseData())
.isEqualTo("HELLO WORLD")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,50 @@ class SpringAiLlmServiceTest {
}
}

@Nested
inner class ToolResponseContentAdapterTests {

@Test
fun `defaults to PASSTHROUGH adapter`() {
val service = SpringAiLlmService(
name = "test-model",
provider = "Provider",
chatModel = mockChatModel,
)

assertThat(service.toolResponseContentAdapter)
.isSameAs(ToolResponseContentAdapter.PASSTHROUGH)
}

@Test
fun `accepts custom adapter`() {
val customAdapter = ToolResponseContentAdapter { "{\"wrapped\": \"$it\"}" }
val service = SpringAiLlmService(
name = "test-model",
provider = "Provider",
chatModel = mockChatModel,
toolResponseContentAdapter = customAdapter,
)

assertThat(service.toolResponseContentAdapter).isSameAs(customAdapter)
}

@Test
fun `adapter is preserved through copy`() {
val customAdapter = JsonWrappingToolResponseContentAdapter()
val original = SpringAiLlmService(
name = "test-model",
provider = "Provider",
chatModel = mockChatModel,
toolResponseContentAdapter = customAdapter,
)

val copy = original.copy(name = "other-model")

assertThat(copy.toolResponseContentAdapter).isSameAs(customAdapter)
}
}

@Nested
inner class ModelPropertyTests {

Expand Down
Loading
Loading