disable kv cache broadcast for better performance#36118
Open
nazanin-beheshti wants to merge 13 commits into
Open
disable kv cache broadcast for better performance#36118nazanin-beheshti wants to merge 13 commits into
nazanin-beheshti wants to merge 13 commits into
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR aims to improve performance for models using Grouped-Query Attention (e.g., MSFT Orca) by avoiding KV-cache/head broadcast behavior in the generated OpenVINO graph, and adjusting ScaledDotProductAttention (SDPA) shape inference/tests accordingly.
Changes:
- Removes KV head “broadcast via concat/reshape” logic from
GroupQueryAttentionDecompositionto avoid expanding KV tensors to Query head count. - Updates SDPA shape inference/tests to stop propagating Key/Value batch dims into the output shape (reducing broadcast-driven shape effects).
- Updates SDPA type-prop expectations/error-message assertions to match the new shape inference behavior.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
src/core/tests/type_prop/scaled_dot_product_attention.cpp |
Updates SDPA type-prop expected output shapes and thrown-message substrings to match new inference behavior. |
src/core/shape_inference/include/scaled_dot_product_attention_shape_inference.hpp |
Modifies SDPA shape validation/inference to avoid broadcast-merging Key/Value leading dims into output; currently contains a Value-shape validation bug. |
src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp |
Removes explicit KV broadcasting in GQA decomposition to avoid performance overhead from KV expansion. |
Comments suppressed due to low confidence (1)
src/core/shape_inference/include/scaled_dot_product_attention_shape_inference.hpp:50
- [HIGH] After removing
broadcast_merge_into(...)from Key/Value validation, the shape inference no longer checks that Key/Value leading dimensions are compatible with Query at all (e.g., mismatched batch dimension can now silently pass). If the intent is to support GQA without explicit broadcast nodes, consider reintroducing non-broadcasting compatibility checks for the true batch dims and a GQA head-dimension rule (Q heads divisible by KV heads), similar toFlashAttentionTileshape inference (src/plugins/intel_npu/src/ops/src/intel_npu/ops/flash_attention_tile.cpp:163-215). This would keep type inference rejecting genuinely incompatible shapes while still avoiding KV broadcast.
const auto& key = input_shapes[1];
const auto& key_rank = key.rank();
if (key_rank.is_static()) {
const bool& key_input_correctness =
key_rank.get_length() >= 3 && DimType::merge(e_dim, e_dim, *(key.end() - 1));
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
key_input_correctness,
"Key input shape not compatible with other inputs.");
s_dim = *(key.end() - 2);
Comment on lines
+56
to
+57
| const bool& value_input_correctness = | ||
| value_rank.get_length() >= 3 && | ||
| TRShape::broadcast_merge_into(n_dims, | ||
| TRShape(std::vector<DimType>(value.begin(), value.end() - 2)), | ||
| AutoBroadcastType::NUMPY) && | ||
| DimType::merge(s_dim, s_dim, *(value.end() - 2)); | ||
| key_rank.get_length() >= 3 && DimType::merge(e_dim, e_dim, *(key.end() - 1)); |
Kotomi-Du
reviewed
Jun 2, 2026
| std::make_shared<op::v13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal); | ||
| EXPECT_EQ(op->get_output_element_type(0), element::f64); | ||
| EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{4, 3, {4, 5}, {3, 7}})); | ||
| EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{2, 4}, 3, {4, 5}, {3, 7}})); |
Contributor
There was a problem hiding this comment.
How broadcast behavior impacts the shape which has range?
Contributor
|
build_jenkins |
Kotomi-Du
reviewed
Jun 2, 2026
Kotomi-Du
reviewed
Jun 3, 2026
Kotomi-Du
reviewed
Jun 3, 2026
Kotomi-Du
reviewed
Jun 3, 2026
| auto op = std::make_shared<op::v13::ScaledDotProductAttention>(query, key, value, attention_mask, causal), | ||
| AssertFailure, | ||
| testing::HasSubstr("Value input shape not compatible with other inputs.")); | ||
| testing::HasSubstr("Attention mask input shape not compatible with other inputs.")); |
Contributor
There was a problem hiding this comment.
it seems you add some unintentional change which is incorrect.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Details:
For MSFT Orca model, it has different num_head for query and key/value. It leads to a broadcast behavior in the graph which introduce a big overhead for the performance. This ticket is created to eliminate the broadcast and get good conformance for Orca pipeline.
Tickets:
AI Assistance: