diff --git a/py/src/braintrust/integrations/langchain/callbacks.py b/py/src/braintrust/integrations/langchain/callbacks.py index a80a5625..fc40b635 100644 --- a/py/src/braintrust/integrations/langchain/callbacks.py +++ b/py/src/braintrust/integrations/langchain/callbacks.py @@ -656,6 +656,19 @@ def _get_metrics_from_response(response: LLMResult): if cache_creation is not None: metrics["prompt_cache_creation_tokens"] = cache_creation + cache_tokens = (cache_read or 0) + (cache_creation or 0) + prompt_tokens = metrics.get("prompt_tokens") + completion_tokens = metrics.get("completion_tokens") + total_tokens = metrics.get("total_tokens") + if ( + cache_tokens + and prompt_tokens is not None + and completion_tokens is not None + and total_tokens == prompt_tokens + completion_tokens + ): + metrics["prompt_tokens"] = prompt_tokens + cache_tokens + metrics["total_tokens"] = total_tokens + cache_tokens + if not metrics or not any(metrics.values()): llm_output: dict[str, Any] = response.llm_output or {} metrics = llm_output.get("token_usage") or llm_output.get("estimatedTokens") or {} diff --git a/py/src/braintrust/integrations/langchain/test_callbacks.py b/py/src/braintrust/integrations/langchain/test_callbacks.py index adeaa37d..11fecd1a 100644 --- a/py/src/braintrust/integrations/langchain/test_callbacks.py +++ b/py/src/braintrust/integrations/langchain/test_callbacks.py @@ -1087,6 +1087,8 @@ def test_prompt_caching_tokens(logger_memory_logger): assert "prompt_cache_creation_tokens" in first_metrics assert first_metrics["prompt_cache_creation_tokens"] > 0 assert first_metrics["prompt_cached_tokens"] == 0 + assert first_metrics["prompt_tokens"] >= first_metrics["prompt_cache_creation_tokens"] + assert first_metrics["total_tokens"] == first_metrics["prompt_tokens"] + first_metrics["completion_tokens"] second_metrics = None for attempt in range(3): @@ -1116,6 +1118,8 @@ def test_prompt_caching_tokens(logger_memory_logger): assert second_metrics is not None assert second_metrics["prompt_cached_tokens"] > 0 + assert second_metrics["prompt_tokens"] >= second_metrics["prompt_cached_tokens"] + assert second_metrics["total_tokens"] == second_metrics["prompt_tokens"] + second_metrics["completion_tokens"] @pytest.mark.vcr