Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 54 additions & 13 deletions crates/braintrust-llm-router/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,32 +154,34 @@ async fn prepare_provider_request(
spec: &ModelSpec,
format: ProviderFormat,
stream: bool,
) -> Result<(Bytes, Option<ProviderFormat>)> {
) -> Result<(Bytes, Option<ProviderFormat>, ProviderFormat)> {
if requires_bedrock_request_preparation(format) {
let bytes = prepare_bedrock_request(body, spec, format).await?;
return Ok((bytes, Some(format)));
return Ok((bytes, Some(format), format));
}

let (transformed, detected_format) =
let (transformed, detected_format, actual_format) =
match lingua::transform_request(body.clone(), format, Some(&spec.model)) {
Ok(TransformResult::PassThrough(bytes)) => (bytes, None),
Ok(TransformResult::PassThrough(bytes)) => (bytes, None, format),
Ok(TransformResult::Transformed {
bytes,
source_format,
}) => (bytes, Some(source_format)),
Err(TransformError::UnsupportedTargetFormat(_)) => (body, None),
actual_target_format,
}) => (bytes, Some(source_format), actual_target_format),
Err(TransformError::UnsupportedTargetFormat(_)) => (body, None, format),
Err(err) => return Err(err.into()),
};

if stream {
// TODO: Fold streaming intent into `lingua::transform_request` once we
// are ready to update its Rust/WASM/Python/TS call sites together.
Ok((
enable_streaming_payload(transformed, format),
enable_streaming_payload(transformed, actual_format),
detected_format,
actual_format,
))
} else {
Ok((transformed, detected_format))
Ok((transformed, detected_format, actual_format))
}
}

Expand Down Expand Up @@ -214,22 +216,22 @@ impl Router {
.first()
.ok_or_else(|| Error::NoProvider(output_format))?;
let (provider_alias, provider, auth, spec, format, strategy) = route;
let (payload, detected_format) =
let (payload, detected_format, actual_format) =
prepare_provider_request(body, spec.as_ref(), *format, stream).await?;
Ok((
PreparedRequestInner {
provider: provider.clone(),
auth,
spec: spec.clone(),
format: *format,
format: actual_format,
payload,
output_format,
strategy: strategy.clone(),
},
RouterMetadata {
detected_input_format: detected_format.unwrap_or(*format),
provider_alias: provider_alias.clone(),
provider_format: *format,
provider_format: actual_format,
},
))
}
Expand Down Expand Up @@ -872,7 +874,7 @@ mod tests {
);
let spec = openai_spec("gpt-5-mini", ModelFlavor::Chat);

let (payload, _) =
let (payload, _, _) =
prepare_provider_request(body, &spec, ProviderFormat::ChatCompletions, true)
.await
.expect("request prepares");
Expand All @@ -890,7 +892,7 @@ mod tests {
);
let spec = openai_spec("gpt-5-mini", ModelFlavor::Chat);

let (payload, _) =
let (payload, _, _) =
prepare_provider_request(body, &spec, ProviderFormat::ChatCompletions, false)
.await
.expect("request prepares");
Expand All @@ -900,6 +902,45 @@ mod tests {
assert_eq!(parsed.get("stream_options"), None);
}

#[tokio::test]
async fn prepare_provider_request_upgrades_actual_format_to_responses_for_reasoning_plus_tools()
{
// A chat-completions request with reasoning_effort + tools should have its actual_format
// upgraded to Responses so the router sends it to the correct endpoint.
let body = Bytes::from(
serde_json::json!({
"model": "gpt-5.4-mini",
"messages": [{"role": "user", "content": "Tokyo weather?"}],
"reasoning_effort": "medium",
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"]
}
}
}]
})
.to_string(),
);
let spec = openai_spec("gpt-5.4-mini", ModelFlavor::Chat);

let (_, _, actual_format) =
prepare_provider_request(body, &spec, ProviderFormat::ChatCompletions, false)
.await
.expect("request prepares");

assert_eq!(
actual_format,
ProviderFormat::Responses,
"actual_format must be Responses so the router uses the /v1/responses endpoint"
);
}

