|
20 | 20 | import numpy as np |
21 | 21 | import torch |
22 | 22 | import torch.nn.functional as F |
23 | | - |
24 | 23 | from executorch.backends.cuda.cuda_backend import CudaBackend |
25 | 24 | from executorch.backends.cuda.cuda_partitioner import CudaPartitioner |
26 | 25 | from executorch.backends.cuda.triton.kernels.tq4_sdpa import tq4_sdpa |
@@ -440,38 +439,6 @@ def test_output_shape_and_dtype(self): |
440 | 439 |
|
441 | 440 | # ------------------------------------------------------------------ |
442 | 441 | # 128k code path: kv_len clamp (decode) + mask_is_causal (prefill) |
443 | | - # |
444 | | - # Every test above calls tq4_sdpa WITHOUT kv_len and WITHOUT |
445 | | - # mask_is_causal, so they only exercise the kv_len=None fallback |
446 | | - # (full-Lk loop) at short KV. The cases below drive the actual |
447 | | - # long-context paths used in production by the Gemma-4 31B global |
448 | | - # layers (head_dim=512, GQA 8:4) and Qwen 3.5 MoE (head_dim=256, |
449 | | - # GQA 16:2): |
450 | | - # * the on-device kv_len scalar that bounds the KV loop to the |
451 | | - # filled context (decode), and |
452 | | - # * the mask_is_causal per-tile causal block-skip (prefill). |
453 | | - # |
454 | | - # "GARBAGE TAIL": in production the KV cache is a fixed buffer |
455 | | - # pre-allocated to max_seq_len (e.g. 131072). At any step only the |
456 | | - # first kv_len positions hold real K/V; the rest is stale / |
457 | | - # uninitialized memory that attention must ignore. We simulate that |
458 | | - # tail by writing large-magnitude (x1000) values into [kv_len:]. If |
459 | | - # the clamp / block-skip works the kernel never reads the tail and |
460 | | - # the output matches a reference built from [0, kv_len) only; if it |
461 | | - # is broken the huge tail values dominate the softmax and the cosine |
462 | | - # collapses to ~0. So the garbage tail is a built-in negative control |
463 | | - # (verified: dropping kv_len drives the cosine to ~-0.01 and fails). |
464 | | - # |
465 | | - # CAUSAL ALIGNMENT (top-left vs bottom-right): when L_q < L_kv (a |
466 | | - # chunked prefill / decode, where the Lq new queries sit at the END |
467 | | - # of a kv_len-long context) there are two ways to place the causal |
468 | | - # triangle. PyTorch F.sdpa(is_causal=True) uses TOP-LEFT alignment |
469 | | - # (query row i attends to keys [0, i]) -- wrong for a KV cache. This |
470 | | - # kernel and gemma4_31b/model.py::_build_masks use BOTTOM-RIGHT |
471 | | - # alignment: query row i is absolute position (kv_len - Lq + i) and |
472 | | - # attends to keys [0, kv_len - Lq + i]. So the reference below builds |
473 | | - # an explicit bottom-right mask (q_pos >= cache_pos) rather than |
474 | | - # passing is_causal=True, which would otherwise mismatch the kernel. |
475 | 442 | # ------------------------------------------------------------------ |
476 | 443 |
|
477 | 444 | def _run_long_kv_test( |
|
0 commit comments