diff --git a/examples/32_ring_attention/README.md b/examples/32_ring_attention/README.md new file mode 100644 index 00000000..ba0166b1 --- /dev/null +++ b/examples/32_ring_attention/README.md @@ -0,0 +1,122 @@ + + +# Ring Attention + +An implementation of **Ring Attention with Blockwise Transformers** for +near-infinite context on AMD GPUs using [Iris](../../README.md). + +> Liu, H., Li, M., Hall, A., Dao, T., & Abbeel, P. (2023). +> *Ring Attention with Blockwise Transformers for Near-Infinite Context.* +> arXiv:2310.01889. + +--- + +## Algorithm + +Standard self-attention requires O(n²) memory in the sequence length n. +Ring Attention enables sequences far longer than what fits on a single device +by distributing them across a *ring* of GPUs: + +1. The full sequence is split evenly across **N GPUs** along the sequence + dimension. Each device holds a chunk of Q, K, and V of length + `seq_total / N`. +2. **Q stays local**. K and V rotate around the ring one step at a time. +3. At each of the **N steps**, every device runs a local + [Flash Attention](https://arxiv.org/abs/2205.14135) pass and accumulates + the result using **online softmax**. +4. After all N steps the accumulator is normalised to yield the final output. + +For **causal (autoregressive) attention** only the steps where the KV chunk +precedes or coincides with the Q chunk contribute, allowing early termination +for some ranks and reducing total compute. + +``` +Step 0: rank r processes its own K_r, V_r (causal block diagonal) +Step 1: rank r receives K_{r-1}, V_{r-1} (full attention, past) +... +Step r: rank r receives K_0, V_0 (full attention, past) +Step r+1..N-1: all-future chunks – skipped (causal mode only) +``` + +--- + +## Files + +| File | Description | +|------|-------------| +| `ring_attention_kernels.py` | Triton flash-attention kernel + Python ring-rotation helper | +| `ring_attention_layer.py` | `RingAttention` – a `torch.nn.Module` wrapper | +| `example_run.py` | End-to-end demo with timing | + +--- + +## Usage + +### Quick demo + +```bash +# 2 GPUs, causal attention (default) +python examples/32_ring_attention/example_run.py + +# 4 GPUs, bidirectional +python examples/32_ring_attention/example_run.py --num_ranks 4 --no_causal + +# Custom sizes +python examples/32_ring_attention/example_run.py \ + --num_ranks 8 \ + --total_seq_len 131072 \ + --num_heads 32 \ + --head_dim 128 +``` + +### Validation + +```bash +python tests/run_tests_distributed.py tests/examples/test_ring_attention.py --num_ranks 2 -v +``` + +--- + +## Python API + +```python +import iris +from examples.ring_attention.ring_attention_layer import RingAttention + +shmem = iris.iris() + +# Each rank holds its local chunk +layer = RingAttention( + shmem, + num_heads=16, + head_dim=64, + causal=True, # autoregressive masking +) + +# q, k, v: [seq_local, num_heads, head_dim] (float16 or bfloat16) +output = layer(q, k, v) # [seq_local, num_heads, head_dim] +``` + +--- + +## Design Notes + +* **Communication**: KV rotation uses `iris.put` Triton kernels — each rank + pushes its K/V chunk directly to the next rank's symmetric heap buffer. + A `shmem.barrier()` after each push ensures all ranks have received the + data before the next attention step proceeds. No `torch.distributed` APIs + are used. +* **Ping-pong buffers**: Two symmetric buffer pairs (`k_ping`/`k_pong` and + `v_ping`/`v_pong`) alternate as source and destination on each step. This + guarantees the source being read and the destination being written are + always different allocations, avoiding any read-after-write hazard. +* **Online softmax**: The kernel maintains running max (`M`) and sum (`L`) + accumulators in float32 for numerical stability. The final output is + `O / L` after all ring steps. +* **Causal masking**: Handled entirely at the granularity of KV *chunks* – + full attention, diagonal block attention, or skip – so the per-element mask + is applied only in the same-block diagonal case. All ranks still + participate in the rotation (required for the barrier to be well-defined). diff --git a/examples/32_ring_attention/benchmark.py b/examples/32_ring_attention/benchmark.py new file mode 100644 index 00000000..5e473f8c --- /dev/null +++ b/examples/32_ring_attention/benchmark.py @@ -0,0 +1,516 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Ring Attention benchmark: performance sweep and roofline analysis. + +Measures ring attention throughput across a range of sequence lengths, compares +against a single-device PyTorch ``scaled_dot_product_attention`` reference, and +generates a roofline plot with a performance table. + +Usage:: + + # 2-GPU sweep (default) + python examples/32_ring_attention/benchmark.py + + # 4-GPU sweep + python examples/32_ring_attention/benchmark.py --num_ranks 4 + + # Save plots to a file instead of showing interactively + python examples/32_ring_attention/benchmark.py --save_fig bench.png + +Hardware targets (auto-detected from ``rocminfo`` / ``hipGetDeviceProperties``): + + * AMD Instinct MI300X (gfx942): FP16 peak ≈ 1307 TFLOPS, BW ≈ 5300 GB/s + * Falls back to conservative estimates when hardware info is unavailable. +""" + +import argparse +import json +import os +import sys +import tempfile +from pathlib import Path +from typing import Any + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import iris + +project_root = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from ring_attention_layer import RingAttention # noqa: E402 + + +# --------------------------------------------------------------------------- +# Hardware peak specs (MI300X / gfx942 defaults) +# --------------------------------------------------------------------------- + +# FP16 matrix peak (TFLOPS) and memory bandwidth (GB/s) for MI300X. +# Source: https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html +_MI300X_FP16_TFLOPS = 1307.4 +_MI300X_MEMBW_GBS = 5300.0 + +# MI300X has exactly 304 compute units (used as a fingerprint when the device name +# does not contain an explicit architecture string). +_MI300X_CU_COUNT = 304 + +# Fallback conservative estimates for unknown hardware +_FALLBACK_FP16_TFLOPS = 100.0 +_FALLBACK_MEMBW_GBS = 500.0 + +# Unit conversion: 1 TB/s = 1000 GB/s +_GB_TO_TB = 1e3 + + +def _get_hw_specs(device: torch.device) -> tuple[float, float]: + """ + Return (peak_fp16_tflops, peak_membw_gbs) for the given device. + + Detects MI300X by GFX version; falls back to conservative defaults + for unknown hardware. + """ + try: + props = torch.cuda.get_device_properties(device) + name = props.name.lower() + # gfx942 = MI300X / MI300A family; 304 CUs is the MI300X fingerprint + if "gfx942" in name or "mi300" in name or (props.multi_processor_count == _MI300X_CU_COUNT): + return _MI300X_FP16_TFLOPS, _MI300X_MEMBW_GBS + except Exception: + pass + return _FALLBACK_FP16_TFLOPS, _FALLBACK_MEMBW_GBS + + +# --------------------------------------------------------------------------- +# FLOPs / bytes helpers +# --------------------------------------------------------------------------- + + +def _attn_flops(seq_q: int, seq_kv: int, num_heads: int, head_dim: int, causal: bool) -> int: + """ + Theoretical FLOPs for one attention forward pass (QK^T + softmax + AV). + + Flash-attention FLOPs (no materialised S×S matrix): + QK^T : 2 * seq_q * seq_kv * head_dim per head + AV : 2 * seq_q * seq_kv * head_dim per head + Total : 4 * seq_q * seq_kv * head_dim * num_heads + + For causal attention roughly half the token-pairs are skipped, so we + apply a 0.5 factor (exact only for the diagonal block; used as an + approximation for the whole pass). + """ + flops = 4 * seq_q * seq_kv * head_dim * num_heads + if causal: + flops = flops // 2 + return flops + + +def _attn_bytes(seq_q: int, seq_kv: int, num_heads: int, head_dim: int, elem_bytes: int = 2) -> int: + """ + Bytes accessed by a tiled flash-attention kernel (no S×S HBM spill): + Reads : Q [seq_q × H × D] + K [seq_kv × H × D] + V [seq_kv × H × D] + Writes: O [seq_q × H × D] + """ + return elem_bytes * num_heads * head_dim * (2 * seq_q + 2 * seq_kv) + + +# --------------------------------------------------------------------------- +# Timing utilities +# --------------------------------------------------------------------------- + + +def _time_ms(fn, warmup: int = 3, iters: int = 10) -> float: + """Return median latency in ms over *iters* timed calls after *warmup*.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + times = [] + for _ in range(iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + times.sort() + return times[len(times) // 2] # median + + +# --------------------------------------------------------------------------- +# Benchmark worker (runs inside each spawned process) +# --------------------------------------------------------------------------- + + +def _benchmark_worker( + rank: int, + world_size: int, + init_url: str, + configs: list[dict[str, Any]], + results_file: str, + causal: bool, + num_warmup: int, + num_iters: int, +): + """ + Worker function executed by each GPU rank. + + Rank 0 also runs the single-device SDPA reference and writes results to + *results_file* as JSON. + """ + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=rank, + device_id=torch.device(f"cuda:{rank}"), + ) + torch.cuda.set_device(rank) + torch.set_default_device(f"cuda:{rank}") + + shmem = iris.iris() + device = torch.device(f"cuda:{rank}") + peak_tflops, peak_bw = _get_hw_specs(device) + + results = [] + + for cfg in configs: + total_seq = cfg["total_seq"] + num_heads = cfg["num_heads"] + head_dim = cfg["head_dim"] + dtype = getattr(torch, cfg["dtype"]) + elem_bytes = 2 # fp16 / bf16 + + seq_local = total_seq // world_size + scale = head_dim**-0.5 + + torch.manual_seed(42 + rank) + q = torch.randn(seq_local, num_heads, head_dim, dtype=dtype) + k = torch.randn(seq_local, num_heads, head_dim, dtype=dtype) + v = torch.randn(seq_local, num_heads, head_dim, dtype=dtype) + + layer = RingAttention(shmem, num_heads=num_heads, head_dim=head_dim, causal=causal, scale=scale) + + shmem.barrier() + + # ---- Ring attention timing ---- + ring_ms = _time_ms(lambda: layer(q, k, v), warmup=num_warmup, iters=num_iters) + + # All ranks need to sync before SDPA + shmem.barrier() + + # ---- Reference SDPA on rank 0 (full sequence, single GPU) ---- + ref_ms = None + if rank == 0: + try: + q_full = torch.randn(total_seq, num_heads, head_dim, dtype=dtype) + k_full = torch.randn_like(q_full) + v_full = torch.randn_like(q_full) + + # [S, H, D] → [H, S, D] for SDPA + q_f = q_full.permute(1, 0, 2) + k_f = k_full.permute(1, 0, 2) + v_f = v_full.permute(1, 0, 2) + + ref_ms = _time_ms( + lambda: torch.nn.functional.scaled_dot_product_attention( + q_f, k_f, v_f, scale=scale, is_causal=causal + ), + warmup=num_warmup, + iters=num_iters, + ) + except torch.OutOfMemoryError: + print(f"[WARN] SDPA reference OOM at total_seq={total_seq}, skipping") + ref_ms = float("nan") + torch.cuda.empty_cache() + + # ---- FLOPs (per rank) ---- + # Ring attention: seq_q × total_seq attention per rank + ring_flops = _attn_flops(seq_local, total_seq, num_heads, head_dim, causal) + # Reference: total_seq × total_seq on a single device + ref_flops = _attn_flops(total_seq, total_seq, num_heads, head_dim, causal) + + # ---- Arithmetic intensity (flash-attn, per rank) ---- + ring_bytes = 0 + for _step in range(world_size): + ring_bytes += _attn_bytes(seq_local, seq_local, num_heads, head_dim, elem_bytes) + ring_ai = ring_flops / ring_bytes # FLOPs/byte + + ref_bytes = _attn_bytes(total_seq, total_seq, num_heads, head_dim, elem_bytes) + ref_ai = ref_flops / ref_bytes + + ring_tflops = ring_flops / (ring_ms * 1e-3) / 1e12 + ref_tflops = ref_flops / (ref_ms * 1e-3) / 1e12 + + results.append( + { + "total_seq": total_seq, + "num_heads": num_heads, + "head_dim": head_dim, + "world_size": world_size, + "causal": causal, + "dtype": cfg["dtype"], + # timings + "ring_ms": ring_ms, + "ref_ms": ref_ms, + "speedup": ref_ms / ring_ms, + # TFLOPS + "ring_tflops": ring_tflops, + "ref_tflops": ref_tflops, + # Arithmetic intensity + "ring_ai": ring_ai, + "ref_ai": ref_ai, + # Hardware peaks + "peak_tflops": peak_tflops, + "peak_bw_gbs": peak_bw, + } + ) + + shmem.barrier() + + del shmem + dist.destroy_process_group() + + # Write results from rank 0 to the shared temp file + if rank == 0: + with open(results_file, "w") as f: + json.dump(results, f, indent=2) + + +# --------------------------------------------------------------------------- +# Plotting +# --------------------------------------------------------------------------- + + +def _print_table(results: list[dict[str, Any]]): + """Print a performance summary table to stdout.""" + if not results: + print("No results.") + return + peak_tflops = results[0]["peak_tflops"] + hdr = ( + f"{'seq':>8} {'H':>4} {'D':>4} " + f"{'ring ms':>9} {'ref ms':>9} {'speedup':>8} " + f"{'ring TFLOPS':>12} {'ref TFLOPS':>12} " + f"{'ring eff%':>10} {'ref eff%':>10}" + ) + print() + print("=" * len(hdr)) + print(hdr) + print("=" * len(hdr)) + for r in results: + ring_eff = 100.0 * r["ring_tflops"] / peak_tflops + ref_eff = 100.0 * r["ref_tflops"] / peak_tflops + print( + f"{r['total_seq']:>8} {r['num_heads']:>4} {r['head_dim']:>4} " + f"{r['ring_ms']:>9.3f} {r['ref_ms']:>9.3f} {r['speedup']:>8.2f}x " + f"{r['ring_tflops']:>12.2f} {r['ref_tflops']:>12.2f} " + f"{ring_eff:>9.1f}% {ref_eff:>9.1f}%" + ) + print("=" * len(hdr)) + + +def _make_plots(results: list[dict[str, Any]], save_fig: str | None): + """Generate performance table + roofline plot.""" + import matplotlib + import matplotlib.pyplot as plt + + if save_fig: + matplotlib.use("Agg") + + if not results: + print("No results to plot.") + return + + _print_table(results) + + peak_tflops = results[0]["peak_tflops"] + peak_bw = results[0]["peak_bw_gbs"] + world_size = results[0]["world_size"] + + # ---- Roofline plot ---- + fig, axes = plt.subplots(1, 2, figsize=(16, 6)) + + # Left: Roofline + ax = axes[0] + ai_vals = [r["ring_ai"] for r in results] + [r["ref_ai"] for r in results] + ai_min = min(ai_vals) * 0.5 + ai_max = max(ai_vals) * 2.0 + ai_range = [ai_min, ai_max] + + # Roofline ceiling: ridge point converts BW from GB/s to TB/s for TFLOPS units + ridge = peak_tflops / peak_bw * _GB_TO_TB # ridge point (FLOPs/byte) + ai_plot = [ai_min, ridge, ai_max] + roof = [min(peak_tflops, a * peak_bw / _GB_TO_TB) for a in ai_plot] + ax.loglog(ai_plot, roof, "k--", linewidth=2, label="Roofline (MI300X)") + ax.axhline(peak_tflops, color="gray", linestyle=":", alpha=0.6, label=f"Peak FP16 ({peak_tflops:.0f} TFLOPS)") + ax.axvline(ridge, color="gray", linestyle=":", alpha=0.6, label=f"Ridge ({ridge:.1f} FLOP/B)") + + # Ring attention points + for r in results: + ax.scatter(r["ring_ai"], r["ring_tflops"], marker="o", s=80, zorder=5) + ax.annotate( + f"S={r['total_seq']}", + (r["ring_ai"], r["ring_tflops"]), + textcoords="offset points", + xytext=(4, 4), + fontsize=7, + ) + + # Reference points + for r in results: + ax.scatter(r["ref_ai"], r["ref_tflops"], marker="^", s=80, zorder=5, color="tab:orange") + + import matplotlib.lines as mlines + + ring_handle = mlines.Line2D( + [], [], color="tab:blue", marker="o", linestyle="None", markersize=8, label="Ring attn (per rank)" + ) + ref_handle = mlines.Line2D( + [], [], color="tab:orange", marker="^", linestyle="None", markersize=8, label="SDPA reference (single GPU)" + ) + ax.legend(handles=[ring_handle, ref_handle] + ax.get_legend_handles_labels()[0][:3], fontsize=8) + + ax.set_xlabel("Arithmetic Intensity (FLOP/Byte)") + ax.set_ylabel("Performance (TFLOPS)") + ax.set_title(f"Roofline — AMD MI300X (gfx942)\n{world_size} GPUs, causal={results[0]['causal']}") + ax.set_xlim(ai_range) + ax.grid(True, which="both", alpha=0.3) + + # Right: Latency comparison bar chart + ax2 = axes[1] + seqs = [r["total_seq"] for r in results] + ring_ms = [r["ring_ms"] for r in results] + ref_ms = [r["ref_ms"] for r in results] + + x = range(len(seqs)) + width = 0.35 + bars1 = ax2.bar( + [i - width / 2 for i in x], ring_ms, width, label=f"Ring attn ({world_size} GPUs)", color="tab:blue", alpha=0.8 + ) + bars2 = ax2.bar([i + width / 2 for i in x], ref_ms, width, label="SDPA ref (1 GPU)", color="tab:orange", alpha=0.8) + + # Add speedup annotations + for i, r in enumerate(results): + ax2.text(i, max(ring_ms[i], ref_ms[i]) * 1.05, f"{r['speedup']:.1f}x", ha="center", fontsize=8, color="green") + + ax2.set_xticks(list(x)) + ax2.set_xticklabels([f"S={s}" for s in seqs], rotation=30) + ax2.set_ylabel("Latency (ms)") + ax2.set_title( + f"Latency: Ring Attention vs SDPA Reference\nH={results[0]['num_heads']}, D={results[0]['head_dim']}, causal={results[0]['causal']}" + ) + ax2.legend() + ax2.grid(axis="y", alpha=0.3) + + plt.tight_layout() + + if save_fig: + plt.savefig(save_fig, dpi=150, bbox_inches="tight") + print(f"\nSaved figure to: {save_fig}") + else: + plt.show() + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def parse_args(): + p = argparse.ArgumentParser(description="Ring Attention benchmark + roofline") + p.add_argument("--num_ranks", type=int, default=2, help="Number of GPUs") + p.add_argument("--num_heads", type=int, default=16, help="Number of attention heads") + p.add_argument("--head_dim", type=int, default=64, help="Head dimension") + p.add_argument( + "--total_seq_lens", + nargs="+", + type=int, + default=[512, 1024, 2048, 4096, 8192], + help="Total sequence lengths to sweep", + ) + p.add_argument( + "--no_causal", dest="causal", action="store_false", default=True, help="Non-causal (bidirectional) attention" + ) + p.add_argument("--dtype", choices=["float16", "bfloat16"], default="float16") + p.add_argument("--warmup", type=int, default=5, help="Warm-up iterations") + p.add_argument("--iters", type=int, default=20, help="Timed iterations") + p.add_argument("--save_fig", type=str, default=None, help="Save figure to this path (e.g. bench.png)") + p.add_argument("--no_plot", action="store_true", help="Skip plotting") + return p.parse_args() + + +def main(): + args = parse_args() + world_size = args.num_ranks + + # Filter configs to ensure seq_len divisible by 64*world_size + min_seq = 64 * world_size + configs = [] + for seq in args.total_seq_lens: + if seq % min_seq != 0: + print(f"[skip] total_seq={seq} not divisible by {min_seq} (64 * world_size), skipping") + continue + if seq % world_size != 0: + print(f"[skip] total_seq={seq} not divisible by world_size={world_size}, skipping") + continue + configs.append( + { + "total_seq": seq, + "num_heads": args.num_heads, + "head_dim": args.head_dim, + "dtype": args.dtype, # string, converted in worker + } + ) + + if not configs: + print("No valid configurations to benchmark.") + return + + print("Ring Attention Benchmark") + print(f" GPUs : {world_size}") + print(f" num_heads : {args.num_heads}") + print(f" head_dim : {args.head_dim}") + print(f" causal : {args.causal}") + print(f" dtype : {args.dtype}") + print(f" seq lengths : {[c['total_seq'] for c in configs]}") + print(f" warmup/iters: {args.warmup}/{args.iters}") + + # Use a temp file for results (safer than mp.Queue with mp.spawn) + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + results_file = f.name + + try: + import socket + + _sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + _sock.bind(("", 0)) + init_url = f"tcp://127.0.0.1:{_sock.getsockname()[1]}" + _sock.close() + mp.spawn( + fn=_benchmark_worker, + args=(world_size, init_url, configs, results_file, args.causal, args.warmup, args.iters), + nprocs=world_size, + join=True, + ) + + with open(results_file) as f: + results = json.load(f) + finally: + os.unlink(results_file) + + if not args.no_plot: + _make_plots(results, save_fig=args.save_fig) + else: + _print_table(results) + + +if __name__ == "__main__": + main() diff --git a/examples/32_ring_attention/benchmark_results.png b/examples/32_ring_attention/benchmark_results.png new file mode 100644 index 00000000..2e6ea46c Binary files /dev/null and b/examples/32_ring_attention/benchmark_results.png differ diff --git a/examples/32_ring_attention/example_run.py b/examples/32_ring_attention/example_run.py new file mode 100644 index 00000000..c33edfa6 --- /dev/null +++ b/examples/32_ring_attention/example_run.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Minimal example demonstrating ring attention using the RingAttention layer. + +The sequence is split evenly across GPUs along the sequence dimension. +Each rank computes its share of the attention output. After the ring passes +Q and V are combined via online-softmax, yielding the same result as a single +device running full attention on the entire sequence. + +Usage:: + + # Run on 2 GPUs (default) + python examples/32_ring_attention/example_run.py + + # Run on 4 GPUs + python examples/32_ring_attention/example_run.py --num_ranks 4 + + # Non-causal (bidirectional) attention + python examples/32_ring_attention/example_run.py --no_causal +""" + +import argparse + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import iris +from ring_attention_layer import RingAttention + + +def parse_args(): + parser = argparse.ArgumentParser(description="Ring Attention example") + parser.add_argument("--total_seq_len", type=int, default=4096, help="Total sequence length (split across GPUs)") + parser.add_argument("--num_heads", type=int, default=16, help="Number of attention heads") + parser.add_argument("--head_dim", type=int, default=64, help="Head dimension") + parser.add_argument("--num_ranks", type=int, default=2, help="Number of GPUs") + parser.add_argument("--no_causal", action="store_true", help="Use bidirectional (non-causal) attention") + parser.add_argument( + "--dtype", + type=str, + default="float16", + choices=["float16", "bfloat16"], + help="Input tensor dtype", + ) + return parser.parse_args() + + +def run(rank: int, world_size: int, init_url: str, args): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=rank, + device_id=torch.device(f"cuda:{rank}"), + ) + + shmem = iris.iris() + torch.manual_seed(42) + torch.set_default_device("cuda") + + dtype = getattr(torch, args.dtype) + causal = not args.no_causal + + seq_local = args.total_seq_len // world_size + num_heads = args.num_heads + head_dim = args.head_dim + + if rank == 0: + attn_type = "causal" if causal else "bidirectional" + print(f"--- Ring Attention Example ({attn_type}) ---") + print(f" GPUs : {world_size}") + print(f" Total seq len : {args.total_seq_len}") + print(f" Seq per GPU : {seq_local}") + print(f" Heads × dim : {num_heads} × {head_dim}") + print(f" dtype : {dtype}") + + # Each rank creates its local Q, K, V chunk + q = torch.randn(seq_local, num_heads, head_dim, dtype=dtype) + k = torch.randn(seq_local, num_heads, head_dim, dtype=dtype) + v = torch.randn(seq_local, num_heads, head_dim, dtype=dtype) + + shmem.barrier() + + layer = RingAttention(shmem, num_heads=num_heads, head_dim=head_dim, causal=causal) + + # Warm-up pass + _ = layer(q, k, v) + torch.cuda.synchronize() + shmem.barrier() + + # Timed pass + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + output = layer(q, k, v) + end.record() + + torch.cuda.synchronize() + elapsed_ms = start.elapsed_time(end) + + if rank == 0: + print(f"\nOutput shape : {output.shape}") + print(f"Output dtype : {output.dtype}") + print(f"Elapsed time : {elapsed_ms:.2f} ms") + print(f"Output[0, 0, :4] = {output[0, 0, :4].float()}") + + shmem.barrier() + dist.destroy_process_group() + + +def _find_free_port() -> int: + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def main(): + args = parse_args() + init_url = f"tcp://127.0.0.1:{_find_free_port()}" + mp.spawn( + fn=run, + args=(args.num_ranks, init_url, args), + nprocs=args.num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/32_ring_attention/profile_ring_attn.py b/examples/32_ring_attention/profile_ring_attn.py new file mode 100644 index 00000000..9dc57684 --- /dev/null +++ b/examples/32_ring_attention/profile_ring_attn.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Ring Attention profiler: per-step timing breakdown. + +Instruments the ring attention loop to measure where time is spent: + - Kernel launch + compute time + - torch.cuda.synchronize() time + - shmem.barrier() time + +Usage:: + + python examples/32_ring_attention/profile_ring_attn.py + python examples/32_ring_attention/profile_ring_attn.py --num_ranks 4 + python examples/32_ring_attention/profile_ring_attn.py --num_ranks 8 --total_seq_len 16384 +""" + +import argparse +import json +import os +import sys +import tempfile +import time +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton + +import iris + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from ring_attention_kernels import _ring_attn_fwd_kernel # noqa: E402 + + +def _profiled_ring_attn_fwd(q, k, v, shmem, causal=True, scale=None, _ping_pong_bufs=None): + """ + Instrumented ring_attn_fwd that collects per-step timing. + + Returns (output, timing_data). + """ + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + seq_q, num_heads, head_dim = q.shape + seq_kv = k.shape[0] + + if scale is None: + scale = head_dim**-0.5 + + input_dtype = q.dtype + + O = torch.zeros(seq_q, num_heads, head_dim, dtype=torch.float32, device=q.device) + M = torch.full((num_heads, seq_q), fill_value=-float("inf"), dtype=torch.float32, device=q.device) + L = torch.zeros(num_heads, seq_q, dtype=torch.float32, device=q.device) + + BLOCK_Q = 64 + BLOCK_KV = 64 + HEAD_DIM = head_dim + + if _ping_pong_bufs is not None: + k_ping, k_pong, v_ping, v_pong = _ping_pong_bufs + else: + k_ping = shmem.empty(k.shape, dtype=k.dtype) + k_pong = shmem.empty(k.shape, dtype=k.dtype) + v_ping = shmem.empty(v.shape, dtype=v.dtype) + v_pong = shmem.empty(v.shape, dtype=v.dtype) + + k_ping.copy_(k.contiguous()) + v_ping.copy_(v.contiguous()) + shmem.barrier() + + k_cur, k_recv = k_ping, k_pong + v_cur, v_recv = v_ping, v_pong + + next_rank = (rank + 1) % world_size + + FUSED_PUT_BLOCK = BLOCK_Q * HEAD_DIM + n_k = k_cur.numel() + heap_bases = shmem.get_heap_bases() + + step_timings = [] + + for step in range(world_size): + kv_rank = (rank - step) % world_size + do_put = step < world_size - 1 + + # --- Time the kernel launch + execution --- + kernel_start = torch.cuda.Event(enable_timing=True) + kernel_end = torch.cuda.Event(enable_timing=True) + + kernel_start.record() + + q_rank_start = rank * seq_q + kv_rank_start = kv_rank * seq_kv + grid = (num_heads, triton.cdiv(seq_q, BLOCK_Q)) + _ring_attn_fwd_kernel[grid]( + q, + k_cur, + v_cur, + O, + M, + L, + q.stride(0), + q.stride(1), + q.stride(2), + k_cur.stride(0), + k_cur.stride(1), + k_cur.stride(2), + v_cur.stride(0), + v_cur.stride(1), + v_cur.stride(2), + O.stride(0), + O.stride(1), + O.stride(2), + M.stride(0), + M.stride(1), + L.stride(0), + L.stride(1), + seq_q, + seq_kv, + q_rank_start, + kv_rank_start, + scale, + # fused put params + k_cur.view(-1), + k_recv.view(-1), + v_cur.view(-1), + v_recv.view(-1), + n_k, + put_rank=rank, + put_next_rank=next_rank, + heap_bases=heap_bases, + CAUSAL=causal, + BLOCK_Q=BLOCK_Q, + BLOCK_KV=BLOCK_KV, + HEAD_DIM=HEAD_DIM, + DO_PUT=do_put, + PUT_BLOCK=FUSED_PUT_BLOCK, + num_warps=4, + num_stages=2, + ) + + kernel_end.record() + + # --- Time sync + barrier --- + if do_put: + sync_start = torch.cuda.Event(enable_timing=True) + sync_end = torch.cuda.Event(enable_timing=True) + + sync_start.record() + torch.cuda.synchronize() + sync_end.record() + torch.cuda.synchronize() # need to sync to read sync timing + + sync_ms = sync_start.elapsed_time(sync_end) + + barrier_wall_start = time.perf_counter() + shmem.barrier() + barrier_wall_end = time.perf_counter() + barrier_ms = (barrier_wall_end - barrier_wall_start) * 1000.0 + + k_cur, k_recv = k_recv, k_cur + v_cur, v_recv = v_recv, v_cur + else: + torch.cuda.synchronize() + sync_ms = 0.0 + barrier_ms = 0.0 + + kernel_ms = kernel_start.elapsed_time(kernel_end) + + step_timings.append( + { + "step": step, + "kv_rank": kv_rank, + "do_put": do_put, + "kernel_ms": kernel_ms, + "sync_ms": sync_ms, + "barrier_ms": barrier_ms, + "total_ms": kernel_ms + sync_ms + barrier_ms, + } + ) + + L_expanded = L.permute(1, 0).unsqueeze(-1) + output = O / L_expanded + + return output.to(input_dtype), step_timings + + +def _profile_worker(rank, world_size, init_url, cfg, results_file): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=rank, + device_id=torch.device(f"cuda:{rank}"), + ) + torch.cuda.set_device(rank) + torch.set_default_device(f"cuda:{rank}") + + shmem = iris.iris() + + total_seq = cfg["total_seq"] + num_heads = cfg["num_heads"] + head_dim = cfg["head_dim"] + dtype = getattr(torch, cfg["dtype"]) + causal = cfg["causal"] + num_warmup = cfg["warmup"] + num_iters = cfg["iters"] + + seq_local = total_seq // world_size + scale = head_dim**-0.5 + + torch.manual_seed(42 + rank) + q = torch.randn(seq_local, num_heads, head_dim, dtype=dtype) + k = torch.randn(seq_local, num_heads, head_dim, dtype=dtype) + v = torch.randn(seq_local, num_heads, head_dim, dtype=dtype) + + # Pre-allocate ping-pong buffers + k_ping = shmem.empty(k.shape, dtype=k.dtype) + k_pong = shmem.empty(k.shape, dtype=k.dtype) + v_ping = shmem.empty(v.shape, dtype=v.dtype) + v_pong = shmem.empty(v.shape, dtype=v.dtype) + bufs = (k_ping, k_pong, v_ping, v_pong) + + shmem.barrier() + + # Warmup + for _ in range(num_warmup): + out, _ = _profiled_ring_attn_fwd(q, k, v, shmem, causal=causal, scale=scale, _ping_pong_bufs=bufs) + torch.cuda.synchronize() + shmem.barrier() + + # Timed iterations — collect per-step timings + all_iter_timings = [] + for it in range(num_iters): + out, step_timings = _profiled_ring_attn_fwd(q, k, v, shmem, causal=causal, scale=scale, _ping_pong_bufs=bufs) + all_iter_timings.append(step_timings) + torch.cuda.synchronize() + shmem.barrier() + + # Aggregate: average each step's timings across iterations + num_steps = world_size + avg_timings = [] + for s in range(num_steps): + kernel_vals = [all_iter_timings[it][s]["kernel_ms"] for it in range(num_iters)] + sync_vals = [all_iter_timings[it][s]["sync_ms"] for it in range(num_iters)] + barrier_vals = [all_iter_timings[it][s]["barrier_ms"] for it in range(num_iters)] + total_vals = [all_iter_timings[it][s]["total_ms"] for it in range(num_iters)] + avg_timings.append( + { + "step": s, + "kv_rank": all_iter_timings[0][s]["kv_rank"], + "do_put": all_iter_timings[0][s]["do_put"], + "kernel_ms": sum(kernel_vals) / len(kernel_vals), + "sync_ms": sum(sync_vals) / len(sync_vals), + "barrier_ms": sum(barrier_vals) / len(barrier_vals), + "total_ms": sum(total_vals) / len(total_vals), + } + ) + + del shmem + dist.destroy_process_group() + + if rank == 0: + result = { + "config": cfg, + "world_size": world_size, + "rank": rank, + "per_step": avg_timings, + "totals": { + "kernel_ms": sum(s["kernel_ms"] for s in avg_timings), + "sync_ms": sum(s["sync_ms"] for s in avg_timings), + "barrier_ms": sum(s["barrier_ms"] for s in avg_timings), + "total_ms": sum(s["total_ms"] for s in avg_timings), + }, + } + with open(results_file, "w") as f: + json.dump(result, f, indent=2) + + +def main(): + parser = argparse.ArgumentParser(description="Ring Attention profiler") + parser.add_argument("--num_ranks", type=int, default=2) + parser.add_argument("--total_seq_len", type=int, default=8192) + parser.add_argument("--num_heads", type=int, default=16) + parser.add_argument("--head_dim", type=int, default=64) + parser.add_argument("--no_causal", dest="causal", action="store_false", default=True) + parser.add_argument("--dtype", default="float16") + parser.add_argument("--warmup", type=int, default=3) + parser.add_argument("--iters", type=int, default=10) + args = parser.parse_args() + + cfg = { + "total_seq": args.total_seq_len, + "num_heads": args.num_heads, + "head_dim": args.head_dim, + "dtype": args.dtype, + "causal": args.causal, + "warmup": args.warmup, + "iters": args.iters, + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + results_file = f.name + + try: + import socket + + _sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + _sock.bind(("", 0)) + init_url = f"tcp://127.0.0.1:{_sock.getsockname()[1]}" + _sock.close() + + mp.spawn( + fn=_profile_worker, + args=(args.num_ranks, init_url, cfg, results_file), + nprocs=args.num_ranks, + join=True, + ) + + with open(results_file) as f: + result = json.load(f) + + # Print results + world_size = result["world_size"] + totals = result["totals"] + print(f"\n{'=' * 80}") + print( + f"Ring Attention Profiling — {world_size} GPUs, seq={cfg['total_seq']}, " + f"H={cfg['num_heads']}, D={cfg['head_dim']}, causal={cfg['causal']}" + ) + print(f"{'=' * 80}") + + print(f"\n{'step':>4} {'kv_rank':>7} {'put':>4} {'kernel':>9} {'sync':>9} {'barrier':>9} {'total':>9}") + print("-" * 65) + for s in result["per_step"]: + print( + f"{s['step']:>4} {s['kv_rank']:>7} {str(s['do_put']):>4} " + f"{s['kernel_ms']:>8.3f}ms {s['sync_ms']:>8.3f}ms {s['barrier_ms']:>8.3f}ms {s['total_ms']:>8.3f}ms" + ) + + print("\n--- Totals (rank 0) ---") + print( + f" Kernel compute : {totals['kernel_ms']:>8.3f} ms ({100 * totals['kernel_ms'] / totals['total_ms']:>5.1f}%)" + ) + print( + f" CUDA sync : {totals['sync_ms']:>8.3f} ms ({100 * totals['sync_ms'] / totals['total_ms']:>5.1f}%)" + ) + print( + f" Barrier : {totals['barrier_ms']:>8.3f} ms ({100 * totals['barrier_ms'] / totals['total_ms']:>5.1f}%)" + ) + print(f" TOTAL : {totals['total_ms']:>8.3f} ms") + + # Compute efficiency + seq_local = cfg["total_seq"] // world_size + flops = 4 * seq_local * cfg["total_seq"] * cfg["head_dim"] * cfg["num_heads"] + if cfg["causal"]: + flops //= 2 + tflops = flops / (totals["total_ms"] * 1e-3) / 1e12 + print(f" TFLOPS : {tflops:>8.2f}") + print(f" MFU (vs 1307) : {100 * tflops / 1307.4:>7.1f}%") + print() + + finally: + os.unlink(results_file) + + +if __name__ == "__main__": + main() diff --git a/examples/32_ring_attention/ring_attention_kernels.py b/examples/32_ring_attention/ring_attention_kernels.py new file mode 100644 index 00000000..9c2391af --- /dev/null +++ b/examples/32_ring_attention/ring_attention_kernels.py @@ -0,0 +1,400 @@ +################################################################################ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# +# Ring Attention implementation based on: +# "Ring Attention with Blockwise Transformers for Near-Infinite Context" +# Liu et al., 2023 (https://arxiv.org/pdf/2310.01889) +# +################################################################################ + +import torch +import triton +import triton.language as tl +from triton.language.extra import libdevice +import iris + + +@triton.jit +def _put_kv_kernel( + k_src, + k_dst, + v_src, + v_dst, + n_elem, + cur_rank: tl.constexpr, + next_rank: tl.constexpr, + heap_bases, + BLOCK: tl.constexpr, +): + """ + Fused K+V put: copy K and V to the next rank in a single kernel launch. + + Both K and V tensors must be flat (same number of elements) and reside on + the Iris symmetric heap so that their addresses can be translated to + ``next_rank``'s address space. + + Each program instance copies ``BLOCK`` elements of K **and** ``BLOCK`` + elements of V, halving kernel-launch overhead compared to two separate + ``_put_tensor_kernel`` calls. + + Args: + k_src: Source K pointer (must be on the symmetric heap). + k_dst: Destination K pointer (must be on the symmetric heap). + v_src: Source V pointer (must be on the symmetric heap). + v_dst: Destination V pointer (must be on the symmetric heap). + n_elem: Total number of elements in K (same as V). + cur_rank: This rank's ID. + next_rank: Destination rank ID. + heap_bases: Iris heap base address table. + BLOCK: Number of elements each program instance handles. + """ + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < n_elem + iris.put(k_src + offs, k_dst + offs, cur_rank, next_rank, heap_bases, mask=mask) + iris.put(v_src + offs, v_dst + offs, cur_rank, next_rank, heap_bases, mask=mask) + + +@triton.jit +def _ring_attn_fwd_kernel( + Q, + K, + V, + O, + M, + L, + # strides for Q, K, V, O: [seq, num_heads, head_dim] + stride_qs, + stride_qh, + stride_qd, + stride_ks, + stride_kh, + stride_kd, + stride_vs, + stride_vh, + stride_vd, + stride_os, + stride_oh, + stride_od, + # strides for M, L: [num_heads, seq] + stride_mh, + stride_ms, + stride_lh, + stride_ls, + # sizes + seq_q, + seq_kv, + # global offsets for causal masking + q_rank_start, + kv_rank_start, + scale, + # fused KV put parameters + k_put_src, + k_put_dst, + v_put_src, + v_put_dst, + n_put_elem, + put_rank: tl.constexpr, + put_next_rank: tl.constexpr, + heap_bases, + # compile-time constants + CAUSAL: tl.constexpr, + BLOCK_Q: tl.constexpr, + BLOCK_KV: tl.constexpr, + HEAD_DIM: tl.constexpr, + DO_PUT: tl.constexpr, + PUT_BLOCK: tl.constexpr, +): + """ + Flash attention kernel for one ring step. + + Each program instance handles one attention head and one block of Q tokens. + Iterates over all KV blocks and accumulates using online softmax. + + Accumulates into O (unnormalized), M (running log-sum-exp), L (running sum). + The final output is O / L, computed after all ring steps complete. + """ + h = tl.program_id(0) + q_blk = tl.program_id(1) + + q_off = q_blk * BLOCK_Q + q_idx = q_off + tl.arange(0, BLOCK_Q) + q_mask = q_idx < seq_q + + # --- Causal early exit: skip attention if entire KV chunk is masked --- + if CAUSAL: + q_global_max = q_rank_start + q_off + BLOCK_Q - 1 + if kv_rank_start > q_global_max: + # All KV positions are in the future — no useful attention. + # Just do the fused KV rotation and return. + if DO_PUT: + num_q_blks = tl.cdiv(seq_q, BLOCK_Q) + pid_flat = h * num_q_blks + q_blk + put_offs = pid_flat * PUT_BLOCK + tl.arange(0, PUT_BLOCK) + put_mask = put_offs < n_put_elem + iris.put( + k_put_src + put_offs, + k_put_dst + put_offs, + put_rank, + put_next_rank, + heap_bases, + mask=put_mask, + ) + iris.put( + v_put_src + put_offs, + v_put_dst + put_offs, + put_rank, + put_next_rank, + heap_bases, + mask=put_mask, + ) + return + + # Load Q block in native dtype (fp16/bf16) for efficient MFMA matrix multiply. + # Keeping inputs in fp16 uses the FP16 MFMA path (1307 TFLOPS on MI300X) + # instead of the FP32 path (~163 TFLOPS). + q_ptrs = Q + h * stride_qh + q_idx[:, None] * stride_qs + tl.arange(0, HEAD_DIM)[None, :] * stride_qd + q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0) + + # Load running statistics for this head and Q block + m_ptrs = M + h * stride_mh + q_idx * stride_ms + l_ptrs = L + h * stride_lh + q_idx * stride_ls + o_ptrs = O + h * stride_oh + q_idx[:, None] * stride_os + tl.arange(0, HEAD_DIM)[None, :] * stride_od + + m = tl.load(m_ptrs, mask=q_mask, other=-float("inf")) + l = tl.load(l_ptrs, mask=q_mask, other=0.0) + o = tl.load(o_ptrs, mask=q_mask[:, None], other=0.0).to(tl.float32) + + # Global Q positions for causal masking. + # Triton loads q_rank_start (a Python int) and q_idx (int32 arange) as int32. + # The maximum value is world_size * seq_q which fits comfortably in int32. + q_global = q_rank_start + q_idx + + # Iterate over KV blocks + d_idx = tl.arange(0, HEAD_DIM) + for kv_off in range(0, seq_kv, BLOCK_KV): + # Causal inner-loop skip: avoid loading K/V for fully-masked blocks. + # Once kv_rank_start + kv_off > q_global_max, all subsequent blocks + # are also masked (KV positions are monotonically increasing). + if CAUSAL: + do_kv_block = kv_rank_start + kv_off <= q_global_max + else: + do_kv_block = True + + if do_kv_block: + kv_idx = kv_off + tl.arange(0, BLOCK_KV) + kv_mask = kv_idx < seq_kv + + # Load K transposed [HEAD_DIM, BLOCK_KV] and V [BLOCK_KV, HEAD_DIM] + # in native dtype (fp16/bf16) for efficient MFMA matrix multiply + k_ptrs = K + h * stride_kh + d_idx[:, None] * stride_kd + kv_idx[None, :] * stride_ks + v_ptrs = V + h * stride_vh + kv_idx[:, None] * stride_vs + d_idx[None, :] * stride_vd + + k = tl.load(k_ptrs, mask=kv_mask[None, :], other=0.0) + v = tl.load(v_ptrs, mask=kv_mask[:, None], other=0.0) + + # QK^T: fp16/bf16 matmul with fp32 accumulation via MFMA + qk = tl.dot(q, k) * scale + + # Apply padding mask and optional causal mask + if CAUSAL: + kv_global = kv_rank_start + kv_idx + causal_mask = kv_global[None, :] <= q_global[:, None] + qk = tl.where(causal_mask & kv_mask[None, :], qk, -float("inf")) + else: + qk = tl.where(kv_mask[None, :], qk, -float("inf")) + + # Online softmax accumulation (fp32) + m_new = tl.maximum(m, tl.max(qk, axis=1)) + alpha = libdevice.fast_expf(m - m_new) + p = libdevice.fast_expf(qk - m_new[:, None]) + l = alpha * l + tl.sum(p, axis=1) + + # AV: cast softmax probs to native dtype for efficient MFMA + o = alpha[:, None] * o + tl.dot(p.to(v.dtype), v) + m = m_new + + # Write back updated statistics and output + tl.store(m_ptrs, m, mask=q_mask) + tl.store(l_ptrs, l, mask=q_mask) + tl.store(o_ptrs, o, mask=q_mask[:, None]) + + # --- Fused KV rotation: each thread block transfers a slice of K and V --- + # The attention grid has (num_heads * cdiv(seq_q, BLOCK_Q)) blocks. + # Each block transfers PUT_BLOCK = BLOCK_Q * HEAD_DIM elements, so the + # total coverage = num_heads * seq_q * HEAD_DIM = n_put_elem exactly. + if DO_PUT: + num_q_blks = tl.cdiv(seq_q, BLOCK_Q) + pid_flat = h * num_q_blks + q_blk + put_offs = pid_flat * PUT_BLOCK + tl.arange(0, PUT_BLOCK) + put_mask = put_offs < n_put_elem + iris.put(k_put_src + put_offs, k_put_dst + put_offs, put_rank, put_next_rank, heap_bases, mask=put_mask) + iris.put(v_put_src + put_offs, v_put_dst + put_offs, put_rank, put_next_rank, heap_bases, mask=put_mask) + + +def ring_attn_fwd(q, k, v, shmem, causal=True, scale=None, _ping_pong_bufs=None): + """ + Ring Attention forward pass. + + Each device holds a contiguous chunk of the sequence (Q, K, V). K and V + are rotated around the ring of devices using Iris ``put`` operations (via + ``_put_kv_kernel``), while Q remains local. At each step the local + flash-attention kernel accumulates partial results into O, M, L using + online softmax. + + After all ``world_size`` steps, O is normalised by L to produce the output. + + Communication uses two ping-pong symmetric buffers per tensor (K and V), + allocated on the Iris heap. After each push, ``shmem.barrier()`` ensures + all ranks have received the new data before proceeding to the next step. + + Args: + q (torch.Tensor): Query tensor, shape ``[seq_q, num_heads, head_dim]``. + Lives on the local device's CUDA memory. + k (torch.Tensor): Key tensor, same shape as ``q``. + v (torch.Tensor): Value tensor, same shape as ``q``. + shmem: Iris shmem context (provides ``get_rank()`` / ``get_num_ranks()``, + ``get_heap_bases()`` and ``barrier()``). + causal (bool): If ``True``, apply a causal (lower-triangular) mask so + that position ``i`` only attends to positions ``j <= i``. + scale (float | None): Softmax scale factor. Defaults to + ``head_dim ** -0.5``. + _ping_pong_bufs (tuple | None): Optional pre-allocated ping-pong buffers + ``(k_ping, k_pong, v_ping, v_pong)`` from the symmetric heap. When + provided, no new heap allocation is performed (avoids heap churn on + repeated calls with the same tensor shape). + + Returns: + torch.Tensor: Attention output, shape ``[seq_q, num_heads, head_dim]``, + same dtype as ``q``. + """ + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + seq_q, num_heads, head_dim = q.shape + seq_kv = k.shape[0] + + assert (head_dim & (head_dim - 1)) == 0, f"head_dim must be a power of 2, got {head_dim}" + assert seq_q % 64 == 0, f"seq_q ({seq_q}) must be divisible by BLOCK_Q (64)" + assert seq_kv % 64 == 0, f"seq_kv ({seq_kv}) must be divisible by BLOCK_KV (64)" + + if scale is None: + scale = head_dim**-0.5 + + input_dtype = q.dtype + + # Running accumulators in float32 for numerical stability + # O is the *unnormalized* weighted value sum + O = torch.zeros(seq_q, num_heads, head_dim, dtype=torch.float32, device=q.device) + # M: running row-max (log domain), L: running normalisation denominator + M = torch.full((num_heads, seq_q), fill_value=-float("inf"), dtype=torch.float32, device=q.device) + L = torch.zeros(num_heads, seq_q, dtype=torch.float32, device=q.device) + + # Choose block sizes; keep them as powers of 2 + BLOCK_Q = 64 + BLOCK_KV = 64 + HEAD_DIM = head_dim # already validated as power of 2 + + # Allocate two symmetric ping-pong buffers per tensor on the Iris heap. + # The destination buffer of each iris.put must be on the symmetric heap so + # that the pointer can be translated to the remote rank's address space. + # If the caller supplies pre-allocated buffers (e.g. from RingAttention), + # reuse them to avoid heap churn on repeated forward passes. + if _ping_pong_bufs is not None: + k_ping, k_pong, v_ping, v_pong = _ping_pong_bufs + else: + k_ping = shmem.empty(k.shape, dtype=k.dtype) + k_pong = shmem.empty(k.shape, dtype=k.dtype) + v_ping = shmem.empty(v.shape, dtype=v.dtype) + v_pong = shmem.empty(v.shape, dtype=v.dtype) + + # Copy initial K/V into the ping buffers, then sync so every rank has its + # own initial chunk ready before the first rotation. + k_ping.copy_(k.contiguous()) + v_ping.copy_(v.contiguous()) + shmem.barrier() + + k_cur, k_recv = k_ping, k_pong + v_cur, v_recv = v_ping, v_pong + + next_rank = (rank + 1) % world_size + + FUSED_PUT_BLOCK = BLOCK_Q * HEAD_DIM + n_k = k_cur.numel() + heap_bases = shmem.get_heap_bases() + + for step in range(world_size): + kv_rank = (rank - step) % world_size + do_put = step < world_size - 1 + + # The kernel handles causal masking internally with two optimizations: + # 1. Program-level early exit: when all KV positions are beyond the + # Q block's range, skip attention entirely (just do the put). + # 2. Inner-loop skip: stop iterating KV blocks once positions exceed + # the Q range, avoiding useless loads and masked matmuls. + # All ranks still launch the kernel at every step (no barrier imbalance). + q_rank_start = rank * seq_q + kv_rank_start = kv_rank * seq_kv + grid = (num_heads, triton.cdiv(seq_q, BLOCK_Q)) + _ring_attn_fwd_kernel[grid]( + q, + k_cur, + v_cur, + O, + M, + L, + q.stride(0), + q.stride(1), + q.stride(2), + k_cur.stride(0), + k_cur.stride(1), + k_cur.stride(2), + v_cur.stride(0), + v_cur.stride(1), + v_cur.stride(2), + O.stride(0), + O.stride(1), + O.stride(2), + M.stride(0), + M.stride(1), + L.stride(0), + L.stride(1), + seq_q, + seq_kv, + q_rank_start, + kv_rank_start, + scale, + # fused put params + k_cur.view(-1), + k_recv.view(-1), + v_cur.view(-1), + v_recv.view(-1), + n_k, + put_rank=rank, + put_next_rank=next_rank, + heap_bases=heap_bases, + CAUSAL=causal, + BLOCK_Q=BLOCK_Q, + BLOCK_KV=BLOCK_KV, + HEAD_DIM=HEAD_DIM, + DO_PUT=do_put, + PUT_BLOCK=FUSED_PUT_BLOCK, + num_warps=4, + num_stages=2, + ) + + # Global barrier ensures all ranks have received data, then swap. + if do_put: + torch.cuda.synchronize() + shmem.barrier() + k_cur, k_recv = k_recv, k_cur + v_cur, v_recv = v_recv, v_cur + + # Normalize: output = O / L, where L is the softmax denominator + # L: [num_heads, seq_q] → [seq_q, num_heads, 1] for broadcasting + L_expanded = L.permute(1, 0).unsqueeze(-1) # [seq_q, num_heads, 1] + output = O / L_expanded + + return output.to(input_dtype) diff --git a/examples/32_ring_attention/ring_attention_layer.py b/examples/32_ring_attention/ring_attention_layer.py new file mode 100644 index 00000000..7828c2d0 --- /dev/null +++ b/examples/32_ring_attention/ring_attention_layer.py @@ -0,0 +1,89 @@ +################################################################################ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# +# Ring Attention layer based on: +# "Ring Attention with Blockwise Transformers for Near-Infinite Context" +# Liu et al., 2023 (https://arxiv.org/pdf/2310.01889) +# +################################################################################ + +import torch +import torch.nn as nn + +from ring_attention_kernels import ring_attn_fwd + + +class RingAttention(nn.Module): + """ + Ring Attention layer for sequence-parallel attention over very long sequences. + + The sequence is assumed to be **already split** across devices along the + sequence dimension before calling ``forward``. Each device receives a + contiguous chunk of Q, K, and V of shape ``[seq_local, num_heads, head_dim]``. + + Internally the layer implements the ring attention algorithm from Liu et al. + (2023): K and V rotate around the device ring while Q stays local, with + online softmax accumulation at every step. + + Args: + shmem: Iris shmem context used for ``barrier()`` and rank queries. + num_heads (int): Number of attention heads. + head_dim (int): Dimension of each attention head. + causal (bool): Whether to apply a causal (lower-triangular) attention + mask. Default: ``True``. + scale (float | None): Softmax scale. Defaults to + ``head_dim ** -0.5``. + + Example:: + + shmem = iris.iris() + layer = RingAttention(shmem, num_heads=16, head_dim=64) + q = torch.randn(seq_local, 16, 64, device="cuda", dtype=torch.float16) + k = torch.randn_like(q) + v = torch.randn_like(q) + output = layer(q, k, v) # [seq_local, 16, 64] + """ + + def __init__(self, shmem, num_heads: int, head_dim: int, causal: bool = True, scale: float | None = None): + super().__init__() + self.shmem = shmem + self.num_heads = num_heads + self.head_dim = head_dim + self.causal = causal + self.scale = scale if scale is not None else head_dim**-0.5 + # Ping-pong buffer cache: keyed by (shape, dtype) to avoid re-allocating + # the symmetric heap buffers on every forward pass. + self._buf_cache: dict[ + tuple[torch.Size, torch.dtype], tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + ] = {} + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Compute ring attention. + + Args: + q: Query tensor ``[seq_local, num_heads, head_dim]``. + k: Key tensor ``[seq_local, num_heads, head_dim]``. + v: Value tensor ``[seq_local, num_heads, head_dim]``. + + Returns: + Attention output tensor ``[seq_local, num_heads, head_dim]``. + """ + assert q.shape == k.shape == v.shape, "Q, K, V must have the same shape" + assert q.shape[1] == self.num_heads, f"Expected {self.num_heads} heads, got {q.shape[1]}" + assert q.shape[2] == self.head_dim, f"Expected head_dim {self.head_dim}, got {q.shape[2]}" + + # Lazily allocate (or reuse) ping-pong symmetric heap buffers for this shape. + buf_key = (k.shape, k.dtype) + if buf_key not in self._buf_cache: + self._buf_cache[buf_key] = ( + self.shmem.empty(k.shape, dtype=k.dtype), + self.shmem.empty(k.shape, dtype=k.dtype), + self.shmem.empty(v.shape, dtype=v.dtype), + self.shmem.empty(v.shape, dtype=v.dtype), + ) + ping_pong = self._buf_cache[buf_key] + + return ring_attn_fwd(q, k, v, self.shmem, causal=self.causal, scale=self.scale, _ping_pong_bufs=ping_pong) diff --git a/examples/32_ring_attention/scaling_benchmark.py b/examples/32_ring_attention/scaling_benchmark.py new file mode 100644 index 00000000..530be1f4 --- /dev/null +++ b/examples/32_ring_attention/scaling_benchmark.py @@ -0,0 +1,594 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Ring Attention Scaling Benchmark. + +Evaluates strong and weak scaling of ring attention on AMD MI300X GPUs. + +**Strong scaling** (fixed total problem size, increasing world_size): + Total sequence length is held constant; adding more GPUs should reduce + latency proportionally. Ideal strong-scaling speedup = world_size. + +**Weak scaling** (fixed per-GPU problem size, increasing world_size): + Each GPU always processes seq_local tokens; adding more GPUs increases + the total sequence while keeping per-GPU work constant. Ideal weak-scaling + efficiency = 100% (flat latency). + +The reference is PyTorch ``scaled_dot_product_attention`` running the *full* +sequence on a *single* GPU, which is the baseline both scaling analyses are +measured against. + +Usage:: + + # Full sweep: world_size in [1, 2, 4, 8], save plots + python examples/32_ring_attention/scaling_benchmark.py --save_fig scaling.png + + # Quick test with 2 and 4 GPUs only + python examples/32_ring_attention/scaling_benchmark.py --world_sizes 2 4 + + # Show table only (no plotting) + python examples/32_ring_attention/scaling_benchmark.py --no_plot +""" + +import argparse +import json +import os +import sys +import tempfile +from pathlib import Path +from typing import Any + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import iris + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from ring_attention_layer import RingAttention # noqa: E402 + + +# --------------------------------------------------------------------------- +# Hardware peak specs (MI300X / gfx942) +# --------------------------------------------------------------------------- + +_MI300X_FP16_TFLOPS = 1307.4 +_MI300X_MEMBW_GBS = 5300.0 +_MI300X_CU_COUNT = 304 +_FALLBACK_FP16_TFLOPS = 100.0 +_FALLBACK_MEMBW_GBS = 500.0 +_GB_TO_TB = 1e3 + + +def _get_hw_specs(device: torch.device) -> tuple[float, float]: + try: + props = torch.cuda.get_device_properties(device) + name = props.name.lower() + if "gfx942" in name or "mi300" in name or props.multi_processor_count == _MI300X_CU_COUNT: + return _MI300X_FP16_TFLOPS, _MI300X_MEMBW_GBS + except Exception: + pass + return _FALLBACK_FP16_TFLOPS, _FALLBACK_MEMBW_GBS + + +# --------------------------------------------------------------------------- +# Timing +# --------------------------------------------------------------------------- + + +def _time_ms(fn, warmup: int = 3, iters: int = 10) -> float: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + times = [] + for _ in range(iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times.sort() + return times[len(times) // 2] + + +# --------------------------------------------------------------------------- +# FLOPs helpers +# --------------------------------------------------------------------------- + + +def _attn_flops(seq_q: int, seq_kv: int, num_heads: int, head_dim: int, causal: bool) -> int: + flops = 4 * seq_q * seq_kv * head_dim * num_heads + return flops // 2 if causal else flops + + +# --------------------------------------------------------------------------- +# Worker +# --------------------------------------------------------------------------- + + +def _scaling_worker( + rank: int, + world_size: int, + init_url: str, + num_heads: int, + head_dim: int, + dtype_str: str, + causal: bool, + # Strong scaling: fixed total_seq list + strong_total_seqs: list[int], + # Weak scaling: fixed seq_local list + weak_seq_locals: list[int], + num_warmup: int, + num_iters: int, + results_file: str, +): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=rank, + device_id=torch.device(f"cuda:{rank}"), + ) + torch.cuda.set_device(rank) + torch.set_default_device(f"cuda:{rank}") + + shmem = iris.iris() + device = torch.device(f"cuda:{rank}") + peak_tflops, peak_bw = _get_hw_specs(device) + dtype = getattr(torch, dtype_str) + scale = head_dim**-0.5 + + strong_results = [] + weak_results = [] + + # ------------------------------------------------------------------ + # Helper: time ring attention for given seq_local + # ------------------------------------------------------------------ + def _run_ring(seq_local: int, _shmem) -> float: + torch.manual_seed(42 + rank) + q = torch.randn(seq_local, num_heads, head_dim, dtype=dtype) + k = torch.randn_like(q) + v = torch.randn_like(q) + layer = RingAttention(_shmem, num_heads=num_heads, head_dim=head_dim, causal=causal, scale=scale) + _shmem.barrier() + ms = _time_ms(lambda: layer(q, k, v), warmup=num_warmup, iters=num_iters) + _shmem.barrier() + return ms + + # ------------------------------------------------------------------ + # Helper: time single-GPU SDPA (rank 0 only, full sequence) + # ------------------------------------------------------------------ + def _run_sdpa(total_seq: int) -> float | None: + if rank != 0: + return None + try: + q_f = torch.randn(num_heads, total_seq, head_dim, dtype=dtype) + k_f = torch.randn_like(q_f) + v_f = torch.randn_like(q_f) + ms = _time_ms( + lambda: torch.nn.functional.scaled_dot_product_attention(q_f, k_f, v_f, scale=scale, is_causal=causal), + warmup=num_warmup, + iters=num_iters, + ) + return ms + except torch.OutOfMemoryError: + print(f"[WARN] SDPA reference OOM at total_seq={total_seq}, skipping") + torch.cuda.empty_cache() + return float("nan") + + # ------------------------------------------------------------------ + # STRONG SCALING: fixed total_seq, world_size GPUs + # ------------------------------------------------------------------ + for total_seq in strong_total_seqs: + if total_seq % (64 * world_size) != 0: + continue + seq_local = total_seq // world_size + ring_ms = _run_ring(seq_local, shmem) + ref_ms = _run_sdpa(total_seq) + + if rank == 0: + ring_flops = _attn_flops(seq_local, total_seq, num_heads, head_dim, causal) + ring_tflops = ring_flops / (ring_ms * 1e-3) / 1e12 + + ref_flops = _attn_flops(total_seq, total_seq, num_heads, head_dim, causal) + ref_tflops = ref_flops / (ref_ms * 1e-3) / 1e12 + + strong_results.append( + { + "total_seq": total_seq, + "world_size": world_size, + "seq_local": seq_local, + "ring_ms": ring_ms, + "ref_ms": ref_ms, + "speedup": ref_ms / ring_ms, + "ideal_speedup": float(world_size), + "scaling_efficiency": ref_ms / (ring_ms * world_size), + "ring_tflops": ring_tflops, + "ref_tflops": ref_tflops, + "peak_tflops": peak_tflops, + "peak_bw_gbs": peak_bw, + } + ) + + shmem.barrier() + + # ------------------------------------------------------------------ + # WEAK SCALING: fixed seq_local per GPU, world_size GPUs + # ------------------------------------------------------------------ + for seq_local in weak_seq_locals: + if seq_local % 64 != 0: + continue + total_seq = seq_local * world_size + ring_ms = _run_ring(seq_local, shmem) + # Reference = single-GPU SDPA on the *full* sequence (total_seq) + ref_ms = _run_sdpa(total_seq) + + if rank == 0: + ring_flops = _attn_flops(seq_local, total_seq, num_heads, head_dim, causal) + ring_tflops = ring_flops / (ring_ms * 1e-3) / 1e12 + + ref_flops = _attn_flops(total_seq, total_seq, num_heads, head_dim, causal) + ref_tflops = ref_flops / (ref_ms * 1e-3) / 1e12 + + weak_results.append( + { + "seq_local": seq_local, + "total_seq": total_seq, + "world_size": world_size, + "ring_ms": ring_ms, + "ref_ms": ref_ms, + "speedup": ref_ms / ring_ms, + "ring_tflops": ring_tflops, + "ref_tflops": ref_tflops, + "peak_tflops": peak_tflops, + "peak_bw_gbs": peak_bw, + } + ) + + shmem.barrier() + del shmem + dist.destroy_process_group() + + if rank == 0: + with open(results_file, "w") as f: + json.dump({"strong": strong_results, "weak": weak_results}, f, indent=2) + + +# --------------------------------------------------------------------------- +# Print tables +# --------------------------------------------------------------------------- + + +def _print_strong_table(strong: list[dict[str, Any]]): + print("\n=== STRONG SCALING (fixed total_seq, increasing world_size) ===") + hdr = f"{'total_seq':>10} {'GPUs':>5} {'ring ms':>9} {'ref ms':>9} {'speedup':>8} {'ideal':>6} {'eff%':>7} {'ring TF':>9} {'ref TF':>9}" + print("=" * len(hdr)) + print(hdr) + print("=" * len(hdr)) + for r in sorted(strong, key=lambda x: (x["total_seq"], x["world_size"])): + eff = 100.0 * r["scaling_efficiency"] + print( + f"{r['total_seq']:>10} {r['world_size']:>5} {r['ring_ms']:>9.3f} {r['ref_ms']:>9.3f} " + f"{r['speedup']:>8.2f}x {r['ideal_speedup']:>6.1f}x {eff:>6.1f}% " + f"{r['ring_tflops']:>9.2f} {r['ref_tflops']:>9.2f}" + ) + print("=" * len(hdr)) + + +def _print_weak_table(weak: list[dict[str, Any]]): + print("\n=== WEAK SCALING (fixed seq_local per GPU, increasing world_size) ===") + hdr = f"{'seq_local':>10} {'total_seq':>10} {'GPUs':>5} {'ring ms':>9} {'ref ms':>9} {'speedup':>8} {'ring TF':>9} {'ref TF':>9}" + print("=" * len(hdr)) + print(hdr) + print("=" * len(hdr)) + for r in sorted(weak, key=lambda x: (x["seq_local"], x["world_size"])): + print( + f"{r['seq_local']:>10} {r['total_seq']:>10} {r['world_size']:>5} " + f"{r['ring_ms']:>9.3f} {r['ref_ms']:>9.3f} {r['speedup']:>8.2f}x " + f"{r['ring_tflops']:>9.2f} {r['ref_tflops']:>9.2f}" + ) + print("=" * len(hdr)) + + +# --------------------------------------------------------------------------- +# Plot +# --------------------------------------------------------------------------- + + +def _make_scaling_plots( + strong: list[dict[str, Any]], + weak: list[dict[str, Any]], + num_heads: int, + head_dim: int, + causal: bool, + save_fig: str | None, +): + import matplotlib + import matplotlib.pyplot as plt + import numpy as np + + if save_fig: + matplotlib.use("Agg") + + _print_strong_table(strong) + _print_weak_table(weak) + + # ---- Layout: 2×2 ---- + fig, axes = plt.subplots(2, 2, figsize=(16, 12)) + fig.suptitle( + f"Ring Attention Scaling — AMD MI300X (gfx942), FP16, causal={causal}\nH={num_heads}, D={head_dim}", + fontsize=13, + fontweight="bold", + ) + + # --- 1. Strong scaling: latency vs world_size per total_seq --- + ax = axes[0, 0] + total_seqs_ss = sorted(set(r["total_seq"] for r in strong)) + colors_ss = plt.cm.tab10(np.linspace(0, 0.9, len(total_seqs_ss))) + for ts, col in zip(total_seqs_ss, colors_ss): + pts = sorted([r for r in strong if r["total_seq"] == ts], key=lambda x: x["world_size"]) + if not pts: + continue + ws_vals = [p["world_size"] for p in pts] + ring_ms = [p["ring_ms"] for p in pts] + ref_ms = pts[0]["ref_ms"] # single-GPU reference is constant + + ax.plot(ws_vals, ring_ms, "o-", color=col, linewidth=2, markersize=8, label=f"Ring S={ts}") + # Ideal scaling: ref_ms / world_size + ideal = [ref_ms / ws for ws in ws_vals] + ax.plot(ws_vals, ideal, "--", color=col, linewidth=1.2, alpha=0.5) + ax.set_xlabel("Number of GPUs") + ax.set_ylabel("Latency (ms)") + ax.set_title("Strong Scaling: Latency vs. GPU Count\n(dashed = ideal 1/N scaling)") + ax.set_xticks(sorted(set(r["world_size"] for r in strong))) + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # --- 2. Strong scaling: scaling efficiency % --- + ax = axes[0, 1] + for ts, col in zip(total_seqs_ss, colors_ss): + pts = sorted([r for r in strong if r["total_seq"] == ts], key=lambda x: x["world_size"]) + if not pts: + continue + ws_vals = [p["world_size"] for p in pts] + eff = [100.0 * p["scaling_efficiency"] for p in pts] + ax.plot(ws_vals, eff, "s-", color=col, linewidth=2, markersize=8, label=f"S={ts}") + ax.axhline(100, color="gray", linestyle="--", alpha=0.6, linewidth=1.5, label="Ideal (100%)") + ax.set_xlabel("Number of GPUs") + ax.set_ylabel("Strong Scaling Efficiency (%)") + ax.set_title("Strong Scaling Efficiency\n(100% = perfect linear speedup)") + ax.set_xticks(sorted(set(r["world_size"] for r in strong))) + ax.set_ylim(0, 130) + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # --- 3. Weak scaling: latency vs world_size per seq_local --- + ax = axes[1, 0] + seq_locals_ws = sorted(set(r["seq_local"] for r in weak)) + colors_ws = plt.cm.tab10(np.linspace(0, 0.9, len(seq_locals_ws))) + for sl, col in zip(seq_locals_ws, colors_ws): + pts = sorted([r for r in weak if r["seq_local"] == sl], key=lambda x: x["world_size"]) + if not pts: + continue + ws_vals = [p["world_size"] for p in pts] + ring_ms = [p["ring_ms"] for p in pts] + # Baseline = single-GPU ring at world_size=1 (first point if available, + # else we just plot relative to the first measured point) + ax.plot(ws_vals, ring_ms, "o-", color=col, linewidth=2, markersize=8, label=f"Ring S_local={sl}") + # Ideal weak scaling = flat (constant latency) + ax.axhline(ring_ms[0], color=col, linestyle="--", linewidth=1.2, alpha=0.4) + ax.set_xlabel("Number of GPUs") + ax.set_ylabel("Latency (ms)") + ax.set_title("Weak Scaling: Latency vs. GPU Count\n(dashed = ideal flat latency)") + ax.set_xticks(sorted(set(r["world_size"] for r in weak))) + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # --- 4. Throughput (TFLOPS per GPU) vs world_size for both strong & weak --- + ax = axes[1, 1] + # Strong scaling TFLOPS + for ts, col in zip(total_seqs_ss, colors_ss): + pts = sorted([r for r in strong if r["total_seq"] == ts], key=lambda x: x["world_size"]) + if not pts: + continue + ws_vals = [p["world_size"] for p in pts] + tfl = [p["ring_tflops"] for p in pts] + ax.plot(ws_vals, tfl, "o-", color=col, linewidth=2, markersize=8, label=f"Strong S={ts}") + + ax.set_xlabel("Number of GPUs") + ax.set_ylabel("TFLOPS (per rank)") + ax.set_title("Per-Rank Throughput vs. GPU Count\n(strong scaling)") + ax.set_xticks(sorted(set(r["world_size"] for r in strong))) + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + + if save_fig: + plt.savefig(save_fig, dpi=150, bbox_inches="tight") + print(f"\nSaved scaling figure to: {save_fig}") + else: + plt.show() + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def parse_args(): + p = argparse.ArgumentParser(description="Ring Attention strong/weak scaling benchmark") + p.add_argument( + "--world_sizes", + nargs="+", + type=int, + default=[1, 2, 4, 8], + help="GPU counts to benchmark (default: 1 2 4 8)", + ) + p.add_argument("--num_heads", type=int, default=16) + p.add_argument("--head_dim", type=int, default=64) + p.add_argument( + "--strong_seqs", + nargs="+", + type=int, + default=[4096, 8192, 16384], + help="Fixed total sequence lengths for strong-scaling sweep", + ) + p.add_argument( + "--weak_seq_locals", + nargs="+", + type=int, + default=[1024, 2048, 4096], + help="Fixed per-GPU sequence lengths for weak-scaling sweep", + ) + p.add_argument("--no_causal", dest="causal", action="store_false", default=True) + p.add_argument("--dtype", choices=["float16", "bfloat16"], default="float16") + p.add_argument("--warmup", type=int, default=5) + p.add_argument("--iters", type=int, default=20) + p.add_argument("--save_fig", type=str, default=None) + p.add_argument("--no_plot", action="store_true") + return p.parse_args() + + +def main(): + args = parse_args() + + all_strong: list[dict] = [] + all_weak: list[dict] = [] + + # Special case: world_size=1 means single-GPU ring (= SDPA, measured directly) + if 1 in args.world_sizes: + print("Measuring world_size=1 (single GPU, ring degenerates to SDPA)...") + device = torch.device("cuda:0") + torch.cuda.set_device(0) + peak_tflops, peak_bw = ( + _MI300X_FP16_TFLOPS + if torch.cuda.get_device_properties(0).multi_processor_count == _MI300X_CU_COUNT + else _FALLBACK_FP16_TFLOPS, + _MI300X_MEMBW_GBS + if torch.cuda.get_device_properties(0).multi_processor_count == _MI300X_CU_COUNT + else _FALLBACK_MEMBW_GBS, + ) + dtype = getattr(torch, args.dtype) + scale = args.head_dim**-0.5 + + for total_seq in args.strong_seqs: + if total_seq % 64 != 0: + continue + q_f = torch.randn(args.num_heads, total_seq, args.head_dim, dtype=dtype, device=device) + k_f = torch.randn_like(q_f) + v_f = torch.randn_like(q_f) + ms = _time_ms( + lambda: torch.nn.functional.scaled_dot_product_attention( + q_f, k_f, v_f, scale=scale, is_causal=args.causal + ), + warmup=args.warmup, + iters=args.iters, + ) + ref_flops = _attn_flops(total_seq, total_seq, args.num_heads, args.head_dim, args.causal) + ref_tflops = ref_flops / (ms * 1e-3) / 1e12 + all_strong.append( + { + "total_seq": total_seq, + "world_size": 1, + "seq_local": total_seq, + "ring_ms": ms, + "ref_ms": ms, + "speedup": 1.0, + "ideal_speedup": 1.0, + "scaling_efficiency": 1.0, + "ring_tflops": ref_tflops, + "ref_tflops": ref_tflops, + "peak_tflops": peak_tflops, + "peak_bw_gbs": peak_bw, + } + ) + + for seq_local in args.weak_seq_locals: + if seq_local % 64 != 0: + continue + q_f = torch.randn(args.num_heads, seq_local, args.head_dim, dtype=dtype, device=device) + k_f = torch.randn_like(q_f) + v_f = torch.randn_like(q_f) + ms = _time_ms( + lambda: torch.nn.functional.scaled_dot_product_attention( + q_f, k_f, v_f, scale=scale, is_causal=args.causal + ), + warmup=args.warmup, + iters=args.iters, + ) + ref_flops = _attn_flops(seq_local, seq_local, args.num_heads, args.head_dim, args.causal) + ref_tflops = ref_flops / (ms * 1e-3) / 1e12 + all_weak.append( + { + "seq_local": seq_local, + "total_seq": seq_local, + "world_size": 1, + "ring_ms": ms, + "ref_ms": ms, + "speedup": 1.0, + "ring_tflops": ref_tflops, + "ref_tflops": ref_tflops, + "peak_tflops": peak_tflops, + "peak_bw_gbs": peak_bw, + } + ) + + torch.cuda.empty_cache() + + for world_size in sorted(args.world_sizes): + if world_size == 1: + continue + if world_size > torch.cuda.device_count(): + print(f"[skip] world_size={world_size} > available GPUs ({torch.cuda.device_count()})") + continue + + print(f"\nRunning world_size={world_size}...") + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + results_file = f.name + + port = 29500 + world_size # unique port per world_size to avoid conflicts + init_url = f"tcp://127.0.0.1:{port}" + + try: + mp.spawn( + fn=_scaling_worker, + args=( + world_size, + init_url, + args.num_heads, + args.head_dim, + args.dtype, + args.causal, + args.strong_seqs, + args.weak_seq_locals, + args.warmup, + args.iters, + results_file, + ), + nprocs=world_size, + join=True, + ) + with open(results_file) as f: + data = json.load(f) + all_strong.extend(data["strong"]) + all_weak.extend(data["weak"]) + finally: + os.unlink(results_file) + + if not args.no_plot: + _make_scaling_plots(all_strong, all_weak, args.num_heads, args.head_dim, args.causal, save_fig=args.save_fig) + else: + _print_strong_table(all_strong) + _print_weak_table(all_weak) + + +if __name__ == "__main__": + main() diff --git a/tests/examples/test_ring_attention.py b/tests/examples/test_ring_attention.py new file mode 100644 index 00000000..fdd8a0e6 --- /dev/null +++ b/tests/examples/test_ring_attention.py @@ -0,0 +1,187 @@ +################################################################################ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# +# Correctness tests for Ring Attention. +# +# Each test validates the distributed ring-attention output against a +# single-device PyTorch reference implementation. +# +################################################################################ + +import gc +import sys +from pathlib import Path + +import pytest +import torch +import iris + +project_root = Path(__file__).resolve() +while not (project_root / "tests").is_dir() or not (project_root / "examples").is_dir(): + if project_root == project_root.parent: + raise FileNotFoundError("Could not find project root") + project_root = project_root.parent + +module_dir = project_root / "examples" / "32_ring_attention" +if module_dir.exists(): + sys.path.insert(0, str(module_dir)) + +from ring_attention_layer import RingAttention # noqa: E402 + + +# --------------------------------------------------------------------------- +# Reference (single-device) implementation +# --------------------------------------------------------------------------- + + +def ref_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + causal: bool, +) -> torch.Tensor: + """ + Reference causal/non-causal self-attention on a single device. + + Args: + q: ``[total_seq, num_heads, head_dim]`` + k: ``[total_seq, num_heads, head_dim]`` + v: ``[total_seq, num_heads, head_dim]`` + scale: Softmax scale factor. + causal: Whether to apply causal masking. + + Returns: + Attention output ``[total_seq, num_heads, head_dim]``. + """ + total_seq, num_heads, head_dim = q.shape + # Work in float32 for reference accuracy + q_f = q.float() + k_f = k.float() + v_f = v.float() + + # [num_heads, total_seq, head_dim] + q_h = q_f.permute(1, 0, 2) + k_h = k_f.permute(1, 0, 2) + v_h = v_f.permute(1, 0, 2) + + # Attention scores: [num_heads, total_seq, total_seq] + attn = torch.bmm(q_h, k_h.transpose(-1, -2)) * scale + + if causal: + mask = torch.triu(torch.ones(total_seq, total_seq, device=q.device, dtype=torch.bool), diagonal=1) + attn = attn.masked_fill(mask.unsqueeze(0), float("-inf")) + + attn = torch.softmax(attn, dim=-1) + out = torch.bmm(attn, v_h) # [num_heads, total_seq, head_dim] + return out.permute(1, 0, 2).to(q.dtype) # [total_seq, num_heads, head_dim] + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + + +def _run_ring_attn_test(total_seq_len, num_heads, head_dim, causal, dtype): + """Run one correctness check; called from test functions.""" + shmem = None + try: + shmem = iris.iris() + rank = shmem.get_rank() + num_ranks = shmem.get_num_ranks() + + torch.set_default_device("cuda") + torch.manual_seed(0) + + scale = head_dim**-0.5 + seq_local = total_seq_len // num_ranks + + # Rank 0 creates the full Q, K, V and broadcasts to all ranks so that + # the reference and distributed implementations see the same data. + if rank == 0: + q_full = torch.randn(total_seq_len, num_heads, head_dim, dtype=dtype) * 0.1 + k_full = torch.randn(total_seq_len, num_heads, head_dim, dtype=dtype) * 0.1 + v_full = torch.randn(total_seq_len, num_heads, head_dim, dtype=dtype) * 0.1 + else: + q_full = torch.empty(total_seq_len, num_heads, head_dim, dtype=dtype) + k_full = torch.empty(total_seq_len, num_heads, head_dim, dtype=dtype) + v_full = torch.empty(total_seq_len, num_heads, head_dim, dtype=dtype) + + q_full = torch.from_numpy(shmem.broadcast(q_full.cpu().numpy(), source_rank=0)).to(q_full.device) + k_full = torch.from_numpy(shmem.broadcast(k_full.cpu().numpy(), source_rank=0)).to(k_full.device) + v_full = torch.from_numpy(shmem.broadcast(v_full.cpu().numpy(), source_rank=0)).to(v_full.device) + + # Local chunks for this rank + q_local = q_full[rank * seq_local : (rank + 1) * seq_local].contiguous() + k_local = k_full[rank * seq_local : (rank + 1) * seq_local].contiguous() + v_local = v_full[rank * seq_local : (rank + 1) * seq_local].contiguous() + + shmem.barrier() + + # --- Distributed ring attention --- + layer = RingAttention(shmem, num_heads=num_heads, head_dim=head_dim, causal=causal, scale=scale) + output_local = layer(q_local, k_local, v_local) + torch.cuda.synchronize() + + # --- Single-device reference --- + ref_full = ref_attention(q_full, k_full, v_full, scale=scale, causal=causal) + ref_local = ref_full[rank * seq_local : (rank + 1) * seq_local] + + shmem.barrier() + + # Compare with relatively tight tolerances + atol, rtol = (2e-2, 2e-2) if dtype == torch.float16 else (1e-2, 1e-2) + error = None + try: + torch.testing.assert_close(output_local.float(), ref_local.float(), atol=atol, rtol=rtol) + except AssertionError as e: + error = e + + # Print a brief report from rank 0 + if rank == 0: + max_diff = (output_local.float() - ref_local.float()).abs().max().item() + status = "PASSED" if error is None else "FAILED" + print( + f"[Rank 0] Ring Attention test {status} | " + f"seq={total_seq_len} h={num_heads} d={head_dim} " + f"causal={causal} dtype={dtype} | max_diff={max_diff:.6f}" + ) + + shmem.barrier() + + if error is not None: + raise error + + finally: + if shmem is not None: + try: + shmem.barrier() + except Exception: + pass + del shmem + gc.collect() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("num_heads", [8, 16]) +@pytest.mark.parametrize("total_seq_len", [512, 2048]) +@pytest.mark.parametrize("causal", [True, False]) +def test_ring_attention_correctness(total_seq_len, num_heads, head_dim, causal): + """ + Validate ring attention output against a single-device PyTorch reference + for both causal and bidirectional modes. + """ + _run_ring_attn_test( + total_seq_len=total_seq_len, + num_heads=num_heads, + head_dim=head_dim, + causal=causal, + dtype=torch.float16, + )