Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions iris/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
- all_gather_matmul: All-Gather + GEMM
- matmul_all_gather: GEMM + All-Gather
- matmul_reduce_scatter: GEMM + Reduce-Scatter
- matmul_all_scatter: GEMM + All-Scatter
"""

from .config import FusedConfig
Expand All @@ -38,6 +39,7 @@
from .all_gather_matmul import all_gather_matmul, all_gather_matmul_preamble
from .matmul_all_gather import matmul_all_gather
from .matmul_reduce_scatter import matmul_reduce_scatter, matmul_reduce_scatter_preamble
from .matmul_all_scatter import matmul_all_scatter, matmul_all_scatter_preamble


class OpsNamespace:
Expand Down Expand Up @@ -166,6 +168,36 @@ def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False,
"""
return matmul_reduce_scatter(self._shmem, output_tensor, A, B, bias, async_op, config, workspace)

def matmul_all_scatter(self, output, A, B_shard, bias=None, async_op=False, config=None, workspace=None):
"""
Fused matrix multiplication and all-scatter.

Computes: output = all_scatter(A @ B_shard) along N dimension

Each rank has B_shard of shape (K, N_shard) where N_shard = N / world_size.
The operation computes C_shard = A @ B_shard on each rank and scatters the
column stripe to all ranks so that every rank ends up with the full C (M, N).

Args:
output: Output tensor (M, N) where N = N_shard * world_size
A: Input matrix A (M, K) — replicated across ranks
B_shard: Column-sharded weight matrix (K, N_shard)
bias: Optional bias vector (M,)
async_op: If False, performs barrier at end
config: Optional FusedConfig for tuning
workspace: Optional pre-allocated workspace

Returns:
workspace: Updated workspace object

Example:
>>> N_shard = N // world_size
>>> B_shard = shmem.randn((K, N_shard), dtype=torch.float16)
>>> output = shmem.zeros((M, N), dtype=torch.float16)
>>> shmem.ops.matmul_all_scatter(output, A, B_shard)
"""
return matmul_all_scatter(self._shmem, output, A, B_shard, bias, async_op, config, workspace)


# Export public API
__all__ = [
Expand All @@ -183,4 +215,6 @@ def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False,
"matmul_all_gather",
"matmul_reduce_scatter",
"matmul_reduce_scatter_preamble",
"matmul_all_scatter",
"matmul_all_scatter_preamble",
]
261 changes: 261 additions & 0 deletions iris/ops/matmul_all_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

"""
Fused GEMM + All-Scatter operation.

Each rank has a column-sharded weight B_shard (K x N_shard) and a replicated
input A (M x K). Each rank computes C_shard = A @ B_shard, then scatters its
column stripe to all other ranks so that every rank ends up with the full
output C (M x N) where N = world_size * N_shard.

This is useful for tensor-parallel workloads where weights are column-sharded
and the full output is needed on all ranks.
"""

from typing import Optional
import torch
import triton
import triton.language as tl
import iris
import iris.x

from tritonblas.kernels.stages import GemmContext, ScheduleContext, make_tensor_view

from .config import FusedConfig
from .workspace import FusedWorkspace


@triton.jit()
def _fused_matmul_all_scatter_kernel(
A, # (M, K) - replicated across ranks
B_shard, # (K, N_shard) - each rank's column shard of weight matrix B
C, # (M, N) - output, N = N_shard * world_size
bias_ptr,
M,
N,
K,
N_shard,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_bias,
context_tensor: tl.tensor,
cur_rank: tl.constexpr,
world_size: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_SMS: tl.constexpr,
NUM_XCDS: tl.constexpr,
BIAS: tl.constexpr,
EVEN_K: tl.constexpr,
ALLOW_TF32: tl.constexpr,
):
"""
Fused GEMM + all-scatter kernel.

Computes local GEMM tile and immediately scatters this rank's column stripe
to all ranks via ``iris.x.all_scatter``. No intermediate buffer needed.
"""
# ═══════════════════════════════════════════════════════════════════════
# Create tritonblas views, context, and scheduler for GEMM
# ═══════════════════════════════════════════════════════════════════════
view_A = make_tensor_view(A, M, K, stride_am, stride_ak)
view_B = make_tensor_view(B_shard, K, N_shard, stride_bk, stride_bn)
gemm_ctx = GemmContext(
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
num_sms=NUM_SMS,
num_xcds=NUM_XCDS,
group_size_m=GROUP_SIZE_M,
even_k=EVEN_K,
allow_tf32=ALLOW_TF32,
)
sched = ScheduleContext(M, N_shard, K, gemm_ctx)
ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size)
dst_view = iris.x.make_tensor_view(C, M, N, stride_cm, stride_cn)

# Persistent loop over local tiles using scheduler
start, total, stride = sched.persistent_tile_range()
for tile_id in range(start, total, stride):
# Get tile coordinates with swizzling from scheduler
out_tile = sched.get_tile_from_idx(tile_id)

# GEMM using tritonblas stages
acc = gemm_ctx.reduce_axis(view_A, view_B, out_tile)

# Add bias if provided
if BIAS:
rm, _ = out_tile.indices()
bias_vec = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0)
acc = acc + bias_vec[:, None]

# Convert to output dtype
c_tile = acc.to(C.type.element_ty)

# Wrap result in a Tile object and scatter to all ranks
tile_obj = iris.x.Tile(out_tile.pid_m, out_tile.pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c_tile)
iris.x.all_scatter(tile_obj, dst_view, ctx)


def matmul_all_scatter_preamble(
shmem,
A: torch.Tensor,
B_shard: torch.Tensor,
config: Optional[FusedConfig] = None,
) -> FusedWorkspace:
"""Allocate workspace for matmul_all_scatter (none needed for scatter pattern)."""
if config is None:
config = FusedConfig()

M, K = A.shape
K2, N_shard = B_shard.shape
world_size = shmem.get_num_ranks()

assert K == K2, f"Inner dimensions must match: A has K={K}, B_shard has K={K2}"

N = N_shard * world_size

return FusedWorkspace(
operation="matmul_all_scatter",
shape=(M, N, K),
dtype=A.dtype,
world_size=world_size,
prepared=True,
)


def matmul_all_scatter(
shmem,
output: torch.Tensor,
A: torch.Tensor,
B_shard: torch.Tensor,
bias: Optional[torch.Tensor] = None,
async_op: bool = False,
config: Optional[FusedConfig] = None,
workspace: Optional[FusedWorkspace] = None,
) -> FusedWorkspace:
"""
Fused matrix multiplication and all-scatter.

Computes: output = all_scatter(A @ B_shard) along N dimension

Each rank has B_shard of shape (K, N_shard) where N_shard = N / world_size.
The operation computes C_shard = A @ B_shard on each rank and immediately
scatters its column stripe to all other ranks via ``iris.x.all_scatter``,
so that every rank ends up with the full C (M, N).

This is a single-kernel implementation — no intermediate buffer needed.

Args:
shmem: Iris shmem context
output: Output tensor C of shape (M, N) where N = N_shard * world_size
A: Input matrix A of shape (M, K) — replicated across ranks
B_shard: Column-sharded weight matrix of shape (K, N_shard)
bias: Optional bias vector of shape (M,)
async_op: If False, performs barrier at end
config: Optional FusedConfig for tuning
workspace: Optional pre-allocated workspace

Returns:
FusedWorkspace object

Example:
>>> N_shard = N // world_size
>>> B_shard = shmem.randn((K, N_shard), dtype=torch.float16)
>>> output = shmem.zeros((M, N), dtype=torch.float16)
>>> shmem.ops.matmul_all_scatter(output, A, B_shard)
"""
if config is None:
config = FusedConfig()

M, K = A.shape
K2, N_shard = B_shard.shape
world_size = shmem.get_num_ranks()
rank = shmem.get_rank()

assert K == K2, f"Inner dimensions must match: A has K={K}, B_shard has K={K2}"

N = N_shard * world_size
assert output.shape == (M, N), f"Output must be ({M}, {N}), got {output.shape}"

# Validate problem size against block sizes
assert M >= config.block_size_m, (
f"M ({M}) must be >= block_size_m ({config.block_size_m}). Use smaller block sizes for small problems."
)
assert K >= config.block_size_k, (
f"K ({K}) must be >= block_size_k ({config.block_size_k}). Use smaller block sizes for small problems."
)
assert N_shard >= config.block_size_n, (
f"N_shard ({N_shard}) must be >= block_size_n ({config.block_size_n}). "
f"Use smaller block sizes for small problems."
)

# Allocate workspace if not provided
if workspace is None:
workspace = matmul_all_scatter_preamble(shmem, A, B_shard, config)

stride_am, stride_ak = A.stride()
stride_bk, stride_bn = B_shard.stride()
stride_cm, stride_cn = output.stride()

if bias is not None:
assert bias.shape[0] == M
bias_ptr = bias
stride_bias = bias.stride()[0] if bias.dim() > 0 else 1
use_bias = True
else:
bias_ptr = output
stride_bias = 1
use_bias = False

device = A.device
num_sms = config.num_sms
if num_sms is None:
props = torch.cuda.get_device_properties(device)
num_sms = props.multi_processor_count

even_k = K % config.block_size_k == 0

# Launch single fused kernel
grid = (num_sms,)
_fused_matmul_all_scatter_kernel[grid](
Comment on lines +218 to +228
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel performs rank-to-rank communication (iris.x.all_gather(...)) inside a persistent-tile loop, which relies on all ranks executing the same sequence/partitioning of tiles. Choosing grid=(num_sms,) from local device properties can diverge across ranks (heterogeneous GPUs, MIG/partitioning, differing visibility), causing mismatched collective participation and potential deadlock. Recommendation: require config.num_sms to be explicitly set and identical across ranks for this op, or derive a consistent cross-rank value (e.g., compute min/agree via a host-side distributed reduction/broadcast) and use that for grid/scheduler.

Copilot uses AI. Check for mistakes.
A,
B_shard,
output,
bias_ptr,
M,
N,
K,
N_shard,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_bias,
shmem.get_device_context(),
rank,
world_size,
config.block_size_m,
config.block_size_n,
config.block_size_k,
config.group_size_m,
num_sms,
config.num_xcds,
use_bias,
even_k,
config.allow_tf32,
)

if not async_op:
shmem.barrier()

return workspace
6 changes: 6 additions & 0 deletions iris/x/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
>>> ctx.all_gather(tile, src_view, dst_view, dim=0)
>>> ctx.all_to_all(tile, src_view, dst_view, N_per_rank)
>>> ctx.reduce_scatter(tile, src_view, dst_view)
>>>
>>> # Standalone all_scatter (each rank pushes its tile to all ranks)
>>> tile_with_data = iris.x.Tile(pid_m, pid_n, BLOCK_M, BLOCK_N, computed_data)
>>> iris.x.all_scatter(tile_with_data, dst_view, ctx)

Example (API with AllReduceConfig for algorithm selection):
>>> @triton.jit
Expand Down Expand Up @@ -70,6 +74,7 @@
)
from .gather import gather
from .all_gather import all_gather
from .all_scatter import all_scatter
from .all_to_all import all_to_all
from .reduce_scatter import reduce_scatter

Expand All @@ -91,6 +96,7 @@
"all_reduce_spinlock",
"gather",
"all_gather",
"all_scatter",
"all_to_all",
"reduce_scatter",
]
Loading