Skip to content

Commit 545cdf3

Browse files
committed
Arm backend: Add real impls to all TOSA dialect ops
Additionally, - Remove special case in ComputeOpsAOT pass for such ops, since they can now be executed. - Start running the model in tests were this was previously impossible due to ops not having a real impl. Signed-off-by: Erik Lundell <erik.lundell@arm.com> Change-Id: I94ed6aa08842d8cd57e9f0fb331edc5261b8d044
1 parent e21f6d4 commit 545cdf3

22 files changed

Lines changed: 65 additions & 63 deletions

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,10 @@ def call(self, graph_module):
174174
for node in graph_module.graph.nodes:
175175
if node.op != "call_function":
176176
continue
177-
# Don't fuse TOSA dialect ops as they do not have eager forward functions.
178-
# Also don't fuse ops whose explicit args/kwargs include symbolic shape values.
179-
if (
180-
self._is_tosa_dialect_op(node.target)
181-
or self._arg_contains_symbolic_shape(node.args)
182-
or self._arg_contains_symbolic_shape(node.kwargs)
183-
):
177+
# Don't fuse ops whose explicit args/kwargs include symbolic shape values.
178+
if self._arg_contains_symbolic_shape(
179+
node.args
180+
) or self._arg_contains_symbolic_shape(node.kwargs):
184181
continue
185182

186183
input_nodes = node.all_input_nodes

backends/arm/operators/node_visitor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
250250

251251

252252
def get_node_visitor(target: str, tosa_spec: TosaSpecification):
253+
# Ensure all operator modules are imported so visitors are registered.
254+
import executorch.backends.arm.operators # noqa: F401
255+
253256
node_visitor_tuples = _node_visitor_tuples.get(tosa_spec)
254257
for target_name, node_visitor_cls in node_visitor_tuples:
255258
if target_name == target:

backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def test_ensure_unique_output_nodes_no_target_inserts_identity_per_repeated_outp
3535
"executorch_exir_dialects_backend__ops_tosa_IDENTITY_default": 2,
3636
},
3737
)
38-
pipeline.pop_stage("run_method_and_compare_outputs")
3938
pipeline.run()
4039

4140
graph_module = (
@@ -62,5 +61,4 @@ def test_ensure_unique_output_nodes_no_target_keeps_unique_outputs_unchanged() -
6261
"executorch_exir_dialects_backend__ops_tosa_IDENTITY_default",
6362
],
6463
)
65-
pipeline.pop_stage("run_method_and_compare_outputs")
6664
pipeline.run()

backends/arm/test/passes/test_rewrite_conv_pass.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,6 @@ def test_rewrite_conv_tosa_FP():
213213
pipeline = PassPipeline(
214214
module, module.get_inputs(), passes_with_exported_program=[RewriteConvPass]
215215
)
216-
# We cannot run TOSA backend dialect operators in eager mode.
217-
pipeline.pop_stage("run_method_and_compare_outputs")
218216
pipeline.run()
219217

220218

backends/arm/test/passes/test_rewrite_max_pool2d_pass.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3737

3838
class MaxPool2dWithoutStride(torch.nn.Module):
3939
def get_inputs(self) -> input_t:
40-
return (torch.rand(1, 3, 8, 8),)
40+
return (torch.rand(1, 3, 9, 9),)
4141

4242
def forward(self, x: torch.Tensor) -> torch.Tensor:
4343
return torch.nn.functional.max_pool2d(x, kernel_size=3)
4444

4545

4646
class MaxPool2dListKernel(torch.nn.Module):
4747
def get_inputs(self) -> input_t:
48-
return (torch.rand(1, 3, 8, 8),)
48+
return (torch.rand(1, 3, 8, 9),)
4949

5050
def forward(self, x: torch.Tensor) -> torch.Tensor:
5151
return torch.nn.functional.max_pool2d(x, kernel_size=[2, 3])
@@ -56,7 +56,7 @@ def get_inputs(self) -> input_t:
5656
return (torch.rand(1, 3, 8, 8),)
5757

5858
def forward(self, x: torch.Tensor) -> torch.Tensor:
59-
return torch.nn.functional.max_pool2d(x, kernel_size=[2, 3], stride=[])
59+
return torch.nn.functional.max_pool2d(x, kernel_size=[2, 2], stride=[])
6060

6161

