Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 152 additions & 3 deletions benchmarks/benchmark_gemm_autotuned.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
gemm,
gemm_act,
gemm_act_tuned,
gemm_dact,
gemm_dact_tuned,
gemm_dgated,
gemm_dgated_tuned,
gemm_tuned,
Expand Down Expand Up @@ -130,6 +132,43 @@ def _torch_dgated_act(dact_fn, x, w, preact):
return dx, postact


def _drelu(x):
return (x > 0).to(x.dtype), F.relu(x)


def _drelu_sq(x):
relu_x = F.relu(x)
return 2.0 * relu_x, relu_x.square()


def _dgelu_tanh_approx(x):
c1 = math.sqrt(2.0 / math.pi)
c2 = 0.044715
z = c1 * (x + c2 * x * x * x)
tanh_z = torch.tanh(z)
sech2_z = 1.0 - tanh_z * tanh_z
dz_dx = c1 * (1.0 + 3.0 * c2 * x * x)
postact = 0.5 * x * (1.0 + tanh_z)
dact = 0.5 * (1.0 + tanh_z) + 0.5 * x * sech2_z * dz_dx
return dact, postact


def _torch_dact_act(dact_fn, x, w, preact):
"""Reference: GEMM + activation backward over preact."""
dout = F.linear(x, w)
dact, postact = dact_fn(preact)
return dout * dact, postact


_dact_act_fns = {
"silu": _dsilu_exp,
"silu-tanh": _dsilu_tanh,
"relu": _drelu,
"relu_sq": _drelu_sq,
"gelu_tanh_approx": _dgelu_tanh_approx,
}


_dgated_act_fns = {
"swiglu": _dsilu_exp,
"swiglu-tanh": _dsilu_tanh,
Expand Down Expand Up @@ -191,6 +230,60 @@ def benchmark_gemm_act(
return ms, tf


def benchmark_gemm_dact(
m,
n,
k,
activation="gelu_tanh_approx",
dtype=torch.bfloat16,
repeats=30,
tuned=True,
config=None,
):
"""Benchmark fused GEMM + activation backward."""
a = torch.randn(m, k, device="cuda", dtype=dtype)
b = torch.randn(k, n, device="cuda", dtype=dtype) / math.sqrt(k)
preact = torch.randn(m, n, device="cuda", dtype=dtype)
nflops = 2 * m * n * k

if config is None:
fn = lambda: gemm_dact(a, b, preact, activation=activation, out_dtype=dtype, tuned=tuned)
else:
dx_out = torch.empty(m, n, device="cuda", dtype=dtype)
postact_out = torch.empty(m, n, device="cuda", dtype=preact.dtype)
fn = lambda: gemm_dact_tuned.fn(
a,
b,
preact,
dx_out,
postact_out,
None,
activation,
False,
None,
None,
True,
config=config,
)
fn() # warmup / autotune
time.sleep(0.5)
ms = do_bench(fn, warmup=5, rep=repeats)
tf = tflops(nflops, ms)

w = b.T.contiguous()
ref_fn = torch.compile(lambda: _torch_dact_act(_dact_act_fns[activation], a, w, preact))
ref_fn()
ref_fn()
time.sleep(0.5)
ms_pt = do_bench(ref_fn, warmup=5, rep=repeats)
tf_pt = tflops(nflops, ms_pt)

print(f" quack: {ms:.3f}ms {tf:.1f} TFLOPS")
print(f" cuBLAS + torch.compile: {ms_pt:.3f}ms {tf_pt:.1f} TFLOPS")
print(f" speedup: {ms_pt / ms:.2f}x")
return ms, tf


def benchmark_gemm_dgated(
m,
n,
Expand Down Expand Up @@ -287,6 +380,17 @@ def main():
action="store_true",
help="Only run the transformer FFN gated backward GEMM benchmark",
)
parser.add_argument(
"--only-dact",
action="store_true",
help="Only run the GEMM activation backward benchmark",
)
parser.add_argument(
"--dact-activation",
choices=sorted(_dact_act_fns),
default=None,
help="Restrict the activation backward benchmark to one activation",
)
parser.add_argument(
"--dgated-activation",
choices=sorted(_dgated_act_fns),
Expand Down Expand Up @@ -332,11 +436,19 @@ def main():
"swiglu-tanh",
]
)
dact_activations = (
[args.dact_activation]
if args.dact_activation
else [
"gelu_tanh_approx",
"relu",
]
)
forced_config = forced_config_from_args(args)
ffn = int(args.dim * 3.5) # Llama-3 ratio

if args.only_gated and args.only_dgated:
raise ValueError("--only-gated and --only-dgated are mutually exclusive")
if sum([args.only_gated, args.only_dgated, args.only_dact]) > 1:
raise ValueError("--only-gated, --only-dgated, and --only-dact are mutually exclusive")

if args.only_gated:
print(
Expand All @@ -361,6 +473,27 @@ def main():
)
return

if args.only_dact:
print(
f"GEMM activation backward benchmark (workers={os.environ.get('QUACK_COMPILE_WORKERS', '4')})"
)
print(f" M={M}, N={N}, K={K}, dtype={args.dtype}")
if forced_config is not None:
print(f" forced config: {forced_config}")
for activation in dact_activations:
print(f"\n d{activation}: ({M}, {K}) x ({K}, {N})")
benchmark_gemm_dact(
M,
N,
K,
activation,
dtype,
repeats=args.repeats,
tuned=not args.untuned and forced_config is None,
config=forced_config,
)
return

if args.only_dgated:
print(
f"GEMM gated backward benchmark (workers={os.environ.get('QUACK_COMPILE_WORKERS', '4')})"
Expand Down Expand Up @@ -413,7 +546,23 @@ def main():
)
print()

# --- 3. Transformer-relevant shapes ---
# --- 3. GEMM + activation backward ---
print("=" * 60)
print(f"GEMM + dGeLU: ({M}, {K}) x ({K}, {N})")
print("=" * 60)
benchmark_gemm_dact(
M,
N,
K,
"gelu_tanh_approx",
dtype,
repeats=args.repeats,
tuned=not args.untuned and forced_config is None,
config=forced_config,
)
print()

# --- 4. Transformer-relevant shapes ---
batch = args.batch
dim = args.dim
head_dim = 128
Expand Down
110 changes: 83 additions & 27 deletions quack/gemm_dact.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,22 @@
class GemmDActMixin(GemmActMixin):
# Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
# and return 2 arguments (dx, out)
EpilogueArguments = GemmActMixin.EpilogueArguments
_epi_ops = (
Scalar("sr_seed", dtype=Int32),
ColVecLoad("mColVecBroadcast"),
TileStore("mAuxOut"),
ColVecReduce("mColVecReduce"),
)
_extra_param_fields = (("act_fn", cutlass.Constexpr, None),)

@mlir_namedtuple
class EpilogueArguments(NamedTuple):
mAuxOut: cute.Tensor
act_fn: cutlass.Constexpr[Optional[Callable]] = None
mColVecBroadcast: Optional[cute.Tensor] = None
mColVecReduce: Optional[cute.Tensor] = None
rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN
sr_seed: Optional[Int32 | cute.Tensor] = None

@cute.jit
def epi_visit_subtile(
Expand All @@ -53,6 +68,8 @@ def epi_visit_subtile(
tRS_rC: Optional[cute.Tensor] = None,
) -> Optional[cute.Tensor]:
assert tRS_rC is not None
tDrColVec = epi_loop_tensors.get("mColVecBroadcast")
tDrColVecReduce = epi_loop_tensors.get("mColVecReduce")
# We don't add C to the accumulator
GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None)
tRS_rC_acc = cute.make_rmem_tensor_like(tRS_rC, self.acc_dtype)
Expand All @@ -74,6 +91,40 @@ def epi_visit_subtile(
)
else:
tRS_rAuxOut = tRS_rC_acc
if const_expr(tDrColVecReduce is not None):
# Accumulate unscaled postact * GEMM output for router-score gradients.
colvec_reduce_accumulate(self, tDrColVecReduce, tRS_rAuxOut, rScale=tRS_rD)
if const_expr(tDrColVec is not None):
if const_expr(self.arch != 100):
tRS_rD.store(tRS_rD.load() * tDrColVec.load().to(tRS_rD.element_type))
tRS_rAuxOut.store(
tRS_rAuxOut.load() * tDrColVec.load().to(tRS_rAuxOut.element_type)
)
else:
tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout)
tRS_rD_mn = layout_utils.convert_layout_zero_stride(tRS_rD, tDrColVec.layout)
tRS_rAuxOut_mn = layout_utils.convert_layout_zero_stride(
tRS_rAuxOut, tDrColVec.layout
)
for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True):
for n in cutlass.range(
cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True
):
scale = (tDrColVec_mn[m, 0], tDrColVec_mn[m, 0])
tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1] = (
cute.arch.mul_packed_f32x2(
(tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1]), scale
)
)
tRS_rAuxOut_mn[m, 2 * n], tRS_rAuxOut_mn[m, 2 * n + 1] = (
cute.arch.mul_packed_f32x2(
(
tRS_rAuxOut_mn[m, 2 * n],
tRS_rAuxOut_mn[m, 2 * n + 1],
),
scale,
)
)
return tRS_rAuxOut


