diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 453b4814637..46717a52014 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -104,6 +104,19 @@ runtime.python_library( ], ) +runtime.python_library( + name = "insert_dtype_promotion", + srcs = ["insert_dtype_promotion.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + runtime.python_library( name = "fuse_patterns", srcs = ["fuse_patterns.py"], @@ -133,6 +146,7 @@ runtime.python_library( ":fold_qdq", ":fuse_patterns", ":fuse_quantized_ops", + ":insert_dtype_promotion", ":insert_prepack_nodes", ":remove_asserts", ":remove_redundant_ops", diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index d6a6823ca88..1afaf48dde7 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -11,6 +11,9 @@ from executorch.backends.vulkan._passes.fuse_quantized_ops import ( FuseQuantizedOpsTransform, ) +from executorch.backends.vulkan._passes.insert_dtype_promotion import ( + InsertDtypePromotionPass, +) from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes from executorch.backends.vulkan._passes.remove_asserts import ( remove_asserts, @@ -28,6 +31,7 @@ "FoldQDQPass", "FusePatternsPass", "FuseQuantizedOpsTransform", + "InsertDtypePromotionPass", "insert_prepack_nodes", "remove_asserts", "RemoveAssertsTransform", diff --git a/backends/vulkan/_passes/insert_dtype_promotion.py b/backends/vulkan/_passes/insert_dtype_promotion.py new file mode 100644 index 00000000000..324273a69df --- /dev/null +++ b/backends/vulkan/_passes/insert_dtype_promotion.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Set, Union + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + +OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload] + +# Binary ops whose first two args are tensor inputs that may need promotion +BINARY_OPS: Set[OpType] = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.pow.Tensor_Tensor, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.ge.Tensor, +} + + +def _promote_dtype(a: torch.dtype, b: torch.dtype) -> torch.dtype: + """Promote to common dtype following PyTorch type promotion rules.""" + if a == b: + return a + # Any mix of different dtypes promotes to float32 + return torch.float32 + + +class InsertDtypePromotionPass(ExportPass): + """ + Insert _to_copy nodes before binary ops when the two tensor inputs have + different dtypes, promoting both to a common dtype. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + dirty = False + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in BINARY_OPS: + continue + + lhs = node.args[0] + rhs = node.args[1] + + if not isinstance(lhs, torch.fx.Node) or not isinstance(rhs, torch.fx.Node): + continue + + if "val" not in lhs.meta or "val" not in rhs.meta: + continue + + lhs_dtype = lhs.meta["val"].dtype + rhs_dtype = rhs.meta["val"].dtype + + if lhs_dtype == rhs_dtype: + continue + + promoted = _promote_dtype(lhs_dtype, rhs_dtype) + + if lhs_dtype != promoted: + with graph_module.graph.inserting_before(node): + cast_lhs = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (lhs,), + {"dtype": promoted}, + ) + cast_lhs.meta["val"] = lhs.meta["val"].to(promoted) + node.replace_input_with(lhs, cast_lhs) + dirty = True + + if rhs_dtype != promoted: + with graph_module.graph.inserting_before(node): + cast_rhs = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (rhs,), + {"dtype": promoted}, + ) + cast_rhs.meta["val"] = rhs.meta["val"].to(promoted) + node.replace_input_with(rhs, cast_rhs) + dirty = True + + if dirty: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + dead_code_elimination_pass(graph_module) + + return PassResult(graph_module, dirty) diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index c45ed4ea25d..373b2a4d135 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -37,10 +37,28 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: # Vulkan compute graph. This annotation is used in later graph passes. node.meta["etvk_tensorref"] = True - # Get the list of node users that do not handle their own prepacking + # Get the list of node users that need a prepack node inserted. This + # includes ops that don't handle their own prepacking, as well as ops + # that do handle their own prepacking but use this constant tensor as + # the primary input (since the op expects the primary input to be a GPU + # tensor, not a TensorRef). nodes_to_replace_input = [] for user in node.users: - if user.op == "call_function" and not handles_own_prepacking(user.target): + if user.op != "call_function": + continue + + if not handles_own_prepacking(user.target): + nodes_to_replace_input.append(user) + continue + + # Most prepacking ops have the primary input at arg 0, but + # embedding is embedding(weight, indices, ...) where the + # primary input (indices) is at arg 1. + primary_arg_idx = 0 + if user.target == exir_ops.edge.aten.embedding.default: + primary_arg_idx = 1 + + if node in user.args and user.args.index(node) == primary_arg_idx: nodes_to_replace_input.append(user) if len(nodes_to_replace_input) == 0: diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 62997ea956f..bb7c0562bad 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -307,8 +307,8 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: return OpFeatures( inputs_storage=utils.ANY_STORAGE, - inputs_dtypes=utils.FP_T, - outputs_dtypes=utils.FP_T, + inputs_dtypes=utils.FP_INT_T, + outputs_dtypes=utils.FP_INT_T, supports_resize=True, are_node_inputs_supported_fn=check_to_copy_node, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index c3d5cd00204..bf07f1a58d2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -46,7 +46,7 @@ binary_op: - VALUE: half - VALUE: float - NAME: binary_eq_texture3d - OPERATOR: all(lessThanEqual(abs(X - Y), VEC4_T(1e-5))) + OPERATOR: lessThanEqual(abs(X - Y), VEC4_T(1e-5)) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -61,7 +61,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_lt_texture3d - OPERATOR: all(lessThan(X, Y)) + OPERATOR: lessThan(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -77,7 +77,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_le_texture3d - OPERATOR: all(lessThanEqual(X, Y)) + OPERATOR: lessThanEqual(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -93,7 +93,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_gt_texture3d - OPERATOR: all(greaterThan(X, Y)) + OPERATOR: greaterThan(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -109,7 +109,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_ge_texture3d - OPERATOR: all(greaterThanEqual(X, Y)) + OPERATOR: greaterThanEqual(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl new file mode 100644 index 00000000000..b2f946d6f36 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_VEC4_T ${texel_type(IN_DTYPE)} +#define OUT_VEC4_T ${texel_type(OUT_DTYPE)} + +${define_required_extensions("texture3d", IN_DTYPE)} +${define_required_extensions("texture3d", OUT_DTYPE)} + +#include "indexing_utils.h" + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = C_DIM; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim); + + if (any(greaterThanEqual(idx, out_sizes))) { + return; + } + + IN_VEC4_T in_texel = IN_VEC4_T(texelFetch(t_in, pos, 0)); + imageStore(t_out, pos, OUT_VEC4_T(in_texel)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml new file mode 100644 index 00000000000..2aec8322dfe --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +view_convert_texture: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: float + STORAGE: texture3d + generate_variant_forall: + combination: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [int32, float] + - parameter_values: [float, int32] + shader_variants: + - NAME: view_convert_texture diff --git a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp index b7e0218823a..275023faa59 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp @@ -9,8 +9,8 @@ #include #include #include +#include #include -#include namespace vkcompute { @@ -25,19 +25,29 @@ void resize_to_copy_op_node( graph->virtual_resize(out, graph->sizes_of(self)); } +bool is_float_type(vkapi::ScalarType dtype) { + return dtype == vkapi::ScalarType::Float || dtype == vkapi::ScalarType::Half; +} + void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) { - static std::set supported_types = { - vkapi::ScalarType::Float, vkapi::ScalarType::Half}; - - VK_CHECK_COND( - supported_types.find(graph.dtype_of(in)) != supported_types.end() && - supported_types.find(graph.dtype_of(out)) != supported_types.end(), - "Unsupported dtype for to_copy, only Float and Half are currently supported, recieved ", - vkapi::to_string(graph.dtype_of(in)), - " <-> ", - vkapi::to_string(graph.dtype_of(out))); - - graph.execute_nodes().emplace_back(new BlitNode(graph, in, out)); + vkapi::ScalarType in_dtype = graph.dtype_of(in); + vkapi::ScalarType out_dtype = graph.dtype_of(out); + + // Same-dtype or float<->half conversions can use BlitNode + if (in_dtype == out_dtype || + (is_float_type(in_dtype) && is_float_type(out_dtype))) { + graph.execute_nodes().emplace_back(new BlitNode(graph, in, out)); + return; + } + + // Other conversions (e.g. int<->float) use view_convert shaders + if (graph.is_buffer_storage(in)) { + add_view_copy_convert_buffer_node( + graph, in, out, {}, resize_to_copy_op_node); + } else { + add_view_copy_convert_texture_node( + graph, in, out, {}, resize_to_copy_op_node); + } } void to_copy(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 3b5ffb4589f..d0b70460214 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -127,6 +127,35 @@ void add_view_copy_buffer_node( resize_fn)); } +void add_view_copy_convert_texture_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn) { + std::string kernel_name = "view_convert_texture"; + add_dtype_suffix(kernel_name, graph.dtype_of(in)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Parameter Buffers + {}, + // Push Constants + {{graph.sizes_pc_of(out)}}, + // Specialization Constants + {graph.packed_dim_of(out)}, + // Resize Args + resize_args, + // Resizing Logic + resize_fn)); +} + void add_view_copy_convert_buffer_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/runtime/graph/ops/impl/View.h b/backends/vulkan/runtime/graph/ops/impl/View.h index c8e52492417..a72cf04d4d3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.h +++ b/backends/vulkan/runtime/graph/ops/impl/View.h @@ -37,6 +37,18 @@ void add_view_copy_convert_buffer_node( const std::vector& resize_args, const ExecuteNode::ResizeFunction& resize_fn); +/* + * Dispatches the view_convert_texture compute shader. This can be used to + * convert between different data types for 3D texture tensors while + * preserving the texel positions. + */ +void add_view_copy_convert_texture_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn); + void add_view_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index db1211883c7..e9d5613668a 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -21,6 +21,7 @@ FoldQDQPass, FuseQuantizedOpsTransform, insert_prepack_nodes, + InsertDtypePromotionPass, RemoveRedundantOpsTransform, SqueezeUnsqueezeInputs, TagMemoryMetaPass, @@ -165,6 +166,7 @@ def preprocess( # noqa: C901 AddmmToLinearTransform(), FuseBatchNormPass(program), AddmmToLinearTransform(), + InsertDtypePromotionPass(), FusePatternsPass(), FuseClampPass(), RemoveRedundantOpsTransform(),