6262
class MaxPool2dDynamic(torch.nn.Module):
@@ -94,9 +94,6 @@ def test_rewrite_max_pool2d_tosa(module: ModuleWithInputs) -> None:
9494
},
9595
pass_list=[RemoveGetItemPass, RewriteMaxPool2dPass],
9696
)
97-
pipeline.pop_stage(
98-
"run_method_and_compare_outputs"
99-
) # Cannnot run aten graph with tosa dialect ops
10097
pipeline.run()
10198

10299

@@ -131,11 +128,10 @@ def test_rewrite_max_pool2d_tosa_empty_stride_uses_kernel_size() -> None:
131128
},
132129
pass_list=[RemoveGetItemPass, RewriteMaxPool2dPass],
133130
)
134-
pipeline.pop_stage("run_method_and_compare_outputs")
135131
pipeline.run()
136132

137133
tosa_node = _get_tosa_max_pool2d_node(pipeline)
138-
assert tosa_node.args[2] == [2, 3]
134+
assert tosa_node.args[2] == [2, 2]
139135

140136

141137
def test_rewrite_max_pool2d_tosa_dynamic_shape() -> None:

backends/arm/tosa/dialect/ops/avg_pool2d_adaptive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
compute_avg_pool2d_output_shape,
1212
validate_avg_pool2d_dtype,
1313
)
14-
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
14+
from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op
1515
from executorch.backends.arm.tosa.specification import (
1616
get_context_shape_env,
1717
get_context_spec,
@@ -36,7 +36,7 @@ def _is_directly_representable(
3636
return remainder in (0, 1)
3737

3838

39-
@register_fake_tosa_op(
39+
@register_tosa_op(
4040
"AVG_POOL2D_ADAPTIVE(Tensor input, Tensor input_zp, Tensor output_zp, SymInt[2] kernel, SymInt[2] stride, SymInt[4] pad, ScalarType acc_type) -> Tensor",
4141
TosaSpecification.all_profiles_for_version("1.1"),
4242
)

backends/arm/tosa/dialect/ops/conv2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
11-
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
11+
from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op
1212
from executorch.backends.arm.tosa.specification import (
1313
get_context_spec,
1414
TosaSpecification,
@@ -77,7 +77,7 @@ def validate_conv2d_args_dtypes(
7777
return output_dtype
7878

7979

80-
@register_fake_tosa_op(
80+
@register_tosa_op(
8181
"CONV2D(Tensor input, "
8282
"Tensor weight, "
8383
"Tensor bias, "

backends/arm/tosa/dialect/ops/conv3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
1111
from executorch.backends.arm.tosa.dialect.ops.conv2d import validate_conv2d_args_dtypes
12-
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
12+
from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op
1313
from executorch.backends.arm.tosa.specification import (
1414
get_context_spec,
1515
TosaSpecification,
@@ -30,7 +30,7 @@ def validate_conv3d_args_dtypes(
3030
return validate_conv2d_args_dtypes(tosa_spec, x, weight, bias, op="CONV3D")
3131

3232

33-
@register_fake_tosa_op(
33+
@register_tosa_op(
3434
"CONV3D(Tensor input, "
3535
"Tensor weight, "
3636
"Tensor bias, "

backends/arm/tosa/dialect/ops/custom.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from collections.abc import Callable
3232

3333
import torch
34-
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
34+
from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op
3535

3636
from executorch.backends.arm.tosa.specification import (
3737
get_context_spec,
@@ -132,7 +132,7 @@ def run_registered_fake_tosa_impl(
132132
return outputs
133133

134134

135-
@register_fake_tosa_op(
135+
@register_tosa_op(
136136
"CUSTOM(Tensor[] inputs, str operator_name, str domain_name, int[] implementation_attrs) -> Tensor[]",
137137
TosaSpecification.all_versions_and_profiles(),
138138
)

backends/arm/tosa/dialect/ops/depthwise_conv2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77

88
import torch
99
from executorch.backends.arm.tosa.dialect.ops.conv2d import validate_conv2d_args_dtypes
10-
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
10+
from executorch.backends.arm.tosa.dialect.ops_registration import register_tosa_op
1111

1212
from executorch.backends.arm.tosa.specification import (
1313
get_context_spec,
1414
TosaSpecification,
1515
)
1616

1717

18-
@register_fake_tosa_op(
18+
@register_tosa_op(
1919
"DEPTHWISE_CONV2D(Tensor input, "
2020
"Tensor weight, "
2121
"Tensor bias, "

0 commit comments

Comments
 (0)