Skip to content

perf(dsv4 prefill sparse-attn): per-group o_proj decoupling + col_expand_add/A_K_TILE (-15%)#649

Merged
zhangqi-chen merged 2 commits into
hw-native-sys:mainfrom
zhaozhaozz:perf/dsv4-prefill-decode-migration
Jul 1, 2026
Merged

perf(dsv4 prefill sparse-attn): per-group o_proj decoupling + col_expand_add/A_K_TILE (-15%)#649
zhangqi-chen merged 2 commits into
hw-native-sys:mainfrom
zhaozhaozz:perf/dsv4-prefill-decode-migration

Conversation

@zhaozhaozz

Copy link
Copy Markdown
Contributor

Summary

Mirror two decode sparse_attn o_proj idioms into prefill sparse_attn, bringing prefill's grouped output projection in line with decode. Two commits, smaller-then-bigger:

1. col_expand_add + A_K_TILE 256 — numerically safe

  • qk_pv: fold the per-block softmax bias via pl.col_expand_add instead of col_expand onto a dead pl.full(0) base + a separate add. Bit-exact; drops two vector ops and a dead UB tile in the softmax band.
  • proj_a: A_K_TILE 128 -> 256. Under b_trans the wo_a operand is K-contiguous, so K*2B = 512B fills a full a2a3 L2 line (128 wasted half); O_GROUP_IN = 4096, so the cube K-frag loop just halves (32 -> 16). Output identical.

2. Per-group o_proj decoupling — mirrors decode

  • Replace the 3-spmd proj_a -> quant -> proj_b form, whose per-row GLOBAL amax barrier serialized cube and vector, with decode's pl.manual_scope 4-segment pipeline: proj_a_mm (pure cube) -> quant (PER-GROUP amax, no global barrier) -> proj_b_mm (pure cube INT32 partials) -> proj_b_act (pure vector dequant+sum), with fine-grained pl.array(TASK_ID) deps. merge_norm/rope become with-form spmd so proj_a can depend on the o_packed producers.
  • Golden switched to per-group amax to match the kernel (see Golden change).

Validation (real NPU, a2a3)

  • Golden PASS (ratio_allclose) for both commits.
  • L2 swimlane: makespan 787 us vs 903-968 us baseline (~ -15%), reproduced across two devices (787.14 / 786.86). The decoupling moves the dominant proj_b vector band off the critical path (proj_b_aiv ~112 us fused dequant -> proj_b_act ~29 us) and lets proj_b cube overlap later groups' proj_a/quant.
  • A_K_TILE swept 128/256/512: proj_a_aic 29.06 -> 26.99 us (-7.1%) at 256; 512 returns to 29.05 -- a clean cache-line optimum.

Reproduce:

python models/deepseek/v4/prefill_sparse_attn.py -p a2a3                        # correctness
python models/deepseek/v4/prefill_sparse_attn.py -p a2a3 --enable-l2-swimlane 4  # timing

Golden change -- why, and precision

The per-group commit changes the INT8 activation quant of o_r from per-row-global amax to per-group amax, so the golden was updated to model the same scheme. This aligns prefill with decode, which already uses per-group (kernel and golden) -- before this PR prefill was the outlier (per-row-global). Precision is not regressed: per-group is a finer activation quant, so it is closer to the official DeepSeek V4 o_proj path (act_quant(block_size=128), finer still). Measured INT8 quant error vs full precision: per-row-global 0.9-1.2% -> per-group ~0.78% (official-granularity per-128-block ~0.64%).

Risks / notes

  • Per-group adds a [T, O_GROUPS*D] INT32 partials intermediate (~16 MB at T=128); L2-swimlane profiling needs a larger PTO2_RING_HEAP than the default.
  • T=128 is the kernel's fixed prefill-chunk size; the perf numbers are at that size and accumulate across chunks for longer sequences.

@coderabbitai

coderabbitai Bot commented Jun 30, 2026

Copy link
Copy Markdown

Review Change Stack

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 34b6622d-1009-4e18-b984-e60d0480ee60

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Updates prefill_sparse_attn.py with expanded INT8 projection tiling constants, a fused col_expand_add for QK bias, explicit SPMD task-id handles for merge_norm and rope loops, a fully decoupled per-group INT8 output projection pipeline under pl.manual_scope(), and a matching per-group torch golden reference.

