diff --git a/quack/gemm_sm100.py b/quack/gemm_sm100.py index 959267a2..3c4dde45 100644 --- a/quack/gemm_sm100.py +++ b/quack/gemm_sm100.py @@ -104,8 +104,9 @@ class GemmSm100(GemmSm90): :param acc_dtype: Data type for accumulation during computation :type acc_dtype: type[cutlass.Numeric] - :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) - :type mma_tiler_mn: Tuple[int, int] + :param mma_tiler_mn: Shape of the MMA tile. Pass (M, N) to default K to + 4 MMA instructions, or (M, N, K) to set the K tile size explicitly. + :type mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]] :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing :type cluster_shape_mn: Tuple[int, int] @@ -154,7 +155,7 @@ def __init__( self, acc_dtype: Type[cutlass.Numeric], a_dtype: Type[cutlass.Numeric], # ignored for now - mma_tiler_mn: Tuple[int, int], + mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]], cluster_shape_mnk: Tuple[int, int, int], sf_vec_size: Optional[int] = None, gather_A: bool = False, @@ -176,8 +177,9 @@ def __init__( :param acc_dtype: Data type of the accumulator. :type acc_dtype: type[cutlass.Numeric] - :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. - :type mma_tiler_mn: Tuple[int, int] + :param mma_tiler_mn: (M, N) or (M, N, K) shape of the MMA tile. + If only (M, N) is given, K defaults to 4 * instruction K. + :type mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]] :param cluster_shape_mnk: Tuple (ClusterM, ClusterN) shape of the cluster. :type cluster_shape_mnk: Tuple[int, int] """ @@ -186,8 +188,11 @@ def __init__( self.use_2cta_instrs = cluster_shape_mnk[0] == 2 and mma_tiler_mn[0] in (256,) self.cluster_shape_mnk = cluster_shape_mnk assert cluster_shape_mnk[2] == 1, "Cluster shape K must be 1" - # K dimension is deferred in _setup_attributes - self.mma_tiler = (*mma_tiler_mn, 1) + # K dimension: if user provides 3 values, use their K; otherwise default in _setup_attributes + if len(mma_tiler_mn) == 3: + self.mma_tiler = tuple(mma_tiler_mn) + else: + self.mma_tiler = (*mma_tiler_mn, 0) self.sf_vec_size = sf_vec_size self.blockscaled = sf_vec_size is not None self.is_persistent = True @@ -302,7 +307,10 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: Varle ) # Compute mma/cluster/tile shapes - mma_inst_tile_k = 4 + if self.mma_tiler[2] > 0: + mma_inst_tile_k = self.mma_tiler[2] // self.mma_inst_shape_mnk[2] + else: + mma_inst_tile_k = 4 self.mma_tiler = ( self.mma_tiler[0], self.mma_tiler[1], @@ -2427,7 +2435,7 @@ def is_valid_dtypes_and_scale_factor_vec_size( @staticmethod def is_valid_mma_tiler_and_cluster_shape( - mma_tiler_mn: Tuple[int, int], + mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]], cluster_shape_mn: Tuple[int, int], blockscaled: bool, ) -> bool: @@ -2536,7 +2544,7 @@ def can_implement_blockscaled( sf_dtype: Type[cutlass.Numeric], sf_vec_size: int, d_dtype: Type[cutlass.Numeric], - mma_tiler_mn: Tuple[int, int], + mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]], cluster_shape_mn: Tuple[int, int], m: int, n: int, @@ -2572,7 +2580,7 @@ def can_implement( ab_dtype: Type[cutlass.Numeric], acc_dtype: Type[cutlass.Numeric], d_dtype: Type[cutlass.Numeric], - mma_tiler_mn: Tuple[int, int], + mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]], cluster_shape_mn: Tuple[int, int], m: int, n: int, diff --git a/quack/rmsnorm.py b/quack/rmsnorm.py index 7bad65ad..bc954ca6 100644 --- a/quack/rmsnorm.py +++ b/quack/rmsnorm.py @@ -1,26 +1,23 @@ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. import math -from typing import Optional, Tuple, Type from functools import partial +from typing import Optional, Tuple, Type import cuda.bindings.driver as cuda - import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, const_expr - -import torch -from torch import Tensor - -import quack.utils as utils import quack.copy_utils as copy_utils import quack.layout_utils as layout_utils +import quack.utils as utils +import torch +from cutlass import const_expr, Float32, Int32 +from quack.cache_utils import jit_cache from quack.compile_utils import make_fake_tensor as fake_tensor +from quack.cute_dsl_utils import torch2cute_dtype_map from quack.reduce import row_reduce from quack.reduction_base import ReductionBase -from quack.cache_utils import jit_cache -from quack.cute_dsl_utils import torch2cute_dtype_map +from torch import Tensor def _ensure_contiguous(t): @@ -33,7 +30,9 @@ def _ensure_contiguous(t): class RMSNorm(ReductionBase): - def __init__(self, dtype: Type[cutlass.Numeric], N: int, is_layernorm: bool = False): + def __init__( + self, dtype: Type[cutlass.Numeric], N: int, is_layernorm: bool = False + ): super().__init__(dtype, N, stage=2 if is_layernorm else 1) self.is_layernorm = is_layernorm self.reload_from = None if N <= (16384 if is_layernorm else 8192) else "smem" @@ -41,7 +40,13 @@ def __init__(self, dtype: Type[cutlass.Numeric], N: int, is_layernorm: bool = Fa def _threads_per_row(self): N = self.N - for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]: + for limit, threads in [ + (64, 8), + (128, 16), + (3072, 32), + (6144, 64), + (16384, 128), + ]: if N <= limit: return threads return 256 @@ -51,9 +56,19 @@ def _set_cluster_n(self): # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason # Similarly cluster_n = 8 is faster for N=128k if const_expr(self.dtype.width == 16): - thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)] + thresholds = [ + (16 * 1024, 1), + (32 * 1024, 2), + (64 * 1024, 4), + (128 * 1024, 8), + ] else: - thresholds = [(32 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)] + thresholds = [ + (32 * 1024, 1), + (64 * 1024, 2), + (128 * 1024, 4), + (256 * 1024, 8), + ] for limit, cluster in thresholds: if N <= limit: self.cluster_n = cluster @@ -77,21 +92,42 @@ def __call__( assert mX.element_type == self.dtype self._set_cluster_n() largest_dtype_width = const_expr( - max(*(t.element_type.width for t in [mX, mRes, mW, mB, mO, mResO] if t is not None)) + max( + *( + t.element_type.width + for t in [mX, mRes, mW, mB, mO, mResO] + if t is not None + ) + ) ) vecsize = math.gcd(self.N, 128 // largest_dtype_width) tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize) num_threads = tiled_copy.size mW, mB = [ - layout_utils.expand(mT, dim=0, size=tiler_mn[0]) if const_expr(mT is not None) else None + layout_utils.expand(mT, dim=0, size=tiler_mn[0]) + if const_expr(mT is not None) + else None for mT in (mW, mB) ] mRstd, mMean = [ - layout_utils.expand(mT, dim=1, size=self.N) if const_expr(mT is not None) else None + layout_utils.expand(mT, dim=1, size=self.N) + if const_expr(mT is not None) + else None for mT in (mRstd, mMean) ] self.kernel( - mX, mW, mB, mRes, mO, mResO, mRstd, mMean, eps, tiler_mn, tiled_copy, threads_per_row + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + mMean, + eps, + tiler_mn, + tiled_copy, + threads_per_row, ).launch( grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], block=[num_threads, 1, 1], @@ -117,12 +153,18 @@ def kernel( ): tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() - cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1] + cluster_y = ( + const_expr(0) + if const_expr(self.cluster_n == 1) + else cute.arch.block_idx()[1] + ) tv_layout = tiled_copy.layout_tv_tiled smem = cutlass.utils.SmemAllocator() sX = smem.allocate_tensor( - mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16 + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, ) if const_expr(mRes is not None): sRes = smem.allocate_tensor( @@ -130,7 +172,9 @@ def kernel( cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16, ) - reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, tv_layout + ) shape = mX.shape idX = cute.make_identity_tensor(shape) @@ -140,7 +184,9 @@ def kernel( for mT in (mX, mRes, mO, mResO, mRstd, mMean, idX) ] gW, gB = [ - cute.local_tile(mT, tiler_mn, (0, cluster_y)) if const_expr(mT is not None) else None + cute.local_tile(mT, tiler_mn, (0, cluster_y)) + if const_expr(mT is not None) + else None for mT in (mW, mB) ] @@ -156,8 +202,12 @@ def kernel( tXgO = thr_copy_X.partition_D(gO) if const_expr(mResO is not None): tXgResO = thr_copy_X.partition_D(gResO) - tXrRstd = thr_copy_X.partition_D(gRstd) if const_expr(mRstd is not None) else None - tXrMean = thr_copy_X.partition_D(gMean) if const_expr(mMean is not None) else None + tXrRstd = ( + thr_copy_X.partition_D(gRstd) if const_expr(mRstd is not None) else None + ) + tXrMean = ( + thr_copy_X.partition_D(gMean) if const_expr(mMean is not None) else None + ) tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] # allocate fragments for gmem->rmem @@ -214,7 +264,9 @@ def kernel( reduction_buffer[None, None, 0], mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None, init_val=0.0, - hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + hook_fn=cute.arch.cluster_wait + if const_expr(self.cluster_n > 1) + else None, ) mean = sum_x / shape[1] if const_expr(mMean is not None): @@ -256,7 +308,9 @@ def kernel( reduction_buffer[None, None, 0], mbar_ptr, init_val=0.0, - hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + hook_fn=cute.arch.cluster_wait + if const_expr(self.cluster_n > 1) + else None, ) rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True) if const_expr(mRstd is not None): @@ -327,9 +381,13 @@ def _rmsnorm_fwd( supported_types = {torch.float16, torch.bfloat16, torch.float32} assert x.dtype in supported_types, "Unsupported dtype" if weight is not None: - assert weight.dtype in supported_types, "Weight must be float32, float16 or bfloat16" + assert weight.dtype in supported_types, ( + "Weight must be float32, float16 or bfloat16" + ) if residual is not None: - assert residual.dtype in supported_types, "Residual must be float16, bfloat16, or float32" + assert residual.dtype in supported_types, ( + "Residual must be float16, bfloat16, or float32" + ) _, N = x.shape dtype, out_dtype, weight_dtype, bias_dtype, res_dtype, res_out_dtype = [ @@ -414,9 +472,12 @@ def _compile_rmsnorm_fwd( all_dtypes = [dtype, out_dtype, res_dtype, weight_dtype, bias_dtype, res_out_dtype] div = math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None)) x_cute, out_cute, res_cute, res_out_cute = [ - fake_tensor(dt, (batch_sym, N), div) for dt in [dtype, out_dtype, res_dtype, res_out_dtype] + fake_tensor(dt, (batch_sym, N), div) + for dt in [dtype, out_dtype, res_dtype, res_out_dtype] + ] + weight_cute, bias_cute = [ + fake_tensor(dt, (N,), div) for dt in [weight_dtype, bias_dtype] ] - weight_cute, bias_cute = [fake_tensor(dt, (N,), div) for dt in [weight_dtype, bias_dtype]] rstd_cute = fake_tensor(Float32, (batch_sym,)) if has_rstd else None mean_cute = fake_tensor(Float32, (batch_sym,)) if has_mean else None return cute.compile( @@ -450,10 +511,16 @@ def rmsnorm_fwd( # so that _layer_norm_fwd_impl doesn't have to return them. out_dtype = x.dtype if out_dtype is None else out_dtype out = torch.empty_like(x, dtype=out_dtype) - rstd = torch.empty(x.shape[0], device=x.device, dtype=torch.float32) if store_rstd else None + rstd = ( + torch.empty(x.shape[0], device=x.device, dtype=torch.float32) + if store_rstd + else None + ) if residual is not None: residual_dtype = residual.dtype - if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + if residual is not None or ( + residual_dtype is not None and residual_dtype != x.dtype + ): residual_out = torch.empty_like( x, dtype=residual_dtype if residual_dtype is not None else x.dtype ) @@ -471,7 +538,9 @@ def rmsnorm_ref(x, w=None, bias=None, residual=None, eps=1e-6): if residual is not None: residual_f32 = residual.float() x_f32 += residual_f32 - x_norm = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) + x_norm = x_f32 / ( + torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps) + ) out = x_norm * w if w is not None else x_norm if bias is not None: out = out + bias.float() @@ -507,7 +576,9 @@ def __init__(self, dtype: cutlass.Numeric, N: int): self.reload_wdy = None if N <= 16 * 1024 else "smem" if self.N > 128 * 1024 and self.dtype.width >= 32: # Not enough smem - raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits") + raise ValueError( + "RMSNormBackward does not support N > 128k with dtype >= 32 bits" + ) def _num_threads(self): return 128 if self.N <= 4096 else 256 @@ -521,7 +592,12 @@ def _threads_per_row(self): def _set_cluster_n(self): N = self.N - for limit, cluster in [(8 * 1024, 1), (16 * 1024, 2), (32 * 1024, 4), (64 * 1024, 8)]: + for limit, cluster in [ + (8 * 1024, 1), + (16 * 1024, 2), + (32 * 1024, 4), + (64 * 1024, 8), + ]: if N <= limit: self.cluster_n = cluster return @@ -537,25 +613,48 @@ def __call__( mRstd: cute.Tensor, mdX: cute.Tensor, mdW: Optional[cute.Tensor], + mdW_final: Optional[cute.Tensor], mdRes: Optional[cute.Tensor], mdB: Optional[cute.Tensor], + reduce_counter: Optional[cute.Tensor], sm_count: Int32, stream: cuda.CUstream, ): assert mX.element_type == self.dtype self._set_cluster_n() largest_dtype_width = const_expr( - max(*(t.element_type.width for t in [mX, mW, mdO, mdResO, mdX, mdRes] if t is not None)) + max( + *( + t.element_type.width + for t in [mX, mW, mdO, mdResO, mdX, mdRes] + if t is not None + ) + ) ) vecsize = math.gcd(self.N, 128 // largest_dtype_width) tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize) num_threads = tiled_copy.size mW = ( - layout_utils.expand(mW, dim=0, size=tiler_mn[0]) if const_expr(mW is not None) else None + layout_utils.expand(mW, dim=0, size=tiler_mn[0]) + if const_expr(mW is not None) + else None ) num_blocks = sm_count self.kernel( - mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tiler_mn, tiled_copy, threads_per_row + mX, + mW, + mdO, + mdResO, + mRstd, + mdX, + mdW, + mdW_final, + mdB, + mdRes, + reduce_counter, + tiler_mn, + tiled_copy, + threads_per_row, ).launch( grid=[num_blocks, self.cluster_n, 1], block=[num_threads, 1, 1], @@ -573,8 +672,10 @@ def kernel( mRstd: cute.Tensor, mdX: cute.Tensor, mdW: Optional[cute.Tensor], + mdW_final: Optional[cute.Tensor], mdB: Optional[cute.Tensor], mdRes: Optional[cute.Tensor], + reduce_counter: Optional[cute.Tensor], tiler_mn: cute.Shape, tiled_copy: cute.TiledCopy, threads_per_row: cutlass.Constexpr[int], @@ -582,7 +683,11 @@ def kernel( tidx, _, _ = cute.arch.thread_idx() bidx_start, _, _ = cute.arch.block_idx() gdim, _, _ = cute.arch.grid_dim() - cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1] + cluster_y = ( + const_expr(0) + if const_expr(self.cluster_n == 1) + else cute.arch.block_idx()[1] + ) tv_layout = tiled_copy.layout_tv_tiled shape = mX.shape @@ -592,7 +697,9 @@ def kernel( idX = cute.make_identity_tensor(shape) smem = cutlass.utils.SmemAllocator() - smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2)) + smem_layout = cute.make_ordered_layout( + (tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2) + ) sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16) sdO = smem.allocate_tensor(mdO.element_type, smem_layout, byte_alignment=16) reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( @@ -629,7 +736,8 @@ def kernel( tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None] tXrX, tXrdO, tXrdX = [ - cute.make_rmem_tensor_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX) + cute.make_rmem_tensor_like(thr[None, None, None, 0]) + for thr in (tXgX, tXgdO, tXgdX) ] tXrdResO = None if const_expr(mdResO is not None): @@ -642,7 +750,9 @@ def kernel( tXpX = ( None if is_even_N - else copy_utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1]) + else copy_utils.predicate_k( + thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1] + ) ) # Each copy will use the same number of elements as X copy = partial(copy_utils.copy, pred=tXpX) @@ -674,14 +784,26 @@ def kernel( # Prefetch the first batch row = tXcX[None, None, None, bidx_start][0][0] if row < M: - copy(tXgX[None, None, None, bidx_start], tXsX[None, None, None, 0], is_async=True) - copy(tXgdO[None, None, None, bidx_start], tXsdO[None, None, None, 0], is_async=True) + copy( + tXgX[None, None, None, bidx_start], + tXsX[None, None, None, 0], + is_async=True, + ) + copy( + tXgdO[None, None, None, bidx_start], + tXsdO[None, None, None, 0], + is_async=True, + ) else: if const_expr(tiler_mn[0] > 1): # Fill with zero, otherwise smem will be uninitialized, and we could read this back # later into registers, causing wrong dW. - utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero) - utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero) + utils.fill_oob( + tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero + ) + utils.fill_oob( + tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero + ) cute.arch.cp_async_commit_group() if const_expr(self.cluster_n > 1): @@ -710,10 +832,14 @@ def kernel( else: if const_expr(tiler_mn[0] > 1): utils.fill_oob( - tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero + tXsX[None, None, None, stage ^ 1], + None, + fill_value=mX.element_type.zero, ) utils.fill_oob( - tXsdO[None, None, None, stage ^ 1], None, fill_value=mdO.element_type.zero + tXsdO[None, None, None, stage ^ 1], + None, + fill_value=mdO.element_type.zero, ) cute.arch.cp_async_commit_group() rstd = cutlass.Float.zero @@ -837,6 +963,87 @@ def kernel( if const_expr(mdB is not None): copy(tXrdB, tXgdB) + # Cross-CTA deterministic dW reduction (last-CTA-reduces pattern). + # Each CTA has already written its partial to dw_partial[bidx_start, :]. + # After a threadfence + atomic counter, the last CTA to arrive loads + # all partials in fixed order and accumulates into dw_final. + if const_expr(mdW_final is not None and self.cluster_n == 1): + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + utils.threadfence() + + smem_is_last = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=Int32), + cute.make_layout((1,)), + ) + if tidx == 0: + old = utils.atomic_add_i32(Int32(1), reduce_counter.iterator) + smem_is_last[0] = old + cute.arch.barrier() + + if smem_is_last[0] == gdim - Int32(1): + sdW_buf = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=cute.Float32), + cute.make_layout((tiler_mn[1],)), + ) + num_thr = cute.size(tiled_copy) + vecsize_f32 = const_expr(min(tiler_mn[1], 128 // cute.Float32.width)) + thr_copy_dw = copy_utils.tiled_copy_1d( + cute.Float32, + num_thr, + vecsize_f32, + is_async=True, + ) + thr_dw = thr_copy_dw.get_slice(tidx) + + gdW_all = cute.make_tensor(mdW.iterator, mdW.layout) + gdW_final_1d = cute.make_tensor( + mdW_final.iterator, cute.make_layout((tiler_mn[1],)) + ) + + # Accumulate first row directly into smem + gdW_row0 = cute.make_tensor( + gdW_all.iterator, cute.make_layout((tiler_mn[1],)) + ) + copy_utils.copy( + thr_dw.partition_S(gdW_row0), + thr_dw.partition_D(sdW_buf), + is_async=True, + ) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + # Load remaining rows and accumulate in smem + for i in range(1, gdim): + sdW_tmp = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=cute.Float32) + tiler_mn[1], + cute.make_layout((tiler_mn[1],)), + ) + gdW_row_i = cute.make_tensor( + gdW_all.iterator + i * tiler_mn[1], + cute.make_layout((tiler_mn[1],)), + ) + copy_utils.copy( + thr_dw.partition_S(gdW_row_i), + thr_dw.partition_D(sdW_tmp), + is_async=True, + ) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + thr_buf = thr_dw.partition_D(sdW_buf) + thr_tmp = thr_dw.partition_S(sdW_tmp) + for j in range(cute.size(thr_buf)): + thr_buf[j] = thr_buf[j] + thr_tmp[j] + cute.arch.barrier() + + # Store final accumulated result to dw_final + copy_utils.copy( + thr_dw.partition_S(sdW_buf), + thr_dw.partition_D(gdW_final_1d), + ) + if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early # Assume state contains that next useful buffer # So we only need to advance to num_stages - 1 times to last used buffer @@ -849,7 +1056,9 @@ def kernel( def _get_sm_count(N: int, device: torch.device) -> int: # This should be tuned on how many CTAs can be launched on each SM sm_count_multiple = ( - 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1))) + 16 + if N <= 256 + else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1))) ) sm_count = torch.cuda.get_device_properties(device).multi_processor_count # By right, if we're using cluster, this should be cluster_count not sm_count. @@ -857,7 +1066,11 @@ def _get_sm_count(N: int, device: torch.device) -> int: # Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to # avoid wave quantization. sm_count = ( - sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2 + sm_count * sm_count_multiple + if N <= 8192 + else sm_count // 2 + if N <= 16384 + else sm_count * 2 ) return sm_count @@ -865,10 +1078,17 @@ def _get_sm_count(N: int, device: torch.device) -> int: @torch.library.custom_op( "quack::_rmsnorm_bwd", - mutates_args={"dx", "dw_partial", "db_partial", "dresidual"}, + mutates_args={ + "dx", + "dw_partial", + "db_partial", + "dresidual", + "dw", + "reduce_counter", + }, device_types="cuda", # We need to specify the schema manually since we're mutating an optional tensor - schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count) -> ()", + schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count, Tensor(a10!)? dw, Tensor(a11!)? reduce_counter) -> ()", ) def _rmsnorm_bwd( x: Tensor, @@ -881,6 +1101,8 @@ def _rmsnorm_bwd( dresidual_out: Optional[Tensor] = None, dresidual: Optional[Tensor] = None, sm_count: Optional[int] = None, + dw: Optional[Tensor] = None, + reduce_counter: Optional[Tensor] = None, ) -> None: """RMSNorm backward pass. Args: @@ -899,9 +1121,13 @@ def _rmsnorm_bwd( assert x.dtype in supported_types, "Unsupported dtype" if weight is not None: assert weight.dim() == 1, "Weight must be 1D" - assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension" + assert x.shape[-1] == weight.shape[0], ( + "Last dimension of input must match weight dimension" + ) assert weight.is_cuda, "Weight tensor must be on CUDA device" - assert weight.dtype in supported_types, "Weight must be float32, float16 or bfloat16" + assert weight.dtype in supported_types, ( + "Weight must be float32, float16 or bfloat16" + ) if dresidual_out is not None: assert dresidual_out.shape == x.shape assert dresidual_out.is_cuda @@ -911,17 +1137,22 @@ def _rmsnorm_bwd( if dresidual is not None: assert dresidual.shape == x.shape assert dresidual.is_cuda - assert dresidual.dtype in supported_types, "Residual must be float16, bfloat16, or float32" + assert dresidual.dtype in supported_types, ( + "Residual must be float16, bfloat16, or float32" + ) N = x.size(1) if dw_partial is None and db_partial is None: assert sm_count is not None else: - sm_count = dw_partial.shape[0] if dw_partial is not None else db_partial.shape[0] + sm_count = ( + dw_partial.shape[0] if dw_partial is not None else db_partial.shape[0] + ) dtype, dout_dtype, dx_dtype, weight_dtype, dres_dtype, dres_out_dtype = [ torch2cute_dtype_map[t.dtype] if t is not None else None for t in [x, dout, dx, weight, dresidual, dresidual_out] ] + dw_dtype = torch2cute_dtype_map[dw.dtype] if dw is not None else None _compile_rmsnorm_bwd( N, dtype, @@ -932,7 +1163,21 @@ def _rmsnorm_bwd( dres_dtype, dres_out_dtype, dw_partial is not None, - )(x, weight, dout, dresidual_out, rstd, dx, dw_partial, dresidual, db_partial, sm_count) + dw_dtype, + )( + x, + weight, + dout, + dresidual_out, + rstd, + dx, + dw_partial, + dw, + dresidual, + db_partial, + reduce_counter, + sm_count, + ) @_rmsnorm_bwd.register_fake @@ -947,6 +1192,8 @@ def _rmsnorm_bwd_fake( dresidual_out: Optional[Tensor] = None, dresidual: Optional[Tensor] = None, sm_count: Optional[int] = None, + dw: Optional[Tensor] = None, + reduce_counter: Optional[Tensor] = None, ) -> None: # See softmax.py _softmax_fwd_fake for why register_fake is needed. from quack.cache_utils import COMPILE_ONLY @@ -959,6 +1206,7 @@ def _rmsnorm_bwd_fake( torch2cute_dtype_map[t.dtype] if t is not None else None for t in [x, dout, dx, weight, dresidual, dresidual_out] ] + dw_dtype = torch2cute_dtype_map[dw.dtype] if dw is not None else None _compile_rmsnorm_bwd( N, dtype, @@ -969,6 +1217,7 @@ def _rmsnorm_bwd_fake( dres_dtype, dres_out_dtype, dw_partial is not None, + dw_dtype, ) @@ -983,6 +1232,7 @@ def _compile_rmsnorm_bwd( dres_dtype, dres_out_dtype, has_dw_partial, + dw_dtype=None, ): batch_sym, batch_partial_sym = cute.sym_int(), cute.sym_int() all_dtypes = [dtype, dout_dtype, dx_dtype, dres_dtype, dres_out_dtype] @@ -993,8 +1243,14 @@ def _compile_rmsnorm_bwd( ] weight_cute = fake_tensor(weight_dtype, (N,), div) rstd_cute = fake_tensor(Float32, (batch_sym,)) - dw_partial_cute = fake_tensor(Float32, (batch_partial_sym, N), div) if has_dw_partial else None - db_partial_cute = fake_tensor(Float32, (batch_partial_sym, N), div) if has_db_partial else None + dw_partial_cute = ( + fake_tensor(Float32, (batch_partial_sym, N), div) if has_dw_partial else None + ) + db_partial_cute = ( + fake_tensor(Float32, (batch_partial_sym, N), div) if has_db_partial else None + ) + dw_cute = fake_tensor(dw_dtype, (N,), div) if dw_dtype is not None else None + reduce_counter_cute = fake_tensor(Int32, (1,)) if dw_dtype is not None else None return cute.compile( RMSNormBackward(dtype, N), x_cute, @@ -1004,8 +1260,10 @@ def _compile_rmsnorm_bwd( rstd_cute, dx_cute, dw_partial_cute, + dw_cute, dres_cute, db_partial_cute, + reduce_counter_cute, 0, # sm_count, just for compilation cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", @@ -1029,19 +1287,46 @@ def rmsnorm_bwd( else: dresidual = None sm_count = _get_sm_count(N, device) + dw_partial: Optional[Tensor] = None + dw_final: Optional[Tensor] = None + reduce_counter: Optional[Tensor] = None + # Fused cross-CTA dW reduction (last-CTA-reduces pattern) is only + # supported for cluster_n == 1 (N <= 8192). For larger N the kernel + # falls back to a host-side reduction of dw_partial via .sum(dim=0). + use_fused_dw_reduce = N <= 8192 and weight is not None if weight is not None: - # Always store partial gradients in fp32 for numerical accuracy dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) - else: - dw_partial = None - db_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None + if use_fused_dw_reduce: + dw_final = torch.empty(N, device=device, dtype=torch.float32) + reduce_counter = torch.zeros(1, device=device, dtype=torch.int32) + db_partial = ( + torch.empty(sm_count, N, device=device, dtype=torch.float32) + if has_bias + else None + ) _rmsnorm_bwd( - x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual, sm_count + x, + weight, + dout, + rstd, + dx, + dw_partial, + db_partial, + dresidual_out, + dresidual, + sm_count, + dw_final, + reduce_counter, ) - # we have summed the partial gradients in fp32, now we convert back to the weight dtype - dw = dw_partial.sum(dim=0).to(weight.dtype) if weight is not None else None + if weight is not None: + if use_fused_dw_reduce: + dw = dw_final.to(weight.dtype) + else: + dw = dw_partial.sum(dim=0).to(weight.dtype) + else: + dw = None db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None # dresidual is the same as dx in this case if has_residual and dresidual is None: @@ -1095,7 +1380,9 @@ def backward(ctx, dout, *args): has_bias = ctx.has_bias if ctx.prenorm and ctx.residual_dtype is not None: dresidual_out = args[0] - dresidual_out = _ensure_contiguous(dresidual_out.reshape(-1, dresidual_out.shape[-1])) + dresidual_out = _ensure_contiguous( + dresidual_out.reshape(-1, dresidual_out.shape[-1]) + ) else: dresidual_out = None x_shape_og = ctx.x_shape_og @@ -1137,7 +1424,9 @@ def rmsnorm( Returns: Normalized output tensor of same shape as x """ - return RMSNormFunction.apply(x, weight, bias, residual, out_dtype, residual_dtype, eps, prenorm) + return RMSNormFunction.apply( + x, weight, bias, residual, out_dtype, residual_dtype, eps, prenorm + ) class QuackRMSNorm(torch.nn.RMSNorm): @@ -1156,7 +1445,12 @@ class QuackRMSNorm(torch.nn.RMSNorm): """ def __init__( - self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True, device=None, dtype=None + self, + dim: int, + eps: float = 1e-6, + elementwise_affine: bool = True, + device=None, + dtype=None, ): super().__init__(dim, eps, elementwise_affine, device=device, dtype=dtype) @@ -1197,7 +1491,9 @@ def layernorm_fwd( """ assert x.dim() == 2, "Input must be 2D" assert weight.dim() == 1, "Weight must be 1D" - assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype" + assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], ( + "Unsupported dtype" + ) assert weight.dtype == torch.float32, "Weight must be float32" if bias is not None: assert bias.dim() == 1, "Bias must be 1D" diff --git a/quack/utils.py b/quack/utils.py index 7039d8ae..6b43f535 100644 --- a/quack/utils.py +++ b/quack/utils.py @@ -274,6 +274,18 @@ def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) ) +@dsl_user_op +def threadfence(*, loc=None, ip=None) -> None: + llvm.inline_asm( + None, + [], + "membar.gl;", + "", + has_side_effects=True, + is_align_stack=False, + ) + + @dsl_user_op def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32: from cutlass import CUDA_VERSION