From fbd9075229db0db9c059c96c3caa0b2ec2fb86f1 Mon Sep 17 00:00:00 2001 From: Subho Ghosh Date: Sat, 9 May 2026 10:47:08 +0000 Subject: [PATCH] Add layernorm backward --- README.md | 2 +- quack/rmsnorm.py | 231 ++++++++++++++++++++++++++++++++++------ tests/test_layernorm.py | 68 +++++++++++- 3 files changed, 267 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 0002e9af..903152ba 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ pip install 'quack-kernels[heuristics]' - 🦆 RMSNorm forward + backward - 🦆 Softmax forward + backward - 🦆 Cross entropy forward + backward -- 🦆 Layernorm forward +- 🦆 Layernorm forward + backward - 🦆 Hopper gemm + epilogue - 🦆 Blackwell gemm + epilogue - 🦆 Blackwell GeForce gemm + epilogue diff --git a/quack/rmsnorm.py b/quack/rmsnorm.py index 08dcc6e0..1145b3f7 100644 --- a/quack/rmsnorm.py +++ b/quack/rmsnorm.py @@ -538,9 +538,11 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6): class RMSNormBackward(ReductionBase): - def __init__(self, dtype: cutlass.Numeric, N: int): - # 2 stages for double buffering when computing mean of x_hat * wdy - super().__init__(dtype, N, stage=2, reduction_dtype=Float32) + def __init__(self, dtype: cutlass.Numeric, N: int, is_layernorm: bool = False): + # LN bwd needs 2 row reductions (mean(wdy), mean(x_hat*wdy)) so + # 2 reduction slots per row, doubled for row pipelining => stage=4. + self.is_layernorm = is_layernorm + super().__init__(dtype, N, stage=4 if is_layernorm else 2, reduction_dtype=Float32) self.reload_wdy = None if N <= 16 * 1024 else "smem" if self.N > 128 * 1024 and self.dtype.width >= 32: # Not enough smem @@ -584,6 +586,7 @@ def __call__( mdO: cute.Tensor, mdResO: Optional[cute.Tensor], mRstd: cute.Tensor, + mMean: Optional[cute.Tensor], mdX: cute.Tensor, mdW: Optional[cute.Tensor], mdRes: Optional[cute.Tensor], @@ -592,6 +595,8 @@ def __call__( stream: cuda.CUstream, ): assert mX.element_type == self.dtype + if const_expr(self.is_layernorm): + assert mMean is not None self._set_cluster_n() largest_dtype_width = const_expr( max(*(t.element_type.width for t in [mX, mW, mdO, mdResO, mdX, mdRes] if t is not None)) @@ -605,7 +610,19 @@ def __call__( num_blocks = sm_count num_heads = mX.shape[1] if const_expr(cute.rank(mX) == 3) else 1 self.kernel( - mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tiler_mn, tiled_copy, threads_per_row + mX, + mW, + mdO, + mdResO, + mRstd, + mMean, + mdX, + mdW, + mdB, + mdRes, + tiler_mn, + tiled_copy, + threads_per_row, ).launch( grid=[num_blocks, self.cluster_n, num_heads], block=[num_threads, 1, 1], @@ -621,6 +638,7 @@ def kernel( mdO: cute.Tensor, mdResO: Optional[cute.Tensor], mRstd: cute.Tensor, + mMean: Optional[cute.Tensor], mdX: cute.Tensor, mdW: Optional[cute.Tensor], mdB: Optional[cute.Tensor], @@ -642,6 +660,8 @@ def kernel( for mT in (mX, mW, mdO, mdResO, mdX, mdW, mdB, mdRes) ] mRstd = mRstd[None, bidz] + if const_expr(mMean is not None): + mMean = mMean[None, bidz] shape = mX.shape M, N = shape[0], shape[1] @@ -657,7 +677,7 @@ def kernel( smem, tv_layout, is_persistent=True ) if const_expr(mbar_ptr is not None): - mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2 + mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + self.stage else: mbar_full_ptr, mbar_empty_ptr = None, None @@ -749,7 +769,10 @@ def kernel( tXrdW.fill(0.0) if const_expr(mdB is not None): tXrdB.fill(0.0) - stage = Int32(0) + # smem prefetch is always 2-stage; reduction stage advances by + # K_REDUCE per row (1 for RMS, 2 for LN). + red_stage = Int32(0) + smem_stage = Int32(0) producer_phase = Int32(1) consumer_phase = Int32(0) for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim): @@ -757,49 +780,61 @@ def kernel( if row + gdim * tiler_mn[0] < M: # Prefetch the next batch copy( tXgX[None, None, None, bidx + gdim], - tXsX[None, None, None, stage ^ 1], + tXsX[None, None, None, smem_stage ^ 1], is_async=True, ) copy( tXgdO[None, None, None, bidx + gdim], - tXsdO[None, None, None, stage ^ 1], + tXsdO[None, None, None, smem_stage ^ 1], is_async=True, ) else: if const_expr(tiler_mn[0] > 1): utils.fill_oob( - tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero + tXsX[None, None, None, smem_stage ^ 1], + None, + fill_value=mX.element_type.zero, ) utils.fill_oob( - tXsdO[None, None, None, stage ^ 1], None, fill_value=mdO.element_type.zero + tXsdO[None, None, None, smem_stage ^ 1], + None, + fill_value=mdO.element_type.zero, ) cute.arch.cp_async_commit_group() rstd = cutlass.Float.zero + mean_v = cutlass.Float.zero if row < M or tiler_mn[0] == 1: rstd = mRstd[row] + if const_expr(self.is_layernorm): + mean_v = mMean[row] if const_expr(mdResO is not None): if row < M or tiler_mn[0] == 1: copy(tXgdResO[None, None, None, bidx], tXrdResO) elif tiler_mn[0] > 1: tXrdResO.fill(0.0) cute.arch.cp_async_wait_group(1) - cute.autovec_copy(tXsX[None, None, None, stage], tXrX) + cute.autovec_copy(tXsX[None, None, None, smem_stage], tXrX) x = tXrX.load().to(cute.Float32) - cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO) + cute.autovec_copy(tXsdO[None, None, None, smem_stage], tXrdO) dout = tXrdO.load().to(cute.Float32) - x_hat = x * rstd + if const_expr(self.is_layernorm): + x_hat = (x - mean_v) * rstd + else: + x_hat = x * rstd wdy = dout if const_expr(mW is not None): wdy *= tXrW.load().to(Float32) + + stage_a = red_stage if const_expr(self.cluster_n > 1): - cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase) + cute.arch.mbarrier_wait(mbar_empty_ptr + stage_a, producer_phase) mean_xhat_wdy = ( row_reduce( x_hat * wdy, cute.ReductionOp.ADD, threads_per_row, - reduction_buffer[None, None, stage], - (mbar_full_ptr + stage if const_expr(self.cluster_n > 1) else None), + reduction_buffer[None, None, stage_a], + (mbar_full_ptr + stage_a if const_expr(self.cluster_n > 1) else None), phase=consumer_phase, init_val=0.0, ) @@ -807,25 +842,51 @@ def kernel( ) if const_expr(self.cluster_n > 1): - # Need this fence since the STAS from the producer is using the async proxy. cute.arch.fence_view_async_shared() - # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes - # Requires adjusting the thread_count when initializing the mbar cute.arch.sync_warp() lane_idx = cute.arch.lane_idx() if lane_idx < self.cluster_n: cute.arch.mbarrier_arrive( - mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx + mbar_empty_ptr + stage_a, peer_cta_rank_in_cluster=lane_idx ) + mean_wdy = cutlass.Float.zero + if const_expr(self.is_layernorm): + stage_b = red_stage + 1 + if const_expr(self.cluster_n > 1): + cute.arch.mbarrier_wait(mbar_empty_ptr + stage_b, producer_phase) + mean_wdy = ( + row_reduce( + wdy, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, stage_b], + (mbar_full_ptr + stage_b if const_expr(self.cluster_n > 1) else None), + phase=consumer_phase, + init_val=0.0, + ) + / shape[1] + ) + if const_expr(self.cluster_n > 1): + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() + lane_idx = cute.arch.lane_idx() + if lane_idx < self.cluster_n: + cute.arch.mbarrier_arrive( + mbar_empty_ptr + stage_b, peer_cta_rank_in_cluster=lane_idx + ) + if const_expr(self.reload_wdy == "smem"): - cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO) + cute.autovec_copy(tXsdO[None, None, None, smem_stage], tXrdO) dout = tXrdO.load().to(cute.Float32) wdy = dout if const_expr(mW is not None): wdy *= tXrW.load().to(Float32) - dx = (wdy - x_hat * mean_xhat_wdy) * rstd + if const_expr(self.is_layernorm): + dx = (wdy - mean_wdy - x_hat * mean_xhat_wdy) * rstd + else: + dx = (wdy - x_hat * mean_xhat_wdy) * rstd if const_expr(mdResO is not None): dx += tXrdResO.load().to(cute.Float32) tXrdX.store(dx.to(tXrdX.element_type)) @@ -840,8 +901,12 @@ def kernel( if const_expr(mdB is not None): tXrdB.store(tXrdB.load() + dout) - stage ^= 1 - if stage == 0: + smem_stage ^= 1 + if const_expr(self.is_layernorm): + red_stage = (red_stage + 2) & 3 # 0 -> 2 -> 0 -> 2 + else: + red_stage ^= 1 + if red_stage == 0: consumer_phase ^= 1 producer_phase ^= 1 @@ -896,12 +961,17 @@ def kernel( copy(tXrdB, tXgdB) if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early - # Assume state contains that next useful buffer - # So we only need to advance to num_stages - 1 times to last used buffer - stage ^= 1 - if stage == 0: - producer_phase ^= 1 - cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase) + if const_expr(self.is_layernorm): + red_stage = (red_stage + 2) & 3 + if red_stage == 0: + producer_phase ^= 1 + cute.arch.mbarrier_wait(mbar_empty_ptr + red_stage, producer_phase) + cute.arch.mbarrier_wait(mbar_empty_ptr + red_stage + 1, producer_phase) + else: + red_stage ^= 1 + if red_stage == 0: + producer_phase ^= 1 + cute.arch.mbarrier_wait(mbar_empty_ptr + red_stage, producer_phase) def _get_sm_count(N: int, device: torch.device) -> int: @@ -988,7 +1058,7 @@ def _rmsnorm_bwd( dres_out_dtype, dw_partial is not None, per_head, - )(x, weight, dout, dresidual_out, rstd, dx, dw_partial, dresidual, db_partial, sm_count) + )(x, weight, dout, dresidual_out, rstd, None, dx, dw_partial, dresidual, db_partial, sm_count) @_rmsnorm_bwd.register_fake @@ -1042,6 +1112,7 @@ def _compile_rmsnorm_bwd( dres_out_dtype, has_dw_partial, per_head=False, + is_layernorm=False, ): batch_sym, batch_partial_sym = cute.sym_int(), cute.sym_int() head_sym = cute.sym_int() if per_head else None @@ -1055,16 +1126,18 @@ def _compile_rmsnorm_bwd( weight_shape = (head_sym, N) if per_head else (N,) weight_cute = fake_tensor(weight_dtype, weight_shape, div) rstd_cute = fake_tensor(Float32, batch_shape) + mean_cute = fake_tensor(Float32, batch_shape) if is_layernorm else None dw_shape = (batch_partial_sym, head_sym, N) if per_head else (batch_partial_sym, N) dw_partial_cute = fake_tensor(Float32, dw_shape, div) if has_dw_partial else None db_partial_cute = fake_tensor(Float32, dw_shape, div) if has_db_partial else None return cute.compile( - RMSNormBackward(dtype, N), + RMSNormBackward(dtype, N, is_layernorm=is_layernorm), x_cute, weight_cute, dout_cute, dres_out_cute, rstd_cute, + mean_cute, dx_cute, dw_partial_cute, dres_cute, @@ -1123,6 +1196,100 @@ def rmsnorm_bwd( return dx, dw, db, dresidual +@torch.library.custom_op( + "quack::_layernorm_bwd", + mutates_args={"dx", "dw_partial", "db_partial"}, + device_types="cuda", + schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor mean, " + "Tensor(a5!) dx, Tensor(a6!) dw_partial, Tensor(a7!) db_partial, " + "int? sm_count) -> ()", +) +def _layernorm_bwd( + x: Tensor, + weight: Tensor, + dout: Tensor, + rstd: Tensor, + mean: Tensor, + dx: Tensor, + dw_partial: Tensor, + db_partial: Tensor, + sm_count: Optional[int] = None, +) -> None: + assert x.dim() in (2, 3), "Input must be 2D or 3D" + supported_types = {torch.float16, torch.bfloat16, torch.float32} + assert x.dtype in supported_types + assert weight.dtype == torch.float32 + per_head = x.dim() == 3 + N = x.size(-1) + if sm_count is None: + sm_count = dw_partial.shape[0] + dtype, dout_dtype, dx_dtype, weight_dtype = [ + torch2cute_dtype_map[t.dtype] for t in [x, dout, dx, weight] + ] + _compile_rmsnorm_bwd( + N, + dtype, + dout_dtype, + dx_dtype, + weight_dtype, + True, + None, + None, + True, + per_head, + is_layernorm=True, + )(x, weight, dout, None, rstd, mean, dx, dw_partial, None, db_partial, sm_count) + + +@_layernorm_bwd.register_fake +def _layernorm_bwd_fake(x, weight, dout, rstd, mean, dx, dw_partial, db_partial, sm_count=None): + return None + + +def layernorm_bwd( + x: Tensor, + weight: Tensor, + dout: Tensor, + rstd: Tensor, + mean: Tensor, +) -> Tuple[Tensor, Tensor, Tensor]: + """LayerNorm backward. + + Args: + x: Input tensor of shape (M, N) or (M, H, N). + weight: Weight of shape (N,) or (H, N). Must be float32. + dout: Upstream gradient of same shape as x. + rstd: Reciprocal stddev of shape (M,) or (M, H), float32. + mean: Mean of shape (M,) or (M, H), float32. + + Returns: + (dx, dw, db): same shapes as x, weight, weight respectively; + dw and db are cast to weight.dtype. + """ + assert weight.dtype == torch.float32, "Weight must be float32" + device = x.device + N = x.size(-1) + per_head = x.dim() == 3 + sm_count = _get_sm_count(N, device) + if per_head: + H = x.size(1) + sm_count = max(round(sm_count / H), 1) + partial_shape = (sm_count, H, N) + else: + partial_shape = (sm_count, N) + dx = torch.empty_like(x) + dw_partial = torch.empty(partial_shape, device=device, dtype=torch.float32) + db_partial = torch.empty(partial_shape, device=device, dtype=torch.float32) + if x.numel() > 0: + _layernorm_bwd(x, weight, dout, rstd, mean, dx, dw_partial, db_partial, sm_count) + dw = dw_partial.sum(dim=0).to(weight.dtype) + db = db_partial.sum(dim=0).to(weight.dtype) + else: + dw = torch.zeros_like(weight) + db = torch.zeros_like(weight) + return dx, dw, db + + class RMSNormFunction(torch.autograd.Function): """Autograd wrapper for rmsnorm. diff --git a/tests/test_layernorm.py b/tests/test_layernorm.py index 91963a31..f678252d 100644 --- a/tests/test_layernorm.py +++ b/tests/test_layernorm.py @@ -3,7 +3,13 @@ import pytest import torch -from quack.rmsnorm import layernorm_fwd, layernorm_ref, layernorm_rstd_ref, layernorm_mean_ref +from quack.rmsnorm import ( + layernorm_bwd, + layernorm_fwd, + layernorm_mean_ref, + layernorm_ref, + layernorm_rstd_ref, +) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @@ -58,6 +64,66 @@ def test_layernorm_forward(M, N, input_dtype, eps): torch.testing.assert_close(mean, mean_ref_val, atol=6e-4, rtol=6e-4) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("M", [1, 37, 199]) +@pytest.mark.parametrize( + "N", [256, 512, 760, 1024, 1128, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144] +) +def test_layernorm_backward(M, N, input_dtype, eps): + """Test LayerNorm backward against torch autograd.""" + device = "cuda" + # RMSNormBackward rejects N > 128k with >=32-bit dtype (smem limit). + if N > 128 * 1024 and input_dtype == torch.float32: + pytest.skip("RMSNormBackward: N > 128k unsupported for fp32") + major, _ = torch.cuda.get_device_capability() + if major == 12: + smem_n_limit = 131072 if input_dtype == torch.float32 else 262144 + if N > smem_n_limit: + pytest.skip("SM12x: exceeds 99 KB SMEM") + + if input_dtype == torch.bfloat16: + atol, rtol = 2e-2, 2e-2 + elif input_dtype == torch.float16: + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 1e-4, 1e-4 + + torch.random.manual_seed(0) + x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True) + # F.layer_norm requires weight/bias dtype to match input; QuACK requires fp32 weight. + w_ref = torch.randn(N, device=device, dtype=input_dtype, requires_grad=True) + b_ref = torch.randn(N, device=device, dtype=input_dtype, requires_grad=True) + w_f32 = w_ref.detach().to(torch.float32).contiguous() + + y_ref = torch.nn.functional.layer_norm(x, (N,), w_ref, b_ref, eps=eps) + dy = torch.randn_like(y_ref) + y_ref.backward(dy) + dx_ref = x.grad.detach() + dw_ref = w_ref.grad.detach() + db_ref = b_ref.grad.detach() + + mean = x.detach().to(torch.float32).mean(dim=-1) + rstd = (x.detach().to(torch.float32).var(dim=-1, unbiased=False) + eps).rsqrt() + + dx, dw, db = layernorm_bwd(x.detach(), w_f32, dy, rstd, mean) + + assert dx.shape == x.shape and dx.dtype == input_dtype + assert dw.shape == (N,) and dw.dtype == torch.float32 + assert db.shape == (N,) and db.dtype == torch.float32 + + torch.testing.assert_close(dx, dx_ref, atol=atol, rtol=rtol) + # dw/db accumulate across rows so absolute error scales with M; use relative. + rel_dw = (dw.float() - dw_ref.float()).abs().max().item() / max( + dw_ref.float().abs().mean().item(), 1e-6 + ) + rel_db = (db.float() - db_ref.float()).abs().max().item() / max( + db_ref.float().abs().mean().item(), 1e-6 + ) + assert rel_dw < 0.05, f"dw rel err {rel_dw}" + assert rel_db < 0.05, f"db rel err {rel_db}" + + @pytest.mark.parametrize("return_rstd", [True, False]) @pytest.mark.parametrize("return_mean", [True, False]) def test_layernormnorm_return_rstd_option(return_rstd, return_mean):