Expand Down Expand Up @@ -301,30 +352,30 @@ def _compile_gemm_dact(
pa_shape = (m, n) if varlen_m else (m, n, l)
mAuxOut = fake_tensor(postact_dtype, pa_shape, leading_dim=pa_leading, divisibility=div_pa)

mColVec = None
if colvec_scale_ndim == 2:
mColVec = fake_tensor(colvec_scale_dtype, (l, m), leading_dim=1, divisibility=4)
elif colvec_scale_ndim == 1:
mColVec = fake_tensor(colvec_scale_dtype, (m,), leading_dim=0, divisibility=4)
mColVecReduce = None
n_tiles = cute.sym_int()
if colvec_reduce_ndim == 3:
mColVecReduce = fake_tensor(
colvec_reduce_dtype,
(l, m, n_tiles),
leading_dim=2,
divisibility=1,
)
elif colvec_reduce_ndim == 2:
mColVecReduce = fake_tensor(
colvec_reduce_dtype,
(m, n_tiles),
leading_dim=1,
divisibility=1,
)

if is_dgated:
act_fn = dgate_fn_map[activation]

mColVec = None
if colvec_scale_ndim == 2:
mColVec = fake_tensor(colvec_scale_dtype, (l, m), leading_dim=1, divisibility=4)
elif colvec_scale_ndim == 1:
mColVec = fake_tensor(colvec_scale_dtype, (m,), leading_dim=0, divisibility=4)
mColVecReduce = None
n_tiles = cute.sym_int()
if colvec_reduce_ndim == 3:
mColVecReduce = fake_tensor(
colvec_reduce_dtype,
(l, m, n_tiles),
leading_dim=2,
divisibility=1,
)
elif colvec_reduce_ndim == 2:
mColVecReduce = fake_tensor(
colvec_reduce_dtype,
(m, n_tiles),
leading_dim=1,
divisibility=1,
)
epi_args = GemmCls.EpilogueArguments(
mAuxOut,
act_fn,
Expand All @@ -338,7 +389,12 @@ def _set_implicit_dtype(gemm_obj):
post_init = _set_implicit_dtype
else:
act_fn = dact_fn_map[activation]
epi_args = GemmCls.EpilogueArguments(mAuxOut, act_fn)
epi_args = GemmCls.EpilogueArguments(
mAuxOut,
act_fn,
mColVecBroadcast=mColVec,
mColVecReduce=mColVecReduce,
)
post_init = None

scheduler_args = make_fake_scheduler_args(
Expand Down Expand Up @@ -384,7 +440,7 @@ def gemm_dact(
persistent: bool = True,
is_dynamic_persistent: bool = False,
max_swizzle_size: int = 8,
colvec_scale: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m (dgated only)
colvec_scale: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
# (l, m, ceildiv(n, tile_n)), or (total_m, ceildiv(n, tile_n)) if varlen_m (dgated only)
colvec_reduce: Optional[Tensor] = None,
cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
Expand All @@ -394,8 +450,6 @@ def gemm_dact(
is_dgated = activation in dgate_fn_map
if not is_dgated:
assert activation in dact_fn_map, f"Unsupported activation {activation}"
assert colvec_scale is None, "colvec_scale is only supported for gated activations"
assert colvec_reduce is None, "colvec_reduce is only supported for gated activations"
gemm_cls_name = "dgated" if is_dgated else "dact"

varlen_m = cu_seqlens_m is not None
Expand Down Expand Up @@ -500,6 +554,8 @@ def gemm_dact(
epi_args = GemmDActMixin.EpilogueArguments(
PostAct_p,
None,
mColVecBroadcast=colvec_scale,
mColVecReduce=colvec_reduce,
rounding_mode=None,
sr_seed=None,
)
Expand Down
Loading
Loading