From 55bb0ea3c347c5a3c62e532c83dc00125cca6546 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Wed, 10 Jun 2026 10:42:21 -0700 Subject: [PATCH] Raise on provider streams that end without a finish event When the transport drops mid-response, provider event generators just exhaust without error, causing the loop to terminate in confusing ways. Track StreamEnd in models.core.api.Stream and raise a new ProviderIncompleteResponseError (retryable by default) when the generator exhausts without one. As a follow-up, we should consider whether we ought to provide access to the completed/cancelled tool runs somewhere in the `Agent` after a failure? Currently a consumer wanting to understand the state when an error happens would need to reconstruct it from stream events. Co-authored-by: anthropic/claude-fable-5, via tau --- src/ai/errors.py | 11 +++++ src/ai/models/core/api.py | 29 ++++++++++--- tests/models/core/test_api.py | 46 +++++++++++++++++++++ tests/providers/ai_gateway/test_protocol.py | 1 + 4 files changed, 82 insertions(+), 5 deletions(-) diff --git a/src/ai/errors.py b/src/ai/errors.py index cdb110e5..f80d3338 100644 --- a/src/ai/errors.py +++ b/src/ai/errors.py @@ -162,6 +162,16 @@ class ProviderResponseError(ProviderAPIError): """Provider returned a malformed or unexpected response.""" +class ProviderIncompleteResponseError(ProviderResponseError): + """Provider stream ended before the response was complete. + + Raised when a streaming response terminates without the provider's + finish event — e.g. the transport connection dropped mid-response — + leaving a partial message (reasoning-only output, or a tool call + with truncated arguments). + """ + + class ProviderStatusError(ProviderAPIError): """Provider returned a non-success HTTP status code.""" @@ -326,6 +336,7 @@ def _is_retryable_status(status_code: int | None) -> bool: "ProviderConnectionError", "ProviderDeadlineExceededError", "ProviderError", + "ProviderIncompleteResponseError", "ProviderInternalServerError", "ProviderModelNotFoundError", "ProviderNotConfiguredError", diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index 9916daff..42f5c202 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -20,7 +20,7 @@ # Use the typing_extensions backport so this works on 3.12 too. from typing_extensions import TypeVar -from ... import types +from ... import errors, types from ...types import integrity if TYPE_CHECKING: @@ -128,6 +128,12 @@ def __init__( # param for ergonomics; internally we know it's a Pydantic model # subclass (or None for the text-default case). self._output_type = cast("type[pydantic.BaseModel] | None", output_type) + # Whether the provider signalled completion (``StreamEnd``). A + # stream that exhausts without it died mid-response (transport + # drop): the message is partial — possibly reasoning-only or a + # tool call with truncated args — so exhaustion must raise + # rather than look like a normal end of turn. + self._ended = False async def aclose(self) -> None: await self._gen.aclose() @@ -150,9 +156,15 @@ def __aiter__(self) -> Self: async def __anext__(self: Self) -> types.events.Event: try: event = await self._gen.__anext__() - except Exception: - # Usually this fires on StopAsyncIteration, but could be a - # real exception too + except StopAsyncIteration: + if not self._ended: + raise errors.ProviderIncompleteResponseError( + "provider stream ended without a finish event; " + "the response is incomplete", + # Premature termination is a transient transport or + # provider failure: worth retrying. + is_retryable=True, + ) from None raise updates = self._aggregate_event(event) return event.model_copy(update={"message": self._message, **updates}) @@ -189,8 +201,10 @@ def _aggregate_event(self, event: types.events.Event) -> dict[str, Any]: updates: dict[str, Any] = {} # Replay events carry no new state — the seeded message already - # has everything they would have produced. + # has everything they would have produced. A replayed turn is + # complete by construction, so it also counts as ended. if event.replay: + self._ended = True return updates # grab usage from any event that carries one @@ -323,6 +337,7 @@ def _aggregate_event(self, event: types.events.Event) -> dict[str, Any]: self._parts[fp.id] = fp case types.events.StreamEnd(provider_metadata=pm): + self._ended = True if pm is not None: self._message.provider_metadata = pm case _: @@ -474,6 +489,10 @@ async def _stream( seed_message=last.model_copy(deep=True), output_type=cast("type[Any] | None", output_type), ) + # The replayed turn is a complete persisted message; don't + # demand a finish event from the synthetic replay generator + # (it yields nothing when the turn has no tool calls). + s._ended = True else: prepared = integrity.prepare_messages(messages) request = StreamRequest( diff --git a/tests/models/core/test_api.py b/tests/models/core/test_api.py index e7c740f0..3f7e81ab 100644 --- a/tests/models/core/test_api.py +++ b/tests/models/core/test_api.py @@ -483,3 +483,49 @@ async def _spy_stream( pass assert called is True + + +async def test_stream_raises_when_stream_ends_without_finish() -> None: + """A transport drop mid-response must raise, not look like a normal end. + + A completed response always carries a ``StreamEnd``; when the SSE + connection dies the provider generator just exhausts, which used to + end the iteration silently with a partial message (reasoning-only, + or a tool call with truncated args). + """ + + async def _dying_stream( + model: models.Model, + messages: list[messages_.Message], + *, + tools: Sequence[ai.tools.Tool] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[events_.Event]: + yield events_.StreamStart() + yield events_.ReasoningDelta(block_id="r-1", chunk="thinking…") + yield events_.ToolStart(tool_call_id="tc-1", tool_name="bash") + yield events_.ToolDelta(tool_call_id="tc-1", chunk='{"command": "uv ru') + # connection drops: no ToolEnd, no StreamEnd + + MOCK_PROVIDER._stream_impl = _dying_stream + + with pytest.raises(ai.errors.ProviderIncompleteResponseError) as excinfo: + async with models.stream(MOCK_MODEL, [ai.user_message("Hi")]) as stream: + async for _ in stream: + pass + + assert excinfo.value.is_retryable is True + # The partial message stays inspectable on the stream object. + assert stream.message.usage is None + assert stream.message.tool_calls[0].tool_args == '{"command": "uv ru' + + +async def test_stream_does_not_raise_on_early_consumer_exit() -> None: + """Breaking out of iteration early is not an incomplete response.""" + mock_llm([[text_msg("Hello world")]]) + + async with models.stream(MOCK_MODEL, [ai.user_message("Hi")]) as stream: + async for event in stream: + if isinstance(event, events_.TextDelta): + break diff --git a/tests/providers/ai_gateway/test_protocol.py b/tests/providers/ai_gateway/test_protocol.py index 3422dc20..fb93aa0e 100644 --- a/tests/providers/ai_gateway/test_protocol.py +++ b/tests/providers/ai_gateway/test_protocol.py @@ -620,6 +620,7 @@ async def test_signature_survives_round_trip(self) -> None: "providerMetadata": {"anthropic": {"signature": "ErMJsig=="}}, }, {"type": "reasoning-end", "id": "0"}, + {"type": "finish"}, ] async def _gen() -> AsyncGenerator[events_.Event]: