From 487c1ee4b8f8a4c749d08d831e97bcc74bd6603b Mon Sep 17 00:00:00 2001 From: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com> Date: Mon, 23 Mar 2026 11:43:12 -0700 Subject: [PATCH 01/15] Add Gluon all-gather kernel with row-wise tiling and dwordx4 vectorization Gluon all-gather kernel that uses explicit BlockedLayout for column-dimension vectorization. Each row is loaded once and broadcast to all ranks via ctx.store(), avoiding redundant loads and enabling dwordx4 memory ops. Key design: - Row-by-row iteration: load row once, write to all ranks (1 load, W stores) - Explicit BlockedLayout([SPT], [64], [4], [0]) on column dimension where SPT = block_size_n / 256, controlling vector width: SPT=1 -> scalar, SPT=2 -> dword, SPT=4 -> dwordx4 (optimal) - Uses ctx.store() for remote writes (compiler-optimized pointer translation) - Optimal tile: bm=32, bn=1024 (SPT=4) for dwordx4 on AMD GFX9+ Also includes: - threads_per_warp config field for BlockedLayout construction - Simplified IrisDeviceCtx (tracing removed for cleaner codegen) - Parameterized correctness tests for multiple tile sizes and dtypes Co-Authored-By: Claude Opus 4.6 --- iris/ccl/all_gather.py | 298 ++++++++++++++++++++++++----- iris/ccl/config.py | 15 ++ iris/experimental/iris_gluon.py | 279 +-------------------------- tests/ccl/test_all_gather_gluon.py | 99 ++++++++++ 4 files changed, 371 insertions(+), 320 deletions(-) create mode 100644 tests/ccl/test_all_gather_gluon.py diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index 2093fb3b..36ed2cba 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -12,6 +12,16 @@ from .config import Config from .utils import extract_group_info +# Conditional import for Gluon +try: + from triton.experimental import gluon + from triton.experimental.gluon import language as gl + from iris.experimental.iris_gluon import IrisDeviceCtx + + GLUON_AVAILABLE = True +except ImportError: + GLUON_AVAILABLE = False + @triton.jit() def persistent_all_gather( @@ -279,6 +289,147 @@ def persistent_all_gather_partitioned( ) +# Gluon implementation +if GLUON_AVAILABLE: + + @gluon.jit + def persistent_all_gather_gluon( + IrisDeviceCtx: gl.constexpr, + context_tensor, + input_ptr, + output_ptr, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + group_rank: gl.constexpr, + iris_rank: gl.constexpr, + world_size: gl.constexpr, + rank_start: gl.constexpr, + rank_stride: gl.constexpr, + BLOCK_SIZE_M: gl.constexpr, + BLOCK_SIZE_N: gl.constexpr, + GROUP_SIZE_M: gl.constexpr, + COMM_SMS: gl.constexpr, + NUM_XCDS: gl.constexpr, + CHUNK_SIZE: gl.constexpr, + THREADS_PER_WARP: gl.constexpr, + WARPS_PER_CTA: gl.constexpr, + ): + """ + Persistent all-gather kernel using Gluon with explicit memory layout control. + + Each rank loads its local input once per row and writes it to the + corresponding output slice on ALL ranks (local + remote), avoiding + redundant loads. Column indices use an explicit BlockedLayout to + control vectorization width. + + Memory layout (BlockedLayout): + The column dimension is distributed across the GPU thread hierarchy + using gl.BlockedLayout([ELEMS_PER_THREAD], [THREADS_PER_WARP], [WARPS_PER_CTA], [order]). + + - ELEMS_PER_THREAD: number of contiguous elements each thread loads/stores. + Controls the vector width of memory instructions. For fp16: + 1 -> 2-byte scalar load + 2 -> 4-byte dword load + 4 -> 8-byte dwordx4 load (optimal on AMD GFX9+) + - THREADS_PER_WARP: threads per warp/wavefront (64 on AMD, 32 on NVIDIA). + - WARPS_PER_CTA: number of warps in the cooperative thread array (workgroup). + + The product ELEMS_PER_THREAD * THREADS_PER_WARP * WARPS_PER_CTA must + equal BLOCK_SIZE_N. ELEMS_PER_THREAD is derived as: + ELEMS_PER_THREAD = BLOCK_SIZE_N // (THREADS_PER_WARP * WARPS_PER_CTA) + + Constraints (validated by host wrapper before launch): + - BLOCK_SIZE_N must be a multiple of (THREADS_PER_WARP * WARPS_PER_CTA). + - BLOCK_SIZE_N must be >= (THREADS_PER_WARP * WARPS_PER_CTA) so that + ELEMS_PER_THREAD >= 1. + - WARPS_PER_CTA must match the num_warps kernel launch parameter. + - THREADS_PER_WARP must match the hardware wavefront size (64 for AMD). + + Args: + IrisDeviceCtx: Gluon device context class for remote memory operations. + context_tensor: Opaque tensor holding IrisDeviceCtx state. + input_ptr: Pointer to local input tensor of shape (M, N). + output_ptr: Pointer to output tensor of shape (world_size * M, N). + M: Number of rows in the input tensor (per rank). + N: Number of columns. + stride_in_m, stride_in_n: Row and column strides for input tensor. + stride_out_m, stride_out_n: Row and column strides for output tensor. + group_rank: This rank's index within the ProcessGroup (0..world_size-1). + iris_rank: This rank's global index in the iris context (for RMA addressing). + world_size: Total number of ranks in the group. + rank_start: First iris rank in the group (for RMA target computation). + rank_stride: Stride between consecutive iris ranks in the group. + BLOCK_SIZE_M: Number of rows per tile. + BLOCK_SIZE_N: Number of columns per tile. Must be a multiple of + (THREADS_PER_WARP * WARPS_PER_CTA). + GROUP_SIZE_M: Swizzle group size for M-dimension tiling. + COMM_SMS: Number of SMs used for persistent scheduling. + NUM_XCDS: Number of XCDs (chiplet count). + CHUNK_SIZE: Chunk size for XCD-aware tile mapping. + THREADS_PER_WARP: Threads per warp/wavefront (64 for AMD, 32 for NVIDIA). + WARPS_PER_CTA: Number of warps per workgroup. Must match num_warps. + """ + ctx = IrisDeviceCtx.initialize(context_tensor) + + pid = gl.program_id(0) + + num_pid_m = gl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = gl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + # Build the 1D BlockedLayout for the column dimension. + # ELEMS_PER_THREAD controls how many contiguous elements each thread + # handles, which directly maps to the vector load/store width: + # elems=1 -> scalar, elems=2 -> dword, elems=4 -> dwordx4 (optimal) + ELEMS_PER_THREAD: gl.constexpr = BLOCK_SIZE_N // (THREADS_PER_WARP * WARPS_PER_CTA) + col_layout: gl.constexpr = gl.BlockedLayout([ELEMS_PER_THREAD], [THREADS_PER_WARP], [WARPS_PER_CTA], [0]) + + for tile_id in range(pid, total_tiles, COMM_SMS): + # Swizzled tile index computation for better L2 locality + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + # Build column index vector with explicit layout for vectorized access + rn = (pid_n * BLOCK_SIZE_N + gl.arange(0, BLOCK_SIZE_N, layout=col_layout)) % N + rn = gl.max_contiguous(gl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + col_offsets_in = rn * stride_in_n + col_offsets_out = rn * stride_out_n + col_mask = rn < N + + rm_base = pid_m * BLOCK_SIZE_M + + # Iterate row-by-row: load each row once, then write to all ranks. + # This avoids reloading the same data for each destination rank. + for i in range(BLOCK_SIZE_M): + row_idx = (rm_base + i) % M + + if row_idx < M: + # Single load from local input + data = gl.load(input_ptr + row_idx * stride_in_m + col_offsets_in, mask=col_mask) + + # Output row position: this rank's slice starts at group_rank * M + output_offset = (group_rank * M + row_idx) * stride_out_m + col_offsets_out + + # Broadcast to every rank (local store + remote RMA writes) + for rank_idx in range(world_size): + target_rank = rank_start + rank_idx * rank_stride + output_ptr_target = output_ptr + output_offset + + if rank_idx == group_rank: + gl.store(output_ptr_target, data, mask=col_mask, cache_modifier=".wt") + else: + ctx.store(output_ptr_target, data, target_rank, mask=col_mask) + + def all_gather( output_tensor, input_tensor, @@ -314,26 +465,11 @@ def all_gather( if config is None: config = Config(block_size_m=32, block_size_n=64) - # Check for unsupported options - if config.use_gluon: - raise ValueError( - "all_gather does not support use_gluon=True. " - "Gluon implementation is not available for all_gather. " - "Use default config (use_gluon=False)." - ) - # Extract group information # rank_in_group: position within the ProcessGroup (0, 1, 2, ...) - passed as group_rank to kernel # rank_global: global rank in iris context - passed as iris_rank to kernel for RMA operations rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, shmem) - # Validate COMM_SMS divisibility for partitioned variant - if config.all_gather_variant == "partitioned" and config.comm_sms % world_size != 0: - raise ValueError( - f"For all_gather_variant='partitioned', COMM_SMS ({config.comm_sms}) must be divisible by world_size ({world_size}). " - f"Please adjust config.comm_sms to be a multiple of {world_size}." - ) - M, N = input_tensor.shape[:2] expected_output_shape = (world_size * M, N) @@ -346,41 +482,109 @@ def all_gather( stride_in_m, stride_in_n = input_tensor.stride(0), input_tensor.stride(1) stride_out_m, stride_out_n = output_tensor.stride(0), output_tensor.stride(1) - heap_bases = shmem.get_heap_bases() + # Choose between Triton and Gluon implementation + if config.use_gluon and GLUON_AVAILABLE: + # Check if shmem is Iris Gluon (has get_device_context method) + if not hasattr(shmem, "get_device_context"): + raise ValueError("use_gluon=True requires Iris Gluon context. Use iris.experimental.iris_gluon.iris()") + + # Validate BlockedLayout constraints. + # The gluon kernel distributes BLOCK_SIZE_N elements across the thread + # hierarchy: ELEMS_PER_THREAD * THREADS_PER_WARP * WARPS_PER_CTA = BLOCK_SIZE_N. + # ELEMS_PER_THREAD controls vector load width (4 = dwordx4 for fp16, optimal). + threads_per_cta = config.threads_per_warp * config.num_warps + if config.block_size_n < threads_per_cta: + raise ValueError( + f"Gluon all-gather requires block_size_n >= threads_per_warp * num_warps " + f"({config.threads_per_warp} * {config.num_warps} = {threads_per_cta}), " + f"got block_size_n={config.block_size_n}." + ) + if config.block_size_n % threads_per_cta != 0: + raise ValueError( + f"Gluon all-gather requires block_size_n to be a multiple of " + f"threads_per_warp * num_warps ({threads_per_cta}), " + f"got block_size_n={config.block_size_n}. " + f"This ensures each thread handles a whole number of elements. " + f"Recommended: block_size_n=1024 with threads_per_warp=64, num_warps=4 " + f"for dwordx4 vectorization (elems_per_thread=4)." + ) - # Dispatch to the appropriate kernel based on variant - if config.all_gather_variant == "persistent": - kernel_fn = persistent_all_gather - elif config.all_gather_variant == "partitioned": - kernel_fn = persistent_all_gather_partitioned + context_tensor = shmem.get_device_context() + + persistent_all_gather_gluon[(config.comm_sms,)]( + IrisDeviceCtx, + context_tensor, + input_tensor, + output_tensor, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + rank_in_group, + rank_global, + world_size, + rank_start, + rank_stride, + config.block_size_m, + config.block_size_n, + config.swizzle_size, + config.comm_sms, + config.num_xcds, + config.chunk_size, + config.threads_per_warp, + config.num_warps, + num_stages=config.num_stages, + num_warps=config.num_warps, + waves_per_eu=config.waves_per_eu, + ) else: - raise ValueError(f"Unknown all_gather_variant: {config.all_gather_variant}") + if config.use_gluon and not GLUON_AVAILABLE: + raise ValueError("Gluon is not available. Install Triton with Gluon support or set use_gluon=False") + + # Validate COMM_SMS divisibility for partitioned variant + if config.all_gather_variant == "partitioned" and config.comm_sms % world_size != 0: + raise ValueError( + f"For all_gather_variant='partitioned', COMM_SMS ({config.comm_sms}) must be divisible by world_size ({world_size}). " + f"Please adjust config.comm_sms to be a multiple of {world_size}." + ) - kernel_fn[(config.comm_sms,)]( - input_tensor, - output_tensor, - M, - N, - stride_in_m, - stride_in_n, - stride_out_m, - stride_out_n, - heap_bases, - rank_in_group, - rank_global, - world_size, - rank_start, - rank_stride, - config.block_size_m, - config.block_size_n, - config.swizzle_size, - config.comm_sms, - config.num_xcds, - config.chunk_size, - num_stages=config.num_stages, - num_warps=config.num_warps, - waves_per_eu=config.waves_per_eu, - ) + heap_bases = shmem.get_heap_bases() + + # Dispatch to the appropriate kernel based on variant + if config.all_gather_variant == "persistent": + kernel_fn = persistent_all_gather + elif config.all_gather_variant == "partitioned": + kernel_fn = persistent_all_gather_partitioned + else: + raise ValueError(f"Unknown all_gather_variant: {config.all_gather_variant}") + + kernel_fn[(config.comm_sms,)]( + input_tensor, + output_tensor, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases, + rank_in_group, + rank_global, + world_size, + rank_start, + rank_stride, + config.block_size_m, + config.block_size_n, + config.swizzle_size, + config.comm_sms, + config.num_xcds, + config.chunk_size, + num_stages=config.num_stages, + num_warps=config.num_warps, + waves_per_eu=config.waves_per_eu, + ) if not async_op: shmem.barrier() diff --git a/iris/ccl/config.py b/iris/ccl/config.py index 5f1f8b9c..9dd35cef 100644 --- a/iris/ccl/config.py +++ b/iris/ccl/config.py @@ -44,6 +44,15 @@ class Config: (default: auto-set to block_size_n // world_size at runtime) reduce_scatter_variant: Variant for reduce-scatter operation (default: "two_shot") Only "two_shot" is supported + num_stages: Number of pipeline stages for the kernel (default: 1) + num_warps: Number of warps per workgroup (default: 4). For gluon kernels, + this also sets WARPS_PER_CTA in the BlockedLayout. The product + threads_per_warp * num_warps determines the minimum block_size_n. + threads_per_warp: Threads per warp/wavefront (default: 64). Must match the + hardware wavefront size: 64 for AMD GPUs, 32 for NVIDIA. + Used by gluon kernels to construct BlockedLayout for + vectorized memory access. + waves_per_eu: Waves per execution unit hint for occupancy (default: 0, auto) Example: >>> import iris @@ -82,6 +91,7 @@ class Config: reduce_scatter_variant: str = "two_shot" num_stages: int = 1 num_warps: int = 4 + threads_per_warp: int = 64 waves_per_eu: int = 0 def __post_init__(self): @@ -132,3 +142,8 @@ def __post_init__(self): # Validate reduce_scatter_variant if self.reduce_scatter_variant != "two_shot": raise ValueError(f"reduce_scatter_variant must be 'two_shot', got '{self.reduce_scatter_variant}'") + + if self.threads_per_warp not in (32, 64): + raise ValueError(f"threads_per_warp must be 32 (NVIDIA) or 64 (AMD), got {self.threads_per_warp}") + if self.num_warps <= 0: + raise ValueError(f"num_warps must be positive, got {self.num_warps}") diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 97add62f..4ef4aea1 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -51,8 +51,6 @@ count_devices, ) from iris.symmetric_heap import SymmetricHeap -from iris import device_utils -from iris.tracing.core import Tracing import numpy as np import torch import logging @@ -64,151 +62,6 @@ from .. import tensor_creation -class _GluonDeviceTracingCls: - """ - Gluon-native device-side tracing: records events into SoA buffers from inside Gluon kernels. - - Created by IrisDeviceCtx.initialize() when tracing=True. Use record_event_start - / record_event_end to bracket operations; events are exported via Tracing.export(). - """ - - enabled: tl.constexpr - rank: gl.tensor - max_events: gl.tensor - counter: gl.tensor - op_index_counter: gl.tensor - buf_event_id: gl.tensor - buf_pid: gl.tensor - buf_pid_m: gl.tensor - buf_pid_n: gl.tensor - buf_cur_rank: gl.tensor - buf_target_rank: gl.tensor - buf_xcc_id: gl.tensor - buf_cu_id: gl.tensor - buf_timestamp: gl.tensor - buf_address: gl.tensor - buf_duration_cycles: gl.tensor - buf_op_index: gl.tensor - buf_payload_size: gl.tensor - - def __init__( - self, - enabled, - rank, - max_events, - counter, - op_index_counter, - buf_event_id, - buf_pid, - buf_pid_m, - buf_pid_n, - buf_cur_rank, - buf_target_rank, - buf_xcc_id, - buf_cu_id, - buf_timestamp, - buf_address, - buf_duration_cycles, - buf_op_index, - buf_payload_size, - ): - """Construct GluonDeviceTracing (called from IrisDeviceCtx.initialize).""" - self.enabled = enabled - self.rank = rank - self.max_events = max_events - self.counter = counter - self.op_index_counter = op_index_counter - self.buf_event_id = buf_event_id - self.buf_pid = buf_pid - self.buf_pid_m = buf_pid_m - self.buf_pid_n = buf_pid_n - self.buf_cur_rank = buf_cur_rank - self.buf_target_rank = buf_target_rank - self.buf_xcc_id = buf_xcc_id - self.buf_cu_id = buf_cu_id - self.buf_timestamp = buf_timestamp - self.buf_address = buf_address - self.buf_duration_cycles = buf_duration_cycles - self.buf_op_index = buf_op_index - self.buf_payload_size = buf_payload_size - - @gluon.jit - def record_event_start( - self, - event_id: tl.constexpr, - target_rank, - address, - pid_m, - pid_n, - mask=None, - ): - """ - Record start of a traced operation. Returns a handle for record_event_end. - - Only stores when event_idx < max_events (bounds check). - cur_rank is taken from the tracing context (self.rank). - - Args: - event_id: Event type ID (constexpr) - target_rank: Target rank for the operation - address: Memory address(es) - can be 1D or 2D block of pointers. - pid_m: Program ID in M dimension - pid_n: Program ID in N dimension - mask: Optional mask tensor indicating valid elements. - """ - if not self.enabled: - return tl.cast(0, tl.int32) - - event_idx = tl.atomic_add(self.counter, 1) - op_index = tl.atomic_add(self.op_index_counter, 1) - - # Calculate payload_size from mask and datatype - if mask is not None: - mask_i32 = tl.cast(mask, tl.int32) - num_elements = gl.sum(mask_i32, axis=0) - elem_type = address.dtype.element_ty - bitwidth = elem_type.primitive_bitwidth - elem_size_bytes = bitwidth // 8 - payload_size = num_elements * tl.cast(elem_size_bytes, tl.int32) - else: - payload_size = tl.cast(0, tl.int32) - - if event_idx < self.max_events: - tl.store(self.buf_event_id + event_idx, tl.cast(event_id, tl.int32)) - tl.store(self.buf_pid + event_idx, tl.cast(gl.program_id(0), tl.int32)) - tl.store(self.buf_pid_m + event_idx, tl.cast(pid_m, tl.int32)) - tl.store(self.buf_pid_n + event_idx, tl.cast(pid_n, tl.int32)) - tl.store(self.buf_cur_rank + event_idx, tl.cast(self.rank, tl.int32)) - tl.store(self.buf_target_rank + event_idx, tl.cast(target_rank, tl.int32)) - tl.store(self.buf_xcc_id + event_idx, device_utils.get_xcc_id()) - tl.store(self.buf_cu_id + event_idx, device_utils.get_cu_id()) - tl.store(self.buf_timestamp + event_idx, device_utils.read_realtime()) - addr_i64 = tl.cast(address, tl.int64) - tl.store(self.buf_address + event_idx, gl.min(addr_i64, axis=0)) - tl.store(self.buf_duration_cycles + event_idx, tl.cast(0, tl.int64)) - tl.store(self.buf_op_index + event_idx, op_index) - tl.store(self.buf_payload_size + event_idx, tl.cast(payload_size, tl.int32)) - return event_idx - - @gluon.jit - def record_event_end(self, handle): - """ - Record end timestamp for the event started with record_event_start(handle). - - Only stores when handle < max_events (bounds check). - """ - if not self.enabled: - return - - end_ts = device_utils.read_realtime() - if handle < self.max_events: - tl.store(self.buf_duration_cycles + handle, end_ts) - - -_GluonDeviceTracingCls.__init__.__triton_builtin__ = True -GluonDeviceTracing = aggregate(_GluonDeviceTracingCls) - - @aggregate class IrisDeviceCtx: """ @@ -221,36 +74,28 @@ class IrisDeviceCtx: cur_rank: Current rank ID num_ranks: Total number of ranks heap_bases: Pointer to array of heap base addresses for all ranks - tracing: GluonDeviceTracing instance (active when tracing=True) """ cur_rank: gl.tensor num_ranks: gl.tensor heap_bases: gl.tensor - tracing: GluonDeviceTracing @gluon.constexpr_function - def __init__(self, cur_rank, num_ranks, heap_bases, tracing): + def __init__(self, cur_rank, num_ranks, heap_bases): self.cur_rank = cur_rank self.num_ranks = num_ranks self.heap_bases = heap_bases - self.tracing = tracing @staticmethod @gluon.jit - def initialize(context_tensor, tracing: gl.constexpr = False): + def initialize(context_tensor): """ Initialize `IrisDeviceCtx` from the encoded tensor. - The context tensor has the format: - ``[cur_rank, num_ranks, heap_base_0, heap_base_1, ..., trace_info...]`` - - If tracing is enabled on the host (via ``shmem.tracing.enable()``), the - context tensor also contains tracing buffer pointers after the heap bases. + The context tensor has the format: `[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]` Args: context_tensor: Pointer to encoded context data - tracing: Enable event tracing (constexpr, default: False) Returns: `IrisDeviceCtx`: Initialized device context @@ -262,82 +107,7 @@ def initialize(context_tensor, tracing: gl.constexpr = False): # Extract heap bases (from index 2 onwards) heap_bases = context_tensor + 2 # Offset pointer to start at heap bases - if tracing: - # Extract tracing info: starts after heap_bases, then skip trace_enabled flag - # Layout: [cur_rank, num_ranks, heap_base_0..N-1, trace_enabled, max_events, - # trace_counter_ptr, op_index_counter_ptr, buf_event_id, ...(13 buffers)] - trace_info_base = 2 + num_ranks + 1 # skip cur_rank, num_ranks, heap_bases, trace_enabled - max_events = tl.cast(gl.load(context_tensor + trace_info_base + 0), tl.int32) - trace_counter_ptr = gl.load(context_tensor + trace_info_base + 1) - op_index_counter_ptr = gl.load(context_tensor + trace_info_base + 2) - - # Cast counter pointers - trace_counter = tl.cast(trace_counter_ptr, tl.pointer_type(tl.int32)) - op_index_counter = tl.cast(op_index_counter_ptr, tl.pointer_type(tl.int32)) - - # Extract trace buffer pointers (13 buffers, same order as Iris._build_device_context) - buf_base = trace_info_base + 3 - buf_event_id = tl.cast(gl.load(context_tensor + buf_base + 0), tl.pointer_type(tl.int32)) - buf_pid = tl.cast(gl.load(context_tensor + buf_base + 1), tl.pointer_type(tl.int32)) - buf_pid_m = tl.cast(gl.load(context_tensor + buf_base + 2), tl.pointer_type(tl.int32)) - buf_pid_n = tl.cast(gl.load(context_tensor + buf_base + 3), tl.pointer_type(tl.int32)) - buf_cur_rank = tl.cast(gl.load(context_tensor + buf_base + 4), tl.pointer_type(tl.int32)) - buf_target_rank = tl.cast(gl.load(context_tensor + buf_base + 5), tl.pointer_type(tl.int32)) - buf_xcc_id = tl.cast(gl.load(context_tensor + buf_base + 6), tl.pointer_type(tl.int32)) - buf_cu_id = tl.cast(gl.load(context_tensor + buf_base + 7), tl.pointer_type(tl.int32)) - buf_timestamp = tl.cast(gl.load(context_tensor + buf_base + 8), tl.pointer_type(tl.int64)) - buf_address = tl.cast(gl.load(context_tensor + buf_base + 9), tl.pointer_type(tl.int64)) - buf_duration_cycles = tl.cast(gl.load(context_tensor + buf_base + 10), tl.pointer_type(tl.int64)) - buf_op_index = tl.cast(gl.load(context_tensor + buf_base + 11), tl.pointer_type(tl.int32)) - buf_payload_size = tl.cast(gl.load(context_tensor + buf_base + 12), tl.pointer_type(tl.int32)) - - device_tracing = GluonDeviceTracing( - enabled=tracing, - rank=cur_rank, - max_events=max_events, - counter=trace_counter, - op_index_counter=op_index_counter, - buf_event_id=buf_event_id, - buf_pid=buf_pid, - buf_pid_m=buf_pid_m, - buf_pid_n=buf_pid_n, - buf_cur_rank=buf_cur_rank, - buf_target_rank=buf_target_rank, - buf_xcc_id=buf_xcc_id, - buf_cu_id=buf_cu_id, - buf_timestamp=buf_timestamp, - buf_address=buf_address, - buf_duration_cycles=buf_duration_cycles, - buf_op_index=buf_op_index, - buf_payload_size=buf_payload_size, - ) - else: - # When tracing disabled, use dummy pointers (never dereferenced) - dummy_ptr_i32 = tl.cast(context_tensor, tl.pointer_type(tl.int32)) - dummy_ptr_i64 = tl.cast(context_tensor, tl.pointer_type(tl.int64)) - max_events_zero = tl.cast(0, tl.int32) - device_tracing = GluonDeviceTracing( - enabled=tracing, - rank=cur_rank, - max_events=max_events_zero, - counter=dummy_ptr_i32, - op_index_counter=dummy_ptr_i32, - buf_event_id=dummy_ptr_i32, - buf_pid=dummy_ptr_i32, - buf_pid_m=dummy_ptr_i32, - buf_pid_n=dummy_ptr_i32, - buf_cur_rank=dummy_ptr_i32, - buf_target_rank=dummy_ptr_i32, - buf_xcc_id=dummy_ptr_i32, - buf_cu_id=dummy_ptr_i32, - buf_timestamp=dummy_ptr_i64, - buf_address=dummy_ptr_i64, - buf_duration_cycles=dummy_ptr_i64, - buf_op_index=dummy_ptr_i32, - buf_payload_size=dummy_ptr_i32, - ) - - return IrisDeviceCtx(cur_rank, num_ranks, heap_bases, device_tracing) + return IrisDeviceCtx(cur_rank, num_ranks, heap_bases) @gluon.jit def _translate(self, ptr, from_rank, to_rank): @@ -365,12 +135,6 @@ def _translate(self, ptr, from_rank, to_rank): # Cast to_base back to pointer type translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) - # Optimization to vectorize the load/store - similar to iris.py - # This enables the compiler to generate dwordx4 or wider loads - # Note: Gluon uses scalar multiples, not 2D tuples like Triton - # ptr = gl.max_contiguous(gl.multiple_of(ptr, 64), 64) - # translated_ptr = gl.max_contiguous(gl.multiple_of(translated_ptr, 64), 64) - return translated_ptr @gluon.jit @@ -743,9 +507,6 @@ def __init__(self, heap_size=1 << 30): distributed_barrier() - # Initialize tracing manager (disabled by default) - self.tracing = Tracing(self) - # Initialize CCL interface self.ccl = self.CCL(self) @@ -901,48 +662,20 @@ def _build_device_context(self): """ Build and cache the device context tensor. - Called during __init__ and again after tracing.enable() to include tracing fields. + Called during __init__ to pre-build the tensor once. """ # Convert heap_bases to a list for concatenation heap_bases_list = self.heap_bases.tolist() # Create context tensor: [cur_rank, num_ranks, heap_base_0, heap_base_1, ...] context_data = [self.cur_rank, self.num_ranks] + heap_bases_list - - # Add tracing info if enabled (same layout as Iris._build_device_context) - if self.tracing.enabled: - trace_buffer_ptrs = [ - self.tracing.trace_buffers["event_id"].data_ptr(), - self.tracing.trace_buffers["pid"].data_ptr(), - self.tracing.trace_buffers["pid_m"].data_ptr(), - self.tracing.trace_buffers["pid_n"].data_ptr(), - self.tracing.trace_buffers["cur_rank"].data_ptr(), - self.tracing.trace_buffers["target_rank"].data_ptr(), - self.tracing.trace_buffers["xcc_id"].data_ptr(), - self.tracing.trace_buffers["cu_id"].data_ptr(), - self.tracing.trace_buffers["timestamp"].data_ptr(), - self.tracing.trace_buffers["address"].data_ptr(), - self.tracing.trace_buffers["duration_cycles"].data_ptr(), - self.tracing.trace_buffers["op_index"].data_ptr(), - self.tracing.trace_buffers["payload_size"].data_ptr(), - ] - context_data += [ - 1, # trace_enabled = 1 (true) - self.tracing.max_events, - self.tracing.trace_counter.data_ptr(), - self.tracing.op_index_counter.data_ptr(), - ] + trace_buffer_ptrs - else: - context_data += [0] # trace_enabled = 0 (false) - self._device_context = torch.tensor(context_data, dtype=torch.int64, device=self.device) def get_device_context(self): """ Get the device context tensor for Gluon kernels. - Returns a tensor encoding: ``[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]`` - If tracing is enabled, also includes: ``[trace_enabled, max_events, trace_counter_ptr, trace_buffer_ptrs...]`` + Returns a tensor encoding: `[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]` Returns: torch.Tensor: Encoded context data as int64 tensor on device diff --git a/tests/ccl/test_all_gather_gluon.py b/tests/ccl/test_all_gather_gluon.py new file mode 100644 index 00000000..a5b65abd --- /dev/null +++ b/tests/ccl/test_all_gather_gluon.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Test suite for all-gather collective operation using Gluon. +""" + +import pytest +import torch +import torch.distributed as dist + +# Try to import Gluon, skip tests if not available +try: + import iris.experimental.iris_gluon as iris_gluon + from iris.ccl import Config + from iris.ccl.all_gather import all_gather + + GLUON_AVAILABLE = True +except ImportError: + GLUON_AVAILABLE = False + + +@pytest.mark.skipif(not GLUON_AVAILABLE, reason="Gluon not available") +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.float32, + torch.bfloat16, + ], +) +@pytest.mark.parametrize( + "M, N, block_size_m, block_size_n", + [ + # block_size_n must be a multiple of (threads_per_warp * num_warps). + # With defaults (threads_per_warp=64, num_warps=4), minimum is 256. + # elems_per_thread = block_size_n / 256: higher = wider vector loads. + (256, 256, 32, 256), # Small: elems_per_thread=1 (scalar loads) + (1024, 512, 32, 512), # Medium: elems_per_thread=2 (dword loads) + (8192, 8192, 32, 1024), # Large: elems_per_thread=4 (dwordx4, optimal) + ], +) +def test_all_gather_gluon(dtype, M, N, block_size_m, block_size_n): + """Test all-gather functionality using Gluon by comparing against PyTorch's implementation.""" + # Ensure torch.distributed is initialized (should be done by test runner) + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 # 8GB + shmem = iris_gluon.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Each rank has an M x N input tensor + # Output is (world_size * M, N) - concatenated along dimension 0 + pytorch_input_tensor = torch.randn(M, N, dtype=dtype, device=f"cuda:{rank}") + # Fill with deterministic values for easier debugging + pytorch_input_tensor.fill_(float(rank + 1)) + + # Create output tensor for PyTorch: (world_size * M, N) + pytorch_output_tensor = torch.zeros(world_size * M, N, dtype=dtype, device=f"cuda:{rank}") + + # Run PyTorch's all_gather_into_tensor to get reference output + shmem.barrier() + dist.all_gather_into_tensor(pytorch_output_tensor, pytorch_input_tensor) + torch.cuda.synchronize() + + # Now set up Iris Gluon all_gather + iris_input_tensor = shmem.zeros((M, N), dtype=dtype) + iris_input_tensor.copy_(pytorch_input_tensor) + + iris_output_tensor = shmem.zeros((world_size * M, N), dtype=dtype) + + # Run Iris Gluon all_gather + shmem.barrier() + config = Config(use_gluon=True, block_size_m=block_size_m, block_size_n=block_size_n) + all_gather(iris_output_tensor, iris_input_tensor, shmem, config=config) + torch.cuda.synchronize() + + # Compare results + atol = 1e-3 if dtype == torch.float16 else 1e-5 + max_diff = torch.abs(iris_output_tensor - pytorch_output_tensor).max().item() + + try: + assert torch.allclose(iris_output_tensor, pytorch_output_tensor, atol=atol), ( + f"Max difference: {max_diff}, expected < {atol}\n" + f"Rank {rank}: Iris Gluon output doesn't match PyTorch's all_gather_into_tensor" + ) + finally: + # Final barrier to ensure all ranks complete before test cleanup + # This helps with test isolation when running multiple tests + # Note: shmem.barrier() already does cuda.synchronize() + shmem.barrier() + # Explicitly delete the shmem instance to trigger cleanup + del shmem + # Force garbage collection to ensure IPC handles are cleaned up + import gc + + gc.collect() From e3f634964d6f5d15ef4118eb30b8759c4d838e82 Mon Sep 17 00:00:00 2001 From: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com> Date: Mon, 23 Mar 2026 12:09:50 -0700 Subject: [PATCH 02/15] Optimize gluon all-gather: hoist heap_bases + preserve vectorization hints The gluon kernel was using ctx.store() which calls _translate() on every remote store, causing two problems visible in the assembly: 1. global_store_short (2-byte scalar) instead of global_store_dwordx4 (16-byte) because _translate() pointer arithmetic breaks contiguity attributes 2. Two global_load_dwordx2 for heap_bases per remote rank per row (14 loads/row in 8-rank case) because heap_bases[from_rank] and heap_bases[to_rank] are reloaded every call Fix: bypass ctx.store() and perform pointer translation inline: - Hoist local_base = gl.load(heap_bases + iris_rank) before all loops - Compute ptr_delta = target_base - local_base manually - Re-apply gl.max_contiguous/gl.multiple_of to translated pointer - Use gl.store() directly, preserving vectorization hints for dwordx4 Co-Authored-By: Claude Opus 4.6 --- iris/ccl/all_gather.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index 36ed2cba..4949adc3 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -321,10 +321,10 @@ def persistent_all_gather_gluon( """ Persistent all-gather kernel using Gluon with explicit memory layout control. - Each rank loads its local input once per row and writes it to the - corresponding output slice on ALL ranks (local + remote), avoiding - redundant loads. Column indices use an explicit BlockedLayout to - control vectorization width. + Uses hoisted pointer translation: heap_bases are loaded once before the + tile loops, and pointer deltas (target_base - local_base) are precomputed. + Remote stores use gl.store() directly with contiguity hints preserved, + enabling vectorized global_store_dwordx4 instead of scalar stores. Memory layout (BlockedLayout): The column dimension is distributed across the GPU thread hierarchy @@ -388,6 +388,12 @@ def persistent_all_gather_gluon( ELEMS_PER_THREAD: gl.constexpr = BLOCK_SIZE_N // (THREADS_PER_WARP * WARPS_PER_CTA) col_layout: gl.constexpr = gl.BlockedLayout([ELEMS_PER_THREAD], [THREADS_PER_WARP], [WARPS_PER_CTA], [0]) + # Hoist heap_bases loads: load local_base once, then compute a signed + # byte-delta for each remote rank. The delta is added to the local + # output pointer at store time so that gl.store() sees a pointer with + # the same contiguity attributes as the original — enabling dwordx4. + local_base = gl.load(ctx.heap_bases + iris_rank) + for tile_id in range(pid, total_tiles, COMM_SMS): # Swizzled tile index computation for better L2 locality num_pid_in_group = GROUP_SIZE_M * num_pid_n @@ -427,7 +433,19 @@ def persistent_all_gather_gluon( if rank_idx == group_rank: gl.store(output_ptr_target, data, mask=col_mask, cache_modifier=".wt") else: - ctx.store(output_ptr_target, data, target_rank, mask=col_mask) + # Hoisted pointer translation: compute target address + # by adding the byte delta between target and local + # heap bases to the local pointer. This avoids + # reloading heap_bases every iteration and preserves + # the contiguity hints for vectorized stores. + target_base = gl.load(ctx.heap_bases + target_rank) + ptr_int = tl.cast(output_ptr_target, gl.uint64) + offset = ptr_int - local_base + target_base_byte = tl.cast(target_base, gl.pointer_type(gl.int8)) + translated_byte = target_base_byte + offset + translated_ptr = tl.cast(translated_byte, output_ptr_target.dtype) + translated_ptr = gl.max_contiguous(gl.multiple_of(translated_ptr, BLOCK_SIZE_N), BLOCK_SIZE_N) + gl.store(translated_ptr, data, mask=col_mask) def all_gather( From 9ec992d7a8c5835e464e53f16a66d135037c935d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 23 Mar 2026 19:10:46 +0000 Subject: [PATCH 03/15] Apply Ruff auto-fixes --- iris/ccl/all_gather.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index 4949adc3..ae756b3a 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -444,7 +444,9 @@ def persistent_all_gather_gluon( target_base_byte = tl.cast(target_base, gl.pointer_type(gl.int8)) translated_byte = target_base_byte + offset translated_ptr = tl.cast(translated_byte, output_ptr_target.dtype) - translated_ptr = gl.max_contiguous(gl.multiple_of(translated_ptr, BLOCK_SIZE_N), BLOCK_SIZE_N) + translated_ptr = gl.max_contiguous( + gl.multiple_of(translated_ptr, BLOCK_SIZE_N), BLOCK_SIZE_N + ) gl.store(translated_ptr, data, mask=col_mask) From 47ff319e9f7161c3e143f01712bcb646cb84b958 Mon Sep 17 00:00:00 2001 From: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com> Date: Mon, 23 Mar 2026 13:45:02 -0700 Subject: [PATCH 04/15] Replace row-by-row Gluon kernel with flat-2D tiling + traffic shaping The row-by-row kernel iterated per-row with 1D loads, producing 32x more instructions than Triton's 2D tile loads and underperforming by 6-38%. The flat-2D kernel uses a single 1D arange over BLOCK_SIZE_M * BLOCK_SIZE_N elements with div/mod for 2D indexing, producing one load + world_size stores per tile (matching Triton's instruction structure). Additional optimizations: - Hoisted pointer translation: local_base loaded once outside tile loop - Traffic-shaped writes: staggered (group_rank + rank_idx) % world_size write order avoids memory controller contention on the receiver side - Auto-default tile size: 8x256 (2048 elements, 8/thread) when user doesn't override Config defaults Also adds GluonDeviceTracing to iris_gluon.py for optional device-side event recording (TRACING=False by default, zero overhead when disabled). Benchmarked on MI308X (gfx942), 8192x8192 fp16, ROCm 7.2.0: 8 ranks, 32 CUs: Gluon 293 GB/s vs Triton 287 GB/s (+2%) 4 ranks, 48 CUs: Gluon 134 GB/s vs Triton 125 GB/s (+7%) 4 ranks, 32 CUs: Gluon 133 GB/s vs Triton 128 GB/s (+3%) Co-Authored-By: Claude Opus 4.6 --- iris/ccl/all_gather.py | 249 +++++++++++++++------------- iris/experimental/iris_gluon.py | 278 +++++++++++++++++++++++++++++++- 2 files changed, 411 insertions(+), 116 deletions(-) diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index ae756b3a..3911e583 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -11,6 +11,7 @@ import iris from .config import Config from .utils import extract_group_info +from iris.tracing.events import TraceEvent # Conditional import for Gluon try: @@ -289,7 +290,17 @@ def persistent_all_gather_partitioned( ) -# Gluon implementation +# Gluon implementation: flat-2D tiling approach +# +# Uses a single 1D arange over BLOCK_SIZE_M * BLOCK_SIZE_N elements with +# div/mod to compute 2D row/col indices. This gives one load + world_size +# stores per tile (matching Triton's 2D load/store structure) while staying +# within gluon's 1D BlockedLayout framework. +# +# Key optimizations: +# - Flat-2D tiling: eliminates the inner BLOCK_SIZE_M row loop +# - Hoisted pointer translation: local_base loaded once outside tile loop +# - Traffic shaping: staggered write order avoids memory controller contention if GLUON_AVAILABLE: @gluon.jit @@ -313,41 +324,35 @@ def persistent_all_gather_gluon( BLOCK_SIZE_N: gl.constexpr, GROUP_SIZE_M: gl.constexpr, COMM_SMS: gl.constexpr, - NUM_XCDS: gl.constexpr, - CHUNK_SIZE: gl.constexpr, THREADS_PER_WARP: gl.constexpr, WARPS_PER_CTA: gl.constexpr, + TRACING: gl.constexpr = False, ): """ - Persistent all-gather kernel using Gluon with explicit memory layout control. + Persistent all-gather kernel using Gluon with flat-2D tiling. - Uses hoisted pointer translation: heap_bases are loaded once before the - tile loops, and pointer deltas (target_base - local_base) are precomputed. - Remote stores use gl.store() directly with contiguity hints preserved, - enabling vectorized global_store_dwordx4 instead of scalar stores. + Uses a flat 1D index space of BLOCK_SIZE_M * BLOCK_SIZE_N elements, + computing 2D row/col via integer div/mod. This produces one vectorized + load and world_size vectorized stores per tile, matching Triton's 2D + load/store instruction structure while staying within gluon's 1D + BlockedLayout framework. Memory layout (BlockedLayout): - The column dimension is distributed across the GPU thread hierarchy - using gl.BlockedLayout([ELEMS_PER_THREAD], [THREADS_PER_WARP], [WARPS_PER_CTA], [order]). - - - ELEMS_PER_THREAD: number of contiguous elements each thread loads/stores. - Controls the vector width of memory instructions. For fp16: - 1 -> 2-byte scalar load - 2 -> 4-byte dword load - 4 -> 8-byte dwordx4 load (optimal on AMD GFX9+) - - THREADS_PER_WARP: threads per warp/wavefront (64 on AMD, 32 on NVIDIA). - - WARPS_PER_CTA: number of warps in the cooperative thread array (workgroup). - - The product ELEMS_PER_THREAD * THREADS_PER_WARP * WARPS_PER_CTA must - equal BLOCK_SIZE_N. ELEMS_PER_THREAD is derived as: - ELEMS_PER_THREAD = BLOCK_SIZE_N // (THREADS_PER_WARP * WARPS_PER_CTA) - - Constraints (validated by host wrapper before launch): - - BLOCK_SIZE_N must be a multiple of (THREADS_PER_WARP * WARPS_PER_CTA). - - BLOCK_SIZE_N must be >= (THREADS_PER_WARP * WARPS_PER_CTA) so that - ELEMS_PER_THREAD >= 1. - - WARPS_PER_CTA must match the num_warps kernel launch parameter. - - THREADS_PER_WARP must match the hardware wavefront size (64 for AMD). + A 1D BlockedLayout distributes TOTAL_ELEMS = BLOCK_SIZE_M * BLOCK_SIZE_N + elements across the thread hierarchy: + ELEMS_PER_THREAD = TOTAL_ELEMS // (THREADS_PER_WARP * WARPS_PER_CTA) + + Each thread handles ELEMS_PER_THREAD contiguous elements in the + flattened row-major order. Row/col are recovered via: + row = flat_idx // BLOCK_SIZE_N + col = flat_idx % BLOCK_SIZE_N + + Constraints: + - BLOCK_SIZE_M * BLOCK_SIZE_N must be a multiple of + (THREADS_PER_WARP * WARPS_PER_CTA). + - Optimal tile: 2048-4096 total elements (8-16 per thread). + Larger tiles cause register spilling and performance collapse. + - Recommended: BLOCK_SIZE_M=8, BLOCK_SIZE_N=256 (2048 elems, 8/thread). Args: IrisDeviceCtx: Gluon device context class for remote memory operations. @@ -364,16 +369,15 @@ def persistent_all_gather_gluon( rank_start: First iris rank in the group (for RMA target computation). rank_stride: Stride between consecutive iris ranks in the group. BLOCK_SIZE_M: Number of rows per tile. - BLOCK_SIZE_N: Number of columns per tile. Must be a multiple of - (THREADS_PER_WARP * WARPS_PER_CTA). + BLOCK_SIZE_N: Number of columns per tile. GROUP_SIZE_M: Swizzle group size for M-dimension tiling. - COMM_SMS: Number of SMs used for persistent scheduling. - NUM_XCDS: Number of XCDs (chiplet count). - CHUNK_SIZE: Chunk size for XCD-aware tile mapping. + COMM_SMS: Number of CUs used for persistent scheduling. THREADS_PER_WARP: Threads per warp/wavefront (64 for AMD, 32 for NVIDIA). WARPS_PER_CTA: Number of warps per workgroup. Must match num_warps. + TRACING: If True, record load/store events into trace buffers. """ - ctx = IrisDeviceCtx.initialize(context_tensor) + ctx = IrisDeviceCtx.initialize(context_tensor, tracing=TRACING) + events = TraceEvent() pid = gl.program_id(0) @@ -381,17 +385,15 @@ def persistent_all_gather_gluon( num_pid_n = gl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n - # Build the 1D BlockedLayout for the column dimension. - # ELEMS_PER_THREAD controls how many contiguous elements each thread - # handles, which directly maps to the vector load/store width: - # elems=1 -> scalar, elems=2 -> dword, elems=4 -> dwordx4 (optimal) - ELEMS_PER_THREAD: gl.constexpr = BLOCK_SIZE_N // (THREADS_PER_WARP * WARPS_PER_CTA) - col_layout: gl.constexpr = gl.BlockedLayout([ELEMS_PER_THREAD], [THREADS_PER_WARP], [WARPS_PER_CTA], [0]) - - # Hoist heap_bases loads: load local_base once, then compute a signed - # byte-delta for each remote rank. The delta is added to the local - # output pointer at store time so that gl.store() sees a pointer with - # the same contiguity attributes as the original — enabling dwordx4. + # Flat 1D layout covering BLOCK_SIZE_M * BLOCK_SIZE_N elements + TOTAL_ELEMS: gl.constexpr = BLOCK_SIZE_M * BLOCK_SIZE_N + ELEMS_PER_THREAD: gl.constexpr = TOTAL_ELEMS // (THREADS_PER_WARP * WARPS_PER_CTA) + flat_layout: gl.constexpr = gl.BlockedLayout( + [ELEMS_PER_THREAD], [THREADS_PER_WARP], [WARPS_PER_CTA], [0] + ) + + # Hoist local heap base outside the tile loop: eliminates redundant + # gl.load(heap_bases) calls in the inner store loop. local_base = gl.load(ctx.heap_bases + iris_rank) for tile_id in range(pid, total_tiles, COMM_SMS): @@ -403,51 +405,70 @@ def persistent_all_gather_gluon( pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m - # Build column index vector with explicit layout for vectorized access - rn = (pid_n * BLOCK_SIZE_N + gl.arange(0, BLOCK_SIZE_N, layout=col_layout)) % N - rn = gl.max_contiguous(gl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) - - col_offsets_in = rn * stride_in_n - col_offsets_out = rn * stride_out_n - col_mask = rn < N - - rm_base = pid_m * BLOCK_SIZE_M - - # Iterate row-by-row: load each row once, then write to all ranks. - # This avoids reloading the same data for each destination rank. - for i in range(BLOCK_SIZE_M): - row_idx = (rm_base + i) % M - - if row_idx < M: - # Single load from local input - data = gl.load(input_ptr + row_idx * stride_in_m + col_offsets_in, mask=col_mask) - - # Output row position: this rank's slice starts at group_rank * M - output_offset = (group_rank * M + row_idx) * stride_out_m + col_offsets_out - - # Broadcast to every rank (local store + remote RMA writes) - for rank_idx in range(world_size): - target_rank = rank_start + rank_idx * rank_stride - output_ptr_target = output_ptr + output_offset - - if rank_idx == group_rank: - gl.store(output_ptr_target, data, mask=col_mask, cache_modifier=".wt") - else: - # Hoisted pointer translation: compute target address - # by adding the byte delta between target and local - # heap bases to the local pointer. This avoids - # reloading heap_bases every iteration and preserves - # the contiguity hints for vectorized stores. - target_base = gl.load(ctx.heap_bases + target_rank) - ptr_int = tl.cast(output_ptr_target, gl.uint64) - offset = ptr_int - local_base - target_base_byte = tl.cast(target_base, gl.pointer_type(gl.int8)) - translated_byte = target_base_byte + offset - translated_ptr = tl.cast(translated_byte, output_ptr_target.dtype) - translated_ptr = gl.max_contiguous( - gl.multiple_of(translated_ptr, BLOCK_SIZE_N), BLOCK_SIZE_N - ) - gl.store(translated_ptr, data, mask=col_mask) + # Flat index -> 2D row/col within tile + flat_idx = gl.arange(0, TOTAL_ELEMS, layout=flat_layout) + row_local = flat_idx // BLOCK_SIZE_N + col_local = flat_idx % BLOCK_SIZE_N + + # Global row/col + row = pid_m * BLOCK_SIZE_M + row_local + col = pid_n * BLOCK_SIZE_N + col_local + + mask = (row < M) & (col < N) + + # Single flat load of the entire tile + input_offsets = row * stride_in_m + col * stride_in_n + input_addr = input_ptr + input_offsets + if TRACING: + h_load = ctx.tracing.record_event_start( + event_id=events.load, + target_rank=group_rank, + address=input_addr, + pid_m=pid_m, + pid_n=pid_n, + mask=mask, + ) + data = gl.load(input_addr, mask=mask, other=0.0) + if TRACING: + ctx.tracing.record_event_end(h_load) + + # Output: this rank's data goes to output[group_rank * M + row, col] + output_row = group_rank * M + row + output_offsets = output_row * stride_out_m + col * stride_out_n + + # Traffic-shaped stores to all ranks: stagger write order per rank + # so each rank writes to a different target at any given moment, + # avoiding memory controller contention on the receiver side. + for rank_idx in range(world_size): + dest_idx = (group_rank + rank_idx) % world_size + target_iris_rank = rank_start + dest_idx * rank_stride + output_ptrs = output_ptr + output_offsets + + if TRACING: + h_store = ctx.tracing.record_event_start( + event_id=events.store, + target_rank=target_iris_rank, + address=output_ptrs, + pid_m=pid_m, + pid_n=pid_n, + mask=mask, + ) + + if dest_idx == group_rank: + gl.store(output_ptrs, data, mask=mask, cache_modifier=".wt") + else: + # Hoisted translation: compute ptr_delta from pre-loaded + # local_base rather than calling ctx.store() which would + # do 2x gl.load(heap_bases) per call. + target_base = gl.load(ctx.heap_bases + target_iris_rank) + ptr_delta = target_base - local_base + output_ptrs_int = tl.cast(output_ptrs, gl.uint64) + remote_ptrs_int = output_ptrs_int + ptr_delta + remote_ptrs = tl.cast(remote_ptrs_int, output_ptrs.dtype) + gl.store(remote_ptrs, data, mask=mask) + + if TRACING: + ctx.tracing.record_event_end(h_store) def all_gather( @@ -508,27 +529,36 @@ def all_gather( if not hasattr(shmem, "get_device_context"): raise ValueError("use_gluon=True requires Iris Gluon context. Use iris.experimental.iris_gluon.iris()") - # Validate BlockedLayout constraints. - # The gluon kernel distributes BLOCK_SIZE_N elements across the thread - # hierarchy: ELEMS_PER_THREAD * THREADS_PER_WARP * WARPS_PER_CTA = BLOCK_SIZE_N. - # ELEMS_PER_THREAD controls vector load width (4 = dwordx4 for fp16, optimal). + # Apply optimal defaults for gluon flat-2D kernel when user hasn't + # overridden block sizes from the Config defaults (32x64). + block_size_m = config.block_size_m + block_size_n = config.block_size_n + if block_size_m == 32 and block_size_n == 64: + # User didn't override — use optimal flat-2D tile: 8x256 + block_size_m = 8 + block_size_n = 256 + + # Validate flat-2D layout constraints. + # TOTAL_ELEMS = BLOCK_SIZE_M * BLOCK_SIZE_N must be a multiple of + # THREADS_PER_WARP * WARPS_PER_CTA so each thread gets a whole + # number of elements. + total_elems = block_size_m * block_size_n threads_per_cta = config.threads_per_warp * config.num_warps - if config.block_size_n < threads_per_cta: + if total_elems < threads_per_cta: raise ValueError( - f"Gluon all-gather requires block_size_n >= threads_per_warp * num_warps " - f"({config.threads_per_warp} * {config.num_warps} = {threads_per_cta}), " - f"got block_size_n={config.block_size_n}." + f"Gluon all-gather requires block_size_m * block_size_n >= " + f"threads_per_warp * num_warps ({threads_per_cta}), " + f"got {block_size_m} * {block_size_n} = {total_elems}." ) - if config.block_size_n % threads_per_cta != 0: + if total_elems % threads_per_cta != 0: raise ValueError( - f"Gluon all-gather requires block_size_n to be a multiple of " - f"threads_per_warp * num_warps ({threads_per_cta}), " - f"got block_size_n={config.block_size_n}. " - f"This ensures each thread handles a whole number of elements. " - f"Recommended: block_size_n=1024 with threads_per_warp=64, num_warps=4 " - f"for dwordx4 vectorization (elems_per_thread=4)." + f"Gluon all-gather requires block_size_m * block_size_n to be a " + f"multiple of threads_per_warp * num_warps ({threads_per_cta}), " + f"got {block_size_m} * {block_size_n} = {total_elems}. " + f"Recommended: block_size_m=8, block_size_n=256." ) + tracing_enabled = hasattr(shmem, "tracing") and shmem.tracing.enabled context_tensor = shmem.get_device_context() persistent_all_gather_gluon[(config.comm_sms,)]( @@ -547,14 +577,13 @@ def all_gather( world_size, rank_start, rank_stride, - config.block_size_m, - config.block_size_n, + block_size_m, + block_size_n, config.swizzle_size, config.comm_sms, - config.num_xcds, - config.chunk_size, config.threads_per_warp, config.num_warps, + tracing_enabled, num_stages=config.num_stages, num_warps=config.num_warps, waves_per_eu=config.waves_per_eu, diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 4ef4aea1..2f56137e 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -51,6 +51,8 @@ count_devices, ) from iris.symmetric_heap import SymmetricHeap +from iris import device_utils +from iris.tracing.core import Tracing import numpy as np import torch import logging @@ -62,6 +64,152 @@ from .. import tensor_creation +class _GluonDeviceTracingCls: + """ + Gluon-native device-side tracing: records events into SoA buffers from inside Gluon kernels. + + Created by IrisDeviceCtx.initialize() when tracing=True. Use record_event_start + / record_event_end to bracket operations; events are exported via Tracing.export(). + """ + + enabled: gl.tensor + rank: gl.tensor + max_events: gl.tensor + counter: gl.tensor + op_index_counter: gl.tensor + buf_event_id: gl.tensor + buf_pid: gl.tensor + buf_pid_m: gl.tensor + buf_pid_n: gl.tensor + buf_cur_rank: gl.tensor + buf_target_rank: gl.tensor + buf_xcc_id: gl.tensor + buf_cu_id: gl.tensor + buf_timestamp: gl.tensor + buf_address: gl.tensor + buf_duration_cycles: gl.tensor + buf_op_index: gl.tensor + buf_payload_size: gl.tensor + + @gluon.constexpr_function + def __init__( + self, + enabled, + rank, + max_events, + counter, + op_index_counter, + buf_event_id, + buf_pid, + buf_pid_m, + buf_pid_n, + buf_cur_rank, + buf_target_rank, + buf_xcc_id, + buf_cu_id, + buf_timestamp, + buf_address, + buf_duration_cycles, + buf_op_index, + buf_payload_size, + ): + """Construct GluonDeviceTracing (called from IrisDeviceCtx.initialize).""" + self.enabled = enabled + self.rank = rank + self.max_events = max_events + self.counter = counter + self.op_index_counter = op_index_counter + self.buf_event_id = buf_event_id + self.buf_pid = buf_pid + self.buf_pid_m = buf_pid_m + self.buf_pid_n = buf_pid_n + self.buf_cur_rank = buf_cur_rank + self.buf_target_rank = buf_target_rank + self.buf_xcc_id = buf_xcc_id + self.buf_cu_id = buf_cu_id + self.buf_timestamp = buf_timestamp + self.buf_address = buf_address + self.buf_duration_cycles = buf_duration_cycles + self.buf_op_index = buf_op_index + self.buf_payload_size = buf_payload_size + + @gluon.jit + def record_event_start( + self, + event_id: tl.constexpr, + target_rank, + address, + pid_m, + pid_n, + mask=None, + ): + """ + Record start of a traced operation. Returns a handle for record_event_end. + + Only stores when event_idx < max_events (bounds check). + cur_rank is taken from the tracing context (self.rank). + + Args: + event_id: Event type ID (constexpr) + target_rank: Target rank for the operation + address: Memory address(es) - can be 1D or 2D block of pointers. + pid_m: Program ID in M dimension + pid_n: Program ID in N dimension + mask: Optional mask tensor indicating valid elements. + """ + if self.enabled == 0: + return tl.cast(self.enabled, tl.int32) + + event_idx = tl.atomic_add(self.counter, 1) + op_index = tl.atomic_add(self.op_index_counter, 1) + + # Calculate payload_size from mask and datatype + if mask is not None: + mask_i32 = tl.cast(mask, tl.int32) + num_elements = tl.sum(mask_i32) + elem_type = address.dtype.element_ty + bitwidth = elem_type.primitive_bitwidth + elem_size_bytes = bitwidth // 8 + payload_size = num_elements * elem_size_bytes + else: + payload_size = self.enabled * 0 # scalar 0 without tl.full + + if event_idx < self.max_events: + tl.store(self.buf_event_id + event_idx, event_id) + tl.store(self.buf_pid + event_idx, gl.program_id(0)) + tl.store(self.buf_pid_m + event_idx, pid_m) + tl.store(self.buf_pid_n + event_idx, pid_n) + tl.store(self.buf_cur_rank + event_idx, self.rank) + tl.store(self.buf_target_rank + event_idx, target_rank) + tl.store(self.buf_xcc_id + event_idx, device_utils.get_xcc_id()) + tl.store(self.buf_cu_id + event_idx, device_utils.get_cu_id()) + tl.store(self.buf_timestamp + event_idx, device_utils.read_realtime()) + addr_i64 = tl.cast(address, tl.int64) + tl.store(self.buf_address + event_idx, tl.min(addr_i64)) + tl.store(self.buf_duration_cycles + event_idx, tl.cast(self.enabled * 0, tl.int64)) + tl.store(self.buf_op_index + event_idx, op_index) + tl.store(self.buf_payload_size + event_idx, payload_size) + return event_idx + + @gluon.jit + def record_event_end(self, handle): + """ + Record end timestamp for the event started with record_event_start(handle). + + Only stores when handle < max_events (bounds check). + """ + if self.enabled == 0: + return + + end_ts = device_utils.read_realtime() + if handle < self.max_events: + tl.store(self.buf_duration_cycles + handle, end_ts) + + +_GluonDeviceTracingCls.__init__.__triton_builtin__ = True +GluonDeviceTracing = aggregate(_GluonDeviceTracingCls) + + @aggregate class IrisDeviceCtx: """ @@ -74,28 +222,36 @@ class IrisDeviceCtx: cur_rank: Current rank ID num_ranks: Total number of ranks heap_bases: Pointer to array of heap base addresses for all ranks + tracing: GluonDeviceTracing instance (active when tracing=True) """ cur_rank: gl.tensor num_ranks: gl.tensor heap_bases: gl.tensor + tracing: GluonDeviceTracing @gluon.constexpr_function - def __init__(self, cur_rank, num_ranks, heap_bases): + def __init__(self, cur_rank, num_ranks, heap_bases, tracing): self.cur_rank = cur_rank self.num_ranks = num_ranks self.heap_bases = heap_bases + self.tracing = tracing @staticmethod @gluon.jit - def initialize(context_tensor): + def initialize(context_tensor, tracing: gl.constexpr = False): """ Initialize `IrisDeviceCtx` from the encoded tensor. - The context tensor has the format: `[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]` + The context tensor has the format: + ``[cur_rank, num_ranks, heap_base_0, heap_base_1, ..., trace_info...]`` + + If tracing is enabled on the host (via ``shmem.tracing.enable()``), the + context tensor also contains tracing buffer pointers after the heap bases. Args: context_tensor: Pointer to encoded context data + tracing: Enable event tracing (constexpr, default: False) Returns: `IrisDeviceCtx`: Initialized device context @@ -107,7 +263,86 @@ def initialize(context_tensor): # Extract heap bases (from index 2 onwards) heap_bases = context_tensor + 2 # Offset pointer to start at heap bases - return IrisDeviceCtx(cur_rank, num_ranks, heap_bases) + if tracing: + # Extract tracing info: starts after heap_bases, then skip trace_enabled flag + # Layout: [cur_rank, num_ranks, heap_base_0..N-1, trace_enabled, max_events, + # trace_counter_ptr, op_index_counter_ptr, buf_event_id, ...(13 buffers)] + trace_info_base = 2 + num_ranks + 1 # skip cur_rank, num_ranks, heap_bases, trace_enabled + max_events = gl.load(context_tensor + trace_info_base + 0) + trace_counter_ptr = gl.load(context_tensor + trace_info_base + 1) + op_index_counter_ptr = gl.load(context_tensor + trace_info_base + 2) + + # Cast counter pointers + trace_counter = tl.cast(trace_counter_ptr, tl.pointer_type(tl.int32)) + op_index_counter = tl.cast(op_index_counter_ptr, tl.pointer_type(tl.int32)) + + # Extract trace buffer pointers (13 buffers, same order as Iris._build_device_context) + buf_base = trace_info_base + 3 + buf_event_id = tl.cast(gl.load(context_tensor + buf_base + 0), tl.pointer_type(tl.int32)) + buf_pid = tl.cast(gl.load(context_tensor + buf_base + 1), tl.pointer_type(tl.int32)) + buf_pid_m = tl.cast(gl.load(context_tensor + buf_base + 2), tl.pointer_type(tl.int32)) + buf_pid_n = tl.cast(gl.load(context_tensor + buf_base + 3), tl.pointer_type(tl.int32)) + buf_cur_rank = tl.cast(gl.load(context_tensor + buf_base + 4), tl.pointer_type(tl.int32)) + buf_target_rank = tl.cast(gl.load(context_tensor + buf_base + 5), tl.pointer_type(tl.int32)) + buf_xcc_id = tl.cast(gl.load(context_tensor + buf_base + 6), tl.pointer_type(tl.int32)) + buf_cu_id = tl.cast(gl.load(context_tensor + buf_base + 7), tl.pointer_type(tl.int32)) + buf_timestamp = tl.cast(gl.load(context_tensor + buf_base + 8), tl.pointer_type(tl.int64)) + buf_address = tl.cast(gl.load(context_tensor + buf_base + 9), tl.pointer_type(tl.int64)) + buf_duration_cycles = tl.cast(gl.load(context_tensor + buf_base + 10), tl.pointer_type(tl.int64)) + buf_op_index = tl.cast(gl.load(context_tensor + buf_base + 11), tl.pointer_type(tl.int32)) + buf_payload_size = tl.cast(gl.load(context_tensor + buf_base + 12), tl.pointer_type(tl.int32)) + + # Read trace_enabled flag from context tensor (at index 2 + num_ranks) + trace_enabled_val = gl.load(context_tensor + 2 + num_ranks) + device_tracing = GluonDeviceTracing( + enabled=trace_enabled_val, + rank=cur_rank, + max_events=max_events, + counter=trace_counter, + op_index_counter=op_index_counter, + buf_event_id=buf_event_id, + buf_pid=buf_pid, + buf_pid_m=buf_pid_m, + buf_pid_n=buf_pid_n, + buf_cur_rank=buf_cur_rank, + buf_target_rank=buf_target_rank, + buf_xcc_id=buf_xcc_id, + buf_cu_id=buf_cu_id, + buf_timestamp=buf_timestamp, + buf_address=buf_address, + buf_duration_cycles=buf_duration_cycles, + buf_op_index=buf_op_index, + buf_payload_size=buf_payload_size, + ) + else: + # When tracing disabled, use dummy pointers (never dereferenced) + dummy_ptr_i32 = tl.cast(context_tensor, tl.pointer_type(tl.int32)) + dummy_ptr_i64 = tl.cast(context_tensor, tl.pointer_type(tl.int64)) + # Read trace_enabled flag from context tensor (0 = disabled) + trace_enabled_val = gl.load(context_tensor + 2 + num_ranks) + max_events_zero = trace_enabled_val # 0 when tracing disabled + device_tracing = GluonDeviceTracing( + enabled=trace_enabled_val, + rank=cur_rank, + max_events=max_events_zero, + counter=dummy_ptr_i32, + op_index_counter=dummy_ptr_i32, + buf_event_id=dummy_ptr_i32, + buf_pid=dummy_ptr_i32, + buf_pid_m=dummy_ptr_i32, + buf_pid_n=dummy_ptr_i32, + buf_cur_rank=dummy_ptr_i32, + buf_target_rank=dummy_ptr_i32, + buf_xcc_id=dummy_ptr_i32, + buf_cu_id=dummy_ptr_i32, + buf_timestamp=dummy_ptr_i64, + buf_address=dummy_ptr_i64, + buf_duration_cycles=dummy_ptr_i64, + buf_op_index=dummy_ptr_i32, + buf_payload_size=dummy_ptr_i32, + ) + + return IrisDeviceCtx(cur_rank, num_ranks, heap_bases, device_tracing) @gluon.jit def _translate(self, ptr, from_rank, to_rank): @@ -507,6 +742,9 @@ def __init__(self, heap_size=1 << 30): distributed_barrier() + # Initialize tracing manager (disabled by default) + self.tracing = Tracing(self) + # Initialize CCL interface self.ccl = self.CCL(self) @@ -662,20 +900,48 @@ def _build_device_context(self): """ Build and cache the device context tensor. - Called during __init__ to pre-build the tensor once. + Called during __init__ and again after tracing.enable() to include tracing fields. """ # Convert heap_bases to a list for concatenation heap_bases_list = self.heap_bases.tolist() # Create context tensor: [cur_rank, num_ranks, heap_base_0, heap_base_1, ...] context_data = [self.cur_rank, self.num_ranks] + heap_bases_list + + # Add tracing info if enabled (same layout as Iris._build_device_context) + if self.tracing.enabled: + trace_buffer_ptrs = [ + self.tracing.trace_buffers["event_id"].data_ptr(), + self.tracing.trace_buffers["pid"].data_ptr(), + self.tracing.trace_buffers["pid_m"].data_ptr(), + self.tracing.trace_buffers["pid_n"].data_ptr(), + self.tracing.trace_buffers["cur_rank"].data_ptr(), + self.tracing.trace_buffers["target_rank"].data_ptr(), + self.tracing.trace_buffers["xcc_id"].data_ptr(), + self.tracing.trace_buffers["cu_id"].data_ptr(), + self.tracing.trace_buffers["timestamp"].data_ptr(), + self.tracing.trace_buffers["address"].data_ptr(), + self.tracing.trace_buffers["duration_cycles"].data_ptr(), + self.tracing.trace_buffers["op_index"].data_ptr(), + self.tracing.trace_buffers["payload_size"].data_ptr(), + ] + context_data += [ + 1, # trace_enabled = 1 (true) + self.tracing.max_events, + self.tracing.trace_counter.data_ptr(), + self.tracing.op_index_counter.data_ptr(), + ] + trace_buffer_ptrs + else: + context_data += [0] # trace_enabled = 0 (false) + self._device_context = torch.tensor(context_data, dtype=torch.int64, device=self.device) def get_device_context(self): """ Get the device context tensor for Gluon kernels. - Returns a tensor encoding: `[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]` + Returns a tensor encoding: ``[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]`` + If tracing is enabled, also includes: ``[trace_enabled, max_events, trace_counter_ptr, trace_buffer_ptrs...]`` Returns: torch.Tensor: Encoded context data as int64 tensor on device From 349e9cf66212661861221f10590427451d27b81e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 23 Mar 2026 20:45:43 +0000 Subject: [PATCH 05/15] Apply Ruff auto-fixes --- iris/ccl/all_gather.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index 3911e583..2f4a190e 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -388,9 +388,7 @@ def persistent_all_gather_gluon( # Flat 1D layout covering BLOCK_SIZE_M * BLOCK_SIZE_N elements TOTAL_ELEMS: gl.constexpr = BLOCK_SIZE_M * BLOCK_SIZE_N ELEMS_PER_THREAD: gl.constexpr = TOTAL_ELEMS // (THREADS_PER_WARP * WARPS_PER_CTA) - flat_layout: gl.constexpr = gl.BlockedLayout( - [ELEMS_PER_THREAD], [THREADS_PER_WARP], [WARPS_PER_CTA], [0] - ) + flat_layout: gl.constexpr = gl.BlockedLayout([ELEMS_PER_THREAD], [THREADS_PER_WARP], [WARPS_PER_CTA], [0]) # Hoist local heap base outside the tile loop: eliminates redundant # gl.load(heap_bases) calls in the inner store loop. From e63a5fe066cf1fa7125ce05d94e2c5c11cd6871d Mon Sep 17 00:00:00 2001 From: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com> Date: Mon, 23 Mar 2026 15:56:10 -0700 Subject: [PATCH 06/15] Add comprehensive all-gather shape + CU sweep benchmark Self-contained benchmark script that compares RCCL (default channels), Iris Triton persistent, and Iris Gluon flat-2D all-gather across 6 tensor shapes (2 MB to 512 MB) and 5 CU counts (8-96). Designed to run under torchrun in a single invocation with formatted table output and optional CSV export. Co-Authored-By: Claude Opus 4.6 --- benchmark/ccl/all_gather/benchmark_shapes.py | 278 +++++++++++++++++++ 1 file changed, 278 insertions(+) create mode 100644 benchmark/ccl/all_gather/benchmark_shapes.py diff --git a/benchmark/ccl/all_gather/benchmark_shapes.py b/benchmark/ccl/all_gather/benchmark_shapes.py new file mode 100644 index 00000000..92c8e9f3 --- /dev/null +++ b/benchmark/ccl/all_gather/benchmark_shapes.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Comprehensive all-gather benchmark: shape sweep x CU sweep. + +Compares RCCL (default channels), Iris Triton persistent, and Iris Gluon flat-2D +across multiple tensor shapes and CU counts. + +Usage: + torchrun --nproc_per_node=8 benchmark/ccl/all_gather/benchmark_shapes.py [--csv results.csv] +""" + +import argparse +import csv +import io +import os + +import torch +import torch.distributed as dist + +import iris +from iris.ccl import Config +import iris.experimental.iris_gluon as iris_gluon + + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +SHAPES = [ + (1024, 1024), # 2 MB - small activations + (2048, 4096), # 16 MB - medium MLP + (4096, 4096), # 32 MB - GPT-scale hidden + (8192, 8192), # 128 MB - large MLP / standard bench + (16384, 8192), # 256 MB - long sequences + (16384, 16384), # 512 MB - large model partitions +] + +CU_COUNTS = [8, 16, 32, 64, 96] + +DTYPE = torch.float16 +DTYPE_STR = "fp16" +ELEMENT_SIZE = 2 # bytes per fp16 element + +# Benchmark parameters +N_WARMUP = 25 +N_REPEAT = 100 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def calc_bandwidth_gbps(M, N, world_size, ms): + """Calculate all-gather bandwidth in GB/s.""" + total_bytes = (world_size - 1) * M * N * ELEMENT_SIZE + total_gb = total_bytes / (1024**3) + return total_gb / (ms * 1e-3) if ms > 0 else 0.0 + + +def bench_rccl(M, N, rank, world_size): + """Benchmark RCCL all_gather_into_tensor at default channel config.""" + inp = torch.zeros(M, N, dtype=DTYPE, device=f"cuda:{rank}") + inp.fill_(float(rank + 1)) + out = torch.zeros(world_size * M, N, dtype=DTYPE, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + dist.all_gather_into_tensor(out, inp) + torch.cuda.synchronize() + dist.barrier() + + out.zero_() + inp.fill_(float(rank + 1)) + dist.barrier() + + def fn(): + dist.all_gather_into_tensor(out, inp) + + ms = iris.do_bench(fn, dist.barrier, n_warmup=N_WARMUP, n_repeat=N_REPEAT) + return ms + + +def bench_iris(M, N, shmem, config): + """Benchmark Iris all-gather (Triton or Gluon depending on config).""" + world_size = shmem.get_num_ranks() + rank = shmem.get_rank() + + inp = shmem.zeros((M, N), dtype=DTYPE) + out = shmem.zeros((world_size * M, N), dtype=DTYPE) + + inp.fill_(float(rank + 1)) + shmem.barrier() + + def fn(): + shmem.ccl.all_gather(out, inp, config=config, async_op=False) + + ms = iris.do_bench(fn, shmem.barrier, n_warmup=N_WARMUP, n_repeat=N_REPEAT) + + # Free symmetric heap memory for next iteration + del inp, out + torch.cuda.empty_cache() + shmem.barrier() + + return ms + + +def validate_iris(M, N, shmem, config): + """Quick correctness check for an Iris config. Returns True if correct.""" + world_size = shmem.get_num_ranks() + rank = shmem.get_rank() + + inp = shmem.zeros((M, N), dtype=DTYPE) + out = shmem.zeros((world_size * M, N), dtype=DTYPE) + + inp.fill_(float(rank + 1)) + out.zero_() + shmem.barrier() + + shmem.ccl.all_gather(out, inp, config=config, async_op=False) + shmem.barrier() + torch.cuda.synchronize() + + ok = True + for r in range(world_size): + expected = float(r + 1) + chunk = out[r * M : (r + 1) * M, :] + if not torch.allclose(chunk, torch.full_like(chunk, expected), atol=1e-3): + ok = False + break + + del inp, out + torch.cuda.empty_cache() + shmem.barrier() + return ok + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description="All-gather shape + CU sweep benchmark") + parser.add_argument("--csv", type=str, default=None, help="Output CSV file path") + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size (default 16 GB)") + parser.add_argument("--validate", action="store_true", help="Validate correctness before benchmarking") + parser.add_argument("--n_warmup", type=int, default=N_WARMUP, help="Warmup iterations") + parser.add_argument("--n_repeat", type=int, default=N_REPEAT, help="Benchmark iterations") + args = parser.parse_args() + + global N_WARMUP, N_REPEAT + N_WARMUP = args.n_warmup + N_REPEAT = args.n_repeat + + # torchrun sets these env vars + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", + device_id=torch.device(f"cuda:{local_rank}"), + ) + + rank = dist.get_rank() + is_root = rank == 0 + + # Initialize both Triton and Gluon shmem contexts. + # They share the same underlying symmetric heap / distributed state, + # but dispatch to different kernel backends. + shmem_triton = iris.iris(args.heap_size) + shmem_gluon = iris_gluon.iris(args.heap_size) + + # Collect results: list of dicts + results = [] + + # Table header + if is_root: + hdr = ( + f"{'Shape':>14s} {'Size':>7s} {'Backend':>10s} {'CUs':>4s} " + f"{'Time(ms)':>9s} {'BW(GB/s)':>9s} {'vs RCCL':>7s}" + ) + print("=" * len(hdr)) + print(hdr) + print("=" * len(hdr)) + + for M, N in SHAPES: + data_mb = M * N * ELEMENT_SIZE / (1024**2) + shape_str = f"{M}x{N}" + + # --- Optional validation --- + if args.validate: + for cu in [CU_COUNTS[-1]]: # validate at highest CU count only + triton_cfg = Config(comm_sms=cu) + gluon_cfg = Config(comm_sms=cu, use_gluon=True) + + ok_t = validate_iris(M, N, shmem_triton, triton_cfg) + ok_g = validate_iris(M, N, shmem_gluon, gluon_cfg) + + if is_root: + if not ok_t: + print(f"WARNING: Triton validation FAILED for {shape_str} cu={cu}") + if not ok_g: + print(f"WARNING: Gluon validation FAILED for {shape_str} cu={cu}") + + # --- RCCL baseline --- + rccl_ms = bench_rccl(M, N, rank, world_size) + rccl_bw = calc_bandwidth_gbps(M, N, world_size, rccl_ms) + + row = { + "shape": shape_str, "M": M, "N": N, "data_mb": data_mb, + "backend": "RCCL", "cus": "-", "time_ms": rccl_ms, + "bw_gbps": rccl_bw, "vs_rccl_pct": 100.0, + } + results.append(row) + + if is_root: + print( + f"{shape_str:>14s} {data_mb:6.0f}M {'RCCL':>10s} {'-':>4s} " + f"{rccl_ms:9.3f} {rccl_bw:9.1f} {100.0:6.1f}%" + ) + + # --- Iris Triton + Gluon at each CU count --- + for cu in CU_COUNTS: + for backend_name, shmem, use_gluon in [ + ("Triton", shmem_triton, False), + ("Gluon", shmem_gluon, True), + ]: + cfg = Config(comm_sms=cu, use_gluon=use_gluon) + ms = bench_iris(M, N, shmem, cfg) + bw = calc_bandwidth_gbps(M, N, world_size, ms) + vs_rccl = (bw / rccl_bw * 100) if rccl_bw > 0 else 0.0 + + row = { + "shape": shape_str, "M": M, "N": N, "data_mb": data_mb, + "backend": backend_name, "cus": cu, "time_ms": ms, + "bw_gbps": bw, "vs_rccl_pct": vs_rccl, + } + results.append(row) + + if is_root: + print( + f"{shape_str:>14s} {data_mb:6.0f}M {backend_name:>10s} {cu:4d} " + f"{ms:9.3f} {bw:9.1f} {vs_rccl:6.1f}%" + ) + + # Separator between shapes + if is_root: + print("-" * 72) + + # --- Summary CSV --- + if is_root: + buf = io.StringIO() + writer = csv.DictWriter( + buf, + fieldnames=["shape", "M", "N", "data_mb", "backend", "cus", "time_ms", "bw_gbps", "vs_rccl_pct"], + ) + writer.writeheader() + writer.writerows(results) + + csv_text = buf.getvalue() + + if args.csv: + with open(args.csv, "w") as f: + f.write(csv_text) + print(f"\nResults written to {args.csv}") + else: + print("\n--- CSV ---") + print(csv_text) + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() From dd7d5bf24c880ca18ba117c224e0512a843f385e Mon Sep 17 00:00:00 2001 From: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:50:14 -0700 Subject: [PATCH 07/15] Fix global variable SyntaxError in benchmark_shapes.py Pass n_warmup and n_repeat as function parameters instead of using global statement, which caused SyntaxError when the names appeared in argparse defaults before the global declaration. Co-Authored-By: Claude Opus 4.6 --- benchmark/ccl/all_gather/benchmark_shapes.py | 27 ++++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/benchmark/ccl/all_gather/benchmark_shapes.py b/benchmark/ccl/all_gather/benchmark_shapes.py index 92c8e9f3..652bcf73 100644 --- a/benchmark/ccl/all_gather/benchmark_shapes.py +++ b/benchmark/ccl/all_gather/benchmark_shapes.py @@ -44,9 +44,9 @@ DTYPE_STR = "fp16" ELEMENT_SIZE = 2 # bytes per fp16 element -# Benchmark parameters -N_WARMUP = 25 -N_REPEAT = 100 +# Default benchmark parameters +DEFAULT_N_WARMUP = 25 +DEFAULT_N_REPEAT = 100 # --------------------------------------------------------------------------- @@ -60,7 +60,7 @@ def calc_bandwidth_gbps(M, N, world_size, ms): return total_gb / (ms * 1e-3) if ms > 0 else 0.0 -def bench_rccl(M, N, rank, world_size): +def bench_rccl(M, N, rank, world_size, n_warmup, n_repeat): """Benchmark RCCL all_gather_into_tensor at default channel config.""" inp = torch.zeros(M, N, dtype=DTYPE, device=f"cuda:{rank}") inp.fill_(float(rank + 1)) @@ -79,11 +79,11 @@ def bench_rccl(M, N, rank, world_size): def fn(): dist.all_gather_into_tensor(out, inp) - ms = iris.do_bench(fn, dist.barrier, n_warmup=N_WARMUP, n_repeat=N_REPEAT) + ms = iris.do_bench(fn, dist.barrier, n_warmup=n_warmup, n_repeat=n_repeat) return ms -def bench_iris(M, N, shmem, config): +def bench_iris(M, N, shmem, config, n_warmup, n_repeat): """Benchmark Iris all-gather (Triton or Gluon depending on config).""" world_size = shmem.get_num_ranks() rank = shmem.get_rank() @@ -97,7 +97,7 @@ def bench_iris(M, N, shmem, config): def fn(): shmem.ccl.all_gather(out, inp, config=config, async_op=False) - ms = iris.do_bench(fn, shmem.barrier, n_warmup=N_WARMUP, n_repeat=N_REPEAT) + ms = iris.do_bench(fn, shmem.barrier, n_warmup=n_warmup, n_repeat=n_repeat) # Free symmetric heap memory for next iteration del inp, out @@ -146,13 +146,12 @@ def main(): parser.add_argument("--csv", type=str, default=None, help="Output CSV file path") parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size (default 16 GB)") parser.add_argument("--validate", action="store_true", help="Validate correctness before benchmarking") - parser.add_argument("--n_warmup", type=int, default=N_WARMUP, help="Warmup iterations") - parser.add_argument("--n_repeat", type=int, default=N_REPEAT, help="Benchmark iterations") + parser.add_argument("--n_warmup", type=int, default=DEFAULT_N_WARMUP, help="Warmup iterations") + parser.add_argument("--n_repeat", type=int, default=DEFAULT_N_REPEAT, help="Benchmark iterations") args = parser.parse_args() - global N_WARMUP, N_REPEAT - N_WARMUP = args.n_warmup - N_REPEAT = args.n_repeat + n_warmup = args.n_warmup + n_repeat = args.n_repeat # torchrun sets these env vars local_rank = int(os.environ["LOCAL_RANK"]) @@ -206,7 +205,7 @@ def main(): print(f"WARNING: Gluon validation FAILED for {shape_str} cu={cu}") # --- RCCL baseline --- - rccl_ms = bench_rccl(M, N, rank, world_size) + rccl_ms = bench_rccl(M, N, rank, world_size, n_warmup, n_repeat) rccl_bw = calc_bandwidth_gbps(M, N, world_size, rccl_ms) row = { @@ -229,7 +228,7 @@ def main(): ("Gluon", shmem_gluon, True), ]: cfg = Config(comm_sms=cu, use_gluon=use_gluon) - ms = bench_iris(M, N, shmem, cfg) + ms = bench_iris(M, N, shmem, cfg, n_warmup, n_repeat) bw = calc_bandwidth_gbps(M, N, world_size, ms) vs_rccl = (bw / rccl_bw * 100) if rccl_bw > 0 else 0.0 From 669487232a3d918adcb314a783e2eb0b2c0c1b8d Mon Sep 17 00:00:00 2001 From: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:54:08 -0700 Subject: [PATCH 08/15] Pre-allocate all heap buffers upfront to avoid OOM The symmetric heap uses a bump allocator with no free. Allocating input/output tensors per-iteration exhausted the heap at larger shapes. Now all buffers for every shape are allocated once before the benchmark loop starts (~8.5 GB per shmem context). Co-Authored-By: Claude Opus 4.6 --- benchmark/ccl/all_gather/benchmark_shapes.py | 61 +++++++++++--------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/benchmark/ccl/all_gather/benchmark_shapes.py b/benchmark/ccl/all_gather/benchmark_shapes.py index 652bcf73..33ce8b2f 100644 --- a/benchmark/ccl/all_gather/benchmark_shapes.py +++ b/benchmark/ccl/all_gather/benchmark_shapes.py @@ -83,14 +83,10 @@ def fn(): return ms -def bench_iris(M, N, shmem, config, n_warmup, n_repeat): - """Benchmark Iris all-gather (Triton or Gluon depending on config).""" - world_size = shmem.get_num_ranks() +def bench_iris(inp, out, shmem, config, n_warmup, n_repeat): + """Benchmark Iris all-gather with pre-allocated heap tensors.""" rank = shmem.get_rank() - inp = shmem.zeros((M, N), dtype=DTYPE) - out = shmem.zeros((world_size * M, N), dtype=DTYPE) - inp.fill_(float(rank + 1)) shmem.barrier() @@ -98,22 +94,15 @@ def fn(): shmem.ccl.all_gather(out, inp, config=config, async_op=False) ms = iris.do_bench(fn, shmem.barrier, n_warmup=n_warmup, n_repeat=n_repeat) - - # Free symmetric heap memory for next iteration - del inp, out - torch.cuda.empty_cache() shmem.barrier() - return ms -def validate_iris(M, N, shmem, config): - """Quick correctness check for an Iris config. Returns True if correct.""" +def validate_iris(inp, out, shmem, config): + """Quick correctness check with pre-allocated heap tensors.""" world_size = shmem.get_num_ranks() rank = shmem.get_rank() - - inp = shmem.zeros((M, N), dtype=DTYPE) - out = shmem.zeros((world_size * M, N), dtype=DTYPE) + M = inp.shape[0] inp.fill_(float(rank + 1)) out.zero_() @@ -131,8 +120,6 @@ def validate_iris(M, N, shmem, config): ok = False break - del inp, out - torch.cuda.empty_cache() shmem.barrier() return ok @@ -167,11 +154,30 @@ def main(): is_root = rank == 0 # Initialize both Triton and Gluon shmem contexts. - # They share the same underlying symmetric heap / distributed state, - # but dispatch to different kernel backends. shmem_triton = iris.iris(args.heap_size) shmem_gluon = iris_gluon.iris(args.heap_size) + # Pre-allocate all input/output tensor pairs on each heap upfront. + # The symmetric heap uses a bump allocator (no free), so we must + # allocate everything before the benchmark loop. + triton_bufs = {} # (M, N) -> (inp, out) + gluon_bufs = {} + for M, N in SHAPES: + triton_bufs[(M, N)] = ( + shmem_triton.zeros((M, N), dtype=DTYPE), + shmem_triton.zeros((world_size * M, N), dtype=DTYPE), + ) + gluon_bufs[(M, N)] = ( + shmem_gluon.zeros((M, N), dtype=DTYPE), + shmem_gluon.zeros((world_size * M, N), dtype=DTYPE), + ) + + if is_root: + total_heap_per_ctx = sum( + (M * N + world_size * M * N) * ELEMENT_SIZE for M, N in SHAPES + ) + print(f"Heap usage per context: {total_heap_per_ctx / (1024**3):.2f} GB") + # Collect results: list of dicts results = [] @@ -189,14 +195,17 @@ def main(): data_mb = M * N * ELEMENT_SIZE / (1024**2) shape_str = f"{M}x{N}" + t_inp, t_out = triton_bufs[(M, N)] + g_inp, g_out = gluon_bufs[(M, N)] + # --- Optional validation --- if args.validate: for cu in [CU_COUNTS[-1]]: # validate at highest CU count only triton_cfg = Config(comm_sms=cu) gluon_cfg = Config(comm_sms=cu, use_gluon=True) - ok_t = validate_iris(M, N, shmem_triton, triton_cfg) - ok_g = validate_iris(M, N, shmem_gluon, gluon_cfg) + ok_t = validate_iris(t_inp, t_out, shmem_triton, triton_cfg) + ok_g = validate_iris(g_inp, g_out, shmem_gluon, gluon_cfg) if is_root: if not ok_t: @@ -223,12 +232,12 @@ def main(): # --- Iris Triton + Gluon at each CU count --- for cu in CU_COUNTS: - for backend_name, shmem, use_gluon in [ - ("Triton", shmem_triton, False), - ("Gluon", shmem_gluon, True), + for backend_name, shmem, inp, out, use_gluon in [ + ("Triton", shmem_triton, t_inp, t_out, False), + ("Gluon", shmem_gluon, g_inp, g_out, True), ]: cfg = Config(comm_sms=cu, use_gluon=use_gluon) - ms = bench_iris(M, N, shmem, cfg, n_warmup, n_repeat) + ms = bench_iris(inp, out, shmem, cfg, n_warmup, n_repeat) bw = calc_bandwidth_gbps(M, N, world_size, ms) vs_rccl = (bw / rccl_bw * 100) if rccl_bw > 0 else 0.0 From 1752080a3bfbf29f02e9f2b646808c897614751a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 23 Mar 2026 23:55:08 +0000 Subject: [PATCH 09/15] Apply Ruff auto-fixes --- benchmark/ccl/all_gather/benchmark_shapes.py | 40 +++++++++++++------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/benchmark/ccl/all_gather/benchmark_shapes.py b/benchmark/ccl/all_gather/benchmark_shapes.py index 33ce8b2f..34e9f1d5 100644 --- a/benchmark/ccl/all_gather/benchmark_shapes.py +++ b/benchmark/ccl/all_gather/benchmark_shapes.py @@ -30,11 +30,11 @@ # --------------------------------------------------------------------------- SHAPES = [ - (1024, 1024), # 2 MB - small activations - (2048, 4096), # 16 MB - medium MLP - (4096, 4096), # 32 MB - GPT-scale hidden - (8192, 8192), # 128 MB - large MLP / standard bench - (16384, 8192), # 256 MB - long sequences + (1024, 1024), # 2 MB - small activations + (2048, 4096), # 16 MB - medium MLP + (4096, 4096), # 32 MB - GPT-scale hidden + (8192, 8192), # 128 MB - large MLP / standard bench + (16384, 8192), # 256 MB - long sequences (16384, 16384), # 512 MB - large model partitions ] @@ -53,6 +53,7 @@ # Helpers # --------------------------------------------------------------------------- + def calc_bandwidth_gbps(M, N, world_size, ms): """Calculate all-gather bandwidth in GB/s.""" total_bytes = (world_size - 1) * M * N * ELEMENT_SIZE @@ -128,6 +129,7 @@ def validate_iris(inp, out, shmem, config): # Main # --------------------------------------------------------------------------- + def main(): parser = argparse.ArgumentParser(description="All-gather shape + CU sweep benchmark") parser.add_argument("--csv", type=str, default=None, help="Output CSV file path") @@ -173,9 +175,7 @@ def main(): ) if is_root: - total_heap_per_ctx = sum( - (M * N + world_size * M * N) * ELEMENT_SIZE for M, N in SHAPES - ) + total_heap_per_ctx = sum((M * N + world_size * M * N) * ELEMENT_SIZE for M, N in SHAPES) print(f"Heap usage per context: {total_heap_per_ctx / (1024**3):.2f} GB") # Collect results: list of dicts @@ -218,9 +218,15 @@ def main(): rccl_bw = calc_bandwidth_gbps(M, N, world_size, rccl_ms) row = { - "shape": shape_str, "M": M, "N": N, "data_mb": data_mb, - "backend": "RCCL", "cus": "-", "time_ms": rccl_ms, - "bw_gbps": rccl_bw, "vs_rccl_pct": 100.0, + "shape": shape_str, + "M": M, + "N": N, + "data_mb": data_mb, + "backend": "RCCL", + "cus": "-", + "time_ms": rccl_ms, + "bw_gbps": rccl_bw, + "vs_rccl_pct": 100.0, } results.append(row) @@ -242,9 +248,15 @@ def main(): vs_rccl = (bw / rccl_bw * 100) if rccl_bw > 0 else 0.0 row = { - "shape": shape_str, "M": M, "N": N, "data_mb": data_mb, - "backend": backend_name, "cus": cu, "time_ms": ms, - "bw_gbps": bw, "vs_rccl_pct": vs_rccl, + "shape": shape_str, + "M": M, + "N": N, + "data_mb": data_mb, + "backend": backend_name, + "cus": cu, + "time_ms": ms, + "bw_gbps": bw, + "vs_rccl_pct": vs_rccl, } results.append(row) From be08ec8dcb574284f1402e57127392289bbb3f10 Mon Sep 17 00:00:00 2001 From: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com> Date: Mon, 23 Mar 2026 17:10:31 -0700 Subject: [PATCH 10/15] Remove benchmark script and tracing from gluon all-gather Drop benchmark_shapes.py (sweep script), remove tracing instrumentation from the gluon all-gather kernel, and revert iris_gluon.py tracing class changes back to main. Co-Authored-By: Claude Opus 4.6 --- benchmark/ccl/all_gather/benchmark_shapes.py | 298 ------------------- iris/ccl/all_gather.py | 32 +- iris/experimental/iris_gluon.py | 51 ++-- 3 files changed, 27 insertions(+), 354 deletions(-) delete mode 100644 benchmark/ccl/all_gather/benchmark_shapes.py diff --git a/benchmark/ccl/all_gather/benchmark_shapes.py b/benchmark/ccl/all_gather/benchmark_shapes.py deleted file mode 100644 index 34e9f1d5..00000000 --- a/benchmark/ccl/all_gather/benchmark_shapes.py +++ /dev/null @@ -1,298 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. - -""" -Comprehensive all-gather benchmark: shape sweep x CU sweep. - -Compares RCCL (default channels), Iris Triton persistent, and Iris Gluon flat-2D -across multiple tensor shapes and CU counts. - -Usage: - torchrun --nproc_per_node=8 benchmark/ccl/all_gather/benchmark_shapes.py [--csv results.csv] -""" - -import argparse -import csv -import io -import os - -import torch -import torch.distributed as dist - -import iris -from iris.ccl import Config -import iris.experimental.iris_gluon as iris_gluon - - -# --------------------------------------------------------------------------- -# Configuration -# --------------------------------------------------------------------------- - -SHAPES = [ - (1024, 1024), # 2 MB - small activations - (2048, 4096), # 16 MB - medium MLP - (4096, 4096), # 32 MB - GPT-scale hidden - (8192, 8192), # 128 MB - large MLP / standard bench - (16384, 8192), # 256 MB - long sequences - (16384, 16384), # 512 MB - large model partitions -] - -CU_COUNTS = [8, 16, 32, 64, 96] - -DTYPE = torch.float16 -DTYPE_STR = "fp16" -ELEMENT_SIZE = 2 # bytes per fp16 element - -# Default benchmark parameters -DEFAULT_N_WARMUP = 25 -DEFAULT_N_REPEAT = 100 - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def calc_bandwidth_gbps(M, N, world_size, ms): - """Calculate all-gather bandwidth in GB/s.""" - total_bytes = (world_size - 1) * M * N * ELEMENT_SIZE - total_gb = total_bytes / (1024**3) - return total_gb / (ms * 1e-3) if ms > 0 else 0.0 - - -def bench_rccl(M, N, rank, world_size, n_warmup, n_repeat): - """Benchmark RCCL all_gather_into_tensor at default channel config.""" - inp = torch.zeros(M, N, dtype=DTYPE, device=f"cuda:{rank}") - inp.fill_(float(rank + 1)) - out = torch.zeros(world_size * M, N, dtype=DTYPE, device=f"cuda:{rank}") - - # Warmup - for _ in range(10): - dist.all_gather_into_tensor(out, inp) - torch.cuda.synchronize() - dist.barrier() - - out.zero_() - inp.fill_(float(rank + 1)) - dist.barrier() - - def fn(): - dist.all_gather_into_tensor(out, inp) - - ms = iris.do_bench(fn, dist.barrier, n_warmup=n_warmup, n_repeat=n_repeat) - return ms - - -def bench_iris(inp, out, shmem, config, n_warmup, n_repeat): - """Benchmark Iris all-gather with pre-allocated heap tensors.""" - rank = shmem.get_rank() - - inp.fill_(float(rank + 1)) - shmem.barrier() - - def fn(): - shmem.ccl.all_gather(out, inp, config=config, async_op=False) - - ms = iris.do_bench(fn, shmem.barrier, n_warmup=n_warmup, n_repeat=n_repeat) - shmem.barrier() - return ms - - -def validate_iris(inp, out, shmem, config): - """Quick correctness check with pre-allocated heap tensors.""" - world_size = shmem.get_num_ranks() - rank = shmem.get_rank() - M = inp.shape[0] - - inp.fill_(float(rank + 1)) - out.zero_() - shmem.barrier() - - shmem.ccl.all_gather(out, inp, config=config, async_op=False) - shmem.barrier() - torch.cuda.synchronize() - - ok = True - for r in range(world_size): - expected = float(r + 1) - chunk = out[r * M : (r + 1) * M, :] - if not torch.allclose(chunk, torch.full_like(chunk, expected), atol=1e-3): - ok = False - break - - shmem.barrier() - return ok - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - - -def main(): - parser = argparse.ArgumentParser(description="All-gather shape + CU sweep benchmark") - parser.add_argument("--csv", type=str, default=None, help="Output CSV file path") - parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size (default 16 GB)") - parser.add_argument("--validate", action="store_true", help="Validate correctness before benchmarking") - parser.add_argument("--n_warmup", type=int, default=DEFAULT_N_WARMUP, help="Warmup iterations") - parser.add_argument("--n_repeat", type=int, default=DEFAULT_N_REPEAT, help="Benchmark iterations") - args = parser.parse_args() - - n_warmup = args.n_warmup - n_repeat = args.n_repeat - - # torchrun sets these env vars - local_rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - torch.cuda.set_device(local_rank) - dist.init_process_group( - backend="nccl", - device_id=torch.device(f"cuda:{local_rank}"), - ) - - rank = dist.get_rank() - is_root = rank == 0 - - # Initialize both Triton and Gluon shmem contexts. - shmem_triton = iris.iris(args.heap_size) - shmem_gluon = iris_gluon.iris(args.heap_size) - - # Pre-allocate all input/output tensor pairs on each heap upfront. - # The symmetric heap uses a bump allocator (no free), so we must - # allocate everything before the benchmark loop. - triton_bufs = {} # (M, N) -> (inp, out) - gluon_bufs = {} - for M, N in SHAPES: - triton_bufs[(M, N)] = ( - shmem_triton.zeros((M, N), dtype=DTYPE), - shmem_triton.zeros((world_size * M, N), dtype=DTYPE), - ) - gluon_bufs[(M, N)] = ( - shmem_gluon.zeros((M, N), dtype=DTYPE), - shmem_gluon.zeros((world_size * M, N), dtype=DTYPE), - ) - - if is_root: - total_heap_per_ctx = sum((M * N + world_size * M * N) * ELEMENT_SIZE for M, N in SHAPES) - print(f"Heap usage per context: {total_heap_per_ctx / (1024**3):.2f} GB") - - # Collect results: list of dicts - results = [] - - # Table header - if is_root: - hdr = ( - f"{'Shape':>14s} {'Size':>7s} {'Backend':>10s} {'CUs':>4s} " - f"{'Time(ms)':>9s} {'BW(GB/s)':>9s} {'vs RCCL':>7s}" - ) - print("=" * len(hdr)) - print(hdr) - print("=" * len(hdr)) - - for M, N in SHAPES: - data_mb = M * N * ELEMENT_SIZE / (1024**2) - shape_str = f"{M}x{N}" - - t_inp, t_out = triton_bufs[(M, N)] - g_inp, g_out = gluon_bufs[(M, N)] - - # --- Optional validation --- - if args.validate: - for cu in [CU_COUNTS[-1]]: # validate at highest CU count only - triton_cfg = Config(comm_sms=cu) - gluon_cfg = Config(comm_sms=cu, use_gluon=True) - - ok_t = validate_iris(t_inp, t_out, shmem_triton, triton_cfg) - ok_g = validate_iris(g_inp, g_out, shmem_gluon, gluon_cfg) - - if is_root: - if not ok_t: - print(f"WARNING: Triton validation FAILED for {shape_str} cu={cu}") - if not ok_g: - print(f"WARNING: Gluon validation FAILED for {shape_str} cu={cu}") - - # --- RCCL baseline --- - rccl_ms = bench_rccl(M, N, rank, world_size, n_warmup, n_repeat) - rccl_bw = calc_bandwidth_gbps(M, N, world_size, rccl_ms) - - row = { - "shape": shape_str, - "M": M, - "N": N, - "data_mb": data_mb, - "backend": "RCCL", - "cus": "-", - "time_ms": rccl_ms, - "bw_gbps": rccl_bw, - "vs_rccl_pct": 100.0, - } - results.append(row) - - if is_root: - print( - f"{shape_str:>14s} {data_mb:6.0f}M {'RCCL':>10s} {'-':>4s} " - f"{rccl_ms:9.3f} {rccl_bw:9.1f} {100.0:6.1f}%" - ) - - # --- Iris Triton + Gluon at each CU count --- - for cu in CU_COUNTS: - for backend_name, shmem, inp, out, use_gluon in [ - ("Triton", shmem_triton, t_inp, t_out, False), - ("Gluon", shmem_gluon, g_inp, g_out, True), - ]: - cfg = Config(comm_sms=cu, use_gluon=use_gluon) - ms = bench_iris(inp, out, shmem, cfg, n_warmup, n_repeat) - bw = calc_bandwidth_gbps(M, N, world_size, ms) - vs_rccl = (bw / rccl_bw * 100) if rccl_bw > 0 else 0.0 - - row = { - "shape": shape_str, - "M": M, - "N": N, - "data_mb": data_mb, - "backend": backend_name, - "cus": cu, - "time_ms": ms, - "bw_gbps": bw, - "vs_rccl_pct": vs_rccl, - } - results.append(row) - - if is_root: - print( - f"{shape_str:>14s} {data_mb:6.0f}M {backend_name:>10s} {cu:4d} " - f"{ms:9.3f} {bw:9.1f} {vs_rccl:6.1f}%" - ) - - # Separator between shapes - if is_root: - print("-" * 72) - - # --- Summary CSV --- - if is_root: - buf = io.StringIO() - writer = csv.DictWriter( - buf, - fieldnames=["shape", "M", "N", "data_mb", "backend", "cus", "time_ms", "bw_gbps", "vs_rccl_pct"], - ) - writer.writeheader() - writer.writerows(results) - - csv_text = buf.getvalue() - - if args.csv: - with open(args.csv, "w") as f: - f.write(csv_text) - print(f"\nResults written to {args.csv}") - else: - print("\n--- CSV ---") - print(csv_text) - - dist.barrier() - dist.destroy_process_group() - - -if __name__ == "__main__": - main() diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index 2f4a190e..fedb25ba 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -11,7 +11,6 @@ import iris from .config import Config from .utils import extract_group_info -from iris.tracing.events import TraceEvent # Conditional import for Gluon try: @@ -326,7 +325,6 @@ def persistent_all_gather_gluon( COMM_SMS: gl.constexpr, THREADS_PER_WARP: gl.constexpr, WARPS_PER_CTA: gl.constexpr, - TRACING: gl.constexpr = False, ): """ Persistent all-gather kernel using Gluon with flat-2D tiling. @@ -374,10 +372,8 @@ def persistent_all_gather_gluon( COMM_SMS: Number of CUs used for persistent scheduling. THREADS_PER_WARP: Threads per warp/wavefront (64 for AMD, 32 for NVIDIA). WARPS_PER_CTA: Number of warps per workgroup. Must match num_warps. - TRACING: If True, record load/store events into trace buffers. """ - ctx = IrisDeviceCtx.initialize(context_tensor, tracing=TRACING) - events = TraceEvent() + ctx = IrisDeviceCtx.initialize(context_tensor, tracing=False) pid = gl.program_id(0) @@ -417,18 +413,7 @@ def persistent_all_gather_gluon( # Single flat load of the entire tile input_offsets = row * stride_in_m + col * stride_in_n input_addr = input_ptr + input_offsets - if TRACING: - h_load = ctx.tracing.record_event_start( - event_id=events.load, - target_rank=group_rank, - address=input_addr, - pid_m=pid_m, - pid_n=pid_n, - mask=mask, - ) data = gl.load(input_addr, mask=mask, other=0.0) - if TRACING: - ctx.tracing.record_event_end(h_load) # Output: this rank's data goes to output[group_rank * M + row, col] output_row = group_rank * M + row @@ -442,16 +427,6 @@ def persistent_all_gather_gluon( target_iris_rank = rank_start + dest_idx * rank_stride output_ptrs = output_ptr + output_offsets - if TRACING: - h_store = ctx.tracing.record_event_start( - event_id=events.store, - target_rank=target_iris_rank, - address=output_ptrs, - pid_m=pid_m, - pid_n=pid_n, - mask=mask, - ) - if dest_idx == group_rank: gl.store(output_ptrs, data, mask=mask, cache_modifier=".wt") else: @@ -465,9 +440,6 @@ def persistent_all_gather_gluon( remote_ptrs = tl.cast(remote_ptrs_int, output_ptrs.dtype) gl.store(remote_ptrs, data, mask=mask) - if TRACING: - ctx.tracing.record_event_end(h_store) - def all_gather( output_tensor, @@ -556,7 +528,6 @@ def all_gather( f"Recommended: block_size_m=8, block_size_n=256." ) - tracing_enabled = hasattr(shmem, "tracing") and shmem.tracing.enabled context_tensor = shmem.get_device_context() persistent_all_gather_gluon[(config.comm_sms,)]( @@ -581,7 +552,6 @@ def all_gather( config.comm_sms, config.threads_per_warp, config.num_warps, - tracing_enabled, num_stages=config.num_stages, num_warps=config.num_warps, waves_per_eu=config.waves_per_eu, diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 2f56137e..97add62f 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -72,7 +72,7 @@ class _GluonDeviceTracingCls: / record_event_end to bracket operations; events are exported via Tracing.export(). """ - enabled: gl.tensor + enabled: tl.constexpr rank: gl.tensor max_events: gl.tensor counter: gl.tensor @@ -91,7 +91,6 @@ class _GluonDeviceTracingCls: buf_op_index: gl.tensor buf_payload_size: gl.tensor - @gluon.constexpr_function def __init__( self, enabled, @@ -157,8 +156,8 @@ def record_event_start( pid_n: Program ID in N dimension mask: Optional mask tensor indicating valid elements. """ - if self.enabled == 0: - return tl.cast(self.enabled, tl.int32) + if not self.enabled: + return tl.cast(0, tl.int32) event_idx = tl.atomic_add(self.counter, 1) op_index = tl.atomic_add(self.op_index_counter, 1) @@ -166,29 +165,29 @@ def record_event_start( # Calculate payload_size from mask and datatype if mask is not None: mask_i32 = tl.cast(mask, tl.int32) - num_elements = tl.sum(mask_i32) + num_elements = gl.sum(mask_i32, axis=0) elem_type = address.dtype.element_ty bitwidth = elem_type.primitive_bitwidth elem_size_bytes = bitwidth // 8 - payload_size = num_elements * elem_size_bytes + payload_size = num_elements * tl.cast(elem_size_bytes, tl.int32) else: - payload_size = self.enabled * 0 # scalar 0 without tl.full + payload_size = tl.cast(0, tl.int32) if event_idx < self.max_events: - tl.store(self.buf_event_id + event_idx, event_id) - tl.store(self.buf_pid + event_idx, gl.program_id(0)) - tl.store(self.buf_pid_m + event_idx, pid_m) - tl.store(self.buf_pid_n + event_idx, pid_n) - tl.store(self.buf_cur_rank + event_idx, self.rank) - tl.store(self.buf_target_rank + event_idx, target_rank) + tl.store(self.buf_event_id + event_idx, tl.cast(event_id, tl.int32)) + tl.store(self.buf_pid + event_idx, tl.cast(gl.program_id(0), tl.int32)) + tl.store(self.buf_pid_m + event_idx, tl.cast(pid_m, tl.int32)) + tl.store(self.buf_pid_n + event_idx, tl.cast(pid_n, tl.int32)) + tl.store(self.buf_cur_rank + event_idx, tl.cast(self.rank, tl.int32)) + tl.store(self.buf_target_rank + event_idx, tl.cast(target_rank, tl.int32)) tl.store(self.buf_xcc_id + event_idx, device_utils.get_xcc_id()) tl.store(self.buf_cu_id + event_idx, device_utils.get_cu_id()) tl.store(self.buf_timestamp + event_idx, device_utils.read_realtime()) addr_i64 = tl.cast(address, tl.int64) - tl.store(self.buf_address + event_idx, tl.min(addr_i64)) - tl.store(self.buf_duration_cycles + event_idx, tl.cast(self.enabled * 0, tl.int64)) + tl.store(self.buf_address + event_idx, gl.min(addr_i64, axis=0)) + tl.store(self.buf_duration_cycles + event_idx, tl.cast(0, tl.int64)) tl.store(self.buf_op_index + event_idx, op_index) - tl.store(self.buf_payload_size + event_idx, payload_size) + tl.store(self.buf_payload_size + event_idx, tl.cast(payload_size, tl.int32)) return event_idx @gluon.jit @@ -198,7 +197,7 @@ def record_event_end(self, handle): Only stores when handle < max_events (bounds check). """ - if self.enabled == 0: + if not self.enabled: return end_ts = device_utils.read_realtime() @@ -268,7 +267,7 @@ def initialize(context_tensor, tracing: gl.constexpr = False): # Layout: [cur_rank, num_ranks, heap_base_0..N-1, trace_enabled, max_events, # trace_counter_ptr, op_index_counter_ptr, buf_event_id, ...(13 buffers)] trace_info_base = 2 + num_ranks + 1 # skip cur_rank, num_ranks, heap_bases, trace_enabled - max_events = gl.load(context_tensor + trace_info_base + 0) + max_events = tl.cast(gl.load(context_tensor + trace_info_base + 0), tl.int32) trace_counter_ptr = gl.load(context_tensor + trace_info_base + 1) op_index_counter_ptr = gl.load(context_tensor + trace_info_base + 2) @@ -292,10 +291,8 @@ def initialize(context_tensor, tracing: gl.constexpr = False): buf_op_index = tl.cast(gl.load(context_tensor + buf_base + 11), tl.pointer_type(tl.int32)) buf_payload_size = tl.cast(gl.load(context_tensor + buf_base + 12), tl.pointer_type(tl.int32)) - # Read trace_enabled flag from context tensor (at index 2 + num_ranks) - trace_enabled_val = gl.load(context_tensor + 2 + num_ranks) device_tracing = GluonDeviceTracing( - enabled=trace_enabled_val, + enabled=tracing, rank=cur_rank, max_events=max_events, counter=trace_counter, @@ -318,11 +315,9 @@ def initialize(context_tensor, tracing: gl.constexpr = False): # When tracing disabled, use dummy pointers (never dereferenced) dummy_ptr_i32 = tl.cast(context_tensor, tl.pointer_type(tl.int32)) dummy_ptr_i64 = tl.cast(context_tensor, tl.pointer_type(tl.int64)) - # Read trace_enabled flag from context tensor (0 = disabled) - trace_enabled_val = gl.load(context_tensor + 2 + num_ranks) - max_events_zero = trace_enabled_val # 0 when tracing disabled + max_events_zero = tl.cast(0, tl.int32) device_tracing = GluonDeviceTracing( - enabled=trace_enabled_val, + enabled=tracing, rank=cur_rank, max_events=max_events_zero, counter=dummy_ptr_i32, @@ -370,6 +365,12 @@ def _translate(self, ptr, from_rank, to_rank): # Cast to_base back to pointer type translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) + # Optimization to vectorize the load/store - similar to iris.py + # This enables the compiler to generate dwordx4 or wider loads + # Note: Gluon uses scalar multiples, not 2D tuples like Triton + # ptr = gl.max_contiguous(gl.multiple_of(ptr, 64), 64) + # translated_ptr = gl.max_contiguous(gl.multiple_of(translated_ptr, 64), 64) + return translated_ptr @gluon.jit From 25a26945a9a0fbf679da59343aaaa50355b85d02 Mon Sep 17 00:00:00 2001 From: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com> Date: Tue, 24 Mar 2026 16:12:11 -0700 Subject: [PATCH 11/15] Address review comments: variant validation, test fixes, docstring - Reject non-persistent variants when use_gluon=True instead of silently ignoring all_gather_variant - Import GLUON_AVAILABLE from all_gather module in test to check both iris_gluon and triton.experimental.gluon availability - Reduce test heap size from 8GB to 1GB - Fix docstring to reflect flat-2D tile size constraint Co-Authored-By: Claude Opus 4.6 --- iris/ccl/all_gather.py | 7 +++++++ iris/ccl/config.py | 3 ++- tests/ccl/test_all_gather_gluon.py | 6 ++---- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index fedb25ba..2cacb5f9 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -499,6 +499,13 @@ def all_gather( if not hasattr(shmem, "get_device_context"): raise ValueError("use_gluon=True requires Iris Gluon context. Use iris.experimental.iris_gluon.iris()") + # Gluon only supports the persistent variant + if config.all_gather_variant != "persistent": + raise ValueError( + f"Gluon all_gather only supports all_gather_variant='persistent', " + f"got '{config.all_gather_variant}'." + ) + # Apply optimal defaults for gluon flat-2D kernel when user hasn't # overridden block sizes from the Config defaults (32x64). block_size_m = config.block_size_m diff --git a/iris/ccl/config.py b/iris/ccl/config.py index 9dd35cef..1084de06 100644 --- a/iris/ccl/config.py +++ b/iris/ccl/config.py @@ -47,7 +47,8 @@ class Config: num_stages: Number of pipeline stages for the kernel (default: 1) num_warps: Number of warps per workgroup (default: 4). For gluon kernels, this also sets WARPS_PER_CTA in the BlockedLayout. The product - threads_per_warp * num_warps determines the minimum block_size_n. + threads_per_warp * num_warps determines the minimum tile size + (block_size_m * block_size_n for flat-2D, or block_size_n for 1D). threads_per_warp: Threads per warp/wavefront (default: 64). Must match the hardware wavefront size: 64 for AMD GPUs, 32 for NVIDIA. Used by gluon kernels to construct BlockedLayout for diff --git a/tests/ccl/test_all_gather_gluon.py b/tests/ccl/test_all_gather_gluon.py index a5b65abd..55410b6a 100644 --- a/tests/ccl/test_all_gather_gluon.py +++ b/tests/ccl/test_all_gather_gluon.py @@ -13,9 +13,7 @@ try: import iris.experimental.iris_gluon as iris_gluon from iris.ccl import Config - from iris.ccl.all_gather import all_gather - - GLUON_AVAILABLE = True + from iris.ccl.all_gather import all_gather, GLUON_AVAILABLE except ImportError: GLUON_AVAILABLE = False @@ -46,7 +44,7 @@ def test_all_gather_gluon(dtype, M, N, block_size_m, block_size_n): if not dist.is_initialized(): pytest.skip("torch.distributed not initialized") - heap_size = 2**33 # 8GB + heap_size = 2**30 # 1GB shmem = iris_gluon.iris(heap_size) rank = shmem.get_rank() world_size = shmem.get_num_ranks() From 3900b4b934aac0827b5fc4e41a2d92789115824e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 24 Mar 2026 23:13:32 +0000 Subject: [PATCH 12/15] Apply Ruff auto-fixes --- iris/ccl/all_gather.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index 2cacb5f9..6b6c53db 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -502,8 +502,7 @@ def all_gather( # Gluon only supports the persistent variant if config.all_gather_variant != "persistent": raise ValueError( - f"Gluon all_gather only supports all_gather_variant='persistent', " - f"got '{config.all_gather_variant}'." + f"Gluon all_gather only supports all_gather_variant='persistent', got '{config.all_gather_variant}'." ) # Apply optimal defaults for gluon flat-2D kernel when user hasn't From 9c326a01fbaa808e78dba6c5932f9bf04bbccb79 Mon Sep 17 00:00:00 2001 From: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com> Date: Tue, 24 Mar 2026 16:18:02 -0700 Subject: [PATCH 13/15] Add --use_gluon flag to ccl all_gather example Switches to iris_gluon context and passes use_gluon=True to Config so the example can exercise both Triton and Gluon kernel backends. Co-Authored-By: Claude Opus 4.6 --- examples/25_ccl_all_gather/example.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/25_ccl_all_gather/example.py b/examples/25_ccl_all_gather/example.py index 18266d90..8365163d 100644 --- a/examples/25_ccl_all_gather/example.py +++ b/examples/25_ccl_all_gather/example.py @@ -8,7 +8,7 @@ Each rank contributes an (M, N) tensor; every rank receives the concatenated (world_size*M, N) result. Run with: - torchrun --nproc_per_node= --standalone example.py [--validate] + torchrun --nproc_per_node= --standalone example.py [--validate] [--use_gluon] """ import argparse @@ -37,6 +37,7 @@ def parse_args(): parser.add_argument("--num_stages", type=int, default=1, help="Number of stages") parser.add_argument("--num_warps", type=int, default=4, help="Number of warps") parser.add_argument("--waves_per_eu", type=int, default=0, help="Number of waves per EU") + parser.add_argument("--use_gluon", action="store_true", help="Use Gluon kernel backend") return vars(parser.parse_args()) @@ -47,7 +48,11 @@ def main(): torch.cuda.set_device(local_rank) dist.init_process_group(backend="gloo") - ctx = iris.iris(heap_size=args["heap_size"]) + if args["use_gluon"]: + import iris.experimental.iris_gluon as iris_gluon + ctx = iris_gluon.iris(heap_size=args["heap_size"]) + else: + ctx = iris.iris(heap_size=args["heap_size"]) rank = ctx.get_rank() world_size = ctx.get_num_ranks() @@ -67,6 +72,7 @@ def main(): "num_stages": args["num_stages"], "num_warps": args["num_warps"], "waves_per_eu": args["waves_per_eu"], + "use_gluon": args["use_gluon"], } config = Config(**config_kwargs) From e8c2bdbbfea7b517e592667f44dca5bdb0582c52 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 24 Mar 2026 23:19:08 +0000 Subject: [PATCH 14/15] Apply Ruff auto-fixes --- examples/25_ccl_all_gather/example.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/25_ccl_all_gather/example.py b/examples/25_ccl_all_gather/example.py index 8365163d..ec5aa8fb 100644 --- a/examples/25_ccl_all_gather/example.py +++ b/examples/25_ccl_all_gather/example.py @@ -50,6 +50,7 @@ def main(): if args["use_gluon"]: import iris.experimental.iris_gluon as iris_gluon + ctx = iris_gluon.iris(heap_size=args["heap_size"]) else: ctx = iris.iris(heap_size=args["heap_size"]) From a4a73451aa7ab7113566babace00c4c662c94b38 Mon Sep 17 00:00:00 2001 From: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com> Date: Wed, 25 Mar 2026 00:01:21 -0700 Subject: [PATCH 15/15] Fix gluon all_gather test OOM: size heap based on test parameters The 8192x8192 test case with 8 ranks needs ~2.3GB per rank for float32. The previous 1GB fixed heap was too small. Now computes heap size from M, N, dtype, and world_size with 2x headroom (minimum 1GB). Co-Authored-By: Claude Opus 4.6 --- tests/ccl/test_all_gather_gluon.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/ccl/test_all_gather_gluon.py b/tests/ccl/test_all_gather_gluon.py index 55410b6a..a912bfe1 100644 --- a/tests/ccl/test_all_gather_gluon.py +++ b/tests/ccl/test_all_gather_gluon.py @@ -5,6 +5,8 @@ Test suite for all-gather collective operation using Gluon. """ +import os + import pytest import torch import torch.distributed as dist @@ -44,7 +46,11 @@ def test_all_gather_gluon(dtype, M, N, block_size_m, block_size_n): if not dist.is_initialized(): pytest.skip("torch.distributed not initialized") - heap_size = 2**30 # 1GB + # Size heap to fit input (M*N) + output (max_ranks*M*N) with headroom + max_ranks = int(os.environ.get("WORLD_SIZE", 8)) + elem_size = torch.tensor([], dtype=dtype).element_size() + needed = (1 + max_ranks) * M * N * elem_size + heap_size = max(2**30, int(needed * 2)) # 2x headroom, minimum 1GB shmem = iris_gluon.iris(heap_size) rank = shmem.get_rank() world_size = shmem.get_num_ranks()