diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 721297dea37..853ba5d3777 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1233,25 +1233,11 @@ def register_embedding(): @update_features(exir_ops.edge.aten._native_batch_norm_legit_no_training.default) def register_native_batch_norm_legit_no_training(): - def check_batch_norm_node(node: torch.fx.Node) -> bool: - x = node.args[0] - if not isinstance(x, torch.fx.Node): - return False - x_val = x.meta.get("val", None) - if x_val is None: - return False - x_shape = x_val.size() - # Only support 4-D input tensors since this is a restriction enforced by the - # operator implementation. - # TODO(ssjia): Add shape agnostic support for batch norm - return len(x_shape) == 4 - return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, inputs_dtypes=utils.FP_T, supports_prepacking=True, supports_resize=True, - are_node_inputs_supported_fn=check_batch_norm_node, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl index 60f437fbdce..be93e800436 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl @@ -46,6 +46,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "other_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "block_config", "0")} // Generate loading functions for input buffers @@ -71,7 +72,7 @@ void main() { ivec4 in_block_a = load_int8x4_block_from_t_in_a( in_a_meta, tidx, in_layout, block_outer_dim); ivec4 in_block_b = load_int8x4_block_from_t_in_b( - in_b_meta, tidx, in_layout, block_outer_dim); + in_b_meta, tidx, other_layout, block_outer_dim); ivec4 out_block; diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp index af934b9b521..05bdd9431c8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp @@ -42,6 +42,7 @@ void add_q8ta_binary_node( VK_CHECK_COND(input_a_info.packed_dim == output_info.packed_dim); VK_CHECK_COND(input_b_info.packed_dim == output_info.packed_dim); + VK_CHECK_COND( input_a_info.packed_dim_block_size == output_info.packed_dim_block_size); VK_CHECK_COND( @@ -105,6 +106,7 @@ void add_q8ta_binary_node( // Specialization Constants {graph.hashed_layout_of(packed_int8_output), graph.hashed_layout_of(packed_int8_input_a), + graph.hashed_layout_of(packed_int8_input_b), block_config.as_packed_int()}, // Resize args {block_config_ref}, diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 3ccbdc8ab85..b276ffd16f5 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -162,10 +162,10 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ + AddmmToLinearTransform(), FuseBatchNormPass(program), FusePatternsPass(), FuseClampPass(), - AddmmToLinearTransform(), RemoveRedundantOpsTransform(), FuseQuantizedOpsTransform(), FoldQDQPass(),