diff --git a/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp index e2179c16f559..fd03b5ab2e35 100644 --- a/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp @@ -12,6 +12,7 @@ #include "openvino/op/multiply.hpp" #include "openvino/op/power.hpp" #include "openvino/op/reduce_mean.hpp" +#include "openvino/op/reshape.hpp" #include "openvino/op/sqrt.hpp" #include "openvino/pass/manager.hpp" #include "openvino/pass/pattern/op/optional.hpp" @@ -37,6 +38,12 @@ RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x, bool enable_wit auto const_power_convert = pattern::optional(const_power); auto power = pattern::wrap_type({x, const_power_convert}); + // x^2 via Multiply(x, x) — used by TFLite RMS norm decomposition + auto mul_square = pattern::wrap_type({x, x}); + + // Either Power(x, 2) or Multiply(x, x) + auto square = std::make_shared(OutputVector{power, mul_square}); + // ReduceMean(x^2,axes) auto mean_axes = pattern::wrap_type([](const ov::Output& output) { auto const_node = ov::as_type_ptr(output.get_node_shared_ptr()); @@ -48,15 +55,18 @@ RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x, bool enable_wit // RMS fusion is only valid when ReduceMean has exactly one axis. return num_elems == 1; }); - auto mean = pattern::wrap_type({power, mean_axes}); + auto mean = pattern::wrap_type({square, mean_axes}); // ReduceMean(x^2,axes)+eps auto eps = pattern::wrap_type(); auto eps_convert = pattern::optional(eps); auto add_eps = pattern::wrap_type({mean, eps_convert}); + // Optional Reshape between add_eps and sqrt/rsqrt (e.g., TFLite decomposition with keepdims=false) + auto add_eps_opt_reshape = pattern::optional({add_eps, pattern::any_input()}); + // Sqrt(ReduceMean(x^2,axes)+eps) - auto sqrt = pattern::wrap_type({add_eps}); + auto sqrt = pattern::wrap_type({add_eps_opt_reshape}); // 1/Sqrt(ReduceMean(x^2,axes)+eps) auto const_pow = pattern::wrap_type(pattern::value_matches("-1")); @@ -70,7 +80,7 @@ RMSFusion::RMSFusion(bool force_tail_convert, bool enable_div_x, bool enable_wit // Power(ReduceMean(x^2,axes)+eps, -0.5) — direct rsqrt without Sqrt node auto const_neg_half = pattern::wrap_type(pattern::value_matches("-0.5")); auto const_neg_half_convert = pattern::optional(const_neg_half); - auto pow_direct = pattern::wrap_type({add_eps, const_neg_half_convert}); + auto pow_direct = pattern::wrap_type({add_eps_opt_reshape, const_neg_half_convert}); std::shared_ptr div_or_pow = std::make_shared(OutputVector{div, pow, pow_direct}); diff --git a/src/common/transformations/tests/common_optimizations/rms_norm_decomposition_test.cpp b/src/common/transformations/tests/common_optimizations/rms_norm_decomposition_test.cpp index 3ce950117f1e..8b6eba057674 100644 --- a/src/common/transformations/tests/common_optimizations/rms_norm_decomposition_test.cpp +++ b/src/common/transformations/tests/common_optimizations/rms_norm_decomposition_test.cpp @@ -21,6 +21,7 @@ #include "openvino/op/parameter.hpp" #include "openvino/op/power.hpp" #include "openvino/op/reduce_mean.hpp" +#include "openvino/op/reshape.hpp" #include "openvino/op/sqrt.hpp" using namespace testing; @@ -569,3 +570,100 @@ TEST_F(TransformationTestsF, RMSNormFusionTest16_PowerNegHalf_F16Convert) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); } + +// Multiply(x, x) instead of Power(x, 2) — TFLite RMS norm decomposition pattern +TEST_F(TransformationTestsF, RMSNormFusionTest17_MulSquare_PowerNegHalf) { + { + auto input = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, 2048}); + // x^2 via Multiply(x, x) + auto mul_square = std::make_shared(input, input); + auto mean_axes = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {-1}); + auto mean = std::make_shared(mul_square, mean_axes, true); + auto eps = ov::op::v0::Constant::create(ov::element::f32, {1, 1, 1}, {1e-6f}); + auto add_eps = std::make_shared(mean, eps); + auto neg_half = ov::op::v0::Constant::create(ov::element::f32, {}, {-0.5f}); + auto rsqrt = std::make_shared(add_eps, neg_half); + auto mul1 = std::make_shared(input, rsqrt); + auto gamma = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{2048}, std::vector(2048, 1.0f)); + auto mul2 = std::make_shared(gamma, mul1); + + model = std::make_shared(ov::OutputVector{mul2}, ov::ParameterVector{input}); + manager.register_pass(false); + } + { + auto input = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, 2048}); + auto gamma = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{2048}, std::vector(2048, 1.0f)); + auto rms = std::make_shared(input, gamma, 1e-6f, ov::element::f32); + + model_ref = std::make_shared(ov::OutputVector{rms}, ov::ParameterVector{input}); + } + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +// Multiply(x, x) with optional Reshape before rsqrt — TFLite pattern with keepdims=false +TEST_F(TransformationTestsF, RMSNormFusionTest18_MulSquare_Reshape_PowerNegHalf) { + { + auto input = std::make_shared(ov::element::f16, ov::PartialShape{1, 1, 30, 256}); + // x^2 via Multiply(x, x) + auto mul_square = std::make_shared(input, input); + auto mean_axes = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {-1}); + auto mean = std::make_shared(mul_square, mean_axes, false); + auto eps = ov::op::v0::Constant::create(ov::element::f16, {1, 1, 1}, {1e-6f}); + auto add_eps = std::make_shared(mean, eps); + // Reshape to add back the reduced dimension for broadcasting + auto reshape_shape = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, {1, 1, 30, 1}); + auto reshape = std::make_shared(add_eps, reshape_shape, false); + auto neg_half = ov::op::v0::Constant::create(ov::element::f16, {}, {-0.5f}); + auto rsqrt = std::make_shared(reshape, neg_half); + auto mul1 = std::make_shared(input, rsqrt); + auto gamma = + ov::op::v0::Constant::create(ov::element::f16, ov::Shape{1, 1, 1, 256}, std::vector(256, 1.0f)); + auto mul2 = std::make_shared(gamma, mul1); + + model = std::make_shared(ov::OutputVector{mul2}, ov::ParameterVector{input}); + manager.register_pass(false); + } + { + auto input = std::make_shared(ov::element::f16, ov::PartialShape{1, 1, 30, 256}); + auto gamma = + ov::op::v0::Constant::create(ov::element::f16, ov::Shape{1, 1, 1, 256}, std::vector(256, 1.0f)); + auto rms = std::make_shared(input, gamma, 1e-6f, ov::element::f16); + + model_ref = std::make_shared(ov::OutputVector{rms}, ov::ParameterVector{input}); + } + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +// Multiply(x, x) without gamma, with dynamic scale — TFLite pattern +TEST_F(TransformationTestsF, RMSNormFusionTest19_MulSquare_NoGamma_DynamicScale) { + { + auto input = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1, 2048}); + auto scale = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1, 2048}); + // x^2 via Multiply(x, x) + auto mul_square = std::make_shared(input, input); + auto mean_axes = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {-1}); + auto mean = std::make_shared(mul_square, mean_axes, true); + auto eps = ov::op::v0::Constant::create(ov::element::f32, {}, {1e-6f}); + auto add_eps = std::make_shared(mean, eps); + auto neg_half = ov::op::v0::Constant::create(ov::element::f32, {}, {-0.5f}); + auto rsqrt = std::make_shared(add_eps, neg_half); + auto mul1 = std::make_shared(input, rsqrt); + auto mul2 = std::make_shared(mul1, scale); + + model = std::make_shared(ov::OutputVector{mul2}, ov::ParameterVector{input, scale}); + manager.register_pass(false, false, true); + } + { + auto input = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1, 2048}); + auto scale = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1, 2048}); + auto rms = std::make_shared(input, 1e-6f, ov::element::f32); + auto mul = std::make_shared(rms, scale); + + model_ref = std::make_shared(ov::OutputVector{mul}, ov::ParameterVector{input, scale}); + } + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); +} \ No newline at end of file