Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 34 additions & 38 deletions models/deepseek/v4/decode_attention_csa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -93,7 +93,6 @@
CMP_BLOCK_NUM = B * CMP_MAX_BLOCKS

# tiling
CSA_TOPK_TOKEN_TILE = 8
CSA_WB_TOKEN_TILE = 8
WRITEBACK_GUARD_TILE = 16
CSA_CMP_GE_BIAS = float(1 - (WIN + S)) # raw - (WIN + S) + 1, folded for the ge clamp
Expand Down Expand Up @@ -232,52 +231,49 @@
valid_block_mask = pl.create_tensor([T, SPARSE_BLOCKS], dtype=pl.INT32)
idx_topk_flat = pl.reshape(idx_topk_full, [T, INDEXER_SCORE_LEN])
position_ids_t1 = pl.reshape(position_ids, [T, 1])
for topk_block in pl.spmd(T // CSA_TOPK_TOKEN_TILE, name_hint="csa_sparse_idx_tile"):
topk_t0 = topk_block * CSA_TOPK_TOKEN_TILE
for topk_dt in pl.range(CSA_TOPK_TOKEN_TILE):
t_idx = topk_t0 + topk_dt
if t_idx < T:
topk_b = t_idx // S
topk_s = t_idx - topk_b * S
topk_abs_pos = pl.read(position_ids, [t_idx])

for topk_k in pl.range(WIN):
if topk_k <= topk_abs_pos:
topk_out = topk_k
for topk_os in pl.range(S):
if topk_os <= topk_s:
topk_overlay_t = topk_b * S + topk_os
topk_overlay_pos = pl.read(position_ids, [topk_overlay_t])
if topk_k == topk_overlay_pos % WIN:
topk_out = WIN + topk_os
pl.write(cmp_sparse_indices, [t_idx, topk_k], pl.cast(topk_out, pl.INT32))
else:
pl.write(cmp_sparse_indices, [t_idx, topk_k], pl.cast(-1, pl.INT32))
pl.write(valid_block_mask, [t_idx, 0], pl.cast(1, pl.INT32))

# Compressed slots [WIN, WIN + IDX_TOPK): vectorized masked copy. Keep raw
# iff WIN + S <= raw < WIN + S + floor((pos + 1) / 4) (== the original
# min(..., kvlen//4); pos + 1 <= kvlen holds), as out = mask * (raw + 1) - 1.
c_raw = pl.cast(idx_topk_flat[topk_t0 : topk_t0 + CSA_TOPK_TOKEN_TILE, 0 : IDX_TOPK], target_type=pl.FP32)
c_pos = pl.cast(position_ids_t1[topk_t0 : topk_t0 + CSA_TOPK_TOKEN_TILE, 0 : 1], target_type=pl.FP32)
# Overlay build, one token per SPMD core.
for t_idx in pl.spmd(T, name_hint="csa_sparse_idx_tile"):
topk_b = t_idx // S
topk_s = t_idx - topk_b * S
topk_abs_pos = pl.read(position_ids, [t_idx])
for topk_k in pl.range(WIN):
if topk_k <= topk_abs_pos:
topk_out = topk_k
for topk_os in pl.range(S):
if topk_os <= topk_s:
topk_overlay_t = topk_b * S + topk_os
topk_overlay_pos = pl.read(position_ids, [topk_overlay_t])
if topk_k == topk_overlay_pos % WIN:
topk_out = WIN + topk_os
pl.write(cmp_sparse_indices, [t_idx, topk_k], pl.cast(topk_out, pl.INT32))
else:
pl.write(cmp_sparse_indices, [t_idx, topk_k], pl.cast(-1, pl.INT32))

# Compressed slots [WIN, WIN + IDX_TOPK): vectorized masked copy over all T rows,
# keeping raw iff WIN + S <= raw < WIN + S + floor((pos + 1) / 4), as
# out = mask * (raw + 1) - 1.
with pl.at(level=pl.Level.CORE_GROUP, name_hint="csa_compressed_slots", allow_early_resolve=True):
c_raw = pl.cast(idx_topk_flat[0 : T, 0 : IDX_TOPK], target_type=pl.FP32)
c_pos = pl.cast(position_ids_t1[0 : T, 0 : 1], target_type=pl.FP32)
c_pos_q = pl.cast(pl.cast(pl.mul(pl.add(c_pos, 1.0), COMPRESS_RATIO_INV), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
c_upper = pl.add(c_pos_q, CSA_CMP_WIN_S_F)
# row_expand to broadcast the per-token bound over IDX_TOPK cols (pl.sub
# does not broadcast a [ROW_TILE, 1] column).
c_upper_b = pl.row_expand_mul(pl.full([CSA_TOPK_TOKEN_TILE, IDX_TOPK], dtype=pl.FP32, value=1.0), c_upper)
# Broadcast the per-token bound over IDX_TOPK cols.
c_upper_b = pl.row_expand_mul(pl.full([T, IDX_TOPK], dtype=pl.FP32, value=1.0), c_upper)
c_ge = pl.minimum(pl.maximum(pl.add(c_raw, CSA_CMP_GE_BIAS), 0.0), 1.0)
c_lt = pl.minimum(pl.maximum(pl.sub(c_upper_b, c_raw), 0.0), 1.0)
c_mask = pl.mul(c_ge, c_lt)
c_out = pl.sub(pl.mul(c_mask, pl.add(c_raw, 1.0)), 1.0)
cmp_sparse_indices[topk_t0 : topk_t0 + CSA_TOPK_TOKEN_TILE, WIN : WIN + IDX_TOPK] = pl.cast(c_out, target_type=pl.INT32)
cmp_sparse_indices[0 : T, WIN : WIN + IDX_TOPK] = pl.cast(c_out, target_type=pl.INT32)
# Block 0 (sliding-window / overlay) is always live; write all of
# valid_block_mask from this single scope.
for c_t0 in pl.range(T):
pl.write(valid_block_mask, [c_t0, 0], pl.cast(1, pl.INT32))
for c_sb in pl.range(1, SPARSE_BLOCKS):
c_s0 = (c_sb - 1) * ATTN_K_TILE
c_blk_valid = pl.row_max(c_mask[:, c_s0 : c_s0 + ATTN_K_TILE])
for c_dt in pl.range(CSA_TOPK_TOKEN_TILE):
c_t = topk_t0 + c_dt
if c_t < T:
c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32)
pl.write(valid_block_mask, [c_t, c_sb], c_valid)
for c_dt in pl.range(T):
c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32)
pl.write(valid_block_mask, [c_dt, c_sb], c_valid)
Comment on lines +269 to +276

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since valid_block_mask and c_blk_valid are tensors, we can avoid the loop over T (which is a small static constant) by using vectorized slice assignment. This is more idiomatic in PyPTO and avoids loop overhead.

Suggested change
for c_t0 in pl.range(T):
pl.write(valid_block_mask, [c_t0, 0], pl.cast(1, pl.INT32))
for c_sb in pl.range(1, SPARSE_BLOCKS):
c_s0 = (c_sb - 1) * ATTN_K_TILE
c_blk_valid = pl.row_max(c_mask[:, c_s0 : c_s0 + ATTN_K_TILE])
for c_dt in pl.range(CSA_TOPK_TOKEN_TILE):
c_t = topk_t0 + c_dt
if c_t < T:
c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32)
pl.write(valid_block_mask, [c_t, c_sb], c_valid)
for c_dt in pl.range(T):
c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32)
pl.write(valid_block_mask, [c_dt, c_sb], c_valid)
valid_block_mask[0 : T, 0 : 1] = pl.full([T, 1], dtype=pl.INT32, value=1)
for c_sb in pl.range(1, SPARSE_BLOCKS):
c_s0 = (c_sb - 1) * ATTN_K_TILE
c_blk_valid = pl.row_max(c_mask[:, c_s0 : c_s0 + ATTN_K_TILE])
valid_block_mask[0 : T, c_sb : c_sb + 1] = pl.cast(c_blk_valid, target_type=pl.INT32)


attn_out = pl.create_tensor([T, D], dtype=pl.BF16)
sparse_attn(
Expand Down
4 changes: 4 additions & 0 deletions models/deepseek/v4/decode_attention_hca.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -716,13 +716,17 @@
help="Fixture-only compatibility seed for position_ids and slot mappings; "
"otherwise use the default per-batch coverage pattern.")
parser.add_argument("--enable-l2-swimlane", type=int, nargs="?", const=1, default=0, choices=(0, 1, 2))
parser.add_argument("--runtime-dir", type=str, default=None)
parser.add_argument("--golden-data", type=str, default=None)
parser.add_argument("--dump-passes", action="store_true", default=False)
args = parser.parse_args()

result = run_jit(
fn=attention_hca_test,
specs=build_tensor_specs(args.start_pos),
golden_fn=golden_attention_hca,
runtime_dir=args.runtime_dir,
golden_data=args.golden_data,
compile_cfg=dict(dump_passes=args.dump_passes),
runtime_cfg=dict(
platform=args.platform,
Expand Down
34 changes: 15 additions & 19 deletions models/deepseek/v4/decode_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@
WEIGHTS_ROW_TILE = 8
QH_QUANT_TILE = 64
QH_HEAD_DIM_TILE = 64
ROPE_ROW_BLOCK = S * IDX_N_HEADS # 128 rows = one batch
ROPE_ROW_BLOCK = S * IDX_N_HEADS
ROPE_ROW_TILE = 32
ROPE_SPMD_TILE = 32

@pl.jit.inline
def indexer(
Expand Down Expand Up @@ -118,20 +119,13 @@ def indexer(

qr_proj_flat = pl.reshape(qr_proj, [T * IDX_N_HEADS, IDX_HEAD_DIM])
qr_rope_out = pl.create_tensor([T * IDX_N_HEADS, ROPE_HEAD_DIM], dtype=pl.BF16)
# One task per batch: a batch owns S*IDX_N_HEADS contiguous rows and a single
# cos/sin row, so the rotation indices/sign and cos_il/sin_il are built ONCE per
# task and reused across the inner 32-row sub-blocks. This coarsens the loop from
# T*IDX_N_HEADS//32 (256) down to B (64) tasks while keeping the per-iter tile at
# 32 rows (same Vec UB footprint), amortizing the in-kernel index build 4x.
for idx in pl.spmd(T * IDX_N_HEADS // ROPE_ROW_BLOCK, name_hint="qr_rope"):
o0 = idx * ROPE_ROW_BLOCK
batch_idx = idx
# A3 interleaved swap-gather (same form as q_head_rms_nope_rope / kv_rms_norm_rope),
# replacing de-interleave gather + rotate + re-interleave scatter. The rotation
# indices/sign and the interleave-duplicated cos/sin are built ENTIRELY IN-KERNEL:
# swap_idx (j^1), sign ([-1,+1,...]) and dup_idx (j>>1) from pl.arange, and
# cos_il/sin_il dup-gathered from the per-batch cos/sin row (broadcast to 32 rows).
# out[j] = x[j]*cos_il[j] + x[j^1]*sign[j]*sin_il[j]
# spmd over ROPE_SPMD_TILE-row blocks; batch_idx = block base // ROPE_ROW_BLOCK
# picks the per-batch cos/sin row. Rotation indices/sign and cos_il/sin_il are
# built once per block, reused across the inner 32-row tiles.
# out[j] = x[j]*cos_il[j] + x[j^1]*sign[j]*sin_il[j] (sign folded into sin_il_signed)
for idx in pl.spmd(T * IDX_N_HEADS // ROPE_SPMD_TILE, name_hint="qr_rope"):
o0 = idx * ROPE_SPMD_TILE
batch_idx = o0 // ROPE_ROW_BLOCK
cos_b = cos[batch_idx : batch_idx + 1, 0 : ROPE_HEAD_DIM // 2]
sin_b = sin[batch_idx : batch_idx + 1, 0 : ROPE_HEAD_DIM // 2]
rope_ones = pl.full([ROPE_ROW_TILE, ROPE_HEAD_DIM], dtype=pl.FP32, value=1.0)
Expand All @@ -144,12 +138,13 @@ def indexer(
cos_b32 = pl.col_expand_mul(pl.full([ROPE_ROW_TILE, ROPE_HEAD_DIM // 2], dtype=pl.FP32, value=1.0), cos_b)
sin_b32 = pl.col_expand_mul(pl.full([ROPE_ROW_TILE, ROPE_HEAD_DIM // 2], dtype=pl.FP32, value=1.0), sin_b)
cos_il = pl.gather(cos_b32, dim=-1, index=rope_dup_idx)
sin_il = pl.gather(sin_b32, dim=-1, index=rope_dup_idx)
for ro in pl.range(0, ROPE_ROW_BLOCK, ROPE_ROW_TILE):
# fold sign into sin_il
sin_il_signed = pl.mul(pl.gather(sin_b32, dim=-1, index=rope_dup_idx), rope_sign)
for ro in pl.range(0, ROPE_SPMD_TILE, ROPE_ROW_TILE):
r0 = o0 + ro
qr_rope_slice = qr_proj_flat[r0 : r0 + ROPE_ROW_TILE, IDX_NOPE_HEAD_DIM : IDX_HEAD_DIM]
qr_swapped = pl.gather(qr_rope_slice, dim=-1, index=rope_swap_idx)
rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(pl.mul(qr_swapped, rope_sign), sin_il))
rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(qr_swapped, sin_il_signed))
qr_rope_out[r0 : r0 + ROPE_ROW_TILE, :] = pl.cast(rope_rot, target_type=pl.BF16, mode="rint")
Comment on lines +143 to 148

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since ROPE_SPMD_TILE and ROPE_ROW_TILE are both hardcoded to 32, the loop for ro in pl.range(0, ROPE_SPMD_TILE, ROPE_ROW_TILE) will always execute for exactly one iteration (ro = 0). We can simplify the code and eliminate the loop overhead by removing the loop entirely. However, when assuming single-tile coverage, we must make this invariant explicit with an assertion (e.g., assert ROPE_SPMD_TILE == ROPE_ROW_TILE) to prevent silent correctness issues if configurations change in the future.

Suggested change
for ro in pl.range(0, ROPE_SPMD_TILE, ROPE_ROW_TILE):
r0 = o0 + ro
qr_rope_slice = qr_proj_flat[r0 : r0 + ROPE_ROW_TILE, IDX_NOPE_HEAD_DIM : IDX_HEAD_DIM]
qr_swapped = pl.gather(qr_rope_slice, dim=-1, index=rope_swap_idx)
rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(pl.mul(qr_swapped, rope_sign), sin_il))
rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(qr_swapped, sin_il_signed))
qr_rope_out[r0 : r0 + ROPE_ROW_TILE, :] = pl.cast(rope_rot, target_type=pl.BF16, mode="rint")
assert ROPE_SPMD_TILE == ROPE_ROW_TILE, "ROPE_SPMD_TILE must match ROPE_ROW_TILE for single-tile coverage"
r0 = o0
qr_rope_slice = qr_proj_flat[r0 : r0 + ROPE_ROW_TILE, IDX_NOPE_HEAD_DIM : IDX_HEAD_DIM]
qr_swapped = pl.gather(qr_rope_slice, dim=-1, index=rope_swap_idx)
rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(qr_swapped, sin_il_signed))
qr_rope_out[r0 : r0 + ROPE_ROW_TILE, :] = pl.cast(rope_rot, target_type=pl.BF16, mode="rint")
References
  1. When assuming a single-tile coverage (e.g., T_PAD <= MM_ROW_TILE) in a kernel, make this invariant explicit with a module-level assertion (e.g., assert T_PAD == MM_ROW_TILE) to prevent silent correctness issues if configurations change.


qr_hadamard_i8 = pl.create_tensor([T * IDX_N_HEADS, IDX_HEAD_DIM], dtype=pl.INT8)
Expand Down Expand Up @@ -643,7 +638,8 @@ def init_idx_slot_mapping():
parser.add_argument("-p", "--platform", type=str, default="a2a3",
choices=["a2a3", "a2a3sim", "a5", "a5sim"])
parser.add_argument("-d", "--device", type=int, default=0)
parser.add_argument("--enable-l2-swimlane", action="store_true", default=False)
parser.add_argument("--enable-l2-swimlane", type=int, default=0, choices=[0, 1, 2],
help="L2 swimlane level: 0=off, 1=AICore timing, 2=+AICPU timing.")
parser.add_argument("--runtime-dir", type=str, default=None)
parser.add_argument("--start-pos", type=int, default=None,
help="Fixture-only compatibility seed for position_ids and slot mappings; "
Expand Down
2 changes: 1 addition & 1 deletion models/deepseek/v4/decode_layer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -833,7 +833,7 @@
atol=1e-3,
compare_fn={
# Real-weight x_next over-thd fractions (frac>5e-3 / frac>1e-2):
# swa(L0) 0.4% / 0.003%, hca(L9) 1.9% / 1.0%, csa(L8) 3.8% / 0.6%.
# swa(L0) 0.7% / 0.006%, hca(L9) 0.5% / 0.003%, csa(L8) 3.8% / 0.4%.
"x_next": ratio_reldiff(diff_thd=0.01, pct_thd=0.05),
"kv_cache": ratio_allclose(atol=1e-4, rtol=1.0 / 128),
},
Expand Down
2 changes: 1 addition & 1 deletion models/deepseek/v4/decode_sparse_attn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -200,7 +200,7 @@

# Additive softmax bias (0 valid / NEG_INF invalid) that qk_pv adds onto the
# scaled scores, so invalid lanes exp to ~0 with no per-block mask multiply.
for v_blk in pl.spmd(T // VALID_TOKEN_TILE, name_hint="build_valid"):
for v_blk in pl.spmd(T // VALID_TOKEN_TILE, name_hint="build_valid", allow_early_resolve=True):
v_t0 = v_blk * VALID_TOKEN_TILE
v_idx_f = pl.cast(cmp_sparse_indices[v_t0 : v_t0 + VALID_TOKEN_TILE, 0 : TOPK], target_type=pl.FP32)
# Index contract (line 138): raw == -1 invalid, raw >= 0 valid. min(idx, 0)
Expand Down
2 changes: 1 addition & 1 deletion models/deepseek/v4/decode_sparse_attn_hca.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -201,7 +201,7 @@

# Additive softmax bias (0 valid / NEG_INF invalid) that qk_pv adds onto the
# scaled scores, so invalid lanes exp to ~0 with no per-block mask multiply.
for v_blk in pl.spmd(T // VALID_TOKEN_TILE, name_hint="build_valid"):
for v_blk in pl.spmd(T // VALID_TOKEN_TILE, name_hint="build_valid", allow_early_resolve=True):
v_t0 = v_blk * VALID_TOKEN_TILE
# Read the full PADDED_TOPK (256, 32-byte aligned); cmp_sparse_indices pads its
# tail with -1, so the padded lanes get a NEG_INF bias for free.
Expand Down
Loading
Loading