Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 145 additions & 2 deletions benchmarks/benchmark_gemm_autotuned.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@

from quack.autotuner import default_cache_dir
from quack.cache_utils import get_cache_path
from quack.gemm_blockscaled_interface import (
mxfp8_gemm_act,
quantize_act_sm90,
quantize_weight_sm90,
)
from quack.gemm_config import GemmConfig
from quack.gemm_interface import (
act_to_pytorch_fn_map,
Expand Down Expand Up @@ -245,6 +250,90 @@ def benchmark_gemm_dgated(
return ms, tf


def benchmark_mxfp8_gemm_act(
m,
n,
k,
activation="swiglu",
dtype=torch.bfloat16,
repeats=30,
trace_path=None,
):
"""Benchmark fused MXFP8 GEMM + gated activation (SM90 blockscaled path).

Quantizes A (bf16 -> fp8_e4m3fn + 1x128 scales) and W (bf16 -> fp8_e4m3fn +
128x128 scales) once outside the timed loop, then measures the fused
GEMM+gated-activation kernel.

Baseline matches benchmark_gemm_act: torch.compile(F.linear + gated_activation)
on bf16. The reported speedup conflates fusion gains with the lower-precision
MMA throughput, so it overstates pure fusion benefit relative to a hypothetical
bf16 fused kernel.
"""
is_gated = activation in gated_to_pytorch_fn_map
if not is_gated:
raise ValueError(f"benchmark_mxfp8_gemm_act expects a gated activation; got {activation!r}")

a_bf16 = torch.randn(m, k, device="cuda", dtype=dtype)
# W: (2*N, K) for gated; quantize then build a (K, 2*N) K-contig view for B.
b_n = 2 * n
w_bf16 = torch.randn(b_n, k, device="cuda", dtype=dtype) / math.sqrt(k)

a_q, a_sc = quantize_act_sm90(a_bf16)
w_q, w_sc = quantize_weight_sm90(w_bf16)
b_q, b_sc = w_q.mT, w_sc.mT

nflops = 2 * m * b_n * k

fn = lambda: mxfp8_gemm_act(
a_q,
b_q,
a_sc,
b_sc,
activation=None,
out_dtype=dtype,
postact_dtype=dtype,
store_preact=False,
tuned=False,
)
fn() # warmup

if trace_path is not None:
for _ in range(3):
fn()
torch.cuda.synchronize()
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
) as prof:
for _ in range(5):
fn()
torch.cuda.synchronize()
prof.export_chrome_trace(trace_path)
print(f" saved kineto trace: {trace_path}")

time.sleep(0.5)
ms = do_bench(fn, warmup=5, rep=repeats)
tf = tflops(nflops, ms)

# Baseline: torch.compile(GEMM + gated activation) on bf16.
ref_fn = torch.compile(
lambda: _torch_gated_act(gated_to_pytorch_fn_map[activation], a_bf16, w_bf16)
)
ref_fn() # compile warmup
ref_fn()
time.sleep(0.5)
ms_pt = do_bench(ref_fn, warmup=5, rep=repeats)
tf_pt = tflops(nflops, ms_pt)

print(f" quack mxfp8: {ms:.3f}ms {tf:.1f} TFLOPS")
print(f" cuBLAS bf16 + torch.compile: {ms_pt:.3f}ms {tf_pt:.1f} TFLOPS")
print(f" speedup vs bf16 baseline: {ms_pt / ms:.2f}x")
return ms, tf


def forced_config_from_args(args):
if args.config_tile_n is None:
return None
Expand Down Expand Up @@ -293,6 +382,17 @@ def main():
default=None,
help="Restrict the FFN gated backward benchmark to one activation",
)
parser.add_argument(
"--only-mxfp8-gated",
action="store_true",
help="Only run the SM90 MXFP8 FFN gated GEMM benchmark",
)
parser.add_argument(
"--mxfp8-gated-activation",
choices=sorted(gated_to_pytorch_fn_map),
default=None,
help="Restrict the MXFP8 FFN gated benchmark to one activation",
)
parser.add_argument(
"--untuned",
action="store_true",
Expand All @@ -307,6 +407,12 @@ def main():
parser.add_argument("--config-pingpong", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--config-swap-ab", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--cold", action="store_true", help="Clear .so and autotuning caches first")
parser.add_argument(
"--trace",
type=str,
default=None,
help="Export a kineto Chrome trace of the quack kernel to this path (mxfp8-gated only)",
)
args = parser.parse_args()

if args.cold:
Expand All @@ -332,11 +438,22 @@ def main():
"swiglu-tanh",
]
)
mxfp8_gated_activations = (
[args.mxfp8_gated_activation]
if args.mxfp8_gated_activation
else [
"swiglu",
"geglu",
]
)
forced_config = forced_config_from_args(args)
ffn = int(args.dim * 3.5) # Llama-3 ratio

if args.only_gated and args.only_dgated:
raise ValueError("--only-gated and --only-dgated are mutually exclusive")
only_flags = [args.only_gated, args.only_dgated, args.only_mxfp8_gated]
if sum(only_flags) > 1:
raise ValueError(
"--only-gated, --only-dgated, and --only-mxfp8-gated are mutually exclusive"
)

if args.only_gated:
print(
Expand Down Expand Up @@ -384,6 +501,32 @@ def main():
)
return

if args.only_mxfp8_gated:
if torch.cuda.get_device_properties(0).major != 9:
raise RuntimeError("--only-mxfp8-gated requires SM90")
print(
f"MXFP8 GEMM gated activation benchmark (workers={os.environ.get('QUACK_COMPILE_WORKERS', '4')})"
)
print(f" batch={args.batch}, dim={args.dim}, ffn={ffn}, dtype={args.dtype}")
for activation in mxfp8_gated_activations:
print(
f"\n FFN up + {activation} (mxfp8): ({args.batch}, {args.dim}) x ({args.dim}, {2 * ffn})"
)
trace_path = args.trace
if trace_path is not None and len(mxfp8_gated_activations) > 1:
root, ext = os.path.splitext(trace_path)
trace_path = f"{root}.{activation}{ext or '.json'}"
benchmark_mxfp8_gemm_act(
args.batch,
ffn,
args.dim,
activation,
dtype,
repeats=args.repeats,
trace_path=trace_path,
)
return

print(f"GEMM autotuning demo (workers={os.environ.get('QUACK_COMPILE_WORKERS', '4')})")
print(f" M={M}, N={N}, K={K}, dtype={args.dtype}")
if forced_config is not None:
Expand Down
Loading
Loading