Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backends/arm/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ runtime.python_library(
resources = [
"vgf/shaders/grid_sampler.glsl",
"vgf/shaders/grid_sampler.spirv.b64",
"vgf/shaders/grid_sampler_sampler.glsl",
"vgf/shaders/grid_sampler_sampler.spirv.b64",
"vgf/shaders/grid_sampler_sampler_align_corners.glsl",
"vgf/shaders/grid_sampler_sampler_align_corners.spirv.b64",
"vgf/shaders/grid_sampler_sampler_int8.glsl",
"vgf/shaders/grid_sampler_sampler_int8.spirv.b64",
"vgf/shaders/grid_sampler_sampler_int8_align_corners.glsl",
"vgf/shaders/grid_sampler_sampler_int8_align_corners.spirv.b64",
],
deps = [
":arm_compile_spec",
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/op_tosa_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _vk_format_component_count(vk_format: str) -> int:
"VK_FORMAT_R32G32_SFLOAT": 2,
"VK_FORMAT_R8G8B8A8_UINT": 4,
"VK_FORMAT_R8G8B8A8_SINT": 4,
"VK_FORMAT_R8G8B8A8_SNORM": 4,
"VK_FORMAT_R16G16B16A16_UINT": 4,
"VK_FORMAT_R16G16B16A16_SINT": 4,
"VK_FORMAT_R16G16B16A16_SFLOAT": 4,
Expand Down
75 changes: 63 additions & 12 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def __init__(self):
class _QParams(NamedTuple):
scale: float
zero_point: int
quant_min: int | None = None
quant_max: int | None = None


def _as_list(x):
Expand Down Expand Up @@ -472,8 +474,53 @@ def _match_pattern(
8: _QParams((0.999 - (-0.999)) / (1 << 8), 0),
16: _QParams((0.99999 - (-0.99999)) / (1 << 16), 0),
},
# grid_sampler image input/output use SNORM-compatible qparams. The grid
# coordinate tensor is intentionally left unquantized.
torch.ops.aten.grid_sampler.default: {
8: _QParams(1.0 / 127.0, 0, -127, 127),
},
}


_fixed_output_qspec_ops: dict[Any, dict[int, _QParams]] = {
torch.ops.aten.grid_sampler.default: {
8: _QParams(1.0 / 127.0, 0, -127, 127),
},
}


def _get_fixed_qparams_qspec(
node_target: Any,
qparams_table: dict[Any, dict[int, _QParams]],
input_act_qspec: QuantizationSpecBase,
) -> FixedQParamsQuantizationSpec | None:
if not isinstance(input_act_qspec, QuantizationSpec):
raise ValueError("Fixed qparams require a QuantizationSpec input.")

num_bits = torch.iinfo(input_act_qspec.dtype).bits
qparams = qparams_table[node_target].get(num_bits)
if qparams is None:
return None

return FixedQParamsQuantizationSpec(
dtype=input_act_qspec.dtype,
scale=qparams.scale,
zero_point=qparams.zero_point,
quant_min=(
input_act_qspec.quant_min
if qparams.quant_min is None
else qparams.quant_min
),
quant_max=(
input_act_qspec.quant_max
if qparams.quant_max is None
else qparams.quant_max
),
qscheme=input_act_qspec.qscheme,
is_dynamic=input_act_qspec.is_dynamic,
)


_one_to_one: set[OpOverload] = {
torch.ops.aten.abs.default,
torch.ops.aten.ceil.default,
Expand Down Expand Up @@ -762,6 +809,16 @@ def any_or_hardtanh_min_zero(n: Node):
_QuantProperty(1, input_act_qspec),
]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target == torch.ops.aten.grid_sampler.default:
image_node = ensure_type(Node, node.args[0])
grid_sampler_image_qspec = quantization_config.get_input_act_qspec(
node, image_node
)
grid_sampler_output_qspec = quantization_config.get_output_act_qspec(node)
if grid_sampler_image_qspec is None or grid_sampler_output_qspec is None:
return None
quant_properties.quant_inputs = [_QuantProperty(0, grid_sampler_image_qspec)]
quant_properties.quant_output = _QuantProperty(0, grid_sampler_output_qspec)
elif node.target in (torch.ops.aten.where.self,):
true_node = ensure_type(Node, node.args[1])
input_qspec = (
Expand Down Expand Up @@ -825,21 +882,15 @@ def any_or_hardtanh_min_zero(n: Node):
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in _fixed_input_qspec_ops:
num_bits = torch.iinfo(input_act_qspec.dtype).bits
qparams = _fixed_input_qspec_ops[node.target][num_bits]

fixed_input_qspec = _get_fixed_qparams_qspec(
node.target, _fixed_input_qspec_ops, input_act_qspec
)
if fixed_input_qspec is None:
return None
quant_properties.quant_inputs = [
_QuantProperty(
0,
FixedQParamsQuantizationSpec(
dtype=input_act_qspec.dtype,
scale=qparams.scale,
zero_point=qparams.zero_point,
quant_min=input_act_qspec.quant_min,
quant_max=input_act_qspec.quant_max,
qscheme=input_act_qspec.qscheme,
is_dynamic=input_act_qspec.is_dynamic,
),
fixed_input_qspec,
)
]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
Expand Down
44 changes: 23 additions & 21 deletions backends/arm/quantizer/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from torchao.quantization.pt2e.quantizer import (
DerivedQuantizationSpec,
FixedQParamsQuantizationSpec,
QuantizationSpec,
QuantizationSpecBase,
SharedQuantizationSpec,
Expand Down Expand Up @@ -295,6 +294,7 @@ def get_input_act_qspec(self, node=None, input_node=None):
# MLETORCH-1853: Fix lazy import when moving files around
from executorch.backends.arm.quantizer.quantization_annotator import (
_fixed_input_qspec_ops,
_get_fixed_qparams_qspec,
)

if node is None or input_node is None:
Expand All @@ -305,28 +305,17 @@ def get_input_act_qspec(self, node=None, input_node=None):
return super().get_input_act_qspec(node, input_node)
else:
return SharedQuantizationSpec((node.args[0], node))
elif node.target == torch.ops.aten.grid_sampler.default:
if input_node != node.args[0]:
return None
input_act_qspec = super().get_input_act_qspec(node, input_node)
return _get_fixed_qparams_qspec(
node.target, _fixed_input_qspec_ops, input_act_qspec
)
elif node.target in _fixed_input_qspec_ops:

input_act_qspec = super().get_input_act_qspec(node, input_node)
if not hasattr(input_act_qspec, "dtype") or not isinstance(
input_act_qspec.dtype, torch.dtype
):
raise ValueError(
f"{node.target} requires an input activation quantization "
"spec to use fixed input qparams."
)
dtype = getattr(input_act_qspec, "dtype", None)
num_bits = torch.iinfo(dtype).bits

qparams = _fixed_input_qspec_ops[node.target][num_bits]
return FixedQParamsQuantizationSpec(
dtype=dtype,
scale=qparams.scale,
zero_point=qparams.zero_point,
quant_min=input_act_qspec.quant_min,
quant_max=input_act_qspec.quant_max,
qscheme=input_act_qspec.qscheme,
is_dynamic=input_act_qspec.is_dynamic,
return _get_fixed_qparams_qspec(
node.target, _fixed_input_qspec_ops, input_act_qspec
)

return super().get_input_act_qspec(node, input_node)
Expand Down Expand Up @@ -371,6 +360,19 @@ def get_output_act_qspec(

if node is None:
return super().get_output_act_qspec()
# MLETORCH-1853: Fix lazy import when moving files around
from executorch.backends.arm.quantizer.quantization_annotator import (
_fixed_output_qspec_ops,
_get_fixed_qparams_qspec,
)

if node.target in _fixed_output_qspec_ops:
output_act_qspec = super().get_output_act_qspec(node)
if output_act_qspec is None:
return None
return _get_fixed_qparams_qspec(
node.target, _fixed_output_qspec_ops, output_act_qspec
)
if node.target not in self.SHARED_OUTPUT_ACT_QSPEC_PATTERNS:
return super().get_output_act_qspec()
if len(node.args) == 0:
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantizer_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def check_pattern(cls, pattern):
(torch.ops.aten.acos.default,),
(torch.ops.aten.atanh.default,),
(torch.ops.aten.einsum.default,),
(torch.ops.aten.grid_sampler.default,),
]
)
TOSA_QUANTIZER_SUPPORT_DICT: dict[tuple[OpOverload, ...], type[PatternCheck] | None] = {
Expand Down
8 changes: 7 additions & 1 deletion backends/arm/scripts/generate_grid_sampler_spirv.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ def main() -> None:
with tempfile.TemporaryDirectory() as tmpdir:
spirv_path = Path(tmpdir) / "grid_sampler.spirv"
subprocess.run( # nosec B603 - glslc path is resolved explicitly.
[glslc, str(args.source), "-o", str(spirv_path)],
[
glslc,
"-fshader-stage=compute",
str(args.source),
"-o",
str(spirv_path),
],
check=True,
)
_write_base64_spirv(spirv_path, args.output)
Expand Down
Loading
Loading