diff --git a/examples/chat/file_explorer_agent.py b/examples/chat/file_explorer_agent.py index 786d245053..b8c591df6b 100644 --- a/examples/chat/file_explorer_agent.py +++ b/examples/chat/file_explorer_agent.py @@ -564,12 +564,11 @@ async def chat( # noqa: PLR0912 """ Chat implementation with non-blocking confirmation support. - The agent will check context.tool_confirmations for any confirmations. - If a hook needs confirmation but hasn't been confirmed yet, it will - yield a ConfirmationRequest and exit. The frontend will then send a - new request with the confirmation in context.tool_confirmations. + Confirmed tools from the previous turn are executed directly by the chat layer + (via ``resolve_pending_confirmations``), then their results are injected into + history before the agent continues. The LLM never regenerates the tool call, + so the confirmation_id hash stays stable and there is no approval loop. """ - # Create agent with history passed explicitly agent: Agent = Agent( llm=self.llm, prompt=f""" @@ -592,10 +591,15 @@ async def chat( # noqa: PLR0912 history=history, ) - # Create agent context with tool_confirmations from the request context - agent_context: AgentRunContext = AgentRunContext() + # Execute tools the user just approved/declined: mutates agent.history in place + # with synthetic (tool_use, tool_result) pairs so the LLM continues from the + # results instead of re-deciding the confirmed call. + for response in await self.resolve_pending_confirmations(agent, context): + yield response - # Pass tool_confirmations from ChatContext to AgentRunContext + # Forward tool_confirmations to the agent context — supports any legacy + # hash-matched confirmations for tools not routed through direct execution. + agent_context: AgentRunContext = AgentRunContext() if context.tool_confirmations: agent_context.tool_confirmations = context.tool_confirmations @@ -616,7 +620,10 @@ async def chat( # noqa: PLR0912 yield self.create_live_update(response.id, LiveUpdateType.START, f"🔧 {response.name}") case ConfirmationRequest(): - # Confirmation needed - send to frontend and wait for user response + # Persist the pending confirmation so the next turn can resolve it + # directly (via resolve_pending_confirmations) instead of asking + # the LLM to regenerate the tool call. + yield self.create_state_update(self.create_pending_confirmation_state(response)) yield ConfirmationRequestResponse(content=ConfirmationRequestContent(confirmation_request=response)) case ToolCallResult(): diff --git a/packages/ragbits-agents/CHANGELOG.md b/packages/ragbits-agents/CHANGELOG.md index 8d897dc4be..b75d3a8cc9 100644 --- a/packages/ragbits-agents/CHANGELOG.md +++ b/packages/ragbits-agents/CHANGELOG.md @@ -2,6 +2,8 @@ ## Unreleased +- Fix tool-confirmation loop across conversation turns caused by LLM argument drift (#969). Adds `Agent.execute_tool_directly` and an `inject_tool_call` helper so chat layers can resume confirmed tools without re-prompting the LLM. `HookManager` also falls back to matching by `tool_name` when the exact `confirmation_id` hash misses. `ConfirmationRequest` now carries `tool_call_id`. + ## 1.6.2 (2026-03-26) - ragbits-core updated to version v1.6.2 diff --git a/packages/ragbits-agents/src/ragbits/agents/_main.py b/packages/ragbits-agents/src/ragbits/agents/_main.py index 59a473083d..02987c0c4d 100644 --- a/packages/ragbits-agents/src/ragbits/agents/_main.py +++ b/packages/ragbits-agents/src/ragbits/agents/_main.py @@ -1,4 +1,5 @@ import asyncio +import json import types import uuid import warnings @@ -1061,6 +1062,57 @@ async def _process_tool_output( else: yield ToolReturn(value=tool_output, metadata=None) + async def execute_tool_directly( + self, + tool_call_id: str, + tool_name: str, + arguments: dict[str, Any], + context: AgentRunContext, + ) -> ToolCallResult: + """ + Execute a tool with caller-supplied arguments, returning its final result. + + Intended for chat layers resuming after a user confirmation: rather than asking + the LLM to regenerate the tool call (which risks argument drift and a broken + confirmation_id match), the chat layer stores the pre-confirmation arguments + and replays them directly through this method. + + PRE_TOOL hooks still run. If a hook requests confirmation, the caller is + responsible for having populated ``context.tool_confirmations`` with the + matching approval so the existing hash-match path resolves to ``pass``. + A ``deny`` decision from any PRE_TOOL hook is respected and short-circuits + execution. POST_TOOL hooks run on success. + + Args: + tool_call_id: Identifier to attach to the resulting ``ToolCallResult``. + tool_name: Name of the tool to invoke. + arguments: Arguments to pass to the tool (should be the original + pre-confirmation arguments to keep the confirmation_id stable). + context: Agent run context, including any prior ``tool_confirmations``. + + Returns: + The ``ToolCallResult`` yielded by the tool execution path. + + Raises: + AgentToolNotAvailableError: If the tool is not registered on this agent. + """ + tools_mapping = await self._get_all_tools() + # ToolCall declares arguments as dict but has a "before" validator that + # json.loads strings, so we pass the serialized form to satisfy the validator. + tool_call = ToolCall( + id=tool_call_id, + type="function", + name=tool_name, + arguments=json.dumps(arguments), # type: ignore[arg-type] + ) + result: ToolCallResult | None = None + async for item in self._execute_tool(tool_call=tool_call, tools_mapping=tools_mapping, context=context): + if isinstance(item, ToolCallResult): + result = item + if result is None: + raise RuntimeError(f"Tool {tool_name!r} produced no ToolCallResult") + return result + async def _execute_tool( self, tool_call: ToolCall, diff --git a/packages/ragbits-agents/src/ragbits/agents/confirmation.py b/packages/ragbits-agents/src/ragbits/agents/confirmation.py index 6cd0483b4c..15e1a4f36b 100644 --- a/packages/ragbits-agents/src/ragbits/agents/confirmation.py +++ b/packages/ragbits-agents/src/ragbits/agents/confirmation.py @@ -14,6 +14,9 @@ class ConfirmationRequest(BaseModel): confirmation_id: str """Unique identifier for this confirmation request.""" + tool_call_id: str + """Identifier of the originating ToolCall — threads the tool_use and tool_result messages + when the chat layer resumes execution via Agent.execute_tool_directly.""" tool_name: str """Name of the tool requiring confirmation.""" tool_description: str diff --git a/packages/ragbits-agents/src/ragbits/agents/history.py b/packages/ragbits-agents/src/ragbits/agents/history.py new file mode 100644 index 0000000000..da2cce24ce --- /dev/null +++ b/packages/ragbits-agents/src/ragbits/agents/history.py @@ -0,0 +1,51 @@ +"""Helpers for manipulating agent conversation history (ChatFormat).""" + +import json +from typing import Any + +from ragbits.core.prompt.base import ChatFormat + + +def inject_tool_call( + history: ChatFormat, + tool_call_id: str, + tool_name: str, + arguments: dict[str, Any], + result: Any, # noqa: ANN401 +) -> ChatFormat: + """ + Append a synthetic (tool_use, tool_result) pair to a conversation history. + + Used by the chat layer when it has executed a tool on the user's behalf + (e.g., after a confirmation was approved) and needs the LLM to see the + outcome without re-deciding the call itself. + + The returned list is a shallow copy with two messages appended in OpenAI's + tool-use format — an ``assistant`` turn carrying the ``tool_calls`` block + and a ``tool`` turn carrying the result keyed by ``tool_call_id``. + + Args: + history: Current conversation history. Not mutated. + tool_call_id: Identifier to thread the tool_use and tool messages. + tool_name: Name of the tool that was invoked. + arguments: Arguments the tool was invoked with. + result: Tool output. Coerced to ``str``. + + Returns: + A new ChatFormat with the two messages appended. + """ + return [ + *history, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": tool_call_id, + "type": "function", + "function": {"name": tool_name, "arguments": json.dumps(arguments)}, + } + ], + }, + {"role": "tool", "tool_call_id": tool_call_id, "content": str(result)}, + ] diff --git a/packages/ragbits-agents/src/ragbits/agents/hooks/manager.py b/packages/ragbits-agents/src/ragbits/agents/hooks/manager.py index 58550ce0f8..00c5746688 100644 --- a/packages/ragbits-agents/src/ragbits/agents/hooks/manager.py +++ b/packages/ragbits-agents/src/ragbits/agents/hooks/manager.py @@ -9,7 +9,7 @@ import json from collections import defaultdict from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Generic, Literal, overload +from typing import TYPE_CHECKING, Any, Generic, Literal, overload from ragbits.agents.confirmation import ConfirmationRequest from ragbits.agents.hooks.base import Hook @@ -111,6 +111,48 @@ def get_hooks(self, event_type: EventType, tool_name: str | None = None) -> list return [hook for hook in hooks if hook.matches_tool(tool_name)] + @staticmethod + def _compute_confirmation_id(hook_name: str, tool_name: str, arguments: dict[str, Any]) -> str: + """ + Compute the confirmation_id for a given (hook, tool, arguments) triple. + + Exposed so the chat layer (or tests) can reproduce the same id when resuming + a paused confirmation flow. + """ + payload = f"{hook_name}:{tool_name}:{json.dumps(arguments, sort_keys=True)}" + return hashlib.sha256(payload.encode()).hexdigest()[:CONFIRMATION_ID_LENGTH] + + @staticmethod + def _find_confirmation( + tool_confirmations: list[dict[str, Any]], + confirmation_id: str, + tool_name: str, + ) -> dict[str, Any] | None: + """ + Find a matching confirmation entry. + + Tries exact confirmation_id match first, then falls back to tool_name match. + The tool_name fallback handles cross-turn scenarios where the LLM regenerates + arguments with cosmetic differences, changing the hash. + + Args: + tool_confirmations: List of confirmation entries from context + confirmation_id: The computed confirmation ID for this tool call + tool_name: The name of the tool being called + + Returns: + The matching confirmation entry, or None if not found + """ + for conf in tool_confirmations: + if conf.get("confirmation_id") == confirmation_id: + return conf + + for conf in tool_confirmations: + if conf.get("tool_name") == tool_name: + return conf + + return None + async def execute_pre_tool( self, tool_call: ToolCall, @@ -132,12 +174,11 @@ async def execute_pre_tool( current_tool_call = tool_call.model_copy() for hook in self.get_hooks(EventType.PRE_TOOL, tool_call.name): - # Generate confirmation_id: hash(hook_function_name + tool_name + arguments) - hook_name = hook.callback.__name__ - confirmation_id_str = ( - f"{hook_name}:{tool_call.name}:{json.dumps(current_tool_call.arguments, sort_keys=True)}" + confirmation_id = self._compute_confirmation_id( + hook_name=hook.callback.__name__, + tool_name=tool_call.name, + arguments=current_tool_call.arguments, ) - confirmation_id = hashlib.sha256(confirmation_id_str.encode()).hexdigest()[:CONFIRMATION_ID_LENGTH] result: ToolCall = await hook.callback(current_tool_call) @@ -148,28 +189,25 @@ async def execute_pre_tool( return result, None elif result.decision == "ask": - # Check if already confirmed/declined in context - for conf in context.tool_confirmations: - if conf.get("confirmation_id") == confirmation_id: - if conf.get("confirmed"): - # Approved → convert to "pass" and continue to next hook - result = result.model_copy(update={"decision": "pass"}) - break - else: - # Declined → convert to "deny" and stop immediately - return ( - result.model_copy( - update={ - "decision": "deny", - "reason": "❌ Action declined by user", - } - ), - None, - ) + matched = self._find_confirmation(context.tool_confirmations, confirmation_id, tool_call.name) + + if matched is not None: + if matched.get("confirmed"): + result = result.model_copy(update={"decision": "pass"}) + else: + return ( + result.model_copy( + update={ + "decision": "deny", + "reason": "❌ Action declined by user", + } + ), + None, + ) else: - # Not in context → return "ask" with ConfirmationRequest confirmation_request = ConfirmationRequest( confirmation_id=confirmation_id, + tool_call_id=tool_call.id, tool_name=tool_call.name, tool_description=result.reason, # type: ignore[arg-type] # guaranteed non-None by ValueError check above arguments=current_tool_call.arguments, diff --git a/packages/ragbits-agents/tests/unit/hooks/test_manager.py b/packages/ragbits-agents/tests/unit/hooks/test_manager.py index 7212b9a4c6..e923e08001 100644 --- a/packages/ragbits-agents/tests/unit/hooks/test_manager.py +++ b/packages/ragbits-agents/tests/unit/hooks/test_manager.py @@ -140,6 +140,73 @@ async def test_ask_with_prior_confirmation(self, tool_call: ToolCall, ask_hook: result, _ = await manager.execute_pre_tool(tool_call, ctx_declined) assert result.decision == "deny" + @pytest.mark.asyncio + async def test_ask_with_tool_name_fallback_approved(self, tool_call: ToolCall, ask_hook: PreToolCallback): + """When confirmation_id doesn't match (cross-turn hash drift), fall back to tool_name match.""" + manager: HookManager = HookManager(hooks=[Hook(event_type=EventType.PRE_TOOL, callback=ask_hook)]) + + # Simulate cross-turn: frontend sends back tool_name but with a stale confirmation_id + ctx: AgentRunContext = AgentRunContext( + tool_confirmations=[ + {"confirmation_id": "stale_id_from_previous_turn", "tool_name": "test_tool", "confirmed": True} + ] + ) + result, confirmation = await manager.execute_pre_tool(tool_call, ctx) + + assert result.decision == "pass" + assert confirmation is None + + @pytest.mark.asyncio + async def test_ask_with_tool_name_fallback_declined(self, tool_call: ToolCall, ask_hook: PreToolCallback): + """When confirmation_id doesn't match but tool_name does and user declined.""" + manager: HookManager = HookManager(hooks=[Hook(event_type=EventType.PRE_TOOL, callback=ask_hook)]) + + ctx: AgentRunContext = AgentRunContext( + tool_confirmations=[ + {"confirmation_id": "stale_id_from_previous_turn", "tool_name": "test_tool", "confirmed": False} + ] + ) + result, confirmation = await manager.execute_pre_tool(tool_call, ctx) + + assert result.decision == "deny" + assert confirmation is None + + @pytest.mark.asyncio + async def test_exact_confirmation_id_takes_priority_over_tool_name( + self, tool_call: ToolCall, ask_hook: PreToolCallback + ): + """Exact confirmation_id match should be used even if a tool_name entry also exists.""" + manager: HookManager = HookManager(hooks=[Hook(event_type=EventType.PRE_TOOL, callback=ask_hook)]) + exact_id = make_confirmation_id("ask_hook", "test_tool", {"arg1": "value1"}) + + ctx: AgentRunContext = AgentRunContext( + tool_confirmations=[ + # tool_name match says declined + {"confirmation_id": "wrong_id", "tool_name": "test_tool", "confirmed": False}, + # exact confirmation_id match says approved — should win + {"confirmation_id": exact_id, "confirmed": True}, + ] + ) + result, _ = await manager.execute_pre_tool(tool_call, ctx) + + assert result.decision == "pass" + + @pytest.mark.asyncio + async def test_tool_name_fallback_does_not_match_different_tool( + self, tool_call: ToolCall, ask_hook: PreToolCallback + ): + """tool_name fallback should not match confirmations for a different tool.""" + manager: HookManager = HookManager(hooks=[Hook(event_type=EventType.PRE_TOOL, callback=ask_hook)]) + + ctx: AgentRunContext = AgentRunContext( + tool_confirmations=[{"confirmation_id": "some_id", "tool_name": "other_tool", "confirmed": True}] + ) + result, confirmation = await manager.execute_pre_tool(tool_call, ctx) + + assert result.decision == "ask" + assert confirmation is not None + assert confirmation.tool_name == "test_tool" + @pytest.mark.asyncio async def test_chaining( self, tool_call: ToolCall, context: AgentRunContext, pre_tool_add_field: Callable[..., PreToolCallback] diff --git a/packages/ragbits-agents/tests/unit/test_agent.py b/packages/ragbits-agents/tests/unit/test_agent.py index ad01bad27f..05728c2188 100644 --- a/packages/ragbits-agents/tests/unit/test_agent.py +++ b/packages/ragbits-agents/tests/unit/test_agent.py @@ -913,6 +913,94 @@ async def test_pre_tool_hook_ask_with_confirmation_approved(llm_with_tool_call: assert "72" in result.tool_calls[0].result +async def test_execute_tool_directly_runs_tool_with_given_arguments(llm_without_tool_call: MockLLM): + """execute_tool_directly runs the named tool with explicit args, bypassing LLM tool selection.""" + agent = Agent(llm=llm_without_tool_call, prompt=CustomPrompt, tools=[get_weather]) + + context: AgentRunContext = AgentRunContext() + result = await agent.execute_tool_directly( + tool_call_id="call_1", + tool_name="get_weather", + arguments={"location": "San Francisco"}, + context=context, + ) + + assert result.id == "call_1" + assert result.name == "get_weather" + assert "72" in result.result + + +async def test_execute_tool_directly_respects_prior_confirmation( + llm_without_tool_call: MockLLM, ask_hook: PreToolCallback +): + """A PRE_TOOL ask hook still runs, but a matching confirmation in context makes it pass.""" + from ragbits.agents.hooks.manager import HookManager + + hook = Hook(event_type=EventType.PRE_TOOL, callback=ask_hook) + agent = Agent(llm=llm_without_tool_call, prompt=CustomPrompt, tools=[get_weather], hooks=[hook]) + + arguments = {"location": "San Francisco"} + confirmation_id = HookManager._compute_confirmation_id( + hook_name="ask_hook", tool_name="get_weather", arguments=arguments + ) + context: AgentRunContext = AgentRunContext( + tool_confirmations=[{"confirmation_id": confirmation_id, "confirmed": True}] + ) + + result = await agent.execute_tool_directly( + tool_call_id="call_1", tool_name="get_weather", arguments=arguments, context=context + ) + + assert "72" in result.result + + +async def test_execute_tool_directly_deny_hook_blocks_execution( + llm_without_tool_call: MockLLM, deny_hook: PreToolCallback +): + """A non-confirmation PRE_TOOL hook returning deny still blocks execution.""" + hook = Hook(event_type=EventType.PRE_TOOL, callback=deny_hook) + agent = Agent(llm=llm_without_tool_call, prompt=CustomPrompt, tools=[get_weather], hooks=[hook]) + + result = await agent.execute_tool_directly( + tool_call_id="call_1", + tool_name="get_weather", + arguments={"location": "San Francisco"}, + context=AgentRunContext(), + ) + + assert result.result == "Blocked by hook" + + +async def test_execute_tool_directly_unknown_tool_raises(llm_without_tool_call: MockLLM): + """Unknown tool name raises AgentToolNotAvailableError.""" + agent = Agent(llm=llm_without_tool_call, prompt=CustomPrompt, tools=[get_weather]) + + with pytest.raises(AgentToolNotAvailableError): + await agent.execute_tool_directly( + tool_call_id="call_1", + tool_name="not_a_tool", + arguments={}, + context=AgentRunContext(), + ) + + +async def test_execute_tool_directly_runs_post_tool_hooks( + llm_without_tool_call: MockLLM, post_tool_append: Callable[..., PostToolCallback] +): + """POST_TOOL hooks must run on the direct-execution result.""" + hook = Hook(event_type=EventType.POST_TOOL, callback=post_tool_append("[PT]", prepend=True)) + agent = Agent(llm=llm_without_tool_call, prompt=CustomPrompt, tools=[get_weather], hooks=[hook]) + + result = await agent.execute_tool_directly( + tool_call_id="call_1", + tool_name="get_weather", + arguments={"location": "San Francisco"}, + context=AgentRunContext(), + ) + + assert result.result.startswith("[PT]") + + async def test_hook_priority_order(llm_with_tool_call: MockLLM): """Test that hooks execute in priority order (lower first).""" execution_order: list[int] = [] diff --git a/packages/ragbits-agents/tests/unit/test_history.py b/packages/ragbits-agents/tests/unit/test_history.py new file mode 100644 index 0000000000..dde49d49cd --- /dev/null +++ b/packages/ragbits-agents/tests/unit/test_history.py @@ -0,0 +1,51 @@ +"""Tests for history manipulation helpers.""" + +import json + +from ragbits.agents.history import inject_tool_call + + +def test_inject_tool_call_appends_assistant_and_tool_messages(): + """Injection produces an assistant tool_use message followed by a tool result message.""" + history = [{"role": "user", "content": "delete foo.txt"}] + + result = inject_tool_call( + history, + tool_call_id="call_42", + tool_name="delete_file", + arguments={"path": "foo.txt"}, + result="Deleted foo.txt", + ) + + assert len(result) == 3 + assert result[0] == {"role": "user", "content": "delete foo.txt"} + + assistant_msg = result[1] + assert assistant_msg["role"] == "assistant" + assert assistant_msg["content"] is None + assert assistant_msg["tool_calls"] == [ + { + "id": "call_42", + "type": "function", + "function": {"name": "delete_file", "arguments": json.dumps({"path": "foo.txt"})}, + } + ] + + tool_msg = result[2] + assert tool_msg == {"role": "tool", "tool_call_id": "call_42", "content": "Deleted foo.txt"} + + +def test_inject_tool_call_does_not_mutate_input(): + """Helper returns a new list; caller's history is untouched.""" + history = [{"role": "user", "content": "hi"}] + + inject_tool_call(history, tool_call_id="c", tool_name="t", arguments={}, result="ok") + + assert history == [{"role": "user", "content": "hi"}] + + +def test_inject_tool_call_stringifies_non_string_result(): + """Non-string results are coerced to str (matching add_tool_use_message).""" + result = inject_tool_call([], tool_call_id="c", tool_name="t", arguments={}, result={"key": "value"}) + + assert result[1]["content"] == str({"key": "value"}) diff --git a/packages/ragbits-chat/CHANGELOG.md b/packages/ragbits-chat/CHANGELOG.md index 6ff581d382..357730f09d 100644 --- a/packages/ragbits-chat/CHANGELOG.md +++ b/packages/ragbits-chat/CHANGELOG.md @@ -2,6 +2,8 @@ ## Unreleased +- Add `ChatInterface.resolve_pending_confirmations` and `create_pending_confirmation_state` helpers so chat implementations can execute user-approved tools directly on the continuation turn, preventing the confirmation loop caused by LLM argument drift (#969). + ## 1.6.2 (2026-03-26) - ragbits-agents updated to version v1.6.2 diff --git a/packages/ragbits-chat/src/ragbits/chat/interface/_interface.py b/packages/ragbits-chat/src/ragbits/chat/interface/_interface.py index 40c3a83c71..680a779476 100644 --- a/packages/ragbits-chat/src/ragbits/chat/interface/_interface.py +++ b/packages/ragbits-chat/src/ragbits/chat/interface/_interface.py @@ -11,6 +11,9 @@ from fastapi import UploadFile +from ragbits.agents._main import Agent, AgentRunContext +from ragbits.agents.confirmation import ConfirmationRequest +from ragbits.agents.history import inject_tool_call from ragbits.agents.tools.planning import Task from ragbits.chat.interface.summary import HeuristicSummaryGenerator, SummaryGenerator from ragbits.chat.interface.ui_customization import UICustomization @@ -289,6 +292,106 @@ def create_usage_response(usage: Usage) -> UsageResponse: def create_plan_item_response(task: Task) -> PlanItemResponse: return PlanItemResponse(content=PlanItemContent(task=task)) + @staticmethod + def create_pending_confirmation_state(request: ConfirmationRequest) -> dict[str, Any]: + """ + Build the state-dict fragment that records a pending confirmation. + + Merge this into ``ChatContext.state`` (via ``create_state_update``) whenever a + ``ConfirmationRequest`` is streamed to the frontend. On the continuation turn, + ``resolve_pending_confirmations`` uses this fragment to replay the approved + tool with its original arguments — avoiding LLM regeneration and the fragile + confirmation_id hash match. + """ + return { + "pending_confirmations": { + request.confirmation_id: { + "tool_call_id": request.tool_call_id, + "tool_name": request.tool_name, + "arguments": request.arguments, + } + } + } + + async def resolve_pending_confirmations( + self, + agent: Agent, + context: ChatContext, + ) -> list[ChatResponseUnion]: + """ + Execute tools whose confirmations the user just returned, updating ``agent.history`` in place. + + For each entry in ``context.tool_confirmations`` matched against + ``context.state["pending_confirmations"]``: + + - **confirmed**: invoke the tool via ``agent.execute_tool_directly`` with the + stored arguments, then append the (tool_use, tool_result) pair to history. + - **declined**: append a synthetic tool_result saying the user declined. + + When the caller next runs ``agent.run_streaming(...)``, the LLM sees the + injected tool results already in history and continues naturally — it is + never asked to regenerate the confirmed tool call, so there is no argument + drift or hash-mismatch loop. + + Args: + agent: The agent whose tools should be executed. ``agent.history`` is + mutated in place to include the injected (tool_use, tool_result) pairs. + context: The chat context from this turn. + + Returns: + UI responses (e.g. LiveUpdates) the caller should yield before continuing + the agent run. Empty list if there are no pending confirmations to resolve. + """ + pending_map: dict[str, dict[str, Any]] = context.state.get("pending_confirmations", {}) or {} + confirmations = context.tool_confirmations or [] + if not pending_map or not confirmations: + return [] + + agent_context: AgentRunContext = AgentRunContext(tool_confirmations=list(confirmations)) + responses: list[ChatResponseUnion] = [] + + for entry in confirmations: + confirmation_id = entry.get("confirmation_id") + pending = pending_map.get(confirmation_id) if confirmation_id else None + if pending is None: + continue + + tool_call_id = pending["tool_call_id"] + tool_name = pending["tool_name"] + arguments = pending["arguments"] + + if entry.get("confirmed"): + result = await agent.execute_tool_directly( + tool_call_id=tool_call_id, + tool_name=tool_name, + arguments=arguments, + context=agent_context, + ) + responses.append( + self.create_live_update( + tool_call_id, LiveUpdateType.FINISH, f"✅ {tool_name}", str(result.result)[:100] + ) + ) + agent.history = inject_tool_call( + agent.history, + tool_call_id=tool_call_id, + tool_name=tool_name, + arguments=arguments, + result=result.result, + ) + else: + decline_msg = "❌ User declined this action" + responses.append(self.create_live_update(tool_call_id, LiveUpdateType.FINISH, f"❌ {tool_name}")) + agent.history = inject_tool_call( + agent.history, + tool_call_id=tool_call_id, + tool_name=tool_name, + arguments=arguments, + result=decline_msg, + ) + + return responses + @staticmethod def _sign_state(state: dict[str, Any]) -> str: """ diff --git a/packages/ragbits-chat/tests/unit/test_confirmation_resolution.py b/packages/ragbits-chat/tests/unit/test_confirmation_resolution.py new file mode 100644 index 0000000000..97eb4abf24 --- /dev/null +++ b/packages/ragbits-chat/tests/unit/test_confirmation_resolution.py @@ -0,0 +1,147 @@ +"""Tests for ChatInterface confirmation-resolution helpers.""" + +from collections.abc import AsyncGenerator +from unittest.mock import AsyncMock + +import pytest + +from ragbits.agents._main import ToolCallResult +from ragbits.agents.confirmation import ConfirmationRequest +from ragbits.chat.interface import ChatInterface +from ragbits.chat.interface.types import ChatContext, ChatResponseUnion, TextContent, TextResponse +from ragbits.core.prompt.base import ChatFormat + + +class _Dummy(ChatInterface): + async def chat( # type: ignore[override] # noqa: PLR6301 + self, message: str, history: ChatFormat, context: ChatContext + ) -> AsyncGenerator[ChatResponseUnion, None]: + yield TextResponse(content=TextContent(text="ok")) + + +def test_create_pending_confirmation_state_shape(): + """Helper returns a state dict keyed by confirmation_id with tool_call_id, name, args.""" + request = ConfirmationRequest( + confirmation_id="conf_1", + tool_call_id="call_1", + tool_name="send_slack", + tool_description="Send a slack message", + arguments={"channel": "#general", "text": "hi"}, + ) + + state = ChatInterface.create_pending_confirmation_state(request) + + assert state == { + "pending_confirmations": { + "conf_1": { + "tool_call_id": "call_1", + "tool_name": "send_slack", + "arguments": {"channel": "#general", "text": "hi"}, + } + } + } + + +async def test_resolve_pending_confirmations_confirmed_executes_and_mutates_history(): + """A confirmed pending causes execute_tool_directly to run and agent.history to grow.""" + iface = _Dummy() + + agent = AsyncMock() + agent.history = [{"role": "user", "content": "send hi to #general"}] + agent.execute_tool_directly.return_value = ToolCallResult( + id="call_1", name="send_slack", arguments={"channel": "#general", "text": "hi"}, result="sent" + ) + + context = ChatContext( + state={ + "pending_confirmations": { + "conf_1": { + "tool_call_id": "call_1", + "tool_name": "send_slack", + "arguments": {"channel": "#general", "text": "hi"}, + } + } + }, + tool_confirmations=[{"confirmation_id": "conf_1", "confirmed": True}], + ) + + responses = await iface.resolve_pending_confirmations(agent, context) + + agent.execute_tool_directly.assert_awaited_once() + call_kwargs = agent.execute_tool_directly.call_args.kwargs + assert call_kwargs["tool_call_id"] == "call_1" + assert call_kwargs["tool_name"] == "send_slack" + assert call_kwargs["arguments"] == {"channel": "#general", "text": "hi"} + + assert len(agent.history) == 3 + assert agent.history[1]["role"] == "assistant" + assert agent.history[1]["tool_calls"][0]["id"] == "call_1" + assert agent.history[2] == {"role": "tool", "tool_call_id": "call_1", "content": "sent"} + assert responses # at least one UI response emitted + + +async def test_resolve_pending_confirmations_declined_skips_execution(): + """A declined pending injects a decline result; no tool execution happens.""" + iface = _Dummy() + + agent = AsyncMock() + agent.history = [] + + context = ChatContext( + state={ + "pending_confirmations": { + "conf_1": { + "tool_call_id": "call_1", + "tool_name": "delete_file", + "arguments": {"path": "x"}, + } + } + }, + tool_confirmations=[{"confirmation_id": "conf_1", "confirmed": False}], + ) + + await iface.resolve_pending_confirmations(agent, context) + + agent.execute_tool_directly.assert_not_called() + assert agent.history[-1]["role"] == "tool" + assert "declined" in agent.history[-1]["content"].lower() + + +async def test_resolve_pending_confirmations_no_pending_returns_input_unchanged(): + """No pending_confirmations in state → history untouched, no responses, no execution.""" + iface = _Dummy() + + agent = AsyncMock() + agent.history = [{"role": "user", "content": "hi"}] + + context = ChatContext(state={}, tool_confirmations=[]) + + responses = await iface.resolve_pending_confirmations(agent, context) + + agent.execute_tool_directly.assert_not_called() + assert agent.history == [{"role": "user", "content": "hi"}] + assert responses == [] + + +async def test_resolve_pending_confirmations_unknown_confirmation_id_is_ignored(): + """A tool_confirmations entry with no matching pending entry is a no-op (legacy flow still works).""" + iface = _Dummy() + + agent = AsyncMock() + agent.history = [] + + context = ChatContext( + state={"pending_confirmations": {}}, + tool_confirmations=[{"confirmation_id": "unknown", "confirmed": True}], + ) + + responses = await iface.resolve_pending_confirmations(agent, context) + + agent.execute_tool_directly.assert_not_called() + assert agent.history == [] + assert responses == [] + + +@pytest.fixture(autouse=True) +def _secret_key(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("RAGBITS_SECRET_KEY", "test-secret")