Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions quack/gemm_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ class GemmSm100(GemmSm90):

:param acc_dtype: Data type for accumulation during computation
:type acc_dtype: type[cutlass.Numeric]
:param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
:type mma_tiler_mn: Tuple[int, int]
:param mma_tiler_mn: Shape of the MMA tile. Pass (M, N) to default K to
4 MMA instructions, or (M, N, K) to set the K tile size explicitly.
:type mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]]
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
:type cluster_shape_mn: Tuple[int, int]

Expand Down Expand Up @@ -154,7 +155,7 @@ def __init__(
self,
acc_dtype: Type[cutlass.Numeric],
a_dtype: Type[cutlass.Numeric], # ignored for now
mma_tiler_mn: Tuple[int, int],
mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]],
cluster_shape_mnk: Tuple[int, int, int],
sf_vec_size: Optional[int] = None,
gather_A: bool = False,
Expand All @@ -176,8 +177,9 @@ def __init__(

:param acc_dtype: Data type of the accumulator.
:type acc_dtype: type[cutlass.Numeric]
:param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
:type mma_tiler_mn: Tuple[int, int]
:param mma_tiler_mn: (M, N) or (M, N, K) shape of the MMA tile.
If only (M, N) is given, K defaults to 4 * instruction K.
:type mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]]
:param cluster_shape_mnk: Tuple (ClusterM, ClusterN) shape of the cluster.
:type cluster_shape_mnk: Tuple[int, int]
"""
Expand All @@ -186,8 +188,11 @@ def __init__(
self.use_2cta_instrs = cluster_shape_mnk[0] == 2 and mma_tiler_mn[0] in (256,)
self.cluster_shape_mnk = cluster_shape_mnk
assert cluster_shape_mnk[2] == 1, "Cluster shape K must be 1"
# K dimension is deferred in _setup_attributes
self.mma_tiler = (*mma_tiler_mn, 1)
# K dimension: if user provides 3 values, use their K; otherwise default in _setup_attributes
if len(mma_tiler_mn) == 3:
self.mma_tiler = tuple(mma_tiler_mn)
else:
self.mma_tiler = (*mma_tiler_mn, 0)
self.sf_vec_size = sf_vec_size
self.blockscaled = sf_vec_size is not None
self.is_persistent = True
Expand Down Expand Up @@ -302,7 +307,10 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: Varle
)

# Compute mma/cluster/tile shapes
mma_inst_tile_k = 4
if self.mma_tiler[2] > 0:
mma_inst_tile_k = self.mma_tiler[2] // self.mma_inst_shape_mnk[2]
else:
mma_inst_tile_k = 4
self.mma_tiler = (
self.mma_tiler[0],
self.mma_tiler[1],
Expand Down Expand Up @@ -2427,7 +2435,7 @@ def is_valid_dtypes_and_scale_factor_vec_size(

@staticmethod
def is_valid_mma_tiler_and_cluster_shape(
mma_tiler_mn: Tuple[int, int],
mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]],
cluster_shape_mn: Tuple[int, int],
blockscaled: bool,
) -> bool:
Expand Down Expand Up @@ -2536,7 +2544,7 @@ def can_implement_blockscaled(
sf_dtype: Type[cutlass.Numeric],
sf_vec_size: int,
d_dtype: Type[cutlass.Numeric],
mma_tiler_mn: Tuple[int, int],
mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]],
cluster_shape_mn: Tuple[int, int],
m: int,
n: int,
Expand Down Expand Up @@ -2572,7 +2580,7 @@ def can_implement(
ab_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
d_dtype: Type[cutlass.Numeric],
mma_tiler_mn: Tuple[int, int],
mma_tiler_mn: Union[Tuple[int, int], Tuple[int, int, int]],
cluster_shape_mn: Tuple[int, int],
m: int,
n: int,
Expand Down
Loading
Loading