Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions backends/arm/test/misc/test_tosa_dialect_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.arm.tosa.dialect # noqa: F401
import pytest
import torch
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
from executorch.backends.arm.tosa.dialect.ops_registration import (
get_registered_tosa_ops,
)
from executorch.backends.arm.tosa.specification import (
TosaLoweringContext,
TosaSpecification,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch._subclasses.fake_tensor import FakeTensorMode


def _to_fake(mode: FakeTensorMode, *values):
return [
mode.from_tensor(value) if isinstance(value, torch.Tensor) else value
for value in values
]


@pytest.mark.parametrize(
("op_name", "spec", "input_tensor", "args", "kwargs"),
[
pytest.param(
"CLAMP",
"TOSA-1.1+INT",
torch.randint(-8, 8, (2, 3, 4), dtype=torch.int8),
(-3, 3),
{},
id="CLAMP",
),
pytest.param(
"ERF",
"TOSA-1.1+FP",
torch.randn((2, 3, 4), dtype=torch.float32),
(),
{},
id="ERF",
),
pytest.param(
"SIGMOID",
"TOSA-1.1+FP",
torch.randn((2, 3, 4), dtype=torch.float32),
(),
{},
id="SIGMOID",
),
pytest.param(
"TANH",
"TOSA-1.1+FP",
torch.randn((2, 3, 4), dtype=torch.float32),
(),
{},
id="TANH",
),
],
)
def test_tosa_activation_ops(
op_name: str,
spec: str,
input_tensor: torch.Tensor,
args: tuple[object, ...],
kwargs: dict[str, object],
) -> None:
with TosaLoweringContext(
TosaSpecification.create_from_string(spec)
), FakeTensorMode() as mode:
output = getattr(exir_ops.backend.tosa, op_name).default(
*_to_fake(mode, input_tensor, *args),
**kwargs,
)

assert output.dtype == input_tensor.dtype
assert tuple(output.shape) == tuple(input_tensor.shape)


@pytest.mark.parametrize(
("op", "spec", "expected"),
[
pytest.param(
exir_ops.backend.tosa.ERF.default, "TOSA-1.1+INT", False, id="erf_int"
),
pytest.param(
exir_ops.backend.tosa.SIGMOID.default,
"TOSA-1.1+INT",
False,
id="sigmoid_int",
),
pytest.param(
exir_ops.backend.tosa.TANH.default, "TOSA-1.1+INT", False, id="tanh_int"
),
pytest.param(
exir_ops.backend.tosa.ERF.default, "TOSA-1.1+FP", True, id="erf_fp"
),
pytest.param(
exir_ops.backend.tosa.SIGMOID.default, "TOSA-1.1+FP", True, id="sigmoid_fp"
),
pytest.param(
exir_ops.backend.tosa.TANH.default, "TOSA-1.1+FP", True, id="tanh_fp"
),
],
)
def test_tosa_transcendentals_registered_for_fp_profile_only(
op,
spec: str,
expected: bool,
) -> None:
with TosaLoweringContext(TosaSpecification.create_from_string(spec)):
registered_ops = get_registered_tosa_ops()

assert (op in registered_ops) is expected


@pytest.mark.parametrize(
("op_name", "input_tensor"),
[
pytest.param(
"ERF",
torch.randn((2, 3, 4), dtype=torch.bfloat16),
id="ERF",
),
pytest.param(
"SIGMOID",
torch.randn((2, 3, 4), dtype=torch.bfloat16),
id="SIGMOID",
),
pytest.param(
"TANH",
torch.randn((2, 3, 4), dtype=torch.bfloat16),
id="TANH",
),
],
)
def test_tosa_transcendentals_accept_bfloat16_with_bf16_extension(
op_name: str,
input_tensor: torch.Tensor,
) -> None:
with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.1+FP+bf16")
), FakeTensorMode() as mode:
output = getattr(exir_ops.backend.tosa, op_name).default(
mode.from_tensor(input_tensor)
)

assert output.dtype == torch.bfloat16
assert tuple(output.shape) == tuple(input_tensor.shape)


def test_clamp_rejects_invalid_range() -> None:
sample_input = torch.randint(-8, 8, (2, 3, 4), dtype=torch.int8)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.1+INT")
), FakeTensorMode() as mode:
with pytest.raises(
TosaValueError,
match="max_val must be greater than or equal to min_val",
):
exir_ops.backend.tosa.CLAMP.default(
mode.from_tensor(sample_input),
4,
-4,
)


@pytest.mark.parametrize(
("min_val", "max_val", "match"),
[
pytest.param(-1.5, 1.5, "must be an integer", id="non_integral"),
pytest.param(-200, 200, "must be in \\[-128, 127\\]", id="out_of_range"),
],
)
def test_clamp_rejects_invalid_integer_bounds(
min_val: int | float,
max_val: int | float,
match: str,
) -> None:
sample_input = torch.randint(-8, 8, (2, 3, 4), dtype=torch.int8)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.1+INT")
), FakeTensorMode() as mode:
with pytest.raises(TosaValueError, match=match):
exir_ops.backend.tosa.CLAMP.default(
mode.from_tensor(sample_input),
min_val,
max_val,
)
1 change: 1 addition & 0 deletions backends/arm/tosa/dialect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401
activation,
avg_pool2d,
avg_pool2d_adaptive,
cast_to_block_scaled,
Expand Down
16 changes: 16 additions & 0 deletions backends/arm/tosa/dialect/ops/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.arm.tosa.dialect.lib import TosaValueError

