-
Notifications
You must be signed in to change notification settings - Fork 6
[QNN EP]: Fusion of multiply and reciprocal to divide #302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ankipand-qti
wants to merge
23
commits into
main
Choose a base branch
from
dev/ankipand_qcom/reciprocal_multiply_fusion
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
ed6617a
[QNN EP]: Fusion of multiply and reciprocal with divide
ankipand-qti be5ccad
Adding a missing include header file
ankipand-qti 0182c12
Removing extra header file
ankipand-qti 0d380df
Merge branch 'main' into dev/ankipand_qcom/reciprocal_multiply_fusion
ankipand-qti eafe732
Using ORT datatype instead of MLFloat16
ankipand-qti 4b77854
Lint runner fixed
ankipand-qti 1d7ce45
Merge branch 'main' into dev/ankipand_qcom/reciprocal_multiply_fusion
ankipand-qti d4bc03f
Addressing review comments
ankipand-qti 8c7232e
Merge branch 'dev/ankipand_qcom/reciprocal_multiply_fusion' of github…
ankipand-qti 7a3fbb8
Addressing failing test cases
ankipand-qti f130384
Addressing the failing test cases
ankipand-qti e9ee222
Found a bug for reciprocal output during testing
ankipand-qti 541d9c1
Fixing the bug in test code
ankipand-qti 089d459
Fixing some bugs in test code
ankipand-qti 4a9e694
Merge branch 'main' into dev/ankipand_qcom/reciprocal_multiply_fusion
ankipand-qti 7b6f9f0
Addressing the review comments
ankipand-qti d811e37
Lint issues fix
ankipand-qti 0e5c53d
Fixed test code problems
ankipand-qti b76b5a7
Addressing the review comments
ankipand-qti d3e7ef0
Adding V68 skips for Linux ARM64
ankipand-qti cce27f4
Merge branch 'main' into dev/ankipand_qcom/reciprocal_multiply_fusion
ankipand-qti bf743cc
Handling the review comments
ankipand-qti 10418f7
Merge branch 'dev/ankipand_qcom/reciprocal_multiply_fusion' of github…
ankipand-qti File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
|
|
||
| #include "core/providers/qnn/builder/op_builder_factory.h" | ||
| #include "core/providers/qnn/builder/opbuilder/base_op_builder.h" | ||
| #include "core/providers/qnn/builder/qnn_def.h" | ||
| #include "core/providers/qnn/builder/qnn_model_wrapper.h" | ||
| #include "core/providers/qnn/builder/qnn_utils.h" | ||
|
|
||
|
|
@@ -37,7 +38,10 @@ Ort::Status ReciprocalOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrappe | |
| const auto& outputs = node_unit.Outputs(); | ||
| RETURN_IF_NOT(outputs.size() == 1, "Reciprocal operator must have exactly 1 output."); | ||
|
|
||
| // Check input type is float for CPU. | ||
| // On the QNN CPU backend only float32 is accepted; other backends (HTP, GPU) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not adding anything. Please avoid this change. |
||
| // are gated by the QNN SDK's own op-validation call inside | ||
| // ProcessAttributesAndOutputs (do_op_validation=true), which will return an | ||
| // error if the backend cannot handle the resulting ElementWiseDivide node. | ||
| RETURN_IF_ERROR(DataTypeCheckForCpuBackend(qnn_model_wrapper, inputs[0].type, "")); | ||
|
|
||
| return Ort::Status(); | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
190 changes: 190 additions & 0 deletions
190
onnxruntime/core/providers/qnn/builder/qnn_node_group/reciprocal_mul_fusion.cc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,190 @@ | ||
| // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| // ReciprocalMulFusion: Fuses SingleNode Reciprocal->Mul into ElementWiseDivide. | ||
| // QDQGroup pattern avoided to preserve separate quantization of 1/b. | ||
|
|
||
| #include "core/providers/qnn/builder/qnn_node_group/reciprocal_mul_fusion.h" | ||
|
|
||
| #include <array> | ||
| #include <memory> | ||
| #include <string> | ||
| #include <unordered_map> | ||
| #include <utility> | ||
|
|
||
| #include <gsl/gsl> | ||
|
|
||
| #include "core/providers/qnn/builder/op_builder_factory.h" | ||
| #include "core/providers/qnn/builder/qnn_model_wrapper.h" | ||
| #include "core/providers/qnn/builder/qnn_node_group/utils.h" | ||
| #include "core/providers/qnn/builder/qnn_utils.h" | ||
| #include "core/providers/qnn/ort_api.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace qnn { | ||
|
|
||
| // Convenience macros for validation and creation paths. | ||
| #define ValidateOnQnn(qnn_model_wrapper, reciprocal_node_unit, mul_node_unit, recip_is_mul_input0) \ | ||
| CreateOrValidateOnQnn((qnn_model_wrapper), (reciprocal_node_unit), (mul_node_unit), (recip_is_mul_input0), /*validate=*/true) | ||
| #define CreateOnQnn(qnn_model_wrapper, reciprocal_node_unit, mul_node_unit, recip_is_mul_input0) \ | ||
| CreateOrValidateOnQnn((qnn_model_wrapper), (reciprocal_node_unit), (mul_node_unit), (recip_is_mul_input0), /*validate=*/false) | ||
|
|
||
| // Forward declaration. | ||
| static Ort::Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, | ||
| const OrtNodeUnit& reciprocal_node_unit, | ||
| const OrtNodeUnit& mul_node_unit, | ||
| bool recip_is_mul_input0, | ||
| bool validate); | ||
|
|
||
| // TryFusion: Matches Reciprocal->Mul pattern and validates fusion. | ||
| std::unique_ptr<IQnnNodeGroup> ReciprocalMulFusion::TryFusion( | ||
| QnnModelWrapper& qnn_model_wrapper, | ||
| const OrtNodeUnit& reciprocal_node_unit, | ||
| const std::unordered_map<const OrtNode*, const OrtNodeUnit*>& node_to_node_unit, | ||
| const std::unordered_map<const OrtNodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group, | ||
| const Ort::Logger& logger) { | ||
| ORT_UNUSED_PARAMETER(logger); | ||
|
|
||
| // Step 1: Check op-type and unit type. | ||
| // Only accept SingleNode Reciprocal to preserve separate quantization of 1/b. | ||
| if (reciprocal_node_unit.OpType() != "Reciprocal" || | ||
| reciprocal_node_unit.UnitType() != OrtNodeUnit::Type::SingleNode) { | ||
| return nullptr; | ||
| } | ||
|
|
||
| // Step 2: Locate single Mul consumer (handles QDQ boundaries). | ||
| const OrtNodeUnit* mul_node_unit = | ||
| GetChildNodeUnitAllowQdq(qnn_model_wrapper, reciprocal_node_unit, "Mul", | ||
| node_to_node_unit, node_unit_to_qnn_node_group); | ||
| if (mul_node_unit == nullptr) { | ||
| return nullptr; | ||
| } | ||
|
|
||
| // Step 3: Determine which Mul input carries the Reciprocal output. | ||
| const auto& mul_inputs = mul_node_unit->Inputs(); | ||
| const std::string& recip_output_name = reciprocal_node_unit.Outputs()[0].name; | ||
| bool recip_is_mul_input0 = (mul_inputs[0].name == recip_output_name); | ||
| bool recip_is_mul_input1 = (mul_inputs[1].name == recip_output_name); | ||
|
|
||
| if (!recip_is_mul_input0 && !recip_is_mul_input1) { | ||
| return nullptr; | ||
| } | ||
|
|
||
| if (recip_is_mul_input0 && recip_is_mul_input1) { | ||
| return nullptr; // Both inputs same: would change semantics. | ||
| } | ||
|
|
||
| // Step 4: QNN capability validation (dry-run). | ||
| if (Ort::Status status = ValidateOnQnn(qnn_model_wrapper, reciprocal_node_unit, *mul_node_unit, recip_is_mul_input0); | ||
| !status.IsOK()) { | ||
| return nullptr; | ||
| } | ||
|
|
||
| // Step 5: Construct fusion object. | ||
| return std::make_unique<ReciprocalMulFusion>(reciprocal_node_unit, *mul_node_unit, recip_is_mul_input0); | ||
| } | ||
|
|
||
| ReciprocalMulFusion::ReciprocalMulFusion(const OrtNodeUnit& reciprocal_node_unit, | ||
| const OrtNodeUnit& mul_node_unit, | ||
| bool recip_is_mul_input0) | ||
| : node_units_{&reciprocal_node_unit, &mul_node_unit}, | ||
| recip_is_mul_input0_{recip_is_mul_input0} { | ||
| } | ||
|
|
||
| Ort::Status ReciprocalMulFusion::IsSupported(QnnModelWrapper& qmw, | ||
| const Ort::Logger& logger) const { | ||
| ORT_UNUSED_PARAMETER(logger); | ||
| return ValidateOnQnn(qmw, *node_units_[0], *node_units_[1], recip_is_mul_input0_); | ||
| } | ||
|
|
||
| Ort::Status ReciprocalMulFusion::AddToModelBuilder(QnnModelWrapper& qmw, | ||
| const Ort::Logger& logger) const { | ||
| ORT_UNUSED_PARAMETER(logger); | ||
| return CreateOnQnn(qmw, *node_units_[0], *node_units_[1], recip_is_mul_input0_); | ||
| } | ||
|
|
||
| gsl::span<const OrtNodeUnit* const> ReciprocalMulFusion::GetNodeUnits() const { | ||
| return node_units_; | ||
| } | ||
|
|
||
| const OrtNodeUnit* ReciprocalMulFusion::GetTargetNodeUnit() const { | ||
| return node_units_[1]; // Mul is the convergence point. | ||
| } | ||
|
|
||
| // CreateOrValidateOnQnn: Shared validate/build path. | ||
| static Ort::Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, | ||
| const OrtNodeUnit& reciprocal_node_unit, | ||
| const OrtNodeUnit& mul_node_unit, | ||
| bool recip_is_mul_input0, | ||
| bool validate) { | ||
| RETURN_IF_NOT(reciprocal_node_unit.OpType() == "Reciprocal", | ||
| ("ReciprocalMulFusion: expected Reciprocal op, got " + reciprocal_node_unit.OpType()).c_str()); | ||
| RETURN_IF_NOT(mul_node_unit.OpType() == "Mul", | ||
| ("ReciprocalMulFusion: expected Mul op, got " + mul_node_unit.OpType()).c_str()); | ||
|
|
||
| // Resolve tensor roles. | ||
| const OrtNodeUnitIODef& denominator_def = reciprocal_node_unit.Inputs()[0]; | ||
| const auto& mul_inputs = mul_node_unit.Inputs(); | ||
| const OrtNodeUnitIODef& numerator_def = recip_is_mul_input0 ? mul_inputs[1] : mul_inputs[0]; | ||
| const OrtNodeUnitIODef& output_def = mul_node_unit.Outputs()[0]; | ||
| const std::string node_name = utils::UniqueNameGenerator().New(reciprocal_node_unit); | ||
|
|
||
| if (validate) { | ||
| // Dry-run: capability query only. | ||
| QnnTensorWrapper numerator_tensor; | ||
| QnnTensorWrapper denominator_tensor; | ||
| QnnTensorWrapper output_tensor; | ||
|
|
||
| RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(numerator_def, numerator_tensor)); | ||
| RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(denominator_def, denominator_tensor)); | ||
| RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); | ||
|
|
||
| RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode( | ||
| node_name, | ||
| QNN_OP_PACKAGE_NAME_QTI_AISW, | ||
| QNN_OP_ELEMENT_WISE_DIVIDE, | ||
| /*input_tensors=*/{numerator_tensor.GetQnnTensor(), denominator_tensor.GetQnnTensor()}, | ||
| /*output_tensors=*/{output_tensor.GetQnnTensor()}, | ||
| /*params=*/{})); | ||
| } else { | ||
| // Build path: register tensors and create QNN node. | ||
|
|
||
| if (!qnn_model_wrapper.IsQnnTensorWrapperExist(numerator_def.name)) { | ||
| QnnTensorWrapper numerator_tensor; | ||
| RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(numerator_def, numerator_tensor)); | ||
| RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(numerator_tensor)), | ||
| "ReciprocalMulFusion: failed to add numerator tensor wrapper."); | ||
| } | ||
|
|
||
| if (!qnn_model_wrapper.IsQnnTensorWrapperExist(denominator_def.name)) { | ||
| QnnTensorWrapper denominator_tensor; | ||
| RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(denominator_def, denominator_tensor)); | ||
| RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(denominator_tensor)), | ||
| "ReciprocalMulFusion: failed to add denominator tensor wrapper."); | ||
| } | ||
|
|
||
| if (!qnn_model_wrapper.IsQnnTensorWrapperExist(output_def.name)) { | ||
| QnnTensorWrapper output_tensor; | ||
| RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); | ||
| RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), | ||
| "ReciprocalMulFusion: failed to add output tensor wrapper."); | ||
| } | ||
|
|
||
| // Create fused ElementWiseDivide node. | ||
| RETURN_IF_NOT( | ||
| qnn_model_wrapper.CreateQnnNode( | ||
| node_name, | ||
| QNN_OP_PACKAGE_NAME_QTI_AISW, | ||
| QNN_OP_ELEMENT_WISE_DIVIDE, | ||
| /*input_names=*/{numerator_def.name, denominator_def.name}, | ||
| /*output_names=*/{output_def.name}, | ||
| /*param_tensor_names=*/{}, | ||
| /*do_op_validation=*/validate), | ||
| "ReciprocalMulFusion: failed to create fused ElementWiseDivide node."); | ||
| } | ||
|
|
||
| return Ort::Status(); | ||
| } | ||
|
|
||
| } // namespace qnn | ||
| } // namespace onnxruntime |
48 changes: 48 additions & 0 deletions
48
onnxruntime/core/providers/qnn/builder/qnn_node_group/reciprocal_mul_fusion.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
| // SPDX-License-Identifier: MIT | ||
|
|
||
| // ReciprocalMulFusion: Fuses SingleNode Reciprocal->Mul into ElementWiseDivide. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <array> | ||
| #include <memory> | ||
| #include <unordered_map> | ||
|
|
||
| #include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" | ||
| #include "core/providers/qnn/ort_api.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace qnn { | ||
|
|
||
| class QnnModelWrapper; | ||
|
|
||
| /// Fuses Reciprocal->Mul into ElementWiseDivide (SingleNode only to preserve quantization). | ||
| class ReciprocalMulFusion : public IQnnNodeGroup { | ||
| public: | ||
| ReciprocalMulFusion(const OrtNodeUnit& reciprocal_node_unit, const OrtNodeUnit& mul_node_unit, | ||
| bool recip_is_mul_input0); | ||
| ORT_DISALLOW_COPY_AND_ASSIGNMENT(ReciprocalMulFusion); | ||
|
|
||
| // IQnnNodeGroup interface | ||
| Ort::Status IsSupported(QnnModelWrapper& qmw, const Ort::Logger& logger) const override; | ||
| Ort::Status AddToModelBuilder(QnnModelWrapper& qmw, const Ort::Logger& logger) const override; | ||
| gsl::span<const OrtNodeUnit* const> GetNodeUnits() const override; | ||
| const OrtNodeUnit* GetTargetNodeUnit() const override; | ||
| std::string_view Type() const override { return "ReciprocalMulFusion"; } | ||
|
|
||
| // Factory | ||
| static std::unique_ptr<IQnnNodeGroup> TryFusion( | ||
| QnnModelWrapper& qnn_model_wrapper, | ||
| const OrtNodeUnit& reciprocal_node_unit, | ||
| const std::unordered_map<const OrtNode*, const OrtNodeUnit*>& node_to_node_unit, | ||
| const std::unordered_map<const OrtNodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group, | ||
| const Ort::Logger& logger); | ||
|
|
||
| private: | ||
| std::array<const OrtNodeUnit*, 2> node_units_; // [0]=Reciprocal, [1]=Mul | ||
| bool recip_is_mul_input0_{false}; // Which Mul input slot carries Reciprocal output. | ||
| }; | ||
|
|
||
| } // namespace qnn | ||
| } // namespace onnxruntime |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not needed as there are no changes in this file.