Skip to content

Refactor: scope dsv4 decode+prefill kernels under auto_scope=False#585

Open
zhangqi-chen wants to merge 12 commits into
hw-native-sys:mainfrom
zhangqi-chen:scope-opt
Open

Refactor: scope dsv4 decode+prefill kernels under auto_scope=False#585
zhangqi-chen wants to merge 12 commits into
hw-native-sys:mainfrom
zhangqi-chen:scope-opt

Conversation

@zhangqi-chen

Copy link
Copy Markdown
Collaborator

Summary

  • Restructure the DeepSeek-V4 decode and prefill kernels under @pl.jit(.inline)(auto_scope=False) with explicit pl.scope() blocks so scratch frees at scope exit instead of pinning the function frame once inlined. Behavior-preserving across the board.
  • decode_sparse_attn / prefill_sparse_attn: 3-scope split (S1 gather/bias, S2 attention, S3 inverse-RoPE + output projection); cross-scope buffers (sparse_kv, sparse_bias, attn_rope_stage, o_packed) stay at the frame, sparse_blk_*/q_flat live in S2.
  • decode_attention_csa and prefill_attention_{swa,hca,csa}: a pre-attention scope frees x_mixed/x_normed/qr/qr_scale before sparse_attn; the frame keeps only what is read after (post/comb/rope/q/kv). For prefill_attention_csa the compressor/indexer/sparse-idx build stay at the frame because they write the cache-state OUT-params the function RETURNS (a returned param written by a bare call inside a scope cannot bridge its SSA version to the return).
  • moe + expert_routed: scoped for auto_scope=False composition; create scratch (sh/ffn_out) at first use.
  • decode_layer / prefill_layer: auto_scope=False; the attention dispatch and the MoE call each get their own pl.scope() so attention scratch frees before MoE; MoE is a bare in-place call into x_next. Generated orchestration nests PTO2_SCOPE at most 3 levels deep (within the 4-level budget).
  • All sub-kernel calls converted to bare in-place form (no returned-handle rebind); inline args packed several-per-line; each create_tensor sits at first use unless a scope/return forces it to the frame.
  • decode_fwd: converted to [:] slice sugar with just-in-time per-layer defs.
  • Docs: add the "declare scratch at first use" convention to docs/pypto-coding-style.md (sec 8).
  • Add a --scope-stats CLI flag (runtime enable_scope_stats) to the affected attention / sparse_attn runners.

Verified on a2a3: prefill_attention_{swa,hca,csa} standalone PASS; prefill_layer EP2 (CSA, layer 2) PASS on cards 14,15. Decode-side scope refactors output bit-identical (kv_cache + x_out PASS).

Related Issues

N/A

@coderabbitai

coderabbitai Bot commented Jun 23, 2026

Copy link
Copy Markdown

Review Change Stack

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: b995f994-6afd-47b9-8ec0-8a4c5d734649

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

All DeepSeek V4 decode and prefill kernels (attention_swa/hca/csa, sparse_attn*, moe, expert_routed, prefill_attention_*, prefill_sparse_attn, decode_layer, prefill_layer) switch their JIT decorators to auto_scope=False and gain explicit pl.scope() blocks to control scratch tensor lifetimes. Every CLI harness gains a --scope-stats flag forwarded into run_jit as enable_scope_stats. A new coding style section documents the "declare scratch at first use" pattern.

Changes

Explicit pl.scope() lifetime management

Layer / File(s) Summary
Coding style: declare scratch at first use
docs/pypto-coding-style.md
Adds a "Declare scratch at first use" guidance section with a pl.scope() cross-scope exception and a Python example.
Layer-level orchestrators: auto_scope=False and explicit scope blocks
models/deepseek/v4/decode_layer.py, models/deepseek/v4/prefill_layer.py
Both top-level layer functions switch to @pl.jit(auto_scope=False), wrap attention selection and MoE FFN in separate pl.scope() blocks, and convert moe(...) from return-capture to out-parameter style.
Decode attention kernels: frame-level tensors and nested scopes
models/deepseek/v4/decode_attention_swa.py, models/deepseek/v4/decode_attention_hca.py, models/deepseek/v4/decode_attention_csa.py
attention_swa, attention_hca, and attention_csa switch to auto_scope=False, pre-allocate frame-level output tensors outside scopes, move short-lived scratch into pl.scope() blocks, and convert sub-kernel calls to bare in-place style.
Decode sparse attention kernels: explicit scope wrapping
models/deepseek/v4/decode_sparse_attn.py, models/deepseek/v4/decode_sparse_attn_hca.py, models/deepseek/v4/decode_sparse_attn_swa.py
sparse_attn, sparse_attn_hca, and sparse_attn_swa switch to auto_scope=False and introduce pl.scope() blocks around KV-build, QK/PV computation, online-softmax merge, inverse-RoPE, and grouped output projection stages.
MoE and expert_routed: scoped dispatch/combine phases
models/deepseek/v4/moe.py, models/deepseek/v4/expert_routed.py
moe, moe_ep1, and expert_routed switch to auto_scope=False; gate intermediates and large dispatch/combine receive buffers are scoped separately so they are freed before hc_post; per-tile work in expert_routed is wrapped in a pl.scope().
Prefill attention kernels: frame-level tensors and nested scopes
models/deepseek/v4/prefill_attention_swa.py, models/deepseek/v4/prefill_attention_hca.py, models/deepseek/v4/prefill_attention_csa.py, models/deepseek/v4/prefill_sparse_attn.py
All prefill attention kernels switch to auto_scope=False; frame-level tensors are allocated outside scopes; pre-attention pipelines run inside pl.scope() via bare in-place calls; gather-KV staging and sparse lens assembly are reorganized to respect cross-scope lifetime rules.
--scope-stats CLI flag across all harnesses
models/deepseek/v4/decode_attention_*.py, models/deepseek/v4/decode_sparse_attn*.py, models/deepseek/v4/prefill_attention_*.py, models/deepseek/v4/prefill_sparse_attn.py, models/deepseek/v4/moe.py, models/deepseek/v4/expert_routed.py
All ten CLI harnesses gain --scope-stats; the value is wired into run_jit via runtime_cfg as enable_scope_stats to write per-scope occupancy stats to scope_stats/scope_stats.jsonl.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#533: The main PR's refactor updates multiple decode/prefill attention kernels' hc_pre call sites to use preallocated in-place outputs within explicit pl.scope() blocks, which is tightly connected to the retrieved PR's fused/scoped hc_pre SPMD restructuring.
  • hw-native-sys/pypto-lib#566: Both PRs modify DeepSeek-V4 MoE integration code (models/deepseek/v4/moe.py and its call sites like decode_layer.py)—one for renaming/EP1 vs distributed wiring, the other for refactoring moe's tensor lifetime handling with explicit pl.scope()/auto-scope disabled—so they're directly related at the MoE implementation level.
  • hw-native-sys/pypto-lib#577: The main PR's refactor of attention_csa/attention_hca to treat compressor_ratio4/compressor_ratio128 as in-place writers (stopping tuple/return-value rebinding) is directly aligned with the retrieved PR's updated compressor/indexer return contracts and reduced CMP_MAX_BLOCKS.