_VALID_NAN_MODES = {"PROPAGATE", "IGNORE"}


def validate_nan_mode(nan_mode: str, op: str) -> None:
if nan_mode not in _VALID_NAN_MODES:
raise TosaValueError(
f"Unsupported nan_mode {nan_mode}. Expected one of {_VALID_NAN_MODES}",
op=op,
)
140 changes: 140 additions & 0 deletions backends/arm/tosa/dialect/ops/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
from executorch.backends.arm.tosa.dialect.ops._common import validate_nan_mode
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
from executorch.backends.arm.tosa.specification import (
get_context_spec,
TosaSpecification,
)

FP_SPECS = TosaSpecification.all_versions_for_profile("FP")


def _validate_clamp_dtype(dtype: torch.dtype, op: str) -> None:
tosa_spec = get_context_spec()

if dtype == torch.int8:
if not tosa_spec.support_integer():
raise TosaValueError(
f"TOSA spec {tosa_spec} doesn't support int8 for {op}",
op=op,
)
return

if dtype == torch.int16:
if not (tosa_spec.support_integer() and tosa_spec.support_extension("int16")):
raise TosaValueError(
f"TOSA spec {tosa_spec} doesn't support int16 for {op}",
op=op,
)
return

_validate_float_dtype(dtype, op)
return

raise TosaValueError(f"Unsupported dtype {dtype} for {op}", op=op)


def _validate_float_dtype(dtype: torch.dtype, op: str) -> None:
tosa_spec = get_context_spec()

if dtype in (torch.float16, torch.float32):
if not tosa_spec.support_float():
raise TosaValueError(
f"TOSA spec {tosa_spec} doesn't support {dtype} for {op}",
op=op,
)
return

if dtype == torch.bfloat16:
if not (tosa_spec.support_float() and tosa_spec.support_extension("bf16")):
raise TosaValueError(
f"TOSA spec {tosa_spec} doesn't support bfloat16 for {op}",
op=op,
)
return

raise TosaValueError(f"Unsupported dtype {dtype} for {op}", op=op)


def _validate_integer_clamp_bounds(
dtype: torch.dtype,
min_val,
max_val,
) -> None:
if dtype not in (torch.int8, torch.int16):
return

dtype_info = torch.iinfo(dtype)
for name, value in (("min_val", min_val), ("max_val", max_val)):
if not isinstance(value, int) or isinstance(value, bool):
raise TosaValueError(
f"{name} must be an integer for {dtype} CLAMP",
op="CLAMP",
)
if value < dtype_info.min or value > dtype_info.max:
raise TosaValueError(
f"{name} must be in [{dtype_info.min}, {dtype_info.max}] for {dtype} CLAMP",
op="CLAMP",
)


@register_fake_tosa_op(
'CLAMP(Tensor input, Scalar min_val, Scalar max_val, *, str nan_mode="PROPAGATE") -> Tensor',
TosaSpecification.all_versions_and_profiles(),
)
def CLAMP(
input: torch.Tensor,
min_val,
max_val,
*,
nan_mode: str = "PROPAGATE",
) -> torch.Tensor:
validate_nan_mode(nan_mode, "CLAMP")
_validate_clamp_dtype(input.dtype, "CLAMP")
_validate_integer_clamp_bounds(input.dtype, min_val, max_val)

if isinstance(min_val, float) and math.isnan(min_val):
raise TosaValueError("min_val cannot be NaN", op="CLAMP")
if isinstance(max_val, float) and math.isnan(max_val):
raise TosaValueError("max_val cannot be NaN", op="CLAMP")
if min_val > max_val:
raise TosaValueError(
"max_val must be greater than or equal to min_val", op="CLAMP"
)

return torch.empty_like(input, dtype=input.dtype)


@register_fake_tosa_op(
"ERF(Tensor input) -> Tensor",
FP_SPECS,
)
def ERF(input: torch.Tensor) -> torch.Tensor:
_validate_float_dtype(input.dtype, "ERF")
return torch.empty_like(input, dtype=input.dtype)


@register_fake_tosa_op(
"SIGMOID(Tensor input) -> Tensor",
FP_SPECS,
)
def SIGMOID(input: torch.Tensor) -> torch.Tensor:
_validate_float_dtype(input.dtype, "SIGMOID")
return torch.empty_like(input, dtype=input.dtype)


@register_fake_tosa_op(
"TANH(Tensor input) -> Tensor",
FP_SPECS,
)
def TANH(input: torch.Tensor) -> torch.Tensor:
_validate_float_dtype(input.dtype, "TANH")
return torch.empty_like(input, dtype=input.dtype)
Loading