diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking.hpp index 853d9434a4d5dc..2bfbec7caf9cee 100644 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking.hpp @@ -14,6 +14,7 @@ namespace ov { namespace pass { class TRANSFORMATIONS_API TransposeSinking; +class TRANSFORMATIONS_API TransposeFQ; class TRANSFORMATIONS_API TransposeConvert; class TRANSFORMATIONS_API TransposeEltwise; class TRANSFORMATIONS_API TransposeReduction; @@ -23,6 +24,16 @@ class TRANSFORMATIONS_API TransposeFuse; } // namespace pass } // namespace ov +/** + * @ingroup ov_transformation_common_api + * @brief TransposeFQ transformation sinks Transpose through FakeQuantize + */ +class ov::pass::TransposeFQ : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("TransposeFQ"); + TransposeFQ(); +}; + /** * @ingroup ov_transformation_common_api * @brief TransposeReduction transformation sinks Transpose through Reduce operations @@ -83,6 +94,7 @@ class ov::pass::TransposeSinking : public ov::pass::GraphRewrite { public: OPENVINO_GRAPH_REWRITE_RTTI("TransposeSinking"); TransposeSinking() { + add_matcher(); add_matcher(); add_matcher(); add_matcher(); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp index a883ba1275922d..60a667a72ac307 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp @@ -78,6 +78,77 @@ std::shared_ptr get_reversed_order_constant(const std::shared_ptr< } // namespace +ov::pass::TransposeFQ::TransposeFQ() { + MATCHER_SCOPE(TransposeFQ); + + auto transpose_order_m = wrap_type(); + auto transpose_label = wrap_type({any_input(pattern::has_static_rank()), transpose_order_m}); + auto fq_label = wrap_type({transpose_label, + any_input(ov::pass::pattern::has_static_rank()), + any_input(ov::pass::pattern::has_static_rank()), + any_input(ov::pass::pattern::has_static_rank()), + any_input(ov::pass::pattern::has_static_rank())}, + consumers_count(1)); + + matcher_pass_callback matcher_pass_callback = [OV_CAPTURE_CPY_AND_THIS](Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + + auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); + auto fq = pattern_to_output.at(fq_label).get_node_shared_ptr(); + auto transpose_order = + ov::as_type_ptr(pattern_to_output.at(transpose_order_m).get_node_shared_ptr()); + if (!transpose_order || !fq) + return false; + + ov::NodeVector new_ops; + + const auto& reverse_order_constant = get_reversed_order_constant(transpose_order); + new_ops.push_back(reverse_order_constant); + + const auto& input_rank = fq->get_input_partial_shape(0).rank().get_length(); + ov::OutputVector fq_inputs = {transpose->input_value(0)}; + for (size_t i = 1; i < fq->inputs().size(); ++i) { + auto input = fq->input_value(i); + if (ov::shape_size(input.get_shape()) == 1) { + fq_inputs.push_back(input); + continue; + } + + const auto& range_rank = input.get_partial_shape().rank().get_length(); + if (range_rank > input_rank) + return false; + + const auto& ranks_diff = input_rank - range_rank; + if (ranks_diff > 0) { + std::vector axes(ranks_diff); + std::iota(axes.begin(), axes.end(), 0); + const auto& axes_const = v0::Constant::create(element::i64, Shape{axes.size()}, axes); + new_ops.push_back(axes_const); + const auto& unsqueezed_input = op_util::make_try_fold(input, axes_const); + new_ops.push_back(unsqueezed_input); + input = unsqueezed_input->output(0); + } + const auto& transposed_input = op_util::make_try_fold(input, reverse_order_constant); + new_ops.push_back(transposed_input); + fq_inputs.push_back(transposed_input); + } + + auto new_fq = fq->clone_with_new_inputs(fq_inputs); + new_ops.push_back(new_fq); + + auto new_transpose = register_new_node(new_fq, transpose_order); + new_ops.push_back(new_transpose); + new_transpose->set_friendly_name(fq->get_friendly_name()); + + ov::copy_runtime_info({fq, transpose}, new_ops); + ov::replace_node(fq, new_transpose); + return true; + }; + + auto m = std::make_shared(fq_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} + ov::pass::TransposeEltwise::TransposeEltwise() { MATCHER_SCOPE(TransposeEltwise); diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_test.cpp b/src/common/transformations/tests/common_optimizations/transpose_sinking_test.cpp index a1287ada70f84f..1189cc2b62c4c2 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_sinking_test.cpp +++ b/src/common/transformations/tests/common_optimizations/transpose_sinking_test.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include "common_test_utils/ov_test_utils.hpp" @@ -105,6 +106,55 @@ class TransposeSinkingFQ : public ov::test::TestsCommon, } }; +struct TransposeFQTransposeFuseParams { + Shape ranges_shape; + Shape expected_ranges_shape; +}; + +class TransposeSinkingFQTransposeFuse : public ov::test::TestsCommon, + public testing::WithParamInterface> { +public: + std::shared_ptr f, f_ref; + + void SetUp() override { + const auto& test_case = std::get<0>(GetParam()); + std::vector fq_values(shape_size(test_case.ranges_shape)); + std::iota(fq_values.begin(), fq_values.end(), 0.f); + + { + auto input = std::make_shared(element::f32, PartialShape{1, 3, 4, 5}); + + auto first_order = + std::make_shared(element::i64, Shape{4}, std::vector{0, 2, 3, 1}); + auto first_transpose = std::make_shared(input, first_order); + + auto i_low = std::make_shared(element::f32, test_case.ranges_shape, fq_values); + auto i_high = std::make_shared(element::f32, test_case.ranges_shape, fq_values); + auto o_low = std::make_shared(element::f32, test_case.ranges_shape, fq_values); + auto o_high = std::make_shared(element::f32, test_case.ranges_shape, fq_values); + auto fq = std::make_shared(first_transpose, i_low, i_high, o_low, o_high, 256); + + auto second_order = + std::make_shared(element::i64, Shape{4}, std::vector{0, 3, 1, 2}); + auto second_transpose = std::make_shared(fq, second_order); + + f = std::make_shared(OutputVector{second_transpose}, ParameterVector{input}); + } + + { + auto input = std::make_shared(element::f32, PartialShape{1, 3, 4, 5}); + + auto i_low = std::make_shared(element::f32, test_case.expected_ranges_shape, fq_values); + auto i_high = std::make_shared(element::f32, test_case.expected_ranges_shape, fq_values); + auto o_low = std::make_shared(element::f32, test_case.expected_ranges_shape, fq_values); + auto o_high = std::make_shared(element::f32, test_case.expected_ranges_shape, fq_values); + auto fq = std::make_shared(input, i_low, i_high, o_low, o_high, 256); + + f_ref = std::make_shared(OutputVector{fq}, ParameterVector{input}); + } + } +}; + TEST_P(TransposeSinkingFQ, TransposeFQReduce) { auto unh = std::make_shared(); pass::Manager manager; @@ -155,6 +205,34 @@ INSTANTIATE_TEST_SUITE_P(TransformationTest, {2, 3}, {0, 1}})); +TEST_P(TransposeSinkingFQTransposeFuse, TransposeFQTransposeFuse) { + auto unh = std::make_shared(); + pass::Manager manager; + manager.register_pass(unh); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(unh); + manager.run_passes(f); + OV_ASSERT_NO_THROW(check_rt_info(f)); + + auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::NODES) + .enable(FunctionsComparator::PRECISIONS) + .enable(FunctionsComparator::CONST_VALUES); + auto res = fc.compare(f, f_ref); + ASSERT_TRUE(res.valid) << res.message; +} + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommon, + TransposeSinkingFQTransposeFuse, + testing::Values(TransposeFQTransposeFuseParams{{}, {}}, + TransposeFQTransposeFuseParams{{1}, {1}}, + TransposeFQTransposeFuseParams{{1, 1, 1, 1}, {1, 1, 1, 1}}, + TransposeFQTransposeFuseParams{{3}, {1, 3, 1, 1}}, + TransposeFQTransposeFuseParams{{1, 3}, {1, 3, 1, 1}}, + TransposeFQTransposeFuseParams{{1, 1, 3}, {1, 3, 1, 1}}, + TransposeFQTransposeFuseParams{{1, 1, 1, 3}, {1, 3, 1, 1}})); + struct TransposeReduceParams { // given params PartialShape transpose_input_shape;