diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 41b23c881..1a06f284e 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -144,7 +144,7 @@ def _translate(self, ptr, from_rank, to_rank): return translated_ptr @gluon.jit - def load(self, pointer, from_rank, mask=None): + def load(self, pointer, from_rank, mask=None, other=None): """ Loads a value from the specified rank's memory location to the current rank. @@ -152,6 +152,7 @@ def load(self, pointer, from_rank, mask=None): pointer: Pointer in the `from_rank`'s address space from_rank: The rank ID from which to read the data mask: Optional mask for conditional loading + other: Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined. Returns: The loaded value from the target memory location @@ -161,7 +162,7 @@ def load(self, pointer, from_rank, mask=None): >>> data = ctx.load(buffer + offsets, 1, mask=mask) """ translated_ptr = self._translate(pointer, self.cur_rank, from_rank) - result = gl.load(translated_ptr, mask=mask) + result = gl.load(translated_ptr, mask=mask, other=other) return result @gluon.jit @@ -183,7 +184,7 @@ def store(self, pointer, value, to_rank, mask=None): gl.store(translated_ptr, value, mask=mask) @gluon.jit - def get(self, from_ptr, to_ptr, from_rank, mask=None): + def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None): """ Copies data from the specified rank's memory to the current rank's local memory. @@ -192,17 +193,18 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None): to_ptr: Pointer to local memory in current rank from_rank: The rank ID from which to read the data mask: Optional mask for conditional operations + other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Example: >>> # Copy from rank 1 to current rank's local memory >>> ctx.get(remote_ptr + offsets, local_ptr + offsets, 1, mask=mask) """ translated_from_ptr = self._translate(from_ptr, self.cur_rank, from_rank) - data = gl.load(translated_from_ptr, mask=mask) + data = gl.load(translated_from_ptr, mask=mask, other=other) gl.store(to_ptr, data, mask=mask) @gluon.jit - def put(self, from_ptr, to_ptr, to_rank, mask=None): + def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None): """ Copies data from the current rank's local memory to the specified rank's memory. @@ -211,17 +213,18 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None): to_ptr: Pointer to remote memory in `to_rank`'s address space to_rank: The rank ID to which the data will be written mask: Optional mask for conditional operations + other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Example: >>> # Copy from current rank's local memory to rank 1 >>> ctx.put(local_ptr + offsets, remote_ptr + offsets, 1, mask=mask) """ translated_to_ptr = self._translate(to_ptr, self.cur_rank, to_rank) - data = gl.load(from_ptr, mask=mask) + data = gl.load(from_ptr, mask=mask, other=other) gl.store(translated_to_ptr, data, mask=mask) @gluon.jit - def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None): + def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None): """ Copies data from the specified rank's memory into the destination rank's memory. @@ -237,6 +240,7 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None): from_rank: The rank ID that owns `src_ptr` (source rank) to_rank: The rank ID that will receive the data (destination rank) mask: Optional mask for conditional operations + other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Example: >>> # Copy from rank 1 to rank 0 (current rank must be either 1 or 0) @@ -258,7 +262,7 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None): translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) - data = gl.load(translated_src, mask=mask) + data = gl.load(translated_src, mask=mask, other=other) gl.store(translated_dst, data, mask=mask) @gluon.jit diff --git a/iris/iris.py b/iris/iris.py index 923109795..fe7e2d4da 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1895,14 +1895,29 @@ def atomic_max(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hin @triton.jit -def load(pointer, to_rank, from_rank, heap_bases, mask=None, hint: tl.constexpr = None): +def load( + pointer, + to_rank, + from_rank, + heap_bases, + mask=None, + other=None, + cache_modifier=None, + volatile=False, + hint: tl.constexpr = None, +): """ Loads a value from the specified rank's memory location. This function performs a memory read operation by translating the pointer from the `from_rank`'s address space to the `to_rank`'s address space and loading - data from the target memory location. If the `from_rank` and `to_rank` are the same, - this function performs a local load operation. + data from the target memory location. The load is **local** when + ``to_rank == from_rank``, and **remote** (cross-GPU) otherwise. + + The ``cache_modifier`` is passed through to the underlying ``tl.load()`` + call. Cache modifiers control instruction-level cache behavior by setting + the appropriate scope (``SC0``, ``SC1``) and non-temporal (``NT``) bits + in the load instruction, following the CDNA ISA. Args: pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. @@ -1910,6 +1925,18 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None, hint: tl.constexpr from_rank (int): The rank ID from which to read the data. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined. Defaults to None. + cache_modifier (str, optional): Controls cache behavior of the load. + + Supported values: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + Ensures global coherence by invalidating stale GPU cache lines. + + volatile (bool, optional): If True, disables compiler optimizations that + could reorder or eliminate the load. hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: @@ -1925,7 +1952,7 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None, hint: tl.constexpr >>> return data """ translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases, hint) - result = tl.load(translated_ptr, mask=mask) + result = tl.load(translated_ptr, mask=mask, other=other, cache_modifier=cache_modifier, volatile=volatile) return result @@ -1938,14 +1965,20 @@ def store( heap_bases, mask=None, hint: tl.constexpr = None, + cache_modifier=None, ): """ Writes data to the specified rank's memory location. This function performs a memory write operation by translating the pointer from the `from_rank`'s address space to the `to_rank`'s address space and storing - the provided data to the target memory location. If the `from_rank` and `to_rank` are the same, - this function performs a local store operation. + the provided data to the target memory location. The store is **local** when + ``from_rank == to_rank``, and **remote** (cross-GPU) otherwise. + + The ``cache_modifier`` is passed through to the underlying ``tl.store()`` + call. Cache modifiers control instruction-level cache behavior by setting + the appropriate scope (``SC0``, ``SC1``) and non-temporal (``NT``) bits + in the store instruction, following the CDNA ISA. Args: pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. @@ -1955,6 +1988,13 @@ def store( heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). + cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None @@ -1969,19 +2009,33 @@ def store( >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) - tl.store(translated_ptr, value, mask=mask) + tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) @triton.jit -def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None, hint: tl.constexpr = None): +def copy( + src_ptr, + dst_ptr, + from_rank, + to_rank, + cur_rank, + heap_bases, + mask=None, + other=None, + load_cache_modifier=None, + store_cache_modifier=None, + hint: tl.constexpr = None, +): """ Copies data from the specified rank's memory into the destination rank's memory. This function performs the transfer by translating `src_ptr` from the `from_rank`'s address space to the `to_rank`'s address space, performing a masked load from the translated source, and storing the loaded data to `dst_ptr` in the `to_rank` memory location. - If `from_rank` and `to_rank` are the same, this function performs a local copy operation. It is undefined behaviour if neither `from_rank` nor `to_rank` is the `cur_rank`. + The load is from ``from_rank`` (remote if ``from_rank != cur_rank``) and the store is to + ``to_rank`` (remote if ``to_rank != cur_rank``). + Args: src_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s local memory from which to read data. dst_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `to_rank`'s local memory where the data will be written. @@ -1990,6 +2044,19 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None, cur_rank (int): The rank ID issuing the copy operation. Must be either `from_rank` or `to_rank`. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointers. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: @@ -2024,19 +2091,32 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None, translated_src = tl.max_contiguous(tl.multiple_of(translated_src, hint), hint) translated_dst = tl.max_contiguous(tl.multiple_of(translated_dst, hint), hint) - data = tl.load(translated_src, mask=mask) - tl.store(translated_dst, data, mask=mask) + data = tl.load(translated_src, mask=mask, other=other, cache_modifier=load_cache_modifier) + tl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit -def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.constexpr = None): +def get( + from_ptr, + to_ptr, + from_rank, + to_rank, + heap_bases, + mask=None, + other=None, + load_cache_modifier=None, + store_cache_modifier=None, + hint: tl.constexpr = None, +): """ Copies data from the specified rank's memory to the current rank's local memory. This function performs a memory read operation by translating the `from_ptr` from the current rank's address space to the `from_rank`'s address space, loading data - from the `from_rank` memory location, and storing it to the local `to_ptr`. - If the `from_rank` is the same as the current rank, this function performs a local copy operation. + from the `from_rank`'s memory location, and storing it to the local `to_ptr`. + + The load is **remote** when ``from_rank != to_rank`` (reading from a peer GPU), while the + store is **always local** (writing to `to_ptr` in the current rank's own memory). Args: from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `from_rank`'s address space. Must be the current rank where the pointer is local. @@ -2045,6 +2125,19 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.co to_rank (int): The current rank ID where the data will be stored. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load (remote when ``from_rank != to_rank``). Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. The store is always to local memory (``to_ptr``). Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: @@ -2059,19 +2152,32 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.co """ translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases, hint) - data = tl.load(translated_from_ptr, mask=mask) + data = tl.load(translated_from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) - tl.store(to_ptr, data, mask=mask) + tl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit -def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.constexpr = None): +def put( + from_ptr, + to_ptr, + from_rank, + to_rank, + heap_bases, + mask=None, + other=None, + load_cache_modifier=None, + store_cache_modifier=None, + hint: tl.constexpr = None, +): """ Copies data from the current rank's local memory to the specified rank's memory. This function performs a memory write operation by loading data from the current rank's `from_ptr`, translating the `to_ptr` from the current rank's address space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. - If the `to_rank` is the same as the current rank, this function performs a local copy operation. + + The load is **always local** (reading from the current rank's own ``from_ptr``), while the + store is **remote** when ``from_rank != to_rank`` (writing to a peer GPU). Args: from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. @@ -2080,6 +2186,20 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.co to_rank (int): The `to_rank` ID to which the data will be written. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. + + load_cache_modifier (str, optional): Controls cache behavior of the load (always local). Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store (remote when ``from_rank != to_rank``). Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: @@ -2094,9 +2214,9 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, hint: tl.co """ translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases, hint) - data = tl.load(from_ptr, mask=mask) + data = tl.load(from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) - tl.store(translated_to_ptr, data, mask=mask) + tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit diff --git a/tests/unittests/test_copy_cache_modifiers.py b/tests/unittests/test_copy_cache_modifiers.py new file mode 100644 index 000000000..ff1be762b --- /dev/null +++ b/tests/unittests/test_copy_cache_modifiers.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from itertools import product + + +@triton.jit +def copy_kernel_local_read_remote_write( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Copy from local memory to remote memory (local read, remote write)""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + # Copy from current rank to other ranks. + # Both load and store cache modifiers are supported on local and remote ops. + for target_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * cur_rank + dest_data = results + BLOCK_SIZE * cur_rank + iris.copy( + src_data + offsets, + dest_data + offsets, + cur_rank, + target_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +@triton.jit +def copy_kernel_remote_read_local_write( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Copy from remote memory to local memory (remote read, local write)""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + # Copy from other ranks to current rank. + # Both load and store cache modifiers are supported on local and remote ops. + for source_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * source_rank + dest_data = results + BLOCK_SIZE * source_rank + iris.copy( + src_data + offsets, + dest_data + offsets, + source_rank, + cur_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +# Define cache modifiers for load and store operations. +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_copy_local_read_remote_write(load_cache_modifier, store_cache_modifier): + """Test copy: local read → remote write + + Direction: from_rank=cur_rank (local), to_rank=other (remote) + - Load: from LOCAL memory + - Store: to REMOTE memory + """ + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + + # Barrier to ensure all ranks have initialized their data before any rank launches + # the kernel (which reads remote data in the remote-read case). + shmem.barrier() + + grid = lambda meta: (1,) + copy_kernel_local_read_remote_write[grid]( + data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + + shmem.barrier() + + # Verify results - each rank copies its data to all other ranks + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device) + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", + list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)), +) +def test_copy_remote_read_local_write(load_cache_modifier, store_cache_modifier): + """Test copy: remote read → local write + + Direction: from_rank=other (remote), to_rank=cur_rank (local) + - Load: from REMOTE memory + - Store: to LOCAL memory + """ + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + + # Barrier to ensure all ranks have initialized their data before any rank launches + # the kernel (which reads remote data in the remote-read case). + shmem.barrier() + + grid = lambda meta: (1,) + copy_kernel_remote_read_local_write[grid]( + data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + + shmem.barrier() + + # Verify results - each rank pulls data from all ranks + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device) + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) diff --git a/tests/unittests/test_get_cache_modifiers.py b/tests/unittests/test_get_cache_modifiers.py new file mode 100644 index 000000000..cb91ddc43 --- /dev/null +++ b/tests/unittests/test_get_cache_modifiers.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from itertools import product + + +@triton.jit +def get_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + acc = tl.zeros([BLOCK_SIZE], dtype=data.type.element_ty) + + # Loop over all ranks and get data with cache modifiers. + # The load is remote when from_rank != cur_rank; the store to results is always local. + for target_rank in range(num_ranks): + iris.get( + data + offsets, + results + offsets, + cur_rank, + target_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + acc += tl.load(results + offsets, mask=mask) + + # Store the accumulated value back to the output + tl.store(results + offsets, acc, mask=mask) + + +# Define cache modifiers for load and store operations +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_get_cache_modifiers(load_cache_modifier, store_cache_modifier): + """Test get (copy from other rank) with various cache modifiers. + + load_cache_modifier applies to the remote load when from_rank != to_rank. + store_cache_modifier applies to the always-local store to to_ptr. + """ + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = lambda meta: (1,) + get_kernel[grid]( + data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + shmem.barrier() + + # Verify the result - should get data from all ranks (including self) + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * num_ranks + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"GET test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + print("Expected:", expected) + print("Actual:", results) + raise diff --git a/tests/unittests/test_get_other_triton.py b/tests/unittests/test_get_other_triton.py new file mode 100644 index 000000000..412d9710a --- /dev/null +++ b/tests/unittests/test_get_other_triton.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris + + +@triton.jit +def get_with_other_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + other_value: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Create a mask that is False for half the elements + mask = offsets < BLOCK_SIZE // 2 + + acc = tl.zeros([BLOCK_SIZE], dtype=data.type.element_ty) + + # Loop over all ranks, get the stored data. + # load to local register, accumulate. + for target_rank in range(num_ranks): + iris.get(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask, other=other_value) + acc += tl.load(results + offsets) + + # Store the accumulated value back to the output. + tl.store(results + offsets, acc) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "BLOCK_SIZE", + [ + 8, + 16, + 32, + ], +) +def test_get_other_api(dtype, BLOCK_SIZE): + # TODO: Adjust heap size. + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + data = shmem.ones(BLOCK_SIZE, dtype=dtype) + results = shmem.zeros_like(data) + + # Use -1 as the "other" value for masked-out elements + other_value = -1.0 + + shmem.barrier() + + grid = lambda meta: (1,) + get_with_other_kernel[grid](data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, other_value) + shmem.barrier() + + # Verify the results + # First half should contain loaded values accumulated from all ranks (num_ranks * 1.0) + # Second half stays at 0.0 because iris.get stores with mask, so masked-out positions + # in `results` are never written; tl.load(results + offsets) reads 0.0 from them. + expected = torch.zeros(BLOCK_SIZE, dtype=dtype, device="cuda") + expected[: BLOCK_SIZE // 2] = num_ranks * 1.0 + expected[BLOCK_SIZE // 2 :] = 0.0 + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + # Final barrier to ensure all ranks complete before test cleanup + # This helps with test isolation when running multiple tests + # Note: shmem.barrier() already does cuda.synchronize() + shmem.barrier() + # Explicitly delete the shmem instance to trigger cleanup + del shmem + # Force garbage collection to ensure IPC handles are cleaned up + import gc + + gc.collect() diff --git a/tests/unittests/test_load_cache_modifiers.py b/tests/unittests/test_load_cache_modifiers.py new file mode 100644 index 000000000..c6d30deed --- /dev/null +++ b/tests/unittests/test_load_cache_modifiers.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from itertools import product + + +@triton.jit +def load_kernel( + data, + results, + source_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + cache_modifier: tl.constexpr, + volatile: tl.constexpr, +): + pid = tl.program_id(0) + + partner = int((source_rank + num_ranks // 2) % num_ranks) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + result = iris.load( + data + offsets, + source_rank, + partner, + heap_bases, + mask=mask, + cache_modifier=cache_modifier, + volatile=volatile, + ) + + tl.store(results + offsets, result, mask=mask) + + +# Define cache modifiers and volatile options +CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +VOLATILE_OPTIONS = [False, True] + + +@pytest.mark.parametrize("cache_modifier,volatile", list(product(CACHE_MODIFIERS, VOLATILE_OPTIONS))) +def test_load_cache_modifiers(cache_modifier, volatile): + """Test load with various cache modifiers and volatile settings.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + source_rank = shmem.get_rank() + partner = int((source_rank + num_ranks // 2) % num_ranks) + + BLOCK_SIZE = 16 + data = shmem.full((BLOCK_SIZE,), source_rank, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = lambda meta: (1,) + load_kernel[grid](data, results, source_rank, num_ranks, BLOCK_SIZE, heap_bases, cache_modifier, volatile) + shmem.barrier() + + # Verify the result - should have loaded from partner rank + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * partner + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise diff --git a/tests/unittests/test_load_other_triton.py b/tests/unittests/test_load_other_triton.py new file mode 100644 index 000000000..e7db690b2 --- /dev/null +++ b/tests/unittests/test_load_other_triton.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris + + +@triton.jit +def load_with_other_kernel( + data, + results, + source_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + other_value: tl.constexpr, +): + pid = tl.program_id(0) + + partner = int((source_rank + num_ranks // 2) % num_ranks) + # Compute start index of this block + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Create a mask that is False for half the elements + mask = offsets < BLOCK_SIZE // 2 + + # Load with mask and other parameter + result = iris.load(data + offsets, source_rank, partner, heap_bases, mask=mask, other=other_value) + tl.store(results + offsets, result) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "BLOCK_SIZE", + [ + 8, + 16, + 32, + ], +) +def test_load_other_api(dtype, BLOCK_SIZE): + # TODO: Adjust heap size. + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + source_rank = shmem.get_rank() + partner = int((source_rank + num_ranks // 2) % num_ranks) + + # Fill data with source rank value so remote reads match expected values: + # each rank's data[i] = source_rank, so loading from partner gives partner's rank value + data = shmem.full((BLOCK_SIZE,), source_rank, dtype=dtype) + results = shmem.zeros_like(data) + + # Use -1 as the "other" value for masked-out elements + other_value = -1.0 + + shmem.barrier() + + grid = lambda meta: (1,) + load_with_other_kernel[grid](data, results, source_rank, num_ranks, BLOCK_SIZE, heap_bases, other_value) + shmem.barrier() + + # Verify the result + # First half should contain loaded values (partner rank) + # Second half should contain the "other" value (-1.0) + expected = torch.zeros(BLOCK_SIZE, dtype=dtype, device="cuda") + expected[: BLOCK_SIZE // 2] = partner + expected[BLOCK_SIZE // 2 :] = other_value + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + # Final barrier to ensure all ranks complete before test cleanup + # This helps with test isolation when running multiple tests + # Note: shmem.barrier() already does cuda.synchronize() + shmem.barrier() + # Explicitly delete the shmem instance to trigger cleanup + del shmem + # Force garbage collection to ensure IPC handles are cleaned up + import gc + + gc.collect() diff --git a/tests/unittests/test_put_cache_modifiers.py b/tests/unittests/test_put_cache_modifiers.py new file mode 100644 index 000000000..12dd3342a --- /dev/null +++ b/tests/unittests/test_put_cache_modifiers.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from itertools import product + + +@triton.jit +def put_kernel( + data, + results, + from_rank: tl.constexpr, + to_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + iris.put( + data + offsets, + results + offsets, + from_rank, + to_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +# Define cache modifiers for load and store operations +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_put_cache_modifiers_local(load_cache_modifier, store_cache_modifier): + """Test local put (from_rank == to_rank) with various cache modifiers.""" + shmem = iris.iris(1 << 20) + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = lambda meta: (1,) + put_kernel[grid]( + data, results, cur_rank, cur_rank, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + shmem.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"LOCAL PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + raise + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_put_cache_modifiers_remote(load_cache_modifier, store_cache_modifier): + """Test remote put (from_rank != to_rank) with various cache modifiers.""" + shmem = iris.iris(1 << 20) + heap_bases = shmem.get_heap_bases() + num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() + + if num_ranks < 2: + pytest.skip("Remote put test requires at least 2 ranks") + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros(BLOCK_SIZE, dtype=torch.float32) + + shmem.barrier() + + # rank 0 puts to rank 1 + remote_rank = (cur_rank + 1) % num_ranks + grid = lambda meta: (1,) + if cur_rank == 0: + put_kernel[grid]( + data, results, cur_rank, remote_rank, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + + shmem.barrier() + + # rank 1 checks the data it received from rank 0 + if cur_rank == 1: + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"REMOTE PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + raise diff --git a/tests/unittests/test_put_other_triton.py b/tests/unittests/test_put_other_triton.py new file mode 100644 index 000000000..51db50f85 --- /dev/null +++ b/tests/unittests/test_put_other_triton.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris + + +@triton.jit +def put_with_other_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + other_value: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Create a mask that is False for half the elements + mask = offsets < BLOCK_SIZE // 2 + + # Put data in all ranks with mask and other parameter + # The "other" value will be used for masked-out elements during the load from data + for target_rank in range(num_ranks): + iris.put(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask, other=other_value) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "BLOCK_SIZE", + [ + 8, + 16, + 32, + ], +) +def test_put_other_api(dtype, BLOCK_SIZE): + # TODO: Adjust heap size. + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + # Fill data with ones + data = shmem.ones(BLOCK_SIZE, dtype=dtype) + results = shmem.zeros_like(data) + + # Use -1 as the "other" value for masked-out elements + other_value = -1.0 + + shmem.barrier() + + grid = lambda meta: (1,) + put_with_other_kernel[grid](data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, other_value) + shmem.barrier() + + # Verify the results + # First half should contain the value 1.0 (from data, written via masked put) + # Second half stays at 0.0 because iris.put stores with mask, so masked-out positions + # in results are never written. + expected = torch.zeros(BLOCK_SIZE, dtype=dtype, device="cuda") + expected[: BLOCK_SIZE // 2] = 1.0 + expected[BLOCK_SIZE // 2 :] = 0.0 + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + # Final barrier to ensure all ranks complete before test cleanup + # This helps with test isolation when running multiple tests + # Note: shmem.barrier() already does cuda.synchronize() + shmem.barrier() + # Explicitly delete the shmem instance to trigger cleanup + del shmem + # Force garbage collection to ensure IPC handles are cleaned up + import gc + + gc.collect() diff --git a/tests/unittests/test_store_cache_modifiers.py b/tests/unittests/test_store_cache_modifiers.py new file mode 100644 index 000000000..97a55d37b --- /dev/null +++ b/tests/unittests/test_store_cache_modifiers.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris + + +@triton.jit +def local_store_kernel( + data, + results, + cur_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + cache_modifier: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + value = tl.load(data + offsets, mask=mask) + # Local store: from_rank == to_rank == cur_rank + iris.store(results + offsets, value, cur_rank, cur_rank, heap_bases, mask=mask, cache_modifier=cache_modifier) + + +@triton.jit +def remote_store_kernel( + data, + results, + from_rank: tl.constexpr, + to_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + cache_modifier: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + value = tl.load(data + offsets, mask=mask) + # Remote store: from_rank != to_rank + iris.store(results + offsets, value, from_rank, to_rank, heap_bases, mask=mask, cache_modifier=cache_modifier) + + +# Define cache modifiers for store operations +CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] + + +@pytest.mark.parametrize("cache_modifier", CACHE_MODIFIERS) +def test_store_cache_modifiers_local(cache_modifier): + """Test local store (from_rank == to_rank) with various cache modifiers.""" + shmem = iris.iris(1 << 20) + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + src = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(src) + + shmem.barrier() + + grid = lambda meta: (1,) + local_store_kernel[grid](src, results, cur_rank, BLOCK_SIZE, heap_bases, cache_modifier) + shmem.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(f"LOCAL STORE test failed with cache_modifier={cache_modifier}") + print(e) + raise + + +@pytest.mark.parametrize("cache_modifier", CACHE_MODIFIERS) +def test_store_cache_modifiers_remote(cache_modifier): + """Test remote store (from_rank != to_rank) with various cache modifiers.""" + shmem = iris.iris(1 << 20) + heap_bases = shmem.get_heap_bases() + num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() + + if num_ranks < 2: + pytest.skip("Remote store test requires at least 2 ranks") + + BLOCK_SIZE = 16 + src = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros(BLOCK_SIZE, dtype=torch.float32) + + shmem.barrier() + + # rank 0 stores to rank 1 + remote_rank = (cur_rank + 1) % num_ranks + grid = lambda meta: (1,) + if cur_rank == 0: + remote_store_kernel[grid](src, results, cur_rank, remote_rank, BLOCK_SIZE, heap_bases, cache_modifier) + + shmem.barrier() + + # rank 1 checks the data it received from rank 0 + if cur_rank == 1: + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(f"REMOTE STORE test failed with cache_modifier={cache_modifier}") + print(e) + raise