From 101490ab2b5651bb0347472db8752c8f461c238c Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 17 Feb 2026 12:19:52 -0800 Subject: [PATCH] [ET-VK][ez] Always partition batch norm as it will be fused The batch norm operator registration had a check_batch_norm_node guard that restricted partitioning to 4D input tensors only. Since batch norm is always fused with adjacent operations during graph compilation, this restriction is unnecessary and prevents valid models from being partitioned to the Vulkan backend. Remove the guard so batch norm is always eligible for Vulkan partitioning regardless of input dimensionality. Differential Revision: [D93511630](https://our.internmc.facebook.com/intern/diff/D93511630/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 721297dea37..853ba5d3777 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1233,25 +1233,11 @@ def register_embedding(): @update_features(exir_ops.edge.aten._native_batch_norm_legit_no_training.default) def register_native_batch_norm_legit_no_training(): - def check_batch_norm_node(node: torch.fx.Node) -> bool: - x = node.args[0] - if not isinstance(x, torch.fx.Node): - return False - x_val = x.meta.get("val", None) - if x_val is None: - return False - x_shape = x_val.size() - # Only support 4-D input tensors since this is a restriction enforced by the - # operator implementation. - # TODO(ssjia): Add shape agnostic support for batch norm - return len(x_shape) == 4 - return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, inputs_dtypes=utils.FP_T, supports_prepacking=True, supports_resize=True, - are_node_inputs_supported_fn=check_batch_norm_node, )