diff --git a/README.md b/README.md
index 24b4cadb..b5dd0f44 100644
--- a/README.md
+++ b/README.md
@@ -45,6 +45,7 @@ from quack import rmsnorm, softmax, cross_entropy
[blogpost](media/2025-07-10-membound-sol.md) on how to get memory-bound kernels
to speed-of-light, right in the comfort of Python thanks to the [CuTe-DSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html).
+
## Performance
diff --git a/quack/__init__.py b/quack/__init__.py
index 34a56ab4..6dcfae33 100644
--- a/quack/__init__.py
+++ b/quack/__init__.py
@@ -6,6 +6,7 @@
from quack.softmax import softmax
from quack.cross_entropy import cross_entropy
from quack.rounding import RoundingMode
+from quack.gemm_interface import gemm, gemm_grouped
if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
@@ -19,5 +20,7 @@
"rmsnorm",
"softmax",
"cross_entropy",
+ "gemm",
+ "gemm_grouped",
"RoundingMode",
]
diff --git a/quack/gemm.py b/quack/gemm.py
index 1dcb7096..edb7e792 100644
--- a/quack/gemm.py
+++ b/quack/gemm.py
@@ -32,6 +32,34 @@
)
+from cutlass import Float32, Int32
+from quack.cute_dsl_utils import (
+ get_device_capacity,
+ get_max_active_clusters,
+ torch2cute_dtype_map,
+)
+from quack.gemm_default_epi import (
+ GemmDefaultEpiMixin,
+ GemmDefaultSm100,
+ GemmDefaultSm120,
+ GemmDefaultSm90,
+)
+from quack.gemm_problem_adapter import (
+ GroupedProblemAdapterMixin,
+ GroupedProblemArguments,
+)
+from quack.gemm_tvm_ffi_utils import (
+ compile_gemm_kernel,
+ get_dtypes,
+ get_majors,
+ make_fake_gemm_tensors,
+ make_fake_scheduler_args,
+ make_fake_varlen_args,
+ make_scheduler_args,
+ make_varlen_args,
+ perm3d,
+)
+
@jit_cache
def _compile_gemm(
a_dtype,
@@ -63,6 +91,7 @@ def _compile_gemm(
rounding_mode,
sr_seed_mode,
has_trace_ptr,
+ grouped,
):
sm_to_cls = {
9: GemmDefaultSm90,
@@ -71,6 +100,10 @@ def _compile_gemm(
12: GemmDefaultSm120,
}
GemmCls = sm_to_cls[device_capacity[0]]
+ if grouped:
+ GemmCls = type(
+ f"Grouped{GemmCls.__name__}", (GroupedProblemAdapterMixin, GemmCls), {}
+ )
mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors(
a_dtype,
b_dtype,
@@ -115,6 +148,14 @@ def fake_scalar(mode, dtype=Float32):
)
aidx_len = m if varlen_m else (k if varlen_k else None)
varlen_args = make_fake_varlen_args(varlen_m, varlen_k, gather_A, aidx_len)
+ problem_args = (
+ GroupedProblemArguments(
+ mProblemIndex=fake_tensor(Int32, (l,), leading_dim=0, divisibility=4),
+ mProblemK=fake_tensor(Int32, (l,), leading_dim=0, divisibility=4),
+ )
+ if grouped
+ else None
+ )
return compile_gemm_kernel(
GemmCls,
a_dtype,
@@ -135,6 +176,7 @@ def fake_scalar(mode, dtype=Float32):
has_trace_ptr=has_trace_ptr,
use_tma_gather=use_tma_gather,
concat_layout=concat_layout or None,
+ problem_args=problem_args,
)
@@ -157,21 +199,48 @@ def gemm(
colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
alpha: float | Tensor = 1.0,
beta: float | Tensor = 1.0,
- cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
- cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length
- A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
- batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
+ cu_seqlens_m: Optional[
+ Tensor
+ ] = None, # (l+1,) cumulative sum of m values for variable length
+ cu_seqlens_k: Optional[
+ Tensor
+ ] = None, # (l+1,) cumulative sum of k values for variable length
+ A_idx: Optional[
+ Tensor
+ ] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
+ batch_idx_permute: Optional[
+ Tensor
+ ] = None, # (l,) permutation of batch indices for scheduler
add_to_output: bool = False,
rounding_mode: int = RoundingMode.RN,
sr_seed: int | Tensor = 0,
use_tma_gather: bool = False,
concat_layout: dict | None = None,
+ grouped: bool = False,
+ grouped_problem_index: Optional[Tensor] = None,
+ grouped_problem_k: Optional[Tensor] = None,
trace_ptr=None, # Optional Int64 from TraceSession.ptr
) -> None:
varlen_m = cu_seqlens_m is not None
varlen_k = cu_seqlens_k is not None
varlen = varlen_m or varlen_k
gather_A = A_idx is not None
+ if grouped:
+ assert A.ndim == 3 and B.ndim == 3 and D.ndim == 3, (
+ "grouped GEMM expects dense packed batched tensors"
+ )
+ assert not varlen and not gather_A, (
+ "grouped GEMM does not combine with varlen/gather"
+ )
+ batch_size = A.shape[0]
+ if grouped_problem_index is not None:
+ assert grouped_problem_index.numel() == batch_size, (
+ "grouped_problem_index must have one entry per problem"
+ )
+ if grouped_problem_k is not None:
+ assert grouped_problem_k.numel() == batch_size, (
+ "grouped_problem_k must have one entry per problem"
+ )
assert not (varlen_m and varlen_k), "Only one of cu_seqlens_m and cu_seqlens_k"
if gather_A:
assert varlen, "gather_A requires varlen"
@@ -188,11 +257,17 @@ def gemm(
assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
device_capacity = get_device_capacity(A.device)
- assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
+ assert device_capacity[0] in [9, 10, 11, 12], (
+ "Only SM90, SM100, SM110, and SM120 are supported"
+ )
if use_tma_gather:
- assert device_capacity[0] in [10, 11], "TMA gather currently requires SM100/SM110"
+ assert device_capacity[0] in [10, 11], (
+ "TMA gather currently requires SM100/SM110"
+ )
if rounding_mode == RoundingMode.RS:
- assert device_capacity[0] == 10, "Stochastic rounding (RoundingMode.RS) requires SM100"
+ assert device_capacity[0] == 10, (
+ "Stochastic rounding (RoundingMode.RS) requires SM100"
+ )
if is_dynamic_persistent and device_capacity[0] == 9:
assert tile_count_semaphore is not None, (
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
@@ -208,7 +283,9 @@ def gemm(
concat_layout = tuple(sorted(concat_layout)) if concat_layout else ()
sr_seed_mode = (
- 2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0)
+ 2
+ if isinstance(sr_seed, Tensor)
+ else (1 if rounding_mode == RoundingMode.RS else 0)
)
compiled_fn = _compile_gemm(
a_dtype,
@@ -240,6 +317,7 @@ def gemm(
rounding_mode,
sr_seed_mode,
trace_ptr is not None,
+ grouped,
)
from quack.cache_utils import COMPILE_ONLY
@@ -255,7 +333,9 @@ def scalar_arg(scalar, mode, dtype=Float32):
else:
return scalar.data_ptr()
- max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
+ max_active_clusters = (
+ get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
+ )
epi_args = GemmDefaultEpiMixin.EpilogueArguments(
alpha=scalar_arg(alpha, alpha_mode),
@@ -273,10 +353,38 @@ def scalar_arg(scalar, mode, dtype=Float32):
batch_idx_permute,
)
varlen_args = make_varlen_args(cu_seqlens_m, cu_seqlens_k, A_idx)
+ problem_args = (
+ GroupedProblemArguments(
+ mProblemIndex=grouped_problem_index,
+ mProblemK=grouped_problem_k,
+ )
+ if grouped
+ else None
+ )
if device_capacity[0] in [10, 11]:
compiled_fn(
- A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, trace_ptr
+ A_p,
+ B_p,
+ D_p,
+ C_p,
+ epi_args,
+ scheduler_args,
+ varlen_args,
+ None,
+ None,
+ trace_ptr,
+ problem_args,
)
else:
- compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, trace_ptr)
+ compiled_fn(
+ A_p,
+ B_p,
+ D_p,
+ C_p,
+ epi_args,
+ scheduler_args,
+ varlen_args,
+ trace_ptr,
+ problem_args,
+ )
diff --git a/quack/gemm_act.py b/quack/gemm_act.py
index 92672333..8f3f48be 100644
--- a/quack/gemm_act.py
+++ b/quack/gemm_act.py
@@ -504,9 +504,9 @@ def scalar_arg(scalar, mode, dtype=Int32):
varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx)
if device_capacity[0] in [10, 11]:
- compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None)
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None, None)
else:
- compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None)
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None)
gemm_gated = gemm_act
diff --git a/quack/gemm_dact.py b/quack/gemm_dact.py
index 6a3d3489..a6afa6d2 100644
--- a/quack/gemm_dact.py
+++ b/quack/gemm_dact.py
@@ -499,10 +499,10 @@ def gemm_dact(
if device_capacity[0] in [10, 11]:
compiled_fn(
- A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None, None, None
+ A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None, None, None, None
)
else:
- compiled_fn(A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None)
+ compiled_fn(A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None, None)
gemm_dgated = gemm_dact
diff --git a/quack/gemm_epilogue_plan.py b/quack/gemm_epilogue_plan.py
new file mode 100644
index 00000000..4058c14c
--- /dev/null
+++ b/quack/gemm_epilogue_plan.py
@@ -0,0 +1,216 @@
+# Copyright (c) 2025, Tri Dao.
+
+from typing import Optional, Tuple, Callable
+
+import cutlass
+import cutlass.cute as cute
+from cutlass import Int32, Boolean, const_expr
+
+import quack.copy_utils as copy_utils
+from quack.rounding import RoundingMode
+from quack.varlen_utils import VarlenManager
+
+
+def default_epi_tile_layout(self, epi_tile_shape):
+ return cute.make_ordered_layout(epi_tile_shape, order=(1, 0))
+
+
+@cute.jit
+def default_epi_commit(
+ self,
+ gmem_coord,
+ epi_buffer,
+ copy_D,
+ copy_postact,
+ postact_ctx,
+ tile_coord_mnkl,
+ is_tma_warp,
+ epi_store_pipeline,
+):
+ del tile_coord_mnkl
+ has_D = const_expr(copy_D is not None)
+ if is_tma_warp:
+ if const_expr(has_D):
+ copy_D(src_idx=epi_buffer, dst_idx=gmem_coord)
+ if const_expr(postact_ctx is not None):
+ copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord)
+ epi_store_pipeline.producer_commit()
+
+
+@cute.jit
+def symmetric_epi_commit(
+ self,
+ gmem_coord,
+ epi_buffer,
+ copy_D,
+ copy_postact,
+ postact_ctx,
+ tile_coord_mnkl,
+ is_tma_warp,
+ epi_store_pipeline,
+):
+ has_D = const_expr(copy_D is not None)
+ if is_tma_warp:
+ square_tile_m = tile_coord_mnkl[0] // self.cluster_shape_mnk[0]
+ square_tile_n = tile_coord_mnkl[1] // self.cluster_shape_mnk[1]
+ if const_expr(has_D):
+ copy_D(src_idx=epi_buffer, dst_idx=gmem_coord)
+ if const_expr(postact_ctx is not None) and square_tile_m != square_tile_n:
+ copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord)
+ epi_store_pipeline.producer_commit()
+
+
+@cute.jit
+def run_epilogue_plan(
+ self,
+ params,
+ epi_smem_tensors: Tuple[cute.Tensor, ...],
+ epi_pipeline: cutlass.pipeline.PipelineAsync,
+ epi_store_pipeline: cutlass.pipeline.PipelineAsync,
+ epi_read_state: cutlass.pipeline.PipelineState,
+ epi_producer_state: cutlass.pipeline.PipelineState,
+ epi_tile: cute.Tile,
+ load_acc_subtile: Callable,
+ tRS_rD: cute.Tensor,
+ tRS_rC: Optional[cute.Tensor],
+ tiled_copy_t2r: Optional[cute.TiledCopy],
+ tiled_copy_r2s: cute.TiledCopy,
+ tRS_sD: cute.Tensor,
+ tiled_copy_s2r: Optional[cute.TiledCopy],
+ tSR_rC: Optional[cute.Tensor],
+ tSR_sC: Optional[cute.Tensor],
+ copy_D: Optional[Callable],
+ copy_C: Optional[Callable],
+ tile_coord_mnkl: cute.Coord,
+ varlen_manager: VarlenManager,
+ epilogue_barrier: cutlass.pipeline.NamedBarrier,
+ tile_scheduler,
+ tidx: Int32,
+ is_tma_warp: Boolean,
+) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
+ has_C = const_expr(tRS_rC is not None)
+ has_D = const_expr(copy_D is not None)
+
+ postact_ctx = self.epi_setup_postact(
+ params,
+ epi_smem_tensors,
+ tiled_copy_r2s,
+ tiled_copy_t2r,
+ tile_coord_mnkl,
+ varlen_manager,
+ tidx,
+ )
+
+ epi_tile_shape = cute.zipped_divide(cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile).shape[1]
+ epi_tile_layout = self.epi_plan_make_tile_layout(epi_tile_shape)
+ epi_tile_num = cute.size(epi_tile_shape)
+ num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
+
+ epi_tensors = self.epi_begin(
+ params,
+ epi_smem_tensors,
+ epi_tile,
+ tiled_copy_t2r,
+ tiled_copy_r2s,
+ tile_coord_mnkl,
+ varlen_manager,
+ epilogue_barrier,
+ tidx,
+ )
+
+ if const_expr(copy_C is not None):
+ for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
+ if is_tma_warp:
+ epi_pipeline.producer_acquire(epi_producer_state)
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
+ epi_pipeline.producer_commit(epi_producer_state)
+ epi_producer_state.advance()
+
+ for epi_idx in cutlass.range_constexpr(epi_tile_num):
+ gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
+ load_acc_subtile(tRS_rD, epi_idx)
+ epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
+ if const_expr(has_C):
+ epi_pipeline.consumer_wait(epi_read_state)
+ cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
+ cute.arch.fence_view_async_shared()
+ cute.arch.sync_warp()
+ with cute.arch.elect_one():
+ epi_pipeline.consumer_release(epi_read_state)
+ epi_read_state.advance()
+ if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
+ gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
+ if is_tma_warp:
+ epi_pipeline.producer_acquire(epi_producer_state)
+ copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
+ epi_pipeline.producer_commit(epi_producer_state)
+ epi_producer_state.advance()
+ tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
+ if const_expr(postact_ctx is not None):
+ tRS_rPostAct_out = self.epi_convert_postact(
+ tRS_rPostAct,
+ epi_loop_tensors["sr_seed"],
+ tidx,
+ tile_coord_mnkl,
+ num_prev_subtiles,
+ epi_idx,
+ )
+ if is_tma_warp:
+ epi_store_pipeline.producer_acquire()
+ epilogue_barrier.arrive_and_wait()
+ epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
+ if const_expr(has_D):
+ if const_expr(
+ self.rounding_mode == RoundingMode.RS
+ and self.acc_dtype == cutlass.Float32
+ and self.d_dtype == cutlass.BFloat16
+ ):
+ seed = epi_loop_tensors["sr_seed"] + (
+ tile_coord_mnkl[0] * 65537
+ + tile_coord_mnkl[1] * 257
+ + tile_coord_mnkl[3] * 17
+ + (num_prev_subtiles + epi_idx) * 7
+ )
+ copy_utils.sr_cvt_copy(
+ tiled_copy_r2s,
+ tRS_rD,
+ tRS_sD[None, None, None, epi_buffer],
+ seed,
+ tidx,
+ )
+ else:
+ copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer])
+ copy_postact = None
+ if const_expr(postact_ctx is not None):
+ tiled_copy_postact_r2s, tRS_sPostAct, copy_postact = postact_ctx
+ cute.copy(
+ tiled_copy_postact_r2s,
+ tiled_copy_postact_r2s.retile(tRS_rPostAct_out),
+ tRS_sPostAct[None, None, None, epi_buffer],
+ )
+ cute.arch.fence_view_async_shared()
+ epilogue_barrier.arrive_and_wait()
+ self.epi_plan_commit(
+ gmem_coord,
+ epi_buffer,
+ copy_D,
+ copy_postact,
+ postact_ctx,
+ tile_coord_mnkl,
+ is_tma_warp,
+ epi_store_pipeline,
+ )
+
+ self.epi_end(
+ params,
+ epi_tensors,
+ epi_tile,
+ tiled_copy_t2r,
+ tiled_copy_r2s,
+ tile_coord_mnkl,
+ varlen_manager,
+ tidx,
+ )
+
+ return epi_read_state, epi_producer_state
diff --git a/quack/gemm_interface.py b/quack/gemm_interface.py
index 64ec9b6c..bd1bdb3f 100644
--- a/quack/gemm_interface.py
+++ b/quack/gemm_interface.py
@@ -437,6 +437,96 @@ def gemm(
return out
+def gemm_grouped(
+ A_list: list[Tensor],
+ B_list: list[Tensor],
+ out_list: Optional[list[Tensor]] = None,
+ alpha: float | Tensor = 1.0,
+ out_dtype: Optional[torch.dtype] = None,
+ dynamic_scheduler: bool = False,
+ tuned: bool = True,
+ rounding_mode: int = RoundingMode.RN,
+ sr_seed: int | Tensor = 0,
+):
+ """Grouped GEMM for heterogeneous `(M_i, K_i) x (K_i, N_i)` problems.
+
+ This host wrapper packs problems into padded batched tensors, dispatches through
+ the grouped-capable core GEMM path, and slices outputs back to per-problem shapes.
+ """
+ assert len(A_list) == len(B_list) and len(A_list) > 0, "A_list/B_list must be same non-zero length"
+ device = A_list[0].device
+ dtype = A_list[0].dtype
+ out_dtype = dtype if out_dtype is None else out_dtype
+ problem_count = len(A_list)
+ ms = [a.shape[0] for a in A_list]
+ ks = [a.shape[1] for a in A_list]
+ ns = [b.shape[1] for b in B_list]
+ assert all(a.ndim == 2 for a in A_list), "Each A must be rank-2"
+ assert all(b.ndim == 2 for b in B_list), "Each B must be rank-2"
+ assert all(a.device == device and b.device == device for a, b in zip(A_list, B_list)), (
+ "All grouped tensors must be on the same device"
+ )
+ assert all(a.dtype == dtype and b.dtype == dtype for a, b in zip(A_list, B_list)), (
+ "All grouped tensors must have the same dtype"
+ )
+ assert all(a.shape[1] == b.shape[0] for a, b in zip(A_list, B_list)), "Incompatible K dims"
+
+ max_m = max(ms)
+ max_k = max(ks)
+ max_n = max(ns)
+
+ A_packed = torch.zeros((problem_count, max_m, max_k), dtype=dtype, device=device)
+ B_packed = torch.zeros((problem_count, max_k, max_n), dtype=dtype, device=device)
+ D_packed = torch.empty((problem_count, max_m, max_n), dtype=out_dtype, device=device)
+
+ for idx, (A_i, B_i) in enumerate(zip(A_list, B_list)):
+ m_i, k_i = A_i.shape
+ _, n_i = B_i.shape
+ A_packed[idx, :m_i, :k_i].copy_(A_i)
+ B_packed[idx, :k_i, :n_i].copy_(B_i)
+
+ grouped_problem_index = torch.arange(problem_count, device=device, dtype=torch.int32)
+ grouped_problem_k = torch.tensor(ks, device=device, dtype=torch.int32)
+
+ config = default_config(device)
+ dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent
+ tile_count_semaphore = (
+ torch.zeros(1, dtype=torch.int32, device=device)
+ if dynamic_scheduler and get_device_capacity(device)[0] == 9
+ else None
+ )
+
+ gemm_dispatch(
+ A_packed,
+ B_packed,
+ D_packed,
+ None,
+ tile_count_semaphore,
+ config.tile_m,
+ config.tile_n,
+ config.cluster_m,
+ config.cluster_n,
+ config.pingpong,
+ persistent=True,
+ is_dynamic_persistent=dynamic_scheduler,
+ max_swizzle_size=config.max_swizzle_size,
+ alpha=alpha,
+ rounding_mode=rounding_mode,
+ sr_seed=sr_seed,
+ grouped=True,
+ grouped_problem_index=grouped_problem_index,
+ grouped_problem_k=grouped_problem_k,
+ )
+
+ if out_list is None:
+ out_list = [
+ torch.empty((m_i, n_i), dtype=out_dtype, device=device) for m_i, n_i in zip(ms, ns)
+ ]
+ for out_i, m_i, n_i, packed_i in zip(out_list, ms, ns, D_packed):
+ out_i.copy_(packed_i[:m_i, :n_i])
+ return out_list
+
+
@torch.library.custom_op(
"quack::gemm_out",
mutates_args=("out",),
diff --git a/quack/gemm_norm_act.py b/quack/gemm_norm_act.py
index f363dfa2..c46394c8 100644
--- a/quack/gemm_norm_act.py
+++ b/quack/gemm_norm_act.py
@@ -395,6 +395,6 @@ def scalar_arg(scalar, mode, dtype=Int32):
varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx)
if device_capacity[0] in [10, 11]:
- compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None)
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None, None)
else:
- compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None)
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None)
diff --git a/quack/gemm_problem_adapter.py b/quack/gemm_problem_adapter.py
new file mode 100644
index 00000000..2b2c5d4a
--- /dev/null
+++ b/quack/gemm_problem_adapter.py
@@ -0,0 +1,158 @@
+# Copyright (c) 2025, Tri Dao.
+
+from typing import NamedTuple, Optional
+from dataclasses import dataclass
+
+import cutlass.cute as cute
+from cutlass import Int32, const_expr
+
+from quack.cute_dsl_utils import ParamsBase, mlir_namedtuple
+
+
+@mlir_namedtuple
+class ProblemArguments(NamedTuple):
+ pass
+
+
+@dataclass
+class ProblemParams(ParamsBase):
+ pass
+
+
+@mlir_namedtuple
+class GroupedProblemArguments(NamedTuple):
+ mProblemIndex: Optional[cute.Tensor] = None
+ mProblemK: Optional[cute.Tensor] = None
+
+
+@dataclass
+class GroupedProblemParams(ParamsBase):
+ problem_index: Optional[cute.Tensor] = None
+ problem_k: Optional[cute.Tensor] = None
+
+ @staticmethod
+ def create(args: GroupedProblemArguments, *, loc=None, ip=None) -> "GroupedProblemParams":
+ return GroupedProblemParams(problem_index=args.mProblemIndex, problem_k=args.mProblemK)
+
+
+def default_problem_to_underlying_arguments(
+ args: ProblemArguments | None, *, loc=None, ip=None
+) -> ProblemParams:
+ return ProblemParams()
+
+
+def grouped_problem_to_underlying_arguments(
+ args: GroupedProblemArguments | None, *, loc=None, ip=None
+) -> GroupedProblemParams:
+ if args is None:
+ args = GroupedProblemArguments()
+ return GroupedProblemParams.create(args, loc=loc, ip=ip)
+
+
+class GroupedProblemAdapterMixin:
+ ProblemArguments = GroupedProblemArguments
+ ProblemParams = GroupedProblemParams
+
+ def problem_to_underlying_arguments(
+ self, args: GroupedProblemArguments | None = None, *, loc=None, ip=None
+ ) -> GroupedProblemParams:
+ return grouped_problem_to_underlying_arguments(args, loc=loc, ip=ip)
+
+ @cute.jit
+ def problem_get_problem_idx(self, params: GroupedProblemParams, work):
+ return grouped_problem_idx(params, work)
+
+ @cute.jit
+ def problem_get_len_k(self, params: GroupedProblemParams, varlen_manager, work):
+ return grouped_problem_len_k(params, work, varlen_manager)
+
+ @cute.jit
+ def problem_get_batch_A(self, params: GroupedProblemParams, mA_mkl, varlen_manager, work):
+ return grouped_problem_batch_a(params, mA_mkl, varlen_manager, work)
+
+ @cute.jit
+ def problem_get_batch_B(self, params: GroupedProblemParams, mB_nkl, varlen_manager, work):
+ return grouped_problem_batch_b(params, mB_nkl, varlen_manager, work)
+
+ @cute.jit
+ def problem_get_batch_epi(self, params: GroupedProblemParams, mX_mnl, varlen_manager, work):
+ return grouped_problem_batch_epi(params, mX_mnl, varlen_manager, work)
+
+
+@cute.jit
+def default_problem_idx(params: ProblemParams, work, *, loc=None, ip=None) -> Int32:
+ del params, loc, ip
+ return work.problem_idx
+
+
+@cute.jit
+def grouped_problem_idx(params: GroupedProblemParams, work, *, loc=None, ip=None) -> Int32:
+ del loc, ip
+ if const_expr(params.problem_index is None):
+ return work.problem_idx
+ return params.problem_index[work.problem_idx]
+
+
+@cute.jit
+def default_problem_len_k(params: ProblemParams, work, varlen_manager, *, loc=None, ip=None) -> Int32:
+ del params, loc, ip
+ return varlen_manager.len_k(work.problem_idx)
+
+
+@cute.jit
+def grouped_problem_len_k(
+ params: GroupedProblemParams, work, varlen_manager, *, loc=None, ip=None
+) -> Int32:
+ del loc, ip
+ if const_expr(params.problem_k is not None):
+ problem_idx = grouped_problem_idx(params, work)
+ return params.problem_k[problem_idx]
+ return varlen_manager.len_k(grouped_problem_idx(params, work))
+
+
+@cute.jit
+def default_problem_batch_a(params: ProblemParams, mA_mkl, varlen_manager, work, *, loc=None, ip=None):
+ del params, loc, ip
+ return varlen_manager.offset_batch_A(mA_mkl, work.problem_idx)
+
+
+@cute.jit
+def grouped_problem_batch_a(
+ params: GroupedProblemParams, mA_mkl, varlen_manager, work, *, loc=None, ip=None
+):
+ del loc, ip
+ return varlen_manager.offset_batch_A(mA_mkl, grouped_problem_idx(params, work))
+
+
+@cute.jit
+def default_problem_batch_b(params: ProblemParams, mB_nkl, varlen_manager, work, *, loc=None, ip=None):
+ del params, loc, ip
+ return varlen_manager.offset_batch_B(mB_nkl, work.problem_idx)
+
+
+@cute.jit
+def grouped_problem_batch_b(
+ params: GroupedProblemParams, mB_nkl, varlen_manager, work, *, loc=None, ip=None
+):
+ del loc, ip
+ return varlen_manager.offset_batch_B(mB_nkl, grouped_problem_idx(params, work))
+
+
+@cute.jit
+def default_problem_batch_epi(
+ params: ProblemParams, mX_mnl, varlen_manager, work, *, loc=None, ip=None
+):
+ del params, loc, ip
+ return None if const_expr(mX_mnl is None) else varlen_manager.offset_batch_epi(mX_mnl, work.problem_idx)
+
+
+@cute.jit
+def grouped_problem_batch_epi(
+ params: GroupedProblemParams, mX_mnl, varlen_manager, work, *, loc=None, ip=None
+):
+ del loc, ip
+ return (
+ None
+ if const_expr(mX_mnl is None)
+ else varlen_manager.offset_batch_epi(mX_mnl, grouped_problem_idx(params, work))
+ )
diff --git a/quack/gemm_sm100.py b/quack/gemm_sm100.py
index 312f9903..01549876 100644
--- a/quack/gemm_sm100.py
+++ b/quack/gemm_sm100.py
@@ -34,7 +34,15 @@
from quack.layout_utils import tile_atom_to_shape_SF_strided
# return PipelineStateWAdvance instead of PipelineState
-
+from cutlass import Boolean, const_expr, Float32, Int32
+from cutlass.cute.nvgpu.warp import (
+ LdMatrix16x16x8bOp,
+ LdMatrix8x8x16bOp,
+ StMatrix16x8x8bOp,
+ StMatrix8x8x16bOp,
+)
+from quack.pipeline import PipelineTmaCpAsyncUmma, PipelineTmaUmma
+from typing import Callable, Literal, Optional, Tuple, Type, Union
"""
A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture
using CUTE DSL.
@@ -210,7 +218,9 @@ def __init__(
if use_tma_gather:
assert gather_A, "TMA gather requires gather_A=True"
- self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
+ self.cta_group = (
+ tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
+ )
self.num_ab_load_warps = 1 if not self.gather_A else 4
self.occupancy = 1
@@ -250,7 +260,9 @@ def __init__(
# Multiple of 4 warps to increase/decrease number of registers
assert self.threads_per_cta % 128 == 0
- def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: VarlenArguments):
+ def _setup_attributes(
+ self, epilogue_args: EpilogueArguments, varlen_args: VarlenArguments
+ ):
"""Set up configurations that are dependent on GEMM inputs
This method configures various attributes based on the input tensor properties
@@ -267,7 +279,9 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: Varle
# Compute mma instruction shapes
mma_inst_bits_k = 256
# (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K)
- mma_inst_shape_n = self.mma_tiler[1] if self.mma_tiler[1] <= 256 else self.mma_tiler[1] // 2
+ mma_inst_shape_n = (
+ self.mma_tiler[1] if self.mma_tiler[1] <= 256 else self.mma_tiler[1] // 2
+ )
self.mma_inst_shape_mnk = (
self.mma_tiler[0],
mma_inst_shape_n,
@@ -380,10 +394,15 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: Varle
# There's a bug w compute_epilogue_tile_shape (as of cutlass-dsl 4.4.2) where if
# tile_n = 224 and there's C, it will set epi_tile to (128, 64).
if const_expr(self.cta_tile_shape_mnk[1] % cute.size(self.epi_tile[1]) != 0):
- warp_n = 2 if (self.cta_tile_shape_mnk[0] == 64 and self.use_2cta_instrs) else 1
- epi_tile_n = math.gcd(self.cta_tile_shape_mnk[1], cute.size(self.epi_tile[1]))
+ warp_n = (
+ 2 if (self.cta_tile_shape_mnk[0] == 64 and self.use_2cta_instrs) else 1
+ )
+ epi_tile_n = math.gcd(
+ self.cta_tile_shape_mnk[1], cute.size(self.epi_tile[1])
+ )
epi_tile_n_layout = cute.make_layout(
- (epi_tile_n // warp_n, warp_n), stride=(1, self.cta_tile_shape_mnk[1] // warp_n)
+ (epi_tile_n // warp_n, warp_n),
+ stride=(1, self.cta_tile_shape_mnk[1] // warp_n),
)
self.epi_tile = (self.epi_tile[0], cute.coalesce(epi_tile_n_layout))
@@ -413,7 +432,9 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: Varle
self.c_layout,
epilogue_args,
prefetch_A_idx,
- cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
+ cutlass.utils.get_smem_capacity_in_bytes(
+ f"sm_{self.arch}"
+ ), # smem_capacity
self.occupancy,
)
self.sched_stage = 1
@@ -430,12 +451,16 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: Varle
self.a_smem_load_layout_staged = self.a_smem_layout_staged
if const_expr(self.gather_A):
if const_expr(self.use_tma_gather):
- self.a_smem_load_layout_staged = quack_sm100_utils.make_smem_layout_tma_gather_a(
- self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
+ self.a_smem_load_layout_staged = (
+ quack_sm100_utils.make_smem_layout_tma_gather_a(
+ self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
+ )
)
else:
- self.a_smem_load_layout_staged = quack_sm100_utils.make_smem_layout_cpasync_a(
- self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
+ self.a_smem_load_layout_staged = (
+ quack_sm100_utils.make_smem_layout_cpasync_a(
+ self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage
+ )
)
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
self.tiled_mma, self.mma_tiler, self.b_dtype, self.ab_stage
@@ -495,7 +520,9 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: Varle
* 4 # 4 cols per stage
* (self.mma_inst_shape_mnk[2] // self.sf_vec_size)
)
- self.iter_acc_early_release = num_sf_tmem_cols // cute.size(self.epi_tile[1])
+ self.iter_acc_early_release = num_sf_tmem_cols // cute.size(
+ self.epi_tile[1]
+ )
else:
self.iter_acc_early_release = -1
@@ -513,6 +540,7 @@ def __call__(
mSFA: Optional[cute.Tensor] = None,
mSFB: Optional[cute.Tensor] = None,
trace_ptr: Optional[cutlass.Int64] = None,
+ problem_args: Optional[GemmSm90.ProblemArguments] = None,
):
"""Execute the GEMM operation in steps:
- Setup static attributes before smem/grid/tma computation
@@ -605,14 +633,18 @@ def __call__(
if const_expr(not self.gather_A):
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op,
- copy_utils.create_ragged_tensor_for_tma(mA, ragged_dim=1, ptr_shift=False)
+ copy_utils.create_ragged_tensor_for_tma(
+ mA, ragged_dim=1, ptr_shift=False
+ )
if varlen_k and not self.gather_A
else mA,
a_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk.shape,
- internal_type=(cutlass.TFloat32 if mA.element_type is Float32 else None),
+ internal_type=(
+ cutlass.TFloat32 if mA.element_type is Float32 else None
+ ),
)
elif const_expr(self.use_tma_gather):
# gather4 descriptor: box has 1 in the gathered dim, tile size in the contiguous dim.
@@ -626,14 +658,18 @@ def __call__(
mA,
tma_smem_layout,
tma_smem_layout.shape,
- internal_type=(cutlass.TFloat32 if mA.element_type is Float32 else None),
+ internal_type=(
+ cutlass.TFloat32 if mA.element_type is Float32 else None
+ ),
)
b_op = sm100_utils.cluster_shape_to_tma_atom_B(
self.cluster_shape_mnk, self.tiled_mma.thr_id
)
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op,
- copy_utils.create_ragged_tensor_for_tma(mB, ragged_dim=1) if varlen_k else mB,
+ copy_utils.create_ragged_tensor_for_tma(mB, ragged_dim=1)
+ if varlen_k
+ else mB,
b_smem_layout,
self.mma_tiler,
self.tiled_mma,
@@ -648,7 +684,9 @@ def __call__(
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(
self.cluster_shape_mnk, self.tiled_mma.thr_id
)
- sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
+ sfa_smem_layout = cute.slice_(
+ self.sfa_smem_layout_staged, (None, None, None, 0)
+ )
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
sfa_op,
mSFA,
@@ -662,7 +700,9 @@ def __call__(
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(
self.cluster_shape_mnk, self.tiled_mma.thr_id
)
- sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
+ sfb_smem_layout = cute.slice_(
+ self.sfb_smem_layout_staged, (None, None, None, 0)
+ )
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
sfb_op,
mSFB,
@@ -673,7 +713,8 @@ def __call__(
internal_type=cutlass.Int16,
)
if const_expr(
- self.cta_tile_shape_mnk[1] == 192 and self.sf_dtype is cutlass.Float8E8M0FNU
+ self.cta_tile_shape_mnk[1] == 192
+ and self.sf_dtype is cutlass.Float8E8M0FNU
):
x = tma_tensor_sfb.stride[0][1]
y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4)
@@ -706,13 +747,18 @@ def __call__(
tma_atom_d, tma_tensor_d = None, None
if const_expr(mD is not None):
tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors(
- copy_utils.create_ragged_tensor_for_tma(mD, ragged_dim=0, ptr_shift=True)
+ copy_utils.create_ragged_tensor_for_tma(
+ mD, ragged_dim=0, ptr_shift=True
+ )
if varlen_m
else mD,
self.epi_smem_layout_staged,
self.epi_tile,
op_type="store"
- if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output)
+ if not (
+ hasattr(epilogue_args, "add_to_output")
+ and epilogue_args.add_to_output
+ )
else "add",
)
tma_atom_c, tma_tensor_c = None, None
@@ -723,6 +769,7 @@ def __call__(
epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
+ problem_params = self.problem_to_underlying_arguments(problem_args)
TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_m)
tile_sched_args = self.get_scheduler_arguments(
@@ -735,14 +782,24 @@ def __call__(
self.buffer_align_bytes = 1024
- epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0
- epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
- sf_dtype = self.sf_dtype if const_expr(self.blockscaled) else cutlass.Float8E8M0FNU
+ epi_smem_size = (
+ cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0
+ )
+ epi_c_smem_size = (
+ cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
+ )
+ sf_dtype = (
+ self.sf_dtype if const_expr(self.blockscaled) else cutlass.Float8E8M0FNU
+ )
sfa_smem_size = (
- cute.cosize(self.sfa_smem_layout_staged) if const_expr(self.blockscaled) else 0
+ cute.cosize(self.sfa_smem_layout_staged)
+ if const_expr(self.blockscaled)
+ else 0
)
sfb_smem_size = (
- cute.cosize(self.sfb_smem_layout_staged) if const_expr(self.blockscaled) else 0
+ cute.cosize(self.sfb_smem_layout_staged)
+ if const_expr(self.blockscaled)
+ else 0
)
a_idx_smem_size = 0
if const_expr(self.gather_A):
@@ -753,10 +810,18 @@ def __call__(
# Define shared storage for kernel
@cute.struct
class SharedStorage:
- ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
- epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
- acc_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
- sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
+ ab_pipeline_array_ptr: cute.struct.MemRange[
+ cutlass.Int64, self.ab_stage * 2
+ ]
+ epi_pipeline_array_ptr: cute.struct.MemRange[
+ cutlass.Int64, self.epi_c_stage * 2
+ ]
+ acc_pipeline_array_ptr: cute.struct.MemRange[
+ cutlass.Int64, self.num_acc_stage * 2
+ ]
+ sched_pipeline_array_ptr: cute.struct.MemRange[
+ cutlass.Int64, self.sched_stage * 2
+ ]
a_prefetch_pipeline_array_ptr: cute.struct.MemRange[
cutlass.Int64, self.a_prefetch_stage * 2
]
@@ -780,12 +845,16 @@ class SharedStorage:
epi: self.epi_get_smem_struct(epilogue_params)
# (MMA, MMA_M, MMA_K, STAGE)
sA: cute.struct.Align[
- cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)],
+ cute.struct.MemRange[
+ self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)
+ ],
self.buffer_align_bytes,
]
# (MMA, MMA_N, MMA_K, STAGE)
sB: cute.struct.Align[
- cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)],
+ cute.struct.MemRange[
+ self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)
+ ],
self.buffer_align_bytes,
]
# (MMA, MMA_M, MMA_K, STAGE)
@@ -806,7 +875,9 @@ class SharedStorage:
self.tiled_mma,
self.tiled_mma_sfb,
tma_atom_a,
- tma_tensor_a if const_expr(not self.gather_A or self.use_tma_gather) else mA,
+ tma_tensor_a
+ if const_expr(not self.gather_A or self.use_tma_gather)
+ else mA,
tma_atom_b,
tma_tensor_b,
tma_atom_sfa,
@@ -831,6 +902,7 @@ class SharedStorage:
self.epi_tile,
tile_sched_params,
TileSchedulerCls,
+ problem_params,
trace_ptr,
).launch(
grid=grid,
@@ -874,6 +946,7 @@ def kernel(
epi_tile: cute.Tile,
tile_sched_params,
TileSchedulerCls: cutlass.Constexpr[Callable],
+ problem_params,
trace_ptr: Optional[cutlass.Int64] = None,
):
"""
@@ -914,7 +987,9 @@ def kernel(
bidx, _, _ = cute.arch.block_idx()
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
is_leader_cta = mma_tile_coord_v == 0
- cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(
+ cute.arch.block_idx_in_cluster()
+ )
# Coord inside cta
tidx, _, _ = cute.arch.thread_idx()
@@ -956,7 +1031,8 @@ def kernel(
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=int(NamedBarrierGemm.TmemPtr),
- num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id)),
+ num_threads=cute.arch.WARP_SIZE
+ * len((self.mma_warp_id, *self.epilog_warp_id)),
)
# Tensor memory dealloc barrier init
tmem = cutlass.utils.TmemAllocator(
@@ -973,13 +1049,19 @@ def kernel(
# Setup smem tensor A/B/D
# (MMA, MMA_M, MMA_K, STAGE)
sA_mma = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
- sA = storage.sA.get_tensor(a_smem_load_layout.outer, swizzle=a_smem_load_layout.inner)
+ sA = storage.sA.get_tensor(
+ a_smem_load_layout.outer, swizzle=a_smem_load_layout.inner
+ )
# (MMA, MMA_N, MMA_K, STAGE)
sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
sAIdx = None
if const_expr(self.gather_A):
- a_idx_smem_dim = self.cta_tile_shape_mnk[0] if varlen_m else self.cta_tile_shape_mnk[2]
- a_idx_smem_layout = cute.make_layout((a_idx_smem_dim, self.a_prefetch_stage))
+ a_idx_smem_dim = (
+ self.cta_tile_shape_mnk[0] if varlen_m else self.cta_tile_shape_mnk[2]
+ )
+ a_idx_smem_layout = cute.make_layout(
+ (a_idx_smem_dim, self.a_prefetch_stage)
+ )
sAIdx = storage.sAIdx.get_tensor(a_idx_smem_layout)
sSFA, sSFB = None, None
if const_expr(self.blockscaled):
@@ -990,22 +1072,30 @@ def kernel(
sD = None
if const_expr(has_D):
# (EPI_TILE_M, EPI_TILE_N, STAGE)
- sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
+ sD = storage.sD.get_tensor(
+ epi_smem_layout.outer, swizzle=epi_smem_layout.inner
+ )
sC = None
if const_expr(has_C):
- sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
+ sC = storage.sC.get_tensor(
+ epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner
+ )
epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
thr_mma_sfb = (
- tiled_mma_sfb.get_slice(mma_tile_coord_v) if const_expr(self.blockscaled) else None
+ tiled_mma_sfb.get_slice(mma_tile_coord_v)
+ if const_expr(self.blockscaled)
+ else None
)
# (MMA, MMA_M, MMA_N)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_fake = tiled_mma.make_fragment_C(
- cute.append(acc_shape, self.num_acc_stage if not self.overlap_accum_sf else 2)
+ cute.append(
+ acc_shape, self.num_acc_stage if not self.overlap_accum_sf else 2
+ )
)
varlen_manager = VarlenManager.create(
@@ -1026,7 +1116,8 @@ def kernel(
epi_load_barrier = None
if const_expr(has_C):
epi_load_barrier = pipeline.NamedBarrier(
- barrier_id=int(NamedBarrierGemm.EpilogueLoad), num_threads=2 * cute.arch.WARP_SIZE
+ barrier_id=int(NamedBarrierGemm.EpilogueLoad),
+ num_threads=2 * cute.arch.WARP_SIZE,
)
# Cluster wait before tensor memory alloc
@@ -1043,11 +1134,13 @@ def kernel(
if const_expr(self.gather_A):
cute.arch.setmaxregister_decrease(self.num_regs_other)
# Compute multicast mask for A/B buffer full
- block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
+ block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(
+ cta_rank_in_cluster
+ )
block_in_cluster_coord_sfb_vmnk = None
if const_expr(self.blockscaled):
- block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord(
- cta_rank_in_cluster
+ block_in_cluster_coord_sfb_vmnk = (
+ cluster_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster)
)
a_mcast_mask, b_mcast_mask = None, None
sfa_mcast_mask, sfb_mcast_mask = None, None
@@ -1063,7 +1156,9 @@ def kernel(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
)
sfb_mcast_mask = cpasync.create_tma_multicast_mask(
- cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1
+ cluster_layout_sfb_vmnk,
+ block_in_cluster_coord_sfb_vmnk,
+ mcast_mode=1,
)
# Persistent tile scheduling loop
@@ -1077,8 +1172,8 @@ def kernel(
)
do_epi_load_barrier_arrive = Boolean(True)
while work_tile.is_valid_tile:
- tile_coord_mnkl = work_tile.tile_idx
- batch_idx = tile_coord_mnkl[3]
+ tile_coord_mnkl = work_tile.tile_coord_mnkl
+ batch_idx = self.problem_get_problem_idx(problem_params, work_tile)
# Local_tile partition global tensors
mma_tile_coord_mnl = (
tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape),
@@ -1087,7 +1182,9 @@ def kernel(
)
gA_mk = None
if const_expr(not self.gather_A):
- mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
+ mA_mk = self.problem_get_batch_A(
+ problem_params, mA_mkl, varlen_manager, work_tile
+ )
# (bM, bK, RestK)
gA_mk = cute.local_tile(
mA_mk,
@@ -1096,7 +1193,9 @@ def kernel(
)
# (bN, bK, RestK)
gB_nk = cute.local_tile(
- varlen_manager.offset_batch_B(mB_nkl, batch_idx),
+ self.problem_get_batch_B(
+ problem_params, mB_nkl, varlen_manager, work_tile
+ ),
cute.select(self.mma_tiler, [1, 2]),
(mma_tile_coord_mnl[1], None),
)
@@ -1127,7 +1226,9 @@ def kernel(
# Partition global tensor for TiledMMA_A/B/D
# Then partition global/shared tensor for TMA load A/B
- len_k = varlen_manager.len_k(batch_idx)
+ len_k = self.problem_get_len_k(
+ problem_params, varlen_manager, work_tile
+ )
# TMA load A partition_S/D
a_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
@@ -1164,7 +1265,9 @@ def kernel(
if const_expr(varlen_m):
cute.arch.sync_warp()
with cute.arch.elect_one():
- a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state)
+ a_prefetch_pipeline.consumer_release(
+ a_prefetch_consumer_state
+ )
a_prefetch_consumer_state.advance()
if const_expr(prefetch_A is not None):
prefetch_A = partial(prefetch_A, a_prefetch_pipeline)
@@ -1224,24 +1327,28 @@ def kernel(
copy_SFB,
)
elif const_expr(self.use_tma_gather):
- ab_producer_state, a_prefetch_consumer_state = self.load_AB_tma_gather(
- ab_pipeline,
- ab_producer_state,
- a_prefetch_consumer_state,
- copy_A,
- prefetch_A,
- copy_B,
- k_tile_cnt,
+ ab_producer_state, a_prefetch_consumer_state = (
+ self.load_AB_tma_gather(
+ ab_pipeline,
+ ab_producer_state,
+ a_prefetch_consumer_state,
+ copy_A,
+ prefetch_A,
+ copy_B,
+ k_tile_cnt,
+ )
)
else:
- ab_producer_state, a_prefetch_consumer_state = self.load_AB_gather_A(
- ab_pipeline,
- ab_producer_state,
- a_prefetch_consumer_state,
- copy_A,
- prefetch_A,
- copy_B,
- k_tile_cnt,
+ ab_producer_state, a_prefetch_consumer_state = (
+ self.load_AB_gather_A(
+ ab_pipeline,
+ ab_producer_state,
+ a_prefetch_consumer_state,
+ copy_A,
+ prefetch_A,
+ copy_B,
+ k_tile_cnt,
+ )
)
tctx.e("tma_load")
if const_expr(epi_load_barrier is not None):
@@ -1274,7 +1381,9 @@ def kernel(
work_tile = tile_scheduler.initial_work_tile_info()
while work_tile.is_valid_tile:
# Advance to next tile
- tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
+ tile_scheduler.advance_to_next_work(
+ is_scheduler_warp=is_scheduler_warp
+ )
work_tile = tile_scheduler.get_current_work()
# End of persistent scheduler loop
if is_scheduler_warp:
@@ -1286,7 +1395,9 @@ def kernel(
cute.arch.setmaxregister_decrease(self.num_regs_other)
tile_M = self.cta_tile_shape_mnk[0]
tile_K = self.cta_tile_shape_mnk[2]
- tiled_copy_AIdx = copy_utils.tiled_copy_1d(Int32, num_threads=32, is_async=True)
+ tiled_copy_AIdx = copy_utils.tiled_copy_1d(
+ Int32, num_threads=32, is_async=True
+ )
thr_copy_AIdx = tiled_copy_AIdx.get_slice(cute.arch.lane_idx())
tAsAIdx = thr_copy_AIdx.partition_D(sAIdx)
tAcAIdx = thr_copy_AIdx.partition_S(
@@ -1299,16 +1410,20 @@ def kernel(
pipeline.PipelineUserType.Producer, self.a_prefetch_stage
)
while work_tile.is_valid_tile:
- tile_coord_mnkl = work_tile.tile_idx
- batch_idx = tile_coord_mnkl[3]
+ tile_coord_mnkl = work_tile.tile_coord_mnkl
+ batch_idx = self.problem_get_problem_idx(problem_params, work_tile)
mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
if const_expr(varlen_m):
# (tile_M,)
- gAIdx = cute.local_tile(mAIdx_mk, (tile_M,), (tile_coord_mnkl[0],))
+ gAIdx = cute.local_tile(
+ mAIdx_mk, (tile_M,), (tile_coord_mnkl[0],)
+ )
tAgAIdx = thr_copy_AIdx.partition_S(gAIdx)
len_m = varlen_manager.len_m(batch_idx)
m_limit = len_m - tile_coord_mnkl[0] * tile_M
- tApAIdx_m = cute.make_rmem_tensor((1, tAsAIdx.shape[1]), Boolean)
+ tApAIdx_m = cute.make_rmem_tensor(
+ (1, tAsAIdx.shape[1]), Boolean
+ )
for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
tApAIdx_m[0, m] = tAcAIdx[0, m] < m_limit
a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
@@ -1324,31 +1439,43 @@ def kernel(
# (tile_K, RestK)
gAIdx = cute.flat_divide(mAIdx_mk, (tile_K,))
tAgAIdx = thr_copy_AIdx.partition_S(gAIdx)
- len_k = varlen_manager.len_k(batch_idx)
+ len_k = self.problem_get_len_k(
+ problem_params, varlen_manager, work_tile
+ )
k_tile_cnt = cute.ceil_div(len_k, tile_K)
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
- a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
+ a_prefetch_pipeline.producer_acquire(
+ a_prefetch_producer_state
+ )
cute.copy(
thr_copy_AIdx,
tAgAIdx[None, None, k_tile],
tAsAIdx[None, None, a_prefetch_producer_state.index],
)
- a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
+ a_prefetch_pipeline.producer_commit(
+ a_prefetch_producer_state
+ )
a_prefetch_producer_state.advance()
if 0 < k_tile_cnt:
k_tile = k_tile_cnt - 1
k_limit = len_k - k_tile * tile_K
- tApAIdx_k = cute.make_rmem_tensor((1, tAsAIdx.shape[1]), Boolean)
+ tApAIdx_k = cute.make_rmem_tensor(
+ (1, tAsAIdx.shape[1]), Boolean
+ )
for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True):
tApAIdx_k[0, m] = tAcAIdx[0, m] < k_limit
- a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state)
+ a_prefetch_pipeline.producer_acquire(
+ a_prefetch_producer_state
+ )
cute.copy(
tiled_copy_AIdx,
tAgAIdx[None, None, k_tile],
tAsAIdx[None, None, a_prefetch_producer_state.index],
pred=tApAIdx_k,
)
- a_prefetch_pipeline.producer_commit(a_prefetch_producer_state)
+ a_prefetch_pipeline.producer_commit(
+ a_prefetch_producer_state
+ )
a_prefetch_producer_state.advance()
# Advance to next tile
tile_scheduler.advance_to_next_work()
@@ -1371,11 +1498,13 @@ def kernel(
work_tile = tile_scheduler.initial_work_tile_info()
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
- tile_coord_mnkl = work_tile.tile_idx
- batch_idx = tile_coord_mnkl[3]
+ tile_coord_mnkl = work_tile.tile_coord_mnkl
+ batch_idx = self.problem_get_problem_idx(problem_params, work_tile)
copy_C_fn, _, bGS_gC = self.epilog_gmem_copy_and_partition(
tma_atom_c,
- varlen_manager.offset_batch_epi(mC_mnl, batch_idx),
+ self.problem_get_batch_epi(
+ problem_params, mC_mnl, varlen_manager, work_tile
+ ),
self.cta_tile_shape_mnk[:2],
epi_tile,
sC,
@@ -1459,8 +1588,16 @@ def kernel(
) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB)
else:
tCtSFA, tCtSFB = None, None
- tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = None, None, None
- tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = None, None, None
+ tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = (
+ None,
+ None,
+ None,
+ )
+ tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = (
+ None,
+ None,
+ None,
+ )
# Persistent tile scheduling loop
tile_scheduler = TileSchedulerCls()
@@ -1473,9 +1610,11 @@ def kernel(
)
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
- tile_coord_mnkl = work_tile.tile_idx
- batch_idx = tile_coord_mnkl[3]
- k_len = varlen_manager.len_k(batch_idx)
+ tile_coord_mnkl = work_tile.tile_coord_mnkl
+ batch_idx = self.problem_get_problem_idx(problem_params, work_tile)
+ k_len = self.problem_get_len_k(
+ problem_params, varlen_manager, work_tile
+ )
k_tile_cnt = cute.ceil_div(k_len, self.mma_tiler[2])
# Set tensor memory buffer for current tile
# (MMA, MMA_M, MMA_N)
@@ -1486,7 +1625,9 @@ def kernel(
)
tCtAcc = tCtAcc_base[None, None, None, acc_stage_idx]
tCtSFB_mma = tCtSFB
- if const_expr(self.blockscaled and self.mma_inst_shape_mnk[1] in (64, 192)):
+ if const_expr(
+ self.blockscaled and self.mma_inst_shape_mnk[1] in (64, 192)
+ ):
tCtSFB_mma = cute.make_tensor(
cute.recast_ptr(
sfb_tmem_base_ptr + Int32((tile_coord_mnkl[1] % 2) * 2),
@@ -1525,14 +1666,20 @@ def kernel(
# Doing tmem ptr arithmetic requires 32-bit type, wrong otherwise
cute.recast_ptr(mT.iterator, dtype=Float32)
+ cute.assume(
- acc_tmem_col_offset * (acc_producer_state.phase * 2 - 1),
+ acc_tmem_col_offset
+ * (acc_producer_state.phase * 2 - 1),
divby=acc_tmem_col_offset,
),
dtype=self.sf_dtype,
),
mT.layout,
)
- for mT in [tCtSFA, tCtSFB, tCtSFA_compact_s2t, tCtSFB_compact_s2t]
+ for mT in [
+ tCtSFA,
+ tCtSFB,
+ tCtSFA_compact_s2t,
+ tCtSFB_compact_s2t,
+ ]
]
tctx.e("mma")
# Advance to next tile
@@ -1566,8 +1713,10 @@ def kernel(
# Partition for epilogue
epi_tidx = tidx
- tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition(
- epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs
+ tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = (
+ self.epilog_tmem_copy_and_partition(
+ epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs
+ )
)
tTR_rD = cute.make_rmem_tensor(tTR_rAcc.shape, self.acc_dtype)
@@ -1577,8 +1726,15 @@ def kernel(
tRS_rC, tSR_rC, tSR_sC = None, None, None
tiled_copy_s2r = None
if const_expr(mC_mnl is not None):
- tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
- tiled_copy_t2r, self.c_layout, self.c_dtype, sC, tRS_rD.layout, epi_tidx
+ tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = (
+ self.epilog_smem_load_and_partition(
+ tiled_copy_t2r,
+ self.c_layout,
+ self.c_dtype,
+ sC,
+ tRS_rD.layout,
+ epi_tidx,
+ )
)
# Persistent tile scheduling loop
@@ -1593,8 +1749,8 @@ def kernel(
)
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
- tile_coord_mnkl = work_tile.tile_idx
- batch_idx = tile_coord_mnkl[3]
+ tile_coord_mnkl = work_tile.tile_coord_mnkl
+ batch_idx = self.problem_get_problem_idx(problem_params, work_tile)
# Set tensor memory buffer for current tile
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
epi_acc_stage = (
@@ -1610,7 +1766,9 @@ def kernel(
if const_expr(has_D):
copy_D, _, _ = self.epilog_gmem_copy_and_partition(
tma_atom_d,
- varlen_manager.offset_batch_epi(mD_mnl, batch_idx),
+ self.problem_get_batch_epi(
+ problem_params, mD_mnl, varlen_manager, work_tile
+ ),
self.cta_tile_shape_mnk[:2],
epi_tile,
sD,
@@ -1619,9 +1777,13 @@ def kernel(
copy_C = None # We're using a separate warp to load C
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
- k_len = varlen_manager.len_k(batch_idx)
+ k_len = self.problem_get_len_k(
+ problem_params, varlen_manager, work_tile
+ )
epi_tile_num = cute.size(
- cute.zipped_divide(cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile),
+ cute.zipped_divide(
+ cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
+ ),
mode=[1],
)
load_acc_subtile = partial(
@@ -1773,24 +1935,36 @@ def load_AB_gather_A(
copy_B: Callable,
k_tile_cnt: Int32,
varlen_m: bool = True,
- ) -> Tuple[cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState]]:
+ ) -> Tuple[
+ cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState]
+ ]:
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
# Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt
peek_ab_empty_status = Boolean(True)
if 0 < k_tile_cnt:
peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
# TMA load on B and cp.async on A
- for k_tile in cutlass.range(k_tile_cnt - 1, unroll=2 if const_expr(varlen_m) else 1):
+ for k_tile in cutlass.range(
+ k_tile_cnt - 1, unroll=2 if const_expr(varlen_m) else 1
+ ):
smem_idx = ab_producer_state.index
prefetch_out = ()
- if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
- prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state),)
+ if const_expr(
+ prefetch_A is not None
+ ): # Prefetch early, even before smem is free
+ prefetch_out = (
+ prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state),
+ )
a_prefetch_consumer_state.advance()
# Wait for A/B buffers to be empty before loading into them
# Also sets the transaction barrier for the A/B buffers
# A tiny bit faster to rotate the warp that does TMA
- is_tma_warp = warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps)
- ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
+ is_tma_warp = warp_idx == self.ab_load_warp_id + (
+ k_tile % self.num_ab_load_warps
+ )
+ ab_pipeline.producer_acquire(
+ ab_producer_state, peek_ab_empty_status, is_tma_warp
+ )
# A bit faster to load B first while we calculate the indices for A
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
if is_tma_warp:
@@ -1801,17 +1975,27 @@ def load_AB_gather_A(
ab_producer_state.advance()
peek_ab_empty_status = Boolean(True)
if k_tile + 1 < k_tile_cnt:
- peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(
+ ab_producer_state
+ )
# bound checking in the K dimension on the last k_tile
if 0 < k_tile_cnt:
k_tile = k_tile_cnt - 1
smem_idx = ab_producer_state.index
prefetch_out = ()
- if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
- prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state, pred=True),)
+ if const_expr(
+ prefetch_A is not None
+ ): # Prefetch early, even before smem is free
+ prefetch_out = (
+ prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state, pred=True),
+ )
a_prefetch_consumer_state.advance()
- is_tma_warp = warp_idx == self.ab_load_warp_id + k_tile % self.num_ab_load_warps
- ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
+ is_tma_warp = (
+ warp_idx == self.ab_load_warp_id + k_tile % self.num_ab_load_warps
+ )
+ ab_pipeline.producer_acquire(
+ ab_producer_state, peek_ab_empty_status, is_tma_warp
+ )
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
if is_tma_warp:
copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
@@ -1830,7 +2014,9 @@ def load_AB_tma_gather(
prefetch_A: Optional[Callable],
copy_B: Callable,
k_tile_cnt: Int32,
- ) -> Tuple[cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState]]:
+ ) -> Tuple[
+ cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState]
+ ]:
"""Unified TMA gather loading loop for both varlen_m and varlen_k.
For varlen_m: a_prefetch_pipeline is None, copy_A receives k_tile as src_idx.
@@ -1844,11 +2030,19 @@ def load_AB_tma_gather(
for k_tile in cutlass.range(k_tile_cnt, unroll=1):
smem_idx = ab_producer_state.index
prefetch_out = ()
- if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
- prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state),)
+ if const_expr(
+ prefetch_A is not None
+ ): # Prefetch early, even before smem is free
+ prefetch_out = (
+ prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state),
+ )
a_prefetch_consumer_state.advance()
- is_tma_warp = warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps)
- ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
+ is_tma_warp = warp_idx == self.ab_load_warp_id + (
+ k_tile % self.num_ab_load_warps
+ )
+ ab_pipeline.producer_acquire(
+ ab_producer_state, peek_ab_empty_status, is_tma_warp
+ )
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
if is_tma_warp:
copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr)
@@ -1857,7 +2051,9 @@ def load_AB_tma_gather(
ab_producer_state.advance()
peek_ab_empty_status = Boolean(True)
if k_tile + 1 < k_tile_cnt:
- peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(
+ ab_producer_state
+ )
return ab_producer_state, a_prefetch_consumer_state
@cute.jit
@@ -1882,7 +2078,9 @@ def mma(
tCsSFB_compact_s2t: Optional[cute.Tensor] = None,
tCtSFA_compact_s2t: Optional[cute.Tensor] = None,
tCtSFB_compact_s2t: Optional[cute.Tensor] = None,
- ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState, cute.TiledMma]:
+ ) -> Tuple[
+ cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState, cute.TiledMma
+ ]:
blockscaled = const_expr(tiled_copy_s2t_sfa is not None)
if const_expr(blockscaled):
assert all(x is not None for x in (tCtSFA, tCtSFB))
@@ -1923,15 +2121,27 @@ def mma(
s2t_stage_coord = (None, None, None, None, ab_consumer_state.index)
tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord]
tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord]
- cute.copy(tiled_copy_s2t_sfa, tCsSFA_compact_s2t_staged, tCtSFA_compact_s2t)
- cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t)
+ cute.copy(
+ tiled_copy_s2t_sfa,
+ tCsSFA_compact_s2t_staged,
+ tCtSFA_compact_s2t,
+ )
+ cute.copy(
+ tiled_copy_s2t_sfb,
+ tCsSFB_compact_s2t_staged,
+ tCtSFB_compact_s2t,
+ )
for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True):
k_blk_coord = (None, None, k_blk_idx, ab_consumer_state.index)
if const_expr(blockscaled):
# Set SFA/SFB tensor to tiled_mma
sf_kblock_coord = (None, None, k_blk_idx)
- tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator)
- tiled_mma.set(tcgen05.Field.SFB, tCtSFB[sf_kblock_coord].iterator)
+ tiled_mma.set(
+ tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator
+ )
+ tiled_mma.set(
+ tcgen05.Field.SFB, tCtSFB[sf_kblock_coord].iterator
+ )
cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Async arrive AB buffer empty
@@ -1999,13 +2209,17 @@ def mainloop_s2t_copy_and_partition(
# (MMA, MMA_MN, MMA_K)
tCtSF_compact = cute.filter_zeros(tSF)
# Make S2T CopyAtom and tiledCopy
- copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype)
+ copy_atom_s2t = cute.make_copy_atom(
+ tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype
+ )
tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact)
thr_copy_s2t = tiled_copy_s2t.get_slice(0)
# ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE)
tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact)
# ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE)
- tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_)
+ tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(
+ tiled_copy_s2t, tCsSF_compact_s2t_
+ )
# ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K)
tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
@@ -2047,19 +2261,25 @@ def epilog_tmem_copy_and_partition(
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE)
tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile)
# (EPI_TILE_M, EPI_TILE_N)
- tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)])
+ tiled_copy_t2r = tcgen05.make_tmem_copy(
+ copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]
+ )
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE)
tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi)
- cAcc = cute.make_identity_tensor((self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]))
+ cAcc = cute.make_identity_tensor(
+ (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1])
+ )
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N)
cAcc_epi = cute.flat_divide(cAcc, epi_tile)
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
tTR_cAcc = thr_copy_t2r.partition_D(cAcc_epi)
# (T2R, T2R_M, T2R_N)
- tTR_rAcc = cute.make_rmem_tensor(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype)
+ tTR_rAcc = cute.make_rmem_tensor(
+ tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype
+ )
return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
def epilog_smem_store_and_partition(
@@ -2120,9 +2340,14 @@ def epilog_smem_load_and_partition(
store_op = copy_atom_r2s.op
# m8n8 16-bit path
if isinstance(store_op, StMatrix8x8x16bOp):
- op = LdMatrix8x8x16bOp(num_matrices=store_op.num_matrices, transpose=store_op.transpose)
+ op = LdMatrix8x8x16bOp(
+ num_matrices=store_op.num_matrices, transpose=store_op.transpose
+ )
# m16n8 8-bit store -> m16n16 8-bit load
- elif isinstance(store_op, StMatrix16x8x8bOp) and store_op.num_matrices in [2, 4]:
+ elif isinstance(store_op, StMatrix16x8x8bOp) and store_op.num_matrices in [
+ 2,
+ 4,
+ ]:
# transpose=True is enforced by the class
op = LdMatrix16x16x8bOp(num_matrices=store_op.num_matrices // 2)
else:
@@ -2158,7 +2383,9 @@ def make_ab_pipeline(
producer_cnt = self.num_ab_load_warps * 32 + (
1 if const_expr(not self.use_2cta_instrs) else 2
)
- ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread, producer_cnt
+ )
# Each warp will contribute to the arrive count with the number of mcast size
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
consumer_arrive_cnt = mcast_size
@@ -2204,7 +2431,9 @@ def make_acc_pipeline(
self, cluster_layout_vmnk: cute.Layout, acc_pipeline_mbar_ptr: cute.Pointer
) -> pipeline.PipelineAsync:
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
- num_acc_consumer_threads = self.num_epi_warps * (2 if self.use_2cta_instrs else 1)
+ num_acc_consumer_threads = self.num_epi_warps * (
+ 2 if self.use_2cta_instrs else 1
+ )
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_acc_consumer_threads
)
@@ -2229,7 +2458,12 @@ def make_sched_pipeline(
# Each warp will contribute 1 to the arrive count
extra_warp_ids = (self.a_prefetch_warp_id,) if self.gather_A else ()
warps_per_cta = self.num_ab_load_warps + len(
- (self.mma_warp_id, *self.epilog_warp_id, self.scheduler_warp_id, *extra_warp_ids)
+ (
+ self.mma_warp_id,
+ *self.epilog_warp_id,
+ self.scheduler_warp_id,
+ *extra_warp_ids,
+ )
)
if has_C:
warps_per_cta += 1
@@ -2252,7 +2486,9 @@ def make_a_prefetch_pipeline(
self, a_prefetch_pipeline_mbar_ptr: cute.Pointer
) -> pipeline.PipelineAsync:
producer_cnt = 32
- a_prefetch_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
+ a_prefetch_producer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread, producer_cnt
+ )
consumer_arrive_cnt = self.num_ab_load_warps
a_prefetch_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, consumer_arrive_cnt
@@ -2319,7 +2555,9 @@ def _compute_stages(
# Default D stages
epi_stage = 4 if cute.size(epi_tile[1]) <= 16 else 2
- epi_c_stage = 0 if c_dtype is None else (4 if cute.size(epi_tile[1]) <= 16 else 2)
+ epi_c_stage = (
+ 0 if c_dtype is None else (4 if cute.size(epi_tile[1]) <= 16 else 2)
+ )
# Calculate smem layout and size for one stage of A, B, and C
a_smem_layout_staged_one = sm100_utils.make_smem_layout_a(
@@ -2371,7 +2609,9 @@ def _compute_stages(
if const_expr(prefetch_A_idx == "varlen_m"):
mbar_helpers_bytes += Int32.width // 8 * cta_tile_shape_mnk[0] * 2
d_bytes_per_stage = (
- cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) if d_dtype is not None else 0
+ cute.size_in_bytes(d_dtype, d_smem_layout_staged_one)
+ if d_dtype is not None
+ else 0
)
epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
epilogue_args, cta_tile_shape_mnk, epi_tile
@@ -2391,7 +2631,9 @@ def _compute_stages(
# Refine epilogue stages:
# Calculate remaining smem after allocating for A/B stages and reserved bytes
# Add remaining unused smem to epilogue
- epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // (epi_bytes_per_stage)
+ epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // (
+ epi_bytes_per_stage
+ )
return num_acc_stage, ab_stage, epi_stage, epi_c_stage
@staticmethod
@@ -2457,7 +2699,8 @@ def is_valid_dtypes(
if (
acc_dtype not in {Float32, cutlass.Float16, Int32}
or acc_dtype == cutlass.Float16
- and ab_dtype not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}
+ and ab_dtype
+ not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}
or acc_dtype == Int32
and ab_dtype not in {cutlass.Uint8, cutlass.Int8}
):
@@ -2522,7 +2765,11 @@ def is_valid_dtypes_and_scale_factor_vec_size(
is_valid = True
# Check valid ab_dtype
- if ab_dtype not in {cutlass.Float4E2M1FN, cutlass.Float8E5M2, cutlass.Float8E4M3FN}:
+ if ab_dtype not in {
+ cutlass.Float4E2M1FN,
+ cutlass.Float8E5M2,
+ cutlass.Float8E4M3FN,
+ }:
is_valid = False
# Check valid sf_vec_size
@@ -2741,7 +2988,9 @@ def can_implement(
"""
can_implement = True
# Skip unsupported types
- if not GemmSm100.is_valid_dtypes(ab_dtype, ab_dtype, acc_dtype, d_dtype, a_major, b_major):
+ if not GemmSm100.is_valid_dtypes(
+ ab_dtype, ab_dtype, acc_dtype, d_dtype, a_major, b_major
+ ):
can_implement = False
# Skip invalid mma tile shape and cluster shape
if not GemmSm100.is_valid_mma_tiler_and_cluster_shape(
diff --git a/quack/gemm_sm120.py b/quack/gemm_sm120.py
index 34e3e383..a8d06289 100644
--- a/quack/gemm_sm120.py
+++ b/quack/gemm_sm120.py
@@ -161,6 +161,7 @@ def kernel(
epi_c_smem_layout: cute.ComposedLayout,
tile_sched_params,
TileSchedulerCls: cutlass.Constexpr[Callable],
+ problem_params,
trace_ptr: Optional[cutlass.Int64] = None,
):
from quack.trace import TraceContext
@@ -267,12 +268,12 @@ def kernel(
)
while work_tile.is_valid_tile:
tctx.b("tma_load")
- tile_coord_mnkl = work_tile.tile_idx
- batch_idx = tile_coord_mnkl[3]
+ tile_coord_mnkl = work_tile.tile_coord_mnkl
+ batch_idx = self.problem_get_problem_idx(problem_params, work_tile)
# Local_tile partition global tensors
copy_A, prefetch_A = None, None
if const_expr(not self.gather_A):
- mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
+ mA_mk = self.problem_get_batch_A(problem_params, mA_mkl, varlen_manager, work_tile)
# (bM, bK, RestK)
gA_mk = cute.local_tile(
mA_mk,
@@ -296,7 +297,7 @@ def kernel(
)
# (bN, bK, RestK)
gB_nk = cute.local_tile(
- varlen_manager.offset_batch_B(mB_nkl, batch_idx),
+ self.problem_get_batch_B(problem_params, mB_nkl, varlen_manager, work_tile),
cute.select(self.cta_tile_shape_mnk, [1, 2]),
(tile_coord_mnkl[1], None),
)
@@ -311,7 +312,7 @@ def kernel(
dst_tensor=sB,
mcast_mask=b_mcast_mask,
)
- len_k = varlen_manager.len_k(batch_idx)
+ len_k = self.problem_get_len_k(problem_params, varlen_manager, work_tile)
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
if const_expr(not self.gather_A):
ab_producer_state = self.load_AB(
@@ -407,15 +408,15 @@ def kernel(
if const_expr(not varlen_k):
ab_read_state.advance_iters(k_tile_cnt_static)
else:
- len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
+ len_k = self.problem_get_len_k(problem_params, varlen_manager, work_tile)
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
ab_read_state.advance_iters(k_tile_cnt)
tile_scheduler.advance_to_next_work()
work_tile = tile_scheduler.get_current_work()
while work_tile.is_valid_tile:
- tile_coord_mnkl = work_tile.tile_idx
- batch_idx = tile_coord_mnkl[3]
- len_k = varlen_manager.len_k(batch_idx)
+ tile_coord_mnkl = work_tile.tile_coord_mnkl
+ batch_idx = self.problem_get_problem_idx(problem_params, work_tile)
+ len_k = self.problem_get_len_k(problem_params, varlen_manager, work_tile)
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
acc.fill(0.0)
if const_expr(self.pingpong):
@@ -450,7 +451,7 @@ def kernel(
if const_expr(has_D):
copy_D, _, _ = self.epilog_gmem_copy_and_partition(
tma_atom_d,
- varlen_manager.offset_batch_epi(mD_mnl, tile_coord_mnkl[3]),
+ self.problem_get_batch_epi(problem_params, mD_mnl, varlen_manager, work_tile),
self.cta_tile_shape_mnk[:2],
self.epi_tile,
sD,
@@ -460,7 +461,7 @@ def kernel(
if const_expr(has_C):
copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition(
tma_atom_c,
- varlen_manager.offset_batch_epi(mC_mnl, tile_coord_mnkl[3]),
+ self.problem_get_batch_epi(problem_params, mC_mnl, varlen_manager, work_tile),
self.cta_tile_shape_mnk[:2],
self.epi_tile,
sC,
@@ -535,7 +536,7 @@ def kernel(
tile_scheduler.advance_to_next_work()
work_tile = tile_scheduler.get_current_work()
if work_tile.is_valid_tile:
- len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
+ len_k = self.problem_get_len_k(problem_params, varlen_manager, work_tile)
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
ab_read_state.advance_iters(k_tile_cnt)
tile_scheduler.advance_to_next_work()
diff --git a/quack/gemm_sm90.py b/quack/gemm_sm90.py
index 7ddb280e..de32301d 100644
--- a/quack/gemm_sm90.py
+++ b/quack/gemm_sm90.py
@@ -8,38 +8,46 @@
import math
-import cuda.bindings.driver as cuda
-
-import cutlass
-import cutlass.cute as cute
-import cutlass.pipeline as pipeline
-from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
+from cutlass import Boolean, const_expr, Float16, Float32, Int32
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
-import cutlass.utils.hopper_helpers as sm90_utils
-from cutlass import Int32, Float32, Float16, Boolean, const_expr
+from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
from cutlass.utils import LayoutEnum
-
-
from dataclasses import dataclass
-
-from quack.cute_dsl_utils import ParamsBase
from quack import layout_utils
+from quack.cute_dsl_utils import ParamsBase
+from quack.gemm_epilogue_plan import (
+ default_epi_commit,
+ default_epi_tile_layout,
+ run_epilogue_plan,
+)
+from quack.gemm_problem_adapter import (
+ default_problem_batch_a,
+ default_problem_batch_b,
+ default_problem_batch_epi,
+ default_problem_idx,
+ default_problem_len_k,
+ default_problem_to_underlying_arguments,
+ ProblemArguments as DefaultProblemArguments,
+ ProblemParams as DefaultProblemParams,
+)
+from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
from quack.tile_scheduler import (
- TileSchedulerOptions,
- TileSchedulerArguments,
+ PersistenceMode,
TileScheduler,
- VarlenMTileSchedulerArguments,
+ TileSchedulerArguments,
+ TileSchedulerOptions,
VarlenMTileScheduler,
- PersistenceMode,
+ VarlenMTileSchedulerArguments,
)
from quack.varlen_utils import VarlenArguments, VarlenManager
-
-# return PipelineStateWAdvance instead of PipelineState
-from quack.pipeline import make_pipeline_state, PipelineTmaCpAsync
+from typing import Callable, Literal, Optional, Tuple, Type, Union
+import cuda.bindings.driver as cuda
+import cutlass
+import cutlass.cute as cute
+import cutlass.pipeline as pipeline
+import cutlass.utils.hopper_helpers as sm90_utils
import quack.copy_utils as copy_utils
import quack.sm90_utils as quack_sm90_utils
-from quack.rounding import RoundingMode
-
"""
A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture
using CUTE DSL.
@@ -134,6 +142,80 @@ class EpilogueArguments:
pass
EpilogueParams = ParamsBase
+ ProblemArguments = DefaultProblemArguments
+ ProblemParams = DefaultProblemParams
+
+ def problem_to_underlying_arguments(
+ self, args: Optional[ProblemArguments] = None, *, loc=None, ip=None
+ ) -> ProblemParams:
+ return default_problem_to_underlying_arguments(args, loc=loc, ip=ip)
+
+ @cute.jit
+ def problem_get_problem_idx(self, params: ProblemParams, work):
+ return default_problem_idx(params, work)
+
+ @cute.jit
+ def problem_get_len_k(
+ self, params: ProblemParams, varlen_manager: VarlenManager, work
+ ):
+ return default_problem_len_k(params, work, varlen_manager)
+
+ @cute.jit
+ def problem_get_batch_A(
+ self,
+ params: ProblemParams,
+ mA_mkl: cute.Tensor,
+ varlen_manager: VarlenManager,
+ work,
+ ):
+ return default_problem_batch_a(params, mA_mkl, varlen_manager, work)
+
+ @cute.jit
+ def problem_get_batch_B(
+ self,
+ params: ProblemParams,
+ mB_nkl: cute.Tensor,
+ varlen_manager: VarlenManager,
+ work,
+ ):
+ return default_problem_batch_b(params, mB_nkl, varlen_manager, work)
+
+ @cute.jit
+ def problem_get_batch_epi(
+ self,
+ params: ProblemParams,
+ mX_mnl: Optional[cute.Tensor],
+ varlen_manager: VarlenManager,
+ work,
+ ):
+ return default_problem_batch_epi(params, mX_mnl, varlen_manager, work)
+
+ def epi_plan_make_tile_layout(self, epi_tile_shape):
+ return default_epi_tile_layout(self, epi_tile_shape)
+
+ @cute.jit
+ def epi_plan_commit(
+ self,
+ gmem_coord,
+ epi_buffer,
+ copy_D,
+ copy_postact,
+ postact_ctx,
+ tile_coord_mnkl,
+ is_tma_warp,
+ epi_store_pipeline,
+ ):
+ default_epi_commit(
+ self,
+ gmem_coord,
+ epi_buffer,
+ copy_D,
+ copy_postact,
+ postact_ctx,
+ tile_coord_mnkl,
+ is_tma_warp,
+ epi_store_pipeline,
+ )
def __init__(
self,
@@ -194,7 +276,8 @@ def __init__(
)
else:
if not (
- (tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)
+ (tile_N % 16 == 0 and tile_N <= 256)
+ or (tile_N % 32 == 0 and tile_N <= 512)
):
raise ValueError(
"CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512"
@@ -204,7 +287,9 @@ def __init__(
raise ValueError("CTA tile shape M must be 64/128/192 if pingpong")
tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128)
if not (tile_N % 16 == 0 and tile_N <= tile_N_max):
- raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}")
+ raise ValueError(
+ f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}"
+ )
if not self.pingpong:
if tile_M == 320: # tile_M / 64 is not even so we have to split along N
@@ -216,7 +301,9 @@ def __init__(
atom_layout_m, atom_layout_n = 1, 2
else:
atom_layout_m = (
- self.cta_tile_shape_mnk[0] // 64 if self.cta_tile_shape_mnk[0] < 256 else 2
+ self.cta_tile_shape_mnk[0] // 64
+ if self.cta_tile_shape_mnk[0] < 256
+ else 2
)
atom_layout_n = 1
assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2]
@@ -232,12 +319,16 @@ def __init__(
self.is_b_mcast = self.num_mcast_ctas_b > 1
self.occupancy = 1
- self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2)
+ self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (
+ 1 if not self.pingpong else 2
+ )
if self.pingpong:
assert self.mma_warp_groups == 2
assert self.mma_warp_groups in [1, 2, 3]
self.num_threads_per_warp_group = 128
- self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group
+ self.threads_per_cta = (
+ self.mma_warp_groups + 1
+ ) * self.num_threads_per_warp_group
self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90")
self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4
self.epilogue_barrier = pipeline.NamedBarrier(
@@ -343,7 +434,9 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments):
self.d_dtype,
self.c_dtype,
epilogue_args,
- cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity
+ cutlass.utils.get_smem_capacity_in_bytes(
+ f"sm_{self.arch}"
+ ), # smem_capacity
self.occupancy,
)
self.sched_stage = 2 if self.pingpong else 1
@@ -381,6 +474,7 @@ def __call__(
varlen_args: Optional[VarlenArguments],
stream: cuda.CUstream,
trace_ptr: Optional[cutlass.Int64] = None,
+ problem_args: Optional[ProblemArguments] = None,
):
"""Execute the GEMM operation in steps:
- Setup static attributes
@@ -420,7 +514,9 @@ def __call__(
if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype):
raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}")
if const_expr(self.a_dtype.width != self.b_dtype.width):
- raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}")
+ raise TypeError(
+ f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}"
+ )
if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8):
raise TypeError("a_dtype should be float16 or float8")
@@ -445,7 +541,9 @@ def __call__(
self.cluster_shape_mnk[1],
)
tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors(
- copy_utils.create_ragged_tensor_for_tma(mB, ragged_dim=1) if varlen_k else mB,
+ copy_utils.create_ragged_tensor_for_tma(mB, ragged_dim=1)
+ if varlen_k
+ else mB,
b_smem_layout,
(self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]),
self.cluster_shape_mnk[0],
@@ -468,7 +566,10 @@ def __call__(
self.epi_smem_layout_staged,
self.epi_tile,
op_type="store"
- if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output)
+ if not (
+ hasattr(epilogue_args, "add_to_output")
+ and epilogue_args.add_to_output
+ )
else "add",
)
tma_atom_c, tma_tensor_c = None, None
@@ -479,6 +580,7 @@ def __call__(
epilogue_params = self.epi_to_underlying_arguments(epilogue_args)
varlen_params = VarlenManager.to_underlying_arguments(varlen_args)
+ problem_params = self.problem_to_underlying_arguments(problem_args)
TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_m)
tile_sched_args = self.get_scheduler_arguments(
@@ -489,14 +591,24 @@ def __call__(
tile_sched_params, scheduler_args.max_active_clusters
)
- epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0
- epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
+ epi_smem_size = (
+ cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0
+ )
+ epi_c_smem_size = (
+ cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0
+ )
@cute.struct
class SharedStorage:
- ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2]
- epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2]
- sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2]
+ ab_pipeline_array_ptr: cute.struct.MemRange[
+ cutlass.Int64, self.ab_stage * 2
+ ]
+ epi_pipeline_array_ptr: cute.struct.MemRange[
+ cutlass.Int64, self.epi_c_stage * 2
+ ]
+ sched_pipeline_array_ptr: cute.struct.MemRange[
+ cutlass.Int64, self.sched_stage * 2
+ ]
sched_data: cute.struct.MemRange[Int32, self.sched_stage * 4]
sD: cute.struct.Align[
cute.struct.MemRange[
@@ -512,11 +624,15 @@ class SharedStorage:
]
epi: self.epi_get_smem_struct(epilogue_params)
sA: cute.struct.Align[
- cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)],
+ cute.struct.MemRange[
+ self.a_dtype, cute.cosize(self.a_smem_layout_staged)
+ ],
self.buffer_align_bytes,
]
sB: cute.struct.Align[
- cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)],
+ cute.struct.MemRange[
+ self.b_dtype, cute.cosize(self.b_smem_layout_staged)
+ ],
self.buffer_align_bytes,
]
@@ -542,6 +658,7 @@ class SharedStorage:
self.epi_c_smem_layout_staged,
tile_sched_params,
TileSchedulerCls,
+ problem_params,
trace_ptr,
).launch(
grid=grid,
@@ -575,6 +692,7 @@ def kernel(
epi_c_smem_layout: cute.ComposedLayout,
tile_sched_params,
TileSchedulerCls: cutlass.Constexpr[Callable],
+ problem_params,
trace_ptr: Optional[cutlass.Int64] = None,
):
"""
@@ -650,17 +768,23 @@ def kernel(
sched_data = storage.sched_data.get_tensor((4, self.sched_stage))
# Cluster arrive after barrier init
- pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk[:-1], is_relaxed=True)
+ pipeline_init_arrive(
+ cluster_shape_mn=self.cluster_shape_mnk[:-1], is_relaxed=True
+ )
# Generate smem tensor A/B
sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner)
sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner)
sD = None
if const_expr(has_D):
- sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner)
+ sD = storage.sD.get_tensor(
+ epi_smem_layout.outer, swizzle=epi_smem_layout.inner
+ )
sC = None
if const_expr(has_C):
- sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner)
+ sC = storage.sC.get_tensor(
+ epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner
+ )
epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage)
varlen_manager = VarlenManager.create(
@@ -691,8 +815,12 @@ def kernel(
if const_expr(self.use_pdl):
cute.arch.griddepcontrol_wait()
# Get mcast mask
- cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
- block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster)
+ cta_rank_in_cluster = cute.arch.make_warp_uniform(
+ cute.arch.block_idx_in_cluster()
+ )
+ block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(
+ cta_rank_in_cluster
+ )
a_mcast_mask = cute.make_layout_image_mask(
cluster_layout_mnk, block_in_cluster_coord_mnk, mode=1
)
@@ -703,9 +831,13 @@ def kernel(
b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0
# Persistent tile scheduling loop
- is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
+ is_scheduler_warp = (
+ self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id
+ )
if const_expr(cute.size(cluster_layout_mnk) > 1):
- is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
+ is_scheduler_warp = (
+ is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0
+ )
tile_scheduler = TileSchedulerCls()
work_tile = tile_scheduler.initial_work_tile_info()
ab_producer_state = make_pipeline_state(
@@ -713,12 +845,14 @@ def kernel(
)
while work_tile.is_valid_tile:
tctx.b("tma_load")
- tile_coord_mnkl = work_tile.tile_idx
- batch_idx = tile_coord_mnkl[3]
+ tile_coord_mnkl = work_tile.tile_coord_mnkl
+ batch_idx = self.problem_get_problem_idx(problem_params, work_tile)
# Local_tile partition global tensors
copy_A, prefetch_A = None, None
if const_expr(not self.gather_A):
- mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx)
+ mA_mk = self.problem_get_batch_A(
+ problem_params, mA_mkl, varlen_manager, work_tile
+ )
# (bM, bK, RestK)
gA_mk = cute.local_tile(
mA_mk,
@@ -742,7 +876,9 @@ def kernel(
)
# (bN, bK, RestK)
gB_nk = cute.local_tile(
- varlen_manager.offset_batch_B(mB_nkl, batch_idx),
+ self.problem_get_batch_B(
+ problem_params, mB_nkl, varlen_manager, work_tile
+ ),
cute.select(self.cta_tile_shape_mnk, [1, 2]),
(tile_coord_mnkl[1], None),
)
@@ -757,7 +893,9 @@ def kernel(
dst_tensor=sB,
mcast_mask=b_mcast_mask,
)
- len_k = varlen_manager.len_k(batch_idx)
+ len_k = self.problem_get_len_k(
+ problem_params, varlen_manager, work_tile
+ )
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
if const_expr(not self.gather_A):
ab_producer_state = self.load_AB(
@@ -774,7 +912,9 @@ def kernel(
varlen_m=varlen_m,
)
tctx.e("tma_load")
- tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp)
+ tile_scheduler.advance_to_next_work(
+ is_scheduler_warp=is_scheduler_warp
+ )
work_tile = tile_scheduler.get_current_work()
# End of persistent scheduler loop
if const_expr(self.pingpong and not varlen_k):
@@ -795,7 +935,9 @@ def kernel(
)
# Partition global tensor for TiledMMA_A/B/C
tidx, _, _ = cute.arch.thread_idx()
- warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)
+ warp_group_idx = cute.arch.make_warp_uniform(
+ tidx // self.num_threads_per_warp_group
+ )
if const_expr(self.pingpong):
tidx = tidx % self.num_threads_per_warp_group
warp_group_thread_layout = cute.make_layout(
@@ -824,9 +966,13 @@ def kernel(
k_tile_cnt_static = cute.ceil_div(
cute.size(mA_mkl, mode=[1]), self.cta_tile_shape_mnk[2]
)
- c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile))
+ c_tile_cnt = cute.size(
+ cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile)
+ )
- ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage)
+ ab_read_state = make_pipeline_state(
+ pipeline.PipelineUserType.Consumer, self.ab_stage
+ )
epi_store_pipeline = self.make_epi_store_pipeline()
epi_read_state = make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.epi_c_stage
@@ -844,22 +990,32 @@ def kernel(
if const_expr(not varlen_k):
ab_read_state.advance_iters(k_tile_cnt_static)
else:
- len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
+ len_k = self.problem_get_len_k(
+ problem_params, varlen_manager, work_tile
+ )
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
ab_read_state.advance_iters(k_tile_cnt)
# TODO: do we need to check if work_tile is valid?
tile_scheduler.advance_to_next_work()
work_tile = tile_scheduler.get_current_work()
while work_tile.is_valid_tile:
- tile_coord_mnkl = work_tile.tile_idx
- batch_idx = tile_coord_mnkl[3]
- len_k = varlen_manager.len_k(batch_idx)
+ tile_coord_mnkl = work_tile.tile_coord_mnkl
+ batch_idx = self.problem_get_problem_idx(problem_params, work_tile)
+ len_k = self.problem_get_len_k(
+ problem_params, varlen_manager, work_tile
+ )
k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
if const_expr(self.pingpong):
self.pingpong_barrier_sync(warp_group_idx, stage="mma")
tctx.b("mma")
ab_read_state = self.mma(
- ab_pipeline, ab_read_state, mma_fn, acc, acc_slow, k_tile_cnt, warp_group_idx
+ ab_pipeline,
+ ab_read_state,
+ mma_fn,
+ acc,
+ acc_slow,
+ k_tile_cnt,
+ warp_group_idx,
)
if const_expr(varlen_k):
if k_tile_cnt == 0:
@@ -875,7 +1031,9 @@ def kernel(
if const_expr(has_D):
copy_D, _, _ = self.epilog_gmem_copy_and_partition(
tma_atom_d,
- varlen_manager.offset_batch_epi(mD_mnl, batch_idx),
+ self.problem_get_batch_epi(
+ problem_params, mD_mnl, varlen_manager, work_tile
+ ),
self.cta_tile_shape_mnk[:2],
self.epi_tile,
sD,
@@ -885,7 +1043,9 @@ def kernel(
if const_expr(has_C):
copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition(
tma_atom_c,
- varlen_manager.offset_batch_epi(mC_mnl, batch_idx),
+ self.problem_get_batch_epi(
+ problem_params, mC_mnl, varlen_manager, work_tile
+ ),
self.cta_tile_shape_mnk[:2],
self.epi_tile,
sC,
@@ -893,7 +1053,9 @@ def kernel(
)
copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline)
- d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
+ d_dtype_for_layout = (
+ self.d_dtype if self.d_dtype is not None else cutlass.BFloat16
+ )
tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition(
tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx
)
@@ -901,13 +1063,22 @@ def kernel(
tRS_rAcc = self.epi_retile_acc(acc, tRS_rD, tiled_copy_r2s)
load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc)
if const_expr(has_C):
- tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition(
- tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx
+ tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = (
+ self.epilog_smem_load_and_partition(
+ tiled_mma,
+ self.c_layout,
+ self.c_dtype,
+ sC,
+ tRS_rD.layout,
+ tidx,
+ )
)
else:
tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None
- self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx)
+ self.epi_visit_acc(
+ epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx
+ )
epi_read_state, epi_producer_state = self.epilogue(
epilogue_params,
@@ -955,14 +1126,20 @@ def kernel(
# Update starting mainloop pipeline state for the next tile
if const_expr(not varlen_k):
ab_read_state.advance_iters(k_tile_cnt_static)
- tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups)
+ tile_scheduler.advance_to_next_work(
+ advance_count=self.mma_warp_groups
+ )
work_tile = tile_scheduler.get_current_work()
else:
tile_scheduler.advance_to_next_work()
work_tile = tile_scheduler.get_current_work()
if work_tile.is_valid_tile:
- len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3])
- k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2])
+ len_k = self.problem_get_len_k(
+ problem_params, varlen_manager, work_tile
+ )
+ k_tile_cnt = cute.ceil_div(
+ len_k, self.cta_tile_shape_mnk[2]
+ )
ab_read_state.advance_iters(k_tile_cnt)
tile_scheduler.advance_to_next_work()
work_tile = tile_scheduler.get_current_work()
@@ -1016,7 +1193,9 @@ def load_AB(
ab_producer_state.advance()
peek_ab_empty_status = Boolean(True)
if k_tile + 1 < k_tile_cnt:
- peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(
+ ab_producer_state
+ )
return ab_producer_state
@cute.jit
@@ -1038,13 +1217,19 @@ def load_AB_gather_A(
# TMA load on B and cp.async on A
for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1):
prefetch_out = ()
- if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
+ if const_expr(
+ prefetch_A is not None
+ ): # Prefetch early, even before smem is free
prefetch_out = (prefetch_A(k_tile),)
# Wait for A/B buffers to be empty before loading into them
# Also sets the transaction barrier for the A/B buffers
# A tiny bit faster to rotate the warp that does TMA
- is_tma_warp = warp_idx == self.ab_load_warp_id + (k_tile % self.num_ab_load_warps)
- ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
+ is_tma_warp = warp_idx == self.ab_load_warp_id + (
+ k_tile % self.num_ab_load_warps
+ )
+ ab_pipeline.producer_acquire(
+ ab_producer_state, peek_ab_empty_status, is_tma_warp
+ )
smem_idx = ab_producer_state.index
# A bit faster to load B first while we calculate the indices for A
if is_tma_warp:
@@ -1056,15 +1241,23 @@ def load_AB_gather_A(
ab_producer_state.advance()
peek_ab_empty_status = Boolean(True)
if k_tile + 1 < k_tile_cnt:
- peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state)
+ peek_ab_empty_status = ab_pipeline.producer_try_acquire(
+ ab_producer_state
+ )
# bound checking in the K dimension on the last k_tile
if 0 < k_tile_cnt:
k_tile = k_tile_cnt - 1
prefetch_out = ()
- if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free
+ if const_expr(
+ prefetch_A is not None
+ ): # Prefetch early, even before smem is free
prefetch_out = (prefetch_A(k_tile, pred=True),)
- is_tma_warp = warp_idx == self.ab_load_warp_id + k_tile % self.num_ab_load_warps
- ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp)
+ is_tma_warp = (
+ warp_idx == self.ab_load_warp_id + k_tile % self.num_ab_load_warps
+ )
+ ab_pipeline.producer_acquire(
+ ab_producer_state, peek_ab_empty_status, is_tma_warp
+ )
smem_idx = ab_producer_state.index
if is_tma_warp:
tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state)
@@ -1087,7 +1280,9 @@ def _make_gather_A_copy(
varlen_m = varlen_manager.varlen_m
mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx)
if const_expr(varlen_m):
- gAIdx = cute.local_tile(mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],))
+ gAIdx = cute.local_tile(
+ mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],)
+ )
mA_mk = mA_mkl
else:
gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],))
@@ -1099,7 +1294,9 @@ def _make_gather_A_copy(
tiled_copy_A = self._make_gmem_tiled_copy_A(
mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32
)
- dma_tidx = cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id
+ dma_tidx = (
+ cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id
+ )
thr_copy_A = tiled_copy_A.get_slice(dma_tidx)
copy_A, prefetch_A = None, None
if const_expr(varlen_m):
@@ -1144,7 +1341,11 @@ def mma(
for k_tile in cutlass.range(num_prologue_mma):
# Wait for A/B buffer to be ready
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
- mma_fn(A_idx=ab_read_state.index, B_idx=ab_read_state.index, zero_init=zero_init)
+ mma_fn(
+ A_idx=ab_read_state.index,
+ B_idx=ab_read_state.index,
+ zero_init=zero_init,
+ )
zero_init = Boolean(False)
ab_read_state.advance()
peek_ab_full_status = Boolean(True)
@@ -1162,7 +1363,11 @@ def mma(
ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status)
if const_expr(self.fp8_slow_accum):
zero_init = Boolean(True)
- mma_fn(A_idx=ab_read_state.index, B_idx=ab_read_state.index, zero_init=zero_init)
+ mma_fn(
+ A_idx=ab_read_state.index,
+ B_idx=ab_read_state.index,
+ zero_init=zero_init,
+ )
zero_init = Boolean(False)
# Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete
if const_expr(not self.fp8_slow_accum):
@@ -1217,142 +1422,34 @@ def epilogue(
tidx: Int32,
is_tma_warp: Boolean,
) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
- has_C = const_expr(tRS_rC is not None)
- has_D = const_expr(copy_D is not None)
-
- # Setup postact output (returns None for default epilogue, context tuple for Act)
- postact_ctx = self.epi_setup_postact(
- params,
- epi_smem_tensors,
- tiled_copy_r2s,
- tiled_copy_t2r,
- tile_coord_mnkl,
- varlen_manager,
- tidx,
- )
-
- epi_tile_shape = cute.zipped_divide(
- cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
- ).shape[1]
- # We iterate over epi tiles in the N dimension first before the M dimension
- epi_tile_layout = cute.make_ordered_layout(epi_tile_shape, order=(1, 0))
- epi_tile_num = cute.size(epi_tile_shape)
- num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
-
- epi_tensors = self.epi_begin(
+ return run_epilogue_plan(
+ self,
params,
epi_smem_tensors,
+ epi_pipeline,
+ epi_store_pipeline,
+ epi_read_state,
+ epi_producer_state,
epi_tile,
+ load_acc_subtile,
+ tRS_rD,
+ tRS_rC,
tiled_copy_t2r,
tiled_copy_r2s,
+ tRS_sD,
+ tiled_copy_s2r,
+ tSR_rC,
+ tSR_sC,
+ copy_D,
+ copy_C,
tile_coord_mnkl,
varlen_manager,
epilogue_barrier,
+ tile_scheduler,
tidx,
+ is_tma_warp,
)
- if const_expr(copy_C is not None):
- for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
- gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
- if is_tma_warp:
- epi_pipeline.producer_acquire(epi_producer_state)
- copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
- epi_pipeline.producer_commit(epi_producer_state)
- epi_producer_state.advance()
-
- for epi_idx in cutlass.range_constexpr(epi_tile_num):
- # The global memory coordinate for the current epi tile
- gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
- # Copy from acc to D registers
- load_acc_subtile(tRS_rD, epi_idx)
- epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
- if const_expr(has_C):
- epi_pipeline.consumer_wait(epi_read_state)
- cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
- # Fence to make sure shared memory read is visible to TMA load
- cute.arch.fence_view_async_shared()
- cute.arch.sync_warp()
- with cute.arch.elect_one():
- epi_pipeline.consumer_release(epi_read_state)
- epi_read_state.advance()
- if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
- gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
- if is_tma_warp:
- epi_pipeline.producer_acquire(epi_producer_state)
- copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
- epi_pipeline.producer_commit(epi_producer_state)
- epi_producer_state.advance()
- tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
- # Convert and store postact if this epilogue produces one
- if const_expr(postact_ctx is not None):
- tRS_rPostAct_out = self.epi_convert_postact(
- tRS_rPostAct,
- epi_loop_tensors["sr_seed"],
- tidx,
- tile_coord_mnkl,
- num_prev_subtiles,
- epi_idx,
- )
- if is_tma_warp:
- epi_store_pipeline.producer_acquire()
- epilogue_barrier.arrive_and_wait()
- # Copy from D registers to shared memory
- epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
- if const_expr(has_D):
- if const_expr(
- self.rounding_mode == RoundingMode.RS
- and self.acc_dtype == cutlass.Float32
- and self.d_dtype == cutlass.BFloat16
- ):
- seed = epi_loop_tensors["sr_seed"] + (
- tile_coord_mnkl[0] * 65537
- + tile_coord_mnkl[1] * 257
- + tile_coord_mnkl[3] * 17
- + (num_prev_subtiles + epi_idx) * 7
- )
- copy_utils.sr_cvt_copy(
- tiled_copy_r2s,
- tRS_rD,
- tRS_sD[None, None, None, epi_buffer],
- seed,
- tidx,
- )
- else:
- copy_utils.cvt_copy(
- tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer]
- )
- # Copy postact from registers to shared memory
- if const_expr(postact_ctx is not None):
- tiled_copy_postact_r2s, tRS_sPostAct, copy_postact = postact_ctx
- cute.copy(
- tiled_copy_postact_r2s,
- tiled_copy_postact_r2s.retile(tRS_rPostAct_out),
- tRS_sPostAct[None, None, None, epi_buffer],
- )
- # Fence and barrier to make sure shared memory store is visible to TMA store
- cute.arch.fence_view_async_shared()
- epilogue_barrier.arrive_and_wait()
- # Copy from shared memory to global memory
- if is_tma_warp:
- if const_expr(has_D):
- copy_D(src_idx=epi_buffer, dst_idx=gmem_coord)
- if const_expr(postact_ctx is not None):
- copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord)
- epi_store_pipeline.producer_commit()
-
- self.epi_end(
- params,
- epi_tensors,
- epi_tile,
- tiled_copy_t2r,
- tiled_copy_r2s,
- tile_coord_mnkl,
- varlen_manager,
- tidx,
- )
-
- return epi_read_state, epi_producer_state
-
def get_scheduler_class(self, varlen_m: bool = False):
"""Return the scheduler class to use. Override in subclasses for custom schedulers."""
return TileScheduler if not varlen_m else VarlenMTileScheduler
@@ -1401,7 +1498,11 @@ def get_scheduler_arguments(
persistence_mode=persistence_mode,
)
else:
- assert (mD is not None) or (epilogue_args.mPostAct is not None) or (not self.gather_A)
+ assert (
+ (mD is not None)
+ or (epilogue_args.mPostAct is not None)
+ or (not self.gather_A)
+ )
problem_shape_ntile_mnl = (
None,
cute.ceil_div(cute.size(mB, mode=[0]), self.cta_tile_shape_mnk[1]),
@@ -1425,7 +1526,9 @@ def epi_retile_acc(self, acc, tRS_rD, tiled_copy_r2s):
return cute.flat_divide(acc, tRS_rD.layout)
@cute.jit
- def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int):
+ def epi_load_acc_subtile(
+ self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int
+ ):
cute.autovec_copy(tRS_rAcc[None, None, None, epi_idx], tRS_rD)
@cute.jit
@@ -1444,7 +1547,10 @@ def epi_begin(
return ()
def epi_begin_loop(
- self, params: EpilogueParams, epi_tensors: Tuple[cute.Tensor, ...], epi_coord: cute.Coord
+ self,
+ params: EpilogueParams,
+ epi_tensors: Tuple[cute.Tensor, ...],
+ epi_coord: cute.Coord,
) -> Tuple[cute.Tensor, ...]:
return ()
@@ -1503,10 +1609,14 @@ def epi_smem_bytes_per_stage(
def epi_get_smem_struct(self, params: EpilogueParams):
return cute.struct.MemRange[Int32, 0] # Dummy struct
- def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]:
+ def epi_get_smem_tensors(
+ self, params: EpilogueParams, storage
+ ) -> Tuple[cute.Tensor, ...]:
return tuple()
- def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
+ def pingpong_barrier_sync(
+ self, warp_group_idx: Int32, stage: Literal["mma", "epi"]
+ ):
assert stage in ["mma", "epi"]
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
cute.arch.barrier(
@@ -1514,7 +1624,9 @@ def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: Literal["mma", "ep
number_of_threads=2 * self.num_threads_per_warp_group,
)
- def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]):
+ def pingpong_barrier_arrive(
+ self, warp_group_idx: Int32, stage: Literal["mma", "epi"]
+ ):
assert stage in ["mma", "epi"]
barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0
cute.arch.barrier_arrive(
@@ -1554,7 +1666,9 @@ def epilog_smem_store_and_partition(
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None
sD_shape = sD.shape[:2] if sD is not None else self.epi_tile
- tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape
+ tRS_rD_shape = thr_copy_r2s.partition_S(
+ cute.make_identity_tensor(sD_shape)
+ ).shape
tRS_rD = cute.make_rmem_tensor(tRS_rD_shape, self.acc_dtype)
return tiled_copy_r2s, tRS_rD, tRS_sD
@@ -1589,7 +1703,8 @@ def epilog_gmem_copy_and_partition(
gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2])
tDgD_for_tma_partition = cute.zipped_divide(gD, epi_tile)
is_s2g = isinstance(
- atom.op, (cpasync.CopyBulkTensorTileS2GOp, cpasync.CopyReduceBulkTensorTileS2GOp)
+ atom.op,
+ (cpasync.CopyBulkTensorTileS2GOp, cpasync.CopyReduceBulkTensorTileS2GOp),
)
src_tensor, dst_tensor = (
(sD, tDgD_for_tma_partition) if is_s2g else (tDgD_for_tma_partition, sD)
@@ -1609,15 +1724,21 @@ def make_ab_pipeline(
ab_pipeline_mbar_ptr: cute.Pointer,
):
# Threads/warps participating in this pipeline
- producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_warps * 32
- ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt)
+ producer_cnt = (
+ 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_warps * 32
+ )
+ ab_pipeline_producer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread, producer_cnt
+ )
# Each warp will contribute to the arrive count with the number of mcast size
mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
consumer_arrive_cnt = mcast_size * tiled_mma.size // cute.arch.WARP_SIZE
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, consumer_arrive_cnt
)
- pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
+ pipeline_cls = (
+ pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync
+ )
return pipeline_cls.create(
barrier_storage=ab_pipeline_mbar_ptr,
num_stages=self.ab_stage,
@@ -1629,7 +1750,9 @@ def make_ab_pipeline(
)
def make_epi_pipeline(
- self, c_smem_layout: cute.Layout | cute.ComposedLayout, epi_pipeline_mbar_ptr: cute.Pointer
+ self,
+ c_smem_layout: cute.Layout | cute.ComposedLayout,
+ epi_pipeline_mbar_ptr: cute.Pointer,
):
# Threads/warps participating in this pipeline
epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
@@ -1651,13 +1774,18 @@ def make_epi_pipeline(
def make_epi_store_pipeline(self):
# Threads/warps participating in tma store pipeline
num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE
- epi_store_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_epi_threads)
+ epi_store_producer_group = pipeline.CooperativeGroup(
+ pipeline.Agent.Thread, num_epi_threads
+ )
return pipeline.PipelineTmaStore.create(
num_stages=self.epi_stage, producer_group=epi_store_producer_group
)
def make_sched_pipeline(
- self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool
+ self,
+ cluster_layout_mnk: cute.Layout,
+ sched_pipeline_mbar_ptr: cute.Pointer,
+ varlen_k: bool,
):
# Threads/warps participating in this pipeline
sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
@@ -1714,7 +1842,9 @@ def _compute_stages(
"""
epi_stage = 4 if epi_tile[1] <= 16 else 2
- d_bytes_per_stage = cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
+ d_bytes_per_stage = (
+ cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0
+ )
epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage(
epilogue_args, cta_tile_shape_mnk, epi_tile
)
@@ -1726,7 +1856,8 @@ def _compute_stages(
a_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None))
b_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None))
ab_bytes_per_stage = (
- cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8
+ cute.size(a_shape) * a_dtype.width // 8
+ + cute.size(b_shape) * b_dtype.width // 8
)
mbar_helpers_bytes = 1024
@@ -1737,7 +1868,9 @@ def _compute_stages(
# Calculate remaining smem after allocating for A/B stages and reserved bytes
# Add remaining unused smem to epilogue
if epi_bytes_per_stage > 0:
- epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage
+ epi_stage += (
+ remaining_bytes - ab_bytes_per_stage * ab_stage
+ ) // epi_bytes_per_stage
return ab_stage, epi_stage, epi_c_stage
@staticmethod
@@ -1797,7 +1930,10 @@ def _make_smem_layouts(
c_layout: Optional[LayoutEnum],
epi_c_stage: int,
) -> Tuple[
- cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout]
+ cute.ComposedLayout,
+ cute.ComposedLayout,
+ cute.ComposedLayout,
+ Optional[cute.ComposedLayout],
]:
"""Create shared memory layouts for A, B, and C tensors.
@@ -1894,7 +2030,9 @@ def _make_tma_epi_atoms_and_tensors(
"""
assert op_type in ["load", "store", "add"]
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
- d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile)
+ d_cta_v_layout = cute.composition(
+ cute.make_identity_layout(tensor_d.shape), epi_tile
+ )
op = (
cpasync.CopyBulkTensorTileG2SOp()
if op_type == "load"
@@ -2001,10 +2139,20 @@ def is_valid_dtypes(
:rtype: bool
"""
is_valid = True
- if a_dtype not in {Float16, cutlass.BFloat16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}:
+ if a_dtype not in {
+ Float16,
+ cutlass.BFloat16,
+ cutlass.Float8E4M3FN,
+ cutlass.Float8E5M2,
+ }:
is_valid = False
# tested b_dtype
- if b_dtype not in {Float16, cutlass.BFloat16, cutlass.Float8E4M3FN, cutlass.Float8E5M2}:
+ if b_dtype not in {
+ Float16,
+ cutlass.BFloat16,
+ cutlass.Float8E4M3FN,
+ cutlass.Float8E5M2,
+ }:
is_valid = False
if acc_dtype not in {Float32, Float16}:
is_valid = False
@@ -2026,6 +2174,8 @@ def is_valid_dtypes(
is_valid = False
# for Float8 types, this implementation only supports k-major layout
- if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"):
+ if (a_dtype.width == 8 and a_major != "k") or (
+ b_dtype.width == 8 and b_major != "k"
+ ):
is_valid = False
return is_valid
diff --git a/quack/gemm_sq_reduce.py b/quack/gemm_sq_reduce.py
index d3ee4ddb..96c8e32c 100644
--- a/quack/gemm_sq_reduce.py
+++ b/quack/gemm_sq_reduce.py
@@ -254,6 +254,6 @@ def gemm_sq_reduce(
varlen_args = make_varlen_args(None, None, None)
if device_capacity[0] in [10, 11]:
- compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None)
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None, None)
else:
- compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None)
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None)
diff --git a/quack/gemm_symmetric.py b/quack/gemm_symmetric.py
index 812bb5df..39822368 100644
--- a/quack/gemm_symmetric.py
+++ b/quack/gemm_symmetric.py
@@ -1,3 +1,4 @@
+
from typing import Tuple, Optional, Callable
from torch import Tensor
@@ -30,172 +31,36 @@
from quack.rounding import RoundingMode
+from quack.gemm_epilogue_plan import symmetric_epi_commit
+
class GemmSymmetricMixin(GemmActMixin):
def get_scheduler_class(self, varlen_m: bool = False):
return TriangularTileScheduler
@cute.jit
- def epilogue(
+ def epi_plan_commit(
self,
- params: GemmActMixin.EpilogueParams,
- epi_smem_tensors: Tuple[cute.Tensor, ...],
- epi_pipeline: cutlass.pipeline.PipelineAsync,
- epi_store_pipeline: cutlass.pipeline.PipelineAsync,
- epi_read_state: cutlass.pipeline.PipelineState,
- epi_producer_state: cutlass.pipeline.PipelineState,
- epi_tile: cute.Tile,
- load_acc_subtile: Callable,
- tRS_rD: cute.Tensor,
- tRS_rC: Optional[cute.Tensor],
- tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100
- tiled_copy_r2s: cute.TiledCopy,
- tRS_sD: cute.Tensor,
- tiled_copy_s2r: Optional[cute.TiledCopy],
- tSR_rC: Optional[cute.Tensor],
- tSR_sC: Optional[cute.Tensor],
- copy_D: Optional[Callable],
- copy_C: Optional[Callable],
- tile_coord_mnkl: cute.Coord,
- varlen_manager: VarlenManager,
- epilogue_barrier: cutlass.pipeline.NamedBarrier,
- tile_scheduler,
- tidx: Int32,
- is_tma_warp: Boolean,
- ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]:
- has_C = const_expr(tRS_rC is not None)
- has_D = const_expr(copy_D is not None)
-
- tiled_copy_postact_r2s, tRS_sPostAct, copy_postact = self.epi_setup_postact(
- params,
- epi_smem_tensors,
- tiled_copy_r2s,
- tiled_copy_t2r,
- tile_coord_mnkl,
- varlen_manager,
- tidx,
- )
-
- # We iterate over epi tiles in the N dimension first before the M dimension
- epi_tile_shape = cute.zipped_divide(
- cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile
- ).shape[1]
- epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1))
- epi_tile_num = cute.size(epi_tile_shape)
- num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num
-
- epi_tensors = self.epi_begin(
- params,
- epi_smem_tensors,
- epi_tile,
- tiled_copy_t2r,
- tiled_copy_r2s,
+ gmem_coord,
+ epi_buffer,
+ copy_D,
+ copy_postact,
+ postact_ctx,
+ tile_coord_mnkl,
+ is_tma_warp,
+ epi_store_pipeline,
+ ):
+ symmetric_epi_commit(
+ self,
+ gmem_coord,
+ epi_buffer,
+ copy_D,
+ copy_postact,
+ postact_ctx,
tile_coord_mnkl,
- varlen_manager,
- epilogue_barrier,
- tidx,
+ is_tma_warp,
+ epi_store_pipeline,
)
- if const_expr(copy_C is not None):
- for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1):
- gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx)
- if is_tma_warp:
- epi_pipeline.producer_acquire(epi_producer_state)
- copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
- epi_pipeline.producer_commit(epi_producer_state)
- epi_producer_state.advance()
-
- for epi_idx in cutlass.range_constexpr(epi_tile_num):
- # The global memory coordinate for the current epi tile
- gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
- # Copy from acc to D registers
- load_acc_subtile(tRS_rD, epi_idx)
- epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord)
- if const_expr(has_C):
- epi_pipeline.consumer_wait(epi_read_state)
- cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC)
- # Fence to make sure shared memory read is visible to TMA load
- cute.arch.fence_view_async_shared()
- cute.arch.sync_warp()
- with cute.arch.elect_one():
- epi_pipeline.consumer_release(epi_read_state)
- epi_read_state.advance()
- if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num):
- gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage)
- if is_tma_warp:
- epi_pipeline.producer_acquire(epi_producer_state)
- copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state)
- epi_pipeline.producer_commit(epi_producer_state)
- epi_producer_state.advance()
- tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC)
- tRS_rPostAct_out = self.epi_convert_postact(
- tRS_rPostAct,
- epi_loop_tensors["sr_seed"],
- tidx,
- tile_coord_mnkl,
- num_prev_subtiles,
- epi_idx,
- )
- if is_tma_warp:
- epi_store_pipeline.producer_acquire()
- epilogue_barrier.arrive_and_wait()
- # Copy from D registers to shared memory
- epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage
- if const_expr(has_D):
- if const_expr(
- self.rounding_mode == RoundingMode.RS
- and self.acc_dtype == cutlass.Float32
- and self.d_dtype == cutlass.BFloat16
- ):
- seed = epi_loop_tensors["sr_seed"] + (
- tile_coord_mnkl[0] * 65537
- + tile_coord_mnkl[1] * 257
- + tile_coord_mnkl[3] * 17
- + (num_prev_subtiles + epi_idx) * 7
- )
- copy_utils.sr_cvt_copy(
- tiled_copy_r2s,
- tRS_rD,
- tRS_sD[None, None, None, epi_buffer],
- seed,
- tidx,
- )
- else:
- copy_utils.cvt_copy(
- tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer]
- )
- cute.copy(
- tiled_copy_postact_r2s,
- tiled_copy_postact_r2s.retile(tRS_rPostAct_out),
- tRS_sPostAct[None, None, None, epi_buffer],
- )
- pid_m = tile_coord_mnkl[0]
- pid_n = tile_coord_mnkl[1]
- # Fence and barrier to make sure shared memory store is visible to TMA store
- cute.arch.fence_view_async_shared()
- epilogue_barrier.arrive_and_wait()
- # Copy from shared memory to global memory
- if is_tma_warp:
- square_tile_m = pid_m // self.cluster_shape_mnk[0]
- square_tile_n = pid_n // self.cluster_shape_mnk[1]
- if const_expr(has_D):
- copy_D(src_idx=epi_buffer, dst_idx=gmem_coord)
- if square_tile_m != square_tile_n: # don't write twice to the same tile
- copy_postact(src_idx=epi_buffer, dst_idx=gmem_coord)
- epi_store_pipeline.producer_commit()
-
- self.epi_end(
- params,
- epi_tensors,
- epi_tile,
- tiled_copy_t2r,
- tiled_copy_r2s,
- tile_coord_mnkl,
- varlen_manager,
- tidx,
- )
-
- return epi_read_state, epi_producer_state
-
class GemmSymmetricSm90(GemmSymmetricMixin, GemmSm90):
pass
@@ -389,6 +254,6 @@ def scalar_arg(scalar, mode):
varlen_args = None
if device_capacity[0] in [10, 11]:
- compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None)
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None, None)
else:
- compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None)
+ compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None)
diff --git a/quack/gemm_tvm_ffi_utils.py b/quack/gemm_tvm_ffi_utils.py
index 18e8ec56..c2e9ab2b 100644
--- a/quack/gemm_tvm_ffi_utils.py
+++ b/quack/gemm_tvm_ffi_utils.py
@@ -4,15 +4,13 @@
from functools import partial
-import cutlass.cute as cute
-from cutlass import Int32, Int64, Float32
+from cutlass import Float32, Int32, Int64
from cutlass.cute.runtime import make_ptr
-
from quack.compile_utils import make_fake_tensor as fake_tensor
from quack.cute_dsl_utils import torch2cute_dtype_map
from quack.tile_scheduler import TileSchedulerOptions
from quack.varlen_utils import VarlenArguments
-
+import cutlass.cute as cute
def div_for_dtype(dtype):
"""16-byte alignment: divisibility in elements = 128 // dtype_width_bits."""
@@ -66,7 +64,9 @@ def make_scheduler_args(
raster_order=None,
max_swizzle_size=max_swizzle_size,
tile_count_semaphore=(
- tile_count_semaphore.data_ptr() if tile_count_semaphore is not None else None
+ tile_count_semaphore.data_ptr()
+ if tile_count_semaphore is not None
+ else None
),
batch_idx_permute=batch_idx_permute,
)
@@ -77,7 +77,9 @@ def make_fake_scheduler_args(has_semaphore, has_batch_idx_permute, l_sym):
max_active_clusters=Int32(1),
max_swizzle_size=Int32(8),
tile_count_semaphore=(
- make_ptr(Int32, 0, cute.AddressSpace.gmem, assumed_align=4) if has_semaphore else None
+ make_ptr(Int32, 0, cute.AddressSpace.gmem, assumed_align=4)
+ if has_semaphore
+ else None
),
batch_idx_permute=(
fake_tensor(Int32, (l_sym,), leading_dim=0, divisibility=4)
@@ -103,13 +105,19 @@ def make_fake_varlen_args(varlen_m, varlen_k, gather_A, aidx_len):
num_seqlens = cute.sym_int()
return VarlenArguments(
mCuSeqlensM=(
- fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4) if varlen_m else None
+ fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4)
+ if varlen_m
+ else None
),
mCuSeqlensK=(
- fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4) if varlen_k else None
+ fake_tensor(Int32, (num_seqlens,), leading_dim=0, divisibility=4)
+ if varlen_k
+ else None
),
mAIdx=(
- fake_tensor(Int32, (aidx_len,), leading_dim=0, divisibility=4) if gather_A else None
+ fake_tensor(Int32, (aidx_len,), leading_dim=0, divisibility=4)
+ if gather_A
+ else None
),
)
@@ -188,6 +196,7 @@ def compile_gemm_kernel(
has_trace_ptr=False,
use_tma_gather=False,
concat_layout=None,
+ problem_args=None,
):
"""Build GemmCls instance, apply SM90 partial, and cute.compile with TVM-FFI."""
if device_capacity[0] in [9, 12]:
@@ -225,5 +234,6 @@ def compile_gemm_kernel(
stream,
*sf_args,
trace_ptr,
+ problem_args,
options="--enable-tvm-ffi",
)
diff --git a/quack/gemm_work.py b/quack/gemm_work.py
new file mode 100644
index 00000000..39a6e0e4
--- /dev/null
+++ b/quack/gemm_work.py
@@ -0,0 +1,44 @@
+# Copyright (c) 2025, Tri Dao.
+
+from typing import NamedTuple, Optional
+
+import cutlass
+import cutlass.cute as cute
+from cutlass import Boolean, Int32
+from quack.cute_dsl_utils import mlir_namedtuple, StaticTypes
+
+
+@mlir_namedtuple
+class WorkDesc(NamedTuple):
+ tile_coord_mnkl: cute.Coord
+ problem_idx: Int32
+ k_tile_begin: Int32 = Int32(0)
+ k_tile_count: Optional[Int32] = None
+ split_k_idx: Int32 = Int32(0)
+ split_k_parts: Int32 = Int32(1)
+ is_final_split: Boolean = Boolean(True)
+ is_valid_tile: Boolean = Boolean(False)
+
+ def __extract_mlir_values__(self):
+ values = []
+ for field_val in self:
+ if field_val is None or isinstance(field_val, StaticTypes):
+ continue
+ values.extend(cutlass.extract_mlir_values(field_val))
+ return values
+
+ @property
+ def tile_idx(self):
+ return self.tile_coord_mnkl
+
+ @property
+ def batch_idx(self):
+ return self.tile_coord_mnkl[3]
+
+
+def make_work_desc(tile_coord_mnkl: cute.Coord, is_valid_tile: Boolean) -> WorkDesc:
+ return WorkDesc(
+ tile_coord_mnkl=tile_coord_mnkl,
+ problem_idx=tile_coord_mnkl[3],
+ is_valid_tile=is_valid_tile,
+ )
diff --git a/quack/tile_scheduler.py b/quack/tile_scheduler.py
index 5cc6ac4f..e798363b 100644
--- a/quack/tile_scheduler.py
+++ b/quack/tile_scheduler.py
@@ -10,6 +10,7 @@
import quack.utils as utils
from quack.fast_math import FastDivmod
+from quack.gemm_work import WorkDesc, make_work_desc
from quack.pipeline import PipelineStateWAdvance
from quack.cute_dsl_utils import mlir_namedtuple
@@ -308,7 +309,7 @@ def _delinearize_work_idx(
block_zero_only: bool = False,
loc=None,
ip=None,
- ) -> cutlass.utils.WorkTileInfo:
+ ) -> WorkDesc:
params = self.params
if const_expr(is_valid is None):
if const_expr(params.persistence_mode == PersistenceMode.NONE):
@@ -336,10 +337,10 @@ def _delinearize_work_idx(
else params.batch_idx_permute[bidz_]
)
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
- return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
+ return make_work_desc(tile_coord_mnkl, is_valid)
@cute.jit
- def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
+ def get_current_work(self, *, loc=None, ip=None) -> WorkDesc:
params = self.params
pid_m, pid_n, batch_idx, is_valid = Int32(0), Int32(0), Int32(0), Boolean(False)
if const_expr(params.persistence_mode == PersistenceMode.NONE):
@@ -361,10 +362,10 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
self._pipeline_state.advance()
is_valid = Boolean(is_valid_i32)
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
- return cutlass.utils.WorkTileInfo(tile_coord_mnkl, Boolean(is_valid))
+ return make_work_desc(tile_coord_mnkl, Boolean(is_valid))
# @cute.jit
- def initial_work_tile_info(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:
+ def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkDesc:
return self._delinearize_work_idx(self._current_work_idx, loc=loc, ip=ip)
# if is_scheduler_warp:
# work_tile_info = self._delinearize_work_idx(block_zero_only=True, loc=loc, ip=ip)
@@ -428,7 +429,7 @@ def _fetch_next_work_idx(self, *, loc=None, ip=None) -> Int32 | Tuple[Int32, Int
@cute.jit
def write_work_tile_to_smem(
- self, work_tile_info: cutlass.utils.WorkTileInfo, *, loc=None, ip=None
+ self, work_tile_info: WorkDesc, *, loc=None, ip=None
):
params = self.params
if const_expr(self._sched_smem is not None):
@@ -733,7 +734,7 @@ def _delinearize_work_idx(
block_zero_only: bool = False,
loc=None,
ip=None,
- ) -> cutlass.utils.WorkTileInfo:
+ ) -> WorkDesc:
params = self.params
if const_expr(is_valid is None):
if const_expr(params.persistence_mode == PersistenceMode.NONE):
@@ -764,7 +765,7 @@ def _delinearize_work_idx(
# if tidx == 0:
# cute.printf("bidx = {}, bidy = {}, group_id = {}, id_in_group = {}, group_size_actual = {}, group_col = {}, group_remainder = {}, cid_n_in_group = {}, cid_m_in_group = {}, cid_m = {}, cid_n = {}, is_valid = {}",
# bidx, bidy, group_id, id_in_group, group_size_actual, group_col, group_remainder, cid_n_in_group, cid_m_in_group, cid_m, cid_n, is_valid)
- return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
+ return make_work_desc(tile_coord_mnkl, is_valid)
@dataclass
@@ -1025,7 +1026,7 @@ def _delinearize_work_idx(
block_zero_only: bool = False,
loc=None,
ip=None,
- ) -> cutlass.utils.WorkTileInfo:
+ ) -> WorkDesc:
assert bidz is None
params = self.params
lane_idx = cute.arch.lane_idx()
@@ -1091,4 +1092,4 @@ def _delinearize_work_idx(
tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
self._current_batch_idx = batch_idx
self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
- return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
+ return make_work_desc(tile_coord_mnkl, is_valid)