diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..41cf4672 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "ext/shader_sdma"] + path = ext/shader_sdma + url = https://github.com/AARInternal/shader_sdma.git diff --git a/examples/06_message_passing/message_passing_copy_engine.py b/examples/06_message_passing/message_passing_copy_engine.py new file mode 100644 index 00000000..677b6846 --- /dev/null +++ b/examples/06_message_passing/message_passing_copy_engine.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import argparse + +import torch +import torch.distributed as dist +import triton +import triton.language as tl +import random + +from mpi4py import MPI + +import iris + + +@triton.jit +def producer_kernel( + source_buffer, # tl.tensor: pointer to source data + target_buffer, # tl.tensor: pointer to target data + flag, # tl.tensor: pointer to flags + buffer_size, # int32: total number of elements + producer_rank: tl.constexpr, + consumer_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers + copy_engine_handle_ptr, +): + pid = tl.program_id(0) + + # Compute start index of this block + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Guard for out-of-bounds accesses + mask = offsets < buffer_size + + # Put chunk into remote buffer + iris.put( + source_buffer + offsets, + target_buffer + offsets, + producer_rank, + consumer_rank, + heap_bases_ptr, + copy_engine_handle_ptr, + mask=mask, + USE_COPY_ENGINE=True, + ) + + # Set flag to signal completion + iris.signal_ce(flag + pid, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr) + + +@triton.jit +def consumer_kernel( + buffer, # tl.tensor: pointer to shared buffer (read from target_rank) + flag, # tl.tensor: sync flag per block + buffer_size, # int32: total number of elements + consumer_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < buffer_size + + # Spin-wait until writer sets flag[pid] = 1 + # zero_u64 = tl.zeros((1,), tl.uint64) + # one_u64 = tl.full((1,), 1, tl.uint64) + done = 0 # zero_u64 + while done == 0: + done = iris.atomic_cas( + flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" + ) + + # Read from the target buffer (written by producer) + values = tl.load(buffer + offsets, mask=mask) + + # Do something with values... + # (Here you might write to output, do computation, etc.) + values = values * 2 + + # Store chunk to target buffer + tl.store( + buffer + offsets, + values, + mask=mask, + ) + + # Optionally reset the flag for next iteration + tl.store(flag + pid, 0) + + +torch.manual_seed(123) +random.seed(123) + + +def torch_dtype_from_str(datatype: str) -> torch.dtype: + dtype_map = { + "fp16": torch.float16, + "fp32": torch.float32, + "int8": torch.int8, + "bf16": torch.bfloat16, + } + try: + return dtype_map[datatype] + except KeyError: + print(f"Unknown datatype: {datatype}") + exit(1) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Parse Message Passing configuration.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-t", + "--datatype", + type=str, + default="fp32", + choices=["fp16", "fp32", "int8", "bf16"], + help="Datatype of computation", + ) + parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size") + parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") + + parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + # Main benchmark logic + shmem = iris.iris(args["heap_size"]) + dtype = torch_dtype_from_str(args["datatype"]) + cur_rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Allocate source and destination buffers on the symmetric heap + destination_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype) + if dtype.is_floating_point: + source_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) + else: + ii = torch.iinfo(dtype) + source_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) + + if world_size != 2: + raise ValueError("This example requires exactly two processes.") + + producer_rank = 0 + consumer_rank = 1 + + n_elements = source_buffer.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + num_blocks = triton.cdiv(n_elements, args["block_size"]) + + # Allocate flags on the symmetric heap + flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32) + + if cur_rank == producer_rank: + shmem.info(f"Rank {cur_rank} is sending data to rank {consumer_rank}.") + kk = producer_kernel[grid]( + source_buffer, + destination_buffer, + flags, + n_elements, + producer_rank, + consumer_rank, + args["block_size"], + shmem.get_heap_bases(), + shmem.get_copy_engine_handle(consumer_rank), + ) + else: + shmem.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.") + kk = consumer_kernel[grid]( + destination_buffer, flags, n_elements, consumer_rank, args["block_size"], shmem.get_heap_bases() + ) + shmem.barrier() + shmem.info(f"Rank {cur_rank} has finished sending/receiving data.") + shmem.info("Validating output...") + + success = True + if cur_rank == consumer_rank: + expected = source_buffer * 2 + diff_mask = ~torch.isclose(destination_buffer, expected, atol=1) + breaking_indices = torch.nonzero(diff_mask, as_tuple=False) + + if not torch.allclose(destination_buffer, expected, atol=1): + max_diff = (destination_buffer - expected).abs().max().item() + shmem.info(f"Max absolute difference: {max_diff}") + for idx in breaking_indices: + idx = tuple(idx.tolist()) + computed_val = destination_buffer[idx] + expected_val = expected[idx] + shmem.info(f"Mismatch at index {idx}: C={computed_val}, expected={expected_val}") + success = False + break + + if success: + shmem.info("Validation successful.") + else: + shmem.info(f"Validation failed with {len(breaking_indices)} errors / {destination_buffer.numel()}") + + shmem.barrier() + + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + comm = MPI.COMM_WORLD # Communicator for all processes + rank = comm.Get_rank() # Get the rank of the current process + num_ranks = comm.Get_size() # Total number of processes + # TODO local_rank + torch.cuda.set_device(rank) + + # Synchronize all processes + comm.barrier() + + init_url = "tcp://127.0.0.1:29500" + + _worker(rank, num_ranks, init_url, args) + + +if __name__ == "__main__": + main() diff --git a/examples/06_message_passing/message_passing_put.py b/examples/06_message_passing/message_passing_put.py index 4f726969..c0c4d7b5 100755 --- a/examples/06_message_passing/message_passing_put.py +++ b/examples/06_message_passing/message_passing_put.py @@ -23,6 +23,8 @@ def producer_kernel( consumer_rank: tl.constexpr, BLOCK_SIZE: tl.constexpr, heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers + copy_engine_handle_ptr, + USE_COPY_ENGINE: tl.constexpr, ): pid = tl.program_id(0) @@ -34,10 +36,30 @@ def producer_kernel( mask = offsets < buffer_size # Put chunk into remote buffer - iris.put(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, mask=mask) + iris.put( + source_buffer + offsets, + target_buffer + offsets, + producer_rank, + consumer_rank, + heap_bases_ptr, + copy_engine_handle_ptr, + mask=mask, + USE_COPY_ENGINE=USE_COPY_ENGINE, + ) # Set flag to signal completion - iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys") + # iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, sem="release", scope="sys") + iris.atomic_add( + flag + pid, + 1, + producer_rank, + consumer_rank, + heap_bases_ptr, + sem="release", + scope="sys", + copy_engine_ctx=copy_engine_handle_ptr, + USE_COPY_ENGINE=USE_COPY_ENGINE, + ) @triton.jit @@ -113,9 +135,11 @@ def parse_args(): ) parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size") parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size") - parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + parser.add_argument( + "-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies" + ) return vars(parser.parse_args()) @@ -138,12 +162,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): world_size = shmem.get_num_ranks() # Allocate source and destination buffers on the symmetric heap - source_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype) + destination_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype) if dtype.is_floating_point: - destination_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) + source_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype) else: ii = torch.iinfo(dtype) - destination_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) + source_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) if world_size != 2: raise ValueError("This example requires exactly two processes.") @@ -158,6 +182,10 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Allocate flags on the symmetric heap flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32) + # Get copy engine context + # copy_engine_ctx = shmem.get_copy_engine_handle(consumer_rank) if args["use_copy_engine"] and cur_rank == producer_rank else None + copy_engine_ctx = shmem.get_copy_engine_ctx() + if cur_rank == producer_rank: shmem.info(f"Rank {cur_rank} is sending data to rank {consumer_rank}.") kk = producer_kernel[grid]( @@ -169,6 +197,8 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): consumer_rank, args["block_size"], shmem.get_heap_bases(), + copy_engine_ctx, + USE_COPY_ENGINE=args["use_copy_engine"], ) else: shmem.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.") @@ -199,7 +229,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): if success: shmem.info("Validation successful.") else: - shmem.info("Validation failed.") + shmem.info(f"Validation failed with {len(breaking_indices)} errors / {destination_buffer.numel()}") shmem.barrier() diff --git a/examples/07_gemm_all_scatter/benchmark.py b/examples/07_gemm_all_scatter/benchmark.py index 994c10ca..c515df52 100755 --- a/examples/07_gemm_all_scatter/benchmark.py +++ b/examples/07_gemm_all_scatter/benchmark.py @@ -58,6 +58,9 @@ def parse_args(): help="Number of SMs for persistent GEMM algorithm (default: auto-detected)", ) parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + parser.add_argument( + "-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies" + ) return vars(parser.parse_args()) @@ -124,11 +127,15 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N + # Get copy engine context + copy_engine_ctx = shmem.get_copy_engine_ctx() + bias = None gemm_stream = torch.cuda.Stream() json_writer.add_field("gemm_sms", args["gemm_sms"]) + json_writer.add_field("total_tiles", total_tiles) kernel_timing = { "gemm": { @@ -142,11 +149,17 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Allocate Timestamps timestamps = Timestamps(num_tiles=total_tiles) + # Allocate flags for synchronization (one flag per SM per rank) + flags = shmem.zeros((args["gemm_sms"] * world_size,), device="cuda", dtype=torch.int32) + def run_experiment(): nonlocal local_C nonlocal global_C nonlocal kernel_timing + # Reset flags to zero before each experiment + flags.zero_() + shmem.barrier() if args["trace_tiles"]: @@ -163,6 +176,7 @@ def run_experiment(): local_C, global_C, bias, + flags, rank, world_size, args["gemm_sms"], @@ -174,6 +188,8 @@ def run_experiment(): shmem.get_heap_bases(), "gfx942", args["trace_tiles"], + args["use_copy_engine"], + copy_engine_ctx, timestamps.mm_begin_timestamp, timestamps.mm_end_timestamp, ) diff --git a/examples/07_gemm_all_scatter/gemm_all_scatter.py b/examples/07_gemm_all_scatter/gemm_all_scatter.py index 937835d6..78d4fba6 100644 --- a/examples/07_gemm_all_scatter/gemm_all_scatter.py +++ b/examples/07_gemm_all_scatter/gemm_all_scatter.py @@ -9,6 +9,11 @@ import iris +@triton.jit +def wait_cnt(): + tl.inline_asm_elementwise("s_waitcnt vmcnt(0)", "=r", [], dtype=tl.int32, is_pure=False, pack=1) + + @triton.jit() def persistent_gemm_all_scatter( A, @@ -16,6 +21,7 @@ def persistent_gemm_all_scatter( C, c_global, bias_ptr, + flags, M, N, K, @@ -40,13 +46,15 @@ def persistent_gemm_all_scatter( cur_rank: tl.constexpr, world_size: tl.constexpr, COLLECT_TIMESTAMPS: tl.constexpr = False, + USE_COPY_ENGINE: tl.constexpr = False, + copy_engine_ctx: tl.tensor = None, mm_begin_timestamp_ptr: tl.tensor = None, mm_end_timestamp_ptr: tl.tensor = None, ): pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + # if NUM_XCDS != 1: + # pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -60,6 +68,7 @@ def persistent_gemm_all_scatter( acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + # Process all tiles for this SM for tile_id in range(pid, total_tiles, NUM_SMS): if COLLECT_TIMESTAMPS: timestamp = read_realtime() @@ -132,17 +141,66 @@ def persistent_gemm_all_scatter( timestamp = read_realtime() tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) - # Store data to the global result using puts - for remote_rank in range(world_size): - if remote_rank == cur_rank: - # For the current rank, we can use store - tl.store(c_global + global_offset, c, mask=sub_mask) - else: - iris.store( - c_global + global_offset, - c, - cur_rank, - remote_rank, - heap_bases, - mask=sub_mask, - ) + if USE_COPY_ENGINE: + # Store locally first + tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt") + wait_cnt() + tl.debug_barrier() + for remote_rank in range(world_size): + if remote_rank != cur_rank: + iris.put( + c_global + global_offset, + c_global + global_offset, + cur_rank, + remote_rank, + heap_bases, + copy_engine_ctx, + stride_tm=stride_cm_global, + stride_tn=stride_cn_global, + stride_fm=stride_cm_global, + stride_fn=stride_cn_global, + mask=sub_mask, + USE_COPY_ENGINE=True, + IS_2D_COPY=True, + from_base_ptr=c_global, + to_base_ptr=c_global, + ) + + else: + # Store data to the global result using puts + for remote_rank in range(world_size): + if remote_rank == cur_rank: + # For the current rank, we can use store + tl.store(c_global + global_offset, c, mask=sub_mask) + else: + iris.store( + c_global + global_offset, + c, + cur_rank, + remote_rank, + heap_bases, + mask=sub_mask, + ) + + # After all tiles are processed, signal and wait once per SM + tl.debug_barrier() + # Signal other ranks that all our puts/stores are complete + for remote_rank in range(world_size): + if remote_rank != cur_rank: + iris.atomic_add( + flags + (pid * world_size) + cur_rank, + 1, + cur_rank, + remote_rank, + heap_bases, + sem="release", + scope="sys", + copy_engine_ctx=copy_engine_ctx, + USE_COPY_ENGINE=USE_COPY_ENGINE, + ) + + # Wait for other ranks to signal us + for remote_rank in range(world_size): + if remote_rank != cur_rank: + while tl.load(flags + (pid * world_size) + remote_rank, cache_modifier=".cv", volatile=True) != 1: + pass diff --git a/examples/07_gemm_all_scatter/matmul_wrapper.py b/examples/07_gemm_all_scatter/matmul_wrapper.py index 5d8adb58..3f6d3e0d 100644 --- a/examples/07_gemm_all_scatter/matmul_wrapper.py +++ b/examples/07_gemm_all_scatter/matmul_wrapper.py @@ -44,6 +44,7 @@ def _call( c: torch.Tensor, c_global: torch.Tensor, bias: torch.Tensor, + flags: torch.Tensor, rank: int, world_size: int, num_sms: int, @@ -55,6 +56,8 @@ def _call( heap_bases_ptr: torch.Tensor = None, arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, + USE_COPY_ENGINE: bool = False, + copy_engine_ctx: torch.Tensor = None, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): @@ -86,6 +89,7 @@ def _call( c, c_global, bias, + flags, M, N, K, @@ -115,6 +119,8 @@ def _call( cur_rank=rank, world_size=world_size, COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + USE_COPY_ENGINE=USE_COPY_ENGINE, + copy_engine_ctx=copy_engine_ctx, mm_begin_timestamp_ptr=mm_begin_timestamp, mm_end_timestamp_ptr=mm_end_timestamp, ) @@ -133,6 +139,7 @@ def forward( c: torch.Tensor, c_global: torch.Tensor, bias: torch.Tensor, + flags: torch.Tensor, rank: int, world_size: int, num_sms: int, @@ -144,6 +151,8 @@ def forward( heap_bases_ptr: torch.Tensor = None, arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, + USE_COPY_ENGINE: bool = False, + copy_engine_ctx: torch.Tensor = None, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): @@ -153,6 +162,7 @@ def forward( c=c, c_global=c_global, bias=bias, + flags=flags, rank=rank, world_size=world_size, num_sms=num_sms, @@ -164,6 +174,8 @@ def forward( heap_bases_ptr=heap_bases_ptr, arch=arch, COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + USE_COPY_ENGINE=USE_COPY_ENGINE, + copy_engine_ctx=copy_engine_ctx, mm_begin_timestamp=mm_begin_timestamp, mm_end_timestamp=mm_end_timestamp, ) diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index 655c892f..a4fe220c 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -65,6 +65,9 @@ def parse_args(): ) parser.add_argument("--num_stages", type=int, default=2, help="Number of stages") parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + parser.add_argument( + "-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies" + ) return vars(parser.parse_args()) @@ -133,6 +136,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): total_tiles = total_blocks_M * total_blocks_N locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8) + comm_sms = args["num_sms"] - args["gemm_sms"] + flags = shmem.zeros((comm_sms, world_size), device="cuda", dtype=torch.uint32) + + # Get copy engine context + copy_engine_ctx = shmem.get_copy_engine_ctx() bias = None @@ -175,6 +183,7 @@ def run_experiment(): global_C, bias, locks, + flags, rank, world_size, args["gemm_sms"], @@ -187,6 +196,8 @@ def run_experiment(): shmem.get_heap_bases(), "gfx942", args["trace_tiles"], + args["use_copy_engine"], + copy_engine_ctx, timestamps.mm_begin_timestamp, timestamps.mm_end_timestamp, ) @@ -224,7 +235,6 @@ def run_experiment(): # Wait for all to finish validation shmem.barrier() - shmem.info("Validating local C...") json_writer.add_field("success", success) diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index 4d9c2825..a1fdb8e6 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -9,6 +9,11 @@ import iris +@triton.jit +def wait_cnt(): + tl.inline_asm_elementwise("s_waitcnt vmcnt(0)", "=r", [], dtype=tl.int32, is_pure=False, pack=1) + + @triton.jit() def persistent_gemm_all_scatter_wg_specialization( A, @@ -17,6 +22,7 @@ def persistent_gemm_all_scatter_wg_specialization( c_global, bias_ptr, locks, + flags, M, N, K, @@ -24,8 +30,8 @@ def persistent_gemm_all_scatter_wg_specialization( stride_ak, stride_bk, stride_bn, - stride_cm, - stride_cn, + stride_cm, # unused + stride_cn, # unused stride_cm_global, stride_cn_global, stride_bias, @@ -42,6 +48,8 @@ def persistent_gemm_all_scatter_wg_specialization( cur_rank: tl.constexpr, world_size: tl.constexpr, COLLECT_TIMESTAMPS: tl.constexpr = False, + USE_COPY_ENGINE: tl.constexpr = False, + copy_engine_ctx: tl.tensor = None, mm_begin_timestamp_ptr: tl.tensor = None, mm_end_timestamp_ptr: tl.tensor = None, ): @@ -67,6 +75,9 @@ def persistent_gemm_all_scatter_wg_specialization( # and another that performs the communication. Uses persistent- # kernel. if pid < GEMM_SMS: + # tl.device_print("GEMM_SMS: ", GEMM_SMS) + # tl.device_print("GEMM pid: ", pid) + for tile_id in range(pid, total_tiles, GEMM_SMS): if COLLECT_TIMESTAMPS: timestamp = read_realtime() @@ -140,12 +151,15 @@ def persistent_gemm_all_scatter_wg_specialization( tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt") + wait_cnt() tl.debug_barrier() tl.store(locks + tile_id, 1, cache_modifier=".wt") else: # pid >= GEMM_SMS COMM_SMS = NUM_SMS - GEMM_SMS pid = pid - GEMM_SMS + # tl.device_print("COMM_SMS: ", COMM_SMS) + # tl.device_print("COMM pid: ", pid) for tile_id in range(pid, total_tiles, COMM_SMS): num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = tile_id // num_pid_in_group @@ -174,5 +188,37 @@ def persistent_gemm_all_scatter_wg_specialization( cur_rank, remote_rank, heap_bases, + copy_engine_ctx, + stride_tm=stride_cm_global, + stride_tn=stride_cn_global, + stride_fm=stride_cm_global, + stride_fn=stride_cn_global, mask=sub_mask, + USE_COPY_ENGINE=USE_COPY_ENGINE, + IS_2D_COPY=True, + from_base_ptr=c_global, + to_base_ptr=c_global, ) + tl.debug_barrier() + # Signal other ranks + for remote_rank in range(world_size): + if remote_rank != cur_rank: + # print("Issue atomic_add") + iris.atomic_add( + flags + (pid * world_size) + cur_rank, + 1, + cur_rank, + remote_rank, + heap_bases, + sem="release", + scope="sys", + copy_engine_ctx=copy_engine_ctx, + USE_COPY_ENGINE=USE_COPY_ENGINE, + ) + # print("Start waiting") + # Wait for other ranks to signal us + for remote_rank in range(world_size): + if remote_rank != cur_rank: + while tl.load(flags + (pid * world_size) + remote_rank, cache_modifier=".cv", volatile=True) != 1: + pass + # print("done waiting") diff --git a/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py b/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py index 1d46297a..135313fb 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py +++ b/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py @@ -47,6 +47,7 @@ def _call( c_global: torch.Tensor, bias: torch.Tensor, locks: torch.Tensor, + flags: torch.Tensor, rank: int, world_size: int, gemm_sms: int, @@ -59,6 +60,8 @@ def _call( heap_bases_ptr: torch.Tensor = None, arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, + USE_COPY_ENGINE: bool = False, + copy_engine_ctx: torch.Tensor = None, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): @@ -91,6 +94,7 @@ def _call( c_global, bias, locks, + flags, M, N, K, @@ -121,6 +125,8 @@ def _call( cur_rank=rank, world_size=world_size, COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + USE_COPY_ENGINE=USE_COPY_ENGINE, + copy_engine_ctx=copy_engine_ctx, mm_begin_timestamp_ptr=mm_begin_timestamp, mm_end_timestamp_ptr=mm_end_timestamp, ) @@ -140,6 +146,7 @@ def forward( c_global: torch.Tensor, bias: torch.Tensor, locks: torch.Tensor, + flags: torch.Tensor, rank: int, world_size: int, gemm_sms: int, @@ -152,6 +159,8 @@ def forward( heap_bases_ptr: torch.Tensor = None, arch: str = "gfx942", COLLECT_TIMESTAMPS: bool = False, + USE_COPY_ENGINE: bool = False, + copy_engine_ctx: torch.Tensor = None, mm_begin_timestamp: torch.Tensor = None, mm_end_timestamp: torch.Tensor = None, ): @@ -162,6 +171,7 @@ def forward( c_global=c_global, bias=bias, locks=locks, + flags=flags, rank=rank, world_size=world_size, gemm_sms=gemm_sms, @@ -174,6 +184,8 @@ def forward( heap_bases_ptr=heap_bases_ptr, arch=arch, COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + USE_COPY_ENGINE=USE_COPY_ENGINE, + copy_engine_ctx=copy_engine_ctx, mm_begin_timestamp=mm_begin_timestamp, mm_end_timestamp=mm_end_timestamp, ) diff --git a/ext/shader_sdma b/ext/shader_sdma new file mode 160000 index 00000000..24fd095e --- /dev/null +++ b/ext/shader_sdma @@ -0,0 +1 @@ +Subproject commit 24fd095ef9a299936d21d1106c5597a1ca5f31f9 diff --git a/iris/iris.py b/iris/iris.py index 0f1073b3..2a0ceeaa 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -53,6 +53,8 @@ get_cu_count, count_devices, ) + +import anvil from iris.symmetric_heap import SymmetricHeap import numpy as np import torch @@ -112,6 +114,35 @@ def __init__(self, heap_size=1 << 30, allocator_type="torch"): distributed_barrier() + # initialize copy engines + self.copy_engines = anvil.AnvilLib.get_instance() + self.copy_engines.init() + + # connect to all peers + # TODO only connect local ranks + # TODO get size + context_size = 6 + self.copy_engines_device_ctx = torch.zeros((num_ranks, context_size), dtype=torch.uint64, device=self.device) + + for rank in range(num_ranks): + if rank != cur_rank: + self.copy_engines.connect(cur_rank, rank, 1) + queue = self.copy_engines.get_sdma_queue(cur_rank, rank, 0) + handle = queue.device_ctx() + self.info(f"---- Queue {rank} ------------") + self.info(f"queue_buf {handle.queue_buf:#x} at {id(handle.queue_buf):#x}") + self.info(f"rptr {handle.rptr:#x} at {id(handle.rptr):#x}") + self.info(f"wptr {handle.wptr:#x} at {id(handle.wptr):#x}") + self.info(f"doorbell {handle.doorbell:#x} at {id(handle.doorbell):#x}") + self.info(f"cached_write_ptr {handle.cached_wptr:#x} at {id(handle.cached_wptr):#x}") + self.info(f"committed_write_ptr {handle.committed_wptr:#x} at {id(handle.committed_wptr):#x}") + + self.copy_engines_device_ctx[rank][0] = handle.queue_buf + self.copy_engines_device_ctx[rank][1] = handle.rptr + self.copy_engines_device_ctx[rank][2] = handle.wptr + self.copy_engines_device_ctx[rank][3] = handle.doorbell + self.copy_engines_device_ctx[rank][4] = handle.cached_wptr + self.copy_engines_device_ctx[rank][5] = handle.committed_wptr # Initialize CCL interface self.ccl = self.CCL(self) @@ -882,6 +913,9 @@ def get_heap_bases(self): """ return self.heap_bases + def get_copy_engine_ctx(self): + return self.copy_engines_device_ctx + def get_device_context(self): """ Get the device context tensor for DeviceContext initialization. @@ -1962,43 +1996,310 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.co @triton.jit -def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.constexpr = None): +def put( + from_ptr, + to_ptr, + from_rank, + to_rank, + heap_bases, + copy_engine_ctx: tl.tensor, + stride_tm: tl.constexpr = 0, + stride_tn: tl.constexpr = 0, + stride_fm: tl.constexpr = 0, + stride_fn: tl.constexpr = 0, + mask=None, + hint: tl.constexpr = None, + USE_COPY_ENGINE: tl.constexpr = False, + IS_2D_COPY: tl.constexpr = False, + from_base_ptr=None, + to_base_ptr=None, +): """ Copies data from the current rank's local memory to the specified rank's memory. This function performs a memory write operation by loading data from the current rank's `from_ptr`, translating the `to_ptr` from the current rank's address space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. - If the `to_rank` is the same as the current rank, this function performs a local copy operation. + + Supports both 1D (flat/linear) and 2D (tiled) copies: + - 1D copies: Used when stride_tm == 0 and stride_fm == 0 (default), uses linear SDMA packets + - 2D copies: Used when strides are non-zero, uses sub-window SDMA packets for better performance Args: from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. - to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. from_rank (int): The current rank ID from which to read the data. - to_rank (int): The `to_rank` ID to which the data will be written. + to_rank (int): The rank ID to which the data will be written. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. - mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. - hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). + copy_engine_ctx (tl.tensor): Copy engine context for SDMA operations. + stride_tm (int, optional): Stride in M dimension for destination buffer (in elements). Default: 0 (flat copy). + stride_tn (int, optional): Stride in N dimension for destination buffer (in elements). Default: 0. + stride_fm (int, optional): Stride in M dimension for source buffer (in elements). Default: 0 (flat copy). + stride_fn (int, optional): Stride in N dimension for source buffer (in elements). Default: 0. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load/copy data at that index. Defaults to None. + USE_COPY_ENGINE (tl.constexpr, optional): Whether to use SDMA copy engine. Defaults to False (uses regular load/store). + from_base_ptr (triton.PointerType, optional): Base pointer of the source buffer. Required for 2D copies when USE_COPY_ENGINE is True. + to_base_ptr (triton.PointerType, optional): Base pointer of the destination buffer. Required for 2D copies when USE_COPY_ENGINE is True. Returns: None - Example: + Examples: + 1D (flat) copy: >>> @triton.jit - >>> def kernel(local_ptr, remote_ptr, heap_bases): + >>> def kernel(local_ptr, remote_ptr, heap_bases, copy_engine_ctx): >>> from_rank = 0 >>> to_rank = 1 - >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) + >>> offsets = tl.arange(0, 256) + >>> iris.put(local_ptr + offsets, remote_ptr + offsets, + >>> from_rank, to_rank, heap_bases, copy_engine_ctx, + >>> mask=offsets < 256, USE_COPY_ENGINE=True) + + 2D (tiled) copy: + >>> @triton.jit + >>> def kernel(local_ptr, remote_ptr, heap_bases, copy_engine_ctx, base_ptr): + >>> from_rank = 0 + >>> to_rank = 1 + >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases, copy_engine_ctx, + >>> stride_tm=1024, stride_fm=1024, + >>> mask=mask, USE_COPY_ENGINE=True, + >>> from_base_ptr=base_ptr, to_base_ptr=base_ptr) """ translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases, hint) - data = tl.load(from_ptr, mask=mask) + if not USE_COPY_ENGINE: + data = tl.load(from_ptr, mask=mask) + + tl.store(translated_to_ptr, data, mask=mask) + else: + ctx = copy_engine_ctx + (6 * to_rank) + queue_ptr_u32 = tl.load(ctx + 0).to(tl.pointer_type(tl.uint32)) + read_ptr = tl.load(ctx + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(ctx + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(ctx + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(ctx + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(ctx + 5).to(tl.pointer_type(tl.uint64)) + + # dst_ptr_val = tl.min(translated_to_ptr.to(tl.uint64), axis=-1) + dst_ptr_val0 = tl.min(translated_to_ptr.to(tl.uint64)) + # Extract source address (min of pointer block where data is stored) + src_ptr_u64 = from_ptr.to(tl.uint64) + # src_ptr_val = tl.min(src_ptr_u64, axis=-1) + src_ptr_val0 = tl.min(src_ptr_u64) + # max_src_ptr = tl.max(src_ptr_u64, axis=0) + + # Infer element size from pointer type + # src_ptr is a block of pointers with a specific element type (e.g., pointer) + # The pointer dtype tells us the element type, which has a known size + # Map Triton dtypes to their byte sizes + ptr_dtype = from_ptr.dtype.element_ty # Get the element type that the pointer points to + + # Get element size in bytes from the dtype + # tl.float16 -> 2, tl.float32 -> 4, tl.float64 -> 8, etc. + if ptr_dtype == tl.float16 or ptr_dtype == tl.bfloat16: + element_size_bytes = 2 + elif ptr_dtype == tl.float32 or ptr_dtype == tl.int32 or ptr_dtype == tl.uint32: + element_size_bytes = 4 + elif ptr_dtype == tl.float64 or ptr_dtype == tl.int64 or ptr_dtype == tl.uint64: + element_size_bytes = 8 + elif ptr_dtype == tl.int8 or ptr_dtype == tl.uint8: + element_size_bytes = 1 + elif ptr_dtype == tl.int16 or ptr_dtype == tl.uint16: + element_size_bytes = 2 + else: + # Default to 4 bytes for unknown types + element_size_bytes = 4 + + # Determine packet size based on copy type + # Linear copy packet: 32 bytes for 1D, Sub-window copy packet: 80 bytes for 2D + # IS_2D_COPY is a compile-time constant for proper branch elimination + mask_int = mask.to(tl.int32) + command_in_bytes_u32 = 80 if IS_2D_COPY else 32 + command_in_bytes = command_in_bytes_u32.to(tl.uint64) + + # Acquire space in the queue + base, offset = anvil.acquire( + queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes + ) + + # Write padding NOPs if we wrapped around + anvil.place_nop_packet(queue_ptr_u32, base, offset) + + # Place the appropriate packet type + packet_offset_bytes = base + offset - tl.store(translated_to_ptr, data, mask=mask) + if not IS_2D_COPY: + # For 1D copies, mask is 1D, so just sum all elements + num_elements = tl.sum(mask_int, axis=0) + size_bytes = (num_elements * element_size_bytes).to(tl.uint32) + + # Place linear copy packet for 1D/flat copies + anvil.place_copy_packet( + queue_ptr_u32, + packet_offset_bytes, + size_bytes, + src_ptr_val0, + dst_ptr_val0, + ) + else: + # For 2D copies, mask is 2D [M, N], use axis operations + num_elements_per_stride = tl.max(tl.sum(mask_int, axis=-1)) + num_strides = tl.max(tl.sum(mask_int, axis=0)) + size_bytes = (num_elements_per_stride * element_size_bytes).to(tl.uint32) + src_stride = (stride_fm * element_size_bytes).to(tl.uint32) + dst_stride = (stride_tm * element_size_bytes).to(tl.uint32) + + # Place sub-window copy packet for 2D tiled copies + # Calculate base addresses and offsets for sub-window copy + src_base = from_base_ptr.to(tl.uint64) + dst_base = __translate(to_base_ptr, from_rank, to_rank, heap_bases).to(tl.uint64) + + # Calculate tile offset from base + tile_offset_bytes = src_ptr_val0 - src_base + src_y_val = (tile_offset_bytes // src_stride).to(tl.uint32) + src_x_val = (tile_offset_bytes % src_stride).to(tl.uint32) + + tile_offset_bytes_dst = dst_ptr_val0 - dst_base + dst_y_val = (tile_offset_bytes_dst // dst_stride).to(tl.uint32) + dst_x_val = (tile_offset_bytes_dst % dst_stride).to(tl.uint32) + + anvil.place_sub_window_copy_packet( + queue_ptr_u32, + packet_offset_bytes, + src_base, + dst_base, + tile_width=size_bytes, + tile_height=num_strides, + src_buffer_pitch=src_stride, + dst_buffer_pitch=dst_stride, + src_x=src_x_val, + src_y=src_y_val, + dst_x=dst_x_val, + dst_y=dst_y_val, + ) + + # Submit the command to the queue + pending_wptr = base + offset + command_in_bytes + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) + + +@triton.jit +def nontemporal_store(addr, value): + tl.inline_asm_elementwise( + asm="""flat_store_dwordx2 $1 $2 sc0 nt; s_waitcnt vmcnt(0)""", + constraints=("=r,v,v"), # =r used for dummy return to satisfy compiler requirement + args=[addr, value], + dtype=tl.int32, # return not used + is_pure=False, + pack=1, + ) + + +# TODO rename or add nt +@triton.jit +def nontemporal_load(addr): + val = tl.inline_asm_elementwise( + asm="""flat_load_dwordx2 $0 $1 sc0 sc1; s_waitcnt vmcnt(0)""", + constraints=("=v,v"), + args=[addr], + dtype=tl.uint64, + is_pure=False, + pack=1, + ) + return val + + +@triton.jit +def nontemporal_atomic_add(addr, value): + old = tl.inline_asm_elementwise( + asm="""flat_atomic_add_x2 $0 $1 sc0 sc1; s_waitcnt vmcnt(0)""", + constraints=("=v,v,v"), + args=[addr, value], + dtype=tl.uint64, + is_pure=False, + pack=1, + ) + return old + + +# @triton.jit +# def nontemporal_compare_exchange(addr, cmp_low, cmp_high, val_low, val_high): +# # data_128bit = tl.cat([cmp_low, cmp_high, val_low, val_high]) +# data_128bit = tl.make_vector([cmp_low, cmp_high, val_low, val_high], type=tl.uint32) +# old = tl.inline_asm_elementwise( +# asm="""flat_atomic_cmpswap_x2 $0 $1 $2 sc0 nt; s_waitcnt vmcnt(0)""", +# constraints=("=v,v,v"), +# args=[addr, data_128bit], +# dtype=tl.uint64, +# is_pure=False, +# pack=1, +# ) +# return True # TODO if old == cmp else False + + +# @triton.jit +# def signal_ce(to_ptr, from_rank, to_rank, heap_bases, ce_handle, mask=None): +# """ +# Copies data from the current rank's local memory to the specified rank's memory. +# This function performs a memory write operation by loading data from the current +# rank's `from_ptr`, translating the `to_ptr` from the current rank's address +# space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. +# If the `to_rank` is the same as the current rank, this function performs a local copy operation. + +# Args: +# from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. +# to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. +# from_rank (int): The current rank ID from which to read the data. +# to_rank (int): The `to_rank` ID to which the data will be written. +# heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. +# mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + +# Returns: +# None + +# Example: +# >>> @triton.jit +# >>> def kernel(local_ptr, remote_ptr, heap_bases): +# >>> from_rank = 0 +# >>> to_rank = 1 +# >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) +# """ + +# handle = ce_handle # iris.get_copy_engine_handle(to_rank) +# queue_ptr_u32 = tl.load(handle + 0).to(tl.pointer_type(tl.uint32)) +# read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) +# write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) +# doorbell_ptr = tl.load(handle + 3).to(tl.pointer_type(tl.uint64)) +# cached_write_ptr = tl.load(handle + 4).to(tl.pointer_type(tl.uint64)) +# committed_write_ptr = tl.load(handle + 5).to(tl.pointer_type(tl.uint64)) + +# translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) +# dst_ptr_val = translated_to_ptr.to(tl.uint64) + +# command_in_bytes = 32 +# # Acquire space +# base = anvil.acquire(queue_ptr_u32, read_ptr, write_ptr, doorbell_ptr, cached_write_ptr, committed_write_ptr, command_in_bytes) + +# # Place command packet +# slot_ptr_u32 = queue_ptr_u32 + (base // 4) +# anvil.place_atomic_packet(slot_ptr_u32, dst_ptr_val) + +# # Submit command +# anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, command_in_bytes) @triton.jit def atomic_add( - pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None + pointer, + val, + from_rank, + to_rank, + heap_bases, + mask=None, + sem=None, + scope=None, + hint: tl.constexpr = None, + copy_engine_ctx=None, + USE_COPY_ENGINE: tl.constexpr = False, ): """ Performs an atomic add at the specified rank's memory location. @@ -2032,7 +2333,45 @@ def atomic_add( >>> old_val = iris.atomic_add(ptr, increment, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) - return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) + if not USE_COPY_ENGINE: + return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) + else: + handle = copy_engine_ctx + (6 * to_rank) + queue_ptr_u32 = tl.load(handle + 0).to(tl.pointer_type(tl.uint32)) + read_ptr = tl.load(handle + 1).to(tl.pointer_type(tl.uint64)) + write_ptr = tl.load(handle + 2).to(tl.pointer_type(tl.uint64)) + doorbell_ptr = tl.load(handle + 3).to(tl.pointer_type(tl.uint64)) + cached_write_ptr = tl.load(handle + 4).to(tl.pointer_type(tl.uint64)) + committed_write_ptr = tl.load(handle + 5).to(tl.pointer_type(tl.uint64)) + + dst_ptr_val = translated_ptr.to(tl.uint64) + + command_in_bytes = 32 + # Acquire space (returns base index and wraparound offset) + base, offset = anvil.acquire( + # base = anvil.acquire( + queue_ptr_u32, + read_ptr, + write_ptr, + doorbell_ptr, + cached_write_ptr, + committed_write_ptr, + command_in_bytes, + ) + # tl.device_print("offset ", offset) + + # Write padding NOPs if we wrapped around + anvil.place_nop_packet(queue_ptr_u32, base, offset) + + # Calculate packet position (base + offset for wraparound) + packet_offset_bytes = base + offset + + # Place command packet + anvil.place_atomic_packet(queue_ptr_u32, packet_offset_bytes, dst_ptr_val) + + # Submit command + pending_wptr = base + offset + command_in_bytes + anvil.submit(write_ptr, doorbell_ptr, committed_write_ptr, base, pending_wptr) @triton.jit diff --git a/test.py b/test.py new file mode 100644 index 00000000..dab152af --- /dev/null +++ b/test.py @@ -0,0 +1,21 @@ +import sys + +sys.path.append("./iris/experimental") + +import my_module as anvil + +print("Get isntance") + +instance = anvil.AnvilLib.get_instance() +print("initialize") +instance.init() + +print("Connect 0 to 1") + +instance.connect(0, 1, 1) + +queue = instance.get_sdma_queue(0, 1, 0) + +# handle = queue.device_handle() + +handle = anvil.get_handle_as_tensor(queue)