diff --git a/README.md b/README.md index 24b4cadb..b5dd0f44 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ from quack import rmsnorm, softmax, cross_entropy [blogpost](media/2025-07-10-membound-sol.md) on how to get memory-bound kernels to speed-of-light, right in the comfort of Python thanks to the [CuTe-DSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html). + ## Performance
diff --git a/quack/__init__.py b/quack/__init__.py index 34a56ab4..6dcfae33 100644 --- a/quack/__init__.py +++ b/quack/__init__.py @@ -6,6 +6,7 @@ from quack.softmax import softmax from quack.cross_entropy import cross_entropy from quack.rounding import RoundingMode +from quack.gemm_interface import gemm, gemm_grouped if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: @@ -19,5 +20,7 @@ "rmsnorm", "softmax", "cross_entropy", + "gemm", + "gemm_grouped", "RoundingMode", ] diff --git a/quack/gemm.py b/quack/gemm.py index 1dcb7096..edb7e792 100644 --- a/quack/gemm.py +++ b/quack/gemm.py @@ -32,6 +32,34 @@ ) +from cutlass import Float32, Int32 +from quack.cute_dsl_utils import ( + get_device_capacity, + get_max_active_clusters, + torch2cute_dtype_map, +) +from quack.gemm_default_epi import ( + GemmDefaultEpiMixin, + GemmDefaultSm100, + GemmDefaultSm120, + GemmDefaultSm90, +) +from quack.gemm_problem_adapter import ( + GroupedProblemAdapterMixin, + GroupedProblemArguments, +) +from quack.gemm_tvm_ffi_utils import ( + compile_gemm_kernel, + get_dtypes, + get_majors, + make_fake_gemm_tensors, + make_fake_scheduler_args, + make_fake_varlen_args, + make_scheduler_args, + make_varlen_args, + perm3d, +) + @jit_cache def _compile_gemm( a_dtype, @@ -63,6 +91,7 @@ def _compile_gemm( rounding_mode, sr_seed_mode, has_trace_ptr, + grouped, ): sm_to_cls = { 9: GemmDefaultSm90, @@ -71,6 +100,10 @@ def _compile_gemm( 12: GemmDefaultSm120, } GemmCls = sm_to_cls[device_capacity[0]] + if grouped: + GemmCls = type( + f"Grouped{GemmCls.__name__}", (GroupedProblemAdapterMixin, GemmCls), {} + ) mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors( a_dtype, b_dtype, @@ -115,6 +148,14 @@ def fake_scalar(mode, dtype=Float32): ) aidx_len = m if varlen_m else (k if varlen_k else None) varlen_args = make_fake_varlen_args(varlen_m, varlen_k, gather_A, aidx_len) + problem_args = ( + GroupedProblemArguments( + mProblemIndex=fake_tensor(Int32, (l,), leading_dim=0, divisibility=4), + mProblemK=fake_tensor(Int32, (l,), leading_dim=0, divisibility=4), + ) + if grouped + else None + ) return compile_gemm_kernel( GemmCls, a_dtype, @@ -135,6 +176,7 @@ def fake_scalar(mode, dtype=Float32): has_trace_ptr=has_trace_ptr, use_tma_gather=use_tma_gather, concat_layout=concat_layout or None, + problem_args=problem_args, ) @@ -157,21 +199,48 @@ def gemm( colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m alpha: float | Tensor = 1.0, beta: float | Tensor = 1.0, - cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length - cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length - A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen - batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler + cu_seqlens_m: Optional[ + Tensor + ] = None, # (l+1,) cumulative sum of m values for variable length + cu_seqlens_k: Optional[ + Tensor + ] = None, # (l+1,) cumulative sum of k values for variable length + A_idx: Optional[ + Tensor + ] = None, # (total_m,) or (total_k,) indices for gather_A when varlen + batch_idx_permute: Optional[ + Tensor + ] = None, # (l,) permutation of batch indices for scheduler add_to_output: bool = False, rounding_mode: int = RoundingMode.RN, sr_seed: int | Tensor = 0, use_tma_gather: bool = False, concat_layout: dict | None = None, + grouped: bool = False, + grouped_problem_index: Optional[Tensor] = None, + grouped_problem_k: Optional[Tensor] = None, trace_ptr=None, # Optional Int64 from TraceSession.ptr ) -> None: varlen_m = cu_seqlens_m is not None varlen_k = cu_seqlens_k is not None varlen = varlen_m or varlen_k gather_A = A_idx is not None + if grouped: + assert A.ndim == 3 and B.ndim == 3 and D.ndim == 3, ( + "grouped GEMM expects dense packed batched tensors" + ) + assert not varlen and not gather_A, ( + "grouped GEMM does not combine with varlen/gather" + ) + batch_size = A.shape[0] + if grouped_problem_index is not None: + assert grouped_problem_index.numel() == batch_size, ( + "grouped_problem_index must have one entry per problem" + ) + if grouped_problem_k is not None: + assert grouped_problem_k.numel() == batch_size, ( + "grouped_problem_k must have one entry per problem" + ) assert not (varlen_m and varlen_k), "Only one of cu_seqlens_m and cu_seqlens_k" if gather_A: assert varlen, "gather_A requires varlen" @@ -188,11 +257,17 @@ def gemm( assert B.stride(-2) == 1, "varlen_k requires B to be n-major" device_capacity = get_device_capacity(A.device) - assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported" + assert device_capacity[0] in [9, 10, 11, 12], ( + "Only SM90, SM100, SM110, and SM120 are supported" + ) if use_tma_gather: - assert device_capacity[0] in [10, 11], "TMA gather currently requires SM100/SM110" + assert device_capacity[0] in [10, 11], ( + "TMA gather currently requires SM100/SM110" + ) if rounding_mode == RoundingMode.RS: - assert device_capacity[0] == 10, "Stochastic rounding (RoundingMode.RS) requires SM100" + assert device_capacity[0] == 10, ( + "Stochastic rounding (RoundingMode.RS) requires SM100" + ) if is_dynamic_persistent and device_capacity[0] == 9: assert tile_count_semaphore is not None, ( "Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM" @@ -208,7 +283,9 @@ def gemm( concat_layout = tuple(sorted(concat_layout)) if concat_layout else () sr_seed_mode = ( - 2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0) + 2 + if isinstance(sr_seed, Tensor) + else (1 if rounding_mode == RoundingMode.RS else 0) ) compiled_fn = _compile_gemm( a_dtype, @@ -240,6 +317,7 @@ def gemm( rounding_mode, sr_seed_mode, trace_ptr is not None, + grouped, ) from quack.cache_utils import COMPILE_ONLY @@ -255,7 +333,9 @@ def scalar_arg(scalar, mode, dtype=Float32): else: return scalar.data_ptr() - max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + max_active_clusters = ( + get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + ) epi_args = GemmDefaultEpiMixin.EpilogueArguments( alpha=scalar_arg(alpha, alpha_mode), @@ -273,10 +353,38 @@ def scalar_arg(scalar, mode, dtype=Float32): batch_idx_permute, ) varlen_args = make_varlen_args(cu_seqlens_m, cu_seqlens_k, A_idx) + problem_args = ( + GroupedProblemArguments( + mProblemIndex=grouped_problem_index, + mProblemK=grouped_problem_k, + ) + if grouped + else None + ) if device_capacity[0] in [10, 11]: compiled_fn( - A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, trace_ptr + A_p, + B_p, + D_p, + C_p, + epi_args, + scheduler_args, + varlen_args, + None, + None, + trace_ptr, + problem_args, ) else: - compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, trace_ptr) + compiled_fn( + A_p, + B_p, + D_p, + C_p, + epi_args, + scheduler_args, + varlen_args, + trace_ptr, + problem_args, + ) diff --git a/quack/gemm_act.py b/quack/gemm_act.py index 92672333..8f3f48be 100644 --- a/quack/gemm_act.py +++ b/quack/gemm_act.py @@ -504,9 +504,9 @@ def scalar_arg(scalar, mode, dtype=Int32): varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx) if device_capacity[0] in [10, 11]: - compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None) + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None, None) else: - compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None) + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None) gemm_gated = gemm_act diff --git a/quack/gemm_dact.py b/quack/gemm_dact.py index 6a3d3489..a6afa6d2 100644 --- a/quack/gemm_dact.py +++ b/quack/gemm_dact.py @@ -499,10 +499,10 @@ def gemm_dact( if device_capacity[0] in [10, 11]: compiled_fn( - A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None, None, None + A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None, None, None, None ) else: - compiled_fn(A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None) + compiled_fn(A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None, None) gemm_dgated = gemm_dact diff --git a/quack/gemm_epilogue_plan.py b/quack/gemm_epilogue_plan.py new file mode 100644 index 00000000..4058c14c --- /dev/null +++ b/quack/gemm_epilogue_plan.py @@ -0,0 +1,216 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Optional, Tuple, Callable + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Boolean, const_expr + +import quack.copy_utils as copy_utils +from quack.rounding import RoundingMode +from quack.varlen_utils import VarlenManager + + +def default_epi_tile_layout(self, epi_tile_shape): + return cute.make_ordered_layout(epi_tile_shape, order=(1, 0)) + + +@cute.jit +def default_epi_commit( + self, + gmem_coord, + epi_buffer, + copy_D, + copy_postact, + postact_ctx, + tile_coord_mnkl, + is_tma_warp, + epi_store_pipeline, +): + del tile_coord_mnkl + has_D = const_expr(copy_D is not None) + if is_tma_warp: + if const_expr(has_D): + copy_D(src_idx=epi_buffer, dst_idx=gmem_coord) + if const_expr(postact_ctx is not None): + copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord) + epi_store_pipeline.producer_commit() + + +@cute.jit +def symmetric_epi_commit( + self, + gmem_coord, + epi_buffer, + copy_D, + copy_postact, + postact_ctx, + tile_coord_mnkl, + is_tma_warp, + epi_store_pipeline, +): + has_D = const_expr(copy_D is not None) + if is_tma_warp: + square_tile_m = tile_coord_mnkl[0] // self.cluster_shape_mnk[0] + square_tile_n = tile_coord_mnkl[1] // self.cluster_shape_mnk[1] + if const_expr(has_D): + copy_D(src_idx=epi_buffer, dst_idx=gmem_coord) + if const_expr(postact_ctx is not None) and square_tile_m != square_tile_n: + copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord) + epi_store_pipeline.producer_commit() + + +@cute.jit +def run_epilogue_plan( + self, + params, + epi_smem_tensors: Tuple[cute.Tensor, ...], + epi_pipeline: cutlass.pipeline.PipelineAsync, + epi_store_pipeline: cutlass.pipeline.PipelineAsync, + epi_read_state: cutlass.pipeline.PipelineState, + epi_producer_state: cutlass.pipeline.PipelineState, + epi_tile: cute.Tile, + load_acc_subtile: Callable, + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor], + tiled_copy_t2r: Optional[cute.TiledCopy], + tiled_copy_r2s: cute.TiledCopy, + tRS_sD: cute.Tensor, + tiled_copy_s2r: Optional[cute.TiledCopy], + tSR_rC: Optional[cute.Tensor], + tSR_sC: Optional[cute.Tensor], + copy_D: Optional[Callable], + copy_C: Optional[Callable], + tile_coord_mnkl: cute.Coord, + varlen_manager: VarlenManager, + epilogue_barrier: cutlass.pipeline.NamedBarrier, + tile_scheduler, + tidx: Int32, + is_tma_warp: Boolean, +) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]: + has_C = const_expr(tRS_rC is not None) + has_D = const_expr(copy_D is not None) + + postact_ctx = self.epi_setup_postact( + params, + epi_smem_tensors, + tiled_copy_r2s, + tiled_copy_t2r, + tile_coord_mnkl, + varlen_manager, + tidx, + ) + + epi_tile_shape = cute.zipped_divide(cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile).shape[1] + epi_tile_layout = self.epi_plan_make_tile_layout(epi_tile_shape) + epi_tile_num = cute.size(epi_tile_shape) + num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num + + epi_tensors = self.epi_begin( + params, + epi_smem_tensors, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + epilogue_barrier, + tidx, + ) + + if const_expr(copy_C is not None): + for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1): + gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx) + if is_tma_warp: + epi_pipeline.producer_acquire(epi_producer_state) + copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + + for epi_idx in cutlass.range_constexpr(epi_tile_num): + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) + load_acc_subtile(tRS_rD, epi_idx) + epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord) + if const_expr(has_C): + epi_pipeline.consumer_wait(epi_read_state) + cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC) + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() + with cute.arch.elect_one(): + epi_pipeline.consumer_release(epi_read_state) + epi_read_state.advance() + if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num): + gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage) + if is_tma_warp: + epi_pipeline.producer_acquire(epi_producer_state) + copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC) + if const_expr(postact_ctx is not None): + tRS_rPostAct_out = self.epi_convert_postact( + tRS_rPostAct, + epi_loop_tensors["sr_seed"], + tidx, + tile_coord_mnkl, + num_prev_subtiles, + epi_idx, + ) + if is_tma_warp: + epi_store_pipeline.producer_acquire() + epilogue_barrier.arrive_and_wait() + epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage + if const_expr(has_D): + if const_expr( + self.rounding_mode == RoundingMode.RS + and self.acc_dtype == cutlass.Float32 + and self.d_dtype == cutlass.BFloat16 + ): + seed = epi_loop_tensors["sr_seed"] + ( + tile_coord_mnkl[0] * 65537 + + tile_coord_mnkl[1] * 257 + + tile_coord_mnkl[3] * 17 + + (num_prev_subtiles + epi_idx) * 7 + ) + copy_utils.sr_cvt_copy( + tiled_copy_r2s, + tRS_rD, + tRS_sD[None, None, None, epi_buffer], + seed, + tidx, + ) + else: + copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer]) + copy_postact = None + if const_expr(postact_ctx is not None): + tiled_copy_postact_r2s, tRS_sPostAct, copy_postact = postact_ctx + cute.copy( + tiled_copy_postact_r2s, + tiled_copy_postact_r2s.retile(tRS_rPostAct_out), + tRS_sPostAct[None, None, None, epi_buffer], + ) + cute.arch.fence_view_async_shared() + epilogue_barrier.arrive_and_wait() + self.epi_plan_commit( + gmem_coord, + epi_buffer, + copy_D, + copy_postact, + postact_ctx, + tile_coord_mnkl, + is_tma_warp, + epi_store_pipeline, + ) + + self.epi_end( + params, + epi_tensors, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + tidx, + ) + + return epi_read_state, epi_producer_state diff --git a/quack/gemm_interface.py b/quack/gemm_interface.py index 64ec9b6c..bd1bdb3f 100644 --- a/quack/gemm_interface.py +++ b/quack/gemm_interface.py @@ -437,6 +437,96 @@ def gemm( return out +def gemm_grouped( + A_list: list[Tensor], + B_list: list[Tensor], + out_list: Optional[list[Tensor]] = None, + alpha: float | Tensor = 1.0, + out_dtype: Optional[torch.dtype] = None, + dynamic_scheduler: bool = False, + tuned: bool = True, + rounding_mode: int = RoundingMode.RN, + sr_seed: int | Tensor = 0, +): + """Grouped GEMM for heterogeneous `(M_i, K_i) x (K_i, N_i)` problems. + + This host wrapper packs problems into padded batched tensors, dispatches through + the grouped-capable core GEMM path, and slices outputs back to per-problem shapes. + """ + assert len(A_list) == len(B_list) and len(A_list) > 0, "A_list/B_list must be same non-zero length" + device = A_list[0].device + dtype = A_list[0].dtype + out_dtype = dtype if out_dtype is None else out_dtype + problem_count = len(A_list) + ms = [a.shape[0] for a in A_list] + ks = [a.shape[1] for a in A_list] + ns = [b.shape[1] for b in B_list] + assert all(a.ndim == 2 for a in A_list), "Each A must be rank-2" + assert all(b.ndim == 2 for b in B_list), "Each B must be rank-2" + assert all(a.device == device and b.device == device for a, b in zip(A_list, B_list)), ( + "All grouped tensors must be on the same device" + ) + assert all(a.dtype == dtype and b.dtype == dtype for a, b in zip(A_list, B_list)), ( + "All grouped tensors must have the same dtype" + ) + assert all(a.shape[1] == b.shape[0] for a, b in zip(A_list, B_list)), "Incompatible K dims" + + max_m = max(ms) + max_k = max(ks) + max_n = max(ns) + + A_packed = torch.zeros((problem_count, max_m, max_k), dtype=dtype, device=device) + B_packed = torch.zeros((problem_count, max_k, max_n), dtype=dtype, device=device) + D_packed = torch.empty((problem_count, max_m, max_n), dtype=out_dtype, device=device) + + for idx, (A_i, B_i) in enumerate(zip(A_list, B_list)): + m_i, k_i = A_i.shape + _, n_i = B_i.shape + A_packed[idx, :m_i, :k_i].copy_(A_i) + B_packed[idx, :k_i, :n_i].copy_(B_i) + + grouped_problem_index = torch.arange(problem_count, device=device, dtype=torch.int32) + grouped_problem_k = torch.tensor(ks, device=device, dtype=torch.int32) + + config = default_config(device) + dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent + tile_count_semaphore = ( + torch.zeros(1, dtype=torch.int32, device=device) + if dynamic_scheduler and get_device_capacity(device)[0] == 9 + else None + ) + + gemm_dispatch( + A_packed, + B_packed, + D_packed, + None, + tile_count_semaphore, + config.tile_m, + config.tile_n, + config.cluster_m, + config.cluster_n, + config.pingpong, + persistent=True, + is_dynamic_persistent=dynamic_scheduler, + max_swizzle_size=config.max_swizzle_size, + alpha=alpha, + rounding_mode=rounding_mode, + sr_seed=sr_seed, + grouped=True, + grouped_problem_index=grouped_problem_index, + grouped_problem_k=grouped_problem_k, + ) + + if out_list is None: + out_list = [ + torch.empty((m_i, n_i), dtype=out_dtype, device=device) for m_i, n_i in zip(ms, ns) + ] + for out_i, m_i, n_i, packed_i in zip(out_list, ms, ns, D_packed): + out_i.copy_(packed_i[:m_i, :n_i]) + return out_list + + @torch.library.custom_op( "quack::gemm_out", mutates_args=("out",), diff --git a/quack/gemm_norm_act.py b/quack/gemm_norm_act.py index f363dfa2..c46394c8 100644 --- a/quack/gemm_norm_act.py +++ b/quack/gemm_norm_act.py @@ -395,6 +395,6 @@ def scalar_arg(scalar, mode, dtype=Int32): varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx) if device_capacity[0] in [10, 11]: - compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None) + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None, None) else: - compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None) + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None) diff --git a/quack/gemm_problem_adapter.py b/quack/gemm_problem_adapter.py new file mode 100644 index 00000000..2b2c5d4a --- /dev/null +++ b/quack/gemm_problem_adapter.py @@ -0,0 +1,158 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import NamedTuple, Optional +from dataclasses import dataclass + +import cutlass.cute as cute +from cutlass import Int32, const_expr + +from quack.cute_dsl_utils import ParamsBase, mlir_namedtuple + + +@mlir_namedtuple +class ProblemArguments(NamedTuple): + pass + + +@dataclass +class ProblemParams(ParamsBase): + pass + + +@mlir_namedtuple +class GroupedProblemArguments(NamedTuple): + mProblemIndex: Optional[cute.Tensor] = None + mProblemK: Optional[cute.Tensor] = None + + +@dataclass +class GroupedProblemParams(ParamsBase): + problem_index: Optional[cute.Tensor] = None + problem_k: Optional[cute.Tensor] = None + + @staticmethod + def create(args: GroupedProblemArguments, *, loc=None, ip=None) -> "GroupedProblemParams": + return GroupedProblemParams(problem_index=args.mProblemIndex, problem_k=args.mProblemK) + + +def default_problem_to_underlying_arguments( + args: ProblemArguments | None, *, loc=None, ip=None +) -> ProblemParams: + return ProblemParams() + + +def grouped_problem_to_underlying_arguments( + args: GroupedProblemArguments | None, *, loc=None, ip=None +) -> GroupedProblemParams: + if args is None: + args = GroupedProblemArguments() + return GroupedProblemParams.create(args, loc=loc, ip=ip) + + +class GroupedProblemAdapterMixin: + ProblemArguments = GroupedProblemArguments + ProblemParams = GroupedProblemParams + + def problem_to_underlying_arguments( + self, args: GroupedProblemArguments | None = None, *, loc=None, ip=None + ) -> GroupedProblemParams: + return grouped_problem_to_underlying_arguments(args, loc=loc, ip=ip) + + @cute.jit + def problem_get_problem_idx(self, params: GroupedProblemParams, work): + return grouped_problem_idx(params, work) + + @cute.jit + def problem_get_len_k(self, params: GroupedProblemParams, varlen_manager, work): + return grouped_problem_len_k(params, work, varlen_manager) + + @cute.jit + def problem_get_batch_A(self, params: GroupedProblemParams, mA_mkl, varlen_manager, work): + return grouped_problem_batch_a(params, mA_mkl, varlen_manager, work) + + @cute.jit + def problem_get_batch_B(self, params: GroupedProblemParams, mB_nkl, varlen_manager, work): + return grouped_problem_batch_b(params, mB_nkl, varlen_manager, work) + + @cute.jit + def problem_get_batch_epi(self, params: GroupedProblemParams, mX_mnl, varlen_manager, work): + return grouped_problem_batch_epi(params, mX_mnl, varlen_manager, work) + + +@cute.jit +def default_problem_idx(params: ProblemParams, work, *, loc=None, ip=None) -> Int32: + del params, loc, ip + return work.problem_idx + + +@cute.jit +def grouped_problem_idx(params: GroupedProblemParams, work, *, loc=None, ip=None) -> Int32: + del loc, ip + if const_expr(params.problem_index is None): + return work.problem_idx + return params.problem_index[work.problem_idx] + + +@cute.jit +def default_problem_len_k(params: ProblemParams, work, varlen_manager, *, loc=None, ip=None) -> Int32: + del params, loc, ip + return varlen_manager.len_k(work.problem_idx) + + +@cute.jit +def grouped_problem_len_k( + params: GroupedProblemParams, work, varlen_manager, *, loc=None, ip=None +) -> Int32: + del loc, ip + if const_expr(params.problem_k is not None): + problem_idx = grouped_problem_idx(params, work) + return params.problem_k[problem_idx] + return varlen_manager.len_k(grouped_problem_idx(params, work)) + + +@cute.jit +def default_problem_batch_a(params: ProblemParams, mA_mkl, varlen_manager, work, *, loc=None, ip=None): + del params, loc, ip + return varlen_manager.offset_batch_A(mA_mkl, work.problem_idx) + + +@cute.jit +def grouped_problem_batch_a( + params: GroupedProblemParams, mA_mkl, varlen_manager, work, *, loc=None, ip=None +): + del loc, ip + return varlen_manager.offset_batch_A(mA_mkl, grouped_problem_idx(params, work)) + + +@cute.jit +def default_problem_batch_b(params: ProblemParams, mB_nkl, varlen_manager, work, *, loc=None, ip=None): + del params, loc, ip + return varlen_manager.offset_batch_B(mB_nkl, work.problem_idx) + + +@cute.jit +def grouped_problem_batch_b( + params: GroupedProblemParams, mB_nkl, varlen_manager, work, *, loc=None, ip=None +): + del loc, ip + return varlen_manager.offset_batch_B(mB_nkl, grouped_problem_idx(params, work)) + + +@cute.jit +def default_problem_batch_epi( + params: ProblemParams, mX_mnl, varlen_manager, work, *, loc=None, ip=None +): + del params, loc, ip + return None if const_expr(mX_mnl is None) else varlen_manager.offset_batch_epi(mX_mnl, work.problem_idx) + + +@cute.jit +def grouped_problem_batch_epi( + params: GroupedProblemParams, mX_mnl, varlen_manager, work, *, loc=None, ip=None +): + del loc, ip + return ( + None + if const_expr(mX_mnl is None) + else varlen_manager.offset_batch_epi(mX_mnl, grouped_problem_idx(params, work)) + ) diff --git a/quack/gemm_sm100.py b/quack/gemm_sm100.py index 312f9903..01549876 100644 --- a/quack/gemm_sm100.py +++ b/quack/gemm_sm100.py @@ -34,7 +34,15 @@ from quack.layout_utils import tile_atom_to_shape_SF_strided # return PipelineStateWAdvance instead of PipelineState - +from cutlass import Boolean, const_expr, Float32, Int32 +from cutlass.cute.nvgpu.warp import ( + LdMatrix16x16x8bOp, + LdMatrix8x8x16bOp, + StMatrix16x8x8bOp, + StMatrix8x8x16bOp, +) +from quack.pipeline import PipelineTmaCpAsyncUmma, PipelineTmaUmma +from typing import Callable, Literal, Optional, Tuple, Type, Union """ A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL. @@ -210,7 +218,9 @@ def __init__( if use_tma_gather: assert gather_A, "TMA gather requires gather_A=True" - self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) self.num_ab_load_warps = 1 if not self.gather_A else 4 self.occupancy = 1 @@ -250,7 +260,9 @@ def __init__( # Multiple of 4 warps to increase/decrease number of registers assert self.threads_per_cta % 128 == 0 - def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: VarlenArguments): + def _setup_attributes( + self, epilogue_args: EpilogueArguments, varlen_args: VarlenArguments + ): """Set up configurations that are dependent on GEMM inputs This method configures various attributes based on the input tensor properties @@ -267,7 +279,9 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: Varle # Compute mma instruction shapes mma_inst_bits_k = 256 # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) - mma_inst_shape_n = self.mma_tiler[1] if self.mma_tiler[1] <= 256 else self.mma_tiler[1] // 2 + mma_inst_shape_n = ( + self.mma_tiler[1] if self.mma_tiler[1] <= 256 else self.mma_tiler[1] // 2 + ) self.mma_inst_shape_mnk = ( self.mma_tiler[0], mma_inst_shape_n, @@ -380,10 +394,15 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: Varle # There's a bug w compute_epilogue_tile_shape (as of cutlass-dsl 4.4.2) where if # tile_n = 224 and there's C, it will set epi_tile to (128, 64). if const_expr(self.cta_tile_shape_mnk[1] % cute.size(self.epi_tile[1]) != 0): - warp_n = 2 if (self.cta_tile_shape_mnk[0] == 64 and self.use_2cta_instrs) else 1 - epi_tile_n = math.gcd(self.cta_tile_shape_mnk[1], cute.size(self.epi_tile[1])) + warp_n = ( + 2 if (self.cta_tile_shape_mnk[0] == 64 and self.use_2cta_instrs) else 1 + ) + epi_tile_n = math.gcd( + self.cta_tile_shape_mnk[1], cute.size(self.epi_tile[1]) + ) epi_tile_n_layout = cute.make_layout( - (epi_tile_n // warp_n, warp_n), stride=(1, self.cta_tile_shape_mnk[1] // warp_n) + (epi_tile_n // warp_n, warp_n), + stride=(1, self.cta_tile_shape_mnk[1] // warp_n), ) self.epi_tile = (self.epi_tile[0], cute.coalesce(epi_tile_n_layout)) @@ -413,7 +432,9 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: Varle self.c_layout, epilogue_args, prefetch_A_idx, - cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity + cutlass.utils.get_smem_capacity_in_bytes( + f"sm_{self.arch}" + ), # smem_capacity self.occupancy, ) self.sched_stage = 1 @@ -430,12 +451,16 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: Varle self.a_smem_load_layout_staged = self.a_smem_layout_staged if const_expr(self.gather_A): if const_expr(self.use_tma_gather): - self.a_smem_load_layout_staged = quack_sm100_utils.make_smem_layout_tma_gather_a( - self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage + self.a_smem_load_layout_staged = ( + quack_sm100_utils.make_smem_layout_tma_gather_a( + self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage + ) ) else: - self.a_smem_load_layout_staged = quack_sm100_utils.make_smem_layout_cpasync_a( - self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage + self.a_smem_load_layout_staged = ( + quack_sm100_utils.make_smem_layout_cpasync_a( + self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage + ) ) self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( self.tiled_mma, self.mma_tiler, self.b_dtype, self.ab_stage @@ -495,7 +520,9 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: Varle * 4 # 4 cols per stage * (self.mma_inst_shape_mnk[2] // self.sf_vec_size) ) - self.iter_acc_early_release = num_sf_tmem_cols // cute.size(self.epi_tile[1]) + self.iter_acc_early_release = num_sf_tmem_cols // cute.size( + self.epi_tile[1] + ) else: self.iter_acc_early_release = -1 @@ -513,6 +540,7 @@ def __call__( mSFA: Optional[cute.Tensor] = None, mSFB: Optional[cute.Tensor] = None, trace_ptr: Optional[cutlass.Int64] = None, + problem_args: Optional[GemmSm90.ProblemArguments] = None, ): """Execute the GEMM operation in steps: - Setup static attributes before smem/grid/tma computation @@ -605,14 +633,18 @@ def __call__( if const_expr(not self.gather_A): tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( a_op, - copy_utils.create_ragged_tensor_for_tma(mA, ragged_dim=1, ptr_shift=False) + copy_utils.create_ragged_tensor_for_tma( + mA, ragged_dim=1, ptr_shift=False + ) if varlen_k and not self.gather_A else mA, a_smem_layout, self.mma_tiler, self.tiled_mma, self.cluster_layout_vmnk.shape, - internal_type=(cutlass.TFloat32 if mA.element_type is Float32 else None), + internal_type=( + cutlass.TFloat32 if mA.element_type is Float32 else None + ), ) elif const_expr(self.use_tma_gather): # gather4 descriptor: box has 1 in the gathered dim, tile size in the contiguous dim. @@ -626,14 +658,18 @@ def __call__( mA, tma_smem_layout, tma_smem_layout.shape, - internal_type=(cutlass.TFloat32 if mA.element_type is Float32 else None), + internal_type=( + cutlass.TFloat32 if mA.element_type is Float32 else None + ), ) b_op = sm100_utils.cluster_shape_to_tma_atom_B( self.cluster_shape_mnk, self.tiled_mma.thr_id ) tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( b_op, - copy_utils.create_ragged_tensor_for_tma(mB, ragged_dim=1) if varlen_k else mB, + copy_utils.create_ragged_tensor_for_tma(mB, ragged_dim=1) + if varlen_k + else mB, b_smem_layout, self.mma_tiler, self.tiled_mma, @@ -648,7 +684,9 @@ def __call__( sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( self.cluster_shape_mnk, self.tiled_mma.thr_id ) - sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) + sfa_smem_layout = cute.slice_( + self.sfa_smem_layout_staged, (None, None, None, 0) + ) tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( sfa_op, mSFA, @@ -662,7 +700,9 @@ def __call__( sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( self.cluster_shape_mnk, self.tiled_mma.thr_id ) - sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) + sfb_smem_layout = cute.slice_( + self.sfb_smem_layout_staged, (None, None, None, 0) + ) tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( sfb_op, mSFB, @@ -673,7 +713,8 @@ def __call__( internal_type=cutlass.Int16, ) if const_expr( - self.cta_tile_shape_mnk[1] == 192 and self.sf_dtype is cutlass.Float8E8M0FNU + self.cta_tile_shape_mnk[1] == 192 + and self.sf_dtype is cutlass.Float8E8M0FNU ): x = tma_tensor_sfb.stride[0][1] y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) @@ -706,13 +747,18 @@ def __call__( tma_atom_d, tma_tensor_d = None, None if const_expr(mD is not None): tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors( - copy_utils.create_ragged_tensor_for_tma(mD, ragged_dim=0, ptr_shift=True) + copy_utils.create_ragged_tensor_for_tma( + mD, ragged_dim=0, ptr_shift=True + ) if varlen_m else mD, self.epi_smem_layout_staged, self.epi_tile, op_type="store" - if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output) + if not ( + hasattr(epilogue_args, "add_to_output") + and epilogue_args.add_to_output + ) else "add", ) tma_atom_c, tma_tensor_c = None, None @@ -723,6 +769,7 @@ def __call__( epilogue_params = self.epi_to_underlying_arguments(epilogue_args) varlen_params = VarlenManager.to_underlying_arguments(varlen_args) + problem_params = self.problem_to_underlying_arguments(problem_args) TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_m) tile_sched_args = self.get_scheduler_arguments( @@ -735,14 +782,24 @@ def __call__( self.buffer_align_bytes = 1024 - epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0 - epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0 - sf_dtype = self.sf_dtype if const_expr(self.blockscaled) else cutlass.Float8E8M0FNU + epi_smem_size = ( + cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0 + ) + epi_c_smem_size = ( + cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0 + ) + sf_dtype = ( + self.sf_dtype if const_expr(self.blockscaled) else cutlass.Float8E8M0FNU + ) sfa_smem_size = ( - cute.cosize(self.sfa_smem_layout_staged) if const_expr(self.blockscaled) else 0 + cute.cosize(self.sfa_smem_layout_staged) + if const_expr(self.blockscaled) + else 0 ) sfb_smem_size = ( - cute.cosize(self.sfb_smem_layout_staged) if const_expr(self.blockscaled) else 0 + cute.cosize(self.sfb_smem_layout_staged) + if const_expr(self.blockscaled) + else 0 ) a_idx_smem_size = 0 if const_expr(self.gather_A): @@ -753,10 +810,18 @@ def __call__( # Define shared storage for kernel @cute.struct class SharedStorage: - ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2] - epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2] - acc_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] - sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2] + ab_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.ab_stage * 2 + ] + epi_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.epi_c_stage * 2 + ] + acc_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_acc_stage * 2 + ] + sched_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.sched_stage * 2 + ] a_prefetch_pipeline_array_ptr: cute.struct.MemRange[ cutlass.Int64, self.a_prefetch_stage * 2 ] @@ -780,12 +845,16 @@ class SharedStorage: epi: self.epi_get_smem_struct(epilogue_params) # (MMA, MMA_M, MMA_K, STAGE) sA: cute.struct.Align[ - cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], self.buffer_align_bytes, ] # (MMA, MMA_N, MMA_K, STAGE) sB: cute.struct.Align[ - cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], self.buffer_align_bytes, ] # (MMA, MMA_M, MMA_K, STAGE) @@ -806,7 +875,9 @@ class SharedStorage: self.tiled_mma, self.tiled_mma_sfb, tma_atom_a, - tma_tensor_a if const_expr(not self.gather_A or self.use_tma_gather) else mA, + tma_tensor_a + if const_expr(not self.gather_A or self.use_tma_gather) + else mA, tma_atom_b, tma_tensor_b, tma_atom_sfa, @@ -831,6 +902,7 @@ class SharedStorage: self.epi_tile, tile_sched_params, TileSchedulerCls, + problem_params, trace_ptr, ).launch( grid=grid, @@ -874,6 +946,7 @@ def kernel( epi_tile: cute.Tile, tile_sched_params, TileSchedulerCls: cutlass.Constexpr[Callable], + problem_params, trace_ptr: Optional[cutlass.Int64] = None, ): """ @@ -914,7 +987,9 @@ def kernel( bidx, _, _ = cute.arch.block_idx() mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) is_leader_cta = mma_tile_coord_v == 0 - cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) # Coord inside cta tidx, _, _ = cute.arch.thread_idx() @@ -956,7 +1031,8 @@ def kernel( tmem_alloc_barrier = pipeline.NamedBarrier( barrier_id=int(NamedBarrierGemm.TmemPtr), - num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id)), + num_threads=cute.arch.WARP_SIZE + * len((self.mma_warp_id, *self.epilog_warp_id)), ) # Tensor memory dealloc barrier init tmem = cutlass.utils.TmemAllocator( @@ -973,13 +1049,19 @@ def kernel( # Setup smem tensor A/B/D # (MMA, MMA_M, MMA_K, STAGE) sA_mma = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner) - sA = storage.sA.get_tensor(a_smem_load_layout.outer, swizzle=a_smem_load_layout.inner) + sA = storage.sA.get_tensor( + a_smem_load_layout.outer, swizzle=a_smem_load_layout.inner + ) # (MMA, MMA_N, MMA_K, STAGE) sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner) sAIdx = None if const_expr(self.gather_A): - a_idx_smem_dim = self.cta_tile_shape_mnk[0] if varlen_m else self.cta_tile_shape_mnk[2] - a_idx_smem_layout = cute.make_layout((a_idx_smem_dim, self.a_prefetch_stage)) + a_idx_smem_dim = ( + self.cta_tile_shape_mnk[0] if varlen_m else self.cta_tile_shape_mnk[2] + ) + a_idx_smem_layout = cute.make_layout( + (a_idx_smem_dim, self.a_prefetch_stage) + ) sAIdx = storage.sAIdx.get_tensor(a_idx_smem_layout) sSFA, sSFB = None, None if const_expr(self.blockscaled): @@ -990,22 +1072,30 @@ def kernel( sD = None if const_expr(has_D): # (EPI_TILE_M, EPI_TILE_N, STAGE) - sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner) + sD = storage.sD.get_tensor( + epi_smem_layout.outer, swizzle=epi_smem_layout.inner + ) sC = None if const_expr(has_C): - sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner) + sC = storage.sC.get_tensor( + epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner + ) epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage) thr_mma = tiled_mma.get_slice(mma_tile_coord_v) thr_mma_sfb = ( - tiled_mma_sfb.get_slice(mma_tile_coord_v) if const_expr(self.blockscaled) else None + tiled_mma_sfb.get_slice(mma_tile_coord_v) + if const_expr(self.blockscaled) + else None ) # (MMA, MMA_M, MMA_N) acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_fake = tiled_mma.make_fragment_C( - cute.append(acc_shape, self.num_acc_stage if not self.overlap_accum_sf else 2) + cute.append( + acc_shape, self.num_acc_stage if not self.overlap_accum_sf else 2 + ) ) varlen_manager = VarlenManager.create( @@ -1026,7 +1116,8 @@ def kernel( epi_load_barrier = None if const_expr(has_C): epi_load_barrier = pipeline.NamedBarrier( - barrier_id=int(NamedBarrierGemm.EpilogueLoad), num_threads=2 * cute.arch.WARP_SIZE + barrier_id=int(NamedBarrierGemm.EpilogueLoad), + num_threads=2 * cute.arch.WARP_SIZE, ) # Cluster wait before tensor memory alloc @@ -1043,11 +1134,13 @@ def kernel( if const_expr(self.gather_A): cute.arch.setmaxregister_decrease(self.num_regs_other) # Compute multicast mask for A/B buffer full - block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) block_in_cluster_coord_sfb_vmnk = None if const_expr(self.blockscaled): - block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( - cta_rank_in_cluster + block_in_cluster_coord_sfb_vmnk = ( + cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster) ) a_mcast_mask, b_mcast_mask = None, None sfa_mcast_mask, sfb_mcast_mask = None, None @@ -1063,7 +1156,9 @@ def kernel( cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 ) sfb_mcast_mask = cpasync.create_tma_multicast_mask( - cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 + cluster_layout_sfb_vmnk, + block_in_cluster_coord_sfb_vmnk, + mcast_mode=1, ) # Persistent tile scheduling loop @@ -1077,8 +1172,8 @@ def kernel( ) do_epi_load_barrier_arrive = Boolean(True) while work_tile.is_valid_tile: - tile_coord_mnkl = work_tile.tile_idx - batch_idx = tile_coord_mnkl[3] + tile_coord_mnkl = work_tile.tile_coord_mnkl + batch_idx = self.problem_get_problem_idx(problem_params, work_tile) # Local_tile partition global tensors mma_tile_coord_mnl = ( tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape), @@ -1087,7 +1182,9 @@ def kernel( ) gA_mk = None if const_expr(not self.gather_A): - mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx) + mA_mk = self.problem_get_batch_A( + problem_params, mA_mkl, varlen_manager, work_tile + ) # (bM, bK, RestK) gA_mk = cute.local_tile( mA_mk, @@ -1096,7 +1193,9 @@ def kernel( ) # (bN, bK, RestK) gB_nk = cute.local_tile( - varlen_manager.offset_batch_B(mB_nkl, batch_idx), + self.problem_get_batch_B( + problem_params, mB_nkl, varlen_manager, work_tile + ), cute.select(self.mma_tiler, [1, 2]), (mma_tile_coord_mnl[1], None), ) @@ -1127,7 +1226,9 @@ def kernel( # Partition global tensor for TiledMMA_A/B/D # Then partition global/shared tensor for TMA load A/B - len_k = varlen_manager.len_k(batch_idx) + len_k = self.problem_get_len_k( + problem_params, varlen_manager, work_tile + ) # TMA load A partition_S/D a_cta_layout = cute.make_layout( cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape @@ -1164,7 +1265,9 @@ def kernel( if const_expr(varlen_m): cute.arch.sync_warp() with cute.arch.elect_one(): - a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state) + a_prefetch_pipeline.consumer_release( + a_prefetch_consumer_state + ) a_prefetch_consumer_state.advance() if const_expr(prefetch_A is not None): prefetch_A = partial(prefetch_A, a_prefetch_pipeline) @@ -1224,24 +1327,28 @@ def kernel( copy_SFB, ) elif const_expr(self.use_tma_gather): - ab_producer_state, a_prefetch_consumer_state = self.load_AB_tma_gather( - ab_pipeline, - ab_producer_state, - a_prefetch_consumer_state, - copy_A, - prefetch_A, - copy_B, - k_tile_cnt, + ab_producer_state, a_prefetch_consumer_state = ( + self.load_AB_tma_gather( + ab_pipeline, + ab_producer_state, + a_prefetch_consumer_state, + copy_A, + prefetch_A, + copy_B, + k_tile_cnt, + ) ) else: - ab_producer_state, a_prefetch_consumer_state = self.load_AB_gather_A( - ab_pipeline, - ab_producer_state, - a_prefetch_consumer_state, - copy_A, - prefetch_A, - copy_B, - k_tile_cnt, + ab_producer_state, a_prefetch_consumer_state = ( + self.load_AB_gather_A( + ab_pipeline, + ab_producer_state, + a_prefetch_consumer_state, + copy_A, + prefetch_A, + copy_B, + k_tile_cnt, + ) ) tctx.e("tma_load") if const_expr(epi_load_barrier is not None): @@ -1274,7 +1381,9 @@ def kernel( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: # Advance to next tile - tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) + tile_scheduler.advance_to_next_work( + is_scheduler_warp=is_scheduler_warp + ) work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop if is_scheduler_warp: @@ -1286,7 +1395,9 @@ def kernel( cute.arch.setmaxregister_decrease(self.num_regs_other) tile_M = self.cta_tile_shape_mnk[0] tile_K = self.cta_tile_shape_mnk[2] - tiled_copy_AIdx = copy_utils.tiled_copy_1d(Int32, num_threads=32, is_async=True) + tiled_copy_AIdx = copy_utils.tiled_copy_1d( + Int32, num_threads=32, is_async=True + ) thr_copy_AIdx = tiled_copy_AIdx.get_slice(cute.arch.lane_idx()) tAsAIdx = thr_copy_AIdx.partition_D(sAIdx) tAcAIdx = thr_copy_AIdx.partition_S( @@ -1299,16 +1410,20 @@ def kernel( pipeline.PipelineUserType.Producer, self.a_prefetch_stage ) while work_tile.is_valid_tile: - tile_coord_mnkl = work_tile.tile_idx - batch_idx = tile_coord_mnkl[3] + tile_coord_mnkl = work_tile.tile_coord_mnkl + batch_idx = self.problem_get_problem_idx(problem_params, work_tile) mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx) if const_expr(varlen_m): # (tile_M,) - gAIdx = cute.local_tile(mAIdx_mk, (tile_M,), (tile_coord_mnkl[0],)) + gAIdx = cute.local_tile( + mAIdx_mk, (tile_M,), (tile_coord_mnkl[0],) + ) tAgAIdx = thr_copy_AIdx.partition_S(gAIdx) len_m = varlen_manager.len_m(batch_idx) m_limit = len_m - tile_coord_mnkl[0] * tile_M - tApAIdx_m = cute.make_rmem_tensor((1, tAsAIdx.shape[1]), Boolean) + tApAIdx_m = cute.make_rmem_tensor( + (1, tAsAIdx.shape[1]), Boolean + ) for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True): tApAIdx_m[0, m] = tAcAIdx[0, m] < m_limit a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state) @@ -1324,31 +1439,43 @@ def kernel( # (tile_K, RestK) gAIdx = cute.flat_divide(mAIdx_mk, (tile_K,)) tAgAIdx = thr_copy_AIdx.partition_S(gAIdx) - len_k = varlen_manager.len_k(batch_idx) + len_k = self.problem_get_len_k( + problem_params, varlen_manager, work_tile + ) k_tile_cnt = cute.ceil_div(len_k, tile_K) for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1): - a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state) + a_prefetch_pipeline.producer_acquire( + a_prefetch_producer_state + ) cute.copy( thr_copy_AIdx, tAgAIdx[None, None, k_tile], tAsAIdx[None, None, a_prefetch_producer_state.index], ) - a_prefetch_pipeline.producer_commit(a_prefetch_producer_state) + a_prefetch_pipeline.producer_commit( + a_prefetch_producer_state + ) a_prefetch_producer_state.advance() if 0 < k_tile_cnt: k_tile = k_tile_cnt - 1 k_limit = len_k - k_tile * tile_K - tApAIdx_k = cute.make_rmem_tensor((1, tAsAIdx.shape[1]), Boolean) + tApAIdx_k = cute.make_rmem_tensor( + (1, tAsAIdx.shape[1]), Boolean + ) for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True): tApAIdx_k[0, m] = tAcAIdx[0, m] < k_limit - a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state) + a_prefetch_pipeline.producer_acquire( + a_prefetch_producer_state + ) cute.copy( tiled_copy_AIdx, tAgAIdx[None, None, k_tile], tAsAIdx[None, None, a_prefetch_producer_state.index], pred=tApAIdx_k, ) - a_prefetch_pipeline.producer_commit(a_prefetch_producer_state) + a_prefetch_pipeline.producer_commit( + a_prefetch_producer_state + ) a_prefetch_producer_state.advance() # Advance to next tile tile_scheduler.advance_to_next_work() @@ -1371,11 +1498,13 @@ def kernel( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: # Get tile coord from tile scheduler - tile_coord_mnkl = work_tile.tile_idx - batch_idx = tile_coord_mnkl[3] + tile_coord_mnkl = work_tile.tile_coord_mnkl + batch_idx = self.problem_get_problem_idx(problem_params, work_tile) copy_C_fn, _, bGS_gC = self.epilog_gmem_copy_and_partition( tma_atom_c, - varlen_manager.offset_batch_epi(mC_mnl, batch_idx), + self.problem_get_batch_epi( + problem_params, mC_mnl, varlen_manager, work_tile + ), self.cta_tile_shape_mnk[:2], epi_tile, sC, @@ -1459,8 +1588,16 @@ def kernel( ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) else: tCtSFA, tCtSFB = None, None - tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = None, None, None - tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = None, None, None + tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = ( + None, + None, + None, + ) + tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = ( + None, + None, + None, + ) # Persistent tile scheduling loop tile_scheduler = TileSchedulerCls() @@ -1473,9 +1610,11 @@ def kernel( ) while work_tile.is_valid_tile: # Get tile coord from tile scheduler - tile_coord_mnkl = work_tile.tile_idx - batch_idx = tile_coord_mnkl[3] - k_len = varlen_manager.len_k(batch_idx) + tile_coord_mnkl = work_tile.tile_coord_mnkl + batch_idx = self.problem_get_problem_idx(problem_params, work_tile) + k_len = self.problem_get_len_k( + problem_params, varlen_manager, work_tile + ) k_tile_cnt = cute.ceil_div(k_len, self.mma_tiler[2]) # Set tensor memory buffer for current tile # (MMA, MMA_M, MMA_N) @@ -1486,7 +1625,9 @@ def kernel( ) tCtAcc = tCtAcc_base[None, None, None, acc_stage_idx] tCtSFB_mma = tCtSFB - if const_expr(self.blockscaled and self.mma_inst_shape_mnk[1] in (64, 192)): + if const_expr( + self.blockscaled and self.mma_inst_shape_mnk[1] in (64, 192) + ): tCtSFB_mma = cute.make_tensor( cute.recast_ptr( sfb_tmem_base_ptr + Int32((tile_coord_mnkl[1] % 2) * 2), @@ -1525,14 +1666,20 @@ def kernel( # Doing tmem ptr arithmetic requires 32-bit type, wrong otherwise cute.recast_ptr(mT.iterator, dtype=Float32) + cute.assume( - acc_tmem_col_offset * (acc_producer_state.phase * 2 - 1), + acc_tmem_col_offset + * (acc_producer_state.phase * 2 - 1), divby=acc_tmem_col_offset, ), dtype=self.sf_dtype, ), mT.layout, ) - for mT in [tCtSFA, tCtSFB, tCtSFA_compact_s2t, tCtSFB_compact_s2t] + for mT in [ + tCtSFA, + tCtSFB, + tCtSFA_compact_s2t, + tCtSFB_compact_s2t, + ] ] tctx.e("mma") # Advance to next tile @@ -1566,8 +1713,10 @@ def kernel( # Partition for epilogue epi_tidx = tidx - tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition( - epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( + self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs + ) ) tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.acc_dtype) @@ -1577,8 +1726,15 @@ def kernel( tRS_rC, tSR_rC, tSR_sC = None, None, None tiled_copy_s2r = None if const_expr(mC_mnl is not None): - tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition( - tiled_copy_t2r, self.c_layout, self.c_dtype, sC, tRS_rD.layout, epi_tidx + tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = ( + self.epilog_smem_load_and_partition( + tiled_copy_t2r, + self.c_layout, + self.c_dtype, + sC, + tRS_rD.layout, + epi_tidx, + ) ) # Persistent tile scheduling loop @@ -1593,8 +1749,8 @@ def kernel( ) while work_tile.is_valid_tile: # Get tile coord from tile scheduler - tile_coord_mnkl = work_tile.tile_idx - batch_idx = tile_coord_mnkl[3] + tile_coord_mnkl = work_tile.tile_coord_mnkl + batch_idx = self.problem_get_problem_idx(problem_params, work_tile) # Set tensor memory buffer for current tile # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) epi_acc_stage = ( @@ -1610,7 +1766,9 @@ def kernel( if const_expr(has_D): copy_D, _, _ = self.epilog_gmem_copy_and_partition( tma_atom_d, - varlen_manager.offset_batch_epi(mD_mnl, batch_idx), + self.problem_get_batch_epi( + problem_params, mD_mnl, varlen_manager, work_tile + ), self.cta_tile_shape_mnk[:2], epi_tile, sD, @@ -1619,9 +1777,13 @@ def kernel( copy_C = None # We're using a separate warp to load C tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) - k_len = varlen_manager.len_k(batch_idx) + k_len = self.problem_get_len_k( + problem_params, varlen_manager, work_tile + ) epi_tile_num = cute.size( - cute.zipped_divide(cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile), + cute.zipped_divide( + cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile + ), mode=[1], ) load_acc_subtile = partial( @@ -1773,24 +1935,36 @@ def load_AB_gather_A( copy_B: Callable, k_tile_cnt: Int32, varlen_m: bool = True, - ) -> Tuple[cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState]]: + ) -> Tuple[ + cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState] + ]: warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt peek_ab_empty_status = Boolean(True) if 0 < k_tile_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) # TMA load on B and cp.async on A - for k_tile in cutlass.range(k_tile_cnt - 1, unroll=2 if const_expr(varlen_m) else 1): + for k_tile in cutlass.range( + k_tile_cnt - 1, unroll=2 if const_expr(varlen_m) else 1 + ): smem_idx = ab_producer_state.index prefetch_out = () - if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free - prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state),) + if const_expr( + prefetch_A is not None + ): # Prefetch early, even before smem is free + prefetch_out = ( + prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state), + ) a_prefetch_consumer_state.advance() # Wait for A/B buffers to be empty before loading into them # Also sets the transaction barrier for the A/B buffers # A tiny bit faster to rotate the warp that does TMA - is_tma_warp = warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps) - ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp) + is_tma_warp = warp_idx == self.ab_load_warp_id + ( + k_tile % self.num_ab_load_warps + ) + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status, is_tma_warp + ) # A bit faster to load B first while we calculate the indices for A tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state) if is_tma_warp: @@ -1801,17 +1975,27 @@ def load_AB_gather_A( ab_producer_state.advance() peek_ab_empty_status = Boolean(True) if k_tile + 1 < k_tile_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) # bound checking in the K dimension on the last k_tile if 0 < k_tile_cnt: k_tile = k_tile_cnt - 1 smem_idx = ab_producer_state.index prefetch_out = () - if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free - prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state, pred=True),) + if const_expr( + prefetch_A is not None + ): # Prefetch early, even before smem is free + prefetch_out = ( + prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state, pred=True), + ) a_prefetch_consumer_state.advance() - is_tma_warp = warp_idx == self.ab_load_warp_id + k_tile % self.num_ab_load_warps - ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp) + is_tma_warp = ( + warp_idx == self.ab_load_warp_id + k_tile % self.num_ab_load_warps + ) + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status, is_tma_warp + ) tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state) if is_tma_warp: copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) @@ -1830,7 +2014,9 @@ def load_AB_tma_gather( prefetch_A: Optional[Callable], copy_B: Callable, k_tile_cnt: Int32, - ) -> Tuple[cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState]]: + ) -> Tuple[ + cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState] + ]: """Unified TMA gather loading loop for both varlen_m and varlen_k. For varlen_m: a_prefetch_pipeline is None, copy_A receives k_tile as src_idx. @@ -1844,11 +2030,19 @@ def load_AB_tma_gather( for k_tile in cutlass.range(k_tile_cnt, unroll=1): smem_idx = ab_producer_state.index prefetch_out = () - if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free - prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state),) + if const_expr( + prefetch_A is not None + ): # Prefetch early, even before smem is free + prefetch_out = ( + prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state), + ) a_prefetch_consumer_state.advance() - is_tma_warp = warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps) - ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp) + is_tma_warp = warp_idx == self.ab_load_warp_id + ( + k_tile % self.num_ab_load_warps + ) + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status, is_tma_warp + ) tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state) if is_tma_warp: copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) @@ -1857,7 +2051,9 @@ def load_AB_tma_gather( ab_producer_state.advance() peek_ab_empty_status = Boolean(True) if k_tile + 1 < k_tile_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) return ab_producer_state, a_prefetch_consumer_state @cute.jit @@ -1882,7 +2078,9 @@ def mma( tCsSFB_compact_s2t: Optional[cute.Tensor] = None, tCtSFA_compact_s2t: Optional[cute.Tensor] = None, tCtSFB_compact_s2t: Optional[cute.Tensor] = None, - ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState, cute.TiledMma]: + ) -> Tuple[ + cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState, cute.TiledMma + ]: blockscaled = const_expr(tiled_copy_s2t_sfa is not None) if const_expr(blockscaled): assert all(x is not None for x in (tCtSFA, tCtSFB)) @@ -1923,15 +2121,27 @@ def mma( s2t_stage_coord = (None, None, None, None, ab_consumer_state.index) tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] - cute.copy(tiled_copy_s2t_sfa, tCsSFA_compact_s2t_staged, tCtSFA_compact_s2t) - cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t) + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True): k_blk_coord = (None, None, k_blk_idx, ab_consumer_state.index) if const_expr(blockscaled): # Set SFA/SFB tensor to tiled_mma sf_kblock_coord = (None, None, k_blk_idx) - tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator) - tiled_mma.set(tcgen05.Field.SFB, tCtSFB[sf_kblock_coord].iterator) + tiled_mma.set( + tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator + ) + tiled_mma.set( + tcgen05.Field.SFB, tCtSFB[sf_kblock_coord].iterator + ) cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc) tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Async arrive AB buffer empty @@ -1999,13 +2209,17 @@ def mainloop_s2t_copy_and_partition( # (MMA, MMA_MN, MMA_K) tCtSF_compact = cute.filter_zeros(tSF) # Make S2T CopyAtom and tiledCopy - copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype) + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype + ) tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) thr_copy_s2t = tiled_copy_s2t.get_slice(0) # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) - tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t, tCsSF_compact_s2t_ + ) # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t @@ -2047,19 +2261,25 @@ def epilog_tmem_copy_and_partition( # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile) # (EPI_TILE_M, EPI_TILE_N) - tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) - cAcc = cute.make_identity_tensor((self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1])) + cAcc = cute.make_identity_tensor( + (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]) + ) # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) cAcc_epi = cute.flat_divide(cAcc, epi_tile) # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) tTR_cAcc = thr_copy_t2r.partition_D(cAcc_epi) # (T2R, T2R_M, T2R_N) - tTR_rAcc = cute.make_rmem_tensor(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype) + tTR_rAcc = cute.make_rmem_tensor( + tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype + ) return tiled_copy_t2r, tTR_tAcc, tTR_rAcc def epilog_smem_store_and_partition( @@ -2120,9 +2340,14 @@ def epilog_smem_load_and_partition( store_op = copy_atom_r2s.op # m8n8 16-bit path if isinstance(store_op, StMatrix8x8x16bOp): - op = LdMatrix8x8x16bOp(num_matrices=store_op.num_matrices, transpose=store_op.transpose) + op = LdMatrix8x8x16bOp( + num_matrices=store_op.num_matrices, transpose=store_op.transpose + ) # m16n8 8-bit store -> m16n16 8-bit load - elif isinstance(store_op, StMatrix16x8x8bOp) and store_op.num_matrices in [2, 4]: + elif isinstance(store_op, StMatrix16x8x8bOp) and store_op.num_matrices in [ + 2, + 4, + ]: # transpose=True is enforced by the class op = LdMatrix16x16x8bOp(num_matrices=store_op.num_matrices // 2) else: @@ -2158,7 +2383,9 @@ def make_ab_pipeline( producer_cnt = self.num_ab_load_warps * 32 + ( 1 if const_expr(not self.use_2cta_instrs) else 2 ) - ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt) + ab_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, producer_cnt + ) # Each warp will contribute to the arrive count with the number of mcast size mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 consumer_arrive_cnt = mcast_size @@ -2204,7 +2431,9 @@ def make_acc_pipeline( self, cluster_layout_vmnk: cute.Layout, acc_pipeline_mbar_ptr: cute.Pointer ) -> pipeline.PipelineAsync: acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) - num_acc_consumer_threads = self.num_epi_warps * (2 if self.use_2cta_instrs else 1) + num_acc_consumer_threads = self.num_epi_warps * ( + 2 if self.use_2cta_instrs else 1 + ) acc_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, num_acc_consumer_threads ) @@ -2229,7 +2458,12 @@ def make_sched_pipeline( # Each warp will contribute 1 to the arrive count extra_warp_ids = (self.a_prefetch_warp_id,) if self.gather_A else () warps_per_cta = self.num_ab_load_warps + len( - (self.mma_warp_id, *self.epilog_warp_id, self.scheduler_warp_id, *extra_warp_ids) + ( + self.mma_warp_id, + *self.epilog_warp_id, + self.scheduler_warp_id, + *extra_warp_ids, + ) ) if has_C: warps_per_cta += 1 @@ -2252,7 +2486,9 @@ def make_a_prefetch_pipeline( self, a_prefetch_pipeline_mbar_ptr: cute.Pointer ) -> pipeline.PipelineAsync: producer_cnt = 32 - a_prefetch_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt) + a_prefetch_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, producer_cnt + ) consumer_arrive_cnt = self.num_ab_load_warps a_prefetch_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, consumer_arrive_cnt @@ -2319,7 +2555,9 @@ def _compute_stages( # Default D stages epi_stage = 4 if cute.size(epi_tile[1]) <= 16 else 2 - epi_c_stage = 0 if c_dtype is None else (4 if cute.size(epi_tile[1]) <= 16 else 2) + epi_c_stage = ( + 0 if c_dtype is None else (4 if cute.size(epi_tile[1]) <= 16 else 2) + ) # Calculate smem layout and size for one stage of A, B, and C a_smem_layout_staged_one = sm100_utils.make_smem_layout_a( @@ -2371,7 +2609,9 @@ def _compute_stages( if const_expr(prefetch_A_idx == "varlen_m"): mbar_helpers_bytes += Int32.width // 8 * cta_tile_shape_mnk[0] * 2 d_bytes_per_stage = ( - cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) if d_dtype is not None else 0 + cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) + if d_dtype is not None + else 0 ) epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage( epilogue_args, cta_tile_shape_mnk, epi_tile @@ -2391,7 +2631,9 @@ def _compute_stages( # Refine epilogue stages: # Calculate remaining smem after allocating for A/B stages and reserved bytes # Add remaining unused smem to epilogue - epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // (epi_bytes_per_stage) + epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // ( + epi_bytes_per_stage + ) return num_acc_stage, ab_stage, epi_stage, epi_c_stage @staticmethod @@ -2457,7 +2699,8 @@ def is_valid_dtypes( if ( acc_dtype not in {Float32, cutlass.Float16, Int32} or acc_dtype == cutlass.Float16 - and ab_dtype not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} + and ab_dtype + not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} or acc_dtype == Int32 and ab_dtype not in {cutlass.Uint8, cutlass.Int8} ): @@ -2522,7 +2765,11 @@ def is_valid_dtypes_and_scale_factor_vec_size( is_valid = True # Check valid ab_dtype - if ab_dtype not in {cutlass.Float4E2M1FN, cutlass.Float8E5M2, cutlass.Float8E4M3FN}: + if ab_dtype not in { + cutlass.Float4E2M1FN, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: is_valid = False # Check valid sf_vec_size @@ -2741,7 +2988,9 @@ def can_implement( """ can_implement = True # Skip unsupported types - if not GemmSm100.is_valid_dtypes(ab_dtype, ab_dtype, acc_dtype, d_dtype, a_major, b_major): + if not GemmSm100.is_valid_dtypes( + ab_dtype, ab_dtype, acc_dtype, d_dtype, a_major, b_major + ): can_implement = False # Skip invalid mma tile shape and cluster shape if not GemmSm100.is_valid_mma_tiler_and_cluster_shape( diff --git a/quack/gemm_sm120.py b/quack/gemm_sm120.py index 34e3e383..a8d06289 100644 --- a/quack/gemm_sm120.py +++ b/quack/gemm_sm120.py @@ -161,6 +161,7 @@ def kernel( epi_c_smem_layout: cute.ComposedLayout, tile_sched_params, TileSchedulerCls: cutlass.Constexpr[Callable], + problem_params, trace_ptr: Optional[cutlass.Int64] = None, ): from quack.trace import TraceContext @@ -267,12 +268,12 @@ def kernel( ) while work_tile.is_valid_tile: tctx.b("tma_load") - tile_coord_mnkl = work_tile.tile_idx - batch_idx = tile_coord_mnkl[3] + tile_coord_mnkl = work_tile.tile_coord_mnkl + batch_idx = self.problem_get_problem_idx(problem_params, work_tile) # Local_tile partition global tensors copy_A, prefetch_A = None, None if const_expr(not self.gather_A): - mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx) + mA_mk = self.problem_get_batch_A(problem_params, mA_mkl, varlen_manager, work_tile) # (bM, bK, RestK) gA_mk = cute.local_tile( mA_mk, @@ -296,7 +297,7 @@ def kernel( ) # (bN, bK, RestK) gB_nk = cute.local_tile( - varlen_manager.offset_batch_B(mB_nkl, batch_idx), + self.problem_get_batch_B(problem_params, mB_nkl, varlen_manager, work_tile), cute.select(self.cta_tile_shape_mnk, [1, 2]), (tile_coord_mnkl[1], None), ) @@ -311,7 +312,7 @@ def kernel( dst_tensor=sB, mcast_mask=b_mcast_mask, ) - len_k = varlen_manager.len_k(batch_idx) + len_k = self.problem_get_len_k(problem_params, varlen_manager, work_tile) k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) if const_expr(not self.gather_A): ab_producer_state = self.load_AB( @@ -407,15 +408,15 @@ def kernel( if const_expr(not varlen_k): ab_read_state.advance_iters(k_tile_cnt_static) else: - len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3]) + len_k = self.problem_get_len_k(problem_params, varlen_manager, work_tile) k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) ab_read_state.advance_iters(k_tile_cnt) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() while work_tile.is_valid_tile: - tile_coord_mnkl = work_tile.tile_idx - batch_idx = tile_coord_mnkl[3] - len_k = varlen_manager.len_k(batch_idx) + tile_coord_mnkl = work_tile.tile_coord_mnkl + batch_idx = self.problem_get_problem_idx(problem_params, work_tile) + len_k = self.problem_get_len_k(problem_params, varlen_manager, work_tile) k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) acc.fill(0.0) if const_expr(self.pingpong): @@ -450,7 +451,7 @@ def kernel( if const_expr(has_D): copy_D, _, _ = self.epilog_gmem_copy_and_partition( tma_atom_d, - varlen_manager.offset_batch_epi(mD_mnl, tile_coord_mnkl[3]), + self.problem_get_batch_epi(problem_params, mD_mnl, varlen_manager, work_tile), self.cta_tile_shape_mnk[:2], self.epi_tile, sD, @@ -460,7 +461,7 @@ def kernel( if const_expr(has_C): copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition( tma_atom_c, - varlen_manager.offset_batch_epi(mC_mnl, tile_coord_mnkl[3]), + self.problem_get_batch_epi(problem_params, mC_mnl, varlen_manager, work_tile), self.cta_tile_shape_mnk[:2], self.epi_tile, sC, @@ -535,7 +536,7 @@ def kernel( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() if work_tile.is_valid_tile: - len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3]) + len_k = self.problem_get_len_k(problem_params, varlen_manager, work_tile) k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) ab_read_state.advance_iters(k_tile_cnt) tile_scheduler.advance_to_next_work() diff --git a/quack/gemm_sm90.py b/quack/gemm_sm90.py index 7ddb280e..de32301d 100644 --- a/quack/gemm_sm90.py +++ b/quack/gemm_sm90.py @@ -8,38 +8,46 @@ import math -import cuda.bindings.driver as cuda - -import cutlass -import cutlass.cute as cute -import cutlass.pipeline as pipeline -from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass import Boolean, const_expr, Float16, Float32, Int32 from cutlass.cute.nvgpu import cpasync, warp, warpgroup -import cutlass.utils.hopper_helpers as sm90_utils -from cutlass import Int32, Float32, Float16, Boolean, const_expr +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait from cutlass.utils import LayoutEnum - - from dataclasses import dataclass - -from quack.cute_dsl_utils import ParamsBase from quack import layout_utils +from quack.cute_dsl_utils import ParamsBase +from quack.gemm_epilogue_plan import ( + default_epi_commit, + default_epi_tile_layout, + run_epilogue_plan, +) +from quack.gemm_problem_adapter import ( + default_problem_batch_a, + default_problem_batch_b, + default_problem_batch_epi, + default_problem_idx, + default_problem_len_k, + default_problem_to_underlying_arguments, + ProblemArguments as DefaultProblemArguments, + ProblemParams as DefaultProblemParams, +) +from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync from quack.tile_scheduler import ( - TileSchedulerOptions, - TileSchedulerArguments, + PersistenceMode, TileScheduler, - VarlenMTileSchedulerArguments, + TileSchedulerArguments, + TileSchedulerOptions, VarlenMTileScheduler, - PersistenceMode, + VarlenMTileSchedulerArguments, ) from quack.varlen_utils import VarlenArguments, VarlenManager - -# return PipelineStateWAdvance instead of PipelineState -from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync +from typing import Callable, Literal, Optional, Tuple, Type, Union +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils.hopper_helpers as sm90_utils import quack.copy_utils as copy_utils import quack.sm90_utils as quack_sm90_utils -from quack.rounding import RoundingMode - """ A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture using CUTE DSL. @@ -134,6 +142,80 @@ class EpilogueArguments: pass EpilogueParams = ParamsBase + ProblemArguments = DefaultProblemArguments + ProblemParams = DefaultProblemParams + + def problem_to_underlying_arguments( + self, args: Optional[ProblemArguments] = None, *, loc=None, ip=None + ) -> ProblemParams: + return default_problem_to_underlying_arguments(args, loc=loc, ip=ip) + + @cute.jit + def problem_get_problem_idx(self, params: ProblemParams, work): + return default_problem_idx(params, work) + + @cute.jit + def problem_get_len_k( + self, params: ProblemParams, varlen_manager: VarlenManager, work + ): + return default_problem_len_k(params, work, varlen_manager) + + @cute.jit + def problem_get_batch_A( + self, + params: ProblemParams, + mA_mkl: cute.Tensor, + varlen_manager: VarlenManager, + work, + ): + return default_problem_batch_a(params, mA_mkl, varlen_manager, work) + + @cute.jit + def problem_get_batch_B( + self, + params: ProblemParams, + mB_nkl: cute.Tensor, + varlen_manager: VarlenManager, + work, + ): + return default_problem_batch_b(params, mB_nkl, varlen_manager, work) + + @cute.jit + def problem_get_batch_epi( + self, + params: ProblemParams, + mX_mnl: Optional[cute.Tensor], + varlen_manager: VarlenManager, + work, + ): + return default_problem_batch_epi(params, mX_mnl, varlen_manager, work) + + def epi_plan_make_tile_layout(self, epi_tile_shape): + return default_epi_tile_layout(self, epi_tile_shape) + + @cute.jit + def epi_plan_commit( + self, + gmem_coord, + epi_buffer, + copy_D, + copy_postact, + postact_ctx, + tile_coord_mnkl, + is_tma_warp, + epi_store_pipeline, + ): + default_epi_commit( + self, + gmem_coord, + epi_buffer, + copy_D, + copy_postact, + postact_ctx, + tile_coord_mnkl, + is_tma_warp, + epi_store_pipeline, + ) def __init__( self, @@ -194,7 +276,8 @@ def __init__( ) else: if not ( - (tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512) + (tile_N % 16 == 0 and tile_N <= 256) + or (tile_N % 32 == 0 and tile_N <= 512) ): raise ValueError( "CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512" @@ -204,7 +287,9 @@ def __init__( raise ValueError("CTA tile shape M must be 64/128/192 if pingpong") tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128) if not (tile_N % 16 == 0 and tile_N <= tile_N_max): - raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}") + raise ValueError( + f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}" + ) if not self.pingpong: if tile_M == 320: # tile_M / 64 is not even so we have to split along N @@ -216,7 +301,9 @@ def __init__( atom_layout_m, atom_layout_n = 1, 2 else: atom_layout_m = ( - self.cta_tile_shape_mnk[0] // 64 if self.cta_tile_shape_mnk[0] < 256 else 2 + self.cta_tile_shape_mnk[0] // 64 + if self.cta_tile_shape_mnk[0] < 256 + else 2 ) atom_layout_n = 1 assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2] @@ -232,12 +319,16 @@ def __init__( self.is_b_mcast = self.num_mcast_ctas_b > 1 self.occupancy = 1 - self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2) + self.mma_warp_groups = math.prod(self.atom_layout_mnk) * ( + 1 if not self.pingpong else 2 + ) if self.pingpong: assert self.mma_warp_groups == 2 assert self.mma_warp_groups in [1, 2, 3] self.num_threads_per_warp_group = 128 - self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group + self.threads_per_cta = ( + self.mma_warp_groups + 1 + ) * self.num_threads_per_warp_group self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90") self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4 self.epilogue_barrier = pipeline.NamedBarrier( @@ -343,7 +434,9 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments): self.d_dtype, self.c_dtype, epilogue_args, - cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity + cutlass.utils.get_smem_capacity_in_bytes( + f"sm_{self.arch}" + ), # smem_capacity self.occupancy, ) self.sched_stage = 2 if self.pingpong else 1 @@ -381,6 +474,7 @@ def __call__( varlen_args: Optional[VarlenArguments], stream: cuda.CUstream, trace_ptr: Optional[cutlass.Int64] = None, + problem_args: Optional[ProblemArguments] = None, ): """Execute the GEMM operation in steps: - Setup static attributes @@ -420,7 +514,9 @@ def __call__( if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype): raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") if const_expr(self.a_dtype.width != self.b_dtype.width): - raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}") + raise TypeError( + f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}" + ) if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8): raise TypeError("a_dtype should be float16 or float8") @@ -445,7 +541,9 @@ def __call__( self.cluster_shape_mnk[1], ) tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors( - copy_utils.create_ragged_tensor_for_tma(mB, ragged_dim=1) if varlen_k else mB, + copy_utils.create_ragged_tensor_for_tma(mB, ragged_dim=1) + if varlen_k + else mB, b_smem_layout, (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]), self.cluster_shape_mnk[0], @@ -468,7 +566,10 @@ def __call__( self.epi_smem_layout_staged, self.epi_tile, op_type="store" - if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output) + if not ( + hasattr(epilogue_args, "add_to_output") + and epilogue_args.add_to_output + ) else "add", ) tma_atom_c, tma_tensor_c = None, None @@ -479,6 +580,7 @@ def __call__( epilogue_params = self.epi_to_underlying_arguments(epilogue_args) varlen_params = VarlenManager.to_underlying_arguments(varlen_args) + problem_params = self.problem_to_underlying_arguments(problem_args) TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_m) tile_sched_args = self.get_scheduler_arguments( @@ -489,14 +591,24 @@ def __call__( tile_sched_params, scheduler_args.max_active_clusters ) - epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0 - epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0 + epi_smem_size = ( + cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0 + ) + epi_c_smem_size = ( + cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0 + ) @cute.struct class SharedStorage: - ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2] - epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2] - sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2] + ab_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.ab_stage * 2 + ] + epi_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.epi_c_stage * 2 + ] + sched_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.sched_stage * 2 + ] sched_data: cute.struct.MemRange[Int32, self.sched_stage * 4] sD: cute.struct.Align[ cute.struct.MemRange[ @@ -512,11 +624,15 @@ class SharedStorage: ] epi: self.epi_get_smem_struct(epilogue_params) sA: cute.struct.Align[ - cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)], + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged) + ], self.buffer_align_bytes, ] sB: cute.struct.Align[ - cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)], + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged) + ], self.buffer_align_bytes, ] @@ -542,6 +658,7 @@ class SharedStorage: self.epi_c_smem_layout_staged, tile_sched_params, TileSchedulerCls, + problem_params, trace_ptr, ).launch( grid=grid, @@ -575,6 +692,7 @@ def kernel( epi_c_smem_layout: cute.ComposedLayout, tile_sched_params, TileSchedulerCls: cutlass.Constexpr[Callable], + problem_params, trace_ptr: Optional[cutlass.Int64] = None, ): """ @@ -650,17 +768,23 @@ def kernel( sched_data = storage.sched_data.get_tensor((4, self.sched_stage)) # Cluster arrive after barrier init - pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk[:-1], is_relaxed=True) + pipeline_init_arrive( + cluster_shape_mn=self.cluster_shape_mnk[:-1], is_relaxed=True + ) # Generate smem tensor A/B sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner) sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner) sD = None if const_expr(has_D): - sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner) + sD = storage.sD.get_tensor( + epi_smem_layout.outer, swizzle=epi_smem_layout.inner + ) sC = None if const_expr(has_C): - sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner) + sC = storage.sC.get_tensor( + epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner + ) epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage) varlen_manager = VarlenManager.create( @@ -691,8 +815,12 @@ def kernel( if const_expr(self.use_pdl): cute.arch.griddepcontrol_wait() # Get mcast mask - cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) - block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster) + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord( + cta_rank_in_cluster + ) a_mcast_mask = cute.make_layout_image_mask( cluster_layout_mnk, block_in_cluster_coord_mnk, mode=1 ) @@ -703,9 +831,13 @@ def kernel( b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0 # Persistent tile scheduling loop - is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id + is_scheduler_warp = ( + self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id + ) if const_expr(cute.size(cluster_layout_mnk) > 1): - is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0 + is_scheduler_warp = ( + is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0 + ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() ab_producer_state = make_pipeline_state( @@ -713,12 +845,14 @@ def kernel( ) while work_tile.is_valid_tile: tctx.b("tma_load") - tile_coord_mnkl = work_tile.tile_idx - batch_idx = tile_coord_mnkl[3] + tile_coord_mnkl = work_tile.tile_coord_mnkl + batch_idx = self.problem_get_problem_idx(problem_params, work_tile) # Local_tile partition global tensors copy_A, prefetch_A = None, None if const_expr(not self.gather_A): - mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx) + mA_mk = self.problem_get_batch_A( + problem_params, mA_mkl, varlen_manager, work_tile + ) # (bM, bK, RestK) gA_mk = cute.local_tile( mA_mk, @@ -742,7 +876,9 @@ def kernel( ) # (bN, bK, RestK) gB_nk = cute.local_tile( - varlen_manager.offset_batch_B(mB_nkl, batch_idx), + self.problem_get_batch_B( + problem_params, mB_nkl, varlen_manager, work_tile + ), cute.select(self.cta_tile_shape_mnk, [1, 2]), (tile_coord_mnkl[1], None), ) @@ -757,7 +893,9 @@ def kernel( dst_tensor=sB, mcast_mask=b_mcast_mask, ) - len_k = varlen_manager.len_k(batch_idx) + len_k = self.problem_get_len_k( + problem_params, varlen_manager, work_tile + ) k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) if const_expr(not self.gather_A): ab_producer_state = self.load_AB( @@ -774,7 +912,9 @@ def kernel( varlen_m=varlen_m, ) tctx.e("tma_load") - tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) + tile_scheduler.advance_to_next_work( + is_scheduler_warp=is_scheduler_warp + ) work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop if const_expr(self.pingpong and not varlen_k): @@ -795,7 +935,9 @@ def kernel( ) # Partition global tensor for TiledMMA_A/B/C tidx, _, _ = cute.arch.thread_idx() - warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + warp_group_idx = cute.arch.make_warp_uniform( + tidx // self.num_threads_per_warp_group + ) if const_expr(self.pingpong): tidx = tidx % self.num_threads_per_warp_group warp_group_thread_layout = cute.make_layout( @@ -824,9 +966,13 @@ def kernel( k_tile_cnt_static = cute.ceil_div( cute.size(mA_mkl, mode=[1]), self.cta_tile_shape_mnk[2] ) - c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile)) + c_tile_cnt = cute.size( + cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile) + ) - ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage) + ab_read_state = make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) epi_store_pipeline = self.make_epi_store_pipeline() epi_read_state = make_pipeline_state( pipeline.PipelineUserType.Consumer, self.epi_c_stage @@ -844,22 +990,32 @@ def kernel( if const_expr(not varlen_k): ab_read_state.advance_iters(k_tile_cnt_static) else: - len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3]) + len_k = self.problem_get_len_k( + problem_params, varlen_manager, work_tile + ) k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) ab_read_state.advance_iters(k_tile_cnt) # TODO: do we need to check if work_tile is valid? tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() while work_tile.is_valid_tile: - tile_coord_mnkl = work_tile.tile_idx - batch_idx = tile_coord_mnkl[3] - len_k = varlen_manager.len_k(batch_idx) + tile_coord_mnkl = work_tile.tile_coord_mnkl + batch_idx = self.problem_get_problem_idx(problem_params, work_tile) + len_k = self.problem_get_len_k( + problem_params, varlen_manager, work_tile + ) k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) if const_expr(self.pingpong): self.pingpong_barrier_sync(warp_group_idx, stage="mma") tctx.b("mma") ab_read_state = self.mma( - ab_pipeline, ab_read_state, mma_fn, acc, acc_slow, k_tile_cnt, warp_group_idx + ab_pipeline, + ab_read_state, + mma_fn, + acc, + acc_slow, + k_tile_cnt, + warp_group_idx, ) if const_expr(varlen_k): if k_tile_cnt == 0: @@ -875,7 +1031,9 @@ def kernel( if const_expr(has_D): copy_D, _, _ = self.epilog_gmem_copy_and_partition( tma_atom_d, - varlen_manager.offset_batch_epi(mD_mnl, batch_idx), + self.problem_get_batch_epi( + problem_params, mD_mnl, varlen_manager, work_tile + ), self.cta_tile_shape_mnk[:2], self.epi_tile, sD, @@ -885,7 +1043,9 @@ def kernel( if const_expr(has_C): copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition( tma_atom_c, - varlen_manager.offset_batch_epi(mC_mnl, batch_idx), + self.problem_get_batch_epi( + problem_params, mC_mnl, varlen_manager, work_tile + ), self.cta_tile_shape_mnk[:2], self.epi_tile, sC, @@ -893,7 +1053,9 @@ def kernel( ) copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline) - d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16 + d_dtype_for_layout = ( + self.d_dtype if self.d_dtype is not None else cutlass.BFloat16 + ) tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition( tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx ) @@ -901,13 +1063,22 @@ def kernel( tRS_rAcc = self.epi_retile_acc(acc, tRS_rD, tiled_copy_r2s) load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc) if const_expr(has_C): - tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition( - tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx + tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = ( + self.epilog_smem_load_and_partition( + tiled_mma, + self.c_layout, + self.c_dtype, + sC, + tRS_rD.layout, + tidx, + ) ) else: tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None - self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx) + self.epi_visit_acc( + epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx + ) epi_read_state, epi_producer_state = self.epilogue( epilogue_params, @@ -955,14 +1126,20 @@ def kernel( # Update starting mainloop pipeline state for the next tile if const_expr(not varlen_k): ab_read_state.advance_iters(k_tile_cnt_static) - tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups) + tile_scheduler.advance_to_next_work( + advance_count=self.mma_warp_groups + ) work_tile = tile_scheduler.get_current_work() else: tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() if work_tile.is_valid_tile: - len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3]) - k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) + len_k = self.problem_get_len_k( + problem_params, varlen_manager, work_tile + ) + k_tile_cnt = cute.ceil_div( + len_k, self.cta_tile_shape_mnk[2] + ) ab_read_state.advance_iters(k_tile_cnt) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -1016,7 +1193,9 @@ def load_AB( ab_producer_state.advance() peek_ab_empty_status = Boolean(True) if k_tile + 1 < k_tile_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) return ab_producer_state @cute.jit @@ -1038,13 +1217,19 @@ def load_AB_gather_A( # TMA load on B and cp.async on A for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1): prefetch_out = () - if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free + if const_expr( + prefetch_A is not None + ): # Prefetch early, even before smem is free prefetch_out = (prefetch_A(k_tile),) # Wait for A/B buffers to be empty before loading into them # Also sets the transaction barrier for the A/B buffers # A tiny bit faster to rotate the warp that does TMA - is_tma_warp = warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps) - ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp) + is_tma_warp = warp_idx == self.ab_load_warp_id + ( + k_tile % self.num_ab_load_warps + ) + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status, is_tma_warp + ) smem_idx = ab_producer_state.index # A bit faster to load B first while we calculate the indices for A if is_tma_warp: @@ -1056,15 +1241,23 @@ def load_AB_gather_A( ab_producer_state.advance() peek_ab_empty_status = Boolean(True) if k_tile + 1 < k_tile_cnt: - peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) # bound checking in the K dimension on the last k_tile if 0 < k_tile_cnt: k_tile = k_tile_cnt - 1 prefetch_out = () - if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free + if const_expr( + prefetch_A is not None + ): # Prefetch early, even before smem is free prefetch_out = (prefetch_A(k_tile, pred=True),) - is_tma_warp = warp_idx == self.ab_load_warp_id + k_tile % self.num_ab_load_warps - ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp) + is_tma_warp = ( + warp_idx == self.ab_load_warp_id + k_tile % self.num_ab_load_warps + ) + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status, is_tma_warp + ) smem_idx = ab_producer_state.index if is_tma_warp: tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state) @@ -1087,7 +1280,9 @@ def _make_gather_A_copy( varlen_m = varlen_manager.varlen_m mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx) if const_expr(varlen_m): - gAIdx = cute.local_tile(mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],)) + gAIdx = cute.local_tile( + mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],) + ) mA_mk = mA_mkl else: gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],)) @@ -1099,7 +1294,9 @@ def _make_gather_A_copy( tiled_copy_A = self._make_gmem_tiled_copy_A( mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32 ) - dma_tidx = cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id + dma_tidx = ( + cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id + ) thr_copy_A = tiled_copy_A.get_slice(dma_tidx) copy_A, prefetch_A = None, None if const_expr(varlen_m): @@ -1144,7 +1341,11 @@ def mma( for k_tile in cutlass.range(num_prologue_mma): # Wait for A/B buffer to be ready ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status) - mma_fn(A_idx=ab_read_state.index, B_idx=ab_read_state.index, zero_init=zero_init) + mma_fn( + A_idx=ab_read_state.index, + B_idx=ab_read_state.index, + zero_init=zero_init, + ) zero_init = Boolean(False) ab_read_state.advance() peek_ab_full_status = Boolean(True) @@ -1162,7 +1363,11 @@ def mma( ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status) if const_expr(self.fp8_slow_accum): zero_init = Boolean(True) - mma_fn(A_idx=ab_read_state.index, B_idx=ab_read_state.index, zero_init=zero_init) + mma_fn( + A_idx=ab_read_state.index, + B_idx=ab_read_state.index, + zero_init=zero_init, + ) zero_init = Boolean(False) # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete if const_expr(not self.fp8_slow_accum): @@ -1217,142 +1422,34 @@ def epilogue( tidx: Int32, is_tma_warp: Boolean, ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]: - has_C = const_expr(tRS_rC is not None) - has_D = const_expr(copy_D is not None) - - # Setup postact output (returns None for default epilogue, context tuple for Act) - postact_ctx = self.epi_setup_postact( - params, - epi_smem_tensors, - tiled_copy_r2s, - tiled_copy_t2r, - tile_coord_mnkl, - varlen_manager, - tidx, - ) - - epi_tile_shape = cute.zipped_divide( - cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile - ).shape[1] - # We iterate over epi tiles in the N dimension first before the M dimension - epi_tile_layout = cute.make_ordered_layout(epi_tile_shape, order=(1, 0)) - epi_tile_num = cute.size(epi_tile_shape) - num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num - - epi_tensors = self.epi_begin( + return run_epilogue_plan( + self, params, epi_smem_tensors, + epi_pipeline, + epi_store_pipeline, + epi_read_state, + epi_producer_state, epi_tile, + load_acc_subtile, + tRS_rD, + tRS_rC, tiled_copy_t2r, tiled_copy_r2s, + tRS_sD, + tiled_copy_s2r, + tSR_rC, + tSR_sC, + copy_D, + copy_C, tile_coord_mnkl, varlen_manager, epilogue_barrier, + tile_scheduler, tidx, + is_tma_warp, ) - if const_expr(copy_C is not None): - for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1): - gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx) - if is_tma_warp: - epi_pipeline.producer_acquire(epi_producer_state) - copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) - epi_pipeline.producer_commit(epi_producer_state) - epi_producer_state.advance() - - for epi_idx in cutlass.range_constexpr(epi_tile_num): - # The global memory coordinate for the current epi tile - gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) - # Copy from acc to D registers - load_acc_subtile(tRS_rD, epi_idx) - epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord) - if const_expr(has_C): - epi_pipeline.consumer_wait(epi_read_state) - cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC) - # Fence to make sure shared memory read is visible to TMA load - cute.arch.fence_view_async_shared() - cute.arch.sync_warp() - with cute.arch.elect_one(): - epi_pipeline.consumer_release(epi_read_state) - epi_read_state.advance() - if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num): - gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage) - if is_tma_warp: - epi_pipeline.producer_acquire(epi_producer_state) - copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) - epi_pipeline.producer_commit(epi_producer_state) - epi_producer_state.advance() - tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC) - # Convert and store postact if this epilogue produces one - if const_expr(postact_ctx is not None): - tRS_rPostAct_out = self.epi_convert_postact( - tRS_rPostAct, - epi_loop_tensors["sr_seed"], - tidx, - tile_coord_mnkl, - num_prev_subtiles, - epi_idx, - ) - if is_tma_warp: - epi_store_pipeline.producer_acquire() - epilogue_barrier.arrive_and_wait() - # Copy from D registers to shared memory - epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage - if const_expr(has_D): - if const_expr( - self.rounding_mode == RoundingMode.RS - and self.acc_dtype == cutlass.Float32 - and self.d_dtype == cutlass.BFloat16 - ): - seed = epi_loop_tensors["sr_seed"] + ( - tile_coord_mnkl[0] * 65537 - + tile_coord_mnkl[1] * 257 - + tile_coord_mnkl[3] * 17 - + (num_prev_subtiles + epi_idx) * 7 - ) - copy_utils.sr_cvt_copy( - tiled_copy_r2s, - tRS_rD, - tRS_sD[None, None, None, epi_buffer], - seed, - tidx, - ) - else: - copy_utils.cvt_copy( - tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer] - ) - # Copy postact from registers to shared memory - if const_expr(postact_ctx is not None): - tiled_copy_postact_r2s, tRS_sPostAct, copy_postact = postact_ctx - cute.copy( - tiled_copy_postact_r2s, - tiled_copy_postact_r2s.retile(tRS_rPostAct_out), - tRS_sPostAct[None, None, None, epi_buffer], - ) - # Fence and barrier to make sure shared memory store is visible to TMA store - cute.arch.fence_view_async_shared() - epilogue_barrier.arrive_and_wait() - # Copy from shared memory to global memory - if is_tma_warp: - if const_expr(has_D): - copy_D(src_idx=epi_buffer, dst_idx=gmem_coord) - if const_expr(postact_ctx is not None): - copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord) - epi_store_pipeline.producer_commit() - - self.epi_end( - params, - epi_tensors, - epi_tile, - tiled_copy_t2r, - tiled_copy_r2s, - tile_coord_mnkl, - varlen_manager, - tidx, - ) - - return epi_read_state, epi_producer_state - def get_scheduler_class(self, varlen_m: bool = False): """Return the scheduler class to use. Override in subclasses for custom schedulers.""" return TileScheduler if not varlen_m else VarlenMTileScheduler @@ -1401,7 +1498,11 @@ def get_scheduler_arguments( persistence_mode=persistence_mode, ) else: - assert (mD is not None) or (epilogue_args.mPostAct is not None) or (not self.gather_A) + assert ( + (mD is not None) + or (epilogue_args.mPostAct is not None) + or (not self.gather_A) + ) problem_shape_ntile_mnl = ( None, cute.ceil_div(cute.size(mB, mode=[0]), self.cta_tile_shape_mnk[1]), @@ -1425,7 +1526,9 @@ def epi_retile_acc(self, acc, tRS_rD, tiled_copy_r2s): return cute.flat_divide(acc, tRS_rD.layout) @cute.jit - def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int): + def epi_load_acc_subtile( + self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int + ): cute.autovec_copy(tRS_rAcc[None, None, None, epi_idx], tRS_rD) @cute.jit @@ -1444,7 +1547,10 @@ def epi_begin( return () def epi_begin_loop( - self, params: EpilogueParams, epi_tensors: Tuple[cute.Tensor, ...], epi_coord: cute.Coord + self, + params: EpilogueParams, + epi_tensors: Tuple[cute.Tensor, ...], + epi_coord: cute.Coord, ) -> Tuple[cute.Tensor, ...]: return () @@ -1503,10 +1609,14 @@ def epi_smem_bytes_per_stage( def epi_get_smem_struct(self, params: EpilogueParams): return cute.struct.MemRange[Int32, 0] # Dummy struct - def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]: + def epi_get_smem_tensors( + self, params: EpilogueParams, storage + ) -> Tuple[cute.Tensor, ...]: return tuple() - def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]): + def pingpong_barrier_sync( + self, warp_group_idx: Int32, stage: Literal["mma", "epi"] + ): assert stage in ["mma", "epi"] barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0 cute.arch.barrier( @@ -1514,7 +1624,9 @@ def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: Literal["mma", "ep number_of_threads=2 * self.num_threads_per_warp_group, ) - def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]): + def pingpong_barrier_arrive( + self, warp_group_idx: Int32, stage: Literal["mma", "epi"] + ): assert stage in ["mma", "epi"] barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0 cute.arch.barrier_arrive( @@ -1554,7 +1666,9 @@ def epilog_smem_store_and_partition( thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None sD_shape = sD.shape[:2] if sD is not None else self.epi_tile - tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape + tRS_rD_shape = thr_copy_r2s.partition_S( + cute.make_identity_tensor(sD_shape) + ).shape tRS_rD = cute.make_rmem_tensor(tRS_rD_shape, self.acc_dtype) return tiled_copy_r2s, tRS_rD, tRS_sD @@ -1589,7 +1703,8 @@ def epilog_gmem_copy_and_partition( gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2]) tDgD_for_tma_partition = cute.zipped_divide(gD, epi_tile) is_s2g = isinstance( - atom.op, (cpasync.CopyBulkTensorTileS2GOp, cpasync.CopyReduceBulkTensorTileS2GOp) + atom.op, + (cpasync.CopyBulkTensorTileS2GOp, cpasync.CopyReduceBulkTensorTileS2GOp), ) src_tensor, dst_tensor = ( (sD, tDgD_for_tma_partition) if is_s2g else (tDgD_for_tma_partition, sD) @@ -1609,15 +1724,21 @@ def make_ab_pipeline( ab_pipeline_mbar_ptr: cute.Pointer, ): # Threads/warps participating in this pipeline - producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_warps * 32 - ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt) + producer_cnt = ( + 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_warps * 32 + ) + ab_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, producer_cnt + ) # Each warp will contribute to the arrive count with the number of mcast size mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 consumer_arrive_cnt = mcast_size * tiled_mma.size // cute.arch.WARP_SIZE ab_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, consumer_arrive_cnt ) - pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync + pipeline_cls = ( + pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync + ) return pipeline_cls.create( barrier_storage=ab_pipeline_mbar_ptr, num_stages=self.ab_stage, @@ -1629,7 +1750,9 @@ def make_ab_pipeline( ) def make_epi_pipeline( - self, c_smem_layout: cute.Layout | cute.ComposedLayout, epi_pipeline_mbar_ptr: cute.Pointer + self, + c_smem_layout: cute.Layout | cute.ComposedLayout, + epi_pipeline_mbar_ptr: cute.Pointer, ): # Threads/warps participating in this pipeline epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) @@ -1651,13 +1774,18 @@ def make_epi_pipeline( def make_epi_store_pipeline(self): # Threads/warps participating in tma store pipeline num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE - epi_store_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_epi_threads) + epi_store_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_epi_threads + ) return pipeline.PipelineTmaStore.create( num_stages=self.epi_stage, producer_group=epi_store_producer_group ) def make_sched_pipeline( - self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool + self, + cluster_layout_mnk: cute.Layout, + sched_pipeline_mbar_ptr: cute.Pointer, + varlen_k: bool, ): # Threads/warps participating in this pipeline sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) @@ -1714,7 +1842,9 @@ def _compute_stages( """ epi_stage = 4 if epi_tile[1] <= 16 else 2 - d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0 + d_bytes_per_stage = ( + cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0 + ) epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage( epilogue_args, cta_tile_shape_mnk, epi_tile ) @@ -1726,7 +1856,8 @@ def _compute_stages( a_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None)) b_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None)) ab_bytes_per_stage = ( - cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8 + cute.size(a_shape) * a_dtype.width // 8 + + cute.size(b_shape) * b_dtype.width // 8 ) mbar_helpers_bytes = 1024 @@ -1737,7 +1868,9 @@ def _compute_stages( # Calculate remaining smem after allocating for A/B stages and reserved bytes # Add remaining unused smem to epilogue if epi_bytes_per_stage > 0: - epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage + epi_stage += ( + remaining_bytes - ab_bytes_per_stage * ab_stage + ) // epi_bytes_per_stage return ab_stage, epi_stage, epi_c_stage @staticmethod @@ -1797,7 +1930,10 @@ def _make_smem_layouts( c_layout: Optional[LayoutEnum], epi_c_stage: int, ) -> Tuple[ - cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout] + cute.ComposedLayout, + cute.ComposedLayout, + cute.ComposedLayout, + Optional[cute.ComposedLayout], ]: """Create shared memory layouts for A, B, and C tensors. @@ -1894,7 +2030,9 @@ def _make_tma_epi_atoms_and_tensors( """ assert op_type in ["load", "store", "add"] epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0)) - d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile) + d_cta_v_layout = cute.composition( + cute.make_identity_layout(tensor_d.shape), epi_tile + ) op = ( cpasync.CopyBulkTensorTileG2SOp() if op_type == "load" @@ -2001,10 +2139,20 @@ def is_valid_dtypes( :rtype: bool """ is_valid = True - if a_dtype not in {Float16, cutlass.BFloat16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}: + if a_dtype not in { + Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: is_valid = False # tested b_dtype - if b_dtype not in {Float16, cutlass.BFloat16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}: + if b_dtype not in { + Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: is_valid = False if acc_dtype not in {Float32, Float16}: is_valid = False @@ -2026,6 +2174,8 @@ def is_valid_dtypes( is_valid = False # for Float8 types, this implementation only supports k-major layout - if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"): + if (a_dtype.width == 8 and a_major != "k") or ( + b_dtype.width == 8 and b_major != "k" + ): is_valid = False return is_valid diff --git a/quack/gemm_sq_reduce.py b/quack/gemm_sq_reduce.py index d3ee4ddb..96c8e32c 100644 --- a/quack/gemm_sq_reduce.py +++ b/quack/gemm_sq_reduce.py @@ -254,6 +254,6 @@ def gemm_sq_reduce( varlen_args = make_varlen_args(None, None, None) if device_capacity[0] in [10, 11]: - compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None) + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None, None) else: - compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None) + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None) diff --git a/quack/gemm_symmetric.py b/quack/gemm_symmetric.py index 812bb5df..39822368 100644 --- a/quack/gemm_symmetric.py +++ b/quack/gemm_symmetric.py @@ -1,3 +1,4 @@ + from typing import Tuple, Optional, Callable from torch import Tensor @@ -30,172 +31,36 @@ from quack.rounding import RoundingMode +from quack.gemm_epilogue_plan import symmetric_epi_commit + class GemmSymmetricMixin(GemmActMixin): def get_scheduler_class(self, varlen_m: bool = False): return TriangularTileScheduler @cute.jit - def epilogue( + def epi_plan_commit( self, - params: GemmActMixin.EpilogueParams, - epi_smem_tensors: Tuple[cute.Tensor, ...], - epi_pipeline: cutlass.pipeline.PipelineAsync, - epi_store_pipeline: cutlass.pipeline.PipelineAsync, - epi_read_state: cutlass.pipeline.PipelineState, - epi_producer_state: cutlass.pipeline.PipelineState, - epi_tile: cute.Tile, - load_acc_subtile: Callable, - tRS_rD: cute.Tensor, - tRS_rC: Optional[cute.Tensor], - tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100 - tiled_copy_r2s: cute.TiledCopy, - tRS_sD: cute.Tensor, - tiled_copy_s2r: Optional[cute.TiledCopy], - tSR_rC: Optional[cute.Tensor], - tSR_sC: Optional[cute.Tensor], - copy_D: Optional[Callable], - copy_C: Optional[Callable], - tile_coord_mnkl: cute.Coord, - varlen_manager: VarlenManager, - epilogue_barrier: cutlass.pipeline.NamedBarrier, - tile_scheduler, - tidx: Int32, - is_tma_warp: Boolean, - ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]: - has_C = const_expr(tRS_rC is not None) - has_D = const_expr(copy_D is not None) - - tiled_copy_postact_r2s, tRS_sPostAct, copy_postact = self.epi_setup_postact( - params, - epi_smem_tensors, - tiled_copy_r2s, - tiled_copy_t2r, - tile_coord_mnkl, - varlen_manager, - tidx, - ) - - # We iterate over epi tiles in the N dimension first before the M dimension - epi_tile_shape = cute.zipped_divide( - cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile - ).shape[1] - epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1)) - epi_tile_num = cute.size(epi_tile_shape) - num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num - - epi_tensors = self.epi_begin( - params, - epi_smem_tensors, - epi_tile, - tiled_copy_t2r, - tiled_copy_r2s, + gmem_coord, + epi_buffer, + copy_D, + copy_postact, + postact_ctx, + tile_coord_mnkl, + is_tma_warp, + epi_store_pipeline, + ): + symmetric_epi_commit( + self, + gmem_coord, + epi_buffer, + copy_D, + copy_postact, + postact_ctx, tile_coord_mnkl, - varlen_manager, - epilogue_barrier, - tidx, + is_tma_warp, + epi_store_pipeline, ) - if const_expr(copy_C is not None): - for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1): - gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx) - if is_tma_warp: - epi_pipeline.producer_acquire(epi_producer_state) - copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) - epi_pipeline.producer_commit(epi_producer_state) - epi_producer_state.advance() - - for epi_idx in cutlass.range_constexpr(epi_tile_num): - # The global memory coordinate for the current epi tile - gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) - # Copy from acc to D registers - load_acc_subtile(tRS_rD, epi_idx) - epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord) - if const_expr(has_C): - epi_pipeline.consumer_wait(epi_read_state) - cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC) - # Fence to make sure shared memory read is visible to TMA load - cute.arch.fence_view_async_shared() - cute.arch.sync_warp() - with cute.arch.elect_one(): - epi_pipeline.consumer_release(epi_read_state) - epi_read_state.advance() - if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num): - gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage) - if is_tma_warp: - epi_pipeline.producer_acquire(epi_producer_state) - copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) - epi_pipeline.producer_commit(epi_producer_state) - epi_producer_state.advance() - tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC) - tRS_rPostAct_out = self.epi_convert_postact( - tRS_rPostAct, - epi_loop_tensors["sr_seed"], - tidx, - tile_coord_mnkl, - num_prev_subtiles, - epi_idx, - ) - if is_tma_warp: - epi_store_pipeline.producer_acquire() - epilogue_barrier.arrive_and_wait() - # Copy from D registers to shared memory - epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage - if const_expr(has_D): - if const_expr( - self.rounding_mode == RoundingMode.RS - and self.acc_dtype == cutlass.Float32 - and self.d_dtype == cutlass.BFloat16 - ): - seed = epi_loop_tensors["sr_seed"] + ( - tile_coord_mnkl[0] * 65537 - + tile_coord_mnkl[1] * 257 - + tile_coord_mnkl[3] * 17 - + (num_prev_subtiles + epi_idx) * 7 - ) - copy_utils.sr_cvt_copy( - tiled_copy_r2s, - tRS_rD, - tRS_sD[None, None, None, epi_buffer], - seed, - tidx, - ) - else: - copy_utils.cvt_copy( - tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer] - ) - cute.copy( - tiled_copy_postact_r2s, - tiled_copy_postact_r2s.retile(tRS_rPostAct_out), - tRS_sPostAct[None, None, None, epi_buffer], - ) - pid_m = tile_coord_mnkl[0] - pid_n = tile_coord_mnkl[1] - # Fence and barrier to make sure shared memory store is visible to TMA store - cute.arch.fence_view_async_shared() - epilogue_barrier.arrive_and_wait() - # Copy from shared memory to global memory - if is_tma_warp: - square_tile_m = pid_m // self.cluster_shape_mnk[0] - square_tile_n = pid_n // self.cluster_shape_mnk[1] - if const_expr(has_D): - copy_D(src_idx=epi_buffer, dst_idx=gmem_coord) - if square_tile_m != square_tile_n: # don't write twice to the same tile - copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord) - epi_store_pipeline.producer_commit() - - self.epi_end( - params, - epi_tensors, - epi_tile, - tiled_copy_t2r, - tiled_copy_r2s, - tile_coord_mnkl, - varlen_manager, - tidx, - ) - - return epi_read_state, epi_producer_state - class GemmSymmetricSm90(GemmSymmetricMixin, GemmSm90): pass @@ -389,6 +254,6 @@ def scalar_arg(scalar, mode): varlen_args = None if device_capacity[0] in [10, 11]: - compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None) + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None, None) else: - compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None) + compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None) diff --git a/quack/gemm_tvm_ffi_utils.py b/quack/gemm_tvm_ffi_utils.py index 18e8ec56..c2e9ab2b 100644 --- a/quack/gemm_tvm_ffi_utils.py +++ b/quack/gemm_tvm_ffi_utils.py @@ -4,15 +4,13 @@ from functools import partial -import cutlass.cute as cute -from cutlass import Int32, Int64, Float32 +from cutlass import Float32, Int32, Int64 from cutlass.cute.runtime import make_ptr - from quack.compile_utils import make_fake_tensor as fake_tensor from quack.cute_dsl_utils import torch2cute_dtype_map from quack.tile_scheduler import TileSchedulerOptions from quack.varlen_utils import VarlenArguments - +import cutlass.cute as cute def div_for_dtype(dtype): """16-byte alignment: divisibility in elements = 128 // dtype_width_bits.""" @@ -66,7 +64,9 @@ def make_scheduler_args( raster_order=None, max_swizzle_size=max_swizzle_size, tile_count_semaphore=( - tile_count_semaphore.data_ptr() if tile_count_semaphore is not None else None + tile_count_semaphore.data_ptr() + if tile_count_semaphore is not None + else None ), batch_idx_permute=batch_idx_permute, ) @@ -77,7 +77,9 @@ def make_fake_scheduler_args(has_semaphore, has_batch_idx_permute, l_sym): max_active_clusters=Int32(1), max_swizzle_size=Int32(8), tile_count_semaphore=( - make_ptr(Int32, 0, cute.AddressSpace.gmem, assumed_align=4) if has_semaphore else None + make_ptr(Int32, 0, cute.AddressSpace.gmem, assumed_align=4) + if has_semaphore + else None ), batch_idx_permute=( fake_tensor(Int32, (l_sym,), leading_dim=0, divisibility=4) @@ -103,13 +105,19 @@ def make_fake_varlen_args(varlen_m, varlen_k, gather_A, aidx_len): num_seqlens = cute.sym_int() return VarlenArguments( mCuSeqlensM=( - fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4) if varlen_m else None + fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4) + if varlen_m + else None ), mCuSeqlensK=( - fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4) if varlen_k else None + fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4) + if varlen_k + else None ), mAIdx=( - fake_tensor(Int32, (aidx_len,), leading_dim=0, divisibility=4) if gather_A else None + fake_tensor(Int32, (aidx_len,), leading_dim=0, divisibility=4) + if gather_A + else None ), ) @@ -188,6 +196,7 @@ def compile_gemm_kernel( has_trace_ptr=False, use_tma_gather=False, concat_layout=None, + problem_args=None, ): """Build GemmCls instance, apply SM90 partial, and cute.compile with TVM-FFI.""" if device_capacity[0] in [9, 12]: @@ -225,5 +234,6 @@ def compile_gemm_kernel( stream, *sf_args, trace_ptr, + problem_args, options="--enable-tvm-ffi", ) diff --git a/quack/gemm_work.py b/quack/gemm_work.py new file mode 100644 index 00000000..39a6e0e4 --- /dev/null +++ b/quack/gemm_work.py @@ -0,0 +1,44 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import NamedTuple, Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Boolean, Int32 +from quack.cute_dsl_utils import mlir_namedtuple, StaticTypes + + +@mlir_namedtuple +class WorkDesc(NamedTuple): + tile_coord_mnkl: cute.Coord + problem_idx: Int32 + k_tile_begin: Int32 = Int32(0) + k_tile_count: Optional[Int32] = None + split_k_idx: Int32 = Int32(0) + split_k_parts: Int32 = Int32(1) + is_final_split: Boolean = Boolean(True) + is_valid_tile: Boolean = Boolean(False) + + def __extract_mlir_values__(self): + values = [] + for field_val in self: + if field_val is None or isinstance(field_val, StaticTypes): + continue + values.extend(cutlass.extract_mlir_values(field_val)) + return values + + @property + def tile_idx(self): + return self.tile_coord_mnkl + + @property + def batch_idx(self): + return self.tile_coord_mnkl[3] + + +def make_work_desc(tile_coord_mnkl: cute.Coord, is_valid_tile: Boolean) -> WorkDesc: + return WorkDesc( + tile_coord_mnkl=tile_coord_mnkl, + problem_idx=tile_coord_mnkl[3], + is_valid_tile=is_valid_tile, + ) diff --git a/quack/tile_scheduler.py b/quack/tile_scheduler.py index 5cc6ac4f..e798363b 100644 --- a/quack/tile_scheduler.py +++ b/quack/tile_scheduler.py @@ -10,6 +10,7 @@ import quack.utils as utils from quack.fast_math import FastDivmod +from quack.gemm_work import WorkDesc, make_work_desc from quack.pipeline import PipelineStateWAdvance from quack.cute_dsl_utils import mlir_namedtuple @@ -308,7 +309,7 @@ def _delinearize_work_idx( block_zero_only: bool = False, loc=None, ip=None, - ) -> cutlass.utils.WorkTileInfo: + ) -> WorkDesc: params = self.params if const_expr(is_valid is None): if const_expr(params.persistence_mode == PersistenceMode.NONE): @@ -336,10 +337,10 @@ def _delinearize_work_idx( else params.batch_idx_permute[bidz_] ) tile_coord_mnkl = (pid_m, pid_n, None, batch_idx) - return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid) + return make_work_desc(tile_coord_mnkl, is_valid) @cute.jit - def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + def get_current_work(self, *, loc=None, ip=None) -> WorkDesc: params = self.params pid_m, pid_n, batch_idx, is_valid = Int32(0), Int32(0), Int32(0), Boolean(False) if const_expr(params.persistence_mode == PersistenceMode.NONE): @@ -361,10 +362,10 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: self._pipeline_state.advance() is_valid = Boolean(is_valid_i32) tile_coord_mnkl = (pid_m, pid_n, None, batch_idx) - return cutlass.utils.WorkTileInfo(tile_coord_mnkl, Boolean(is_valid)) + return make_work_desc(tile_coord_mnkl, Boolean(is_valid)) # @cute.jit - def initial_work_tile_info(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkDesc: return self._delinearize_work_idx(self._current_work_idx, loc=loc, ip=ip) # if is_scheduler_warp: # work_tile_info = self._delinearize_work_idx(block_zero_only=True, loc=loc, ip=ip) @@ -428,7 +429,7 @@ def _fetch_next_work_idx(self, *, loc=None, ip=None) -> Int32 | Tuple[Int32, Int @cute.jit def write_work_tile_to_smem( - self, work_tile_info: cutlass.utils.WorkTileInfo, *, loc=None, ip=None + self, work_tile_info: WorkDesc, *, loc=None, ip=None ): params = self.params if const_expr(self._sched_smem is not None): @@ -733,7 +734,7 @@ def _delinearize_work_idx( block_zero_only: bool = False, loc=None, ip=None, - ) -> cutlass.utils.WorkTileInfo: + ) -> WorkDesc: params = self.params if const_expr(is_valid is None): if const_expr(params.persistence_mode == PersistenceMode.NONE): @@ -764,7 +765,7 @@ def _delinearize_work_idx( # if tidx == 0: # cute.printf("bidx = {}, bidy = {}, group_id = {}, id_in_group = {}, group_size_actual = {}, group_col = {}, group_remainder = {}, cid_n_in_group = {}, cid_m_in_group = {}, cid_m = {}, cid_n = {}, is_valid = {}", # bidx, bidy, group_id, id_in_group, group_size_actual, group_col, group_remainder, cid_n_in_group, cid_m_in_group, cid_m, cid_n, is_valid) - return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid) + return make_work_desc(tile_coord_mnkl, is_valid) @dataclass @@ -1025,7 +1026,7 @@ def _delinearize_work_idx( block_zero_only: bool = False, loc=None, ip=None, - ) -> cutlass.utils.WorkTileInfo: + ) -> WorkDesc: assert bidz is None params = self.params lane_idx = cute.arch.lane_idx() @@ -1091,4 +1092,4 @@ def _delinearize_work_idx( tile_coord_mnkl = (pid_m, pid_n, None, batch_idx) self._current_batch_idx = batch_idx self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch - return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid) + return make_work_desc(tile_coord_mnkl, is_valid)