Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ from quack import rmsnorm, softmax, cross_entropy
[blogpost](media/2025-07-10-membound-sol.md) on how to get memory-bound kernels
to speed-of-light, right in the comfort of Python thanks to the [CuTe-DSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html).


## Performance

<div align="center">
Expand Down
3 changes: 3 additions & 0 deletions quack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from quack.softmax import softmax
from quack.cross_entropy import cross_entropy
from quack.rounding import RoundingMode
from quack.gemm_interface import gemm, gemm_grouped


if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
Expand All @@ -19,5 +20,7 @@
"rmsnorm",
"softmax",
"cross_entropy",
"gemm",
"gemm_grouped",
"RoundingMode",
]
130 changes: 119 additions & 11 deletions quack/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,34 @@
)


from cutlass import Float32, Int32
from quack.cute_dsl_utils import (
get_device_capacity,
get_max_active_clusters,
torch2cute_dtype_map,
)
from quack.gemm_default_epi import (
GemmDefaultEpiMixin,
GemmDefaultSm100,
GemmDefaultSm120,
GemmDefaultSm90,
)
from quack.gemm_problem_adapter import (
GroupedProblemAdapterMixin,
GroupedProblemArguments,
)
from quack.gemm_tvm_ffi_utils import (
compile_gemm_kernel,
get_dtypes,
get_majors,
make_fake_gemm_tensors,
make_fake_scheduler_args,
make_fake_varlen_args,
make_scheduler_args,
make_varlen_args,
perm3d,
)

