Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "openvino/op/multiply.hpp"
#include "openvino/op/power.hpp"
#include "openvino/op/reduce_mean.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
Expand All @@ -37,6 +38,12 @@ RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x, bool enable_wit
auto const_power_convert = pattern::optional<v0::Convert>(const_power);
auto power = pattern::wrap_type<v1::Power>({x, const_power_convert});

// x^2 via Multiply(x, x) — used by TFLite RMS norm decomposition
auto mul_square = pattern::wrap_type<v1::Multiply>({x, x});

// Either Power(x, 2) or Multiply(x, x)
auto square = std::make_shared<pattern::op::Or>(OutputVector{power, mul_square});

// ReduceMean(x^2,axes)
auto mean_axes = pattern::wrap_type<v0::Constant>([](const ov::Output<ov::Node>& output) {
auto const_node = ov::as_type_ptr<v0::Constant>(output.get_node_shared_ptr());
Expand All @@ -48,15 +55,18 @@ RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x, bool enable_wit
// RMS fusion is only valid when ReduceMean has exactly one axis.
return num_elems == 1;
});
auto mean = pattern::wrap_type<v1::ReduceMean>({power, mean_axes});
auto mean = pattern::wrap_type<v1::ReduceMean>({square, mean_axes});

// ReduceMean(x^2,axes)+eps
auto eps = pattern::wrap_type<v0::Constant>();
auto eps_convert = pattern::optional<v0::Convert>(eps);
auto add_eps = pattern::wrap_type<v1::Add>({mean, eps_convert});

// Optional Reshape between add_eps and sqrt/rsqrt (e.g., TFLite decomposition with keepdims=false)
auto add_eps_opt_reshape = pattern::optional<v1::Reshape>({add_eps, pattern::any_input()});

// Sqrt(ReduceMean(x^2,axes)+eps)
auto sqrt = pattern::wrap_type<v0::Sqrt>({add_eps});
auto sqrt = pattern::wrap_type<v0::Sqrt>({add_eps_opt_reshape});

// 1/Sqrt(ReduceMean(x^2,axes)+eps)
auto const_pow = pattern::wrap_type<v0::Constant>(pattern::value_matches("-1"));
Expand All @@ -70,7 +80,7 @@ RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x, bool enable_wit
// Power(ReduceMean(x^2,axes)+eps, -0.5) — direct rsqrt without Sqrt node
auto const_neg_half = pattern::wrap_type<v0::Constant>(pattern::value_matches("-0.5"));
auto const_neg_half_convert = pattern::optional<v0::Convert>(const_neg_half);
auto pow_direct = pattern::wrap_type<v1::Power>({add_eps, const_neg_half_convert});
auto pow_direct = pattern::wrap_type<v1::Power>({add_eps_opt_reshape, const_neg_half_convert});

std::shared_ptr<pattern::op::Or> div_or_pow = std::make_shared<pattern::op::Or>(OutputVector{div, pow, pow_direct});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "openvino/op/parameter.hpp"
#include "openvino/op/power.hpp"
#include "openvino/op/reduce_mean.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/sqrt.hpp"

using namespace testing;
Expand Down Expand Up @@ -569,3 +570,100 @@ TEST_F(TransformationTestsF, RMSNormFusionTest16_PowerNegHalf_F16Convert) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}

// Multiply(x, x) instead of Power(x, 2) — TFLite RMS norm decomposition pattern
TEST_F(TransformationTestsF, RMSNormFusionTest17_MulSquare_PowerNegHalf) {
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{1, 1, 2048});
// x^2 via Multiply(x, x)
auto mul_square = std::make_shared<ov::op::v1::Multiply>(input, input);
auto mean_axes = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {-1});
auto mean = std::make_shared<ov::op::v1::ReduceMean>(mul_square, mean_axes, true);
auto eps = ov::op::v0::Constant::create(ov::element::f32, {1, 1, 1}, {1e-6f});
auto add_eps = std::make_shared<ov::op::v1::Add>(mean, eps);
auto neg_half = ov::op::v0::Constant::create(ov::element::f32, {}, {-0.5f});
auto rsqrt = std::make_shared<ov::op::v1::Power>(add_eps, neg_half);
auto mul1 = std::make_shared<ov::op::v1::Multiply>(input, rsqrt);
auto gamma = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{2048}, std::vector<float>(2048, 1.0f));
auto mul2 = std::make_shared<ov::op::v1::Multiply>(gamma, mul1);

