Skip to content

Commit 9d0be16

Browse files
committed
Arm backend: Add TOSA CUSTOM dialect op, visitor
- add tosa.CUSTOM fake op registration in the dialect - for backend passes to create+use tosa.custom only within partition - register a TOSA CUSTOM node visitor for serialization - needing shape for the wrapped op, adding a decorator for tosa shape Signed-off-by: Rob Elliott <Robert.Elliott@arm.com> Change-Id: I085ffe8656ffc1edf22d70c92bfc80aa2e602694
1 parent b606983 commit 9d0be16

5 files changed

Lines changed: 250 additions & 0 deletions

File tree

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
op_clamp,
2525
op_cond_if,
2626
op_cos,
27+
op_tosa_custom,
2728
op_eq,
2829
op_erf,
2930
op_exp,
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, List
7+
8+
import torch
9+
import tosa_serializer as ts
10+
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa.mapping import TosaArg
16+
17+
18+
@register_node_visitor
19+
class CustomVisitor(NodeVisitor):
20+
"""Lower the TOSA CUSTOM op from the TOSA backend dialect."""
21+
22+
target = "tosa.CUSTOM.default"
23+
24+
def define_node(
25+
self,
26+
node: torch.fx.Node,
27+
tosa_graph: Any,
28+
inputs: List[TosaArg],
29+
output: TosaArg,
30+
) -> None:
31+
allowed_kwargs = {"operator_name", "domain_name", "implementation_attrs"}
32+
unexpected = set(node.kwargs.keys()) - allowed_kwargs
33+
if unexpected:
34+
raise ValueError(
35+
f"tosa.CUSTOM received unexpected kwargs: {sorted(unexpected)}"
36+
)
37+
38+
operator_name = node.kwargs.get("operator_name")
39+
domain_name = node.kwargs.get("domain_name")
40+
implementation_attrs = node.kwargs.get("implementation_attrs")
41+
42+
if operator_name is None or domain_name is None:
43+
raise ValueError(
44+
"tosa.CUSTOM requires operator_name and domain_name in kwargs"
45+
)
46+
47+
if implementation_attrs is None:
48+
impl_list = []
49+
elif isinstance(implementation_attrs, list):
50+
# NOTE: PyTorch schemas do not support a bytes type; we pass
51+
# implementation_attrs as int[] representing raw bytes.
52+
impl_list = [int(x) for x in implementation_attrs]
53+
else:
54+
raise TypeError(
55+
"implementation_attrs must be None or list[int]; "
56+
f"got {type(implementation_attrs)}"
57+
)
58+
59+
attr = ts.TosaSerializerAttribute()
60+
attr.CustomAttribute(
61+
operator_name=operator_name,
62+
domain_name=domain_name,
63+
implementation_attrs=impl_list,
64+
)
65+
66+
expanded = [TosaArg(item, self.tosa_spec) for item in inputs[0].special]
67+
input_names = [arg.name for arg in expanded]
68+
output_names = (
69+
output.multiple_output_names
70+
if getattr(output, "multiple_output_names", None)
71+
else [output.name]
72+
)
73+
if len(output_names) != 1:
74+
# TODO: Support multi-output CUSTOM ops with per-output meta/shape.
75+
raise ValueError(
76+
f"tosa.CUSTOM currently requires a single output, got {len(output_names)}"
77+
)
78+
self._serialize_operator(
79+
node,
80+
tosa_graph,
81+
ts.Op.CUSTOM,
82+
input_names,
83+
output_names,
84+
attr,
85+
)

backends/arm/tosa/dialect/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401
77
conv2d,
88
conv3d,
9+
custom,
910
depthwise_conv2d,
1011
gather,
1112
matmul,
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""Fake-op support for the generic TOSA ``CUSTOM`` dialect op.
7+
8+
The serialized TOSA ``CUSTOM`` op is intentionally generic: it carries a
9+
stable operator identity (for example ``thribrary.threee_pleee``) plus an
10+
opaque payload in ``implementation_attrs``. That is enough for serialization,
11+
but not enough for FakeTensor propagation unless we also teach the compiler how
12+
to model the output tensors of the specific wrapped op.
13+
14+
This module provides a lightweight registration mechanism for those compiler
15+
side fake implementations:
16+
17+
1. A lowering pass rewrites an op to ``exir_ops.backend.tosa.CUSTOM.default``.
18+
2. The wrapped custom op registers a thin adapter with
19+
``@register_fake_tosa("namespace::op")``.
20+
3. The generic ``CUSTOM`` fake implementation looks up that adapter by the
21+
``operator_name`` argument and invokes it with the full custom-op calling
22+
convention ``(inputs, operator_name, domain_name, implementation_attrs)``.
23+
24+
The adapter should stay thin: it should only translate from the generic TOSA
25+
CUSTOM signature back to the wrapped op's fake semantics. The real semantic
26+
logic should continue to live in the original fake implementation where
27+
possible.
28+
"""
29+
30+
import inspect
31+
from collections.abc import Callable
32+
33+
import torch
34+
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
35+
36+
from executorch.backends.arm.tosa.specification import (
37+
get_context_spec,
38+
TosaSpecification,
39+
)
40+
41+
_TOSA_CUSTOM_FAKE_IMPLS: dict[str, Callable] = {}
42+
43+
44+
def _normalize_tosa_custom_operator_name(operator_name: str) -> str:
45+
"""Normalize operator names so ``ns::op`` and ``ns.op`` map identically."""
46+
return operator_name.replace("::", ".")
47+
48+
49+
def validate_tosa_custom_fake_impl(fake_impl: object) -> Callable:
50+
"""Validate the signature expected by ``register_fake_tosa``.
51+
52+
Registered fake implementations must accept the generic TOSA CUSTOM fake
53+
calling convention:
54+
55+
``(inputs, operator_name, domain_name, implementation_attrs)``
56+
57+
and return ``list[Tensor]``.
58+
"""
59+
if not callable(fake_impl):
60+
raise TypeError(
61+
"Expected tosa.CUSTOM fake impl to be callable, "
62+
f"got {type(fake_impl)}"
63+
)
64+
65+
params = tuple(inspect.signature(fake_impl).parameters.values())
66+
positional_kinds = {
67+
inspect.Parameter.POSITIONAL_ONLY,
68+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
69+
}
70+
if len(params) != 4 or any(param.kind not in positional_kinds for param in params):
71+
raise TypeError(
72+
"tosa.CUSTOM fake impl must have signature "
73+
"(inputs, operator_name, domain_name, implementation_attrs)"
74+
)
75+
return fake_impl
76+
77+
78+
def register_fake_tosa(operator_name: str) -> Callable[[Callable], Callable]:
79+
"""Register a fake implementation for a specific wrapped TOSA custom op.
80+
81+
Args:
82+
operator_name: Stable custom operator identifier. Both ``ns::op`` and
83+
``ns.op`` spellings are accepted.
84+
85+
Returns:
86+
A decorator that registers a callable with signature
87+
``(inputs, operator_name, domain_name, implementation_attrs)`` and
88+
returning ``list[Tensor]``.
89+
90+
Example:
91+
``@register_fake_tosa("thribrary::threee_pleee")``
92+
"""
93+
normalized_name = _normalize_tosa_custom_operator_name(operator_name)
94+
95+
def decorator(fake_impl: Callable) -> Callable:
96+
validated = validate_tosa_custom_fake_impl(fake_impl)
97+
_TOSA_CUSTOM_FAKE_IMPLS[normalized_name] = validated
98+
return fake_impl
99+
100+
return decorator
101+
102+
103+
def has_fake_tosa_impl(operator_name: str) -> bool:
104+
"""Return whether a wrapped custom op has a registered fake impl."""
105+
normalized_name = _normalize_tosa_custom_operator_name(operator_name)
106+
return normalized_name in _TOSA_CUSTOM_FAKE_IMPLS
107+
108+
109+
def run_registered_fake_tosa_impl(
110+
inputs: list[torch.Tensor],
111+
operator_name: str,
112+
domain_name: str,
113+
implementation_attrs: list[int],
114+
) -> list[torch.Tensor]:
115+
"""Invoke the registered fake implementation for a wrapped custom op."""
116+
normalized_name = _normalize_tosa_custom_operator_name(operator_name)
117+
fake_impl = _TOSA_CUSTOM_FAKE_IMPLS.get(normalized_name)
118+
if fake_impl is None:
119+
raise RuntimeError(
120+
f"tosa.CUSTOM requires a registered fake impl for {normalized_name}"
121+
)
122+
outputs = fake_impl(inputs, operator_name, domain_name, implementation_attrs)
123+
if not isinstance(outputs, list):
124+
raise TypeError(
125+
"tosa.CUSTOM fake impl must return list[Tensor], "
126+
f"got {type(outputs)}"
127+
)
128+
if not outputs:
129+
raise RuntimeError("tosa.CUSTOM fake impl must return at least one output")
130+
if not all(isinstance(output, torch.Tensor) for output in outputs):
131+
raise TypeError("tosa.CUSTOM fake impl must return list[Tensor]")
132+
return outputs
133+
134+
135+
@register_fake_tosa_op(
136+
"CUSTOM(Tensor[] inputs, str operator_name, str domain_name, int[] implementation_attrs) -> Tensor[]",
137+
TosaSpecification.all_versions_and_profiles(),
138+
)
139+
def CUSTOM(
140+
inputs: list[torch.Tensor],
141+
operator_name: str,
142+
domain_name: str,
143+
implementation_attrs: list[int],
144+
) -> list[torch.Tensor]:
145+
"""Fake implementation for TOSA CUSTOM op.
146+
147+
The CUSTOM op is backend-defined. The fake implementation dispatches to a
148+
registered compiler-side fake implementation for the specific custom op.
149+
"""
150+
_ = get_context_spec() # ensure a spec context exists
151+
if not inputs:
152+
raise RuntimeError("tosa.CUSTOM requires at least one input tensor")
153+
return run_registered_fake_tosa_impl(
154+
inputs,
155+
operator_name,
156+
domain_name,
157+
implementation_attrs,
158+
)

backends/arm/tosa/mapping.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ def extract_tensor_meta(meta):
139139
if type(val) is tuple:
140140
# TODO: should use first concrete representation
141141
val = val[0]
142+
if isinstance(val, list):
143+
if not val:
144+
raise ValueError("Expected node.meta['val'] list to be non-empty")
145+
# Use first concrete representation for multi-output ops.
146+
val = val[0]
142147

143148
if not isinstance(val, torch._subclasses.fake_tensor.FakeTensor):
144149
raise ValueError(

0 commit comments

Comments
 (0)