diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index f6ecc06..06bea9e 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -225,6 +225,7 @@ class MessageAggregator( ): def __init__(self) -> None: self._messages: list[types.messages.Message] = [] + self._index_by_id: dict[str, int] = {} def feed(self, item: events_.AgentEvent) -> None: if isinstance(item, events_.PartialToolCallResult): @@ -232,10 +233,17 @@ def feed(self, item: events_.AgentEvent) -> None: msg = item.message if msg is None: return - if self._messages and self._messages[-1].id == msg.id: - self._messages[-1] = msg - else: + # Later snapshots of a message replace earlier ones in place. + # Snapshots of the same message are not always consecutive — + # e.g. a tool-result message can land between two snapshots of + # the assistant message that called it — so dedupe by id, not + # by checking the tail. + index = self._index_by_id.get(msg.id) + if index is None: + self._index_by_id[msg.id] = len(self._messages) self._messages.append(msg) + else: + self._messages[index] = msg def snapshot(self) -> MessageBundle: return MessageBundle(messages=tuple(self._messages)) diff --git a/tests/agents/test_message_aggregator.py b/tests/agents/test_message_aggregator.py new file mode 100644 index 0000000..f4a081c --- /dev/null +++ b/tests/agents/test_message_aggregator.py @@ -0,0 +1,47 @@ +"""MessageAggregator — deduping message snapshots by id.""" + +from __future__ import annotations + +import ai +from ai.types import events as events_ + +from ..conftest import text_msg, tool_result_msg + + +def test_consecutive_snapshots_replace() -> None: + """Later snapshots of the same message replace the earlier ones.""" + agg = ai.agents.MessageAggregator() + agg.feed(events_.StreamEnd(message=text_msg("partial", id="msg-a"))) + agg.feed(events_.StreamEnd(message=text_msg("complete", id="msg-a"))) + + bundle = agg.snapshot() + assert [m.text for m in bundle.messages] == ["complete"] + + +def test_interleaved_snapshots_replace() -> None: + """Snapshots of the same message dedupe even when another message + lands in between (e.g. a tool-result message mid-stream).""" + agg = ai.agents.MessageAggregator() + tool = tool_result_msg(tc_id="tc-1", result="r") + agg.feed(events_.StreamEnd(message=text_msg("partial", id="msg-a"))) + agg.feed(events_.StreamEnd(message=tool)) + agg.feed(events_.StreamEnd(message=text_msg("complete", id="msg-a"))) + + bundle = agg.snapshot() + assert [m.id for m in bundle.messages] == ["msg-a", tool.id] + assert bundle.messages[0].text == "complete" + + +def test_first_occurrence_position_is_kept() -> None: + """Replacement keeps the message at its original position.""" + agg = ai.agents.MessageAggregator() + tool = tool_result_msg(tc_id="tc-1", result="r") + agg.feed(events_.StreamEnd(message=text_msg("a1", id="msg-a"))) + agg.feed(events_.StreamEnd(message=tool)) + agg.feed(events_.StreamEnd(message=text_msg("b1", id="msg-b"))) + agg.feed(events_.StreamEnd(message=text_msg("a2", id="msg-a"))) + agg.feed(events_.StreamEnd(message=text_msg("b2", id="msg-b"))) + + bundle = agg.snapshot() + assert [m.id for m in bundle.messages] == ["msg-a", tool.id, "msg-b"] + assert [m.text for m in bundle.messages] == ["a2", "", "b2"]