Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ed6617a
[QNN EP]: Fusion of multiply and reciprocal with divide
ankipand-qti Apr 25, 2026
be5ccad
Adding a missing include header file
ankipand-qti Apr 27, 2026
0182c12
Removing extra header file
ankipand-qti Apr 28, 2026
0d380df
Merge branch 'main' into dev/ankipand_qcom/reciprocal_multiply_fusion
ankipand-qti Apr 28, 2026
eafe732
Using ORT datatype instead of MLFloat16
ankipand-qti Apr 28, 2026
4b77854
Lint runner fixed
ankipand-qti Apr 28, 2026
1d7ce45
Merge branch 'main' into dev/ankipand_qcom/reciprocal_multiply_fusion
ankipand-qti Apr 28, 2026
d4bc03f
Addressing review comments
ankipand-qti Apr 28, 2026
8c7232e
Merge branch 'dev/ankipand_qcom/reciprocal_multiply_fusion' of github…
ankipand-qti Apr 28, 2026
7a3fbb8
Addressing failing test cases
ankipand-qti Apr 28, 2026
f130384
Addressing the failing test cases
ankipand-qti Apr 28, 2026
e9ee222
Found a bug for reciprocal output during testing
ankipand-qti Apr 29, 2026
541d9c1
Fixing the bug in test code
ankipand-qti Apr 29, 2026
089d459
Fixing some bugs in test code
ankipand-qti Apr 29, 2026
4a9e694
Merge branch 'main' into dev/ankipand_qcom/reciprocal_multiply_fusion
ankipand-qti Apr 30, 2026
7b6f9f0
Addressing the review comments
ankipand-qti May 6, 2026
d811e37
Lint issues fix
ankipand-qti May 6, 2026
0e5c53d
Fixed test code problems
ankipand-qti May 6, 2026
b76b5a7
Addressing the review comments
ankipand-qti May 12, 2026
d3e7ef0
Adding V68 skips for Linux ARM64
ankipand-qti May 13, 2026
cce27f4
Merge branch 'main' into dev/ankipand_qcom/reciprocal_multiply_fusion
ankipand-qti May 22, 2026
bf743cc
Handling the review comments
ankipand-qti Jun 19, 2026
10418f7
Merge branch 'dev/ankipand_qcom/reciprocal_multiply_fusion' of github…
ankipand-qti Jun 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
#include "core/providers/qnn/builder/qnn_def.h"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not needed as there are no changes in this file.

#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/qnn_utils.h"

Expand Down Expand Up @@ -37,7 +38,10 @@ Ort::Status ReciprocalOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrappe
const auto& outputs = node_unit.Outputs();
RETURN_IF_NOT(outputs.size() == 1, "Reciprocal operator must have exactly 1 output.");

// Check input type is float for CPU.
// On the QNN CPU backend only float32 is accepted; other backends (HTP, GPU)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not adding anything. Please avoid this change.

// are gated by the QNN SDK's own op-validation call inside
// ProcessAttributesAndOutputs (do_op_validation=true), which will return an
// error if the backend cannot handle the resulting ElementWiseDivide node.
RETURN_IF_ERROR(DataTypeCheckForCpuBackend(qnn_model_wrapper, inputs[0].type, ""));

return Ort::Status();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
#include "core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h"
#include "core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h"
#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h"
#include "core/providers/qnn/builder/qnn_node_group/reciprocal_mul_fusion.h"
#include "core/providers/qnn/builder/qnn_node_group/reshape_einsum_reshape.h"
#include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h"
#include "core/providers/qnn/builder/qnn_node_group/spacetodepth_fusion.h"
#include "core/providers/qnn/builder/qnn_node_group/reshape_transpose_rank5.h"
#include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h"
#include "core/providers/qnn/builder/qnn_node_group/spacetodepth_fusion.h"
#include "core/providers/qnn/builder/qnn_node_group/transpose_reshape_transpose_fusion.h"
#include "core/providers/qnn/builder/qnn_node_group/udo_fusion.h"
#include "core/providers/qnn/builder/qnn_utils.h"
Expand Down Expand Up @@ -96,6 +97,7 @@ static std::unordered_map<std::string, std::vector<FusionFunc>> fusions = {
{"Mul", {ScaleSoftmaxFusion::TryFusion}},
{"Cast", {CastLoneQFusion::TryFusion}},
{"Erf", {GeluFusion::TryFusion}},
{"Reciprocal", {ReciprocalMulFusion::TryFusion}},
{"ReduceMean", {LayerNormFusion::TryFusion}},
{"Einsum", {ReshapeEinsumReshapeNodeGroup::TryFusion}},
{"Reshape", {SpaceToDepthFusion::TryFusion, Rank6ToRank5Fusion::TryFusion}},
Expand Down Expand Up @@ -137,9 +139,9 @@ static std::unique_ptr<IQnnNodeGroup> TryQnnFusions(
// For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except
// MatMul w/ LPBQ encodings, Erf, and Reshape.
if (starting_node_unit.UnitType() != OrtNodeUnit::Type::SingleNode &&
starting_node_unit.OpType() != "Erf" &&
starting_node_unit.OpType() != "Gather" &&
starting_node_unit.OpType() != "MatMul" &&
starting_node_unit.OpType() != "Erf" &&
starting_node_unit.OpType() != "Reshape") {
return nullptr;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
// SPDX-License-Identifier: MIT

// ReciprocalMulFusion: Fuses SingleNode Reciprocal->Mul into ElementWiseDivide.
// QDQGroup pattern avoided to preserve separate quantization of 1/b.

#include "core/providers/qnn/builder/qnn_node_group/reciprocal_mul_fusion.h"

#include <array>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>

#include <gsl/gsl>

#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/qnn_node_group/utils.h"
#include "core/providers/qnn/builder/qnn_utils.h"
#include "core/providers/qnn/ort_api.h"

namespace onnxruntime {
namespace qnn {

// Convenience macros for validation and creation paths.
#define ValidateOnQnn(qnn_model_wrapper, reciprocal_node_unit, mul_node_unit, recip_is_mul_input0) \
CreateOrValidateOnQnn((qnn_model_wrapper), (reciprocal_node_unit), (mul_node_unit), (recip_is_mul_input0), /*validate=*/true)
#define CreateOnQnn(qnn_model_wrapper, reciprocal_node_unit, mul_node_unit, recip_is_mul_input0) \
CreateOrValidateOnQnn((qnn_model_wrapper), (reciprocal_node_unit), (mul_node_unit), (recip_is_mul_input0), /*validate=*/false)

// Forward declaration.
static Ort::Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper,
const OrtNodeUnit& reciprocal_node_unit,
const OrtNodeUnit& mul_node_unit,
bool recip_is_mul_input0,
bool validate);

// TryFusion: Matches Reciprocal->Mul pattern and validates fusion.
std::unique_ptr<IQnnNodeGroup> ReciprocalMulFusion::TryFusion(
QnnModelWrapper& qnn_model_wrapper,
const OrtNodeUnit& reciprocal_node_unit,
const std::unordered_map<const OrtNode*, const OrtNodeUnit*>& node_to_node_unit,
const std::unordered_map<const OrtNodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
const Ort::Logger& logger) {
ORT_UNUSED_PARAMETER(logger);

// Step 1: Check op-type and unit type.
// Only accept SingleNode Reciprocal to preserve separate quantization of 1/b.
if (reciprocal_node_unit.OpType() != "Reciprocal" ||
reciprocal_node_unit.UnitType() != OrtNodeUnit::Type::SingleNode) {
return nullptr;
}

// Step 2: Locate single Mul consumer (handles QDQ boundaries).
const OrtNodeUnit* mul_node_unit =
GetChildNodeUnitAllowQdq(qnn_model_wrapper, reciprocal_node_unit, "Mul",
node_to_node_unit, node_unit_to_qnn_node_group);
if (mul_node_unit == nullptr) {
return nullptr;
}

// Step 3: Determine which Mul input carries the Reciprocal output.
const auto& mul_inputs = mul_node_unit->Inputs();
const std::string& recip_output_name = reciprocal_node_unit.Outputs()[0].name;
bool recip_is_mul_input0 = (mul_inputs[0].name == recip_output_name);
bool recip_is_mul_input1 = (mul_inputs[1].name == recip_output_name);

if (!recip_is_mul_input0 && !recip_is_mul_input1) {
return nullptr;
}

if (recip_is_mul_input0 && recip_is_mul_input1) {
return nullptr; // Both inputs same: would change semantics.
}

// Step 4: QNN capability validation (dry-run).
if (Ort::Status status = ValidateOnQnn(qnn_model_wrapper, reciprocal_node_unit, *mul_node_unit, recip_is_mul_input0);
!status.IsOK()) {
return nullptr;
}

// Step 5: Construct fusion object.
return std::make_unique<ReciprocalMulFusion>(reciprocal_node_unit, *mul_node_unit, recip_is_mul_input0);
}

ReciprocalMulFusion::ReciprocalMulFusion(const OrtNodeUnit& reciprocal_node_unit,
const OrtNodeUnit& mul_node_unit,
bool recip_is_mul_input0)
: node_units_{&reciprocal_node_unit, &mul_node_unit},
recip_is_mul_input0_{recip_is_mul_input0} {
}

Ort::Status ReciprocalMulFusion::IsSupported(QnnModelWrapper& qmw,
const Ort::Logger& logger) const {
ORT_UNUSED_PARAMETER(logger);
return ValidateOnQnn(qmw, *node_units_[0], *node_units_[1], recip_is_mul_input0_);
}

Ort::Status ReciprocalMulFusion::AddToModelBuilder(QnnModelWrapper& qmw,
const Ort::Logger& logger) const {
ORT_UNUSED_PARAMETER(logger);
return CreateOnQnn(qmw, *node_units_[0], *node_units_[1], recip_is_mul_input0_);
}

gsl::span<const OrtNodeUnit* const> ReciprocalMulFusion::GetNodeUnits() const {
return node_units_;
}

const OrtNodeUnit* ReciprocalMulFusion::GetTargetNodeUnit() const {
return node_units_[1]; // Mul is the convergence point.
}

// CreateOrValidateOnQnn: Shared validate/build path.
static Ort::Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper,
const OrtNodeUnit& reciprocal_node_unit,
const OrtNodeUnit& mul_node_unit,
bool recip_is_mul_input0,
bool validate) {
RETURN_IF_NOT(reciprocal_node_unit.OpType() == "Reciprocal",
("ReciprocalMulFusion: expected Reciprocal op, got " + reciprocal_node_unit.OpType()).c_str());
RETURN_IF_NOT(mul_node_unit.OpType() == "Mul",
("ReciprocalMulFusion: expected Mul op, got " + mul_node_unit.OpType()).c_str());

// Resolve tensor roles.
const OrtNodeUnitIODef& denominator_def = reciprocal_node_unit.Inputs()[0];
const auto& mul_inputs = mul_node_unit.Inputs();
const OrtNodeUnitIODef& numerator_def = recip_is_mul_input0 ? mul_inputs[1] : mul_inputs[0];
const OrtNodeUnitIODef& output_def = mul_node_unit.Outputs()[0];
const std::string node_name = utils::UniqueNameGenerator().New(reciprocal_node_unit);

if (validate) {
// Dry-run: capability query only.
QnnTensorWrapper numerator_tensor;
QnnTensorWrapper denominator_tensor;
QnnTensorWrapper output_tensor;

RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(numerator_def, numerator_tensor));
RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(denominator_def, denominator_tensor));
RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor));

RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(
node_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
QNN_OP_ELEMENT_WISE_DIVIDE,
/*input_tensors=*/{numerator_tensor.GetQnnTensor(), denominator_tensor.GetQnnTensor()},
/*output_tensors=*/{output_tensor.GetQnnTensor()},
/*params=*/{}));
} else {
// Build path: register tensors and create QNN node.

if (!qnn_model_wrapper.IsQnnTensorWrapperExist(numerator_def.name)) {
QnnTensorWrapper numerator_tensor;
RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(numerator_def, numerator_tensor));
RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(numerator_tensor)),
"ReciprocalMulFusion: failed to add numerator tensor wrapper.");
}

if (!qnn_model_wrapper.IsQnnTensorWrapperExist(denominator_def.name)) {
QnnTensorWrapper denominator_tensor;
RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(denominator_def, denominator_tensor));
RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(denominator_tensor)),
"ReciprocalMulFusion: failed to add denominator tensor wrapper.");
}

if (!qnn_model_wrapper.IsQnnTensorWrapperExist(output_def.name)) {
QnnTensorWrapper output_tensor;
RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor));
RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)),
"ReciprocalMulFusion: failed to add output tensor wrapper.");
}

// Create fused ElementWiseDivide node.
RETURN_IF_NOT(
qnn_model_wrapper.CreateQnnNode(
node_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
QNN_OP_ELEMENT_WISE_DIVIDE,
/*input_names=*/{numerator_def.name, denominator_def.name},
/*output_names=*/{output_def.name},
/*param_tensor_names=*/{},
/*do_op_validation=*/validate),
"ReciprocalMulFusion: failed to create fused ElementWiseDivide node.");
}

return Ort::Status();
}

} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
// SPDX-License-Identifier: MIT

// ReciprocalMulFusion: Fuses SingleNode Reciprocal->Mul into ElementWiseDivide.

#pragma once

#include <array>
#include <memory>
#include <unordered_map>

#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h"
#include "core/providers/qnn/ort_api.h"

namespace onnxruntime {
namespace qnn {

class QnnModelWrapper;

/// Fuses Reciprocal->Mul into ElementWiseDivide (SingleNode only to preserve quantization).
class ReciprocalMulFusion : public IQnnNodeGroup {
public:
ReciprocalMulFusion(const OrtNodeUnit& reciprocal_node_unit, const OrtNodeUnit& mul_node_unit,
bool recip_is_mul_input0);
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ReciprocalMulFusion);

// IQnnNodeGroup interface
Ort::Status IsSupported(QnnModelWrapper& qmw, const Ort::Logger& logger) const override;
Ort::Status AddToModelBuilder(QnnModelWrapper& qmw, const Ort::Logger& logger) const override;
gsl::span<const OrtNodeUnit* const> GetNodeUnits() const override;
const OrtNodeUnit* GetTargetNodeUnit() const override;
std::string_view Type() const override { return "ReciprocalMulFusion"; }

// Factory
static std::unique_ptr<IQnnNodeGroup> TryFusion(
QnnModelWrapper& qnn_model_wrapper,
const OrtNodeUnit& reciprocal_node_unit,
const std::unordered_map<const OrtNode*, const OrtNodeUnit*>& node_to_node_unit,
const std::unordered_map<const OrtNodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
const Ort::Logger& logger);

private:
std::array<const OrtNodeUnit*, 2> node_units_; // [0]=Reciprocal, [1]=Mul
bool recip_is_mul_input0_{false}; // Which Mul input slot carries Reciprocal output.
};

} // namespace qnn
} // namespace onnxruntime
Loading
Loading