Skip to content

Commit 366de24

Browse files
Arm backend: Add FP8 support for gather/scatter-based composite ops
Support FP8 tensors for the following composite ops: - TOSA GATHER: embedding, index_select, index.Tensor, unfold_copy - TOSA SCATTER: index_put, index_copy, slice_scatter Run all FP8 tests through the TOSA reference model. For ops without eager CPU FP8 support, only execute the TOSA reference model; otherwise keep the default output comparison against eager. Change-Id: I3d81cd6dd426f16b5f2db8937228cad12184b6a6 Signed-off-by: Yufeng Shi <yufeng.shi@arm.com>
1 parent e88fd04 commit 366de24

10 files changed

Lines changed: 327 additions & 14 deletions

backends/arm/operator_support/index_select_support.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,14 @@ def is_node_tosa_supported(
7777
f"{node.target}: dtype {values_dtype} requires INT profile.",
7878
)
7979
return False
80-
# fp16/fp32/bf16: either FP profile, or INT profile (via quantization)
81-
elif values_dtype in (torch.float16, torch.float32, torch.bfloat16):
80+
# fp16/fp32/bf16/fp8: either FP profile, or INT profile (via quantization)
81+
elif values_dtype in (
82+
torch.float16,
83+
torch.float32,
84+
torch.bfloat16,
85+
torch.float8_e4m3fn,
86+
torch.float8_e5m2,
87+
):
8288
if values_dtype == torch.bfloat16 and not tosa_spec.support_extension(
8389
"bf16"
8490
):
@@ -87,6 +93,22 @@ def is_node_tosa_supported(
8793
f"{node.target}: dtype {values_dtype} requires bf16 extension.",
8894
)
8995
return False
96+
if values_dtype == torch.float8_e4m3fn and not tosa_spec.support_extension(
97+
"fp8e4m3"
98+
):
99+
self.reporter.report_reject(
100+
node,
101+
f"{node.target}: dtype {values_dtype} requires fp8e4m3 extension.",
102+
)
103+
return False
104+
if values_dtype == torch.float8_e5m2 and not tosa_spec.support_extension(
105+
"fp8e5m2"
106+
):
107+
self.reporter.report_reject(
108+
node,
109+
f"{node.target}: dtype {values_dtype} requires fp8e5m2 extension.",
110+
)
111+
return False
90112
if not (tosa_spec.support_float() or tosa_spec.support_integer()):
91113
self.reporter.report_reject(
92114
node,
@@ -98,7 +120,8 @@ def is_node_tosa_supported(
98120
self.reporter.report_reject(
99121
node,
100122
f"{node.target}: unsupported values dtype {values_dtype}; "
101-
"expected bool/int8/int16/int32/float16/bfloat16/float32.",
123+
"expected bool/int8/int16/int32/float16/bfloat16/float32/"
124+
"float8_e4m3fn/float8_e5m2.",
102125
)
103126
return False
104127

backends/arm/operator_support/index_tensor_support.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,13 @@ def is_node_tosa_supported(
144144
f"{node.target}: dtype {values_dtype} requires INT profile.",
145145
)
146146
return False
147-
elif values_dtype in (torch.float16, torch.float32, torch.bfloat16):
147+
elif values_dtype in (
148+
torch.float16,
149+
torch.float32,
150+
torch.bfloat16,
151+
torch.float8_e4m3fn,
152+
torch.float8_e5m2,
153+
):
148154
if values_dtype == torch.bfloat16 and not tosa_spec.support_extension(
149155
"bf16"
150156
):
@@ -153,6 +159,22 @@ def is_node_tosa_supported(
153159
f"{node.target}: dtype {values_dtype} requires bf16 extension.",
154160
)
155161
return False
162+
if values_dtype == torch.float8_e4m3fn and not tosa_spec.support_extension(
163+
"fp8e4m3"
164+
):
165+
self.reporter.report_reject(
166+
node,
167+
f"{node.target}: dtype {values_dtype} requires fp8e4m3 extension.",
168+
)
169+
return False
170+
if values_dtype == torch.float8_e5m2 and not tosa_spec.support_extension(
171+
"fp8e5m2"
172+
):
173+
self.reporter.report_reject(
174+
node,
175+
f"{node.target}: dtype {values_dtype} requires fp8e5m2 extension.",
176+
)
177+
return False
156178
if not (tosa_spec.support_float() or tosa_spec.support_integer()):
157179
self.reporter.report_reject(
158180
node,
@@ -164,7 +186,7 @@ def is_node_tosa_supported(
164186
self.reporter.report_reject(
165187
node,
166188
f"{node.target}: unsupported values dtype {values_dtype}; "
167-
"expected bool/int8/int16/int32/float16/bfloat16/float32.",
189+
"expected bool/int8/int16/int32/float16/bfloat16/float32/float8_e4m3fn/float8_e5m2.",
168190
)
169191
return False
170192

backends/arm/operator_support/unfold_copy_support.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,14 @@ def is_node_tosa_supported(
8484
f"{node.target}: dtype {values_dtype} requires INT profile.",
8585
)
8686
return False
87-
# fp16/fp32/bf16: either FP profile, or INT profile (via quantization)
88-
elif values_dtype in (torch.float16, torch.float32, torch.bfloat16):
87+
# fp16/fp32/bf16/fp8: either FP profile, or INT profile (via quantization)
88+
elif values_dtype in (
89+
torch.float16,
90+
torch.float32,
91+
torch.bfloat16,
92+
torch.float8_e4m3fn,
93+
torch.float8_e5m2,
94+
):
8995
if values_dtype == torch.bfloat16 and not tosa_spec.support_extension(
9096
"bf16"
9197
):
@@ -94,6 +100,22 @@ def is_node_tosa_supported(
94100
f"{node.target}: dtype {values_dtype} requires bf16 extension.",
95101
)
96102
return False
103+
if values_dtype == torch.float8_e4m3fn and not tosa_spec.support_extension(
104+
"fp8e4m3"
105+
):
106+
self.reporter.report_reject(
107+
node,
108+
f"{node.target}: dtype {values_dtype} requires fp8e4m3 extension.",
109+
)
110+
return False
111+
if values_dtype == torch.float8_e5m2 and not tosa_spec.support_extension(
112+
"fp8e5m2"
113+
):
114+
self.reporter.report_reject(
115+
node,
116+
f"{node.target}: dtype {values_dtype} requires fp8e5m2 extension.",
117+
)
118+
return False
97119
if not (tosa_spec.support_float() or tosa_spec.support_integer()):
98120
self.reporter.report_reject(
99121
node,
@@ -105,7 +127,8 @@ def is_node_tosa_supported(
105127
self.reporter.report_reject(
106128
node,
107129
f"{node.target}: unsupported values dtype {values_dtype}; "
108-
"expected bool/int8/int16/int32/float16/bfloat16/float32.",
130+
"expected bool/int8/int16/int32/float16/bfloat16/float32/"
131+
"float8_e4m3fn/float8_e5m2.",
109132
)
110133
return False
111134

backends/arm/test/ops/test_embedding.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -63,6 +63,22 @@ def forward(self, weights: torch.Tensor, indices: torch.Tensor):
6363
torch.randint(low=0, high=10, size=(4, 3, 2, 5), dtype=torch.int64),
6464
),
6565
}
66+
test_input_fp8: dict[str, tuple[input_params, str]] = {
67+
"test_fp8e4m3_int32_indices": (
68+
(
69+
torch.randn(10, 3, dtype=torch.float32).to(torch.float8_e4m3fn),
70+
torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int32),
71+
),
72+
"fp8e4m3",
73+
),
74+
"test_fp8e5m2_int64_indices": (
75+
(
76+
torch.randn(11, 5, dtype=torch.float32).to(torch.float8_e5m2),
77+
torch.randint(low=0, high=10, size=(4, 3), dtype=torch.int64),
78+
),
79+
"fp8e5m2",
80+
),
81+
}
6682

6783

6884
@pytest.mark.skip(reason="MLETORCH-1274 Improve data type checks during partitioning")
@@ -74,7 +90,6 @@ def test_embedding_tosa_FP(test_input: input_params):
7490
test_input,
7591
op.aten_op,
7692
op.exir_op,
77-
use_to_edge_transform_and_lower=True,
7893
transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()],
7994
)
8095
pipeline.run()
@@ -88,22 +103,36 @@ def test_embedding_tosa_INT(test_input: input_params):
88103
test_input,
89104
op.aten_op,
90105
op.exir_op,
91-
use_to_edge_transform_and_lower=True,
92106
)
93107
pipeline.pop_stage("check.aten")
94108
pipeline.pop_stage("check_count.exir")
95109

96110
pipeline.run()
97111

98112

113+
@common.parametrize("test_input", test_input_fp8)
114+
def test_embedding_tosa_FP_fp8(test_input):
115+
inputs, tosa_extension = test_input
116+
op = Embedding()
117+
pipeline = TosaPipelineFP[input_params](
118+
op,
119+
inputs,
120+
op.aten_op,
121+
op.exir_op,
122+
transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()],
123+
compare_tosa_ref_model_outputs=False,
124+
tosa_extensions=[tosa_extension],
125+
)
126+
pipeline.run()
127+
128+
99129
def test_embedding_tosa_INT_expand():
100130
op = ExpandEmbedding()
101131
pipeline = TosaPipelineINT(
102132
op,
103133
ExpandEmbedding.example_inputs,
104134
ExpandEmbedding.aten_op,
105135
ExpandEmbedding.exir_op,
106-
use_to_edge_transform_and_lower=True,
107136
)
108137
pipeline.pop_stage("check.aten")
109138
pipeline.pop_stage("check_count.exir")
@@ -121,7 +150,6 @@ def test_embedding_vgf_no_quant(test_input: input_params):
121150
test_input,
122151
op.aten_op,
123152
op.exir_op,
124-
use_to_edge_transform_and_lower=True,
125153
transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()],
126154
quantize=False,
127155
)
@@ -137,7 +165,6 @@ def test_embedding_vgf_quant(test_input: input_params):
137165
test_input,
138166
op.aten_op,
139167
op.exir_op,
140-
use_to_edge_transform_and_lower=True,
141168
quantize=True,
142169
)
143170
pipeline.pop_stage("check.aten")

backends/arm/test/ops/test_index_copy.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,32 @@ class IndexCopyModule(torch.nn.Module):
7474
("in_place", True),
7575
)
7676
}
77+
test_data_fp8 = {
78+
"rand_single_index_fp8e4m3_out_of_place": (
79+
lambda: (
80+
(
81+
0,
82+
torch.rand((4, 5), dtype=torch.float32).to(torch.float8_e4m3fn),
83+
torch.LongTensor([0]),
84+
torch.zeros((1, 5), dtype=torch.float32).to(torch.float8_e4m3fn),
85+
),
86+
False,
87+
"fp8e4m3",
88+
)
89+
),
90+
"rand_3d_dim_1_fp8e5m2_in_place": (
91+
lambda: (
92+
(
93+
1,
94+
torch.rand((4, 2, 3), dtype=torch.float32).to(torch.float8_e5m2),
95+
torch.LongTensor([0, 1]),
96+
torch.ones((4, 2, 3), dtype=torch.float32).to(torch.float8_e5m2),
97+
),
98+
True,
99+
"fp8e5m2",
100+
)
101+
),
102+
}
77103

78104
aten_ops = {
79105
False: ["torch.ops.aten.index_put.default"],
@@ -112,6 +138,21 @@ def test_index_copy_tosa_FP(test_data):
112138
pipeline.run()
113139

114140

141+
@common.parametrize("test_data", IndexCopyModule.test_data_fp8)
142+
def test_index_copy_tosa_FP_fp8(test_data):
143+
inputs, inplace, tosa_extension = test_data()
144+
module = IndexCopyModule(inplace=inplace)
145+
pipeline = TosaPipelineFP(
146+
module=module,
147+
test_data=inputs,
148+
aten_op=[],
149+
compare_tosa_ref_model_outputs=False,
150+
transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()],
151+
tosa_extensions=[tosa_extension],
152+
)
153+
pipeline.run()
154+
155+
115156
@common.parametrize("test_data", IndexCopyModule.test_data)
116157
def test_index_copy_tosa_INT(test_data):
117158
inputs, inplace = test_data()

backends/arm/test/ops/test_index_put.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,29 @@
333333
0,
334334
),
335335
}
336+
test_data_suite_fp8 = {
337+
"rank2_fp8e4m3": (
338+
lambda: (
339+
torch.rand((4, 5), dtype=torch.float32).to(torch.float8_e4m3fn),
340+
(torch.tensor([0, 2], dtype=torch.int32),),
341+
torch.rand((2, 5), dtype=torch.float32).to(torch.float8_e4m3fn),
342+
False,
343+
),
344+
"fp8e4m3",
345+
),
346+
"rank3_fp8e5m2": (
347+
lambda: (
348+
torch.rand((3, 4, 2), dtype=torch.float32).to(torch.float8_e5m2),
349+
(
350+
torch.tensor([0, 2], dtype=torch.int32),
351+
torch.tensor([1, 3], dtype=torch.int32),
352+
),
353+
torch.rand((2, 2), dtype=torch.float32).to(torch.float8_e5m2),
354+
False,
355+
),
356+
"fp8e5m2",
357+
),
358+
}
336359

337360

338361
class IndexPut(torch.nn.Module):
@@ -375,6 +398,19 @@ def test_index_put_tosa_FP(test_module: input_t):
375398
pipeline.run()
376399

377400

401+
@common.parametrize("test_module", test_data_suite_fp8)
402+
def test_index_put_tosa_FP_fp8(test_module):
403+
test_data, tosa_extension = test_module
404+
pipeline = TosaPipelineFP(
405+
IndexPut(),
406+
test_data(),
407+
aten_op=IndexPut.aten_op,
408+
exir_op=IndexPut.exir_op,
409+
tosa_extensions=[tosa_extension],
410+
)
411+
pipeline.run()
412+
413+
378414
@common.parametrize("test_module", test_data_suite_fp | test_data_int, xfails=xfails)
379415
def test_index_put_tosa_INT(test_module: input_t):
380416
pipeline = TosaPipelineINT[input_t](

backends/arm/test/ops/test_index_select.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,26 @@ def forward(self, input_: torch.Tensor, dim: int, index_: torch.Tensor):
8181
torch.tensor([1, 0], dtype=torch.int32), # [W=2]
8282
),
8383
}
84+
test_data_fp8: dict[str, input_params] = {
85+
# Rank-3: [N, K, C] -> index_select dim=1 => [N, W, C]
86+
"test_fp8e4m3_rank3_dim1": (
87+
torch.randn(2, 4, 3, dtype=torch.float32).to(
88+
torch.float8_e4m3fn
89+
), # [N=2, K=4, C=3]
90+
1,
91+
torch.tensor([1, 3], dtype=torch.int32), # [W=2]
92+
"fp8e4m3",
93+
),
94+
# Rank-4: [A, B, K, C] -> index_select dim=2 => [A, B, W, C]
95+
"test_fp8e5m2_rank4_dim2": (
96+
torch.randn(2, 3, 4, 5, dtype=torch.float32).to(
97+
torch.float8_e5m2
98+
), # [A=2, B=3, K=4, C=5]
99+
2,
100+
torch.tensor([3, 1], dtype=torch.int32), # [W=2]
101+
"fp8e5m2",
102+
),
103+
}
84104

85105
# ---- INT profile: integer inputs + bool ----
86106
test_data_int: dict[str, input_params] = {
@@ -136,6 +156,20 @@ def test_index_select_tosa_FP_bf16(test_data: input_params):
136156
pipeline.run()
137157

138158

159+
@common.parametrize("test_data", test_data_fp8)
160+
def test_index_select_tosa_FP_fp8(test_data):
161+
input_, dim, index_, tosa_extension = test_data
162+
pipeline = TosaPipelineFP[input_params](
163+
IndexSelect(),
164+
(input_, dim, index_),
165+
aten_op=IndexSelect.aten_op,
166+
exir_op=IndexSelect.exir_op,
167+
compare_tosa_ref_model_outputs=False,
168+
tosa_extensions=[tosa_extension],
169+
)
170+
pipeline.run()
171+
172+
139173
@common.parametrize("test_data", test_data_int | test_data_fp)
140174
def test_index_select_tosa_INT(test_data: input_params):
141175
# INT profile runs quantized, so we test both int inputs and float inputs here.

0 commit comments

Comments
 (0)