From 7bc669b06e78780315085ddd6057c39c473279cf Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 4 Mar 2026 08:29:33 -0800 Subject: [PATCH 1/3] [ET-VK] Fix softmax NaN and depthwise conv correctness bugs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix three bugs causing incorrect output when running the edgeTAM model with the Vulkan backend. Together these fixes bring the model from producing all-NaN output to matching the reference within fp32 tolerance. **Bug 1 — softmax_packed_dim OOB max contamination (softmax.glsl)** In `softmax_packed_dim`, each workgroup uses NWORKERS=4 threads to collaboratively reduce along the packed dimension. Before the main loop, each worker initializes `max_elements` by loading from texel index `tid.x`. When NWORKERS exceeds the number of texels (e.g., a 12-element dim has only 3 texels, but worker 3 tries to load texel index 3), the load is out-of-bounds and returns 0 per Vulkan spec. This 0 enters the cross-worker max reduction, so for any row where all actual values are negative, the computed max becomes 0 instead of the true (negative) max. Then `exp(value - 0)` underflows to 0 for all elements, giving denominator=0 and NaN output. Fixed by initializing `max_elements = vec4(-3.402823e+38)` (i.e., -FLT_MAX) so that workers with no valid texels contribute -inf to the reduction. Also added a `safe_denominator = max(denominator, 1e-37)` clamp as a secondary safety net against any remaining underflow edge cases. This affected the edgeTAM attention softmax over 12 key positions, where ~15% of query rows had all-negative attention scores and produced NaN. **Bug 1b — softmax_nonpacked_dim defensive hardening (softmax.glsl)** Applied similar defensive fixes to `softmax_nonpacked_dim`: - Clamped denominator via `max(denominators, vec4(1e-37))` to prevent 0/0 = NaN if all exp values underflow. - Added IEEE 754 bit-level NaN/Inf → 0 sanitization on output texels. This uses `floatBitsToUint`/`uintBitsToFloat` with exponent-bit masking rather than `isnan()` or `x != x`, which may not work reliably on all GPU drivers due to OpIsNan bugs and ordered comparison semantics. - Added `memoryBarrierImage()` after the output write loop to flush imageStore writes so they're visible to subsequent GPU operations. **Bug 2 — conv2d_dw parameter binding mismatch (Convolution.cpp)** The depthwise convolution code path in `add_conv2d_node` unconditionally passed kernel parameters (stride, padding, dilation, etc.) via push constants. However, the base `conv2d_dw.glsl` shader (used for non-3x3 and non-5x5 kernels, such as 1x1 depthwise convolutions) declares these parameters as UBOs at binding points 4–8, not as push constants. The `_output_tile` shader variants do use push constants, so 3x3 and 5x5 depthwise convolutions worked correctly. For 1x1 depthwise convolutions, the shader read from unbound UBOs, getting zeros for stride, padding, dilation, and overlay_region. With stride=0 and overlay_region=(0,0), the convolution loop never executed, producing output equal to just the bias (effectively zero for small biases). Fixed by checking whether the selected shader name contains `_output_tile`. If not, parameters are passed via UBOs (matching the shader's declarations) instead of push constants. **Bug 3 — conv2d_dw workgroup size mismatch (Convolution.cpp)** The base `conv2d_dw.glsl` shader uses a fully 1D thread mapping where `gl_GlobalInvocationID.x` encodes all three output dimensions: `pos.x = gid.x % W`, `pos.y = (gid.x / W) % H`, `pos.z = gid.x / (W * H)`. The `_output_tile` variants use a 2D mapping with spatial tiles in `.x` and channels in `.y`. The `conv2d_global_wg_size` callback was dispatching all depthwise shaders with workgroup size `{W*H, C_packed, 1}`, which is correct for `_output_tile` but wrong for the base shader. With this size, all threads have `gid.x < W*H`, so `pos.z = gid.x / (W*H) = 0` — only channel texel 0 (channels 0–3 out of e.g. 192) gets computed. Fixed by dispatching `{W*H*C_packed, 1, 1}` for the base shader so that `gid.x` ranges over all spatial × channel positions. Differential Revision: [D95217947](https://our.internmc.facebook.com/intern/diff/D95217947/) ghstack-source-id: 347411472 Pull Request resolved: https://github.com/pytorch/executorch/pull/17848 --- .../runtime/graph/ops/glsl/softmax.glsl | 36 ++++++++- .../runtime/graph/ops/impl/Convolution.cpp | 78 +++++++++++++------ backends/vulkan/test/op_tests/cases.py | 18 ++++- 3 files changed, 103 insertions(+), 29 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl index 9b44d5c5a94..3176d0142bb 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl @@ -142,7 +142,25 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { for (int i = tid.x; i < tin_sizes[reduce_dim]; i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { const vec4 numerators = op1(load_texel(tin, scan_pos) - max_elements); - vec4 outtex = op2(numerators, denominators); + // Clamp denominator to avoid 0/0 = NaN when all exp values underflow. + const vec4 safe_denom = max(denominators, vec4(1e-37)); + vec4 outtex = op2(numerators, safe_denom); + // Replace any NaN/Inf with 0 using IEEE 754 bit-level manipulation. + // This avoids isnan()/x!=x which may not work reliably on all GPU drivers: + // - OpIsNan may have driver bugs for certain NaN bit patterns + // - OpFOrdNotEqual(NaN,NaN) = false (ordered comparison semantics) + // NaN/Inf pattern: all exponent bits set = (bits & 0x7F800000) == 0x7F800000 + { + uvec4 bits = floatBitsToUint(outtex); + // Build a mask: 0xFFFFFFFF where NaN/Inf (exponent all-ones), else 0 + uvec4 nan_inf_mask = uvec4( + ((bits.x & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u, + ((bits.y & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u, + ((bits.z & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u, + ((bits.w & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u); + // Zero out bits where NaN/Inf: normal values are unchanged + outtex = uintBitsToFloat(bits & ~nan_inf_mask); + } // For the last texel in the packed dim, make sure that the padding elements // are explicitly set to 0. Otherwise, they may influence computations later // down the line. @@ -153,6 +171,9 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { } write_texel(tout, scan_pos, outtex); } + // Flush outstanding imageStore writes so they're committed to memory and + // visible to subsequent GPU operations on this image. + memoryBarrierImage(); } /* @@ -173,7 +194,12 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) { const int reduce_len = tin_sizes[packed_dim] - nspill; scan_pos[reduce_dim] = tid.x; - vec4 max_elements = vec4(load_texel(tin, scan_pos).x); + // Initialize with -FLT_MAX to avoid contaminating the maximum with out-of- + // bounds texture reads. When NWORKERS > number of texels (e.g. reduce_len=12 + // has 3 texels but NWORKERS=4), worker threads with no valid texels would + // otherwise load from an OOB index and get 0, which corrupts the max for + // rows where all values are negative and causes denominator underflow -> NaN. + vec4 max_elements = vec4(-3.402823e+38); for (int i = tid.x * 4; i < reduce_len; i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) { max_elements = max(max_elements, load_texel(tin, scan_pos)); @@ -230,19 +256,21 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) { [[unroll]] for (int i = 0; i < 4; ++i) { denominator += denominators[i]; } + // Clamp denominator to avoid 0/0 = NaN when all exp values underflow. + const float safe_denominator = max(denominator, 1e-37); scan_pos[reduce_dim] = tid.x; for (int i = tid.x * 4; i < reduce_len; i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) { const vec4 numerators = op1(load_texel(tin, scan_pos) - max_element); - write_texel(tout, scan_pos, op2(numerators, denominator)); + write_texel(tout, scan_pos, op2(numerators, safe_denominator)); } // For the last texel in the dim, if there are padding elements then the // padding elements need to be set to 0 explicitly, otherwise they may // influence subsequent operations. if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) { const vec4 numerator = op1(load_texel(tin, scan_pos) - max_element); - vec4 outtex = op2(numerator, denominator); + vec4 outtex = op2(numerator, safe_denominator); [[unroll]] for (int i = nspill; i < 4; ++i) { outtex[i] = 0; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index efd61848af1..2bf3f8f726d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -367,7 +367,21 @@ utils::uvec3 conv2d_global_wg_size( utils::uvec3 wg_size = create_conv2d_global_wg_size( *graph, method, out, weight_data, stride_equals_dilation); - if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) { + if (method == Conv2dMethod::Depthwise) { + // The output_tile shaders (conv2d_dw_output_tile, + // conv2d_dw_sned_output_tile) use a 2D dispatch: (x_tile, y_tile) packed + // into glb_x, channel in glb_y. The base conv2d_dw shader uses a 1D + // dispatch: all (x, y, channel) packed into glb_x. For the base shader, we + // must use {W*H*C_packed, 1, 1}. + const bool uses_output_tile = + shader.kernel_name.find("_output_tile") != std::string::npos; + if (uses_output_tile) { + wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; + } else { + const utils::uvec3 base_extents = graph->create_global_wg_size(out); + wg_size = {base_extents[0] * base_extents[1] * base_extents[2], 1, 1}; + } + } else if (method == Conv2dMethod::Pointwise) { wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; if (shader.kernel_name.find("s1p0") != std::string::npos) { @@ -562,29 +576,45 @@ void add_conv2d_node( PushConstantDataInfo(¶m, sizeof(param)), }; } else if (method == Conv2dMethod::Depthwise) { - const utils::ivec4 kernel_param_size_stride = { - kernel_params.kernel_size[0], - kernel_params.kernel_size[1], - kernel_params.stride[0], - kernel_params.stride[1]}; - - const utils::ivec4 kernel_param_pad_dial = { - kernel_params.padding[0], - kernel_params.padding[1], - kernel_params.dilation[0], - kernel_params.dilation[1]}; - - push_constants = { - graph.logical_limits_pc_of(out), - graph.sizes_pc_of(in), - PushConstantDataInfo( - &kernel_param_size_stride, sizeof(kernel_param_size_stride)), - PushConstantDataInfo( - &kernel_param_pad_dial, sizeof(kernel_param_pad_dial)), - PushConstantDataInfo( - &extra_params, sizeof(extra_params), sizeof(utils::ivec4)), - PushConstantDataInfo(&out_params, sizeof(out_params)), - }; + // output_tile variants use push constants; the base conv2d_dw shader uses + // UBOs. Distinguish by checking if "_output_tile" is in the shader name. + const bool uses_output_tile = + shader.kernel_name.find("_output_tile") != std::string::npos; + + if (uses_output_tile) { + const utils::ivec4 kernel_param_size_stride = { + kernel_params.kernel_size[0], + kernel_params.kernel_size[1], + kernel_params.stride[0], + kernel_params.stride[1]}; + + const utils::ivec4 kernel_param_pad_dial = { + kernel_params.padding[0], + kernel_params.padding[1], + kernel_params.dilation[0], + kernel_params.dilation[1]}; + + push_constants = { + graph.logical_limits_pc_of(out), + graph.sizes_pc_of(in), + PushConstantDataInfo( + &kernel_param_size_stride, sizeof(kernel_param_size_stride)), + PushConstantDataInfo( + &kernel_param_pad_dial, sizeof(kernel_param_pad_dial)), + PushConstantDataInfo( + &extra_params, sizeof(extra_params), sizeof(utils::ivec4)), + PushConstantDataInfo(&out_params, sizeof(out_params)), + }; + } else { + // Base conv2d_dw shader uses UBOs, same as SlidingWindow case + param_buffers = { + graph.logical_limits_ubo(out), + graph.sizes_ubo(in), + graph.create_params_buffer(kernel_params), + graph.create_params_buffer(extra_params), + graph.create_params_buffer(out_params), + }; + } } else { param_buffers = { graph.logical_limits_ubo(out), diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index fe2e4169f05..6a9db70adaa 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1588,7 +1588,23 @@ def get_softmax_inputs(): "utils::kWidthPacked", "utils::kChannelsPacked", ] - return test_suite + + # Large negative values regression test (edgeTAM attention scores that + # produced NaN due to missing max-shift in softmax numerics) + large_neg_test_suite = VkTestSuite( + [ + ((1, 8, 512, 12), -1, False), + ] + ) + large_neg_test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kChannelsPacked", + ] + large_neg_test_suite.data_range = (-1.8e10, -6.5e9) + large_neg_test_suite.test_name_suffix = "large_negative" + large_neg_test_suite.dtypes = ["at::kFloat"] + + return [test_suite, large_neg_test_suite] @register_test_suite( From 1582219de361829ab29c5e80e48b4becfe6d6be4 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 4 Mar 2026 08:29:38 -0800 Subject: [PATCH 2/3] [ET-VK] Fix mixed-dtype binary ops and comparison op padding bugs Two bugs caused incorrect outputs in models with mixed-dtype binary operations (e.g. EdgeTAM remaining frames): 1. Mixed-dtype binary ops (e.g. int arange vs float tensor) were fed to shaders that declare both inputs with the same DTYPE, causing data misinterpretation. This is now fixed by adding an `InsertDtypePromotionPass` export pass that inserts `_to_copy` nodes to promote inputs to a common dtype at compile time. The `_to_copy` op is extended to support int<->float conversions via new `view_convert_texture` shaders, and the previous float/half-only restriction in ToCopy.cpp is replaced with branching logic that uses BlitNode for same-dtype/float<->half and view_convert shaders for other conversions. 2. Texture3d comparison operators (gt, lt, le, ge, eq) used `all()` to reduce component-wise `bvec4` results to a single bool. With packed textures where padding components are zero, `all()` always returned false because padding zeros fail comparison against non-zero values. Fixed by removing `all()` so the result stays as a component-wise `bvec4`, which is correctly converted to `uvec4` for the Bool output texture. Additional changes: - New `view_convert_texture.glsl` shader and YAML for texture dtype conversion - `add_view_copy_convert_texture_node` added to View.cpp/h - `_to_copy` op registry updated to accept int dtypes (FP_INT_T) Differential Revision: [D95217948](https://our.internmc.facebook.com/intern/diff/D95217948/) ghstack-source-id: 347411474 Pull Request resolved: https://github.com/pytorch/executorch/pull/17849 --- backends/vulkan/_passes/TARGETS | 14 +++ backends/vulkan/_passes/__init__.py | 4 + .../vulkan/_passes/insert_dtype_promotion.py | 102 ++++++++++++++++++ backends/vulkan/op_registry.py | 4 +- .../runtime/graph/ops/glsl/binary_op.yaml | 10 +- .../graph/ops/glsl/view_convert_texture.glsl | 44 ++++++++ .../graph/ops/glsl/view_convert_texture.yaml | 19 ++++ .../vulkan/runtime/graph/ops/impl/ToCopy.cpp | 36 ++++--- .../vulkan/runtime/graph/ops/impl/View.cpp | 29 +++++ backends/vulkan/runtime/graph/ops/impl/View.h | 12 +++ backends/vulkan/vulkan_preprocess.py | 2 + 11 files changed, 256 insertions(+), 20 deletions(-) create mode 100644 backends/vulkan/_passes/insert_dtype_promotion.py create mode 100644 backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 453b4814637..46717a52014 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -104,6 +104,19 @@ runtime.python_library( ], ) +runtime.python_library( + name = "insert_dtype_promotion", + srcs = ["insert_dtype_promotion.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + runtime.python_library( name = "fuse_patterns", srcs = ["fuse_patterns.py"], @@ -133,6 +146,7 @@ runtime.python_library( ":fold_qdq", ":fuse_patterns", ":fuse_quantized_ops", + ":insert_dtype_promotion", ":insert_prepack_nodes", ":remove_asserts", ":remove_redundant_ops", diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index d6a6823ca88..1afaf48dde7 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -11,6 +11,9 @@ from executorch.backends.vulkan._passes.fuse_quantized_ops import ( FuseQuantizedOpsTransform, ) +from executorch.backends.vulkan._passes.insert_dtype_promotion import ( + InsertDtypePromotionPass, +) from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes from executorch.backends.vulkan._passes.remove_asserts import ( remove_asserts, @@ -28,6 +31,7 @@ "FoldQDQPass", "FusePatternsPass", "FuseQuantizedOpsTransform", + "InsertDtypePromotionPass", "insert_prepack_nodes", "remove_asserts", "RemoveAssertsTransform", diff --git a/backends/vulkan/_passes/insert_dtype_promotion.py b/backends/vulkan/_passes/insert_dtype_promotion.py new file mode 100644 index 00000000000..324273a69df --- /dev/null +++ b/backends/vulkan/_passes/insert_dtype_promotion.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Set, Union + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + +OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload] + +# Binary ops whose first two args are tensor inputs that may need promotion +BINARY_OPS: Set[OpType] = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.pow.Tensor_Tensor, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.ge.Tensor, +} + + +def _promote_dtype(a: torch.dtype, b: torch.dtype) -> torch.dtype: + """Promote to common dtype following PyTorch type promotion rules.""" + if a == b: + return a + # Any mix of different dtypes promotes to float32 + return torch.float32 + + +class InsertDtypePromotionPass(ExportPass): + """ + Insert _to_copy nodes before binary ops when the two tensor inputs have + different dtypes, promoting both to a common dtype. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + dirty = False + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in BINARY_OPS: + continue + + lhs = node.args[0] + rhs = node.args[1] + + if not isinstance(lhs, torch.fx.Node) or not isinstance(rhs, torch.fx.Node): + continue + + if "val" not in lhs.meta or "val" not in rhs.meta: + continue + + lhs_dtype = lhs.meta["val"].dtype + rhs_dtype = rhs.meta["val"].dtype + + if lhs_dtype == rhs_dtype: + continue + + promoted = _promote_dtype(lhs_dtype, rhs_dtype) + + if lhs_dtype != promoted: + with graph_module.graph.inserting_before(node): + cast_lhs = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (lhs,), + {"dtype": promoted}, + ) + cast_lhs.meta["val"] = lhs.meta["val"].to(promoted) + node.replace_input_with(lhs, cast_lhs) + dirty = True + + if rhs_dtype != promoted: + with graph_module.graph.inserting_before(node): + cast_rhs = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (rhs,), + {"dtype": promoted}, + ) + cast_rhs.meta["val"] = rhs.meta["val"].to(promoted) + node.replace_input_with(rhs, cast_rhs) + dirty = True + + if dirty: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + dead_code_elimination_pass(graph_module) + + return PassResult(graph_module, dirty) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 62997ea956f..bb7c0562bad 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -307,8 +307,8 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: return OpFeatures( inputs_storage=utils.ANY_STORAGE, - inputs_dtypes=utils.FP_T, - outputs_dtypes=utils.FP_T, + inputs_dtypes=utils.FP_INT_T, + outputs_dtypes=utils.FP_INT_T, supports_resize=True, are_node_inputs_supported_fn=check_to_copy_node, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index c3d5cd00204..bf07f1a58d2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -46,7 +46,7 @@ binary_op: - VALUE: half - VALUE: float - NAME: binary_eq_texture3d - OPERATOR: all(lessThanEqual(abs(X - Y), VEC4_T(1e-5))) + OPERATOR: lessThanEqual(abs(X - Y), VEC4_T(1e-5)) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -61,7 +61,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_lt_texture3d - OPERATOR: all(lessThan(X, Y)) + OPERATOR: lessThan(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -77,7 +77,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_le_texture3d - OPERATOR: all(lessThanEqual(X, Y)) + OPERATOR: lessThanEqual(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -93,7 +93,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_gt_texture3d - OPERATOR: all(greaterThan(X, Y)) + OPERATOR: greaterThan(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: @@ -109,7 +109,7 @@ binary_op: - VALUE: float - VALUE: int32 - NAME: binary_ge_texture3d - OPERATOR: all(greaterThanEqual(X, Y)) + OPERATOR: greaterThanEqual(X, Y) STORAGE: texture3d generate_variant_forall: DTYPE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl new file mode 100644 index 00000000000..b2f946d6f36 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.glsl @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_VEC4_T ${texel_type(IN_DTYPE)} +#define OUT_VEC4_T ${texel_type(OUT_DTYPE)} + +${define_required_extensions("texture3d", IN_DTYPE)} +${define_required_extensions("texture3d", OUT_DTYPE)} + +#include "indexing_utils.h" + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = C_DIM; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim); + + if (any(greaterThanEqual(idx, out_sizes))) { + return; + } + + IN_VEC4_T in_texel = IN_VEC4_T(texelFetch(t_in, pos, 0)); + imageStore(t_out, pos, OUT_VEC4_T(in_texel)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml new file mode 100644 index 00000000000..2aec8322dfe --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_texture.yaml @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +view_convert_texture: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: float + STORAGE: texture3d + generate_variant_forall: + combination: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [int32, float] + - parameter_values: [float, int32] + shader_variants: + - NAME: view_convert_texture diff --git a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp index b7e0218823a..275023faa59 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp @@ -9,8 +9,8 @@ #include #include #include +#include #include -#include namespace vkcompute { @@ -25,19 +25,29 @@ void resize_to_copy_op_node( graph->virtual_resize(out, graph->sizes_of(self)); } +bool is_float_type(vkapi::ScalarType dtype) { + return dtype == vkapi::ScalarType::Float || dtype == vkapi::ScalarType::Half; +} + void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) { - static std::set supported_types = { - vkapi::ScalarType::Float, vkapi::ScalarType::Half}; - - VK_CHECK_COND( - supported_types.find(graph.dtype_of(in)) != supported_types.end() && - supported_types.find(graph.dtype_of(out)) != supported_types.end(), - "Unsupported dtype for to_copy, only Float and Half are currently supported, recieved ", - vkapi::to_string(graph.dtype_of(in)), - " <-> ", - vkapi::to_string(graph.dtype_of(out))); - - graph.execute_nodes().emplace_back(new BlitNode(graph, in, out)); + vkapi::ScalarType in_dtype = graph.dtype_of(in); + vkapi::ScalarType out_dtype = graph.dtype_of(out); + + // Same-dtype or float<->half conversions can use BlitNode + if (in_dtype == out_dtype || + (is_float_type(in_dtype) && is_float_type(out_dtype))) { + graph.execute_nodes().emplace_back(new BlitNode(graph, in, out)); + return; + } + + // Other conversions (e.g. int<->float) use view_convert shaders + if (graph.is_buffer_storage(in)) { + add_view_copy_convert_buffer_node( + graph, in, out, {}, resize_to_copy_op_node); + } else { + add_view_copy_convert_texture_node( + graph, in, out, {}, resize_to_copy_op_node); + } } void to_copy(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 3b5ffb4589f..d0b70460214 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -127,6 +127,35 @@ void add_view_copy_buffer_node( resize_fn)); } +void add_view_copy_convert_texture_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn) { + std::string kernel_name = "view_convert_texture"; + add_dtype_suffix(kernel_name, graph.dtype_of(in)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Parameter Buffers + {}, + // Push Constants + {{graph.sizes_pc_of(out)}}, + // Specialization Constants + {graph.packed_dim_of(out)}, + // Resize Args + resize_args, + // Resizing Logic + resize_fn)); +} + void add_view_copy_convert_buffer_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/runtime/graph/ops/impl/View.h b/backends/vulkan/runtime/graph/ops/impl/View.h index c8e52492417..a72cf04d4d3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.h +++ b/backends/vulkan/runtime/graph/ops/impl/View.h @@ -37,6 +37,18 @@ void add_view_copy_convert_buffer_node( const std::vector& resize_args, const ExecuteNode::ResizeFunction& resize_fn); +/* + * Dispatches the view_convert_texture compute shader. This can be used to + * convert between different data types for 3D texture tensors while + * preserving the texel positions. + */ +void add_view_copy_convert_texture_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn); + void add_view_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index db1211883c7..e9d5613668a 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -21,6 +21,7 @@ FoldQDQPass, FuseQuantizedOpsTransform, insert_prepack_nodes, + InsertDtypePromotionPass, RemoveRedundantOpsTransform, SqueezeUnsqueezeInputs, TagMemoryMetaPass, @@ -165,6 +166,7 @@ def preprocess( # noqa: C901 AddmmToLinearTransform(), FuseBatchNormPass(program), AddmmToLinearTransform(), + InsertDtypePromotionPass(), FusePatternsPass(), FuseClampPass(), RemoveRedundantOpsTransform(), From 8691147c2948873cf59feda08fc393f6f772358c Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 4 Mar 2026 08:29:43 -0800 Subject: [PATCH 3/3] [ET-VK] Insert prepack nodes for constant primary inputs of prepacking ops The insert_prepack_nodes pass was skipping prepack node insertion for all constant tensor args of ops with supports_prepacking=True. However, these ops only handle prepacking for weight/bias tensors internally; the primary input tensor is still expected to be a GPU tensor. If the primary input happens to be a constant tensor (serialized as TensorRef), the op throws an exception at runtime. Fix this by detecting the primary input index directly in insert_prepack_nodes. Most prepacking ops have the primary input at arg 0, but embedding uses arg 1 since its signature is embedding(weight, indices, ...). The pass now checks whether a constant tensor is used as the primary input of a prepacking op, and if so, still inserts a prepack node for it. Differential Revision: [D95217949](https://our.internmc.facebook.com/intern/diff/D95217949/) ghstack-source-id: 347411473 Pull Request resolved: https://github.com/pytorch/executorch/pull/17850 --- .../vulkan/_passes/insert_prepack_nodes.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index c45ed4ea25d..373b2a4d135 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -37,10 +37,28 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: # Vulkan compute graph. This annotation is used in later graph passes. node.meta["etvk_tensorref"] = True - # Get the list of node users that do not handle their own prepacking + # Get the list of node users that need a prepack node inserted. This + # includes ops that don't handle their own prepacking, as well as ops + # that do handle their own prepacking but use this constant tensor as + # the primary input (since the op expects the primary input to be a GPU + # tensor, not a TensorRef). nodes_to_replace_input = [] for user in node.users: - if user.op == "call_function" and not handles_own_prepacking(user.target): + if user.op != "call_function": + continue + + if not handles_own_prepacking(user.target): + nodes_to_replace_input.append(user) + continue + + # Most prepacking ops have the primary input at arg 0, but + # embedding is embedding(weight, indices, ...) where the + # primary input (indices) is at arg 1. + primary_arg_idx = 0 + if user.target == exir_ops.edge.aten.embedding.default: + primary_arg_idx = 1 + + if node in user.args and user.args.index(node) == primary_arg_idx: nodes_to_replace_input.append(user) if len(nodes_to_replace_input) == 0: