Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@ import com.embabel.common.ai.model.ModelProvider
import com.embabel.common.ai.model.ModelSelectionCriteria
import com.embabel.common.core.thinking.ThinkingResponse
import com.embabel.common.util.time
import jakarta.validation.ConstraintViolation
import jakarta.validation.Validator
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.lang.reflect.Field
import java.time.Duration
import java.util.concurrent.ExecutionException
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
import java.util.function.Predicate

// Log message constants to avoid duplication
private const val LLM_TIMEOUT_MESSAGE = "LLM {}: attempt {} timed out after {}ms"
Expand Down Expand Up @@ -162,6 +165,7 @@ abstract class AbstractLlmOperations(
validationPromptGenerator.generateRequirementsPrompt(
validator = validator,
outputClass = outputClass,
fieldFilter = interaction.fieldFilter,
)
)
} else {
Expand All @@ -186,6 +190,8 @@ abstract class AbstractLlmOperations(
}
if (interaction.validation) {
var constraintViolations = validator.validate(candidate)
constraintViolations =
filterConstraintViolations(constraintViolations, outputClass, interaction.fieldFilter)
if (constraintViolations.isNotEmpty()) {
// If we had violations, try again, once, before throwing an exception
candidate = dataBindingProperties.retryTemplate(interaction.id.value)
Expand All @@ -207,6 +213,8 @@ abstract class AbstractLlmOperations(
}
}
constraintViolations = validator.validate(candidate)
constraintViolations =
filterConstraintViolations(constraintViolations, outputClass, interaction.fieldFilter)
if (constraintViolations.isNotEmpty()) {
throw InvalidLlmReturnTypeException(
returnedObject = candidate as Any,
Expand All @@ -227,6 +235,17 @@ abstract class AbstractLlmOperations(
return createdObject
}

private fun <O> filterConstraintViolations(
constraintViolations: Set<ConstraintViolation<O>>,
outputClass: Class<O>,
fieldFilter: Predicate<Field>,
): Set<ConstraintViolation<O>> =
constraintViolations.filterTo(mutableSetOf()) { violation ->
runCatching { outputClass.getDeclaredField(violation.propertyPath.toString()) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

.map { fieldFilter.test(it) }
.getOrDefault(true)
}

final override fun <O> createObjectIfPossible(
messages: List<Message>,
interaction: LlmInteraction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,30 @@ package com.embabel.agent.spi.validation

import jakarta.validation.ConstraintViolation
import jakarta.validation.Validator
import java.lang.reflect.Field
import java.util.function.Predicate

class DefaultValidationPromptGenerator : ValidationPromptGenerator {

override fun generateRequirementsPrompt(
validator: Validator,
outputClass: Class<*>,
fieldFilter: Predicate<Field>,
): String {
val descriptor = validator.getConstraintsForClass(outputClass)
val requirements = mutableListOf<String>()

descriptor.constrainedProperties.forEach { propertyDescriptor ->
val propertyName = propertyDescriptor.propertyName
val constraints = propertyDescriptor.constraintDescriptors
if (filter(propertyName, outputClass, fieldFilter)) {
val constraints = propertyDescriptor.constraintDescriptors

constraints.forEach { constraint ->
val annotationType = constraint.annotation.annotationClass.simpleName
val message = constraint.messageTemplate
constraints.forEach { constraint ->
val annotationType = constraint.annotation.annotationClass.simpleName
val message = constraint.messageTemplate

requirements.add("- Field '$propertyName': $annotationType constraint ($message)")
requirements.add("- Field '$propertyName': $annotationType constraint ($message)")
}
}
}

Expand All @@ -46,6 +51,19 @@ class DefaultValidationPromptGenerator : ValidationPromptGenerator {
}
}

private fun filter(
propertyName: String,
outputClass: Class<*>,
fieldFilter: Predicate<Field>,
): Boolean = try {
val field = outputClass.getDeclaredField(propertyName)
fieldFilter.test(field)
} catch (_: NoSuchFieldException) {
true
} catch (_: SecurityException) {
true
}

/**
* (b) Generate a string based on actual constraint violations
* This describes what went wrong after validation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package com.embabel.agent.spi.validation

import jakarta.validation.ConstraintViolation
import jakarta.validation.Validator
import java.lang.reflect.Field
import java.util.function.Predicate

/**
* Generate validation prompts for JSR-380 annotated types
Expand All @@ -30,6 +32,7 @@ interface ValidationPromptGenerator {
fun generateRequirementsPrompt(
validator: Validator,
outputClass: Class<*>,
fieldFilter: Predicate<Field> = Predicate { true },
): String

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import org.springframework.ai.chat.prompt.Prompt
import org.springframework.ai.model.tool.ToolCallingChatOptions
import java.time.LocalDate
import java.util.concurrent.Executors
import java.util.function.Predicate
import kotlin.test.assertEquals

/**
Expand Down Expand Up @@ -1339,6 +1340,65 @@ class ChatClientLlmOperationsTest {
)
assertEquals(invalidHusky, createdDog, "Invalid response should have been corrected")
}

@Test
fun `field filter suppresses constraint violation for excluded field`() {
data class BorderCollie(
val name: String,
@field:Pattern(regexp = "^mince$", message = "eats field must be 'mince'")
val eats: String,
)

val invalidHusky = BorderCollie("Husky", eats = "kibble")
val fakeChatModel = FakeChatModel(jacksonObjectMapper().writeValueAsString(invalidHusky))
val setup = createChatClientLlmOperations(fakeChatModel)

// Exclude 'eats' from the field filter — its constraint violation should be ignored
val result = setup.llmOperations.createObject(
messages = listOf(UserMessage("prompt")),
interaction = LlmInteraction(
id = InteractionId("id"),
llm = LlmOptions(),
fieldFilter = Predicate { field -> field.name != "eats" },
),
outputClass = BorderCollie::class.java,
action = SimpleTestAgent.actions.first(),
agentProcess = setup.mockAgentProcess,
)

assertEquals(invalidHusky, result, "Filtered-out field violation should not block the result")
}

@Test
fun `field filter does not suppress constraint violation for included field`() {
data class BorderCollie(
val name: String,
@field:Pattern(regexp = "^mince$", message = "eats field must be 'mince'")
val eats: String,
)

val invalidHusky = BorderCollie("Husky", eats = "kibble")
val fakeChatModel = FakeChatModel(jacksonObjectMapper().writeValueAsString(invalidHusky))
val setup = createChatClientLlmOperations(fakeChatModel)

// 'eats' is still included in the filter — violation should be raised
try {
setup.llmOperations.createObject(
messages = listOf(UserMessage("prompt")),
interaction = LlmInteraction(
id = InteractionId("id"),
llm = LlmOptions(),
fieldFilter = Predicate { true },
),
outputClass = BorderCollie::class.java,
action = SimpleTestAgent.actions.first(),
agentProcess = setup.mockAgentProcess,
)
fail("Should have thrown InvalidLlmReturnTypeException")
} catch (e: InvalidLlmReturnTypeException) {
assertTrue(e.constraintViolations.any { it.propertyPath.toString() == "eats" })
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import jakarta.validation.constraints.*
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
import java.lang.reflect.Field
import java.time.LocalDate
import java.util.function.Predicate
import kotlin.test.assertEquals
import kotlin.test.assertTrue

Expand Down Expand Up @@ -87,6 +89,13 @@ class DefaultValidationPromptGeneratorTest {

class EmptyClass

data class TwoFieldClass(
@field:NotBlank(message = "Name cannot be blank")
val name: String,
@field:Email(message = "Must be a valid email address")
val email: String,
)

@Nested
inner class GenerateRequirementsPrompt {

Expand Down Expand Up @@ -302,6 +311,86 @@ class DefaultValidationPromptGeneratorTest {
}
}

@Nested
inner class FieldFilterBehavior {

@Test
fun `default field filter includes all constrained fields in requirements`() {
val result = generator.generateRequirementsPrompt(validator, TwoFieldClass::class.java)

assertTrue(result.contains("Field 'name'"), "name should be included with default filter")
assertTrue(result.contains("Field 'email'"), "email should be included with default filter")
}

@Test
fun `field filter excluding a field omits its requirements from prompt`() {
val filterExcludeEmail = Predicate<Field> { it.name != "email" }

val result = generator.generateRequirementsPrompt(validator, TwoFieldClass::class.java, filterExcludeEmail)

assertTrue(result.contains("Field 'name'"), "name should still appear")
assertTrue(!result.contains("Field 'email'"), "email should be omitted when filtered out")
}

@Test
fun `field filter excluding all fields yields no constraints message`() {
val excludeAll = Predicate<Field> { false }

val result = generator.generateRequirementsPrompt(validator, TwoFieldClass::class.java, excludeAll)

assertEquals("No validation constraints defined.", result)
}

@Test
fun `field filter including only one field omits the other`() {
val nameOnly = Predicate<Field> { it.name == "name" }

val result = generator.generateRequirementsPrompt(validator, TwoFieldClass::class.java, nameOnly)

assertTrue(result.contains("Field 'name'"))
assertTrue(!result.contains("Field 'email'"))
}

@Test
fun `filterConstraintViolations drops violations for fields excluded by filter`() {
val invalid = TwoFieldClass(name = "", email = "not-an-email")
val violations = validator.validate(invalid)
assertTrue(violations.isNotEmpty())

val nameOnly = Predicate<Field> { it.name == "name" }
val filtered = violations.filterTo(mutableSetOf()) { violation ->
runCatching { TwoFieldClass::class.java.getDeclaredField(violation.propertyPath.toString()) }
.map { nameOnly.test(it) }
.getOrDefault(true)
}

assertTrue(
filtered.all { it.propertyPath.toString() == "name" },
"only name violations should survive the filter"
)
assertTrue(
filtered.none { it.propertyPath.toString() == "email" },
"email violations should be dropped by the filter"
)
}

@Test
fun `filterConstraintViolations keeps violations when field cannot be resolved`() {
val invalid = TwoFieldClass(name = "", email = "not-an-email")
val violations = validator.validate(invalid)

val excludeAll = Predicate<Field> { false }
// Look up fields on an unrelated class so lookup always fails → default true → all kept
val filtered = violations.filterTo(mutableSetOf()) { violation ->
runCatching { String::class.java.getDeclaredField(violation.propertyPath.toString()) }
.map { excludeAll.test(it) }
.getOrDefault(true)
}

assertEquals(violations.size, filtered.size)
}
}

// Validation groups for testing
interface CreateGroup
interface UpdateGroup
Expand Down