Skip to content

Add LayerNorm Backward#136

Open
ighoshsubho wants to merge 1 commit into
Dao-AILab:mainfrom
ighoshsubho:main
Open

Add LayerNorm Backward#136
ighoshsubho wants to merge 1 commit into
Dao-AILab:mainfrom
ighoshsubho:main

Conversation

@ighoshsubho
Copy link
Copy Markdown

@ighoshsubho ighoshsubho commented May 9, 2026

This PR adds LayerNorm backward to QuACK. Instead of forking a new kernel, I extended the existing RMSNormBackward with an is_layernorm flag, since the two ops share the same memory pattern and only differ in the math at the row-reduction step. RMSNorm bwd needs one row reduction (mean(x_hat * wdy)) and writes only dW.

LayerNorm needs a second one (mean(wdy)), uses x_hat = (x - mean) * rstd instead of x * rstd, and additionally writes db = sum_rows(wdy). The dx formula becomes (wdy - mean(wdy) - x_hat * mean(x_hat * wdy)) * rstd, and db reuses the same (sm_count, N) fp32 workspace + host-reduce pattern that dw already uses, so we still avoid atomics. To carry two row reductions per row through the cluster pipeline, the reduction buffer goes from stage=2 to stage=4 (two slots × double-buffered for row pipelining). The rmsnorm path passes mMean=None and keeps the original 2-stage layout, so existing rmsnorm tests are untouched.

What's Added

  • A public layernorm_bwd(x, weight, dout, rstd, mean) API that mirrors rmsnorm_bwd, returning (dx, dw, db). It wraps a torch.library.custom_op (quack::_layernorm_bwd) and requires fp32 weight (same constraint as the rest of the bwd path).
  • Changed RMSNormBackward to accept is_layernorm, take an optional mMean tensor, and switch on the LN dx formula and second row reduction when set.
  • Separated the reduction-buffer stage from the smem prefetch stage in the kernel so the LN path can run with stage=4 reductions and stage=2 prefetch without OOB writes.
  • Added backward tests in tests/test_layernorm.py next to the forward tests, with the same (M, N, dtype, eps) parametrization (M ∈ {1, 37, 199}, N up to 262144, bf16/fp16/fp32). They validate against F.layer_norm autograd and skip the N > 128k && fp32 combo that RMSNormBackward already rejects on smem grounds.

Benchmark (B200)

M N dtype quack ms GB/s torch ms GB/s speedup
64 256 bfloat16 0.0057 71 0.0792 1 13.87x
64 256 float16 0.0060 67 0.0807 1 13.49x
64 256 float32 0.0056 90 0.0779 3 13.99x
128 1024 bfloat16 0.0054 371 0.0789 10 14.59x
128 1024 float16 0.0059 341 0.0877 9 14.92x
128 1024 float32 0.0054 520 0.0790 20 14.72x
1024 4096 bfloat16 0.0101 2981 0.0843 299 8.37x
1024 4096 float16 0.0104 2880 0.0832 303 7.98x
1024 4096 float32 0.0125 4419 0.0867 581 6.94x
4096 8192 bfloat16 0.0533 3957 0.1058 1904 1.98x
4096 8192 float16 0.0544 3878 0.1059 1902 1.94x
4096 8192 float32 0.0783 5269 0.1552 2594 1.98x
4096 16384 bfloat16 0.1549 2725 0.2062 1953 1.33x
4096 16384 float16 0.1551 2722 0.2062 1954 1.33x
4096 16384 float32 0.1892 4360 0.3257 2473 1.72x
8192 16384 bfloat16 0.2992 2757 0.3927 2051 1.31x
8192 16384 float16 0.2988 2760 0.3917 2056 1.31x
8192 16384 float32 0.3648 4468 0.5986 2691 1.64x
32768 4096 bfloat16 0.2562 3163 0.4761 1692 1.86x
32768 4096 float16 0.2605 3111 0.4715 1708 1.81x
32768 4096 float32 0.3912 4130 0.6431 2505 1.64x
32768 8192 bfloat16 0.3741 4332 0.6942 2320 1.86x
32768 8192 float16 0.3820 4242 0.6917 2329 1.81x
32768 8192 float32 0.5600 5770 1.1174 2883 2.00x

cc: @tridao @GarlGuo @JackCharlesZhang would appreciate your eyes on this when you get a chance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant