Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions examples/coding-agent/1_raw_stream.py

This file was deleted.

8 changes: 5 additions & 3 deletions examples/samples/explicit_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

async def main() -> None:
try:
async for event in ai.models.stream(model, messages):
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)
async with ai.stream(model, messages) as s:
async for event in s:
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)
print()
finally:
# Explicit clients need explicit cleanup.
await client.aclose()


Expand Down
25 changes: 12 additions & 13 deletions examples/samples/inline_image.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Inline image generation — LLM that outputs images alongside text.

Models like Gemini 3 Pro Image can generate images as part of their
language model response. The images arrive as FileParts on the final
MessageEnd message.
language model response. The images arrive as ``FileEvent`` events
during the stream and end up as ``FilePart``s on the aggregated
``Stream.message``.
"""

import asyncio
Expand All @@ -22,20 +23,18 @@


async def main() -> None:
last_msg: ai.Message | None = None

# Stream — text deltas arrive as events, images arrive on MessageEnd
async for event in ai.stream(model, messages):
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)
elif isinstance(event, ai.MessageEnd):
last_msg = event.message
# Stream — text deltas arrive as TextDelta events, generated images
# arrive as FileEvent events and accumulate on s.message.
async with ai.stream(model, messages) as s:
async for event in s:
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)

print()

# Check for images in the final message
if last_msg and last_msg.images:
for i, img in enumerate(last_msg.images):
# Check for images in the aggregated message.
if s.message.images:
for i, img in enumerate(s.message.images):
filename = f"inline_{i}.png"
data = (
img.data if isinstance(img.data, bytes) else base64.b64decode(img.data)
Expand Down
7 changes: 4 additions & 3 deletions examples/samples/multimodal_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@


async def main() -> None:
async for event in ai.stream(model, messages):
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)
async with ai.stream(model, messages) as s:
async for event in s:
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)
print()


Expand Down
7 changes: 4 additions & 3 deletions examples/samples/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@


async def main() -> None:
async for event in ai.stream(model, messages):
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)
async with ai.stream(model, messages) as s:
async for event in s:
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)
print()


Expand Down
21 changes: 12 additions & 9 deletions examples/samples/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@ class Recipe(pydantic.BaseModel):


async def main() -> None:
# Stream with structured output — watch JSON arrive, get validated at the end
async for event in ai.stream(model, messages, output_type=Recipe):
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)
elif isinstance(event, ai.MessageEnd) and event.message.output:
recipe: Recipe = event.message.output
print(f"\n\nParsed recipe: {recipe.name}")
print(f" Ingredients: {', '.join(recipe.ingredients)}")
print(f" Prep time: {recipe.prep_time_minutes} min")
# Stream with structured output — watch JSON arrive, get validated at the end.
async with ai.stream(model, messages, output_type=Recipe) as s:
async for event in s:
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)

# After iteration, s.output is the validated pydantic model.
recipe: Recipe | None = s.output
if recipe is not None:
print(f"\n\nParsed recipe: {recipe.name}")
print(f" Ingredients: {', '.join(recipe.ingredients)}")
print(f" Prep time: {recipe.prep_time_minutes} min")


if __name__ == "__main__":
Expand Down
16 changes: 9 additions & 7 deletions examples/samples/tools_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@


async def main() -> None:
# Stream with tools — the model may emit tool calls
async for event in ai.stream(model, messages, tools=[get_weather]):
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)
elif isinstance(event, ai.MessageEnd):
for tc in event.message.tool_calls:
print(f"\nTool call: {tc.tool_name}({tc.tool_args})")
# Stream with tools — the model may emit tool calls.
async with ai.stream(model, messages, tools=[get_weather]) as s:
async for event in s:
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)
print()

# After iteration, s.tool_calls collects every tool call from the response.
for tc in s.tool_calls:
print(f"Tool call: {tc.tool_name}({tc.tool_args})")


if __name__ == "__main__":
asyncio.run(main())
30 changes: 18 additions & 12 deletions src/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
from .middleware import AgentRunContext, Middleware
from .models import (
Client,
Executor,
GenerateExecutor,
GenerateRequest,
ImageParams,
Model,
Provider,
StreamResult,
Stream,
StreamExecutor,
StreamRequest,
VideoParams,
ai_gateway,
anthropic,
Expand All @@ -32,22 +37,20 @@

# Re-export core types
from .types import (
End,
Event,
FileEvent,
FilePart,
HookPart,
HookResolution,
HookSuspention,
Message,
MessageEnd,
MessageStart,
Part,
ReasoningDelta,
ReasoningEnd,
ReasoningPart,
ReasoningStart,
Start,
StreamResultLike,
StreamEnd,
StreamStart,
StructuredOutputPart,
TextDelta,
TextEnd,
Expand All @@ -74,12 +77,10 @@

__all__ = [
# Types (from types/)
"Start",
"End",
"Event",
"Message",
"MessageStart",
"MessageEnd",
"StreamStart",
"StreamEnd",
"Part",
"TextPart",
"TextStart",
Expand All @@ -94,6 +95,7 @@
"ReasoningStart",
"ReasoningDelta",
"ReasoningEnd",
"FileEvent",
"FilePart",
"HookPart",
"HookSuspention",
Expand All @@ -116,8 +118,12 @@
"ImageParams",
"VideoParams",
"Client",
"StreamResult",
"StreamResultLike",
"Stream",
"StreamRequest",
"GenerateRequest",
"Executor",
"StreamExecutor",
"GenerateExecutor",
"check_connection",
"stream",
"generate",
Expand Down
45 changes: 28 additions & 17 deletions src/ai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .. import middleware as middleware_
from .. import models, types
from ..types import builders
from . import events as events_
from . import runtime


Expand Down Expand Up @@ -199,21 +200,23 @@ def resolve(self, tool_parts: list[types.ToolCallPart]) -> list[ToolCall]:
]


StreamItem = types.Event | types.Message
StreamItem = events_.AgentEvent | types.Message


class LoopFn(Protocol):
def __call__(self, context: Context) -> AsyncGenerator[StreamItem]: ...


async def _message_events(message: types.Message) -> AsyncGenerator[types.Event]:
yield types.MessageStart(message=message)
yield types.MessageEnd(message=message)
async def _message_events(
message: types.Message,
) -> AsyncGenerator[events_.AgentEvent]:
yield events_.MessageStart(message=message)
yield events_.MessageEnd(message=message)


async def _coerce_events(
source: AsyncIterable[StreamItem],
) -> AsyncGenerator[types.Event]:
) -> AsyncGenerator[events_.AgentEvent]:
async for item in source:
if isinstance(item, types.Message):
async for event in _message_events(item):
Expand All @@ -222,15 +225,23 @@ async def _coerce_events(
yield item


async def _default_loop(context: Context) -> AsyncGenerator[types.Event]:
async def _default_loop(context: Context) -> AsyncGenerator[events_.AgentEvent]:
while True:
stream = models.stream(
context.model,
context.messages,
tools=context.tools,
)
async for event in stream:
yield event
async for stream_event in stream:
yield stream_event

# Bridge: emit MessageStart/MessageEnd around the assistant message
# the model stream just produced, so _collect_messages and downstream
# consumers (AI-SDK outbound, label stamping) see the same boundary

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(what is AI-SDK outbound?)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The part of the AI SDK UI adapter (the one that has useChat) responsible for converting the stream from backend to frontend format

# events they did under the previous adapter contract.
if stream.message is not None and stream.message.parts:
async for boundary in _message_events(stream.message):
yield boundary

tool_calls = context.resolve(stream.tool_calls)
if not tool_calls:
Expand All @@ -244,14 +255,14 @@ async def _default_loop(context: Context) -> AsyncGenerator[types.Event]:
# Left un-stamped: the tool result is the input of the *next* turn,
# so the next stream() call will stamp it with that turn's id.
tool_msg = builders.tool_message(*(t.result() for t in tasks))
async for event in _message_events(tool_msg):
yield event
async for boundary in _message_events(tool_msg):
yield boundary


async def _collect_messages(
source: AsyncIterable[StreamItem],
messages: list[types.Message],
) -> AsyncGenerator[types.Event]:
) -> AsyncGenerator[events_.AgentEvent]:
"""Intercept yielded events and collect MessageEnd messages into *messages*.

