Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions crates/forge_domain/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ impl Error {
pub fn sync_failed(count: usize) -> Self {
Self::SyncFailed { count }
}

/// Convert a JSON error to a Retryable error for retry on malformed arguments.
/// This wraps a serde_json::Error in a ToolCallArgument error which is then made retryable.
pub fn tool_call_json_error(error: serde_json::Error, args: String) -> Self {
let json_error = forge_json_repair::JsonRepairError::JsonError(error);
Self::Retryable(anyhow::anyhow!(Self::ToolCallArgument {
error: json_error,
args,
}))
}
}

#[cfg(test)]
Expand Down
29 changes: 11 additions & 18 deletions crates/forge_domain/src/result_stream_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,35 +746,28 @@ mod tests {
#[tokio::test]
async fn test_into_full_with_tool_call_parse_failure_creates_retryable_error() {
use crate::{ToolCallId, ToolCallPart, ToolName};

// Fixture: Create a stream with invalid tool call JSON
let invalid_tool_call_part = ToolCallPart {
call_id: Some(ToolCallId::new("call_123")),
name: Some(ToolName::new("test_tool")),
arguments_part: "invalid json {".to_string(), // Invalid JSON
thought_signature: None,
};

let messages = vec![Ok(ChatCompletionMessage::default()
.content(Content::part("Processing..."))
.add_tool_call(ToolCall::Part(invalid_tool_call_part)))];

let result_stream: BoxStream<ChatCompletionMessage, anyhow::Error> =
Box::pin(tokio_stream::iter(messages));

// Actual: Convert stream to full message
let actual = result_stream.into_full(false).await;

// Expected: Should not fail with invalid tool calls
assert!(actual.is_ok());
let actual = actual.unwrap();
let expected = ToolCallFull {
name: ToolName::new("test_tool"),
call_id: Some(ToolCallId::new("call_123")),
arguments: ToolCallArguments::from_json("invalid json {"),
thought_signature: None,
};
assert_eq!(actual.tool_calls[0], expected);
let actual: Result<ChatCompletionMessageFull, anyhow::Error> = result_stream.into_full(false).await;
// Expected: Should fail with a Retryable error since the tool call has invalid JSON
assert!(actual.is_err(), "Invalid tool call JSON should create a Retryable error");
let err = actual.unwrap_err();
// The error should be a Retryable error
// Check the error chain for Retryable
let is_retryable = err
.downcast_ref::<crate::Error>()
.map(|e| matches!(e, crate::Error::Retryable(_)))
.unwrap_or(false);
assert!(is_retryable, "Error should be Retryable, got: {:?}", err);
}

#[tokio::test]
Expand Down
31 changes: 31 additions & 0 deletions crates/forge_domain/src/tools/call/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,40 @@ impl ToolCallArguments {
}

pub fn from_json(str: &str) -> Self {
// Always store as Unparsed to preserve the original JSON string format
// for serialization. The parse_json method should be used when you need
// to handle parsing errors gracefully.
ToolCallArguments::Unparsed(str.to_string())
}

#[cfg(test)]
pub(crate) fn from_json_for_test(str: &str) -> Self {
serde_json::from_str::<Value>(str)
.map(ToolCallArguments::Parsed)
.unwrap_or_else(|_| ToolCallArguments::Unparsed(str.to_string()))
}

/// Parse a JSON string into ToolCallArguments with proper error handling.
///
/// This attempts to parse the string as JSON using json_repair for recovery.
/// Returns a retryable error if parsing fails, allowing the system to retry
/// the request with the model.
pub fn parse_json(str: &str) -> crate::Result<Self> {
match serde_json::from_str::<Value>(str) {
Ok(value) => Ok(ToolCallArguments::Parsed(value)),
Err(serde_err) => {
// Try json_repair as a fallback
match json_repair(str) {
Ok(repaired) => Ok(ToolCallArguments::Parsed(repaired)),
Err(_json_err) => Err(crate::Error::tool_call_json_error(
serde_err,
str.to_string(),
)),
}
}
}
}

pub fn from_parameters(object: BTreeMap<String, String>) -> ToolCallArguments {
let mut map = Map::new();

Expand Down
26 changes: 10 additions & 16 deletions crates/forge_domain/src/tools/call/tool_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,8 @@ impl ToolCallFull {
let arguments = if current_arguments.is_empty() {
ToolCallArguments::default()
} else {
ToolCallArguments::from_json(current_arguments.as_str())
ToolCallArguments::parse_json(&current_arguments)?
};

tool_calls.push(ToolCallFull {
name: tool_name,
call_id: Some(existing_call_id.clone()),
Expand Down Expand Up @@ -188,9 +187,8 @@ impl ToolCallFull {
let arguments = if current_arguments.is_empty() {
ToolCallArguments::default()
} else {
ToolCallArguments::from_json(current_arguments.as_str())
ToolCallArguments::parse_json(&current_arguments)?
};

tool_calls.push(ToolCallFull {
name: tool_name,
call_id: current_call_id,
Expand Down Expand Up @@ -384,23 +382,19 @@ mod tests {
ToolCallFull {
name: ToolName::new("read"),
call_id: Some(ToolCallId("call_1".to_string())),
arguments: ToolCallArguments::from_json(
r#"{"path": "crates/forge_services/src/fixtures/mascot.md"}"#,
),
arguments: ToolCallArguments::from_json(r#"{"path": "crates/forge_services/src/fixtures/mascot.md"}"#).normalize(),
thought_signature: None,
},
ToolCallFull {
name: ToolName::new("read"),
call_id: Some(ToolCallId("call_2".to_string())),
arguments: ToolCallArguments::from_json(r#"{"path": "docs/onboarding.md"}"#),
arguments: ToolCallArguments::from_json(r#"{"path": "docs/onboarding.md"}"#).normalize(),
thought_signature: None,
},
ToolCallFull {
name: ToolName::new("read"),
call_id: Some(ToolCallId("call_3".to_string())),
arguments: ToolCallArguments::from_json(
r#"{"path": "crates/forge_services/src/service/service.md"}"#,
),
arguments: ToolCallArguments::from_json(r#"{"path": "crates/forge_services/src/service/service.md"}"#).normalize(),
thought_signature: None,
},
];
Expand Down Expand Up @@ -522,7 +516,7 @@ mod tests {
let expected = vec![ToolCallFull {
call_id: Some(ToolCallId("call_1".to_string())),
name: ToolName::new("read"),
arguments: ToolCallArguments::from_json(r#"{"path": "docs/onboarding.md"}"#),
arguments: ToolCallArguments::from_json(r#"{"path": "docs/onboarding.md"}"#).normalize().normalize(),
thought_signature: None,
}];

Expand Down Expand Up @@ -604,7 +598,7 @@ mod tests {
let expected = vec![ToolCallFull {
call_id: Some(ToolCallId("0".to_string())),
name: ToolName::new("read"),
arguments: ToolCallArguments::from_json(r#"{"path": "/test/file.md"}"#),
arguments: ToolCallArguments::from_json(r#"{"path": "/test/file.md"}"#).normalize(),
thought_signature: None,
}];

Expand Down Expand Up @@ -712,7 +706,7 @@ mod tests {
let expected = vec![ToolCallFull {
call_id: Some(ToolCallId("call_1".to_string())),
name: ToolName::new("shell"),
arguments: ToolCallArguments::from_json(r#"{"command": "date"}"#),
arguments: ToolCallArguments::from_json(r#"{"command": "date"}"#).normalize(),
thought_signature: Some("signature_abc123".to_string()),
}];

Expand Down Expand Up @@ -742,13 +736,13 @@ mod tests {
ToolCallFull {
call_id: Some(ToolCallId("call_1".to_string())),
name: ToolName::new("read"),
arguments: ToolCallArguments::from_json(r#"{"path": "file1.txt"}"#),
arguments: ToolCallArguments::from_json(r#"{"path": "file1.txt"}"#).normalize(),
thought_signature: Some("sig_1".to_string()),
},
ToolCallFull {
call_id: Some(ToolCallId("call_2".to_string())),
name: ToolName::new("read"),
arguments: ToolCallArguments::from_json(r#"{"path": "file2.txt"}"#),
arguments: ToolCallArguments::from_json(r#"{"path": "file2.txt"}"#).normalize(),
thought_signature: Some("sig_2".to_string()),
},
];
Expand Down
Loading