Skip to content

perf(qwen3-14b prefill): reduce task fragmentation in prefill scheduling#662

Draft
Leaf-Salix wants to merge 4 commits into
hw-native-sys:mainfrom
Leaf-Salix:codex/prefill-tuning
Draft

perf(qwen3-14b prefill): reduce task fragmentation in prefill scheduling#662
Leaf-Salix wants to merge 4 commits into
hw-native-sys:mainfrom
Leaf-Salix:codex/prefill-tuning

Conversation

@Leaf-Salix

@Leaf-Salix Leaf-Salix commented Jul 1, 2026

Copy link
Copy Markdown

Summary

  • Optimize Qwen3-14B prefill scheduling in models/qwen3/14b/prefill_fwd.py, reducing excessive fine-grained task dispatch while keeping the model interface and runner inputs unchanged.
  • Rework the prefill layer task layout around RoPE/KV-cache, attention, projection, RMSNorm, MLP, and residual phases so the generated orchestration uses fewer, larger, and better-balanced tasks than the original fragmented task graph.
  • Fix Q RoPE correctness in the optimized path so the optimized prefill remains numerically aligned with the original Qwen3-14B implementation.
  • Golden PASS on a2a3 for the Qwen3-14B prefill validation path, including batch=1 max_seq=128.
Scenario Metric Original / Revert Optimized Reduction Speedup / Ratio
40-layer batch=1 max_seq=128 L2 swimlane Task count 351050 27866 92.1% 12.6x fewer
40-layer batch=1 max_seq=128 L2 swimlane Total Test Time 536.352 ms 96.103 ms 82.1% 5.58x faster
Real-weight serving replay swimlane Prefill device span ~584.1 ms ~99.6 ms 83.0% 5.86x faster
  • The real-weight serving replay swimlane shows the same device-side trend as the synthetic L2 swimlane benchmark.

Notes

  • This PR is focused on Qwen3-14B prefill task scheduling and correctness. It does not require pypto runtime changes.
  • The timing numbers are L2 swimlane/device-side measurements. End-to-end serving wall time includes additional L3/runtime/dispatch/drain overhead and should be interpreted separately from the prefill kernel device span.

@chatgpt-codex-connector

Copy link
Copy Markdown

You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard.
To continue using code reviews, you can upgrade your account or add credits to your account and enable them for code reviews in your settings.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Qwen3-14B prefill forward pass in prefill_fwd.py to utilize SPMD parallelization, pipelining, and micro-windowing optimizations across various stages, including RMSNorm, projections, attention, and MLP. Feedback on these changes highlights a critical correctness bug in _attention_micro_window where sequence lengths exceeding 128 tokens cause out-of-bounds indexing and bypass the necessary online softmax reduction. Additionally, a performance optimization is suggested to avoid peeling the first iteration of pl.pipeline loops, which would allow better overlapping of memory loads on CANN/Ascend hardware.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread models/qwen3/14b/prefill_fwd.py Outdated
Comment on lines +183 to +187
raw_row0 = (
gi * ATTN_TOK_GROUP * Q_HEAD_PAD
+ dd * Q_HEAD_PAD
+ sb * Q_HEAD_PAD
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Critical Correctness Bug: Out-of-Bounds Indexing and Missing Online Softmax Reduction for Multi-Block Contexts

There is a critical correctness issue in _attention_micro_window when the sequence length exceeds SEQ_TILE (128 tokens), which causes block_ctx_blocks > 1 and allows sb to be greater than 0.

  1. Out-of-Bounds Indexing:

    • raw_scores_group is allocated with shape [ATTN_TOK_GROUP * TOTAL_Q_GROUPS * Q_HEAD_PAD, SEQ_TILE].
    • In raw_row0 (lines 183–187), sb * Q_HEAD_PAD is added to the row index. When sb >= 1, the maximum index can reach TOTAL_Q_GROUPS * ATTN_TOK_GROUP * Q_HEAD_PAD, which is out of bounds (the maximum valid index is TOTAL_Q_GROUPS * ATTN_TOK_GROUP * Q_HEAD_PAD - 1).
    • The same out-of-bounds indexing occurs in li_row0 (lines 232–236) on cur_li_group and exp_row0 (lines 267–271) on exp_padded_group.
  2. Missing Online Softmax Reduction:

    • For sequence lengths spanning multiple blocks (block_ctx_blocks > 1), attention scores must be mathematically reduced across all sb blocks using online softmax (tracking running maximums mi and sums li).
    • In the optimized code, there is no reduction/accumulation of mi and li across different sb blocks. Instead, _finalize_attention_micro_window simply slices oi_tmp_group and cur_li_group using exp_base and li_base (which do not depend on sb), effectively ignoring all blocks other than sb = 0 (or writing/reading corrupted data due to the out-of-bounds indexing).

This will cause silent correctness corruption or runtime crashes for any prefill sequence length greater than 128 tokens. Since the validation test only ran with max_seq=128, this bug was not caught by the test harness.

Comment on lines +425 to +430
q_acc = pl.matmul(tile_a, tile_w, out_dtype=pl.FP32)
for kb in pl.pipeline(1, HIDDEN_BLOCKS, stage=2):
k0 = kb * K_CHUNK
tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0])
tile_w_i = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [layer_hidden_base + k0, q0])
q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_w_i)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Performance Optimization: Avoid Peeling the First Iteration of pl.pipeline Loops

In PyPTO on CANN/Ascend hardware, keeping a conditional check (e.g., if kb == 0) inside a pl.pipeline loop is preferred over peeling the first iteration. This allows the first chunk's load to overlap with the rest of the pipeline rather than running as an un-pipelined prologue, provided the compiler successfully pipelines the loop-index branch.

This optimization opportunity also applies to other stages in this file (Stage 1.3, Stage 3.1, Stage 3.3a, and Stage 3.3b).

References
  1. In PyPTO on CANN/Ascend hardware, keeping a conditional check (e.g., if db == 0) inside a pl.pipeline loop can be preferred over peeling the first iteration. This allows the first chunk's load to overlap with the rest of the pipeline rather than running as an un-pipelined prologue, provided the compiler successfully pipelines the loop-index branch.

@Leaf-Salix Leaf-Salix marked this pull request as draft July 1, 2026 08:56
@coderabbitai

coderabbitai Bot commented Jul 1, 2026

Copy link
Copy Markdown

Review Change Stack

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 12d19a4e-d85c-445e-9628-0051e25f3875

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

The Qwen3-14B prefill forward kernel is rewritten to replace the prior monolithic causal attention loop with an SPMD-tiled implementation. Q/K/V RMSNorm and projections now use chunked matmuls with FP32 accumulators, and attention is computed via micro-window scoring plus staged finalize across multiple finalize cores, feeding into the existing output/MLP/residual writeback path.

Changes

Prefill kernel rewrite

Layer / File(s) Summary
Tiling constants and micro-window helpers
models/qwen3/14b/prefill_fwd.py
New tiling/grouping constants and _attention_micro_window/_finalize_attention_micro_window helper kernels are added, plus a minor CI marker comment fix.
SPMD RMSNorm and Q/K/V projections
models/qwen3/14b/prefill_fwd.py
Stage 1 of prefill_layer switches to pl.spmd loops with chunked matmuls and FP32 accumulators for RMSNorm and Q/K/V projection, followed by per-head Q/K RMSNorm on the new tile layout.
RoPE, KV cache, and staged attention finalize
models/qwen3/14b/prefill_fwd.py
Scope 2 builds a padded Q tile after RoPE, updates the KV cache with RoPE, computes attention via repeated _attention_micro_window calls, and finalizes across multiple finalize cores into attn_tile.
Output projection, MLP, and residual writeback
models/qwen3/14b/prefill_fwd.py
Scope 3 output projection, post-attention RMSNorm, SwiGLU MLP, and down projection with residual are refactored to consume the new attention tile layout with updated valid_shape masking.

Estimated code review effort: 4 (Complex) | ~60 minutes

Sequence Diagram(s)

sequenceDiagram
  participant PrefillLayer
  participant RoPECacheUpdate
  participant AttentionMicroWindow
  participant FinalizeMicroWindow

  PrefillLayer->>RoPECacheUpdate: Apply RoPE, build padded Q tile, update KV cache
  RoPECacheUpdate->>AttentionMicroWindow: Compute per-window raw scores, softmax, SV scratch
  AttentionMicroWindow->>FinalizeMicroWindow: Pass scratch tensors per finalize core
  FinalizeMicroWindow->>PrefillLayer: Return finalized attn_tile
  PrefillLayer->>PrefillLayer: Output projection, MLP, residual writeback
Loading

Possibly related PRs

  • hw-native-sys/pypto-lib#143: Ports/rewrites the same Scope 2 attention and Scope 3 output/MLP tiling patterns in the same prefill_fwd.py file.
  • hw-native-sys/pypto-lib#161: Refactors the same Scope 2 attention softmax accumulation and finalize/ctx writeback structure.
  • hw-native-sys/pypto-lib#447: Refactors Q/K/V RoPE computation from looping into SPMD/tiled execution, analogous to this PR's projection/RoPE tiling changes.

Suggested labels: enhancement

Poem

A rabbit hops through tiles so neat,
Micro-windows dance to softmax beat,
RoPE spins fast, the cache renews,
Finalize cores compute the news,
Scope by scope, the kernel's fleet! 🐇✨

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly reflects the main change: reducing task fragmentation in Qwen3-14B prefill scheduling.
Description check ✅ Passed The description is directly related to the changeset and summarizes the scheduling optimization and correctness fix.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
models/qwen3/14b/prefill_fwd.py (1)

