From 36d54290310084e3b05b6153dfb21d3e76b70923 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 16 Jun 2025 11:07:38 -0700 Subject: [PATCH] [Quantized DeConv Support] Dynamically Quantized Deconvolutions with groups ==1 Here we support dynamically quantized Deconvolutions. There is some refactoring of the previous diff, but in general, we just remove the constraint in the Dynamism check that the convolution isn't transposed. For the same reasons as before, this only supports channel_axis = 1 and groups = 1. Differential Revision: [D76638904](https://our.internmc.facebook.com/intern/diff/D76638904/) [ghstack-poisoned] --- .../xnnpack/quantizer/xnnpack_quantizer.py | 2 +- .../quantizer/xnnpack_quantizer_utils.py | 82 ++++++++++++------- backends/xnnpack/test/ops/test_conv2d.py | 81 ++++++++++-------- backends/xnnpack/utils/utils.py | 31 +++++++ 4 files changed, 133 insertions(+), 63 deletions(-) diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py index 130eda03f88..c07d27e4231 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -274,7 +274,7 @@ class XNNPACKQuantizer(Quantizer): QuantPattern("linear_relu", False, False, LINEAR_TARGETS), QuantPattern("linear", True, False, LINEAR_TARGETS), QuantPattern("conv", True, False, CONV_TARGETS), - QuantPattern("conv_transpose", False, False, CONV_TARGETS), + QuantPattern("conv_transpose", True, False, CONV_TARGETS), QuantPattern("conv_relu", False, False, CONV_TARGETS), QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS), QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS), diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index 2ebf69da4f5..3d687d0b513 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -4,7 +4,10 @@ import torch import torch.nn.functional as F -from executorch.backends.xnnpack.utils.utils import is_depthwise_conv +from executorch.backends.xnnpack.utils.utils import ( + get_groups_from_conv, + is_depthwise_conv, +) from torch._subclasses import FakeTensor from torch.fx import Node from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( @@ -65,6 +68,28 @@ def decorator(annotator: AnnotatorType) -> None: return decorator +def change_quantization_config( + original_qspec, + dtype=None, + quant_min=None, + quant_max=None, + qscheme=None, + ch_axis=None, + is_dynamic=None, + observer_or_fake_quant_ctr=None, +): + return QuantizationSpec( + dtype=dtype or original_qspec.dtype, + quant_min=quant_min or original_qspec.quant_min, + quant_max=quant_max or original_qspec.quant_max, + qscheme=qscheme or original_qspec.qscheme, + ch_axis=ch_axis or original_qspec.ch_axis, + is_dynamic=is_dynamic or original_qspec.is_dynamic, + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr + or original_qspec.observer_or_fake_quant_ctr, + ) + + def is_relu_node(node: Node) -> bool: """ Check if a given node is a relu node @@ -231,6 +256,9 @@ def _do_annotate_conv( if is_relu_node(user): continue + # Tracks conditions for whether or not to skip + skip = False + input_qspec_map = {} input_act = conv_node.args[0] assert isinstance(input_act, Node) @@ -239,35 +267,33 @@ def _do_annotate_conv( weight = conv_node.args[1] assert isinstance(weight, Node) weight_qspec = get_weight_qspec(quantization_config) + num_groups = get_groups_from_conv(conv_node) + + # skip if transposed conv has more than 1 group + skip = skip or (is_conv_transpose and num_groups != 1) + print(f"{skip} conv transpose and num_groups") + if is_conv_transpose: # transposed convs per output channel quantization - weight_qspec = QuantizationSpec( - dtype=weight_qspec.dtype, - quant_min=weight_qspec.quant_min, - quant_max=weight_qspec.quant_max, - qscheme=weight_qspec.qscheme, - ch_axis=1, - is_dynamic=False, - observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr, - ) - input_qspec_map[weight] = weight_qspec + weight_qspec = change_quantization_config(weight_qspec, ch_axis=1) - # Only annotate dynamically quantized conv if it's 2D and not depthwise - if ( + input_qspec_map[weight] = weight_qspec + is_dynamic = ( quantization_config and quantization_config.input_activation and quantization_config.input_activation.is_dynamic - ): + ) + + # Only annotate dynamically quantized conv if it's 2D and not depthwise + if is_dynamic: weight_val = weight.meta.get("val", None) weight_shape = getattr(weight_val, "shape", None) - # Skip if not a 4D weight tensor (i.e. not conv2d) - if weight_shape is not None and len(weight_shape) != 4: - continue - + skip = skip or (weight_shape is not None and len(weight_shape) != 4) # Skip if depthwise (default to groups=1 since it's not an arg) - if is_depthwise_conv(weight_shape, 1, is_conv_transpose): - continue + skip = skip or ( + not is_conv_transpose and is_depthwise_conv(weight_shape, 1, False) + ) # adding weight node to the partition as well partition = [conv_node, conv_node.args[1]] @@ -277,7 +303,7 @@ def _do_annotate_conv( input_qspec_map[bias] = get_bias_qspec(quantization_config) partition.append(bias) - if _is_annotated(partition): + if _is_annotated(partition) or skip: continue if filter_fn and any(not filter_fn(n) for n in partition): @@ -324,17 +350,10 @@ def _do_annotate_conv_relu( weight = conv_node.args[1] assert isinstance(weight, Node) weight_qspec = get_weight_qspec(quantization_config) + groups = get_groups_from_conv(conv_node) if is_conv_transpose: # transposed convs per output channel quantization - weight_qspec = QuantizationSpec( - dtype=weight_qspec.dtype, - quant_min=weight_qspec.quant_min, - quant_max=weight_qspec.quant_max, - qscheme=weight_qspec.qscheme, - ch_axis=1, - is_dynamic=False, - observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr, - ) + weight_qspec = change_quantization_config(weight_qspec, ch_axis=1) input_qspec_map[weight] = weight_qspec # adding weight node to the partition as well @@ -347,6 +366,9 @@ def _do_annotate_conv_relu( if _is_annotated(partition): continue + if is_conv_transpose and groups != 1: + continue + if filter_fn and any(not filter_fn(n) for n in partition): continue diff --git a/backends/xnnpack/test/ops/test_conv2d.py b/backends/xnnpack/test/ops/test_conv2d.py index d838ef0ffe9..2a0a82d99b6 100644 --- a/backends/xnnpack/test/ops/test_conv2d.py +++ b/backends/xnnpack/test/ops/test_conv2d.py @@ -174,14 +174,11 @@ def get_inputs(self): class Conv2dDQSeq(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() - self.first = torch.nn.Conv2d( - in_channels=3, out_channels=8, kernel_size=3, padding=1 - ) - self.second = torch.nn.Conv2d( - in_channels=8, out_channels=10, kernel_size=3, padding=1 - ) + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.first = op(in_channels=3, out_channels=8, kernel_size=3, padding=1) + self.second = op(in_channels=8, out_channels=10, kernel_size=3, padding=1) def forward(self, x): y = self.first(x) @@ -192,14 +189,11 @@ def get_inputs(self): class Conv2dDQParallel(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() - self.first = torch.nn.Conv2d( - in_channels=3, out_channels=8, kernel_size=3, padding=1 - ) - self.second = torch.nn.Conv2d( - in_channels=3, out_channels=8, kernel_size=3, padding=1 - ) + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.first = op(in_channels=3, out_channels=8, kernel_size=3, padding=1) + self.second = op(in_channels=3, out_channels=10, kernel_size=3, padding=1) def forward(self, x): first = self.first(x) @@ -266,8 +260,7 @@ def _test_dq( ) DynamicallyQuantizedPartitioner = XnnpackPartitioner( - config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, - per_op_mode=True, + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, per_op_mode=True ) tester = Tester(m, m.get_inputs(), dynamic_shapes=dynamic_shapes) @@ -349,11 +342,10 @@ def test_fp32_conv2d_depthwise(self): ) def test_qs8_conv2d_depthwise(self): - for transpose in (True, False): - self._test( - Conv2d(groups=2, in_channels=2, out_channels=6, transpose=transpose), - quant_config=get_symmetric_quantization_config(), - ) + self._test( + Conv2d(groups=2, in_channels=2, out_channels=6), + quant_config=get_symmetric_quantization_config(), + ) def test_fp32_conv2d_bn(self): class Conv2dBatchNorm(torch.nn.Module): @@ -515,17 +507,14 @@ def forward(self, x): def get_inputs(self): return (torch.randn(batches, in_channels, height, width) * 11,) - for transpose in (True, False): - for per_channel_quant in (False, True): - if transpose and per_channel_quant: - continue - model = ModelConvReLU(transpose=transpose) - self._test( - model, - quant_config=get_symmetric_quantization_config( - is_per_channel=per_channel_quant - ), - ) + for per_channel_quant in (False, True): + model = ModelConvReLU() + self._test( + model, + quant_config=get_symmetric_quantization_config( + is_per_channel=per_channel_quant + ), + ) def test_qs8_conv2d_relu_seq(self): class ConvReLUSeq(torch.nn.Module): @@ -728,3 +717,31 @@ def test_dq_conv2d_parallel(self) -> None: model = Conv2dDQParallel() conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d) self._test_dq(model, conv_count) + + def test_dq_conv2d_transpose(self) -> None: + model = Conv2d( + in_channels=3, + out_channels=10, + kernel_size=(3, 3), + stride=(1, 1), + padding=(0, 0), + batches=1, + width=8, + height=8, + transpose=True, + ) + self._test_dq(model) + + def test_dq_conv2d_transpose_seq(self) -> None: + model = Conv2dDQSeq(transpose=True) + conv_count = sum( + 1 for m in model.modules() if type(m) is torch.nn.ConvTranspose2d + ) + self._test_dq(model, conv_count) + + def test_dq_conv2d_transpose_parallel(self) -> None: + model = Conv2dDQParallel(transpose=True) + conv_count = sum( + 1 for m in model.modules() if type(m) is torch.nn.ConvTranspose2d + ) + self._test_dq(model, conv_count) diff --git a/backends/xnnpack/utils/utils.py b/backends/xnnpack/utils/utils.py index b23fd444117..a8f3178f98f 100644 --- a/backends/xnnpack/utils/utils.py +++ b/backends/xnnpack/utils/utils.py @@ -25,6 +25,7 @@ is_lifted_tensor_constant, is_param, ) +from torchao.quantization.pt2e.utils import _is_conv_node, _is_conv_transpose_node ### XNNPACK Capture ### @@ -160,6 +161,36 @@ def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]: return source_fn[1] +def get_groups_from_conv(conv_node: torch.fx.Node) -> int: + if _is_conv_node(conv_node): + in_node = cast(torch.fx.Node, conv_node.args[0]) + weight_node = cast(torch.fx.Node, conv_node.args[1]) + # groups isn't given to us in the training graph so we deduce it from the weight shape + # and the input shape + + # input shape is (N, C_in, H_in, W_in) + in_channels = in_node.meta["val"].shape[1] + + # weight shape is (C_out, C_in/groups, kernel_size[0], kernel_size[1]) + in_groups = weight_node.meta["val"].shape[1] + + return in_channels // in_groups + elif _is_conv_transpose_node(conv_node): + weight_node = cast(torch.fx.Node, conv_node.args[1]) + # groups isn't given to us in the training graph so we deduce it from the weight shape + # and the output shape + + # weight shape is (C_in, C_out/groups, kernel_size[0], kernel_size[1]) + out_groups = weight_node.meta["val"].shape[1] + + # output shape is (N, C_out, H_out, W_out) + out_channels = conv_node.meta["val"].shape[1] + + return out_channels // out_groups + + raise RuntimeError(f"expected {conv_node} to be a conv or conv_transpose node") + + def is_depthwise_conv( kernel_shape: Tuple[int, ...], groups: int = 1, is_transpose: bool = False ) -> bool: