diff --git a/src/vercel_ai_sdk/__init__.py b/src/vercel_ai_sdk/__init__.py index f9f02c03..a493958e 100644 --- a/src/vercel_ai_sdk/__init__.py +++ b/src/vercel_ai_sdk/__init__.py @@ -14,6 +14,7 @@ TextPart, ToolDelta, ToolPart, + Usage, make_messages, ) from .core.runtime import ( @@ -41,6 +42,7 @@ "ToolLike", "ToolSchema", "Tool", + "Usage", "LanguageModel", "Runtime", "RunResult", diff --git a/src/vercel_ai_sdk/anthropic/__init__.py b/src/vercel_ai_sdk/anthropic/__init__.py index fbb2f440..bce6f5f6 100644 --- a/src/vercel_ai_sdk/anthropic/__init__.py +++ b/src/vercel_ai_sdk/anthropic/__init__.py @@ -208,8 +208,21 @@ async def stream_events( if tool_id: yield core.llm.ToolEnd(tool_call_id=tool_id) - elif event.type == "message_stop": - yield core.llm.MessageDone() + # The Anthropic SDK accumulates usage across message_start and + # message_delta events into current_message_snapshot. Read it + # once here instead of tracking state ourselves. + snapshot = stream.current_message_snapshot + sdk_usage = snapshot.usage + usage = core.messages.Usage( + input_tokens=sdk_usage.input_tokens or 0, + output_tokens=sdk_usage.output_tokens or 0, + cache_read_tokens=getattr(sdk_usage, "cache_read_input_tokens", None), + cache_write_tokens=getattr( + sdk_usage, "cache_creation_input_tokens", None + ), + raw=sdk_usage.model_dump(exclude_none=True) or None, + ) + yield core.llm.MessageDone(usage=usage) @override async def stream( diff --git a/src/vercel_ai_sdk/core/llm.py b/src/vercel_ai_sdk/core/llm.py index 2c262d3e..e967503b 100644 --- a/src/vercel_ai_sdk/core/llm.py +++ b/src/vercel_ai_sdk/core/llm.py @@ -63,6 +63,7 @@ class ToolEnd: @dataclasses.dataclass class MessageDone: finish_reason: str | None = None + usage: messages_.Usage | None = None StreamEvent = ( @@ -104,6 +105,7 @@ class StreamHandler: _active_tool_ids: set[str] = dataclasses.field(default_factory=set) _is_done: bool = False + _usage: messages_.Usage | None = None def handle_event(self, event: StreamEvent) -> messages_.Message: """Process event and return current Message state.""" @@ -153,8 +155,9 @@ def handle_event(self, event: StreamEvent) -> messages_.Message: case ToolEnd(tool_call_id=tcid): self._active_tool_ids.discard(tcid) - case MessageDone(): + case MessageDone(usage=usage): self._is_done = True + self._usage = usage self._active_text_id = None self._active_reasoning_id = None self._active_tool_ids.clear() @@ -209,6 +212,7 @@ def _build_message( id=self.message_id, role="assistant", parts=parts, + usage=self._usage if self._is_done else None, ) diff --git a/src/vercel_ai_sdk/core/messages.py b/src/vercel_ai_sdk/core/messages.py index d1600478..88a5cb85 100644 --- a/src/vercel_ai_sdk/core/messages.py +++ b/src/vercel_ai_sdk/core/messages.py @@ -115,6 +115,56 @@ def value(self) -> Any: ] +class Usage(pydantic.BaseModel): + """Normalized token usage from a single LLM call. + + Provides a provider-agnostic view of token consumption. Fields that a + provider does not report are left as ``None`` (not zero) so callers + can distinguish "not reported" from "zero tokens used". + """ + + input_tokens: int = 0 + output_tokens: int = 0 + + # Optional breakdowns — not all providers report these. + reasoning_tokens: int | None = None + cache_read_tokens: int | None = None + cache_write_tokens: int | None = None + + # Pass-through of the raw provider usage payload so callers can access + # provider-specific fields (e.g. OpenAI's accepted_prediction_tokens). + raw: dict[str, Any] | None = None + + @property + def total_tokens(self) -> int: + """input_tokens + output_tokens (always consistent).""" + return self.input_tokens + self.output_tokens + + def __add__(self, other: Usage) -> Usage: + """Accumulate usage across multiple LLM calls.""" + + def _add_optional(a: int | None, b: int | None) -> int | None: + """Add two optional ints. Returns None only if both are None.""" + if a is None and b is None: + return None + return (a or 0) + (b or 0) + + return Usage( + input_tokens=self.input_tokens + other.input_tokens, + output_tokens=self.output_tokens + other.output_tokens, + reasoning_tokens=_add_optional( + self.reasoning_tokens, other.reasoning_tokens + ), + cache_read_tokens=_add_optional( + self.cache_read_tokens, other.cache_read_tokens + ), + cache_write_tokens=_add_optional( + self.cache_write_tokens, other.cache_write_tokens + ), + # Don't merge raw — it's per-call and provider-specific. + ) + + def _gen_id() -> str: return uuid.uuid4().hex[:12] @@ -130,6 +180,7 @@ class Message(pydantic.BaseModel): parts: list[Part] id: str = pydantic.Field(default_factory=_gen_id) label: str | None = None + usage: Usage | None = None @property def output(self) -> Any: diff --git a/src/vercel_ai_sdk/core/streams.py b/src/vercel_ai_sdk/core/streams.py index 0a1bcc64..15a4a32c 100644 --- a/src/vercel_ai_sdk/core/streams.py +++ b/src/vercel_ai_sdk/core/streams.py @@ -37,6 +37,27 @@ def output(self) -> Any: return self.last_message.output return None + @property + def usage(self) -> messages_.Usage | None: + """Token usage from the last (most recent) LLM call.""" + if self.last_message: + return self.last_message.usage + return None + + @property + def total_usage(self) -> messages_.Usage | None: + """Accumulated token usage across all LLM calls in this result. + + Sums usage from every message that carries it (i.e. assistant + messages produced by LLM calls). Returns ``None`` if no message + reported usage. + """ + total: messages_.Usage | None = None + for msg in self.messages: + if msg.usage is not None: + total = msg.usage if total is None else total + msg.usage + return total + Stream = Callable[[], AsyncGenerator[messages_.Message]] # maybe it should have a name and an id inferred from LLM outputs diff --git a/src/vercel_ai_sdk/openai/__init__.py b/src/vercel_ai_sdk/openai/__init__.py index ff7d3fe3..1b42b50e 100644 --- a/src/vercel_ai_sdk/openai/__init__.py +++ b/src/vercel_ai_sdk/openai/__init__.py @@ -153,6 +153,7 @@ async def stream_events( } if openai_tools: kwargs["tools"] = openai_tools + kwargs["stream_options"] = {"include_usage": True} if output_type is not None: from openai.lib._pydantic import to_strict_json_schema @@ -183,8 +184,35 @@ async def stream_events( text_started = False reasoning_started = False tool_calls: dict[int, dict[str, Any]] = {} # index -> {id, name, started} + finish_reason: str | None = None + usage: core.messages.Usage | None = None async for chunk in stream: + # Extract usage from any chunk that carries it (typically the final + # chunk when stream_options.include_usage is True). + if chunk.usage is not None: + raw = chunk.usage.model_dump(exclude_none=True) + # Extract optional breakdowns + reasoning_tokens: int | None = None + cache_read: int | None = None + completion_details = getattr( + chunk.usage, "completion_tokens_details", None + ) + if completion_details: + reasoning_tokens = getattr( + completion_details, "reasoning_tokens", None + ) + prompt_details = getattr(chunk.usage, "prompt_tokens_details", None) + if prompt_details: + cache_read = getattr(prompt_details, "cached_tokens", None) + usage = core.messages.Usage( + input_tokens=chunk.usage.prompt_tokens or 0, + output_tokens=chunk.usage.completion_tokens or 0, + reasoning_tokens=reasoning_tokens, + cache_read_tokens=cache_read, + raw=raw, + ) + if not chunk.choices: continue @@ -247,6 +275,7 @@ async def stream_events( ) if choice.finish_reason is not None: + finish_reason = choice.finish_reason # Close any open blocks if reasoning_started: yield core.llm.ReasoningEnd(block_id="reasoning") @@ -256,8 +285,10 @@ async def stream_events( if tc["started"] and tc["id"]: yield core.llm.ToolEnd(tool_call_id=tc["id"]) - yield core.llm.MessageDone(finish_reason=choice.finish_reason) - return + # Don't return yet — the usage chunk may arrive after + # finish_reason. We'll emit MessageDone after the loop. + + yield core.llm.MessageDone(finish_reason=finish_reason, usage=usage) @override async def stream( diff --git a/tests/core/test_llm.py b/tests/core/test_llm.py index a20f0c25..f15a97e4 100644 --- a/tests/core/test_llm.py +++ b/tests/core/test_llm.py @@ -20,7 +20,7 @@ ToolEnd, ToolStart, ) -from vercel_ai_sdk.core.messages import ReasoningPart, TextPart, ToolPart +from vercel_ai_sdk.core.messages import ReasoningPart, TextPart, ToolPart, Usage from ..conftest import MockLLM, text_msg @@ -175,6 +175,24 @@ def test_message_done_finalizes_all() -> None: assert m.is_done +def test_message_done_propagates_usage() -> None: + """Usage on MessageDone surfaces on the built Message.""" + usage = Usage(input_tokens=10, output_tokens=20) + h = StreamHandler(message_id="m1") + h.handle_event(TextStart(block_id="t1")) + h.handle_event(TextDelta(block_id="t1", delta="hi")) + + # Before MessageDone, usage should not be on the message + m = h.handle_event(TextEnd(block_id="t1")) + assert m.usage is None + + m = h.handle_event(MessageDone(usage=usage)) + assert m.usage is not None + assert m.usage.input_tokens == 10 + assert m.usage.output_tokens == 20 + assert m.usage.total_tokens == 30 + + # -- Message properties propagate ------------------------------------------ diff --git a/tests/core/test_messages.py b/tests/core/test_messages.py index 36565e90..96d443bf 100644 --- a/tests/core/test_messages.py +++ b/tests/core/test_messages.py @@ -11,6 +11,7 @@ StructuredOutputPart, TextPart, ToolPart, + Usage, make_messages, ) @@ -282,3 +283,37 @@ def test_structured_output_round_trip() -> None: restored = Message.model_validate(m.model_dump()) assert isinstance(restored.output, _Weather) assert restored.output.city == "SF" + + +# -- Usage ----------------------------------------------------------------- + + +def test_usage_add_merges_optional_fields() -> None: + """__add__ accumulates tokens and treats None vs populated correctly.""" + a = Usage( + input_tokens=100, + output_tokens=50, + cache_read_tokens=20, + # reasoning_tokens and cache_write_tokens left as None + ) + b = Usage( + input_tokens=200, + output_tokens=80, + reasoning_tokens=10, + # cache_read_tokens left as None, cache_write_tokens left as None + ) + total = a + b + + assert total.input_tokens == 300 + assert total.output_tokens == 130 + assert total.total_tokens == 430 + + # None + int -> int (not None) + assert total.reasoning_tokens == 10 + # int + None -> int (not None) + assert total.cache_read_tokens == 20 + # None + None -> None (not zero) + assert total.cache_write_tokens is None + + # raw is intentionally not merged + assert total.raw is None diff --git a/uv.lock b/uv.lock index 545c9d50..4184f2f8 100644 --- a/uv.lock +++ b/uv.lock @@ -997,7 +997,7 @@ wheels = [ [[package]] name = "vercel-ai-sdk" -version = "0.0.1.dev4" +version = "0.0.1.dev5" source = { editable = "." } dependencies = [ { name = "anthropic" },