From 44b193b1460ed4ee5c200e22333b69175c9fde28 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 1 Apr 2026 11:29:19 -0700 Subject: [PATCH] [ET-VK] Fix pack_fp_linear_weight for devices without VK_KHR_16bit_storage The `pack_fp_linear_weight` prepack shader crashes on devices that lack `VK_KHR_16bit_storage` support because the half-precision variant reads from a `float16_t[]` staging buffer, which requires that extension. This applies the same two-dtype pattern used by `nchw_to_image` and `conv2d_dw_prepack_weights`: a new `BUF_DTYPE` shader parameter allows the staging buffer to use float32 (`[half, float]` combo) while the packed output remains half-precision. The runtime selects the correct variant via `get_staging_dtype_for()`, which returns `kFloat` when the device lacks fp16 buffer support. All three call sites that construct the `pack_fp_linear_weight` shader name (Linear.cpp, Conv1dPW.cpp, Conv2dPW.cpp) are updated to append the staging dtype suffix. Authored with Claude. Differential Revision: [D99133993](https://our.internmc.facebook.com/intern/diff/D99133993/) [ghstack-poisoned] --- .../graph/ops/glsl/pack_fp_linear_weight.glsl | 9 +++++---- .../graph/ops/glsl/pack_fp_linear_weight.yaml | 15 +++++++++------ .../vulkan/runtime/graph/ops/impl/Conv1dPW.cpp | 1 + .../vulkan/runtime/graph/ops/impl/Conv2dPW.cpp | 1 + backends/vulkan/runtime/graph/ops/impl/Linear.cpp | 1 + 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.glsl index 8976f4b8d69..a439500c97a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.glsl @@ -9,15 +9,16 @@ #version 450 core #define PRECISION ${PRECISION} -#define VEC4_T ${texel_load_type(DTYPE, "buffer")} -#define T ${texel_load_component_type(DTYPE, "buffer")} +#define BUF_T ${buffer_scalar_type(BUF_DTYPE)} +#define VEC4_T ${texel_load_type(DTYPE, PACKED_STORAGE)} +#define T ${texel_load_component_type(DTYPE, PACKED_STORAGE)} $if PACKED_STORAGE == "buffer": #define OUTPUT_BUFFER #extension GL_EXT_control_flow_attributes : require -${define_required_extensions("buffer", DTYPE)} +${define_required_extensions("buffer", BUF_DTYPE)} $if PACKED_STORAGE != "buffer": ${define_required_extensions(PACKED_STORAGE, DTYPE)} @@ -29,7 +30,7 @@ $if PACKED_STORAGE == "buffer": ${layout_declare_tensor(B, "w", "t_weight_packed", DTYPE, "buffer", is_scalar_array=False)} $else: ${layout_declare_tensor(B, "w", "t_weight_packed", DTYPE, PACKED_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_weight_src", DTYPE, "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_weight_src", BUF_DTYPE, "buffer", is_scalar_array=True)} layout(push_constant) uniform restrict Block { int N; diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.yaml index 34793634435..2da89934369 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.yaml @@ -7,13 +7,16 @@ pack_fp_linear_weight: parameter_names_with_default_values: DTYPE: float + BUF_DTYPE: float PACKED_STORAGE: texture2d generate_variant_forall: - PACKED_STORAGE: - - VALUE: texture2d - - VALUE: buffer - DTYPE: - - VALUE: float - - VALUE: half + combination: + parameter_names: [PACKED_STORAGE, DTYPE, BUF_DTYPE] + combos: + - parameter_values: [texture2d, float, float] + - parameter_values: [texture2d, half, half] + - parameter_values: [texture2d, half, float] + - parameter_values: [buffer, float, float] + - parameter_values: [buffer, half, half] shader_variants: - NAME: pack_fp_linear_weight diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp index f6db56fc581..90dada6b58e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp @@ -71,6 +71,7 @@ static ValueRef prepack_conv1d_pw_weight( std::string kernel_name = "pack_fp_linear_weight"; add_storage_type_suffix(kernel_name, weight_storage); add_dtype_suffix(kernel_name, graph.dtype_of(weight_data)); + add_dtype_suffix(kernel_name, graph.get_staging_dtype_for(weight_data)); graph.prepack_nodes().emplace_back(new PrepackNode( graph, diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv2dPW.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv2dPW.cpp index 2863d80aa0e..43657e7f4ec 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Conv2dPW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Conv2dPW.cpp @@ -130,6 +130,7 @@ ValueRef prepack_conv2d_pw_weight( std::string pack_kernel_name = "pack_fp_linear_weight"; add_storage_type_suffix(pack_kernel_name, weight_storage); add_dtype_suffix(pack_kernel_name, graph.dtype_of(weight_data)); + add_dtype_suffix(pack_kernel_name, graph.get_staging_dtype_for(weight_data)); graph.prepack_nodes().emplace_back(new PrepackNode( graph, diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index ca7f55e85f2..62266473351 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -81,6 +81,7 @@ ValueRef prepack_fp_linear_weight( std::string kernel_name = "pack_fp_linear_weight"; add_storage_type_suffix(kernel_name, weight_storage); add_dtype_suffix(kernel_name, graph.dtype_of(weight_data)); + add_dtype_suffix(kernel_name, graph.get_staging_dtype_for(weight_data)); graph.prepack_nodes().emplace_back(new PrepackNode( graph,