Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1030,16 +1030,17 @@ JitConstants SDPAMicroGenerator::get_jit_constants(const kernel_impl_params& par

auto data_inputs_num = micro_get_input_num(params, config);

size_t attn_input_idx = 3;
size_t scale_input_idx = 4;
jit.make("IS_CAUSAL", config.is_causal);
if (!config.is_paged_attention) {
const auto desc = params.typed_desc<scaled_dot_product_attention>();
const bool has_attn_mask_input = has_runtime_attn_mask_input(params, *desc);
if (config.has_const_attn_mask_val) {
jit.make("WITH_ATTN_MASK", 0);
jit.make("STATIC_SCALAR_ATTN_MASK_VALUE", config.attn_mask_val);
// scale_input_idx -= 1;
} else {
jit.make("WITH_ATTN_MASK", data_inputs_num > attn_input_idx);
jit.make("WITH_ATTN_MASK", has_attn_mask_input ? 1 : 0);
}
} else {
jit.make("WITH_ATTN_MASK", 0);
Expand Down Expand Up @@ -1261,7 +1262,7 @@ JitConstants SDPAMicroGenerator::get_jit_constants(const kernel_impl_params& par
jit.add(unit_parameters("VAL"));
jit.add(unit_parameters("DST"));

if (data_inputs_num > 3 && !config.has_const_attn_mask_val) {
if (data_inputs_num > 3 && !config.is_paged_attention && has_runtime_attn_mask_input(params, *params.typed_desc<scaled_dot_product_attention>())) {
jit.add(convert_strides("MSK", "INPUT3", {0, 1, 2, 3}));
jit.add(unit_parameters("MSK"));
}
Expand Down Expand Up @@ -1325,8 +1326,9 @@ Arguments SDPAMicroGenerator::get_arguments_desc(const kernel_impl_params& param
args.push_back({ArgumentDescriptor::Types::INPUT, ScaledDotProductAttentionInputIdx::VALUE}); // V
args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); // A

const auto desc = params.typed_desc<scaled_dot_product_attention>();
const uint32_t attn_mask_idx = ScaledDotProductAttentionInputIdx::ATTN_MASK;
if (config.input_num > attn_mask_idx && !config.has_const_attn_mask_val)
if (has_runtime_attn_mask_input(params, *desc))
args.push_back({ArgumentDescriptor::Types::INPUT, attn_mask_idx}); // mask
const uint32_t scale_idx = ScaledDotProductAttentionInputIdx::SCALE;
if (config.input_num > scale_idx && !config.has_const_scale_val)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ JitConstants SDPAOptGeneratorBase::get_jit_constants_base(const kernel_impl_para

size_t data_inputs_num = get_data_inputs_num(*desc);
size_t attn_mask_idx = ScaledDotProductAttentionInputIdx::ATTN_MASK;
const bool has_attn_mask_input = has_runtime_attn_mask_input(params, *desc);
if (desc->attn_mask_val.has_value()) {
jit.make("STATIC_SCALAR_ATTN_MASK_VALUE", desc->attn_mask_val.value());
jit.make("HAS_ATTN_MASK_INPUT", 0);
} else {
const bool has_attn_mask_input = data_inputs_num > attn_mask_idx;
jit.make("HAS_ATTN_MASK_INPUT", has_attn_mask_input ? 1 : 0);
if (has_attn_mask_input) {
const auto& attn_mask_layout = params.get_input_layout(attn_mask_idx);
Expand Down Expand Up @@ -122,8 +122,9 @@ Arguments SDPAOptGeneratorBase::get_arguments_desc_impl(const kernel_impl_params

const size_t attn_mask_idx = ScaledDotProductAttentionInputIdx::ATTN_MASK;
const size_t scale_idx = ScaledDotProductAttentionInputIdx::SCALE;
const bool has_attn_mask_input = has_runtime_attn_mask_input(params, *desc);
for (uint32_t i = 0; i < data_inputs_num; i++) {
if (i == attn_mask_idx && desc->attn_mask_val.has_value())
if (i == attn_mask_idx && !has_attn_mask_input)
continue;
if (i == scale_idx && desc->scale_val.has_value())
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ class SDPARefGenerator : public SDPABase {
jit.add(make_type_jit_constants("ACCUMULATOR", get_accumulator_type(params)));

size_t data_inputs_num = get_data_inputs_num(*desc);
size_t attn_mask_idx = ScaledDotProductAttentionInputIdx::ATTN_MASK;
if (data_inputs_num > attn_mask_idx) {
if (has_runtime_attn_mask_input(params, *desc)) {
jit.make("HAS_ATTN_MASK_INPUT", 1);
}
size_t scale_idx = ScaledDotProductAttentionInputIdx::SCALE;
Expand Down Expand Up @@ -67,6 +66,9 @@ class SDPARefGenerator : public SDPABase {
}

for (uint32_t i = 0; i < data_inputs_num; i++) {
if (i == ScaledDotProductAttentionInputIdx::ATTN_MASK && !has_runtime_attn_mask_input(params, *desc)) {
continue;
}
args.push_back({ArgumentDescriptor::Types::INPUT, i});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@ inline size_t get_data_inputs_num(const cldnn::scaled_dot_product_attention& des
return data_inputs_num;
}

inline bool has_runtime_attn_mask_input(const cldnn::kernel_impl_params& params, const cldnn::scaled_dot_product_attention& desc) {
if (desc.attn_mask_val.has_value()) {
return false;
}

if (get_data_inputs_num(desc) <= cldnn::scaled_dot_product_attention::ScaledDotProductAttentionInputIdx::ATTN_MASK) {
return false;
}

const auto& attn_mask_pshape =
params.get_input_layout(cldnn::scaled_dot_product_attention::ScaledDotProductAttentionInputIdx::ATTN_MASK).get_partial_shape();

// Keep scalar and 1D placeholders out of the real attention-mask path.
if (attn_mask_pshape.rank().is_static() && attn_mask_pshape.rank().get_length() <= 1) {
return false;
}

return true;
}

inline size_t get_key_cache_id(const cldnn::scaled_dot_product_attention& desc) {
size_t key_cache_id = desc.input_size();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,44 @@ ConvertMatMulToFullyConnected::ConvertMatMulToFullyConnected(bool supports_immad
auto fc_input_a = pattern_map.at(activations_m);
auto fc_input_b = pattern_map.at(weights_m);

auto introduces_non_trivial_batch_broadcast = [](const ov::PartialShape& original_shape,
const ov::PartialShape& broadcasted_shape) {
if (!original_shape.rank().is_static() || !broadcasted_shape.rank().is_static()) {
return false;
}

const auto original_rank = static_cast<size_t>(original_shape.rank().get_length());
const auto broadcasted_rank = static_cast<size_t>(broadcasted_shape.rank().get_length());
if (broadcasted_rank < 2 || original_rank > broadcasted_rank) {
return false;
}

ov::PartialShape aligned_original_shape = original_shape;
for (size_t i = 0, cnt = broadcasted_rank - original_rank; i < cnt; ++i) {
aligned_original_shape.insert(aligned_original_shape.begin(), 1);
}

for (size_t i = 0; i < broadcasted_rank - 2; ++i) {
const auto& original_dim = aligned_original_shape[i];
const auto& broadcasted_dim = broadcasted_shape[i];
if (original_dim == 1 && broadcasted_dim.is_static() && broadcasted_dim.get_length() != 1) {
return true;
}
}

return false;
};

auto mul2_it = pattern_map.find(mul2_m);
if (mul2_it != pattern_map.end() && mul2_it->second.get_node_shared_ptr() == fc_input_b.get_node_shared_ptr()) {
const auto reshape_output = pattern_map.at(reshape_m);
// Keep valid 3D compressed FC cases enabled. Only reject the extra post-reshape
// multiply when it reintroduces a real batched RHS through broadcasting.
if (introduces_non_trivial_batch_broadcast(reshape_output.get_partial_shape(), fc_input_b.get_partial_shape())) {
return false;
}
}

// If 'fc_input_b' is shared with another matmul, transposing 'fc_input_b' is restricted.
// If it is connected to the 'input_a' of another matmul, do not transpose
// If it is connected to the 'input_b' of another matmul and the transpose option differs between the two matmuls, do not transpose.
Expand Down Expand Up @@ -153,7 +191,7 @@ ConvertMatMulToFullyConnected::ConvertMatMulToFullyConnected(bool supports_immad
};

bool is_compressed_weight = ((pattern_map.find(compressed_weights_input_m) != pattern_map.end())
&& (pattern_map.at(compressed_weights_input_m).get_node_shared_ptr() != nullptr));
&& (pattern_map.at(compressed_weights_input_m).get_node_shared_ptr() != nullptr));
bool success = true;
ov::PartialShape shape_a_aligned;
ov::PartialShape shape_b_aligned;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,17 @@ void ScaledAttnLayerGPUTest::SetUp() {
manager.run_passes(functionRefs);

auto it = std::find_if(inputShapes[1].second.begin(), inputShapes[1].second.end(), [&](const ov::Shape& shape){
return shape[0] >= 128 || shape[2] >= 384 || shape[3] >= 128;
if (shape.empty()) {
return false;
}

const auto rank = shape.size();
const auto seq_idx = rank >= 2 ? rank - 2 : 0;
const auto head_idx = rank - 1;
return shape[0] >= 128 || shape[seq_idx] >= 384 || shape[head_idx] >= 128;
});

bool has_diff_head_size = inputShapes[1].first.begin()[3] != inputShapes[2].first.begin()[3];
bool has_diff_head_size = inputShapes[1].first[-1] != inputShapes[2].first[-1];

bool has_long_seq = it != inputShapes[1].second.end();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,12 +405,19 @@ const std::vector<bool> transpose_weights = {true, false};
const std::vector<bool> param_weights = {true, false};
const std::vector<ShapeParams> input_shapes_basic = {
{{{-1, -1, -1}, {{1, 4, 16}, {10, 16, 16}}}, {16, 32}},
{{{}, {{1, 4, 16}}}, {16, 32}, 2ul},
{{{}, {{1, 4, 16}}}, {1, 16, 32}},
{{{}, {{1, 4, 48}}}, {48, 256}},
{{{}, {{11, 339, 377}}}, {377, 335}}
};

const std::vector<ShapeParams> input_shapes_extra_multiply_supported = {
{{{}, {{1, 4, 2}}}, {2, 32}, 2ul},
};

const std::vector<ShapeParams> input_shapes_extra_multiply_broadcast_3d = {
{{{}, {{1, 4, 16}}}, {16, 32}, 2ul},
};

INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_basic),
Expand All @@ -428,7 +435,22 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic,

INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_extra_multiply,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_basic),
::testing::Combine(::testing::ValuesIn(input_shapes_extra_multiply_supported),
::testing::ValuesIn(weights_precisions),
::testing::ValuesIn(activations_precisions),
::testing::Values(false),
::testing::Values(false),
::testing::Values(false),
::testing::Values(true),
::testing::Values(false),
::testing::ValuesIn(param_weights),
::testing::Values(0),
::testing::Values(1.0f)),
MatmulWeightsDecompression::get_test_case_name);

INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_extra_multiply_broadcast_3D,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_extra_multiply_broadcast_3d),
::testing::ValuesIn(weights_precisions),
::testing::ValuesIn(activations_precisions),
::testing::Values(false),
Expand Down
88 changes: 88 additions & 0 deletions src/plugins/intel_gpu/tests/unit/test_cases/sdpa_gpu_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,4 +360,92 @@ TEST(sdpa_gpu_custom, single_token_cond_attn_mask_clamp) {
ASSERT_NEAR(static_cast<float>(ref_ptr[hs]), static_cast<float>(output_ptr[hs]), 1e-2f);
}
}

TEST(sdpa_gpu_custom, scalar_placeholder_mask_matches_scale_only) {
tests::random_generator rg;
rg.set_seed(GET_SUITE_NAME);
auto& engine = get_test_engine();

const int batch = 1;
const int seq_length_q = 4;
const int seq_length_kv = 6;
const int num_heads = 2;
const int head_size = 32;
const float scale_val = 0.35f;

const layout q_layout({batch, seq_length_q, num_heads, head_size}, data_types::f16, format::bfyx);
const layout k_layout({batch, seq_length_kv, num_heads, head_size}, data_types::f16, format::bfyx);
const layout v_layout({batch, seq_length_kv, num_heads, head_size}, data_types::f16, format::bfyx);
const layout scalar_mask_layout{ov::PartialShape{}, data_types::f16, format::bfyx};

auto q_mem = engine.allocate_memory(q_layout);
auto k_mem = engine.allocate_memory(k_layout);
auto v_mem = engine.allocate_memory(v_layout);
auto scalar_mask_mem = engine.allocate_memory(scalar_mask_layout);

auto fill_random = [&](const memory::ptr& mem) {
const auto shape = mem->get_layout().get_shape();
const size_t elements_num = ov::shape_size(shape);
auto data = rg.generate_random_1d<ov::float16>(elements_num, -1.0f, 1.0f);
set_values(mem, data);
};

fill_random(q_mem);
fill_random(k_mem);
fill_random(v_mem);
set_values(scalar_mask_mem, {ov::float16(1.0f)});

auto run_sdpa = [&](bool use_placeholder_mask) {
topology topo;
topo.add(input_layout("q", q_layout));
topo.add(input_layout("k", k_layout));
topo.add(input_layout("v", v_layout));
std::vector<input_info> inputs = {input_info("q"), input_info("k"), input_info("v")};
if (use_placeholder_mask) {
topo.add(input_layout("mask", scalar_mask_layout));
inputs.push_back(input_info("mask"));
}

auto sdpa_prim = scaled_dot_product_attention("sdpa",
inputs,
false,
-1,
{0, 2, 1, 3},
{0, 2, 1, 3},
{0, 2, 1, 3},
{0, 1, 2, 3},
{},
false);
sdpa_prim.scale_val = scale_val;

topo.add(sdpa_prim);
topo.add(reorder("result", input_info("sdpa"), format::bfyx, data_types::f16));

ExecutionConfig cfg = get_test_default_config(engine);
cfg.set_property(ov::intel_gpu::allow_new_shape_infer(true));
cfg.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{{"sdpa", {format::type::bfyx, "sdpa_opt"}}}));

auto network = get_network(engine, topo, cfg, get_test_stream_ptr(), false);
network->set_input_data("q", q_mem);
network->set_input_data("k", k_mem);
network->set_input_data("v", v_mem);
if (use_placeholder_mask) {
network->set_input_data("mask", scalar_mask_mem);
}

return network->execute().at("result").get_memory();
};

auto output_without_mask = run_sdpa(false);
auto output_with_placeholder_mask = run_sdpa(true);

cldnn::mem_lock<ov::float16, mem_lock_type::read> without_mask_ptr(output_without_mask, get_test_stream());
cldnn::mem_lock<ov::float16, mem_lock_type::read> with_placeholder_mask_ptr(output_with_placeholder_mask, get_test_stream());

ASSERT_EQ(without_mask_ptr.size(), with_placeholder_mask_ptr.size());
for (size_t i = 0; i < without_mask_ptr.size(); ++i) {
ASSERT_NEAR(static_cast<float>(without_mask_ptr[i]), static_cast<float>(with_placeholder_mask_ptr[i]), 1e-3f)
<< "Mismatch at index " << i;
}
}
} // namespace
Loading
Loading