Skip to content

Commit 9c4471c

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for div.Tensor_mode core ATen op (#19785)
### Summary Added support for the core ATen op `div.Tensor_mode` using a decomposition pass and the `div, trunc, floor` ops, based on the selected mode: ``` div(x, y, rounding_mode=None) -> div(x, y) div(x, y, rounding_mode="trunc") -> trunc(div(x, y)) div(x, y, rounding_mode="floor") -> floor(div(x, y)) ``` ### Test plan ``` python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator.test_qnn_backend_div_mode --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNFloatingPointOperator.test_qnn_backend_div_mode --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator.test_qnn_backend_div_scalar_mode --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNFloatingPointOperator.test_qnn_backend_div_scalar_mode --model SM8750 --host aisw-vm15-labsd --device 545ee4aa --build_folder build-android ``` cc @cccclai @cbilgin @abhinaykukkadapu
1 parent e257a71 commit 9c4471c

8 files changed

Lines changed: 213 additions & 1 deletion

File tree

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .decompose_binary_alpha import DecomposeBinaryAlpha
2020
from .decompose_cdist import DecomposeCDist
2121
from .decompose_col_im import DecomposeColIm
22+
from .decompose_div_mode import DecomposeDivMode
2223
from .decompose_einsum import DecomposeEinsum
2324
from .decompose_expm1 import DecomposeExpM1
2425
from .decompose_fill import DecomposeFill
@@ -82,6 +83,7 @@
8283
DecomposeBinaryAlpha,
8384
DecomposeCDist,
8485
DecomposeColIm,
86+
DecomposeDivMode,
8587
DecomposeEinsum,
8688
DecomposeExpM1,
8789
DecomposeFill,
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
from .utils import copy_meta
13+
14+
15+
class DecomposeDivMode(ExportPass):
16+
"""
17+
Decompose aten.div.Tensor_mode into supported primitives.
18+
19+
div(x, y, rounding_mode=None) -> div(x, y)
20+
div(x, y, rounding_mode="trunc") -> trunc(div(x, y))
21+
div(x, y, rounding_mode="floor") -> floor(div(x, y))
22+
23+
Note: div.Scalar_mode is handled by LiftConstantScalarOperands which converts it to div.Tensor_mode before this pass runs.
24+
"""
25+
26+
def __init__(self):
27+
super(DecomposeDivMode, self).__init__()
28+
self.targets = {
29+
torch.ops.aten.div.Tensor_mode,
30+
exir_ops.edge.aten.div.Tensor_mode,
31+
}
32+
33+
def call(self, graph_module: torch.fx.GraphModule):
34+
graph = graph_module.graph
35+
36+
for node in list(graph.nodes):
37+
if node.op == "call_function" and node.target in self.targets:
38+
is_edge = isinstance(node.target, EdgeOpOverload)
39+
meta = node.meta
40+
41+
x_node = node.args[0]
42+
y_node = node.args[1]
43+
44+
rounding_mode = node.kwargs.get("rounding_mode", None)
45+
if rounding_mode is None and len(node.args) > 2:
46+
rounding_mode = node.args[2]
47+
48+
div_op = (
49+
exir_ops.edge.aten.div.Tensor
50+
if is_edge
51+
else torch.ops.aten.div.Tensor
52+
)
53+
54+
with graph.inserting_before(node):
55+
# Step 1: div_result = div(x, y)
56+
div_node = graph.create_node(
57+
"call_function", div_op, (x_node, y_node)
58+
)
59+
div_node.meta = copy_meta(meta)
60+
61+
# Step 2: Apply rounding mode if needed
62+
if rounding_mode == "trunc":
63+
trunc_op = (
64+
exir_ops.edge.aten.trunc.default
65+
if is_edge
66+
else torch.ops.aten.trunc.default
67+
)
68+
result_node = graph.create_node(
69+
"call_function", trunc_op, (div_node,)
70+
)
71+
result_node.meta = copy_meta(meta)
72+
elif rounding_mode == "floor":
73+
floor_op = (
74+
exir_ops.edge.aten.floor.default
75+
if is_edge
76+
else torch.ops.aten.floor.default
77+
)
78+
result_node = graph.create_node(
79+
"call_function", floor_op, (div_node,)
80+
)
81+
result_node.meta = copy_meta(meta)
82+
else:
83+
# rounding_mode=None: plain division
84+
result_node = div_node
85+
86+
for user in node.users.copy():
87+
user.replace_input_with(node, result_node)
88+
89+
graph.eliminate_dead_code()
90+
graph_module.recompile()
91+
return PassResult(graph_module, True)

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class TensorOpInfo:
4343
# For below cases, refer to LiftAddTensor Model in UT for sample
4444
aten.add.Tensor: TensorOpInfo(aten.add.Tensor, False, False),
4545
aten.div.Scalar: TensorOpInfo(aten.div.Tensor, False, False),
46+
aten.div.Scalar_mode: TensorOpInfo(aten.div.Tensor_mode, False, False),
4647
aten.mul.Scalar: TensorOpInfo(aten.mul.Tensor, False, False),
4748
aten.rsub.Scalar: TensorOpInfo(aten.rsub.Tensor, False, False),
4849
aten.sub.Scalar: TensorOpInfo(aten.sub.Tensor, False, False),

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
DecomposeBinaryAlpha,
2626
DecomposeCDist,
2727
DecomposeColIm,
28+
DecomposeDivMode,
2829
DecomposeEinsum,
2930
DecomposeExpM1,
3031
DecomposeFill,
@@ -127,6 +128,7 @@ def get_default_pass_activations(cls):
127128
(DecomposeAtan2, True),
128129
(DecomposeColIm, True),
129130
(DecomposeCDist, True),
131+
(DecomposeDivMode, True),
130132
(DecomposeFill, True),
131133
(DecomposeLogVariants, True),
132134
(DecomposeMaxPool3d, True),
@@ -164,6 +166,7 @@ def get_annotation_passes(cls):
164166
DecomposeAtan2,
165167
DecomposeBinaryAlpha,
166168
DecomposeCDist,
169+
DecomposeDivMode,
167170
DecomposeMaxPool3d,
168171
DecomposePad,
169172
DecomposeScaledDotProductAttention,
@@ -280,6 +283,7 @@ def get_passes_dependency_for_capture_program(cls):
280283
DecomposeAtan2: [RemoveRedundancy],
281284
DecomposeColIm: [FoldQDQ],
282285
DecomposeCDist: [RemoveRedundancy],
286+
DecomposeDivMode: [RemoveRedundancy],
283287
DecomposeFill: [RemoveRedundancy],
284288
DecomposeLinalgVectorNorm: [RemoveRedundancy],
285289
DecomposeLogVariants: [RemoveRedundancy],

backends/qualcomm/builders/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,8 @@ The following PyTorch operators are supported through decomposition or annotatio
503503
| `aten.atan2.default`, `aten.atan2.out` | `DecomposeAtan2` |
504504
| `aten.add` (with alpha), `aten.sub` (with alpha) | `DecomposeBinaryAlpha` |
505505
| `aten.cdist`, `aten._cdist_forward` | `DecomposeCDist` |
506+
| `aten.div.Tensor_mode` | `DecomposeDivMode` |
507+
| `aten.div.Scalar_mode` | `LiftConstantScalarOperands``DecomposeDivMode` |
506508
| `aten.im2col`, `aten.col2im` | `DecomposeColIm` |
507509
| `aten.einsum` | `DecomposeEinsum` |
508510
| `aten.special_expm1` | `DecomposeExpM1` |

backends/qualcomm/partition/common_defs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
to_be_implemented_operator = [
2222
exir_ops.edge.aten.adaptive_max_pool3d.default,
23-
exir_ops.edge.aten.div.Tensor_mode,
2423
exir_ops.edge.aten.max_pool3d_with_indices.default,
2524
exir_ops.edge.aten.median.default,
2625
exir_ops.edge.aten.median.dim,

backends/qualcomm/tests/models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,25 @@ def forward(self, x):
10211021
return x / 10
10221022

10231023

1024+
class DivMode(torch.nn.Module):
1025+
def __init__(self, rounding_mode=None):
1026+
super().__init__()
1027+
self.rounding_mode = rounding_mode
1028+
1029+
def forward(self, x, y):
1030+
return torch.div(x, y, rounding_mode=self.rounding_mode)
1031+
1032+
1033+
class DivScalarMode(torch.nn.Module):
1034+
def __init__(self, scalar=2.0, rounding_mode=None):
1035+
super().__init__()
1036+
self.scalar = scalar
1037+
self.rounding_mode = rounding_mode
1038+
1039+
def forward(self, x):
1040+
return torch.div(x, self.scalar, rounding_mode=self.rounding_mode)
1041+
1042+
10241043
class DrawGraphModel(torch.nn.Module):
10251044
def __init__(self):
10261045
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,52 @@ def test_qnn_backend_cumsum(self):
633633
index += 1
634634
self.lower_module_and_test_output(module, sample_input)
635635

636+
def test_qnn_backend_div_mode(self):
637+
test_comb = [
638+
{
639+
QCOM_MODULE: [
640+
DivMode(rounding_mode=None), # noqa: F405
641+
DivMode(rounding_mode="trunc"), # noqa: F405
642+
DivMode(rounding_mode="floor"), # noqa: F405
643+
],
644+
QCOM_SAMPLE_INPUTS: [
645+
(
646+
torch.tensor([7.0, 5.0, -3.0, 8.0, 1.0, 9.0]).reshape(2, 3),
647+
torch.tensor([2.0, 3.0, 2.0, 5.0, 4.0, 2.0]).reshape(2, 3),
648+
),
649+
],
650+
},
651+
]
652+
653+
index = 0
654+
for comb in test_comb:
655+
for module in comb[QCOM_MODULE]:
656+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
657+
with self.subTest(i=index):
658+
index += 1
659+
self.lower_module_and_test_output(module, sample_input)
660+
661+
def test_qnn_backend_div_scalar_mode(self):
662+
test_comb = [
663+
{
664+
QCOM_MODULE: [
665+
DivScalarMode(scalar=2.0, rounding_mode="trunc"), # noqa: F405
666+
DivScalarMode(scalar=3.0, rounding_mode="floor"), # noqa: F405
667+
],
668+
QCOM_SAMPLE_INPUTS: [
669+
(torch.tensor([7.0, 5.0, -3.0, 8.0, 1.0, 9.0]).reshape(2, 3),),
670+
],
671+
},
672+
]
673+
674+
index = 0
675+
for comb in test_comb:
676+
for module in comb[QCOM_MODULE]:
677+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
678+
with self.subTest(i=index):
679+
index += 1
680+
self.lower_module_and_test_output(module, sample_input)
681+
636682
def test_qnn_backend_einsum_outer_product(self):
637683
module = EinsumOuterProduct() # noqa: F405
638684
x = torch.randn(5)
@@ -3434,6 +3480,54 @@ def test_qnn_backend_cumsum(self):
34343480
module = self.get_qdq_module(module, sample_input)
34353481
self.lower_module_and_test_output(module, sample_input)
34363482

3483+
def test_qnn_backend_div_mode(self):
3484+
test_comb = [
3485+
{
3486+
QCOM_MODULE: [
3487+
DivMode(rounding_mode=None), # noqa: F405
3488+
DivMode(rounding_mode="trunc"), # noqa: F405
3489+
DivMode(rounding_mode="floor"), # noqa: F405
3490+
],
3491+
QCOM_SAMPLE_INPUTS: [
3492+
(
3493+
torch.tensor([7.0, 5.0, -3.0, 8.0, 1.0, 9.0]).reshape(2, 3),
3494+
torch.tensor([2.0, 3.0, 2.0, 5.0, 4.0, 2.0]).reshape(2, 3),
3495+
),
3496+
],
3497+
},
3498+
]
3499+
3500+
index = 0
3501+
for comb in test_comb:
3502+
for module in comb[QCOM_MODULE]:
3503+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
3504+
with self.subTest(i=index):
3505+
index += 1
3506+
qdq_module = self.get_qdq_module(module, sample_input)
3507+
self.lower_module_and_test_output(qdq_module, sample_input)
3508+
3509+
def test_qnn_backend_div_scalar_mode(self):
3510+
test_comb = [
3511+
{
3512+
QCOM_MODULE: [
3513+
DivScalarMode(scalar=2.0, rounding_mode="trunc"), # noqa: F405
3514+
DivScalarMode(scalar=3.0, rounding_mode="floor"), # noqa: F405
3515+
],
3516+
QCOM_SAMPLE_INPUTS: [
3517+
(torch.tensor([7.0, 5.0, -3.0, 8.0, 1.0, 9.0]).reshape(2, 3),),
3518+
],
3519+
},
3520+
]
3521+
3522+
index = 0
3523+
for comb in test_comb:
3524+
for module in comb[QCOM_MODULE]:
3525+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
3526+
with self.subTest(i=index):
3527+
index += 1
3528+
qdq_module = self.get_qdq_module(module, sample_input)
3529+
self.lower_module_and_test_output(qdq_module, sample_input)
3530+
34373531
def test_qnn_backend_einsum_outer_product(self):
34383532
module = EinsumOuterProduct() # noqa: F405
34393533
x = torch.randn(5)

0 commit comments

Comments
 (0)