From 132073455e9462bb7d2035b9e5bb418ae82df3a5 Mon Sep 17 00:00:00 2001 From: morelos Date: Wed, 11 Jun 2025 09:59:41 -0700 Subject: [PATCH 1/2] [ET-VK][Ops] choose_qparams op shaders and impl Creating the choose_qparams per_tensor and per_token logic shaders and impl which are linked with the testing framework Differential Revision: [D76436933](https://our.internmc.facebook.com/intern/diff/D76436933/) [ghstack-poisoned] --- .../graph/ops/glsl/choose_qparams.glsl | 564 ++++++++++++++++++ .../graph/ops/glsl/choose_qparams.yaml | 23 + .../runtime/graph/ops/impl/ChooseQParams.cpp | 378 ++++++++++++ .../test/op_tests/choose_qparams_test.cpp | 56 ++ 4 files changed, 1021 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/choose_qparams.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glsl new file mode 100644 index 00000000000..0b0d4078806 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glsl @@ -0,0 +1,564 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define FVEC4_T ${texel_load_type(IN_DTYPE, STORAGE)} + +${define_active_storage_type(STORAGE)} +${define_required_extensions(IN_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, STORAGE)} +${layout_declare_tensor(B, "rw", "t_scale", "float", STORAGE)} +${layout_declare_tensor(B, "rw", "t_zero_point", "int", STORAGE)} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + int quant_min; + int quant_max; + }; +$else: + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +$if STORAGE == "buffer": + ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} + ${layout_declare_ubo(B, "ivec4", "t_in_strides")} + ${layout_declare_ubo(B, "ivec4", "t_scale_sizes")} + ${layout_declare_ubo(B, "ivec4", "t_scale_strides")} + ${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")} + ${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")} +$else: + ${layout_declare_ubo(B, "ivec3", "t_in_limits")} + ${layout_declare_ubo(B, "ivec3", "t_scale_limits")} + ${layout_declare_ubo(B, "ivec3", "t_zero_point_limits")} + +#include "indexing_utils.h" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Constants for reduction +#define NWORKERS 64 + +// Constant for small scale threshold +#define SMALL_SCALE_THRESHOLD 6.1e-5 + +// Shared memory for reduction - must match local work group size +shared float shared_min[NWORKERS]; +shared float shared_max[NWORKERS]; + +// Calculate scale and zero point from min and max values +void calculate_scale_and_zero_point( + float min_val, + float max_val, + int qmin, + int qmax, + out float scale_val, + out int zero_point_val) { + // Ensure the range includes zero + min_val = min(min_val, 0.0); + max_val = max(max_val, 0.0); + + // Calculate scale + scale_val = (max_val - min_val) / float(qmax - qmin); + + // Handle zero or very small scale + if (scale_val == 0.0 || isinf(1.0 / scale_val)) { + scale_val = 0.1; + } + + // Cut off small scale + if (scale_val < SMALL_SCALE_THRESHOLD) { + float org_scale = scale_val; + scale_val = SMALL_SCALE_THRESHOLD; + + // Adjust min and max based on new scale + if (min_val == 0.0) { + max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin); + } else if (max_val == 0.0) { + min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min_val *= amplifier; + max_val *= amplifier; + } + } + + // Calculate zero point + float zero_point_from_min = float(qmin) - min_val / scale_val; + float zero_point_from_max = float(qmax) - max_val / scale_val; + float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale_val); + float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale_val); + float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Nudge zero point to integer + if (initial_zero_point < float(qmin)) { + zero_point_val = qmin; + } else if (initial_zero_point > float(qmax)) { + zero_point_val = qmax; + } else { + zero_point_val = int(round(initial_zero_point)); + } +} + +#ifdef USING_BUFFER + +$if MODE == "per_tensor": + void main() { + // Single-Pass Hierarchical Reduction for per-tensor min/max + uint global_id = gl_GlobalInvocationID.x; + uint local_id = gl_LocalInvocationID.x; + uint group_id = gl_WorkGroupID.x; + uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; + + // Calculate total number of elements in the input tensor + uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); + + // Phase 1: Each thread processes multiple elements with stride + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + for (uint i = global_id; i < total_elements; i += total_threads) { + float val = t_in[i]; + if (!isnan(val) && !isinf(val)) { + if (!found_valid) { + thread_min = val; + thread_max = val; + found_valid = true; + } else { + thread_min = min(thread_min, val); + thread_max = max(thread_max, val); + } + } + } + + // Phase 2: Intra-group reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + // Handle infinity values properly + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Phase 3: Final result calculation (single workgroup only) + if (local_id == 0) { + float global_min = shared_min[0]; + float global_max = shared_max[0]; + + // Calculate final scale and zero_point + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + + // Write final results + t_scale[0] = scale_val; + t_zero_point[0] = zero_point_val; + } + } +$else: + void main() { + // Per-token hierarchical reduction implementation with multiple tokens per workgroup + uint global_id = gl_GlobalInvocationID.x; + uint local_id = gl_LocalInvocationID.x; + uint group_id = gl_WorkGroupID.x; + uint total_workgroups = gl_NumWorkGroups.x; + + // Calculate total number of elements in the input tensor + uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); + uint token_size = total_elements / uint(num_tokens); + + // Calculate how many tokens each workgroup should process + // This handles the case where we have more tokens than workgroups + uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; + + // Calculate which tokens this workgroup is responsible for + uint start_token = group_id * tokens_per_workgroup; + uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); + + // Early exit if this workgroup has no tokens to process + if (start_token >= uint(num_tokens)) { + return; + } + + // Process each token assigned to this workgroup + for (uint token_id = start_token; token_id < end_token; token_id++) { + // Calculate the start and end indices for this token + uint token_start = token_id * token_size; + uint token_end = token_start + token_size; + + // Phase 1: Each thread processes multiple elements within the token with stride + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + // Process elements within this token only + for (uint i = token_start + local_id; i < token_end; i += gl_WorkGroupSize.x) { + float val = t_in[i]; + if (!isnan(val) && !isinf(val)) { + if (!found_valid) { + thread_min = val; + thread_max = val; + found_valid = true; + } else { + thread_min = min(thread_min, val); + thread_max = max(thread_max, val); + } + } + } + + // Phase 2: Intra-group reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + // Handle infinity values properly + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Phase 3: Final calculation for this token + if (local_id == 0) { + float token_min = shared_min[0]; + float token_max = shared_max[0]; + + // Calculate scale and zero_point for this token + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + + // Write results for this token + t_scale[token_id] = scale_val; + t_zero_point[token_id] = zero_point_val; + } + + // Synchronize before processing next token + barrier(); + } + } + +#else // Texture storage + +$if MODE == "per_tensor": + void main() { + // Multi-workgroup texture-based per-tensor quantization parameter calculation + uint global_id = gl_GlobalInvocationID.x; + uint local_id = gl_LocalInvocationID.x; + uint group_id = gl_WorkGroupID.x; + uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; + + // Calculate total number of texels in the input tensor + uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); + + // Phase 1: Each thread processes multiple texels with stride + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + // Process texels with stride across all threads + for (uint texel_idx = global_id; texel_idx < total_texels; texel_idx += total_threads) { + // Convert linear texel index to 3D coordinates + uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); + uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); + uint y = remainder / uint(t_in_limits.x); + uint x = remainder % uint(t_in_limits.x); + ivec3 texel_pos = ivec3(int(x), int(y), int(z)); + + // Load texel data (4 float values) + FVEC4_T texel_data = load_texel(t_in, texel_pos); + + // For texture storage, we assume width-packed (packed_dim = 0) + // Calculate number of valid elements in this texel (handle padding) + int packed_dim = 0; // Width dimension is packed + ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format + ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); + + // Calculate total tensor elements to determine padding + int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; + int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + + tensor_coord.z * sizes.x * sizes.y; + int remaining_elements = total_elements - (linear_tensor_idx); + int valid_elements = min(4, remaining_elements); + + // Find min/max within this texel, considering only valid elements + if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { + if (!found_valid) { + thread_min = texel_data.x; + thread_max = texel_data.x; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.x); + thread_max = max(thread_max, texel_data.x); + } + } + + if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { + if (!found_valid) { + thread_min = texel_data.y; + thread_max = texel_data.y; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.y); + thread_max = max(thread_max, texel_data.y); + } + } + + if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { + if (!found_valid) { + thread_min = texel_data.z; + thread_max = texel_data.z; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.z); + thread_max = max(thread_max, texel_data.z); + } + } + + if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { + if (!found_valid) { + thread_min = texel_data.w; + thread_max = texel_data.w; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.w); + thread_max = max(thread_max, texel_data.w); + } + } + } + + // Phase 2: Intra-workgroup reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + // Handle infinity values properly + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Phase 3: Final result calculation (single workgroup only for reliability) + if (local_id == 0 && group_id == 0) { + float global_min = shared_min[0]; + float global_max = shared_max[0]; + + // Calculate final scale and zero_point + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + + // Write final results to output textures + write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); + write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); + } + } +$else: + void main() { + // Texture-based per-token quantization parameter calculation + // Each token is processed by multiple workgroups for parallel reduction + + uint local_id = gl_LocalInvocationID.x; + uint group_id = gl_WorkGroupID.x; + uint total_workgroups = gl_NumWorkGroups.x; + + // Calculate total number of texels in the input tensor + uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); + + // Calculate texels per token (assuming last dimension contains the token data) + // For per-token quantization, we assume tokens are along the last dimension + uint texels_per_token = total_texels / uint(num_tokens); + + // Calculate how many tokens each workgroup should process + uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; + + // Calculate which tokens this workgroup is responsible for + uint start_token = group_id * tokens_per_workgroup; + uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); + + // Process each token assigned to this workgroup + for (uint token_id = start_token; token_id < end_token; token_id++) { + // Calculate the texel range for this token + uint token_start_texel = token_id * texels_per_token; + uint token_end_texel = token_start_texel + texels_per_token; + + // Phase 1: Each thread processes multiple texels within the token + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + // Process texels within this token only + for (uint texel_idx = token_start_texel + local_id; texel_idx < token_end_texel; texel_idx += gl_WorkGroupSize.x) { + // Convert linear texel index to 3D coordinates + uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); + uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); + uint y = remainder / uint(t_in_limits.x); + uint x = remainder % uint(t_in_limits.x); + ivec3 texel_pos = ivec3(int(x), int(y), int(z)); + + // Load texel data (4 float values) + FVEC4_T texel_data = load_texel(t_in, texel_pos); + + // For texture storage, we assume width-packed (packed_dim = 0) + // Calculate number of valid elements in this texel (handle padding) + int packed_dim = 0; // Width dimension is packed + ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format + ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); + + // Calculate total tensor elements to determine padding + int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; + int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + + tensor_coord.z * sizes.x * sizes.y; + int remaining_elements = total_elements - (linear_tensor_idx); + int valid_elements = min(4, remaining_elements); + + // Find min/max within this texel, considering only valid elements + if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { + if (!found_valid) { + thread_min = texel_data.x; + thread_max = texel_data.x; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.x); + thread_max = max(thread_max, texel_data.x); + } + } + + if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { + if (!found_valid) { + thread_min = texel_data.y; + thread_max = texel_data.y; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.y); + thread_max = max(thread_max, texel_data.y); + } + } + + if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { + if (!found_valid) { + thread_min = texel_data.z; + thread_max = texel_data.z; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.z); + thread_max = max(thread_max, texel_data.z); + } + } + + if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { + if (!found_valid) { + thread_min = texel_data.w; + thread_max = texel_data.w; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.w); + thread_max = max(thread_max, texel_data.w); + } + } + } + + // Phase 2: Intra-workgroup reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + // Handle infinity values properly + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Phase 3: Final calculation for this token + if (local_id == 0) { + float token_min = shared_min[0]; + float token_max = shared_max[0]; + + // Calculate scale and zero_point for this token + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + + // Convert token_id to 3D coordinates for output texture + // Assuming output tensors have the same layout as input but with different dimensions + uint out_z = token_id / uint(t_scale_limits.x * t_scale_limits.y); + uint out_remainder = token_id % uint(t_scale_limits.x * t_scale_limits.y); + uint out_y = out_remainder / uint(t_scale_limits.x); + uint out_x = out_remainder % uint(t_scale_limits.x); + ivec3 out_pos = ivec3(int(out_x), int(out_y), int(out_z)); + + // Write results for this token + write_texel(t_scale, out_pos, vec4(scale_val, 0.0, 0.0, 0.0)); + write_texel(t_zero_point, out_pos, ivec4(zero_point_val, 0, 0, 0)); + } + + // Synchronize before processing next token + barrier(); + } + } + +#endif // USING_BUFFER diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.yaml new file mode 100644 index 00000000000..e84333e76ec --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.yaml @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +choose_qparams: + parameter_names_with_default_values: + IN_DTYPE: float + STORAGE: texture3d + MODE: per_tensor + generate_variant_forall: + STORAGE: + - VALUE: buffer + - VALUE: texture3d + IN_DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: choose_qparams_tensor + MODE: per_tensor + - NAME: choose_qparams_per_token_asymmetric + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp new file mode 100644 index 00000000000..e41dd30afc6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -0,0 +1,378 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace vkcompute { + +namespace { + +void resize_choose_qparams_tensor_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + vTensorPtr scale_out = graph->get_tensor(args[0].refs[0]); + vTensorPtr zero_point_out = graph->get_tensor(args[0].refs[1]); + + // Both scale and zero_point are scalar tensors for per-tensor quantization + // Since we use single workgroup approach, no extra buffer space needed + scale_out->virtual_resize({}); + zero_point_out->virtual_resize({}); +} + +void resize_choose_qparams_per_token_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + vTensorPtr scale_out = graph->get_tensor(args[0].refs[0]); + vTensorPtr zero_point_out = graph->get_tensor(args[0].refs[1]); + vTensorPtr input = graph->get_tensor(args[1].refs[0]); + + // Calculate output sizes for scale and zero_point tensors + std::vector output_sizes; + for (size_t i = 0; i < input->sizes().size() - 1; i++) { + output_sizes.push_back(input->sizes()[i]); + } + output_sizes.push_back(1); + + scale_out->virtual_resize(output_sizes); + zero_point_out->virtual_resize(output_sizes); +} + +utils::uvec3 choose_qparams_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + // For global reduction, we need to process the entire input tensor + const ValueRef input = args.at(0).refs.at(0); + + if (graph->is_buffer_storage(input)) { + const uint32_t local_threads = 64; // From choose_qparams_local_wg_size + + // For per-tensor quantization, use SINGLE WORKGROUP approach to avoid + // complex multi-workgroup synchronization issues that cause race + // conditions. A single workgroup with 64 threads can efficiently process + // large tensors by having each thread process multiple elements with + // stride. + + // Return single workgroup with 64 threads + return {local_threads, 1u, 1u}; + } else { + // For texture storage, use single workgroup approach for reliability + const uint32_t local_threads = 64; // From choose_qparams_local_wg_size + + // Return single workgroup with 64 threads + return {local_threads, 1u, 1u}; + } +} + +utils::uvec3 choose_qparams_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + (void)global_workgroup_size; + + const ValueRef input = args.at(0).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For hierarchical reduction, use 64 threads per work group for better + // efficiency This provides better GPU utilization while still being + // manageable for shared memory + + const uint32_t local_threads = 64; + return {local_threads, 1u, 1u}; + } else { + // For texture storage, use default local workgroup size + return graph->create_local_wg_size(global_workgroup_size); + } +} + +utils::uvec3 choose_qparams_per_token_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef input = args.at(0).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For per-token reduction, we need one workgroup per token + // Calculate number of tokens (product of all dimensions except the last + // one) + int64_t num_tokens = 1; + const auto input_sizes = graph->sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + // GPU hardware limits: Most GPUs support max ~65535 workgroups per + // dimension + const uint32_t max_workgroups = 65535; + const uint32_t local_x = 64u; // From choose_qparams_per_token_local_wg_size + + // Clamp number of workgroups to hardware limits + uint32_t clamped_workgroups = + std::min(static_cast(num_tokens), max_workgroups); + + // If we have more tokens than workgroups, each workgroup will process + // multiple tokens + + // Calculate total threads needed + const uint32_t total_threads_x = clamped_workgroups * local_x; + const uint32_t total_threads_y = 1u; + const uint32_t total_threads_z = 1u; + + return {total_threads_x, total_threads_y, total_threads_z}; + } else { + // For texture storage, calculate number of tokens + int64_t num_tokens = 1; + const auto input_sizes = graph->sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + // For texture storage, clamp to reasonable limits for performance + // Large token counts (>1024) can cause very slow execution + const uint32_t max_reasonable_tokens = 1024; + const uint32_t local_x = 64u; // From choose_qparams_per_token_local_wg_size + + uint32_t clamped_workgroups = + std::min(static_cast(num_tokens), max_reasonable_tokens); + + // Calculate total threads needed + const uint32_t total_threads_x = clamped_workgroups * local_x; + const uint32_t total_threads_y = 1u; + const uint32_t total_threads_z = 1u; + + return {total_threads_x, total_threads_y, total_threads_z}; + } +} + +utils::uvec3 choose_qparams_per_token_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + (void)global_workgroup_size; + + const ValueRef input = args.at(0).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For per-token reduction, each workgroup processes one token + // Use 64 threads per work group to match shared memory allocation + const uint32_t local_threads = 64; + + return {local_threads, 1u, 1u}; + } else { + // For texture storage, use default local workgroup size + return graph->create_local_wg_size(global_workgroup_size); + } +} + +} // namespace + +void add_choose_qparams_tensor_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& scale_out, + const ValueRef& zero_point_out) { + std::string kernel_name("choose_qparams_tensor"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zero_point_out), + graph.strides_ubo(zero_point_out)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(scale_out), + graph.logical_limits_ubo(zero_point_out)}; + } + + std::vector push_constants; + push_constants = { + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + choose_qparams_global_wg_size, + choose_qparams_local_wg_size, + // Inputs and Outputs + {{input, vkapi::kRead}, + {scale_out, vkapi::kWrite}, + {zero_point_out, vkapi::kWrite}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_choose_qparams_tensor_output)); +} + +void add_choose_qparams_per_token_asymmetric_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale_out, + const ValueRef& zero_point_out) { + std::string kernel_name("choose_qparams_per_token_asymmetric"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + // Calculate number of tokens (product of all dimensions except the last one) + int64_t num_tokens = 1; + const auto input_sizes = graph.sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + int num_tokens_val = static_cast(num_tokens); + int quant_min_val = -128; // Fixed for asymmetric quantization + int quant_max_val = 127; // Fixed for asymmetric quantization + + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zero_point_out), + graph.strides_ubo(zero_point_out)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(scale_out), + graph.logical_limits_ubo(zero_point_out)}; + } + + std::vector push_constants; + push_constants = { + PushConstantDataInfo(&num_tokens_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + choose_qparams_per_token_global_wg_size, + choose_qparams_per_token_local_wg_size, + // Inputs and Outputs + {{input, vkapi::kRead}, + {scale_out, vkapi::kWrite}, + {zero_point_out, vkapi::kWrite}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_choose_qparams_per_token_output)); +} + +void choose_qparams_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef scale_out = args[arg_idx++]; + const ValueRef zero_point_out = args[arg_idx++]; + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf || + graph.dtype_of(input) == vkapi::kDouble); + + // Verify output types - accept CPU types but convert to GPU types + VK_CHECK_COND( + graph.dtype_of(scale_out) == vkapi::kFloat || + graph.dtype_of(scale_out) == vkapi::kDouble); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kLong); + + add_choose_qparams_tensor_node( + graph, input, quant_min, quant_max, scale_out, zero_point_out); +} + +void choose_qparams_per_token_asymmetric_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale_out = args[arg_idx++]; + const ValueRef zero_point_out = args[arg_idx++]; + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf || + graph.dtype_of(input) == vkapi::kDouble); + + // Verify output types - accept CPU types but convert to GPU types + VK_CHECK_COND( + graph.dtype_of(scale_out) == vkapi::kFloat || + graph.dtype_of(scale_out) == vkapi::kDouble); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kLong); + + add_choose_qparams_per_token_asymmetric_node( + graph, input, scale_out, zero_point_out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(choose_qparams.tensor, choose_qparams_tensor_impl); + VK_REGISTER_OP( + choose_qparams_per_token_asymmetric.default, + choose_qparams_per_token_asymmetric_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp index cdbe20e633d..e8e3fb153fd 100644 --- a/backends/vulkan/test/op_tests/choose_qparams_test.cpp +++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp @@ -634,6 +634,38 @@ TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) { at::kChar); } +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_uint8_4D) { + test_vulkan_choose_qparams_tensor( + {5, 3, 2, 4}, // input sizes + 0, // quant_min + 255, // quant_max + at::kByte); +} + +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_2D) { + test_vulkan_choose_qparams_tensor( + {5, 5}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_3D) { + test_vulkan_choose_qparams_tensor( + {12, 8, 2}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_4D) { + test_vulkan_choose_qparams_tensor( + {10, 10, 6, 4}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + void test_reference_choose_qparams_per_token_asymmetric( const std::vector& input_sizes, at::ScalarType dtype) { @@ -791,3 +823,27 @@ TEST( {2, 3, 4}, // input sizes (2*3=6 tokens) at::kChar); } + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_1D) { + test_vulkan_choose_qparams_per_token_asymmetric({7}, at::kChar); +} + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_2D) { + test_vulkan_choose_qparams_per_token_asymmetric({2, 2}, at::kChar); +} + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_3D) { + test_vulkan_choose_qparams_per_token_asymmetric({3, 6, 4}, at::kChar); +} + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_4D) { + test_vulkan_choose_qparams_per_token_asymmetric({128, 2, 16, 3}, at::kChar); +} From 232f5f5e080306b97ef9300b1a233bc81f262283 Mon Sep 17 00:00:00 2001 From: morelos Date: Wed, 11 Jun 2025 10:10:19 -0700 Subject: [PATCH 2/2] Update on "[ET-VK][Ops] choose_qparams op shaders and impl" Creating the choose_qparams per_tensor and per_token logic shaders and impl which are linked with the testing framework Differential Revision: [D76436933](https://our.internmc.facebook.com/intern/diff/D76436933/) [ghstack-poisoned] --- .../graph/ops/glsl/choose_qparams.glsl | 52 +++------ .../runtime/graph/ops/impl/ChooseQParams.cpp | 101 ++++-------------- 2 files changed, 37 insertions(+), 116 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glsl index 0b0d4078806..09e7db52e57 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glsl @@ -52,10 +52,9 @@ $else: layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -// Constants for reduction #define NWORKERS 64 -// Constant for small scale threshold +// equivalent of the eps defined in the cpu implemnetation #define SMALL_SCALE_THRESHOLD 6.1e-5 // Shared memory for reduction - must match local work group size @@ -70,11 +69,10 @@ void calculate_scale_and_zero_point( int qmax, out float scale_val, out int zero_point_val) { - // Ensure the range includes zero + // ensure we have zero included in our range min_val = min(min_val, 0.0); max_val = max(max_val, 0.0); - // Calculate scale scale_val = (max_val - min_val) / float(qmax - qmin); // Handle zero or very small scale @@ -122,16 +120,14 @@ void calculate_scale_and_zero_point( $if MODE == "per_tensor": void main() { - // Single-Pass Hierarchical Reduction for per-tensor min/max uint global_id = gl_GlobalInvocationID.x; uint local_id = gl_LocalInvocationID.x; uint group_id = gl_WorkGroupID.x; uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; - // Calculate total number of elements in the input tensor uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); - // Phase 1: Each thread processes multiple elements with stride + // Each thread processes multiple elements with stride float thread_min = 1.0/0.0; // +infinity float thread_max = -1.0/0.0; // -infinity bool found_valid = false; @@ -150,7 +146,7 @@ $if MODE == "per_tensor": } } - // Phase 2: Intra-group reduction using shared memory + // Intra-group reduction using shared memory shared_min[local_id] = thread_min; shared_max[local_id] = thread_max; barrier(); @@ -161,7 +157,6 @@ $if MODE == "per_tensor": float other_min = shared_min[local_id + stride]; float other_max = shared_max[local_id + stride]; - // Handle infinity values properly if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { shared_min[local_id] = other_min; } @@ -172,30 +167,26 @@ $if MODE == "per_tensor": barrier(); } - // Phase 3: Final result calculation (single workgroup only) + // Final result calculation (single workgroup only) if (local_id == 0) { float global_min = shared_min[0]; float global_max = shared_max[0]; - // Calculate final scale and zero_point float scale_val; int zero_point_val; calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); - // Write final results t_scale[0] = scale_val; t_zero_point[0] = zero_point_val; } } $else: void main() { - // Per-token hierarchical reduction implementation with multiple tokens per workgroup uint global_id = gl_GlobalInvocationID.x; uint local_id = gl_LocalInvocationID.x; uint group_id = gl_WorkGroupID.x; uint total_workgroups = gl_NumWorkGroups.x; - // Calculate total number of elements in the input tensor uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); uint token_size = total_elements / uint(num_tokens); @@ -218,7 +209,7 @@ $else: uint token_start = token_id * token_size; uint token_end = token_start + token_size; - // Phase 1: Each thread processes multiple elements within the token with stride + // Each thread processes multiple elements within the token with stride float thread_min = 1.0/0.0; // +infinity float thread_max = -1.0/0.0; // -infinity bool found_valid = false; @@ -238,7 +229,7 @@ $else: } } - // Phase 2: Intra-group reduction using shared memory + // Intra-group reduction using shared memory shared_min[local_id] = thread_min; shared_max[local_id] = thread_max; barrier(); @@ -249,7 +240,6 @@ $else: float other_min = shared_min[local_id + stride]; float other_max = shared_max[local_id + stride]; - // Handle infinity values properly if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { shared_min[local_id] = other_min; } @@ -260,17 +250,15 @@ $else: barrier(); } - // Phase 3: Final calculation for this token + // Final calculation for this token if (local_id == 0) { float token_min = shared_min[0]; float token_max = shared_max[0]; - // Calculate scale and zero_point for this token float scale_val; int zero_point_val; calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); - // Write results for this token t_scale[token_id] = scale_val; t_zero_point[token_id] = zero_point_val; } @@ -284,16 +272,14 @@ $else: $if MODE == "per_tensor": void main() { - // Multi-workgroup texture-based per-tensor quantization parameter calculation uint global_id = gl_GlobalInvocationID.x; uint local_id = gl_LocalInvocationID.x; uint group_id = gl_WorkGroupID.x; uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; - // Calculate total number of texels in the input tensor uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); - // Phase 1: Each thread processes multiple texels with stride + // Each thread processes multiple texels with stride float thread_min = 1.0/0.0; // +infinity float thread_max = -1.0/0.0; // -infinity bool found_valid = false; @@ -307,7 +293,6 @@ $if MODE == "per_tensor": uint x = remainder % uint(t_in_limits.x); ivec3 texel_pos = ivec3(int(x), int(y), int(z)); - // Load texel data (4 float values) FVEC4_T texel_data = load_texel(t_in, texel_pos); // For texture storage, we assume width-packed (packed_dim = 0) @@ -369,7 +354,7 @@ $if MODE == "per_tensor": } } - // Phase 2: Intra-workgroup reduction using shared memory + // Intra-workgroup reduction using shared memory shared_min[local_id] = thread_min; shared_max[local_id] = thread_max; barrier(); @@ -380,7 +365,6 @@ $if MODE == "per_tensor": float other_min = shared_min[local_id + stride]; float other_max = shared_max[local_id + stride]; - // Handle infinity values properly if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { shared_min[local_id] = other_min; } @@ -391,31 +375,26 @@ $if MODE == "per_tensor": barrier(); } - // Phase 3: Final result calculation (single workgroup only for reliability) + // Final result calculation (single workgroup only for reliability) if (local_id == 0 && group_id == 0) { float global_min = shared_min[0]; float global_max = shared_max[0]; - // Calculate final scale and zero_point float scale_val; int zero_point_val; calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); - // Write final results to output textures write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); } } $else: void main() { - // Texture-based per-token quantization parameter calculation // Each token is processed by multiple workgroups for parallel reduction - uint local_id = gl_LocalInvocationID.x; uint group_id = gl_WorkGroupID.x; uint total_workgroups = gl_NumWorkGroups.x; - // Calculate total number of texels in the input tensor uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); // Calculate texels per token (assuming last dimension contains the token data) @@ -435,7 +414,7 @@ $else: uint token_start_texel = token_id * texels_per_token; uint token_end_texel = token_start_texel + texels_per_token; - // Phase 1: Each thread processes multiple texels within the token + // Each thread processes multiple texels within the token float thread_min = 1.0/0.0; // +infinity float thread_max = -1.0/0.0; // -infinity bool found_valid = false; @@ -449,7 +428,6 @@ $else: uint x = remainder % uint(t_in_limits.x); ivec3 texel_pos = ivec3(int(x), int(y), int(z)); - // Load texel data (4 float values) FVEC4_T texel_data = load_texel(t_in, texel_pos); // For texture storage, we assume width-packed (packed_dim = 0) @@ -511,7 +489,7 @@ $else: } } - // Phase 2: Intra-workgroup reduction using shared memory + // Intra-workgroup reduction using shared memory shared_min[local_id] = thread_min; shared_max[local_id] = thread_max; barrier(); @@ -533,12 +511,11 @@ $else: barrier(); } - // Phase 3: Final calculation for this token + // Final calculation for this token if (local_id == 0) { float token_min = shared_min[0]; float token_max = shared_max[0]; - // Calculate scale and zero_point for this token float scale_val; int zero_point_val; calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); @@ -551,7 +528,6 @@ $else: uint out_x = out_remainder % uint(t_scale_limits.x); ivec3 out_pos = ivec3(int(out_x), int(out_y), int(out_z)); - // Write results for this token write_texel(t_scale, out_pos, vec4(scale_val, 0.0, 0.0, 0.0)); write_texel(t_zero_point, out_pos, ivec4(zero_point_val, 0, 0, 0)); } diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index e41dd30afc6..316efccf753 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -54,27 +54,9 @@ utils::uvec3 choose_qparams_global_wg_size( (void)shader; (void)resize_args; - // For global reduction, we need to process the entire input tensor - const ValueRef input = args.at(0).refs.at(0); - - if (graph->is_buffer_storage(input)) { - const uint32_t local_threads = 64; // From choose_qparams_local_wg_size - - // For per-tensor quantization, use SINGLE WORKGROUP approach to avoid - // complex multi-workgroup synchronization issues that cause race - // conditions. A single workgroup with 64 threads can efficiently process - // large tensors by having each thread process multiple elements with - // stride. - - // Return single workgroup with 64 threads - return {local_threads, 1u, 1u}; - } else { - // For texture storage, use single workgroup approach for reliability - const uint32_t local_threads = 64; // From choose_qparams_local_wg_size + const uint32_t local_threads = 64; // From choose_qparams_local_wg_size - // Return single workgroup with 64 threads - return {local_threads, 1u, 1u}; - } + return {local_threads, 1u, 1u}; } utils::uvec3 choose_qparams_local_wg_size( @@ -90,14 +72,10 @@ utils::uvec3 choose_qparams_local_wg_size( const ValueRef input = args.at(0).refs.at(0); if (graph->is_buffer_storage(input)) { - // For hierarchical reduction, use 64 threads per work group for better - // efficiency This provides better GPU utilization while still being - // manageable for shared memory - const uint32_t local_threads = 64; + return {local_threads, 1u, 1u}; } else { - // For texture storage, use default local workgroup size return graph->create_local_wg_size(global_workgroup_size); } } @@ -112,57 +90,27 @@ utils::uvec3 choose_qparams_per_token_global_wg_size( const ValueRef input = args.at(0).refs.at(0); - if (graph->is_buffer_storage(input)) { - // For per-token reduction, we need one workgroup per token - // Calculate number of tokens (product of all dimensions except the last - // one) - int64_t num_tokens = 1; - const auto input_sizes = graph->sizes_of(input); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - // GPU hardware limits: Most GPUs support max ~65535 workgroups per - // dimension - const uint32_t max_workgroups = 65535; - const uint32_t local_x = 64u; // From choose_qparams_per_token_local_wg_size - - // Clamp number of workgroups to hardware limits - uint32_t clamped_workgroups = - std::min(static_cast(num_tokens), max_workgroups); - - // If we have more tokens than workgroups, each workgroup will process - // multiple tokens - - // Calculate total threads needed - const uint32_t total_threads_x = clamped_workgroups * local_x; - const uint32_t total_threads_y = 1u; - const uint32_t total_threads_z = 1u; - - return {total_threads_x, total_threads_y, total_threads_z}; - } else { - // For texture storage, calculate number of tokens - int64_t num_tokens = 1; - const auto input_sizes = graph->sizes_of(input); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - // For texture storage, clamp to reasonable limits for performance - // Large token counts (>1024) can cause very slow execution - const uint32_t max_reasonable_tokens = 1024; - const uint32_t local_x = 64u; // From choose_qparams_per_token_local_wg_size - - uint32_t clamped_workgroups = - std::min(static_cast(num_tokens), max_reasonable_tokens); - - // Calculate total threads needed - const uint32_t total_threads_x = clamped_workgroups * local_x; - const uint32_t total_threads_y = 1u; - const uint32_t total_threads_z = 1u; - - return {total_threads_x, total_threads_y, total_threads_z}; + // For per-token reduction, we need one workgroup per token + // Calculate number of tokens (product of all dimensions except the last + // one) + int64_t num_tokens = 1; + const auto input_sizes = graph->sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; } + + const uint32_t max_workgroups = 65535; + const uint32_t local_x = 64u; // From choose_qparams_per_token_local_wg_size + + // Clamp number of workgroups to avoid being slow + uint32_t clamped_workgroups = + std::min(static_cast(num_tokens), max_workgroups); + + // If we have more tokens than workgroups, each workgroup will process + // multiple tokens + const uint32_t total_threads_x = clamped_workgroups * local_x; + + return {total_threads_x, 1u, 1u}; } utils::uvec3 choose_qparams_per_token_local_wg_size( @@ -178,13 +126,10 @@ utils::uvec3 choose_qparams_per_token_local_wg_size( const ValueRef input = args.at(0).refs.at(0); if (graph->is_buffer_storage(input)) { - // For per-token reduction, each workgroup processes one token - // Use 64 threads per work group to match shared memory allocation const uint32_t local_threads = 64; return {local_threads, 1u, 1u}; } else { - // For texture storage, use default local workgroup size return graph->create_local_wg_size(global_workgroup_size); } }