Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 236 additions & 0 deletions backends/cuda/tests/test_tq4_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,242 @@ def test_output_shape_and_dtype(self):
self.assertEqual(out.shape, (1, H_q, Lq, D))
self.assertEqual(out.dtype, torch.bfloat16)

# ------------------------------------------------------------------
# 128k code path: kv_len clamp (decode) + mask_is_causal (prefill)
#
# Every test above calls tq4_sdpa WITHOUT kv_len and WITHOUT
# mask_is_causal, so they only exercise the kv_len=None fallback
# (full-Lk loop) at short KV. The cases below drive the actual
# long-context paths used in production by the Gemma-4 31B global
# layers (head_dim=512, GQA 8:4) and Qwen 3.5 MoE (head_dim=256,
# GQA 16:2):
# * the on-device kv_len scalar that bounds the KV loop to the
# filled context (decode), and
# * the mask_is_causal per-tile causal block-skip (prefill).
#
# "GARBAGE TAIL": in production the KV cache is a fixed buffer
# pre-allocated to max_seq_len (e.g. 131072). At any step only the
# first kv_len positions hold real K/V; the rest is stale /
# uninitialized memory that attention must ignore. We simulate that
# tail by writing large-magnitude (x1000) values into [kv_len:]. If
# the clamp / block-skip works the kernel never reads the tail and
# the output matches a reference built from [0, kv_len) only; if it
# is broken the huge tail values dominate the softmax and the cosine
# collapses to ~0. So the garbage tail is a built-in negative control
# (verified: dropping kv_len drives the cosine to ~-0.01 and fails).
#
# CAUSAL ALIGNMENT (top-left vs bottom-right): when L_q < L_kv (a
# chunked prefill / decode, where the Lq new queries sit at the END
# of a kv_len-long context) there are two ways to place the causal
# triangle. PyTorch F.sdpa(is_causal=True) uses TOP-LEFT alignment
# (query row i attends to keys [0, i]) -- wrong for a KV cache. This
# kernel and gemma4_31b/model.py::_build_masks use BOTTOM-RIGHT
# alignment: query row i is absolute position (kv_len - Lq + i) and
# attends to keys [0, kv_len - Lq + i]. So the reference below builds
# an explicit bottom-right mask (q_pos >= cache_pos) rather than
# passing is_causal=True, which would otherwise mismatch the kernel.
# ------------------------------------------------------------------

def _run_long_kv_test(
self,
*,
H_q,
H_kv,
D,
Lq,
kv_len,
buffer_len,
causal=False,
garbage=True,
pass_kv_len=True,
min_cosine=0.99,
seed=42,
):
"""Drive tq4_sdpa over a buffer whose first ``kv_len`` positions are
real and whose ``[kv_len:]`` tail is large-magnitude garbage, then
compare against an fp32 reference built from the first ``kv_len``
positions only.

The kernel sees the full (garbage-tailed) compressed buffer; the
on-device ``kv_len`` scalar (and, for prefill, the bottom-right
causal mask) must confine attention to ``[0, kv_len)``.

``causal=True`` builds a bottom-right-aligned mask (the Lq queries
are the last Lq positions of a kv_len-long context), mirroring the
production ``q_pos >= cache_pos`` mask in gemma4_31b/model.py
``_build_masks`` and the kernel's ``(kv_len - Lq) + seq_pos`` block
bound. We deliberately do NOT use ``F.sdpa(is_causal=True)`` for the
reference: PyTorch aligns is_causal top-left when L_q < L_kv, while
this kernel (and the model) align bottom-right.
"""
torch.manual_seed(seed)
centroids, boundaries, rotation = _make_codebook_and_rotation(D)
centroids = centroids.cuda()
boundaries = boundaries.cuda()
rotation = rotation.cuda()

B = 1
k = torch.randn(B, H_kv, buffer_len, D, dtype=torch.bfloat16, device="cuda")
v = torch.randn(B, H_kv, buffer_len, D, dtype=torch.bfloat16, device="cuda")
if garbage and buffer_len > kv_len:
g = buffer_len - kv_len
k[:, :, kv_len:, :] = (
torch.randn(B, H_kv, g, D, dtype=torch.bfloat16, device="cuda") * 1000.0
)
v[:, :, kv_len:, :] = (
torch.randn(B, H_kv, g, D, dtype=torch.bfloat16, device="cuda") * 1000.0
)

q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda")

k_packed, k_norms = _compress(k, boundaries, rotation)
v_packed, v_norms = _compress(v, boundaries, rotation)

attn_mask = None
if causal:
cache_pos = torch.arange(buffer_len, device="cuda")
q_pos = torch.arange(kv_len - Lq, kv_len, device="cuda").unsqueeze(1)
attn_mask = (q_pos >= cache_pos.unsqueeze(0)).view(1, 1, Lq, buffer_len)

kv_len_t = (
torch.tensor([kv_len], dtype=torch.int32, device="cuda")
if pass_kv_len
else None
)

out = self.tq4_sdpa(
q,
k_packed,
k_norms,
v_packed,
v_norms,
centroids,
rotation,
attn_mask=attn_mask,
is_causal=False,
scale=None,
kv_len=kv_len_t,
mask_is_causal=causal,
)

# Reference: the same decompress-then-fp32-SDPA path the other tests
# use (_reference_tq4_sdpa), but over ONLY the first kv_len positions
# so the garbage tail can never influence it. _compress is per-row,
# so compressing the sliced K/V here is bit-identical to the kernel's
# view of the full buffer sliced to [:, :, :kv_len]; the helper also
# handles the GQA repeat_interleave and mask broadcast internally.
ref_mask = attn_mask[:, :, :, :kv_len] if attn_mask is not None else None
ref, *_ = _reference_tq4_sdpa(
q,
k[:, :, :kv_len],
v[:, :, :kv_len],
centroids,
boundaries,
rotation,
attn_mask=ref_mask,
)

self.assertFalse(torch.isnan(out).any(), "NaN in output")
cos = _cosine_sim(out, ref)
self.assertGreater(
cos,
min_cosine,
f"Cosine {cos:.5f} < {min_cosine} "
f"(H_q={H_q} H_kv={H_kv} D={D} Lq={Lq} kv_len={kv_len} "
f"buffer={buffer_len} causal={causal} kv_len_passed={pass_kv_len})",
)
return cos

def test_kv_len_clamp_decode_gemma_global(self):
"""Decode (Lq=1) kv_len clamp at Gemma-4 31B global-layer shape
(head_dim=512, GQA 8:4). N=8192 leaves a 24k garbage tail in a 32k
buffer (clamp guard); N=32768 fills the buffer (full 32k loop)."""
for N in (8192, 32768):
with self.subTest(N=N):
self._run_long_kv_test(
H_q=8, H_kv=4, D=512, Lq=1, kv_len=N, buffer_len=32768
)

def test_kv_len_clamp_decode_qwen(self):
"""Decode (Lq=1) kv_len clamp at Qwen 3.5 MoE shape
(head_dim=256, GQA 16:2)."""
for N in (8192, 32768):
with self.subTest(N=N):
self._run_long_kv_test(
H_q=16, H_kv=2, D=256, Lq=1, kv_len=N, buffer_len=32768
)

def test_mask_is_causal_prefill_gemma_global(self):
"""Chunked prefill (Lq>1) with mask_is_causal at Gemma global shape.
The Lq queries are the last Lq of a kv_len-long context; the
per-tile causal block-skip plus bottom-right mask must match the
fp32 causal reference over the first kv_len positions. A garbage
tail beyond kv_len also exercises the clamp."""
for Lq, kv_len, buf in ((256, 4096, 8192), (2048, 8192, 16384)):
with self.subTest(Lq=Lq, kv_len=kv_len):
self._run_long_kv_test(
H_q=8,
H_kv=4,
D=512,
Lq=Lq,
kv_len=kv_len,
buffer_len=buf,
causal=True,
)

def test_mask_is_causal_prefill_qwen(self):
"""Chunked prefill (Lq>1) with mask_is_causal at Qwen shape."""
for Lq, kv_len, buf in ((256, 4096, 8192), (2048, 8192, 16384)):
with self.subTest(Lq=Lq, kv_len=kv_len):
self._run_long_kv_test(
H_q=16,
H_kv=2,
D=256,
Lq=Lq,
kv_len=kv_len,
buffer_len=buf,
causal=True,
)

def test_kv_len_none_fallback_qwen(self):
"""Regression: the kv_len=None fallback (HAS_KV_LEN False, full-Lk
loop) that the Qwen path relies on still matches the fp32 reference.
This guards the original behavior the kv_len feature must preserve
for callers that pass neither kv_len nor mask_is_causal."""
self._run_long_kv_test(
H_q=16,
H_kv=2,
D=256,
Lq=1,
kv_len=256,
buffer_len=256,
garbage=False,
pass_kv_len=False,
)

@unittest.skipUnless(
os.environ.get("TQ4_RUN_128K") == "1",
"128k case is heavy for the 24GB CI runner; set TQ4_RUN_128K=1 to run",
)
def test_kv_len_clamp_128k(self):
"""Full 131072-entry buffer (Qwen shape). (a) kv_len=8192 with a
~123k garbage tail — the clamp keeps decode O(context) and never
touches the tail; (b) kv_len=131072 — correctness at true 128k
scale. Gated behind TQ4_RUN_128K because the fp32 reference for (b)
needs >~6GB and CI runs on a 24GB A10G."""
self._run_long_kv_test(
H_q=16, H_kv=2, D=256, Lq=1, kv_len=8192, buffer_len=131072
)
self._run_long_kv_test(
H_q=16,
H_kv=2,
D=256,
Lq=1,
kv_len=131072,
buffer_len=131072,
garbage=False,
)

# ------------------------------------------------------------------
# Validation errors
# ------------------------------------------------------------------
Expand Down
Loading
Loading