Skip to content

Deterministic fused cross-CTA dW reduction in RMSNorm backward#110

Closed
santoshmo wants to merge 1 commit into
Dao-AILab:mainfrom
santoshmo:main
Closed

Deterministic fused cross-CTA dW reduction in RMSNorm backward#110
santoshmo wants to merge 1 commit into
Dao-AILab:mainfrom
santoshmo:main

Conversation

@santoshmo
Copy link
Copy Markdown
Contributor

@santoshmo santoshmo commented Apr 19, 2026

Eliminates the separate .sum(dim=0) kernel for dw_partial reduction by fusing a deterministic last-CTA-reduces pattern into the backward kernel.

Each CTA writes its partial to dw_partial[bidx, :] as before, then does a threadfence + atomic increment of a global counter. The last CTA to arrive loads all partials in fixed order 0..sm_count-1 and accumulates into dw_final, ensuring deterministic results across runs.

Only enabled for N <= 8192 (cluster_n == 1). For larger N, falls back to the existing host-side .sum(dim=0) reduction.

Based on the approach discussed in #101 by @AaronWang04

Eliminates the separate .sum(dim=0) kernel for dw_partial reduction by
fusing a deterministic last-CTA-reduces pattern into the backward kernel.

Each CTA writes its partial to dw_partial[bidx, :] as before, then does
a threadfence + atomic increment of a global counter. The last CTA to
arrive loads all partials in fixed order 0..sm_count-1 and accumulates
into dw_final, ensuring deterministic results across runs.

Only enabled for N <= 8192 (cluster_n == 1). For larger N, falls back
to the existing host-side .sum(dim=0) reduction.

Based on the approach discussed in Dao-AILab#101.

Co-authored-by: Aaron Wang <aaronwang04@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant