diff --git a/crates/forge_domain/src/error.rs b/crates/forge_domain/src/error.rs index 02d8f60529..c25c51b719 100644 --- a/crates/forge_domain/src/error.rs +++ b/crates/forge_domain/src/error.rs @@ -148,6 +148,17 @@ impl Error { pub fn sync_failed(count: usize) -> Self { Self::SyncFailed { count } } + + /// Convert a JSON error to a Retryable error for retry on malformed + /// arguments. This wraps a serde_json::Error in a ToolCallArgument + /// error which is then made retryable. + pub fn tool_call_json_error(error: serde_json::Error, args: String) -> Self { + let json_error = forge_json_repair::JsonRepairError::JsonError(error); + Self::Retryable(anyhow::anyhow!(Self::ToolCallArgument { + error: json_error, + args, + })) + } } #[cfg(test)] diff --git a/crates/forge_domain/src/result_stream_ext.rs b/crates/forge_domain/src/result_stream_ext.rs index 9250ff0adc..96247fbf34 100644 --- a/crates/forge_domain/src/result_stream_ext.rs +++ b/crates/forge_domain/src/result_stream_ext.rs @@ -287,10 +287,7 @@ mod tests { use pretty_assertions::assert_eq; use super::*; - use crate::{ - BoxStream, Content, FinishReason, TokenCount, ToolCall, ToolCallArguments, ToolCallId, - ToolName, - }; + use crate::{BoxStream, Content, FinishReason, TokenCount, ToolCall, ToolCallId, ToolName}; #[tokio::test] async fn test_into_full_basic() { @@ -746,35 +743,33 @@ mod tests { #[tokio::test] async fn test_into_full_with_tool_call_parse_failure_creates_retryable_error() { use crate::{ToolCallId, ToolCallPart, ToolName}; - - // Fixture: Create a stream with invalid tool call JSON let invalid_tool_call_part = ToolCallPart { call_id: Some(ToolCallId::new("call_123")), name: Some(ToolName::new("test_tool")), arguments_part: "invalid json {".to_string(), // Invalid JSON thought_signature: None, }; - let messages = vec![Ok(ChatCompletionMessage::default() .content(Content::part("Processing...")) .add_tool_call(ToolCall::Part(invalid_tool_call_part)))]; - let result_stream: BoxStream = Box::pin(tokio_stream::iter(messages)); - - // Actual: Convert stream to full message - let actual = result_stream.into_full(false).await; - - // Expected: Should not fail with invalid tool calls - assert!(actual.is_ok()); - let actual = actual.unwrap(); - let expected = ToolCallFull { - name: ToolName::new("test_tool"), - call_id: Some(ToolCallId::new("call_123")), - arguments: ToolCallArguments::from_json("invalid json {"), - thought_signature: None, - }; - assert_eq!(actual.tool_calls[0], expected); + let actual: Result = + result_stream.into_full(false).await; + // Expected: Should fail with a Retryable error since the tool call has invalid + // JSON + assert!( + actual.is_err(), + "Invalid tool call JSON should create a Retryable error" + ); + let err = actual.unwrap_err(); + // The error should be a Retryable error + // Check the error chain for Retryable + let is_retryable = err + .downcast_ref::() + .map(|e| matches!(e, crate::Error::Retryable(_))) + .unwrap_or(false); + assert!(is_retryable, "Error should be Retryable, got: {:?}", err); } #[tokio::test] diff --git a/crates/forge_domain/src/tools/call/args.rs b/crates/forge_domain/src/tools/call/args.rs index 2a6bf77130..bd93d7ef38 100644 --- a/crates/forge_domain/src/tools/call/args.rs +++ b/crates/forge_domain/src/tools/call/args.rs @@ -106,9 +106,40 @@ impl ToolCallArguments { } pub fn from_json(str: &str) -> Self { + // Always store as Unparsed to preserve the original JSON string format + // for serialization. The parse_json method should be used when you need + // to handle parsing errors gracefully. ToolCallArguments::Unparsed(str.to_string()) } + #[cfg(test)] + pub(crate) fn from_json_for_test(str: &str) -> Self { + serde_json::from_str::(str) + .map(ToolCallArguments::Parsed) + .unwrap_or_else(|_| ToolCallArguments::Unparsed(str.to_string())) + } + + /// Parse a JSON string into ToolCallArguments with proper error handling. + /// + /// This attempts to parse the string as JSON using json_repair for + /// recovery. Returns a retryable error if parsing fails, allowing the + /// system to retry the request with the model. + pub fn parse_json(str: &str) -> crate::Result { + match serde_json::from_str::(str) { + Ok(value) => Ok(ToolCallArguments::Parsed(value)), + Err(serde_err) => { + // Try json_repair as a fallback + match json_repair(str) { + Ok(repaired) => Ok(ToolCallArguments::Parsed(repaired)), + Err(_json_err) => Err(crate::Error::tool_call_json_error( + serde_err, + str.to_string(), + )), + } + } + } + } + pub fn from_parameters(object: BTreeMap) -> ToolCallArguments { let mut map = Map::new(); diff --git a/crates/forge_domain/src/tools/call/tool_call.rs b/crates/forge_domain/src/tools/call/tool_call.rs index 317906e6f6..39f13c4810 100644 --- a/crates/forge_domain/src/tools/call/tool_call.rs +++ b/crates/forge_domain/src/tools/call/tool_call.rs @@ -148,9 +148,8 @@ impl ToolCallFull { let arguments = if current_arguments.is_empty() { ToolCallArguments::default() } else { - ToolCallArguments::from_json(current_arguments.as_str()) + ToolCallArguments::parse_json(¤t_arguments)? }; - tool_calls.push(ToolCallFull { name: tool_name, call_id: Some(existing_call_id.clone()), @@ -188,9 +187,8 @@ impl ToolCallFull { let arguments = if current_arguments.is_empty() { ToolCallArguments::default() } else { - ToolCallArguments::from_json(current_arguments.as_str()) + ToolCallArguments::parse_json(¤t_arguments)? }; - tool_calls.push(ToolCallFull { name: tool_name, call_id: current_call_id, @@ -386,13 +384,15 @@ mod tests { call_id: Some(ToolCallId("call_1".to_string())), arguments: ToolCallArguments::from_json( r#"{"path": "crates/forge_services/src/fixtures/mascot.md"}"#, - ), + ) + .normalize(), thought_signature: None, }, ToolCallFull { name: ToolName::new("read"), call_id: Some(ToolCallId("call_2".to_string())), - arguments: ToolCallArguments::from_json(r#"{"path": "docs/onboarding.md"}"#), + arguments: ToolCallArguments::from_json(r#"{"path": "docs/onboarding.md"}"#) + .normalize(), thought_signature: None, }, ToolCallFull { @@ -400,7 +400,8 @@ mod tests { call_id: Some(ToolCallId("call_3".to_string())), arguments: ToolCallArguments::from_json( r#"{"path": "crates/forge_services/src/service/service.md"}"#, - ), + ) + .normalize(), thought_signature: None, }, ]; @@ -522,7 +523,9 @@ mod tests { let expected = vec![ToolCallFull { call_id: Some(ToolCallId("call_1".to_string())), name: ToolName::new("read"), - arguments: ToolCallArguments::from_json(r#"{"path": "docs/onboarding.md"}"#), + arguments: ToolCallArguments::from_json(r#"{"path": "docs/onboarding.md"}"#) + .normalize() + .normalize(), thought_signature: None, }]; @@ -604,7 +607,7 @@ mod tests { let expected = vec![ToolCallFull { call_id: Some(ToolCallId("0".to_string())), name: ToolName::new("read"), - arguments: ToolCallArguments::from_json(r#"{"path": "/test/file.md"}"#), + arguments: ToolCallArguments::from_json(r#"{"path": "/test/file.md"}"#).normalize(), thought_signature: None, }]; @@ -712,7 +715,7 @@ mod tests { let expected = vec![ToolCallFull { call_id: Some(ToolCallId("call_1".to_string())), name: ToolName::new("shell"), - arguments: ToolCallArguments::from_json(r#"{"command": "date"}"#), + arguments: ToolCallArguments::from_json(r#"{"command": "date"}"#).normalize(), thought_signature: Some("signature_abc123".to_string()), }]; @@ -742,13 +745,13 @@ mod tests { ToolCallFull { call_id: Some(ToolCallId("call_1".to_string())), name: ToolName::new("read"), - arguments: ToolCallArguments::from_json(r#"{"path": "file1.txt"}"#), + arguments: ToolCallArguments::from_json(r#"{"path": "file1.txt"}"#).normalize(), thought_signature: Some("sig_1".to_string()), }, ToolCallFull { call_id: Some(ToolCallId("call_2".to_string())), name: ToolName::new("read"), - arguments: ToolCallArguments::from_json(r#"{"path": "file2.txt"}"#), + arguments: ToolCallArguments::from_json(r#"{"path": "file2.txt"}"#).normalize(), thought_signature: Some("sig_2".to_string()), }, ];