DeepSeek-V4 Prefill Sparse Attention Refactor

Layer / File(s) Summary
Projection/quantization tiling constants
models/deepseek/v4/prefill_sparse_attn.py
A_K_TILE increased; new derived tiles (QUANT_TILE, PROJ_*, PA_NFRAGS, PB_DCHUNKS, NUM_QUANT_T_CHUNKS, PROJ_B_ACT_*) and updated divisibility assertions added for grouped INT8 projection.
QK bias fusion and SPMD task-id exposure
models/deepseek/v4/prefill_sparse_attn.py
QK softmax bias switches from zero-tile + add to pl.col_expand_add; merge_norm and rope spmd loops refactored to with pl.spmd(...) as merge_tid/rope_tid to expose task-ids as explicit dependencies.
Decoupled per-group INT8 output projection pipeline
models/deepseek/v4/prefill_sparse_attn.py
INT8 output projection replaced with a pl.manual_scope() pipeline: per-group proj_a_mm (deps on merge_tid, rope_tid) → act_scale_dq quantization → proj_b_mm INT32 partials → proj_b_act dequantization into attn_out.
Updated torch golden reference
models/deepseek/v4/prefill_sparse_attn.py
Golden reference reshaped to [T, G, O_LORA] with per-group amax/scale, INT8 quant/dequant, and per-channel wo_b_scale, replacing the prior flattened per-row quantization path.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#394: Directly overlaps in merge+RoPE handling and INT8/quantization behavior in the same file.
  • hw-native-sys/pypto-lib#477: Restructures RoPE SPMD/task tiling and INT8 quantization tiling across output groups in the same kernel.
  • hw-native-sys/pypto-lib#569: Changes sparse-attn tiling constants, QK softmax bias, rope+merge SPMD structure, and grouped INT8 output-projection pipeline in the same file.

Suggested labels

enhancement

🐇 A rabbit hopped through tiling land,
Where INT8 groups were carefully planned.
col_expand_add fused the bias neat,
merge_tid and rope_tid took their seat.
Per-group pipelines, staged with care—
The golden reference now matches there! 🌟

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main perf refactor and optimization work in prefill sparse-attn.
Description check ✅ Passed The description is directly related and accurately explains the sparse-attn output-projection changes and validation.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the prefill sparse attention kernel to decouple the grouped INT8 output projection per-group, mirroring the decode sparse attention pipeline. Key changes include updating tile sizes, optimizing bias addition via pl.col_expand_add, establishing explicit task dependencies with pl.spmd contexts, and rewriting the projection, quantization, and activation steps under a manual scope to overlap execution. The golden reference implementation is also updated to match. The review feedback suggests optimizing the pipelined loops in both proj_a_mm and proj_b_mm by avoiding first-iteration peeling and instead using conditional checks inside the pl.pipeline loops to allow better load overlap on CANN/Ascend hardware.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread models/deepseek/v4/prefill_sparse_attn.py
Comment thread models/deepseek/v4/prefill_sparse_attn.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
models/deepseek/v4/prefill_sparse_attn.py (1)

84-88: 🎯 Functional Correctness | 🔵 Trivial | ⚡ Quick win

Add the missing A_K_TILE divisibility guard.

proj_a_mm peels 0:A_K_TILE and then covers O_GROUP_IN // A_K_TILE chunks, so a non-divisible O_GROUP_IN would leave a K-tail uncovered. Add this next to the other tile guards:

Proposed guard
 assert T % QUANT_TOKEN_TILE == 0
+assert O_GROUP_IN % A_K_TILE == 0
 assert D % PROJ_B_MM_N_TILE == 0 and D % PROJ_B_D_CHUNK == 0 and PROJ_B_D_CHUNK % PROJ_B_MM_N_TILE == 0

Based on learnings, DeepSeek v4 module-level correctness guards using assert are intentional.

