diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..eb9d371 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,65 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on Keep a Changelog and this project follows Semantic Versioning. + +## [0.3.0] - to be released + +### Added +- Dynamic dimension representation in AST and WG grammar via `dyn("name", maxSize)` input dims. +- ONNX conversion support for preserving unresolved dynamic input dimensions in graph metadata. +- New `convert-onnx` CLI flag: `--experimental-dynamic-inputs` (opt-in dynamic input preservation). +- ONNX converter/operator support improvements: + - `ScatterND` + - `Where`, `Equal`, comparison operators, `Cos`, `Sin`, `TriLu`, `Tile` + - `ConstantOfShape`, `Range`, and additional constant-folding evaluators +- Built-in constant folding in `webnn-graph` (`--optimize`) to reduce dynamic-shape plumbing. +- Global debug switch for converter diagnostics (`--debug`). +- Pre-commit hook setup script and make-based local checks. + +### Changed +- ONNX conversion now supports static-lowering + dynamic metadata workflows in one pipeline. +- Graph parser/serializer now support richer values (including object literals in options). +- JS/HTML emitters and visualizer now render mixed static/dynamic shapes. +- Docs expanded and corrected: + - ONNX lowering behavior + - Dynamic dimension guidance + - SmolLM-135M conversion example from Hugging Face + +### Fixed +- Multiple ONNX conversion correctness fixes, including: + - dynamic reshape/expand conversion edge cases + - shape inflation prevention and post-conversion shape tracking + - `Unsqueeze` v14 handling + - identifier sanitization robustness (including `$` prefixes) + - clippy/robustness cleanup across converter and shape inference + +### Compatibility +- Existing static graphs remain supported. +- Validator/serializer support both graph versions `v1` and `v2`. +- Dynamic input metadata is experimental and must be enabled with + `--experimental-dynamic-inputs`. + +## [0.2.1] - 2025-12-28 + +### Added +- ONNX shape inference and `Expand` conversion support. +- Initial ONNX lowering documentation. + +### Fixed +- BERT conversion fixes. +- Identifier sanitization updates. + +## [0.2.0] - 2025-12-24 + +### Added +- Interactive HTML visualizer and `emit-html` command. +- Drag-and-drop `.webnn` loading and parser improvements. +- Graph/weights split workflow improvements and docs refinements. + +## [0.1.0] - 2025-12-24 + +### Added +- Initial release with core DSL parsing/serialization/validation scaffold. +- Binary weights support and foundational CLI commands. diff --git a/Cargo.lock b/Cargo.lock index cdf7aee..dc64b9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -660,7 +660,7 @@ dependencies = [ [[package]] name = "webnn-graph" -version = "0.2.1" +version = "0.3.0" dependencies = [ "anyhow", "base64", diff --git a/Cargo.toml b/Cargo.toml index 45d11f1..719a63e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "webnn-graph" -version = "0.2.1" +version = "0.3.0" edition = "2021" license = "Apache-2.0" description = "Simple DSL for WebNN graphs" diff --git a/README.md b/README.md index 0eecd9b..7ba661f 100644 --- a/README.md +++ b/README.md @@ -319,6 +319,28 @@ Observed output from that run: - `/tmp/smol_hf.weights`: ~513 MB - `/tmp/smol_hf.manifest.json`: ~423 KB +### Example: Converting all-MiniLM-L6-v2 + +Download the ONNX model: + +```bash +curl -L "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx?download=true" \ + -o /tmp/minilm.onnx +``` + +Convert with common sentence-embedding overrides: + +```bash +webnn-graph convert-onnx \ + --input /tmp/minilm.onnx \ + --optimize \ + --override-dim batch_size=1 \ + --override-dim sequence_length=128 \ + --output /tmp/minilm.webnn \ + --weights /tmp/minilm.weights \ + --manifest /tmp/minilm.manifest.json +``` + ### Supported ONNX Operations The converter focuses on NLP/Transformer operations: diff --git a/src/emit_js.rs b/src/emit_js.rs index 2b2267f..758db6e 100644 --- a/src/emit_js.rs +++ b/src/emit_js.rs @@ -27,6 +27,46 @@ fn shape_to_js(shape: &[Dimension]) -> String { format!("[{}]", dims.join(", ")) } +fn normalize_dtype_name(name: &str) -> Option<&'static str> { + match name.to_ascii_lowercase().as_str() { + "float32" => Some("float32"), + "float16" => Some("float16"), + "int4" => Some("int4"), + "uint4" => Some("uint4"), + "int32" => Some("int32"), + "uint32" => Some("uint32"), + "int64" => Some("int64"), + "uint64" => Some("uint64"), + "int8" => Some("int8"), + "uint8" => Some("uint8"), + _ => None, + } +} + +fn normalize_options_for_js(value: &mut serde_json::Value) { + match value { + serde_json::Value::Object(obj) => { + for (k, v) in obj.iter_mut() { + if k == "dataType" || k == "to" { + if let Some(s) = v.as_str() { + if let Some(norm) = normalize_dtype_name(s) { + *v = serde_json::Value::String(norm.to_string()); + continue; + } + } + } + normalize_options_for_js(v); + } + } + serde_json::Value::Array(arr) => { + for v in arr { + normalize_options_for_js(v); + } + } + _ => {} + } +} + /// Emit the WeightsFile helper class for loading weights pub fn emit_weights_loader_js() -> &'static str { r#"/** @@ -174,31 +214,64 @@ pub fn emit_builder_js(g: &GraphJson) -> String { s.push('\n'); for n in &g.nodes { + if n.op == "constant" { + let mut opts = serde_json::Value::Object(n.options.clone()); + normalize_options_for_js(&mut opts); + let dtype = opts + .get("dataType") + .and_then(|v| v.as_str()) + .and_then(normalize_dtype_name) + .unwrap_or("float32"); + let shape = opts + .get("shape") + .cloned() + .unwrap_or_else(|| serde_json::json!([])) + .to_string(); + let data = opts + .get("data") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + s.push_str(&format!( + " {{\n const b64 = {data:?};\n const bytes = Uint8Array.from(atob(b64), c => c.charCodeAt(0));\n env.set({id:?}, builder.constant({{ dataType: {dtype:?}, shape: {shape} }}, bytes.buffer));\n }}\n", + id = n.id, + data = data, + dtype = dtype, + shape = shape + )); + continue; + } + let ins = n .inputs .iter() .map(|x| format!("env.get({:?})", x)) .collect::>() .join(", "); - let opts = serde_json::Value::Object(n.options.clone()).to_string(); - if let Some(outs) = &n.outputs { - s.push_str(&format!( - " {{\n const tmp = builder[{op:?}]({ins}, {opts});\n", + let mut opts_val = serde_json::Value::Object(n.options.clone()); + normalize_options_for_js(&mut opts_val); + let opts = opts_val.to_string(); + let call = if ins.is_empty() { + format!("builder[{op:?}]({opts})", op = n.op, opts = opts) + } else { + format!( + "builder[{op:?}]({ins}, {opts})", op = n.op, ins = ins, opts = opts - )); + ) + }; + if let Some(outs) = &n.outputs { + s.push_str(&format!(" {{\n const tmp = {call};\n", call = call)); for (i, o) in outs.iter().enumerate() { s.push_str(&format!(" env.set({o:?}, tmp[{i}]);\n", o = o, i = i)); } s.push_str(" }\n"); } else { s.push_str(&format!( - " env.set({id:?}, builder[{op:?}]({ins}, {opts}));\n", + " env.set({id:?}, {call});\n", id = n.id, - op = n.op, - ins = ins, - opts = opts + call = call )); } } @@ -383,6 +456,51 @@ mod tests { assert!(js.contains("\"axis\":1")); } + #[test] + fn test_emit_cast_normalizes_dtype_option() { + let mut g = new_graph_json(); + g.inputs.insert( + "x".to_string(), + OperandDesc { + data_type: DataType::Float32, + shape: to_dimension_vector(&[1]), + }, + ); + let mut options = serde_json::Map::new(); + options.insert("to".to_string(), serde_json::json!("Int32")); + g.nodes.push(Node { + id: "y".to_string(), + op: "cast".to_string(), + inputs: vec!["x".to_string()], + options, + outputs: None, + }); + g.outputs.insert("y".to_string(), "y".to_string()); + let js = emit_builder_js(&g); + assert!(js.contains("\"to\":\"int32\"")); + } + + #[test] + fn test_emit_constant_node_uses_atob_decode() { + let mut g = new_graph_json(); + let mut options = serde_json::Map::new(); + options.insert("dataType".to_string(), serde_json::json!("Float32")); + options.insert("shape".to_string(), serde_json::json!([1])); + options.insert("data".to_string(), serde_json::json!("AAAAAA==")); + g.nodes.push(Node { + id: "c0".to_string(), + op: "constant".to_string(), + inputs: vec![], + options, + outputs: None, + }); + g.outputs.insert("c0".to_string(), "c0".to_string()); + let js = emit_builder_js(&g); + assert!(js.contains("atob(b64)")); + assert!(js.contains("dataType: \"float32\"")); + assert!(js.contains("builder.constant")); + } + #[test] fn test_emit_with_multi_outputs() { let mut g = new_graph_json(); diff --git a/src/onnx/convert.rs b/src/onnx/convert.rs index 8b9d40b..6ac1ea3 100644 --- a/src/onnx/convert.rs +++ b/src/onnx/convert.rs @@ -88,10 +88,7 @@ fn infer_shape( value_shapes.get(ins[0].as_str()).cloned() } - // Binary operations (with broadcasting) - prefer shape with FEWER dimensions - // This prevents shape inflation: constants remain compact, not broadcast-expanded - // Rationale: Broadcasting happens implicitly in WebNN ops; storing inflated shapes - // causes compatibility issues when converting back to ONNX + // Binary operations with NumPy-style broadcasting semantics. "Add" | "Sub" | "Mul" | "Div" | "Pow" => { let ins = node.input.as_slice(); if ins.len() < 2 { @@ -103,13 +100,21 @@ fn infer_shape( match (shape_a, shape_b) { (Some(a), Some(b)) => { - // Prefer the shape with FEWER dimensions to avoid shape inflation - // Example: [129] + [1, 128, 1] → keep [129], not [1, 128, 129] - if a.len() <= b.len() { - Some(a.clone()) - } else { - Some(b.clone()) + let rank = a.len().max(b.len()); + let mut out_rev = Vec::with_capacity(rank); + for i in 0..rank { + let da = a.get(a.len().wrapping_sub(1 + i)).copied().unwrap_or(1); + let db = b.get(b.len().wrapping_sub(1 + i)).copied().unwrap_or(1); + if da == db || da == 1 { + out_rev.push(db); + } else if db == 1 { + out_rev.push(da); + } else { + return None; + } } + out_rev.reverse(); + Some(out_rev) } (Some(a), None) => Some(a.clone()), (None, Some(b)) => Some(b.clone()), @@ -527,6 +532,144 @@ fn infer_shape( } } +fn shape_numel(shape: &[i64]) -> Option { + shape.iter().try_fold(1usize, |acc, &d| { + if d < 0 { + return None; + } + usize::try_from(d).ok().map(|v| acc.saturating_mul(v)) + }) +} + +fn const_shape_for_folding( + name: &str, + values: &[i64], + value_shapes: &HashMap>, +) -> Vec { + if let Some(shape) = value_shapes.get(name) { + if shape_numel(shape) == Some(values.len()) { + return shape.clone(); + } + } + + if values.len() == 1 { + Vec::new() + } else { + vec![values.len() as i64] + } +} + +fn broadcast_shape(shape_a: &[i64], shape_b: &[i64]) -> Option> { + let rank = shape_a.len().max(shape_b.len()); + let mut out_rev = Vec::with_capacity(rank); + for i in 0..rank { + let da = shape_a + .get(shape_a.len().wrapping_sub(1 + i)) + .copied() + .unwrap_or(1); + let db = shape_b + .get(shape_b.len().wrapping_sub(1 + i)) + .copied() + .unwrap_or(1); + if da <= 0 || db <= 0 { + return None; + } + if da == db || da == 1 { + out_rev.push(db); + } else if db == 1 { + out_rev.push(da); + } else { + return None; + } + } + out_rev.reverse(); + Some(out_rev) +} + +fn linear_index_for_broadcast_operand( + out_linear_idx: usize, + out_shape: &[i64], + in_shape: &[i64], +) -> Option { + if in_shape.is_empty() { + return Some(0); + } + + let in_rank = in_shape.len(); + let out_rank = out_shape.len(); + if in_rank > out_rank { + return None; + } + + let mut in_linear_idx = 0usize; + let mut in_stride = 1usize; + let mut rem = out_linear_idx; + + for out_axis_rev in 0..out_rank { + let out_axis = out_rank - 1 - out_axis_rev; + let out_dim = usize::try_from(out_shape[out_axis]).ok()?; + if out_dim == 0 { + return None; + } + let out_coord = rem % out_dim; + rem /= out_dim; + + if out_axis_rev < in_rank { + let in_axis = in_rank - 1 - out_axis_rev; + let in_dim = usize::try_from(in_shape[in_axis]).ok()?; + if in_dim == 0 { + return None; + } + let in_coord = if in_dim == 1 { 0 } else { out_coord }; + in_linear_idx = in_linear_idx.saturating_add(in_coord.saturating_mul(in_stride)); + in_stride = in_stride.saturating_mul(in_dim); + } + } + + Some(in_linear_idx) +} + +fn fold_binary_const_i64( + op_type: &str, + a_values: &[i64], + b_values: &[i64], + a_shape: &[i64], + b_shape: &[i64], +) -> Option<(Vec, Vec)> { + let out_shape = broadcast_shape(a_shape, b_shape)?; + let out_numel = shape_numel(&out_shape)?; + + let mut out_values = Vec::with_capacity(out_numel); + for out_idx in 0..out_numel { + let a_idx = linear_index_for_broadcast_operand(out_idx, &out_shape, a_shape)?; + let b_idx = linear_index_for_broadcast_operand(out_idx, &out_shape, b_shape)?; + let av = *a_values.get(a_idx)?; + let bv = *b_values.get(b_idx)?; + let v = match op_type { + "Add" => av + bv, + "Sub" => av - bv, + "Mul" => av * bv, + "Div" => { + if bv == 0 { + return None; + } + av / bv + } + "Equal" => { + if av == bv { + 1 + } else { + 0 + } + } + _ => return None, + }; + out_values.push(v); + } + + Some((out_values, out_shape)) +} + /// Conversion options for ONNX to WebNN #[derive(Debug, Clone)] pub struct ConvertOptions { @@ -938,11 +1081,19 @@ Provide --override-dim {}= or enable --experimental-dynamic-inputs.", if let Some(TypeProtoValue::TensorType(tensor_type)) = &type_proto.value { if let Some(shape_proto) = &tensor_type.shape { let mut shape: Vec = Vec::new(); + let mut unknown = false; for dim in &shape_proto.dim { if let Some(dim_value) = &dim.value { match dim_value { DimensionValue::DimValue(v) => { - shape.push(*v); + if *v > 0 { + shape.push(*v); + } else if options.experimental_dynamic_inputs { + shape.push(default_dynamic_max_size as i64); + } else { + unknown = true; + break; + } } DimensionValue::DimParam(dim_param) => { if let Some(v) = resolve_dim_for_inference( @@ -950,12 +1101,22 @@ Provide --override-dim {}= or enable --experimental-dynamic-inputs.", &mut inference_overrides, ) { shape.push(v as i64); + } else if options.experimental_dynamic_inputs { + shape.push(dynamic_max_for_dim(dim_param) as i64); + } else { + unknown = true; + break; } } } + } else if options.experimental_dynamic_inputs { + shape.push(default_dynamic_max_size as i64); + } else { + unknown = true; + break; } } - if !shape.is_empty() { + if !unknown && !shape.is_empty() { value_shapes.insert(raw_name.clone(), shape.clone()); value_shapes.insert(mapped_name.clone(), shape); } @@ -1028,7 +1189,14 @@ Provide --override-dim {}= or enable --experimental-dynamic-inputs.", if let Some(dim_value) = &dim.value { match dim_value { DimensionValue::DimValue(v) => { - shape.push(*v); + if *v > 0 { + shape.push(*v); + } else if options.experimental_dynamic_inputs { + shape.push(default_dynamic_max_size as i64); + } else { + unknown = true; + break; + } } DimensionValue::DimParam(dim_param) => { if let Some(v) = resolve_dim_for_inference( @@ -1036,12 +1204,16 @@ Provide --override-dim {}= or enable --experimental-dynamic-inputs.", &mut inference_overrides, ) { shape.push(v as i64); + } else if options.experimental_dynamic_inputs { + shape.push(dynamic_max_for_dim(dim_param) as i64); } else { unknown = true; break; } } } + } else if options.experimental_dynamic_inputs { + shape.push(default_dynamic_max_size as i64); } else { unknown = true; break; @@ -1170,42 +1342,68 @@ Provide --override-dim {}= or enable --experimental-dynamic-inputs.", // Run the static shape/type inference scaffold to seed shapes/types/constants // before lowering. Errors surface early if dynamic dims remain. - match crate::onnx::shape_inference::infer_static_shapes(&self.model, &inference_overrides) { - Ok(inferred) => { - // Initial seeding: use or_insert since these are the first values - // (no prior shapes to override) - for (k, v) in inferred.value_shapes { - value_shapes.entry(k).or_insert(v); - } - for (k, v) in inferred.value_types { - value_types.entry(k).or_insert(v); - } - for (k, v) in inferred.const_values { - // Use insert() instead of or_insert() to allow shape inference to correct - // earlier wrong values (e.g., Where operation heuristics) - if k.contains("rotary") && k.contains("Where") { - if let Some(old_val) = const_values.get(&k) { - crate::debug_println!( - "[CONVERT] Overwriting {} from {:?} to {:?}", - k, - old_val, - v - ); - } else { - crate::debug_println!("[CONVERT] Inserting new {} = {:?}", k, v); + let mut dynamic_inference_attempts: HashSet = HashSet::new(); + loop { + match crate::onnx::shape_inference::infer_static_shapes( + &self.model, + &inference_overrides, + ) { + Ok(inferred) => { + // Initial seeding: use or_insert since these are the first values + // (no prior shapes to override) + for (k, v) in inferred.value_shapes { + value_shapes.entry(k).or_insert(v); + } + for (k, v) in inferred.value_types { + value_types.entry(k).or_insert(v); + } + for (k, v) in inferred.const_values { + // Use insert() instead of or_insert() to allow shape inference to correct + // earlier wrong values (e.g., Where operation heuristics) + if k.contains("rotary") && k.contains("Where") { + if let Some(old_val) = const_values.get(&k) { + crate::debug_println!( + "[CONVERT] Overwriting {} from {:?} to {:?}", + k, + old_val, + v + ); + } else { + crate::debug_println!("[CONVERT] Inserting new {} = {:?}", k, v); + } } + const_values.insert(k, v); } - const_values.insert(k, v); + break; } - } - Err(crate::onnx::shape_inference::ShapeInferenceError::DynamicDim { input, dim }) => { - crate::debug_println!( - "[CONVERT] Skipping static shape inference due to unresolved dynamic dim '{}' on input '{}'", + Err(crate::onnx::shape_inference::ShapeInferenceError::DynamicDim { + input, dim, - input - ); + }) => { + if options.experimental_dynamic_inputs + && !dynamic_inference_attempts.contains(dim.as_str()) + { + let fallback = dynamic_max_for_dim(&dim); + inference_overrides.insert(dim.clone(), fallback); + dynamic_inference_attempts.insert(dim.clone()); + crate::debug_println!( + "[CONVERT] Retrying static shape inference with inferred override {}={} \ + (required by input '{}')", + dim, + fallback, + input + ); + continue; + } + crate::debug_println!( + "[CONVERT] Skipping static shape inference due to unresolved dynamic dim '{}' on input '{}'", + dim, + input + ); + break; + } + Err(e) => return Err(OnnxError::ShapeInference(e.to_string())), } - Err(e) => return Err(OnnxError::ShapeInference(e.to_string())), } // Propagate shapes and fold constant shape expressions in a few passes @@ -1373,54 +1571,31 @@ Provide --override-dim {}= or enable --experimental-dynamic-inputs.", } } else if matches!(op_type, "Add" | "Sub" | "Mul" | "Div") { if node.input.as_slice().len() >= 2 { - if let (Some(a), Some(b), Some(out)) = ( - node.input - .as_slice() - .first() - .and_then(|i| const_values.get(i)), - node.input - .as_slice() - .get(1) - .and_then(|i| const_values.get(i)), + if let (Some(a_name), Some(b_name), Some(out)) = ( + node.input.as_slice().first(), + node.input.as_slice().get(1), node.output.as_slice().first(), ) { - let mut result_vals = Vec::new(); - let (a_len, b_len) = (a.len(), b.len()); - let max_len = a_len.max(b_len); - for idx in 0..max_len { - let av = if a_len == 1 { a[0] } else { a[idx] }; - let bv = if b_len == 1 { b[0] } else { b[idx] }; - let v = match op_type { - "Add" => av + bv, - "Sub" => av - bv, - "Mul" => av * bv, - "Div" => { - if bv == 0 { - continue; - } - av / bv - } - _ => unreachable!(), - }; - result_vals.push(v); - } - if !result_vals.is_empty() { - const_values.insert(out.to_string(), result_vals.clone()); - let out_shape = if result_vals.len() == 1 { - Vec::new() - } else { - vec![result_vals.len() as i64] - }; - // Force the correct shape - Binary operations compute exact output shape - value_shapes.insert(out.to_string(), out_shape.clone()); - value_shapes.insert(sanitize_identifier(out), out_shape); - if let Some(dtype) = node - .input - .as_slice() - .iter() - .find_map(|i| value_types.get(i).cloned()) + let a = const_values.get(a_name); + let b = const_values.get(b_name); + if let (Some(a), Some(b)) = (a, b) { + let a_shape = const_shape_for_folding(a_name, a, &value_shapes); + let b_shape = const_shape_for_folding(b_name, b, &value_shapes); + if let Some((result_vals, out_shape)) = + fold_binary_const_i64(op_type, a, b, &a_shape, &b_shape) { - value_types.insert(out.to_string(), dtype); + const_values.insert(out.to_string(), result_vals.clone()); + // Force the correct shape - Binary operations compute exact output shape + value_shapes.insert(out.to_string(), out_shape.clone()); + value_shapes.insert(sanitize_identifier(out), out_shape); + if let Some(dtype) = node + .input + .as_slice() + .iter() + .find_map(|i| value_types.get(i).cloned()) + { + value_types.insert(out.to_string(), dtype); + } } } } @@ -1583,44 +1758,25 @@ Provide --override-dim {}= or enable --experimental-dynamic-inputs.", } else if op_type == "Equal" { // Equal(a, b) -> boolean tensor (represented as i64: 1 for true, 0 for false) if node.input.as_slice().len() >= 2 { - if let (Some(a), Some(b), Some(out)) = ( - node.input - .as_slice() - .first() - .and_then(|i| const_values.get(i)), - node.input - .as_slice() - .get(1) - .and_then(|i| const_values.get(i)), + if let (Some(a_name), Some(b_name), Some(out)) = ( + node.input.as_slice().first(), + node.input.as_slice().get(1), node.output.as_slice().first(), ) { - let mut result_vals = Vec::new(); - let (a_len, b_len) = (a.len(), b.len()); - let max_len = a_len.max(b_len); - for idx in 0..max_len { - let av = if a_len == 1 { - a[0] - } else { - a.get(idx).copied().unwrap_or(0) - }; - let bv = if b_len == 1 { - b[0] - } else { - b.get(idx).copied().unwrap_or(0) - }; - result_vals.push(if av == bv { 1 } else { 0 }); - } - if !result_vals.is_empty() { - const_values.insert(out.to_string(), result_vals.clone()); - let out_shape = if result_vals.len() == 1 { - Vec::new() - } else { - vec![result_vals.len() as i64] - }; - // Force the correct shape - Equal operation computes exact output shape - value_shapes.insert(out.to_string(), out_shape.clone()); - value_shapes.insert(sanitize_identifier(out), out_shape); - value_types.insert(out.to_string(), DataType::Int64); + let a = const_values.get(a_name); + let b = const_values.get(b_name); + if let (Some(a), Some(b)) = (a, b) { + let a_shape = const_shape_for_folding(a_name, a, &value_shapes); + let b_shape = const_shape_for_folding(b_name, b, &value_shapes); + if let Some((result_vals, out_shape)) = + fold_binary_const_i64("Equal", a, b, &a_shape, &b_shape) + { + const_values.insert(out.to_string(), result_vals.clone()); + // Force the correct shape - Equal operation computes exact output shape + value_shapes.insert(out.to_string(), out_shape.clone()); + value_shapes.insert(sanitize_identifier(out), out_shape); + value_types.insert(out.to_string(), DataType::Int64); + } } } } @@ -1645,7 +1801,6 @@ Provide --override-dim {}= or enable --experimental-dynamic-inputs.", if let Some(val) = const_values.get("/model/rotary_emb/Where_output_0") { crate::debug_println!("[NODE CONV] /model/rotary_emb/Where_output_0 = {:?}", val); } - for onnx_node in onnx_graph.node.as_slice() { // If all outputs are compile-time constants, emit them directly and skip conversion let outputs = onnx_node.output.as_slice(); @@ -2269,4 +2424,446 @@ mod tests { assert!(msg.contains("override-dim")); assert!(msg.contains("experimental-dynamic-inputs")); } + + #[test] + fn test_convert_dynamic_shape_concat_reshape_path_with_experimental_flag() { + use crate::protos::onnx::{tensor_shape_proto, type_proto}; + use crate::protos::onnx::{ + AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, TensorShapeProto, + ValueInfoProto, + }; + + let batch_dim = tensor_shape_proto::Dimension { + value: Some(tensor_shape_proto::dimension::Value::DimValue(1)), + denotation: String::new(), + }; + let seq_dim = tensor_shape_proto::Dimension { + value: Some(tensor_shape_proto::dimension::Value::DimParam( + "sequence_length".to_string(), + )), + denotation: String::new(), + }; + let hidden_dim = tensor_shape_proto::Dimension { + value: Some(tensor_shape_proto::dimension::Value::DimValue(4)), + denotation: String::new(), + }; + let data_shape = TensorShapeProto { + dim: vec![batch_dim, seq_dim, hidden_dim], + }; + + let data_tensor_type = type_proto::Tensor { + elem_type: TensorProto_DataType::Float.into(), + shape: Some(data_shape), + }; + let data_type_proto = crate::protos::onnx::TypeProto { + value: Some(type_proto::Value::TensorType(data_tensor_type)), + denotation: String::new(), + }; + + let data_input = ValueInfoProto { + name: "data".to_string(), + r#type: Some(data_type_proto.clone()), + ..Default::default() + }; + let data_output = ValueInfoProto { + name: "out".to_string(), + r#type: Some(data_type_proto), + ..Default::default() + }; + + let idx0 = TensorProto { + name: "idx0".to_string(), + data_type: TensorProto_DataType::Int64 as i32, + dims: vec![1], + int64_data: vec![0], + ..Default::default() + }; + let idx1 = TensorProto { + name: "idx1".to_string(), + data_type: TensorProto_DataType::Int64 as i32, + dims: vec![1], + int64_data: vec![1], + ..Default::default() + }; + let last_dim = TensorProto { + name: "last_dim".to_string(), + data_type: TensorProto_DataType::Int64 as i32, + dims: vec![1], + int64_data: vec![4], + ..Default::default() + }; + + let shape_node = NodeProto { + op_type: "Shape".to_string(), + input: vec!["data".to_string()], + output: vec!["shape_out".to_string()], + ..Default::default() + }; + let gather0 = NodeProto { + op_type: "Gather".to_string(), + input: vec!["shape_out".to_string(), "idx0".to_string()], + output: vec!["dim0".to_string()], + attribute: vec![AttributeProto { + name: "axis".to_string(), + i: 0, + ..Default::default() + }], + ..Default::default() + }; + let gather1 = NodeProto { + op_type: "Gather".to_string(), + input: vec!["shape_out".to_string(), "idx1".to_string()], + output: vec!["dim1".to_string()], + attribute: vec![AttributeProto { + name: "axis".to_string(), + i: 0, + ..Default::default() + }], + ..Default::default() + }; + let concat_shape = NodeProto { + op_type: "Concat".to_string(), + input: vec![ + "dim0".to_string(), + "dim1".to_string(), + "last_dim".to_string(), + ], + output: vec!["shape_for_reshape".to_string()], + attribute: vec![AttributeProto { + name: "axis".to_string(), + i: 0, + ..Default::default() + }], + ..Default::default() + }; + let reshape = NodeProto { + op_type: "Reshape".to_string(), + input: vec!["data".to_string(), "shape_for_reshape".to_string()], + output: vec!["out".to_string()], + ..Default::default() + }; + + let model = ModelProto { + graph: Some(GraphProto { + input: vec![data_input], + output: vec![data_output], + initializer: vec![idx0, idx1, last_dim], + node: vec![shape_node, gather0, gather1, concat_shape, reshape], + ..Default::default() + }), + ..Default::default() + }; + + let converter = OnnxConverter::new(model).expect("converter"); + let graph = converter + .convert(&ConvertOptions { + optimize: true, + experimental_dynamic_inputs: true, + extract_weights: false, + ..ConvertOptions::default() + }) + .expect("dynamic reshape path should convert"); + + let reshape_node = graph + .nodes + .iter() + .find(|n| n.op == "reshape") + .expect("reshape node should exist"); + let shape = reshape_node + .options + .get("newShape") + .and_then(|v| v.as_array()) + .expect("newShape should be an array"); + assert_eq!(shape.len(), 3); + assert_eq!(shape[0].as_u64(), Some(1)); + assert_eq!(shape[2].as_u64(), Some(4)); + assert!( + shape[1].as_u64().is_some_and(|v| v > 0), + "sequence dimension should be concretized for lowering" + ); + } + + #[test] + fn test_convert_reshape_shape_path_survives_add_broadcast() { + use crate::protos::onnx::{tensor_shape_proto, type_proto}; + use crate::protos::onnx::{ + AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, TensorShapeProto, + ValueInfoProto, + }; + + let batch_dim = tensor_shape_proto::Dimension { + value: Some(tensor_shape_proto::dimension::Value::DimValue(1)), + denotation: String::new(), + }; + let seq_dim = tensor_shape_proto::Dimension { + value: Some(tensor_shape_proto::dimension::Value::DimValue(128)), + denotation: String::new(), + }; + let hidden_dim = tensor_shape_proto::Dimension { + value: Some(tensor_shape_proto::dimension::Value::DimValue(4)), + denotation: String::new(), + }; + let data_shape = TensorShapeProto { + dim: vec![batch_dim, seq_dim, hidden_dim], + }; + + let data_tensor_type = type_proto::Tensor { + elem_type: TensorProto_DataType::Float.into(), + shape: Some(data_shape), + }; + let data_type_proto = crate::protos::onnx::TypeProto { + value: Some(type_proto::Value::TensorType(data_tensor_type)), + denotation: String::new(), + }; + + let data_input = ValueInfoProto { + name: "data".to_string(), + r#type: Some(data_type_proto.clone()), + ..Default::default() + }; + let data_output = ValueInfoProto { + name: "out".to_string(), + r#type: Some(data_type_proto), + ..Default::default() + }; + + let bias = TensorProto { + name: "bias".to_string(), + data_type: TensorProto_DataType::Float as i32, + dims: vec![4], + float_data: vec![0.0, 0.0, 0.0, 0.0], + ..Default::default() + }; + let idx0 = TensorProto { + name: "idx0".to_string(), + data_type: TensorProto_DataType::Int64 as i32, + dims: vec![1], + int64_data: vec![0], + ..Default::default() + }; + let idx1 = TensorProto { + name: "idx1".to_string(), + data_type: TensorProto_DataType::Int64 as i32, + dims: vec![1], + int64_data: vec![1], + ..Default::default() + }; + let last_dim = TensorProto { + name: "last_dim".to_string(), + data_type: TensorProto_DataType::Int64 as i32, + dims: vec![1], + int64_data: vec![4], + ..Default::default() + }; + + let add_node = NodeProto { + op_type: "Add".to_string(), + input: vec!["data".to_string(), "bias".to_string()], + output: vec!["add_out".to_string()], + ..Default::default() + }; + let shape_node = NodeProto { + op_type: "Shape".to_string(), + input: vec!["add_out".to_string()], + output: vec!["shape_out".to_string()], + ..Default::default() + }; + let gather0 = NodeProto { + op_type: "Gather".to_string(), + input: vec!["shape_out".to_string(), "idx0".to_string()], + output: vec!["dim0".to_string()], + attribute: vec![AttributeProto { + name: "axis".to_string(), + i: 0, + ..Default::default() + }], + ..Default::default() + }; + let gather1 = NodeProto { + op_type: "Gather".to_string(), + input: vec!["shape_out".to_string(), "idx1".to_string()], + output: vec!["dim1".to_string()], + attribute: vec![AttributeProto { + name: "axis".to_string(), + i: 0, + ..Default::default() + }], + ..Default::default() + }; + let concat_shape = NodeProto { + op_type: "Concat".to_string(), + input: vec![ + "dim0".to_string(), + "dim1".to_string(), + "last_dim".to_string(), + ], + output: vec!["shape_for_reshape".to_string()], + attribute: vec![AttributeProto { + name: "axis".to_string(), + i: 0, + ..Default::default() + }], + ..Default::default() + }; + let reshape = NodeProto { + op_type: "Reshape".to_string(), + input: vec!["add_out".to_string(), "shape_for_reshape".to_string()], + output: vec!["out".to_string()], + ..Default::default() + }; + + let model = ModelProto { + graph: Some(GraphProto { + input: vec![data_input], + output: vec![data_output], + initializer: vec![bias, idx0, idx1, last_dim], + node: vec![ + add_node, + shape_node, + gather0, + gather1, + concat_shape, + reshape, + ], + ..Default::default() + }), + ..Default::default() + }; + + let converter = OnnxConverter::new(model).expect("converter"); + let graph = converter + .convert(&ConvertOptions { + optimize: true, + extract_weights: false, + ..ConvertOptions::default() + }) + .expect("broadcasted shape path should convert"); + + let reshape_node = graph + .nodes + .iter() + .find(|n| n.op == "reshape") + .expect("reshape node should exist"); + assert_eq!( + reshape_node.options.get("newShape"), + Some(&serde_json::json!([1, 128, 4])) + ); + } + + #[test] + fn test_binary_const_folding_preserves_broadcast_shape() { + let a = vec![-1]; + let b = vec![1, 2, 3, 4].repeat(128); + let a_shape = Vec::::new(); + let b_shape = vec![1, 128, 4]; + let (out, out_shape) = + fold_binary_const_i64("Mul", &a, &b, &a_shape, &b_shape).expect("broadcast fold"); + assert_eq!(out_shape, vec![1, 128, 4]); + assert_eq!(out.len(), 512); + assert_eq!(out[0], -1); + assert_eq!(out[1], -2); + assert_eq!(out[2], -3); + assert_eq!(out[3], -4); + } + + #[test] + fn test_convert_equal_broadcast_path_does_not_flatten_const_shape() { + use crate::protos::onnx::{ + type_proto, AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, + }; + + let a = TensorProto { + name: "shape_vec".to_string(), + data_type: TensorProto_DataType::Int64 as i32, + dims: vec![4], + int64_data: vec![1, 128, 4, 8], + ..Default::default() + }; + let shape3 = TensorProto { + name: "shape3".to_string(), + data_type: TensorProto_DataType::Int64 as i32, + dims: vec![3], + int64_data: vec![1, 128, 4], + ..Default::default() + }; + let neg1 = TensorProto { + name: "neg1".to_string(), + data_type: TensorProto_DataType::Int64 as i32, + dims: vec![], + int64_data: vec![-1], + ..Default::default() + }; + let cos_fill = TensorProto { + data_type: TensorProto_DataType::Int64 as i32, + dims: vec![], + int64_data: vec![1], + ..Default::default() + }; + + let cos = NodeProto { + op_type: "ConstantOfShape".to_string(), + input: vec!["shape3".to_string()], + output: vec!["cos_out".to_string()], + attribute: vec![AttributeProto { + name: "value".to_string(), + t: Some(cos_fill), + ..Default::default() + }], + ..Default::default() + }; + let mul = NodeProto { + op_type: "Mul".to_string(), + input: vec!["cos_out".to_string(), "neg1".to_string()], + output: vec!["mul_out".to_string()], + ..Default::default() + }; + let eq = NodeProto { + op_type: "Equal".to_string(), + input: vec!["shape_vec".to_string(), "mul_out".to_string()], + output: vec!["eq_out".to_string()], + ..Default::default() + }; + + let output_type = crate::protos::onnx::TypeProto { + value: Some(type_proto::Value::TensorType(type_proto::Tensor { + elem_type: TensorProto_DataType::Bool.into(), + shape: None, + })), + denotation: String::new(), + }; + + let model = ModelProto { + graph: Some(GraphProto { + initializer: vec![a, shape3, neg1], + node: vec![cos, mul, eq], + output: vec![crate::protos::onnx::ValueInfoProto { + name: "eq_out".to_string(), + r#type: Some(output_type), + ..Default::default() + }], + ..Default::default() + }), + ..Default::default() + }; + + let converter = OnnxConverter::new(model).expect("converter"); + let graph = converter + .convert(&ConvertOptions { + optimize: true, + extract_weights: false, + ..ConvertOptions::default() + }) + .expect("convert"); + + let mul_const = graph.consts.get("mul_out").expect("mul_out const"); + assert_eq!(mul_const.shape, vec![1, 128, 4]); + assert!( + graph.consts.get("eq_out").is_none() + || graph + .consts + .get("eq_out") + .is_some_and(|decl| decl.shape == vec![1, 128, 4]), + "eq_out constant must not be flattened" + ); + } } diff --git a/src/onnx/ops/conversion.rs b/src/onnx/ops/conversion.rs index 571b0e2..6c37d71 100644 --- a/src/onnx/ops/conversion.rs +++ b/src/onnx/ops/conversion.rs @@ -8,6 +8,21 @@ use serde_json::Map; pub struct ConversionHandler; +fn dtype_to_webnn_string(dt: &crate::ast::DataType) -> &'static str { + match dt { + crate::ast::DataType::Float32 => "float32", + crate::ast::DataType::Float16 => "float16", + crate::ast::DataType::Int4 => "int4", + crate::ast::DataType::Uint4 => "uint4", + crate::ast::DataType::Int32 => "int32", + crate::ast::DataType::Uint32 => "uint32", + crate::ast::DataType::Int64 => "int64", + crate::ast::DataType::Uint64 => "uint64", + crate::ast::DataType::Int8 => "int8", + crate::ast::DataType::Uint8 => "uint8", + } +} + impl OpHandler for ConversionHandler { fn supports(&self, op_type: &str) -> bool { matches!(op_type, "Cast" | "Constant") @@ -82,7 +97,7 @@ impl ConversionHandler { let mut options = Map::new(); options.insert( "to".to_string(), - serde_json::json!(format!("{:?}", target_type)), + serde_json::json!(dtype_to_webnn_string(&target_type)), ); let mut result = ConversionResult::new(vec![Node { @@ -140,7 +155,7 @@ impl ConversionHandler { let mut options = Map::new(); options.insert( "dataType".to_string(), - serde_json::json!(format!("{:?}", data_type)), + serde_json::json!(dtype_to_webnn_string(&data_type)), ); options.insert("shape".to_string(), serde_json::json!(shape)); @@ -222,5 +237,47 @@ mod tests { assert_eq!(result.nodes[0].op, "cast"); assert_eq!(result.nodes[0].inputs, vec!["x"]); assert!(result.nodes[0].options.contains_key("to")); + assert_eq!( + result.nodes[0].options.get("to"), + Some(&serde_json::json!("int64")) + ); + } + + #[test] + fn test_convert_constant_uses_lowercase_dtype_and_base64_data() { + let handler = ConversionHandler; + let mut node = create_test_node("Constant", vec![], vec!["c0"]); + let tensor = crate::protos::onnx::TensorProto { + data_type: crate::protos::onnx::TensorProto_DataType::Float as i32, + dims: vec![1], + raw_data: vec![0, 0, 128, 63], // 1.0f32 + ..Default::default() + }; + node.attribute.push(AttributeProto { + name: "value".to_string(), + t: Some(tensor), + ..Default::default() + }); + + let result = handler + .convert( + &node, + &ConversionContext { + initializers: &std::collections::HashMap::new(), + value_shapes: &std::collections::HashMap::new(), + value_shape_dims: crate::onnx::ops::empty_value_shape_dims(), + const_values: &std::collections::HashMap::new(), + value_ids: &std::collections::HashMap::new(), + value_types: &std::collections::HashMap::new(), + }, + ) + .unwrap(); + + assert_eq!(result.nodes.len(), 1); + assert_eq!( + result.nodes[0].options.get("dataType"), + Some(&serde_json::json!("float32")) + ); + assert!(result.nodes[0].options.get("data").is_some()); } } diff --git a/src/onnx/ops/mod.rs b/src/onnx/ops/mod.rs index 7953858..379b50e 100644 --- a/src/onnx/ops/mod.rs +++ b/src/onnx/ops/mod.rs @@ -59,6 +59,45 @@ impl<'a> ConversionContext<'a> { sanitized } + + pub fn resolve_shape(&self, name: &str) -> Option<&Vec> { + let sanitized = crate::onnx::convert::sanitize_identifier(name); + let trimmed = name.trim_start_matches('/'); + self.value_shapes + .get(name) + .or_else(|| self.value_shapes.get(&sanitized)) + .or_else(|| self.value_shapes.get(trimmed)) + } + + pub fn input_rank(&self, name: &str) -> Option { + self.resolve_shape(name).map(|s| s.len()) + } +} + +pub fn normalize_axis(axis: i64, rank: usize) -> Result { + let rank_i64 = rank as i64; + let normalized = if axis < 0 { axis + rank_i64 } else { axis }; + if normalized < 0 || normalized >= rank_i64 { + return Err(OnnxError::InvalidShape(format!( + "axis {} is out of bounds for rank {}", + axis, rank + ))); + } + Ok(normalized) +} + +pub fn normalize_axes(axes: &[i64], rank: usize) -> Result, OnnxError> { + axes.iter().map(|&a| normalize_axis(a, rank)).collect() +} + +pub fn normalize_axis_best_effort(axis: i64, rank: usize) -> i64 { + normalize_axis(axis, rank).unwrap_or(axis) +} + +pub fn normalize_axes_best_effort(axes: &[i64], rank: usize) -> Vec { + axes.iter() + .map(|&a| normalize_axis_best_effort(a, rank)) + .collect() } pub fn empty_value_shape_dims() -> &'static HashMap> { diff --git a/src/onnx/ops/normalization.rs b/src/onnx/ops/normalization.rs index 53c41da..d90181d 100644 --- a/src/onnx/ops/normalization.rs +++ b/src/onnx/ops/normalization.rs @@ -2,7 +2,9 @@ use crate::ast::Node; use crate::onnx::convert::{sanitize_identifier, OnnxError}; -use crate::onnx::ops::{ConversionContext, ConversionResult, OpHandler}; +use crate::onnx::ops::{ + normalize_axis_best_effort, ConversionContext, ConversionResult, OpHandler, +}; use crate::protos::onnx::NodeProto; use serde_json::Map; @@ -80,9 +82,11 @@ impl NormalizationHandler { let mut options = Map::new(); options.insert("epsilon".to_string(), serde_json::json!(epsilon)); - // WebNN layerNormalization uses axes parameter (array) - // Convert ONNX axis to axes array - if axis != -1 { + // WebNN layerNormalization uses positive axes. + if let Some(rank) = context.input_rank(inputs[0].as_str()) { + let normalized_axis = normalize_axis_best_effort(axis, rank); + options.insert("axes".to_string(), serde_json::json!([normalized_axis])); + } else if axis != -1 { options.insert("axes".to_string(), serde_json::json!([axis])); } @@ -152,8 +156,13 @@ impl NormalizationHandler { let input0 = context.resolve_input(&inputs[0]); + let axis = if let Some(rank) = context.input_rank(inputs[0].as_str()) { + normalize_axis_best_effort(axis, rank) + } else { + axis + }; + let mut options = Map::new(); - // WebNN softmax uses axis parameter (single value) options.insert("axis".to_string(), serde_json::json!(axis)); let mut result = ConversionResult::new(vec![Node { @@ -212,7 +221,8 @@ mod tests { let mut node = create_test_node("Softmax", vec!["x"], vec!["y"]); add_int_attribute(&mut node, "axis", -1); let initializers = std::collections::HashMap::new(); - let value_shapes = std::collections::HashMap::new(); + let mut value_shapes = std::collections::HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 128, 384]); let const_values = std::collections::HashMap::new(); let value_ids = std::collections::HashMap::new(); let value_types = std::collections::HashMap::new(); @@ -231,14 +241,21 @@ mod tests { assert_eq!(result.nodes[0].inputs, vec!["x"]); assert_eq!(result.nodes[0].id, "y"); assert!(result.nodes[0].options.contains_key("axis")); + assert_eq!( + result.nodes[0].options.get("axis"), + Some(&serde_json::json!(2)) + ); } #[test] fn test_convert_layer_norm() { let handler = NormalizationHandler; - let node = create_test_node("LayerNormalization", vec!["x", "scale", "bias"], vec!["y"]); + let mut node = + create_test_node("LayerNormalization", vec!["x", "scale", "bias"], vec!["y"]); + add_int_attribute(&mut node, "axis", -1); let initializers = std::collections::HashMap::new(); - let value_shapes = std::collections::HashMap::new(); + let mut value_shapes = std::collections::HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 128, 384]); let const_values = std::collections::HashMap::new(); let value_ids = std::collections::HashMap::new(); let value_types = std::collections::HashMap::new(); @@ -256,5 +273,9 @@ mod tests { assert_eq!(result.nodes[0].op, "layerNormalization"); assert_eq!(result.nodes[0].inputs.len(), 3); assert!(result.nodes[0].options.contains_key("epsilon")); + assert_eq!( + result.nodes[0].options.get("axes"), + Some(&serde_json::json!([2])) + ); } } diff --git a/src/onnx/ops/reduction.rs b/src/onnx/ops/reduction.rs index ac7b952..5c730e7 100644 --- a/src/onnx/ops/reduction.rs +++ b/src/onnx/ops/reduction.rs @@ -2,7 +2,9 @@ use crate::ast::Node; use crate::onnx::convert::{sanitize_identifier, OnnxError}; -use crate::onnx::ops::{ConversionContext, ConversionResult, OpHandler}; +use crate::onnx::ops::{ + normalize_axes_best_effort, ConversionContext, ConversionResult, OpHandler, +}; use crate::protos::onnx::NodeProto; use serde_json::Map; @@ -88,6 +90,11 @@ impl ReductionHandler { // Add axes if specified if let Some(axes_values) = axes { + let axes_values = if let Some(rank) = context.input_rank(inputs[0].as_str()) { + normalize_axes_best_effort(&axes_values, rank) + } else { + axes_values + }; options.insert("axes".to_string(), serde_json::json!(axes_values)); } @@ -190,9 +197,10 @@ mod tests { fn test_convert_reduce_sum() { let handler = ReductionHandler; let mut node = create_test_node("ReduceSum", vec!["x"], vec!["y"]); - add_ints_attribute(&mut node, "axes", vec![0]); + add_ints_attribute(&mut node, "axes", vec![-1]); let initializers = std::collections::HashMap::new(); - let value_shapes = std::collections::HashMap::new(); + let mut value_shapes = std::collections::HashMap::new(); + value_shapes.insert("x".to_string(), vec![2, 3, 4]); let const_values = std::collections::HashMap::new(); let value_ids = std::collections::HashMap::new(); let value_types = std::collections::HashMap::new(); @@ -208,5 +216,9 @@ mod tests { let result = handler.convert(&node, &context).unwrap(); assert_eq!(result.nodes.len(), 1); assert_eq!(result.nodes[0].op, "reduceSum"); + assert_eq!( + result.nodes[0].options.get("axes"), + Some(&serde_json::json!([2])) + ); } } diff --git a/src/onnx/ops/reshape.rs b/src/onnx/ops/reshape.rs index d06d399..c23b3cb 100644 --- a/src/onnx/ops/reshape.rs +++ b/src/onnx/ops/reshape.rs @@ -2,7 +2,10 @@ use crate::ast::Node; use crate::onnx::convert::{sanitize_identifier, OnnxError}; -use crate::onnx::ops::{ConversionContext, ConversionResult, OpHandler}; +use crate::onnx::ops::{ + normalize_axes_best_effort, normalize_axis_best_effort, ConversionContext, ConversionResult, + OpHandler, +}; use crate::protos::onnx::{NodeProto, TensorProto_DataType}; use serde_json::Map; @@ -187,6 +190,23 @@ impl ReshapeHandler { let shape_from_const = !shape_values.is_empty(); // Fallback: derive shape from known input shape when the shape tensor isn't const. + if shape_values.is_empty() { + if let Some(out_name) = node.output.as_slice().first() { + let out_s = out_name.to_string(); + let known_output_shape = context + .value_shapes + .get(&out_s) + .or_else(|| context.value_shapes.get(&sanitize_identifier(&out_s))) + .or_else(|| context.value_shapes.get(out_s.trim_start_matches('/'))) + .cloned(); + if let Some(out_shape) = known_output_shape { + if !out_shape.is_empty() && out_shape.iter().all(|&d| d > 0) { + shape_values = out_shape; + } + } + } + } + if shape_values.is_empty() { if let Some(ds) = context.value_shapes.get(data_input_raw.as_str()) { if ds.len() >= 3 { @@ -268,8 +288,9 @@ impl ReshapeHandler { } return Err(OnnxError::InvalidShape(format!( - "Reshape shape input '{}' must be a constant (initializer/constant-folded) or input shape must be known.", - shape_input_raw + "Reshape shape input '{}' must be a constant (initializer/constant-folded) or input shape must be known. \ + data input='{}', resolved='{}'.", + shape_input_raw, data_input_raw, data_input ))); } } else if shape_from_const @@ -792,6 +813,12 @@ impl ReshapeHandler { let sanitized_inputs: Vec = inputs.iter().map(|s| context.resolve_input(s)).collect(); + let axis = if let Some(rank) = context.input_rank(inputs[0].as_str()) { + normalize_axis_best_effort(axis, rank) + } else { + axis + }; + let mut options = Map::new(); options.insert("axis".to_string(), serde_json::json!(axis)); @@ -857,6 +884,12 @@ impl ReshapeHandler { .map(|s| sanitize_identifier(&s.to_string())) .collect(); + let axis = if let Some(rank) = context.input_rank(inputs[0].as_str()) { + normalize_axis_best_effort(axis, rank) + } else { + axis + }; + let mut options = Map::new(); options.insert("axis".to_string(), serde_json::json!(axis)); if let Some(split_values) = splits { @@ -930,6 +963,11 @@ impl ReshapeHandler { from_const } }; + let axes_values = if let Some(rank) = context.input_rank(inputs[0].as_str()) { + normalize_axes_best_effort(&axes_values, rank) + } else { + axes_values + }; let mut options = Map::new(); options.insert("axes".to_string(), serde_json::json!(axes_values.clone())); @@ -993,6 +1031,11 @@ impl ReshapeHandler { } else { self.read_axes_from_attr_or_const(node, context)? }; + let axes_values = if let Some(rank) = context.input_rank(inputs[0].as_str()) { + normalize_axes_best_effort(&axes_values, rank) + } else { + axes_values + }; let mut options = Map::new(); options.insert("axes".to_string(), serde_json::json!(axes_values)); @@ -1245,6 +1288,64 @@ mod tests { ); } + #[test] + fn test_convert_reshape_errors_when_shape_non_const_and_input_unknown() { + let handler = ReshapeHandler; + let node = create_test_node("Reshape", vec!["data", "shape_dyn"], vec!["reshaped"]); + + let initializers = std::collections::HashMap::new(); + let value_shapes = std::collections::HashMap::new(); + let const_values = std::collections::HashMap::new(); + let value_ids = std::collections::HashMap::new(); + let value_types = std::collections::HashMap::new(); + let context = ConversionContext { + initializers: &initializers, + value_shapes: &value_shapes, + value_shape_dims: crate::onnx::ops::empty_value_shape_dims(), + const_values: &const_values, + value_ids: &value_ids, + value_types: &value_types, + }; + + let err = handler + .convert(&node, &context) + .expect_err("expected reshape error"); + let msg = err.to_string(); + assert!(msg.contains("shape input")); + assert!(msg.contains("must be a constant")); + } + + #[test] + fn test_convert_reshape_uses_known_output_shape_when_shape_input_non_const() { + let handler = ReshapeHandler; + let node = create_test_node("Reshape", vec!["data", "shape_dyn"], vec!["reshaped"]); + + let initializers = std::collections::HashMap::new(); + let mut value_shapes = std::collections::HashMap::new(); + value_shapes.insert("reshaped".to_string(), vec![1, 128, 384]); + let const_values = std::collections::HashMap::new(); + let value_ids = std::collections::HashMap::new(); + let value_types = std::collections::HashMap::new(); + let context = ConversionContext { + initializers: &initializers, + value_shapes: &value_shapes, + value_shape_dims: crate::onnx::ops::empty_value_shape_dims(), + const_values: &const_values, + value_ids: &value_ids, + value_types: &value_types, + }; + + let result = handler + .convert(&node, &context) + .expect("reshape should convert"); + assert_eq!(result.nodes.len(), 1); + assert_eq!(result.nodes[0].op, "reshape"); + assert_eq!( + result.nodes[0].options.get("newShape"), + Some(&serde_json::json!([1, 128, 384])) + ); + } + #[test] fn test_convert_transpose() { let handler = ReshapeHandler; @@ -1307,9 +1408,12 @@ mod tests { fn test_convert_concat() { let handler = ReshapeHandler; let mut node = create_test_node("Concat", vec!["a", "b", "c"], vec!["result"]); - add_int_attribute(&mut node, "axis", 1); + add_int_attribute(&mut node, "axis", -1); let initializers = std::collections::HashMap::new(); - let value_shapes = std::collections::HashMap::new(); + let mut value_shapes = std::collections::HashMap::new(); + value_shapes.insert("a".to_string(), vec![1, 2, 3]); + value_shapes.insert("b".to_string(), vec![1, 2, 3]); + value_shapes.insert("c".to_string(), vec![1, 2, 3]); let const_values = std::collections::HashMap::new(); let value_ids = std::collections::HashMap::new(); let value_types = std::collections::HashMap::new(); @@ -1327,15 +1431,20 @@ mod tests { assert_eq!(result.nodes[0].op, "concat"); assert_eq!(result.nodes[0].inputs.len(), 3); assert!(result.nodes[0].options.contains_key("axis")); + assert_eq!( + result.nodes[0].options.get("axis"), + Some(&serde_json::json!(2)) + ); } #[test] fn test_convert_split() { let handler = ReshapeHandler; let mut node = create_test_node("Split", vec!["x"], vec!["y1", "y2"]); - add_int_attribute(&mut node, "axis", 0); + add_int_attribute(&mut node, "axis", -1); let initializers = std::collections::HashMap::new(); - let value_shapes = std::collections::HashMap::new(); + let mut value_shapes = std::collections::HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 2, 4]); let const_values = std::collections::HashMap::new(); let value_ids = std::collections::HashMap::new(); let value_types = std::collections::HashMap::new(); @@ -1352,6 +1461,10 @@ mod tests { assert_eq!(result.nodes.len(), 1); assert_eq!(result.nodes[0].op, "split"); assert!(result.nodes[0].outputs.is_some()); + assert_eq!( + result.nodes[0].options.get("axis"), + Some(&serde_json::json!(2)) + ); } #[test] diff --git a/src/onnx/ops/utility.rs b/src/onnx/ops/utility.rs index 8aba0c5..c4938ee 100644 --- a/src/onnx/ops/utility.rs +++ b/src/onnx/ops/utility.rs @@ -3,7 +3,9 @@ use crate::ast::Node; use crate::ast::{ConstDecl, ConstInit, DataType}; use crate::onnx::convert::{sanitize_identifier, OnnxError}; -use crate::onnx::ops::{ConversionContext, ConversionResult, OpHandler}; +use crate::onnx::ops::{ + normalize_axis_best_effort, ConversionContext, ConversionResult, OpHandler, +}; use crate::protos::onnx::NodeProto; use serde_json::{json, Map}; @@ -428,6 +430,12 @@ impl UtilityHandler { let input0 = context.resolve_input(&inputs[0]); let input1 = context.resolve_input(&inputs[1]); + let axis = if let Some(rank) = context.input_rank(inputs[0].as_str()) { + normalize_axis_best_effort(axis, rank) + } else { + axis + }; + let mut options = Map::new(); options.insert("axis".to_string(), serde_json::json!(axis)); @@ -436,10 +444,7 @@ impl UtilityHandler { context.value_shapes.get(&inputs[0]), context.value_shapes.get(&inputs[1]), ) { - let mut resolved_axis = axis; - if resolved_axis < 0 { - resolved_axis += data_shape.len() as i64; - } + let resolved_axis = axis; if resolved_axis >= 0 && (resolved_axis as usize) < data_shape.len() { let axis_idx = resolved_axis as usize; let mut out_shape = Vec::new(); @@ -628,16 +633,90 @@ impl UtilityHandler { ends_norm.resize(desired_len, fill); } - options.insert("starts".to_string(), serde_json::json!(starts_norm)); - options.insert("ends".to_string(), serde_json::json!(ends_norm)); + if let Some(input_shape) = context.resolve_shape(inputs[0].as_str()) { + let rank = input_shape.len(); + let mut axes = if let Some(a) = axes_opt { + if a.is_empty() { + (0..desired_len as i64).collect::>() + } else { + a + } + } else { + (0..desired_len as i64).collect::>() + }; + if axes.len() != desired_len { + axes.resize(desired_len, 0); + } + let axes: Vec = axes + .iter() + .map(|&a| normalize_axis_best_effort(a, rank)) + .collect(); + + let mut steps = if inputs.len() >= 5 { + let steps_name = inputs[4].as_str(); + read_ints(steps_name, context).unwrap_or_default() + } else { + Vec::new() + }; + if steps.len() > desired_len { + steps.truncate(desired_len); + } else { + steps.resize(desired_len, 1); + } - if let Some(axes) = axes_opt { - options.insert("axes".to_string(), serde_json::json!(axes)); - } - if inputs.len() >= 5 { - let steps_name = inputs[4].as_str(); - if let Some(steps) = read_ints(steps_name, context) { - options.insert("steps".to_string(), serde_json::json!(steps)); + let mut dense_starts = vec![0i64; rank]; + let mut dense_sizes: Vec = input_shape.clone(); + let mut dense_strides = vec![1i64; rank]; + + for i in 0..desired_len { + let axis = axes[i] as usize; + let dim = input_shape[axis]; + let step = steps[i]; + if step <= 0 { + return Err(OnnxError::InvalidShape( + "Slice currently requires positive step values".to_string(), + )); + } + + let mut start = starts_norm[i]; + let mut end = ends_norm[i]; + if start < 0 { + start += dim; + } + if end == i64::MAX { + end = dim; + } else if end < 0 { + end += dim; + } + start = start.clamp(0, dim); + end = end.clamp(0, dim); + + let size = if end <= start { + 0 + } else { + (end - start + step - 1) / step + }; + + dense_starts[axis] = start; + dense_sizes[axis] = size; + dense_strides[axis] = step; + } + + options.insert("starts".to_string(), serde_json::json!(dense_starts)); + options.insert("sizes".to_string(), serde_json::json!(dense_sizes)); + options.insert("strides".to_string(), serde_json::json!(dense_strides)); + } else { + // Fallback for unknown-rank tensors: keep ONNX-style static slice options. + options.insert("starts".to_string(), serde_json::json!(starts_norm)); + options.insert("ends".to_string(), serde_json::json!(ends_norm)); + if let Some(axes) = axes_opt { + options.insert("axes".to_string(), serde_json::json!(axes)); + } + if inputs.len() >= 5 { + let steps_name = inputs[4].as_str(); + if let Some(steps) = read_ints(steps_name, context) { + options.insert("steps".to_string(), serde_json::json!(steps)); + } } } } else { @@ -665,6 +744,89 @@ impl UtilityHandler { "Slice requires static starts/ends".to_string(), )); } + + if let Some(input_shape) = context.resolve_shape(inputs[0].as_str()) { + let rank = input_shape.len(); + let starts = options + .remove("starts") + .and_then(|v| serde_json::from_value::>(v).ok()) + .ok_or_else(|| OnnxError::InvalidShape("Slice starts malformed".to_string()))?; + let ends = options + .remove("ends") + .and_then(|v| serde_json::from_value::>(v).ok()) + .ok_or_else(|| OnnxError::InvalidShape("Slice ends malformed".to_string()))?; + let axes = options + .remove("axes") + .and_then(|v| serde_json::from_value::>(v).ok()) + .unwrap_or_else(|| (0..starts.len() as i64).collect::>()); + let mut steps = options + .remove("steps") + .and_then(|v| serde_json::from_value::>(v).ok()) + .unwrap_or_else(|| vec![1; starts.len()]); + + let desired_len = starts.len().max(ends.len()).max(axes.len()); + let mut starts = starts; + let mut ends = ends; + let mut axes = axes; + if starts.len() < desired_len { + starts.resize(desired_len, 0); + } + if ends.len() < desired_len { + ends.resize(desired_len, i64::MAX); + } + if axes.len() < desired_len { + axes.resize(desired_len, 0); + } + if steps.len() < desired_len { + steps.resize(desired_len, 1); + } + + let axes: Vec = axes + .iter() + .map(|&a| normalize_axis_best_effort(a, rank)) + .collect(); + let mut dense_starts = vec![0i64; rank]; + let mut dense_sizes: Vec = input_shape.clone(); + let mut dense_strides = vec![1i64; rank]; + + for i in 0..desired_len { + let axis = axes[i] as usize; + let dim = input_shape[axis]; + let step = steps[i]; + if step <= 0 { + return Err(OnnxError::InvalidShape( + "Slice currently requires positive step values".to_string(), + )); + } + + let mut start = starts[i]; + let mut end = ends[i]; + if start < 0 { + start += dim; + } + if end == i64::MAX { + end = dim; + } else if end < 0 { + end += dim; + } + start = start.clamp(0, dim); + end = end.clamp(0, dim); + + let size = if end <= start { + 0 + } else { + (end - start + step - 1) / step + }; + + dense_starts[axis] = start; + dense_sizes[axis] = size; + dense_strides[axis] = step; + } + + options.insert("starts".to_string(), serde_json::json!(dense_starts)); + options.insert("sizes".to_string(), serde_json::json!(dense_sizes)); + options.insert("strides".to_string(), serde_json::json!(dense_strides)); + } } let mut result = ConversionResult::new(vec![Node { @@ -753,9 +915,11 @@ mod tests { fn test_convert_gather() { let handler = UtilityHandler; let mut node = create_test_node("Gather", vec!["data", "indices"], vec!["output"]); - add_int_attribute(&mut node, "axis", 1); + add_int_attribute(&mut node, "axis", -1); let initializers = std::collections::HashMap::new(); - let value_shapes = std::collections::HashMap::new(); + let mut value_shapes = std::collections::HashMap::new(); + value_shapes.insert("data".to_string(), vec![2, 3, 4]); + value_shapes.insert("indices".to_string(), vec![2]); let const_values = std::collections::HashMap::new(); let value_ids = std::collections::HashMap::new(); let value_types = std::collections::HashMap::new(); @@ -773,17 +937,28 @@ mod tests { assert_eq!(result.nodes[0].op, "gather"); assert_eq!(result.nodes[0].inputs.len(), 2); assert!(result.nodes[0].options.contains_key("axis")); + assert_eq!( + result.nodes[0].options.get("axis"), + Some(&serde_json::json!(2)) + ); } #[test] fn test_convert_slice() { let handler = UtilityHandler; - let node = create_test_node("Slice", vec!["x", "starts", "ends"], vec!["output"]); + let node = create_test_node( + "Slice", + vec!["x", "starts", "ends", "axes", "steps"], + vec!["output"], + ); let initializers = std::collections::HashMap::new(); - let value_shapes = std::collections::HashMap::new(); + let mut value_shapes = std::collections::HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 128]); let mut const_values = std::collections::HashMap::new(); - const_values.insert("starts".to_string(), vec![0, 1]); - const_values.insert("ends".to_string(), vec![3, 3]); + const_values.insert("starts".to_string(), vec![0]); + const_values.insert("ends".to_string(), vec![128]); + const_values.insert("axes".to_string(), vec![1]); + const_values.insert("steps".to_string(), vec![1]); let value_ids = std::collections::HashMap::new(); let value_types = std::collections::HashMap::new(); let context = ConversionContext { @@ -800,6 +975,21 @@ mod tests { assert_eq!(result.nodes[0].op, "slice"); assert_eq!(result.nodes[0].inputs, vec!["x"]); assert!(result.nodes[0].options.contains_key("starts")); + assert_eq!( + result.nodes[0].options.get("starts"), + Some(&serde_json::json!([0, 0])) + ); + assert_eq!( + result.nodes[0].options.get("sizes"), + Some(&serde_json::json!([1, 128])) + ); + assert_eq!( + result.nodes[0].options.get("strides"), + Some(&serde_json::json!([1, 1])) + ); + assert!(!result.nodes[0].options.contains_key("ends")); + assert!(!result.nodes[0].options.contains_key("axes")); + assert!(!result.nodes[0].options.contains_key("steps")); } #[test]