Add LayerNorm Backward#136
Open
ighoshsubho wants to merge 1 commit into
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds LayerNorm backward to QuACK. Instead of forking a new kernel, I extended the existing
RMSNormBackwardwith anis_layernormflag, since the two ops share the same memory pattern and only differ in the math at the row-reduction step. RMSNorm bwd needs one row reduction (mean(x_hat * wdy)) and writes onlydW.LayerNorm needs a second one (
mean(wdy)), usesx_hat = (x - mean) * rstdinstead ofx * rstd, and additionally writesdb = sum_rows(wdy). The dx formula becomes(wdy - mean(wdy) - x_hat * mean(x_hat * wdy)) * rstd, anddbreuses the same(sm_count, N)fp32 workspace + host-reduce pattern thatdwalready uses, so we still avoid atomics. To carry two row reductions per row through the cluster pipeline, the reduction buffer goes fromstage=2tostage=4(two slots × double-buffered for row pipelining). The rmsnorm path passesmMean=Noneand keeps the original 2-stage layout, so existing rmsnorm tests are untouched.What's Added
layernorm_bwd(x, weight, dout, rstd, mean)API that mirrorsrmsnorm_bwd, returning(dx, dw, db). It wraps atorch.library.custom_op(quack::_layernorm_bwd) and requires fp32 weight (same constraint as the rest of the bwd path).RMSNormBackwardto acceptis_layernorm, take an optionalmMeantensor, and switch on the LN dx formula and second row reduction when set.stage=4reductions andstage=2prefetch without OOB writes.tests/test_layernorm.pynext to the forward tests, with the same(M, N, dtype, eps)parametrization (M ∈ {1, 37, 199}, N up to 262144, bf16/fp16/fp32). They validate againstF.layer_normautograd and skip theN > 128k && fp32combo thatRMSNormBackwardalready rejects on smem grounds.Benchmark (B200)
cc: @tridao @GarlGuo @JackCharlesZhang would appreciate your eyes on this when you get a chance