Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions tests/support_unit_tests/trace_llm_test_fixtures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,23 @@ use ironclaw::llm::{ChatMessage, CompletionRequest, ToolCompletionRequest};
use crate::support::trace_provider::TraceLlm;
use crate::support::trace_types::LlmTrace;

#[derive(Copy, Clone, Debug)]
pub struct TextStepSpec<'a> {
pub content: &'a str,
pub input_tokens: u32,
pub output_tokens: u32,
}

/// Builds a text-response trace step.
///
/// `content`, `input_tokens`, and `output_tokens` populate the response; the
/// returned [`TraceStep`] has no request hint or expected tool results.
pub fn text_step(content: &str, input_tokens: u32, output_tokens: u32) -> TraceStep {
/// The returned [`TraceStep`] has no request hint or expected tool results.
pub fn text_step(spec: TextStepSpec<'_>) -> TraceStep {
TraceStep {
request_hint: None,
response: TraceResponse::Text {
content: content.to_string(),
input_tokens,
output_tokens,
content: spec.content.to_string(),
input_tokens: spec.input_tokens,
output_tokens: spec.output_tokens,
},
expected_tool_results: Vec::new(),
}
Expand Down Expand Up @@ -67,18 +73,9 @@ pub fn make_completion_request(user_msg: &str) -> CompletionRequest {

/// Builds a [`TraceLlm`] backed by a single text-response step.
///
/// `user_msg` seeds the trace turn, while `content`, `input_tokens`, and
/// `output_tokens` configure the returned provider's only replayable response.
pub fn single_text_step_llm(
user_msg: &str,
content: &str,
input_tokens: u32,
output_tokens: u32,
) -> TraceLlm {
let trace = LlmTrace::single_turn(
"test-model",
user_msg,
vec![text_step(content, input_tokens, output_tokens)],
);
/// `user_msg` seeds the trace turn, while `spec` configures the returned
/// provider's only replayable response.
pub fn single_text_step_llm(user_msg: &str, spec: TextStepSpec<'_>) -> TraceLlm {
let trace = LlmTrace::single_turn("test-model", user_msg, vec![text_step(spec)]);
TraceLlm::from_trace(trace)
}
134 changes: 110 additions & 24 deletions tests/support_unit_tests/trace_llm_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
use crate::support::trace_provider::TraceLlm;
use crate::support::trace_types::{LlmTrace, TraceExpects, TraceTurn};
use crate::trace_llm_test_fixtures::{
make_completion_request, make_request, simple_tool_call, single_text_step_llm, text_step,
tool_calls_step,
TextStepSpec, make_completion_request, make_request, simple_tool_call, single_text_step_llm,
text_step, tool_calls_step,
};
use ironclaw::llm::{
ChatMessage, FinishReason, LlmProvider, Role, ToolCall, ToolCompletionRequest,
Expand All @@ -24,6 +24,19 @@ struct LlmCounterMinima {
output_tokens: u32,
}

#[derive(Copy, Clone, Debug)]
struct ExpectedToolCall<'a> {
name: &'a str,
id: &'a str,
}

#[derive(Copy, Clone, Debug)]
struct CapturedRequestsExpectation<'a> {
batches: usize,
last_user_contains: &'a str,
min_msgs_per_batch: usize,
}

fn assert_msg(role: Role, msg: &ChatMessage, contains: &str) {
assert_eq!(msg.role, role);
assert!(
Expand All @@ -36,24 +49,23 @@ fn assert_msg(role: Role, msg: &ChatMessage, contains: &str) {

fn assert_captured_requests_shape(
captured: &[Vec<ChatMessage>],
expected_batches: usize,
last_user_contains: &str,
min_msgs_per_batch: usize,
expected: CapturedRequestsExpectation<'_>,
) {
assert_eq!(captured.len(), expected_batches);
assert_eq!(captured.len(), expected.batches);
assert!(
captured
.iter()
.all(|batch| batch.len() >= min_msgs_per_batch),
"expected every captured batch to contain at least {min_msgs_per_batch} messages"
.all(|batch| batch.len() >= expected.min_msgs_per_batch),
"expected every captured batch to contain at least {} messages",
expected.min_msgs_per_batch
);
let last_batch = captured
.last()
.expect("captured requests should contain at least one batch");
let last_message = last_batch
.last()
.expect("captured request batch should contain at least one message");
assert_msg(Role::User, last_message, last_user_contains);
assert_msg(Role::User, last_message, expected.last_user_contains);
}

fn assert_llm_counters(actual: LlmCounterSnapshot, min: LlmCounterMinima) {
Expand All @@ -77,15 +89,22 @@ fn assert_llm_counters(actual: LlmCounterSnapshot, min: LlmCounterMinima) {
);
}

fn assert_tool_call(call: &ToolCall, expected_name: &str, expected_id: &str) {
assert_eq!(call.name, expected_name);
assert_eq!(call.id, expected_id);
fn assert_tool_call(call: &ToolCall, expected: ExpectedToolCall<'_>) {
assert_eq!(call.name, expected.name);
assert_eq!(call.id, expected.id);
assert_eq!(call.arguments, serde_json::json!({"key": "value"}));
}

#[tokio::test]
async fn replays_text_response() {
let llm = single_text_step_llm("hi", "Hello world", 100, 20);
let llm = single_text_step_llm(
"hi",
TextStepSpec {
content: "Hello world",
input_tokens: 100,
output_tokens: 20,
},
);

let resp = llm.complete_with_tools(make_request("hi")).await.unwrap();

Expand Down Expand Up @@ -126,7 +145,13 @@ async fn replays_tool_calls() {

assert!(resp.content.is_none());
assert_eq!(resp.tool_calls.len(), 1);
assert_tool_call(&resp.tool_calls[0], "memory_search", "call_memory_search");
assert_tool_call(
&resp.tool_calls[0],
ExpectedToolCall {
name: "memory_search",
id: "call_memory_search",
},
);
assert_eq!(resp.finish_reason, FinishReason::ToolUse);
assert_llm_counters(
LlmCounterSnapshot {
Expand Down Expand Up @@ -172,7 +197,13 @@ async fn replays_non_templated_tool_calls_after_plain_text_tool_errors() {
.expect("plain-text tool errors should not be parsed when no templates are present");

assert_eq!(resp.tool_calls.len(), 1);
assert_tool_call(&resp.tool_calls[0], "write_file", "call_write_file");
assert_tool_call(
&resp.tool_calls[0],
ExpectedToolCall {
name: "write_file",
id: "call_write_file",
},
);
}

#[tokio::test]
Expand All @@ -182,7 +213,11 @@ async fn advances_through_steps() {
"do something",
vec![
tool_calls_step(vec![simple_tool_call("echo")], 50, 10),
text_step("Done!", 60, 5),
text_step(TextStepSpec {
content: "Done!",
input_tokens: 60,
output_tokens: 5,
}),
],
);
let llm = TraceLlm::from_trace(trace);
Expand All @@ -192,7 +227,13 @@ async fn advances_through_steps() {
.await
.unwrap();
assert_eq!(resp1.tool_calls.len(), 1);
assert_tool_call(&resp1.tool_calls[0], "echo", "call_echo");
assert_tool_call(
&resp1.tool_calls[0],
ExpectedToolCall {
name: "echo",
id: "call_echo",
},
);
assert_eq!(llm.calls(), 1);

let resp2 = llm
Expand All @@ -206,7 +247,15 @@ async fn advances_through_steps() {

#[tokio::test]
async fn errors_when_exhausted() {
let trace = LlmTrace::single_turn("test-model", "first", vec![text_step("only once", 10, 5)]);
let trace = LlmTrace::single_turn(
"test-model",
"first",
vec![text_step(TextStepSpec {
content: "only once",
input_tokens: 10,
output_tokens: 5,
})],
);
let llm = TraceLlm::from_trace(trace);

let resp1 = llm.complete_with_tools(make_request("first")).await;
Expand Down Expand Up @@ -256,7 +305,14 @@ async fn from_json_file() {

#[tokio::test]
async fn complete_text_step() {
let llm = single_text_step_llm("hi", "plain text", 30, 8);
let llm = single_text_step_llm(
"hi",
TextStepSpec {
content: "plain text",
input_tokens: 30,
output_tokens: 8,
},
);

let resp = llm.complete(make_completion_request("hi")).await.unwrap();

Expand All @@ -283,7 +339,11 @@ async fn complete_skips_tool_calls_step() {
"hi",
vec![
tool_calls_step(vec![simple_tool_call("echo")], 10, 5),
text_step("skipped past tools", 20, 8),
text_step(TextStepSpec {
content: "skipped past tools",
input_tokens: 20,
output_tokens: 8,
}),
],
);
let llm = TraceLlm::from_trace(trace);
Expand Down Expand Up @@ -314,7 +374,18 @@ async fn captured_requests() {
let trace = LlmTrace::single_turn(
"test-model",
"test",
vec![text_step("resp1", 10, 5), text_step("resp2", 10, 5)],
vec![
text_step(TextStepSpec {
content: "resp1",
input_tokens: 10,
output_tokens: 5,
}),
text_step(TextStepSpec {
content: "resp2",
input_tokens: 10,
output_tokens: 5,
}),
],
);
let llm = TraceLlm::from_trace(trace);

Expand All @@ -328,7 +399,14 @@ async fn captured_requests() {
let captured = llm
.captured_requests()
.expect("captured requests should be available");
assert_captured_requests_shape(&captured, 2, "second message", 1);
assert_captured_requests_shape(
&captured,
CapturedRequestsExpectation {
batches: 2,
last_user_contains: "second message",
min_msgs_per_batch: 1,
},
);
assert_msg(Role::User, &captured[0][0], "first message");
}

Expand All @@ -339,12 +417,20 @@ async fn multi_turn() {
vec![
TraceTurn {
user_input: "first".to_string(),
steps: vec![text_step("turn 1 response", 10, 5)],
steps: vec![text_step(TextStepSpec {
content: "turn 1 response",
input_tokens: 10,
output_tokens: 5,
})],
expects: TraceExpects::default(),
},
TraceTurn {
user_input: "second".to_string(),
steps: vec![text_step("turn 2 response", 20, 10)],
steps: vec![text_step(TextStepSpec {
content: "turn 2 response",
input_tokens: 20,
output_tokens: 10,
})],
expects: TraceExpects::default(),
},
],
Expand Down
Loading