From 17d7a2bd9a2b7c7d7cc481772696c6c3922b165a Mon Sep 17 00:00:00 2001 From: morelos Date: Thu, 10 Jul 2025 14:33:43 -0700 Subject: [PATCH] [ET-VK][Ops] affine quantization operators registration # Context In order to enable dynamic quantization, especially for the source transform method using `Int8DynActInt4WeightQuantizer` we need to have vulkan versions for `quantize_affine`, `dequantize_affine`, and `choose_qparams_affine`. Currently we do not have a shader that performs block-based quantization as expected from these shaders, so we delegate to the per_tensor variant just to get unblocked. At a later stage, this will likely be developed more on in order to ensure we don't get too much accuracy loss. # Changes This creates a schema reference in the TorchAO library for out variants of these respective operators. Then there is a VK_REGISTER_OP done on them to ensure that we can properly register them when lowering the ET model with vulkan. Also the vulkan_quantizer is changed a bit in this to enable a dynamic quantization config so that we aren't purely working with just static quantization anymore. Furthermore, we have `_annotate_for_static_quantization_config` for parity/legacy reasons, and we simply create an equivalent dynamic quantization config method. We also changed `Linear.cpp`, particularly to allow a passthrough for weight_data since during dynamic quantization it's possible that it'll be a tensor_data than tensor_ref. Differential Revision: [D78035354](https://our.internmc.facebook.com/intern/diff/D78035354/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 22 ++++++ backends/vulkan/quantizer/vulkan_quantizer.py | 57 ++++++++++++-- .../runtime/graph/ops/impl/ChooseQParams.cpp | 78 +++++++++++++++++-- .../runtime/graph/ops/impl/Dequantize.cpp | 44 +++++++++++ .../vulkan/runtime/graph/ops/impl/Linear.cpp | 2 +- .../runtime/graph/ops/impl/Quantize.cpp | 35 +++++++++ 6 files changed, 226 insertions(+), 12 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 8d47afb4525..4613ccb5e83 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -272,6 +272,28 @@ def register_quantization_op(features: OpFeatures): return features +@update_features( + [ + exir_ops.edge.torchao.quantize_affine.default, + exir_ops.edge.torchao.dequantize_affine.default, + exir_ops.edge.torchao.choose_qparams_affine.default, + ] +) +def register_torchao_quantization_op(features: OpFeatures): + # TorchAO quantization operators - default to per-tensor behavior + # Same features as standard quantization ops + features.texture_impl = TextureImplFeatures( + uses_axis_map=True, + valid_packed_dims={ + PackedDim.WIDTH, + }, + ) + features.buffer_impl = True + features.resize_fn = True + features.optimal_storage = VkStorageType.BUFFER + return features + + @update_features( [ exir_ops.edge.aten.add.Tensor, diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index a82c2091cf6..3696d18c1b3 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -18,7 +18,10 @@ propagate_annotation, ) from torch.fx import Node -from torchao.quantization.pt2e import PerChannelMinMaxObserver +from torchao.quantization.pt2e import ( + PerChannelMinMaxObserver, + PlaceholderObserver, +) from torchao.quantization.pt2e.quantizer import ( QuantizationConfig, QuantizationSpec, @@ -77,6 +80,38 @@ def get_linear_weight_only_qcs_xnn_qconfig(quant_bits: int) -> QuantizationConfi ) +@functools.lru_cache +def get_dynamic_activation_qconfig( + weight_bits: int = 4, + act_qmin: int = -128, + act_qmax: int = 127, +) -> QuantizationConfig: + """ + Return a QuantizationConfig for dynamic activation quantization with 4-bit weights. + This is compatible with Vulkan backend's quantized_decomposed operators. + """ + # Dynamic activation quantization spec + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=act_qmin, + quant_max=act_qmax, + qscheme=torch.per_tensor_affine, + is_dynamic=True, + observer_or_fake_quant_ctr=PlaceholderObserver, + ) + + # Weight quantization spec (per-channel symmetric) + weight_qspec = get_linear_weight_qcs_qspec(weight_bits) + + return QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=None, + weight=weight_qspec, + bias=None, + is_qat=False, + ) + + _SUPPORTED_OPS = [ "linear", ] @@ -99,12 +134,15 @@ def transform_for_annotation( return _convert_scalars_to_attrs(model) def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - # currently only support static quant on Vulkan - model = self._annotate_for_static_quantization_config(model) + # Support both static and dynamic quantization + if self.global_config and self.global_config.input_activation and self.global_config.input_activation.is_dynamic: + model = self._annotate_for_dynamic_quantization_config(model) + else: + model = self._annotate_for_static_quantization_config(model) propagate_annotation(model) return model - def _annotate_all_static_patterns( + def _annotate_all_patterns( self, model: torch.fx.GraphModule, quantization_config: Optional[QuantizationConfig], @@ -120,7 +158,16 @@ def _annotate_all_static_patterns( def _annotate_for_static_quantization_config( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: - self._annotate_all_static_patterns( + self._annotate_all_patterns( + model, + self.global_config, + ) + return model + + def _annotate_for_dynamic_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + self._annotate_all_patterns( model, self.global_config, ) diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index 5e9599b91e6..426ed5dda2e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -306,10 +306,12 @@ void choose_qparams_tensor_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - only accept Vulkan-supported types - // The Vulkan backend only supports float32 and int32, not float64/int64 + // Verify output types - accept both int32 and float32 for zero_point + // TorchAO may use float32 for zero_point in some cases VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kFloat); // Check that texture storage is width packed if (!graph.is_buffer_storage(input)) { @@ -352,21 +354,85 @@ void choose_qparams_per_token_asymmetric_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - only accept Vulkan-supported types - // The Vulkan backend only supports float32 and int32, not float64/int64 + // Verify output types - accept both int32 and float32 for zero_point + // TorchAO may use float32 for zero_point in some cases VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kFloat); add_choose_qparams_per_token_asymmetric_node( graph, input, scale_out, zero_point_out); } +void choose_qparams_affine_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef mapping_type = args[arg_idx++]; // str - ignored for per-tensor + const ValueRef block_size = args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef target_dtype = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef eps = args[arg_idx++]; + const ValueRef scale_dtype = args[arg_idx++]; + const ValueRef zero_point_dtype = args[arg_idx++]; + const ValueRef out_tuple_ref = args[arg_idx++]; + + // Suppress unused variable warnings + (void)mapping_type; + (void)block_size; + (void)target_dtype; + (void)scale_dtype; + (void)zero_point_dtype; + + ValueRef scale_out = kDummyValueRef; + ValueRef zero_point_out = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + scale_out = out_tuple->at(0); + zero_point_out = out_tuple->at(1); + } + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale_out)); + VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf || + graph.dtype_of(input) == vkapi::kDouble); + + // Verify output types - accept both int32 and float32 for zero_point + // TorchAO may use float32 for zero_point in some cases + VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kFloat); + + // Check that texture storage is width packed + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + } + + // Default to per-tensor quantization parameter calculation for TorchAO affine ops + add_choose_qparams_tensor_node( + graph, input, quant_min, quant_max, eps, scale_out, zero_point_out); +} + REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl); VK_REGISTER_OP( quantized_decomposed.choose_qparams_per_token_asymmetric.default, choose_qparams_per_token_asymmetric_impl); + + // TorchAO affine choose_qparams operators + VK_REGISTER_OP(torchao.choose_qparams_affine.default, choose_qparams_affine_impl); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 1578b515f55..caf7e2e4fb5 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -508,6 +508,47 @@ void dequantize_per_channel_impl( graph, input, scale, zero_point, axis, quant_min, quant_max, output); } +void dequantize_affine_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef block_size = args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef input_dtype = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output_dtype = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Suppress unused variable warnings + (void)block_size; + (void)input_dtype; + (void)output_dtype; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kHalf || + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + // Default to per-tensor dequantization for TorchAO affine ops + add_dequantize_per_tensor_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.dequantize_per_tensor.tensor, @@ -518,6 +559,9 @@ REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.dequantize_per_channel.default, dequantize_per_channel_impl); + + // TorchAO affine dequantization operators + VK_REGISTER_OP(torchao.dequantize_affine.default, dequantize_affine_impl); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 86df735acbe..f48635f3940 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -351,7 +351,7 @@ void linear(ComputeGraph& graph, const std::vector& args) { ValueRef bias = args.at(2); ValueRef out = args.at(3); ValueRef weight = prepack_standard( - graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked); + graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked, /*passthrough = */ true); ValueRef mat2_is_transposed = graph.add_scalar(true); if (graph.val_is_none(bias)) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index 0105a384042..210d95a08ee 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -480,6 +480,38 @@ void quantize_per_channel_impl( graph, input, scale, zero_point, axis, quant_min, quant_max, output); } +void quantize_affine_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef block_size = args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef output_dtype = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Suppress unused variable warnings + (void)block_size; + (void)output_dtype; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kDouble || + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf); + + // Default to per-tensor quantization for TorchAO affine ops + add_quantize_per_tensor_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.quantize_per_tensor.tensor, @@ -489,6 +521,9 @@ REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.quantize_per_channel.default, quantize_per_channel_impl); + + // TorchAO affine quantization operators + VK_REGISTER_OP(torchao.quantize_affine.default, quantize_affine_impl); } } // namespace vkcompute