Skip to content

[BUG] padding null token should not impact the final result #2

@JacoCheung

Description

@JacoCheung

The hstu_attn kernel takes jagged tensors and the actual length is indicated by cu_seqlens_q and cu_seqlens_k.
As per my understanding , it's safe to append some tokens to candidates/targets and feed the padded sequences to attn kernel and the padded token should not affect the previous tokens becuase it's causal.
However, I found out the results differed if I padded so.

Reproduce script:

note that the preprocess_input is a must (and usually the real use case), otherwise no diff occurs.

import torch

from hstu_attn import hstu_attn_varlen_func as ampere_attn_func
from hopper.hstu_attn_interface import hstu_attn_varlen_func as hopper_attn_func


def preprocess_input(input_tensor):
  D = input_tensor.shape[-1]
  normed_x = torch.nn.functional.layer_norm(input_tensor, (D,))
  linear_module = torch.nn.Linear(D, D * 3, bias=False).cuda().bfloat16()
  mixed_qkv = linear_module(normed_x)
  tq, tk, tv = torch.split(mixed_qkv, [D, D, D], dim=-1)
  return tq.contiguous(), tk.contiguous(), tv.contiguous()

def test_trunc_or_pad_attn_kernel():
  device = torch.device("cuda")
  max_seqlen = torch.randint(10, 100, (1,)).item()
  B = 32
  max_num_targets = 10
  max_num_contextuals = 4
  lengths = torch.randint(
      low=2,
      high=max_seqlen + 1,
      size=(B,),
      device=device,
      dtype=torch.int,
  )
  num_targets = torch.randint(
      low=0,
      high=max_num_targets + 1,
      size=(B,),
      device=device,
      dtype=torch.int32,
  )
  num_targets = torch.clamp(
      num_targets, max=lengths - 1, min=torch.zeros_like(num_targets)
  )  # at least 1 history

  num_contextuals = torch.randint(
      low=0,
      high=max_num_contextuals + 1,
      size=(B,),
      device=device,
      dtype=torch.int32,
  )
  num_contextuals = torch.clamp(
      num_contextuals,
      max=lengths - 1 - num_targets if num_targets is not None else lengths - 1,
      min=torch.zeros_like(num_contextuals),
  )  # at least 1 history!!
  seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
  L = int(seq_offsets[-1].item())
  H = 4
  D = 32
  x = torch.empty(
      (L, H * D),
      dtype=torch.bfloat16,
      device=device,
  ).uniform_(-0.1, 0.1)
  q, k, v = preprocess_input(x)
  q = q.view(-1, H, D)
  k = k.view(-1, H, D)
  v = v.view(-1, H, D)
  
  cu_seqlens_q = seq_offsets.clone()
  cu_seqlens_k = seq_offsets.clone()
  
  max_seqlen_q = max_seqlen
  max_seqlen_k = max_seqlen
  
  ampere_out = ampere_attn_func(
    q, 
    k, 
    v, 
    cu_seqlens_q, 
    cu_seqlens_k, 
    max_seqlen_q, 
    max_seqlen_k, 
    num_contextuals, 
    num_targets,
    rab=None,
    alpha=1.0 / (D**0.5),
    target_group_size=1,
    window_size=(-1, 0),
  )
  padded_q = torch.nn.functional.pad(q.detach().clone(), (0, 0, 0, 0, 0, 1), value=0.0)
  padded_k = torch.nn.functional.pad(k.detach().clone(), (0, 0, 0, 0, 0, 1), value=0.0)
  padded_v = torch.nn.functional.pad(v.detach().clone(), (0, 0, 0, 0, 0, 1), value=0.0)
  
  
  padded_cu_seqlens_q = cu_seqlens_q.clone()
  padded_cu_seqlens_q[-1] += 1
  padded_cu_seqlens_k = cu_seqlens_k.clone()
  padded_cu_seqlens_k[-1] += 1
  
  padded_ampere_out = ampere_attn_func(
    padded_q, 
    padded_k, 
    padded_v, 
    padded_cu_seqlens_q, 
    padded_cu_seqlens_k, 
    max_seqlen_q, 
    max_seqlen_k, 
    num_contextuals, 
    num_targets,
    rab=None,
    alpha=1.0 / (D**0.5),
    target_group_size=1,
    window_size=(-1, 0),
  )
  
  # remove pad
  padded_ampere_out = padded_ampere_out[: -1, :, :]
  import pdb; pdb.set_trace()
  torch.testing.assert_close(ampere_out, padded_ampere_out)
  

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions