Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 176 additions & 11 deletions quack/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
from typing import Optional, Tuple, Type
from functools import partial
from functools import cache, partial

import cuda.bindings.driver as cuda

Expand Down Expand Up @@ -539,7 +539,7 @@ class RMSNormBackward(ReductionBase):
def __init__(self, dtype: cutlass.Numeric, N: int):
# 2 stages for double buffering when computing mean of x_hat * wdy
super().__init__(dtype, N, stage=2, reduction_dtype=Float32)
self.reload_wdy = None if N <= 16 * 1024 else "smem"
self.reload_wdy = "smem"
if self.N > 128 * 1024 and self.dtype.width >= 32:
# Not enough smem
raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
Expand Down Expand Up @@ -584,9 +584,12 @@ def __call__(
mRstd: cute.Tensor,
mdX: cute.Tensor,
mdW: Optional[cute.Tensor],
mdW_final: Optional[cute.Tensor],
mdRes: Optional[cute.Tensor],
mdB: Optional[cute.Tensor],
mSemaphore: Optional[cute.Tensor],
sm_count: Int32,
group_size: Int32,
stream: cuda.CUstream,
):
assert mX.element_type == self.dtype
Expand All @@ -600,10 +603,26 @@ def __call__(
mW = (
layout_utils.expand(mW, dim=0, size=tiler_mn[0]) if const_expr(mW is not None) else None
)
if const_expr(mdW_final is not None):
mdW_final = layout_utils.expand(mdW_final, dim=0, size=1)
num_blocks = sm_count
num_heads = mX.shape[1] if const_expr(cute.rank(mX) == 3) else 1
self.kernel(
mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tiler_mn, tiled_copy, threads_per_row
mX,
mW,
mdO,
mdResO,
mRstd,
mdX,
mdW,
mdW_final,
mdB,
mdRes,
mSemaphore,
group_size,
tiler_mn,
tiled_copy,
threads_per_row,
).launch(
grid=[num_blocks, self.cluster_n, num_heads],
block=[num_threads, 1, 1],
Expand All @@ -621,8 +640,11 @@ def kernel(
mRstd: cute.Tensor,
mdX: cute.Tensor,
mdW: Optional[cute.Tensor],
mdW_final: Optional[cute.Tensor],
mdB: Optional[cute.Tensor],
mdRes: Optional[cute.Tensor],
mSemaphore: Optional[cute.Tensor],
group_size: Int32,
tiler_mn: cute.Shape,
tiled_copy: cute.TiledCopy,
threads_per_row: cutlass.Constexpr[int],
Expand Down Expand Up @@ -893,6 +915,87 @@ def kernel(
if const_expr(mdB is not None):
copy(tXrdB, tXgdB)

# Two-level grouped reduction: reduce dw_partial across CTAs into mdW_final.
# Level 1: each group of group_size CTAs reduces to dw_partial[group_leader].
# Level 2: the last group-reducer reduces G group sums into mdW_final.
# Only supported for cluster_n == 1; for cluster_n > 1 the caller
# must reduce dw_partial on the host.
if const_expr(mdW_final is not None and self.cluster_n == 1):
cute.arch.fence_acq_rel_gpu()

my_group = Int32(bidx_start / group_size)
group_base = Int32(my_group * group_size)
group_count = Int32(group_size)
if group_base + group_size > gdim:
group_count = Int32(gdim - group_base)
num_groups = Int32((gdim + group_size - Int32(1)) / group_size)

sFlag = cute.make_tensor(
cute.recast_ptr(sX.iterator, dtype=Int32),
cute.make_layout((1,)),
)

gdW_all = cute.local_tile(mdW, (1, tiler_mn[1]), (None, cluster_y))
tXgdW_all = thr_copy_X.partition_S(gdW_all)
tXrdW_accum = cute.make_fragment_like(
tXgdW_all[None, None, None, 0],
Float32,
)
tXrdW_row = cute.make_fragment_like(tXgdW_all[None, None, None, 0])

# --- Level 1: intra-group reduction ---
if tidx == 0:
is_last_in_group = Int32(0)
old = utils.atomic_add_i32(Int32(1), mSemaphore.iterator + my_group)
if old == group_count - Int32(1):
is_last_in_group = Int32(1)
sFlag[0] = is_last_in_group
cute.arch.barrier()

if sFlag[0]:
cute.arch.fence_acq_rel_gpu()
tXrdW_accum.fill(0.0)
for i in cutlass.range(group_base, group_base + group_count):
copy(tXgdW_all[None, None, None, i], tXrdW_row)
tXrdW_accum.store(tXrdW_accum.load() + tXrdW_row.load())
gdW_leader = cute.local_tile(
mdW,
(1, tiler_mn[1]),
(group_base, cluster_y),
)
tXgdW_leader = thr_copy_X.partition_D(gdW_leader)
copy(tXrdW_accum, tXgdW_leader)

# --- Level 2: cross-group reduction ---
cute.arch.fence_acq_rel_gpu()
if tidx == 0:
is_last_group = Int32(0)
old = utils.atomic_add_i32(
Int32(1),
mSemaphore.iterator + num_groups,
)
if old == num_groups - Int32(1):
is_last_group = Int32(1)
sFlag[0] = is_last_group
cute.arch.barrier()

if sFlag[0]:
cute.arch.fence_acq_rel_gpu()
tXrdW_accum.fill(0.0)
for g in cutlass.range(0, num_groups):
leader_row = g * group_size
copy(tXgdW_all[None, None, None, leader_row], tXrdW_row)
tXrdW_accum.store(tXrdW_accum.load() + tXrdW_row.load())
gdW_final = cute.local_tile(
mdW_final,
(1, tiler_mn[1]),
(0, cluster_y),
)
tXgdW_final = thr_copy_X.partition_D(gdW_final)
tXrdW_out = cute.make_fragment_like(tXgdW_final)
tXrdW_out.store(tXrdW_accum.load().to(tXrdW_out.element_type))
copy(tXrdW_out, tXgdW_final)

if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
# Assume state contains that next useful buffer
# So we only need to advance to num_stages - 1 times to last used buffer
Expand All @@ -919,12 +1022,20 @@ def _get_sm_count(N: int, device: torch.device) -> int:
return sm_count


@cache
def _get_semaphore(device: torch.device) -> torch.Tensor:
"""Reuse same semaphore to avoid repeated torch.zero calls.
num_groups + 1 slots needed; ceil(sqrt(max_sm_count)) + 1 fits in 64 for any current GPU.
"""
return torch.zeros(64, device=device, dtype=torch.int32)


@torch.library.custom_op(
"quack::_rmsnorm_bwd",
mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
mutates_args={"dx", "dw_partial", "db_partial", "dresidual", "dw"},
device_types="cuda",
# We need to specify the schema manually since we're mutating an optional tensor
schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count) -> ()",
schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count, Tensor(a10!)? dw, Tensor? semaphore, int? group_size) -> ()",
)
def _rmsnorm_bwd(
x: Tensor,
Expand All @@ -937,6 +1048,9 @@ def _rmsnorm_bwd(
dresidual_out: Optional[Tensor] = None,
dresidual: Optional[Tensor] = None,
sm_count: Optional[int] = None,
dw: Optional[Tensor] = None,
semaphore: Optional[Tensor] = None,
group_size: Optional[int] = None,
) -> None:
"""RMSNorm backward pass.
Args:
Expand Down Expand Up @@ -977,6 +1091,7 @@ def _rmsnorm_bwd(
torch2cute_dtype_map[t.dtype] if t is not None else None
for t in [x, dout, dx, weight, dresidual, dresidual_out]
]
dw_dtype = torch2cute_dtype_map[dw.dtype] if dw is not None else None
_compile_rmsnorm_bwd(
N,
dtype,
Expand All @@ -988,7 +1103,22 @@ def _rmsnorm_bwd(
dres_out_dtype,
dw_partial is not None,
per_head,
)(x, weight, dout, dresidual_out, rstd, dx, dw_partial, dresidual, db_partial, sm_count)
dw_dtype,
)(
x,
weight,
dout,
dresidual_out,
rstd,
dx,
dw_partial,
dw,
dresidual,
db_partial,
semaphore,
sm_count,
group_size if group_size is not None else 0,
)


@_rmsnorm_bwd.register_fake
Expand All @@ -1003,6 +1133,9 @@ def _rmsnorm_bwd_fake(
dresidual_out: Optional[Tensor] = None,
dresidual: Optional[Tensor] = None,
sm_count: Optional[int] = None,
dw: Optional[Tensor] = None,
semaphore: Optional[Tensor] = None,
group_size: Optional[int] = None,
) -> None:
# See softmax.py _softmax_fwd_fake for why register_fake is needed.
from quack.cache_utils import COMPILE_ONLY
Expand All @@ -1016,6 +1149,7 @@ def _rmsnorm_bwd_fake(
torch2cute_dtype_map[t.dtype] if t is not None else None
for t in [x, dout, dx, weight, dresidual, dresidual_out]
]
dw_dtype = torch2cute_dtype_map[dw.dtype] if dw is not None else None
_compile_rmsnorm_bwd(
N,
dtype,
Expand All @@ -1027,6 +1161,7 @@ def _rmsnorm_bwd_fake(
dres_out_dtype,
dw_partial is not None,
per_head,
dw_dtype,
)


Expand All @@ -1042,6 +1177,7 @@ def _compile_rmsnorm_bwd(
dres_out_dtype,
has_dw_partial,
per_head=False,
dw_dtype=None,
):
batch_sym, batch_partial_sym = cute.sym_int(), cute.sym_int()
head_sym = cute.sym_int() if per_head else None
Expand All @@ -1058,6 +1194,8 @@ def _compile_rmsnorm_bwd(
dw_shape = (batch_partial_sym, head_sym, N) if per_head else (batch_partial_sym, N)
dw_partial_cute = fake_tensor(Float32, dw_shape, div) if has_dw_partial else None
db_partial_cute = fake_tensor(Float32, dw_shape, div) if has_db_partial else None
dw_cute = fake_tensor(dw_dtype, (N,), div) if dw_dtype is not None else None
semaphore_cute = fake_tensor(Int32, (64,)) if dw_dtype is not None else None
return cute.compile(
RMSNormBackward(dtype, N),
x_cute,
Expand All @@ -1067,9 +1205,12 @@ def _compile_rmsnorm_bwd(
rstd_cute,
dx_cute,
dw_partial_cute,
dw_cute,
dres_cute,
db_partial_cute,
semaphore_cute,
0, # sm_count, just for compilation
0, # group_size, just for compilation
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
options="--enable-tvm-ffi",
)
Expand Down Expand Up @@ -1098,21 +1239,45 @@ def rmsnorm_bwd(
sm_count = max(round(sm_count / H), 1)
else:
H = None
dw_partial: Optional[Tensor] = None
dw: Optional[Tensor] = None
semaphore: Optional[Tensor] = None
group_size: Optional[int] = None
# In-kernel cross-CTA dw reduction via two-level tree. Only supported for
# cluster_n == 1 (N <= 8192) and non-per-head. For larger N or per-head the
# kernel ignores dw/semaphore and we fall back to host-side dw_partial.sum().
use_in_kernel_dw_reduction = N <= 2048 and weight is not None and not per_head
if weight is not None:
# Always store partial gradients in fp32 for numerical accuracy
dw_shape = (sm_count, H, N) if per_head else (sm_count, N)
dw_partial = torch.empty(dw_shape, device=device, dtype=torch.float32)
else:
dw_partial = None
if use_in_kernel_dw_reduction:
dw = torch.empty(N, device=device, dtype=weight.dtype)
semaphore = _get_semaphore(device)
semaphore.zero_()
G = math.ceil(math.sqrt(sm_count))
group_size = math.ceil(sm_count / G)
db_shape = (sm_count, H, N) if per_head else (sm_count, N)
db_partial = torch.empty(db_shape, device=device, dtype=torch.float32) if has_bias else None

_rmsnorm_bwd(
x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual, sm_count
x,
weight,
dout,
rstd,
dx,
dw_partial,
db_partial,
dresidual_out,
dresidual,
sm_count,
dw,
semaphore,
group_size,
)

# we have summed the partial gradients in fp32, now we convert back to the weight dtype
dw = dw_partial.sum(dim=0).to(weight.dtype) if weight is not None else None
if weight is not None and not use_in_kernel_dw_reduction:
dw = dw_partial.sum(dim=0).to(weight.dtype)
db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None
# dresidual is the same as dx in this case
if has_residual and dresidual is None:
Expand Down