Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions examples/25_ccl_all_gather/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Each rank contributes an (M, N) tensor; every rank receives the concatenated (world_size*M, N) result.

Run with:
torchrun --nproc_per_node=<num_gpus> --standalone example.py [--validate]
torchrun --nproc_per_node=<num_gpus> --standalone example.py [--validate] [--use_gluon]
"""

import argparse
Expand Down Expand Up @@ -37,6 +37,7 @@ def parse_args():
parser.add_argument("--num_stages", type=int, default=1, help="Number of stages")
parser.add_argument("--num_warps", type=int, default=4, help="Number of warps")
parser.add_argument("--waves_per_eu", type=int, default=0, help="Number of waves per EU")
parser.add_argument("--use_gluon", action="store_true", help="Use Gluon kernel backend")
return vars(parser.parse_args())


Expand All @@ -47,7 +48,12 @@ def main():
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="gloo")

ctx = iris.iris(heap_size=args["heap_size"])
if args["use_gluon"]:
import iris.experimental.iris_gluon as iris_gluon

ctx = iris_gluon.iris(heap_size=args["heap_size"])
else:
ctx = iris.iris(heap_size=args["heap_size"])
rank = ctx.get_rank()
world_size = ctx.get_num_ranks()

Expand All @@ -67,6 +73,7 @@ def main():
"num_stages": args["num_stages"],
"num_warps": args["num_warps"],
"waves_per_eu": args["waves_per_eu"],
"use_gluon": args["use_gluon"],
}
config = Config(**config_kwargs)

Expand Down
321 changes: 274 additions & 47 deletions iris/ccl/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
from .config import Config
from .utils import extract_group_info

# Conditional import for Gluon
try:
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from iris.experimental.iris_gluon import IrisDeviceCtx

GLUON_AVAILABLE = True
except ImportError:
GLUON_AVAILABLE = False


@triton.jit()
def persistent_all_gather(
Expand Down Expand Up @@ -279,6 +289,158 @@ def persistent_all_gather_partitioned(
)


# Gluon implementation: flat-2D tiling approach
#
# Uses a single 1D arange over BLOCK_SIZE_M * BLOCK_SIZE_N elements with
# div/mod to compute 2D row/col indices. This gives one load + world_size
# stores per tile (matching Triton's 2D load/store structure) while staying
# within gluon's 1D BlockedLayout framework.
#
# Key optimizations:
# - Flat-2D tiling: eliminates the inner BLOCK_SIZE_M row loop
# - Hoisted pointer translation: local_base loaded once outside tile loop
# - Traffic shaping: staggered write order avoids memory controller contention
if GLUON_AVAILABLE:

@gluon.jit
def persistent_all_gather_gluon(
IrisDeviceCtx: gl.constexpr,
context_tensor,
input_ptr,
output_ptr,
M,
N,
stride_in_m,
stride_in_n,
stride_out_m,
stride_out_n,
group_rank: gl.constexpr,
iris_rank: gl.constexpr,
world_size: gl.constexpr,
rank_start: gl.constexpr,
rank_stride: gl.constexpr,
BLOCK_SIZE_M: gl.constexpr,
BLOCK_SIZE_N: gl.constexpr,
GROUP_SIZE_M: gl.constexpr,
COMM_SMS: gl.constexpr,
THREADS_PER_WARP: gl.constexpr,
WARPS_PER_CTA: gl.constexpr,
):
"""
Persistent all-gather kernel using Gluon with flat-2D tiling.

Uses a flat 1D index space of BLOCK_SIZE_M * BLOCK_SIZE_N elements,
computing 2D row/col via integer div/mod. This produces one vectorized
load and world_size vectorized stores per tile, matching Triton's 2D
load/store instruction structure while staying within gluon's 1D
BlockedLayout framework.

Memory layout (BlockedLayout):
A 1D BlockedLayout distributes TOTAL_ELEMS = BLOCK_SIZE_M * BLOCK_SIZE_N
elements across the thread hierarchy:
ELEMS_PER_THREAD = TOTAL_ELEMS // (THREADS_PER_WARP * WARPS_PER_CTA)

Each thread handles ELEMS_PER_THREAD contiguous elements in the
flattened row-major order. Row/col are recovered via:
row = flat_idx // BLOCK_SIZE_N
col = flat_idx % BLOCK_SIZE_N

Constraints:
- BLOCK_SIZE_M * BLOCK_SIZE_N must be a multiple of
(THREADS_PER_WARP * WARPS_PER_CTA).
- Optimal tile: 2048-4096 total elements (8-16 per thread).
Larger tiles cause register spilling and performance collapse.
- Recommended: BLOCK_SIZE_M=8, BLOCK_SIZE_N=256 (2048 elems, 8/thread).

Args:
IrisDeviceCtx: Gluon device context class for remote memory operations.
context_tensor: Opaque tensor holding IrisDeviceCtx state.
input_ptr: Pointer to local input tensor of shape (M, N).
output_ptr: Pointer to output tensor of shape (world_size * M, N).
M: Number of rows in the input tensor (per rank).
N: Number of columns.
stride_in_m, stride_in_n: Row and column strides for input tensor.
stride_out_m, stride_out_n: Row and column strides for output tensor.
group_rank: This rank's index within the ProcessGroup (0..world_size-1).
iris_rank: This rank's global index in the iris context (for RMA addressing).
world_size: Total number of ranks in the group.
rank_start: First iris rank in the group (for RMA target computation).
rank_stride: Stride between consecutive iris ranks in the group.
BLOCK_SIZE_M: Number of rows per tile.
BLOCK_SIZE_N: Number of columns per tile.
GROUP_SIZE_M: Swizzle group size for M-dimension tiling.
COMM_SMS: Number of CUs used for persistent scheduling.
THREADS_PER_WARP: Threads per warp/wavefront (64 for AMD, 32 for NVIDIA).
WARPS_PER_CTA: Number of warps per workgroup. Must match num_warps.
"""
ctx = IrisDeviceCtx.initialize(context_tensor, tracing=False)

pid = gl.program_id(0)

num_pid_m = gl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = gl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n

# Flat 1D layout covering BLOCK_SIZE_M * BLOCK_SIZE_N elements
TOTAL_ELEMS: gl.constexpr = BLOCK_SIZE_M * BLOCK_SIZE_N
ELEMS_PER_THREAD: gl.constexpr = TOTAL_ELEMS // (THREADS_PER_WARP * WARPS_PER_CTA)
flat_layout: gl.constexpr = gl.BlockedLayout([ELEMS_PER_THREAD], [THREADS_PER_WARP], [WARPS_PER_CTA], [0])

# Hoist local heap base outside the tile loop: eliminates redundant
# gl.load(heap_bases) calls in the inner store loop.
local_base = gl.load(ctx.heap_bases + iris_rank)

for tile_id in range(pid, total_tiles, COMM_SMS):
# Swizzled tile index computation for better L2 locality
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m

# Flat index -> 2D row/col within tile
flat_idx = gl.arange(0, TOTAL_ELEMS, layout=flat_layout)
row_local = flat_idx // BLOCK_SIZE_N
col_local = flat_idx % BLOCK_SIZE_N

# Global row/col
row = pid_m * BLOCK_SIZE_M + row_local
col = pid_n * BLOCK_SIZE_N + col_local

mask = (row < M) & (col < N)

# Single flat load of the entire tile
input_offsets = row * stride_in_m + col * stride_in_n
input_addr = input_ptr + input_offsets
data = gl.load(input_addr, mask=mask, other=0.0)

# Output: this rank's data goes to output[group_rank * M + row, col]
output_row = group_rank * M + row
output_offsets = output_row * stride_out_m + col * stride_out_n

# Traffic-shaped stores to all ranks: stagger write order per rank
# so each rank writes to a different target at any given moment,
# avoiding memory controller contention on the receiver side.
for rank_idx in range(world_size):
dest_idx = (group_rank + rank_idx) % world_size
target_iris_rank = rank_start + dest_idx * rank_stride
output_ptrs = output_ptr + output_offsets

if dest_idx == group_rank:
gl.store(output_ptrs, data, mask=mask, cache_modifier=".wt")
else:
# Hoisted translation: compute ptr_delta from pre-loaded
# local_base rather than calling ctx.store() which would
# do 2x gl.load(heap_bases) per call.
target_base = gl.load(ctx.heap_bases + target_iris_rank)
ptr_delta = target_base - local_base
output_ptrs_int = tl.cast(output_ptrs, gl.uint64)
remote_ptrs_int = output_ptrs_int + ptr_delta
remote_ptrs = tl.cast(remote_ptrs_int, output_ptrs.dtype)
gl.store(remote_ptrs, data, mask=mask)


def all_gather(
output_tensor,
input_tensor,
Expand Down Expand Up @@ -314,26 +476,11 @@ def all_gather(
if config is None:
config = Config(block_size_m=32, block_size_n=64)

# Check for unsupported options
if config.use_gluon:
raise ValueError(
"all_gather does not support use_gluon=True. "
"Gluon implementation is not available for all_gather. "
"Use default config (use_gluon=False)."
)

# Extract group information
# rank_in_group: position within the ProcessGroup (0, 1, 2, ...) - passed as group_rank to kernel
# rank_global: global rank in iris context - passed as iris_rank to kernel for RMA operations
rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, shmem)

# Validate COMM_SMS divisibility for partitioned variant
if config.all_gather_variant == "partitioned" and config.comm_sms % world_size != 0:
raise ValueError(
f"For all_gather_variant='partitioned', COMM_SMS ({config.comm_sms}) must be divisible by world_size ({world_size}). "
f"Please adjust config.comm_sms to be a multiple of {world_size}."
)

M, N = input_tensor.shape[:2]
expected_output_shape = (world_size * M, N)

Expand All @@ -346,41 +493,121 @@ def all_gather(
stride_in_m, stride_in_n = input_tensor.stride(0), input_tensor.stride(1)
stride_out_m, stride_out_n = output_tensor.stride(0), output_tensor.stride(1)

heap_bases = shmem.get_heap_bases()
# Choose between Triton and Gluon implementation
if config.use_gluon and GLUON_AVAILABLE:
# Check if shmem is Iris Gluon (has get_device_context method)
if not hasattr(shmem, "get_device_context"):
raise ValueError("use_gluon=True requires Iris Gluon context. Use iris.experimental.iris_gluon.iris()")

# Dispatch to the appropriate kernel based on variant
if config.all_gather_variant == "persistent":
kernel_fn = persistent_all_gather
elif config.all_gather_variant == "partitioned":
kernel_fn = persistent_all_gather_partitioned
# Gluon only supports the persistent variant
if config.all_gather_variant != "persistent":
raise ValueError(
f"Gluon all_gather only supports all_gather_variant='persistent', got '{config.all_gather_variant}'."
)

# Apply optimal defaults for gluon flat-2D kernel when user hasn't
# overridden block sizes from the Config defaults (32x64).
block_size_m = config.block_size_m
block_size_n = config.block_size_n
if block_size_m == 32 and block_size_n == 64:
# User didn't override — use optimal flat-2D tile: 8x256
block_size_m = 8
block_size_n = 256

# Validate flat-2D layout constraints.
# TOTAL_ELEMS = BLOCK_SIZE_M * BLOCK_SIZE_N must be a multiple of
# THREADS_PER_WARP * WARPS_PER_CTA so each thread gets a whole
# number of elements.
total_elems = block_size_m * block_size_n
threads_per_cta = config.threads_per_warp * config.num_warps
if total_elems < threads_per_cta:
raise ValueError(
f"Gluon all-gather requires block_size_m * block_size_n >= "
f"threads_per_warp * num_warps ({threads_per_cta}), "
f"got {block_size_m} * {block_size_n} = {total_elems}."
)
if total_elems % threads_per_cta != 0:
raise ValueError(
f"Gluon all-gather requires block_size_m * block_size_n to be a "
f"multiple of threads_per_warp * num_warps ({threads_per_cta}), "
f"got {block_size_m} * {block_size_n} = {total_elems}. "
f"Recommended: block_size_m=8, block_size_n=256."
)

context_tensor = shmem.get_device_context()

persistent_all_gather_gluon[(config.comm_sms,)](
IrisDeviceCtx,
context_tensor,
input_tensor,
output_tensor,
M,
N,
stride_in_m,
stride_in_n,
stride_out_m,
stride_out_n,
rank_in_group,
rank_global,
world_size,
rank_start,
rank_stride,
block_size_m,
block_size_n,
config.swizzle_size,
config.comm_sms,
config.threads_per_warp,
config.num_warps,
num_stages=config.num_stages,
num_warps=config.num_warps,
waves_per_eu=config.waves_per_eu,
)
else:
raise ValueError(f"Unknown all_gather_variant: {config.all_gather_variant}")
if config.use_gluon and not GLUON_AVAILABLE:
raise ValueError("Gluon is not available. Install Triton with Gluon support or set use_gluon=False")

# Validate COMM_SMS divisibility for partitioned variant
if config.all_gather_variant == "partitioned" and config.comm_sms % world_size != 0:
raise ValueError(
f"For all_gather_variant='partitioned', COMM_SMS ({config.comm_sms}) must be divisible by world_size ({world_size}). "
f"Please adjust config.comm_sms to be a multiple of {world_size}."
)

kernel_fn[(config.comm_sms,)](
input_tensor,
output_tensor,
M,
N,
stride_in_m,
stride_in_n,
stride_out_m,
stride_out_n,
heap_bases,
rank_in_group,
rank_global,
world_size,
rank_start,
rank_stride,
config.block_size_m,
config.block_size_n,
config.swizzle_size,
config.comm_sms,
config.num_xcds,
config.chunk_size,
num_stages=config.num_stages,
num_warps=config.num_warps,
waves_per_eu=config.waves_per_eu,
)
heap_bases = shmem.get_heap_bases()

# Dispatch to the appropriate kernel based on variant
if config.all_gather_variant == "persistent":
kernel_fn = persistent_all_gather
elif config.all_gather_variant == "partitioned":
kernel_fn = persistent_all_gather_partitioned
else:
raise ValueError(f"Unknown all_gather_variant: {config.all_gather_variant}")

kernel_fn[(config.comm_sms,)](
input_tensor,
output_tensor,
M,
N,
stride_in_m,
stride_in_n,
stride_out_m,
stride_out_n,
heap_bases,
rank_in_group,
rank_global,
world_size,
rank_start,
rank_stride,
config.block_size_m,
config.block_size_n,
config.swizzle_size,
config.comm_sms,
config.num_xcds,
config.chunk_size,
num_stages=config.num_stages,
num_warps=config.num_warps,
waves_per_eu=config.waves_per_eu,
)

if not async_op:
shmem.barrier()
Loading
Loading