Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace ov {
namespace pass {

class TRANSFORMATIONS_API TransposeSinking;
class TRANSFORMATIONS_API TransposeFQ;
class TRANSFORMATIONS_API TransposeConvert;
class TRANSFORMATIONS_API TransposeEltwise;
class TRANSFORMATIONS_API TransposeReduction;
Expand All @@ -23,6 +24,16 @@ class TRANSFORMATIONS_API TransposeFuse;
} // namespace pass
} // namespace ov

/**
* @ingroup ov_transformation_common_api
* @brief TransposeFQ transformation sinks Transpose through FakeQuantize
*/
class ov::pass::TransposeFQ : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("TransposeFQ");
TransposeFQ();
};

/**
* @ingroup ov_transformation_common_api
* @brief TransposeReduction transformation sinks Transpose through Reduce operations
Expand Down Expand Up @@ -83,6 +94,7 @@ class ov::pass::TransposeSinking : public ov::pass::GraphRewrite {
public:
OPENVINO_GRAPH_REWRITE_RTTI("TransposeSinking");
TransposeSinking() {
add_matcher<ov::pass::TransposeFQ>();
add_matcher<ov::pass::TransposeFQReduction>();
add_matcher<ov::pass::TransposeReduction>();
add_matcher<ov::pass::TransposeConvert>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,77 @@ std::shared_ptr<v0::Constant> get_reversed_order_constant(const std::shared_ptr<

} // namespace

ov::pass::TransposeFQ::TransposeFQ() {
MATCHER_SCOPE(TransposeFQ);

auto transpose_order_m = wrap_type<v0::Constant>();
auto transpose_label = wrap_type<v1::Transpose>({any_input(pattern::has_static_rank()), transpose_order_m});
auto fq_label = wrap_type<v0::FakeQuantize>({transpose_label,
any_input(ov::pass::pattern::has_static_rank()),
any_input(ov::pass::pattern::has_static_rank()),
any_input(ov::pass::pattern::has_static_rank()),
any_input(ov::pass::pattern::has_static_rank())},
consumers_count(1));

matcher_pass_callback matcher_pass_callback = [OV_CAPTURE_CPY_AND_THIS](Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();

auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto fq = pattern_to_output.at(fq_label).get_node_shared_ptr();
auto transpose_order =
ov::as_type_ptr<v0::Constant>(pattern_to_output.at(transpose_order_m).get_node_shared_ptr());
if (!transpose_order || !fq)
return false;

ov::NodeVector new_ops;

const auto& reverse_order_constant = get_reversed_order_constant(transpose_order);
new_ops.push_back(reverse_order_constant);

const auto& input_rank = fq->get_input_partial_shape(0).rank().get_length();
ov::OutputVector fq_inputs = {transpose->input_value(0)};
for (size_t i = 1; i < fq->inputs().size(); ++i) {
auto input = fq->input_value(i);
if (ov::shape_size(input.get_shape()) == 1) {
fq_inputs.push_back(input);
continue;
}

const auto& range_rank = input.get_partial_shape().rank().get_length();
if (range_rank > input_rank)
return false;

const auto& ranks_diff = input_rank - range_rank;
if (ranks_diff > 0) {
std::vector<int64_t> axes(ranks_diff);
std::iota(axes.begin(), axes.end(), 0);
const auto& axes_const = v0::Constant::create(element::i64, Shape{axes.size()}, axes);
new_ops.push_back(axes_const);
const auto& unsqueezed_input = op_util::make_try_fold<v0::Unsqueeze>(input, axes_const);
new_ops.push_back(unsqueezed_input);
input = unsqueezed_input->output(0);
}
const auto& transposed_input = op_util::make_try_fold<v1::Transpose>(input, reverse_order_constant);
Comment thread
v-Golubev marked this conversation as resolved.
new_ops.push_back(transposed_input);
fq_inputs.push_back(transposed_input);
}

auto new_fq = fq->clone_with_new_inputs(fq_inputs);
new_ops.push_back(new_fq);

auto new_transpose = register_new_node<v1::Transpose>(new_fq, transpose_order);
new_ops.push_back(new_transpose);
new_transpose->set_friendly_name(fq->get_friendly_name());

ov::copy_runtime_info({fq, transpose}, new_ops);
ov::replace_node(fq, new_transpose);
return true;
};

auto m = std::make_shared<Matcher>(fq_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

ov::pass::TransposeEltwise::TransposeEltwise() {
MATCHER_SCOPE(TransposeEltwise);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <gtest/gtest.h>

#include <memory>
#include <numeric>
#include <string>

#include "common_test_utils/ov_test_utils.hpp"
Expand Down Expand Up @@ -105,6 +106,55 @@ class TransposeSinkingFQ : public ov::test::TestsCommon,
}
};

struct TransposeFQTransposeFuseParams {
Shape ranges_shape;
Shape expected_ranges_shape;
};

class TransposeSinkingFQTransposeFuse : public ov::test::TestsCommon,
public testing::WithParamInterface<std::tuple<TransposeFQTransposeFuseParams>> {
public:
std::shared_ptr<Model> f, f_ref;

void SetUp() override {
const auto& test_case = std::get<0>(GetParam());
std::vector<float> fq_values(shape_size(test_case.ranges_shape));
std::iota(fq_values.begin(), fq_values.end(), 0.f);

{
auto input = std::make_shared<opset6::Parameter>(element::f32, PartialShape{1, 3, 4, 5});

auto first_order =
std::make_shared<opset6::Constant>(element::i64, Shape{4}, std::vector<int64_t>{0, 2, 3, 1});
auto first_transpose = std::make_shared<opset6::Transpose>(input, first_order);

auto i_low = std::make_shared<opset6::Constant>(element::f32, test_case.ranges_shape, fq_values);
auto i_high = std::make_shared<opset6::Constant>(element::f32, test_case.ranges_shape, fq_values);
auto o_low = std::make_shared<opset6::Constant>(element::f32, test_case.ranges_shape, fq_values);
auto o_high = std::make_shared<opset6::Constant>(element::f32, test_case.ranges_shape, fq_values);
auto fq = std::make_shared<opset6::FakeQuantize>(first_transpose, i_low, i_high, o_low, o_high, 256);

auto second_order =
std::make_shared<opset6::Constant>(element::i64, Shape{4}, std::vector<int64_t>{0, 3, 1, 2});
auto second_transpose = std::make_shared<opset6::Transpose>(fq, second_order);

f = std::make_shared<ov::Model>(OutputVector{second_transpose}, ParameterVector{input});
}

{
auto input = std::make_shared<opset6::Parameter>(element::f32, PartialShape{1, 3, 4, 5});

auto i_low = std::make_shared<opset6::Constant>(element::f32, test_case.expected_ranges_shape, fq_values);
auto i_high = std::make_shared<opset6::Constant>(element::f32, test_case.expected_ranges_shape, fq_values);
auto o_low = std::make_shared<opset6::Constant>(element::f32, test_case.expected_ranges_shape, fq_values);
auto o_high = std::make_shared<opset6::Constant>(element::f32, test_case.expected_ranges_shape, fq_values);
auto fq = std::make_shared<opset6::FakeQuantize>(input, i_low, i_high, o_low, o_high, 256);

f_ref = std::make_shared<ov::Model>(OutputVector{fq}, ParameterVector{input});
}
}
};

TEST_P(TransposeSinkingFQ, TransposeFQReduce) {
auto unh = std::make_shared<ov::pass::UniqueNamesHolder>();
pass::Manager manager;
Expand Down Expand Up @@ -155,6 +205,34 @@ INSTANTIATE_TEST_SUITE_P(TransformationTest,
{2, 3},
{0, 1}}));

TEST_P(TransposeSinkingFQTransposeFuse, TransposeFQTransposeFuse) {
auto unh = std::make_shared<ov::pass::UniqueNamesHolder>();
pass::Manager manager;
manager.register_pass<ov::pass::InitUniqueNames>(unh);
manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<ov::pass::TransposeSinking>();
manager.register_pass<ov::pass::CheckUniqueNames>(unh);
manager.run_passes(f);
OV_ASSERT_NO_THROW(check_rt_info(f));

auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::NODES)
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::CONST_VALUES);
auto res = fc.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}

INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommon,
TransposeSinkingFQTransposeFuse,
testing::Values(TransposeFQTransposeFuseParams{{}, {}},
TransposeFQTransposeFuseParams{{1}, {1}},
Comment thread
alvoron marked this conversation as resolved.
TransposeFQTransposeFuseParams{{1, 1, 1, 1}, {1, 1, 1, 1}},
TransposeFQTransposeFuseParams{{3}, {1, 3, 1, 1}},
TransposeFQTransposeFuseParams{{1, 3}, {1, 3, 1, 1}},
TransposeFQTransposeFuseParams{{1, 1, 3}, {1, 3, 1, 1}},
TransposeFQTransposeFuseParams{{1, 1, 1, 3}, {1, 3, 1, 1}}));

struct TransposeReduceParams {
// given params
PartialShape transpose_input_shape;
Expand Down
Loading