From a86dc0400de2fff3e8b4d27cd637b2c41fc5dffc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 11 Mar 2026 16:11:23 +0000 Subject: [PATCH 1/2] Initial plan From 445b25cb095078a51c3db9fb5cbf5c5b6eec80ee Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 12 Mar 2026 19:48:25 +0000 Subject: [PATCH 2/2] Add vectorization hints and tests for HBM buffer all-gather matmul Co-authored-by: ryanswann-amd <109695074+ryanswann-amd@users.noreply.github.com> --- iris/ops/all_gather_matmul_hbm_buffer.py | 4 +- iris/x/gather.py | 5 + .../ops/test_all_gather_matmul_hbm_buffer.py | 202 ++++++++++++++++++ 3 files changed, 210 insertions(+), 1 deletion(-) create mode 100644 tests/ops/test_all_gather_matmul_hbm_buffer.py diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py index abe3b393..2db1b6ed 100644 --- a/iris/ops/all_gather_matmul_hbm_buffer.py +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -126,6 +126,7 @@ def _hbm_buffer_all_gather_matmul_kernel( k_block_start = k_flag_group * K_PER_FLAG rm = m_tile * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) for k_off in range(K_PER_FLAG): k_block_global = k_block_start + k_off @@ -138,11 +139,12 @@ def _hbm_buffer_all_gather_matmul_kernel( k_tile = iris.x.TileView(pid_m_t, tile_k_t, BLOCK_SIZE_M, BLOCK_SIZE_K) rk = k_block_global * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K) staged_ptrs = staged_a + rm.to(tl.int64)[:, None] * stride_sa_m + rk[None, :] * stride_sa_k for compile_rank in range(world_size): if src_rank_idx == compile_rank: - a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx) + a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx, hint=(1, BLOCK_SIZE_K)) tl.store(staged_ptrs, a_tile, cache_modifier=".cg") flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group diff --git a/iris/x/gather.py b/iris/x/gather.py index ca8bd4f9..4e2b10cc 100644 --- a/iris/x/gather.py +++ b/iris/x/gather.py @@ -24,6 +24,7 @@ def gather( src_view: TensorView, source_rank: tl.constexpr, ctx: DeviceContext, + hint: tl.constexpr = None, ): """ Tile-level gather from a specific rank. @@ -37,6 +38,9 @@ def gather( src_view: TensorView for source tensor on source_rank. source_rank: Specific rank to load from (constexpr). ctx: DeviceContext with rank, world_size, and heap_bases. + hint: Vectorization hint passed to tl.multiple_of / tl.max_contiguous on + the translated pointer. Use a scalar (e.g. 16) or a tuple + (e.g. (1, 16)) to indicate alignment. Defaults to None (no hint). Returns: Loaded tile data as a tensor. @@ -61,6 +65,7 @@ def gather( source_rank, # from_rank (source rank) ctx.heap_bases, mask=mask, + hint=hint, ) return tile_data diff --git a/tests/ops/test_all_gather_matmul_hbm_buffer.py b/tests/ops/test_all_gather_matmul_hbm_buffer.py new file mode 100644 index 00000000..af173ea8 --- /dev/null +++ b/tests/ops/test_all_gather_matmul_hbm_buffer.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tests for fused all_gather + matmul using the HBM staging buffer implementation. + +Each rank has A_sharded (M x K_local), B is replicated. +The operation gathers A from all ranks into a local HBM buffer and computes C = A_gathered @ B. +""" + +import pytest +import torch +import torch.distributed as dist + +import iris +from iris.ops.all_gather_matmul_hbm_buffer import ( + all_gather_matmul_hbm_buffer, + all_gather_matmul_hbm_buffer_preamble, +) +from iris.ops.config import FusedConfig + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 1e-2, 1e-2), + ], +) +@pytest.mark.parametrize( + "M,K_local,N", + [ + (128, 32, 64), + (256, 64, 128), + ], +) +@pytest.mark.parametrize( + "staged_a_layout", + [ + "k_contiguous", + "m_contiguous", + ], +) +def test_all_gather_matmul_hbm_buffer(dtype, atol, rtol, M, K_local, N, staged_a_layout): + """Test all_gather_matmul_hbm_buffer against torch all_gather + matmul.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + K = K_local * world_size # Full K dimension + + # Seed for reproducibility - different seed per rank for A_sharded + torch.manual_seed(42 + rank) + A_sharded = torch.randn(M, K_local, dtype=dtype, device=f"cuda:{rank}") + + # B must be identical on all ranks + torch.manual_seed(123) + B = torch.randn(K, N, dtype=dtype, device=f"cuda:{rank}") + + # Reference: torch all_gather + matmul + A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_gathered_list, A_sharded) + A_gathered_ref = torch.cat(A_gathered_list, dim=1) # (M, K) + ref_output = torch.matmul(A_gathered_ref, B) + torch.cuda.synchronize() + + # Create shmem tensors + A_sharded_shmem = shmem.zeros((M, K_local), dtype=dtype) + A_sharded_shmem.copy_(A_sharded) + B_shmem = shmem.zeros((K, N), dtype=dtype) + B_shmem.copy_(B) + output = shmem.zeros((M, N), dtype=dtype) + + shmem.barrier() + + # Use small block sizes for small test problems + config = FusedConfig( + block_size_m=64, + block_size_n=64, + block_size_k=32, + ) + + workspace = all_gather_matmul_hbm_buffer_preamble( + shmem, A_sharded_shmem, B_shmem, config=config, staged_a_layout=staged_a_layout + ) + + all_gather_matmul_hbm_buffer( + shmem, + output, + A_sharded_shmem, + B_shmem, + config=config, + workspace=workspace, + staged_a_layout=staged_a_layout, + trace=False, + ) + + torch.cuda.synchronize() + shmem.barrier() + + max_diff = (output - ref_output).abs().max().item() + + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( + f"Rank {rank}: Max diff {max_diff}, expected < {atol} " + f"(staged_a_layout={staged_a_layout}, M={M}, K_local={K_local}, N={N})" + ) + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 1e-2, 1e-2), + ], +) +@pytest.mark.parametrize( + "M,K_local,N", + [ + (128, 32, 64), + ], +) +def test_all_gather_matmul_hbm_buffer_with_bias(dtype, atol, rtol, M, K_local, N): + """Test all_gather_matmul_hbm_buffer with a bias vector.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + K = K_local * world_size + + torch.manual_seed(42 + rank) + A_sharded = torch.randn(M, K_local, dtype=dtype, device=f"cuda:{rank}") + + torch.manual_seed(123) + B = torch.randn(K, N, dtype=dtype, device=f"cuda:{rank}") + + torch.manual_seed(77) + bias = torch.randn(M, dtype=dtype, device=f"cuda:{rank}") + + # Reference: torch all_gather + matmul + bias + A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_gathered_list, A_sharded) + A_gathered_ref = torch.cat(A_gathered_list, dim=1) + ref_output = torch.matmul(A_gathered_ref, B) + bias[:, None] + torch.cuda.synchronize() + + # Create shmem tensors + A_sharded_shmem = shmem.zeros((M, K_local), dtype=dtype) + A_sharded_shmem.copy_(A_sharded) + B_shmem = shmem.zeros((K, N), dtype=dtype) + B_shmem.copy_(B) + bias_shmem = shmem.zeros((M,), dtype=dtype) + bias_shmem.copy_(bias) + output = shmem.zeros((M, N), dtype=dtype) + + shmem.barrier() + + config = FusedConfig( + block_size_m=64, + block_size_n=64, + block_size_k=32, + ) + + all_gather_matmul_hbm_buffer( + shmem, + output, + A_sharded_shmem, + B_shmem, + bias=bias_shmem, + config=config, + trace=False, + ) + + torch.cuda.synchronize() + shmem.barrier() + + max_diff = (output - ref_output).abs().max().item() + + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( + f"Rank {rank}: Max diff {max_diff}, expected < {atol} (with bias)" + ) + + +if __name__ == "__main__": + # For quick debugging + import sys + + if not dist.is_initialized(): + print("Run with: torchrun --nproc_per_node=2 tests/ops/test_all_gather_matmul_hbm_buffer.py") + sys.exit(1) + + rank = dist.get_rank() + torch.cuda.set_device(rank) + + print(f"[Rank {rank}] Testing all_gather_matmul_hbm_buffer...") + test_all_gather_matmul_hbm_buffer(torch.float16, 1e-2, 1e-2, 128, 32, 64, "k_contiguous") + print(f"[Rank {rank}] ✓ Test passed!")