diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh index bbcba658..0c6951f0 100755 --- a/.github/scripts/run_tests.sh +++ b/.github/scripts/run_tests.sh @@ -74,10 +74,10 @@ EXIT_CODE=0 # shellcheck disable=SC2086 "$SCRIPT_DIR/container_exec.sh" $GPU_ARG " set -e - + echo \"Installing iris using method: $INSTALL_METHOD\" $INSTALL_CMD - + # Run tests in the specified directory for test_file in tests/$TEST_DIR/test_*.py; do if [ -f \"\$test_file\" ]; then @@ -88,4 +88,4 @@ EXIT_CODE=0 " || { EXIT_CODE=$?; } # GPU cleanup is now handled by workflow-level release_gpus.sh step -exit $EXIT_CODE \ No newline at end of file +exit $EXIT_CODE diff --git a/.gitignore b/.gitignore index d8f9754f..7c603be9 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,8 @@ omni*.pdf slurm*.out *.egg-info +*.backup +*.with_chunked examples/gemm/results/* asm/ diff --git a/apptainer/iris.def b/apptainer/iris.def index 7a1f3984..37066944 100644 --- a/apptainer/iris.def +++ b/apptainer/iris.def @@ -30,7 +30,7 @@ From: rocm/pytorch:rocm7.1_ubuntu24.04_py3.13_pytorch_release_2.9.1 cd /opt git clone https://github.com/triton-lang/triton.git \$TRITON_PATH cd \$TRITON_PATH - git checkout bcbcabdd0cff6539c7168299075992b2a23ff38e + git checkout bcbcabdd0cff6539c7168299075992b2a23ff38e pip3 install -e . # Make the venv writable by all diff --git a/benchmark/ops/all_gather_matmul/benchmark.py b/benchmark/ops/all_gather_matmul/benchmark.py new file mode 100644 index 00000000..b9d40118 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/benchmark.py @@ -0,0 +1,422 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops all_gather_matmul fused operation. + +This benchmark showcases the fused All-Gather + GEMM operation where each rank +has a sharded A matrix that gets gathered, then multiplied with B. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops.all_gather_matmul import all_gather_matmul_preamble +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark all_gather_matmul fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension total (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="all_gather_matmul.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (all_gather_into_tensor + matmul) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--variant", + type=str, + default="pull", + choices=["pull", "chunked", "push", "pipelined_pull"], + help="All-gather matmul variant", + ) + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29530", help="Initialization URL for distributed setup" + ) + parser.add_argument( + "--b_col_major", + action="store_true", + help="Store B matrix in column-major order (K-contiguous) to reduce LDS transpose overhead", + ) + parser.add_argument( + "--a_col_major", + action="store_true", + help="Store A matrix in column-major order (M-contiguous). Default is row-major (K-contiguous).", + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + K = args["k"] + K_local = K // world_size # Sharded K dimension + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + "all_gather_matmul_variant": args["variant"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "all_gather_matmul") + json_writer.add_field("k_local", K_local) + json_writer.add_field("k_total", K) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + + # Create input and output tensors + # A_sharded is M x K_local, B is K x N, output is M x N + C = shmem.zeros((M, N), dtype=datatype) + expected_tensor = None + + # Create A_sharded matrix with optional column-major layout + # When a_col_major=True, M becomes the contiguous dimension + # Default (row-major): K is contiguous (stride_ak=1, stride_am=K_local) + if args["a_col_major"]: + # Allocate storage as (K_local, M) row-major, then transpose to get (M, K_local) with M-contiguous + # This means stride_am=1 and stride_ak=M + A_storage = shmem.zeros((K_local, M), dtype=datatype) + A_sharded = A_storage.T # View as (M, K_local) with M-contiguous strides + shmem.info(f"Using column-major A: shape={A_sharded.shape}, strides={A_sharded.stride()} (M-contiguous)") + else: + # Standard row-major (M, K_local) - K is contiguous + A_sharded = shmem.zeros((M, K_local), dtype=datatype) + shmem.info(f"Using row-major A: shape={A_sharded.shape}, strides={A_sharded.stride()} (K-contiguous)") + + json_writer.add_field("a_col_major", args["a_col_major"]) + json_writer.add_field("a_stride_m", A_sharded.stride()[0]) + json_writer.add_field("a_stride_k", A_sharded.stride()[1]) + + # Create B matrix with optional column-major layout for K-contiguous access + # When b_col_major=True, we store B such that K is the contiguous dimension + # This reduces LDS transpose overhead when loading B tiles along the K dimension + if args["b_col_major"]: + # Allocate storage as (N, K) row-major, then transpose to get (K, N) with K-contiguous + # This means stride_bk=1 and stride_bn=K + B_storage = shmem.zeros((N, K), dtype=datatype) + B = B_storage.T # View as (K, N) with K-contiguous strides + shmem.info(f"Using column-major B: shape={B.shape}, strides={B.stride()} (K-contiguous)") + else: + # Standard row-major (K, N) - N is contiguous + B = shmem.zeros((K, N), dtype=datatype) + shmem.info(f"Using row-major B: shape={B.shape}, strides={B.stride()} (N-contiguous)") + + json_writer.add_field("b_col_major", args["b_col_major"]) + json_writer.add_field("b_stride_k", B.stride()[0]) + json_writer.add_field("b_stride_n", B.stride()[1]) + + # Fill inputs with deterministic values + # Each rank has different A_sharded, same B + torch.manual_seed(123 + rank) + A_sharded_data = torch.randn((M, K_local), dtype=datatype, device=f"cuda:{rank}") + A_sharded.copy_(A_sharded_data) + + torch.manual_seed(456) # Same B for all ranks + # Generate B data in standard (K, N) layout for consistency + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + # Copy to B (handles both row-major and column-major storage) + B.copy_(B_data) + + # For validation: compute expected result + if args["validate"]: + # Gather all A_sharded matrices and compute expected result + A_sharded_list = [torch.zeros((M, K_local), dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_sharded_list, A_sharded_data) + + # Concatenate along K dimension: A_gathered = [A_0 | A_1 | ... | A_n] + A_gathered = torch.cat(A_sharded_list, dim=1) # (M, K) + + # Expected: A_gathered @ B + expected_tensor = shmem.zeros((M, N), dtype=datatype) + expected_result = torch.matmul(A_gathered, B_data) + expected_tensor.copy_(expected_result) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "all_gather_matmul": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + # Pre-allocate workspace once (important for push variant which needs large buffers) + workspace = all_gather_matmul_preamble(shmem, A_sharded, B, config) + + def run_experiment(): + nonlocal kernel_timing + + shmem.barrier() + + torch.cuda.nvtx.range_push("All-Gather-Matmul") + with torch.cuda.stream(comm_stream): + kernel_timing["all_gather_matmul"]["start_event"].record() + shmem.ops.all_gather_matmul( + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + ) + kernel_timing["all_gather_matmul"]["end_event"].record() + kernel_timing["all_gather_matmul"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["all_gather_matmul"]["start_event"].elapsed_time( + kernel_timing["all_gather_matmul"]["end_event"] + ) + kernel_timing["all_gather_matmul"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 1e-1 if datatype == torch.float16 else 1e-3 + success = torch.allclose(C, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation failed, max diff: {max_diff}") + + if success: + shmem.info("All-gather-matmul validation passed!") + else: + shmem.error("All-gather-matmul validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + for k in ["all_gather_matmul"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=25, n_repeat=1) + + for k in ["all_gather_matmul"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + tflops = total_tflops_unit / ( + (kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"]) * 1e-3 + ) + + # Calculate bandwidth for all-gather part + # All-gather moves (world_size - 1) * M * K_local * element_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + input_bytes = M * K_local * element_size + total_bytes = input_bytes * (world_size - 1) + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"All-gather-matmul (M={M}, K_local={K_local}, K_total={K}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "all_gather_matmul_ms", + kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"], + ) + json_writer.add_field("all_gather_matmul_experiments", kernel_timing["all_gather_matmul"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (all_gather_into_tensor + matmul) for comparison + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (all_gather_into_tensor + matmul)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A_sharded = torch.randn(M, K_local, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_A_gathered = torch.zeros(M, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + dist.all_gather_into_tensor(pytorch_A_gathered, pytorch_A_sharded) + pytorch_C = torch.matmul(pytorch_A_gathered, pytorch_B) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + def run_pytorch_experiment(): + dist.all_gather_into_tensor(pytorch_A_gathered, pytorch_A_sharded) + pytorch_C = torch.matmul(pytorch_A_gathered, pytorch_B) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch all_gather_into_tensor+matmul (M={M}, K_local={K_local}, K_total={K}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = args["init_url"] + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py new file mode 100644 index 00000000..19079998 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py @@ -0,0 +1,718 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for the HBM-buffered all_gather_matmul variant. + +This variant cooperatively gathers A into a local HBM buffer with per-tile +ready flags, then runs GEMM from local memory. No global barriers -- CUs +that finish gathering early start GEMM immediately, spinning on flags for +any tile not yet available. + +Usage with torchrun: + torchrun --nproc_per_node=8 benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py \\ + -m 2048 -n 16384 -k 131072 --benchmark + + torchrun --nproc_per_node=8 benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py \\ + -m 2048 -n 16384 -k 131072 --benchmark --benchmark_pytorch --b_col_major +""" + +import os +import time +import torch +import torch.distributed as dist +import random +import argparse +import numpy as np + +import iris +from iris.ops.all_gather_matmul_hbm_buffer import ( + all_gather_matmul_hbm_buffer, + all_gather_matmul_hbm_buffer_preamble, +) +from iris.ops import FusedConfig + +_DERIVE_AVAILABLE = False +try: + import sys as _sys + + _script_dir = os.path.dirname(os.path.abspath(__file__)) + if _script_dir not in _sys.path: + _sys.path.insert(0, _script_dir) + from derive_params import ( + derive as _derive_params, + DEFAULT_NUM_CUS, + DEFAULT_PEAK_TFLOPS_FP16, + DEFAULT_HBM_BW_GBPS, + DEFAULT_L2_SIZE_BYTES, + DEFAULT_SCHEDULING_FACTOR, + ) + + _DERIVE_AVAILABLE = True +except Exception: + pass + +_MODEL_PARAMS = ( + "block_size_m", + "block_size_n", + "block_size_k", + "group_size_m", + "num_fetch_sms", + "k_per_flag", + "num_warps", + "num_fetch_stages", + "first_stage_fetch_sms", +) + +_FALLBACK_DEFAULTS = { + "block_size_m": 256, + "block_size_n": 64, + "block_size_k": 64, + "group_size_m": 1, + "k_per_flag": 1, + "num_fetch_stages": 1, +} + +torch.manual_seed(123) +random.seed(123) + +TICKS_PER_US = 100 # s_memrealtime runs at 100 MHz: 1 tick = 10 ns = 0.01 us + + +def _plot_trace(trace_data, output_path, rank, M, N, K, num_fetch_sms_cfg): + """Generate a tall Gantt chart showing per-workgroup activity over time. + + Y-axis: workgroup (sorted by start time) + X-axis: time in microseconds + Colors: fetcher stages (blue shades), GEMM wait (red), GEMM compute (green) + """ + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib.lines import Line2D + + starts = trace_data["start"].numpy().astype(np.int64) + ends = trace_data["end"].numpy().astype(np.int64) + waits = trace_data["wait"].numpy().astype(np.int64) + xcds = trace_data["xcd"].numpy().astype(np.int32) + grid_size = trace_data["grid_size"] + n_fetch_per_stage = trace_data["num_fetch_sms"] + n_stages = trace_data.get("num_fetch_stages", 1) + total_fetch = trace_data.get("total_fetch_wgs", n_fetch_per_stage) + first_stage_fetch = trace_data.get("first_stage_fetch_sms", n_fetch_per_stage) + first_stage_size = trace_data.get("first_stage_size", grid_size) + rest_stage_size = trace_data.get("rest_stage_size", grid_size) + + # Convert to microseconds relative to earliest start + t_min = starts.min() + starts_us = (starts - t_min) / TICKS_PER_US + ends_us = (ends - t_min) / TICKS_PER_US + waits_us = waits / TICKS_PER_US + + # Build role array: stage index for fetchers (0..S-1), S for GEMM + # Asymmetric layout: [fetch0 (P)] [gemm0] [fetch1 (F)] [gemm1] ... + roles = np.empty(grid_size, dtype=np.int32) + for i in range(grid_size): + if i < first_stage_size: + stage = 0 + local = i + fetch_thresh = first_stage_fetch + else: + adjusted = i - first_stage_size + stage = 1 + adjusted // rest_stage_size + local = adjusted % rest_stage_size + fetch_thresh = n_fetch_per_stage + if local < fetch_thresh: + roles[i] = stage # fetcher for this stage + else: + roles[i] = n_stages # GEMM + + # Sort by start time + order = np.argsort(starts_us) + + # Compute figure height: ~0.012 inches per row, min 12 inches + row_h = 0.012 + fig_h = max(12, grid_size * row_h + 2) + fig, ax = plt.subplots(figsize=(18, fig_h)) + + # One color per fetch stage (blue palette), plus GEMM colors + fetch_blues = ["#1565C0", "#42A5F5", "#90CAF9", "#BBDEFB"] + wait_color = "#F44336" # red + compute_color = "#4CAF50" # green + + for y_idx, wg_idx in enumerate(order): + s = starts_us[wg_idx] + e = ends_us[wg_idx] + dur = e - s + role = roles[wg_idx] + + if role < n_stages: + # Fetcher: color by stage + c = fetch_blues[role % len(fetch_blues)] + ax.barh(y_idx, dur, left=s, height=0.8, color=c, edgecolor="none", linewidth=0) + else: + # GEMM: split into wait (red) and compute (green) + w = waits_us[wg_idx] + comp = max(0, dur - w) + ax.barh(y_idx, w, left=s, height=0.8, color=wait_color, edgecolor="none", linewidth=0) + ax.barh(y_idx, comp, left=s + w, height=0.8, color=compute_color, edgecolor="none", linewidth=0) + + # XCD annotations on the right margin + xcd_set = sorted(set(xcds.tolist())) + xcd_cmap = {} + if len(xcd_set) > 1: + cmap = matplotlib.colormaps.get_cmap("tab10").resampled(len(xcd_set)) + for i, x in enumerate(xcd_set): + xcd_cmap[x] = cmap(i) + + x_max = ends_us.max() * 1.02 + for y_idx, wg_idx in enumerate(order): + xcd_id = xcds[wg_idx] + if xcd_id in xcd_cmap: + ax.plot(x_max, y_idx, marker="s", markersize=1.5, color=xcd_cmap[xcd_id], clip_on=False) + + n_gemm = grid_size - total_fetch + if n_stages > 1 and first_stage_fetch != n_fetch_per_stage: + stage_info = f"{first_stage_fetch}+{n_stages - 1}x{n_fetch_per_stage}" + elif n_stages > 1: + stage_info = f"{n_stages}x{n_fetch_per_stage}" + else: + stage_info = str(first_stage_fetch) + ax.set_xlabel("Time (us)", fontsize=12) + ax.set_ylabel("Workgroup (sorted by start time)", fontsize=12) + ax.set_title( + f"Rank {rank} | All-Gather GEMM Trace | " + f"M={M} N={N} K={K} | " + f"{stage_info} fetchers + {n_gemm} GEMM workgroups", + fontsize=13, + ) + ax.set_ylim(-1, grid_size + 1) + ax.set_xlim(0, x_max) + + # Invert y so earliest-starting workgroups are at top + ax.invert_yaxis() + + # Legend + legend_elements = [] + for s_idx in range(min(n_stages, len(fetch_blues))): + legend_elements.append(Line2D([0], [0], color=fetch_blues[s_idx], lw=6, label=f"Fetch stage {s_idx}")) + legend_elements.append(Line2D([0], [0], color=wait_color, lw=6, label="GEMM: waiting on data")) + legend_elements.append(Line2D([0], [0], color=compute_color, lw=6, label="GEMM: compute")) + ax.legend(handles=legend_elements, loc="upper right", fontsize=10) + + # Summary stats + fetch_mask = roles < n_stages + gemm_mask = roles == n_stages + fetch_dur = (ends_us - starts_us)[fetch_mask] + gemm_dur = (ends_us - starts_us)[gemm_mask] + gemm_wait = waits_us[gemm_mask] + gemm_compute = gemm_dur - gemm_wait + + stats_lines = [] + for s_idx in range(n_stages): + s_mask = roles == s_idx + s_dur = (ends_us - starts_us)[s_mask] + s_start = starts_us[s_mask] + if len(s_dur) > 0: + stats_lines.append( + f"Fetch stg{s_idx}: {s_dur.mean():.1f} us avg " + f"({s_dur.min():.1f}-{s_dur.max():.1f}) " + f"first@{s_start.min():.0f}us" + ) + stats_lines += [ + f"GEMM total: {gemm_dur.mean():.1f} us avg ({gemm_dur.min():.1f}-{gemm_dur.max():.1f})", + f" wait: {gemm_wait.mean():.1f} us avg ({gemm_wait.min():.1f}-{gemm_wait.max():.1f})", + f" compute: {gemm_compute.mean():.1f} us avg ({gemm_compute.min():.1f}-{gemm_compute.max():.1f})", + f" wait%: {100 * gemm_wait.sum() / gemm_dur.sum():.1f}%", + f"Wall time: {ends_us.max():.1f} us", + ] + stats_text = "\n".join(stats_lines) + ax.text( + 0.01, + 0.99, + stats_text, + transform=ax.transAxes, + fontsize=9, + verticalalignment="top", + fontfamily="monospace", + bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.85), + ) + + plt.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" [Rank {rank}] Trace plot saved to: {output_path}") + print(f" {stats_text}") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark HBM-buffered all_gather_matmul (per-tile flags).", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=2048, help="M dimension") + parser.add_argument("-n", type=int, default=16384, help="N dimension") + parser.add_argument("-k", type=int, default=131072, help="K dimension (total)") + parser.add_argument("-v", "--validate", action="store_true", help="Validate correctness") + parser.add_argument("-b", "--benchmark", action="store_true", help="Run benchmark") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Tensor datatype", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs (auto if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (all_gather_into_tensor + matmul)", + ) + parser.add_argument("--block_size_m", type=int, default=None, help="Block size M (model-derived if omitted)") + parser.add_argument("--block_size_n", type=int, default=None, help="Block size N (model-derived if omitted)") + parser.add_argument("--block_size_k", type=int, default=None, help="Block size K (model-derived if omitted)") + parser.add_argument("--group_size_m", type=int, default=None, help="Group size M (model-derived if omitted)") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto if None)") + parser.add_argument("--b_col_major", action="store_true", help="B col-major (K-contiguous)") + parser.add_argument("--a_col_major", action="store_true", help="A col-major (M-contiguous)") + parser.add_argument("--single-run", action="store_true", help="1 iteration (for profiling)") + parser.add_argument("--num_fetch_sms", type=int, default=None, help="Fetcher SMs (auto if None)") + parser.add_argument( + "--k_per_flag", type=int, default=None, help="K-blocks per ready flag (model-derived if omitted)" + ) + parser.add_argument("--num_warps", type=int, default=None, help="Triton num_warps (auto if None)") + parser.add_argument("--num_stages", type=int, default=None, help="Triton num_stages (auto if None)") + parser.add_argument( + "--num_fetch_stages", + type=int, + default=None, + help="Number of fetch stages (model-derived if omitted)", + ) + parser.add_argument( + "--first_stage_fetch_sms", + type=int, + default=None, + help="Fetcher WGs for stage 0 (fills first GPU wave; defaults to num_fetch_sms)", + ) + parser.add_argument( + "--trace", + action=argparse.BooleanOptionalAction, + default=True, + help="Collect per-workgroup trace and save Gantt chart PNG", + ) + parser.add_argument("--trace_output", type=str, default="trace.png", help="Output path for trace plot") + return vars(parser.parse_args()) + + +def _apply_model_defaults(args, world_size, dtype_bytes=2): + """Fill None-valued kernel parameters with model-derived predictions. + + Returns a list of parameter names that were set by the model. + """ + applied = [] + if _DERIVE_AVAILABLE: + try: + p = _derive_params( + args["m"], + args["n"], + args["k"], + world_size, + link_bw=50.0, + num_cus=DEFAULT_NUM_CUS, + peak_tflops=DEFAULT_PEAK_TFLOPS_FP16, + hbm_bw_gbps=DEFAULT_HBM_BW_GBPS, + l2_size=DEFAULT_L2_SIZE_BYTES, + scheduling_factor=DEFAULT_SCHEDULING_FACTOR, + dtype_bytes=dtype_bytes, + ) + for name in _MODEL_PARAMS: + if args.get(name) is None and name in p: + args[name] = p[name] + applied.append(name) + except Exception: + pass + + for name, fallback in _FALLBACK_DEFAULTS.items(): + if args.get(name) is None: + args[name] = fallback + + return applied + + +def _worker(args): + """Worker function for torchrun.""" + local_rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0))) + world_size_env = int(os.environ.get("WORLD_SIZE", 1)) + + t0 = time.perf_counter() + + backend = "nccl" if torch.cuda.is_available() else "gloo" + + if "RANK" in os.environ or "LOCAL_RANK" in os.environ: + dist.init_process_group( + backend=backend, + init_method="env://", + device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None, + ) + else: + dist.init_process_group( + backend=backend, + init_method="tcp://127.0.0.1:29530", + world_size=world_size_env, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None, + ) + + t1 = time.perf_counter() + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + t2 = time.perf_counter() + shmem.info(f"Startup: dist.init={t1 - t0:.1f}s, iris.init={t2 - t1:.1f}s, total={t2 - t0:.1f}s") + + datatype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + datatype = datatype_map.get(args["datatype"], torch.float16) + dtype_bytes = torch.tensor([], dtype=datatype).element_size() + + model_applied = _apply_model_defaults(args, world_size, dtype_bytes) + if rank == 0 and model_applied: + shmem.info(f"Model-derived defaults: {', '.join(model_applied)}") + if rank == 0: + param_summary = " ".join(f"{k}={args[k]}" for k in _MODEL_PARAMS) + shmem.info(f"Kernel params: {param_summary}") + + M = args["m"] + N = args["n"] + K = args["k"] + K_local = K // world_size + + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + config = FusedConfig(**config_kwargs) + + buffer_mb = M * K * torch.tensor([], dtype=datatype).element_size() / (1024**2) + num_m_tiles = M // config.block_size_m + num_k_blocks = K // config.block_size_k + shmem.info( + f"HBM-Buffer variant: M={M} N={N} K={K} K_local={K_local} " + f"block=({config.block_size_m},{config.block_size_n},{config.block_size_k}) " + f"buffer={buffer_mb:.0f}MB flags={num_m_tiles}x{num_k_blocks}" + ) + + # ── Allocate tensors ───────────────────────────────────────────────── + C = shmem.zeros((M, N), dtype=datatype) + + if args["a_col_major"]: + A_storage = shmem.zeros((K_local, M), dtype=datatype) + A_sharded = A_storage.T + else: + A_sharded = shmem.zeros((M, K_local), dtype=datatype) + + if args["b_col_major"]: + B_storage = shmem.zeros((N, K), dtype=datatype) + B = B_storage.T + else: + B = shmem.zeros((K, N), dtype=datatype) + + shmem.info(f"A strides={A_sharded.stride()}, B strides={B.stride()}") + + # Fill + torch.manual_seed(123 + rank) + A_data = torch.randn((M, K_local), dtype=datatype, device=f"cuda:{rank}") + A_sharded.copy_(A_data) + + torch.manual_seed(456) + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # Expected + expected_tensor = None + if args["validate"]: + A_list = [torch.zeros((M, K_local), dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_list, A_data) + A_gathered = torch.cat(A_list, dim=1) + expected_tensor = shmem.zeros((M, N), dtype=datatype) + expected_tensor.copy_(torch.matmul(A_gathered, B_data)) + + # Pre-allocate workspace + k_per_flag = args["k_per_flag"] + workspace = all_gather_matmul_hbm_buffer_preamble(shmem, A_sharded, B, config, k_per_flag=k_per_flag) + + # ── Timing ─────────────────────────────────────────────────────────── + comm_stream = torch.cuda.Stream() + start_ev = torch.cuda.Event(enable_timing=True) + end_ev = torch.cuda.Event(enable_timing=True) + total_ms = 0.0 + num_experiments = 0 + + num_fetch_sms = args["num_fetch_sms"] + num_warps = args["num_warps"] + num_stages = args["num_stages"] + num_fetch_stages = args["num_fetch_stages"] + first_stage_fetch_sms = args["first_stage_fetch_sms"] + + def run_experiment(): + nonlocal total_ms, num_experiments + shmem.barrier() + with torch.cuda.stream(comm_stream): + start_ev.record() + all_gather_matmul_hbm_buffer( + shmem, + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + num_fetch_sms=num_fetch_sms, + k_per_flag=k_per_flag, + num_warps=num_warps, + num_stages=num_stages, + num_fetch_stages=num_fetch_stages, + first_stage_fetch_sms=first_stage_fetch_sms, + ) + end_ev.record() + num_experiments += 1 + shmem.barrier() + total_ms += start_ev.elapsed_time(end_ev) + + shmem.barrier() + + # ── Validate ───────────────────────────────────────────────────────── + if args["validate"]: + shmem.info("Validating...") + C.zero_() + shmem.barrier() + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 1e-1 if datatype == torch.float16 else 1e-3 + rtol = 1e-2 if datatype == torch.float16 else 1e-5 + success = torch.allclose(C, expected_tensor, atol=atol, rtol=rtol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation FAILED, max diff: {max_diff}") + else: + shmem.info("Validation PASSED!") + shmem.barrier() + + # ── Benchmark ──────────────────────────────────────────────────────── + if args["benchmark"]: + if args.get("single_run"): + n_warmup, n_repeat = 0, 1 + else: + n_warmup, n_repeat = 25, 100 + + # Warmup + total_ms = 0.0 + num_experiments = 0 + if n_warmup > 0: + iris.do_bench(run_experiment, shmem.barrier, n_warmup=n_warmup, n_repeat=1) + + total_ms = 0.0 + num_experiments = 0 + C.zero_() + shmem.barrier() + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=0, n_repeat=n_repeat) + avg_ms = total_ms / num_experiments if num_experiments > 0 else 0 + + total_flops = 2 * M * N * K + tflops = (total_flops * 1e-12) / (avg_ms * 1e-3) if avg_ms > 0 else 0 + element_size = torch.tensor([], dtype=datatype).element_size() + total_bytes = M * K_local * element_size * (world_size - 1) + bw_gbps = (total_bytes / (1024**3)) / (avg_ms * 1e-3) if avg_ms > 0 else 0 + + shmem.info( + f"HBM-Buffer (M={M}, K_local={K_local}, K={K}, N={N}, " + f"ws={world_size}, dtype={args['datatype']}): " + f"{avg_ms:.3f} ms, {tflops:.3f} TFLOPS, {bw_gbps:.3f} GB/s" + ) + shmem.barrier() + + # ── Per-rank finish time measurement ───────────────────────────── + # Run a single iteration and record wall-clock finish time per rank + # to see if ranks complete at different times (load imbalance). + shmem.barrier() + torch.cuda.synchronize() + dist.barrier() + + # Synchronized start + dist.barrier() + t_start = time.perf_counter() + + all_gather_matmul_hbm_buffer( + shmem, + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + num_fetch_sms=num_fetch_sms, + k_per_flag=k_per_flag, + num_warps=num_warps, + num_stages=num_stages, + num_fetch_stages=num_fetch_stages, + first_stage_fetch_sms=first_stage_fetch_sms, + ) + torch.cuda.synchronize() + t_end = time.perf_counter() + + finish_ms = (t_end - t_start) * 1000.0 + + # Gather all finish times to rank 0 for display + finish_tensor = torch.tensor([finish_ms], dtype=torch.float64, device=f"cuda:{rank}") + all_finish = [torch.zeros(1, dtype=torch.float64, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(all_finish, finish_tensor) + + if rank == 0: + times = [t.item() for t in all_finish] + min_t = min(times) + max_t = max(times) + print("\n Per-rank finish times (single run):") + print(f" {'Rank':>6} {'Finish ms':>10} {'Delta ms':>10}") + print(f" {'-' * 30}") + for r, t in enumerate(times): + delta = t - min_t + print(f" {r:>6} {t:>10.3f} {delta:>+10.3f}") + print(f" {'-' * 30}") + print(f" Spread (max - min): {max_t - min_t:.3f} ms") + print() + + shmem.barrier() + + # ── Trace ──────────────────────────────────────────────────────────── + if args["trace"]: + # Warmup: compile the TRACE=True kernel variant before the real run + shmem.info("Trace warmup (compiling traced kernel variant)...") + C.zero_() + workspace.locks.zero_() + shmem.barrier() + all_gather_matmul_hbm_buffer( + shmem, + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + num_fetch_sms=num_fetch_sms, + k_per_flag=k_per_flag, + num_warps=num_warps, + num_stages=num_stages, + num_fetch_stages=num_fetch_stages, + first_stage_fetch_sms=first_stage_fetch_sms, + trace=True, + ) + torch.cuda.synchronize() + shmem.barrier() + + # Actual traced run (post-compilation, clean state) + shmem.info("Running single traced iteration...") + C.zero_() + workspace.locks.zero_() + shmem.barrier() + + all_gather_matmul_hbm_buffer( + shmem, + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + num_fetch_sms=num_fetch_sms, + k_per_flag=k_per_flag, + num_warps=num_warps, + num_stages=num_stages, + num_fetch_stages=num_fetch_stages, + first_stage_fetch_sms=first_stage_fetch_sms, + trace=True, + ) + torch.cuda.synchronize() + shmem.barrier() + + if rank == 0 and hasattr(workspace, "trace_data"): + trace_out = args.get("trace_output", "trace_gantt.png") + try: + _plot_trace(workspace.trace_data, trace_out, rank, M, N, K, num_fetch_sms) + except ImportError: + print(" (matplotlib not available -- skipping trace plot)") + except Exception as e: + print(f" (Trace plot failed: {e})") + shmem.barrier() + + # ── PyTorch baseline ───────────────────────────────────────────────── + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (all_gather_into_tensor + matmul)...") + + pt_A = torch.randn(M, K_local, dtype=datatype, device=f"cuda:{rank}") + pt_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pt_Ag = torch.zeros(M, K, dtype=datatype, device=f"cuda:{rank}") + + for _ in range(10): + dist.all_gather_into_tensor(pt_Ag, pt_A) + _ = torch.matmul(pt_Ag, pt_B) + torch.cuda.synchronize() + dist.barrier() + + def run_pt(): + dist.all_gather_into_tensor(pt_Ag, pt_A) + _ = torch.matmul(pt_Ag, pt_B) + + total_flops = 2 * M * N * K + element_size = torch.tensor([], dtype=datatype).element_size() + total_bytes = M * K_local * element_size * (world_size - 1) + + pt_ms = iris.do_bench(run_pt, dist.barrier) + pt_tflops = (total_flops * 1e-12) / (pt_ms * 1e-3) if pt_ms > 0 else 0 + pt_bw = (total_bytes / (1024**3)) / (pt_ms * 1e-3) if pt_ms > 0 else 0 + + shmem.info( + f"PyTorch (M={M}, K_local={K_local}, K={K}, N={N}, ws={world_size}, " + f"dtype={args['datatype']}): " + f"{pt_ms:.3f} ms, {pt_tflops:.3f} TFLOPS, {pt_bw:.3f} GB/s" + ) + + if args["benchmark"]: + avg_ms = total_ms / num_experiments if num_experiments > 0 else 0 + iris_tflops = (total_flops * 1e-12) / (avg_ms * 1e-3) if avg_ms > 0 else 0 + speedup = iris_tflops / pt_tflops if pt_tflops > 0 else 0 + shmem.info(f"Speedup (HBM-Buffer / PyTorch): {speedup:.2f}x") + + shmem.barrier() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + print("Starting HBM-buffer all_gather_matmul benchmark...") + args = parse_args() + if "RANK" in os.environ or "LOCAL_RANK" in os.environ: + _worker(args) + else: + print( + "Please run with torchrun:\n" + " torchrun --nproc_per_node=N " + "benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py [OPTIONS]" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/all_gather_matmul/benchmark_torchrun.py b/benchmark/ops/all_gather_matmul/benchmark_torchrun.py new file mode 100755 index 00000000..f4526410 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/benchmark_torchrun.py @@ -0,0 +1,487 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops all_gather_matmul fused operation. + +This benchmark showcases the fused All-Gather + GEMM operation where each rank +has a sharded A matrix that gets gathered, then multiplied with B. + +This version is compatible with torchrun for use with profiling tools like rocprofv3/att. + +Usage with torchrun: + torchrun --nproc_per_node=8 benchmark_torchrun.py -m 16384 -n 2048 -k 131072 --benchmark + +Usage with rocprofv3: + torchrun --nproc_per_node=8 rocprofv3 --att benchmark_torchrun.py -m 16384 -n 2048 -k 131072 --benchmark +""" + +import os +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops.all_gather_matmul import all_gather_matmul_preamble +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark all_gather_matmul fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension total (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="all_gather_matmul.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (all_gather_into_tensor + matmul) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--variant", + type=str, + default="pull", + choices=["pull", "chunked", "push", "pipelined_pull"], + help="All-gather matmul variant", + ) + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29530", help="Initialization URL for distributed setup" + ) + parser.add_argument( + "--single-run", + action="store_true", + help="Run only one iteration (no warmup, 1 repeat) - useful for profiling", + ) + parser.add_argument( + "--b_col_major", + action="store_true", + help="Store B matrix in column-major order (K-contiguous) to reduce LDS transpose overhead", + ) + parser.add_argument( + "--a_col_major", + action="store_true", + help="Store A matrix in column-major order (M-contiguous). Default is row-major (K-contiguous).", + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int = None, world_size: int = None, init_url: str = None, args: dict = None): + """Worker function for PyTorch distributed execution.""" + # Support torchrun: read from environment variables if available + if local_rank is None: + local_rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0))) + if world_size is None: + world_size = int(os.environ.get("WORLD_SIZE", 1)) + if init_url is None: + # torchrun sets MASTER_ADDR and MASTER_PORT + master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") + master_port = os.environ.get("MASTER_PORT", "29500") + init_url = f"tcp://{master_addr}:{master_port}" + + # Use nccl backend - gloo doesn't support uint64 tensors used by Iris + backend = "nccl" if torch.cuda.is_available() else "gloo" + print(f"Rank {local_rank}: Using backend: {backend}") + + # Use environment-based initialization if torchrun is detected + if "RANK" in os.environ or "LOCAL_RANK" in os.environ: + # For torchrun, use env:// initialization with device_id for nccl + dist.init_process_group( + backend=backend, + init_method="env://", + device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None, + ) + else: + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None, + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + K = args["k"] + K_local = K // world_size # Sharded K dimension + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + "all_gather_matmul_variant": args["variant"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "all_gather_matmul") + json_writer.add_field("k_local", K_local) + json_writer.add_field("k_total", K) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + + # Create input and output tensors + # A_sharded is M x K_local, B is K x N, output is M x N + C = shmem.zeros((M, N), dtype=datatype) + expected_tensor = None + + # Create A_sharded matrix with optional column-major layout + # When a_col_major=True, M becomes the contiguous dimension + # Default (row-major): K is contiguous (stride_ak=1, stride_am=K_local) + if args["a_col_major"]: + # Allocate storage as (K_local, M) row-major, then transpose to get (M, K_local) with M-contiguous + # This means stride_am=1 and stride_ak=M + A_storage = shmem.zeros((K_local, M), dtype=datatype) + A_sharded = A_storage.T # View as (M, K_local) with M-contiguous strides + shmem.info(f"Using column-major A: shape={A_sharded.shape}, strides={A_sharded.stride()} (M-contiguous)") + else: + # Standard row-major (M, K_local) - K is contiguous + A_sharded = shmem.zeros((M, K_local), dtype=datatype) + shmem.info(f"Using row-major A: shape={A_sharded.shape}, strides={A_sharded.stride()} (K-contiguous)") + + json_writer.add_field("a_col_major", args["a_col_major"]) + json_writer.add_field("a_stride_m", A_sharded.stride()[0]) + json_writer.add_field("a_stride_k", A_sharded.stride()[1]) + + # Create B matrix with optional column-major layout for K-contiguous access + # When b_col_major=True, we store B such that K is the contiguous dimension + # This reduces LDS transpose overhead when loading B tiles along the K dimension + if args["b_col_major"]: + # Allocate storage as (N, K) row-major, then transpose to get (K, N) with K-contiguous + # This means stride_bk=1 and stride_bn=K + B_storage = shmem.zeros((N, K), dtype=datatype) + B = B_storage.T # View as (K, N) with K-contiguous strides + shmem.info(f"Using column-major B: shape={B.shape}, strides={B.stride()} (K-contiguous)") + else: + # Standard row-major (K, N) - N is contiguous + B = shmem.zeros((K, N), dtype=datatype) + shmem.info(f"Using row-major B: shape={B.shape}, strides={B.stride()} (N-contiguous)") + + json_writer.add_field("b_col_major", args["b_col_major"]) + json_writer.add_field("b_stride_k", B.stride()[0]) + json_writer.add_field("b_stride_n", B.stride()[1]) + + # Fill inputs with deterministic values + # Each rank has different A_sharded, same B + torch.manual_seed(123 + rank) + A_sharded_data = torch.randn((M, K_local), dtype=datatype, device=f"cuda:{rank}") + A_sharded.copy_(A_sharded_data) + + torch.manual_seed(456) # Same B for all ranks + # Generate B data in standard (K, N) layout for consistency + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + # Copy to B (handles both row-major and column-major storage) + B.copy_(B_data) + + # For validation: compute expected result + if args["validate"]: + # Gather all A_sharded matrices and compute expected result + A_sharded_list = [torch.zeros((M, K_local), dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_sharded_list, A_sharded_data) + + # Concatenate along K dimension: A_gathered = [A_0 | A_1 | ... | A_n] + A_gathered = torch.cat(A_sharded_list, dim=1) # (M, K) + + # Expected: A_gathered @ B + expected_tensor = shmem.zeros((M, N), dtype=datatype) + expected_result = torch.matmul(A_gathered, B_data) + expected_tensor.copy_(expected_result) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "all_gather_matmul": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + # Pre-allocate workspace once (important for push variant which needs large buffers) + workspace = all_gather_matmul_preamble(shmem, A_sharded, B, config) + + def run_experiment(): + nonlocal kernel_timing + + shmem.barrier() + + torch.cuda.nvtx.range_push("All-Gather-Matmul") + with torch.cuda.stream(comm_stream): + kernel_timing["all_gather_matmul"]["start_event"].record() + shmem.ops.all_gather_matmul( + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + ) + kernel_timing["all_gather_matmul"]["end_event"].record() + kernel_timing["all_gather_matmul"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["all_gather_matmul"]["start_event"].elapsed_time( + kernel_timing["all_gather_matmul"]["end_event"] + ) + kernel_timing["all_gather_matmul"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 1e-1 if datatype == torch.float16 else 1e-3 + success = torch.allclose(C, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation failed, max diff: {max_diff}") + + if success: + shmem.info("All-gather-matmul validation passed!") + else: + shmem.error("All-gather-matmul validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Determine warmup and repeat counts + if args.get("single_run", False): + n_warmup = 0 + n_repeat = 1 + shmem.info("Single-run mode: no warmup, 1 repeat") + else: + n_warmup = 25 + n_repeat = 100 # default from iris.do_bench + + # Warmup for benchmarking (skip if single-run) + if not args.get("single_run", False): + for k in ["all_gather_matmul"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=n_warmup, n_repeat=1) + + for k in ["all_gather_matmul"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier, n_warmup=n_warmup, n_repeat=n_repeat) + tflops = total_tflops_unit / ( + (kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"]) * 1e-3 + ) + + # Calculate bandwidth for all-gather part + # All-gather moves (world_size - 1) * M * K_local * element_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + input_bytes = M * K_local * element_size + total_bytes = input_bytes * (world_size - 1) + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"All-gather-matmul (M={M}, K_local={K_local}, K_total={K}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "all_gather_matmul_ms", + kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"], + ) + json_writer.add_field("all_gather_matmul_experiments", kernel_timing["all_gather_matmul"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (all_gather_into_tensor + matmul) for comparison + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (all_gather_into_tensor + matmul)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A_sharded = torch.randn(M, K_local, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_A_gathered = torch.zeros(M, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + dist.all_gather_into_tensor(pytorch_A_gathered, pytorch_A_sharded) + pytorch_C = torch.matmul(pytorch_A_gathered, pytorch_B) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + # Calculate bandwidth for all-gather part + element_size = torch.tensor([], dtype=datatype).element_size() + input_bytes = M * K_local * element_size + total_bytes = input_bytes * (world_size - 1) + total_bytes_gb = total_bytes / (1024**3) + + def run_pytorch_experiment(): + dist.all_gather_into_tensor(pytorch_A_gathered, pytorch_A_sharded) + pytorch_C = torch.matmul(pytorch_A_gathered, pytorch_B) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch all_gather_into_tensor+matmul (M={M}, K_local={K_local}, K_total={K}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + print("Starting all_gather_matmul benchmark...") + args = parse_args() + + # Check if running with torchrun (detected by environment variables) + if "RANK" in os.environ or "LOCAL_RANK" in os.environ: + # torchrun handles process spawning, so call _worker directly + print("Detected torchrun execution mode") + _worker(args=args) + else: + # Use multiprocessing spawn for backward compatibility + num_ranks = args["num_ranks"] + init_url = args["init_url"] + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/all_gather_matmul/derive_params.py b/benchmark/ops/all_gather_matmul/derive_params.py new file mode 100644 index 00000000..cf4acd9f --- /dev/null +++ b/benchmark/ops/all_gather_matmul/derive_params.py @@ -0,0 +1,721 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Parameter derivation for the HBM-buffered all_gather_matmul kernel. + +Given a problem size (M, N, K), world size, and per-link XGMI bandwidth, +derives kernel parameters that balance communication and computation in +the device-level pipeline. + +The kernel fuses all-gather with GEMM using two workgroup roles: + - Fetcher WGs: gather remote A tiles into an HBM staging buffer, + setting per-tile ready flags as data arrives. + - GEMM WGs: poll flags, then compute C += A_staged @ B tile-by-tile. + +The M dimension is split into `num_fetch_stages` pipeline stages. Each +stage's fetchers and GEMM WGs are interleaved in the launch grid so that +stage N+1's fetch overlaps with stage N's compute. + +Pipeline timeline (S stages): + |-- fetch stage 0 --|-- max(fetch, compute) * (S-1) --|-- compute last --| + +Usage: + python derive_params.py -m 131072 -n 2048 -k 16384 + python derive_params.py -m 196608 -n 2304 -k 16384 --link_bw 50 + python derive_params.py -m 196608 -n 2304 -k 16384 -v -b --trace + +When --link_bw is omitted the script automatically profiles the XGMI +link bandwidth by timing GPU-to-GPU copies across all peer pairs visible +from GPU 0. +""" + +import argparse +import math +import time + +# ── MI300X hardware defaults ────────────────────────────────────────────── +DEFAULT_NUM_CUS = 304 +DEFAULT_PEAK_TFLOPS_FP16 = 1300.0 +DEFAULT_HBM_BW_GBPS = 5300.0 +DEFAULT_L2_SIZE_BYTES = 256 * 1024 * 1024 +DEFAULT_NUM_XCDS = 8 +DEFAULT_WORLD_SIZE = 8 + +# Calibrated from MI300X trace data: the ratio of measured wall time to +# the CU-work-queue lower bound. Captures WG dispatch overhead, +# cross-XCD coherence latency, and pipeline bubble effects. +DEFAULT_SCHEDULING_FACTOR = 4.5 + + +def profile_link_bandwidth(world_size=DEFAULT_WORLD_SIZE): + """Measure per-link unidirectional XGMI bandwidth. + + Copies a 256 MB fp16 tensor from GPU 0 to every other visible GPU, + times the transfers with host-side timing after explicit device syncs, + and returns the conservative (min) per-link bandwidth. + """ + import torch + + n_gpus = torch.cuda.device_count() + if n_gpus < 2: + raise RuntimeError( + f"Need >= 2 visible GPUs for bandwidth profiling, found {n_gpus}. Pass --link_bw explicitly instead." + ) + + n_peers = min(world_size, n_gpus) - 1 + size_bytes = 256 * 1024 * 1024 + numel = size_bytes // 2 + warmup_iters = 10 + timed_iters = 40 + + print(f"\n── Link Bandwidth Profiling {'─' * 43}") + print(f" GPUs visible: {n_gpus}") + print(f" Testing: GPU 0 → GPUs 1..{n_peers}") + print(f" Transfer size: {size_bytes // (1024**2)} MB × {timed_iters} iterations\n") + + src = torch.empty(numel, dtype=torch.float16, device="cuda:0").normal_() + bandwidths = [] + + for peer in range(1, n_peers + 1): + dst = torch.empty(numel, dtype=torch.float16, device=f"cuda:{peer}") + + for _ in range(warmup_iters): + dst.copy_(src) + torch.cuda.synchronize(0) + torch.cuda.synchronize(peer) + + t_start = time.perf_counter() + for _ in range(timed_iters): + dst.copy_(src) + torch.cuda.synchronize(peer) + elapsed_s = time.perf_counter() - t_start + + bw = size_bytes * timed_iters / elapsed_s / 1e9 + bandwidths.append(bw) + print(f" GPU 0 → GPU {peer}: {bw:6.1f} GB/s") + + del dst + + del src + torch.cuda.empty_cache() + + bw_min = min(bandwidths) + bw_max = max(bandwidths) + bw_avg = sum(bandwidths) / len(bandwidths) + print(f"\n min = {bw_min:.1f} avg = {bw_avg:.1f} max = {bw_max:.1f} GB/s") + print(f" Using conservative (min): {bw_min:.1f} GB/s per link") + + return bw_min + + +# ── Tile / block size heuristics ────────────────────────────────────────── + + +def _choose_block_sizes(M, N, K, K_local): + """Heuristic tile-size selection for MI300X MFMA.""" + bk = 64 + + bm = 256 if M >= 8192 else 128 + while M % bm != 0 and bm > 64: + bm //= 2 + + if N >= 512: + bn = 256 + elif N >= 256: + bn = 256 if N % 256 == 0 else 128 + else: + bn = 128 + while N % bn != 0 and bn > 32: + bn //= 2 + + while K % bk != 0 and bk > 16: + bk //= 2 + while K_local % bk != 0 and bk > 16: + bk //= 2 + + nw = 8 if bm * bn >= 256 * 256 else 4 + return bm, bn, bk, nw + + +def _choose_k_per_flag(num_k_blocks, num_k_blocks_local, target_groups=8): + """Pick k_per_flag so that flag groups align to rank boundaries when + possible, falling back to the largest divisor near the target.""" + if num_k_blocks % num_k_blocks_local == 0: + candidate = num_k_blocks_local + groups = num_k_blocks // candidate + if groups >= 4: + return candidate + + kpf = max(1, num_k_blocks // target_groups) + while num_k_blocks % kpf != 0 and kpf > 1: + kpf -= 1 + return kpf + + +# ── Per-tile roofline model ────────────────────────────────────────────── + + +def _tile_roofline(bm, bn, bk, M, K, N, dtype_bytes, peak_tflops, hbm_bw_gbps, l2_size): + """Compute achievable per-CU TFLOPS from tile arithmetic intensity. + + staged_a is always >> L2, so A tiles come from HBM. B may fit in L2 + only when staged_a is small enough that reads don't evict B. + Returns (roofline_tflops, tile_intensity, ridge_point, b_in_l2). + """ + tile_flops = 2 * bm * bn * bk + a_bytes = bm * bk * dtype_bytes + b_bytes = bk * bn * dtype_bytes + + b_total = K * N * dtype_bytes + staged_a_total = M * K * dtype_bytes + # When staged_a exceeds L2, streaming GEMM reads evict B regardless + # of B's absolute size. + b_in_l2 = (staged_a_total <= l2_size) and (b_total <= l2_size) + + hbm_bytes = a_bytes + (0 if b_in_l2 else b_bytes) + intensity = tile_flops / max(hbm_bytes, 1) + + ridge = peak_tflops * 1e3 / hbm_bw_gbps + if intensity >= ridge: + roofline = peak_tflops + else: + roofline = hbm_bw_gbps * intensity / 1e3 + + return roofline, intensity, ridge, b_in_l2 + + +# ── Per-WG execution time models ──────────────────────────────────────── + + +def _gemm_wg_time_us(bm, bn, bk, K, num_flag_groups, roofline_tflops, num_cus): + """Estimate per-WG GEMM execution time in microseconds. + + Uses the per-tile roofline to get the per-CU throughput, then applies + a calibrated overhead for memory-latency hiding and instruction + scheduling at single-WG occupancy (large tiles). + """ + total_flops = 2 * bm * bn * K + per_cu_tflops = roofline_tflops / num_cus + + # Roofline-ideal per-WG time + ideal_us = total_flops / (per_cu_tflops * 1e6) + + # Single-occupancy overhead: imperfect latency hiding, instruction + # scheduling gaps, cross-XCD coherence on staged_a reads. + # Calibrated from MI300X traces: actual/ideal ≈ 1.2-1.3. + occupancy_factor = 1.25 if bm * bn >= 256 * 256 else 1.10 + + # Flag polling: acquire-semantics atomic per flag group + flag_us = num_flag_groups * 2.5 + + return ideal_us * occupancy_factor + flag_us + + +def _fetch_wg_time_us(bm, bk, kpf, world_size, link_bw, dtype_bytes, num_fgs_per_wg): + """Estimate per-fetcher-WG execution time in microseconds. + + Each flag group fetches kpf K-blocks (each BM × BK) from one rank. + Remote data traverses XGMI; local data uses HBM. + """ + bytes_per_fg = bm * kpf * bk * dtype_bytes + remote_frac = (world_size - 1) / world_size + + # XGMI gather: raw transfer + iris.x.gather software overhead + remote_bytes = bytes_per_fg * remote_frac + gather_overhead = 1.5 + xgmi_us = remote_bytes / (link_bw * 1e3) * gather_overhead + + # HBM write to staged_a (.cg → L2/HBM, per-WG share of bandwidth) + write_bw = 15.0 # GB/s effective per fetcher WG (calibrated from traces) + write_us = bytes_per_fg / (write_bw * 1e3) + + # Read and write overlap within each tile; dominant cost + flag-store + per_fg_us = max(xgmi_us, write_us) + 5.0 + + return num_fgs_per_wg * per_fg_us + + +# ── Kernel time estimation ─────────────────────────────────────────────── + + +def _estimate_kernel_time(total_gemm_wgs, gemm_wg_us, total_fetch_wgs, fetch_wg_us, num_cus, scheduling_factor): + """Estimate kernel wall-clock time from the CU work queue model. + + total_CU_work / num_CUs gives the ideal (work-conserving) lower + bound. The scheduling_factor captures GPU dispatch overhead, + cross-XCD coherence, and pipeline bubble effects measured on MI300X. + """ + total_cu_work_us = total_gemm_wgs * gemm_wg_us + total_fetch_wgs * fetch_wg_us + + ideal_ms = total_cu_work_us / num_cus / 1e3 + estimated_ms = ideal_ms * scheduling_factor + return estimated_ms, ideal_ms + + +# ── Pipeline stage selection ───────────────────────────────────────────── + + +def _choose_fetch_stages(num_m_tiles, num_tiles_n, group_size_m, comm_time_ms, compute_time_ms, num_cus): + """Choose num_fetch_stages for good pipeline efficiency while keeping + m_per_stage divisible by group_size_m.""" + ratio = comm_time_ms / compute_time_ms if compute_time_ms > 0 else 999 + + if ratio > 1.5: + ideal_stages = 32 + elif ratio > 0.8: + ideal_stages = 16 + elif ratio > 0.3: + ideal_stages = 8 + else: + ideal_stages = 4 + + min_gemm_tiles = max(num_cus // 4, 32) + min_m_per_stage = max(group_size_m, math.ceil(min_gemm_tiles / max(num_tiles_n, 1))) + max_stages = max(1, num_m_tiles // min_m_per_stage) + num_stages = min(ideal_stages, max_stages) + num_stages = max(1, num_stages) + + m_per_stage = math.ceil(num_m_tiles / num_stages) + if m_per_stage % group_size_m != 0: + m_per_stage = ((m_per_stage + group_size_m - 1) // group_size_m) * group_size_m + num_stages = max(1, math.ceil(num_m_tiles / m_per_stage)) + + m_per_stage = math.ceil(num_m_tiles / num_stages) + return num_stages, m_per_stage + + +# ── num_fetch_sms optimisation ─────────────────────────────────────────── + + +def _choose_num_fetch_sms( + m_per_stage, + group_size_m, + num_flag_groups_k, + link_bw, + world_size, + num_cus, + bm, + bk, + kpf, + dtype_bytes, + gemm_wg_us, + gemm_tiles_per_stage, +): + """Choose num_fetch_sms for good pipeline overlap. + + Balances three constraints: + 1. Flag delivery parallelism: ≥ m_per_stage so every M-tile gets + a fetcher early (good for reducing GEMM flag-poll stalls). + 2. Link saturation: enough concurrent fetchers to use the XGMI + aggregate bandwidth. + 3. CU budget: leave enough CUs for GEMM in the first dispatch wave. + + Returns (num_fetch_sms, per-WG timing info dict). + """ + total_fg_per_stage = num_flag_groups_k * m_per_stage + + # Constraint 1: one fetcher per M-group for broad flag delivery + parallel_min = max(1, m_per_stage // group_size_m) + + # Constraint 2: enough fetchers to keep XGMI links busy + per_fg_bytes = bm * kpf * bk * dtype_bytes + per_fg_remote = per_fg_bytes * (world_size - 1) / world_size + per_fg_xgmi_us = per_fg_remote / (link_bw * 1e3) * 1.5 + per_fg_write_us = per_fg_bytes / (15.0 * 1e3) + per_fg_us = max(per_fg_xgmi_us, per_fg_write_us) + 5.0 + + # Total flag groups per stage should finish within the stage GEMM time + gemm_waves = math.ceil(gemm_tiles_per_stage / num_cus) + stage_gemm_us = gemm_waves * gemm_wg_us + if per_fg_us > 0: + balance_min = max(1, math.ceil(total_fg_per_stage * per_fg_us / stage_gemm_us)) + else: + balance_min = 1 + + nf = max(parallel_min, balance_min, 64) + nf = min(nf, num_cus // 2) + nf = max(1, nf) + + return nf + + +# ── Main derivation ────────────────────────────────────────────────────── + + +def derive(M, N, K, world_size, link_bw, num_cus, peak_tflops, hbm_bw_gbps, l2_size, scheduling_factor, dtype_bytes): + K_local = K // world_size + + # 1. Tile sizes + bm, bn, bk, nw = _choose_block_sizes(M, N, K, K_local) + gm = 4 + num_m_tiles = M // bm + num_tiles_n = math.ceil(N / bn) + num_k_blocks = K // bk + num_k_blocks_local = K_local // bk + + # 2. Per-tile roofline + roofline_tflops, intensity, ridge, b_in_l2 = _tile_roofline( + bm, bn, bk, M, K, N, dtype_bytes, peak_tflops, hbm_bw_gbps, l2_size + ) + + # 3. Communication model (link-limited) + total_remote_bytes = M * K_local * (world_size - 1) * dtype_bytes + total_link_bw = link_bw * (world_size - 1) + comm_time_ms = total_remote_bytes / (total_link_bw * 1e9) * 1e3 + + # 4. Compute model (roofline-limited) + total_flops = 2 * M * N * K + compute_time_ms = total_flops / (roofline_tflops * 1e12) * 1e3 + + ratio = comm_time_ms / compute_time_ms if compute_time_ms > 0 else 999 + + # 5. k_per_flag + kpf = _choose_k_per_flag(num_k_blocks, num_k_blocks_local) + num_flag_groups_k = num_k_blocks // kpf + + # 6. Pipeline stages + num_stages, m_per_stage = _choose_fetch_stages(num_m_tiles, num_tiles_n, gm, comm_time_ms, compute_time_ms, num_cus) + gemm_tiles_per_stage = m_per_stage * num_tiles_n + + # 7. first_stage_fetch_sms: use all CUs to fill the pipeline ASAP + fsf = num_cus + + # 8. Per-WG timing + gemm_wg_us_val = _gemm_wg_time_us(bm, bn, bk, K, num_flag_groups_k, roofline_tflops, num_cus) + + # 9. Choose num_fetch_sms + nf = _choose_num_fetch_sms( + m_per_stage, + gm, + num_flag_groups_k, + link_bw, + world_size, + num_cus, + bm, + bk, + kpf, + dtype_bytes, + gemm_wg_us_val, + gemm_tiles_per_stage, + ) + + # 10. Compute per-WG fetch times + total_fg_per_stage = num_flag_groups_k * m_per_stage + fgs_per_wg_stg0 = max(1, math.ceil(total_fg_per_stage / fsf)) + fgs_per_wg_rest = max(1, math.ceil(total_fg_per_stage / nf)) + fetch_us_stg0 = _fetch_wg_time_us(bm, bk, kpf, world_size, link_bw, dtype_bytes, fgs_per_wg_stg0) + fetch_us_rest = _fetch_wg_time_us(bm, bk, kpf, world_size, link_bw, dtype_bytes, fgs_per_wg_rest) + + # 11. Grid geometry + first_stage_size = fsf + gemm_tiles_per_stage + rest_stage_size = nf + gemm_tiles_per_stage + grid_size = first_stage_size + rest_stage_size * max(0, num_stages - 1) + total_fetch_wgs = fsf + nf * max(0, num_stages - 1) + total_gemm_wgs = gemm_tiles_per_stage * num_stages + + # 12. Kernel time estimate (CU-work model) + avg_fetch_us = fsf * fetch_us_stg0 + nf * max(0, num_stages - 1) * fetch_us_rest + avg_fetch_us /= max(total_fetch_wgs, 1) + est_kernel_ms, est_ideal_ms = _estimate_kernel_time( + total_gemm_wgs, gemm_wg_us_val, total_fetch_wgs, avg_fetch_us, num_cus, scheduling_factor + ) + + # 13. Link-limited pipeline estimate (simple model for comparison) + stage_m = m_per_stage * bm + stage_comm_ms = stage_m * K_local * (world_size - 1) * dtype_bytes / (total_link_bw * 1e9) * 1e3 + stage_compute_ms = 2 * stage_m * N * K / (roofline_tflops * 1e12) * 1e3 + startup_ms = stage_comm_ms + steady_ms = max(stage_comm_ms, stage_compute_ms) * max(0, num_stages - 1) + drain_ms = stage_compute_ms + pipeline_ms = startup_ms + steady_ms + drain_ms + sequential_ms = comm_time_ms + compute_time_ms + + # 14. Standalone GEMM estimate (rocBLAS-class efficiency for comparison) + standalone_gemm_eff = 0.30 + standalone_tflops = roofline_tflops * standalone_gemm_eff + standalone_gemm_ms = total_flops / (standalone_tflops * 1e12) * 1e3 + pytorch_est_ms = comm_time_ms + standalone_gemm_ms + + staged_a_gb = M * K * dtype_bytes / (1024**3) + + return dict( + block_size_m=bm, + block_size_n=bn, + block_size_k=bk, + group_size_m=gm, + num_warps=nw, + num_fetch_sms=nf, + k_per_flag=kpf, + num_fetch_stages=num_stages, + first_stage_fetch_sms=fsf, + # derived + K_local=K_local, + num_m_tiles=num_m_tiles, + num_tiles_n=num_tiles_n, + num_k_blocks=num_k_blocks, + num_flag_groups_k=num_flag_groups_k, + m_per_stage=m_per_stage, + gemm_tiles_per_stage=gemm_tiles_per_stage, + grid_size=grid_size, + total_fetch_wgs=total_fetch_wgs, + total_gemm_wgs=total_gemm_wgs, + # roofline + roofline_tflops=roofline_tflops, + tile_intensity=intensity, + ridge_point=ridge, + b_in_l2=b_in_l2, + # per-WG timing + gemm_wg_us=gemm_wg_us_val, + fetch_wg_us_stg0=fetch_us_stg0, + fetch_wg_us_rest=fetch_us_rest, + # estimates + total_remote_bytes=total_remote_bytes, + total_link_bw=total_link_bw, + comm_time_ms=comm_time_ms, + total_flops=total_flops, + compute_time_ms=compute_time_ms, + ratio=ratio, + stage_comm_ms=stage_comm_ms, + stage_compute_ms=stage_compute_ms, + pipeline_ms=pipeline_ms, + sequential_ms=sequential_ms, + est_kernel_ms=est_kernel_ms, + est_ideal_ms=est_ideal_ms, + standalone_gemm_ms=standalone_gemm_ms, + pytorch_est_ms=pytorch_est_ms, + staged_a_gb=staged_a_gb, + scheduling_factor=scheduling_factor, + ) + + +# ── Formatting helpers ─────────────────────────────────────────────────── + + +def _fmt_bytes(n): + if n >= 1024**3: + return f"{n / 1024**3:.2f} GB" + if n >= 1024**2: + return f"{n / 1024**2:.1f} MB" + return f"{n / 1024:.1f} KB" + + +def _fmt_flops(n): + if n >= 1e15: + return f"{n / 1e15:.2f} PFLOPs" + return f"{n / 1e12:.2f} TFLOPs" + + +def _fmt_tflops(t): + return f"{t:.0f} TFLOPS" + + +# ── Analysis output ────────────────────────────────────────────────────── + + +def print_analysis(M, N, K, world_size, link_bw, p, passthrough_args, bw_profiled=False): + K_local = p["K_local"] + dtype_bytes = 2 + bound = "COMM-BOUND" if p["ratio"] > 1.0 else "COMPUTE-BOUND" + + print("=" * 72) + print(" All-Gather Matmul HBM-Buffer — Parameter Derivation") + print("=" * 72) + + # ── Problem ─────────────────────────────────────────────────────── + print(f"\n{'Problem':>14}: C({M}, {N}) = all_gather(A_shard({M}, {K_local})) @ B({K}, {N})") + print(f"{'World size':>14}: {world_size} GPUs") + print(f"{'Dtype':>14}: fp16 ({dtype_bytes}B)") + + # ── Data sizes ──────────────────────────────────────────────────── + a_shard = M * K_local * dtype_bytes + b_size = K * N * dtype_bytes + c_size = M * N * dtype_bytes + staged = M * K * dtype_bytes + print(f"\n{'A_shard':>14}: ({M}, {K_local}) {_fmt_bytes(a_shard)}") + print(f"{'B':>14}: ({K}, {N}) {_fmt_bytes(b_size)}") + print(f"{'C':>14}: ({M}, {N}) {_fmt_bytes(c_size)}") + print(f"{'staged_a':>14}: ({M}, {K}) {_fmt_bytes(staged)}") + if staged > 4 * 1024**3: + print(f"{'':>14} *** > 4 GB: requires int64 pointer arithmetic ***") + + # ── Per-tile roofline ───────────────────────────────────────────── + print(f"\n── Roofline {'─' * 59}") + print(f"{'Tile':>14}: ({p['block_size_m']}, {p['block_size_n']}, {p['block_size_k']})") + print(f"{'Intensity':>14}: {p['tile_intensity']:.0f} FLOPs/byte {'(B in L2)' if p['b_in_l2'] else '(B from HBM)'}") + print(f"{'Ridge point':>14}: {p['ridge_point']:.0f} FLOPs/byte") + region = "COMPUTE" if p["tile_intensity"] >= p["ridge_point"] else "MEMORY" + print(f"{'Roofline':>14}: {_fmt_tflops(p['roofline_tflops'])} ({region}-bound tiles)") + + # ── Communication ───────────────────────────────────────────────── + print(f"\n── Communication {'─' * 54}") + print(f"{'Remote bytes':>14}: {_fmt_bytes(p['total_remote_bytes'])} (from {world_size - 1} peers)") + bw_src = "profiled" if bw_profiled else "user" + print( + f"{'Link BW':>14}: {link_bw:.1f} GB/s/link × {world_size - 1} links " + f"= {p['total_link_bw']:.0f} GB/s aggregate ({bw_src})" + ) + print(f"{'Comm time':>14}: {p['comm_time_ms']:.3f} ms (link-limited)") + + # ── Compute ─────────────────────────────────────────────────────── + print(f"\n── Compute {'─' * 60}") + print(f"{'Total FLOPs':>14}: {_fmt_flops(p['total_flops'])}") + print(f"{'Roofline time':>14}: {p['compute_time_ms']:.3f} ms (at {_fmt_tflops(p['roofline_tflops'])})") + print(f"{'Comm/Compute':>14}: {p['ratio']:.2f}x → {bound}") + + # ── Per-WG timing ───────────────────────────────────────────────── + print(f"\n── Per-WG Model {'─' * 55}") + print(f"{'GEMM WG':>14}: {p['gemm_wg_us']:.0f} us ({p['total_flops'] / p['total_gemm_wgs'] / 1e9:.2f} GFLOPs/WG)") + print(f"{'Fetch WG stg0':>14}: {p['fetch_wg_us_stg0']:.0f} us") + if p["num_fetch_stages"] > 1: + print(f"{'Fetch WG rest':>14}: {p['fetch_wg_us_rest']:.0f} us") + + # ── Pipeline ────────────────────────────────────────────────────── + S = p["num_fetch_stages"] + print(f"\n── Pipeline {'─' * 59}") + print(f"{'Stages (S)':>14}: {S}") + print(f"{'M tiles/stage':>14}: {p['m_per_stage']} ({p['m_per_stage'] * p['block_size_m']} rows)") + print( + f"{'GEMM WGs/stg':>14}: {p['gemm_tiles_per_stage']} ({p['m_per_stage']} m-tiles × {p['num_tiles_n']} n-tiles)" + ) + print(f"{'K flag groups':>14}: {p['num_flag_groups_k']} (k_per_flag={p['k_per_flag']})") + print(f"{'Stage comm':>14}: {p['stage_comm_ms']:.3f} ms") + print(f"{'Stage compute':>14}: {p['stage_compute_ms']:.3f} ms") + + # ── Grid ────────────────────────────────────────────────────────── + print(f"\n── Grid Layout {'─' * 56}") + print( + f"{'Stage 0':>14}: {p['first_stage_fetch_sms']} fetchers + " + f"{p['gemm_tiles_per_stage']} GEMM = " + f"{p['first_stage_fetch_sms'] + p['gemm_tiles_per_stage']} WGs" + ) + if S > 1: + print( + f"{'Stages 1..{}'.format(S - 1):>14}: {p['num_fetch_sms']} fetchers + " + f"{p['gemm_tiles_per_stage']} GEMM = " + f"{p['num_fetch_sms'] + p['gemm_tiles_per_stage']} WGs (×{S - 1})" + ) + print(f"{'Total grid':>14}: {p['grid_size']} WGs ({p['total_fetch_wgs']} fetch + {p['total_gemm_wgs']} GEMM)") + + # ── Time estimates ──────────────────────────────────────────────── + print(f"\n── Time Estimates {'─' * 53}") + print(f"{'CU-work lower':>14}: {p['est_ideal_ms']:.1f} ms (total WG time / {DEFAULT_NUM_CUS} CUs)") + print(f"{'Fused kernel':>14}: {p['est_kernel_ms']:.1f} ms (×{p['scheduling_factor']:.1f} scheduling overhead)") + est_tflops = p["total_flops"] / (p["est_kernel_ms"] * 1e-3) / 1e12 + print( + f"{'Est. TFLOPS':>14}: {est_tflops:.0f} TFLOPS ({est_tflops / p['roofline_tflops'] * 100:.0f}% of roofline)" + ) + print(f"{'':>14}") + print( + f"{'PyTorch est.':>14}: {p['pytorch_est_ms']:.1f} ms " + f"(all_gather {p['comm_time_ms']:.1f} + matmul {p['standalone_gemm_ms']:.1f})" + ) + if p["est_kernel_ms"] < p["pytorch_est_ms"]: + speedup = p["pytorch_est_ms"] / p["est_kernel_ms"] + print(f"{'Fused speedup':>14}: {speedup:.2f}x over sequential PyTorch") + else: + slowdown = p["est_kernel_ms"] / p["pytorch_est_ms"] + print(f"{'Fused speedup':>14}: {1 / slowdown:.2f}x (slower than sequential by {slowdown:.2f}x)") + + # ── Recommended parameters ──────────────────────────────────────── + print(f"\n── Recommended Kernel Parameters {'─' * 38}") + params = [ + ("block_size_m", p["block_size_m"]), + ("block_size_n", p["block_size_n"]), + ("block_size_k", p["block_size_k"]), + ("group_size_m", p["group_size_m"]), + ("num_fetch_sms", p["num_fetch_sms"]), + ("k_per_flag", p["k_per_flag"]), + ("num_warps", p["num_warps"]), + ("num_fetch_stages", p["num_fetch_stages"]), + ("first_stage_fetch_sms", p["first_stage_fetch_sms"]), + ] + for name, val in params: + print(f" --{name:30s} {val}") + + # ── Command line ────────────────────────────────────────────────── + extra = " ".join(passthrough_args) + if extra: + extra = " " + extra + cmd = ( + f"HSA_NO_SCRATCH_RECLAIM=1 torchrun --nproc_per_node {world_size} " + f"benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py " + f"-m {M} -n {N} -k {K} " + f"--block_size_m {p['block_size_m']} " + f"--block_size_n {p['block_size_n']} " + f"--block_size_k {p['block_size_k']} " + f"--group_size_m {p['group_size_m']} " + f"--num_fetch_sms {p['num_fetch_sms']} " + f"--k_per_flag {p['k_per_flag']} " + f"--num_warps {p['num_warps']} " + f"--num_fetch_stages {p['num_fetch_stages']} " + f"--first_stage_fetch_sms {p['first_stage_fetch_sms']}" + f"{extra}" + ) + print(f"\n── Command {'─' * 60}") + print(f" {cmd}") + print() + + +def main(): + parser = argparse.ArgumentParser( + description="Derive parameters for HBM-buffered all_gather_matmul.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument("-m", type=int, required=True, help="M dimension (rows of output)") + parser.add_argument("-n", type=int, required=True, help="N dimension (cols of output)") + parser.add_argument("-k", type=int, required=True, help="K dimension (total reduction dim)") + parser.add_argument("--world_size", type=int, default=DEFAULT_WORLD_SIZE, help="Number of GPUs") + parser.add_argument( + "--link_bw", + type=float, + default=None, + help="Per-link XGMI bandwidth in GB/s (one direction). Omit to auto-profile via GPU-to-GPU copies.", + ) + parser.add_argument("--num_cus", type=int, default=DEFAULT_NUM_CUS, help="Number of compute units") + parser.add_argument("--peak_tflops", type=float, default=DEFAULT_PEAK_TFLOPS_FP16, help="Peak fp16 TFLOPS") + parser.add_argument("--hbm_bw", type=float, default=DEFAULT_HBM_BW_GBPS, help="HBM bandwidth in GB/s") + parser.add_argument( + "--scheduling_factor", + type=float, + default=DEFAULT_SCHEDULING_FACTOR, + help="CU scheduling overhead factor (calibrated from traces)", + ) + + args, passthrough = parser.parse_known_args() + + if args.k % args.world_size != 0: + parser.error(f"K ({args.k}) must be divisible by world_size ({args.world_size})") + + link_bw = args.link_bw + bw_profiled = False + if link_bw is None: + try: + link_bw = profile_link_bandwidth(args.world_size) + bw_profiled = True + except Exception as e: + print(f"\n Auto-profiling failed: {e}") + print(" Falling back to --link_bw 50 (MI300X default)\n") + link_bw = 50.0 + + p = derive( + args.m, + args.n, + args.k, + args.world_size, + link_bw, + args.num_cus, + args.peak_tflops, + args.hbm_bw, + DEFAULT_L2_SIZE_BYTES, + args.scheduling_factor, + dtype_bytes=2, + ) + + print_analysis(args.m, args.n, args.k, args.world_size, link_bw, p, passthrough, bw_profiled=bw_profiled) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/all_gather_matmul/profile_att.sh b/benchmark/ops/all_gather_matmul/profile_att.sh new file mode 100755 index 00000000..21f6f21f --- /dev/null +++ b/benchmark/ops/all_gather_matmul/profile_att.sh @@ -0,0 +1,344 @@ +#!/bin/bash +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +# ATT (Advanced Thread Trace) Profiling Script for all_gather_matmul benchmark +# Uses rocprofv3 with thread trace to profile the benchmark at ISA instruction level. +# +# Usage: +# ./profile_att.sh [OPTIONS] +# +# Options: +# -r, --ranks NUM_RANKS Number of ranks/GPUs (default: 8) +# -m, --m-dim M M dimension (default: 2048) +# -n, --n-dim N N dimension (default: 16384) +# -k, --k-dim K K dimension (default: 131072) +# -v, --variant VARIANT Variant: pull, chunked, push, pipelined_pull (default: pull) +# --block-m SIZE Block size for M dimension (default: 256) +# --block-n SIZE Block size for N dimension (default: 256) +# --block-k SIZE Block size for K dimension (default: 64) +# --group-m SIZE Group size for M dimension tiling (default: 1) +# --num-xcds NUM Number of XCDs (default: 8) +# --validate Enable validation mode +# --benchmark-pytorch Also benchmark PyTorch for comparison +# -o, --output-dir DIR Base output directory (default: ./att_profiles) +# --att-target-cu CU Target CU for thread trace (default: 1) +# --att-buffer-size SIZE Trace buffer size in hex (default: 0x6000000 = 96MB) +# --att-activity LEVEL Perfcounter streaming level 1-16 (default: 8) +# --kernel-regex REGEX Kernel name regex filter (optional) +# --single-run Run only one iteration (no warmup, no repeat) +# --k-contiguous Use K-contiguous layout for both A and B matrices +# (default A is row-major/K-contiguous, adds --b_col_major) +# --a-col-major Store A matrix in column-major order (M-contiguous) +# --b-col-major Store B matrix in column-major order (K-contiguous) +# -h, --help Show this help message + +set -e + +# Default values +NUM_RANKS=8 +M_DIM=2048 +N_DIM=16384 +K_DIM=131072 +VARIANT="pull" +BASE_OUTPUT_DIR="./att_profiles" +ATT_TARGET_CU=1 +ATT_BUFFER_SIZE="0x6000000" # 96MB +ATT_ACTIVITY=8 +KERNEL_REGEX="" +SINGLE_RUN=true +K_CONTIGUOUS=true # Default to K-contiguous layout for both matrices +A_COL_MAJOR=false +B_COL_MAJOR=false +BLOCK_M=256 +BLOCK_N=256 +BLOCK_K=64 +GROUP_M=1 +NUM_XCDS=8 +VALIDATE=true +BENCHMARK_PYTORCH=true + +# Script directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCHMARK_SCRIPT="${SCRIPT_DIR}/benchmark_torchrun.py" + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + -r|--ranks) + NUM_RANKS="$2" + shift 2 + ;; + -m|--m-dim) + M_DIM="$2" + shift 2 + ;; + -n|--n-dim) + N_DIM="$2" + shift 2 + ;; + -k|--k-dim) + K_DIM="$2" + shift 2 + ;; + -v|--variant) + VARIANT="$2" + shift 2 + ;; + -o|--output-dir) + BASE_OUTPUT_DIR="$2" + shift 2 + ;; + --att-target-cu) + ATT_TARGET_CU="$2" + shift 2 + ;; + --att-buffer-size) + ATT_BUFFER_SIZE="$2" + shift 2 + ;; + --att-activity) + ATT_ACTIVITY="$2" + shift 2 + ;; + --kernel-regex) + KERNEL_REGEX="$2" + shift 2 + ;; + --single-run) + SINGLE_RUN=true + shift + ;; + --k-contiguous) + K_CONTIGUOUS=true + shift + ;; + --a-col-major) + A_COL_MAJOR=true + shift + ;; + --b-col-major) + B_COL_MAJOR=true + shift + ;; + --block-m) + BLOCK_M="$2" + shift 2 + ;; + --block-n) + BLOCK_N="$2" + shift 2 + ;; + --block-k) + BLOCK_K="$2" + shift 2 + ;; + --group-m) + GROUP_M="$2" + shift 2 + ;; + --num-xcds) + NUM_XCDS="$2" + shift 2 + ;; + --validate) + VALIDATE=true + shift + ;; + --no-validate) + VALIDATE=false + shift + ;; + --benchmark-pytorch) + BENCHMARK_PYTORCH=true + shift + ;; + --no-benchmark-pytorch) + BENCHMARK_PYTORCH=false + shift + ;; + -h|--help) + head -30 "$0" | tail -n +2 | sed 's/^# //' | sed 's/^#//' + exit 0 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Generate timestamp for output directory +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") +OUTPUT_DIR="${BASE_OUTPUT_DIR}/att_${VARIANT}_m${M_DIM}_n${N_DIM}_k${K_DIM}_${TIMESTAMP}" + +# Create output directory +mkdir -p "${OUTPUT_DIR}" + +# Log file with timestamp +LOG_FILE="${OUTPUT_DIR}/profile_${TIMESTAMP}.log" + +echo "==============================================" | tee "${LOG_FILE}" +echo "ATT Profiling for all_gather_matmul benchmark" | tee -a "${LOG_FILE}" +echo "==============================================" | tee -a "${LOG_FILE}" +echo "Timestamp: $(date)" | tee -a "${LOG_FILE}" +echo "Output directory: ${OUTPUT_DIR}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" +echo "Configuration:" | tee -a "${LOG_FILE}" +echo " NUM_RANKS: ${NUM_RANKS}" | tee -a "${LOG_FILE}" +echo " M: ${M_DIM}" | tee -a "${LOG_FILE}" +echo " N: ${N_DIM}" | tee -a "${LOG_FILE}" +echo " K: ${K_DIM}" | tee -a "${LOG_FILE}" +echo " Variant: ${VARIANT}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" +echo "ATT Parameters:" | tee -a "${LOG_FILE}" +echo " att-target-cu: ${ATT_TARGET_CU}" | tee -a "${LOG_FILE}" +echo " att-buffer-size: ${ATT_BUFFER_SIZE}" | tee -a "${LOG_FILE}" +echo " att-activity: ${ATT_ACTIVITY}" | tee -a "${LOG_FILE}" +if [[ -n "${KERNEL_REGEX}" ]]; then + echo " kernel-include-regex: ${KERNEL_REGEX}" | tee -a "${LOG_FILE}" +fi +echo " single-run: ${SINGLE_RUN}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" +echo "Matrix Layout:" | tee -a "${LOG_FILE}" +echo " k-contiguous: ${K_CONTIGUOUS}" | tee -a "${LOG_FILE}" +echo " a-col-major: ${A_COL_MAJOR}" | tee -a "${LOG_FILE}" +echo " b-col-major: ${B_COL_MAJOR}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" +echo "Block Sizes:" | tee -a "${LOG_FILE}" +echo " block-m: ${BLOCK_M}" | tee -a "${LOG_FILE}" +echo " block-n: ${BLOCK_N}" | tee -a "${LOG_FILE}" +echo " block-k: ${BLOCK_K}" | tee -a "${LOG_FILE}" +echo " group-m: ${GROUP_M}" | tee -a "${LOG_FILE}" +echo " num-xcds: ${NUM_XCDS}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" +echo "Benchmark Options:" | tee -a "${LOG_FILE}" +echo " validate: ${VALIDATE}" | tee -a "${LOG_FILE}" +echo " benchmark-pytorch: ${BENCHMARK_PYTORCH}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" + +# Build rocprofv3 ATT options +ROCPROF_OPTS="--att" +ROCPROF_OPTS="${ROCPROF_OPTS} --att-target-cu ${ATT_TARGET_CU}" +ROCPROF_OPTS="${ROCPROF_OPTS} --att-buffer-size ${ATT_BUFFER_SIZE}" +ROCPROF_OPTS="${ROCPROF_OPTS} --att-activity ${ATT_ACTIVITY}" + +if [[ -n "${KERNEL_REGEX}" ]]; then + ROCPROF_OPTS="${ROCPROF_OPTS} --kernel-include-regex \"${KERNEL_REGEX}\"" +fi + +# Build benchmark args +BENCH_ARGS="-m ${M_DIM} -n ${N_DIM} -k ${K_DIM} --variant ${VARIANT} --benchmark -r ${NUM_RANKS}" +BENCH_ARGS="${BENCH_ARGS} --block_size_m ${BLOCK_M} --block_size_n ${BLOCK_N} --block_size_k ${BLOCK_K}" +BENCH_ARGS="${BENCH_ARGS} --group_size_m ${GROUP_M} --num_xcds ${NUM_XCDS}" + +if [[ "${SINGLE_RUN}" == "true" ]]; then + BENCH_ARGS="${BENCH_ARGS} --single-run" +fi + +if [[ "${VALIDATE}" == "true" ]]; then + BENCH_ARGS="${BENCH_ARGS} -v" +fi + +if [[ "${BENCHMARK_PYTORCH}" == "true" ]]; then + BENCH_ARGS="${BENCH_ARGS} --benchmark_pytorch" +fi + +# Add K-contiguous layout options +# --k-contiguous: Both A and B become K-contiguous +# - A is already K-contiguous in default row-major layout +# - B needs --b_col_major to become K-contiguous +if [[ "${K_CONTIGUOUS}" == "true" ]]; then + BENCH_ARGS="${BENCH_ARGS} --b_col_major" +fi + +# Individual matrix layout overrides +if [[ "${A_COL_MAJOR}" == "true" ]]; then + BENCH_ARGS="${BENCH_ARGS} --a_col_major" +fi +if [[ "${B_COL_MAJOR}" == "true" ]]; then + BENCH_ARGS="${BENCH_ARGS} --b_col_major" +fi + +# Full command +# rocprofv3 wraps the entire torchrun command, not the other way around +# HSA_NO_SCRATCH_RECLAIM=1 prevents scratch memory reclaim issues +FULL_CMD="HSA_NO_SCRATCH_RECLAIM=1 rocprofv3 ${ROCPROF_OPTS} -d ${OUTPUT_DIR} -- torchrun --nproc_per_node=${NUM_RANKS} ${BENCHMARK_SCRIPT} ${BENCH_ARGS}" + +echo "Command:" | tee -a "${LOG_FILE}" +echo "${FULL_CMD}" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" + +# Save configuration to JSON for reference +cat > "${OUTPUT_DIR}/config.json" << EOF +{ + "timestamp": "${TIMESTAMP}", + "num_ranks": ${NUM_RANKS}, + "m_dim": ${M_DIM}, + "n_dim": ${N_DIM}, + "k_dim": ${K_DIM}, + "variant": "${VARIANT}", + "att_target_cu": ${ATT_TARGET_CU}, + "att_buffer_size": "${ATT_BUFFER_SIZE}", + "att_activity": ${ATT_ACTIVITY}, + "kernel_regex": "${KERNEL_REGEX}", + "single_run": ${SINGLE_RUN}, + "k_contiguous": ${K_CONTIGUOUS}, + "a_col_major": ${A_COL_MAJOR}, + "b_col_major": ${B_COL_MAJOR}, + "block_m": ${BLOCK_M}, + "block_n": ${BLOCK_N}, + "block_k": ${BLOCK_K}, + "group_m": ${GROUP_M}, + "num_xcds": ${NUM_XCDS}, + "validate": ${VALIDATE}, + "benchmark_pytorch": ${BENCHMARK_PYTORCH}, + "command": "${FULL_CMD}" +} +EOF + +echo "Starting profiling..." | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" + +# Run the profiling command +START_TIME=$(date +%s) + +# Execute the command and capture output +eval "${FULL_CMD}" 2>&1 | tee -a "${LOG_FILE}" +EXIT_CODE=${PIPESTATUS[0]} + +END_TIME=$(date +%s) +DURATION=$((END_TIME - START_TIME)) + +echo "" | tee -a "${LOG_FILE}" +echo "==============================================" | tee -a "${LOG_FILE}" +echo "Profiling completed" | tee -a "${LOG_FILE}" +echo "Exit code: ${EXIT_CODE}" | tee -a "${LOG_FILE}" +echo "Duration: ${DURATION} seconds" | tee -a "${LOG_FILE}" +echo "End time: $(date)" | tee -a "${LOG_FILE}" +echo "==============================================" | tee -a "${LOG_FILE}" +echo "" | tee -a "${LOG_FILE}" + +# List output files +echo "Output files:" | tee -a "${LOG_FILE}" +ls -la "${OUTPUT_DIR}" 2>&1 | tee -a "${LOG_FILE}" + +# Check for stats CSV files +if ls "${OUTPUT_DIR}"/stats_*.csv 1> /dev/null 2>&1; then + echo "" | tee -a "${LOG_FILE}" + echo "Stats CSV files found:" | tee -a "${LOG_FILE}" + ls -la "${OUTPUT_DIR}"/stats_*.csv 2>&1 | tee -a "${LOG_FILE}" +fi + +# Check for ui_output directories (ROCprof Compute Viewer compatible) +if ls -d "${OUTPUT_DIR}"/ui_output_* 1> /dev/null 2>&1; then + echo "" | tee -a "${LOG_FILE}" + echo "UI output directories (for ROCprof Compute Viewer):" | tee -a "${LOG_FILE}" + ls -d "${OUTPUT_DIR}"/ui_output_* 2>&1 | tee -a "${LOG_FILE}" +fi + +echo "" | tee -a "${LOG_FILE}" +echo "Profile output saved to: ${OUTPUT_DIR}" | tee -a "${LOG_FILE}" +echo "Log file: ${LOG_FILE}" | tee -a "${LOG_FILE}" + +exit ${EXIT_CODE} diff --git a/benchmark/ops/all_gather_matmul/tune_hbm_buffer.py b/benchmark/ops/all_gather_matmul/tune_hbm_buffer.py new file mode 100644 index 00000000..db9cc56f --- /dev/null +++ b/benchmark/ops/all_gather_matmul/tune_hbm_buffer.py @@ -0,0 +1,634 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Parameter tuning script for HBM-buffered all_gather_matmul. + +Sweeps parameters around a baseline configuration, collecting traces, TFLOPs, +PyTorch baseline, and validation for every configuration. + +This script does NOT modify benchmark_hbm_buffer.py — it invokes it via +``torchrun`` as a subprocess for each parameter set. + +Usage: + # Default one-at-a-time sweep (each param varied independently): + python benchmark/ops/all_gather_matmul/tune_hbm_buffer.py + + # Custom matrix size: + python benchmark/ops/all_gather_matmul/tune_hbm_buffer.py -m 8192 -n 4096 -k 131072 + + # Only sweep specific parameters: + python benchmark/ops/all_gather_matmul/tune_hbm_buffer.py --params num_fetch_sms k_per_flag + + # Full cartesian product (warning: combinatorial explosion): + python benchmark/ops/all_gather_matmul/tune_hbm_buffer.py --mode full + + # Dry run — just print what would be tested: + python benchmark/ops/all_gather_matmul/tune_hbm_buffer.py --dry_run +""" + +import argparse +import json +import os +import re +import subprocess +import time +from datetime import datetime +from itertools import product +from pathlib import Path + +# ───────────────────────────────────────────────────────────────────────────── +# Baseline configuration — the centre point of every sweep. +# Edit these to match your current best-known config. +# ───────────────────────────────────────────────────────────────────────────── +BASELINE = { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 4, + "num_fetch_sms": 64, + "k_per_flag": 64, + "num_warps": 8, + "num_fetch_stages": 4, + "first_stage_fetch_sms": 304, +} + +# ───────────────────────────────────────────────────────────────────────────── +# Sweep ranges — values to try for each parameter. +# In ``oneatatime`` mode only one parameter deviates from the baseline at a +# time; in ``full`` mode the cartesian product is taken (use with care). +# ───────────────────────────────────────────────────────────────────────────── +SWEEP_RANGES = { + "block_size_m": [64, 128, 256], + "block_size_n": [64, 128, 256], + "block_size_k": [64], + "group_size_m": [1, 2, 4, 8], + "num_fetch_sms": [64, 128, 192, 256], + "k_per_flag": [16, 32, 64, 128], + "num_warps": [4, 8], + "num_fetch_stages": [2, 4, 8], + "first_stage_fetch_sms": [128, 192, 256, 304], +} + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers +# ───────────────────────────────────────────────────────────────────────────── + + +def make_label(cfg): + """Short human-readable label for a config.""" + parts = [ + f"bm{cfg['block_size_m']}", + f"bn{cfg['block_size_n']}", + f"bk{cfg['block_size_k']}", + f"gm{cfg['group_size_m']}", + f"nf{cfg['num_fetch_sms']}", + f"kpf{cfg['k_per_flag']}", + f"nw{cfg['num_warps']}", + f"fs{cfg['num_fetch_stages']}", + ] + if cfg["num_fetch_stages"] > 1: + parts.append(f"fsf{cfg['first_stage_fetch_sms']}") + return "_".join(parts) + + +def validate_config(cfg, M, N, K, world_size=8): + """Return a list of error strings; empty list means valid.""" + errors = [] + K_local = K // world_size + bm, bn, bk = cfg["block_size_m"], cfg["block_size_n"], cfg["block_size_k"] + kpf = cfg["k_per_flag"] + + if M % bm != 0: + errors.append(f"M={M} not divisible by block_size_m={bm}") + if N % bn != 0: + errors.append(f"N={N} not divisible by block_size_n={bn}") + if K % bk != 0: + errors.append(f"K={K} not divisible by block_size_k={bk}") + if K_local % bk != 0: + errors.append(f"K_local={K_local} not divisible by block_size_k={bk}") + + num_k_blocks = K // bk + if num_k_blocks % kpf != 0: + errors.append(f"num_k_blocks={num_k_blocks} not divisible by k_per_flag={kpf}") + + if cfg["num_warps"] not in (1, 2, 4, 8, 16): + errors.append(f"num_warps={cfg['num_warps']} must be a power of 2 in [1..16]") + + return errors + + +def build_command(cfg, M, N, K, trace_path, nproc=8, validate=True, benchmark=True, benchmark_pytorch=False): + """Build the ``torchrun`` CLI for one configuration.""" + cmd = [ + "torchrun", + "--nproc_per_node", + str(nproc), + "benchmark/ops/all_gather_matmul/benchmark_hbm_buffer.py", + "-m", + str(M), + "-n", + str(N), + "-k", + str(K), + "--block_size_m", + str(cfg["block_size_m"]), + "--block_size_n", + str(cfg["block_size_n"]), + "--block_size_k", + str(cfg["block_size_k"]), + "--group_size_m", + str(cfg["group_size_m"]), + "--num_fetch_sms", + str(cfg["num_fetch_sms"]), + "--k_per_flag", + str(cfg["k_per_flag"]), + "--num_warps", + str(cfg["num_warps"]), + "--num_fetch_stages", + str(cfg["num_fetch_stages"]), + ] + + if cfg["num_fetch_stages"] > 1 and cfg.get("first_stage_fetch_sms") is not None: + cmd.extend(["--first_stage_fetch_sms", str(cfg["first_stage_fetch_sms"])]) + + if validate: + cmd.append("-v") + if benchmark: + cmd.append("-b") + if benchmark_pytorch: + cmd.append("--benchmark_pytorch") + + cmd.extend(["--trace", "--trace_output", trace_path]) + return cmd + + +# ── Output parsing ──────────────────────────────────────────────────────────── + +_RE_IRIS = re.compile(r"HBM-Buffer\s*\([^)]*\):\s*([\d.]+)\s*ms,\s*([\d.]+)\s*TFLOPS,\s*([\d.]+)\s*GB/s") +_RE_PYTORCH = re.compile(r"PyTorch\s*\([^)]*\):\s*([\d.]+)\s*ms,\s*([\d.]+)\s*TFLOPS,\s*([\d.]+)\s*GB/s") +_RE_SPEEDUP = re.compile(r"Speedup.*?:\s*([\d.]+)x") +_RE_VALID_FAIL = re.compile(r"Validation FAILED.*?max diff:\s*([\d.eE+-]+)") + + +def parse_output(output): + """Extract metrics from benchmark stdout+stderr.""" + result = { + "iris_ms": None, + "iris_tflops": None, + "iris_bw_gbps": None, + "pytorch_ms": None, + "pytorch_tflops": None, + "pytorch_bw_gbps": None, + "validation": None, + "speedup": None, + } + + m = _RE_IRIS.search(output) + if m: + result["iris_ms"] = float(m.group(1)) + result["iris_tflops"] = float(m.group(2)) + result["iris_bw_gbps"] = float(m.group(3)) + + m = _RE_PYTORCH.search(output) + if m: + result["pytorch_ms"] = float(m.group(1)) + result["pytorch_tflops"] = float(m.group(2)) + result["pytorch_bw_gbps"] = float(m.group(3)) + + if "Validation PASSED" in output: + result["validation"] = "PASSED" + elif "Validation FAILED" in output: + fm = _RE_VALID_FAIL.search(output) + result["validation"] = f"FAILED (diff={fm.group(1)})" if fm else "FAILED" + + m = _RE_SPEEDUP.search(output) + if m: + result["speedup"] = float(m.group(1)) + + return result + + +# ── Sweep generation ────────────────────────────────────────────────────────── + + +def generate_configs(baseline, sweep_ranges, mode="oneatatime", params=None): + """ + Generate the list of configs to evaluate. + + Args: + baseline: dict of default values + sweep_ranges: dict mapping param name -> list of values + mode: "oneatatime" or "full" + params: optional list of param names to sweep (None = all) + """ + configs = [] + seen = set() + + def _add(cfg): + label = make_label(cfg) + if label not in seen: + configs.append(dict(cfg)) + seen.add(label) + + # Always include baseline first + _add(baseline) + + active_params = params if params else list(sweep_ranges.keys()) + + if mode == "oneatatime": + for param in active_params: + if param not in sweep_ranges: + print(f" WARNING: unknown param '{param}', skipping") + continue + for val in sweep_ranges[param]: + cfg = dict(baseline) + cfg[param] = val + # When num_fetch_stages == 1, first_stage_fetch_sms is irrelevant + if cfg["num_fetch_stages"] == 1: + cfg["first_stage_fetch_sms"] = cfg["num_fetch_sms"] + _add(cfg) + + elif mode == "full": + active_ranges = {p: sweep_ranges[p] for p in active_params if p in sweep_ranges} + names = list(active_ranges.keys()) + values = [active_ranges[n] for n in names] + for combo in product(*values): + cfg = dict(baseline) + for n, v in zip(names, combo): + cfg[n] = v + if cfg["num_fetch_stages"] == 1: + cfg["first_stage_fetch_sms"] = cfg["num_fetch_sms"] + _add(cfg) + + return configs + + +# ── Main ────────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser( + description="Parameter tuning for HBM-buffered all_gather_matmul.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + # ── Matrix dimensions ──────────────────────────────────────────────── + parser.add_argument("-m", type=int, default=16384, help="M dimension") + parser.add_argument("-n", type=int, default=2048, help="N dimension") + parser.add_argument("-k", type=int, default=131072, help="K dimension (total)") + parser.add_argument("--nproc", type=int, default=8, help="Number of GPUs") + + # ── Baseline overrides (non-swept params use these values) ──────── + parser.add_argument( + "--block_size_m", type=int, default=None, help=f"Baseline block_size_m (default: {BASELINE['block_size_m']})" + ) + parser.add_argument( + "--block_size_n", type=int, default=None, help=f"Baseline block_size_n (default: {BASELINE['block_size_n']})" + ) + parser.add_argument( + "--block_size_k", type=int, default=None, help=f"Baseline block_size_k (default: {BASELINE['block_size_k']})" + ) + parser.add_argument( + "--group_size_m", type=int, default=None, help=f"Baseline group_size_m (default: {BASELINE['group_size_m']})" + ) + parser.add_argument( + "--num_fetch_sms", type=int, default=None, help=f"Baseline num_fetch_sms (default: {BASELINE['num_fetch_sms']})" + ) + parser.add_argument( + "--k_per_flag", type=int, default=None, help=f"Baseline k_per_flag (default: {BASELINE['k_per_flag']})" + ) + parser.add_argument( + "--num_warps", type=int, default=None, help=f"Baseline num_warps (default: {BASELINE['num_warps']})" + ) + parser.add_argument( + "--num_fetch_stages", + type=int, + default=None, + help=f"Baseline num_fetch_stages (default: {BASELINE['num_fetch_stages']})", + ) + parser.add_argument( + "--first_stage_fetch_sms", + type=int, + default=None, + help=f"Baseline first_stage_fetch_sms (default: {BASELINE['first_stage_fetch_sms']})", + ) + + # ── Sweep control ───────────────────────────────────────────────── + parser.add_argument( + "--mode", + choices=["oneatatime", "full"], + default="oneatatime", + help="'oneatatime' varies one param at a time; 'full' = cartesian product", + ) + parser.add_argument( + "--params", + nargs="+", + default=None, + help="Only sweep these parameters (default: all). Choices: " + ", ".join(SWEEP_RANGES.keys()), + ) + parser.add_argument("--output_dir", type=str, default=None, help="Output directory (auto-generated if unset)") + parser.add_argument("--dry_run", action="store_true", help="Print configs and exit without running") + parser.add_argument("--skip_validation", action="store_true", help="Skip validation (faster, no correctness check)") + parser.add_argument("--timeout", type=int, default=600, help="Per-config timeout in seconds (default: 600)") + + args = parser.parse_args() + M, N, K = args.m, args.n, args.k + + # Apply any CLI baseline overrides + baseline = dict(BASELINE) + for key in baseline: + cli_val = getattr(args, key, None) + if cli_val is not None: + baseline[key] = cli_val + + # Output directory + if args.output_dir: + output_dir = Path(args.output_dir) + else: + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = Path(f"benchmark/ops/all_gather_matmul/tune_results_{ts}") + output_dir.mkdir(parents=True, exist_ok=True) + trace_dir = output_dir / "traces" + trace_dir.mkdir(exist_ok=True) + + # Generate configs + configs = generate_configs(baseline, SWEEP_RANGES, mode=args.mode, params=args.params) + + # Pre-validate all configs + valid_configs = [] + skipped = [] + for cfg in configs: + errs = validate_config(cfg, M, N, K, world_size=args.nproc) + if errs: + skipped.append((cfg, errs)) + else: + valid_configs.append(cfg) + + # Banner + print(f"\n{'=' * 100}") + print(" HBM-Buffer All-Gather MatMul — Parameter Tuning") + print(f" M={M} N={N} K={K} nproc={args.nproc} mode={args.mode}") + print(f" Baseline: {make_label(baseline)}") + print(f" Configs to run: {len(valid_configs)} (skipped: {len(skipped)})") + print(f" Output dir: {output_dir}") + print(f" Validation: {'OFF' if args.skip_validation else 'ON'}") + print(f"{'=' * 100}") + + if skipped: + print(f"\n Skipped (invalid for M={M}, N={N}, K={K}):") + for cfg, errs in skipped: + print(f" {make_label(cfg)}: {'; '.join(errs)}") + + if args.dry_run: + print("\n Configs that would be run:") + for i, cfg in enumerate(valid_configs): + label = make_label(cfg) + is_baseline = cfg == baseline + tag = " [BASELINE]" if is_baseline else "" + print(f" [{i + 1:>3}] {label}{tag}") + print(f"\n Total: {len(valid_configs)} configs") + return + + # ── Run sweep ───────────────────────────────────────────────────────── + results = [] + pytorch_baseline = None + env = os.environ.copy() + env["HSA_NO_SCRATCH_RECLAIM"] = "1" + + total_start = time.time() + + for i, cfg in enumerate(valid_configs): + label = make_label(cfg) + trace_path = str(trace_dir / f"trace_{label}.png") + is_first = i == 0 + + sep = "-" * 80 + print(f"\n{sep}") + print(f"[{i + 1}/{len(valid_configs)}] {label}") + if is_first: + print(" (includes PyTorch baseline benchmark)") + print(sep) + + cmd = build_command( + cfg, + M, + N, + K, + trace_path, + nproc=args.nproc, + validate=not args.skip_validation, + benchmark=True, + benchmark_pytorch=is_first, + ) + cmd_str = " ".join(cmd) + print(f" $ HSA_NO_SCRATCH_RECLAIM=1 {cmd_str}") + + t0 = time.time() + try: + proc = subprocess.run( + cmd, + env=env, + capture_output=True, + text=True, + timeout=args.timeout, + ) + elapsed = time.time() - t0 + full_output = proc.stdout + "\n" + proc.stderr + + parsed = parse_output(full_output) + + # Capture PyTorch baseline on first run + if is_first and parsed["pytorch_tflops"] is not None: + pytorch_baseline = { + "ms": parsed["pytorch_ms"], + "tflops": parsed["pytorch_tflops"], + "bw_gbps": parsed["pytorch_bw_gbps"], + } + + trace_exists = os.path.exists(trace_path) + results.append( + { + "label": label, + "config": cfg, + "iris_ms": parsed["iris_ms"], + "iris_tflops": parsed["iris_tflops"], + "iris_bw_gbps": parsed["iris_bw_gbps"], + "validation": parsed["validation"], + "trace_path": trace_path if trace_exists else None, + "elapsed_s": round(elapsed, 1), + "returncode": proc.returncode, + } + ) + + # Print summary line + parts = [] + if parsed["iris_tflops"] is not None: + parts.append(f"{parsed['iris_tflops']:.2f} TFLOPS") + parts.append(f"{parsed['iris_ms']:.3f} ms") + if parsed["iris_bw_gbps"] is not None: + parts.append(f"{parsed['iris_bw_gbps']:.1f} GB/s") + if parsed["validation"]: + parts.append(f"valid={parsed['validation']}") + if trace_exists: + parts.append("trace=OK") + else: + parts.append("trace=MISSING") + if proc.returncode != 0: + parts.append(f"EXIT={proc.returncode}") + print(f" => {' | '.join(parts)} ({elapsed:.0f}s)") + + if is_first and pytorch_baseline: + print( + f" => PyTorch baseline: {pytorch_baseline['tflops']:.2f} TFLOPS {pytorch_baseline['ms']:.3f} ms" + ) + + # Save full log for debugging + log_path = output_dir / f"log_{label}.txt" + with open(log_path, "w") as f: + f.write(f"COMMAND: HSA_NO_SCRATCH_RECLAIM=1 {cmd_str}\n") + f.write(f"EXIT CODE: {proc.returncode}\n") + f.write(f"ELAPSED: {elapsed:.1f}s\n\n") + f.write("=== STDOUT ===\n") + f.write(proc.stdout) + f.write("\n=== STDERR ===\n") + f.write(proc.stderr) + + except subprocess.TimeoutExpired: + elapsed = time.time() - t0 + results.append( + { + "label": label, + "config": cfg, + "iris_ms": None, + "iris_tflops": None, + "iris_bw_gbps": None, + "validation": "TIMEOUT", + "trace_path": None, + "elapsed_s": round(elapsed, 1), + "returncode": -1, + } + ) + print(f" => TIMEOUT after {args.timeout}s") + + except Exception as e: + elapsed = time.time() - t0 + results.append( + { + "label": label, + "config": cfg, + "iris_ms": None, + "iris_tflops": None, + "iris_bw_gbps": None, + "validation": f"ERROR: {e}", + "trace_path": None, + "elapsed_s": round(elapsed, 1), + "returncode": -1, + } + ) + print(f" => ERROR: {e}") + + total_elapsed = time.time() - total_start + + # ── Summary table ───────────────────────────────────────────────────── + W = 130 + print(f"\n\n{'=' * W}") + print( + f" TUNING RESULTS | M={M} N={N} K={K} | nproc={args.nproc} | " + f"{len(valid_configs)} configs in {total_elapsed:.0f}s" + ) + if pytorch_baseline: + print( + f" PyTorch baseline: {pytorch_baseline['ms']:.3f} ms | " + f"{pytorch_baseline['tflops']:.2f} TFLOPS | " + f"{pytorch_baseline['bw_gbps']:.1f} GB/s" + ) + print(f"{'=' * W}") + + col_label_w = 65 + print( + f" {'#':>3} {'Configuration':<{col_label_w}} {'ms':>8} {'TFLOPS':>8} " + f"{'vs PT':>7} {'Valid':>8} {'Trace':>5}" + ) + print(f" {'-' * (W - 4)}") + + for i, r in enumerate(results): + ms_s = f"{r['iris_ms']:.3f}" if r["iris_ms"] is not None else "--" + tf_s = f"{r['iris_tflops']:.2f}" if r["iris_tflops"] is not None else "--" + + if pytorch_baseline and r["iris_tflops"] is not None and pytorch_baseline["tflops"] > 0: + vs_pt = f"{r['iris_tflops'] / pytorch_baseline['tflops']:.2f}x" + else: + vs_pt = "--" + + valid_s = (r["validation"] or "--")[:8] + trace_s = "Y" if r.get("trace_path") else "N" + + tag = ( + " *" + if ( + r["iris_tflops"] is not None + and r["iris_tflops"] + == max((x["iris_tflops"] for x in results if x["iris_tflops"] is not None), default=0) + ) + else "" + ) + + print( + f" {i + 1:>3} {r['label']:<{col_label_w}} {ms_s:>8} {tf_s:>8} " + f"{vs_pt:>7} {valid_s:>8} {trace_s:>5}{tag}" + ) + + # Best config + valid_results = [r for r in results if r["iris_tflops"] is not None] + if valid_results: + best = max(valid_results, key=lambda r: r["iris_tflops"]) + worst = min(valid_results, key=lambda r: r["iris_tflops"]) + print(f"\n {'BEST':>6}: {best['label']}") + print(f" {best['iris_ms']:.3f} ms | {best['iris_tflops']:.2f} TFLOPS | valid={best['validation']}") + if pytorch_baseline and pytorch_baseline["tflops"] > 0: + print(f" {best['iris_tflops'] / pytorch_baseline['tflops']:.2f}x vs PyTorch") + if best.get("trace_path"): + print(f" trace: {best['trace_path']}") + print(f" {'WORST':>6}: {worst['label']}") + print(f" {worst['iris_ms']:.3f} ms | {worst['iris_tflops']:.2f} TFLOPS") + if best["iris_tflops"] > 0 and worst["iris_tflops"] > 0: + print( + f" SPREAD: {best['iris_tflops'] / worst['iris_tflops']:.2f}x " + f"({worst['iris_tflops']:.2f} → {best['iris_tflops']:.2f} TFLOPS)" + ) + + print(f"{'=' * W}") + + # ── Save results JSON ───────────────────────────────────────────────── + results_path = output_dir / "results.json" + with open(results_path, "w") as f: + json.dump( + { + "meta": { + "M": M, + "N": N, + "K": K, + "nproc": args.nproc, + "mode": args.mode, + "baseline": baseline, + "sweep_ranges": SWEEP_RANGES, + "timestamp": datetime.now().isoformat(), + "total_elapsed_s": round(total_elapsed, 1), + "pytorch_baseline": pytorch_baseline, + }, + "results": results, + }, + f, + indent=2, + default=str, + ) + + print(f"\n Results JSON : {results_path}") + print(f" Trace PNGs : {trace_dir}/") + print(f" Per-run logs : {output_dir}/log_*.txt") + print() + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/matmul_all_gather/benchmark.py b/benchmark/ops/matmul_all_gather/benchmark.py new file mode 100644 index 00000000..22c914e8 --- /dev/null +++ b/benchmark/ops/matmul_all_gather/benchmark.py @@ -0,0 +1,367 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops matmul_all_gather fused operation. + +This benchmark showcases the fused GEMM + All-Gather operation where each rank +computes a local matmul and then gathers results along M dimension. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark matmul_all_gather fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows per rank in matrix A (M_local)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="matmul_all_gather.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (matmul + all_gather_into_tensor) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29529", help="Initialization URL for distributed setup" + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M_local = args["m"] # Local M dimension + M = M_local * world_size # Total M after gather + N = args["n"] + K = args["k"] + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "matmul_all_gather") + json_writer.add_field("m_local", M_local) + json_writer.add_field("m_total", M) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + + # Create input and output tensors + # A_local is M_local x K, output is M x N (gathered) + A_local = shmem.zeros((M_local, K), dtype=datatype) + B = shmem.zeros((K, N), dtype=datatype) + C = shmem.zeros((M, N), dtype=datatype) + expected_tensor = None + + # Fill inputs with deterministic values + # Each rank has different A_local, same B + torch.manual_seed(123 + rank) + A_local_data = torch.randn((M_local, K), dtype=datatype, device=f"cuda:{rank}") + A_local.copy_(A_local_data) + + torch.manual_seed(456) # Same B for all ranks + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # For validation: compute expected result + if args["validate"]: + # Gather all A_local matrices and compute expected result + A_local_list = [torch.zeros((M_local, K), dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_local_list, A_local_data) + + # Expected: [A_0 @ B; A_1 @ B; ...; A_n @ B] stacked along M + expected_tensor = shmem.zeros((M, N), dtype=datatype) + expected_parts = [] + for i, A_rank_local in enumerate(A_local_list): + C_rank_local = torch.matmul(A_rank_local, B_data) + expected_parts.append(C_rank_local) + expected_result = torch.cat(expected_parts, dim=0) + expected_tensor.copy_(expected_result) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "matmul_all_gather": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + workspace = None + + def run_experiment(): + nonlocal kernel_timing, workspace + + shmem.barrier() + + torch.cuda.nvtx.range_push("Matmul-All-Gather") + with torch.cuda.stream(comm_stream): + kernel_timing["matmul_all_gather"]["start_event"].record() + shmem.ops.matmul_all_gather( + C, + A_local, + B, + config=config, + async_op=False, + workspace=workspace, + ) + kernel_timing["matmul_all_gather"]["end_event"].record() + kernel_timing["matmul_all_gather"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["matmul_all_gather"]["start_event"].elapsed_time( + kernel_timing["matmul_all_gather"]["end_event"] + ) + kernel_timing["matmul_all_gather"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 1e-1 if datatype == torch.float16 else 1e-3 + success = torch.allclose(C, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation failed, max diff: {max_diff}") + + if success: + shmem.info("Matmul-all-gather validation passed!") + else: + shmem.error("Matmul-all-gather validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + for k in ["matmul_all_gather"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=25, n_repeat=1) + + for k in ["matmul_all_gather"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M_local*N*K flops per rank (but total is same across all ranks) + total_flops = 2 * M_local * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + tflops = total_tflops_unit / ( + (kernel_timing["matmul_all_gather"]["ms"] / kernel_timing["matmul_all_gather"]["experiments"]) * 1e-3 + ) + + # Calculate bandwidth for all-gather part + # All-gather moves (world_size - 1) * M_local * N * element_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + output_bytes = M_local * N * element_size + total_bytes = output_bytes * (world_size - 1) + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["matmul_all_gather"]["ms"] / kernel_timing["matmul_all_gather"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"Matmul-all-gather (M_local={M_local}, M_total={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "matmul_all_gather_ms", + kernel_timing["matmul_all_gather"]["ms"] / kernel_timing["matmul_all_gather"]["experiments"], + ) + json_writer.add_field("matmul_all_gather_experiments", kernel_timing["matmul_all_gather"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (matmul + all_gather_into_tensor) for comparison + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (matmul + all_gather_into_tensor)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A_local = torch.randn(M_local, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_C_local = torch.zeros(M_local, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + pytorch_C_local = torch.matmul(pytorch_A_local, pytorch_B) + dist.all_gather_into_tensor(pytorch_C, pytorch_C_local) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + def run_pytorch_experiment(): + pytorch_C_local = torch.matmul(pytorch_A_local, pytorch_B) + dist.all_gather_into_tensor(pytorch_C, pytorch_C_local) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch matmul+all_gather_into_tensor (M_local={M_local}, M_total={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = args["init_url"] + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/matmul_all_reduce/benchmark.py b/benchmark/ops/matmul_all_reduce/benchmark.py new file mode 100644 index 00000000..fd923e05 --- /dev/null +++ b/benchmark/ops/matmul_all_reduce/benchmark.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops matmul_all_reduce fused operation. + +This benchmark showcases the fused GEMM + All-Reduce operation and reports +achieved TFLOPS and communication bandwidth. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark matmul_all_reduce fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="matmul_all_reduce.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (matmul + all_reduce) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--all_reduce_variant", + type=str, + default="two_shot", + choices=["atomic", "ring", "two_shot", "one_shot", "spinlock"], + help="All-reduce variant to use", + ) + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29528", help="Initialization URL for distributed setup" + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + K = args["k"] + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + "all_reduce_variant": args["all_reduce_variant"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "matmul_all_reduce") + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + json_writer.add_field("all_reduce_variant", config.all_reduce_variant) + + # Create input and output tensors + # Must use shmem.zeros() to allocate on Iris symmetric heap + A = shmem.zeros((M, K), dtype=datatype) + B = shmem.zeros((K, N), dtype=datatype) + C = shmem.zeros((M, N), dtype=datatype) + expected_tensor = None + + # Fill inputs with deterministic values + # Each rank has different A, same B + torch.manual_seed(123 + rank) + A_local_data = torch.randn((M, K), dtype=datatype, device=f"cuda:{rank}") + A.copy_(A_local_data) + + torch.manual_seed(456) # Same B for all ranks + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # For validation: compute expected result + # Reference: each rank computes local C = A @ B, then all_reduce + if args["validate"]: + expected_tensor = shmem.zeros((M, N), dtype=datatype) + C_local_ref = torch.matmul(A_local_data, B_data) + pytorch_output = C_local_ref.clone() + shmem.barrier() + dist.all_reduce(pytorch_output, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + expected_tensor.copy_(pytorch_output) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "matmul_all_reduce": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + workspace = None + + def run_experiment(): + nonlocal kernel_timing, workspace + + # Preamble if available + if hasattr(shmem.ops, "matmul_all_reduce_preamble"): + workspace = shmem.ops.matmul_all_reduce_preamble( + C, + A, + B, + config=config, + workspace=workspace, + ) + + shmem.barrier() + + torch.cuda.nvtx.range_push("Matmul-All-Reduce") + with torch.cuda.stream(comm_stream): + kernel_timing["matmul_all_reduce"]["start_event"].record() + shmem.ops.matmul_all_reduce( + C, + A, + B, + config=config, + async_op=False, + workspace=workspace, + ) + kernel_timing["matmul_all_reduce"]["end_event"].record() + kernel_timing["matmul_all_reduce"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["matmul_all_reduce"]["start_event"].elapsed_time( + kernel_timing["matmul_all_reduce"]["end_event"] + ) + kernel_timing["matmul_all_reduce"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 0.2 if datatype == torch.float16 else 0.3 + success = torch.allclose(C, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation failed, max diff: {max_diff}") + + if success: + shmem.info("Matmul-all-reduce validation passed!") + else: + shmem.error("Matmul-all-reduce validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + for k in ["matmul_all_reduce"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=25, n_repeat=1) + + for k in ["matmul_all_reduce"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + tflops = total_tflops_unit / ( + (kernel_timing["matmul_all_reduce"]["ms"] / kernel_timing["matmul_all_reduce"]["experiments"]) * 1e-3 + ) + + # Calculate bandwidth for all-reduce part + # All-reduce moves 2 * (world_size - 1) / world_size * data_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + output_bytes = M * N * element_size + total_bytes = output_bytes * (2 * (world_size - 1)) / world_size + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["matmul_all_reduce"]["ms"] / kernel_timing["matmul_all_reduce"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"Matmul-all-reduce (M={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}, variant={args['all_reduce_variant']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "matmul_all_reduce_ms", + kernel_timing["matmul_all_reduce"]["ms"] / kernel_timing["matmul_all_reduce"]["experiments"], + ) + json_writer.add_field("matmul_all_reduce_experiments", kernel_timing["matmul_all_reduce"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (matmul + all_reduce) for comparison + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (matmul + all_reduce)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A = torch.randn(M, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + pytorch_C = torch.matmul(pytorch_A, pytorch_B) + dist.all_reduce(pytorch_C, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + def run_pytorch_experiment(): + pytorch_C = torch.matmul(pytorch_A, pytorch_B) + dist.all_reduce(pytorch_C, op=dist.ReduceOp.SUM) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch matmul+all_reduce (M={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = args["init_url"] + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/matmul_reduce_scatter/benchmark.py b/benchmark/ops/matmul_reduce_scatter/benchmark.py new file mode 100644 index 00000000..301444f2 --- /dev/null +++ b/benchmark/ops/matmul_reduce_scatter/benchmark.py @@ -0,0 +1,421 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops matmul_reduce_scatter fused operation. + +This benchmark showcases the fused GEMM + Reduce-Scatter operation where each rank +computes a local matmul, reduces across all ranks, and scatters tiles to ranks. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark matmul_reduce_scatter fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="matmul_reduce_scatter.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (matmul + all_reduce) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29531", help="Initialization URL for distributed setup" + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + K = args["k"] + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "matmul_reduce_scatter") + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + + # Calculate tile distribution + num_pid_m = (M + config.block_size_m - 1) // config.block_size_m + num_pid_n = (N + config.block_size_n - 1) // config.block_size_n + total_tiles = num_pid_m * num_pid_n + tiles_per_rank = total_tiles // world_size + start_tile = rank * tiles_per_rank + if rank == world_size - 1: + tiles_per_rank = total_tiles - start_tile + + json_writer.add_field("total_tiles", total_tiles) + json_writer.add_field("tiles_per_rank", tiles_per_rank) + + # Create input and output tensors + # Each rank computes full A @ B, but only keeps its assigned tiles + A = shmem.zeros((M, K), dtype=datatype) + B = shmem.zeros((K, N), dtype=datatype) + C = shmem.zeros((M, N), dtype=datatype) + expected_tiles = [] + + # Fill inputs with deterministic values + # Each rank has different A, same B + torch.manual_seed(123 + rank) + A_local_data = torch.randn((M, K), dtype=datatype, device=f"cuda:{rank}") + A.copy_(A_local_data) + + torch.manual_seed(456) # Same B for all ranks + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # For validation: compute expected result for this rank's tiles + if args["validate"]: + # Gather all A matrices to compute expected result + A_list = [torch.zeros((M, K), dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_list, A_local_data) + + # Expected: sum of all (A_i @ B) for each rank i, but only for this rank's tiles + expected_full = torch.zeros((M, N), dtype=datatype, device=f"cuda:{rank}") + for A_rank in A_list: + expected_full += torch.matmul(A_rank, B_data) + + # Extract only this rank's tiles + for local_tile_idx in range(tiles_per_rank): + tile_id = start_tile + local_tile_idx + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + + m_start = pid_m * config.block_size_m + m_end = min(m_start + config.block_size_m, M) + n_start = pid_n * config.block_size_n + n_end = min(n_start + config.block_size_n, N) + + expected_tiles.append( + { + "tile_id": tile_id, + "pid_m": pid_m, + "pid_n": pid_n, + "m_start": m_start, + "m_end": m_end, + "n_start": n_start, + "n_end": n_end, + "data": expected_full[m_start:m_end, n_start:n_end].clone(), + } + ) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "matmul_reduce_scatter": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + workspace = None + + def run_experiment(): + nonlocal kernel_timing, workspace + + # Preamble if available + if hasattr(shmem.ops, "matmul_reduce_scatter_preamble"): + workspace = shmem.ops.matmul_reduce_scatter_preamble( + C, + A, + B, + config=config, + workspace=workspace, + ) + + shmem.barrier() + + torch.cuda.nvtx.range_push("Matmul-Reduce-Scatter") + with torch.cuda.stream(comm_stream): + kernel_timing["matmul_reduce_scatter"]["start_event"].record() + shmem.ops.matmul_reduce_scatter( + C, + A, + B, + async_op=False, + config=config, + workspace=workspace, + ) + kernel_timing["matmul_reduce_scatter"]["end_event"].record() + kernel_timing["matmul_reduce_scatter"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["matmul_reduce_scatter"]["start_event"].elapsed_time( + kernel_timing["matmul_reduce_scatter"]["end_event"] + ) + kernel_timing["matmul_reduce_scatter"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 2e-1 if datatype == torch.float16 else 1e-1 + success = True + + # Validate each tile assigned to this rank + for tile_info in expected_tiles: + C_tile = C[tile_info["m_start"] : tile_info["m_end"], tile_info["n_start"] : tile_info["n_end"]] + expected_tile = tile_info["data"] + + tile_match = torch.allclose(C_tile, expected_tile, atol=atol) + if not tile_match: + max_diff = torch.abs(C_tile - expected_tile).max().item() + shmem.error( + f"Rank {rank}, tile {tile_info['tile_id']} ({tile_info['pid_m']},{tile_info['pid_n']}): " + f"Validation failed, max diff: {max_diff}" + ) + success = False + + if success: + shmem.info("Matmul-reduce-scatter validation passed!") + else: + shmem.error("Matmul-reduce-scatter validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + for k in ["matmul_reduce_scatter"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=25, n_repeat=1) + + for k in ["matmul_reduce_scatter"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + tflops = total_tflops_unit / ( + (kernel_timing["matmul_reduce_scatter"]["ms"] / kernel_timing["matmul_reduce_scatter"]["experiments"]) + * 1e-3 + ) + + # Calculate bandwidth for reduce-scatter part + # Similar to all-reduce: 2 * (world_size - 1) / world_size * data_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + output_bytes = M * N * element_size + total_bytes = output_bytes * (2 * (world_size - 1)) / world_size + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["matmul_reduce_scatter"]["ms"] / kernel_timing["matmul_reduce_scatter"]["experiments"]) + * 1e-3 + ) + + shmem.info( + f"Matmul-reduce-scatter (M={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "matmul_reduce_scatter_ms", + kernel_timing["matmul_reduce_scatter"]["ms"] / kernel_timing["matmul_reduce_scatter"]["experiments"], + ) + json_writer.add_field( + "matmul_reduce_scatter_experiments", kernel_timing["matmul_reduce_scatter"]["experiments"] + ) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (matmul + all_reduce) for comparison + # Note: We use all_reduce since PyTorch's reduce_scatter has different semantics + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (matmul + all_reduce)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A = torch.randn(M, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + pytorch_C = torch.matmul(pytorch_A, pytorch_B) + dist.all_reduce(pytorch_C, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + def run_pytorch_experiment(): + pytorch_C = torch.matmul(pytorch_A, pytorch_B) + dist.all_reduce(pytorch_C, op=dist.ReduceOp.SUM) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch matmul+all_reduce (M={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = args["init_url"] + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/iris/iris.py b/iris/iris.py index f0effbb2..8242fa18 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1517,6 +1517,7 @@ def store(self, pointer, value, to_rank, mask=None, hint: tl.constexpr = None): value (Block): The tensor of elements to be stored. to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None. Returns: None @@ -1542,6 +1543,7 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None, hint: tl.constexpr = None) to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer to local memory in current rank where the data will be written. from_rank (int): The rank ID from which to read the data. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None. Returns: None @@ -1568,6 +1570,7 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None, hint: tl.constexpr = None): to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that references memory in `to_rank`. to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None. Returns: None diff --git a/iris/ops/__init__.py b/iris/ops/__init__.py index e0d12ba5..c96fa32e 100644 --- a/iris/ops/__init__.py +++ b/iris/ops/__init__.py @@ -141,17 +141,16 @@ def matmul_all_gather(self, output_tensor, A, B, bias=None, async_op=False, conf """ return matmul_all_gather(self._shmem, output_tensor, A, B, bias, async_op, config, workspace) - def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, config=None, workspace=None): + def matmul_reduce_scatter(self, output_tensor, A, B, async_op=False, config=None, workspace=None): """ Fused matrix multiplication and reduce-scatter. - Computes: output = reduce_scatter(A @ B + bias) along N dimension + Computes: output = reduce_scatter(A @ B) where each rank keeps assigned tiles Args: - output_tensor: Output tensor (M, N_local) where N_local = N / world_size + output_tensor: Output tensor (M, N) - will contain reduced tiles for this rank A: Input matrix A (M, K) B: Input matrix B (K, N) - bias: Optional bias vector (M,) or (N,) async_op: If False, performs barrier at end config: Optional FusedConfig for tuning workspace: Optional pre-allocated workspace @@ -160,11 +159,10 @@ def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, workspace: Updated workspace object Example: - >>> N_local = N // world_size - >>> output = shmem.zeros((M, N_local), dtype=torch.float16) + >>> output = shmem.zeros((M, N), dtype=torch.float16) >>> shmem.ops.matmul_reduce_scatter(output, A, B) """ - return matmul_reduce_scatter(self._shmem, output_tensor, A, B, bias, async_op, config, workspace) + return matmul_reduce_scatter(self._shmem, output_tensor, A, B, async_op, config, workspace) # Export public API @@ -175,7 +173,6 @@ def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, # Namespace "OpsNamespace", # Operations - "matmul", # Simple single-GPU GEMM "matmul_all_reduce", "matmul_all_reduce_preamble", "all_gather_matmul", diff --git a/iris/ops/all_gather_matmul.py b/iris/ops/all_gather_matmul.py index 5d700206..4f272825 100644 --- a/iris/ops/all_gather_matmul.py +++ b/iris/ops/all_gather_matmul.py @@ -164,7 +164,7 @@ def all_gather_matmul_preamble( B: torch.Tensor, config: Optional[FusedConfig] = None, ) -> FusedWorkspace: - """Allocate workspace for all_gather_matmul (none needed for pull pattern).""" + """Allocate workspace for all_gather_matmul.""" if config is None: config = FusedConfig() @@ -175,14 +175,25 @@ def all_gather_matmul_preamble( expected_K = world_size * K_local assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" - return FusedWorkspace( + ws = FusedWorkspace( operation="all_gather_matmul", shape=(M, N, K), dtype=A_sharded.dtype, world_size=world_size, + variant=config.all_gather_matmul_variant, prepared=True, ) + # Allocate push variant workspace + if config.all_gather_matmul_variant == "push": + num_m_tiles = (M + config.block_size_m - 1) // config.block_size_m + num_k_tiles = (K_local + config.block_size_k - 1) // config.block_size_k + ws.a_inbox = shmem.zeros((world_size, M, K_local), dtype=A_sharded.dtype) + ws.signal_flags = shmem.zeros((world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32) + shmem.barrier() + + return ws + def all_gather_matmul( shmem, @@ -208,17 +219,6 @@ def all_gather_matmul( assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" # Validate problem size against block sizes - assert M >= config.block_size_m, ( - f"M ({M}) must be >= block_size_m ({config.block_size_m}). Use smaller block sizes for small problems." - ) - assert K_local >= config.block_size_k, ( - f"K_local ({K_local}) must be >= block_size_k ({config.block_size_k}). " - f"Use smaller block sizes for small problems." - ) - assert N >= config.block_size_n, ( - f"N ({N}) must be >= block_size_n ({config.block_size_n}). Use smaller block sizes for small problems." - ) - if workspace is None: workspace = all_gather_matmul_preamble(shmem, A_sharded, B, config) @@ -245,38 +245,44 @@ def all_gather_matmul( even_k = K_local % config.block_size_k == 0 num_k_blocks_local = (K_local + config.block_size_k - 1) // config.block_size_k - # Launch single fused kernel - grid = (num_sms,) - _fused_all_gather_matmul_kernel[grid]( - A_sharded, - B, - output_tensor, - bias_ptr, - M, - N, - K, - K_local, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_bias, - shmem.get_device_context(), - rank, - world_size, - config.block_size_m, - config.block_size_n, - config.block_size_k, - config.group_size_m, - num_sms, - config.num_xcds, - num_k_blocks_local, - use_bias, - even_k, - config.allow_tf32, - ) + variant = config.all_gather_matmul_variant + + if variant == "pull": + num_tiles_m = (M + config.block_size_m - 1) // config.block_size_m + num_tiles_n = (N + config.block_size_n - 1) // config.block_size_n + num_tiles = num_tiles_m * num_tiles_n + # grid = (num_tiles,) + grid = (num_sms,) + _fused_all_gather_matmul_kernel[grid]( + A_sharded, + B, + output_tensor, + bias_ptr, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + shmem.get_device_context(), + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_sms, + config.num_xcds, + num_k_blocks_local, + use_bias, + even_k, + config.allow_tf32, + ) if not async_op: shmem.barrier() diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py new file mode 100644 index 00000000..2db1b6ed --- /dev/null +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -0,0 +1,520 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Fused All-Gather + GEMM using a local HBM staging buffer with dedicated +fetcher and GEMM workgroups, launched data-parallel. + +Supports configurable staged_a buffer layout (M-contiguous or K-contiguous) +and B layout to match optimal tritonblas conventions (TN, TT, NT, NN). +""" + +from typing import Optional +import torch +import triton +import triton.language as tl +import iris +import iris.x + +from iris.device_utils import read_realtime +from iris.tracing.events import TraceEvent +from .config import FusedConfig +from .workspace import FusedWorkspace + + +@triton.jit +def _hbm_buffer_all_gather_matmul_kernel( + A_sharded, + B, + C, + bias_ptr, + staged_a, + flags_ptr, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_sa_m, # staged_a stride in M dim + stride_sa_k, # staged_a stride in K dim + stride_bias, + context_tensor: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_FETCH_SMS: tl.constexpr, + NUM_M_TILES: tl.constexpr, + NUM_TILES_N: tl.constexpr, + NUM_K_BLOCKS: tl.constexpr, + NUM_K_BLOCKS_LOCAL: tl.constexpr, + K_PER_FLAG: tl.constexpr, + NUM_FLAG_GROUPS_K: tl.constexpr, + TOTAL_GATHER_TILES: tl.constexpr, + BIAS: tl.constexpr, + ALLOW_TF32: tl.constexpr, + NUM_FETCH_STAGES: tl.constexpr, + GEMM_TILES_PER_STAGE: tl.constexpr, + FIRST_STAGE_FETCH_SMS: tl.constexpr, + TRACE: tl.constexpr, +): + pid = tl.program_id(0) + acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 + zero = tl.program_id(0) * 0 + + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size, tracing=TRACE) + + # Interleaved layout with asymmetric first stage: + # [fetch0 (P)] [gemm0 (G)] [fetch1 (F)] [gemm1 (G)] ... + # P = FIRST_STAGE_FETCH_SMS, F = NUM_FETCH_SMS, G = GEMM_TILES_PER_STAGE + FIRST_STAGE_SIZE: tl.constexpr = FIRST_STAGE_FETCH_SMS + GEMM_TILES_PER_STAGE + REST_STAGE_SIZE: tl.constexpr = NUM_FETCH_SMS + GEMM_TILES_PER_STAGE + M_PER_STAGE: tl.constexpr = (NUM_M_TILES + NUM_FETCH_STAGES - 1) // NUM_FETCH_STAGES + + # Two-phase decode: stage 0 has a different size than subsequent stages + if pid < FIRST_STAGE_SIZE: + my_stage = zero + local_pid = pid + fetch_threshold = zero + FIRST_STAGE_FETCH_SMS + else: + adjusted = pid - FIRST_STAGE_SIZE + my_stage = 1 + adjusted // REST_STAGE_SIZE + local_pid = adjusted % REST_STAGE_SIZE + fetch_threshold = zero + NUM_FETCH_SMS + + if local_pid < fetch_threshold: + # ============================================================== + # FETCHER — stage 0 uses FIRST_STAGE_FETCH_SMS WGs, + # later stages use NUM_FETCH_SMS WGs + # ============================================================== + stage_pid = local_pid + + if TRACE: + _trace_handle = ctx.tracing.record_event_start( + event_id=TraceEvent().wg_fetch, + target_rank=cur_rank, + address=flags_ptr + tl.arange(0, 1), + pid_m=pid, + pid_n=my_stage, + ) + + src_view = iris.x.make_tensor_view(A_sharded, M, K_local, stride_am, stride_ak) + + tiles_per_m_group = NUM_FLAG_GROUPS_K * GROUP_SIZE_M + + for const_stage in range(NUM_FETCH_STAGES): + if my_stage == const_stage: + stage_fetch_sms = FIRST_STAGE_FETCH_SMS if const_stage == 0 else NUM_FETCH_SMS + stage_m_start = const_stage * M_PER_STAGE + stage_m_count = min(M_PER_STAGE, NUM_M_TILES - stage_m_start) + total_fg_stage = NUM_FLAG_GROUPS_K * stage_m_count + + for fg_idx in range(stage_pid, total_fg_stage, stage_fetch_sms): + m_group = fg_idx // tiles_per_m_group + within_group = fg_idx % tiles_per_m_group + k_flag_group = within_group // GROUP_SIZE_M + m_in_group = within_group % GROUP_SIZE_M + m_tile = stage_m_start + m_group * GROUP_SIZE_M + m_in_group + m_tile = min(m_tile, NUM_M_TILES - 1) + k_block_start = k_flag_group * K_PER_FLAG + + rm = m_tile * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + + for k_off in range(K_PER_FLAG): + k_block_global = k_block_start + k_off + + src_rank_idx = k_block_global // NUM_K_BLOCKS_LOCAL + k_block_local = k_block_global % NUM_K_BLOCKS_LOCAL + + pid_m_t = zero + m_tile + tile_k_t = zero + k_block_local + k_tile = iris.x.TileView(pid_m_t, tile_k_t, BLOCK_SIZE_M, BLOCK_SIZE_K) + + rk = k_block_global * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K) + staged_ptrs = staged_a + rm.to(tl.int64)[:, None] * stride_sa_m + rk[None, :] * stride_sa_k + + for compile_rank in range(world_size): + if src_rank_idx == compile_rank: + a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx, hint=(1, BLOCK_SIZE_K)) + tl.store(staged_ptrs, a_tile, cache_modifier=".cg") + + flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group + tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") + + if TRACE: + ctx.tracing.record_event_end(_trace_handle) + + else: + # ============================================================== + # GEMM — gemm_local_id indexes into this stage's M-tile range + # ============================================================== + gemm_local_id = local_pid - fetch_threshold + stage_m_start = my_stage * M_PER_STAGE + + num_pid_in_group = GROUP_SIZE_M * NUM_TILES_N + group_id = gemm_local_id // num_pid_in_group + first_pid_m = stage_m_start + group_id * GROUP_SIZE_M + first_pid_m = min(first_pid_m, NUM_M_TILES - 1) + group_sz = min(NUM_M_TILES - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((gemm_local_id % num_pid_in_group) % group_sz) + pid_n = (gemm_local_id % num_pid_in_group) // group_sz + pid_m = min(pid_m, NUM_M_TILES - 1) + + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + rn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_SIZE_N), BLOCK_SIZE_N) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + if TRACE: + _trace_handle = ctx.tracing.record_event_start( + event_id=TraceEvent().wg_gemm, + target_rank=cur_rank, + address=flags_ptr + tl.arange(0, 1), + pid_m=pid, + pid_n=my_stage, + ) + _wt = zero.to(tl.int64) + + for k_fg in range(NUM_FLAG_GROUPS_K): + if TRACE: + _ws = read_realtime() + + flag_idx = pid_m * NUM_FLAG_GROUPS_K + k_fg + while tl.atomic_add(flags_ptr + flag_idx, 0, sem="acquire", scope="gpu") == 0: + pass + + if TRACE: + _wt = _wt + (read_realtime() - _ws) + + k_block_base = k_fg * K_PER_FLAG + for k_off in range(K_PER_FLAG): + k_block = k_block_base + k_off + rk = k_block * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K) + + a_ptrs = staged_a + rm.to(tl.int64)[:, None] * stride_sa_m + rk[None, :] * stride_sa_k + a = tl.load(a_ptrs) + + B_ptrs = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(B_ptrs) + + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) + + if BIAS: + bias_val = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0) + acc = acc + bias_val[:, None] + + c = acc.to(C.type.element_ty) + C_ptrs = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + c_mask = (rm[:, None] < M) & (rn[None, :] < N) + tl.store(C_ptrs, c, mask=c_mask, cache_modifier=".wt") + + if TRACE: + ctx.tracing.record_event_end(_trace_handle) + ctx.tracing.record_event_start( + event_id=TraceEvent().wg_gemm_wait, + target_rank=cur_rank, + address=flags_ptr + tl.arange(0, 1), + pid_m=pid, + pid_n=_wt.to(tl.int32), + ) + + +# ========================================================================== +# Python API +# ========================================================================== + + +def all_gather_matmul_hbm_buffer_preamble( + shmem, + A_sharded: torch.Tensor, + B: torch.Tensor, + config: Optional[FusedConfig] = None, + k_per_flag: int = 1, + staged_a_layout: str = "k_contiguous", +) -> FusedWorkspace: + """ + Allocate workspace. + + Args: + staged_a_layout: "k_contiguous" (default, row-major (M,K)) or + "m_contiguous" (col-major, stored as (K,M) transposed). + """ + if config is None: + config = FusedConfig() + + M, K_local = A_sharded.shape + K, N = B.shape + world_size = shmem.get_num_ranks() + + assert world_size * K_local == K + assert K_local % config.block_size_k == 0 + assert K % config.block_size_k == 0 + assert M % config.block_size_m == 0 + + num_m_tiles = M // config.block_size_m + num_k_blocks = K // config.block_size_k + assert num_k_blocks % k_per_flag == 0 + num_flag_groups_k = num_k_blocks // k_per_flag + + ws = FusedWorkspace( + operation="all_gather_matmul_hbm_buffer", + shape=(M, N, K), + dtype=A_sharded.dtype, + world_size=world_size, + variant=f"hbm_buffer_{staged_a_layout}", + prepared=True, + ) + + if staged_a_layout == "m_contiguous": + # Allocate (K, M) row-major, .T gives (M, K) with stride_m=1, stride_k=M + storage = shmem.zeros((K, M), dtype=A_sharded.dtype) + ws.aux_buffer = storage.T # (M, K) view, M-contiguous + else: + # Default: (M, K) row-major, stride_m=K, stride_k=1 + ws.aux_buffer = shmem.zeros((M, K), dtype=A_sharded.dtype) + + ws.locks = shmem.zeros((num_m_tiles * num_flag_groups_k,), dtype=torch.int32) + + buffer_mb = M * K * A_sharded.element_size() / (1024**2) + sa_stride_m, sa_stride_k = ws.aux_buffer.stride() + shmem.info( + f"HBM buffer: staged_a=({M},{K}) [{buffer_mb:.1f} MB] " + f"layout={staged_a_layout} strides=({sa_stride_m},{sa_stride_k}), " + f"flags={num_m_tiles}x{num_flag_groups_k}, k_per_flag={k_per_flag}" + ) + + shmem.barrier() + return ws + + +_WG_FETCH = 14 +_WG_GEMM = 15 +_WG_GEMM_WAIT = 16 + + +def _extract_wg_trace(shmem, grid_size, **metadata): + """Reconstruct per-workgroup trace arrays from DeviceTracing events.""" + import numpy as np + + bufs = shmem.tracing.trace_buffers + n = min(shmem.tracing.trace_counter.item(), shmem.tracing.max_events) + + event_ids = bufs["event_id"][:n].cpu().numpy() + pids = bufs["pid"][:n].cpu().numpy() + timestamps = bufs["timestamp"][:n].cpu().numpy().astype(np.int64) + end_ts = bufs["duration_cycles"][:n].cpu().numpy().astype(np.int64) + xcc_ids = bufs["xcc_id"][:n].cpu().numpy().astype(np.int32) + pid_ns = bufs["pid_n"][:n].cpu().numpy() + + starts = torch.zeros(grid_size, dtype=torch.int64) + ends = torch.zeros(grid_size, dtype=torch.int64) + waits = torch.zeros(grid_size, dtype=torch.int64) + xcds = torch.zeros(grid_size, dtype=torch.int32) + + for i in range(n): + eid = int(event_ids[i]) + wg = int(pids[i]) + if wg >= grid_size: + continue + if eid == _WG_FETCH or eid == _WG_GEMM: + starts[wg] = int(timestamps[i]) + ends[wg] = int(end_ts[i]) + xcds[wg] = int(xcc_ids[i]) + elif eid == _WG_GEMM_WAIT: + waits[wg] = int(pid_ns[i]) + + return {"start": starts, "end": ends, "wait": waits, "xcd": xcds, "grid_size": grid_size, **metadata} + + +def all_gather_matmul_hbm_buffer( + shmem, + output_tensor: torch.Tensor, + A_sharded: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor] = None, + async_op: bool = False, + config: Optional[FusedConfig] = None, + workspace: Optional[FusedWorkspace] = None, + num_fetch_sms: Optional[int] = None, + k_per_flag: int = 1, + fetch_block_m: Optional[int] = None, + fetch_block_k: Optional[int] = None, + staged_a_layout: str = "k_contiguous", + num_warps: Optional[int] = None, + num_stages: Optional[int] = None, + num_fetch_stages: int = 1, + first_stage_fetch_sms: Optional[int] = None, + trace: bool = False, +) -> FusedWorkspace: + """ + All-gather + matmul with dedicated fetcher/GEMM workgroups. + + Args: + staged_a_layout: Buffer layout for gathered A. + "k_contiguous" — (M,K) row-major, K is fast dim. Matches NN convention. + "m_contiguous" — (M,K) with M as fast dim. Matches TN convention (best for tritonblas). + """ + if config is None: + config = FusedConfig() + + M, K_local = A_sharded.shape + K, N = B.shape + world_size = shmem.get_num_ranks() + rank = shmem.get_rank() + + assert world_size * K_local == K + assert output_tensor.shape == (M, N) + assert M % config.block_size_m == 0 + assert K % config.block_size_k == 0 + assert K_local % config.block_size_k == 0 + + if fetch_block_m is None: + fetch_block_m = config.block_size_m + if fetch_block_k is None: + fetch_block_k = config.block_size_k + + num_k_blocks = K // config.block_size_k + assert num_k_blocks % k_per_flag == 0 + + if workspace is None: + workspace = all_gather_matmul_hbm_buffer_preamble(shmem, A_sharded, B, config, k_per_flag, staged_a_layout) + + workspace.locks.zero_() + + stride_am, stride_ak = A_sharded.stride() + stride_bk, stride_bn = B.stride() + stride_cm, stride_cn = output_tensor.stride() + stride_sa_m, stride_sa_k = workspace.aux_buffer.stride() + + if bias is not None: + assert bias.shape[0] == M + bias_ptr = bias + stride_bias = bias.stride()[0] if bias.dim() > 0 else 1 + use_bias = True + else: + bias_ptr = output_tensor + stride_bias = 1 + use_bias = False + + device = A_sharded.device + num_sms = config.num_sms + if num_sms is None: + props = torch.cuda.get_device_properties(device) + num_sms = props.multi_processor_count + + num_m_tiles = M // config.block_size_m + num_tiles_n = (N + config.block_size_n - 1) // config.block_size_n + total_gemm_tiles = num_m_tiles * num_tiles_n + num_k_blocks_local = K_local // config.block_size_k + num_flag_groups_k = num_k_blocks // k_per_flag + total_gather_tiles = num_m_tiles * num_k_blocks + + if num_fetch_sms is None: + num_fetch_sms = max(1, num_sms // 10) + assert 0 < num_fetch_sms + assert num_fetch_stages >= 1 + + # First stage can use more fetcher WGs to fill the first GPU wave + if first_stage_fetch_sms is None: + first_stage_fetch_sms = num_fetch_sms + + # Interleaved layout: [fetch0 (P)] [gemm0 (G)] [fetch1 (F)] [gemm1 (G)] ... + m_per_stage = (num_m_tiles + num_fetch_stages - 1) // num_fetch_stages + gemm_tiles_per_stage = m_per_stage * num_tiles_n + first_stage_size = first_stage_fetch_sms + gemm_tiles_per_stage + rest_stage_size = num_fetch_sms + gemm_tiles_per_stage + total_fetch_wgs = first_stage_fetch_sms + num_fetch_sms * max(0, num_fetch_stages - 1) + grid_size = first_stage_size + rest_stage_size * max(0, num_fetch_stages - 1) + + if trace: + max_trace_events = grid_size * 4 + if not shmem.tracing.enabled: + shmem.tracing.enable(max_events=max_trace_events) + else: + shmem.tracing.reset() + + launch_kwargs = {"matrix_instr_nonkdim": 16} + if num_warps is not None: + launch_kwargs["num_warps"] = num_warps + if num_stages is not None: + launch_kwargs["num_stages"] = num_stages + + _hbm_buffer_all_gather_matmul_kernel[(grid_size,)]( + A_sharded, + B, + output_tensor, + bias_ptr, + workspace.aux_buffer, + workspace.locks, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_sa_m, + stride_sa_k, + stride_bias, + shmem.get_device_context(), + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_fetch_sms, + num_m_tiles, + num_tiles_n, + num_k_blocks, + num_k_blocks_local, + k_per_flag, + num_flag_groups_k, + total_gather_tiles, + use_bias, + config.allow_tf32, + num_fetch_stages, + gemm_tiles_per_stage, + first_stage_fetch_sms, + trace, + **launch_kwargs, + ) + + if not async_op: + shmem.barrier() + + if trace: + torch.cuda.synchronize() + workspace.trace_data = _extract_wg_trace( + shmem, + grid_size, + num_fetch_sms=num_fetch_sms, + num_fetch_stages=num_fetch_stages, + total_fetch_wgs=total_fetch_wgs, + num_m_tiles=num_m_tiles, + num_tiles_n=num_tiles_n, + first_stage_fetch_sms=first_stage_fetch_sms, + first_stage_size=first_stage_size, + rest_stage_size=rest_stage_size, + gemm_tiles_per_stage=gemm_tiles_per_stage, + ) + + return workspace diff --git a/iris/ops/config.py b/iris/ops/config.py index 3ca085c3..c5d15349 100644 --- a/iris/ops/config.py +++ b/iris/ops/config.py @@ -19,10 +19,10 @@ class FusedConfig: but users can override specific settings for performance tuning. GEMM Parameters: - block_size_m: Block size for M dimension (rows). Default: 256. - block_size_n: Block size for N dimension (columns). Default: 64. + block_size_m: Block size for M dimension (rows). Default: 128. + block_size_n: Block size for N dimension (columns). Default: 256. block_size_k: Block size for K dimension (reduction). Default: 64. - group_size_m: Group size for M dimension tiling. Default: 1. + group_size_m: Group size for M dimension tiling. Default: 4. num_sms: Number of SMs to use. If None, auto-detects from device. Default: None. num_xcds: Number of XCDs (chiplets). Default: 1. chunk_size: Chunk size for chiplet transform. Default: 1. @@ -32,8 +32,11 @@ class FusedConfig: CCL Parameters (for operations that need collective communication): all_reduce_variant: All-reduce algorithm variant. Options: "atomic", "ring", - "one_shot", "two_shot", "spinlock". Default: "one_shot". + "one_shot", "two_shot", "spinlock". Default: "two_shot". all_reduce_num_rings: Number of concurrent rings (for ring variant). Default: 1. + all_gather_matmul_variant: All-gather + matmul algorithm variant. Options: + "pull" (on-demand pull from remote ranks). + Default: "pull". Example: >>> # Use defaults @@ -47,12 +50,12 @@ class FusedConfig: """ # GEMM parameters - block_size_m: int = 256 - block_size_n: int = 64 + block_size_m: int = 128 + block_size_n: int = 256 block_size_k: int = 64 group_size_m: int = 1 num_sms: Optional[int] = None # Auto-detect if None - num_xcds: int = 1 + num_xcds: int = 8 chunk_size: int = 1 cache_modifier_a: str = ".ca" cache_modifier_b: str = ".ca" @@ -61,6 +64,7 @@ class FusedConfig: # CCL-specific parameters all_reduce_variant: str = "two_shot" # atomic, ring, one_shot, two_shot, spinlock all_reduce_num_rings: int = 1 + all_gather_matmul_variant: str = "pull" # pull, chunked def validate(self, world_size: Optional[int] = None): """ @@ -102,3 +106,10 @@ def validate(self, world_size: Optional[int] = None): if self.all_reduce_num_rings <= 0: raise ValueError(f"all_reduce_num_rings must be positive, got {self.all_reduce_num_rings}") + + # Validate all_gather_matmul_variant + valid_ag_variants = ["pull"] + if self.all_gather_matmul_variant not in valid_ag_variants: + raise ValueError( + f"all_gather_matmul_variant must be one of {valid_ag_variants}, got {self.all_gather_matmul_variant}" + ) diff --git a/iris/ops/matmul_all_gather.py b/iris/ops/matmul_all_gather.py index ad42ac04..6b19caea 100644 --- a/iris/ops/matmul_all_gather.py +++ b/iris/ops/matmul_all_gather.py @@ -180,17 +180,6 @@ def matmul_all_gather( assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" # Validate problem size against block sizes - assert M_local >= config.block_size_m, ( - f"M_local ({M_local}) must be >= block_size_m ({config.block_size_m}). " - f"Use smaller block sizes for small problems." - ) - assert K >= config.block_size_k, ( - f"K ({K}) must be >= block_size_k ({config.block_size_k}). Use smaller block sizes for small problems." - ) - assert N >= config.block_size_n, ( - f"N ({N}) must be >= block_size_n ({config.block_size_n}). Use smaller block sizes for small problems." - ) - # Allocate workspace if not provided if workspace is None: workspace = matmul_all_gather_preamble(shmem, A, B, config) diff --git a/iris/ops/matmul_all_reduce.py b/iris/ops/matmul_all_reduce.py index 73bea92c..ceded705 100644 --- a/iris/ops/matmul_all_reduce.py +++ b/iris/ops/matmul_all_reduce.py @@ -272,11 +272,6 @@ def matmul_all_reduce( if A.dtype != B.dtype or A.dtype != C.dtype: raise ValueError(f"All tensors must have same dtype, got A:{A.dtype}, B:{B.dtype}, C:{C.dtype}") - # Validate block sizes match problem dimensions - assert M >= config.block_size_m, f"M={M} too small for block_size_m={config.block_size_m}" - assert K >= config.block_size_k, f"K={K} too small for block_size_k={config.block_size_k}" - assert N >= config.block_size_n, f"N={N} too small for block_size_n={config.block_size_n}" - # Extract strides stride_am, stride_ak = A.stride() stride_bk, stride_bn = B.stride() diff --git a/iris/ops/workspace.py b/iris/ops/workspace.py index a9c7cb61..e519f082 100644 --- a/iris/ops/workspace.py +++ b/iris/ops/workspace.py @@ -38,10 +38,18 @@ class FusedWorkspace: world_size: int = 1 variant: str = "" + # Hardware configuration (detected in preamble) + num_sms: Optional[int] = None # Number of streaming multiprocessors + num_xcds: int = 1 # Number of XCDs/chiplets + # Temporary buffers (allocated as needed) aux_buffer: Optional[torch.Tensor] = None # Generic buffer for intermediate results locks: Optional[torch.Tensor] = None # Synchronization primitives + # Push variant workspace + a_inbox: Optional[torch.Tensor] = None # (world_size, M, K_local) inbox buffer + signal_flags: Optional[torch.Tensor] = None # (world_size, world_size, m_tiles, k_tiles) + prepared: bool = False def matches( @@ -82,4 +90,6 @@ def clear(self): """Free all allocated buffers.""" self.aux_buffer = None self.locks = None + self.a_inbox = None + self.signal_flags = None self.prepared = False diff --git a/iris/tracing/events.py b/iris/tracing/events.py index 4838c09d..62d7cf8d 100644 --- a/iris/tracing/events.py +++ b/iris/tracing/events.py @@ -26,6 +26,9 @@ 11: "atomic_or", 12: "atomic_min", 13: "atomic_max", + 14: "wg_fetch", + 15: "wg_gemm", + 16: "wg_gemm_wait", } @@ -75,6 +78,11 @@ class TraceEvent: atomic_min: tl.constexpr atomic_max: tl.constexpr + # Workgroup-level profiling events + wg_fetch: tl.constexpr + wg_gemm: tl.constexpr + wg_gemm_wait: tl.constexpr + @triton.constexpr_function def __init__(self): # Data movement @@ -94,3 +102,8 @@ def __init__(self): self.atomic_or = tl.constexpr(11) self.atomic_min = tl.constexpr(12) self.atomic_max = tl.constexpr(13) + + # Workgroup-level profiling + self.wg_fetch = tl.constexpr(14) + self.wg_gemm = tl.constexpr(15) + self.wg_gemm_wait = tl.constexpr(16) diff --git a/iris/x/gather.py b/iris/x/gather.py index ca8bd4f9..4e2b10cc 100644 --- a/iris/x/gather.py +++ b/iris/x/gather.py @@ -24,6 +24,7 @@ def gather( src_view: TensorView, source_rank: tl.constexpr, ctx: DeviceContext, + hint: tl.constexpr = None, ): """ Tile-level gather from a specific rank. @@ -37,6 +38,9 @@ def gather( src_view: TensorView for source tensor on source_rank. source_rank: Specific rank to load from (constexpr). ctx: DeviceContext with rank, world_size, and heap_bases. + hint: Vectorization hint passed to tl.multiple_of / tl.max_contiguous on + the translated pointer. Use a scalar (e.g. 16) or a tuple + (e.g. (1, 16)) to indicate alignment. Defaults to None (no hint). Returns: Loaded tile data as a tensor. @@ -61,6 +65,7 @@ def gather( source_rank, # from_rank (source rank) ctx.heap_bases, mask=mask, + hint=hint, ) return tile_data diff --git a/pyproject.toml b/pyproject.toml index 18e71bad..02533764 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "numpy", "requests", "ruff", - "tritonblas @ git+https://github.com/ROCm/tritonBLAS.git@df58476a4520b72495a3f03f911368a184126568", + "tritonblas @ git+https://github.com/ROCm/tritonBLAS.git@cd119279f3df543a558aa6d2cd4a3daed0b1ec7a", ] diff --git a/tests/ops/test_all_gather_matmul.py b/tests/ops/test_all_gather_matmul.py index 19350501..db4b2125 100644 --- a/tests/ops/test_all_gather_matmul.py +++ b/tests/ops/test_all_gather_matmul.py @@ -28,7 +28,13 @@ (256, 64, 128), ], ) -def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N): +@pytest.mark.parametrize( + "variant", + [ + "pull", + ], +) +def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N, variant): """Test all_gather_matmul against torch all_gather + matmul.""" if not dist.is_initialized(): pytest.skip("torch.distributed not initialized") @@ -77,12 +83,20 @@ def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N): # Run fused all_gather + matmul using shmem.ops API from iris.ops.config import FusedConfig + if rank == 0: + print(f"\n[Test] Testing variant={variant}, M={M}, K_local={K_local}, N={N}, dtype={dtype}") + # Use appropriate block sizes based on problem size # For small problems, use smaller blocks if M <= 256 or K_local <= 64 or N <= 128: - config = FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) + config = FusedConfig( + block_size_m=64, + block_size_n=64, + block_size_k=32, + all_gather_matmul_variant=variant, + ) else: - config = FusedConfig() + config = FusedConfig(all_gather_matmul_variant=variant) # Validate config against problem size assert M >= config.block_size_m, f"M ({M}) must be >= block_size_m ({config.block_size_m})" diff --git a/tests/ops/test_all_gather_matmul_hbm_buffer.py b/tests/ops/test_all_gather_matmul_hbm_buffer.py new file mode 100644 index 00000000..af173ea8 --- /dev/null +++ b/tests/ops/test_all_gather_matmul_hbm_buffer.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tests for fused all_gather + matmul using the HBM staging buffer implementation. + +Each rank has A_sharded (M x K_local), B is replicated. +The operation gathers A from all ranks into a local HBM buffer and computes C = A_gathered @ B. +""" + +import pytest +import torch +import torch.distributed as dist + +import iris +from iris.ops.all_gather_matmul_hbm_buffer import ( + all_gather_matmul_hbm_buffer, + all_gather_matmul_hbm_buffer_preamble, +) +from iris.ops.config import FusedConfig + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 1e-2, 1e-2), + ], +) +@pytest.mark.parametrize( + "M,K_local,N", + [ + (128, 32, 64), + (256, 64, 128), + ], +) +@pytest.mark.parametrize( + "staged_a_layout", + [ + "k_contiguous", + "m_contiguous", + ], +) +def test_all_gather_matmul_hbm_buffer(dtype, atol, rtol, M, K_local, N, staged_a_layout): + """Test all_gather_matmul_hbm_buffer against torch all_gather + matmul.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + K = K_local * world_size # Full K dimension + + # Seed for reproducibility - different seed per rank for A_sharded + torch.manual_seed(42 + rank) + A_sharded = torch.randn(M, K_local, dtype=dtype, device=f"cuda:{rank}") + + # B must be identical on all ranks + torch.manual_seed(123) + B = torch.randn(K, N, dtype=dtype, device=f"cuda:{rank}") + + # Reference: torch all_gather + matmul + A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_gathered_list, A_sharded) + A_gathered_ref = torch.cat(A_gathered_list, dim=1) # (M, K) + ref_output = torch.matmul(A_gathered_ref, B) + torch.cuda.synchronize() + + # Create shmem tensors + A_sharded_shmem = shmem.zeros((M, K_local), dtype=dtype) + A_sharded_shmem.copy_(A_sharded) + B_shmem = shmem.zeros((K, N), dtype=dtype) + B_shmem.copy_(B) + output = shmem.zeros((M, N), dtype=dtype) + + shmem.barrier() + + # Use small block sizes for small test problems + config = FusedConfig( + block_size_m=64, + block_size_n=64, + block_size_k=32, + ) + + workspace = all_gather_matmul_hbm_buffer_preamble( + shmem, A_sharded_shmem, B_shmem, config=config, staged_a_layout=staged_a_layout + ) + + all_gather_matmul_hbm_buffer( + shmem, + output, + A_sharded_shmem, + B_shmem, + config=config, + workspace=workspace, + staged_a_layout=staged_a_layout, + trace=False, + ) + + torch.cuda.synchronize() + shmem.barrier() + + max_diff = (output - ref_output).abs().max().item() + + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( + f"Rank {rank}: Max diff {max_diff}, expected < {atol} " + f"(staged_a_layout={staged_a_layout}, M={M}, K_local={K_local}, N={N})" + ) + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 1e-2, 1e-2), + ], +) +@pytest.mark.parametrize( + "M,K_local,N", + [ + (128, 32, 64), + ], +) +def test_all_gather_matmul_hbm_buffer_with_bias(dtype, atol, rtol, M, K_local, N): + """Test all_gather_matmul_hbm_buffer with a bias vector.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + K = K_local * world_size + + torch.manual_seed(42 + rank) + A_sharded = torch.randn(M, K_local, dtype=dtype, device=f"cuda:{rank}") + + torch.manual_seed(123) + B = torch.randn(K, N, dtype=dtype, device=f"cuda:{rank}") + + torch.manual_seed(77) + bias = torch.randn(M, dtype=dtype, device=f"cuda:{rank}") + + # Reference: torch all_gather + matmul + bias + A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_gathered_list, A_sharded) + A_gathered_ref = torch.cat(A_gathered_list, dim=1) + ref_output = torch.matmul(A_gathered_ref, B) + bias[:, None] + torch.cuda.synchronize() + + # Create shmem tensors + A_sharded_shmem = shmem.zeros((M, K_local), dtype=dtype) + A_sharded_shmem.copy_(A_sharded) + B_shmem = shmem.zeros((K, N), dtype=dtype) + B_shmem.copy_(B) + bias_shmem = shmem.zeros((M,), dtype=dtype) + bias_shmem.copy_(bias) + output = shmem.zeros((M, N), dtype=dtype) + + shmem.barrier() + + config = FusedConfig( + block_size_m=64, + block_size_n=64, + block_size_k=32, + ) + + all_gather_matmul_hbm_buffer( + shmem, + output, + A_sharded_shmem, + B_shmem, + bias=bias_shmem, + config=config, + trace=False, + ) + + torch.cuda.synchronize() + shmem.barrier() + + max_diff = (output - ref_output).abs().max().item() + + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( + f"Rank {rank}: Max diff {max_diff}, expected < {atol} (with bias)" + ) + + +if __name__ == "__main__": + # For quick debugging + import sys + + if not dist.is_initialized(): + print("Run with: torchrun --nproc_per_node=2 tests/ops/test_all_gather_matmul_hbm_buffer.py") + sys.exit(1) + + rank = dist.get_rank() + torch.cuda.set_device(rank) + + print(f"[Rank {rank}] Testing all_gather_matmul_hbm_buffer...") + test_all_gather_matmul_hbm_buffer(torch.float16, 1e-2, 1e-2, 128, 32, 64, "k_contiguous") + print(f"[Rank {rank}] ✓ Test passed!") diff --git a/tests/ops/test_matmul_all_reduce.py b/tests/ops/test_matmul_all_reduce.py index 5780b5d4..0fd278fe 100644 --- a/tests/ops/test_matmul_all_reduce.py +++ b/tests/ops/test_matmul_all_reduce.py @@ -112,7 +112,7 @@ def test_matmul_all_reduce_via_shmem_ops(): shmem = iris.iris(heap_size) rank = shmem.get_rank() - M, N, K = 256, 128, 64 + M, N, K = 256, 256, 64 dtype = torch.float16 A = shmem.randn((M, K), dtype=dtype)