diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index fcf95653438..34841a52dc7 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -100,6 +100,14 @@ runtime.python_library( resources = [ "vgf/shaders/grid_sampler.glsl", "vgf/shaders/grid_sampler.spirv.b64", + "vgf/shaders/grid_sampler_sampler.glsl", + "vgf/shaders/grid_sampler_sampler.spirv.b64", + "vgf/shaders/grid_sampler_sampler_align_corners.glsl", + "vgf/shaders/grid_sampler_sampler_align_corners.spirv.b64", + "vgf/shaders/grid_sampler_sampler_int8.glsl", + "vgf/shaders/grid_sampler_sampler_int8.spirv.b64", + "vgf/shaders/grid_sampler_sampler_int8_align_corners.glsl", + "vgf/shaders/grid_sampler_sampler_int8_align_corners.spirv.b64", ], deps = [ ":arm_compile_spec", diff --git a/backends/arm/operators/op_tosa_custom.py b/backends/arm/operators/op_tosa_custom.py index 45a6097af43..efa8701ae9e 100644 --- a/backends/arm/operators/op_tosa_custom.py +++ b/backends/arm/operators/op_tosa_custom.py @@ -40,6 +40,7 @@ def _vk_format_component_count(vk_format: str) -> int: "VK_FORMAT_R32G32_SFLOAT": 2, "VK_FORMAT_R8G8B8A8_UINT": 4, "VK_FORMAT_R8G8B8A8_SINT": 4, + "VK_FORMAT_R8G8B8A8_SNORM": 4, "VK_FORMAT_R16G16B16A16_UINT": 4, "VK_FORMAT_R16G16B16A16_SINT": 4, "VK_FORMAT_R16G16B16A16_SFLOAT": 4, diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 88b59b21d31..7810077a679 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -84,6 +84,8 @@ def __init__(self): class _QParams(NamedTuple): scale: float zero_point: int + quant_min: int | None = None + quant_max: int | None = None def _as_list(x): @@ -472,8 +474,53 @@ def _match_pattern( 8: _QParams((0.999 - (-0.999)) / (1 << 8), 0), 16: _QParams((0.99999 - (-0.99999)) / (1 << 16), 0), }, + # grid_sampler image input/output use SNORM-compatible qparams. The grid + # coordinate tensor is intentionally left unquantized. + torch.ops.aten.grid_sampler.default: { + 8: _QParams(1.0 / 127.0, 0, -127, 127), + }, +} + + +_fixed_output_qspec_ops: dict[Any, dict[int, _QParams]] = { + torch.ops.aten.grid_sampler.default: { + 8: _QParams(1.0 / 127.0, 0, -127, 127), + }, } + +def _get_fixed_qparams_qspec( + node_target: Any, + qparams_table: dict[Any, dict[int, _QParams]], + input_act_qspec: QuantizationSpecBase, +) -> FixedQParamsQuantizationSpec | None: + if not isinstance(input_act_qspec, QuantizationSpec): + raise ValueError("Fixed qparams require a QuantizationSpec input.") + + num_bits = torch.iinfo(input_act_qspec.dtype).bits + qparams = qparams_table[node_target].get(num_bits) + if qparams is None: + return None + + return FixedQParamsQuantizationSpec( + dtype=input_act_qspec.dtype, + scale=qparams.scale, + zero_point=qparams.zero_point, + quant_min=( + input_act_qspec.quant_min + if qparams.quant_min is None + else qparams.quant_min + ), + quant_max=( + input_act_qspec.quant_max + if qparams.quant_max is None + else qparams.quant_max + ), + qscheme=input_act_qspec.qscheme, + is_dynamic=input_act_qspec.is_dynamic, + ) + + _one_to_one: set[OpOverload] = { torch.ops.aten.abs.default, torch.ops.aten.ceil.default, @@ -762,6 +809,16 @@ def any_or_hardtanh_min_zero(n: Node): _QuantProperty(1, input_act_qspec), ] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) + elif node.target == torch.ops.aten.grid_sampler.default: + image_node = ensure_type(Node, node.args[0]) + grid_sampler_image_qspec = quantization_config.get_input_act_qspec( + node, image_node + ) + grid_sampler_output_qspec = quantization_config.get_output_act_qspec(node) + if grid_sampler_image_qspec is None or grid_sampler_output_qspec is None: + return None + quant_properties.quant_inputs = [_QuantProperty(0, grid_sampler_image_qspec)] + quant_properties.quant_output = _QuantProperty(0, grid_sampler_output_qspec) elif node.target in (torch.ops.aten.where.self,): true_node = ensure_type(Node, node.args[1]) input_qspec = ( @@ -825,21 +882,15 @@ def any_or_hardtanh_min_zero(n: Node): quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in _fixed_input_qspec_ops: - num_bits = torch.iinfo(input_act_qspec.dtype).bits - qparams = _fixed_input_qspec_ops[node.target][num_bits] - + fixed_input_qspec = _get_fixed_qparams_qspec( + node.target, _fixed_input_qspec_ops, input_act_qspec + ) + if fixed_input_qspec is None: + return None quant_properties.quant_inputs = [ _QuantProperty( 0, - FixedQParamsQuantizationSpec( - dtype=input_act_qspec.dtype, - scale=qparams.scale, - zero_point=qparams.zero_point, - quant_min=input_act_qspec.quant_min, - quant_max=input_act_qspec.quant_max, - qscheme=input_act_qspec.qscheme, - is_dynamic=input_act_qspec.is_dynamic, - ), + fixed_input_qspec, ) ] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index 0c64d147c84..a076d20743e 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -21,7 +21,6 @@ from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, - FixedQParamsQuantizationSpec, QuantizationSpec, QuantizationSpecBase, SharedQuantizationSpec, @@ -295,6 +294,7 @@ def get_input_act_qspec(self, node=None, input_node=None): # MLETORCH-1853: Fix lazy import when moving files around from executorch.backends.arm.quantizer.quantization_annotator import ( _fixed_input_qspec_ops, + _get_fixed_qparams_qspec, ) if node is None or input_node is None: @@ -305,28 +305,17 @@ def get_input_act_qspec(self, node=None, input_node=None): return super().get_input_act_qspec(node, input_node) else: return SharedQuantizationSpec((node.args[0], node)) + elif node.target == torch.ops.aten.grid_sampler.default: + if input_node != node.args[0]: + return None + input_act_qspec = super().get_input_act_qspec(node, input_node) + return _get_fixed_qparams_qspec( + node.target, _fixed_input_qspec_ops, input_act_qspec + ) elif node.target in _fixed_input_qspec_ops: - input_act_qspec = super().get_input_act_qspec(node, input_node) - if not hasattr(input_act_qspec, "dtype") or not isinstance( - input_act_qspec.dtype, torch.dtype - ): - raise ValueError( - f"{node.target} requires an input activation quantization " - "spec to use fixed input qparams." - ) - dtype = getattr(input_act_qspec, "dtype", None) - num_bits = torch.iinfo(dtype).bits - - qparams = _fixed_input_qspec_ops[node.target][num_bits] - return FixedQParamsQuantizationSpec( - dtype=dtype, - scale=qparams.scale, - zero_point=qparams.zero_point, - quant_min=input_act_qspec.quant_min, - quant_max=input_act_qspec.quant_max, - qscheme=input_act_qspec.qscheme, - is_dynamic=input_act_qspec.is_dynamic, + return _get_fixed_qparams_qspec( + node.target, _fixed_input_qspec_ops, input_act_qspec ) return super().get_input_act_qspec(node, input_node) @@ -371,6 +360,19 @@ def get_output_act_qspec( if node is None: return super().get_output_act_qspec() + # MLETORCH-1853: Fix lazy import when moving files around + from executorch.backends.arm.quantizer.quantization_annotator import ( + _fixed_output_qspec_ops, + _get_fixed_qparams_qspec, + ) + + if node.target in _fixed_output_qspec_ops: + output_act_qspec = super().get_output_act_qspec(node) + if output_act_qspec is None: + return None + return _get_fixed_qparams_qspec( + node.target, _fixed_output_qspec_ops, output_act_qspec + ) if node.target not in self.SHARED_OUTPUT_ACT_QSPEC_PATTERNS: return super().get_output_act_qspec() if len(node.args) == 0: diff --git a/backends/arm/quantizer/quantizer_support.py b/backends/arm/quantizer/quantizer_support.py index d6a725c2b06..a462e0227c8 100644 --- a/backends/arm/quantizer/quantizer_support.py +++ b/backends/arm/quantizer/quantizer_support.py @@ -174,6 +174,7 @@ def check_pattern(cls, pattern): (torch.ops.aten.acos.default,), (torch.ops.aten.atanh.default,), (torch.ops.aten.einsum.default,), + (torch.ops.aten.grid_sampler.default,), ] ) TOSA_QUANTIZER_SUPPORT_DICT: dict[tuple[OpOverload, ...], type[PatternCheck] | None] = { diff --git a/backends/arm/scripts/generate_grid_sampler_spirv.py b/backends/arm/scripts/generate_grid_sampler_spirv.py index f8956a86cda..d5ab01c214c 100644 --- a/backends/arm/scripts/generate_grid_sampler_spirv.py +++ b/backends/arm/scripts/generate_grid_sampler_spirv.py @@ -65,7 +65,13 @@ def main() -> None: with tempfile.TemporaryDirectory() as tmpdir: spirv_path = Path(tmpdir) / "grid_sampler.spirv" subprocess.run( # nosec B603 - glslc path is resolved explicitly. - [glslc, str(args.source), "-o", str(spirv_path)], + [ + glslc, + "-fshader-stage=compute", + str(args.source), + "-o", + str(spirv_path), + ], check=True, ) _write_base64_spirv(spirv_path, args.output) diff --git a/backends/arm/test/misc/test_custom_shader_payload.py b/backends/arm/test/misc/test_custom_shader_payload.py index 6243e8752ba..bfafc143c57 100644 --- a/backends/arm/test/misc/test_custom_shader_payload.py +++ b/backends/arm/test/misc/test_custom_shader_payload.py @@ -4,12 +4,24 @@ # LICENSE file in the root directory of this source tree. import base64 +from importlib.resources import files import pytest +import torch from executorch.backends.arm.vgf.shaders.grid_sampler import ( build_grid_sampler_2d_payload, decode_payload, encode_payload, + GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_BINARY, + GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_SOURCE, + GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_BINARY, + GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_SOURCE, + GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_BINARY, + GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_SOURCE, + GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT, + GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY, + GRID_SAMPLER_2D_SAMPLER_SHADER_SOURCE, + GRID_SAMPLER_2D_SAMPLER_VK_FORMAT, GRID_SAMPLER_2D_SHADER_BINARY, GRID_SAMPLER_2D_SHADER_ENTRY_POINT, GRID_SAMPLER_2D_SHADER_LANGUAGE, @@ -19,6 +31,15 @@ ) +def _shader_code_from_resource(shader_file: str) -> str: + return "".join( + files("executorch.backends.arm.vgf.shaders") + .joinpath(shader_file) + .read_text(encoding="utf-8") + .split() + ) + + def test_grid_sampler_2d_custom_shader_payload_no_target_round_trip(): payload = build_grid_sampler_2d_payload( interpolation_mode=0, @@ -45,6 +66,144 @@ def test_grid_sampler_2d_custom_shader_payload_no_target_round_trip(): assert decoded["output_0_binding"] == 2 +def test_grid_sampler_2d_custom_shader_payload_no_target_uses_sampler_for_c4(): + payload = build_grid_sampler_2d_payload( + interpolation_mode=0, + padding_mode=0, + align_corners=False, + input_shape=(1, 4, 8, 8), + input_dtype=torch.float32, + ) + + assert payload["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE + assert base64.b64decode(payload["shader_code"])[:4] == b"\x03\x02\x23\x07" + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ) + assert payload["input_1_type"] == "Tensor" + assert payload["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT + assert payload["input_1_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_TENSOR_ARM" + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" + assert payload["input_0_sampler"] == { + "address_mode_u": "VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_BORDER", + "address_mode_v": "VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_BORDER", + "border_color": "VK_BORDER_COLOR_FLOAT_TRANSPARENT_BLACK", + "mag_filter": "VK_FILTER_LINEAR", + "min_filter": "VK_FILTER_LINEAR", + } + + +def test_grid_sampler_2d_custom_shader_payload_no_target_uses_int8_sampler_for_c4(): + payload = build_grid_sampler_2d_payload( + interpolation_mode=0, + padding_mode=0, + align_corners=False, + input_shape=(1, 4, 8, 8), + input_dtype=torch.int8, + output_dtype=torch.int8, + ) + + assert payload["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE + assert base64.b64decode(payload["shader_code"])[:4] == b"\x03\x02\x23\x07" + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ) + assert payload["input_1_type"] == "Tensor" + assert payload["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT + assert payload["input_1_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_TENSOR_ARM" + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" + + +def test_grid_sampler_2d_custom_shader_payload_no_target_keeps_c3_on_buffer(): + payload = build_grid_sampler_2d_payload( + interpolation_mode=0, + padding_mode=0, + align_corners=False, + input_shape=(1, 3, 8, 8), + input_dtype=torch.float32, + ) + + assert payload["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE + assert payload["input_0_type"] == "Tensor" + assert payload["input_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER" + assert payload["output_0_type"] == "Tensor" + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER" + assert "input_0_sampler" not in payload + + +def test_grid_sampler_2d_custom_shader_payload_no_target_align_corners_sampler(): + payload = build_grid_sampler_2d_payload( + interpolation_mode=0, + padding_mode=0, + align_corners=True, + input_shape=(1, 4, 8, 8), + input_dtype=torch.float32, + ) + + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ) + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" + assert payload["shader_code"] == _shader_code_from_resource( + GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_BINARY + ) + + +def test_grid_sampler_2d_custom_shader_payload_no_target_int8_align_corners_sampler(): + payload = build_grid_sampler_2d_payload( + interpolation_mode=0, + padding_mode=0, + align_corners=True, + input_shape=(1, 4, 8, 8), + input_dtype=torch.int8, + output_dtype=torch.int8, + ) + + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ) + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" + assert payload["shader_code"] == _shader_code_from_resource( + GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_BINARY + ) + + +def test_grid_sampler_2d_custom_shader_payload_no_target_bicubic_buffer(): + payload = build_grid_sampler_2d_payload( + interpolation_mode=2, + padding_mode=0, + align_corners=False, + input_shape=(1, 4, 8, 8), + input_dtype=torch.float32, + ) + + assert payload["input_0_type"] == "Tensor" + assert payload["input_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER" + assert payload["output_0_type"] == "Tensor" + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER" + assert "input_0_sampler" not in payload + + def test_grid_sampler_2d_custom_shader_payload_no_target_uses_spirv(): payload = build_grid_sampler_2d_payload( interpolation_mode=0, @@ -61,6 +220,31 @@ def test_grid_sampler_2d_custom_shader_payload_no_target_uses_spirv(): def test_grid_sampler_2d_custom_shader_payload_no_target_has_shader_resources(): assert GRID_SAMPLER_2D_SHADER_SOURCE == "grid_sampler.glsl" assert GRID_SAMPLER_2D_SHADER_BINARY == "grid_sampler.spirv.b64" + assert GRID_SAMPLER_2D_SAMPLER_SHADER_SOURCE == "grid_sampler_sampler.glsl" + assert GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY == "grid_sampler_sampler.spirv.b64" + assert ( + GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_SOURCE + == "grid_sampler_sampler_align_corners.glsl" + ) + assert ( + GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_BINARY + == "grid_sampler_sampler_align_corners.spirv.b64" + ) + assert ( + GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_SOURCE == "grid_sampler_sampler_int8.glsl" + ) + assert ( + GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_BINARY + == "grid_sampler_sampler_int8.spirv.b64" + ) + assert ( + GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_SOURCE + == "grid_sampler_sampler_int8_align_corners.glsl" + ) + assert ( + GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_BINARY + == "grid_sampler_sampler_int8_align_corners.spirv.b64" + ) def test_grid_sampler_2d_custom_shader_payload_no_target_rejects_bad_modes(): diff --git a/backends/arm/test/misc/test_multiple_delegates.py b/backends/arm/test/misc/test_multiple_delegates.py index bbc0b2b1bce..e8a929a4e1f 100644 --- a/backends/arm/test/misc/test_multiple_delegates.py +++ b/backends/arm/test/misc/test_multiple_delegates.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,6 +10,7 @@ from executorch.backends.arm.test.tester.test_pipeline import ( TosaPipelineFP, TosaPipelineINT, + VgfPipeline, ) @@ -51,3 +52,20 @@ def test_multiple_delegates_tosa_INT(test_data: input_t1): "check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2} ) pipeline.run() + + +@common.parametrize("test_data", MultipleDelegatesModule.inputs) +@common.SkipIfNoModelConverter +def test_multiple_delegates_vgf_INT(test_data: input_t1): + aten_ops: list[str] = [] + exir_ops: list[str] = [] + pipeline = VgfPipeline[input_t1]( + MultipleDelegatesModule(), + test_data, + aten_ops, + exir_ops, + quantize=True, + run_on_vulkan_runtime=False, + n_expected_delegates=2, + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py b/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py index ec7773dfdbc..eb4b5f23660 100644 --- a/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py +++ b/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py @@ -4,8 +4,14 @@ # LICENSE file in the root directory of this source tree. import executorch.backends.arm.tosa.dialect # noqa: F401 +import pytest import torch import torch.nn.functional as F +from executorch.backends.arm._passes import FoldAndAnnotateQParamsPass +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) from executorch.backends.arm.tosa.specification import ( TosaLoweringContext, TosaSpecification, @@ -17,6 +23,8 @@ CUSTOM_SHADER_DOMAIN_NAME, decode_payload, grid_sampler_2d_operator_name, + GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT, + GRID_SAMPLER_2D_SAMPLER_VK_FORMAT, GRID_SAMPLER_2D_SHADER_ENTRY_POINT, GRID_SAMPLER_2D_SHADER_LANGUAGE, GRID_SAMPLER_2D_VK_FORMAT, @@ -25,6 +33,7 @@ from executorch.exir import to_edge from executorch.exir.dialects._ops import ops as exir_ops from torch.export import export +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e class GridSampler2d(torch.nn.Module): @@ -35,10 +44,11 @@ def __init__(self): self.align_corners_ = False def forward(self, x, grid): + mode = ("bilinear", "nearest", "bicubic")[self.interpolation_mode_] return F.grid_sample( x, grid, - mode="bilinear" if self.interpolation_mode_ == 0 else "nearest", + mode=mode, padding_mode="zeros" if self.padding_mode_ == 0 else "border", align_corners=self.align_corners_, ) @@ -80,15 +90,150 @@ def test_rewrite_grid_sampler_to_tosa_custom_vgf_no_target(): assert payload["entry_point"] == GRID_SAMPLER_2D_SHADER_ENTRY_POINT assert payload["workgroup_sizes"] == GRID_SAMPLER_2D_WORKGROUP_SIZES assert payload["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE - assert payload["input_0_type"] == "Tensor" - assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ) assert payload["input_0_binding"] == 0 assert payload["input_0_descriptorset"] == 0 assert payload["input_1_type"] == "Tensor" assert payload["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT + assert payload["input_1_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_TENSOR_ARM" assert payload["input_1_binding"] == 1 assert payload["input_1_descriptorset"] == 0 - assert payload["output_0_type"] == "Tensor" - assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" assert payload["output_0_binding"] == 2 assert payload["output_0_descriptorset"] == 0 + assert any(node.target == exir_ops.edge.aten.slice_copy.Tensor for node in nodes) + + +def test_rewrite_grid_sampler_to_tosa_custom_no_target_uses_sampler_for_c4(): + model = GridSampler2d() + example_inputs = ( + torch.randn(1, 4, 8, 8), + torch.randn(1, 4, 4, 2), + ) + + edge_model = to_edge(export(model, example_inputs)) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")): + edge_model = edge_model.transform([RewriteGridSamplerToTosaCustomPass()]) + nodes = list(edge_model.exported_program().graph.nodes) + + custom_node = next( + node for node in nodes if node.target == exir_ops.backend.tosa.CUSTOM.default + ) + payload = decode_payload(custom_node.kwargs["implementation_attrs"]) + + assert payload["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ) + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert payload["input_1_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_TENSOR_ARM" + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + + +@pytest.mark.parametrize("channels", [3, 4]) +@pytest.mark.parametrize("use_composable_quantizer", [False, True]) +def test_quantized_grid_sampler_uses_int8_sampler_payload( + channels, use_composable_quantizer +): + model = GridSampler2d().eval() + example_inputs = ( + torch.randn(1, channels, 8, 8), + torch.rand(1, 4, 4, 2), + ) + quantizer = TOSAQuantizer( + TosaSpecification.create_from_string("TOSA-1.0+INT"), + use_composable_quantizer=use_composable_quantizer, + ) + quantizer.set_global(get_symmetric_quantization_config(is_per_channel=False)) + + exported = export(model, example_inputs, strict=True) + prepared = prepare_pt2e(exported.module(), quantizer) + prepared(*example_inputs) + converted = convert_pt2e(prepared) + + edge_model = to_edge(export(converted, example_inputs, strict=True)) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP+INT")): + edge_model = edge_model.transform( + [FoldAndAnnotateQParamsPass(), RewriteGridSamplerToTosaCustomPass()] + ) + nodes = list(edge_model.exported_program().graph.nodes) + + custom_node = next( + node for node in nodes if node.target == exir_ops.backend.tosa.CUSTOM.default + ) + payload = decode_payload(custom_node.kwargs["implementation_attrs"]) + + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + assert payload["input_1_type"] == "Tensor" + assert payload["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + assert custom_node.meta["input_qparams"][0].qmin == -127 + assert custom_node.meta["input_qparams"][0].qmax == 127 + assert next(iter(custom_node.meta["output_qparams"].values())).qmin == -127 + assert next(iter(custom_node.meta["output_qparams"].values())).qmax == 127 + + +def test_rewrite_grid_sampler_to_tosa_custom_c3_pad_for_align_corners(): + model = GridSampler2d() + model.align_corners_ = True + example_inputs = ( + torch.randn(1, 3, 8, 8), + torch.randn(1, 4, 4, 2), + ) + + edge_model = to_edge(export(model, example_inputs)) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")): + edge_model = edge_model.transform([RewriteGridSamplerToTosaCustomPass()]) + nodes = list(edge_model.exported_program().graph.nodes) + + custom_node = next( + node for node in nodes if node.target == exir_ops.backend.tosa.CUSTOM.default + ) + payload = decode_payload(custom_node.kwargs["implementation_attrs"]) + + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ) + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert any(node.target == exir_ops.edge.aten.cat.default for node in nodes) + assert any(node.target == exir_ops.edge.aten.slice_copy.Tensor for node in nodes) + + +def test_rewrite_grid_sampler_to_tosa_custom_no_c3_pad_for_bicubic(): + model = GridSampler2d() + model.interpolation_mode_ = 2 + example_inputs = ( + torch.randn(1, 3, 8, 8), + torch.randn(1, 4, 4, 2), + ) + + edge_model = to_edge(export(model, example_inputs)) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")): + edge_model = edge_model.transform([RewriteGridSamplerToTosaCustomPass()]) + nodes = list(edge_model.exported_program().graph.nodes) + + custom_node = next( + node for node in nodes if node.target == exir_ops.backend.tosa.CUSTOM.default + ) + payload = decode_payload(custom_node.kwargs["implementation_attrs"]) + + assert payload["input_0_type"] == "Tensor" + assert not any(node.target == exir_ops.edge.aten.cat.default for node in nodes) + assert not any( + node.target == exir_ops.edge.aten.slice_copy.Tensor for node in nodes + ) diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 1304c2b2e54..589c0a851a3 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -1223,6 +1223,8 @@ class VgfPipeline(BasePipeline, Generic[T]): use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module. custom_path : Path to dump intermediate artifacts such as tosa and pte to. + n_expected_delegates: Number of delegate calls expected after + partitioning. """ @@ -1252,6 +1254,7 @@ def __init__( tosa_spec: TosaSpecification | str | None = None, fold_quantize: bool = True, preserve_io_quantization: bool = False, + n_expected_delegates: int = 1, ): if tosa_spec is None: if tosa_version is None: @@ -1282,6 +1285,10 @@ def __init__( dynamic_shapes, transform_passes=transform_passes, ) + self.change_args( + "check_count.exir", + {"torch.ops.higher_order.executorch_call_delegate": n_expected_delegates}, + ) remove_torch_quant_nodes_stage = ( "to_edge_transform_and_lower" diff --git a/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py b/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py index 9d4f17dc936..efff1730914 100644 --- a/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py +++ b/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py @@ -3,12 +3,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math import operator from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) +from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER from executorch.backends.arm.tosa.dialect.ops.custom import register_fake_tosa from executorch.backends.arm.vgf.shaders.grid_sampler import ( @@ -84,6 +90,62 @@ def _set_fake_tensor_meta(node: torch.fx.Node, value) -> None: node.meta["tensor_meta"] = _extract_tensor_metadata(value) +def _is_static_nchw_with_channels(node: torch.fx.Node, channels: int) -> bool: + value = node.meta.get("val") + return ( + isinstance(value, torch.Tensor) + and len(value.shape) == 4 + and int(value.shape[1]) == channels + ) + + +def _can_pad_c3_for_sampler( + input_tensor: torch.fx.Node, + interpolation_mode: int, + align_corners: bool, # noqa: ARG001 +) -> bool: + value = input_tensor.meta.get("val") + return ( + isinstance(value, torch.Tensor) + and len(value.shape) == 4 + and int(value.shape[0]) == 1 + and int(value.shape[1]) == 3 + and value.dtype in (torch.float32, torch.int8) + and int(interpolation_mode) in (0, 1) + ) + + +def _uses_grid_sampler_int8_snorm_qparams(qparams: QuantArgs) -> bool: + return ( + not qparams.per_channel + and math.isclose( + qparams.get_scale_per_tensor(), 1.0 / 127.0, rel_tol=1e-6, abs_tol=1e-9 + ) + and qparams.get_zp_per_tensor() == 0 + and qparams.qmin == -127 + and qparams.qmax == 127 + and qparams.dtype == torch.int8 + ) + + +def _uses_grid_sampler_int8_snorm_metadata(node: torch.fx.Node) -> bool: + try: + input_qparams = get_input_qparams(node) + output_qparams = get_output_qparams(node) + except ValueError: + return False + + image_qparams = input_qparams.get(0) + if image_qparams is None: + return False + if not output_qparams: + return False + + return _uses_grid_sampler_int8_snorm_qparams( + image_qparams + ) and _uses_grid_sampler_int8_snorm_qparams(next(iter(output_qparams.values()))) + + class RewriteGridSamplerToTosaCustomPass(ArmPass): """Rewrite ``aten.grid_sampler_2d`` nodes to ``tosa.CUSTOM``.""" @@ -92,15 +154,72 @@ class RewriteGridSamplerToTosaCustomPass(ArmPass): @staticmethod def _encode_payload( - interpolation_mode: int, padding_mode: int, align_corners: bool + interpolation_mode: int, + padding_mode: int, + align_corners: bool, + input_tensor: torch.fx.Node, + output_dtype: torch.dtype | None = None, ) -> list[int]: + input_val = input_tensor.meta.get("val") + if input_val is None: + raise RuntimeError("grid_sampler_2d input is missing tensor metadata") payload = build_grid_sampler_2d_payload( interpolation_mode=interpolation_mode, padding_mode=padding_mode, align_corners=align_corners, + input_shape=tuple(input_val.shape), + input_dtype=input_val.dtype, + output_dtype=output_dtype, ) return encode_payload(payload) + @staticmethod + def _pad_c3_input_to_c4( + graph_module: torch.fx.GraphModule, + input_tensor: torch.fx.Node, + ) -> torch.fx.Node: + input_val = input_tensor.meta["val"] + first_channel = create_node( + graph_module.graph, + op_target=exir_ops.edge.aten.slice_copy.Tensor, + args=(input_tensor, 1, 0, 1, 1), + from_node=input_tensor, + ) + first_channel_val = exir_ops.edge.aten.slice_copy.Tensor(input_val, 1, 0, 1, 1) + _set_fake_tensor_meta(first_channel, first_channel_val) + + padded_input = create_node( + graph_module.graph, + op_target=exir_ops.edge.aten.cat.default, + args=([input_tensor, first_channel], 1), + from_node=input_tensor, + ) + _set_fake_tensor_meta( + padded_input, + exir_ops.edge.aten.cat.default([input_val, first_channel_val], 1), + ) + return padded_input + + @staticmethod + def _slice_c4_output_to_c3( + graph_module: torch.fx.GraphModule, + output: torch.fx.Node, + original_node: torch.fx.Node, + ) -> torch.fx.Node: + output_val = output.meta["val"] + sliced_output = create_node( + graph_module.graph, + op_target=exir_ops.edge.aten.slice_copy.Tensor, + args=(output, 1, 0, 3, 1), + from_node=original_node, + ) + sliced_output.meta = dict(original_node.meta) + _set_fake_tensor_meta( + sliced_output, + exir_ops.edge.aten.slice_copy.Tensor(output_val, 1, 0, 3, 1), + ) + return sliced_output + def call(self, graph_module): modified = False for node in graph_module.graph.nodes: @@ -114,12 +233,14 @@ def call(self, graph_module): input_tensor, grid, interpolation_mode, padding_mode, align_corners = ( node.args ) - - implementation_attrs = self._encode_payload( - interpolation_mode=interpolation_mode, - padding_mode=padding_mode, - align_corners=align_corners, + use_quantized_image_payload = _uses_grid_sampler_int8_snorm_metadata(node) + output_dtype = torch.int8 if use_quantized_image_payload else None + pad_c3_for_sampler = _can_pad_c3_for_sampler( + input_tensor, + interpolation_mode, + align_corners, ) + operator_name = grid_sampler_2d_operator_name( interpolation_mode=interpolation_mode, padding_mode=padding_mode, @@ -127,16 +248,28 @@ def call(self, graph_module): ) with graph_module.graph.inserting_before(node): + custom_input = ( + self._pad_c3_input_to_c4(graph_module, input_tensor) + if pad_c3_for_sampler + else input_tensor + ) + implementation_attrs = self._encode_payload( + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + input_tensor=custom_input, + output_dtype=output_dtype, + ) nhwc_input = create_node( graph_module.graph, op_target=exir_ops.edge.aten.permute_copy.default, - args=(input_tensor, list(NHWC_ORDER)), - from_node=input_tensor, + args=(custom_input, list(NHWC_ORDER)), + from_node=custom_input, ) _set_fake_tensor_meta( nhwc_input, exir_ops.edge.aten.permute_copy.default( - input_tensor.meta["val"], list(NHWC_ORDER) + custom_input.meta["val"], list(NHWC_ORDER) ), ) @@ -184,10 +317,18 @@ def call(self, graph_module): custom_output, list(NHWC_INVERSE_ORDER) ), ) - node.replace_all_uses_with(output) + if pad_c3_for_sampler: + with graph_module.graph.inserting_after(output): + replacement = self._slice_c4_output_to_c3( + graph_module, output, node + ) + else: + replacement = output + node.replace_all_uses_with(replacement) graph_module.graph.erase_node(node) if modified: + graph_module.graph.eliminate_dead_code() graph_module.graph.lint() graph_module.recompile() graph_module = super().call(graph_module).graph_module diff --git a/backends/arm/vgf/shaders/grid_sampler.py b/backends/arm/vgf/shaders/grid_sampler.py index 800a4ec0013..5cd63ddf95b 100644 --- a/backends/arm/vgf/shaders/grid_sampler.py +++ b/backends/arm/vgf/shaders/grid_sampler.py @@ -15,6 +15,24 @@ GRID_SAMPLER_2D_VK_FORMAT = "VK_FORMAT_R32_SFLOAT" GRID_SAMPLER_2D_SHADER_SOURCE = "grid_sampler.glsl" GRID_SAMPLER_2D_SHADER_BINARY = "grid_sampler.spirv.b64" +GRID_SAMPLER_2D_SAMPLER_SHADER_SOURCE = "grid_sampler_sampler.glsl" +GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY = "grid_sampler_sampler.spirv.b64" +GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_SOURCE = ( + "grid_sampler_sampler_align_corners.glsl" +) +GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_BINARY = ( + "grid_sampler_sampler_align_corners.spirv.b64" +) +GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_SOURCE = "grid_sampler_sampler_int8.glsl" +GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_BINARY = "grid_sampler_sampler_int8.spirv.b64" +GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_SOURCE = ( + "grid_sampler_sampler_int8_align_corners.glsl" +) +GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_BINARY = ( + "grid_sampler_sampler_int8_align_corners.spirv.b64" +) +GRID_SAMPLER_2D_SAMPLER_VK_FORMAT = "VK_FORMAT_R32G32B32A32_SFLOAT" +GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT = "VK_FORMAT_R8G8B8A8_SNORM" _INTERPOLATION_MODE_NAMES = { 0: "bilinear", @@ -67,6 +85,9 @@ def build_grid_sampler_2d_payload( interpolation_mode: int, padding_mode: int, align_corners: bool, + input_shape: tuple[int, ...] | None = None, + input_dtype: Any | None = None, + output_dtype: Any | None = None, ) -> dict[str, Any]: _mode_name( int(interpolation_mode), @@ -78,34 +99,126 @@ def build_grid_sampler_2d_payload( _PADDING_MODE_NAMES, "padding_mode", ) + if output_dtype is None: + output_dtype = input_dtype + + sampler_vk_format = _sampler_vk_format(input_dtype, output_dtype) + use_sampler = ( + input_shape is not None + and len(input_shape) == 4 + and int(input_shape[0]) == 1 + and int(input_shape[1]) == 4 + and sampler_vk_format is not None + and int(interpolation_mode) in (0, 1) + ) + shader_file = ( + _sampler_shader_file(sampler_vk_format, align_corners=align_corners) + if use_sampler + else GRID_SAMPLER_2D_SHADER_BINARY + ) shader_code = "".join( - files(__package__) - .joinpath(GRID_SAMPLER_2D_SHADER_BINARY) - .read_text(encoding="utf-8") - .split() + files(__package__).joinpath(shader_file).read_text(encoding="utf-8").split() ) - return { + payload = { "entry_point": GRID_SAMPLER_2D_SHADER_ENTRY_POINT, "workgroup_sizes": GRID_SAMPLER_2D_WORKGROUP_SIZES, "shader_language": GRID_SAMPLER_2D_SHADER_LANGUAGE, "shader_code": shader_code, - "input_0_type": "Tensor", - "input_0_vkformat": GRID_SAMPLER_2D_VK_FORMAT, - "input_0_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", "input_0_binding": 0, "input_0_descriptorset": 0, "input_1_type": "Tensor", "input_1_vkformat": GRID_SAMPLER_2D_VK_FORMAT, - "input_1_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", "input_1_binding": 1, "input_1_descriptorset": 0, - "output_0_type": "Tensor", - "output_0_vkformat": GRID_SAMPLER_2D_VK_FORMAT, - "output_0_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", "output_0_binding": 2, "output_0_descriptorset": 0, } + if use_sampler: + payload.update( + { + "input_0_type": "Image", + "input_0_vkformat": sampler_vk_format, + "input_0_vkdescriptortype": ( + "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ), + "input_0_sampler": _sampler_config( + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + ), + "input_1_vkdescriptortype": "VK_DESCRIPTOR_TYPE_TENSOR_ARM", + "output_0_type": "Image", + "output_0_vkformat": sampler_vk_format, + "output_0_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE", + } + ) + else: + payload.update( + { + "input_0_type": "Tensor", + "input_0_vkformat": GRID_SAMPLER_2D_VK_FORMAT, + "input_0_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", + "input_1_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", + "output_0_type": "Tensor", + "output_0_vkformat": GRID_SAMPLER_2D_VK_FORMAT, + "output_0_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", + } + ) + return payload + + +def _sampler_vk_format(input_dtype: Any | None, output_dtype: Any | None) -> str | None: + if str(input_dtype) != str(output_dtype): + return None + if str(input_dtype) == "torch.float32": + return GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + if str(input_dtype) == "torch.int8": + return GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + return None + + +def _sampler_shader_file( + sampler_vk_format: str | None, + align_corners: bool, +) -> str: + if sampler_vk_format == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT: + if align_corners: + return GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_BINARY + return GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_BINARY + if align_corners: + return GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_BINARY + return GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY + + +def _sampler_config(interpolation_mode: int, padding_mode: int) -> dict[str, str]: + interpolation = _mode_name( + int(interpolation_mode), + _INTERPOLATION_MODE_NAMES, + "interpolation_mode", + ) + padding = _mode_name( + int(padding_mode), + _PADDING_MODE_NAMES, + "padding_mode", + ) + + filter_mode = ( + "VK_FILTER_NEAREST" if interpolation == "nearest" else "VK_FILTER_LINEAR" + ) + if padding == "zeros": + address_mode = "VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_BORDER" + elif padding == "border": + address_mode = "VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE" + else: + address_mode = "VK_SAMPLER_ADDRESS_MODE_MIRRORED_REPEAT" + + return { + "min_filter": filter_mode, + "mag_filter": filter_mode, + "address_mode_u": address_mode, + "address_mode_v": address_mode, + "border_color": "VK_BORDER_COLOR_FLOAT_TRANSPARENT_BLACK", + } def encode_payload(payload: dict[str, Any]) -> list[int]: diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler.glsl b/backends/arm/vgf/shaders/grid_sampler_sampler.glsl new file mode 100644 index 00000000000..69fcbd3d59f --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler.glsl @@ -0,0 +1,35 @@ +// Copyright 2026 Arm Limited and/or its affiliates. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#version 450 +#extension GL_ARM_tensors : require + +layout(set = 0, binding = 0) uniform sampler2D inputImage; +layout(set = 0, binding = 1) uniform tensorARM grid; +layout(set = 0, binding = 2, rgba32f) uniform writeonly image2D outImage; + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +vec2 readGridXY(ivec2 p) { + uint xCoords[4] = uint[](0u, uint(p.y), uint(p.x), 0u); + uint yCoords[4] = uint[](0u, uint(p.y), uint(p.x), 1u); + float xVal[1]; + float yVal[1]; + tensorReadARM(grid, xCoords, xVal); + tensorReadARM(grid, yCoords, yVal); + return vec2(xVal[0], yVal[0]); +} + +void main() { + ivec2 outSize = imageSize(outImage); + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= outSize.x || gid.y >= outSize.y) { + return; + } + + vec2 gridXY = readGridXY(gid); + vec2 uv = (gridXY + vec2(1.0)) * 0.5; + imageStore(outImage, gid, texture(inputImage, uv)); +} diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler.spirv.b64 b/backends/arm/vgf/shaders/grid_sampler_sampler.spirv.b64 new file mode 100644 index 00000000000..e460cbe8621 --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler.spirv.b64 @@ -0,0 +1 @@ +AwIjBwAAAQALAA0AdQAAAAAAAAARAAIAAQAAABEAAgAyAAAAEQACAE4QAAAKAAUAU1BWX0FSTV90ZW5zb3JzAAsABgABAAAAR0xTTC5zdGQuNDUwAAAAAA4AAwAAAAAAAQAAAA8ABgAFAAAABAAAAG1haW4AAAAARQAAABAABgAEAAAAEQAAAAgAAAAIAAAAAQAAAAMAAwACAAAAwgEAAAQABQBHTF9BUk1fdGVuc29ycwAABAAKAEdMX0dPT0dMRV9jcHBfc3R5bGVfbGluZV9kaXJlY3RpdmUAAAQACABHTF9HT09HTEVfaW5jbHVkZV9kaXJlY3RpdmUABQAEAAQAAABtYWluAAAAAAUABgANAAAAcmVhZEdyaWRYWSh2aTI7AAUAAwAMAAAAcAAAAAUABAATAAAAeENvb3JkcwAFAAQAHgAAAHlDb29yZHMABQAEACgAAABncmlkAAAAAAUABAAtAAAAeFZhbAAAAAAFAAQAMQAAAHlWYWwAAAAABQAEADwAAABvdXRTaXplAAUABQA/AAAAb3V0SW1hZ2UAAAAABQADAEIAAABnaWQABQAIAEUAAABnbF9HbG9iYWxJbnZvY2F0aW9uSUQAAAAFAAQAXQAAAGdyaWRYWQAABQAEAF4AAABwYXJhbQAAAAUAAwBhAAAAdXYAAAUABQBtAAAAaW5wdXRJbWFnZQAARwAEACgAAAAhAAAAAQAAAEcABAAoAAAAIgAAAAAAAABHAAMAPwAAABkAAABHAAQAPwAAACEAAAACAAAARwAEAD8AAAAiAAAAAAAAAEcABABFAAAACwAAABwAAABHAAQAbQAAACEAAAAAAAAARwAEAG0AAAAiAAAAAAAAAEcABAB0AAAACwAAABkAAAATAAIAAgAAACEAAwADAAAAAgAAABUABAAGAAAAIAAAAAEAAAAXAAQABwAAAAYAAAACAAAAIAAEAAgAAAAHAAAABwAAABYAAwAJAAAAIAAAABcABAAKAAAACQAAAAIAAAAhAAQACwAAAAoAAAAIAAAAFQAEAA8AAAAgAAAAAAAAACsABAAPAAAAEAAAAAQAAAAcAAQAEQAAAA8AAAAQAAAAIAAEABIAAAAHAAAAEQAAACsABAAPAAAAFAAAAAAAAAArAAQADwAAABUAAAABAAAAIAAEABYAAAAHAAAABgAAAEMQBAAmAAAACQAAABAAAAAgAAQAJwAAAAAAAAAmAAAAOwAEACcAAAAoAAAAAAAAABwABAArAAAACQAAABUAAAAgAAQALAAAAAcAAAArAAAAKwAEAAYAAAAzAAAAAAAAACAABAA0AAAABwAAAAkAAAAZAAkAPQAAAAkAAAABAAAAAAAAAAAAAAAAAAAAAgAAAAEAAAAgAAQAPgAAAAAAAAA9AAAAOwAEAD4AAAA/AAAAAAAAABcABABDAAAADwAAAAMAAAAgAAQARAAAAAEAAABDAAAAOwAEAEQAAABFAAAAAQAAABcABABGAAAADwAAAAIAAAAUAAIASgAAACAABABcAAAABwAAAAoAAAArAAQACQAAAGMAAAAAAIA/LAAFAAoAAABkAAAAYwAAAGMAAAArAAQACQAAAGYAAAAAAAA/GQAJAGoAAAAJAAAAAQAAAAAAAAAAAAAAAAAAAAEAAAAAAAAAGwADAGsAAABqAAAAIAAEAGwAAAAAAAAAawAAADsABABsAAAAbQAAAAAAAAAXAAQAcAAAAAkAAAAEAAAAKwAEAAkAAABxAAAAAAAAACsABAAPAAAAcwAAAAgAAAAsAAYAQwAAAHQAAABzAAAAcwAAABUAAAA2AAUAAgAAAAQAAAAAAAAAAwAAAPgAAgAFAAAAOwAEAAgAAAA8AAAABwAAADsABAAIAAAAQgAAAAcAAAA7AAQAXAAAAF0AAAAHAAAAOwAEAAgAAABeAAAABwAAADsABABcAAAAYQAAAAcAAAA9AAQAPQAAAEAAAAA/AAAAaAAEAAcAAABBAAAAQAAAAD4AAwA8AAAAQQAAAD0ABABDAAAARwAAAEUAAABPAAcARgAAAEgAAABHAAAARwAAAAAAAAABAAAAfAAEAAcAAABJAAAASAAAAD4AAwBCAAAASQAAAEEABQAWAAAASwAAAEIAAAAUAAAAPQAEAAYAAABMAAAASwAAAEEABQAWAAAATQAAADwAAAAUAAAAPQAEAAYAAABOAAAATQAAAK8ABQBKAAAATwAAAEwAAABOAAAAqAAEAEoAAABQAAAATwAAAPcAAwBSAAAAAAAAAPoABABQAAAAUQAAAFIAAAD4AAIAUQAAAEEABQAWAAAAUwAAAEIAAAAVAAAAPQAEAAYAAABUAAAAUwAAAEEABQAWAAAAVQAAADwAAAAVAAAAPQAEAAYAAABWAAAAVQAAAK8ABQBKAAAAVwAAAFQAAABWAAAA+QACAFIAAAD4AAIAUgAAAPUABwBKAAAAWAAAAE8AAAAFAAAAVwAAAFEAAAD3AAMAWgAAAAAAAAD6AAQAWAAAAFkAAABaAAAA+AACAFkAAAD9AAEA+AACAFoAAAA9AAQABwAAAF8AAABCAAAAPgADAF4AAABfAAAAOQAFAAoAAABgAAAADQAAAF4AAAA+AAMAXQAAAGAAAAA9AAQACgAAAGIAAABdAAAAgQAFAAoAAABlAAAAYgAAAGQAAACOAAUACgAAAGcAAABlAAAAZgAAAD4AAwBhAAAAZwAAAD0ABAA9AAAAaAAAAD8AAAA9AAQABwAAAGkAAABCAAAAPQAEAGsAAABuAAAAbQAAAD0ABAAKAAAAbwAAAGEAAABYAAcAcAAAAHIAAABuAAAAbwAAAAIAAABxAAAAYwAEAGgAAABpAAAAcgAAAP0AAQA4AAEANgAFAAoAAAANAAAAAAAAAAsAAAA3AAMACAAAAAwAAAD4AAIADgAAADsABAASAAAAEwAAAAcAAAA7AAQAEgAAAB4AAAAHAAAAOwAEACwAAAAtAAAABwAAADsABAAsAAAAMQAAAAcAAABBAAUAFgAAABcAAAAMAAAAFQAAAD0ABAAGAAAAGAAAABcAAAB8AAQADwAAABkAAAAYAAAAQQAFABYAAAAaAAAADAAAABQAAAA9AAQABgAAABsAAAAaAAAAfAAEAA8AAAAcAAAAGwAAAFAABwARAAAAHQAAABQAAAAZAAAAHAAAABQAAAA+AAMAEwAAAB0AAABBAAUAFgAAAB8AAAAMAAAAFQAAAD0ABAAGAAAAIAAAAB8AAAB8AAQADwAAACEAAAAgAAAAQQAFABYAAAAiAAAADAAAABQAAAA9AAQABgAAACMAAAAiAAAAfAAEAA8AAAAkAAAAIwAAAFAABwARAAAAJQAAABQAAAAhAAAAJAAAABUAAAA+AAMAHgAAACUAAAA9AAQAJgAAACkAAAAoAAAAPQAEABEAAAAqAAAAEwAAAEQQBQArAAAALgAAACkAAAAqAAAAPgADAC0AAAAuAAAAPQAEACYAAAAvAAAAKAAAAD0ABAARAAAAMAAAAB4AAABEEAUAKwAAADIAAAAvAAAAMAAAAD4AAwAxAAAAMgAAAEEABQA0AAAANQAAAC0AAAAzAAAAPQAEAAkAAAA2AAAANQAAAEEABQA0AAAANwAAADEAAAAzAAAAPQAEAAkAAAA4AAAANwAAAFAABQAKAAAAOQAAADYAAAA4AAAA/gACADkAAAA4AAEA diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.glsl b/backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.glsl new file mode 100644 index 00000000000..efa5f8c589b --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.glsl @@ -0,0 +1,40 @@ +// Copyright 2026 Arm Limited and/or its affiliates. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#version 450 +#extension GL_ARM_tensors : require + +layout(set = 0, binding = 0) uniform sampler2D inputImage; +layout(set = 0, binding = 1) uniform tensorARM grid; +layout(set = 0, binding = 2, rgba32f) uniform writeonly image2D outImage; + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +vec2 readGridXY(ivec2 p) { + uint xCoords[4] = uint[](0u, uint(p.y), uint(p.x), 0u); + uint yCoords[4] = uint[](0u, uint(p.y), uint(p.x), 1u); + float xVal[1]; + float yVal[1]; + tensorReadARM(grid, xCoords, xVal); + tensorReadARM(grid, yCoords, yVal); + return vec2(xVal[0], yVal[0]); +} + +vec2 alignCornersUv(vec2 gridXY) { + vec2 inputSize = vec2(textureSize(inputImage, 0)); + vec2 texel = (gridXY + vec2(1.0)) * vec2(0.5) * (inputSize - vec2(1.0)); + return (texel + vec2(0.5)) / inputSize; +} + +void main() { + ivec2 outSize = imageSize(outImage); + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= outSize.x || gid.y >= outSize.y) { + return; + } + + vec2 gridXY = readGridXY(gid); + imageStore(outImage, gid, texture(inputImage, alignCornersUv(gridXY))); +} diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.spirv.b64 b/backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.spirv.b64 new file mode 100644 index 00000000000..c000bc299a3 --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.spirv.b64 @@ -0,0 +1 @@ +AwIjBwAAAQALAA0AigAAAAAAAAARAAIAAQAAABEAAgAyAAAAEQACAE4QAAAKAAUAU1BWX0FSTV90ZW5zb3JzAAsABgABAAAAR0xTTC5zdGQuNDUwAAAAAA4AAwAAAAAAAQAAAA8ABgAFAAAABAAAAG1haW4AAAAAZAAAABAABgAEAAAAEQAAAAgAAAAIAAAAAQAAAAMAAwACAAAAwgEAAAQABQBHTF9BUk1fdGVuc29ycwAABAAKAEdMX0dPT0dMRV9jcHBfc3R5bGVfbGluZV9kaXJlY3RpdmUAAAQACABHTF9HT09HTEVfaW5jbHVkZV9kaXJlY3RpdmUABQAEAAQAAABtYWluAAAAAAUABgANAAAAcmVhZEdyaWRYWSh2aTI7AAUAAwAMAAAAcAAAAAUABwASAAAAYWxpZ25Db3JuZXJzVXYodmYyOwAFAAQAEQAAAGdyaWRYWQAABQAEABgAAAB4Q29vcmRzAAUABAAjAAAAeUNvb3JkcwAFAAQALQAAAGdyaWQAAAAABQAEADIAAAB4VmFsAAAAAAUABAA2AAAAeVZhbAAAAAAFAAUAQQAAAGlucHV0U2l6ZQAAAAUABQBFAAAAaW5wdXRJbWFnZQAABQAEAEoAAAB0ZXhlbAAAAAUABABbAAAAb3V0U2l6ZQAFAAUAXgAAAG91dEltYWdlAAAAAAUAAwBhAAAAZ2lkAAUACABkAAAAZ2xfR2xvYmFsSW52b2NhdGlvbklEAAAABQAEAHsAAABncmlkWFkAAAUABAB8AAAAcGFyYW0AAAAFAAQAggAAAHBhcmFtAAAARwAEAC0AAAAhAAAAAQAAAEcABAAtAAAAIgAAAAAAAABHAAQARQAAACEAAAAAAAAARwAEAEUAAAAiAAAAAAAAAEcAAwBeAAAAGQAAAEcABABeAAAAIQAAAAIAAABHAAQAXgAAACIAAAAAAAAARwAEAGQAAAALAAAAHAAAAEcABACJAAAACwAAABkAAAATAAIAAgAAACEAAwADAAAAAgAAABUABAAGAAAAIAAAAAEAAAAXAAQABwAAAAYAAAACAAAAIAAEAAgAAAAHAAAABwAAABYAAwAJAAAAIAAAABcABAAKAAAACQAAAAIAAAAhAAQACwAAAAoAAAAIAAAAIAAEAA8AAAAHAAAACgAAACEABAAQAAAACgAAAA8AAAAVAAQAFAAAACAAAAAAAAAAKwAEABQAAAAVAAAABAAAABwABAAWAAAAFAAAABUAAAAgAAQAFwAAAAcAAAAWAAAAKwAEABQAAAAZAAAAAAAAACsABAAUAAAAGgAAAAEAAAAgAAQAGwAAAAcAAAAGAAAAQxAEACsAAAAJAAAAFQAAACAABAAsAAAAAAAAACsAAAA7AAQALAAAAC0AAAAAAAAAHAAEADAAAAAJAAAAGgAAACAABAAxAAAABwAAADAAAAArAAQABgAAADgAAAAAAAAAIAAEADkAAAAHAAAACQAAABkACQBCAAAACQAAAAEAAAAAAAAAAAAAAAAAAAABAAAAAAAAABsAAwBDAAAAQgAAACAABABEAAAAAAAAAEMAAAA7AAQARAAAAEUAAAAAAAAAKwAEAAkAAABMAAAAAACAPywABQAKAAAATQAAAEwAAABMAAAAKwAEAAkAAABPAAAAAAAAPywABQAKAAAAUAAAAE8AAABPAAAAGQAJAFwAAAAJAAAAAQAAAAAAAAAAAAAAAAAAAAIAAAABAAAAIAAEAF0AAAAAAAAAXAAAADsABABdAAAAXgAAAAAAAAAXAAQAYgAAABQAAAADAAAAIAAEAGMAAAABAAAAYgAAADsABABjAAAAZAAAAAEAAAAXAAQAZQAAABQAAAACAAAAFAACAGkAAAAXAAQAhQAAAAkAAAAEAAAAKwAEAAkAAACGAAAAAAAAACsABAAUAAAAiAAAAAgAAAAsAAYAYgAAAIkAAACIAAAAiAAAABoAAAA2AAUAAgAAAAQAAAAAAAAAAwAAAPgAAgAFAAAAOwAEAAgAAABbAAAABwAAADsABAAIAAAAYQAAAAcAAAA7AAQADwAAAHsAAAAHAAAAOwAEAAgAAAB8AAAABwAAADsABAAPAAAAggAAAAcAAAA9AAQAXAAAAF8AAABeAAAAaAAEAAcAAABgAAAAXwAAAD4AAwBbAAAAYAAAAD0ABABiAAAAZgAAAGQAAABPAAcAZQAAAGcAAABmAAAAZgAAAAAAAAABAAAAfAAEAAcAAABoAAAAZwAAAD4AAwBhAAAAaAAAAEEABQAbAAAAagAAAGEAAAAZAAAAPQAEAAYAAABrAAAAagAAAEEABQAbAAAAbAAAAFsAAAAZAAAAPQAEAAYAAABtAAAAbAAAAK8ABQBpAAAAbgAAAGsAAABtAAAAqAAEAGkAAABvAAAAbgAAAPcAAwBxAAAAAAAAAPoABABvAAAAcAAAAHEAAAD4AAIAcAAAAEEABQAbAAAAcgAAAGEAAAAaAAAAPQAEAAYAAABzAAAAcgAAAEEABQAbAAAAdAAAAFsAAAAaAAAAPQAEAAYAAAB1AAAAdAAAAK8ABQBpAAAAdgAAAHMAAAB1AAAA+QACAHEAAAD4AAIAcQAAAPUABwBpAAAAdwAAAG4AAAAFAAAAdgAAAHAAAAD3AAMAeQAAAAAAAAD6AAQAdwAAAHgAAAB5AAAA+AACAHgAAAD9AAEA+AACAHkAAAA9AAQABwAAAH0AAABhAAAAPgADAHwAAAB9AAAAOQAFAAoAAAB+AAAADQAAAHwAAAA+AAMAewAAAH4AAAA9AAQAXAAAAH8AAABeAAAAPQAEAAcAAACAAAAAYQAAAD0ABABDAAAAgQAAAEUAAAA9AAQACgAAAIMAAAB7AAAAPgADAIIAAACDAAAAOQAFAAoAAACEAAAAEgAAAIIAAABYAAcAhQAAAIcAAACBAAAAhAAAAAIAAACGAAAAYwAEAH8AAACAAAAAhwAAAP0AAQA4AAEANgAFAAoAAAANAAAAAAAAAAsAAAA3AAMACAAAAAwAAAD4AAIADgAAADsABAAXAAAAGAAAAAcAAAA7AAQAFwAAACMAAAAHAAAAOwAEADEAAAAyAAAABwAAADsABAAxAAAANgAAAAcAAABBAAUAGwAAABwAAAAMAAAAGgAAAD0ABAAGAAAAHQAAABwAAAB8AAQAFAAAAB4AAAAdAAAAQQAFABsAAAAfAAAADAAAABkAAAA9AAQABgAAACAAAAAfAAAAfAAEABQAAAAhAAAAIAAAAFAABwAWAAAAIgAAABkAAAAeAAAAIQAAABkAAAA+AAMAGAAAACIAAABBAAUAGwAAACQAAAAMAAAAGgAAAD0ABAAGAAAAJQAAACQAAAB8AAQAFAAAACYAAAAlAAAAQQAFABsAAAAnAAAADAAAABkAAAA9AAQABgAAACgAAAAnAAAAfAAEABQAAAApAAAAKAAAAFAABwAWAAAAKgAAABkAAAAmAAAAKQAAABoAAAA+AAMAIwAAACoAAAA9AAQAKwAAAC4AAAAtAAAAPQAEABYAAAAvAAAAGAAAAEQQBQAwAAAAMwAAAC4AAAAvAAAAPgADADIAAAAzAAAAPQAEACsAAAA0AAAALQAAAD0ABAAWAAAANQAAACMAAABEEAUAMAAAADcAAAA0AAAANQAAAD4AAwA2AAAANwAAAEEABQA5AAAAOgAAADIAAAA4AAAAPQAEAAkAAAA7AAAAOgAAAEEABQA5AAAAPAAAADYAAAA4AAAAPQAEAAkAAAA9AAAAPAAAAFAABQAKAAAAPgAAADsAAAA9AAAA/gACAD4AAAA4AAEANgAFAAoAAAASAAAAAAAAABAAAAA3AAMADwAAABEAAAD4AAIAEwAAADsABAAPAAAAQQAAAAcAAAA7AAQADwAAAEoAAAAHAAAAPQAEAEMAAABGAAAARQAAAGQABABCAAAARwAAAEYAAABnAAUABwAAAEgAAABHAAAAOAAAAG8ABAAKAAAASQAAAEgAAAA+AAMAQQAAAEkAAAA9AAQACgAAAEsAAAARAAAAgQAFAAoAAABOAAAASwAAAE0AAACFAAUACgAAAFEAAABOAAAAUAAAAD0ABAAKAAAAUgAAAEEAAACDAAUACgAAAFMAAABSAAAATQAAAIUABQAKAAAAVAAAAFEAAABTAAAAPgADAEoAAABUAAAAPQAEAAoAAABVAAAASgAAAIEABQAKAAAAVgAAAFUAAABQAAAAPQAEAAoAAABXAAAAQQAAAIgABQAKAAAAWAAAAFYAAABXAAAA/gACAFgAAAA4AAEA diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler_int8.glsl b/backends/arm/vgf/shaders/grid_sampler_sampler_int8.glsl new file mode 100644 index 00000000000..132049210db --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler_int8.glsl @@ -0,0 +1,35 @@ +// Copyright 2026 Arm Limited and/or its affiliates. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#version 450 +#extension GL_ARM_tensors : require + +layout(set = 0, binding = 0) uniform sampler2D inputImage; +layout(set = 0, binding = 1) uniform tensorARM grid; +layout(set = 0, binding = 2, rgba8_snorm) uniform writeonly image2D outImage; + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +vec2 readGridXY(ivec2 p) { + uint xCoords[4] = uint[](0u, uint(p.y), uint(p.x), 0u); + uint yCoords[4] = uint[](0u, uint(p.y), uint(p.x), 1u); + float xVal[1]; + float yVal[1]; + tensorReadARM(grid, xCoords, xVal); + tensorReadARM(grid, yCoords, yVal); + return vec2(xVal[0], yVal[0]); +} + +void main() { + ivec2 outSize = imageSize(outImage); + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= outSize.x || gid.y >= outSize.y) { + return; + } + + vec2 gridXY = readGridXY(gid); + vec2 uv = (gridXY + vec2(1.0)) * 0.5; + imageStore(outImage, gid, texture(inputImage, uv)); +} diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler_int8.spirv.b64 b/backends/arm/vgf/shaders/grid_sampler_sampler_int8.spirv.b64 new file mode 100644 index 00000000000..07f4440d3b4 --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler_int8.spirv.b64 @@ -0,0 +1 @@ +AwIjBwAAAQALAA0AdQAAAAAAAAARAAIAAQAAABEAAgAyAAAAEQACAE4QAAAKAAUAU1BWX0FSTV90ZW5zb3JzAAsABgABAAAAR0xTTC5zdGQuNDUwAAAAAA4AAwAAAAAAAQAAAA8ABgAFAAAABAAAAG1haW4AAAAARQAAABAABgAEAAAAEQAAAAgAAAAIAAAAAQAAAAMAAwACAAAAwgEAAAQABQBHTF9BUk1fdGVuc29ycwAABAAKAEdMX0dPT0dMRV9jcHBfc3R5bGVfbGluZV9kaXJlY3RpdmUAAAQACABHTF9HT09HTEVfaW5jbHVkZV9kaXJlY3RpdmUABQAEAAQAAABtYWluAAAAAAUABgANAAAAcmVhZEdyaWRYWSh2aTI7AAUAAwAMAAAAcAAAAAUABAATAAAAeENvb3JkcwAFAAQAHgAAAHlDb29yZHMABQAEACgAAABncmlkAAAAAAUABAAtAAAAeFZhbAAAAAAFAAQAMQAAAHlWYWwAAAAABQAEADwAAABvdXRTaXplAAUABQA/AAAAb3V0SW1hZ2UAAAAABQADAEIAAABnaWQABQAIAEUAAABnbF9HbG9iYWxJbnZvY2F0aW9uSUQAAAAFAAQAXQAAAGdyaWRYWQAABQAEAF4AAABwYXJhbQAAAAUAAwBhAAAAdXYAAAUABQBtAAAAaW5wdXRJbWFnZQAARwAEACgAAAAhAAAAAQAAAEcABAAoAAAAIgAAAAAAAABHAAMAPwAAABkAAABHAAQAPwAAACEAAAACAAAARwAEAD8AAAAiAAAAAAAAAEcABABFAAAACwAAABwAAABHAAQAbQAAACEAAAAAAAAARwAEAG0AAAAiAAAAAAAAAEcABAB0AAAACwAAABkAAAATAAIAAgAAACEAAwADAAAAAgAAABUABAAGAAAAIAAAAAEAAAAXAAQABwAAAAYAAAACAAAAIAAEAAgAAAAHAAAABwAAABYAAwAJAAAAIAAAABcABAAKAAAACQAAAAIAAAAhAAQACwAAAAoAAAAIAAAAFQAEAA8AAAAgAAAAAAAAACsABAAPAAAAEAAAAAQAAAAcAAQAEQAAAA8AAAAQAAAAIAAEABIAAAAHAAAAEQAAACsABAAPAAAAFAAAAAAAAAArAAQADwAAABUAAAABAAAAIAAEABYAAAAHAAAABgAAAEMQBAAmAAAACQAAABAAAAAgAAQAJwAAAAAAAAAmAAAAOwAEACcAAAAoAAAAAAAAABwABAArAAAACQAAABUAAAAgAAQALAAAAAcAAAArAAAAKwAEAAYAAAAzAAAAAAAAACAABAA0AAAABwAAAAkAAAAZAAkAPQAAAAkAAAABAAAAAAAAAAAAAAAAAAAAAgAAAAUAAAAgAAQAPgAAAAAAAAA9AAAAOwAEAD4AAAA/AAAAAAAAABcABABDAAAADwAAAAMAAAAgAAQARAAAAAEAAABDAAAAOwAEAEQAAABFAAAAAQAAABcABABGAAAADwAAAAIAAAAUAAIASgAAACAABABcAAAABwAAAAoAAAArAAQACQAAAGMAAAAAAIA/LAAFAAoAAABkAAAAYwAAAGMAAAArAAQACQAAAGYAAAAAAAA/GQAJAGoAAAAJAAAAAQAAAAAAAAAAAAAAAAAAAAEAAAAAAAAAGwADAGsAAABqAAAAIAAEAGwAAAAAAAAAawAAADsABABsAAAAbQAAAAAAAAAXAAQAcAAAAAkAAAAEAAAAKwAEAAkAAABxAAAAAAAAACsABAAPAAAAcwAAAAgAAAAsAAYAQwAAAHQAAABzAAAAcwAAABUAAAA2AAUAAgAAAAQAAAAAAAAAAwAAAPgAAgAFAAAAOwAEAAgAAAA8AAAABwAAADsABAAIAAAAQgAAAAcAAAA7AAQAXAAAAF0AAAAHAAAAOwAEAAgAAABeAAAABwAAADsABABcAAAAYQAAAAcAAAA9AAQAPQAAAEAAAAA/AAAAaAAEAAcAAABBAAAAQAAAAD4AAwA8AAAAQQAAAD0ABABDAAAARwAAAEUAAABPAAcARgAAAEgAAABHAAAARwAAAAAAAAABAAAAfAAEAAcAAABJAAAASAAAAD4AAwBCAAAASQAAAEEABQAWAAAASwAAAEIAAAAUAAAAPQAEAAYAAABMAAAASwAAAEEABQAWAAAATQAAADwAAAAUAAAAPQAEAAYAAABOAAAATQAAAK8ABQBKAAAATwAAAEwAAABOAAAAqAAEAEoAAABQAAAATwAAAPcAAwBSAAAAAAAAAPoABABQAAAAUQAAAFIAAAD4AAIAUQAAAEEABQAWAAAAUwAAAEIAAAAVAAAAPQAEAAYAAABUAAAAUwAAAEEABQAWAAAAVQAAADwAAAAVAAAAPQAEAAYAAABWAAAAVQAAAK8ABQBKAAAAVwAAAFQAAABWAAAA+QACAFIAAAD4AAIAUgAAAPUABwBKAAAAWAAAAE8AAAAFAAAAVwAAAFEAAAD3AAMAWgAAAAAAAAD6AAQAWAAAAFkAAABaAAAA+AACAFkAAAD9AAEA+AACAFoAAAA9AAQABwAAAF8AAABCAAAAPgADAF4AAABfAAAAOQAFAAoAAABgAAAADQAAAF4AAAA+AAMAXQAAAGAAAAA9AAQACgAAAGIAAABdAAAAgQAFAAoAAABlAAAAYgAAAGQAAACOAAUACgAAAGcAAABlAAAAZgAAAD4AAwBhAAAAZwAAAD0ABAA9AAAAaAAAAD8AAAA9AAQABwAAAGkAAABCAAAAPQAEAGsAAABuAAAAbQAAAD0ABAAKAAAAbwAAAGEAAABYAAcAcAAAAHIAAABuAAAAbwAAAAIAAABxAAAAYwAEAGgAAABpAAAAcgAAAP0AAQA4AAEANgAFAAoAAAANAAAAAAAAAAsAAAA3AAMACAAAAAwAAAD4AAIADgAAADsABAASAAAAEwAAAAcAAAA7AAQAEgAAAB4AAAAHAAAAOwAEACwAAAAtAAAABwAAADsABAAsAAAAMQAAAAcAAABBAAUAFgAAABcAAAAMAAAAFQAAAD0ABAAGAAAAGAAAABcAAAB8AAQADwAAABkAAAAYAAAAQQAFABYAAAAaAAAADAAAABQAAAA9AAQABgAAABsAAAAaAAAAfAAEAA8AAAAcAAAAGwAAAFAABwARAAAAHQAAABQAAAAZAAAAHAAAABQAAAA+AAMAEwAAAB0AAABBAAUAFgAAAB8AAAAMAAAAFQAAAD0ABAAGAAAAIAAAAB8AAAB8AAQADwAAACEAAAAgAAAAQQAFABYAAAAiAAAADAAAABQAAAA9AAQABgAAACMAAAAiAAAAfAAEAA8AAAAkAAAAIwAAAFAABwARAAAAJQAAABQAAAAhAAAAJAAAABUAAAA+AAMAHgAAACUAAAA9AAQAJgAAACkAAAAoAAAAPQAEABEAAAAqAAAAEwAAAEQQBQArAAAALgAAACkAAAAqAAAAPgADAC0AAAAuAAAAPQAEACYAAAAvAAAAKAAAAD0ABAARAAAAMAAAAB4AAABEEAUAKwAAADIAAAAvAAAAMAAAAD4AAwAxAAAAMgAAAEEABQA0AAAANQAAAC0AAAAzAAAAPQAEAAkAAAA2AAAANQAAAEEABQA0AAAANwAAADEAAAAzAAAAPQAEAAkAAAA4AAAANwAAAFAABQAKAAAAOQAAADYAAAA4AAAA/gACADkAAAA4AAEA diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.glsl b/backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.glsl new file mode 100644 index 00000000000..b0aa9d303fc --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.glsl @@ -0,0 +1,40 @@ +// Copyright 2026 Arm Limited and/or its affiliates. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#version 450 +#extension GL_ARM_tensors : require + +layout(set = 0, binding = 0) uniform sampler2D inputImage; +layout(set = 0, binding = 1) uniform tensorARM grid; +layout(set = 0, binding = 2, rgba8_snorm) uniform writeonly image2D outImage; + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +vec2 readGridXY(ivec2 p) { + uint xCoords[4] = uint[](0u, uint(p.y), uint(p.x), 0u); + uint yCoords[4] = uint[](0u, uint(p.y), uint(p.x), 1u); + float xVal[1]; + float yVal[1]; + tensorReadARM(grid, xCoords, xVal); + tensorReadARM(grid, yCoords, yVal); + return vec2(xVal[0], yVal[0]); +} + +vec2 alignCornersUv(vec2 gridXY) { + vec2 inputSize = vec2(textureSize(inputImage, 0)); + vec2 texel = (gridXY + vec2(1.0)) * vec2(0.5) * (inputSize - vec2(1.0)); + return (texel + vec2(0.5)) / inputSize; +} + +void main() { + ivec2 outSize = imageSize(outImage); + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= outSize.x || gid.y >= outSize.y) { + return; + } + + vec2 gridXY = readGridXY(gid); + imageStore(outImage, gid, texture(inputImage, alignCornersUv(gridXY))); +} diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.spirv.b64 b/backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.spirv.b64 new file mode 100644 index 00000000000..cd15b7a8264 --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.spirv.b64 @@ -0,0 +1 @@ +AwIjBwAAAQALAA0AigAAAAAAAAARAAIAAQAAABEAAgAyAAAAEQACAE4QAAAKAAUAU1BWX0FSTV90ZW5zb3JzAAsABgABAAAAR0xTTC5zdGQuNDUwAAAAAA4AAwAAAAAAAQAAAA8ABgAFAAAABAAAAG1haW4AAAAAZAAAABAABgAEAAAAEQAAAAgAAAAIAAAAAQAAAAMAAwACAAAAwgEAAAQABQBHTF9BUk1fdGVuc29ycwAABAAKAEdMX0dPT0dMRV9jcHBfc3R5bGVfbGluZV9kaXJlY3RpdmUAAAQACABHTF9HT09HTEVfaW5jbHVkZV9kaXJlY3RpdmUABQAEAAQAAABtYWluAAAAAAUABgANAAAAcmVhZEdyaWRYWSh2aTI7AAUAAwAMAAAAcAAAAAUABwASAAAAYWxpZ25Db3JuZXJzVXYodmYyOwAFAAQAEQAAAGdyaWRYWQAABQAEABgAAAB4Q29vcmRzAAUABAAjAAAAeUNvb3JkcwAFAAQALQAAAGdyaWQAAAAABQAEADIAAAB4VmFsAAAAAAUABAA2AAAAeVZhbAAAAAAFAAUAQQAAAGlucHV0U2l6ZQAAAAUABQBFAAAAaW5wdXRJbWFnZQAABQAEAEoAAAB0ZXhlbAAAAAUABABbAAAAb3V0U2l6ZQAFAAUAXgAAAG91dEltYWdlAAAAAAUAAwBhAAAAZ2lkAAUACABkAAAAZ2xfR2xvYmFsSW52b2NhdGlvbklEAAAABQAEAHsAAABncmlkWFkAAAUABAB8AAAAcGFyYW0AAAAFAAQAggAAAHBhcmFtAAAARwAEAC0AAAAhAAAAAQAAAEcABAAtAAAAIgAAAAAAAABHAAQARQAAACEAAAAAAAAARwAEAEUAAAAiAAAAAAAAAEcAAwBeAAAAGQAAAEcABABeAAAAIQAAAAIAAABHAAQAXgAAACIAAAAAAAAARwAEAGQAAAALAAAAHAAAAEcABACJAAAACwAAABkAAAATAAIAAgAAACEAAwADAAAAAgAAABUABAAGAAAAIAAAAAEAAAAXAAQABwAAAAYAAAACAAAAIAAEAAgAAAAHAAAABwAAABYAAwAJAAAAIAAAABcABAAKAAAACQAAAAIAAAAhAAQACwAAAAoAAAAIAAAAIAAEAA8AAAAHAAAACgAAACEABAAQAAAACgAAAA8AAAAVAAQAFAAAACAAAAAAAAAAKwAEABQAAAAVAAAABAAAABwABAAWAAAAFAAAABUAAAAgAAQAFwAAAAcAAAAWAAAAKwAEABQAAAAZAAAAAAAAACsABAAUAAAAGgAAAAEAAAAgAAQAGwAAAAcAAAAGAAAAQxAEACsAAAAJAAAAFQAAACAABAAsAAAAAAAAACsAAAA7AAQALAAAAC0AAAAAAAAAHAAEADAAAAAJAAAAGgAAACAABAAxAAAABwAAADAAAAArAAQABgAAADgAAAAAAAAAIAAEADkAAAAHAAAACQAAABkACQBCAAAACQAAAAEAAAAAAAAAAAAAAAAAAAABAAAAAAAAABsAAwBDAAAAQgAAACAABABEAAAAAAAAAEMAAAA7AAQARAAAAEUAAAAAAAAAKwAEAAkAAABMAAAAAACAPywABQAKAAAATQAAAEwAAABMAAAAKwAEAAkAAABPAAAAAAAAPywABQAKAAAAUAAAAE8AAABPAAAAGQAJAFwAAAAJAAAAAQAAAAAAAAAAAAAAAAAAAAIAAAAFAAAAIAAEAF0AAAAAAAAAXAAAADsABABdAAAAXgAAAAAAAAAXAAQAYgAAABQAAAADAAAAIAAEAGMAAAABAAAAYgAAADsABABjAAAAZAAAAAEAAAAXAAQAZQAAABQAAAACAAAAFAACAGkAAAAXAAQAhQAAAAkAAAAEAAAAKwAEAAkAAACGAAAAAAAAACsABAAUAAAAiAAAAAgAAAAsAAYAYgAAAIkAAACIAAAAiAAAABoAAAA2AAUAAgAAAAQAAAAAAAAAAwAAAPgAAgAFAAAAOwAEAAgAAABbAAAABwAAADsABAAIAAAAYQAAAAcAAAA7AAQADwAAAHsAAAAHAAAAOwAEAAgAAAB8AAAABwAAADsABAAPAAAAggAAAAcAAAA9AAQAXAAAAF8AAABeAAAAaAAEAAcAAABgAAAAXwAAAD4AAwBbAAAAYAAAAD0ABABiAAAAZgAAAGQAAABPAAcAZQAAAGcAAABmAAAAZgAAAAAAAAABAAAAfAAEAAcAAABoAAAAZwAAAD4AAwBhAAAAaAAAAEEABQAbAAAAagAAAGEAAAAZAAAAPQAEAAYAAABrAAAAagAAAEEABQAbAAAAbAAAAFsAAAAZAAAAPQAEAAYAAABtAAAAbAAAAK8ABQBpAAAAbgAAAGsAAABtAAAAqAAEAGkAAABvAAAAbgAAAPcAAwBxAAAAAAAAAPoABABvAAAAcAAAAHEAAAD4AAIAcAAAAEEABQAbAAAAcgAAAGEAAAAaAAAAPQAEAAYAAABzAAAAcgAAAEEABQAbAAAAdAAAAFsAAAAaAAAAPQAEAAYAAAB1AAAAdAAAAK8ABQBpAAAAdgAAAHMAAAB1AAAA+QACAHEAAAD4AAIAcQAAAPUABwBpAAAAdwAAAG4AAAAFAAAAdgAAAHAAAAD3AAMAeQAAAAAAAAD6AAQAdwAAAHgAAAB5AAAA+AACAHgAAAD9AAEA+AACAHkAAAA9AAQABwAAAH0AAABhAAAAPgADAHwAAAB9AAAAOQAFAAoAAAB+AAAADQAAAHwAAAA+AAMAewAAAH4AAAA9AAQAXAAAAH8AAABeAAAAPQAEAAcAAACAAAAAYQAAAD0ABABDAAAAgQAAAEUAAAA9AAQACgAAAIMAAAB7AAAAPgADAIIAAACDAAAAOQAFAAoAAACEAAAAEgAAAIIAAABYAAcAhQAAAIcAAACBAAAAhAAAAAIAAACGAAAAYwAEAH8AAACAAAAAhwAAAP0AAQA4AAEANgAFAAoAAAANAAAAAAAAAAsAAAA3AAMACAAAAAwAAAD4AAIADgAAADsABAAXAAAAGAAAAAcAAAA7AAQAFwAAACMAAAAHAAAAOwAEADEAAAAyAAAABwAAADsABAAxAAAANgAAAAcAAABBAAUAGwAAABwAAAAMAAAAGgAAAD0ABAAGAAAAHQAAABwAAAB8AAQAFAAAAB4AAAAdAAAAQQAFABsAAAAfAAAADAAAABkAAAA9AAQABgAAACAAAAAfAAAAfAAEABQAAAAhAAAAIAAAAFAABwAWAAAAIgAAABkAAAAeAAAAIQAAABkAAAA+AAMAGAAAACIAAABBAAUAGwAAACQAAAAMAAAAGgAAAD0ABAAGAAAAJQAAACQAAAB8AAQAFAAAACYAAAAlAAAAQQAFABsAAAAnAAAADAAAABkAAAA9AAQABgAAACgAAAAnAAAAfAAEABQAAAApAAAAKAAAAFAABwAWAAAAKgAAABkAAAAmAAAAKQAAABoAAAA+AAMAIwAAACoAAAA9AAQAKwAAAC4AAAAtAAAAPQAEABYAAAAvAAAAGAAAAEQQBQAwAAAAMwAAAC4AAAAvAAAAPgADADIAAAAzAAAAPQAEACsAAAA0AAAALQAAAD0ABAAWAAAANQAAACMAAABEEAUAMAAAADcAAAA0AAAANQAAAD4AAwA2AAAANwAAAEEABQA5AAAAOgAAADIAAAA4AAAAPQAEAAkAAAA7AAAAOgAAAEEABQA5AAAAPAAAADYAAAA4AAAAPQAEAAkAAAA9AAAAPAAAAFAABQAKAAAAPgAAADsAAAA9AAAA/gACAD4AAAA4AAEANgAFAAoAAAASAAAAAAAAABAAAAA3AAMADwAAABEAAAD4AAIAEwAAADsABAAPAAAAQQAAAAcAAAA7AAQADwAAAEoAAAAHAAAAPQAEAEMAAABGAAAARQAAAGQABABCAAAARwAAAEYAAABnAAUABwAAAEgAAABHAAAAOAAAAG8ABAAKAAAASQAAAEgAAAA+AAMAQQAAAEkAAAA9AAQACgAAAEsAAAARAAAAgQAFAAoAAABOAAAASwAAAE0AAACFAAUACgAAAFEAAABOAAAAUAAAAD0ABAAKAAAAUgAAAEEAAACDAAUACgAAAFMAAABSAAAATQAAAIUABQAKAAAAVAAAAFEAAABTAAAAPgADAEoAAABUAAAAPQAEAAoAAABVAAAASgAAAIEABQAKAAAAVgAAAFUAAABQAAAAPQAEAAoAAABXAAAAQQAAAIgABQAKAAAAWAAAAFYAAABXAAAA/gACAFgAAAA4AAEA