From 6eeefc5b597beb629d3a99f6b1bb5f5afef24db2 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 9 Mar 2026 16:45:55 -0700 Subject: [PATCH 1/2] [ET-VK][qconv] Add q8ta_conv2d_transposed operator Pull Request resolved: https://github.com/pytorch/executorch/pull/18016 Implement quantized transposed 2D convolution for the Vulkan backend, enabling int8 transposed convolutions used in decoder/upsampling networks. The GLSL shader iterates over all kernel positions and derives valid input positions via (output + padding - kernel) / stride. Invalid positions use input_zp_packed so the precomputed weight_sums zero-point correction remains consistent. Reuses the existing q8ta_conv2d weight packing and workgroup size selection since, after the pattern matcher reshapes transposed weights from (IC, OC, KH, KW) to (OC, KH*KW*IC), the layout is identical to regular conv2d. Supports hardware int8 dot product with software fallback, grouped convolutions, optional bias and ReLU activation. Only dilation=1 is supported (matching the ATen conv_transpose2d constraint). This diff was authored with Claude. ghstack-source-id: 349646651 @exported-using-ghexport Differential Revision: [D95807070](https://our.internmc.facebook.com/intern/diff/D95807070/) --- backends/vulkan/custom_ops_lib.py | 99 +++ backends/vulkan/op_registry.py | 33 + .../vulkan/patterns/quantized_convolution.py | 101 ++- .../ops/glsl/q8ta_conv2d_transposed.glsl | 254 ++++++++ .../ops/glsl/q8ta_conv2d_transposed.yaml | 17 + .../runtime/graph/ops/impl/Q8taConv2d.h | 7 + .../graph/ops/impl/Q8taConv2dTransposed.cpp | 267 ++++++++ .../impl/TestQ8taConv2dTransposed.cpp | 87 +++ backends/vulkan/test/custom_ops/targets.bzl | 1 + .../test_q8ta_conv2d_transposed.cpp | 601 ++++++++++++++++++ 10 files changed, 1444 insertions(+), 23 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taConv2dTransposed.cpp create mode 100644 backends/vulkan/test/custom_ops/impl/TestQ8taConv2dTransposed.cpp create mode 100644 backends/vulkan/test/custom_ops/test_q8ta_conv2d_transposed.cpp diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 87506f0b773..7f687bb10f4 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -685,6 +685,105 @@ def q8ta_conv2d_dw( lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd") conv2d_q8ta_q8csw_dw_op = getattr(getattr(torch.ops, namespace), name) + +def q8ta_conv2d_transposed( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + weights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + output_scale: float, + output_zero_point: int, + bias: Optional[torch.Tensor], + kernel_size: list, + stride: list, + padding: list, + output_padding: list, + dilation: list, + groups: int, + activation: str, +): + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, input_scale, input_zero_point, -128, 127, x.dtype + ) + + OC = weights.shape[0] + IC_per_group = int(x.shape[1] / groups) + K_h, K_w = kernel_size[0], kernel_size[1] + + orig_weight_K_dim = K_h * K_w * IC_per_group + if weights.shape[-1] > orig_weight_K_dim: + weights = weights[:, :orig_weight_K_dim] + + if weight_scales.shape[0] > OC: + weight_scales = weight_scales[:OC] + if bias is not None: + bias = bias[:OC] + + # Reshape to (OC, IC_per_group, K_h, K_w) then transpose to + # (IC_per_group * groups, OC_per_group, K_h, K_w) for conv_transpose2d + weights = weights.view(OC, IC_per_group, K_h, K_w) + OC_per_group = OC // groups + weights = ( + weights.view(groups, OC_per_group, IC_per_group, K_h, K_w) + .permute(0, 2, 1, 3, 4) + .contiguous() + .view(IC_per_group * groups, OC_per_group, K_h, K_w) + ) + + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + # Dequantize per OC channel. For transposed weight (IC, OC_per_group, KH, KW), + # OC is at axis=1. + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights, + weight_scales[:OC_per_group].repeat(groups) if groups > 1 else weight_scales, + weight_zeros[:OC_per_group].repeat(groups) if groups > 1 else weight_zeros, + 1, + -127, + 127, + torch.int8, + ) + + out = torch.nn.functional.conv_transpose2d( + x, weights, bias, stride, padding, output_padding, groups, dilation + ) + + if activation == "relu": + out = torch.nn.functional.relu(out) + + out = torch.ops.quantized_decomposed.quantize_per_tensor( + out, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return out + + +name = "q8ta_conv2d_transposed" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + float output_scale, + int output_zero_point, + Tensor? bias, + SymInt[] kernel_size, + SymInt[] stride, + SymInt[] padding, + SymInt[] output_padding, + SymInt[] dilation, + SymInt groups, + str activation) -> Tensor + """ +) +lib.impl(name, q8ta_conv2d_transposed, "CompositeExplicitAutograd") +q8ta_conv2d_transposed_op = getattr(getattr(torch.ops, namespace), name) + ###################### ## apply_rotary_emb ## ###################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 4bcdbadeea5..af2389d72f9 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -865,6 +865,39 @@ def register_q8ta_conv2d_ops(): ) +@update_features( + [ + exir_ops.edge.et_vk.q8ta_conv2d_transposed.default, + ] +) +def register_q8ta_conv2d_transposed_op(): + return OpFeatures( + inputs_storage=[ + utils.PACKED_INT8_CONV2D_BUFFER, # input + utils.NO_STORAGE, # input_scale (non tensor) + utils.NO_STORAGE, # input_zero_point (non tensor) + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # weight_sums (prepacked) + utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # output_scale (non tensor) + utils.NO_STORAGE, # output_zero_point (non tensor) + utils.NO_STORAGE, # bias (prepacked) + utils.NO_STORAGE, # kernel_size (non tensor) + utils.NO_STORAGE, # stride (non tensor) + utils.NO_STORAGE, # padding (non tensor) + utils.NO_STORAGE, # output_padding (non tensor) + utils.NO_STORAGE, # dilation (non tensor) + utils.NO_STORAGE, # groups (non tensor) + utils.NO_STORAGE, # activation (non tensor) + ], + outputs_storage=[ + utils.PACKED_INT8_CHANNELS_PACKED_BUFFER, + ], + supports_resize=False, + supports_prepacking=True, + ) + + # ============================================================================= # Q8taLinear.cpp # ============================================================================= diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py index 12ebbd1a382..d291d4009b7 100644 --- a/backends/vulkan/patterns/quantized_convolution.py +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import cast, List, Optional import executorch.backends.vulkan.utils as utils @@ -33,12 +33,27 @@ def __init__(self, conv_node: torch.fx.Node) -> None: self.match_found = False self.all_nodes = [self.anchor_node] + # Determine if this is a transposed convolution + self.transposed = False + self.output_padding = [0, 0] + if conv_node.target == exir_ops.edge.aten.convolution.default: + transposed_flag = conv_node.args[6] if len(conv_node.args) > 6 else False + if transposed_flag: + self.transposed = True + self.output_padding = ( + cast(List[int], conv_node.args[7]) if len(conv_node.args) > 7 else [0, 0] + ) + # Extract convolution parameters self.stride = conv_node.args[3] if len(conv_node.args) > 3 else [1, 1] self.padding = conv_node.args[4] if len(conv_node.args) > 4 else [0, 0] self.dilation = conv_node.args[5] if len(conv_node.args) > 5 else [1, 1] self.groups = conv_node.args[8] if len(conv_node.args) > 8 else 1 + # Transposed conv only supported with dilation=[1,1] + if self.transposed and cast(List[int], self.dilation) != [1, 1]: + return + const_node, arg_chain = utils.trace_args_until_placeholder( self.anchor_node.args[1] ) @@ -60,6 +75,16 @@ def __init__(self, conv_node: torch.fx.Node) -> None: self.dequantize_weight_node = dequantize_weight_node self.all_nodes.extend(arg_chain) + # For transposed conv, verify per-channel quantization is on the OC dimension. + # Transposed weight shape is (IC, OC_per_group, KH, KW), so per-OC quantization + # should be on axis=1. If axis=0, that's per-IC which is not supported. + if self.transposed and utils.is_dequant_per_channel_node( + self.dequantize_weight_node + ): + quant_axis = self.dequantize_weight_node.args[3] + if quant_axis != 1: + return + # Identify weight quantization parameter nodes self.weight_scales_node, arg_chain = utils.trace_args_until_placeholder( self.dequantize_weight_node.args[1] @@ -177,9 +202,30 @@ def make_q8ta_conv2d_custom_op( bias_tensor = get_param_tensor(ep, match.bias_node) assert bias_tensor is not None - OC, IC_per_group, H, W = weight_tensor.shape + if match.transposed: + # Transposed conv weight shape: (IC, OC_per_group, H, W) + IC, OC_per_group, H, W = weight_tensor.shape + OC = OC_per_group * match.groups + IC_per_group = IC // match.groups + # Reshape to (OC, H*W*IC_per_group) matrix format for Im2Col-based + # transposed convolution. + # (IC, OC_per_group, H, W) -> + # (groups, IC_per_group, OC_per_group, H, W) -> + # (groups, OC_per_group, H, W, IC_per_group) -> + # (OC, H*W*IC_per_group) + weight_tensor = ( + weight_tensor.reshape(match.groups, IC_per_group, OC_per_group, H, W) + .permute(0, 2, 3, 4, 1) + .contiguous() + .reshape(OC, H * W * IC_per_group) + .contiguous() + ) + else: + OC, IC_per_group, H, W = weight_tensor.shape - is_depthwise_conv = IC_per_group == 1 and match.groups == OC + is_depthwise_conv = ( + not match.transposed and IC_per_group == 1 and match.groups == OC + ) if is_depthwise_conv: assert OC % 4 == 0, "depthwise conv requires that OC is divisible by 4" @@ -188,7 +234,7 @@ def make_q8ta_conv2d_custom_op( weight_tensor = ( weight_tensor.permute(2, 3, 1, 0).contiguous().view(H, W, OC).contiguous() ) - else: + elif not match.transposed: # Reshape weight tensor from (OC, IC_per_group, H, W) to (OC, H * W * IC_per_group) # (i.e. matrix format). This prepares the weights for Im2Col-based convolution. weight_tensor = ( @@ -257,32 +303,41 @@ def make_q8ta_conv2d_custom_op( ) with graph_module.graph.inserting_before(match.output_node): - op_target = exir_ops.edge.et_vk.q8ta_conv2d.default - if is_depthwise_conv: + if match.transposed: + op_target = exir_ops.edge.et_vk.q8ta_conv2d_transposed.default + elif is_depthwise_conv: op_target = exir_ops.edge.et_vk.q8ta_conv2d_dw.default elif is_pointwise_conv: op_target = exir_ops.edge.et_vk.q8ta_conv2d_pw.default + else: + op_target = exir_ops.edge.et_vk.q8ta_conv2d.default + + op_args = ( + match.quantize_input_node, + match.input_scales_node, + match.input_zeros_node, + match.weight_node, + weight_sums_node, + match.weight_scales_node, + match.output_scales_node, + match.output_zeros_node, + match.bias_node, + [H, W], + match.stride, + match.padding, + ) + if match.transposed: + op_args = op_args + (match.output_padding,) + op_args = op_args + ( + match.dilation, + match.groups, + "relu" if match.relu_node is not None else "none", + ) qconv_node = graph_module.graph.create_node( "call_function", op_target, - args=( - match.quantize_input_node, - match.input_scales_node, - match.input_zeros_node, - match.weight_node, - weight_sums_node, - match.weight_scales_node, - match.output_scales_node, - match.output_zeros_node, - match.bias_node, # Add bias after weight_scales - [H, W], # Pass kernel size information before stride - match.stride, - match.padding, - match.dilation, - match.groups, - "relu" if match.relu_node is not None else "none", - ), + args=op_args, ) qconv_node.meta["val"] = match.output_node.meta["val"] diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.glsl new file mode 100644 index 00000000000..efed2e3a95b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.glsl @@ -0,0 +1,254 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT} + +#extension GL_EXT_control_flow_attributes : require +$if USE_INT8_DOT_PRODUCT_EXT == 1: + #extension GL_EXT_integer_dot_product : require + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +${define_active_storage_type("buffer")} + +layout(std430) buffer; + +#include "indexing.glslh" +#include "common.glslh" +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", "texture2d", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +// Metadata for input/output tensors (memory layout agnostic) +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} + +// Layout specialization constants +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} + +// Load weight block for a given (ic4, kx, ky, oc4) position. +// Weight texture layout (from pack_q8_conv2d_weights.glsl): +// block_x = oc4 * K_w + kx +// block_y = ky * IC4 + ic4 +// Each texel ivec4 has 4 components (4 output channels), each component is +// a packed int32 containing 4 int8 values for 4 consecutive input channels. +ivec4 load_weight_block(int ic4, int kx, int ky, int oc4, int IC4, int KW) { + const int block_x = oc4 * KW + kx; + const int block_y = ky * IC4 + ic4; + return texelFetch(t_packed_int8_weight, ivec2(block_x, block_y), 0); +} + +ivec4 quantize(const vec4 texel, const float inv_scale, const int zp) { + vec4 quantized = round(texel * inv_scale) + zp; + return clamp(ivec4(quantized), -128, 127); +} + +void main() { + // Thread mapping: same as q8ta_conv2d + // Each thread handles a 4W x 4C tile of output + int oc4 = int(gl_GlobalInvocationID.z); + int w4 = int(gl_GlobalInvocationID.x); + + // Initialize output tensor index (WHCN order) + TensorIndex4D outp_tidx; + outp_tidx.data[0] = w4 * 4; + outp_tidx.data[1] = int(gl_GlobalInvocationID.y); + outp_tidx.data[2] = oc4 * 4; + outp_tidx.data[3] = 0; + + const int W = int(outp.sizes[0][0]); + const int OC = int(outp.sizes[0][2]); + + // Bounds check + if (any(greaterThanEqual(outp_tidx.data, ivec4(outp.sizes[0])))) { + return; + } + + // Input dimensions + const int inp_W = int(inp.sizes[0][0]); + const int inp_H = int(inp.sizes[0][1]); + const int IC = int(inp.sizes[0][2]); + + // Compute channels per group + const int OC_per_group = OC / conv2d_params.groups; + const int IC_per_group = IC / conv2d_params.groups; + const int IC4_per_group = div_up_4(IC_per_group); + + // Determine which group this output channel block belongs to + const int group_idx = outp_tidx.data[2] / OC_per_group; + const int ic_group_start = group_idx * IC_per_group; + + // Get strides for efficient indexing + const int inp_w_stride = int(inp.strides[0][0]); + const int inp_h_stride = int(inp.strides[0][1]); + + // Create packed input zero point (4 copies of input_zp packed into int32) + const int input_zp_packed = pack_into_int32(ivec4(input_zp)); + + // Initialize accumulators for 4 width positions x 4 output channels each + ivec4 acc[4]; + [[unroll]] for (int i = 0; i < 4; ++i) { + acc[i] = ivec4(0); + } + + // Transposed convolution loop structure: + // Iterate over all kernel positions (ky, kx, ic4). For each position, + // compute the corresponding input position. If the input position is valid + // (in bounds and on a stride-aligned position), load the actual input; + // otherwise use input_zp_packed. This ensures weight_sums correction is + // consistent with the accumulation (all kernel positions are accounted for). + // + // For transposed convolution, the input position for a given (output, kernel) + // is: input = (output + padding - kernel) / stride, which must be exact + // (remainder == 0) and within [0, input_size). + + const int KH = conv2d_params.kernel_size.y; + const int KW = conv2d_params.kernel_size.x; + const int stride_x = conv2d_params.stride.x; + const int stride_y = conv2d_params.stride.y; + const int pad_x = conv2d_params.padding.x; + const int pad_y = conv2d_params.padding.y; + + for (int ky = 0; ky < KH; ky++) { + // Check if this kernel row maps to a valid input row + const int in_y_numer = outp_tidx.data[1] + pad_y - ky; + const bool y_stride_valid = (in_y_numer >= 0) && ((in_y_numer % stride_y) == 0); + const int iy = y_stride_valid ? (in_y_numer / stride_y) : 0; + const bool h_in_bounds = y_stride_valid && (iy < inp_H); + + // Loop order: ic4 before kx for better weight texture cache locality. + // Consecutive ic4 values at fixed kx access consecutive y-rows in the + // weight texture, but consecutive kx at fixed ic4 access consecutive + // x-coordinates (same row) which is the fast dimension for texture + // cache lines. By iterating ic4 in the outer loop and kx in the inner + // loop, each ic4 iteration sweeps kx across a texture row. + for (int ic4 = 0; ic4 < IC4_per_group; ic4++) { + for (int kx = 0; kx < KW; kx++) { + // Load weight block: 4 output channels x 4 input channels + const ivec4 weight_block = load_weight_block( + ic4, kx, ky, oc4, IC4_per_group, KW); + + // Process 4 adjacent width positions + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + const int ow = outp_tidx.data[0] + subtile_w; + const int in_x_numer = ow + pad_x - kx; + + // Load packed input, or use zero point if out of bounds + int packed_input = input_zp_packed; + if (h_in_bounds && in_x_numer >= 0 && (in_x_numer % stride_x) == 0) { + const int ix = in_x_numer / stride_x; + if (ix < inp_W) { + TensorIndex4D inp_tidx; + inp_tidx.data[0] = ix; + inp_tidx.data[1] = iy; + inp_tidx.data[2] = ic_group_start + ic4 * 4; + inp_tidx.data[3] = 0; + + int inp_texel_idx; + if (get_outer_packed_dim_block_size(inp_layout) == 1) { + inp_texel_idx = tensor4d_idx_to_texel_idx(inp, inp_tidx, inp_layout); + } else { + const int w4_inp = div_4(ix); + const int inp_c4 = div_4(inp_tidx.data[2]); + inp_texel_idx = (iy * inp_h_stride + w4_inp * inp_w_stride + inp_c4) * 4 + mod_4(ix); + } + packed_input = t_packed_int8_input[inp_texel_idx]; + } + } + + // Accumulate using packed int8 dot product for each output channel + [[unroll]] for (int oc_offset = 0; oc_offset < 4; ++oc_offset) { + acc[subtile_w][oc_offset] = dotPacked4x8AccSat( + packed_input, + weight_block[oc_offset], + acc[subtile_w][oc_offset]); + } + } + } + } + } + + // Apply input zero point correction via weight_sums + const vec4 weight_sums = vec4(t_weight_sums[oc4]); + const vec4 weight_scales = vec4(t_weight_scales[oc4]); + + // Convert to float, apply dequantization, and optionally add bias + vec4 facc[4]; + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] = vec4(acc[subtile_w]); + facc[subtile_w] -= weight_sums * input_zp; + facc[subtile_w] *= weight_scales * input_scale; + } + + // Apply bias if enabled + if (apply_bias > 0) { + const vec4 bias = vec4(t_bias[oc4]); + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] += bias; + } + } + + // Apply ReLU if enabled + if (activation_type > 0) { + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] = max(facc[subtile_w], vec4(0.0)); + } + } + + // Compute base output texel index (for subtile_w=0) + const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout); + const int out_w_stride = int(outp.strides[0][0]); + + // Quantize and store outputs using stride offsets + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + // Skip out-of-bounds width positions + if (outp_tidx.data[0] >= W) { + continue; + } + + const ivec4 quantized_out = quantize(facc[subtile_w], output_inv_scale, output_zp); + const int packed_out = pack_into_int32(quantized_out); + + // Store using stride offset from base + int outp_texel_idx; + if (get_outer_packed_dim_block_size(outp_layout) == 1) { + outp_texel_idx = base_outp_texel_idx + subtile_w * out_w_stride; + } else { + outp_texel_idx = base_outp_texel_idx + subtile_w; + } + + t_packed_int8_output[outp_texel_idx] = packed_out; + + outp_tidx.data[0] += 1; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.yaml new file mode 100644 index 00000000000..69469fabd95 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.yaml @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_conv2d_transposed: + parameter_names_with_default_values: + DTYPE: float + USE_INT8_DOT_PRODUCT_EXT: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: q8ta_conv2d_transposed + - NAME: q8ta_conv2d_transposed_fallback + USE_INT8_DOT_PRODUCT_EXT: 0 diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h index 6da98fbef74..f463589c50a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include namespace vkcompute { @@ -145,4 +146,10 @@ void add_q8ta_im2col_node( void q8ta_conv2d_im2col(ComputeGraph& graph, const std::vector& args); +// Transposed convolution + +void q8ta_conv2d_transposed( + ComputeGraph& graph, + const std::vector& args); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dTransposed.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dTransposed.cpp new file mode 100644 index 00000000000..bdbdaa14fec --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dTransposed.cpp @@ -0,0 +1,267 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include + +namespace vkcompute { + +// Dedicated workgroup size functions for transposed convolution. +// Unlike regular conv2d, transposed conv with stride > 1 causes branch +// divergence along the height dimension (different rows have different +// stride-alignment patterns). Keeping local_y=1 ensures all threads in a +// workgroup process the same height row, maximizing branch coherence. + +utils::uvec3 pick_q8ta_conv2d_transposed_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, output); + const uint32_t H = graph->size_at(-2, output); + const uint32_t C = graph->size_at(-3, output); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {W4, H, C4}; +} + +utils::uvec3 pick_q8ta_conv2d_transposed_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + (void)graph; + (void)args; + + // Always keep local_y=1 to avoid branch divergence between height rows. + if (global_workgroup_size[0u] >= 6 && global_workgroup_size[2u] >= 6) { + return {8u, 1u, 8u}; + } + if (global_workgroup_size[0u] < 2u) { + return {1u, 1u, 64u}; + } + if (global_workgroup_size[2u] < 2u) { + return {64u, 1u, 1u}; + } + return {16u, 1u, 4u}; +} + +void add_q8ta_conv2d_transposed_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const uint32_t activation_type, + const ValueRef packed_int8_output) { + Conv2DParams conv_params = create_conv2d_params( + graph, + packed_int8_input, + packed_int8_output, + kernel_size, + stride, + padding, + dilation, + groups); + + // Transposed convolution only supports dilation=1 + VK_CHECK_COND( + conv_params.dilation[0] == 1 && conv_params.dilation[1] == 1, + "q8ta_conv2d_transposed only supports dilation=1"); + + // The implementation requires that for grouped convolutions, the input + // channels per group is a multiple of 4. + if (conv_params.groups > 1) { + VK_CHECK_COND(conv_params.in_channels_per_group % 4 == 0); + } + + // Validate packed dim info for input and output tensors + VK_CHECK_COND(q8ta_conv2d_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_input))); + VK_CHECK_COND(q8ta_conv2d_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_output))); + + // Validate dtype is kInt8x4 + VK_CHECK_COND(graph.dtype_of(packed_int8_input) == vkapi::kInt8x4); + VK_CHECK_COND(graph.dtype_of(packed_int8_output) == vkapi::kInt8x4); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + const bool use_hw_dot = + graph.context()->adapter_ptr()->supports_int8_dot_product(); + std::string kernel_name = + use_hw_dot ? "q8ta_conv2d_transposed" : "q8ta_conv2d_transposed_fallback"; + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + + // Pass metadata for both output and input tensors + vkapi::ParamsBindList param_buffers = { + graph.buffer_meta_ubo(packed_int8_output), + graph.buffer_meta_ubo(packed_int8_input), + graph.create_params_buffer(conv_params)}; + + // Build spec constants: apply_bias, activation_type + layout constants + vkapi::SpecVarList spec_constants = { + apply_bias, + activation_type, + // Layout specialization constants + graph.hashed_layout_of(packed_int8_input), + graph.hashed_layout_of(packed_int8_output), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_q8ta_conv2d_transposed_global_wg_size, + pick_q8ta_conv2d_transposed_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + spec_constants, + // Resize args + {})); +} + +void q8ta_conv2d_transposed( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + args.at(idx++); // output_padding: only affects output size, not shader + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + + QuantizationConfig weight_quant_config(8, kPerChannel, {}); + + // Reuse the conv2d weight packing (after the pattern matcher reshapes the + // transposed weight to (OC, KH*KW*IC_per_group), the weight layout is + // identical to regular conv2d) + ValueRef packed_weight = prepack_quantized_conv2d_weight( + graph, + weight_quant_config, + weight_data, + packed_int8_input, + packed_int8_output, + groups, + kernel_size); + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(weight_scales_data), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + add_q8ta_conv2d_transposed_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + activation_type_val, + packed_int8_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.q8ta_conv2d_transposed.default, q8ta_conv2d_transposed); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2dTransposed.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2dTransposed.cpp new file mode 100644 index 00000000000..894ce71fed9 --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2dTransposed.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace vkcompute { + +void test_q8ta_conv2d_transposed( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef output_padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef layout_int = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + int32_t layout_value = graph.extract_scalar(layout_int); + utils::GPUMemoryLayout layout = + static_cast(layout_value); + + TmpTensor packed_int8_input( + &graph, graph.sizes_of(fp_input), vkapi::kInt8x4, utils::kBuffer, layout); + + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_output), + vkapi::kInt8x4, + utils::kBuffer, + layout); + + add_q8ta_quantize_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + output_padding, + dilation, + groups, + activation, + packed_int8_output}; + VK_GET_OP_FN("et_vk.q8ta_conv2d_transposed.default")(graph, conv_args); + + add_q8ta_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP( + test_etvk.test_q8ta_conv2d_transposed.default, + test_q8ta_conv2d_transposed); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index badba5666fa..ba4873af603 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -98,3 +98,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("test_q8ta_conv2d_pw") define_custom_op_test_binary("test_q8ta_conv2d_dw") define_custom_op_test_binary("test_q8ta_linear") + define_custom_op_test_binary("test_q8ta_conv2d_transposed") diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_transposed.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_transposed.cpp new file mode 100644 index 00000000000..903a9c678b1 --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_transposed.cpp @@ -0,0 +1,601 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +#include + +#include "conv2d_utils.h" +#include "utils.h" + +// #define DEBUG_MODE + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 100; + +// Transposed convolution output size formula: +// H_out = (H_in - 1) * stride_h - 2 * pad_h + dilation_h * (K_h - 1) +// + output_pad_h + 1 +static int64_t get_transpose_output_height( + const Conv2dConfig& config, + int32_t output_pad_h) { + return (config.input_size.h - 1) * config.stride.h - 2 * config.padding.h + + config.dilation.h * (config.kernel.h - 1) + output_pad_h + 1; +} + +static int64_t get_transpose_output_width( + const Conv2dConfig& config, + int32_t output_pad_w) { + return (config.input_size.w - 1) * config.stride.w - 2 * config.padding.w + + config.dilation.w * (config.kernel.w - 1) + output_pad_w + 1; +} + +// Utility function to create a test case from a Conv2dConfig for transposed +// convolution +static TestCase create_test_case_from_config( + const Conv2dConfig& config, + int32_t output_pad_h, + int32_t output_pad_w, + vkapi::ScalarType input_dtype, + utils::StorageType fp_storage_type, + utils::GPUMemoryLayout int8_memory_layout) { + TestCase test_case; + + int64_t H_out = get_transpose_output_height(config, output_pad_h); + int64_t W_out = get_transpose_output_width(config, output_pad_w); + + // Input tensor (float/half) - [1, C_in, H_in, W_in] (batch size always 1) + // For transposed conv, C_in is typically larger (downsampled channels) + std::vector input_size = { + 1, config.channels.in, config.input_size.h, config.input_size.w}; + + utils::GPUMemoryLayout fp_memory_layout = fp_storage_type == utils::kBuffer + ? utils::kWidthPacked + : utils::kChannelsPacked; + + // Create test case name + std::string prefix = config.test_case_name.substr(0, 4); + std::string test_name = prefix + " " + std::to_string(config.channels.in) + + "->" + std::to_string(config.channels.out) + " " + + "I=" + std::to_string(config.input_size.h) + "," + + std::to_string(config.input_size.w) + " " + + "g=" + std::to_string(config.groups) + " " + + "k=" + std::to_string(config.kernel.h) + " " + + "op=" + std::to_string(output_pad_h) + "," + + std::to_string(output_pad_w) + " " + + repr_str(utils::kBuffer, int8_memory_layout); + test_case.set_name(test_name); + + test_case.set_operator_name("test_etvk.test_q8ta_conv2d_transposed.default"); + + ValueSpec input_tensor( + input_size, + input_dtype, + fp_storage_type, + fp_memory_layout, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.008123; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = 2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [C_out, align_up_4(C_in_per_group * K_h * + // K_w)] After the pattern matcher reshapes, the transposed conv weight has + // the same layout as regular conv2d + const int64_t in_channels_per_group = config.channels.in / config.groups; + const int64_t in_features = utils::align_up_4( + in_channels_per_group * config.kernel.h * config.kernel.w); + std::vector weight_size = {config.channels.out, in_features}; + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, + fp_storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + const int64_t aligned_out_channels = utils::align_up_4(config.channels.out); + + ValueSpec weight_scales( + {aligned_out_channels}, + input_dtype, + fp_storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {aligned_out_channels}, + vkapi::kInt, + fp_storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + compute_weight_sums( + weight_sums, quantized_weight, config.channels.out, in_features); + + ValueSpec bias( + {aligned_out_channels}, + input_dtype, + fp_storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + bias.set_constant(true); + + float output_scale_val = 0.05314; + ValueSpec output_scale(output_scale_val); + + int32_t output_zero_point_val = -1; + ValueSpec output_zero_point(output_zero_point_val); + + ValueSpec stride({config.stride.h, config.stride.w}); + ValueSpec padding({config.padding.h, config.padding.w}); + ValueSpec output_padding({output_pad_h, output_pad_w}); + ValueSpec dilation({config.dilation.h, config.dilation.w}); + ValueSpec groups(config.groups); + ValueSpec kernel_size({config.kernel.h, config.kernel.w}); + + // Output tensor - [1, C_out, H_out, W_out] + ValueSpec output( + {1, config.channels.out, H_out, W_out}, + input_dtype, + fp_storage_type, + fp_memory_layout, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(bias); + test_case.add_input_spec(kernel_size); + test_case.add_input_spec(stride); + test_case.add_input_spec(padding); + test_case.add_input_spec(output_padding); + test_case.add_input_spec(dilation); + test_case.add_input_spec(groups); + + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + + ValueSpec layout_int(static_cast(int8_memory_layout)); + test_case.add_input_spec(layout_int); + + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(output_scale_val + 1e-4f); + + test_case.set_shader_filter({ + "nchw_to", + "to_nchw", + "q8ta_quantize", + "q8ta_dequantize", + }); + + return test_case; +} + +// Generate easy test cases for debugging +std::vector generate_quantized_conv2d_transposed_easy_cases() { + std::vector test_cases; + + Conv2dConfig config = { + OutInChannels(16, 32), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1, + }; + + std::vector int8_memory_layouts = { + utils::kPackedInt8_4C1W, utils::kPackedInt8_4W4C, utils::kPackedInt8_4C}; + + for (const utils::GPUMemoryLayout int8_memory_layout : int8_memory_layouts) { + config.test_case_name = + make_test_case_name(config, false, utils::kTexture3D, utils::kBuffer); + test_cases.push_back(create_test_case_from_config( + config, + /*output_pad_h=*/1, + /*output_pad_w=*/1, + vkapi::kFloat, + utils::kTexture3D, + int8_memory_layout)); + } + + return test_cases; +} + +// Generate test cases for quantized transposed conv2d +static std::vector generate_quantized_conv2d_transposed_test_cases() { + std::vector test_cases; + if (!vkcompute::api::context()->adapter_ptr()->supports_int8_dot_product()) { + return test_cases; + } + + // Each entry: {config, output_pad_h, output_pad_w} + struct TransposedConvTestConfig { + Conv2dConfig config; + int32_t output_pad_h; + int32_t output_pad_w; + }; + + std::vector configs = { + // Basic transposed conv (stride=2, common in decoder networks) + {{OutInChannels(16, 32), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 1, + 1}, + {{OutInChannels(32, 64), + InputSize2D(4, 4), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 1, + 1}, + // No output padding + {{OutInChannels(16, 32), + InputSize2D(8, 8), + KernelSize(4, 4), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 0, + 0}, + // Stride=1 (degenerate case) + {{OutInChannels(16, 16), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + 0, + 0}, + // Grouped transposed conv + {{OutInChannels(32, 64), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 2}, + 1, + 1}, + // Larger spatial + {{OutInChannels(64, 128), + InputSize2D(16, 16), + KernelSize(4, 4), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 0, + 0}, + // Performance cases + {{OutInChannels(64, 128), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 1, + 1}, + {{OutInChannels(128, 256), + InputSize2D(16, 16), + KernelSize(4, 4), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 0, + 0}, + }; + + std::vector int8_memory_layouts = { + utils::kPackedInt8_4C1W, utils::kPackedInt8_4W4C, utils::kPackedInt8_4C}; + + for (auto& tc : configs) { + auto& config = tc.config; + bool is_performance = config.channels.out > kRefDimSizeLimit || + config.channels.in > kRefDimSizeLimit || + config.input_size.h > kRefDimSizeLimit || + config.input_size.w > kRefDimSizeLimit; + + for (const utils::GPUMemoryLayout int8_memory_layout : + int8_memory_layouts) { + config.test_case_name = make_test_case_name( + config, is_performance, utils::kTexture3D, utils::kBuffer); + + test_cases.push_back(create_test_case_from_config( + config, + tc.output_pad_h, + tc.output_pad_w, + vkapi::kFloat, + utils::kTexture3D, + int8_memory_layout)); + } + } + + return test_cases; +} + +// Reference implementation for quantized transposed conv2d +static void conv2d_transposed_q8ta_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& kernel_size_spec = test_case.inputs()[idx++]; + const ValueSpec& stride_spec = test_case.inputs()[idx++]; + const ValueSpec& padding_spec = test_case.inputs()[idx++]; + const ValueSpec& output_padding_spec = test_case.inputs()[idx++]; + (void)output_padding_spec; // output_padding only affects output size + const ValueSpec& dilation_spec = test_case.inputs()[idx++]; + const ValueSpec& groups_spec = test_case.inputs()[idx++]; + const ValueSpec& activation_spec = test_case.inputs()[idx++]; + (void)activation_spec; + const ValueSpec& layout_spec = test_case.inputs()[idx++]; + (void)layout_spec; + + ValueSpec& output_spec = test_case.outputs()[0]; + + auto input_sizes = input_spec.get_tensor_sizes(); + auto output_sizes = output_spec.get_tensor_sizes(); + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t H_in = input_sizes[2]; + int64_t W_in = input_sizes[3]; + int64_t C_out = output_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + auto kernel_size_data = kernel_size_spec.get_int32_data(); + int64_t K_h = kernel_size_data[0]; + int64_t K_w = kernel_size_data[1]; + + auto stride_data = stride_spec.get_int32_data(); + auto padding_data = padding_spec.get_int32_data(); + auto dilation_data = dilation_spec.get_int32_data(); + int64_t stride_h = stride_data[0]; + int64_t stride_w = stride_data[1]; + int64_t pad_h = padding_data[0]; + int64_t pad_w = padding_data[1]; + int64_t dilation_h = dilation_data[0]; + int64_t dilation_w = dilation_data[1]; + int64_t groups = groups_spec.get_int_value(); + + if (N > kRefDimSizeLimit || C_in > kRefDimSizeLimit || + H_in > kRefDimSizeLimit || W_in > kRefDimSizeLimit || + C_out > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zero_point = output_zeros_spec.get_int_value(); + + int64_t C_in_per_group = C_in / groups; + int64_t C_out_per_group = C_out / groups; + + int64_t num_output_elements = N * C_out * H_out * W_out; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + const int64_t in_features = utils::align_up_4(C_in_per_group * K_h * K_w); + + // Transposed convolution reference implementation. + // For transposed conv, we scatter each input element across the output + // rather than gather. But for the reference we compute it by iterating + // over output positions and finding which input positions contribute. + // + // For each output position (oh, ow), an input position (iy, ix) contributes + // via kernel position (kh, kw) if: + // oh + pad_h - kh * dilation_h == iy * stride_h + // ow + pad_w - kw * dilation_w == ix * stride_w + // i.e., (oh + pad_h - kh * dilation_h) must be divisible by stride_h + // and the quotient must be a valid input index. + for (int64_t n = 0; n < N; ++n) { + for (int64_t out_c = 0; out_c < C_out; ++out_c) { + int64_t group_idx = out_c / C_out_per_group; + int64_t in_c_start = group_idx * C_in_per_group; + + for (int64_t out_h = 0; out_h < H_out; ++out_h) { + for (int64_t out_w = 0; out_w < W_out; ++out_w) { + int32_t int_sum = 0; + int32_t weight_sum = 0; + + for (int64_t kh = 0; kh < K_h; ++kh) { + int64_t h_offset = out_h + pad_h - kh * dilation_h; + if (h_offset < 0 || h_offset % stride_h != 0) { + continue; + } + int64_t iy = h_offset / stride_h; + if (iy >= H_in) { + continue; + } + + for (int64_t kw = 0; kw < K_w; ++kw) { + int64_t w_offset = out_w + pad_w - kw * dilation_w; + if (w_offset < 0 || w_offset % stride_w != 0) { + continue; + } + int64_t ix = w_offset / stride_w; + if (ix >= W_in) { + continue; + } + + for (int64_t ic_local = 0; ic_local < C_in_per_group; + ++ic_local) { + int64_t in_c = in_c_start + ic_local; + + int64_t input_idx = n * (C_in * H_in * W_in) + + in_c * (H_in * W_in) + iy * W_in + ix; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + + input_zero_point; + quant_input_f = + std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + // Weight layout: [C_out, align_up_4(C_in_per_group * K_h * + // K_w)] Inner dimension order: kh, kw, ic_local + int64_t weight_idx = out_c * in_features + + (kh * (K_w * C_in_per_group) + kw * C_in_per_group + + ic_local); + int8_t quantized_weight = weight_data[weight_idx]; + + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + weight_sum += static_cast(quantized_weight); + } + } + } + + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_c]; + + float_result += bias_data[out_c]; + + float quant_output_f = + std::round(float_result / output_scale) + output_zero_point; + quant_output_f = std::min(std::max(quant_output_f, -128.0f), 127.0f); + int8_t quantized_output = static_cast(quant_output_f); + + float dequant_output = + (static_cast(quantized_output) - output_zero_point) * + output_scale; + + int64_t output_idx = n * (C_out * H_out * W_out) + + out_c * (H_out * W_out) + out_h * W_out + out_w; + ref_data[output_idx] = dequant_output; + } + } + } + } +} + +static void reference_impl(TestCase& test_case) { + conv2d_transposed_q8ta_reference_impl(test_case); +} + +static int64_t quantized_conv2d_transposed_flop_calculator( + const TestCase& test_case) { + int kernel_idx = 9; + + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + const auto& kernel_sizes = test_case.inputs()[kernel_idx].get_int32_data(); + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t C_out = output_sizes[1]; + int64_t K_h = kernel_sizes[0]; + int64_t K_w = kernel_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + int64_t output_elements = N * C_out * H_out * W_out; + int64_t ops_per_output = C_in * K_h * K_w; + + return output_elements * ops_per_output; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); +#ifdef DEBUG_MODE + set_print_latencies(true); +#else + set_print_latencies(false); +#endif + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout + << "Quantized Transposed Conv2d Operation with Output Quantization Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + auto results = execute_test_cases( +#ifdef DEBUG_MODE + generate_quantized_conv2d_transposed_easy_cases, +#else + generate_quantized_conv2d_transposed_test_cases, +#endif + quantized_conv2d_transposed_flop_calculator, + "QuantizedTransposedConv2d", +#ifdef DEBUG_MODE + 0, + 1, +#else + 3, + 10, +#endif + ref_fn); + + return 0; +} From 52e7fe75b4d61dcd95ab3b5527b3a9873df5a010 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 9 Mar 2026 16:45:57 -0700 Subject: [PATCH 2/2] [ET-VK][qlinear] Add bmm support to quantized linear pattern detector Pull Request resolved: https://github.com/pytorch/executorch/pull/18017 Some quantized linear projections (e.g. in EdgeTAM's SpatialPerceiver / mask decoder) decompose as aten.bmm instead of aten.mm. Add aten.bmm.default as an anchor node in the quantized linear pattern detector so these nodes can be fused into custom quantized linear ops. Reject bmm nodes with batch dim > 1 since the custom ops assume a single batch. ghstack-source-id: 349646654 @exported-using-ghexport Differential Revision: [D95807072](https://our.internmc.facebook.com/intern/diff/D95807072/) --- backends/vulkan/patterns/quantized_linear.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 85e3476cad3..b9b307e14f1 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -90,6 +90,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 # Identify output node self.output_node = self.anchor_node + # bmm with batch dim > 1 is not supported + is_bmm = self.anchor_node.target == exir_ops.edge.aten.bmm.default + if is_bmm and self.output_node.meta["val"].shape[0] != 1: + return + # Identify primary input node of the anchor. Due to decomposition of aten.linear # there may be a view_copy node between the original input tensor to the linear # op and the actual linear op node. @@ -268,6 +273,7 @@ def is_input_dynamic_perchannel_quantized(self) -> bool: exir_ops.edge.aten.linear.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.bmm.default, }