diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/reciprocal_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/reciprocal_op_builder.cc index 8ed0207eb9..9875c5a814 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/reciprocal_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/reciprocal_op_builder.cc @@ -3,6 +3,7 @@ #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_utils.h" @@ -37,7 +38,10 @@ Ort::Status ReciprocalOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrappe const auto& outputs = node_unit.Outputs(); RETURN_IF_NOT(outputs.size() == 1, "Reciprocal operator must have exactly 1 output."); - // Check input type is float for CPU. + // On the QNN CPU backend only float32 is accepted; other backends (HTP, GPU) + // are gated by the QNN SDK's own op-validation call inside + // ProcessAttributesAndOutputs (do_op_validation=true), which will return an + // error if the backend cannot handle the resulting ElementWiseDivide node. RETURN_IF_ERROR(DataTypeCheckForCpuBackend(qnn_model_wrapper, inputs[0].type, "")); return Ort::Status(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index dc5b065bd2..d975e63056 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -23,11 +23,12 @@ #include "core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" +#include "core/providers/qnn/builder/qnn_node_group/reciprocal_mul_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/reshape_einsum_reshape.h" #include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h" -#include "core/providers/qnn/builder/qnn_node_group/spacetodepth_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/reshape_transpose_rank5.h" #include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/spacetodepth_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/transpose_reshape_transpose_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/udo_fusion.h" #include "core/providers/qnn/builder/qnn_utils.h" @@ -96,6 +97,7 @@ static std::unordered_map> fusions = { {"Mul", {ScaleSoftmaxFusion::TryFusion}}, {"Cast", {CastLoneQFusion::TryFusion}}, {"Erf", {GeluFusion::TryFusion}}, + {"Reciprocal", {ReciprocalMulFusion::TryFusion}}, {"ReduceMean", {LayerNormFusion::TryFusion}}, {"Einsum", {ReshapeEinsumReshapeNodeGroup::TryFusion}}, {"Reshape", {SpaceToDepthFusion::TryFusion, Rank6ToRank5Fusion::TryFusion}}, @@ -137,9 +139,9 @@ static std::unique_ptr TryQnnFusions( // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except // MatMul w/ LPBQ encodings, Erf, and Reshape. if (starting_node_unit.UnitType() != OrtNodeUnit::Type::SingleNode && + starting_node_unit.OpType() != "Erf" && starting_node_unit.OpType() != "Gather" && starting_node_unit.OpType() != "MatMul" && - starting_node_unit.OpType() != "Erf" && starting_node_unit.OpType() != "Reshape") { return nullptr; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/reciprocal_mul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reciprocal_mul_fusion.cc new file mode 100644 index 0000000000..0f2750e35b --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reciprocal_mul_fusion.cc @@ -0,0 +1,190 @@ +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// SPDX-License-Identifier: MIT + +// ReciprocalMulFusion: Fuses SingleNode Reciprocal->Mul into ElementWiseDivide. +// QDQGroup pattern avoided to preserve separate quantization of 1/b. + +#include "core/providers/qnn/builder/qnn_node_group/reciprocal_mul_fusion.h" + +#include +#include +#include +#include +#include + +#include + +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +// Convenience macros for validation and creation paths. +#define ValidateOnQnn(qnn_model_wrapper, reciprocal_node_unit, mul_node_unit, recip_is_mul_input0) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (reciprocal_node_unit), (mul_node_unit), (recip_is_mul_input0), /*validate=*/true) +#define CreateOnQnn(qnn_model_wrapper, reciprocal_node_unit, mul_node_unit, recip_is_mul_input0) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (reciprocal_node_unit), (mul_node_unit), (recip_is_mul_input0), /*validate=*/false) + +// Forward declaration. +static Ort::Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const OrtNodeUnit& reciprocal_node_unit, + const OrtNodeUnit& mul_node_unit, + bool recip_is_mul_input0, + bool validate); + +// TryFusion: Matches Reciprocal->Mul pattern and validates fusion. +std::unique_ptr ReciprocalMulFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const OrtNodeUnit& reciprocal_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const Ort::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + + // Step 1: Check op-type and unit type. + // Only accept SingleNode Reciprocal to preserve separate quantization of 1/b. + if (reciprocal_node_unit.OpType() != "Reciprocal" || + reciprocal_node_unit.UnitType() != OrtNodeUnit::Type::SingleNode) { + return nullptr; + } + + // Step 2: Locate single Mul consumer (handles QDQ boundaries). + const OrtNodeUnit* mul_node_unit = + GetChildNodeUnitAllowQdq(qnn_model_wrapper, reciprocal_node_unit, "Mul", + node_to_node_unit, node_unit_to_qnn_node_group); + if (mul_node_unit == nullptr) { + return nullptr; + } + + // Step 3: Determine which Mul input carries the Reciprocal output. + const auto& mul_inputs = mul_node_unit->Inputs(); + const std::string& recip_output_name = reciprocal_node_unit.Outputs()[0].name; + bool recip_is_mul_input0 = (mul_inputs[0].name == recip_output_name); + bool recip_is_mul_input1 = (mul_inputs[1].name == recip_output_name); + + if (!recip_is_mul_input0 && !recip_is_mul_input1) { + return nullptr; + } + + if (recip_is_mul_input0 && recip_is_mul_input1) { + return nullptr; // Both inputs same: would change semantics. + } + + // Step 4: QNN capability validation (dry-run). + if (Ort::Status status = ValidateOnQnn(qnn_model_wrapper, reciprocal_node_unit, *mul_node_unit, recip_is_mul_input0); + !status.IsOK()) { + return nullptr; + } + + // Step 5: Construct fusion object. + return std::make_unique(reciprocal_node_unit, *mul_node_unit, recip_is_mul_input0); +} + +ReciprocalMulFusion::ReciprocalMulFusion(const OrtNodeUnit& reciprocal_node_unit, + const OrtNodeUnit& mul_node_unit, + bool recip_is_mul_input0) + : node_units_{&reciprocal_node_unit, &mul_node_unit}, + recip_is_mul_input0_{recip_is_mul_input0} { +} + +Ort::Status ReciprocalMulFusion::IsSupported(QnnModelWrapper& qmw, + const Ort::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return ValidateOnQnn(qmw, *node_units_[0], *node_units_[1], recip_is_mul_input0_); +} + +Ort::Status ReciprocalMulFusion::AddToModelBuilder(QnnModelWrapper& qmw, + const Ort::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return CreateOnQnn(qmw, *node_units_[0], *node_units_[1], recip_is_mul_input0_); +} + +gsl::span ReciprocalMulFusion::GetNodeUnits() const { + return node_units_; +} + +const OrtNodeUnit* ReciprocalMulFusion::GetTargetNodeUnit() const { + return node_units_[1]; // Mul is the convergence point. +} + +// CreateOrValidateOnQnn: Shared validate/build path. +static Ort::Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const OrtNodeUnit& reciprocal_node_unit, + const OrtNodeUnit& mul_node_unit, + bool recip_is_mul_input0, + bool validate) { + RETURN_IF_NOT(reciprocal_node_unit.OpType() == "Reciprocal", + ("ReciprocalMulFusion: expected Reciprocal op, got " + reciprocal_node_unit.OpType()).c_str()); + RETURN_IF_NOT(mul_node_unit.OpType() == "Mul", + ("ReciprocalMulFusion: expected Mul op, got " + mul_node_unit.OpType()).c_str()); + + // Resolve tensor roles. + const OrtNodeUnitIODef& denominator_def = reciprocal_node_unit.Inputs()[0]; + const auto& mul_inputs = mul_node_unit.Inputs(); + const OrtNodeUnitIODef& numerator_def = recip_is_mul_input0 ? mul_inputs[1] : mul_inputs[0]; + const OrtNodeUnitIODef& output_def = mul_node_unit.Outputs()[0]; + const std::string node_name = utils::UniqueNameGenerator().New(reciprocal_node_unit); + + if (validate) { + // Dry-run: capability query only. + QnnTensorWrapper numerator_tensor; + QnnTensorWrapper denominator_tensor; + QnnTensorWrapper output_tensor; + + RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(numerator_def, numerator_tensor)); + RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(denominator_def, denominator_tensor)); + RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + + RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode( + node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_ELEMENT_WISE_DIVIDE, + /*input_tensors=*/{numerator_tensor.GetQnnTensor(), denominator_tensor.GetQnnTensor()}, + /*output_tensors=*/{output_tensor.GetQnnTensor()}, + /*params=*/{})); + } else { + // Build path: register tensors and create QNN node. + + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(numerator_def.name)) { + QnnTensorWrapper numerator_tensor; + RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(numerator_def, numerator_tensor)); + RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(numerator_tensor)), + "ReciprocalMulFusion: failed to add numerator tensor wrapper."); + } + + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(denominator_def.name)) { + QnnTensorWrapper denominator_tensor; + RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(denominator_def, denominator_tensor)); + RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(denominator_tensor)), + "ReciprocalMulFusion: failed to add denominator tensor wrapper."); + } + + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(output_def.name)) { + QnnTensorWrapper output_tensor; + RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), + "ReciprocalMulFusion: failed to add output tensor wrapper."); + } + + // Create fused ElementWiseDivide node. + RETURN_IF_NOT( + qnn_model_wrapper.CreateQnnNode( + node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_ELEMENT_WISE_DIVIDE, + /*input_names=*/{numerator_def.name, denominator_def.name}, + /*output_names=*/{output_def.name}, + /*param_tensor_names=*/{}, + /*do_op_validation=*/validate), + "ReciprocalMulFusion: failed to create fused ElementWiseDivide node."); + } + + return Ort::Status(); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/reciprocal_mul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reciprocal_mul_fusion.h new file mode 100644 index 0000000000..438f7ac008 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reciprocal_mul_fusion.h @@ -0,0 +1,48 @@ +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// SPDX-License-Identifier: MIT + +// ReciprocalMulFusion: Fuses SingleNode Reciprocal->Mul into ElementWiseDivide. + +#pragma once + +#include +#include +#include + +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// Fuses Reciprocal->Mul into ElementWiseDivide (SingleNode only to preserve quantization). +class ReciprocalMulFusion : public IQnnNodeGroup { + public: + ReciprocalMulFusion(const OrtNodeUnit& reciprocal_node_unit, const OrtNodeUnit& mul_node_unit, + bool recip_is_mul_input0); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ReciprocalMulFusion); + + // IQnnNodeGroup interface + Ort::Status IsSupported(QnnModelWrapper& qmw, const Ort::Logger& logger) const override; + Ort::Status AddToModelBuilder(QnnModelWrapper& qmw, const Ort::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const OrtNodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "ReciprocalMulFusion"; } + + // Factory + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const OrtNodeUnit& reciprocal_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const Ort::Logger& logger); + + private: + std::array node_units_; // [0]=Reciprocal, [1]=Mul + bool recip_is_mul_input0_{false}; // Which Mul input slot carries Reciprocal output. +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/reciprocal_mul_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/reciprocal_mul_fusion_test.cc new file mode 100644 index 0000000000..962e30f40b --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_node_group/reciprocal_mul_fusion_test.cc @@ -0,0 +1,489 @@ +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// SPDX-License-Identifier: MIT + +// Tests for ReciprocalMulFusion: validates fusion of Reciprocal->Mul into ElementWiseDivide. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include +#include + +#include "test/providers/qnn/qnn_node_group/qnn_graph_checker.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +namespace { + +// Builds Reciprocal->Mul pattern (commute controls Mul input order). +GetTestModelFn BuildReciprocalMulTestCase(const TestInputDef& numerator_def, + const TestInputDef& denominator_def, + bool commute = false) { + return [numerator_def, denominator_def, commute](ModelTestBuilder& builder) -> void { + builder.graph_->set_name("reciprocal_mul_fusion_graph"); + + MakeTestInput(builder, "numerator", numerator_def); + MakeTestInput(builder, "denominator", denominator_def); + + // denominator -> Reciprocal -> recip_out + builder.AddNode("Reciprocal_node", + "Reciprocal", + {"denominator"}, + {"recip_out"}, + kOnnxDomain); + + // Mul(numerator, recip_out) or Mul(recip_out, numerator) + std::vector mul_inputs = commute + ? std::vector{"recip_out", "numerator"} + : std::vector{"numerator", "recip_out"}; + + builder.AddNode("Mul_node", + "Mul", + mul_inputs, + {"output"}, + kOnnxDomain); + + builder.MakeOutput("output"); + }; +} + +// FP16 variant of fusion pattern. +GetTestModelFn BuildReciprocalMulFP16TestCase(const TestInputDef& numerator_def, + const TestInputDef& denominator_def, + bool commute = false) { + const TestInputDef num_fp16_def = ConvertToFP16InputDef(numerator_def); + const TestInputDef den_fp16_def = ConvertToFP16InputDef(denominator_def); + + return [num_fp16_def, den_fp16_def, commute](ModelTestBuilder& builder) -> void { + builder.graph_->set_name("reciprocal_mul_fp16_fusion_graph"); + + MakeTestInput(builder, "numerator", num_fp16_def); + MakeTestInput(builder, "denominator", den_fp16_def); + + builder.AddNode("Reciprocal_node", + "Reciprocal", + {"denominator"}, + {"recip_out"}, + kOnnxDomain); + + std::vector mul_inputs = commute + ? std::vector{"recip_out", "numerator"} + : std::vector{"numerator", "recip_out"}; + + builder.AddNode("Mul_node", + "Mul", + mul_inputs, + {"output"}, + kOnnxDomain); + + builder.MakeOutput("output"); + }; +} + +// QDQ version of fusion pattern (SingleNode Reciprocal). +template +GetTestQDQModelFn BuildQDQReciprocalMulTestCase( + const TestInputDef& numerator_def, + const TestInputDef& denominator_def, + bool commute = false, + bool use_contrib_qdq = false) { + return [numerator_def, denominator_def, commute, use_contrib_qdq]( + ModelTestBuilder& builder, + std::vector>& output_qparams) -> void { + builder.graph_->set_name("qdq_reciprocal_mul_fusion_graph"); + + MakeTestInput(builder, "numerator", numerator_def); + MakeTestInput(builder, "denominator", denominator_def); + + const QuantParams num_qparams = GetTestInputQuantParams(numerator_def); + const QuantParams den_qparams = GetTestInputQuantParams(denominator_def); + + const std::string num_qdq = AddQDQNodePair( + builder, "qdq_num", "numerator", num_qparams.scale, num_qparams.zero_point, use_contrib_qdq); + const std::string den_qdq = AddQDQNodePair( + builder, "qdq_den", "denominator", den_qparams.scale, den_qparams.zero_point, use_contrib_qdq); + builder.AddNode("Reciprocal_node", + "Reciprocal", + {den_qdq}, + {"recip_out"}, + kOnnxDomain); + + const QuantParams recip_qparams = GetTestInputQuantParams(denominator_def); + const std::string recip_qdq = AddQDQNodePair( + builder, "qdq_recip", "recip_out", recip_qparams.scale, recip_qparams.zero_point, use_contrib_qdq); + + std::vector mul_inputs = commute + ? std::vector{recip_qdq, num_qdq} + : std::vector{num_qdq, recip_qdq}; + + builder.AddNode("Mul_node", + "Mul", + mul_inputs, + {"mul_out"}, + kOnnxDomain); + + AddQDQNodePairWithOutputAsGraphOutput( + builder, "qdq_out", "mul_out", + output_qparams[0].scale, output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// No-fusion case: QDQ-wrapped Reciprocal with two Mul consumers. +template +GetTestQDQModelFn BuildQDQReciprocalMulNoFusionTestCase( + const TestInputDef& numerator_def, + const TestInputDef& denominator_def, + bool use_contrib_qdq = false) { + return [numerator_def, denominator_def, use_contrib_qdq]( + ModelTestBuilder& builder, + std::vector>& output_qparams) -> void { + builder.graph_->set_name("qdq_reciprocal_qdq_wrapped_no_fusion_graph"); + + MakeTestInput(builder, "numerator_a", numerator_def); + MakeTestInput(builder, "numerator_b", numerator_def); + MakeTestInput(builder, "denominator", denominator_def); + + const QuantParams num_qparams = GetTestInputQuantParams(numerator_def); + const QuantParams den_qparams = GetTestInputQuantParams(denominator_def); + + const std::string num_a_qdq = AddQDQNodePair( + builder, "qdq_num_a", "numerator_a", num_qparams.scale, num_qparams.zero_point, use_contrib_qdq); + const std::string num_b_qdq = AddQDQNodePair( + builder, "qdq_num_b", "numerator_b", num_qparams.scale, num_qparams.zero_point, use_contrib_qdq); + const std::string den_qdq = AddQDQNodePair( + builder, "qdq_den", "denominator", den_qparams.scale, den_qparams.zero_point, use_contrib_qdq); + + builder.AddNode("Reciprocal_node", + "Reciprocal", + {den_qdq}, + {"recip_out"}, + kOnnxDomain); + + const QuantParams recip_qparams = GetTestInputQuantParams(denominator_def); + const std::string recip_qdq = AddQDQNodePair( + builder, "qdq_recip", "recip_out", + recip_qparams.scale, recip_qparams.zero_point, use_contrib_qdq); + + builder.AddNode("Mul_A", + "Mul", + {num_a_qdq, recip_qdq}, + {"mul_out_a"}, + kOnnxDomain); + + builder.AddNode("Mul_B", + "Mul", + {num_b_qdq, recip_qdq}, + {"mul_out_b"}, + kOnnxDomain); + + AddQDQNodePairWithOutputAsGraphOutput( + builder, "qdq_out_a", "mul_out_a", + output_qparams[0].scale, output_qparams[0].zero_point, use_contrib_qdq); + AddQDQNodePairWithOutputAsGraphOutput( + builder, "qdq_out_b", "mul_out_b", + output_qparams[1].scale, output_qparams[1].zero_point, use_contrib_qdq); + }; +} + +// No-fusion case: Reciprocal with two Mul consumers. +GetTestModelFn BuildReciprocalTwoConsumersTestCase(const TestInputDef& numerator_def, + const TestInputDef& denominator_def) { + return [numerator_def, denominator_def](ModelTestBuilder& builder) -> void { + builder.graph_->set_name("reciprocal_two_consumers_graph"); + + MakeTestInput(builder, "numerator_a", numerator_def); + MakeTestInput(builder, "numerator_b", numerator_def); + MakeTestInput(builder, "denominator", denominator_def); + + builder.AddNode("Reciprocal_node", + "Reciprocal", + {"denominator"}, + {"recip_out"}, + kOnnxDomain); + + builder.AddNode("Mul_A", + "Mul", + {"numerator_a", "recip_out"}, + {"out_a"}, + kOnnxDomain); + + builder.AddNode("Mul_B", + "Mul", + {"numerator_b", "recip_out"}, + {"out_b"}, + kOnnxDomain); + + builder.MakeOutput("out_a"); + builder.MakeOutput("out_b"); + }; +} + +// No-fusion case: Reciprocal output is a graph output. +GetTestModelFn BuildReciprocalOutputIsGraphOutputTestCase(const TestInputDef& numerator_def, + const TestInputDef& denominator_def) { + return [numerator_def, denominator_def](ModelTestBuilder& builder) -> void { + builder.graph_->set_name("reciprocal_output_is_graph_output_graph"); + + MakeTestInput(builder, "numerator", numerator_def); + MakeTestInput(builder, "denominator", denominator_def); + + builder.AddNode("Reciprocal_node", + "Reciprocal", + {"denominator"}, + {"recip_out"}, + kOnnxDomain); + + builder.AddNode("Mul_node", + "Mul", + {"numerator", "recip_out"}, + {"output"}, + kOnnxDomain); + + builder.MakeOutput("recip_out"); + builder.MakeOutput("output"); + }; +} + +ProviderOptions GetProviderOptions() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; +#if defined(__linux__) && !defined(__aarch64__) + provider_options["soc_model"] = std::to_string(QNN_SOC_MODEL_SM8850); +#endif + return provider_options; +} + +} // namespace + +TEST_F(QnnHTPBackendTests, ReciprocalMulFusion_Float32_4D_StandardOrder) { + if (QnnHTPBackendTests::ShouldSkipIfHtpArchIsLessThanOrEqualTo(QNN_HTP_DEVICE_ARCH_V68)) { + GTEST_SKIP() << "FP32 HTP test skipped on architecture <= 68"; + } + + const std::filesystem::path json_qnn_graph_dir = "ReciprocalMulFusion_Float32_4D_StandardOrder"; + std::filesystem::remove_all(json_qnn_graph_dir); + ASSERT_TRUE(std::filesystem::create_directory(json_qnn_graph_dir)); + auto cleanup = gsl::finally([&json_qnn_graph_dir]() { std::filesystem::remove_all(json_qnn_graph_dir); }); + + ProviderOptions provider_options = GetProviderOptions(); + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = json_qnn_graph_dir.string(); + + const auto numerator_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + const auto denominator_def = TestInputDef({1, 2, 3, 4}, false, 0.5f, 2.0f); + + RunQnnModelTest(BuildReciprocalMulTestCase(numerator_def, denominator_def, /*commute=*/false), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); + + AssertOpInQnnGraph(json_qnn_graph_dir, "ElementWiseDivide", 1); +} + +TEST_F(QnnHTPBackendTests, ReciprocalMulFusion_Float32_4D_CommutedOrder) { + if (QnnHTPBackendTests::ShouldSkipIfHtpArchIsLessThanOrEqualTo(QNN_HTP_DEVICE_ARCH_V68)) { + GTEST_SKIP() << "FP32 HTP test skipped on architecture <= 68"; + } + + const std::filesystem::path json_qnn_graph_dir = "ReciprocalMulFusion_Float32_4D_CommutedOrder"; + std::filesystem::remove_all(json_qnn_graph_dir); + ASSERT_TRUE(std::filesystem::create_directory(json_qnn_graph_dir)); + auto cleanup = gsl::finally([&json_qnn_graph_dir]() { std::filesystem::remove_all(json_qnn_graph_dir); }); + + ProviderOptions provider_options = GetProviderOptions(); + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = json_qnn_graph_dir.string(); + + const auto numerator_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + const auto denominator_def = TestInputDef({1, 2, 3, 4}, false, 0.5f, 2.0f); + + RunQnnModelTest(BuildReciprocalMulTestCase(numerator_def, denominator_def, /*commute=*/true), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); + + AssertOpInQnnGraph(json_qnn_graph_dir, "ElementWiseDivide", 1); +} + +TEST_F(QnnHTPBackendTests, ReciprocalMulFusion_QDQ_U8_StandardOrder) { + if (QnnHTPBackendTests::ShouldSkipIfHtpArchIsLessThanOrEqualTo(QNN_HTP_DEVICE_ARCH_V68)) { + GTEST_SKIP() << "QDQ test skipped on HTP architecture <= 68"; + } + + const std::filesystem::path json_qnn_graph_dir = "ReciprocalMulFusion_QDQ_U8_StandardOrder"; + std::filesystem::remove_all(json_qnn_graph_dir); + ASSERT_TRUE(std::filesystem::create_directory(json_qnn_graph_dir)); + auto cleanup = gsl::finally([&json_qnn_graph_dir]() { std::filesystem::remove_all(json_qnn_graph_dir); }); + + ProviderOptions provider_options = GetProviderOptions(); + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = json_qnn_graph_dir.string(); + + const auto numerator_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + const auto denominator_def = TestInputDef({1, 2, 3, 4}, false, 0.5f, 2.0f); + + TestQDQModelAccuracy( + BuildReciprocalMulTestCase(numerator_def, denominator_def, /*commute=*/false), + BuildQDQReciprocalMulTestCase(numerator_def, denominator_def, /*commute=*/false), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); + + AssertOpInQnnGraph(json_qnn_graph_dir, "ElementWiseDivide", /*count=*/1); +} + +TEST_F(QnnHTPBackendTests, ReciprocalMulFusion_QDQ_U8_CommutedOrder) { + if (QnnHTPBackendTests::ShouldSkipIfHtpArchIsLessThanOrEqualTo(QNN_HTP_DEVICE_ARCH_V68)) { + GTEST_SKIP() << "QDQ test skipped on HTP architecture <= 68"; + } + + const std::filesystem::path json_qnn_graph_dir = "ReciprocalMulFusion_QDQ_U8_CommutedOrder"; + std::filesystem::remove_all(json_qnn_graph_dir); + ASSERT_TRUE(std::filesystem::create_directory(json_qnn_graph_dir)); + auto cleanup = gsl::finally([&json_qnn_graph_dir]() { std::filesystem::remove_all(json_qnn_graph_dir); }); + + ProviderOptions provider_options = GetProviderOptions(); + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = json_qnn_graph_dir.string(); + + const auto numerator_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + const auto denominator_def = TestInputDef({1, 2, 3, 4}, false, 0.5f, 2.0f); + + TestQDQModelAccuracy( + BuildReciprocalMulTestCase(numerator_def, denominator_def, /*commute=*/true), + BuildQDQReciprocalMulTestCase(numerator_def, denominator_def, /*commute=*/true), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); + + AssertOpInQnnGraph(json_qnn_graph_dir, "ElementWiseDivide", /*count=*/1); +} + +TEST_F(QnnHTPBackendTests, ReciprocalMulFusion_QDQ_U16_StandardOrder) { + if (QnnHTPBackendTests::ShouldSkipIfHtpArchIsLessThanOrEqualTo(QNN_HTP_DEVICE_ARCH_V68)) { + GTEST_SKIP() << "uint16 QDQ requires HTP arch > v68"; + } + + const std::filesystem::path json_qnn_graph_dir = "ReciprocalMulFusion_QDQ_U16_StandardOrder"; + std::filesystem::remove_all(json_qnn_graph_dir); + ASSERT_TRUE(std::filesystem::create_directory(json_qnn_graph_dir)); + auto cleanup = gsl::finally([&json_qnn_graph_dir]() { std::filesystem::remove_all(json_qnn_graph_dir); }); + + ProviderOptions provider_options = GetProviderOptions(); + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = json_qnn_graph_dir.string(); + + const auto numerator_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + const auto denominator_def = TestInputDef({1, 2, 3, 4}, false, 0.5f, 2.0f); + + TestQDQModelAccuracy( + BuildReciprocalMulTestCase(numerator_def, denominator_def, /*commute=*/false), + BuildQDQReciprocalMulTestCase(numerator_def, denominator_def, + /*commute=*/false, /*use_contrib_qdq=*/true), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); + + AssertOpInQnnGraph(json_qnn_graph_dir, "ElementWiseDivide", /*count=*/1); +} + +TEST_F(QnnHTPBackendTests, ReciprocalMulFusion_FP16) { + if (QnnHTPBackendTests::ShouldSkipIfHtpArchIsLessThanOrEqualTo(QNN_HTP_DEVICE_ARCH_V68)) { + GTEST_SKIP() << "uint16 QDQ requires HTP arch > v68"; + } + + const std::filesystem::path json_qnn_graph_dir = "ReciprocalMulFusion_FP16"; + std::filesystem::remove_all(json_qnn_graph_dir); + ASSERT_TRUE(std::filesystem::create_directory(json_qnn_graph_dir)); + auto cleanup = gsl::finally([&json_qnn_graph_dir]() { std::filesystem::remove_all(json_qnn_graph_dir); }); + + ProviderOptions provider_options = GetProviderOptions(); + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = json_qnn_graph_dir.string(); + + const auto numerator_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + const auto denominator_def = TestInputDef({1, 2, 3, 4}, false, 0.5f, 2.0f); + + const auto fp32_model_fn = BuildReciprocalMulTestCase(numerator_def, denominator_def, /*commute=*/false); + const auto fp16_model_fn = BuildReciprocalMulFP16TestCase(numerator_def, denominator_def, /*commute=*/false); + + TestFp16ModelAccuracy(fp32_model_fn, + fp16_model_fn, + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*tolerance=*/0.004f); + + AssertOpInQnnGraph(json_qnn_graph_dir, "ElementWiseDivide", /*count=*/1); +} + +TEST_F(QnnHTPBackendTests, ReciprocalMulFusion_ReciprocalOutputIsGraphOutput_NoFusion) { + if (QnnHTPBackendTests::ShouldSkipIfHtpArchIsLessThanOrEqualTo(QNN_HTP_DEVICE_ARCH_V68)) { + GTEST_SKIP() << "FP32 HTP test skipped on architecture <= 68"; + } + + const std::filesystem::path json_qnn_graph_dir = "ReciprocalMulFusion_ReciprocalOutputIsGraphOutput_NoFusion"; + std::filesystem::remove_all(json_qnn_graph_dir); + ASSERT_TRUE(std::filesystem::create_directory(json_qnn_graph_dir)); + auto cleanup = gsl::finally([&json_qnn_graph_dir]() { std::filesystem::remove_all(json_qnn_graph_dir); }); + + ProviderOptions provider_options = GetProviderOptions(); + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = json_qnn_graph_dir.string(); + + const auto numerator_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + const auto denominator_def = TestInputDef({1, 2, 3, 4}, false, 0.5f, 2.0f); + + RunQnnModelTest(BuildReciprocalOutputIsGraphOutputTestCase(numerator_def, denominator_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/2e-3f); + + AssertOpInQnnGraph(json_qnn_graph_dir, "ElementWiseDivide", /*count=*/1); + AssertOpInQnnGraph(json_qnn_graph_dir, "ElementWiseMultiply", /*count=*/1); +} + +TEST_F(QnnHTPBackendTests, ReciprocalMulFusion_QDQWrappedReciprocal_TwoConsumers_NoFusion) { + if (QnnHTPBackendTests::ShouldSkipIfHtpArchIsLessThanOrEqualTo(QNN_HTP_DEVICE_ARCH_V68)) { + GTEST_SKIP() << "QDQ test skipped on HTP architecture <= 68"; + } + + const std::filesystem::path json_qnn_graph_dir = "ReciprocalMulFusion_QDQWrappedReciprocal_TwoConsumers_NoFusion"; + std::filesystem::remove_all(json_qnn_graph_dir); + ASSERT_TRUE(std::filesystem::create_directory(json_qnn_graph_dir)); + auto cleanup = gsl::finally([&json_qnn_graph_dir]() { std::filesystem::remove_all(json_qnn_graph_dir); }); + + ProviderOptions provider_options = GetProviderOptions(); + provider_options["dump_json_qnn_graph"] = "1"; + provider_options["json_qnn_graph_dir"] = json_qnn_graph_dir.string(); + + const auto numerator_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + const auto denominator_def = TestInputDef({1, 2, 3, 4}, false, 0.5f, 2.0f); + + TestQDQModelAccuracy( + BuildReciprocalTwoConsumersTestCase(numerator_def, denominator_def), + BuildQDQReciprocalMulNoFusionTestCase(numerator_def, denominator_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); + + AssertOpInQnnGraph(json_qnn_graph_dir, "ElementWiseDivide", /*count=*/1); + AssertOpInQnnGraph(json_qnn_graph_dir, "ElementWiseMultiply", /*count=*/2); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD)