diff --git a/backends/nxp/backend/ir/converter/node_converter.py b/backends/nxp/backend/ir/converter/node_converter.py index f8405f37680..cd9af43b42c 100755 --- a/backends/nxp/backend/ir/converter/node_converter.py +++ b/backends/nxp/backend/ir/converter/node_converter.py @@ -5,6 +5,7 @@ import operator from abc import ABC, abstractmethod +from math import prod from typing import Callable import torch @@ -411,30 +412,31 @@ def uses_shape_broadcasting(node: Node) -> bool: ) @staticmethod - def at_least_one_input_shape_matches_the_output_shape(node: Node) -> bool: - """Determine if given PyTorch fx Node uses at least one input shape broadcasting for it's input nodes or not. + def inputs_satisfy_broadcast_condition(node: Node) -> bool: + """Determine if given PyTorch fx Node has inputs that satisfy broadcasting conditions for Neutron or not. :param node: PyTorch fx Node with 'all_input_nodes' initialized. - :return: True, if at least one input has the same shape as the output node. + :return: True, if at least one input has the same number of elements as the output node. False otherwise. """ if node.all_input_nodes is None: logger.e( logger.Code.INTERNAL_ERROR, - "node_converter.at_least_one_input_shape_matches_the_output_shape(): 'all_input_nodes' are None!", + "node_converter.inputs_satisfy_broadcast_condition(): 'all_input_nodes' are None!", ) if len(node.all_input_nodes) == 0: logger.e( logger.Code.INTERNAL_ERROR, - "node_converter.at_least_one_input_shape_matches_the_output_shape(): Operator has no inputs!", + "node_converter.inputs_satisfy_broadcast_condition(): Operator has no inputs!", ) output_shape = node.meta["val"].shape + num_elements = prod(output_shape) return any( - input_tensor.meta["val"].shape == output_shape + prod(input_tensor.meta["val"].shape) == num_elements for input_tensor in node.all_input_nodes ) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py index 8b67f954df9..cd9cb8e0824 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py @@ -26,7 +26,7 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - if not NodeConverter.at_least_one_input_shape_matches_the_output_shape(node): + if not NodeConverter.inputs_satisfy_broadcast_condition(node): return False # If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/mul_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mul_tensor_converter.py index cbbac02d708..379af0daaa6 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/mul_tensor_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mul_tensor_converter.py @@ -25,7 +25,7 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - if not NodeConverter.at_least_one_input_shape_matches_the_output_shape(node): + if not NodeConverter.inputs_satisfy_broadcast_condition(node): return False # If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/prelu_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/prelu_converter.py index 003221a16fb..eb398764914 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/prelu_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/prelu_converter.py @@ -3,7 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT +import torch + from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, @@ -24,38 +25,17 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - node_shape = node.meta["val"].shape - rank = len(node_shape) - - # According to Neutron spec., PReLU can be done only on 4D tensors - if rank != 4: - return False - - ch_idx, h_idx, w_idx = PReLUConverter._get_channel_spatial_indices(node) - # According to Neutron spec., size of channels must be divisible by num_macs. - num_macs = neutron_target_spec.get_num_macs() - if node_shape[ch_idx] % num_macs != 0: + if not NodeConverter.inputs_satisfy_broadcast_condition(node): return False - # According to Neutron spec., height * width cannot be greater than a given constant. - if node_shape[w_idx] * node_shape[h_idx] > 4096: + supported_types = [torch.int8, torch.uint8] + if not NodeConverter.uses_quantization_type_for_io( + node, supported_types, [0, 1], [0] + ): return False return True - @staticmethod - def _get_channel_spatial_indices(node: Node): - if node.meta[NXP_NODE_FORMAT].is_channels_first(): - ch_idx = 1 - h_idx = 2 - w_idx = 3 - else: - ch_idx = 3 - h_idx = 1 - w_idx = 2 - - return ch_idx, h_idx, w_idx - @staticmethod def _is_supported_in_IR( node: Node, diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py index 105dbc09c7b..a13e018dcf2 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py @@ -26,7 +26,7 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - if not NodeConverter.at_least_one_input_shape_matches_the_output_shape(node): + if not NodeConverter.inputs_satisfy_broadcast_condition(node): return False # If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes diff --git a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py index 6ac96e41cd1..768bf7ae339 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py @@ -104,6 +104,10 @@ def test__basic_nsys_inference_qat(self, mocker, request): pytest.param( [ModelInputSpec((4,)), ModelInputSpec((4, 4))], id="2 inputs 1D + 2D." ), + pytest.param( + [ModelInputSpec((10,)), ModelInputSpec((1, 1))], + id="2 inputs 2D, num_elems of input == num_elems of output", + ), ], ) def test__broadcast(self, mocker, request, input_spec): diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py index d112ff1e1ac..f1f4a98a148 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py @@ -90,6 +90,10 @@ def test__basic_nsys_inference_qat(self, mocker, request, x_input_shape): pytest.param( [ModelInputSpec((4,)), ModelInputSpec((4, 4))], id="2 inputs 1D+2D." ), + pytest.param( + [ModelInputSpec((10,)), ModelInputSpec((1, 1))], + id="2 inputs 2D, num_elems of input == num_elems of output", + ), ], ) def test__correct_broadcast(self, input_spec, mocker, request): diff --git a/backends/nxp/tests/ir/converter/node_converter/test_prelu_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_prelu_converter.py index fb25f02785a..cf154fbf9c6 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_prelu_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_prelu_converter.py @@ -9,18 +9,32 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.tests.executors import ( - convert_run_compare, - graph_contains_any_of_ops, +from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier +from executorch.backends.nxp.tests.model_output_comparator import ( + AllCloseOutputComparator, ) from executorch.backends.nxp.tests.models import ( + ConvPReLUModule, LinearPReLUModule, TwoPartitionPReLUModel, ) + +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import ( + AddMm, + Convolution, + ExecutorchDelegateCall, + GtScalar, + MulTensor, + PermuteCopy, + Prelu, + ViewCopy, + WhereSelf, +) from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program -from executorch.exir.dialects._ops import ops as exir_ops @pytest.fixture(autouse=True) @@ -29,123 +43,155 @@ def reseed_model_per_test_run(): np.random.seed(23) -# noinspection PyProtectedMember -ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate - - -@pytest.mark.parametrize( - "input_shape", - [ - pytest.param((1, 8, 24, 32), id="4D."), - ], -) -def test_prelu_with_linear_quant_conversion(mocker, input_shape): - converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - - # Run conversion - channels = input_shape[-1] - edge_program = to_quantized_edge_program( - LinearPReLUModule(in_features=channels, out_features=channels), - input_shape, - ).exported_program() - - # Capture generated entities - neutron_ir_model, _ = converter_spy.spy_return - exported_program: ExportedProgram = converter_spy.call_args.args[1] - - # Check `prelu` was not decomposed into simpler edge operators - assert not graph_contains_any_of_ops( - exported_program.graph, +class TestPreluConverter: + @pytest.mark.parametrize( + "input_shape", [ - exir_ops.edge.aten.gt.Scalar, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.where.self, + pytest.param((1,), id="1D."), + pytest.param( + (36, 487), + id="2D incorrect results.", + marks=pytest.mark.xfail( + reason="AIR-14737: incorrect results", strict=True + ), + ), + pytest.param( + (87, 842), + id="2D incorrect results alt.", + marks=pytest.mark.xfail( + reason="AIR-14737: incorrect results", strict=True + ), + ), + pytest.param((7, 83), id="2D."), + pytest.param( + (1, 43, 183), + id="3D incorrect results alt.", + marks=pytest.mark.xfail( + reason="AIR-14737: incorrect results", strict=True + ), + ), + pytest.param((1, 43, 93), id="3D."), + pytest.param((1, 4, 7, 8), id="4D."), + pytest.param((1, 4, 3, 4, 14), id="5D."), ], ) - - assert graph_contains_any_of_ops( - exported_program.graph, - [exir_ops.edge.aten.prelu.default], - ) - - # Check `prelu` was delegated - assert not graph_contains_any_of_ops( - edge_program.graph, - [exir_ops.edge.aten.prelu.default], - ) - - input_data = ( - (2 * np.random.random(input_shape).astype(np.float32) - 1) * 50 - ).astype(np.int8) - - convert_run_compare(exported_program, input_data, tfl_model=neutron_ir_model) - - -@pytest.mark.parametrize( - "input_shape", - [ - pytest.param((1, 8, 24, 32), id="4D."), - ], -) -def test_prelu_2_partitions(mocker, input_shape): - # TODO (Martin) Add a channels last dim order variant of this test to verify correct partitioning. - # Run conversion - edge_program = to_quantized_edge_program( - TwoPartitionPReLUModel(), [input_shape, input_shape] - ).exported_program() - - # Check `prelu` was delegated - assert not graph_contains_any_of_ops( - edge_program.graph, - [exir_ops.edge.aten.prelu.default], - ) - - # Check there are two partitions - edge_nodes = list(edge_program.graph.nodes) - assert sum(n.target == ExecutorchDelegateCall for n in edge_nodes) == 2 - - -@pytest.mark.parametrize( - "input_shape", - [ - pytest.param((1,), id="1D not supported."), - pytest.param((1, 8), id="2D not supported."), - pytest.param((1, 8, 16), id="3D not supported."), - pytest.param((1, 8, 16, 32, 64), id="5D not supported."), - pytest.param((1, 8, 16, 31), id="channels must be divisible by NUM_MACS"), - pytest.param((1, 8, 1024, 8), id="width*height is too big (limit 4096)"), - ], -) -def test_prelu__no_delegation__unsupported_conversion(mocker, input_shape): - # Run conversion - channels = input_shape[-1] - edge_program = to_quantized_edge_program( - LinearPReLUModule(in_features=channels, out_features=channels), - input_shape, - ).exported_program() - - # Check `prelu` was not delegated (only `linear` was) - edge_nodes = list(edge_program.graph.nodes) - assert sum(n.target == ExecutorchDelegateCall for n in edge_nodes) == 1 - - # Check `prelu` was decomposed into simpler edge operators - assert graph_contains_any_of_ops( - edge_program.graph, - [ - exir_ops.edge.aten.gt.Scalar, - ], - ) - - assert graph_contains_any_of_ops( - edge_program.graph, - [ - exir_ops.edge.aten.mul.Tensor, - ], - ) - - assert graph_contains_any_of_ops( - edge_program.graph, + def test__basic_nsys_inference(self, mocker, request, input_shape): + channels = input_shape[-1] + rank = len(input_shape) + model = LinearPReLUModule(in_features=channels, out_features=channels) + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={ + Prelu: 1, + AddMm: 1, + PermuteCopy: 1, + ViewCopy: 0 if rank == 2 else 2, + }, + expected_non_delegated_ops={}, + ) + comparator = AllCloseOutputComparator(atol=1) + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + lower_run_compare( + model, + input_shape, + graph_verifier, + request, + output_comparator=comparator, + remove_quant_io_ops=True, + ) + + # Capture generated entities + neutron_ir_model, _ = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + # Check `prelu` was not decomposed into simpler edge operators + assert not graph_contains_any_of_ops( + exported_program.graph, + [ + GtScalar, + MulTensor, + WhereSelf, + ], + ) + + def test__num_parameters_param(self, mocker, request): + input_shape = (1, 43, 93) + channels = input_shape[-1] + rank = len(input_shape) + num_parameters = input_shape[1] + model = LinearPReLUModule( + in_features=channels, out_features=channels, num_parameters=num_parameters + ) + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={ + Prelu: 1, + AddMm: 1, + PermuteCopy: 1, + ViewCopy: 0 if rank == 2 else 2, + }, + expected_non_delegated_ops={}, + ) + comparator = AllCloseOutputComparator(atol=1) + + lower_run_compare( + model, + input_shape, + graph_verifier, + request, + output_comparator=comparator, + remove_quant_io_ops=True, + ) + + def test_prelu_2_partitions(self): + input_shape = (1, 8, 24, 32) + # Run conversion + edge_program = to_quantized_edge_program( + TwoPartitionPReLUModel(), [input_shape, input_shape] + ).exported_program() + + # Check `prelu` was delegated + assert not graph_contains_any_of_ops( + edge_program.graph, + [Prelu], + ) + + # Check there are two partitions + edge_nodes = list(edge_program.graph.nodes) + assert sum(n.target == ExecutorchDelegateCall for n in edge_nodes) == 2 + + @pytest.mark.parametrize( + "input_shape", [ - exir_ops.edge.aten.where.self, + pytest.param((1, 8, 42, 24), id="4D."), + pytest.param( + (1, 8, 42, 21), + id="4D incorrect results.", + marks=pytest.mark.xfail( + reason="AIR-14737: incorrect results", strict=True + ), + ), ], ) + def test__w_conv(self, mocker, request, input_shape): + channels = input_shape[1] + model = ConvPReLUModule(in_channels=channels) + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={ + Prelu: 1, + Convolution: 1, + }, + expected_non_delegated_ops={}, + ) + comparator = AllCloseOutputComparator(atol=1) + + lower_run_compare( + model, + input_shape, + graph_verifier, + request, + output_comparator=comparator, + remove_quant_io_ops=True, + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py index e71ff7e8af5..7699a5ad089 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py @@ -104,6 +104,10 @@ def test__basic_nsys_inference_qat(self, mocker, request): [ModelInputSpec((5, 3, 4)), ModelInputSpec((1, 3, 1))], id="2 inputs 3D.", ), + pytest.param( + [ModelInputSpec((10,)), ModelInputSpec((1, 1))], + id="2 inputs 2D, num_elems of input == num_elems of output", + ), ], ) def test__broadcast(self, mocker, request, input_spec): diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index 7545dd940f2..529b7971489 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -887,6 +887,20 @@ def forward(self, x): return self.prelu(x) +class ConvPReLUModule(torch.nn.Module): + def __init__(self, in_channels, num_parameters=1): + super().__init__() + + self.conv = Conv2dModule( + in_channels=in_channels, out_channels=in_channels, stride=1, padding=1 + ) + self.prelu = torch.nn.PReLU(num_parameters) + + def forward(self, x): + x = self.conv(x) + return self.prelu(x) + + class TwoPartitionPReLUModel(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/nxp/tests/ops_aliases.py b/backends/nxp/tests/ops_aliases.py index aceb9707106..487d308cf7f 100644 --- a/backends/nxp/tests/ops_aliases.py +++ b/backends/nxp/tests/ops_aliases.py @@ -27,6 +27,7 @@ DequantizePerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate GetItem = operator.getitem +GtScalar = exir_ops.edge.aten.gt.Scalar HardTanh = exir_ops.edge.aten.hardtanh.default HardTanh_ = exir_ops.edge.aten.hardtanh_.default LeakyRelu = exir_ops.edge.aten.leaky_relu.default @@ -36,9 +37,9 @@ MeanDim = exir_ops.edge.aten.mean.dim MulTensor = exir_ops.edge.aten.mul.Tensor PermuteCopy = exir_ops.edge.aten.permute_copy.default +Prelu = exir_ops.edge.aten.prelu.default QuantizePerChannel = exir_ops.edge.quantized_decomposed.quantize_per_channel.default QuantizePerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default -PermuteCopy = exir_ops.edge.aten.permute_copy.default Relu = exir_ops.edge.aten.relu.default Sigmoid = exir_ops.edge.aten.sigmoid.default Slice = exir_ops.edge.aten.slice.Tensor @@ -54,3 +55,4 @@ UpsampleBilinear2D = exir_ops.edge.aten.upsample_bilinear2d.vec UpsampleNearest2D = exir_ops.edge.aten.upsample_nearest2d.vec ViewCopy = exir_ops.edge.aten.view_copy.default +WhereSelf = exir_ops.edge.aten.where.self