From bb093ef4c84e82b2f529dae99d0746e4c05731fd Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 4 Mar 2026 08:29:35 -0800 Subject: [PATCH] [ET-VK] Fix mixed-dtype binary ops and comparison op padding bugs Two bugs caused incorrect outputs in models with mixed-dtype binary operations (e.g. EdgeTAM remaining frames): 1. Mixed-dtype binary ops (e.g. int arange vs float tensor) were fed to shaders that declare both inputs with the same DTYPE, causing data misinterpretation. This is now fixed by adding an `InsertDtypePromotionPass` export pass that inserts `_to_copy` nodes to promote inputs to a common dtype at compile time. The `_to_copy` op is extended to support int<->float conversions via new `view_convert_texture` shaders, and the previous float/half-only restriction in ToCopy.cpp is replaced with branching logic that uses BlitNode for same-dtype/float<->half and view_convert shaders for other conversions. 2. Texture3d comparison operators (gt, lt, le, ge, eq) used `all()` to reduce component-wise `bvec4` results to a single bool. With packed textures where padding components are zero, `all()` always returned false because padding zeros fail comparison against non-zero values. Fixed by removing `all()` so the result stays as a component-wise `bvec4`, which is correctly converted to `uvec4` for the Bool output texture. Additional changes: - New `view_convert_texture.glsl` shader and YAML for texture dtype conversion - `add_view_copy_convert_texture_node` added to View.cpp/h - `_to_copy` op registry updated to accept int dtypes (FP_INT_T) Differential Revision: [D95217948](https://our.internmc.facebook.com/intern/diff/D95217948/) [ghstack-poisoned] --- backends/vulkan/_passes/TARGETS | 14 +++ backends/vulkan/_passes/__init__.py | 4 + .../vulkan/_passes/insert_dtype_promotion.py | 102 ++++++++++++++++++ backends/vulkan/op_registry.py | 4 +- .../runtime/graph/ops/glsl/binary_op.yaml | 10 +- .../graph/ops/glsl/view_convert_texture.glsl | 44 ++++++++ .../graph/ops/glsl/view_convert_texture.yaml | 19 ++++ .../vulkan/runtime/graph/ops/impl/ToCopy.cpp | 36 ++++--- .../vulkan/runtime/graph/ops/impl/View.cpp | 29 +++++ backends/vulkan/runtime/graph/ops/impl/View.h | 12 +++ backends/vulkan/vulkan_preprocess.py | 2 + 11 files changed, 256 insertions(+), 20 deletions(-) create mode 100644 backends/vulkan/_passes/insert_dtype_promotion.py create mode 100644 backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 453b4814637..46717a52014 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -104,6 +104,19 @@ runtime.python_library( ], ) +runtime.python_library( + name = "insert_dtype_promotion", + srcs = ["insert_dtype_promotion.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + runtime.python_library( name = "fuse_patterns", srcs = ["fuse_patterns.py"], @@ -133,6 +146,7 @@ runtime.python_library( ":fold_qdq", ":fuse_patterns", ":fuse_quantized_ops", + ":insert_dtype_promotion", ":insert_prepack_nodes", ":remove_asserts", ":remove_redundant_ops", diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index d6a6823ca88..1afaf48dde7 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -11,6 +11,9 @@ from executorch.backends.vulkan._passes.fuse_quantized_ops import ( FuseQuantizedOpsTransform, ) +from executorch.backends.vulkan._passes.insert_dtype_promotion import ( + InsertDtypePromotionPass, +) from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes from executorch.backends.vulkan._passes.remove_asserts import ( remove_asserts, @@ -28,6 +31,7 @@ "FoldQDQPass", "FusePatternsPass", "FuseQuantizedOpsTransform", + "InsertDtypePromotionPass", "insert_prepack_nodes", "remove_asserts", "RemoveAssertsTransform", diff --git a/backends/vulkan/_passes/insert_dtype_promotion.py b/backends/vulkan/_passes/insert_dtype_promotion.py new file mode 100644 index 00000000000..324273a69df --- /dev/null +++ b/backends/vulkan/_passes/insert_dtype_promotion.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Set, Union + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + +OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload] + +# Binary ops whose first two args are tensor inputs that may need promotion +BINARY_OPS: Set[OpType] = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.pow.Tensor_Tensor, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.ge.Tensor, +} + + +def _promote_dtype(a: torch.dtype, b: torch.dtype) -> torch.dtype: + """Promote to common dtype following PyTorch type promotion rules.""" + if a == b: + return a + # Any mix of different dtypes promotes to float32 + return torch.float32 + + +class InsertDtypePromotionPass(ExportPass): + """ + Insert _to_copy nodes before binary ops when the two tensor inputs have + different dtypes, promoting both to a common dtype. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + dirty = False + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in BINARY_OPS: + continue + + lhs = node.args[0] + rhs = node.args[1] + + if not isinstance(lhs, torch.fx.Node) or not isinstance(rhs, torch.fx.Node): + continue + + if "val" not in lhs.meta or "val" not in rhs.meta: + continue + + lhs_dtype = lhs.meta["val"].dtype + rhs_dtype = rhs.meta["val"].dtype + + if lhs_dtype == rhs_dtype: + continue + + promoted = _promote_dtype(lhs_dtype, rhs_dtype) + + if lhs_dtype != promoted: + with graph_module.graph.inserting_before(node): + cast_lhs = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (lhs,), + {"dtype": promoted}, + ) + cast_lhs.meta["val"] = lhs.meta["val"].to(promoted) + node.replace_input_with(lhs, cast_lhs) + dirty = True + + if rhs_dtype != promoted: + with graph_module.graph.inserting_before(node): + cast_rhs = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (rhs,), + {"dtype": promoted}, + ) + cast_rhs.meta["val"] = rhs.meta["val"].to(promoted) + node.replace_input_with(rhs, cast_rhs) + dirty = True + + if dirty: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + dead_code_elimination_pass(graph_module) + + return PassResult(graph_module, dirty) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 62997ea956f..bb7c0562bad 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -307,8 +307,8 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: return OpFeatures( inputs_storage=utils.ANY_STORAGE, - inputs_dtypes=utils.FP_T, - outputs_dtypes=utils.FP_T, + inputs_dtypes=utils.FP_INT_T, + outputs_dtypes=utils.FP_INT_T, supports_resize=True, are_node_inputs_supported_fn=check_to_copy_node, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index c3d5cd00204..bf07f1a58d2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -46,7 +46,7 @@ binary_op: - VALUE: half - VALUE: float - NAME: binary_eq_texture3d - OPERATOR: all(lessThanEqual(abs(X - Y), VEC4_T(1e-5))) + OPERATOR: lessThanEqual(abs(X - Y), VEC4_T(1e-5)) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -61,7 +61,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_lt_texture3d - OPERATOR: all(lessThan(X, Y)) + OPERATOR: lessThan(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -77,7 +77,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_le_texture3d - OPERATOR: all(lessThanEqual(X, Y)) + OPERATOR: lessThanEqual(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -93,7 +93,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_gt_texture3d - OPERATOR: all(greaterThan(X, Y)) + OPERATOR: greaterThan(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -109,7 +109,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_ge_texture3d - OPERATOR: all(greaterThanEqual(X, Y)) + OPERATOR: greaterThanEqual(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl new file mode 100644 index 00000000000..b2f946d6f36 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_VEC4_T ${texel_type(IN_DTYPE)} +#define OUT_VEC4_T ${texel_type(OUT_DTYPE)} + +${define_required_extensions("texture3d", IN_DTYPE)} +${define_required_extensions("texture3d", OUT_DTYPE)} + +#include "indexing_utils.h" + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = C_DIM; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim); + + if (any(greaterThanEqual(idx, out_sizes))) { + return; + } + + IN_VEC4_T in_texel = IN_VEC4_T(texelFetch(t_in, pos, 0)); + imageStore(t_out, pos, OUT_VEC4_T(in_texel)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml new file mode 100644 index 00000000000..2aec8322dfe --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +view_convert_texture: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: float + STORAGE: texture3d + generate_variant_forall: + combination: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [int32, float] + - parameter_values: [float, int32] + shader_variants: + - NAME: view_convert_texture diff --git a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp index b7e0218823a..275023faa59 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp @@ -9,8 +9,8 @@ #include #include #include +#include #include -#include namespace vkcompute { @@ -25,19 +25,29 @@ void resize_to_copy_op_node( graph->virtual_resize(out, graph->sizes_of(self)); } +bool is_float_type(vkapi::ScalarType dtype) { + return dtype == vkapi::ScalarType::Float || dtype == vkapi::ScalarType::Half; +} + void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) { - static std::set supported_types = { - vkapi::ScalarType::Float, vkapi::ScalarType::Half}; - - VK_CHECK_COND( - supported_types.find(graph.dtype_of(in)) != supported_types.end() && - supported_types.find(graph.dtype_of(out)) != supported_types.end(), - "Unsupported dtype for to_copy, only Float and Half are currently supported, recieved ", - vkapi::to_string(graph.dtype_of(in)), - " <-> ", - vkapi::to_string(graph.dtype_of(out))); - - graph.execute_nodes().emplace_back(new BlitNode(graph, in, out)); + vkapi::ScalarType in_dtype = graph.dtype_of(in); + vkapi::ScalarType out_dtype = graph.dtype_of(out); + + // Same-dtype or float<->half conversions can use BlitNode + if (in_dtype == out_dtype || + (is_float_type(in_dtype) && is_float_type(out_dtype))) { + graph.execute_nodes().emplace_back(new BlitNode(graph, in, out)); + return; + } + + // Other conversions (e.g. int<->float) use view_convert shaders + if (graph.is_buffer_storage(in)) { + add_view_copy_convert_buffer_node( + graph, in, out, {}, resize_to_copy_op_node); + } else { + add_view_copy_convert_texture_node( + graph, in, out, {}, resize_to_copy_op_node); + } } void to_copy(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 3b5ffb4589f..d0b70460214 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -127,6 +127,35 @@ void add_view_copy_buffer_node( resize_fn)); } +void add_view_copy_convert_texture_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn) { + std::string kernel_name = "view_convert_texture"; + add_dtype_suffix(kernel_name, graph.dtype_of(in)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Parameter Buffers + {}, + // Push Constants + {{graph.sizes_pc_of(out)}}, + // Specialization Constants + {graph.packed_dim_of(out)}, + // Resize Args + resize_args, + // Resizing Logic + resize_fn)); +} + void add_view_copy_convert_buffer_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/runtime/graph/ops/impl/View.h b/backends/vulkan/runtime/graph/ops/impl/View.h index c8e52492417..a72cf04d4d3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.h +++ b/backends/vulkan/runtime/graph/ops/impl/View.h @@ -37,6 +37,18 @@ void add_view_copy_convert_buffer_node( const std::vector& resize_args, const ExecuteNode::ResizeFunction& resize_fn); +/* + * Dispatches the view_convert_texture compute shader. This can be used to + * convert between different data types for 3D texture tensors while + * preserving the texel positions. + */ +void add_view_copy_convert_texture_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn); + void add_view_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index db1211883c7..e9d5613668a 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -21,6 +21,7 @@ FoldQDQPass, FuseQuantizedOpsTransform, insert_prepack_nodes, + InsertDtypePromotionPass, RemoveRedundantOpsTransform, SqueezeUnsqueezeInputs, TagMemoryMetaPass, @@ -165,6 +166,7 @@ def preprocess( # noqa: C901 AddmmToLinearTransform(), FuseBatchNormPass(program), AddmmToLinearTransform(), + InsertDtypePromotionPass(), FusePatternsPass(), FuseClampPass(), RemoveRedundantOpsTransform(),