diff --git a/frontend/src/components/agents-table/agent-configuration-form.tsx b/frontend/src/components/agents-table/agent-configuration-form.tsx index 8af23625..621d5e85 100644 --- a/frontend/src/components/agents-table/agent-configuration-form.tsx +++ b/frontend/src/components/agents-table/agent-configuration-form.tsx @@ -6,7 +6,7 @@ import { DynamicConfigInput } from "./dynamic-config-input"; import { TypeInfo } from "@/lib/types"; interface AgentConfigurationFormProps { - agentConfigSections: Record | Record[]>; + agentConfigSections: Record>; openSections: Record; setOpenSections: (sections: Record) => void; config: Record; @@ -27,18 +27,13 @@ export function AgentConfigurationForm({ }; useEffect(() => { - const checkForSectionErrors = (section: string, fields: Record | Record[]) => { - if (Array.isArray(fields)) { - return fields.some(obj => { - if (obj && typeof obj === 'object') { - return Object.keys(obj).some(key => fieldErrors[`${section}_${key}`]); - } - return false; - }); - } else if (fields && typeof fields === 'object') { - return Object.keys(fields).some(key => fieldErrors[`${section}_${key}`]); + const checkForSectionErrors = (section: string, fields: Record) => { + if (section === 'tool_config') { + return Object.entries(fields as Record>).some(([toolName, toolFields]) => + Object.keys(toolFields || {}).some(key => fieldErrors[`tool_config_${toolName}_${key}`]) + ); } - return false; + return Object.keys(fields).some(key => fieldErrors[`${section}_${key}`]); }; const expandSectionsWithErrors = () => { @@ -64,13 +59,13 @@ export function AgentConfigurationForm({ expandSectionsWithErrors(); }, [fieldErrors, agentConfigSections, openSections, setOpenSections]); - const hasConfigFields = Object.values(agentConfigSections).some(section => { - if (Array.isArray(section)) { - return section.some(obj => obj && typeof obj === 'object' && Object.keys(obj).length > 0); - } else if (section && typeof section === 'object') { - return Object.keys(section).length > 0; + const hasConfigFields = Object.entries(agentConfigSections).some(([section, sectionData]) => { + if (section === 'tool_config') { + return Object.values(sectionData as Record>).some( + toolFields => toolFields && typeof toolFields === 'object' && Object.keys(toolFields).length > 0 + ); } - return false; + return sectionData !== null && sectionData !== undefined && typeof sectionData === 'object' && Object.keys(sectionData as object).length > 0; }); const renderConfigField = (key: string, schema: unknown, configKey: string) => { @@ -95,16 +90,25 @@ export function AgentConfigurationForm({ ); }; - const renderToolsArrayFields = (section: string, fields: Record[]) => { - return fields.map((obj, idx) => { - const sObj = obj && typeof obj === 'object' ? obj : {}; - if (Object.keys(sObj).length === 0) return null; - + const renderRegularFields = (section: string, fields: Record) => { + return Object.entries(fields).map(([key, schema]) => { + const configKey = `${section}_${key}`; + return typeof key === 'string' ? renderConfigField(key, schema, configKey) : null; + }); + }; + + const renderToolConfigFields = (fields: Record>) => { + return Object.entries(fields).map(([toolName, toolFields]) => { + const sFields = toolFields && typeof toolFields === 'object' ? toolFields : {}; + if (Object.keys(sFields).length === 0) return null; return ( -
+
+

+ {toolName.replace(/_/g, ' ')} +

- {Object.entries(sObj).map(([key, schema]) => { - const configKey = `${section}_${key}`; + {Object.entries(sFields).map(([key, schema]) => { + const configKey = `tool_config_${toolName}_${key}`; return typeof key === 'string' ? renderConfigField(key, schema, configKey) : null; })}
@@ -113,18 +117,13 @@ export function AgentConfigurationForm({ }); }; - const renderRegularFields = (section: string, fields: Record) => { - return Object.entries(fields).map(([key, schema]) => { - const configKey = `${section}_${key}`; - return typeof key === 'string' ? renderConfigField(key, schema, configKey) : null; - }); - }; - - const renderConfigSection = (section: string, fields: Record | Record[]) => { + const renderConfigSection = (section: string, fields: Record) => { let hasFields = false; - if (Array.isArray(fields)) { - hasFields = fields.some(obj => obj && typeof obj === 'object' && Object.keys(obj).length > 0); - } else if (fields && typeof fields === 'object') { + if (section === 'tool_config') { + hasFields = Object.values(fields as Record>).some( + toolFields => toolFields && typeof toolFields === 'object' && Object.keys(toolFields).length > 0 + ); + } else { hasFields = Object.keys(fields).length > 0; } if (!hasFields) return null; @@ -132,9 +131,9 @@ export function AgentConfigurationForm({ const sectionTitle = String(section).replace(/_/g, ' '); return ( - setOpenSections({ ...openSections, [section]: open })} > @@ -149,9 +148,9 @@ export function AgentConfigurationForm({
- {Array.isArray(fields) - ? renderToolsArrayFields(section, fields) - : renderRegularFields(section, fields as Record)} + {section === 'tool_config' + ? renderToolConfigFields(fields as Record>) + : renderRegularFields(section, fields)}
diff --git a/frontend/src/components/agents-table/hooks/use-agent-form.ts b/frontend/src/components/agents-table/hooks/use-agent-form.ts index efc4cdcd..1249dde5 100644 --- a/frontend/src/components/agents-table/hooks/use-agent-form.ts +++ b/frontend/src/components/agents-table/hooks/use-agent-form.ts @@ -18,7 +18,7 @@ export function useAgentForm( const [description, setDescription] = useState(""); const [type, setType] = useState(""); const [config, setConfig] = useState>({}); - const [agentConfigSections, setAgentConfigSections] = useState | Record[]> >({}); + const [agentConfigSections, setAgentConfigSections] = useState>>({}); const [openSections, setOpenSections] = useState>({}); const [fieldErrors, setFieldErrors] = useState>({}); const [generalError, setGeneralError] = useState(""); @@ -74,7 +74,7 @@ export function useAgentForm( }; const extractConfigSectionsFromType = (selectedType: AgentChoice) => { - const configSections: Record | Record[]> = {}; + const configSections: Record> = {}; Object.entries(selectedType).forEach(([section, value]) => { if (section !== 'key' && section !== 'name' && value && typeof value === 'object') { configSections[section] = value; @@ -94,10 +94,10 @@ export function useAgentForm( const createDefaultConfigValues = (configSections: Record) => { const defaults: Record = {}; Object.entries(configSections).forEach(([section, sectionData]) => { - if (Array.isArray(sectionData)) { - sectionData.forEach(obj => { - if (obj && typeof obj === 'object') { - Object.keys(obj).forEach(k => { defaults[`${section}_${k}`] = ""; }); + if (section === 'tool_config' && sectionData && typeof sectionData === 'object' && !Array.isArray(sectionData)) { + Object.entries(sectionData as Record>).forEach(([toolName, toolFields]) => { + if (toolFields && typeof toolFields === 'object') { + Object.keys(toolFields).forEach(k => { defaults[`tool_config_${toolName}_${k}`] = ""; }); } }); } else if (sectionData && typeof sectionData === 'object') { @@ -138,11 +138,25 @@ export function useAgentForm( setConfig(prev => ({ ...prev, [key]: value })); }; + const buildToolConfigDictConfig = (fields: Record>, currentConfig: Record) => { + const out: Record> = {}; + Object.entries(fields).forEach(([toolName, toolFields]) => { + out[toolName] = {}; + if (toolFields && typeof toolFields === 'object') { + Object.keys(toolFields).forEach(k => { + const v = currentConfig[`tool_config_${toolName}_${k}`]; + if (v !== '' && v !== null && v !== undefined) out[toolName][k] = v; + }); + } + }); + return out; + }; + const buildNestedConfigFromFlattened = () => { const configObj: Record = {}; Object.entries(agentConfigSections).forEach(([section, fields]) => { - if (Array.isArray(fields)) { - configObj[section] = buildToolsArrayConfig(section, fields); + if (section === 'tool_config') { + configObj[section] = buildToolConfigDictConfig(fields as Record>, config); } else if (fields && typeof fields === 'object') { configObj[section] = buildRegularSectionConfig(section, fields); } @@ -150,19 +164,6 @@ export function useAgentForm( return configObj; }; - const buildToolsArrayConfig = (section: string, fields: Record[]) => { - return fields.map(obj => { - const out: Record = {}; - if (obj && typeof obj === 'object') { - Object.keys(obj).forEach(k => { - const v = config[`${section}_${k}`]; - if (v !== '' && v !== null && v !== undefined) out[k] = v; - }); - } - return out; - }); - }; - const buildRegularSectionConfig = (section: string, fields: Record) => { const out: Record = {}; Object.keys(fields).forEach(k => { @@ -191,20 +192,18 @@ export function useAgentForm( }; const parseFieldErrors = (errorData: unknown) => { - const newFieldErrors = parseFieldErrorsBase(errorData); - + const { config: configErrors, ...topLevelErrors } = (errorData as Record) ?? {}; + const newFieldErrors = parseFieldErrorsBase(topLevelErrors); + // Add agent-specific error parsing for complex sections - if (typeof errorData === 'object' && errorData && 'config' in errorData) { - const configErrors = (errorData as { config: unknown }).config; - if (typeof configErrors === 'object' && configErrors) { - const sectionNames = ['agent_args', 'prompt_input', 'prompt_extension']; - - sectionNames.forEach(sectionName => { - parseSingleConfigSection(sectionName, (configErrors as Record)[sectionName], newFieldErrors); - }); - - parseToolsArrayErrors((configErrors as Record).tools, newFieldErrors); - } + if (typeof configErrors === 'object' && configErrors) { + const sectionNames = ['agent_args', 'prompt_input', 'prompt_extension']; + + sectionNames.forEach(sectionName => { + parseSingleConfigSection(sectionName, (configErrors as Record)[sectionName], newFieldErrors); + }); + + parseToolConfigErrors((configErrors as Record).tool_config, newFieldErrors); } return newFieldErrors; @@ -218,12 +217,12 @@ export function useAgentForm( } }; - const parseToolsArrayErrors = (toolsErrors: unknown, newFieldErrors: Record) => { - if (toolsErrors && Array.isArray(toolsErrors)) { - toolsErrors.forEach((toolErrors: unknown) => { + const parseToolConfigErrors = (toolConfigErrors: unknown, newFieldErrors: Record) => { + if (toolConfigErrors && typeof toolConfigErrors === 'object' && !Array.isArray(toolConfigErrors)) { + Object.entries(toolConfigErrors as Record).forEach(([toolName, toolErrors]) => { if (toolErrors && typeof toolErrors === 'object') { Object.entries(toolErrors as Record).forEach(([field, error]) => { - newFieldErrors[`tools_${field}`] = Array.isArray(error) ? String(error[0]) : String(error); + newFieldErrors[`tool_config_${toolName}_${field}`] = Array.isArray(error) ? String(error[0]) : String(error); }); } }); diff --git a/frontend/src/lib/api/agents.ts b/frontend/src/lib/api/agents.ts index 49758d06..f5966817 100644 --- a/frontend/src/lib/api/agents.ts +++ b/frontend/src/lib/api/agents.ts @@ -1,14 +1,14 @@ import {BaseApiClient} from "@/lib/api/base.ts"; -import {Agent, AgentConfig, AgentDetails} from "@/lib/types.ts"; +import {Agent, AgentConfig, AgentDetails, ExtraArgDetail} from "@/lib/types.ts"; import {ApiError} from "@/lib/api-error.ts"; export type AgentChoice = { key: string; name: string; agent_args: Record; - prompt_inputs: Record; + prompt_input: Record; prompt_extension: Record; - tools: Record[]; + tool_config: Record>; }; type AvailableAgentsResponse = { diff --git a/frontend/src/lib/form-utils.ts b/frontend/src/lib/form-utils.ts index 1d3873de..614391ca 100644 --- a/frontend/src/lib/form-utils.ts +++ b/frontend/src/lib/form-utils.ts @@ -3,7 +3,7 @@ */ export function flattenConfigForForm( - config: unknown, + config: unknown, sectionMapping?: Record ): Record { const flattenedConfig: Record = {}; @@ -15,17 +15,18 @@ export function flattenConfigForForm( Object.entries(config as Record).forEach(([section, sectionData]) => { const frontendSection = sectionMapping?.[section] || section; - if (Array.isArray(sectionData)) { - sectionData.forEach(obj => { - if (obj && typeof obj === 'object') { - Object.entries(obj).forEach(([key, value]) => { + if (section === 'tool_config' && sectionData && typeof sectionData === 'object' && !Array.isArray(sectionData)) { + // 2-level: {toolName: {fieldName: value}} + Object.entries(sectionData as Record>).forEach(([toolName, toolFields]) => { + if (toolFields && typeof toolFields === 'object') { + Object.entries(toolFields).forEach(([key, value]) => { if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') { - flattenedConfig[`${frontendSection}_${key}`] = value; + flattenedConfig[`${section}_${toolName}_${key}`] = value; } else if (typeof value === 'object' && value !== null) { try { - flattenedConfig[`${frontendSection}_${key}`] = JSON.stringify(value); + flattenedConfig[`${section}_${toolName}_${key}`] = JSON.stringify(value); } catch { - flattenedConfig[`${frontendSection}_${key}`] = '{}'; + flattenedConfig[`${section}_${toolName}_${key}`] = '{}'; } } }); diff --git a/frontend/src/lib/types.ts b/frontend/src/lib/types.ts index 38109580..6d6455c1 100644 --- a/frontend/src/lib/types.ts +++ b/frontend/src/lib/types.ts @@ -123,7 +123,7 @@ export type AgentConfig = { agent_args?: Record; prompt_input?: Record; prompt_extension?: Record; - tools?: Array>; + tool_config?: Record>; }; export type ConversationFile = { diff --git a/plugins/enthusiast-common/enthusiast_common/agents/base.py b/plugins/enthusiast-common/enthusiast_common/agents/base.py index 7238f2b8..9e165767 100644 --- a/plugins/enthusiast-common/enthusiast_common/agents/base.py +++ b/plugins/enthusiast-common/enthusiast_common/agents/base.py @@ -68,12 +68,17 @@ def _get_system_prompt_variables(self) -> dict: return {} def set_runtime_arguments(self, runtime_arguments: Any) -> None: - tools_runtime_arguments = runtime_arguments.pop("tools") + """Inject stored config values into agent fields and named tool instances.""" + tool_config = runtime_arguments.get("tool_config", {}) for key, value in runtime_arguments.items(): + if key == "tool_config": + continue class_field_key = key.upper() field = getattr(self, class_field_key) if field is None: continue setattr(self, key.upper(), field(**value)) - for index, tool_runtime_args in enumerate(tools_runtime_arguments): - self._tools[index].set_runtime_arguments(tool_runtime_args) + for tool in self._tools: + tool_runtime_args = tool_config.get(tool.NAME) + if tool_runtime_args is not None: + tool.set_runtime_arguments(tool_runtime_args) diff --git a/plugins/enthusiast-common/enthusiast_common/agents/config.py b/plugins/enthusiast-common/enthusiast_common/agents/config.py index 1c9db481..bff6c38e 100644 --- a/plugins/enthusiast-common/enthusiast_common/agents/config.py +++ b/plugins/enthusiast-common/enthusiast_common/agents/config.py @@ -1,7 +1,11 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from enum import StrEnum +from typing import TYPE_CHECKING -from ..config import AgentConfig +if TYPE_CHECKING: + from ..config.base import AgentConfig class ConfigType(StrEnum): diff --git a/server/agent/management/commands/verifyagents.py b/server/agent/management/commands/verifyagents.py index d7e40517..a4fd04dc 100644 --- a/server/agent/management/commands/verifyagents.py +++ b/server/agent/management/commands/verifyagents.py @@ -30,11 +30,11 @@ def handle(self, *args, **options): agent_types_cache[agent.agent_type] = agent_class config = deepcopy(agent.config) - tools_config = config.pop("tools", []) + tool_config = config.pop("tool_config", {}) try: - self._validate_agent_config(agent_class, config, tools_config) - except (ValidationError, IndexError): + self._validate_agent_config(agent_class, config, tool_config) + except (ValidationError, TypeError, AttributeError): corrupted_count += 1 agent.corrupted = True agent.save(update_fields=["corrupted"]) @@ -42,7 +42,7 @@ def handle(self, *args, **options): print(f"Corrupted agent configurations found: {corrupted_count}") @staticmethod - def _validate_agent_config(agent_class, config, tools_config): + def _validate_agent_config(agent_class, config, tool_config): for key, value in config.items(): class_field_key = key.upper() field = getattr(agent_class, class_field_key, None) @@ -50,11 +50,12 @@ def _validate_agent_config(agent_class, config, tools_config): continue field(**value) - if len(agent_class.TOOLS) != len(tools_config): - raise ValidationError.from_exception_data("Agent configurations do not match", line_errors=[]) - - for index, tool in enumerate(agent_class.TOOLS): - field = getattr(tool, "CONFIGURATION_ARGS", None) + for tool_config_entry in agent_class.TOOLS: + tool_class = tool_config_entry.tool_class + field = getattr(tool_class, "CONFIGURATION_ARGS", None) if field is None: continue - field(**tools_config[index]) + tool_name = tool_class.NAME + if tool_name not in tool_config: + raise ValidationError.from_exception_data("Agent configurations do not match", line_errors=[]) + field(**tool_config[tool_name]) diff --git a/server/agent/serializers/configuration.py b/server/agent/serializers/configuration.py index 8f335739..eb00517c 100644 --- a/server/agent/serializers/configuration.py +++ b/server/agent/serializers/configuration.py @@ -4,7 +4,7 @@ from agent.core.registries.agents.agent_registry import AgentRegistry from agent.models import Agent -from agent.serializers.customs.fields import PydanticModelField, PydanticModelToolListField +from agent.serializers.customs.fields import PydanticModelField, PydanticModelToolConfigField from catalog.models import DataSet @@ -14,7 +14,9 @@ class AgentChoiceSerializer(serializers.Serializer): agent_args = serializers.DictField(child=ExtraArgDetailSerializer(), allow_empty=True) prompt_input = serializers.DictField(child=ExtraArgDetailSerializer(), allow_empty=True) prompt_extension = serializers.DictField(child=ExtraArgDetailSerializer(), allow_empty=True) - tools = serializers.ListField(child=serializers.DictField(child=ExtraArgDetailSerializer()), allow_empty=True) + tool_config = serializers.DictField( + child=serializers.DictField(child=ExtraArgDetailSerializer(), allow_empty=True), allow_empty=True + ) class AvailableAgentsResponseSerializer(serializers.Serializer): @@ -27,7 +29,7 @@ class AgentConfigSerializer(ParentDataContextSerializerMixin, serializers.Serial agent_args = PydanticModelField(agent_field_name="AGENT_ARGS") prompt_input = PydanticModelField(agent_field_name="PROMPT_INPUT") prompt_extension = PydanticModelField(agent_field_name="PROMPT_EXTENSION") - tools = PydanticModelToolListField(agent_field_name="TOOLS", tool_field_name="CONFIGURATION_ARGS") + tool_config = PydanticModelToolConfigField(agent_field_name="TOOLS", tool_field_name="CONFIGURATION_ARGS") class AgentSerializer(ParentDataContextSerializerMixin, serializers.ModelSerializer): diff --git a/server/agent/serializers/customs/fields.py b/server/agent/serializers/customs/fields.py index 05b36e17..c48c9e0c 100644 --- a/server/agent/serializers/customs/fields.py +++ b/server/agent/serializers/customs/fields.py @@ -45,7 +45,7 @@ class Meta: swagger_schema_fields = {"type": openapi.TYPE_OBJECT} -class PydanticModelToolListField(BasePydanticModelField): +class PydanticModelToolConfigField(BasePydanticModelField): def __init__(self, *, agent_field_name: str, tool_field_name: str, **kwargs): self.agent_field_name = agent_field_name self.tool_field_name = tool_field_name @@ -62,30 +62,37 @@ def to_internal_value(self, data): except Exception as e: raise serializers.ValidationError(f"Error loading agent or field: {str(e)}") - if not isinstance(data, list): - raise serializers.ValidationError("Expected a list of tool configurations.") - if len(tool_config_list) != len(data): - raise serializers.ValidationError("Mismatch between number of tools and provided configs.") + if not isinstance(data, dict): + raise serializers.ValidationError("Expected a dict of tool configurations keyed by tool name.") - validated = [] - all_errors = [] + tool_map = { + tc.tool_class.NAME: tc.tool_class + for tc in tool_config_list + if getattr(tc.tool_class, "NAME", None) is not None + } + + validated = {} + all_errors = {} has_errors = False - for idx, (tool_config_obj, tool_config_dict) in enumerate(zip(tool_config_list, data)): - config_schema = getattr(tool_config_obj.tool_class, self.tool_field_name, None) + for tool_name, tool_config_dict in data.items(): + tool_class = tool_map.get(tool_name) + if tool_class is None: + all_errors[tool_name] = [f"Unknown tool: {tool_name}"] + has_errors = True + continue + + config_schema = getattr(tool_class, self.tool_field_name, None) if not config_schema or not isinstance(config_schema, type) or not issubclass(config_schema, BaseModel): - all_errors.append({}) - validated.append({}) + validated[tool_name] = {} continue try: instance = config_schema(**tool_config_dict) - validated.append(instance.model_dump()) - all_errors.append({}) + validated[tool_name] = instance.model_dump() except PydanticValidationError as e: has_errors = True - all_errors.append(self._format_pydantic_errors(e)) - validated.append(None) + all_errors[tool_name] = self._format_pydantic_errors(e) if has_errors: raise serializers.ValidationError(all_errors) @@ -93,12 +100,7 @@ def to_internal_value(self, data): return validated def to_representation(self, value): - if isinstance(value, list) and all(isinstance(v, BaseModel) for v in value): - return [v.model_dump() for v in value] return value class Meta: - swagger_schema_fields = { - "type": openapi.TYPE_ARRAY, - "items": {"type": openapi.TYPE_OBJECT}, - } + swagger_schema_fields = {"type": openapi.TYPE_OBJECT} diff --git a/server/agent/serializers/customs/tests/test_fields.py b/server/agent/serializers/customs/tests/test_fields.py index f9453a47..1e3b4b32 100644 --- a/server/agent/serializers/customs/tests/test_fields.py +++ b/server/agent/serializers/customs/tests/test_fields.py @@ -8,7 +8,7 @@ from rest_framework import serializers from rest_framework.exceptions import APIException, ValidationError -from agent.serializers.customs.fields import PydanticModelField, PydanticModelToolListField +from agent.serializers.customs.fields import PydanticModelField, PydanticModelToolConfigField class DummySchema(BaseModel): @@ -21,7 +21,13 @@ class BadSchema(BaseModel): class DummyTool(BaseFunctionTool): - CONFIGURATION = DummySchema + NAME = "dummy_tool" + CONFIGURATION_ARGS = DummySchema + + +class DummyTool2(BaseFunctionTool): + NAME = "dummy_tool_2" + CONFIGURATION_ARGS = DummySchema def get_model_serializer( @@ -34,13 +40,14 @@ class FieldTestSerializer(serializers.Serializer): return FieldTestSerializer(data=data, context={"agent_type": "dummy_agent"}) -def get_list_model_serializer(agent_field_name: str, tool_field_name: str, data: Any): +def get_tool_config_serializer(agent_field_name: str, tool_field_name: str, data: Any): class FieldTestSerializer(serializers.Serializer): - config = PydanticModelToolListField(agent_field_name=agent_field_name, tool_field_name=tool_field_name) + config = PydanticModelToolConfigField(agent_field_name=agent_field_name, tool_field_name=tool_field_name) return FieldTestSerializer(data=data, context={"agent_type": "dummy_agent"}) + @pytest.fixture def agent_context(): return {"agent_type": "dummy_agent"} @@ -95,15 +102,16 @@ def test_pydantic_model_field_import_error(mock_import, mock_settings, available assert "Error loading agent" in str(e.value) + @patch("agent.core.registries.agents.agent_registry.settings") @patch("agent.serializers.customs.fields.AgentRegistry.get_agent_class_by_type") -def test_pydantic_model_tool_list_field_valid(mock_import, mock_settings, available_agents): +def test_pydantic_model_tool_config_field_valid(mock_import, mock_settings, available_agents): mock_settings.AVAILABLE_AGENTS = available_agents mock_import.return_value = type( - "Agent", (), {"TOOLS": [FunctionToolConfig(tool_class=DummyTool), FunctionToolConfig(tool_class=DummyTool)]} + "Agent", (), {"TOOLS": [FunctionToolConfig(tool_class=DummyTool), FunctionToolConfig(tool_class=DummyTool2)]} ) - input_data = {"config": [{"value_1": "Alice", "value_2": 25}, {"value_1": "Bob", "value_2": 30}]} - serializer = get_list_model_serializer(agent_field_name="TOOLS", tool_field_name="CONFIGURATION", data=input_data) + input_data = {"config": {"dummy_tool": {"value_1": "Alice", "value_2": 25}}} + serializer = get_tool_config_serializer(agent_field_name="TOOLS", tool_field_name="CONFIGURATION_ARGS", data=input_data) serializer.is_valid(raise_exception=True) @@ -112,42 +120,81 @@ def test_pydantic_model_tool_list_field_valid(mock_import, mock_settings, availa @patch("agent.core.registries.agents.agent_registry.settings") @patch("agent.serializers.customs.fields.AgentRegistry.get_agent_class_by_type") -def test_pydantic_model_tool_list_field_invalid_configs_number(mock_import, mock_settings, available_agents): +def test_pydantic_model_tool_config_field_invalid_type(mock_import, mock_settings, available_agents): mock_settings.AVAILABLE_AGENTS = available_agents - mock_import.return_value = type("Agent", (), {"TOOLS": [DummyTool]}) - input_data = {"config": []} - serializer = get_list_model_serializer(agent_field_name="TOOLS", tool_field_name="CONFIGURATION", data=input_data) + mock_import.return_value = type( + "Agent", (), {"TOOLS": [FunctionToolConfig(tool_class=DummyTool)]} + ) + input_data = {"config": [{"value_1": "Alice", "value_2": 25}]} + serializer = get_tool_config_serializer(agent_field_name="TOOLS", tool_field_name="CONFIGURATION_ARGS", data=input_data) with pytest.raises(ValidationError) as e: serializer.is_valid(raise_exception=True) - assert "Mismatch between number of tools and provided configs." in str(e.value) + assert "Expected a dict" in str(e.value) @patch("agent.core.registries.agents.agent_registry.settings") @patch("agent.serializers.customs.fields.AgentRegistry.get_agent_class_by_type") -def test_pydantic_model_tool_list_field_invalid_config_type(mock_import, mock_settings, available_agents): +def test_pydantic_model_tool_config_field_unknown_tool_name(mock_import, mock_settings, available_agents): mock_settings.AVAILABLE_AGENTS = available_agents - mock_import.return_value = type("Agent", (), {"TOOLS": [DummyTool]}) - input_data = {"config": {}} - serializer = get_list_model_serializer(agent_field_name="TOOLS", tool_field_name="CONFIGURATION", data=input_data) + mock_import.return_value = type( + "Agent", (), {"TOOLS": [FunctionToolConfig(tool_class=DummyTool)]} + ) + input_data = {"config": {"nonexistent_tool": {"value_1": "Alice", "value_2": 25}}} + serializer = get_tool_config_serializer(agent_field_name="TOOLS", tool_field_name="CONFIGURATION_ARGS", data=input_data) with pytest.raises(ValidationError) as e: serializer.is_valid(raise_exception=True) - assert "Expected a list of tool configurations." in str(e.value) + error_detail = e.value.detail["config"]["nonexistent_tool"] + assert isinstance(error_detail, list) + assert "Unknown tool" in str(error_detail[0]) @patch("agent.core.registries.agents.agent_registry.settings") @patch("agent.serializers.customs.fields.AgentRegistry.get_agent_class_by_type") -def test_pydantic_model_tool_list_field_invalid_data(mock_import, mock_settings, agent_context, available_agents): +def test_pydantic_model_tool_config_field_invalid_field_values(mock_import, mock_settings, available_agents): mock_settings.AVAILABLE_AGENTS = available_agents mock_import.return_value = type( - "Agent", (), {"TOOLS": [FunctionToolConfig(tool_class=DummyTool), FunctionToolConfig(tool_class=DummyTool)]} + "Agent", (), {"TOOLS": [FunctionToolConfig(tool_class=DummyTool), FunctionToolConfig(tool_class=DummyTool2)]} ) - input_data = {"config": [{"value_1": "Missing age"}, {"value_2": 28}]} - serializer = get_list_model_serializer(agent_field_name="TOOLS", tool_field_name="CONFIGURATION", data=input_data) + input_data = {"config": { + "dummy_tool": {"value_1": "Missing value_2"}, + "dummy_tool_2": {"value_2": 99}, + }} + serializer = get_tool_config_serializer(agent_field_name="TOOLS", tool_field_name="CONFIGURATION_ARGS", data=input_data) serializer.is_valid() - assert len(serializer.errors["config"]) == 2 - assert "value_2" in serializer.errors["config"][0].keys() - assert "value_1" in serializer.errors["config"][1].keys() + assert "dummy_tool" in serializer.errors["config"] + assert "dummy_tool_2" in serializer.errors["config"] + + +@patch("agent.core.registries.agents.agent_registry.settings") +@patch("agent.serializers.customs.fields.AgentRegistry.get_agent_class_by_type") +def test_pydantic_model_tool_config_field_tool_without_config_args_is_accepted( + mock_import, mock_settings, available_agents +): + class NoConfigTool: + NAME = "no_config_tool" + CONFIGURATION_ARGS = None + + class MockToolConfigEntry: + def __init__(self, tool_class): + self.tool_class = tool_class + + mock_settings.AVAILABLE_AGENTS = available_agents + mock_import.return_value = type( + "Agent", + (), + { + "TOOLS": [ + MockToolConfigEntry(DummyTool), + MockToolConfigEntry(NoConfigTool), + ] + }, + ) + input_data = {"config": {"no_config_tool": {}}} + serializer = get_tool_config_serializer(agent_field_name="TOOLS", tool_field_name="CONFIGURATION_ARGS", data=input_data) + + assert serializer.is_valid() + assert serializer.validated_data["config"]["no_config_tool"] == {} diff --git a/server/agent/services/agent_preconfiguration_service.py b/server/agent/services/agent_preconfiguration_service.py index 5d25d12d..f2f86ad1 100644 --- a/server/agent/services/agent_preconfiguration_service.py +++ b/server/agent/services/agent_preconfiguration_service.py @@ -44,8 +44,11 @@ def _build_default_agent_configuration(agent_class: BaseAgent): "agent_args": get_model_descriptor_default_value_from_class(agent_class, "AGENT_ARGS"), "prompt_input": get_model_descriptor_default_value_from_class(agent_class, "PROMPT_INPUT"), "prompt_extension": get_model_descriptor_default_value_from_class(agent_class, "PROMPT_EXTENSION"), - "tools": [ - get_model_descriptor_default_value_from_class(tool_config.tool_class, "CONFIGURATION_ARGS") + "tool_config": { + tool_config.tool_class.NAME: get_model_descriptor_default_value_from_class( + tool_config.tool_class, "CONFIGURATION_ARGS" + ) for tool_config in agent_class.TOOLS - ], + if tool_config.tool_class.CONFIGURATION_ARGS is not None + }, } diff --git a/server/agent/tests/test_management_commands.py b/server/agent/tests/test_management_commands.py index f7358283..10940bb8 100644 --- a/server/agent/tests/test_management_commands.py +++ b/server/agent/tests/test_management_commands.py @@ -23,14 +23,25 @@ class MockToolConfigurationArgs(BaseModel): class MockToolClass: + NAME = "mock_tool" CONFIGURATION_ARGS = MockToolConfigurationArgs +class MockToolClass2: + NAME = "mock_tool_2" + CONFIGURATION_ARGS = MockToolConfigurationArgs + + +class MockToolConfig: + def __init__(self, tool_class): + self.tool_class = tool_class + + class MockAgentClass: AGENT_ARGS = MockConfigurationArgs PROMPT_INPUT = MockConfigurationArgs PROMPT_EXTENSION = MockConfigurationArgs - TOOLS = [MockToolClass, MockToolClass] + TOOLS = [MockToolConfig(MockToolClass), MockToolConfig(MockToolClass2)] @pytest.fixture @@ -39,10 +50,10 @@ def config_dict() -> dict[Any, Any]: "agent_args": {"value_1": "value_1", "value_2": "value_2"}, "prompt_input": {"value_1": "value_1", "value_2": "value_2"}, "prompt_extension": {"value_1": "value_1", "value_2": "value_2"}, - "tools": [ - {"tool_value_1": "value_1", "tool_value_2": True}, - {"tool_value_1": "value_1", "tool_value_2": False}, - ], + "tool_config": { + "mock_tool": {"tool_value_1": "value_1", "tool_value_2": True}, + "mock_tool_2": {"tool_value_1": "value_1", "tool_value_2": False}, + }, } @@ -92,8 +103,9 @@ def test_verifyagents_command_invalid_type(self, mock_agent_registry, config_dic assert corrupted_agent.corrupted is True @patch("agent.management.commands.verifyagents.AgentRegistry") - def test_verifyagents_command_missing_tool_config(self, mock_agent_registry, config_dict): - config_dict["tools"] = [{"tool_value_1": "tool1", "tool_value_2": True}, {}] + def test_verifyagents_command_missing_tool_name_in_tool_config(self, mock_agent_registry, config_dict): + # mock_tool has valid config, but mock_tool_2 key is entirely absent + config_dict["tool_config"] = {"mock_tool": {"tool_value_1": "value_1", "tool_value_2": True}} corrupted_agent = baker.make(Agent, name="Corrupted Tools", agent_type=self.AGENT_TYPE, config=config_dict) mock_registry_instance = Mock() @@ -106,11 +118,8 @@ def test_verifyagents_command_missing_tool_config(self, mock_agent_registry, con assert corrupted_agent.corrupted is True @patch("agent.management.commands.verifyagents.AgentRegistry") - def test_verifyagents_command_invalid_tool_config_type(self, mock_agent_registry, config_dict): - config_dict["tools"] = [ - {"tool_value_1": "tool1", "tool_value_2": True}, - {"tool_value_1": "tool1", "tool_value_2": "InvalidType"}, - ] + def test_verifyagents_command_missing_tool_config(self, mock_agent_registry, config_dict): + config_dict["tool_config"] = {"mock_tool": {}} corrupted_agent = baker.make(Agent, name="Corrupted Tools", agent_type=self.AGENT_TYPE, config=config_dict) mock_registry_instance = Mock() @@ -123,8 +132,11 @@ def test_verifyagents_command_invalid_tool_config_type(self, mock_agent_registry assert corrupted_agent.corrupted is True @patch("agent.management.commands.verifyagents.AgentRegistry") - def test_verifyagents_command_additional_tool(self, mock_agent_registry, config_dict): - config_dict["tools"] = [] + def test_verifyagents_command_invalid_tool_config_type(self, mock_agent_registry, config_dict): + config_dict["tool_config"] = { + "mock_tool": {"tool_value_1": "val", "tool_value_2": True}, + "mock_tool_2": {"tool_value_1": "val", "tool_value_2": "InvalidType"}, + } corrupted_agent = baker.make(Agent, name="Corrupted Tools", agent_type=self.AGENT_TYPE, config=config_dict) mock_registry_instance = Mock() @@ -137,12 +149,8 @@ def test_verifyagents_command_additional_tool(self, mock_agent_registry, config_ assert corrupted_agent.corrupted is True @patch("agent.management.commands.verifyagents.AgentRegistry") - def test_verifyagents_command_missing_tool(self, mock_agent_registry, config_dict): - config_dict["tools"] = [ - {"tool_value_1": "value_1", "tool_value_2": True}, - {"tool_value_1": "value_1", "tool_value_2": True}, - {"tool_value_1": "value_1", "tool_value_2": True}, - ] + def test_verifyagents_command_additional_tool(self, mock_agent_registry, config_dict): + config_dict["tool_config"] = {} corrupted_agent = baker.make(Agent, name="Corrupted Tools", agent_type=self.AGENT_TYPE, config=config_dict) mock_registry_instance = Mock() @@ -154,6 +162,20 @@ def test_verifyagents_command_missing_tool(self, mock_agent_registry, config_dic corrupted_agent.refresh_from_db() assert corrupted_agent.corrupted is True + @patch("agent.management.commands.verifyagents.AgentRegistry") + def test_verifyagents_command_extra_tool_key_is_ignored(self, mock_agent_registry, config_dict): + config_dict["tool_config"]["unknown_tool"] = {"tool_value_1": "value_1", "tool_value_2": True} + valid_agent = baker.make(Agent, name="Extra Tool Key Agent", agent_type=self.AGENT_TYPE, config=config_dict) + + mock_registry_instance = Mock() + mock_registry_instance.get_agent_class_by_type.return_value = self.MOCK_AGENT_CLASS + mock_agent_registry.return_value = mock_registry_instance + + call_command("verifyagents") + + valid_agent.refresh_from_db() + assert valid_agent.corrupted is False + @patch("agent.management.commands.verifyagents.AgentRegistry") def test_verifyagents_command_missing_agent_type(self, mock_agent_registry, config_dict): baker.make(Agent, agent_type=self.AGENT_TYPE, config=config_dict, _quantity=3) @@ -169,7 +191,7 @@ def test_verifyagents_command_missing_agent_type(self, mock_agent_registry, conf @patch("agent.management.commands.verifyagents.AgentRegistry") def test_verifyagents_command_multiple_corrupted_agents(self, mock_agent_registry, config_dict): - config_dict["tools"] = [{}, {}] + config_dict["tool_config"] = {"mock_tool": {}} baker.make(Agent, name="Corrupted Agent 1", agent_type=self.AGENT_TYPE, config=config_dict) baker.make(Agent, name="Corrupted Agent 2", agent_type=self.AGENT_TYPE, config=config_dict) @@ -211,13 +233,29 @@ class MockAgentClassWithNoneFields: Agent, name="Minimal Agent", agent_type="minimal_agent", - config={"agent_args": {}, "prompt_input": {}, "prompt_extension": {}, "tools": []}, + config={"agent_args": {}, "prompt_input": {}, "prompt_extension": {}, "tool_config": {}}, ) call_command("verifyagents") assert Agent.objects.filter(corrupted=True).count() == 0 + @patch("agent.management.commands.verifyagents.AgentRegistry") + def test_verifyagents_command_non_dict_config_value(self, mock_agent_registry): + corrupted_agent = baker.make( + Agent, + agent_type=self.AGENT_TYPE, + config={"agent_args": "not_a_dict", "prompt_input": {}, "prompt_extension": {}, "tool_config": {}}, + ) + mock_registry_instance = Mock() + mock_registry_instance.get_agent_class_by_type.return_value = self.MOCK_AGENT_CLASS + mock_agent_registry.return_value = mock_registry_instance + + call_command("verifyagents") + + corrupted_agent.refresh_from_db() + assert corrupted_agent.corrupted is True + def test_verifyagents_command_empty_database(self): Agent.objects.all().delete() diff --git a/server/agent/tests/test_runtime_arguments.py b/server/agent/tests/test_runtime_arguments.py new file mode 100644 index 00000000..0d90a201 --- /dev/null +++ b/server/agent/tests/test_runtime_arguments.py @@ -0,0 +1,82 @@ +from unittest.mock import MagicMock + +from enthusiast_common.agents.base import BaseAgent + +# Pure-Python unit tests — no Django DB required, no pytest.mark.django_db needed. +# _StubAgent is a minimal non-metaclass object used to call BaseAgent.set_runtime_arguments +# directly as an unbound method, bypassing the AgentExtraArgsClassBaseMeta validation. + + +class _StubAgent: + """Minimal agent-like object for testing set_runtime_arguments without metaclass validation.""" + AGENT_ARGS = None + PROMPT_INPUT = None + PROMPT_EXTENSION = None + + def __init__(self, tools): + self._tools = tools + + +def test_set_runtime_arguments_injects_config_by_tool_name(): + tool = MagicMock() + tool.NAME = "my_tool" + agent = _StubAgent(tools=[tool]) + + BaseAgent.set_runtime_arguments(agent, { + "agent_args": {}, + "prompt_input": {}, + "prompt_extension": {}, + "tool_config": {"my_tool": {"proxy": "http://example.com"}}, + }) + + tool.set_runtime_arguments.assert_called_once_with({"proxy": "http://example.com"}) + + +def test_set_runtime_arguments_skips_tool_absent_from_tool_config(): + """Tools with no CONFIGURATION_ARGS (e.g. StopExecutionTool) won't have a tool_config entry — must not crash.""" + tool = MagicMock() + tool.NAME = "stop_execution" + agent = _StubAgent(tools=[tool]) + + BaseAgent.set_runtime_arguments(agent, { + "agent_args": {}, + "prompt_input": {}, + "prompt_extension": {}, + "tool_config": {}, + }) + + tool.set_runtime_arguments.assert_not_called() + + +def test_set_runtime_arguments_tolerates_missing_tool_config_key(): + """Old-format configs without a tool_config key must not crash.""" + tool = MagicMock() + tool.NAME = "my_tool" + agent = _StubAgent(tools=[tool]) + + BaseAgent.set_runtime_arguments(agent, { + "agent_args": {}, + "prompt_input": {}, + "prompt_extension": {}, + }) + + tool.set_runtime_arguments.assert_not_called() + + +def test_set_runtime_arguments_only_injects_matching_tool(): + """When multiple tools are present, only the matching one receives its config.""" + tool_a = MagicMock() + tool_a.NAME = "tool_a" + tool_b = MagicMock() + tool_b.NAME = "tool_b" + agent = _StubAgent(tools=[tool_a, tool_b]) + + BaseAgent.set_runtime_arguments(agent, { + "agent_args": {}, + "prompt_input": {}, + "prompt_extension": {}, + "tool_config": {"tool_a": {"key": "value"}}, + }) + + tool_a.set_runtime_arguments.assert_called_once_with({"key": "value"}) + tool_b.set_runtime_arguments.assert_not_called() diff --git a/server/agent/tests/test_services.py b/server/agent/tests/test_services.py index 8257e64a..64d0f3b6 100644 --- a/server/agent/tests/test_services.py +++ b/server/agent/tests/test_services.py @@ -19,6 +19,7 @@ class ToolArgs(RequiredFieldsModel): class DummyTool(BaseFunctionTool): + NAME = "dummy_tool" CONFIGURATION_ARGS = ToolArgs @@ -45,7 +46,7 @@ class MockAgentClass: AGENT_ARGS = AgentArgs PROMPT_INPUT = PromptInput PROMPT_EXTENSION = PromptExtension - TOOLS = [FunctionToolConfig(tool_class=DummyTool), FunctionToolConfig(tool_class=DummyTool)] + TOOLS = [FunctionToolConfig(tool_class=DummyTool)] FILE_UPLOAD = False @@ -67,7 +68,7 @@ class MockAgentClassWithoutDefaults: "agent_args": {"with_default": "default"}, "prompt_extension": {"with_default": "default"}, "prompt_input": {"with_default": "default"}, - "tools": [{"with_default": "default"}, {"with_default": "default"}], + "tool_config": {"dummy_tool": {"with_default": "default"}}, } diff --git a/server/agent/tests/test_views.py b/server/agent/tests/test_views.py index 46a1d194..8758a87e 100644 --- a/server/agent/tests/test_views.py +++ b/server/agent/tests/test_views.py @@ -58,6 +58,7 @@ class ToolArgs(RequiredFieldsModel): class DummyTool(BaseFunctionTool): + NAME = "dummy_tool" CONFIGURATION_ARGS = ToolArgs @@ -80,7 +81,7 @@ class DummyAgentBase: AGENT_ARGS = AgentArgs PROMPT_INPUT = PromptInput PROMPT_EXTENSION = PromptExtension - TOOLS = [FunctionToolConfig(tool_class=DummyTool), FunctionToolConfig(tool_class=DummyTool)] + TOOLS = [FunctionToolConfig(tool_class=DummyTool)] FILE_UPLOAD = False @@ -115,10 +116,9 @@ def config(): "required_test": "required_test", "optional_test": "optional_test", }, - "tools": [ - {"required_test": "required_test", "optional_test": "optional_test"}, - {"required_test": "required_test", "optional_test": "optional_test"}, - ], + "tool_config": { + "dummy_tool": {"required_test": "required_test", "optional_test": "optional_test"}, + }, } @@ -132,9 +132,12 @@ def test_get_agent_types_returns_200(self, api_client): assert len(response.data["choices"]) == 2 assert response.data["choices"][0]["key"] == "agent_1" assert response.data["choices"][1]["key"] == "agent_2" - assert len(response.data["choices"][0]["tools"]) == 2 - assert len(response.data["choices"][1]["tools"]) == 2 - assert list(response.data["choices"][0]["tools"][0].keys()) == ["required_test", "optional_test"] + assert len(response.data["choices"][0]["tool_config"]) == 1 + assert len(response.data["choices"][1]["tool_config"]) == 1 + assert list(response.data["choices"][0]["tool_config"]["dummy_tool"].keys()) == [ + "required_test", + "optional_test", + ] def test_get_agent_types_returns_401(self): response = APIClient().get(reverse("agent-types")) @@ -223,14 +226,11 @@ def test_post_creates_agent(self, api_client, url, config): "prompt_extension": { "required_test": "required_test", }, - "tools": [ - { - "required_test": "required_test", - }, - { + "tool_config": { + "dummy_tool": { "required_test": "required_test", }, - ], + }, } payload = { "name": "name", @@ -262,10 +262,9 @@ def test_post_creates_agent_optional_fields_saved(self, api_client, url, config) "required_test": "required_test", "optional_test": "optional_test", }, - "tools": [ - {"required_test": "required_test", "optional_test": "optional_test"}, - {"required_test": "required_test", "optional_test": "optional_test"}, - ], + "tool_config": { + "dummy_tool": {"required_test": "required_test", "optional_test": "optional_test"}, + }, } payload = { "name": "name", @@ -281,8 +280,7 @@ def test_post_creates_agent_optional_fields_saved(self, api_client, url, config) assert response.status_code == status.HTTP_201_CREATED created = Agent.objects.get(pk=response.data["id"]) assert created.name == "name" - assert created.config["tools"][0].get("optional_test") == "optional_test" - assert created.config["tools"][1].get("optional_test") == "optional_test" + assert created.config["tool_config"]["dummy_tool"].get("optional_test") == "optional_test" assert created.config["agent_args"].get("optional_test") == "optional_test" assert created.config["prompt_input"].get("optional_test") == "optional_test" assert created.config["prompt_extension"].get("optional_test") == "optional_test" @@ -299,7 +297,9 @@ def test_post_creates_agent_do_not_save_empty_field(self, api_client, url, confi "prompt_extension": { "required_test": "required_test", }, - "tools": [{"required_test": "required_test"}, {"required_test": "required_test"}], + "tool_config": { + "dummy_tool": {"required_test": "required_test"}, + }, } payload = { "name": "name", @@ -309,14 +309,11 @@ def test_post_creates_agent_do_not_save_empty_field(self, api_client, url, confi "agent_type": "agent_1", } - class NoArgsDummyTool(BaseFunctionTool): - CONFIGURATION_ARGS = None - class NoArgsDummyAgent: AGENT_ARGS = None PROMPT_INPUT = PromptInput PROMPT_EXTENSION = PromptExtension - TOOLS = [FunctionToolConfig(tool_class=NoArgsDummyTool), FunctionToolConfig(tool_class=DummyTool)] + TOOLS = [FunctionToolConfig(tool_class=DummyTool)] FILE_UPLOAD = False with patch( @@ -328,8 +325,10 @@ class NoArgsDummyAgent: created = Agent.objects.get(pk=response.data["id"]) assert created.name == "name" assert created.config["agent_args"] == {} - assert created.config["tools"][0] == {} - assert created.config["tools"][1] == {"required_test": "required_test", "optional_test": "default"} + assert created.config["tool_config"]["dummy_tool"] == { + "required_test": "required_test", + "optional_test": "default", + } class TestAgentDetailsView: @@ -370,10 +369,9 @@ def test_put_updates_agent(self, api_client, agent_instance, url, config): "required_test": "required_updated", "optional_test": "optional_updated", }, - "tools": [ - {"required_test": "required_upated", "optional_test": "optional_updated"}, - {"required_test": "required_updated", "optional_test": "optional_updated"}, - ], + "tool_config": { + "dummy_tool": {"required_test": "required_updated", "optional_test": "optional_updated"}, + }, } payload = { "name": "updated", @@ -407,10 +405,9 @@ def test_put_removes_corrupted_flag_for_correct_data(self, api_client, url, conf "required_test": "required_updated", "optional_test": "optional_updated", }, - "tools": [ - {"required_test": "required_upated", "optional_test": "optional_updated"}, - {"required_test": "required_updated", "optional_test": "optional_updated"}, - ], + "tool_config": { + "dummy_tool": {"required_test": "required_updated", "optional_test": "optional_updated"}, + }, } payload = { "name": "updated", diff --git a/server/agent/views.py b/server/agent/views.py index 11fd623f..62099e0b 100644 --- a/server/agent/views.py +++ b/server/agent/views.py @@ -254,10 +254,13 @@ def get(self, request): "agent_args": get_model_descriptor_from_class_field(agent_class, "AGENT_ARGS"), "prompt_input": get_model_descriptor_from_class_field(agent_class, "PROMPT_INPUT"), "prompt_extension": get_model_descriptor_from_class_field(agent_class, "PROMPT_EXTENSION"), - "tools": [ - get_model_descriptor_from_class_field(tool_config.tool_class, "CONFIGURATION_ARGS") + "tool_config": { + tool_config.tool_class.NAME: get_model_descriptor_from_class_field( + tool_config.tool_class, "CONFIGURATION_ARGS" + ) for tool_config in agent_class.TOOLS - ], + if tool_config.tool_class.CONFIGURATION_ARGS is not None + }, } ) response_serializer = AvailableAgentsResponseSerializer(data={"choices": choices}) diff --git a/specs/tool-configuration-model-refactor.md b/specs/tool-configuration-model-refactor.md new file mode 100644 index 00000000..e6f874cf --- /dev/null +++ b/specs/tool-configuration-model-refactor.md @@ -0,0 +1,112 @@ +# Tool Configuration Model Refactor + +## Summary + +Replaced the positional `tools: [{...}, {}]` array in agent DB configs with a named `tool_config: {"tool_name": {...}}` dict. This fixes a crash in agentic execution and removes positional coupling between the runtime tool list and stored configuration. + +--- + +## Problem + +Tools in the system define per-tool configuration via `CONFIGURATION_ARGS` on `BaseTool`. That configuration was persisted as a positional array under the `tools` key in the agent's JSON config: + +```json +{ + "agent_args": {}, + "prompt_input": {"auto_confirm": true}, + "prompt_extension": {}, + "tools": [{"field_name": ""}, {}] +} +``` + +At runtime, `BaseAgent.set_runtime_arguments()` injected config into each tool by array index: `self._tools[i].set_runtime_arguments(config["tools"][i])`. + +Two concrete problems: + +1. **Positional coupling.** `StopExecutionTool` is appended to the tool list during agentic execution at runtime. The DB config was created with N tools; at runtime there are N+1. The index access for the last tool raised `IndexError` — all agentic executions crashed on initialisation. + +2. **Distributed config responsibility.** Agent configuration was split across the agent (agent_args, prompt fields) and its tools (tools array). No single place held the full config. + +--- + +## Solution + +`tool_config` is a dict keyed by `tool.NAME`. Tools with no `CONFIGURATION_ARGS` have no entry. A missing entry is silently skipped — not an error. + +```json +{ + "agent_args": {}, + "prompt_input": {"auto_confirm": true}, + "prompt_extension": {}, + "tool_config": { + "tool_name": {"field_name": "value"} + } +} +``` + +--- + +## Backend Changes + +### `BaseAgent.set_runtime_arguments` (`enthusiast-common`) + +Named lookup replaces the positional index loop: + +```python +def set_runtime_arguments(self, runtime_arguments: Any) -> None: + tool_config = runtime_arguments.get("tool_config", {}) + for key, value in runtime_arguments.items(): + if key == "tool_config": + continue + # ... agent-level field injection unchanged ... + for tool in self._tools: + tool_runtime_args = tool_config.get(tool.NAME) + if tool_runtime_args is not None: + tool.set_runtime_arguments(tool_runtime_args) +``` + +Tools absent from `tool_config` are skipped without raising. This fixes the `StopExecutionTool` crash. + +### `AgentPreconfigurationService._build_default_agent_configuration` + +Builds a `tool_config` dict instead of a `tools` list. Tools with `CONFIGURATION_ARGS = None` are excluded: + +```python +"tool_config": { + tool_config.tool_class.NAME: get_model_descriptor_default_value_from_class( + tool_config.tool_class, "CONFIGURATION_ARGS" + ) + for tool_config in agent_class.TOOLS + if tool_config.tool_class.CONFIGURATION_ARGS is not None +}, +``` + +### `PydanticModelToolConfigField` (new, replaces `PydanticModelToolListField`) + +Validates the `tool_config` dict on API config updates: + +- Accepts `{"tool_name": {"field_name": "value"}}` instead of `[{"field_name": "value"}, {}]` +- Validates each key against the agent's tool list by `NAME` +- Validates values against the tool's `CONFIGURATION_ARGS` Pydantic schema +- Unknown tool names → validation error `"Unknown tool: {name}"` +- Tools with `CONFIGURATION_ARGS = None` are not valid keys + +`AgentConfigSerializer` and `AgentChoiceSerializer` were updated to use `tool_config`. `AgentTypesView` builds the `tool_config` dict filtered by `CONFIGURATION_ARGS is not None`. + +### `verifyagents` Management Command + +Updated to validate tools by named lookup via `tool_config_entry.tool_class.NAME` instead of positional index. Missing keys in `tool_config` mark the agent as corrupted. Non-dict config values (`TypeError`) are also caught. + +--- + +## Frontend Changes + +`AgentConfig.tool_config` and `AgentChoice.tool_config` replace the old `tools` array in the TypeScript types. + +The form flattens the two-level structure to `tool_config_${toolName}_${fieldName}` flat keys for React Hook Form state, then reconstructs the nested dict on submit. Server-side validation errors are mapped back to the same flat keys for inline field highlighting. The config form renders each tool group in its own labelled bordered card. + +--- + +## Breaking Change + +No DB migration. Existing agent configs with the old `tools: [...]` format are incompatible. **Agents must be re-preconfigured on deploy.**