Skip to content

Rmsnorm bwd deterministic fused reduce#109

Closed
santoshmo wants to merge 5 commits into
Dao-AILab:mainfrom
santoshmo:rmsnorm-bwd-deterministic-fused-reduce
Closed

Rmsnorm bwd deterministic fused reduce#109
santoshmo wants to merge 5 commits into
Dao-AILab:mainfrom
santoshmo:rmsnorm-bwd-deterministic-fused-reduce

Conversation

@santoshmo
Copy link
Copy Markdown
Contributor

Eliminates the separate .sum(dim=0) kernel for dw_partial reduction by
fusing a deterministic last-CTA-reduces pattern into the backward kernel.

Each CTA writes its partial to dw_partial[bidx, :] as before, then does
a threadfence + atomic increment of a global counter. The last CTA to
arrive loads all partials in fixed order 0..sm_count-1 and accumulates
into dw_final, ensuring deterministic results across runs.

Only enabled for N <= 8192 (cluster_n == 1). For larger N, falls back
to the existing host-side .sum(dim=0) reduction.

Based on the approach discussed in #101.

Co-authored-by: Aaron Wang aaronwang04@users.noreply.github.com"

santoshmo and others added 5 commits April 7, 2026 12:34
Co-authored-by: Vijay Thakkar <vijaythakkar@me.com>
Eliminates the separate .sum(dim=0) kernel for dw_partial reduction by
fusing a deterministic last-CTA-reduces pattern into the backward kernel.

Each CTA writes its partial to dw_partial[bidx, :] as before, then does
a threadfence + atomic increment of a global counter. The last CTA to
arrive loads all partials in fixed order 0..sm_count-1 and accumulates
into dw_final, ensuring deterministic results across runs.

Only enabled for N <= 8192 (cluster_n == 1). For larger N, falls back
to the existing host-side .sum(dim=0) reduction.

Based on the approach discussed in Dao-AILab#101.

Co-authored-by: Aaron Wang <aaronwang04@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant