|
2 | 2 | # |
3 | 3 | # This source code is licensed under the BSD-style license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
| 5 | +import torch |
5 | 6 |
|
6 | 7 | from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT |
7 | | -from executorch.backends.nxp.backend.edge_helper import input_rank |
| 8 | +from executorch.backends.nxp.backend.edge_helper import ( |
| 9 | + get_quantization_parameters_for, |
| 10 | + input_rank, |
| 11 | +) |
8 | 12 | from executorch.backends.nxp.backend.ir.converter.conversion import translator |
9 | 13 | from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList |
10 | 14 | from executorch.backends.nxp.backend.ir.converter.node_converter import ( |
|
14 | 18 | from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( |
15 | 19 | batch_mat_mul_options, |
16 | 20 | ) |
17 | | -from executorch.backends.nxp.backend.neutron_operator_support import ( |
18 | | - transposition_is_supported_on_neutron, |
19 | | -) |
20 | 21 | from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec |
21 | 22 | from torch.fx import Node |
22 | 23 | from torch.nn import Parameter |
@@ -44,35 +45,18 @@ def _is_supported_on_target( |
44 | 45 | parameters_mapping: dict[str, Parameter], |
45 | 46 | custom_delegation_options: CustomDelegationOptions, |
46 | 47 | ) -> bool: |
47 | | - is_ch_first_1 = node.args[0].meta[NXP_NODE_FORMAT].is_channels_first() |
48 | | - is_ch_first_2 = node.args[1].meta[NXP_NODE_FORMAT].is_channels_first() |
49 | | - # This combination of node formats is not supported on Neutron (`adj_x = True`, `adj_y = False`), |
50 | | - # but it should never happen because both input tensors are expected to share the same format. |
51 | | - if is_ch_first_1 and not is_ch_first_2: |
| 48 | + if not NodeConverter.uses_quantization_type_for_io( |
| 49 | + node, |
| 50 | + supported_types=[torch.int8, torch.uint8], |
| 51 | + input_indices=[0, 1], |
| 52 | + output_indices=[0], |
| 53 | + ): |
52 | 54 | return False |
53 | 55 |
|
54 | | - # In case we need to insert transpose after `BatchMatMul`, we also need to check if |
55 | | - # such transposition is supported. |
56 | | - if node.meta[NXP_NODE_FORMAT].is_channels_first(): |
57 | | - tensor_shape = node.meta["val"].shape |
58 | | - tensor_rank = len(tensor_shape) |
59 | | - perm = translator.create_channels_first_to_channels_last_permutation( |
60 | | - tensor_rank, return_list=True |
61 | | - ) |
62 | | - |
63 | | - tensor_shape_channels_last = [tensor_shape[i] for i in perm] |
64 | | - if not transposition_is_supported_on_neutron( |
65 | | - tensor_shape_channels_last, perm, neutron_target_spec |
66 | | - ): |
67 | | - return False |
68 | | - |
69 | | - _, d1, d2 = node.args[0].meta["val"].shape |
70 | | - _, d3, d4 = node.args[1].meta["val"].shape |
71 | | - |
72 | | - # The Neutron converter requires that every dimension participating in the |
73 | | - # multiplication is divisible by NUM_MACS. |
74 | | - num_macs = neutron_target_spec.get_num_macs() |
75 | | - if not all(m % num_macs == 0 for m in [d1, d2, d3, d4]): |
| 56 | + _, input_1_zp = get_quantization_parameters_for(node.args[0]) |
| 57 | + _, input_2_zp = get_quantization_parameters_for(node.args[1]) |
| 58 | + if not (input_1_zp == input_2_zp == 0): |
| 59 | + # Neutron requirement. |
76 | 60 | return False |
77 | 61 |
|
78 | 62 | return True |
|
0 commit comments