Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,5 @@ reports/
.askui_cache/*

bom.json

*playground*
4,213 changes: 1,933 additions & 2,280 deletions pdm.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ authors = [
]
dependencies = [
"askui-agent-os>=26.1.1",
"anthropic>=0.72.0",
"anthropic>=0.86.0",
"fastapi>=0.115.12",
"fastmcp>=2.3.0",
"gradio-client>=1.4.3",
"grpcio>=1.73.1",
"grpcio>=1.73.1,<1.80.0",
"httpx>=0.28.1",
"Jinja2>=3.1.4",
"openai>=1.61.1",
Expand Down
283 changes: 217 additions & 66 deletions src/askui/callbacks/usage_tracking_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import TYPE_CHECKING

from opentelemetry import trace
from pydantic import BaseModel
from typing_extensions import override
from pydantic import BaseModel, Field
from typing_extensions import Self, override

from askui.callbacks.conversation_callback import ConversationCallback
from askui.reporting import NULL_REPORTER
Expand All @@ -18,6 +18,8 @@
from askui.speaker.speaker import SpeakerResult
from askui.utils.model_pricing import ModelPricing

_USD_CURRENCY = "USD"


class UsageSummary(BaseModel):
"""Accumulated token usage and optional cost breakdown for a conversation.
Expand All @@ -27,9 +29,13 @@ class UsageSummary(BaseModel):
output_tokens (int | None): Total output tokens generated.
cache_creation_input_tokens (int | None): Tokens used for cache creation.
cache_read_input_tokens (int | None): Tokens read from cache.
input_cost (float | None): Computed input cost in `currency`.
output_cost (float | None): Computed output cost in `currency`.
total_cost (float | None): Sum of `input_cost` and `output_cost`.
input_token_cost (float | None): Computed cost for input tokens in `currency`.
output_token_cost (float | None): Computed cost for output tokens in `currency`.
cache_write_token_cost (float | None): Computed cost for cache write tokens in
`currency`.
cache_read_token_cost (float | None): Computed cost for cache read tokens in
`currency`.
total_cost (float | None): Sum of all computed cost values.
currency (str | None): ISO 4217 currency code (e.g. ``"USD"``).
input_cost_per_million_tokens (float | None): Rate used to compute `input_cost`.
output_cost_per_million_tokens (float|None): Rate used to compute `output_cost`.
Expand All @@ -39,12 +45,138 @@ class UsageSummary(BaseModel):
output_tokens: int | None = None
cache_creation_input_tokens: int | None = None
cache_read_input_tokens: int | None = None
input_cost: float | None = None
output_cost: float | None = None
input_token_cost: float | None = None
output_token_cost: float | None = None
cache_write_token_cost: float | None = None
cache_read_token_cost: float | None = None
total_cost: float | None = None
currency: str | None = None
input_cost_per_million_tokens: float | None = None
output_cost_per_million_tokens: float | None = None
cache_write_cost_per_million_tokens: float | None = None
cache_read_cost_per_million_tokens: float | None = None
per_conversation_summaries: list[ConversationUsageSummary] | None = None

@classmethod
def create(cls, pricing: ModelPricing | None = None) -> "UsageSummary":
"""Create a summary configured with optional model pricing."""
if pricing is None:
return cls()
return cls(
input_cost_per_million_tokens=pricing.input_cost_per_million_tokens,
output_cost_per_million_tokens=pricing.output_cost_per_million_tokens,
cache_write_cost_per_million_tokens=(
pricing.cache_write_cost_per_million_tokens
),
cache_read_cost_per_million_tokens=(
pricing.cache_read_cost_per_million_tokens
),
)

@classmethod
def create_from(cls, summary: "UsageSummary") -> "UsageSummary":
"""Create a new summary that reuses pricing fields from `summary`."""
return cls(
input_cost_per_million_tokens=summary.input_cost_per_million_tokens,
output_cost_per_million_tokens=summary.output_cost_per_million_tokens,
cache_write_cost_per_million_tokens=(
summary.cache_write_cost_per_million_tokens
),
cache_read_cost_per_million_tokens=(
summary.cache_read_cost_per_million_tokens
),
)

def add_usage(self, usage: UsageParam) -> None:
"""Add token counts from `usage`."""
self.input_tokens = (self.input_tokens or 0) + (usage.input_tokens or 0)
self.output_tokens = (self.output_tokens or 0) + (usage.output_tokens or 0)
self.cache_creation_input_tokens = (self.cache_creation_input_tokens or 0) + (
usage.cache_creation_input_tokens or 0
)
self.cache_read_input_tokens = (self.cache_read_input_tokens or 0) + (
usage.cache_read_input_tokens or 0
)

def generate(self) -> Self:
"""Compute and populate cost fields from current token and pricing fields."""
if not self._has_pricing():
self._clear_cost_fields()
return self

input_tokens = self.input_tokens or 0
output_tokens = self.output_tokens or 0
cache_write_tokens = self.cache_creation_input_tokens or 0
cache_read_tokens = self.cache_read_input_tokens or 0

assert self.input_cost_per_million_tokens is not None
assert self.output_cost_per_million_tokens is not None
assert self.cache_write_cost_per_million_tokens is not None
assert self.cache_read_cost_per_million_tokens is not None

self.input_token_cost = self._calculate_cost(
input_tokens, self.input_cost_per_million_tokens
)
self.output_token_cost = self._calculate_cost(
output_tokens, self.output_cost_per_million_tokens
)
self.cache_write_token_cost = self._calculate_cost(
cache_write_tokens, self.cache_write_cost_per_million_tokens
)
self.cache_read_token_cost = self._calculate_cost(
cache_read_tokens, self.cache_read_cost_per_million_tokens
)
self.total_cost = (
(self.input_token_cost or 0.0)
+ (self.output_token_cost or 0.0)
+ (self.cache_write_token_cost or 0.0)
+ (self.cache_read_token_cost or 0.0)
)
self.currency = _USD_CURRENCY
return self

def token_attributes(self) -> dict[str, int]:
"""Return token fields for telemetry attributes."""
return {
"input_tokens": self.input_tokens or 0,
"output_tokens": self.output_tokens or 0,
"cache_creation_input_tokens": self.cache_creation_input_tokens or 0,
"cache_read_input_tokens": self.cache_read_input_tokens or 0,
}

def _has_pricing(self) -> bool:
return (
self.input_cost_per_million_tokens is not None
and self.output_cost_per_million_tokens is not None
and self.cache_write_cost_per_million_tokens is not None
and self.cache_read_cost_per_million_tokens is not None
)

def _clear_cost_fields(self) -> None:
self.input_token_cost = None
self.output_token_cost = None
self.cache_write_token_cost = None
self.cache_read_token_cost = None
self.total_cost = None
self.currency = None

@staticmethod
def _calculate_cost(tokens: int, rate_per_million_tokens: float) -> float:
return rate_per_million_tokens * tokens / 1e6


class StepUsageSummary(UsageSummary):
"""Usage summary for a single step."""

step_index: int


class ConversationUsageSummary(UsageSummary):
"""Usage summary for one conversation including per-step breakdown."""

conversation_index: int
conversation_id: str
step_summaries: list[StepUsageSummary] = Field(default_factory=list)


class UsageTrackingCallback(ConversationCallback):
Expand All @@ -62,12 +194,17 @@ def __init__(
pricing: ModelPricing | None = None,
) -> None:
self._reporter = reporter
self._pricing = pricing
self._summary = UsageSummary()
self._summary: UsageSummary = UsageSummary.create(pricing)
self._per_conversation_usage: UsageSummary = UsageSummary.create(pricing)
self._per_conversation_summaries: list[ConversationUsageSummary] = []
self._per_step_summaries: list[StepUsageSummary] = []
self._conversation_index: int = 0

@override
def on_conversation_start(self, conversation: Conversation) -> None:
self._summary = UsageSummary()
self._per_conversation_usage = UsageSummary.create_from(self._summary)
self._per_step_summaries = []
self._conversation_index += 1

@override
def on_step_end(
Expand All @@ -76,71 +213,85 @@ def on_step_end(
step_index: int,
result: SpeakerResult,
) -> None:
if result.usage:
self._accumulate(result.usage)
step_usage: UsageParam | None = result.usage
if step_usage is None:
return

step_summary = self._create_step_summary(
step_index=step_index, usage=step_usage
)
self._per_step_summaries.append(step_summary)
self._per_conversation_usage.add_usage(step_usage)
self._summary.add_usage(step_usage)

current_span = trace.get_current_span()
current_span.set_attributes(step_summary.token_attributes())

@override
def on_conversation_end(self, conversation: Conversation) -> None:
self._reporter.add_usage_summary(self._summary)
generated_steps: list[StepUsageSummary] = [
step_summary.generate() for step_summary in self._per_step_summaries
]
conversation_summary = self._create_conversation_summary(
conversation=conversation,
generated_step_summaries=generated_steps,
)
self._per_conversation_summaries.append(conversation_summary)
self._summary.per_conversation_summaries = list(
self._per_conversation_summaries
)
self._reporter.add_usage_summary(self._summary.generate().model_copy(deep=True))

@property
def accumulated_usage(self) -> UsageSummary:
"""Current accumulated usage statistics."""
return self._summary

def _accumulate(self, step_usage: UsageParam) -> None:
# Add step tokens to running totals (None counts as 0)
self._summary.input_tokens = (self._summary.input_tokens or 0) + (
step_usage.input_tokens or 0
)
self._summary.output_tokens = (self._summary.output_tokens or 0) + (
step_usage.output_tokens or 0
)
self._summary.cache_creation_input_tokens = (
self._summary.cache_creation_input_tokens or 0
) + (step_usage.cache_creation_input_tokens or 0)
self._summary.cache_read_input_tokens = (
self._summary.cache_read_input_tokens or 0
) + (step_usage.cache_read_input_tokens or 0)

# Record per-step token counts on the current OTel span
current_span = trace.get_current_span()
current_span.set_attributes(
{
"input_tokens": step_usage.input_tokens or 0,
"output_tokens": step_usage.output_tokens or 0,
"cache_creation_input_tokens": (
step_usage.cache_creation_input_tokens or 0
),
"cache_read_input_tokens": (step_usage.cache_read_input_tokens or 0),
}
def _create_step_summary(
self, step_index: int, usage: UsageParam
) -> StepUsageSummary:
return StepUsageSummary(
step_index=step_index,
input_tokens=usage.input_tokens or 0,
output_tokens=usage.output_tokens or 0,
cache_creation_input_tokens=usage.cache_creation_input_tokens or 0,
cache_read_input_tokens=usage.cache_read_input_tokens or 0,
input_cost_per_million_tokens=self._summary.input_cost_per_million_tokens,
output_cost_per_million_tokens=self._summary.output_cost_per_million_tokens,
cache_write_cost_per_million_tokens=(
self._summary.cache_write_cost_per_million_tokens
),
cache_read_cost_per_million_tokens=(
self._summary.cache_read_cost_per_million_tokens
),
)

# Update costs from updated totals if pricing values are set
if not (
self._pricing
and self._pricing.input_cost_per_million_tokens
and self._pricing.output_cost_per_million_tokens
):
return

input_cost = (
self._summary.input_tokens
* self._pricing.input_cost_per_million_tokens
/ 1e6
)
output_cost = (
self._summary.output_tokens
* self._pricing.output_cost_per_million_tokens
/ 1e6
)
self._summary.input_cost = input_cost
self._summary.output_cost = output_cost
self._summary.total_cost = input_cost + output_cost
self._summary.currency = self._pricing.currency
self._summary.input_cost_per_million_tokens = (
self._pricing.input_cost_per_million_tokens
)
self._summary.output_cost_per_million_tokens = (
self._pricing.output_cost_per_million_tokens
def _create_conversation_summary(
self,
conversation: Conversation,
generated_step_summaries: list[StepUsageSummary],
) -> ConversationUsageSummary:
conversation_summary = ConversationUsageSummary(
conversation_index=self._conversation_index,
conversation_id=conversation.conversation_id,
step_summaries=generated_step_summaries,
input_tokens=self._per_conversation_usage.input_tokens,
output_tokens=self._per_conversation_usage.output_tokens,
cache_creation_input_tokens=(
self._per_conversation_usage.cache_creation_input_tokens
),
cache_read_input_tokens=self._per_conversation_usage.cache_read_input_tokens,
input_cost_per_million_tokens=(
self._per_conversation_usage.input_cost_per_million_tokens
),
output_cost_per_million_tokens=(
self._per_conversation_usage.output_cost_per_million_tokens
),
cache_write_cost_per_million_tokens=(
self._per_conversation_usage.cache_write_cost_per_million_tokens
),
cache_read_cost_per_million_tokens=(
self._per_conversation_usage.cache_read_cost_per_million_tokens
),
)
return conversation_summary.generate()
12 changes: 10 additions & 2 deletions src/askui/model_providers/anthropic_vlm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@ class AnthropicVlmProvider(VlmProvider):
client (Anthropic | None, optional): Pre-configured Anthropic client.
If provided, other connection parameters are ignored.
input_cost_per_million_tokens (float | None, optional): Override
cost in USD per 1M input tokens. Both cost params must be set
to override the built-in defaults.
cost in USD per 1M input tokens. All override pricing params must be set to
override the built-in defaults.
output_cost_per_million_tokens (float | None, optional): Override
cost in USD per 1M output tokens.
cache_write_cost_per_million_tokens (float | None, optional): Override
cost in USD per 1M cache write input tokens.
cache_read_cost_per_million_tokens (float | None, optional): Override
cost in USD per 1M cache read input tokens.

Example:
```python
Expand All @@ -68,6 +72,8 @@ def __init__(
client: Anthropic | None = None,
input_cost_per_million_tokens: float | None = None,
output_cost_per_million_tokens: float | None = None,
cache_write_cost_per_million_tokens: float | None = None,
cache_read_cost_per_million_tokens: float | None = None,
) -> None:
self._model_id_value = (
model_id or os.environ.get("VLM_PROVIDER_MODEL_ID") or _DEFAULT_MODEL_ID
Expand All @@ -84,6 +90,8 @@ def __init__(
self._model_id_value,
input_cost_per_million_tokens=input_cost_per_million_tokens,
output_cost_per_million_tokens=output_cost_per_million_tokens,
cache_write_cost_per_million_tokens=cache_write_cost_per_million_tokens,
cache_read_cost_per_million_tokens=cache_read_cost_per_million_tokens,
)

@property
Expand Down
Loading
Loading