diff --git a/models/deepseek/v4/decode_attention_csa.py b/models/deepseek/v4/decode_attention_csa.py index 2f0eb94b..e828a0df 100644 --- a/models/deepseek/v4/decode_attention_csa.py +++ b/models/deepseek/v4/decode_attention_csa.py @@ -98,7 +98,6 @@ CMP_BLOCK_NUM = DECODE_CMP_BLOCK_NUM # tiling -CSA_TOPK_TOKEN_TILE = 8 CSA_WB_TOKEN_TILE = 8 WRITEBACK_GUARD_TILE = 16 CSA_CMP_GE_BIAS = float(1 - (WIN + S)) # raw - (WIN + S) + 1, folded for the ge clamp @@ -237,52 +236,49 @@ def attention_csa( valid_block_mask = pl.create_tensor([T, SPARSE_BLOCKS], dtype=pl.INT32) idx_topk_flat = pl.reshape(idx_topk_full, [T, INDEXER_SCORE_LEN]) position_ids_t1 = pl.reshape(position_ids, [T, 1]) - for topk_block in pl.spmd(T // CSA_TOPK_TOKEN_TILE, name_hint="csa_sparse_idx_tile"): - topk_t0 = topk_block * CSA_TOPK_TOKEN_TILE - for topk_dt in pl.range(CSA_TOPK_TOKEN_TILE): - t_idx = topk_t0 + topk_dt - if t_idx < T: - topk_b = t_idx // S - topk_s = t_idx - topk_b * S - topk_abs_pos = pl.read(position_ids, [t_idx]) - - for topk_k in pl.range(WIN): - if topk_k <= topk_abs_pos: - topk_out = topk_k - for topk_os in pl.range(S): - if topk_os <= topk_s: - topk_overlay_t = topk_b * S + topk_os - topk_overlay_pos = pl.read(position_ids, [topk_overlay_t]) - if topk_k == topk_overlay_pos % WIN: - topk_out = WIN + topk_os - pl.write(cmp_sparse_indices, [t_idx, topk_k], pl.cast(topk_out, pl.INT32)) - else: - pl.write(cmp_sparse_indices, [t_idx, topk_k], pl.cast(-1, pl.INT32)) - pl.write(valid_block_mask, [t_idx, 0], pl.cast(1, pl.INT32)) - - # Compressed slots [WIN, WIN + IDX_TOPK): vectorized masked copy. Keep raw - # iff WIN + S <= raw < WIN + S + floor((pos + 1) / 4) (== the original - # min(..., kvlen//4); pos + 1 <= kvlen holds), as out = mask * (raw + 1) - 1. - c_raw = pl.cast(idx_topk_flat[topk_t0 : topk_t0 + CSA_TOPK_TOKEN_TILE, 0 : IDX_TOPK], target_type=pl.FP32) - c_pos = pl.cast(position_ids_t1[topk_t0 : topk_t0 + CSA_TOPK_TOKEN_TILE, 0 : 1], target_type=pl.FP32) + # Overlay build, one token per SPMD core. + for t_idx in pl.spmd(T, name_hint="csa_sparse_idx_tile"): + topk_b = t_idx // S + topk_s = t_idx - topk_b * S + topk_abs_pos = pl.read(position_ids, [t_idx]) + for topk_k in pl.range(WIN): + if topk_k <= topk_abs_pos: + topk_out = topk_k + for topk_os in pl.range(S): + if topk_os <= topk_s: + topk_overlay_t = topk_b * S + topk_os + topk_overlay_pos = pl.read(position_ids, [topk_overlay_t]) + if topk_k == topk_overlay_pos % WIN: + topk_out = WIN + topk_os + pl.write(cmp_sparse_indices, [t_idx, topk_k], pl.cast(topk_out, pl.INT32)) + else: + pl.write(cmp_sparse_indices, [t_idx, topk_k], pl.cast(-1, pl.INT32)) + + # Compressed slots [WIN, WIN + IDX_TOPK): vectorized masked copy over all T rows, + # keeping raw iff WIN + S <= raw < WIN + S + floor((pos + 1) / 4), as + # out = mask * (raw + 1) - 1. + with pl.at(level=pl.Level.CORE_GROUP, name_hint="csa_compressed_slots", allow_early_resolve=True): + c_raw = pl.cast(idx_topk_flat[0 : T, 0 : IDX_TOPK], target_type=pl.FP32) + c_pos = pl.cast(position_ids_t1[0 : T, 0 : 1], target_type=pl.FP32) c_pos_q = pl.cast(pl.cast(pl.mul(pl.add(c_pos, 1.0), COMPRESS_RATIO_INV), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) c_upper = pl.add(c_pos_q, CSA_CMP_WIN_S_F) - # row_expand to broadcast the per-token bound over IDX_TOPK cols (pl.sub - # does not broadcast a [ROW_TILE, 1] column). - c_upper_b = pl.row_expand_mul(pl.full([CSA_TOPK_TOKEN_TILE, IDX_TOPK], dtype=pl.FP32, value=1.0), c_upper) + # Broadcast the per-token bound over IDX_TOPK cols. + c_upper_b = pl.row_expand_mul(pl.full([T, IDX_TOPK], dtype=pl.FP32, value=1.0), c_upper) c_ge = pl.minimum(pl.maximum(pl.add(c_raw, CSA_CMP_GE_BIAS), 0.0), 1.0) c_lt = pl.minimum(pl.maximum(pl.sub(c_upper_b, c_raw), 0.0), 1.0) c_mask = pl.mul(c_ge, c_lt) c_out = pl.sub(pl.mul(c_mask, pl.add(c_raw, 1.0)), 1.0) - cmp_sparse_indices[topk_t0 : topk_t0 + CSA_TOPK_TOKEN_TILE, WIN : WIN + IDX_TOPK] = pl.cast(c_out, target_type=pl.INT32) + cmp_sparse_indices[0 : T, WIN : WIN + IDX_TOPK] = pl.cast(c_out, target_type=pl.INT32) + # Block 0 (sliding-window / overlay) is always live; write all of + # valid_block_mask from this single scope. + for c_t0 in pl.range(T): + pl.write(valid_block_mask, [c_t0, 0], pl.cast(1, pl.INT32)) for c_sb in pl.range(1, SPARSE_BLOCKS): c_s0 = (c_sb - 1) * ATTN_K_TILE c_blk_valid = pl.row_max(c_mask[:, c_s0 : c_s0 + ATTN_K_TILE]) - for c_dt in pl.range(CSA_TOPK_TOKEN_TILE): - c_t = topk_t0 + c_dt - if c_t < T: - c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32) - pl.write(valid_block_mask, [c_t, c_sb], c_valid) + for c_dt in pl.range(T): + c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32) + pl.write(valid_block_mask, [c_dt, c_sb], c_valid) attn_out = pl.create_tensor([T, D], dtype=pl.BF16) sparse_attn( diff --git a/models/deepseek/v4/decode_attention_hca.py b/models/deepseek/v4/decode_attention_hca.py index 55c1d1fd..934f0cd6 100644 --- a/models/deepseek/v4/decode_attention_hca.py +++ b/models/deepseek/v4/decode_attention_hca.py @@ -720,6 +720,8 @@ def init_wo_b(): help="Fixture-only compatibility seed for position_ids and slot mappings; " "otherwise use the default per-batch coverage pattern.") parser.add_argument("--enable-l2-swimlane", type=int, nargs="?", const=1, default=0, choices=(0, 1, 2)) + parser.add_argument("--runtime-dir", type=str, default=None) + parser.add_argument("--golden-data", type=str, default=None) parser.add_argument("--dump-passes", action="store_true", default=False) args = parser.parse_args() @@ -727,6 +729,8 @@ def init_wo_b(): fn=attention_hca_test, specs=build_tensor_specs(args.start_pos), golden_fn=golden_attention_hca, + runtime_dir=args.runtime_dir, + golden_data=args.golden_data, compile_cfg=dict(dump_passes=args.dump_passes), runtime_cfg=dict( platform=args.platform, diff --git a/models/deepseek/v4/decode_indexer.py b/models/deepseek/v4/decode_indexer.py index b60608f5..05e6f2b9 100644 --- a/models/deepseek/v4/decode_indexer.py +++ b/models/deepseek/v4/decode_indexer.py @@ -73,8 +73,9 @@ WEIGHTS_ROW_TILE = 8 QH_QUANT_TILE = 64 QH_HEAD_DIM_TILE = 64 -ROPE_ROW_BLOCK = S * IDX_N_HEADS # 128 rows = one batch +ROPE_ROW_BLOCK = S * IDX_N_HEADS ROPE_ROW_TILE = 32 +ROPE_SPMD_TILE = 32 @pl.jit.inline def indexer( @@ -128,20 +129,13 @@ def indexer( qr_proj_flat = pl.reshape(qr_proj, [T * IDX_N_HEADS, IDX_HEAD_DIM]) qr_rope_out = pl.create_tensor([T * IDX_N_HEADS, ROPE_HEAD_DIM], dtype=pl.BF16) - # One task per batch: a batch owns S*IDX_N_HEADS contiguous rows and a single - # cos/sin row, so the rotation indices/sign and cos_il/sin_il are built ONCE per - # task and reused across the inner 32-row sub-blocks. This coarsens the loop from - # T*IDX_N_HEADS//32 (256) down to B (64) tasks while keeping the per-iter tile at - # 32 rows (same Vec UB footprint), amortizing the in-kernel index build 4x. - for idx in pl.spmd(T * IDX_N_HEADS // ROPE_ROW_BLOCK, name_hint="qr_rope"): - o0 = idx * ROPE_ROW_BLOCK - batch_idx = idx - # A3 interleaved swap-gather (same form as q_head_rms_nope_rope / kv_rms_norm_rope), - # replacing de-interleave gather + rotate + re-interleave scatter. The rotation - # indices/sign and the interleave-duplicated cos/sin are built ENTIRELY IN-KERNEL: - # swap_idx (j^1), sign ([-1,+1,...]) and dup_idx (j>>1) from pl.arange, and - # cos_il/sin_il dup-gathered from the per-batch cos/sin row (broadcast to 32 rows). - # out[j] = x[j]*cos_il[j] + x[j^1]*sign[j]*sin_il[j] + # spmd over ROPE_SPMD_TILE-row blocks; batch_idx = block base // ROPE_ROW_BLOCK + # picks the per-batch cos/sin row. Rotation indices/sign and cos_il/sin_il are + # built once per block, reused across the inner 32-row tiles. + # out[j] = x[j]*cos_il[j] + x[j^1]*sign[j]*sin_il[j] (sign folded into sin_il_signed) + for idx in pl.spmd(T * IDX_N_HEADS // ROPE_SPMD_TILE, name_hint="qr_rope"): + o0 = idx * ROPE_SPMD_TILE + batch_idx = o0 // ROPE_ROW_BLOCK cos_b = cos[batch_idx : batch_idx + 1, 0 : ROPE_HEAD_DIM // 2] sin_b = sin[batch_idx : batch_idx + 1, 0 : ROPE_HEAD_DIM // 2] rope_ones = pl.full([ROPE_ROW_TILE, ROPE_HEAD_DIM], dtype=pl.FP32, value=1.0) @@ -154,12 +148,13 @@ def indexer( cos_b32 = pl.col_expand_mul(pl.full([ROPE_ROW_TILE, ROPE_HEAD_DIM // 2], dtype=pl.FP32, value=1.0), cos_b) sin_b32 = pl.col_expand_mul(pl.full([ROPE_ROW_TILE, ROPE_HEAD_DIM // 2], dtype=pl.FP32, value=1.0), sin_b) cos_il = pl.gather(cos_b32, dim=-1, index=rope_dup_idx) - sin_il = pl.gather(sin_b32, dim=-1, index=rope_dup_idx) - for ro in pl.range(0, ROPE_ROW_BLOCK, ROPE_ROW_TILE): + # fold sign into sin_il + sin_il_signed = pl.mul(pl.gather(sin_b32, dim=-1, index=rope_dup_idx), rope_sign) + for ro in pl.range(0, ROPE_SPMD_TILE, ROPE_ROW_TILE): r0 = o0 + ro qr_rope_slice = qr_proj_flat[r0 : r0 + ROPE_ROW_TILE, IDX_NOPE_HEAD_DIM : IDX_HEAD_DIM] qr_swapped = pl.gather(qr_rope_slice, dim=-1, index=rope_swap_idx) - rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(pl.mul(qr_swapped, rope_sign), sin_il)) + rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(qr_swapped, sin_il_signed)) qr_rope_out[r0 : r0 + ROPE_ROW_TILE, :] = pl.cast(rope_rot, target_type=pl.BF16, mode="rint") qr_hadamard_i8 = pl.create_tensor([T * IDX_N_HEADS, IDX_HEAD_DIM], dtype=pl.INT8) @@ -653,7 +648,8 @@ def init_idx_slot_mapping(): parser.add_argument("-p", "--platform", type=str, default="a2a3", choices=["a2a3", "a2a3sim", "a5", "a5sim"]) parser.add_argument("-d", "--device", type=int, default=0) - parser.add_argument("--enable-l2-swimlane", action="store_true", default=False) + parser.add_argument("--enable-l2-swimlane", type=int, default=0, choices=[0, 1, 2], + help="L2 swimlane level: 0=off, 1=AICore timing, 2=+AICPU timing.") parser.add_argument("--runtime-dir", type=str, default=None) parser.add_argument("--start-pos", type=int, default=None, help="Fixture-only compatibility seed for position_ids and slot mappings; " diff --git a/models/deepseek/v4/decode_layer.py b/models/deepseek/v4/decode_layer.py index f35e530e..0982d592 100644 --- a/models/deepseek/v4/decode_layer.py +++ b/models/deepseek/v4/decode_layer.py @@ -825,7 +825,7 @@ def init_input_ids(): atol=1e-3, compare_fn={ # Real-weight x_next over-thd fractions (frac>5e-3 / frac>1e-2): - # swa(L0) 0.4% / 0.003%, hca(L9) 1.9% / 1.0%, csa(L8) 3.8% / 0.6%. + # swa(L0) 0.7% / 0.006%, hca(L9) 0.5% / 0.003%, csa(L8) 3.8% / 0.4%. "x_next": ratio_reldiff(diff_thd=0.01, pct_thd=0.05), "kv_cache": ratio_allclose(atol=1e-4, rtol=1.0 / 128), }, diff --git a/models/deepseek/v4/decode_sparse_attn.py b/models/deepseek/v4/decode_sparse_attn.py index 674ab415..ce7aa1b2 100644 --- a/models/deepseek/v4/decode_sparse_attn.py +++ b/models/deepseek/v4/decode_sparse_attn.py @@ -211,7 +211,7 @@ def sparse_attn( # Additive softmax bias (0 valid / NEG_INF invalid) that qk_pv adds onto the # scaled scores, so invalid lanes exp to ~0 with no per-block mask multiply. - for v_blk in pl.spmd(T // VALID_TOKEN_TILE, name_hint="build_valid"): + for v_blk in pl.spmd(T // VALID_TOKEN_TILE, name_hint="build_valid", allow_early_resolve=True): v_t0 = v_blk * VALID_TOKEN_TILE v_idx_f = pl.cast(cmp_sparse_indices[v_t0 : v_t0 + VALID_TOKEN_TILE, 0 : TOPK], target_type=pl.FP32) # Index contract (line 138): raw == -1 invalid, raw >= 0 valid. min(idx, 0) diff --git a/models/deepseek/v4/decode_sparse_attn_hca.py b/models/deepseek/v4/decode_sparse_attn_hca.py index beb49921..b578a5b7 100644 --- a/models/deepseek/v4/decode_sparse_attn_hca.py +++ b/models/deepseek/v4/decode_sparse_attn_hca.py @@ -212,7 +212,7 @@ def sparse_attn_hca( # Additive softmax bias (0 valid / NEG_INF invalid) that qk_pv adds onto the # scaled scores, so invalid lanes exp to ~0 with no per-block mask multiply. - for v_blk in pl.spmd(T // VALID_TOKEN_TILE, name_hint="build_valid"): + for v_blk in pl.spmd(T // VALID_TOKEN_TILE, name_hint="build_valid", allow_early_resolve=True): v_t0 = v_blk * VALID_TOKEN_TILE # Read the full PADDED_TOPK (256, 32-byte aligned); cmp_sparse_indices pads its # tail with -1, so the padded lanes get a NEG_INF bias for free. diff --git a/models/deepseek/v4/expert_shared.py b/models/deepseek/v4/expert_shared.py index ab4e44b3..0ac0bc60 100644 --- a/models/deepseek/v4/expert_shared.py +++ b/models/deepseek/v4/expert_shared.py @@ -31,7 +31,6 @@ SWIGLU_LIMIT = M.swiglu_limit # tiling -T_TILE = 8 SH_M_TILE = 16 T_PAD = ((T + SH_M_TILE - 1) // SH_M_TILE) * SH_M_TILE # Decode (T <= SH_M_TILE, single partial block) or prefill (T a multiple of @@ -40,11 +39,19 @@ assert T <= SH_M_TILE or T % SH_M_TILE == 0, \ "expert_shared needs T <= SH_M_TILE (decode) or T a multiple of SH_M_TILE (prefill)" SH_VALID_M = T if T < SH_M_TILE else SH_M_TILE +N_MTILES = T_PAD // SH_M_TILE + K_TILE = 512 INTER_K = 512 -SH_INTER_TILE = 64 -SH_D_OUT_TILE = 64 +MM_INTER_TILE = 256 +MM_GATE_INNER = 4 +ACT_INTER_TILE = 128 +ACT_GATE_INNER = 4 +D_OUT_TILE = 256 QUANT_TILE = 256 +D_OUT_TILE_ACT = 512 +W2_INNER = 4 +W2_ACT_INNER = 8 @pl.jit.inline @@ -59,85 +66,115 @@ def expert_shared( shared_w2_scale: pl.Tensor[[D], pl.FP32], sh: pl.Tensor[[T, D], pl.BF16], ): - sh_tile_fp32 = pl.create_tensor([T_PAD, MOE_INTER], dtype=pl.FP32) - sh_tile_i8 = pl.create_tensor([T_PAD, MOE_INTER], dtype=pl.INT8) - sh_tile_scale_dq = pl.create_tensor([T_PAD, 1], dtype=pl.FP32) sh_pad = pl.create_tensor([T_PAD, D], dtype=pl.BF16) - for gu_block in pl.spmd((T_PAD // SH_M_TILE) * (MOE_INTER // (8 * SH_INTER_TILE)), name_hint="sh_gate_up"): - gu_tb = gu_block // (MOE_INTER // (8 * SH_INTER_TILE)) - gu_nb = gu_block - gu_tb * (MOE_INTER // (8 * SH_INTER_TILE)) - ts0 = gu_tb * SH_M_TILE - n_base = gu_nb * (8 * SH_INTER_TILE) - # Static valid_shape: a dynamic one perturbs the Vec row_expand_mul. - x_local_scale_dq_tile = pl.slice(x_local_scale_dq, [SH_M_TILE, 1], [ts0, 0], valid_shape=[SH_VALID_M, 1]) - for ng in pl.range(8): - n0 = n_base + ng * SH_INTER_TILE - sh_gate_acc = pl.create_tensor([SH_M_TILE, SH_INTER_TILE], dtype=pl.INT32) - sh_up_acc = pl.create_tensor([SH_M_TILE, SH_INTER_TILE], dtype=pl.INT32) - for kb in pl.pipeline(0, D // K_TILE, stage=2): - k0 = kb * K_TILE - xs_k = pl.slice(x_local_i8, [SH_M_TILE, K_TILE], [ts0, k0], valid_shape=[SH_VALID_M, K_TILE]) - sw1_k = shared_w1[n0 : n0 + SH_INTER_TILE, k0 : k0 + K_TILE] - sw3_k = shared_w3[n0 : n0 + SH_INTER_TILE, k0 : k0 + K_TILE] - if k0 == 0: - sh_gate_acc = pl.matmul(xs_k, sw1_k, b_trans=True, out_dtype=pl.INT32) - sh_up_acc = pl.matmul(xs_k, sw3_k, b_trans=True, out_dtype=pl.INT32) - else: - sh_gate_acc = pl.matmul_acc(sh_gate_acc, xs_k, sw1_k, b_trans=True) - sh_up_acc = pl.matmul_acc(sh_up_acc, xs_k, sw3_k, b_trans=True) - - sw1_scale_chunk = pl.reshape(shared_w1_scale[n0 : n0 + SH_INTER_TILE], [1, SH_INTER_TILE]) - sw3_scale_chunk = pl.reshape(shared_w3_scale[n0 : n0 + SH_INTER_TILE], [1, SH_INTER_TILE]) - sh_gate = pl.cast(sh_gate_acc, target_type=pl.FP32, mode="none") - sh_up = pl.cast(sh_up_acc, target_type=pl.FP32, mode="none") - sh_gate = pl.col_expand_mul(pl.row_expand_mul(sh_gate, x_local_scale_dq_tile), sw1_scale_chunk) - sh_up = pl.col_expand_mul(pl.row_expand_mul(sh_up, x_local_scale_dq_tile), sw3_scale_chunk) - if SWIGLU_LIMIT > 0.0: - sh_gate = pl.minimum(sh_gate, SWIGLU_LIMIT) - sh_up = pl.maximum(pl.minimum(sh_up, SWIGLU_LIMIT), -SWIGLU_LIMIT) - sh_sigmoid = pl.recip(pl.add(pl.exp(pl.neg(sh_gate)), 1.0)) - sh_silu = pl.mul(sh_gate, sh_sigmoid) - sh_gated = pl.mul(sh_silu, sh_up) - sh_tile_fp32[ts0 : ts0 + SH_M_TILE, n0 : n0 + SH_INTER_TILE] = sh_gated - - for q_tb in pl.spmd(T // T_TILE, name_hint="sh_h_q"): - ts0 = q_tb * T_TILE - shq_amax = pl.full([1, T_TILE], dtype=pl.FP32, value=INT8_AMAX_EPS) - for k0 in pl.range(0, MOE_INTER, QUANT_TILE): - shq_a_f32 = sh_tile_fp32[ts0 : ts0 + T_TILE, k0 : k0 + QUANT_TILE] - shq_a_abs = pl.maximum(shq_a_f32, pl.neg(shq_a_f32)) - shq_a_max = pl.reshape(pl.row_max(shq_a_abs), [1, T_TILE]) - shq_amax = pl.maximum(shq_amax, shq_a_max) - shq_sq_row = pl.div(pl.full([1, T_TILE], dtype=pl.FP32, value=INT8_SCALE_MAX), shq_amax) - sh_tile_scale_dq[ts0 : ts0 + T_TILE, 0:1] = pl.reshape(pl.recip(shq_sq_row), [T_TILE, 1]) - shq_sq_col = pl.reshape(shq_sq_row, [T_TILE, 1]) - for k1 in pl.range(0, MOE_INTER, QUANT_TILE): - shq_q_f32 = sh_tile_fp32[ts0 : ts0 + T_TILE, k1 : k1 + QUANT_TILE] - shq_q_scaled = pl.row_expand_mul(shq_q_f32, shq_sq_col) - shq_q_i32 = pl.cast(shq_q_scaled, target_type=pl.INT32, mode="rint") - shq_q_half = pl.cast(shq_q_i32, target_type=pl.FP16, mode="round") - sh_tile_i8[ts0 : ts0 + T_TILE, k1 : k1 + QUANT_TILE] = pl.cast(shq_q_half, target_type=pl.INT8, mode="trunc") - for w2_block in pl.spmd((T_PAD // SH_M_TILE) * (D // (16 * SH_D_OUT_TILE)), name_hint="sh_w2"): - w2_tb = w2_block // (D // (16 * SH_D_OUT_TILE)) - w2_db = w2_block - w2_tb * (D // (16 * SH_D_OUT_TILE)) - ts0 = w2_tb * SH_M_TILE - d_base = w2_db * (16 * SH_D_OUT_TILE) - sh_tile_scale_dq_tile = sh_tile_scale_dq[ts0 : ts0 + SH_M_TILE, 0:1] - for dg in pl.range(16): - d0 = d_base + dg * SH_D_OUT_TILE - hs_init = sh_tile_i8[ts0 : ts0 + SH_M_TILE, 0 : INTER_K] - sw2_init = shared_w2[d0 : d0 + SH_D_OUT_TILE, 0 : INTER_K] - sh_y_acc = pl.matmul(hs_init, sw2_init, b_trans=True, out_dtype=pl.INT32) - for k0 in pl.range(INTER_K, MOE_INTER, INTER_K): - hs_k = sh_tile_i8[ts0 : ts0 + SH_M_TILE, k0 : k0 + INTER_K] - sw2_k = shared_w2[d0 : d0 + SH_D_OUT_TILE, k0 : k0 + INTER_K] - sh_y_acc = pl.matmul_acc(sh_y_acc, hs_k, sw2_k, b_trans=True) - - sw2_scale_chunk = pl.reshape(shared_w2_scale[d0 : d0 + SH_D_OUT_TILE], [1, SH_D_OUT_TILE]) - sh_y = pl.cast(sh_y_acc, target_type=pl.FP32, mode="none") - sh_y = pl.col_expand_mul(pl.row_expand_mul(sh_y, sh_tile_scale_dq_tile), sw2_scale_chunk) - sh_pad[ts0 : ts0 + SH_M_TILE, d0 : d0 + SH_D_OUT_TILE] = pl.cast(sh_y, target_type=pl.BF16, mode="rint") + # One M-tile of SH_M_TILE rows per iteration (decode: 1 tile, T<=16 rows valid; + # prefill: T_PAD/SH_M_TILE fully-valid tiles). + for mt in pl.parallel(N_MTILES): + ts0 = mt * SH_M_TILE + + h_tile_fp32 = pl.create_tensor([SH_M_TILE, MOE_INTER], dtype=pl.FP32) + gate_i32 = pl.create_tensor([SH_M_TILE, MOE_INTER], dtype=pl.INT32) + up_i32 = pl.create_tensor([SH_M_TILE, MOE_INTER], dtype=pl.INT32) + + # gate (w1) cube matmul -> INT32 GM accumulator. + for nb_idx in pl.spmd(MOE_INTER // (MM_GATE_INNER * MM_INTER_TILE), name_hint="sh_gate_mm"): + n_base = nb_idx * (MM_GATE_INNER * MM_INTER_TILE) + for ng in pl.range(MM_GATE_INNER): + n0 = n_base + ng * MM_INTER_TILE + gate_acc = pl.create_tensor([SH_M_TILE, MM_INTER_TILE], dtype=pl.INT32) + for k0 in pl.pipeline(0, D, K_TILE, stage=2): + xs_k = pl.slice(x_local_i8, [SH_M_TILE, K_TILE], [ts0, k0], valid_shape=[SH_VALID_M, K_TILE]) + sw1_k = shared_w1[n0 : n0 + MM_INTER_TILE, k0 : k0 + K_TILE] + if k0 == 0: + gate_acc = pl.matmul(xs_k, sw1_k, b_trans=True, out_dtype=pl.INT32) + else: + gate_acc = pl.matmul_acc(gate_acc, xs_k, sw1_k, b_trans=True) + gate_i32[:, n0 : n0 + MM_INTER_TILE] = gate_acc + + # up (w3) cube matmul -> INT32 GM accumulator. + for nb_idx in pl.spmd(MOE_INTER // (MM_GATE_INNER * MM_INTER_TILE), name_hint="sh_up_mm"): + n_base = nb_idx * (MM_GATE_INNER * MM_INTER_TILE) + for ng in pl.range(MM_GATE_INNER): + n0 = n_base + ng * MM_INTER_TILE + up_acc = pl.create_tensor([SH_M_TILE, MM_INTER_TILE], dtype=pl.INT32) + for k0 in pl.pipeline(0, D, K_TILE, stage=2): + xs_k = pl.slice(x_local_i8, [SH_M_TILE, K_TILE], [ts0, k0], valid_shape=[SH_VALID_M, K_TILE]) + sw3_k = shared_w3[n0 : n0 + MM_INTER_TILE, k0 : k0 + K_TILE] + if k0 == 0: + up_acc = pl.matmul(xs_k, sw3_k, b_trans=True, out_dtype=pl.INT32) + else: + up_acc = pl.matmul_acc(up_acc, xs_k, sw3_k, b_trans=True) + up_i32[:, n0 : n0 + MM_INTER_TILE] = up_acc + + # SwiGLU activation (dequant gate/up, clamp, silu*up) -> FP32 GM. + for nb_idx in pl.spmd(MOE_INTER // (ACT_GATE_INNER * ACT_INTER_TILE), name_hint="sh_gate_up_act"): + n_base = nb_idx * (ACT_GATE_INNER * ACT_INTER_TILE) + for ng in pl.pipeline(ACT_GATE_INNER, stage=2): + n0 = n_base + ng * ACT_INTER_TILE + gate_2d_i32 = gate_i32[:, n0 : n0 + ACT_INTER_TILE] + up_2d_i32 = up_i32[:, n0 : n0 + ACT_INTER_TILE] + # Static valid_shape: a dynamic one perturbs the Vec row_expand_mul. + x_local_scale_dq_tile = pl.slice(x_local_scale_dq, [SH_M_TILE, 1], [ts0, 0], valid_shape=[SH_VALID_M, 1]) + w1_scale_chunk = pl.reshape(shared_w1_scale[n0 : n0 + ACT_INTER_TILE], [1, ACT_INTER_TILE]) + w3_scale_chunk = pl.reshape(shared_w3_scale[n0 : n0 + ACT_INTER_TILE], [1, ACT_INTER_TILE]) + gate_2d = pl.cast(gate_2d_i32, target_type=pl.FP32, mode="none") + up_2d = pl.cast(up_2d_i32, target_type=pl.FP32, mode="none") + gate_2d = pl.col_expand_mul(pl.row_expand_mul(gate_2d, x_local_scale_dq_tile), w1_scale_chunk) + up_2d = pl.col_expand_mul(pl.row_expand_mul(up_2d, x_local_scale_dq_tile), w3_scale_chunk) + if SWIGLU_LIMIT > 0.0: + gate_2d = pl.minimum(gate_2d, SWIGLU_LIMIT) + up_2d = pl.maximum(pl.minimum(up_2d, SWIGLU_LIMIT), -SWIGLU_LIMIT) + sigmoid = pl.recip(pl.add(pl.exp(pl.neg(gate_2d)), 1.0)) + silu = pl.mul(gate_2d, sigmoid) + gated = pl.mul(silu, up_2d) + h_tile_fp32[:, n0 : n0 + ACT_INTER_TILE] = gated + + # Per-row A8 requant of h_tile (amax across full MOE_INTER row). + h_tile_i8 = pl.create_tensor([SH_M_TILE, MOE_INTER], dtype=pl.INT8) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="sh_h_q"): + eh_amax = pl.full([1, SH_M_TILE], dtype=pl.FP32, value=INT8_AMAX_EPS) + for k0 in pl.pipeline(0, MOE_INTER, QUANT_TILE, stage=2): + eh_a_f32 = h_tile_fp32[:, k0 : k0 + QUANT_TILE] + eh_a_abs = pl.maximum(eh_a_f32, pl.neg(eh_a_f32)) + eh_a_max = pl.reshape(pl.row_max(eh_a_abs), [1, SH_M_TILE]) + eh_amax = pl.maximum(eh_amax, eh_a_max) + eh_sq_row = pl.div(pl.full([1, SH_M_TILE], dtype=pl.FP32, value=INT8_SCALE_MAX), eh_amax) + h_tile_scale_dq = pl.reshape(pl.recip(eh_sq_row), [SH_M_TILE, 1]) + eh_sq_col = pl.reshape(eh_sq_row, [SH_M_TILE, 1]) + for k1 in pl.pipeline(0, MOE_INTER, QUANT_TILE, stage=2): + eh_q_f32 = h_tile_fp32[:, k1 : k1 + QUANT_TILE] + eh_q_scaled = pl.row_expand_mul(eh_q_f32, eh_sq_col) + eh_q_i32 = pl.cast(eh_q_scaled, target_type=pl.INT32, mode="rint") + eh_q_half = pl.cast(eh_q_i32, target_type=pl.FP16, mode="round") + h_tile_i8[:, k1 : k1 + QUANT_TILE] = pl.cast(eh_q_half, target_type=pl.INT8, mode="trunc") + + # w2 (down) cube matmul -> INT32 GM accumulator. + y_i32 = pl.create_tensor([SH_M_TILE, D], dtype=pl.INT32) + for db_idx in pl.spmd(D // (W2_INNER * D_OUT_TILE), name_hint="sh_w2_mm"): + d_base = db_idx * (W2_INNER * D_OUT_TILE) + for dg in pl.range(W2_INNER): + d0 = d_base + dg * D_OUT_TILE + y_acc = pl.create_tensor([SH_M_TILE, D_OUT_TILE], dtype=pl.INT32) + for k0 in pl.pipeline(0, MOE_INTER, INTER_K, stage=2): + hs_k = h_tile_i8[:, k0 : k0 + INTER_K] + sw2_k = shared_w2[d0 : d0 + D_OUT_TILE, k0 : k0 + INTER_K] + if k0 == 0: + y_acc = pl.matmul(hs_k, sw2_k, b_trans=True, out_dtype=pl.INT32) + else: + y_acc = pl.matmul_acc(y_acc, hs_k, sw2_k, b_trans=True) + y_i32[:, d0 : d0 + D_OUT_TILE] = y_acc + + # Dequant w2 output (per-row h scale x per-channel w2 scale) -> BF16. + for db_idx in pl.spmd(D // (W2_ACT_INNER * D_OUT_TILE_ACT), name_hint="sh_w2_act"): + d_base = db_idx * (W2_ACT_INNER * D_OUT_TILE_ACT) + for dg in pl.pipeline(W2_ACT_INNER, stage=2): + d0 = d_base + dg * D_OUT_TILE_ACT + y_2d_i32 = y_i32[:, d0 : d0 + D_OUT_TILE_ACT] + w2_scale_chunk = pl.reshape(shared_w2_scale[d0 : d0 + D_OUT_TILE_ACT], [1, D_OUT_TILE_ACT]) + y_2d = pl.cast(y_2d_i32, target_type=pl.FP32, mode="none") + y_2d = pl.col_expand_mul(pl.row_expand_mul(y_2d, h_tile_scale_dq), w2_scale_chunk) + sh_pad[ts0 : ts0 + SH_M_TILE, d0 : d0 + D_OUT_TILE_ACT] = pl.cast(y_2d, target_type=pl.BF16, mode="rint") with pl.at(level=pl.Level.CORE_GROUP, name_hint="sh_output"): for t0 in pl.range(0, T, SH_M_TILE): diff --git a/models/deepseek/v4/hc_post.py b/models/deepseek/v4/hc_post.py index ea320fb6..a2e1adff 100644 --- a/models/deepseek/v4/hc_post.py +++ b/models/deepseek/v4/hc_post.py @@ -24,7 +24,7 @@ HC_DIM = M.hc_dim # tiling -T_TILE = 8 +T_TILE = 4 assert (DECODE_BATCH * DECODE_SEQ) % T_TILE == 0 assert (PREFILL_BATCH * PREFILL_SEQ) % T_TILE == 0 @@ -137,9 +137,9 @@ def init_comb(): parser.add_argument("-p", "--platform", type=str, default="a2a3", choices=["a2a3", "a2a3sim", "a5", "a5sim"]) parser.add_argument("-d", "--device", type=int, default=0) - parser.add_argument("--mode", choices=["decode", "prefill", "all"], default="all", + parser.add_argument("--mode", choices=["decode", "prefill", "all"], default="decode", help="Use decode or prefill batch sizes, or 'all' to test both.") - parser.add_argument("--enable-l2-swimlane", action="store_true", default=False) + parser.add_argument("--enable-l2-swimlane", type=int, nargs="?", const=1, default=0, choices=(0, 1, 2)) parser.add_argument("--compile-only", action="store_true", default=False) parser.add_argument("--dump-passes", action="store_true", default=False) args = parser.parse_args() diff --git a/models/deepseek/v4/moe.py b/models/deepseek/v4/moe.py index 903396d2..513ceb9a 100644 --- a/models/deepseek/v4/moe.py +++ b/models/deepseek/v4/moe.py @@ -269,17 +269,15 @@ def combine( active_tokens = pl.cast(0, pl.INDEX) if active_tokens > T: active_tokens = pl.cast(T, pl.INDEX) - for tb in pl.spmd(T // 4, name_hint="shared_routed"): - for tt in pl.range(4): - t = tb * 4 + tt - if t < active_tokens: - acc = pl.cast(sh[t:t + 1, :], target_type=pl.FP32) - for k in pl.range(TOPK): - r = t * TOPK + k - acc = pl.add(acc, pl.cast(routed_y_buf[r:r + 1, :], target_type=pl.FP32)) - ffn_out[t:t + 1, :] = pl.cast(acc, target_type=pl.BF16, mode="rint") - else: - ffn_out[t:t + 1, :] = sh[t:t + 1, :] + for t in pl.spmd(T, name_hint="shared_routed"): + if t < active_tokens: + acc = pl.cast(sh[t:t + 1, :], target_type=pl.FP32) + for k in pl.range(TOPK): + r = t * TOPK + k + acc = pl.add(acc, pl.cast(routed_y_buf[r:r + 1, :], target_type=pl.FP32)) + ffn_out[t:t + 1, :] = pl.cast(acc, target_type=pl.BF16, mode="rint") + else: + ffn_out[t:t + 1, :] = sh[t:t + 1, :] @pl.jit.inline(auto_scope=False) diff --git a/models/deepseek/v4/prefill_layer.py b/models/deepseek/v4/prefill_layer.py index 01b47bd0..b61c81cd 100644 --- a/models/deepseek/v4/prefill_layer.py +++ b/models/deepseek/v4/prefill_layer.py @@ -981,7 +981,7 @@ def cmp(actual, expected, **kwargs): atol=1e-3, compare_fn={ # Real-weight x_next over-thd fractions (frac>5e-3 / frac>1e-2): - # swa(L0) 0.3% / 0.0%, hca(L9) 1.7% / 0.6%, csa(L8) 5.4% / 0.7%. + # swa(L0) 0.2% / 0.001%, hca(L9) 2.0% / 1.05%, csa(L8) 4.5% / 0.7%. "x_next": valid_ratio_reldiff(args.num_tokens, diff_thd=0.01, pct_thd=0.05), "kv_cache": ratio_allclose(atol=1e-4, rtol=1.0 / 128), }, diff --git a/models/deepseek/v4/qkv_proj_rope.py b/models/deepseek/v4/qkv_proj_rope.py index 75cb91ce..cd3e6a93 100644 --- a/models/deepseek/v4/qkv_proj_rope.py +++ b/models/deepseek/v4/qkv_proj_rope.py @@ -249,7 +249,7 @@ def qkv_proj_rope( # of the rotation and is folded into the writeback. # out[j] = inv_rms * (x[j]*cos_il[j] + x[j^1]*sign[j]*sin_il[j]) q_flat = pl.reshape(q, [t_dim, H * HEAD_DIM]) - for hg_idx in pl.spmd(H // 2, name_hint="q_head_rms_nope_rope"): + for hg_idx in pl.spmd(H // 2, name_hint="q_head_rms_nope_rope", allow_early_resolve=True): hg = hg_idx * 2 # In-kernel A3 index/sign build (per task, reused across the inner tg/h loop). q_ones = pl.full([Q_ROPE_T_TILE, ROPE_DIM], dtype=pl.FP32, value=1.0) @@ -266,29 +266,27 @@ def qkv_proj_rope( for h_inner in pl.range(2): h = hg + h_inner h0 = h * HEAD_DIM - # Pass 1: per-row sum of squares over the full HEAD_DIM -> inv_rms (no gamma). - q_head_sq_sum = pl.full([1, Q_ROPE_T_TILE], dtype=pl.FP32, value=0.0) - for db in pl.pipeline(HEAD_DIM // HEAD_TILE, stage=2): - d0 = h0 + db * HEAD_TILE - q_head_chunk = q_proj_fp32[tg : tg + Q_ROPE_T_TILE, d0 : d0 + HEAD_TILE] - q_head_sq_sum = pl.add( - q_head_sq_sum, - pl.reshape(pl.row_sum(pl.mul(q_head_chunk, q_head_chunk)), [1, Q_ROPE_T_TILE]), - ) + # Load each head's NOPE + RoPE columns once (fp32), reused for both the + # inv_rms reduction and the writeback. + q_nope_full = q_proj_fp32[tg : tg + Q_ROPE_T_TILE, h0 : h0 + NOPE_DIM] + q_rope_chunk = q_proj_fp32[tg : tg + Q_ROPE_T_TILE, h0 + NOPE_DIM : h0 + HEAD_DIM] + + # Pass 1: per-row sum of squares over the full HEAD_DIM = NOPE part + RoPE + # part -> inv_rms (no gamma). + q_head_sq_sum = pl.add( + pl.reshape(pl.row_sum(pl.mul(q_nope_full, q_nope_full)), [1, Q_ROPE_T_TILE]), + pl.reshape(pl.row_sum(pl.mul(q_rope_chunk, q_rope_chunk)), [1, Q_ROPE_T_TILE]), + ) q_head_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS))) q_head_inv_rms_t = pl.reshape(q_head_inv_rms, [Q_ROPE_T_TILE, 1]) # NOPE writeback: rms-normalize columns [h0:h0+NOPE_DIM) (no gamma). - for nb in pl.pipeline(NOPE_DIM // HEAD_TILE, stage=2): - n0 = nb * HEAD_TILE - q_nope_chunk = q_proj_fp32[tg : tg + Q_ROPE_T_TILE, h0 + n0 : h0 + n0 + HEAD_TILE] - q_normed = pl.row_expand_mul(q_nope_chunk, q_head_inv_rms_t) - q_flat[tg : tg + Q_ROPE_T_TILE, h0 + n0 : h0 + n0 + HEAD_TILE] = pl.cast( - q_normed, target_type=pl.BF16, mode="rint" - ) + q_normed = pl.row_expand_mul(q_nope_full, q_head_inv_rms_t) + q_flat[tg : tg + Q_ROPE_T_TILE, h0 : h0 + NOPE_DIM] = pl.cast( + q_normed, target_type=pl.BF16, mode="rint" + ) # RoPE writeback on columns [h0+NOPE_DIM:h0+HEAD_DIM), inv_rms folded after. - q_rope_chunk = q_proj_fp32[tg : tg + Q_ROPE_T_TILE, h0 + NOPE_DIM : h0 + NOPE_DIM + ROPE_DIM] q_rope_swapped = pl.gather(q_rope_chunk, dim=-1, index=q_swap_idx) q_rope_rot = pl.add(pl.mul(q_rope_chunk, q_cos_il), pl.mul(pl.mul(q_rope_swapped, q_sign), q_sin_il)) q_flat[tg : tg + Q_ROPE_T_TILE, h0 + NOPE_DIM : h0 + NOPE_DIM + ROPE_DIM] = pl.cast( @@ -567,7 +565,9 @@ def init_gamma_ckv(): parser.add_argument("-d", "--device", type=int, default=0) parser.add_argument("--mode", choices=["decode", "prefill", "all"], default="all", help="Use decode or prefill batch sizes, or 'all' to test both.") - parser.add_argument("--enable-l2-swimlane", action="store_true", default=False) + parser.add_argument("--enable-l2-swimlane", type=int, choices=[0, 1, 2], default=0, + help="L2 swimlane level: 0=off, 1=per-kernel AICore timing " + "(prints the per-function Task Statistics table), 2=+AICPU timing.") parser.add_argument("--runtime-dir", type=str, default=None) parser.add_argument("--golden-data", type=str, default=None) parser.add_argument("--compile-only", action="store_true", default=False)