Skip to content

Rmsnorm backward fusing sum#101

Open
AaronWang04 wants to merge 4 commits into
Dao-AILab:mainfrom
AaronWang04:rmsnorm_bwd_bulk_reduce
Open

Rmsnorm backward fusing sum#101
AaronWang04 wants to merge 4 commits into
Dao-AILab:mainfrom
AaronWang04:rmsnorm_bwd_bulk_reduce

Conversation

@AaronWang04
Copy link
Copy Markdown

@AaronWang04 AaronWang04 commented Apr 8, 2026

Eliminates the extra .sum(dim=0) kernel for dw_partial reduction in the RMSNorm backward pass. Each CTA writes its partial to dw_partial as before, loads it back into smem contiguously, then thread 0 issues a bulk async reduce-add to a dw tensor. Only for N ≤ 8192 (no intra-cta reduction)

Benchmark sweeped between M=4k to 65k and N=256 to 8k, no performance regressions in any case and on average 1.2x speedup on gb200.

Rough results with triton.testing.do_bench

UPDATE:
Added determinism via a semaphore with a 2 stage reduction

new results (note I disabled this fusion for N>2048)

GPU: NVIDIA GB200
==========================================================================================
dtype: torch.bfloat16
==========================================================================================
M=  4096, N=  256 | fused=0.0294ms  separate=0.0415ms  speedup=1.41x  FUSED
M=  8192, N=  256 | fused=0.0334ms  separate=0.0412ms  speedup=1.24x  FUSED
M= 16384, N=  256 | fused=0.0330ms  separate=0.0391ms  speedup=1.19x  FUSED
M= 32768, N=  256 | fused=0.0376ms  separate=0.0408ms  speedup=1.09x  FUSED
M= 65536, N=  256 | fused=0.0471ms  separate=0.0450ms  speedup=0.95x  SEPARATE
M=  4096, N=  512 | fused=0.0260ms  separate=0.0416ms  speedup=1.60x  FUSED
M=  8192, N=  512 | fused=0.0291ms  separate=0.0411ms  speedup=1.41x  FUSED
M= 16384, N=  512 | fused=0.0346ms  separate=0.0419ms  speedup=1.21x  FUSED
M= 32768, N=  512 | fused=0.0468ms  separate=0.0472ms  speedup=1.01x  FUSED
M= 65536, N=  512 | fused=0.0714ms  separate=0.0719ms  speedup=1.01x  FUSED
M=  4096, N= 1024 | fused=0.0307ms  separate=0.0432ms  speedup=1.40x  FUSED
M=  8192, N= 1024 | fused=0.0354ms  separate=0.0409ms  speedup=1.15x  FUSED
M= 16384, N= 1024 | fused=0.0471ms  separate=0.0456ms  speedup=0.97x  SEPARATE
M= 32768, N= 1024 | fused=0.0714ms  separate=0.0702ms  speedup=0.98x  SEPARATE
M= 65536, N= 1024 | fused=0.1187ms  separate=0.1175ms  speedup=0.99x  SEPARATE
M=  4096, N= 2048 | fused=0.0369ms  separate=0.0395ms  speedup=1.07x  FUSED
M=  8192, N= 2048 | fused=0.0477ms  separate=0.0490ms  speedup=1.03x  FUSED
M= 16384, N= 2048 | fused=0.0695ms  separate=0.0708ms  speedup=1.02x  FUSED
M= 32768, N= 2048 | fused=0.1112ms  separate=0.1133ms  speedup=1.02x  FUSED
M= 65536, N= 2048 | fused=0.1944ms  separate=0.1975ms  speedup=1.02x  FUSED
M=  4096, N= 3072 | fused=0.0453ms  separate=0.0406ms  speedup=0.90x  SEPARATE
M=  8192, N= 3072 | fused=0.0676ms  separate=0.0613ms  speedup=0.91x  SEPARATE
M= 16384, N= 3072 | fused=0.1105ms  separate=0.1013ms  speedup=0.92x  SEPARATE
M= 32768, N= 3072 | fused=0.1965ms  separate=0.1825ms  speedup=0.93x  SEPARATE
M= 65536, N= 3072 | fused=0.3677ms  separate=0.3443ms  speedup=0.94x  SEPARATE
M=  4096, N= 4096 | fused=0.0536ms  separate=0.0469ms  speedup=0.87x  SEPARATE
M=  8192, N= 4096 | fused=0.0775ms  separate=0.0637ms  speedup=0.82x  SEPARATE
M= 16384, N= 4096 | fused=0.1244ms  separate=0.1064ms  speedup=0.85x  SEPARATE
M= 32768, N= 4096 | fused=0.2196ms  separate=0.1912ms  speedup=0.87x  SEPARATE
M= 65536, N= 4096 | fused=0.4088ms  separate=0.3614ms  speedup=0.88x  SEPARATE

