From f128bf23f6a918870f8eb8b4f85d6ed311b73d1b Mon Sep 17 00:00:00 2001 From: Santosh Mohan Date: Sat, 18 Apr 2026 19:00:37 -0700 Subject: [PATCH] Deterministic fused cross-CTA dW reduction in RMSNorm backward Eliminates the separate .sum(dim=0) kernel for dw_partial reduction by fusing a deterministic last-CTA-reduces pattern into the backward kernel. Each CTA writes its partial to dw_partial[bidx, :] as before, then does a threadfence + atomic increment of a global counter. The last CTA to arrive loads all partials in fixed order 0..sm_count-1 and accumulates into dw_final, ensuring deterministic results across runs. Only enabled for N <= 8192 (cluster_n == 1). For larger N, falls back to the existing host-side .sum(dim=0) reduction. Based on the approach discussed in Dao-AILab/quack#101. Co-authored-by: Aaron Wang --- quack/rmsnorm.py | 108 ++++++++++++++++++++++++++++++++++++++++++++--- quack/utils.py | 12 ++++++ 2 files changed, 115 insertions(+), 5 deletions(-) diff --git a/quack/rmsnorm.py b/quack/rmsnorm.py index bf09b8c9..c45a9957 100644 --- a/quack/rmsnorm.py +++ b/quack/rmsnorm.py @@ -584,8 +584,10 @@ def __call__( mRstd: cute.Tensor, mdX: cute.Tensor, mdW: Optional[cute.Tensor], + mdW_final: Optional[cute.Tensor], mdRes: Optional[cute.Tensor], mdB: Optional[cute.Tensor], + reduce_counter: Optional[cute.Tensor], sm_count: Int32, stream: cuda.CUstream, ): @@ -603,7 +605,8 @@ def __call__( num_blocks = sm_count num_heads = mX.shape[1] if const_expr(cute.rank(mX) == 3) else 1 self.kernel( - mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tiler_mn, tiled_copy, threads_per_row + mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdW_final, mdB, mdRes, + reduce_counter, tiler_mn, tiled_copy, threads_per_row ).launch( grid=[num_blocks, self.cluster_n, num_heads], block=[num_threads, 1, 1], @@ -621,8 +624,10 @@ def kernel( mRstd: cute.Tensor, mdX: cute.Tensor, mdW: Optional[cute.Tensor], + mdW_final: Optional[cute.Tensor], mdB: Optional[cute.Tensor], mdRes: Optional[cute.Tensor], + reduce_counter: Optional[cute.Tensor], tiler_mn: cute.Shape, tiled_copy: cute.TiledCopy, threads_per_row: cutlass.Constexpr[int], @@ -893,6 +898,82 @@ def kernel( if const_expr(mdB is not None): copy(tXrdB, tXgdB) + + # Cross-CTA deterministic dW reduction (last-CTA-reduces pattern). + # Each CTA has already written its partial to dw_partial[bidx_start, :]. + # After a threadfence + atomic counter, the last CTA to arrive loads + # all partials in fixed order and accumulates into dw_final. + if const_expr(mdW_final is not None and self.cluster_n == 1): + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + utils.threadfence() + + smem_is_last = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=Int32), + cute.make_layout((1,)), + ) + if tidx == 0: + old = utils.atomic_add_i32(Int32(1), reduce_counter.iterator) + smem_is_last[0] = old + cute.arch.barrier() + + if smem_is_last[0] == gdim - Int32(1): + sdW_buf = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=cute.Float32), + cute.make_layout((tiler_mn[1],)), + ) + num_thr = cute.size(tiled_copy) + vecsize_f32 = const_expr(min(tiler_mn[1], 128 // cute.Float32.width)) + thr_copy_dw = copy_utils.tiled_copy_1d( + cute.Float32, num_thr, vecsize_f32, is_async=True, + ) + thr_dw = thr_copy_dw.get_slice(tidx) + + gdW_all = cute.make_tensor(mdW.iterator, mdW.layout) + gdW_final_1d = cute.make_tensor( + mdW_final.iterator, cute.make_layout((tiler_mn[1],)) + ) + + # Accumulate first row directly into smem + gdW_row0 = cute.make_tensor( + gdW_all.iterator, cute.make_layout((tiler_mn[1],)) + ) + copy_utils.copy( + thr_dw.partition_S(gdW_row0), thr_dw.partition_D(sdW_buf), + is_async=True, + ) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + # Load remaining rows and accumulate in smem + for i in range(1, gdim): + sdW_tmp = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=cute.Float32) + tiler_mn[1], + cute.make_layout((tiler_mn[1],)), + ) + gdW_row_i = cute.make_tensor( + gdW_all.iterator + i * tiler_mn[1], + cute.make_layout((tiler_mn[1],)), + ) + copy_utils.copy( + thr_dw.partition_S(gdW_row_i), thr_dw.partition_D(sdW_tmp), + is_async=True, + ) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + thr_buf = thr_dw.partition_D(sdW_buf) + thr_tmp = thr_dw.partition_S(sdW_tmp) + for j in range(cute.size(thr_buf)): + thr_buf[j] = thr_buf[j] + thr_tmp[j] + cute.arch.barrier() + + # Store final accumulated result to dw_final + copy_utils.copy( + thr_dw.partition_S(sdW_buf), thr_dw.partition_D(gdW_final_1d), + ) + if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early # Assume state contains that next useful buffer # So we only need to advance to num_stages - 1 times to last used buffer @@ -921,10 +1002,10 @@ def _get_sm_count(N: int, device: torch.device) -> int: @torch.library.custom_op( "quack::_rmsnorm_bwd", - mutates_args={"dx", "dw_partial", "db_partial", "dresidual"}, + mutates_args={"dx", "dw_partial", "db_partial", "dresidual", "dw", "reduce_counter"}, device_types="cuda", # We need to specify the schema manually since we're mutating an optional tensor - schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count) -> ()", + schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count, Tensor(a10!)? dw, Tensor(a11!)? reduce_counter) -> ()", ) def _rmsnorm_bwd( x: Tensor, @@ -937,6 +1018,8 @@ def _rmsnorm_bwd( dresidual_out: Optional[Tensor] = None, dresidual: Optional[Tensor] = None, sm_count: Optional[int] = None, + dw: Optional[Tensor] = None, + reduce_counter: Optional[Tensor] = None, ) -> None: """RMSNorm backward pass. Args: @@ -1003,6 +1086,8 @@ def _rmsnorm_bwd_fake( dresidual_out: Optional[Tensor] = None, dresidual: Optional[Tensor] = None, sm_count: Optional[int] = None, + dw: Optional[Tensor] = None, + reduce_counter: Optional[Tensor] = None, ) -> None: # See softmax.py _softmax_fwd_fake for why register_fake is needed. from quack.cache_utils import COMPILE_ONLY @@ -1016,6 +1101,7 @@ def _rmsnorm_bwd_fake( torch2cute_dtype_map[t.dtype] if t is not None else None for t in [x, dout, dx, weight, dresidual, dresidual_out] ] + dw_dtype = torch2cute_dtype_map[dw.dtype] if dw is not None else None _compile_rmsnorm_bwd( N, dtype, @@ -1027,6 +1113,7 @@ def _rmsnorm_bwd_fake( dres_out_dtype, dw_partial is not None, per_head, + dw_dtype, ) @@ -1042,6 +1129,7 @@ def _compile_rmsnorm_bwd( dres_out_dtype, has_dw_partial, per_head=False, + dw_dtype=None, ): batch_sym, batch_partial_sym = cute.sym_int(), cute.sym_int() head_sym = cute.sym_int() if per_head else None @@ -1058,6 +1146,8 @@ def _compile_rmsnorm_bwd( dw_shape = (batch_partial_sym, head_sym, N) if per_head else (batch_partial_sym, N) dw_partial_cute = fake_tensor(Float32, dw_shape, div) if has_dw_partial else None db_partial_cute = fake_tensor(Float32, dw_shape, div) if has_db_partial else None + dw_cute = fake_tensor(dw_dtype, (N,), div) if dw_dtype is not None else None + reduce_counter_cute = fake_tensor(Int32, (1,)) if dw_dtype is not None else None return cute.compile( RMSNormBackward(dtype, N), x_cute, @@ -1067,8 +1157,10 @@ def _compile_rmsnorm_bwd( rstd_cute, dx_cute, dw_partial_cute, + dw_cute, dres_cute, db_partial_cute, + reduce_counter_cute, 0, # sm_count, just for compilation cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", @@ -1098,12 +1190,18 @@ def rmsnorm_bwd( sm_count = max(round(sm_count / H), 1) else: H = None + dw_partial: Optional[Tensor] = None + dw_final: Optional[Tensor] = None + reduce_counter: Optional[Tensor] = None + # Fused cross-CTA dW reduction (last-CTA-reduces) for cluster_n == 1 and non-per-head. + use_fused_dw_reduce = N <= 8192 and weight is not None and not per_head if weight is not None: # Always store partial gradients in fp32 for numerical accuracy dw_shape = (sm_count, H, N) if per_head else (sm_count, N) dw_partial = torch.empty(dw_shape, device=device, dtype=torch.float32) - else: - dw_partial = None + if use_fused_dw_reduce: + dw_final = torch.empty(N, device=device, dtype=torch.float32) + reduce_counter = torch.zeros(1, device=device, dtype=torch.int32) db_shape = (sm_count, H, N) if per_head else (sm_count, N) db_partial = torch.empty(db_shape, device=device, dtype=torch.float32) if has_bias else None diff --git a/quack/utils.py b/quack/utils.py index 7039d8ae..a80b375f 100644 --- a/quack/utils.py +++ b/quack/utils.py @@ -275,6 +275,18 @@ def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) @dsl_user_op + + +@dsl_user_op +def threadfence(*, loc=None, ip=None) -> None: + llvm.inline_asm( + None, + [], + "membar.gl;", + "", + has_side_effects=True, + is_align_stack=False, + ) def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32: from cutlass import CUDA_VERSION