From 6ac09f506cc3caa1d8944eee562c3c1127703d4f Mon Sep 17 00:00:00 2001 From: AaronWang04 Date: Tue, 21 Apr 2026 13:54:14 -0700 Subject: [PATCH 1/4] test --- quack/rmsnorm.py | 159 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 137 insertions(+), 22 deletions(-) diff --git a/quack/rmsnorm.py b/quack/rmsnorm.py index bf09b8c9..9d44f394 100644 --- a/quack/rmsnorm.py +++ b/quack/rmsnorm.py @@ -2,7 +2,7 @@ import math from typing import Optional, Tuple, Type -from functools import partial +from functools import cache, partial import cuda.bindings.driver as cuda @@ -539,7 +539,7 @@ class RMSNormBackward(ReductionBase): def __init__(self, dtype: cutlass.Numeric, N: int): # 2 stages for double buffering when computing mean of x_hat * wdy super().__init__(dtype, N, stage=2, reduction_dtype=Float32) - self.reload_wdy = None if N <= 16 * 1024 else "smem" + self.reload_wdy = "smem" if self.N > 128 * 1024 and self.dtype.width >= 32: # Not enough smem raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits") @@ -584,9 +584,12 @@ def __call__( mRstd: cute.Tensor, mdX: cute.Tensor, mdW: Optional[cute.Tensor], + mdW_final: Optional[cute.Tensor], mdRes: Optional[cute.Tensor], mdB: Optional[cute.Tensor], + mSemaphore: Optional[cute.Tensor], sm_count: Int32, + group_size: Int32, stream: cuda.CUstream, ): assert mX.element_type == self.dtype @@ -600,10 +603,13 @@ def __call__( mW = ( layout_utils.expand(mW, dim=0, size=tiler_mn[0]) if const_expr(mW is not None) else None ) + if const_expr(mdW_final is not None): + mdW_final = layout_utils.expand(mdW_final, dim=0, size=1) num_blocks = sm_count num_heads = mX.shape[1] if const_expr(cute.rank(mX) == 3) else 1 self.kernel( - mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tiler_mn, tiled_copy, threads_per_row + mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdW_final, mdB, mdRes, mSemaphore, + group_size, tiler_mn, tiled_copy, threads_per_row, ).launch( grid=[num_blocks, self.cluster_n, num_heads], block=[num_threads, 1, 1], @@ -621,8 +627,11 @@ def kernel( mRstd: cute.Tensor, mdX: cute.Tensor, mdW: Optional[cute.Tensor], + mdW_final: Optional[cute.Tensor], mdB: Optional[cute.Tensor], mdRes: Optional[cute.Tensor], + mSemaphore: Optional[cute.Tensor], + group_size: Int32, tiler_mn: cute.Shape, tiled_copy: cute.TiledCopy, threads_per_row: cutlass.Constexpr[int], @@ -893,6 +902,81 @@ def kernel( if const_expr(mdB is not None): copy(tXrdB, tXgdB) + # Two-level grouped reduction: reduce dw_partial across CTAs into mdW_final. + # Level 1: each group of group_size CTAs reduces to dw_partial[group_leader]. + # Level 2: the last group-reducer reduces G group sums into mdW_final. + # Only supported for cluster_n == 1; for cluster_n > 1 the caller + # must reduce dw_partial on the host. + if const_expr(mdW_final is not None and self.cluster_n == 1): + cute.arch.fence_acq_rel_gpu() + + my_group = Int32(bidx_start / group_size) + group_base = Int32(my_group * group_size) + group_count = Int32(group_size) + if group_base + group_size > gdim: + group_count = Int32(gdim - group_base) + num_groups = Int32((gdim + group_size - Int32(1)) / group_size) + + sFlag = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=Int32), + cute.make_layout((1,)), + ) + + gdW_all = cute.local_tile(mdW, (1, tiler_mn[1]), (None, cluster_y)) + tXgdW_all = thr_copy_X.partition_S(gdW_all) + tXrdW_accum = cute.make_fragment_like( + tXgdW_all[None, None, None, 0], Float32, + ) + tXrdW_row = cute.make_fragment_like(tXgdW_all[None, None, None, 0]) + + # --- Level 1: intra-group reduction --- + if tidx == 0: + is_last_in_group = Int32(0) + old = utils.atomic_add_i32(Int32(1), mSemaphore.iterator + my_group) + if old == group_count - Int32(1): + is_last_in_group = Int32(1) + sFlag[0] = is_last_in_group + cute.arch.barrier() + + if sFlag[0]: + cute.arch.fence_acq_rel_gpu() + tXrdW_accum.fill(0.0) + for i in cutlass.range(group_base, group_base + group_count): + copy(tXgdW_all[None, None, None, i], tXrdW_row) + tXrdW_accum.store(tXrdW_accum.load() + tXrdW_row.load()) + gdW_leader = cute.local_tile( + mdW, (1, tiler_mn[1]), (group_base, cluster_y), + ) + tXgdW_leader = thr_copy_X.partition_D(gdW_leader) + copy(tXrdW_accum, tXgdW_leader) + + # --- Level 2: cross-group reduction --- + cute.arch.fence_acq_rel_gpu() + if tidx == 0: + is_last_group = Int32(0) + old = utils.atomic_add_i32( + Int32(1), mSemaphore.iterator + num_groups, + ) + if old == num_groups - Int32(1): + is_last_group = Int32(1) + sFlag[0] = is_last_group + cute.arch.barrier() + + if sFlag[0]: + cute.arch.fence_acq_rel_gpu() + tXrdW_accum.fill(0.0) + for g in cutlass.range(0, num_groups): + leader_row = g * group_size + copy(tXgdW_all[None, None, None, leader_row], tXrdW_row) + tXrdW_accum.store(tXrdW_accum.load() + tXrdW_row.load()) + gdW_final = cute.local_tile( + mdW_final, (1, tiler_mn[1]), (0, cluster_y), + ) + tXgdW_final = thr_copy_X.partition_D(gdW_final) + tXrdW_out = cute.make_fragment_like(tXgdW_final) + tXrdW_out.store(tXrdW_accum.load().to(tXrdW_out.element_type)) + copy(tXrdW_out, tXgdW_final) + if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early # Assume state contains that next useful buffer # So we only need to advance to num_stages - 1 times to last used buffer @@ -919,12 +1003,20 @@ def _get_sm_count(N: int, device: torch.device) -> int: return sm_count +@cache +def _get_semaphore(device: torch.device) -> torch.Tensor: + """Reuse same semaphore to avoid repeated torch.zero calls. + num_groups + 1 slots needed; ceil(sqrt(max_sm_count)) + 1 fits in 64 for any current GPU. + """ + return torch.zeros(64, device=device, dtype=torch.int32) + + @torch.library.custom_op( "quack::_rmsnorm_bwd", - mutates_args={"dx", "dw_partial", "db_partial", "dresidual"}, + mutates_args={"dx", "dw_partial", "db_partial", "dresidual", "dw"}, device_types="cuda", # We need to specify the schema manually since we're mutating an optional tensor - schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count) -> ()", + schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count, Tensor(a10!)? dw, Tensor? semaphore, int? group_size) -> ()", ) def _rmsnorm_bwd( x: Tensor, @@ -937,18 +1029,11 @@ def _rmsnorm_bwd( dresidual_out: Optional[Tensor] = None, dresidual: Optional[Tensor] = None, sm_count: Optional[int] = None, + dw: Optional[Tensor] = None, + semaphore: Optional[Tensor] = None, + group_size: Optional[int] = None, ) -> None: - """RMSNorm backward pass. - Args: - x: Input tensor of shape (M, N) or (M, H, N) for per-head - weight: Optional weight tensor of shape (N,) or (H, N) for per-head - dout: Upstream gradients tensor of shape (M, N) or (M, H, N) - rstd: Reciprocal standard deviation tensor of shape (M,) or (M, H) - Returns: - Tuple of (dx, dw) where: - - dx: Input gradients tensor of same shape as x - - dw: Weight gradients tensor of same shape as weight (or None if weight is None) - """ + """RMSNorm backward pass (mutates dx, dw_partial, dw in-place).""" assert x.dim() in (2, 3), "Input must be 2D or 3D" assert x.is_cuda, "Input tensor must be on CUDA device" supported_types = {torch.float16, torch.bfloat16, torch.float32} @@ -977,6 +1062,7 @@ def _rmsnorm_bwd( torch2cute_dtype_map[t.dtype] if t is not None else None for t in [x, dout, dx, weight, dresidual, dresidual_out] ] + dw_dtype = torch2cute_dtype_map[dw.dtype] if dw is not None else None _compile_rmsnorm_bwd( N, dtype, @@ -988,7 +1074,12 @@ def _rmsnorm_bwd( dres_out_dtype, dw_partial is not None, per_head, - )(x, weight, dout, dresidual_out, rstd, dx, dw_partial, dresidual, db_partial, sm_count) + dw_dtype, + )( + x, weight, dout, dresidual_out, rstd, dx, dw_partial, dw, + dresidual, db_partial, semaphore, sm_count, + group_size if group_size is not None else 0, + ) @_rmsnorm_bwd.register_fake @@ -1003,6 +1094,9 @@ def _rmsnorm_bwd_fake( dresidual_out: Optional[Tensor] = None, dresidual: Optional[Tensor] = None, sm_count: Optional[int] = None, + dw: Optional[Tensor] = None, + semaphore: Optional[Tensor] = None, + group_size: Optional[int] = None, ) -> None: # See softmax.py _softmax_fwd_fake for why register_fake is needed. from quack.cache_utils import COMPILE_ONLY @@ -1016,6 +1110,7 @@ def _rmsnorm_bwd_fake( torch2cute_dtype_map[t.dtype] if t is not None else None for t in [x, dout, dx, weight, dresidual, dresidual_out] ] + dw_dtype = torch2cute_dtype_map[dw.dtype] if dw is not None else None _compile_rmsnorm_bwd( N, dtype, @@ -1027,6 +1122,7 @@ def _rmsnorm_bwd_fake( dres_out_dtype, dw_partial is not None, per_head, + dw_dtype, ) @@ -1042,6 +1138,7 @@ def _compile_rmsnorm_bwd( dres_out_dtype, has_dw_partial, per_head=False, + dw_dtype=None, ): batch_sym, batch_partial_sym = cute.sym_int(), cute.sym_int() head_sym = cute.sym_int() if per_head else None @@ -1058,6 +1155,8 @@ def _compile_rmsnorm_bwd( dw_shape = (batch_partial_sym, head_sym, N) if per_head else (batch_partial_sym, N) dw_partial_cute = fake_tensor(Float32, dw_shape, div) if has_dw_partial else None db_partial_cute = fake_tensor(Float32, dw_shape, div) if has_db_partial else None + dw_cute = fake_tensor(dw_dtype, (N,), div) if dw_dtype is not None else None + semaphore_cute = fake_tensor(Int32, (64,)) if dw_dtype is not None else None return cute.compile( RMSNormBackward(dtype, N), x_cute, @@ -1067,9 +1166,12 @@ def _compile_rmsnorm_bwd( rstd_cute, dx_cute, dw_partial_cute, + dw_cute, dres_cute, db_partial_cute, + semaphore_cute, 0, # sm_count, just for compilation + 0, # group_size, just for compilation cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) @@ -1098,21 +1200,34 @@ def rmsnorm_bwd( sm_count = max(round(sm_count / H), 1) else: H = None + dw_partial: Optional[Tensor] = None + dw: Optional[Tensor] = None + semaphore: Optional[Tensor] = None + group_size: Optional[int] = None + # In-kernel cross-CTA dw reduction via two-level tree. Only supported for + # cluster_n == 1 (N <= 8192) and non-per-head. For larger N or per-head the + # kernel ignores dw/semaphore and we fall back to host-side dw_partial.sum(). + use_in_kernel_dw_reduction = N <= 8192 and weight is not None and not per_head if weight is not None: # Always store partial gradients in fp32 for numerical accuracy dw_shape = (sm_count, H, N) if per_head else (sm_count, N) dw_partial = torch.empty(dw_shape, device=device, dtype=torch.float32) - else: - dw_partial = None + if use_in_kernel_dw_reduction: + dw = torch.empty(N, device=device, dtype=weight.dtype) + semaphore = _get_semaphore(device) + semaphore.zero_() + G = math.ceil(math.sqrt(sm_count)) + group_size = math.ceil(sm_count / G) db_shape = (sm_count, H, N) if per_head else (sm_count, N) db_partial = torch.empty(db_shape, device=device, dtype=torch.float32) if has_bias else None _rmsnorm_bwd( - x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual, sm_count + x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual, + sm_count, dw, semaphore, group_size, ) - # we have summed the partial gradients in fp32, now we convert back to the weight dtype - dw = dw_partial.sum(dim=0).to(weight.dtype) if weight is not None else None + if weight is not None and not use_in_kernel_dw_reduction: + dw = dw_partial.sum(dim=0).to(weight.dtype) db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None # dresidual is the same as dx in this case if has_residual and dresidual is None: From 78cb5124cf0e79c358cd0320cbd9b9977aa57431 Mon Sep 17 00:00:00 2001 From: AaronWang04 Date: Tue, 21 Apr 2026 13:56:50 -0700 Subject: [PATCH 2/4] ops --- quack/rmsnorm.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/quack/rmsnorm.py b/quack/rmsnorm.py index 9d44f394..62e2daa3 100644 --- a/quack/rmsnorm.py +++ b/quack/rmsnorm.py @@ -1033,7 +1033,17 @@ def _rmsnorm_bwd( semaphore: Optional[Tensor] = None, group_size: Optional[int] = None, ) -> None: - """RMSNorm backward pass (mutates dx, dw_partial, dw in-place).""" + """RMSNorm backward pass. + Args: + x: Input tensor of shape (M, N) or (M, H, N) for per-head + weight: Optional weight tensor of shape (N,) or (H, N) for per-head + dout: Upstream gradients tensor of shape (M, N) or (M, H, N) + rstd: Reciprocal standard deviation tensor of shape (M,) or (M, H) + Returns: + Tuple of (dx, dw) where: + - dx: Input gradients tensor of same shape as x + - dw: Weight gradients tensor of same shape as weight (or None if weight is None) + """ assert x.dim() in (2, 3), "Input must be 2D or 3D" assert x.is_cuda, "Input tensor must be on CUDA device" supported_types = {torch.float16, torch.bfloat16, torch.float32} From 21aca9a8f7572fd4eeac79ac9c7969daee70eb2f Mon Sep 17 00:00:00 2001 From: AaronWang04 Date: Tue, 21 Apr 2026 13:57:37 -0700 Subject: [PATCH 3/4] format --- quack/rmsnorm.py | 60 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 50 insertions(+), 10 deletions(-) diff --git a/quack/rmsnorm.py b/quack/rmsnorm.py index 62e2daa3..b2a87bf8 100644 --- a/quack/rmsnorm.py +++ b/quack/rmsnorm.py @@ -608,8 +608,21 @@ def __call__( num_blocks = sm_count num_heads = mX.shape[1] if const_expr(cute.rank(mX) == 3) else 1 self.kernel( - mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdW_final, mdB, mdRes, mSemaphore, - group_size, tiler_mn, tiled_copy, threads_per_row, + mX, + mW, + mdO, + mdResO, + mRstd, + mdX, + mdW, + mdW_final, + mdB, + mdRes, + mSemaphore, + group_size, + tiler_mn, + tiled_copy, + threads_per_row, ).launch( grid=[num_blocks, self.cluster_n, num_heads], block=[num_threads, 1, 1], @@ -925,7 +938,8 @@ def kernel( gdW_all = cute.local_tile(mdW, (1, tiler_mn[1]), (None, cluster_y)) tXgdW_all = thr_copy_X.partition_S(gdW_all) tXrdW_accum = cute.make_fragment_like( - tXgdW_all[None, None, None, 0], Float32, + tXgdW_all[None, None, None, 0], + Float32, ) tXrdW_row = cute.make_fragment_like(tXgdW_all[None, None, None, 0]) @@ -945,7 +959,9 @@ def kernel( copy(tXgdW_all[None, None, None, i], tXrdW_row) tXrdW_accum.store(tXrdW_accum.load() + tXrdW_row.load()) gdW_leader = cute.local_tile( - mdW, (1, tiler_mn[1]), (group_base, cluster_y), + mdW, + (1, tiler_mn[1]), + (group_base, cluster_y), ) tXgdW_leader = thr_copy_X.partition_D(gdW_leader) copy(tXrdW_accum, tXgdW_leader) @@ -955,7 +971,8 @@ def kernel( if tidx == 0: is_last_group = Int32(0) old = utils.atomic_add_i32( - Int32(1), mSemaphore.iterator + num_groups, + Int32(1), + mSemaphore.iterator + num_groups, ) if old == num_groups - Int32(1): is_last_group = Int32(1) @@ -970,7 +987,9 @@ def kernel( copy(tXgdW_all[None, None, None, leader_row], tXrdW_row) tXrdW_accum.store(tXrdW_accum.load() + tXrdW_row.load()) gdW_final = cute.local_tile( - mdW_final, (1, tiler_mn[1]), (0, cluster_y), + mdW_final, + (1, tiler_mn[1]), + (0, cluster_y), ) tXgdW_final = thr_copy_X.partition_D(gdW_final) tXrdW_out = cute.make_fragment_like(tXgdW_final) @@ -1086,8 +1105,18 @@ def _rmsnorm_bwd( per_head, dw_dtype, )( - x, weight, dout, dresidual_out, rstd, dx, dw_partial, dw, - dresidual, db_partial, semaphore, sm_count, + x, + weight, + dout, + dresidual_out, + rstd, + dx, + dw_partial, + dw, + dresidual, + db_partial, + semaphore, + sm_count, group_size if group_size is not None else 0, ) @@ -1232,8 +1261,19 @@ def rmsnorm_bwd( db_partial = torch.empty(db_shape, device=device, dtype=torch.float32) if has_bias else None _rmsnorm_bwd( - x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual, - sm_count, dw, semaphore, group_size, + x, + weight, + dout, + rstd, + dx, + dw_partial, + db_partial, + dresidual_out, + dresidual, + sm_count, + dw, + semaphore, + group_size, ) if weight is not None and not use_in_kernel_dw_reduction: From 7ba98f49a4e1dba5c39a9a619a15108637bc266d Mon Sep 17 00:00:00 2001 From: AaronWang04 Date: Tue, 21 Apr 2026 14:25:09 -0700 Subject: [PATCH 4/4] condition --- quack/rmsnorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quack/rmsnorm.py b/quack/rmsnorm.py index b2a87bf8..8869bed1 100644 --- a/quack/rmsnorm.py +++ b/quack/rmsnorm.py @@ -1246,7 +1246,7 @@ def rmsnorm_bwd( # In-kernel cross-CTA dw reduction via two-level tree. Only supported for # cluster_n == 1 (N <= 8192) and non-per-head. For larger N or per-head the # kernel ignores dw/semaphore and we fall back to host-side dw_partial.sum(). - use_in_kernel_dw_reduction = N <= 8192 and weight is not None and not per_head + use_in_kernel_dw_reduction = N <= 2048 and weight is not None and not per_head if weight is not None: # Always store partial gradients in fp32 for numerical accuracy dw_shape = (sm_count, H, N) if per_head else (sm_count, N)