From b2e15412177aae4f7fda5d2ae4982d9ba9a3f147 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 4 Mar 2026 08:29:29 -0800 Subject: [PATCH] [ET-VK] Fix softmax NaN and depthwise conv correctness bugs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix three bugs causing incorrect output when running the edgeTAM model with the Vulkan backend. Together these fixes bring the model from producing all-NaN output to matching the reference within fp32 tolerance. **Bug 1 — softmax_packed_dim OOB max contamination (softmax.glsl)** In `softmax_packed_dim`, each workgroup uses NWORKERS=4 threads to collaboratively reduce along the packed dimension. Before the main loop, each worker initializes `max_elements` by loading from texel index `tid.x`. When NWORKERS exceeds the number of texels (e.g., a 12-element dim has only 3 texels, but worker 3 tries to load texel index 3), the load is out-of-bounds and returns 0 per Vulkan spec. This 0 enters the cross-worker max reduction, so for any row where all actual values are negative, the computed max becomes 0 instead of the true (negative) max. Then `exp(value - 0)` underflows to 0 for all elements, giving denominator=0 and NaN output. Fixed by initializing `max_elements = vec4(-3.402823e+38)` (i.e., -FLT_MAX) so that workers with no valid texels contribute -inf to the reduction. Also added a `safe_denominator = max(denominator, 1e-37)` clamp as a secondary safety net against any remaining underflow edge cases. This affected the edgeTAM attention softmax over 12 key positions, where ~15% of query rows had all-negative attention scores and produced NaN. **Bug 1b — softmax_nonpacked_dim defensive hardening (softmax.glsl)** Applied similar defensive fixes to `softmax_nonpacked_dim`: - Clamped denominator via `max(denominators, vec4(1e-37))` to prevent 0/0 = NaN if all exp values underflow. - Added IEEE 754 bit-level NaN/Inf → 0 sanitization on output texels. This uses `floatBitsToUint`/`uintBitsToFloat` with exponent-bit masking rather than `isnan()` or `x != x`, which may not work reliably on all GPU drivers due to OpIsNan bugs and ordered comparison semantics. - Added `memoryBarrierImage()` after the output write loop to flush imageStore writes so they're visible to subsequent GPU operations. **Bug 2 — conv2d_dw parameter binding mismatch (Convolution.cpp)** The depthwise convolution code path in `add_conv2d_node` unconditionally passed kernel parameters (stride, padding, dilation, etc.) via push constants. However, the base `conv2d_dw.glsl` shader (used for non-3x3 and non-5x5 kernels, such as 1x1 depthwise convolutions) declares these parameters as UBOs at binding points 4–8, not as push constants. The `_output_tile` shader variants do use push constants, so 3x3 and 5x5 depthwise convolutions worked correctly. For 1x1 depthwise convolutions, the shader read from unbound UBOs, getting zeros for stride, padding, dilation, and overlay_region. With stride=0 and overlay_region=(0,0), the convolution loop never executed, producing output equal to just the bias (effectively zero for small biases). Fixed by checking whether the selected shader name contains `_output_tile`. If not, parameters are passed via UBOs (matching the shader's declarations) instead of push constants. **Bug 3 — conv2d_dw workgroup size mismatch (Convolution.cpp)** The base `conv2d_dw.glsl` shader uses a fully 1D thread mapping where `gl_GlobalInvocationID.x` encodes all three output dimensions: `pos.x = gid.x % W`, `pos.y = (gid.x / W) % H`, `pos.z = gid.x / (W * H)`. The `_output_tile` variants use a 2D mapping with spatial tiles in `.x` and channels in `.y`. The `conv2d_global_wg_size` callback was dispatching all depthwise shaders with workgroup size `{W*H, C_packed, 1}`, which is correct for `_output_tile` but wrong for the base shader. With this size, all threads have `gid.x < W*H`, so `pos.z = gid.x / (W*H) = 0` — only channel texel 0 (channels 0–3 out of e.g. 192) gets computed. Fixed by dispatching `{W*H*C_packed, 1, 1}` for the base shader so that `gid.x` ranges over all spatial × channel positions. Differential Revision: [D95217947](https://our.internmc.facebook.com/intern/diff/D95217947/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/softmax.glsl | 36 ++++++++- .../runtime/graph/ops/impl/Convolution.cpp | 78 +++++++++++++------ backends/vulkan/test/op_tests/cases.py | 18 ++++- 3 files changed, 103 insertions(+), 29 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl index 9b44d5c5a94..3176d0142bb 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/softmax.glsl @@ -142,7 +142,25 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { for (int i = tid.x; i < tin_sizes[reduce_dim]; i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { const vec4 numerators = op1(load_texel(tin, scan_pos) - max_elements); - vec4 outtex = op2(numerators, denominators); + // Clamp denominator to avoid 0/0 = NaN when all exp values underflow. + const vec4 safe_denom = max(denominators, vec4(1e-37)); + vec4 outtex = op2(numerators, safe_denom); + // Replace any NaN/Inf with 0 using IEEE 754 bit-level manipulation. + // This avoids isnan()/x!=x which may not work reliably on all GPU drivers: + // - OpIsNan may have driver bugs for certain NaN bit patterns + // - OpFOrdNotEqual(NaN,NaN) = false (ordered comparison semantics) + // NaN/Inf pattern: all exponent bits set = (bits & 0x7F800000) == 0x7F800000 + { + uvec4 bits = floatBitsToUint(outtex); + // Build a mask: 0xFFFFFFFF where NaN/Inf (exponent all-ones), else 0 + uvec4 nan_inf_mask = uvec4( + ((bits.x & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u, + ((bits.y & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u, + ((bits.z & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u, + ((bits.w & 0x7F800000u) == 0x7F800000u) ? 0xFFFFFFFFu : 0u); + // Zero out bits where NaN/Inf: normal values are unchanged + outtex = uintBitsToFloat(bits & ~nan_inf_mask); + } // For the last texel in the packed dim, make sure that the padding elements // are explicitly set to 0. Otherwise, they may influence computations later // down the line. @@ -153,6 +171,9 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { } write_texel(tout, scan_pos, outtex); } + // Flush outstanding imageStore writes so they're committed to memory and + // visible to subsequent GPU operations on this image. + memoryBarrierImage(); } /* @@ -173,7 +194,12 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) { const int reduce_len = tin_sizes[packed_dim] - nspill; scan_pos[reduce_dim] = tid.x; - vec4 max_elements = vec4(load_texel(tin, scan_pos).x); + // Initialize with -FLT_MAX to avoid contaminating the maximum with out-of- + // bounds texture reads. When NWORKERS > number of texels (e.g. reduce_len=12 + // has 3 texels but NWORKERS=4), worker threads with no valid texels would + // otherwise load from an OOB index and get 0, which corrupts the max for + // rows where all values are negative and causes denominator underflow -> NaN. + vec4 max_elements = vec4(-3.402823e+38); for (int i = tid.x * 4; i < reduce_len; i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) { max_elements = max(max_elements, load_texel(tin, scan_pos)); @@ -230,19 +256,21 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) { [[unroll]] for (int i = 0; i < 4; ++i) { denominator += denominators[i]; } + // Clamp denominator to avoid 0/0 = NaN when all exp values underflow. + const float safe_denominator = max(denominator, 1e-37); scan_pos[reduce_dim] = tid.x; for (int i = tid.x * 4; i < reduce_len; i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) { const vec4 numerators = op1(load_texel(tin, scan_pos) - max_element); - write_texel(tout, scan_pos, op2(numerators, denominator)); + write_texel(tout, scan_pos, op2(numerators, safe_denominator)); } // For the last texel in the dim, if there are padding elements then the // padding elements need to be set to 0 explicitly, otherwise they may // influence subsequent operations. if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) { const vec4 numerator = op1(load_texel(tin, scan_pos) - max_element); - vec4 outtex = op2(numerator, denominator); + vec4 outtex = op2(numerator, safe_denominator); [[unroll]] for (int i = nspill; i < 4; ++i) { outtex[i] = 0; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index efd61848af1..2bf3f8f726d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -367,7 +367,21 @@ utils::uvec3 conv2d_global_wg_size( utils::uvec3 wg_size = create_conv2d_global_wg_size( *graph, method, out, weight_data, stride_equals_dilation); - if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) { + if (method == Conv2dMethod::Depthwise) { + // The output_tile shaders (conv2d_dw_output_tile, + // conv2d_dw_sned_output_tile) use a 2D dispatch: (x_tile, y_tile) packed + // into glb_x, channel in glb_y. The base conv2d_dw shader uses a 1D + // dispatch: all (x, y, channel) packed into glb_x. For the base shader, we + // must use {W*H*C_packed, 1, 1}. + const bool uses_output_tile = + shader.kernel_name.find("_output_tile") != std::string::npos; + if (uses_output_tile) { + wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; + } else { + const utils::uvec3 base_extents = graph->create_global_wg_size(out); + wg_size = {base_extents[0] * base_extents[1] * base_extents[2], 1, 1}; + } + } else if (method == Conv2dMethod::Pointwise) { wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; if (shader.kernel_name.find("s1p0") != std::string::npos) { @@ -562,29 +576,45 @@ void add_conv2d_node( PushConstantDataInfo(¶m, sizeof(param)), }; } else if (method == Conv2dMethod::Depthwise) { - const utils::ivec4 kernel_param_size_stride = { - kernel_params.kernel_size[0], - kernel_params.kernel_size[1], - kernel_params.stride[0], - kernel_params.stride[1]}; - - const utils::ivec4 kernel_param_pad_dial = { - kernel_params.padding[0], - kernel_params.padding[1], - kernel_params.dilation[0], - kernel_params.dilation[1]}; - - push_constants = { - graph.logical_limits_pc_of(out), - graph.sizes_pc_of(in), - PushConstantDataInfo( - &kernel_param_size_stride, sizeof(kernel_param_size_stride)), - PushConstantDataInfo( - &kernel_param_pad_dial, sizeof(kernel_param_pad_dial)), - PushConstantDataInfo( - &extra_params, sizeof(extra_params), sizeof(utils::ivec4)), - PushConstantDataInfo(&out_params, sizeof(out_params)), - }; + // output_tile variants use push constants; the base conv2d_dw shader uses + // UBOs. Distinguish by checking if "_output_tile" is in the shader name. + const bool uses_output_tile = + shader.kernel_name.find("_output_tile") != std::string::npos; + + if (uses_output_tile) { + const utils::ivec4 kernel_param_size_stride = { + kernel_params.kernel_size[0], + kernel_params.kernel_size[1], + kernel_params.stride[0], + kernel_params.stride[1]}; + + const utils::ivec4 kernel_param_pad_dial = { + kernel_params.padding[0], + kernel_params.padding[1], + kernel_params.dilation[0], + kernel_params.dilation[1]}; + + push_constants = { + graph.logical_limits_pc_of(out), + graph.sizes_pc_of(in), + PushConstantDataInfo( + &kernel_param_size_stride, sizeof(kernel_param_size_stride)), + PushConstantDataInfo( + &kernel_param_pad_dial, sizeof(kernel_param_pad_dial)), + PushConstantDataInfo( + &extra_params, sizeof(extra_params), sizeof(utils::ivec4)), + PushConstantDataInfo(&out_params, sizeof(out_params)), + }; + } else { + // Base conv2d_dw shader uses UBOs, same as SlidingWindow case + param_buffers = { + graph.logical_limits_ubo(out), + graph.sizes_ubo(in), + graph.create_params_buffer(kernel_params), + graph.create_params_buffer(extra_params), + graph.create_params_buffer(out_params), + }; + } } else { param_buffers = { graph.logical_limits_ubo(out), diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index fe2e4169f05..6a9db70adaa 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1588,7 +1588,23 @@ def get_softmax_inputs(): "utils::kWidthPacked", "utils::kChannelsPacked", ] - return test_suite + + # Large negative values regression test (edgeTAM attention scores that + # produced NaN due to missing max-shift in softmax numerics) + large_neg_test_suite = VkTestSuite( + [ + ((1, 8, 512, 12), -1, False), + ] + ) + large_neg_test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kChannelsPacked", + ] + large_neg_test_suite.data_range = (-1.8e10, -6.5e9) + large_neg_test_suite.test_name_suffix = "large_negative" + large_neg_test_suite.dtypes = ["at::kFloat"] + + return [test_suite, large_neg_test_suite] @register_test_suite(