From d44394c00e45383b388aa383049f8d334a794504 Mon Sep 17 00:00:00 2001 From: David Sidler Date: Wed, 5 Nov 2025 22:13:31 -0600 Subject: [PATCH 01/29] wip back of sdma integration --- .gitmodules | 3 + .../message_passing_copy_engine.py | 262 ++++++++++++++++ .../06_message_passing/message_passing_put.py | 19 +- ext/shader_sdma | 1 + iris/__init__.py | 4 + iris/iris.py | 294 +++++++++++++++++- setup.py | 82 ++++- test.py | 21 ++ 8 files changed, 676 insertions(+), 10 deletions(-) create mode 100644 .gitmodules create mode 100644 examples/06_message_passing/message_passing_copy_engine.py create mode 160000 ext/shader_sdma create mode 100644 test.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..41cf4672e --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "ext/shader_sdma"] + path = ext/shader_sdma + url = https://github.com/AARInternal/shader_sdma.git diff --git a/examples/06_message_passing/message_passing_copy_engine.py b/examples/06_message_passing/message_passing_copy_engine.py new file mode 100644 index 000000000..8f9ea2267 --- /dev/null +++ b/examples/06_message_passing/message_passing_copy_engine.py @@ -0,0 +1,262 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import argparse + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import triton.language as tl +import random +import os +import sys +import time +import ctypes + +from mpi4py import MPI + +import iris + + +@triton.jit +def producer_kernel( + source_buffer, # tl.tensor: pointer to source data + target_buffer, # tl.tensor: pointer to target data + flag, # tl.tensor: pointer to flags + buffer_size, # int32: total number of elements + producer_rank: tl.constexpr, + consumer_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers + copy_engine_handle_ptr, +): + pid = tl.program_id(0) + + # Compute start index of this block + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Guard for out-of-bounds accesses + mask = offsets < buffer_size + + # Put chunk into remote buffer + # iris.put_ce(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, mask=mask) + + # Set flag to signal completion + iris.signal_ce(flag + pid, producer_rank, producer_rank, heap_bases_ptr, copy_engine_handle_ptr) + # iris.atomic_cas_ce() + # zero_u64 = tl.zeros((1,), tl.uint64) + # one_u64 = tl.full((1,), 1, tl.uint64) + # iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys") + + # done = 0 #zero_u64 + # while done == 0: + # done = iris.atomic_cas( + # flag + pid, 1, 0, producer_rank, producer_rank, heap_bases_ptr, sem="acquire", scope="sys" + # ) + + +@triton.jit +def consumer_kernel( + buffer, # tl.tensor: pointer to shared buffer (read from target_rank) + flag, # tl.tensor: sync flag per block + buffer_size, # int32: total number of elements + consumer_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < buffer_size + + # Spin-wait until writer sets flag[pid] = 1 + # zero_u64 = tl.zeros((1,), tl.uint64) + # one_u64 = tl.full((1,), 1, tl.uint64) + done = 0 #zero_u64 + # while done == 0: + # done = iris.atomic_cas( + # flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" + # ) + + # Read from the target buffer (written by producer) + # values = tl.load(buffer + offsets, mask=mask) + + # Do something with values... + # (Here you might write to output, do computation, etc.) + # values = values * 2 + + # Store chunk to target buffer + # tl.store( + # buffer + offsets, + # values, + # mask=mask, + # ) + + # Optionally reset the flag for next iteration + tl.store(flag + pid, 0) + + +torch.manual_seed(123) +random.seed(123) + + +def torch_dtype_from_str(datatype: str) -> torch.dtype: + dtype_map = { + "fp16": torch.float16, + "fp32": torch.float32, + "int8": torch.int8, + "bf16": torch.bfloat16, + } + try: + return dtype_map[datatype] + except KeyError: + print(f"Unknown datatype: {datatype}") + exit(1) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Parse Message Passing configuration.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-t", + "--datatype", + type=str, + default="fp32", + choices=["fp16", "fp32", "int8", "bf16"], + help="Datatype of computation", + ) + parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size") + parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") + + parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + # Main benchmark logic + shmem = iris.iris(args["heap_size"]) + dtype = torch_dtype_from_str(args["datatype"]) + cur_rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Allocate source and destination buffers on the symmetric heap + destination_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype) + if dtype.is_floating_point: + source_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) + else: + ii = torch.iinfo(dtype) + source_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) + + if world_size != 2: + raise ValueError("This example requires exactly two processes.") + + producer_rank = 0 + consumer_rank = 1 + + n_elements = source_buffer.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + num_blocks = triton.cdiv(n_elements, args["block_size"]) + shmem.info(f"n_elements {n_elements} grid {grid} num_blocks {num_blocks}") + shmem.info(f"src buffer {id(source_buffer):#x} dst buffer {id(destination_buffer):#x}") + shmem.info(f"data ptr src buffer {source_buffer.data_ptr():#x} dst buffer {destination_buffer.data_ptr():#x}") + + # Allocate flags on the symmetric heap + flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32) + + # copy_engine_handle = iris.get_copy_engine_handle(consumer_rank) + + # src_ptr = ctypes.c_void_p(source_buffer.data_ptr()), + if cur_rank == producer_rank: + shmem.info(f"Rank {cur_rank} is sending data to rank {consumer_rank}.") + kk = producer_kernel[grid]( + source_buffer, + destination_buffer, + flags, + n_elements, + producer_rank, + consumer_rank, + args["block_size"], + shmem.get_heap_bases(), + shmem.get_copy_engine_handle(consumer_rank), + ) + else: + # time.sleep(1) + shmem.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.") + kk = consumer_kernel[grid]( + destination_buffer, flags, n_elements, consumer_rank, args["block_size"], shmem.get_heap_bases() + ) + # time.sleep(1) + + shmem.barrier() + # torch.cuda.synchronize() + shmem.info(f"Rank {cur_rank} has finished sending/receiving data.") + shmem.info("Validating output...") + + success = True + numErrors = 0 + if cur_rank == consumer_rank: + expected = source_buffer # * 2 + diff_mask = ~torch.isclose(destination_buffer, expected, atol=1) + breaking_indices = torch.nonzero(diff_mask, as_tuple=False) + + if not torch.allclose(destination_buffer, expected, atol=1): + max_diff = (destination_buffer - expected).abs().max().item() + shmem.info(f"Max absolute difference: {max_diff}") + for idx in breaking_indices: + idx = tuple(idx.tolist()) + computed_val = destination_buffer[idx] + expected_val = expected[idx] + if numErrors < 10: + shmem.info(f"Mismatch at index {idx}: C={computed_val}, expected={expected_val}") + success = False + numErrors += 1 + # break + + if success: + shmem.info("Validation successful.") + else: + shmem.info(f"Validation failed with {numErrors} errors / {destination_buffer.numel()}") + + shmem.barrier() + + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + comm = MPI.COMM_WORLD # Communicator for all processes + rank = comm.Get_rank() # Get the rank of the current process + num_ranks = comm.Get_size() # Total number of processes + + # TODO local_rank + torch.cuda.set_device(rank) + + + # Synchronize all processes + comm.barrier() + + init_url = "tcp://127.0.0.1:29500" + + _worker(rank, num_ranks, init_url, args) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index 54abe2554..fdbd36a7a 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -3,6 +3,8 @@ import argparse +from mpi4py import MPI + import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -117,7 +119,7 @@ def parse_args(): parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + # parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -212,15 +214,16 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): def main(): args = parse_args() - num_ranks = args["num_ranks"] + comm = MPI.COMM_WORLD # Communicator for all processes + rank = comm.Get_rank() # Get the rank of the current process + num_ranks = comm.Get_size() # Total number of processes + + # Synchronize all processes + comm.barrier() init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) + + _worker(rank, num_ranks, init_url, args) if __name__ == "__main__": diff --git a/ext/shader_sdma b/ext/shader_sdma new file mode 160000 index 000000000..243be5f30 --- /dev/null +++ b/ext/shader_sdma @@ -0,0 +1 @@ +Subproject commit 243be5f30f96374d4231cd669e179584e714435d diff --git a/iris/__init__.py b/iris/__init__.py index 2b048d03a..d992dca4f 100644 --- a/iris/__init__.py +++ b/iris/__init__.py @@ -59,6 +59,8 @@ atomic_and, atomic_min, atomic_max, + put_ce, + signal_ce, ) from .util import ( @@ -98,6 +100,8 @@ "atomic_and", "atomic_min", "atomic_max", + "put_ce", + "signal_ce", "do_bench", "hip", "experimental", # Experimental features including iris_gluon diff --git a/iris/iris.py b/iris/iris.py index 9e52b4ec5..305f3bb5a 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -41,6 +41,12 @@ get_wall_clock_rate, get_ipc_handle_size, ) +import sys + +# sys.path.append("/home/dasidler/iris/iris/experimental") +# import my_module as anvil +import iris.experimental.my_module as anvil + import numpy as np import math import torch @@ -87,6 +93,7 @@ def __init__(self, heap_size=1 << 30): heap_base = self.memory_pool.data_ptr() heap_base_ptr = ctypes.c_void_p(heap_base) + self.info(f"heap_base {heap_base:#x}") heap_bases = np.zeros(num_ranks, dtype=np.uint64) heap_bases[cur_rank] = heap_base @@ -112,13 +119,22 @@ def __init__(self, heap_size=1 << 30): ipc_heap_bases[rank] = heap_bases[rank] for i in range(num_ranks): - self.debug(f"GPU {i}: Heap base {hex(int(ipc_heap_bases[i]))}") + self.info(f"GPU {i}: Heap base {hex(int(ipc_heap_bases[i]))}") distributed_barrier() self.heap_bases = torch.from_numpy(ipc_heap_bases).to(device=self.device, dtype=torch.uint64) distributed_barrier() + # initialize copy engines + self.copy_engines = anvil.AnvilLib.get_instance() + self.copy_engines.init() + + # connect to all peers + for rank in range(num_ranks): + if rank != cur_rank: + self.copy_engines.connect(cur_rank, rank, 1) + def _log_with_rank(self, level, message): """Helper method to log with rank information injected into the record.""" if logger.isEnabledFor(level): @@ -1216,6 +1232,35 @@ def get_num_ranks(self): """ return self.num_ranks + def get_copy_engine_handle(self, to_rank): + # TODO remove last arg + queue = self.copy_engines.get_sdma_queue(self.get_rank(), to_rank, 0) + # Wrap into numpy array + handle = queue.device_ctx() + self.info("---- Queue ------------") + # print(f"handle at {id(handle):#x}") + self.info(f"queue_buf {handle.queue_buf:#x} at {id(handle.queue_buf):#x}") + self.info(f"rptr {handle.rptr:#x} at {id(handle.rptr):#x}") + self.info(f"wptr {handle.wptr:#x} at {id(handle.wptr):#x}") + self.info(f"doorbell {handle.doorbell:#x} at {id(handle.doorbell):#x}") + self.info(f"cached_write_ptr {handle.cached_wptr:#x} at {id(handle.cached_wptr):#x}") + self.info(f"committed_write_ptr {handle.committed_wptr:#x} at {id(handle.committed_wptr):#x}") + + # TODO get size + # array = np.ctypeslib.as_array(ctypes.cast(handle, ctypes.POINTER(ctypes.c_uint64)), shape=(7, )) + context_size = 6 + device_ctx = torch.zeros(context_size, dtype=torch.uint64, device=self.device) + device_ctx[0] = handle.queue_buf + device_ctx[1] = handle.rptr + device_ctx[2] = handle.wptr + device_ctx[3] = handle.doorbell + device_ctx[4] = handle.cached_wptr + device_ctx[5] = handle.committed_wptr + # context[6] = handle. + + + return device_ctx # anvil.get_handle_as_tensor(queue) # torch.from_numpy(array) #.to(device='cuda') + def __throw_if_invalid_output_tensor(self, tensor: torch.Tensor, num_elements: int, dtype: torch.dtype): if not self.__tensor_on_device(tensor): raise RuntimeError( @@ -1708,6 +1753,253 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): tl.store(translated_to_ptr, data, mask=mask) +@triton.jit +def nontemporal_store(addr, value): + tl.inline_asm_elementwise( + asm="""flat_store_dwordx2 $0 $1 sc0 nt; s_waitcnt vmcnt(0)""", + constraints=("v,v"), + args=[addr, value], + dtype=tl.uint64, + is_pure=False, + pack=1, + ) + return tl.zeros_like(value) + +@triton.jit +def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): + """ + Copies data from the current rank's local memory to the specified rank's memory. + This function performs a memory write operation by loading data from the current + rank's `from_ptr`, translating the `to_ptr` from the current rank's address + space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. + If the `to_rank` is the same as the current rank, this function performs a local copy operation. + + Args: + from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. + to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + from_rank (int): The current rank ID from which to read the data. + to_rank (int): The `to_rank` ID to which the data will be written. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(local_ptr, remote_ptr, heap_bases): + >>> from_rank = 0 + >>> to_rank = 1 + >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) + """ + + handle = ce_handle #iris.get_copy_engine_handle(to_rank) + queue_ptr = tl.load(handle + 0) #.to(tl.pointer_type(tl.uint64)) + read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(handle + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(handle + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(handle + 5).to(tl.pointer_type(tl.uint64)) + + translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) + dst_ptr_val = tl.min(translated_to_ptr.to(tl.uint64), axis=0) + + # Extract source address (min of pointer block where data is stored) + src_ptr_u64 = from_ptr.to(tl.uint64) + src_ptr_val = tl.min(src_ptr_u64, axis=0) + max_src_ptr = tl.max(src_ptr_u64, axis=0) + + # Infer element size from pointer type + # src_ptr is a block of pointers with a specific element type (e.g., pointer) + # The pointer dtype tells us the element type, which has a known size + # Map Triton dtypes to their byte sizes + ptr_dtype = from_ptr.dtype.element_ty # Get the element type that the pointer points to + + # Get element size in bytes from the dtype + # tl.float16 -> 2, tl.float32 -> 4, tl.float64 -> 8, etc. + if ptr_dtype == tl.float16 or ptr_dtype == tl.bfloat16: + element_size_bytes = 2 + elif ptr_dtype == tl.float32 or ptr_dtype == tl.int32 or ptr_dtype == tl.uint32: + element_size_bytes = 4 + elif ptr_dtype == tl.float64 or ptr_dtype == tl.int64 or ptr_dtype == tl.uint64: + element_size_bytes = 8 + elif ptr_dtype == tl.int8 or ptr_dtype == tl.uint8: + element_size_bytes = 1 + elif ptr_dtype == tl.int16 or ptr_dtype == tl.uint16: + element_size_bytes = 2 + else: + # Default to 4 bytes for unknown types + element_size_bytes = 4 + + # Calculate total size in bytes + # Count number of valid elements based on mask + mask_int = mask.to(tl.int32) + num_elements = tl.sum(mask_int, axis=0) + size_bytes = (num_elements * element_size_bytes).to(tl.uint32) + + # data = tl.load(from_ptr, mask=mask) + # tl.store(translated_to_ptr, data, mask=mask) + + command_in_bytes = 28 + base = tl.zeros((), dtype=tl.uint64) + # copy_size_in_bytes = tl.sum(mask.to(tl.int8)).to(tl.uint32) + # Acquire space + run_loop = True + while run_loop: + cur_index = tl.load(cached_write_ptr) + new_index = cur_index + command_in_bytes + # Check if wrap around + # TODO + + # Check if full + # TODO + # expected = cur_index + if tl.atomic_cas(cached_write_ptr, cur_index, new_index, sem='acquire', scope='gpu') == cur_index: + base = tl.full((), cur_index, dtype=tl.uint64) + run_loop = False + + + # Place command packet + queue_ptr_u32 = queue_ptr.to(tl.pointer_type(tl.uint32)) + slot_ptr_u32 = queue_ptr_u32 + (base // 4) + + + # Convert to scalar value + # from_ptr_as_u64 = tl.uint64(from_ptr) #tl.cast(from_ptr[0], tl.uint64) + + # offset 0: op + sub_op + tl.store(slot_ptr_u32 + 0, 1) + # offset 1: reserved + tl.store(slot_ptr_u32 + 1, 0) + # offset 2: count + tl.store(slot_ptr_u32 + 2, size_bytes - 1) + # offset 3: src address 31:0 + tl.store(slot_ptr_u32 + 3, src_ptr_val.to(tl.uint32)) + # offset 4: src address 63:32 + tl.store(slot_ptr_u32 + 4, (src_ptr_val >> 32).to(tl.uint32)) + # offset 5: dst address 31:0 + tl.store(slot_ptr_u32 + 5, dst_ptr_val.to(tl.uint32)) + # offset 6: dst address 63:32 + tl.store(slot_ptr_u32 + 6, (dst_ptr_val >> 32).to(tl.uint32)) + + + # Submit command + while tl.load(committed_write_ptr) != base: + pass + + tl.store(write_ptr, base + command_in_bytes) + + tl.debug_barrier() + + # Ring doorbell + # tl.store(doorbell_ptr, base + command_in_bytes) + tl.atomic_xchg(doorbell_ptr, base + command_in_bytes, sem='release', scope='sys') + tl.debug_barrier() + tl.store(committed_write_ptr, base + command_in_bytes) + +@triton.jit +def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): + """ + Copies data from the current rank's local memory to the specified rank's memory. + This function performs a memory write operation by loading data from the current + rank's `from_ptr`, translating the `to_ptr` from the current rank's address + space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. + If the `to_rank` is the same as the current rank, this function performs a local copy operation. + + Args: + from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. + to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + from_rank (int): The current rank ID from which to read the data. + to_rank (int): The `to_rank` ID to which the data will be written. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(local_ptr, remote_ptr, heap_bases): + >>> from_rank = 0 + >>> to_rank = 1 + >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) + """ + + handle = ce_handle #iris.get_copy_engine_handle(to_rank) + queue_ptr = tl.load(handle + 0) #.to(tl.pointer_type(tl.uint64)) + read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(handle + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(handle + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(handle + 5).to(tl.pointer_type(tl.uint64)) + + translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) + dst_ptr_val = translated_to_ptr.to(tl.uint64) + + command_in_bytes = 32 + base = tl.zeros((), dtype=tl.uint64) + # copy_size_in_bytes = tl.sum(mask.to(tl.int8)).to(tl.uint32) + # Acquire space + run_loop = True + while run_loop: + cur_index = tl.load(cached_write_ptr) + new_index = cur_index + command_in_bytes + # Check if wrap around + # TODO + + # Check if full + # TODO + # expected = cur_index + if tl.atomic_cas(cached_write_ptr, cur_index, new_index, sem='acquire', scope='gpu') == cur_index: + base = tl.full((), cur_index, dtype=tl.uint64) + run_loop = False + + + base_val = base.to(tl.uint64) + # Place command packet + queue_ptr_u32 = queue_ptr.to(tl.pointer_type(tl.uint32)) + slot_ptr_u32 = queue_ptr_u32 + (base_val // 4) + # print("queue_ptr: ", queue_ptr, " slot_ptr ", slot_ptr_u32, " base ", base) + + + # Convert to scalar value + # from_ptr_as_u64 = tl.uint64(from_ptr) #tl.cast(from_ptr[0], tl.uint64) + + # offset 0: op + sub_op + # tl.store(slot_ptr_u32 + 0, 0x2F0A) # op: 10, subop: 47 atomicAdd64 + tl.store(slot_ptr_u32 + 0, 0x0F0A) # op: 10, subop: 15 atomicAdd32 + # offset 1: dst address 31:0 + tl.store(slot_ptr_u32 + 1, dst_ptr_val.to(tl.uint32)) + # offset 2: dst address 63:32 + tl.store(slot_ptr_u32 + 2, (dst_ptr_val >> 32).to(tl.uint32)) + # offset 3: src data 31:0 + tl.store(slot_ptr_u32 + 3, 1) # increment by 1 + # offset 4: src data 63:32 + tl.store(slot_ptr_u32 + 4, 0) + # offset 5 - 7 unused + tl.store(slot_ptr_u32 + 5, 0) + tl.store(slot_ptr_u32 + 6, 0) + tl.store(slot_ptr_u32 + 7, 0) + + + # Submit command + while tl.load(committed_write_ptr) != base_val: + pass + + + tl.store(write_ptr, base + command_in_bytes) + + tl.debug_barrier() + + # Ring doorbell + # tl.store(doorbell_ptr, base_val + command_in_bytes) + # tl.atomic_xchg(doorbell_ptr, base + command_in_bytes, sem='release', scope='sys') + nontemporal_store(doorbell_ptr, base + command_in_bytes) + tl.debug_barrier() + tl.store(committed_write_ptr, base_val + command_in_bytes) + + + @triton.jit def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): """ diff --git a/setup.py b/setup.py index 698324612..478747224 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,91 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. -from setuptools import setup +import os +import subprocess +import sys +from pathlib import Path +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext + + +class CMakeExtension(Extension): + """Extension that uses CMake to build""" + def __init__(self, name, sourcedir=""): + super().__init__(name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class CMakeBuild(build_ext): + """Custom build_ext command that runs CMake""" + + def run(self): + # Check if CMake is available + try: + subprocess.check_output(["cmake", "--version"]) + except OSError: + raise RuntimeError("CMake must be installed to build RDMA extensions") + + # Build each extension + for ext in self.extensions: + self.build_extension(ext) + + def build_extension(self, ext): + if not isinstance(ext, CMakeExtension): + return super().build_extension(ext) + + extdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute() + + # CMake configuration arguments + cmake_args = [ + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", + f"-DPYTHON_EXECUTABLE={sys.executable}", + "-DCMAKE_CXX_COMPILER=/usr/bin/hipcc", + "-DCMAKE_BUILD_TYPE=Release", + ] + + # Build arguments + build_args = ["--config", "Release"] + + # Parallel build + if hasattr(os, "cpu_count"): + build_args += [f"-j{os.cpu_count()}"] + + # Create build directory + build_temp = Path(self.build_temp) / ext.name + build_temp.mkdir(parents=True, exist_ok=True) + + # Run CMake + subprocess.check_call( + ["cmake", ext.sourcedir] + cmake_args, + cwd=build_temp + ) + + # Build + subprocess.check_call( + ["cmake", "--build", "."] + build_args, + cwd=build_temp + ) + + +ext_modules = [] + +# TODO make optional +build_copy_engine_offload = True +if build_copy_engine_offload: + print("Building Copy Engine offload library") + copy_engine_ext = CMakeExtension( + "iris.experimental.anvil", + sourcedir="ext/shader_sdma" + ) + ext_modules.append(copy_engine_ext) + # This setup.py provides backward compatibility for legacy metadata fields # that don't map directly from pyproject.toml's modern PEP 621 format. setup( url="https://rocm.github.io/iris/", author="Muhammad Awad, Muhammad Osama, Brandon Potter", + ext_modules=ext_modules, + cmdclass={"build_ext": CMakeBuild} if ext_modules else {}, ) diff --git a/test.py b/test.py new file mode 100644 index 000000000..b56505abc --- /dev/null +++ b/test.py @@ -0,0 +1,21 @@ +import sys + +sys.path.append("./iris/experimental") + +import my_module as anvil + +print("Get isntance") + +instance = anvil.AnvilLib.get_instance() +print("initialize") +instance.init() + +print("Connect 0 to 1") + +instance.connect(0, 1, 1) + +queue = instance.get_sdma_queue(0, 1, 0) + +# handle = queue.device_handle() + +handle = anvil.get_handle_as_tensor(queue) \ No newline at end of file From c50e761fc11e1e124688f0da4dc232a9e85025d3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 6 Nov 2025 04:16:08 +0000 Subject: [PATCH 02/29] Apply Ruff auto-fixes --- .../message_passing_copy_engine.py | 8 ++-- iris/iris.py | 39 ++++++++----------- setup.py | 16 ++------ test.py | 2 +- 4 files changed, 25 insertions(+), 40 deletions(-) diff --git a/examples/06_message_passing/message_passing_copy_engine.py b/examples/06_message_passing/message_passing_copy_engine.py index 8f9ea2267..d278c8808 100644 --- a/examples/06_message_passing/message_passing_copy_engine.py +++ b/examples/06_message_passing/message_passing_copy_engine.py @@ -75,7 +75,7 @@ def consumer_kernel( # Spin-wait until writer sets flag[pid] = 1 # zero_u64 = tl.zeros((1,), tl.uint64) # one_u64 = tl.full((1,), 1, tl.uint64) - done = 0 #zero_u64 + done = 0 # zero_u64 # while done == 0: # done = iris.atomic_cas( # flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" @@ -212,7 +212,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): success = True numErrors = 0 if cur_rank == consumer_rank: - expected = source_buffer # * 2 + expected = source_buffer # * 2 diff_mask = ~torch.isclose(destination_buffer, expected, atol=1) breaking_indices = torch.nonzero(diff_mask, as_tuple=False) @@ -250,7 +250,6 @@ def main(): # TODO local_rank torch.cuda.set_device(rank) - # Synchronize all processes comm.barrier() @@ -258,5 +257,6 @@ def main(): _worker(rank, num_ranks, init_url, args) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/iris/iris.py b/iris/iris.py index 305f3bb5a..a67ddb952 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1258,8 +1258,7 @@ def get_copy_engine_handle(self, to_rank): device_ctx[5] = handle.committed_wptr # context[6] = handle. - - return device_ctx # anvil.get_handle_as_tensor(queue) # torch.from_numpy(array) #.to(device='cuda') + return device_ctx # anvil.get_handle_as_tensor(queue) # torch.from_numpy(array) #.to(device='cuda') def __throw_if_invalid_output_tensor(self, tensor: torch.Tensor, num_elements: int, dtype: torch.dtype): if not self.__tensor_on_device(tensor): @@ -1765,6 +1764,7 @@ def nontemporal_store(addr, value): ) return tl.zeros_like(value) + @triton.jit def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): """ @@ -1793,8 +1793,8 @@ def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=Non >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) """ - handle = ce_handle #iris.get_copy_engine_handle(to_rank) - queue_ptr = tl.load(handle + 0) #.to(tl.pointer_type(tl.uint64)) + handle = ce_handle # iris.get_copy_engine_handle(to_rank) + queue_ptr = tl.load(handle + 0) # .to(tl.pointer_type(tl.uint64)) read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) doorbell_ptr = tl.load(handle + 3).to(tl.pointer_type(tl.uint64)) @@ -1854,15 +1854,13 @@ def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=Non # Check if full # TODO # expected = cur_index - if tl.atomic_cas(cached_write_ptr, cur_index, new_index, sem='acquire', scope='gpu') == cur_index: + if tl.atomic_cas(cached_write_ptr, cur_index, new_index, sem="acquire", scope="gpu") == cur_index: base = tl.full((), cur_index, dtype=tl.uint64) run_loop = False - # Place command packet queue_ptr_u32 = queue_ptr.to(tl.pointer_type(tl.uint32)) - slot_ptr_u32 = queue_ptr_u32 + (base // 4) - + slot_ptr_u32 = queue_ptr_u32 + (base // 4) # Convert to scalar value # from_ptr_as_u64 = tl.uint64(from_ptr) #tl.cast(from_ptr[0], tl.uint64) @@ -1870,7 +1868,7 @@ def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=Non # offset 0: op + sub_op tl.store(slot_ptr_u32 + 0, 1) # offset 1: reserved - tl.store(slot_ptr_u32 + 1, 0) + tl.store(slot_ptr_u32 + 1, 0) # offset 2: count tl.store(slot_ptr_u32 + 2, size_bytes - 1) # offset 3: src address 31:0 @@ -1882,21 +1880,21 @@ def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=Non # offset 6: dst address 63:32 tl.store(slot_ptr_u32 + 6, (dst_ptr_val >> 32).to(tl.uint32)) - # Submit command while tl.load(committed_write_ptr) != base: pass - + tl.store(write_ptr, base + command_in_bytes) tl.debug_barrier() # Ring doorbell # tl.store(doorbell_ptr, base + command_in_bytes) - tl.atomic_xchg(doorbell_ptr, base + command_in_bytes, sem='release', scope='sys') + tl.atomic_xchg(doorbell_ptr, base + command_in_bytes, sem="release", scope="sys") tl.debug_barrier() tl.store(committed_write_ptr, base + command_in_bytes) + @triton.jit def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): """ @@ -1925,8 +1923,8 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) """ - handle = ce_handle #iris.get_copy_engine_handle(to_rank) - queue_ptr = tl.load(handle + 0) #.to(tl.pointer_type(tl.uint64)) + handle = ce_handle # iris.get_copy_engine_handle(to_rank) + queue_ptr = tl.load(handle + 0) # .to(tl.pointer_type(tl.uint64)) read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) doorbell_ptr = tl.load(handle + 3).to(tl.pointer_type(tl.uint64)) @@ -1950,30 +1948,28 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): # Check if full # TODO # expected = cur_index - if tl.atomic_cas(cached_write_ptr, cur_index, new_index, sem='acquire', scope='gpu') == cur_index: + if tl.atomic_cas(cached_write_ptr, cur_index, new_index, sem="acquire", scope="gpu") == cur_index: base = tl.full((), cur_index, dtype=tl.uint64) run_loop = False - base_val = base.to(tl.uint64) # Place command packet queue_ptr_u32 = queue_ptr.to(tl.pointer_type(tl.uint32)) - slot_ptr_u32 = queue_ptr_u32 + (base_val // 4) + slot_ptr_u32 = queue_ptr_u32 + (base_val // 4) # print("queue_ptr: ", queue_ptr, " slot_ptr ", slot_ptr_u32, " base ", base) - # Convert to scalar value # from_ptr_as_u64 = tl.uint64(from_ptr) #tl.cast(from_ptr[0], tl.uint64) # offset 0: op + sub_op # tl.store(slot_ptr_u32 + 0, 0x2F0A) # op: 10, subop: 47 atomicAdd64 - tl.store(slot_ptr_u32 + 0, 0x0F0A) # op: 10, subop: 15 atomicAdd32 + tl.store(slot_ptr_u32 + 0, 0x0F0A) # op: 10, subop: 15 atomicAdd32 # offset 1: dst address 31:0 tl.store(slot_ptr_u32 + 1, dst_ptr_val.to(tl.uint32)) # offset 2: dst address 63:32 tl.store(slot_ptr_u32 + 2, (dst_ptr_val >> 32).to(tl.uint32)) # offset 3: src data 31:0 - tl.store(slot_ptr_u32 + 3, 1) # increment by 1 + tl.store(slot_ptr_u32 + 3, 1) # increment by 1 # offset 4: src data 63:32 tl.store(slot_ptr_u32 + 4, 0) # offset 5 - 7 unused @@ -1981,11 +1977,9 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): tl.store(slot_ptr_u32 + 6, 0) tl.store(slot_ptr_u32 + 7, 0) - # Submit command while tl.load(committed_write_ptr) != base_val: pass - tl.store(write_ptr, base + command_in_bytes) @@ -1999,7 +1993,6 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): tl.store(committed_write_ptr, base_val + command_in_bytes) - @triton.jit def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): """ diff --git a/setup.py b/setup.py index 478747224..69e4e9d4f 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ class CMakeExtension(Extension): """Extension that uses CMake to build""" + def __init__(self, name, sourcedir=""): super().__init__(name, sources=[]) self.sourcedir = os.path.abspath(sourcedir) @@ -56,16 +57,10 @@ def build_extension(self, ext): build_temp.mkdir(parents=True, exist_ok=True) # Run CMake - subprocess.check_call( - ["cmake", ext.sourcedir] + cmake_args, - cwd=build_temp - ) + subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp) # Build - subprocess.check_call( - ["cmake", "--build", "."] + build_args, - cwd=build_temp - ) + subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=build_temp) ext_modules = [] @@ -74,10 +69,7 @@ def build_extension(self, ext): build_copy_engine_offload = True if build_copy_engine_offload: print("Building Copy Engine offload library") - copy_engine_ext = CMakeExtension( - "iris.experimental.anvil", - sourcedir="ext/shader_sdma" - ) + copy_engine_ext = CMakeExtension("iris.experimental.anvil", sourcedir="ext/shader_sdma") ext_modules.append(copy_engine_ext) diff --git a/test.py b/test.py index b56505abc..dab152afc 100644 --- a/test.py +++ b/test.py @@ -18,4 +18,4 @@ # handle = queue.device_handle() -handle = anvil.get_handle_as_tensor(queue) \ No newline at end of file +handle = anvil.get_handle_as_tensor(queue) From 2f7bc5e3ec639c4f12004c93dec5172c1ae6197e Mon Sep 17 00:00:00 2001 From: David Sidler Date: Thu, 6 Nov 2025 12:17:11 -0600 Subject: [PATCH 03/29] message passing example working --- .../message_passing_copy_engine.py | 64 +++++----------- .../06_message_passing/message_passing_put.py | 14 ++-- iris/iris.py | 74 ++++++++++++++----- 3 files changed, 84 insertions(+), 68 deletions(-) diff --git a/examples/06_message_passing/message_passing_copy_engine.py b/examples/06_message_passing/message_passing_copy_engine.py index 8f9ea2267..1c02f2970 100644 --- a/examples/06_message_passing/message_passing_copy_engine.py +++ b/examples/06_message_passing/message_passing_copy_engine.py @@ -11,8 +11,6 @@ import random import os import sys -import time -import ctypes from mpi4py import MPI @@ -41,20 +39,10 @@ def producer_kernel( mask = offsets < buffer_size # Put chunk into remote buffer - # iris.put_ce(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, mask=mask) + iris.put_ce(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, mask=mask) # Set flag to signal completion - iris.signal_ce(flag + pid, producer_rank, producer_rank, heap_bases_ptr, copy_engine_handle_ptr) - # iris.atomic_cas_ce() - # zero_u64 = tl.zeros((1,), tl.uint64) - # one_u64 = tl.full((1,), 1, tl.uint64) - # iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys") - - # done = 0 #zero_u64 - # while done == 0: - # done = iris.atomic_cas( - # flag + pid, 1, 0, producer_rank, producer_rank, heap_bases_ptr, sem="acquire", scope="sys" - # ) + iris.signal_ce(flag + pid, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr) @triton.jit @@ -76,24 +64,24 @@ def consumer_kernel( # zero_u64 = tl.zeros((1,), tl.uint64) # one_u64 = tl.full((1,), 1, tl.uint64) done = 0 #zero_u64 - # while done == 0: - # done = iris.atomic_cas( - # flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" - # ) + while done == 0: + done = iris.atomic_cas( + flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" + ) # Read from the target buffer (written by producer) - # values = tl.load(buffer + offsets, mask=mask) + values = tl.load(buffer + offsets, mask=mask) # Do something with values... # (Here you might write to output, do computation, etc.) - # values = values * 2 + values = values * 2 # Store chunk to target buffer - # tl.store( - # buffer + offsets, - # values, - # mask=mask, - # ) + tl.store( + buffer + offsets, + values, + mask=mask, + ) # Optionally reset the flag for next iteration tl.store(flag + pid, 0) @@ -134,7 +122,6 @@ def parse_args(): parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") - parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -173,16 +160,10 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): n_elements = source_buffer.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) num_blocks = triton.cdiv(n_elements, args["block_size"]) - shmem.info(f"n_elements {n_elements} grid {grid} num_blocks {num_blocks}") - shmem.info(f"src buffer {id(source_buffer):#x} dst buffer {id(destination_buffer):#x}") - shmem.info(f"data ptr src buffer {source_buffer.data_ptr():#x} dst buffer {destination_buffer.data_ptr():#x}") # Allocate flags on the symmetric heap flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32) - # copy_engine_handle = iris.get_copy_engine_handle(consumer_rank) - - # src_ptr = ctypes.c_void_p(source_buffer.data_ptr()), if cur_rank == producer_rank: shmem.info(f"Rank {cur_rank} is sending data to rank {consumer_rank}.") kk = producer_kernel[grid]( @@ -197,22 +178,17 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): shmem.get_copy_engine_handle(consumer_rank), ) else: - # time.sleep(1) shmem.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.") kk = consumer_kernel[grid]( destination_buffer, flags, n_elements, consumer_rank, args["block_size"], shmem.get_heap_bases() ) - # time.sleep(1) - shmem.barrier() - # torch.cuda.synchronize() shmem.info(f"Rank {cur_rank} has finished sending/receiving data.") shmem.info("Validating output...") success = True - numErrors = 0 if cur_rank == consumer_rank: - expected = source_buffer # * 2 + expected = source_buffer * 2 diff_mask = ~torch.isclose(destination_buffer, expected, atol=1) breaking_indices = torch.nonzero(diff_mask, as_tuple=False) @@ -223,16 +199,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): idx = tuple(idx.tolist()) computed_val = destination_buffer[idx] expected_val = expected[idx] - if numErrors < 10: - shmem.info(f"Mismatch at index {idx}: C={computed_val}, expected={expected_val}") + shmem.info(f"Mismatch at index {idx}: C={computed_val}, expected={expected_val}") success = False - numErrors += 1 - # break + break if success: shmem.info("Validation successful.") else: - shmem.info(f"Validation failed with {numErrors} errors / {destination_buffer.numel()}") + shmem.info(f"Validation failed with {len(breaking_indices)} errors / {destination_buffer.numel()}") shmem.barrier() @@ -246,7 +220,6 @@ def main(): comm = MPI.COMM_WORLD # Communicator for all processes rank = comm.Get_rank() # Get the rank of the current process num_ranks = comm.Get_size() # Total number of processes - # TODO local_rank torch.cuda.set_device(rank) @@ -258,5 +231,6 @@ def main(): _worker(rank, num_ranks, init_url, args) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index fdbd36a7a..21eed8fef 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -3,8 +3,6 @@ import argparse -from mpi4py import MPI - import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -14,6 +12,8 @@ import os import sys +from mpi4py import MPI + import iris @@ -119,7 +119,6 @@ def parse_args(): parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") - # parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") return vars(parser.parse_args()) @@ -142,12 +141,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): world_size = shmem.get_num_ranks() # Allocate source and destination buffers on the symmetric heap - source_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype) + destination_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype) if dtype.is_floating_point: - destination_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) + source_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) else: ii = torch.iinfo(dtype) - destination_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) + source_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) if world_size != 2: raise ValueError("This example requires exactly two processes.") @@ -217,6 +216,9 @@ def main(): comm = MPI.COMM_WORLD # Communicator for all processes rank = comm.Get_rank() # Get the rank of the current process num_ranks = comm.Get_size() # Total number of processes + # TODO local_rank + torch.cuda.set_device(rank) + # Synchronize all processes comm.barrier() diff --git a/iris/iris.py b/iris/iris.py index 305f3bb5a..c48b72a1d 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1756,14 +1756,54 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): @triton.jit def nontemporal_store(addr, value): tl.inline_asm_elementwise( - asm="""flat_store_dwordx2 $0 $1 sc0 nt; s_waitcnt vmcnt(0)""", - constraints=("v,v"), + asm="""flat_store_dwordx2 $1 $2 sc0 nt; s_waitcnt vmcnt(0)""", + constraints=("=r,v,v"), # =r used for dummy return to satisfy compiler requirement args=[addr, value], + dtype=tl.int32, # return not used + is_pure=False, + pack=1, + ) + +# TODO rename or add nt +@triton.jit +def nontemporal_load(addr): + val = tl.inline_asm_elementwise( + asm="""flat_load_dwordx2 $0 $1 sc0 sc1; s_waitcnt vmcnt(0)""", + constraints=("=v,v"), + args=[ addr], + dtype=tl.uint64, + is_pure=False, + pack=1, + ) + return val + +@triton.jit +def nontemporal_atomic_add(addr, value): + old = tl.inline_asm_elementwise( + asm="""flat_atomic_add_x2 $0 $1 sc0 sc1; s_waitcnt vmcnt(0)""", + constraints=("=v,v,v"), + args=[addr,value], dtype=tl.uint64, is_pure=False, pack=1, ) - return tl.zeros_like(value) + return old + + +# @triton.jit +# def nontemporal_compare_exchange(addr, cmp_low, cmp_high, val_low, val_high): +# # data_128bit = tl.cat([cmp_low, cmp_high, val_low, val_high]) +# data_128bit = tl.make_vector([cmp_low, cmp_high, val_low, val_high], type=tl.uint32) +# old = tl.inline_asm_elementwise( +# asm="""flat_atomic_cmpswap_x2 $0 $1 $2 sc0 nt; s_waitcnt vmcnt(0)""", +# constraints=("=v,v,v"), +# args=[addr, data_128bit], +# dtype=tl.uint64, +# is_pure=False, +# pack=1, +# ) +# return True # TODO if old == cmp else False + @triton.jit def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): @@ -1869,10 +1909,10 @@ def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=Non # offset 0: op + sub_op tl.store(slot_ptr_u32 + 0, 1) - # offset 1: reserved - tl.store(slot_ptr_u32 + 1, 0) - # offset 2: count - tl.store(slot_ptr_u32 + 2, size_bytes - 1) + # offset 1: count + tl.store(slot_ptr_u32 + 1, size_bytes - 1) + # offset 2: parameters + tl.store(slot_ptr_u32 + 2, 0) # offset 3: src address 31:0 tl.store(slot_ptr_u32 + 3, src_ptr_val.to(tl.uint32)) # offset 4: src address 63:32 @@ -1887,7 +1927,8 @@ def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=Non while tl.load(committed_write_ptr) != base: pass - tl.store(write_ptr, base + command_in_bytes) + # tl.store(write_ptr, base + command_in_bytes) + tl.atomic_xchg(write_ptr, base + command_in_bytes, sem='release', scope='gpu') tl.debug_barrier() @@ -1955,10 +1996,9 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): run_loop = False - base_val = base.to(tl.uint64) # Place command packet queue_ptr_u32 = queue_ptr.to(tl.pointer_type(tl.uint32)) - slot_ptr_u32 = queue_ptr_u32 + (base_val // 4) + slot_ptr_u32 = queue_ptr_u32 + (base // 4) # print("queue_ptr: ", queue_ptr, " slot_ptr ", slot_ptr_u32, " base ", base) @@ -1966,8 +2006,8 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): # from_ptr_as_u64 = tl.uint64(from_ptr) #tl.cast(from_ptr[0], tl.uint64) # offset 0: op + sub_op - # tl.store(slot_ptr_u32 + 0, 0x2F0A) # op: 10, subop: 47 atomicAdd64 - tl.store(slot_ptr_u32 + 0, 0x0F0A) # op: 10, subop: 15 atomicAdd32 + # tl.store(slot_ptr_u32 + 0, ((0x2F & 0x7F) << 25 | (0xA & 0xFF)) # op: 10, operation: 47 atomicAdd64 + tl.store(slot_ptr_u32 + 0, ((0xF & 0x7F) << 25) | (0xA & 0xFF)) # op: 10, operation: 15 atomicAdd32 # offset 1: dst address 31:0 tl.store(slot_ptr_u32 + 1, dst_ptr_val.to(tl.uint32)) # offset 2: dst address 63:32 @@ -1983,20 +2023,20 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): # Submit command - while tl.load(committed_write_ptr) != base_val: + while tl.load(committed_write_ptr) != base: pass - tl.store(write_ptr, base + command_in_bytes) + # tl.store(write_ptr, base + command_in_bytes) + tl.atomic_xchg(write_ptr, base + command_in_bytes, sem='release', scope='gpu') tl.debug_barrier() # Ring doorbell # tl.store(doorbell_ptr, base_val + command_in_bytes) - # tl.atomic_xchg(doorbell_ptr, base + command_in_bytes, sem='release', scope='sys') - nontemporal_store(doorbell_ptr, base + command_in_bytes) + tl.atomic_xchg(doorbell_ptr, base + command_in_bytes, sem='release', scope='sys') tl.debug_barrier() - tl.store(committed_write_ptr, base_val + command_in_bytes) + tl.store(committed_write_ptr, base + command_in_bytes) From 759f662001971619f585b8fc0f8fa9ed5f856760 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 6 Nov 2025 18:21:02 +0000 Subject: [PATCH 04/29] Apply Ruff auto-fixes --- .../message_passing_copy_engine.py | 12 +++++++-- .../06_message_passing/message_passing_put.py | 1 - iris/iris.py | 26 +++++++++---------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/examples/06_message_passing/message_passing_copy_engine.py b/examples/06_message_passing/message_passing_copy_engine.py index 563a2580b..0fbb886a6 100644 --- a/examples/06_message_passing/message_passing_copy_engine.py +++ b/examples/06_message_passing/message_passing_copy_engine.py @@ -39,7 +39,15 @@ def producer_kernel( mask = offsets < buffer_size # Put chunk into remote buffer - iris.put_ce(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, mask=mask) + iris.put_ce( + source_buffer + offsets, + target_buffer + offsets, + producer_rank, + consumer_rank, + heap_bases_ptr, + copy_engine_handle_ptr, + mask=mask, + ) # Set flag to signal completion iris.signal_ce(flag + pid, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr) @@ -63,7 +71,7 @@ def consumer_kernel( # Spin-wait until writer sets flag[pid] = 1 # zero_u64 = tl.zeros((1,), tl.uint64) # one_u64 = tl.full((1,), 1, tl.uint64) - done = 0 #zero_u64 + done = 0 # zero_u64 while done == 0: done = iris.atomic_cas( flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index 21eed8fef..f396c42b1 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -219,7 +219,6 @@ def main(): # TODO local_rank torch.cuda.set_device(rank) - # Synchronize all processes comm.barrier() diff --git a/iris/iris.py b/iris/iris.py index b167aaaf5..7bedd9cb2 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1756,32 +1756,34 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): def nontemporal_store(addr, value): tl.inline_asm_elementwise( asm="""flat_store_dwordx2 $1 $2 sc0 nt; s_waitcnt vmcnt(0)""", - constraints=("=r,v,v"), # =r used for dummy return to satisfy compiler requirement + constraints=("=r,v,v"), # =r used for dummy return to satisfy compiler requirement args=[addr, value], - dtype=tl.int32, # return not used + dtype=tl.int32, # return not used is_pure=False, pack=1, ) + # TODO rename or add nt @triton.jit def nontemporal_load(addr): val = tl.inline_asm_elementwise( asm="""flat_load_dwordx2 $0 $1 sc0 sc1; s_waitcnt vmcnt(0)""", constraints=("=v,v"), - args=[ addr], + args=[addr], dtype=tl.uint64, is_pure=False, pack=1, ) return val + @triton.jit def nontemporal_atomic_add(addr, value): old = tl.inline_asm_elementwise( asm="""flat_atomic_add_x2 $0 $1 sc0 sc1; s_waitcnt vmcnt(0)""", constraints=("=v,v,v"), - args=[addr,value], + args=[addr, value], dtype=tl.uint64, is_pure=False, pack=1, @@ -1804,7 +1806,6 @@ def nontemporal_atomic_add(addr, value): # return True # TODO if old == cmp else False - @triton.jit def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): """ @@ -1908,7 +1909,7 @@ def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=Non # offset 0: op + sub_op tl.store(slot_ptr_u32 + 0, 1) # offset 1: count - tl.store(slot_ptr_u32 + 1, size_bytes - 1) + tl.store(slot_ptr_u32 + 1, size_bytes - 1) # offset 2: parameters tl.store(slot_ptr_u32 + 2, 0) # offset 3: src address 31:0 @@ -1923,9 +1924,9 @@ def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=Non # Submit command while tl.load(committed_write_ptr) != base: pass - + # tl.store(write_ptr, base + command_in_bytes) - tl.atomic_xchg(write_ptr, base + command_in_bytes, sem='release', scope='gpu') + tl.atomic_xchg(write_ptr, base + command_in_bytes, sem="release", scope="gpu") tl.debug_barrier() @@ -1993,10 +1994,9 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): base = tl.full((), cur_index, dtype=tl.uint64) run_loop = False - # Place command packet queue_ptr_u32 = queue_ptr.to(tl.pointer_type(tl.uint32)) - slot_ptr_u32 = queue_ptr_u32 + (base // 4) + slot_ptr_u32 = queue_ptr_u32 + (base // 4) # print("queue_ptr: ", queue_ptr, " slot_ptr ", slot_ptr_u32, " base ", base) # Convert to scalar value @@ -2004,7 +2004,7 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): # offset 0: op + sub_op # tl.store(slot_ptr_u32 + 0, ((0x2F & 0x7F) << 25 | (0xA & 0xFF)) # op: 10, operation: 47 atomicAdd64 - tl.store(slot_ptr_u32 + 0, ((0xF & 0x7F) << 25) | (0xA & 0xFF)) # op: 10, operation: 15 atomicAdd32 + tl.store(slot_ptr_u32 + 0, ((0xF & 0x7F) << 25) | (0xA & 0xFF)) # op: 10, operation: 15 atomicAdd32 # offset 1: dst address 31:0 tl.store(slot_ptr_u32 + 1, dst_ptr_val.to(tl.uint32)) # offset 2: dst address 63:32 @@ -2023,13 +2023,13 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): pass # tl.store(write_ptr, base + command_in_bytes) - tl.atomic_xchg(write_ptr, base + command_in_bytes, sem='release', scope='gpu') + tl.atomic_xchg(write_ptr, base + command_in_bytes, sem="release", scope="gpu") tl.debug_barrier() # Ring doorbell # tl.store(doorbell_ptr, base_val + command_in_bytes) - tl.atomic_xchg(doorbell_ptr, base + command_in_bytes, sem='release', scope='sys') + tl.atomic_xchg(doorbell_ptr, base + command_in_bytes, sem="release", scope="sys") tl.debug_barrier() tl.store(committed_write_ptr, base + command_in_bytes) From ad7769de849c38ffa1763e952a11f79802b6799f Mon Sep 17 00:00:00 2001 From: David Sidler Date: Fri, 7 Nov 2025 15:07:06 -0600 Subject: [PATCH 05/29] update put example to use ce --- .../06_message_passing/message_passing_put.py | 18 +- iris/iris.py | 200 ++++++++---------- setup.py | 74 +------ 3 files changed, 107 insertions(+), 185 deletions(-) diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index 21eed8fef..66eedf989 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -27,6 +27,8 @@ def producer_kernel( consumer_rank: tl.constexpr, BLOCK_SIZE: tl.constexpr, heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers + copy_engine_handle_ptr, + USE_COPY_ENGINE: tl.constexpr, ): pid = tl.program_id(0) @@ -38,10 +40,11 @@ def producer_kernel( mask = offsets < buffer_size # Put chunk into remote buffer - iris.put(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, mask=mask) + iris.put(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, mask=mask, USE_COPY_ENGINE=USE_COPY_ENGINE) # Set flag to signal completion - iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys") + # iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, sem="release", scope="sys") + iris.atomic_add(flag + pid, 1, producer_rank, consumer_rank, heap_bases_ptr, sem='release', scope='sys', copy_engine_ctx=copy_engine_handle_ptr, USE_COPY_ENGINE=USE_COPY_ENGINE) @triton.jit @@ -117,8 +120,9 @@ def parse_args(): ) parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size") parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") - parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") + parser.add_argument("-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies") + return vars(parser.parse_args()) @@ -161,6 +165,9 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Allocate flags on the symmetric heap flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32) + # Get copy engine context + copy_engine_ctx = shmem.get_copy_engine_handle(consumer_rank) if args["use_copy_engine"] and cur_rank == producer_rank else None + if cur_rank == producer_rank: shmem.info(f"Rank {cur_rank} is sending data to rank {consumer_rank}.") kk = producer_kernel[grid]( @@ -172,6 +179,8 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): consumer_rank, args["block_size"], shmem.get_heap_bases(), + copy_engine_ctx, + USE_COPY_ENGINE=args["use_copy_engine"] ) else: shmem.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.") @@ -202,7 +211,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if success: shmem.info("Validation successful.") else: - shmem.info("Validation failed.") + shmem.info(f"Validation failed with {len(breaking_indices)} errors / {destination_buffer.numel()}") shmem.barrier() @@ -219,7 +228,6 @@ def main(): # TODO local_rank torch.cuda.set_device(rank) - # Synchronize all processes comm.barrier() diff --git a/iris/iris.py b/iris/iris.py index b167aaaf5..8fb5188cf 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -43,9 +43,7 @@ ) import sys -# sys.path.append("/home/dasidler/iris/iris/experimental") -# import my_module as anvil -import iris.experimental.my_module as anvil +import anvil import numpy as np import math @@ -1719,7 +1717,7 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): @triton.jit -def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): +def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, copy_engine_ctx, mask=None, USE_COPY_ENGINE : tl.constexpr=False): """ Copies data from the current rank's local memory to the specified rank's memory. This function performs a memory write operation by loading data from the current @@ -1747,9 +1745,65 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): """ translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) - data = tl.load(from_ptr, mask=mask) + if not USE_COPY_ENGINE: + data = tl.load(from_ptr, mask=mask) + + tl.store(translated_to_ptr, data, mask=mask) + else: + ctx = copy_engine_ctx + queue_ptr_u32 = tl.load(ctx + 0).to(tl.pointer_type(tl.uint32)) + read_ptr = tl.load(ctx + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(ctx + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(ctx + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(ctx + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(ctx + 5).to(tl.pointer_type(tl.uint64)) + + dst_ptr_val = tl.min(translated_to_ptr.to(tl.uint64), axis=0) + + # Extract source address (min of pointer block where data is stored) + src_ptr_u64 = from_ptr.to(tl.uint64) + src_ptr_val = tl.min(src_ptr_u64, axis=0) + max_src_ptr = tl.max(src_ptr_u64, axis=0) + + # Infer element size from pointer type + # src_ptr is a block of pointers with a specific element type (e.g., pointer) + # The pointer dtype tells us the element type, which has a known size + # Map Triton dtypes to their byte sizes + ptr_dtype = from_ptr.dtype.element_ty # Get the element type that the pointer points to + + # Get element size in bytes from the dtype + # tl.float16 -> 2, tl.float32 -> 4, tl.float64 -> 8, etc. + if ptr_dtype == tl.float16 or ptr_dtype == tl.bfloat16: + element_size_bytes = 2 + elif ptr_dtype == tl.float32 or ptr_dtype == tl.int32 or ptr_dtype == tl.uint32: + element_size_bytes = 4 + elif ptr_dtype == tl.float64 or ptr_dtype == tl.int64 or ptr_dtype == tl.uint64: + element_size_bytes = 8 + elif ptr_dtype == tl.int8 or ptr_dtype == tl.uint8: + element_size_bytes = 1 + elif ptr_dtype == tl.int16 or ptr_dtype == tl.uint16: + element_size_bytes = 2 + else: + # Default to 4 bytes for unknown types + element_size_bytes = 4 + + # Calculate total size in bytes + # Count number of valid elements based on mask + mask_int = mask.to(tl.int32) + num_elements = tl.sum(mask_int, axis=0) + size_bytes = (num_elements * element_size_bytes).to(tl.uint32) + + command_in_bytes = 28 + # Acquire space + base = anvil.acquire(cached_write_ptr, command_in_bytes) + + # Place command + slot_ptr_u32 = queue_ptr_u32 + (base // 4) + anvil.place_copy_packet(slot_ptr_u32, size_bytes, src_ptr_val, dst_ptr_val) + + # Submit command + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) - tl.store(translated_to_ptr, data, mask=mask) @triton.jit @@ -1877,63 +1931,17 @@ def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=Non num_elements = tl.sum(mask_int, axis=0) size_bytes = (num_elements * element_size_bytes).to(tl.uint32) - # data = tl.load(from_ptr, mask=mask) - # tl.store(translated_to_ptr, data, mask=mask) - command_in_bytes = 28 - base = tl.zeros((), dtype=tl.uint64) - # copy_size_in_bytes = tl.sum(mask.to(tl.int8)).to(tl.uint32) # Acquire space - run_loop = True - while run_loop: - cur_index = tl.load(cached_write_ptr) - new_index = cur_index + command_in_bytes - # Check if wrap around - # TODO - - # Check if full - # TODO - # expected = cur_index - if tl.atomic_cas(cached_write_ptr, cur_index, new_index, sem="acquire", scope="gpu") == cur_index: - base = tl.full((), cur_index, dtype=tl.uint64) - run_loop = False + base = anvil.acquire(cached_write_ptr, command_in_bytes) - # Place command packet + # Place command queue_ptr_u32 = queue_ptr.to(tl.pointer_type(tl.uint32)) slot_ptr_u32 = queue_ptr_u32 + (base // 4) - - # Convert to scalar value - # from_ptr_as_u64 = tl.uint64(from_ptr) #tl.cast(from_ptr[0], tl.uint64) - - # offset 0: op + sub_op - tl.store(slot_ptr_u32 + 0, 1) - # offset 1: count - tl.store(slot_ptr_u32 + 1, size_bytes - 1) - # offset 2: parameters - tl.store(slot_ptr_u32 + 2, 0) - # offset 3: src address 31:0 - tl.store(slot_ptr_u32 + 3, src_ptr_val.to(tl.uint32)) - # offset 4: src address 63:32 - tl.store(slot_ptr_u32 + 4, (src_ptr_val >> 32).to(tl.uint32)) - # offset 5: dst address 31:0 - tl.store(slot_ptr_u32 + 5, dst_ptr_val.to(tl.uint32)) - # offset 6: dst address 63:32 - tl.store(slot_ptr_u32 + 6, (dst_ptr_val >> 32).to(tl.uint32)) + anvil.place_copy_packet(slot_ptr_u32, size_bytes, src_ptr_val, dst_ptr_val) # Submit command - while tl.load(committed_write_ptr) != base: - pass - - # tl.store(write_ptr, base + command_in_bytes) - tl.atomic_xchg(write_ptr, base + command_in_bytes, sem='release', scope='gpu') - - tl.debug_barrier() - - # Ring doorbell - # tl.store(doorbell_ptr, base + command_in_bytes) - tl.atomic_xchg(doorbell_ptr, base + command_in_bytes, sem="release", scope="sys") - tl.debug_barrier() - tl.store(committed_write_ptr, base + command_in_bytes) + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) @triton.jit @@ -1976,66 +1984,20 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): dst_ptr_val = translated_to_ptr.to(tl.uint64) command_in_bytes = 32 - base = tl.zeros((), dtype=tl.uint64) - # copy_size_in_bytes = tl.sum(mask.to(tl.int8)).to(tl.uint32) # Acquire space - run_loop = True - while run_loop: - cur_index = tl.load(cached_write_ptr) - new_index = cur_index + command_in_bytes - # Check if wrap around - # TODO - - # Check if full - # TODO - # expected = cur_index - if tl.atomic_cas(cached_write_ptr, cur_index, new_index, sem="acquire", scope="gpu") == cur_index: - base = tl.full((), cur_index, dtype=tl.uint64) - run_loop = False - + base = anvil.acquire(cached_write_ptr, command_in_bytes) # Place command packet queue_ptr_u32 = queue_ptr.to(tl.pointer_type(tl.uint32)) slot_ptr_u32 = queue_ptr_u32 + (base // 4) - # print("queue_ptr: ", queue_ptr, " slot_ptr ", slot_ptr_u32, " base ", base) - - # Convert to scalar value - # from_ptr_as_u64 = tl.uint64(from_ptr) #tl.cast(from_ptr[0], tl.uint64) - - # offset 0: op + sub_op - # tl.store(slot_ptr_u32 + 0, ((0x2F & 0x7F) << 25 | (0xA & 0xFF)) # op: 10, operation: 47 atomicAdd64 - tl.store(slot_ptr_u32 + 0, ((0xF & 0x7F) << 25) | (0xA & 0xFF)) # op: 10, operation: 15 atomicAdd32 - # offset 1: dst address 31:0 - tl.store(slot_ptr_u32 + 1, dst_ptr_val.to(tl.uint32)) - # offset 2: dst address 63:32 - tl.store(slot_ptr_u32 + 2, (dst_ptr_val >> 32).to(tl.uint32)) - # offset 3: src data 31:0 - tl.store(slot_ptr_u32 + 3, 1) # increment by 1 - # offset 4: src data 63:32 - tl.store(slot_ptr_u32 + 4, 0) - # offset 5 - 7 unused - tl.store(slot_ptr_u32 + 5, 0) - tl.store(slot_ptr_u32 + 6, 0) - tl.store(slot_ptr_u32 + 7, 0) + anvil.place_atomic_packet(slot_ptr_u32, dst_ptr_val) # Submit command - while tl.load(committed_write_ptr) != base: - pass - - # tl.store(write_ptr, base + command_in_bytes) - tl.atomic_xchg(write_ptr, base + command_in_bytes, sem='release', scope='gpu') - - tl.debug_barrier() - - # Ring doorbell - # tl.store(doorbell_ptr, base_val + command_in_bytes) - tl.atomic_xchg(doorbell_ptr, base + command_in_bytes, sem='release', scope='sys') - tl.debug_barrier() - tl.store(committed_write_ptr, base + command_in_bytes) + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) @triton.jit -def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, copy_engine_ctx=None, USE_COPY_ENGINE: tl.constexpr=False): """ Performs an atomic add at the specified rank's memory location. @@ -2067,7 +2029,31 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> old_val = iris.atomic_add(ptr, increment, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) + if not USE_COPY_ENGINE: + return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) + else: + handle = copy_engine_ctx + queue_ptr_u32 = tl.load(handle + 0).to(tl.pointer_type(tl.uint32)) + read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(handle + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(handle + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(handle + 5).to(tl.pointer_type(tl.uint64)) + + dst_ptr_val = translated_ptr.to(tl.uint64) + + command_in_bytes = 32 + # Acquire space + base = anvil.acquire(cached_write_ptr, command_in_bytes) + + # Place command packet + slot_ptr_u32 = queue_ptr_u32 + (base // 4) + anvil.place_atomic_packet(slot_ptr_u32, dst_ptr_val) + + # Submit command + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) + + @triton.jit diff --git a/setup.py b/setup.py index 69e4e9d4f..698324612 100644 --- a/setup.py +++ b/setup.py @@ -1,83 +1,11 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. -import os -import subprocess -import sys -from pathlib import Path -from setuptools import setup, Extension -from setuptools.command.build_ext import build_ext - - -class CMakeExtension(Extension): - """Extension that uses CMake to build""" - - def __init__(self, name, sourcedir=""): - super().__init__(name, sources=[]) - self.sourcedir = os.path.abspath(sourcedir) - - -class CMakeBuild(build_ext): - """Custom build_ext command that runs CMake""" - - def run(self): - # Check if CMake is available - try: - subprocess.check_output(["cmake", "--version"]) - except OSError: - raise RuntimeError("CMake must be installed to build RDMA extensions") - - # Build each extension - for ext in self.extensions: - self.build_extension(ext) - - def build_extension(self, ext): - if not isinstance(ext, CMakeExtension): - return super().build_extension(ext) - - extdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute() - - # CMake configuration arguments - cmake_args = [ - f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", - f"-DPYTHON_EXECUTABLE={sys.executable}", - "-DCMAKE_CXX_COMPILER=/usr/bin/hipcc", - "-DCMAKE_BUILD_TYPE=Release", - ] - - # Build arguments - build_args = ["--config", "Release"] - - # Parallel build - if hasattr(os, "cpu_count"): - build_args += [f"-j{os.cpu_count()}"] - - # Create build directory - build_temp = Path(self.build_temp) / ext.name - build_temp.mkdir(parents=True, exist_ok=True) - - # Run CMake - subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp) - - # Build - subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=build_temp) - - -ext_modules = [] - -# TODO make optional -build_copy_engine_offload = True -if build_copy_engine_offload: - print("Building Copy Engine offload library") - copy_engine_ext = CMakeExtension("iris.experimental.anvil", sourcedir="ext/shader_sdma") - ext_modules.append(copy_engine_ext) - +from setuptools import setup # This setup.py provides backward compatibility for legacy metadata fields # that don't map directly from pyproject.toml's modern PEP 621 format. setup( url="https://rocm.github.io/iris/", author="Muhammad Awad, Muhammad Osama, Brandon Potter", - ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuild} if ext_modules else {}, ) From b8862cc7011cf0b067a530bb06c33df425a8c301 Mon Sep 17 00:00:00 2001 From: David Sidler Date: Fri, 7 Nov 2025 17:18:22 -0600 Subject: [PATCH 06/29] update api calls --- .../message_passing_copy_engine.py | 2 +- iris/__init__.py | 2 - iris/iris.py | 94 +------------------ 3 files changed, 5 insertions(+), 93 deletions(-) diff --git a/examples/06_message_passing/message_passing_copy_engine.py b/examples/06_message_passing/message_passing_copy_engine.py index 563a2580b..a0ec7a472 100644 --- a/examples/06_message_passing/message_passing_copy_engine.py +++ b/examples/06_message_passing/message_passing_copy_engine.py @@ -39,7 +39,7 @@ def producer_kernel( mask = offsets < buffer_size # Put chunk into remote buffer - iris.put_ce(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, mask=mask) + iris.put(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, mask=mask, USE_COPY_ENGINE=True) # Set flag to signal completion iris.signal_ce(flag + pid, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr) diff --git a/iris/__init__.py b/iris/__init__.py index d992dca4f..58079dc53 100644 --- a/iris/__init__.py +++ b/iris/__init__.py @@ -59,7 +59,6 @@ atomic_and, atomic_min, atomic_max, - put_ce, signal_ce, ) @@ -100,7 +99,6 @@ "atomic_and", "atomic_min", "atomic_max", - "put_ce", "signal_ce", "do_bench", "hip", diff --git a/iris/iris.py b/iris/iris.py index 8fb5188cf..8095cea5a 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1795,7 +1795,7 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, copy_engine_ctx, mask= command_in_bytes = 28 # Acquire space - base = anvil.acquire(cached_write_ptr, command_in_bytes) + base = anvil.acquire(queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes) # Place command slot_ptr_u32 = queue_ptr_u32 + (base // 4) @@ -1859,91 +1859,6 @@ def nontemporal_atomic_add(addr, value): -@triton.jit -def put_ce(from_ptr, to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): - """ - Copies data from the current rank's local memory to the specified rank's memory. - This function performs a memory write operation by loading data from the current - rank's `from_ptr`, translating the `to_ptr` from the current rank's address - space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. - If the `to_rank` is the same as the current rank, this function performs a local copy operation. - - Args: - from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. - to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - from_rank (int): The current rank ID from which to read the data. - to_rank (int): The `to_rank` ID to which the data will be written. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. - - Returns: - None - - Example: - >>> @triton.jit - >>> def kernel(local_ptr, remote_ptr, heap_bases): - >>> from_rank = 0 - >>> to_rank = 1 - >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) - """ - - handle = ce_handle # iris.get_copy_engine_handle(to_rank) - queue_ptr = tl.load(handle + 0) # .to(tl.pointer_type(tl.uint64)) - read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) - write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) - doorbell_ptr = tl.load(handle + 3).to(tl.pointer_type(tl.uint64)) - cached_write_ptr = tl.load(handle + 4).to(tl.pointer_type(tl.uint64)) - committed_write_ptr = tl.load(handle + 5).to(tl.pointer_type(tl.uint64)) - - translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) - dst_ptr_val = tl.min(translated_to_ptr.to(tl.uint64), axis=0) - - # Extract source address (min of pointer block where data is stored) - src_ptr_u64 = from_ptr.to(tl.uint64) - src_ptr_val = tl.min(src_ptr_u64, axis=0) - max_src_ptr = tl.max(src_ptr_u64, axis=0) - - # Infer element size from pointer type - # src_ptr is a block of pointers with a specific element type (e.g., pointer) - # The pointer dtype tells us the element type, which has a known size - # Map Triton dtypes to their byte sizes - ptr_dtype = from_ptr.dtype.element_ty # Get the element type that the pointer points to - - # Get element size in bytes from the dtype - # tl.float16 -> 2, tl.float32 -> 4, tl.float64 -> 8, etc. - if ptr_dtype == tl.float16 or ptr_dtype == tl.bfloat16: - element_size_bytes = 2 - elif ptr_dtype == tl.float32 or ptr_dtype == tl.int32 or ptr_dtype == tl.uint32: - element_size_bytes = 4 - elif ptr_dtype == tl.float64 or ptr_dtype == tl.int64 or ptr_dtype == tl.uint64: - element_size_bytes = 8 - elif ptr_dtype == tl.int8 or ptr_dtype == tl.uint8: - element_size_bytes = 1 - elif ptr_dtype == tl.int16 or ptr_dtype == tl.uint16: - element_size_bytes = 2 - else: - # Default to 4 bytes for unknown types - element_size_bytes = 4 - - # Calculate total size in bytes - # Count number of valid elements based on mask - mask_int = mask.to(tl.int32) - num_elements = tl.sum(mask_int, axis=0) - size_bytes = (num_elements * element_size_bytes).to(tl.uint32) - - command_in_bytes = 28 - # Acquire space - base = anvil.acquire(cached_write_ptr, command_in_bytes) - - # Place command - queue_ptr_u32 = queue_ptr.to(tl.pointer_type(tl.uint32)) - slot_ptr_u32 = queue_ptr_u32 + (base // 4) - anvil.place_copy_packet(slot_ptr_u32, size_bytes, src_ptr_val, dst_ptr_val) - - # Submit command - anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) - - @triton.jit def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): """ @@ -1973,7 +1888,7 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): """ handle = ce_handle # iris.get_copy_engine_handle(to_rank) - queue_ptr = tl.load(handle + 0) # .to(tl.pointer_type(tl.uint64)) + queue_ptr_u32 = tl.load(handle + 0).to(tl.pointer_type(tl.uint32)) read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) doorbell_ptr = tl.load(handle + 3).to(tl.pointer_type(tl.uint64)) @@ -1985,10 +1900,9 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): command_in_bytes = 32 # Acquire space - base = anvil.acquire(cached_write_ptr, command_in_bytes) + base = anvil.acquire(queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes) # Place command packet - queue_ptr_u32 = queue_ptr.to(tl.pointer_type(tl.uint32)) slot_ptr_u32 = queue_ptr_u32 + (base // 4) anvil.place_atomic_packet(slot_ptr_u32, dst_ptr_val) @@ -2044,7 +1958,7 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None command_in_bytes = 32 # Acquire space - base = anvil.acquire(cached_write_ptr, command_in_bytes) + base = anvil.acquire(queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes) # Place command packet slot_ptr_u32 = queue_ptr_u32 + (base // 4) From 75c56263634939d357399dad674e725c8016b91b Mon Sep 17 00:00:00 2001 From: David Sidler Date: Fri, 7 Nov 2025 17:19:24 -0600 Subject: [PATCH 07/29] update submodule --- ext/shader_sdma | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/shader_sdma b/ext/shader_sdma index 243be5f30..d17774541 160000 --- a/ext/shader_sdma +++ b/ext/shader_sdma @@ -1 +1 @@ -Subproject commit 243be5f30f96374d4231cd669e179584e714435d +Subproject commit d177745411cd7c4bafb00493b1ea3bdd0390e39e From e3aef16a7cae0c24003bbc2ddf64d1234d6ff8d3 Mon Sep 17 00:00:00 2001 From: David Sidler Date: Fri, 7 Nov 2025 17:23:06 -0600 Subject: [PATCH 08/29] fix merge --- examples/06_message_passing/message_passing_copy_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/06_message_passing/message_passing_copy_engine.py b/examples/06_message_passing/message_passing_copy_engine.py index ec7bd1795..2e96f131f 100644 --- a/examples/06_message_passing/message_passing_copy_engine.py +++ b/examples/06_message_passing/message_passing_copy_engine.py @@ -39,7 +39,7 @@ def producer_kernel( mask = offsets < buffer_size # Put chunk into remote buffer - iris.put_ce( + iris.put( source_buffer + offsets, target_buffer + offsets, producer_rank, From df04547c22734eb0db848c5e7c3fd423700f1418 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 7 Nov 2025 23:23:44 +0000 Subject: [PATCH 09/29] Apply Ruff auto-fixes --- .../message_passing_copy_engine.py | 2 +- .../06_message_passing/message_passing_put.py | 34 ++++++++++++++---- iris/iris.py | 36 +++++++++++++------ 3 files changed, 55 insertions(+), 17 deletions(-) diff --git a/examples/06_message_passing/message_passing_copy_engine.py b/examples/06_message_passing/message_passing_copy_engine.py index 2e96f131f..6f0acf13b 100644 --- a/examples/06_message_passing/message_passing_copy_engine.py +++ b/examples/06_message_passing/message_passing_copy_engine.py @@ -47,7 +47,7 @@ def producer_kernel( heap_bases_ptr, copy_engine_handle_ptr, mask=mask, - USE_COPY_ENGINE=True + USE_COPY_ENGINE=True, ) # Set flag to signal completion diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index 66eedf989..42d38fe0e 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -40,11 +40,30 @@ def producer_kernel( mask = offsets < buffer_size # Put chunk into remote buffer - iris.put(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, mask=mask, USE_COPY_ENGINE=USE_COPY_ENGINE) + iris.put( + source_buffer + offsets, + target_buffer + offsets, + producer_rank, + consumer_rank, + heap_bases_ptr, + copy_engine_handle_ptr, + mask=mask, + USE_COPY_ENGINE=USE_COPY_ENGINE, + ) # Set flag to signal completion # iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, sem="release", scope="sys") - iris.atomic_add(flag + pid, 1, producer_rank, consumer_rank, heap_bases_ptr, sem='release', scope='sys', copy_engine_ctx=copy_engine_handle_ptr, USE_COPY_ENGINE=USE_COPY_ENGINE) + iris.atomic_add( + flag + pid, + 1, + producer_rank, + consumer_rank, + heap_bases_ptr, + sem="release", + scope="sys", + copy_engine_ctx=copy_engine_handle_ptr, + USE_COPY_ENGINE=USE_COPY_ENGINE, + ) @triton.jit @@ -121,8 +140,9 @@ def parse_args(): parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size") parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") - parser.add_argument("-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies") - + parser.add_argument( + "-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies" + ) return vars(parser.parse_args()) @@ -166,7 +186,9 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32) # Get copy engine context - copy_engine_ctx = shmem.get_copy_engine_handle(consumer_rank) if args["use_copy_engine"] and cur_rank == producer_rank else None + copy_engine_ctx = ( + shmem.get_copy_engine_handle(consumer_rank) if args["use_copy_engine"] and cur_rank == producer_rank else None + ) if cur_rank == producer_rank: shmem.info(f"Rank {cur_rank} is sending data to rank {consumer_rank}.") @@ -180,7 +202,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): args["block_size"], shmem.get_heap_bases(), copy_engine_ctx, - USE_COPY_ENGINE=args["use_copy_engine"] + USE_COPY_ENGINE=args["use_copy_engine"], ) else: shmem.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.") diff --git a/iris/iris.py b/iris/iris.py index e8bd2e68c..342795d6e 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1717,7 +1717,9 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): @triton.jit -def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, copy_engine_ctx, mask=None, USE_COPY_ENGINE : tl.constexpr=False): +def put( + from_ptr, to_ptr, from_rank, to_rank, heap_bases, copy_engine_ctx, mask=None, USE_COPY_ENGINE: tl.constexpr = False +): """ Copies data from the current rank's local memory to the specified rank's memory. This function performs a memory write operation by loading data from the current @@ -1795,7 +1797,9 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, copy_engine_ctx, mask= command_in_bytes = 28 # Acquire space - base = anvil.acquire(queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes) + base = anvil.acquire( + queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes + ) # Place command slot_ptr_u32 = queue_ptr_u32 + (base // 4) @@ -1805,7 +1809,6 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, copy_engine_ctx, mask= anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) - @triton.jit def nontemporal_store(addr, value): tl.inline_asm_elementwise( @@ -1901,10 +1904,12 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): command_in_bytes = 32 # Acquire space - base = anvil.acquire(queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes) + base = anvil.acquire( + queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes + ) # Place command packet - slot_ptr_u32 = queue_ptr_u32 + (base // 4) + slot_ptr_u32 = queue_ptr_u32 + (base // 4) anvil.place_atomic_packet(slot_ptr_u32, dst_ptr_val) # Submit command @@ -1912,7 +1917,18 @@ def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): @triton.jit -def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, copy_engine_ctx=None, USE_COPY_ENGINE: tl.constexpr=False): +def atomic_add( + pointer, + val, + from_rank, + to_rank, + heap_bases, + mask=None, + sem=None, + scope=None, + copy_engine_ctx=None, + USE_COPY_ENGINE: tl.constexpr = False, +): """ Performs an atomic add at the specified rank's memory location. @@ -1959,18 +1975,18 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None command_in_bytes = 32 # Acquire space - base = anvil.acquire(queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes) + base = anvil.acquire( + queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes + ) # Place command packet - slot_ptr_u32 = queue_ptr_u32 + (base // 4) + slot_ptr_u32 = queue_ptr_u32 + (base // 4) anvil.place_atomic_packet(slot_ptr_u32, dst_ptr_val) # Submit command anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) - - @triton.jit def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): """ From c5e4735158c7992beb954baa8160c999a0f24e49 Mon Sep 17 00:00:00 2001 From: David Sidler Date: Fri, 5 Dec 2025 12:38:35 -0600 Subject: [PATCH 10/29] wip fixed wrap into ring when placing --- .../06_message_passing/message_passing_put.py | 5 +- .../benchmark.py | 66 +++-- .../gemm_all_scatter_wg_specialization.py | 37 +++ .../matmul_wrapper.py | 15 ++ iris/__init__.py | 2 - iris/iris.py | 243 +++++++++++------- 6 files changed, 243 insertions(+), 125 deletions(-) diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index 66eedf989..12750e4b7 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -43,7 +43,7 @@ def producer_kernel( iris.put(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, mask=mask, USE_COPY_ENGINE=USE_COPY_ENGINE) # Set flag to signal completion - # iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, sem="release", scope="sys") + # iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys") iris.atomic_add(flag + pid, 1, producer_rank, consumer_rank, heap_bases_ptr, sem='release', scope='sys', copy_engine_ctx=copy_engine_handle_ptr, USE_COPY_ENGINE=USE_COPY_ENGINE) @@ -166,7 +166,8 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32) # Get copy engine context - copy_engine_ctx = shmem.get_copy_engine_handle(consumer_rank) if args["use_copy_engine"] and cur_rank == producer_rank else None + # copy_engine_ctx = shmem.get_copy_engine_handle(consumer_rank) if args["use_copy_engine"] and cur_rank == producer_rank else None + copy_engine_ctx = shmem.get_copy_engine_ctx() if cur_rank == producer_rank: shmem.info(f"Rank {cur_rank} is sending data to rank {consumer_rank}.") diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index 59d145651..7db53545e 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -67,6 +67,9 @@ def parse_args(): help="Number of total SMs for gemm + scatter kernel (default: auto-detected)", ) parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + parser.add_argument( + "-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies" + ) return vars(parser.parse_args()) @@ -87,7 +90,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): world_size = shmem.get_num_ranks() # Set default SM values if not provided - cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count if args["num_sms"] is None: args["num_sms"] = cu_count if args["gemm_sms"] is None: @@ -130,13 +133,23 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): json_writer.add_field(key, value) global_C = shmem.zeros((args["M"], args["N"]), device="cuda", dtype=A.dtype) - local_C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) + # why on heap? + local_C = torch.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N - locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8) + # TODO why is this on the heap? + locks = torch.zeros((total_tiles,), device="cuda", dtype=torch.int8) #why int8?? + comm_sms = args["num_sms"] - args["gemm_sms"] + flags = shmem.zeros((comm_sms, world_size), device="cuda", dtype=torch.uint32) + + # Get copy engine context + copy_engine_ctx = shmem.get_copy_engine_ctx() # if args["use_copy_engine"] else None + # ( + # shmem.get_copy_engine_handle() if args["use_copy_engine"] and cur_rank == producer_rank else None + # ) bias = None @@ -179,6 +192,7 @@ def run_experiment(): global_C, bias, locks, + flags, rank, world_size, args["gemm_sms"], @@ -190,6 +204,8 @@ def run_experiment(): shmem.get_heap_bases(), "gfx942", args["trace_tiles"], + args["use_copy_engine"], + copy_engine_ctx, timestamps.mm_begin_timestamp, timestamps.mm_end_timestamp, ) @@ -227,10 +243,12 @@ def run_experiment(): # Wait for all to finish validation shmem.barrier() - shmem.info("Validating local C...") + # shmem.info("Validating local C...") json_writer.add_field("success", success) + # shmem.info(flags) + if not is_triton_interpret_set(): gemm_registers = matmul.get_matmul_registers() gemm_spills = matmul.get_matmul_spills() @@ -240,26 +258,26 @@ def run_experiment(): shmem.info("Validation completed") - if args["benchmark"]: - matmul.set_debug(False) - shmem.info("Benchmarking...") - perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) - triton_ms = iris.do_bench(run_experiment, shmem.barrier) - triton_tflops = perf(triton_ms) - algo_string = "all_scatter" - shmem.info( - f"tile matmul + {algo_string} (total_tiles={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops" - ) - - json_writer.add_field("tflops", triton_tflops) - json_writer.add_field("total_ms", triton_ms) - - for k in ["gemm"]: - json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) - json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) - - # Wait for all to finish benchmarking - shmem.barrier() + # if args["benchmark"]: + # matmul.set_debug(False) + # shmem.info("Benchmarking...") + # perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) + # triton_ms = iris.do_bench(run_experiment, shmem.barrier) + # triton_tflops = perf(triton_ms) + # algo_string = "all_scatter" + # shmem.info( + # f"tile matmul + {algo_string} (total_tiles={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops" + # ) + + # json_writer.add_field("tflops", triton_tflops) + # json_writer.add_field("total_ms", triton_ms) + + # for k in ["gemm"]: + # json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) + # json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) + + # # Wait for all to finish benchmarking + # shmem.barrier() if rank == 0: json_writer.flush() diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index ac2d2e353..96e2370a1 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -19,6 +19,7 @@ def persistent_gemm_all_scatter_wg_specialization( c_global, bias_ptr, locks, + flags, M, N, K, @@ -44,6 +45,8 @@ def persistent_gemm_all_scatter_wg_specialization( cur_rank: tl.constexpr, world_size: tl.constexpr, COLLECT_TIMESTAMPS: tl.constexpr = False, + USE_COPY_ENGINE: tl.constexpr = False, + copy_engine_ctx: tl.tensor = None, mm_begin_timestamp_ptr: tl.tensor = None, mm_end_timestamp_ptr: tl.tensor = None, ): @@ -69,6 +72,9 @@ def persistent_gemm_all_scatter_wg_specialization( # and another that performs the communication. Uses persistent- # kernel. if pid < GEMM_SMS: + # tl.device_print("GEMM_SMS: ", GEMM_SMS) + # tl.device_print("GEMM pid: ", pid) + for tile_id in range(pid, total_tiles, GEMM_SMS): if COLLECT_TIMESTAMPS: timestamp = read_realtime() @@ -148,6 +154,8 @@ def persistent_gemm_all_scatter_wg_specialization( else: # pid >= GEMM_SMS COMM_SMS = NUM_SMS - GEMM_SMS pid = pid - GEMM_SMS + # tl.device_print("COMM_SMS: ", COMM_SMS) + # tl.device_print("COMM pid: ", pid) for tile_id in range(pid, total_tiles, COMM_SMS): num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = tile_id // num_pid_in_group @@ -176,5 +184,34 @@ def persistent_gemm_all_scatter_wg_specialization( cur_rank, remote_rank, heap_bases, + copy_engine_ctx, + stride_cm, + stride_cn, + stride_cm_global, + stride_cn_global, + BLOCK_SIZE_M, + BLOCK_SIZE_N, mask=sub_mask, + USE_COPY_ENGINE=USE_COPY_ENGINE ) + tl.debug_barrier() + # Signal other ranks + for remote_rank in range(world_size): + if remote_rank != cur_rank: + iris.atomic_add( + flags + (pid * world_size) + cur_rank, + 1, + cur_rank, + remote_rank, + heap_bases, + sem="release", + scope="sys", + copy_engine_ctx=copy_engine_ctx, + USE_COPY_ENGINE=USE_COPY_ENGINE, + ) + # Wait for other ranks to signal us + for remote_rank in range(world_size): + if remote_rank != cur_rank: + while tl.load(flags + (pid * world_size) + remote_rank, cache_modifier=".cv", volatile=True) != 1: + pass + diff --git a/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py b/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py index ce1865618..5896ea063 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py +++ b/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py @@ -47,6 +47,7 @@ def _call( c_global: torch.Tensor, bias: torch.Tensor, locks: torch.Tensor, + flags: torch.Tensor, rank: int, world_size: int, gemm_sms: int, @@ -58,6 +59,8 @@ def _call( heap_bases_ptr: torch.Tensor = None, arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, + USE_COPY_ENGINE: bool = False, + copy_engine_ctx: torch.Tensor = None, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): @@ -82,6 +85,9 @@ def _call( even_k = K % BLK_K == 0 use_bias = False + print("C: ", c.stride(0), " ", c.stride(1), " global ", c_global.stride(0), " ", c_global.stride(1)) + print("BLK_M ", BLK_M, " BLK_N ", BLK_N, " BLK_K ", BLK_K, " even k ", even_k, " total_tiles ", total_tiles) + # compute grid (work to do per SM on the first wave) stride_bias = bias.stride(0) if use_bias else 0 kk = gemm_kernel[(num_sms,)]( @@ -91,6 +97,7 @@ def _call( c_global, bias, locks, + flags, M, N, K, @@ -121,6 +128,8 @@ def _call( cur_rank=rank, world_size=world_size, COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + USE_COPY_ENGINE=USE_COPY_ENGINE, + copy_engine_ctx=copy_engine_ctx, mm_begin_timestamp_ptr=mm_begin_timestamp, mm_end_timestamp_ptr=mm_end_timestamp, ) @@ -140,6 +149,7 @@ def forward( c_global: torch.Tensor, bias: torch.Tensor, locks: torch.Tensor, + flags: torch.Tensor, rank: int, world_size: int, gemm_sms: int, @@ -151,6 +161,8 @@ def forward( heap_bases_ptr: torch.Tensor = None, arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, + USE_COPY_ENGINE: bool = False, + copy_engine_ctx: torch.Tensor = None, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): @@ -161,6 +173,7 @@ def forward( c_global=c_global, bias=bias, locks=locks, + flags=flags, rank=rank, world_size=world_size, gemm_sms=gemm_sms, @@ -172,6 +185,8 @@ def forward( heap_bases_ptr=heap_bases_ptr, arch=arch, COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + USE_COPY_ENGINE=USE_COPY_ENGINE, + copy_engine_ctx=copy_engine_ctx, mm_begin_timestamp=mm_begin_timestamp, mm_end_timestamp=mm_end_timestamp, ) diff --git a/iris/__init__.py b/iris/__init__.py index 58079dc53..2b048d03a 100644 --- a/iris/__init__.py +++ b/iris/__init__.py @@ -59,7 +59,6 @@ atomic_and, atomic_min, atomic_max, - signal_ce, ) from .util import ( @@ -99,7 +98,6 @@ "atomic_and", "atomic_min", "atomic_max", - "signal_ce", "do_bench", "hip", "experimental", # Experimental features including iris_gluon diff --git a/iris/iris.py b/iris/iris.py index e8bd2e68c..9c43d9660 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -129,9 +129,30 @@ def __init__(self, heap_size=1 << 30): self.copy_engines.init() # connect to all peers + # TODO only connect local ranks + # TODO get size + context_size = 6 + self.copy_engines_device_ctx = torch.zeros((num_ranks, context_size), dtype=torch.uint64, device=self.device) + for rank in range(num_ranks): if rank != cur_rank: self.copy_engines.connect(cur_rank, rank, 1) + queue = self.copy_engines.get_sdma_queue(cur_rank, rank, 0) + handle = queue.device_ctx() + self.info(f"---- Queue {rank} ------------") + self.info(f"queue_buf {handle.queue_buf:#x} at {id(handle.queue_buf):#x}") + self.info(f"rptr {handle.rptr:#x} at {id(handle.rptr):#x}") + self.info(f"wptr {handle.wptr:#x} at {id(handle.wptr):#x}") + self.info(f"doorbell {handle.doorbell:#x} at {id(handle.doorbell):#x}") + self.info(f"cached_write_ptr {handle.cached_wptr:#x} at {id(handle.cached_wptr):#x}") + self.info(f"committed_write_ptr {handle.committed_wptr:#x} at {id(handle.committed_wptr):#x}") + + self.copy_engines_device_ctx[rank][0] = handle.queue_buf + self.copy_engines_device_ctx[rank][1] = handle.rptr + self.copy_engines_device_ctx[rank][2] = handle.wptr + self.copy_engines_device_ctx[rank][3] = handle.doorbell + self.copy_engines_device_ctx[rank][4] = handle.cached_wptr + self.copy_engines_device_ctx[rank][5] = handle.committed_wptr def _log_with_rank(self, level, message): """Helper method to log with rank information injected into the record.""" @@ -1151,6 +1172,9 @@ def get_heap_bases(self): """ return self.heap_bases + def get_copy_engine_ctx(self): + return self.copy_engines_device_ctx + def barrier(self, stream=None): """ Synchronize all ranks and their CUDA devices. @@ -1230,34 +1254,6 @@ def get_num_ranks(self): """ return self.num_ranks - def get_copy_engine_handle(self, to_rank): - # TODO remove last arg - queue = self.copy_engines.get_sdma_queue(self.get_rank(), to_rank, 0) - # Wrap into numpy array - handle = queue.device_ctx() - self.info("---- Queue ------------") - # print(f"handle at {id(handle):#x}") - self.info(f"queue_buf {handle.queue_buf:#x} at {id(handle.queue_buf):#x}") - self.info(f"rptr {handle.rptr:#x} at {id(handle.rptr):#x}") - self.info(f"wptr {handle.wptr:#x} at {id(handle.wptr):#x}") - self.info(f"doorbell {handle.doorbell:#x} at {id(handle.doorbell):#x}") - self.info(f"cached_write_ptr {handle.cached_wptr:#x} at {id(handle.cached_wptr):#x}") - self.info(f"committed_write_ptr {handle.committed_wptr:#x} at {id(handle.committed_wptr):#x}") - - # TODO get size - # array = np.ctypeslib.as_array(ctypes.cast(handle, ctypes.POINTER(ctypes.c_uint64)), shape=(7, )) - context_size = 6 - device_ctx = torch.zeros(context_size, dtype=torch.uint64, device=self.device) - device_ctx[0] = handle.queue_buf - device_ctx[1] = handle.rptr - device_ctx[2] = handle.wptr - device_ctx[3] = handle.doorbell - device_ctx[4] = handle.cached_wptr - device_ctx[5] = handle.committed_wptr - # context[6] = handle. - - return device_ctx # anvil.get_handle_as_tensor(queue) # torch.from_numpy(array) #.to(device='cuda') - def __throw_if_invalid_output_tensor(self, tensor: torch.Tensor, num_elements: int, dtype: torch.dtype): if not self.__tensor_on_device(tensor): raise RuntimeError( @@ -1717,7 +1713,16 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): @triton.jit -def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, copy_engine_ctx, mask=None, USE_COPY_ENGINE : tl.constexpr=False): +def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, + copy_engine_ctx: tl.tensor, + stride_tm, + stride_tn, + stride_fm, + stride_fn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + mask=None, + USE_COPY_ENGINE: tl.constexpr=False): """ Copies data from the current rank's local memory to the specified rank's memory. This function performs a memory write operation by loading data from the current @@ -1750,7 +1755,7 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, copy_engine_ctx, mask= tl.store(translated_to_ptr, data, mask=mask) else: - ctx = copy_engine_ctx + ctx = copy_engine_ctx + (6 * to_rank) queue_ptr_u32 = tl.load(ctx + 0).to(tl.pointer_type(tl.uint32)) read_ptr = tl.load(ctx + 1).to(tl.pointer_type(tl.uint64)) write_ptr = tl.load(ctx + 2).to(tl.pointer_type(tl.uint64)) @@ -1758,12 +1763,14 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, copy_engine_ctx, mask= cached_write_ptr = tl.load(ctx + 4).to(tl.pointer_type(tl.uint64)) committed_write_ptr = tl.load(ctx + 5).to(tl.pointer_type(tl.uint64)) - dst_ptr_val = tl.min(translated_to_ptr.to(tl.uint64), axis=0) - + # dst_ptr_val = tl.min(translated_to_ptr.to(tl.uint64), axis=-1) + dst_ptr_val0 = tl.min(translated_to_ptr.to(tl.uint64)) # Extract source address (min of pointer block where data is stored) src_ptr_u64 = from_ptr.to(tl.uint64) - src_ptr_val = tl.min(src_ptr_u64, axis=0) - max_src_ptr = tl.max(src_ptr_u64, axis=0) + # src_ptr_val = tl.min(src_ptr_u64, axis=-1) + src_ptr_val0 = tl.min(src_ptr_u64) + # max_src_ptr = tl.max(src_ptr_u64, axis=0) + # Infer element size from pointer type # src_ptr is a block of pointers with a specific element type (e.g., pointer) @@ -1789,20 +1796,62 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, copy_engine_ctx, mask= # Calculate total size in bytes # Count number of valid elements based on mask + # src stride: 9216 + # dst strice: 9216 mask_int = mask.to(tl.int32) - num_elements = tl.sum(mask_int, axis=0) - size_bytes = (num_elements * element_size_bytes).to(tl.uint32) - - command_in_bytes = 28 - # Acquire space - base = anvil.acquire(queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes) + num_elements_per_stride = tl.max(tl.sum(mask_int, axis=-1)) + num_strides = tl.max(tl.sum(mask_int, axis=0)) + size_bytes = (num_elements_per_stride * element_size_bytes).to(tl.uint32) + src_stride = (stride_fm * element_size_bytes).to(tl.uint32) + dst_stride = (stride_tm * element_size_bytes).to(tl.uint32) + + # if tl.program_id(axis=0) == 230 and from_rank == 1 and to_rank == 0: + # if from_rank == 1 and to_rank == 0: + # if to_rank == 1: + # if tl.max(size_bytes) == 0: + # tl.device_print("from_ptr ", from_ptr.block_shape) + # tl.device_print("stride_tm ", stride_tm) + # tl.device_print("stride_tn ", stride_tn) + # tl.device_print("stride_fm ", stride_fm) + # tl.device_print("stride_fn ", stride_fn) + # tl.device_print("src_stride ", src_stride) + # tl.device_print("dst_stride ", dst_stride) + + # tl.device_print("queue_ptr_u32 ", queue_ptr_u32) + # tl.device_print("dst_ptr_val (all) ", translated_to_ptr.to(tl.uint64)) + # tl.device_print("dst_ptr_val ", dst_ptr_val) + # tl.device_print("dst_ptr_val (single) ", dst_ptr_val0) + # tl.device_print("src_ptr_u64", src_ptr_u64) + # tl.device_print("src_ptr_val ", src_ptr_val) + # tl.device_print("src_ptr_val (single) ", src_ptr_val0) + # tl.device_print("mask(axis=0): ", tl.sum(mask_int, axis=0)) + # tl.device_print("mask: ", tl.sum(mask_int)) + # tl.device_print("num strides: ", num_strides) + # tl.device_print("size_bytes per stride", size_bytes) - # Place command - slot_ptr_u32 = queue_ptr_u32 + (base // 4) - anvil.place_copy_packet(slot_ptr_u32, size_bytes, src_ptr_val, dst_ptr_val) - # Submit command - anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) + command_in_bytes = 28 + # TODO wrap-around seems broken + # Overwrite here + num_strides = 8 + required_bytes = command_in_bytes * num_strides + # queue_offsets = (command_in_bytes // 4) * tl.arange(0, num_strides) + # if tl.program_id(axis=0) == 23 and to_rank == 0: + if to_rank == 7: + # if tl.program_id(axis=0) == 230: + # tl.device_print("required_bytes", required_bytes) + # Acquire space + base = anvil.acquire(queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, required_bytes) + + # Place command + for stride in range(0, num_strides): + # slot_ptr_u32 = queue_ptr_u32 + (base // 4) + (stride * 7) + offset_bytes = base + (stride * command_in_bytes) + anvil.place_copy_packet(queue_ptr_u32, offset_bytes, size_bytes, src_ptr_val0 + (src_stride * stride), dst_ptr_val0 + (dst_stride * stride)) + # anvil.place_copy_packet(queue_ptr_u32, offset_bytes, size_bytes, src_ptr_val0, dst_ptr_val0) + + # Submit command + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, required_bytes) @@ -1860,55 +1909,55 @@ def nontemporal_atomic_add(addr, value): # return True # TODO if old == cmp else False -@triton.jit -def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): - """ - Copies data from the current rank's local memory to the specified rank's memory. - This function performs a memory write operation by loading data from the current - rank's `from_ptr`, translating the `to_ptr` from the current rank's address - space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. - If the `to_rank` is the same as the current rank, this function performs a local copy operation. - - Args: - from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. - to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. - from_rank (int): The current rank ID from which to read the data. - to_rank (int): The `to_rank` ID to which the data will be written. - heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. - - Returns: - None - - Example: - >>> @triton.jit - >>> def kernel(local_ptr, remote_ptr, heap_bases): - >>> from_rank = 0 - >>> to_rank = 1 - >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) - """ - - handle = ce_handle # iris.get_copy_engine_handle(to_rank) - queue_ptr_u32 = tl.load(handle + 0).to(tl.pointer_type(tl.uint32)) - read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) - write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) - doorbell_ptr = tl.load(handle + 3).to(tl.pointer_type(tl.uint64)) - cached_write_ptr = tl.load(handle + 4).to(tl.pointer_type(tl.uint64)) - committed_write_ptr = tl.load(handle + 5).to(tl.pointer_type(tl.uint64)) - - translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) - dst_ptr_val = translated_to_ptr.to(tl.uint64) - - command_in_bytes = 32 - # Acquire space - base = anvil.acquire(queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes) - - # Place command packet - slot_ptr_u32 = queue_ptr_u32 + (base // 4) - anvil.place_atomic_packet(slot_ptr_u32, dst_ptr_val) - - # Submit command - anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) +# @triton.jit +# def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): +# """ +# Copies data from the current rank's local memory to the specified rank's memory. +# This function performs a memory write operation by loading data from the current +# rank's `from_ptr`, translating the `to_ptr` from the current rank's address +# space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. +# If the `to_rank` is the same as the current rank, this function performs a local copy operation. + +# Args: +# from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. +# to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. +# from_rank (int): The current rank ID from which to read the data. +# to_rank (int): The `to_rank` ID to which the data will be written. +# heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. +# mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + +# Returns: +# None + +# Example: +# >>> @triton.jit +# >>> def kernel(local_ptr, remote_ptr, heap_bases): +# >>> from_rank = 0 +# >>> to_rank = 1 +# >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) +# """ + +# handle = ce_handle # iris.get_copy_engine_handle(to_rank) +# queue_ptr_u32 = tl.load(handle + 0).to(tl.pointer_type(tl.uint32)) +# read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) +# write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) +# doorbell_ptr = tl.load(handle + 3).to(tl.pointer_type(tl.uint64)) +# cached_write_ptr = tl.load(handle + 4).to(tl.pointer_type(tl.uint64)) +# committed_write_ptr = tl.load(handle + 5).to(tl.pointer_type(tl.uint64)) + +# translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) +# dst_ptr_val = translated_to_ptr.to(tl.uint64) + +# command_in_bytes = 32 +# # Acquire space +# base = anvil.acquire(queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes) + +# # Place command packet +# slot_ptr_u32 = queue_ptr_u32 + (base // 4) +# anvil.place_atomic_packet(slot_ptr_u32, dst_ptr_val) + +# # Submit command +# anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) @triton.jit @@ -1947,7 +1996,7 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None if not USE_COPY_ENGINE: return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) else: - handle = copy_engine_ctx + handle = copy_engine_ctx + (6 * to_rank) queue_ptr_u32 = tl.load(handle + 0).to(tl.pointer_type(tl.uint32)) read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) @@ -1962,8 +2011,8 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None base = anvil.acquire(queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes) # Place command packet - slot_ptr_u32 = queue_ptr_u32 + (base // 4) - anvil.place_atomic_packet(slot_ptr_u32, dst_ptr_val) + # slot_ptr_u32 = queue_ptr_u32 + (base // 4) + anvil.place_atomic_packet(queue_ptr_u32, base, dst_ptr_val) # Submit command anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) From 536231885429b0d1f7b95989d5e26adc7d82306b Mon Sep 17 00:00:00 2001 From: David Sidler Date: Fri, 5 Dec 2025 13:26:35 -0600 Subject: [PATCH 11/29] to_rank 7 working --- .../benchmark.py | 7 ++++--- .../gemm_all_scatter_wg_specialization.py | 14 ++++++++++---- iris/iris.py | 2 +- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index f038caf16..88e1fee0b 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -133,15 +133,16 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): json_writer.add_field(key, value) global_C = shmem.zeros((args["M"], args["N"]), device="cuda", dtype=A.dtype) - # why on heap? - local_C = torch.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) + # why on heap? so it is uncached?? + # TODO unused + local_C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N # TODO why is this on the heap? - locks = torch.zeros((total_tiles,), device="cuda", dtype=torch.int8) # why int8?? + locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8) # why int8?? comm_sms = args["num_sms"] - args["gemm_sms"] flags = shmem.zeros((comm_sms, world_size), device="cuda", dtype=torch.uint32) diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index c0faeebd0..dacd86b3c 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -10,6 +10,11 @@ import iris +@triton.jit +def wait_cnt(): + tl.inline_asm_elementwise( + "s_waitcnt vmcnt(0)", "=r", [], dtype=tl.int32, is_pure=False, pack=1 + ) @triton.jit() def persistent_gemm_all_scatter_wg_specialization( @@ -27,8 +32,8 @@ def persistent_gemm_all_scatter_wg_specialization( stride_ak, stride_bk, stride_bn, - stride_cm, - stride_cn, + stride_cm, # unused + stride_cn, # unused stride_cm_global, stride_cn_global, stride_bias, @@ -148,6 +153,7 @@ def persistent_gemm_all_scatter_wg_specialization( tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt") + wait_cnt() tl.debug_barrier() tl.store(locks + tile_id, 1, cache_modifier=".wt") @@ -185,8 +191,8 @@ def persistent_gemm_all_scatter_wg_specialization( remote_rank, heap_bases, copy_engine_ctx, - stride_cm, - stride_cn, + stride_cm_global, + stride_cn_global, stride_cm_global, stride_cn_global, BLOCK_SIZE_M, diff --git a/iris/iris.py b/iris/iris.py index 82ad2b386..489fe5876 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1837,7 +1837,7 @@ def put( command_in_bytes = 28 # TODO wrap-around seems broken # Overwrite here - num_strides = 8 + # num_strides = 8 required_bytes = command_in_bytes * num_strides # queue_offsets = (command_in_bytes // 4) * tl.arange(0, num_strides) # if tl.program_id(axis=0) == 23 and to_rank == 0: From a6b1d40b71676addde55f227b81809f54854171c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 10 Dec 2025 18:00:49 +0000 Subject: [PATCH 12/29] Apply Ruff auto-fixes --- .../gemm_all_scatter_wg_specialization.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index dacd86b3c..1a3d9f8c0 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -10,11 +10,11 @@ import iris + @triton.jit def wait_cnt(): - tl.inline_asm_elementwise( - "s_waitcnt vmcnt(0)", "=r", [], dtype=tl.int32, is_pure=False, pack=1 - ) + tl.inline_asm_elementwise("s_waitcnt vmcnt(0)", "=r", [], dtype=tl.int32, is_pure=False, pack=1) + @triton.jit() def persistent_gemm_all_scatter_wg_specialization( @@ -32,8 +32,8 @@ def persistent_gemm_all_scatter_wg_specialization( stride_ak, stride_bk, stride_bn, - stride_cm, # unused - stride_cn, # unused + stride_cm, # unused + stride_cn, # unused stride_cm_global, stride_cn_global, stride_bias, From 400b5b765cf2062406ce4d5d80e8a56d357bbb39 Mon Sep 17 00:00:00 2001 From: David Sidler Date: Wed, 14 Jan 2026 17:26:42 -0600 Subject: [PATCH 13/29] use triton commit with fix --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 6524f82e3..9d7698014 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -31,7 +31,7 @@ RUN pip3 install --upgrade pip && \ # Clone and install Triton WORKDIR $TRITON_PATH RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH -RUN git checkout 715f6b1d442601436bf8d462db6ff8e17aec8cfb +RUN git checkout 32d63ac67c7cc44715ab2428ffa182f606efa012 RUN pip3 install -e . ENV PYTHONPATH=$TRITON_PATH From d06cb72b9476e8652a53c21473a6035eb641192e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 14 Jan 2026 23:28:01 +0000 Subject: [PATCH 14/29] Apply Ruff auto-fixes --- examples/06_message_passing/message_passing_copy_engine.py | 3 --- examples/06_message_passing/message_passing_put.py | 1 - iris/iris.py | 1 - 3 files changed, 5 deletions(-) diff --git a/examples/06_message_passing/message_passing_copy_engine.py b/examples/06_message_passing/message_passing_copy_engine.py index 6f0acf13b..677b68465 100644 --- a/examples/06_message_passing/message_passing_copy_engine.py +++ b/examples/06_message_passing/message_passing_copy_engine.py @@ -5,12 +5,9 @@ import torch import torch.distributed as dist -import torch.multiprocessing as mp import triton import triton.language as tl import random -import os -import sys from mpi4py import MPI diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index 1470d0fcb..b041ac200 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -5,7 +5,6 @@ import torch import torch.distributed as dist -import torch.multiprocessing as mp import triton import triton.language as tl import random diff --git a/iris/iris.py b/iris/iris.py index 37fb39067..1dea5191c 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -40,7 +40,6 @@ open_ipc_handle, get_ipc_handle_size, ) -import sys import anvil From b2e358b549991347067a8c9d44eed0a4ef3dde78 Mon Sep 17 00:00:00 2001 From: David Sidler Date: Tue, 20 Jan 2026 15:24:47 -0600 Subject: [PATCH 15/29] send to all ranks but always same stride --- iris/iris.py | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/iris/iris.py b/iris/iris.py index 37fb39067..9fa89001f 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -2028,29 +2028,29 @@ def put( required_bytes = command_in_bytes * num_strides # queue_offsets = (command_in_bytes // 4) * tl.arange(0, num_strides) # if tl.program_id(axis=0) == 23 and to_rank == 0: - if to_rank == 7: - # if tl.program_id(axis=0) == 230: - # tl.device_print("required_bytes", required_bytes) - # Acquire space - base = anvil.acquire( - queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, required_bytes - ) + # if to_rank == 7: + # if tl.program_id(axis=0) == 230: + # tl.device_print("required_bytes", required_bytes) + # Acquire space + base = anvil.acquire( + queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, required_bytes + ) + + # Place command + for stride in range(0, num_strides): + # slot_ptr_u32 = queue_ptr_u32 + (base // 4) + (stride * 7) + offset_bytes = base + (stride * command_in_bytes) + # anvil.place_copy_packet( + # queue_ptr_u32, + # offset_bytes, + # size_bytes, + # src_ptr_val0 + (src_stride * stride), + # dst_ptr_val0 + (dst_stride * stride), + # ) + anvil.place_copy_packet(queue_ptr_u32, offset_bytes, size_bytes, src_ptr_val0, dst_ptr_val0) - # Place command - for stride in range(0, num_strides): - # slot_ptr_u32 = queue_ptr_u32 + (base // 4) + (stride * 7) - offset_bytes = base + (stride * command_in_bytes) - anvil.place_copy_packet( - queue_ptr_u32, - offset_bytes, - size_bytes, - src_ptr_val0 + (src_stride * stride), - dst_ptr_val0 + (dst_stride * stride), - ) - # anvil.place_copy_packet(queue_ptr_u32, offset_bytes, size_bytes, src_ptr_val0, dst_ptr_val0) - - # Submit command - anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, required_bytes) + # Submit command + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, required_bytes) @triton.jit From b245899b59dff395d4cb52614cc3f09b50919b22 Mon Sep 17 00:00:00 2001 From: David Sidler Date: Tue, 20 Jan 2026 15:24:58 -0600 Subject: [PATCH 16/29] update submodule --- ext/shader_sdma | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/shader_sdma b/ext/shader_sdma index d17774541..3494f0284 160000 --- a/ext/shader_sdma +++ b/ext/shader_sdma @@ -1 +1 @@ -Subproject commit d177745411cd7c4bafb00493b1ea3bdd0390e39e +Subproject commit 3494f0284f436fe17c6855fab5f8aee67740999e From 1ee4c5865aae54ac896462469b86583af0f636ef Mon Sep 17 00:00:00 2001 From: David Sidler Date: Fri, 30 Jan 2026 16:52:09 -0600 Subject: [PATCH 17/29] use 32B copy packets workaround --- iris/iris.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/iris/iris.py b/iris/iris.py index fe8c94548..2010b42e4 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -2020,7 +2020,9 @@ def put( # tl.device_print("num strides: ", num_strides) # tl.device_print("size_bytes per stride", size_bytes) - command_in_bytes = 28 + # workaround to avoid padding + command_in_bytes_u32 = 32 + command_in_bytes = command_in_bytes_u32.to(tl.uint64) # TODO wrap-around seems broken # Overwrite here # num_strides = 8 @@ -2039,14 +2041,14 @@ def put( for stride in range(0, num_strides): # slot_ptr_u32 = queue_ptr_u32 + (base // 4) + (stride * 7) offset_bytes = base + (stride * command_in_bytes) - # anvil.place_copy_packet( - # queue_ptr_u32, - # offset_bytes, - # size_bytes, - # src_ptr_val0 + (src_stride * stride), - # dst_ptr_val0 + (dst_stride * stride), - # ) - anvil.place_copy_packet(queue_ptr_u32, offset_bytes, size_bytes, src_ptr_val0, dst_ptr_val0) + anvil.place_copy_packet( + queue_ptr_u32, + offset_bytes, + size_bytes, + src_ptr_val0 + (src_stride * stride), + dst_ptr_val0 + (dst_stride * stride), + ) + # anvil.place_copy_packet(queue_ptr_u32, offset_bytes, size_bytes, src_ptr_val0, dst_ptr_val0) # Submit command anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, required_bytes) From 1c384c3613d1f916cf201c3635cc039827e1a409 Mon Sep 17 00:00:00 2001 From: David Sidler Date: Fri, 30 Jan 2026 16:54:28 -0600 Subject: [PATCH 18/29] submodule update --- ext/shader_sdma | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/shader_sdma b/ext/shader_sdma index 3494f0284..91c66569f 160000 --- a/ext/shader_sdma +++ b/ext/shader_sdma @@ -1 +1 @@ -Subproject commit 3494f0284f436fe17c6855fab5f8aee67740999e +Subproject commit 91c66569f035e4fbb04ce92cb51f4dfa0ddef5eb From 0224866db0cbf8eba40d0fd7af96f64721b3ab84 Mon Sep 17 00:00:00 2001 From: David Sidler Date: Wed, 4 Mar 2026 16:38:35 -0600 Subject: [PATCH 19/29] use window command --- .../gemm_all_scatter_wg_specialization.py | 39 +++--- iris/iris.py | 113 +++++++++++++----- 2 files changed, 106 insertions(+), 46 deletions(-) diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index 7cac48db0..ea9fc75c5 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -182,26 +182,31 @@ def persistent_gemm_all_scatter_wg_specialization( for remote_rank in range(world_size): if remote_rank != cur_rank: - iris.put( - c_global + global_offset, - c_global + global_offset, - cur_rank, - remote_rank, - heap_bases, - copy_engine_ctx, - stride_cm_global, - stride_cn_global, - stride_cm_global, - stride_cn_global, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - mask=sub_mask, - USE_COPY_ENGINE=USE_COPY_ENGINE, - ) + if tile_id < 250: + # tl.device_print("tile_id", tile_id) + iris.put( + c_global + global_offset, + c_global + global_offset, + cur_rank, + remote_rank, + heap_bases, + copy_engine_ctx, + stride_cm_global, + stride_cn_global, + stride_cm_global, + stride_cn_global, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + mask=sub_mask, + USE_COPY_ENGINE=USE_COPY_ENGINE, + from_base_ptr=c_global, + to_base_ptr=c_global, + ) tl.debug_barrier() # Signal other ranks for remote_rank in range(world_size): if remote_rank != cur_rank: + # print("Issue atomic_add") iris.atomic_add( flags + (pid * world_size) + cur_rank, 1, @@ -213,8 +218,10 @@ def persistent_gemm_all_scatter_wg_specialization( copy_engine_ctx=copy_engine_ctx, USE_COPY_ENGINE=USE_COPY_ENGINE, ) + # print("Start waiting") # Wait for other ranks to signal us for remote_rank in range(world_size): if remote_rank != cur_rank: while tl.load(flags + (pid * world_size) + remote_rank, cache_modifier=".cv", volatile=True) != 1: pass + # print("done waiting") diff --git a/iris/iris.py b/iris/iris.py index 2010b42e4..f616003d3 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1910,10 +1910,12 @@ def put( stride_tn, stride_fm, stride_fn, - BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, #TODO Needed?? BLOCK_SIZE_N: tl.constexpr, mask=None, USE_COPY_ENGINE: tl.constexpr = False, + from_base_ptr=None, + to_base_ptr=None, ): """ Copies data from the current rank's local memory to the specified rank's memory. @@ -1928,17 +1930,30 @@ def put( from_rank (int): The current rank ID from which to read the data. to_rank (int): The `to_rank` ID to which the data will be written. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + copy_engine_ctx (tl.tensor): Copy engine context for SDMA operations. + stride_tm (int): Stride in M dimension for destination buffer. + stride_tn (int): Stride in N dimension for destination buffer. + stride_fm (int): Stride in M dimension for source buffer. + stride_fn (int): Stride in N dimension for source buffer. + BLOCK_SIZE_M (tl.constexpr): Block size in M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size in N dimension. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + USE_COPY_ENGINE (tl.constexpr, optional): Whether to use SDMA copy engine. Defaults to False. + from_base_ptr (triton.PointerType, optional): Base pointer of the source buffer. Required when USE_COPY_ENGINE is True. + to_base_ptr (triton.PointerType, optional): Base pointer of the destination buffer. Required when USE_COPY_ENGINE is True. Returns: None Example: >>> @triton.jit - >>> def kernel(local_ptr, remote_ptr, heap_bases): + >>> def kernel(local_ptr, remote_ptr, heap_bases, copy_engine_ctx): >>> from_rank = 0 >>> to_rank = 1 - >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) + >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases, + >>> copy_engine_ctx, stride_m, stride_n, stride_m, stride_n, + >>> BLOCK_M, BLOCK_N, mask=None, USE_COPY_ENGINE=True, + >>> from_base_ptr=base_ptr, to_base_ptr=base_ptr) """ translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) @@ -2020,35 +2035,61 @@ def put( # tl.device_print("num strides: ", num_strides) # tl.device_print("size_bytes per stride", size_bytes) - # workaround to avoid padding - command_in_bytes_u32 = 32 + # Use sub-window copy packet to copy entire tile with a single packet + # Sub-window copy packet is 80 bytes (20 DWORDs) + # Requires from_base_ptr and to_base_ptr to be provided + command_in_bytes_u32 = 80 command_in_bytes = command_in_bytes_u32.to(tl.uint64) - # TODO wrap-around seems broken - # Overwrite here - # num_strides = 8 - required_bytes = command_in_bytes * num_strides - # queue_offsets = (command_in_bytes // 4) * tl.arange(0, num_strides) - # if tl.program_id(axis=0) == 23 and to_rank == 0: - # if to_rank == 7: - # if tl.program_id(axis=0) == 230: - # tl.device_print("required_bytes", required_bytes) - # Acquire space + required_bytes = command_in_bytes + + # Calculate base addresses and offsets for sub-window copy + src_base = from_base_ptr.to(tl.uint64) + dst_base = __translate(to_base_ptr, from_rank, to_rank, heap_bases).to(tl.uint64) + + # Calculate tile offset from base + # offset_bytes = src_ptr_val0 - src_base + # For 2D: offset_bytes = (y * stride_bytes) + (x * element_size_bytes) + # Decompose into x and y offsets + tile_offset_bytes = src_ptr_val0 - src_base + src_y_val = (tile_offset_bytes // src_stride).to(tl.uint32) + src_x_val = (tile_offset_bytes % src_stride).to(tl.uint32) + + tile_offset_bytes_dst = dst_ptr_val0 - dst_base + dst_y_val = (tile_offset_bytes_dst // dst_stride).to(tl.uint32) + dst_x_val = (tile_offset_bytes_dst % dst_stride).to(tl.uint32) + + # Acquire space (returns base index and wraparound offset) base = anvil.acquire( queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, required_bytes ) - # Place command - for stride in range(0, num_strides): - # slot_ptr_u32 = queue_ptr_u32 + (base // 4) + (stride * 7) - offset_bytes = base + (stride * command_in_bytes) - anvil.place_copy_packet( - queue_ptr_u32, - offset_bytes, - size_bytes, - src_ptr_val0 + (src_stride * stride), - dst_ptr_val0 + (dst_stride * stride), - ) - # anvil.place_copy_packet(queue_ptr_u32, offset_bytes, size_bytes, src_ptr_val0, dst_ptr_val0) + # Write padding NOPs if we wrapped around + # TODO move + # if offset > 0: + # num_offset_dwords = (offset // 4).to(tl.int32) + # base_ring_pos = anvil.wrap_into_ring(base) + # base_index_in_dwords = (base_ring_pos // 4).to(tl.int32) + # for i in range(num_offset_dwords): + # tl.store(queue_ptr_u32 + base_index_in_dwords + i, 0) + + # Calculate packet position (base + offset for wraparound) + packet_offset_bytes = base #+ offset + + # Place single sub-window copy command for entire tile + anvil.place_sub_window_copy_packet( + queue_ptr_u32, + packet_offset_bytes, + src_base, + dst_base, + tile_width=size_bytes, + tile_height=num_strides, + src_buffer_pitch=src_stride, + dst_buffer_pitch=dst_stride, + src_x=src_x_val, + src_y=src_y_val, + dst_x=dst_x_val, + dst_y=dst_y_val, + ) # Submit command anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, required_bytes) @@ -2217,14 +2258,26 @@ def atomic_add( dst_ptr_val = translated_ptr.to(tl.uint64) command_in_bytes = 32 - # Acquire space + # Acquire space (returns base index and wraparound offset) + # base, offset = anvil.acquire( base = anvil.acquire( queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes ) + # tl.device_print("offset ", offset) + + # Write padding NOPs if we wrapped around + # if offset > 0: + # num_offset_dwords = (offset // 4).to(tl.int32) + # base_ring_pos = anvil.wrap_into_ring(base) + # base_index_in_dwords = (base_ring_pos // 4).to(tl.int32) + # for i in range(num_offset_dwords): + # tl.store(queue_ptr_u32 + base_index_in_dwords + i, 0) + + # Calculate packet position (base + offset for wraparound) + packet_offset_bytes = base # + offset # Place command packet - # slot_ptr_u32 = queue_ptr_u32 + (base // 4) - anvil.place_atomic_packet(queue_ptr_u32, base, dst_ptr_val) + anvil.place_atomic_packet(queue_ptr_u32, packet_offset_bytes, dst_ptr_val) # Submit command anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) From 40c228a2527748b8eafbb800093021440671fd9e Mon Sep 17 00:00:00 2001 From: David Sidler Date: Wed, 4 Mar 2026 23:46:28 -0600 Subject: [PATCH 20/29] use new acquire function --- .../gemm_all_scatter_wg_specialization.py | 38 ++++++------ iris/iris.py | 58 +++++++++---------- 2 files changed, 44 insertions(+), 52 deletions(-) diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index ea9fc75c5..82652c8b9 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -182,26 +182,24 @@ def persistent_gemm_all_scatter_wg_specialization( for remote_rank in range(world_size): if remote_rank != cur_rank: - if tile_id < 250: - # tl.device_print("tile_id", tile_id) - iris.put( - c_global + global_offset, - c_global + global_offset, - cur_rank, - remote_rank, - heap_bases, - copy_engine_ctx, - stride_cm_global, - stride_cn_global, - stride_cm_global, - stride_cn_global, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - mask=sub_mask, - USE_COPY_ENGINE=USE_COPY_ENGINE, - from_base_ptr=c_global, - to_base_ptr=c_global, - ) + iris.put( + c_global + global_offset, + c_global + global_offset, + cur_rank, + remote_rank, + heap_bases, + copy_engine_ctx, + stride_cm_global, + stride_cn_global, + stride_cm_global, + stride_cn_global, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + mask=sub_mask, + USE_COPY_ENGINE=USE_COPY_ENGINE, + from_base_ptr=c_global, + to_base_ptr=c_global, + ) tl.debug_barrier() # Signal other ranks for remote_rank in range(world_size): diff --git a/iris/iris.py b/iris/iris.py index f616003d3..f10260ee2 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -2035,9 +2035,7 @@ def put( # tl.device_print("num strides: ", num_strides) # tl.device_print("size_bytes per stride", size_bytes) - # Use sub-window copy packet to copy entire tile with a single packet - # Sub-window copy packet is 80 bytes (20 DWORDs) - # Requires from_base_ptr and to_base_ptr to be provided + # Use sub-window copy packet (single packet for entire tile) command_in_bytes_u32 = 80 command_in_bytes = command_in_bytes_u32.to(tl.uint64) required_bytes = command_in_bytes @@ -2047,9 +2045,6 @@ def put( dst_base = __translate(to_base_ptr, from_rank, to_rank, heap_bases).to(tl.uint64) # Calculate tile offset from base - # offset_bytes = src_ptr_val0 - src_base - # For 2D: offset_bytes = (y * stride_bytes) + (x * element_size_bytes) - # Decompose into x and y offsets tile_offset_bytes = src_ptr_val0 - src_base src_y_val = (tile_offset_bytes // src_stride).to(tl.uint32) src_x_val = (tile_offset_bytes % src_stride).to(tl.uint32) @@ -2058,24 +2053,21 @@ def put( dst_y_val = (tile_offset_bytes_dst // dst_stride).to(tl.uint32) dst_x_val = (tile_offset_bytes_dst % dst_stride).to(tl.uint32) - # Acquire space (returns base index and wraparound offset) - base = anvil.acquire( + # Acquire space + base, offset = anvil.acquire( queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, required_bytes ) # Write padding NOPs if we wrapped around - # TODO move - # if offset > 0: - # num_offset_dwords = (offset // 4).to(tl.int32) - # base_ring_pos = anvil.wrap_into_ring(base) - # base_index_in_dwords = (base_ring_pos // 4).to(tl.int32) - # for i in range(num_offset_dwords): - # tl.store(queue_ptr_u32 + base_index_in_dwords + i, 0) - - # Calculate packet position (base + offset for wraparound) - packet_offset_bytes = base #+ offset - - # Place single sub-window copy command for entire tile + if offset > 0: + num_offset_dwords = (offset // 4).to(tl.int32) + base_ring_pos = anvil.wrap_into_ring(base) + base_index_in_dwords = (base_ring_pos // 4).to(tl.int32) + for i in range(num_offset_dwords): + tl.store(queue_ptr_u32 + base_index_in_dwords + i, 0) + + # Place single sub-window copy packet + packet_offset_bytes = base + offset anvil.place_sub_window_copy_packet( queue_ptr_u32, packet_offset_bytes, @@ -2091,8 +2083,9 @@ def put( dst_y=dst_y_val, ) - # Submit command - anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, required_bytes) + # Submit + pending_wptr = base + offset + required_bytes + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) @triton.jit @@ -2259,28 +2252,29 @@ def atomic_add( command_in_bytes = 32 # Acquire space (returns base index and wraparound offset) - # base, offset = anvil.acquire( - base = anvil.acquire( + base, offset = anvil.acquire( + # base = anvil.acquire( queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes ) # tl.device_print("offset ", offset) # Write padding NOPs if we wrapped around - # if offset > 0: - # num_offset_dwords = (offset // 4).to(tl.int32) - # base_ring_pos = anvil.wrap_into_ring(base) - # base_index_in_dwords = (base_ring_pos // 4).to(tl.int32) - # for i in range(num_offset_dwords): - # tl.store(queue_ptr_u32 + base_index_in_dwords + i, 0) + if offset > 0: + num_offset_dwords = (offset // 4).to(tl.int32) + base_ring_pos = anvil.wrap_into_ring(base) + base_index_in_dwords = (base_ring_pos // 4).to(tl.int32) + for i in range(num_offset_dwords): + tl.store(queue_ptr_u32 + base_index_in_dwords + i, 0) # Calculate packet position (base + offset for wraparound) - packet_offset_bytes = base # + offset + packet_offset_bytes = base + offset # Place command packet anvil.place_atomic_packet(queue_ptr_u32, packet_offset_bytes, dst_ptr_val) # Submit command - anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) + pending_wptr = base + offset + command_in_bytes + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) @triton.jit From 34d4ffc836d9fd4d943d796a75a9e4b7dab29d55 Mon Sep 17 00:00:00 2001 From: David Sidler Date: Wed, 4 Mar 2026 23:47:46 -0600 Subject: [PATCH 21/29] update submodule --- ext/shader_sdma | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/shader_sdma b/ext/shader_sdma index 91c66569f..0f56ac78a 160000 --- a/ext/shader_sdma +++ b/ext/shader_sdma @@ -1 +1 @@ -Subproject commit 91c66569f035e4fbb04ce92cb51f4dfa0ddef5eb +Subproject commit 0f56ac78a1be5c535f8ed6e5a9fc595bdc410ac8 From c8d4b4646803b4571b18bdf7eed8df21fbe5b117 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 5 Mar 2026 05:52:32 +0000 Subject: [PATCH 22/29] Apply Ruff auto-fixes --- examples/common/utils.py | 2 +- iris/iris.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/common/utils.py b/examples/common/utils.py index 0e6ea9482..f9ebba8d7 100644 --- a/examples/common/utils.py +++ b/examples/common/utils.py @@ -86,7 +86,7 @@ def reset(self): self.comm_end_timestamp.fill_(self.min_ts) def to_json(self, filename, gpu_freq): - cycles_to_us = lambda cycles: (cycles / gpu_freq) + cycles_to_us = lambda cycles: cycles / gpu_freq gemm_begin_us = cycles_to_us(self.mm_begin_timestamp.cpu().numpy()) gemm_end_us = cycles_to_us(self.mm_end_timestamp.cpu().numpy()) diff --git a/iris/iris.py b/iris/iris.py index f10260ee2..263d849dd 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1910,7 +1910,7 @@ def put( stride_tn, stride_fm, stride_fn, - BLOCK_SIZE_M: tl.constexpr, #TODO Needed?? + BLOCK_SIZE_M: tl.constexpr, # TODO Needed?? BLOCK_SIZE_N: tl.constexpr, mask=None, USE_COPY_ENGINE: tl.constexpr = False, @@ -2253,8 +2253,14 @@ def atomic_add( command_in_bytes = 32 # Acquire space (returns base index and wraparound offset) base, offset = anvil.acquire( - # base = anvil.acquire( - queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes + # base = anvil.acquire( + queue_ptr_u32, + read_ptr, + write_ptr, + doorbell_ptr, + cached_write_ptr, + committed_write_ptr, + command_in_bytes, ) # tl.device_print("offset ", offset) From 53f1a209300f756a91296dae98758ceaff00ebbd Mon Sep 17 00:00:00 2001 From: David Sidler Date: Thu, 5 Mar 2026 14:17:35 -0600 Subject: [PATCH 23/29] move padding code --- iris/iris.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/iris/iris.py b/iris/iris.py index f10260ee2..0c9bcfcdc 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -2059,12 +2059,7 @@ def put( ) # Write padding NOPs if we wrapped around - if offset > 0: - num_offset_dwords = (offset // 4).to(tl.int32) - base_ring_pos = anvil.wrap_into_ring(base) - base_index_in_dwords = (base_ring_pos // 4).to(tl.int32) - for i in range(num_offset_dwords): - tl.store(queue_ptr_u32 + base_index_in_dwords + i, 0) + anvil.place_nop_packet(queue_ptr_u32, base, offset) # Place single sub-window copy packet packet_offset_bytes = base + offset @@ -2259,12 +2254,7 @@ def atomic_add( # tl.device_print("offset ", offset) # Write padding NOPs if we wrapped around - if offset > 0: - num_offset_dwords = (offset // 4).to(tl.int32) - base_ring_pos = anvil.wrap_into_ring(base) - base_index_in_dwords = (base_ring_pos // 4).to(tl.int32) - for i in range(num_offset_dwords): - tl.store(queue_ptr_u32 + base_index_in_dwords + i, 0) + anvil.place_nop_packet(queue_ptr_u32, base, offset) # Calculate packet position (base + offset for wraparound) packet_offset_bytes = base + offset From 099a84c1c0c41710e54a376c384f4f1a4355d321 Mon Sep 17 00:00:00 2001 From: David Sidler Date: Thu, 5 Mar 2026 14:17:54 -0600 Subject: [PATCH 24/29] update submodule for nop packet --- ext/shader_sdma | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/shader_sdma b/ext/shader_sdma index 0f56ac78a..24fd095ef 160000 --- a/ext/shader_sdma +++ b/ext/shader_sdma @@ -1 +1 @@ -Subproject commit 0f56ac78a1be5c535f8ed6e5a9fc595bdc410ac8 +Subproject commit 24fd095ef9a299936d21d1106c5597a1ca5f31f9 From 75b55b2501b3c83625762387231ccc56f71cf353 Mon Sep 17 00:00:00 2001 From: David Sidler Date: Thu, 5 Mar 2026 16:26:31 -0600 Subject: [PATCH 25/29] enable flat copy --- .../06_message_passing/message_passing_put.py | 23 +-- .../gemm_all_scatter_wg_specialization.py | 11 +- iris/iris.py | 185 +++++++++--------- 3 files changed, 107 insertions(+), 112 deletions(-) diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index b041ac200..0fea309ba 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -5,12 +5,11 @@ import torch import torch.distributed as dist +import torch.multiprocessing as mp import triton import triton.language as tl import random -from mpi4py import MPI - import iris @@ -137,6 +136,7 @@ def parse_args(): parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size") parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") parser.add_argument( "-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies" ) @@ -236,23 +236,18 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): dist.barrier() dist.destroy_process_group() - def main(): args = parse_args() - comm = MPI.COMM_WORLD # Communicator for all processes - rank = comm.Get_rank() # Get the rank of the current process - num_ranks = comm.Get_size() # Total number of processes - # TODO local_rank - torch.cuda.set_device(rank) - - # Synchronize all processes - comm.barrier() + num_ranks = args["num_ranks"] init_url = "tcp://127.0.0.1:29500" - - _worker(rank, num_ranks, init_url, args) - + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) if __name__ == "__main__": main() diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index 82652c8b9..28a6b2ffd 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -189,14 +189,13 @@ def persistent_gemm_all_scatter_wg_specialization( remote_rank, heap_bases, copy_engine_ctx, - stride_cm_global, - stride_cn_global, - stride_cm_global, - stride_cn_global, - BLOCK_SIZE_M, - BLOCK_SIZE_N, + stride_tm=stride_cm_global, + stride_tn=stride_cn_global, + stride_fm=stride_cm_global, + stride_fn=stride_cn_global, mask=sub_mask, USE_COPY_ENGINE=USE_COPY_ENGINE, + IS_2D_COPY=True, from_base_ptr=c_global, to_base_ptr=c_global, ) diff --git a/iris/iris.py b/iris/iris.py index 0c9bcfcdc..a506ac7c3 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1906,14 +1906,13 @@ def put( to_rank, heap_bases, copy_engine_ctx: tl.tensor, - stride_tm, - stride_tn, - stride_fm, - stride_fn, - BLOCK_SIZE_M: tl.constexpr, #TODO Needed?? - BLOCK_SIZE_N: tl.constexpr, + stride_tm: tl.constexpr = 0, + stride_tn: tl.constexpr = 0, + stride_fm: tl.constexpr = 0, + stride_fn: tl.constexpr = 0, mask=None, USE_COPY_ENGINE: tl.constexpr = False, + IS_2D_COPY: tl.constexpr = False, from_base_ptr=None, to_base_ptr=None, ): @@ -1922,37 +1921,49 @@ def put( This function performs a memory write operation by loading data from the current rank's `from_ptr`, translating the `to_ptr` from the current rank's address space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. - If the `to_rank` is the same as the current rank, this function performs a local copy operation. + + Supports both 1D (flat/linear) and 2D (tiled) copies: + - 1D copies: Used when stride_tm == 0 and stride_fm == 0 (default), uses linear SDMA packets + - 2D copies: Used when strides are non-zero, uses sub-window SDMA packets for better performance Args: from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. - to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. from_rank (int): The current rank ID from which to read the data. - to_rank (int): The `to_rank` ID to which the data will be written. + to_rank (int): The rank ID to which the data will be written. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. copy_engine_ctx (tl.tensor): Copy engine context for SDMA operations. - stride_tm (int): Stride in M dimension for destination buffer. - stride_tn (int): Stride in N dimension for destination buffer. - stride_fm (int): Stride in M dimension for source buffer. - stride_fn (int): Stride in N dimension for source buffer. - BLOCK_SIZE_M (tl.constexpr): Block size in M dimension. - BLOCK_SIZE_N (tl.constexpr): Block size in N dimension. - mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. - USE_COPY_ENGINE (tl.constexpr, optional): Whether to use SDMA copy engine. Defaults to False. - from_base_ptr (triton.PointerType, optional): Base pointer of the source buffer. Required when USE_COPY_ENGINE is True. - to_base_ptr (triton.PointerType, optional): Base pointer of the destination buffer. Required when USE_COPY_ENGINE is True. + stride_tm (int, optional): Stride in M dimension for destination buffer (in elements). Default: 0 (flat copy). + stride_tn (int, optional): Stride in N dimension for destination buffer (in elements). Default: 0. + stride_fm (int, optional): Stride in M dimension for source buffer (in elements). Default: 0 (flat copy). + stride_fn (int, optional): Stride in N dimension for source buffer (in elements). Default: 0. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load/copy data at that index. Defaults to None. + USE_COPY_ENGINE (tl.constexpr, optional): Whether to use SDMA copy engine. Defaults to False (uses regular load/store). + from_base_ptr (triton.PointerType, optional): Base pointer of the source buffer. Required for 2D copies when USE_COPY_ENGINE is True. + to_base_ptr (triton.PointerType, optional): Base pointer of the destination buffer. Required for 2D copies when USE_COPY_ENGINE is True. Returns: None - Example: + Examples: + 1D (flat) copy: >>> @triton.jit >>> def kernel(local_ptr, remote_ptr, heap_bases, copy_engine_ctx): >>> from_rank = 0 >>> to_rank = 1 - >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases, - >>> copy_engine_ctx, stride_m, stride_n, stride_m, stride_n, - >>> BLOCK_M, BLOCK_N, mask=None, USE_COPY_ENGINE=True, + >>> offsets = tl.arange(0, 256) + >>> iris.put(local_ptr + offsets, remote_ptr + offsets, + >>> from_rank, to_rank, heap_bases, copy_engine_ctx, + >>> mask=offsets < 256, USE_COPY_ENGINE=True) + + 2D (tiled) copy: + >>> @triton.jit + >>> def kernel(local_ptr, remote_ptr, heap_bases, copy_engine_ctx, base_ptr): + >>> from_rank = 0 + >>> to_rank = 1 + >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases, copy_engine_ctx, + >>> stride_tm=1024, stride_fm=1024, + >>> mask=mask, USE_COPY_ENGINE=True, >>> from_base_ptr=base_ptr, to_base_ptr=base_ptr) """ translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) @@ -2000,86 +2011,76 @@ def put( # Default to 4 bytes for unknown types element_size_bytes = 4 - # Calculate total size in bytes - # Count number of valid elements based on mask - # src stride: 9216 - # dst strice: 9216 + # Determine packet size based on copy type + # Linear copy packet: 32 bytes for 1D, Sub-window copy packet: 80 bytes for 2D + # IS_2D_COPY is a compile-time constant for proper branch elimination mask_int = mask.to(tl.int32) - num_elements_per_stride = tl.max(tl.sum(mask_int, axis=-1)) - num_strides = tl.max(tl.sum(mask_int, axis=0)) - size_bytes = (num_elements_per_stride * element_size_bytes).to(tl.uint32) - src_stride = (stride_fm * element_size_bytes).to(tl.uint32) - dst_stride = (stride_tm * element_size_bytes).to(tl.uint32) - - # if tl.program_id(axis=0) == 230 and from_rank == 1 and to_rank == 0: - # if from_rank == 1 and to_rank == 0: - # if to_rank == 1: - # if tl.max(size_bytes) == 0: - # tl.device_print("from_ptr ", from_ptr.block_shape) - # tl.device_print("stride_tm ", stride_tm) - # tl.device_print("stride_tn ", stride_tn) - # tl.device_print("stride_fm ", stride_fm) - # tl.device_print("stride_fn ", stride_fn) - # tl.device_print("src_stride ", src_stride) - # tl.device_print("dst_stride ", dst_stride) - - # tl.device_print("queue_ptr_u32 ", queue_ptr_u32) - # tl.device_print("dst_ptr_val (all) ", translated_to_ptr.to(tl.uint64)) - # tl.device_print("dst_ptr_val ", dst_ptr_val) - # tl.device_print("dst_ptr_val (single) ", dst_ptr_val0) - # tl.device_print("src_ptr_u64", src_ptr_u64) - # tl.device_print("src_ptr_val ", src_ptr_val) - # tl.device_print("src_ptr_val (single) ", src_ptr_val0) - # tl.device_print("mask(axis=0): ", tl.sum(mask_int, axis=0)) - # tl.device_print("mask: ", tl.sum(mask_int)) - # tl.device_print("num strides: ", num_strides) - # tl.device_print("size_bytes per stride", size_bytes) - - # Use sub-window copy packet (single packet for entire tile) - command_in_bytes_u32 = 80 + command_in_bytes_u32 = 80 if IS_2D_COPY else 32 command_in_bytes = command_in_bytes_u32.to(tl.uint64) - required_bytes = command_in_bytes - - # Calculate base addresses and offsets for sub-window copy - src_base = from_base_ptr.to(tl.uint64) - dst_base = __translate(to_base_ptr, from_rank, to_rank, heap_bases).to(tl.uint64) - # Calculate tile offset from base - tile_offset_bytes = src_ptr_val0 - src_base - src_y_val = (tile_offset_bytes // src_stride).to(tl.uint32) - src_x_val = (tile_offset_bytes % src_stride).to(tl.uint32) - - tile_offset_bytes_dst = dst_ptr_val0 - dst_base - dst_y_val = (tile_offset_bytes_dst // dst_stride).to(tl.uint32) - dst_x_val = (tile_offset_bytes_dst % dst_stride).to(tl.uint32) - - # Acquire space + # Acquire space in the queue base, offset = anvil.acquire( - queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, required_bytes + queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes ) # Write padding NOPs if we wrapped around anvil.place_nop_packet(queue_ptr_u32, base, offset) - # Place single sub-window copy packet + # Place the appropriate packet type packet_offset_bytes = base + offset - anvil.place_sub_window_copy_packet( - queue_ptr_u32, - packet_offset_bytes, - src_base, - dst_base, - tile_width=size_bytes, - tile_height=num_strides, - src_buffer_pitch=src_stride, - dst_buffer_pitch=dst_stride, - src_x=src_x_val, - src_y=src_y_val, - dst_x=dst_x_val, - dst_y=dst_y_val, - ) - # Submit - pending_wptr = base + offset + required_bytes + if not IS_2D_COPY: + # For 1D copies, mask is 1D, so just sum all elements + num_elements = tl.sum(mask_int, axis=0) + size_bytes = (num_elements * element_size_bytes).to(tl.uint32) + + # Place linear copy packet for 1D/flat copies + anvil.place_copy_packet( + queue_ptr_u32, + packet_offset_bytes, + size_bytes, + src_ptr_val0, + dst_ptr_val0, + ) + else: + # For 2D copies, mask is 2D [M, N], use axis operations + num_elements_per_stride = tl.max(tl.sum(mask_int, axis=-1)) + num_strides = tl.max(tl.sum(mask_int, axis=0)) + size_bytes = (num_elements_per_stride * element_size_bytes).to(tl.uint32) + src_stride = (stride_fm * element_size_bytes).to(tl.uint32) + dst_stride = (stride_tm * element_size_bytes).to(tl.uint32) + + # Place sub-window copy packet for 2D tiled copies + # Calculate base addresses and offsets for sub-window copy + src_base = from_base_ptr.to(tl.uint64) + dst_base = __translate(to_base_ptr, from_rank, to_rank, heap_bases).to(tl.uint64) + + # Calculate tile offset from base + tile_offset_bytes = src_ptr_val0 - src_base + src_y_val = (tile_offset_bytes // src_stride).to(tl.uint32) + src_x_val = (tile_offset_bytes % src_stride).to(tl.uint32) + + tile_offset_bytes_dst = dst_ptr_val0 - dst_base + dst_y_val = (tile_offset_bytes_dst // dst_stride).to(tl.uint32) + dst_x_val = (tile_offset_bytes_dst % dst_stride).to(tl.uint32) + + anvil.place_sub_window_copy_packet( + queue_ptr_u32, + packet_offset_bytes, + src_base, + dst_base, + tile_width=size_bytes, + tile_height=num_strides, + src_buffer_pitch=src_stride, + dst_buffer_pitch=dst_stride, + src_x=src_x_val, + src_y=src_y_val, + dst_x=dst_x_val, + dst_y=dst_y_val, + ) + + # Submit the command to the queue + pending_wptr = base + offset + command_in_bytes anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) From e5a38ddaf55213d7aa472dc072164d364a6a7d5f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 5 Mar 2026 22:29:34 +0000 Subject: [PATCH 26/29] Apply Ruff auto-fixes --- examples/06_message_passing/message_passing_put.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index 0fea309ba..c0c4d7b51 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -236,6 +236,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): dist.barrier() dist.destroy_process_group() + def main(): args = parse_args() @@ -249,5 +250,6 @@ def main(): join=True, ) + if __name__ == "__main__": main() From 0b6ff1ad9c3fcc25e4b29e5be4894d47f5ccd83f Mon Sep 17 00:00:00 2001 From: David Sidler Date: Thu, 5 Mar 2026 17:51:38 -0600 Subject: [PATCH 27/29] clean up --- .../benchmark.py | 53 ++++++++----------- .../matmul_wrapper.py | 3 -- iris/iris.py | 2 +- 3 files changed, 23 insertions(+), 35 deletions(-) diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index ee78383a0..a4fe220c7 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -129,24 +129,18 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): json_writer.add_field(key, value) global_C = shmem.zeros((args["M"], args["N"]), device="cuda", dtype=A.dtype) - # why on heap? so it is uncached?? - # TODO unused local_C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N - # TODO why is this on the heap? - locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8) # why int8?? + locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8) comm_sms = args["num_sms"] - args["gemm_sms"] flags = shmem.zeros((comm_sms, world_size), device="cuda", dtype=torch.uint32) # Get copy engine context - copy_engine_ctx = shmem.get_copy_engine_ctx() # if args["use_copy_engine"] else None - # ( - # shmem.get_copy_engine_handle() if args["use_copy_engine"] and cur_rank == producer_rank else None - # ) + copy_engine_ctx = shmem.get_copy_engine_ctx() bias = None @@ -241,12 +235,9 @@ def run_experiment(): # Wait for all to finish validation shmem.barrier() - # shmem.info("Validating local C...") json_writer.add_field("success", success) - # shmem.info(flags) - if not is_triton_interpret_set(): gemm_registers = matmul.get_matmul_registers() gemm_spills = matmul.get_matmul_spills() @@ -256,26 +247,26 @@ def run_experiment(): shmem.info("Validation completed") - # if args["benchmark"]: - # matmul.set_debug(False) - # shmem.info("Benchmarking...") - # perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) - # triton_ms = iris.do_bench(run_experiment, shmem.barrier) - # triton_tflops = perf(triton_ms) - # algo_string = "all_scatter" - # shmem.info( - # f"tile matmul + {algo_string} (total_tiles={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops" - # ) - - # json_writer.add_field("tflops", triton_tflops) - # json_writer.add_field("total_ms", triton_ms) - - # for k in ["gemm"]: - # json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) - # json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) - - # # Wait for all to finish benchmarking - # shmem.barrier() + if args["benchmark"]: + matmul.set_debug(False) + shmem.info("Benchmarking...") + perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + triton_tflops = perf(triton_ms) + algo_string = "all_scatter" + shmem.info( + f"tile matmul + {algo_string} (total_tiles={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops" + ) + + json_writer.add_field("tflops", triton_tflops) + json_writer.add_field("total_ms", triton_ms) + + for k in ["gemm"]: + json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) + json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() if rank == 0: json_writer.flush() diff --git a/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py b/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py index 28db5c389..135313fb4 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py +++ b/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py @@ -85,9 +85,6 @@ def _call( even_k = K % BLK_K == 0 use_bias = False - print("C: ", c.stride(0), " ", c.stride(1), " global ", c_global.stride(0), " ", c_global.stride(1)) - print("BLK_M ", BLK_M, " BLK_N ", BLK_N, " BLK_K ", BLK_K, " even k ", even_k, " total_tiles ", total_tiles) - # compute grid (work to do per SM on the first wave) stride_bias = bias.stride(0) if use_bias else 0 kk = gemm_kernel[(num_sms,)]( diff --git a/iris/iris.py b/iris/iris.py index 59885eb59..2a0ceeaaa 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -55,7 +55,6 @@ ) import anvil - from iris.symmetric_heap import SymmetricHeap import numpy as np import torch @@ -2009,6 +2008,7 @@ def put( stride_fm: tl.constexpr = 0, stride_fn: tl.constexpr = 0, mask=None, + hint: tl.constexpr = None, USE_COPY_ENGINE: tl.constexpr = False, IS_2D_COPY: tl.constexpr = False, from_base_ptr=None, From bfe454808d59b3452675e02598f37fadcb1f5a6a Mon Sep 17 00:00:00 2001 From: David Sidler Date: Wed, 18 Mar 2026 16:12:14 -0500 Subject: [PATCH 28/29] add copy engine support to fused gemm-allscatter --- examples/07_gemm_all_scatter/benchmark.py | 16 ++++ .../07_gemm_all_scatter/gemm_all_scatter.py | 91 +++++++++++++++---- .../07_gemm_all_scatter/matmul_wrapper.py | 12 +++ 3 files changed, 103 insertions(+), 16 deletions(-) diff --git a/examples/07_gemm_all_scatter/benchmark.py b/examples/07_gemm_all_scatter/benchmark.py index 994c10cad..c515df52c 100755 --- a/examples/07_gemm_all_scatter/benchmark.py +++ b/examples/07_gemm_all_scatter/benchmark.py @@ -58,6 +58,9 @@ def parse_args(): help="Number of SMs for persistent GEMM algorithm (default: auto-detected)", ) parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + parser.add_argument( + "-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies" + ) return vars(parser.parse_args()) @@ -124,11 +127,15 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N + # Get copy engine context + copy_engine_ctx = shmem.get_copy_engine_ctx() + bias = None gemm_stream = torch.cuda.Stream() json_writer.add_field("gemm_sms", args["gemm_sms"]) + json_writer.add_field("total_tiles", total_tiles) kernel_timing = { "gemm": { @@ -142,11 +149,17 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Allocate Timestamps timestamps = Timestamps(num_tiles=total_tiles) + # Allocate flags for synchronization (one flag per SM per rank) + flags = shmem.zeros((args["gemm_sms"] * world_size,), device="cuda", dtype=torch.int32) + def run_experiment(): nonlocal local_C nonlocal global_C nonlocal kernel_timing + # Reset flags to zero before each experiment + flags.zero_() + shmem.barrier() if args["trace_tiles"]: @@ -163,6 +176,7 @@ def run_experiment(): local_C, global_C, bias, + flags, rank, world_size, args["gemm_sms"], @@ -174,6 +188,8 @@ def run_experiment(): shmem.get_heap_bases(), "gfx942", args["trace_tiles"], + args["use_copy_engine"], + copy_engine_ctx, timestamps.mm_begin_timestamp, timestamps.mm_end_timestamp, ) diff --git a/examples/07_gemm_all_scatter/gemm_all_scatter.py b/examples/07_gemm_all_scatter/gemm_all_scatter.py index 937835d6f..9fa78ac5e 100644 --- a/examples/07_gemm_all_scatter/gemm_all_scatter.py +++ b/examples/07_gemm_all_scatter/gemm_all_scatter.py @@ -9,6 +9,11 @@ import iris +@triton.jit +def wait_cnt(): + tl.inline_asm_elementwise("s_waitcnt vmcnt(0)", "=r", [], dtype=tl.int32, is_pure=False, pack=1) + + @triton.jit() def persistent_gemm_all_scatter( A, @@ -16,6 +21,7 @@ def persistent_gemm_all_scatter( C, c_global, bias_ptr, + flags, M, N, K, @@ -40,13 +46,15 @@ def persistent_gemm_all_scatter( cur_rank: tl.constexpr, world_size: tl.constexpr, COLLECT_TIMESTAMPS: tl.constexpr = False, + USE_COPY_ENGINE: tl.constexpr = False, + copy_engine_ctx: tl.tensor = None, mm_begin_timestamp_ptr: tl.tensor = None, mm_end_timestamp_ptr: tl.tensor = None, ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + # if NUM_XCDS != 1: + # pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -60,6 +68,7 @@ def persistent_gemm_all_scatter( acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + # Process all tiles for this SM for tile_id in range(pid, total_tiles, NUM_SMS): if COLLECT_TIMESTAMPS: timestamp = read_realtime() @@ -132,17 +141,67 @@ def persistent_gemm_all_scatter( timestamp = read_realtime() tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) - # Store data to the global result using puts - for remote_rank in range(world_size): - if remote_rank == cur_rank: - # For the current rank, we can use store - tl.store(c_global + global_offset, c, mask=sub_mask) - else: - iris.store( - c_global + global_offset, - c, - cur_rank, - remote_rank, - heap_bases, - mask=sub_mask, - ) + + if USE_COPY_ENGINE: + # Store locally first + tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt") + wait_cnt() + tl.debug_barrier() + for remote_rank in range(world_size): + if remote_rank != cur_rank: + iris.put( + c_global + global_offset, + c_global + global_offset, + cur_rank, + remote_rank, + heap_bases, + copy_engine_ctx, + stride_tm=stride_cm_global, + stride_tn=stride_cn_global, + stride_fm=stride_cm_global, + stride_fn=stride_cn_global, + mask=sub_mask, + USE_COPY_ENGINE=True, + IS_2D_COPY=True, + from_base_ptr=c_global, + to_base_ptr=c_global, + ) + + else: + # Store data to the global result using puts + for remote_rank in range(world_size): + if remote_rank == cur_rank: + # For the current rank, we can use store + tl.store(c_global + global_offset, c, mask=sub_mask) + else: + iris.store( + c_global + global_offset, + c, + cur_rank, + remote_rank, + heap_bases, + mask=sub_mask, + ) + + # After all tiles are processed, signal and wait once per SM + tl.debug_barrier() + # Signal other ranks that all our puts/stores are complete + for remote_rank in range(world_size): + if remote_rank != cur_rank: + iris.atomic_add( + flags + (pid * world_size) + cur_rank, + 1, + cur_rank, + remote_rank, + heap_bases, + sem="release", + scope="sys", + copy_engine_ctx=copy_engine_ctx, + USE_COPY_ENGINE=USE_COPY_ENGINE, + ) + + # Wait for other ranks to signal us + for remote_rank in range(world_size): + if remote_rank != cur_rank: + while tl.load(flags + (pid * world_size) + remote_rank, cache_modifier=".cv", volatile=True) != 1: + pass diff --git a/examples/07_gemm_all_scatter/matmul_wrapper.py b/examples/07_gemm_all_scatter/matmul_wrapper.py index 5d8adb589..3f6d3e0d6 100644 --- a/examples/07_gemm_all_scatter/matmul_wrapper.py +++ b/examples/07_gemm_all_scatter/matmul_wrapper.py @@ -44,6 +44,7 @@ def _call( c: torch.Tensor, c_global: torch.Tensor, bias: torch.Tensor, + flags: torch.Tensor, rank: int, world_size: int, num_sms: int, @@ -55,6 +56,8 @@ def _call( heap_bases_ptr: torch.Tensor = None, arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, + USE_COPY_ENGINE: bool = False, + copy_engine_ctx: torch.Tensor = None, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): @@ -86,6 +89,7 @@ def _call( c, c_global, bias, + flags, M, N, K, @@ -115,6 +119,8 @@ def _call( cur_rank=rank, world_size=world_size, COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + USE_COPY_ENGINE=USE_COPY_ENGINE, + copy_engine_ctx=copy_engine_ctx, mm_begin_timestamp_ptr=mm_begin_timestamp, mm_end_timestamp_ptr=mm_end_timestamp, ) @@ -133,6 +139,7 @@ def forward( c: torch.Tensor, c_global: torch.Tensor, bias: torch.Tensor, + flags: torch.Tensor, rank: int, world_size: int, num_sms: int, @@ -144,6 +151,8 @@ def forward( heap_bases_ptr: torch.Tensor = None, arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, + USE_COPY_ENGINE: bool = False, + copy_engine_ctx: torch.Tensor = None, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): @@ -153,6 +162,7 @@ def forward( c=c, c_global=c_global, bias=bias, + flags=flags, rank=rank, world_size=world_size, num_sms=num_sms, @@ -164,6 +174,8 @@ def forward( heap_bases_ptr=heap_bases_ptr, arch=arch, COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + USE_COPY_ENGINE=USE_COPY_ENGINE, + copy_engine_ctx=copy_engine_ctx, mm_begin_timestamp=mm_begin_timestamp, mm_end_timestamp=mm_end_timestamp, ) From 27040c8fd542b9f1ad69d8bf334212b8043c785d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 18 Mar 2026 21:13:12 +0000 Subject: [PATCH 29/29] Apply Ruff auto-fixes --- examples/07_gemm_all_scatter/gemm_all_scatter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/07_gemm_all_scatter/gemm_all_scatter.py b/examples/07_gemm_all_scatter/gemm_all_scatter.py index 9fa78ac5e..78d4fba6a 100644 --- a/examples/07_gemm_all_scatter/gemm_all_scatter.py +++ b/examples/07_gemm_all_scatter/gemm_all_scatter.py @@ -141,7 +141,6 @@ def persistent_gemm_all_scatter( timestamp = read_realtime() tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) - if USE_COPY_ENGINE: # Store locally first tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt")