Skip to content

Commit f6bdb2c

Browse files
committed
remove reduntatn comments
1 parent 6e773c0 commit f6bdb2c

1 file changed

Lines changed: 0 additions & 33 deletions

File tree

backends/cuda/tests/test_tq4_sdpa.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import numpy as np
2121
import torch
2222
import torch.nn.functional as F
23-
2423
from executorch.backends.cuda.cuda_backend import CudaBackend
2524
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
2625
from executorch.backends.cuda.triton.kernels.tq4_sdpa import tq4_sdpa
@@ -440,38 +439,6 @@ def test_output_shape_and_dtype(self):
440439

441440
# ------------------------------------------------------------------
442441
# 128k code path: kv_len clamp (decode) + mask_is_causal (prefill)
443-
#
444-
# Every test above calls tq4_sdpa WITHOUT kv_len and WITHOUT
445-
# mask_is_causal, so they only exercise the kv_len=None fallback
446-
# (full-Lk loop) at short KV. The cases below drive the actual
447-
# long-context paths used in production by the Gemma-4 31B global
448-
# layers (head_dim=512, GQA 8:4) and Qwen 3.5 MoE (head_dim=256,
449-
# GQA 16:2):
450-
# * the on-device kv_len scalar that bounds the KV loop to the
451-
# filled context (decode), and
452-
# * the mask_is_causal per-tile causal block-skip (prefill).
453-
#
454-
# "GARBAGE TAIL": in production the KV cache is a fixed buffer
455-
# pre-allocated to max_seq_len (e.g. 131072). At any step only the
456-
# first kv_len positions hold real K/V; the rest is stale /
457-
# uninitialized memory that attention must ignore. We simulate that
458-
# tail by writing large-magnitude (x1000) values into [kv_len:]. If
459-
# the clamp / block-skip works the kernel never reads the tail and
460-
# the output matches a reference built from [0, kv_len) only; if it
461-
# is broken the huge tail values dominate the softmax and the cosine
462-
# collapses to ~0. So the garbage tail is a built-in negative control
463-
# (verified: dropping kv_len drives the cosine to ~-0.01 and fails).
464-
#
465-
# CAUSAL ALIGNMENT (top-left vs bottom-right): when L_q < L_kv (a
466-
# chunked prefill / decode, where the Lq new queries sit at the END
467-
# of a kv_len-long context) there are two ways to place the causal
468-
# triangle. PyTorch F.sdpa(is_causal=True) uses TOP-LEFT alignment
469-
# (query row i attends to keys [0, i]) -- wrong for a KV cache. This
470-
# kernel and gemma4_31b/model.py::_build_masks use BOTTOM-RIGHT
471-
# alignment: query row i is absolute position (kv_len - Lq + i) and
472-
# attends to keys [0, kv_len - Lq + i]. So the reference below builds
473-
# an explicit bottom-right mask (q_pos >= cache_pos) rather than
474-
# passing is_causal=True, which would otherwise mismatch the kernel.
475442
# ------------------------------------------------------------------
476443

477444
def _run_long_kv_test(

0 commit comments

Comments
 (0)