@jit_cache
def _compile_gemm(
a_dtype,
Expand Down Expand Up @@ -63,6 +91,7 @@ def _compile_gemm(
rounding_mode,
sr_seed_mode,
has_trace_ptr,
grouped,
):
sm_to_cls = {
9: GemmDefaultSm90,
Expand All @@ -71,6 +100,10 @@ def _compile_gemm(
12: GemmDefaultSm120,
}
GemmCls = sm_to_cls[device_capacity[0]]
if grouped:
GemmCls = type(
f"Grouped{GemmCls.__name__}", (GroupedProblemAdapterMixin, GemmCls), {}
)
mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors(
a_dtype,
b_dtype,
Expand Down Expand Up @@ -115,6 +148,14 @@ def fake_scalar(mode, dtype=Float32):
)
aidx_len = m if varlen_m else (k if varlen_k else None)
varlen_args = make_fake_varlen_args(varlen_m, varlen_k, gather_A, aidx_len)
problem_args = (
GroupedProblemArguments(
mProblemIndex=fake_tensor(Int32, (l,), leading_dim=0, divisibility=4),
mProblemK=fake_tensor(Int32, (l,), leading_dim=0, divisibility=4),
)
if grouped
else None
)
return compile_gemm_kernel(
GemmCls,
a_dtype,
Expand All @@ -135,6 +176,7 @@ def fake_scalar(mode, dtype=Float32):
has_trace_ptr=has_trace_ptr,
use_tma_gather=use_tma_gather,
concat_layout=concat_layout or None,
problem_args=problem_args,
)


Expand All @@ -157,21 +199,48 @@ def gemm(
colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
alpha: float | Tensor = 1.0,
beta: float | Tensor = 1.0,
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length
A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
cu_seqlens_m: Optional[
Tensor
] = None, # (l+1,) cumulative sum of m values for variable length
cu_seqlens_k: Optional[
Tensor
] = None, # (l+1,) cumulative sum of k values for variable length
A_idx: Optional[
Tensor
] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
batch_idx_permute: Optional[
Tensor
] = None, # (l,) permutation of batch indices for scheduler
add_to_output: bool = False,
rounding_mode: int = RoundingMode.RN,
sr_seed: int | Tensor = 0,
use_tma_gather: bool = False,
concat_layout: dict | None = None,
grouped: bool = False,
grouped_problem_index: Optional[Tensor] = None,
grouped_problem_k: Optional[Tensor] = None,
trace_ptr=None, # Optional Int64 from TraceSession.ptr
) -> None:
varlen_m = cu_seqlens_m is not None
varlen_k = cu_seqlens_k is not None
varlen = varlen_m or varlen_k
gather_A = A_idx is not None
if grouped:
assert A.ndim == 3 and B.ndim == 3 and D.ndim == 3, (
"grouped GEMM expects dense packed batched tensors"
)
assert not varlen and not gather_A, (
"grouped GEMM does not combine with varlen/gather"
)
batch_size = A.shape[0]
if grouped_problem_index is not None:
assert grouped_problem_index.numel() == batch_size, (
"grouped_problem_index must have one entry per problem"
)
if grouped_problem_k is not None:
assert grouped_problem_k.numel() == batch_size, (
"grouped_problem_k must have one entry per problem"
)
assert not (varlen_m and varlen_k), "Only one of cu_seqlens_m and cu_seqlens_k"
if gather_A:
assert varlen, "gather_A requires varlen"
Expand All @@ -188,11 +257,17 @@ def gemm(
assert B.stride(-2) == 1, "varlen_k requires B to be n-major"

device_capacity = get_device_capacity(A.device)
assert device_capacity[0] in [9, 10, 11, 12], "Only SM90, SM100, SM110, and SM120 are supported"
assert device_capacity[0] in [9, 10, 11, 12], (
"Only SM90, SM100, SM110, and SM120 are supported"
)
if use_tma_gather:
assert device_capacity[0] in [10, 11], "TMA gather currently requires SM100/SM110"
assert device_capacity[0] in [10, 11], (
"TMA gather currently requires SM100/SM110"
)
if rounding_mode == RoundingMode.RS:
assert device_capacity[0] == 10, "Stochastic rounding (RoundingMode.RS) requires SM100"
assert device_capacity[0] == 10, (
"Stochastic rounding (RoundingMode.RS) requires SM100"
)
if is_dynamic_persistent and device_capacity[0] == 9:
assert tile_count_semaphore is not None, (
"Dynamic persistent tile scheduler in SM90 requires a semaphore in GMEM"
Expand All @@ -208,7 +283,9 @@ def gemm(
concat_layout = tuple(sorted(concat_layout)) if concat_layout else ()

sr_seed_mode = (
2 if isinstance(sr_seed, Tensor) else (1 if rounding_mode == RoundingMode.RS else 0)
2
if isinstance(sr_seed, Tensor)
else (1 if rounding_mode == RoundingMode.RS else 0)
)
compiled_fn = _compile_gemm(
a_dtype,
Expand Down Expand Up @@ -240,6 +317,7 @@ def gemm(
rounding_mode,
sr_seed_mode,
trace_ptr is not None,
grouped,
)

from quack.cache_utils import COMPILE_ONLY
Expand All @@ -255,7 +333,9 @@ def scalar_arg(scalar, mode, dtype=Float32):
else:
return scalar.data_ptr()

max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
max_active_clusters = (
get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
)

epi_args = GemmDefaultEpiMixin.EpilogueArguments(
alpha=scalar_arg(alpha, alpha_mode),
Expand All @@ -273,10 +353,38 @@ def scalar_arg(scalar, mode, dtype=Float32):
batch_idx_permute,
)
varlen_args = make_varlen_args(cu_seqlens_m, cu_seqlens_k, A_idx)
problem_args = (
GroupedProblemArguments(
mProblemIndex=grouped_problem_index,
mProblemK=grouped_problem_k,
)
if grouped
else None
)

if device_capacity[0] in [10, 11]:
compiled_fn(
A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, trace_ptr
A_p,
B_p,
D_p,
C_p,
epi_args,
scheduler_args,
varlen_args,
None,
None,
trace_ptr,
problem_args,
)
else:
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, trace_ptr)
compiled_fn(
A_p,
B_p,
D_p,
C_p,
epi_args,
scheduler_args,
varlen_args,
trace_ptr,
problem_args,
)
4 changes: 2 additions & 2 deletions quack/gemm_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,9 @@ def scalar_arg(scalar, mode, dtype=Int32):
varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx)

if device_capacity[0] in [10, 11]:
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None)
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None, None, None)
else:
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None)
compiled_fn(A_p, B_p, D_p, C_p, epi_args, scheduler_args, varlen_args, None, None)


gemm_gated = gemm_act
4 changes: 2 additions & 2 deletions quack/gemm_dact.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,10 +499,10 @@ def gemm_dact(

if device_capacity[0] in [10, 11]:
compiled_fn(
A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None, None, None
A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None, None, None, None
)
else:
compiled_fn(A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None)
compiled_fn(A_p, B_p, Out_p, PreAct_p, epi_args, scheduler_args, varlen_args, None, None)


gemm_dgated = gemm_dact
Loading
Loading