Refactor: scope dsv4 decode+prefill kernels under auto_scope=False#585
Refactor: scope dsv4 decode+prefill kernels under auto_scope=False#585zhangqi-chen wants to merge 12 commits into
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAll DeepSeek V4 decode and prefill kernels ( ChangesExplicit
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request optimizes memory reuse across DeepSeek v4 model components (decode/prefill attention paths, decode layer, MoE, and expert routing) by disabling automatic scoping and manually declaring pl.scope() blocks to minimize tensor live ranges. It also introduces a --scope-stats CLI argument to collect occupancy statistics. The reviewer identified a bug in decode_attention_hca.py where the --scope-stats flag is non-functional because the argument is not passed to run_jit. Additionally, several style guide violations were found regarding the 'Declare scratch at first use' rule, specifically where scratch tensors (such as q_flat, cmp_sparse_work, and cmp_sparse_lens_2d) were declared too early or outside their minimal active scopes.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| help="Fixture-only compatibility seed for position_ids and slot mappings; " | ||
| "otherwise use the default per-batch coverage pattern.") | ||
| parser.add_argument("--enable-l2-swimlane", action="store_true", default=False) | ||
| parser.add_argument("--scope-stats", action="store_true", default=False) |
| q_flat = pl.reshape(q, [T * H, HEAD_DIM]) | ||
| attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32) | ||
| o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16) | ||
| sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) | ||
| sparse_blk_li = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) | ||
| sparse_blk_oi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, HEAD_DIM], dtype=pl.FP32) | ||
|
|
||
| for qk_t in pl.spmd(T, name_hint="qk_pv"): | ||
| qk_kv_base = qk_t * PADDED_TOPK | ||
| qk_token_base = qk_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE | ||
| # Sparse-block OUTER / head-tile INNER: the KV tile and bias depend only | ||
| # on (token, sparse-block), so hoisting them above the head-tile loop lets | ||
| # one KV/bias load serve all head-tiles instead of re-loading per head. | ||
| for qk_sb in pl.range(SPARSE_BLOCKS): | ||
| qk_s0 = qk_sb * ATTN_K_TILE | ||
| qk_kv_k = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM] | ||
| qk_kv_v = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM] | ||
| qk_bias_row = sparse_bias[qk_t : qk_t + 1, qk_s0 : qk_s0 + ATTN_K_TILE] | ||
|
|
||
| # Cube-batch QK_M_TILE head rows per QK/PV matmul so the shared KV | ||
| # tile is extracted L1->L0 once per QK_M_TILE/H_TILE head-tiles | ||
| # (2x reuse at QK_M_TILE=32) instead of per head-tile. The | ||
| # [QK_M_TILE, ...] softmax result is sliced back into H_TILE-row | ||
| # stores at the SAME offsets as the per-head-tile path | ||
| # (qk_h_idx == qk_hb * (QK_M_TILE // H_TILE) + qk_sub), so the | ||
| # sparse_blk_* layout and merge_norm are bit-identical. | ||
| for qk_hb in pl.pipeline(H // QK_M_TILE, stage=2): | ||
| qk_h0 = qk_hb * QK_M_TILE | ||
| qk_head_row = qk_t * H + qk_h0 | ||
| qk_q_tile = q_flat[qk_head_row : qk_head_row + QK_M_TILE, 0 : HEAD_DIM] | ||
| qk_raw = pl.matmul(qk_q_tile, qk_kv_k, b_trans=True, out_dtype=pl.FP32) | ||
| qk_scaled = pl.mul(qk_raw, SOFTMAX_SCALE) | ||
| qk_scores = pl.add(qk_scaled, pl.col_expand(pl.full([QK_M_TILE, ATTN_K_TILE], dtype=pl.FP32, value=0.0), qk_bias_row)) | ||
| qk_mi = pl.row_max(qk_scores) | ||
| # Invalid lanes (NEG_INF bias, zero kv rows) exp to ~0; all-invalid | ||
| # blocks die in the merge alpha/beta -- no mask multiply needed. | ||
| qk_exp = pl.exp(pl.row_expand_sub(qk_scores, qk_mi)) | ||
| qk_li = pl.row_sum(qk_exp) | ||
| qk_exp_bf16 = pl.cast(qk_exp, target_type=pl.BF16, mode="rint") | ||
| qk_oi = pl.matmul(qk_exp_bf16, qk_kv_v, out_dtype=pl.FP32) | ||
| for qk_sub in pl.unroll(QK_M_TILE // H_TILE): | ||
| qk_h_idx = qk_hb * (QK_M_TILE // H_TILE) + qk_sub | ||
| qk_r0 = qk_sub * H_TILE | ||
| qk_blk_base = qk_token_base + qk_h_idx * SPARSE_BLOCKS * H_TILE | ||
| qk_row = qk_blk_base + qk_sb * H_TILE | ||
| sparse_blk_mi[qk_row : qk_row + H_TILE, 0 : 1] = qk_mi[qk_r0 : qk_r0 + H_TILE, 0 : 1] | ||
| sparse_blk_li[qk_row : qk_row + H_TILE, 0 : 1] = qk_li[qk_r0 : qk_r0 + H_TILE, 0 : 1] | ||
| sparse_blk_oi[qk_row : qk_row + H_TILE, 0 : HEAD_DIM] = qk_oi[qk_r0 : qk_r0 + H_TILE, 0 : HEAD_DIM] | ||
|
|
||
| # Online-softmax merge across sparse-K tiles, then sink-norm. | ||
| for m_t in pl.spmd(T, name_hint="merge_norm"): | ||
| m_token_base = m_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE | ||
|
|
||
| for m_h_idx in pl.range((H // H_TILE)): | ||
| m_h0 = m_h_idx * H_TILE | ||
| m_blk_base = m_token_base + m_h_idx * SPARSE_BLOCKS * H_TILE | ||
| m_mi = sparse_blk_mi[m_blk_base : m_blk_base + H_TILE, 0 : 1] | ||
| m_li = sparse_blk_li[m_blk_base : m_blk_base + H_TILE, 0 : 1] | ||
| m_oi = sparse_blk_oi[m_blk_base : m_blk_base + H_TILE, 0 : HEAD_DIM] | ||
|
|
||
| # Guarded so the SWA (SPARSE_BLOCKS == 1) specialization uses the | ||
| # single block's stats directly instead of an empty merge loop. | ||
| if SPARSE_BLOCKS > 1: | ||
| for m_sb in pl.range(1, SPARSE_BLOCKS): | ||
| m_row = m_blk_base + m_sb * H_TILE | ||
| m_cur_mi = sparse_blk_mi[m_row : m_row + H_TILE, 0 : 1] | ||
| m_cur_li = sparse_blk_li[m_row : m_row + H_TILE, 0 : 1] | ||
| m_cur_oi = sparse_blk_oi[m_row : m_row + H_TILE, 0 : HEAD_DIM] | ||
| m_mi_new = pl.maximum(m_mi, m_cur_mi) | ||
| m_alpha = pl.exp(pl.sub(m_mi, m_mi_new)) | ||
| m_beta = pl.exp(pl.sub(m_cur_mi, m_mi_new)) | ||
| m_li = pl.add(pl.mul(m_alpha, m_li), pl.mul(m_beta, m_cur_li)) | ||
| m_oi = pl.add(pl.row_expand_mul(m_oi, m_alpha), pl.row_expand_mul(m_cur_oi, m_beta)) | ||
| m_mi = m_mi_new | ||
|
|
||
| n_sink_bias = pl.reshape(attn_sink[m_h0 : m_h0 + H_TILE], [H_TILE, 1]) | ||
| n_sink_tile = pl.add(pl.sub(m_mi, m_mi), n_sink_bias) | ||
| n_denom = pl.add(m_li, pl.exp(pl.sub(n_sink_tile, m_mi))) | ||
| n_full = pl.row_expand_div(m_oi, n_denom)[0 : H_TILE, 0 : HEAD_DIM] | ||
| n_bf16 = pl.cast(n_full, target_type=pl.BF16, mode="rint") | ||
| n_rope_row = m_t * H + m_h0 | ||
| attn_rope_stage[n_rope_row : n_rope_row + H_TILE, 0 : ROPE_DIM] = n_full[0 : H_TILE, NOPE_DIM : HEAD_DIM] | ||
|
|
||
| for n_hi in pl.range(H_TILE): | ||
| n_gh = m_h0 + n_hi | ||
| n_g = n_gh // HEADS_PER_GROUP | ||
| n_hh = n_gh - n_g * HEADS_PER_GROUP | ||
| n_pack_row = n_g * T + m_t | ||
| n_col = n_hh * HEAD_DIM | ||
| o_packed[n_pack_row : n_pack_row + 1, n_col : n_col + NOPE_DIM] = n_bf16[n_hi : n_hi + 1, 0 : NOPE_DIM] | ||
| with pl.scope(): | ||
| sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) |
There was a problem hiding this comment.
According to the repository style guide rule Declare scratch at first use, q_flat should be declared immediately before the first block that uses it. Since q_flat is only used inside the with pl.scope(): block (S2), declaring it outside the scope extends its live range unnecessarily to the function frame level. It should be moved inside the with pl.scope(): block, consistent with decode_sparse_attn.py.
| q_flat = pl.reshape(q, [T * H, HEAD_DIM]) | |
| attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32) | |
| o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16) | |
| sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) | |
| sparse_blk_li = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) | |
| sparse_blk_oi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, HEAD_DIM], dtype=pl.FP32) | |
| for qk_t in pl.spmd(T, name_hint="qk_pv"): | |
| qk_kv_base = qk_t * PADDED_TOPK | |
| qk_token_base = qk_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE | |
| # Sparse-block OUTER / head-tile INNER: the KV tile and bias depend only | |
| # on (token, sparse-block), so hoisting them above the head-tile loop lets | |
| # one KV/bias load serve all head-tiles instead of re-loading per head. | |
| for qk_sb in pl.range(SPARSE_BLOCKS): | |
| qk_s0 = qk_sb * ATTN_K_TILE | |
| qk_kv_k = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM] | |
| qk_kv_v = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM] | |
| qk_bias_row = sparse_bias[qk_t : qk_t + 1, qk_s0 : qk_s0 + ATTN_K_TILE] | |
| # Cube-batch QK_M_TILE head rows per QK/PV matmul so the shared KV | |
| # tile is extracted L1->L0 once per QK_M_TILE/H_TILE head-tiles | |
| # (2x reuse at QK_M_TILE=32) instead of per head-tile. The | |
| # [QK_M_TILE, ...] softmax result is sliced back into H_TILE-row | |
| # stores at the SAME offsets as the per-head-tile path | |
| # (qk_h_idx == qk_hb * (QK_M_TILE // H_TILE) + qk_sub), so the | |
| # sparse_blk_* layout and merge_norm are bit-identical. | |
| for qk_hb in pl.pipeline(H // QK_M_TILE, stage=2): | |
| qk_h0 = qk_hb * QK_M_TILE | |
| qk_head_row = qk_t * H + qk_h0 | |
| qk_q_tile = q_flat[qk_head_row : qk_head_row + QK_M_TILE, 0 : HEAD_DIM] | |
| qk_raw = pl.matmul(qk_q_tile, qk_kv_k, b_trans=True, out_dtype=pl.FP32) | |
| qk_scaled = pl.mul(qk_raw, SOFTMAX_SCALE) | |
| qk_scores = pl.add(qk_scaled, pl.col_expand(pl.full([QK_M_TILE, ATTN_K_TILE], dtype=pl.FP32, value=0.0), qk_bias_row)) | |
| qk_mi = pl.row_max(qk_scores) | |
| # Invalid lanes (NEG_INF bias, zero kv rows) exp to ~0; all-invalid | |
| # blocks die in the merge alpha/beta -- no mask multiply needed. | |
| qk_exp = pl.exp(pl.row_expand_sub(qk_scores, qk_mi)) | |
| qk_li = pl.row_sum(qk_exp) | |
| qk_exp_bf16 = pl.cast(qk_exp, target_type=pl.BF16, mode="rint") | |
| qk_oi = pl.matmul(qk_exp_bf16, qk_kv_v, out_dtype=pl.FP32) | |
| for qk_sub in pl.unroll(QK_M_TILE // H_TILE): | |
| qk_h_idx = qk_hb * (QK_M_TILE // H_TILE) + qk_sub | |
| qk_r0 = qk_sub * H_TILE | |
| qk_blk_base = qk_token_base + qk_h_idx * SPARSE_BLOCKS * H_TILE | |
| qk_row = qk_blk_base + qk_sb * H_TILE | |
| sparse_blk_mi[qk_row : qk_row + H_TILE, 0 : 1] = qk_mi[qk_r0 : qk_r0 + H_TILE, 0 : 1] | |
| sparse_blk_li[qk_row : qk_row + H_TILE, 0 : 1] = qk_li[qk_r0 : qk_r0 + H_TILE, 0 : 1] | |
| sparse_blk_oi[qk_row : qk_row + H_TILE, 0 : HEAD_DIM] = qk_oi[qk_r0 : qk_r0 + H_TILE, 0 : HEAD_DIM] | |
| # Online-softmax merge across sparse-K tiles, then sink-norm. | |
| for m_t in pl.spmd(T, name_hint="merge_norm"): | |
| m_token_base = m_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE | |
| for m_h_idx in pl.range((H // H_TILE)): | |
| m_h0 = m_h_idx * H_TILE | |
| m_blk_base = m_token_base + m_h_idx * SPARSE_BLOCKS * H_TILE | |
| m_mi = sparse_blk_mi[m_blk_base : m_blk_base + H_TILE, 0 : 1] | |
| m_li = sparse_blk_li[m_blk_base : m_blk_base + H_TILE, 0 : 1] | |
| m_oi = sparse_blk_oi[m_blk_base : m_blk_base + H_TILE, 0 : HEAD_DIM] | |
| # Guarded so the SWA (SPARSE_BLOCKS == 1) specialization uses the | |
| # single block's stats directly instead of an empty merge loop. | |
| if SPARSE_BLOCKS > 1: | |
| for m_sb in pl.range(1, SPARSE_BLOCKS): | |
| m_row = m_blk_base + m_sb * H_TILE | |
| m_cur_mi = sparse_blk_mi[m_row : m_row + H_TILE, 0 : 1] | |
| m_cur_li = sparse_blk_li[m_row : m_row + H_TILE, 0 : 1] | |
| m_cur_oi = sparse_blk_oi[m_row : m_row + H_TILE, 0 : HEAD_DIM] | |
| m_mi_new = pl.maximum(m_mi, m_cur_mi) | |
| m_alpha = pl.exp(pl.sub(m_mi, m_mi_new)) | |
| m_beta = pl.exp(pl.sub(m_cur_mi, m_mi_new)) | |
| m_li = pl.add(pl.mul(m_alpha, m_li), pl.mul(m_beta, m_cur_li)) | |
| m_oi = pl.add(pl.row_expand_mul(m_oi, m_alpha), pl.row_expand_mul(m_cur_oi, m_beta)) | |
| m_mi = m_mi_new | |
| n_sink_bias = pl.reshape(attn_sink[m_h0 : m_h0 + H_TILE], [H_TILE, 1]) | |
| n_sink_tile = pl.add(pl.sub(m_mi, m_mi), n_sink_bias) | |
| n_denom = pl.add(m_li, pl.exp(pl.sub(n_sink_tile, m_mi))) | |
| n_full = pl.row_expand_div(m_oi, n_denom)[0 : H_TILE, 0 : HEAD_DIM] | |
| n_bf16 = pl.cast(n_full, target_type=pl.BF16, mode="rint") | |
| n_rope_row = m_t * H + m_h0 | |
| attn_rope_stage[n_rope_row : n_rope_row + H_TILE, 0 : ROPE_DIM] = n_full[0 : H_TILE, NOPE_DIM : HEAD_DIM] | |
| for n_hi in pl.range(H_TILE): | |
| n_gh = m_h0 + n_hi | |
| n_g = n_gh // HEADS_PER_GROUP | |
| n_hh = n_gh - n_g * HEADS_PER_GROUP | |
| n_pack_row = n_g * T + m_t | |
| n_col = n_hh * HEAD_DIM | |
| o_packed[n_pack_row : n_pack_row + 1, n_col : n_col + NOPE_DIM] = n_bf16[n_hi : n_hi + 1, 0 : NOPE_DIM] | |
| with pl.scope(): | |
| sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) | |
| attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32) | |
| o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16) | |
| with pl.scope(): | |
| q_flat = pl.reshape(q, [T * H, HEAD_DIM]) | |
| sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) |
References
- Declare each pl.create_tensor / pl.slice / pl.reshape immediately before the first pl.at / pl.spmd / pl.range / pl.scope (or sub-kernel call) that uses it — not in a block at the top of the function. Tight placement keeps a tensor's live range minimal, which lets the scope / memory-reuse passes free it as early as possible. (link)
| q_flat = pl.reshape(q, [T * H, HEAD_DIM]) | ||
| attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32) | ||
| o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16) | ||
| sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) | ||
| sparse_blk_li = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) | ||
| sparse_blk_oi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, HEAD_DIM], dtype=pl.FP32) | ||
|
|
||
| for qk_t in pl.spmd(T, name_hint="qk_pv"): | ||
| qk_kv_base = qk_t * PADDED_TOPK | ||
| qk_token_base = qk_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE | ||
| # Sparse-block OUTER / head-tile INNER: the KV tile and bias depend only | ||
| # on (token, sparse-block), so hoisting them above the head-tile loop lets | ||
| # one KV/bias load serve all head-tiles instead of re-loading per head. | ||
| for qk_sb in pl.range(SPARSE_BLOCKS): | ||
| qk_s0 = qk_sb * ATTN_K_TILE | ||
| qk_kv_k = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM] | ||
| qk_kv_v = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM] | ||
| qk_bias_row = sparse_bias[qk_t : qk_t + 1, qk_s0 : qk_s0 + ATTN_K_TILE] | ||
|
|
||
| # Cube-batch QK_M_TILE head rows per QK/PV matmul so the shared KV | ||
| # tile is extracted L1->L0 once per QK_M_TILE/H_TILE head-tiles | ||
| # (2x reuse at QK_M_TILE=32) instead of per head-tile. The | ||
| # [QK_M_TILE, ...] softmax result is sliced back into H_TILE-row | ||
| # stores at the SAME offsets as the per-head-tile path | ||
| # (qk_h_idx == qk_hb * (QK_M_TILE // H_TILE) + qk_sub), so the | ||
| # sparse_blk_* layout and merge_norm are bit-identical. | ||
| for qk_hb in pl.pipeline(H // QK_M_TILE, stage=2): | ||
| qk_h0 = qk_hb * QK_M_TILE | ||
| qk_head_row = qk_t * H + qk_h0 | ||
| qk_q_tile = q_flat[qk_head_row : qk_head_row + QK_M_TILE, 0 : HEAD_DIM] | ||
| qk_raw = pl.matmul(qk_q_tile, qk_kv_k, b_trans=True, out_dtype=pl.FP32) | ||
| qk_scaled = pl.mul(qk_raw, SOFTMAX_SCALE) | ||
| qk_scores = pl.add(qk_scaled, pl.col_expand(pl.full([QK_M_TILE, ATTN_K_TILE], dtype=pl.FP32, value=0.0), qk_bias_row)) | ||
| qk_mi = pl.row_max(qk_scores) | ||
| # Invalid lanes (NEG_INF bias, zero kv rows) exp to ~0; all-invalid | ||
| # blocks die in the merge alpha/beta -- no mask multiply needed. | ||
| qk_exp = pl.exp(pl.row_expand_sub(qk_scores, qk_mi)) | ||
| qk_li = pl.row_sum(qk_exp) | ||
| qk_exp_bf16 = pl.cast(qk_exp, target_type=pl.BF16, mode="rint") | ||
| qk_oi = pl.matmul(qk_exp_bf16, qk_kv_v, out_dtype=pl.FP32) | ||
| for qk_sub in pl.unroll(QK_M_TILE // H_TILE): | ||
| qk_h_idx = qk_hb * (QK_M_TILE // H_TILE) + qk_sub | ||
| qk_r0 = qk_sub * H_TILE | ||
| qk_blk_base = qk_token_base + qk_h_idx * SPARSE_BLOCKS * H_TILE | ||
| qk_row = qk_blk_base + qk_sb * H_TILE | ||
| sparse_blk_mi[qk_row : qk_row + H_TILE, 0 : 1] = qk_mi[qk_r0 : qk_r0 + H_TILE, 0 : 1] | ||
| sparse_blk_li[qk_row : qk_row + H_TILE, 0 : 1] = qk_li[qk_r0 : qk_r0 + H_TILE, 0 : 1] | ||
| sparse_blk_oi[qk_row : qk_row + H_TILE, 0 : HEAD_DIM] = qk_oi[qk_r0 : qk_r0 + H_TILE, 0 : HEAD_DIM] | ||
|
|
||
| # Online-softmax merge across sparse-K tiles, then sink-norm. | ||
| for m_t in pl.spmd(T, name_hint="merge_norm"): | ||
| m_token_base = m_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE | ||
|
|
||
| for m_h_idx in pl.range((H // H_TILE)): | ||
| m_h0 = m_h_idx * H_TILE | ||
| m_blk_base = m_token_base + m_h_idx * SPARSE_BLOCKS * H_TILE | ||
| m_mi = sparse_blk_mi[m_blk_base : m_blk_base + H_TILE, 0 : 1] | ||
| m_li = sparse_blk_li[m_blk_base : m_blk_base + H_TILE, 0 : 1] | ||
| m_oi = sparse_blk_oi[m_blk_base : m_blk_base + H_TILE, 0 : HEAD_DIM] | ||
|
|
||
| # Guarded so the SWA (SPARSE_BLOCKS == 1) specialization uses the | ||
| # single block's stats directly instead of an empty merge loop. | ||
| if SPARSE_BLOCKS > 1: | ||
| for m_sb in pl.range(1, SPARSE_BLOCKS): | ||
| m_row = m_blk_base + m_sb * H_TILE | ||
| m_cur_mi = sparse_blk_mi[m_row : m_row + H_TILE, 0 : 1] | ||
| m_cur_li = sparse_blk_li[m_row : m_row + H_TILE, 0 : 1] | ||
| m_cur_oi = sparse_blk_oi[m_row : m_row + H_TILE, 0 : HEAD_DIM] | ||
| m_mi_new = pl.maximum(m_mi, m_cur_mi) | ||
| m_alpha = pl.exp(pl.sub(m_mi, m_mi_new)) | ||
| m_beta = pl.exp(pl.sub(m_cur_mi, m_mi_new)) | ||
| m_li = pl.add(pl.mul(m_alpha, m_li), pl.mul(m_beta, m_cur_li)) | ||
| m_oi = pl.add(pl.row_expand_mul(m_oi, m_alpha), pl.row_expand_mul(m_cur_oi, m_beta)) | ||
| m_mi = m_mi_new | ||
|
|
||
| n_sink_bias = pl.reshape(attn_sink[m_h0 : m_h0 + H_TILE], [H_TILE, 1]) | ||
| n_sink_tile = pl.add(pl.sub(m_mi, m_mi), n_sink_bias) | ||
| n_denom = pl.add(m_li, pl.exp(pl.sub(n_sink_tile, m_mi))) | ||
| n_full = pl.row_expand_div(m_oi, n_denom)[0 : H_TILE, 0 : HEAD_DIM] | ||
| n_bf16 = pl.cast(n_full, target_type=pl.BF16, mode="rint") | ||
| n_rope_row = m_t * H + m_h0 | ||
| attn_rope_stage[n_rope_row : n_rope_row + H_TILE, 0 : ROPE_DIM] = n_full[0 : H_TILE, NOPE_DIM : HEAD_DIM] | ||
|
|
||
| for n_hi in pl.range(H_TILE): | ||
| n_gh = m_h0 + n_hi | ||
| n_g = n_gh // HEADS_PER_GROUP | ||
| n_hh = n_gh - n_g * HEADS_PER_GROUP | ||
| n_pack_row = n_g * T + m_t | ||
| n_col = n_hh * HEAD_DIM | ||
| o_packed[n_pack_row : n_pack_row + 1, n_col : n_col + NOPE_DIM] = n_bf16[n_hi : n_hi + 1, 0 : NOPE_DIM] | ||
| with pl.scope(): | ||
| sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) |
There was a problem hiding this comment.
According to the repository style guide rule Declare scratch at first use, q_flat should be declared immediately before the first block that uses it. Since q_flat is only used inside the with pl.scope(): block (S2), declaring it outside the scope extends its live range unnecessarily to the function frame level. It should be moved inside the with pl.scope(): block, consistent with decode_sparse_attn.py.
| q_flat = pl.reshape(q, [T * H, HEAD_DIM]) | |
| attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32) | |
| o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16) | |
| sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) | |
| sparse_blk_li = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) | |
| sparse_blk_oi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, HEAD_DIM], dtype=pl.FP32) | |
| for qk_t in pl.spmd(T, name_hint="qk_pv"): | |
| qk_kv_base = qk_t * PADDED_TOPK | |
| qk_token_base = qk_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE | |
| # Sparse-block OUTER / head-tile INNER: the KV tile and bias depend only | |
| # on (token, sparse-block), so hoisting them above the head-tile loop lets | |
| # one KV/bias load serve all head-tiles instead of re-loading per head. | |
| for qk_sb in pl.range(SPARSE_BLOCKS): | |
| qk_s0 = qk_sb * ATTN_K_TILE | |
| qk_kv_k = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM] | |
| qk_kv_v = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM] | |
| qk_bias_row = sparse_bias[qk_t : qk_t + 1, qk_s0 : qk_s0 + ATTN_K_TILE] | |
| # Cube-batch QK_M_TILE head rows per QK/PV matmul so the shared KV | |
| # tile is extracted L1->L0 once per QK_M_TILE/H_TILE head-tiles | |
| # (2x reuse at QK_M_TILE=32) instead of per head-tile. The | |
| # [QK_M_TILE, ...] softmax result is sliced back into H_TILE-row | |
| # stores at the SAME offsets as the per-head-tile path | |
| # (qk_h_idx == qk_hb * (QK_M_TILE // H_TILE) + qk_sub), so the | |
| # sparse_blk_* layout and merge_norm are bit-identical. | |
| for qk_hb in pl.pipeline(H // QK_M_TILE, stage=2): | |
| qk_h0 = qk_hb * QK_M_TILE | |
| qk_head_row = qk_t * H + qk_h0 | |
| qk_q_tile = q_flat[qk_head_row : qk_head_row + QK_M_TILE, 0 : HEAD_DIM] | |
| qk_raw = pl.matmul(qk_q_tile, qk_kv_k, b_trans=True, out_dtype=pl.FP32) | |
| qk_scaled = pl.mul(qk_raw, SOFTMAX_SCALE) | |
| qk_scores = pl.add(qk_scaled, pl.col_expand(pl.full([QK_M_TILE, ATTN_K_TILE], dtype=pl.FP32, value=0.0), qk_bias_row)) | |
| qk_mi = pl.row_max(qk_scores) | |
| # Invalid lanes (NEG_INF bias, zero kv rows) exp to ~0; all-invalid | |
| # blocks die in the merge alpha/beta -- no mask multiply needed. | |
| qk_exp = pl.exp(pl.row_expand_sub(qk_scores, qk_mi)) | |
| qk_li = pl.row_sum(qk_exp) | |
| qk_exp_bf16 = pl.cast(qk_exp, target_type=pl.BF16, mode="rint") | |
| qk_oi = pl.matmul(qk_exp_bf16, qk_kv_v, out_dtype=pl.FP32) | |
| for qk_sub in pl.unroll(QK_M_TILE // H_TILE): | |
| qk_h_idx = qk_hb * (QK_M_TILE // H_TILE) + qk_sub | |
| qk_r0 = qk_sub * H_TILE | |
| qk_blk_base = qk_token_base + qk_h_idx * SPARSE_BLOCKS * H_TILE | |
| qk_row = qk_blk_base + qk_sb * H_TILE | |
| sparse_blk_mi[qk_row : qk_row + H_TILE, 0 : 1] = qk_mi[qk_r0 : qk_r0 + H_TILE, 0 : 1] | |
| sparse_blk_li[qk_row : qk_row + H_TILE, 0 : 1] = qk_li[qk_r0 : qk_r0 + H_TILE, 0 : 1] | |
| sparse_blk_oi[qk_row : qk_row + H_TILE, 0 : HEAD_DIM] = qk_oi[qk_r0 : qk_r0 + H_TILE, 0 : HEAD_DIM] | |
| # Online-softmax merge across sparse-K tiles, then sink-norm. | |
| for m_t in pl.spmd(T, name_hint="merge_norm"): | |
| m_token_base = m_t * (H // H_TILE) * SPARSE_BLOCKS * H_TILE | |
| for m_h_idx in pl.range((H // H_TILE)): | |
| m_h0 = m_h_idx * H_TILE | |
| m_blk_base = m_token_base + m_h_idx * SPARSE_BLOCKS * H_TILE | |
| m_mi = sparse_blk_mi[m_blk_base : m_blk_base + H_TILE, 0 : 1] | |
| m_li = sparse_blk_li[m_blk_base : m_blk_base + H_TILE, 0 : 1] | |
| m_oi = sparse_blk_oi[m_blk_base : m_blk_base + H_TILE, 0 : HEAD_DIM] | |
| # Guarded so the SWA (SPARSE_BLOCKS == 1) specialization uses the | |
| # single block's stats directly instead of an empty merge loop. | |
| if SPARSE_BLOCKS > 1: | |
| for m_sb in pl.range(1, SPARSE_BLOCKS): | |
| m_row = m_blk_base + m_sb * H_TILE | |
| m_cur_mi = sparse_blk_mi[m_row : m_row + H_TILE, 0 : 1] | |
| m_cur_li = sparse_blk_li[m_row : m_row + H_TILE, 0 : 1] | |
| m_cur_oi = sparse_blk_oi[m_row : m_row + H_TILE, 0 : HEAD_DIM] | |
| m_mi_new = pl.maximum(m_mi, m_cur_mi) | |
| m_alpha = pl.exp(pl.sub(m_mi, m_mi_new)) | |
| m_beta = pl.exp(pl.sub(m_cur_mi, m_mi_new)) | |
| m_li = pl.add(pl.mul(m_alpha, m_li), pl.mul(m_beta, m_cur_li)) | |
| m_oi = pl.add(pl.row_expand_mul(m_oi, m_alpha), pl.row_expand_mul(m_cur_oi, m_beta)) | |
| m_mi = m_mi_new | |
| n_sink_bias = pl.reshape(attn_sink[m_h0 : m_h0 + H_TILE], [H_TILE, 1]) | |
| n_sink_tile = pl.add(pl.sub(m_mi, m_mi), n_sink_bias) | |
| n_denom = pl.add(m_li, pl.exp(pl.sub(n_sink_tile, m_mi))) | |
| n_full = pl.row_expand_div(m_oi, n_denom)[0 : H_TILE, 0 : HEAD_DIM] | |
| n_bf16 = pl.cast(n_full, target_type=pl.BF16, mode="rint") | |
| n_rope_row = m_t * H + m_h0 | |
| attn_rope_stage[n_rope_row : n_rope_row + H_TILE, 0 : ROPE_DIM] = n_full[0 : H_TILE, NOPE_DIM : HEAD_DIM] | |
| for n_hi in pl.range(H_TILE): | |
| n_gh = m_h0 + n_hi | |
| n_g = n_gh // HEADS_PER_GROUP | |
| n_hh = n_gh - n_g * HEADS_PER_GROUP | |
| n_pack_row = n_g * T + m_t | |
| n_col = n_hh * HEAD_DIM | |
| o_packed[n_pack_row : n_pack_row + 1, n_col : n_col + NOPE_DIM] = n_bf16[n_hi : n_hi + 1, 0 : NOPE_DIM] | |
| with pl.scope(): | |
| sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) | |
| attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32) | |
| o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16) | |
| with pl.scope(): | |
| q_flat = pl.reshape(q, [T * H, HEAD_DIM]) | |
| sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32) |
References
- Declare each pl.create_tensor / pl.slice / pl.reshape immediately before the first pl.at / pl.spmd / pl.range / pl.scope (or sub-kernel call) that uses it — not in a block at the top of the function. Tight placement keeps a tensor's live range minimal, which lets the scope / memory-reuse passes free it as early as possible. (link)
| x_normed = pl.create_tensor([T, D], dtype=pl.BF16) | ||
| cmp_sparse_work = pl.create_tensor([T, SPARSE_TOPK], dtype=pl.INT32) | ||
| cmp_sparse_lens_2d = pl.create_tensor([1, T], dtype=pl.INT32) |
There was a problem hiding this comment.
According to the repository style guide rule Declare scratch at first use, cmp_sparse_work and cmp_sparse_lens_2d should not be declared in a block at the top of the function. They should be moved down to immediately before the first block that uses them (line 250).
| x_normed = pl.create_tensor([T, D], dtype=pl.BF16) | |
| cmp_sparse_work = pl.create_tensor([T, SPARSE_TOPK], dtype=pl.INT32) | |
| cmp_sparse_lens_2d = pl.create_tensor([1, T], dtype=pl.INT32) | |
| x_normed = pl.create_tensor([T, D], dtype=pl.BF16) |
References
- Declare each pl.create_tensor / pl.slice / pl.reshape immediately before the first pl.at / pl.spmd / pl.range / pl.scope (or sub-kernel call) that uses it — not in a block at the top of the function. Tight placement keeps a tensor's live range minimal, which lets the scope / memory-reuse passes free it as early as possible. (link)
| # [T] the kernel expects -- assemble + reshape is SSA-tracked so the gather's read | ||
| # is ordered after this write (a 1D in-place pl.write races the round-trip). | ||
| cmp_sparse_lens_2d = pl.create_tensor([1, T], dtype=pl.INT32) | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_csa_sparse_lens"): |
There was a problem hiding this comment.
According to the repository style guide rule Declare scratch at first use, cmp_sparse_work and cmp_sparse_lens_2d should be declared immediately before the first block that uses them.
cmp_sparse_work = pl.create_tensor([T, SPARSE_TOPK], dtype=pl.INT32)
cmp_sparse_lens_2d = pl.create_tensor([1, T], dtype=pl.INT32)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_csa_sparse_lens"):References
- Declare each pl.create_tensor / pl.slice / pl.reshape immediately before the first pl.at / pl.spmd / pl.range / pl.scope (or sub-kernel call) that uses it — not in a block at the top of the function. Tight placement keeps a tensor's live range minimal, which lets the scope / memory-reuse passes free it as early as possible. (link)
Address PR hw-native-sys#585 review (gemini-code-assist): - decode_attention_hca: pass enable_scope_stats=args.scope_stats into the run_jit runtime_cfg; the --scope-stats CLI flag was defined but never wired, so it was a no-op (swa/csa already wired it). - decode_sparse_attn_swa / decode_sparse_attn_hca: move q_flat inside the S2 pl.scope() (it is only read by qk_pv in S2), matching the canonical decode_sparse_attn.py and tightening its live range. - prefill_attention_csa: declare cmp_sparse_work / cmp_sparse_lens_2d immediately before the sparse-index build (still frame-level, read by sparse_attn after) instead of in the top frame block. Behavior-preserving: prefill_attention_csa, decode_sparse_attn_swa, and decode_sparse_attn_hca standalone tests PASS on a2a3.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
models/deepseek/v4/decode_sparse_attn_swa.py (1)
210-213: 🚀 Performance & Scalability | 🔵 Trivial | ⚡ Quick winMove
q_flatinto the S2 scope.
q_flatis declared immediately before this scope but is only consumed byqk_pvinside it, so it still gets frame lifetime instead of the intended sparse-attention scope lifetime.♻️ Proposed scoped placement
- q_flat = pl.reshape(q, [T * H, HEAD_DIM]) attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32) o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16) with pl.scope(): + q_flat = pl.reshape(q, [T * H, HEAD_DIM]) sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/decode_sparse_attn_swa.py` around lines 210 - 213, The variable q_flat is currently declared outside the pl.scope() block but is only consumed by qk_pv which is inside the scope. Move the q_flat declaration from before the pl.scope() block to inside it so that its lifetime is scoped to the sparse-attention scope where it is actually used, rather than having the broader frame lifetime.models/deepseek/v4/decode_sparse_attn_hca.py (1)
210-213: 🚀 Performance & Scalability | 🔵 Trivial | ⚡ Quick winMove
q_flatinto the S2 scope.
q_flatis declared immediately before this scope but is only consumed byqk_pvinside it, so it still gets frame lifetime instead of the intended sparse-attention scope lifetime.♻️ Proposed scoped placement
- q_flat = pl.reshape(q, [T * H, HEAD_DIM]) attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32) o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16) with pl.scope(): + q_flat = pl.reshape(q, [T * H, HEAD_DIM]) sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * SPARSE_BLOCKS * H_TILE, 1], dtype=pl.FP32)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/decode_sparse_attn_hca.py` around lines 210 - 213, The variable q_flat is currently declared outside the pl.scope() block but is only consumed within the scope by qk_pv, which causes it to have frame lifetime instead of the intended sparse-attention scope lifetime. Move the q_flat declaration from outside the pl.scope() block to inside it, placing it before the creation of sparse_blk_mi, sparse_blk_li, and sparse_blk_oi tensors, so that its lifetime is properly scoped to the sparse-attention scope.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/deepseek/v4/decode_attention_hca.py`:
- Line 698: The --scope-stats argument was added to the argument parser but is
not connected to the runtime_cfg dictionary. Locate the runtime_cfg dictionary
initialization (around lines 705-709) and add a new entry
enable_scope_stats=args.scope_stats alongside the existing enable_l2_swimlane
entry to wire the command-line argument into the configuration that will be used
at runtime.
---
Nitpick comments:
In `@models/deepseek/v4/decode_sparse_attn_hca.py`:
- Around line 210-213: The variable q_flat is currently declared outside the
pl.scope() block but is only consumed within the scope by qk_pv, which causes it
to have frame lifetime instead of the intended sparse-attention scope lifetime.
Move the q_flat declaration from outside the pl.scope() block to inside it,
placing it before the creation of sparse_blk_mi, sparse_blk_li, and
sparse_blk_oi tensors, so that its lifetime is properly scoped to the
sparse-attention scope.
In `@models/deepseek/v4/decode_sparse_attn_swa.py`:
- Around line 210-213: The variable q_flat is currently declared outside the
pl.scope() block but is only consumed by qk_pv which is inside the scope. Move
the q_flat declaration from before the pl.scope() block to inside it so that its
lifetime is scoped to the sparse-attention scope where it is actually used,
rather than having the broader frame lifetime.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: f870a35e-70f3-4e39-b937-a21830f55302
📒 Files selected for processing (15)
docs/pypto-coding-style.mdmodels/deepseek/v4/decode_attention_csa.pymodels/deepseek/v4/decode_attention_hca.pymodels/deepseek/v4/decode_attention_swa.pymodels/deepseek/v4/decode_layer.pymodels/deepseek/v4/decode_sparse_attn.pymodels/deepseek/v4/decode_sparse_attn_hca.pymodels/deepseek/v4/decode_sparse_attn_swa.pymodels/deepseek/v4/expert_routed.pymodels/deepseek/v4/moe.pymodels/deepseek/v4/prefill_attention_csa.pymodels/deepseek/v4/prefill_attention_hca.pymodels/deepseek/v4/prefill_attention_swa.pymodels/deepseek/v4/prefill_layer.pymodels/deepseek/v4/prefill_sparse_attn.py
| help="Fixture-only compatibility seed for position_ids and slot mappings; " | ||
| "otherwise use the default per-batch coverage pattern.") | ||
| parser.add_argument("--enable-l2-swimlane", action="store_true", default=False) | ||
| parser.add_argument("--scope-stats", action="store_true", default=False) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify decode_attention_hca forwards --scope-stats into run_jit runtime_cfg.
# Expected: a nearby runtime_cfg entry like `enable_scope_stats=args.scope_stats`.
rg -n -C4 'scope_stats|enable_scope_stats|run_jit|runtime_cfg' models/deepseek/v4/decode_attention_hca.pyRepository: hw-native-sys/pypto-lib
Length of output: 1041
🏁 Script executed:
sed -n '705,730p' models/deepseek/v4/decode_attention_hca.pyRepository: hw-native-sys/pypto-lib
Length of output: 838
Add enable_scope_stats=args.scope_stats to the runtime_cfg dict.
The --scope-stats argument was added to the parser (line 698) but is not wired into runtime_cfg (lines 705–709). Without this entry, the flag has no effect. Add it alongside the existing enable_l2_swimlane entry to complete the wiring.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/deepseek/v4/decode_attention_hca.py` at line 698, The --scope-stats
argument was added to the argument parser but is not connected to the
runtime_cfg dictionary. Locate the runtime_cfg dictionary initialization (around
lines 705-709) and add a new entry enable_scope_stats=args.scope_stats alongside
the existing enable_l2_swimlane entry to wire the command-line argument into the
configuration that will be used at runtime.
Split sparse_attn into three with pl.scope() phases (KV-gather / attention / out-proj) under auto_scope=False, so scope-local intermediates free at scope exit instead of living to function return. - Move S2-internal scratch (sparse_blk_mi/li/oi, q_flat) inside the attention scope; cross-scope buffers (sparse_kv, sparse_bias, attn_rope_stage, o_packed) stay function-level since the sequential S1->S2->S3 data flow can't nest them tighter. - Add a --scope-stats CLI flag wiring the runtime enable_scope_stats DFX option, so per-scope occupancy lands in scope_stats.jsonl. Behavior-preserving: attn_out is numerically identical (PASS). Per scope_stats the function-frame heap residency drops ~170 -> ~90 MB.
Restructure attention_csa under auto_scope=False with nested pl.scope()
so pre-attention scratch frees before sparse_attn instead of pinning the
function frame. Cuts the scope_stats peak heap 256 -> 154 MB (-40%),
output bit-identical (kv_cache + x_out PASS).
- Outer "pre-attention" scope (hc_pre .. sparse-index build) holds the
cross-sub-phase scratch (step/qr/x_normed_t); only the tensors read
after sparse_attn (post_t, comb_t, rope, q, kv, cmp_sparse_indices,
attn_out) stay in the function frame.
- Two sibling inner scopes free the biggest short-lived buffers at their
death: {hc_pre, rms_norm} frees x_mixed, {cmp_rope, compressor} frees
cmp_cos/sin + cmp_out. Nesting kept at 2 levels so the total depth fits
the 4-scope budget once composed under decode_layer.
- All sub-kernel calls are now bare in-place (no returned-handle rebind):
rebinding a frame tensor to a callee's dynamic-T return across a scope
boundary is rejected, and dropping it also removes the cmp_out_v1 unused
warning.
- idx_topk_full is created next to its writer (indexer) and scoped with
the idx_tile region; cmp_sparse_indices moves to the frame as the only
bridge into sparse_attn.
Prepare the moe FFN path for an auto_scope=False decode_layer: turning
auto_scope off without explicit scopes regresses moe_ep1 to a flat
117 MB (nothing freed). Replacing the auto-inferred scopes with explicit
ones recovers and beats it -- moe_ep1 max-live 78.4 -> 72.7 MB, end
residency 78.4 -> 5.8 MB, at scope depth 2 (fits the 4-level budget once
composed under decode_layer).
- expert_routed: auto_scope=False + a per-tile pl.scope around the
pl.parallel(n_tiles) body (matching where auto_scope placed PTO2_SCOPE),
so h_tile_fp32 / h_tile_i8 free each tile. Standalone: 39 MB, residency 0.
- moe / moe_ep1 (and their test entries): auto_scope=False + two sibling
orchestration scopes -- {hc_pre, gate} frees x_mixed/x_norm, and
{dispatch, expert_routed, combine} frees the big recv_* / recv_y before
hc_post. Long-lived / wide-bridging tensors (post_ffn, comb_ffn, sh,
x_norm_i8, indices, weights, ffn_out) stay in the function frame.
- Add a --scope-stats CLI flag (runtime enable_scope_stats) to both
runners for per-scope occupancy measurement.
Verified: ep1 single-card PASS (x_next), ep2 distributed PASS on the
real dispatch/combine HCCL path. Output bit-identical.
Apply the csa scope pattern to the SWA and HCA decode-attention paths so they are ready for an auto_scope=False decode_layer (turning auto_scope off without explicit scopes flattens the whole inlined tree and pins all scratch). Output bit-identical: x_out + kv_cache PASS for both. - sparse_attn_swa / sparse_attn_hca: auto_scope=False + the committed decode_sparse_attn 3-scope split (S1 gather / S2 attention / S3 out-proj); sparse_kv / sparse_bias / attn_rope_stage / o_packed stay at the function frame, sparse_blk_* in S2. - attention_swa / attention_hca (and test entries): auto_scope=False + one pre-attention scope (hc_pre .. overlay-topk) that frees its scratch (x_mixed, x_normed, qr/qr_scale, and HCA's cmp_cos/sin + cmp_kv_proj) before sparse_attn. Frame holds only tensors read by/after sparse_attn (post_t, comb_t, rope_cos_t/sin_t, q, kv, topk, attn_out). - All sub-kernel calls are bare in-place (no returned-handle rebind); inline call args packed several-per-line; each create_tensor sits just before its first-use call (attn_out before sparse_attn, etc.) unless a scope forces it earlier. - Add a --scope-stats CLI flag (runtime enable_scope_stats) to both attention runners.
sh is first written by expert_shared (body level, not scope-constrained) and ffn_out by combine (scope C'), so move their create_tensor next to those sites instead of the top frame block, matching the create-at- first-use convention. Behavior-preserving: moe_ep1 ep1 PASS, peak unchanged (72.7 MB / 5.8 MB residency).
Declare each create_tensor / pl.slice / pl.reshape just before the first pl.at / pl.spmd / pl.range / pl.scope (or sub-kernel call) that uses it, not in a block at the top of the function -- this minimizes each tensor's live range so the scope / memory-reuse passes free it early. Exception: a tensor written inside a pl.scope() but read after it must precede that scope (else the scope frees it at exit).
Make decode_layer auto_scope=False and wrap the per-layer attention dispatch and the MoE call each in its own pl.scope(), so the attention sub-kernels' scratch frees before MoE runs. x_attn bridges the two scopes (written by attention, read by MoE) so it stays at the function frame; the MoE call is bare in-place into the x_next out-param. Composes the already-scoped csa/hca/swa attention + moe (each auto_scope=False) into the full layer within the 4-level scope-nesting budget. Not yet device-verified (compiles; ep2 run pending).
…False
Apply the decode-side auto_scope=False + pl.scope() memory-reuse refactor
to the prefill path so scratch frees at scope exit instead of pinning the
function frame once inlined. Behavior-preserving: all standalone attention
tests and prefill_layer EP2 PASS (kv_cache + x_out/x_next).
- prefill_sparse_attn: auto_scope=False + 3-scope split (S1 gather/bias,
S2 attention, S3 inverse-RoPE + out-proj). Cross-scope buffers
(sparse_kv, sparse_bias, attn_rope_stage, o_packed) stay at the frame;
sparse_blk_*/q_flat live in S2.
- prefill_attention_swa/hca: auto_scope=False + one pre-attention scope
(hc_pre .. qkv/compressor) that frees x_mixed/x_normed/qr/qr_scale
before sparse_attn; frame holds post/comb/rope/q/kv (read after).
- prefill_attention_csa: nested scopes (outer pre-attention, inner
{hc_pre, rms_norm} to free x_mixed). The compressor/indexer/sparse-idx
build stay at the frame because they write the cache-state OUT-params
this function RETURNS -- a returned param written by a bare call inside
a scope can't bridge its SSA version out to the return.
- prefill_layer: auto_scope=False; the attention dispatch and the MoE
call each get their own pl.scope() so attention scratch frees before
MoE; moe is a bare in-place call into x_next.
- All sub-kernel calls are bare in-place (no returned-handle rebind);
inline args packed several-per-line; each create_tensor sits at first
use unless a scope/return forces it to the frame.
- Add a --scope-stats CLI flag (runtime enable_scope_stats) to the swa,
hca, csa, and sparse_attn runners.
Verified on a2a3: prefill_attention_{swa,hca,csa} standalone PASS;
prefill_layer EP2 (CSA, layer 2) PASS on cards 14,15.
Address PR hw-native-sys#585 review (gemini-code-assist): - decode_attention_hca: pass enable_scope_stats=args.scope_stats into the run_jit runtime_cfg; the --scope-stats CLI flag was defined but never wired, so it was a no-op (swa/csa already wired it). - decode_sparse_attn_swa / decode_sparse_attn_hca: move q_flat inside the S2 pl.scope() (it is only read by qk_pv in S2), matching the canonical decode_sparse_attn.py and tightening its live range. - prefill_attention_csa: declare cmp_sparse_work / cmp_sparse_lens_2d immediately before the sparse-index build (still frame-level, read by sparse_attn after) instead of in the top frame block. Behavior-preserving: prefill_attention_csa, decode_sparse_attn_swa, and decode_sparse_attn_hca standalone tests PASS on a2a3.
- Move each pl.slice / pl.create_tensor result to immediately before the first call that consumes it, instead of hoisting all defs to the top of the scope (attention/moe weight + scratch tensors now sit right above their attention_swa/csa/hca / moe call) - Drop the redundant local type annotations on these results; pypto infers shape/dtype from the producing op, so they were noise Keeps pl.slice rather than the [:] subscript sugar: converting these forwarded slices to sugar makes the specializer lose the static shape across the @pl.jit.inline boundary (callees pl.reshape kv_cache/cmp_kv/ attn_sink/...), failing with 'missing inferred tensor metadata'. The docs claim the two forms are equivalent, so this looks like a pypto bug. No behavior change; offsets, sizes, and dim counts are preserved.
Make the 7-layer decode_fwd auto_scope=False and wrap each layer's attention call and MoE call in its own pl.scope(), so each sub-kernel's scratch frees at its scope exit rather than piling up across the whole forward. x_attn_* / hidden carries stay at the function frame (written in one scope, read by the next). Composes the already-scoped swa/hca/csa attention + moe within the 4-level nesting budget. Compile verified (--compile-only ep2 PASS, 21.9s); full distributed device run pending.
Address PR hw-native-sys#585 review (gemini-code-assist): - decode_attention_hca: pass enable_scope_stats=args.scope_stats into the run_jit runtime_cfg; the --scope-stats CLI flag was defined but never wired, so it was a no-op (swa/csa already wired it). - decode_sparse_attn_swa / decode_sparse_attn_hca: move q_flat inside the S2 pl.scope() (it is only read by qk_pv in S2), matching the canonical decode_sparse_attn.py and tightening its live range. - prefill_attention_csa: declare cmp_sparse_work / cmp_sparse_lens_2d immediately before the sparse-index build (still frame-level, read by sparse_attn after) instead of in the top frame block. Behavior-preserving: prefill_attention_csa, decode_sparse_attn_swa, and decode_sparse_attn_hca standalone tests PASS on a2a3.
Each per-layer 'hidden = moe(...)' / 'x_next = moe(...)' reassignment inside a pl.scope() produced a scope-local SSA name (hidden__ssa_v1) that the next layer referenced outside the scope, failing orchestration g++ with 'hidden__ssa_v1 was not declared in this scope' (a2a3 CI). moe writes its out-param in place, so call it bare; the carry tensor is created before the scope and persists.
Summary
@pl.jit(.inline)(auto_scope=False)with explicitpl.scope()blocks so scratch frees at scope exit instead of pinning the function frame once inlined. Behavior-preserving across the board.decode_sparse_attn/prefill_sparse_attn: 3-scope split (S1 gather/bias, S2 attention, S3 inverse-RoPE + output projection); cross-scope buffers (sparse_kv,sparse_bias,attn_rope_stage,o_packed) stay at the frame,sparse_blk_*/q_flatlive in S2.decode_attention_csaandprefill_attention_{swa,hca,csa}: a pre-attention scope freesx_mixed/x_normed/qr/qr_scalebeforesparse_attn; the frame keeps only what is read after (post/comb/rope/q/kv). Forprefill_attention_csathe compressor/indexer/sparse-idx build stay at the frame because they write the cache-state OUT-params the function RETURNS (a returned param written by a bare call inside a scope cannot bridge its SSA version to the return).moe+expert_routed: scoped forauto_scope=Falsecomposition; create scratch (sh/ffn_out) at first use.decode_layer/prefill_layer:auto_scope=False; the attention dispatch and the MoE call each get their ownpl.scope()so attention scratch frees before MoE; MoE is a bare in-place call intox_next. Generated orchestration nestsPTO2_SCOPEat most 3 levels deep (within the 4-level budget).create_tensorsits at first use unless a scope/return forces it to the frame.decode_fwd: converted to[:]slice sugar with just-in-time per-layer defs.docs/pypto-coding-style.md(sec 8).--scope-statsCLI flag (runtimeenable_scope_stats) to the affected attention / sparse_attn runners.Verified on a2a3:
prefill_attention_{swa,hca,csa}standalone PASS;prefill_layerEP2 (CSA, layer 2) PASS on cards 14,15. Decode-side scope refactors output bit-identical (kv_cache + x_out PASS).Related Issues
N/A