Skip to content

Commit 3495635

Browse files
committed
NXP backend: Enable MM and AddMM with new Neutron flow.
1 parent fa5fc74 commit 3495635

10 files changed

Lines changed: 424 additions & 187 deletions

File tree

backends/nxp/backend/edge_helper.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,43 @@ def node_is_effectively_static_tensor(
109109
)
110110

111111

112+
def weights_are_effectively_static(
113+
node: Node, parameters_mapping: dict[str, Parameter], weight_index: int = 1
114+
) -> bool:
115+
"""Neutron IR sometimes requires some weights to be static. This method checks if this is the case for the
116+
provided `node`.
117+
118+
Sometimes a `permute_copy` is inserted to transpose the weights during edge lowering. The `permute_copy` is
119+
then removed during conversion to Neutron IR if it transposes static data. In those cases, the weights will be
120+
static. Therefore, it is ok if the weights are produced by a `permute_copy` with a static input.
121+
122+
:param node: Tensor node to check for data.
123+
:param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the
124+
`state_dict` attribute of an edge program.
125+
:param weight_index: Index to the `node.args` where the weight is located. Defaults to 1.
126+
:return: True if the weight at the given index is effectively static.
127+
"""
128+
129+
def _is_permute_copy(node_: Node) -> bool:
130+
return hasattr(node_, "target") and node_.target == PermuteCopy
131+
132+
if (
133+
_is_dequantize(dq_node := node.args[weight_index])
134+
and _is_quantize(q_node := dq_node.args[0])
135+
and _is_permute_copy(permute_copy_node := q_node.args[0])
136+
):
137+
# The weights are produced by a `permute_copy`. Its input (the weights) must be static.
138+
return node_is_effectively_static_tensor(
139+
permute_copy_node.args[0], parameters_mapping
140+
)
141+
142+
else:
143+
# There is no `permute_copy`. The weights must be static directly.
144+
return node_is_effectively_static_tensor(
145+
node.args[weight_index], parameters_mapping
146+
)
147+
148+
112149
def try_get_tensor_constant_from_node(
113150
graph_module: GraphModule, node: Node
114151
) -> Parameter | None:

backends/nxp/backend/ir/converter/node_converters/ops_converters/addmm_converter.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
# Copyright 2024-2025 NXP
1+
# Copyright 2024-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from executorch.backends.nxp.backend.edge_helper import input_rank
6+
import torch
7+
8+
from executorch.backends.nxp.backend.edge_helper import (
9+
input_rank,
10+
node_is_effectively_static_tensor,
11+
weights_are_effectively_static,
12+
)
713
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
814
from executorch.backends.nxp.backend.ir.converter.node_converter import (
915
CustomDelegationOptions,
@@ -12,10 +18,18 @@
1218
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
1319
fully_connected_options,
1420
)
21+
22+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
1523
from torch.fx import Node
1624
from torch.nn import Parameter
1725

1826

27+
# The edge operator signature is: aten.addmm(bias, input, weight, *, beta=1, alpha=1)
28+
MAIN_INPUT_IDX = 1
29+
WEIGHT_IDX = 2
30+
BIAS_IDX = 0
31+
32+
1933
class AddMMConverter(NodeConverter):
2034
"""Convert the `aten.addmm` operator to TFLite `FullyConnected` with a bias input."""
2135

