Conversation
…ation
Gluon all-gather kernel that uses explicit BlockedLayout for column-dimension
vectorization. Each row is loaded once and broadcast to all ranks via
ctx.store(), avoiding redundant loads and enabling dwordx4 memory ops.
Key design:
- Row-by-row iteration: load row once, write to all ranks (1 load, W stores)
- Explicit BlockedLayout([SPT], [64], [4], [0]) on column dimension
where SPT = block_size_n / 256, controlling vector width:
SPT=1 -> scalar, SPT=2 -> dword, SPT=4 -> dwordx4 (optimal)
- Uses ctx.store() for remote writes (compiler-optimized pointer translation)
- Optimal tile: bm=32, bn=1024 (SPT=4) for dwordx4 on AMD GFX9+
Also includes:
- threads_per_warp config field for BlockedLayout construction
- Simplified IrisDeviceCtx (tracing removed for cleaner codegen)
- Parameterized correctness tests for multiple tile sizes and dtypes
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…hints The gluon kernel was using ctx.store() which calls _translate() on every remote store, causing two problems visible in the assembly: 1. global_store_short (2-byte scalar) instead of global_store_dwordx4 (16-byte) because _translate() pointer arithmetic breaks contiguity attributes 2. Two global_load_dwordx2 for heap_bases per remote rank per row (14 loads/row in 8-rank case) because heap_bases[from_rank] and heap_bases[to_rank] are reloaded every call Fix: bypass ctx.store() and perform pointer translation inline: - Hoist local_base = gl.load(heap_bases + iris_rank) before all loops - Compute ptr_delta = target_base - local_base manually - Re-apply gl.max_contiguous/gl.multiple_of to translated pointer - Use gl.store() directly, preserving vectorization hints for dwordx4 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The row-by-row kernel iterated per-row with 1D loads, producing 32x more instructions than Triton's 2D tile loads and underperforming by 6-38%. The flat-2D kernel uses a single 1D arange over BLOCK_SIZE_M * BLOCK_SIZE_N elements with div/mod for 2D indexing, producing one load + world_size stores per tile (matching Triton's instruction structure). Additional optimizations: - Hoisted pointer translation: local_base loaded once outside tile loop - Traffic-shaped writes: staggered (group_rank + rank_idx) % world_size write order avoids memory controller contention on the receiver side - Auto-default tile size: 8x256 (2048 elements, 8/thread) when user doesn't override Config defaults Also adds GluonDeviceTracing to iris_gluon.py for optional device-side event recording (TRACING=False by default, zero overhead when disabled). Benchmarked on MI308X (gfx942), 8192x8192 fp16, ROCm 7.2.0: 8 ranks, 32 CUs: Gluon 293 GB/s vs Triton 287 GB/s (+2%) 4 ranks, 48 CUs: Gluon 134 GB/s vs Triton 125 GB/s (+7%) 4 ranks, 32 CUs: Gluon 133 GB/s vs Triton 128 GB/s (+3%) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Self-contained benchmark script that compares RCCL (default channels), Iris Triton persistent, and Iris Gluon flat-2D all-gather across 6 tensor shapes (2 MB to 512 MB) and 5 CU counts (8-96). Designed to run under torchrun in a single invocation with formatted table output and optional CSV export. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pass n_warmup and n_repeat as function parameters instead of using global statement, which caused SyntaxError when the names appeared in argparse defaults before the global declaration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The symmetric heap uses a bump allocator with no free. Allocating input/output tensors per-iteration exhausted the heap at larger shapes. Now all buffers for every shape are allocated once before the benchmark loop starts (~8.5 GB per shmem context). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Drop benchmark_shapes.py (sweep script), remove tracing instrumentation from the gluon all-gather kernel, and revert iris_gluon.py tracing class changes back to main. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds a Gluon-based flat-2D tiled all-gather kernel (with traffic-shaped writes) as an alternative to the existing Triton implementation, updates Config to support Gluon layout construction, and introduces a correctness test comparing against PyTorch’s all_gather_into_tensor.
Changes:
- Add a Gluon persistent all-gather kernel using a flat-2D tiling strategy and RMA stores.
- Extend
Configwiththreads_per_warpvalidation to parameterize GluonBlockedLayout. - Add a distributed test validating Gluon all-gather correctness against PyTorch.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| tests/ccl/test_all_gather_gluon.py | Adds a correctness test for Gluon all-gather vs PyTorch reference. |
| iris/ccl/config.py | Adds threads_per_warp configuration + validation and documents Gluon-related tuning knobs. |
| iris/ccl/all_gather.py | Adds Gluon kernel + dispatch path behind Config(use_gluon=True) and input validation for tiling constraints. |
|
@copilot Run benchmarks comparing all three backends (RCCL, Iris Triton, Iris Gluon) and produce throughput (GB/s) vs CU count plots. Do NOT commit any scripts or files to the repo. You have AMD MI325X GPUs and PyTorch with ROCm available. Run Shapes to test (fp16, 8 GPUs):
CU counts: 4, 8, 16, 24, 32, 48, 64, 80, 96, 112, 128 How to benchmark each backend at each CU count:
Use Output: 3 plots (one per shape), each with 3 lines (RCCL, Triton, Gluon), X-axis = CU count, Y-axis = bandwidth (GB/s). Reply with the plots inline. Do not commit anything. |
|
@mawad-amd I've opened a new pull request, #476, to work on those changes. Once the pull request is ready, I'll request review from you. |
- Reject non-persistent variants when use_gluon=True instead of silently ignoring all_gather_variant - Import GLUON_AVAILABLE from all_gather module in test to check both iris_gluon and triton.experimental.gluon availability - Reduce test heap size from 8GB to 1GB - Fix docstring to reflect flat-2D tile size constraint Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Switches to iris_gluon context and passes use_gluon=True to Config so the example can exercise both Triton and Gluon kernel backends. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
Configwiththreads_per_warpfield for Gluon'sBlockedLayoutconstructionall_gather_into_tensorDesign
The kernel uses a 1D
BlockedLayoutoverBLOCK_SIZE_M * BLOCK_SIZE_Nelements, recovering 2D row/col via integer div/mod. This produces one vectorized load andworld_sizevectorized stores per tile while staying within Gluon's 1D layout framework.Key optimizations:
local_baseloaded once outside the tile loop (avoids 2xgl.load(heap_bases)perctx.store()call)(group_rank + rank_idx) % world_sizeso ranks write to different targets simultaneously, reducing memory controller contentionDispatched via
Config(use_gluon=True)through the existingshmem.ccl.all_gather()path.Benchmark results (MI325X, 8 GPUs, fp16)
Gluon matches RCCL at large shapes and outperforms Triton at small shapes (80% vs 66% of RCCL at 2 MB).
Test plan
torchrunwith pytest)test_all_gather_gluon.pyvalidates correctness across fp16/fp32/bf16 and three tile sizes🤖 Generated with Claude Code