Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

${define_active_storage_type("buffer")}

layout(std430) buffer;

#include "indexing.glslh"

// Output staging buffer: raw int8 data interpreted as int32 for device compat
${layout_declare_tensor(B, "w", "nchw_out", "int", "buffer")}
// Input buffer: packed int8x4 values (each int32 contains 4 packed int8)
${layout_declare_tensor(B, "r", "t_inp", "int", "buffer")}

// Metadata for input tensor
${layout_declare_ubo(B, "BufferMetadata", "inp")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")}

void main() {
// One thread per output int32 in the NCHW staging buffer.
// Each output int32 holds 4 consecutive NCHW bytes.
const uint out_int32_idx = gl_GlobalInvocationID.x;

const uint W = inp.sizes[0][0];
const uint H = inp.sizes[0][1];
const uint C = inp.sizes[0][2];
const uint N = inp.sizes[0][3];
const uint total_numel = W * H * C * N;
const uint num_out_int32s = (total_numel + 3u) / 4u;

if (out_int32_idx >= num_out_int32s) {
return;
}

int output_int32 = 0;
[[unroll]] for (int j = 0; j < 4; ++j) {
const uint nchw_idx = out_int32_idx * 4u + uint(j);
if (nchw_idx >= total_numel) {
break;
}

// Convert NCHW linear index to tensor4D (WHCN) coordinates.
const uint w = nchw_idx % W;
const uint h = (nchw_idx / W) % H;
const uint c = (nchw_idx / (W * H)) % C;
const uint n = nchw_idx / (W * H * C);

TensorIndex4D tidx;
tidx.data = ivec4(int(w), int(h), int(c), int(n));

// tensor4d_idx_to_buf_idx returns a linear element index where
// element_index / 4 is the int32 slot and element_index % 4 is the byte
// position within that int32. This matches the packing order used by
// nchw_to_int8x4_buffer when writing to the int8x4 buffer.
const int elem_buf_idx = tensor4d_idx_to_buf_idx(inp, tidx, inp_layout);
const int int8_val =
(t_inp[elem_buf_idx / 4] >> ((elem_buf_idx % 4) * 8)) & 0xFF;

output_int32 |= (int8_val << (j * 8));
}

nchw_out[out_int32_idx] = output_int32;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

int8x4_buffer_to_nchw:
parameter_names_with_default_values:
DTYPE: int
shader_variants:
- NAME: int8x4_buffer_to_nchw
140 changes: 140 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void add_prepack_int8x4_buffer_node(
ComputeGraph& graph,
const ValueRef tensor_data,
const ValueRef tensor) {
VK_CHECK_COND(graph.dtype_of(tensor) == vkapi::kInt8x4);
// TODO(ssjia): Update shaders to handle high-dim tensors
VK_CHECK_COND(graph.dim_of(tensor) <= 4);

std::string kernel_name = "nchw_to_int8x4_buffer";

vkapi::ParamsBindList param_buffers;
param_buffers.append(graph.buffer_meta_ubo(tensor));

// One thread per texel (each texel = one int32 = 4 packed int8).
// Use padded_numel to account for dimension padding in packed int8 layouts
// (e.g., kPackedInt8_4C with C=3 pads to C=4).
uint32_t num_texels =
utils::safe_downcast<uint32_t>(graph.padded_numel_of(tensor) / 4);
utils::uvec3 global_wg_size = {num_texels, 1, 1};
utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);

graph.prepack_nodes().emplace_back(new PrepackNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
global_wg_size,
local_wg_size,
// Input and Output
tensor_data,
tensor,
// Parameter Buffers
param_buffers,
// Specialization Constants
{graph.hashed_layout_of(tensor)}));
}

static utils::uvec3 staging_to_int8x4_buffer_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)shader;
(void)resize_args;
const ValueRef out_tensor = args.at(0).refs.at(0);
const uint32_t num_texels =
utils::safe_downcast<uint32_t>(graph->padded_numel_of(out_tensor) / 4);
return {num_texels, 1, 1};
}

void add_staging_to_int8x4_buffer_node(
ComputeGraph& graph,
const ValueRef in_staging,
const ValueRef tensor) {
VK_CHECK_COND(graph.dtype_of(tensor) == vkapi::kInt8x4);
// TODO(ssjia): Update shaders to handle high-dim tensors
VK_CHECK_COND(graph.dim_of(tensor) <= 4);

vkapi::ParamsBindList param_buffers;
param_buffers.append(graph.buffer_meta_ubo(tensor));

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR("nchw_to_int8x4_buffer"),
staging_to_int8x4_buffer_global_wg_size,
default_pick_local_wg_size,
// Input and Output
{{tensor, vkapi::kWrite}, {in_staging, vkapi::kRead}},
// Parameter Buffers
param_buffers,
// Push Constants
{},
// Specialization Constants
{graph.hashed_layout_of(tensor)},
// Resize Args
{},
// Resizing Logic
nullptr));
}

static utils::uvec3 int8x4_buffer_to_staging_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)shader;
(void)resize_args;
const ValueRef in_tensor = args.at(1).refs.at(0);
// One thread per output int32 in the NCHW staging buffer.
const int32_t numel = graph->numel_of(in_tensor);
const uint32_t num_out_int32s =
utils::safe_downcast<uint32_t>((numel + 3) / 4);
return {num_out_int32s, 1, 1};
}

void add_int8x4_buffer_to_staging_node(
ComputeGraph& graph,
const ValueRef tensor,
const ValueRef staging_data) {
VK_CHECK_COND(graph.dtype_of(tensor) == vkapi::kInt8x4);
// TODO(ssjia): Update shaders to handle high-dim tensors
VK_CHECK_COND(graph.dim_of(tensor) <= 4);

vkapi::ParamsBindList param_buffers;
param_buffers.append(graph.buffer_meta_ubo(tensor));

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR("int8x4_buffer_to_nchw"),
int8x4_buffer_to_staging_global_wg_size,
default_pick_local_wg_size,
// Input and Output
{{staging_data, vkapi::kWrite}, {tensor, vkapi::kRead}},
// Parameter Buffers
param_buffers,
// Push Constants
{},
// Specialization Constants
{graph.hashed_layout_of(tensor)},
// Resize Args
{},
// Resizing Logic
nullptr));
}

} // namespace vkcompute
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,19 @@

namespace vkcompute {

void add_staging_to_int8x4_buffer_node(
void add_prepack_int8x4_buffer_node(
ComputeGraph& graph,
const ValueRef tensor_data,
const ValueRef tensor);

void add_staging_to_int8x4_buffer_node(
ComputeGraph& graph,
const ValueRef in_staging,
const ValueRef tensor);

void add_int8x4_buffer_to_staging_node(
ComputeGraph& graph,
const ValueRef tensor,
const ValueRef staging_data);

} // namespace vkcompute
49 changes: 0 additions & 49 deletions backends/vulkan/runtime/graph/ops/impl/Q8taStaging.cpp

This file was deleted.

12 changes: 10 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

#include <executorch/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>

Expand All @@ -27,6 +27,10 @@ void add_staging_to_tensor_node(
const ValueRef out_tensor) {
VK_CHECK_COND(graph.val_is_staging(in_staging));

if (graph.dtype_of(out_tensor) == vkapi::kInt8x4) {
return add_staging_to_int8x4_buffer_node(graph, in_staging, out_tensor);
}

vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(
graph,
out_tensor,
Expand Down Expand Up @@ -104,6 +108,10 @@ void add_tensor_to_staging_node(
const ValueRef out_staging) {
VK_CHECK_COND(graph.val_is_staging(out_staging));

if (graph.dtype_of(in_tensor) == vkapi::kInt8x4) {
return add_int8x4_buffer_to_staging_node(graph, in_tensor, out_staging);
}

vkapi::ShaderInfo shader = get_tensor_to_nchw_shader(
graph,
in_tensor,
Expand Down Expand Up @@ -329,7 +337,7 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved(

void prepack_op(ComputeGraph& graph, const std::vector<ValueRef>& args) {
if (graph.dtype_of(args[1]) == vkapi::kInt8x4) {
return add_staging_to_int8x4_buffer_node(graph, args[0], args[1]);
return add_prepack_int8x4_buffer_node(graph, args[0], args[1]);
}
return add_prepack_standard_node(graph, args[0], args[1]);
}
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Q8taQuantizeDequantize.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h>

namespace vkcompute {

Expand Down Expand Up @@ -62,7 +62,7 @@ void q8ta_add_test(ComputeGraph& graph, const std::vector<ValueRef>& args) {
if (input_b_is_int8) {
// Input B is a pre-quantized int8 TensorRef; prepack directly into packed
// int8x4 format
add_staging_to_int8x4_buffer_node(graph, input_b, packed_int8_input_b);
add_prepack_int8x4_buffer_node(graph, input_b, packed_int8_input_b);
} else {
// Input B is a float tensor; quantize at runtime
add_q8ta_quantize_node(
Expand Down
Loading
Loading