diff --git a/backends/arm/test/misc/test_tosa_dialect_unary_ops.py b/backends/arm/test/misc/test_tosa_dialect_unary_ops.py new file mode 100644 index 00000000000..9bfd33d4e0c --- /dev/null +++ b/backends/arm/test/misc/test_tosa_dialect_unary_ops.py @@ -0,0 +1,394 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 +import pytest +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import ( + get_registered_tosa_ops, +) +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +@pytest.mark.parametrize( + ("op_name", "spec", "input_tensor"), + [ + pytest.param( + "ABS", + "TOSA-1.1+INT", + torch.randint(1, 16, (2, 3), dtype=torch.int32), + id="ABS", + ), + pytest.param( + "BITWISE_NOT", + "TOSA-1.1+INT", + torch.randint(-8, 8, (2, 3), dtype=torch.int8), + id="BITWISE_NOT", + ), + pytest.param( + "BITWISE_NOT", + "TOSA-1.1+INT", + torch.randint(-8, 8, (2, 3), dtype=torch.int16), + id="BITWISE_NOT_INT16", + ), + pytest.param( + "CEIL", + "TOSA-1.1+FP", + torch.randn((2, 3), dtype=torch.float32), + id="CEIL", + ), + pytest.param( + "CLZ", + "TOSA-1.1+INT", + torch.randint(1, 16, (2, 3), dtype=torch.int32), + id="CLZ", + ), + pytest.param( + "COS", + "TOSA-1.1+FP", + torch.randn((2, 3), dtype=torch.float32), + id="COS", + ), + pytest.param( + "EXP", + "TOSA-1.1+FP", + torch.randn((2, 3), dtype=torch.float32), + id="EXP", + ), + pytest.param( + "FLOOR", + "TOSA-1.1+FP", + torch.randn((2, 3), dtype=torch.float32), + id="FLOOR", + ), + pytest.param( + "LOG", + "TOSA-1.1+FP", + torch.randn((2, 3), dtype=torch.float32).abs() + 1.0, + id="LOG", + ), + pytest.param( + "LOGICAL_NOT", + "TOSA-1.1+FP", + torch.tensor([[True, False], [False, True]], dtype=torch.bool), + id="LOGICAL_NOT", + ), + pytest.param( + "NEGATE", + "TOSA-1.1+INT", + torch.randint(-8, 8, (2, 3), dtype=torch.int32), + id="NEGATE", + ), + pytest.param( + "NEGATE", + "TOSA-1.1+INT", + torch.randint(-8, 8, (2, 3), dtype=torch.int16), + id="NEGATE_INT16", + ), + pytest.param( + "NEGATE", + "TOSA-1.1+FP", + torch.randn((2, 3), dtype=torch.float32), + id="NEGATE_FP32", + ), + pytest.param( + "RECIPROCAL", + "TOSA-1.1+FP", + torch.randn((2, 3), dtype=torch.float32).abs() + 1.0, + id="RECIPROCAL", + ), + pytest.param( + "RSQRT", + "TOSA-1.1+FP", + torch.randn((2, 3), dtype=torch.float32).abs() + 1.0, + id="RSQRT", + ), + pytest.param( + "SIN", + "TOSA-1.1+FP", + torch.randn((2, 3), dtype=torch.float32), + id="SIN", + ), + ], +) +def test_tosa_unary_ops( + op_name: str, + spec: str, + input_tensor: torch.Tensor, +) -> None: + with TosaLoweringContext( + TosaSpecification.create_from_string(spec) + ), FakeTensorMode() as mode: + output = getattr(exir_ops.backend.tosa, op_name).default( + mode.from_tensor(input_tensor) + ) + + assert output.dtype == input_tensor.dtype + assert tuple(output.shape) == tuple(input_tensor.shape) + + +@pytest.mark.parametrize( + ("op", "spec", "expected"), + [ + pytest.param( + exir_ops.backend.tosa.BITWISE_NOT.default, + "TOSA-1.1+INT", + True, + id="bitwise_not_int", + ), + pytest.param( + exir_ops.backend.tosa.BITWISE_NOT.default, + "TOSA-1.1+FP", + False, + id="bitwise_not_fp", + ), + pytest.param( + exir_ops.backend.tosa.CLZ.default, + "TOSA-1.1+INT", + True, + id="clz_int", + ), + pytest.param( + exir_ops.backend.tosa.CLZ.default, + "TOSA-1.1+FP", + False, + id="clz_fp", + ), + ], +) +def test_tosa_integer_unary_ops_registered_for_int_profile_only( + op, + spec: str, + expected: bool, +) -> None: + with TosaLoweringContext(TosaSpecification.create_from_string(spec)): + registered_ops = get_registered_tosa_ops() + + assert (op in registered_ops) is expected + + +@pytest.mark.parametrize( + ("op", "spec", "expected"), + [ + pytest.param( + exir_ops.backend.tosa.CEIL.default, + "TOSA-1.1+INT", + False, + id="ceil_int", + ), + pytest.param( + exir_ops.backend.tosa.CEIL.default, + "TOSA-1.1+FP", + True, + id="ceil_fp", + ), + pytest.param( + exir_ops.backend.tosa.COS.default, + "TOSA-1.1+INT", + False, + id="cos_int", + ), + pytest.param( + exir_ops.backend.tosa.COS.default, + "TOSA-1.1+FP", + True, + id="cos_fp", + ), + pytest.param( + exir_ops.backend.tosa.EXP.default, + "TOSA-1.1+INT", + False, + id="exp_int", + ), + pytest.param( + exir_ops.backend.tosa.EXP.default, + "TOSA-1.1+FP", + True, + id="exp_fp", + ), + pytest.param( + exir_ops.backend.tosa.FLOOR.default, + "TOSA-1.1+INT", + False, + id="floor_int", + ), + pytest.param( + exir_ops.backend.tosa.FLOOR.default, + "TOSA-1.1+FP", + True, + id="floor_fp", + ), + pytest.param( + exir_ops.backend.tosa.LOG.default, + "TOSA-1.1+INT", + False, + id="log_int", + ), + pytest.param( + exir_ops.backend.tosa.LOG.default, + "TOSA-1.1+FP", + True, + id="log_fp", + ), + pytest.param( + exir_ops.backend.tosa.RECIPROCAL.default, + "TOSA-1.1+INT", + False, + id="reciprocal_int", + ), + pytest.param( + exir_ops.backend.tosa.RECIPROCAL.default, + "TOSA-1.1+FP", + True, + id="reciprocal_fp", + ), + pytest.param( + exir_ops.backend.tosa.RSQRT.default, + "TOSA-1.1+INT", + False, + id="rsqrt_int", + ), + pytest.param( + exir_ops.backend.tosa.RSQRT.default, + "TOSA-1.1+FP", + True, + id="rsqrt_fp", + ), + pytest.param( + exir_ops.backend.tosa.SIN.default, + "TOSA-1.1+INT", + False, + id="sin_int", + ), + pytest.param( + exir_ops.backend.tosa.SIN.default, + "TOSA-1.1+FP", + True, + id="sin_fp", + ), + ], +) +def test_tosa_float_unary_ops_registered_for_fp_profile_only( + op, + spec: str, + expected: bool, +) -> None: + with TosaLoweringContext(TosaSpecification.create_from_string(spec)): + registered_ops = get_registered_tosa_ops() + + assert (op in registered_ops) is expected + + +@pytest.mark.parametrize( + ("spec", "expected"), + [ + pytest.param("TOSA-1.1+INT", True, id="negate_int"), + pytest.param("TOSA-1.1+FP", True, id="negate_fp"), + ], +) +def test_tosa_negate_registered_for_int_and_fp_profiles( + spec: str, + expected: bool, +) -> None: + with TosaLoweringContext(TosaSpecification.create_from_string(spec)): + registered_ops = get_registered_tosa_ops() + + assert (exir_ops.backend.tosa.NEGATE.default in registered_ops) is expected + + +@pytest.mark.parametrize( + ("op_name", "input_tensor"), + [ + pytest.param( + "CEIL", + torch.randn((2, 3), dtype=torch.bfloat16), + id="CEIL", + ), + pytest.param( + "COS", + torch.randn((2, 3), dtype=torch.bfloat16), + id="COS", + ), + pytest.param( + "EXP", + torch.randn((2, 3), dtype=torch.bfloat16), + id="EXP", + ), + pytest.param( + "FLOOR", + torch.randn((2, 3), dtype=torch.bfloat16), + id="FLOOR", + ), + pytest.param( + "LOG", + torch.randn((2, 3), dtype=torch.bfloat16).abs() + 1.0, + id="LOG", + ), + pytest.param( + "NEGATE", + torch.randn((2, 3), dtype=torch.bfloat16), + id="NEGATE", + ), + ], +) +def test_tosa_float_unary_ops_accept_bfloat16_with_bf16_extension( + op_name: str, + input_tensor: torch.Tensor, +) -> None: + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+bf16") + ), FakeTensorMode() as mode: + output = getattr(exir_ops.backend.tosa, op_name).default( + mode.from_tensor(input_tensor) + ) + + assert output.dtype == torch.bfloat16 + assert tuple(output.shape) == tuple(input_tensor.shape) + + +def test_negate_rejects_bfloat16_without_bf16_extension() -> None: + sample_input = torch.randn((2, 3), dtype=torch.bfloat16) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="doesn't support bfloat16"): + exir_ops.backend.tosa.NEGATE.default(mode.from_tensor(sample_input)) + + +def test_abs_rejects_int8() -> None: + sample_input = torch.randint(-8, 8, (2, 3), dtype=torch.int8) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+INT") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="Unsupported dtype"): + exir_ops.backend.tosa.ABS.default(mode.from_tensor(sample_input)) + + +def test_floor_requires_float_profile() -> None: + sample_input = torch.randn((2, 3), dtype=torch.float32) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+INT") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="doesn't support"): + exir_ops.backend.tosa.FLOOR.default(mode.from_tensor(sample_input)) + + +def test_logical_not_rejects_non_bool() -> None: + sample_input = torch.randint(-8, 8, (2, 3), dtype=torch.int8) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+INT") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="requires bool inputs"): + exir_ops.backend.tosa.LOGICAL_NOT.default(mode.from_tensor(sample_input)) diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index 3a733e8827b..6d74b28b270 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -26,4 +26,5 @@ slice, table, transpose_conv2d, + unary_elementwise, ) diff --git a/backends/arm/tosa/dialect/ops/unary_elementwise.py b/backends/arm/tosa/dialect/ops/unary_elementwise.py new file mode 100644 index 00000000000..56ac8edf3cd --- /dev/null +++ b/backends/arm/tosa/dialect/ops/unary_elementwise.py @@ -0,0 +1,224 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) + +FP_SPECS = TosaSpecification.all_versions_for_profile("FP") +INT_SPECS = TosaSpecification.all_versions_for_profile("INT") +DUAL_PROFILE_SPECS = [*INT_SPECS, *FP_SPECS] + + +def _validate_float_dtype(dtype: torch.dtype, op: str) -> None: + tosa_spec = get_context_spec() + + if dtype in (torch.float16, torch.float32): + if not tosa_spec.support_float(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support {dtype} for {op}", + op=op, + ) + return + + if dtype == torch.bfloat16: + if not (tosa_spec.support_float() and tosa_spec.support_extension("bf16")): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support bfloat16 for {op}", + op=op, + ) + return + + raise TosaValueError(f"Unsupported dtype {dtype} for {op}", op=op) + + +def _validate_integer_dtype(dtype: torch.dtype, op: str) -> None: + tosa_spec = get_context_spec() + + if dtype in {torch.int8, torch.int16, torch.int32}: + if not tosa_spec.support_integer(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support {dtype} for {op}", + op=op, + ) + return + + raise TosaValueError(f"Unsupported dtype {dtype} for {op}", op=op) + + +def _validate_abs_dtype(dtype: torch.dtype) -> None: + tosa_spec = get_context_spec() + + if dtype == torch.int32: + if not tosa_spec.support_integer(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support int32 for ABS", + op="ABS", + ) + return + + if dtype in (torch.float16, torch.float32): + if not tosa_spec.support_float(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support {dtype} for ABS", + op="ABS", + ) + return + + if dtype == torch.bfloat16: + if not (tosa_spec.support_float() and tosa_spec.support_extension("bf16")): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support bfloat16 for ABS", + op="ABS", + ) + return + + raise TosaValueError(f"Unsupported dtype {dtype} for ABS", op="ABS") + + +def _validate_clz_dtype(dtype: torch.dtype) -> None: + tosa_spec = get_context_spec() + + if dtype != torch.int32: + raise TosaValueError(f"CLZ requires int32 inputs but got {dtype}", op="CLZ") + if not tosa_spec.support_integer(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support int32 for CLZ", + op="CLZ", + ) + + +def _validate_bool_dtype(dtype: torch.dtype, op: str) -> None: + if dtype != torch.bool: + raise TosaValueError(f"{op} requires bool inputs but got {dtype}", op=op) + + +def _validate_negate_dtype(dtype: torch.dtype) -> None: + if dtype in (torch.int8, torch.int16, torch.int32): + _validate_integer_dtype(dtype, "NEGATE") + return + + _validate_float_dtype(dtype, "NEGATE") + + +@register_fake_tosa_op( + "ABS(Tensor input1) -> Tensor", + DUAL_PROFILE_SPECS, +) +def ABS(input1: torch.Tensor) -> torch.Tensor: + _validate_abs_dtype(input1.dtype) + return torch.empty_like(input1, dtype=input1.dtype) + + +@register_fake_tosa_op( + "BITWISE_NOT(Tensor input1) -> Tensor", + INT_SPECS, +) +def BITWISE_NOT(input1: torch.Tensor) -> torch.Tensor: + _validate_integer_dtype(input1.dtype, "BITWISE_NOT") + return torch.empty_like(input1, dtype=input1.dtype) + + +@register_fake_tosa_op( + "CEIL(Tensor input1) -> Tensor", + FP_SPECS, +) +def CEIL(input1: torch.Tensor) -> torch.Tensor: + _validate_float_dtype(input1.dtype, "CEIL") + return torch.empty_like(input1, dtype=input1.dtype) + + +@register_fake_tosa_op( + "CLZ(Tensor input1) -> Tensor", + INT_SPECS, +) +def CLZ(input1: torch.Tensor) -> torch.Tensor: + _validate_clz_dtype(input1.dtype) + return torch.empty_like(input1, dtype=input1.dtype) + + +@register_fake_tosa_op( + "COS(Tensor input1) -> Tensor", + FP_SPECS, +) +def COS(input1: torch.Tensor) -> torch.Tensor: + _validate_float_dtype(input1.dtype, "COS") + return torch.empty_like(input1, dtype=input1.dtype) + + +@register_fake_tosa_op( + "EXP(Tensor input1) -> Tensor", + FP_SPECS, +) +def EXP(input1: torch.Tensor) -> torch.Tensor: + _validate_float_dtype(input1.dtype, "EXP") + return torch.empty_like(input1, dtype=input1.dtype) + + +@register_fake_tosa_op( + "FLOOR(Tensor input1) -> Tensor", + FP_SPECS, +) +def FLOOR(input1: torch.Tensor) -> torch.Tensor: + _validate_float_dtype(input1.dtype, "FLOOR") + return torch.empty_like(input1, dtype=input1.dtype) + + +@register_fake_tosa_op( + "LOG(Tensor input1) -> Tensor", + FP_SPECS, +) +def LOG(input1: torch.Tensor) -> torch.Tensor: + _validate_float_dtype(input1.dtype, "LOG") + return torch.empty_like(input1, dtype=input1.dtype) + + +@register_fake_tosa_op( + "LOGICAL_NOT(Tensor input1) -> Tensor", + DUAL_PROFILE_SPECS, +) +def LOGICAL_NOT(input1: torch.Tensor) -> torch.Tensor: + _validate_bool_dtype(input1.dtype, "LOGICAL_NOT") + return torch.empty_like(input1, dtype=input1.dtype) + + +@register_fake_tosa_op( + "NEGATE(Tensor input1) -> Tensor", + DUAL_PROFILE_SPECS, +) +def NEGATE(input1: torch.Tensor) -> torch.Tensor: + _validate_negate_dtype(input1.dtype) + return torch.empty_like(input1, dtype=input1.dtype) + + +@register_fake_tosa_op( + "RECIPROCAL(Tensor input1) -> Tensor", + FP_SPECS, +) +def RECIPROCAL(input1: torch.Tensor) -> torch.Tensor: + _validate_float_dtype(input1.dtype, "RECIPROCAL") + return torch.empty_like(input1, dtype=input1.dtype) + + +@register_fake_tosa_op( + "RSQRT(Tensor input1) -> Tensor", + FP_SPECS, +) +def RSQRT(input1: torch.Tensor) -> torch.Tensor: + _validate_float_dtype(input1.dtype, "RSQRT") + return torch.empty_like(input1, dtype=input1.dtype) + + +@register_fake_tosa_op( + "SIN(Tensor input1) -> Tensor", + FP_SPECS, +) +def SIN(input1: torch.Tensor) -> torch.Tensor: + _validate_float_dtype(input1.dtype, "SIN") + return torch.empty_like(input1, dtype=input1.dtype)