Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion iris/ops/all_gather_matmul_hbm_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def _hbm_buffer_all_gather_matmul_kernel(
k_block_start = k_flag_group * K_PER_FLAG

rm = m_tile * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)

for k_off in range(K_PER_FLAG):
k_block_global = k_block_start + k_off
Expand All @@ -138,11 +139,12 @@ def _hbm_buffer_all_gather_matmul_kernel(
k_tile = iris.x.TileView(pid_m_t, tile_k_t, BLOCK_SIZE_M, BLOCK_SIZE_K)

rk = k_block_global * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K)
staged_ptrs = staged_a + rm.to(tl.int64)[:, None] * stride_sa_m + rk[None, :] * stride_sa_k

for compile_rank in range(world_size):
if src_rank_idx == compile_rank:
a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx)
a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx, hint=(1, BLOCK_SIZE_K))
tl.store(staged_ptrs, a_tile, cache_modifier=".cg")

flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group
Expand Down
5 changes: 5 additions & 0 deletions iris/x/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def gather(
src_view: TensorView,
source_rank: tl.constexpr,
ctx: DeviceContext,
hint: tl.constexpr = None,
):
"""
Tile-level gather from a specific rank.
Expand All @@ -37,6 +38,9 @@ def gather(
src_view: TensorView for source tensor on source_rank.
source_rank: Specific rank to load from (constexpr).
ctx: DeviceContext with rank, world_size, and heap_bases.
hint: Vectorization hint passed to tl.multiple_of / tl.max_contiguous on
the translated pointer. Use a scalar (e.g. 16) or a tuple
(e.g. (1, 16)) to indicate alignment. Defaults to None (no hint).

Returns:
Loaded tile data as a tensor.
Expand All @@ -61,6 +65,7 @@ def gather(
source_rank, # from_rank (source rank)
ctx.heap_bases,
mask=mask,
hint=hint,
)

return tile_data
202 changes: 202 additions & 0 deletions tests/ops/test_all_gather_matmul_hbm_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.

"""
Tests for fused all_gather + matmul using the HBM staging buffer implementation.