87-88: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Guard the manually unrolled micro-window coverage.

The 10 hard-coded calls cover at most 10 * ATTN_TOK_GROUP relative tokens. Since FINALIZE_TOK_GROUP follows TOK_TILE, a future tile retune can silently leave tail tokens uncomputed.

Proposed guard
 FINALIZE_SPMD_BLOCKS = 48
 FINALIZE_TOK_GROUP = TOK_TILE
 ATTN_MICRO_WORK_ITEMS = ATTN_TOK_GROUP * TOTAL_Q_GROUPS
+ATTN_MICRO_WINDOWS = 10
+assert FINALIZE_TOK_GROUP <= ATTN_MICRO_WINDOWS * ATTN_TOK_GROUP

Also applies to: 620-868, 870-965

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/prefill_fwd.py` around lines 87 - 88, The manually unrolled
micro-window calls in prefill_fwd.py only cover a fixed range, so a future
change to FINALIZE_TOK_GROUP/TOK_TILE can leave tail tokens unprocessed. Update
the logic around the unrolled attention work in the prefill path to add an
explicit guard that checks the covered range against the total tokens needed and
falls back to a safe loop or assertion when the fixed 10-call span is
insufficient. Use the existing symbols FINALIZE_TOK_GROUP,
ATTN_MICRO_WORK_ITEMS, and the unrolled prefill/attention block to locate and
protect all affected call sites.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/qwen3/14b/prefill_fwd.py`:
- Around line 140-147: The attention finalize path in prefill_fwd.py is losing
sequence-block state because the scratch tensors are indexed only by (gi, token,
q_head) while the finalizer reads one row per (gi, token), so blocks alias each
other when ctx_blocks > 1. Update the online softmax flow in the affected
helpers to preserve O_i/L_i across all sb blocks, either by adding an explicit
sequence-block scratch/reduction dimension or by accumulating and reducing the
state before the final divide. Make the fix consistently in the referenced
attention finalize sections so the row layout matches the way sb is encoded in
the offsets.

---

Nitpick comments:
In `@models/qwen3/14b/prefill_fwd.py`:
- Around line 87-88: The manually unrolled micro-window calls in prefill_fwd.py
only cover a fixed range, so a future change to FINALIZE_TOK_GROUP/TOK_TILE can
leave tail tokens unprocessed. Update the logic around the unrolled attention
work in the prefill path to add an explicit guard that checks the covered range
against the total tokens needed and falls back to a safe loop or assertion when
the fixed 10-call span is insufficient. Use the existing symbols
FINALIZE_TOK_GROUP, ATTN_MICRO_WORK_ITEMS, and the unrolled prefill/attention
block to locate and protect all affected call sites.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: c1b6e97a-2e0f-43bd-b577-8375d667662e

📥 Commits

Reviewing files that changed from the base of the PR and between 57772f3 and 4f31d09.

📒 Files selected for processing (1)
  • models/qwen3/14b/prefill_fwd.py

Comment on lines +140 to +147
raw_scores_group = pl.create_tensor(
[ATTN_TOK_GROUP * TOTAL_Q_GROUPS * Q_HEAD_PAD, SEQ_TILE],
dtype=pl.FP32,
)
exp_padded_group = pl.create_tensor(
[ATTN_TOK_GROUP * TOTAL_Q_GROUPS * Q_HEAD_PAD, SEQ_TILE],
dtype=pl.BF16,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🔴 Critical | 🏗️ Heavy lift

Preserve sequence-block state through attention finalize.

sb is encoded into row offsets, but the scratch tensors only have rows for (gi, token, q_head), and the finalizer reads one row per (gi, token). When ctx_blocks > 1, block rows alias other token/group rows and O_i/L_i is normalized per block instead of across the full causal prefix, producing incorrect attention for contexts longer than SEQ_TILE.

Fix by carrying online softmax state across all sb blocks, or by adding an explicit sequence-block scratch/reduction dimension before the final divide.

Also applies to: 183-191, 232-245, 267-279, 303-307

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/prefill_fwd.py` around lines 140 - 147, The attention
finalize path in prefill_fwd.py is losing sequence-block state because the
scratch tensors are indexed only by (gi, token, q_head) while the finalizer
reads one row per (gi, token), so blocks alias each other when ctx_blocks > 1.
Update the online softmax flow in the affected helpers to preserve O_i/L_i
across all sb blocks, either by adding an explicit sequence-block
scratch/reduction dimension or by accumulating and reducing the state before the
final divide. Make the fix consistently in the referenced attention finalize
sections so the row layout matches the way sb is encoded in the offsets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant