Skip to content

disable kv cache broadcast for better performance#36118

Open
nazanin-beheshti wants to merge 13 commits into
openvinotoolkit:masterfrom
nazanin-beheshti:naz/disable-kv-cache-broadcast
Open

disable kv cache broadcast for better performance#36118
nazanin-beheshti wants to merge 13 commits into
openvinotoolkit:masterfrom
nazanin-beheshti:naz/disable-kv-cache-broadcast

Conversation

@nazanin-beheshti
Copy link
Copy Markdown
Contributor

Details:

For MSFT Orca model, it has different num_head for query and key/value. It leads to a broadcast behavior in the graph which introduce a big overhead for the performance. This ticket is created to eliminate the broadcast and get good conformance for Orca pipeline.

stateless-before-after

Tickets:

AI Assistance:

  • AI assistance used: no

@nazanin-beheshti nazanin-beheshti requested review from a team as code owners May 28, 2026 18:00
@github-actions github-actions Bot added category: Core OpenVINO Core (aka ngraph) category: transformations OpenVINO Runtime library - Transformations labels May 28, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to improve performance for models using Grouped-Query Attention (e.g., MSFT Orca) by avoiding KV-cache/head broadcast behavior in the generated OpenVINO graph, and adjusting ScaledDotProductAttention (SDPA) shape inference/tests accordingly.

Changes:

  • Removes KV head “broadcast via concat/reshape” logic from GroupQueryAttentionDecomposition to avoid expanding KV tensors to Query head count.
  • Updates SDPA shape inference/tests to stop propagating Key/Value batch dims into the output shape (reducing broadcast-driven shape effects).
  • Updates SDPA type-prop expectations/error-message assertions to match the new shape inference behavior.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
src/core/tests/type_prop/scaled_dot_product_attention.cpp Updates SDPA type-prop expected output shapes and thrown-message substrings to match new inference behavior.
src/core/shape_inference/include/scaled_dot_product_attention_shape_inference.hpp Modifies SDPA shape validation/inference to avoid broadcast-merging Key/Value leading dims into output; currently contains a Value-shape validation bug.
src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp Removes explicit KV broadcasting in GQA decomposition to avoid performance overhead from KV expansion.
Comments suppressed due to low confidence (1)

src/core/shape_inference/include/scaled_dot_product_attention_shape_inference.hpp:50

  • [HIGH] After removing broadcast_merge_into(...) from Key/Value validation, the shape inference no longer checks that Key/Value leading dimensions are compatible with Query at all (e.g., mismatched batch dimension can now silently pass). If the intent is to support GQA without explicit broadcast nodes, consider reintroducing non-broadcasting compatibility checks for the true batch dims and a GQA head-dimension rule (Q heads divisible by KV heads), similar to FlashAttentionTile shape inference (src/plugins/intel_npu/src/ops/src/intel_npu/ops/flash_attention_tile.cpp:163-215). This would keep type inference rejecting genuinely incompatible shapes while still avoiding KV broadcast.
    const auto& key = input_shapes[1];
    const auto& key_rank = key.rank();
    if (key_rank.is_static()) {
        const bool& key_input_correctness =
            key_rank.get_length() >= 3 && DimType::merge(e_dim, e_dim, *(key.end() - 1));
        NODE_SHAPE_INFER_CHECK(op,
                               input_shapes,
                               key_input_correctness,
                               "Key input shape not compatible with other inputs.");
        s_dim = *(key.end() - 2);

Comment on lines +56 to +57
const bool& value_input_correctness =
value_rank.get_length() >= 3 &&
TRShape::broadcast_merge_into(n_dims,
TRShape(std::vector<DimType>(value.begin(), value.end() - 2)),
AutoBroadcastType::NUMPY) &&
DimType::merge(s_dim, s_dim, *(value.end() - 2));
key_rank.get_length() >= 3 && DimType::merge(e_dim, e_dim, *(key.end() - 1));
std::make_shared<op::v13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal);
EXPECT_EQ(op->get_output_element_type(0), element::f64);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{4, 3, {4, 5}, {3, 7}}));
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{2, 4}, 3, {4, 5}, {3, 7}}));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How broadcast behavior impacts the shape which has range?

@Kotomi-Du
Copy link
Copy Markdown
Contributor

build_jenkins

Comment thread src/core/tests/type_prop/scaled_dot_product_attention.cpp Outdated
auto op = std::make_shared<op::v13::ScaledDotProductAttention>(query, key, value, attention_mask, causal),
AssertFailure,
testing::HasSubstr("Value input shape not compatible with other inputs."));
testing::HasSubstr("Attention mask input shape not compatible with other inputs."));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems you add some unintentional change which is incorrect.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: Core OpenVINO Core (aka ngraph) category: transformations OpenVINO Runtime library - Transformations

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants