diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index 84f9f22b..d79e5ca0 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -11,6 +11,7 @@ import pydantic from ... import middleware as middleware_ +from ...types import integrity as integrity_ from ...types import messages as messages_ from ...types import stream as stream_ from ...types import tools as tools_ @@ -37,6 +38,8 @@ async def stream( The client is resolved from the model: ``model.client`` if set, otherwise auto-created from ``model.base_url`` / ``model.api_key_env``. """ + messages = integrity_.prepare_messages(messages) + call = middleware_.ModelContext( model=model, messages=messages, @@ -79,6 +82,8 @@ async def generate( * :class:`ImageParams` — image generation (``/image-model``). * :class:`VideoParams` — video generation (``/video-model``). """ + messages = integrity_.prepare_messages(messages) + call = middleware_.GenerateContext( model=model, messages=messages, diff --git a/src/ai/types/integrity.py b/src/ai/types/integrity.py new file mode 100644 index 00000000..f2b8ca87 --- /dev/null +++ b/src/ai/types/integrity.py @@ -0,0 +1,261 @@ +import json +import logging +from typing import Literal + +from . import builders +from . import messages as messages_ + +logger = logging.getLogger(__name__) + +Mode = Literal["strict", "auto"] + +IssueKind = Literal[ + "duplicate-tool-call", + "duplicate-tool-result", + "internal-part", + "invalid-tool-args", + "orphaned-tool-call", + "orphaned-tool-result", + "signal-message", +] + + +class IntegrityError(ValueError): + def __init__(self, issues: list[IssueKind]) -> None: + self.issues = issues + super().__init__( + f"Message history has {len(issues)} issue(s): " + ", ".join(issues) + ) + + +# used for stripping internal parts +_LLM_PART_TYPES = ( + messages_.TextPart, + messages_.ToolCallPart, + messages_.ToolResultPart, + messages_.ReasoningPart, + messages_.FilePart, +) + + +def _clean_messages( + messages: list[messages_.Message], mode: Mode +) -> tuple[list[messages_.Message], list[IssueKind]]: + """Strip internal messages, fix broken tool args""" + + issues: list[IssueKind] = [] + result: list[messages_.Message] = [] + + for msg in messages: + # 1. drop signal messages emitted by hooks + if msg.role == "signal": + issues.append("signal-message") + if mode == "strict": + result.append(msg) + continue + + parts: list[messages_.Part] = list(msg.parts) + changed = False + + # 2. strip everything that isn't an LLM part + kept: list[messages_.Part] = [ + p for p in parts if isinstance(p, _LLM_PART_TYPES) + ] + if len(kept) < len(parts): + issues.append("internal-part") + if mode == "auto": + parts = kept + changed = True + + # 3. ensure tool args are json-decodable + new_parts: list[messages_.Part] = [] + for part in parts: + if isinstance(part, messages_.ToolCallPart): + try: + json.loads(part.tool_args) + except (json.JSONDecodeError, TypeError): + if mode == "auto": + part = part.model_copy(update={"tool_args": "{}"}) + issues.append("invalid-tool-args") + changed = True + new_parts.append(part) + + if changed and mode == "auto": + parts = new_parts + + # 4. drop empty messages + if mode == "auto" and not parts: + continue + + if changed and mode == "auto": + # messages are immutable so we have to do this + result.append(msg.model_copy(update={"parts": parts})) + else: + result.append(msg) + + return result, issues + + +def _validate_tool_ids(messages: list[messages_.Message]) -> list[IssueKind]: + """Check for fatal issues: duplicate tool ids, orphaned tool results.""" + + issues: list[IssueKind] = [] + seen_call_ids: set[str] = set() + seen_result_ids: set[str] = set() + pending_call_ids: set[str] = set() + + duplicate_call = False + duplicate_result = False + orphaned_result = False + + for msg in messages: + if msg.role in ("user", "assistant") and pending_call_ids: + # result should have been in a tool message before this + # if it wasn't then it's a stray call, will be auto-fixed later + pending_call_ids.clear() + + if msg.role == "assistant": + # check if tool call is duplicate + # if not, mark it and append it to pending + for part in msg.parts: + if not isinstance(part, messages_.ToolCallPart): + continue + if part.tool_call_id in seen_call_ids: + duplicate_call = True + else: + seen_call_ids.add(part.tool_call_id) + pending_call_ids.add(part.tool_call_id) + + elif msg.role == "tool": + # check that this tool result is not duplicate and that + # there's a pending call from previous assistant message + for part in msg.parts: + if not isinstance(part, messages_.ToolResultPart): + continue + if part.tool_call_id in seen_result_ids: + duplicate_result = True + else: + seen_result_ids.add(part.tool_call_id) + if part.tool_call_id not in pending_call_ids: + orphaned_result = True + continue + pending_call_ids.remove(part.tool_call_id) + + if duplicate_call: + issues.append("duplicate-tool-call") + if duplicate_result: + issues.append("duplicate-tool-result") + if orphaned_result: + issues.append("orphaned-tool-result") + + return issues + + +def _fix_missing_results( + messages: list[messages_.Message], mode: Mode +) -> tuple[list[messages_.Message], list[IssueKind]]: + """Insert fake error results for stray tool calls.""" + issues: list[IssueKind] = [] + result: list[messages_.Message] = [] + + # 1. collect all result ids + answered: set[str] = set() + for msg in messages: + if msg.role == "tool": + for part in msg.parts: + if isinstance(part, messages_.ToolResultPart): + answered.add(part.tool_call_id) + + # pending tool calls from the current assistant turn + pending: dict[str, messages_.ToolCallPart] = {} + + def _flush_pending() -> None: + if not pending: + return + issues.append("orphaned-tool-call") + if mode == "auto": + synthetic = builders.tool_message( + *( + messages_.ToolResultPart( + tool_call_id=tc.tool_call_id, + tool_name=tc.tool_name, + result="Tool result not available", + is_error=True, + ) + for tc in pending.values() + ) + ) + result.append(synthetic) + + for msg in messages: + # if we're seeing a user / assistant message, then + # all pending tool calls are strays, because their results + # should have followed immediately after in a tool message + if msg.role in ("user", "assistant") and pending: + _flush_pending() + pending.clear() + + # 2. track calls + if msg.role == "assistant": + for part in msg.parts: + if ( + isinstance(part, messages_.ToolCallPart) + and part.tool_call_id not in answered + ): + pending[part.tool_call_id] = part + result.append(msg) + # 3. match results with calls + elif msg.role == "tool": + for part in msg.parts: + if isinstance(part, messages_.ToolResultPart): + pending.pop(part.tool_call_id, None) + result.append(msg) + else: + result.append(msg) + + _flush_pending() + + return result, issues + + +def prepare_messages( + messages: list[messages_.Message], + *, + mode: Mode = "auto", +) -> list[messages_.Message]: + """Fix and validate message list. + + ``"auto"`` (default) -- silently fixes recoverable issues (signal + messages, internal parts, invalid tool args, missing tool results). + ``"strict"`` -- collects every recoverable issue and raises + :class:`IntegrityError`. + + Duplicate tool-call IDs, duplicate tool-result IDs, and orphaned + tool results always raise :class:`IntegrityError` regardless of mode. + + Always returns a **new** list; never mutates the input. + """ + issues: list[IssueKind] = [] + + result, phase1_issues = _clean_messages(list(messages), mode) + issues.extend(phase1_issues) + + # never auto-fixed + fatal_issues = _validate_tool_ids(result) + issues.extend(fatal_issues) + + if not fatal_issues: + result, phase3_issues = _fix_missing_results(result, mode) + issues.extend(phase3_issues) + + if fatal_issues or (mode == "strict" and issues): + raise IntegrityError(issues) + + if issues: + logger.warning( + "Auto-fixed %d message issue(s): %s", + len(issues), + ", ".join(issues), + ) + + return result diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 8d7a88ce..0a8ecf14 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -7,6 +7,7 @@ from typing import Any import pydantic +import pytest import ai from ai import middleware, models @@ -335,12 +336,12 @@ async def gen_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: async def test_wrap_tool_context_fields_flow_to_result() -> None: - """ToolContext.tool_call_id and tool_name are used in the result message.""" + """ToolContext.tool_name is used in the result message.""" class Rewriter(ai.Middleware): async def wrap_tool(self, call: middleware.ToolContext, next: Any) -> Any: - # Rewrite the tool_call_id via dataclasses.replace. - modified = dataclasses.replace(call, tool_call_id="rewritten-id") + # Rewrite the tool_name via dataclasses.replace. + modified = dataclasses.replace(call, tool_name="rewritten-name") return await next(modified) @ai.tool @@ -349,7 +350,7 @@ async def echo(x: int) -> int: return x my_agent = ai.agent(tools=[echo]) - call1 = [tool_call_msg(tc_id="original-id", name="echo", args='{"x": 42}')] + call1 = [tool_call_msg(tc_id="tc-1", name="echo", args='{"x": 42}')] call2 = [text_msg("done")] mock_llm([call1, call2]) @@ -361,8 +362,36 @@ async def echo(x: int) -> int: tool_result_msgs.append(m) assert len(tool_result_msgs) >= 1 - # The result message should use the rewritten ID, not the original. - assert tool_result_msgs[0].tool_results[0].tool_call_id == "rewritten-id" + # The result message should use the rewritten name, not the original. + assert tool_result_msgs[0].tool_results[0].tool_name == "rewritten-name" + + +async def test_wrap_tool_rewriting_tool_call_id_breaks_history() -> None: + """tool_call_id is a correlation key and must stay stable.""" + + class Rewriter(ai.Middleware): + async def wrap_tool(self, call: middleware.ToolContext, next: Any) -> Any: + modified = dataclasses.replace(call, tool_call_id="rewritten-id") + return await next(modified) + + @ai.tool + async def echo(x: int) -> int: + """Echo a number.""" + return x + + my_agent = ai.agent(tools=[echo]) + call1 = [tool_call_msg(tc_id="original-id", name="echo", args='{"x": 42}')] + call2 = [text_msg("done")] + mock_llm([call1, call2]) + + with pytest.raises(ExceptionGroup) as exc_info: + async for _m in my_agent.run( + MOCK_MODEL, [ai.user_message("go")], middleware=[Rewriter()] + ): + pass + + assert len(exc_info.value.exceptions) == 1 + assert "orphaned-tool-result" in str(exc_info.value.exceptions[0]) # ── StreamResult wrapping ─────────────────────────────────────── diff --git a/tests/types/test_integrity.py b/tests/types/test_integrity.py new file mode 100644 index 00000000..ac308752 --- /dev/null +++ b/tests/types/test_integrity.py @@ -0,0 +1,573 @@ +"""Tests for message integrity checker.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, Sequence +from typing import Any, Literal +from unittest.mock import patch + +import pydantic +import pytest + +import ai +from ai import models +from ai.types import builders, messages +from ai.types.integrity import IntegrityError, prepare_messages + +from ..conftest import MOCK_MODEL, mock_generate, mock_llm, text_msg + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _assistant_with_tool_call( + tool_call_id: str = "tc-1", + tool_name: str = "calc", + tool_args: str = '{"x": 1}', +) -> messages.Message: + return messages.Message( + role="assistant", + parts=[ + messages.ToolCallPart( + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_args=tool_args, + ) + ], + ) + + +def _tool_result( + tool_call_id: str = "tc-1", + tool_name: str = "calc", + result: str = "42", +) -> messages.Message: + return builders.tool_message( + tool_call_id=tool_call_id, + tool_name=tool_name, + result=result, + ) + + +def _assert_raises_issue( + msgs: list[messages.Message], + issue: str, + *, + mode: Literal["auto", "strict"] = "auto", +) -> None: + with pytest.raises(IntegrityError) as exc_info: + prepare_messages(msgs, mode=mode) + assert issue in exc_info.value.issues + + +# --------------------------------------------------------------------------- +# Clean passthrough +# --------------------------------------------------------------------------- + + +def test_clean_messages_pass_through() -> None: + msgs = [ + builders.user_message("hello"), + builders.assistant_message("world"), + ] + result = prepare_messages(msgs) + assert len(result) == 2 + assert result[0].text == "hello" + assert result[1].text == "world" + + +def test_idempotent() -> None: + msgs = [ + builders.user_message("hi"), + _assistant_with_tool_call(), + _tool_result(), + builders.assistant_message("done"), + ] + once = prepare_messages(msgs) + twice = prepare_messages(once) + assert len(once) == len(twice) + for a, b in zip(once, twice, strict=True): + assert a.role == b.role + assert len(a.parts) == len(b.parts) + + +# --------------------------------------------------------------------------- +# Signal messages +# --------------------------------------------------------------------------- + + +def test_drops_signal_messages() -> None: + msgs = [ + builders.user_message("hi"), + messages.Message(role="signal", parts=[messages.TextPart(text="internal")]), + builders.assistant_message("hello"), + ] + result = prepare_messages(msgs) + assert len(result) == 2 + assert result[0].role == "user" + assert result[1].role == "assistant" + + +def test_signal_strict_raises() -> None: + msgs = [ + messages.Message(role="signal", parts=[messages.TextPart(text="x")]), + ] + with pytest.raises(IntegrityError) as exc_info: + prepare_messages(msgs, mode="strict") + assert "signal-message" in exc_info.value.issues + + +# --------------------------------------------------------------------------- +# Internal parts (HookPart, StructuredOutputPart) +# --------------------------------------------------------------------------- + + +def test_strips_internal_parts() -> None: + msg = messages.Message( + role="assistant", + parts=[ + messages.TextPart(text="hi"), + messages.HookPart(hook_id="h1", hook_type="confirm", status="resolved"), + ], + ) + result = prepare_messages([msg]) + assert len(result) == 1 + assert len(result[0].parts) == 1 + assert isinstance(result[0].parts[0], messages.TextPart) + + +def test_strips_internal_parts_drops_empty_message() -> None: + """Message with only internal parts becomes empty and is dropped.""" + msg = messages.Message( + role="assistant", + parts=[ + messages.HookPart(hook_id="h1", hook_type="confirm", status="resolved"), + ], + ) + result = prepare_messages([msg]) + assert len(result) == 0 + + +def test_internal_parts_strict_raises() -> None: + msg = messages.Message( + role="assistant", + parts=[ + messages.HookPart(hook_id="h1", hook_type="confirm", status="resolved"), + ], + ) + with pytest.raises(IntegrityError) as exc_info: + prepare_messages([msg], mode="strict") + assert "internal-part" in exc_info.value.issues + + +# --------------------------------------------------------------------------- +# Invalid tool args +# --------------------------------------------------------------------------- + + +def test_fixes_invalid_tool_args() -> None: + msg = _assistant_with_tool_call(tool_args="not json {{{") + result = prepare_messages([msg]) + tc = result[0].parts[0] + assert isinstance(tc, messages.ToolCallPart) + assert tc.tool_args == "{}" + + +def test_preserves_valid_tool_args() -> None: + msg = _assistant_with_tool_call(tool_args='{"key": "value"}') + result = prepare_messages([msg]) + tc = result[0].parts[0] + assert isinstance(tc, messages.ToolCallPart) + assert tc.tool_args == '{"key": "value"}' + + +def test_invalid_tool_args_strict_raises() -> None: + msg = _assistant_with_tool_call(tool_args="broken") + with pytest.raises(IntegrityError) as exc_info: + prepare_messages([msg], mode="strict") + assert "invalid-tool-args" in exc_info.value.issues + + +# --------------------------------------------------------------------------- +# Orphaned tool calls (no matching result) — auto-fixable +# --------------------------------------------------------------------------- + + +def test_inserts_synthetic_result_for_orphaned_call_at_end() -> None: + """Tool call at end of history with no result gets a synthetic one.""" + msgs = [ + builders.user_message("calc 2+2"), + _assistant_with_tool_call(), + ] + result = prepare_messages(msgs) + assert len(result) == 3 + assert result[2].role == "tool" + tr = result[2].tool_results[0] + assert tr.tool_call_id == "tc-1" + assert tr.is_error is True + + +def test_inserts_synthetic_result_before_user_interruption() -> None: + """User message interrupting tool flow triggers synthetic results.""" + msgs = [ + builders.user_message("calc 2+2"), + _assistant_with_tool_call(), + builders.user_message("never mind"), + ] + result = prepare_messages(msgs) + assert len(result) == 4 + # Synthetic result inserted before the user message. + assert result[2].role == "tool" + assert result[2].tool_results[0].is_error is True + assert result[3].role == "user" + assert result[3].text == "never mind" + + +def test_inserts_synthetic_result_before_next_assistant() -> None: + """New assistant message while tool calls pending triggers synthetic results.""" + msgs = [ + builders.user_message("calc 2+2"), + _assistant_with_tool_call(), + builders.assistant_message("actually, the answer is 4"), + ] + result = prepare_messages(msgs) + assert len(result) == 4 + assert result[2].role == "tool" + assert result[2].tool_results[0].is_error is True + assert result[3].role == "assistant" + + +def test_multiple_orphaned_calls_get_individual_results() -> None: + msg = messages.Message( + role="assistant", + parts=[ + messages.ToolCallPart(tool_call_id="tc-1", tool_name="a", tool_args="{}"), + messages.ToolCallPart(tool_call_id="tc-2", tool_name="b", tool_args="{}"), + ], + ) + result = prepare_messages([builders.user_message("go"), msg]) + # Synthetic tool message should have results for both calls. + synthetic = result[2] + assert synthetic.role == "tool" + ids = {tr.tool_call_id for tr in synthetic.tool_results} + assert ids == {"tc-1", "tc-2"} + + +def test_partial_results_only_fills_missing() -> None: + """If some results exist, only the missing ones get synthetic fills.""" + msgs = [ + builders.user_message("go"), + messages.Message( + role="assistant", + parts=[ + messages.ToolCallPart( + tool_call_id="tc-1", tool_name="a", tool_args="{}" + ), + messages.ToolCallPart( + tool_call_id="tc-2", tool_name="b", tool_args="{}" + ), + ], + ), + _tool_result(tool_call_id="tc-1"), + # tc-2 is missing, then user interrupts + builders.user_message("stop"), + ] + result = prepare_messages(msgs) + # user, assistant, tool(tc-1), synthetic-tool(tc-2), user + assert len(result) == 5 + synthetic = result[3] + assert synthetic.role == "tool" + assert len(synthetic.tool_results) == 1 + assert synthetic.tool_results[0].tool_call_id == "tc-2" + assert synthetic.tool_results[0].is_error is True + + +def test_orphaned_tool_call_strict_raises() -> None: + msgs = [ + builders.user_message("go"), + _assistant_with_tool_call(), + ] + with pytest.raises(IntegrityError) as exc_info: + prepare_messages(msgs, mode="strict") + assert "orphaned-tool-call" in exc_info.value.issues + + +# --------------------------------------------------------------------------- +# Orphaned tool results (no matching call) — always raises +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("mode", ["auto", "strict"]) +def test_orphaned_tool_result_always_raises( + mode: Literal["auto", "strict"], +) -> None: + """Tool result referencing a nonexistent call always raises.""" + msgs = [ + builders.user_message("hi"), + _tool_result(tool_call_id="nonexistent"), + ] + _assert_raises_issue(msgs, "orphaned-tool-result", mode=mode) + + +def test_out_of_sequence_tool_result_raises() -> None: + """A late tool result cannot arrive after another conversation turn.""" + msgs = [ + builders.user_message("go"), + _assistant_with_tool_call(), + builders.user_message("never mind"), + _tool_result(), + ] + _assert_raises_issue(msgs, "orphaned-tool-result") + + +# --------------------------------------------------------------------------- +# Complete tool flow (no issues) +# --------------------------------------------------------------------------- + + +def test_complete_tool_flow_unchanged() -> None: + """A properly paired tool flow passes through without modification.""" + msgs = [ + builders.user_message("calc 2+2"), + _assistant_with_tool_call(), + _tool_result(), + builders.assistant_message("The answer is 4"), + ] + result = prepare_messages(msgs) + assert len(result) == 4 + assert [m.role for m in result] == ["user", "assistant", "tool", "assistant"] + + +# --------------------------------------------------------------------------- +# Strict mode collects multiple issues +# --------------------------------------------------------------------------- + + +def test_strict_collects_all_issues() -> None: + msgs = [ + messages.Message(role="signal", parts=[messages.TextPart(text="x")]), + messages.Message( + role="assistant", + parts=[ + messages.TextPart(text="hi"), + messages.HookPart(hook_id="h1", hook_type="confirm", status="resolved"), + ], + ), + ] + with pytest.raises(IntegrityError) as exc_info: + prepare_messages(msgs, mode="strict") + issues = exc_info.value.issues + assert "signal-message" in issues + assert "internal-part" in issues + + +def test_strict_keeps_recoverable_issues_when_history_is_corrupt() -> None: + msgs = [ + messages.Message(role="signal", parts=[messages.TextPart(text="x")]), + builders.user_message("go"), + messages.Message( + role="assistant", + parts=[ + messages.ToolCallPart( + tool_call_id="tc-1", tool_name="a", tool_args="{}" + ), + messages.ToolCallPart( + tool_call_id="tc-1", tool_name="b", tool_args="{}" + ), + ], + ), + ] + with pytest.raises(IntegrityError) as exc_info: + prepare_messages(msgs, mode="strict") + assert "signal-message" in exc_info.value.issues + assert "duplicate-tool-call" in exc_info.value.issues + + +# --------------------------------------------------------------------------- +# Duplicate tool call IDs — always raises +# --------------------------------------------------------------------------- + + +def test_duplicate_tool_calls_raises_in_auto() -> None: + """Two assistant messages using the same tool_call_id always raises.""" + msgs = [ + builders.user_message("go"), + _assistant_with_tool_call(tool_call_id="tc-1", tool_args='{"v": 1}'), + _tool_result(tool_call_id="tc-1", result="old"), + _assistant_with_tool_call(tool_call_id="tc-1", tool_args='{"v": 2}'), + _tool_result(tool_call_id="tc-1", result="new"), + ] + with pytest.raises(IntegrityError) as exc_info: + prepare_messages(msgs) + assert "duplicate-tool-call" in exc_info.value.issues + assert "duplicate-tool-result" in exc_info.value.issues + + +def test_duplicate_tool_calls_within_same_message_raises() -> None: + """Two tool calls with the same ID in one assistant message raises.""" + msg = messages.Message( + role="assistant", + parts=[ + messages.ToolCallPart( + tool_call_id="tc-1", tool_name="a", tool_args='{"v": 1}' + ), + messages.ToolCallPart( + tool_call_id="tc-1", tool_name="a", tool_args='{"v": 2}' + ), + ], + ) + with pytest.raises(IntegrityError) as exc_info: + prepare_messages([builders.user_message("go"), msg]) + assert "duplicate-tool-call" in exc_info.value.issues + + +def test_duplicate_tool_results_raises_in_auto() -> None: + """Two tool messages with results for the same call always raises.""" + msgs = [ + builders.user_message("go"), + _assistant_with_tool_call(tool_call_id="tc-1"), + _tool_result(tool_call_id="tc-1", result="first"), + _tool_result(tool_call_id="tc-1", result="second"), + ] + with pytest.raises(IntegrityError) as exc_info: + prepare_messages(msgs) + assert "duplicate-tool-result" in exc_info.value.issues + + +def test_duplicate_tool_results_within_same_message_raises() -> None: + """Two results for the same call ID in one tool message raises.""" + msgs = [ + builders.user_message("go"), + _assistant_with_tool_call(tool_call_id="tc-1"), + messages.Message( + role="tool", + parts=[ + builders.tool_result("tc-1", result="first"), + builders.tool_result("tc-1", result="second"), + ], + ), + ] + with pytest.raises(IntegrityError) as exc_info: + prepare_messages(msgs) + assert "duplicate-tool-result" in exc_info.value.issues + + +# --------------------------------------------------------------------------- +# Does not mutate input +# --------------------------------------------------------------------------- + + +def test_does_not_mutate_input() -> None: + original = [ + builders.user_message("hi"), + _assistant_with_tool_call(), + ] + original_len = len(original) + _ = prepare_messages(original) + assert len(original) == original_len + + +# --------------------------------------------------------------------------- +# Wiring: stream() and generate() run prepare_messages on input +# --------------------------------------------------------------------------- + + +async def test_stream_calls_prepare_messages() -> None: + """stream() should invoke prepare_messages before hitting the adapter.""" + mock_llm([[text_msg("ok")]]) + msgs = [ai.user_message("hi")] + + with patch( + "ai.models.core.api.integrity_.prepare_messages", wraps=lambda m: m + ) as spy: + s = await models.stream(MOCK_MODEL, msgs) + async for _ in s: + pass + spy.assert_called_once_with(msgs) + + +async def test_stream_sanitizes_signal_messages() -> None: + """Signal messages are stripped before reaching the adapter.""" + received: list[list[messages.Message]] = [] + mock = mock_llm([[text_msg("ok")]]) + + # Wrap the mock adapter to capture messages it receives + original_adapter = mock.stream + + async def _spy_stream( + client: models.Client, + model: models.Model, + messages: list[messages.Message], + *, + tools: Sequence[ai.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[messages.Message]: + received.append(list(messages)) + async for m in original_adapter( + client, model, messages, tools=tools, output_type=output_type, **kwargs + ): + yield m + + models.register_stream("mock", _spy_stream) + + msgs = [ + ai.user_message("hi"), + messages.Message(role="signal", parts=[messages.TextPart(text="internal")]), + ai.assistant_message("hello"), + ] + s = await models.stream(MOCK_MODEL, msgs) + async for _ in s: + pass + + # The adapter should have received only 2 messages (signal stripped) + assert len(received) == 1 + assert len(received[0]) == 2 + assert all(m.role != "signal" for m in received[0]) + + +async def test_generate_calls_prepare_messages() -> None: + """generate() should invoke prepare_messages before hitting the adapter.""" + sentinel = messages.Message( + role="assistant", + parts=[messages.FilePart(data=b"\x89PNG", media_type="image/png")], + ) + mock_generate([sentinel]) + msgs = [ai.user_message("A cat")] + + with patch( + "ai.models.core.api.integrity_.prepare_messages", wraps=lambda m: m + ) as spy: + await models.generate(MOCK_MODEL, msgs, models.ImageParams(n=1)) + spy.assert_called_once_with(msgs) + + +async def test_generate_sanitizes_signal_messages() -> None: + """Signal messages are stripped before reaching generate adapter.""" + received: list[list[messages.Message]] = [] + sentinel = messages.Message( + role="assistant", + parts=[messages.FilePart(data=b"\x89PNG", media_type="image/png")], + ) + + async def _spy_gen( + client: models.Client, + model: models.Model, + messages: list[messages.Message], + params: Any, + ) -> messages.Message: + received.append(list(messages)) + return sentinel + + models.register_generate("mock", _spy_gen) + + msgs = [ + ai.user_message("A cat"), + messages.Message(role="signal", parts=[messages.TextPart(text="internal")]), + ] + await models.generate(MOCK_MODEL, msgs, models.ImageParams(n=1)) + + assert len(received) == 1 + assert len(received[0]) == 1 + assert received[0][0].role == "user"