From 179f62b14cd11d87819eb896ad70cde9bf719629 Mon Sep 17 00:00:00 2001 From: Alan Morelos Date: Mon, 16 Jun 2025 08:44:12 -0700 Subject: [PATCH] Revert vulkan changes from D76646172 fixup patch Summary: # Context Need these changes that were reverted in the weekend. Original stack of commits were unable to be merged into main due to an existing lintrunner issue blocking the merge. All the changes already went through [review](https://github.com/pytorch/executorch/pull/11479) and approved. Differential Revision: D76737404 --- backends/vulkan/runtime/gen_vulkan_spv.py | 136 ++- .../vulkan/runtime/graph/ops/glsl/arange.yaml | 4 +- .../runtime/graph/ops/glsl/avg_pool2d.yaml | 2 +- .../runtime/graph/ops/glsl/binary_op.yaml | 2 +- .../graph/ops/glsl/buffer_to_buffer.yaml | 3 +- .../graph/ops/glsl/buffer_to_nchw.yaml | 3 +- .../graph/ops/glsl/copy_channel_offset.yaml | 2 +- .../runtime/graph/ops/glsl/copy_offset.yaml | 2 +- .../ops/glsl/copy_packed_dim_offset.yaml | 2 +- .../runtime/graph/ops/glsl/embedding.yaml | 2 +- .../vulkan/runtime/graph/ops/glsl/flip.yaml | 3 +- .../runtime/graph/ops/glsl/image_to_nchw.yaml | 3 +- .../runtime/graph/ops/glsl/index_select.yaml | 2 +- .../graph/ops/glsl/index_select_channel.yaml | 2 +- .../graph/ops/glsl/nchw_to_buffer.yaml | 3 +- .../runtime/graph/ops/glsl/nchw_to_image.glsl | 6 +- .../runtime/graph/ops/glsl/nchw_to_image.yaml | 3 +- .../vulkan/runtime/graph/ops/glsl/no_op.yaml | 2 +- .../runtime/graph/ops/glsl/permute.yaml | 2 +- .../vulkan/runtime/graph/ops/glsl/repeat.yaml | 2 +- .../runtime/graph/ops/glsl/unary_op.yaml | 4 +- .../vulkan/runtime/graph/ops/glsl/view.yaml | 2 +- .../graph/ops/utils/ShaderNameUtils.cpp | 26 +- backends/vulkan/runtime/vk_api/Types.h | 30 +- backends/vulkan/test/glsl/all_shaders.yaml | 2 +- .../test/op_tests/choose_qparams_test.cpp | 675 +++++++++++ .../vulkan/test/op_tests/dequantize_test.cpp | 1061 +++++++++++++++++ .../test/op_tests/linear_weight_int4_test.cpp | 22 +- .../vulkan/test/op_tests/quantize_test.cpp | 843 +++++++++++++ .../test/op_tests/rotary_embedding_test.cpp | 22 +- backends/vulkan/test/op_tests/sdpa_test.cpp | 20 +- backends/vulkan/test/op_tests/targets.bzl | 64 +- backends/vulkan/test/op_tests/test_utils.cpp | 114 ++ backends/vulkan/test/op_tests/test_utils.h | 32 + .../test/op_tests/utils/gen_benchmark_vk.py | 4 + .../test/op_tests/utils/gen_correctness_vk.py | 2 + .../vulkan/tools/gpuinfo/glsl/warp_size.yaml | 2 +- kernels/quantized/cpu/op_dequantize.cpp | 46 +- kernels/quantized/cpu/op_quantize.cpp | 4 +- kernels/quantized/test/op_dequantize_test.cpp | 90 ++ kernels/quantized/test/op_quantize_test.cpp | 65 + .../core/exec_aten/util/scalar_type_util.h | 5 + 42 files changed, 3169 insertions(+), 152 deletions(-) create mode 100644 backends/vulkan/test/op_tests/choose_qparams_test.cpp create mode 100644 backends/vulkan/test/op_tests/dequantize_test.cpp create mode 100644 backends/vulkan/test/op_tests/quantize_test.cpp create mode 100644 backends/vulkan/test/op_tests/test_utils.cpp create mode 100644 backends/vulkan/test/op_tests/test_utils.h diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 5c59f13fc24..a137a7d538f 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -56,52 +56,97 @@ TYPE_MAPPINGS: Dict[str, Any] = { "IMAGE_T": { 3: { + "double": "image3D", "float": "image3D", "half": "image3D", - "int": "iimage3D", - "uint": "uimage3D", + # integer dtypes "int8": "iimage3D", "uint8": "uimage3D", + "int16": "iimage3D", + "uint16": "uimage3D", + "int32": "iimage3D", + "uint32": "uimage3D", + "int64": "iimage3D", + "uint64": "uimage3D", + # common dtype aliases "bool": "uimage3D", + "int": "iimage3D", + "uint": "uimage3D", }, 2: { + "double": "image2D", "float": "image2D", "half": "image2D", - "int": "iimage2D", - "uint": "uimage2D", + # integer dtypes "int8": "iimage2D", "uint8": "uimage2D", + "int16": "iimage2D", + "uint16": "uimage2D", + "int32": "iimage2D", + "uint32": "uimage2D", + "int64": "iimage2D", + "uint64": "uimage2D", + # common dtype aliases "bool": "uimage2D", + "int": "iimage2D", + "uint": "uimage2D", }, }, "SAMPLER_T": { 3: { + "double": "sampler3D", "float": "sampler3D", "half": "sampler3D", - "int": "isampler3D", - "uint": "usampler3D", + # integer dtypes "int8": "isampler3D", "uint8": "usampler3D", + "int16": "isampler3D", + "uint16": "usampler3D", + "int32": "isampler3D", + "uint32": "usampler3D", + "int64": "isampler3D", + "uint64": "usampler3D", + # common dtype aliases "bool": "usampler3D", + "int": "isampler3D", + "uint": "usampler3D", }, 2: { + "double": "sampler2D", "float": "sampler2D", "half": "sampler2D", - "int": "isampler2D", - "uint": "usampler2D", + # integer dtypes "int8": "isampler2D", "uint8": "usampler2D", + "int16": "isampler2D", + "uint16": "usampler2D", + "int32": "isampler2D", + "uint32": "usampler2D", + "int64": "isampler2D", + "uint64": "usampler2D", + # common dtype aliases "bool": "usampler2D", + "int": "isampler2D", + "uint": "usampler2D", }, }, "IMAGE_FORMAT": { + "double": "rgba32f", "float": "rgba32f", "half": "rgba16f", - "int": "rgba32i", - "uint": "rgba32ui", + # integer dtypes "int8": "rgba8i", "uint8": "rgba8ui", + "int16": "rgba16i", + "uint16": "rgba16ui", + "int32": "rgba32i", + "uint32": "rgba32ui", + "int64": "rgba32i", + "uint64": "rgba32ui", + # common dtype aliases "bool": "rgba8ui", + "int": "rgba32i", + "uint": "rgba32ui", }, } @@ -118,10 +163,18 @@ def define_variable(name: str) -> str: def buffer_scalar_type(dtype: str) -> str: if dtype == "half": return "float16_t" - elif dtype[-1] == "8": - return dtype + "_t" + elif dtype == "float": + return "float" + elif dtype == "double": + return "float64_t" + # integer dtype alias conversion elif dtype == "bool": return "uint8_t" + # we don't want to append _t for int32 or uint32 as int is already 32bit + elif dtype == "int32" or dtype == "uint32": + return "int" if dtype == "int32" else "uint" + elif dtype[-1].isdigit(): + return dtype + "_t" return dtype @@ -129,22 +182,28 @@ def buffer_gvec_type(dtype: str, n: int) -> str: if n == 1: return buffer_scalar_type(dtype) - if dtype == "float": - return f"vec{n}" - if dtype == "uint": - return f"uvec{n}" - elif dtype == "half": - return f"f16vec{n}" - elif dtype == "int": - return f"ivec{n}" - elif dtype == "int8": - return f"i8vec{n}" - elif dtype == "uint8": - return f"u8vec{n}" - elif dtype == "bool": - return f"u8vec{n}" - - raise AssertionError(f"Invalid dtype: {dtype}") + dtype_map = { + "half": f"f16vec{n}", + "float": f"vec{n}", + "double": f"vec{n}", # No 64bit image format support in GLSL + "int8": f"i8vec{n}", + "uint8": f"u8vec{n}", + "int16": f"i16vec{n}", + "uint16": f"u16vec{n}", + "int32": f"ivec{n}", + "int": f"ivec{n}", + "uint32": f"uvec{n}", + "uint": f"uvec{n}", + "int64": f"ivec{n}", # No 64bit image format support in GLSL + "uint64": f"uvec{n}", # No 64bit image format support in GLSL + "bool": f"u8vec{n}", + } + + vector_type = dtype_map.get(dtype) + if vector_type is None: + raise AssertionError(f"Invalid dtype: {dtype}") + + return vector_type def texel_type(dtype: str) -> str: @@ -365,15 +424,22 @@ def define_required_extensions(dtypes: Union[str, List[str]]): if dtype == "half": nbit = "16bit" glsl_type = "float16" - elif dtype == "int16" or dtype == "uint16": - nbit = "16bit" - glsl_type = "int16" - elif dtype == "int8" or dtype == "uint8" or dtype == "bool": + elif dtype == "double": + # We only need to allow float64_t type usage + glsl_type = "float64" + elif dtype in ["int8", "uint8", "bool"]: nbit = "8bit" glsl_type = "int8" + elif dtype in ["int16", "uint16"]: + nbit = "16bit" + glsl_type = "int16" + elif dtype in ["int64", "uint64"]: + # We only need to allow int64_t and uint64_t type usage + glsl_type = "int64" - if nbit is not None and glsl_type is not None: + if nbit is not None: out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n" + if glsl_type is not None: out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n" return out_str @@ -629,6 +695,10 @@ def generateVariantCombinations( elif "VALUE" in value: suffix = value.get("SUFFIX", value["VALUE"]) + if value["VALUE"] in ["int", "uint"]: + raise ValueError( + f"Use int32 or uint32 instead of {value['VALUE']}" + ) param_values.append((param_name, suffix, value["VALUE"])) else: diff --git a/backends/vulkan/runtime/graph/ops/glsl/arange.yaml b/backends/vulkan/runtime/graph/ops/glsl/arange.yaml index e3df8bf73a1..37b2027db85 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/arange.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/arange.yaml @@ -7,13 +7,13 @@ arange: parameter_names_with_default_values: NDIM: 3 - DTYPE: int + DTYPE: int32 STORAGE: texture3d PACKING: C_packed generate_variant_forall: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: arange diff --git a/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml index eddddec0d8d..b1e16dec8d6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml @@ -13,6 +13,6 @@ avg_pool2d: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: avg_pool2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index c0efdd81eb9..accfcf53599 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -17,7 +17,7 @@ binary_op: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: binary_add - NAME: binary_sub diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml index 9abd9c1deac..e8bb86dbf6a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml @@ -12,8 +12,9 @@ buffer_to_buffer: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: buffer_to_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml index e48eab63a64..679e686dc2f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml @@ -13,9 +13,10 @@ buffer_to_nchw: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: buffer_to_nchw - NAME: buffer_to_nchw_no_pc diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml index 414bf8191b9..984d9a09d43 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml @@ -7,6 +7,6 @@ copy_channel_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: copy_channel_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml index 87df7bf9dc1..09f5ca36ea4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml @@ -7,7 +7,7 @@ copy_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 STORAGE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml index e872d64e3c3..6e55876cb28 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml @@ -7,6 +7,6 @@ copy_packed_dim_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: copy_packed_dim_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml b/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml index 5ffe37265b1..0e7b491c433 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml @@ -7,6 +7,6 @@ embedding: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: embedding diff --git a/backends/vulkan/runtime/graph/ops/glsl/flip.yaml b/backends/vulkan/runtime/graph/ops/glsl/flip.yaml index 646fd05e420..f5e7c874773 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/flip.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/flip.yaml @@ -6,8 +6,9 @@ flip: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: flip diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml index 804ce19bdb8..646d8f1be81 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml @@ -14,9 +14,10 @@ image_to_nchw: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: image_to_nchw_texture3d - NAME: image_to_nchw_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml index 5a6c525993e..abef2225cd9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml @@ -7,6 +7,6 @@ index_select: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: index_select diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml index 66cb7ec3f89..a306e3ce47d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml @@ -7,6 +7,6 @@ index_select_channel: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: index_select_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml index 486d710cf55..99e41a0ab6f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml @@ -13,9 +13,10 @@ nchw_to_buffer: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: nchw_to_buffer - NAME: nchw_to_buffer_no_pc diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index 4674822ce6a..f3f604e10cd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -87,5 +87,9 @@ void main() { return; } - write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx)); + $if DTYPE == "double" and DTYPE == "int64": + VEC4_T texel = read_texel(tidx); + write_texel(t_out, lpos_to_pos(lpos, axis_map), texel); + $else: + write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx)); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml index 7e52ec10376..85119c8d508 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml @@ -14,9 +14,10 @@ nchw_to_image: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: nchw_to_image_texture3d - NAME: nchw_to_image_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml index e64e1bd260a..bfeaba2496b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml @@ -12,7 +12,7 @@ no_op: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 STORAGE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml index f678aeedf6e..a90ddcb41ce 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml @@ -7,6 +7,6 @@ permute: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: permute diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml index 526980a0f41..f40d94142e1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml @@ -7,7 +7,7 @@ repeat: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 shader_variants: diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index f13393ce6c7..47f538aee6c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -15,9 +15,9 @@ unary_op: OPERATOR: abs(X) - NAME: clamp OPERATOR: clamp(X, A, B) - - NAME: clamp_int + - NAME: clamp_int32 OPERATOR: clamp(X, A, B) - DTYPE: int + DTYPE: int32 - NAME: cos OPERATOR: cos(X) - NAME: exp diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.yaml b/backends/vulkan/runtime/graph/ops/glsl/view.yaml index ba11a2496a0..33364a25225 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/view.yaml @@ -7,6 +7,6 @@ view: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: view diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index e1ac4e9d40a..6388a8ad091 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -34,24 +34,42 @@ void add_storage_type_suffix( void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { switch (dtype) { + case vkapi::kDouble: + kernel_name += "_double"; + break; case vkapi::kFloat: kernel_name += "_float"; break; case vkapi::kHalf: kernel_name += "_half"; break; - case vkapi::kInt: - kernel_name += "_int"; - break; case vkapi::kChar: case vkapi::kQInt8: kernel_name += "_int8"; break; case vkapi::kByte: - case vkapi::kQUInt8: case vkapi::kBool: + case vkapi::kQUInt8: kernel_name += "_uint8"; break; + case vkapi::kShort: + kernel_name += "_int16"; + break; + case vkapi::kUInt16: + kernel_name += "_uint16"; + break; + case vkapi::kInt: + kernel_name += "_int32"; + break; + case vkapi::kUInt: + kernel_name += "_uint32"; + break; + case vkapi::kLong: + kernel_name += "_int64"; + break; + case vkapi::kUInt64: + kernel_name += "_uint64"; + break; default: break; } diff --git a/backends/vulkan/runtime/vk_api/Types.h b/backends/vulkan/runtime/vk_api/Types.h index f25fe95d72b..b3309aa6c69 100644 --- a/backends/vulkan/runtime/vk_api/Types.h +++ b/backends/vulkan/runtime/vk_api/Types.h @@ -30,11 +30,17 @@ #define VK_FORALL_SCALAR_TYPES(_) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \ - _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ - _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Bool) \ + _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ _(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \ + _(uint16_t, VK_FORMAT_R16G16B16A16_UINT, UInt16) \ + _(int16_t, VK_FORMAT_R16G16B16A16_SINT, Short) \ + _(uint32_t, VK_FORMAT_R32G32B32A32_UINT, UInt) \ + _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ + _(uint64_t, VK_FORMAT_R64G64B64A64_UINT, UInt64) \ + _(int64_t, VK_FORMAT_R64G64B64A64_SINT, Long) \ _(float, VK_FORMAT_FLOAT4, Float) \ + _(double, VK_FORMAT_R64G64B64A64_SFLOAT, Double) \ _(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \ _(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32) @@ -86,17 +92,29 @@ inline VkFormat to_vkformat(const ScalarType t) { */ inline ScalarType element_scalartype(const VkFormat vkformat) { switch (vkformat) { + case VK_FORMAT_R64G64B64A64_SFLOAT: + return kDouble; + case VK_FORMAT_R32G32B32A32_SFLOAT: + return kFloat; + case VK_FORMAT_R16G16B16A16_SFLOAT: + return kHalf; case VK_FORMAT_R8G8B8A8_SINT: return kChar; case VK_FORMAT_R8G8B8A8_UINT: case VK_FORMAT_R8G8B8A8_UNORM: return kByte; + case VK_FORMAT_R16G16B16A16_SINT: + return kShort; + case VK_FORMAT_R16G16B16A16_UINT: + return kUInt16; case VK_FORMAT_R32G32B32A32_SINT: return kInt; - case VK_FORMAT_R32G32B32A32_SFLOAT: - return kFloat; - case VK_FORMAT_R16G16B16A16_SFLOAT: - return kHalf; + case VK_FORMAT_R32G32B32A32_UINT: + return kUInt; + case VK_FORMAT_R64G64B64A64_SINT: + return kLong; + case VK_FORMAT_R64G64B64A64_UINT: + return kUInt64; default: VK_THROW("No corresponding scalar type for unknown VkFormat: ", vkformat); } diff --git a/backends/vulkan/test/glsl/all_shaders.yaml b/backends/vulkan/test/glsl/all_shaders.yaml index 37403c97ac8..4ef934eb105 100644 --- a/backends/vulkan/test/glsl/all_shaders.yaml +++ b/backends/vulkan/test/glsl/all_shaders.yaml @@ -51,7 +51,7 @@ idx_fill_texture: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 shader_variants: - NAME: idx_fill_texture diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp new file mode 100644 index 00000000000..24c856e9d46 --- /dev/null +++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp @@ -0,0 +1,675 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include +#include + +#include "test_utils.h" + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +std::tuple choose_qparams_tensor_out( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ET_UNUSED double eps, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out); + +std::tuple choose_qparams_per_token_asymmetric_out( + const Tensor& input, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out); + +// Wrapper function for choose_qparams_tensor_out without context +Tensor& choose_qparams_tensor_out_no_context( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ET_UNUSED double eps, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out) { + torch::executor::native::choose_qparams_tensor_out( + input, quant_min, quant_max, eps, dtype, scale_out, zero_point_out); + return scale_out; +} + +// Wrapper function for choose_qparams_per_token_asymmetric_out without context +Tensor& choose_qparams_per_token_asymmetric_out_no_context( + const Tensor& input, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out) { + torch::executor::native::choose_qparams_per_token_asymmetric_out( + input, dtype, scale_out, zero_point_out); + return scale_out; +} + +// ATen wrapper for choose_qparams_tensor +std::tuple choose_qparams_tensor_aten( + const at::Tensor& input, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); + auto zero_point_out = at::empty({}, at::device(at::kCPU).dtype(at::kLong)); + double eps = 1e-7; + + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + // Use WRAP_TO_ATEN with the wrapper function + WRAP_TO_ATEN(choose_qparams_tensor_out_no_context, 5) + (input, quant_min, quant_max, eps, et_dtype, scale_out, zero_point_out); + + return {scale_out, zero_point_out}; +} + +// ATen wrapper for choose_qparams_per_token_asymmetric +std::tuple choose_qparams_per_token_asymmetric_aten( + const at::Tensor& input, + at::ScalarType dtype) { + // Calculate output sizes for scale and zero_point tensors + std::vector output_sizes; + for (int64_t i = 0; i < input.dim() - 1; i++) { + output_sizes.push_back(input.size(i)); + } + output_sizes.push_back(1); + + auto scale_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); + auto zero_point_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); + + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + // Use WRAP_TO_ATEN with the wrapper function + WRAP_TO_ATEN(choose_qparams_per_token_asymmetric_out_no_context, 2) + (input, et_dtype, scale_out, zero_point_out); + + return {scale_out, zero_point_out}; +} + +} // namespace native +} // namespace executor +} // namespace torch + +// +// Reference Implementation +// + +/* + * Reference implementation of choose_qparams_tensor + */ +std::tuple choose_qparams_tensor_reference_impl( + const at::Tensor& input, + int64_t quant_min, + int64_t quant_max) { + // Create output tensors + at::Tensor scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_out = + at::empty({}, at::device(at::kCPU).dtype(at::kLong)); + + // Find min and max values in the input tensor + float min_val = input.min().item(); + float max_val = input.max().item(); + + // Extend the [min, max] interval to ensure it contains 0 + min_val = std::min(min_val, 0.f); + max_val = std::max(max_val, 0.f); + + // Calculate scale + double scale = + (static_cast(max_val) - min_val) / (quant_max - quant_min); + + // Handle small scale + constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; + if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { + scale = 0.1; + } + + if (scale < SMALL_SCALE_THRESHOLD) { + float org_scale = scale; + scale = SMALL_SCALE_THRESHOLD; + // Adjust min and max based on new scale + if (min_val == 0.0f) { + max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min); + } else if (max_val == 0.0f) { + min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min_val *= amplifier; + max_val *= amplifier; + } + } + + // Calculate zero point + double zero_point_from_min = quant_min - min_val / static_cast(scale); + double zero_point_from_max = quant_max - max_val / static_cast(scale); + double zero_point_from_min_error = + std::abs(quant_min) - std::abs(min_val / static_cast(scale)); + double zero_point_from_max_error = + std::abs(quant_max) - std::abs(max_val / static_cast(scale)); + double initial_zero_point = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Nudge zero point to be an integer + int64_t nudged_zero_point = 0; + if (initial_zero_point < quant_min) { + nudged_zero_point = quant_min; + } else if (initial_zero_point > quant_max) { + nudged_zero_point = quant_max; + } else { + nudged_zero_point = std::nearbyint(static_cast(initial_zero_point)); + } + + // Set output values - use item_mutable() for scalar tensors + scale_out.fill_(scale); + zero_point_out.fill_(nudged_zero_point); + + return std::make_tuple(scale_out, zero_point_out); +} + +/* + * Reference implementation of choose_qparams_per_token_asymmetric + */ +std::tuple +choose_qparams_per_token_asymmetric_reference_impl( + const at::Tensor& input, + at::ScalarType dtype) { + // For per-token quantization, we need to compute scale and zero_point for + // each token + int64_t quant_min = -128; + int64_t quant_max = 127; + + // Calculate output sizes + std::vector output_sizes; + for (int64_t i = 0; i < input.dim() - 1; i++) { + output_sizes.push_back(input.size(i)); + } + output_sizes.push_back(1); + + // Create output tensors + at::Tensor scale_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); + + // Calculate number of tokens + int64_t num_tokens = 1; + for (int64_t i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + + // Reshape input to [num_tokens, last_dim] + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + + // Process each token + for (int64_t token_idx = 0; token_idx < num_tokens; token_idx++) { + at::Tensor token = reshaped_input[token_idx]; + + // Find min and max values for this token + float min_val = token.min().item(); + float max_val = token.max().item(); + + // Extend the [min, max] interval to ensure it contains 0 + min_val = std::min(min_val, 0.f); + max_val = std::max(max_val, 0.f); + + // Calculate scale + double scale = + (static_cast(max_val) - min_val) / (quant_max - quant_min); + + // Handle small scale + constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; + if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { + scale = 0.1; + } + + if (scale < SMALL_SCALE_THRESHOLD) { + float org_scale = scale; + scale = SMALL_SCALE_THRESHOLD; + // Adjust min and max based on new scale + if (min_val == 0.0f) { + max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min); + } else if (max_val == 0.0f) { + min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min_val *= amplifier; + max_val *= amplifier; + } + } + + // Calculate zero point + double zero_point_from_min = + quant_min - min_val / static_cast(scale); + double zero_point_from_max = + quant_max - max_val / static_cast(scale); + double zero_point_from_min_error = + std::abs(quant_min) - std::abs(min_val / static_cast(scale)); + double zero_point_from_max_error = + std::abs(quant_max) - std::abs(max_val / static_cast(scale)); + double initial_zero_point = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Nudge zero point to be an integer + int64_t nudged_zero_point = 0; + if (initial_zero_point < quant_min) { + nudged_zero_point = quant_min; + } else if (initial_zero_point > quant_max) { + nudged_zero_point = quant_max; + } else { + nudged_zero_point = + std::nearbyint(static_cast(initial_zero_point)); + } + + // Set output values for this token - use index_put_ for safety + scale_out.view({num_tokens, 1}).index_put_({token_idx, 0}, scale); + zero_point_out.view({num_tokens, 1}) + .index_put_({token_idx, 0}, nudged_zero_point); + } + + return std::make_tuple(scale_out, zero_point_out); +} + +// Forward declaration of implementation functions +void test_vulkan_choose_qparams_tensor_impl( + const std::vector& input_sizes, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +void test_vulkan_choose_qparams_per_token_asymmetric_impl( + const std::vector& input_sizes, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_choose_qparams_tensor( + const std::vector& input_sizes, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + // Test with buffer storage + test_vulkan_choose_qparams_tensor_impl( + input_sizes, + quant_min, + quant_max, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_choose_qparams_tensor_impl( + input_sizes, + quant_min, + quant_max, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_choose_qparams_per_token_asymmetric( + const std::vector& input_sizes, + at::ScalarType dtype) { + // Test with buffer storage + test_vulkan_choose_qparams_per_token_asymmetric_impl( + input_sizes, dtype, vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_choose_qparams_per_token_asymmetric_impl( + input_sizes, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +void test_reference_choose_qparams_tensor( + const std::vector& input_sizes, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Get reference output + auto [reference_scale, reference_zero_point] = + choose_qparams_tensor_reference_impl(input, quant_min, quant_max); + + // Get implementation output + auto [impl_scale, impl_zero_point] = + torch::executor::native::choose_qparams_tensor_aten( + input, quant_min, quant_max, dtype); + + // Compare outputs + const bool scale_correct = at::allclose(reference_scale, impl_scale); + const bool zero_point_correct = + at::equal(reference_zero_point, impl_zero_point); + + if (!scale_correct || !zero_point_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference scale:" << std::endl; + std::cout << reference_scale << std::endl; + std::cout << "implementation scale:" << std::endl; + std::cout << impl_scale << std::endl; + std::cout << "reference zero_point:" << std::endl; + std::cout << reference_zero_point << std::endl; + std::cout << "implementation zero_point:" << std::endl; + std::cout << impl_zero_point << std::endl; + } + + ASSERT_TRUE(scale_correct && zero_point_correct); +} + +void test_vulkan_choose_qparams_tensor_impl( + const std::vector& input_sizes, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Get reference output + auto [reference_scale, reference_zero_point] = + torch::executor::native::choose_qparams_tensor_aten( + input, quant_min, quant_max, dtype); + + // Build Vulkan choose_qparams_tensor graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + // Output tensors + const ValueRef r_scale = graph.add_tensor({}, vkapi::kFloat, out_storage); + const ValueRef r_zero_point = graph.add_tensor({}, vkapi::kInt, out_storage); + + VK_GET_OP_FN("choose_qparams.tensor") + (graph, + { + r_input.value, + r_quant_min, + r_quant_max, + r_scale, + r_zero_point, + }); + + ValueRef staging_scale = graph.set_output_tensor(r_scale); + ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan choose_qparams_tensor + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + graph.execute(); + + // Create output tensors to hold the results - use types that match GPU output + at::Tensor vk_scale = + at::empty({}, at::device(at::kCPU).dtype(at::kFloat)).contiguous(); + at::Tensor vk_zero_point = + at::empty({}, at::device(at::kCPU).dtype(at::kInt)).contiguous(); + + // Copy results from GPU to CPU + graph.copy_from_staging( + staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); + graph.copy_from_staging( + staging_zero_point, + vk_zero_point.mutable_data_ptr(), + vk_zero_point.numel()); + + // Convert reference values to match Vulkan output types for comparison + at::Tensor reference_scale_float = reference_scale.to(at::kFloat); + at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt); + + // Compare outputs + const bool scale_correct = at::allclose(reference_scale_float, vk_scale); + const bool zero_point_correct = + at::equal(reference_zero_point_int, vk_zero_point); + + if (!scale_correct || !zero_point_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + // make sure that there arent a ton of elements in the input tensor + if (input.numel() < 100) { + std::cout << "input:" << std::endl; + std::cout << input << "\n" << std::endl; + std::cout << "reference scale:" << std::endl; + std::cout << reference_scale << std::endl; + std::cout << "vulkan scale:" << std::endl; + std::cout << vk_scale << "\n" << std::endl; + std::cout << "reference zero_point:" << std::endl; + std::cout << reference_zero_point << std::endl; + std::cout << "vulkan zero_point:" << std::endl; + std::cout << vk_zero_point << std::endl; + } + } + + ASSERT_TRUE(scale_correct && zero_point_correct); +} + +TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) { + test_reference_choose_qparams_tensor( + {2, 3, 4}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + +void test_reference_choose_qparams_per_token_asymmetric( + const std::vector& input_sizes, + at::ScalarType dtype) { + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Get reference output + auto [reference_scale, reference_zero_point] = + choose_qparams_per_token_asymmetric_reference_impl(input, dtype); + + // Get implementation output + auto [impl_scale, impl_zero_point] = + torch::executor::native::choose_qparams_per_token_asymmetric_aten( + input, dtype); + + // Compare outputs + const bool scale_correct = at::allclose(reference_scale, impl_scale); + const bool zero_point_correct = + at::equal(reference_zero_point, impl_zero_point); + + if (!scale_correct || !zero_point_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference scale:" << std::endl; + std::cout << reference_scale << std::endl; + std::cout << "implementation scale:" << std::endl; + std::cout << impl_scale << std::endl; + std::cout << "reference zero_point:" << std::endl; + std::cout << reference_zero_point << std::endl; + std::cout << "implementation zero_point:" << std::endl; + std::cout << impl_zero_point << std::endl; + } + + ASSERT_TRUE(scale_correct && zero_point_correct); +} + +void test_vulkan_choose_qparams_per_token_asymmetric_impl( + const std::vector& input_sizes, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Calculate output sizes + std::vector output_sizes; + for (int64_t i = 0; i < input.dim() - 1; i++) { + output_sizes.push_back(input.size(i)); + } + output_sizes.push_back(1); + + // Get reference output + auto [reference_scale, reference_zero_point] = + torch::executor::native::choose_qparams_per_token_asymmetric_aten( + input, dtype); + + // Build Vulkan choose_qparams_per_token_asymmetric graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + // Output tensors + const ValueRef r_scale = + graph.add_tensor(output_sizes, vkapi::kFloat, out_storage); + const ValueRef r_zero_point = + graph.add_tensor(output_sizes, vkapi::kInt, out_storage); + + VK_GET_OP_FN("choose_qparams_per_token_asymmetric.default") + (graph, + { + r_input.value, + r_scale, + r_zero_point, + }); + + ValueRef staging_scale = graph.set_output_tensor(r_scale); + ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan choose_qparams_per_token_asymmetric + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + graph.execute(); + + // Create output tensors to hold the results - use types that match GPU output + at::Tensor vk_scale = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kFloat)) + .contiguous(); + at::Tensor vk_zero_point = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kInt)) + .contiguous(); + + // Copy results from GPU to CPU + graph.copy_from_staging( + staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); + graph.copy_from_staging( + staging_zero_point, + vk_zero_point.mutable_data_ptr(), + vk_zero_point.numel()); + + // Convert reference values to match Vulkan output types for comparison + at::Tensor reference_scale_float = reference_scale.to(at::kFloat); + at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt); + + // Compare outputs + const bool scale_correct = at::allclose(reference_scale_float, vk_scale); + const bool zero_point_correct = + at::equal(reference_zero_point_int, vk_zero_point); + if (!scale_correct || !zero_point_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + if (input.numel() < 100) { + std::cout << "input:" << std::endl; + std::cout << input << "\n" << std::endl; + std::cout << "reference scale:" << std::endl; + std::cout << reference_scale << std::endl; + std::cout << "vulkan scale:" << std::endl; + std::cout << vk_scale << "\n" << std::endl; + std::cout << "reference zero_point:" << std::endl; + std::cout << reference_zero_point << std::endl; + std::cout << "vulkan zero_point:" << std::endl; + std::cout << vk_zero_point << std::endl; + } + } + + ASSERT_TRUE(scale_correct && zero_point_correct); +} + +TEST( + VulkanChooseQparamsTest, + test_reference_choose_qparams_per_token_asymmetric_int8) { + test_reference_choose_qparams_per_token_asymmetric( + {2, 3, 4}, // input sizes (2*3=6 tokens) + at::kChar); +} diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp new file mode 100644 index 00000000000..7b155c8f98b --- /dev/null +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -0,0 +1,1061 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include +#include + +#include "test_utils.h" + +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +Tensor& dequantize_per_tensor_out( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out); + +Tensor& dequantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out); + +// Wrapper function for dequantize_per_tensor_out without context +Tensor& dequantize_per_tensor_out_no_context( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); +} + +// Wrapper function for dequantize_per_token_out without context +Tensor& dequantize_per_token_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_token_out( + input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); +} + +// ATen wrapper for dequantize_per_tensor +at::Tensor dequantize_per_tensor_aten( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + executorch::aten::optional opt_et_out_dtype(et_out_dtype); + + WRAP_TO_ATEN(dequantize_per_tensor_out_no_context, 7) + (input, + scale, + zero_point, + quant_min, + quant_max, + et_dtype, + opt_et_out_dtype, + out); + return out; +} + +// ATen wrapper for dequantize_per_token +at::Tensor dequantize_per_token_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + WRAP_TO_ATEN(dequantize_per_token_out_no_context, 7) + (input, + scale, + zero_points, + quant_min, + quant_max, + et_dtype, + et_out_dtype, + out); + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch + +void check_dequantize_args( + int64_t quant_min, + int64_t quant_max, + c10::ScalarType in_dtype, + c10::ScalarType out_dtype) { + using namespace vkcompute; + + // Check that quant_min <= quant_max + VK_CHECK_COND( + quant_min <= quant_max, + "quant_min must be <= quant_max, got quant_min: ", + quant_min, + " quant_max: ", + quant_max); + + // Check that input dtype is a quantized type + switch (in_dtype) { + case c10::kByte: + case c10::kChar: + case c10::kShort: + case c10::kInt: + case c10::kLong: + break; + default: + VK_THROW( + "Unsupported input dtype: ", + scalar_type_name(in_dtype), + " (", + static_cast(in_dtype), + ")"); + } + + // Check that output dtype is a floating point type + switch (out_dtype) { + case c10::kHalf: + case c10::kFloat: + case c10::kDouble: + break; + default: + VK_THROW( + "Unsupported output dtype: ", + scalar_type_name(out_dtype), + " (", + static_cast(out_dtype), + ")"); + } +} + +// +// Reference Implementation +// + +/* + * Reference implementation of dequantize_per_tensor + */ +at::Tensor dequantize_per_tensor_reference_impl( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, out_dtype); + + // Dequantize the input tensor + at::Tensor flat_input = input.flatten(); + at::Tensor flat_out = out.flatten(); + + // Store casted values to avoid repeated casting + const int32_t zero_point_int32 = static_cast(zero_point); + const float scale_float = static_cast(scale); + + for (int i = 0; i < flat_input.numel(); i++) { + double dequantized_value = 0.0; + + // Extract quantized value and dequantize based on input dtype + // Following the CPU implementation pattern: (input - zero_point) * scale + if (dtype == at::kByte) { + uint8_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kChar) { + int8_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kShort) { + int16_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kInt) { + int32_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kLong) { + int64_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } + + // Store result based on output dtype + if (out_dtype == at::kFloat) { + flat_out[i] = static_cast(dequantized_value); + } else if (out_dtype == at::kDouble) { + flat_out[i] = dequantized_value; + } else if (out_dtype == at::kHalf) { + flat_out[i] = static_cast(dequantized_value); + } + } + + return out.reshape(input.sizes()); +} + +/* + * Reference implementation of dequantize_per_token + */ +at::Tensor dequantize_per_token_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, out_dtype); + + // Calculate number of tokens + int num_tokens = 1; + for (int i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + + // Verify that the number of tokens matches the size of scale and zero_point + // tensors + assert(num_tokens == scale.numel()); + assert(num_tokens == zero_point.numel()); + + // Reshape input to [num_tokens, last_dim] + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)}); + + // Dequantize each token separately + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + // Get scale and zero_point for this token + float token_scale = scale[token_idx].item(); + int64_t token_zero_point = zero_point[token_idx].item(); + + // Store casted values to avoid repeated casting + const int32_t token_zero_point_int32 = + static_cast(token_zero_point); + + // Dequantize the token + for (int i = 0; i < input.size(-1); i++) { + double dequantized_value = 0.0; + + // Extract quantized value and dequantize based on input dtype + // Following the CPU implementation pattern: (input - zero_point) * scale + if (dtype == at::kByte) { + uint8_t qvalue = reshaped_input[token_idx][i].item(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else if (dtype == at::kChar) { + int8_t qvalue = reshaped_input[token_idx][i].item(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else if (dtype == at::kShort) { + int16_t qvalue = reshaped_input[token_idx][i].item(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else if (dtype == at::kInt) { + int32_t qvalue = reshaped_input[token_idx][i].item(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else if (dtype == at::kLong) { + int64_t qvalue = reshaped_input[token_idx][i].item(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else { + throw std::runtime_error("Unsupported input dtype"); + } + + // Store result based on output dtype + if (out_dtype == at::kFloat) { + reshaped_out[token_idx][i] = static_cast(dequantized_value); + } else if (out_dtype == at::kDouble) { + reshaped_out[token_idx][i] = dequantized_value; + } else if (out_dtype == at::kHalf) { + reshaped_out[token_idx][i] = static_cast(dequantized_value); + } + } + } + + return out; +} + +// Forward declaration of implementation functions +void test_vulkan_dequantize_per_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +void test_vulkan_dequantize_per_token_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_per_tensor( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Test with buffer storage + test_vulkan_dequantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_dequantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_per_token( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Test with buffer storage + test_vulkan_dequantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_dequantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +void test_reference_dequantize_per_tensor( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create a quantized input tensor with values from quant_min to quant_max + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Get reference output + at::Tensor reference_out = dequantize_per_tensor_reference_impl( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::dequantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, impl_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "implementation:" << std::endl; + std::cout << impl_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_dequantize_per_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create a quantized input tensor with values from quant_min to quant_max + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Get reference output + at::Tensor reference_out = + torch::executor::native::dequantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); + + // Build Vulkan dequantize_per_tensor graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(dtype), in_storage); + + const ValueRef r_scale = graph.add_scalar(scale); + const ValueRef r_zero_point = graph.add_scalar(zero_point); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); + + VK_GET_OP_FN("dequantize_per_tensor.default") + (graph, + { + r_input.value, + r_scale, + r_zero_point, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan dequantize_per_tensor + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, vk_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +// Test cases for dequantize_per_tensor +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_uint8_to_float) { + test_reference_dequantize_per_tensor( + {2, 3, 4}, // input sizes + 0.1, // scale + 5, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_int8_to_float) { + test_reference_dequantize_per_tensor( + {3, 4, 5}, // input sizes + 0.05, // scale + 0, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_int32_to_float) { + test_reference_dequantize_per_tensor( + {4, 6, 2}, // input sizes + 0.2, // scale + 2, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_uint8_to_half) { + test_reference_dequantize_per_tensor( + {7, 4}, // input sizes + 0.1, // scale + 10, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype (uint8) + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_int32_to_half) { + test_reference_dequantize_per_tensor( + {2, 6, 5}, // input sizes + 0.3, // scale + -10, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +} + +void test_reference_dequantize_per_token( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + int num_tokens = 1; + for (int i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + // Create input tensor with quantized values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + float step = 1.0f; + if (input.size(-1) > 1) { + step = static_cast(quant_max - quant_min) / (input.size(-1) - 1); + } + + for (int i = 0; i < input.size(-1); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } + } + } + + // Reshape back to original dimensions + input = reshaped_input.reshape(input_sizes_int64); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor reference_out = dequantize_per_token_reference_impl( + input, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype, + out_dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::dequantize_per_token_aten( + input, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype, + out_dtype); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, impl_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "implementation:" << std::endl; + std::cout << impl_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_dequantize_per_token_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + int num_tokens = 1; + for (int i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + // Create input tensor with quantized values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + float step = 1.0f; + if (input.size(-1) > 1) { + step = static_cast(quant_max - quant_min) / (input.size(-1) - 1); + } + + for (int i = 0; i < input.size(-1); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } + } + } + + // Reshape back to original dimensions + input = reshaped_input.reshape(input_sizes_int64); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor reference_out = torch::executor::native::dequantize_per_token_aten( + input, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype, + out_dtype); + + // Build Vulkan dequantize_per_token graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(dtype), in_storage); + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), vkapi::kFloat, in_storage); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), vkapi::kInt, in_storage); + + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); + + VK_GET_OP_FN("dequantize_per_token.default") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Convert scale tensor to float and copy to GPU + at::Tensor scale_float = scale_tensor.to(at::kFloat); + graph.copy_into_staging( + r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); + + // Convert zero_point tensor to int and copy to GPU + at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); + graph.copy_into_staging( + r_zero_point.staging, + zero_point_int.const_data_ptr(), + zero_point_int.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, vk_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +// Test cases for dequantize_per_token +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_uint8_to_float) { + std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector zero_points = {5, 10, 15, 20, 25, 30}; + + test_reference_dequantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_int8_to_float) { + std::vector scales = {0.05, 0.1, 0.15, 0.2}; + std::vector zero_points = {0, -5, 5, 10}; + + test_reference_dequantize_per_token( + {2, 2, 5}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_int32_to_float) { + std::vector scales = {0.05, 0.1, 0.15, 0.2}; + std::vector zero_points = {0, -5, 5, 10}; + + test_reference_dequantize_per_token( + {2, 2, 10}, // input sizes (2*2=4 tokens) + scales, + zero_points, + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_int8_to_half) { + std::vector scales = {0.05, 0.1, 0.15, 0.2}; + std::vector zero_points = {0, -5, 5, 10}; + + test_reference_dequantize_per_token( + {4, 1, 5}, // input sizes (4*1=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype (int8) + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_int32_to_half) { + std::vector scales = {0.05, 0.1}; + std::vector zero_points = {0, -5}; + + test_reference_dequantize_per_token( + {2, 2}, // input sizes (2 tokens) + scales, + zero_points, + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +} diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp index b95b7b3aa6d..e48042c4620 100644 --- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp +++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp @@ -14,6 +14,8 @@ #include #include +#include "test_utils.h" + #include // @@ -201,26 +203,6 @@ void test_reference_linear_qcs4w( ASSERT_TRUE(at::allclose(out, out_ref)); } -vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { - using namespace vkcompute; - switch (at_scalartype) { - case c10::kFloat: - return vkapi::kFloat; - case c10::kHalf: - return vkapi::kHalf; - case c10::kInt: - return vkapi::kInt; - case c10::kLong: - return vkapi::kInt; - case c10::kChar: - return vkapi::kChar; - case c10::kByte: - return vkapi::kByte; - default: - VK_THROW("Unsupported at::ScalarType!"); - } -} - void test_vulkan_linear_qga4w_impl( const int B, const int M, diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp new file mode 100644 index 00000000000..8b79dc1ce6b --- /dev/null +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -0,0 +1,843 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include +#include + +#include "test_utils.h" + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +Tensor& quantize_per_tensor_out( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + +Tensor& quantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + +// Wrapper function for quantize_per_tensor_out without context +Tensor& quantize_per_tensor_out_no_context( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + return torch::executor::native::quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, dtype, out); +} + +// Wrapper function for quantize_per_token_out without context +Tensor& quantize_per_token_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + return torch::executor::native::quantize_per_token_out( + input, scale, zero_point, quant_min, quant_max, dtype, out); +} + +// ATen wrapper for quantize_per_tensor +at::Tensor quantize_per_tensor_aten( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto out = at::empty_like(input, dtype); + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + WRAP_TO_ATEN(quantize_per_tensor_out_no_context, 6) + (input, scale, zero_point, quant_min, quant_max, et_dtype, out); + return out; +} + +// ATen wrapper for quantize_per_token +at::Tensor quantize_per_token_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto out = at::empty_like(input, dtype); + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + WRAP_TO_ATEN(quantize_per_token_out_no_context, 6) + (input, scale, zero_point, quant_min, quant_max, et_dtype, out); + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch + +void check_quantize_args( + int64_t quant_min, + int64_t quant_max, + c10::ScalarType out_dtype) { + using namespace vkcompute; + int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0; + switch (out_dtype) { + case c10::kByte: + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + break; + case c10::kChar: + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + break; + case c10::kBits16: + case c10::kUInt16: + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + break; + case c10::kShort: + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + break; + case c10::kInt: + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + break; + default: + VK_CHECK_COND(false, "Unsupported dtype: ", scalar_type_name(out_dtype)); + } + VK_CHECK_COND( + quant_min >= quant_min_lower_bound, + "quant_min out of bound for dtype, expected quant_min_lower_bound: ", + quant_min_lower_bound, + " actual quant_min: ", + quant_min); + + VK_CHECK_COND( + quant_max <= quant_max_upper_bound, + "quant_max out of bound for dtype, expected quant_max_upper_bound: ", + quant_max_upper_bound, + " actual quant_max: ", + quant_max); +} + +// +// Reference Implementation +// + +/* + * Reference implementation of quantize_per_tensor + */ +at::Tensor quantize_per_tensor_reference_impl( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, dtype); + + // Quantize the input tensor + float inv_scale = 1.0 / scale; + + // Iterate through the tensor and quantize each element + at::Tensor float_input = input.to(at::kFloat); + at::Tensor float_values = float_input.flatten(); + + auto out_flat = out.flatten(); + + for (int i = 0; i < float_values.numel(); i++) { + float value = float_values[i].item(); + int64_t qvalue = zero_point + std::nearbyint(inv_scale * value); + + qvalue = std::max(qvalue, quant_min); + qvalue = std::min(qvalue, quant_max); + + if (dtype == at::kByte) { + out_flat[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + out_flat[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + out_flat[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + out_flat[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + out_flat[i] = static_cast(qvalue); + } + } + + return out.reshape(input.sizes()); +} + +/* + * Reference implementation of quantize_per_token + */ +at::Tensor quantize_per_token_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, dtype); + + // Calculate number of tokens + int num_tokens = 1; + for (int i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + + // Verify that the number of tokens matches the size of scale and zero_point + // tensors + assert(num_tokens == scale.numel()); + assert(num_tokens == zero_point.numel()); + + // Reshape input to [num_tokens, last_dim] + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)}); + + // Quantize each token separately + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + // Use float for scale since Vulkan doesn't support double + float token_scale = scale[token_idx].item(); + // Use int for zero_point since Vulkan doesn't support int64_t + int token_zero_point = zero_point[token_idx].item(); + + float inv_scale = 1.0 / token_scale; + + // Quantize the token + for (int i = 0; i < input.size(-1); i++) { + float value = reshaped_input[token_idx][i].item(); + int qvalue = token_zero_point + std::nearbyint(inv_scale * value); + + qvalue = std::max(qvalue, quant_min); + qvalue = std::min(qvalue, quant_max); + + if (dtype == at::kByte) { + reshaped_out[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + reshaped_out[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + reshaped_out[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + reshaped_out[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + reshaped_out[token_idx][i] = static_cast(qvalue); + } + } + } + + return out; +} + +// Forward declaration of implementation functions +void test_vulkan_quantize_per_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +void test_vulkan_quantize_per_token_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_quantize_per_tensor( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + // Test with buffer storage + test_vulkan_quantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_quantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_quantize_per_token( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + // Test with buffer storage + test_vulkan_quantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_quantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +void test_reference_quantize_per_tensor( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + check_quantize_args(quant_min, quant_max, dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Fill with a simple pattern: values from 0 to 1 in steps + float step = 1.0f / (input.numel() - 1); + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + flat_input[i] = i * step; + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Get reference output + at::Tensor reference_out = quantize_per_tensor_reference_impl( + input, scale, zero_point, quant_min, quant_max, dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::quantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype); + + // Convert to int for consistent display regardless of underlying type + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor impl_int = impl_out.to(at::kInt); + + const bool output_correct = at::equal(reference_int, impl_int); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_int - impl_int); + + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "my_reference:" << std::endl; + std::cout << impl_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_quantize_per_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + check_quantize_args(quant_min, quant_max, dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Get reference output + at::Tensor reference_out = torch::executor::native::quantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype); + + // Build Vulkan quantize_per_tensor graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + const ValueRef r_scale = graph.add_scalar(scale); + const ValueRef r_zero_point = graph.add_scalar(zero_point); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(dtype), out_storage); + + VK_GET_OP_FN("quantize_per_tensor.default") + (graph, + { + r_input.value, + r_scale, + r_zero_point, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan quantize_per_tensor + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + // For quantized types, we need to compare the actual integer values + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor vk_int = vk_out.to(at::kInt); + + const bool output_correct = at::equal(reference_int, vk_int); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_int - vk_int); + + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_tensor_float_to_int8) { + test_reference_quantize_per_tensor( + {2, 3, 4}, // input sizes + 0.1, // scale + 0, // zero_point + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_tensor_float_to_int32) { + test_reference_quantize_per_tensor( + {2, 3, 4}, // input sizes + 0.04, // scale + 5, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_tensor_half_to_uint8) { + test_reference_quantize_per_tensor( + {2, 3, 4}, // input sizes + 0.2, // scale + 2, // zero_point + 0, // quant_min + 255, // quant_max + at::kHalf, + at::kByte); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_tensor_half_to_int32) { + test_reference_quantize_per_tensor( + {2, 3, 4}, // input sizes + 0.01, // scale + 1, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kHalf, + at::kInt); +} + +void test_reference_quantize_per_token( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + check_quantize_args(quant_min, quant_max, dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Fill with a simple pattern: values from 0 to 1 in steps + float step = 1.0 / (input.numel() - 1); + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + flat_input[i] = i * step; + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Calculate number of tokens + int num_tokens = 1; + for (int i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + + // Verify that the number of tokens matches the size of scales and zero_points + ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor reference_out = quantize_per_token_reference_impl( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::quantize_per_token_aten( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); + + // Convert to int for consistent display regardless of underlying type + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor impl_int = impl_out.to(at::kInt); + + const bool output_correct = at::equal(reference_int, impl_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "my_reference:" << std::endl; + std::cout << impl_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_quantize_per_token_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + check_quantize_args(quant_min, quant_max, dtype); + int num_tokens = 1; + for (int i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + // Create input tensor with random values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output to show what we would compare against + at::Tensor reference_out = torch::executor::native::quantize_per_token_aten( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), vkapi::kFloat, in_storage); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), vkapi::kInt, in_storage); + + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(dtype), out_storage); + + VK_GET_OP_FN("quantize_per_token.default") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Convert scale tensor to float and copy to GPU + at::Tensor scale_float = scale_tensor.to(at::kFloat); + graph.copy_into_staging( + r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); + + // Convert zero_point tensor to int and copy to GPU + at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); + graph.copy_into_staging( + r_zero_point.staging, + zero_point_int.const_data_ptr(), + zero_point_int.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor vk_int = vk_out.to(at::kInt); + + const bool output_correct = at::equal(reference_int, vk_int); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_int - vk_int); + + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_token_float_to_int8) { + std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; + std::vector zero_points = {1, 2, 3, 0, -1, -2}; + + test_reference_quantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_token_float_to_int32) { + std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; + std::vector zero_points = {1, 2, 3, 0, -1, -2}; + + test_reference_quantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_token_half_to_int32) { + std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; + std::vector zero_points = {1, 2, 3, 0, -1, -2}; + + test_reference_quantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kHalf, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_token_half_to_uint8) { + std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; + std::vector zero_points = {1, 2, 3, 0, -1, -2}; + + test_reference_quantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kHalf, + at::kByte); +} diff --git a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp index 534bb577e7a..eebbb89ab40 100644 --- a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp +++ b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp @@ -14,6 +14,8 @@ #include #include +#include "test_utils.h" + #include // @@ -55,26 +57,6 @@ std::pair rotary_embedding_impl( // Test functions // -vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { - using namespace vkcompute; - switch (at_scalartype) { - case c10::kFloat: - return vkapi::kFloat; - case c10::kHalf: - return vkapi::kHalf; - case c10::kInt: - return vkapi::kInt; - case c10::kLong: - return vkapi::kInt; - case c10::kChar: - return vkapi::kChar; - case c10::kByte: - return vkapi::kByte; - default: - VK_THROW("Unsupported at::ScalarType!"); - } -} - void test_reference( const int n_heads = 4, const int n_kv_heads = 2, diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index 772039eda6a..79b679674a5 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -18,6 +18,8 @@ #include #include +#include "test_utils.h" + #include #include @@ -261,24 +263,6 @@ void test_reference_sdpa( } } -vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { - using namespace vkcompute; - switch (at_scalartype) { - case c10::kFloat: - return vkapi::kFloat; - case c10::kHalf: - return vkapi::kHalf; - case c10::kInt: - return vkapi::kInt; - case c10::kLong: - return vkapi::kInt; - case c10::kChar: - return vkapi::kChar; - default: - VK_THROW("Unsupported at::ScalarType!"); - } -} - void test_vulkan_sdpa( const int start_input_pos, const int base_sequence_len, diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 5c9afa40762..0d014c7ef29 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -142,6 +142,28 @@ def define_common_targets(is_fbcode = False): platforms = get_platforms(), ) + runtime.cxx_library( + name = "test_utils", + srcs = [ + "test_utils.cpp", + ], + headers = [ + "test_utils.h", + ], + exported_headers = [ + "test_utils.h", + ], + deps = [ + "//executorch/backends/vulkan:vulkan_graph_runtime", + "//executorch/runtime/core/exec_aten:lib", + runtime.external_dep_location("libtorch"), + ], + visibility = [ + "//executorch/backends/vulkan/test/op_tests/...", + "@EXECUTORCH_CLIENTS", + ], + ) + define_test_targets( "compute_graph_op_tests", src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]" @@ -150,9 +172,47 @@ def define_common_targets(is_fbcode = False): define_test_targets( "sdpa_test", extra_deps = [ + ":test_utils", "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", "//executorch/extension/tensor:tensor", ] ) - define_test_targets("linear_weight_int4_test") - define_test_targets("rotary_embedding_test") + define_test_targets( + "quantize_test", + extra_deps = [ + ":test_utils", + "//executorch/kernels/quantized/cpu:op_quantize", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) + define_test_targets( + "dequantize_test", + extra_deps = [ + ":test_utils", + "//executorch/kernels/quantized/cpu:op_dequantize", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) + define_test_targets( + "choose_qparams_test", + extra_deps = [ + ":test_utils", + "//executorch/kernels/quantized/cpu:op_choose_qparams", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) + define_test_targets( + "linear_weight_int4_test", + extra_deps = [ + ":test_utils", + ] + ) + define_test_targets( + "rotary_embedding_test", + extra_deps = [ + ":test_utils", + ] + ) diff --git a/backends/vulkan/test/op_tests/test_utils.cpp b/backends/vulkan/test/op_tests/test_utils.cpp new file mode 100644 index 00000000000..196f079be2c --- /dev/null +++ b/backends/vulkan/test/op_tests/test_utils.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "test_utils.h" + +#include + +executorch::aten::ScalarType at_scalartype_to_et_scalartype( + at::ScalarType dtype) { + using ScalarType = executorch::aten::ScalarType; + switch (dtype) { + case at::kByte: + return ScalarType::Byte; + case at::kChar: + return ScalarType::Char; + case at::kShort: + return ScalarType::Short; + case at::kInt: + return ScalarType::Int; + case at::kLong: + return ScalarType::Long; + case at::kHalf: + return ScalarType::Half; + case at::kFloat: + return ScalarType::Float; + case at::kDouble: + return ScalarType::Double; + default: + throw std::runtime_error("Unsupported dtype"); + } +} + +std::string scalar_type_name(c10::ScalarType dtype) { + switch (dtype) { + case c10::kLong: + return "c10::kLong"; + case c10::kShort: + return "c10::kShort"; + case c10::kComplexHalf: + return "c10::kComplexHalf"; + case c10::kComplexFloat: + return "c10::kComplexFloat"; + case c10::kComplexDouble: + return "c10::kComplexDouble"; + case c10::kBool: + return "c10::kBool"; + case c10::kQInt8: + return "c10::kQInt8"; + case c10::kQUInt8: + return "c10::kQUInt8"; + case c10::kQInt32: + return "c10::kQInt32"; + case c10::kBFloat16: + return "c10::kBFloat16"; + case c10::kQUInt4x2: + return "c10::kQUInt4x2"; + case c10::kQUInt2x4: + return "c10::kQUInt2x4"; + case c10::kFloat: + return "c10::kFloat"; + case c10::kHalf: + return "c10::kHalf"; + case c10::kInt: + return "c10::kInt"; + case c10::kChar: + return "c10::kChar"; + case c10::kByte: + return "c10::kByte"; + case c10::kDouble: + return "c10::kDouble"; + case c10::kUInt16: + return "c10::kUInt16"; + case c10::kBits16: + return "c10::kBits16"; + default: + return "Unknown(" + std::to_string(static_cast(dtype)) + ")"; + } +} + +vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { + using namespace vkcompute; + switch (at_scalartype) { + case c10::kHalf: + return vkapi::kHalf; + case c10::kFloat: + return vkapi::kFloat; + case c10::kDouble: + return vkapi::kDouble; + case c10::kInt: + return vkapi::kInt; + case c10::kLong: + return vkapi::kLong; + case c10::kChar: + return vkapi::kChar; + case c10::kByte: + return vkapi::kByte; + case c10::kShort: + return vkapi::kShort; + case c10::kUInt16: + return vkapi::kUInt16; + default: + VK_THROW( + "Unsupported at::ScalarType: ", + scalar_type_name(at_scalartype), + " (", + static_cast(at_scalartype), + ")"); + } +} diff --git a/backends/vulkan/test/op_tests/test_utils.h b/backends/vulkan/test/op_tests/test_utils.h new file mode 100644 index 00000000000..369767007e0 --- /dev/null +++ b/backends/vulkan/test/op_tests/test_utils.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +/** + * Convert at::ScalarType to executorch::ScalarType + */ +executorch::aten::ScalarType at_scalartype_to_et_scalartype( + at::ScalarType dtype); + +/** + * Get the string name of a c10::ScalarType for better error messages + */ +std::string scalar_type_name(c10::ScalarType dtype); + +/** + * Convert c10::ScalarType to vkcompute::vkapi::ScalarType + */ +vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype); diff --git a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py index 65bb959f6d1..a054fdf1a19 100644 --- a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py @@ -177,6 +177,8 @@ def generate_benchmark_fixture(self) -> str: vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {{ switch (at_scalartype) {{ + case c10::kDouble: + return vkapi::kDouble; case c10::kFloat: return vkapi::kFloat; case c10::kHalf: @@ -187,6 +189,8 @@ def generate_benchmark_fixture(self) -> str: return vkapi::kInt; case c10::kChar: return vkapi::kChar; + case c10::kBool: + return vkapi::kBool; default: VK_THROW("Unsupported at::ScalarType!"); }} diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index 4f0d2ff11ef..e7cf5ba92a5 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -110,6 +110,8 @@ def gen_parameterization(self) -> str: vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { switch (at_scalartype) { + case c10::kDouble: + return vkapi::kDouble; case c10::kFloat: return vkapi::kFloat; case c10::kHalf: diff --git a/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml b/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml index a00bba2bc5a..69587bd38d0 100644 --- a/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml +++ b/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml @@ -6,7 +6,7 @@ warp_size: parameter_names_with_default_values: - DTYPE: int + DTYPE: int32 STORAGE: buffer generate_variant_forall: METHOD: diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index c1f2770d3d6..876099598dc 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -288,16 +288,16 @@ Tensor& dequantize_per_tensor_out( static_cast(scale)); \ } \ } break; -#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ +#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOATH_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ break; switch (input.scalar_type()) { @@ -459,7 +459,8 @@ Tensor& dequantize_per_channel_out( } \ out_data_ptr[current_ix] = \ static_cast( \ - input_data_ptr[current_ix] - zero_point) * \ + input_data_ptr[current_ix] - \ + static_cast(zero_point)) * \ _scale; \ } \ }, \ @@ -478,23 +479,24 @@ Tensor& dequantize_per_channel_out( apply_over_dim_list( \ [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ out_data_ptr[in_ix] = static_cast( \ - (input_data_ptr[in_ix] - _zero_point) * _scale); \ + (input_data_ptr[in_ix] - static_cast(_zero_point)) * \ + _scale); \ }, \ input, \ optional_dim_list, \ channel_ix); \ } \ break; -#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ +#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOATH_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ break; switch (input.scalar_type()) { diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index 4665c3d665b..d0b7c882f8e 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -150,7 +150,7 @@ Tensor& quantize_per_tensor_out( break; switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE); + ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE); default: ET_CHECK_MSG( false, @@ -346,7 +346,7 @@ Tensor& quantize_per_channel_out( break; switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE); + ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE); default: ET_CHECK_MSG( false, diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp index bbda1590a10..4a0c195e3ab 100644 --- a/kernels/quantized/test/op_dequantize_test.cpp +++ b/kernels/quantized/test/op_dequantize_test.cpp @@ -67,6 +67,96 @@ TEST(OpDequantizeOutTest, AllDtypesSupported) { test_dtype(); } +/// Test all supported output dtypes for dequantization +template +void test_output_dtype() { + TensorFactory tf; + + Tensor input = tf.full({3, 5}, 100); + double scale = 0.5; + int64_t zero_point = 30; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + // (100 - 30) * 0.5 = 35 + Tensor expected = tfo.full({3, 5}, 35); + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(OUT_DTYPE), + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpDequantizeOutTest, AllOutputDtypesSupported) { + et_pal_init(); + test_output_dtype(); + test_output_dtype(); + test_output_dtype(); +} + +TEST(OpDequantizeOutTest, HalfOutput) { + et_pal_init(); + TensorFactory tf; + + Tensor input = tf.full({3, 5}, 10); + double scale = 0.5; + int64_t zero_point = 100000; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + // (10 - 100000) * 0.5 = -49995 + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(ScalarType::Half), + out); + + // The expected result should be (10 - 100000) * 0.5 = -49995 + Tensor expected = tfo.full({3, 5}, -49995); + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpDequantizeOutTest, DoubleOutput) { + et_pal_init(); + TensorFactory tf; + + Tensor input = tf.full({3, 5}, 10); + double scale = 0.5; + int64_t zero_point = 100000; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(ScalarType::Double), + out); + + // The expected result should be (10 - 100000) * 0.5 = -49995 + Tensor expected = tfo.full({3, 5}, -49995); + EXPECT_TENSOR_EQ(out, expected); +} + TEST(OpDequantizeOutTest, NonWholeNumbers) { et_pal_init(); TensorFactory tf; diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp index 704d8d06c5c..5cd17223d80 100644 --- a/kernels/quantized/test/op_quantize_test.cpp +++ b/kernels/quantized/test/op_quantize_test.cpp @@ -49,6 +49,32 @@ void test_dtype() { EXPECT_TENSOR_EQ(out, expected); } +template +void test_input_dtype() { + TensorFactory tf_input; + + Tensor input = tf_input.full({3, 5}, 4); + double scale = 0.5; + int64_t zero_point = 108; + int64_t quant_min = 0; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + // 4 / 0.5 + 108 = 116 + Tensor expected = tfo.full({3, 5}, 116); + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, AllInputDtypesSupported) { + test_input_dtype(); + test_input_dtype(); + test_input_dtype(); +} + TEST(OpQuantizeOutTest, AllDtypesSupported) { test_dtype(); test_dtype(); @@ -58,6 +84,45 @@ TEST(OpQuantizeOutTest, AllDtypesSupported) { test_dtype(); } +TEST(OpQuantizeOutTest, DoubleInputTest) { + TensorFactory tf_double; + + // Test with a more complex value that might have precision differences + Tensor input = tf_double.full({2, 3}, 3.14159265359); + double scale = 0.01; + int64_t zero_point = -100; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // 3.14159265359 / 0.01 - 100 = 214.159265359 + Tensor expected = tfo.full({2, 3}, 214); + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, HalfInputTest) { + TensorFactory tf_half; + + Tensor input = tf_half.full({2, 3}, 2.5); + double scale = 0.5; + int64_t zero_point = 10; + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // 2.5 / 0.5 + 10 = 15 + Tensor expected = tfo.full({2, 3}, 15); + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + TEST(OpQuantizeOutTest, TensorArgOverload) { TensorFactory tf_float; TensorFactory tf_double; diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 6f81146e925..d81b3ad4d0f 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -199,6 +199,11 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) _(ANOTHER_INPUT, float, Float) \ _(ANOTHER_INPUT, double, Double) +#define ET_FORALL_FLOATH_TYPES_WITH(ANOTHER_INPUT, _) \ + _(ANOTHER_INPUT, float, Float) \ + _(ANOTHER_INPUT, double, Double) \ + _(ANOTHER_INPUT, ::executorch::aten::Half, Half) + #define ET_FORALL_FLOAT_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double)