Skip to content

Conversation

@xenova
Copy link
Contributor

@xenova xenova commented Dec 9, 2025

Description

This PR fixes the last tests that were failing in #26715 (comment), where rotary_interleaved=1 in GQA kernel. The root cause was that the rotary_interleaved parameter was not being propagated correctly, meaning it always defaulted to 0 in FusedQKRotaryEmbeddingProgram.

Testing on Providers: ['CPUExecutionProvider', 'WebGpuExecutionProvider']
=================================================================================
Prefill_ColdStart    | In:3 Past:0 Total:3 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08)
Prefill_ColdStart    | In:16 Past:0 Total:16 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Decode_Early         | In:1 Past:16 Total:17 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08)
Decode_Deep          | In:1 Past:64 Total:65 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Speculative_Dec      | In:4 Past:20 Total:24 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
Batch_Prefill        | In:16 Past:0 Total:16 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Batch_Decode         | In:1 Past:32 Total:33 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
GQA_Prefill          | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
GQA_Decode           | In:1 Past:32 Total:33 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
GQA_Batch_Dec        | In:1 Past:32 Total:33 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
MQA_Prefill          | In:32 Past:0 Total:32 H:8 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
MQA_Decode           | In:1 Past:32 Total:33 H:8 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
LgBatch_MHA          | In:1 Past:16 Total:17 H:4 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
LgBatch_GQA          | In:1 Past:16 Total:17 H:8 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Odd_SeqLen           | In:7 Past:13 Total:20 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
Odd_Heads            | In:1 Past:10 Total:11 H:6 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
HighHeads_MHA        | In:1 Past:32 Total:33 H:32 KV:32 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
HighHeads_GQA        | In:1 Past:32 Total:33 H:32 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.38e-07)
HighHeads_MQA        | In:1 Past:32 Total:33 H:32 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
LgCtx_Prefill        | In:128 Past:0 Total:128 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 4.17e-07)
LgCtx_Decode         | In:1 Past:127 Total:128 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 2.98e-07)
TinyHead_MHA         | In:4 Past:4 Total:8 H:4 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
TinyHead_GQA         | In:4 Past:4 Total:8 H:8 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.49e-07)
LgHead_MHA           | In:2 Past:2 Total:4 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
LgHead_GQA           | In:2 Past:2 Total:4 H:4 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.79e-07)
Ratio_5_1            | In:1 Past:10 Total:11 H:5 KV:1 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
Ratio_6_2            | In:1 Past:10 Total:11 H:6 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
Ratio_6_3            | In:1 Past:10 Total:11 H:6 KV:3 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 5.96e-08)
Ratio_12_4           | In:1 Past:10 Total:11 H:12 KV:4 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
Zero_Past            | In:1 Past:0 Total:1 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 0.00e+00)
Single_Token_Prefill | In:1 Past:0 Total:1 H:2 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 0.00e+00)
Rotary_Cache_Test    | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary               | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Window_Small         | In:10 Past:0 Total:10 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Window_Large         | In:10 Past:0 Total:10 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.79e-07)
Window_Decode        | In:1 Past:20 Total:21 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Softcap_Enabled      | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 5.98e-05)
Scale_0.5            | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_MHA: 1.19e-07)
Rotary_Interleaved   | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary               | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary_Half          | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary Interleaved 2 | In:4 Past:0 Total:4 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.19e-07)
Rotary_Window        | In:16 Past:0 Total:16 H:4 KV:2 -> ✅ PASS (Max Diff vs CPUExecutionProvider_GQA: 1.79e-07)

🎉 ALL SCENARIOS PASSED ACROSS ALL PROVIDERS.

Motivation and Context

cc @qjia7 @guschmue

@xenova xenova changed the title Propagate rotary_interleaved parameter to GQA shader [webgpu] Propagate rotary_interleaved parameter to GQA shader Dec 9, 2025
@guschmue
Copy link
Contributor

guschmue commented Dec 9, 2025

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Dec 9, 2025
@guschmue guschmue enabled auto-merge (squash) December 9, 2025 18:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants