diff --git a/dlblas/kernels/kernelswift_torch/level3/12_hc_post.py b/dlblas/kernels/kernelswift_torch/level3/12_hc_post.py index 76490b12..25c35f17 100644 --- a/dlblas/kernels/kernelswift_torch/level3/12_hc_post.py +++ b/dlblas/kernels/kernelswift_torch/level3/12_hc_post.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn + class Model(nn.Module): def __init__(self): super(Model, self).__init__() @@ -9,27 +10,41 @@ def forward( self, x: torch.Tensor, residual: torch.Tensor, - post_layer_mix: torch.Tensor, - comb_res_mix: torch.Tensor, + post: torch.Tensor, + comb: torch.Tensor, ) -> torch.Tensor: - term2 = torch.einsum('abmn,abmc->abnc', comb_res_mix, residual.float()) - return (x.float().unsqueeze(-2) * post_layer_mix + term2).bfloat16() + x_f = x.float() + residual_f = residual.float() + post_f = post.float().unsqueeze(-1) + comb_f = comb.float().unsqueeze(-1) + output = post_f * x_f.unsqueeze(-2) + torch.sum( + comb_f * residual_f.unsqueeze(-2), dim=2 + ) + return output.bfloat16() + + +def generate_test_data(params): + batch_size = params['batch_size'] + seq_len = params['seq_len'] + hidden_size = params['hidden'] + hc_mult = params['hc'] + x_data = torch.randn(batch_size, seq_len, hidden_size, dtype=torch.bfloat16, device='cpu') + residual_data = torch.randn(batch_size, seq_len, hc_mult, hidden_size, dtype=torch.bfloat16, device='cpu') + post_data = torch.randn(batch_size, seq_len, hc_mult, dtype=torch.float32, device='cpu') + comb_data = torch.randn(batch_size, seq_len, hc_mult, hc_mult, dtype=torch.float32, device='cpu') + o_grad = torch.randn(batch_size, seq_len, hc_mult, hidden_size, dtype=torch.bfloat16, device='cpu') + return x_data, residual_data, post_data, comb_data, o_grad + + +def test_hc_post_fwd(): + return Model(*get_init_inputs()).forward(*get_inputs()) -n0 = 1 -n1 = 4096 -h = 1280 -mhc_mult = 4 -device = 'cuda' def get_inputs(): - x = torch.randn((n0, n1, h), dtype=torch.bfloat16, device=device) - residual = torch.randn((n0, n1, mhc_mult, h), dtype=torch.bfloat16, device=device) - post_layer_mix = torch.randn((n0, n1, mhc_mult, 1), dtype=torch.float32, device=device) - comb_res_mix = torch.randn((n0, n1, mhc_mult, mhc_mult), dtype=torch.float32, device=device) - - return [ - x, residual, post_layer_mix, comb_res_mix, - ] + params = {'batch_size': 1, 'seq_len': 4096, 'hidden': 1280, 'hc': 4} + x_data, residual_data, post_data, comb_data, o_grad = generate_test_data(params) + return [x_data, residual_data, post_data, comb_data] + def get_init_inputs(): return [] diff --git a/dlblas/kernels/kernelswift_torch/level3/25_compressor.py b/dlblas/kernels/kernelswift_torch/level3/25_compressor.py new file mode 100644 index 00000000..8dd05c72 --- /dev/null +++ b/dlblas/kernels/kernelswift_torch/level3/25_compressor.py @@ -0,0 +1,238 @@ +import math +from functools import lru_cache +import torch_npu # noqa: F401 + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Linear(nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype: torch.dtype = torch.float32): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) + if bias: + self.bias = nn.Parameter(torch.empty(out_features, dtype=dtype)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, self.weight, self.bias) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + x_f32 = x.float() + x_f32 = x_f32 * torch.rsqrt(x_f32.square().mean(-1, keepdim=True) + self.eps) + return (x_f32 * self.weight).to(dtype) + + +@lru_cache(16) +def precompute_freqs_cis( + dim: int, + seqlen: int, + theta: float = 10000.0, +) -> torch.Tensor: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + t = torch.arange(seqlen, dtype=torch.float32) + freqs = torch.outer(t, freqs) + return torch.polar(torch.ones_like(freqs), freqs) + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor: + y = x + x_complex = torch.view_as_complex(x.float().unflatten(-1, (-1, 2))) + if inverse: + freqs_cis = freqs_cis.conj() + if x_complex.ndim == 3: + freqs_cis = freqs_cis.view(1, x_complex.size(1), x_complex.size(-1)) + else: + freqs_cis = freqs_cis.view(1, x_complex.size(1), 1, x_complex.size(-1)) + x_complex = torch.view_as_real(x_complex * freqs_cis).flatten(-2) + y.copy_(x_complex) + return y + + +class Model(nn.Module): + def __init__( + self, + max_batch_size: int = 4, + max_seq_len: int = 256, + dim: int = 512, + head_dim: int = 128, + rope_head_dim: int = 64, + compress_ratio: int = 4, + norm_eps: float = 1e-6, + ): + super(Model, self).__init__() + self.dim = dim + self.head_dim = head_dim + self.rope_head_dim = rope_head_dim + self.compress_ratio = compress_ratio + self.overlap = compress_ratio == 4 + coeff = 1 + int(self.overlap) + + self.ape = nn.Parameter(torch.empty(compress_ratio, coeff * head_dim, dtype=torch.float32)) + self.wkv = Linear(dim, coeff * head_dim, dtype=torch.float32) + self.wgate = Linear(dim, coeff * head_dim, dtype=torch.float32) + self.norm = RMSNorm(head_dim, norm_eps) + + self.register_buffer( + "kv_state", + torch.zeros(max_batch_size, coeff * compress_ratio, coeff * head_dim, dtype=torch.float32), + persistent=False, + ) + self.register_buffer( + "score_state", + torch.full( + (max_batch_size, coeff * compress_ratio, coeff * head_dim), + float("-inf"), + dtype=torch.float32, + ), + persistent=False, + ) + self.register_buffer( + "kv_cache", + torch.zeros(max_batch_size, max_seq_len // compress_ratio, head_dim, dtype=torch.bfloat16), + persistent=False, + ) + self.register_buffer( + "freqs_cis", + precompute_freqs_cis(rope_head_dim, max_seq_len), + persistent=False, + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.normal_(self.ape, mean=0.0, std=0.02) + self.wkv.reset_parameters() + self.wgate.reset_parameters() + nn.init.ones_(self.norm.weight) + + def overlap_transform(self, tensor: torch.Tensor, fill_value: float) -> torch.Tensor: + batch_size, num_windows, ratio, _ = tensor.shape + output = tensor.new_full((batch_size, num_windows, 2 * ratio, self.head_dim), fill_value) + output[:, :, ratio:] = tensor[:, :, :, self.head_dim :] + output[:, 1:, :ratio] = tensor[:, :-1, :, : self.head_dim] + return output + + def reset_runtime_state(self) -> None: + self.kv_state.zero_() + self.score_state.fill_(float("-inf")) + self.kv_cache.zero_() + + def forward(self, x: torch.Tensor, start_pos: int) -> torch.Tensor | None: + batch_size, seqlen, _ = x.shape + ratio = self.compress_ratio + head_dim = self.head_dim + rope_head_dim = self.rope_head_dim + overlap = self.overlap + dtype = x.dtype + + x = x.float() + kv = self.wkv(x) + score = self.wgate(x) + + if start_pos == 0: + should_compress = seqlen >= ratio + remainder = seqlen % ratio + cutoff = seqlen - remainder + offset = ratio if overlap else 0 + + if overlap and cutoff >= ratio: + self.kv_state[:batch_size, :ratio] = kv[:, cutoff - ratio : cutoff] + self.score_state[:batch_size, :ratio] = score[:, cutoff - ratio : cutoff] + self.ape + + if remainder > 0: + kv, self.kv_state[:batch_size, offset : offset + remainder] = kv.split([cutoff, remainder], dim=1) + self.score_state[:batch_size, offset : offset + remainder] = score[:, cutoff:] + self.ape[:remainder] + score = score[:, :cutoff] + + if not should_compress: + return None + + kv = kv.unflatten(1, (-1, ratio)) + score = score.unflatten(1, (-1, ratio)) + self.ape + if overlap: + kv = self.overlap_transform(kv, 0.0) + score = self.overlap_transform(score, float("-inf")) + kv = (kv * score.softmax(dim=2)).sum(dim=2) + else: + slot = start_pos % ratio + should_compress = (start_pos + 1) % ratio == 0 + score = score + self.ape[slot] + + if overlap: + self.kv_state[:batch_size, ratio + slot] = kv.squeeze(1) + self.score_state[:batch_size, ratio + slot] = score.squeeze(1) + if not should_compress: + return None + merged_kv = torch.cat( + [self.kv_state[:batch_size, :ratio, :head_dim], self.kv_state[:batch_size, ratio:, head_dim:]], + dim=1, + ) + merged_score = torch.cat( + [ + self.score_state[:batch_size, :ratio, :head_dim], + self.score_state[:batch_size, ratio:, head_dim:], + ], + dim=1, + ) + kv = (merged_kv * merged_score.softmax(dim=1)).sum(dim=1, keepdim=True) + self.kv_state[:batch_size, :ratio] = self.kv_state[:batch_size, ratio:] + self.score_state[:batch_size, :ratio] = self.score_state[:batch_size, ratio:] + else: + self.kv_state[:batch_size, slot] = kv.squeeze(1) + self.score_state[:batch_size, slot] = score.squeeze(1) + if not should_compress: + return None + kv = ( + self.kv_state[:batch_size, :ratio] + * self.score_state[:batch_size, :ratio].softmax(dim=1) + ).sum(dim=1, keepdim=True) + + kv = self.norm(kv.to(dtype)) + if start_pos == 0: + freqs_cis = self.freqs_cis[:cutoff:ratio].to(kv.device) + self.kv_cache[:batch_size, : seqlen // ratio] = kv + else: + freqs_cis = self.freqs_cis[start_pos + 1 - ratio].unsqueeze(0).to(kv.device) + self.kv_cache[:batch_size, start_pos // ratio] = kv.squeeze(1) + apply_rotary_emb(kv[..., -rope_head_dim:], freqs_cis) + return kv + + +def generate_test_data(params: dict) -> tuple[torch.Tensor, int]: + batch_size = params["batch_size"] + seq_len = params["seq_len"] + dim = params["dim"] + start_pos = params["start_pos"] + x = torch.randn(batch_size, seq_len, dim, dtype=torch.bfloat16, device="cpu") + return x, start_pos + + +def test_kv_compress(): + return Model(*get_init_inputs()).forward(*get_inputs()) + + +def get_inputs(): + params = {"batch_size": 1, "seq_len": 12, "dim": 448, "start_pos": 0} + return list(generate_test_data(params)) + + +def get_init_inputs(): + return [1, 256, 448, 32, 4, 4, 1e-6] diff --git a/dlblas/kernels/kernelswift_triton/level3/12_hc_post.py b/dlblas/kernels/kernelswift_triton/level3/12_hc_post.py index 57c55518..36d3b2e5 100644 --- a/dlblas/kernels/kernelswift_triton/level3/12_hc_post.py +++ b/dlblas/kernels/kernelswift_triton/level3/12_hc_post.py @@ -1,78 +1,89 @@ import torch import torch.nn as nn + import triton import triton.language as tl @triton.jit -def fused_einsum_axpy_kernel( - x_ptr, # [a, b, c] bf16 - residual_ptr, # [a, b, m, c] bf16 - post_layer_mix_ptr, # [a, b, n, 1] fp32 - comb_res_mix_ptr, # [a, b, m, n] fp32 - out_ptr, # [a, b, n, c] bf16 - n0, n1, C, # sizes: a, b, c - STRIDE_X_A, STRIDE_X_B, STRIDE_X_C, - STRIDE_R_A, STRIDE_R_B, STRIDE_R_M, STRIDE_R_C, - STRIDE_P_A, STRIDE_P_B, STRIDE_P_N, STRIDE_P_1, - STRIDE_CM_A, STRIDE_CM_B, STRIDE_CM_M, STRIDE_CM_N, - STRIDE_O_A, STRIDE_O_B, STRIDE_O_N, STRIDE_O_C, - N1, # n1 for decoding ab index - MHC: tl.constexpr, # mhc_mult (n dimension) - compile-time constant - BLOCK_C: tl.constexpr, # tile size along c +def hc_post_kernel( + x_ptr, + residual_ptr, + post_ptr, + comb_ptr, + y_ptr, + batch_size, + seq_len, + hc_mult, + hidden_size, + stride_x_b, + stride_x_s, + stride_x_d, + stride_r_b, + stride_r_s, + stride_r_h, + stride_r_d, + stride_p_b, + stride_p_s, + stride_p_h, + stride_c_b, + stride_c_s, + stride_c_m, + stride_c_h, + stride_y_b, + stride_y_s, + stride_y_h, + stride_y_d, + BLOCK_D: tl.constexpr, + HC: tl.constexpr, ): - pid_ab = tl.program_id(0) - pid_c = tl.program_id(1) - - a_idx = pid_ab // N1 - b_idx = pid_ab - a_idx * N1 - - offs_c = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) - c_mask = offs_c < C - - # Precompute AB base offsets to reduce repeated integer arithmetic - x_ab_base = a_idx * STRIDE_X_A + b_idx * STRIDE_X_B - r_ab_base = a_idx * STRIDE_R_A + b_idx * STRIDE_R_B - p_ab_base = a_idx * STRIDE_P_A + b_idx * STRIDE_P_B - cm_ab_base = a_idx * STRIDE_CM_A + b_idx * STRIDE_CM_B - o_ab_base = a_idx * STRIDE_O_A + b_idx * STRIDE_O_B - - # Load x[a, b, c] as bf16 -> fp32 (streaming, bypass L1) - x_ptrs = x_ptr + x_ab_base + offs_c * STRIDE_X_C - x_vec = tl.load(x_ptrs, mask=c_mask, other=0, cache_modifier='.cg').to(tl.float32) - - # Load post_layer_mix[a, b, n] as fp32 and keep in cache - offs_n = tl.arange(0, MHC) - plm_ptrs = post_layer_mix_ptr + p_ab_base + offs_n * STRIDE_P_N - plm_vec = tl.load(plm_ptrs, eviction_policy='evict_last') # fp32, length MHC - - # Initialize output tile [MHC, BLOCK_C] with x * post_layer_mix - out_tile = plm_vec[:, None] * x_vec[None, :] # fp32 - - # Software pipeline for m-reduction: prefetch next before using current - # Prefetch m=0 - res_ptrs = residual_ptr + r_ab_base + offs_c * STRIDE_R_C - res_vec = tl.load(res_ptrs, mask=c_mask, other=0, cache_modifier='.cg').to(tl.float32) - crm_ptrs = comb_res_mix_ptr + cm_ab_base + offs_n * STRIDE_CM_N - crm_vec = tl.load(crm_ptrs, eviction_policy='evict_last') - - for m in range(MHC): - # FMA: out += crm_vec[:, None] * res_vec[None, :] - out_tile += crm_vec[:, None] * res_vec[None, :] - - if m + 1 < MHC: - # Prefetch next m+1 - res_ptrs_next = residual_ptr + r_ab_base + (m + 1) * STRIDE_R_M + offs_c * STRIDE_R_C - res_vec_next = tl.load(res_ptrs_next, mask=c_mask, other=0, cache_modifier='.cg').to(tl.float32) - crm_ptrs_next = comb_res_mix_ptr + cm_ab_base + (m + 1) * STRIDE_CM_M + offs_n * STRIDE_CM_N - crm_vec_next = tl.load(crm_ptrs_next, eviction_policy='evict_last') - # Advance pipeline - res_vec = res_vec_next - crm_vec = crm_vec_next - - # Store result to out[a,b,n,c] as bf16 - out_ptrs = out_ptr + o_ab_base + offs_n[:, None] * STRIDE_O_N + offs_c[None, :] * STRIDE_O_C - tl.store(out_ptrs, out_tile.to(tl.bfloat16), mask=c_mask[None, :]) + bs = tl.program_id(0) + h = tl.program_id(1) + + b = bs // seq_len + s = bs % seq_len + offs_d = tl.arange(0, BLOCK_D) + + post_ptrs = post_ptr + b * stride_p_b + s * stride_p_s + h * stride_p_h + post_val = tl.load(post_ptrs).to(tl.float32) + + for d0 in range(0, hidden_size, BLOCK_D): + d_offsets = d0 + offs_d + mask_d = d_offsets < hidden_size + + x_ptrs = x_ptr + b * stride_x_b + s * stride_x_s + d_offsets * stride_x_d + x_vals = tl.load(x_ptrs, mask=mask_d, other=0).to(tl.float32) + + acc = post_val * x_vals + + for m in tl.static_range(0, HC): + comb_ptrs = ( + comb_ptr + + b * stride_c_b + + s * stride_c_s + + m * stride_c_m + + h * stride_c_h + ) + comb_val = tl.load(comb_ptrs).to(tl.float32) + + residual_ptrs = ( + residual_ptr + + b * stride_r_b + + s * stride_r_s + + m * stride_r_h + + d_offsets * stride_r_d + ) + residual_vals = tl.load(residual_ptrs, mask=mask_d, other=0).to(tl.float32) + acc += comb_val * residual_vals + + y_ptrs = ( + y_ptr + + b * stride_y_b + + s * stride_y_s + + h * stride_y_h + + d_offsets * stride_y_d + ) + tl.store(y_ptrs, acc.to(tl.bfloat16), mask=mask_d) class ModelNew(nn.Module): @@ -83,67 +94,87 @@ def forward( self, x: torch.Tensor, residual: torch.Tensor, - post_layer_mix: torch.Tensor, - comb_res_mix: torch.Tensor, + post: torch.Tensor, + comb: torch.Tensor, ) -> torch.Tensor: - # Shapes - n0, n1, h = x.shape - mhc_mult = residual.shape[2] - assert comb_res_mix.shape[:3] == (n0, n1, mhc_mult) and comb_res_mix.shape[3] == mhc_mult - assert post_layer_mix.shape[:3] == (n0, n1, mhc_mult) - - # Allocate output - out = torch.empty((n0, n1, mhc_mult, h), dtype=torch.bfloat16, device=x.device) - - # Launch Triton kernel only on CUDA; fallback otherwise - if x.is_cuda: - # Compute grid - BLOCK_C = 256 - grid = (n0 * n1, triton.cdiv(h, BLOCK_C)) - - # Extract strides (in elements) - sx0, sx1, sx2 = x.stride() - sr0, sr1, sr2, sr3 = residual.stride() - sp0, sp1, sp2, sp3 = post_layer_mix.stride() - sc0, sc1, sc2, sc3 = comb_res_mix.stride() - so0, so1, so2, so3 = out.stride() - - fused_einsum_axpy_kernel[grid]( - x, residual, post_layer_mix, comb_res_mix, out, - n0, n1, h, - sx0, sx1, sx2, - sr0, sr1, sr2, sr3, - sp0, sp1, sp2, sp3, - sc0, sc1, sc2, sc3, - so0, so1, so2, so3, - n1, - MHC=mhc_mult, - BLOCK_C=BLOCK_C, - num_warps=8, - num_stages=3, - ) - return out - else: - # Fallback: reference implementation on CPU or non-CUDA - term2 = torch.einsum('abmn,abmc->abnc', comb_res_mix, residual.float()) - return (x.float().unsqueeze(-2) * post_layer_mix + term2).bfloat16() - + batch_size, seq_len, hidden_size = x.shape + hc_mult = residual.shape[2] + + assert residual.shape == (batch_size, seq_len, hc_mult, hidden_size) + assert post.shape == (batch_size, seq_len, hc_mult) + assert comb.shape == (batch_size, seq_len, hc_mult, hc_mult) + + y = torch.empty((batch_size, seq_len, hc_mult, hidden_size), device=x.device, dtype=torch.bfloat16) + + sx0, sx1, sx2 = x.stride() + sr0, sr1, sr2, sr3 = residual.stride() + sp0, sp1, sp2 = post.stride() + sc0, sc1, sc2, sc3 = comb.stride() + sy0, sy1, sy2, sy3 = y.stride() + + BLOCK_D = 256 + num_warps = 4 + num_stages = 2 + + grid = (batch_size * seq_len, hc_mult) + hc_post_kernel[grid]( + x, + residual, + post, + comb, + y, + batch_size, + seq_len, + hc_mult, + hidden_size, + sx0, + sx1, + sx2, + sr0, + sr1, + sr2, + sr3, + sp0, + sp1, + sp2, + sc0, + sc1, + sc2, + sc3, + sy0, + sy1, + sy2, + sy3, + BLOCK_D=BLOCK_D, + HC=hc_mult, + num_warps=num_warps, + num_stages=num_stages, + ) + return y + + +def generate_test_data(params): + batch_size = params['batch_size'] + seq_len = params['seq_len'] + hidden_size = params['hidden'] + hc_mult = params['hc'] + x_data = torch.randn(batch_size, seq_len, hidden_size, dtype=torch.bfloat16, device='cpu') + residual_data = torch.randn(batch_size, seq_len, hc_mult, hidden_size, dtype=torch.bfloat16, device='cpu') + post_data = torch.randn(batch_size, seq_len, hc_mult, dtype=torch.float32, device='cpu') + comb_data = torch.randn(batch_size, seq_len, hc_mult, hc_mult, dtype=torch.float32, device='cpu') + o_grad = torch.randn(batch_size, seq_len, hc_mult, hidden_size, dtype=torch.bfloat16, device='cpu') + return x_data, residual_data, post_data, comb_data, o_grad + + +def test_hc_post_fwd(): + return ModelNew(*get_init_inputs()).forward(*get_inputs()) -n0 = 1 -n1 = 4096 -h = 1280 -mhc_mult = 4 -device = 'cuda' def get_inputs(): - x = torch.randn((n0, n1, h), dtype=torch.bfloat16, device=device) - residual = torch.randn((n0, n1, mhc_mult, h), dtype=torch.bfloat16, device=device) - post_layer_mix = torch.randn((n0, n1, mhc_mult, 1), dtype=torch.float32, device=device) - comb_res_mix = torch.randn((n0, n1, mhc_mult, mhc_mult), dtype=torch.float32, device=device) - - return [ - x, residual, post_layer_mix, comb_res_mix, - ] + params = {'batch_size': 1, 'seq_len': 4096, 'hidden': 1280, 'hc': 4} + x_data, residual_data, post_data, comb_data, o_grad = generate_test_data(params) + return [x_data, residual_data, post_data, comb_data] + def get_init_inputs(): return [] diff --git a/dlblas/kernels/kernelswift_triton/level3/13_head_compute_mix_fwd.py b/dlblas/kernels/kernelswift_triton/level3/13_head_compute_mix_fwd.py index a00cf787..6db41bda 100644 --- a/dlblas/kernels/kernelswift_triton/level3/13_head_compute_mix_fwd.py +++ b/dlblas/kernels/kernelswift_triton/level3/13_head_compute_mix_fwd.py @@ -61,9 +61,7 @@ def forward( mhc_pre_eps: float, ) -> torch.Tensor: # Fallback to PyTorch if Triton unavailable or running on CPU - if (not _TRITON_AVAILABLE) or (not input_mix.is_cuda) or (not mhc_scale.is_cuda) or (not mhc_base.is_cuda): - mhc_head_layer_mix = input_mix * mhc_scale + mhc_base - return torch.sigmoid(mhc_head_layer_mix) + mhc_pre_eps + assert input_mix.dim() == 3, "input_mix must be 3D [B, N, C]" B, N, C = input_mix.shape @@ -87,7 +85,7 @@ def forward( _sigmoid_affine_kernel[grid]( x, scale, base, y, - total, C, float(mhc_pre_eps), + total, C, mhc_pre_eps.float(), BLOCK=BLOCK, num_warps=8, num_stages=2, diff --git a/dlblas/kernels/kernelswift_triton/level3/14_head_compute_mix_bwd.py b/dlblas/kernels/kernelswift_triton/level3/14_head_compute_mix_bwd.py index d9be1722..92042de8 100644 --- a/dlblas/kernels/kernelswift_triton/level3/14_head_compute_mix_bwd.py +++ b/dlblas/kernels/kernelswift_triton/level3/14_head_compute_mix_bwd.py @@ -1,70 +1,78 @@ import torch import torch.nn as nn - import triton import triton.language as tl @triton.jit def fused_backward_kernel( - input_ptr, # *f32, shape [Ni, Mh] - grad_out_ptr, # *f32, shape [Ni, Mh] - mhc_scale_ptr, # *f32, shape [1] - mhc_base_ptr, # *f32, shape [Mh] - grad_input_ptr, # *f32, shape [Ni, Mh] - grad_mhc_base_ptr, # *f32, shape [Mh] - grad_mhc_scale_ptr, # *f32, shape [1] - Ni: tl.constexpr, # int - Mh: tl.constexpr, # int - BLOCK_N: tl.constexpr, + input_ptr, # float* (n0, n1, k) + mhc_scale_ptr, # float* (1,) + mhc_base_ptr, # float* (k,) + grad_out_ptr, # float* (n0, n1, k) + grad_input_ptr, # float* (n0, n1, k) + grad_scale_ptr, # float* (1,) + grad_base_ptr, # float* (k,) + n0, n1, k, # int sizes + stride_in_0, stride_in_1, stride_in_2, # input strides + stride_go_0, stride_go_1, stride_go_2, # grad_out strides + stride_gi_0, stride_gi_1, stride_gi_2, # grad_input strides BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, ): - pid_n = tl.program_id(0) - pid_m = tl.program_id(1) + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) + tl.max_contiguous(offs_k, BLOCK_K) - n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - m_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + M = n0 * n1 + mask_m = offs_m < M + mask_k = offs_k < k + mask = mask_m[:, None] & mask_k[None, :] - # Hints for better codegen/vectorization - tl.multiple_of(n_offsets, BLOCK_N) - tl.multiple_of(m_offsets, BLOCK_M) + # Decompose offs_m -> (i0, i1) + i0 = offs_m // n1 + i1 = offs_m - i0 * n1 # cheaper than modulo - n_mask = n_offsets < Ni - m_mask = m_offsets < Mh - mask = n_mask[:, None] & m_mask[None, :] + # Compute pointers + ptr_in = input_ptr + (i0[:, None] * stride_in_0 + i1[:, None] * stride_in_1 + offs_k[None, :] * stride_in_2) + ptr_go = grad_out_ptr + (i0[:, None] * stride_go_0 + i1[:, None] * stride_go_1 + offs_k[None, :] * stride_go_2) + ptr_gi = grad_input_ptr + (i0[:, None] * stride_gi_0 + i1[:, None] * stride_gi_1 + offs_k[None, :] * stride_gi_2) - # Compute base pointers for 2D tile - base_ptrs = n_offsets[:, None] * Mh + m_offsets[None, :] + # Loads with cache/eviction hints + x = tl.load(ptr_in, mask=mask, other=0.0, cache_modifier=".cg") + go = tl.load(ptr_go, mask=mask, other=0.0, cache_modifier=".cg") + scale = tl.load(mhc_scale_ptr, eviction_policy="evict_last") + base_k = tl.load(mhc_base_ptr + offs_k, mask=mask_k, other=0.0, eviction_policy="evict_last") - # Loads with cache hints: stream x/g, keep base in cache (reused across rows) - x = tl.load(input_ptr + base_ptrs, mask=mask, other=0.0, cache_modifier=".cg") - g = tl.load(grad_out_ptr + base_ptrs, mask=mask, other=0.0, cache_modifier=".cg") - s = tl.load(mhc_scale_ptr) # scalar - b_m = tl.load(mhc_base_ptr + m_offsets, mask=m_mask, other=0.0, cache_modifier=".ca") + # z = x * scale + base + z = x * scale + base_k[None, :] - # Broadcast mhc_base across rows and compute sigmoid and its derivative - z = tl.fma(x, s, b_m[None, :]) - sig = tl.sigmoid(z) - one_minus_sig = 1.0 - sig - gz = g * sig * one_minus_sig + # sigmoid and grad_z (use s - s*s for derivative) + s = tl.sigmoid(z) + t = s - s * s + grad_z = go * t # grad_input_mix - grad_input = gz * s - tl.store(grad_input_ptr + base_ptrs, grad_input, mask=mask) + gi = grad_z * scale + tl.store(ptr_gi, gi, mask=mask) - # Partial reductions for grad_mhc_base (sum over n for each m) - partial_base = tl.sum(gz, axis=0) # [BLOCK_M] - tl.atomic_add(grad_mhc_base_ptr + m_offsets, partial_base, mask=m_mask) + # grad_mhc_base: sum over m for each k, then atomic add + sum_m = tl.sum(grad_z, axis=0) + tl.atomic_add(grad_base_ptr + offs_k, sum_m, mask=mask_k) - # Partial reduction for grad_mhc_scale (sum over all n and m) - partial_scale_rows = tl.sum(gz * x, axis=1) # [BLOCK_N] - partial_scale = tl.sum(partial_scale_rows, axis=0) - tl.atomic_add(grad_mhc_scale_ptr, partial_scale) + # grad_mhc_scale: sum over all elements of grad_z * x + prod = grad_z * x + tile_row_sum = tl.sum(prod, axis=1) # [BLOCK_M] + tile_sum = tl.sum(tile_row_sum, axis=0) # scalar + tl.atomic_add(grad_scale_ptr, tile_sum) class ModelNew(nn.Module): """ - Model that computes manual backward of mhc_head_compute_mix using a fused Triton kernel when available. + Model that computes manual backward of mhc_head_compute_mix using a fused Triton kernel. """ def __init__(self): @@ -89,61 +97,42 @@ def forward( Returns: grad_input_mix, grad_mhc_scale, grad_mhc_base """ - # Fallback to PyTorch if inputs are not CUDA tensors or not float32 - if ( - (not input_mix.is_cuda) - or (not grad_out.is_cuda) - or (not mhc_scale.is_cuda) - or (not mhc_base.is_cuda) - or (input_mix.dtype != torch.float32) - or (grad_out.dtype != torch.float32) - or (mhc_scale.dtype != torch.float32) - or (mhc_base.dtype != torch.float32) - ): - z = input_mix * mhc_scale + mhc_base - sigmoid = torch.sigmoid(z) - grad_z = grad_out * sigmoid * (1 - sigmoid) - grad_input_mix = grad_z * mhc_scale - grad_mhc_base = grad_z.sum(dim=(0, 1), keepdim=True).view(-1) - grad_mhc_scale = (grad_z * input_mix).sum(dim=(0, 1, 2), keepdim=True).view(1) - return grad_input_mix, grad_mhc_scale, grad_mhc_base - - # Ensure contiguous memory layout - n0, n1, mh = input_mix.shape - Ni = n0 * n1 - Mh = mh - - x2d = input_mix.reshape(Ni, Mh).contiguous() - g2d = grad_out.reshape(Ni, Mh).contiguous() + # Prepare outputs grad_input_mix = torch.empty_like(input_mix) - grad_input_2d = grad_input_mix.view(Ni, Mh) + grad_mhc_scale = torch.zeros_like(mhc_scale) + grad_mhc_base = torch.zeros_like(mhc_base) + + n0, n1, k = input_mix.shape - grad_mhc_base = torch.zeros(Mh, device=input_mix.device, dtype=input_mix.dtype) - grad_mhc_scale = torch.zeros(1, device=input_mix.device, dtype=input_mix.dtype) + # Strides (in elements) + s_in0, s_in1, s_in2 = input_mix.stride() + s_go0, s_go1, s_go2 = grad_out.stride() + s_gi0, s_gi1, s_gi2 = grad_input_mix.stride() # Launch Triton kernel - BLOCK_N = 128 - BLOCK_M = 32 - grid = (triton.cdiv(Ni, BLOCK_N), triton.cdiv(Mh, BLOCK_M)) + BLOCK_M = 128 + BLOCK_K = 32 + grid = (triton.cdiv(n0 * n1, BLOCK_M), triton.cdiv(k, BLOCK_K)) fused_backward_kernel[grid]( - x2d, - g2d, + input_mix, mhc_scale, mhc_base, - grad_input_2d, - grad_mhc_base, + grad_out, + grad_input_mix, grad_mhc_scale, - Ni, - Mh, - BLOCK_N=BLOCK_N, + grad_mhc_base, + n0, n1, k, + s_in0, s_in1, s_in2, + s_go0, s_go1, s_go2, + s_gi0, s_gi1, s_gi2, BLOCK_M=BLOCK_M, + BLOCK_K=BLOCK_K, num_warps=4, num_stages=2, ) - return grad_input_mix, grad_mhc_scale, grad_mhc_base - + return grad_input_mix, grad_mhc_scale.view(1), grad_mhc_base batch0 = 2 batch1 = 1024 diff --git a/dlblas/kernels/kernelswift_triton/level3/15_expand_kenel_fwd.py b/dlblas/kernels/kernelswift_triton/level3/15_expand_kenel_fwd.py index c881b567..42950e52 100644 --- a/dlblas/kernels/kernelswift_triton/level3/15_expand_kenel_fwd.py +++ b/dlblas/kernels/kernelswift_triton/level3/15_expand_kenel_fwd.py @@ -6,36 +6,44 @@ @triton.jit def expand_to_mhc_kernel( - x_ptr, # *x* pointer, shape (L, H), contiguous - y_ptr, # *y* pointer, shape (L, M, H), contiguous - L, # total leading elements collapsed - M, # mhc_mult - H, # hidden size + x_ptr, y_ptr, + B, S, M, H, + stride_x_b, stride_x_s, stride_x_h, + stride_y_b, stride_y_s, stride_y_m, stride_y_h, BLOCK_H: tl.constexpr, - BLOCK_M: tl.constexpr, ): - # Program IDs along (L, ceil_div(M, BLOCK_M), ceil_div(H, BLOCK_H)) - pid_l = tl.program_id(0) - pid_mb = tl.program_id(1) - pid_ht = tl.program_id(2) + pid_bs = tl.program_id(0) + pid_h = tl.program_id(1) + pid_m = tl.program_id(2) - # Offsets along H and M for this program - offs_h = pid_ht * BLOCK_H + tl.arange(0, BLOCK_H) - mask_h = offs_h < H - offs_m = pid_mb * BLOCK_M + tl.arange(0, BLOCK_M) - mask_m = offs_m < M + # Map program ids to (b, s, m) + b = pid_bs // S + s = pid_bs - b * S + m = pid_m - # Load a contiguous tile of H from x for this l - x_row_ptr = x_ptr + pid_l * H + offs_h - vals = tl.load(x_row_ptr, mask=mask_h, other=0) + # Offsets along hidden dimension + h_offsets = pid_h * BLOCK_H + tl.arange(0, BLOCK_H) + tl.max_contiguous(h_offsets, BLOCK_H) + tl.multiple_of(h_offsets, 16) + h_mask = h_offsets < H - # Prepare a 2D tile pointer for y with shape (BLOCK_M, BLOCK_H) - # y index: ((l * M + m) * H + h) - y_row_offsets = (pid_l * M + offs_m) * H - y_tile_ptrs = y_ptr + y_row_offsets[:, None] + offs_h[None, :] + # Precompute base pointers for (b, s, m) + base_x = b.to(tl.int64) * stride_x_b + s.to(tl.int64) * stride_x_s + base_y = (b.to(tl.int64) * stride_y_b + + s.to(tl.int64) * stride_y_s + + m.to(tl.int64) * stride_y_m) - # Broadcast vals across the M dimension and store once - tl.store(y_tile_ptrs, vals[None, :], mask=mask_m[:, None] & mask_h[None, :]) + # Compute input/output pointers + x_ptrs = x_ptr + base_x + h_offsets.to(tl.int64) * stride_x_h + y_ptrs = y_ptr + base_y + h_offsets.to(tl.int64) * stride_y_h + + # Hints for locality/coalescing + tl.max_contiguous(x_ptrs, BLOCK_H) + tl.max_contiguous(y_ptrs, BLOCK_H) + + # Load once and store; use L2 cache to improve reuse across M CTAs + x_vals = tl.load(x_ptrs, mask=h_mask, other=0, cache_modifier=".cg") + tl.store(y_ptrs, x_vals, mask=h_mask) class ModelNew(nn.Module): @@ -59,63 +67,35 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: 扩展后的张量,形状为 (batch, seq_len, mhc_mult, hidden_dim)。 """ - original_shape = x.shape - M = self.mhc_mult - H = original_shape[-1] - - # Fallback to reference path for CPU or degenerate cases - if (not x.is_cuda) or (H == 0) or (M == 0) or (x.numel() == 0): - return x.unsqueeze(-2).expand(*original_shape[:-1], M, H).contiguous() - - # Ensure contiguous and flatten leading dims into L - x_contig = x.contiguous() - L = 1 - for d in original_shape[:-1]: - L *= d - x_flat = x_contig.view(L, H) - - # Allocate output tensor as (L, M, H), then reshape to final - y_flat = torch.empty((L, M, H), dtype=x.dtype, device=x.device) - - # Kernel launch configuration tuning - # Larger BLOCK_H improves bandwidth by reducing grid overhead; choose based on H - if H >= 512: - BLOCK_H = 512 - num_warps = 8 - elif H >= 256: - BLOCK_H = 256 - num_warps = 8 - elif H >= 128: - BLOCK_H = 128 - num_warps = 4 - else: - BLOCK_H = 64 - num_warps = 2 - - # Replicate multiple M slices per program; choose a small power of two for occupancy - if M >= 16: - BLOCK_M = 16 - elif M >= 8: - BLOCK_M = 8 - elif M >= 4: - BLOCK_M = 4 - elif M >= 2: - BLOCK_M = 2 + # Fast path: Triton kernel for CUDA tensors with 3D shape (B, S, H) + if x.ndim == 3: + B, S, H = x.shape + M = self.mhc_mult + # Allocate contiguous output + y = torch.empty((B, S, M, H), dtype=x.dtype, device=x.device) + + # Get strides in element units + sx_b, sx_s, sx_h = x.stride() + sy_b, sy_s, sy_m, sy_h = y.stride() + + # Larger tile and more warps for better bandwidth utilization + BLOCK_H = 1024 + grid = (B * S, triton.cdiv(H, BLOCK_H), M) + + expand_to_mhc_kernel[grid]( + x, y, + B, S, M, H, + sx_b, sx_s, sx_h, + sy_b, sy_s, sy_m, sy_h, + BLOCK_H=BLOCK_H, + num_warps=8, + num_stages=2, + ) + return y else: - BLOCK_M = 1 - - grid = (L, (M + BLOCK_M - 1) // BLOCK_M, (H + BLOCK_H - 1) // BLOCK_H) - - expand_to_mhc_kernel[grid]( - x_flat, y_flat, - L, M, H, - BLOCK_H=BLOCK_H, - BLOCK_M=BLOCK_M, - num_warps=num_warps, - ) - - # Reshape back to (..., M, H) - return y_flat.view(*original_shape[:-1], M, H) + # Fallback to reference implementation for non-CUDA or non-3D cases + original_shape = x.shape + return x.unsqueeze(-2).expand(*original_shape[:-1], self.mhc_mult, original_shape[-1]).contiguous() def get_init_inputs(): @@ -134,5 +114,5 @@ def get_inputs(): batch_size = 1 seq_len = 1024 hidden_dim = 1280 - x = torch.randn(batch_size, seq_len, hidden_dim, device='cuda') + x = torch.randn(batch_size, seq_len, hidden_dim) return [x] \ No newline at end of file diff --git a/dlblas/kernels/kernelswift_triton/level3/24_multilayer_recompute.py b/dlblas/kernels/kernelswift_triton/level3/24_multilayer_recompute.py index 74395811..adbe1bfd 100644 --- a/dlblas/kernels/kernelswift_triton/level3/24_multilayer_recompute.py +++ b/dlblas/kernels/kernelswift_triton/level3/24_multilayer_recompute.py @@ -107,18 +107,18 @@ def _mhc_post_kernel( + offs_i[:, None] * stride_rm + offs_h[None, :] * stride_rh ).to(tl.float32) - comb_t = tl.load( + comb = tl.load( comb_mix_ptr + pid_n * stride_cn - + offs_o[:, None] * stride_co - + offs_i[None, :] * stride_ci - ) - term2 = tl.dot(comb_t, residual_block, input_precision="ieee", out_dtype=tl.float32) + + offs_i[:, None] * stride_ci + + offs_o[None, :] * stride_co + ).to(tl.float32) + term2 = tl.dot(tl.trans(comb), residual_block, out_dtype=tl.float32) post_mix_block = tl.load( post_mix_ptr + pid_n * stride_pn + offs_o[:, None] * stride_pm - ) + ).to(tl.float32) out_block = term2 + post_mix_block * layer_output[None, :] tl.store( out_ptr diff --git a/dlblas/kernels/kernelswift_triton/level3/25_compressor.py b/dlblas/kernels/kernelswift_triton/level3/25_compressor.py new file mode 100644 index 00000000..4624b5a4 --- /dev/null +++ b/dlblas/kernels/kernelswift_triton/level3/25_compressor.py @@ -0,0 +1,878 @@ +import math +import torch_npu # noqa: F401 +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + + +def _next_power_of_2(x: int) -> int: + return 1 if x <= 1 else 1 << (x - 1).bit_length() + + +def _num_warps(block_size: int) -> int: + if block_size <= 64: + return 1 + if block_size <= 128: + return 2 + if block_size <= 256: + return 4 + return 8 + + +class Linear(nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype: torch.dtype = torch.float32): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) + if bias: + self.bias = nn.Parameter(torch.empty(out_features, dtype=dtype)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, self.weight, self.bias) + + +if True: + @triton.jit + def _rmsnorm_fwd_kernel( + x_ptr, + weight_ptr, + out_ptr, + n_rows, + n_cols, + eps, + BLOCK: tl.constexpr, + ): + pid = tl.program_id(0) + offs = tl.arange(0, BLOCK) + row_start = pid * n_cols + idx = row_start + offs + mask = offs < n_cols + x = tl.load(x_ptr + idx, mask=mask, other=0.0) + x_f32 = x.to(tl.float32) + mean = tl.sum(x_f32 * x_f32, axis=0) * (1.0 / n_cols) + inv_rms = tl.math.rsqrt(mean + eps) + w = tl.load(weight_ptr + offs, mask=mask, other=0.0).to(tl.float32) + y = x_f32 * inv_rms * w + tl.store(out_ptr + idx, y, mask=mask) + + + @triton.jit + def _overlap_pack_kernel( + kv_in_ptr, + score_in_ptr, + kv_out_ptr, + score_out_ptr, + batch_size, + num_windows, + ratio, + head_dim, + score_fill, + BLOCK_D: tl.constexpr, + ): + pid = tl.program_id(0) + rows_per_batch = num_windows * (2 * ratio) + b = pid // rows_per_batch + row_in_batch = pid % rows_per_batch + w = row_in_batch // (2 * ratio) + rr = row_in_batch % (2 * ratio) + + offs_d = tl.arange(0, BLOCK_D) + mask_d = offs_d < head_dim + + use_curr = rr >= ratio + prev_w = tl.maximum(w - 1, 0) + src_w = tl.where(use_curr, w, prev_w) + src_r = tl.where(use_curr, rr - ratio, rr) + src_d_base = tl.where(use_curr, head_dim, 0) + valid_src = use_curr | (w > 0) + + in_row_offset = (((b * num_windows + src_w) * ratio + src_r) * (2 * head_dim)) + src_d_base + out_row_offset = (((b * num_windows + w) * (2 * ratio) + rr) * head_dim) + + kv_vals = tl.load(kv_in_ptr + in_row_offset + offs_d, mask=valid_src & mask_d, other=0.0) + score_vals = tl.load(score_in_ptr + in_row_offset + offs_d, mask=valid_src & mask_d, other=score_fill) + + tl.store(kv_out_ptr + out_row_offset + offs_d, kv_vals, mask=mask_d) + tl.store(score_out_ptr + out_row_offset + offs_d, score_vals, mask=mask_d) + + + @triton.jit + def _prefill_reduce_kernel( + kv_ptr, + score_ptr, + out_ptr, + num_rows, + ratio, + head_dim, + BLOCK_D: tl.constexpr, + ): + pid = tl.program_id(0) + if pid >= num_rows: + return + + offs_d = tl.arange(0, BLOCK_D) + mask_d = offs_d < head_dim + + base_offset = pid * ratio * head_dim + max_vec = tl.full([BLOCK_D], -float("inf"), dtype=tl.float32) + sum_vec = tl.zeros([BLOCK_D], dtype=tl.float32) + acc_vec = tl.zeros([BLOCK_D], dtype=tl.float32) + + for r in range(0, ratio): + row_offset = base_offset + r * head_dim + offs_d + score_vec = tl.load(score_ptr + row_offset, mask=mask_d, other=-float("inf")).to(tl.float32) + kv_vec = tl.load(kv_ptr + row_offset, mask=mask_d, other=0.0).to(tl.float32) + + new_max = tl.maximum(max_vec, score_vec) + prev_is_inf = max_vec == -float("inf") + score_is_inf = score_vec == -float("inf") + prev_delta = tl.where(prev_is_inf, 0.0, max_vec) - tl.where(prev_is_inf, 0.0, new_max) + score_delta = tl.where(score_is_inf, 0.0, score_vec) - tl.where(score_is_inf, 0.0, new_max) + alpha = tl.where(prev_is_inf, 0.0, tl.exp(prev_delta)) + beta = tl.where(score_is_inf, 0.0, tl.exp(score_delta)) + + acc_vec = acc_vec * alpha + kv_vec * beta + sum_vec = sum_vec * alpha + beta + max_vec = new_max + + valid_sum = sum_vec > 0.0 + safe_sum = tl.where(valid_sum, sum_vec, 1.0) + out_vec = tl.where(valid_sum, acc_vec / safe_sum, 0.0) + tl.store(out_ptr + pid * head_dim + offs_d, out_vec, mask=mask_d) + + + @triton.jit + def _prefill_reduce_overlap_kernel( + kv_ptr, + score_ptr, + out_ptr, + num_windows, + ratio, + head_dim, + BLOCK_D: tl.constexpr, + ): + pid = tl.program_id(0) + b = pid // num_windows + w = pid % num_windows + + offs_d = tl.arange(0, BLOCK_D) + mask_d = offs_d < head_dim + + max_vec = tl.full([BLOCK_D], -float("inf"), dtype=tl.float32) + sum_vec = tl.zeros([BLOCK_D], dtype=tl.float32) + acc_vec = tl.zeros([BLOCK_D], dtype=tl.float32) + + for rr in range(0, 2 * ratio): + use_curr = rr >= ratio + prev_w = tl.maximum(w - 1, 0) + src_w = tl.where(use_curr, w, prev_w) + src_r = tl.where(use_curr, rr - ratio, rr) + src_c = tl.where(use_curr, head_dim + offs_d, offs_d) + valid_src = use_curr | (w > 0) + + row_offset = (((b * num_windows + src_w) * ratio + src_r) * (2 * head_dim)) + src_c + score_vec = tl.load(score_ptr + row_offset, mask=valid_src & mask_d, other=-float("inf")).to(tl.float32) + kv_vec = tl.load(kv_ptr + row_offset, mask=valid_src & mask_d, other=0.0).to(tl.float32) + + new_max = tl.maximum(max_vec, score_vec) + prev_is_inf = max_vec == -float("inf") + score_is_inf = score_vec == -float("inf") + prev_delta = tl.where(prev_is_inf, 0.0, max_vec) - tl.where(prev_is_inf, 0.0, new_max) + score_delta = tl.where(score_is_inf, 0.0, score_vec) - tl.where(score_is_inf, 0.0, new_max) + alpha = tl.where(prev_is_inf, 0.0, tl.exp(prev_delta)) + beta = tl.where(score_is_inf, 0.0, tl.exp(score_delta)) + + acc_vec = acc_vec * alpha + kv_vec * beta + sum_vec = sum_vec * alpha + beta + max_vec = new_max + + valid_sum = sum_vec > 0.0 + safe_sum = tl.where(valid_sum, sum_vec, 1.0) + out_vec = tl.where(valid_sum, acc_vec / safe_sum, 0.0) + tl.store(out_ptr + pid * head_dim + offs_d, out_vec, mask=mask_d) + + + @triton.jit + def _decode_update_state_kernel( + kv_token_ptr, + score_token_ptr, + ape_ptr, + kv_state_ptr, + score_state_ptr, + state_rows, + state_slot, + state_width, + BLOCK_W: tl.constexpr, + ): + pid = tl.program_id(0) + offs = tl.arange(0, BLOCK_W) + mask = offs < state_width + + token_offset = pid * state_width + offs + state_offset = ((pid * state_rows) + state_slot) * state_width + offs + kv_vals = tl.load(kv_token_ptr + token_offset, mask=mask, other=0.0) + score_vals = tl.load(score_token_ptr + token_offset, mask=mask, other=0.0) + ape_vals = tl.load(ape_ptr + offs, mask=mask, other=0.0) + + tl.store(kv_state_ptr + state_offset, kv_vals, mask=mask) + tl.store(score_state_ptr + state_offset, score_vals + ape_vals, mask=mask) + + + @triton.jit + def _decode_reduce_nonoverlap_kernel( + kv_ptr, + score_ptr, + out_ptr, + ratio, + head_dim, + BLOCK_D: tl.constexpr, + ): + pid = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_D) + mask_d = offs_d < head_dim + + base_offset = pid * ratio * head_dim + max_vec = tl.full([BLOCK_D], -float("inf"), dtype=tl.float32) + sum_vec = tl.zeros([BLOCK_D], dtype=tl.float32) + acc_vec = tl.zeros([BLOCK_D], dtype=tl.float32) + + for r in range(0, ratio): + row_offset = base_offset + r * head_dim + offs_d + score_vec = tl.load(score_ptr + row_offset, mask=mask_d, other=-float("inf")).to(tl.float32) + kv_vec = tl.load(kv_ptr + row_offset, mask=mask_d, other=0.0).to(tl.float32) + new_max = tl.maximum(max_vec, score_vec) + prev_is_inf = max_vec == -float("inf") + score_is_inf = score_vec == -float("inf") + prev_delta = tl.where(prev_is_inf, 0.0, max_vec) - tl.where(prev_is_inf, 0.0, new_max) + score_delta = tl.where(score_is_inf, 0.0, score_vec) - tl.where(score_is_inf, 0.0, new_max) + alpha = tl.where(prev_is_inf, 0.0, tl.exp(prev_delta)) + beta = tl.where(score_is_inf, 0.0, tl.exp(score_delta)) + acc_vec = acc_vec * alpha + kv_vec * beta + sum_vec = sum_vec * alpha + beta + max_vec = new_max + + valid_sum = sum_vec > 0.0 + safe_sum = tl.where(valid_sum, sum_vec, 1.0) + out_vec = tl.where(valid_sum, acc_vec / safe_sum, 0.0) + tl.store(out_ptr + pid * head_dim + offs_d, out_vec, mask=mask_d) + + + @triton.jit + def _decode_update_reduce_nonoverlap_kernel( + kv_token_ptr, + score_token_ptr, + ape_ptr, + kv_state_ptr, + score_state_ptr, + out_ptr, + ratio, + state_slot, + head_dim, + BLOCK_D: tl.constexpr, + ): + pid = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_D) + mask_d = offs_d < head_dim + + token_offset = pid * head_dim + offs_d + state_base = pid * ratio * head_dim + slot_offset = state_base + state_slot * head_dim + offs_d + + kv_new = tl.load(kv_token_ptr + token_offset, mask=mask_d, other=0.0).to(tl.float32) + score_new = ( + tl.load(score_token_ptr + token_offset, mask=mask_d, other=0.0).to(tl.float32) + + tl.load(ape_ptr + offs_d, mask=mask_d, other=0.0).to(tl.float32) + ) + tl.store(kv_state_ptr + slot_offset, kv_new, mask=mask_d) + tl.store(score_state_ptr + slot_offset, score_new, mask=mask_d) + + max_vec = tl.full([BLOCK_D], -float("inf"), dtype=tl.float32) + sum_vec = tl.zeros([BLOCK_D], dtype=tl.float32) + acc_vec = tl.zeros([BLOCK_D], dtype=tl.float32) + + for r in range(0, ratio): + row_offset = state_base + r * head_dim + offs_d + kv_vec = tl.load(kv_state_ptr + row_offset, mask=mask_d, other=0.0).to(tl.float32) + score_vec = tl.load(score_state_ptr + row_offset, mask=mask_d, other=-float("inf")).to(tl.float32) + new_max = tl.maximum(max_vec, score_vec) + prev_is_inf = max_vec == -float("inf") + score_is_inf = score_vec == -float("inf") + prev_delta = tl.where(prev_is_inf, 0.0, max_vec) - tl.where(prev_is_inf, 0.0, new_max) + score_delta = tl.where(score_is_inf, 0.0, score_vec) - tl.where(score_is_inf, 0.0, new_max) + alpha = tl.where(prev_is_inf, 0.0, tl.exp(prev_delta)) + beta = tl.where(score_is_inf, 0.0, tl.exp(score_delta)) + acc_vec = acc_vec * alpha + kv_vec * beta + sum_vec = sum_vec * alpha + beta + max_vec = new_max + + valid_sum = sum_vec > 0.0 + safe_sum = tl.where(valid_sum, sum_vec, 1.0) + out_vec = tl.where(valid_sum, acc_vec / safe_sum, 0.0) + tl.store(out_ptr + pid * head_dim + offs_d, out_vec, mask=mask_d) + + + @triton.jit + def _decode_reduce_overlap_kernel( + kv_ptr, + score_ptr, + out_ptr, + ratio, + head_dim, + BLOCK_D: tl.constexpr, + ): + pid = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_D) + mask_d = offs_d < head_dim + + state_rows = 2 * ratio + state_width = 2 * head_dim + base_offset = pid * state_rows * state_width + max_vec = tl.full([BLOCK_D], -float("inf"), dtype=tl.float32) + sum_vec = tl.zeros([BLOCK_D], dtype=tl.float32) + acc_vec = tl.zeros([BLOCK_D], dtype=tl.float32) + + for r in range(0, 2 * ratio): + src_row = tl.where(r < ratio, r, ratio + (r - ratio)) + src_col = tl.where(r < ratio, offs_d, head_dim + offs_d) + row_offset = base_offset + src_row * state_width + src_col + score_vec = tl.load(score_ptr + row_offset, mask=mask_d, other=-float("inf")).to(tl.float32) + kv_vec = tl.load(kv_ptr + row_offset, mask=mask_d, other=0.0).to(tl.float32) + new_max = tl.maximum(max_vec, score_vec) + prev_is_inf = max_vec == -float("inf") + score_is_inf = score_vec == -float("inf") + prev_delta = tl.where(prev_is_inf, 0.0, max_vec) - tl.where(prev_is_inf, 0.0, new_max) + score_delta = tl.where(score_is_inf, 0.0, score_vec) - tl.where(score_is_inf, 0.0, new_max) + alpha = tl.where(prev_is_inf, 0.0, tl.exp(prev_delta)) + beta = tl.where(score_is_inf, 0.0, tl.exp(score_delta)) + acc_vec = acc_vec * alpha + kv_vec * beta + sum_vec = sum_vec * alpha + beta + max_vec = new_max + + valid_sum = sum_vec > 0.0 + safe_sum = tl.where(valid_sum, sum_vec, 1.0) + out_vec = tl.where(valid_sum, acc_vec / safe_sum, 0.0) + tl.store(out_ptr + pid * head_dim + offs_d, out_vec, mask=mask_d) + + + @triton.jit + def _decode_update_reduce_overlap_kernel( + kv_token_ptr, + score_token_ptr, + ape_ptr, + kv_state_ptr, + score_state_ptr, + out_ptr, + ratio, + state_slot, + head_dim, + BLOCK_D: tl.constexpr, + ): + pid = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_D) + mask_d = offs_d < head_dim + + state_rows = 2 * ratio + state_width = 2 * head_dim + state_base = pid * state_rows * state_width + token_base = pid * state_width + slot_row_base = state_base + (ratio + state_slot) * state_width + + kv_lo = tl.load(kv_token_ptr + token_base + offs_d, mask=mask_d, other=0.0).to(tl.float32) + kv_hi = tl.load(kv_token_ptr + token_base + head_dim + offs_d, mask=mask_d, other=0.0).to(tl.float32) + score_lo = ( + tl.load(score_token_ptr + token_base + offs_d, mask=mask_d, other=0.0).to(tl.float32) + + tl.load(ape_ptr + offs_d, mask=mask_d, other=0.0).to(tl.float32) + ) + score_hi = ( + tl.load(score_token_ptr + token_base + head_dim + offs_d, mask=mask_d, other=0.0).to(tl.float32) + + tl.load(ape_ptr + head_dim + offs_d, mask=mask_d, other=0.0).to(tl.float32) + ) + + tl.store(kv_state_ptr + slot_row_base + offs_d, kv_lo, mask=mask_d) + tl.store(kv_state_ptr + slot_row_base + head_dim + offs_d, kv_hi, mask=mask_d) + tl.store(score_state_ptr + slot_row_base + offs_d, score_lo, mask=mask_d) + tl.store(score_state_ptr + slot_row_base + head_dim + offs_d, score_hi, mask=mask_d) + + max_vec = tl.full([BLOCK_D], -float("inf"), dtype=tl.float32) + sum_vec = tl.zeros([BLOCK_D], dtype=tl.float32) + acc_vec = tl.zeros([BLOCK_D], dtype=tl.float32) + + for r in range(0, ratio): + prev_row = state_base + r * state_width + offs_d + curr_row = state_base + (ratio + r) * state_width + head_dim + offs_d + + prev_score = tl.load(score_state_ptr + prev_row, mask=mask_d, other=-float("inf")).to(tl.float32) + prev_kv = tl.load(kv_state_ptr + prev_row, mask=mask_d, other=0.0).to(tl.float32) + new_max = tl.maximum(max_vec, prev_score) + prev_is_inf = max_vec == -float("inf") + score_is_inf = prev_score == -float("inf") + prev_delta = tl.where(prev_is_inf, 0.0, max_vec) - tl.where(prev_is_inf, 0.0, new_max) + score_delta = tl.where(score_is_inf, 0.0, prev_score) - tl.where(score_is_inf, 0.0, new_max) + alpha = tl.where(prev_is_inf, 0.0, tl.exp(prev_delta)) + beta = tl.where(score_is_inf, 0.0, tl.exp(score_delta)) + acc_vec = acc_vec * alpha + prev_kv * beta + sum_vec = sum_vec * alpha + beta + max_vec = new_max + + curr_score = tl.load(score_state_ptr + curr_row, mask=mask_d, other=-float("inf")).to(tl.float32) + curr_kv = tl.load(kv_state_ptr + curr_row, mask=mask_d, other=0.0).to(tl.float32) + new_max = tl.maximum(max_vec, curr_score) + prev_is_inf = max_vec == -float("inf") + score_is_inf = curr_score == -float("inf") + prev_delta = tl.where(prev_is_inf, 0.0, max_vec) - tl.where(prev_is_inf, 0.0, new_max) + score_delta = tl.where(score_is_inf, 0.0, curr_score) - tl.where(score_is_inf, 0.0, new_max) + alpha = tl.where(prev_is_inf, 0.0, tl.exp(prev_delta)) + beta = tl.where(score_is_inf, 0.0, tl.exp(score_delta)) + acc_vec = acc_vec * alpha + curr_kv * beta + sum_vec = sum_vec * alpha + beta + max_vec = new_max + valid_sum = sum_vec > 0.0 + safe_sum = tl.where(valid_sum, sum_vec, 1.0) + out_vec = tl.where(valid_sum, acc_vec / safe_sum, 0.0) + tl.store(out_ptr + pid * head_dim + offs_d, out_vec, mask=mask_d) + + for r in range(0, ratio): + src_row = state_base + (ratio + r) * state_width + dst_row = state_base + r * state_width + src_lo = tl.load(kv_state_ptr + src_row + offs_d, mask=mask_d, other=0.0) + src_hi = tl.load(kv_state_ptr + src_row + head_dim + offs_d, mask=mask_d, other=0.0) + src_score_lo = tl.load(score_state_ptr + src_row + offs_d, mask=mask_d, other=-float("inf")) + src_score_hi = tl.load(score_state_ptr + src_row + head_dim + offs_d, mask=mask_d, other=-float("inf")) + tl.store(kv_state_ptr + dst_row + offs_d, src_lo, mask=mask_d) + tl.store(kv_state_ptr + dst_row + head_dim + offs_d, src_hi, mask=mask_d) + tl.store(score_state_ptr + dst_row + offs_d, src_score_lo, mask=mask_d) + tl.store(score_state_ptr + dst_row + head_dim + offs_d, src_score_hi, mask=mask_d) + + + @triton.jit + def _overlap_roll_state_kernel( + kv_state_ptr, + score_state_ptr, + ratio, + state_rows, + state_width, + BLOCK_W: tl.constexpr, + ): + pid = tl.program_id(0) + blocks_per_row = tl.cdiv(state_width, BLOCK_W) + batch = pid // (ratio * blocks_per_row) + row = (pid // blocks_per_row) % ratio + block = pid % blocks_per_row + offs = block * BLOCK_W + tl.arange(0, BLOCK_W) + mask = offs < state_width + batch_base = batch * state_rows * state_width + + src_offset = batch_base + (ratio + row) * state_width + offs + dst_offset = batch_base + row * state_width + offs + + kv_vals = tl.load(kv_state_ptr + src_offset, mask=mask, other=0.0) + score_vals = tl.load(score_state_ptr + src_offset, mask=mask, other=0.0) + tl.store(kv_state_ptr + dst_offset, kv_vals, mask=mask) + tl.store(score_state_ptr + dst_offset, score_vals, mask=mask) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + n_cols = x.shape[-1] + x_flat = x.contiguous().view(-1, n_cols) + n_rows = x_flat.shape[0] + out_flat = torch.empty_like(x_flat) + block = 1 << (int(n_cols - 1).bit_length()) + block = max(64, min(1024, block)) + _rmsnorm_fwd_kernel[(n_rows,)]( + x_flat, + self.weight, + out_flat, + n_rows, + n_cols, + self.eps, + BLOCK=block, + num_warps=_num_warps(block), + num_stages=1, + ) + return out_flat.view_as(x).to(orig_dtype) + + +def precompute_freqs_cis(dim: int, seqlen: int, theta: float = 10000.0) -> torch.Tensor: + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + t = torch.arange(seqlen, dtype=torch.float32) + freqs = torch.outer(t, freqs) + return torch.polar(torch.ones_like(freqs), freqs) + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor: + y = x + x_complex = torch.view_as_complex(x.float().unflatten(-1, (-1, 2))) + if inverse: + freqs_cis = freqs_cis.conj() + if x_complex.ndim == 3: + freqs_cis = freqs_cis.view(1, x_complex.size(1), x_complex.size(-1)) + else: + freqs_cis = freqs_cis.view(1, x_complex.size(1), 1, x_complex.size(-1)) + x_complex = torch.view_as_real(x_complex * freqs_cis).flatten(-2) + y.copy_(x_complex) + return y + + +def overlap_pack( + kv_windows: torch.Tensor, + score_windows: torch.Tensor, + head_dim: int, + fill_score_neg_inf: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_windows, ratio, width = kv_windows.shape + if width != 2 * head_dim or score_windows.shape != kv_windows.shape: + raise ValueError("overlap_pack expects [..., ratio, 2 * head_dim] tensors") + kv_in = kv_windows.contiguous() + score_in = score_windows.contiguous() + packed_kv = torch.empty((batch_size, num_windows, 2 * ratio, head_dim), device=kv_in.device, dtype=kv_in.dtype) + packed_score = torch.empty( + (batch_size, num_windows, 2 * ratio, head_dim), + device=score_in.device, + dtype=score_in.dtype, + ) + grid = (batch_size * num_windows * (2 * ratio),) + block_d = min(256, _next_power_of_2(head_dim)) + score_fill = float("-inf") if fill_score_neg_inf else 0.0 + _overlap_pack_kernel[grid]( + kv_in, + score_in, + packed_kv, + packed_score, + batch_size, + num_windows, + ratio, + head_dim, + score_fill, + BLOCK_D=block_d, + num_warps=_num_warps(block_d), + num_stages=1, + ) + return packed_kv, packed_score + + +def prefill_reduce(kv_windows: torch.Tensor, score_windows: torch.Tensor) -> torch.Tensor: + kv_in = kv_windows.contiguous() + score_in = score_windows.contiguous() + batch_size, num_windows, ratio, head_dim = kv_in.shape + out = torch.empty((batch_size, num_windows, head_dim), device=kv_in.device, dtype=torch.float32) + grid = (batch_size * num_windows,) + block_d = min(256, _next_power_of_2(head_dim)) + _prefill_reduce_kernel[grid]( + kv_in, + score_in, + out, + batch_size * num_windows, + ratio, + head_dim, + BLOCK_D=block_d, + num_warps=_num_warps(block_d), + num_stages=1, + ) + return out.to(kv_in.dtype) + + +def prefill_reduce_overlap(kv_windows: torch.Tensor, score_windows: torch.Tensor, head_dim: int) -> torch.Tensor: + kv_in = kv_windows.contiguous() + score_in = score_windows.contiguous() + batch_size, num_windows, ratio, width = kv_in.shape + if width != 2 * head_dim: + raise ValueError("prefill_reduce_overlap expects [..., ratio, 2 * head_dim] tensors") + out = torch.empty((batch_size, num_windows, head_dim), device=kv_in.device, dtype=torch.float32) + block_d = min(256, _next_power_of_2(head_dim)) + _prefill_reduce_overlap_kernel[(batch_size * num_windows,)]( + kv_in, + score_in, + out, + num_windows, + ratio, + head_dim, + BLOCK_D=block_d, + num_warps=_num_warps(block_d), + num_stages=1, + ) + return out.to(kv_in.dtype) + + +def decode_update_state( + kv_token: torch.Tensor, + score_token: torch.Tensor, + ape_row: torch.Tensor, + kv_state: torch.Tensor, + score_state: torch.Tensor, + slot: int, + overlap: bool, +) -> None: + ratio = kv_state.size(1) // (1 + int(overlap)) + state_slot = ratio + slot if overlap else slot + kv_in = kv_token.squeeze(1).contiguous() + score_in = score_token.squeeze(1).contiguous() + ape_in = ape_row.contiguous() + state_width = kv_state.size(-1) + block_w = min(256, _next_power_of_2(state_width)) + grid = (kv_state.size(0),) + _decode_update_state_kernel[grid]( + kv_in, + score_in, + ape_in, + kv_state, + score_state, + kv_state.size(1), + state_slot, + state_width, + BLOCK_W=block_w, + num_warps=_num_warps(block_w), + num_stages=1, + ) + + +def decode_update_reduce( + kv_token: torch.Tensor, + score_token: torch.Tensor, + ape_row: torch.Tensor, + kv_state: torch.Tensor, + score_state: torch.Tensor, + ratio: int, + head_dim: int, + slot: int, + overlap: bool, +) -> torch.Tensor: + out = torch.empty((kv_state.size(0), head_dim), device=kv_state.device, dtype=torch.float32) + block_d = min(256, _next_power_of_2(head_dim)) + if overlap: + _decode_update_reduce_overlap_kernel[(kv_state.size(0),)]( + kv_token.squeeze(1).contiguous(), + score_token.squeeze(1).contiguous(), + ape_row.contiguous(), + kv_state, + score_state, + out, + ratio, + slot, + head_dim, + BLOCK_D=block_d, + num_warps=_num_warps(block_d), + num_stages=1, + ) + else: + _decode_update_reduce_nonoverlap_kernel[(kv_state.size(0),)]( + kv_token.squeeze(1).contiguous(), + score_token.squeeze(1).contiguous(), + ape_row.contiguous(), + kv_state, + score_state, + out, + ratio, + slot, + head_dim, + BLOCK_D=block_d, + num_warps=_num_warps(block_d), + num_stages=1, + ) + return out.unsqueeze(1).to(kv_state.dtype) + + +def decode_reduce( + kv_state: torch.Tensor, + score_state: torch.Tensor, + ratio: int, + head_dim: int, + overlap: bool, +) -> torch.Tensor: + out = torch.empty((kv_state.size(0), head_dim), device=kv_state.device, dtype=torch.float32) + block_d = min(256, _next_power_of_2(head_dim)) + if overlap: + _decode_reduce_overlap_kernel[(kv_state.size(0),)]( + kv_state.contiguous(), + score_state.contiguous(), + out, + ratio, + head_dim, + BLOCK_D=block_d, + num_warps=_num_warps(block_d), + num_stages=1, + ) + else: + _decode_reduce_nonoverlap_kernel[(kv_state.size(0),)]( + kv_state.contiguous(), + score_state.contiguous(), + out, + ratio, + head_dim, + BLOCK_D=block_d, + num_warps=_num_warps(block_d), + num_stages=1, + ) + return out.unsqueeze(1).to(kv_state.dtype) + + +class ModelNew(nn.Module): + def __init__( + self, + max_batch_size: int = 4, + max_seq_len: int = 256, + dim: int = 512, + head_dim: int = 128, + rope_head_dim: int = 64, + compress_ratio: int = 4, + norm_eps: float = 1e-6, + ): + super(ModelNew, self).__init__() + self.dim = dim + self.head_dim = head_dim + self.rope_head_dim = rope_head_dim + self.compress_ratio = compress_ratio + self.overlap = compress_ratio == 4 + coeff = 1 + int(self.overlap) + + self.ape = nn.Parameter(torch.empty(compress_ratio, coeff * head_dim, dtype=torch.float32)) + self.wkv = Linear(dim, coeff * head_dim, dtype=torch.float32) + self.wgate = Linear(dim, coeff * head_dim, dtype=torch.float32) + self.norm = RMSNorm(head_dim, norm_eps) + + self.register_buffer( + "kv_state", + torch.zeros(max_batch_size, coeff * compress_ratio, coeff * head_dim, dtype=torch.float32), + persistent=False, + ) + self.register_buffer( + "score_state", + torch.full( + (max_batch_size, coeff * compress_ratio, coeff * head_dim), + float("-inf"), + dtype=torch.float32, + ), + persistent=False, + ) + self.register_buffer( + "kv_cache", + torch.zeros(max_batch_size, max_seq_len // compress_ratio, head_dim, dtype=torch.bfloat16), + persistent=False, + ) + self.register_buffer( + "freqs_cis", + precompute_freqs_cis(rope_head_dim, max_seq_len), + persistent=False, + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.normal_(self.ape, mean=0.0, std=0.02) + self.wkv.reset_parameters() + self.wgate.reset_parameters() + nn.init.ones_(self.norm.weight) + + def reset_runtime_state(self) -> None: + self.kv_state.zero_() + self.score_state.fill_(float("-inf")) + self.kv_cache.zero_() + + def forward(self, x: torch.Tensor, start_pos: int) -> torch.Tensor | None: + batch_size, seqlen, _ = x.shape + ratio = self.compress_ratio + head_dim = self.head_dim + rope_head_dim = self.rope_head_dim + overlap = self.overlap + dtype = x.dtype + + x = x.float() + kv = self.wkv(x) + score = self.wgate(x) + + if start_pos == 0: + should_compress = seqlen >= ratio + remainder = seqlen % ratio + cutoff = seqlen - remainder + offset = ratio if overlap else 0 + + if overlap and cutoff >= ratio: + self.kv_state[:batch_size, :ratio] = kv[:, cutoff - ratio : cutoff] + self.score_state[:batch_size, :ratio] = score[:, cutoff - ratio : cutoff] + self.ape + + if remainder > 0: + self.kv_state[:batch_size, offset : offset + remainder] = kv[:, cutoff:] + self.score_state[:batch_size, offset : offset + remainder] = score[:, cutoff:] + self.ape[:remainder] + + if not should_compress: + return None + + kv_windows = kv[:, :cutoff].unflatten(1, (-1, ratio)) + score_windows = score[:, :cutoff].unflatten(1, (-1, ratio)) + self.ape + if overlap: + kv = prefill_reduce_overlap(kv_windows, score_windows, head_dim) + else: + kv = prefill_reduce(kv_windows, score_windows) + else: + slot = start_pos % ratio + should_compress = (start_pos + 1) % ratio == 0 + if not should_compress: + decode_update_state( + kv_token=kv, + score_token=score, + ape_row=self.ape[slot], + kv_state=self.kv_state[:batch_size], + score_state=self.score_state[:batch_size], + slot=slot, + overlap=overlap, + ) + return None + kv = decode_update_reduce( + kv_token=kv, + score_token=score, + ape_row=self.ape[slot], + kv_state=self.kv_state[:batch_size], + score_state=self.score_state[:batch_size], + ratio=ratio, + head_dim=head_dim, + slot=slot, + overlap=overlap, + ) + + kv = self.norm(kv.to(dtype)) + if start_pos == 0: + freqs_cis = self.freqs_cis[:cutoff:ratio].to(kv.device) + self.kv_cache[:batch_size, : seqlen // ratio] = kv + else: + freqs_cis = self.freqs_cis[start_pos + 1 - ratio].unsqueeze(0).to(kv.device) + self.kv_cache[:batch_size, start_pos // ratio] = kv.squeeze(1) + apply_rotary_emb(kv[..., -rope_head_dim:], freqs_cis) + return kv + + +def generate_test_data(params: dict) -> tuple[torch.Tensor, int]: + batch_size = params["batch_size"] + seq_len = params["seq_len"] + dim = params["dim"] + start_pos = params["start_pos"] + x = torch.randn(batch_size, seq_len, dim, dtype=torch.bfloat16, device="cpu") + return x, start_pos + + +def test_kv_compress(): + return ModelNew(*get_init_inputs()).forward(*get_inputs()) + + +def get_inputs(): + params = {"batch_size": 1, "seq_len": 12, "dim": 448, "start_pos": 0} + return list(generate_test_data(params)) + + +def get_init_inputs(): + return [1, 256, 448, 32, 4, 4, 1e-6] diff --git a/dlblas/kernels/validate.py b/dlblas/kernels/validate.py index 5b539c19..fb0473e7 100644 --- a/dlblas/kernels/validate.py +++ b/dlblas/kernels/validate.py @@ -151,6 +151,25 @@ def _move_to_device(obj, device): return obj +def _outputs_match(expected, actual, tol): + if isinstance(expected, torch.Tensor) and isinstance(actual, torch.Tensor): + if expected.shape != actual.shape: + return False + return torch.allclose(expected, actual, atol=tol, rtol=tol) + + if isinstance(expected, (list, tuple)) and isinstance(actual, (list, tuple)): + if len(expected) != len(actual): + return False + return all(_outputs_match(x, y, tol) for x, y in zip(expected, actual)) + + if isinstance(expected, dict) and isinstance(actual, dict): + if expected.keys() != actual.keys(): + return False + return all(_outputs_match(expected[k], actual[k], tol) for k in expected) + + return expected == actual + + class KernelBenchDataset: """ 条目示例 @@ -303,19 +322,8 @@ def main(): inputs = _move_to_device(inputs, device) output = original_model(*inputs) output_new = custom_model(*inputs) - outputs = (output,) if not isinstance(output, tuple) else output - outputs_new = (output_new,) if not isinstance(output_new, tuple) else output_new - if len(outputs) != len(outputs_new): + if not _outputs_match(output, output_new, tol): correctness=False - # 遍历每个输出张量 - for i, (out, out_new) in enumerate(zip(outputs, outputs_new)): - # 检查形状是否一致 - if out.shape != out_new.shape: - correctness=False - - # 检查数值是否一致 - if not torch.allclose(out, out_new, atol=tol, rtol=tol): - correctness=False except Exception as e: print(f"{item['uid']} run with exception: {e}", flush=True) correctness=False @@ -336,4 +344,4 @@ def main(): print(f"保存 JSON 失败: {e}", flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main()