Skip to content

Commit b137126

Browse files
committed
NXP backend: Add format checks to prevent NodeFormatInference bugs.
1 parent c37b6c3 commit b137126

8 files changed

Lines changed: 114 additions & 33 deletions

File tree

backends/nxp/backend/ir/converter/node_converter.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import logging
67
import operator
78
from abc import ABC, abstractmethod
89
from typing import Callable
@@ -12,6 +13,7 @@
1213
from executorch.backends.nxp.backend.custom_delegation_options import (
1314
CustomDelegationOptions,
1415
)
16+
from executorch.backends.nxp.backend.data_format import DataFormat, NXP_NODE_FORMAT
1517
from executorch.backends.nxp.backend.edge_helper import (
1618
input_quantization_type,
1719
output_quantization_type,
@@ -53,6 +55,23 @@ def is_not_qdq_node(node: torch.fx.Node) -> bool:
5355
return not (_is_quant_node(node) or _is_dequant_node(node))
5456

5557

58+
def requires_channels_first_format(cls):
59+
"""Class decorator for NodeConverter subclasses.
60+
61+
Marks a converter as requiring that both the node's main input and output
62+
use the channels-first data format (as inferred by NodeFormatInference).
63+
The check is automatically enforced via `NodeConverter.is_supported()`.
64+
65+
Usage::
66+
67+
@requires_channels_first_format
68+
class ConvConverter(NodeConverter):
69+
...
70+
"""
71+
cls._requires_channels_first_format = True
72+
return cls
73+
74+
5675
class NodeConverter(ABC):
5776
"""
5877
Classes which implement conversion of torch.Node to TFLite should inherit from this class and overwrite the
@@ -61,6 +80,11 @@ class NodeConverter(ABC):
6180

6281
context: ConversionContext
6382

83+
# If `True`, the `is_supported()` method will disallow delegation if the node's main input/output doesn't have the
84+
# channels first node format.
85+
# Subclasses decorated with @requires_channels_first_format will have this set to True.
86+
_requires_channels_first_format: bool = False
87+
6488
def __init__(self, context: ConversionContext):
6589
self.context = context
6690

@@ -115,6 +139,36 @@ def _is_supported_on_target(
115139
"""
116140
return True
117141

142+
@classmethod
143+
def _node_format_is_supported(cls, node: Node) -> bool:
144+
"""Check that the node's main input and output carry the channels-first data format, if the converter was
145+
decorated with `@requires_channels_first_format`.
146+
147+
When the decorator is not present the check returns True.
148+
149+
:param node: The node to inspect.
150+
:return: True when the format requirement is satisfied (or not applicable).
151+
"""
152+
if not cls._requires_channels_first_format:
153+
return True
154+
155+
def _is_channels_first(n: Node) -> bool:
156+
return (
157+
n.meta.get(NXP_NODE_FORMAT, DataFormat.NONE)
158+
is DataFormat.CHANNELS_FIRST
159+
)
160+
161+
format_requirement_satisfied = _is_channels_first(node) and _is_channels_first(
162+
node.args[0]
163+
)
164+
if not format_requirement_satisfied:
165+
logging.warning(
166+
f"NXP backend: Node `{node}` requires channels-first format for its input and output, but the inferred "
167+
"format does not satisfy this requirement. The node will not be delegated. Please report this issue."
168+
)
169+
170+
return format_requirement_satisfied
171+
118172
@classmethod
119173
def is_supported(
120174
cls,
@@ -133,10 +187,13 @@ def is_supported(
133187
be outdated.
134188
:param custom_delegation_options: Custom user options which affect node delegation.
135189
"""
136-
return cls._is_supported_in_IR(
137-
node, parameters_mapping, custom_delegation_options
138-
) and cls._is_supported_on_target(
139-
node, neutron_target_spec, parameters_mapping, custom_delegation_options
190+
191+
return (
192+
cls._is_supported_in_IR(node, parameters_mapping, custom_delegation_options)
193+
and cls._is_supported_on_target(
194+
node, neutron_target_spec, parameters_mapping, custom_delegation_options
195+
)
196+
and cls._node_format_is_supported(node)
140197
)
141198

142199
@classmethod

backends/nxp/backend/ir/converter/node_converters/ops_converters/adaptive_avg_pool_2d_converter.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
import logging
65

76
import executorch.backends.nxp.backend.ir.lib.tflite.Padding as tflPadding
87
import torch
98

10-
from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT
119
from executorch.backends.nxp.backend.ir.converter.conversion import common
1210
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1311
CustomDelegationOptions,
1412
NodeConverter,
13+
requires_channels_first_format,
1514
)
1615
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
1716
average_pool_2d_options,
@@ -25,6 +24,7 @@
2524
Stride = tuple[int, int]
2625

2726

27+
@requires_channels_first_format
2828
class AdaptiveAvgPool2dConverter(NodeConverter):
2929

3030
@staticmethod
@@ -45,15 +45,6 @@ def _is_supported_in_IR(
4545
parameters_mapping: dict[str, Parameter],
4646
custom_delegation_options: CustomDelegationOptions,
4747
) -> bool:
48-
if (
49-
format_ := node.meta.get(NXP_NODE_FORMAT)
50-
) is None or not format_.is_channels_first():
51-
logging.warning(
52-
"NXP backend: `adaptive_avg_pool_2d` doesn't have the required input format for delegation. "
53-
"Please run `NodeFormatInference.identify_node_formats()` during lowering or report this issue."
54-
)
55-
return False
56-
5748
input_size = node.args[0].meta["val"].shape
5849
output_size = node.args[1]
5950

backends/nxp/backend/ir/converter/node_converters/ops_converters/avg_pool_2d_converter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1717
CustomDelegationOptions,
1818
NodeConverter,
19+
requires_channels_first_format,
1920
)
2021
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
2122
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
@@ -26,6 +27,7 @@
2627
from torch.nn import Parameter
2728

2829

30+
@requires_channels_first_format
2931
class AvgPool2dConverter(NodeConverter):
3032

3133
@staticmethod

backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from executorch.backends.nxp.backend.ir.converter.node_converter import (
2424
CustomDelegationOptions,
2525
NodeConverter,
26+
requires_channels_first_format,
2627
)
2728
from executorch.backends.nxp.backend.ir.converter.node_converters.shared import (
2829
conv_utils,
@@ -48,6 +49,7 @@
4849
from torch.nn import Parameter
4950

5051

52+
@requires_channels_first_format
5153
class ConvolutionConverter(NodeConverter):
5254
@staticmethod
5355
def _is_supported_on_target(

backends/nxp/backend/ir/converter/node_converters/ops_converters/max_pool2d_with_indices_converter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1717
CustomDelegationOptions,
1818
NodeConverter,
19+
requires_channels_first_format,
1920
)
2021
from executorch.backends.nxp.backend.ir.lib.tflite.TensorType import TensorType
2122
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.max_pool_2d_options import (
@@ -32,6 +33,7 @@
3233
CeilMode = bool
3334

3435

36+
@requires_channels_first_format
3537
class MaxPool2DWithIndicesConverter(NodeConverter):
3638

3739
@staticmethod

backends/nxp/backend/ir/converter/node_converters/ops_converters/upsample_bilinear2d_converter.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import numpy as np
77
import torch
88

9-
from executorch.backends.nxp.backend.data_format import DataFormat, NXP_NODE_FORMAT
109
from executorch.backends.nxp.backend.edge_helper import node_has_well_defined_shape
1110
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1211
CustomDelegationOptions,
1312
is_not_qdq_node,
1413
NodeConverter,
14+
requires_channels_first_format,
1515
)
1616
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.resize_bilinear_options import (
1717
ResizeBilinear,
@@ -23,6 +23,7 @@
2323

2424

2525
# noinspection SpellCheckingInspection
26+
@requires_channels_first_format
2627
class UpsampleBilinear2DConverter(NodeConverter):
2728

2829
@classmethod
@@ -53,14 +54,6 @@ def _is_supported_in_IR(
5354
parameters_mapping: dict[str, Parameter],
5455
custom_delegation_options: CustomDelegationOptions,
5556
) -> bool:
56-
57-
if node.meta.get(NXP_NODE_FORMAT, DataFormat.NONE) != DataFormat.CHANNELS_FIRST:
58-
# This should never happen.
59-
raise NotImplementedError(
60-
"NXP backend: `aten.upsample_bilinear2d.vec` didn't have correctly identified data"
61-
" format. Please report this."
62-
)
63-
6457
# The conversion requires the output shape to be known and static.
6558
if not node_has_well_defined_shape(node):
6659
return False

backends/nxp/backend/ir/converter/node_converters/ops_converters/upsample_nearest2d_converter.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import numpy as np
77
import torch
88

9-
from executorch.backends.nxp.backend.data_format import DataFormat, NXP_NODE_FORMAT
109
from executorch.backends.nxp.backend.edge_helper import node_has_well_defined_shape
1110
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1211
CustomDelegationOptions,
1312
is_not_qdq_node,
1413
NodeConverter,
14+
requires_channels_first_format,
1515
)
1616
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.resize_nearest_neighbor_options import (
1717
ResizeNearestNeighbor,
@@ -26,6 +26,7 @@
2626

2727

2828
# noinspection SpellCheckingInspection
29+
@requires_channels_first_format
2930
class UpsampleNearest2DConverter(NodeConverter):
3031

3132
@classmethod
@@ -55,14 +56,6 @@ def _is_supported_in_IR(
5556
parameters_mapping: dict[str, Parameter],
5657
custom_delegation_options: CustomDelegationOptions,
5758
) -> bool:
58-
59-
if node.meta.get(NXP_NODE_FORMAT, DataFormat.NONE) != DataFormat.CHANNELS_FIRST:
60-
# This should never happen.
61-
raise NotImplementedError(
62-
"NXP backend: `aten.upsample_nearest2d.vec` didn't have correctly identified data"
63-
" format. Please report this."
64-
)
65-
6659
# The conversion requires the output shape to be known and static.
6760
if not node_has_well_defined_shape(node):
6861
return False

backends/nxp/tests/generic_tests/test_node_format_inference.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import logging
7+
68
import torch
79

810
from executorch import exir
@@ -11,12 +13,18 @@
1113
NodeFormatInference,
1214
NXP_NODE_FORMAT,
1315
)
16+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
17+
from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops
1418

1519
from executorch.backends.nxp.tests.models import (
1620
Conv2dModule,
1721
MaxPool2dModule,
1822
SoftmaxModule,
1923
)
24+
from executorch.backends.nxp.tests.ops_aliases import (
25+
ExecutorchDelegateCall,
26+
MaxPool2DWithIndices,
27+
)
2028

2129

2230
def test_convolution():
@@ -77,3 +85,36 @@ def test_max_pool2d():
7785

7886
for node in epm.exported_program().graph.nodes:
7987
assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT]
88+
89+
90+
def test_unhandled_channels_first_node(caplog):
91+
# This test focuses on the case where some operator requires the channels first format, which is enforced in the
92+
# `NodeConverter`, but the `NodeFormatInference` fails to reflect this.
93+
# We use the `MaxPool` operator for this, and we temporarily modify the `NodeFormatInference` to trigger the issue.
94+
95+
model = MaxPool2dModule()
96+
input_shape = (1, 4, 32, 32)
97+
98+
# Temporarily "break" the NodeFormatInference.
99+
old_channels_first_ops = NodeFormatInference.ops_with_channels_first_nodes
100+
NodeFormatInference.ops_with_channels_first_nodes = {}
101+
102+
with caplog.at_level(
103+
logging.WARNING,
104+
logger="executorch.backends.nxp.backend.ir.converter.node_converter",
105+
):
106+
ep = to_quantized_edge_program(model, input_shape).exported_program()
107+
108+
# Make sure the `MaxPool` wasn't delegated.
109+
assert graph_contains_any_of_ops(ep.graph, [MaxPool2DWithIndices])
110+
assert not graph_contains_any_of_ops(ep.graph, [ExecutorchDelegateCall])
111+
112+
# Make sure the warning is printed.
113+
assert any(
114+
"`aten_max_pool2d_with_indices_default` requires channels-first format for its input and output, but the "
115+
"inferred format does not satisfy this requirement" in message
116+
for message in caplog.messages
117+
)
118+
119+
# Restore the original channels first ops configuration.
120+
NodeFormatInference.ops_with_channels_first_nodes = old_channels_first_ops

0 commit comments

Comments
 (0)