Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,45 @@ ConvertMatMulToFullyConnected::ConvertMatMulToFullyConnected(bool supports_immad
auto fc_input_a = pattern_map.at(activations_m);
auto fc_input_b = pattern_map.at(weights_m);

auto introduces_non_trivial_batch_broadcast = [](const ov::PartialShape& original_shape,
const ov::PartialShape& broadcasted_shape) {
if (!original_shape.rank().is_static() || !broadcasted_shape.rank().is_static()) {
return false;
}

const auto original_rank = static_cast<size_t>(original_shape.rank().get_length());
const auto broadcasted_rank = static_cast<size_t>(broadcasted_shape.rank().get_length());
if (broadcasted_rank < 2 || original_rank > broadcasted_rank) {
return false;
}

ov::PartialShape aligned_original_shape = original_shape;
for (size_t i = 0, cnt = broadcasted_rank - original_rank; i < cnt; ++i) {
aligned_original_shape.insert(aligned_original_shape.begin(), 1);
}

for (size_t i = 0; i < broadcasted_rank - 2; ++i) {
const auto& original_dim = aligned_original_shape[i];
const auto& broadcasted_dim = broadcasted_shape[i];
if (original_dim == 1 && broadcasted_dim.is_static() && broadcasted_dim.get_length() != 1) {
return true;
}
}

return false;
};

auto mul2_it = pattern_map.find(mul2_m);
if (mul2_it != pattern_map.end() && mul2_it->second.get_node_shared_ptr() == fc_input_b.get_node_shared_ptr()) {
const auto reshape_output = pattern_map.at(reshape_m);
// Keep valid 3D compressed FC cases enabled. Only reject the extra post-reshape multiply when broadcasting changes the weights
// from a shared matrix into data with real batch dimensions. For example, reshape may first squeeze the weights to [16, 32],
// then an extra multiply with scale [8, 1, 32] broadcasts them to [8, 16, 32], which makes the weights effectively batched again.
if (introduces_non_trivial_batch_broadcast(reshape_output.get_partial_shape(), fc_input_b.get_partial_shape())) {
return false;
}
}

// If 'fc_input_b' is shared with another matmul, transposing 'fc_input_b' is restricted.
// If it is connected to the 'input_a' of another matmul, do not transpose
// If it is connected to the 'input_b' of another matmul and the transpose option differs between the two matmuls, do not transpose.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,15 @@ const std::vector<ShapeParams> input_shapes_basic = {
{{{}, {{11, 339, 377}}}, {377, 335}}
};

const std::vector<ShapeParams> input_shapes_extra_multiply = {
{{{}, {{1, 4, 2}}}, {2, 32}, 2ul},
{{{}, {{1, 4, 16}}}, {1, 16, 32}},
};

const std::vector<ShapeParams> input_shapes_extra_multiply_non_trivial_batch_broadcast = {
{{{}, {{1, 4, 16}}}, {16, 32}, 2ul},
};

INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_basic),
Expand All @@ -428,7 +437,22 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic,

INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_extra_multiply,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_basic),
::testing::Combine(::testing::ValuesIn(input_shapes_extra_multiply),
::testing::ValuesIn(weights_precisions),
::testing::ValuesIn(activations_precisions),
::testing::Values(false),
::testing::Values(false),
::testing::Values(false),
::testing::Values(true),
::testing::Values(false),
::testing::ValuesIn(param_weights),
::testing::Values(0),
::testing::Values(1.0f)),
MatmulWeightsDecompression::get_test_case_name);

INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_extra_multiply_non_trivial_batch_broadcast_no_convert,
MatmulWeightsDecompression,
Comment thread
yuanxion marked this conversation as resolved.
::testing::Combine(::testing::ValuesIn(input_shapes_extra_multiply_non_trivial_batch_broadcast),
::testing::ValuesIn(weights_precisions),
::testing::ValuesIn(activations_precisions),
::testing::Values(false),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,72 @@ TEST_F(TransformationTestsF, ConvertMatMulToFullyConnectedTest_compressed_u8_par
}
}

