Skip to content

Commit 8d2e488

Browse files
committed
NXP backend: Enable aten.bmm with new Neutron flow.
1 parent b094b0e commit 8d2e488

5 files changed

Lines changed: 135 additions & 240 deletions

File tree

backends/nxp/backend/ir/converter/node_converters/ops_converters/bmm_converter.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
import torch
56

67
from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT
7-
from executorch.backends.nxp.backend.edge_helper import input_rank
8+
from executorch.backends.nxp.backend.edge_helper import (
9+
get_quantization_parameters_for,
10+
input_rank,
11+
)
812
from executorch.backends.nxp.backend.ir.converter.conversion import translator
913
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
1014
from executorch.backends.nxp.backend.ir.converter.node_converter import (
@@ -14,9 +18,6 @@
1418
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
1519
batch_mat_mul_options,
1620
)
17-
from executorch.backends.nxp.backend.neutron_operator_support import (
18-
transposition_is_supported_on_neutron,
19-
)
2021
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
2122
from torch.fx import Node
2223
from torch.nn import Parameter
@@ -44,35 +45,18 @@ def _is_supported_on_target(
4445
parameters_mapping: dict[str, Parameter],
4546
custom_delegation_options: CustomDelegationOptions,
4647
) -> bool:
47-
is_ch_first_1 = node.args[0].meta[NXP_NODE_FORMAT].is_channels_first()
48-
is_ch_first_2 = node.args[1].meta[NXP_NODE_FORMAT].is_channels_first()
49-
# This combination of node formats is not supported on Neutron (`adj_x = True`, `adj_y = False`),
50-
# but it should never happen because both input tensors are expected to share the same format.
51-
if is_ch_first_1 and not is_ch_first_2:
48+
if not NodeConverter.uses_quantization_type_for_io(
49+
node,
50+
supported_types=[torch.int8, torch.uint8],
51+
input_indices=[0, 1],
52+
output_indices=[0],
53+
):
5254
return False
5355

54-
# In case we need to insert transpose after `BatchMatMul`, we also need to check if
55-
# such transposition is supported.
56-
if node.meta[NXP_NODE_FORMAT].is_channels_first():
57-
tensor_shape = node.meta["val"].shape
58-
tensor_rank = len(tensor_shape)
59-
perm = translator.create_channels_first_to_channels_last_permutation(
60-
tensor_rank, return_list=True
61-
)
62-
63-
tensor_shape_channels_last = [tensor_shape[i] for i in perm]
64-
if not transposition_is_supported_on_neutron(
65-
tensor_shape_channels_last, perm, neutron_target_spec
66-
):
67-
return False
68-
69-
_, d1, d2 = node.args[0].meta["val"].shape
70-
_, d3, d4 = node.args[1].meta["val"].shape
71-
72-
# The Neutron converter requires that every dimension participating in the
73-
# multiplication is divisible by NUM_MACS.
74-
num_macs = neutron_target_spec.get_num_macs()
75-
if not all(m % num_macs == 0 for m in [d1, d2, d3, d4]):
56+
_, input_1_zp = get_quantization_parameters_for(node.args[0])
57+
_, input_2_zp = get_quantization_parameters_for(node.args[1])
58+
if not (input_1_zp == input_2_zp == 0):
59+
# Neutron requirement.
7660
return False
7761

7862
return True

backends/nxp/quantizer/patterns.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from functools import partial
1111

1212
import torch
13+
1314
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.clamp_converter import (
1415
_is_convertible_to_relu,
1516
)
@@ -22,6 +23,8 @@
2223
from torch.fx import Node
2324
from torchao.quantization.pt2e import (
2425
FakeQuantize,
26+
MinMaxObserver,
27+
MovingAverageMinMaxObserver,
2528
MovingAveragePerChannelMinMaxObserver,
2629
PerChannelMinMaxObserver,
2730
)
@@ -326,10 +329,24 @@ def get_anchors(
326329
) -> PartitionAnchors | None:
327330
bmm_node = fused_partition[0].nodes[-1]
328331

332+
# Use per_tensor_symmetric to enforce zero_point=0 for both inputs
333+
observer_or_fake_quant_ctr = (
334+
FakeQuantize.with_args(observer=MovingAverageMinMaxObserver)
335+
if self.is_qat
336+
else MinMaxObserver
337+
)
338+
input_quantization_spec = QuantizationSpec(
339+
dtype=torch.int8,
340+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
341+
quant_min=-128,
342+
quant_max=127,
343+
qscheme=torch.per_tensor_symmetric, # Neutron requires the inputs to have zero point = 0.
344+
)
345+
329346
return PartitionAnchors(
330347
inputs=[
331-
(bmm_node, NodeArgsIdx(0)),
332-
(bmm_node, NodeArgsIdx(1)),
348+
(bmm_node, NodeArgsIdx(0), input_quantization_spec),
349+
(bmm_node, NodeArgsIdx(1), input_quantization_spec),
333350
],
334351
biases=[],
335352
output=[(bmm_node,)],

0 commit comments

Comments
 (0)