[NPUW] Enable mask skipping for fused flash attention #36077
Conversation
There was a problem hiding this comment.
Pull request overview
Enables skipping attention-mask processing for fused Host Flash Attention (HFA) regular tiles by introducing an alternate “no mask input” tiled subgraph and wiring runtime selection between the masked vs no-mask variants.
Changes:
- Add generation/compilation plumbing for an additional regular-tile HFA model variant without the mask input.
- Extend HFA runtime selector interface with
current_length()to support mask-skipping decisions. - Update HFA runtime request setup/execution to optionally use the no-mask regular-tile infer request/model.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| src/plugins/intel_npu/src/plugin/npuw/host_flash_attention.hpp | Adds storage for a no-mask tile model/compiled model and extends selector API with current_length(). |
| src/plugins/intel_npu/src/plugin/npuw/host_flash_attention.cpp | Builds an optional no-mask tile model (fused path) and implements PositionIDs::current_length(). |
| src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp | Compiles/dumps the additional no-mask tile model when present. |
| src/plugins/intel_npu/src/plugin/npuw/attn/attn_subgraph.cpp | Creates/shares infer requests for the no-mask model and selects masked vs no-mask execution at runtime. |
| kv_tile_offset, | ||
| mask_tile_offset, | ||
| tile_size); | ||
| tile_size, | ||
| false, | ||
| use_mask); |
| // If the regular tile is not fully filled, need to use the mask | ||
| const bool use_mask = (actual_kv_length + 1) % tile_size != 0; | ||
| const bool use_no_mask_model = !use_mask && hfa_desc->_compiled_tile_no_mask_model; | ||
| auto& regular_tile_request = | ||
| use_no_mask_model ? state.hfa_requests.infer_requests[HFARequestSet::REGULAR_TILE_NO_MASK] |
esmirno
left a comment
There was a problem hiding this comment.
overall LGTM - but tests better to always create since otherwise not clear what this change is fixing.
| REGULAR_TILE = 0, | ||
| FINAL_TILE = 1, | ||
| COUNT = 2, | ||
| REGULAR_TILE_MASK = 0, |
There was a problem hiding this comment.
if it is regular tile with mask - may be better to have TILE_WITH_MASK and REGULAR_TILE
| state.hfa_requests.pipeline_requests[HFARequestSet::REGULAR_TILE] = | ||
| state.hfa_requests.pipeline_requests[HFARequestSet::REGULAR_TILE_MASK] = | ||
| hfa->_compiled_tile_model->create_infer_request(); | ||
| if (hfa->_compiled_tile_no_mask_model) { |
There was a problem hiding this comment.
this is clear place to spot a problem - initial no_mask_model might refer missing a model, so better to have compiled_tile, compiled_tile_with_mask
| // ======================================================================== | ||
| HostFlashAttention hfa; | ||
| hfa._tile_model = tile_model; | ||
| if (fused_flash_attention) { |
There was a problem hiding this comment.
do we have any tests for HFA - i think @intelgaoxiong introduced one - interestingly how non of it is failing with such change - could you please add some tests that shows MASKS behavior
Details:
PR adds logic for working with a attention mask:
New subgraph without mask input added.
Tickets:
AI Assistance: