Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions backends/nxp/backend/ir/converter/node_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import operator
from abc import ABC, abstractmethod
from math import prod
from typing import Callable

import torch
Expand Down Expand Up @@ -411,30 +412,31 @@ def uses_shape_broadcasting(node: Node) -> bool:
)

@staticmethod
def at_least_one_input_shape_matches_the_output_shape(node: Node) -> bool:
"""Determine if given PyTorch fx Node uses at least one input shape broadcasting for it's input nodes or not.
def inputs_satisfy_broadcast_condition(node: Node) -> bool:
"""Determine if given PyTorch fx Node has inputs that satisfy broadcasting conditions for Neutron or not.

:param node: PyTorch fx Node with 'all_input_nodes' initialized.
:return: True, if at least one input has the same shape as the output node.
:return: True, if at least one input has the same number of elements as the output node.
False otherwise.
"""

if node.all_input_nodes is None:
logger.e(
logger.Code.INTERNAL_ERROR,
"node_converter.at_least_one_input_shape_matches_the_output_shape(): 'all_input_nodes' are None!",
"node_converter.inputs_satisfy_broadcast_condition(): 'all_input_nodes' are None!",
)

if len(node.all_input_nodes) == 0:
logger.e(
logger.Code.INTERNAL_ERROR,
"node_converter.at_least_one_input_shape_matches_the_output_shape(): Operator has no inputs!",
"node_converter.inputs_satisfy_broadcast_condition(): Operator has no inputs!",
)

output_shape = node.meta["val"].shape
num_elements = prod(output_shape)

return any(
input_tensor.meta["val"].shape == output_shape
prod(input_tensor.meta["val"].shape) == num_elements
for input_tensor in node.all_input_nodes
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _is_supported_on_target(
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
if not NodeConverter.at_least_one_input_shape_matches_the_output_shape(node):
if not NodeConverter.inputs_satisfy_broadcast_condition(node):
return False

# If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _is_supported_on_target(
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
if not NodeConverter.at_least_one_input_shape_matches_the_output_shape(node):
if not NodeConverter.inputs_satisfy_broadcast_condition(node):
return False

# If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT
import torch

from executorch.backends.nxp.backend.ir.converter.node_converter import (
CustomDelegationOptions,
NodeConverter,
Expand All @@ -24,38 +25,17 @@ def _is_supported_on_target(
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
node_shape = node.meta["val"].shape
rank = len(node_shape)

# According to Neutron spec., PReLU can be done only on 4D tensors
if rank != 4:
return False

ch_idx, h_idx, w_idx = PReLUConverter._get_channel_spatial_indices(node)
# According to Neutron spec., size of channels must be divisible by num_macs.
num_macs = neutron_target_spec.get_num_macs()
if node_shape[ch_idx] % num_macs != 0:
if not NodeConverter.inputs_satisfy_broadcast_condition(node):
return False

# According to Neutron spec., height * width cannot be greater than a given constant.
if node_shape[w_idx] * node_shape[h_idx] > 4096:
supported_types = [torch.int8, torch.uint8]
if not NodeConverter.uses_quantization_type_for_io(
node, supported_types, [0, 1], [0]
):
return False

return True

@staticmethod
def _get_channel_spatial_indices(node: Node):
if node.meta[NXP_NODE_FORMAT].is_channels_first():
ch_idx = 1
h_idx = 2
w_idx = 3
else:
ch_idx = 3
h_idx = 1
w_idx = 2

return ch_idx, h_idx, w_idx

@staticmethod
def _is_supported_in_IR(
node: Node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _is_supported_on_target(
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
if not NodeConverter.at_least_one_input_shape_matches_the_output_shape(node):
if not NodeConverter.inputs_satisfy_broadcast_condition(node):
return False

# If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def test__basic_nsys_inference_qat(self, mocker, request):
pytest.param(
[ModelInputSpec((4,)), ModelInputSpec((4, 4))], id="2 inputs 1D + 2D."
),
pytest.param(
[ModelInputSpec((10,)), ModelInputSpec((1, 1))],
id="2 inputs 2D, num_elems of input == num_elems of output",
),
],
)
def test__broadcast(self, mocker, request, input_spec):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def test__basic_nsys_inference_qat(self, mocker, request, x_input_shape):
pytest.param(
[ModelInputSpec((4,)), ModelInputSpec((4, 4))], id="2 inputs 1D+2D."
),
pytest.param(
[ModelInputSpec((10,)), ModelInputSpec((1, 1))],
id="2 inputs 2D, num_elems of input == num_elems of output",
),
],
)
def test__correct_broadcast(self, input_spec, mocker, request):
Expand Down
Loading
Loading