diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 721297dea37..853ba5d3777 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1233,25 +1233,11 @@ def register_embedding(): @update_features(exir_ops.edge.aten._native_batch_norm_legit_no_training.default) def register_native_batch_norm_legit_no_training(): - def check_batch_norm_node(node: torch.fx.Node) -> bool: - x = node.args[0] - if not isinstance(x, torch.fx.Node): - return False - x_val = x.meta.get("val", None) - if x_val is None: - return False - x_shape = x_val.size() - # Only support 4-D input tensors since this is a restriction enforced by the - # operator implementation. - # TODO(ssjia): Add shape agnostic support for batch norm - return len(x_shape) == 4 - return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, inputs_dtypes=utils.FP_T, supports_prepacking=True, supports_resize=True, - are_node_inputs_supported_fn=check_batch_norm_node, ) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 3ccbdc8ab85..b276ffd16f5 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -162,10 +162,10 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ + AddmmToLinearTransform(), FuseBatchNormPass(program), FusePatternsPass(), FuseClampPass(), - AddmmToLinearTransform(), RemoveRedundantOpsTransform(), FuseQuantizedOpsTransform(), FoldQDQPass(),