Also applies to: 356-363

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/prefill_sparse_attn.py` around lines 84 - 88, Add the
missing divisibility assert for A_K_TILE in the module-level guard block around
the existing tensor tile checks in prefill_sparse_attn.py. The issue is that
proj_a_mm peels 0:A_K_TILE and then iterates over O_GROUP_IN // A_K_TILE, so
O_GROUP_IN must be divisible by A_K_TILE to avoid an uncovered K-tail; place
this next to the other assert guards and keep the same correctness-check style
used throughout the DeepSeek v4 module.

Source: Learnings

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@models/deepseek/v4/prefill_sparse_attn.py`:
- Around line 84-88: Add the missing divisibility assert for A_K_TILE in the
module-level guard block around the existing tensor tile checks in
prefill_sparse_attn.py. The issue is that proj_a_mm peels 0:A_K_TILE and then
iterates over O_GROUP_IN // A_K_TILE, so O_GROUP_IN must be divisible by
A_K_TILE to avoid an uncovered K-tail; place this next to the other assert
guards and keep the same correctness-check style used throughout the DeepSeek v4
module.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: c985c0a0-a5cc-405b-b6ef-63052c4ea130

📥 Commits

Reviewing files that changed from the base of the PR and between 1a23bd3 and 17a1435.

📒 Files selected for processing (1)
  • models/deepseek/v4/prefill_sparse_attn.py

@zhaozhaozz zhaozhaozz force-pushed the perf/dsv4-prefill-decode-migration branch from 17a1435 to 1042228 Compare June 30, 2026 08:43
… -7.1%)

Mirror two numerically-safe decode sparse-attn idioms into prefill:

- qk_pv: fold the per-block softmax bias into pl.col_expand_add instead of
  col_expand into a dead pl.full(0) base + a separate add. Bit-exact; removes
  two vector ops and a [QK_M_TILE, PREFILL_ATTN_TILE] FP32 dead UB tile in the
  attn_qk_softmax band.
- proj_a: A_K_TILE 128 -> 256. With b_trans, wo_a is K-contiguous, so K*2B =
  512B fills a full a2a3 L2 cache line (128 wastes half). O_GROUP_IN = 4096, so
  the cube K-frag loop just halves (32 -> 16 frags); output is identical.

Real-NPU a2a3 swimlane sweep: proj_a_aic 29.06 -> 26.99us (-7.1%) at 256, 512
returns to 29.05 (no gain). All values golden PASS (ratio_allclose).
@zhaozhaozz zhaozhaozz force-pushed the perf/dsv4-prefill-decode-migration branch from 1042228 to c13a93d Compare July 1, 2026 03:13
…pan)

Mirror decode's per-group output-projection pipeline into prefill, replacing the
3-spmd proj_a -> quant -> proj_b form -- whose per-row GLOBAL amax barrier
between proj_a and proj_b serialized cube and vector -- with a manual_scope
4-segment pipeline:

- proj_a_mm (pure cube): BF16 grouped GEMM -> o_r.
- quant (pure vector): PER-GROUP amax (one per O_LORA group, not the full
  O_GROUPS*O_LORA-row reduction) + INT8 quant -> act_scale_dq[G, T]. The
  per-group reduction removes the proj_a<->proj_b barrier.
- proj_b_mm (pure cube): INT8 GEMM -> INT32 partials[:, g*D+n], deps on quant[g]
  only, so group g's proj_b overlaps proj_a/quant of later groups.
- proj_b_act (pure vector): sum the O_GROUPS partials by their per-group act
  scales, apply the per-channel weight scale -> BF16.

Fine-grained deps via pl.manual_scope + pl.array(TASK_ID); merge_norm/rope become
with-form spmd so proj_a can depend on the o_packed producers. Golden switched to
per-group amax to match the kernel.

Real-NPU a2a3 swimlane: golden PASS; makespan 787us vs 903-968us baseline. The
decoupling moves the dominant proj_b vector band off the critical path
(proj_b_aiv ~112us fused dequant -> proj_b_act ~29us) and lets proj_b cube
overlap later groups -- the T=128 e2e win the per-row barrier forbade.
@zhaozhaozz zhaozhaozz force-pushed the perf/dsv4-prefill-decode-migration branch from c13a93d to 6dd3c70 Compare July 1, 2026 06:54
@zhangqi-chen zhangqi-chen merged commit 6b8a92d into hw-native-sys:main Jul 1, 2026
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants