From 73fc9a9d9872d74a1af41df75f1458764c7ac62d Mon Sep 17 00:00:00 2001 From: morelos Date: Fri, 13 Jun 2025 15:49:24 -0700 Subject: [PATCH] [ET-VK] additional dtype runtime support along with their aliases Pull Request resolved: https://github.com/pytorch/executorch/pull/11365 # Context This diff generally aims provide improvements to the existing framework for defining dtype GLSL shader variants, along with setting up support that would be necessary for future shader implementations that wish to support int64 and double dtypes. In order to allow doubles as input/output dtypes for dequantization and quantization, this diff will create the dtype runtime support on the Vulkan backend in Executorch by establishing the relationship between different tensor types and different GLSL types. # Changes The main changes are included in `gen_vulkan_spv.py` which maps the relationship between different dtypes and their GLSL types. For instance, we add aliases for every common dtype which includes `uint8`, `int8`, `uint16`, `int16`, `uint32`, `int32`, `uint64`, `int64`, and `double`. We maintain support for `int`, `uint`, and `bool` alises such that we can avoid making the change overly complex while supporting the most common recognizable alias (int). Furthermore, this diff also modifies the vulkan api to incorporate new types, namely `uint32_t`, `double`, the int16 and int64 variants. We then make sure that the `ShaderNameUtils` (which is commonly used by most operators for creating their variant names), utilizes the new aliasing. Beyond that we also throw an exception to disallow YAML files to include just "int", and to be more specific, like with "int32". We then modify dozens of files to switch to the new alias of int32. Furthermore, we also include double in certain shaders that are used as intermediaries for image to buffer to nchw converisons. ghstack-source-id: 290376491 @exported-using-ghexport Differential Revision: [D75959063](https://our.internmc.facebook.com/intern/diff/D75959063/) --- backends/vulkan/runtime/gen_vulkan_spv.py | 136 +++++++++++++----- .../vulkan/runtime/graph/ops/glsl/arange.yaml | 4 +- .../runtime/graph/ops/glsl/avg_pool2d.yaml | 2 +- .../runtime/graph/ops/glsl/binary_op.yaml | 2 +- .../graph/ops/glsl/buffer_to_buffer.yaml | 3 +- .../graph/ops/glsl/buffer_to_nchw.yaml | 3 +- .../graph/ops/glsl/copy_channel_offset.yaml | 2 +- .../runtime/graph/ops/glsl/copy_offset.yaml | 2 +- .../ops/glsl/copy_packed_dim_offset.yaml | 2 +- .../runtime/graph/ops/glsl/embedding.yaml | 2 +- .../vulkan/runtime/graph/ops/glsl/flip.yaml | 3 +- .../runtime/graph/ops/glsl/image_to_nchw.yaml | 3 +- .../runtime/graph/ops/glsl/index_select.yaml | 2 +- .../graph/ops/glsl/index_select_channel.yaml | 2 +- .../graph/ops/glsl/nchw_to_buffer.yaml | 3 +- .../runtime/graph/ops/glsl/nchw_to_image.glsl | 6 +- .../runtime/graph/ops/glsl/nchw_to_image.yaml | 3 +- .../vulkan/runtime/graph/ops/glsl/no_op.yaml | 2 +- .../runtime/graph/ops/glsl/permute.yaml | 2 +- .../vulkan/runtime/graph/ops/glsl/repeat.yaml | 2 +- .../runtime/graph/ops/glsl/unary_op.yaml | 4 +- .../vulkan/runtime/graph/ops/glsl/view.yaml | 2 +- .../graph/ops/utils/ShaderNameUtils.cpp | 26 +++- backends/vulkan/runtime/vk_api/Types.h | 30 +++- backends/vulkan/test/glsl/all_shaders.yaml | 2 +- .../test/op_tests/utils/gen_benchmark_vk.py | 4 + .../test/op_tests/utils/gen_correctness_vk.py | 2 + .../vulkan/tools/gpuinfo/glsl/warp_size.yaml | 2 +- 28 files changed, 190 insertions(+), 68 deletions(-) diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 5c59f13fc24..a137a7d538f 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -56,52 +56,97 @@ TYPE_MAPPINGS: Dict[str, Any] = { "IMAGE_T": { 3: { + "double": "image3D", "float": "image3D", "half": "image3D", - "int": "iimage3D", - "uint": "uimage3D", + # integer dtypes "int8": "iimage3D", "uint8": "uimage3D", + "int16": "iimage3D", + "uint16": "uimage3D", + "int32": "iimage3D", + "uint32": "uimage3D", + "int64": "iimage3D", + "uint64": "uimage3D", + # common dtype aliases "bool": "uimage3D", + "int": "iimage3D", + "uint": "uimage3D", }, 2: { + "double": "image2D", "float": "image2D", "half": "image2D", - "int": "iimage2D", - "uint": "uimage2D", + # integer dtypes "int8": "iimage2D", "uint8": "uimage2D", + "int16": "iimage2D", + "uint16": "uimage2D", + "int32": "iimage2D", + "uint32": "uimage2D", + "int64": "iimage2D", + "uint64": "uimage2D", + # common dtype aliases "bool": "uimage2D", + "int": "iimage2D", + "uint": "uimage2D", }, }, "SAMPLER_T": { 3: { + "double": "sampler3D", "float": "sampler3D", "half": "sampler3D", - "int": "isampler3D", - "uint": "usampler3D", + # integer dtypes "int8": "isampler3D", "uint8": "usampler3D", + "int16": "isampler3D", + "uint16": "usampler3D", + "int32": "isampler3D", + "uint32": "usampler3D", + "int64": "isampler3D", + "uint64": "usampler3D", + # common dtype aliases "bool": "usampler3D", + "int": "isampler3D", + "uint": "usampler3D", }, 2: { + "double": "sampler2D", "float": "sampler2D", "half": "sampler2D", - "int": "isampler2D", - "uint": "usampler2D", + # integer dtypes "int8": "isampler2D", "uint8": "usampler2D", + "int16": "isampler2D", + "uint16": "usampler2D", + "int32": "isampler2D", + "uint32": "usampler2D", + "int64": "isampler2D", + "uint64": "usampler2D", + # common dtype aliases "bool": "usampler2D", + "int": "isampler2D", + "uint": "usampler2D", }, }, "IMAGE_FORMAT": { + "double": "rgba32f", "float": "rgba32f", "half": "rgba16f", - "int": "rgba32i", - "uint": "rgba32ui", + # integer dtypes "int8": "rgba8i", "uint8": "rgba8ui", + "int16": "rgba16i", + "uint16": "rgba16ui", + "int32": "rgba32i", + "uint32": "rgba32ui", + "int64": "rgba32i", + "uint64": "rgba32ui", + # common dtype aliases "bool": "rgba8ui", + "int": "rgba32i", + "uint": "rgba32ui", }, } @@ -118,10 +163,18 @@ def define_variable(name: str) -> str: def buffer_scalar_type(dtype: str) -> str: if dtype == "half": return "float16_t" - elif dtype[-1] == "8": - return dtype + "_t" + elif dtype == "float": + return "float" + elif dtype == "double": + return "float64_t" + # integer dtype alias conversion elif dtype == "bool": return "uint8_t" + # we don't want to append _t for int32 or uint32 as int is already 32bit + elif dtype == "int32" or dtype == "uint32": + return "int" if dtype == "int32" else "uint" + elif dtype[-1].isdigit(): + return dtype + "_t" return dtype @@ -129,22 +182,28 @@ def buffer_gvec_type(dtype: str, n: int) -> str: if n == 1: return buffer_scalar_type(dtype) - if dtype == "float": - return f"vec{n}" - if dtype == "uint": - return f"uvec{n}" - elif dtype == "half": - return f"f16vec{n}" - elif dtype == "int": - return f"ivec{n}" - elif dtype == "int8": - return f"i8vec{n}" - elif dtype == "uint8": - return f"u8vec{n}" - elif dtype == "bool": - return f"u8vec{n}" - - raise AssertionError(f"Invalid dtype: {dtype}") + dtype_map = { + "half": f"f16vec{n}", + "float": f"vec{n}", + "double": f"vec{n}", # No 64bit image format support in GLSL + "int8": f"i8vec{n}", + "uint8": f"u8vec{n}", + "int16": f"i16vec{n}", + "uint16": f"u16vec{n}", + "int32": f"ivec{n}", + "int": f"ivec{n}", + "uint32": f"uvec{n}", + "uint": f"uvec{n}", + "int64": f"ivec{n}", # No 64bit image format support in GLSL + "uint64": f"uvec{n}", # No 64bit image format support in GLSL + "bool": f"u8vec{n}", + } + + vector_type = dtype_map.get(dtype) + if vector_type is None: + raise AssertionError(f"Invalid dtype: {dtype}") + + return vector_type def texel_type(dtype: str) -> str: @@ -365,15 +424,22 @@ def define_required_extensions(dtypes: Union[str, List[str]]): if dtype == "half": nbit = "16bit" glsl_type = "float16" - elif dtype == "int16" or dtype == "uint16": - nbit = "16bit" - glsl_type = "int16" - elif dtype == "int8" or dtype == "uint8" or dtype == "bool": + elif dtype == "double": + # We only need to allow float64_t type usage + glsl_type = "float64" + elif dtype in ["int8", "uint8", "bool"]: nbit = "8bit" glsl_type = "int8" + elif dtype in ["int16", "uint16"]: + nbit = "16bit" + glsl_type = "int16" + elif dtype in ["int64", "uint64"]: + # We only need to allow int64_t and uint64_t type usage + glsl_type = "int64" - if nbit is not None and glsl_type is not None: + if nbit is not None: out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n" + if glsl_type is not None: out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n" return out_str @@ -629,6 +695,10 @@ def generateVariantCombinations( elif "VALUE" in value: suffix = value.get("SUFFIX", value["VALUE"]) + if value["VALUE"] in ["int", "uint"]: + raise ValueError( + f"Use int32 or uint32 instead of {value['VALUE']}" + ) param_values.append((param_name, suffix, value["VALUE"])) else: diff --git a/backends/vulkan/runtime/graph/ops/glsl/arange.yaml b/backends/vulkan/runtime/graph/ops/glsl/arange.yaml index e3df8bf73a1..37b2027db85 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/arange.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/arange.yaml @@ -7,13 +7,13 @@ arange: parameter_names_with_default_values: NDIM: 3 - DTYPE: int + DTYPE: int32 STORAGE: texture3d PACKING: C_packed generate_variant_forall: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: arange diff --git a/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml index eddddec0d8d..b1e16dec8d6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml @@ -13,6 +13,6 @@ avg_pool2d: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: avg_pool2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index c0efdd81eb9..accfcf53599 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -17,7 +17,7 @@ binary_op: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: binary_add - NAME: binary_sub diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml index 9abd9c1deac..e8bb86dbf6a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml @@ -12,8 +12,9 @@ buffer_to_buffer: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: buffer_to_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml index e48eab63a64..679e686dc2f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml @@ -13,9 +13,10 @@ buffer_to_nchw: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: buffer_to_nchw - NAME: buffer_to_nchw_no_pc diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml index 414bf8191b9..984d9a09d43 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml @@ -7,6 +7,6 @@ copy_channel_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: copy_channel_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml index 87df7bf9dc1..09f5ca36ea4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml @@ -7,7 +7,7 @@ copy_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 STORAGE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml index e872d64e3c3..6e55876cb28 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml @@ -7,6 +7,6 @@ copy_packed_dim_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: copy_packed_dim_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml b/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml index 5ffe37265b1..0e7b491c433 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml @@ -7,6 +7,6 @@ embedding: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: embedding diff --git a/backends/vulkan/runtime/graph/ops/glsl/flip.yaml b/backends/vulkan/runtime/graph/ops/glsl/flip.yaml index 646fd05e420..f5e7c874773 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/flip.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/flip.yaml @@ -6,8 +6,9 @@ flip: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: flip diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml index 804ce19bdb8..646d8f1be81 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml @@ -14,9 +14,10 @@ image_to_nchw: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: image_to_nchw_texture3d - NAME: image_to_nchw_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml index 5a6c525993e..abef2225cd9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml @@ -7,6 +7,6 @@ index_select: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: index_select diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml index 66cb7ec3f89..a306e3ce47d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml @@ -7,6 +7,6 @@ index_select_channel: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: index_select_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml index 486d710cf55..99e41a0ab6f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml @@ -13,9 +13,10 @@ nchw_to_buffer: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: nchw_to_buffer - NAME: nchw_to_buffer_no_pc diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index 4674822ce6a..f3f604e10cd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -87,5 +87,9 @@ void main() { return; } - write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx)); + $if DTYPE == "double" and DTYPE == "int64": + VEC4_T texel = read_texel(tidx); + write_texel(t_out, lpos_to_pos(lpos, axis_map), texel); + $else: + write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx)); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml index 7e52ec10376..85119c8d508 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml @@ -14,9 +14,10 @@ nchw_to_image: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: nchw_to_image_texture3d - NAME: nchw_to_image_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml index e64e1bd260a..bfeaba2496b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml @@ -12,7 +12,7 @@ no_op: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 STORAGE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml index f678aeedf6e..a90ddcb41ce 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml @@ -7,6 +7,6 @@ permute: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: permute diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml index 526980a0f41..f40d94142e1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml @@ -7,7 +7,7 @@ repeat: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 shader_variants: diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index f13393ce6c7..47f538aee6c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -15,9 +15,9 @@ unary_op: OPERATOR: abs(X) - NAME: clamp OPERATOR: clamp(X, A, B) - - NAME: clamp_int + - NAME: clamp_int32 OPERATOR: clamp(X, A, B) - DTYPE: int + DTYPE: int32 - NAME: cos OPERATOR: cos(X) - NAME: exp diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.yaml b/backends/vulkan/runtime/graph/ops/glsl/view.yaml index ba11a2496a0..33364a25225 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/view.yaml @@ -7,6 +7,6 @@ view: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: view diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index e1ac4e9d40a..6388a8ad091 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -34,24 +34,42 @@ void add_storage_type_suffix( void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { switch (dtype) { + case vkapi::kDouble: + kernel_name += "_double"; + break; case vkapi::kFloat: kernel_name += "_float"; break; case vkapi::kHalf: kernel_name += "_half"; break; - case vkapi::kInt: - kernel_name += "_int"; - break; case vkapi::kChar: case vkapi::kQInt8: kernel_name += "_int8"; break; case vkapi::kByte: - case vkapi::kQUInt8: case vkapi::kBool: + case vkapi::kQUInt8: kernel_name += "_uint8"; break; + case vkapi::kShort: + kernel_name += "_int16"; + break; + case vkapi::kUInt16: + kernel_name += "_uint16"; + break; + case vkapi::kInt: + kernel_name += "_int32"; + break; + case vkapi::kUInt: + kernel_name += "_uint32"; + break; + case vkapi::kLong: + kernel_name += "_int64"; + break; + case vkapi::kUInt64: + kernel_name += "_uint64"; + break; default: break; } diff --git a/backends/vulkan/runtime/vk_api/Types.h b/backends/vulkan/runtime/vk_api/Types.h index f25fe95d72b..b3309aa6c69 100644 --- a/backends/vulkan/runtime/vk_api/Types.h +++ b/backends/vulkan/runtime/vk_api/Types.h @@ -30,11 +30,17 @@ #define VK_FORALL_SCALAR_TYPES(_) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \ - _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ - _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Bool) \ + _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ _(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \ + _(uint16_t, VK_FORMAT_R16G16B16A16_UINT, UInt16) \ + _(int16_t, VK_FORMAT_R16G16B16A16_SINT, Short) \ + _(uint32_t, VK_FORMAT_R32G32B32A32_UINT, UInt) \ + _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ + _(uint64_t, VK_FORMAT_R64G64B64A64_UINT, UInt64) \ + _(int64_t, VK_FORMAT_R64G64B64A64_SINT, Long) \ _(float, VK_FORMAT_FLOAT4, Float) \ + _(double, VK_FORMAT_R64G64B64A64_SFLOAT, Double) \ _(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \ _(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32) @@ -86,17 +92,29 @@ inline VkFormat to_vkformat(const ScalarType t) { */ inline ScalarType element_scalartype(const VkFormat vkformat) { switch (vkformat) { + case VK_FORMAT_R64G64B64A64_SFLOAT: + return kDouble; + case VK_FORMAT_R32G32B32A32_SFLOAT: + return kFloat; + case VK_FORMAT_R16G16B16A16_SFLOAT: + return kHalf; case VK_FORMAT_R8G8B8A8_SINT: return kChar; case VK_FORMAT_R8G8B8A8_UINT: case VK_FORMAT_R8G8B8A8_UNORM: return kByte; + case VK_FORMAT_R16G16B16A16_SINT: + return kShort; + case VK_FORMAT_R16G16B16A16_UINT: + return kUInt16; case VK_FORMAT_R32G32B32A32_SINT: return kInt; - case VK_FORMAT_R32G32B32A32_SFLOAT: - return kFloat; - case VK_FORMAT_R16G16B16A16_SFLOAT: - return kHalf; + case VK_FORMAT_R32G32B32A32_UINT: + return kUInt; + case VK_FORMAT_R64G64B64A64_SINT: + return kLong; + case VK_FORMAT_R64G64B64A64_UINT: + return kUInt64; default: VK_THROW("No corresponding scalar type for unknown VkFormat: ", vkformat); } diff --git a/backends/vulkan/test/glsl/all_shaders.yaml b/backends/vulkan/test/glsl/all_shaders.yaml index 37403c97ac8..4ef934eb105 100644 --- a/backends/vulkan/test/glsl/all_shaders.yaml +++ b/backends/vulkan/test/glsl/all_shaders.yaml @@ -51,7 +51,7 @@ idx_fill_texture: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 shader_variants: - NAME: idx_fill_texture diff --git a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py index 65bb959f6d1..a054fdf1a19 100644 --- a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py @@ -177,6 +177,8 @@ def generate_benchmark_fixture(self) -> str: vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {{ switch (at_scalartype) {{ + case c10::kDouble: + return vkapi::kDouble; case c10::kFloat: return vkapi::kFloat; case c10::kHalf: @@ -187,6 +189,8 @@ def generate_benchmark_fixture(self) -> str: return vkapi::kInt; case c10::kChar: return vkapi::kChar; + case c10::kBool: + return vkapi::kBool; default: VK_THROW("Unsupported at::ScalarType!"); }} diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index 4f0d2ff11ef..e7cf5ba92a5 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -110,6 +110,8 @@ def gen_parameterization(self) -> str: vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { switch (at_scalartype) { + case c10::kDouble: + return vkapi::kDouble; case c10::kFloat: return vkapi::kFloat; case c10::kHalf: diff --git a/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml b/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml index a00bba2bc5a..69587bd38d0 100644 --- a/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml +++ b/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml @@ -6,7 +6,7 @@ warp_size: parameter_names_with_default_values: - DTYPE: int + DTYPE: int32 STORAGE: buffer generate_variant_forall: METHOD: