[NPUW] Add block-based KV cache support for HFA and Pyramid attention#35014
Open
intelgaoxiong wants to merge 4 commits into
Open
[NPUW] Add block-based KV cache support for HFA and Pyramid attention#35014intelgaoxiong wants to merge 4 commits into
intelgaoxiong wants to merge 4 commits into
Conversation
bcef02a to
88c1504
Compare
88c1504 to
f1f9080
Compare
dylanneve1
added a commit
to dylanneve1/openvino
that referenced
this pull request
May 27, 2026
…ock-aware HFA + Pyramid
86b3295 to
2a50694
Compare
Refactor the SDPA index structures and attention metadata to accommodate block-based KV cache layouts where past_key/value are split into N fixed- size block tensors instead of a single contiguous buffer. Key changes: - attention.hpp: rename SDPAIndices past_key/value -> past_key_blocks/ past_value_blocks (vector<size_t>); extend PyramidAttentionInfo with per-variant block port sets and global param index lists - sdpa_utils.cpp/hpp: new shared helpers extracted from pyramid_attention and host_flash_attention (build_sdpa_param_mapping, etc.) - host_flash_attention: use block index loop in build_sdpa_param_mapping - pyramid_attention: add is_block_split path; shrink_concat_inputs / collect_concat_block_indices helpers; populate block port metadata - attn/attn_subgraph.cpp: alias block slots on main request; name-based input sharing for pyramid variants in block mode - partitioning/patterns/sdpa.cpp: relax Concat input-count check to allow variable number of block inputs - serialization.cpp + test: update field names to match SDPAIndices rename Related-to: EISW-206740 Fixed clang-format. Signed-off-by: intelgaoxiong <xiong.gao@intel.com>
Signed-off-by: intelgaoxiong <xiong.gao@intel.com>
Signed-off-by: intelgaoxiong <xiong.gao@intel.com>
Signed-off-by: intelgaoxiong <xiong.gao@intel.com>
2a50694 to
3eaa3f4
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Details:
What this PR does
Extends Host Flash Attention (HFA) and Pyramid Attention to operate with the block-split KV cache produced by
SplitKVCacheIntoBlocks. After that transformation a singlepast_key/past_valueparameter is replaced by N independent block parameters each feeding one input of a multi-inputConcatnode. Both attention backends mustdetect this new layout and adapt their compilation-time model shaping and inference-time tensor dispatch accordingly.
This is Part 3/4 of the block-based KV cache feature series:
SplitKVCacheIntoBlocksgraph transformationKVCacheBlockManager— block allocation and lifecycleChanges by section
Section 1 — Shared infrastructure
util.hpp/cpp— renameisPastKeyValues{Key,Value}→isPastKeyParam/isPastValueParam; add…Contiguousvariants that match only single-parameter (non-block) names; used by both HFA and Pyramid to distinguish contiguous from block-split layouts.sdpa_utils.hpp/cpp(new file) — extractsSDPAPatternNodes(holdingvector<…> past_key_param_nodes/past_value_param_nodesfor 1 or N elements),find_sdpa_pattern_nodes(), andfind_mask_parameter(). Previously duplicated between HFA and Pyramid; now shared.attention.hpp— extendfunction::Attentionwithpast_key/value_block_variant_param_indices(ordered by Concat input); extendcompiled::SDPAIndiceswithpast_key_blocks/past_value_blocksvectors.Section 2 — Host Flash Attention
Compilation (
host_flash_attention.cpp):build_sdpa_param_mapping()now iterates all Concat inputs (not only the first) to discover and record every block-parameter index into_past_key_block_indices/_past_value_block_indices.compiled::SDPAIndices.past_key_blocks/past_value_blocksfor use at inference time.HFA_Tile,HFA_Final_Tile) are generated once and remain layout-agnostic; only the tensor sourcing strategy differs at inference time.Inference (
attn_subgraph.cpp):context_size / tile_sizeoffsets, slicing or zero-copy-viewing into the tensor at eachkv_offset.past_key_blocks(one entry per block tensor); the final tile usespresent_key_tensor. All tiles are dispatched through the sameprocess_tilelambda regardless of source.Section 3 — Pyramid Attention
Compilation (
pyramid_attention.cpp):is_block_split = (past_key_param_nodes.size() > 1)or name does not matchisPastKeyParamContiguous().current_past_length→reshape()→validate_nodes_and_infer_types().shrink_concat_inputs()to keep exactlymodel_idxpast block inputs in the Concat (model[0] → 0 blocks, model[1] → 1 block, …, model[k] → k blocks) →patch_broadcast_constants+patch_reshape_constants→validate_nodes_and_infer_types→collect_concat_block_indicesto populatepast_key/value_block_variant_param_indices. Precomputepast_key_block_port_map(global index → variant port) andpast_key_block_port_setfor O(1) lookup at inference time.Inference (
attn_subgraph.cpp,just_sync_infer_request.cpp):bind_function_input()callsutil::view(tensor, param.dim, 0, past_len)to present each pyramidvariant with a correctly sized KV slice; mask is rebuilt per-variant in
prologue().alias_block_slots()pre-wires all global block-slot ports on the main request toblock_0's buffer as a placeholder, so earlier generic binding code never touches unallocated slots.bind_function_input()→try_bind_block()looks uppast_key_block_port_map[global_idx]to dispatch each incoming block tensor to the correct variant-local port. Variants that expose no port for a given block index (e.g. model[0] with 0 past blocks) consume the call silently withoutset_tensor.just_sync_infer_request.cpp:share_kv_block_buffers()shares block KV buffers across pyramid variant sub-requests to avoid redundant allocations.Other:
base_sync_infer_request.cpp— replace scalarpast_key/past_value checks with anis_past_kv()lambda; addblock_mode+bind_block_ports()lambda inbind_pyramid_attention_inputs()`.partitioning/patterns/sdpa.cpp— relax Concat input-count guard to allow multi-block inputs.Tickets:
AI Assistance: