-
Notifications
You must be signed in to change notification settings - Fork 45
perf(dsv4 decode): retile csa/indexer/expert_shared/hc_post/moe/qkv #672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -63,8 +63,9 @@ | |||||||||||||||||||||||||||
| WEIGHTS_ROW_TILE = 8 | ||||||||||||||||||||||||||||
| QH_QUANT_TILE = 64 | ||||||||||||||||||||||||||||
| QH_HEAD_DIM_TILE = 64 | ||||||||||||||||||||||||||||
| ROPE_ROW_BLOCK = S * IDX_N_HEADS # 128 rows = one batch | ||||||||||||||||||||||||||||
| ROPE_ROW_BLOCK = S * IDX_N_HEADS | ||||||||||||||||||||||||||||
| ROPE_ROW_TILE = 32 | ||||||||||||||||||||||||||||
| ROPE_SPMD_TILE = 32 | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @pl.jit.inline | ||||||||||||||||||||||||||||
| def indexer( | ||||||||||||||||||||||||||||
|
|
@@ -118,20 +119,13 @@ def indexer( | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| qr_proj_flat = pl.reshape(qr_proj, [T * IDX_N_HEADS, IDX_HEAD_DIM]) | ||||||||||||||||||||||||||||
| qr_rope_out = pl.create_tensor([T * IDX_N_HEADS, ROPE_HEAD_DIM], dtype=pl.BF16) | ||||||||||||||||||||||||||||
| # One task per batch: a batch owns S*IDX_N_HEADS contiguous rows and a single | ||||||||||||||||||||||||||||
| # cos/sin row, so the rotation indices/sign and cos_il/sin_il are built ONCE per | ||||||||||||||||||||||||||||
| # task and reused across the inner 32-row sub-blocks. This coarsens the loop from | ||||||||||||||||||||||||||||
| # T*IDX_N_HEADS//32 (256) down to B (64) tasks while keeping the per-iter tile at | ||||||||||||||||||||||||||||
| # 32 rows (same Vec UB footprint), amortizing the in-kernel index build 4x. | ||||||||||||||||||||||||||||
| for idx in pl.spmd(T * IDX_N_HEADS // ROPE_ROW_BLOCK, name_hint="qr_rope"): | ||||||||||||||||||||||||||||
| o0 = idx * ROPE_ROW_BLOCK | ||||||||||||||||||||||||||||
| batch_idx = idx | ||||||||||||||||||||||||||||
| # A3 interleaved swap-gather (same form as q_head_rms_nope_rope / kv_rms_norm_rope), | ||||||||||||||||||||||||||||
| # replacing de-interleave gather + rotate + re-interleave scatter. The rotation | ||||||||||||||||||||||||||||
| # indices/sign and the interleave-duplicated cos/sin are built ENTIRELY IN-KERNEL: | ||||||||||||||||||||||||||||
| # swap_idx (j^1), sign ([-1,+1,...]) and dup_idx (j>>1) from pl.arange, and | ||||||||||||||||||||||||||||
| # cos_il/sin_il dup-gathered from the per-batch cos/sin row (broadcast to 32 rows). | ||||||||||||||||||||||||||||
| # out[j] = x[j]*cos_il[j] + x[j^1]*sign[j]*sin_il[j] | ||||||||||||||||||||||||||||
| # spmd over ROPE_SPMD_TILE-row blocks; batch_idx = block base // ROPE_ROW_BLOCK | ||||||||||||||||||||||||||||
| # picks the per-batch cos/sin row. Rotation indices/sign and cos_il/sin_il are | ||||||||||||||||||||||||||||
| # built once per block, reused across the inner 32-row tiles. | ||||||||||||||||||||||||||||
| # out[j] = x[j]*cos_il[j] + x[j^1]*sign[j]*sin_il[j] (sign folded into sin_il_signed) | ||||||||||||||||||||||||||||
| for idx in pl.spmd(T * IDX_N_HEADS // ROPE_SPMD_TILE, name_hint="qr_rope"): | ||||||||||||||||||||||||||||
| o0 = idx * ROPE_SPMD_TILE | ||||||||||||||||||||||||||||
| batch_idx = o0 // ROPE_ROW_BLOCK | ||||||||||||||||||||||||||||
| cos_b = cos[batch_idx : batch_idx + 1, 0 : ROPE_HEAD_DIM // 2] | ||||||||||||||||||||||||||||
| sin_b = sin[batch_idx : batch_idx + 1, 0 : ROPE_HEAD_DIM // 2] | ||||||||||||||||||||||||||||
| rope_ones = pl.full([ROPE_ROW_TILE, ROPE_HEAD_DIM], dtype=pl.FP32, value=1.0) | ||||||||||||||||||||||||||||
|
|
@@ -144,12 +138,13 @@ def indexer( | |||||||||||||||||||||||||||
| cos_b32 = pl.col_expand_mul(pl.full([ROPE_ROW_TILE, ROPE_HEAD_DIM // 2], dtype=pl.FP32, value=1.0), cos_b) | ||||||||||||||||||||||||||||
| sin_b32 = pl.col_expand_mul(pl.full([ROPE_ROW_TILE, ROPE_HEAD_DIM // 2], dtype=pl.FP32, value=1.0), sin_b) | ||||||||||||||||||||||||||||
| cos_il = pl.gather(cos_b32, dim=-1, index=rope_dup_idx) | ||||||||||||||||||||||||||||
| sin_il = pl.gather(sin_b32, dim=-1, index=rope_dup_idx) | ||||||||||||||||||||||||||||
| for ro in pl.range(0, ROPE_ROW_BLOCK, ROPE_ROW_TILE): | ||||||||||||||||||||||||||||
| # fold sign into sin_il | ||||||||||||||||||||||||||||
| sin_il_signed = pl.mul(pl.gather(sin_b32, dim=-1, index=rope_dup_idx), rope_sign) | ||||||||||||||||||||||||||||
| for ro in pl.range(0, ROPE_SPMD_TILE, ROPE_ROW_TILE): | ||||||||||||||||||||||||||||
| r0 = o0 + ro | ||||||||||||||||||||||||||||
| qr_rope_slice = qr_proj_flat[r0 : r0 + ROPE_ROW_TILE, IDX_NOPE_HEAD_DIM : IDX_HEAD_DIM] | ||||||||||||||||||||||||||||
| qr_swapped = pl.gather(qr_rope_slice, dim=-1, index=rope_swap_idx) | ||||||||||||||||||||||||||||
| rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(pl.mul(qr_swapped, rope_sign), sin_il)) | ||||||||||||||||||||||||||||
| rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(qr_swapped, sin_il_signed)) | ||||||||||||||||||||||||||||
| qr_rope_out[r0 : r0 + ROPE_ROW_TILE, :] = pl.cast(rope_rot, target_type=pl.BF16, mode="rint") | ||||||||||||||||||||||||||||
|
Comment on lines
+143
to
148
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since
Suggested change
References
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| qr_hadamard_i8 = pl.create_tensor([T * IDX_N_HEADS, IDX_HEAD_DIM], dtype=pl.INT8) | ||||||||||||||||||||||||||||
|
|
@@ -643,7 +638,8 @@ def init_idx_slot_mapping(): | |||||||||||||||||||||||||||
| parser.add_argument("-p", "--platform", type=str, default="a2a3", | ||||||||||||||||||||||||||||
| choices=["a2a3", "a2a3sim", "a5", "a5sim"]) | ||||||||||||||||||||||||||||
| parser.add_argument("-d", "--device", type=int, default=0) | ||||||||||||||||||||||||||||
| parser.add_argument("--enable-l2-swimlane", action="store_true", default=False) | ||||||||||||||||||||||||||||
| parser.add_argument("--enable-l2-swimlane", type=int, default=0, choices=[0, 1, 2], | ||||||||||||||||||||||||||||
| help="L2 swimlane level: 0=off, 1=AICore timing, 2=+AICPU timing.") | ||||||||||||||||||||||||||||
| parser.add_argument("--runtime-dir", type=str, default=None) | ||||||||||||||||||||||||||||
| parser.add_argument("--start-pos", type=int, default=None, | ||||||||||||||||||||||||||||
| help="Fixture-only compatibility seed for position_ids and slot mappings; " | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
valid_block_maskandc_blk_validare tensors, we can avoid the loop overT(which is a small static constant) by using vectorized slice assignment. This is more idiomatic in PyPTO and avoids loop overhead.