TEST_F(TransformationTestsF, ConvertMatMulToFullyConnectedTest_compressed_u8_weights_extra_multiply) {
{
auto data = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::Shape{1, 4, 2});
auto weights = ov::opset1::Constant::create(ov::element::u8, ov::Shape{1, 2, 32}, {1});
auto convert = std::make_shared<ov::opset1::Convert>(weights, ov::element::f16);
auto mul_const = ov::opset1::Constant::create(ov::element::f16, ov::Shape{1, 1, 32}, {1});
auto mul = std::make_shared<ov::opset1::Multiply>(convert, mul_const);
auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {2, 32});
auto reshape = std::make_shared<ov::opset1::Reshape>(mul, reshape_const, false);
auto extra_mul = std::make_shared<ov::opset1::Multiply>(reshape, mul_const);
auto matmul = std::make_shared<ov::opset1::MatMul>(data, extra_mul);

model = std::make_shared<ov::Model>(ov::OutputVector{matmul}, ov::ParameterVector{data});
bool support_immad = true;
manager.register_pass<ConvertMatMulToFullyConnected>(support_immad);
}
{
auto data = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::Shape{1, 4, 2});
auto weights = ov::opset1::Constant::create(ov::element::u8, ov::Shape{1, 2, 32}, {1});
auto convert = std::make_shared<ov::opset1::Convert>(weights, ov::element::f16);
auto mul_const = ov::opset1::Constant::create(ov::element::f16, ov::Shape{1, 1, 32}, {1});
auto mul = std::make_shared<ov::opset1::Multiply>(convert, mul_const);
auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {2, 32});
auto reshape = std::make_shared<ov::opset1::Reshape>(mul, reshape_const, false);
auto extra_mul = std::make_shared<ov::opset1::Multiply>(reshape, mul_const);

auto transpose_const = ov::opset1::Constant::create(ov::element::i32, {3}, {0, 2, 1});
auto transpose = std::make_shared<ov::opset1::Transpose>(extra_mul, transpose_const);
auto no_bias = std::make_shared<ov::intel_gpu::op::Placeholder>();
auto matmul = std::make_shared<op::FullyConnected>(data, transpose, no_bias);

model_ref = std::make_shared<ov::Model>(ov::OutputVector{matmul}, ov::ParameterVector{data});
}
}

TEST_F(TransformationTestsF, ConvertMatMulToFullyConnectedTest_compressed_u8_weights_extra_multiply_non_trivial_batch_broadcast) {
{
auto data = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::Shape{1, 4, 16});
auto weights = ov::opset1::Constant::create(ov::element::u8, ov::Shape{8, 2, 32}, {1});
auto convert = std::make_shared<ov::opset1::Convert>(weights, ov::element::f16);
auto mul_const = ov::opset1::Constant::create(ov::element::f16, ov::Shape{8, 1, 32}, {1});
auto mul = std::make_shared<ov::opset1::Multiply>(convert, mul_const);
auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {16, 32});
auto reshape = std::make_shared<ov::opset1::Reshape>(mul, reshape_const, false);
auto extra_mul = std::make_shared<ov::opset1::Multiply>(reshape, mul_const);
auto matmul = std::make_shared<ov::opset1::MatMul>(data, extra_mul);

model = std::make_shared<ov::Model>(ov::OutputVector{matmul}, ov::ParameterVector{data});
bool support_immad = true;
manager.register_pass<ConvertMatMulToFullyConnected>(support_immad);
}
{
auto data = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::Shape{1, 4, 16});
auto weights = ov::opset1::Constant::create(ov::element::u8, ov::Shape{8, 2, 32}, {1});
auto convert = std::make_shared<ov::opset1::Convert>(weights, ov::element::f16);
auto mul_const = ov::opset1::Constant::create(ov::element::f16, ov::Shape{8, 1, 32}, {1});
auto mul = std::make_shared<ov::opset1::Multiply>(convert, mul_const);
auto reshape_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {16, 32});
auto reshape = std::make_shared<ov::opset1::Reshape>(mul, reshape_const, false);
auto extra_mul = std::make_shared<ov::opset1::Multiply>(reshape, mul_const);
auto matmul = std::make_shared<ov::opset1::MatMul>(data, extra_mul);

model_ref = std::make_shared<ov::Model>(ov::OutputVector{matmul}, ov::ParameterVector{data});
}
}

TEST_F(TransformationTestsF, ConvertMatMulToFullyConnectedTest_compressed_u4_weights_3D) {
{
auto data = std::make_shared<ov::opset1::Parameter>(ov::element::f16, ov::Shape{3, 2, 2});
Expand Down
Loading