From e85005724b144812accaa313cc194f60e4179b99 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 19 Feb 2026 11:48:22 -0800 Subject: [PATCH 1/6] [ET-VK][qconv] Fix depthwise weight_sums sum dimension Pull Request resolved: https://github.com/pytorch/executorch/pull/17504 The weight_sums tensor stores per-output-channel sums of quantized weight values, used to apply activation zero point correction during integer accumulation. For depthwise convolutions, the weight tensor is reshaped to (H, W, OC), but the sum was unconditionally computed along dim=1 (the W dimension). This produced a tensor of shape (H, OC) instead of (OC,), causing incorrect zero point correction and corrupted depthwise conv output. Fix by branching on is_depthwise_conv to sum over dims (0, 1) for the (H, W, OC) layout. ghstack-source-id: 342806069 @exported-using-ghexport Differential Revision: [D93511635](https://our.internmc.facebook.com/intern/diff/D93511635/) --- backends/vulkan/patterns/quantized_convolution.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py index b89dfe9aaab..93140e15341 100644 --- a/backends/vulkan/patterns/quantized_convolution.py +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -215,9 +215,17 @@ def make_q8ta_conv2d_custom_op( with graph_module.graph.inserting_before(first_graph_node): qweight_tensor_name = utils.get_tensor_name(ep, match.weight_node) # Pre-compute the weight sums which are needed to apply activation zero point - # when using integer accumulation. For the reshaped 2D weight matrix (IC_per_group * H * W, OC), - # sum over dimension 0 to get sums per output channel - sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous() + # when using integer accumulation. Sum all weight elements per output channel. + if is_depthwise_conv: + # weight_tensor shape is (H, W, OC); sum over spatial dims (H, W) + sum_per_output_channel = ( + weight_tensor.sum(dim=(0, 1)).to(torch.int32).contiguous() + ) + else: + # weight_tensor shape is (OC, H*W*IC_per_group); sum over dim 1 + sum_per_output_channel = ( + weight_tensor.sum(dim=1).to(torch.int32).contiguous() + ) sums_name = qweight_tensor_name + "_sums" # Sanitize the name sums_name = sums_name.replace(".", "_") From 0d727608aa232b472815e0db9d935dd995d7c5ee Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 19 Feb 2026 11:48:25 -0800 Subject: [PATCH 2/6] [ET-VK][qconv] Pad weight_sums buffer to multiple-of-4 alignment Pull Request resolved: https://github.com/pytorch/executorch/pull/17505 The q8ta convolution shaders read weight_sums via ivec4 loads (4 int32 values at once), requiring the buffer to have at least align_up_4(OC) elements. The weight tensor, weight_scales, and bias are all padded via align_width_and_update_state_dict, but weight_sums was created as a 1D tensor of shape (OC,) without any padding. For OC values that are not a multiple of 4 (e.g. OC=1 in the final pointwise conv of MetaNet GreenScreen), this results in out-of-bounds GPU buffer reads. On host testing with ASAN, this manifests as a heap-buffer-overflow. Fix by padding sum_per_output_channel to align_up_4(OC) before creating the constant placeholder. Also fix the C++ test utility compute_weight_sums() which was incorrectly shrinking a pre-allocated aligned buffer. ghstack-source-id: 342806075 @exported-using-ghexport Differential Revision: [D93511633](https://our.internmc.facebook.com/intern/diff/D93511633/) --- .../vulkan/patterns/quantized_convolution.py | 10 +++++++++ .../test/custom_ops/test_q8ta_conv2d_pw.cpp | 22 +++++++++++++++++++ backends/vulkan/test/custom_ops/utils.cpp | 6 ++++- 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py index 93140e15341..9a6fb69bf87 100644 --- a/backends/vulkan/patterns/quantized_convolution.py +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -226,6 +226,16 @@ def make_q8ta_conv2d_custom_op( sum_per_output_channel = ( weight_tensor.sum(dim=1).to(torch.int32).contiguous() ) + # Pad weight sums to align OC to multiple of 4, matching the alignment + # applied to weight, weight_scales, and bias above. Without this, the + # GPU shader would read out-of-bounds when OC is not a multiple of 4. + oc = sum_per_output_channel.shape[0] + if oc % 4 != 0: + num_padding = 4 - (oc % 4) + sum_per_output_channel = torch.nn.functional.pad( + sum_per_output_channel, (0, num_padding) + ).contiguous() + sums_name = qweight_tensor_name + "_sums" # Sanitize the name sums_name = sums_name.replace(".", "_") diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp index 51095c649b6..6ce6671ec84 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp @@ -210,6 +210,28 @@ static std::vector generate_quantized_conv2d_pw_test_cases() { } std::vector configs = { + // OC < 4 cases to test edge cases with partial output channel blocks + {OutInChannels(1, 16), + InputSize2D(8, 8), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(2, 16), + InputSize2D(8, 8), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(3, 16), + InputSize2D(8, 8), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, // Pointwise convolutions: kernel size 1x1 {OutInChannels(32, 3), InputSize2D(64, 64), diff --git a/backends/vulkan/test/custom_ops/utils.cpp b/backends/vulkan/test/custom_ops/utils.cpp index b23c288a58f..2a50e7b5ec1 100644 --- a/backends/vulkan/test/custom_ops/utils.cpp +++ b/backends/vulkan/test/custom_ops/utils.cpp @@ -2064,7 +2064,11 @@ void compute_weight_sums( auto& weight_sums_data = weight_sums.get_int32_data(); auto& quantized_weight_data = quantized_weight.get_int8_data(); - weight_sums_data.resize(out_features); + // Don't resize down - the buffer may be pre-allocated with aligned size. + // Only resize up if needed. + if (weight_sums_data.size() < static_cast(out_features)) { + weight_sums_data.resize(out_features); + } // For each output feature, compute the sum of quantized weights for (int64_t out_f = 0; out_f < out_features; ++out_f) { From ca20b0ed1050931ddc952e440df8ab86848ebbe4 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 19 Feb 2026 11:48:27 -0800 Subject: [PATCH 3/6] [ET-VK][qconv] Add apply_relu support to q8ta conv operators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/17506 The quantized convolution pattern detector correctly identifies ReLU nodes between conv output and the output quantize node, but the pattern replacement did not pass this information to the fused q8ta operator. When the pattern replaced `dequant → conv → relu → quant` with `q8ta_conv2d`, the relu node was removed from the graph but its effect was not preserved. This silently removed all conv-relu non-linearity from int8 quantized models. Add an `apply_relu` parameter throughout the full pipeline: - Custom op schemas and reference implementations (custom_ops_lib.py) - Pattern replacement (quantized_convolution.py) - C++ dispatch logic extracts apply_relu and passes it as a spec constant (Q8taConv2d.cpp, Q8taConv2dDW.cpp, Q8taConv2dPW.cpp, Q8taConv2dIm2Col.cpp) - GLSL shaders apply conditional max(value, 0) after dequantization and before requantization (q8ta_conv2d.glsl, q8ta_conv2d_dw.glsl, q8ta_conv2d_pw.glsl) - Test operator wrappers updated with proper legacy path handling (TestQ8taConv2d.cpp) ghstack-source-id: 342806070 @exported-using-ghexport Differential Revision: [D93511632](https://our.internmc.facebook.com/intern/diff/D93511632/) --- backends/vulkan/custom_ops_lib.py | 17 +- .../vulkan/patterns/quantized_convolution.py | 1 + .../runtime/graph/ops/glsl/q8ta_conv2d.glsl | 8 + .../graph/ops/glsl/q8ta_conv2d_dw.glsl | 8 + .../graph/ops/glsl/q8ta_conv2d_pw.glsl | 9 + .../runtime/graph/ops/impl/Q8taConv2d.cpp | 18 +- .../runtime/graph/ops/impl/Q8taConv2d.h | 10 + .../runtime/graph/ops/impl/Q8taConv2dDW.cpp | 9 +- .../graph/ops/impl/Q8taConv2dIm2Col.cpp | 5 + .../runtime/graph/ops/impl/Q8taConv2dPW.cpp | 9 +- .../graph/ops/impl/QuantizedConvolution.cpp | 1 + .../test/custom_ops/impl/TestQ8taConv2d.cpp | 177 +++++++++++------- .../test/custom_ops/test_q8ta_conv2d.cpp | 6 + .../test/custom_ops/test_q8ta_conv2d_dw.cpp | 4 + .../test/custom_ops/test_q8ta_conv2d_pw.cpp | 6 + 15 files changed, 217 insertions(+), 71 deletions(-) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index f2e4482c9b9..3e77b0c0eea 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -376,6 +376,7 @@ def q8ta_conv2d( padding: list, dilation: list, groups: int, + activation: str, ): x = torch.ops.quantized_decomposed.dequantize_per_tensor( x, input_scale, input_zero_point, -128, 127, x.dtype @@ -418,6 +419,9 @@ def q8ta_conv2d( x, weights, bias, stride, padding, dilation, groups ) + if activation == "relu": + out = torch.nn.functional.relu(out) + out = torch.ops.quantized_decomposed.quantize_per_tensor( out, output_scale, output_zero_point, -128, 127, torch.int8 ) @@ -442,7 +446,8 @@ def q8ta_conv2d( SymInt[] stride, SymInt[] padding, SymInt[] dilation, - SymInt groups) -> Tensor + SymInt groups, + str activation) -> Tensor """ ) lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd") @@ -466,7 +471,8 @@ def q8ta_conv2d( SymInt[] stride, SymInt[] padding, SymInt[] dilation, - SymInt groups) -> Tensor + SymInt groups, + str activation) -> Tensor """ ) lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd") @@ -488,6 +494,7 @@ def q8ta_conv2d_dw( padding: list, dilation: list, groups: int, + activation: str, ): x = torch.ops.quantized_decomposed.dequantize_per_tensor( x, input_scale, input_zero_point, -128, 127, x.dtype @@ -514,6 +521,9 @@ def q8ta_conv2d_dw( x, weights, bias, stride, padding, dilation, groups ) + if activation == "relu": + out = torch.nn.functional.relu(out) + out = torch.ops.quantized_decomposed.quantize_per_tensor( out, output_scale, output_zero_point, -128, 127, torch.int8 ) @@ -538,7 +548,8 @@ def q8ta_conv2d_dw( SymInt[] stride, SymInt[] padding, SymInt[] dilation, - SymInt groups) -> Tensor + SymInt groups, + str activation) -> Tensor """ ) lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd") diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py index 9a6fb69bf87..12ebbd1a382 100644 --- a/backends/vulkan/patterns/quantized_convolution.py +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -281,6 +281,7 @@ def make_q8ta_conv2d_custom_op( match.padding, match.dilation, match.groups, + "relu" if match.relu_node is not None else "none", ), ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl index 623de3a5d9a..d693acbab3f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl @@ -47,6 +47,7 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "apply_bias", "1")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} // Layout specialization constants ${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} @@ -220,6 +221,13 @@ void main() { } } + // Apply ReLU if enabled + if (activation_type > 0) { + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] = max(facc[subtile_w], vec4(0.0)); + } + } + // Compute base output texel index (for subtile_w=0) const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout); const int out_w_stride = int(outp.strides[0][0]); diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl index e6be92e7ba1..7f4d03887df 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl @@ -44,6 +44,7 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "apply_bias", "1")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} // Layout specialization constants ${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} @@ -197,6 +198,13 @@ void main() { } } + // Apply ReLU if enabled + if (activation_type > 0) { + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] = max(facc[subtile_w], vec4(0.0)); + } + } + // Compute base output texel index (for subtile_w=0) const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout); const int out_w_stride = int(outp.strides[0][0]); diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl index e0963dfcf48..ec41d933114 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl @@ -57,6 +57,7 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "apply_bias", "1")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} ${layout_declare_spec_const(C, "int", "conv2d_params_K4_per_group", "1")} // Layout specialization constants @@ -197,6 +198,10 @@ void main() { fma(vec4(accum_adjusted), vec4(weight_scales[n4]) * input_scale, vec4(bias[n4])); + // Apply ReLU if enabled + if (activation_type > 0) { + float_out_texel = max(float_out_texel, vec4(0.0)); + } // Requantize to int8 float_out_texel = round(float_out_texel * output_inv_scale) + output_zp; @@ -216,6 +221,10 @@ void main() { input_zp_vec * weight_sums[n4] + out_accum[m][n4]; vec4 float_out_texel = vec4(accum_adjusted) * vec4(weight_scales[n4] * input_scale); + // Apply ReLU if enabled + if (activation_type > 0) { + float_out_texel = max(float_out_texel, vec4(0.0)); + } // Requantize to int8 float_out_texel = round(float_out_texel * output_inv_scale) + output_zp; diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp index 4f047d414f8..33b7005a845 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp @@ -17,6 +17,15 @@ namespace vkcompute { +ActivationType activation_type_from_string(const std::string& activation) { + if (activation == "none") { + return ActivationType::kNone; + } else if (activation == "relu") { + return ActivationType::kRelu; + } + VK_THROW("Unknown activation type: ", activation); +} + bool q8ta_conv2d_check_packed_dim_info(const api::PackedDimInfo& info) { return info.packed_dim == WHCN::kChannelsDim && info.packed_dim_block_size == 4 && @@ -231,6 +240,7 @@ void add_q8ta_conv2d_node( const ValueRef padding, const ValueRef dilation, const ValueRef groups, + const uint32_t activation_type, const ValueRef packed_int8_output) { (void)packed_int8_input_im2col; // Not used in general shader @@ -288,9 +298,10 @@ void add_q8ta_conv2d_node( graph.buffer_meta_ubo(packed_int8_input), graph.create_params_buffer(conv_params)}; - // Build spec constants: apply_bias + layout constants + // Build spec constants: apply_bias, apply_relu + layout constants vkapi::SpecVarList spec_constants = { apply_bias, + activation_type, // Layout specialization constants graph.hashed_layout_of(packed_int8_input), graph.hashed_layout_of(packed_int8_output), @@ -341,8 +352,12 @@ void q8ta_conv2d_general( const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef packed_int8_output = args.at(idx++); + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + QuantizationConfig weight_quant_config(8, kPerChannel, {}); // Prepack weight using the conv2d weight packing for the general shader @@ -397,6 +412,7 @@ void q8ta_conv2d_general( padding, dilation, groups, + activation_type_val, packed_int8_output); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h index 9686c873c1b..2779a7445a8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h @@ -13,6 +13,13 @@ namespace vkcompute { +enum class ActivationType : uint32_t { + kNone = 0, + kRelu = 1, +}; + +ActivationType activation_type_from_string(const std::string& activation); + bool q8ta_conv2d_check_packed_dim_info(const api::PackedDimInfo& info); bool q8ta_conv2d_check_4w4c_packed_dim_info(const api::PackedDimInfo& info); @@ -58,6 +65,7 @@ void add_q8ta_conv2d_dw_node( const ValueRef padding, const ValueRef dilation, const ValueRef groups, + const uint32_t activation_type, const ValueRef packed_int8_output); void add_conv2d_dw_q8ta_q8csw_q8to_4w4c_node( @@ -97,6 +105,7 @@ void add_q8ta_conv2d_node( const ValueRef padding, const ValueRef dilation, const ValueRef groups, + const uint32_t activation_type, const ValueRef packed_int8_output); void add_q8ta_conv2d_pw_node( @@ -111,6 +120,7 @@ void add_q8ta_conv2d_pw_node( const ValueRef output_zp, const ValueRef bias_data, const ValueRef packed_bias, + const uint32_t activation_type, const ValueRef packed_int8_output); void q8ta_conv2d_im2col(ComputeGraph& graph, const std::vector& args); diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp index d12bbc0574a..e690ff435a8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp @@ -281,6 +281,7 @@ void add_q8ta_conv2d_dw_node( const ValueRef padding, const ValueRef dilation, const ValueRef groups, + const uint32_t activation_type, const ValueRef packed_int8_output) { Conv2DParams conv_params = create_conv2d_params( graph, @@ -334,9 +335,10 @@ void add_q8ta_conv2d_dw_node( graph.buffer_meta_ubo(packed_int8_input), graph.create_params_buffer(conv_params)}; - // Build spec constants: apply_bias + layout constants + // Build spec constants: apply_bias, activation_type + layout constants vkapi::SpecVarList spec_constants = { apply_bias, + activation_type, // Layout specialization constants graph.hashed_layout_of(packed_int8_input), graph.hashed_layout_of(packed_int8_output), @@ -385,8 +387,12 @@ void q8ta_conv2d_dw(ComputeGraph& graph, const std::vector& args) { const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef packed_int8_output = args.at(idx++); + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + QuantizationConfig weight_quant_config(8, kPerChannel, {}); // Prepack weight using depthwise-specific packing @@ -432,6 +438,7 @@ void q8ta_conv2d_dw(ComputeGraph& graph, const std::vector& args) { padding, dilation, groups, + activation_type_val, packed_int8_output); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp index e89ebc92aba..161b5e8fc24 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp @@ -197,6 +197,7 @@ void q8ta_conv2d_im2col( const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef packed_int8_output = args.at(idx++); QuantizationConfig weight_quant_config(8, kPerChannel, {}); @@ -225,6 +226,9 @@ void q8ta_conv2d_im2col( prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); } + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + // Calculate im2col output sizes std::vector im2col_sizes = calculate_q8ta_im2col_sizes( &graph, packed_int8_input, packed_int8_output, kernel_size, groups); @@ -265,6 +269,7 @@ void q8ta_conv2d_im2col( output_zp, bias_data, packed_bias, + activation_type_val, packed_int8_output); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp index fc883eefeef..b72f5b78f53 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp @@ -199,6 +199,7 @@ void add_q8ta_conv2d_pw_node( const ValueRef output_zp, const ValueRef bias_data, const ValueRef packed_bias, + const uint32_t activation_type, const ValueRef packed_int8_output) { // Validate packed dim info for input and output tensors // To maximize performance, the input tensor must be in 4W4C layout @@ -242,9 +243,10 @@ void add_q8ta_conv2d_pw_node( graph.buffer_meta_ubo(packed_int8_output), graph.buffer_meta_ubo(packed_int8_input)}; - // Build spec constants: apply_bias + layout constants + // Build spec constants: apply_bias, activation_type + layout constants vkapi::SpecVarList spec_constants = { apply_bias, + activation_type, K4_per_group, // Layout specialization constants graph.hashed_layout_of(packed_int8_output), @@ -296,8 +298,12 @@ void q8ta_conv2d_pw(ComputeGraph& graph, const std::vector& args) { (void)args.at(idx++); // padding (void)args.at(idx++); // dilation (void)args.at(idx++); // groups + const ValueRef activation_ref = args.at(idx++); const ValueRef packed_int8_output = args.at(idx++); + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation_ref))); + QuantizationConfig weight_quant_config(8, kPerChannel, {}); // Prepack weight using pointwise-specific packing @@ -342,6 +348,7 @@ void q8ta_conv2d_pw(ComputeGraph& graph, const std::vector& args) { output_zp, bias_data, packed_bias, + activation_type_val, packed_int8_output); } diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp index 1bfff6f1342..ebc276ee347 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp @@ -894,6 +894,7 @@ void add_conv2d_q8ta_q8csw_q8to_node( padding, dilation, groups, + static_cast(ActivationType::kNone), packed_int8_output); } } diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp index 4fed7461ce6..679ac33d11b 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp @@ -32,6 +32,7 @@ void test_q8ta_conv2d_dw( const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef layout_int = args.at(idx++); const ValueRef impl_selector_str = args.at(idx++); const ValueRef fp_output = args.at(idx++); @@ -59,29 +60,43 @@ void test_q8ta_conv2d_dw( add_q8ta_quantize_node( graph, fp_input, input_scale, input_zp, packed_int8_input); - // Build args for conv operator - std::vector conv_args = { - packed_int8_input, - input_scale, - input_zp, - weight_data, - weight_sums_data, - weight_scales_data, - output_scale, - output_zp, - bias_data, - kernel_size, - stride, - padding, - dilation, - groups, - packed_int8_output}; - if (impl_selector == "legacy_4w4c") { - // Use the general quantized conv2d operator for legacy path + // Legacy path does not support activation + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output}; VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); } else { - // Use the dedicated depthwise conv2d operator + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + activation, + packed_int8_output}; VK_GET_OP_FN("et_vk.q8ta_conv2d_dw.default")(graph, conv_args); } @@ -106,6 +121,7 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef layout_int = args.at(idx++); const ValueRef impl_selector_str = args.at(idx++); const ValueRef fp_output = args.at(idx++); @@ -133,36 +149,50 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { add_q8ta_quantize_node( graph, fp_input, input_scale, input_zp, packed_int8_input); - // Build args for conv operator - std::vector conv_args = { - packed_int8_input, - input_scale, - input_zp, - weight_data, - weight_sums_data, - weight_scales_data, - output_scale, - output_zp, - bias_data, - kernel_size, - stride, - padding, - dilation, - groups, - packed_int8_output}; - if (impl_selector == "legacy_4w4c") { - // Use the general quantized conv2d operator for legacy path + // Legacy path does not support activation + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output}; VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); - } else if (impl_selector == "im2col") { - // Use the im2col-based conv2d operator - VK_GET_OP_FN("et_vk.q8ta_conv2d_im2col.default")(graph, conv_args); - } else if (impl_selector == "general") { - // Use the general q8ta_conv2d operator (no im2col dispatch) - VK_GET_OP_FN("et_vk.q8ta_conv2d_general.default")(graph, conv_args); } else { - // Use the new general q8ta_conv2d operator - VK_GET_OP_FN("et_vk.q8ta_conv2d.default")(graph, conv_args); + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + activation, + packed_int8_output}; + if (impl_selector == "im2col") { + VK_GET_OP_FN("et_vk.q8ta_conv2d_im2col.default")(graph, conv_args); + } else if (impl_selector == "general") { + VK_GET_OP_FN("et_vk.q8ta_conv2d_general.default")(graph, conv_args); + } else { + VK_GET_OP_FN("et_vk.q8ta_conv2d.default")(graph, conv_args); + } } // Dequantize packed int8 output to floating point @@ -188,6 +218,7 @@ void test_q8ta_conv2d_pw( const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef layout_int = args.at(idx++); const ValueRef impl_selector_str = args.at(idx++); const ValueRef fp_output = args.at(idx++); @@ -219,27 +250,43 @@ void test_q8ta_conv2d_pw( add_q8ta_quantize_node( graph, fp_input, input_scale, input_zp, packed_int8_input); - // Build args for conv operator - std::vector conv_args = { - packed_int8_input, - input_scale, - input_zp, - weight_data, - weight_sums_data, - weight_scales_data, - output_scale, - output_zp, - bias_data, - kernel_size, - stride, - padding, - dilation, - groups, - packed_int8_output}; - if (impl_selector == "legacy_4w4c") { + // Legacy path does not support activation + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output}; VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); } else { + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + activation, + packed_int8_output}; VK_GET_OP_FN("et_vk.q8ta_conv2d_pw.default")(graph, conv_args); } diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp index 17dd7a0fc53..bc95cc724f5 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp @@ -178,6 +178,10 @@ static TestCase create_test_case_from_config( test_case.add_input_spec(dilation); test_case.add_input_spec(groups); + // Activation (none = no activation) + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + // Add memory layout parameter for the quantized tensors ValueSpec layout_int(static_cast(int8_memory_layout)); test_case.add_input_spec(layout_int); @@ -455,6 +459,8 @@ static void conv2d_q8ta_q8csw_q8to_reference_impl(TestCase& test_case) { const ValueSpec& padding_spec = test_case.inputs()[idx++]; const ValueSpec& dilation_spec = test_case.inputs()[idx++]; const ValueSpec& groups_spec = test_case.inputs()[idx++]; + const ValueSpec& activation_spec = test_case.inputs()[idx++]; + (void)activation_spec; // Not used in reference implementation const ValueSpec& layout_spec = test_case.inputs()[idx++]; (void)layout_spec; // Not used in reference implementation const ValueSpec& impl_selector_spec = test_case.inputs()[idx++]; diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp index 7ef73d49802..0734e444d57 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp @@ -187,6 +187,10 @@ TestCase create_test_case_from_config( test_case.add_input_spec(dilation); test_case.add_input_spec(groups); + // Activation (none = no activation) + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + // Add memory layout parameter for the quantized tensors ValueSpec layout_int(static_cast(int8_memory_layout)); test_case.add_input_spec(layout_int); diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp index 6ce6671ec84..83b9f92fb3a 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp @@ -179,6 +179,10 @@ static TestCase create_test_case_from_config( test_case.add_input_spec(dilation); test_case.add_input_spec(groups); + // Activation (none = no activation) + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + // Add memory layout parameter for the quantized tensors ValueSpec layout_int(static_cast(int8_memory_layout)); test_case.add_input_spec(layout_int); @@ -366,6 +370,8 @@ static void conv2d_q8ta_q8csw_q8to_reference_impl(TestCase& test_case) { const ValueSpec& padding_spec = test_case.inputs()[idx++]; const ValueSpec& dilation_spec = test_case.inputs()[idx++]; const ValueSpec& groups_spec = test_case.inputs()[idx++]; + const ValueSpec& activation_spec = test_case.inputs()[idx++]; + (void)activation_spec; // Not used in reference implementation const ValueSpec& layout_spec = test_case.inputs()[idx++]; (void)layout_spec; // Not used in reference implementation const ValueSpec& impl_selector_spec = test_case.inputs()[idx++]; From 116a75a9584a4e6de4488386843be02b03f942e1 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 19 Feb 2026 11:48:29 -0800 Subject: [PATCH 4/6] [ET-VK] Add fused q8ta_relu unary operator for int8x4 tensors Pull Request resolved: https://github.com/pytorch/executorch/pull/17507 This adds a fused quantized unary operator (ReLU) that operates directly on int8x4 packed buffer tensors, avoiding the overhead of separate dequantize-relu-requantize dispatches. The implementation follows the same pattern as q8ta_binary: a single GLSL compute shader dequantizes int8x4 blocks to float, applies the unary operation, and requantizes back to int8x4 in one dispatch. The shader uses the OPERATOR macro for parameterization so additional unary ops can be added as YAML variants without new shader code. Components added: - GLSL shader (q8ta_unary.glsl) and YAML config with relu variant - C++ operator implementation (Q8taUnary.cpp/h) registering et_vk.q8ta_relu.default - Export graph fusion pattern (quantized_unary.py) that detects dequant->relu->quant sequences and replaces them with the fused op - Custom op definition (q8ta_relu in custom_ops_lib.py) for the export pipeline - Test harness (TestQ8taUnary.cpp, test_q8ta_unary.cpp) with reference implementation and coverage across multiple shapes and quantized layouts This diff was authored with Claude. ghstack-source-id: 342806073 @exported-using-ghexport Differential Revision: [D93511629](https://our.internmc.facebook.com/intern/diff/D93511629/) --- backends/vulkan/custom_ops_lib.py | 35 ++ backends/vulkan/op_registry.py | 14 +- backends/vulkan/patterns/BUCK | 1 + backends/vulkan/patterns/__init__.py | 2 + backends/vulkan/patterns/quantized_unary.py | 121 +++++++ .../runtime/graph/ops/glsl/q8ta_unary.glsl | 82 +++++ .../runtime/graph/ops/glsl/q8ta_unary.yaml | 12 + .../runtime/graph/ops/impl/Q8taUnary.cpp | 124 +++++++ .../vulkan/runtime/graph/ops/impl/Q8taUnary.h | 29 ++ .../test/custom_ops/impl/TestQ8taUnary.cpp | 70 ++++ .../test/custom_ops/test_q8ta_unary.cpp | 311 ++++++++++++++++++ 11 files changed, 800 insertions(+), 1 deletion(-) create mode 100644 backends/vulkan/patterns/quantized_unary.py create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h create mode 100644 backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp create mode 100644 backends/vulkan/test/custom_ops/test_q8ta_unary.cpp diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 3e77b0c0eea..e371338e904 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -616,6 +616,41 @@ def q8ta_add_impl( lib.impl(name, q8ta_add_impl, "CompositeExplicitAutograd") q8ta_add_op = getattr(getattr(torch.ops, namespace), name) +######################## +## q8ta_relu ## +######################## + + +def q8ta_relu_impl( + input: torch.Tensor, + input_scale: float, + input_zero_point: int, + output_scale: float, + output_zero_point: int, +): + # Dequantize input to float + dequant = torch.ops.quantized_decomposed.dequantize_per_tensor( + input, input_scale, input_zero_point, -128, 127, input.dtype + ) + + # Apply ReLU + result = torch.nn.functional.relu(dequant) + + # Quantize the result back to int8 + quantized_result = torch.ops.quantized_decomposed.quantize_per_tensor( + result, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return quantized_result + + +name = "q8ta_relu" +lib.define( + f"{name}(Tensor input, float input_scale, int input_zero_point, float output_scale, int output_zero_point) -> Tensor" +) +lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd") +q8ta_relu_op = getattr(getattr(torch.ops, namespace), name) + ############################# ## select_as_symint ## ############################# diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 55a92335bc7..721297dea37 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -514,7 +514,19 @@ def register_q8ta_add(): # ============================================================================= -# Reduce.cpp +# Q8taUnary.cpp +# ============================================================================= + + +@update_features(exir_ops.edge.et_vk.q8ta_relu.default) +def register_q8ta_relu(): + return OpFeatures( + inputs_storage=utils.PACKED_INT8_BUFFER, + supports_resize=True, + ) + + +# ============================================================================= # ============================================================================= diff --git a/backends/vulkan/patterns/BUCK b/backends/vulkan/patterns/BUCK index a7153b30967..711000f74ca 100644 --- a/backends/vulkan/patterns/BUCK +++ b/backends/vulkan/patterns/BUCK @@ -13,6 +13,7 @@ fbcode_target(_kind = runtime.python_library, "quantized_linear.py", "quantized_convolution.py", "quantized_binary.py", + "quantized_unary.py", "sdpa.py", "select_as_symint.py", ], diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index 9b875def944..050680b024d 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -12,6 +12,8 @@ import executorch.backends.vulkan.patterns.quantized_linear # noqa +import executorch.backends.vulkan.patterns.quantized_unary # noqa + import executorch.backends.vulkan.patterns.rope # noqa import executorch.backends.vulkan.patterns.sdpa # noqa diff --git a/backends/vulkan/patterns/quantized_unary.py b/backends/vulkan/patterns/quantized_unary.py new file mode 100644 index 00000000000..28dc84b7997 --- /dev/null +++ b/backends/vulkan/patterns/quantized_unary.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +class QuantizedUnaryMatch(PatternMatch): + def __init__(self, unary_node: torch.fx.Node) -> None: + self.anchor_node = unary_node + self.match_found = False + self.all_nodes = [self.anchor_node] + + # The unary op takes a single input which must be a dequantize node + if len(unary_node.args) < 1: + return + + input_node = unary_node.args[0] + assert isinstance(input_node, torch.fx.Node) + + if not utils.is_dequant_node(input_node): + return + + self.dequantize_input_node = input_node + + # Extract quantization parameters for the input + self.quantize_input_node = self.dequantize_input_node.args[0] + self.input_scales_node = self.dequantize_input_node.args[1] + self.input_zeros_node = self.dequantize_input_node.args[2] + + self.all_nodes.append(self.dequantize_input_node) + + # The unary op output must have exactly one user: a quantize node + self.output_node = self.anchor_node + + if len(self.output_node.users) != 1: + return + + cur_node = list(self.output_node.users)[0] + + if not utils.is_quant_node(cur_node): + return + + self.quantize_output_node = cur_node + self.output_scales_node = self.quantize_output_node.args[1] + self.output_zeros_node = self.quantize_output_node.args[2] + + self.all_nodes.append(self.quantize_output_node) + + self.match_found = True + + +# Unary operation anchor nodes that we support +unary_anchor_nodes = { + exir_ops.edge.aten.relu.default, +} + + +@register_pattern_detector("quantized_unary") +def find_quantized_unary_patterns( + node: torch.fx.Node, +) -> Optional[QuantizedUnaryMatch]: + if node.target not in unary_anchor_nodes: + return None + + matched_pattern = QuantizedUnaryMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +@register_pattern_replacement("quantized_unary") +def make_q8ta_unary_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedUnaryMatch, +): + op_target = None + if match.anchor_node.target == exir_ops.edge.aten.relu.default: + op_target = exir_ops.edge.et_vk.q8ta_relu.default + else: + raise NotImplementedError( + f"Unsupported unary operation: {match.anchor_node.target}" + ) + + with graph_module.graph.inserting_before(match.output_node): + qunary_node = graph_module.graph.create_node( + "call_function", + op_target, + args=( + match.quantize_input_node, + match.input_scales_node, + match.input_zeros_node, + match.output_scales_node, + match.output_zeros_node, + ), + ) + + qunary_node.meta["val"] = match.output_node.meta["val"] + match.quantize_output_node.replace_all_uses_with(qunary_node) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl new file mode 100644 index 00000000000..e97d6d47877 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type("buffer")} + +#define op(X) ${OPERATOR} + +layout(std430) buffer; + +#include "indexing.glslh" +#include "common.glslh" +#include "block_indexing.glslh" +#include "block_int8x4_load.glslh" +#include "block_int8x4_store.glslh" + +// Output buffer: packed int8x4 values +${layout_declare_tensor(B, "w", "t_out", "int", "buffer")} +// Input buffer: packed int8x4 values +${layout_declare_tensor(B, "r", "t_in", "int", "buffer")} + +// Metadata for output and input tensors +${layout_declare_ubo(B, "BufferMetadata", "out_meta")} +${layout_declare_ubo(B, "BufferMetadata", "in_meta")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "block_config", "0")} + +// Generate loading functions for input buffer +define_load_int8x4_buffer_fns(t_in) + +// Generate storing functions for output buffer +define_store_int8x4_buffer_fns(t_out) + +void main() { + // Buffer storage: use linear dispatch + const uint contig_block_idx = gl_GlobalInvocationID.x; + TensorIndex4D tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config( + out_meta, contig_block_idx, block_config); + + if (out_of_bounds(tidx, out_meta)) { + return; + } + + const int block_outer_dim = get_block_outer_dim(block_config); + + // Load int8x4 block from input + ivec4 in_block = load_int8x4_block_from_t_in( + in_meta, tidx, in_layout, block_outer_dim); + + ivec4 out_block; + + for (int row = 0; row < 4; row++) { + vec4 in_texel = unpack_and_dequantize( + in_block[row], input_scale, input_zp); + + vec4 out_texel = op(in_texel); + out_block[row] = quantize_and_pack(out_texel, output_inv_scale, output_zp); + } + + // Store to output buffer + store_int8x4_block_to_t_out( + out_meta, tidx, out_layout, block_outer_dim, out_block); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml new file mode 100644 index 00000000000..257f6a44205 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_unary: + parameter_names_with_default_values: + OPERATOR: X + shader_variants: + - NAME: q8ta_relu_buffer + OPERATOR: max(X, vec4(0.0)) diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp new file mode 100644 index 00000000000..f8b606f3dfa --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void resize_q8ta_unary_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); + graph->virtual_resize(out, graph->sizes_of(self)); +} + +// +// Dispatch nodes +// + +void add_q8ta_unary_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef packed_int8_output, + const std::string& op_name) { + const api::PackedDimInfo& output_info = + graph.packed_dim_info_of(packed_int8_output); + const api::PackedDimInfo& input_info = + graph.packed_dim_info_of(packed_int8_input); + + VK_CHECK_COND(input_info.packed_dim == output_info.packed_dim); + VK_CHECK_COND( + input_info.packed_dim_block_size == output_info.packed_dim_block_size); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + std::string kernel_name = "q8ta_" + op_name; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + + vkapi::ParamsBindList param_buffers; + param_buffers.append(graph.buffer_meta_ubo(packed_int8_output)); + param_buffers.append(graph.buffer_meta_ubo(packed_int8_input)); + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + const BlockConfig block_config = + create_block_config_for_tensor(graph, packed_int8_output); + + const ValueRef block_config_ref = + static_cast(block_config.as_packed_int()); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_linear_global_wg_with_block_config, + pick_square_local_wg_with_block_config, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, {packed_int8_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {graph.hashed_layout_of(packed_int8_output), + graph.hashed_layout_of(packed_int8_input), + block_config.as_packed_int()}, + // Resize args + {block_config_ref}, + // Resizing Logic + resize_q8ta_unary_node)); +} + +// +// High level operator impl +// + +void q8ta_relu(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + add_q8ta_unary_node( + graph, + packed_int8_input, + input_scale, + input_zp, + output_scale, + output_zp, + packed_int8_output, + "relu"); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.q8ta_relu.default, q8ta_relu); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h new file mode 100644 index 00000000000..2b68fa53c22 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +// +// Unary operations for int8x4 tensors +// + +void add_q8ta_unary_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef packed_int8_output, + const std::string& op_name); + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp new file mode 100644 index 00000000000..6212216686f --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void q8ta_unary_test(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef quant_layout_int = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + int32_t layout_value = graph.extract_scalar(quant_layout_int); + utils::GPUMemoryLayout quant_layout = + static_cast(layout_value); + + // Create temporary tensor for quantized input + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + quant_layout); + + // Create temporary tensor for quantized output + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + quant_layout); + + // Quantize: FP -> int8x4 + add_q8ta_quantize_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + // Unary op: int8x4 -> int8x4 + add_q8ta_unary_node( + graph, + packed_int8_input, + input_scale, + input_zp, + output_scale, + output_zp, + packed_int8_output, + "relu"); + + // Dequantize: int8x4 -> FP + add_q8ta_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(test_etvk.q8ta_unary_test.default, q8ta_unary_test); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/test_q8ta_unary.cpp b/backends/vulkan/test/custom_ops/test_q8ta_unary.cpp new file mode 100644 index 00000000000..bc184c6c182 --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_q8ta_unary.cpp @@ -0,0 +1,311 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include "utils.h" + +#include + +// #define DEBUG_MODE + +using namespace executorch::vulkan::prototyping; +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 512; + +struct Q8taUnaryConfig { + std::vector shape; + std::string test_case_name = "placeholder"; + std::string op_name = "q8ta_unary_test"; +}; + +TestCase create_test_case_from_config( + const Q8taUnaryConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype, + utils::GPUMemoryLayout fp_memory_layout, + utils::GPUMemoryLayout quant_layout) { + TestCase test_case; + + std::string shape_str = shape_string(config.shape); + std::string test_name = config.test_case_name + " I=" + shape_str + " " + + repr_str(storage_type, fp_memory_layout) + "->" + + repr_str(utils::kBuffer, quant_layout); + test_case.set_name(test_name); + + std::string operator_name = "test_etvk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Input tensor (float) + ValueSpec input_tensor( + config.shape, + input_dtype, + storage_type, + fp_memory_layout, + DataGenType::RANDOM); + + float scale_val = 0.007112; + ValueSpec input_scale(scale_val); + + int32_t zero_point_val = 0; + ValueSpec input_zero_point(zero_point_val); + + // For relu, output scale and zero point can differ from input + float output_scale_val = 0.007112; + ValueSpec output_scale(output_scale_val); + + int32_t output_zp_val = 0; + ValueSpec output_zero_point(output_zp_val); + + int32_t layout_int = static_cast(quant_layout); + ValueSpec layout_spec(layout_int); + + // Output tensor (float) - same shape as input + ValueSpec output_tensor( + config.shape, + input_dtype, + storage_type, + fp_memory_layout, + DataGenType::ZEROS); + + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(layout_spec); + test_case.add_output_spec(output_tensor); + + test_case.set_abs_tolerance(scale_val + 1e-4); + + test_case.set_shader_filter({ + "nchw_to", + "to_nchw", + "q8ta_quantize", + "q8ta_dequantize", + }); + + return test_case; +} + +std::vector generate_q8ta_unary_easy_cases() { + std::vector test_cases; + + Q8taUnaryConfig config = { + {1, 16, 16, 16}, + "ACCU", + }; + + std::vector fp_layouts = { + utils::kWidthPacked, + utils::kChannelsPacked, + }; + + std::vector quant_layouts = { + utils::kPackedInt8_4W, + utils::kPackedInt8_4C, + utils::kPackedInt8_4W4C, + utils::kPackedInt8_4H4W, + utils::kPackedInt8_4C1W, + }; + + std::vector storage_types = {utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + for (const auto& fp_layout : fp_layouts) { + for (const auto& quant_layout : quant_layouts) { + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back(create_test_case_from_config( + config, storage_type, input_dtype, fp_layout, quant_layout)); + } + } + } + } + + return test_cases; +} + +std::vector generate_q8ta_unary_test_cases() { + std::vector test_cases; + + std::vector> shapes = { + {1, 3, 16, 16}, + {1, 8, 32, 32}, + {1, 16, 24, 24}, + {1, 32, 12, 12}, + {1, 1, 64, 64}, + {1, 3, 64, 64}, + {1, 4, 16, 16}, + + {1, 8, 20, 20}, + {1, 16, 14, 14}, + {1, 8, 28, 28}, + + // Odd tensor sizes + {1, 3, 15, 15}, + {1, 13, 31, 31}, + {1, 17, 23, 23}, + + // Larger tensors + {1, 64, 128, 128}, + {1, 32, 64, 64}, + {1, 128, 56, 56}, + {1, 128, 128, 128}, + }; + + std::vector fp_layouts = { + utils::kWidthPacked, + utils::kChannelsPacked, + }; + + std::vector quant_layouts = { + utils::kPackedInt8_4W, + utils::kPackedInt8_4C, + utils::kPackedInt8_4W4C, + utils::kPackedInt8_4H4W, + utils::kPackedInt8_4C1W, + }; + + std::vector storage_types = {utils::kBuffer}; + + for (const auto& shape : shapes) { + std::string prefix = "ACCU"; + for (const auto& dim : shape) { + if (dim > kRefDimSizeLimit) { + prefix = "PERF"; + break; + } + } + + for (const auto& fp_layout : fp_layouts) { + for (const auto& quant_layout : quant_layouts) { + for (const auto& storage_type : storage_types) { + Q8taUnaryConfig config; + config.shape = shape; + config.test_case_name = prefix; + + test_cases.push_back(create_test_case_from_config( + config, storage_type, vkapi::kFloat, fp_layout, quant_layout)); + } + } + } + } + + return test_cases; +} + +// Reference implementation: quantize -> relu -> dequantize +void q8ta_unary_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zp_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zp_spec = test_case.inputs()[idx++]; + const ValueSpec& layout_spec = test_case.inputs()[idx++]; + (void)layout_spec; + + ValueSpec& output_spec = test_case.outputs()[0]; + + auto input_sizes = input_spec.get_tensor_sizes(); + + int64_t num_elements = 1; + for (const auto& dim : input_sizes) { + num_elements *= dim; + } + + for (const auto& dim : input_sizes) { + if (dim > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference " + "implementation."); + } + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + auto& input_data = input_spec.get_float_data(); + + float input_scale = input_scale_spec.get_float_value(); + int32_t input_zp = input_zp_spec.get_int_value(); + float output_scale = output_scale_spec.get_float_value(); + int32_t output_zp = output_zp_spec.get_int_value(); + int32_t quant_min = -128; + int32_t quant_max = 127; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_elements); + + for (int64_t i = 0; i < num_elements; ++i) { + float input_val = input_data[i]; + + // Quantize with input scale/zp + float quantized_float = std::round(input_val / input_scale) + input_zp; + quantized_float = std::max(quantized_float, static_cast(quant_min)); + quantized_float = std::min(quantized_float, static_cast(quant_max)); + int32_t quantized_int = static_cast(quantized_float); + + // Dequantize to float + float dequantized = (quantized_int - input_zp) * input_scale; + + // Apply ReLU + float activated = std::max(dequantized, 0.0f); + + // Requantize with output scale/zp + float requantized_float = std::round(activated / output_scale) + output_zp; + requantized_float = + std::max(requantized_float, static_cast(quant_min)); + requantized_float = + std::min(requantized_float, static_cast(quant_max)); + int32_t requantized_int = static_cast(requantized_float); + + // Dequantize back to float for comparison + ref_data[i] = (requantized_int - output_zp) * output_scale; + } +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); +#ifdef DEBUG_MODE + set_print_latencies(false); +#else + set_print_latencies(false); +#endif + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Q8TA Unary (ReLU) Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = q8ta_unary_reference_impl; + + auto results = execute_test_cases( +#ifdef DEBUG_MODE + generate_q8ta_unary_easy_cases, +#else + generate_q8ta_unary_test_cases, +#endif + "Q8taUnary", +#ifdef DEBUG_MODE + 0, + 1, +#else + 3, + 10, +#endif + ref_fn); + + return 0; +} From 8a10718c51019386db9cc16d04a924c478f86a2e Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 19 Feb 2026 11:48:31 -0800 Subject: [PATCH 5/6] [ET-VK][ez] Always partition batch norm as it will be fused Pull Request resolved: https://github.com/pytorch/executorch/pull/17508 The batch norm operator registration had a check_batch_norm_node guard that restricted partitioning to 4D input tensors only. Since batch norm is always fused with adjacent operations during graph compilation, this restriction is unnecessary and prevents valid models from being partitioned to the Vulkan backend. Remove the guard so batch norm is always eligible for Vulkan partitioning regardless of input dimensionality. ghstack-source-id: 342806074 @exported-using-ghexport Differential Revision: [D93511630](https://our.internmc.facebook.com/intern/diff/D93511630/) --- backends/vulkan/op_registry.py | 14 -------------- backends/vulkan/vulkan_preprocess.py | 2 +- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 721297dea37..853ba5d3777 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1233,25 +1233,11 @@ def register_embedding(): @update_features(exir_ops.edge.aten._native_batch_norm_legit_no_training.default) def register_native_batch_norm_legit_no_training(): - def check_batch_norm_node(node: torch.fx.Node) -> bool: - x = node.args[0] - if not isinstance(x, torch.fx.Node): - return False - x_val = x.meta.get("val", None) - if x_val is None: - return False - x_shape = x_val.size() - # Only support 4-D input tensors since this is a restriction enforced by the - # operator implementation. - # TODO(ssjia): Add shape agnostic support for batch norm - return len(x_shape) == 4 - return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, inputs_dtypes=utils.FP_T, supports_prepacking=True, supports_resize=True, - are_node_inputs_supported_fn=check_batch_norm_node, ) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 3ccbdc8ab85..b276ffd16f5 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -162,10 +162,10 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ + AddmmToLinearTransform(), FuseBatchNormPass(program), FusePatternsPass(), FuseClampPass(), - AddmmToLinearTransform(), RemoveRedundantOpsTransform(), FuseQuantizedOpsTransform(), FoldQDQPass(), From 673083770696968b21c9731ff0d4fe56a11ac282 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 19 Feb 2026 11:48:35 -0800 Subject: [PATCH 6/6] [ET-VK] Support different input layouts in q8ta_binary operator Previously, the q8ta_binary operator required both inputs to use the same memory layout. This was enforced by using a single `in_layout` specialization constant for both input buffers. However, some models may have inputs with different layouts (e.g., 4W4C and 4C1W) that share the same packed dimension and block size, which should be compatible for binary operations. This change introduces a separate `other_layout` specialization constant for the second input, allowing the shader to correctly load from input_b using its actual layout while input_a continues to use `in_layout`. The C++ side now passes both layout hashes as separate specialization constants to the shader. Differential Revision: [D93768638](https://our.internmc.facebook.com/intern/diff/D93768638/) ghstack-source-id: 342806076 Pull Request resolved: https://github.com/pytorch/executorch/pull/17563 --- backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl | 3 ++- backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl index 60f437fbdce..be93e800436 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl @@ -46,6 +46,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "other_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "block_config", "0")} // Generate loading functions for input buffers @@ -71,7 +72,7 @@ void main() { ivec4 in_block_a = load_int8x4_block_from_t_in_a( in_a_meta, tidx, in_layout, block_outer_dim); ivec4 in_block_b = load_int8x4_block_from_t_in_b( - in_b_meta, tidx, in_layout, block_outer_dim); + in_b_meta, tidx, other_layout, block_outer_dim); ivec4 out_block; diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp index af934b9b521..05bdd9431c8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp @@ -42,6 +42,7 @@ void add_q8ta_binary_node( VK_CHECK_COND(input_a_info.packed_dim == output_info.packed_dim); VK_CHECK_COND(input_b_info.packed_dim == output_info.packed_dim); + VK_CHECK_COND( input_a_info.packed_dim_block_size == output_info.packed_dim_block_size); VK_CHECK_COND( @@ -105,6 +106,7 @@ void add_q8ta_binary_node( // Specialization Constants {graph.hashed_layout_of(packed_int8_output), graph.hashed_layout_of(packed_int8_input_a), + graph.hashed_layout_of(packed_int8_input_b), block_config.as_packed_int()}, // Resize args {block_config_ref},