diff --git a/backends/vulkan/runtime/graph/ops/glsl/int8x4_buffer_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/int8x4_buffer_to_nchw.glsl new file mode 100644 index 00000000000..76e6a6c6238 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/int8x4_buffer_to_nchw.glsl @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type("buffer")} + +layout(std430) buffer; + +#include "indexing.glslh" + +// Output staging buffer: raw int8 data interpreted as int32 for device compat +${layout_declare_tensor(B, "w", "nchw_out", "int", "buffer")} +// Input buffer: packed int8x4 values (each int32 contains 4 packed int8) +${layout_declare_tensor(B, "r", "t_inp", "int", "buffer")} + +// Metadata for input tensor +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} + +void main() { + // One thread per output int32 in the NCHW staging buffer. + // Each output int32 holds 4 consecutive NCHW bytes. + const uint out_int32_idx = gl_GlobalInvocationID.x; + + const uint W = inp.sizes[0][0]; + const uint H = inp.sizes[0][1]; + const uint C = inp.sizes[0][2]; + const uint N = inp.sizes[0][3]; + const uint total_numel = W * H * C * N; + const uint num_out_int32s = (total_numel + 3u) / 4u; + + if (out_int32_idx >= num_out_int32s) { + return; + } + + int output_int32 = 0; + [[unroll]] for (int j = 0; j < 4; ++j) { + const uint nchw_idx = out_int32_idx * 4u + uint(j); + if (nchw_idx >= total_numel) { + break; + } + + // Convert NCHW linear index to tensor4D (WHCN) coordinates. + const uint w = nchw_idx % W; + const uint h = (nchw_idx / W) % H; + const uint c = (nchw_idx / (W * H)) % C; + const uint n = nchw_idx / (W * H * C); + + TensorIndex4D tidx; + tidx.data = ivec4(int(w), int(h), int(c), int(n)); + + // tensor4d_idx_to_buf_idx returns a linear element index where + // element_index / 4 is the int32 slot and element_index % 4 is the byte + // position within that int32. This matches the packing order used by + // nchw_to_int8x4_buffer when writing to the int8x4 buffer. + const int elem_buf_idx = tensor4d_idx_to_buf_idx(inp, tidx, inp_layout); + const int int8_val = + (t_inp[elem_buf_idx / 4] >> ((elem_buf_idx % 4) * 8)) & 0xFF; + + output_int32 |= (int8_val << (j * 8)); + } + + nchw_out[out_int32_idx] = output_int32; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/int8x4_buffer_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/int8x4_buffer_to_nchw.yaml new file mode 100644 index 00000000000..1ee9728779a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/int8x4_buffer_to_nchw.yaml @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +int8x4_buffer_to_nchw: + parameter_names_with_default_values: + DTYPE: int + shader_variants: + - NAME: int8x4_buffer_to_nchw diff --git a/backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.cpp new file mode 100644 index 00000000000..eb1d9965f30 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace vkcompute { + +void add_prepack_int8x4_buffer_node( + ComputeGraph& graph, + const ValueRef tensor_data, + const ValueRef tensor) { + VK_CHECK_COND(graph.dtype_of(tensor) == vkapi::kInt8x4); + // TODO(ssjia): Update shaders to handle high-dim tensors + VK_CHECK_COND(graph.dim_of(tensor) <= 4); + + std::string kernel_name = "nchw_to_int8x4_buffer"; + + vkapi::ParamsBindList param_buffers; + param_buffers.append(graph.buffer_meta_ubo(tensor)); + + // One thread per texel (each texel = one int32 = 4 packed int8). + // Use padded_numel to account for dimension padding in packed int8 layouts + // (e.g., kPackedInt8_4C with C=3 pads to C=4). + uint32_t num_texels = + utils::safe_downcast(graph.padded_numel_of(tensor) / 4); + utils::uvec3 global_wg_size = {num_texels, 1, 1}; + utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + local_wg_size, + // Input and Output + tensor_data, + tensor, + // Parameter Buffers + param_buffers, + // Specialization Constants + {graph.hashed_layout_of(tensor)})); +} + +static utils::uvec3 staging_to_int8x4_buffer_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out_tensor = args.at(0).refs.at(0); + const uint32_t num_texels = + utils::safe_downcast(graph->padded_numel_of(out_tensor) / 4); + return {num_texels, 1, 1}; +} + +void add_staging_to_int8x4_buffer_node( + ComputeGraph& graph, + const ValueRef in_staging, + const ValueRef tensor) { + VK_CHECK_COND(graph.dtype_of(tensor) == vkapi::kInt8x4); + // TODO(ssjia): Update shaders to handle high-dim tensors + VK_CHECK_COND(graph.dim_of(tensor) <= 4); + + vkapi::ParamsBindList param_buffers; + param_buffers.append(graph.buffer_meta_ubo(tensor)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR("nchw_to_int8x4_buffer"), + staging_to_int8x4_buffer_global_wg_size, + default_pick_local_wg_size, + // Input and Output + {{tensor, vkapi::kWrite}, {in_staging, vkapi::kRead}}, + // Parameter Buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {graph.hashed_layout_of(tensor)}, + // Resize Args + {}, + // Resizing Logic + nullptr)); +} + +static utils::uvec3 int8x4_buffer_to_staging_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef in_tensor = args.at(1).refs.at(0); + // One thread per output int32 in the NCHW staging buffer. + const int32_t numel = graph->numel_of(in_tensor); + const uint32_t num_out_int32s = + utils::safe_downcast((numel + 3) / 4); + return {num_out_int32s, 1, 1}; +} + +void add_int8x4_buffer_to_staging_node( + ComputeGraph& graph, + const ValueRef tensor, + const ValueRef staging_data) { + VK_CHECK_COND(graph.dtype_of(tensor) == vkapi::kInt8x4); + // TODO(ssjia): Update shaders to handle high-dim tensors + VK_CHECK_COND(graph.dim_of(tensor) <= 4); + + vkapi::ParamsBindList param_buffers; + param_buffers.append(graph.buffer_meta_ubo(tensor)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR("int8x4_buffer_to_nchw"), + int8x4_buffer_to_staging_global_wg_size, + default_pick_local_wg_size, + // Input and Output + {{staging_data, vkapi::kWrite}, {tensor, vkapi::kRead}}, + // Parameter Buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {graph.hashed_layout_of(tensor)}, + // Resize Args + {}, + // Resizing Logic + nullptr)); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h b/backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.h similarity index 65% rename from backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h rename to backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.h index 40386551e36..659ed696cd1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h +++ b/backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.h @@ -12,9 +12,19 @@ namespace vkcompute { -void add_staging_to_int8x4_buffer_node( +void add_prepack_int8x4_buffer_node( ComputeGraph& graph, const ValueRef tensor_data, const ValueRef tensor); +void add_staging_to_int8x4_buffer_node( + ComputeGraph& graph, + const ValueRef in_staging, + const ValueRef tensor); + +void add_int8x4_buffer_to_staging_node( + ComputeGraph& graph, + const ValueRef tensor, + const ValueRef staging_data); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.cpp deleted file mode 100644 index 8dc3f8156f8..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.cpp +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include -#include - -namespace vkcompute { - -void add_staging_to_int8x4_buffer_node( - ComputeGraph& graph, - const ValueRef tensor_data, - const ValueRef tensor) { - VK_CHECK_COND(graph.dtype_of(tensor) == vkapi::kInt8x4); - - std::string kernel_name = "nchw_to_int8x4_buffer"; - - vkapi::ParamsBindList param_buffers; - param_buffers.append(graph.buffer_meta_ubo(tensor)); - - // One thread per texel (each texel = one int32 = 4 packed int8). - // Use padded_numel to account for dimension padding in packed int8 layouts - // (e.g., kPackedInt8_4C with C=3 pads to C=4). - uint32_t num_texels = - utils::safe_downcast(graph.padded_numel_of(tensor) / 4); - utils::uvec3 global_wg_size = {num_texels, 1, 1}; - utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); - - graph.prepack_nodes().emplace_back(new PrepackNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, - // Input and Output - tensor_data, - tensor, - // Parameter Buffers - param_buffers, - // Specialization Constants - {graph.hashed_layout_of(tensor)})); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index adcad9f9817..c418a3681c8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -12,7 +12,7 @@ #include #include -#include +#include #include #include @@ -27,6 +27,10 @@ void add_staging_to_tensor_node( const ValueRef out_tensor) { VK_CHECK_COND(graph.val_is_staging(in_staging)); + if (graph.dtype_of(out_tensor) == vkapi::kInt8x4) { + return add_staging_to_int8x4_buffer_node(graph, in_staging, out_tensor); + } + vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( graph, out_tensor, @@ -104,6 +108,10 @@ void add_tensor_to_staging_node( const ValueRef out_staging) { VK_CHECK_COND(graph.val_is_staging(out_staging)); + if (graph.dtype_of(in_tensor) == vkapi::kInt8x4) { + return add_int8x4_buffer_to_staging_node(graph, in_tensor, out_staging); + } + vkapi::ShaderInfo shader = get_tensor_to_nchw_shader( graph, in_tensor, @@ -329,7 +337,7 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved( void prepack_op(ComputeGraph& graph, const std::vector& args) { if (graph.dtype_of(args[1]) == vkapi::kInt8x4) { - return add_staging_to_int8x4_buffer_node(graph, args[0], args[1]); + return add_prepack_int8x4_buffer_node(graph, args[0], args[1]); } return add_prepack_standard_node(graph, args[0], args[1]); } diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp index f5214221359..e3c3e6e2642 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp @@ -8,9 +8,9 @@ #include +#include #include #include -#include namespace vkcompute { @@ -62,7 +62,7 @@ void q8ta_add_test(ComputeGraph& graph, const std::vector& args) { if (input_b_is_int8) { // Input B is a pre-quantized int8 TensorRef; prepack directly into packed // int8x4 format - add_staging_to_int8x4_buffer_node(graph, input_b, packed_int8_input_b); + add_prepack_int8x4_buffer_node(graph, input_b, packed_int8_input_b); } else { // Input B is a float tensor; quantize at runtime add_q8ta_quantize_node( diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 261d3f72d01..746fa2c5253 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -29,6 +29,8 @@ #include +#include + using namespace vkcompute; using namespace vkcompute::api; @@ -3490,3 +3492,85 @@ void test_dynamic_dispatch(int M, int N) { TEST(VulkanComputeGraphOpsTest, test_dynamic_dispatch_graph) { test_dynamic_dispatch(128, 128); } + +// +// Int8x4 Staging Tests +// + +void test_int8x4_staging_round_trip( + const std::vector& sizes, + const utils::GPUMemoryLayout layout) { + GraphConfig config; + ComputeGraph graph(config); + + const int32_t numel = utils::multiply_integers(sizes); + + // Build graph: + // staging_in (kInt8x4) -> [execute: nchw_to_int8x4_buffer] -> tensor + // (kInt8x4) + // -> [execute: int8x4_buffer_to_nchw] -> staging_out + ValueRef tensor = + graph.add_tensor(sizes, vkapi::kInt8x4, utils::kBuffer, layout); + + ValueRef staging_in = graph.set_input_tensor(tensor); + ValueRef staging_out = graph.set_output_tensor(tensor); + + // staging_buffer_numel_of returns padded_numel / 4 (number of int32 + // elements). Multiply by 4 to get the byte count, which is used to zero-pad + // the input. + const size_t staging_numel = graph.staging_buffer_numel_of(tensor); + // Create NCHW int8 input data zero-padded to the full staging buffer size. + std::vector data_in(staging_numel * 4, 0); + for (int32_t i = 0; i < numel; ++i) { + data_in[i] = static_cast(static_cast(i * 37 + 13)); + } + + graph.prepare(); + // prepack() allocates Vulkan memory for all tensors even when there are no + // prepack nodes; it must be called before execute(). + graph.prepack(); + + // Copy NCHW int8 data into the input staging buffer. The staging buffer has + // kInt8x4 dtype (staging_numel int32 elements), so reinterpret the int8 data + // as int32 for the copy call. + graph.maybe_cast_and_copy_into_staging( + staging_in, + reinterpret_cast(data_in.data()), + staging_numel, + vkapi::kInt8x4); + + graph.execute(); + + // Read back packed int32s from staging. The staging dtype is kInt8x4 (4 + // bytes per element = one packed int32 holding 4 int8 values). + std::vector data_out_packed(staging_numel); + graph.maybe_cast_and_copy_from_staging( + staging_out, data_out_packed.data(), staging_numel, vkapi::kInt8x4); + + // Verify each int8 element matches the round-trip + for (int32_t i = 0; i < numel; ++i) { + const uint8_t byte = static_cast( + static_cast(data_out_packed[i / 4]) >> ((i % 4) * 8)); + const int8_t actual = static_cast(byte); + EXPECT_EQ(actual, data_in[i]) + << "Mismatch at nchw index " << i << " for sizes [" << sizes[0] + << (sizes.size() > 1 ? ", " + std::to_string(sizes[1]) : "") + << (sizes.size() > 2 ? ", " + std::to_string(sizes[2]) : "") + << (sizes.size() > 3 ? ", " + std::to_string(sizes[3]) : "") + << "] layout " << layout; + } +} + +TEST(VulkanComputeGraphTest, test_int8x4_staging_round_trip) { + const std::vector layouts = { + utils::kPackedInt8_4C, + utils::kPackedInt8_4W, + utils::kPackedInt8_4W4C, + utils::kPackedInt8_4C1W, + }; + for (const auto& sizes : standard_sizes_to_test) { + for (const auto layout : layouts) { + test_int8x4_staging_round_trip(sizes, layout); + } + } +}