diff --git a/crates/adaptive/tests/unit/acg/telemetry_tests.rs b/crates/adaptive/tests/unit/acg/telemetry_tests.rs index 7ce86135..53d56f34 100644 --- a/crates/adaptive/tests/unit/acg/telemetry_tests.rs +++ b/crates/adaptive/tests/unit/acg/telemetry_tests.rs @@ -168,6 +168,7 @@ fn test_anthropic_cache_telemetry_event_reconstructs_total_prompt_tokens() { total_tokens: None, cache_read_tokens: Some(500), cache_write_tokens: Some(200), + cost: None, }; let event = CacheTelemetryEvent::from_usage( @@ -196,6 +197,7 @@ fn test_anthropic_cache_telemetry_event_maps_write_only_zero_read_to_cold_start( total_tokens: None, cache_read_tokens: Some(0), cache_write_tokens: Some(700), + cost: None, }; let event = CacheTelemetryEvent::from_usage( @@ -221,6 +223,7 @@ fn test_anthropic_cache_telemetry_event_returns_none_without_prompt_tokens() { total_tokens: None, cache_read_tokens: Some(500), cache_write_tokens: Some(200), + cost: None, }; let event = CacheTelemetryEvent::from_usage( @@ -243,6 +246,7 @@ fn test_openai_cache_telemetry_event_normalizes_creation_tokens_to_zero() { total_tokens: None, cache_read_tokens: Some(600), cache_write_tokens: Some(999), + cost: None, }; let event = CacheTelemetryEvent::from_usage( @@ -270,6 +274,7 @@ fn test_openai_cache_telemetry_event_maps_zero_read_to_unknown() { total_tokens: None, cache_read_tokens: Some(0), cache_write_tokens: Some(999), + cost: None, }; let event = CacheTelemetryEvent::from_usage( @@ -303,6 +308,7 @@ fn telemetry_observability_keeps_request_facts_optional_for_anthropic_unknown_mi total_tokens: None, cache_read_tokens: Some(0), cache_write_tokens: Some(0), + cost: None, }; let event = CacheTelemetryEvent::from_usage( @@ -340,6 +346,7 @@ fn test_from_usage_uses_prefix_mismatch_diagnosis_when_request_facts_are_availab total_tokens: None, cache_read_tokens: Some(0), cache_write_tokens: Some(0), + cost: None, }; let request_facts = CacheRequestFacts { provider: "openai".to_string(), @@ -408,6 +415,7 @@ fn test_cache_miss_diagnosis_prefix_mismatch_is_bounded_and_serialized() { total_tokens: None, cache_read_tokens: Some(0), cache_write_tokens: Some(0), + cost: None, }, sample_timestamp(), Some(&request_facts), @@ -475,6 +483,7 @@ fn test_cache_miss_diagnosis_below_minimum_threshold_reports_exact_token_counts( total_tokens: None, cache_read_tokens: Some(0), cache_write_tokens: Some(0), + cost: None, }, sample_timestamp(), Some(&request_facts), @@ -529,6 +538,7 @@ fn test_cache_miss_diagnosis_retention_expired_reports_gap_and_window() { total_tokens: None, cache_read_tokens: Some(0), cache_write_tokens: Some(0), + cost: None, }, sample_timestamp(), Some(&request_facts), @@ -584,6 +594,7 @@ fn test_cache_miss_diagnosis_unknown_preserves_missing_facts() { total_tokens: None, cache_read_tokens: Some(0), cache_write_tokens: Some(0), + cost: None, }, sample_timestamp(), Some(&request_facts), @@ -636,6 +647,7 @@ fn test_no_write_anthropic_cache_miss_diagnosis_uses_threshold_facts_without_loc total_tokens: None, cache_read_tokens: Some(0), cache_write_tokens: Some(0), + cost: None, }, sample_timestamp(), Some(&request_facts), @@ -676,6 +688,7 @@ fn test_anthropic_multi_breakpoint_telemetry_event_uses_normalized_usage_totals( total_tokens: None, cache_read_tokens: Some(900), cache_write_tokens: Some(600), + cost: None, }, sample_timestamp(), Some(&request_facts), diff --git a/crates/adaptive/tests/unit/drain_tests.rs b/crates/adaptive/tests/unit/drain_tests.rs index 63d43494..25c8b9b0 100644 --- a/crates/adaptive/tests/unit/drain_tests.rs +++ b/crates/adaptive/tests/unit/drain_tests.rs @@ -870,6 +870,7 @@ fn test_accumulator_extracts_annotated_response() { total_tokens: Some(150), cache_read_tokens: None, cache_write_tokens: None, + cost: None, }), api_specific: None, extra: serde_json::Map::new(), diff --git a/crates/cli/src/config.rs b/crates/cli/src/config.rs index 2307065f..4125e469 100644 --- a/crates/cli/src/config.rs +++ b/crates/cli/src/config.rs @@ -79,6 +79,8 @@ pub(crate) enum Command { Config(ConfigCommand), /// Create or edit plugin configuration (writes `plugins.toml`) Plugins(PluginsCommand), + /// Validate and configure model pricing catalogs. + Pricing(PricingCommand), /// Diagnose env, agents, config, observability (optionally scoped to one agent) Doctor(DoctorCommand), /// List supported and locally-detected agents (use `--json` for machine output) @@ -162,6 +164,93 @@ pub(crate) enum PluginsSubcommand { Edit(PluginsEditCommand), } +/// Args for `nemo-relay pricing`. +#[derive(Debug, Clone, Args)] +pub(crate) struct PricingCommand { + #[command(subcommand)] + pub(crate) command: PricingSubcommand, +} + +/// Pricing catalog and resolver subcommands. +#[derive(Debug, Clone, Subcommand)] +pub(crate) enum PricingSubcommand { + /// Validate a pricing catalog JSON file. + Validate(PricingValidateCommand), + /// Initialize the pricing plugin component in `plugins.toml`. + Init(PricingInitCommand), + /// Add a pricing catalog file source to `plugins.toml`. + AddSource(PricingAddSourceCommand), + /// Resolve which pricing entry matches a model and optional usage. + Resolve(PricingResolveCommand), +} + +/// Common target-scope flags for pricing config mutations. +#[derive(Debug, Clone, Default, Args)] +#[command(group( + ArgGroup::new("pricing_scope") + .args(["user", "project", "global"]) + .multiple(false) +))] +pub(crate) struct PricingScopeArgs { + /// Edit the user config at `$XDG_CONFIG_HOME/nemo-relay/plugins.toml`. + #[arg(long)] + pub(crate) user: bool, + /// Edit the nearest project config at `.nemo-relay/plugins.toml`. + #[arg(long)] + pub(crate) project: bool, + /// Edit the system config at `/etc/nemo-relay/plugins.toml`. + #[arg(long)] + pub(crate) global: bool, +} + +/// Args for `nemo-relay pricing validate`. +#[derive(Debug, Clone, Args)] +pub(crate) struct PricingValidateCommand { + /// Path to a Relay pricing catalog JSON file. + pub(crate) path: PathBuf, +} + +/// Args for `nemo-relay pricing init`. +#[derive(Debug, Clone, Args)] +pub(crate) struct PricingInitCommand { + #[command(flatten)] + pub(crate) scope: PricingScopeArgs, +} + +/// Args for `nemo-relay pricing add-source`. +#[derive(Debug, Clone, Args)] +pub(crate) struct PricingAddSourceCommand { + #[command(flatten)] + pub(crate) scope: PricingScopeArgs, + /// Path to a Relay pricing catalog JSON file. + pub(crate) path: PathBuf, + /// Append as a lower-priority source instead of prepending as the highest-priority override. + #[arg(long)] + pub(crate) append: bool, +} + +/// Args for `nemo-relay pricing resolve`. +#[derive(Debug, Clone, Args)] +pub(crate) struct PricingResolveCommand { + /// Model ID or routed model name to look up. + pub(crate) model: String, + /// Optional provider or route, such as `openai`, `anthropic`, or `azure/openai`. + #[arg(long)] + pub(crate) provider: Option, + /// Prompt/input token count to use for an estimate. + #[arg(long)] + pub(crate) prompt_tokens: Option, + /// Completion/output token count to use for an estimate. + #[arg(long)] + pub(crate) completion_tokens: Option, + /// Prompt-cache read token count to use for an estimate. + #[arg(long)] + pub(crate) cache_read_tokens: Option, + /// Prompt-cache write token count to use for an estimate. + #[arg(long)] + pub(crate) cache_write_tokens: Option, +} + /// Args for `nemo-relay plugins edit`. #[derive(Debug, Clone, Default, Args)] #[command(group( @@ -898,13 +987,47 @@ fn merge_plugin_components(left: &mut toml::Value, right: toml::Value) { .iter_mut() .find(|candidate| component_kind(candidate) == Some(kind.as_str())) { - merge_toml(existing, component); + if kind == "pricing" { + merge_pricing_component(existing, component); + } else { + merge_toml(existing, component); + } } else { left_components.push(component); } } } +fn merge_pricing_component(existing: &mut toml::Value, higher_priority: toml::Value) { + let lower_priority_sources = pricing_component_sources(existing).cloned(); + let higher_priority_sources = pricing_component_sources(&higher_priority).cloned(); + merge_toml(existing, higher_priority); + + let Some(mut sources) = higher_priority_sources else { + return; + }; + if let Some(lower_priority_sources) = lower_priority_sources { + sources.extend(lower_priority_sources); + } + set_pricing_component_sources(existing, sources); +} + +fn pricing_component_sources(component: &toml::Value) -> Option<&Vec> { + component + .get("config") + .and_then(|config| config.get("sources")) + .and_then(toml::Value::as_array) +} + +fn set_pricing_component_sources(component: &mut toml::Value, sources: Vec) { + if let Some(config) = component + .get_mut("config") + .and_then(toml::Value::as_table_mut) + { + config.insert("sources".into(), toml::Value::Array(sources)); + } +} + fn component_kind(component: &toml::Value) -> Option<&str> { component .as_table() diff --git a/crates/cli/src/doctor.rs b/crates/cli/src/doctor.rs index 7b2898e2..51d6250b 100644 --- a/crates/cli/src/doctor.rs +++ b/crates/cli/src/doctor.rs @@ -15,6 +15,7 @@ use std::time::Duration; use futures_util::SinkExt; use nemo_relay::api::event::{BaseEvent, Event, MarkEvent}; +use nemo_relay::codec::pricing::{PricingCatalog, PricingConfig, PricingSourceConfig}; use nemo_relay::observability::plugin_component::OBSERVABILITY_PLUGIN_KIND; use nemo_relay::plugin::{DiagnosticLevel, PluginConfig, validate_plugin_config}; use nemo_relay_adaptive::plugin_component::register_adaptive_component; @@ -30,6 +31,7 @@ use crate::config::{ use crate::error::CliError; const NETWORK_TIMEOUT: Duration = Duration::from_secs(2); +const PRICING_PLUGIN_KIND: &str = "pricing"; /// Outcome of one check inside the doctor report. The `details` field carries human-readable /// supplementary text; the `status` is the bottom-line signal callers (and CI) use to decide @@ -635,6 +637,7 @@ async fn collect_observability(gateway: &GatewayConfig) -> Vec { details: "component not configured".into(), }); } + collect_pricing_component_checks(&mut checks, &plugin_config); checks } @@ -716,6 +719,89 @@ fn observability_component_config(plugin_value: &Value) -> Option<&Value> { .and_then(|component| component.get("config")) } +fn collect_pricing_component_checks(checks: &mut Vec, plugin_config: &PluginConfig) { + let Some(component) = plugin_config + .components + .iter() + .find(|component| component.kind == PRICING_PLUGIN_KIND) + else { + checks.push(Check { + name: "Pricing", + status: Status::Info, + details: "component not configured".into(), + }); + return; + }; + + if !component.enabled { + checks.push(Check { + name: "Pricing", + status: Status::Info, + details: "component disabled".into(), + }); + return; + } + + let config = + match serde_json::from_value::(Value::Object(component.config.clone())) { + Ok(config) => config, + Err(error) => { + checks.push(Check { + name: "Pricing", + status: Status::Fail, + details: format!("invalid config: {error}"), + }); + return; + } + }; + + if config.sources.is_empty() { + checks.push(Check { + name: "Pricing", + status: Status::Info, + details: "component configured with no sources".into(), + }); + return; + } + + for (index, source) in config.sources.iter().enumerate() { + checks.push(pricing_source_check(index, source)); + } +} + +fn pricing_source_check(index: usize, source: &PricingSourceConfig) -> Check { + match source { + PricingSourceConfig::Inline { catalog } => Check { + name: "Pricing source", + status: Status::Pass, + details: format!("inline:{index} valid ({} entries)", catalog.entries.len()), + }, + PricingSourceConfig::File { path } => match std::fs::read_to_string(path) { + Ok(raw) => match PricingCatalog::from_json_str(&raw) { + Ok(catalog) => Check { + name: "Pricing source", + status: Status::Pass, + details: format!( + "file:{} valid ({} entries)", + path.display(), + catalog.entries.len() + ), + }, + Err(error) => Check { + name: "Pricing source", + status: Status::Fail, + details: format!("file:{} invalid catalog: {error}", path.display()), + }, + }, + Err(error) => Check { + name: "Pricing source", + status: Status::Fail, + details: format!("file:{} unreadable: {error}", path.display()), + }, + }, + } +} + fn section_enabled(config: &Value, section: &str) -> bool { config .get(section) diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 91a2eaa3..8bd4e2b2 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -15,6 +15,7 @@ mod installer; mod launcher; mod model; mod plugins; +mod pricing; mod server; mod session; mod setup; @@ -24,7 +25,7 @@ use std::process::ExitCode; use clap::Parser; -use crate::config::{Cli, CodingAgent, Command, PluginsSubcommand}; +use crate::config::{Cli, CodingAgent, Command, PluginsSubcommand, PricingSubcommand}; #[tokio::main] // Runs the async CLI entrypoint and converts any surfaced gateway error into a non-zero process @@ -81,6 +82,15 @@ async fn run() -> Result { } Ok(ExitCode::SUCCESS) } + Some(Command::Pricing(command)) => { + match command.command { + PricingSubcommand::Validate(command) => pricing::validate(command)?, + PricingSubcommand::Init(command) => pricing::init(command)?, + PricingSubcommand::AddSource(command) => pricing::add_source(command)?, + PricingSubcommand::Resolve(command) => pricing::resolve(command)?, + } + Ok(ExitCode::SUCCESS) + } Some(Command::Doctor(command)) => doctor::run_doctor(command.agent, command.json).await, Some(Command::Agents(command)) => doctor::run_agents(command.json).await, Some(Command::Completions(command)) => { diff --git a/crates/cli/src/plugins.rs b/crates/cli/src/plugins.rs index b2120fd3..f6f44800 100644 --- a/crates/cli/src/plugins.rs +++ b/crates/cli/src/plugins.rs @@ -20,7 +20,7 @@ use serde_json::{Value, json}; use crate::config::PluginsEditCommand; use crate::error::CliError; -mod config_io; +pub(crate) mod config_io; mod editor_model; use self::config_io::*; diff --git a/crates/cli/src/plugins/config_io.rs b/crates/cli/src/plugins/config_io.rs index c256dc8a..98b35b26 100644 --- a/crates/cli/src/plugins/config_io.rs +++ b/crates/cli/src/plugins/config_io.rs @@ -17,13 +17,13 @@ use crate::config::{ use crate::error::CliError; #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(super) enum TargetScope { +pub(crate) enum TargetScope { User, Project, Global, } -pub(super) fn target_scope(command: &PluginsEditCommand) -> Result { +pub(crate) fn target_scope(command: &PluginsEditCommand) -> Result { let selected = [command.user, command.project, command.global] .into_iter() .filter(|selected| *selected) @@ -42,7 +42,7 @@ pub(super) fn target_scope(command: &PluginsEditCommand) -> Result Result { +pub(crate) fn target_path(scope: TargetScope) -> Result { match scope { TargetScope::User => user_plugin_config_path().ok_or_else(|| { CliError::Config( @@ -57,7 +57,7 @@ pub(super) fn target_path(scope: TargetScope) -> Result { } } -pub(super) fn read_plugin_config(path: &Path) -> Result { +pub(crate) fn read_plugin_config(path: &Path) -> Result { if !path.exists() { return Ok(PluginConfig::default()); } @@ -78,7 +78,7 @@ pub(super) fn read_plugin_config(path: &Path) -> Result .map_err(|error| CliError::Config(format!("invalid plugin config: {error}"))) } -pub(super) fn write_plugin_config(path: &Path, config: &PluginConfig) -> Result<(), CliError> { +pub(crate) fn write_plugin_config(path: &Path, config: &PluginConfig) -> Result<(), CliError> { let mut value = serde_json::to_value(config) .map_err(|error| CliError::Config(format!("could not serialize plugin config: {error}")))?; prune_plugin_defaults(&mut value); @@ -115,7 +115,7 @@ pub(super) fn print_preview(config: &PluginConfig) -> Result<(), CliError> { Ok(()) } -pub(super) fn validate_config(config: &PluginConfig) -> Result<(), CliError> { +pub(crate) fn validate_config(config: &PluginConfig) -> Result<(), CliError> { register_adaptive_component().map_err(|error| { CliError::Config(format!("adaptive plugin registration failed: {error}")) })?; diff --git a/crates/cli/src/pricing.rs b/crates/cli/src/pricing.rs new file mode 100644 index 00000000..24d02e60 --- /dev/null +++ b/crates/cli/src/pricing.rs @@ -0,0 +1,278 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Pricing catalog CLI helpers. + +use std::path::Path; + +use nemo_relay::codec::pricing::{ + ModelPricing, PricingCatalog, PricingConfig, PricingSourceConfig, +}; +use nemo_relay::codec::response::Usage; +use nemo_relay::plugin::{PluginComponentSpec, PluginConfig}; +use serde_json::Value; + +use crate::config::{ + PricingAddSourceCommand, PricingInitCommand, PricingResolveCommand, PricingScopeArgs, + PricingValidateCommand, ServerArgs, resolve_server_config, +}; +use crate::error::CliError; +use crate::plugins::config_io::{ + TargetScope, read_plugin_config, target_path, validate_config, write_plugin_config, +}; + +const PRICING_PLUGIN_KIND: &str = "pricing"; + +pub(crate) fn validate(command: PricingValidateCommand) -> Result<(), CliError> { + let catalog = read_pricing_catalog(&command.path)?; + let entries = catalog.entries.len(); + println!( + "Valid pricing catalog: {} ({entries} {})", + command.path.display(), + plural(entries, "entry", "entries") + ); + Ok(()) +} + +pub(crate) fn init(command: PricingInitCommand) -> Result<(), CliError> { + let scope = target_pricing_scope(&command.scope)?; + let path = target_path(scope)?; + let mut plugin_config = read_plugin_config(&path)?; + let index = ensure_pricing_component(&mut plugin_config)?; + let pricing_config = pricing_config_from_component(&plugin_config.components[index])?; + store_pricing_config(&mut plugin_config.components[index], &pricing_config)?; + plugin_config.components[index].enabled = true; + validate_config(&plugin_config)?; + write_plugin_config(&path, &plugin_config)?; + println!("Initialized pricing config: {}", path.display()); + Ok(()) +} + +pub(crate) fn add_source(command: PricingAddSourceCommand) -> Result<(), CliError> { + let source_path = std::fs::canonicalize(&command.path).map_err(|source| { + CliError::Config(format!( + "could not canonicalize pricing catalog '{}': {source}", + command.path.display() + )) + })?; + read_pricing_catalog(&source_path)?; + let scope = target_pricing_scope(&command.scope)?; + let path = target_path(scope)?; + let mut plugin_config = read_plugin_config(&path)?; + let index = ensure_pricing_component(&mut plugin_config)?; + let mut pricing_config = pricing_config_from_component(&plugin_config.components[index])?; + let source = PricingSourceConfig::File { path: source_path }; + + if !pricing_config.sources.contains(&source) { + if command.append { + pricing_config.sources.push(source); + } else { + pricing_config.sources.insert(0, source); + } + } + + store_pricing_config(&mut plugin_config.components[index], &pricing_config)?; + plugin_config.components[index].enabled = true; + validate_config(&plugin_config)?; + write_plugin_config(&path, &plugin_config)?; + println!( + "Added pricing source: {} -> {}", + command.path.display(), + path.display() + ); + Ok(()) +} + +pub(crate) fn resolve(command: PricingResolveCommand) -> Result<(), CliError> { + let sources = pricing_catalog_sources_from_current_config()?; + if sources.is_empty() { + return Err(CliError::Config( + "no pricing sources configured; run `nemo-relay pricing add-source ` or enable the pricing component".into(), + )); + } + let resolved = resolve_pricing(&sources, command.provider.as_deref(), &command.model) + .ok_or_else(|| { + CliError::Config(format!( + "no pricing entry matched provider={} model={}", + command.provider.as_deref().unwrap_or(""), + command.model + )) + })?; + let pricing = resolved.pricing; + + println!("Resolved pricing"); + println!("source = {}", resolved.source); + println!("provider = {}", pricing.provider); + println!("model = {}", pricing.model_id); + println!("pricing_as_of = {}", pricing.pricing_as_of); + println!("pricing_source = {}", pricing.pricing_source); + + let usage = Usage { + prompt_tokens: command.prompt_tokens, + completion_tokens: command.completion_tokens, + total_tokens: None, + cache_read_tokens: command.cache_read_tokens, + cache_write_tokens: command.cache_write_tokens, + cost: None, + }; + if usage_has_tokens(&usage) { + if let Some(cost) = pricing.estimate_cost(&usage) { + if let Some(total) = cost.total { + println!("estimated_total = {total}"); + println!("currency = {}", cost.currency); + } else { + println!("estimated_total = unavailable"); + } + } else { + println!("estimated_total = unavailable"); + } + } + Ok(()) +} + +fn read_pricing_catalog(path: &Path) -> Result { + let raw = std::fs::read_to_string(path).map_err(|source| { + CliError::Config(format!( + "could not read pricing catalog '{}': {source}", + path.display() + )) + })?; + PricingCatalog::from_json_str(&raw).map_err(|error| { + CliError::Config(format!( + "invalid pricing catalog '{}': {error}", + path.display() + )) + }) +} + +#[derive(Debug, Clone)] +struct PricingCatalogSource { + label: String, + catalog: PricingCatalog, +} + +#[derive(Debug, Clone)] +struct ResolvedPricing { + source: String, + pricing: ModelPricing, +} + +fn pricing_catalog_sources_from_current_config() -> Result, CliError> { + let resolved = resolve_server_config(&ServerArgs::default())?; + let Some(plugin_config) = resolved.gateway.plugin_config else { + return Ok(vec![]); + }; + let config: PluginConfig = serde_json::from_value(plugin_config) + .map_err(|error| CliError::Config(format!("invalid plugin config: {error}")))?; + let Some(component) = config + .components + .iter() + .find(|component| component.kind == PRICING_PLUGIN_KIND && component.enabled) + else { + return Ok(vec![]); + }; + let pricing_config = pricing_config_from_component(component)?; + pricing_catalog_sources_from_config(&pricing_config) +} + +fn pricing_catalog_sources_from_config( + config: &PricingConfig, +) -> Result, CliError> { + let mut sources = Vec::new(); + for (index, source) in config.sources.iter().enumerate() { + match source { + PricingSourceConfig::Inline { catalog } => sources.push(PricingCatalogSource { + label: format!("inline:{index}"), + catalog: catalog.clone(), + }), + PricingSourceConfig::File { path } => sources.push(PricingCatalogSource { + label: format!("file:{}", path.display()), + catalog: read_pricing_catalog(path)?, + }), + } + } + Ok(sources) +} + +fn resolve_pricing( + sources: &[PricingCatalogSource], + provider: Option<&str>, + model: &str, +) -> Option { + sources.iter().find_map(|source| { + source + .catalog + .pricing_for(provider, model) + .map(|pricing| ResolvedPricing { + source: source.label.clone(), + pricing, + }) + }) +} + +fn target_pricing_scope(scope: &PricingScopeArgs) -> Result { + let selected = [scope.user, scope.project, scope.global] + .into_iter() + .filter(|selected| *selected) + .count(); + if selected > 1 { + return Err(CliError::Config( + "choose only one of --user, --project, or --global".into(), + )); + } + if scope.project { + Ok(TargetScope::Project) + } else if scope.global { + Ok(TargetScope::Global) + } else { + Ok(TargetScope::User) + } +} + +fn ensure_pricing_component(config: &mut PluginConfig) -> Result { + if let Some(index) = config + .components + .iter() + .position(|component| component.kind == PRICING_PLUGIN_KIND) + { + return Ok(index); + } + let mut component = PluginComponentSpec::new(PRICING_PLUGIN_KIND); + store_pricing_config(&mut component, &PricingConfig::default())?; + config.components.push(component); + Ok(config.components.len() - 1) +} + +fn pricing_config_from_component( + component: &PluginComponentSpec, +) -> Result { + serde_json::from_value(Value::Object(component.config.clone())) + .map_err(|error| CliError::Config(format!("invalid pricing config: {error}"))) +} + +fn store_pricing_config( + component: &mut PluginComponentSpec, + config: &PricingConfig, +) -> Result<(), CliError> { + let value = serde_json::to_value(config).map_err(|error| { + CliError::Config(format!("could not serialize pricing config: {error}")) + })?; + let Value::Object(object) = value else { + return Err(CliError::Config( + "could not serialize pricing config as an object".into(), + )); + }; + component.config = object; + Ok(()) +} + +fn usage_has_tokens(usage: &Usage) -> bool { + usage.prompt_tokens.is_some() + || usage.completion_tokens.is_some() + || usage.cache_read_tokens.is_some() + || usage.cache_write_tokens.is_some() +} + +fn plural<'a>(count: usize, singular: &'a str, plural: &'a str) -> &'a str { + if count == 1 { singular } else { plural } +} diff --git a/crates/cli/tests/cli_tests.rs b/crates/cli/tests/cli_tests.rs index 3a67b9d4..88ad0e8e 100644 --- a/crates/cli/tests/cli_tests.rs +++ b/crates/cli/tests/cli_tests.rs @@ -13,6 +13,34 @@ fn gateway_bin() -> &'static str { env!("CARGO_BIN_EXE_nemo-relay") } +fn toml_basic_string(value: &str) -> String { + let escaped = value + .chars() + .map(|character| match character { + '\\' => "\\\\".to_string(), + '"' => "\\\"".to_string(), + '\n' => "\\n".to_string(), + '\t' => "\\t".to_string(), + '\r' => "\\r".to_string(), + '\u{08}' => "\\b".to_string(), + '\u{0c}' => "\\f".to_string(), + '\u{00}'..='\u{1f}' | '\u{7f}' => { + format!("\\u{:04X}", character as u32) + } + character => character.to_string(), + }) + .collect::(); + format!("\"{escaped}\"") +} + +#[test] +fn toml_basic_string_escapes_toml_control_characters() { + assert_eq!( + toml_basic_string("a\\b\"c\nd\te\rf\u{08}g\u{0c}h\u{01}\u{7f}"), + "\"a\\\\b\\\"c\\nd\\te\\rf\\bg\\fh\\u0001\\u007F\"" + ); +} + #[test] fn cli_help_exits_successfully() { let output = Command::new(gateway_bin()).arg("--help").output().unwrap(); @@ -52,7 +80,10 @@ fn cli_agents_json_emits_supported_agent_shapes() { #[test] fn cli_doctor_json_emits_versioned_report() { let temp = tempfile::tempdir().unwrap(); + let cwd = temp.path().join("workdir"); + std::fs::create_dir_all(&cwd).unwrap(); let output = Command::new(gateway_bin()) + .current_dir(&cwd) .env("XDG_CONFIG_HOME", temp.path().join("xdg")) .env("HOME", temp.path()) .args(["doctor", "--json"]) @@ -97,6 +128,187 @@ fn cli_plugins_edit_requires_tty() { ); } +#[test] +fn cli_pricing_validate_accepts_valid_catalog() { + let temp = tempfile::tempdir().unwrap(); + let catalog = temp.path().join("pricing.json"); + std::fs::write(&catalog, pricing_catalog_json("test-model")).unwrap(); + + let output = Command::new(gateway_bin()) + .args(["pricing", "validate"]) + .arg(&catalog) + .output() + .unwrap(); + + assert!(output.status.success()); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("Valid pricing catalog")); + assert!(stdout.contains("1 entry")); +} + +#[test] +fn cli_pricing_validate_rejects_invalid_catalog() { + let temp = tempfile::tempdir().unwrap(); + let catalog = temp.path().join("pricing.json"); + std::fs::write( + &catalog, + r#"{ + "version": 1, + "entries": [{ + "provider": "test", + "model_id": "bad-model", + "prompt_cache": { "read_accounting": "included_in_prompt_tokens" }, + "pricing_as_of": "2026-06-05", + "pricing_source": "test" + }] +}"#, + ) + .unwrap(); + + let output = Command::new(gateway_bin()) + .args(["pricing", "validate"]) + .arg(&catalog) + .output() + .unwrap(); + + assert!(!output.status.success()); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!(stderr.contains("invalid pricing catalog")); + assert!(stderr.contains("rates or rate_schedule")); +} + +#[test] +fn cli_pricing_init_creates_project_pricing_component() { + let temp = tempfile::tempdir().unwrap(); + let project = temp.path().join("project"); + std::fs::create_dir_all(&project).unwrap(); + + let output = Command::new(gateway_bin()) + .current_dir(&project) + .args(["pricing", "init", "--project"]) + .output() + .unwrap(); + + assert!(output.status.success()); + let path = project.join(".nemo-relay/plugins.toml"); + let rendered = std::fs::read_to_string(path).unwrap(); + assert!(rendered.contains("kind = \"pricing\"")); + assert!(!rendered.contains("include_bundled")); +} + +#[test] +fn cli_pricing_add_source_validates_and_updates_user_plugin_config() { + let temp = tempfile::tempdir().unwrap(); + let catalog = temp.path().join("pricing.json"); + std::fs::write(&catalog, pricing_catalog_json("custom-model")).unwrap(); + let cwd = temp.path().join("workdir"); + std::fs::create_dir_all(&cwd).unwrap(); + std::fs::copy(&catalog, cwd.join("pricing.json")).unwrap(); + let canonical = std::fs::canonicalize(cwd.join("pricing.json")).unwrap(); + + let output = Command::new(gateway_bin()) + .current_dir(&cwd) + .env("XDG_CONFIG_HOME", temp.path().join("xdg")) + .env("HOME", temp.path()) + .args(["pricing", "add-source"]) + .arg("pricing.json") + .output() + .unwrap(); + + assert!(output.status.success()); + let rendered = std::fs::read_to_string( + temp.path() + .join("xdg") + .join("nemo-relay") + .join("plugins.toml"), + ) + .unwrap(); + assert!(rendered.contains("kind = \"pricing\"")); + assert!(rendered.contains("type = \"file\"")); + assert!(rendered.contains(canonical.to_str().unwrap())); +} + +#[test] +fn cli_pricing_resolve_reports_source_match_and_estimate() { + let temp = tempfile::tempdir().unwrap(); + let catalog = temp.path().join("pricing.json"); + let xdg = temp.path().join("xdg/nemo-relay"); + let project = temp.path().join("project"); + std::fs::create_dir_all(&xdg).unwrap(); + std::fs::create_dir_all(&project).unwrap(); + std::fs::write(&catalog, pricing_catalog_json("custom-model")).unwrap(); + std::fs::write( + xdg.join("plugins.toml"), + format!( + r#" +[[components]] +kind = "pricing" + +[components.config] +[[components.config.sources]] +type = "file" +path = {} +"#, + toml_basic_string(&catalog.display().to_string()) + ), + ) + .unwrap(); + + let output = Command::new(gateway_bin()) + .current_dir(&project) + .env("XDG_CONFIG_HOME", temp.path().join("xdg")) + .env("HOME", temp.path()) + .args([ + "pricing", + "resolve", + "custom-model", + "--provider", + "test", + "--prompt-tokens", + "1000", + "--completion-tokens", + "500", + ]) + .output() + .unwrap(); + + assert!( + output.status.success(), + "stderr was:\n{}\nstdout was:\n{}", + String::from_utf8_lossy(&output.stderr), + String::from_utf8_lossy(&output.stdout) + ); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("Resolved pricing")); + assert!(stdout.contains(&format!("source = file:{}", catalog.display()))); + assert!(stdout.contains("provider = test")); + assert!(stdout.contains("model = custom-model")); + assert!(stdout.contains("estimated_total")); + assert!(stdout.contains("currency = USD")); +} + +#[test] +fn cli_pricing_resolve_reports_missing_sources_distinctly() { + let temp = tempfile::tempdir().unwrap(); + let cwd = temp.path().join("workdir"); + std::fs::create_dir_all(&cwd).unwrap(); + + let output = Command::new(gateway_bin()) + .current_dir(&cwd) + .env("XDG_CONFIG_HOME", temp.path().join("xdg")) + .env("HOME", temp.path()) + .args(["pricing", "resolve", "custom-model"]) + .output() + .unwrap(); + + assert!(!output.status.success()); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + stderr.contains("no pricing sources configured"), + "expected missing pricing source error, got:\n{stderr}" + ); +} + #[test] fn cli_help_lists_easy_path_agent_shortcuts() { let output = Command::new(gateway_bin()).arg("--help").output().unwrap(); @@ -560,3 +772,23 @@ fn read_http_request(stream: &mut std::net::TcpStream) -> String { fn find_header_end(buffer: &[u8]) -> Option { buffer.windows(4).position(|window| window == b"\r\n\r\n") } + +fn pricing_catalog_json(model_id: &str) -> String { + format!( + r#"{{ + "version": 1, + "entries": [{{ + "provider": "test", + "model_id": "{model_id}", + "rates": {{ + "input_per_million": 1.0, + "output_per_million": 2.0, + "cache_read_per_million": 0.1 + }}, + "prompt_cache": {{ "read_accounting": "included_in_prompt_tokens" }}, + "pricing_as_of": "2026-06-05", + "pricing_source": "test" + }}] +}}"# + ) +} diff --git a/crates/cli/tests/coverage/config_tests.rs b/crates/cli/tests/coverage/config_tests.rs index d7426da4..e6ea8dda 100644 --- a/crates/cli/tests/coverage/config_tests.rs +++ b/crates/cli/tests/coverage/config_tests.rs @@ -403,6 +403,70 @@ source = "user" ); } +#[test] +fn discovered_pricing_plugin_sources_layer_user_before_lower_priority_sources() { + let temp = tempfile::tempdir().unwrap(); + let system_plugin = temp.path().join("system-plugins.toml"); + let user_plugin = temp.path().join("user-plugins.toml"); + std::fs::write( + &system_plugin, + r#" +version = 1 + +[[components]] +kind = "pricing" +enabled = true + +[[components.config.sources]] +type = "file" +path = "/etc/nemo-relay/pricing.json" +"#, + ) + .unwrap(); + std::fs::write( + &user_plugin, + r#" +version = 1 + +[[components]] +kind = "pricing" +enabled = true + +[[components.config.sources]] +type = "file" +path = "/home/user/.config/nemo-relay/pricing.json" +"#, + ) + .unwrap(); + + let resolved = load_plugin_toml_config_from_paths(vec![system_plugin, user_plugin]).unwrap(); + + assert_eq!( + resolved.map(|config| config.value), + Some(json!({ + "version": 1, + "components": [ + { + "kind": "pricing", + "enabled": true, + "config": { + "sources": [ + { + "type": "file", + "path": "/home/user/.config/nemo-relay/pricing.json" + }, + { + "type": "file", + "path": "/etc/nemo-relay/pricing.json" + } + ] + } + } + ] + })) + ); +} + #[test] fn discovered_plugins_toml_can_disable_lower_priority_observability_section() { let temp = tempfile::tempdir().unwrap(); diff --git a/crates/cli/tests/coverage/doctor_tests.rs b/crates/cli/tests/coverage/doctor_tests.rs index f2ef6891..8f5d1453 100644 --- a/crates/cli/tests/coverage/doctor_tests.rs +++ b/crates/cli/tests/coverage/doctor_tests.rs @@ -777,6 +777,91 @@ async fn collect_observability_rejects_websocket_endpoint_http_scheme() { assert!(endpoint.details.contains("must be ws or wss")); } +#[tokio::test] +async fn collect_observability_validates_pricing_file_source() { + let temp = tempfile::tempdir().unwrap(); + let catalog = temp.path().join("pricing.json"); + std::fs::write( + &catalog, + serde_json::json!({ + "version": 1, + "entries": [{ + "provider": "openai", + "model_id": "gpt-test", + "currency": "USD", + "unit": "per_token", + "rates": { + "input_per_million": 1.0, + "output_per_million": 2.0 + }, + "prompt_cache": { + "read_accounting": "separate" + }, + "pricing_as_of": "2026-06-06", + "pricing_source": "test" + }] + }) + .to_string(), + ) + .unwrap(); + let gateway = GatewayConfig { + plugin_config: Some(serde_json::json!({ + "version": 1, + "components": [{ + "kind": "pricing", + "config": { + "sources": [{ + "type": "file", + "path": catalog + }] + } + }] + })), + ..GatewayConfig::default() + }; + + let checks = collect_observability(&gateway).await; + + let pricing = checks + .iter() + .find(|check| check.name == "Pricing source") + .expect("pricing source check"); + assert_eq!(pricing.status, Status::Pass); + assert!(pricing.details.contains("valid (1 entries)")); +} + +#[tokio::test] +async fn collect_observability_fails_for_missing_pricing_file_source() { + let missing = tempfile::tempdir() + .unwrap() + .path() + .join("missing-pricing.json"); + let gateway = GatewayConfig { + plugin_config: Some(serde_json::json!({ + "version": 1, + "components": [{ + "kind": "pricing", + "config": { + "sources": [{ + "type": "file", + "path": missing + }] + } + }] + })), + ..GatewayConfig::default() + }; + + let checks = collect_observability(&gateway).await; + + let pricing = checks + .iter() + .find(|check| check.name == "Pricing source") + .expect("pricing source check"); + assert_eq!(pricing.status, Status::Fail); + assert!(pricing.details.contains("unreadable")); +} + #[test] fn format_agents_human_lists_supported_and_separates_detected() { let agents = vec![ diff --git a/crates/cli/tests/coverage/session_tests.rs b/crates/cli/tests/coverage/session_tests.rs index 2f8480fc..14891adf 100644 --- a/crates/cli/tests/coverage/session_tests.rs +++ b/crates/cli/tests/coverage/session_tests.rs @@ -1943,9 +1943,85 @@ async fn writes_hermes_api_hook_usage_to_atif_metrics() { assert_eq!(atif["steps"][1]["metrics"]["prompt_tokens"], json!(10)); assert_eq!(atif["steps"][1]["metrics"]["completion_tokens"], json!(5)); assert_eq!(atif["steps"][1]["metrics"]["cached_tokens"], json!(3)); + assert!(atif["steps"][1]["metrics"].get("cost_usd").is_none()); assert_eq!(atif["final_metrics"]["total_prompt_tokens"], json!(10)); assert_eq!(atif["final_metrics"]["total_completion_tokens"], json!(5)); assert_eq!(atif["final_metrics"]["total_cached_tokens"], json!(3)); + assert!(atif["final_metrics"].get("total_cost_usd").is_none()); +} + +#[tokio::test] +async fn writes_hermes_api_hook_reported_cost_to_atif_metrics() { + let _guard = OBSERVABILITY_PLUGIN_TEST_LOCK.lock().await; + let temp = tempfile::tempdir().unwrap(); + let atif_dir = temp.path().join("atif"); + install_test_atif_plugin(&atif_dir).await; + let config = GatewayConfig { + bind: "127.0.0.1:0".parse().unwrap(), + openai_base_url: "http://127.0.0.1".into(), + + anthropic_base_url: "http://127.0.0.1".into(), + metadata: None, + plugin_config: None, + }; + let manager = SessionManager::new(config); + let headers = HeaderMap::new(); + + manager + .apply_events( + &headers, + vec![ + NormalizedEvent::AgentStarted(SessionEvent { + session_id: "hermes-cost".into(), + agent_kind: AgentKind::Hermes, + event_name: "on_session_start".into(), + payload: json!({}), + metadata: json!({}), + }), + NormalizedEvent::LlmStarted(LlmEvent { + session_id: "hermes-cost".into(), + agent_kind: AgentKind::Hermes, + event_name: "pre_api_request".into(), + api_call_id: "hermes-cost:task-1:1".into(), + provider: "custom".into(), + model_name: Some("qwen".into()), + request: json!({ "model": "qwen" }), + response: Value::Null, + metadata: json!({}), + }), + NormalizedEvent::LlmEnded(LlmEvent { + session_id: "hermes-cost".into(), + agent_kind: AgentKind::Hermes, + event_name: "post_api_request".into(), + api_call_id: "hermes-cost:task-1:1".into(), + provider: "custom".into(), + model_name: Some("qwen".into()), + request: json!({}), + response: json!({ + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "cost_usd": 0.123 + } + }), + metadata: json!({}), + }), + NormalizedEvent::AgentEnded(SessionEvent { + session_id: "hermes-cost".into(), + agent_kind: AgentKind::Hermes, + event_name: "on_session_finalize".into(), + payload: json!({}), + metadata: json!({}), + }), + ], + ) + .await + .unwrap(); + + clear_plugin_configuration().unwrap(); + let atif = read_atif_for_session(&atif_dir, "hermes-cost"); + assert_eq!(atif["steps"][1]["metrics"]["cost_usd"], json!(0.123)); + assert_eq!(atif["final_metrics"]["total_cost_usd"], json!(0.123)); } #[tokio::test] diff --git a/crates/core/src/api/llm.rs b/crates/core/src/api/llm.rs index 91c149e6..cb585c72 100644 --- a/crates/core/src/api/llm.rs +++ b/crates/core/src/api/llm.rs @@ -23,7 +23,7 @@ use crate::api::shared::{ snapshot_event_subscribers, }; use crate::codec::request::AnnotatedLlmRequest; -use crate::codec::response::AnnotatedLlmResponse; +use crate::codec::response::{AnnotatedLlmResponse, attach_estimated_cost_for_provider}; use crate::codec::traits::{LlmCodec, LlmResponseCodec}; use crate::error::{FlowError, Result}; use crate::json::Json; @@ -585,7 +585,11 @@ pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result { Ok(response) => { let annotated_response = response_codec .as_ref() - .and_then(|codec| codec.decode_response(&response).ok()) + .and_then(|codec| { + let mut decoded = codec.decode_response(&response).ok()?; + attach_estimated_cost_for_provider(&mut decoded, Some(&name)); + Some(decoded) + }) .map(Arc::new); llm_call_end( LlmCallEndParams::builder() diff --git a/crates/core/src/codec/anthropic.rs b/crates/core/src/codec/anthropic.rs index 1fbf953e..837e6333 100644 --- a/crates/core/src/codec/anthropic.rs +++ b/crates/core/src/codec/anthropic.rs @@ -27,7 +27,8 @@ use super::request::{ ToolChoiceFunction, ToolChoiceFunctionName, ToolDefinition, }; use super::response::{ - AnnotatedLlmResponse, ApiSpecificResponse, FinishReason, ResponseToolCall, Usage, + AnnotatedLlmResponse, ApiSpecificResponse, FinishReason, RawUsageCost, ResponseToolCall, Usage, + estimate_cost_for_provider, infer_model_provider, provider_reported_cost, }; use super::traits::{LlmCodec, LlmResponseCodec}; @@ -65,6 +66,9 @@ struct RawAnthropicUsage { output_tokens: Option, cache_read_input_tokens: Option, cache_creation_input_tokens: Option, + #[serde(rename = "cost_usd")] + provider_cost: Option, + cost: Option, } // --------------------------------------------------------------------------- @@ -349,10 +353,12 @@ impl LlmResponseCodec for AnthropicMessagesCodec { let finish_reason = raw.stop_reason.as_deref().map(map_anthropic_stop_reason); // Map usage. + let model_for_pricing = raw.model.as_deref(); + let model_provider = infer_model_provider("anthropic", model_for_pricing); let usage = raw.usage.map(|u| { let prompt = u.input_tokens; let completion = u.output_tokens; - Usage { + let mut usage = Usage { prompt_tokens: prompt, completion_tokens: completion, // Anthropic does not supply total_tokens; compute it. @@ -362,7 +368,14 @@ impl LlmResponseCodec for AnthropicMessagesCodec { }, cache_read_tokens: u.cache_read_input_tokens, cache_write_tokens: u.cache_creation_input_tokens, + cost: provider_reported_cost(u.provider_cost, u.cost), + }; + if usage.cost.is_none() { + usage.cost = model_for_pricing.and_then(|model| { + estimate_cost_for_provider(model_provider.as_deref(), model, &usage) + }); } + usage }); // Build API-specific fields: all content blocks + stop_sequence. diff --git a/crates/core/src/codec/mod.rs b/crates/core/src/codec/mod.rs index 82e46198..20653c14 100644 --- a/crates/core/src/codec/mod.rs +++ b/crates/core/src/codec/mod.rs @@ -14,6 +14,7 @@ pub mod anthropic; pub mod openai_chat; pub mod openai_responses; +pub mod pricing; pub mod request; pub mod response; pub mod streaming; diff --git a/crates/core/src/codec/openai_chat.rs b/crates/core/src/codec/openai_chat.rs index e04b688d..5bf992d5 100644 --- a/crates/core/src/codec/openai_chat.rs +++ b/crates/core/src/codec/openai_chat.rs @@ -14,7 +14,8 @@ use crate::json::Json; use super::request::{AnnotatedLlmRequest, GenerationParams, Message, ToolChoice, ToolDefinition}; use super::response::{ - AnnotatedLlmResponse, ApiSpecificResponse, FinishReason, ResponseToolCall, Usage, + AnnotatedLlmResponse, ApiSpecificResponse, FinishReason, RawUsageCost, ResponseToolCall, Usage, + estimate_cost_for_provider, infer_model_provider, provider_reported_cost, }; use super::traits::{LlmCodec, LlmResponseCodec}; @@ -72,6 +73,9 @@ struct RawChatUsage { completion_tokens: Option, total_tokens: Option, prompt_tokens_details: Option, + #[serde(rename = "cost_usd")] + provider_cost: Option, + cost: Option, } #[derive(Deserialize)] @@ -169,12 +173,23 @@ impl LlmResponseCodec for OpenAIChatCodec { .map(map_chat_finish_reason); // Map usage. - let usage = raw.usage.map(|u| Usage { - prompt_tokens: u.prompt_tokens, - completion_tokens: u.completion_tokens, - total_tokens: u.total_tokens, - cache_read_tokens: u.prompt_tokens_details.and_then(|d| d.cached_tokens), - cache_write_tokens: None, + let model_for_pricing = raw.model.as_deref(); + let model_provider = infer_model_provider("openai", model_for_pricing); + let usage = raw.usage.map(|u| { + let mut usage = Usage { + prompt_tokens: u.prompt_tokens, + completion_tokens: u.completion_tokens, + total_tokens: u.total_tokens, + cache_read_tokens: u.prompt_tokens_details.and_then(|d| d.cached_tokens), + cache_write_tokens: None, + cost: provider_reported_cost(u.provider_cost, u.cost), + }; + if usage.cost.is_none() { + usage.cost = model_for_pricing.and_then(|model| { + estimate_cost_for_provider(model_provider.as_deref(), model, &usage) + }); + } + usage }); // Build API-specific fields. diff --git a/crates/core/src/codec/openai_responses.rs b/crates/core/src/codec/openai_responses.rs index 8d57951a..f2661eef 100644 --- a/crates/core/src/codec/openai_responses.rs +++ b/crates/core/src/codec/openai_responses.rs @@ -26,7 +26,8 @@ use super::request::{ ToolChoiceFunctionName, ToolDefinition, }; use super::response::{ - AnnotatedLlmResponse, ApiSpecificResponse, FinishReason, ResponseToolCall, Usage, + AnnotatedLlmResponse, ApiSpecificResponse, FinishReason, RawUsageCost, ResponseToolCall, Usage, + estimate_cost_for_provider, infer_model_provider, provider_reported_cost, }; use super::traits::{LlmCodec, LlmResponseCodec}; @@ -65,6 +66,9 @@ struct RawResponsesUsage { total_tokens: Option, input_tokens_details: Option, output_tokens_details: Option, + #[serde(rename = "cost_usd")] + provider_cost: Option, + cost: Option, } #[derive(Deserialize, Clone)] @@ -471,15 +475,26 @@ impl LlmResponseCodec for OpenAIResponsesCodec { }); // Map usage. - let usage = raw.usage.map(|u| Usage { - prompt_tokens: u.input_tokens, - completion_tokens: u.output_tokens, - total_tokens: u.total_tokens, - cache_read_tokens: u - .input_tokens_details - .as_ref() - .and_then(|d| d.cached_tokens), - cache_write_tokens: None, + let model_for_pricing = raw.model.as_deref(); + let model_provider = infer_model_provider("openai", model_for_pricing); + let usage = raw.usage.map(|u| { + let mut usage = Usage { + prompt_tokens: u.input_tokens, + completion_tokens: u.output_tokens, + total_tokens: u.total_tokens, + cache_read_tokens: u + .input_tokens_details + .as_ref() + .and_then(|d| d.cached_tokens), + cache_write_tokens: None, + cost: provider_reported_cost(u.provider_cost, u.cost), + }; + if usage.cost.is_none() { + usage.cost = model_for_pricing.and_then(|model| { + estimate_cost_for_provider(model_provider.as_deref(), model, &usage) + }); + } + usage }); // Build API-specific fields. diff --git a/crates/core/src/codec/pricing.rs b/crates/core/src/codec/pricing.rs new file mode 100644 index 00000000..8aa56cf0 --- /dev/null +++ b/crates/core/src/codec/pricing.rs @@ -0,0 +1,817 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Data-driven LLM model pricing used to layer cost estimates onto usage. +//! +//! Pricing is deliberately separate from response normalization so adding +//! providers, aliases, or cache-accounting rules does not require editing +//! [`AnnotatedLlmResponse`](super::response::AnnotatedLlmResponse). + +use std::collections::HashSet; +use std::path::PathBuf; +#[cfg(test)] +use std::sync::Mutex; +use std::sync::{Arc, LazyLock, RwLock}; + +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use super::response::{AnnotatedLlmResponse, CostEstimate, CostSource, Usage}; + +const PRICING_CATALOG_VERSION: u32 = 1; + +static ACTIVE_PRICING_RESOLVER: LazyLock>> = + LazyLock::new(|| RwLock::new(Arc::new(PricingResolver::default()))); + +#[cfg(test)] +pub(crate) fn pricing_test_mutex() -> &'static Mutex<()> { + static PRICING_TEST_MUTEX: LazyLock> = LazyLock::new(|| Mutex::new(())); + &PRICING_TEST_MUTEX +} + +/// Errors produced while parsing or validating a pricing catalog. +#[derive(Debug, Error)] +pub enum PricingCatalogError { + /// The catalog was not valid JSON for the catalog schema. + #[error("invalid pricing catalog JSON: {0}")] + Json(#[from] serde_json::Error), + /// Two entries or aliases normalize to the same model key. + #[error("duplicate pricing model alias '{model}'")] + DuplicateModelAlias { + /// Normalized model key that appeared more than once. + model: String, + }, + /// The catalog schema version is not supported by this Relay build. + #[error("unsupported pricing catalog version {version}")] + UnsupportedVersion { + /// Version number from the catalog payload. + version: u32, + }, + /// A required text field was empty. + #[error("pricing entry {entry_index} has empty {field}")] + EmptyField { + /// Zero-based index of the invalid catalog entry. + entry_index: usize, + /// Name of the invalid field. + field: String, + }, + /// A price was negative or non-finite. + #[error("pricing entry {entry_index} has invalid {field}: {value}")] + InvalidRate { + /// Zero-based index of the invalid catalog entry. + entry_index: usize, + /// Name of the invalid rate field. + field: String, + /// Invalid field value. + value: f64, + }, + /// A pricing catalog file could not be read. + #[error("could not read pricing catalog file '{}': {source}", path.display())] + FileRead { + /// Catalog path. + path: PathBuf, + /// Underlying I/O error. + source: std::io::Error, + }, + /// The active pricing resolver lock was poisoned. + #[error("pricing resolver lock poisoned: {0}")] + LockPoisoned(String), +} + +/// Collection of model pricing entries. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PricingCatalog { + /// Catalog schema version. + pub version: u32, + /// Pricing entries keyed by canonical model ID plus aliases. + pub entries: Vec, +} + +impl PricingCatalog { + /// Parses and validates a pricing catalog from JSON. + pub fn from_json_str(catalog_json: &str) -> Result { + let catalog: Self = serde_json::from_str(catalog_json)?; + catalog.validate()?; + Ok(catalog) + } + + /// Finds pricing for a canonical model ID or alias. + #[must_use] + pub fn pricing_for_model(&self, model: &str) -> Option { + self.pricing_for(None, model) + } + + /// Finds pricing for a provider/model pair, with model-only fallback. + #[must_use] + pub fn pricing_for(&self, provider: Option<&str>, model: &str) -> Option { + let model_keys = normalized_model_lookup_keys(provider, model); + if model_keys.is_empty() { + return None; + } + + model_keys.iter().find_map(|model_key| { + self.entries + .iter() + .find(|entry| entry.matches_model(model_key)) + .cloned() + }) + } + + fn validate(&self) -> Result<(), PricingCatalogError> { + if self.version != PRICING_CATALOG_VERSION { + return Err(PricingCatalogError::UnsupportedVersion { + version: self.version, + }); + } + + let mut seen = HashSet::new(); + + for (entry_index, entry) in self.entries.iter().enumerate() { + entry.validate(entry_index)?; + + for model_key in entry.provider_model_keys() { + if !seen.insert(model_key.clone()) { + return Err(PricingCatalogError::DuplicateModelAlias { model: model_key }); + } + } + } + + Ok(()) + } +} + +/// Runtime pricing resolver configuration. +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct PricingConfig { + /// Pricing sources in precedence order. + #[serde(default)] + pub sources: Vec, +} + +/// Declarative pricing source supported by Relay configuration. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum PricingSourceConfig { + /// Inline catalog entries from project, user, system, or plugin config. + Inline { + /// Inline catalog payload. + catalog: PricingCatalog, + }, + /// Catalog loaded from a JSON file. + File { + /// JSON pricing catalog path. + path: PathBuf, + }, +} + +/// Pluggable pricing source interface. +/// +/// Database, service-backed, or enterprise-managed pricing integrations should +/// implement this trait and return a validated catalog snapshot. The LLM hot +/// path uses [`PricingResolver`], so sources can refresh out-of-band without +/// making each response decode perform network or database I/O. +pub trait PricingSource: Send + Sync { + /// Stable source name for diagnostics. + fn source_name(&self) -> &str; + + /// Loads a catalog snapshot from this source. + fn load_catalog(&self) -> Result, PricingCatalogError>; +} + +/// Ordered pricing lookup chain. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct PricingResolver { + catalogs: Vec, +} + +impl PricingResolver { + /// Builds a resolver from already-loaded catalogs in precedence order. + #[must_use] + pub fn from_catalogs(catalogs: Vec) -> Self { + Self { catalogs } + } + + /// Builds a resolver from declarative config. + pub fn from_config(config: &PricingConfig) -> Result { + let mut catalogs = Vec::new(); + for source in &config.sources { + match source { + PricingSourceConfig::Inline { catalog } => { + catalog.validate()?; + catalogs.push(catalog.clone()); + } + PricingSourceConfig::File { path } => { + let raw = std::fs::read_to_string(path).map_err(|source| { + PricingCatalogError::FileRead { + path: path.clone(), + source, + } + })?; + catalogs.push(PricingCatalog::from_json_str(&raw)?); + } + } + } + Ok(Self { catalogs }) + } + + /// Builds a resolver from imperative source implementations. + pub fn from_sources(sources: Vec>) -> Result { + let mut catalogs = Vec::new(); + for source in sources { + if let Some(catalog) = source.load_catalog()? { + catalog.validate()?; + catalogs.push(catalog); + } + } + Ok(Self { catalogs }) + } + + /// Finds pricing for a canonical model ID or alias. + #[must_use] + pub fn pricing_for_model(&self, model: &str) -> Option { + self.pricing_for(None, model) + } + + /// Finds pricing for a provider/model pair, with model-only fallback. + #[must_use] + pub fn pricing_for(&self, provider: Option<&str>, model: &str) -> Option { + self.catalogs + .iter() + .find_map(|catalog| catalog.pricing_for(provider, model)) + } + + /// Estimates cost for a model/usage pair when pricing is known. + #[must_use] + pub fn estimate_cost(&self, model: &str, usage: &Usage) -> Option { + self.estimate_cost_for_provider(None, model, usage) + } + + /// Estimates cost for a provider/model pair when pricing is known. + #[must_use] + pub fn estimate_cost_for_provider( + &self, + provider: Option<&str>, + model: &str, + usage: &Usage, + ) -> Option { + self.pricing_for(provider, model) + .and_then(|pricing| pricing.estimate_cost(usage)) + } +} + +/// Per-token pricing for a model, expressed in USD per one million tokens. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ModelPricing { + /// Provider that owns this pricing entry. + pub provider: String, + /// Canonical model ID for this pricing entry. + pub model_id: String, + /// Additional model IDs that should use this pricing. + #[serde(default)] + pub aliases: Vec, + /// ISO 4217 currency for this pricing entry. + #[serde(default = "default_pricing_currency")] + pub currency: String, + /// Billing unit represented by this pricing entry. + #[serde(default)] + pub unit: PricingUnit, + /// Token rates expressed as USD per one million tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub rates: Option, + /// Data-driven token rate schedule for threshold-based provider pricing. + #[serde(skip_serializing_if = "Option::is_none")] + pub rate_schedule: Option, + /// Prompt-cache accounting model for this provider/model. + pub prompt_cache: PromptCachePricing, + /// Date this pricing entry was last verified. + pub pricing_as_of: String, + /// Source URL for this pricing entry. + pub pricing_source: String, +} + +impl ModelPricing { + /// Estimates cost for the provided token usage. + #[must_use] + pub fn estimate_cost(&self, usage: &Usage) -> Option { + if self.unit != PricingUnit::PerToken { + return None; + } + let prompt_tokens = usage.prompt_tokens.unwrap_or(0); + let completion_tokens = usage.completion_tokens.unwrap_or(0); + let cache_read_tokens = usage.cache_read_tokens.unwrap_or(0); + let cache_write_tokens = usage.cache_write_tokens.unwrap_or(0); + let rates = self.rates_for_usage(usage)?; + + if prompt_tokens == 0 + && completion_tokens == 0 + && cache_read_tokens == 0 + && cache_write_tokens == 0 + { + return None; + } + + let billable_prompt_tokens = + if self.prompt_cache.read_accounting == CacheReadAccounting::IncludedInPromptTokens { + prompt_tokens.saturating_sub(cache_read_tokens) + } else { + prompt_tokens + }; + + let input_cost = cost_component_if_nonzero(billable_prompt_tokens, rates.input_per_million); + let output_cost = cost_component_if_nonzero(completion_tokens, rates.output_per_million); + let cache_read_cost = rates + .cache_read_per_million + .and_then(|price| cost_component_if_nonzero(cache_read_tokens, price)); + let cache_write_cost = rates + .cache_write_per_million + .and_then(|price| cost_component_if_nonzero(cache_write_tokens, price)); + + let total: f64 = [input_cost, output_cost, cache_read_cost, cache_write_cost] + .into_iter() + .flatten() + .sum(); + + Some(CostEstimate { + total: Some(round_cost_amount(total)), + currency: self.currency.clone(), + input: input_cost, + output: output_cost, + cache_read: cache_read_cost, + cache_write: cache_write_cost, + source: CostSource::ModelPricing, + pricing_provider: Some(self.provider.clone()), + pricing_model: Some(self.model_id.clone()), + pricing_as_of: Some(self.pricing_as_of.clone()), + pricing_source: Some(self.pricing_source.clone()), + }) + } + + fn rates_for_usage(&self, usage: &Usage) -> Option { + if let Some(schedule) = &self.rate_schedule { + return schedule.rates_for_usage(usage); + } + self.rates + } + + fn matches_model(&self, lookup: &ModelLookupKey) -> bool { + if let Some(provider) = lookup.provider.as_deref() + && normalized_provider_name(&self.provider) != provider + { + return false; + } + + self.model_keys().any(|key| key == lookup.model) + } + + fn model_keys(&self) -> impl Iterator + '_ { + std::iter::once(&self.model_id) + .chain(self.aliases.iter()) + .map(|model| normalized_model_name(model)) + .filter(|model| !model.is_empty()) + } + + fn provider_model_keys(&self) -> impl Iterator + '_ { + let provider = normalized_provider_name(&self.provider); + self.model_keys() + .map(move |model| format!("{provider}/{model}")) + } + + fn validate(&self, entry_index: usize) -> Result<(), PricingCatalogError> { + validate_nonempty(entry_index, "provider", &self.provider)?; + validate_nonempty(entry_index, "model_id", &self.model_id)?; + validate_nonempty(entry_index, "currency", &self.currency)?; + validate_nonempty(entry_index, "pricing_as_of", &self.pricing_as_of)?; + validate_nonempty(entry_index, "pricing_source", &self.pricing_source)?; + + if self.unit == PricingUnit::PerToken + && self.rates.is_none() + && self.rate_schedule.is_none() + { + return Err(PricingCatalogError::EmptyField { + entry_index, + field: "rates or rate_schedule".to_string(), + }); + } + if let Some(rates) = &self.rates { + rates.validate(entry_index, "rates")?; + } + if let Some(schedule) = &self.rate_schedule { + schedule.validate(entry_index)?; + } + + Ok(()) + } +} + +/// Billing unit represented by a pricing entry. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum PricingUnit { + /// Token-based pricing. + #[default] + PerToken, + /// Request-based pricing, reserved for future estimation. + PerRequest, + /// Time-based pricing, reserved for future estimation. + PerSecond, + /// GPU-hour amortized pricing for self-hosted models, reserved for future estimation. + GpuHour, +} + +/// Token rates expressed as USD per one million tokens. +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +pub struct TokenPricingRates { + /// Uncached prompt/input token price. + pub input_per_million: f64, + /// Completion/output token price. + pub output_per_million: f64, + /// Cached prompt/input token read price. + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_read_per_million: Option, + /// Prompt cache write price. + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_write_per_million: Option, +} + +impl TokenPricingRates { + fn validate(&self, entry_index: usize, field_prefix: &str) -> Result<(), PricingCatalogError> { + validate_rate( + entry_index, + format!("{field_prefix}.input_per_million"), + self.input_per_million, + )?; + validate_rate( + entry_index, + format!("{field_prefix}.output_per_million"), + self.output_per_million, + )?; + if let Some(value) = self.cache_read_per_million { + validate_rate( + entry_index, + format!("{field_prefix}.cache_read_per_million"), + value, + )?; + } + if let Some(value) = self.cache_write_per_million { + validate_rate( + entry_index, + format!("{field_prefix}.cache_write_per_million"), + value, + )?; + } + Ok(()) + } +} + +/// Data-driven token rate schedule for provider pricing with request thresholds. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum TokenRateSchedule { + /// Selects one full-request rate tier based on prompt/input tokens. + PromptTokenThreshold { + /// How selected tier rates apply to tokens. + #[serde(default)] + applies_to: RateScheduleApplication, + /// Ordered threshold tiers. + tiers: Vec, + }, +} + +impl TokenRateSchedule { + fn rates_for_usage(&self, usage: &Usage) -> Option { + match self { + Self::PromptTokenThreshold { applies_to, tiers } => { + if *applies_to != RateScheduleApplication::FullRequest { + return None; + } + let prompt_tokens = usage.prompt_tokens?; + tiers + .iter() + .find(|tier| tier.matches_prompt_tokens(prompt_tokens)) + .map(|tier| tier.rates) + } + } + } + + fn validate(&self, entry_index: usize) -> Result<(), PricingCatalogError> { + match self { + Self::PromptTokenThreshold { tiers, .. } if tiers.is_empty() => { + Err(PricingCatalogError::EmptyField { + entry_index, + field: "rate_schedule.tiers".to_string(), + }) + } + Self::PromptTokenThreshold { tiers, .. } => { + for (tier_index, tier) in tiers.iter().enumerate() { + tier.validate(entry_index, tier_index)?; + } + Ok(()) + } + } + } +} + +/// How a selected rate-schedule tier applies to billable usage. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum RateScheduleApplication { + /// Apply the selected tier rates to the entire request. + #[default] + FullRequest, +} + +/// A token pricing tier selected by prompt/input token count. +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +pub struct TokenRateTier { + /// Inclusive lower bound for prompt tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub min_prompt_tokens: Option, + /// Inclusive upper bound for prompt tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_prompt_tokens: Option, + /// Rates to apply when this tier is selected. + pub rates: TokenPricingRates, +} + +impl TokenRateTier { + fn matches_prompt_tokens(&self, prompt_tokens: u64) -> bool { + self.min_prompt_tokens + .is_none_or(|min| prompt_tokens >= min) + && self + .max_prompt_tokens + .is_none_or(|max| prompt_tokens <= max) + } + + fn validate(&self, entry_index: usize, tier_index: usize) -> Result<(), PricingCatalogError> { + if let (Some(min), Some(max)) = (self.min_prompt_tokens, self.max_prompt_tokens) + && min > max + { + return Err(PricingCatalogError::InvalidRate { + entry_index, + field: "rate_schedule.tiers.prompt_tokens".to_string(), + value: min as f64, + }); + } + self.rates.validate( + entry_index, + &format!("rate_schedule.tiers[{tier_index}].rates"), + ) + } +} + +/// Prompt-cache accounting rules for a model pricing entry. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub struct PromptCachePricing { + /// Whether cache-read tokens are included in `prompt_tokens`. + pub read_accounting: CacheReadAccounting, +} + +/// How cache-read tokens relate to prompt token counts in provider usage. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CacheReadAccounting { + /// `cache_read_tokens` are already included in `prompt_tokens`. + IncludedInPromptTokens, + /// `cache_read_tokens` are separate from `prompt_tokens`. + Separate, +} + +/// Returns known pricing for a model ID. +/// +/// Unknown models return `None` so response handling and observability export +/// can continue without inventing a cost. +#[must_use] +pub fn pricing_for_model(model: &str) -> Option { + active_pricing_resolver().pricing_for_model(model) +} + +/// Returns known pricing for a provider/model pair. +#[must_use] +pub fn pricing_for_provider(provider: Option<&str>, model: &str) -> Option { + active_pricing_resolver().pricing_for(provider, model) +} + +/// Estimates USD cost for a model/usage pair when pricing is known. +#[must_use] +pub fn estimate_cost(model: &str, usage: &Usage) -> Option { + active_pricing_resolver().estimate_cost(model, usage) +} + +/// Estimates USD cost for a provider/model pair when pricing is known. +#[must_use] +pub fn estimate_cost_for_provider( + provider: Option<&str>, + model: &str, + usage: &Usage, +) -> Option { + active_pricing_resolver().estimate_cost_for_provider(provider, model, usage) +} + +/// Estimates USD cost using the provided catalog. +#[must_use] +pub fn estimate_cost_with_catalog( + catalog: &PricingCatalog, + model: &str, + usage: &Usage, +) -> Option { + catalog + .pricing_for_model(model) + .and_then(|pricing| pricing.estimate_cost(usage)) +} + +/// Estimates USD cost using the provided catalog and provider/model pair. +#[must_use] +pub fn estimate_cost_with_provider( + catalog: &PricingCatalog, + provider: Option<&str>, + model: &str, + usage: &Usage, +) -> Option { + catalog + .pricing_for(provider, model) + .and_then(|pricing| pricing.estimate_cost(usage)) +} + +/// Returns the active process-wide pricing resolver. +#[must_use] +pub fn active_pricing_resolver() -> Arc { + ACTIVE_PRICING_RESOLVER + .read() + .map(|resolver| Arc::clone(&resolver)) + .unwrap_or_else(|_| Arc::new(PricingResolver::default())) +} + +/// Replaces the active process-wide pricing resolver. +pub fn set_active_pricing_resolver(resolver: PricingResolver) -> Result<(), PricingCatalogError> { + let mut guard = ACTIVE_PRICING_RESOLVER + .write() + .map_err(|err| PricingCatalogError::LockPoisoned(err.to_string()))?; + *guard = Arc::new(resolver); + Ok(()) +} + +/// Restores the active process-wide pricing resolver to an empty resolver. +pub fn reset_active_pricing_resolver() -> Result<(), PricingCatalogError> { + set_active_pricing_resolver(PricingResolver::default()) +} + +/// Adds a model-pricing estimate to a normalized response when cost is missing. +/// +/// Existing provider-reported or caller-supplied costs are preserved. +pub fn attach_estimated_cost(response: &mut AnnotatedLlmResponse) { + attach_estimated_cost_for_provider(response, None); +} + +/// Adds a provider-aware model-pricing estimate to a normalized response when cost is missing. +/// +/// Existing provider-reported or caller-supplied costs are preserved. +pub fn attach_estimated_cost_for_provider( + response: &mut AnnotatedLlmResponse, + provider: Option<&str>, +) { + if response + .usage + .as_ref() + .and_then(|usage| usage.cost.as_ref()) + .is_some() + { + return; + } + + let Some(model) = response.model.clone() else { + return; + }; + let Some(usage) = response.usage.as_mut() else { + return; + }; + + usage.cost = estimate_cost_for_provider(provider, &model, usage); +} + +fn validate_nonempty( + entry_index: usize, + field: &'static str, + value: &str, +) -> Result<(), PricingCatalogError> { + if value.trim().is_empty() { + return Err(PricingCatalogError::EmptyField { + entry_index, + field: field.to_string(), + }); + } + + Ok(()) +} + +fn validate_rate( + entry_index: usize, + field: impl Into, + value: f64, +) -> Result<(), PricingCatalogError> { + if !value.is_finite() || value < 0.0 { + return Err(PricingCatalogError::InvalidRate { + entry_index, + field: field.into(), + value, + }); + } + + Ok(()) +} + +fn default_pricing_currency() -> String { + "USD".into() +} + +fn normalized_model_name(model: &str) -> String { + model.trim().to_ascii_lowercase() +} + +fn normalized_provider_name(provider: &str) -> String { + provider.trim().trim_matches('/').to_ascii_lowercase() +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ModelLookupKey { + provider: Option, + model: String, +} + +fn normalized_model_lookup_keys(provider: Option<&str>, model: &str) -> Vec { + let normalized = normalized_model_name(model); + if normalized.is_empty() { + return vec![]; + } + + let parts: Vec<&str> = normalized + .split('/') + .map(str::trim) + .filter(|part| !part.is_empty()) + .collect(); + let mut keys = Vec::with_capacity(parts.len() + 3); + let explicit_provider = provider + .map(normalized_provider_name) + .filter(|provider| !provider.is_empty()); + let terminal_model = parts + .last() + .copied() + .unwrap_or(normalized.as_str()) + .to_string(); + + if let Some(provider) = explicit_provider { + push_lookup_key(&mut keys, Some(provider.clone()), normalized.clone()); + push_lookup_key(&mut keys, Some(provider), terminal_model.clone()); + } else if parts.len() > 1 { + push_lookup_key( + &mut keys, + Some(parts[..parts.len() - 1].join("/")), + terminal_model, + ); + } + + for start in 0..parts.len() { + let key = parts[start..].join("/"); + push_lookup_key(&mut keys, None, key); + } + keys +} + +fn push_lookup_key(keys: &mut Vec, provider: Option, model: String) { + let key = ModelLookupKey { provider, model }; + if !key.model.is_empty() && !keys.contains(&key) { + keys.push(key); + } +} + +/// Infers a provider/route value for a decoded model. +#[must_use] +pub fn infer_model_provider(default_provider: &str, model: Option<&str>) -> Option { + let normalized_default = normalized_provider_name(default_provider); + if let Some(model) = model { + let normalized = normalized_model_name(model); + let parts: Vec<&str> = normalized + .split('/') + .map(str::trim) + .filter(|part| !part.is_empty()) + .collect(); + if parts.len() > 1 { + return Some(parts[..parts.len() - 1].join("/")); + } + } + + (!normalized_default.is_empty()).then_some(normalized_default) +} + +fn cost_component(tokens: u64, price_per_million: f64) -> f64 { + tokens as f64 * price_per_million / 1_000_000.0 +} + +fn cost_component_if_nonzero(tokens: u64, price_per_million: f64) -> Option { + (tokens > 0).then(|| round_cost_amount(cost_component(tokens, price_per_million))) +} + +fn round_cost_amount(cost: f64) -> f64 { + const SCALE: f64 = 1_000_000_000_000.0; + (cost * SCALE).round() / SCALE +} diff --git a/crates/core/src/codec/response.rs b/crates/core/src/codec/response.rs index 1ea83a89..325be5fd 100644 --- a/crates/core/src/codec/response.rs +++ b/crates/core/src/codec/response.rs @@ -10,6 +10,15 @@ use serde::{Deserialize, Serialize}; use crate::json::Json; +pub use super::pricing::{ + CacheReadAccounting, ModelPricing, PricingCatalog, PricingCatalogError, PricingConfig, + PricingResolver, PricingSource, PricingSourceConfig, PricingUnit, PromptCachePricing, + TokenPricingRates, active_pricing_resolver, attach_estimated_cost, + attach_estimated_cost_for_provider, estimate_cost, estimate_cost_for_provider, + estimate_cost_with_catalog, estimate_cost_with_provider, infer_model_provider, + pricing_for_model, pricing_for_provider, reset_active_pricing_resolver, + set_active_pricing_resolver, +}; use super::request::MessageContent; // --------------------------------------------------------------------------- @@ -68,7 +77,7 @@ pub struct AnnotatedLlmResponse { /// All fields are `Option` because not every provider supplies every /// field. For example, cache token counts are only available from providers /// that support prompt caching. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] pub struct Usage { /// Tokens consumed by the prompt/input. #[serde(skip_serializing_if = "Option::is_none")] @@ -85,6 +94,176 @@ pub struct Usage { /// Tokens written to prompt cache. #[serde(skip_serializing_if = "Option::is_none")] pub cache_write_tokens: Option, + /// Optional cost reported by provider data or estimated from Relay pricing. + #[serde(skip_serializing_if = "Option::is_none")] + pub cost: Option, +} + +/// Source of a normalized cost value. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CostSource { + /// Cost was estimated by applying Relay's model pricing table to usage. + ModelPricing, + /// Cost was reported directly by a provider or framework payload. + ProviderReported, +} + +/// Normalized LLM response cost. +/// +/// Provider-reported cost is preserved as-is. Model-pricing estimates include +/// source and as-of metadata so downstream systems can audit stale pricing +/// tables without losing a usable estimate. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct CostEstimate { + /// Total cost in `currency`. + #[serde(skip_serializing_if = "Option::is_none")] + pub total: Option, + /// ISO 4217 currency code for the cost fields. + #[serde(default = "default_cost_currency")] + pub currency: String, + /// Uncached prompt/input token cost in `currency`. + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, + /// Completion/output token cost in `currency`. + #[serde(skip_serializing_if = "Option::is_none")] + pub output: Option, + /// Prompt cache read cost in `currency`. + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_read: Option, + /// Prompt cache write cost in `currency`. + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_write: Option, + /// Origin of this cost value. + pub source: CostSource, + /// Provider associated with the cost or pricing estimate, if known. + #[serde(skip_serializing_if = "Option::is_none")] + pub pricing_provider: Option, + /// Model ID associated with the cost or pricing estimate, if known. + #[serde(skip_serializing_if = "Option::is_none")] + pub pricing_model: Option, + /// Date the pricing value was last verified, if known. + #[serde(skip_serializing_if = "Option::is_none")] + pub pricing_as_of: Option, + /// Source URL or label for the pricing value, if known. + #[serde(skip_serializing_if = "Option::is_none")] + pub pricing_source: Option, +} + +impl CostEstimate { + /// Returns the explicit total, or the sum of component costs when no total was supplied. + #[must_use] + pub fn total_or_component_sum(&self) -> Option { + self.total.or_else(|| { + let (has_component, total) = + [self.input, self.output, self.cache_read, self.cache_write] + .into_iter() + .flatten() + .fold((false, 0.0), |(_, total), value| (true, total + value)); + has_component.then_some(total) + }) + } + + /// Returns the total only when it is denominated in the requested currency. + #[must_use] + pub fn total_for_currency(&self, currency: &str) -> Option { + self.currency + .eq_ignore_ascii_case(currency) + .then_some(self.total) + .flatten() + } + + /// Returns the explicit or component-derived total in the requested currency. + #[must_use] + pub fn total_or_component_sum_for_currency(&self, currency: &str) -> Option { + self.currency + .eq_ignore_ascii_case(currency) + .then(|| self.total_or_component_sum()) + .flatten() + } +} + +/// Provider/framework cost object accepted by built-in response codecs. +#[derive(Debug, Clone, Default, Deserialize)] +pub(crate) struct RawUsageCost { + /// Normalized total cost in the supplied currency. + pub total: Option, + /// Uncached prompt/input token cost in the supplied currency. + pub input: Option, + /// Completion/output token cost in the supplied currency. + pub output: Option, + /// Prompt cache read cost in the supplied currency. + pub cache_read: Option, + /// Prompt cache write cost in the supplied currency. + pub cache_write: Option, + /// Optional currency override from provider data. + pub currency: Option, + /// Optional provider provenance. + pub pricing_provider: Option, + /// Optional model provenance. + pub pricing_model: Option, + /// Optional as-of provenance. + pub pricing_as_of: Option, + /// Optional source provenance. + pub pricing_source: Option, +} + +pub(crate) fn provider_reported_cost( + provider_total_cost: Option, + cost: Option, +) -> Option { + let cost = cost.unwrap_or_default(); + let input = cost.input; + let output = cost.output; + let cache_read = cost.cache_read; + let cache_write = cost.cache_write; + let has_currency_native_amount = cost.total.is_some() + || cost.input.is_some() + || cost.output.is_some() + || cost.cache_read.is_some() + || cost.cache_write.is_some(); + let component_total = [input, output, cache_read, cache_write] + .into_iter() + .flatten() + .sum(); + let has_component_cost = + input.is_some() || output.is_some() || cache_read.is_some() || cache_write.is_some(); + let total = provider_total_cost + .or(cost.total) + .or_else(|| has_component_cost.then_some(component_total)); + + if total.is_none() + && input.is_none() + && output.is_none() + && cache_read.is_none() + && cache_write.is_none() + { + return None; + } + + Some(CostEstimate { + total, + currency: if provider_total_cost.is_some() { + default_cost_currency() + } else if has_currency_native_amount { + cost.currency.unwrap_or_else(default_cost_currency) + } else { + default_cost_currency() + }, + input, + output, + cache_read, + cache_write, + source: CostSource::ProviderReported, + pricing_provider: cost.pricing_provider, + pricing_model: cost.pricing_model, + pricing_as_of: cost.pricing_as_of, + pricing_source: cost.pricing_source, + }) +} + +fn default_cost_currency() -> String { + "USD".into() } // --------------------------------------------------------------------------- diff --git a/crates/core/src/observability/atif.rs b/crates/core/src/observability/atif.rs index cbb73644..8b7b9e9f 100644 --- a/crates/core/src/observability/atif.rs +++ b/crates/core/src/observability/atif.rs @@ -38,6 +38,7 @@ use uuid::Uuid; use crate::api::event::Event; use crate::api::runtime::EventSubscriberFn; use crate::api::subscriber::flush_subscribers; +use crate::codec::response::{Usage, estimate_cost_for_provider}; use crate::error::Result; use crate::json::Json; @@ -648,10 +649,16 @@ fn collect_openai_responses_content_text( /// Known keys in token_usage that we extract to dedicated fields. const TOKEN_USAGE_KNOWN_KEYS: &[&str] = &[ "prompt_tokens", + "input_tokens", "completion_tokens", + "output_tokens", "cached_tokens", + "cache_read_input_tokens", + "cache_creation_input_tokens", + "cache_write_tokens", "cost_usd", "cost", + "prompt_tokens_details", "prompt_token_ids", "completion_token_ids", "logprobs", @@ -662,23 +669,48 @@ const TOKEN_USAGE_KNOWN_KEYS: &[&str] = &[ /// Supports NeMo Relay `token_usage` and provider-native `usage` payloads. /// Populates `extra` with any unknown usage keys (e.g. reasoning_tokens or total_tokens). /// Returns `None` if the response has no recognized token or cost metrics. -fn extract_metrics(output: &Json) -> Option { +fn extract_metrics( + output: &Json, + provider: Option<&str>, + model_name: Option<&str>, +) -> Option { let usage = token_usage_object(output)?; let prompt = usage_u64(usage, &["prompt_tokens", "input_tokens"]); let completion = usage_u64(usage, &["completion_tokens", "output_tokens"]); - let cached = usage_u64(usage, &["cached_tokens"]) + let cache_read = usage_u64(usage, &["cached_tokens"]) .or_else(|| prompt_tokens_detail_u64(usage, "cached_tokens")) .or_else(|| input_tokens_detail_u64(usage, "cached_tokens")) - .or_else(|| { - sum_usage_u64( - usage, - &["cache_read_input_tokens", "cache_creation_input_tokens"], - ) - }); - let cost = usage + .or_else(|| usage_u64(usage, &["cache_read_input_tokens"])); + let cache_write = usage_u64( + usage, + &["cache_creation_input_tokens", "cache_write_tokens"], + ); + let cached = sum_options(cache_read, cache_write); + let explicit_cost = usage .get("cost_usd") .and_then(Json::as_f64) - .or_else(|| usage.get("cost")?.as_object()?.get("total")?.as_f64()); + .or_else(|| usage.get("cost").and_then(cost_usd_from_cost_object)); + let has_reported_cost = usage.get("cost").is_some(); + let cost = if has_reported_cost { + explicit_cost + } else { + explicit_cost.or_else(|| { + let model_name = model_name.or_else(|| response_model_name(output))?; + estimate_cost_for_provider( + provider, + model_name, + &Usage { + prompt_tokens: prompt, + completion_tokens: completion, + total_tokens: usage_u64(usage, &["total_tokens"]), + cache_read_tokens: cache_read, + cache_write_tokens: cache_write, + cost: None, + }, + ) + .and_then(|cost| cost.total_for_currency("USD")) + }) + }; let prompt_ids = usage .get("prompt_token_ids") .and_then(Json::as_array) @@ -717,6 +749,30 @@ fn extract_metrics(output: &Json) -> Option { }) } +fn cost_usd_from_cost_object(cost: &Json) -> Option { + let cost = cost.as_object()?; + let currency = cost.get("currency").and_then(Json::as_str); + let is_relay_normalized_cost = cost + .get("source") + .and_then(Json::as_str) + .is_some_and(|source| matches!(source, "provider_reported" | "model_pricing")); + let has_legacy_provider_total = + currency.is_none() && cost.get("total").and_then(Json::as_f64).is_some(); + let is_usd_cost = currency.is_some_and(|currency| currency.eq_ignore_ascii_case("USD")) + || currency.is_none() && (is_relay_normalized_cost || has_legacy_provider_total); + if !is_usd_cost { + return None; + } + + cost.get("total").and_then(Json::as_f64).or_else(|| { + let (has_component, component_total) = ["input", "output", "cache_read", "cache_write"] + .iter() + .filter_map(|field| cost.get(*field).and_then(Json::as_f64)) + .fold((false, 0.0), |(_, total), value| (true, total + value)); + has_component.then_some(component_total) + }) +} + fn merge_metrics( primary: Option, supplemental: Option<&AtifMetrics>, @@ -787,16 +843,18 @@ fn usage_u64(usage: &serde_json::Map, keys: &[&str]) -> Option, keys: &[&str]) -> Option { - let mut total = 0; - let mut found = false; - for key in keys { - if let Some(value) = usage.get(*key).and_then(Json::as_u64) { - total += value; - found = true; - } +fn response_model_name(output: &Json) -> Option<&str> { + output + .as_object() + .and_then(|object| object.get("model").and_then(Json::as_str)) +} + +fn sum_options(left: Option, right: Option) -> Option { + match (left, right) { + (Some(left), Some(right)) => Some(left + right), + (Some(value), None) | (None, Some(value)) => Some(value), + (None, None) => None, } - found.then_some(total) } fn prompt_tokens_detail_u64(usage: &serde_json::Map, key: &str) -> Option { @@ -1349,7 +1407,9 @@ impl LlmSpanCandidate { .or_else(|| end.model_name()) .map(ToOwned::to_owned), fidelity_score: llm_event_fidelity_score(start).max(llm_event_fidelity_score(end)), - end_metrics: end.data().and_then(extract_metrics), + end_metrics: end + .data() + .and_then(|output| extract_metrics(output, Some(end.name()), end.model_name())), hook_instrumentation: is_hook_instrumented_llm_event(start) || is_hook_instrumented_llm_event(end), gateway_instrumentation: is_gateway_instrumented_llm_event(start) @@ -2168,7 +2228,7 @@ impl StepConversionState { ); let metrics = merge_metrics( - extract_metrics(output), + extract_metrics(output, Some(event.name()), event.model_name()), lookups.supplemental_llm_metrics.get(&event.uuid()), ); diff --git a/crates/core/src/observability/openinference.rs b/crates/core/src/observability/openinference.rs index 092ed669..98b6b239 100644 --- a/crates/core/src/observability/openinference.rs +++ b/crates/core/src/observability/openinference.rs @@ -27,7 +27,9 @@ use crate::api::subscriber::{deregister_subscriber, flush_subscribers, register_ use crate::codec::request::{ AnnotatedLlmRequest, ContentPart, Message, MessageContent, ToolDefinition, }; -use crate::codec::response::{AnnotatedLlmResponse, FinishReason, ResponseToolCall, Usage}; +use crate::codec::response::{ + AnnotatedLlmResponse, FinishReason, ResponseToolCall, Usage, estimate_cost_for_provider, +}; use crate::error::FlowError; use crate::json::Json; use chrono::{DateTime, Utc}; @@ -750,7 +752,7 @@ fn end_attributes(event: &Event) -> Vec { )); } } - if is_llm && let Some(cost_total) = cost_total_from_manual_llm_output(event.output()) { + if is_llm && let Some(cost_total) = cost_total_from_llm_event(event, fallback_usage.as_ref()) { attributes.push(KeyValue::new(oi::llm::cost::TOTAL, cost_total)); } if is_llm { @@ -1096,11 +1098,51 @@ fn cost_total_from_manual_llm_output(output: Option<&Json>) -> Option { .or_else(|| token_usage.and_then(cost_total_from_usage)) } +fn cost_total_from_llm_event(event: &Event, fallback_usage: Option<&Usage>) -> Option { + if let Some(cost) = cost_total_from_manual_llm_output(event.output()) { + return Some(cost); + } + + if let Some(response) = event.annotated_response() + && let Some(usage) = response.usage.as_ref() + { + if let Some(cost) = usage.cost.as_ref() { + return cost.total_or_component_sum_for_currency("USD"); + } + if let Some(model_name) = response.model.as_deref().or_else(|| event.model_name()) { + return estimate_cost_for_provider(Some(event.name()), model_name, usage) + .and_then(|cost| cost.total_for_currency("USD")); + } + } + + let usage = fallback_usage?; + let model_name = event + .model_name() + .or_else(|| model_name_from_manual_llm_output(event.output()))?; + estimate_cost_for_provider(Some(event.name()), model_name, usage) + .and_then(|cost| cost.total_for_currency("USD")) +} + +fn model_name_from_manual_llm_output(output: Option<&Json>) -> Option<&str> { + output?.as_object()?.get("model").and_then(Json::as_str) +} + fn cost_total_from_usage(usage: &serde_json::Map) -> Option { - usage - .get("cost_usd") - .and_then(Json::as_f64) - .or_else(|| usage.get("cost")?.as_object()?.get("total")?.as_f64()) + usage.get("cost_usd").and_then(Json::as_f64).or_else(|| { + let cost = usage.get("cost")?.as_object()?; + let currency = cost.get("currency").and_then(Json::as_str); + let is_usd_cost = currency.is_none_or(|currency| currency.eq_ignore_ascii_case("USD")); + if !is_usd_cost { + return None; + } + cost.get("total").and_then(Json::as_f64).or_else(|| { + let (has_component, component_total) = ["input", "output", "cache_read", "cache_write"] + .iter() + .filter_map(|field| cost.get(*field).and_then(Json::as_f64)) + .fold((false, 0.0), |(_, total), value| (true, total + value)); + has_component.then_some(component_total) + }) + }) } fn usage_from_manual_llm_output(output: Option<&Json>) -> Option { @@ -1190,6 +1232,7 @@ fn usage_from_manual_llm_output(output: Option<&Json>) -> Option { total_tokens, cache_read_tokens, cache_write_tokens, + cost: None, }) } diff --git a/crates/core/src/observability/otel.rs b/crates/core/src/observability/otel.rs index f7b42c6c..3cce7742 100644 --- a/crates/core/src/observability/otel.rs +++ b/crates/core/src/observability/otel.rs @@ -25,7 +25,9 @@ use crate::api::event::ScopeCategory; use crate::api::runtime::EventSubscriberFn; use crate::api::scope::ScopeType; use crate::api::subscriber::{deregister_subscriber, flush_subscribers, register_subscriber}; +use crate::codec::response::{CostEstimate, Usage, estimate_cost_for_provider}; use crate::error::FlowError; +use crate::json::Json; use chrono::{DateTime, Utc}; use opentelemetry::trace::{ Span as _, SpanContext, SpanKind, TraceContextExt, Tracer, TracerProvider as _, @@ -667,9 +669,222 @@ fn end_attributes(event: &Event) -> Vec { "nemo_relay.end.output_json", event.output(), ); + if event + .category() + .is_some_and(|category| category.as_str() == "llm") + && let Some((cost, currency)) = cost_from_llm_event(event) + { + attributes.push(KeyValue::new("nemo_relay.llm.cost.total", cost)); + attributes.push(KeyValue::new("nemo_relay.llm.cost.currency", currency)); + } attributes } +fn cost_from_llm_event(event: &Event) -> Option<(f64, String)> { + if let Some(cost) = cost_from_manual_llm_output(event.output()) { + return Some(cost); + } + if let Some(response) = event.annotated_response() + && let Some(usage) = response.usage.as_ref() + { + if let Some(cost) = usage.cost.as_ref() { + return cost_total_and_currency(cost); + } + if let Some(model_name) = response.model.as_deref().or_else(|| event.model_name()) { + return estimate_cost_for_provider(Some(event.name()), model_name, usage) + .and_then(|cost| cost_total_and_currency(&cost)); + } + } + let usage = usage_from_manual_llm_output(event.output())?; + let model_name = event + .model_name() + .or_else(|| model_name_from_manual_llm_output(event.output()))?; + estimate_cost_for_provider(Some(event.name()), model_name, &usage) + .and_then(|cost| cost_total_and_currency(&cost)) +} + +fn cost_total_and_currency(cost: &CostEstimate) -> Option<(f64, String)> { + Some((cost.total_or_component_sum()?, cost.currency.clone())) +} + +fn cost_from_manual_llm_output(output: Option<&Json>) -> Option<(f64, String)> { + let object = output?.as_object()?; + let usage = object.get("usage").and_then(Json::as_object); + let token_usage = object.get("token_usage").and_then(Json::as_object); + usage + .and_then(cost_from_manual_usage) + .or_else(|| token_usage.and_then(cost_from_manual_usage)) +} + +fn cost_from_manual_usage(usage: &serde_json::Map) -> Option<(f64, String)> { + usage + .get("cost_usd") + .and_then(Json::as_f64) + .map(|total| (total, "USD".to_string())) + .or_else(|| { + let cost = usage.get("cost")?.as_object()?; + let total = cost.get("total").and_then(Json::as_f64).or_else(|| { + let (has_component, component_total) = + ["input", "output", "cache_read", "cache_write"] + .iter() + .filter_map(|field| cost.get(*field).and_then(Json::as_f64)) + .fold((false, 0.0), |(_, total), value| (true, total + value)); + has_component.then_some(component_total) + })?; + Some(( + total, + cost.get("currency") + .and_then(Json::as_str) + .unwrap_or("USD") + .to_string(), + )) + }) +} + +fn usage_from_manual_llm_output(output: Option<&Json>) -> Option { + let object = output?.as_object()?; + let usage = object.get("usage").and_then(Json::as_object); + let token_usage = object.get("token_usage").and_then(Json::as_object); + if usage.is_none() && token_usage.is_none() { + return None; + } + + let prompt_tokens = first_u64_from_manual_usage( + usage, + token_usage, + &["prompt_tokens", "input_tokens", "inputTokens", "input"], + ); + let completion_tokens = first_u64_from_manual_usage( + usage, + token_usage, + &[ + "completion_tokens", + "output_tokens", + "completionTokens", + "outputTokens", + "output", + ], + ); + let reported_total_tokens = first_u64_from_manual_usage( + usage, + token_usage, + &["total_tokens", "totalTokens", "total"], + ); + let cache_read_tokens = first_u64_from_manual_usage( + usage, + token_usage, + &[ + "cache_read_tokens", + "cached_tokens", + "cache_read_input_tokens", + "cacheReadTokens", + "cachedTokens", + "cacheReadInputTokens", + "cacheRead", + ], + ) + .or_else(|| { + first_nested_u64_from_manual_usage( + usage, + token_usage, + "input_tokens_details", + "cached_tokens", + ) + }) + .or_else(|| { + first_nested_u64_from_manual_usage( + usage, + token_usage, + "prompt_tokens_details", + "cached_tokens", + ) + }); + let cache_write_tokens = first_u64_from_manual_usage( + usage, + token_usage, + &[ + "cache_write_tokens", + "cache_creation_input_tokens", + "cacheWriteTokens", + "cacheCreationInputTokens", + "cacheWrite", + ], + ); + + if prompt_tokens.is_none() + && completion_tokens.is_none() + && reported_total_tokens.is_none() + && cache_read_tokens.is_none() + && cache_write_tokens.is_none() + { + return None; + } + + Some(Usage { + prompt_tokens, + completion_tokens, + total_tokens: normalize_total_tokens( + reported_total_tokens, + prompt_tokens, + completion_tokens, + ), + cache_read_tokens, + cache_write_tokens, + cost: None, + }) +} + +fn model_name_from_manual_llm_output(output: Option<&Json>) -> Option<&str> { + output?.as_object()?.get("model").and_then(Json::as_str) +} + +fn first_u64_from_manual_usage( + usage: Option<&serde_json::Map>, + token_usage: Option<&serde_json::Map>, + keys: &[&str], +) -> Option { + keys.iter().find_map(|key| { + usage + .and_then(|usage| usage.get(*key).and_then(Json::as_u64)) + .or_else(|| token_usage.and_then(|usage| usage.get(*key).and_then(Json::as_u64))) + }) +} + +fn first_nested_u64_from_manual_usage( + usage: Option<&serde_json::Map>, + token_usage: Option<&serde_json::Map>, + parent: &str, + key: &str, +) -> Option { + usage + .and_then(|usage| usage.get(parent).and_then(Json::as_object)) + .and_then(|details| details.get(key).and_then(Json::as_u64)) + .or_else(|| { + token_usage + .and_then(|usage| usage.get(parent).and_then(Json::as_object)) + .and_then(|details| details.get(key).and_then(Json::as_u64)) + }) +} + +fn normalize_total_tokens( + reported_total_tokens: Option, + prompt_tokens: Option, + completion_tokens: Option, +) -> Option { + let calculated_total = match (prompt_tokens, completion_tokens) { + (Some(prompt), Some(completion)) => Some(prompt + completion), + (Some(prompt), None) => Some(prompt), + (None, Some(completion)) => Some(completion), + (None, None) => None, + }; + match (reported_total_tokens, calculated_total) { + (Some(reported), Some(calculated)) if reported >= calculated => Some(reported), + (Some(_), Some(calculated)) => Some(calculated), + (Some(reported), None) => Some(reported), + (None, calculated) => calculated, + } +} + fn mark_attributes(event: &Event) -> Vec { let handle_attributes = event.attributes(); let mut attributes = vec![ diff --git a/crates/core/src/plugin.rs b/crates/core/src/plugin.rs index f9731a3e..e3d90715 100644 --- a/crates/core/src/plugin.rs +++ b/crates/core/src/plugin.rs @@ -764,7 +764,8 @@ pub fn register_plugin(plugin: Arc) -> Result<()> { pub fn ensure_builtin_plugins_registered() -> Result<()> { let register_builtins = || { crate::observability::plugin_component::register_observability_component()?; - crate::plugins::nemo_guardrails::component::register_nemo_guardrails_component() + crate::plugins::nemo_guardrails::component::register_nemo_guardrails_component()?; + crate::plugins::pricing::register_pricing_component() }; match BUILTIN_PLUGIN_REGISTRATION.get_or_init(register_builtins) { Ok(()) => Ok(()), diff --git a/crates/core/src/plugins/mod.rs b/crates/core/src/plugins/mod.rs index de69a84f..d6cef9c1 100644 --- a/crates/core/src/plugins/mod.rs +++ b/crates/core/src/plugins/mod.rs @@ -4,3 +4,4 @@ //! First-party plugin implementations for NeMo Relay Core. pub mod nemo_guardrails; +pub mod pricing; diff --git a/crates/core/src/plugins/pricing.rs b/crates/core/src/plugins/pricing.rs new file mode 100644 index 00000000..36e04a0b --- /dev/null +++ b/crates/core/src/plugins/pricing.rs @@ -0,0 +1,96 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Built-in pricing plugin component. + +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use serde_json::{Map, Value as Json}; + +use crate::codec::response::{ + PricingConfig, PricingResolver, reset_active_pricing_resolver, set_active_pricing_resolver, +}; +use crate::plugin::{ + ConfigDiagnostic, DiagnosticLevel, Plugin, PluginError, PluginRegistration, + PluginRegistrationContext, Result, register_plugin, +}; + +/// Plugin kind used by the pricing component. +pub const PRICING_PLUGIN_KIND: &str = "pricing"; + +/// Registers the built-in pricing component. +pub fn register_pricing_component() -> Result<()> { + match register_plugin(Arc::new(PricingPlugin)) { + Ok(()) => Ok(()), + Err(PluginError::RegistrationFailed(message)) + if message.contains("plugin 'pricing' is already registered") => + { + Ok(()) + } + Err(err) => Err(err), + } +} + +struct PricingPlugin; + +impl Plugin for PricingPlugin { + fn plugin_kind(&self) -> &str { + PRICING_PLUGIN_KIND + } + + fn allows_multiple_components(&self) -> bool { + false + } + + fn validate(&self, plugin_config: &Map) -> Vec { + let config = + match serde_json::from_value::(Json::Object(plugin_config.clone())) { + Ok(config) => config, + Err(error) => { + return vec![ConfigDiagnostic { + level: DiagnosticLevel::Error, + code: "pricing.invalid_config".into(), + component: Some(PRICING_PLUGIN_KIND.into()), + field: None, + message: format!("invalid pricing config: {error}"), + }]; + } + }; + match PricingResolver::from_config(&config) { + Ok(_) => vec![], + Err(error) => vec![ConfigDiagnostic { + level: DiagnosticLevel::Error, + code: "pricing.invalid_config".into(), + component: Some(PRICING_PLUGIN_KIND.into()), + field: None, + message: format!("invalid pricing config: {error}"), + }], + } + } + + fn register<'a>( + &'a self, + plugin_config: &Map, + ctx: &'a mut PluginRegistrationContext, + ) -> Pin> + Send + 'a>> { + let plugin_config = plugin_config.clone(); + Box::pin(async move { + let config: PricingConfig = serde_json::from_value(Json::Object(plugin_config))?; + let resolver = PricingResolver::from_config(&config) + .map_err(|error| PluginError::InvalidConfig(error.to_string()))?; + set_active_pricing_resolver(resolver) + .map_err(|error| PluginError::RegistrationFailed(error.to_string()))?; + ctx.add_registration(PluginRegistration::new( + "plugin", + ctx.qualify_name("pricing"), + Box::new(|| { + reset_active_pricing_resolver() + .map_err(|error| PluginError::RegistrationFailed(error.to_string())) + }), + )); + Ok(()) + }) + } +} diff --git a/crates/core/src/stream.rs b/crates/core/src/stream.rs index 18eeb326..057d8cdb 100644 --- a/crates/core/src/stream.rs +++ b/crates/core/src/stream.rs @@ -36,7 +36,7 @@ use crate::api::llm::LlmHandle; use crate::api::runtime::NemoRelayContextState; use crate::api::runtime::global_context; use crate::api::runtime::{ScopeStackHandle, current_scope_stack}; -use crate::codec::response::AnnotatedLlmResponse; +use crate::codec::response::{AnnotatedLlmResponse, attach_estimated_cost_for_provider}; use crate::codec::traits::LlmResponseCodec; use crate::error::Result; use crate::json::Json; @@ -144,7 +144,11 @@ impl LlmStreamWrapper { let annotated_response: Option> = self .response_codec .as_ref() - .and_then(|c| c.decode_response(&aggregated).ok()) + .and_then(|c| { + let mut decoded = c.decode_response(&aggregated).ok()?; + attach_estimated_cost_for_provider(&mut decoded, Some(&self.handle.name)); + Some(decoded) + }) .map(Arc::new); let event_snapshot = { diff --git a/crates/core/tests/integration/pipeline_tests.rs b/crates/core/tests/integration/pipeline_tests.rs index b2200150..6b9da3ef 100644 --- a/crates/core/tests/integration/pipeline_tests.rs +++ b/crates/core/tests/integration/pipeline_tests.rs @@ -28,8 +28,11 @@ use nemo_relay::api::scope::ScopeType; use nemo_relay::api::subscriber::{deregister_subscriber, flush_subscribers, register_subscriber}; use nemo_relay::codec::request::AnnotatedLlmRequest; use nemo_relay::codec::request::MessageContent; -use nemo_relay::codec::response::AnnotatedLlmResponse; use nemo_relay::codec::response::FinishReason; +use nemo_relay::codec::response::{ + AnnotatedLlmResponse, PricingCatalog, PricingResolver, Usage, reset_active_pricing_resolver, + set_active_pricing_resolver, +}; use nemo_relay::codec::traits::{LlmCodec, LlmResponseCodec}; use nemo_relay::error::{FlowError, Result}; use nemo_relay::json::Json; @@ -60,6 +63,33 @@ fn captured_events_snapshot(events: &Arc>>) -> Vec { events.lock().unwrap().clone() } +fn install_mock_response_pricing() { + let catalog = PricingCatalog::from_json_str( + &json!({ + "version": 1, + "entries": [ + { + "provider": "openai", + "model_id": "gpt-4o-mini", + "pricing_as_of": "2026-06-05", + "pricing_source": "test", + "rates": { + "input_per_million": 0.15, + "output_per_million": 0.60, + "cache_read_per_million": 0.075 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + } + ] + }) + .to_string(), + ) + .unwrap(); + set_active_pricing_resolver(PricingResolver::from_catalogs(vec![catalog])).unwrap(); +} + // --------------------------------------------------------------------------- // TrackingCodec — records decode/encode calls and performs real transformations // --------------------------------------------------------------------------- @@ -890,11 +920,18 @@ impl LlmResponseCodec for MockResponseCodec { fn decode_response(&self, _response: &Json) -> Result { Ok(AnnotatedLlmResponse { id: Some("mock-resp-id".into()), - model: Some("mock-model".into()), + model: Some("gpt-4o-mini".into()), message: Some(MessageContent::Text("mock response text".into())), tool_calls: None, finish_reason: Some(FinishReason::Complete), - usage: None, + usage: Some(Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + total_tokens: Some(1_500), + cache_read_tokens: Some(200), + cache_write_tokens: None, + cost: None, + }), api_specific: None, extra: serde_json::Map::new(), }) @@ -915,6 +952,7 @@ async fn test_response_codec_populates_annotated_response() { let _lock = TEST_MUTEX.lock().unwrap(); reset_global(); setup_isolated_thread(); + install_mock_response_pricing(); let events = Arc::new(Mutex::new(Vec::new())); let ec = events.clone(); @@ -931,7 +969,7 @@ async fn test_response_codec_populates_annotated_response() { let _result = llm_call_execute( LlmCallExecuteParams::builder() - .name("test_llm") + .name("openai") .request(request) .func(noop_exec_fn()) .response_codec(response_codec) @@ -951,8 +989,16 @@ async fn test_response_codec_populates_annotated_response() { .expect("annotated_response should be Some when response codec is active"); assert_eq!(ann.id, Some("mock-resp-id".into())); assert_eq!(ann.response_text(), Some("mock response text")); + assert_eq!( + ann.usage + .as_ref() + .and_then(|usage| usage.cost.as_ref()) + .and_then(|cost| cost.total), + Some(0.000_435) + ); deregister_subscriber("resp_codec_sub").unwrap(); + reset_active_pricing_resolver().unwrap(); } #[tokio::test] @@ -1100,6 +1146,7 @@ async fn test_stream_response_codec_populates_annotated_response() { let _lock = TEST_MUTEX.lock().unwrap(); reset_global(); setup_isolated_thread(); + install_mock_response_pricing(); let events = Arc::new(Mutex::new(Vec::new())); let ec = events.clone(); @@ -1120,7 +1167,7 @@ async fn test_stream_response_codec_populates_annotated_response() { let mut stream = llm_stream_call_execute( LlmStreamCallExecuteParams::builder() - .name("test_stream") + .name("openai") .request(request) .func(noop_stream_exec_fn()) .collector(collector) @@ -1145,6 +1192,14 @@ async fn test_stream_response_codec_populates_annotated_response() { .expect("annotated_response should be Some on stream path when response codec is active"); assert_eq!(ann.id, Some("mock-resp-id".into())); assert_eq!(ann.response_text(), Some("mock response text")); + assert_eq!( + ann.usage + .as_ref() + .and_then(|usage| usage.cost.as_ref()) + .and_then(|cost| cost.total), + Some(0.000_435) + ); deregister_subscriber("stream_resp_codec_sub").unwrap(); + reset_active_pricing_resolver().unwrap(); } diff --git a/crates/core/tests/unit/atif_tests.rs b/crates/core/tests/unit/atif_tests.rs index b409abae..1bfe2cdd 100644 --- a/crates/core/tests/unit/atif_tests.rs +++ b/crates/core/tests/unit/atif_tests.rs @@ -11,9 +11,48 @@ use crate::api::event::{ use crate::api::llm::LlmAttributes; use crate::api::scope::{HandleAttributes, ScopeAttributes, ScopeType}; use crate::api::tool::ToolAttributes; +use crate::codec::pricing::pricing_test_mutex; +use crate::codec::response::{ + PricingCatalog, PricingResolver, reset_active_pricing_resolver, set_active_pricing_resolver, +}; use serde_json::json; use std::collections::HashSet; +struct ResetPricingResolverGuard; + +impl Drop for ResetPricingResolverGuard { + fn drop(&mut self) { + let _ = reset_active_pricing_resolver(); + } +} + +fn install_test_pricing(model_id: &str) { + let catalog = PricingCatalog::from_json_str( + &json!({ + "version": 1, + "entries": [ + { + "provider": "test", + "model_id": model_id, + "pricing_as_of": "2026-06-05", + "pricing_source": "test", + "rates": { + "input_per_million": 0.15, + "output_per_million": 0.60, + "cache_read_per_million": 0.075 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + } + ] + }) + .to_string(), + ) + .unwrap(); + set_active_pricing_resolver(PricingResolver::from_catalogs(vec![catalog])).unwrap(); +} + #[derive(Debug, Clone, Copy)] enum EventType { Start, @@ -678,17 +717,21 @@ fn test_exporter_llm_lifecycle() { #[test] fn test_extract_metrics_supports_provider_usage_payloads() { - let openai_metrics = extract_metrics(&json!({ - "usage": { - "prompt_tokens": 10, - "completion_tokens": 20, - "total_tokens": 30, - "cost_usd": 0.001, - "prompt_tokens_details": { - "cached_tokens": 4 + let openai_metrics = extract_metrics( + &json!({ + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + "cost_usd": 0.001, + "prompt_tokens_details": { + "cached_tokens": 4 + } } - } - })) + }), + None, + None, + ) .unwrap(); assert_eq!(openai_metrics.prompt_tokens, Some(10)); assert_eq!(openai_metrics.completion_tokens, Some(20)); @@ -699,17 +742,21 @@ fn test_extract_metrics_supports_provider_usage_payloads() { ); assert_eq!(openai_metrics.cost_usd, Some(0.001)); - let responses_metrics = extract_metrics(&json!({ - "usage": { - "input_tokens": 75, - "output_tokens": 20, - "total_tokens": 95, - "input_tokens_details": { - "cached_tokens": 10 - }, - "cost_usd": 0.005 - } - })) + let responses_metrics = extract_metrics( + &json!({ + "usage": { + "input_tokens": 75, + "output_tokens": 20, + "total_tokens": 95, + "input_tokens_details": { + "cached_tokens": 10 + }, + "cost_usd": 0.005 + } + }), + None, + None, + ) .unwrap(); assert_eq!(responses_metrics.prompt_tokens, Some(75)); assert_eq!(responses_metrics.completion_tokens, Some(20)); @@ -720,20 +767,216 @@ fn test_extract_metrics_supports_provider_usage_payloads() { ); assert_eq!(responses_metrics.cost_usd, Some(0.005)); - let anthropic_metrics = extract_metrics(&json!({ - "usage": { - "input_tokens": 11, - "output_tokens": 22, - "cache_read_input_tokens": 3, - "cache_creation_input_tokens": 5, - "cost": { "total": 0.0042 } - } - })) + let anthropic_metrics = extract_metrics( + &json!({ + "usage": { + "input_tokens": 11, + "output_tokens": 22, + "cache_read_input_tokens": 3, + "cache_creation_input_tokens": 5, + "cost": { "total": 0.0042, "currency": "USD" } + } + }), + None, + None, + ) .unwrap(); assert_eq!(anthropic_metrics.prompt_tokens, Some(11)); assert_eq!(anthropic_metrics.completion_tokens, Some(22)); assert_eq!(anthropic_metrics.cached_tokens, Some(8)); assert_eq!(anthropic_metrics.cost_usd, Some(0.0042)); + + let non_usd_metrics = extract_metrics( + &json!({ + "usage": { + "input_tokens": 11, + "output_tokens": 22, + "cost": { "total": 0.0042, "currency": "EUR" } + } + }), + None, + None, + ) + .unwrap(); + assert_eq!(non_usd_metrics.cost_usd, None); +} + +#[test] +fn test_reported_cost_object_blocks_model_pricing_estimation() { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + install_test_pricing("priced-model"); + let _reset_guard = ResetPricingResolverGuard; + + let non_usd_metrics = extract_metrics( + &json!({ + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 500, + "total_tokens": 1500, + "cost": { + "total": 0.42, + "currency": "EUR" + } + } + }), + Some("test"), + Some("priced-model"), + ) + .unwrap(); + + assert_eq!(non_usd_metrics.cost_usd, None); + + let missing_total_metrics = extract_metrics( + &json!({ + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 500, + "total_tokens": 1500, + "cost": { + "currency": "USD" + } + } + }), + Some("test"), + Some("priced-model"), + ) + .unwrap(); + + assert_eq!(missing_total_metrics.cost_usd, None); + + let legacy_missing_currency_metrics = extract_metrics( + &json!({ + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 500, + "total_tokens": 1500, + "cost": { + "total": 0.42 + } + } + }), + Some("test"), + Some("priced-model"), + ) + .unwrap(); + + assert_eq!(legacy_missing_currency_metrics.cost_usd, Some(0.42)); + + let component_metrics = extract_metrics( + &json!({ + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 500, + "total_tokens": 1500, + "cost": { + "currency": "usd", + "input": 0.25, + "output": 0.5, + "cache_read": 0.125 + } + } + }), + Some("test"), + Some("priced-model"), + ) + .unwrap(); + + assert_eq!(component_metrics.cost_usd, Some(0.875)); +} + +#[test] +fn test_exporter_derives_llm_cost_from_model_pricing() { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + install_test_pricing("priced-model"); + let _reset_guard = ResetPricingResolverGuard; + + let exporter = AtifExporter::new("session-1".to_string(), make_agent_info()); + let llm_uuid = Uuid::now_v7(); + + let end = event_builder(llm_uuid, EventType::End) + .name("gpt-4o-mini") + .scope_type(ScopeType::Llm) + .output(json!({ + "content": "priced response", + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 500, + "total_tokens": 1500, + "prompt_tokens_details": {"cached_tokens": 200} + } + })) + .model_name("priced-model") + .build(); + + { + let mut state = exporter.state.lock().unwrap(); + state.events.push(end); + } + + let trajectory = exporter.export().unwrap(); + let metrics = trajectory.steps[0].metrics.as_ref().unwrap(); + assert_eq!(metrics.cost_usd, Some(0.000_435)); + assert_eq!( + trajectory.final_metrics.as_ref().unwrap().total_cost_usd, + Some(0.000_435) + ); +} + +#[test] +fn test_exporter_uses_normalized_usage_cost_before_model_pricing() { + let exporter = AtifExporter::new("session-1".to_string(), make_agent_info()); + let llm_uuid = Uuid::now_v7(); + + let end = event_builder(llm_uuid, EventType::End) + .name("unknown-model") + .scope_type(ScopeType::Llm) + .output(json!({ + "content": "priced response", + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 500, + "cost": { + "total": 0.42, + "source": "provider_reported", + "pricing_model": "external-model", + "pricing_as_of": "2026-06-04" + } + } + })) + .model_name("unknown-model") + .build(); + + { + let mut state = exporter.state.lock().unwrap(); + state.events.push(end); + } + + let trajectory = exporter.export().unwrap(); + let metrics = trajectory.steps[0].metrics.as_ref().unwrap(); + assert_eq!(metrics.cost_usd, Some(0.42)); + assert_eq!( + trajectory.final_metrics.as_ref().unwrap().total_cost_usd, + Some(0.42) + ); +} + +#[test] +fn test_exporter_omits_cost_for_unknown_model_pricing() { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + reset_active_pricing_resolver().unwrap(); + let metrics = extract_metrics( + &json!({ + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 500 + } + }), + None, + Some("unknown-model"), + ) + .unwrap(); + + assert_eq!(metrics.cost_usd, None); } #[test] @@ -3935,13 +4178,14 @@ fn test_metrics_extra_captures_unknown_token_usage_keys() { let metrics = trajectory.steps[0].metrics.as_ref().unwrap(); assert_eq!(metrics.prompt_tokens, Some(20)); assert_eq!(metrics.completion_tokens, Some(10)); + assert_eq!(metrics.cached_tokens, Some(5)); // Unknown keys land in extra let extra = metrics.extra.as_ref().unwrap(); assert_eq!(extra["reasoning_tokens"], json!(150)); - assert_eq!(extra["cache_creation_input_tokens"], json!(5)); // Known keys do not appear in extra assert!(extra.get("prompt_tokens").is_none()); assert!(extra.get("completion_tokens").is_none()); + assert!(extra.get("cache_creation_input_tokens").is_none()); } #[test] diff --git a/crates/core/tests/unit/codec/openai_chat_tests.rs b/crates/core/tests/unit/codec/openai_chat_tests.rs index f54562af..eb418a03 100644 --- a/crates/core/tests/unit/codec/openai_chat_tests.rs +++ b/crates/core/tests/unit/codec/openai_chat_tests.rs @@ -7,7 +7,7 @@ use super::*; use serde_json::json; use super::super::request::{ContentPart, MessageContent, OpenAiImageUrl}; -use super::super::response::{ApiSpecificResponse, FinishReason}; +use super::super::response::{ApiSpecificResponse, CostSource, FinishReason}; // ------------------------------------------------------------------- // Helpers @@ -115,6 +115,38 @@ fn test_decode_response_cached_tokens() { assert_eq!(usage.cache_read_tokens, Some(42)); } +#[test] +fn test_decode_response_provider_reported_cost() { + let codec = OpenAIChatCodec; + let response = json!({ + "id": "chatcmpl_cost", + "object": "chat.completion", + "model": "gpt-4o-mini", + "choices": [{ + "message": {"role": "assistant", "content": "ok"}, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + "cost": { + "total": 0.0123, + "input": 0.004, + "output": 0.0083 + } + } + }); + + let resp = codec.decode_response(&response).unwrap(); + let cost = resp.usage.unwrap().cost.unwrap(); + + assert_eq!(cost.total, Some(0.0123)); + assert_eq!(cost.input, Some(0.004)); + assert_eq!(cost.output, Some(0.0083)); + assert_eq!(cost.source, CostSource::ProviderReported); +} + #[test] fn test_decode_response_finish_reason_stop() { let codec = OpenAIChatCodec; diff --git a/crates/core/tests/unit/codec/response_tests.rs b/crates/core/tests/unit/codec/response_tests.rs index 945f0fc1..d184a9cb 100644 --- a/crates/core/tests/unit/codec/response_tests.rs +++ b/crates/core/tests/unit/codec/response_tests.rs @@ -4,11 +4,33 @@ //! Unit tests for response in the NeMo Relay core crate. use super::*; -use serde_json::json; +use serde_json::{Value, json}; +use std::fs; +use std::time::{SystemTime, UNIX_EPOCH}; use super::super::request::ContentPart; use super::super::traits::LlmResponseCodec; +use crate::codec::pricing::pricing_test_mutex; use crate::error::FlowError; +use crate::plugin::{ + PluginComponentSpec, PluginConfig, clear_plugin_configuration, initialize_plugins, +}; + +struct ResetPricingResolverGuard; + +impl Drop for ResetPricingResolverGuard { + fn drop(&mut self) { + let _ = reset_active_pricing_resolver(); + } +} + +struct ClearPluginConfigurationGuard; + +impl Drop for ClearPluginConfigurationGuard { + fn drop(&mut self) { + let _ = clear_plugin_configuration(); + } +} /// Helper: build a fully-populated AnnotatedLlmResponse. fn full_response() -> AnnotatedLlmResponse { @@ -28,6 +50,7 @@ fn full_response() -> AnnotatedLlmResponse { total_tokens: Some(30), cache_read_tokens: Some(5), cache_write_tokens: Some(3), + cost: None, }), api_specific: Some(ApiSpecificResponse::OpenAIChat { logprobs: None, @@ -52,6 +75,72 @@ fn minimal_response() -> AnnotatedLlmResponse { } } +fn pricing_catalog(entries: Value) -> PricingCatalog { + PricingCatalog::from_json_str(&json!({ "version": 1, "entries": entries }).to_string()).unwrap() +} + +fn pricing_catalog_error(entries: Value) -> PricingCatalogError { + PricingCatalog::from_json_str(&json!({ "version": 1, "entries": entries }).to_string()) + .unwrap_err() +} + +fn flat_pricing_entry( + provider: &str, + model_id: &str, + input_per_million: f64, + output_per_million: f64, +) -> Value { + json!({ + "provider": provider, + "model_id": model_id, + "pricing_as_of": "2026-06-04", + "pricing_source": format!("https://example.test/{provider}"), + "rates": { + "input_per_million": input_per_million, + "output_per_million": output_per_million + }, + "prompt_cache": { + "read_accounting": "separate" + } + }) +} + +fn threshold_pricing_catalog(read_accounting: &str) -> PricingCatalog { + pricing_catalog(json!([ + { + "provider": "threshold-ai", + "model_id": "threshold-model", + "pricing_as_of": "2026-06-05", + "pricing_source": "https://example.test/pricing", + "rate_schedule": { + "type": "prompt_token_threshold", + "applies_to": "full_request", + "tiers": [ + { + "max_prompt_tokens": 200000, + "rates": { + "input_per_million": 1.0, + "output_per_million": 2.0, + "cache_read_per_million": 0.1 + } + }, + { + "min_prompt_tokens": 200001, + "rates": { + "input_per_million": 10.0, + "output_per_million": 20.0, + "cache_read_per_million": 1.0 + } + } + ] + }, + "prompt_cache": { + "read_accounting": read_accounting + } + } + ])) +} + // ------------------------------------------------------------------- // AnnotatedLlmResponse serialization // ------------------------------------------------------------------- @@ -97,12 +186,816 @@ fn test_usage_all_populated_round_trip() { total_tokens: Some(150), cache_read_tokens: Some(20), cache_write_tokens: Some(10), + cost: None, + }; + let json_val = serde_json::to_value(&usage).unwrap(); + let deserialized: Usage = serde_json::from_value(json_val).unwrap(); + assert_eq!(usage, deserialized); +} + +#[test] +fn test_default_pricing_resolver_has_no_model_prices() { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + reset_active_pricing_resolver().unwrap(); + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + total_tokens: Some(1_500), + cache_read_tokens: Some(200), + cache_write_tokens: None, + cost: None, + }; + + assert_eq!(estimate_cost("configured-model", &usage), None); +} + +#[test] +fn test_configured_model_pricing_estimates_total_cost() { + let catalog = pricing_catalog(json!([ + { + "provider": "configured", + "model_id": "configured-model", + "pricing_as_of": "2026-06-04", + "pricing_source": "file:///tmp/pricing.json", + "rates": { + "input_per_million": 0.15, + "output_per_million": 0.60, + "cache_read_per_million": 0.075 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + } + ])); + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + total_tokens: Some(1_500), + cache_read_tokens: Some(200), + cache_write_tokens: None, + cost: None, + }; + + let cost = estimate_cost_with_catalog(&catalog, "configured-model", &usage).unwrap(); + + assert_eq!(cost.total, Some(0.000_435)); + assert_eq!(cost.currency, "USD"); + assert_eq!(cost.input, Some(0.000_12)); + assert_eq!(cost.output, Some(0.000_3)); + assert_eq!(cost.cache_read, Some(0.000_015)); + assert_eq!(cost.cache_write, None); + assert_eq!(cost.source, CostSource::ModelPricing); + assert_eq!(cost.pricing_provider.as_deref(), Some("configured")); + assert_eq!(cost.pricing_model.as_deref(), Some("configured-model")); + assert_eq!(cost.pricing_as_of.as_deref(), Some("2026-06-04")); + assert_eq!( + cost.pricing_source.as_deref(), + Some("file:///tmp/pricing.json") + ); +} + +#[test] +fn test_pricing_catalog_uses_data_driven_alias_entries() { + let catalog = pricing_catalog(json!([ + { + "provider": "configured", + "model_id": "configured-model", + "aliases": ["configured-model-2026-06-04"], + "pricing_as_of": "2026-06-04", + "pricing_source": "file:///tmp/pricing.json", + "rates": { + "input_per_million": 0.15, + "output_per_million": 0.60 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + } + ])); + let pricing = catalog + .pricing_for_model("CONFIGURED-MODEL-2026-06-04") + .expect("alias should resolve from configured catalog"); + + assert_eq!(pricing.provider, "configured"); + assert_eq!(pricing.model_id, "configured-model"); + assert_eq!(pricing.currency, "USD"); + assert_eq!(pricing.unit, PricingUnit::PerToken); + assert_eq!(pricing.pricing_as_of, "2026-06-04"); + let rates = pricing.rates.as_ref().unwrap(); + assert_eq!(rates.input_per_million, 0.15); + assert_eq!(rates.output_per_million, 0.60); + assert_eq!( + pricing.prompt_cache.read_accounting, + CacheReadAccounting::IncludedInPromptTokens + ); +} + +#[test] +fn test_pricing_catalog_preserves_currency_and_unit() { + let catalog = pricing_catalog(json!([ + { + "provider": "enterprise", + "model_id": "regional-model", + "currency": "EUR", + "unit": "per_token", + "pricing_as_of": "2026-06-04", + "pricing_source": "postgres://pricing/model_prices", + "rates": { + "input_per_million": 1.0, + "output_per_million": 2.0 + }, + "prompt_cache": { + "read_accounting": "separate" + } + } + ])); + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + ..Usage::default() + }; + + let pricing = catalog.pricing_for_model("regional-model").unwrap(); + let cost = estimate_cost_with_catalog(&catalog, "regional-model", &usage).unwrap(); + + assert_eq!(pricing.currency, "EUR"); + assert_eq!(pricing.unit, PricingUnit::PerToken); + assert_eq!(cost.currency, "EUR"); + assert_eq!(cost.total, Some(0.002)); +} + +#[test] +fn test_non_token_pricing_units_are_representable_but_not_estimated() { + let catalog = pricing_catalog(json!([ + { + "provider": "self-hosted", + "model_id": "nemotron-owned", + "unit": "gpu_hour", + "pricing_as_of": "2026-06-04", + "pricing_source": "internal-owned-fleet-snapshot", + "prompt_cache": { + "read_accounting": "separate" + } + } + ])); + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + ..Usage::default() + }; + + let pricing = catalog.pricing_for_model("nemotron-owned").unwrap(); + + assert_eq!(pricing.unit, PricingUnit::GpuHour); + assert_eq!(pricing.rates, None); + assert_eq!( + estimate_cost_with_catalog(&catalog, "nemotron-owned", &usage), + None + ); +} + +#[test] +fn test_per_token_pricing_requires_token_rates() { + let err = pricing_catalog_error(json!([ + { + "provider": "broken", + "model_id": "missing-rates", + "unit": "per_token", + "pricing_as_of": "2026-06-04", + "pricing_source": "test", + "prompt_cache": { + "read_accounting": "separate" + } + } + ])); + + assert!(err.to_string().contains("empty rates or rate_schedule")); +} + +#[test] +fn test_pricing_catalog_normalizes_routed_model_names() { + let catalog = pricing_catalog(json!([flat_pricing_entry( + "openai", + "gpt-4o-mini", + 0.15, + 0.6 + )])); + + let azure_pricing = catalog + .pricing_for_model("azure/openai/gpt-4o-mini") + .expect("routed provider/model name should resolve"); + let openai_pricing = catalog + .pricing_for_model("openai/openai/gpt-4o-mini") + .expect("routed provider/model-owner name should resolve"); + + assert_eq!(azure_pricing.provider, "openai"); + assert_eq!(azure_pricing.model_id, "gpt-4o-mini"); + assert_eq!(openai_pricing.provider, "openai"); + assert_eq!(openai_pricing.model_id, "gpt-4o-mini"); +} + +#[test] +fn test_pricing_resolver_prefers_exact_routed_model_before_suffix_fallback() { + let catalog = pricing_catalog(json!([ + { + "provider": "openai", + "model_id": "gpt-4o-mini", + "pricing_as_of": "2026-06-04", + "pricing_source": "https://example.test/openai", + "rates": { + "input_per_million": 0.15, + "output_per_million": 0.6 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + }, + { + "provider": "azure-openai", + "model_id": "azure/openai/gpt-4o-mini", + "pricing_as_of": "2026-06-04", + "pricing_source": "https://example.test/azure-openai", + "rates": { + "input_per_million": 1.0, + "output_per_million": 2.0 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + } + ])); + let resolver = PricingResolver::from_catalogs(vec![catalog]); + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + ..Usage::default() + }; + + let cost = resolver + .estimate_cost("azure/openai/gpt-4o-mini", &usage) + .unwrap(); + + assert_eq!(cost.total, Some(0.002)); + assert_eq!(cost.pricing_provider.as_deref(), Some("azure-openai")); + assert_eq!( + cost.pricing_model.as_deref(), + Some("azure/openai/gpt-4o-mini") + ); +} + +#[test] +fn test_pricing_catalog_allows_same_model_id_for_distinct_providers() { + let catalog = pricing_catalog(json!([ + { + "provider": "openai", + "model_id": "gpt-4o-mini", + "pricing_as_of": "2026-06-04", + "pricing_source": "https://example.test/openai", + "rates": { + "input_per_million": 0.15, + "output_per_million": 0.6 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + }, + { + "provider": "azure/openai", + "model_id": "gpt-4o-mini", + "pricing_as_of": "2026-06-04", + "pricing_source": "https://example.test/azure-openai", + "rates": { + "input_per_million": 1.0, + "output_per_million": 2.0 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + } + ])); + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + ..Usage::default() + }; + + let openai = estimate_cost_with_provider(&catalog, Some("openai"), "gpt-4o-mini", &usage) + .expect("openai provider price should resolve"); + let azure = estimate_cost_with_provider(&catalog, Some("azure/openai"), "gpt-4o-mini", &usage) + .expect("azure/openai provider price should resolve"); + + assert_eq!(openai.total, Some(0.000_45)); + assert_eq!(openai.pricing_provider.as_deref(), Some("openai")); + assert_eq!(azure.total, Some(0.002)); + assert_eq!(azure.pricing_provider.as_deref(), Some("azure/openai")); +} + +#[test] +fn test_attach_estimated_cost_uses_event_provider() { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + let catalog = pricing_catalog(json!([ + flat_pricing_entry("openai", "same-model", 1.0, 2.0), + flat_pricing_entry("azure/openai", "same-model", 10.0, 20.0) + ])); + set_active_pricing_resolver(PricingResolver::from_catalogs(vec![catalog])).unwrap(); + let _reset_guard = ResetPricingResolverGuard; + + let mut response = AnnotatedLlmResponse { + model: Some("same-model".into()), + usage: Some(Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + ..Usage::default() + }), + ..minimal_response() + }; + + attach_estimated_cost_for_provider(&mut response, Some("azure/openai")); + + let cost = response.usage.unwrap().cost.unwrap(); + assert_eq!(cost.total, Some(0.02)); + assert_eq!(cost.pricing_provider.as_deref(), Some("azure/openai")); +} + +#[test] +fn test_custom_pricing_catalog_supports_future_models_without_code_changes() { + let catalog = pricing_catalog(json!([ + { + "provider": "future-ai", + "model_id": "future-model", + "aliases": ["future-model-latest"], + "pricing_as_of": "2026-06-04", + "pricing_source": "https://example.test/pricing", + "rates": { + "input_per_million": 1.0, + "output_per_million": 2.0, + "cache_read_per_million": 0.25, + "cache_write_per_million": 1.5 + }, + "prompt_cache": { + "read_accounting": "separate" + } + } + ])); + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(2_000), + cache_read_tokens: Some(3_000), + cache_write_tokens: Some(4_000), + ..Usage::default() + }; + + let cost = estimate_cost_with_catalog(&catalog, "future-model-latest", &usage).unwrap(); + + assert_eq!(cost.total, Some(0.011_75)); + assert_eq!(cost.input, Some(0.001)); + assert_eq!(cost.output, Some(0.004)); + assert_eq!(cost.cache_read, Some(0.000_75)); + assert_eq!(cost.cache_write, Some(0.006)); + assert_eq!(cost.pricing_provider.as_deref(), Some("future-ai")); + assert_eq!(cost.pricing_model.as_deref(), Some("future-model")); +} + +#[test] +fn test_prompt_threshold_pricing_applies_selected_tier_to_full_request() { + let catalog = threshold_pricing_catalog("included_in_prompt_tokens"); + let usage = Usage { + prompt_tokens: Some(200_001), + completion_tokens: Some(1_000), + cache_read_tokens: Some(1_000), + ..Usage::default() + }; + + let cost = estimate_cost_with_catalog(&catalog, "threshold-model", &usage).unwrap(); + + assert_eq!(cost.input, Some(1.990_01)); + assert_eq!(cost.output, Some(0.02)); + assert_eq!(cost.cache_read, Some(0.001)); + assert_eq!(cost.total, Some(2.011_01)); +} + +#[test] +fn test_prompt_threshold_pricing_uses_lower_tier_at_boundary() { + let catalog = threshold_pricing_catalog("separate"); + let usage = Usage { + prompt_tokens: Some(200_000), + completion_tokens: Some(1_000), + ..Usage::default() + }; + + let cost = estimate_cost_with_catalog(&catalog, "threshold-model", &usage).unwrap(); + + assert_eq!(cost.input, Some(0.2)); + assert_eq!(cost.output, Some(0.002)); + assert_eq!(cost.total, Some(0.202)); +} + +#[test] +fn test_prompt_threshold_pricing_requires_prompt_tokens() { + let catalog = threshold_pricing_catalog("separate"); + let usage = Usage { + completion_tokens: Some(1_000), + ..Usage::default() + }; + + assert!(estimate_cost_with_catalog(&catalog, "threshold-model", &usage).is_none()); +} + +#[test] +fn test_pricing_resolver_uses_first_matching_source() { + let override_catalog = pricing_catalog(json!([ + { + "provider": "local-override", + "model_id": "gpt-4o-mini", + "pricing_as_of": "2026-06-04", + "pricing_source": "file:///tmp/local-pricing.json", + "rates": { + "input_per_million": 1.0, + "output_per_million": 2.0, + "cache_read_per_million": 0.5 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + } + ])); + let resolver = PricingResolver::from_catalogs(vec![override_catalog]); + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + cache_read_tokens: Some(200), + ..Usage::default() + }; + + let cost = resolver.estimate_cost("gpt-4o-mini", &usage).unwrap(); + + assert_eq!(cost.total, Some(0.001_9)); + assert_eq!(cost.pricing_provider.as_deref(), Some("local-override")); + assert!(resolver.estimate_cost("missing-model", &usage).is_none()); +} + +#[test] +fn test_pricing_resolver_loads_inline_and_file_sources_in_order() { + let temp = std::env::temp_dir().join(format!( + "nemo-relay-pricing-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() + )); + fs::create_dir_all(&temp).unwrap(); + let file_path = temp.join("pricing.json"); + fs::write( + &file_path, + json!({ + "version": 1, + "entries": [flat_pricing_entry("global-file", "file-model", 4.0, 8.0)] + }) + .to_string(), + ) + .unwrap(); + let inline_catalog = pricing_catalog(json!([flat_pricing_entry( + "project-inline", + "inline-model", + 1.0, + 2.0 + )])); + let config = PricingConfig { + sources: vec![ + PricingSourceConfig::Inline { + catalog: inline_catalog, + }, + PricingSourceConfig::File { path: file_path }, + ], + }; + let resolver = PricingResolver::from_config(&config).unwrap(); + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + ..Usage::default() + }; + + let inline = resolver.estimate_cost("inline-model", &usage).unwrap(); + let file = resolver.estimate_cost("file-model", &usage).unwrap(); + + assert_eq!(inline.total, Some(0.002)); + assert_eq!(inline.pricing_provider.as_deref(), Some("project-inline")); + assert_eq!(file.total, Some(0.008)); + assert_eq!(file.pricing_provider.as_deref(), Some("global-file")); + assert!(resolver.estimate_cost("gpt-4o-mini", &usage).is_none()); + fs::remove_dir_all(temp).unwrap(); +} + +#[test] +fn test_pricing_resolver_validates_inline_catalogs() { + let config = PricingConfig { + sources: vec![PricingSourceConfig::Inline { + catalog: PricingCatalog { + version: 2, + entries: vec![], + }, + }], + }; + + let err = PricingResolver::from_config(&config).unwrap_err(); + + assert!( + err.to_string() + .contains("unsupported pricing catalog version 2") + ); +} + +#[test] +fn test_pricing_resolver_accepts_custom_database_backed_sources() { + struct TestDatabasePricingSource { + catalog: PricingCatalog, + } + + impl PricingSource for TestDatabasePricingSource { + fn source_name(&self) -> &str { + "test-db" + } + + fn load_catalog(&self) -> Result, PricingCatalogError> { + Ok(Some(self.catalog.clone())) + } + } + + let catalog = pricing_catalog(json!([flat_pricing_entry( + "database", "db-model", 10.0, 20.0 + )])); + let resolver = + PricingResolver::from_sources(vec![Box::new(TestDatabasePricingSource { catalog })]) + .unwrap(); + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + ..Usage::default() }; + + let cost = resolver.estimate_cost("db-model", &usage).unwrap(); + + assert_eq!(cost.total, Some(0.02)); + assert_eq!(cost.pricing_provider.as_deref(), Some("database")); +} + +#[test] +fn test_pricing_resolver_validates_custom_source_catalogs() { + struct InvalidDatabasePricingSource; + + impl PricingSource for InvalidDatabasePricingSource { + fn source_name(&self) -> &str { + "invalid-test-db" + } + + fn load_catalog(&self) -> Result, PricingCatalogError> { + Ok(Some(PricingCatalog { + version: 2, + entries: vec![], + })) + } + } + + let err = + PricingResolver::from_sources(vec![Box::new(InvalidDatabasePricingSource)]).unwrap_err(); + + assert!( + err.to_string() + .contains("unsupported pricing catalog version 2") + ); +} + +#[test] +fn test_pricing_plugin_configures_process_resolver_and_clears_to_default() { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + let mut component = PluginComponentSpec::new("pricing"); + component.config = serde_json::from_value(json!({ + "sources": [ + { + "type": "inline", + "catalog": { + "version": 1, + "entries": [ + { + "provider": "plugin-inline", + "model_id": "plugin-model", + "pricing_as_of": "2026-06-04", + "pricing_source": "plugins.toml", + "rates": { + "input_per_million": 1.0, + "output_per_million": 2.0 + }, + "prompt_cache": { + "read_accounting": "separate" + } + } + ] + } + } + ] + })) + .unwrap(); + let mut config = PluginConfig::default(); + config.components.push(component); + + tokio::runtime::Runtime::new() + .unwrap() + .block_on(async { initialize_plugins(config).await.unwrap() }); + let _clear_guard = ClearPluginConfigurationGuard; + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + ..Usage::default() + }; + + let configured = estimate_cost("plugin-model", &usage).unwrap(); + assert_eq!(configured.total, Some(0.002)); + assert!(estimate_cost("gpt-4o-mini", &usage).is_none()); + + clear_plugin_configuration().unwrap(); + + assert!(estimate_cost("plugin-model", &usage).is_none()); + assert!(estimate_cost("gpt-4o-mini", &usage).is_none()); +} + +#[test] +fn test_pricing_catalog_rejects_duplicate_model_aliases() { + let err = pricing_catalog_error(json!([ + flat_pricing_entry("a", "same-model", 1.0, 1.0), + { + "provider": "a", + "model_id": "other-model", + "aliases": ["SAME-MODEL"], + "pricing_as_of": "2026-06-04", + "pricing_source": "https://example.test/b", + "rates": { + "input_per_million": 1.0, + "output_per_million": 1.0 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + } + ])); + + assert!( + err.to_string() + .contains("duplicate pricing model alias 'a/same-model'") + ); +} + +#[test] +fn test_pricing_catalog_rejects_unsupported_schema_version() { + let err = PricingCatalog::from_json_str( + r#"{ + "version": 2, + "entries": [] + }"#, + ) + .unwrap_err(); + + assert!( + err.to_string() + .contains("unsupported pricing catalog version 2") + ); +} + +#[test] +fn test_missing_token_pricing_returns_none_without_fabricating_zero_cost() { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + reset_active_pricing_resolver().unwrap(); + assert_eq!(estimate_cost("gpt-4o-mini", &Usage::default()), None); + assert_eq!(estimate_cost("gpt-4o-mini", &Usage::default()), None); +} + +#[test] +fn test_provider_reported_cost_sums_components_and_defaults_currency() { + let cost = provider_reported_cost( + None, + Some(RawUsageCost { + input: Some(0.12), + output: Some(0.30), + cache_read: Some(0.01), + cache_write: Some(0.02), + ..RawUsageCost::default() + }), + ) + .expect("component-only provider cost should be retained"); + + assert_eq!(cost.total, Some(0.45)); + assert_eq!(cost.currency, "USD"); + assert_eq!(cost.source, CostSource::ProviderReported); +} + +#[test] +fn test_provider_reported_cost_keeps_top_level_provider_usd_currency() { + let cost = provider_reported_cost( + Some(0.42), + Some(RawUsageCost { + currency: Some("EUR".to_string()), + ..RawUsageCost::default() + }), + ) + .expect("top-level provider USD cost should be retained"); + + assert_eq!(cost.total, Some(0.42)); + assert_eq!(cost.currency, "USD"); + assert_eq!(cost.total_for_currency("USD"), Some(0.42)); +} + +#[test] +fn test_usage_cost_round_trip_preserves_model_pricing_codec_compatibility() { + let catalog = pricing_catalog(json!([ + { + "provider": "configured", + "model_id": "configured-model", + "pricing_as_of": "2026-06-04", + "pricing_source": "file:///tmp/pricing.json", + "rates": { + "input_per_million": 0.15, + "output_per_million": 0.60, + "cache_read_per_million": 0.075 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + } + ])); + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + total_tokens: Some(1_500), + cache_read_tokens: Some(200), + cache_write_tokens: None, + cost: estimate_cost_with_catalog( + &catalog, + "configured-model", + &Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + total_tokens: Some(1_500), + cache_read_tokens: Some(200), + cache_write_tokens: None, + cost: None, + }, + ), + }; + let json_val = serde_json::to_value(&usage).unwrap(); + + assert_eq!(json_val["cost"]["total"], json!(0.000_435)); + assert_eq!(json_val["cost"]["source"], json!("model_pricing")); + assert_eq!(json_val["cost"]["pricing_as_of"], json!("2026-06-04")); let deserialized: Usage = serde_json::from_value(json_val).unwrap(); assert_eq!(usage, deserialized); } +#[test] +fn test_usage_cost_round_trip_preserves_provider_reported_codec_compatibility() { + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + total_tokens: Some(1_500), + cache_read_tokens: Some(200), + cache_write_tokens: None, + cost: Some(CostEstimate { + total: Some(0.42), + currency: "USD".into(), + input: Some(0.12), + output: Some(0.30), + cache_read: None, + cache_write: None, + source: CostSource::ProviderReported, + pricing_provider: None, + pricing_model: None, + pricing_as_of: None, + pricing_source: None, + }), + }; + + let json_val = serde_json::to_value(&usage).unwrap(); + + assert_eq!(json_val["cost"]["total"], json!(0.42)); + assert_eq!(json_val["cost"]["source"], json!("provider_reported")); + let deserialized: Usage = serde_json::from_value(json_val).unwrap(); + assert_eq!(usage, deserialized); +} + +#[test] +fn test_unknown_model_pricing_returns_none_without_blocking_usage() { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + reset_active_pricing_resolver().unwrap(); + let usage = Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + ..Usage::default() + }; + + assert_eq!(estimate_cost("unknown-model", &usage), None); + assert_eq!(usage.prompt_tokens, Some(1_000)); +} + // ------------------------------------------------------------------- // FinishReason serialization // ------------------------------------------------------------------- diff --git a/crates/core/tests/unit/observability/openinference_tests.rs b/crates/core/tests/unit/observability/openinference_tests.rs index 4488b4b9..0f23e687 100644 --- a/crates/core/tests/unit/observability/openinference_tests.rs +++ b/crates/core/tests/unit/observability/openinference_tests.rs @@ -13,11 +13,15 @@ use crate::api::runtime::global_context; use crate::api::scope::ScopeType; use crate::api::scope::{event, pop_scope, push_scope}; use crate::api::tool::ToolAttributes; +use crate::codec::pricing::pricing_test_mutex; use crate::codec::request::{ AnnotatedLlmRequest, FunctionDefinition, GenerationParams, Message, MessageContent, ToolDefinition, }; -use crate::codec::response::{AnnotatedLlmResponse, FinishReason, ResponseToolCall, Usage}; +use crate::codec::response::{ + AnnotatedLlmResponse, CostEstimate, CostSource, FinishReason, PricingCatalog, PricingResolver, + ResponseToolCall, Usage, reset_active_pricing_resolver, set_active_pricing_resolver, +}; use crate::json::Json; use crate::observability::atif::{AtifAgentInfo, AtifExporter, AtifStepExtra}; use opentelemetry_sdk::trace::InMemorySpanExporterBuilder; @@ -29,6 +33,14 @@ use std::sync::mpsc; use std::thread; use uuid::Uuid; +struct ResetPricingResolverGuard; + +impl Drop for ResetPricingResolverGuard { + fn drop(&mut self) { + let _ = reset_active_pricing_resolver(); + } +} + fn reset_global() { crate::shared_runtime::reset_runtime_owner_for_tests(); let context = global_context(); @@ -118,6 +130,33 @@ fn empty_annotated_response() -> AnnotatedLlmResponse { } } +fn install_test_pricing(model_id: &str) { + let catalog = PricingCatalog::from_json_str( + &json!({ + "version": 1, + "entries": [ + { + "provider": "test", + "model_id": model_id, + "pricing_as_of": "2026-06-05", + "pricing_source": "test", + "rates": { + "input_per_million": 0.15, + "output_per_million": 0.60, + "cache_read_per_million": 0.075 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + } + ] + }) + .to_string(), + ) + .unwrap(); + set_active_pricing_resolver(PricingResolver::from_catalogs(vec![catalog])).unwrap(); +} + fn sample_openinference_annotated_request() -> AnnotatedLlmRequest { AnnotatedLlmRequest { messages: vec![ @@ -2107,6 +2146,7 @@ fn llm_end_with_usage_emits_token_count_attributes() { total_tokens: Some(150), cache_read_tokens: Some(25), cache_write_tokens: Some(10), + cost: None, }), api_specific: None, extra: serde_json::Map::new(), @@ -2143,6 +2183,282 @@ fn llm_end_with_usage_emits_token_count_attributes() { assert!(!attributes.contains_key("llm.output_messages.0.message.role")); } +#[test] +fn llm_end_with_known_model_usage_emits_derived_cost_attribute() { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + install_test_pricing("priced-model"); + let _reset_guard = ResetPricingResolverGuard; + let (provider, exporter) = make_provider(); + let mut processor = + OpenInferenceEventProcessor::new(provider.clone(), "test-scope".to_string()); + let uuid = Uuid::now_v7(); + + processor.process(&make_start_event(uuid, None, "chat", ScopeType::Llm, None)); + processor.process(&make_scope_event_with_profile( + ScopeCategory::End, + uuid, + None, + "chat", + ScopeType::Llm, + Some(json!({"message": "hello"})), + Some( + CategoryProfile::builder() + .model_name("priced-model") + .annotated_response(Arc::new(AnnotatedLlmResponse { + usage: Some(Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + total_tokens: Some(1_500), + cache_read_tokens: Some(200), + cache_write_tokens: None, + cost: None, + }), + ..empty_annotated_response() + })) + .build(), + ), + )); + + processor.force_flush().unwrap(); + + let spans = exporter.get_finished_spans().unwrap(); + let attributes = attr_map(&spans[0].attributes); + assert_eq!( + attributes.get("llm.cost.total"), + Some(&"0.000435".to_string()) + ); +} + +#[test] +fn llm_end_with_manual_usage_and_output_model_emits_derived_cost_attribute() { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + install_test_pricing("priced-model"); + let _reset_guard = ResetPricingResolverGuard; + let (provider, exporter) = make_provider(); + let mut processor = + OpenInferenceEventProcessor::new(provider.clone(), "test-scope".to_string()); + let uuid = Uuid::now_v7(); + + processor.process(&make_start_event(uuid, None, "chat", ScopeType::Llm, None)); + processor.process(&make_end_event( + uuid, + None, + "chat", + ScopeType::Llm, + Some(json!({ + "model": "priced-model", + "usage": { + "prompt_tokens": 1_000, + "completion_tokens": 500, + "total_tokens": 1_500, + "prompt_tokens_details": {"cached_tokens": 200} + } + })), + )); + + processor.force_flush().unwrap(); + + let spans = exporter.get_finished_spans().unwrap(); + let attributes = attr_map(&spans[0].attributes); + assert_eq!( + attributes.get("llm.cost.total"), + Some(&"0.000435".to_string()) + ); +} + +#[test] +fn llm_end_with_normalized_usage_cost_emits_cost_attribute() { + let (provider, exporter) = make_provider(); + let mut processor = + OpenInferenceEventProcessor::new(provider.clone(), "test-scope".to_string()); + let uuid = Uuid::now_v7(); + + processor.process(&make_start_event(uuid, None, "chat", ScopeType::Llm, None)); + processor.process(&make_scope_event_with_profile( + ScopeCategory::End, + uuid, + None, + "chat", + ScopeType::Llm, + Some(json!({"message": "hello"})), + Some( + CategoryProfile::builder() + .model_name("unknown-model") + .annotated_response(Arc::new(AnnotatedLlmResponse { + usage: Some(Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + cost: Some(CostEstimate { + total: Some(0.42), + currency: "USD".into(), + input: None, + output: None, + cache_read: None, + cache_write: None, + source: CostSource::ProviderReported, + pricing_provider: Some("external".to_string()), + pricing_model: Some("external-model".to_string()), + pricing_as_of: Some("2026-06-04".to_string()), + pricing_source: None, + }), + ..Usage::default() + }), + ..empty_annotated_response() + })) + .build(), + ), + )); + + processor.force_flush().unwrap(); + + let spans = exporter.get_finished_spans().unwrap(); + let attributes = attr_map(&spans[0].attributes); + assert_eq!(attributes.get("llm.cost.total"), Some(&"0.42".to_string())); +} + +#[test] +fn llm_end_with_component_only_usd_usage_cost_emits_cost_attribute() { + let (provider, exporter) = make_provider(); + let mut processor = + OpenInferenceEventProcessor::new(provider.clone(), "test-scope".to_string()); + let uuid = Uuid::now_v7(); + + processor.process(&make_start_event(uuid, None, "chat", ScopeType::Llm, None)); + processor.process(&make_scope_event_with_profile( + ScopeCategory::End, + uuid, + None, + "chat", + ScopeType::Llm, + Some(json!({"message": "hello"})), + Some( + CategoryProfile::builder() + .model_name("unknown-model") + .annotated_response(Arc::new(AnnotatedLlmResponse { + usage: Some(Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + cost: Some(CostEstimate { + total: None, + currency: "usd".into(), + input: Some(0.25), + output: Some(0.5), + cache_read: Some(0.125), + cache_write: None, + source: CostSource::ProviderReported, + pricing_provider: Some("external".to_string()), + pricing_model: Some("external-model".to_string()), + pricing_as_of: Some("2026-06-04".to_string()), + pricing_source: None, + }), + ..Usage::default() + }), + ..empty_annotated_response() + })) + .build(), + ), + )); + + processor.force_flush().unwrap(); + + let spans = exporter.get_finished_spans().unwrap(); + let attributes = attr_map(&spans[0].attributes); + assert_eq!(attributes.get("llm.cost.total"), Some(&"0.875".to_string())); +} + +#[test] +fn llm_end_with_non_usd_normalized_usage_cost_blocks_model_pricing_estimate() { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + install_test_pricing("priced-model"); + let _reset_guard = ResetPricingResolverGuard; + let (provider, exporter) = make_provider(); + let mut processor = + OpenInferenceEventProcessor::new(provider.clone(), "test-scope".to_string()); + let uuid = Uuid::now_v7(); + + processor.process(&make_start_event(uuid, None, "test", ScopeType::Llm, None)); + processor.process(&make_scope_event_with_profile( + ScopeCategory::End, + uuid, + None, + "test", + ScopeType::Llm, + Some(json!({"message": "hello"})), + Some( + CategoryProfile::builder() + .model_name("priced-model") + .annotated_response(Arc::new(AnnotatedLlmResponse { + usage: Some(Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + cost: Some(CostEstimate { + total: Some(0.42), + currency: "EUR".into(), + input: None, + output: None, + cache_read: None, + cache_write: None, + source: CostSource::ProviderReported, + pricing_provider: Some("external".to_string()), + pricing_model: Some("external-model".to_string()), + pricing_as_of: Some("2026-06-04".to_string()), + pricing_source: None, + }), + ..Usage::default() + }), + ..empty_annotated_response() + })) + .build(), + ), + )); + + processor.force_flush().unwrap(); + + let spans = exporter.get_finished_spans().unwrap(); + let attributes = attr_map(&spans[0].attributes); + assert!(!attributes.contains_key("llm.cost.total")); +} + +#[test] +fn llm_end_with_unknown_model_usage_omits_derived_cost_attribute() { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + reset_active_pricing_resolver().unwrap(); + let _reset_guard = ResetPricingResolverGuard; + let (provider, exporter) = make_provider(); + let mut processor = + OpenInferenceEventProcessor::new(provider.clone(), "test-scope".to_string()); + let uuid = Uuid::now_v7(); + + processor.process(&make_start_event(uuid, None, "chat", ScopeType::Llm, None)); + processor.process(&make_scope_event_with_profile( + ScopeCategory::End, + uuid, + None, + "chat", + ScopeType::Llm, + Some(json!({"message": "hello"})), + Some( + CategoryProfile::builder() + .model_name("unknown-model") + .annotated_response(Arc::new(AnnotatedLlmResponse { + usage: Some(Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + ..Usage::default() + }), + ..empty_annotated_response() + })) + .build(), + ), + )); + + processor.force_flush().unwrap(); + + let spans = exporter.get_finished_spans().unwrap(); + let attributes = attr_map(&spans[0].attributes); + assert!(!attributes.contains_key("llm.cost.total")); +} + #[test] fn llm_end_with_manual_usage_payload_emits_token_count_attributes() { let (provider, exporter) = make_provider(); @@ -2692,6 +3008,7 @@ fn llm_end_with_partial_usage_emits_only_present_fields() { total_tokens: None, cache_read_tokens: None, cache_write_tokens: None, + cost: None, }), api_specific: None, extra: serde_json::Map::new(), diff --git a/crates/core/tests/unit/observability/otel_tests.rs b/crates/core/tests/unit/observability/otel_tests.rs index 424de1a1..cf00cc84 100644 --- a/crates/core/tests/unit/observability/otel_tests.rs +++ b/crates/core/tests/unit/observability/otel_tests.rs @@ -13,6 +13,11 @@ use crate::api::runtime::global_context; use crate::api::scope::ScopeType; use crate::api::scope::{event, pop_scope, push_scope}; use crate::api::tool::ToolAttributes; +use crate::codec::pricing::pricing_test_mutex; +use crate::codec::response::{ + AnnotatedLlmResponse, CostEstimate, CostSource, PricingCatalog, PricingResolver, Usage, + reset_active_pricing_resolver, set_active_pricing_resolver, +}; use crate::json::Json; use crate::observability::atif::{AtifAgentInfo, AtifExporter, AtifStepExtra}; use opentelemetry_sdk::trace::InMemorySpanExporterBuilder; @@ -24,6 +29,94 @@ use std::sync::mpsc; use std::thread; use uuid::Uuid; +struct ResetPricingResolverGuard; + +impl Drop for ResetPricingResolverGuard { + fn drop(&mut self) { + let _ = reset_active_pricing_resolver(); + } +} + +fn empty_annotated_response() -> AnnotatedLlmResponse { + AnnotatedLlmResponse { + id: None, + model: None, + message: None, + tool_calls: None, + finish_reason: None, + usage: None, + api_specific: None, + extra: serde_json::Map::new(), + } +} + +fn install_test_pricing(model_id: &str) { + let catalog = PricingCatalog::from_json_str( + &json!({ + "version": 1, + "entries": [ + { + "provider": "test", + "model_id": model_id, + "pricing_as_of": "2026-06-05", + "pricing_source": "test", + "rates": { + "input_per_million": 0.15, + "output_per_million": 0.60, + "cache_read_per_million": 0.075 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + } + ] + }) + .to_string(), + ) + .unwrap(); + set_active_pricing_resolver(PricingResolver::from_catalogs(vec![catalog])).unwrap(); +} + +fn install_provider_disambiguation_pricing(model_id: &str) { + let catalog = PricingCatalog::from_json_str( + &json!({ + "version": 1, + "entries": [ + { + "provider": "other", + "model_id": model_id, + "pricing_as_of": "2026-06-05", + "pricing_source": "test", + "rates": { + "input_per_million": 1000.0, + "output_per_million": 1000.0 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + }, + { + "provider": "test", + "model_id": model_id, + "pricing_as_of": "2026-06-05", + "pricing_source": "test", + "rates": { + "input_per_million": 0.15, + "output_per_million": 0.60, + "cache_read_per_million": 0.075 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } + } + ] + }) + .to_string(), + ) + .unwrap(); + set_active_pricing_resolver(PricingResolver::from_catalogs(vec![catalog])).unwrap(); +} + fn reset_global() { crate::shared_runtime::reset_runtime_owner_for_tests(); let context = global_context(); @@ -773,7 +866,7 @@ fn helper_functions_cover_additional_otel_branches() { Some(&"{\"meta\":true}".to_string()) ); - let end_attributes = attr_map(&end_attributes(&Event::Scope(ScopeEvent::new( + let tool_end_attributes = attr_map(&end_attributes(&Event::Scope(ScopeEvent::new( BaseEvent::builder() .name("lookup") .metadata(json!({"phase": "complete"})) @@ -785,10 +878,256 @@ fn helper_functions_cover_additional_otel_branches() { Some(CategoryProfile::builder().tool_call_id("call-456").build()), )))); assert_eq!( - end_attributes.get("nemo_relay.end.output_json"), + tool_end_attributes.get("nemo_relay.end.output_json"), Some(&"{\"result\":true}".to_string()) ); + { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + install_test_pricing("priced-model"); + let _reset_guard = ResetPricingResolverGuard; + let llm_cost_event = make_scope_event_with_profile( + ScopeCategory::End, + Uuid::now_v7(), + None, + "chat", + ScopeType::Llm, + Some(json!({"answer": "ok"})), + Some( + CategoryProfile::builder() + .model_name("priced-model") + .annotated_response(std::sync::Arc::new(AnnotatedLlmResponse { + usage: Some(Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + total_tokens: Some(1_500), + cache_read_tokens: Some(200), + cache_write_tokens: None, + cost: None, + }), + ..empty_annotated_response() + })) + .build(), + ), + ); + let llm_cost_attributes = attr_map(&end_attributes(&llm_cost_event)); + assert_eq!( + llm_cost_attributes.get("nemo_relay.llm.cost.total"), + Some(&"0.000435".to_string()) + ); + assert_eq!( + llm_cost_attributes.get("nemo_relay.llm.cost.currency"), + Some(&"USD".to_string()) + ); + } + + { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + install_provider_disambiguation_pricing("priced-model"); + let _reset_guard = ResetPricingResolverGuard; + let provider_qualified_cost_event = make_scope_event_with_profile( + ScopeCategory::End, + Uuid::now_v7(), + None, + "test", + ScopeType::Llm, + Some(json!({"answer": "ok"})), + Some( + CategoryProfile::builder() + .model_name("priced-model") + .annotated_response(std::sync::Arc::new(AnnotatedLlmResponse { + usage: Some(Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + total_tokens: Some(1_500), + cache_read_tokens: Some(200), + cache_write_tokens: None, + cost: None, + }), + ..empty_annotated_response() + })) + .build(), + ), + ); + let provider_qualified_cost_attributes = + attr_map(&end_attributes(&provider_qualified_cost_event)); + assert_eq!( + provider_qualified_cost_attributes.get("nemo_relay.llm.cost.total"), + Some(&"0.000435".to_string()) + ); + assert_eq!( + provider_qualified_cost_attributes.get("nemo_relay.llm.cost.currency"), + Some(&"USD".to_string()) + ); + } + + let normalized_cost_event = make_scope_event_with_profile( + ScopeCategory::End, + Uuid::now_v7(), + None, + "chat", + ScopeType::Llm, + Some(json!({"answer": "ok"})), + Some( + CategoryProfile::builder() + .model_name("unknown-model") + .annotated_response(std::sync::Arc::new(AnnotatedLlmResponse { + usage: Some(Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + cost: Some(CostEstimate { + total: Some(0.42), + currency: "USD".into(), + input: None, + output: None, + cache_read: None, + cache_write: None, + source: CostSource::ProviderReported, + pricing_provider: Some("external".to_string()), + pricing_model: Some("external-model".to_string()), + pricing_as_of: Some("2026-06-04".to_string()), + pricing_source: None, + }), + ..Usage::default() + }), + ..empty_annotated_response() + })) + .build(), + ), + ); + let normalized_cost_attributes = attr_map(&end_attributes(&normalized_cost_event)); + assert_eq!( + normalized_cost_attributes.get("nemo_relay.llm.cost.total"), + Some(&"0.42".to_string()) + ); + assert_eq!( + normalized_cost_attributes.get("nemo_relay.llm.cost.currency"), + Some(&"USD".to_string()) + ); + + { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + install_test_pricing("priced-model"); + let _reset_guard = ResetPricingResolverGuard; + let reported_cost_without_total_event = make_scope_event_with_profile( + ScopeCategory::End, + Uuid::now_v7(), + None, + "test", + ScopeType::Llm, + Some(json!({"answer": "ok"})), + Some( + CategoryProfile::builder() + .model_name("priced-model") + .annotated_response(std::sync::Arc::new(AnnotatedLlmResponse { + usage: Some(Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + cost: Some(CostEstimate { + total: None, + currency: "EUR".into(), + input: Some(0.10), + output: None, + cache_read: None, + cache_write: None, + source: CostSource::ProviderReported, + pricing_provider: Some("external".to_string()), + pricing_model: Some("external-model".to_string()), + pricing_as_of: Some("2026-06-04".to_string()), + pricing_source: None, + }), + ..Usage::default() + }), + ..empty_annotated_response() + })) + .build(), + ), + ); + let reported_cost_without_total_attributes = + attr_map(&end_attributes(&reported_cost_without_total_event)); + assert_eq!( + reported_cost_without_total_attributes.get("nemo_relay.llm.cost.total"), + Some(&"0.1".to_string()) + ); + assert_eq!( + reported_cost_without_total_attributes.get("nemo_relay.llm.cost.currency"), + Some(&"EUR".to_string()) + ); + } + + { + let _pricing_guard = pricing_test_mutex().lock().unwrap(); + install_test_pricing("priced-model"); + let _reset_guard = ResetPricingResolverGuard; + let manual_cost_event = make_scope_event_with_profile( + ScopeCategory::End, + Uuid::now_v7(), + None, + "chat", + ScopeType::Llm, + Some(json!({ + "model": "priced-model", + "usage": { + "prompt_tokens": 1_000, + "completion_tokens": 500, + "total_tokens": 1_500, + "prompt_tokens_details": {"cached_tokens": 200} + } + })), + None, + ); + let manual_cost_attributes = attr_map(&end_attributes(&manual_cost_event)); + assert_eq!( + manual_cost_attributes.get("nemo_relay.llm.cost.total"), + Some(&"0.000435".to_string()) + ); + assert_eq!( + manual_cost_attributes.get("nemo_relay.llm.cost.currency"), + Some(&"USD".to_string()) + ); + + let annotated_without_model_event = make_scope_event_with_profile( + ScopeCategory::End, + Uuid::now_v7(), + None, + "chat", + ScopeType::Llm, + Some(json!({ + "model": "priced-model", + "usage": { + "prompt_tokens": 1_000, + "completion_tokens": 500, + "total_tokens": 1_500, + "prompt_tokens_details": {"cached_tokens": 200} + } + })), + Some( + CategoryProfile::builder() + .annotated_response(std::sync::Arc::new(AnnotatedLlmResponse { + usage: Some(Usage { + prompt_tokens: Some(1_000), + completion_tokens: Some(500), + total_tokens: Some(1_500), + cache_read_tokens: Some(200), + ..Usage::default() + }), + ..empty_annotated_response() + })) + .build(), + ), + ); + let annotated_without_model_attributes = + attr_map(&end_attributes(&annotated_without_model_event)); + assert_eq!( + annotated_without_model_attributes.get("nemo_relay.llm.cost.total"), + Some(&"0.000435".to_string()) + ); + assert_eq!( + annotated_without_model_attributes.get("nemo_relay.llm.cost.currency"), + Some(&"USD".to_string()) + ); + } + let mark = Event::Mark(MarkEvent::new( BaseEvent::builder() .parent_uuid(Uuid::now_v7()) diff --git a/crates/ffi/tests/unit/types_tests.rs b/crates/ffi/tests/unit/types_tests.rs index 839ca7d9..22481ec4 100644 --- a/crates/ffi/tests/unit/types_tests.rs +++ b/crates/ffi/tests/unit/types_tests.rs @@ -569,7 +569,26 @@ fn test_annotated_event_accessors_and_codec_handles() { )), tool_calls: None, finish_reason: Some(nemo_relay::codec::response::FinishReason::Complete), - usage: None, + usage: Some(nemo_relay::codec::response::Usage { + prompt_tokens: Some(10), + completion_tokens: Some(5), + total_tokens: Some(15), + cache_read_tokens: None, + cache_write_tokens: None, + cost: Some(nemo_relay::codec::response::CostEstimate { + total: Some(0.000_01), + currency: "USD".into(), + input: Some(0.000_004), + output: Some(0.000_006), + cache_read: None, + cache_write: None, + source: nemo_relay::codec::response::CostSource::ProviderReported, + pricing_provider: Some("ffi-provider".into()), + pricing_model: Some("gpt-test".into()), + pricing_as_of: Some("2026-06-04".into()), + pricing_source: Some("https://example.test/pricing".into()), + }), + }), api_specific: None, extra: serde_json::Map::from_iter([("trace".into(), json!(true))]), }; @@ -595,7 +614,12 @@ fn test_annotated_event_accessors_and_codec_handles() { let annotated_response_value: serde_json::Value = serde_json::from_str(&annotated_response_json).unwrap(); assert_eq!(annotated_response_value["id"], json!("resp_123")); + assert!(annotated_response_value.get("model_provider").is_none()); assert_eq!(annotated_response_value["trace"], json!(true)); + assert_eq!( + annotated_response_value["usage"]["cost"]["pricing_provider"], + json!("ffi-provider") + ); assert!(unsafe { nemo_relay_event_annotated_request(&ffi_end) }.is_null()); let scope_event = FfiEvent(make_scope_event(ScopeEventFixture { diff --git a/crates/node/tests/typed_tests.mjs b/crates/node/tests/typed_tests.mjs index cdc27cb9..2e2f9b65 100644 --- a/crates/node/tests/typed_tests.mjs +++ b/crates/node/tests/typed_tests.mjs @@ -510,7 +510,7 @@ describe('typedLlmExecute', () => { makeAnthropicRequest(), () => ({ id: 'msg_123', - model: 'claude-3-5-sonnet', + model: 'claude-sonnet-4', content: [ { type: 'text', @@ -521,6 +521,12 @@ describe('typedLlmExecute', () => { usage: { input_tokens: 5, output_tokens: 3, + cost: { + total: 0.00006, + source: 'provider_reported', + pricing_provider: 'anthropic', + pricing_model: 'claude-sonnet-4', + }, }, }), new JsonPassthrough(), @@ -552,9 +558,12 @@ describe('typedLlmExecute', () => { event.scope_category === 'end' && event.name === 'typed_anthropic_codec_llm', ); - assert.equal(endEvent.category_profile.annotated_response.model, 'claude-3-5-sonnet'); + assert.equal(endEvent.category_profile.annotated_response.model, 'claude-sonnet-4'); assert.equal(endEvent.category_profile.annotated_response.message, 'Anthropic hello'); assert.equal(endEvent.category_profile.annotated_response.finish_reason, 'complete'); + assert.equal(endEvent.category_profile.annotated_response.usage.cost.total, 0.00006); + assert.equal(endEvent.category_profile.annotated_response.usage.cost.pricing_provider, 'anthropic'); + assert.equal(endEvent.category_profile.annotated_response.usage.cost.pricing_model, 'claude-sonnet-4'); } finally { deregisterSubscriber('typed_anthropic_codec_sub'); popScope(scope); diff --git a/crates/python/tests/coverage/py_types_coverage_tests.rs b/crates/python/tests/coverage/py_types_coverage_tests.rs index 5fccf2c5..3198eac8 100644 --- a/crates/python/tests/coverage/py_types_coverage_tests.rs +++ b/crates/python/tests/coverage/py_types_coverage_tests.rs @@ -19,8 +19,8 @@ use nemo_relay::codec::request::{ AnnotatedLlmRequest as AnnotatedLLMRequest, Message, MessageContent, }; use nemo_relay::codec::response::{ - AnnotatedLlmResponse as AnnotatedLLMResponse, ApiSpecificResponse, FinishReason, - ResponseToolCall, Usage, + AnnotatedLlmResponse as AnnotatedLLMResponse, ApiSpecificResponse, CostEstimate, CostSource, + FinishReason, ResponseToolCall, Usage, }; use pyo3::types::{PyDict, PyList, PyModule}; use serde_json::json; @@ -617,6 +617,7 @@ fn test_stream_request_event_and_handle_wrappers_cover_remaining_methods() { total_tokens: Some(3), cache_read_tokens: None, cache_write_tokens: None, + cost: None, }), api_specific: Some(ApiSpecificResponse::Custom { api_name: "custom".into(), @@ -1198,6 +1199,19 @@ fn test_annotated_llm_types_and_builtin_codecs_cover_mutators_and_codecs() { total_tokens: Some(5), cache_read_tokens: Some(1), cache_write_tokens: None, + cost: Some(CostEstimate { + total: Some(0.000_001), + currency: "USD".into(), + input: Some(0.000_000_2), + output: Some(0.000_000_8), + cache_read: None, + cache_write: None, + source: CostSource::ProviderReported, + pricing_provider: Some("test-provider".into()), + pricing_model: Some("demo-model".into()), + pricing_as_of: Some("2026-06-04".into()), + pricing_source: Some("https://example.test/pricing".into()), + }), }), api_specific: Some(ApiSpecificResponse::Custom { api_name: "custom".into(), @@ -1221,6 +1235,10 @@ fn test_annotated_llm_types_and_builtin_codecs_cover_mutators_and_codecs() { py_to_json(response.usage(py).unwrap().bind(py)).unwrap()["total_tokens"], json!(5) ); + assert_eq!( + py_to_json(response.usage(py).unwrap().bind(py)).unwrap()["cost"]["pricing_provider"], + json!("test-provider") + ); assert_eq!( py_to_json(response.api_specific(py).unwrap().bind(py)).unwrap()["api_name"], json!("custom") @@ -1454,6 +1472,7 @@ fn test_forced_serialization_error_hooks_cover_unreachable_wrappers() { total_tokens: Some(5), cache_read_tokens: Some(1), cache_write_tokens: None, + cost: None, }), api_specific: Some(ApiSpecificResponse::Custom { api_name: "custom".into(), diff --git a/crates/wasm/tests-js/typed_tests.mjs b/crates/wasm/tests-js/typed_tests.mjs index 22f664c4..68ee8522 100644 --- a/crates/wasm/tests-js/typed_tests.mjs +++ b/crates/wasm/tests-js/typed_tests.mjs @@ -91,25 +91,74 @@ test('WebAssembly typed tool wrappers execute asynchronous flows', async () => { test('WebAssembly typed llm wrappers support response codecs', async () => { const passthrough = new JsonPassthrough(); + const responseEvents = []; + const subscriberName = unique('wrapper_llm_cost'); + wasm.registerSubscriber(subscriberName, (event) => responseEvents.push(event)); - const llmResult = await typedLlmExecute( - 'wrapper_llm', - makeLlmRequest('test-model'), - () => ({ + try { + const llmResult = await typedLlmExecute( + 'wrapper_llm', + makeLlmRequest('gpt-4o-mini'), + () => ({ + response: 'ok', + model: 'gpt-4o-mini', + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + cost: { + total: 0.00001, + source: 'provider_reported', + pricing_provider: 'wasm-provider', + pricing_model: 'gpt-4o-mini', + pricing_as_of: '2026-06-04', + pricing_source: 'https://example.test/pricing', + }, + }, + }), + passthrough, + { + responseCodec: { + decodeResponse(response) { + return response; + }, + }, + }, + ); + assert.deepEqual(llmResult, { response: 'ok', - }), - passthrough, - { - responseCodec: { - decodeResponse(response) { - return response; + model: 'gpt-4o-mini', + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + cost: { + total: 0.00001, + source: 'provider_reported', + pricing_provider: 'wasm-provider', + pricing_model: 'gpt-4o-mini', + pricing_as_of: '2026-06-04', + pricing_source: 'https://example.test/pricing', }, }, - }, - ); - assert.deepEqual(llmResult, { - response: 'ok', - }); + }); + + const endEvent = await waitFor(() => + responseEvents.find( + (event) => + event.kind === 'scope' && + event.category === 'llm' && + event.scope_category === 'end' && + event.name === 'wrapper_llm', + ), + ); + assert.equal( + endEvent.category_profile.annotated_response.usage.cost.pricing_provider, + 'wasm-provider', + ); + } finally { + wasm.deregisterSubscriber(subscriberName); + } }); test('WebAssembly typed llm wrappers execute synchronous flows', async () => { diff --git a/docs/build-plugins/about.mdx b/docs/build-plugins/about.mdx index 4d7e83d2..6bbae23d 100644 --- a/docs/build-plugins/about.mdx +++ b/docs/build-plugins/about.mdx @@ -18,12 +18,22 @@ Plugins prevent repeated registration code for policies, request transforms, exporters, and related runtime components. They give shared behavior a stable kind name, a structured config document, and a clear activation lifecycle. +NeMo Relay also ships built-in plugin components for shared runtime behavior. +For example, the `pricing` component installs model-pricing sources that +response codecs can use to annotate managed LLM responses with cost estimates. +Applications, eval harnesses, custom agents, and framework integrations can +activate that component through the same plugin APIs they use for custom +plugins; the `nemo-relay` CLI is only one host that can load plugin +configuration from files. + ## Start Here When Use these signals to decide whether this documentation path matches your current task. - Ship policy bundles across applications - Install observability exporters consistently +- Install pricing sources for cost estimates across applications, harnesses, or + agent integrations - Package framework-agnostic request transforms - Validate operator-supplied config before runtime behavior changes diff --git a/docs/build-plugins/plugin-configuration-files.mdx b/docs/build-plugins/plugin-configuration-files.mdx index 815e8c56..da0b25bd 100644 --- a/docs/build-plugins/plugin-configuration-files.mdx +++ b/docs/build-plugins/plugin-configuration-files.mdx @@ -188,6 +188,12 @@ across files. A higher-precedence component with the same `kind` merges into the lower-precedence component. A component with a different `kind` is added to the effective configuration. +The `pricing` component has one additional merge rule: when both lower- and +higher-precedence files define `components.config.sources`, the +higher-precedence sources are placed before lower-precedence sources instead of +replacing them. This lets a user or project file override one model while still +falling back to fleet-managed pricing from `/etc/nemo-relay/plugins.toml`. + Declare each `kind` at most once inside one `plugins.toml` file. Duplicate component kinds in the same file fail before merge. Duplicate singleton components that reach plugin validation also fail validation. @@ -244,6 +250,8 @@ Use `nemo-relay doctor` to inspect the resolved gateway configuration and plugin diagnostics. For Observability, doctor also reports enabled exporter sections, checks writable file exporter directories, probes configured ATOF streaming endpoints, and checks reachable OTLP endpoints when those settings are present. +For Pricing, doctor validates enabled file and inline sources and fails when a +source is unreadable or the catalog schema is invalid. ## Relationship To `config.toml` diff --git a/docs/instrument-applications/instrument-llm-call.mdx b/docs/instrument-applications/instrument-llm-call.mdx index d019edcf..388ebc36 100644 --- a/docs/instrument-applications/instrument-llm-call.mdx +++ b/docs/instrument-applications/instrument-llm-call.mdx @@ -32,6 +32,14 @@ Create a scope for the active request or agent run before adding LLM instrumenta The request and response payloads must be JSON-compatible. If your provider SDK uses clients, streams, callbacks, or other opaque objects, keep those objects in the provider callback and pass only a serializable request projection into NeMo Relay. +If you want Relay to add cost estimates, initialize the built-in `pricing` +plugin before the LLM call and attach a response codec that decodes `model` and +token `usage` from the provider response. Provider- or framework-reported cost +is preserved when present. Otherwise Relay estimates cost only when a configured +pricing source matches the response model and usage fields. For catalog setup +and embedded plugin examples, see +[Provider Response Codecs](/integrate-into-frameworks/provider-response-codecs#cost-estimation). + ## Integration Pattern Follow these steps to route the provider invocation through NeMo Relay: @@ -42,7 +50,9 @@ Follow these steps to route the provider invocation through NeMo Relay: 4. Build an LLM request object with provider headers and content. 5. Replace the direct provider invocation with the managed LLM execute helper. 6. Pass the active scope handle and a stable `model_name`. -7. Check that the provider result is unchanged and lifecycle events are emitted. +7. Attach a response codec when subscribers or exporters need normalized + response usage, tool calls, or cost annotations. +8. Check that the provider result is unchanged and lifecycle events are emitted. ## Minimal Example @@ -213,6 +223,9 @@ Check both behavior and instrumentation: - The provider result matches what the application returned before the wrapper was added. - The subscriber prints an agent or request scope event. - The subscriber prints LLM start and LLM end events for `demo-provider`. +- If pricing is configured, LLM end events include + `annotated_response.usage.cost` only when a response codec decoded model and + usage fields and a source matched the model. Native subscriber delivery is asynchronous. Flush subscribers before validating printed output. In Node.js, also wait one event-loop tick after @@ -233,6 +246,8 @@ Before deploying to production, ensure the following checklist is completed: - Keep request and response payloads JSON-compatible. - Keep SDK clients and transport objects inside the provider callback. - Use codecs when middleware needs normalized provider request or response semantics. +- Use response codecs and the `pricing` plugin when exporters need cost + estimates from model pricing. - Use sanitize guardrails before exporting prompts or model responses in production. ## Common Issues diff --git a/docs/integrate-into-frameworks/provider-response-codecs.mdx b/docs/integrate-into-frameworks/provider-response-codecs.mdx index 5cc367d4..e70ed3d7 100644 --- a/docs/integrate-into-frameworks/provider-response-codecs.mdx +++ b/docs/integrate-into-frameworks/provider-response-codecs.mdx @@ -40,12 +40,280 @@ Response codecs normalize provider output into fields that subscribers can inspe | `message` | Primary assistant message content. | | `tool_calls` | Tool calls requested by the model. | | `finish_reason` | Normalized completion reason, such as `complete`, `length`, `tool_use`, or `content_filter`. | -| `usage` | Token accounting, including cache-read and cache-write counts when available. | +| `usage` | Token accounting, including cache-read and cache-write counts when available. May also include normalized `cost` when the provider reports cost or Relay can estimate it from known model pricing. | | `api_specific` | Provider-specific fields that do not fit the common model. | | `extra` | Additional unmodeled response fields. | Use these annotations for observability, export, and debugging. Keep business logic that changes the caller-visible response in the framework or provider adapter, not in the response codec. +## Cost Estimation + +Response codecs should keep reporting provider usage fields without rewriting +the caller-visible response. If a provider or framework reports cost, map it to +`Usage.cost` with `source: "provider_reported"`. Otherwise Relay can layer cost +estimation onto `AnnotatedLlmResponse.usage.cost` when all required inputs are +available: + +- The decoded response includes `model`. +- The managed LLM call name identifies the provider or route, such as `openai`, + `anthropic`, or `azure/openai`, when provider-specific pricing is needed. +- The decoded response includes prompt and/or completion token usage. +- Relay has an explicit pricing entry for that model or alias. + +Pricing estimates carry `pricing_provider`, `pricing_model`, `pricing_as_of`, +`pricing_source`, and `currency` metadata so stale pricing can be audited +without failing response decoding. Normalized cost uses currency-neutral amount +fields such as `total`, `input`, `output`, `cache_read`, and `cache_write`. +Unknown model pricing and missing token data are non-fatal: Relay omits the cost +field and still exports token metrics and response annotations. + +Relay resolves pricing through an active `PricingResolver` source chain. Provider +or framework-reported cost remains authoritative; the resolver is used only when +`Usage.cost` is missing. Relay does not ship provider price data by default: +estimates require a configured inline, file, or embedding-provided pricing +source. With no configured source, every model is treated as unknown for pricing. + +Pricing is runtime state, not a CLI-only feature. Any host that initializes +Relay plugins can activate the built-in `pricing` component before it runs +managed LLM calls. This includes application code, eval harnesses, custom +agents, framework integrations, and third-party patches. The CLI commands below +are a file-management convenience for the local gateway; embedded hosts can pass +the same component config directly through the plugin APIs. + +Source precedence is deployment controlled: + +1. Project or application overrides. +2. User/global device pricing. +3. Enterprise-managed sources, such as a remotely synced file or a service + backed by a database. + +The built-in `pricing` plugin component accepts inline catalogs or JSON catalog +files in precedence order. In discovered `plugins.toml` config, system config +loads first, project config loads next, and user config loads last. For the +`pricing` component, higher-priority `sources` are prepended instead of +replacing lower-priority sources, so a user override can win for one model while +enterprise or fleet pricing remains available for everything else: + +```toml +[[components]] +kind = "pricing" +enabled = true + +[[components.config.sources]] +type = "file" +path = "/etc/nemo-relay/pricing.json" + +[[components.config.sources]] +type = "inline" +[components.config.sources.catalog] +version = 1 +entries = [] +``` + +Each catalog entry declares: + +- `provider` and canonical `model_id`. +- `aliases` for dated or provider-specific model IDs. +- `currency`, defaulting to `USD`. +- `unit`, defaulting to `per_token`. Relay estimates only `per_token` entries in this version; `per_request`, `per_second`, and `gpu_hour` are representable for future source integrations but are not estimated. +- `rates` per one million input, output, cache-read, and cache-write tokens for flat `per_token` entries. +- `rate_schedule` for data-driven threshold pricing, such as models whose full-request input/output rates change after a prompt-token threshold. +- `prompt_cache.read_accounting`, which tells Relay whether cache-read tokens are already included in prompt tokens. +- `pricing_as_of` and `pricing_source` for auditability. + +Relay validates catalogs at startup and rejects duplicate canonical IDs or +aliases within the same normalized provider/model key. The same model ID can +appear under distinct providers, such as `openai/gpt-4o-mini` and +`azure/openai/gpt-4o-mini`. Adding a model should be a catalog/source update +plus tests; it should not require adding another Rust `match` arm. + +Use the CLI to validate catalog files and manage file-backed pricing sources: + +```bash +nemo-relay pricing validate /path/to/pricing.json +nemo-relay pricing init --project +nemo-relay pricing add-source /path/to/pricing.json --project +nemo-relay pricing resolve gpt-4o-mini --provider openai --prompt-tokens 1000 --completion-tokens 500 +``` + +`pricing init` creates or enables the `pricing` plugin component in the selected +`plugins.toml`. The initialized component has an empty `sources` list; use +`pricing add-source` or an inline config edit to provide pricing data. + +`pricing add-source` validates the referenced JSON catalog before updating +`plugins.toml`. It creates the pricing component if needed and prepends the new +file source by default, making it the highest-priority source in that scope. Use +`--append` when the file should be a lower-priority fallback. Both commands +default to user config at `$XDG_CONFIG_HOME/nemo-relay/plugins.toml`; pass +`--project` for `.nemo-relay/plugins.toml` or `--global` for +`/etc/nemo-relay/plugins.toml`. + +`pricing resolve` uses the same discovered config path as the gateway. It +reports the winning catalog source, matched provider/model, and, when token +counts are supplied, the estimated total cost. The source line is one of +`file:` or `inline:`, which makes overlapping project/user/fleet +entries debuggable. This is a dry diagnostic command; it does not mutate +configuration. + +`nemo-relay doctor` also validates enabled pricing sources and reports missing, +unreadable, or invalid catalogs before the gateway starts. + +Model lookup is provider-aware and route-aware. Relay uses the managed LLM call +name as the provider/route and first tries provider-scoped keys for the full +model and terminal model name, then falls back to model-only suffixes. For +example, a call named `azure/openai` with response `model = "gpt-4o-mini"` tries +`azure/openai/gpt-4o-mini` before generic `gpt-4o-mini`. If the model string is +itself routed, such as `azure/openai/gpt-4o-mini`, Relay can infer +`azure/openai` for the terminal model before trying slash-delimited model-only +suffixes. This keeps route-specific enterprise pricing authoritative when +configured while still allowing generic model pricing to apply to routed names. + +For threshold pricing, use `rate_schedule.type = "prompt_token_threshold"`. +Relay selects exactly one tier from `prompt_tokens` and applies that tier to the +full request; it does not price only the overflow tokens at the higher rate. +This matches providers that publish "short context" and "long context" prices +for the entire request/session. If `prompt_tokens` is missing for a thresholded +entry, Relay omits the estimate instead of guessing. + +```json +{ + "provider": "google", + "model_id": "gemini-3.1-pro-preview", + "aliases": ["gemini-3.1-pro-preview-customtools"], + "pricing_as_of": "2026-06-05", + "pricing_source": "https://ai.google.dev/gemini-api/docs/pricing", + "rate_schedule": { + "type": "prompt_token_threshold", + "applies_to": "full_request", + "tiers": [ + { + "max_prompt_tokens": 200000, + "rates": { + "input_per_million": 2.0, + "output_per_million": 12.0, + "cache_read_per_million": 0.2 + } + }, + { + "min_prompt_tokens": 200001, + "rates": { + "input_per_million": 4.0, + "output_per_million": 18.0, + "cache_read_per_million": 0.4 + } + } + ] + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + } +} +``` + +Database-backed or remote pricing should be implemented as a source that returns +a validated `PricingCatalog` snapshot to Relay. Keep database queries, service +auth, refresh cadence, and caching outside the LLM response hot path. A fleet +deployment can refresh `/etc/nemo-relay/pricing.json` from an IT-managed service, +or embed a custom Rust `PricingSource` that reads from a database and installs a +`PricingResolver` snapshot during process startup. + +External pricing catalogs should be converted into Relay catalog JSON +out-of-band and then loaded through a `file` source, unless the embedding +application installs a custom Rust `PricingSource` directly. + +Embedded applications and eval harnesses can initialize the built-in pricing +component directly: + + + +```python +import nemo_relay + +config = nemo_relay.plugin.PluginConfig( + components=[ + nemo_relay.plugin.ComponentSpec( + kind="pricing", + config={ + "sources": [ + {"type": "file", "path": "./pricing.json"}, + ], + }, + ) + ] +) + +report = nemo_relay.plugin.validate(config) +if any(diagnostic["level"] == "error" for diagnostic in report["diagnostics"]): + raise RuntimeError(report["diagnostics"]) + +await nemo_relay.plugin.initialize(config) +``` + + + +```ts +import * as plugin from 'nemo-relay-node/plugin'; + +const config = plugin.defaultConfig(); +config.components = [ + plugin.ComponentSpec('pricing', { + sources: [{ type: 'file', path: './pricing.json' }], + }), +]; + +const report = plugin.validate(config); +if (report.diagnostics.some((diagnostic) => diagnostic.level === 'error')) { + throw new Error(JSON.stringify(report.diagnostics)); +} + +await plugin.initialize(config); +``` + + + +```rust +use nemo_relay::plugin::{initialize_plugins, validate_plugin_config, PluginConfig}; +use serde_json::json; + +let config: PluginConfig = serde_json::from_value(json!({ + "version": 1, + "components": [{ + "kind": "pricing", + "config": { + "sources": [ + {"type": "file", "path": "./pricing.json"} + ] + } + }] +}))?; + +let report = validate_plugin_config(&config); +if report.has_errors() { + return Err("invalid pricing plugin config".into()); +} + +initialize_plugins(config).await?; +``` + + + +Initialize pricing once during process or harness startup, before the managed +LLM calls whose responses should be cost-annotated. In tests or reusable +harnesses, clear plugin configuration during teardown if later cases need a +different resolver. + +Built-in response codecs attach estimated cost directly to +`AnnotatedLlmResponse.usage.cost` when pricing is known. Managed LLM wrappers +also enrich decoded custom response-codec output when the custom codec returns +`model` and `usage` but omits `usage.cost`. Existing cost values are preserved, +so provider-reported cost remains authoritative in the annotation. + +Observability exporters prefer an explicit cost in the raw payload, then +normalized `Usage.cost`, then a derived estimate from model pricing. When cost is +available, ATIF step metrics and final metrics include `cost_usd`, +OpenInference includes the USD-denominated `llm.cost.total`, and OpenTelemetry +includes `nemo_relay.llm.cost.total` and `nemo_relay.llm.cost.currency`. + ## Built-in Response Codecs The built-in provider codecs also implement response decoding: @@ -185,6 +453,7 @@ def on_event(event): print("model", annotated.model) print("text", annotated.response_text()) print("usage", annotated.usage) + print("cost", (annotated.usage or {}).get("cost")) nemo_relay.subscribers.register("response-debugger", on_event) ``` @@ -203,6 +472,7 @@ registerSubscriber('response-debugger', (event) => { console.log('model', annotated.model); console.log('message', annotated.message); console.log('usage', annotated.usage); + console.log('cost', annotated.usage?.cost); }); ``` @@ -308,6 +578,7 @@ impl LlmResponseCodec for FrameworkResponseCodec { total_tokens, cache_read_tokens: None, cache_write_tokens: None, + cost: None, }), api_specific: None, extra: Map::new(), diff --git a/docs/nemo-relay-cli/about.mdx b/docs/nemo-relay-cli/about.mdx index 9ca8cfc7..9462c000 100644 --- a/docs/nemo-relay-cli/about.mdx +++ b/docs/nemo-relay-cli/about.mdx @@ -24,6 +24,7 @@ Use these guides when you need to: - Observe Claude Code, Codex, Cursor, or Hermes Agent sessions locally. - Configure coding-agent hooks for NeMo Relay lifecycle events. - Route model-provider traffic through the local NeMo Relay gateway. +- Validate and install model-pricing catalog sources for local cost estimates. - Export local sessions to Agent Trajectory Interchange Format (ATIF), Agent Trajectory Observability Format (ATOF) JSONL, OpenTelemetry, or OpenInference. diff --git a/docs/nemo-relay-cli/basic-usage.mdx b/docs/nemo-relay-cli/basic-usage.mdx index effc2c3b..b421a604 100644 --- a/docs/nemo-relay-cli/basic-usage.mdx +++ b/docs/nemo-relay-cli/basic-usage.mdx @@ -147,6 +147,85 @@ enabled = true endpoint = "http://127.0.0.1:4318/v1/traces" ``` +## Add Pricing For Cost Estimates + +Pricing is configured with the same `plugins.toml` discovery path as +Observability. The configured sources apply to transparent agent runs, +standalone gateway runs, and evals or custom agents that initialize the same +gateway plugin config. Framework patches, harnesses, and custom integrations do +not need their own pricing logic when they emit managed LLM calls with response +codecs: Relay attaches cost to the annotated response when the provider reports +cost or when a configured pricing source matches the response model and token +usage. + +Create a Relay pricing catalog JSON file: + +```json +{ + "version": 1, + "entries": [ + { + "provider": "openai", + "model_id": "gpt-4o-mini", + "aliases": ["openai/openai/gpt-4o-mini"], + "currency": "USD", + "unit": "per_token", + "rates": { + "input_per_million": 0.15, + "output_per_million": 0.6, + "cache_read_per_million": 0.075 + }, + "prompt_cache": { + "read_accounting": "included_in_prompt_tokens" + }, + "pricing_as_of": "2026-06-06", + "pricing_source": "internal-pricing-snapshot" + } + ] +} +``` + +Validate and add the file-backed source: + +```bash +nemo-relay pricing validate /path/to/pricing.json +nemo-relay pricing init --project +nemo-relay pricing add-source /path/to/pricing.json --project +``` + +Use `--user` instead of `--project` for a device-wide user config, or +`--global` for `/etc/nemo-relay/plugins.toml`. `pricing add-source` prepends the +source by default, so the new file becomes the highest-priority source for that +scope. Use `--append` to add it as a lower-priority fallback. + +Resolve a model before running an agent: + +```bash +nemo-relay pricing resolve gpt-4o-mini \ + --provider openai \ + --prompt-tokens 1000 \ + --completion-tokens 500 +``` + +`pricing resolve` prints the source that won, the matched provider/model, and an +estimated total when token counts are supplied. Use it to debug overlapping +fleet, project, and user pricing files. + +Run doctor to validate the active pricing sources alongside exporter checks: + +```bash +nemo-relay doctor codex +``` + +Doctor fails when an enabled pricing source is unreadable or contains an invalid +catalog, and it reports passing sources as `Pricing source`. + +Relay does not ship a canonical price catalog. Unknown models and missing token +fields leave cost absent instead of defaulting to zero. For the catalog schema, +provider-aware lookup behavior, threshold pricing, and custom `PricingSource` +integrations, see +[Provider Response Codecs](/integrate-into-frameworks/provider-response-codecs#cost-estimation). + Transparent runs always bind the managed gateway to `127.0.0.1:0`. The selected port is discovered by the wrapper and exposed to hooks through `NEMO_RELAY_GATEWAY_URL`. diff --git a/go/nemo_relay/llm_test.go b/go/nemo_relay/llm_test.go index ce1ca3f4..6996187b 100644 --- a/go/nemo_relay/llm_test.go +++ b/go/nemo_relay/llm_test.go @@ -192,7 +192,31 @@ func TestLlmCallExecuteWithRequestAndResponseCodecs(t *testing.T) { startEvent, endEvent := requireLlmScopeEvents(t, events) _ = startEvent.Attributes() _ = startEvent.AnnotatedRequest() - _ = endEvent.AnnotatedResponse() + var annotatedResponse map[string]any + if err := json.Unmarshal(endEvent.AnnotatedResponse(), &annotatedResponse); err != nil { + t.Fatalf("AnnotatedResponse JSON did not parse: %v", err) + } + usage, ok := annotatedResponse["usage"].(map[string]any) + if !ok { + t.Fatalf("expected annotated response usage, got %#v", annotatedResponse) + } + cost, ok := usage["cost"].(map[string]any) + if !ok { + t.Fatalf("expected annotated response cost, got %#v", usage) + } + if cost["pricing_provider"] != "openai" { + t.Fatalf("expected openai pricing provider, got %#v", cost["pricing_provider"]) + } + if cost["pricing_model"] != "gpt-4o-mini" { + t.Fatalf("expected gpt-4o-mini pricing model, got %#v", cost["pricing_model"]) + } + total, ok := cost["total"].(float64) + if !ok { + t.Fatalf("expected numeric total, got %#v", cost["total"]) + } + if diff := total - 0.0000435; diff > 1e-12 || diff < -1e-12 { + t.Fatalf("expected total 0.0000435, got %#v", total) + } } func llmRequestResponseCodec() CodecFunc { @@ -246,7 +270,7 @@ func requireEncodedModelExecutor(t *testing.T) func(json.RawMessage) (json.RawMe if request.Content["model"] != "encoded-model" { t.Fatalf("expected encoded model in execution payload, got %#v", request.Content) } - return json.RawMessage(`{"id":"chatcmpl-1","object":"chat.completion","created":1,"model":"gpt-test","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}]}`), nil + return json.RawMessage(`{"id":"chatcmpl-1","object":"chat.completion","created":1,"model":"gpt-4o-mini","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150,"prompt_tokens_details":{"cached_tokens":20},"cost":{"total":0.0000435,"source":"provider_reported","pricing_provider":"openai","pricing_model":"gpt-4o-mini"}}}`), nil } }