Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,18 @@ interface MatryoshkaTool : Tool {
}
}

/**
* Returns a JSON-formatted tool response listing enabled tools.
* JSON format ensures compatibility with LLM providers that require
* valid JSON in tool responses (e.g. Google Gemini).
*/
private fun enabledToolsJson(toolNames: List<String>): Tool.Result {
val toolNamesJson = toolNames.joinToString(", ") { "\"$it\"" }
return Tool.Result.text(
"""{"enabled_tools_count": ${toolNames.size}, "enabled_tools": [$toolNamesJson]}"""
)
}

/**
* Simple implementation that exposes all inner tools.
*/
Expand All @@ -571,9 +583,7 @@ private class SimpleMatryoshkaTool(

override fun call(input: String): Tool.Result {
val toolNames = innerTools.map { it.definition.name }
return Tool.Result.text(
"Enabled ${innerTools.size} tools: ${toolNames.joinToString(", ")}"
)
return enabledToolsJson(toolNames)
}
}

Expand All @@ -593,8 +603,6 @@ private class SelectableMatryoshkaTool(
override fun call(input: String): Tool.Result {
val selected = selectTools(input)
val toolNames = selected.map { it.definition.name }
return Tool.Result.text(
"Enabled ${selected.size} tools: ${toolNames.joinToString(", ")}"
)
return enabledToolsJson(toolNames)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.embabel.agent.spi.support.springai

import com.embabel.chat.*
import com.fasterxml.jackson.databind.ObjectMapper
import org.springframework.ai.content.Media
import org.springframework.core.io.ByteArrayResource
import org.springframework.util.MimeTypeUtils
Expand Down Expand Up @@ -50,7 +51,7 @@ fun Message.toSpringAiMessage(): SpringAiMessage {
val toolResponse = ToolResponseMessage.ToolResponse(
this.toolCallId,
this.toolName,
this.textContent
this.textContent.ensureJson()
)
ToolResponseMessage.builder().responses(listOf(toolResponse)).metadata(metadata).build()
}
Expand Down Expand Up @@ -123,6 +124,19 @@ internal fun List<SpringAiMessage>.mergeConsecutiveToolResponses(): List<SpringA
* Convert a Spring AI AssistantMessage to an Embabel message.
* Handles both regular messages and messages with tool calls.
*/
/**
* Ensures the string is valid JSON for Gemini compatibility.
* The Google GenAI adapter parses tool response data as JSON.
* Plain text responses must be wrapped in a JSON object.
*/
private fun String.ensureJson(): String {
val trimmed = trimStart()
if (trimmed.startsWith('{') || trimmed.startsWith('[')) {
return this
}
return ObjectMapper().writeValueAsString(mapOf("result" to this))
}

fun SpringAiAssistantMessage.toEmbabelMessage(): Message {
val toolCalls = this.toolCalls
return if (toolCalls.isNullOrEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class MatryoshkaToolTest {
}

@Test
fun `call returns message listing enabled tools`() {
fun `call returns JSON message listing enabled tools`() {
val innerTool1 = MockTool("tool_a", "Tool A") { Tool.Result.text("a") }
val innerTool2 = MockTool("tool_b", "Tool B") { Tool.Result.text("b") }

Expand All @@ -78,9 +78,10 @@ class MatryoshkaToolTest {

assertTrue(result is Tool.Result.Text)
val text = (result as Tool.Result.Text).content
assertTrue(text.contains("2 tools"))
assertTrue(text.contains("tool_a"))
assertTrue(text.contains("tool_b"))
val json = objectMapper.readTree(text)
assertEquals(2, json["enabled_tools_count"].intValue())
val toolNames = json["enabled_tools"].map { it.textValue() }
assertEquals(listOf("tool_a", "tool_b"), toolNames)
}

@Test
Expand Down Expand Up @@ -282,7 +283,7 @@ class MatryoshkaToolTest {
lastToolCall = ToolCallResult(
toolName = "outer",
toolInput = "{}",
result = "Enabled 1 tools: inner",
result = """{"enabled_tools_count": 1, "enabled_tools": ["inner"]}""",
resultObject = null,
),
iterationCount = 1,
Expand Down Expand Up @@ -316,7 +317,7 @@ class MatryoshkaToolTest {
lastToolCall = ToolCallResult(
toolName = "persistent",
toolInput = "{}",
result = "Enabled 1 tools: inner",
result = """{"enabled_tools_count": 1, "enabled_tools": ["inner"]}""",
resultObject = null,
),
iterationCount = 1,
Expand Down Expand Up @@ -353,7 +354,7 @@ class MatryoshkaToolTest {
lastToolCall = ToolCallResult(
toolName = "selector",
toolInput = """{"pick": "one"}""",
result = "Enabled 1 tools: tool1",
result = """{"enabled_tools_count": 1, "enabled_tools": ["tool1"]}""",
resultObject = null,
),
iterationCount = 1,
Expand Down Expand Up @@ -383,7 +384,7 @@ class MatryoshkaToolTest {
lastToolCall = ToolCallResult(
toolName = "empty",
toolInput = "{}",
result = "Enabled 0 tools:",
result = """{"enabled_tools_count": 0, "enabled_tools": []}""",
resultObject = null,
),
iterationCount = 1,
Expand Down Expand Up @@ -413,7 +414,7 @@ class MatryoshkaToolTest {
lastToolCall = ToolCallResult(
toolName = "composer_stats",
toolInput = "{}",
result = "Enabled 2 tools: count, getValues",
result = """{"enabled_tools_count": 2, "enabled_tools": ["count", "getValues"]}""",
resultObject = null,
),
iterationCount = 1,
Expand Down Expand Up @@ -445,7 +446,7 @@ class MatryoshkaToolTest {
lastToolCall = ToolCallResult(
toolName = "composer_stats",
toolInput = "{}",
result = "Enabled 2 tools: count, getValues",
result = """{"enabled_tools_count": 2, "enabled_tools": ["count", "getValues"]}""",
resultObject = null,
),
iterationCount = 1,
Expand Down Expand Up @@ -481,7 +482,7 @@ class MatryoshkaToolTest {
lastToolCall = ToolCallResult(
toolName = "composer_stats",
toolInput = "{}",
result = "Enabled 2 tools: count, getValues",
result = """{"enabled_tools_count": 2, "enabled_tools": ["count", "getValues"]}""",
resultObject = null,
),
iterationCount = 1,
Expand Down Expand Up @@ -520,7 +521,7 @@ class MatryoshkaToolTest {
lastToolCall = ToolCallResult(
toolName = "empty",
toolInput = "{}",
result = "Enabled 0 tools:",
result = """{"enabled_tools_count": 0, "enabled_tools": []}""",
resultObject = null,
),
iterationCount = 1,
Expand Down Expand Up @@ -550,7 +551,7 @@ class MatryoshkaToolTest {
lastToolCall = ToolCallResult(
toolName = "spotify_search",
toolInput = "{}",
result = "Enabled 2 tools: vector_search, text_search",
result = """{"enabled_tools_count": 2, "enabled_tools": ["vector_search", "text_search"]}""",
resultObject = null,
),
iterationCount = 1,
Expand Down Expand Up @@ -588,7 +589,7 @@ class MatryoshkaToolTest {
lastToolCall = ToolCallResult(
toolName = "music_search",
toolInput = "{}",
result = "Enabled 2 tools",
result = """{"enabled_tools_count": 2, "enabled_tools": ["vector_search", "text_search"]}""",
resultObject = null,
),
iterationCount = 1,
Expand Down Expand Up @@ -625,7 +626,7 @@ class MatryoshkaToolTest {
lastToolCall = ToolCallResult(
toolName = "no_notes",
toolInput = "{}",
result = "Enabled 1 tool",
result = """{"enabled_tools_count": 1, "enabled_tools": ["tool1"]}""",
resultObject = null,
),
iterationCount = 1,
Expand Down Expand Up @@ -695,7 +696,7 @@ class MatryoshkaToolTest {
lastToolCall = ToolCallResult(
toolName = "outer",
toolInput = "{}",
result = "Enabled 1 tools: inner",
result = """{"enabled_tools_count": 1, "enabled_tools": ["inner"]}""",
resultObject = null,
),
iterationCount = 1,
Expand Down Expand Up @@ -936,7 +937,7 @@ class MatryoshkaToolTest {
lastToolCall = ToolCallResult(
toolName = "music_search",
toolInput = "{}",
result = "Enabled 2 tools",
result = """{"enabled_tools_count": 2, "enabled_tools": ["vectorSearch", "textSearch"]}""",
resultObject = null,
),
iterationCount = 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,62 @@ class MessageConversionTest {
assertThat(response.name()).isEqualTo("get_weather")
assertThat(response.responseData()).isEqualTo("""{"temperature": 72}""")
}

@Test
fun `tool result with plain text content is wrapped as JSON`() {
val message = ToolResultMessage(
toolCallId = "call-1",
toolName = "search",
content = "2 results: HBNB Services - Technical Blockchain Advisor"
)

val springAiMessage = message.toSpringAiMessage() as ToolResponseMessage
val responseData = springAiMessage.responses[0].responseData()

assertThat(responseData).startsWith("{")
assertThat(responseData).contains("\"result\"")
assertThat(responseData).contains("2 results: HBNB Services - Technical Blockchain Advisor")
}

@Test
fun `tool result with valid JSON object content is preserved`() {
val message = ToolResultMessage(
toolCallId = "call-1",
toolName = "get_count",
content = """{"count": 5}"""
)

val springAiMessage = message.toSpringAiMessage() as ToolResponseMessage

assertThat(springAiMessage.responses[0].responseData()).isEqualTo("""{"count": 5}""")
}

@Test
fun `tool result with valid JSON array content is preserved`() {
val message = ToolResultMessage(
toolCallId = "call-1",
toolName = "get_items",
content = """[1, 2, 3]"""
)

val springAiMessage = message.toSpringAiMessage() as ToolResponseMessage

assertThat(springAiMessage.responses[0].responseData()).isEqualTo("""[1, 2, 3]""")
}

@Test
fun `tool result with whitespace-only content is wrapped as JSON`() {
val message = ToolResultMessage(
toolCallId = "call-1",
toolName = "no_result",
content = " "
)

val springAiMessage = message.toSpringAiMessage() as ToolResponseMessage
val responseData = springAiMessage.responses[0].responseData()

assertThat(responseData).isEqualTo("""{"result":" "}""")
}
}

@Nested
Expand Down