From 7f1d49618ce5b137efd23cf5ff2d799e1f24e089 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Mar 2026 16:20:01 +0000 Subject: [PATCH 1/2] Initial plan From 2b903ad4c8a7ba51a0312f29d7eba738cee243da Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Mar 2026 19:41:07 +0000 Subject: [PATCH 2/2] Modernize flash decode example: replace example_run.py with torchrun-based example.py Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- examples/13_flash_decode/README.md | 8 +- examples/13_flash_decode/example.py | 157 +++++++++++++++++++++++ examples/13_flash_decode/example_run.py | 163 ------------------------ examples/README.md | 2 +- 4 files changed, 164 insertions(+), 166 deletions(-) create mode 100644 examples/13_flash_decode/example.py delete mode 100644 examples/13_flash_decode/example_run.py diff --git a/examples/13_flash_decode/README.md b/examples/13_flash_decode/README.md index e71c69cbf..930fb7521 100644 --- a/examples/13_flash_decode/README.md +++ b/examples/13_flash_decode/README.md @@ -21,9 +21,13 @@ We perform comparisons against the RCCL baseline. To simply do a test run of the code, run: ```terminal -python examples/13_flash_decode/example_run.py +torchrun --nproc_per_node= --standalone examples/13_flash_decode/example.py +``` + +Pass `--validate` to verify the output against a PyTorch reference: +```terminal +torchrun --nproc_per_node= --standalone examples/13_flash_decode/example.py --validate ``` -This example will run by default on 8 GPUs. Use the `--num_ranks` flag to select the number of GPUs. ### Validation diff --git a/examples/13_flash_decode/example.py b/examples/13_flash_decode/example.py new file mode 100644 index 000000000..12992d2a4 --- /dev/null +++ b/examples/13_flash_decode/example.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Example: Flash Decode Fused Attention + +A distributed Flash Decode kernel for accelerating LLM inference. The KV cache +is sharded across all ranks; each rank computes local attention scores and +participates in a fused global reduce to produce the final output. + +Run with: + torchrun --nproc_per_node= --standalone example.py [--validate] +""" + +import argparse +import os +import sys +from pathlib import Path + +import torch +import torch.distributed as dist + +import iris + +# The flash_decode_fused_layer module lives alongside this file +sys.path.insert(0, str(Path(__file__).parent)) +from flash_decode_fused_layer import flash_decode_fused_layer # noqa: E402 + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Flash Decode fused attention example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--kv_len_per_rank", type=int, default=32768, help="KV sequence length per rank") + parser.add_argument("--num_heads", type=int, default=96, help="Number of attention heads") + parser.add_argument("--head_dim", type=int, default=128, help="Dimension of each attention head") + parser.add_argument("--num_seqs", type=int, default=4, help="Number of sequences in the batch") + parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") + parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "bf16"], help="Data type") + parser.add_argument("-v", "--validate", action="store_true", help="Validate output against PyTorch reference") + return vars(parser.parse_args()) + + +def ref_paged_attn(query, key_cache, value_cache, kv_lens, block_tables, scale): + """Compute reference paged attention output using PyTorch.""" + num_seqs = query.shape[0] + _, block_size, num_kv_heads, head_size = key_cache.shape + outputs = [] + for i in range(num_seqs): + kv_len = kv_lens[i] + q = query[i : i + 1] * scale # (1, num_q_heads, head_dim) + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks].cpu().numpy() + k = key_cache[block_indices].view(-1, num_kv_heads, head_size)[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size)[:kv_len] + gqa_ratio = q.shape[1] // k.shape[1] + if gqa_ratio > 1: + k = torch.repeat_interleave(k, gqa_ratio, dim=1) + v = torch.repeat_interleave(v, gqa_ratio, dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + outputs.append(out) + return torch.cat(outputs, dim=0) + + +def main(): + args = parse_args() + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="gloo") + + ctx = iris.iris(heap_size=args["heap_size"]) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16} + dtype = dtype_map[args["datatype"]] + + torch.manual_seed(42) + torch.set_default_device("cuda") + + kv_len = args["kv_len_per_rank"] + num_heads = args["num_heads"] + head_dim = args["head_dim"] + num_seqs = args["num_seqs"] + num_kv_heads = max(1, num_heads // 8) + block_size = 1 + scale = head_dim**-0.5 + num_blocks_per_rank = (kv_len + block_size - 1) // block_size + + # Build input tensors + query = torch.randn(num_seqs, num_heads, head_dim, dtype=dtype) + key_cache = torch.randn(num_blocks_per_rank, block_size, num_kv_heads, head_dim, dtype=dtype) + value_cache = torch.randn(num_blocks_per_rank, block_size, num_kv_heads, head_dim, dtype=dtype) + block_table = torch.arange(num_blocks_per_rank, dtype=torch.int32).repeat(num_seqs, 1) + kv_lens_tensor = torch.tensor([kv_len] * num_seqs, dtype=torch.int32) + global_kv_lens = kv_lens_tensor.unsqueeze(0).repeat(world_size, 1) + + ctx.barrier() + + fd_layer = flash_decode_fused_layer( + ctx, + rank, + rank, + world_size, + world_size, + num_q_heads=num_heads, + num_kv_heads=num_kv_heads, + q_head_dim=head_dim, + v_head_dim=head_dim, + page_size=block_size, + scale=scale, + soft_cap=0.0, + max_allowed_batch=num_seqs, + ) + + output = fd_layer(query, key_cache, value_cache, global_kv_lens, block_table) + torch.cuda.synchronize() + + if rank == 0: + ctx.info( + f"flash_decode: world_size={world_size}, num_heads={num_heads}, " + f"head_dim={head_dim}, num_seqs={num_seqs}, kv_len_per_rank={kv_len}, dtype={dtype}" + ) + + if args["validate"]: + # Gather all rank KV caches for a full reference computation + all_key = torch.zeros(world_size * num_blocks_per_rank, block_size, num_kv_heads, head_dim, dtype=dtype) + all_val = torch.zeros(world_size * num_blocks_per_rank, block_size, num_kv_heads, head_dim, dtype=dtype) + dist.all_gather_into_tensor(all_key, key_cache) + dist.all_gather_into_tensor(all_val, value_cache) + + ref_block_table = torch.cat([block_table + r * num_blocks_per_rank for r in range(world_size)], dim=-1) + global_kv_len = kv_len * world_size + ref_kv_lens = [global_kv_len] * num_seqs + + ref_output = ref_paged_attn(query, all_key, all_val, ref_kv_lens, ref_block_table, scale) + + try: + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) + if rank == 0: + ctx.info(f"Validation passed: output[0,0,:3] = {output[0, 0, :3].tolist()}") + except AssertionError as e: + if rank == 0: + ctx.info(f"Validation FAILED: {e}") + raise + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/13_flash_decode/example_run.py b/examples/13_flash_decode/example_run.py deleted file mode 100644 index 34ab1fc37..000000000 --- a/examples/13_flash_decode/example_run.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. - -""" -A simple, minimal example demonstrating how to use the flash_decode_fused_layer. - -This script initializes the necessary distributed components with Iris, -creates sample input tensors, instantiates the layer, and calls its -forward pass once. It then prints the shape and a slice of the output -tensor to show that the operation completed successfully. - -The layer is defined in the flash_decode_fused_layer.py file. -All the triton kernels are defined in decode_kernels.py -""" - -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -import iris -import argparse -from flash_decode_fused_layer import flash_decode_fused_layer - - -def parse_args(): - """Parses command-line arguments for the example.""" - parser = argparse.ArgumentParser(description="A minimal example for flash_decode_fused_layer.") - parser.add_argument("--kv_len_per_rank", type=int, default=32768, help="KV sequence length per rank.") - parser.add_argument("--num_heads", type=int, default=96, help="Number of attention heads.") - parser.add_argument("--head_dim", type=int, default=128, help="Dimension of each attention head.") - parser.add_argument("--num_seqs", type=int, default=4, help="Number of sequences in the batch.") - parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.") - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "bfloat16"], help="PyTorch data type to use." - ) - return parser.parse_args() - - -def setup_example_data(rank, world_size, args, dtype): - """Creates a set of random tensors to serve as inputs for the layer.""" - - num_query_heads = args.num_heads - # Assume an 8:1 Grouped-Query Attention ratio for this example - num_kv_heads = max(1, args.num_heads // 8) - block_size = 1 # PagedAttention works with blocks of tokens - - # Number of blocks needed on this rank to store the KV cache for all sequences - num_blocks_per_rank = (args.kv_len_per_rank + block_size - 1) // block_size - - print(f"[Rank {rank}] Creating example tensors...") - - # 1. Query tensor: The new tokens for which we are calculating attention. - query = torch.randn(args.num_seqs, num_query_heads, args.head_dim, dtype=dtype).cuda() - - # 2. Key/Value Caches: Tensors representing the keys and values - # The KV is split across ranks - key_cache_this_rank = torch.randn(num_blocks_per_rank, block_size, num_kv_heads, args.head_dim, dtype=dtype).cuda() - value_cache_this_rank = torch.randn( - num_blocks_per_rank, block_size, num_kv_heads, args.head_dim, dtype=dtype - ).cuda() - - # 3. Block Tables: A mapping that tells the kernel where to find the blocks for each sequence in the KV cache. - # Here, we create a simple identity mapping for demonstration. - block_tables_this_rank = torch.arange(num_blocks_per_rank, dtype=torch.int32).repeat(args.num_seqs, 1).cuda() - - # 4. Global KV Lengths Tensor: The layer needs to know the sequence length on all ranks. - # Create a list of lengths for each sequence in the batch on this rank. - kv_lens_per_rank = [args.kv_len_per_rank] * args.num_seqs - # Create a 1D tensor from this list. Shape: (NUM_SEQS,) - kv_lens_tensor_this_rank = torch.tensor(kv_lens_per_rank, dtype=torch.int32).cuda() - # Reshape to (1, NUM_SEQS) and repeat for all ranks to get shape (world_size, NUM_SEQS) - global_kv_lens_tensor = kv_lens_tensor_this_rank.unsqueeze(0).repeat(world_size, 1) - - return { - "query": query, - "key_cache_this_rank": key_cache_this_rank, - "value_cache_this_rank": value_cache_this_rank, - "block_tables_this_rank": block_tables_this_rank, - "global_kv_lens_tensor": global_kv_lens_tensor, - } - - -def example_run(rank: int, world_size: int, init_url: str, args: dict): - backend = "nccl" if torch.cuda.is_available() else "gloo" - dist.init_process_group( - backend=backend, init_method=init_url, world_size=world_size, rank=rank, device_id=torch.device(f"cuda:{rank}") - ) - - # 1. Initialize Iris for distributed communication - shmem = iris.iris() - - torch.manual_seed(42) - torch.set_default_device("cuda") - dtype = getattr(torch, args.dtype) - - if rank == 0: - print("--- flash_decode_fused_layer Minimal Example ---") - print(f"Running with {world_size} rank(s).") - - # 2. Set up the example input tensors - tensor_data = setup_example_data(rank, world_size, args, dtype) - shmem.barrier() - - # 3. Define the layer's parameters - num_kv_heads = max(1, args.num_heads // 8) - scale = args.head_dim**-0.5 - common_params = { - "num_q_heads": args.num_heads, - "num_kv_heads": num_kv_heads, - "q_head_dim": args.head_dim, - "v_head_dim": args.head_dim, - "page_size": 1, - "scale": scale, - "soft_cap": 0.0, - "max_allowed_batch": args.num_seqs, - } - - # 4. Instantiate the layer - if rank == 0: - print("\nInstantiating flash_decode_fused_layer...") - fd_layer = flash_decode_fused_layer(shmem, rank, rank, world_size, world_size, **common_params) - - # 5. Call the forward pass of the layer - if rank == 0: - print("Calling the forward pass...") - output = fd_layer( - tensor_data["query"], - tensor_data["key_cache_this_rank"], - tensor_data["value_cache_this_rank"], - tensor_data["global_kv_lens_tensor"], - tensor_data["block_tables_this_rank"], - ) - - # Ensure the computation is finished before printing - torch.cuda.synchronize() - shmem.barrier() - - # 6. Print a summary of the output tensor on the main rank - if rank == 0: - print("\n--- Example Run Finished ---") - print(f"Output tensor shape: {output.shape}") - print("Output tensor values (first 5 elements of the first sequence):") - print(output[0, 0, :5]) - print("--------------------------") - - shmem.barrier() - dist.destroy_process_group() - - -def main(): - args = parse_args() - num_ranks = args.num_ranks - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=example_run, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) - - -if __name__ == "__main__": - main() diff --git a/examples/README.md b/examples/README.md index d361da826..c4ea656a1 100644 --- a/examples/README.md +++ b/examples/README.md @@ -80,7 +80,7 @@ python examples/11_gemm_all_scatter_producer_consumer/benchmark.py --benchmark - python examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py --benchmark --validate --num_ranks 8 # Flash Decode Attention - simple example run -python examples/13_flash_decode/example_run.py --num_ranks 8 +torchrun --nproc_per_node=8 --standalone examples/13_flash_decode/example.py # All-Gather + GEMM - Pull model python examples/14_all_gather_gemm/example_run_pull.py --num_ranks 8