Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 8 additions & 99 deletions src/build/caller_utils_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,50 +142,6 @@ fn wit_type_to_rust(wit_type: &str) -> String {
}
}

// Generate default value for Rust type - IMPROVED with additional types
fn generate_default_value(rust_type: &str) -> String {
match rust_type {
// Integer types
"i8" | "u8" | "i16" | "u16" | "i32" | "u32" | "i64" | "u64" | "isize" | "usize" => {
"0".to_string()
}
// Floating point types
"f32" | "f64" => "0.0".to_string(),
// String types
"String" => "String::new()".to_string(),
"&str" => "\"\"".to_string(),
// Other primitive types
"bool" => "false".to_string(),
"char" => "'\\0'".to_string(),
"()" => "()".to_string(),
// Collection types
t if t.starts_with("Vec<") => "Vec::new()".to_string(),
t if t.starts_with("Option<") => "None".to_string(),
t if t.starts_with("Result<") => {
// For Result, default to Ok with the default value of the success type
if let Some(success_type_end) = t.find(',') {
let success_type = &t[7..success_type_end];
format!("Ok({})", generate_default_value(success_type))
} else {
"Ok(())".to_string()
}
}
//t if t.starts_with("HashMap<") => "HashMap::new()".to_string(),
t if t.starts_with("(") => {
// Generate default tuple with default values for each element
let inner_part = t.trim_start_matches('(').trim_end_matches(')');
let parts: Vec<_> = inner_part.split(", ").collect();
let default_values: Vec<_> = parts
.iter()
.map(|part| generate_default_value(part))
.collect();
format!("({})", default_values.join(", "))
}
// For custom types, assume they implement Default
_ => format!("{}::default()", rust_type),
}
}

// Structure to represent a field in a WIT signature struct
#[derive(Debug)]
struct SignatureField {
Expand Down Expand Up @@ -346,7 +302,7 @@ fn parse_wit_file(file_path: &Path) -> Result<(Vec<SignatureStruct>, Vec<String>
}

// Generate a Rust async function from a signature struct
fn generate_async_function(signature: &SignatureStruct) -> String {
fn generate_async_function(signature: &SignatureStruct) -> Option<String> {
// Convert function name from kebab-case to snake_case
let snake_function_name = to_snake_case(&signature.function_name);

Expand Down Expand Up @@ -405,55 +361,7 @@ fn generate_async_function(signature: &SignatureStruct) -> String {

// For HTTP endpoints, generate commented-out implementation
if signature.attr_type == "http" {
debug!("Generating commented-out stub for HTTP endpoint");
let default_value = generate_default_value(&return_type);

// Add underscore prefix to all parameters for HTTP stubs
let all_params_with_underscore = if target_param.is_empty() {
params
.iter()
.map(|param| {
let parts: Vec<&str> = param.split(':').collect();
if parts.len() == 2 {
format!("_{}: {}", parts[0], parts[1])
} else {
warn!(param = %param, "Could not parse parameter for underscore prefix");
format!("_{}", param)
}
})
.collect::<Vec<String>>()
.join(", ")
} else {
let target_with_underscore = format!("_target: {}", target_param);
if params.is_empty() {
target_with_underscore
} else {
let params_with_underscore = params
.iter()
.map(|param| {
let parts: Vec<&str> = param.split(':').collect();
if parts.len() == 2 {
format!("_{}: {}", parts[0], parts[1])
} else {
warn!(param = %param, "Could not parse parameter for underscore prefix");
format!("_{}", param)
}
})
.collect::<Vec<String>>()
.join(", ");
format!("{}, {}", target_with_underscore, params_with_underscore)
}
};

return format!(
"// /// Generated stub for `{}` {} RPC call\n// /// HTTP endpoint - uncomment to implement\n// pub async fn {}({}) -> {} {{\n// // TODO: Implement HTTP endpoint\n// Ok({})\n// }}",
signature.function_name,
signature.attr_type,
full_function_name,
all_params_with_underscore,
wrapped_return_type,
default_value
);
return None;
}

// Format JSON parameters correctly
Expand All @@ -480,7 +388,7 @@ fn generate_async_function(signature: &SignatureStruct) -> String {

// Generate function with implementation using send
debug!("Generating standard RPC stub implementation");
format!(
Some(format!(
"/// Generated stub for `{}` {} RPC call\npub async fn {}({}) -> {} {{\n let body = {};\n let body = serde_json::to_vec(&body).unwrap();\n let request = Request::to(target)\n .body(body);\n send::<{}>(request).await\n}}",
signature.function_name,
signature.attr_type,
Expand All @@ -489,7 +397,7 @@ fn generate_async_function(signature: &SignatureStruct) -> String {
wrapped_return_type,
json_params,
return_type
)
))
}

// Create the caller-utils crate with a single lib.rs file
Expand Down Expand Up @@ -621,9 +529,10 @@ crate-type = ["cdylib", "lib"]

// Add function implementations
for signature in &signatures {
let function_impl = generate_async_function(signature);
mod_content.push_str(&function_impl);
mod_content.push_str("\n\n");
if let Some(function_impl) = generate_async_function(signature) {
mod_content.push_str(&function_impl);
mod_content.push_str("\n\n");
}
}

// Store the module content
Expand Down