From e92cacff2200f07144871d228eb5a29d2309c0a7 Mon Sep 17 00:00:00 2001 From: Simon Veitner Date: Sat, 23 May 2026 22:02:49 +0000 Subject: [PATCH] Fuse the copy for rotary. --- benchmarks/benchmark_rotary.py | 406 +++++++++++++++++++++++++++++++++ quack/rotary.py | 102 ++++++++- tests/test_rotary.py | 33 +++ 3 files changed, 537 insertions(+), 4 deletions(-) create mode 100644 benchmarks/benchmark_rotary.py diff --git a/benchmarks/benchmark_rotary.py b/benchmarks/benchmark_rotary.py new file mode 100644 index 00000000..56cfa8c5 --- /dev/null +++ b/benchmarks/benchmark_rotary.py @@ -0,0 +1,406 @@ +import argparse +import math +import os + +os.environ.setdefault("TORCH_COMPILE_DYNAMIC", "0") + +import torch +from triton.testing import Benchmark, do_bench, perf_report + +from quack.bench.bench_utils import run_and_print +from quack.rotary import apply_rotary_emb + + +# Keep the flattened row count (B * S * H) +# fixed while increasing D, then shrink rows for large D to keep the +# benchmark tensor size bounded. B/S/H stay explicit in the table so +# sequence/head-layout effects remain visible. +SHAPE_CASES = { + "rows1m-d32-full": (8, 4096, 32, 32, 32), + "rows1m-d64-full": (8, 4096, 32, 64, 64), + "rows1m-d96-full": (8, 4096, 32, 96, 96), + "rows1m-d128-full": (8, 4096, 32, 128, 128), + "rows1m-d128-half": (8, 4096, 32, 128, 64), + "rows512k-d256-full": (4, 4096, 32, 256, 256), + "rows512k-d256-half": (4, 4096, 32, 256, 128), + "rows256k-d512-full": (2, 4096, 32, 512, 512), + "rows256k-d512-half": (2, 4096, 32, 512, 256), +} + +DTYPE_MAP = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, +} + + +def generate_cos_sin(seqlen: int, rotary_dim: int, device: str, dtype: torch.dtype): + assert rotary_dim % 2 == 0 + angle = torch.rand(seqlen, rotary_dim // 2, device=device) * 2 * math.pi + cos = torch.cos(angle).to(dtype=dtype) + sin = torch.sin(angle).to(dtype=dtype) + return cos, sin + + +def generate_offsets(use_offsets: bool, batch: int, seqlen: int, device: str): + if not use_offsets: + return None + return torch.randint(0, seqlen + 1, (batch,), dtype=torch.int32, device=device) + + +def rotate_half(x: torch.Tensor, interleaved: bool): + if not interleaved: + x0, x1 = x.chunk(2, dim=-1) + return torch.cat((-x1, x0), dim=-1) + x0, x1 = x[..., ::2], x[..., 1::2] + return torch.stack((-x1, x0), dim=-1).reshape_as(x) + + +def rotary_ref( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: torch.Tensor | None, + interleaved: bool, +): + seqlen = x.shape[1] + rotary_dim = cos.shape[-1] * 2 + if seqlen_offsets is None: + cos = cos[:seqlen] + sin = sin[:seqlen] + else: + arange = torch.arange(seqlen, device=x.device).view(1, seqlen) + idx = seqlen_offsets.view(-1, 1) + arange + cos = cos[idx] + sin = sin[idx] + + cos = cos.unsqueeze(-2) + sin = sin.unsqueeze(-2) + if not interleaved: + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + else: + cos = cos.unsqueeze(-1).expand(*cos.shape, 2).reshape(*cos.shape[:-1], rotary_dim) + sin = sin.unsqueeze(-1).expand(*sin.shape, 2).reshape(*sin.shape[:-1], rotary_dim) + if cos.dim() == 3: + cos = cos.unsqueeze(0) + sin = sin.unsqueeze(0) + + x_ro = x[..., :rotary_dim] + out_ro = x_ro * cos + rotate_half(x_ro, interleaved) * sin + return torch.cat((out_ro, x[..., rotary_dim:]), dim=-1) + + +def rotary_ref_inplace( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: torch.Tensor | None, + interleaved: bool, +): + rotary_dim = cos.shape[-1] * 2 + x[..., :rotary_dim].copy_( + rotary_ref(x[..., :rotary_dim], cos, sin, seqlen_offsets, interleaved) + ) + return x + + +def _providers(include_torch_eager: bool): + providers = [("quack", "quack"), ("torch_compile", "torch.compile")] + if include_torch_eager: + providers.append(("torch_eager", "torch eager")) + return providers + + +def _result(num_bytes: int, ms: float) -> dict: + gbps = num_bytes / (ms / 1000) / 1e9 + return {"ms": round(ms, 4), "GB/s": round(gbps)} + + +def _logical_bytes( + batch: int, + seqlen: int, + nheads: int, + headdim: int, + rotary_dim: int, + dtype: torch.dtype, + cossin_dtype: torch.dtype, + inplace: bool, + use_offsets: bool, +) -> int: + # HBM-style lower-bound bytes: x read + output write, plus the unique + # cos/sin table footprint once. Counting cos/sin per head is an effective + # reuse metric and can exceed physical HBM bandwidth when the table is hot + # in cache. + x_elems = batch * seqlen * nheads * (rotary_dim if inplace else headdim) + cossin_elems = (2 * seqlen if use_offsets else seqlen) * rotary_dim + return 2 * x_elems * dtype.itemsize + cossin_elems * cossin_dtype.itemsize + + +def make_benchmark( + dtype_name: str, + cossin_dtype_name: str, + interleaved: bool, + use_offsets: bool, + backward: bool, + inplace: bool, + include_torch_eager: bool, + warmup: int, + rep: int, + x_vals=None, +) -> Benchmark: + line_vals, line_names = zip(*_providers(include_torch_eager)) + direction = "bwd" if backward else "fwd" + suffix = [ + direction, + dtype_name, + f"cossin-{cossin_dtype_name}", + "interleaved" if interleaved else "contiguous-pair", + "offsets" if use_offsets else "no-offsets", + ] + if inplace: + suffix.append("inplace") + return Benchmark( + x_names=["B", "S", "H", "D", "rotary_dim"], + x_vals=x_vals if x_vals is not None else list(SHAPE_CASES.values()), + line_arg="provider", + line_vals=list(line_vals), + line_names=list(line_names), + plot_name="rotary-" + "-".join(suffix), + args={ + "dtype_name": dtype_name, + "cossin_dtype_name": cossin_dtype_name, + "interleaved": interleaved, + "use_offsets": use_offsets, + "backward": backward, + "inplace": inplace, + "warmup": warmup, + "rep": rep, + }, + xlabel="(B, S, H, D, rotary_dim)", + ylabel="GB/s", + ) + + +def _make_inputs(B, S, H, D, rotary_dim, dtype_name, cossin_dtype_name, use_offsets): + dtype = DTYPE_MAP[dtype_name] + cossin_dtype = DTYPE_MAP[cossin_dtype_name] + x = torch.randn(B, S, H, D, device="cuda", dtype=dtype) + cos_len = 2 * S if use_offsets else S + cos, sin = generate_cos_sin(cos_len, rotary_dim, "cuda", cossin_dtype) + seqlen_offsets = generate_offsets(use_offsets, B, S, "cuda") + return x, cos, sin, seqlen_offsets + + +def _check_correctness( + B, + S, + H, + D, + rotary_dim, + dtype_name, + cossin_dtype_name, + interleaved, + use_offsets, + backward, + inplace, +): + x, cos, sin, seqlen_offsets = _make_inputs( + B, S, H, D, rotary_dim, dtype_name, cossin_dtype_name, use_offsets + ) + x_ref = x.detach().clone() + if backward: + x = x.requires_grad_() + x_ref = x_ref.requires_grad_() + out = apply_rotary_emb(x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved) + out_ref = rotary_ref( + x_ref.float(), cos.float(), sin.float(), seqlen_offsets, interleaved + ).to(dtype=x.dtype) + grad = torch.randn_like(out) + (dx,) = torch.autograd.grad(out, x, grad) + (dx_ref,) = torch.autograd.grad(out_ref, x_ref, grad) + torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-3) + torch.testing.assert_close(dx, dx_ref, atol=1e-2, rtol=1e-3) + return + + out = apply_rotary_emb(x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved) + out_ref = rotary_ref(x_ref.float(), cos.float(), sin.float(), seqlen_offsets, interleaved).to( + dtype=x.dtype + ) + torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-3) + if inplace: + x_inplace = x.detach().clone() + out_inplace = apply_rotary_emb( + x_inplace, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=interleaved, + inplace=True, + ) + assert out_inplace.data_ptr() == x_inplace.data_ptr() + torch.testing.assert_close(out_inplace, out_ref, atol=1e-2, rtol=1e-3) + + +def rotary_runner( + B, + S, + H, + D, + rotary_dim, + provider, + dtype_name, + cossin_dtype_name, + interleaved, + use_offsets, + backward, + inplace, + warmup, + rep, +): + x, cos, sin, seqlen_offsets = _make_inputs( + B, S, H, D, rotary_dim, dtype_name, cossin_dtype_name, use_offsets + ) + bytes_moved = _logical_bytes( + B, S, H, D, rotary_dim, x.dtype, cos.dtype, inplace and not backward, use_offsets + ) + + if provider == "quack": + if backward: + x = x.requires_grad_() + y = apply_rotary_emb( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=interleaved, + inplace=False, + ) + dy = torch.randn_like(y) + fn = lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) + grad_to_none = (x,) + else: + fn = lambda: apply_rotary_emb( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=interleaved, + inplace=inplace, + ) + grad_to_none = None + elif provider == "torch_compile": + ref_fn = rotary_ref_inplace if inplace else rotary_ref + compiled = torch.compile( + lambda x, cos, sin, offsets: ref_fn(x, cos, sin, offsets, interleaved) + ) + if backward: + x = x.requires_grad_() + y = compiled(x, cos, sin, seqlen_offsets) + dy = torch.randn_like(y) + fn = torch.compile( + lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) + ) + grad_to_none = (x,) + else: + fn = lambda: compiled(x, cos, sin, seqlen_offsets) + grad_to_none = None + elif provider == "torch_eager": + if backward: + x = x.requires_grad_() + y = rotary_ref(x, cos, sin, seqlen_offsets, interleaved) + dy = torch.randn_like(y) + fn = lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) + grad_to_none = (x,) + else: + ref_fn = rotary_ref_inplace if inplace else rotary_ref + fn = lambda: ref_fn(x, cos, sin, seqlen_offsets, interleaved) + grad_to_none = None + else: + raise ValueError(provider) + + ms = do_bench(fn, warmup=warmup, rep=rep, grad_to_none=grad_to_none) + return _result(bytes_moved, ms) + + +def _parse_shape_args(args): + shape_values = [args.B, args.S, args.H, args.D, args.rotary_dim] + if any(v is not None for v in shape_values): + if any(v is None for v in shape_values): + raise SystemExit("--B, --S, --H, --D, and --rotary_dim must be provided together") + return [tuple(shape_values)] + if args.case: + return [SHAPE_CASES[name] for name in args.case] + return None + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark rotary embedding fwd / bwd") + parser.add_argument("--dtype", default="bfloat16", choices=list(DTYPE_MAP)) + parser.add_argument("--cossin_dtype", default=None, choices=list(DTYPE_MAP)) + parser.add_argument("--case", action="append", choices=sorted(SHAPE_CASES)) + parser.add_argument("--B", type=int, default=None, help="Batch for a single custom shape") + parser.add_argument( + "--S", type=int, default=None, help="Sequence length for a single custom shape" + ) + parser.add_argument( + "--H", type=int, default=None, help="Number of heads for a single custom shape" + ) + parser.add_argument( + "--D", type=int, default=None, help="Head dimension for a single custom shape" + ) + parser.add_argument( + "--rotary_dim", type=int, default=None, help="Rotary dimension for a custom shape" + ) + parser.add_argument("--interleaved", action="store_true") + parser.add_argument("--offsets", action="store_true", help="Benchmark tensor seqlen_offsets") + parser.add_argument("--backward", action="store_true") + parser.add_argument( + "--inplace", action="store_true", help="Only supported for forward benchmarks" + ) + parser.add_argument("--include_torch_eager", action="store_true") + parser.add_argument( + "--check", action="store_true", help="Run a correctness check before timing" + ) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--rep", type=int, default=100) + parser.add_argument("--save_path", default=None) + args = parser.parse_args() + + if args.backward and args.inplace: + parser.error("--inplace is only supported for forward benchmarks") + cossin_dtype = args.cossin_dtype or args.dtype + x_vals = _parse_shape_args(args) + torch.manual_seed(0) + if args.check: + for shape in x_vals if x_vals is not None else SHAPE_CASES.values(): + _check_correctness( + *shape, + args.dtype, + cossin_dtype, + args.interleaved, + args.offsets, + args.backward, + args.inplace, + ) + + torch.manual_seed(0) + bench = perf_report( + make_benchmark( + args.dtype, + cossin_dtype, + args.interleaved, + args.offsets, + args.backward, + args.inplace, + args.include_torch_eager, + args.warmup, + args.rep, + x_vals, + ) + )(rotary_runner) + run_and_print(bench, save_path=args.save_path) + + +if __name__ == "__main__": + main() diff --git a/quack/rotary.py b/quack/rotary.py index 6e109ce8..11c1a5d0 100644 --- a/quack/rotary.py +++ b/quack/rotary.py @@ -27,18 +27,40 @@ def _ensure_last_dim_contiguous(t: Tensor) -> Tensor: return t if t.stride(-1) == 1 else t.contiguous() +def _copy_vecsize_for_rows( + element_width: int, + logical_dim: int, + tile_dim: int, + rows: int, + num_threads: int, +) -> int: + vecsize = math.gcd(128 // element_width, logical_dim) + while vecsize > 1: + vecs_per_row = tile_dim // vecsize + threads_per_row = math.gcd(32, vecs_per_row) + rows_per_copy = num_threads // threads_per_row + if rows % rows_per_copy == 0: + return vecsize + vecsize //= 2 + return 1 + + class RotaryKernel: def __init__( self, dtype: type[cutlass.Numeric], dim: int, + headdim: int, interleaved: bool = False, conjugate: bool = False, + copy_tail: bool = False, ): self.dtype = dtype self.dim = dim + self.headdim = headdim self.interleaved = interleaved self.conjugate = conjugate + self.copy_tail = copy_tail self.num_threads = 128 self.tile_h = 2 if self.dim <= 96 else 1 multiple = 32 if dim <= 128 else 64 @@ -61,6 +83,7 @@ def __call__( assert mCos.element_type == mSin.element_type assert mCos.shape[1] == mSin.shape[1] assert mCos.shape[1] * 2 == self.dim + assert self.dim <= self.headdim self.is_varlen = const_expr(mCuSeqlens is not None) @@ -91,6 +114,26 @@ def __call__( mCos.element_type, threads_per_row_cs, self.num_threads, vecsize_cs ) assert tiler_mn[0] % (self.num_threads // threads_per_row_cs) == 0 + tail_tiler_mn = tiler_mn + tiled_copy_tail = tiled_copy + if const_expr(self.copy_tail): + tail_dim = self.headdim - self.dim + tail_multiple = 32 if tail_dim <= 128 else 64 + tail_tile_d = (tail_dim + tail_multiple - 1) // tail_multiple * tail_multiple + tail_vecsize = _copy_vecsize_for_rows( + mX.element_type.width, + tail_dim, + tail_tile_d, + tiler_mn[0], + self.num_threads, + ) + tail_vecs_per_row = tail_tile_d // tail_vecsize + tail_threads_per_row = math.gcd(32, tail_vecs_per_row) + tail_tiler_mn = (tiler_mn[0], tail_tile_d) + tiled_copy_tail = copy_utils.tiled_copy_2d( + mX.element_type, tail_threads_per_row, self.num_threads, tail_vecsize + ) + assert tiler_mn[0] % (self.num_threads // tail_threads_per_row) == 0 # (b, s, h, d) -> (s, d, h, b); (s, h, d) -> (s, d, h) x_layout_transpose = [0, 2, 1] if const_expr(self.is_varlen) else [1, 3, 2, 0] @@ -116,8 +159,10 @@ def __call__( mO, max_seqlen, tiler_mn, + tail_tiler_mn, tiled_copy, tiled_copy_cs, + tiled_copy_tail, ).launch( grid=[ cute.ceil_div(nheads, self.tile_h), @@ -139,8 +184,10 @@ def kernel( mO: cute.Tensor, max_seqlen: Int32, tiler_mn: cute.Shape, + tail_tiler_mn: cute.Shape, tiled_copy: cute.TiledCopy, tiled_copy_cs: cute.TiledCopy, + tiled_copy_tail: cute.TiledCopy, ): tidx, _, _ = cute.arch.thread_idx() head_idx, m_idx, batch_idx = cute.arch.block_idx() @@ -285,6 +332,41 @@ def kernel( if tXcX[0, m, 0][0] < seq_len: copy(tXrX[None, m, None], tXgO[None, m, None, h]) + if const_expr(self.copy_tail): + tail_dim = const_expr(self.headdim - self.dim) + tail_tiler_mnh = (tail_tiler_mn[0], tail_tiler_mn[1], self.tile_h) + if const_expr(self.is_varlen): + cTail_shape = (max_seqlen, tail_tiler_mn[1]) + else: + cTail_shape = (mX.shape[0], tail_tiler_mn[1]) + mX_tail = cute.domain_offset((None, self.dim, None), mX_batch) + mO_tail = cute.domain_offset((None, self.dim, None), mO_batch) + gTailX = cute.local_tile(mX_tail, tail_tiler_mnh, (m_idx, 0, head_idx)) + gTailO = cute.local_tile(mO_tail, tail_tiler_mnh, (m_idx, 0, head_idx)) + cTail = cute.local_tile( + cute.make_identity_tensor(cTail_shape), tail_tiler_mn, (m_idx, 0) + ) + thr_copy_tail = tiled_copy_tail.get_slice(tidx) + tTailcTail_full = thr_copy_tail.partition_S(cTail) + tTailcTail = tTailcTail_full[(0, None), None, None] + tTailgX = thr_copy_tail.partition_S(gTailX) + tTailgO = thr_copy_tail.partition_D(gTailO) + is_even_tail_dim = const_expr(tail_tiler_mn[1] == tail_dim) + pred_tail = None + if const_expr(not is_even_tail_dim): + pred_tail = copy_utils.predicate_k(tTailcTail_full, limit=tail_dim) + copy_tail = partial( + copy_utils.copy, + pred=pred_tail[None, 0, None] if not is_even_tail_dim else None, + ) + for h in cutlass.range_constexpr(self.tile_h): + if self.tile_h == 1 or h < nheads - head_idx * self.tile_h: + for m in cutlass.range(cute.size(tTailgX, mode=[1]), unroll_full=True): + if tTailcTail[0, m, 0][0] < seq_len: + tTailrX = cute.make_rmem_tensor_like(tTailgX[None, m, None, h]) + copy_tail(tTailgX[None, m, None, h], tTailrX) + copy_tail(tTailrX, tTailgO[None, m, None, h]) + @staticmethod @jit_cache def compile( @@ -293,8 +375,10 @@ def compile( seqlen_offsets_dtype, cu_seqlens_dtype, dim, + headdim, interleaved, conjugate, + copy_tail, ): is_varlen = cu_seqlens_dtype is not None has_seqlen_offsets = seqlen_offsets_dtype is not None @@ -303,7 +387,7 @@ def compile( seqlen_sym = cute.sym_int() total_seqlen_sym = cute.sym_int() nheads_sym = cute.sym_int() - x_dim_sym = cute.sym_int() + x_dim_sym = headdim if copy_tail else cute.sym_int() seqlen_ro_sym = cute.sym_int() x_shape = ( (total_seqlen_sym, nheads_sym, x_dim_sym) @@ -325,7 +409,14 @@ def compile( ) cu_seqlens_cute = fake_tensor(cu_seqlens_dtype, (batch_p1_sym,)) if is_varlen else None return cute.compile( - RotaryKernel(dtype, dim, interleaved=interleaved, conjugate=conjugate), + RotaryKernel( + dtype, + dim, + headdim, + interleaved=interleaved, + conjugate=conjugate, + copy_tail=copy_tail, + ), x_cute, cos_cute, sin_cute, @@ -357,8 +448,11 @@ def _launch_rotary( return dtype = torch2cute_dtype_map[x.dtype] cossin_dtype = torch2cute_dtype_map[cos.dtype] + headdim = x.size(-1) dim_half = cos.size(1) dim = dim_half * 2 + copy_tail = out is not x and dim < headdim + compile_headdim = headdim if copy_tail else dim seqlen_offsets_dtype = ( torch2cute_dtype_map[seqlen_offsets.dtype] if seqlen_offsets is not None else None ) @@ -369,8 +463,10 @@ def _launch_rotary( seqlen_offsets_dtype, cu_seqlens_dtype, dim, + compile_headdim, interleaved, conjugate, + copy_tail, )(x, cos, sin, seqlen_offsets, cu_seqlens, out, max_seqlen) @@ -506,8 +602,6 @@ def apply_rotary( cos, sin = _ensure_last_dim_contiguous(cos), _ensure_last_dim_contiguous(sin) out = x if inplace else torch.empty_like(x) - if rotary_dim < headdim and not inplace: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) if inplace: _rotary_fwd_inplace( x, diff --git a/tests/test_rotary.py b/tests/test_rotary.py index fa53f17f..ff2f4f7e 100644 --- a/tests/test_rotary.py +++ b/tests/test_rotary.py @@ -328,6 +328,39 @@ def test_rotary_emb_vector_width_selection(headdim, rotary_dim, x_offset, dtype) torch.testing.assert_close(x.grad, x_pt.grad, atol=1e-2, rtol=1e-3) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("seqlen_offsets_type", [None, torch.Tensor]) +@pytest.mark.parametrize("interleaved", [False, True]) +@pytest.mark.parametrize(("headdim", "rotary_dim"), [(256, 128), (512, 256)]) +def test_rotary_emb_large_partial_tail_copy( + headdim, rotary_dim, interleaved, seqlen_offsets_type, dtype +): + torch.manual_seed(42) + device = "cuda" + batch_size, seqlen, nheads = 2, 23, 3 + x = torch.randn( + batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True + ) + x_pt = x.detach().clone().requires_grad_() + cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype) + seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device) + + out = apply_rotary_emb(x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved) + cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen) + out_pt = apply_rotary_emb_torch( + x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved + ).to(dtype=dtype) + + grad = torch.randn_like(out) + grad_pt = grad.clone() + out.backward(grad) + out_pt.backward(grad_pt) + + assert torch.equal(x, x_pt) + torch.testing.assert_close(out, out_pt, atol=1e-2, rtol=1e-3) + torch.testing.assert_close(x.grad, x_pt.grad, atol=1e-2, rtol=1e-3) + + @pytest.mark.parametrize("x_dtype", [torch.bfloat16, torch.float32]) # @pytest.mark.parametrize("x_dtype", [torch.bfloat16]) @pytest.mark.parametrize("cossin_dtype", [torch.bfloat16, torch.float32])