From ff24498c4a9a00c89fd9e641eddd042552fa77de Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 2 Dec 2024 10:06:59 -0500 Subject: [PATCH 01/17] support vllm splitkv layout --- csrc/composable_kernel | 2 +- csrc/flash_attn_ck/mha_varlen_fwd.cpp | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index e7b6286441a..c6cb8c52c16 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit e7b6286441aae59d3a87db67f42369d3cc2636a4 +Subproject commit c6cb8c52c168fcc63ca5fc63fbe9650f81052a26 diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index b6a274f4fa5..b305c7b8436 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -36,7 +36,7 @@ fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &m head_size, dtype, true, // is_group_mode - true, // is_v_rowmajor + false, // is_v_rowmajor mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, @@ -183,8 +183,8 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, at::Tensor out_acc) { // q: (total_q, nheads, d) - // k: (num_blocks, page_block_size, num_heads_k, d) - // v: (num_blocks, page_block_size, num_heads_k, d) + // k: (num_blocks, num_heads_k, d / 8, page_block_size, 8) + // v: (num_blocks, num_heads_k, d, page_block_size) // o: (total_q, nheads, d) // alibi_slopes:(batch_size, nheads) or (nhead) @@ -241,12 +241,12 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, args.nhead_stride_q = q.stride(1); args.batch_stride_k = k.stride(0); - args.stride_k = k.stride(1); - args.nhead_stride_k = k.stride(2); + args.nhead_stride_k = k.stride(1); + args.stride_k = k.stride(2); args.batch_stride_v = v.stride(0); - args.stride_v = v.stride(1); - args.nhead_stride_v = v.stride(2); + args.nhead_stride_v = v.stride(1); + args.stride_v = v.stride(2); args.batch_stride_o = 0; args.stride_o = out.stride(0); @@ -292,8 +292,8 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x num_heads_k x head_size / 8 x page_block_size x 8 if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x num_heads_k x page_block_size x head_size if there's a block_table. c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 @@ -348,11 +348,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si const int batch_size = cu_seqlens_q.numel() - 1; int num_heads = sizes[1]; const int head_size = sizes[2]; - const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + const int num_heads_k = k.size(1); const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); - const int page_block_size = !paged_KV ? 1 : k.size(1); + const int page_block_size = !paged_KV ? 1 : k.size(3); TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, "Paged KV cache block size must be divisible by 128"); if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case @@ -394,8 +394,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size); } else { - CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); + CHECK_SHAPE(k, num_blocks, num_heads_k, head_size / 8, page_block_size, 8); + CHECK_SHAPE(v, num_blocks, num_heads_k, head_size, page_block_size); CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); } From efc01f5a2898c09c41a6829d52cffcfa628e127e Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 2 Dec 2024 12:42:25 -0500 Subject: [PATCH 02/17] Add test script --- tests/test_flash_attn_ck_page_varlen.py | 195 ++++++++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 tests/test_flash_attn_ck_page_varlen.py diff --git a/tests/test_flash_attn_ck_page_varlen.py b/tests/test_flash_attn_ck_page_varlen.py new file mode 100644 index 00000000000..7f651301206 --- /dev/null +++ b/tests/test_flash_attn_ck_page_varlen.py @@ -0,0 +1,195 @@ +import math + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from flash_attn import ( + flash_attn_func, + flash_attn_kvpacked_func, + flash_attn_qkvpacked_func, + flash_attn_varlen_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_with_kvcache, +) + +from test_flash_attn import ( + attn_bias_from_alibi_slopes, + convert_flash_attn_S_to_softmax, + generate_qkv, + generate_random_padding_mask, + _generate_block_kvcache, + attention_ref, + attention_kvpacked_ref, + attention_qkvpacked_ref, +) + +from flash_attn.layers.rotary import apply_rotary_emb + +def is_bwd_hdim_supported(d): + return d <= 256 + + +def ck_randval_to_dropout_mask(randval, p): + # If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout + # randval in 255 * [0, 0.7] will be kept + # If return dropout_mask >=0, value will be kept + return math.floor(255.0 * (1 - p)) - randval.to(torch.float32) + + +def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_rounded, seqlen_k_rounded): + """ pad + rearrange [nheads, total_q, max_seqlen_k] into [b, nheads, seqlen_q_rounded, seqlen_k_rounded] + Arguments: + S_dmask: (nheads, total_q, max_seqlen_k) + cu_seqlens_q: (b + 1) + Output: + S_dmask: (b, nheads, seqlen_q_rounded, seqlen_k_rounded) + """ + batch_size = cu_seqlens_q.numel() - 1 + seqlens_q = torch.roll(cu_seqlens_q, shifts = -1) - cu_seqlens_q + seqlens_q = seqlens_q[0:batch_size].tolist() + S_dmask = torch.split(S_dmask, seqlens_q, dim=1) + # [(nheads, seqlen_q0, max_seqlen_k), (nheads, seqlen_q1, max_seqlen_k), ..., (nheads, seqlen_qb, max_seqlen_k)] + masks = () + for mask in S_dmask: + # (nheads, seqlen_qi, max_seqlen_k) -> (nheads, seqlen_q_rounded, seqlen_k_rounded) + mask = F.pad(mask, (0, seqlen_k_rounded - mask.shape[2], 0, seqlen_q_rounded - mask.shape[1], 0, 0)).unsqueeze(1) + masks = masks + (mask, ) + S_dmask = torch.cat(masks, dim=1) + + S_dmask = S_dmask.transpose(0, 1) + return S_dmask + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("d", [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32]) +@pytest.mark.parametrize("swap_sq_sk", [False, True]) +# @pytest.mark.parametrize("swap_sq_sk", [False]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + # (1, 239), + (3, 799), + (127, 512), + (127, 513), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1023, 1024), + ], +) +# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) +@pytest.mark.parametrize("paged_kv_block_size", [128, 256, 512]) +def test_flash_attn_varlen_causal( + seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype +): + if max(seqlen_q, seqlen_k) >= 2048: + pytest.skip() + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + causal = True + # set seed + torch.random.manual_seed(0) + batch_size = 8 + nheads = 9 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + + if paged_kv_block_size is None: + k = torch.randn( + batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True + ) + v = torch.randn( + batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True + ) + block_table = None + else: + k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache( + seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype + ) + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + if paged_kv_block_size is not None: + nblock = k_cache_paged.shape[0] + print(k_cache_paged.shape) + print(v_cache_paged.shape) + + k_cache_paged = rearrange(k_cache_paged, + 'nblock block_size nheads (d1 d2) -> nblock nheads d1 block_size d2', + block_size=paged_kv_block_size, d1=d // 8, d2=8) + + v_cache_paged = rearrange(v_cache_paged, + 'nblock block_size nheads d -> nblock nheads d block_size', + block_size=paged_kv_block_size) + + print(k_cache_paged.shape) + print(v_cache_paged.shape) + + out_unpad = flash_attn_varlen_func( + q_unpad, + k_unpad if paged_kv_block_size is None else k_cache_paged, + v_unpad if paged_kv_block_size is None else v_cache_paged, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, + causal=causal, + window_size=window_size, + block_table=block_table, + ) + out = output_pad_fn(out_unpad) + out_ref, attn_ref = attention_ref( + q, + k, + v, + query_padding_mask, + key_padding_mask, + None, + 0.0, + None, + causal=causal, + window_size=window_size, + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + query_padding_mask, + key_padding_mask, + None, + 0.0, + None, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") From 7df044766760fc1f74742d9ebf84816896d0a8fa Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 3 Dec 2024 02:30:23 -0500 Subject: [PATCH 03/17] Add missing assert --- tests/test_flash_attn_ck_page_varlen.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_flash_attn_ck_page_varlen.py b/tests/test_flash_attn_ck_page_varlen.py index 7f651301206..cc8f7fcab8a 100644 --- a/tests/test_flash_attn_ck_page_varlen.py +++ b/tests/test_flash_attn_ck_page_varlen.py @@ -193,3 +193,7 @@ def test_flash_attn_varlen_causal( print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 From c63ac38d64ee78b8cc59524d6eb2149ffb41cde5 Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 3 Dec 2024 16:23:41 -0500 Subject: [PATCH 04/17] Fix bug for k and v --- csrc/composable_kernel | 2 +- csrc/flash_attn_ck/mha_varlen_fwd.cpp | 2 ++ tests/test_flash_attn_ck_page_varlen.py | 8 ++------ 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index c6cb8c52c16..ead9c3cbf38 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit c6cb8c52c168fcc63ca5fc63fbe9650f81052a26 +Subproject commit ead9c3cbf38ab8065c6024f8b8d637862571595f diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index b305c7b8436..4d7053ace6b 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -340,6 +340,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(k); + CHECK_CONTIGUOUS(v); CHECK_CONTIGUOUS(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_k); diff --git a/tests/test_flash_attn_ck_page_varlen.py b/tests/test_flash_attn_ck_page_varlen.py index cc8f7fcab8a..451493a03a5 100644 --- a/tests/test_flash_attn_ck_page_varlen.py +++ b/tests/test_flash_attn_ck_page_varlen.py @@ -134,19 +134,15 @@ def test_flash_attn_varlen_causal( ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) if paged_kv_block_size is not None: nblock = k_cache_paged.shape[0] - print(k_cache_paged.shape) - print(v_cache_paged.shape) k_cache_paged = rearrange(k_cache_paged, 'nblock block_size nheads (d1 d2) -> nblock nheads d1 block_size d2', - block_size=paged_kv_block_size, d1=d // 8, d2=8) + block_size=paged_kv_block_size, d1=d // 8, d2=8).contiguous() v_cache_paged = rearrange(v_cache_paged, 'nblock block_size nheads d -> nblock nheads d block_size', - block_size=paged_kv_block_size) + block_size=paged_kv_block_size).contiguous() - print(k_cache_paged.shape) - print(v_cache_paged.shape) out_unpad = flash_attn_varlen_func( q_unpad, From d3d271b9bd8e3e7309a7b324c41f58cff870943f Mon Sep 17 00:00:00 2001 From: rocking Date: Fri, 13 Dec 2024 02:47:33 -0500 Subject: [PATCH 05/17] Merge new layout to test --- tests/test_flash_attn_ck.py | 15 ++ tests/test_flash_attn_ck_page_varlen.py | 195 ------------------------ 2 files changed, 15 insertions(+), 195 deletions(-) delete mode 100644 tests/test_flash_attn_ck_page_varlen.py diff --git a/tests/test_flash_attn_ck.py b/tests/test_flash_attn_ck.py index 503b7bf01c3..b9091f0dc19 100644 --- a/tests/test_flash_attn_ck.py +++ b/tests/test_flash_attn_ck.py @@ -922,6 +922,21 @@ def test_flash_attn_varlen_causal( dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + + if paged_kv_block_size is not None: + # vllm layout + if d % 8 != 0: + pytest.skip() + nblock = k_cache_paged.shape[0] + + k_cache_paged = rearrange(k_cache_paged, + 'nblock block_size nheads (d1 d2) -> nblock nheads d1 block_size d2', + block_size=paged_kv_block_size, d1=d // 8, d2=8).contiguous() + + v_cache_paged = rearrange(v_cache_paged, + 'nblock block_size nheads d -> nblock nheads d block_size', + block_size=paged_kv_block_size).contiguous() + out_unpad = flash_attn_varlen_func( q_unpad, k_unpad if paged_kv_block_size is None else k_cache_paged, diff --git a/tests/test_flash_attn_ck_page_varlen.py b/tests/test_flash_attn_ck_page_varlen.py deleted file mode 100644 index 451493a03a5..00000000000 --- a/tests/test_flash_attn_ck_page_varlen.py +++ /dev/null @@ -1,195 +0,0 @@ -import math - -import pytest -import torch -import torch.nn.functional as F -from einops import rearrange, repeat -from flash_attn import ( - flash_attn_func, - flash_attn_kvpacked_func, - flash_attn_qkvpacked_func, - flash_attn_varlen_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, - flash_attn_with_kvcache, -) - -from test_flash_attn import ( - attn_bias_from_alibi_slopes, - convert_flash_attn_S_to_softmax, - generate_qkv, - generate_random_padding_mask, - _generate_block_kvcache, - attention_ref, - attention_kvpacked_ref, - attention_qkvpacked_ref, -) - -from flash_attn.layers.rotary import apply_rotary_emb - -def is_bwd_hdim_supported(d): - return d <= 256 - - -def ck_randval_to_dropout_mask(randval, p): - # If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout - # randval in 255 * [0, 0.7] will be kept - # If return dropout_mask >=0, value will be kept - return math.floor(255.0 * (1 - p)) - randval.to(torch.float32) - - -def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_rounded, seqlen_k_rounded): - """ pad + rearrange [nheads, total_q, max_seqlen_k] into [b, nheads, seqlen_q_rounded, seqlen_k_rounded] - Arguments: - S_dmask: (nheads, total_q, max_seqlen_k) - cu_seqlens_q: (b + 1) - Output: - S_dmask: (b, nheads, seqlen_q_rounded, seqlen_k_rounded) - """ - batch_size = cu_seqlens_q.numel() - 1 - seqlens_q = torch.roll(cu_seqlens_q, shifts = -1) - cu_seqlens_q - seqlens_q = seqlens_q[0:batch_size].tolist() - S_dmask = torch.split(S_dmask, seqlens_q, dim=1) - # [(nheads, seqlen_q0, max_seqlen_k), (nheads, seqlen_q1, max_seqlen_k), ..., (nheads, seqlen_qb, max_seqlen_k)] - masks = () - for mask in S_dmask: - # (nheads, seqlen_qi, max_seqlen_k) -> (nheads, seqlen_q_rounded, seqlen_k_rounded) - mask = F.pad(mask, (0, seqlen_k_rounded - mask.shape[2], 0, seqlen_q_rounded - mask.shape[1], 0, 0)).unsqueeze(1) - masks = masks + (mask, ) - S_dmask = torch.cat(masks, dim=1) - - S_dmask = S_dmask.transpose(0, 1) - return S_dmask - - -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [False]) -@pytest.mark.parametrize("d", [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize("d", [32]) -@pytest.mark.parametrize("swap_sq_sk", [False, True]) -# @pytest.mark.parametrize("swap_sq_sk", [False]) -@pytest.mark.parametrize( - "seqlen_q,seqlen_k", - [ - # (1, 239), - (3, 799), - (127, 512), - (127, 513), - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (1023, 1024), - ], -) -# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) -@pytest.mark.parametrize("paged_kv_block_size", [128, 256, 512]) -def test_flash_attn_varlen_causal( - seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype -): - if max(seqlen_q, seqlen_k) >= 2048: - pytest.skip() - if swap_sq_sk: - seqlen_q, seqlen_k = seqlen_k, seqlen_q - device = "cuda" - causal = True - # set seed - torch.random.manual_seed(0) - batch_size = 8 - nheads = 9 - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) - - if paged_kv_block_size is None: - k = torch.randn( - batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True - ) - v = torch.randn( - batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True - ) - block_table = None - else: - k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache( - seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype - ) - query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") - key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") - ( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - output_pad_fn, - dq_pad_fn, - dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) - if paged_kv_block_size is not None: - nblock = k_cache_paged.shape[0] - - k_cache_paged = rearrange(k_cache_paged, - 'nblock block_size nheads (d1 d2) -> nblock nheads d1 block_size d2', - block_size=paged_kv_block_size, d1=d // 8, d2=8).contiguous() - - v_cache_paged = rearrange(v_cache_paged, - 'nblock block_size nheads d -> nblock nheads d block_size', - block_size=paged_kv_block_size).contiguous() - - - out_unpad = flash_attn_varlen_func( - q_unpad, - k_unpad if paged_kv_block_size is None else k_cache_paged, - v_unpad if paged_kv_block_size is None else v_cache_paged, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - 0.0, - causal=causal, - window_size=window_size, - block_table=block_table, - ) - out = output_pad_fn(out_unpad) - out_ref, attn_ref = attention_ref( - q, - k, - v, - query_padding_mask, - key_padding_mask, - None, - 0.0, - None, - causal=causal, - window_size=window_size, - ) - out_pt, attn_pt = attention_ref( - q, - k, - v, - query_padding_mask, - key_padding_mask, - None, - 0.0, - None, - causal=causal, - window_size=window_size, - upcast=False, - reorder_ops=True, - ) - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 From 84c153fa64476b97e50eb22f502c889431c75c37 Mon Sep 17 00:00:00 2001 From: rocking Date: Fri, 13 Dec 2024 03:08:57 -0500 Subject: [PATCH 06/17] Fix bug of qkv pack --- csrc/flash_attn_ck/mha_varlen_fwd.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 4d7053ace6b..4da8b63c33f 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -335,13 +335,13 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si CHECK_DEVICE(block_table); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + CHECK_CONTIGUOUS(k); + CHECK_CONTIGUOUS(v); } TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - CHECK_CONTIGUOUS(k); - CHECK_CONTIGUOUS(v); CHECK_CONTIGUOUS(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_k); From 3c655c51615b94883600b1ab7e4f49ccb4c2cfb8 Mon Sep 17 00:00:00 2001 From: rocking Date: Fri, 13 Dec 2024 08:40:00 -0500 Subject: [PATCH 07/17] Disable kvcache api for now, waiting CK to support correct layout for append kv --- csrc/flash_attn_ck/mha_fwd_kvcache.cpp | 2 + tests/test_flash_attn_ck.py | 530 ++++++++++++------------- 2 files changed, 267 insertions(+), 265 deletions(-) diff --git a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp index 2f8b6436307..b769d626adf 100644 --- a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp +++ b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp @@ -287,6 +287,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits) { + TORCH_CHECK(false, "vllm layout does not support mha_fwd_kvcache for now"); + auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); diff --git a/tests/test_flash_attn_ck.py b/tests/test_flash_attn_ck.py index b9091f0dc19..e2bb1187cf5 100644 --- a/tests/test_flash_attn_ck.py +++ b/tests/test_flash_attn_ck.py @@ -1035,271 +1035,271 @@ def test_flash_attn_varlen_causal( # TODO - Support has_leftpad -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("num_splits", [1, 0]) -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("new_kv", [False, True]) -@pytest.mark.parametrize("alibi", [False, True]) -@pytest.mark.parametrize("local", [False, True]) -@pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) -@pytest.mark.parametrize("rotary_interleaved", [False, True]) -@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) -@pytest.mark.parametrize("paged_kv_block_size", [None, 256]) -@pytest.mark.parametrize("has_leftpad", [False]) -@pytest.mark.parametrize("has_batch_idx", [False, True]) -@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) -@pytest.mark.parametrize( - "seqlen_q,seqlen_k", - [ - (1, 128), - (1, 339), - (3, 1024), - (64, 800), - (64, 256), - (3, 799), - (64, 2048), - (16, 20000), - (1, 128 * 1024), - (16, 128 * 1024), - (128, 128), - ], -) -def test_flash_attn_kvcache( - seqlen_q, - seqlen_k, - d, - has_batch_idx, - has_leftpad, - paged_kv_block_size, - rotary_fraction, - rotary_interleaved, - seqlen_new_eq_seqlen_q, - causal, - local, - alibi, - new_kv, - mha_type, - num_splits, - dtype, -): - if seqlen_q > seqlen_k and new_kv: - pytest.skip() - if not new_kv and rotary_fraction > 0.0: - pytest.skip() - if has_batch_idx and paged_kv_block_size is not None: - pytest.skip() - if has_leftpad and paged_kv_block_size is not None: - pytest.skip() - device = "cuda" - # set seed - torch.random.manual_seed(0) - batch_size = 1 - batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 - nheads = 6 - # rotary_dim must be a multiple of 16, and must be <= d - rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 - nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) - assert nheads % nheads_k == 0 - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) - seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() - if new_kv: - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) - v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) - else: - k, v = None, None - if paged_kv_block_size is None: - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) - block_table = None - else: - ( - k_cache, - v_cache, - block_table, - k_cache_paged, - v_cache_paged, - num_blocks, - ) = _generate_block_kvcache( - seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype - ) - cache_seqlens = torch.randint( - 0 if new_kv else 1, - # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - ( - (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) - if new_kv - else (seqlen_k + 1) - ), - (batch_size,), - dtype=torch.int32, - device=device, - ) - if has_leftpad: - cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) - if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) - for i in range(batch_size)]) - else: - cache_leftpad = None - arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) - if has_leftpad: - key_padding_mask = torch.logical_and( - key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) - ) - if has_batch_idx: - cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ - :batch_size - ] - else: - cache_batch_idx = None - if alibi: - alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 - attn_bias = attn_bias_from_alibi_slopes( - alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad - ) - else: - alibi_slopes, attn_bias = None, None - # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) - if rotary_dim > 0: - angle = ( - torch.rand( - seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, - rotary_dim // 2, - device=device, - ) - * 2 - * math.pi - ) - cos = torch.cos(angle).to(dtype=dtype) - sin = torch.sin(angle).to(dtype=dtype) - if causal or local: - q_ro = apply_rotary_emb( - q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved - ) - else: - q_ro = rearrange( - apply_rotary_emb( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=cache_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=seqlen_q, - ) - # q_ro = q - k_ro = apply_rotary_emb( - k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved - ) - else: - cos, sin = None, None - q_ro, k_ro = q, k - # k_cache[:, 64:] = -1 - k_cache_ref = ( - k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] - ).clone() - v_cache_ref = ( - v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] - ).clone() - if new_kv: - update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new - ) - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - out = flash_attn_with_kvcache( - q, - k_cache if paged_kv_block_size is None else k_cache_paged, - v_cache if paged_kv_block_size is None else v_cache_paged, - k, - v, - rotary_cos=cos, - rotary_sin=sin, - cache_seqlens=cache_seqlens, - cache_batch_idx=cache_batch_idx, - cache_leftpad=cache_leftpad, - block_table=block_table, - causal=causal, - window_size=window_size, - rotary_interleaved=rotary_interleaved, - alibi_slopes=alibi_slopes, - num_splits=num_splits, - ) - # out = flash_attn_with_kvcache( - # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size - # ) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) - # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) - # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) - # probs = torch.softmax(qk, dim=-1) - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - key_padding_mask, - attn_bias, - 0.0, - None, - causal=causal, - window_size=window_size, - key_leftpad=cache_leftpad, - ) - out_pt, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - key_padding_mask, - attn_bias, - 0.0, - None, - causal=causal, - window_size=window_size, - upcast=False, - reorder_ops=True, - key_leftpad=cache_leftpad, - ) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - if new_kv: - if paged_kv_block_size is None: - k_cache_select = ( - k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] - ) - v_cache_select = ( - v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] - ) - else: - k_cache_select = rearrange( - k_cache_paged[block_table.to(dtype=torch.long).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k] - v_cache_select = rearrange( - v_cache_paged[block_table.to(dtype=torch.long).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k] - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) - assert torch.equal(v_cache_select, v_cache_ref) - # mult = 3 if f16, bf16 need 4 - mult = 4 if not alibi else 5 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 +# @pytest.mark.parametrize("dtype", [torch.float16]) +# @pytest.mark.parametrize("num_splits", [1, 0]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("new_kv", [False, True]) +# @pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) +# @pytest.mark.parametrize("rotary_interleaved", [False, True]) +# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +# @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) +# @pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize( +# "seqlen_q,seqlen_k", +# [ +# (1, 128), +# (1, 339), +# (3, 1024), +# (64, 800), +# (64, 256), +# (3, 799), +# (64, 2048), +# (16, 20000), +# (1, 128 * 1024), +# (16, 128 * 1024), +# (128, 128), +# ], +# ) +# def test_flash_attn_kvcache( +# seqlen_q, +# seqlen_k, +# d, +# has_batch_idx, +# has_leftpad, +# paged_kv_block_size, +# rotary_fraction, +# rotary_interleaved, +# seqlen_new_eq_seqlen_q, +# causal, +# local, +# alibi, +# new_kv, +# mha_type, +# num_splits, +# dtype, +# ): +# if seqlen_q > seqlen_k and new_kv: +# pytest.skip() +# if not new_kv and rotary_fraction > 0.0: +# pytest.skip() +# if has_batch_idx and paged_kv_block_size is not None: +# pytest.skip() +# if has_leftpad and paged_kv_block_size is not None: +# pytest.skip() +# device = "cuda" +# # set seed +# torch.random.manual_seed(0) +# batch_size = 1 +# batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 +# nheads = 6 +# # rotary_dim must be a multiple of 16, and must be <= d +# rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 +# nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) +# assert nheads % nheads_k == 0 +# window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) +# q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) +# seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() +# if new_kv: +# k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) +# v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) +# else: +# k, v = None, None +# if paged_kv_block_size is None: +# k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) +# v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) +# block_table = None +# else: +# ( +# k_cache, +# v_cache, +# block_table, +# k_cache_paged, +# v_cache_paged, +# num_blocks, +# ) = _generate_block_kvcache( +# seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype +# ) +# cache_seqlens = torch.randint( +# 0 if new_kv else 1, +# # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough +# ( +# (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) +# if new_kv +# else (seqlen_k + 1) +# ), +# (batch_size,), +# dtype=torch.int32, +# device=device, +# ) +# if has_leftpad: +# cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) +# if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) +# for i in range(batch_size)]) +# else: +# cache_leftpad = None +# arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") +# cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") +# key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) +# if has_leftpad: +# key_padding_mask = torch.logical_and( +# key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) +# ) +# if has_batch_idx: +# cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ +# :batch_size +# ] +# else: +# cache_batch_idx = None +# if alibi: +# alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 +# attn_bias = attn_bias_from_alibi_slopes( +# alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad +# ) +# else: +# alibi_slopes, attn_bias = None, None +# # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) +# if rotary_dim > 0: +# angle = ( +# torch.rand( +# seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, +# rotary_dim // 2, +# device=device, +# ) +# * 2 +# * math.pi +# ) +# cos = torch.cos(angle).to(dtype=dtype) +# sin = torch.sin(angle).to(dtype=dtype) +# if causal or local: +# q_ro = apply_rotary_emb( +# q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved +# ) +# else: +# q_ro = rearrange( +# apply_rotary_emb( +# rearrange(q, "b s h d -> b 1 (s h) d"), +# cos, +# sin, +# seqlen_offsets=cache_seqlens, +# interleaved=rotary_interleaved, +# ), +# "b 1 (s h) d -> b s h d", +# s=seqlen_q, +# ) +# # q_ro = q +# k_ro = apply_rotary_emb( +# k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved +# ) +# else: +# cos, sin = None, None +# q_ro, k_ro = q, k +# # k_cache[:, 64:] = -1 +# k_cache_ref = ( +# k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] +# ).clone() +# v_cache_ref = ( +# v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] +# ).clone() +# if new_kv: +# update_mask = torch.logical_and( +# cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new +# ) +# k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") +# v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") +# k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) +# v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) +# out = flash_attn_with_kvcache( +# q, +# k_cache if paged_kv_block_size is None else k_cache_paged, +# v_cache if paged_kv_block_size is None else v_cache_paged, +# k, +# v, +# rotary_cos=cos, +# rotary_sin=sin, +# cache_seqlens=cache_seqlens, +# cache_batch_idx=cache_batch_idx, +# cache_leftpad=cache_leftpad, +# block_table=block_table, +# causal=causal, +# window_size=window_size, +# rotary_interleaved=rotary_interleaved, +# alibi_slopes=alibi_slopes, +# num_splits=num_splits, +# ) +# # out = flash_attn_with_kvcache( +# # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size +# # ) +# # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) +# # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) +# # m = qk.amax(-1, keepdim=True) +# # s_tmp = torch.exp((qk - m) / math.sqrt(d)) +# # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) +# # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) +# # probs = torch.softmax(qk, dim=-1) +# out_ref, _ = attention_ref( +# q_ro, +# k_cache_rep, +# v_cache_rep, +# None, +# key_padding_mask, +# attn_bias, +# 0.0, +# None, +# causal=causal, +# window_size=window_size, +# key_leftpad=cache_leftpad, +# ) +# out_pt, _ = attention_ref( +# q_ro, +# k_cache_rep, +# v_cache_rep, +# None, +# key_padding_mask, +# attn_bias, +# 0.0, +# None, +# causal=causal, +# window_size=window_size, +# upcast=False, +# reorder_ops=True, +# key_leftpad=cache_leftpad, +# ) +# print(f"Output max diff: {(out - out_ref).abs().max().item()}") +# print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") +# print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") +# print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + +# # Check that FlashAttention's numerical error is at most twice the numerical error +# # of a Pytorch implementation. +# if new_kv: +# if paged_kv_block_size is None: +# k_cache_select = ( +# k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] +# ) +# v_cache_select = ( +# v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] +# ) +# else: +# k_cache_select = rearrange( +# k_cache_paged[block_table.to(dtype=torch.long).flatten()], +# "(b nblocks) block_size ... -> b (nblocks block_size) ...", +# b=batch_size, +# )[:, :seqlen_k] +# v_cache_select = rearrange( +# v_cache_paged[block_table.to(dtype=torch.long).flatten()], +# "(b nblocks) block_size ... -> b (nblocks block_size) ...", +# b=batch_size, +# )[:, :seqlen_k] +# assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) +# assert torch.equal(v_cache_select, v_cache_ref) +# # mult = 3 if f16, bf16 need 4 +# mult = 4 if not alibi else 5 +# assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 From e5c5435a576c88f0146022d6c700881db38903ca Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 17 Dec 2024 19:12:49 +0000 Subject: [PATCH 08/17] Update num_splits heuristics --- csrc/flash_attn_ck/flash_common.cpp | 51 ++++++++++++++++++--- csrc/flash_attn_ck/flash_common.hpp | 63 ++++++++++++++------------ csrc/flash_attn_ck/mha_fwd_kvcache.cpp | 3 +- csrc/flash_attn_ck/mha_varlen_fwd.cpp | 3 +- 4 files changed, 83 insertions(+), 37 deletions(-) diff --git a/csrc/flash_attn_ck/flash_common.cpp b/csrc/flash_attn_ck/flash_common.cpp index fb80d05cefe..492fd15c855 100644 --- a/csrc/flash_attn_ck/flash_common.cpp +++ b/csrc/flash_attn_ck/flash_common.cpp @@ -5,28 +5,65 @@ #include "flash_common.hpp" namespace flash { -int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) +int override_num_splits_if_necessary(int batch, + int nhead, + int max_seqlen_q, + int hdim_q, + int hdim_v, + float p_drop, + bool is_prefill, + int num_splits) { int device; auto status = hipGetDevice(&device); if(status != hipSuccess) + { return num_splits; + } hipDeviceProp_t props{}; status = hipGetDeviceProperties(&props, device); if(status != hipSuccess) + { return num_splits; + } - // TODO - tile size should match the TileFmhaShape, hardcode for now - const int kM0 = 128; - const int kN1 = hdim_v; + const int kM0 = [&] { + // get kM0 for prefill phase + if(is_prefill) + { + return 128; + } + + // get kM0 for decode phase + /// TODO: take dtype=fp8/bf8 into consideration + const std::map hdim_to_m0 = { + {32, 32}, + {64, 64}, + // {96, 64}, + {128, 64}, + {256, 64}, + }; + + for(auto [hdim, m0] : hdim_to_m0) + { + if(hdim_q <= hdim && hdim_v <= hdim) + { + return m0; + } + } + + return 64; // meet unsupported hdim_q/hdim_v + }(); + // const int kN1 = hdim_v; const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; - const int num_n_blocks = (hdim_v + kN1 - 1) / kN1; + // const int num_n_blocks = (hdim_v + kN1 - 1) / kN1; // always 1 if(num_splits < 1 && p_drop == 0.0f) - return num_splits_heuristic_ck( - batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); + { + return num_splits_heuristic_ck(batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 8); + } return num_splits; } diff --git a/csrc/flash_attn_ck/flash_common.hpp b/csrc/flash_attn_ck/flash_common.hpp index cc86546ea54..e8158cc0f39 100644 --- a/csrc/flash_attn_ck/flash_common.hpp +++ b/csrc/flash_attn_ck/flash_common.hpp @@ -35,42 +35,49 @@ inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* r } } -inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { +inline int num_splits_heuristic_ck(int batch_nhead_mblocks, int num_SMs, int max_splits) +{ // If we have enough to almost fill the SMs, then just use 1 split - if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } - max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + if(batch_nhead_mblocks >= 0.8f * num_SMs) + { + return 1; + } + + max_splits = std::min({max_splits, num_SMs}); + + constexpr std::array num_splits_array = {1, 2, 4, 8, 16}; + float max_efficiency = 0.f; - std::vector efficiency; - efficiency.reserve(max_splits); - auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; - // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, - // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks - // (i.e. it's 11 splits anyway). - // So we check if the number of blocks per split is the same as the previous num_splits. - auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { - return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); - }; - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - if (!is_split_eligible(num_splits)) { - efficiency.push_back(0.f); - } else { - float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; - float eff = n_waves / ceil(n_waves); - // printf("num_splits = %d, eff = %f\n", num_splits, eff); - if (eff > max_efficiency) { max_efficiency = eff; } - efficiency.push_back(eff); + std::array efficiency; + + for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx) + { + float n_blocks = float(batch_nhead_mblocks * num_splits_array[idx]) / num_SMs; + float eff = n_blocks / std::ceil(n_blocks); + + if(eff > max_efficiency) + { + max_efficiency = eff; } + efficiency[idx] = eff; } - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - if (!is_split_eligible(num_splits)) { continue; } - if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { - // printf("num_splits chosen = %d\n", num_splits); - return num_splits; + for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx) + { + if(efficiency[idx] >= 0.85 * max_efficiency) + { + return num_splits_array[idx]; } } return 1; } -int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits); +int override_num_splits_if_necessary(int batch, + int nhead, + int max_seqlen_q, + int hdim_q, + int hdim_v, + float p_drop, + bool is_prefill, + int num_splits); } // namespace flash diff --git a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp index b769d626adf..3ff388acbe3 100644 --- a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp +++ b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp @@ -473,7 +473,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32"); } - num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, seqlen_q, head_size_8x, 0, num_splits); + num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, seqlen_q, head_size_8x, head_size_8x, + /*p_drop=*/0, /*is_prefill=*/false, num_splits); TORCH_CHECK(num_splits > 0, "num_splits should greater than 0"); TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported"); diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 4da8b63c33f..570cfd18fa7 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -446,7 +446,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si } int num_splits = 0; - num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, 0, num_splits); + num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, head_size, + /*p_drop=*/0, /*is_prefill=*/true, num_splits); TORCH_CHECK(num_splits > 0, "num_splits should greater than 0"); TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported"); From 89a144c4f1a53a73fe1aad5999bb41fe5b5a0761 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 17 Dec 2024 19:28:20 +0000 Subject: [PATCH 09/17] Update CK changes --- csrc/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index ead9c3cbf38..a176a65d2f7 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit ead9c3cbf38ab8065c6024f8b8d637862571595f +Subproject commit a176a65d2f77fe8a173044d22c3ee0932614bbd8 From f8ffd2bff0942e050b943cacaf48ccdc52897f12 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Thu, 19 Dec 2024 08:31:57 +0000 Subject: [PATCH 10/17] Use experimental branch instead --- csrc/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index a176a65d2f7..43596386192 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit a176a65d2f77fe8a173044d22c3ee0932614bbd8 +Subproject commit 435963861922f746cf1a7256b78c6b587070d15a From ddcc375dcb80b11d9acf19da162eb3b2a68d6923 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 23 Dec 2024 09:31:51 -0600 Subject: [PATCH 11/17] Include splitkv optimizations --- csrc/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index 43596386192..c5083c0f1b2 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit 435963861922f746cf1a7256b78c6b587070d15a +Subproject commit c5083c0f1b20902d3a2f3d2584e13cc064178e78 From 5804c42a514bf8c6c19f34526e407144de654a48 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 23 Dec 2024 14:33:40 -0600 Subject: [PATCH 12/17] Fix wrong V layout used for splitkv kernel --- csrc/flash_attn_ck/mha_varlen_fwd.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 570cfd18fa7..d8218a98f29 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -36,7 +36,7 @@ fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &m head_size, dtype, true, // is_group_mode - false, // is_v_rowmajor + true, // is_v_rowmajor mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, From 469f7c8399387754931bce44a73becc35eaaf0b6 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 23 Dec 2024 14:34:49 -0600 Subject: [PATCH 13/17] Update codegen logic --- csrc/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index c5083c0f1b2..bb0934704e7 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit c5083c0f1b20902d3a2f3d2584e13cc064178e78 +Subproject commit bb0934704e70a4ea69d91f53b1650aaf5cac87e0 From 230c51301b90e94b17a5d06821cbaa9226227be1 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 23 Dec 2024 21:12:10 -0600 Subject: [PATCH 14/17] Revert "Fix wrong V layout used for splitkv kernel" This reverts commit 5804c42a514bf8c6c19f34526e407144de654a48. --- csrc/flash_attn_ck/mha_varlen_fwd.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index d8218a98f29..570cfd18fa7 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -36,7 +36,7 @@ fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &m head_size, dtype, true, // is_group_mode - true, // is_v_rowmajor + false, // is_v_rowmajor mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, From 37dffb0175ba460af038e90f1a4c4a3e380bd17d Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 23 Dec 2024 21:47:51 -0600 Subject: [PATCH 15/17] Sync new fwd splitkv codegen logics --- csrc/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index bb0934704e7..1fef9106526 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit bb0934704e70a4ea69d91f53b1650aaf5cac87e0 +Subproject commit 1fef9106526259bc30454c75250487f319ad4014 From ff0bcf4591ddf036b3b1952a3e67a0e329279967 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 24 Dec 2024 12:17:11 -0600 Subject: [PATCH 16/17] Use vector load if paged-vcache is in column major --- csrc/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index 1fef9106526..65bbe6ea5c5 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit 1fef9106526259bc30454c75250487f319ad4014 +Subproject commit 65bbe6ea5c527b8f557789ec8392a8fdf9a0a26c From 766f2f6df50c4c34be677b96c32a4ddba5151265 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Sun, 29 Dec 2024 11:57:11 -0600 Subject: [PATCH 17/17] Sync block-mapping for fmha splitkv kernel --- csrc/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index 65bbe6ea5c5..ca1a816d6fe 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit 65bbe6ea5c527b8f557789ec8392a8fdf9a0a26c +Subproject commit ca1a816d6fe3d796ac335577800d0949bbd7d8ed