From 80ecb39dc7a1400477b3dc6ab0b76c92aeb7c1d4 Mon Sep 17 00:00:00 2001 From: morelos Date: Wed, 4 Jun 2025 11:02:51 -0700 Subject: [PATCH] [ET-VK] double, short, and uint16 dtype runtime support Creating support for double, short, and uint16 for quantization ops. Registering the short keyword since theres already support. Also changing the cpu implementation to support half Differential Revision: [D75959063](https://our.internmc.facebook.com/intern/diff/D75959063/) [ghstack-poisoned] --- backends/vulkan/runtime/gen_vulkan_spv.py | 37 +++++++++++++++++-- .../graph/ops/glsl/buffer_to_buffer.yaml | 2 + .../graph/ops/glsl/buffer_to_nchw.yaml | 2 + .../runtime/graph/ops/glsl/image_to_nchw.yaml | 2 + .../graph/ops/glsl/nchw_to_buffer.yaml | 2 + .../runtime/graph/ops/glsl/nchw_to_image.yaml | 2 + .../graph/ops/utils/ShaderNameUtils.cpp | 9 +++++ backends/vulkan/runtime/vk_api/Types.h | 13 ++++++- .../test/op_tests/utils/gen_correctness_vk.py | 4 +- kernels/quantized/cpu/op_quantize.cpp | 4 +- 10 files changed, 68 insertions(+), 9 deletions(-) diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 5c59f13fc24..c36309b889f 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -56,6 +56,7 @@ TYPE_MAPPINGS: Dict[str, Any] = { "IMAGE_T": { 3: { + "double": "image3D", "float": "image3D", "half": "image3D", "int": "iimage3D", @@ -63,8 +64,11 @@ "int8": "iimage3D", "uint8": "uimage3D", "bool": "uimage3D", + "short": "iimage3D", + "uint16": "uimage3D", }, 2: { + "double": "image2D", "float": "image2D", "half": "image2D", "int": "iimage2D", @@ -72,10 +76,13 @@ "int8": "iimage2D", "uint8": "uimage2D", "bool": "uimage2D", + "short": "iimage2D", + "uint16": "uimage2D", }, }, "SAMPLER_T": { 3: { + "double": "sampler3D", "float": "sampler3D", "half": "sampler3D", "int": "isampler3D", @@ -83,8 +90,11 @@ "int8": "isampler3D", "uint8": "usampler3D", "bool": "usampler3D", + "short": "isampler3D", + "uint16": "usampler3D", }, 2: { + "double": "sampler2D", "float": "sampler2D", "half": "sampler2D", "int": "isampler2D", @@ -92,9 +102,12 @@ "int8": "isampler2D", "uint8": "usampler2D", "bool": "usampler2D", + "short": "isampler2D", + "uint16": "usampler2D", }, }, "IMAGE_FORMAT": { + "double": "rgba64f", "float": "rgba32f", "half": "rgba16f", "int": "rgba32i", @@ -102,6 +115,8 @@ "int8": "rgba8i", "uint8": "rgba8ui", "bool": "rgba8ui", + "short": "rgba16i", + "uint16": "rgba16ui", }, } @@ -118,10 +133,16 @@ def define_variable(name: str) -> str: def buffer_scalar_type(dtype: str) -> str: if dtype == "half": return "float16_t" - elif dtype[-1] == "8": - return dtype + "_t" + elif dtype == "float": + return "float" + elif dtype == "double": + return "float64_t" + elif dtype == "short": + return "int16_t" elif dtype == "bool": return "uint8_t" + elif dtype[-1].isdigit(): + return dtype + "_t" return dtype @@ -135,8 +156,14 @@ def buffer_gvec_type(dtype: str, n: int) -> str: return f"uvec{n}" elif dtype == "half": return f"f16vec{n}" + elif dtype == "double": + return f"dvec{n}" elif dtype == "int": return f"ivec{n}" + elif dtype == "short": + return f"i16vec{n}" + elif dtype == "uint16": + return f"u16vec{n}" elif dtype == "int8": return f"i8vec{n}" elif dtype == "uint8": @@ -365,12 +392,14 @@ def define_required_extensions(dtypes: Union[str, List[str]]): if dtype == "half": nbit = "16bit" glsl_type = "float16" - elif dtype == "int16" or dtype == "uint16": + elif dtype == "short" or dtype == "int16" or dtype == "uint16": nbit = "16bit" glsl_type = "int16" - elif dtype == "int8" or dtype == "uint8" or dtype == "bool": + elif dtype == "bool" or dtype == "int8" or dtype == "uint8": nbit = "8bit" glsl_type = "int8" + elif dtype == "double" or dtype == "float64": + out_str += "#extension GL_ARB_gpu_shader_fp64 : require\n" if nbit is not None and glsl_type is not None: out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n" diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml index 9abd9c1deac..6a25a31d925 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml @@ -15,5 +15,7 @@ buffer_to_buffer: - VALUE: int - VALUE: int8 - VALUE: uint8 + - VALUE: short + - VALUE: uint16 shader_variants: - NAME: buffer_to_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml index 25b3657c2eb..5bdc6454c4b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml @@ -15,5 +15,7 @@ buffer_to_nchw: - VALUE: int - VALUE: int8 - VALUE: uint8 + - VALUE: short + - VALUE: uint16 shader_variants: - NAME: buffer_to_nchw diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml index c1045d93afc..f0f45cd37a2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml @@ -16,6 +16,8 @@ image_to_nchw: - VALUE: int - VALUE: int8 - VALUE: uint8 + - VALUE: short + - VALUE: uint16 shader_variants: - NAME: image_to_nchw_texture3d - NAME: image_to_nchw_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml index 486d710cf55..864649aec64 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml @@ -16,6 +16,8 @@ nchw_to_buffer: - VALUE: int - VALUE: int8 - VALUE: uint8 + - VALUE: short + - VALUE: uint16 shader_variants: - NAME: nchw_to_buffer - NAME: nchw_to_buffer_no_pc diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml index 7e52ec10376..fa4f2d9c827 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml @@ -17,6 +17,8 @@ nchw_to_image: - VALUE: int - VALUE: int8 - VALUE: uint8 + - VALUE: short + - VALUE: uint16 shader_variants: - NAME: nchw_to_image_texture3d - NAME: nchw_to_image_texture2d diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index e1ac4e9d40a..2b68af3a9fd 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -34,6 +34,9 @@ void add_storage_type_suffix( void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { switch (dtype) { + case vkapi::kDouble: + kernel_name += "_double"; + break; case vkapi::kFloat: kernel_name += "_float"; break; @@ -43,6 +46,9 @@ void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { case vkapi::kInt: kernel_name += "_int"; break; + case vkapi::kShort: + kernel_name += "_short"; + break; case vkapi::kChar: case vkapi::kQInt8: kernel_name += "_int8"; @@ -52,6 +58,9 @@ void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { case vkapi::kBool: kernel_name += "_uint8"; break; + case vkapi::kUInt16: + kernel_name += "_uint16"; + break; default: break; } diff --git a/backends/vulkan/runtime/vk_api/Types.h b/backends/vulkan/runtime/vk_api/Types.h index 6531bf4710c..f34cd30e5a6 100644 --- a/backends/vulkan/runtime/vk_api/Types.h +++ b/backends/vulkan/runtime/vk_api/Types.h @@ -25,11 +25,14 @@ #define VK_FORALL_SCALAR_TYPES(_) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \ - _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ - _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Bool) \ + _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ _(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \ + _(uint16_t, VK_FORMAT_R16G16B16A16_UINT, UInt16) \ + _(int16_t, VK_FORMAT_R16G16B16A16_SINT, Short) \ + _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ _(float, VK_FORMAT_FLOAT4, Float) \ + _(double, VK_FORMAT_R64G64B64A64_SFLOAT, Double) \ _(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \ _(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32) @@ -88,10 +91,16 @@ inline ScalarType element_scalartype(const VkFormat vkformat) { return kByte; case VK_FORMAT_R32G32B32A32_SINT: return kInt; + case VK_FORMAT_R64G64B64A64_SFLOAT: + return kDouble; case VK_FORMAT_R32G32B32A32_SFLOAT: return kFloat; case VK_FORMAT_R16G16B16A16_SFLOAT: return kHalf; + case VK_FORMAT_R16G16B16A16_SINT: + return kShort; + case VK_FORMAT_R16G16B16A16_UINT: + return kUInt16; default: VK_THROW("No corresponding scalar type for unknown VkFormat: ", vkformat); } diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index ce6ab32ce60..c368c23c539 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -109,6 +109,8 @@ def gen_parameterization(self) -> str: vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { switch (at_scalartype) { + case c10::kDouble: + return vkapi::kDouble; case c10::kFloat: return vkapi::kFloat; case c10::kHalf: @@ -119,7 +121,7 @@ def gen_parameterization(self) -> str: return vkapi::kInt; case c10::kChar: return vkapi::kChar; - case c10::kBool: + case c10::kBool: return vkapi::kBool; default: VK_THROW("Unsupported at::ScalarType!"); diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index 632bddd58c4..5c42f090c99 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -150,7 +150,7 @@ Tensor& quantize_per_tensor_out( break; switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE); + ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE); default: ET_CHECK_MSG( false, @@ -347,7 +347,7 @@ Tensor& quantize_per_channel_out( break; switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE); + ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE); default: ET_CHECK_MSG( false,