diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 94378e885e5..a59b150e7ae 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -959,3 +959,24 @@ def select_as_symint_impl(x: torch.Tensor, dim: int, index: int): lib.define(f"{name}(Tensor x, int dim, int index) -> SymInt") lib.impl(name, select_as_symint_impl, "Meta") select_as_symint_op = getattr(getattr(torch.ops, namespace), name) + +################ +## rms_norm ## +################ + + +def rms_norm_impl( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, +) -> torch.Tensor: + input_dtype = x.dtype + variance = x.float().pow(2).mean(-1, keepdim=True) + x_normed = x.float() * torch.rsqrt(variance + eps) + return (x_normed * weight.float()).to(input_dtype) + + +name = "rms_norm" +lib.define(f"{name}(Tensor x, Tensor weight, float eps) -> Tensor") +lib.impl(name, rms_norm_impl, "CompositeExplicitAutograd") +rms_norm_op = getattr(getattr(torch.ops, namespace), name) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index f71fc2b03ee..ff056d76c3a 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1606,6 +1606,21 @@ def register_native_layer_norm(): ) +# ============================================================================= +# RmsNorm.cpp +# ============================================================================= + + +@update_features(exir_ops.edge.et_vk.rms_norm.default) +def register_rms_norm(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + inputs_dtypes=utils.FP_T, + supports_prepacking=True, + supports_resize=True, + ) + + ####################### ## Utility functions ## ####################### diff --git a/backends/vulkan/patterns/BUCK b/backends/vulkan/patterns/BUCK index 2e8f201f17f..7fa132fd5cb 100644 --- a/backends/vulkan/patterns/BUCK +++ b/backends/vulkan/patterns/BUCK @@ -16,6 +16,7 @@ fbcode_target(_kind = runtime.python_library, "quantized_convolution.py", "quantized_binary.py", "quantized_unary.py", + "rms_norm.py", "sdpa.py", "select_as_symint.py", ], diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index ae29a817c9f..86fb82a03d2 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -16,6 +16,8 @@ import executorch.backends.vulkan.patterns.quantized_unary # noqa +import executorch.backends.vulkan.patterns.rms_norm # noqa + import executorch.backends.vulkan.patterns.rope # noqa import executorch.backends.vulkan.patterns.rope_hf # noqa diff --git a/backends/vulkan/patterns/rms_norm.py b/backends/vulkan/patterns/rms_norm.py new file mode 100644 index 00000000000..beb5e677ead --- /dev/null +++ b/backends/vulkan/patterns/rms_norm.py @@ -0,0 +1,280 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +_CAST_OPS = { + exir_ops.edge.aten._to_copy.default, + exir_ops.edge.aten.to.dtype, +} + + +def _skip_casts(node: torch.fx.Node) -> torch.fx.Node: + """Unwrap chains of dtype-cast nodes to find the underlying value.""" + while node.target in _CAST_OPS: + arg0 = node.args[0] if node.args else None + if not isinstance(arg0, torch.fx.Node): + break + node = arg0 + # pyre-ignore[7]: node is always a Node; Pyre cannot narrow through loops + return node + + +class RmsNormMatch(PatternMatch): + """ + Detects the decomposed RMSNorm pattern, including variants where dtype + casts (to_copy) are inserted around the computation. + + The canonical pattern emitted by the Llama RMSNorm implementation is: + + x_orig (any dtype) + -> to_copy(fp32) -> x_f32 + -> mul(x_f32, x_f32) -> mean(dim=-1, keepdim=True) + -> add(eps) -> rsqrt -> rstd_f32 + -> mul(x_f32, rstd_f32) -> norm_f32 + -> to_copy(orig dtype) -> norm_cast + weight -> to_copy(orig dtype) -> weight_cast + -> mul(norm_cast, weight_cast) ← anchor node + + We look through to_copy nodes when comparing tensor identities so that + the match also handles fp32-only models where no casts are present. + + The anchor node is the final mul (scale by weight). + """ + + def __init__(self, final_mul_node: torch.fx.Node) -> None: # noqa: C901 + self.anchor_node = final_mul_node + self.match_found = False + self.all_nodes = [self.anchor_node] + + # final_mul: mul(normalized_cast, weight_cast) + # Unwrap casts to reach the underlying norm_mul and weight. + norm_mul_node, self.weight_node = self._identify_norm_mul_and_weight( + final_mul_node + ) + if norm_mul_node is None: + return + + self.all_nodes.append(norm_mul_node) + + # norm_mul: mul(x_f32, rstd_f32) + rsqrt_node, x_for_norm = self._identify_rsqrt_and_input(norm_mul_node) + if rsqrt_node is None: + return + + self.all_nodes.append(rsqrt_node) + + # rsqrt -> add(mean_sq, eps) -> mean(x_sq, dim=-1, keepdim=True) + add_node = self._get_single_arg_node( + rsqrt_node, exir_ops.edge.aten.rsqrt.default + ) + if add_node is None or add_node.target != exir_ops.edge.aten.add.Tensor: + return + + self.all_nodes.append(add_node) + + self.eps_node = None + mean_node = None + for arg in add_node.args[:2]: + if ( + isinstance(arg, torch.fx.Node) + and arg.target == exir_ops.edge.aten.mean.dim + ): + mean_node = arg + else: + self.eps_node = arg + + if mean_node is None or self.eps_node is None: + return + + self.all_nodes.append(mean_node) + + # Verify mean has keepdim=True and dim=[-1] + if len(mean_node.args) < 3: + return + mean_dims = mean_node.args[1] + if mean_dims != [-1]: + return + if not mean_node.args[2]: + return + + # mean's input should be x_sq = mul(x, x) or pow(x, 2) + sq_node = mean_node.args[0] + if not isinstance(sq_node, torch.fx.Node): + return + + self.all_nodes.append(sq_node) + + # Use the fp32 x (x_for_norm) as the canonical fp32 input. + # Both mul(x,x) and the norm mul should share the same fp32 source. + x_f32 = ( + _skip_casts(x_for_norm) + if isinstance(x_for_norm, torch.fx.Node) + else x_for_norm + ) + + if sq_node.target == exir_ops.edge.aten.mul.Tensor: + if sq_node.args[0] != sq_node.args[1]: + return + sq_input = sq_node.args[0] + if not isinstance(sq_input, torch.fx.Node): + return + if _skip_casts(sq_input) != x_f32 and sq_input != x_for_norm: + return + elif sq_node.target == exir_ops.edge.aten.pow.Tensor_Scalar: + sq_input = sq_node.args[0] + if not isinstance(sq_input, torch.fx.Node): + return + if _skip_casts(sq_input) != x_f32 and sq_input != x_for_norm: + return + if sq_node.args[1] != 2 and sq_node.args[1] != 2.0: + return + else: + return + + # The canonical input node to expose to the fused op is the original + # tensor before any fp32 upcast (i.e. the input to the first to_copy). + # If there's no cast, x_for_norm is already the original input. + self.input_node = ( + _skip_casts(x_for_norm) + if isinstance(x_for_norm, torch.fx.Node) + else x_for_norm + ) + # Also collect the intermediate cast nodes so they can be cleaned up + cast_node = x_for_norm + while ( + isinstance(cast_node, torch.fx.Node) + and cast_node.target in _CAST_OPS + and cast_node not in self.all_nodes + ): + self.all_nodes.append(cast_node) + cast_node = cast_node.args[0] if cast_node.args else cast_node + + self.match_found = True + + def _identify_norm_mul_and_weight(self, final_mul_node): + """From mul(norm_cast, weight_cast), unwrap casts and find the + underlying norm-mul node and the weight source node.""" + if len(final_mul_node.args) < 2: + return None, None + + a, b = final_mul_node.args[0], final_mul_node.args[1] + + for norm_candidate_raw, weight_candidate_raw in [(a, b), (b, a)]: + if not isinstance(norm_candidate_raw, torch.fx.Node): + continue + norm_candidate = _skip_casts(norm_candidate_raw) + if ( + isinstance(norm_candidate, torch.fx.Node) + and norm_candidate.target == exir_ops.edge.aten.mul.Tensor + and self._has_rsqrt_ancestor(norm_candidate) + ): + return norm_candidate, weight_candidate_raw + + return None, None + + def _has_rsqrt_ancestor(self, mul_node): + """Check if one of mul_node's args is an rsqrt node (possibly through casts).""" + for arg in mul_node.args[:2]: + if not isinstance(arg, torch.fx.Node): + continue + if _skip_casts(arg).target == exir_ops.edge.aten.rsqrt.default: + return True + return False + + def _identify_rsqrt_and_input(self, norm_mul_node): + """From mul(x, rstd), find the rsqrt node and the input x. + The rsqrt may be wrapped in a cast node.""" + if len(norm_mul_node.args) < 2: + return None, None + + a, b = norm_mul_node.args[0], norm_mul_node.args[1] + + for rsqrt_candidate_raw, input_candidate in [(a, b), (b, a)]: + if not isinstance(rsqrt_candidate_raw, torch.fx.Node): + continue + rsqrt_candidate = _skip_casts(rsqrt_candidate_raw) + if ( + isinstance(rsqrt_candidate, torch.fx.Node) + and rsqrt_candidate.target == exir_ops.edge.aten.rsqrt.default + ): + return rsqrt_candidate, input_candidate + + return None, None + + def _get_single_arg_node(self, node, expected_target): + """Get the single input arg of a unary op node.""" + if node.target != expected_target: + return None + if len(node.args) < 1 or not isinstance(node.args[0], torch.fx.Node): + return None + return node.args[0] + + +@register_pattern_detector("rms_norm") +def find_rms_norm_patterns( + node: torch.fx.Node, +) -> Optional[RmsNormMatch]: + if node.target != exir_ops.edge.aten.mul.Tensor: + return None + + matched_pattern = RmsNormMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +def _extract_eps_value(eps_node) -> float: + if isinstance(eps_node, (int, float)): + return float(eps_node) + if isinstance(eps_node, torch.fx.Node) and "val" in eps_node.meta: + val = eps_node.meta["val"] + if isinstance(val, torch.Tensor): + return float(val.item()) + if isinstance(val, (int, float)): + return float(val) + raise ValueError(f"Cannot extract epsilon value from {eps_node}") + + +@register_pattern_replacement("rms_norm") +def replace_rms_norm_with_fused_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: RmsNormMatch, +): + eps_val = _extract_eps_value(match.eps_node) + + with graph_module.graph.inserting_before(match.anchor_node): + rms_norm_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.rms_norm.default, + args=( + match.input_node, + match.weight_node, + eps_val, + ), + ) + + rms_norm_node.meta["val"] = match.anchor_node.meta["val"] + match.anchor_node.replace_all_uses_with(rms_norm_node) diff --git a/backends/vulkan/runtime/graph/ops/glsl/rms_norm_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/rms_norm_buffer.glsl new file mode 100644 index 00000000000..a869cc1188f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/rms_norm_buffer.glsl @@ -0,0 +1,106 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(push_constant) uniform PRECISION restrict Block { + float epsilon; +}; + +#define NUM_WORKERS 64 + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")} + +shared float shared_sum[NUM_WORKERS]; + +void reduce_shared(const uint worker_id) { + memoryBarrierShared(); + barrier(); + + [[unroll]] for (int stride = NUM_WORKERS / 2; stride > 0; stride >>= 1) { + if (worker_id < stride) { + shared_sum[worker_id] += shared_sum[worker_id + stride]; + } + memoryBarrierShared(); + barrier(); + } +} + +void main() { + const uint row_idx = gl_GlobalInvocationID.y; + const uint worker_id = gl_LocalInvocationID.x; + + const uint row_width = width(inp); + const uint num_rows = numel(inp) / row_width; + + if (row_idx >= num_rows) { + return; + } + + // Decompose row_idx into H, C, N indices using inp sizes (WHCN order) + uint remaining = row_idx; + const uint H = uint(inp.sizes[0][1]); + const uint C = uint(inp.sizes[0][2]); + + const uint h = remaining % H; + remaining /= H; + const uint c = remaining % C; + const uint n = remaining / C; + + // Build tensor index with w=0 to get base buffer index for this row + TensorIndex tidx; + tidx.data[0] = uvec4(0, h, c, n); + tidx.data[1] = uvec4(0); + const uint base_bufi = tensor_idx_to_linear_idx(inp, tidx); + const uint width_stride = stride_at(inp, 0); + + // Phase 1: Compute mean(x^2) via cooperative reduction in fp32 + float local_sq_sum = 0.0; + for (uint x = worker_id; x < row_width; x += NUM_WORKERS) { + const uint in_bufi = base_bufi + x * width_stride; + const float val = float(t_in[in_bufi]); + local_sq_sum += val * val; + } + + shared_sum[worker_id] = local_sq_sum; + reduce_shared(worker_id); + + const float mean_sq = shared_sum[0] / float(row_width); + const float rstd = inversesqrt(mean_sq + epsilon); + + // Phase 2: Normalize and write output + for (uint x = worker_id; x < row_width; x += NUM_WORKERS) { + const uint in_bufi = base_bufi + x * width_stride; + const float val = float(t_in[in_bufi]); + const float w = float(t_weight[x]); + t_out[in_bufi] = T(val * rstd * w); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/rms_norm_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/rms_norm_buffer.yaml new file mode 100644 index 00000000000..56358f406fd --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/rms_norm_buffer.yaml @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +rms_norm_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: rms_norm_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/rms_norm_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/rms_norm_texture.glsl new file mode 100644 index 00000000000..838736919e2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/rms_norm_texture.glsl @@ -0,0 +1,113 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("texture3d", DTYPE)} + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, "texture3d")} +#define T ${texel_load_component_type(DTYPE, "texture3d")} + +${define_active_storage_type("texture3d")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} + +${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "texture3d")} + +${layout_declare_ubo(B, "TextureMetadata", "out_meta")} +${layout_declare_ubo(B, "TextureMetadata", "in_meta")} + +layout(push_constant) uniform PRECISION restrict Block { + float epsilon; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")} + +// This shader assumes width-packed layout. +// Dispatch: global = {1, num_rows, 1}, local = {NUM_WORKERS, 1, 1} +// Each workgroup processes one row; workers cooperatively reduce across texels. + +#define NUM_WORKERS 64 + +shared float shared_sum[NUM_WORKERS]; + +void reduce_shared(const uint worker_id) { + memoryBarrierShared(); + barrier(); + + [[unroll]] for (int stride = NUM_WORKERS / 2; stride > 0; stride >>= 1) { + if (worker_id < stride) { + shared_sum[worker_id] += shared_sum[worker_id + stride]; + } + memoryBarrierShared(); + barrier(); + } +} + +void main() { + const uint row_idx = gl_GlobalInvocationID.y; + const uint worker_id = gl_LocalInvocationID.x; + + const int width = in_meta.sizes.x; + const int num_texels = div_up_4(width); + const int remain = width & 3; + + // Decompose row_idx into (y, z) texture coordinates. When the tensor has more + // than one Z slice (e.g. 4D tensors), row_idx encodes both Y and Z. + const int tex_H = in_meta.limits.y; + const int y = int(row_idx) % tex_H; + const int z = int(row_idx) / tex_H; + + // First pass: compute mean(x^2) via cooperative reduction in fp32 + float local_sq_sum = 0.0; + for (int tx = int(worker_id); tx < num_texels; tx += NUM_WORKERS) { + ivec3 pos = ivec3(tx, y, z); + + VEC4_T in_val = texelFetch(t_in, pos, 0); + + if (tx == num_texels - 1 && remain != 0) { + const int remain_inv = 4 - remain; + in_val.y = mix(in_val.y, T(0), remain_inv > 2); + in_val.z = mix(in_val.z, T(0), remain_inv > 1); + in_val.w = mix(in_val.w, T(0), remain_inv > 0); + } + vec4 v = vec4(in_val); + local_sq_sum += dot(v, v); + } + + shared_sum[worker_id] = local_sq_sum; + reduce_shared(worker_id); + + const float mean_sq = shared_sum[0] / float(width); + const float rstd = inversesqrt(mean_sq + epsilon); + + memoryBarrierShared(); + barrier(); + + // Second pass: normalize and write output + for (int tx = int(worker_id); tx < num_texels; tx += NUM_WORKERS) { + ivec3 pos = ivec3(tx, y, z); + + const VEC4_T in_val = texelFetch(t_in, pos, 0); + const VEC4_T weight = texelFetch(t_weight, ivec3(tx, 0, 0), 0); + const VEC4_T outtex = + VEC4_T(vec4(in_val) * rstd * vec4(weight)); + imageStore(t_out, pos, outtex); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/rms_norm_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/rms_norm_texture.yaml new file mode 100644 index 00000000000..3cd2c2d9436 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/rms_norm_texture.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +rms_norm_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: rms_norm_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/RmsNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/RmsNorm.cpp new file mode 100644 index 00000000000..1eb267f9794 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/RmsNorm.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +#include + +namespace vkcompute { + +void resize_rms_norm_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + graph->virtual_resize(out, graph->sizes_of(in)); +} + +utils::uvec3 rms_norm_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef in = args.at(1).refs.at(0); + const auto& sizes = graph->sizes_of(in); + const int64_t hidden = sizes.back(); + const int64_t numel = graph->numel_of(in); + const uint32_t num_rows = utils::safe_downcast(numel / hidden); + return {1u, num_rows, 1u}; +} + +utils::uvec3 rms_norm_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + return {64u, 1u, 1u}; +} + +void add_rms_norm_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight_data, + const ValueRef eps, + const ValueRef out) { + ValueRef arg_weight = prepack_standard_like(graph, weight_data, in); + + float epsilon = graph.extract_scalar(eps); + + const bool is_buffer = graph.is_buffer_storage(in); + + std::string kernel_name("rms_norm"); + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + if (!is_buffer) { + VK_CHECK_COND(check_same_packed_dim(graph, in, out)); + VK_CHECK_COND( + graph.packed_dim_of(in) == WHCN::kWidthDim, + "RmsNorm texture path requires width-packed input"); + } + + vkapi::ParamsBindList param_ubos = {graph.meta_ubo(out), graph.meta_ubo(in)}; + vkapi::SpecVarList spec_constants = {graph.hashed_layout_of(in)}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + rms_norm_global_wg_size, + rms_norm_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{in, arg_weight}, vkapi::kRead}}, + // Shader params buffers + param_ubos, + // Push Constants + {PushConstantDataInfo(&epsilon, sizeof(epsilon))}, + // Specialization Constants + spec_constants, + // Resize Args + {}, + // Resizing Logic + resize_rms_norm_node)); +} + +void rms_norm(ComputeGraph& graph, const std::vector& args) { + // et_vk.rms_norm(input, weight, epsilon) -> output + return add_rms_norm_node(graph, args[0], args[1], args[2], args[3]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.rms_norm.default, rms_norm); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/rms_norm_test.cpp b/backends/vulkan/test/op_tests/rms_norm_test.cpp new file mode 100644 index 00000000000..057eda4445f --- /dev/null +++ b/backends/vulkan/test/op_tests/rms_norm_test.cpp @@ -0,0 +1,301 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include + +#include +#include + +using executorch::aten::Half; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using torch::executor::testing::TensorFactory; + +// +// Helpers +// + +std::vector rand_floats(size_t n, unsigned seed = 42) { + std::mt19937 gen(seed); + std::uniform_real_distribution dist(-1.0f, 1.0f); + std::vector data(n); + std::generate(data.begin(), data.end(), [&]() { return dist(gen); }); + return data; +} + +size_t numel(const std::vector& sizes) { + size_t n = 1; + for (auto s : sizes) { + n *= static_cast(s); + } + return n; +} + +std::vector to_int32(const std::vector& v) { + return std::vector(v.begin(), v.end()); +} + +// +// Reference Implementation (pure C++) +// + +std::vector rms_norm_ref( + const std::vector& x, + const std::vector& weight, + const std::vector& shape, + float eps) { + const size_t hidden = static_cast(shape.back()); + const size_t num_rows = x.size() / hidden; + std::vector out(x.size()); + + for (size_t r = 0; r < num_rows; ++r) { + const size_t off = r * hidden; + float sq_sum = 0.0f; + for (size_t i = 0; i < hidden; ++i) { + sq_sum += x[off + i] * x[off + i]; + } + float rstd = 1.0f / std::sqrt(sq_sum / static_cast(hidden) + eps); + for (size_t i = 0; i < hidden; ++i) { + out[off + i] = x[off + i] * rstd * weight[i]; + } + } + return out; +} + +// +// Test function +// + +void test_rms_norm( + const std::vector& input_shape, + const float eps = 1e-5f, + const vkcompute::vkapi::ScalarType dtype = vkcompute::vkapi::kFloat, + const vkcompute::utils::StorageType storage_type = + vkcompute::utils::kTexture3D) { + const int64_t hidden_size = input_shape.back(); + const size_t input_numel = numel(input_shape); + const size_t weight_numel = static_cast(hidden_size); + + std::vector x_data = rand_floats(input_numel, 42); + std::vector w_data = rand_floats(weight_numel, 123); + + // For fp16: round-trip through Half so the reference uses the same precision + // as the GPU input. + std::vector x_half, w_half; + if (dtype == vkcompute::vkapi::kHalf) { + x_half.resize(input_numel); + w_half.resize(weight_numel); + for (size_t i = 0; i < input_numel; ++i) { + x_half[i] = static_cast(x_data[i]); + x_data[i] = static_cast(x_half[i]); + } + for (size_t i = 0; i < weight_numel; ++i) { + w_half[i] = static_cast(w_data[i]); + w_data[i] = static_cast(w_half[i]); + } + } + + std::vector ref_data = rms_norm_ref(x_data, w_data, input_shape, eps); + + // Build Vulkan graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(storage_type); + ComputeGraph graph(config); + + IOValueRef r_x = graph.add_input_tensor(input_shape, dtype); + + ValueRef r_weight = (dtype == vkapi::kHalf) + ? graph.add_tensorref({hidden_size}, vkapi::kHalf, w_half.data()) + : graph.add_tensorref({hidden_size}, vkapi::kFloat, w_data.data()); + + const ValueRef r_eps = graph.add_scalar(static_cast(eps)); + const ValueRef r_out = graph.add_tensor(input_shape, dtype); + + VK_GET_OP_FN("et_vk.rms_norm.default") + (graph, {r_x.value, r_weight, r_eps, r_out}); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.prepack(); + graph.propagate_resize(); + + if (dtype == vkapi::kHalf) { + graph.maybe_cast_and_copy_into_staging( + r_x.staging, x_half.data(), input_numel, vkapi::kHalf); + } else { + graph.maybe_cast_and_copy_into_staging( + r_x.staging, x_data.data(), input_numel, vkapi::kFloat); + } + + graph.execute(); + + // Read output — fp16 staging returns Half, convert to float for comparison. + std::vector vk_data(input_numel); + if (dtype == vkapi::kHalf) { + std::vector vk_half(input_numel); + graph.maybe_cast_and_copy_from_staging( + staging_out, vk_half.data(), input_numel, vkapi::kHalf); + for (size_t i = 0; i < input_numel; ++i) { + vk_data[i] = static_cast(vk_half[i]); + } + } else { + graph.maybe_cast_and_copy_from_staging( + staging_out, vk_data.data(), input_numel, vkapi::kFloat); + } + + TensorFactory tf; + Tensor ref_tensor = tf.make(to_int32(input_shape), ref_data); + Tensor vk_tensor = tf.make(to_int32(input_shape), vk_data); + + const double rtol = (dtype == vkapi::kHalf) ? 1e-2 : 1e-3; + const double atol = (dtype == vkapi::kHalf) ? 1e-2 : 1e-3; + EXPECT_TENSOR_CLOSE_WITH_TOL(ref_tensor, vk_tensor, rtol, atol); +} + +// +// Texture storage tests +// + +TEST(VulkanRmsNormTest, basic_small_texture) { + test_rms_norm({1, 1, 1, 64}); +} + +TEST(VulkanRmsNormTest, llm_hidden_size_texture) { + test_rms_norm({1, 1, 1, 896}, 1e-6f); +} + +TEST(VulkanRmsNormTest, fp16_texture) { + test_rms_norm({1, 1, 1, 896}, 1e-6f, vkcompute::vkapi::kHalf); +} + +TEST(VulkanRmsNormTest, multi_row_texture) { + test_rms_norm({1, 1, 7, 896}); +} + +TEST(VulkanRmsNormTest, multi_z_slice_texture) { + // C=7 maps to multiple texture Z slices, exercising the y/z decomposition + test_rms_norm({1, 7, 1, 896}); +} + +TEST(VulkanRmsNormTest, 4d_multi_z_slice_texture) { + // 4D shape similar to model's QK norm with multiple Z slices + test_rms_norm({1, 5, 4, 128}); +} + +// +// Buffer storage tests +// + +TEST(VulkanRmsNormTest, basic_small_buffer) { + test_rms_norm( + {1, 1, 1, 64}, + 1e-5f, + vkcompute::vkapi::kFloat, + vkcompute::utils::kBuffer); +} + +TEST(VulkanRmsNormTest, fp16_buffer) { + test_rms_norm( + {1, 1, 1, 896}, + 1e-6f, + vkcompute::vkapi::kHalf, + vkcompute::utils::kBuffer); +} + +// +// Dynamic resize test +// + +TEST(VulkanRmsNormTest, dynamic_resize_texture) { + const int64_t hidden_size = 896; + const float eps = 1e-6f; + const int prefill_seq_len = 7; + + std::vector w_data = rand_floats(static_cast(hidden_size), 99); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + + IOValueRef r_x = graph.add_input_tensor( + {1, 1, prefill_seq_len, hidden_size}, vkapi::kFloat); + ValueRef r_weight = + graph.add_tensorref({hidden_size}, vkapi::kFloat, w_data.data()); + + const ValueRef r_eps = graph.add_scalar(static_cast(eps)); + const ValueRef r_out = + graph.add_tensor({1, 1, prefill_seq_len, hidden_size}, vkapi::kFloat); + + VK_GET_OP_FN("et_vk.rms_norm.default") + (graph, {r_x.value, r_weight, r_eps, r_out}); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.prepack(); + + TensorFactory tf; + + // --- Prefill run (seq_len = 7) --- + { + std::vector shape = {1, 1, prefill_seq_len, hidden_size}; + size_t n = numel(shape); + std::vector x_data = rand_floats(n, 200); + std::vector ref_data = rms_norm_ref(x_data, w_data, shape, eps); + + graph.resize_input(0, shape); + graph.propagate_resize(); + graph.maybe_cast_and_copy_into_staging( + r_x.staging, x_data.data(), n, vkapi::kFloat); + + graph.execute(); + + std::vector vk_data(n); + graph.maybe_cast_and_copy_from_staging( + staging_out, vk_data.data(), n, vkapi::kFloat); + + Tensor ref_t = tf.make(to_int32(shape), ref_data); + Tensor vk_t = tf.make(to_int32(shape), vk_data); + EXPECT_TENSOR_CLOSE_WITH_TOL(ref_t, vk_t, 1e-3, 1e-3) << "Prefill mismatch"; + } + + // --- Decode run (seq_len = 1) --- + { + std::vector shape = {1, 1, 1, hidden_size}; + size_t n = numel(shape); + std::vector x_data = rand_floats(n, 300); + std::vector ref_data = rms_norm_ref(x_data, w_data, shape, eps); + + graph.resize_input(0, shape); + graph.propagate_resize(); + graph.maybe_cast_and_copy_into_staging( + r_x.staging, x_data.data(), n, vkapi::kFloat); + + graph.execute(); + + std::vector vk_data(n); + graph.maybe_cast_and_copy_from_staging( + staging_out, vk_data.data(), n, vkapi::kFloat); + + Tensor ref_t = tf.make(to_int32(shape), ref_data); + Tensor vk_t = tf.make(to_int32(shape), vk_data); + EXPECT_TENSOR_CLOSE_WITH_TOL(ref_t, vk_t, 1e-3, 1e-3) << "Decode mismatch"; + } +} diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 4d11886f590..383a2d67eaa 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -180,6 +180,12 @@ def define_common_targets(is_fbcode = False): ":test_utils", ] ) + define_test_targets( + "rms_norm_test", + extra_deps = [ + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + ] + ) define_test_targets( "rotary_embedding_test", extra_deps = [