From 33ee4a56de872f3b6b550819b76ccad8036383c7 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Wed, 29 Apr 2026 18:32:19 -0700 Subject: [PATCH] Replace MessageStart/MessageEnd with ToolCallResult and HookEvent MessageStart and MessageEnd were always emitted back-to-back after the models rework, and MessageStart didn't carry any useful information anymore. A bunch of downstream examples were broken. (Some are still a bit broken after this; I have some more pending tool streaming changes.) Instead, we'll add two special-purpose event types: - ToolCallResult: emitted after tool execution with the result message - HookEvent: emitted when a hook suspends/resolves/cancels (Previously those were signaled with synthetic MessageStart/MessageEnd pairs.) We also add a `TerminalEvent = types.StreamEnd | ToolCallResult | HookEvent` that represents all of the AgentEvents that are "terminal". I also wanted to make `message` nonoptional in all the events (since from a *user* perspective, it should be!), so I've made the default be a `_DUMMY_MESSAGE`. This is an annoying hack but improves the user facing interface a fair amount, if they are using typing. --- CLAUDE.md | 1 + README.md | 8 +- examples/fastapi-vite/README.md | 2 +- examples/fastapi-vite/backend/agent.py | 6 +- examples/fastapi-vite/backend/main.py | 8 +- examples/fastapi-vite/backend/pyproject.toml | 1 - examples/fastapi-vite/frontend/vite.config.ts | 8 + examples/multiagent-textual/client.py | 72 +++-- examples/multiagent-textual/server.py | 43 +-- examples/samples/agent_custom_loop.py | 9 +- examples/samples/agent_hooks.py | 32 +-- examples/samples/agent_hooks_serverless.py | 29 +- examples/samples/streaming_tool.py | 25 +- examples/temporal-direct/main.py | 12 +- examples/temporal-middleware/main.py | 47 +++- src/ai/__init__.py | 12 +- src/ai/agents/__init__.py | 15 +- src/ai/agents/agent.py | 117 ++++---- src/ai/agents/events.py | 39 +-- src/ai/agents/hooks.py | 59 ++-- src/ai/agents/runtime.py | 9 +- src/ai/agents/ui/ai_sdk/outbound/_state.py | 242 ++++++---------- src/ai/agents/ui/ai_sdk/outbound/stream.py | 16 +- src/ai/middleware.py | 2 +- src/ai/types/builders.py | 28 +- src/ai/types/events.py | 8 +- tests/agents/test_generator_tools.py | 69 +++-- tests/agents/test_hooks.py | 43 ++- tests/agents/ui/ai_sdk/outbound/test_sse.py | 15 +- .../agents/ui/ai_sdk/outbound/test_stream.py | 262 ++++++++---------- tests/conftest.py | 2 +- tests/test_middleware.py | 5 +- tests/types/test_builders.py | 4 +- tests/types/test_integrity.py | 4 +- 34 files changed, 566 insertions(+), 688 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 8ea7fedf..c444a9ec 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -13,6 +13,7 @@ - UNLESS it's `typing` — then `from typing import Foo` (there are too many of them). - if the module name shadows a local variable in the same file, add a trailing underscore to the import: `from ..types import messages as messages_`. do not add trailing underscores preemptively — only when there is an actual collision. 4. tests directory structure mirrors `src` +5. to run examples that have their own `pyproject.toml`: `uv run --frozen --with-editable ~/src/py-ai/` ## design principles diff --git a/README.md b/README.md index 1ca41863..8daa5bd1 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ ai.yield_from(...) forward nested agent / streaming tool output ``` ai.system_message ai.user_message ai.assistant_message ai.tool_message -ai.tool_result ai.file_part ai.thinking +ai.tool_result ai.tool_result_part ai.file_part ai.thinking ``` ### Middleware @@ -109,15 +109,15 @@ async def custom(context: ai.Context): s = ai.stream(context.model, context.messages, tools=context.tools) async for event in s: yield event + if s.message is not None: + yield s.message tool_calls = context.resolve(s.tool_calls) if not tool_calls: return results = [await tc() for tc in tool_calls] - tool_msg = ai.tool_message(*results) - yield ai.MessageStart(message=tool_msg) - yield ai.MessageEnd(message=tool_msg) + yield ai.tool_result(*results) ``` ## Examples diff --git a/examples/fastapi-vite/README.md b/examples/fastapi-vite/README.md index 9fff5052..0418c296 100644 --- a/examples/fastapi-vite/README.md +++ b/examples/fastapi-vite/README.md @@ -16,7 +16,7 @@ to suspend execution whenever the LLM wants to call a tool. The flow is: 1. LLM emits a tool call 2. Backend calls `await ai.hook(...)` with `payload=ai.ToolApproval` -3. The runtime emits a `MessageEnd` event containing an internal `HookPart` +3. The runtime emits a `HookEvent` containing the `HookPart` 4. The frontend renders Approve / Reject buttons via the `` component (from AI Elements) 5. When the user clicks a button, `addToolApprovalResponse()` patches diff --git a/examples/fastapi-vite/backend/agent.py b/examples/fastapi-vite/backend/agent.py index 4a5ad165..3f1303d2 100644 --- a/examples/fastapi-vite/backend/agent.py +++ b/examples/fastapi-vite/backend/agent.py @@ -37,6 +37,8 @@ async def graph(context: ai.Context) -> AsyncGenerator[ai.Event]: s = ai.models.stream(context.model, context.messages, tools=context.tools) async for event in s: yield event + if s.message is not None: + yield s.message tool_calls = context.resolve(s.tool_calls) if not tool_calls: @@ -45,9 +47,7 @@ async def graph(context: ai.Context) -> AsyncGenerator[ai.Event]: results = await asyncio.gather( *(_execute_with_approval(tc) for tc in tool_calls) ) - tool_msg = ai.tool_message(*results) - yield ai.MessageStart(message=tool_msg) - yield ai.MessageEnd(message=tool_msg) + yield ai.tool_result(*results) async def _execute_with_approval(tc: ai.ToolCall) -> ai.Message: diff --git a/examples/fastapi-vite/backend/main.py b/examples/fastapi-vite/backend/main.py index e4e49d5a..32b1c607 100644 --- a/examples/fastapi-vite/backend/main.py +++ b/examples/fastapi-vite/backend/main.py @@ -35,21 +35,21 @@ async def health() -> dict[str, str]: class ChatRequest(pydantic.BaseModel): """Request body for the chat endpoint.""" - messages: list[ai.ai_sdk_ui.UIMessage] + messages: list[ai.agents.ui.ai_sdk.UIMessage] session_id: str | None = None @app.post("/chat") async def chat(request: ChatRequest) -> fastapi.responses.StreamingResponse: """Handle chat requests and stream responses.""" - messages = ai.ai_sdk_ui.to_messages(request.messages) + messages = ai.agents.ui.ai_sdk.to_messages(request.messages) result = agent_.chat_agent.run(agent_.MODEL, messages) async def stream_response() -> AsyncGenerator[str]: - async for chunk in ai.ai_sdk_ui.to_sse(result): + async for chunk in ai.agents.ui.ai_sdk.to_sse(result): yield chunk return fastapi.responses.StreamingResponse( stream_response(), - headers=ai.ai_sdk_ui.UI_MESSAGE_STREAM_HEADERS, + headers=ai.agents.ui.ai_sdk.UI_MESSAGE_STREAM_HEADERS, ) diff --git a/examples/fastapi-vite/backend/pyproject.toml b/examples/fastapi-vite/backend/pyproject.toml index 3fa5b347..1909ae57 100644 --- a/examples/fastapi-vite/backend/pyproject.toml +++ b/examples/fastapi-vite/backend/pyproject.toml @@ -7,4 +7,3 @@ dependencies = [ "fastapi[standard]>=0.128.1", "vercel-ai-sdk>=0.0.1.dev5", ] - diff --git a/examples/fastapi-vite/frontend/vite.config.ts b/examples/fastapi-vite/frontend/vite.config.ts index 0cca23e1..705c4337 100644 --- a/examples/fastapi-vite/frontend/vite.config.ts +++ b/examples/fastapi-vite/frontend/vite.config.ts @@ -11,4 +11,12 @@ export default defineConfig({ '@': path.resolve(__dirname, './src'), }, }, + server: { + proxy: { + '/api': { + target: 'http://localhost:8000', + rewrite: (path) => path.replace(/^\/api/, ''), + }, + }, + }, }) diff --git a/examples/multiagent-textual/client.py b/examples/multiagent-textual/client.py index d6123af5..5badedcd 100644 --- a/examples/multiagent-textual/client.py +++ b/examples/multiagent-textual/client.py @@ -112,7 +112,7 @@ def __init__(self) -> None: self._hook_queue: asyncio.Queue[ai.HookPart] = asyncio.Queue() self._current_hook: ai.HookPart | None = None self._ws: websockets.ClientConnection | None = None - self._event_adapter = pydantic.TypeAdapter(ai.Event) + self._event_adapter = pydantic.TypeAdapter(ai.AgentEvent) self._current_label = "unknown" def compose(self) -> textual.app.ComposeResult: @@ -150,7 +150,10 @@ async def run_websocket(self) -> None: self._on_run_complete() break - event = self._event_adapter.validate_python(data) + try: + event = self._event_adapter.validate_python(data) + except pydantic.ValidationError: + continue self._handle_event(event) except (ConnectionRefusedError, OSError) as exc: @@ -158,54 +161,47 @@ async def run_websocket(self) -> None: # ------------------------------------------------------------------ # Event routing + # + # TODO: streaming events (TextDelta, etc.) don't carry a source + # label, so _current_label is only updated when a ToolCallResult + # or HookEvent arrives. With concurrent sub-agents, streaming + # text can route to the wrong panel. The hooks in this demo + # serialize the flow enough that it works in practice, but a + # proper fix needs labels on streaming events. # ------------------------------------------------------------------ - def _handle_event(self, event: ai.Event) -> None: - if isinstance(event, ai.MessageStart) and event.message is not None: - self._current_label = event.message.source_label or "unknown" - panel = self._get_panel(self._current_label) - if panel is not None and panel.status == "idle": - panel.status = "streaming..." + def _handle_event(self, event: ai.AgentEvent) -> None: + if isinstance(event, ai.ToolCallResult): + label = event.message.source_label or self._current_label + self._current_label = label + panel = self._get_panel(label) + if panel is not None: + for part in event.message.parts: + match part: + case ai.ToolCallPart(tool_name=name, tool_args=args): + panel.append_line(f"> {name}({args})") + case ai.ToolResultPart(tool_name=name, result=result): + panel.append_line(f"< {name} = {result}") + return + + if isinstance(event, ai.HookEvent): + if event.hook.status == "pending": + self._on_hook_pending(event.hook) + elif event.hook.status == "resolved": + self._on_hook_resolved(event.hook) return if isinstance(event, ai.TextDelta): panel = self._get_panel(self._current_label) if panel is not None: panel.append_text(event.chunk) - return + if panel.status == "idle": + panel.status = "streaming..." - if isinstance(event, ai.ReasoningDelta | ai.ToolDelta): + elif isinstance(event, ai.ReasoningDelta | ai.ToolDelta): panel = self._get_panel(self._current_label) if panel is not None: panel.append_text(event.chunk, style="dim") - return - - if not isinstance(event, ai.MessageEnd): - return - - msg = event.message - label = msg.source_label or self._current_label - - hook_parts = [p for p in msg.parts if isinstance(p, ai.HookPart)] - if hook_parts: - hook_part = hook_parts[0] - if hook_part.status == "pending": - self._on_hook_pending(hook_part) - return - if hook_part.status == "resolved": - self._on_hook_resolved(hook_part) - return - - panel = self._get_panel(label) - if panel is None: - return - - for part in msg.parts: - match part: - case ai.ToolCallPart(tool_name=name, tool_args=args): - panel.append_line(f"> {name}({args})") - case ai.ToolResultPart(tool_name=name, result=result): - panel.append_line(f"< {name} = {result}") # ------------------------------------------------------------------ # Hook lifecycle diff --git a/examples/multiagent-textual/server.py b/examples/multiagent-textual/server.py index f23f76e3..a673fe2c 100644 --- a/examples/multiagent-textual/server.py +++ b/examples/multiagent-textual/server.py @@ -83,6 +83,8 @@ async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: s = ai.stream(context.model, context.messages, tools=context.tools) async for event in s: yield event + if s.message is not None: + yield s.message tool_calls = context.resolve(s.tool_calls) if not tool_calls: @@ -110,9 +112,7 @@ async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: else: results.append(await tc()) - tool_msg = ai.tool_message(*results) - yield ai.MessageStart(message=tool_msg) - yield ai.MessageEnd(message=tool_msg) + yield ai.tool_result(*results) return gated @@ -176,8 +176,9 @@ async def multiagent_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: combined = f"Mothership: {r1}\nData centers: {r2}" - # Fan in: summarise. - s = ai.stream( + # Fan in: summarise via a labelled sub-agent. + summary_agent = ai.agent() + async for event in summary_agent.run( context.model, [ ai.system_message( @@ -185,26 +186,9 @@ async def multiagent_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: ), ai.user_message(combined), ], - ) - async for event in s: - if isinstance(event, ai.MessageEnd): - yield event.model_copy( - update={ - "message": event.message.model_copy( - update={"source_label": "summary"} - ) - } - ) - elif isinstance(event, ai.MessageStart) and event.message is not None: - yield event.model_copy( - update={ - "message": event.message.model_copy( - update={"source_label": "summary"} - ) - } - ) - else: - yield event + label="summary", + ): + yield event # --------------------------------------------------------------------------- @@ -262,13 +246,8 @@ async def read_resolutions() -> None: data = _normalise_event(event.model_dump()) await websocket.send_json(data) - if isinstance(event, ai.MessageEnd) and event.message.role == "internal": - hook_parts = [ - p for p in event.message.parts if isinstance(p, ai.HookPart) - ] - if hook_parts: - hook_part = hook_parts[0] - print(f" Hook {hook_part.status}: {hook_part.hook_id}") + if isinstance(event, ai.HookEvent): + print(f" Hook {event.hook.status}: {event.hook.hook_id}") finally: reader.cancel() with contextlib.suppress(asyncio.CancelledError): diff --git a/examples/samples/agent_custom_loop.py b/examples/samples/agent_custom_loop.py index 6342b2ea..7bcfe152 100644 --- a/examples/samples/agent_custom_loop.py +++ b/examples/samples/agent_custom_loop.py @@ -32,6 +32,10 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: async for event in s: yield event + # Yield the assistant message for silent history collection. + if s.message is not None: + yield s.message + tool_calls = context.resolve(s.tool_calls) if not tool_calls: return @@ -46,10 +50,7 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: async with asyncio.TaskGroup() as tg: tasks = [tg.create_task(tc()) for tc in tool_calls] - # Yield one merged tool-result message — history auto-collects it. - tool_msg = ai.tool_message(*(t.result() for t in tasks)) - yield ai.MessageStart(message=tool_msg) - yield ai.MessageEnd(message=tool_msg) + yield ai.tool_result(*(t.result() for t in tasks)) async for event in my_agent.run( model, diff --git a/examples/samples/agent_hooks.py b/examples/samples/agent_hooks.py index 8fe4d887..c103035c 100644 --- a/examples/samples/agent_hooks.py +++ b/examples/samples/agent_hooks.py @@ -3,7 +3,7 @@ Demonstrates the function-based hook API: - await hook("label", payload=Model) to suspend inside the loop - resolve_hook("label", data) to unblock from outside - - Hook messages arrive as MessageEnd events with role="internal" + - Hook signals arrive as HookEvent events """ import asyncio @@ -38,6 +38,8 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Event]: ) async for event in s: yield event + if s.message is not None: + yield s.message tool_calls = context.resolve(s.tool_calls) if not tool_calls: @@ -66,9 +68,7 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Event]: else: results.append(await tc()) - tool_msg = ai.tool_message(*results) - yield ai.MessageStart(message=tool_msg) - yield ai.MessageEnd(message=tool_msg) + yield ai.tool_result(*results) messages = [ ai.system_message( @@ -82,19 +82,17 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Event]: print(event.chunk, end="", flush=True) continue - # Hook signals arrive as internal MessageEnd events. - if isinstance(event, ai.MessageEnd) and event.message.role == "internal": - hook_parts = [p for p in event.message.parts if isinstance(p, ai.HookPart)] - hook_part = hook_parts[0] if hook_parts else None - if hook_part is not None and hook_part.status == "pending": - answer = input(f"Approve {hook_part.hook_id}? [y/n] ") - ai.resolve_hook( - hook_part.hook_id, - Approval( - granted=answer.strip().lower() in ("y", "yes"), - reason="operator decision", - ), - ) + # Hook signals arrive as HookEvent events. + if isinstance(event, ai.HookEvent) and event.hook.status == "pending": + hook_part = event.hook + answer = input(f"Approve {hook_part.hook_id}? [y/n] ") + ai.resolve_hook( + hook_part.hook_id, + Approval( + granted=answer.strip().lower() in ("y", "yes"), + reason="operator decision", + ), + ) print() diff --git a/examples/samples/agent_hooks_serverless.py b/examples/samples/agent_hooks_serverless.py index 7372ca6b..8262c9ab 100644 --- a/examples/samples/agent_hooks_serverless.py +++ b/examples/samples/agent_hooks_serverless.py @@ -43,6 +43,8 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Event]: ) async for event in s: yield event + if s.message is not None: + yield s.message tool_calls = context.resolve(s.tool_calls) if not tool_calls: @@ -73,9 +75,7 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Event]: ) ) - tool_msg = ai.tool_message(*results) - yield ai.MessageStart(message=tool_msg) - yield ai.MessageEnd(message=tool_msg) + yield ai.tool_result(*results) messages = [ ai.system_message("Delete files when asked. Always use the delete_file tool."), @@ -89,15 +89,13 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Event]: async for event in my_agent.run(model, messages): if isinstance(event, ai.TextDelta): print(event.chunk, end="", flush=True) - elif isinstance(event, ai.MessageEnd) and event.message.role == "internal": - hook_parts = [p for p in event.message.parts if isinstance(p, ai.HookPart)] - hook_part = hook_parts[0] if hook_parts else None - if hook_part is not None and hook_part.status == "pending": - pending_hook_labels.append(hook_part.hook_id) - print( - f" Hook pending: {hook_part.hook_id}" - f" (metadata={hook_part.metadata})" - ) + elif isinstance(event, ai.HookEvent) and event.hook.status == "pending": + hook_part = event.hook + pending_hook_labels.append(hook_part.hook_id) + print( + f" Hook pending: {hook_part.hook_id}" + f" (metadata={hook_part.metadata})" + ) print("\n Run interrupted; approval will be pre-registered for re-entry.\n") @@ -109,11 +107,8 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Event]: async for event in my_agent.run(model, messages): if isinstance(event, ai.TextDelta): print(event.chunk, end="", flush=True) - elif isinstance(event, ai.MessageEnd) and event.message.role == "internal": - hook_parts = [p for p in event.message.parts if isinstance(p, ai.HookPart)] - hook_part = hook_parts[0] if hook_parts else None - if hook_part is not None: - print(f" Hook {hook_part.status}: {hook_part.hook_id}") + elif isinstance(event, ai.HookEvent): + print(f" Hook {event.hook.status}: {event.hook.hook_id}") print() diff --git a/examples/samples/streaming_tool.py b/examples/samples/streaming_tool.py index 59e5b15d..277e3b33 100644 --- a/examples/samples/streaming_tool.py +++ b/examples/samples/streaming_tool.py @@ -15,22 +15,17 @@ async def talk_to_mothership(question: str) -> AsyncGenerator[ai.Event]: """Ask the mothership a question. Streams progress back to the caller.""" for step in ["Connecting...", "Transmitting...", "Awaiting response..."]: - msg = ai.Message( - role="assistant", - parts=[ai.TextPart(text=step)], - source_label="tool_progress", - ) - yield ai.MessageStart(message=msg) - yield ai.MessageEnd(message=msg) + yield ai.TextStart(block_id=f"progress-{step}") + yield ai.TextDelta(block_id=f"progress-{step}", chunk=step) + yield ai.TextEnd(block_id=f"progress-{step}") await asyncio.sleep(0.3) # The final yielded message's text is returned as the tool result. - msg = ai.Message( + final = ai.Message( role="assistant", parts=[ai.TextPart(text="The mothership says: Soon.")], ) - yield ai.MessageStart(message=msg) - yield ai.MessageEnd(message=msg) + yield final async def main() -> None: @@ -45,12 +40,10 @@ async def main() -> None: async for event in my_agent.run(model, messages): if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) - elif ( - isinstance(event, ai.MessageEnd) - and event.message.source_label == "tool_progress" - ): - print(f" [{event.message.text}]") + if event.block_id.startswith("progress-"): + print(f" [{event.chunk}]") + else: + print(event.chunk, end="", flush=True) print() diff --git a/examples/temporal-direct/main.py b/examples/temporal-direct/main.py index e834fedb..3af2bf0e 100644 --- a/examples/temporal-direct/main.py +++ b/examples/temporal-direct/main.py @@ -136,8 +136,8 @@ async def temporal_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: retry_policy=temporalio.common.RetryPolicy(maximum_attempts=3), ) msg = ai.Message.model_validate(result.message) - yield ai.MessageStart(message=msg) - yield ai.MessageEnd(message=msg) + yield ai.StreamEnd(message=msg) + yield msg # 2. No tool calls → done if not msg.tool_calls: @@ -154,15 +154,13 @@ async def run_tool(tc: ai.ToolCallPart) -> ai.ToolResultPart: args=list(kwargs.values()), start_to_close_timeout=datetime.timedelta(minutes=2), ) - return ai.tool_result( + return ai.tool_result_part( tc.tool_call_id, tool_name=tc.tool_name, result=result ) tasks = [asyncio.ensure_future(run_tool(tc)) for tc in msg.tool_calls] parts = await asyncio.gather(*tasks) - tool_msg = ai.tool_message(*parts) - yield ai.MessageStart(message=tool_msg) - yield ai.MessageEnd(message=tool_msg) + yield ai.tool_result(*parts) # ── Workflow ───────────────────────────────────────────────────── @@ -182,7 +180,7 @@ async def run(self, user_query: str) -> str: final_text = "" async for event in weather_agent.run(model, messages): - if isinstance(event, ai.MessageEnd) and event.message.text: + if isinstance(event, ai.TerminalEvent): final_text = event.message.text return final_text diff --git a/examples/temporal-middleware/main.py b/examples/temporal-middleware/main.py index d71b514d..0eebe327 100644 --- a/examples/temporal-middleware/main.py +++ b/examples/temporal-middleware/main.py @@ -125,11 +125,39 @@ async def llm_call_activity(params: LLMParams) -> LLMResult: return LLMResult(message=s.message.model_dump()) +async def _replay_as_stream(msg: ai.Message) -> AsyncGenerator[ai.Event]: + """Replay a complete message as streaming events for ``ai.Stream``. + + TODO: This exists because wrap_model must return a Stream, and Stream + aggregates from streaming deltas. A complete message has to be + decomposed into synthetic events so Stream can rebuild it. The + middleware contract should support returning a complete Message + directly. + """ + yield ai.StreamStart() + for i, part in enumerate(msg.parts): + if isinstance(part, ai.TextPart) and part.text: + bid = f"text-{i}" + yield ai.TextStart(block_id=bid) + yield ai.TextDelta(block_id=bid, chunk=part.text) + yield ai.TextEnd(block_id=bid) + elif isinstance(part, ai.ToolCallPart): + yield ai.ToolStart( + tool_call_id=part.tool_call_id, tool_name=part.tool_name + ) + if part.tool_args: + yield ai.ToolDelta( + tool_call_id=part.tool_call_id, chunk=part.tool_args + ) + yield ai.ToolEnd(tool_call_id=part.tool_call_id) + yield ai.StreamEnd() + + # ── Middleware ─────────────────────────────────────────────────── # # Intercepts wrap_model and wrap_tool to replace real I/O with # Temporal activities. The default agent loop runs unchanged — -# it just sees a StreamResult from wrap_model and a Message from +# it just sees a Stream from wrap_model and a Message from # wrap_tool, same as without middleware. @@ -143,8 +171,12 @@ async def wrap_model( self, call: ai.middleware.ModelContext, next: Any, - ) -> ai.StreamResultLike: - """LLM call → Temporal activity.""" + ) -> Any: + """LLM call → Temporal activity. + + Returns an ``ai.Stream`` that replays the complete message as + streaming events so the default loop can iterate it normally. + """ result = await temporalio.workflow.execute_activity( llm_call_activity, LLMParams( @@ -155,12 +187,7 @@ async def wrap_model( retry_policy=temporalio.common.RetryPolicy(maximum_attempts=3), ) msg = ai.Message.model_validate(result.message) - - async def _single() -> AsyncGenerator[ai.Event]: - yield ai.MessageStart(message=msg) - yield ai.MessageEnd(message=msg) - - return ai.StreamResult.from_generator(_single()) + return ai.Stream(_replay_as_stream(msg)) async def wrap_tool( self, @@ -218,7 +245,7 @@ async def run(self, user_query: str) -> str: final_text = "" async for event in weather_agent.run(model, messages, middleware=[mw]): - if isinstance(event, ai.MessageEnd) and event.message.text: + if isinstance(event, ai.TerminalEvent): final_text = event.message.text return final_text diff --git a/src/ai/__init__.py b/src/ai/__init__.py index d985b9d9..816a506f 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -2,16 +2,21 @@ from .agents import ( TOOL_APPROVAL_HOOK_TYPE, Agent, + AgentEvent, Context, + HookEvent, + TerminalEvent, Tool, ToolApproval, ToolCall, + ToolCallResult, agent, cancel_hook, hook, mcp, resolve_hook, tool, + tool_result, yield_from, ) from .middleware import AgentRunContext, Middleware @@ -71,7 +76,7 @@ system_message, thinking, tool_message, - tool_result, + tool_result_part, user_message, ) @@ -110,6 +115,7 @@ "system_message", "tool_message", "tool_result", + "tool_result_part", "file_part", "thinking", # Models (from models/) @@ -134,11 +140,13 @@ "ai_gateway", # Agents — primary API "Agent", + "AgentEvent", "agent", "Context", # Agents — tools "Tool", "ToolCall", + "ToolCallResult", "tool", # Agents — composition "yield_from", @@ -146,6 +154,8 @@ "hook", "resolve_hook", "cancel_hook", + "HookEvent", + "TerminalEvent", "ToolApproval", "TOOL_APPROVAL_HOOK_TYPE", # Middleware diff --git a/src/ai/agents/__init__.py b/src/ai/agents/__init__.py index fe8ab878..575ab7eb 100644 --- a/src/ai/agents/__init__.py +++ b/src/ai/agents/__init__.py @@ -1,5 +1,6 @@ -from . import mcp -from .agent import Agent, Context, Tool, ToolCall, agent, tool, yield_from +from . import mcp, ui +from .agent import Agent, Context, Tool, ToolCall, agent, tool, tool_result, yield_from +from .events import AgentEvent, HookEvent, TerminalEvent, ToolCallResult from .hooks import ( TOOL_APPROVAL_HOOK_TYPE, ToolApproval, @@ -10,16 +11,22 @@ __all__ = [ "Agent", + "AgentEvent", "Context", + "HookEvent", "Tool", + "ToolApproval", "ToolCall", + "TerminalEvent", + "ToolCallResult", + "TOOL_APPROVAL_HOOK_TYPE", "agent", "cancel_hook", "hook", "mcp", "resolve_hook", "tool", + "tool_result", + "ui", "yield_from", - "ToolApproval", - "TOOL_APPROVAL_HOOK_TYPE", ] diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 57735f99..701f9e7f 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -16,6 +16,10 @@ from . import events as events_ from . import runtime +# What loop functions yield: AgentEvents pass through to the consumer, +# bare Messages are silently collected into history. +StreamItem = events_.AgentEvent | types.Message + class Tool[**P, R]: """Wraps async function, introspects schema, attaches a validator""" @@ -200,52 +204,55 @@ def resolve(self, tool_parts: list[types.ToolCallPart]) -> list[ToolCall]: ] -StreamItem = events_.AgentEvent | types.Message - - class LoopFn(Protocol): def __call__(self, context: Context) -> AsyncGenerator[StreamItem]: ... -async def _message_events( - message: types.Message, -) -> AsyncGenerator[events_.AgentEvent]: - yield events_.MessageStart(message=message) - yield events_.MessageEnd(message=message) +def tool_result( + *items: types.Message | types.ToolResultPart | list[types.Message], +) -> events_.ToolCallResult: + """Create a :class:`ToolCallResult` from tool messages or parts. + Wraps :func:`ai.tool_message` and returns a ``ToolCallResult`` + event ready to yield from a custom loop:: + + yield ai.tool_result(*(t.result() for t in tasks)) + """ + msg = builders.tool_message(*items) + return events_.ToolCallResult(message=msg, results=msg.tool_results) -async def _coerce_events( - source: AsyncIterable[StreamItem], -) -> AsyncGenerator[events_.AgentEvent]: - async for item in source: - if isinstance(item, types.Message): - async for event in _message_events(item): - yield event - else: - yield item + +def _upsert_message(messages: list[types.Message], message: types.Message) -> None: + """Insert or replace *message* in the history list.""" + for i, existing in enumerate(messages): + if existing.id == message.id: + messages[i] = message + return + messages.append(message) async def _collect_messages( source: AsyncIterable[StreamItem], messages: list[types.Message], ) -> AsyncGenerator[events_.AgentEvent]: - """Intercept yielded events and collect MessageEnd messages into *messages*. + """Intercept yielded items and maintain the *messages* history list. + + * Bare ``Message`` — silently collected (not forwarded to consumer). + * ``ToolCallResult`` — collected *and* forwarded. + * Any other ``AgentEvent`` — forwarded as-is. This runs on the **producer** side (same coroutine as the loop function), so ``messages`` is always up-to-date by the time the loop reads it for - the next model call — avoiding the race that would occur if collection - happened on the consumer side of the runtime queue. + the next model call. """ - async for event in _coerce_events(source): - if isinstance(event, events_.MessageEnd): - message = event.message - for i, existing in enumerate(messages): - if existing.id == message.id: - messages[i] = message - break - else: - messages.append(message) - yield event + async for item in source: + if isinstance(item, types.Message): + _upsert_message(messages, item) + elif isinstance(item, events_.ToolCallResult): + _upsert_message(messages, item.message) + yield item + else: + yield item async def yield_from(source: AsyncIterable[StreamItem]) -> str: @@ -267,9 +274,12 @@ async def yield_from(source: AsyncIterable[StreamItem]) -> str: """ rt = runtime.get_runtime() last: types.Message | None = None - async for item in _coerce_events(source): + async for item in source: + if isinstance(item, types.Message): + last = item + continue await rt.put_event(item) - if isinstance(item, events_.MessageEnd): + if isinstance(item, events_.TerminalEvent): last = item.message return last.text if last else "" @@ -296,9 +306,7 @@ def loop(self, fn: LoopFn) -> LoopFn: self._loop_fn = fn return fn - async def default_loop( - self, context: Context - ) -> AsyncGenerator[events_.AgentEvent]: + async def default_loop(self, context: Context) -> AsyncGenerator[StreamItem]: while True: stream = models.stream( context.model, @@ -308,13 +316,9 @@ async def default_loop( async for stream_event in stream: yield stream_event - # Bridge: emit MessageStart/MessageEnd around the assistant message - # the model stream just produced, so _collect_messages and downstream - # consumers (AI-SDK outbound, label stamping) see the same boundary - # events they did under the previous adapter contract. + # Yield the assistant message for silent history collection. if stream.message is not None and stream.message.parts: - async for boundary in _message_events(stream.message): - yield boundary + yield stream.message tool_calls = context.resolve(stream.tool_calls) if not tool_calls: @@ -324,12 +328,7 @@ async def default_loop( async with asyncio.TaskGroup() as tg: tasks = [tg.create_task(tc()) for tc in tool_calls] - # Yield one merged tool-result message — history auto-collects it. - # Left un-stamped: the tool result is the input of the *next* turn, - # so the next stream() call will stamp it with that turn's id. - tool_msg = builders.tool_message(*(t.result() for t in tasks)) - async for boundary in _message_events(tool_msg): - yield boundary + yield tool_result(*(t.result() for t in tasks)) async def run( self, @@ -370,22 +369,14 @@ async def _real( ) source = _collect_messages(loop_fn(context), context.messages) async for event in runtime.run(source): - if call.label is not None: - event_message: types.Message | None = None - if isinstance(event, events_.MessageEnd) or ( - isinstance(event, events_.MessageStart) - and event.message is not None - ): - event_message = event.message - - if event_message is not None: - event = event.model_copy( - update={ - "message": event_message.model_copy( - update={"source_label": call.label} - ) - } - ) + if call.label is not None and isinstance(event, events_.ToolCallResult): + event = event.model_copy( + update={ + "message": event.message.model_copy( + update={"source_label": call.label} + ) + } + ) yield event # Activate middleware for this run (and everything it calls). diff --git a/src/ai/agents/events.py b/src/ai/agents/events.py index 12108464..ac3fbc51 100644 --- a/src/ai/agents/events.py +++ b/src/ai/agents/events.py @@ -1,10 +1,8 @@ """Agent-layer event types. The model layer emits ``StreamStart`` / ``StreamEnd`` plus block-level -deltas. The agent layer wraps those with ``MessageStart`` / ``MessageEnd`` -boundaries that delimit complete messages — assistant turns produced by -the model, plus synthetic user / tool / hook messages injected into the -runtime queue. +deltas. The agent layer adds ``ToolCallResult`` (tool execution outcomes) +and ``HookEvent`` (human-in-the-loop suspension points). These types live here (rather than in ``ai.types.events``) because they are an agent-runtime concern, not part of the public model-streaming @@ -13,33 +11,42 @@ from __future__ import annotations -from typing import Literal +from collections.abc import Sequence +from typing import Any, Literal import pydantic from .. import types -class MessageStart(pydantic.BaseModel): - message: types.Message | None = None +class ToolCallResult(pydantic.BaseModel): + """Emitted after tool calls execute — carries the result message.""" - kind: Literal["message_start"] = "message_start" + message: types.Message + results: Sequence[types.ToolResultPart] + + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + + kind: Literal["tool_call_result"] = "tool_call_result" -class MessageEnd(pydantic.BaseModel): +class HookEvent(pydantic.BaseModel): + """Emitted when a hook suspends, resolves, or is cancelled.""" + message: types.Message - usage: types.Usage | None = None + hook: types.HookPart[Any] + + kind: Literal["hook"] = "hook" - kind: Literal["message_end"] = "message_end" +AgentEvent = types.Event | ToolCallResult | HookEvent -# Widened event alias used inside agents/. Not part of ``types.Event``'s -# discriminated union — these wrappers do not flow through model adapters. -AgentEvent = types.Event | MessageStart | MessageEnd +TerminalEvent = types.StreamEnd | ToolCallResult | HookEvent __all__ = [ "AgentEvent", - "MessageEnd", - "MessageStart", + "HookEvent", + "TerminalEvent", + "ToolCallResult", ] diff --git a/src/ai/agents/hooks.py b/src/ai/agents/hooks.py index 6205dcb0..b30bfaf2 100644 --- a/src/ai/agents/hooks.py +++ b/src/ai/agents/hooks.py @@ -122,18 +122,13 @@ async def _hook_impl(call: middleware_.HookContext) -> pydantic.BaseModel: _live_hooks[label] = (future, hook_metadata, rt) rt.track_hook_label(label) - # Emit pending signal message. - await rt.put_message( - messages_.Message( - role="internal", - parts=[ - messages_.HookPart( - hook_id=label, - hook_type=payload.__name__, - status="pending", - metadata=hook_metadata, - ) - ], + # Emit pending signal. + await rt.put_hook( + messages_.HookPart( + hook_id=label, + hook_type=payload.__name__, + status="pending", + metadata=hook_metadata, ) ) @@ -150,19 +145,14 @@ async def _hook_impl(call: middleware_.HookContext) -> pydantic.BaseModel: # Clean up live registry. _live_hooks.pop(label, None) - # Emit resolved internal message. - await rt.put_message( - messages_.Message( - role="internal", - parts=[ - messages_.HookPart( - hook_id=label, - hook_type=payload.__name__, - status="resolved", - metadata=hook_metadata, - resolution=resolution, - ) - ], + # Emit resolved signal. + await rt.put_hook( + messages_.HookPart( + hook_id=label, + hook_type=payload.__name__, + status="resolved", + metadata=hook_metadata, + resolution=resolution, ) ) @@ -229,18 +219,13 @@ async def cancel_hook(label: str, *, reason: str | None = None) -> None: future, hook_metadata, rt = _live_hooks.pop(label) future.cancel(reason) - # Emit cancelled internal message. - await rt.put_message( - messages_.Message( - role="internal", - parts=[ - messages_.HookPart( - hook_id=label, - hook_type="", # not available at cancel site - status="cancelled", - metadata=hook_metadata, - ) - ], + # Emit cancelled signal. + await rt.put_hook( + messages_.HookPart( + hook_id=label, + hook_type="", # not available at cancel site + status="cancelled", + metadata=hook_metadata, ) ) diff --git a/src/ai/agents/runtime.py b/src/ai/agents/runtime.py index a8005748..b2b09d1d 100644 --- a/src/ai/agents/runtime.py +++ b/src/ai/agents/runtime.py @@ -5,8 +5,9 @@ import asyncio import contextvars from collections.abc import AsyncGenerator, AsyncIterable, Awaitable +from typing import Any -from .. import types +from ..types import messages as messages_ from . import events as events_ from . import hooks as hooks_ from .mcp import client as mcp_client @@ -29,9 +30,9 @@ def __init__(self) -> None: async def put_event(self, event: events_.AgentEvent) -> None: await self._event_queue.put(event) - async def put_message(self, message: types.Message) -> None: - await self.put_event(events_.MessageStart(message=message)) - await self.put_event(events_.MessageEnd(message=message)) + async def put_hook(self, hook_part: messages_.HookPart[Any]) -> None: + msg = messages_.Message(role="internal", parts=[hook_part]) + await self.put_event(events_.HookEvent(message=msg, hook=hook_part)) async def signal_done(self) -> None: await self._event_queue.put(self._SENTINEL) diff --git a/src/ai/agents/ui/ai_sdk/outbound/_state.py b/src/ai/agents/ui/ai_sdk/outbound/_state.py index e6c6eed3..9d041a8a 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/_state.py +++ b/src/ai/agents/ui/ai_sdk/outbound/_state.py @@ -6,6 +6,7 @@ from .....types import events as events_ from .....types import messages as messages_ +from ....events import HookEvent, ToolCallResult from .. import _approvals, protocol @@ -25,16 +26,11 @@ class _StreamState: """Single-pass state across one ``to_stream()`` call.""" def __init__(self) -> None: - self.current_turn_id: str | None = None self.current_agent: str | None = None self.ui_message_id: str | None = None self.emitted_start: bool = False self.in_step: bool = False - self.seen_done: set[str] = set() - self.skip_current_message: bool = False - self.started_current_message: bool = False - self.started_tool_inputs: set[str] = set() self.tool_names: dict[str, str] = {} self.input_available_emitted: set[str] = set() @@ -76,35 +72,18 @@ def _reset_step_tracking(self) -> None: self.emitted_tool_results.clear() self.emitted_approval_requests.clear() - @staticmethod - def _is_visible_message(msg: messages_.Message) -> bool: - return msg.role not in ("user", "system") - - # -- phase: message start ------------------------------------------------ - - def on_message_start( - self, msg: messages_.Message | None + def _ensure_started( + self, *, source_label: str | None = None ) -> list[protocol.UIMessageStreamPart]: - self.started_current_message = False - self.skip_current_message = False - if msg is None: - return [] - if msg.id in self.seen_done or not self._is_visible_message(msg): - self.skip_current_message = True - return [] - self.started_current_message = True - return self.on_message(msg) - - def on_message(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPart]: - """Emit UIMessage/step boundary parts for *msg*.""" - if not self._is_visible_message(msg): - return [] + """Lazily emit StartPart / StartStepPart on the first event. + Also handles agent-change boundaries in multi-agent scenarios. + """ parts: list[protocol.UIMessageStreamPart] = [] agent_changed = ( self.emitted_start - and msg.source_label is not None - and msg.source_label != self.current_agent + and source_label is not None + and source_label != self.current_agent ) if not self.emitted_start or agent_changed: @@ -112,92 +91,72 @@ def on_message(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPar if self.emitted_start: parts.append(protocol.FinishPart(finish_reason="stop")) - self.ui_message_id = msg.id - parts.append(protocol.StartPart(message_id=msg.id)) + parts.append(protocol.StartPart(message_id=None)) parts.append(protocol.StartStepPart()) self.emitted_start = True self.in_step = True - self.current_agent = msg.source_label - self.current_turn_id = msg.turn_id - self._reset_step_tracking() - return parts - - if ( - msg.turn_id is not None - and self.current_turn_id is not None - and msg.turn_id != self.current_turn_id - ): - parts.extend(self._finish_step()) - parts.append(protocol.StartStepPart()) - self.in_step = True + self.current_agent = source_label self._reset_step_tracking() - self.current_turn_id = msg.turn_id - elif msg.turn_id is not None and self.current_turn_id is None: - self.current_turn_id = msg.turn_id return parts # -- phase: streaming events -------------------------------------------- def on_event(self, event: events_.Event) -> list[protocol.UIMessageStreamPart]: - if self.skip_current_message: - return [] + out: list[protocol.UIMessageStreamPart] = [] + + # Lazily open the UI message on the first streaming event. + if not self.emitted_start: + out.extend(self._ensure_started()) match event: case events_.TextStart(block_id=pid): self.open_text_ids.add(pid) - return [protocol.TextStartPart(id=pid)] + out.append(protocol.TextStartPart(id=pid)) case events_.TextDelta(block_id=pid, chunk=chunk): - out: list[protocol.UIMessageStreamPart] = [] if pid not in self.open_text_ids: self.open_text_ids.add(pid) out.append(protocol.TextStartPart(id=pid)) self.text_delta_ids.add(pid) out.append(protocol.TextDeltaPart(id=pid, delta=chunk)) - return out case events_.TextEnd(block_id=pid): if pid in self.open_text_ids: self.open_text_ids.discard(pid) self.completed_text_ids.add(pid) - return [protocol.TextEndPart(id=pid)] - return [] + out.append(protocol.TextEndPart(id=pid)) case events_.ReasoningStart(block_id=pid): self.open_reasoning_ids.add(pid) - return [protocol.ReasoningStartPart(id=pid)] + out.append(protocol.ReasoningStartPart(id=pid)) case events_.ReasoningDelta(block_id=pid, chunk=chunk): - out = [] if pid not in self.open_reasoning_ids: self.open_reasoning_ids.add(pid) out.append(protocol.ReasoningStartPart(id=pid)) self.reasoning_delta_ids.add(pid) out.append(protocol.ReasoningDeltaPart(id=pid, delta=chunk)) - return out case events_.ReasoningEnd(block_id=pid): if pid in self.open_reasoning_ids: self.open_reasoning_ids.discard(pid) self.completed_reasoning_ids.add(pid) - return [protocol.ReasoningEndPart(id=pid)] - return [] + out.append(protocol.ReasoningEndPart(id=pid)) case events_.ToolStart(tool_call_id=tcid, tool_name=name): self.tool_names[tcid] = name if tcid in self.started_tool_inputs: - return [] + return out self.started_tool_inputs.add(tcid) - return [ + out.append( protocol.ToolInputStartPart( tool_call_id=tcid, tool_name=name, ) - ] + ) case events_.ToolDelta(tool_call_id=tcid, chunk=chunk): - out = [] if tcid not in self.started_tool_inputs: self.started_tool_inputs.add(tcid) out.append( @@ -212,64 +171,26 @@ def on_event(self, event: events_.Event) -> list[protocol.UIMessageStreamPart]: input_text_delta=chunk, ) ) - return out case events_.ToolEnd(): - return [] - - return [] - - # -- phase: terminal message -------------------------------------------- - - def _static_content( - self, msg: messages_.Message - ) -> list[protocol.UIMessageStreamPart]: - out: list[protocol.UIMessageStreamPart] = [] - - for part in msg.parts: - if isinstance(part, messages_.ReasoningPart): - if part.id not in self.completed_reasoning_ids: - if part.id not in self.open_reasoning_ids: - out.append(protocol.ReasoningStartPart(id=part.id)) - if part.text and part.id not in self.reasoning_delta_ids: - out.append( - protocol.ReasoningDeltaPart(id=part.id, delta=part.text) - ) - out.append(protocol.ReasoningEndPart(id=part.id)) - self.open_reasoning_ids.discard(part.id) - self.completed_reasoning_ids.add(part.id) - - elif isinstance(part, messages_.TextPart): - if part.id not in self.completed_text_ids: - if part.id not in self.open_text_ids: - out.append(protocol.TextStartPart(id=part.id)) - if part.text and part.id not in self.text_delta_ids: - out.append(protocol.TextDeltaPart(id=part.id, delta=part.text)) - out.append(protocol.TextEndPart(id=part.id)) - self.open_text_ids.discard(part.id) - self.completed_text_ids.add(part.id) - - elif isinstance(part, messages_.FilePart): - out.append( - protocol.FilePart( - url=part.data if isinstance(part.data, str) else "", - media_type=part.media_type, - ) - ) + pass return out - def on_terminal(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPart]: - if msg.id in self.seen_done or not self._is_visible_message(msg): - self.seen_done.add(msg.id) - return [] + # -- phase: tool results ------------------------------------------------ + def on_tool_result( + self, event: ToolCallResult + ) -> list[protocol.UIMessageStreamPart]: + """Handle a ``ToolCallResult`` — emit tool input/output parts.""" + msg = event.message out: list[protocol.UIMessageStreamPart] = [] - if not self.started_current_message: - out.extend(self.on_message(msg)) - out.extend(self._static_content(msg)) + # Ensure the UI message is started (handles agent-change too). + out.extend(self._ensure_started(source_label=msg.source_label)) + # Emit ToolInputAvailable for each tool call that triggered + # these results (from the assistant message's ToolCallParts). for part in msg.parts: if isinstance(part, messages_.ToolCallPart): if part.tool_call_id in self.input_available_emitted: @@ -291,55 +212,64 @@ def on_terminal(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPa ) ) - elif isinstance(part, messages_.ToolResultPart): - if part.tool_call_id in self.emitted_tool_results: - continue - self.emitted_tool_results.add(part.tool_call_id) - if part.is_error: - out.append( - protocol.ToolOutputErrorPart( - tool_call_id=part.tool_call_id, - error_text=_tool_error_text(part), - ) + # Emit tool results. + for part in event.results: + if part.tool_call_id in self.emitted_tool_results: + continue + self.emitted_tool_results.add(part.tool_call_id) + if part.is_error: + out.append( + protocol.ToolOutputErrorPart( + tool_call_id=part.tool_call_id, + error_text=_tool_error_text(part), ) - else: - out.append( - protocol.ToolOutputAvailablePart( - tool_call_id=part.tool_call_id, - output=part.result, - ) + ) + else: + out.append( + protocol.ToolOutputAvailablePart( + tool_call_id=part.tool_call_id, + output=part.result, ) + ) - elif isinstance(part, messages_.HookPart): - tc_id = _approvals.tool_call_id_for(part) - if tc_id is None: - continue + return out - if part.status == "pending": - if tc_id in self.emitted_approval_requests: - continue - self.emitted_approval_requests.add(tc_id) - out.append( - protocol.ToolApprovalRequestPart( - approval_id=part.hook_id, - tool_call_id=tc_id, - ) - ) - elif part.status == "resolved": - resolution: dict[str, Any] = part.resolution or {} - if not resolution.get("granted", False): - out.append(protocol.ToolOutputDeniedPart(tool_call_id=tc_id)) - elif part.status == "cancelled": - out.append( - protocol.ToolOutputErrorPart( - tool_call_id=tc_id, - error_text="Hook cancelled", - ) - ) + # -- phase: hooks ------------------------------------------------------- + + def on_hook(self, event: HookEvent) -> list[protocol.UIMessageStreamPart]: + """Handle a ``HookEvent`` — emit approval parts.""" + hook_part = event.hook + out: list[protocol.UIMessageStreamPart] = [] + + # Ensure the UI message is started. + out.extend(self._ensure_started()) + + tc_id = _approvals.tool_call_id_for(hook_part) + if tc_id is None: + return out + + if hook_part.status == "pending": + if tc_id in self.emitted_approval_requests: + return out + self.emitted_approval_requests.add(tc_id) + out.append( + protocol.ToolApprovalRequestPart( + approval_id=hook_part.hook_id, + tool_call_id=tc_id, + ) + ) + elif hook_part.status == "resolved": + resolution: dict[str, Any] = hook_part.resolution or {} + if not resolution.get("granted", False): + out.append(protocol.ToolOutputDeniedPart(tool_call_id=tc_id)) + elif hook_part.status == "cancelled": + out.append( + protocol.ToolOutputErrorPart( + tool_call_id=tc_id, + error_text="Hook cancelled", + ) + ) - self.seen_done.add(msg.id) - self.skip_current_message = False - self.started_current_message = False return out # -- phase: stream finish ------------------------------------------------ diff --git a/src/ai/agents/ui/ai_sdk/outbound/stream.py b/src/ai/agents/ui/ai_sdk/outbound/stream.py index a569342f..58e86655 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound/stream.py @@ -4,7 +4,7 @@ from collections.abc import AsyncGenerator, AsyncIterable -from ....events import AgentEvent, MessageEnd, MessageStart +from ....events import AgentEvent, HookEvent, ToolCallResult from .. import protocol from ._state import _StreamState @@ -14,18 +14,18 @@ async def to_stream( ) -> AsyncGenerator[protocol.UIMessageStreamPart]: """Walk ``events`` once, emitting AI SDK UI stream parts. - Streaming text/reasoning/tool-input deltas come from public events. - Terminal tool results, approvals, and files come from - ``MessageEnd.message``. + Streaming text/reasoning/tool-input deltas come from model events. + Tool results come from ``ToolCallResult``. Hook signals come from + ``HookEvent``. """ state = _StreamState() async for event in events: - if isinstance(event, MessageStart): - for part in state.on_message_start(event.message): + if isinstance(event, ToolCallResult): + for part in state.on_tool_result(event): yield part - elif isinstance(event, MessageEnd): - for part in state.on_terminal(event.message): + elif isinstance(event, HookEvent): + for part in state.on_hook(event): yield part else: for part in state.on_event(event): diff --git a/src/ai/middleware.py b/src/ai/middleware.py index afac14b6..8e0c22ff 100644 --- a/src/ai/middleware.py +++ b/src/ai/middleware.py @@ -121,7 +121,7 @@ def __post_init__(self) -> None: # Event/message aliases for brevity in signatures. ``_Event`` is intentionally # typed as ``Any`` so the agent-run chain accepts the wider ``AgentEvent`` -# union (which includes ``MessageStart``/``MessageEnd``) without a circular +# union (which includes ``ToolCallResult``/``HookEvent``) without a circular # import from ``ai.agents``. _Event = Any _Message = messages_.Message diff --git a/src/ai/types/builders.py b/src/ai/types/builders.py index 94ddda5e..566158d3 100644 --- a/src/ai/types/builders.py +++ b/src/ai/types/builders.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Any, overload +from typing import Any from .messages import ( FilePart, @@ -111,24 +111,6 @@ def _tool_results_from_messages(messages: list[Message]) -> list[ToolResultPart] return parts -@overload -def tool_message(*messages: Message | list[Message]) -> Message: ... - - -@overload -def tool_message(*parts: ToolResultPart) -> Message: ... - - -@overload -def tool_message( - *, - tool_call_id: str, - result: Any = None, - tool_name: str = "", - is_error: bool = False, -) -> Message: ... - - def tool_message( *items: Message | ToolResultPart | list[Message], tool_call_id: str | None = None, @@ -138,7 +120,7 @@ def tool_message( ) -> Message: """Create or merge a tool-result message. - >>> part = ai.tool_result("tc-1", result=72, tool_name="weather") + >>> part = ai.tool_result_part("tc-1", result=72, tool_name="weather") >>> ai.tool_message(part) >>> ai.tool_message(tool_call_id="tc-1", result=72, tool_name="weather") """ @@ -156,7 +138,7 @@ def tool_message( return Message( role="tool", parts=[ - tool_result( + tool_result_part( tool_call_id, result=result, tool_name=tool_name, @@ -204,7 +186,7 @@ def tool_message( return Message(role="tool", parts=tool_parts) -def tool_result( +def tool_result_part( tool_call_id: str, *, result: Any = None, @@ -213,7 +195,7 @@ def tool_result( ) -> ToolResultPart: """Create a :class:`ToolResultPart`. - >>> ai.tool_result("tc-1", result={"temp": 72}, tool_name="weather") + >>> ai.tool_result_part("tc-1", result={"temp": 72}, tool_name="weather") """ return ToolResultPart( tool_call_id=tool_call_id, diff --git a/src/ai/types/events.py b/src/ai/types/events.py index 8ce0edae..480f55ab 100644 --- a/src/ai/types/events.py +++ b/src/ai/types/events.py @@ -9,6 +9,12 @@ # serialization border in the case of durable execution +# Placeholder so BaseEvent.message is typed as Message (not Message | None). +# Stream.__anext__ stamps the real in-progress message before yielding, +# so consumers never see this value. +_DUMMY_MESSAGE = messages.Message(id="", role="assistant", parts=[]) + + class BaseEvent(pydantic.BaseModel): """Common fields stamped onto every event by the streaming wrapper. @@ -18,7 +24,7 @@ class BaseEvent(pydantic.BaseModel): usage value reported by the provider (latest-wins across the stream). """ - message: messages.Message | None = None + message: messages.Message = _DUMMY_MESSAGE usage: usage_.Usage | None = None model_config = pydantic.ConfigDict(frozen=True) diff --git a/tests/agents/test_generator_tools.py b/tests/agents/test_generator_tools.py index 2c720b9f..63d58339 100644 --- a/tests/agents/test_generator_tools.py +++ b/tests/agents/test_generator_tools.py @@ -15,7 +15,6 @@ from ..conftest import ( MOCK_MODEL, - collect_messages, emit_events_for_messages, mock_llm, text_msg, @@ -28,13 +27,11 @@ @ai.tool # type: ignore[arg-type] -async def progress_tool(query: str) -> AsyncGenerator[ai.Message]: +async def progress_tool(query: str) -> AsyncGenerator[events_.Event | ai.Message]: """Tool that streams progress, then returns a final answer.""" - yield ai.Message( - role="assistant", - parts=[messages_.TextPart(text="Working...")], - source_label="progress", - ) + yield events_.TextStart(block_id="progress") + yield events_.TextDelta(block_id="progress", chunk="Working...") + yield events_.TextEnd(block_id="progress") yield ai.Message( role="assistant", parts=[messages_.TextPart(text=f"Answer for {query}")], @@ -42,7 +39,7 @@ async def progress_tool(query: str) -> AsyncGenerator[ai.Message]: async def test_generator_tool_streams_and_returns_result() -> None: - """Generator tool yields intermediate messages visible to consumer; + """Generator tool yields streaming events visible to consumer; final text becomes the tool result fed back to the LLM.""" my_agent = ai.agent(tools=[progress_tool]) @@ -52,24 +49,27 @@ async def test_generator_tool_streams_and_returns_result() -> None: reply = [text_msg("Done!", id="msg-2")] llm = mock_llm([call, reply]) - collected = await collect_messages( - my_agent.run(MOCK_MODEL, [ai.user_message("Go")]) - ) + all_events: list[agent_events_.AgentEvent] = [] + async for event in my_agent.run(MOCK_MODEL, [ai.user_message("Go")]): + all_events.append(event) assert llm.call_count == 2 - # Intermediate progress message was forwarded to consumer. - progress = [m for m in collected if m.source_label == "progress"] - assert len(progress) == 1 - assert progress[0].text == "Working..." + # Intermediate progress events were forwarded to consumer. + progress_deltas = [ + e + for e in all_events + if isinstance(e, events_.TextDelta) and e.block_id == "progress" + ] + assert len(progress_deltas) == 1 + assert progress_deltas[0].chunk == "Working..." # Tool result was fed back to LLM. - tool_results = [m for m in collected if m.role == "tool"] + tool_results = [ + e for e in all_events if isinstance(e, agent_events_.ToolCallResult) + ] assert len(tool_results) >= 1 - assert tool_results[0].tool_results[0].result == "Answer for test" - - # Final response arrived. - assert any(m.text == "Done!" for m in collected) + assert tool_results[0].results[0].result == "Answer for test" # --------------------------------------------------------------------------- @@ -127,12 +127,12 @@ async def research_tool(topic: str) -> AsyncGenerator[agent_events_.AgentEvent]: async def test_yield_from_nested_agent() -> None: - """yield_from forwards inner messages to the consumer but does NOT + """yield_from forwards inner events to the consumer but does NOT add them to the outer agent's history (context.messages). - The critical contract from agent.py:292: yield_from streams messages - through the runtime queue without going through _collect_messages, - so the parent agent's context.messages stays clean. + The critical contract: yield_from streams events through the runtime + queue without going through _collect_messages, so the parent agent's + context.messages stays clean. """ outer = ai.agent(tools=[research_tool]) @@ -149,15 +149,24 @@ async def test_yield_from_nested_agent() -> None: adapter = _CapturingAdapter([outer_call, inner_reply, outer_reply]) models.register_stream("mock", adapter.stream) - collected = await collect_messages( - outer.run(MOCK_MODEL, [ai.user_message("Tell me about Mars")]) - ) + all_events: list[agent_events_.AgentEvent] = [] + async for event in outer.run(MOCK_MODEL, [ai.user_message("Tell me about Mars")]): + all_events.append(event) assert adapter.call_count == 3 - # Inner messages were forwarded to the consumer with label="inner". - inner_msgs = [m for m in collected if m.source_label == "inner"] - assert len(inner_msgs) > 0 + # Inner text events were forwarded to the consumer. + inner_text = [ + e + for e in all_events + if isinstance(e, events_.TextDelta) and e.chunk == "Mars has two moons." + ] + assert len(inner_text) > 0 + + tool_results = [ + e for e in all_events if isinstance(e, agent_events_.ToolCallResult) + ] + assert tool_results[0].results[0].result == "Mars has two moons." # The outer LLM's second call (index 2) must NOT contain any inner # agent messages. It should only see: the original user message, diff --git a/tests/agents/test_hooks.py b/tests/agents/test_hooks.py index 73d88d22..14fd40f2 100644 --- a/tests/agents/test_hooks.py +++ b/tests/agents/test_hooks.py @@ -4,6 +4,7 @@ import asyncio from collections.abc import AsyncGenerator +from typing import Any import pydantic import pytest @@ -38,11 +39,10 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: mock_llm([[text_msg("OK")]]) async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): - if not isinstance(event, agent_events_.MessageEnd): + if not isinstance(event, agent_events_.HookEvent): continue - msg = event.message - # When we see the pending hook message, resolve it. - if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): + # When we see the pending hook, resolve it. + if event.hook.status == "pending": ai.resolve_hook("confirm_1", {"approved": True, "reason": "looks good"}) assert resolved_value is not None @@ -71,10 +71,9 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: mock_llm([[text_msg("OK")]]) async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): - if not isinstance(event, agent_events_.MessageEnd): + if not isinstance(event, agent_events_.HookEvent): continue - msg = event.message - if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): + if event.hook.status == "pending": await ai.cancel_hook("cancel_me", reason="denied") assert was_cancelled @@ -143,22 +142,17 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: mock_llm([[text_msg("OK")]]) - msgs: list[ai.Message] = [] + hooks: list[ai.HookPart[Any]] = [] async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): - if not isinstance(event, agent_events_.MessageEnd): + if not isinstance(event, agent_events_.HookEvent): continue - msg = event.message - msgs.append(msg) - if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): + hooks.append(event.hook) + if event.hook.status == "pending": ai.resolve_hook("emit_test", {"approved": False}) - hook_msgs = [ - m - for m in msgs - if any(isinstance(p, ai.HookPart) and p.status == "resolved" for p in m.parts) - ] - assert len(hook_msgs) == 1 - assert hook_msgs[0].parts[0].resolution == {"approved": False} # type: ignore[union-attr] + resolved = [h for h in hooks if h.status == "resolved"] + assert len(resolved) == 1 + assert resolved[0].resolution == {"approved": False} # -- Hook metadata surfaces in pending message ----------------------------- @@ -179,11 +173,10 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: ) mock_llm([[text_msg("OK")]]) - msgs: list[ai.Message] = [] + hooks: list[ai.HookPart[Any]] = [] async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): - if isinstance(event, agent_events_.MessageEnd): - msgs.append(event.message) + if isinstance(event, agent_events_.HookEvent): + hooks.append(event.hook) - hook_msgs = [m for m in msgs if any(isinstance(p, ai.HookPart) for p in m.parts)] - assert len(hook_msgs) >= 1 - assert hook_msgs[0].parts[0].metadata == {"tool": "rm -rf", "path": "/"} # type: ignore[union-attr] + assert len(hooks) >= 1 + assert hooks[0].metadata == {"tool": "rm -rf", "path": "/"} diff --git a/tests/agents/ui/ai_sdk/outbound/test_sse.py b/tests/agents/ui/ai_sdk/outbound/test_sse.py index 4c3ba93d..0d2d6362 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_sse.py +++ b/tests/agents/ui/ai_sdk/outbound/test_sse.py @@ -6,7 +6,7 @@ from ai.agents import events as agent_events_ from ai.agents.ui.ai_sdk import protocol, to_sse from ai.agents.ui.ai_sdk.outbound.sse import format_sse, serialize_part -from ai.types import messages as messages_ +from ai.types import events as events_ def test_serialize_part_camelcases_keys() -> None: @@ -37,24 +37,19 @@ async def _gen( async def test_to_sse_emits_data_prefixed_lines() -> None: - msg = messages_.Message( - id="m1", - role="assistant", - turn_id="t1", - parts=[messages_.TextPart(text="hi")], - ) lines = [ line async for line in to_sse( _gen( [ - agent_events_.MessageStart(message=msg), - agent_events_.MessageEnd(message=msg), + events_.TextStart(block_id="t1"), + events_.TextDelta(block_id="t1", chunk="hi"), + events_.TextEnd(block_id="t1"), ] ) ) ] assert all(line.startswith("data: ") for line in lines) - # first line is the start part + # first line is the start part (lazy open) first = json.loads(lines[0].removeprefix("data: ").rstrip()) assert first["type"] == "start" diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/outbound/test_stream.py index e193b359..bc329a0c 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_stream.py +++ b/tests/agents/ui/ai_sdk/outbound/test_stream.py @@ -21,43 +21,18 @@ async def _collect( return [part async for part in to_stream(_gen(stream_events))] -def _assistant_start( - msg_id: str = "m1", - *, - turn_id: str | None = "t1", - source_label: str | None = None, -) -> agent_events_.MessageStart: - return agent_events_.MessageStart( - message=messages_.Message( - id=msg_id, - role="assistant", - turn_id=turn_id, - source_label=source_label, - parts=[], - ) - ) - - async def test_event_driven_text_streaming() -> None: + """Streaming text events lazily open a UI message.""" text_id = "txt1" - final = messages_.Message( - id="m1", - role="assistant", - turn_id="t1", - parts=[messages_.TextPart(id=text_id, text="hi")], - ) out = await _collect( [ - _assistant_start("m1"), events_.TextStart(block_id=text_id), events_.TextDelta(block_id=text_id, chunk="hi"), events_.TextEnd(block_id=text_id), - agent_events_.MessageEnd(message=final), ] ) assert isinstance(out[0], protocol.StartPart) - assert out[0].message_id == "m1" assert isinstance(out[1], protocol.StartStepPart) assert isinstance(out[2], protocol.TextStartPart) and out[2].id == text_id assert isinstance(out[3], protocol.TextDeltaPart) and out[3].delta == "hi" @@ -66,108 +41,67 @@ async def test_event_driven_text_streaming() -> None: assert isinstance(out[6], protocol.FinishPart) -async def test_static_text_message_emits_text_parts() -> None: - msg = messages_.Message( - id="m1", - role="assistant", - parts=[messages_.TextPart(id="txt1", text="hello")], - ) - out = await _collect( - [agent_events_.MessageStart(message=msg), agent_events_.MessageEnd(message=msg)] - ) - assert any(isinstance(part, protocol.TextDeltaPart) for part in out) - - -async def test_turn_id_change_emits_step_boundary() -> None: - msg1 = messages_.Message( - id="m1", - role="assistant", - turn_id="t1", - parts=[messages_.TextPart(text="hello")], - ) - msg2 = messages_.Message( - id="m2", - role="assistant", - turn_id="t2", - parts=[messages_.TextPart(text="world")], - ) - out = await _collect( - [ - agent_events_.MessageStart(message=msg1), - agent_events_.MessageEnd(message=msg1), - agent_events_.MessageStart(message=msg2), - agent_events_.MessageEnd(message=msg2), - ] - ) - has_mid_step_boundary = any( - isinstance(out[i], protocol.FinishStepPart) - and i + 1 < len(out) - and isinstance(out[i + 1], protocol.StartStepPart) - for i in range(1, len(out) - 1) - ) - assert has_mid_step_boundary - - -async def test_agent_change_emits_message_boundary() -> None: - msg1 = messages_.Message( - id="m1", - role="assistant", - source_label="a1", - parts=[messages_.TextPart(text="from a")], - ) - msg2 = messages_.Message( - id="m2", - role="assistant", - source_label="a2", - parts=[messages_.TextPart(text="from b")], +async def test_tool_call_and_result_emit_terminal_parts() -> None: + """ToolCallResult emits tool input and output parts.""" + tool_result_msg = messages_.Message( + role="tool", + parts=[ + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="search", + result={"hits": 1}, + ) + ], ) out = await _collect( [ - agent_events_.MessageStart(message=msg1), - agent_events_.MessageEnd(message=msg1), - agent_events_.MessageStart(message=msg2), - agent_events_.MessageEnd(message=msg2), + # Streaming tool input events from the model + events_.ToolStart(tool_call_id="tc1", tool_name="search"), + events_.ToolDelta(tool_call_id="tc1", chunk='{"q":"x"}'), + events_.ToolEnd( + tool_call_id="tc1", + tool_call=messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="search", + tool_args='{"q":"x"}', + ), + ), + # Tool execution result + agent_events_.ToolCallResult( + message=tool_result_msg, + results=tool_result_msg.tool_results, + ), ] ) - has_mid_msg_boundary = any( - isinstance(out[i], protocol.FinishPart) - and i + 1 < len(out) - and isinstance(out[i + 1], protocol.StartPart) - for i in range(1, len(out) - 1) - ) - assert has_mid_msg_boundary + types = [type(part).__name__ for part in out] + assert "ToolInputStartPart" in types + assert "ToolOutputAvailablePart" in types -async def test_tool_call_and_result_emit_terminal_parts() -> None: - tool_call = messages_.Message( - id="m1", - role="assistant", - turn_id="t1", +async def test_tool_result_without_streaming_emits_input_start() -> None: + """ToolCallResult for a non-streamed tool emits input + output parts.""" + tool_result_msg = messages_.Message( + role="tool", parts=[ messages_.ToolCallPart( id="tc1", tool_call_id="tc1", tool_name="search", tool_args='{"q":"x"}', - ) - ], - ) - tool_result = messages_.Message( - role="tool", - parts=[ + ), messages_.ToolResultPart( tool_call_id="tc1", tool_name="search", result={"hits": 1}, - ) + ), ], ) out = await _collect( [ - agent_events_.MessageStart(message=tool_call), - agent_events_.MessageEnd(message=tool_call), - agent_events_.MessageStart(message=tool_result), - agent_events_.MessageEnd(message=tool_result), + agent_events_.ToolCallResult( + message=tool_result_msg, + results=tool_result_msg.tool_results, + ), ] ) types = [type(part).__name__ for part in out] @@ -177,57 +111,91 @@ async def test_tool_call_and_result_emit_terminal_parts() -> None: async def test_approval_request_hook_emits_approval_part() -> None: - tool_call = messages_.Message( - id="m1", - role="assistant", - turn_id="t1", + """HookEvent with pending status emits a ToolApprovalRequestPart.""" + out = await _collect( + [ + # Streaming tool events first + events_.ToolStart(tool_call_id="tc1", tool_name="delete"), + events_.ToolDelta(tool_call_id="tc1", chunk="{}"), + events_.ToolEnd( + tool_call_id="tc1", + tool_call=messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="delete", + tool_args="{}", + ), + ), + # Hook requesting approval + agent_events_.HookEvent( + message=messages_.Message( + role="internal", + parts=[ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + ], + ), + hook=messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ), + ), + ] + ) + approval_parts = [p for p in out if isinstance(p, protocol.ToolApprovalRequestPart)] + assert len(approval_parts) == 1 + assert approval_parts[0].tool_call_id == "tc1" + assert approval_parts[0].approval_id == "approve_tc1" + + +async def test_agent_change_emits_message_boundary() -> None: + """ToolCallResult from a different agent triggers a new StartPart.""" + tool_result_a = messages_.Message( + role="tool", + source_label="a1", parts=[ - messages_.ToolCallPart( - id="tc1", + messages_.ToolResultPart( tool_call_id="tc1", - tool_name="delete", - tool_args="{}", + tool_name="foo", + result="ok", ) ], ) - hook = messages_.Message( - role="internal", + tool_result_b = messages_.Message( + role="tool", + source_label="a2", parts=[ - messages_.HookPart( - hook_id="approve_tc1", - hook_type="ToolApproval", - status="pending", + messages_.ToolResultPart( + tool_call_id="tc2", + tool_name="bar", + result="ok", ) ], ) out = await _collect( [ - agent_events_.MessageStart(message=tool_call), - agent_events_.MessageEnd(message=tool_call), - agent_events_.MessageStart(message=hook), - agent_events_.MessageEnd(message=hook), + # Agent a1 does text + tool + events_.TextStart(block_id="t1"), + events_.TextDelta(block_id="t1", chunk="from a"), + events_.TextEnd(block_id="t1"), + agent_events_.ToolCallResult( + message=tool_result_a, + results=tool_result_a.tool_results, + ), + # Agent a2 does text + tool — should trigger new StartPart + agent_events_.ToolCallResult( + message=tool_result_b, + results=tool_result_b.tool_results, + ), ] ) - approval_parts = [p for p in out if isinstance(p, protocol.ToolApprovalRequestPart)] - assert len(approval_parts) == 1 - assert approval_parts[0].tool_call_id == "tc1" - assert approval_parts[0].approval_id == "approve_tc1" - - -async def test_dedup_on_reemitted_message_id() -> None: - msg = messages_.Message( - id="m1", - role="assistant", - turn_id="t1", - parts=[messages_.TextPart(id="txt1", text="hi")], + has_mid_msg_boundary = any( + isinstance(out[i], protocol.FinishPart) + and i + 1 < len(out) + and isinstance(out[i + 1], protocol.StartPart) + for i in range(1, len(out) - 1) ) - stream_events: list[agent_events_.AgentEvent] = [ - agent_events_.MessageStart(message=msg), - events_.TextStart(block_id="txt1"), - events_.TextDelta(block_id="txt1", chunk="hi"), - events_.TextEnd(block_id="txt1"), - agent_events_.MessageEnd(message=msg), - ] - out = await _collect([*stream_events, *stream_events]) - text_deltas = [part for part in out if isinstance(part, protocol.TextDeltaPart)] - assert len(text_deltas) == 1 + assert has_mid_msg_boundary diff --git a/tests/conftest.py b/tests/conftest.py index 4be7dae2..7af695f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -189,7 +189,7 @@ async def collect_messages( """Collect terminal messages from an event stream.""" result: list[messages_.Message] = [] async for event in source: - if isinstance(event, agent_events_.MessageEnd): + if isinstance(event, agent_events_.TerminalEvent): result.append(event.message) return result diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 64be953d..51aab302 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -87,10 +87,9 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: async for event in my_agent.run( MOCK_MODEL, [ai.user_message("go")], middleware=[Spy()] ): - if not isinstance(event, agent_events_.MessageEnd): + if not isinstance(event, agent_events_.HookEvent): continue - msg = event.message - if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): + if event.hook.status == "pending": ai.resolve_hook("test_hook", {"approved": True, "reason": "ok"}) assert len(hook_calls) == 1 diff --git a/tests/types/test_builders.py b/tests/types/test_builders.py index 4232929e..46b2cc59 100644 --- a/tests/types/test_builders.py +++ b/tests/types/test_builders.py @@ -38,11 +38,11 @@ def test_file_part_from_bytes_unknown_raises() -> None: def test_tool_message_merges_tool_messages() -> None: m1 = messages.Message( role="tool", - parts=[builders.tool_result("tc-1", result=1, tool_name="a")], + parts=[builders.tool_result_part("tc-1", result=1, tool_name="a")], ) m2 = messages.Message( role="tool", - parts=[builders.tool_result("tc-2", result=2, tool_name="b")], + parts=[builders.tool_result_part("tc-2", result=2, tool_name="b")], ) merged = builders.tool_message(m1, m2) diff --git a/tests/types/test_integrity.py b/tests/types/test_integrity.py index 756a3acb..524664c8 100644 --- a/tests/types/test_integrity.py +++ b/tests/types/test_integrity.py @@ -444,8 +444,8 @@ def test_duplicate_tool_results_within_same_message_raises() -> None: messages.Message( role="tool", parts=[ - builders.tool_result("tc-1", result="first"), - builders.tool_result("tc-1", result="second"), + builders.tool_result_part("tc-1", result="first"), + builders.tool_result_part("tc-1", result="second"), ], ), ]