Suggested labels

enhancement

🐇 No auto-scope shall bind me tight,
With pl.scope() I keep things right!
Scratch lives short, frame tensors stay,
--scope-stats tracks the stats each day.
Hop hop, the lifetimes dance in place — 🎉

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.83% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main refactoring: restructuring DeepSeek-V4 kernels to use auto_scope=False with explicit scope blocks.
Description check ✅ Passed The description comprehensively explains the changes across multiple files, scope structuring, memory optimizations, and verification results, all directly related to the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes memory reuse across DeepSeek v4 model components (decode/prefill attention paths, decode layer, MoE, and expert routing) by disabling automatic scoping and manually declaring pl.scope() blocks to minimize tensor live ranges. It also introduces a --scope-stats CLI argument to collect occupancy statistics. The reviewer identified a bug in decode_attention_hca.py where the --scope-stats flag is non-functional because the argument is not passed to run_jit. Additionally, several style guide violations were found regarding the 'Declare scratch at first use' rule, specifically where scratch tensors (such as q_flat, cmp_sparse_work, and cmp_sparse_lens_2d) were declared too early or outside their minimal active scopes.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

help="Fixture-only compatibility seed for position_ids and slot mappings; "
"otherwise use the default per-batch coverage pattern.")
parser.add_argument("--enable-l2-swimlane", action="store_true", default=False)
parser.add_argument("--scope-stats", action="store_true", default=False)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The --scope-stats CLI argument is defined here, but enable_scope_stats=args.scope_stats is not passed to the run_jit call in this file (unlike in decode_attention_csa.py and decode_attention_swa.py). This makes the CLI flag non-functional for this runner.

Comment on lines +207 to +211
q_flat = pl.reshape(q, [T * H, HEAD_DIM])
attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32)
o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16)
sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
sparse_blk_li = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
sparse_blk_oi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, HEAD_DIM], dtype=pl.FP32)

for qk_t in pl.spmd(T, name_hint="qk_pv"):
qk_kv_base = qk_t * PADDED_TOPK
qk_token_base = qk_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE
# Sparse-block OUTER / head-tile INNER: the KV tile and bias depend only
# on (token, sparse-block), so hoisting them above the head-tile loop lets
# one KV/bias load serve all head-tiles instead of re-loading per head.
for qk_sb in pl.range(SPARSE_BLOCKS):
qk_s0 = qk_sb * ATTN_K_TILE
qk_kv_k = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM]
qk_kv_v = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM]
qk_bias_row = sparse_bias[qk_t : qk_t + 1, qk_s0 : qk_s0 + ATTN_K_TILE]

# Cube-batch QK_M_TILE head rows per QK/PV matmul so the shared KV
# tile is extracted L1->L0 once per QK_M_TILE/H_TILE head-tiles
# (2x reuse at QK_M_TILE=32) instead of per head-tile. The
# [QK_M_TILE, ...] softmax result is sliced back into H_TILE-row
# stores at the SAME offsets as the per-head-tile path
# (qk_h_idx == qk_hb * (QK_M_TILE // H_TILE) + qk_sub), so the
# sparse_blk_* layout and merge_norm are bit-identical.
for qk_hb in pl.pipeline(H // QK_M_TILE, stage=2):
qk_h0 = qk_hb * QK_M_TILE
qk_head_row = qk_t * H + qk_h0
qk_q_tile = q_flat[qk_head_row : qk_head_row + QK_M_TILE, 0 : HEAD_DIM]
qk_raw = pl.matmul(qk_q_tile, qk_kv_k, b_trans=True, out_dtype=pl.FP32)
qk_scaled = pl.mul(qk_raw, SOFTMAX_SCALE)
qk_scores = pl.add(qk_scaled, pl.col_expand(pl.full([QK_M_TILE, ATTN_K_TILE], dtype=pl.FP32, value=0.0), qk_bias_row))
qk_mi = pl.row_max(qk_scores)
# Invalid lanes (NEG_INF bias, zero kv rows) exp to ~0; all-invalid
# blocks die in the merge alpha/beta -- no mask multiply needed.
qk_exp = pl.exp(pl.row_expand_sub(qk_scores, qk_mi))
qk_li = pl.row_sum(qk_exp)
qk_exp_bf16 = pl.cast(qk_exp, target_type=pl.BF16, mode="rint")
qk_oi = pl.matmul(qk_exp_bf16, qk_kv_v, out_dtype=pl.FP32)
for qk_sub in pl.unroll(QK_M_TILE // H_TILE):
qk_h_idx = qk_hb * (QK_M_TILE // H_TILE) + qk_sub
qk_r0 = qk_sub * H_TILE
qk_blk_base = qk_token_base + qk_h_idx * SPARSE_BLOCKS * H_TILE
qk_row = qk_blk_base + qk_sb * H_TILE
sparse_blk_mi[qk_row : qk_row + H_TILE, 0 : 1] = qk_mi[qk_r0 : qk_r0 + H_TILE, 0 : 1]
sparse_blk_li[qk_row : qk_row + H_TILE, 0 : 1] = qk_li[qk_r0 : qk_r0 + H_TILE, 0 : 1]
sparse_blk_oi[qk_row : qk_row + H_TILE, 0 : HEAD_DIM] = qk_oi[qk_r0 : qk_r0 + H_TILE, 0 : HEAD_DIM]

# Online-softmax merge across sparse-K tiles, then sink-norm.
for m_t in pl.spmd(T, name_hint="merge_norm"):
m_token_base = m_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE

for m_h_idx in pl.range((H // H_TILE)):
m_h0 = m_h_idx * H_TILE
m_blk_base = m_token_base + m_h_idx * SPARSE_BLOCKS * H_TILE
m_mi = sparse_blk_mi[m_blk_base : m_blk_base + H_TILE, 0 : 1]
m_li = sparse_blk_li[m_blk_base : m_blk_base + H_TILE, 0 : 1]
m_oi = sparse_blk_oi[m_blk_base : m_blk_base + H_TILE, 0 : HEAD_DIM]

# Guarded so the SWA (SPARSE_BLOCKS == 1) specialization uses the
# single block's stats directly instead of an empty merge loop.
if SPARSE_BLOCKS > 1:
for m_sb in pl.range(1, SPARSE_BLOCKS):
m_row = m_blk_base + m_sb * H_TILE
m_cur_mi = sparse_blk_mi[m_row : m_row + H_TILE, 0 : 1]
m_cur_li = sparse_blk_li[m_row : m_row + H_TILE, 0 : 1]
m_cur_oi = sparse_blk_oi[m_row : m_row + H_TILE, 0 : HEAD_DIM]
m_mi_new = pl.maximum(m_mi, m_cur_mi)
m_alpha = pl.exp(pl.sub(m_mi, m_mi_new))
m_beta = pl.exp(pl.sub(m_cur_mi, m_mi_new))
m_li = pl.add(pl.mul(m_alpha, m_li), pl.mul(m_beta, m_cur_li))
m_oi = pl.add(pl.row_expand_mul(m_oi, m_alpha), pl.row_expand_mul(m_cur_oi, m_beta))
m_mi = m_mi_new

n_sink_bias = pl.reshape(attn_sink[m_h0 : m_h0 + H_TILE], [H_TILE, 1])
n_sink_tile = pl.add(pl.sub(m_mi, m_mi), n_sink_bias)
n_denom = pl.add(m_li, pl.exp(pl.sub(n_sink_tile, m_mi)))
n_full = pl.row_expand_div(m_oi, n_denom)[0 : H_TILE, 0 : HEAD_DIM]
n_bf16 = pl.cast(n_full, target_type=pl.BF16, mode="rint")
n_rope_row = m_t * H + m_h0
attn_rope_stage[n_rope_row : n_rope_row + H_TILE, 0 : ROPE_DIM] = n_full[0 : H_TILE, NOPE_DIM : HEAD_DIM]

for n_hi in pl.range(H_TILE):
n_gh = m_h0 + n_hi
n_g = n_gh // HEADS_PER_GROUP
n_hh = n_gh - n_g * HEADS_PER_GROUP
n_pack_row = n_g * T + m_t
n_col = n_hh * HEAD_DIM
o_packed[n_pack_row : n_pack_row + 1, n_col : n_col + NOPE_DIM] = n_bf16[n_hi : n_hi + 1, 0 : NOPE_DIM]
with pl.scope():
sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to the repository style guide rule Declare scratch at first use, q_flat should be declared immediately before the first block that uses it. Since q_flat is only used inside the with pl.scope(): block (S2), declaring it outside the scope extends its live range unnecessarily to the function frame level. It should be moved inside the with pl.scope(): block, consistent with decode_sparse_attn.py.

Suggested change
q_flat = pl.reshape(q, [T * H, HEAD_DIM])
attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32)
o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16)
sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
sparse_blk_li = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
sparse_blk_oi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, HEAD_DIM], dtype=pl.FP32)
for qk_t in pl.spmd(T, name_hint="qk_pv"):
qk_kv_base = qk_t * PADDED_TOPK
qk_token_base = qk_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE
# Sparse-block OUTER / head-tile INNER: the KV tile and bias depend only
# on (token, sparse-block), so hoisting them above the head-tile loop lets
# one KV/bias load serve all head-tiles instead of re-loading per head.
for qk_sb in pl.range(SPARSE_BLOCKS):
qk_s0 = qk_sb * ATTN_K_TILE
qk_kv_k = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM]
qk_kv_v = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM]
qk_bias_row = sparse_bias[qk_t : qk_t + 1, qk_s0 : qk_s0 + ATTN_K_TILE]
# Cube-batch QK_M_TILE head rows per QK/PV matmul so the shared KV
# tile is extracted L1->L0 once per QK_M_TILE/H_TILE head-tiles
# (2x reuse at QK_M_TILE=32) instead of per head-tile. The
# [QK_M_TILE, ...] softmax result is sliced back into H_TILE-row
# stores at the SAME offsets as the per-head-tile path
# (qk_h_idx == qk_hb * (QK_M_TILE // H_TILE) + qk_sub), so the
# sparse_blk_* layout and merge_norm are bit-identical.
for qk_hb in pl.pipeline(H // QK_M_TILE, stage=2):
qk_h0 = qk_hb * QK_M_TILE
qk_head_row = qk_t * H + qk_h0
qk_q_tile = q_flat[qk_head_row : qk_head_row + QK_M_TILE, 0 : HEAD_DIM]
qk_raw = pl.matmul(qk_q_tile, qk_kv_k, b_trans=True, out_dtype=pl.FP32)
qk_scaled = pl.mul(qk_raw, SOFTMAX_SCALE)
qk_scores = pl.add(qk_scaled, pl.col_expand(pl.full([QK_M_TILE, ATTN_K_TILE], dtype=pl.FP32, value=0.0), qk_bias_row))
qk_mi = pl.row_max(qk_scores)
# Invalid lanes (NEG_INF bias, zero kv rows) exp to ~0; all-invalid
# blocks die in the merge alpha/beta -- no mask multiply needed.
qk_exp = pl.exp(pl.row_expand_sub(qk_scores, qk_mi))
qk_li = pl.row_sum(qk_exp)
qk_exp_bf16 = pl.cast(qk_exp, target_type=pl.BF16, mode="rint")
qk_oi = pl.matmul(qk_exp_bf16, qk_kv_v, out_dtype=pl.FP32)
for qk_sub in pl.unroll(QK_M_TILE // H_TILE):
qk_h_idx = qk_hb * (QK_M_TILE // H_TILE) + qk_sub
qk_r0 = qk_sub * H_TILE
qk_blk_base = qk_token_base + qk_h_idx * SPARSE_BLOCKS * H_TILE
qk_row = qk_blk_base + qk_sb * H_TILE
sparse_blk_mi[qk_row : qk_row + H_TILE, 0 : 1] = qk_mi[qk_r0 : qk_r0 + H_TILE, 0 : 1]
sparse_blk_li[qk_row : qk_row + H_TILE, 0 : 1] = qk_li[qk_r0 : qk_r0 + H_TILE, 0 : 1]
sparse_blk_oi[qk_row : qk_row + H_TILE, 0 : HEAD_DIM] = qk_oi[qk_r0 : qk_r0 + H_TILE, 0 : HEAD_DIM]
# Online-softmax merge across sparse-K tiles, then sink-norm.
for m_t in pl.spmd(T, name_hint="merge_norm"):
m_token_base = m_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE
for m_h_idx in pl.range((H // H_TILE)):
m_h0 = m_h_idx * H_TILE
m_blk_base = m_token_base + m_h_idx * SPARSE_BLOCKS * H_TILE
m_mi = sparse_blk_mi[m_blk_base : m_blk_base + H_TILE, 0 : 1]
m_li = sparse_blk_li[m_blk_base : m_blk_base + H_TILE, 0 : 1]
m_oi = sparse_blk_oi[m_blk_base : m_blk_base + H_TILE, 0 : HEAD_DIM]
# Guarded so the SWA (SPARSE_BLOCKS == 1) specialization uses the
# single block's stats directly instead of an empty merge loop.
if SPARSE_BLOCKS > 1:
for m_sb in pl.range(1, SPARSE_BLOCKS):
m_row = m_blk_base + m_sb * H_TILE
m_cur_mi = sparse_blk_mi[m_row : m_row + H_TILE, 0 : 1]
m_cur_li = sparse_blk_li[m_row : m_row + H_TILE, 0 : 1]
m_cur_oi = sparse_blk_oi[m_row : m_row + H_TILE, 0 : HEAD_DIM]
m_mi_new = pl.maximum(m_mi, m_cur_mi)
m_alpha = pl.exp(pl.sub(m_mi, m_mi_new))
m_beta = pl.exp(pl.sub(m_cur_mi, m_mi_new))
m_li = pl.add(pl.mul(m_alpha, m_li), pl.mul(m_beta, m_cur_li))
m_oi = pl.add(pl.row_expand_mul(m_oi, m_alpha), pl.row_expand_mul(m_cur_oi, m_beta))
m_mi = m_mi_new
n_sink_bias = pl.reshape(attn_sink[m_h0 : m_h0 + H_TILE], [H_TILE, 1])
n_sink_tile = pl.add(pl.sub(m_mi, m_mi), n_sink_bias)
n_denom = pl.add(m_li, pl.exp(pl.sub(n_sink_tile, m_mi)))
n_full = pl.row_expand_div(m_oi, n_denom)[0 : H_TILE, 0 : HEAD_DIM]
n_bf16 = pl.cast(n_full, target_type=pl.BF16, mode="rint")
n_rope_row = m_t * H + m_h0
attn_rope_stage[n_rope_row : n_rope_row + H_TILE, 0 : ROPE_DIM] = n_full[0 : H_TILE, NOPE_DIM : HEAD_DIM]
for n_hi in pl.range(H_TILE):
n_gh = m_h0 + n_hi
n_g = n_gh // HEADS_PER_GROUP
n_hh = n_gh - n_g * HEADS_PER_GROUP
n_pack_row = n_g * T + m_t
n_col = n_hh * HEAD_DIM
o_packed[n_pack_row : n_pack_row + 1, n_col : n_col + NOPE_DIM] = n_bf16[n_hi : n_hi + 1, 0 : NOPE_DIM]
with pl.scope():
sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32)
o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16)
with pl.scope():
q_flat = pl.reshape(q, [T * H, HEAD_DIM])
sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
References
  1. Declare each pl.create_tensor / pl.slice / pl.reshape immediately before the first pl.at / pl.spmd / pl.range / pl.scope (or sub-kernel call) that uses it — not in a block at the top of the function. Tight placement keeps a tensor's live range minimal, which lets the scope / memory-reuse passes free it as early as possible. (link)

Comment on lines +207 to +211
q_flat = pl.reshape(q, [T * H, HEAD_DIM])
attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32)
o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16)
sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
sparse_blk_li = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
sparse_blk_oi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, HEAD_DIM], dtype=pl.FP32)

for qk_t in pl.spmd(T, name_hint="qk_pv"):
qk_kv_base = qk_t * PADDED_TOPK
qk_token_base = qk_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE
# Sparse-block OUTER / head-tile INNER: the KV tile and bias depend only
# on (token, sparse-block), so hoisting them above the head-tile loop lets
# one KV/bias load serve all head-tiles instead of re-loading per head.
for qk_sb in pl.range(SPARSE_BLOCKS):
qk_s0 = qk_sb * ATTN_K_TILE
qk_kv_k = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM]
qk_kv_v = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM]
qk_bias_row = sparse_bias[qk_t : qk_t + 1, qk_s0 : qk_s0 + ATTN_K_TILE]

# Cube-batch QK_M_TILE head rows per QK/PV matmul so the shared KV
# tile is extracted L1->L0 once per QK_M_TILE/H_TILE head-tiles
# (2x reuse at QK_M_TILE=32) instead of per head-tile. The
# [QK_M_TILE, ...] softmax result is sliced back into H_TILE-row
# stores at the SAME offsets as the per-head-tile path
# (qk_h_idx == qk_hb * (QK_M_TILE // H_TILE) + qk_sub), so the
# sparse_blk_* layout and merge_norm are bit-identical.
for qk_hb in pl.pipeline(H // QK_M_TILE, stage=2):
qk_h0 = qk_hb * QK_M_TILE
qk_head_row = qk_t * H + qk_h0
qk_q_tile = q_flat[qk_head_row : qk_head_row + QK_M_TILE, 0 : HEAD_DIM]
qk_raw = pl.matmul(qk_q_tile, qk_kv_k, b_trans=True, out_dtype=pl.FP32)
qk_scaled = pl.mul(qk_raw, SOFTMAX_SCALE)
qk_scores = pl.add(qk_scaled, pl.col_expand(pl.full([QK_M_TILE, ATTN_K_TILE], dtype=pl.FP32, value=0.0), qk_bias_row))
qk_mi = pl.row_max(qk_scores)
# Invalid lanes (NEG_INF bias, zero kv rows) exp to ~0; all-invalid
# blocks die in the merge alpha/beta -- no mask multiply needed.
qk_exp = pl.exp(pl.row_expand_sub(qk_scores, qk_mi))
qk_li = pl.row_sum(qk_exp)
qk_exp_bf16 = pl.cast(qk_exp, target_type=pl.BF16, mode="rint")
qk_oi = pl.matmul(qk_exp_bf16, qk_kv_v, out_dtype=pl.FP32)
for qk_sub in pl.unroll(QK_M_TILE // H_TILE):
qk_h_idx = qk_hb * (QK_M_TILE // H_TILE) + qk_sub
qk_r0 = qk_sub * H_TILE
qk_blk_base = qk_token_base + qk_h_idx * SPARSE_BLOCKS * H_TILE
qk_row = qk_blk_base + qk_sb * H_TILE
sparse_blk_mi[qk_row : qk_row + H_TILE, 0 : 1] = qk_mi[qk_r0 : qk_r0 + H_TILE, 0 : 1]
sparse_blk_li[qk_row : qk_row + H_TILE, 0 : 1] = qk_li[qk_r0 : qk_r0 + H_TILE, 0 : 1]
sparse_blk_oi[qk_row : qk_row + H_TILE, 0 : HEAD_DIM] = qk_oi[qk_r0 : qk_r0 + H_TILE, 0 : HEAD_DIM]

# Online-softmax merge across sparse-K tiles, then sink-norm.
for m_t in pl.spmd(T, name_hint="merge_norm"):
m_token_base = m_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE

for m_h_idx in pl.range((H // H_TILE)):
m_h0 = m_h_idx * H_TILE
m_blk_base = m_token_base + m_h_idx * SPARSE_BLOCKS * H_TILE
m_mi = sparse_blk_mi[m_blk_base : m_blk_base + H_TILE, 0 : 1]
m_li = sparse_blk_li[m_blk_base : m_blk_base + H_TILE, 0 : 1]
m_oi = sparse_blk_oi[m_blk_base : m_blk_base + H_TILE, 0 : HEAD_DIM]

# Guarded so the SWA (SPARSE_BLOCKS == 1) specialization uses the
# single block's stats directly instead of an empty merge loop.
if SPARSE_BLOCKS > 1:
for m_sb in pl.range(1, SPARSE_BLOCKS):
m_row = m_blk_base + m_sb * H_TILE
m_cur_mi = sparse_blk_mi[m_row : m_row + H_TILE, 0 : 1]
m_cur_li = sparse_blk_li[m_row : m_row + H_TILE, 0 : 1]
m_cur_oi = sparse_blk_oi[m_row : m_row + H_TILE, 0 : HEAD_DIM]
m_mi_new = pl.maximum(m_mi, m_cur_mi)
m_alpha = pl.exp(pl.sub(m_mi, m_mi_new))
m_beta = pl.exp(pl.sub(m_cur_mi, m_mi_new))
m_li = pl.add(pl.mul(m_alpha, m_li), pl.mul(m_beta, m_cur_li))
m_oi = pl.add(pl.row_expand_mul(m_oi, m_alpha), pl.row_expand_mul(m_cur_oi, m_beta))
m_mi = m_mi_new

n_sink_bias = pl.reshape(attn_sink[m_h0 : m_h0 + H_TILE], [H_TILE, 1])
n_sink_tile = pl.add(pl.sub(m_mi, m_mi), n_sink_bias)
n_denom = pl.add(m_li, pl.exp(pl.sub(n_sink_tile, m_mi)))
n_full = pl.row_expand_div(m_oi, n_denom)[0 : H_TILE, 0 : HEAD_DIM]
n_bf16 = pl.cast(n_full, target_type=pl.BF16, mode="rint")
n_rope_row = m_t * H + m_h0
attn_rope_stage[n_rope_row : n_rope_row + H_TILE, 0 : ROPE_DIM] = n_full[0 : H_TILE, NOPE_DIM : HEAD_DIM]

for n_hi in pl.range(H_TILE):
n_gh = m_h0 + n_hi
n_g = n_gh // HEADS_PER_GROUP
n_hh = n_gh - n_g * HEADS_PER_GROUP
n_pack_row = n_g * T + m_t
n_col = n_hh * HEAD_DIM
o_packed[n_pack_row : n_pack_row + 1, n_col : n_col + NOPE_DIM] = n_bf16[n_hi : n_hi + 1, 0 : NOPE_DIM]
with pl.scope():
sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to the repository style guide rule Declare scratch at first use, q_flat should be declared immediately before the first block that uses it. Since q_flat is only used inside the with pl.scope(): block (S2), declaring it outside the scope extends its live range unnecessarily to the function frame level. It should be moved inside the with pl.scope(): block, consistent with decode_sparse_attn.py.

Suggested change
q_flat = pl.reshape(q, [T * H, HEAD_DIM])
attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32)
o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16)
sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
sparse_blk_li = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
sparse_blk_oi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, HEAD_DIM], dtype=pl.FP32)
for qk_t in pl.spmd(T, name_hint="qk_pv"):
qk_kv_base = qk_t * PADDED_TOPK
qk_token_base = qk_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE
# Sparse-block OUTER / head-tile INNER: the KV tile and bias depend only
# on (token, sparse-block), so hoisting them above the head-tile loop lets
# one KV/bias load serve all head-tiles instead of re-loading per head.
for qk_sb in pl.range(SPARSE_BLOCKS):
qk_s0 = qk_sb * ATTN_K_TILE
qk_kv_k = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM]
qk_kv_v = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM]
qk_bias_row = sparse_bias[qk_t : qk_t + 1, qk_s0 : qk_s0 + ATTN_K_TILE]
# Cube-batch QK_M_TILE head rows per QK/PV matmul so the shared KV
# tile is extracted L1->L0 once per QK_M_TILE/H_TILE head-tiles
# (2x reuse at QK_M_TILE=32) instead of per head-tile. The
# [QK_M_TILE, ...] softmax result is sliced back into H_TILE-row
# stores at the SAME offsets as the per-head-tile path
# (qk_h_idx == qk_hb * (QK_M_TILE // H_TILE) + qk_sub), so the
# sparse_blk_* layout and merge_norm are bit-identical.
for qk_hb in pl.pipeline(H // QK_M_TILE, stage=2):
qk_h0 = qk_hb * QK_M_TILE
qk_head_row = qk_t * H + qk_h0
qk_q_tile = q_flat[qk_head_row : qk_head_row + QK_M_TILE, 0 : HEAD_DIM]
qk_raw = pl.matmul(qk_q_tile, qk_kv_k, b_trans=True, out_dtype=pl.FP32)
qk_scaled = pl.mul(qk_raw, SOFTMAX_SCALE)
qk_scores = pl.add(qk_scaled, pl.col_expand(pl.full([QK_M_TILE, ATTN_K_TILE], dtype=pl.FP32, value=0.0), qk_bias_row))
qk_mi = pl.row_max(qk_scores)
# Invalid lanes (NEG_INF bias, zero kv rows) exp to ~0; all-invalid
# blocks die in the merge alpha/beta -- no mask multiply needed.
qk_exp = pl.exp(pl.row_expand_sub(qk_scores, qk_mi))
qk_li = pl.row_sum(qk_exp)
qk_exp_bf16 = pl.cast(qk_exp, target_type=pl.BF16, mode="rint")
qk_oi = pl.matmul(qk_exp_bf16, qk_kv_v, out_dtype=pl.FP32)
for qk_sub in pl.unroll(QK_M_TILE // H_TILE):
qk_h_idx = qk_hb * (QK_M_TILE // H_TILE) + qk_sub
qk_r0 = qk_sub * H_TILE
qk_blk_base = qk_token_base + qk_h_idx * SPARSE_BLOCKS * H_TILE
qk_row = qk_blk_base + qk_sb * H_TILE
sparse_blk_mi[qk_row : qk_row + H_TILE, 0 : 1] = qk_mi[qk_r0 : qk_r0 + H_TILE, 0 : 1]
sparse_blk_li[qk_row : qk_row + H_TILE, 0 : 1] = qk_li[qk_r0 : qk_r0 + H_TILE, 0 : 1]
sparse_blk_oi[qk_row : qk_row + H_TILE, 0 : HEAD_DIM] = qk_oi[qk_r0 : qk_r0 + H_TILE, 0 : HEAD_DIM]
# Online-softmax merge across sparse-K tiles, then sink-norm.
for m_t in pl.spmd(T, name_hint="merge_norm"):
m_token_base = m_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE
for m_h_idx in pl.range((H // H_TILE)):
m_h0 = m_h_idx * H_TILE
m_blk_base = m_token_base + m_h_idx * SPARSE_BLOCKS * H_TILE
m_mi = sparse_blk_mi[m_blk_base : m_blk_base + H_TILE, 0 : 1]
m_li = sparse_blk_li[m_blk_base : m_blk_base + H_TILE, 0 : 1]
m_oi = sparse_blk_oi[m_blk_base : m_blk_base + H_TILE, 0 : HEAD_DIM]
# Guarded so the SWA (SPARSE_BLOCKS == 1) specialization uses the
# single block's stats directly instead of an empty merge loop.
if SPARSE_BLOCKS > 1:
for m_sb in pl.range(1, SPARSE_BLOCKS):
m_row = m_blk_base + m_sb * H_TILE
m_cur_mi = sparse_blk_mi[m_row : m_row + H_TILE, 0 : 1]
m_cur_li = sparse_blk_li[m_row : m_row + H_TILE, 0 : 1]
m_cur_oi = sparse_blk_oi[m_row : m_row + H_TILE, 0 : HEAD_DIM]
m_mi_new = pl.maximum(m_mi, m_cur_mi)
m_alpha = pl.exp(pl.sub(m_mi, m_mi_new))
m_beta = pl.exp(pl.sub(m_cur_mi, m_mi_new))
m_li = pl.add(pl.mul(m_alpha, m_li), pl.mul(m_beta, m_cur_li))
m_oi = pl.add(pl.row_expand_mul(m_oi, m_alpha), pl.row_expand_mul(m_cur_oi, m_beta))
m_mi = m_mi_new
n_sink_bias = pl.reshape(attn_sink[m_h0 : m_h0 + H_TILE], [H_TILE, 1])
n_sink_tile = pl.add(pl.sub(m_mi, m_mi), n_sink_bias)
n_denom = pl.add(m_li, pl.exp(pl.sub(n_sink_tile, m_mi)))
n_full = pl.row_expand_div(m_oi, n_denom)[0 : H_TILE, 0 : HEAD_DIM]
n_bf16 = pl.cast(n_full, target_type=pl.BF16, mode="rint")
n_rope_row = m_t * H + m_h0
attn_rope_stage[n_rope_row : n_rope_row + H_TILE, 0 : ROPE_DIM] = n_full[0 : H_TILE, NOPE_DIM : HEAD_DIM]
for n_hi in pl.range(H_TILE):
n_gh = m_h0 + n_hi
n_g = n_gh // HEADS_PER_GROUP
n_hh = n_gh - n_g * HEADS_PER_GROUP
n_pack_row = n_g * T + m_t
n_col = n_hh * HEAD_DIM
o_packed[n_pack_row : n_pack_row + 1, n_col : n_col + NOPE_DIM] = n_bf16[n_hi : n_hi + 1, 0 : NOPE_DIM]
with pl.scope():
sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32)
o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16)
with pl.scope():
q_flat = pl.reshape(q, [T * H, HEAD_DIM])
sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
References
  1. Declare each pl.create_tensor / pl.slice / pl.reshape immediately before the first pl.at / pl.spmd / pl.range / pl.scope (or sub-kernel call) that uses it — not in a block at the top of the function. Tight placement keeps a tensor's live range minimal, which lets the scope / memory-reuse passes free it as early as possible. (link)

Comment on lines +164 to +166
x_normed = pl.create_tensor([T, D], dtype=pl.BF16)
cmp_sparse_work = pl.create_tensor([T, SPARSE_TOPK], dtype=pl.INT32)
cmp_sparse_lens_2d = pl.create_tensor([1, T], dtype=pl.INT32)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to the repository style guide rule Declare scratch at first use, cmp_sparse_work and cmp_sparse_lens_2d should not be declared in a block at the top of the function. They should be moved down to immediately before the first block that uses them (line 250).

Suggested change
x_normed = pl.create_tensor([T, D], dtype=pl.BF16)
cmp_sparse_work = pl.create_tensor([T, SPARSE_TOPK], dtype=pl.INT32)
cmp_sparse_lens_2d = pl.create_tensor([1, T], dtype=pl.INT32)
x_normed = pl.create_tensor([T, D], dtype=pl.BF16)
References
  1. Declare each pl.create_tensor / pl.slice / pl.reshape immediately before the first pl.at / pl.spmd / pl.range / pl.scope (or sub-kernel call) that uses it — not in a block at the top of the function. Tight placement keeps a tensor's live range minimal, which lets the scope / memory-reuse passes free it as early as possible. (link)

# [T] the kernel expects -- assemble + reshape is SSA-tracked so the gather's read
# is ordered after this write (a 1D in-place pl.write races the round-trip).
cmp_sparse_lens_2d = pl.create_tensor([1, T], dtype=pl.INT32)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_csa_sparse_lens"):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to the repository style guide rule Declare scratch at first use, cmp_sparse_work and cmp_sparse_lens_2d should be declared immediately before the first block that uses them.

    cmp_sparse_work = pl.create_tensor([T, SPARSE_TOPK], dtype=pl.INT32)
    cmp_sparse_lens_2d = pl.create_tensor([1, T], dtype=pl.INT32)
    with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_csa_sparse_lens"):
References
  1. Declare each pl.create_tensor / pl.slice / pl.reshape immediately before the first pl.at / pl.spmd / pl.range / pl.scope (or sub-kernel call) that uses it — not in a block at the top of the function. Tight placement keeps a tensor's live range minimal, which lets the scope / memory-reuse passes free it as early as possible. (link)

zhangqi-chen added a commit to zhangqi-chen/pypto-lib that referenced this pull request Jun 23, 2026
Address PR hw-native-sys#585 review (gemini-code-assist):
- decode_attention_hca: pass enable_scope_stats=args.scope_stats into the
  run_jit runtime_cfg; the --scope-stats CLI flag was defined but never
  wired, so it was a no-op (swa/csa already wired it).
- decode_sparse_attn_swa / decode_sparse_attn_hca: move q_flat inside the
  S2 pl.scope() (it is only read by qk_pv in S2), matching the canonical
  decode_sparse_attn.py and tightening its live range.
- prefill_attention_csa: declare cmp_sparse_work / cmp_sparse_lens_2d
  immediately before the sparse-index build (still frame-level, read by
  sparse_attn after) instead of in the top frame block.

Behavior-preserving: prefill_attention_csa, decode_sparse_attn_swa, and
decode_sparse_attn_hca standalone tests PASS on a2a3.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
models/deepseek/v4/decode_sparse_attn_swa.py (1)

210-213: 🚀 Performance & Scalability | 🔵 Trivial | ⚡ Quick win

Move q_flat into the S2 scope.

q_flat is declared immediately before this scope but is only consumed by qk_pv inside it, so it still gets frame lifetime instead of the intended sparse-attention scope lifetime.

♻️ Proposed scoped placement
-    q_flat = pl.reshape(q, [T * H, HEAD_DIM])
     attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32)
     o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16)
     with pl.scope():
+        q_flat = pl.reshape(q, [T * H, HEAD_DIM])
         sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/decode_sparse_attn_swa.py` around lines 210 - 213, The
variable q_flat is currently declared outside the pl.scope() block but is only
consumed by qk_pv which is inside the scope. Move the q_flat declaration from
before the pl.scope() block to inside it so that its lifetime is scoped to the
sparse-attention scope where it is actually used, rather than having the broader
frame lifetime.
models/deepseek/v4/decode_sparse_attn_hca.py (1)

210-213: 🚀 Performance & Scalability | 🔵 Trivial | ⚡ Quick win

Move q_flat into the S2 scope.

q_flat is declared immediately before this scope but is only consumed by qk_pv inside it, so it still gets frame lifetime instead of the intended sparse-attention scope lifetime.

♻️ Proposed scoped placement
-    q_flat = pl.reshape(q, [T * H, HEAD_DIM])
     attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32)
     o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16)
     with pl.scope():
+        q_flat = pl.reshape(q, [T * H, HEAD_DIM])
         sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/decode_sparse_attn_hca.py` around lines 210 - 213, The
variable q_flat is currently declared outside the pl.scope() block but is only
consumed within the scope by qk_pv, which causes it to have frame lifetime
instead of the intended sparse-attention scope lifetime. Move the q_flat
declaration from outside the pl.scope() block to inside it, placing it before
the creation of sparse_blk_mi, sparse_blk_li, and sparse_blk_oi tensors, so that
its lifetime is properly scoped to the sparse-attention scope.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/deepseek/v4/decode_attention_hca.py`:
- Line 698: The --scope-stats argument was added to the argument parser but is
not connected to the runtime_cfg dictionary. Locate the runtime_cfg dictionary
initialization (around lines 705-709) and add a new entry
enable_scope_stats=args.scope_stats alongside the existing enable_l2_swimlane
entry to wire the command-line argument into the configuration that will be used
at runtime.

---

Nitpick comments:
In `@models/deepseek/v4/decode_sparse_attn_hca.py`:
- Around line 210-213: The variable q_flat is currently declared outside the
pl.scope() block but is only consumed within the scope by qk_pv, which causes it
to have frame lifetime instead of the intended sparse-attention scope lifetime.
Move the q_flat declaration from outside the pl.scope() block to inside it,
placing it before the creation of sparse_blk_mi, sparse_blk_li, and
sparse_blk_oi tensors, so that its lifetime is properly scoped to the
sparse-attention scope.

In `@models/deepseek/v4/decode_sparse_attn_swa.py`:
- Around line 210-213: The variable q_flat is currently declared outside the
pl.scope() block but is only consumed by qk_pv which is inside the scope. Move
the q_flat declaration from before the pl.scope() block to inside it so that its
lifetime is scoped to the sparse-attention scope where it is actually used,
rather than having the broader frame lifetime.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: f870a35e-70f3-4e39-b937-a21830f55302

📥 Commits

Reviewing files that changed from the base of the PR and between b51a34c and 30568b8.

📒 Files selected for processing (15)
  • docs/pypto-coding-style.md
  • models/deepseek/v4/decode_attention_csa.py
  • models/deepseek/v4/decode_attention_hca.py
  • models/deepseek/v4/decode_attention_swa.py
  • models/deepseek/v4/decode_layer.py
  • models/deepseek/v4/decode_sparse_attn.py
  • models/deepseek/v4/decode_sparse_attn_hca.py
  • models/deepseek/v4/decode_sparse_attn_swa.py
  • models/deepseek/v4/expert_routed.py
  • models/deepseek/v4/moe.py
  • models/deepseek/v4/prefill_attention_csa.py
  • models/deepseek/v4/prefill_attention_hca.py
  • models/deepseek/v4/prefill_attention_swa.py
  • models/deepseek/v4/prefill_layer.py
  • models/deepseek/v4/prefill_sparse_attn.py

help="Fixture-only compatibility seed for position_ids and slot mappings; "
"otherwise use the default per-batch coverage pattern.")
parser.add_argument("--enable-l2-swimlane", action="store_true", default=False)
parser.add_argument("--scope-stats", action="store_true", default=False)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify decode_attention_hca forwards --scope-stats into run_jit runtime_cfg.
# Expected: a nearby runtime_cfg entry like `enable_scope_stats=args.scope_stats`.
rg -n -C4 'scope_stats|enable_scope_stats|run_jit|runtime_cfg' models/deepseek/v4/decode_attention_hca.py

Repository: hw-native-sys/pypto-lib

Length of output: 1041


🏁 Script executed:

sed -n '705,730p' models/deepseek/v4/decode_attention_hca.py

Repository: hw-native-sys/pypto-lib

Length of output: 838


Add enable_scope_stats=args.scope_stats to the runtime_cfg dict.

The --scope-stats argument was added to the parser (line 698) but is not wired into runtime_cfg (lines 705–709). Without this entry, the flag has no effect. Add it alongside the existing enable_l2_swimlane entry to complete the wiring.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/decode_attention_hca.py` at line 698, The --scope-stats
argument was added to the argument parser but is not connected to the
runtime_cfg dictionary. Locate the runtime_cfg dictionary initialization (around
lines 705-709) and add a new entry enable_scope_stats=args.scope_stats alongside
the existing enable_l2_swimlane entry to wire the command-line argument into the
configuration that will be used at runtime.

Split sparse_attn into three with pl.scope() phases (KV-gather /
attention / out-proj) under auto_scope=False, so scope-local
intermediates free at scope exit instead of living to function return.

- Move S2-internal scratch (sparse_blk_mi/li/oi, q_flat) inside the
  attention scope; cross-scope buffers (sparse_kv, sparse_bias,
  attn_rope_stage, o_packed) stay function-level since the sequential
  S1->S2->S3 data flow can't nest them tighter.
- Add a --scope-stats CLI flag wiring the runtime enable_scope_stats
  DFX option, so per-scope occupancy lands in scope_stats.jsonl.

Behavior-preserving: attn_out is numerically identical (PASS). Per
scope_stats the function-frame heap residency drops ~170 -> ~90 MB.
Restructure attention_csa under auto_scope=False with nested pl.scope()
so pre-attention scratch frees before sparse_attn instead of pinning the
function frame. Cuts the scope_stats peak heap 256 -> 154 MB (-40%),
output bit-identical (kv_cache + x_out PASS).

- Outer "pre-attention" scope (hc_pre .. sparse-index build) holds the
  cross-sub-phase scratch (step/qr/x_normed_t); only the tensors read
  after sparse_attn (post_t, comb_t, rope, q, kv, cmp_sparse_indices,
  attn_out) stay in the function frame.
- Two sibling inner scopes free the biggest short-lived buffers at their
  death: {hc_pre, rms_norm} frees x_mixed, {cmp_rope, compressor} frees
  cmp_cos/sin + cmp_out. Nesting kept at 2 levels so the total depth fits
  the 4-scope budget once composed under decode_layer.
- All sub-kernel calls are now bare in-place (no returned-handle rebind):
  rebinding a frame tensor to a callee's dynamic-T return across a scope
  boundary is rejected, and dropping it also removes the cmp_out_v1 unused
  warning.
- idx_topk_full is created next to its writer (indexer) and scoped with
  the idx_tile region; cmp_sparse_indices moves to the frame as the only
  bridge into sparse_attn.
Prepare the moe FFN path for an auto_scope=False decode_layer: turning
auto_scope off without explicit scopes regresses moe_ep1 to a flat
117 MB (nothing freed). Replacing the auto-inferred scopes with explicit
ones recovers and beats it -- moe_ep1 max-live 78.4 -> 72.7 MB, end
residency 78.4 -> 5.8 MB, at scope depth 2 (fits the 4-level budget once
composed under decode_layer).

- expert_routed: auto_scope=False + a per-tile pl.scope around the
  pl.parallel(n_tiles) body (matching where auto_scope placed PTO2_SCOPE),
  so h_tile_fp32 / h_tile_i8 free each tile. Standalone: 39 MB, residency 0.
- moe / moe_ep1 (and their test entries): auto_scope=False + two sibling
  orchestration scopes -- {hc_pre, gate} frees x_mixed/x_norm, and
  {dispatch, expert_routed, combine} frees the big recv_* / recv_y before
  hc_post. Long-lived / wide-bridging tensors (post_ffn, comb_ffn, sh,
  x_norm_i8, indices, weights, ffn_out) stay in the function frame.
- Add a --scope-stats CLI flag (runtime enable_scope_stats) to both
  runners for per-scope occupancy measurement.

Verified: ep1 single-card PASS (x_next), ep2 distributed PASS on the
real dispatch/combine HCCL path. Output bit-identical.
Apply the csa scope pattern to the SWA and HCA decode-attention paths so
they are ready for an auto_scope=False decode_layer (turning auto_scope
off without explicit scopes flattens the whole inlined tree and pins all
scratch). Output bit-identical: x_out + kv_cache PASS for both.

- sparse_attn_swa / sparse_attn_hca: auto_scope=False + the committed
  decode_sparse_attn 3-scope split (S1 gather / S2 attention / S3
  out-proj); sparse_kv / sparse_bias / attn_rope_stage / o_packed stay
  at the function frame, sparse_blk_* in S2.
- attention_swa / attention_hca (and test entries): auto_scope=False +
  one pre-attention scope (hc_pre .. overlay-topk) that frees its scratch
  (x_mixed, x_normed, qr/qr_scale, and HCA's cmp_cos/sin + cmp_kv_proj)
  before sparse_attn. Frame holds only tensors read by/after sparse_attn
  (post_t, comb_t, rope_cos_t/sin_t, q, kv, topk, attn_out).
- All sub-kernel calls are bare in-place (no returned-handle rebind);
  inline call args packed several-per-line; each create_tensor sits just
  before its first-use call (attn_out before sparse_attn, etc.) unless a
  scope forces it earlier.
- Add a --scope-stats CLI flag (runtime enable_scope_stats) to both
  attention runners.
sh is first written by expert_shared (body level, not scope-constrained)
and ffn_out by combine (scope C'), so move their create_tensor next to
those sites instead of the top frame block, matching the create-at-
first-use convention. Behavior-preserving: moe_ep1 ep1 PASS, peak
unchanged (72.7 MB / 5.8 MB residency).
Declare each create_tensor / pl.slice / pl.reshape just before the first
pl.at / pl.spmd / pl.range / pl.scope (or sub-kernel call) that uses it,
not in a block at the top of the function -- this minimizes each tensor's
live range so the scope / memory-reuse passes free it early. Exception:
a tensor written inside a pl.scope() but read after it must precede that
scope (else the scope frees it at exit).
Make decode_layer auto_scope=False and wrap the per-layer attention
dispatch and the MoE call each in its own pl.scope(), so the attention
sub-kernels' scratch frees before MoE runs. x_attn bridges the two
scopes (written by attention, read by MoE) so it stays at the function
frame; the MoE call is bare in-place into the x_next out-param. Composes
the already-scoped csa/hca/swa attention + moe (each auto_scope=False)
into the full layer within the 4-level scope-nesting budget.

Not yet device-verified (compiles; ep2 run pending).
…False

Apply the decode-side auto_scope=False + pl.scope() memory-reuse refactor
to the prefill path so scratch frees at scope exit instead of pinning the
function frame once inlined. Behavior-preserving: all standalone attention
tests and prefill_layer EP2 PASS (kv_cache + x_out/x_next).

- prefill_sparse_attn: auto_scope=False + 3-scope split (S1 gather/bias,
  S2 attention, S3 inverse-RoPE + out-proj). Cross-scope buffers
  (sparse_kv, sparse_bias, attn_rope_stage, o_packed) stay at the frame;
  sparse_blk_*/q_flat live in S2.
- prefill_attention_swa/hca: auto_scope=False + one pre-attention scope
  (hc_pre .. qkv/compressor) that frees x_mixed/x_normed/qr/qr_scale
  before sparse_attn; frame holds post/comb/rope/q/kv (read after).
- prefill_attention_csa: nested scopes (outer pre-attention, inner
  {hc_pre, rms_norm} to free x_mixed). The compressor/indexer/sparse-idx
  build stay at the frame because they write the cache-state OUT-params
  this function RETURNS -- a returned param written by a bare call inside
  a scope can't bridge its SSA version out to the return.
- prefill_layer: auto_scope=False; the attention dispatch and the MoE
  call each get their own pl.scope() so attention scratch frees before
  MoE; moe is a bare in-place call into x_next.
- All sub-kernel calls are bare in-place (no returned-handle rebind);
  inline args packed several-per-line; each create_tensor sits at first
  use unless a scope/return forces it to the frame.
- Add a --scope-stats CLI flag (runtime enable_scope_stats) to the swa,
  hca, csa, and sparse_attn runners.

Verified on a2a3: prefill_attention_{swa,hca,csa} standalone PASS;
prefill_layer EP2 (CSA, layer 2) PASS on cards 14,15.
zhangqi-chen added a commit to zhangqi-chen/pypto-lib that referenced this pull request Jun 23, 2026
Address PR hw-native-sys#585 review (gemini-code-assist):
- decode_attention_hca: pass enable_scope_stats=args.scope_stats into the
  run_jit runtime_cfg; the --scope-stats CLI flag was defined but never
  wired, so it was a no-op (swa/csa already wired it).
- decode_sparse_attn_swa / decode_sparse_attn_hca: move q_flat inside the
  S2 pl.scope() (it is only read by qk_pv in S2), matching the canonical
  decode_sparse_attn.py and tightening its live range.
- prefill_attention_csa: declare cmp_sparse_work / cmp_sparse_lens_2d
  immediately before the sparse-index build (still frame-level, read by
  sparse_attn after) instead of in the top frame block.

Behavior-preserving: prefill_attention_csa, decode_sparse_attn_swa, and
decode_sparse_attn_hca standalone tests PASS on a2a3.
- Move each pl.slice / pl.create_tensor result to immediately before
  the first call that consumes it, instead of hoisting all defs to the
  top of the scope (attention/moe weight + scratch tensors now sit
  right above their attention_swa/csa/hca / moe call)
- Drop the redundant local type annotations on these results; pypto
  infers shape/dtype from the producing op, so they were noise

Keeps pl.slice rather than the [:] subscript sugar: converting these
forwarded slices to sugar makes the specializer lose the static shape
across the @pl.jit.inline boundary (callees pl.reshape kv_cache/cmp_kv/
attn_sink/...), failing with 'missing inferred tensor metadata'. The
docs claim the two forms are equivalent, so this looks like a pypto bug.

No behavior change; offsets, sizes, and dim counts are preserved.
Make the 7-layer decode_fwd auto_scope=False and wrap each layer's
attention call and MoE call in its own pl.scope(), so each sub-kernel's
scratch frees at its scope exit rather than piling up across the whole
forward. x_attn_* / hidden carries stay at the function frame (written in
one scope, read by the next). Composes the already-scoped swa/hca/csa
attention + moe within the 4-level nesting budget.

Compile verified (--compile-only ep2 PASS, 21.9s); full distributed
device run pending.
Address PR hw-native-sys#585 review (gemini-code-assist):
- decode_attention_hca: pass enable_scope_stats=args.scope_stats into the
  run_jit runtime_cfg; the --scope-stats CLI flag was defined but never
  wired, so it was a no-op (swa/csa already wired it).
- decode_sparse_attn_swa / decode_sparse_attn_hca: move q_flat inside the
  S2 pl.scope() (it is only read by qk_pv in S2), matching the canonical
  decode_sparse_attn.py and tightening its live range.
- prefill_attention_csa: declare cmp_sparse_work / cmp_sparse_lens_2d
  immediately before the sparse-index build (still frame-level, read by
  sparse_attn after) instead of in the top frame block.

Behavior-preserving: prefill_attention_csa, decode_sparse_attn_swa, and
decode_sparse_attn_hca standalone tests PASS on a2a3.
Each per-layer 'hidden = moe(...)' / 'x_next = moe(...)' reassignment
inside a pl.scope() produced a scope-local SSA name (hidden__ssa_v1) that
the next layer referenced outside the scope, failing orchestration g++
with 'hidden__ssa_v1 was not declared in this scope' (a2a3 CI). moe
writes its out-param in place, so call it bare; the carry tensor is
created before the scope and persists.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant