Skip to content

Add Gluon flat-2D all-gather kernel#475

Open
mawad-amd wants to merge 16 commits intomainfrom
muhaawad/gluon-all-gather-v2
Open

Add Gluon flat-2D all-gather kernel#475
mawad-amd wants to merge 16 commits intomainfrom
muhaawad/gluon-all-gather-v2

Conversation

@mawad-amd
Copy link
Collaborator

Summary

  • Add a Gluon-based all-gather kernel using flat-2D tiling with traffic-shaped writes, as an alternative to the existing Triton persistent kernel
  • Extend Config with threads_per_warp field for Gluon's BlockedLayout construction
  • Add test suite comparing Gluon all-gather output against PyTorch's all_gather_into_tensor

Design

The kernel uses a 1D BlockedLayout over BLOCK_SIZE_M * BLOCK_SIZE_N elements, recovering 2D row/col via integer div/mod. This produces one vectorized load and world_size vectorized stores per tile while staying within Gluon's 1D layout framework.

Key optimizations:

  • Flat-2D tiling: eliminates the inner row loop from the earlier row-by-row approach
  • Hoisted pointer translation: local_base loaded once outside the tile loop (avoids 2x gl.load(heap_bases) per ctx.store() call)
  • Traffic shaping: staggered write order (group_rank + rank_idx) % world_size so ranks write to different targets simultaneously, reducing memory controller contention

Dispatched via Config(use_gluon=True) through the existing shmem.ccl.all_gather() path.

Benchmark results (MI325X, 8 GPUs, fp16)

Shape RCCL Triton (best CUs) Gluon (best CUs)
1024x1024 (2 MB) 138 GB/s 91 GB/s (96) 110 GB/s (96)
4096x4096 (32 MB) 272 GB/s 249 GB/s (96) 266 GB/s (64)
8192x8192 (128 MB) 286 GB/s 289 GB/s (96) 285 GB/s (96)
16384x16384 (512 MB) 293 GB/s 295 GB/s (96) 292 GB/s (96)

Gluon matches RCCL at large shapes and outperforms Triton at small shapes (80% vs 66% of RCCL at 2 MB).

Test plan

  • CI passes on MI325X (8-GPU torchrun with pytest)
  • test_all_gather_gluon.py validates correctness across fp16/fp32/bf16 and three tile sizes
  • Verify no regression on existing Triton all-gather tests

🤖 Generated with Claude Code

mawad-amd and others added 10 commits March 23, 2026 11:58
…ation

Gluon all-gather kernel that uses explicit BlockedLayout for column-dimension
vectorization. Each row is loaded once and broadcast to all ranks via
ctx.store(), avoiding redundant loads and enabling dwordx4 memory ops.

Key design:
- Row-by-row iteration: load row once, write to all ranks (1 load, W stores)
- Explicit BlockedLayout([SPT], [64], [4], [0]) on column dimension
  where SPT = block_size_n / 256, controlling vector width:
    SPT=1 -> scalar, SPT=2 -> dword, SPT=4 -> dwordx4 (optimal)
- Uses ctx.store() for remote writes (compiler-optimized pointer translation)
- Optimal tile: bm=32, bn=1024 (SPT=4) for dwordx4 on AMD GFX9+

Also includes:
- threads_per_warp config field for BlockedLayout construction
- Simplified IrisDeviceCtx (tracing removed for cleaner codegen)
- Parameterized correctness tests for multiple tile sizes and dtypes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…hints

The gluon kernel was using ctx.store() which calls _translate() on every
remote store, causing two problems visible in the assembly:

1. global_store_short (2-byte scalar) instead of global_store_dwordx4 (16-byte)
   because _translate() pointer arithmetic breaks contiguity attributes
2. Two global_load_dwordx2 for heap_bases per remote rank per row (14 loads/row
   in 8-rank case) because heap_bases[from_rank] and heap_bases[to_rank] are
   reloaded every call

Fix: bypass ctx.store() and perform pointer translation inline:
- Hoist local_base = gl.load(heap_bases + iris_rank) before all loops
- Compute ptr_delta = target_base - local_base manually
- Re-apply gl.max_contiguous/gl.multiple_of to translated pointer
- Use gl.store() directly, preserving vectorization hints for dwordx4

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The row-by-row kernel iterated per-row with 1D loads, producing 32x more
instructions than Triton's 2D tile loads and underperforming by 6-38%.

