Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/ai/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ class ProviderResponseError(ProviderAPIError):
"""Provider returned a malformed or unexpected response."""


class ProviderIncompleteResponseError(ProviderResponseError):
"""Provider stream ended before the response was complete.

Raised when a streaming response terminates without the provider's
finish event — e.g. the transport connection dropped mid-response —
leaving a partial message (reasoning-only output, or a tool call
with truncated arguments).
"""


class ProviderStatusError(ProviderAPIError):
"""Provider returned a non-success HTTP status code."""

Expand Down Expand Up @@ -326,6 +336,7 @@ def _is_retryable_status(status_code: int | None) -> bool:
"ProviderConnectionError",
"ProviderDeadlineExceededError",
"ProviderError",
"ProviderIncompleteResponseError",
"ProviderInternalServerError",
"ProviderModelNotFoundError",
"ProviderNotConfiguredError",
Expand Down
29 changes: 24 additions & 5 deletions src/ai/models/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Use the typing_extensions backport so this works on 3.12 too.
from typing_extensions import TypeVar

from ... import types
from ... import errors, types
from ...types import integrity

if TYPE_CHECKING:
Expand Down Expand Up @@ -128,6 +128,12 @@ def __init__(
# param for ergonomics; internally we know it's a Pydantic model
# subclass (or None for the text-default case).
self._output_type = cast("type[pydantic.BaseModel] | None", output_type)
# Whether the provider signalled completion (``StreamEnd``). A
# stream that exhausts without it died mid-response (transport
# drop): the message is partial — possibly reasoning-only or a
# tool call with truncated args — so exhaustion must raise
# rather than look like a normal end of turn.
self._ended = False

async def aclose(self) -> None:
await self._gen.aclose()
Expand All @@ -150,9 +156,15 @@ def __aiter__(self) -> Self:
async def __anext__(self: Self) -> types.events.Event:
try:
event = await self._gen.__anext__()
except Exception:
# Usually this fires on StopAsyncIteration, but could be a
# real exception too
except StopAsyncIteration:
if not self._ended:
raise errors.ProviderIncompleteResponseError(
"provider stream ended without a finish event; "
"the response is incomplete",
# Premature termination is a transient transport or
# provider failure: worth retrying.
is_retryable=True,
) from None
raise
updates = self._aggregate_event(event)
return event.model_copy(update={"message": self._message, **updates})
Expand Down Expand Up @@ -189,8 +201,10 @@ def _aggregate_event(self, event: types.events.Event) -> dict[str, Any]:
updates: dict[str, Any] = {}

# Replay events carry no new state — the seeded message already
# has everything they would have produced.
# has everything they would have produced. A replayed turn is
# complete by construction, so it also counts as ended.
if event.replay:
self._ended = True
return updates

# grab usage from any event that carries one
Expand Down Expand Up @@ -323,6 +337,7 @@ def _aggregate_event(self, event: types.events.Event) -> dict[str, Any]:
self._parts[fp.id] = fp

case types.events.StreamEnd(provider_metadata=pm):
self._ended = True
if pm is not None:
self._message.provider_metadata = pm
case _:
Expand Down Expand Up @@ -474,6 +489,10 @@ async def _stream(
seed_message=last.model_copy(deep=True),
output_type=cast("type[Any] | None", output_type),
)
# The replayed turn is a complete persisted message; don't
# demand a finish event from the synthetic replay generator
# (it yields nothing when the turn has no tool calls).
s._ended = True
else:
prepared = integrity.prepare_messages(messages)
request = StreamRequest(
Expand Down
46 changes: 46 additions & 0 deletions tests/models/core/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,49 @@ async def _spy_stream(
pass

assert called is True


async def test_stream_raises_when_stream_ends_without_finish() -> None:
"""A transport drop mid-response must raise, not look like a normal end.

A completed response always carries a ``StreamEnd``; when the SSE
connection dies the provider generator just exhausts, which used to
end the iteration silently with a partial message (reasoning-only,
or a tool call with truncated args).
"""

async def _dying_stream(
model: models.Model,
messages: list[messages_.Message],
*,
tools: Sequence[ai.tools.Tool] | None = None,
output_type: type[pydantic.BaseModel] | None = None,
**kwargs: Any,
) -> AsyncGenerator[events_.Event]:
yield events_.StreamStart()
yield events_.ReasoningDelta(block_id="r-1", chunk="thinking…")
yield events_.ToolStart(tool_call_id="tc-1", tool_name="bash")
yield events_.ToolDelta(tool_call_id="tc-1", chunk='{"command": "uv ru')
# connection drops: no ToolEnd, no StreamEnd

MOCK_PROVIDER._stream_impl = _dying_stream

with pytest.raises(ai.errors.ProviderIncompleteResponseError) as excinfo:
async with models.stream(MOCK_MODEL, [ai.user_message("Hi")]) as stream:
async for _ in stream:
pass

assert excinfo.value.is_retryable is True
# The partial message stays inspectable on the stream object.
assert stream.message.usage is None
assert stream.message.tool_calls[0].tool_args == '{"command": "uv ru'


async def test_stream_does_not_raise_on_early_consumer_exit() -> None:
"""Breaking out of iteration early is not an incomplete response."""
mock_llm([[text_msg("Hello world")]])

async with models.stream(MOCK_MODEL, [ai.user_message("Hi")]) as stream:
async for event in stream:
if isinstance(event, events_.TextDelta):
break
1 change: 1 addition & 0 deletions tests/providers/ai_gateway/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ async def test_signature_survives_round_trip(self) -> None:
"providerMetadata": {"anthropic": {"signature": "ErMJsig=="}},
},
{"type": "reasoning-end", "id": "0"},
{"type": "finish"},
]

async def _gen() -> AsyncGenerator[events_.Event]:
Expand Down
Loading