From 4050ddbc2e1c5bebee7e4e3920c01d86be31cc7b Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 9 Jun 2026 18:12:15 -0700 Subject: [PATCH] Dedupe MessageAggregator snapshots by id, not just consecutively Snapshots of the same message are not always consecutive in the event stream: a tool-result message can land between two snapshots of the assistant message that called it. The tail-only check then appended the later snapshot as a new entry, so bundles accumulated duplicate copies of the same message - one real subagent run produced 20 usage-bearing entries for 8 distinct messages, bloating the persisted result and overcounting for anyone summing usage across the bundle. Track an index by message id and replace in place, keeping the message at its first position. Co-authored-by: anthropic/claude-fable-5, via tau --- src/ai/agents/agent.py | 14 ++++++-- tests/agents/test_message_aggregator.py | 47 +++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) create mode 100644 tests/agents/test_message_aggregator.py diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index f6ecc06..06bea9e 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -225,6 +225,7 @@ class MessageAggregator( ): def __init__(self) -> None: self._messages: list[types.messages.Message] = [] + self._index_by_id: dict[str, int] = {} def feed(self, item: events_.AgentEvent) -> None: if isinstance(item, events_.PartialToolCallResult): @@ -232,10 +233,17 @@ def feed(self, item: events_.AgentEvent) -> None: msg = item.message if msg is None: return - if self._messages and self._messages[-1].id == msg.id: - self._messages[-1] = msg - else: + # Later snapshots of a message replace earlier ones in place. + # Snapshots of the same message are not always consecutive — + # e.g. a tool-result message can land between two snapshots of + # the assistant message that called it — so dedupe by id, not + # by checking the tail. + index = self._index_by_id.get(msg.id) + if index is None: + self._index_by_id[msg.id] = len(self._messages) self._messages.append(msg) + else: + self._messages[index] = msg def snapshot(self) -> MessageBundle: return MessageBundle(messages=tuple(self._messages)) diff --git a/tests/agents/test_message_aggregator.py b/tests/agents/test_message_aggregator.py new file mode 100644 index 0000000..f4a081c --- /dev/null +++ b/tests/agents/test_message_aggregator.py @@ -0,0 +1,47 @@ +"""MessageAggregator — deduping message snapshots by id.""" + +from __future__ import annotations + +import ai +from ai.types import events as events_ + +from ..conftest import text_msg, tool_result_msg + + +def test_consecutive_snapshots_replace() -> None: + """Later snapshots of the same message replace the earlier ones.""" + agg = ai.agents.MessageAggregator() + agg.feed(events_.StreamEnd(message=text_msg("partial", id="msg-a"))) + agg.feed(events_.StreamEnd(message=text_msg("complete", id="msg-a"))) + + bundle = agg.snapshot() + assert [m.text for m in bundle.messages] == ["complete"] + + +def test_interleaved_snapshots_replace() -> None: + """Snapshots of the same message dedupe even when another message + lands in between (e.g. a tool-result message mid-stream).""" + agg = ai.agents.MessageAggregator() + tool = tool_result_msg(tc_id="tc-1", result="r") + agg.feed(events_.StreamEnd(message=text_msg("partial", id="msg-a"))) + agg.feed(events_.StreamEnd(message=tool)) + agg.feed(events_.StreamEnd(message=text_msg("complete", id="msg-a"))) + + bundle = agg.snapshot() + assert [m.id for m in bundle.messages] == ["msg-a", tool.id] + assert bundle.messages[0].text == "complete" + + +def test_first_occurrence_position_is_kept() -> None: + """Replacement keeps the message at its original position.""" + agg = ai.agents.MessageAggregator() + tool = tool_result_msg(tc_id="tc-1", result="r") + agg.feed(events_.StreamEnd(message=text_msg("a1", id="msg-a"))) + agg.feed(events_.StreamEnd(message=tool)) + agg.feed(events_.StreamEnd(message=text_msg("b1", id="msg-b"))) + agg.feed(events_.StreamEnd(message=text_msg("a2", id="msg-a"))) + agg.feed(events_.StreamEnd(message=text_msg("b2", id="msg-b"))) + + bundle = agg.snapshot() + assert [m.id for m in bundle.messages] == ["msg-a", tool.id, "msg-b"] + assert [m.text for m in bundle.messages] == ["a2", "", "b2"]