The flat-2D kernel uses a single 1D arange over BLOCK_SIZE_M * BLOCK_SIZE_N
elements with div/mod for 2D indexing, producing one load + world_size
stores per tile (matching Triton's instruction structure). Additional
optimizations:

- Hoisted pointer translation: local_base loaded once outside tile loop
- Traffic-shaped writes: staggered (group_rank + rank_idx) % world_size
  write order avoids memory controller contention on the receiver side
- Auto-default tile size: 8x256 (2048 elements, 8/thread) when user
  doesn't override Config defaults

Also adds GluonDeviceTracing to iris_gluon.py for optional device-side
event recording (TRACING=False by default, zero overhead when disabled).

Benchmarked on MI308X (gfx942), 8192x8192 fp16, ROCm 7.2.0:

  8 ranks, 32 CUs: Gluon 293 GB/s vs Triton 287 GB/s (+2%)
  4 ranks, 48 CUs: Gluon 134 GB/s vs Triton 125 GB/s (+7%)
  4 ranks, 32 CUs: Gluon 133 GB/s vs Triton 128 GB/s (+3%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Self-contained benchmark script that compares RCCL (default channels),
Iris Triton persistent, and Iris Gluon flat-2D all-gather across 6
tensor shapes (2 MB to 512 MB) and 5 CU counts (8-96). Designed to
run under torchrun in a single invocation with formatted table output
and optional CSV export.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pass n_warmup and n_repeat as function parameters instead of using
global statement, which caused SyntaxError when the names appeared
in argparse defaults before the global declaration.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The symmetric heap uses a bump allocator with no free. Allocating
input/output tensors per-iteration exhausted the heap at larger shapes.
Now all buffers for every shape are allocated once before the benchmark
loop starts (~8.5 GB per shmem context).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Drop benchmark_shapes.py (sweep script), remove tracing
instrumentation from the gluon all-gather kernel, and revert
iris_gluon.py tracing class changes back to main.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings March 24, 2026 00:14
@mawad-amd mawad-amd requested review from BKP and neoblizz as code owners March 24, 2026 00:14
@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Mar 24, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a Gluon-based flat-2D tiled all-gather kernel (with traffic-shaped writes) as an alternative to the existing Triton implementation, updates Config to support Gluon layout construction, and introduces a correctness test comparing against PyTorch’s all_gather_into_tensor.

Changes:

  • Add a Gluon persistent all-gather kernel using a flat-2D tiling strategy and RMA stores.
  • Extend Config with threads_per_warp validation to parameterize Gluon BlockedLayout.
  • Add a distributed test validating Gluon all-gather correctness against PyTorch.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
tests/ccl/test_all_gather_gluon.py Adds a correctness test for Gluon all-gather vs PyTorch reference.
iris/ccl/config.py Adds threads_per_warp configuration + validation and documents Gluon-related tuning knobs.
iris/ccl/all_gather.py Adds Gluon kernel + dispatch path behind Config(use_gluon=True) and input validation for tiling constraints.

@mawad-amd
Copy link
Collaborator Author

@copilot Run benchmarks comparing all three backends (RCCL, Iris Triton, Iris Gluon) and produce throughput (GB/s) vs CU count plots. Do NOT commit any scripts or files to the repo.

You have AMD MI325X GPUs and PyTorch with ROCm available. Run pip install -e . from this branch first.

Shapes to test (fp16, 8 GPUs):

  • Small: 2048x4096 (16 MB)
  • Medium: 8192x8192 (128 MB)
  • Large: 16384x16384 (512 MB)

CU counts: 4, 8, 16, 24, 32, 48, 64, 80, 96, 112, 128

How to benchmark each backend at each CU count:

  1. Iris Triton: shmem.ccl.all_gather(out, inp, config=Config(comm_sms=CU_COUNT)) — uses iris.iris() context
  2. Iris Gluon: shmem.ccl.all_gather(out, inp, config=Config(comm_sms=CU_COUNT, use_gluon=True)) — uses iris_gluon.iris() context
  3. RCCL: dist.all_gather_into_tensor(out, inp) with NCCL_MIN_NCHANNELS=CU_COUNT and NCCL_MAX_NCHANNELS=CU_COUNT set before launch. In RCCL, channels = CTAs = CUs (1:1). Each CU count needs a separate torchrun invocation since these are env vars read at init.

Use iris.do_bench() for timing. Bandwidth: (world_size - 1) * M * N * 2 / time_seconds / 1024^3 GB/s.

Output: 3 plots (one per shape), each with 3 lines (RCCL, Triton, Gluon), X-axis = CU count, Y-axis = bandwidth (GB/s). Reply with the plots inline. Do not commit anything.

Copy link
Contributor

Copilot AI commented Mar 24, 2026

@mawad-amd I've opened a new pull request, #476, to work on those changes. Once the pull request is ready, I'll request review from you.

mawad-amd and others added 6 commits March 24, 2026 16:09
- Reject non-persistent variants when use_gluon=True instead of silently
  ignoring all_gather_variant
- Import GLUON_AVAILABLE from all_gather module in test to check both
  iris_gluon and triton.experimental.gluon availability
- Reduce test heap size from 8GB to 1GB
- Fix docstring to reflect flat-2D tile size constraint

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Switches to iris_gluon context and passes use_gluon=True to Config
so the example can exercise both Triton and Gluon kernel backends.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants