Skip to content

Commit ea75cd5

Browse files
Update
[ghstack-poisoned]
2 parents 09b19f5 + 1f65bd4 commit ea75cd5

35 files changed

Lines changed: 1219 additions & 703 deletions

backends/arm/_passes/rewrite_avg_pool2d_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
6565
# Materialize output zero-point as a scalar tensor
6666
output_zp = super().call_scalar(out_zp_val, meta)
6767

68-
# Determine accumulator dtype for AVG_POOL2D: INT32 for integer inputs, FP32 otherwise
68+
# Determine accumulator dtype for AVG_POOL2D.
6969
if x.data.dtype in (torch.int8, torch.int16):
7070
acc_type = torch.int32
71+
elif x.data.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
72+
acc_type = torch.float16
7173
else:
7274
acc_type = torch.float32
7375

backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 72 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
import itertools
8-
from typing import Any, Set, Type
8+
from typing import Any, cast, Set, Type
99

1010
import torch
1111
from executorch.backends.arm._passes import ArmPass
@@ -39,6 +39,7 @@
3939
from executorch.exir.dialects._ops import ops as exir_ops
4040
from executorch.exir.pass_base import ExportPass, PassResult
4141

42+
from torch._subclasses.fake_tensor import FakeTensor
4243
from torch.export.graph_signature import InputKind
4344

4445

@@ -350,6 +351,68 @@ def _has_int32_rescale_user(self, node: torch.fx.Node) -> bool:
350351
return True
351352
return False
352353

354+
def _insert_output_conversion(
355+
self,
356+
graph_module: torch.fx.GraphModule,
357+
node: torch.fx.Node,
358+
tosa_op: torch.fx.Node,
359+
input_fake_tensor: torch.Tensor,
360+
tosa_node_fake_tensor: torch.Tensor,
361+
) -> tuple[torch.fx.Node, FakeTensor]:
362+
node_replacement: torch.fx.Node = tosa_op
363+
node_replacement_fake_tensor = tosa_node_fake_tensor
364+
if (
365+
tosa_node_fake_tensor.dtype == torch.int32
366+
and input_fake_tensor.dtype == torch.int8
367+
):
368+
node_replacement, node_replacement_fake_tensor = self.insert_output_rescale(
369+
graph_module, node, tosa_op, tosa_node_fake_tensor
370+
)
371+
elif (
372+
tosa_node_fake_tensor.dtype == torch.int32
373+
and input_fake_tensor.dtype == torch.int16
374+
):
375+
# Explicit layout paths require a post-conv permute, which does
376+
# not support INT48. Always rescale before post-permute.
377+
if self._has_int32_rescale_user(node):
378+
node_replacement, node_replacement_fake_tensor = (
379+
self.insert_identity_int32_rescale(
380+
graph_module, node, tosa_op, tosa_node_fake_tensor
381+
)
382+
)
383+
else:
384+
node_replacement, node_replacement_fake_tensor = (
385+
self.insert_output_rescale(
386+
graph_module, node, tosa_op, tosa_node_fake_tensor
387+
)
388+
)
389+
390+
tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48
391+
elif (
392+
tosa_node_fake_tensor.dtype == torch.float16
393+
and input_fake_tensor.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
394+
):
395+
node_output_fake_tensor = get_first_fake_tensor(node)
396+
# TOSA FP8 conv widens the output. Cast back to the exported
397+
# graph dtype before the post-layout permute.
398+
node_replacement_fake_tensor = (
399+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default(
400+
tosa_node_fake_tensor,
401+
dtype=node_output_fake_tensor.dtype,
402+
)
403+
)
404+
with graph_module.graph.inserting_after(tosa_op):
405+
node_replacement = create_node(
406+
graph=graph_module.graph,
407+
op_target=exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
408+
args=(tosa_op,),
409+
kwargs={"dtype": node_output_fake_tensor.dtype},
410+
from_node=tosa_op,
411+
)
412+
node_replacement.meta["val"] = node_replacement_fake_tensor
413+
414+
return node_replacement, cast(FakeTensor, node_replacement_fake_tensor)
415+
353416
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
354417
modified = False
355418
for node in graph_module.graph.nodes:
@@ -561,37 +624,15 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
561624
)
562625
tosa_op.meta["val"] = tosa_node_fake_tensor
563626

564-
node_replacement: torch.fx.Node = tosa_op
565-
node_replacement_fake_tensor = tosa_node_fake_tensor
566-
if (
567-
tosa_node_fake_tensor.dtype == torch.int32
568-
and input_fake_tensor.dtype == torch.int8
569-
):
570-
output_rescale, output_rescale_fake = self.insert_output_rescale(
571-
graph_module, node, tosa_op, tosa_node_fake_tensor
627+
node_replacement, node_replacement_fake_tensor = (
628+
self._insert_output_conversion(
629+
graph_module,
630+
node,
631+
tosa_op,
632+
input_fake_tensor,
633+
tosa_node_fake_tensor,
572634
)
573-
node_replacement = output_rescale
574-
node_replacement_fake_tensor = output_rescale_fake
575-
elif (
576-
tosa_node_fake_tensor.dtype == torch.int32
577-
and input_fake_tensor.dtype == torch.int16
578-
):
579-
# Explicit layout paths require a post-conv permute, which does
580-
# not support INT48. Always rescale before post-permute.
581-
if self._has_int32_rescale_user(node):
582-
output_rescale, output_rescale_fake = (
583-
self.insert_identity_int32_rescale(
584-
graph_module, node, tosa_op, tosa_node_fake_tensor
585-
)
586-
)
587-
else:
588-
output_rescale, output_rescale_fake = self.insert_output_rescale(
589-
graph_module, node, tosa_op, tosa_node_fake_tensor
590-
)
591-
node_replacement = output_rescale
592-
node_replacement_fake_tensor = output_rescale_fake
593-
594-
tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48
635+
)
595636

596637
if post_permute_dims is None:
597638
raise RuntimeError("Expected post permute dims for explicit layout")

backends/arm/_passes/rewrite_matmul.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,21 @@
2121

2222

2323
class RewriteMatmulPass(ArmPass):
24-
"""Rewrites aten.bmm to tosa.MATMUL and inserts a tosa.RESCALE op if
24+
"""Rewrites aten.bmm to tosa.MATMUL and inserts a tosa.RESCALE or cast op if
2525
needed.
2626
"""
2727

2828
_passes_required_after: Set[Type[ExportPass]] = set()
2929

30+
# TOSA MATMUL widens these floating-point input types, so outputs may need
31+
# casting back to preserve the original PyTorch node semantics.
32+
_WIDENING_INPUT_DTYPES = (
33+
torch.float16,
34+
torch.bfloat16,
35+
torch.float8_e4m3fn,
36+
torch.float8_e5m2,
37+
)
38+
3039
def _insert_output_rescale(self, graph_module, node, tosa_matmul_node, dtype):
3140
input_qparams = get_input_qparams(node)
3241
output_qparams = get_output_qparams(node)[0]
@@ -94,17 +103,18 @@ def call(self, graph_module):
94103
TosaSpecialDtype.INT48
95104
)
96105
elif (
97-
x1_fake_tensor.dtype in [torch.float16, torch.bfloat16]
98-
and x2_fake_tensor.dtype in [torch.float16, torch.bfloat16]
99-
and output_fake_tensor.dtype not in [torch.float16, torch.bfloat16]
106+
x1_fake_tensor.dtype in self._WIDENING_INPUT_DTYPES
107+
and x2_fake_tensor.dtype in self._WIDENING_INPUT_DTYPES
108+
and output_fake_tensor.dtype not in self._WIDENING_INPUT_DTYPES
100109
):
101-
# A TOSA BF16/FP16 MATMUL outputs FP32 whereas pytorch outputs BF16/FP16.
102-
# Cast back to BF16/FP16 to get matching semantics.
110+
# TOSA BF16/FP16/FP8 MATMUL outputs FP32, while the original
111+
# exported node outputs BF16/FP16/FP8. Cast back to preserve
112+
# the exported graph dtype.
103113
with graph_module.graph.inserting_after(tosa_matmul_node):
104114
cast_node = create_node(
105115
graph_module.graph,
106116
op_target=exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
107-
kwargs={"dtype": x1_fake_tensor.dtype},
117+
kwargs={"dtype": node_output_fake_tensor.dtype},
108118
from_node=tosa_matmul_node,
109119
)
110120
tosa_matmul_node.replace_all_uses_with(cast_node)

backends/arm/operators/op_tosa_avg_pool2d.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def define_node(
4343

4444
if self.tosa_spec.support_extension("int16"):
4545
supported.append(ts.DType.INT16)
46+
if self.tosa_spec.support_extension("fp8e4m3"):
47+
supported.append(ts.DType.FP8E4M3)
48+
if self.tosa_spec.support_extension("fp8e5m2"):
49+
supported.append(ts.DType.FP8E5M2)
4650

4751
validate_valid_dtype(self.target, [input, output], supported, self.tosa_spec)
4852

backends/arm/operators/op_tosa_conv2d.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def define_node(
6767
)
6868
if self.tosa_spec.support_extension("bf16"):
6969
valid_input_dtypes.append(ts.DType.BF16)
70+
if self.tosa_spec.support_extension("fp8e4m3"):
71+
valid_input_dtypes.append(ts.DType.FP8E4M3)
72+
if self.tosa_spec.support_extension("fp8e5m2"):
73+
valid_input_dtypes.append(ts.DType.FP8E5M2)
7074

7175
validate_valid_dtype(
7276
self.target,
@@ -82,8 +86,13 @@ def define_node(
8286

8387
conv2d_output_name = output.name
8488
acc_type = output.dtype
85-
if output.dtype in [ts.DType.BF16, ts.DType.FP16]:
86-
# Accumulate BF16, FP16 inputs in FP32 for better precision.
89+
if input.dtype in [ts.DType.FP8E4M3, ts.DType.FP8E5M2]:
90+
acc_type = ts.DType.FP16
91+
elif output.dtype in [
92+
ts.DType.BF16,
93+
ts.DType.FP16,
94+
]:
95+
# Accumulate BF16 and FP16 inputs in FP32 for better precision.
8796
acc_type = ts.DType.FP32
8897

8998
input_zp_name, weight_zp_name = add_input_weight_zp_consts(

backends/arm/operators/op_tosa_matmul.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def define_node(
5454
ts.DType.FP16,
5555
ts.DType.FP32,
5656
ts.DType.BF16,
57+
ts.DType.FP8E4M3,
58+
ts.DType.FP8E5M2,
5759
],
5860
self.tosa_spec,
5961
)

backends/arm/operators/op_tosa_max_pool2d.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def define_node(
4242
supported_dtypes = [ts.DType.INT8, ts.DType.FP16, ts.DType.FP32, ts.DType.BF16]
4343
if self.tosa_spec.support_extension("int16"):
4444
supported_dtypes.append(ts.DType.INT16)
45+
if self.tosa_spec.support_extension("fp8e4m3"):
46+
supported_dtypes.append(ts.DType.FP8E4M3)
47+
if self.tosa_spec.support_extension("fp8e5m2"):
48+
supported_dtypes.append(ts.DType.FP8E5M2)
4549
validate_valid_dtype(
4650
self.target,
4751
[input_tensor, output],

backends/arm/operators/op_tosa_transpose_conv2d.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,24 @@ def define_node(
7373
validate_valid_dtype(
7474
self.target, [inputs[2]], [ts.DType.BF16], self.tosa_spec
7575
)
76+
if self.tosa_spec.support_extension("fp8e4m3"):
77+
valid_input_dtypes.append(ts.DType.FP8E4M3)
78+
if inputs[0].dtype == ts.DType.FP8E4M3:
79+
validate_valid_dtype(
80+
self.target, [inputs[1]], [ts.DType.FP8E4M3], self.tosa_spec
81+
)
82+
validate_valid_dtype(
83+
self.target, [inputs[2]], [ts.DType.FP8E4M3], self.tosa_spec
84+
)
85+
if self.tosa_spec.support_extension("fp8e5m2"):
86+
valid_input_dtypes.append(ts.DType.FP8E5M2)
87+
if inputs[0].dtype == ts.DType.FP8E5M2:
88+
validate_valid_dtype(
89+
self.target, [inputs[1]], [ts.DType.FP8E5M2], self.tosa_spec
90+
)
91+
validate_valid_dtype(
92+
self.target, [inputs[2]], [ts.DType.FP8E5M2], self.tosa_spec
93+
)
7694

7795
validate_valid_dtype(
7896
self.target,

backends/arm/scripts/aot_arm_compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,7 @@ def main() -> None: # noqa: C901
10341034
args.calibration_data, example_inputs
10351035
)
10361036
model = original_model.eval()
1037+
model.requires_grad_(False)
10371038

10381039
# export under the assumption we quantize, the exported form also works
10391040
# in to_edge if we don't quantize
@@ -1115,8 +1116,6 @@ def main() -> None: # noqa: C901
11151116

11161117
dump_delegation_info(edge, args.intermediates)
11171118

1118-
edge_program_manager_copy = copy.deepcopy(edge)
1119-
11201119
try:
11211120
exec_prog = edge.to_executorch(
11221121
config=ExecutorchBackendConfig(extract_delegate_segments=False)
@@ -1175,6 +1174,7 @@ def main() -> None: # noqa: C901
11751174
if args.bundleio or args.etrecord:
11761175
etrecord_file_name = os.path.splitext(output_file_name)[0] + "_etrecord.bin"
11771176
try:
1177+
edge_program_manager_copy = copy.deepcopy(edge)
11781178
generate_etrecord(etrecord_file_name, edge_program_manager_copy, exec_prog)
11791179
print(f"ETRecord saved as {etrecord_file_name}")
11801180
except Exception as e:

backends/arm/scripts/build_executor_runner.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ help() {
5454
echo " --et_build_root=<FOLDER> Build output root folder to use, defaults to ${et_build_root}"
5555
echo " --ethosu_tools_dir=<FOLDER> Path to your Ethos-U tools dir if you not using default: ${ethosu_tools_dir}"
5656
echo " --toolchain=<TOOLCHAIN> Toolchain can be specified (arm-none-eabi-gcc, arm-zephyr-eabi-gcc). Default: ${toolchain}"
57-
echo " --select_ops_list=<OPS> Comma separated list of portable (non delagated) kernels to include Default: ${select_ops_list}"
57+
echo " --select_ops_list=<OPS> Comma separated list of portable (non-delegated) kernels to include Default: ${select_ops_list}"
5858
echo " NOTE: This is used when select_ops_model is not possible to use, e.g. for semihosting or bundleio."
5959
echo " See https://docs.pytorch.org/executorch/stable/kernel-library-selective-build.html for more information."
6060
exit 0

0 commit comments

Comments
 (0)