@@ -29,12 +43,67 @@ def _is_supported_in_IR(
2943
return False
3044

3145
# The weights must be 2D.
32-
if input_rank(node, 2) != 2:
46+
if input_rank(node, WEIGHT_IDX) != 2:
47+
return False
48+
49+
alpha, beta = node.kwargs.get("alpha", 1), node.kwargs.get("beta", 1)
50+
if alpha != 1 or beta != 1:
51+
# As these cases seem rare, conversion is not implemented for the time being.
52+
return False
53+
54+
return True
55+
56+
@staticmethod
57+
def _is_supported_on_target(
58+
node: Node,
59+
neutron_target_spec: NeutronTargetSpec,
60+
parameters_mapping: dict[str, Parameter],
61+
custom_delegation_options: CustomDelegationOptions,
62+
) -> bool:
63+
# Main input and output must be `int8` or `uint8`.
64+
if not NodeConverter.uses_quantization_type_for_io(
65+
node, [torch.int8, torch.uint8], [MAIN_INPUT_IDX], [0]
66+
):
67+
return False
68+
69+
# Weights must be `int8`.
70+
if not NodeConverter.uses_quantization_type_for_io(
71+
node, [torch.int8], [WEIGHT_IDX], []
72+
):
73+
return False
74+
75+
# Bias must be `int32`.
76+
if not NodeConverter.uses_quantization_type_for_io(
77+
node, [torch.int32], [BIAS_IDX], []
78+
):
79+
return False
80+
81+
# Weights must be constant.
82+
if not weights_are_effectively_static(
83+
node, parameters_mapping, weight_index=WEIGHT_IDX
84+
):
85+
return False
86+
87+
# The bias must be constant.
88+
if not node_is_effectively_static_tensor(
89+
node.args[BIAS_IDX], parameters_mapping
90+
):
3391
return False
3492

3593
return True
3694

3795
def convert(self, node: Node):
96+
"""Convert the `aten.addmm` operator to NeutronIR `FullyConnected`.
97+
The schema is:
98+
addmm(
99+
Tensor self,
100+
Tensor mat1,
101+
Tensor mat2,
102+
*,
103+
Scalar beta=1,
104+
Scalar alpha=1
105+
) -> Tensor
106+
"""
38107
self.assert_convertible(node)
39108

40109
t_op = self._create_tflite_op_with_io_tensors(node)
@@ -47,14 +116,14 @@ def convert(self, node: Node):
47116
w = t_op.tmp_inputs[2]
48117
y = t_op.tmp_outputs[0]
49118

50-
# Assign the operator its TFLite inputs and outputs
119+
# Assign the operator its Neutron IR inputs and outputs
51120
t_op.tmp_inputs = [x, w, bias]
52121
t_op.tmp_outputs = [y]
53122

54123
ops = OpsList(middle_op=t_op)
55124

56125
# The `aten.addmm` uses main input with shape [M, N] and the weights have the shape [N, O].
57-
# TFLite `FullyConnected` requires the weights to have shape [O, N] (if the main input has shape [M, N]).
126+
# Neutron IR `FullyConnected` requires the weights to have shape [O, N] (if the main input has shape [M, N]).
58127
# Insert a `Transpose` operator to permute the weights to achieve correct conversion. (The `Transpose` will not
59128
# be present in the output model if the weights are static.)
60129
ops.add_pre(self.builder.create_transpose_operator_before(t_op, 1, [1, 0]))

backends/nxp/backend/ir/converter/node_converters/ops_converters/mm_converter.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
# Copyright 2024-2025 NXP
1+
# Copyright 2024-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from executorch.backends.nxp.backend.edge_helper import input_rank
6+
import torch
7+
8+
from executorch.backends.nxp.backend.edge_helper import (
9+
input_rank,
10+
weights_are_effectively_static,
11+
)
712
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
813
from executorch.backends.nxp.backend.ir.converter.node_converter import (
914
CustomDelegationOptions,
@@ -12,6 +17,7 @@
1217
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
1318
fully_connected_options,
1419
)
20+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
1521
from torch.fx import Node
1622
from torch.nn import Parameter
1723

@@ -33,8 +39,37 @@ def _is_supported_in_IR(
3339

3440
return True
3541

42+
@staticmethod
43+
def _is_supported_on_target(
44+
node: Node,
45+
neutron_target_spec: NeutronTargetSpec,
46+
parameters_mapping: dict[str, Parameter],
47+
custom_delegation_options: CustomDelegationOptions,
48+
) -> bool:
49+
# Main input and output must be `int8` or `uint8`.
50+
if not NodeConverter.uses_quantization_type_for_io(
51+
node, [torch.int8, torch.uint8], [0], [0]
52+
):
53+
return False
54+
55+
# Weights must be `int8`.
56+
if not NodeConverter.uses_quantization_type_for_io(node, [torch.int8], [1], []):
57+
return False
58+
59+
# Weights must be static.
60+
if not weights_are_effectively_static(node, parameters_mapping):
61+
return False
62+
63+
return True
64+
3665
def convert(self, node: Node):
37-
"""Convert the `aten.mm` operator to TFLite `FullyConnected` without a bias input."""
66+
"""Convert the `aten.mm` operator to Neutron IR `FullyConnected` without a bias input.
67+
The schema is:
68+
mm(
69+
Tensor self,
70+
Tensor mat2
71+
) -> Tensor
72+
"""
3873
self.assert_convertible(node)
3974

4075
t_op = self._create_tflite_op_with_io_tensors(node)
@@ -44,14 +79,14 @@ def convert(self, node: Node):
4479
w = t_op.tmp_inputs[1]
4580
y = t_op.tmp_outputs[0]
4681

47-
# Assign the operator its TFLite inputs and outputs
82+
# Assign the operator its Neutron IR inputs and outputs
4883
t_op.tmp_inputs = [x, w]
4984
t_op.tmp_outputs = [y]
5085

5186
ops = OpsList(middle_op=t_op)
5287

5388
# The `aten.mm` uses main input with shape [M, N] and the weights have the shape [N, O].
54-
# TFLite `FullyConnected` requires the weights to have shape [O, N] (if the main input has shape [M, N]).
89+
# Neutron IR `FullyConnected` requires the weights to have shape [O, N] (if the main input has shape [M, N]).
5590
# Insert a `Transpose` operator to permute the weights to achieve correct conversion. (The `Transpose` will not
5691
# be present in the output model if the weights are static.)
5792
ops.add_pre(self.builder.create_transpose_operator_before(t_op, 1, [1, 0]))

backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
1111
from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer
12+
from executorch.backends.nxp.tests.ops_aliases import PermuteCopy
1213

1314
# noinspection PyProtectedMember
1415
from executorch.exir.dialects._ops import ops as exir_ops
@@ -109,9 +110,11 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
109110
main_cluster_node_to_auxiliary_nodes = {
110111
AddMM: [
111112
ViewCopy,
113+
PermuteCopy,
112114
],
113115
MM: [
114116
ViewCopy,
117+
PermuteCopy,
115118
],
116119
ViewCopy: [Clone, CloneDimOrder],
117120
Conv: [

0 commit comments

Comments
 (0)