From 1f7ea0e7aaa27ef45d0d23573baff84d446a8dae Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Mon, 11 May 2026 12:26:36 -0400 Subject: [PATCH] fix(langchain): normalize anthropic cache token metrics LangChain Anthropic usage reports cache read/write tokens separately from raw input and total counts, which could produce Braintrust spans where cache creation tokens exceeded total tokens. Fold cache read and creation tokens into prompt_tokens and total_tokens when LangChain appears to expose uncached prompt totals. This keeps Braintrust token metrics internally consistent for users relying on prompt caching: before, spans could underreport prompt/total tokens; after, prompt_tokens includes uncached, cached, and cache-creation tokens and total_tokens remains prompt plus completion tokens. Extend the existing VCR prompt-caching regression test to assert the normalized metric convention for cache creation and cache reads. --- .../braintrust/integrations/langchain/callbacks.py | 13 +++++++++++++ .../integrations/langchain/test_callbacks.py | 4 ++++ 2 files changed, 17 insertions(+) diff --git a/py/src/braintrust/integrations/langchain/callbacks.py b/py/src/braintrust/integrations/langchain/callbacks.py index a80a5625..fc40b635 100644 --- a/py/src/braintrust/integrations/langchain/callbacks.py +++ b/py/src/braintrust/integrations/langchain/callbacks.py @@ -656,6 +656,19 @@ def _get_metrics_from_response(response: LLMResult): if cache_creation is not None: metrics["prompt_cache_creation_tokens"] = cache_creation + cache_tokens = (cache_read or 0) + (cache_creation or 0) + prompt_tokens = metrics.get("prompt_tokens") + completion_tokens = metrics.get("completion_tokens") + total_tokens = metrics.get("total_tokens") + if ( + cache_tokens + and prompt_tokens is not None + and completion_tokens is not None + and total_tokens == prompt_tokens + completion_tokens + ): + metrics["prompt_tokens"] = prompt_tokens + cache_tokens + metrics["total_tokens"] = total_tokens + cache_tokens + if not metrics or not any(metrics.values()): llm_output: dict[str, Any] = response.llm_output or {} metrics = llm_output.get("token_usage") or llm_output.get("estimatedTokens") or {} diff --git a/py/src/braintrust/integrations/langchain/test_callbacks.py b/py/src/braintrust/integrations/langchain/test_callbacks.py index adeaa37d..11fecd1a 100644 --- a/py/src/braintrust/integrations/langchain/test_callbacks.py +++ b/py/src/braintrust/integrations/langchain/test_callbacks.py @@ -1087,6 +1087,8 @@ def test_prompt_caching_tokens(logger_memory_logger): assert "prompt_cache_creation_tokens" in first_metrics assert first_metrics["prompt_cache_creation_tokens"] > 0 assert first_metrics["prompt_cached_tokens"] == 0 + assert first_metrics["prompt_tokens"] >= first_metrics["prompt_cache_creation_tokens"] + assert first_metrics["total_tokens"] == first_metrics["prompt_tokens"] + first_metrics["completion_tokens"] second_metrics = None for attempt in range(3): @@ -1116,6 +1118,8 @@ def test_prompt_caching_tokens(logger_memory_logger): assert second_metrics is not None assert second_metrics["prompt_cached_tokens"] > 0 + assert second_metrics["prompt_tokens"] >= second_metrics["prompt_cached_tokens"] + assert second_metrics["total_tokens"] == second_metrics["prompt_tokens"] + second_metrics["completion_tokens"] @pytest.mark.vcr