diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 453b4814637..46717a52014 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -104,6 +104,19 @@ runtime.python_library( ], ) +runtime.python_library( + name = "insert_dtype_promotion", + srcs = ["insert_dtype_promotion.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + runtime.python_library( name = "fuse_patterns", srcs = ["fuse_patterns.py"], @@ -133,6 +146,7 @@ runtime.python_library( ":fold_qdq", ":fuse_patterns", ":fuse_quantized_ops", + ":insert_dtype_promotion", ":insert_prepack_nodes", ":remove_asserts", ":remove_redundant_ops", diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index d6a6823ca88..1afaf48dde7 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -11,6 +11,9 @@ from executorch.backends.vulkan._passes.fuse_quantized_ops import ( FuseQuantizedOpsTransform, ) +from executorch.backends.vulkan._passes.insert_dtype_promotion import ( + InsertDtypePromotionPass, +) from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes from executorch.backends.vulkan._passes.remove_asserts import ( remove_asserts, @@ -28,6 +31,7 @@ "FoldQDQPass", "FusePatternsPass", "FuseQuantizedOpsTransform", + "InsertDtypePromotionPass", "insert_prepack_nodes", "remove_asserts", "RemoveAssertsTransform", diff --git a/backends/vulkan/_passes/insert_dtype_promotion.py b/backends/vulkan/_passes/insert_dtype_promotion.py new file mode 100644 index 00000000000..324273a69df --- /dev/null +++ b/backends/vulkan/_passes/insert_dtype_promotion.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Set, Union + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + +OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload] + +# Binary ops whose first two args are tensor inputs that may need promotion +BINARY_OPS: Set[OpType] = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.pow.Tensor_Tensor, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.ge.Tensor, +} + + +def _promote_dtype(a: torch.dtype, b: torch.dtype) -> torch.dtype: + """Promote to common dtype following PyTorch type promotion rules.""" + if a == b: + return a + # Any mix of different dtypes promotes to float32 + return torch.float32 + + +class InsertDtypePromotionPass(ExportPass): + """ + Insert _to_copy nodes before binary ops when the two tensor inputs have + different dtypes, promoting both to a common dtype. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + dirty = False + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in BINARY_OPS: + continue + + lhs = node.args[0] + rhs = node.args[1] + + if not isinstance(lhs, torch.fx.Node) or not isinstance(rhs, torch.fx.Node): + continue + + if "val" not in lhs.meta or "val" not in rhs.meta: + continue + + lhs_dtype = lhs.meta["val"].dtype + rhs_dtype = rhs.meta["val"].dtype + + if lhs_dtype == rhs_dtype: + continue + + promoted = _promote_dtype(lhs_dtype, rhs_dtype) + + if lhs_dtype != promoted: + with graph_module.graph.inserting_before(node): + cast_lhs = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (lhs,), + {"dtype": promoted}, + ) + cast_lhs.meta["val"] = lhs.meta["val"].to(promoted) + node.replace_input_with(lhs, cast_lhs) + dirty = True + + if rhs_dtype != promoted: + with graph_module.graph.inserting_before(node): + cast_rhs = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (rhs,), + {"dtype": promoted}, + ) + cast_rhs.meta["val"] = rhs.meta["val"].to(promoted) + node.replace_input_with(rhs, cast_rhs) + dirty = True + + if dirty: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + dead_code_elimination_pass(graph_module) + + return PassResult(graph_module, dirty) diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index c45ed4ea25d..373b2a4d135 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -37,10 +37,28 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: # Vulkan compute graph. This annotation is used in later graph passes. node.meta["etvk_tensorref"] = True - # Get the list of node users that do not handle their own prepacking + # Get the list of node users that need a prepack node inserted. This + # includes ops that don't handle their own prepacking, as well as ops + # that do handle their own prepacking but use this constant tensor as + # the primary input (since the op expects the primary input to be a GPU + # tensor, not a TensorRef). nodes_to_replace_input = [] for user in node.users: - if user.op == "call_function" and not handles_own_prepacking(user.target): + if user.op != "call_function": + continue + + if not handles_own_prepacking(user.target): + nodes_to_replace_input.append(user) + continue + + # Most prepacking ops have the primary input at arg 0, but + # embedding is embedding(weight, indices, ...) where the + # primary input (indices) is at arg 1. + primary_arg_idx = 0 + if user.target == exir_ops.edge.aten.embedding.default: + primary_arg_idx = 1 + + if node in user.args and user.args.index(node) == primary_arg_idx: nodes_to_replace_input.append(user) if len(nodes_to_replace_input) == 0: diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 62997ea956f..bb7c0562bad 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -307,8 +307,8 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: return OpFeatures( inputs_storage=utils.ANY_STORAGE, - inputs_dtypes=utils.FP_T, - outputs_dtypes=utils.FP_T, + inputs_dtypes=utils.FP_INT_T, + outputs_dtypes=utils.FP_INT_T, supports_resize=True, are_node_inputs_supported_fn=check_to_copy_node, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index c3d5cd00204..bf07f1a58d2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -46,7 +46,7 @@ binary_op: - VALUE: half - VALUE: float - NAME: binary_eq_texture3d - OPERATOR: all(lessThanEqual(abs(X - Y), VEC4_T(1e-5))) + OPERATOR: lessThanEqual(abs(X - Y), VEC4_T(1e-5)) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -61,7 +61,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_lt_texture3d - OPERATOR: all(lessThan(X, Y)) + OPERATOR: lessThan(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -77,7 +77,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_le_texture3d - OPERATOR: all(lessThanEqual(X, Y)) + OPERATOR: lessThanEqual(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -93,7 +93,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_gt_texture3d - OPERATOR: all(greaterThan(X, Y)) + OPERATOR: greaterThan(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -109,7 +109,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_ge_texture3d - OPERATOR: all(greaterThanEqual(X, Y)) + OPERATOR: greaterThanEqual(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl index 9b44d5c5a94..3176d0142bb 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl @@ -142,7 +142,25 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { for (int i = tid.x; i < tin_sizes[reduce_dim]; i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { const vec4 numerators = op1(load_texel(tin, scan_pos) - max_elements); - vec4 outtex = op2(numerators, denominators); + // Clamp denominator to avoid 0/0 = NaN when all exp values underflow. + const vec4 safe_denom = max(denominators, vec4(1e-37)); + vec4 outtex = op2(numerators, safe_denom); + // Replace any NaN/Inf with 0 using IEEE 754 bit-level manipulation. + // This avoids isnan()/x!=x which may not work reliably on all GPU drivers: + // - OpIsNan may have driver bugs for certain NaN bit patterns + // - OpFOrdNotEqual(NaN,NaN) = false (ordered comparison semantics) + // NaN/Inf pattern: all exponent bits set = (bits & 0x7F800000) == 0x7F800000 + { + uvec4 bits = floatBitsToUint(outtex); + // Build a mask: 0xFFFFFFFF where NaN/Inf (exponent all-ones), else 0 + uvec4 nan_inf_mask = uvec4( + ((bits.x & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u, + ((bits.y & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u, + ((bits.z & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u, + ((bits.w & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u); + // Zero out bits where NaN/Inf: normal values are unchanged + outtex = uintBitsToFloat(bits & ~nan_inf_mask); + } // For the last texel in the packed dim, make sure that the padding elements // are explicitly set to 0. Otherwise, they may influence computations later // down the line. @@ -153,6 +171,9 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { } write_texel(tout, scan_pos, outtex); } + // Flush outstanding imageStore writes so they're committed to memory and + // visible to subsequent GPU operations on this image. + memoryBarrierImage(); } /* @@ -173,7 +194,12 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) { const int reduce_len = tin_sizes[packed_dim] - nspill; scan_pos[reduce_dim] = tid.x; - vec4 max_elements = vec4(load_texel(tin, scan_pos).x); + // Initialize with -FLT_MAX to avoid contaminating the maximum with out-of- + // bounds texture reads. When NWORKERS > number of texels (e.g. reduce_len=12 + // has 3 texels but NWORKERS=4), worker threads with no valid texels would + // otherwise load from an OOB index and get 0, which corrupts the max for + // rows where all values are negative and causes denominator underflow -> NaN. + vec4 max_elements = vec4(-3.402823e+38); for (int i = tid.x * 4; i < reduce_len; i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) { max_elements = max(max_elements, load_texel(tin, scan_pos)); @@ -230,19 +256,21 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) { [[unroll]] for (int i = 0; i < 4; ++i) { denominator += denominators[i]; } + // Clamp denominator to avoid 0/0 = NaN when all exp values underflow. + const float safe_denominator = max(denominator, 1e-37); scan_pos[reduce_dim] = tid.x; for (int i = tid.x * 4; i < reduce_len; i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) { const vec4 numerators = op1(load_texel(tin, scan_pos) - max_element); - write_texel(tout, scan_pos, op2(numerators, denominator)); + write_texel(tout, scan_pos, op2(numerators, safe_denominator)); } // For the last texel in the dim, if there are padding elements then the // padding elements need to be set to 0 explicitly, otherwise they may // influence subsequent operations. if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) { const vec4 numerator = op1(load_texel(tin, scan_pos) - max_element); - vec4 outtex = op2(numerator, denominator); + vec4 outtex = op2(numerator, safe_denominator); [[unroll]] for (int i = nspill; i < 4; ++i) { outtex[i] = 0; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl new file mode 100644 index 00000000000..b2f946d6f36 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_VEC4_T ${texel_type(IN_DTYPE)} +#define OUT_VEC4_T ${texel_type(OUT_DTYPE)} + +${define_required_extensions("texture3d", IN_DTYPE)} +${define_required_extensions("texture3d", OUT_DTYPE)} + +#include "indexing_utils.h" + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = C_DIM; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim); + + if (any(greaterThanEqual(idx, out_sizes))) { + return; + } + + IN_VEC4_T in_texel = IN_VEC4_T(texelFetch(t_in, pos, 0)); + imageStore(t_out, pos, OUT_VEC4_T(in_texel)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml new file mode 100644 index 00000000000..2aec8322dfe --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +view_convert_texture: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: float + STORAGE: texture3d + generate_variant_forall: + combination: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [int32, float] + - parameter_values: [float, int32] + shader_variants: + - NAME: view_convert_texture diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index efd61848af1..2bf3f8f726d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -367,7 +367,21 @@ utils::uvec3 conv2d_global_wg_size( utils::uvec3 wg_size = create_conv2d_global_wg_size( *graph, method, out, weight_data, stride_equals_dilation); - if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) { + if (method == Conv2dMethod::Depthwise) { + // The output_tile shaders (conv2d_dw_output_tile, + // conv2d_dw_sned_output_tile) use a 2D dispatch: (x_tile, y_tile) packed + // into glb_x, channel in glb_y. The base conv2d_dw shader uses a 1D + // dispatch: all (x, y, channel) packed into glb_x. For the base shader, we + // must use {W*H*C_packed, 1, 1}. + const bool uses_output_tile = + shader.kernel_name.find("_output_tile") != std::string::npos; + if (uses_output_tile) { + wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; + } else { + const utils::uvec3 base_extents = graph->create_global_wg_size(out); + wg_size = {base_extents[0] * base_extents[1] * base_extents[2], 1, 1}; + } + } else if (method == Conv2dMethod::Pointwise) { wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; if (shader.kernel_name.find("s1p0") != std::string::npos) { @@ -562,29 +576,45 @@ void add_conv2d_node( PushConstantDataInfo(¶m, sizeof(param)), }; } else if (method == Conv2dMethod::Depthwise) { - const utils::ivec4 kernel_param_size_stride = { - kernel_params.kernel_size[0], - kernel_params.kernel_size[1], - kernel_params.stride[0], - kernel_params.stride[1]}; - - const utils::ivec4 kernel_param_pad_dial = { - kernel_params.padding[0], - kernel_params.padding[1], - kernel_params.dilation[0], - kernel_params.dilation[1]}; - - push_constants = { - graph.logical_limits_pc_of(out), - graph.sizes_pc_of(in), - PushConstantDataInfo( - &kernel_param_size_stride, sizeof(kernel_param_size_stride)), - PushConstantDataInfo( - &kernel_param_pad_dial, sizeof(kernel_param_pad_dial)), - PushConstantDataInfo( - &extra_params, sizeof(extra_params), sizeof(utils::ivec4)), - PushConstantDataInfo(&out_params, sizeof(out_params)), - }; + // output_tile variants use push constants; the base conv2d_dw shader uses + // UBOs. Distinguish by checking if "_output_tile" is in the shader name. + const bool uses_output_tile = + shader.kernel_name.find("_output_tile") != std::string::npos; + + if (uses_output_tile) { + const utils::ivec4 kernel_param_size_stride = { + kernel_params.kernel_size[0], + kernel_params.kernel_size[1], + kernel_params.stride[0], + kernel_params.stride[1]}; + + const utils::ivec4 kernel_param_pad_dial = { + kernel_params.padding[0], + kernel_params.padding[1], + kernel_params.dilation[0], + kernel_params.dilation[1]}; + + push_constants = { + graph.logical_limits_pc_of(out), + graph.sizes_pc_of(in), + PushConstantDataInfo( + &kernel_param_size_stride, sizeof(kernel_param_size_stride)), + PushConstantDataInfo( + &kernel_param_pad_dial, sizeof(kernel_param_pad_dial)), + PushConstantDataInfo( + &extra_params, sizeof(extra_params), sizeof(utils::ivec4)), + PushConstantDataInfo(&out_params, sizeof(out_params)), + }; + } else { + // Base conv2d_dw shader uses UBOs, same as SlidingWindow case + param_buffers = { + graph.logical_limits_ubo(out), + graph.sizes_ubo(in), + graph.create_params_buffer(kernel_params), + graph.create_params_buffer(extra_params), + graph.create_params_buffer(out_params), + }; + } } else { param_buffers = { graph.logical_limits_ubo(out), diff --git a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp index b7e0218823a..275023faa59 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp @@ -9,8 +9,8 @@ #include #include #include +#include #include -#include namespace vkcompute { @@ -25,19 +25,29 @@ void resize_to_copy_op_node( graph->virtual_resize(out, graph->sizes_of(self)); } +bool is_float_type(vkapi::ScalarType dtype) { + return dtype == vkapi::ScalarType::Float || dtype == vkapi::ScalarType::Half; +} + void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) { - static std::set supported_types = { - vkapi::ScalarType::Float, vkapi::ScalarType::Half}; - - VK_CHECK_COND( - supported_types.find(graph.dtype_of(in)) != supported_types.end() && - supported_types.find(graph.dtype_of(out)) != supported_types.end(), - "Unsupported dtype for to_copy, only Float and Half are currently supported, recieved ", - vkapi::to_string(graph.dtype_of(in)), - " <-> ", - vkapi::to_string(graph.dtype_of(out))); - - graph.execute_nodes().emplace_back(new BlitNode(graph, in, out)); + vkapi::ScalarType in_dtype = graph.dtype_of(in); + vkapi::ScalarType out_dtype = graph.dtype_of(out); + + // Same-dtype or float<->half conversions can use BlitNode + if (in_dtype == out_dtype || + (is_float_type(in_dtype) && is_float_type(out_dtype))) { + graph.execute_nodes().emplace_back(new BlitNode(graph, in, out)); + return; + } + + // Other conversions (e.g. int<->float) use view_convert shaders + if (graph.is_buffer_storage(in)) { + add_view_copy_convert_buffer_node( + graph, in, out, {}, resize_to_copy_op_node); + } else { + add_view_copy_convert_texture_node( + graph, in, out, {}, resize_to_copy_op_node); + } } void to_copy(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 3b5ffb4589f..d0b70460214 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -127,6 +127,35 @@ void add_view_copy_buffer_node( resize_fn)); } +void add_view_copy_convert_texture_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn) { + std::string kernel_name = "view_convert_texture"; + add_dtype_suffix(kernel_name, graph.dtype_of(in)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Parameter Buffers + {}, + // Push Constants + {{graph.sizes_pc_of(out)}}, + // Specialization Constants + {graph.packed_dim_of(out)}, + // Resize Args + resize_args, + // Resizing Logic + resize_fn)); +} + void add_view_copy_convert_buffer_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/runtime/graph/ops/impl/View.h b/backends/vulkan/runtime/graph/ops/impl/View.h index c8e52492417..a72cf04d4d3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.h +++ b/backends/vulkan/runtime/graph/ops/impl/View.h @@ -37,6 +37,18 @@ void add_view_copy_convert_buffer_node( const std::vector& resize_args, const ExecuteNode::ResizeFunction& resize_fn); +/* + * Dispatches the view_convert_texture compute shader. This can be used to + * convert between different data types for 3D texture tensors while + * preserving the texel positions. + */ +void add_view_copy_convert_texture_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn); + void add_view_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index fe2e4169f05..6a9db70adaa 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1588,7 +1588,23 @@ def get_softmax_inputs(): "utils::kWidthPacked", "utils::kChannelsPacked", ] - return test_suite + + # Large negative values regression test (edgeTAM attention scores that + # produced NaN due to missing max-shift in softmax numerics) + large_neg_test_suite = VkTestSuite( + [ + ((1, 8, 512, 12), -1, False), + ] + ) + large_neg_test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kChannelsPacked", + ] + large_neg_test_suite.data_range = (-1.8e10, -6.5e9) + large_neg_test_suite.test_name_suffix = "large_negative" + large_neg_test_suite.dtypes = ["at::kFloat"] + + return [test_suite, large_neg_test_suite] @register_test_suite( diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index db1211883c7..e9d5613668a 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -21,6 +21,7 @@ FoldQDQPass, FuseQuantizedOpsTransform, insert_prepack_nodes, + InsertDtypePromotionPass, RemoveRedundantOpsTransform, SqueezeUnsqueezeInputs, TagMemoryMetaPass, @@ -165,6 +166,7 @@ def preprocess( # noqa: C901 AddmmToLinearTransform(), FuseBatchNormPass(program), AddmmToLinearTransform(), + InsertDtypePromotionPass(), FusePatternsPass(), FuseClampPass(), RemoveRedundantOpsTransform(),