diff --git a/csrc/composable_kernel b/csrc/composable_kernel index e7b6286441a..ca1a816d6fe 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit e7b6286441aae59d3a87db67f42369d3cc2636a4 +Subproject commit ca1a816d6fe3d796ac335577800d0949bbd7d8ed diff --git a/csrc/flash_attn_ck/flash_common.cpp b/csrc/flash_attn_ck/flash_common.cpp index fb80d05cefe..492fd15c855 100644 --- a/csrc/flash_attn_ck/flash_common.cpp +++ b/csrc/flash_attn_ck/flash_common.cpp @@ -5,28 +5,65 @@ #include "flash_common.hpp" namespace flash { -int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) +int override_num_splits_if_necessary(int batch, + int nhead, + int max_seqlen_q, + int hdim_q, + int hdim_v, + float p_drop, + bool is_prefill, + int num_splits) { int device; auto status = hipGetDevice(&device); if(status != hipSuccess) + { return num_splits; + } hipDeviceProp_t props{}; status = hipGetDeviceProperties(&props, device); if(status != hipSuccess) + { return num_splits; + } - // TODO - tile size should match the TileFmhaShape, hardcode for now - const int kM0 = 128; - const int kN1 = hdim_v; + const int kM0 = [&] { + // get kM0 for prefill phase + if(is_prefill) + { + return 128; + } + + // get kM0 for decode phase + /// TODO: take dtype=fp8/bf8 into consideration + const std::map hdim_to_m0 = { + {32, 32}, + {64, 64}, + // {96, 64}, + {128, 64}, + {256, 64}, + }; + + for(auto [hdim, m0] : hdim_to_m0) + { + if(hdim_q <= hdim && hdim_v <= hdim) + { + return m0; + } + } + + return 64; // meet unsupported hdim_q/hdim_v + }(); + // const int kN1 = hdim_v; const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; - const int num_n_blocks = (hdim_v + kN1 - 1) / kN1; + // const int num_n_blocks = (hdim_v + kN1 - 1) / kN1; // always 1 if(num_splits < 1 && p_drop == 0.0f) - return num_splits_heuristic_ck( - batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); + { + return num_splits_heuristic_ck(batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 8); + } return num_splits; } diff --git a/csrc/flash_attn_ck/flash_common.hpp b/csrc/flash_attn_ck/flash_common.hpp index cc86546ea54..e8158cc0f39 100644 --- a/csrc/flash_attn_ck/flash_common.hpp +++ b/csrc/flash_attn_ck/flash_common.hpp @@ -35,42 +35,49 @@ inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* r } } -inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { +inline int num_splits_heuristic_ck(int batch_nhead_mblocks, int num_SMs, int max_splits) +{ // If we have enough to almost fill the SMs, then just use 1 split - if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } - max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + if(batch_nhead_mblocks >= 0.8f * num_SMs) + { + return 1; + } + + max_splits = std::min({max_splits, num_SMs}); + + constexpr std::array num_splits_array = {1, 2, 4, 8, 16}; + float max_efficiency = 0.f; - std::vector efficiency; - efficiency.reserve(max_splits); - auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; - // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, - // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks - // (i.e. it's 11 splits anyway). - // So we check if the number of blocks per split is the same as the previous num_splits. - auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { - return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); - }; - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - if (!is_split_eligible(num_splits)) { - efficiency.push_back(0.f); - } else { - float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; - float eff = n_waves / ceil(n_waves); - // printf("num_splits = %d, eff = %f\n", num_splits, eff); - if (eff > max_efficiency) { max_efficiency = eff; } - efficiency.push_back(eff); + std::array efficiency; + + for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx) + { + float n_blocks = float(batch_nhead_mblocks * num_splits_array[idx]) / num_SMs; + float eff = n_blocks / std::ceil(n_blocks); + + if(eff > max_efficiency) + { + max_efficiency = eff; } + efficiency[idx] = eff; } - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - if (!is_split_eligible(num_splits)) { continue; } - if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { - // printf("num_splits chosen = %d\n", num_splits); - return num_splits; + for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx) + { + if(efficiency[idx] >= 0.85 * max_efficiency) + { + return num_splits_array[idx]; } } return 1; } -int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits); +int override_num_splits_if_necessary(int batch, + int nhead, + int max_seqlen_q, + int hdim_q, + int hdim_v, + float p_drop, + bool is_prefill, + int num_splits); } // namespace flash diff --git a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp index 2f8b6436307..3ff388acbe3 100644 --- a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp +++ b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp @@ -287,6 +287,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits) { + TORCH_CHECK(false, "vllm layout does not support mha_fwd_kvcache for now"); + auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); @@ -471,7 +473,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32"); } - num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, seqlen_q, head_size_8x, 0, num_splits); + num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, seqlen_q, head_size_8x, head_size_8x, + /*p_drop=*/0, /*is_prefill=*/false, num_splits); TORCH_CHECK(num_splits > 0, "num_splits should greater than 0"); TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported"); diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index b6a274f4fa5..570cfd18fa7 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -36,7 +36,7 @@ fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &m head_size, dtype, true, // is_group_mode - true, // is_v_rowmajor + false, // is_v_rowmajor mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, @@ -183,8 +183,8 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, at::Tensor out_acc) { // q: (total_q, nheads, d) - // k: (num_blocks, page_block_size, num_heads_k, d) - // v: (num_blocks, page_block_size, num_heads_k, d) + // k: (num_blocks, num_heads_k, d / 8, page_block_size, 8) + // v: (num_blocks, num_heads_k, d, page_block_size) // o: (total_q, nheads, d) // alibi_slopes:(batch_size, nheads) or (nhead) @@ -241,12 +241,12 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, args.nhead_stride_q = q.stride(1); args.batch_stride_k = k.stride(0); - args.stride_k = k.stride(1); - args.nhead_stride_k = k.stride(2); + args.nhead_stride_k = k.stride(1); + args.stride_k = k.stride(2); args.batch_stride_v = v.stride(0); - args.stride_v = v.stride(1); - args.nhead_stride_v = v.stride(2); + args.nhead_stride_v = v.stride(1); + args.stride_v = v.stride(2); args.batch_stride_o = 0; args.stride_o = out.stride(0); @@ -292,8 +292,8 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x num_heads_k x head_size / 8 x page_block_size x 8 if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x num_heads_k x page_block_size x head_size if there's a block_table. c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 @@ -335,6 +335,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si CHECK_DEVICE(block_table); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + CHECK_CONTIGUOUS(k); + CHECK_CONTIGUOUS(v); } TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); @@ -348,11 +350,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si const int batch_size = cu_seqlens_q.numel() - 1; int num_heads = sizes[1]; const int head_size = sizes[2]; - const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + const int num_heads_k = k.size(1); const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); - const int page_block_size = !paged_KV ? 1 : k.size(1); + const int page_block_size = !paged_KV ? 1 : k.size(3); TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, "Paged KV cache block size must be divisible by 128"); if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case @@ -394,8 +396,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size); } else { - CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); + CHECK_SHAPE(k, num_blocks, num_heads_k, head_size / 8, page_block_size, 8); + CHECK_SHAPE(v, num_blocks, num_heads_k, head_size, page_block_size); CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); } @@ -444,7 +446,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si } int num_splits = 0; - num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, 0, num_splits); + num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, head_size, + /*p_drop=*/0, /*is_prefill=*/true, num_splits); TORCH_CHECK(num_splits > 0, "num_splits should greater than 0"); TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported"); diff --git a/tests/test_flash_attn_ck.py b/tests/test_flash_attn_ck.py index 503b7bf01c3..e2bb1187cf5 100644 --- a/tests/test_flash_attn_ck.py +++ b/tests/test_flash_attn_ck.py @@ -922,6 +922,21 @@ def test_flash_attn_varlen_causal( dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + + if paged_kv_block_size is not None: + # vllm layout + if d % 8 != 0: + pytest.skip() + nblock = k_cache_paged.shape[0] + + k_cache_paged = rearrange(k_cache_paged, + 'nblock block_size nheads (d1 d2) -> nblock nheads d1 block_size d2', + block_size=paged_kv_block_size, d1=d // 8, d2=8).contiguous() + + v_cache_paged = rearrange(v_cache_paged, + 'nblock block_size nheads d -> nblock nheads d block_size', + block_size=paged_kv_block_size).contiguous() + out_unpad = flash_attn_varlen_func( q_unpad, k_unpad if paged_kv_block_size is None else k_cache_paged, @@ -1020,271 +1035,271 @@ def test_flash_attn_varlen_causal( # TODO - Support has_leftpad -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("num_splits", [1, 0]) -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("new_kv", [False, True]) -@pytest.mark.parametrize("alibi", [False, True]) -@pytest.mark.parametrize("local", [False, True]) -@pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) -@pytest.mark.parametrize("rotary_interleaved", [False, True]) -@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) -@pytest.mark.parametrize("paged_kv_block_size", [None, 256]) -@pytest.mark.parametrize("has_leftpad", [False]) -@pytest.mark.parametrize("has_batch_idx", [False, True]) -@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) -@pytest.mark.parametrize( - "seqlen_q,seqlen_k", - [ - (1, 128), - (1, 339), - (3, 1024), - (64, 800), - (64, 256), - (3, 799), - (64, 2048), - (16, 20000), - (1, 128 * 1024), - (16, 128 * 1024), - (128, 128), - ], -) -def test_flash_attn_kvcache( - seqlen_q, - seqlen_k, - d, - has_batch_idx, - has_leftpad, - paged_kv_block_size, - rotary_fraction, - rotary_interleaved, - seqlen_new_eq_seqlen_q, - causal, - local, - alibi, - new_kv, - mha_type, - num_splits, - dtype, -): - if seqlen_q > seqlen_k and new_kv: - pytest.skip() - if not new_kv and rotary_fraction > 0.0: - pytest.skip() - if has_batch_idx and paged_kv_block_size is not None: - pytest.skip() - if has_leftpad and paged_kv_block_size is not None: - pytest.skip() - device = "cuda" - # set seed - torch.random.manual_seed(0) - batch_size = 1 - batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 - nheads = 6 - # rotary_dim must be a multiple of 16, and must be <= d - rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 - nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) - assert nheads % nheads_k == 0 - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) - seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() - if new_kv: - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) - v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) - else: - k, v = None, None - if paged_kv_block_size is None: - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) - block_table = None - else: - ( - k_cache, - v_cache, - block_table, - k_cache_paged, - v_cache_paged, - num_blocks, - ) = _generate_block_kvcache( - seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype - ) - cache_seqlens = torch.randint( - 0 if new_kv else 1, - # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - ( - (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) - if new_kv - else (seqlen_k + 1) - ), - (batch_size,), - dtype=torch.int32, - device=device, - ) - if has_leftpad: - cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) - if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) - for i in range(batch_size)]) - else: - cache_leftpad = None - arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) - if has_leftpad: - key_padding_mask = torch.logical_and( - key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) - ) - if has_batch_idx: - cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ - :batch_size - ] - else: - cache_batch_idx = None - if alibi: - alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 - attn_bias = attn_bias_from_alibi_slopes( - alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad - ) - else: - alibi_slopes, attn_bias = None, None - # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) - if rotary_dim > 0: - angle = ( - torch.rand( - seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, - rotary_dim // 2, - device=device, - ) - * 2 - * math.pi - ) - cos = torch.cos(angle).to(dtype=dtype) - sin = torch.sin(angle).to(dtype=dtype) - if causal or local: - q_ro = apply_rotary_emb( - q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved - ) - else: - q_ro = rearrange( - apply_rotary_emb( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=cache_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=seqlen_q, - ) - # q_ro = q - k_ro = apply_rotary_emb( - k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved - ) - else: - cos, sin = None, None - q_ro, k_ro = q, k - # k_cache[:, 64:] = -1 - k_cache_ref = ( - k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] - ).clone() - v_cache_ref = ( - v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] - ).clone() - if new_kv: - update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new - ) - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - out = flash_attn_with_kvcache( - q, - k_cache if paged_kv_block_size is None else k_cache_paged, - v_cache if paged_kv_block_size is None else v_cache_paged, - k, - v, - rotary_cos=cos, - rotary_sin=sin, - cache_seqlens=cache_seqlens, - cache_batch_idx=cache_batch_idx, - cache_leftpad=cache_leftpad, - block_table=block_table, - causal=causal, - window_size=window_size, - rotary_interleaved=rotary_interleaved, - alibi_slopes=alibi_slopes, - num_splits=num_splits, - ) - # out = flash_attn_with_kvcache( - # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size - # ) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) - # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) - # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) - # probs = torch.softmax(qk, dim=-1) - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - key_padding_mask, - attn_bias, - 0.0, - None, - causal=causal, - window_size=window_size, - key_leftpad=cache_leftpad, - ) - out_pt, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - key_padding_mask, - attn_bias, - 0.0, - None, - causal=causal, - window_size=window_size, - upcast=False, - reorder_ops=True, - key_leftpad=cache_leftpad, - ) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - if new_kv: - if paged_kv_block_size is None: - k_cache_select = ( - k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] - ) - v_cache_select = ( - v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] - ) - else: - k_cache_select = rearrange( - k_cache_paged[block_table.to(dtype=torch.long).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k] - v_cache_select = rearrange( - v_cache_paged[block_table.to(dtype=torch.long).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k] - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) - assert torch.equal(v_cache_select, v_cache_ref) - # mult = 3 if f16, bf16 need 4 - mult = 4 if not alibi else 5 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 +# @pytest.mark.parametrize("dtype", [torch.float16]) +# @pytest.mark.parametrize("num_splits", [1, 0]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("new_kv", [False, True]) +# @pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) +# @pytest.mark.parametrize("rotary_interleaved", [False, True]) +# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +# @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) +# @pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize( +# "seqlen_q,seqlen_k", +# [ +# (1, 128), +# (1, 339), +# (3, 1024), +# (64, 800), +# (64, 256), +# (3, 799), +# (64, 2048), +# (16, 20000), +# (1, 128 * 1024), +# (16, 128 * 1024), +# (128, 128), +# ], +# ) +# def test_flash_attn_kvcache( +# seqlen_q, +# seqlen_k, +# d, +# has_batch_idx, +# has_leftpad, +# paged_kv_block_size, +# rotary_fraction, +# rotary_interleaved, +# seqlen_new_eq_seqlen_q, +# causal, +# local, +# alibi, +# new_kv, +# mha_type, +# num_splits, +# dtype, +# ): +# if seqlen_q > seqlen_k and new_kv: +# pytest.skip() +# if not new_kv and rotary_fraction > 0.0: +# pytest.skip() +# if has_batch_idx and paged_kv_block_size is not None: +# pytest.skip() +# if has_leftpad and paged_kv_block_size is not None: +# pytest.skip() +# device = "cuda" +# # set seed +# torch.random.manual_seed(0) +# batch_size = 1 +# batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 +# nheads = 6 +# # rotary_dim must be a multiple of 16, and must be <= d +# rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 +# nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) +# assert nheads % nheads_k == 0 +# window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) +# q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) +# seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() +# if new_kv: +# k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) +# v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) +# else: +# k, v = None, None +# if paged_kv_block_size is None: +# k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) +# v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) +# block_table = None +# else: +# ( +# k_cache, +# v_cache, +# block_table, +# k_cache_paged, +# v_cache_paged, +# num_blocks, +# ) = _generate_block_kvcache( +# seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype +# ) +# cache_seqlens = torch.randint( +# 0 if new_kv else 1, +# # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough +# ( +# (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) +# if new_kv +# else (seqlen_k + 1) +# ), +# (batch_size,), +# dtype=torch.int32, +# device=device, +# ) +# if has_leftpad: +# cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) +# if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) +# for i in range(batch_size)]) +# else: +# cache_leftpad = None +# arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") +# cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") +# key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) +# if has_leftpad: +# key_padding_mask = torch.logical_and( +# key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) +# ) +# if has_batch_idx: +# cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ +# :batch_size +# ] +# else: +# cache_batch_idx = None +# if alibi: +# alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 +# attn_bias = attn_bias_from_alibi_slopes( +# alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad +# ) +# else: +# alibi_slopes, attn_bias = None, None +# # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) +# if rotary_dim > 0: +# angle = ( +# torch.rand( +# seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, +# rotary_dim // 2, +# device=device, +# ) +# * 2 +# * math.pi +# ) +# cos = torch.cos(angle).to(dtype=dtype) +# sin = torch.sin(angle).to(dtype=dtype) +# if causal or local: +# q_ro = apply_rotary_emb( +# q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved +# ) +# else: +# q_ro = rearrange( +# apply_rotary_emb( +# rearrange(q, "b s h d -> b 1 (s h) d"), +# cos, +# sin, +# seqlen_offsets=cache_seqlens, +# interleaved=rotary_interleaved, +# ), +# "b 1 (s h) d -> b s h d", +# s=seqlen_q, +# ) +# # q_ro = q +# k_ro = apply_rotary_emb( +# k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved +# ) +# else: +# cos, sin = None, None +# q_ro, k_ro = q, k +# # k_cache[:, 64:] = -1 +# k_cache_ref = ( +# k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] +# ).clone() +# v_cache_ref = ( +# v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] +# ).clone() +# if new_kv: +# update_mask = torch.logical_and( +# cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new +# ) +# k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") +# v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") +# k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) +# v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) +# out = flash_attn_with_kvcache( +# q, +# k_cache if paged_kv_block_size is None else k_cache_paged, +# v_cache if paged_kv_block_size is None else v_cache_paged, +# k, +# v, +# rotary_cos=cos, +# rotary_sin=sin, +# cache_seqlens=cache_seqlens, +# cache_batch_idx=cache_batch_idx, +# cache_leftpad=cache_leftpad, +# block_table=block_table, +# causal=causal, +# window_size=window_size, +# rotary_interleaved=rotary_interleaved, +# alibi_slopes=alibi_slopes, +# num_splits=num_splits, +# ) +# # out = flash_attn_with_kvcache( +# # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size +# # ) +# # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) +# # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) +# # m = qk.amax(-1, keepdim=True) +# # s_tmp = torch.exp((qk - m) / math.sqrt(d)) +# # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) +# # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) +# # probs = torch.softmax(qk, dim=-1) +# out_ref, _ = attention_ref( +# q_ro, +# k_cache_rep, +# v_cache_rep, +# None, +# key_padding_mask, +# attn_bias, +# 0.0, +# None, +# causal=causal, +# window_size=window_size, +# key_leftpad=cache_leftpad, +# ) +# out_pt, _ = attention_ref( +# q_ro, +# k_cache_rep, +# v_cache_rep, +# None, +# key_padding_mask, +# attn_bias, +# 0.0, +# None, +# causal=causal, +# window_size=window_size, +# upcast=False, +# reorder_ops=True, +# key_leftpad=cache_leftpad, +# ) +# print(f"Output max diff: {(out - out_ref).abs().max().item()}") +# print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") +# print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") +# print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + +# # Check that FlashAttention's numerical error is at most twice the numerical error +# # of a Pytorch implementation. +# if new_kv: +# if paged_kv_block_size is None: +# k_cache_select = ( +# k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] +# ) +# v_cache_select = ( +# v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] +# ) +# else: +# k_cache_select = rearrange( +# k_cache_paged[block_table.to(dtype=torch.long).flatten()], +# "(b nblocks) block_size ... -> b (nblocks block_size) ...", +# b=batch_size, +# )[:, :seqlen_k] +# v_cache_select = rearrange( +# v_cache_paged[block_table.to(dtype=torch.long).flatten()], +# "(b nblocks) block_size ... -> b (nblocks block_size) ...", +# b=batch_size, +# )[:, :seqlen_k] +# assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) +# assert torch.equal(v_cache_select, v_cache_ref) +# # mult = 3 if f16, bf16 need 4 +# mult = 4 if not alibi else 5 +# assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5