diff --git a/examples/25_ccl_all_gather/example.py b/examples/25_ccl_all_gather/example.py index 18266d90..ec5aa8fb 100644 --- a/examples/25_ccl_all_gather/example.py +++ b/examples/25_ccl_all_gather/example.py @@ -8,7 +8,7 @@ Each rank contributes an (M, N) tensor; every rank receives the concatenated (world_size*M, N) result. Run with: - torchrun --nproc_per_node= --standalone example.py [--validate] + torchrun --nproc_per_node= --standalone example.py [--validate] [--use_gluon] """ import argparse @@ -37,6 +37,7 @@ def parse_args(): parser.add_argument("--num_stages", type=int, default=1, help="Number of stages") parser.add_argument("--num_warps", type=int, default=4, help="Number of warps") parser.add_argument("--waves_per_eu", type=int, default=0, help="Number of waves per EU") + parser.add_argument("--use_gluon", action="store_true", help="Use Gluon kernel backend") return vars(parser.parse_args()) @@ -47,7 +48,12 @@ def main(): torch.cuda.set_device(local_rank) dist.init_process_group(backend="gloo") - ctx = iris.iris(heap_size=args["heap_size"]) + if args["use_gluon"]: + import iris.experimental.iris_gluon as iris_gluon + + ctx = iris_gluon.iris(heap_size=args["heap_size"]) + else: + ctx = iris.iris(heap_size=args["heap_size"]) rank = ctx.get_rank() world_size = ctx.get_num_ranks() @@ -67,6 +73,7 @@ def main(): "num_stages": args["num_stages"], "num_warps": args["num_warps"], "waves_per_eu": args["waves_per_eu"], + "use_gluon": args["use_gluon"], } config = Config(**config_kwargs) diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index 2093fb3b..6b6c53db 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -12,6 +12,16 @@ from .config import Config from .utils import extract_group_info +# Conditional import for Gluon +try: + from triton.experimental import gluon + from triton.experimental.gluon import language as gl + from iris.experimental.iris_gluon import IrisDeviceCtx + + GLUON_AVAILABLE = True +except ImportError: + GLUON_AVAILABLE = False + @triton.jit() def persistent_all_gather( @@ -279,6 +289,158 @@ def persistent_all_gather_partitioned( ) +# Gluon implementation: flat-2D tiling approach +# +# Uses a single 1D arange over BLOCK_SIZE_M * BLOCK_SIZE_N elements with +# div/mod to compute 2D row/col indices. This gives one load + world_size +# stores per tile (matching Triton's 2D load/store structure) while staying +# within gluon's 1D BlockedLayout framework. +# +# Key optimizations: +# - Flat-2D tiling: eliminates the inner BLOCK_SIZE_M row loop +# - Hoisted pointer translation: local_base loaded once outside tile loop +# - Traffic shaping: staggered write order avoids memory controller contention +if GLUON_AVAILABLE: + + @gluon.jit + def persistent_all_gather_gluon( + IrisDeviceCtx: gl.constexpr, + context_tensor, + input_ptr, + output_ptr, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + group_rank: gl.constexpr, + iris_rank: gl.constexpr, + world_size: gl.constexpr, + rank_start: gl.constexpr, + rank_stride: gl.constexpr, + BLOCK_SIZE_M: gl.constexpr, + BLOCK_SIZE_N: gl.constexpr, + GROUP_SIZE_M: gl.constexpr, + COMM_SMS: gl.constexpr, + THREADS_PER_WARP: gl.constexpr, + WARPS_PER_CTA: gl.constexpr, + ): + """ + Persistent all-gather kernel using Gluon with flat-2D tiling. + + Uses a flat 1D index space of BLOCK_SIZE_M * BLOCK_SIZE_N elements, + computing 2D row/col via integer div/mod. This produces one vectorized + load and world_size vectorized stores per tile, matching Triton's 2D + load/store instruction structure while staying within gluon's 1D + BlockedLayout framework. + + Memory layout (BlockedLayout): + A 1D BlockedLayout distributes TOTAL_ELEMS = BLOCK_SIZE_M * BLOCK_SIZE_N + elements across the thread hierarchy: + ELEMS_PER_THREAD = TOTAL_ELEMS // (THREADS_PER_WARP * WARPS_PER_CTA) + + Each thread handles ELEMS_PER_THREAD contiguous elements in the + flattened row-major order. Row/col are recovered via: + row = flat_idx // BLOCK_SIZE_N + col = flat_idx % BLOCK_SIZE_N + + Constraints: + - BLOCK_SIZE_M * BLOCK_SIZE_N must be a multiple of + (THREADS_PER_WARP * WARPS_PER_CTA). + - Optimal tile: 2048-4096 total elements (8-16 per thread). + Larger tiles cause register spilling and performance collapse. + - Recommended: BLOCK_SIZE_M=8, BLOCK_SIZE_N=256 (2048 elems, 8/thread). + + Args: + IrisDeviceCtx: Gluon device context class for remote memory operations. + context_tensor: Opaque tensor holding IrisDeviceCtx state. + input_ptr: Pointer to local input tensor of shape (M, N). + output_ptr: Pointer to output tensor of shape (world_size * M, N). + M: Number of rows in the input tensor (per rank). + N: Number of columns. + stride_in_m, stride_in_n: Row and column strides for input tensor. + stride_out_m, stride_out_n: Row and column strides for output tensor. + group_rank: This rank's index within the ProcessGroup (0..world_size-1). + iris_rank: This rank's global index in the iris context (for RMA addressing). + world_size: Total number of ranks in the group. + rank_start: First iris rank in the group (for RMA target computation). + rank_stride: Stride between consecutive iris ranks in the group. + BLOCK_SIZE_M: Number of rows per tile. + BLOCK_SIZE_N: Number of columns per tile. + GROUP_SIZE_M: Swizzle group size for M-dimension tiling. + COMM_SMS: Number of CUs used for persistent scheduling. + THREADS_PER_WARP: Threads per warp/wavefront (64 for AMD, 32 for NVIDIA). + WARPS_PER_CTA: Number of warps per workgroup. Must match num_warps. + """ + ctx = IrisDeviceCtx.initialize(context_tensor, tracing=False) + + pid = gl.program_id(0) + + num_pid_m = gl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = gl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + # Flat 1D layout covering BLOCK_SIZE_M * BLOCK_SIZE_N elements + TOTAL_ELEMS: gl.constexpr = BLOCK_SIZE_M * BLOCK_SIZE_N + ELEMS_PER_THREAD: gl.constexpr = TOTAL_ELEMS // (THREADS_PER_WARP * WARPS_PER_CTA) + flat_layout: gl.constexpr = gl.BlockedLayout([ELEMS_PER_THREAD], [THREADS_PER_WARP], [WARPS_PER_CTA], [0]) + + # Hoist local heap base outside the tile loop: eliminates redundant + # gl.load(heap_bases) calls in the inner store loop. + local_base = gl.load(ctx.heap_bases + iris_rank) + + for tile_id in range(pid, total_tiles, COMM_SMS): + # Swizzled tile index computation for better L2 locality + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + # Flat index -> 2D row/col within tile + flat_idx = gl.arange(0, TOTAL_ELEMS, layout=flat_layout) + row_local = flat_idx // BLOCK_SIZE_N + col_local = flat_idx % BLOCK_SIZE_N + + # Global row/col + row = pid_m * BLOCK_SIZE_M + row_local + col = pid_n * BLOCK_SIZE_N + col_local + + mask = (row < M) & (col < N) + + # Single flat load of the entire tile + input_offsets = row * stride_in_m + col * stride_in_n + input_addr = input_ptr + input_offsets + data = gl.load(input_addr, mask=mask, other=0.0) + + # Output: this rank's data goes to output[group_rank * M + row, col] + output_row = group_rank * M + row + output_offsets = output_row * stride_out_m + col * stride_out_n + + # Traffic-shaped stores to all ranks: stagger write order per rank + # so each rank writes to a different target at any given moment, + # avoiding memory controller contention on the receiver side. + for rank_idx in range(world_size): + dest_idx = (group_rank + rank_idx) % world_size + target_iris_rank = rank_start + dest_idx * rank_stride + output_ptrs = output_ptr + output_offsets + + if dest_idx == group_rank: + gl.store(output_ptrs, data, mask=mask, cache_modifier=".wt") + else: + # Hoisted translation: compute ptr_delta from pre-loaded + # local_base rather than calling ctx.store() which would + # do 2x gl.load(heap_bases) per call. + target_base = gl.load(ctx.heap_bases + target_iris_rank) + ptr_delta = target_base - local_base + output_ptrs_int = tl.cast(output_ptrs, gl.uint64) + remote_ptrs_int = output_ptrs_int + ptr_delta + remote_ptrs = tl.cast(remote_ptrs_int, output_ptrs.dtype) + gl.store(remote_ptrs, data, mask=mask) + + def all_gather( output_tensor, input_tensor, @@ -314,26 +476,11 @@ def all_gather( if config is None: config = Config(block_size_m=32, block_size_n=64) - # Check for unsupported options - if config.use_gluon: - raise ValueError( - "all_gather does not support use_gluon=True. " - "Gluon implementation is not available for all_gather. " - "Use default config (use_gluon=False)." - ) - # Extract group information # rank_in_group: position within the ProcessGroup (0, 1, 2, ...) - passed as group_rank to kernel # rank_global: global rank in iris context - passed as iris_rank to kernel for RMA operations rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, shmem) - # Validate COMM_SMS divisibility for partitioned variant - if config.all_gather_variant == "partitioned" and config.comm_sms % world_size != 0: - raise ValueError( - f"For all_gather_variant='partitioned', COMM_SMS ({config.comm_sms}) must be divisible by world_size ({world_size}). " - f"Please adjust config.comm_sms to be a multiple of {world_size}." - ) - M, N = input_tensor.shape[:2] expected_output_shape = (world_size * M, N) @@ -346,41 +493,121 @@ def all_gather( stride_in_m, stride_in_n = input_tensor.stride(0), input_tensor.stride(1) stride_out_m, stride_out_n = output_tensor.stride(0), output_tensor.stride(1) - heap_bases = shmem.get_heap_bases() + # Choose between Triton and Gluon implementation + if config.use_gluon and GLUON_AVAILABLE: + # Check if shmem is Iris Gluon (has get_device_context method) + if not hasattr(shmem, "get_device_context"): + raise ValueError("use_gluon=True requires Iris Gluon context. Use iris.experimental.iris_gluon.iris()") - # Dispatch to the appropriate kernel based on variant - if config.all_gather_variant == "persistent": - kernel_fn = persistent_all_gather - elif config.all_gather_variant == "partitioned": - kernel_fn = persistent_all_gather_partitioned + # Gluon only supports the persistent variant + if config.all_gather_variant != "persistent": + raise ValueError( + f"Gluon all_gather only supports all_gather_variant='persistent', got '{config.all_gather_variant}'." + ) + + # Apply optimal defaults for gluon flat-2D kernel when user hasn't + # overridden block sizes from the Config defaults (32x64). + block_size_m = config.block_size_m + block_size_n = config.block_size_n + if block_size_m == 32 and block_size_n == 64: + # User didn't override — use optimal flat-2D tile: 8x256 + block_size_m = 8 + block_size_n = 256 + + # Validate flat-2D layout constraints. + # TOTAL_ELEMS = BLOCK_SIZE_M * BLOCK_SIZE_N must be a multiple of + # THREADS_PER_WARP * WARPS_PER_CTA so each thread gets a whole + # number of elements. + total_elems = block_size_m * block_size_n + threads_per_cta = config.threads_per_warp * config.num_warps + if total_elems < threads_per_cta: + raise ValueError( + f"Gluon all-gather requires block_size_m * block_size_n >= " + f"threads_per_warp * num_warps ({threads_per_cta}), " + f"got {block_size_m} * {block_size_n} = {total_elems}." + ) + if total_elems % threads_per_cta != 0: + raise ValueError( + f"Gluon all-gather requires block_size_m * block_size_n to be a " + f"multiple of threads_per_warp * num_warps ({threads_per_cta}), " + f"got {block_size_m} * {block_size_n} = {total_elems}. " + f"Recommended: block_size_m=8, block_size_n=256." + ) + + context_tensor = shmem.get_device_context() + + persistent_all_gather_gluon[(config.comm_sms,)]( + IrisDeviceCtx, + context_tensor, + input_tensor, + output_tensor, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + rank_in_group, + rank_global, + world_size, + rank_start, + rank_stride, + block_size_m, + block_size_n, + config.swizzle_size, + config.comm_sms, + config.threads_per_warp, + config.num_warps, + num_stages=config.num_stages, + num_warps=config.num_warps, + waves_per_eu=config.waves_per_eu, + ) else: - raise ValueError(f"Unknown all_gather_variant: {config.all_gather_variant}") + if config.use_gluon and not GLUON_AVAILABLE: + raise ValueError("Gluon is not available. Install Triton with Gluon support or set use_gluon=False") + + # Validate COMM_SMS divisibility for partitioned variant + if config.all_gather_variant == "partitioned" and config.comm_sms % world_size != 0: + raise ValueError( + f"For all_gather_variant='partitioned', COMM_SMS ({config.comm_sms}) must be divisible by world_size ({world_size}). " + f"Please adjust config.comm_sms to be a multiple of {world_size}." + ) - kernel_fn[(config.comm_sms,)]( - input_tensor, - output_tensor, - M, - N, - stride_in_m, - stride_in_n, - stride_out_m, - stride_out_n, - heap_bases, - rank_in_group, - rank_global, - world_size, - rank_start, - rank_stride, - config.block_size_m, - config.block_size_n, - config.swizzle_size, - config.comm_sms, - config.num_xcds, - config.chunk_size, - num_stages=config.num_stages, - num_warps=config.num_warps, - waves_per_eu=config.waves_per_eu, - ) + heap_bases = shmem.get_heap_bases() + + # Dispatch to the appropriate kernel based on variant + if config.all_gather_variant == "persistent": + kernel_fn = persistent_all_gather + elif config.all_gather_variant == "partitioned": + kernel_fn = persistent_all_gather_partitioned + else: + raise ValueError(f"Unknown all_gather_variant: {config.all_gather_variant}") + + kernel_fn[(config.comm_sms,)]( + input_tensor, + output_tensor, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases, + rank_in_group, + rank_global, + world_size, + rank_start, + rank_stride, + config.block_size_m, + config.block_size_n, + config.swizzle_size, + config.comm_sms, + config.num_xcds, + config.chunk_size, + num_stages=config.num_stages, + num_warps=config.num_warps, + waves_per_eu=config.waves_per_eu, + ) if not async_op: shmem.barrier() diff --git a/iris/ccl/config.py b/iris/ccl/config.py index 5f1f8b9c..1084de06 100644 --- a/iris/ccl/config.py +++ b/iris/ccl/config.py @@ -44,6 +44,16 @@ class Config: (default: auto-set to block_size_n // world_size at runtime) reduce_scatter_variant: Variant for reduce-scatter operation (default: "two_shot") Only "two_shot" is supported + num_stages: Number of pipeline stages for the kernel (default: 1) + num_warps: Number of warps per workgroup (default: 4). For gluon kernels, + this also sets WARPS_PER_CTA in the BlockedLayout. The product + threads_per_warp * num_warps determines the minimum tile size + (block_size_m * block_size_n for flat-2D, or block_size_n for 1D). + threads_per_warp: Threads per warp/wavefront (default: 64). Must match the + hardware wavefront size: 64 for AMD GPUs, 32 for NVIDIA. + Used by gluon kernels to construct BlockedLayout for + vectorized memory access. + waves_per_eu: Waves per execution unit hint for occupancy (default: 0, auto) Example: >>> import iris @@ -82,6 +92,7 @@ class Config: reduce_scatter_variant: str = "two_shot" num_stages: int = 1 num_warps: int = 4 + threads_per_warp: int = 64 waves_per_eu: int = 0 def __post_init__(self): @@ -132,3 +143,8 @@ def __post_init__(self): # Validate reduce_scatter_variant if self.reduce_scatter_variant != "two_shot": raise ValueError(f"reduce_scatter_variant must be 'two_shot', got '{self.reduce_scatter_variant}'") + + if self.threads_per_warp not in (32, 64): + raise ValueError(f"threads_per_warp must be 32 (NVIDIA) or 64 (AMD), got {self.threads_per_warp}") + if self.num_warps <= 0: + raise ValueError(f"num_warps must be positive, got {self.num_warps}") diff --git a/tests/ccl/test_all_gather_gluon.py b/tests/ccl/test_all_gather_gluon.py new file mode 100644 index 00000000..a912bfe1 --- /dev/null +++ b/tests/ccl/test_all_gather_gluon.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Test suite for all-gather collective operation using Gluon. +""" + +import os + +import pytest +import torch +import torch.distributed as dist + +# Try to import Gluon, skip tests if not available +try: + import iris.experimental.iris_gluon as iris_gluon + from iris.ccl import Config + from iris.ccl.all_gather import all_gather, GLUON_AVAILABLE +except ImportError: + GLUON_AVAILABLE = False + + +@pytest.mark.skipif(not GLUON_AVAILABLE, reason="Gluon not available") +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.float32, + torch.bfloat16, + ], +) +@pytest.mark.parametrize( + "M, N, block_size_m, block_size_n", + [ + # block_size_n must be a multiple of (threads_per_warp * num_warps). + # With defaults (threads_per_warp=64, num_warps=4), minimum is 256. + # elems_per_thread = block_size_n / 256: higher = wider vector loads. + (256, 256, 32, 256), # Small: elems_per_thread=1 (scalar loads) + (1024, 512, 32, 512), # Medium: elems_per_thread=2 (dword loads) + (8192, 8192, 32, 1024), # Large: elems_per_thread=4 (dwordx4, optimal) + ], +) +def test_all_gather_gluon(dtype, M, N, block_size_m, block_size_n): + """Test all-gather functionality using Gluon by comparing against PyTorch's implementation.""" + # Ensure torch.distributed is initialized (should be done by test runner) + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + # Size heap to fit input (M*N) + output (max_ranks*M*N) with headroom + max_ranks = int(os.environ.get("WORLD_SIZE", 8)) + elem_size = torch.tensor([], dtype=dtype).element_size() + needed = (1 + max_ranks) * M * N * elem_size + heap_size = max(2**30, int(needed * 2)) # 2x headroom, minimum 1GB + shmem = iris_gluon.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Each rank has an M x N input tensor + # Output is (world_size * M, N) - concatenated along dimension 0 + pytorch_input_tensor = torch.randn(M, N, dtype=dtype, device=f"cuda:{rank}") + # Fill with deterministic values for easier debugging + pytorch_input_tensor.fill_(float(rank + 1)) + + # Create output tensor for PyTorch: (world_size * M, N) + pytorch_output_tensor = torch.zeros(world_size * M, N, dtype=dtype, device=f"cuda:{rank}") + + # Run PyTorch's all_gather_into_tensor to get reference output + shmem.barrier() + dist.all_gather_into_tensor(pytorch_output_tensor, pytorch_input_tensor) + torch.cuda.synchronize() + + # Now set up Iris Gluon all_gather + iris_input_tensor = shmem.zeros((M, N), dtype=dtype) + iris_input_tensor.copy_(pytorch_input_tensor) + + iris_output_tensor = shmem.zeros((world_size * M, N), dtype=dtype) + + # Run Iris Gluon all_gather + shmem.barrier() + config = Config(use_gluon=True, block_size_m=block_size_m, block_size_n=block_size_n) + all_gather(iris_output_tensor, iris_input_tensor, shmem, config=config) + torch.cuda.synchronize() + + # Compare results + atol = 1e-3 if dtype == torch.float16 else 1e-5 + max_diff = torch.abs(iris_output_tensor - pytorch_output_tensor).max().item() + + try: + assert torch.allclose(iris_output_tensor, pytorch_output_tensor, atol=atol), ( + f"Max difference: {max_diff}, expected < {atol}\n" + f"Rank {rank}: Iris Gluon output doesn't match PyTorch's all_gather_into_tensor" + ) + finally: + # Final barrier to ensure all ranks complete before test cleanup + # This helps with test isolation when running multiple tests + # Note: shmem.barrier() already does cuda.synchronize() + shmem.barrier() + # Explicitly delete the shmem instance to trigger cleanup + del shmem + # Force garbage collection to ensure IPC handles are cleaned up + import gc + + gc.collect()