This runs on the **producer** side (same coroutine as the loop function),
Expand All @@ -260,7 +271,7 @@ async def _collect_messages(
happened on the consumer side of the runtime queue.
"""
async for event in _coerce_events(source):
if isinstance(event, types.MessageEnd):
if isinstance(event, events_.MessageEnd):
message = event.message
for i, existing in enumerate(messages):
if existing.id == message.id:
Expand Down Expand Up @@ -292,7 +303,7 @@ async def yield_from(source: AsyncIterable[StreamItem]) -> str:
last: types.Message | None = None
async for item in _coerce_events(source):
await rt.put_event(item)
if isinstance(item, types.MessageEnd):
if isinstance(item, events_.MessageEnd):
last = item.message
return last.text if last else ""

Expand Down Expand Up @@ -325,7 +336,7 @@ async def run(
*,
label: str | None = None,
middleware: list[middleware_.Middleware] | None = None,
) -> AsyncGenerator[types.Event]:
) -> AsyncGenerator[events_.AgentEvent]:
"""Run the agent loop, yielding events to the consumer.

Args:
Expand All @@ -349,7 +360,7 @@ async def run(

async def _real(
call: middleware_.AgentRunContext,
) -> AsyncGenerator[types.Event]:
) -> AsyncGenerator[events_.AgentEvent]:
context = Context(
model=call.model,
messages=list(call.messages),
Expand All @@ -359,8 +370,8 @@ async def _real(
async for event in runtime.run(source):
if call.label is not None:
event_message: types.Message | None = None
if isinstance(event, types.MessageEnd) or (
isinstance(event, types.MessageStart)
if isinstance(event, events_.MessageEnd) or (
isinstance(event, events_.MessageStart)
and event.message is not None
):
event_message = event.message
Expand Down
Loading
Loading