Comment thread quack/rmsnorm.py Outdated
Comment thread quack/rmsnorm.py Outdated
Comment on lines +861 to +863
copy_utils.cpasync_reduce_bulk_add_f32(
sdW_buf.iterator, mdW_final.iterator, store_bytes,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without a semaphore guarding this, this will result in non-deterministic reductions

Copy link
Copy Markdown
Author

@AaronWang04 AaronWang04 Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but if we add a semaphore I don't think this fusion becomes worth it.

should I add a bool flag for determinism which forces the unfused .sum() path instead?

fwiw i think the performance gains are worth it (I still haven't tuned sm_count yet)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think a backward kernel without deterministic reduction is worth adding tbh. I do not see anyone using it.

@tridao
Copy link
Copy Markdown
Member

tridao commented Apr 13, 2026

We could try w semaphore to see if there's still any perf win.

@tridao
Copy link
Copy Markdown
Member

tridao commented Apr 13, 2026

i don't think you need to load dW from gmem -> smem. Just do cp.async.bulk.reduce from smem -> gmem. With a semaphore this would be deterministic.

santoshmo pushed a commit to santoshmo/quack that referenced this pull request Apr 19, 2026
Eliminates the separate .sum(dim=0) kernel for dw_partial reduction by
fusing a deterministic last-CTA-reduces pattern into the backward kernel.

Each CTA writes its partial to dw_partial[bidx, :] as before, then does
a threadfence + atomic increment of a global counter. The last CTA to
arrive loads all partials in fixed order 0..sm_count-1 and accumulates
into dw_final, ensuring deterministic results across runs.

Only enabled for N <= 8192 (cluster_n == 1). For larger N, falls back
to the existing host-side .sum(dim=0) reduction.

Based on the approach discussed in Dao-AILab#101.

Co-authored-by: Aaron Wang <aaronwang04@users.noreply.github.com>
santoshmo pushed a commit to santoshmo/quack that referenced this pull request Apr 19, 2026
Eliminates the separate .sum(dim=0) kernel for dw_partial reduction by
fusing a deterministic last-CTA-reduces pattern into the backward kernel.

Each CTA writes its partial to dw_partial[bidx, :] as before, then does
a threadfence + atomic increment of a global counter. The last CTA to
arrive loads all partials in fixed order 0..sm_count-1 and accumulates
into dw_final, ensuring deterministic results across runs.

Only enabled for N <= 8192 (cluster_n == 1). For larger N, falls back
to the existing host-side .sum(dim=0) reduction.

Based on the approach discussed in Dao-AILab#101.

Co-authored-by: Aaron Wang <aaronwang04@users.noreply.github.com>
@AaronWang04
Copy link
Copy Markdown
Author

Some updates:
I messed around with adding a semaphore and got varied perf wins/regressions

it seems like it bumped to registers from 64 to 66 which killed occupancy for certain shapes (the ones that were ok had decent perf wins)

I tried passing in maxrregcount which either wasn't being properly passed down to ptx or the compiler isn't respecting the argument

@Pranshu-Bahadur
Copy link
Copy Markdown

Some updates:
I messed around with adding a semaphore and got varied perf wins/regressions

it seems like it bumped to registers from 64 to 66 which killed occupancy for certain shapes (the ones that were ok had decent perf wins)

I tried passing in maxrregcount which either wasn't being properly passed down to ptx or the compiler isn't respecting the argument

Sorry, I'm super new and still learning, but if it's 2 regs, couldn't we reduce size of constants like producer* = int32(0) usage for instance to producer* = int8(0)? Or is the backend int32 (if so there is probably other places where you could squeeze registers too maybe) -- I'm used to C++ CUDA so idk. Also you could tweak cluster sizes too right (worst case)? Thank you, super humbling to learn from your work! Sorry again, I'm still learning 🙏

@tridao
Copy link
Copy Markdown
Member

tridao commented Apr 21, 2026

you could try min_blocks_per_mp=blah in kernel launch to hint to the compiler.
There's some upcoming change to reduce num regs for rmsnorm backward as well that will help.

@AaronWang04
Copy link
Copy Markdown
Author

@Pranshu-Bahadur You can do that but registers are reused throughout the kernel for various diff variables. 66 is how many registers the compiler determines the kernel will need throughput its lifetime, so setting producer=Int8(0) when its a variable that will be recaptured later on anyways may not change the number of total regs the kernel needs

@AaronWang04
Copy link
Copy Markdown
Author

AaronWang04 commented Apr 21, 2026

min_blocks_per_mp didn't work but I just set self.reload_wdy = "smem" to be true for all N and got pretty good perf with the semaphore path after freeing up the regs

updated top comment with new perf #s

@Pranshu-Bahadur
Copy link
Copy Markdown

@Pranshu-Bahadur You can do that but registers are reused throughout the kernel for various diff variables. 66 is how many registers the compiler determines the kernel will need throughput its lifetime, so setting producer=Int8(0) when its a variable that will be recaptured later on anyways may not change the number of total regs the kernel needs

hey, yeah tyty makes sense from the code base's design pov! Your latest commit looks super cool. Will prolly learn to use cute api too python for tcgen05 stuff -- the intra-groupwise reduce and leader based movement is hella cool 😁 -- also noticed you lowered reg pressure in multiple ways

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants