Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/ai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,25 @@ class MessageAggregator(
):
def __init__(self) -> None:
self._messages: list[types.messages.Message] = []
self._index_by_id: dict[str, int] = {}

def feed(self, item: events_.AgentEvent) -> None:
if isinstance(item, events_.PartialToolCallResult):
return
msg = item.message
if msg is None:
return
if self._messages and self._messages[-1].id == msg.id:
self._messages[-1] = msg
else:
# Later snapshots of a message replace earlier ones in place.
# Snapshots of the same message are not always consecutive —
# e.g. a tool-result message can land between two snapshots of
# the assistant message that called it — so dedupe by id, not
# by checking the tail.
index = self._index_by_id.get(msg.id)
if index is None:
self._index_by_id[msg.id] = len(self._messages)
self._messages.append(msg)
else:
self._messages[index] = msg

def snapshot(self) -> MessageBundle:
return MessageBundle(messages=tuple(self._messages))
Expand Down
47 changes: 47 additions & 0 deletions tests/agents/test_message_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""MessageAggregator — deduping message snapshots by id."""

from __future__ import annotations

import ai
from ai.types import events as events_

from ..conftest import text_msg, tool_result_msg


def test_consecutive_snapshots_replace() -> None:
"""Later snapshots of the same message replace the earlier ones."""
agg = ai.agents.MessageAggregator()
agg.feed(events_.StreamEnd(message=text_msg("partial", id="msg-a")))
agg.feed(events_.StreamEnd(message=text_msg("complete", id="msg-a")))

bundle = agg.snapshot()
assert [m.text for m in bundle.messages] == ["complete"]


def test_interleaved_snapshots_replace() -> None:
"""Snapshots of the same message dedupe even when another message
lands in between (e.g. a tool-result message mid-stream)."""
agg = ai.agents.MessageAggregator()
tool = tool_result_msg(tc_id="tc-1", result="r")
agg.feed(events_.StreamEnd(message=text_msg("partial", id="msg-a")))
agg.feed(events_.StreamEnd(message=tool))
agg.feed(events_.StreamEnd(message=text_msg("complete", id="msg-a")))

bundle = agg.snapshot()
assert [m.id for m in bundle.messages] == ["msg-a", tool.id]
assert bundle.messages[0].text == "complete"


def test_first_occurrence_position_is_kept() -> None:
"""Replacement keeps the message at its original position."""
agg = ai.agents.MessageAggregator()
tool = tool_result_msg(tc_id="tc-1", result="r")
agg.feed(events_.StreamEnd(message=text_msg("a1", id="msg-a")))
agg.feed(events_.StreamEnd(message=tool))
agg.feed(events_.StreamEnd(message=text_msg("b1", id="msg-b")))
agg.feed(events_.StreamEnd(message=text_msg("a2", id="msg-a")))
agg.feed(events_.StreamEnd(message=text_msg("b2", id="msg-b")))

bundle = agg.snapshot()
assert [m.id for m in bundle.messages] == ["msg-a", tool.id, "msg-b"]
assert [m.text for m in bundle.messages] == ["a2", "", "b2"]
Loading