Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 103 additions & 5 deletions quack/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,10 @@ def __call__(
mRstd: cute.Tensor,
mdX: cute.Tensor,
mdW: Optional[cute.Tensor],
mdW_final: Optional[cute.Tensor],
mdRes: Optional[cute.Tensor],
mdB: Optional[cute.Tensor],
reduce_counter: Optional[cute.Tensor],
sm_count: Int32,
stream: cuda.CUstream,
):
Expand All @@ -603,7 +605,8 @@ def __call__(
num_blocks = sm_count
num_heads = mX.shape[1] if const_expr(cute.rank(mX) == 3) else 1
self.kernel(
mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tiler_mn, tiled_copy, threads_per_row
mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdW_final, mdB, mdRes,
reduce_counter, tiler_mn, tiled_copy, threads_per_row
).launch(
grid=[num_blocks, self.cluster_n, num_heads],
block=[num_threads, 1, 1],
Expand All @@ -621,8 +624,10 @@ def kernel(
mRstd: cute.Tensor,
mdX: cute.Tensor,
mdW: Optional[cute.Tensor],
mdW_final: Optional[cute.Tensor],
mdB: Optional[cute.Tensor],
mdRes: Optional[cute.Tensor],
reduce_counter: Optional[cute.Tensor],
tiler_mn: cute.Shape,
tiled_copy: cute.TiledCopy,
threads_per_row: cutlass.Constexpr[int],
Expand Down Expand Up @@ -893,6 +898,82 @@ def kernel(
if const_expr(mdB is not None):
copy(tXrdB, tXgdB)


# Cross-CTA deterministic dW reduction (last-CTA-reduces pattern).
# Each CTA has already written its partial to dw_partial[bidx_start, :].
# After a threadfence + atomic counter, the last CTA to arrive loads
# all partials in fixed order and accumulates into dw_final.
if const_expr(mdW_final is not None and self.cluster_n == 1):
cute.arch.cp_async_wait_group(0)
cute.arch.barrier()
utils.threadfence()

smem_is_last = cute.make_tensor(
cute.recast_ptr(sX.iterator, dtype=Int32),
cute.make_layout((1,)),
)
if tidx == 0:
old = utils.atomic_add_i32(Int32(1), reduce_counter.iterator)
smem_is_last[0] = old
cute.arch.barrier()

if smem_is_last[0] == gdim - Int32(1):
sdW_buf = cute.make_tensor(
cute.recast_ptr(sX.iterator, dtype=cute.Float32),
cute.make_layout((tiler_mn[1],)),
)
num_thr = cute.size(tiled_copy)
vecsize_f32 = const_expr(min(tiler_mn[1], 128 // cute.Float32.width))
thr_copy_dw = copy_utils.tiled_copy_1d(
cute.Float32, num_thr, vecsize_f32, is_async=True,
)
thr_dw = thr_copy_dw.get_slice(tidx)

gdW_all = cute.make_tensor(mdW.iterator, mdW.layout)
gdW_final_1d = cute.make_tensor(
mdW_final.iterator, cute.make_layout((tiler_mn[1],))
)

# Accumulate first row directly into smem
gdW_row0 = cute.make_tensor(
gdW_all.iterator, cute.make_layout((tiler_mn[1],))
)
copy_utils.copy(
thr_dw.partition_S(gdW_row0), thr_dw.partition_D(sdW_buf),
is_async=True,
)
cute.arch.cp_async_commit_group()
cute.arch.cp_async_wait_group(0)
cute.arch.barrier()

# Load remaining rows and accumulate in smem
for i in range(1, gdim):
sdW_tmp = cute.make_tensor(
cute.recast_ptr(sX.iterator, dtype=cute.Float32) + tiler_mn[1],
cute.make_layout((tiler_mn[1],)),
)
gdW_row_i = cute.make_tensor(
gdW_all.iterator + i * tiler_mn[1],
cute.make_layout((tiler_mn[1],)),
)
copy_utils.copy(
thr_dw.partition_S(gdW_row_i), thr_dw.partition_D(sdW_tmp),
is_async=True,
)
cute.arch.cp_async_commit_group()
cute.arch.cp_async_wait_group(0)
cute.arch.barrier()
thr_buf = thr_dw.partition_D(sdW_buf)
thr_tmp = thr_dw.partition_S(sdW_tmp)
for j in range(cute.size(thr_buf)):
thr_buf[j] = thr_buf[j] + thr_tmp[j]
cute.arch.barrier()

# Store final accumulated result to dw_final
copy_utils.copy(
thr_dw.partition_S(sdW_buf), thr_dw.partition_D(gdW_final_1d),
)

if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
# Assume state contains that next useful buffer
# So we only need to advance to num_stages - 1 times to last used buffer
Expand Down Expand Up @@ -921,10 +1002,10 @@ def _get_sm_count(N: int, device: torch.device) -> int:

@torch.library.custom_op(
"quack::_rmsnorm_bwd",
mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
mutates_args={"dx", "dw_partial", "db_partial", "dresidual", "dw", "reduce_counter"},
device_types="cuda",
# We need to specify the schema manually since we're mutating an optional tensor
schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count) -> ()",
schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count, Tensor(a10!)? dw, Tensor(a11!)? reduce_counter) -> ()",
)
def _rmsnorm_bwd(
x: Tensor,
Expand All @@ -937,6 +1018,8 @@ def _rmsnorm_bwd(
dresidual_out: Optional[Tensor] = None,
dresidual: Optional[Tensor] = None,
sm_count: Optional[int] = None,
dw: Optional[Tensor] = None,
reduce_counter: Optional[Tensor] = None,
) -> None:
"""RMSNorm backward pass.
Args:
Expand Down Expand Up @@ -1003,6 +1086,8 @@ def _rmsnorm_bwd_fake(
dresidual_out: Optional[Tensor] = None,
dresidual: Optional[Tensor] = None,
sm_count: Optional[int] = None,
dw: Optional[Tensor] = None,
reduce_counter: Optional[Tensor] = None,
) -> None:
# See softmax.py _softmax_fwd_fake for why register_fake is needed.
from quack.cache_utils import COMPILE_ONLY
Expand All @@ -1016,6 +1101,7 @@ def _rmsnorm_bwd_fake(
torch2cute_dtype_map[t.dtype] if t is not None else None
for t in [x, dout, dx, weight, dresidual, dresidual_out]
]
dw_dtype = torch2cute_dtype_map[dw.dtype] if dw is not None else None
_compile_rmsnorm_bwd(
N,
dtype,
Expand All @@ -1027,6 +1113,7 @@ def _rmsnorm_bwd_fake(
dres_out_dtype,
dw_partial is not None,
per_head,
dw_dtype,
)


Expand All @@ -1042,6 +1129,7 @@ def _compile_rmsnorm_bwd(
dres_out_dtype,
has_dw_partial,
per_head=False,
dw_dtype=None,
):
batch_sym, batch_partial_sym = cute.sym_int(), cute.sym_int()
head_sym = cute.sym_int() if per_head else None
Expand All @@ -1058,6 +1146,8 @@ def _compile_rmsnorm_bwd(
dw_shape = (batch_partial_sym, head_sym, N) if per_head else (batch_partial_sym, N)
dw_partial_cute = fake_tensor(Float32, dw_shape, div) if has_dw_partial else None
db_partial_cute = fake_tensor(Float32, dw_shape, div) if has_db_partial else None
dw_cute = fake_tensor(dw_dtype, (N,), div) if dw_dtype is not None else None
reduce_counter_cute = fake_tensor(Int32, (1,)) if dw_dtype is not None else None
return cute.compile(
RMSNormBackward(dtype, N),
x_cute,
Expand All @@ -1067,8 +1157,10 @@ def _compile_rmsnorm_bwd(
rstd_cute,
dx_cute,
dw_partial_cute,
dw_cute,
dres_cute,
db_partial_cute,
reduce_counter_cute,
0, # sm_count, just for compilation
cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
options="--enable-tvm-ffi",
Expand Down Expand Up @@ -1098,12 +1190,18 @@ def rmsnorm_bwd(
sm_count = max(round(sm_count / H), 1)
else:
H = None
dw_partial: Optional[Tensor] = None
dw_final: Optional[Tensor] = None
reduce_counter: Optional[Tensor] = None
# Fused cross-CTA dW reduction (last-CTA-reduces) for cluster_n == 1 and non-per-head.
use_fused_dw_reduce = N <= 8192 and weight is not None and not per_head
if weight is not None:
# Always store partial gradients in fp32 for numerical accuracy
dw_shape = (sm_count, H, N) if per_head else (sm_count, N)
dw_partial = torch.empty(dw_shape, device=device, dtype=torch.float32)
else:
dw_partial = None
if use_fused_dw_reduce:
dw_final = torch.empty(N, device=device, dtype=torch.float32)
reduce_counter = torch.zeros(1, device=device, dtype=torch.int32)
db_shape = (sm_count, H, N) if per_head else (sm_count, N)
db_partial = torch.empty(db_shape, device=device, dtype=torch.float32) if has_bias else None

Expand Down
12 changes: 12 additions & 0 deletions quack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,18 @@ def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None)


@dsl_user_op


@dsl_user_op
def threadfence(*, loc=None, ip=None) -> None:
llvm.inline_asm(
None,
[],
"membar.gl;",
"",
has_side_effects=True,
is_align_stack=False,
)
def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
from cutlass import CUDA_VERSION

Expand Down
Loading