Each rank has A_sharded (M x K_local), B is replicated.
The operation gathers A from all ranks into a local HBM buffer and computes C = A_gathered @ B.
"""

import pytest
import torch
import torch.distributed as dist

import iris
from iris.ops.all_gather_matmul_hbm_buffer import (
all_gather_matmul_hbm_buffer,
all_gather_matmul_hbm_buffer_preamble,
)
from iris.ops.config import FusedConfig


@pytest.mark.parametrize(
"dtype, atol, rtol",
[
(torch.float16, 1e-2, 1e-2),
],
)
@pytest.mark.parametrize(
"M,K_local,N",
[
(128, 32, 64),
(256, 64, 128),
],
)
@pytest.mark.parametrize(
"staged_a_layout",
[
"k_contiguous",
"m_contiguous",
],
)
def test_all_gather_matmul_hbm_buffer(dtype, atol, rtol, M, K_local, N, staged_a_layout):
"""Test all_gather_matmul_hbm_buffer against torch all_gather + matmul."""
if not dist.is_initialized():
pytest.skip("torch.distributed not initialized")

heap_size = 2**33
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()

K = K_local * world_size # Full K dimension

# Seed for reproducibility - different seed per rank for A_sharded
torch.manual_seed(42 + rank)
A_sharded = torch.randn(M, K_local, dtype=dtype, device=f"cuda:{rank}")

# B must be identical on all ranks
torch.manual_seed(123)
B = torch.randn(K, N, dtype=dtype, device=f"cuda:{rank}")

# Reference: torch all_gather + matmul
A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)]
dist.all_gather(A_gathered_list, A_sharded)
A_gathered_ref = torch.cat(A_gathered_list, dim=1) # (M, K)
ref_output = torch.matmul(A_gathered_ref, B)
torch.cuda.synchronize()

# Create shmem tensors
A_sharded_shmem = shmem.zeros((M, K_local), dtype=dtype)
A_sharded_shmem.copy_(A_sharded)
B_shmem = shmem.zeros((K, N), dtype=dtype)
B_shmem.copy_(B)
output = shmem.zeros((M, N), dtype=dtype)

shmem.barrier()

# Use small block sizes for small test problems
config = FusedConfig(
block_size_m=64,
block_size_n=64,
block_size_k=32,
)

workspace = all_gather_matmul_hbm_buffer_preamble(
shmem, A_sharded_shmem, B_shmem, config=config, staged_a_layout=staged_a_layout
)

all_gather_matmul_hbm_buffer(
shmem,
output,
A_sharded_shmem,
B_shmem,
config=config,
workspace=workspace,
staged_a_layout=staged_a_layout,
trace=False,
)

torch.cuda.synchronize()
shmem.barrier()

max_diff = (output - ref_output).abs().max().item()

assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), (
f"Rank {rank}: Max diff {max_diff}, expected < {atol} "
f"(staged_a_layout={staged_a_layout}, M={M}, K_local={K_local}, N={N})"
)


@pytest.mark.parametrize(
"dtype, atol, rtol",
[
(torch.float16, 1e-2, 1e-2),
],
)
@pytest.mark.parametrize(
"M,K_local,N",
[
(128, 32, 64),
],
)
def test_all_gather_matmul_hbm_buffer_with_bias(dtype, atol, rtol, M, K_local, N):
"""Test all_gather_matmul_hbm_buffer with a bias vector."""
if not dist.is_initialized():
pytest.skip("torch.distributed not initialized")

heap_size = 2**33
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()

K = K_local * world_size

torch.manual_seed(42 + rank)
A_sharded = torch.randn(M, K_local, dtype=dtype, device=f"cuda:{rank}")

torch.manual_seed(123)
B = torch.randn(K, N, dtype=dtype, device=f"cuda:{rank}")

torch.manual_seed(77)
bias = torch.randn(M, dtype=dtype, device=f"cuda:{rank}")

# Reference: torch all_gather + matmul + bias
A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)]
dist.all_gather(A_gathered_list, A_sharded)
A_gathered_ref = torch.cat(A_gathered_list, dim=1)
ref_output = torch.matmul(A_gathered_ref, B) + bias[:, None]
torch.cuda.synchronize()

# Create shmem tensors
A_sharded_shmem = shmem.zeros((M, K_local), dtype=dtype)
A_sharded_shmem.copy_(A_sharded)
B_shmem = shmem.zeros((K, N), dtype=dtype)
B_shmem.copy_(B)
bias_shmem = shmem.zeros((M,), dtype=dtype)
bias_shmem.copy_(bias)
output = shmem.zeros((M, N), dtype=dtype)

shmem.barrier()

config = FusedConfig(
block_size_m=64,
block_size_n=64,
block_size_k=32,
)

all_gather_matmul_hbm_buffer(
shmem,
output,
A_sharded_shmem,
B_shmem,
bias=bias_shmem,
config=config,
trace=False,
)

torch.cuda.synchronize()
shmem.barrier()

max_diff = (output - ref_output).abs().max().item()

assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), (
f"Rank {rank}: Max diff {max_diff}, expected < {atol} (with bias)"
)


if __name__ == "__main__":
# For quick debugging
import sys

if not dist.is_initialized():
print("Run with: torchrun --nproc_per_node=2 tests/ops/test_all_gather_matmul_hbm_buffer.py")
sys.exit(1)

rank = dist.get_rank()
torch.cuda.set_device(rank)

print(f"[Rank {rank}] Testing all_gather_matmul_hbm_buffer...")
test_all_gather_matmul_hbm_buffer(torch.float16, 1e-2, 1e-2, 128, 32, 64, "k_contiguous")
print(f"[Rank {rank}] ✓ Test passed!")
Loading