Rmsnorm backward fusing sum#101
Conversation
| copy_utils.cpasync_reduce_bulk_add_f32( | ||
| sdW_buf.iterator, mdW_final.iterator, store_bytes, | ||
| ) |
There was a problem hiding this comment.
without a semaphore guarding this, this will result in non-deterministic reductions
There was a problem hiding this comment.
Yes, but if we add a semaphore I don't think this fusion becomes worth it.
should I add a bool flag for determinism which forces the unfused .sum() path instead?
fwiw i think the performance gains are worth it (I still haven't tuned sm_count yet)
There was a problem hiding this comment.
I dont think a backward kernel without deterministic reduction is worth adding tbh. I do not see anyone using it.
|
We could try w semaphore to see if there's still any perf win. |
ffacb0e to
521174c
Compare
|
i don't think you need to load dW from gmem -> smem. Just do cp.async.bulk.reduce from smem -> gmem. With a semaphore this would be deterministic. |
Eliminates the separate .sum(dim=0) kernel for dw_partial reduction by fusing a deterministic last-CTA-reduces pattern into the backward kernel. Each CTA writes its partial to dw_partial[bidx, :] as before, then does a threadfence + atomic increment of a global counter. The last CTA to arrive loads all partials in fixed order 0..sm_count-1 and accumulates into dw_final, ensuring deterministic results across runs. Only enabled for N <= 8192 (cluster_n == 1). For larger N, falls back to the existing host-side .sum(dim=0) reduction. Based on the approach discussed in Dao-AILab#101. Co-authored-by: Aaron Wang <aaronwang04@users.noreply.github.com>
Eliminates the separate .sum(dim=0) kernel for dw_partial reduction by fusing a deterministic last-CTA-reduces pattern into the backward kernel. Each CTA writes its partial to dw_partial[bidx, :] as before, then does a threadfence + atomic increment of a global counter. The last CTA to arrive loads all partials in fixed order 0..sm_count-1 and accumulates into dw_final, ensuring deterministic results across runs. Only enabled for N <= 8192 (cluster_n == 1). For larger N, falls back to the existing host-side .sum(dim=0) reduction. Based on the approach discussed in Dao-AILab#101. Co-authored-by: Aaron Wang <aaronwang04@users.noreply.github.com>
|
Some updates: it seems like it bumped to registers from 64 to 66 which killed occupancy for certain shapes (the ones that were ok had decent perf wins) I tried passing in maxrregcount which either wasn't being properly passed down to ptx or the compiler isn't respecting the argument |
Sorry, I'm super new and still learning, but if it's 2 regs, couldn't we reduce size of constants like producer* = int32(0) usage for instance to producer* = int8(0)? Or is the backend int32 (if so there is probably other places where you could squeeze registers too maybe) -- I'm used to C++ CUDA so idk. Also you could tweak cluster sizes too right (worst case)? Thank you, super humbling to learn from your work! Sorry again, I'm still learning 🙏 |
|
you could try |
47786ed to
6ac09f5
Compare
|
@Pranshu-Bahadur You can do that but registers are reused throughout the kernel for various diff variables. 66 is how many registers the compiler determines the kernel will need throughput its lifetime, so setting producer=Int8(0) when its a variable that will be recaptured later on anyways may not change the number of total regs the kernel needs |
|
updated top comment with new perf #s |
hey, yeah tyty makes sense from the code base's design pov! Your latest commit looks super cool. Will prolly learn to use cute api too python for tcgen05 stuff -- the intra-groupwise reduce and leader based movement is hella cool 😁 -- also noticed you lowered reg pressure in multiple ways |
Eliminates the extra .sum(dim=0) kernel for dw_partial reduction in the RMSNorm backward pass. Each CTA writes its partial to dw_partial as before, loads it back into smem contiguously, then thread 0 issues a bulk async reduce-add to a dw tensor. Only for N ≤ 8192 (no intra-cta reduction)
Benchmark sweeped between M=4k to 65k and N=256 to 8k, no performance regressions in any case and on average 1.2x speedup on gb200.
Rough results with triton.testing.do_bench
UPDATE:
Added determinism via a semaphore with a 2 stage reduction
new results (note I disabled this fusion for N>2048)