Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/vercel_ai_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TextPart,
ToolDelta,
ToolPart,
Usage,
make_messages,
)
from .core.runtime import (
Expand Down Expand Up @@ -41,6 +42,7 @@
"ToolLike",
"ToolSchema",
"Tool",
"Usage",
"LanguageModel",
"Runtime",
"RunResult",
Expand Down
17 changes: 15 additions & 2 deletions src/vercel_ai_sdk/anthropic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,21 @@ async def stream_events(
if tool_id:
yield core.llm.ToolEnd(tool_call_id=tool_id)

elif event.type == "message_stop":
yield core.llm.MessageDone()
# The Anthropic SDK accumulates usage across message_start and
# message_delta events into current_message_snapshot. Read it
# once here instead of tracking state ourselves.
snapshot = stream.current_message_snapshot
sdk_usage = snapshot.usage
usage = core.messages.Usage(
input_tokens=sdk_usage.input_tokens or 0,
output_tokens=sdk_usage.output_tokens or 0,
cache_read_tokens=getattr(sdk_usage, "cache_read_input_tokens", None),
cache_write_tokens=getattr(
sdk_usage, "cache_creation_input_tokens", None
),
raw=sdk_usage.model_dump(exclude_none=True) or None,
)
yield core.llm.MessageDone(usage=usage)

@override
async def stream(
Expand Down
6 changes: 5 additions & 1 deletion src/vercel_ai_sdk/core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class ToolEnd:
@dataclasses.dataclass
class MessageDone:
finish_reason: str | None = None
usage: messages_.Usage | None = None


StreamEvent = (
Expand Down Expand Up @@ -104,6 +105,7 @@ class StreamHandler:
_active_tool_ids: set[str] = dataclasses.field(default_factory=set)

_is_done: bool = False
_usage: messages_.Usage | None = None

def handle_event(self, event: StreamEvent) -> messages_.Message:
"""Process event and return current Message state."""
Expand Down Expand Up @@ -153,8 +155,9 @@ def handle_event(self, event: StreamEvent) -> messages_.Message:
case ToolEnd(tool_call_id=tcid):
self._active_tool_ids.discard(tcid)

case MessageDone():
case MessageDone(usage=usage):
self._is_done = True
self._usage = usage
self._active_text_id = None
self._active_reasoning_id = None
self._active_tool_ids.clear()
Expand Down Expand Up @@ -209,6 +212,7 @@ def _build_message(
id=self.message_id,
role="assistant",
parts=parts,
usage=self._usage if self._is_done else None,
)


Expand Down
51 changes: 51 additions & 0 deletions src/vercel_ai_sdk/core/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,56 @@ def value(self) -> Any:
]


class Usage(pydantic.BaseModel):
"""Normalized token usage from a single LLM call.

Provides a provider-agnostic view of token consumption. Fields that a
provider does not report are left as ``None`` (not zero) so callers
can distinguish "not reported" from "zero tokens used".
"""

input_tokens: int = 0
output_tokens: int = 0

# Optional breakdowns — not all providers report these.
reasoning_tokens: int | None = None
cache_read_tokens: int | None = None
cache_write_tokens: int | None = None

# Pass-through of the raw provider usage payload so callers can access
# provider-specific fields (e.g. OpenAI's accepted_prediction_tokens).
raw: dict[str, Any] | None = None

@property
def total_tokens(self) -> int:
"""input_tokens + output_tokens (always consistent)."""
return self.input_tokens + self.output_tokens

def __add__(self, other: Usage) -> Usage:
"""Accumulate usage across multiple LLM calls."""

def _add_optional(a: int | None, b: int | None) -> int | None:
"""Add two optional ints. Returns None only if both are None."""
if a is None and b is None:
return None
return (a or 0) + (b or 0)

return Usage(
input_tokens=self.input_tokens + other.input_tokens,
output_tokens=self.output_tokens + other.output_tokens,
reasoning_tokens=_add_optional(
self.reasoning_tokens, other.reasoning_tokens
),
cache_read_tokens=_add_optional(
self.cache_read_tokens, other.cache_read_tokens
),
cache_write_tokens=_add_optional(
self.cache_write_tokens, other.cache_write_tokens
),
# Don't merge raw — it's per-call and provider-specific.
)


def _gen_id() -> str:
return uuid.uuid4().hex[:12]

Expand All @@ -130,6 +180,7 @@ class Message(pydantic.BaseModel):
parts: list[Part]
id: str = pydantic.Field(default_factory=_gen_id)
label: str | None = None
usage: Usage | None = None

@property
def output(self) -> Any:
Expand Down
21 changes: 21 additions & 0 deletions src/vercel_ai_sdk/core/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,27 @@ def output(self) -> Any:
return self.last_message.output
return None

@property
def usage(self) -> messages_.Usage | None:
"""Token usage from the last (most recent) LLM call."""
if self.last_message:
return self.last_message.usage
return None

@property
def total_usage(self) -> messages_.Usage | None:
"""Accumulated token usage across all LLM calls in this result.

Sums usage from every message that carries it (i.e. assistant
messages produced by LLM calls). Returns ``None`` if no message
reported usage.
"""
total: messages_.Usage | None = None
for msg in self.messages:
if msg.usage is not None:
total = msg.usage if total is None else total + msg.usage
return total


Stream = Callable[[], AsyncGenerator[messages_.Message]]
# maybe it should have a name and an id inferred from LLM outputs
Expand Down
35 changes: 33 additions & 2 deletions src/vercel_ai_sdk/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ async def stream_events(
}
if openai_tools:
kwargs["tools"] = openai_tools
kwargs["stream_options"] = {"include_usage": True}

if output_type is not None:
from openai.lib._pydantic import to_strict_json_schema
Expand Down Expand Up @@ -183,8 +184,35 @@ async def stream_events(
text_started = False
reasoning_started = False
tool_calls: dict[int, dict[str, Any]] = {} # index -> {id, name, started}
finish_reason: str | None = None
usage: core.messages.Usage | None = None

async for chunk in stream:
# Extract usage from any chunk that carries it (typically the final
# chunk when stream_options.include_usage is True).
if chunk.usage is not None:
raw = chunk.usage.model_dump(exclude_none=True)
# Extract optional breakdowns
reasoning_tokens: int | None = None
cache_read: int | None = None
completion_details = getattr(
chunk.usage, "completion_tokens_details", None
)
if completion_details:
reasoning_tokens = getattr(
completion_details, "reasoning_tokens", None
)
prompt_details = getattr(chunk.usage, "prompt_tokens_details", None)
if prompt_details:
cache_read = getattr(prompt_details, "cached_tokens", None)
usage = core.messages.Usage(
input_tokens=chunk.usage.prompt_tokens or 0,
output_tokens=chunk.usage.completion_tokens or 0,
reasoning_tokens=reasoning_tokens,
cache_read_tokens=cache_read,
raw=raw,
)

if not chunk.choices:
continue

Expand Down Expand Up @@ -247,6 +275,7 @@ async def stream_events(
)

if choice.finish_reason is not None:
finish_reason = choice.finish_reason
# Close any open blocks
if reasoning_started:
yield core.llm.ReasoningEnd(block_id="reasoning")
Expand All @@ -256,8 +285,10 @@ async def stream_events(
if tc["started"] and tc["id"]:
yield core.llm.ToolEnd(tool_call_id=tc["id"])

yield core.llm.MessageDone(finish_reason=choice.finish_reason)
return
# Don't return yet — the usage chunk may arrive after
# finish_reason. We'll emit MessageDone after the loop.

yield core.llm.MessageDone(finish_reason=finish_reason, usage=usage)

@override
async def stream(
Expand Down
20 changes: 19 additions & 1 deletion tests/core/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ToolEnd,
ToolStart,
)
from vercel_ai_sdk.core.messages import ReasoningPart, TextPart, ToolPart
from vercel_ai_sdk.core.messages import ReasoningPart, TextPart, ToolPart, Usage

from ..conftest import MockLLM, text_msg

Expand Down Expand Up @@ -175,6 +175,24 @@ def test_message_done_finalizes_all() -> None:
assert m.is_done


def test_message_done_propagates_usage() -> None:
"""Usage on MessageDone surfaces on the built Message."""
usage = Usage(input_tokens=10, output_tokens=20)
h = StreamHandler(message_id="m1")
h.handle_event(TextStart(block_id="t1"))
h.handle_event(TextDelta(block_id="t1", delta="hi"))

# Before MessageDone, usage should not be on the message
m = h.handle_event(TextEnd(block_id="t1"))
assert m.usage is None

m = h.handle_event(MessageDone(usage=usage))
assert m.usage is not None
assert m.usage.input_tokens == 10
assert m.usage.output_tokens == 20
assert m.usage.total_tokens == 30


# -- Message properties propagate ------------------------------------------


Expand Down
35 changes: 35 additions & 0 deletions tests/core/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
StructuredOutputPart,
TextPart,
ToolPart,
Usage,
make_messages,
)

Expand Down Expand Up @@ -282,3 +283,37 @@ def test_structured_output_round_trip() -> None:
restored = Message.model_validate(m.model_dump())
assert isinstance(restored.output, _Weather)
assert restored.output.city == "SF"


# -- Usage -----------------------------------------------------------------


def test_usage_add_merges_optional_fields() -> None:
"""__add__ accumulates tokens and treats None vs populated correctly."""
a = Usage(
input_tokens=100,
output_tokens=50,
cache_read_tokens=20,
# reasoning_tokens and cache_write_tokens left as None
)
b = Usage(
input_tokens=200,
output_tokens=80,
reasoning_tokens=10,
# cache_read_tokens left as None, cache_write_tokens left as None
)
total = a + b

assert total.input_tokens == 300
assert total.output_tokens == 130
assert total.total_tokens == 430

# None + int -> int (not None)
assert total.reasoning_tokens == 10
# int + None -> int (not None)
assert total.cache_read_tokens == 20
# None + None -> None (not zero)
assert total.cache_write_tokens is None

# raw is intentionally not merged
assert total.raw is None
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.