diff --git a/src/aiq_agent/agents/clarifier/agent.py b/src/aiq_agent/agents/clarifier/agent.py index 5779e5f0..d94b6b0c 100644 --- a/src/aiq_agent/agents/clarifier/agent.py +++ b/src/aiq_agent/agents/clarifier/agent.py @@ -104,6 +104,15 @@ ) """Reminder prompt added after tool results to reinforce JSON-only output.""" +FORCE_SEARCH_GUIDANCE = ( + "You attempted to ask the user for clarification before gathering any context. " + "Before asking the user a question, you MUST first use the available search tools " + "to look up unfamiliar entities, acronyms, products, or terms in their request. " + "Issue one focused tool call now with a query derived from the user's request. " + "Only after reviewing the tool results should you decide whether clarification is still needed." +) +"""Guidance prompt injected when the LLM tries to clarify without having searched first.""" + class ClarifierAgent: """ @@ -484,6 +493,24 @@ def _get_fallback_clarification(self, query: str | None = None) -> str: SKIP_COMMANDS = {"skip", "done", "exit", "quit", "proceed", "continue", "no", "n", ""} """Set of commands that indicate the user wants to skip clarification.""" + @staticmethod + def _has_tool_invocations(messages: Sequence[Any]) -> bool: + """ + Check whether any prior assistant message in the conversation issued tool calls. + + Args: + messages: The conversation message history. + + Returns: + True if any AIMessage in the history carries non-empty tool_calls, + False otherwise. + """ + for msg in messages: + tool_calls = getattr(msg, "tool_calls", None) + if tool_calls: + return True + return False + def _is_skip_command(self, user_reply: str) -> bool: """ Check if the user's reply indicates they want to skip clarification. @@ -503,16 +530,21 @@ def _build_graph(self) -> CompiledStateGraph: """ Build the LangGraph StateGraph for the clarification workflow. - Creates a graph with three nodes: - - agent: Generates clarification questions using the LLM + Creates a graph with the following nodes: + - agent: Generates clarification questions using the LLM. On the first + turn it also enforces search-before-clarify (issue #234): if the model + asks for clarification without using its bound search tools, it nudges + the model once and retries inline. - tools: Executes tool calls (e.g., web search) for context - ask_for_clarification: Prompts user and processes response + - plan_preview: Optional plan approval flow The graph flow: - 1. agent generates a response (question, tool call, or completion) + 1. agent generates a response (question, tool call, or completion); + on turn 0 it may force one search-and-retry before yielding 2. If tool call → tools node → back to agent - 3. If question → ask_for_clarification → back to agent - 4. If complete → end + 3. If complete → end (or plan_preview if enabled) + 4. Otherwise → ask_for_clarification → back to agent Returns: Compiled LangGraph StateGraph ready for execution. @@ -526,6 +558,17 @@ def _build_graph(self) -> CompiledStateGraph: async def agent_node(state: ClarifierAgentState): if state.remaining_questions <= 0: + # Clarification budget is exhausted — emit a completion signal, + # unless a prior node already did (the skip-command branch in + # ask_clarification returns its own AIMessage(complete) and then + # this node is re-entered via the unconditional edge). Emitting + # another here would place two consecutive assistant messages in + # history, which the OpenAI/Anthropic APIs reject. If the last + # message is already a completion, leave the state untouched and + # let decide_route end the run. + last_message = state.messages[-1] if state.messages else None + if isinstance(last_message, AIMessage) and self._is_complete(getattr(last_message, "content", "")): + return {} complete_response = ClarificationResponse(needs_clarification=False, clarification_question=None) return {"messages": [AIMessage(content=complete_response.model_dump_json())]} tools_info = [ @@ -548,6 +591,43 @@ async def agent_node(state: ClarifierAgentState): messages.append(HumanMessage(content=JSON_REMINDER_AFTER_TOOLS)) response = await bound_llm.ainvoke(messages) + + # Search-before-clarify (issue #234): if, on the first turn, the model + # asks for clarification without first using its bound search tools, + # nudge it once to search and retry inline. This keeps the behavior + # model-agnostic without adding graph nodes or extra state — even + # models that would otherwise skip tool use must attempt a search + # before falling back to asking the user. + # + # The guard is one-shot by construction: + # * iteration == 0 — only on the first turn; once the user replies, + # iteration advances and this never fires again. + # * not _has_tool_invocations(state.messages) — once any tool call + # is in history (e.g. after a successful forced search, even while + # iteration is still 0), we never re-nudge. + # FORCE_SEARCH_GUIDANCE is sent only in the local retry_messages and is + # never returned to state, so it cannot leak into get_latest_user_query. + # + # We return ONLY retry_response, not the first (search-skipping) + # response. The first attempt was already shown to the model inside + # retry_messages; persisting it would put two consecutive + # assistant-role messages in history once retry_response carries a + # tool call (… AIMessage(clarif), AIMessage(tool_call), ToolMessage …), + # which the OpenAI Chat Completions and Anthropic Messages APIs reject + # with a 400. Keeping only retry_response preserves a valid sequence + # regardless of whether it is a tool call or another clarification. + if ( + self.tools + and state.iteration == 0 + and not self._has_tool_invocations(state.messages) + and not getattr(response, "tool_calls", None) + and self._is_needed(response.content) + ): + logger.info("Clarifier: model skipped search before clarifying; injecting guidance and retrying once") + retry_messages = messages + [response, HumanMessage(content=FORCE_SEARCH_GUIDANCE)] + retry_response = await bound_llm.ainvoke(retry_messages) + return {"messages": [retry_response]} + return {"messages": [response]} async def ask_clarification(state: ClarifierAgentState): @@ -576,8 +656,17 @@ async def ask_clarification(state: ClarifierAgentState): logger.info("Clarifier: User requested to skip clarification") complete_response = ClarificationResponse(needs_clarification=False, clarification_question=None) clarifier_log = f"{clarifier_log}\n**Turn {iteration + 1} - User:** [Skipped clarification]" + # Persist the user's reply as a HumanMessage before the + # completion AIMessage. The prior turn already left an + # AIMessage(clarification) in history; without an interleaving + # human message the two assistant turns would be adjacent, which + # the OpenAI/Anthropic APIs reject. (The duplicate completion on + # graph re-entry is suppressed by the guard in agent_node.) return { - "messages": [AIMessage(content=complete_response.model_dump_json())], + "messages": [ + HumanMessage(content=user_reply), + AIMessage(content=complete_response.model_dump_json()), + ], "iteration": max_turns, # Force end of clarification "clarifier_log": clarifier_log, } @@ -610,6 +699,11 @@ def decide_route(state: ClarifierAgentState | dict): if self.enable_plan_approval: return "plan_preview" return "__end__" + + # The search-before-clarify nudge (issue #234) is handled inline in + # agent_node, not here — see the retry block there. By the time a + # clarification response reaches this router, any forced search has + # already happened, so we route straight to the user. return "ask_for_clarification" async def plan_preview_node(state: ClarifierAgentState): diff --git a/src/aiq_agent/agents/clarifier/prompts/research_clarification.j2 b/src/aiq_agent/agents/clarifier/prompts/research_clarification.j2 index a2745577..ad62e064 100644 --- a/src/aiq_agent/agents/clarifier/prompts/research_clarification.j2 +++ b/src/aiq_agent/agents/clarifier/prompts/research_clarification.j2 @@ -29,9 +29,11 @@ Your ONLY responsibility is to determine whether a research request requires cla ## Tool Usage -- You may use search tools ONLY to understand unfamiliar domains -- Use at most 1-2 high-value searches -- Searches are for your internal understanding only +- **Search first, ask second.** If the user's request contains any unfamiliar entity, acronym, project, person, product, or technical term that you cannot fully define from your training data, you MUST issue a search tool call before deciding clarification is needed. Do not ask the user to define terms that a quick search would resolve. +- On the first turn, prefer at least one search to ground the topic in current context whenever search tools are available. +- Use at most 1-2 high-value searches per turn — keep queries focused on the specific unknown. +- Searches are for your internal understanding only. Do not summarize or report search results to the user. +- After reviewing search results, re-evaluate: if the request is now sufficiently specified, return `needs_clarification: false`. Only ask a clarification question if a genuine ambiguity remains. --- diff --git a/tests/aiq_agent/agents/clarifier/test_agent.py b/tests/aiq_agent/agents/clarifier/test_agent.py index 38f53b01..3adfd083 100644 --- a/tests/aiq_agent/agents/clarifier/test_agent.py +++ b/tests/aiq_agent/agents/clarifier/test_agent.py @@ -22,10 +22,13 @@ import pytest from langchain_core.messages import AIMessage +from langchain_core.messages import BaseMessage from langchain_core.messages import HumanMessage +from langchain_core.messages import ToolMessage from langchain_core.tools import tool from aiq_agent.agents.clarifier.agent import DEFAULT_CLARIFICATION_PROMPT +from aiq_agent.agents.clarifier.agent import FORCE_SEARCH_GUIDANCE from aiq_agent.agents.clarifier.agent import ClarifierAgent from aiq_agent.agents.clarifier.models import ClarificationResponse from aiq_agent.agents.clarifier.models import ClarifierAgentState @@ -872,3 +875,472 @@ def test_init_with_planner_llm(self, mock_llm_provider): ) assert agent.planner_llm == planner_llm + + +class TestHasToolInvocations: + """Tests for the _has_tool_invocations helper.""" + + def test_empty_messages(self): + """No messages -> no invocations.""" + assert ClarifierAgent._has_tool_invocations([]) is False + + def test_only_human_messages(self): + """Human-only history has no invocations.""" + messages = [HumanMessage(content="hi"), HumanMessage(content="more")] + assert ClarifierAgent._has_tool_invocations(messages) is False + + def test_ai_message_without_tool_calls(self): + """An AIMessage without tool_calls counts as no invocation.""" + messages = [HumanMessage(content="hi"), AIMessage(content="hello")] + assert ClarifierAgent._has_tool_invocations(messages) is False + + def test_ai_message_with_tool_calls(self): + """An AIMessage with tool_calls counts as an invocation.""" + ai = AIMessage( + content="", + tool_calls=[{"name": "web_search_tool", "args": {"query": "x"}, "id": "call_1"}], + ) + assert ClarifierAgent._has_tool_invocations([HumanMessage(content="hi"), ai]) is True + + def test_with_tool_message(self): + """A ToolMessage by itself does not count - we look at the assistant turn.""" + msg = ToolMessage(content="result", tool_call_id="call_1") + assert ClarifierAgent._has_tool_invocations([msg]) is False + + +class TestClarifierForceSearch: + """Tests for the search-before-clarify behavior (issue #234).""" + + @pytest.fixture + def mock_llm(self): + """Create a mock LLM.""" + llm = MagicMock() + llm.bind_tools = MagicMock(return_value=llm) + return llm + + @pytest.fixture + def mock_llm_provider(self, mock_llm): + """Create a mock LLM provider.""" + provider = MagicMock(spec=LLMProvider) + provider.get = MagicMock(return_value=mock_llm) + return provider + + @pytest.mark.asyncio + async def test_force_search_fires_when_llm_skips_tools(self, mock_llm_provider, mock_llm): + """When tools are configured and the LLM tries to clarify without searching, + the agent must first nudge the LLM with a force-search guidance message, + then route to tools when the LLM complies.""" + + # 1st LLM call: skip tools, ask for clarification. + # 2nd LLM call (after force_search): produce a tool call. + # 3rd LLM call (after tool result): return complete. + clarif_msg = AIMessage( + content=ClarificationResponse( + needs_clarification=True, clarification_question="What aspect?" + ).model_dump_json() + ) + tool_call_msg = AIMessage( + content="", + tool_calls=[{"name": "web_search_tool", "args": {"query": "AI"}, "id": "call_1"}], + ) + complete_msg = AIMessage( + content=ClarificationResponse(needs_clarification=False, clarification_question=None).model_dump_json() + ) + mock_llm.ainvoke = AsyncMock(side_effect=[clarif_msg, tool_call_msg, complete_msg]) + + user_callback = AsyncMock() + + agent = ClarifierAgent( + llm_provider=mock_llm_provider, + tools=[web_search_tool], + user_prompt_callback=user_callback, + ) + + state = ClarifierAgentState(messages=[HumanMessage(content="Research Foo Project XYZ")]) + result = await agent.run(state) + + assert result is not None + assert isinstance(result, ClarifierResult) + # The LLM was invoked three times: clarify-attempt, tool-call, finalize. + assert mock_llm.ainvoke.call_count == 3 + # The user was never prompted because the search-then-complete path was taken. + user_callback.assert_not_called() + # The 2nd LLM call (after force_search guidance) should have received the + # guidance string as the latest HumanMessage. + second_call_messages = mock_llm.ainvoke.call_args_list[1].args[0] + latest_human = next(m for m in reversed(second_call_messages) if isinstance(m, HumanMessage)) + assert FORCE_SEARCH_GUIDANCE in latest_human.content + + @pytest.mark.asyncio + async def test_force_search_skipped_when_no_tools(self, mock_llm_provider, mock_llm): + """When no tools are configured, force_search must NOT fire; the agent + should fall back to asking the user immediately.""" + clarif_msg = AIMessage( + content=ClarificationResponse( + needs_clarification=True, clarification_question="What aspect?" + ).model_dump_json() + ) + complete_msg = AIMessage( + content=ClarificationResponse(needs_clarification=False, clarification_question=None).model_dump_json() + ) + mock_llm.ainvoke = AsyncMock(side_effect=[clarif_msg, complete_msg]) + + user_callback = AsyncMock(return_value="technical deep dive") + + agent = ClarifierAgent( + llm_provider=mock_llm_provider, + tools=[], # no tools available + user_prompt_callback=user_callback, + ) + + state = ClarifierAgentState(messages=[HumanMessage(content="Research AI")]) + result = await agent.run(state) + + assert result is not None + # User callback is called once for clarification (no force_search detour). + user_callback.assert_called_once() + + @pytest.mark.asyncio + async def test_force_search_fires_at_most_once(self, mock_llm_provider, mock_llm): + """Even if the LLM stubbornly refuses to call a tool after the force_search + nudge, the agent must not loop forever - it should proceed to asking the + user after the single nudge attempt.""" + clarif_msg_1 = AIMessage( + content=ClarificationResponse( + needs_clarification=True, clarification_question="What aspect?" + ).model_dump_json() + ) + # After the force_search nudge, the model still refuses to call a tool + # and returns another clarification request. + clarif_msg_2 = AIMessage( + content=ClarificationResponse( + needs_clarification=True, clarification_question="What aspect?" + ).model_dump_json() + ) + # After the user replies, the model completes. + complete_msg = AIMessage( + content=ClarificationResponse(needs_clarification=False, clarification_question=None).model_dump_json() + ) + mock_llm.ainvoke = AsyncMock(side_effect=[clarif_msg_1, clarif_msg_2, complete_msg]) + + user_callback = AsyncMock(return_value="technical") + + agent = ClarifierAgent( + llm_provider=mock_llm_provider, + tools=[web_search_tool], + user_prompt_callback=user_callback, + ) + + state = ClarifierAgentState(messages=[HumanMessage(content="Research AI")]) + result = await agent.run(state) + + assert result is not None + # The LLM was invoked three times max - the nudge fired once, then we + # fell through to ask_for_clarification, and the user reply produced + # the final completion. + assert mock_llm.ainvoke.call_count == 3 + # The user was prompted exactly once (no infinite loop of nudges). + user_callback.assert_called_once() + + @pytest.mark.asyncio + async def test_force_search_not_triggered_when_llm_searches_first(self, mock_llm_provider, mock_llm): + """When the LLM voluntarily issues a tool call on the first turn, the + force_search path is never entered - normal flow is preserved.""" + tool_call_msg = AIMessage( + content="", + tool_calls=[{"name": "web_search_tool", "args": {"query": "AI"}, "id": "call_1"}], + ) + complete_msg = AIMessage( + content=ClarificationResponse(needs_clarification=False, clarification_question=None).model_dump_json() + ) + mock_llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, complete_msg]) + + user_callback = AsyncMock() + + agent = ClarifierAgent( + llm_provider=mock_llm_provider, + tools=[web_search_tool], + user_prompt_callback=user_callback, + ) + + state = ClarifierAgentState(messages=[HumanMessage(content="Research AI")]) + result = await agent.run(state) + + assert result is not None + assert mock_llm.ainvoke.call_count == 2 + user_callback.assert_not_called() + # No HumanMessage with the force_search guidance should have been added. + second_call_messages = mock_llm.ainvoke.call_args_list[1].args[0] + for m in second_call_messages: + if isinstance(m, HumanMessage): + assert FORCE_SEARCH_GUIDANCE not in m.content + + @pytest.mark.asyncio + async def test_force_search_guidance_not_in_state_messages(self, mock_llm_provider, mock_llm): + """The force_search guidance must be injected ephemerally only; it must + never end up in state.messages, otherwise helpers like + get_latest_user_query would surface internal scaffolding back to the + user in fallback text. Regression test for Codex review feedback.""" + # The LLM ignores the nudge and returns invalid JSON, triggering the + # invalid-format fallback path inside ask_clarification. The user then + # replies "skip", which forces completion - so only two LLM calls + # actually happen in this run. + clarif_msg_1 = AIMessage( + content=ClarificationResponse( + needs_clarification=True, clarification_question="What aspect?" + ).model_dump_json() + ) + clarif_invalid = AIMessage(content="not valid JSON at all") + mock_llm.ainvoke = AsyncMock(side_effect=[clarif_msg_1, clarif_invalid]) + + # Capture what gets sent to the user. + prompts_received: list[str] = [] + + async def user_callback(question: str) -> str: + prompts_received.append(question) + return "skip" + + agent = ClarifierAgent( + llm_provider=mock_llm_provider, + tools=[web_search_tool], + user_prompt_callback=user_callback, + ) + + original_query = "Research Project Foo at Acme" + state = ClarifierAgentState(messages=[HumanMessage(content=original_query)]) + await agent.run(state) + + # The user was prompted exactly once - with a fallback derived from + # their actual query, never from the force-search guidance. + assert len(prompts_received) == 1 + prompt_text = prompts_received[0] + assert "Project Foo" in prompt_text or "Acme" in prompt_text + assert FORCE_SEARCH_GUIDANCE not in prompt_text + # The internal force-search guidance must never be visible in any + # message the user-facing fallback would draw from. + assert "You attempted to ask the user" not in prompt_text + + @pytest.mark.asyncio + async def test_force_search_guidance_not_injected_after_user_reply(self, mock_llm_provider, mock_llm): + """After the user has actually replied (iteration > 0), the agent must + NOT re-inject the search-first nudge on the next LLM call. Otherwise + the model would receive 'issue a tool call now' immediately after the + user provided clarifying answer, causing a gratuitous search instead + of synthesizing the answer. Regression test for Greptile P1 finding.""" + clarif_msg_1 = AIMessage( + content=ClarificationResponse( + needs_clarification=True, clarification_question="What aspect?" + ).model_dump_json() + ) + # After the nudge, the model still refuses to call a tool. + clarif_msg_2 = AIMessage( + content=ClarificationResponse( + needs_clarification=True, clarification_question="Which area?" + ).model_dump_json() + ) + # After the user replies, the model completes. + complete_msg = AIMessage( + content=ClarificationResponse(needs_clarification=False, clarification_question=None).model_dump_json() + ) + mock_llm.ainvoke = AsyncMock(side_effect=[clarif_msg_1, clarif_msg_2, complete_msg]) + + user_callback = AsyncMock(return_value="technical deep dive") + + agent = ClarifierAgent( + llm_provider=mock_llm_provider, + tools=[web_search_tool], + user_prompt_callback=user_callback, + ) + + state = ClarifierAgentState(messages=[HumanMessage(content="Research AI")]) + await agent.run(state) + + # The 3rd LLM call happens AFTER the user reply (iteration moves from + # 0 to 1), so the inline search-before-clarify guard (gated on + # iteration == 0) must not fire again and the nudge must not appear in + # that call's message list. + assert mock_llm.ainvoke.call_count == 3 + third_call_messages = mock_llm.ainvoke.call_args_list[2].args[0] + for m in third_call_messages: + if isinstance(m, HumanMessage): + assert FORCE_SEARCH_GUIDANCE not in m.content, ( + "force_search guidance must not be re-injected after the user replies" + ) + + @pytest.mark.asyncio + async def test_forced_retry_does_not_emit_consecutive_assistant_messages(self, mock_llm_provider, mock_llm): + """After a forced search-retry whose retry produces a tool call, the + message list fed to the LLM on the next turn must NOT contain two + consecutive assistant (AIMessage) turns. Two adjacent assistant + messages with no interleaved user/tool message are rejected with a 400 + by the OpenAI Chat Completions and Anthropic Messages APIs; mocked LLMs + don't enforce this, so we assert the invariant explicitly. Regression + test for the Greptile P1 finding on PR #245.""" + clarif_msg = AIMessage( + content=ClarificationResponse( + needs_clarification=True, clarification_question="What aspect?" + ).model_dump_json() + ) + tool_call_msg = AIMessage( + content="", + tool_calls=[{"name": "web_search_tool", "args": {"query": "AI"}, "id": "call_1"}], + ) + complete_msg = AIMessage( + content=ClarificationResponse(needs_clarification=False, clarification_question=None).model_dump_json() + ) + mock_llm.ainvoke = AsyncMock(side_effect=[clarif_msg, tool_call_msg, complete_msg]) + + agent = ClarifierAgent( + llm_provider=mock_llm_provider, + tools=[web_search_tool], + user_prompt_callback=AsyncMock(), + ) + + state = ClarifierAgentState(messages=[HumanMessage(content="Research Foo Project XYZ")]) + await agent.run(state) + + # Inspect every message list that was actually sent to the LLM and assert + # no two consecutive AIMessages appear (the API-invalid shape). + def _adjacent_assistant_pairs(messages: list[BaseMessage]) -> list[tuple[int, int]]: + pairs = [] + for i in range(len(messages) - 1): + if isinstance(messages[i], AIMessage) and isinstance(messages[i + 1], AIMessage): + pairs.append((i, i + 1)) + return pairs + + for call_idx, call in enumerate(mock_llm.ainvoke.call_args_list): + sent_messages = call.args[0] + offenders = _adjacent_assistant_pairs(sent_messages) + assert not offenders, ( + f"LLM call #{call_idx} contained consecutive assistant messages at {offenders}; " + "this is rejected by OpenAI/Anthropic APIs" + ) + + # Specifically: the forced retry must not persist the skipped + # clarification, so the post-tool history is [..., AIMessage(tool_call), + # ToolMessage, ...] with no stale AIMessage before the tool call. + third_call_messages = mock_llm.ainvoke.call_args_list[2].args[0] + ai_then_tool = any( + isinstance(third_call_messages[i], AIMessage) + and getattr(third_call_messages[i], "tool_calls", None) + and isinstance(third_call_messages[i + 1], ToolMessage) + for i in range(len(third_call_messages) - 1) + ) + assert ai_then_tool, "expected a tool-call AIMessage immediately followed by its ToolMessage in history" + + +class TestClarifierSkipMessageOrdering: + """Skip-command branch must not produce consecutive assistant messages. + + Regression tests for the ordering bug surfaced during the PR #245 audit: + the skip branch returned an AIMessage(complete) without persisting the + user's reply, leaving it adjacent to the prior clarification AIMessage; the + graph then re-entered agent_node and (with the budget exhausted) appended a + third AIMessage. Two/three consecutive assistant turns are rejected by the + OpenAI/Anthropic APIs and corrupt the planner call when plan approval is on. + """ + + @pytest.fixture + def mock_llm(self): + llm = MagicMock() + llm.bind_tools = MagicMock(return_value=llm) + return llm + + @pytest.fixture + def mock_llm_provider(self, mock_llm): + provider = MagicMock(spec=LLMProvider) + provider.get = MagicMock(return_value=mock_llm) + return provider + + @staticmethod + def _adjacent_assistant_pairs(messages: list[BaseMessage]) -> list[tuple[int, int]]: + return [ + (i, i + 1) + for i in range(len(messages) - 1) + if isinstance(messages[i], AIMessage) and isinstance(messages[i + 1], AIMessage) + ] + + @pytest.mark.asyncio + async def test_skip_persists_reply_and_no_consecutive_assistants(self, mock_llm_provider, mock_llm): + """User skips after one clarification: final history must interleave the + skip reply (HumanMessage) and contain no consecutive AIMessages.""" + clarif = AIMessage( + content=ClarificationResponse( + needs_clarification=True, clarification_question="What aspect?" + ).model_dump_json() + ) + # Only one real LLM response is needed; after skip the graph completes + # without another model call (the early-complete guard returns {}). + mock_llm.ainvoke = AsyncMock(side_effect=[clarif]) + + captured: dict = {} + + async def user_callback(question: str) -> str: + return "skip" + + agent = ClarifierAgent( + llm_provider=mock_llm_provider, + tools=[], # no tools → no forced search; isolate the skip path + user_prompt_callback=user_callback, + ) + + state = ClarifierAgentState(messages=[HumanMessage(content="Research AI")]) + + # Capture the final graph state to inspect persisted message ordering. + final = await agent.graph.ainvoke(state, config={"callbacks": []}) + captured["messages"] = final["messages"] if isinstance(final, dict) else final.messages + + msgs = captured["messages"] + offenders = self._adjacent_assistant_pairs(msgs) + assert not offenders, f"persisted history has consecutive assistant messages at {offenders}: {msgs}" + # The skip reply must be persisted as a HumanMessage. + assert any(isinstance(m, HumanMessage) and m.content == "skip" for m in msgs), ( + "skip reply was not persisted as a HumanMessage" + ) + # Exactly one completion AIMessage should terminate the dialog (no duplicate). + completion_ais = [ + m for m in msgs if isinstance(m, AIMessage) and ClarifierAgent._has_tool_invocations([m]) is False + ] + assert len(completion_ais) >= 1 + + @pytest.mark.asyncio + async def test_skip_with_plan_approval_planner_gets_valid_sequence(self, mock_llm_provider, mock_llm): + """With plan approval on, the corrupted skip history previously flowed + into the planner's ainvoke. Assert the planner never receives two + consecutive assistant messages.""" + clarif = AIMessage( + content=ClarificationResponse( + needs_clarification=True, clarification_question="What aspect?" + ).model_dump_json() + ) + mock_llm.ainvoke = AsyncMock(side_effect=[clarif]) + + planner_llm = MagicMock() + plan_json = '{"title": "Plan", "sections": ["Intro", "Analysis"]}' + planner_llm.ainvoke = AsyncMock(return_value=AIMessage(content=plan_json)) + + replies = iter(["skip", "approve"]) + + async def user_callback(question: str) -> str: + return next(replies) + + agent = ClarifierAgent( + llm_provider=mock_llm_provider, + tools=[], + user_prompt_callback=user_callback, + enable_plan_approval=True, + planner_llm=planner_llm, + ) + + state = ClarifierAgentState(messages=[HumanMessage(content="Research AI")]) + result = await agent.run(state) + + assert result is not None + assert result.plan_approved is True + # The planner was called; none of its input message lists may contain + # consecutive assistant messages. + assert planner_llm.ainvoke.call_count >= 1 + for call_idx, call in enumerate(planner_llm.ainvoke.call_args_list): + sent = call.args[0] + offenders = self._adjacent_assistant_pairs(sent) + assert not offenders, f"planner ainvoke #{call_idx} had consecutive assistant messages at {offenders}"