From 4f2bffefd32ed699f7fe75ae7097ee67e3cc7105 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 23 Mar 2026 07:22:49 -0700 Subject: [PATCH] [ET-VK][sdpa] Use numerically-stable softmax in attention weights The SDPA attention weights softmax shader computed naive softmax: exp(x) / sum(exp(x)). When attention weights are large (e.g., 151.29 for Phi-4-mini with head_dim=128), exp(x) overflows float32 (threshold ~88.7), producing Infinity and then NaN from inf/inf in the normalization step. This replaces the naive softmax with the standard numerically-stable variant: exp(x - max(x)) / sum(exp(x - max(x))). The implementation adds a cooperative max-finding pass (same workgroup reduction pattern as the existing exp_sum pass) before the exp_sum and normalization passes. The max subtraction ensures that the largest exponent is 0, preventing overflow. This fixes Phi-4-mini Vulkan inference which previously produced garbage output due to NaN propagation from the first transformer layer's attention. On-device A/B benchmarks on Samsung Galaxy S24 (Adreno 750) with Llama 3.2 1B (8da4w g128 q4emb, 677 MB) confirm no performance regression: Llama 3.2 1B (short prompt, 4 tokens, --warmup): Prefill: 67.2 tok/s | Decode: 59.4 tok/s | TTFT: 60 ms Llama 3.2 1B (medium prompt, 197 tokens, --warmup): Prefill: 723.5 tok/s | Decode: 53.3 tok/s | TTFT: 273 ms These numbers are within run-to-run variance of the baseline (no fix) measurements, confirming the additional max-finding pass has negligible overhead. Differential Revision: [D97757920](https://our.internmc.facebook.com/intern/diff/D97757920/) [ghstack-poisoned] --- .../ops/glsl/sdpa_attn_weights_softmax.glsl | 82 +++++++++++++++---- 1 file changed, 64 insertions(+), 18 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl index 5560fa6e11c..e6c118b6ab2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl @@ -32,7 +32,8 @@ ${layout_declare_ubo(B, "int", "input_pos")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -// Shared memory for cooperative exp sum finding +// Shared memory for cooperative max finding and exp sum reduction +shared T shared_max[NUM_WORKERS_PER_WG]; shared T shared_exp_sum[NUM_WORKERS_PER_WG]; VEC4_T load_attn_weights_c4( @@ -87,24 +88,24 @@ void main() { return; } - // Initialize thread-local min/max - T local_exp_sum = T(0); - const int context_len_aligned_down = context_len - mod_4(context_len); const int C4_limit = div_4(context_len_aligned_down); - // Each thread processes elements along a context_len row with a stride of the - // number of threads in the work group. + // ========================================================================= + // Pass 1: Find the maximum value across the row for numerical stability. + // Without this, exp(x) can overflow float32 when x > ~88.7. + // ========================================================================= + + T local_max = T(-1.0 / 0.0); // -infinity + for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { VEC4_T in_texel = load_attn_weights_c4( c4, s, q_h, context_texel_len, S_aligned, Q_H); for (int comp = 0; comp < 4; comp++) { - local_exp_sum += exp(in_texel[comp]); + local_max = max(local_max, in_texel[comp]); } } - // First thread in the work group responsible for handling last texel if it - // contains any padded elements if (worker_id == 0) { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); @@ -113,19 +114,63 @@ void main() { [[unroll]] for (int comp = 0; comp < 4; comp++) { if (c_base + comp < context_len) { - local_exp_sum += exp(in_texel[comp]); + local_max = max(local_max, in_texel[comp]); + } + } + } + } + + shared_max[worker_id] = local_max; + + memoryBarrierShared(); + barrier(); + + // Tree reduction to find the global max + for (int i = NUM_WORKERS_PER_WG / 2; i > 0; i >>= 1) { + if (worker_id < i) { + shared_max[worker_id] = max( + shared_max[worker_id], shared_max[worker_id + i]); + } + memoryBarrierShared(); + barrier(); + } + + const T global_max = shared_max[0]; + + // ========================================================================= + // Pass 2: Compute sum(exp(x - max)) using the global max for stability + // ========================================================================= + + T local_exp_sum = T(0); + + for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { + VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, S_aligned, Q_H); + + for (int comp = 0; comp < 4; comp++) { + local_exp_sum += exp(in_texel[comp] - global_max); + } + } + if (worker_id == 0) { + for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { + const int c_base = mul_4(c4); + VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, S_aligned, Q_H); + + [[unroll]] for (int comp = 0; comp < 4; comp++) { + if (c_base + comp < context_len) { + local_exp_sum += exp(in_texel[comp] - global_max); } } } } - // Store thread-local results in shared memory shared_exp_sum[worker_id] = local_exp_sum; memoryBarrierShared(); barrier(); - // Tree reduction to compute the overall result + // Tree reduction to compute the overall exp sum for (int i = NUM_WORKERS_PER_WG / 2; i > 0; i >>= 1) { if (worker_id < i) { shared_exp_sum[worker_id] = shared_exp_sum[worker_id] + @@ -136,28 +181,29 @@ void main() { } local_exp_sum = shared_exp_sum[0]; - // Now go back through each element in the row and normalize + + // ========================================================================= + // Pass 3: Normalize each element: out = exp(x - max) / sum(exp(x - max)) + // ========================================================================= + for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { VEC4_T in_texel = load_attn_weights_c4( c4, s, q_h, context_texel_len, S_aligned, Q_H); - VEC4_T out_texel = exp(in_texel) / local_exp_sum; + VEC4_T out_texel = exp(in_texel - global_max) / local_exp_sum; store_attn_weights_softmax_c4( out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H); } - // First thread in the work group responsible for handling last texel if it - // contains any padded elements if (worker_id == 0) { for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { const int c_base = mul_4(c4); VEC4_T in_texel = load_attn_weights_c4( c4, s, q_h, context_texel_len, S_aligned, Q_H); - // Ensure that padding elements are set to 0. VEC4_T out_texel = VEC4_T(0); [[unroll]] for (int comp = 0; comp < 4; comp++) { if (c_base + comp < context_len) { - out_texel[comp] = exp(in_texel[comp]) / local_exp_sum; + out_texel[comp] = exp(in_texel[comp] - global_max) / local_exp_sum; } } store_attn_weights_softmax_c4(