diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 1a06f284..24ddd014 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -144,7 +144,7 @@ def _translate(self, ptr, from_rank, to_rank): return translated_ptr @gluon.jit - def load(self, pointer, from_rank, mask=None, other=None): + def load(self, pointer, from_rank, mask=None, other=None, cache_modifier=None, volatile=False): """ Loads a value from the specified rank's memory location to the current rank. @@ -153,6 +153,17 @@ def load(self, pointer, from_rank, mask=None, other=None): from_rank: The rank ID from which to read the data mask: Optional mask for conditional loading other: Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined. + cache_modifier (str, optional): Controls cache behavior of the load. + + Supported values: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + Ensures global coherence by invalidating stale GPU cache lines. + + volatile (bool, optional): If True, disables compiler optimizations that + could reorder or eliminate the load. Defaults to False. Returns: The loaded value from the target memory location @@ -162,11 +173,11 @@ def load(self, pointer, from_rank, mask=None, other=None): >>> data = ctx.load(buffer + offsets, 1, mask=mask) """ translated_ptr = self._translate(pointer, self.cur_rank, from_rank) - result = gl.load(translated_ptr, mask=mask, other=other) + result = gl.load(translated_ptr, mask=mask, other=other, cache_modifier=cache_modifier, volatile=volatile) return result @gluon.jit - def store(self, pointer, value, to_rank, mask=None): + def store(self, pointer, value, to_rank, mask=None, cache_modifier=None): """ Writes data from the current rank to the specified rank's memory location. @@ -175,16 +186,25 @@ def store(self, pointer, value, to_rank, mask=None): value: The value to store to_rank: The rank ID to which the data will be written mask: Optional mask for conditional storing + cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Example: >>> # Store from current rank to rank 1 >>> ctx.store(buffer + offsets, values, 1, mask=mask) """ translated_ptr = self._translate(pointer, self.cur_rank, to_rank) - gl.store(translated_ptr, value, mask=mask) + gl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) @gluon.jit - def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None): + def get( + self, from_ptr, to_ptr, from_rank, mask=None, other=None, load_cache_modifier=None, store_cache_modifier=None + ): """ Copies data from the specified rank's memory to the current rank's local memory. @@ -194,17 +214,31 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None): from_rank: The rank ID from which to read the data mask: Optional mask for conditional operations other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Example: >>> # Copy from rank 1 to current rank's local memory >>> ctx.get(remote_ptr + offsets, local_ptr + offsets, 1, mask=mask) """ translated_from_ptr = self._translate(from_ptr, self.cur_rank, from_rank) - data = gl.load(translated_from_ptr, mask=mask, other=other) - gl.store(to_ptr, data, mask=mask) + data = gl.load(translated_from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) + gl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @gluon.jit - def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None): + def put( + self, from_ptr, to_ptr, to_rank, mask=None, other=None, load_cache_modifier=None, store_cache_modifier=None + ): """ Copies data from the current rank's local memory to the specified rank's memory. @@ -214,17 +248,39 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None): to_rank: The rank ID to which the data will be written mask: Optional mask for conditional operations other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Example: >>> # Copy from current rank's local memory to rank 1 >>> ctx.put(local_ptr + offsets, remote_ptr + offsets, 1, mask=mask) """ translated_to_ptr = self._translate(to_ptr, self.cur_rank, to_rank) - data = gl.load(from_ptr, mask=mask, other=other) - gl.store(translated_to_ptr, data, mask=mask) + data = gl.load(from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) + gl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @gluon.jit - def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None): + def copy( + self, + src_ptr, + dst_ptr, + from_rank, + to_rank, + mask=None, + other=None, + load_cache_modifier=None, + store_cache_modifier=None, + ): """ Copies data from the specified rank's memory into the destination rank's memory. @@ -241,6 +297,18 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None): to_rank: The rank ID that will receive the data (destination rank) mask: Optional mask for conditional operations other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Example: >>> # Copy from rank 1 to rank 0 (current rank must be either 1 or 0) @@ -262,8 +330,8 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None): translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) - data = gl.load(translated_src, mask=mask, other=other) - gl.store(translated_dst, data, mask=mask) + data = gl.load(translated_src, mask=mask, other=other, cache_modifier=load_cache_modifier) + gl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier) @gluon.jit def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None): diff --git a/iris/iris.py b/iris/iris.py index fe7e2d4d..c03edd5e 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1523,7 +1523,16 @@ def _translate(self, ptr, from_rank, to_rank, hint: tl.constexpr = None): return __translate(ptr, from_rank, to_rank, self.heap_bases, hint) @triton.jit - def load(self, pointer, from_rank, mask=None, hint: tl.constexpr = None): + def load( + self, + pointer, + from_rank, + mask=None, + other=None, + cache_modifier=None, + volatile=False, + hint: tl.constexpr = None, + ): """ Loads a value from the specified rank's memory location. @@ -1536,6 +1545,18 @@ def load(self, pointer, from_rank, mask=None, hint: tl.constexpr = None): pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `from_rank`'s address space. from_rank (int): The rank ID from which to read the data. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined. Defaults to None. + cache_modifier (str, optional): Controls cache behavior of the load. + + Supported values: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + Ensures global coherence by invalidating stale GPU cache lines. + + volatile (bool, optional): If True, disables compiler optimizations that + could reorder or eliminate the load. Defaults to False. hint (int or tuple, optional): Vectorization hint for the translated pointer. Defaults to None. Returns: @@ -1545,11 +1566,11 @@ def load(self, pointer, from_rank, mask=None, hint: tl.constexpr = None): >>> data = ctx.load(buffer + offsets, from_rank=1, mask=mask) """ translated_ptr = self._translate(pointer, self.rank, from_rank, hint) - result = tl.load(translated_ptr, mask=mask) + result = tl.load(translated_ptr, mask=mask, other=other, cache_modifier=cache_modifier, volatile=volatile) return result @triton.jit - def store(self, pointer, value, to_rank, mask=None, hint: tl.constexpr = None): + def store(self, pointer, value, to_rank, mask=None, cache_modifier=None, hint: tl.constexpr = None): """ Writes data to the specified rank's memory location. @@ -1563,6 +1584,13 @@ def store(self, pointer, value, to_rank, mask=None, hint: tl.constexpr = None): value (Block): The tensor of elements to be stored. to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. + cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None @@ -1571,10 +1599,20 @@ def store(self, pointer, value, to_rank, mask=None, hint: tl.constexpr = None): >>> ctx.store(buffer + offsets, values, to_rank=1, mask=mask) """ translated_ptr = self._translate(pointer, self.rank, to_rank, hint) - tl.store(translated_ptr, value, mask=mask) + tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) @triton.jit - def get(self, from_ptr, to_ptr, from_rank, mask=None, hint: tl.constexpr = None): + def get( + self, + from_ptr, + to_ptr, + from_rank, + mask=None, + other=None, + load_cache_modifier=None, + store_cache_modifier=None, + hint: tl.constexpr = None, + ): """ Copies data from the specified rank's memory into current rank's local memory. @@ -1588,6 +1626,19 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None, hint: tl.constexpr = None) to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer to local memory in current rank where the data will be written. from_rank (int): The rank ID from which to read the data. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None @@ -1596,11 +1647,21 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None, hint: tl.constexpr = None) >>> ctx.get(remote_ptr + offsets, local_ptr + offsets, from_rank=1, mask=mask) """ translated_from_ptr = self._translate(from_ptr, self.rank, from_rank, hint) - data = tl.load(translated_from_ptr, mask=mask) - tl.store(to_ptr, data, mask=mask) + data = tl.load(translated_from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) + tl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit - def put(self, from_ptr, to_ptr, to_rank, mask=None, hint: tl.constexpr = None): + def put( + self, + from_ptr, + to_ptr, + to_rank, + mask=None, + other=None, + load_cache_modifier=None, + store_cache_modifier=None, + hint: tl.constexpr = None, + ): """ Copies data from current rank's local memory to the specified rank's memory. @@ -1614,6 +1675,19 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None, hint: tl.constexpr = None): to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that references memory in `to_rank`. to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None @@ -1622,11 +1696,22 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None, hint: tl.constexpr = None): >>> ctx.put(local_ptr + offsets, remote_ptr + offsets, to_rank=1, mask=mask) """ translated_to_ptr = self._translate(to_ptr, self.rank, to_rank, hint) - data = tl.load(from_ptr, mask=mask) - tl.store(translated_to_ptr, data, mask=mask) + data = tl.load(from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) + tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit - def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, hint: tl.constexpr = None): + def copy( + self, + src_ptr, + dst_ptr, + from_rank, + to_rank, + mask=None, + other=None, + load_cache_modifier=None, + store_cache_modifier=None, + hint: tl.constexpr = None, + ): """ Copies data from one rank's memory to another rank's memory. @@ -1643,6 +1728,19 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, hint: tl.constex from_rank (int): The rank ID that owns `src_ptr` (source rank). to_rank (int): The rank ID that will receive the data (destination rank). mask (Block of triton.int1, optional): If mask[idx] is false, do not load from src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None @@ -1670,8 +1768,8 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, hint: tl.constex translated_src = tl.max_contiguous(tl.multiple_of(translated_src, hint), hint) translated_dst = tl.max_contiguous(tl.multiple_of(translated_dst, hint), hint) - data = tl.load(translated_src, mask=mask) - tl.store(translated_dst, data, mask=mask) + data = tl.load(translated_src, mask=mask, other=other, cache_modifier=load_cache_modifier) + tl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): diff --git a/tests/unittests/test_device_context_cache_modifiers.py b/tests/unittests/test_device_context_cache_modifiers.py new file mode 100644 index 00000000..216e7f18 --- /dev/null +++ b/tests/unittests/test_device_context_cache_modifiers.py @@ -0,0 +1,533 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from iris import DeviceContext +from itertools import product + + +# === Kernel Definitions === + + +@triton.jit +def device_context_load_cache_modifier_kernel( + context_tensor, + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + cache_modifier: tl.constexpr, + volatile: tl.constexpr, +): + """Test DeviceContext.load() with cache_modifier and volatile.""" + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + + pid = tl.program_id(0) + partner = int((cur_rank + num_ranks // 2) % num_ranks) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + result = ctx.load( + data + offsets, + from_rank=partner, + mask=mask, + cache_modifier=cache_modifier, + volatile=volatile, + ) + tl.store(results + offsets, result, mask=mask) + + +@triton.jit +def device_context_store_cache_modifier_kernel( + context_tensor, + source, + target, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + to_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + cache_modifier: tl.constexpr, +): + """Test DeviceContext.store() with cache_modifier.""" + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + data = tl.load(source + offsets, mask=mask) + ctx.store(target + offsets, data, to_rank=to_rank, mask=mask, cache_modifier=cache_modifier) + + +@triton.jit +def device_context_get_cache_modifier_kernel( + context_tensor, + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Test DeviceContext.get() with load_cache_modifier and store_cache_modifier.""" + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + acc = tl.zeros([BLOCK_SIZE], dtype=data.type.element_ty) + + for target_rank in range(num_ranks): + ctx.get( + data + offsets, + results + offsets, + from_rank=target_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + acc += tl.load(results + offsets, mask=mask) + + tl.store(results + offsets, acc, mask=mask) + + +@triton.jit +def device_context_put_cache_modifier_kernel( + context_tensor, + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + to_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Test DeviceContext.put() with load_cache_modifier and store_cache_modifier.""" + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + ctx.put( + data + offsets, + results + offsets, + to_rank=to_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +@triton.jit +def device_context_copy_local_read_remote_write_kernel( + context_tensor, + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Test DeviceContext.copy() with cache modifiers (local read, remote write).""" + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + for target_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * cur_rank + dest_data = results + BLOCK_SIZE * cur_rank + ctx.copy( + src_data + offsets, + dest_data + offsets, + from_rank=cur_rank, + to_rank=target_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +@triton.jit +def device_context_copy_remote_read_local_write_kernel( + context_tensor, + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Test DeviceContext.copy() with cache modifiers (remote read, local write).""" + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + for source_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * source_rank + dest_data = results + BLOCK_SIZE * source_rank + ctx.copy( + src_data + offsets, + dest_data + offsets, + from_rank=source_rank, + to_rank=cur_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +# === Cache modifier lists === + +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] +VOLATILE_OPTIONS = [False, True] + + +# === Test Functions === + + +@pytest.mark.parametrize("cache_modifier,volatile", list(product(LOAD_CACHE_MODIFIERS, VOLATILE_OPTIONS))) +def test_device_context_load_cache_modifiers(cache_modifier, volatile): + """Test DeviceContext.load() with various cache modifiers and volatile settings.""" + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() + partner = int((cur_rank + num_ranks // 2) % num_ranks) + + context_tensor = ctx.get_device_context() + + BLOCK_SIZE = 16 + data = ctx.full((BLOCK_SIZE,), cur_rank, dtype=torch.float32) + results = ctx.zeros_like(data) + + ctx.barrier() + + grid = lambda meta: (1,) + device_context_load_cache_modifier_kernel[grid]( + context_tensor, data, results, cur_rank, num_ranks, BLOCK_SIZE, cache_modifier, volatile + ) + ctx.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * partner + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(f"LOAD test failed with cache_modifier={cache_modifier}, volatile={volatile}") + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + + +@pytest.mark.parametrize("cache_modifier", STORE_CACHE_MODIFIERS) +def test_device_context_store_cache_modifiers_local(cache_modifier): + """Test DeviceContext.store() local (from_rank == to_rank) with various cache modifiers.""" + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() + + # For local store, we need partner == cur_rank; use a different kernel approach. + # We'll test with partner = cur_rank by calling the kernel but verifying store to self. + context_tensor = ctx.get_device_context() + + BLOCK_SIZE = 16 + source = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + target = ctx.zeros(BLOCK_SIZE, dtype=torch.float32) + + ctx.barrier() + + # We override the kernel to store to itself (to_rank == cur_rank). + @triton.jit + def local_store_kernel( + context_tensor, + source, + target, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + cache_modifier: tl.constexpr, + ): + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + pid = tl.program_id(0) + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + data = tl.load(source + offsets, mask=mask) + ctx.store(target + offsets, data, to_rank=cur_rank, mask=mask, cache_modifier=cache_modifier) + + grid = lambda meta: (1,) + local_store_kernel[grid](context_tensor, source, target, cur_rank, num_ranks, BLOCK_SIZE, cache_modifier) + ctx.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(target, expected, rtol=0, atol=0) + except AssertionError as e: + print(f"LOCAL STORE test failed with cache_modifier={cache_modifier}") + print(e) + raise + + +@pytest.mark.parametrize("cache_modifier", STORE_CACHE_MODIFIERS) +def test_device_context_store_cache_modifiers_remote(cache_modifier): + """Test DeviceContext.store() remote (from_rank != to_rank) with various cache modifiers.""" + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() + + if num_ranks < 2: + pytest.skip("Remote store test requires at least 2 ranks") + + context_tensor = ctx.get_device_context() + + BLOCK_SIZE = 16 + source = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + target = ctx.zeros(BLOCK_SIZE, dtype=torch.float32) + + ctx.barrier() + + remote_rank = (cur_rank + 1) % num_ranks + grid = lambda meta: (1,) + if cur_rank == 0: + device_context_store_cache_modifier_kernel[grid]( + context_tensor, source, target, cur_rank, num_ranks, remote_rank, BLOCK_SIZE, cache_modifier + ) + + ctx.barrier() + + if cur_rank == 1: + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(target, expected, rtol=0, atol=0) + except AssertionError as e: + print(f"REMOTE STORE test failed with cache_modifier={cache_modifier}") + print(e) + raise + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_device_context_get_cache_modifiers(load_cache_modifier, store_cache_modifier): + """Test DeviceContext.get() with various cache modifiers.""" + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() + + context_tensor = ctx.get_device_context() + + BLOCK_SIZE = 16 + data = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros_like(data) + + ctx.barrier() + + grid = lambda meta: (1,) + device_context_get_cache_modifier_kernel[grid]( + context_tensor, data, results, cur_rank, num_ranks, BLOCK_SIZE, load_cache_modifier, store_cache_modifier + ) + ctx.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * num_ranks + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"GET test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_device_context_put_cache_modifiers_local(load_cache_modifier, store_cache_modifier): + """Test DeviceContext.put() local (from_rank == to_rank) with various cache modifiers.""" + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() + + context_tensor = ctx.get_device_context() + + BLOCK_SIZE = 16 + data = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros_like(data) + + ctx.barrier() + + grid = lambda meta: (1,) + device_context_put_cache_modifier_kernel[grid]( + context_tensor, + data, + results, + cur_rank, + num_ranks, + cur_rank, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + ) + ctx.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"LOCAL PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + raise + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_device_context_put_cache_modifiers_remote(load_cache_modifier, store_cache_modifier): + """Test DeviceContext.put() remote (from_rank != to_rank) with various cache modifiers.""" + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() + + if num_ranks < 2: + pytest.skip("Remote put test requires at least 2 ranks") + + context_tensor = ctx.get_device_context() + + BLOCK_SIZE = 16 + data = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros(BLOCK_SIZE, dtype=torch.float32) + + ctx.barrier() + + remote_rank = (cur_rank + 1) % num_ranks + grid = lambda meta: (1,) + if cur_rank == 0: + device_context_put_cache_modifier_kernel[grid]( + context_tensor, + data, + results, + cur_rank, + num_ranks, + remote_rank, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + ) + + ctx.barrier() + + if cur_rank == 1: + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"REMOTE PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + raise + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_device_context_copy_local_read_remote_write(load_cache_modifier, store_cache_modifier): + """Test DeviceContext.copy() local read → remote write with various cache modifiers.""" + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() + + context_tensor = ctx.get_device_context() + + BLOCK_SIZE = 16 + data = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + + ctx.barrier() + + grid = lambda meta: (1,) + device_context_copy_local_read_remote_write_kernel[grid]( + context_tensor, data, results, cur_rank, num_ranks, BLOCK_SIZE, load_cache_modifier, store_cache_modifier + ) + + ctx.barrier() + + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], + torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device), + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", + list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)), +) +def test_device_context_copy_remote_read_local_write(load_cache_modifier, store_cache_modifier): + """Test DeviceContext.copy() remote read → local write with various cache modifiers.""" + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() + + context_tensor = ctx.get_device_context() + + BLOCK_SIZE = 16 + data = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + + ctx.barrier() + + grid = lambda meta: (1,) + device_context_copy_remote_read_local_write_kernel[grid]( + context_tensor, data, results, cur_rank, num_ranks, BLOCK_SIZE, load_cache_modifier, store_cache_modifier + ) + + ctx.barrier() + + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], + torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device), + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) diff --git a/tests/unittests/test_gluon_cache_modifiers.py b/tests/unittests/test_gluon_cache_modifiers.py new file mode 100644 index 00000000..818b6d7f --- /dev/null +++ b/tests/unittests/test_gluon_cache_modifiers.py @@ -0,0 +1,558 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import pytest +from triton.experimental import gluon +from triton.experimental.gluon import language as gl +import iris.experimental.iris_gluon as iris_gl +from itertools import product + + +# === Kernel Definitions === + + +@gluon.jit +def load_cache_modifier_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + source_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + cache_modifier: gl.constexpr, + volatile: gl.constexpr, +): + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + partner = int((source_rank + num_ranks // 2) % num_ranks) + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + result = ctx.load(data + offsets, partner, mask=mask, cache_modifier=cache_modifier, volatile=volatile) + gl.store(results + offsets, result, mask=mask) + + +@gluon.jit +def store_cache_modifier_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + destination_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + cache_modifier: gl.constexpr, +): + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + value = gl.load(data + offsets, mask=mask) + + for dst_rank in range(num_ranks): + ctx.store(results + offsets, value, dst_rank, mask=mask, cache_modifier=cache_modifier) + + +@gluon.jit +def get_cache_modifier_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + cur_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + load_cache_modifier: gl.constexpr, + store_cache_modifier: gl.constexpr, +): + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + acc = gl.zeros([BLOCK_SIZE], dtype=gl.float32, layout=layout) + + for target_rank in range(num_ranks): + ctx.get( + data + offsets, + results + offsets, + target_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + acc = acc + gl.load(results + offsets, mask=mask) + + gl.store(results + offsets, acc, mask=mask) + + +@gluon.jit +def put_cache_modifier_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + cur_rank: gl.constexpr, + to_rank: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + load_cache_modifier: gl.constexpr, + store_cache_modifier: gl.constexpr, +): + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + ctx.put( + data + offsets, + results + offsets, + to_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +@gluon.jit +def copy_local_read_remote_write_cache_modifier_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + cur_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + load_cache_modifier: gl.constexpr, + store_cache_modifier: gl.constexpr, +): + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + for target_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * cur_rank + dest_data = results + BLOCK_SIZE * cur_rank + ctx.copy( + src_data + offsets, + dest_data + offsets, + cur_rank, + target_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +@gluon.jit +def copy_remote_read_local_write_cache_modifier_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + cur_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + load_cache_modifier: gl.constexpr, + store_cache_modifier: gl.constexpr, +): + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + for source_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * source_rank + dest_data = results + BLOCK_SIZE * source_rank + ctx.copy( + src_data + offsets, + dest_data + offsets, + source_rank, + cur_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +# === Cache modifier lists === + +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] +VOLATILE_OPTIONS = [False, True] + + +# === Test Functions === + + +@pytest.mark.parametrize("cache_modifier,volatile", list(product(LOAD_CACHE_MODIFIERS, VOLATILE_OPTIONS))) +def test_gluon_load_cache_modifiers(cache_modifier, volatile): + """Test IrisDeviceCtx.load() with various cache modifiers and volatile settings.""" + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + context_tensor = ctx.get_device_context() + source_rank = ctx.get_rank() + partner = int((source_rank + num_ranks // 2) % num_ranks) + + BLOCK_SIZE = 16 + data = ctx.full((BLOCK_SIZE,), source_rank, dtype=torch.float32) + results = ctx.zeros_like(data) + + ctx.barrier() + + grid = (1,) + load_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + data, + results, + source_rank, + num_ranks, + BLOCK_SIZE, + cache_modifier, + volatile, + num_warps=1, + ) + ctx.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * partner + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(f"LOAD test failed with cache_modifier={cache_modifier}, volatile={volatile}") + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + ctx.barrier() + del ctx + import gc + + gc.collect() + + +@pytest.mark.parametrize("cache_modifier", STORE_CACHE_MODIFIERS) +def test_gluon_store_cache_modifiers(cache_modifier): + """Test IrisDeviceCtx.store() with various cache modifiers.""" + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + context_tensor = ctx.get_device_context() + destination_rank = ctx.get_rank() + + BLOCK_SIZE = 16 + src = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros_like(src) + + ctx.barrier() + + grid = (1,) + store_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + src, + results, + destination_rank, + num_ranks, + BLOCK_SIZE, + cache_modifier, + num_warps=1, + ) + ctx.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(f"STORE test failed with cache_modifier={cache_modifier}") + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + ctx.barrier() + del ctx + import gc + + gc.collect() + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_gluon_get_cache_modifiers(load_cache_modifier, store_cache_modifier): + """Test IrisDeviceCtx.get() with various cache modifiers.""" + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + context_tensor = ctx.get_device_context() + cur_rank = ctx.get_rank() + + BLOCK_SIZE = 16 + data = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros_like(data) + + ctx.barrier() + + grid = (1,) + get_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + data, + results, + cur_rank, + num_ranks, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + num_warps=1, + ) + ctx.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * num_ranks + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"GET test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + ctx.barrier() + del ctx + import gc + + gc.collect() + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_gluon_put_cache_modifiers_local(load_cache_modifier, store_cache_modifier): + """Test IrisDeviceCtx.put() local (to_rank == cur_rank) with various cache modifiers.""" + ctx = iris_gl.iris(1 << 20) + cur_rank = ctx.get_rank() + context_tensor = ctx.get_device_context() + + BLOCK_SIZE = 16 + data = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros_like(data) + + ctx.barrier() + + grid = (1,) + put_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + data, + results, + cur_rank, + cur_rank, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + num_warps=1, + ) + ctx.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"LOCAL PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + raise + finally: + ctx.barrier() + del ctx + import gc + + gc.collect() + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_gluon_put_cache_modifiers_remote(load_cache_modifier, store_cache_modifier): + """Test IrisDeviceCtx.put() remote (to_rank != cur_rank) with various cache modifiers.""" + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() + context_tensor = ctx.get_device_context() + + if num_ranks < 2: + pytest.skip("Remote put test requires at least 2 ranks") + + BLOCK_SIZE = 16 + data = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros(BLOCK_SIZE, dtype=torch.float32) + + ctx.barrier() + + remote_rank = (cur_rank + 1) % num_ranks + grid = (1,) + if cur_rank == 0: + put_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + data, + results, + cur_rank, + remote_rank, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + num_warps=1, + ) + + ctx.barrier() + + if cur_rank == 1: + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"REMOTE PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + raise + + ctx.barrier() + del ctx + import gc + + gc.collect() + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_gluon_copy_local_read_remote_write(load_cache_modifier, store_cache_modifier): + """Test IrisDeviceCtx.copy() local read → remote write with various cache modifiers.""" + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + context_tensor = ctx.get_device_context() + cur_rank = ctx.get_rank() + + BLOCK_SIZE = 16 + data = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + + ctx.barrier() + + grid = (1,) + copy_local_read_remote_write_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + data, + results, + cur_rank, + num_ranks, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + num_warps=1, + ) + + ctx.barrier() + + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], + torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device), + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + + ctx.barrier() + del ctx + import gc + + gc.collect() + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", + list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)), +) +def test_gluon_copy_remote_read_local_write(load_cache_modifier, store_cache_modifier): + """Test IrisDeviceCtx.copy() remote read → local write with various cache modifiers.""" + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + context_tensor = ctx.get_device_context() + cur_rank = ctx.get_rank() + + BLOCK_SIZE = 16 + data = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + + ctx.barrier() + + grid = (1,) + copy_remote_read_local_write_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + data, + results, + cur_rank, + num_ranks, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + num_warps=1, + ) + + ctx.barrier() + + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], + torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device), + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + + ctx.barrier() + del ctx + import gc + + gc.collect()