perf(qwen3-14b prefill): reduce task fragmentation in prefill scheduling#662
perf(qwen3-14b prefill): reduce task fragmentation in prefill scheduling#662Leaf-Salix wants to merge 4 commits into
Conversation
|
You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard. |
There was a problem hiding this comment.
Code Review
This pull request refactors the Qwen3-14B prefill forward pass in prefill_fwd.py to utilize SPMD parallelization, pipelining, and micro-windowing optimizations across various stages, including RMSNorm, projections, attention, and MLP. Feedback on these changes highlights a critical correctness bug in _attention_micro_window where sequence lengths exceeding 128 tokens cause out-of-bounds indexing and bypass the necessary online softmax reduction. Additionally, a performance optimization is suggested to avoid peeling the first iteration of pl.pipeline loops, which would allow better overlapping of memory loads on CANN/Ascend hardware.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| raw_row0 = ( | ||
| gi * ATTN_TOK_GROUP * Q_HEAD_PAD | ||
| + dd * Q_HEAD_PAD | ||
| + sb * Q_HEAD_PAD | ||
| ) |
There was a problem hiding this comment.
Critical Correctness Bug: Out-of-Bounds Indexing and Missing Online Softmax Reduction for Multi-Block Contexts
There is a critical correctness issue in _attention_micro_window when the sequence length exceeds SEQ_TILE (128 tokens), which causes block_ctx_blocks > 1 and allows sb to be greater than 0.
-
Out-of-Bounds Indexing:
raw_scores_groupis allocated with shape[ATTN_TOK_GROUP * TOTAL_Q_GROUPS * Q_HEAD_PAD, SEQ_TILE].- In
raw_row0(lines 183–187),sb * Q_HEAD_PADis added to the row index. Whensb >= 1, the maximum index can reachTOTAL_Q_GROUPS * ATTN_TOK_GROUP * Q_HEAD_PAD, which is out of bounds (the maximum valid index isTOTAL_Q_GROUPS * ATTN_TOK_GROUP * Q_HEAD_PAD - 1). - The same out-of-bounds indexing occurs in
li_row0(lines 232–236) oncur_li_groupandexp_row0(lines 267–271) onexp_padded_group.
-
Missing Online Softmax Reduction:
- For sequence lengths spanning multiple blocks (
block_ctx_blocks > 1), attention scores must be mathematically reduced across allsbblocks using online softmax (tracking running maximumsmiand sumsli). - In the optimized code, there is no reduction/accumulation of
miandliacross differentsbblocks. Instead,_finalize_attention_micro_windowsimply slicesoi_tmp_groupandcur_li_groupusingexp_baseandli_base(which do not depend onsb), effectively ignoring all blocks other thansb = 0(or writing/reading corrupted data due to the out-of-bounds indexing).
- For sequence lengths spanning multiple blocks (
This will cause silent correctness corruption or runtime crashes for any prefill sequence length greater than 128 tokens. Since the validation test only ran with max_seq=128, this bug was not caught by the test harness.
| q_acc = pl.matmul(tile_a, tile_w, out_dtype=pl.FP32) | ||
| for kb in pl.pipeline(1, HIDDEN_BLOCKS, stage=2): | ||
| k0 = kb * K_CHUNK | ||
| tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0]) | ||
| tile_w_i = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [layer_hidden_base + k0, q0]) | ||
| q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_w_i) |
There was a problem hiding this comment.
Performance Optimization: Avoid Peeling the First Iteration of pl.pipeline Loops
In PyPTO on CANN/Ascend hardware, keeping a conditional check (e.g., if kb == 0) inside a pl.pipeline loop is preferred over peeling the first iteration. This allows the first chunk's load to overlap with the rest of the pipeline rather than running as an un-pipelined prologue, provided the compiler successfully pipelines the loop-index branch.
This optimization opportunity also applies to other stages in this file (Stage 1.3, Stage 3.1, Stage 3.3a, and Stage 3.3b).
References
- In PyPTO on CANN/Ascend hardware, keeping a conditional check (e.g., if db == 0) inside a pl.pipeline loop can be preferred over peeling the first iteration. This allows the first chunk's load to overlap with the rest of the pipeline rather than running as an un-pipelined prologue, provided the compiler successfully pipelines the loop-index branch.
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThe Qwen3-14B prefill forward kernel is rewritten to replace the prior monolithic causal attention loop with an SPMD-tiled implementation. Q/K/V RMSNorm and projections now use chunked matmuls with FP32 accumulators, and attention is computed via micro-window scoring plus staged finalize across multiple finalize cores, feeding into the existing output/MLP/residual writeback path. ChangesPrefill kernel rewrite
Estimated code review effort: 4 (Complex) | ~60 minutes Sequence Diagram(s)sequenceDiagram
participant PrefillLayer
participant RoPECacheUpdate
participant AttentionMicroWindow
participant FinalizeMicroWindow
PrefillLayer->>RoPECacheUpdate: Apply RoPE, build padded Q tile, update KV cache
RoPECacheUpdate->>AttentionMicroWindow: Compute per-window raw scores, softmax, SV scratch
AttentionMicroWindow->>FinalizeMicroWindow: Pass scratch tensors per finalize core
FinalizeMicroWindow->>PrefillLayer: Return finalized attn_tile
PrefillLayer->>PrefillLayer: Output projection, MLP, residual writeback
Possibly related PRs
Suggested labels: Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
models/qwen3/14b/prefill_fwd.py (1)
87-88: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winGuard the manually unrolled micro-window coverage.
The 10 hard-coded calls cover at most
10 * ATTN_TOK_GROUPrelative tokens. SinceFINALIZE_TOK_GROUPfollowsTOK_TILE, a future tile retune can silently leave tail tokens uncomputed.Proposed guard
FINALIZE_SPMD_BLOCKS = 48 FINALIZE_TOK_GROUP = TOK_TILE ATTN_MICRO_WORK_ITEMS = ATTN_TOK_GROUP * TOTAL_Q_GROUPS +ATTN_MICRO_WINDOWS = 10 +assert FINALIZE_TOK_GROUP <= ATTN_MICRO_WINDOWS * ATTN_TOK_GROUPAlso applies to: 620-868, 870-965
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/qwen3/14b/prefill_fwd.py` around lines 87 - 88, The manually unrolled micro-window calls in prefill_fwd.py only cover a fixed range, so a future change to FINALIZE_TOK_GROUP/TOK_TILE can leave tail tokens unprocessed. Update the logic around the unrolled attention work in the prefill path to add an explicit guard that checks the covered range against the total tokens needed and falls back to a safe loop or assertion when the fixed 10-call span is insufficient. Use the existing symbols FINALIZE_TOK_GROUP, ATTN_MICRO_WORK_ITEMS, and the unrolled prefill/attention block to locate and protect all affected call sites.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/qwen3/14b/prefill_fwd.py`:
- Around line 140-147: The attention finalize path in prefill_fwd.py is losing
sequence-block state because the scratch tensors are indexed only by (gi, token,
q_head) while the finalizer reads one row per (gi, token), so blocks alias each
other when ctx_blocks > 1. Update the online softmax flow in the affected
helpers to preserve O_i/L_i across all sb blocks, either by adding an explicit
sequence-block scratch/reduction dimension or by accumulating and reducing the
state before the final divide. Make the fix consistently in the referenced
attention finalize sections so the row layout matches the way sb is encoded in
the offsets.
---
Nitpick comments:
In `@models/qwen3/14b/prefill_fwd.py`:
- Around line 87-88: The manually unrolled micro-window calls in prefill_fwd.py
only cover a fixed range, so a future change to FINALIZE_TOK_GROUP/TOK_TILE can
leave tail tokens unprocessed. Update the logic around the unrolled attention
work in the prefill path to add an explicit guard that checks the covered range
against the total tokens needed and falls back to a safe loop or assertion when
the fixed 10-call span is insufficient. Use the existing symbols
FINALIZE_TOK_GROUP, ATTN_MICRO_WORK_ITEMS, and the unrolled prefill/attention
block to locate and protect all affected call sites.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: c1b6e97a-2e0f-43bd-b577-8375d667662e
📒 Files selected for processing (1)
models/qwen3/14b/prefill_fwd.py
| raw_scores_group = pl.create_tensor( | ||
| [ATTN_TOK_GROUP * TOTAL_Q_GROUPS * Q_HEAD_PAD, SEQ_TILE], | ||
| dtype=pl.FP32, | ||
| ) | ||
| exp_padded_group = pl.create_tensor( | ||
| [ATTN_TOK_GROUP * TOTAL_Q_GROUPS * Q_HEAD_PAD, SEQ_TILE], | ||
| dtype=pl.BF16, | ||
| ) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🔴 Critical | 🏗️ Heavy lift
Preserve sequence-block state through attention finalize.
sb is encoded into row offsets, but the scratch tensors only have rows for (gi, token, q_head), and the finalizer reads one row per (gi, token). When ctx_blocks > 1, block rows alias other token/group rows and O_i/L_i is normalized per block instead of across the full causal prefix, producing incorrect attention for contexts longer than SEQ_TILE.
Fix by carrying online softmax state across all sb blocks, or by adding an explicit sequence-block scratch/reduction dimension before the final divide.
Also applies to: 183-191, 232-245, 267-279, 303-307
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/14b/prefill_fwd.py` around lines 140 - 147, The attention
finalize path in prefill_fwd.py is losing sequence-block state because the
scratch tensors are indexed only by (gi, token, q_head) while the finalizer
reads one row per (gi, token), so blocks alias each other when ctx_blocks > 1.
Update the online softmax flow in the affected helpers to preserve O_i/L_i
across all sb blocks, either by adding an explicit sequence-block
scratch/reduction dimension or by accumulating and reducing the state before the
final divide. Make the fix consistently in the referenced attention finalize
sections so the row layout matches the way sb is encoded in the offsets.
Summary
models/qwen3/14b/prefill_fwd.py, reducing excessive fine-grained task dispatch while keeping the model interface and runner inputs unchanged.batch=1 max_seq=128.batch=1 max_seq=128L2 swimlanebatch=1 max_seq=128L2 swimlaneNotes