From 8d3e8010672b85a471098309bb75d66fc4519fa9 Mon Sep 17 00:00:00 2001 From: Baris Demir Date: Fri, 12 Jun 2026 11:10:44 +0100 Subject: [PATCH 1/4] Arm backend: Use sampler path for VGF grid sampler Enable the VGF grid_sampler custom payload to select an image sampler shader for NCHW float32 C4 inputs, while keeping the existing storage-buffer shader for unsupported layouts. Pad static C3 inputs to C4 in the rewrite pass so RIFE feature warps can use texture samplers, then slice outputs back to C3. Add shader resources and tests covering sampler and fallback paths. Signed-off-by: Baris Demir Change-Id: I8737436d2b5920e74804103dae2a609ce2d24183 --- backends/arm/TARGETS | 2 + .../scripts/generate_grid_sampler_spirv.py | 8 +- .../test/misc/test_custom_shader_payload.py | 87 ++++++++++++ ...ewrite_grid_sampler_to_tosa_custom_pass.py | 97 ++++++++++++- .../rewrite_grid_sampler_to_tosa_custom.py | 131 ++++++++++++++++-- backends/arm/vgf/shaders/grid_sampler.py | 95 +++++++++++-- .../arm/vgf/shaders/grid_sampler_sampler.glsl | 35 +++++ .../shaders/grid_sampler_sampler.spirv.b64 | 1 + 8 files changed, 428 insertions(+), 28 deletions(-) create mode 100644 backends/arm/vgf/shaders/grid_sampler_sampler.glsl create mode 100644 backends/arm/vgf/shaders/grid_sampler_sampler.spirv.b64 diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index fcf95653438..5c66d713e5a 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -100,6 +100,8 @@ runtime.python_library( resources = [ "vgf/shaders/grid_sampler.glsl", "vgf/shaders/grid_sampler.spirv.b64", + "vgf/shaders/grid_sampler_sampler.glsl", + "vgf/shaders/grid_sampler_sampler.spirv.b64", ], deps = [ ":arm_compile_spec", diff --git a/backends/arm/scripts/generate_grid_sampler_spirv.py b/backends/arm/scripts/generate_grid_sampler_spirv.py index f8956a86cda..d5ab01c214c 100644 --- a/backends/arm/scripts/generate_grid_sampler_spirv.py +++ b/backends/arm/scripts/generate_grid_sampler_spirv.py @@ -65,7 +65,13 @@ def main() -> None: with tempfile.TemporaryDirectory() as tmpdir: spirv_path = Path(tmpdir) / "grid_sampler.spirv" subprocess.run( # nosec B603 - glslc path is resolved explicitly. - [glslc, str(args.source), "-o", str(spirv_path)], + [ + glslc, + "-fshader-stage=compute", + str(args.source), + "-o", + str(spirv_path), + ], check=True, ) _write_base64_spirv(spirv_path, args.output) diff --git a/backends/arm/test/misc/test_custom_shader_payload.py b/backends/arm/test/misc/test_custom_shader_payload.py index 6243e8752ba..e9f0cc431ce 100644 --- a/backends/arm/test/misc/test_custom_shader_payload.py +++ b/backends/arm/test/misc/test_custom_shader_payload.py @@ -6,10 +6,14 @@ import base64 import pytest +import torch from executorch.backends.arm.vgf.shaders.grid_sampler import ( build_grid_sampler_2d_payload, decode_payload, encode_payload, + GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY, + GRID_SAMPLER_2D_SAMPLER_SHADER_SOURCE, + GRID_SAMPLER_2D_SAMPLER_VK_FORMAT, GRID_SAMPLER_2D_SHADER_BINARY, GRID_SAMPLER_2D_SHADER_ENTRY_POINT, GRID_SAMPLER_2D_SHADER_LANGUAGE, @@ -45,6 +49,87 @@ def test_grid_sampler_2d_custom_shader_payload_no_target_round_trip(): assert decoded["output_0_binding"] == 2 +def test_grid_sampler_2d_custom_shader_payload_no_target_uses_sampler_for_c4(): + payload = build_grid_sampler_2d_payload( + interpolation_mode=0, + padding_mode=0, + align_corners=False, + input_shape=(1, 4, 8, 8), + input_dtype=torch.float32, + ) + + assert payload["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE + assert base64.b64decode(payload["shader_code"])[:4] == b"\x03\x02\x23\x07" + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ) + assert payload["input_1_type"] == "Tensor" + assert payload["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT + assert payload["input_1_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_TENSOR_ARM" + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" + assert payload["input_0_sampler"] == { + "address_mode_u": "VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_BORDER", + "address_mode_v": "VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_BORDER", + "border_color": "VK_BORDER_COLOR_FLOAT_TRANSPARENT_BLACK", + "mag_filter": "VK_FILTER_LINEAR", + "min_filter": "VK_FILTER_LINEAR", + } + + +def test_grid_sampler_2d_custom_shader_payload_no_target_keeps_c3_on_buffer(): + payload = build_grid_sampler_2d_payload( + interpolation_mode=0, + padding_mode=0, + align_corners=False, + input_shape=(1, 3, 8, 8), + input_dtype=torch.float32, + ) + + assert payload["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE + assert payload["input_0_type"] == "Tensor" + assert payload["input_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER" + assert payload["output_0_type"] == "Tensor" + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER" + assert "input_0_sampler" not in payload + + +def test_grid_sampler_2d_custom_shader_payload_no_target_align_corners_buffer(): + payload = build_grid_sampler_2d_payload( + interpolation_mode=0, + padding_mode=0, + align_corners=True, + input_shape=(1, 4, 8, 8), + input_dtype=torch.float32, + ) + + assert payload["input_0_type"] == "Tensor" + assert payload["input_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER" + assert payload["output_0_type"] == "Tensor" + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER" + assert "input_0_sampler" not in payload + + +def test_grid_sampler_2d_custom_shader_payload_no_target_bicubic_buffer(): + payload = build_grid_sampler_2d_payload( + interpolation_mode=2, + padding_mode=0, + align_corners=False, + input_shape=(1, 4, 8, 8), + input_dtype=torch.float32, + ) + + assert payload["input_0_type"] == "Tensor" + assert payload["input_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER" + assert payload["output_0_type"] == "Tensor" + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER" + assert "input_0_sampler" not in payload + + def test_grid_sampler_2d_custom_shader_payload_no_target_uses_spirv(): payload = build_grid_sampler_2d_payload( interpolation_mode=0, @@ -61,6 +146,8 @@ def test_grid_sampler_2d_custom_shader_payload_no_target_uses_spirv(): def test_grid_sampler_2d_custom_shader_payload_no_target_has_shader_resources(): assert GRID_SAMPLER_2D_SHADER_SOURCE == "grid_sampler.glsl" assert GRID_SAMPLER_2D_SHADER_BINARY == "grid_sampler.spirv.b64" + assert GRID_SAMPLER_2D_SAMPLER_SHADER_SOURCE == "grid_sampler_sampler.glsl" + assert GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY == "grid_sampler_sampler.spirv.b64" def test_grid_sampler_2d_custom_shader_payload_no_target_rejects_bad_modes(): diff --git a/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py b/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py index ec7773dfdbc..7fe2644f1e7 100644 --- a/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py +++ b/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py @@ -17,6 +17,7 @@ CUSTOM_SHADER_DOMAIN_NAME, decode_payload, grid_sampler_2d_operator_name, + GRID_SAMPLER_2D_SAMPLER_VK_FORMAT, GRID_SAMPLER_2D_SHADER_ENTRY_POINT, GRID_SAMPLER_2D_SHADER_LANGUAGE, GRID_SAMPLER_2D_VK_FORMAT, @@ -35,10 +36,11 @@ def __init__(self): self.align_corners_ = False def forward(self, x, grid): + mode = ("bilinear", "nearest", "bicubic")[self.interpolation_mode_] return F.grid_sample( x, grid, - mode="bilinear" if self.interpolation_mode_ == 0 else "nearest", + mode=mode, padding_mode="zeros" if self.padding_mode_ == 0 else "border", align_corners=self.align_corners_, ) @@ -80,15 +82,100 @@ def test_rewrite_grid_sampler_to_tosa_custom_vgf_no_target(): assert payload["entry_point"] == GRID_SAMPLER_2D_SHADER_ENTRY_POINT assert payload["workgroup_sizes"] == GRID_SAMPLER_2D_WORKGROUP_SIZES assert payload["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE - assert payload["input_0_type"] == "Tensor" - assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ) assert payload["input_0_binding"] == 0 assert payload["input_0_descriptorset"] == 0 assert payload["input_1_type"] == "Tensor" assert payload["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT + assert payload["input_1_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_TENSOR_ARM" assert payload["input_1_binding"] == 1 assert payload["input_1_descriptorset"] == 0 - assert payload["output_0_type"] == "Tensor" - assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" assert payload["output_0_binding"] == 2 assert payload["output_0_descriptorset"] == 0 + assert any(node.target == exir_ops.edge.aten.slice_copy.Tensor for node in nodes) + + +def test_rewrite_grid_sampler_to_tosa_custom_no_target_uses_sampler_for_c4(): + model = GridSampler2d() + example_inputs = ( + torch.randn(1, 4, 8, 8), + torch.randn(1, 4, 4, 2), + ) + + edge_model = to_edge(export(model, example_inputs)) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")): + edge_model = edge_model.transform([RewriteGridSamplerToTosaCustomPass()]) + nodes = list(edge_model.exported_program().graph.nodes) + + custom_node = next( + node for node in nodes if node.target == exir_ops.backend.tosa.CUSTOM.default + ) + payload = decode_payload(custom_node.kwargs["implementation_attrs"]) + + assert payload["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ) + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert payload["input_1_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_TENSOR_ARM" + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + + +def test_rewrite_grid_sampler_to_tosa_custom_no_c3_pad_for_align_corners(): + model = GridSampler2d() + model.align_corners_ = True + example_inputs = ( + torch.randn(1, 3, 8, 8), + torch.randn(1, 4, 4, 2), + ) + + edge_model = to_edge(export(model, example_inputs)) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")): + edge_model = edge_model.transform([RewriteGridSamplerToTosaCustomPass()]) + nodes = list(edge_model.exported_program().graph.nodes) + + custom_node = next( + node for node in nodes if node.target == exir_ops.backend.tosa.CUSTOM.default + ) + payload = decode_payload(custom_node.kwargs["implementation_attrs"]) + + assert payload["input_0_type"] == "Tensor" + assert not any(node.target == exir_ops.edge.aten.cat.default for node in nodes) + assert not any( + node.target == exir_ops.edge.aten.slice_copy.Tensor for node in nodes + ) + + +def test_rewrite_grid_sampler_to_tosa_custom_no_c3_pad_for_bicubic(): + model = GridSampler2d() + model.interpolation_mode_ = 2 + example_inputs = ( + torch.randn(1, 3, 8, 8), + torch.randn(1, 4, 4, 2), + ) + + edge_model = to_edge(export(model, example_inputs)) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")): + edge_model = edge_model.transform([RewriteGridSamplerToTosaCustomPass()]) + nodes = list(edge_model.exported_program().graph.nodes) + + custom_node = next( + node for node in nodes if node.target == exir_ops.backend.tosa.CUSTOM.default + ) + payload = decode_payload(custom_node.kwargs["implementation_attrs"]) + + assert payload["input_0_type"] == "Tensor" + assert not any(node.target == exir_ops.edge.aten.cat.default for node in nodes) + assert not any( + node.target == exir_ops.edge.aten.slice_copy.Tensor for node in nodes + ) diff --git a/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py b/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py index 9d4f17dc936..32c899c68db 100644 --- a/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py +++ b/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py @@ -84,6 +84,32 @@ def _set_fake_tensor_meta(node: torch.fx.Node, value) -> None: node.meta["tensor_meta"] = _extract_tensor_metadata(value) +def _is_static_nchw_with_channels(node: torch.fx.Node, channels: int) -> bool: + value = node.meta.get("val") + return ( + isinstance(value, torch.Tensor) + and len(value.shape) == 4 + and int(value.shape[1]) == channels + ) + + +def _can_pad_c3_for_sampler( + input_tensor: torch.fx.Node, + interpolation_mode: int, + align_corners: bool, +) -> bool: + value = input_tensor.meta.get("val") + return ( + isinstance(value, torch.Tensor) + and len(value.shape) == 4 + and int(value.shape[0]) == 1 + and int(value.shape[1]) == 3 + and value.dtype is torch.float32 + and int(interpolation_mode) in (0, 1) + and not bool(align_corners) + ) + + class RewriteGridSamplerToTosaCustomPass(ArmPass): """Rewrite ``aten.grid_sampler_2d`` nodes to ``tosa.CUSTOM``.""" @@ -92,15 +118,82 @@ class RewriteGridSamplerToTosaCustomPass(ArmPass): @staticmethod def _encode_payload( - interpolation_mode: int, padding_mode: int, align_corners: bool + interpolation_mode: int, + padding_mode: int, + align_corners: bool, + input_tensor: torch.fx.Node, ) -> list[int]: + input_val = input_tensor.meta.get("val") + if input_val is None: + raise RuntimeError("grid_sampler_2d input is missing tensor metadata") payload = build_grid_sampler_2d_payload( interpolation_mode=interpolation_mode, padding_mode=padding_mode, align_corners=align_corners, + input_shape=tuple(input_val.shape), + input_dtype=input_val.dtype, ) return encode_payload(payload) + @staticmethod + def _pad_c3_input_to_c4( + graph_module: torch.fx.GraphModule, + input_tensor: torch.fx.Node, + ) -> torch.fx.Node: + input_val = input_tensor.meta["val"] + first_channel = create_node( + graph_module.graph, + op_target=exir_ops.edge.aten.slice_copy.Tensor, + args=(input_tensor, 1, 0, 1, 1), + from_node=input_tensor, + ) + first_channel_val = exir_ops.edge.aten.slice_copy.Tensor(input_val, 1, 0, 1, 1) + _set_fake_tensor_meta(first_channel, first_channel_val) + + zero_channel = create_node( + graph_module.graph, + op_target=exir_ops.edge.aten.sub.Tensor, + args=(first_channel, first_channel), + kwargs={"alpha": 1}, + from_node=input_tensor, + ) + _set_fake_tensor_meta( + zero_channel, + exir_ops.edge.aten.sub.Tensor(first_channel_val, first_channel_val), + ) + + padded_input = create_node( + graph_module.graph, + op_target=exir_ops.edge.aten.cat.default, + args=([input_tensor, zero_channel], 1), + from_node=input_tensor, + ) + _set_fake_tensor_meta( + padded_input, + exir_ops.edge.aten.cat.default([input_val, zero_channel.meta["val"]], 1), + ) + return padded_input + + @staticmethod + def _slice_c4_output_to_c3( + graph_module: torch.fx.GraphModule, + output: torch.fx.Node, + original_node: torch.fx.Node, + ) -> torch.fx.Node: + output_val = output.meta["val"] + sliced_output = create_node( + graph_module.graph, + op_target=exir_ops.edge.aten.slice_copy.Tensor, + args=(output, 1, 0, 3, 1), + from_node=original_node, + ) + sliced_output.meta = dict(original_node.meta) + _set_fake_tensor_meta( + sliced_output, + exir_ops.edge.aten.slice_copy.Tensor(output_val, 1, 0, 3, 1), + ) + return sliced_output + def call(self, graph_module): modified = False for node in graph_module.graph.nodes: @@ -114,12 +207,12 @@ def call(self, graph_module): input_tensor, grid, interpolation_mode, padding_mode, align_corners = ( node.args ) - - implementation_attrs = self._encode_payload( - interpolation_mode=interpolation_mode, - padding_mode=padding_mode, - align_corners=align_corners, + pad_c3_for_sampler = _can_pad_c3_for_sampler( + input_tensor, + interpolation_mode, + align_corners, ) + operator_name = grid_sampler_2d_operator_name( interpolation_mode=interpolation_mode, padding_mode=padding_mode, @@ -127,16 +220,27 @@ def call(self, graph_module): ) with graph_module.graph.inserting_before(node): + custom_input = ( + self._pad_c3_input_to_c4(graph_module, input_tensor) + if pad_c3_for_sampler + else input_tensor + ) + implementation_attrs = self._encode_payload( + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + input_tensor=custom_input, + ) nhwc_input = create_node( graph_module.graph, op_target=exir_ops.edge.aten.permute_copy.default, - args=(input_tensor, list(NHWC_ORDER)), - from_node=input_tensor, + args=(custom_input, list(NHWC_ORDER)), + from_node=custom_input, ) _set_fake_tensor_meta( nhwc_input, exir_ops.edge.aten.permute_copy.default( - input_tensor.meta["val"], list(NHWC_ORDER) + custom_input.meta["val"], list(NHWC_ORDER) ), ) @@ -184,7 +288,14 @@ def call(self, graph_module): custom_output, list(NHWC_INVERSE_ORDER) ), ) - node.replace_all_uses_with(output) + if pad_c3_for_sampler: + with graph_module.graph.inserting_after(output): + replacement = self._slice_c4_output_to_c3( + graph_module, output, node + ) + else: + replacement = output + node.replace_all_uses_with(replacement) graph_module.graph.erase_node(node) if modified: diff --git a/backends/arm/vgf/shaders/grid_sampler.py b/backends/arm/vgf/shaders/grid_sampler.py index 800a4ec0013..12aa86e6d35 100644 --- a/backends/arm/vgf/shaders/grid_sampler.py +++ b/backends/arm/vgf/shaders/grid_sampler.py @@ -15,6 +15,9 @@ GRID_SAMPLER_2D_VK_FORMAT = "VK_FORMAT_R32_SFLOAT" GRID_SAMPLER_2D_SHADER_SOURCE = "grid_sampler.glsl" GRID_SAMPLER_2D_SHADER_BINARY = "grid_sampler.spirv.b64" +GRID_SAMPLER_2D_SAMPLER_SHADER_SOURCE = "grid_sampler_sampler.glsl" +GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY = "grid_sampler_sampler.spirv.b64" +GRID_SAMPLER_2D_SAMPLER_VK_FORMAT = "VK_FORMAT_R32G32B32A32_SFLOAT" _INTERPOLATION_MODE_NAMES = { 0: "bilinear", @@ -67,6 +70,8 @@ def build_grid_sampler_2d_payload( interpolation_mode: int, padding_mode: int, align_corners: bool, + input_shape: tuple[int, ...] | None = None, + input_dtype: Any | None = None, ) -> dict[str, Any]: _mode_name( int(interpolation_mode), @@ -78,34 +83,100 @@ def build_grid_sampler_2d_payload( _PADDING_MODE_NAMES, "padding_mode", ) + use_sampler = ( + input_shape is not None + and len(input_shape) == 4 + and int(input_shape[0]) == 1 + and int(input_shape[1]) == 4 + and str(input_dtype) == "torch.float32" + and int(interpolation_mode) in (0, 1) + and not bool(align_corners) + ) + shader_file = ( + GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY + if use_sampler + else GRID_SAMPLER_2D_SHADER_BINARY + ) shader_code = "".join( - files(__package__) - .joinpath(GRID_SAMPLER_2D_SHADER_BINARY) - .read_text(encoding="utf-8") - .split() + files(__package__).joinpath(shader_file).read_text(encoding="utf-8").split() ) - return { + payload = { "entry_point": GRID_SAMPLER_2D_SHADER_ENTRY_POINT, "workgroup_sizes": GRID_SAMPLER_2D_WORKGROUP_SIZES, "shader_language": GRID_SAMPLER_2D_SHADER_LANGUAGE, "shader_code": shader_code, - "input_0_type": "Tensor", - "input_0_vkformat": GRID_SAMPLER_2D_VK_FORMAT, - "input_0_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", "input_0_binding": 0, "input_0_descriptorset": 0, "input_1_type": "Tensor", "input_1_vkformat": GRID_SAMPLER_2D_VK_FORMAT, - "input_1_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", "input_1_binding": 1, "input_1_descriptorset": 0, - "output_0_type": "Tensor", - "output_0_vkformat": GRID_SAMPLER_2D_VK_FORMAT, - "output_0_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", "output_0_binding": 2, "output_0_descriptorset": 0, } + if use_sampler: + payload.update( + { + "input_0_type": "Image", + "input_0_vkformat": GRID_SAMPLER_2D_SAMPLER_VK_FORMAT, + "input_0_vkdescriptortype": ( + "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ), + "input_0_sampler": _sampler_config( + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + ), + "input_1_vkdescriptortype": "VK_DESCRIPTOR_TYPE_TENSOR_ARM", + "output_0_type": "Image", + "output_0_vkformat": GRID_SAMPLER_2D_SAMPLER_VK_FORMAT, + "output_0_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE", + } + ) + else: + payload.update( + { + "input_0_type": "Tensor", + "input_0_vkformat": GRID_SAMPLER_2D_VK_FORMAT, + "input_0_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", + "input_1_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", + "output_0_type": "Tensor", + "output_0_vkformat": GRID_SAMPLER_2D_VK_FORMAT, + "output_0_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", + } + ) + return payload + + +def _sampler_config(interpolation_mode: int, padding_mode: int) -> dict[str, str]: + interpolation = _mode_name( + int(interpolation_mode), + _INTERPOLATION_MODE_NAMES, + "interpolation_mode", + ) + padding = _mode_name( + int(padding_mode), + _PADDING_MODE_NAMES, + "padding_mode", + ) + + filter_mode = ( + "VK_FILTER_NEAREST" if interpolation == "nearest" else "VK_FILTER_LINEAR" + ) + if padding == "zeros": + address_mode = "VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_BORDER" + elif padding == "border": + address_mode = "VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE" + else: + address_mode = "VK_SAMPLER_ADDRESS_MODE_MIRRORED_REPEAT" + + return { + "min_filter": filter_mode, + "mag_filter": filter_mode, + "address_mode_u": address_mode, + "address_mode_v": address_mode, + "border_color": "VK_BORDER_COLOR_FLOAT_TRANSPARENT_BLACK", + } def encode_payload(payload: dict[str, Any]) -> list[int]: diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler.glsl b/backends/arm/vgf/shaders/grid_sampler_sampler.glsl new file mode 100644 index 00000000000..69fcbd3d59f --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler.glsl @@ -0,0 +1,35 @@ +// Copyright 2026 Arm Limited and/or its affiliates. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#version 450 +#extension GL_ARM_tensors : require + +layout(set = 0, binding = 0) uniform sampler2D inputImage; +layout(set = 0, binding = 1) uniform tensorARM grid; +layout(set = 0, binding = 2, rgba32f) uniform writeonly image2D outImage; + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +vec2 readGridXY(ivec2 p) { + uint xCoords[4] = uint[](0u, uint(p.y), uint(p.x), 0u); + uint yCoords[4] = uint[](0u, uint(p.y), uint(p.x), 1u); + float xVal[1]; + float yVal[1]; + tensorReadARM(grid, xCoords, xVal); + tensorReadARM(grid, yCoords, yVal); + return vec2(xVal[0], yVal[0]); +} + +void main() { + ivec2 outSize = imageSize(outImage); + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= outSize.x || gid.y >= outSize.y) { + return; + } + + vec2 gridXY = readGridXY(gid); + vec2 uv = (gridXY + vec2(1.0)) * 0.5; + imageStore(outImage, gid, texture(inputImage, uv)); +} diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler.spirv.b64 b/backends/arm/vgf/shaders/grid_sampler_sampler.spirv.b64 new file mode 100644 index 00000000000..e460cbe8621 --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler.spirv.b64 @@ -0,0 +1 @@ +AwIjBwAAAQALAA0AdQAAAAAAAAARAAIAAQAAABEAAgAyAAAAEQACAE4QAAAKAAUAU1BWX0FSTV90ZW5zb3JzAAsABgABAAAAR0xTTC5zdGQuNDUwAAAAAA4AAwAAAAAAAQAAAA8ABgAFAAAABAAAAG1haW4AAAAARQAAABAABgAEAAAAEQAAAAgAAAAIAAAAAQAAAAMAAwACAAAAwgEAAAQABQBHTF9BUk1fdGVuc29ycwAABAAKAEdMX0dPT0dMRV9jcHBfc3R5bGVfbGluZV9kaXJlY3RpdmUAAAQACABHTF9HT09HTEVfaW5jbHVkZV9kaXJlY3RpdmUABQAEAAQAAABtYWluAAAAAAUABgANAAAAcmVhZEdyaWRYWSh2aTI7AAUAAwAMAAAAcAAAAAUABAATAAAAeENvb3JkcwAFAAQAHgAAAHlDb29yZHMABQAEACgAAABncmlkAAAAAAUABAAtAAAAeFZhbAAAAAAFAAQAMQAAAHlWYWwAAAAABQAEADwAAABvdXRTaXplAAUABQA/AAAAb3V0SW1hZ2UAAAAABQADAEIAAABnaWQABQAIAEUAAABnbF9HbG9iYWxJbnZvY2F0aW9uSUQAAAAFAAQAXQAAAGdyaWRYWQAABQAEAF4AAABwYXJhbQAAAAUAAwBhAAAAdXYAAAUABQBtAAAAaW5wdXRJbWFnZQAARwAEACgAAAAhAAAAAQAAAEcABAAoAAAAIgAAAAAAAABHAAMAPwAAABkAAABHAAQAPwAAACEAAAACAAAARwAEAD8AAAAiAAAAAAAAAEcABABFAAAACwAAABwAAABHAAQAbQAAACEAAAAAAAAARwAEAG0AAAAiAAAAAAAAAEcABAB0AAAACwAAABkAAAATAAIAAgAAACEAAwADAAAAAgAAABUABAAGAAAAIAAAAAEAAAAXAAQABwAAAAYAAAACAAAAIAAEAAgAAAAHAAAABwAAABYAAwAJAAAAIAAAABcABAAKAAAACQAAAAIAAAAhAAQACwAAAAoAAAAIAAAAFQAEAA8AAAAgAAAAAAAAACsABAAPAAAAEAAAAAQAAAAcAAQAEQAAAA8AAAAQAAAAIAAEABIAAAAHAAAAEQAAACsABAAPAAAAFAAAAAAAAAArAAQADwAAABUAAAABAAAAIAAEABYAAAAHAAAABgAAAEMQBAAmAAAACQAAABAAAAAgAAQAJwAAAAAAAAAmAAAAOwAEACcAAAAoAAAAAAAAABwABAArAAAACQAAABUAAAAgAAQALAAAAAcAAAArAAAAKwAEAAYAAAAzAAAAAAAAACAABAA0AAAABwAAAAkAAAAZAAkAPQAAAAkAAAABAAAAAAAAAAAAAAAAAAAAAgAAAAEAAAAgAAQAPgAAAAAAAAA9AAAAOwAEAD4AAAA/AAAAAAAAABcABABDAAAADwAAAAMAAAAgAAQARAAAAAEAAABDAAAAOwAEAEQAAABFAAAAAQAAABcABABGAAAADwAAAAIAAAAUAAIASgAAACAABABcAAAABwAAAAoAAAArAAQACQAAAGMAAAAAAIA/LAAFAAoAAABkAAAAYwAAAGMAAAArAAQACQAAAGYAAAAAAAA/GQAJAGoAAAAJAAAAAQAAAAAAAAAAAAAAAAAAAAEAAAAAAAAAGwADAGsAAABqAAAAIAAEAGwAAAAAAAAAawAAADsABABsAAAAbQAAAAAAAAAXAAQAcAAAAAkAAAAEAAAAKwAEAAkAAABxAAAAAAAAACsABAAPAAAAcwAAAAgAAAAsAAYAQwAAAHQAAABzAAAAcwAAABUAAAA2AAUAAgAAAAQAAAAAAAAAAwAAAPgAAgAFAAAAOwAEAAgAAAA8AAAABwAAADsABAAIAAAAQgAAAAcAAAA7AAQAXAAAAF0AAAAHAAAAOwAEAAgAAABeAAAABwAAADsABABcAAAAYQAAAAcAAAA9AAQAPQAAAEAAAAA/AAAAaAAEAAcAAABBAAAAQAAAAD4AAwA8AAAAQQAAAD0ABABDAAAARwAAAEUAAABPAAcARgAAAEgAAABHAAAARwAAAAAAAAABAAAAfAAEAAcAAABJAAAASAAAAD4AAwBCAAAASQAAAEEABQAWAAAASwAAAEIAAAAUAAAAPQAEAAYAAABMAAAASwAAAEEABQAWAAAATQAAADwAAAAUAAAAPQAEAAYAAABOAAAATQAAAK8ABQBKAAAATwAAAEwAAABOAAAAqAAEAEoAAABQAAAATwAAAPcAAwBSAAAAAAAAAPoABABQAAAAUQAAAFIAAAD4AAIAUQAAAEEABQAWAAAAUwAAAEIAAAAVAAAAPQAEAAYAAABUAAAAUwAAAEEABQAWAAAAVQAAADwAAAAVAAAAPQAEAAYAAABWAAAAVQAAAK8ABQBKAAAAVwAAAFQAAABWAAAA+QACAFIAAAD4AAIAUgAAAPUABwBKAAAAWAAAAE8AAAAFAAAAVwAAAFEAAAD3AAMAWgAAAAAAAAD6AAQAWAAAAFkAAABaAAAA+AACAFkAAAD9AAEA+AACAFoAAAA9AAQABwAAAF8AAABCAAAAPgADAF4AAABfAAAAOQAFAAoAAABgAAAADQAAAF4AAAA+AAMAXQAAAGAAAAA9AAQACgAAAGIAAABdAAAAgQAFAAoAAABlAAAAYgAAAGQAAACOAAUACgAAAGcAAABlAAAAZgAAAD4AAwBhAAAAZwAAAD0ABAA9AAAAaAAAAD8AAAA9AAQABwAAAGkAAABCAAAAPQAEAGsAAABuAAAAbQAAAD0ABAAKAAAAbwAAAGEAAABYAAcAcAAAAHIAAABuAAAAbwAAAAIAAABxAAAAYwAEAGgAAABpAAAAcgAAAP0AAQA4AAEANgAFAAoAAAANAAAAAAAAAAsAAAA3AAMACAAAAAwAAAD4AAIADgAAADsABAASAAAAEwAAAAcAAAA7AAQAEgAAAB4AAAAHAAAAOwAEACwAAAAtAAAABwAAADsABAAsAAAAMQAAAAcAAABBAAUAFgAAABcAAAAMAAAAFQAAAD0ABAAGAAAAGAAAABcAAAB8AAQADwAAABkAAAAYAAAAQQAFABYAAAAaAAAADAAAABQAAAA9AAQABgAAABsAAAAaAAAAfAAEAA8AAAAcAAAAGwAAAFAABwARAAAAHQAAABQAAAAZAAAAHAAAABQAAAA+AAMAEwAAAB0AAABBAAUAFgAAAB8AAAAMAAAAFQAAAD0ABAAGAAAAIAAAAB8AAAB8AAQADwAAACEAAAAgAAAAQQAFABYAAAAiAAAADAAAABQAAAA9AAQABgAAACMAAAAiAAAAfAAEAA8AAAAkAAAAIwAAAFAABwARAAAAJQAAABQAAAAhAAAAJAAAABUAAAA+AAMAHgAAACUAAAA9AAQAJgAAACkAAAAoAAAAPQAEABEAAAAqAAAAEwAAAEQQBQArAAAALgAAACkAAAAqAAAAPgADAC0AAAAuAAAAPQAEACYAAAAvAAAAKAAAAD0ABAARAAAAMAAAAB4AAABEEAUAKwAAADIAAAAvAAAAMAAAAD4AAwAxAAAAMgAAAEEABQA0AAAANQAAAC0AAAAzAAAAPQAEAAkAAAA2AAAANQAAAEEABQA0AAAANwAAADEAAAAzAAAAPQAEAAkAAAA4AAAANwAAAFAABQAKAAAAOQAAADYAAAA4AAAA/gACADkAAAA4AAEA From 056d58dffefceace67448ec7dfbdd571f41f612e Mon Sep 17 00:00:00 2001 From: Baris Demir Date: Fri, 12 Jun 2026 12:30:55 +0100 Subject: [PATCH 2/4] Arm backend: Allow VGF delegate count checks Allow VgfPipeline callers to set the expected number of executorch_call_delegate nodes while keeping the default at one. This is a test-framework-only change for multi-delegate VGF tests. It does not affect Arm backend lowering or runtime behavior. Signed-off-by: Baris Demir Change-Id: Iac60bb75bd40cab424ccb9841a2d02ee3767f11c --- .../arm/test/misc/test_multiple_delegates.py | 20 ++++++++++++++++++- backends/arm/test/tester/test_pipeline.py | 7 +++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/backends/arm/test/misc/test_multiple_delegates.py b/backends/arm/test/misc/test_multiple_delegates.py index bbc0b2b1bce..e8a929a4e1f 100644 --- a/backends/arm/test/misc/test_multiple_delegates.py +++ b/backends/arm/test/misc/test_multiple_delegates.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,6 +10,7 @@ from executorch.backends.arm.test.tester.test_pipeline import ( TosaPipelineFP, TosaPipelineINT, + VgfPipeline, ) @@ -51,3 +52,20 @@ def test_multiple_delegates_tosa_INT(test_data: input_t1): "check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2} ) pipeline.run() + + +@common.parametrize("test_data", MultipleDelegatesModule.inputs) +@common.SkipIfNoModelConverter +def test_multiple_delegates_vgf_INT(test_data: input_t1): + aten_ops: list[str] = [] + exir_ops: list[str] = [] + pipeline = VgfPipeline[input_t1]( + MultipleDelegatesModule(), + test_data, + aten_ops, + exir_ops, + quantize=True, + run_on_vulkan_runtime=False, + n_expected_delegates=2, + ) + pipeline.run() diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 1304c2b2e54..589c0a851a3 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -1223,6 +1223,8 @@ class VgfPipeline(BasePipeline, Generic[T]): use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module. custom_path : Path to dump intermediate artifacts such as tosa and pte to. + n_expected_delegates: Number of delegate calls expected after + partitioning. """ @@ -1252,6 +1254,7 @@ def __init__( tosa_spec: TosaSpecification | str | None = None, fold_quantize: bool = True, preserve_io_quantization: bool = False, + n_expected_delegates: int = 1, ): if tosa_spec is None: if tosa_version is None: @@ -1282,6 +1285,10 @@ def __init__( dynamic_shapes, transform_passes=transform_passes, ) + self.change_args( + "check_count.exir", + {"torch.ops.higher_order.executorch_call_delegate": n_expected_delegates}, + ) remove_torch_quant_nodes_stage = ( "to_edge_transform_and_lower" From 587e438a44945ee620e5f8ef288af32947384b17 Mon Sep 17 00:00:00 2001 From: Baris Demir Date: Fri, 12 Jun 2026 15:58:21 +0100 Subject: [PATCH 3/4] Arm backend: Use int8 grid sampler payload Add an int8 sampler shader payload for quantized grid_sample boundaries using fixed SNORM-compatible qparams. Annotate grid_sampler in the Arm quantizer with the same fixed int8 image qparams so PT2E places the expected Q/DQ boundary before the VGF custom-op rewrite. The rewrite pass consumes that quantized boundary when selecting the int8 image payload, avoiding a float image payload and the extra runtime Q/DQ around the custom shader. Signed-off-by: Baris Demir Change-Id: I7c85dde518f871a7572386ca6cfe67e865ee6745 --- backends/arm/TARGETS | 2 + backends/arm/operators/op_tosa_custom.py | 1 + .../arm/quantizer/quantization_annotator.py | 75 ++++++++++++++++--- backends/arm/quantizer/quantization_config.py | 44 +++++------ backends/arm/quantizer/quantizer_support.py | 1 + .../test/misc/test_custom_shader_payload.py | 36 +++++++++ ...ewrite_grid_sampler_to_tosa_custom_pass.py | 53 +++++++++++++ .../rewrite_grid_sampler_to_tosa_custom.py | 61 +++++++++++---- backends/arm/vgf/shaders/grid_sampler.py | 32 +++++++- .../shaders/grid_sampler_sampler_int8.glsl | 35 +++++++++ .../grid_sampler_sampler_int8.spirv.b64 | 1 + 11 files changed, 289 insertions(+), 52 deletions(-) create mode 100644 backends/arm/vgf/shaders/grid_sampler_sampler_int8.glsl create mode 100644 backends/arm/vgf/shaders/grid_sampler_sampler_int8.spirv.b64 diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index 5c66d713e5a..c0f5ac7612e 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -102,6 +102,8 @@ runtime.python_library( "vgf/shaders/grid_sampler.spirv.b64", "vgf/shaders/grid_sampler_sampler.glsl", "vgf/shaders/grid_sampler_sampler.spirv.b64", + "vgf/shaders/grid_sampler_sampler_int8.glsl", + "vgf/shaders/grid_sampler_sampler_int8.spirv.b64", ], deps = [ ":arm_compile_spec", diff --git a/backends/arm/operators/op_tosa_custom.py b/backends/arm/operators/op_tosa_custom.py index 45a6097af43..efa8701ae9e 100644 --- a/backends/arm/operators/op_tosa_custom.py +++ b/backends/arm/operators/op_tosa_custom.py @@ -40,6 +40,7 @@ def _vk_format_component_count(vk_format: str) -> int: "VK_FORMAT_R32G32_SFLOAT": 2, "VK_FORMAT_R8G8B8A8_UINT": 4, "VK_FORMAT_R8G8B8A8_SINT": 4, + "VK_FORMAT_R8G8B8A8_SNORM": 4, "VK_FORMAT_R16G16B16A16_UINT": 4, "VK_FORMAT_R16G16B16A16_SINT": 4, "VK_FORMAT_R16G16B16A16_SFLOAT": 4, diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 88b59b21d31..7810077a679 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -84,6 +84,8 @@ def __init__(self): class _QParams(NamedTuple): scale: float zero_point: int + quant_min: int | None = None + quant_max: int | None = None def _as_list(x): @@ -472,8 +474,53 @@ def _match_pattern( 8: _QParams((0.999 - (-0.999)) / (1 << 8), 0), 16: _QParams((0.99999 - (-0.99999)) / (1 << 16), 0), }, + # grid_sampler image input/output use SNORM-compatible qparams. The grid + # coordinate tensor is intentionally left unquantized. + torch.ops.aten.grid_sampler.default: { + 8: _QParams(1.0 / 127.0, 0, -127, 127), + }, +} + + +_fixed_output_qspec_ops: dict[Any, dict[int, _QParams]] = { + torch.ops.aten.grid_sampler.default: { + 8: _QParams(1.0 / 127.0, 0, -127, 127), + }, } + +def _get_fixed_qparams_qspec( + node_target: Any, + qparams_table: dict[Any, dict[int, _QParams]], + input_act_qspec: QuantizationSpecBase, +) -> FixedQParamsQuantizationSpec | None: + if not isinstance(input_act_qspec, QuantizationSpec): + raise ValueError("Fixed qparams require a QuantizationSpec input.") + + num_bits = torch.iinfo(input_act_qspec.dtype).bits + qparams = qparams_table[node_target].get(num_bits) + if qparams is None: + return None + + return FixedQParamsQuantizationSpec( + dtype=input_act_qspec.dtype, + scale=qparams.scale, + zero_point=qparams.zero_point, + quant_min=( + input_act_qspec.quant_min + if qparams.quant_min is None + else qparams.quant_min + ), + quant_max=( + input_act_qspec.quant_max + if qparams.quant_max is None + else qparams.quant_max + ), + qscheme=input_act_qspec.qscheme, + is_dynamic=input_act_qspec.is_dynamic, + ) + + _one_to_one: set[OpOverload] = { torch.ops.aten.abs.default, torch.ops.aten.ceil.default, @@ -762,6 +809,16 @@ def any_or_hardtanh_min_zero(n: Node): _QuantProperty(1, input_act_qspec), ] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) + elif node.target == torch.ops.aten.grid_sampler.default: + image_node = ensure_type(Node, node.args[0]) + grid_sampler_image_qspec = quantization_config.get_input_act_qspec( + node, image_node + ) + grid_sampler_output_qspec = quantization_config.get_output_act_qspec(node) + if grid_sampler_image_qspec is None or grid_sampler_output_qspec is None: + return None + quant_properties.quant_inputs = [_QuantProperty(0, grid_sampler_image_qspec)] + quant_properties.quant_output = _QuantProperty(0, grid_sampler_output_qspec) elif node.target in (torch.ops.aten.where.self,): true_node = ensure_type(Node, node.args[1]) input_qspec = ( @@ -825,21 +882,15 @@ def any_or_hardtanh_min_zero(n: Node): quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in _fixed_input_qspec_ops: - num_bits = torch.iinfo(input_act_qspec.dtype).bits - qparams = _fixed_input_qspec_ops[node.target][num_bits] - + fixed_input_qspec = _get_fixed_qparams_qspec( + node.target, _fixed_input_qspec_ops, input_act_qspec + ) + if fixed_input_qspec is None: + return None quant_properties.quant_inputs = [ _QuantProperty( 0, - FixedQParamsQuantizationSpec( - dtype=input_act_qspec.dtype, - scale=qparams.scale, - zero_point=qparams.zero_point, - quant_min=input_act_qspec.quant_min, - quant_max=input_act_qspec.quant_max, - qscheme=input_act_qspec.qscheme, - is_dynamic=input_act_qspec.is_dynamic, - ), + fixed_input_qspec, ) ] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index 0c64d147c84..a076d20743e 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -21,7 +21,6 @@ from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, - FixedQParamsQuantizationSpec, QuantizationSpec, QuantizationSpecBase, SharedQuantizationSpec, @@ -295,6 +294,7 @@ def get_input_act_qspec(self, node=None, input_node=None): # MLETORCH-1853: Fix lazy import when moving files around from executorch.backends.arm.quantizer.quantization_annotator import ( _fixed_input_qspec_ops, + _get_fixed_qparams_qspec, ) if node is None or input_node is None: @@ -305,28 +305,17 @@ def get_input_act_qspec(self, node=None, input_node=None): return super().get_input_act_qspec(node, input_node) else: return SharedQuantizationSpec((node.args[0], node)) + elif node.target == torch.ops.aten.grid_sampler.default: + if input_node != node.args[0]: + return None + input_act_qspec = super().get_input_act_qspec(node, input_node) + return _get_fixed_qparams_qspec( + node.target, _fixed_input_qspec_ops, input_act_qspec + ) elif node.target in _fixed_input_qspec_ops: - input_act_qspec = super().get_input_act_qspec(node, input_node) - if not hasattr(input_act_qspec, "dtype") or not isinstance( - input_act_qspec.dtype, torch.dtype - ): - raise ValueError( - f"{node.target} requires an input activation quantization " - "spec to use fixed input qparams." - ) - dtype = getattr(input_act_qspec, "dtype", None) - num_bits = torch.iinfo(dtype).bits - - qparams = _fixed_input_qspec_ops[node.target][num_bits] - return FixedQParamsQuantizationSpec( - dtype=dtype, - scale=qparams.scale, - zero_point=qparams.zero_point, - quant_min=input_act_qspec.quant_min, - quant_max=input_act_qspec.quant_max, - qscheme=input_act_qspec.qscheme, - is_dynamic=input_act_qspec.is_dynamic, + return _get_fixed_qparams_qspec( + node.target, _fixed_input_qspec_ops, input_act_qspec ) return super().get_input_act_qspec(node, input_node) @@ -371,6 +360,19 @@ def get_output_act_qspec( if node is None: return super().get_output_act_qspec() + # MLETORCH-1853: Fix lazy import when moving files around + from executorch.backends.arm.quantizer.quantization_annotator import ( + _fixed_output_qspec_ops, + _get_fixed_qparams_qspec, + ) + + if node.target in _fixed_output_qspec_ops: + output_act_qspec = super().get_output_act_qspec(node) + if output_act_qspec is None: + return None + return _get_fixed_qparams_qspec( + node.target, _fixed_output_qspec_ops, output_act_qspec + ) if node.target not in self.SHARED_OUTPUT_ACT_QSPEC_PATTERNS: return super().get_output_act_qspec() if len(node.args) == 0: diff --git a/backends/arm/quantizer/quantizer_support.py b/backends/arm/quantizer/quantizer_support.py index d6a725c2b06..a462e0227c8 100644 --- a/backends/arm/quantizer/quantizer_support.py +++ b/backends/arm/quantizer/quantizer_support.py @@ -174,6 +174,7 @@ def check_pattern(cls, pattern): (torch.ops.aten.acos.default,), (torch.ops.aten.atanh.default,), (torch.ops.aten.einsum.default,), + (torch.ops.aten.grid_sampler.default,), ] ) TOSA_QUANTIZER_SUPPORT_DICT: dict[tuple[OpOverload, ...], type[PatternCheck] | None] = { diff --git a/backends/arm/test/misc/test_custom_shader_payload.py b/backends/arm/test/misc/test_custom_shader_payload.py index e9f0cc431ce..f1529ce88f0 100644 --- a/backends/arm/test/misc/test_custom_shader_payload.py +++ b/backends/arm/test/misc/test_custom_shader_payload.py @@ -11,6 +11,9 @@ build_grid_sampler_2d_payload, decode_payload, encode_payload, + GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_BINARY, + GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_SOURCE, + GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT, GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY, GRID_SAMPLER_2D_SAMPLER_SHADER_SOURCE, GRID_SAMPLER_2D_SAMPLER_VK_FORMAT, @@ -81,6 +84,32 @@ def test_grid_sampler_2d_custom_shader_payload_no_target_uses_sampler_for_c4(): } +def test_grid_sampler_2d_custom_shader_payload_no_target_uses_int8_sampler_for_c4(): + payload = build_grid_sampler_2d_payload( + interpolation_mode=0, + padding_mode=0, + align_corners=False, + input_shape=(1, 4, 8, 8), + input_dtype=torch.int8, + output_dtype=torch.int8, + ) + + assert payload["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE + assert base64.b64decode(payload["shader_code"])[:4] == b"\x03\x02\x23\x07" + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ) + assert payload["input_1_type"] == "Tensor" + assert payload["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT + assert payload["input_1_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_TENSOR_ARM" + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" + + def test_grid_sampler_2d_custom_shader_payload_no_target_keeps_c3_on_buffer(): payload = build_grid_sampler_2d_payload( interpolation_mode=0, @@ -148,6 +177,13 @@ def test_grid_sampler_2d_custom_shader_payload_no_target_has_shader_resources(): assert GRID_SAMPLER_2D_SHADER_BINARY == "grid_sampler.spirv.b64" assert GRID_SAMPLER_2D_SAMPLER_SHADER_SOURCE == "grid_sampler_sampler.glsl" assert GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY == "grid_sampler_sampler.spirv.b64" + assert ( + GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_SOURCE == "grid_sampler_sampler_int8.glsl" + ) + assert ( + GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_BINARY + == "grid_sampler_sampler_int8.spirv.b64" + ) def test_grid_sampler_2d_custom_shader_payload_no_target_rejects_bad_modes(): diff --git a/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py b/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py index 7fe2644f1e7..b8aedd7c038 100644 --- a/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py +++ b/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py @@ -4,8 +4,14 @@ # LICENSE file in the root directory of this source tree. import executorch.backends.arm.tosa.dialect # noqa: F401 +import pytest import torch import torch.nn.functional as F +from executorch.backends.arm._passes import FoldAndAnnotateQParamsPass +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) from executorch.backends.arm.tosa.specification import ( TosaLoweringContext, TosaSpecification, @@ -17,6 +23,7 @@ CUSTOM_SHADER_DOMAIN_NAME, decode_payload, grid_sampler_2d_operator_name, + GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT, GRID_SAMPLER_2D_SAMPLER_VK_FORMAT, GRID_SAMPLER_2D_SHADER_ENTRY_POINT, GRID_SAMPLER_2D_SHADER_LANGUAGE, @@ -26,6 +33,7 @@ from executorch.exir import to_edge from executorch.exir.dialects._ops import ops as exir_ops from torch.export import export +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e class GridSampler2d(torch.nn.Module): @@ -131,6 +139,51 @@ def test_rewrite_grid_sampler_to_tosa_custom_no_target_uses_sampler_for_c4(): assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT +@pytest.mark.parametrize("channels", [3, 4]) +@pytest.mark.parametrize("use_composable_quantizer", [False, True]) +def test_quantized_grid_sampler_uses_int8_sampler_payload( + channels, use_composable_quantizer +): + model = GridSampler2d().eval() + example_inputs = ( + torch.randn(1, channels, 8, 8), + torch.rand(1, 4, 4, 2), + ) + quantizer = TOSAQuantizer( + TosaSpecification.create_from_string("TOSA-1.0+INT"), + use_composable_quantizer=use_composable_quantizer, + ) + quantizer.set_global(get_symmetric_quantization_config(is_per_channel=False)) + + exported = export(model, example_inputs, strict=True) + prepared = prepare_pt2e(exported.module(), quantizer) + prepared(*example_inputs) + converted = convert_pt2e(prepared) + + edge_model = to_edge(export(converted, example_inputs, strict=True)) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP+INT")): + edge_model = edge_model.transform( + [FoldAndAnnotateQParamsPass(), RewriteGridSamplerToTosaCustomPass()] + ) + nodes = list(edge_model.exported_program().graph.nodes) + + custom_node = next( + node for node in nodes if node.target == exir_ops.backend.tosa.CUSTOM.default + ) + payload = decode_payload(custom_node.kwargs["implementation_attrs"]) + + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + assert payload["input_1_type"] == "Tensor" + assert payload["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + assert custom_node.meta["input_qparams"][0].qmin == -127 + assert custom_node.meta["input_qparams"][0].qmax == 127 + assert next(iter(custom_node.meta["output_qparams"].values())).qmin == -127 + assert next(iter(custom_node.meta["output_qparams"].values())).qmax == 127 + + def test_rewrite_grid_sampler_to_tosa_custom_no_c3_pad_for_align_corners(): model = GridSampler2d() model.align_corners_ = True diff --git a/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py b/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py index 32c899c68db..fd52164051f 100644 --- a/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py +++ b/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py @@ -3,12 +3,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math import operator from typing import Set, Type import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) +from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER from executorch.backends.arm.tosa.dialect.ops.custom import register_fake_tosa from executorch.backends.arm.vgf.shaders.grid_sampler import ( @@ -104,12 +110,43 @@ def _can_pad_c3_for_sampler( and len(value.shape) == 4 and int(value.shape[0]) == 1 and int(value.shape[1]) == 3 - and value.dtype is torch.float32 + and value.dtype in (torch.float32, torch.int8) and int(interpolation_mode) in (0, 1) and not bool(align_corners) ) +def _uses_grid_sampler_int8_snorm_qparams(qparams: QuantArgs) -> bool: + return ( + not qparams.per_channel + and math.isclose( + qparams.get_scale_per_tensor(), 1.0 / 127.0, rel_tol=1e-6, abs_tol=1e-9 + ) + and qparams.get_zp_per_tensor() == 0 + and qparams.qmin == -127 + and qparams.qmax == 127 + and qparams.dtype == torch.int8 + ) + + +def _uses_grid_sampler_int8_snorm_metadata(node: torch.fx.Node) -> bool: + try: + input_qparams = get_input_qparams(node) + output_qparams = get_output_qparams(node) + except ValueError: + return False + + image_qparams = input_qparams.get(0) + if image_qparams is None: + return False + if not output_qparams: + return False + + return _uses_grid_sampler_int8_snorm_qparams( + image_qparams + ) and _uses_grid_sampler_int8_snorm_qparams(next(iter(output_qparams.values()))) + + class RewriteGridSamplerToTosaCustomPass(ArmPass): """Rewrite ``aten.grid_sampler_2d`` nodes to ``tosa.CUSTOM``.""" @@ -122,6 +159,7 @@ def _encode_payload( padding_mode: int, align_corners: bool, input_tensor: torch.fx.Node, + output_dtype: torch.dtype | None = None, ) -> list[int]: input_val = input_tensor.meta.get("val") if input_val is None: @@ -132,6 +170,7 @@ def _encode_payload( align_corners=align_corners, input_shape=tuple(input_val.shape), input_dtype=input_val.dtype, + output_dtype=output_dtype, ) return encode_payload(payload) @@ -150,27 +189,15 @@ def _pad_c3_input_to_c4( first_channel_val = exir_ops.edge.aten.slice_copy.Tensor(input_val, 1, 0, 1, 1) _set_fake_tensor_meta(first_channel, first_channel_val) - zero_channel = create_node( - graph_module.graph, - op_target=exir_ops.edge.aten.sub.Tensor, - args=(first_channel, first_channel), - kwargs={"alpha": 1}, - from_node=input_tensor, - ) - _set_fake_tensor_meta( - zero_channel, - exir_ops.edge.aten.sub.Tensor(first_channel_val, first_channel_val), - ) - padded_input = create_node( graph_module.graph, op_target=exir_ops.edge.aten.cat.default, - args=([input_tensor, zero_channel], 1), + args=([input_tensor, first_channel], 1), from_node=input_tensor, ) _set_fake_tensor_meta( padded_input, - exir_ops.edge.aten.cat.default([input_val, zero_channel.meta["val"]], 1), + exir_ops.edge.aten.cat.default([input_val, first_channel_val], 1), ) return padded_input @@ -207,6 +234,8 @@ def call(self, graph_module): input_tensor, grid, interpolation_mode, padding_mode, align_corners = ( node.args ) + use_quantized_image_payload = _uses_grid_sampler_int8_snorm_metadata(node) + output_dtype = torch.int8 if use_quantized_image_payload else None pad_c3_for_sampler = _can_pad_c3_for_sampler( input_tensor, interpolation_mode, @@ -230,6 +259,7 @@ def call(self, graph_module): padding_mode=padding_mode, align_corners=align_corners, input_tensor=custom_input, + output_dtype=output_dtype, ) nhwc_input = create_node( graph_module.graph, @@ -299,6 +329,7 @@ def call(self, graph_module): graph_module.graph.erase_node(node) if modified: + graph_module.graph.eliminate_dead_code() graph_module.graph.lint() graph_module.recompile() graph_module = super().call(graph_module).graph_module diff --git a/backends/arm/vgf/shaders/grid_sampler.py b/backends/arm/vgf/shaders/grid_sampler.py index 12aa86e6d35..e81fb30518d 100644 --- a/backends/arm/vgf/shaders/grid_sampler.py +++ b/backends/arm/vgf/shaders/grid_sampler.py @@ -17,7 +17,10 @@ GRID_SAMPLER_2D_SHADER_BINARY = "grid_sampler.spirv.b64" GRID_SAMPLER_2D_SAMPLER_SHADER_SOURCE = "grid_sampler_sampler.glsl" GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY = "grid_sampler_sampler.spirv.b64" +GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_SOURCE = "grid_sampler_sampler_int8.glsl" +GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_BINARY = "grid_sampler_sampler_int8.spirv.b64" GRID_SAMPLER_2D_SAMPLER_VK_FORMAT = "VK_FORMAT_R32G32B32A32_SFLOAT" +GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT = "VK_FORMAT_R8G8B8A8_SNORM" _INTERPOLATION_MODE_NAMES = { 0: "bilinear", @@ -72,6 +75,7 @@ def build_grid_sampler_2d_payload( align_corners: bool, input_shape: tuple[int, ...] | None = None, input_dtype: Any | None = None, + output_dtype: Any | None = None, ) -> dict[str, Any]: _mode_name( int(interpolation_mode), @@ -83,17 +87,21 @@ def build_grid_sampler_2d_payload( _PADDING_MODE_NAMES, "padding_mode", ) + if output_dtype is None: + output_dtype = input_dtype + + sampler_vk_format = _sampler_vk_format(input_dtype, output_dtype) use_sampler = ( input_shape is not None and len(input_shape) == 4 and int(input_shape[0]) == 1 and int(input_shape[1]) == 4 - and str(input_dtype) == "torch.float32" + and sampler_vk_format is not None and int(interpolation_mode) in (0, 1) and not bool(align_corners) ) shader_file = ( - GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY + _sampler_shader_file(sampler_vk_format) if use_sampler else GRID_SAMPLER_2D_SHADER_BINARY ) @@ -119,7 +127,7 @@ def build_grid_sampler_2d_payload( payload.update( { "input_0_type": "Image", - "input_0_vkformat": GRID_SAMPLER_2D_SAMPLER_VK_FORMAT, + "input_0_vkformat": sampler_vk_format, "input_0_vkdescriptortype": ( "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" ), @@ -129,7 +137,7 @@ def build_grid_sampler_2d_payload( ), "input_1_vkdescriptortype": "VK_DESCRIPTOR_TYPE_TENSOR_ARM", "output_0_type": "Image", - "output_0_vkformat": GRID_SAMPLER_2D_SAMPLER_VK_FORMAT, + "output_0_vkformat": sampler_vk_format, "output_0_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE", } ) @@ -148,6 +156,22 @@ def build_grid_sampler_2d_payload( return payload +def _sampler_vk_format(input_dtype: Any | None, output_dtype: Any | None) -> str | None: + if str(input_dtype) != str(output_dtype): + return None + if str(input_dtype) == "torch.float32": + return GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + if str(input_dtype) == "torch.int8": + return GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + return None + + +def _sampler_shader_file(sampler_vk_format: str | None) -> str: + if sampler_vk_format == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT: + return GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_BINARY + return GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY + + def _sampler_config(interpolation_mode: int, padding_mode: int) -> dict[str, str]: interpolation = _mode_name( int(interpolation_mode), diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler_int8.glsl b/backends/arm/vgf/shaders/grid_sampler_sampler_int8.glsl new file mode 100644 index 00000000000..132049210db --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler_int8.glsl @@ -0,0 +1,35 @@ +// Copyright 2026 Arm Limited and/or its affiliates. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#version 450 +#extension GL_ARM_tensors : require + +layout(set = 0, binding = 0) uniform sampler2D inputImage; +layout(set = 0, binding = 1) uniform tensorARM grid; +layout(set = 0, binding = 2, rgba8_snorm) uniform writeonly image2D outImage; + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +vec2 readGridXY(ivec2 p) { + uint xCoords[4] = uint[](0u, uint(p.y), uint(p.x), 0u); + uint yCoords[4] = uint[](0u, uint(p.y), uint(p.x), 1u); + float xVal[1]; + float yVal[1]; + tensorReadARM(grid, xCoords, xVal); + tensorReadARM(grid, yCoords, yVal); + return vec2(xVal[0], yVal[0]); +} + +void main() { + ivec2 outSize = imageSize(outImage); + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= outSize.x || gid.y >= outSize.y) { + return; + } + + vec2 gridXY = readGridXY(gid); + vec2 uv = (gridXY + vec2(1.0)) * 0.5; + imageStore(outImage, gid, texture(inputImage, uv)); +} diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler_int8.spirv.b64 b/backends/arm/vgf/shaders/grid_sampler_sampler_int8.spirv.b64 new file mode 100644 index 00000000000..07f4440d3b4 --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler_int8.spirv.b64 @@ -0,0 +1 @@ +AwIjBwAAAQALAA0AdQAAAAAAAAARAAIAAQAAABEAAgAyAAAAEQACAE4QAAAKAAUAU1BWX0FSTV90ZW5zb3JzAAsABgABAAAAR0xTTC5zdGQuNDUwAAAAAA4AAwAAAAAAAQAAAA8ABgAFAAAABAAAAG1haW4AAAAARQAAABAABgAEAAAAEQAAAAgAAAAIAAAAAQAAAAMAAwACAAAAwgEAAAQABQBHTF9BUk1fdGVuc29ycwAABAAKAEdMX0dPT0dMRV9jcHBfc3R5bGVfbGluZV9kaXJlY3RpdmUAAAQACABHTF9HT09HTEVfaW5jbHVkZV9kaXJlY3RpdmUABQAEAAQAAABtYWluAAAAAAUABgANAAAAcmVhZEdyaWRYWSh2aTI7AAUAAwAMAAAAcAAAAAUABAATAAAAeENvb3JkcwAFAAQAHgAAAHlDb29yZHMABQAEACgAAABncmlkAAAAAAUABAAtAAAAeFZhbAAAAAAFAAQAMQAAAHlWYWwAAAAABQAEADwAAABvdXRTaXplAAUABQA/AAAAb3V0SW1hZ2UAAAAABQADAEIAAABnaWQABQAIAEUAAABnbF9HbG9iYWxJbnZvY2F0aW9uSUQAAAAFAAQAXQAAAGdyaWRYWQAABQAEAF4AAABwYXJhbQAAAAUAAwBhAAAAdXYAAAUABQBtAAAAaW5wdXRJbWFnZQAARwAEACgAAAAhAAAAAQAAAEcABAAoAAAAIgAAAAAAAABHAAMAPwAAABkAAABHAAQAPwAAACEAAAACAAAARwAEAD8AAAAiAAAAAAAAAEcABABFAAAACwAAABwAAABHAAQAbQAAACEAAAAAAAAARwAEAG0AAAAiAAAAAAAAAEcABAB0AAAACwAAABkAAAATAAIAAgAAACEAAwADAAAAAgAAABUABAAGAAAAIAAAAAEAAAAXAAQABwAAAAYAAAACAAAAIAAEAAgAAAAHAAAABwAAABYAAwAJAAAAIAAAABcABAAKAAAACQAAAAIAAAAhAAQACwAAAAoAAAAIAAAAFQAEAA8AAAAgAAAAAAAAACsABAAPAAAAEAAAAAQAAAAcAAQAEQAAAA8AAAAQAAAAIAAEABIAAAAHAAAAEQAAACsABAAPAAAAFAAAAAAAAAArAAQADwAAABUAAAABAAAAIAAEABYAAAAHAAAABgAAAEMQBAAmAAAACQAAABAAAAAgAAQAJwAAAAAAAAAmAAAAOwAEACcAAAAoAAAAAAAAABwABAArAAAACQAAABUAAAAgAAQALAAAAAcAAAArAAAAKwAEAAYAAAAzAAAAAAAAACAABAA0AAAABwAAAAkAAAAZAAkAPQAAAAkAAAABAAAAAAAAAAAAAAAAAAAAAgAAAAUAAAAgAAQAPgAAAAAAAAA9AAAAOwAEAD4AAAA/AAAAAAAAABcABABDAAAADwAAAAMAAAAgAAQARAAAAAEAAABDAAAAOwAEAEQAAABFAAAAAQAAABcABABGAAAADwAAAAIAAAAUAAIASgAAACAABABcAAAABwAAAAoAAAArAAQACQAAAGMAAAAAAIA/LAAFAAoAAABkAAAAYwAAAGMAAAArAAQACQAAAGYAAAAAAAA/GQAJAGoAAAAJAAAAAQAAAAAAAAAAAAAAAAAAAAEAAAAAAAAAGwADAGsAAABqAAAAIAAEAGwAAAAAAAAAawAAADsABABsAAAAbQAAAAAAAAAXAAQAcAAAAAkAAAAEAAAAKwAEAAkAAABxAAAAAAAAACsABAAPAAAAcwAAAAgAAAAsAAYAQwAAAHQAAABzAAAAcwAAABUAAAA2AAUAAgAAAAQAAAAAAAAAAwAAAPgAAgAFAAAAOwAEAAgAAAA8AAAABwAAADsABAAIAAAAQgAAAAcAAAA7AAQAXAAAAF0AAAAHAAAAOwAEAAgAAABeAAAABwAAADsABABcAAAAYQAAAAcAAAA9AAQAPQAAAEAAAAA/AAAAaAAEAAcAAABBAAAAQAAAAD4AAwA8AAAAQQAAAD0ABABDAAAARwAAAEUAAABPAAcARgAAAEgAAABHAAAARwAAAAAAAAABAAAAfAAEAAcAAABJAAAASAAAAD4AAwBCAAAASQAAAEEABQAWAAAASwAAAEIAAAAUAAAAPQAEAAYAAABMAAAASwAAAEEABQAWAAAATQAAADwAAAAUAAAAPQAEAAYAAABOAAAATQAAAK8ABQBKAAAATwAAAEwAAABOAAAAqAAEAEoAAABQAAAATwAAAPcAAwBSAAAAAAAAAPoABABQAAAAUQAAAFIAAAD4AAIAUQAAAEEABQAWAAAAUwAAAEIAAAAVAAAAPQAEAAYAAABUAAAAUwAAAEEABQAWAAAAVQAAADwAAAAVAAAAPQAEAAYAAABWAAAAVQAAAK8ABQBKAAAAVwAAAFQAAABWAAAA+QACAFIAAAD4AAIAUgAAAPUABwBKAAAAWAAAAE8AAAAFAAAAVwAAAFEAAAD3AAMAWgAAAAAAAAD6AAQAWAAAAFkAAABaAAAA+AACAFkAAAD9AAEA+AACAFoAAAA9AAQABwAAAF8AAABCAAAAPgADAF4AAABfAAAAOQAFAAoAAABgAAAADQAAAF4AAAA+AAMAXQAAAGAAAAA9AAQACgAAAGIAAABdAAAAgQAFAAoAAABlAAAAYgAAAGQAAACOAAUACgAAAGcAAABlAAAAZgAAAD4AAwBhAAAAZwAAAD0ABAA9AAAAaAAAAD8AAAA9AAQABwAAAGkAAABCAAAAPQAEAGsAAABuAAAAbQAAAD0ABAAKAAAAbwAAAGEAAABYAAcAcAAAAHIAAABuAAAAbwAAAAIAAABxAAAAYwAEAGgAAABpAAAAcgAAAP0AAQA4AAEANgAFAAoAAAANAAAAAAAAAAsAAAA3AAMACAAAAAwAAAD4AAIADgAAADsABAASAAAAEwAAAAcAAAA7AAQAEgAAAB4AAAAHAAAAOwAEACwAAAAtAAAABwAAADsABAAsAAAAMQAAAAcAAABBAAUAFgAAABcAAAAMAAAAFQAAAD0ABAAGAAAAGAAAABcAAAB8AAQADwAAABkAAAAYAAAAQQAFABYAAAAaAAAADAAAABQAAAA9AAQABgAAABsAAAAaAAAAfAAEAA8AAAAcAAAAGwAAAFAABwARAAAAHQAAABQAAAAZAAAAHAAAABQAAAA+AAMAEwAAAB0AAABBAAUAFgAAAB8AAAAMAAAAFQAAAD0ABAAGAAAAIAAAAB8AAAB8AAQADwAAACEAAAAgAAAAQQAFABYAAAAiAAAADAAAABQAAAA9AAQABgAAACMAAAAiAAAAfAAEAA8AAAAkAAAAIwAAAFAABwARAAAAJQAAABQAAAAhAAAAJAAAABUAAAA+AAMAHgAAACUAAAA9AAQAJgAAACkAAAAoAAAAPQAEABEAAAAqAAAAEwAAAEQQBQArAAAALgAAACkAAAAqAAAAPgADAC0AAAAuAAAAPQAEACYAAAAvAAAAKAAAAD0ABAARAAAAMAAAAB4AAABEEAUAKwAAADIAAAAvAAAAMAAAAD4AAwAxAAAAMgAAAEEABQA0AAAANQAAAC0AAAAzAAAAPQAEAAkAAAA2AAAANQAAAEEABQA0AAAANwAAADEAAAAzAAAAPQAEAAkAAAA4AAAANwAAAFAABQAKAAAAOQAAADYAAAA4AAAA/gACADkAAAA4AAEA From 1d9b2c9d6e3eee1d4f6f0a39f329089f5d16f882 Mon Sep 17 00:00:00 2001 From: Baris Demir Date: Mon, 15 Jun 2026 16:51:14 +0100 Subject: [PATCH 4/4] Arm backend: Support aligned VGF grid sampler Add sampler shader variants for grid_sample with align_corners=True. Select the aligned sampler shader from the VGF custom payload when the operator requests align_corners=True, while keeping bicubic on the storage-buffer fallback path. Allow C3 inputs to use the sampler padding path for both align modes so RIFE warps can use the int8 SNORM sampler payload without changing the model's grid_sample semantics. Signed-off-by: Baris Demir Change-Id: I83f8cdb8024d3d77c8706278c9b12c9dbe77b58f --- backends/arm/TARGETS | 4 + .../test/misc/test_custom_shader_payload.py | 73 +++++++++++++++++-- ...ewrite_grid_sampler_to_tosa_custom_pass.py | 15 ++-- .../rewrite_grid_sampler_to_tosa_custom.py | 3 +- backends/arm/vgf/shaders/grid_sampler.py | 24 +++++- .../grid_sampler_sampler_align_corners.glsl | 40 ++++++++++ ...id_sampler_sampler_align_corners.spirv.b64 | 1 + ...id_sampler_sampler_int8_align_corners.glsl | 40 ++++++++++ ...mpler_sampler_int8_align_corners.spirv.b64 | 1 + 9 files changed, 185 insertions(+), 16 deletions(-) create mode 100644 backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.glsl create mode 100644 backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.spirv.b64 create mode 100644 backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.glsl create mode 100644 backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.spirv.b64 diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index c0f5ac7612e..34841a52dc7 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -102,8 +102,12 @@ runtime.python_library( "vgf/shaders/grid_sampler.spirv.b64", "vgf/shaders/grid_sampler_sampler.glsl", "vgf/shaders/grid_sampler_sampler.spirv.b64", + "vgf/shaders/grid_sampler_sampler_align_corners.glsl", + "vgf/shaders/grid_sampler_sampler_align_corners.spirv.b64", "vgf/shaders/grid_sampler_sampler_int8.glsl", "vgf/shaders/grid_sampler_sampler_int8.spirv.b64", + "vgf/shaders/grid_sampler_sampler_int8_align_corners.glsl", + "vgf/shaders/grid_sampler_sampler_int8_align_corners.spirv.b64", ], deps = [ ":arm_compile_spec", diff --git a/backends/arm/test/misc/test_custom_shader_payload.py b/backends/arm/test/misc/test_custom_shader_payload.py index f1529ce88f0..bfafc143c57 100644 --- a/backends/arm/test/misc/test_custom_shader_payload.py +++ b/backends/arm/test/misc/test_custom_shader_payload.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import base64 +from importlib.resources import files import pytest import torch @@ -11,6 +12,10 @@ build_grid_sampler_2d_payload, decode_payload, encode_payload, + GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_BINARY, + GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_SOURCE, + GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_BINARY, + GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_SOURCE, GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_BINARY, GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_SOURCE, GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT, @@ -26,6 +31,15 @@ ) +def _shader_code_from_resource(shader_file: str) -> str: + return "".join( + files("executorch.backends.arm.vgf.shaders") + .joinpath(shader_file) + .read_text(encoding="utf-8") + .split() + ) + + def test_grid_sampler_2d_custom_shader_payload_no_target_round_trip(): payload = build_grid_sampler_2d_payload( interpolation_mode=0, @@ -127,7 +141,7 @@ def test_grid_sampler_2d_custom_shader_payload_no_target_keeps_c3_on_buffer(): assert "input_0_sampler" not in payload -def test_grid_sampler_2d_custom_shader_payload_no_target_align_corners_buffer(): +def test_grid_sampler_2d_custom_shader_payload_no_target_align_corners_sampler(): payload = build_grid_sampler_2d_payload( interpolation_mode=0, padding_mode=0, @@ -136,11 +150,42 @@ def test_grid_sampler_2d_custom_shader_payload_no_target_align_corners_buffer(): input_dtype=torch.float32, ) - assert payload["input_0_type"] == "Tensor" - assert payload["input_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER" - assert payload["output_0_type"] == "Tensor" - assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER" - assert "input_0_sampler" not in payload + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ) + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" + assert payload["shader_code"] == _shader_code_from_resource( + GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_BINARY + ) + + +def test_grid_sampler_2d_custom_shader_payload_no_target_int8_align_corners_sampler(): + payload = build_grid_sampler_2d_payload( + interpolation_mode=0, + padding_mode=0, + align_corners=True, + input_shape=(1, 4, 8, 8), + input_dtype=torch.int8, + output_dtype=torch.int8, + ) + + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" + ) + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT + assert payload["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE" + assert payload["shader_code"] == _shader_code_from_resource( + GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_BINARY + ) def test_grid_sampler_2d_custom_shader_payload_no_target_bicubic_buffer(): @@ -177,6 +222,14 @@ def test_grid_sampler_2d_custom_shader_payload_no_target_has_shader_resources(): assert GRID_SAMPLER_2D_SHADER_BINARY == "grid_sampler.spirv.b64" assert GRID_SAMPLER_2D_SAMPLER_SHADER_SOURCE == "grid_sampler_sampler.glsl" assert GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY == "grid_sampler_sampler.spirv.b64" + assert ( + GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_SOURCE + == "grid_sampler_sampler_align_corners.glsl" + ) + assert ( + GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_BINARY + == "grid_sampler_sampler_align_corners.spirv.b64" + ) assert ( GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_SOURCE == "grid_sampler_sampler_int8.glsl" ) @@ -184,6 +237,14 @@ def test_grid_sampler_2d_custom_shader_payload_no_target_has_shader_resources(): GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_BINARY == "grid_sampler_sampler_int8.spirv.b64" ) + assert ( + GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_SOURCE + == "grid_sampler_sampler_int8_align_corners.glsl" + ) + assert ( + GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_BINARY + == "grid_sampler_sampler_int8_align_corners.spirv.b64" + ) def test_grid_sampler_2d_custom_shader_payload_no_target_rejects_bad_modes(): diff --git a/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py b/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py index b8aedd7c038..eb4b5f23660 100644 --- a/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py +++ b/backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py @@ -184,7 +184,7 @@ def test_quantized_grid_sampler_uses_int8_sampler_payload( assert next(iter(custom_node.meta["output_qparams"].values())).qmax == 127 -def test_rewrite_grid_sampler_to_tosa_custom_no_c3_pad_for_align_corners(): +def test_rewrite_grid_sampler_to_tosa_custom_c3_pad_for_align_corners(): model = GridSampler2d() model.align_corners_ = True example_inputs = ( @@ -202,11 +202,16 @@ def test_rewrite_grid_sampler_to_tosa_custom_no_c3_pad_for_align_corners(): ) payload = decode_payload(custom_node.kwargs["implementation_attrs"]) - assert payload["input_0_type"] == "Tensor" - assert not any(node.target == exir_ops.edge.aten.cat.default for node in nodes) - assert not any( - node.target == exir_ops.edge.aten.slice_copy.Tensor for node in nodes + assert payload["input_0_type"] == "Image" + assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert ( + payload["input_0_vkdescriptortype"] + == "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER" ) + assert payload["output_0_type"] == "Image" + assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_SAMPLER_VK_FORMAT + assert any(node.target == exir_ops.edge.aten.cat.default for node in nodes) + assert any(node.target == exir_ops.edge.aten.slice_copy.Tensor for node in nodes) def test_rewrite_grid_sampler_to_tosa_custom_no_c3_pad_for_bicubic(): diff --git a/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py b/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py index fd52164051f..efff1730914 100644 --- a/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py +++ b/backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py @@ -102,7 +102,7 @@ def _is_static_nchw_with_channels(node: torch.fx.Node, channels: int) -> bool: def _can_pad_c3_for_sampler( input_tensor: torch.fx.Node, interpolation_mode: int, - align_corners: bool, + align_corners: bool, # noqa: ARG001 ) -> bool: value = input_tensor.meta.get("val") return ( @@ -112,7 +112,6 @@ def _can_pad_c3_for_sampler( and int(value.shape[1]) == 3 and value.dtype in (torch.float32, torch.int8) and int(interpolation_mode) in (0, 1) - and not bool(align_corners) ) diff --git a/backends/arm/vgf/shaders/grid_sampler.py b/backends/arm/vgf/shaders/grid_sampler.py index e81fb30518d..5cd63ddf95b 100644 --- a/backends/arm/vgf/shaders/grid_sampler.py +++ b/backends/arm/vgf/shaders/grid_sampler.py @@ -17,8 +17,20 @@ GRID_SAMPLER_2D_SHADER_BINARY = "grid_sampler.spirv.b64" GRID_SAMPLER_2D_SAMPLER_SHADER_SOURCE = "grid_sampler_sampler.glsl" GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY = "grid_sampler_sampler.spirv.b64" +GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_SOURCE = ( + "grid_sampler_sampler_align_corners.glsl" +) +GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_BINARY = ( + "grid_sampler_sampler_align_corners.spirv.b64" +) GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_SOURCE = "grid_sampler_sampler_int8.glsl" GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_BINARY = "grid_sampler_sampler_int8.spirv.b64" +GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_SOURCE = ( + "grid_sampler_sampler_int8_align_corners.glsl" +) +GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_BINARY = ( + "grid_sampler_sampler_int8_align_corners.spirv.b64" +) GRID_SAMPLER_2D_SAMPLER_VK_FORMAT = "VK_FORMAT_R32G32B32A32_SFLOAT" GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT = "VK_FORMAT_R8G8B8A8_SNORM" @@ -98,10 +110,9 @@ def build_grid_sampler_2d_payload( and int(input_shape[1]) == 4 and sampler_vk_format is not None and int(interpolation_mode) in (0, 1) - and not bool(align_corners) ) shader_file = ( - _sampler_shader_file(sampler_vk_format) + _sampler_shader_file(sampler_vk_format, align_corners=align_corners) if use_sampler else GRID_SAMPLER_2D_SHADER_BINARY ) @@ -166,9 +177,16 @@ def _sampler_vk_format(input_dtype: Any | None, output_dtype: Any | None) -> str return None -def _sampler_shader_file(sampler_vk_format: str | None) -> str: +def _sampler_shader_file( + sampler_vk_format: str | None, + align_corners: bool, +) -> str: if sampler_vk_format == GRID_SAMPLER_2D_SAMPLER_INT8_VK_FORMAT: + if align_corners: + return GRID_SAMPLER_2D_SAMPLER_INT8_ALIGN_CORNERS_SHADER_BINARY return GRID_SAMPLER_2D_SAMPLER_INT8_SHADER_BINARY + if align_corners: + return GRID_SAMPLER_2D_SAMPLER_ALIGN_CORNERS_SHADER_BINARY return GRID_SAMPLER_2D_SAMPLER_SHADER_BINARY diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.glsl b/backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.glsl new file mode 100644 index 00000000000..efa5f8c589b --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.glsl @@ -0,0 +1,40 @@ +// Copyright 2026 Arm Limited and/or its affiliates. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#version 450 +#extension GL_ARM_tensors : require + +layout(set = 0, binding = 0) uniform sampler2D inputImage; +layout(set = 0, binding = 1) uniform tensorARM grid; +layout(set = 0, binding = 2, rgba32f) uniform writeonly image2D outImage; + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +vec2 readGridXY(ivec2 p) { + uint xCoords[4] = uint[](0u, uint(p.y), uint(p.x), 0u); + uint yCoords[4] = uint[](0u, uint(p.y), uint(p.x), 1u); + float xVal[1]; + float yVal[1]; + tensorReadARM(grid, xCoords, xVal); + tensorReadARM(grid, yCoords, yVal); + return vec2(xVal[0], yVal[0]); +} + +vec2 alignCornersUv(vec2 gridXY) { + vec2 inputSize = vec2(textureSize(inputImage, 0)); + vec2 texel = (gridXY + vec2(1.0)) * vec2(0.5) * (inputSize - vec2(1.0)); + return (texel + vec2(0.5)) / inputSize; +} + +void main() { + ivec2 outSize = imageSize(outImage); + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= outSize.x || gid.y >= outSize.y) { + return; + } + + vec2 gridXY = readGridXY(gid); + imageStore(outImage, gid, texture(inputImage, alignCornersUv(gridXY))); +} diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.spirv.b64 b/backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.spirv.b64 new file mode 100644 index 00000000000..c000bc299a3 --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler_align_corners.spirv.b64 @@ -0,0 +1 @@ +AwIjBwAAAQALAA0AigAAAAAAAAARAAIAAQAAABEAAgAyAAAAEQACAE4QAAAKAAUAU1BWX0FSTV90ZW5zb3JzAAsABgABAAAAR0xTTC5zdGQuNDUwAAAAAA4AAwAAAAAAAQAAAA8ABgAFAAAABAAAAG1haW4AAAAAZAAAABAABgAEAAAAEQAAAAgAAAAIAAAAAQAAAAMAAwACAAAAwgEAAAQABQBHTF9BUk1fdGVuc29ycwAABAAKAEdMX0dPT0dMRV9jcHBfc3R5bGVfbGluZV9kaXJlY3RpdmUAAAQACABHTF9HT09HTEVfaW5jbHVkZV9kaXJlY3RpdmUABQAEAAQAAABtYWluAAAAAAUABgANAAAAcmVhZEdyaWRYWSh2aTI7AAUAAwAMAAAAcAAAAAUABwASAAAAYWxpZ25Db3JuZXJzVXYodmYyOwAFAAQAEQAAAGdyaWRYWQAABQAEABgAAAB4Q29vcmRzAAUABAAjAAAAeUNvb3JkcwAFAAQALQAAAGdyaWQAAAAABQAEADIAAAB4VmFsAAAAAAUABAA2AAAAeVZhbAAAAAAFAAUAQQAAAGlucHV0U2l6ZQAAAAUABQBFAAAAaW5wdXRJbWFnZQAABQAEAEoAAAB0ZXhlbAAAAAUABABbAAAAb3V0U2l6ZQAFAAUAXgAAAG91dEltYWdlAAAAAAUAAwBhAAAAZ2lkAAUACABkAAAAZ2xfR2xvYmFsSW52b2NhdGlvbklEAAAABQAEAHsAAABncmlkWFkAAAUABAB8AAAAcGFyYW0AAAAFAAQAggAAAHBhcmFtAAAARwAEAC0AAAAhAAAAAQAAAEcABAAtAAAAIgAAAAAAAABHAAQARQAAACEAAAAAAAAARwAEAEUAAAAiAAAAAAAAAEcAAwBeAAAAGQAAAEcABABeAAAAIQAAAAIAAABHAAQAXgAAACIAAAAAAAAARwAEAGQAAAALAAAAHAAAAEcABACJAAAACwAAABkAAAATAAIAAgAAACEAAwADAAAAAgAAABUABAAGAAAAIAAAAAEAAAAXAAQABwAAAAYAAAACAAAAIAAEAAgAAAAHAAAABwAAABYAAwAJAAAAIAAAABcABAAKAAAACQAAAAIAAAAhAAQACwAAAAoAAAAIAAAAIAAEAA8AAAAHAAAACgAAACEABAAQAAAACgAAAA8AAAAVAAQAFAAAACAAAAAAAAAAKwAEABQAAAAVAAAABAAAABwABAAWAAAAFAAAABUAAAAgAAQAFwAAAAcAAAAWAAAAKwAEABQAAAAZAAAAAAAAACsABAAUAAAAGgAAAAEAAAAgAAQAGwAAAAcAAAAGAAAAQxAEACsAAAAJAAAAFQAAACAABAAsAAAAAAAAACsAAAA7AAQALAAAAC0AAAAAAAAAHAAEADAAAAAJAAAAGgAAACAABAAxAAAABwAAADAAAAArAAQABgAAADgAAAAAAAAAIAAEADkAAAAHAAAACQAAABkACQBCAAAACQAAAAEAAAAAAAAAAAAAAAAAAAABAAAAAAAAABsAAwBDAAAAQgAAACAABABEAAAAAAAAAEMAAAA7AAQARAAAAEUAAAAAAAAAKwAEAAkAAABMAAAAAACAPywABQAKAAAATQAAAEwAAABMAAAAKwAEAAkAAABPAAAAAAAAPywABQAKAAAAUAAAAE8AAABPAAAAGQAJAFwAAAAJAAAAAQAAAAAAAAAAAAAAAAAAAAIAAAABAAAAIAAEAF0AAAAAAAAAXAAAADsABABdAAAAXgAAAAAAAAAXAAQAYgAAABQAAAADAAAAIAAEAGMAAAABAAAAYgAAADsABABjAAAAZAAAAAEAAAAXAAQAZQAAABQAAAACAAAAFAACAGkAAAAXAAQAhQAAAAkAAAAEAAAAKwAEAAkAAACGAAAAAAAAACsABAAUAAAAiAAAAAgAAAAsAAYAYgAAAIkAAACIAAAAiAAAABoAAAA2AAUAAgAAAAQAAAAAAAAAAwAAAPgAAgAFAAAAOwAEAAgAAABbAAAABwAAADsABAAIAAAAYQAAAAcAAAA7AAQADwAAAHsAAAAHAAAAOwAEAAgAAAB8AAAABwAAADsABAAPAAAAggAAAAcAAAA9AAQAXAAAAF8AAABeAAAAaAAEAAcAAABgAAAAXwAAAD4AAwBbAAAAYAAAAD0ABABiAAAAZgAAAGQAAABPAAcAZQAAAGcAAABmAAAAZgAAAAAAAAABAAAAfAAEAAcAAABoAAAAZwAAAD4AAwBhAAAAaAAAAEEABQAbAAAAagAAAGEAAAAZAAAAPQAEAAYAAABrAAAAagAAAEEABQAbAAAAbAAAAFsAAAAZAAAAPQAEAAYAAABtAAAAbAAAAK8ABQBpAAAAbgAAAGsAAABtAAAAqAAEAGkAAABvAAAAbgAAAPcAAwBxAAAAAAAAAPoABABvAAAAcAAAAHEAAAD4AAIAcAAAAEEABQAbAAAAcgAAAGEAAAAaAAAAPQAEAAYAAABzAAAAcgAAAEEABQAbAAAAdAAAAFsAAAAaAAAAPQAEAAYAAAB1AAAAdAAAAK8ABQBpAAAAdgAAAHMAAAB1AAAA+QACAHEAAAD4AAIAcQAAAPUABwBpAAAAdwAAAG4AAAAFAAAAdgAAAHAAAAD3AAMAeQAAAAAAAAD6AAQAdwAAAHgAAAB5AAAA+AACAHgAAAD9AAEA+AACAHkAAAA9AAQABwAAAH0AAABhAAAAPgADAHwAAAB9AAAAOQAFAAoAAAB+AAAADQAAAHwAAAA+AAMAewAAAH4AAAA9AAQAXAAAAH8AAABeAAAAPQAEAAcAAACAAAAAYQAAAD0ABABDAAAAgQAAAEUAAAA9AAQACgAAAIMAAAB7AAAAPgADAIIAAACDAAAAOQAFAAoAAACEAAAAEgAAAIIAAABYAAcAhQAAAIcAAACBAAAAhAAAAAIAAACGAAAAYwAEAH8AAACAAAAAhwAAAP0AAQA4AAEANgAFAAoAAAANAAAAAAAAAAsAAAA3AAMACAAAAAwAAAD4AAIADgAAADsABAAXAAAAGAAAAAcAAAA7AAQAFwAAACMAAAAHAAAAOwAEADEAAAAyAAAABwAAADsABAAxAAAANgAAAAcAAABBAAUAGwAAABwAAAAMAAAAGgAAAD0ABAAGAAAAHQAAABwAAAB8AAQAFAAAAB4AAAAdAAAAQQAFABsAAAAfAAAADAAAABkAAAA9AAQABgAAACAAAAAfAAAAfAAEABQAAAAhAAAAIAAAAFAABwAWAAAAIgAAABkAAAAeAAAAIQAAABkAAAA+AAMAGAAAACIAAABBAAUAGwAAACQAAAAMAAAAGgAAAD0ABAAGAAAAJQAAACQAAAB8AAQAFAAAACYAAAAlAAAAQQAFABsAAAAnAAAADAAAABkAAAA9AAQABgAAACgAAAAnAAAAfAAEABQAAAApAAAAKAAAAFAABwAWAAAAKgAAABkAAAAmAAAAKQAAABoAAAA+AAMAIwAAACoAAAA9AAQAKwAAAC4AAAAtAAAAPQAEABYAAAAvAAAAGAAAAEQQBQAwAAAAMwAAAC4AAAAvAAAAPgADADIAAAAzAAAAPQAEACsAAAA0AAAALQAAAD0ABAAWAAAANQAAACMAAABEEAUAMAAAADcAAAA0AAAANQAAAD4AAwA2AAAANwAAAEEABQA5AAAAOgAAADIAAAA4AAAAPQAEAAkAAAA7AAAAOgAAAEEABQA5AAAAPAAAADYAAAA4AAAAPQAEAAkAAAA9AAAAPAAAAFAABQAKAAAAPgAAADsAAAA9AAAA/gACAD4AAAA4AAEANgAFAAoAAAASAAAAAAAAABAAAAA3AAMADwAAABEAAAD4AAIAEwAAADsABAAPAAAAQQAAAAcAAAA7AAQADwAAAEoAAAAHAAAAPQAEAEMAAABGAAAARQAAAGQABABCAAAARwAAAEYAAABnAAUABwAAAEgAAABHAAAAOAAAAG8ABAAKAAAASQAAAEgAAAA+AAMAQQAAAEkAAAA9AAQACgAAAEsAAAARAAAAgQAFAAoAAABOAAAASwAAAE0AAACFAAUACgAAAFEAAABOAAAAUAAAAD0ABAAKAAAAUgAAAEEAAACDAAUACgAAAFMAAABSAAAATQAAAIUABQAKAAAAVAAAAFEAAABTAAAAPgADAEoAAABUAAAAPQAEAAoAAABVAAAASgAAAIEABQAKAAAAVgAAAFUAAABQAAAAPQAEAAoAAABXAAAAQQAAAIgABQAKAAAAWAAAAFYAAABXAAAA/gACAFgAAAA4AAEA diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.glsl b/backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.glsl new file mode 100644 index 00000000000..b0aa9d303fc --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.glsl @@ -0,0 +1,40 @@ +// Copyright 2026 Arm Limited and/or its affiliates. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#version 450 +#extension GL_ARM_tensors : require + +layout(set = 0, binding = 0) uniform sampler2D inputImage; +layout(set = 0, binding = 1) uniform tensorARM grid; +layout(set = 0, binding = 2, rgba8_snorm) uniform writeonly image2D outImage; + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +vec2 readGridXY(ivec2 p) { + uint xCoords[4] = uint[](0u, uint(p.y), uint(p.x), 0u); + uint yCoords[4] = uint[](0u, uint(p.y), uint(p.x), 1u); + float xVal[1]; + float yVal[1]; + tensorReadARM(grid, xCoords, xVal); + tensorReadARM(grid, yCoords, yVal); + return vec2(xVal[0], yVal[0]); +} + +vec2 alignCornersUv(vec2 gridXY) { + vec2 inputSize = vec2(textureSize(inputImage, 0)); + vec2 texel = (gridXY + vec2(1.0)) * vec2(0.5) * (inputSize - vec2(1.0)); + return (texel + vec2(0.5)) / inputSize; +} + +void main() { + ivec2 outSize = imageSize(outImage); + ivec2 gid = ivec2(gl_GlobalInvocationID.xy); + if (gid.x >= outSize.x || gid.y >= outSize.y) { + return; + } + + vec2 gridXY = readGridXY(gid); + imageStore(outImage, gid, texture(inputImage, alignCornersUv(gridXY))); +} diff --git a/backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.spirv.b64 b/backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.spirv.b64 new file mode 100644 index 00000000000..cd15b7a8264 --- /dev/null +++ b/backends/arm/vgf/shaders/grid_sampler_sampler_int8_align_corners.spirv.b64 @@ -0,0 +1 @@ +AwIjBwAAAQALAA0AigAAAAAAAAARAAIAAQAAABEAAgAyAAAAEQACAE4QAAAKAAUAU1BWX0FSTV90ZW5zb3JzAAsABgABAAAAR0xTTC5zdGQuNDUwAAAAAA4AAwAAAAAAAQAAAA8ABgAFAAAABAAAAG1haW4AAAAAZAAAABAABgAEAAAAEQAAAAgAAAAIAAAAAQAAAAMAAwACAAAAwgEAAAQABQBHTF9BUk1fdGVuc29ycwAABAAKAEdMX0dPT0dMRV9jcHBfc3R5bGVfbGluZV9kaXJlY3RpdmUAAAQACABHTF9HT09HTEVfaW5jbHVkZV9kaXJlY3RpdmUABQAEAAQAAABtYWluAAAAAAUABgANAAAAcmVhZEdyaWRYWSh2aTI7AAUAAwAMAAAAcAAAAAUABwASAAAAYWxpZ25Db3JuZXJzVXYodmYyOwAFAAQAEQAAAGdyaWRYWQAABQAEABgAAAB4Q29vcmRzAAUABAAjAAAAeUNvb3JkcwAFAAQALQAAAGdyaWQAAAAABQAEADIAAAB4VmFsAAAAAAUABAA2AAAAeVZhbAAAAAAFAAUAQQAAAGlucHV0U2l6ZQAAAAUABQBFAAAAaW5wdXRJbWFnZQAABQAEAEoAAAB0ZXhlbAAAAAUABABbAAAAb3V0U2l6ZQAFAAUAXgAAAG91dEltYWdlAAAAAAUAAwBhAAAAZ2lkAAUACABkAAAAZ2xfR2xvYmFsSW52b2NhdGlvbklEAAAABQAEAHsAAABncmlkWFkAAAUABAB8AAAAcGFyYW0AAAAFAAQAggAAAHBhcmFtAAAARwAEAC0AAAAhAAAAAQAAAEcABAAtAAAAIgAAAAAAAABHAAQARQAAACEAAAAAAAAARwAEAEUAAAAiAAAAAAAAAEcAAwBeAAAAGQAAAEcABABeAAAAIQAAAAIAAABHAAQAXgAAACIAAAAAAAAARwAEAGQAAAALAAAAHAAAAEcABACJAAAACwAAABkAAAATAAIAAgAAACEAAwADAAAAAgAAABUABAAGAAAAIAAAAAEAAAAXAAQABwAAAAYAAAACAAAAIAAEAAgAAAAHAAAABwAAABYAAwAJAAAAIAAAABcABAAKAAAACQAAAAIAAAAhAAQACwAAAAoAAAAIAAAAIAAEAA8AAAAHAAAACgAAACEABAAQAAAACgAAAA8AAAAVAAQAFAAAACAAAAAAAAAAKwAEABQAAAAVAAAABAAAABwABAAWAAAAFAAAABUAAAAgAAQAFwAAAAcAAAAWAAAAKwAEABQAAAAZAAAAAAAAACsABAAUAAAAGgAAAAEAAAAgAAQAGwAAAAcAAAAGAAAAQxAEACsAAAAJAAAAFQAAACAABAAsAAAAAAAAACsAAAA7AAQALAAAAC0AAAAAAAAAHAAEADAAAAAJAAAAGgAAACAABAAxAAAABwAAADAAAAArAAQABgAAADgAAAAAAAAAIAAEADkAAAAHAAAACQAAABkACQBCAAAACQAAAAEAAAAAAAAAAAAAAAAAAAABAAAAAAAAABsAAwBDAAAAQgAAACAABABEAAAAAAAAAEMAAAA7AAQARAAAAEUAAAAAAAAAKwAEAAkAAABMAAAAAACAPywABQAKAAAATQAAAEwAAABMAAAAKwAEAAkAAABPAAAAAAAAPywABQAKAAAAUAAAAE8AAABPAAAAGQAJAFwAAAAJAAAAAQAAAAAAAAAAAAAAAAAAAAIAAAAFAAAAIAAEAF0AAAAAAAAAXAAAADsABABdAAAAXgAAAAAAAAAXAAQAYgAAABQAAAADAAAAIAAEAGMAAAABAAAAYgAAADsABABjAAAAZAAAAAEAAAAXAAQAZQAAABQAAAACAAAAFAACAGkAAAAXAAQAhQAAAAkAAAAEAAAAKwAEAAkAAACGAAAAAAAAACsABAAUAAAAiAAAAAgAAAAsAAYAYgAAAIkAAACIAAAAiAAAABoAAAA2AAUAAgAAAAQAAAAAAAAAAwAAAPgAAgAFAAAAOwAEAAgAAABbAAAABwAAADsABAAIAAAAYQAAAAcAAAA7AAQADwAAAHsAAAAHAAAAOwAEAAgAAAB8AAAABwAAADsABAAPAAAAggAAAAcAAAA9AAQAXAAAAF8AAABeAAAAaAAEAAcAAABgAAAAXwAAAD4AAwBbAAAAYAAAAD0ABABiAAAAZgAAAGQAAABPAAcAZQAAAGcAAABmAAAAZgAAAAAAAAABAAAAfAAEAAcAAABoAAAAZwAAAD4AAwBhAAAAaAAAAEEABQAbAAAAagAAAGEAAAAZAAAAPQAEAAYAAABrAAAAagAAAEEABQAbAAAAbAAAAFsAAAAZAAAAPQAEAAYAAABtAAAAbAAAAK8ABQBpAAAAbgAAAGsAAABtAAAAqAAEAGkAAABvAAAAbgAAAPcAAwBxAAAAAAAAAPoABABvAAAAcAAAAHEAAAD4AAIAcAAAAEEABQAbAAAAcgAAAGEAAAAaAAAAPQAEAAYAAABzAAAAcgAAAEEABQAbAAAAdAAAAFsAAAAaAAAAPQAEAAYAAAB1AAAAdAAAAK8ABQBpAAAAdgAAAHMAAAB1AAAA+QACAHEAAAD4AAIAcQAAAPUABwBpAAAAdwAAAG4AAAAFAAAAdgAAAHAAAAD3AAMAeQAAAAAAAAD6AAQAdwAAAHgAAAB5AAAA+AACAHgAAAD9AAEA+AACAHkAAAA9AAQABwAAAH0AAABhAAAAPgADAHwAAAB9AAAAOQAFAAoAAAB+AAAADQAAAHwAAAA+AAMAewAAAH4AAAA9AAQAXAAAAH8AAABeAAAAPQAEAAcAAACAAAAAYQAAAD0ABABDAAAAgQAAAEUAAAA9AAQACgAAAIMAAAB7AAAAPgADAIIAAACDAAAAOQAFAAoAAACEAAAAEgAAAIIAAABYAAcAhQAAAIcAAACBAAAAhAAAAAIAAACGAAAAYwAEAH8AAACAAAAAhwAAAP0AAQA4AAEANgAFAAoAAAANAAAAAAAAAAsAAAA3AAMACAAAAAwAAAD4AAIADgAAADsABAAXAAAAGAAAAAcAAAA7AAQAFwAAACMAAAAHAAAAOwAEADEAAAAyAAAABwAAADsABAAxAAAANgAAAAcAAABBAAUAGwAAABwAAAAMAAAAGgAAAD0ABAAGAAAAHQAAABwAAAB8AAQAFAAAAB4AAAAdAAAAQQAFABsAAAAfAAAADAAAABkAAAA9AAQABgAAACAAAAAfAAAAfAAEABQAAAAhAAAAIAAAAFAABwAWAAAAIgAAABkAAAAeAAAAIQAAABkAAAA+AAMAGAAAACIAAABBAAUAGwAAACQAAAAMAAAAGgAAAD0ABAAGAAAAJQAAACQAAAB8AAQAFAAAACYAAAAlAAAAQQAFABsAAAAnAAAADAAAABkAAAA9AAQABgAAACgAAAAnAAAAfAAEABQAAAApAAAAKAAAAFAABwAWAAAAKgAAABkAAAAmAAAAKQAAABoAAAA+AAMAIwAAACoAAAA9AAQAKwAAAC4AAAAtAAAAPQAEABYAAAAvAAAAGAAAAEQQBQAwAAAAMwAAAC4AAAAvAAAAPgADADIAAAAzAAAAPQAEACsAAAA0AAAALQAAAD0ABAAWAAAANQAAACMAAABEEAUAMAAAADcAAAA0AAAANQAAAD4AAwA2AAAANwAAAEEABQA5AAAAOgAAADIAAAA4AAAAPQAEAAkAAAA7AAAAOgAAAEEABQA5AAAAPAAAADYAAAA4AAAAPQAEAAkAAAA9AAAAPAAAAFAABQAKAAAAPgAAADsAAAA9AAAA/gACAD4AAAA4AAEANgAFAAoAAAASAAAAAAAAABAAAAA3AAMADwAAABEAAAD4AAIAEwAAADsABAAPAAAAQQAAAAcAAAA7AAQADwAAAEoAAAAHAAAAPQAEAEMAAABGAAAARQAAAGQABABCAAAARwAAAEYAAABnAAUABwAAAEgAAABHAAAAOAAAAG8ABAAKAAAASQAAAEgAAAA+AAMAQQAAAEkAAAA9AAQACgAAAEsAAAARAAAAgQAFAAoAAABOAAAASwAAAE0AAACFAAUACgAAAFEAAABOAAAAUAAAAD0ABAAKAAAAUgAAAEEAAACDAAUACgAAAFMAAABSAAAATQAAAIUABQAKAAAAVAAAAFEAAABTAAAAPgADAEoAAABUAAAAPQAEAAoAAABVAAAASgAAAIEABQAKAAAAVgAAAFUAAABQAAAAPQAEAAoAAABXAAAAQQAAAIgABQAKAAAAWAAAAFYAAABXAAAA/gACAFgAAAA4AAEA