From 48db8ce400e4affb1d59722efe83645aab864fb2 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 30 Mar 2026 14:45:02 -0700 Subject: [PATCH] [ET-VK] Add fused HuggingFace RoPE operator (apply_rotary_emb_hf) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/18592 Add a fused rotary positional embedding operator for the HuggingFace RoPE convention used by Qwen3, Phi-4-mini, and other HF-based models. The existing `et_vk.apply_rotary_emb` only matches the stock Meta/Llama RoPE pattern (interleaved pairs via reshape+unbind+stack+flatten). HF models use a different convention (split-half via slice+neg+cat), causing Qwen3's RoPE to decompose into ~560 GPU dispatches per decode step instead of 16 fused dispatches (~1,295 µs/decode, 7% of total). This commit adds `et_vk.apply_rotary_emb_hf` with: - Pattern matching: `HfRotaryEmbeddingPattern` in `patterns/rope_hf.py` using SubgraphMatcher to detect the HF RoPE graph and replace with fused op. Supports both full rotation (freqs_dim == head_dim) and partial rotation (freqs_dim < head_dim, e.g. Phi-4-mini with partial_rotary_factor=0.75) by registering two pattern variants in get_hf_rope_graphs(). - GLSL shader: `rotary_embedding_hf.glsl` which pairs elements at distance D/2 (half-apart) instead of adjacent pairs, computing half_dim from the metadata UBO for dynamic shape support - C++ dispatch: `add_rotary_embedding_hf_node` with corrected assertion (head_dim == freqs_dim, not freqs_dim*2) since HF freqs are full-dim - Custom op registration in both xplat and fbcode - Op tests covering multiple configurations and dynamic prefill→decode resize Also adds a convert_phi4_mini_weights binary target to the phi_4_mini TARGETS file to enable converting HF checkpoint weights to Meta format. Authored with Claude. ghstack-source-id: 359963407 @exported-using-ghexport Differential Revision: [D98741178](https://our.internmc.facebook.com/intern/diff/D98741178/) --- backends/vulkan/custom_ops_lib.py | 26 + backends/vulkan/op_registry.py | 10 + backends/vulkan/patterns/BUCK | 1 + backends/vulkan/patterns/__init__.py | 4 + backends/vulkan/patterns/rope_hf.py | 188 +++++ .../graph/ops/glsl/rotary_embedding_hf.glsl | 137 ++++ .../graph/ops/glsl/rotary_embedding_hf.yaml | 13 + .../graph/ops/impl/RotaryEmbedding.cpp | 103 +++ .../test/op_tests/rotary_embedding_test.cpp | 641 ++++++++++++++++++ 9 files changed, 1123 insertions(+) create mode 100644 backends/vulkan/patterns/rope_hf.py create mode 100644 backends/vulkan/runtime/graph/ops/glsl/rotary_embedding_hf.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/rotary_embedding_hf.yaml diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 7b0a0544662..94378e885e5 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -802,6 +802,32 @@ def apply_rotary_emb_impl( lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd") apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name) +######################### +## apply_rotary_emb_hf ## +######################### + + +def apply_rotary_emb_hf_impl( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + start_pos: int, +): + seq_len = xq.shape[1] + freqs_cos = freqs_cos[start_pos : start_pos + seq_len] + freqs_sin = freqs_sin[start_pos : start_pos + seq_len] + pattern = vk_patterns.HfRotaryEmbeddingPattern() + return pattern.forward(xq, xk, freqs_cos, freqs_sin) + + +name = "apply_rotary_emb_hf" +lib.define( + f"{name}(Tensor xq, Tensor xk, Tensor freqs_cos, Tensor freqs_sin, SymInt start_pos) -> (Tensor, Tensor)" +) +lib.impl(name, apply_rotary_emb_hf_impl, "CompositeExplicitAutograd") +apply_rotary_emb_hf_op = getattr(getattr(torch.ops, namespace), name) + ######################## ## q8ta_add ## ######################## diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index a06849c57a3..f71fc2b03ee 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1086,6 +1086,16 @@ def register_apply_rotary_emb(): ) +@update_features(exir_ops.edge.et_vk.apply_rotary_emb_hf.default) +def register_apply_rotary_emb_hf(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + inputs_dtypes=utils.FP_T, + supports_resize=True, + supports_highdim=True, + ) + + # ============================================================================= # Permute.cpp # ============================================================================= diff --git a/backends/vulkan/patterns/BUCK b/backends/vulkan/patterns/BUCK index 73bdc7edd1e..2e8f201f17f 100644 --- a/backends/vulkan/patterns/BUCK +++ b/backends/vulkan/patterns/BUCK @@ -10,6 +10,7 @@ fbcode_target(_kind = runtime.python_library, "__init__.py", "pattern_registry.py", "rope.py", + "rope_hf.py", "quantized_embedding.py", "quantized_linear.py", "quantized_convolution.py", diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index a9323a57b09..ae29a817c9f 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -18,6 +18,8 @@ import executorch.backends.vulkan.patterns.rope # noqa +import executorch.backends.vulkan.patterns.rope_hf # noqa + import executorch.backends.vulkan.patterns.sdpa # noqa import executorch.backends.vulkan.patterns.select_as_symint # noqa @@ -37,6 +39,7 @@ ) from executorch.backends.vulkan.patterns.rope import RotaryEmbeddingPattern +from executorch.backends.vulkan.patterns.rope_hf import HfRotaryEmbeddingPattern from executorch.exir import ExportedProgram @@ -49,6 +52,7 @@ "DetectorFn", "CreateReplacementFn", "RotaryEmbeddingPattern", + "HfRotaryEmbeddingPattern", "fusable_patterns", "register_pattern_graph", "register_pattern_detector", diff --git a/backends/vulkan/patterns/rope_hf.py b/backends/vulkan/patterns/rope_hf.py new file mode 100644 index 00000000000..1514ab403b5 --- /dev/null +++ b/backends/vulkan/patterns/rope_hf.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator + +from functools import lru_cache +from typing import List, Optional + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_graph, + register_pattern_replacement, +) + +from executorch.exir import EdgeCompileConfig, ExportedProgram, to_edge +from executorch.exir.dialects._ops import ops as exir_ops + +from torch.export import export + + +class HfRotaryEmbeddingPattern(torch.nn.Module): + """ + HuggingFace-style RoPE using rotate_half convention. + Matches the hf_apply_rotary_emb function in examples/models/llama/rope.py. + """ + + def __init__(self): + super().__init__() + + def forward( + self, + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + cos = freqs_cos.unsqueeze(1) + sin = freqs_sin.unsqueeze(1) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = xq[..., :rotary_dim], xq[..., rotary_dim:] + k_rot, k_pass = xk[..., :rotary_dim], xk[..., rotary_dim:] + + q_embed = torch.cat( + [(q_rot.float() * cos) + (self._rotate_half(q_rot.float()) * sin), q_pass], + dim=-1, + ) + k_embed = torch.cat( + [(k_rot.float() * cos) + (self._rotate_half(k_rot.float()) * sin), k_pass], + dim=-1, + ) + return q_embed.type_as(xq), k_embed.type_as(xk) + + @staticmethod + def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@lru_cache(maxsize=2) +@register_pattern_graph("hf_rope") +def get_hf_rope_graphs() -> List[torch.fx.GraphModule]: + batch_size = 1 + seq_len = 1 + n_heads = 4 + n_kv_heads = 2 + head_dim = 32 + + graphs = [] + dtype = torch.float32 + + # Full rotation pattern (partial_rotary_factor == 1.0): freqs_dim == head_dim + xq = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=dtype) + xk = torch.randn(batch_size, seq_len, n_kv_heads, head_dim, dtype=dtype) + freqs_cos = torch.randn(seq_len, head_dim, dtype=dtype) + freqs_sin = torch.randn(seq_len, head_dim, dtype=dtype) + + edge = to_edge( + export( + HfRotaryEmbeddingPattern(), + (xq, xk, freqs_cos, freqs_sin), + strict=True, + ), + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + gm = edge.exported_program().graph_module + graphs.append(gm) + + # Partial rotation pattern (partial_rotary_factor < 1.0): freqs_dim < head_dim + # e.g. head_dim=32, rotary_dim=24 (0.75 factor), so q_pass is non-empty + rotary_dim = 24 + xq_partial = torch.randn(batch_size, seq_len, n_heads, head_dim, dtype=dtype) + xk_partial = torch.randn(batch_size, seq_len, n_kv_heads, head_dim, dtype=dtype) + freqs_cos_partial = torch.randn(seq_len, rotary_dim, dtype=dtype) + freqs_sin_partial = torch.randn(seq_len, rotary_dim, dtype=dtype) + + edge_partial = to_edge( + export( + HfRotaryEmbeddingPattern(), + (xq_partial, xk_partial, freqs_cos_partial, freqs_sin_partial), + strict=True, + ), + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + gm_partial = edge_partial.exported_program().graph_module + graphs.append(gm_partial) + + return graphs + + +def identify_hf_rotary_emb_io_nodes( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: PatternMatch, +) -> Optional[List[torch.fx.Node]]: + input_nodes = match.input_nodes + if len(input_nodes) != 4: + return None + + xq, xk, freqs_cos, freqs_sin = input_nodes + + output_nodes = match.output_nodes + if len(output_nodes) != 2: + return None + + xq_out, xk_out = output_nodes + + return [xq, xk, freqs_cos, freqs_sin, xq_out, xk_out] + + +@register_pattern_replacement("hf_rope") +def create_hf_rotary_emb_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: PatternMatch, +): + io_nodes = identify_hf_rotary_emb_io_nodes(ep, graph_module, match) + if io_nodes is None: + return + + assert len(io_nodes) == 6 + xq, xk, freqs_cos, freqs_sin, xq_out, xk_out = io_nodes + + # Check if freqs come from slice_copy and extract full table + start_pos + if ( + freqs_cos.op == "call_function" + and freqs_cos.target == exir_ops.edge.aten.slice_copy.Tensor + ): + full_freqs_cos = freqs_cos.args[0] + start_pos = freqs_cos.args[2] + full_freqs_sin = freqs_sin.args[0] + freqs_cos = full_freqs_cos + freqs_sin = full_freqs_sin + else: + start_pos = 0 + + with graph_module.graph.inserting_before(xq_out): + rotary_emb_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.apply_rotary_emb_hf.default, + args=(xq, xk, freqs_cos, freqs_sin, start_pos), + ) + + with graph_module.graph.inserting_after(rotary_emb_node): + getitem_0 = graph_module.graph.create_node( + "call_function", + operator.getitem, + args=(rotary_emb_node, 0), + ) + getitem_1 = graph_module.graph.create_node( + "call_function", + operator.getitem, + args=(rotary_emb_node, 1), + ) + + if hasattr(xq_out, "meta") and "val" in xq_out.meta: + getitem_0.meta["val"] = xq_out.meta["val"] + if hasattr(xk_out, "meta") and "val" in xk_out.meta: + getitem_1.meta["val"] = xk_out.meta["val"] + + xq_out.replace_all_uses_with(getitem_0) + xk_out.replace_all_uses_with(getitem_1) diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding_hf.glsl b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding_hf.glsl new file mode 100644 index 00000000000..09cd31940c7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding_hf.glsl @@ -0,0 +1,137 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions(STORAGE, DTYPE)} + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} + +${define_active_storage_type(STORAGE)} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_xqout", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_xkout", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_xq", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_xk", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_freqs_cos", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_freqs_sin", DTYPE, STORAGE, is_scalar_array=False)} + +$if STORAGE == "buffer": + ${layout_declare_ubo(B, "BufferMetadata", "xqout")} + ${layout_declare_ubo(B, "BufferMetadata", "xkout")} + ${layout_declare_ubo(B, "BufferMetadata", "freqs_cos")} +$else: + ${layout_declare_ubo(B, "TextureMetadata", "xqout")} + ${layout_declare_ubo(B, "TextureMetadata", "xkout")} + ${layout_declare_ubo(B, "TextureMetadata", "freqs_cos")} + +${layout_declare_ubo(B, "int", "start_pos")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "xqout_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "freqs_layout", "CONTIG_LAYOUT_INT")} +// 0 = full rotation (rotary_dim == head_dim), 1 = partial rotation with +// passthrough region. Resolved at pipeline creation time, so the driver +// eliminates the dead branch entirely. +${layout_declare_spec_const(C, "int", "partial_rotary", "0")} + +// Load/store helpers that abstract buffer vs texture access. The `layout` +// parameter is only used in the texture path; the buffer path ignores it. +#ifdef USING_BUFFER +#define LOAD(tensor, meta, tidx, layout) \ + tensor[div_4(tensor4d_idx_to_linear_idx(meta, tidx))] +#define STORE(tensor, meta, tidx, layout, val) \ + tensor[div_4(tensor4d_idx_to_linear_idx(meta, tidx))] = val +#else +#define LOAD(tensor, meta, tidx, layout) \ + texelFetch(tensor, tensor4d_idx_to_texel_pos_simple(meta, tidx, layout), 0) +#define STORE(tensor, meta, tidx, layout, val) \ + imageStore(tensor, tensor4d_idx_to_texel_pos_simple(meta, tidx, layout), val) +#endif + +/* + * HuggingFace-style rotary positional embeddings. + * + * Input tensors: + * xq (batch, seq_len, n_heads, head_dim) + * xk (batch, seq_len, n_kv_heads, head_dim) + * freqs_cos (max_seq_len, rotary_dim) rotary_dim <= head_dim + * freqs_sin (max_seq_len, rotary_dim) + * start_pos (int) offset into freqs table + * + * For i in [0, rotary_half): + * out[i] = x[i]*cos[i] - x[i+rotary_half]*sin[i] + * out[i+rotary_half] = x[i+rotary_half]*cos[i] + x[i]*sin[i] + * When partial_rotary == 1, for i in [rotary_dim, head_dim): + * out[i] = x[i] (passthrough) + * + * Each thread handles one texel (4 elements) along head_dim. + * All input tensors must be width-packed. + */ +void main() { +#ifdef USING_BUFFER + const int rotary_half = int(width(freqs_cos)) / 2; +#else + const int rotary_half = freqs_cos.sizes.x / 2; +#endif + + const int x = int(gl_GlobalInvocationID.x) * 4; + + TensorIndex4D tidx = zero_tensor4d_idx(); + tidx.data.x = x; + tidx.data.yz = ivec2(gl_GlobalInvocationID.yz); + + if (out_of_bounds(tidx, xqout)) { + return; + } + + const bool process_k = !out_of_bounds(tidx, xkout); + + // Passthrough region (only reachable when partial_rotary == 1). + if (partial_rotary == 1 && x >= rotary_half * 2) { + STORE(t_xqout, xqout, tidx, xqout_layout, LOAD(t_xq, xqout, tidx, xqout_layout)); + if (process_k) { + STORE(t_xkout, xkout, tidx, xqout_layout, LOAD(t_xk, xkout, tidx, xqout_layout)); + } + return; + } + + // Rotation region: determine pair and freqs indices. + const bool is_second_half = (x >= rotary_half); + + TensorIndex4D pair_tidx = tidx; + pair_tidx.data.x = is_second_half ? (x - rotary_half) : (x + rotary_half); + + TensorIndex4D freqs_tidx = zero_tensor4d_idx(); + freqs_tidx.data.x = is_second_half ? (x - rotary_half) : x; + freqs_tidx.data.y = tidx.data.z + start_pos; + + const VEC4_T cos_val = LOAD(t_freqs_cos, freqs_cos, freqs_tidx, freqs_layout); + const VEC4_T sin_val = LOAD(t_freqs_sin, freqs_cos, freqs_tidx, freqs_layout); + + // First half: out = x*cos - pair*sin + // Second half: out = x*cos + pair*sin + const VEC4_T xq_val = LOAD(t_xq, xqout, tidx, xqout_layout); + const VEC4_T xq_pair = LOAD(t_xq, xqout, pair_tidx, xqout_layout); + const VEC4_T sign = VEC4_T(is_second_half ? 1.0 : -1.0); + + STORE(t_xqout, xqout, tidx, xqout_layout, xq_val * cos_val + sign * xq_pair * sin_val); + + if (process_k) { + const VEC4_T xk_val = LOAD(t_xk, xkout, tidx, xqout_layout); + const VEC4_T xk_pair = LOAD(t_xk, xkout, pair_tidx, xqout_layout); + STORE(t_xkout, xkout, tidx, xqout_layout, xk_val * cos_val + sign * xk_pair * sin_val); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding_hf.yaml b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding_hf.yaml new file mode 100644 index 00000000000..3005015da49 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding_hf.yaml @@ -0,0 +1,13 @@ +rotary_embedding_hf: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + generate_variant_forall: + STORAGE: + - VALUE: texture3d + - VALUE: buffer + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: rotary_embedding_hf diff --git a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp index 17275ef9b10..7f90e2557cb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp @@ -117,8 +117,111 @@ void apply_rotary_emb(ComputeGraph& graph, const std::vector& args) { graph, args[0], args[1], args[2], args[3], xq_out, xk_out); } +// +// HuggingFace RoPE variant +// + +utils::uvec3 rotary_embedding_hf_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef xq_out = args.at(0).refs.at(0); + + // Each invocation handles one texel (4 elements) along head_dim. + // Dispatch for all head_dim elements so that both the rotary region and the + // passthrough region (partial_rotary_factor < 1.0) are covered. + const uint32_t D4 = utils::div_up_4(graph->size_at(-1, xq_out)); + + const uint32_t QH = graph->size_at(-2, xq_out); + const uint32_t S = graph->size_at(-3, xq_out); + + return {D4, QH, S}; +} + +void add_rotary_embedding_hf_node( + ComputeGraph& graph, + const ValueRef xq, + const ValueRef xk, + const ValueRef freqs_cos, + const ValueRef freqs_sin, + const ValueRef start_pos, + const ValueRef xq_out, + const ValueRef xk_out) { + VK_CHECK_COND(graph.size_at(-1, xq) == graph.size_at(-1, xk)); + VK_CHECK_COND(graph.size_at(-3, xq) == graph.size_at(-3, xk)); + // HF convention: freqs rotary_dim <= head_dim (supports + // partial_rotary_factor) + VK_CHECK_COND( + graph.size_at(-1, freqs_cos) <= graph.size_at(-1, xq)); + // freqs_cos rotary_dim must be even (pairs required for rotation) + VK_CHECK_COND(graph.size_at(-1, freqs_cos) % 8 == 0); + VK_CHECK_COND(graph.sizes_of(freqs_cos) == graph.sizes_of(freqs_sin)); + // freqs dim 0 is max_seq_len which must be >= current seq_len + VK_CHECK_COND( + graph.size_at(-2, freqs_cos) >= graph.size_at(-3, xq)); + + VK_CHECK_COND(graph.packed_dim_of(xq) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(xk) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(freqs_cos) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(freqs_sin) == WHCN::kWidthDim); + VK_CHECK_COND(graph.has_standard_axis_map(xq)); + VK_CHECK_COND(graph.has_standard_axis_map(xk)); + VK_CHECK_COND(graph.has_standard_axis_map(freqs_cos)); + VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin)); + + const int32_t partial_rotary = + graph.size_at(-1, freqs_cos) < graph.size_at(-1, xq) ? 1 : 0; + + std::string kernel_name = "rotary_embedding_hf"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(xq_out)); + add_dtype_suffix(kernel_name, graph.dtype_of(xq_out)); + + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(xq_out), + graph.meta_ubo(xk_out), + graph.meta_ubo(freqs_cos), + graph.get_or_create_int_param_buffer(start_pos)}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + rotary_embedding_hf_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{{xq_out, xk_out}, vkapi::kWrite}, + {{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}}, + // Parameter buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {graph.hashed_layout_of(xq_out), + graph.hashed_layout_of(freqs_cos), + partial_rotary}, + // Resize Args + {}, + // Resizing Logic + resize_rotary_embedding_node)); +} + +void apply_rotary_emb_hf( + ComputeGraph& graph, + const std::vector& args) { + const ValueListPtr out_tuple = graph.get_value_list(args[5]); + const ValueRef xq_out = out_tuple->at(0); + const ValueRef xk_out = out_tuple->at(1); + + add_rotary_embedding_hf_node( + graph, args[0], args[1], args[2], args[3], args[4], xq_out, xk_out); +} + REGISTER_OPERATORS { VK_REGISTER_OP(et_vk.apply_rotary_emb.default, apply_rotary_emb); + VK_REGISTER_OP(et_vk.apply_rotary_emb_hf.default, apply_rotary_emb_hf); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp index d75d611de7e..e2be1526a4a 100644 --- a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp +++ b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp @@ -179,3 +179,644 @@ TEST(VulkanRotaryEmbeddingTest, rotary_embedding_llama3_params_test_seq_len_3) { /*dim=*/2048, /*seq_len=*/3); } + +// +// HuggingFace RoPE reference and tests +// + +std::pair rotary_embedding_hf_impl( + const at::Tensor& xq, + const at::Tensor& xk, + const at::Tensor& freqs_cos, + const at::Tensor& freqs_sin) { + const int64_t head_dim = xq.size(3); + const int64_t half_dim = head_dim / 2; + + // Split into first half and second half along head_dim + at::Tensor xq_first = xq.slice(/*dim=*/3, /*start=*/0, /*end=*/half_dim); + at::Tensor xq_second = xq.slice(/*dim=*/3, /*start=*/half_dim); + at::Tensor xk_first = xk.slice(/*dim=*/3, /*start=*/0, /*end=*/half_dim); + at::Tensor xk_second = xk.slice(/*dim=*/3, /*start=*/half_dim); + + // freqs are (seq_len, head_dim) but duplicated; use first half only + at::Tensor cos_half = + freqs_cos.slice(/*dim=*/1, /*start=*/0, /*end=*/half_dim); + at::Tensor sin_half = + freqs_sin.slice(/*dim=*/1, /*start=*/0, /*end=*/half_dim); + + at::Tensor cos_reshape = + cos_half.reshape({1, cos_half.size(0), 1, cos_half.size(1)}); + at::Tensor sin_reshape = + sin_half.reshape({1, sin_half.size(0), 1, sin_half.size(1)}); + + // out[i] = x[i] * cos[i] - x[i+D/2] * sin[i] + // out[i+D/2] = x[i+D/2] * cos[i] + x[i] * sin[i] + at::Tensor xq_out_first = xq_first * cos_reshape - xq_second * sin_reshape; + at::Tensor xq_out_second = xq_second * cos_reshape + xq_first * sin_reshape; + at::Tensor xk_out_first = xk_first * cos_reshape - xk_second * sin_reshape; + at::Tensor xk_out_second = xk_second * cos_reshape + xk_first * sin_reshape; + + at::Tensor xq_out = at::cat({xq_out_first, xq_out_second}, /*dim=*/3); + at::Tensor xk_out = at::cat({xk_out_first, xk_out_second}, /*dim=*/3); + + return std::make_pair(xq_out, xk_out); +} + +void test_reference_hf( + const int n_heads = 4, + const int n_kv_heads = 2, + const int dim = 32, + const int seq_len = 1) { + const int head_dim = dim / n_heads; + + at::Tensor xq = at::rand( + {1, seq_len, n_heads, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor xk = at::rand( + {1, seq_len, n_kv_heads, head_dim}, + at::device(at::kCPU).dtype(at::kFloat)); + // HF convention: freqs are full head_dim (duplicated) + at::Tensor freqs_cos = + at::rand({seq_len, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor freqs_sin = + at::rand({seq_len, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + + std::pair outs = + rotary_embedding_hf_impl(xq, xk, freqs_cos, freqs_sin); + at::Tensor& xq_out = outs.first; + at::Tensor& xk_out = outs.second; + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + + IOValueRef r_xq = graph.add_input_tensor( + xq.sizes().vec(), from_at_scalartype(xq.scalar_type())); + IOValueRef r_xk = graph.add_input_tensor( + xk.sizes().vec(), from_at_scalartype(xk.scalar_type())); + IOValueRef r_freqs_cos = graph.add_input_tensor( + freqs_cos.sizes().vec(), from_at_scalartype(freqs_cos.scalar_type())); + IOValueRef r_freqs_sin = graph.add_input_tensor( + freqs_sin.sizes().vec(), from_at_scalartype(freqs_sin.scalar_type())); + + const ValueRef r_xq_out = graph.add_tensor( + xq_out.sizes().vec(), from_at_scalartype(xq_out.scalar_type())); + const ValueRef r_xk_out = graph.add_tensor( + xk_out.sizes().vec(), from_at_scalartype(xk_out.scalar_type())); + + const ValueRef r_start_pos = graph.add_scalar(0); + + VK_GET_OP_FN("et_vk.apply_rotary_emb_hf.default") + (graph, + {r_xq.value, + r_xk.value, + r_freqs_cos.value, + r_freqs_sin.value, + r_start_pos, + graph.add_value_list({r_xq_out, r_xk_out})}); + + ValueRef staging_xq_out = graph.set_output_tensor(r_xq_out); + ValueRef staging_xk_out = graph.set_output_tensor(r_xk_out); + + graph.prepare(); + graph.prepack(); + + graph.propagate_resize(); + graph.maybe_cast_and_copy_into_staging( + r_xq.staging, + xq.const_data_ptr(), + xq.numel(), + from_at_scalartype(xq.scalar_type())); + graph.maybe_cast_and_copy_into_staging( + r_xk.staging, + xk.const_data_ptr(), + xk.numel(), + from_at_scalartype(xk.scalar_type())); + graph.maybe_cast_and_copy_into_staging( + r_freqs_cos.staging, + freqs_cos.const_data_ptr(), + freqs_cos.numel(), + from_at_scalartype(freqs_cos.scalar_type())); + graph.maybe_cast_and_copy_into_staging( + r_freqs_sin.staging, + freqs_sin.const_data_ptr(), + freqs_sin.numel(), + from_at_scalartype(freqs_sin.scalar_type())); + + graph.execute(); + + at::Tensor vk_xq_out = at::empty_like(xq_out); + graph.maybe_cast_and_copy_from_staging( + staging_xq_out, + vk_xq_out.mutable_data_ptr(), + vk_xq_out.numel(), + from_at_scalartype(vk_xq_out.scalar_type())); + + at::Tensor vk_xk_out = at::empty_like(xk_out); + graph.maybe_cast_and_copy_from_staging( + staging_xk_out, + vk_xk_out.mutable_data_ptr(), + vk_xk_out.numel(), + from_at_scalartype(vk_xk_out.scalar_type())); + + EXPECT_TRUE(at::allclose(xq_out, vk_xq_out, 1e-4, 1e-4)); + EXPECT_TRUE(at::allclose(xk_out, vk_xk_out, 1e-4, 1e-4)); +} + +TEST(VulkanRotaryEmbeddingHFTest, rotary_embedding_hf_test) { + test_reference_hf(); +} + +TEST(VulkanRotaryEmbeddingHFTest, rotary_embedding_hf_llama3_params_test) { + test_reference_hf( + /*n_heads=*/32, + /*n_kv_heads=*/8, + /*dim=*/2048); +} + +TEST( + VulkanRotaryEmbeddingHFTest, + rotary_embedding_hf_llama3_params_test_seq_len_3) { + test_reference_hf( + /*n_heads=*/32, + /*n_kv_heads=*/8, + /*dim=*/2048, + /*seq_len=*/3); +} + +TEST(VulkanRotaryEmbeddingHFTest, rotary_embedding_hf_head_dim_128) { + test_reference_hf( + /*n_heads=*/8, + /*n_kv_heads=*/4, + /*dim=*/1024, + /*seq_len=*/5); +} + +// Tests dynamic resize from prefill (seq_len=N) to decode (seq_len=1), +// simulating the actual LLM inference pattern that was previously broken. +TEST(VulkanRotaryEmbeddingHFTest, rotary_embedding_hf_dynamic_resize_qwen3) { + const int n_heads = 16; + const int n_kv_heads = 8; + const int head_dim = 128; + const int prefill_seq_len = 7; + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + + // Build graph with prefill shapes (max size) + IOValueRef r_xq = graph.add_input_tensor( + {1, prefill_seq_len, n_heads, head_dim}, vkapi::kFloat); + IOValueRef r_xk = graph.add_input_tensor( + {1, prefill_seq_len, n_kv_heads, head_dim}, vkapi::kFloat); + IOValueRef r_freqs_cos = + graph.add_input_tensor({prefill_seq_len, head_dim}, vkapi::kFloat); + IOValueRef r_freqs_sin = + graph.add_input_tensor({prefill_seq_len, head_dim}, vkapi::kFloat); + + const ValueRef r_xq_out = + graph.add_tensor({1, prefill_seq_len, n_heads, head_dim}, vkapi::kFloat); + const ValueRef r_xk_out = graph.add_tensor( + {1, prefill_seq_len, n_kv_heads, head_dim}, vkapi::kFloat); + + const ValueRef r_start_pos = graph.add_scalar(0); + + VK_GET_OP_FN("et_vk.apply_rotary_emb_hf.default") + (graph, + {r_xq.value, + r_xk.value, + r_freqs_cos.value, + r_freqs_sin.value, + r_start_pos, + graph.add_value_list({r_xq_out, r_xk_out})}); + + ValueRef staging_xq_out = graph.set_output_tensor(r_xq_out); + ValueRef staging_xk_out = graph.set_output_tensor(r_xk_out); + + graph.prepare(); + graph.prepack(); + + // --- Prefill run (seq_len = 7) --- + { + at::Tensor xq = at::rand( + {1, prefill_seq_len, n_heads, head_dim}, + at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor xk = at::rand( + {1, prefill_seq_len, n_kv_heads, head_dim}, + at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor freqs_cos = at::rand( + {prefill_seq_len, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor freqs_sin = at::rand( + {prefill_seq_len, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + + auto ref = rotary_embedding_hf_impl(xq, xk, freqs_cos, freqs_sin); + + graph.resize_input(0, xq.sizes().vec()); + graph.resize_input(1, xk.sizes().vec()); + graph.resize_input(2, freqs_cos.sizes().vec()); + graph.resize_input(3, freqs_sin.sizes().vec()); + graph.propagate_resize(); + + graph.maybe_cast_and_copy_into_staging( + r_xq.staging, xq.const_data_ptr(), xq.numel(), vkapi::kFloat); + graph.maybe_cast_and_copy_into_staging( + r_xk.staging, xk.const_data_ptr(), xk.numel(), vkapi::kFloat); + graph.maybe_cast_and_copy_into_staging( + r_freqs_cos.staging, + freqs_cos.const_data_ptr(), + freqs_cos.numel(), + vkapi::kFloat); + graph.maybe_cast_and_copy_into_staging( + r_freqs_sin.staging, + freqs_sin.const_data_ptr(), + freqs_sin.numel(), + vkapi::kFloat); + + graph.execute(); + + at::Tensor vk_xq_out = at::empty_like(ref.first); + graph.maybe_cast_and_copy_from_staging( + staging_xq_out, + vk_xq_out.mutable_data_ptr(), + vk_xq_out.numel(), + vkapi::kFloat); + at::Tensor vk_xk_out = at::empty_like(ref.second); + graph.maybe_cast_and_copy_from_staging( + staging_xk_out, + vk_xk_out.mutable_data_ptr(), + vk_xk_out.numel(), + vkapi::kFloat); + + EXPECT_TRUE(at::allclose(ref.first, vk_xq_out, 1e-4, 1e-4)) + << "Prefill xq_out mismatch"; + EXPECT_TRUE(at::allclose(ref.second, vk_xk_out, 1e-4, 1e-4)) + << "Prefill xk_out mismatch"; + } + + // --- Decode run (seq_len = 1) --- + { + at::Tensor xq = at::rand( + {1, 1, n_heads, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor xk = at::rand( + {1, 1, n_kv_heads, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor freqs_cos = + at::rand({1, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor freqs_sin = + at::rand({1, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + + auto ref = rotary_embedding_hf_impl(xq, xk, freqs_cos, freqs_sin); + + graph.resize_input(0, xq.sizes().vec()); + graph.resize_input(1, xk.sizes().vec()); + graph.resize_input(2, freqs_cos.sizes().vec()); + graph.resize_input(3, freqs_sin.sizes().vec()); + graph.propagate_resize(); + + graph.maybe_cast_and_copy_into_staging( + r_xq.staging, xq.const_data_ptr(), xq.numel(), vkapi::kFloat); + graph.maybe_cast_and_copy_into_staging( + r_xk.staging, xk.const_data_ptr(), xk.numel(), vkapi::kFloat); + graph.maybe_cast_and_copy_into_staging( + r_freqs_cos.staging, + freqs_cos.const_data_ptr(), + freqs_cos.numel(), + vkapi::kFloat); + graph.maybe_cast_and_copy_into_staging( + r_freqs_sin.staging, + freqs_sin.const_data_ptr(), + freqs_sin.numel(), + vkapi::kFloat); + + graph.execute(); + + at::Tensor vk_xq_out = at::empty_like(ref.first); + graph.maybe_cast_and_copy_from_staging( + staging_xq_out, + vk_xq_out.mutable_data_ptr(), + vk_xq_out.numel(), + vkapi::kFloat); + at::Tensor vk_xk_out = at::empty_like(ref.second); + graph.maybe_cast_and_copy_from_staging( + staging_xk_out, + vk_xk_out.mutable_data_ptr(), + vk_xk_out.numel(), + vkapi::kFloat); + + EXPECT_TRUE(at::allclose(ref.first, vk_xq_out, 1e-4, 1e-4)) + << "Decode xq_out mismatch"; + EXPECT_TRUE(at::allclose(ref.second, vk_xk_out, 1e-4, 1e-4)) + << "Decode xk_out mismatch"; + } +} + +// Tests that start_pos correctly offsets into the full freqs table. +// The Vulkan op receives the full [max_seq_len, head_dim] freqs table plus a +// start_pos offset, while the reference impl receives pre-sliced freqs. +void test_reference_hf_with_start_pos( + const int n_heads = 8, + const int n_kv_heads = 4, + const int head_dim = 128, + const int seq_len = 3, + const int start_pos = 7, + const int max_seq_len = 32) { + at::Tensor xq = at::rand( + {1, seq_len, n_heads, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor xk = at::rand( + {1, seq_len, n_kv_heads, head_dim}, + at::device(at::kCPU).dtype(at::kFloat)); + + // Full freqs table of size [max_seq_len, head_dim] + at::Tensor freqs_cos_full = + at::rand({max_seq_len, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor freqs_sin_full = + at::rand({max_seq_len, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + + // Slice freqs for the reference implementation + at::Tensor freqs_cos_sliced = + freqs_cos_full.slice(/*dim=*/0, start_pos, start_pos + seq_len); + at::Tensor freqs_sin_sliced = + freqs_sin_full.slice(/*dim=*/0, start_pos, start_pos + seq_len); + + // Reference uses pre-sliced freqs + std::pair ref = + rotary_embedding_hf_impl(xq, xk, freqs_cos_sliced, freqs_sin_sliced); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + + IOValueRef r_xq = graph.add_input_tensor( + xq.sizes().vec(), from_at_scalartype(xq.scalar_type())); + IOValueRef r_xk = graph.add_input_tensor( + xk.sizes().vec(), from_at_scalartype(xk.scalar_type())); + // Vulkan op receives full freqs table + IOValueRef r_freqs_cos = graph.add_input_tensor( + freqs_cos_full.sizes().vec(), + from_at_scalartype(freqs_cos_full.scalar_type())); + IOValueRef r_freqs_sin = graph.add_input_tensor( + freqs_sin_full.sizes().vec(), + from_at_scalartype(freqs_sin_full.scalar_type())); + + const ValueRef r_xq_out = graph.add_tensor( + ref.first.sizes().vec(), from_at_scalartype(ref.first.scalar_type())); + const ValueRef r_xk_out = graph.add_tensor( + ref.second.sizes().vec(), from_at_scalartype(ref.second.scalar_type())); + + const ValueRef r_start_pos = graph.add_scalar(start_pos); + + VK_GET_OP_FN("et_vk.apply_rotary_emb_hf.default") + (graph, + {r_xq.value, + r_xk.value, + r_freqs_cos.value, + r_freqs_sin.value, + r_start_pos, + graph.add_value_list({r_xq_out, r_xk_out})}); + + ValueRef staging_xq_out = graph.set_output_tensor(r_xq_out); + ValueRef staging_xk_out = graph.set_output_tensor(r_xk_out); + + graph.prepare(); + graph.prepack(); + + graph.propagate_resize(); + graph.maybe_cast_and_copy_into_staging( + r_xq.staging, + xq.const_data_ptr(), + xq.numel(), + from_at_scalartype(xq.scalar_type())); + graph.maybe_cast_and_copy_into_staging( + r_xk.staging, + xk.const_data_ptr(), + xk.numel(), + from_at_scalartype(xk.scalar_type())); + graph.maybe_cast_and_copy_into_staging( + r_freqs_cos.staging, + freqs_cos_full.const_data_ptr(), + freqs_cos_full.numel(), + from_at_scalartype(freqs_cos_full.scalar_type())); + graph.maybe_cast_and_copy_into_staging( + r_freqs_sin.staging, + freqs_sin_full.const_data_ptr(), + freqs_sin_full.numel(), + from_at_scalartype(freqs_sin_full.scalar_type())); + + graph.execute(); + + at::Tensor vk_xq_out = at::empty_like(ref.first); + graph.maybe_cast_and_copy_from_staging( + staging_xq_out, + vk_xq_out.mutable_data_ptr(), + vk_xq_out.numel(), + from_at_scalartype(vk_xq_out.scalar_type())); + + at::Tensor vk_xk_out = at::empty_like(ref.second); + graph.maybe_cast_and_copy_from_staging( + staging_xk_out, + vk_xk_out.mutable_data_ptr(), + vk_xk_out.numel(), + from_at_scalartype(vk_xk_out.scalar_type())); + + EXPECT_TRUE(at::allclose(ref.first, vk_xq_out, 1e-4, 1e-4)); + EXPECT_TRUE(at::allclose(ref.second, vk_xk_out, 1e-4, 1e-4)); +} + +TEST(VulkanRotaryEmbeddingHFTest, rotary_embedding_hf_start_pos_offset) { + test_reference_hf_with_start_pos(); +} + +TEST(VulkanRotaryEmbeddingHFTest, rotary_embedding_hf_start_pos_decode) { + test_reference_hf_with_start_pos( + /*n_heads=*/16, + /*n_kv_heads=*/8, + /*head_dim=*/128, + /*seq_len=*/1, + /*start_pos=*/15, + /*max_seq_len=*/64); +} + +// +// Partial rotary tests (partial_rotary_factor < 1.0) +// + +// Reference impl for partial rotary: only first rotary_dim elements are +// rotated, the rest pass through unchanged. +std::pair rotary_embedding_hf_partial_impl( + const at::Tensor& xq, + const at::Tensor& xk, + const at::Tensor& freqs_cos, + const at::Tensor& freqs_sin) { + const int64_t rotary_dim = freqs_cos.size(1); + const int64_t rotary_half = rotary_dim / 2; + + // Split into rotary and passthrough regions + at::Tensor xq_rot = xq.slice(/*dim=*/3, /*start=*/0, /*end=*/rotary_dim); + at::Tensor xq_pass = xq.slice(/*dim=*/3, /*start=*/rotary_dim); + at::Tensor xk_rot = xk.slice(/*dim=*/3, /*start=*/0, /*end=*/rotary_dim); + at::Tensor xk_pass = xk.slice(/*dim=*/3, /*start=*/rotary_dim); + + // Split rotary region into first and second halves + at::Tensor xq_first = + xq_rot.slice(/*dim=*/3, /*start=*/0, /*end=*/rotary_half); + at::Tensor xq_second = xq_rot.slice(/*dim=*/3, /*start=*/rotary_half); + at::Tensor xk_first = + xk_rot.slice(/*dim=*/3, /*start=*/0, /*end=*/rotary_half); + at::Tensor xk_second = xk_rot.slice(/*dim=*/3, /*start=*/rotary_half); + + // freqs are (seq_len, rotary_dim); use first half only + at::Tensor cos_half = + freqs_cos.slice(/*dim=*/1, /*start=*/0, /*end=*/rotary_half); + at::Tensor sin_half = + freqs_sin.slice(/*dim=*/1, /*start=*/0, /*end=*/rotary_half); + + at::Tensor cos_reshape = + cos_half.reshape({1, cos_half.size(0), 1, cos_half.size(1)}); + at::Tensor sin_reshape = + sin_half.reshape({1, sin_half.size(0), 1, sin_half.size(1)}); + + at::Tensor xq_out_first = xq_first * cos_reshape - xq_second * sin_reshape; + at::Tensor xq_out_second = xq_second * cos_reshape + xq_first * sin_reshape; + at::Tensor xk_out_first = xk_first * cos_reshape - xk_second * sin_reshape; + at::Tensor xk_out_second = xk_second * cos_reshape + xk_first * sin_reshape; + + at::Tensor xq_out = + at::cat({xq_out_first, xq_out_second, xq_pass}, /*dim=*/3); + at::Tensor xk_out = + at::cat({xk_out_first, xk_out_second, xk_pass}, /*dim=*/3); + + return std::make_pair(xq_out, xk_out); +} + +void test_reference_hf_partial_rotary( + const int n_heads = 8, + const int n_kv_heads = 4, + const int head_dim = 128, + const int rotary_dim = 96, + const int seq_len = 3, + const int start_pos = 0, + const int max_seq_len = 32) { + at::Tensor xq = at::rand( + {1, seq_len, n_heads, head_dim}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor xk = at::rand( + {1, seq_len, n_kv_heads, head_dim}, + at::device(at::kCPU).dtype(at::kFloat)); + + // Full freqs table with rotary_dim < head_dim + at::Tensor freqs_cos_full = at::rand( + {max_seq_len, rotary_dim}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor freqs_sin_full = at::rand( + {max_seq_len, rotary_dim}, at::device(at::kCPU).dtype(at::kFloat)); + + // Slice freqs for reference + at::Tensor freqs_cos_sliced = + freqs_cos_full.slice(/*dim=*/0, start_pos, start_pos + seq_len); + at::Tensor freqs_sin_sliced = + freqs_sin_full.slice(/*dim=*/0, start_pos, start_pos + seq_len); + + auto ref = rotary_embedding_hf_partial_impl( + xq, xk, freqs_cos_sliced, freqs_sin_sliced); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + + IOValueRef r_xq = graph.add_input_tensor( + xq.sizes().vec(), from_at_scalartype(xq.scalar_type())); + IOValueRef r_xk = graph.add_input_tensor( + xk.sizes().vec(), from_at_scalartype(xk.scalar_type())); + IOValueRef r_freqs_cos = graph.add_input_tensor( + freqs_cos_full.sizes().vec(), + from_at_scalartype(freqs_cos_full.scalar_type())); + IOValueRef r_freqs_sin = graph.add_input_tensor( + freqs_sin_full.sizes().vec(), + from_at_scalartype(freqs_sin_full.scalar_type())); + + const ValueRef r_xq_out = graph.add_tensor( + ref.first.sizes().vec(), from_at_scalartype(ref.first.scalar_type())); + const ValueRef r_xk_out = graph.add_tensor( + ref.second.sizes().vec(), from_at_scalartype(ref.second.scalar_type())); + + const ValueRef r_start_pos = graph.add_scalar(start_pos); + + VK_GET_OP_FN("et_vk.apply_rotary_emb_hf.default") + (graph, + {r_xq.value, + r_xk.value, + r_freqs_cos.value, + r_freqs_sin.value, + r_start_pos, + graph.add_value_list({r_xq_out, r_xk_out})}); + + ValueRef staging_xq_out = graph.set_output_tensor(r_xq_out); + ValueRef staging_xk_out = graph.set_output_tensor(r_xk_out); + + graph.prepare(); + graph.prepack(); + + graph.propagate_resize(); + graph.maybe_cast_and_copy_into_staging( + r_xq.staging, + xq.const_data_ptr(), + xq.numel(), + from_at_scalartype(xq.scalar_type())); + graph.maybe_cast_and_copy_into_staging( + r_xk.staging, + xk.const_data_ptr(), + xk.numel(), + from_at_scalartype(xk.scalar_type())); + graph.maybe_cast_and_copy_into_staging( + r_freqs_cos.staging, + freqs_cos_full.const_data_ptr(), + freqs_cos_full.numel(), + from_at_scalartype(freqs_cos_full.scalar_type())); + graph.maybe_cast_and_copy_into_staging( + r_freqs_sin.staging, + freqs_sin_full.const_data_ptr(), + freqs_sin_full.numel(), + from_at_scalartype(freqs_sin_full.scalar_type())); + + graph.execute(); + + at::Tensor vk_xq_out = at::empty_like(ref.first); + graph.maybe_cast_and_copy_from_staging( + staging_xq_out, + vk_xq_out.mutable_data_ptr(), + vk_xq_out.numel(), + from_at_scalartype(vk_xq_out.scalar_type())); + + at::Tensor vk_xk_out = at::empty_like(ref.second); + graph.maybe_cast_and_copy_from_staging( + staging_xk_out, + vk_xk_out.mutable_data_ptr(), + vk_xk_out.numel(), + from_at_scalartype(vk_xk_out.scalar_type())); + + EXPECT_TRUE(at::allclose(ref.first, vk_xq_out, 1e-4, 1e-4)); + EXPECT_TRUE(at::allclose(ref.second, vk_xk_out, 1e-4, 1e-4)); +} + +// Phi4 Mini-like: head_dim=128, rotary_dim=96 (partial_rotary_factor=0.75) +TEST(VulkanRotaryEmbeddingHFTest, rotary_embedding_hf_partial_rotary) { + test_reference_hf_partial_rotary(); +} + +// Partial rotary with non-zero start_pos +TEST( + VulkanRotaryEmbeddingHFTest, + rotary_embedding_hf_partial_rotary_start_pos) { + test_reference_hf_partial_rotary( + /*n_heads=*/16, + /*n_kv_heads=*/8, + /*head_dim=*/128, + /*rotary_dim=*/96, + /*seq_len=*/1, + /*start_pos=*/10, + /*max_seq_len=*/64); +}