Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion crates/protocols/src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use validator::Validate;

use crate::{skills::MessagesSkillRef, validated::Normalizable};
use crate::{common::GenerationRequest, skills::MessagesSkillRef, validated::Normalizable};

// ============================================================================
// Request Types
Expand Down Expand Up @@ -105,6 +105,58 @@ impl CreateMessageRequest {
}
}

impl GenerationRequest for CreateMessageRequest {
fn is_stream(&self) -> bool {
self.stream.unwrap_or(false)
}

fn get_model(&self) -> Option<&str> {
Some(&self.model)
}

fn extract_text_for_routing(&self) -> String {
let mut buffer = String::new();
let mut has_content = false;

let push = |s: &str, has_content: &mut bool, buffer: &mut String| {
if s.is_empty() {
return;
}
if *has_content {
buffer.push(' ');
}
buffer.push_str(s);
*has_content = true;
};

if let Some(system) = &self.system {
match system {
SystemContent::String(s) => push(s, &mut has_content, &mut buffer),
SystemContent::Blocks(blocks) => {
for block in blocks {
push(&block.text, &mut has_content, &mut buffer);
}
}
}
}

for msg in &self.messages {
match &msg.content {
InputContent::String(s) => push(s, &mut has_content, &mut buffer),
InputContent::Blocks(blocks) => {
for block in blocks {
if let InputContentBlock::Text(text_block) = block {
push(&text_block.text, &mut has_content, &mut buffer);
Comment on lines +148 to +149
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Include tool_result content when extracting routing text

The new CreateMessageRequest::extract_text_for_routing only appends InputContentBlock::Text, so requests whose latest user turn is a tool_result (common in tool-calling loops) can produce an empty or incomplete routing key even though they contain substantial text. This degrades text-based worker selection and can misroute /v1/messages traffic compared with chat routing, which includes tool message content; consider extracting text from ToolResult payloads (string and text blocks) as well.

Useful? React with 👍 / 👎.

}
Comment on lines +148 to +150
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Extract document text when building routing key

CreateMessageRequest::extract_text_for_routing currently appends only InputContentBlock::Text, so requests whose prompt is carried in document blocks (for example DocumentSource::Text or document content blocks) produce an empty/partial routing key even though they contain substantial text. Because HTTP routing uses this extracted text for worker selection (route_typed_request), document-heavy /v1/messages traffic can be misrouted compared with equivalent chat payloads; include textual fields from document blocks when constructing the routing text.

Useful? React with 👍 / 👎.

}
}
}
}
Comment on lines +143 to +154
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of extract_text_for_routing only extracts text from InputContentBlock::Text. We should extend this to include ToolResult blocks to provide more context for routing decisions. However, we must ensure that control-plane items (such as McpApprovalRequest and McpApprovalResponse) are excluded to avoid noise and instability in routing. Additionally, Thinking blocks should be excluded as they contain internal reasoning rather than user-facing content, keeping the analysis channel distinct from routing logic.

        for msg in &self.messages {
            match &msg.content {
                InputContent::String(s) => push(s, &mut has_content, &mut buffer),
                InputContent::Blocks(blocks) => {
                    for block in blocks {
                        match block {
                            InputContentBlock::Text(text_block) => {
                                push(&text_block.text, &mut has_content, &mut buffer);
                            }
                            InputContentBlock::ToolResult(tool_result) => {
                                if tool_result.is_control_plane() {
                                    continue;
                                }
                                if let Some(content) = &tool_result.content {
                                    match content {
                                        ToolResultContent::String(s) => {
                                            push(s, &mut has_content, &mut buffer);
                                        }
                                        ToolResultContent::Blocks(blocks) => {
                                            for b in blocks {
                                                if let ToolResultContentBlock::Text(t) = b {
                                                    push(&t.text, &mut has_content, &mut buffer);
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                            _ => {}
                        }
                    }
                }
            }
        }
References
  1. To avoid noise and instability in routing decisions, do not extract text from control-plane items such as McpApprovalRequest and McpApprovalResponse.
  2. Mixing user-facing content with internal chain-of-thought (CoT) in the analysis channel conflates the two distinct purposes; routing should focus on user-facing or environment state.


buffer
}
}

impl Tool {
fn matches_tool_choice_name(&self, name: &str) -> bool {
match self {
Expand Down
12 changes: 12 additions & 0 deletions model_gateway/src/routers/http/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use openai_protocol::{
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
messages::CreateMessageRequest,
rerank::{RerankRequest, RerankResponse, RerankResult},
responses::ResponsesRequest,
transcription::{AudioFile, TranscriptionRequest},
Expand Down Expand Up @@ -1125,6 +1126,17 @@ impl RouterTrait for Router {
.await
}

async fn route_messages(
&self,
headers: Option<&HeaderMap>,
_tenant_meta: &TenantRequestMeta,
body: &CreateMessageRequest,
model_id: &str,
) -> Response {
self.route_typed_request(headers, body, "/v1/messages", model_id)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Important: route_to_endpoint() in grpc/utils/metrics.rs has no match arm for "/v1/messages", so it falls through to "other". All messages-API metrics (request counts, durations, errors, retries) from the HTTP router will be bucketed under the "other" endpoint label instead of the existing ENDPOINT_MESSAGES ("messages").

The gRPC routers already use metrics_labels::ENDPOINT_MESSAGES directly, but the HTTP router relies on route_to_endpoint(route) to derive the label from the path string.

Fix: add "/v1/messages" => metrics_labels::ENDPOINT_MESSAGES to the match in route_to_endpoint.

.await
}

async fn route_completion(
&self,
headers: Option<&HeaderMap>,
Expand Down
155 changes: 155 additions & 0 deletions model_gateway/tests/api/messages_api_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
//! Integration tests for the Anthropic Messages API (`/v1/messages`)
//! against the HTTP backend, which proxies to sglang's native
//! `/v1/messages` endpoint.

use axum::{
body::Body,
extract::Request,
http::{header::CONTENT_TYPE, StatusCode},
};
use serde_json::json;
use tower::ServiceExt;

use crate::common::{
mock_worker::{HealthStatus, MockWorkerConfig, WorkerType},
AppTestContext,
};

#[tokio::test]
async fn test_v1_messages_proxy_success() {
let ctx = AppTestContext::new(vec![MockWorkerConfig {
port: 18301,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;

let app = ctx.create_app();

let payload = json!({
"model": "mock-model",
"max_tokens": 64,
"messages": [
{"role": "user", "content": "Hello, Claude!"}
]
});

let req = Request::builder()
.method("POST")
.uri("/v1/messages")
.header(CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&payload).unwrap()))
.unwrap();

let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);

let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();

assert_eq!(body_json["type"], "message");
assert_eq!(body_json["role"], "assistant");
assert_eq!(body_json["model"], "mock-model");
assert_eq!(body_json["stop_reason"], "end_turn");
let content = body_json["content"].as_array().expect("content array");
assert_eq!(content.len(), 1);
assert_eq!(content[0]["type"], "text");
assert!(body_json["usage"]["input_tokens"].is_number());

ctx.shutdown().await;
}

#[tokio::test]
async fn test_v1_messages_proxy_streaming() {
let ctx = AppTestContext::new(vec![MockWorkerConfig {
port: 18302,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;

let app = ctx.create_app();

let payload = json!({
"model": "mock-model",
"max_tokens": 64,
"stream": true,
"messages": [
{"role": "user", "content": "Stream me a haiku"}
]
});

let req = Request::builder()
.method("POST")
.uri("/v1/messages")
.header(CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&payload).unwrap()))
.unwrap();

let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let content_type = resp
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
content_type.contains("text/event-stream"),
"expected SSE content-type, got {content_type:?}"
);

let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let text = std::str::from_utf8(&body).expect("utf8");

// Wire format: `event: <type>\ndata: <json>\n\n`
let event_types: Vec<&str> = text
.lines()
.filter_map(|l| l.strip_prefix("event: "))
.collect();

assert_eq!(event_types.first().copied(), Some("message_start"));
assert_eq!(event_types.last().copied(), Some("message_stop"));
assert!(event_types.contains(&"content_block_delta"));

ctx.shutdown().await;
}

#[tokio::test]
async fn test_v1_messages_proxy_propagates_upstream_error() {
let ctx = AppTestContext::new(vec![MockWorkerConfig {
port: 18303,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 1.0, // always fail
}])
.await;

let app = ctx.create_app();

let payload = json!({
"model": "mock-model",
"max_tokens": 16,
"messages": [{"role": "user", "content": "fail please"}]
});

let req = Request::builder()
.method("POST")
.uri("/v1/messages")
.header(CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&payload).unwrap()))
.unwrap();

let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);

ctx.shutdown().await;
}
1 change: 1 addition & 0 deletions model_gateway/tests/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! API endpoint integration tests

mod api_endpoints_test;
mod messages_api_test;
mod parser_endpoints_test;
mod request_formats_test;
mod responses_api_test;
Expand Down
112 changes: 112 additions & 0 deletions model_gateway/tests/common/mock_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ impl MockWorker {
.route("/get_model_info", get(model_info_handler))
.route("/generate", post(generate_handler))
.route("/v1/chat/completions", post(chat_completions_handler))
.route("/v1/messages", post(messages_handler))
.route("/v1/completions", post(completions_handler))
.route("/v1/rerank", post(rerank_handler))
.route("/v1/responses", post(responses_handler))
Expand Down Expand Up @@ -501,6 +502,117 @@ async fn chat_completions_handler(
}
}

async fn messages_handler(
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
Json(payload): Json<serde_json::Value>,
) -> Response {
let config = config.read().await;

if should_fail(&config) {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"type": "error",
"error": {
"type": "api_error",
"message": "Random failure for testing"
}
})),
)
.into_response();
}

if config.response_delay_ms > 0 {
tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await;
}

let is_stream = payload
.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);

let model = payload
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("mock-model")
.to_string();
let message_id = format!("msg_{}", Uuid::now_v7());

if is_stream {
let message_id_for_stream = message_id.clone();
let model_for_stream = model.clone();
let events = vec![
(
"message_start",
json!({
"type": "message_start",
"message": {
"id": message_id_for_stream,
"type": "message",
"role": "assistant",
"content": [],
"model": model_for_stream,
"stop_reason": null,
"stop_sequence": null,
"usage": {"input_tokens": 10, "output_tokens": 0}
}
}),
),
(
"content_block_start",
json!({
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""}
}),
),
(
"content_block_delta",
json!({
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": "Mock streamed response."}
}),
),
(
"content_block_stop",
json!({"type": "content_block_stop", "index": 0}),
),
(
"message_delta",
json!({
"type": "message_delta",
"delta": {"stop_reason": "end_turn", "stop_sequence": null},
"usage": {"output_tokens": 5}
}),
),
("message_stop", json!({"type": "message_stop"})),
];

let stream = stream::iter(events.into_iter().map(|(event, data)| {
Ok::<_, Infallible>(Event::default().event(event).data(data.to_string()))
}));

Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
} else {
Json(json!({
"id": message_id,
"type": "message",
"role": "assistant",
"content": [
{"type": "text", "text": "This is a mock messages response."}
],
"model": model,
"stop_reason": "end_turn",
"stop_sequence": null,
"usage": {"input_tokens": 10, "output_tokens": 5}
}))
.into_response()
}
}

#[expect(
clippy::unwrap_used,
reason = "test helper - panicking on failure is intentional"
Expand Down
Loading
Loading