model = std::make_shared<ov::Model>(ov::OutputVector{mul2}, ov::ParameterVector{input});
manager.register_pass<RMSFusion>(false);
}
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{1, 1, 2048});
auto gamma = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{2048}, std::vector<float>(2048, 1.0f));
auto rms = std::make_shared<ov::op::internal::RMS>(input, gamma, 1e-6f, ov::element::f32);

model_ref = std::make_shared<ov::Model>(ov::OutputVector{rms}, ov::ParameterVector{input});
}
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}

// Multiply(x, x) with optional Reshape before rsqrt — TFLite pattern with keepdims=false
TEST_F(TransformationTestsF, RMSNormFusionTest18_MulSquare_Reshape_PowerNegHalf) {
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{1, 1, 30, 256});
// x^2 via Multiply(x, x)
auto mul_square = std::make_shared<ov::op::v1::Multiply>(input, input);
auto mean_axes = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {-1});
auto mean = std::make_shared<ov::op::v1::ReduceMean>(mul_square, mean_axes, false);
auto eps = ov::op::v0::Constant::create(ov::element::f16, {1, 1, 1}, {1e-6f});
auto add_eps = std::make_shared<ov::op::v1::Add>(mean, eps);
// Reshape to add back the reduced dimension for broadcasting
auto reshape_shape = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {1, 1, 30, 1});
auto reshape = std::make_shared<ov::op::v1::Reshape>(add_eps, reshape_shape, false);
auto neg_half = ov::op::v0::Constant::create(ov::element::f16, {}, {-0.5f});
auto rsqrt = std::make_shared<ov::op::v1::Power>(reshape, neg_half);
auto mul1 = std::make_shared<ov::op::v1::Multiply>(input, rsqrt);
auto gamma =
ov::op::v0::Constant::create(ov::element::f16, ov::Shape{1, 1, 1, 256}, std::vector<float>(256, 1.0f));
auto mul2 = std::make_shared<ov::op::v1::Multiply>(gamma, mul1);

model = std::make_shared<ov::Model>(ov::OutputVector{mul2}, ov::ParameterVector{input});
manager.register_pass<RMSFusion>(false);
}
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{1, 1, 30, 256});
auto gamma =
ov::op::v0::Constant::create(ov::element::f16, ov::Shape{1, 1, 1, 256}, std::vector<float>(256, 1.0f));
auto rms = std::make_shared<ov::op::internal::RMS>(input, gamma, 1e-6f, ov::element::f16);

model_ref = std::make_shared<ov::Model>(ov::OutputVector{rms}, ov::ParameterVector{input});
}
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}

// Multiply(x, x) without gamma, with dynamic scale — TFLite pattern
TEST_F(TransformationTestsF, RMSNormFusionTest19_MulSquare_NoGamma_DynamicScale) {
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 2048});
auto scale = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 2048});
// x^2 via Multiply(x, x)
auto mul_square = std::make_shared<ov::op::v1::Multiply>(input, input);
auto mean_axes = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {-1});
auto mean = std::make_shared<ov::op::v1::ReduceMean>(mul_square, mean_axes, true);
auto eps = ov::op::v0::Constant::create(ov::element::f32, {}, {1e-6f});
auto add_eps = std::make_shared<ov::op::v1::Add>(mean, eps);
auto neg_half = ov::op::v0::Constant::create(ov::element::f32, {}, {-0.5f});
auto rsqrt = std::make_shared<ov::op::v1::Power>(add_eps, neg_half);
auto mul1 = std::make_shared<ov::op::v1::Multiply>(input, rsqrt);
auto mul2 = std::make_shared<ov::op::v1::Multiply>(mul1, scale);

model = std::make_shared<ov::Model>(ov::OutputVector{mul2}, ov::ParameterVector{input, scale});
manager.register_pass<RMSFusion>(false, false, true);
}
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 2048});
auto scale = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 2048});
auto rms = std::make_shared<ov::op::internal::RMS>(input, 1e-6f, ov::element::f32);
auto mul = std::make_shared<ov::op::v1::Multiply>(rms, scale);

model_ref = std::make_shared<ov::Model>(ov::OutputVector{mul}, ov::ParameterVector{input, scale});
}
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
Loading