import torch
from hstu_attn import hstu_attn_varlen_func as ampere_attn_func
from hopper.hstu_attn_interface import hstu_attn_varlen_func as hopper_attn_func
def preprocess_input(input_tensor):
D = input_tensor.shape[-1]
normed_x = torch.nn.functional.layer_norm(input_tensor, (D,))
linear_module = torch.nn.Linear(D, D * 3, bias=False).cuda().bfloat16()
mixed_qkv = linear_module(normed_x)
tq, tk, tv = torch.split(mixed_qkv, [D, D, D], dim=-1)
return tq.contiguous(), tk.contiguous(), tv.contiguous()
def test_trunc_or_pad_attn_kernel():
device = torch.device("cuda")
max_seqlen = torch.randint(10, 100, (1,)).item()
B = 32
max_num_targets = 10
max_num_contextuals = 4
lengths = torch.randint(
low=2,
high=max_seqlen + 1,
size=(B,),
device=device,
dtype=torch.int,
)
num_targets = torch.randint(
low=0,
high=max_num_targets + 1,
size=(B,),
device=device,
dtype=torch.int32,
)
num_targets = torch.clamp(
num_targets, max=lengths - 1, min=torch.zeros_like(num_targets)
) # at least 1 history
num_contextuals = torch.randint(
low=0,
high=max_num_contextuals + 1,
size=(B,),
device=device,
dtype=torch.int32,
)
num_contextuals = torch.clamp(
num_contextuals,
max=lengths - 1 - num_targets if num_targets is not None else lengths - 1,
min=torch.zeros_like(num_contextuals),
) # at least 1 history!!
seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
L = int(seq_offsets[-1].item())
H = 4
D = 32
x = torch.empty(
(L, H * D),
dtype=torch.bfloat16,
device=device,
).uniform_(-0.1, 0.1)
q, k, v = preprocess_input(x)
q = q.view(-1, H, D)
k = k.view(-1, H, D)
v = v.view(-1, H, D)
cu_seqlens_q = seq_offsets.clone()
cu_seqlens_k = seq_offsets.clone()
max_seqlen_q = max_seqlen
max_seqlen_k = max_seqlen
ampere_out = ampere_attn_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
num_contextuals,
num_targets,
rab=None,
alpha=1.0 / (D**0.5),
target_group_size=1,
window_size=(-1, 0),
)
padded_q = torch.nn.functional.pad(q.detach().clone(), (0, 0, 0, 0, 0, 1), value=0.0)
padded_k = torch.nn.functional.pad(k.detach().clone(), (0, 0, 0, 0, 0, 1), value=0.0)
padded_v = torch.nn.functional.pad(v.detach().clone(), (0, 0, 0, 0, 0, 1), value=0.0)
padded_cu_seqlens_q = cu_seqlens_q.clone()
padded_cu_seqlens_q[-1] += 1
padded_cu_seqlens_k = cu_seqlens_k.clone()
padded_cu_seqlens_k[-1] += 1
padded_ampere_out = ampere_attn_func(
padded_q,
padded_k,
padded_v,
padded_cu_seqlens_q,
padded_cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
num_contextuals,
num_targets,
rab=None,
alpha=1.0 / (D**0.5),
target_group_size=1,
window_size=(-1, 0),
)
# remove pad
padded_ampere_out = padded_ampere_out[: -1, :, :]
import pdb; pdb.set_trace()
torch.testing.assert_close(ampere_out, padded_ampere_out)
The hstu_attn kernel takes jagged tensors and the actual length is indicated by
cu_seqlens_qandcu_seqlens_k.As per my understanding , it's safe to append some tokens to candidates/targets and feed the padded sequences to attn kernel and the padded token should not affect the previous tokens becuase it's causal.
However, I found out the results differed if I padded so.
Reproduce script:
note that the preprocess_input is a must (and usually the real use case), otherwise no diff occurs.