fn dummy_auth() -> AuthConfig {
AuthConfig::ApiKey {
key: "test".into(),
Expand Down
145 changes: 145 additions & 0 deletions crates/braintrust-llm-router/tests/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,151 @@ async fn anthropic_native_catalog_format_resolves_to_anthropic() {
);
}

/// Stub provider that supports both ChatCompletions and Responses formats and
/// records the format it was called with.
#[derive(Clone)]
struct CapturingOpenAIStub {
recorded_format: Arc<Mutex<Option<ProviderFormat>>>,
}

#[async_trait]
impl Provider for CapturingOpenAIStub {
fn id(&self) -> &'static str {
"openai"
}

fn provider_formats(&self) -> Vec<ProviderFormat> {
vec![ProviderFormat::ChatCompletions, ProviderFormat::Responses]
}

async fn complete(
&self,
_payload: Bytes,
_auth: &AuthConfig,
_spec: &ModelSpec,
format: ProviderFormat,
_client_headers: &ClientHeaders,
) -> braintrust_llm_router::Result<Bytes> {
*self.recorded_format.lock().unwrap() = Some(format);
let response = braintrust_llm_router::serde_json::json!({
"id": "stub",
"object": "chat.completion",
"model": "gpt-5.4-mini",
"choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}
});
Ok(Bytes::from(
braintrust_llm_router::serde_json::to_vec(&response).unwrap(),
))
}

async fn complete_stream(
&self,
_payload: Bytes,
_auth: &AuthConfig,
_spec: &ModelSpec,
_format: ProviderFormat,
_client_headers: &ClientHeaders,
) -> braintrust_llm_router::Result<RawResponseStream> {
Ok(Box::pin(tokio_stream::empty()))
}

async fn health_check(&self, _auth: &AuthConfig) -> braintrust_llm_router::Result<()> {
Ok(())
}
}

fn openai_router(catalog: Arc<ModelCatalog>) -> (Router, Arc<Mutex<Option<ProviderFormat>>>) {
let recorded_format = Arc::new(Mutex::new(None));
let stub = CapturingOpenAIStub {
recorded_format: Arc::clone(&recorded_format),
};
let router = RouterBuilder::new()
.with_catalog(catalog)
.add_provider(
"openai",
stub,
AuthConfig::ApiKey {
key: "test-key".into(),
header: Some("authorization".into()),
prefix: Some("Bearer".into()),
},
vec![ProviderFormat::ChatCompletions, ProviderFormat::Responses],
)
.build()
.expect("router builds");
(router, recorded_format)
}

#[tokio::test]
async fn reasoning_effort_with_tools_upgrades_format_to_responses() {
// Use gpt-5.2-mini (minor version 2 < 3) so model_requires_responses_api() returns
// false. Only the body-level detection (reasoning_effort + tools) should trigger the
// upgrade from ChatCompletions to Responses.
let mut catalog = ModelCatalog::empty();
catalog.insert(
"gpt-5.2-mini".into(),
ModelSpec {
model: "gpt-5.2-mini".into(),
format: ProviderFormat::ChatCompletions,
flavor: ModelFlavor::Chat,
display_name: None,
parent: None,
input_cost_per_mil_tokens: None,
output_cost_per_mil_tokens: None,
input_cache_read_cost_per_mil_tokens: None,
multimodal: None,
reasoning: None,
max_input_tokens: None,
max_output_tokens: None,
supports_streaming: true,
extra: Default::default(),
available_providers: Default::default(),
},
);
let (router, recorded_format) = openai_router(Arc::new(catalog));

let body = to_body(json!({
"model": "gpt-5.2-mini",
"messages": [{"role": "user", "content": "Tokyo weather?"}],
"reasoning_effort": "medium",
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"]
}
}
}]
}));

let (request, metadata) = router
.create_request(body, "gpt-5.2-mini", ProviderFormat::ChatCompletions)
.await
.expect("create request");

assert_eq!(
metadata.provider_format,
ProviderFormat::Responses,
"provider_format in metadata should reflect the upgrade to Responses"
);

router
.complete(request, &ClientHeaders::default())
.await
.expect("complete");

assert_eq!(
*recorded_format.lock().unwrap(),
Some(ProviderFormat::Responses),
"provider.complete() must be called with Responses so the request hits /v1/responses"
);
}

fn retry_model_catalog() -> Arc<ModelCatalog> {
let mut catalog = ModelCatalog::empty();
catalog.insert(
Expand Down
Loading
Loading