From fd2a3aa21a7a9bfb4f747903f4e6b1699479df3e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 22 Mar 2026 02:50:09 +0000 Subject: [PATCH 1/6] Initial plan From 4e1fa9b95a08eda866eba96dd2e5c76fb16425a1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 22 Mar 2026 02:58:35 +0000 Subject: [PATCH 2/6] Port cache_modifier, volatile, and other to DeviceContext and Gluon APIs - DeviceContext.load(): add other, cache_modifier, volatile params - DeviceContext.store(): add cache_modifier param - DeviceContext.get(): add other, load_cache_modifier, store_cache_modifier params - DeviceContext.put(): add other, load_cache_modifier, store_cache_modifier params - DeviceContext.copy(): add other, load_cache_modifier, store_cache_modifier params - IrisDeviceCtx.load(): add cache_modifier, volatile params - IrisDeviceCtx.store(): add cache_modifier param - IrisDeviceCtx.get(): add load_cache_modifier, store_cache_modifier params - IrisDeviceCtx.put(): add load_cache_modifier, store_cache_modifier params - IrisDeviceCtx.copy(): add load_cache_modifier, store_cache_modifier params - Add tests: test_device_context_cache_modifiers.py - Add tests: test_gluon_cache_modifiers.py Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/8ce4f35c-fe72-4506-89d2-e79687a16a09 --- iris/experimental/iris_gluon.py | 45 +- iris/iris.py | 79 ++- .../test_device_context_cache_modifiers.py | 533 +++++++++++++++++ tests/unittests/test_gluon_cache_modifiers.py | 558 ++++++++++++++++++ 4 files changed, 1189 insertions(+), 26 deletions(-) create mode 100644 tests/unittests/test_device_context_cache_modifiers.py create mode 100644 tests/unittests/test_gluon_cache_modifiers.py diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 1a06f284..706f92f2 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -144,7 +144,7 @@ def _translate(self, ptr, from_rank, to_rank): return translated_ptr @gluon.jit - def load(self, pointer, from_rank, mask=None, other=None): + def load(self, pointer, from_rank, mask=None, other=None, cache_modifier="", volatile=False): """ Loads a value from the specified rank's memory location to the current rank. @@ -153,6 +153,8 @@ def load(self, pointer, from_rank, mask=None, other=None): from_rank: The rank ID from which to read the data mask: Optional mask for conditional loading other: Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined. + cache_modifier: Controls cache behavior of the load. Supported values: "" (default), ".ca", ".cg", ".cv". + volatile: If True, disables compiler optimizations that could reorder or eliminate the load. Defaults to False. Returns: The loaded value from the target memory location @@ -162,11 +164,11 @@ def load(self, pointer, from_rank, mask=None, other=None): >>> data = ctx.load(buffer + offsets, 1, mask=mask) """ translated_ptr = self._translate(pointer, self.cur_rank, from_rank) - result = gl.load(translated_ptr, mask=mask, other=other) + result = gl.load(translated_ptr, mask=mask, other=other, cache_modifier=cache_modifier, volatile=volatile) return result @gluon.jit - def store(self, pointer, value, to_rank, mask=None): + def store(self, pointer, value, to_rank, mask=None, cache_modifier=""): """ Writes data from the current rank to the specified rank's memory location. @@ -175,16 +177,17 @@ def store(self, pointer, value, to_rank, mask=None): value: The value to store to_rank: The rank ID to which the data will be written mask: Optional mask for conditional storing + cache_modifier: Controls cache behavior of the store. Supported values: "" (default), ".wb", ".cg", ".cs", ".wt". Example: >>> # Store from current rank to rank 1 >>> ctx.store(buffer + offsets, values, 1, mask=mask) """ translated_ptr = self._translate(pointer, self.cur_rank, to_rank) - gl.store(translated_ptr, value, mask=mask) + gl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) @gluon.jit - def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None): + def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None, load_cache_modifier="", store_cache_modifier=""): """ Copies data from the specified rank's memory to the current rank's local memory. @@ -194,17 +197,19 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None): from_rank: The rank ID from which to read the data mask: Optional mask for conditional operations other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. + load_cache_modifier: Controls cache behavior of the load. Supported values: "" (default), ".ca", ".cg", ".cv". + store_cache_modifier: Controls cache behavior of the store. Supported values: "" (default), ".wb", ".cg", ".cs", ".wt". Example: >>> # Copy from rank 1 to current rank's local memory >>> ctx.get(remote_ptr + offsets, local_ptr + offsets, 1, mask=mask) """ translated_from_ptr = self._translate(from_ptr, self.cur_rank, from_rank) - data = gl.load(translated_from_ptr, mask=mask, other=other) - gl.store(to_ptr, data, mask=mask) + data = gl.load(translated_from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) + gl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @gluon.jit - def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None): + def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None, load_cache_modifier="", store_cache_modifier=""): """ Copies data from the current rank's local memory to the specified rank's memory. @@ -214,17 +219,29 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None): to_rank: The rank ID to which the data will be written mask: Optional mask for conditional operations other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. + load_cache_modifier: Controls cache behavior of the load. Supported values: "" (default), ".ca", ".cg", ".cv". + store_cache_modifier: Controls cache behavior of the store. Supported values: "" (default), ".wb", ".cg", ".cs", ".wt". Example: >>> # Copy from current rank's local memory to rank 1 >>> ctx.put(local_ptr + offsets, remote_ptr + offsets, 1, mask=mask) """ translated_to_ptr = self._translate(to_ptr, self.cur_rank, to_rank) - data = gl.load(from_ptr, mask=mask, other=other) - gl.store(translated_to_ptr, data, mask=mask) + data = gl.load(from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) + gl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @gluon.jit - def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None): + def copy( + self, + src_ptr, + dst_ptr, + from_rank, + to_rank, + mask=None, + other=None, + load_cache_modifier="", + store_cache_modifier="", + ): """ Copies data from the specified rank's memory into the destination rank's memory. @@ -241,6 +258,8 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None): to_rank: The rank ID that will receive the data (destination rank) mask: Optional mask for conditional operations other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. + load_cache_modifier: Controls cache behavior of the load. Supported values: "" (default), ".ca", ".cg", ".cv". + store_cache_modifier: Controls cache behavior of the store. Supported values: "" (default), ".wb", ".cg", ".cs", ".wt". Example: >>> # Copy from rank 1 to rank 0 (current rank must be either 1 or 0) @@ -262,8 +281,8 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None): translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) - data = gl.load(translated_src, mask=mask, other=other) - gl.store(translated_dst, data, mask=mask) + data = gl.load(translated_src, mask=mask, other=other, cache_modifier=load_cache_modifier) + gl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier) @gluon.jit def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None): diff --git a/iris/iris.py b/iris/iris.py index fe7e2d4d..3c66bbda 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1523,7 +1523,16 @@ def _translate(self, ptr, from_rank, to_rank, hint: tl.constexpr = None): return __translate(ptr, from_rank, to_rank, self.heap_bases, hint) @triton.jit - def load(self, pointer, from_rank, mask=None, hint: tl.constexpr = None): + def load( + self, + pointer, + from_rank, + mask=None, + other=None, + cache_modifier=None, + volatile=False, + hint: tl.constexpr = None, + ): """ Loads a value from the specified rank's memory location. @@ -1536,6 +1545,9 @@ def load(self, pointer, from_rank, mask=None, hint: tl.constexpr = None): pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `from_rank`'s address space. from_rank (int): The rank ID from which to read the data. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined. Defaults to None. + cache_modifier (str, optional): Controls cache behavior of the load. Supported values: None (default, same as ".ca"), ".ca", ".cg", ".cv". Defaults to None. + volatile (bool, optional): If True, disables compiler optimizations that could reorder or eliminate the load. Defaults to False. hint (int or tuple, optional): Vectorization hint for the translated pointer. Defaults to None. Returns: @@ -1545,11 +1557,11 @@ def load(self, pointer, from_rank, mask=None, hint: tl.constexpr = None): >>> data = ctx.load(buffer + offsets, from_rank=1, mask=mask) """ translated_ptr = self._translate(pointer, self.rank, from_rank, hint) - result = tl.load(translated_ptr, mask=mask) + result = tl.load(translated_ptr, mask=mask, other=other, cache_modifier=cache_modifier, volatile=volatile) return result @triton.jit - def store(self, pointer, value, to_rank, mask=None, hint: tl.constexpr = None): + def store(self, pointer, value, to_rank, mask=None, cache_modifier=None, hint: tl.constexpr = None): """ Writes data to the specified rank's memory location. @@ -1563,6 +1575,7 @@ def store(self, pointer, value, to_rank, mask=None, hint: tl.constexpr = None): value (Block): The tensor of elements to be stored. to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. + cache_modifier (str, optional): Controls cache behavior of the store. Supported values: None (default, same as ".wb"), ".wb", ".cg", ".cs", ".wt". Defaults to None. Returns: None @@ -1571,10 +1584,20 @@ def store(self, pointer, value, to_rank, mask=None, hint: tl.constexpr = None): >>> ctx.store(buffer + offsets, values, to_rank=1, mask=mask) """ translated_ptr = self._translate(pointer, self.rank, to_rank, hint) - tl.store(translated_ptr, value, mask=mask) + tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) @triton.jit - def get(self, from_ptr, to_ptr, from_rank, mask=None, hint: tl.constexpr = None): + def get( + self, + from_ptr, + to_ptr, + from_rank, + mask=None, + other=None, + load_cache_modifier=None, + store_cache_modifier=None, + hint: tl.constexpr = None, + ): """ Copies data from the specified rank's memory into current rank's local memory. @@ -1588,6 +1611,9 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None, hint: tl.constexpr = None) to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer to local memory in current rank where the data will be written. from_rank (int): The rank ID from which to read the data. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values: None (default, same as ".ca"), ".ca", ".cg", ".cv". Defaults to None. + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values: None (default, same as ".wb"), ".wb", ".cg", ".cs", ".wt". Defaults to None. Returns: None @@ -1596,11 +1622,21 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None, hint: tl.constexpr = None) >>> ctx.get(remote_ptr + offsets, local_ptr + offsets, from_rank=1, mask=mask) """ translated_from_ptr = self._translate(from_ptr, self.rank, from_rank, hint) - data = tl.load(translated_from_ptr, mask=mask) - tl.store(to_ptr, data, mask=mask) + data = tl.load(translated_from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) + tl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit - def put(self, from_ptr, to_ptr, to_rank, mask=None, hint: tl.constexpr = None): + def put( + self, + from_ptr, + to_ptr, + to_rank, + mask=None, + other=None, + load_cache_modifier=None, + store_cache_modifier=None, + hint: tl.constexpr = None, + ): """ Copies data from current rank's local memory to the specified rank's memory. @@ -1614,6 +1650,9 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None, hint: tl.constexpr = None): to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that references memory in `to_rank`. to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values: None (default, same as ".ca"), ".ca", ".cg", ".cv". Defaults to None. + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values: None (default, same as ".wb"), ".wb", ".cg", ".cs", ".wt". Defaults to None. Returns: None @@ -1622,11 +1661,22 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None, hint: tl.constexpr = None): >>> ctx.put(local_ptr + offsets, remote_ptr + offsets, to_rank=1, mask=mask) """ translated_to_ptr = self._translate(to_ptr, self.rank, to_rank, hint) - data = tl.load(from_ptr, mask=mask) - tl.store(translated_to_ptr, data, mask=mask) + data = tl.load(from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) + tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit - def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, hint: tl.constexpr = None): + def copy( + self, + src_ptr, + dst_ptr, + from_rank, + to_rank, + mask=None, + other=None, + load_cache_modifier=None, + store_cache_modifier=None, + hint: tl.constexpr = None, + ): """ Copies data from one rank's memory to another rank's memory. @@ -1643,6 +1693,9 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, hint: tl.constex from_rank (int): The rank ID that owns `src_ptr` (source rank). to_rank (int): The rank ID that will receive the data (destination rank). mask (Block of triton.int1, optional): If mask[idx] is false, do not load from src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. + other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values: None (default, same as ".ca"), ".ca", ".cg", ".cv". Defaults to None. + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values: None (default, same as ".wb"), ".wb", ".cg", ".cs", ".wt". Defaults to None. Returns: None @@ -1670,8 +1723,8 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, hint: tl.constex translated_src = tl.max_contiguous(tl.multiple_of(translated_src, hint), hint) translated_dst = tl.max_contiguous(tl.multiple_of(translated_dst, hint), hint) - data = tl.load(translated_src, mask=mask) - tl.store(translated_dst, data, mask=mask) + data = tl.load(translated_src, mask=mask, other=other, cache_modifier=load_cache_modifier) + tl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): diff --git a/tests/unittests/test_device_context_cache_modifiers.py b/tests/unittests/test_device_context_cache_modifiers.py new file mode 100644 index 00000000..6568d8a1 --- /dev/null +++ b/tests/unittests/test_device_context_cache_modifiers.py @@ -0,0 +1,533 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from iris import DeviceContext +from itertools import product + + +# === Kernel Definitions === + + +@triton.jit +def device_context_load_cache_modifier_kernel( + context_tensor, + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + cache_modifier: tl.constexpr, + volatile: tl.constexpr, +): + """Test DeviceContext.load() with cache_modifier and volatile.""" + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + + pid = tl.program_id(0) + partner = int((cur_rank + num_ranks // 2) % num_ranks) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + result = ctx.load( + data + offsets, + from_rank=partner, + mask=mask, + cache_modifier=cache_modifier, + volatile=volatile, + ) + tl.store(results + offsets, result, mask=mask) + + +@triton.jit +def device_context_store_cache_modifier_kernel( + context_tensor, + source, + target, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + cache_modifier: tl.constexpr, +): + """Test DeviceContext.store() with cache_modifier.""" + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + + pid = tl.program_id(0) + partner = int((cur_rank + num_ranks // 2) % num_ranks) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + data = tl.load(source + offsets, mask=mask) + ctx.store(target + offsets, data, to_rank=partner, mask=mask, cache_modifier=cache_modifier) + + +@triton.jit +def device_context_get_cache_modifier_kernel( + context_tensor, + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Test DeviceContext.get() with load_cache_modifier and store_cache_modifier.""" + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + acc = tl.zeros([BLOCK_SIZE], dtype=data.type.element_ty) + + for target_rank in range(num_ranks): + ctx.get( + data + offsets, + results + offsets, + from_rank=target_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + acc += tl.load(results + offsets, mask=mask) + + tl.store(results + offsets, acc, mask=mask) + + +@triton.jit +def device_context_put_cache_modifier_kernel( + context_tensor, + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + to_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Test DeviceContext.put() with load_cache_modifier and store_cache_modifier.""" + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + ctx.put( + data + offsets, + results + offsets, + to_rank=to_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +@triton.jit +def device_context_copy_local_read_remote_write_kernel( + context_tensor, + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Test DeviceContext.copy() with cache modifiers (local read, remote write).""" + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + for target_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * cur_rank + dest_data = results + BLOCK_SIZE * cur_rank + ctx.copy( + src_data + offsets, + dest_data + offsets, + from_rank=cur_rank, + to_rank=target_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +@triton.jit +def device_context_copy_remote_read_local_write_kernel( + context_tensor, + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Test DeviceContext.copy() with cache modifiers (remote read, local write).""" + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + for source_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * source_rank + dest_data = results + BLOCK_SIZE * source_rank + ctx.copy( + src_data + offsets, + dest_data + offsets, + from_rank=source_rank, + to_rank=cur_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +# === Cache modifier lists === + +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] +VOLATILE_OPTIONS = [False, True] + + +# === Test Functions === + + +@pytest.mark.parametrize("cache_modifier,volatile", list(product(LOAD_CACHE_MODIFIERS, VOLATILE_OPTIONS))) +def test_device_context_load_cache_modifiers(cache_modifier, volatile): + """Test DeviceContext.load() with various cache modifiers and volatile settings.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() + partner = int((cur_rank + num_ranks // 2) % num_ranks) + + context_tensor = shmem.get_device_context() + + BLOCK_SIZE = 16 + data = shmem.full((BLOCK_SIZE,), cur_rank, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = lambda meta: (1,) + device_context_load_cache_modifier_kernel[grid]( + context_tensor, data, results, cur_rank, num_ranks, BLOCK_SIZE, cache_modifier, volatile + ) + shmem.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * partner + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(f"LOAD test failed with cache_modifier={cache_modifier}, volatile={volatile}") + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + + +@pytest.mark.parametrize("cache_modifier", STORE_CACHE_MODIFIERS) +def test_device_context_store_cache_modifiers_local(cache_modifier): + """Test DeviceContext.store() local (from_rank == to_rank) with various cache modifiers.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() + + # For local store, we need partner == cur_rank; use a different kernel approach. + # We'll test with partner = cur_rank by calling the kernel but verifying store to self. + context_tensor = shmem.get_device_context() + + BLOCK_SIZE = 16 + source = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + target = shmem.zeros(BLOCK_SIZE, dtype=torch.float32) + + shmem.barrier() + + # We override the kernel to store to itself (to_rank == cur_rank). + @triton.jit + def local_store_kernel( + context_tensor, + source, + target, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + cache_modifier: tl.constexpr, + ): + ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) + pid = tl.program_id(0) + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + data = tl.load(source + offsets, mask=mask) + ctx.store(target + offsets, data, to_rank=cur_rank, mask=mask, cache_modifier=cache_modifier) + + grid = lambda meta: (1,) + local_store_kernel[grid](context_tensor, source, target, cur_rank, num_ranks, BLOCK_SIZE, cache_modifier) + shmem.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(target, expected, rtol=0, atol=0) + except AssertionError as e: + print(f"LOCAL STORE test failed with cache_modifier={cache_modifier}") + print(e) + raise + + +@pytest.mark.parametrize("cache_modifier", STORE_CACHE_MODIFIERS) +def test_device_context_store_cache_modifiers_remote(cache_modifier): + """Test DeviceContext.store() remote (from_rank != to_rank) with various cache modifiers.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() + + if num_ranks < 2: + pytest.skip("Remote store test requires at least 2 ranks") + + context_tensor = shmem.get_device_context() + + BLOCK_SIZE = 16 + source = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + target = shmem.zeros(BLOCK_SIZE, dtype=torch.float32) + + shmem.barrier() + + remote_rank = (cur_rank + 1) % num_ranks + grid = lambda meta: (1,) + if cur_rank == 0: + device_context_store_cache_modifier_kernel[grid]( + context_tensor, source, target, cur_rank, num_ranks, BLOCK_SIZE, cache_modifier + ) + + shmem.barrier() + + if cur_rank == 1: + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(target, expected, rtol=0, atol=0) + except AssertionError as e: + print(f"REMOTE STORE test failed with cache_modifier={cache_modifier}") + print(e) + raise + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_device_context_get_cache_modifiers(load_cache_modifier, store_cache_modifier): + """Test DeviceContext.get() with various cache modifiers.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() + + context_tensor = shmem.get_device_context() + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = lambda meta: (1,) + device_context_get_cache_modifier_kernel[grid]( + context_tensor, data, results, cur_rank, num_ranks, BLOCK_SIZE, load_cache_modifier, store_cache_modifier + ) + shmem.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * num_ranks + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"GET test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_device_context_put_cache_modifiers_local(load_cache_modifier, store_cache_modifier): + """Test DeviceContext.put() local (from_rank == to_rank) with various cache modifiers.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() + + context_tensor = shmem.get_device_context() + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = lambda meta: (1,) + device_context_put_cache_modifier_kernel[grid]( + context_tensor, + data, + results, + cur_rank, + num_ranks, + cur_rank, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + ) + shmem.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"LOCAL PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + raise + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_device_context_put_cache_modifiers_remote(load_cache_modifier, store_cache_modifier): + """Test DeviceContext.put() remote (from_rank != to_rank) with various cache modifiers.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() + + if num_ranks < 2: + pytest.skip("Remote put test requires at least 2 ranks") + + context_tensor = shmem.get_device_context() + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros(BLOCK_SIZE, dtype=torch.float32) + + shmem.barrier() + + remote_rank = (cur_rank + 1) % num_ranks + grid = lambda meta: (1,) + if cur_rank == 0: + device_context_put_cache_modifier_kernel[grid]( + context_tensor, + data, + results, + cur_rank, + num_ranks, + remote_rank, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + ) + + shmem.barrier() + + if cur_rank == 1: + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"REMOTE PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + raise + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_device_context_copy_local_read_remote_write(load_cache_modifier, store_cache_modifier): + """Test DeviceContext.copy() local read → remote write with various cache modifiers.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() + + context_tensor = shmem.get_device_context() + + BLOCK_SIZE = 16 + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + + shmem.barrier() + + grid = lambda meta: (1,) + device_context_copy_local_read_remote_write_kernel[grid]( + context_tensor, data, results, cur_rank, num_ranks, BLOCK_SIZE, load_cache_modifier, store_cache_modifier + ) + + shmem.barrier() + + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], + torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device), + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", + list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)), +) +def test_device_context_copy_remote_read_local_write(load_cache_modifier, store_cache_modifier): + """Test DeviceContext.copy() remote read → local write with various cache modifiers.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() + + context_tensor = shmem.get_device_context() + + BLOCK_SIZE = 16 + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + + shmem.barrier() + + grid = lambda meta: (1,) + device_context_copy_remote_read_local_write_kernel[grid]( + context_tensor, data, results, cur_rank, num_ranks, BLOCK_SIZE, load_cache_modifier, store_cache_modifier + ) + + shmem.barrier() + + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], + torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device), + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) diff --git a/tests/unittests/test_gluon_cache_modifiers.py b/tests/unittests/test_gluon_cache_modifiers.py new file mode 100644 index 00000000..b97db768 --- /dev/null +++ b/tests/unittests/test_gluon_cache_modifiers.py @@ -0,0 +1,558 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import pytest +from triton.experimental import gluon +from triton.experimental.gluon import language as gl +import iris.experimental.iris_gluon as iris_gl +from itertools import product + + +# === Kernel Definitions === + + +@gluon.jit +def load_cache_modifier_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + source_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + cache_modifier: gl.constexpr, + volatile: gl.constexpr, +): + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + partner = int((source_rank + num_ranks // 2) % num_ranks) + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + result = ctx.load(data + offsets, partner, mask=mask, cache_modifier=cache_modifier, volatile=volatile) + gl.store(results + offsets, result, mask=mask) + + +@gluon.jit +def store_cache_modifier_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + destination_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + cache_modifier: gl.constexpr, +): + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + value = gl.load(data + offsets, mask=mask) + + for dst_rank in range(num_ranks): + ctx.store(results + offsets, value, dst_rank, mask=mask, cache_modifier=cache_modifier) + + +@gluon.jit +def get_cache_modifier_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + cur_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + load_cache_modifier: gl.constexpr, + store_cache_modifier: gl.constexpr, +): + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + acc = gl.zeros([BLOCK_SIZE], dtype=gl.float32, layout=layout) + + for target_rank in range(num_ranks): + ctx.get( + data + offsets, + results + offsets, + target_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + acc = acc + gl.load(results + offsets, mask=mask) + + gl.store(results + offsets, acc, mask=mask) + + +@gluon.jit +def put_cache_modifier_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + cur_rank: gl.constexpr, + to_rank: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + load_cache_modifier: gl.constexpr, + store_cache_modifier: gl.constexpr, +): + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + ctx.put( + data + offsets, + results + offsets, + to_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +@gluon.jit +def copy_local_read_remote_write_cache_modifier_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + cur_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + load_cache_modifier: gl.constexpr, + store_cache_modifier: gl.constexpr, +): + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + for target_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * cur_rank + dest_data = results + BLOCK_SIZE * cur_rank + ctx.copy( + src_data + offsets, + dest_data + offsets, + cur_rank, + target_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +@gluon.jit +def copy_remote_read_local_write_cache_modifier_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + data, + results, + cur_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + load_cache_modifier: gl.constexpr, + store_cache_modifier: gl.constexpr, +): + ctx = IrisDeviceCtx.initialize(context_tensor) + pid = gl.program_id(0) + + block_start = pid * BLOCK_SIZE + layout: gl.constexpr = gl.BlockedLayout([1], [64], [1], [0]) + offsets = block_start + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < BLOCK_SIZE + + for source_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * source_rank + dest_data = results + BLOCK_SIZE * source_rank + ctx.copy( + src_data + offsets, + dest_data + offsets, + source_rank, + cur_rank, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +# === Cache modifier lists === + +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] +VOLATILE_OPTIONS = [False, True] + + +# === Test Functions === + + +@pytest.mark.parametrize("cache_modifier,volatile", list(product(LOAD_CACHE_MODIFIERS, VOLATILE_OPTIONS))) +def test_gluon_load_cache_modifiers(cache_modifier, volatile): + """Test IrisDeviceCtx.load() with various cache modifiers and volatile settings.""" + shmem = iris_gl.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + context_tensor = shmem.get_device_context() + source_rank = shmem.get_rank() + partner = int((source_rank + num_ranks // 2) % num_ranks) + + BLOCK_SIZE = 16 + data = shmem.full((BLOCK_SIZE,), source_rank, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = (1,) + load_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + data, + results, + source_rank, + num_ranks, + BLOCK_SIZE, + cache_modifier, + volatile, + num_warps=1, + ) + shmem.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * partner + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(f"LOAD test failed with cache_modifier={cache_modifier}, volatile={volatile}") + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + shmem.barrier() + del shmem + import gc + + gc.collect() + + +@pytest.mark.parametrize("cache_modifier", STORE_CACHE_MODIFIERS) +def test_gluon_store_cache_modifiers(cache_modifier): + """Test IrisDeviceCtx.store() with various cache modifiers.""" + shmem = iris_gl.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + context_tensor = shmem.get_device_context() + destination_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + src = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(src) + + shmem.barrier() + + grid = (1,) + store_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + src, + results, + destination_rank, + num_ranks, + BLOCK_SIZE, + cache_modifier, + num_warps=1, + ) + shmem.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(f"STORE test failed with cache_modifier={cache_modifier}") + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + shmem.barrier() + del shmem + import gc + + gc.collect() + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_gluon_get_cache_modifiers(load_cache_modifier, store_cache_modifier): + """Test IrisDeviceCtx.get() with various cache modifiers.""" + shmem = iris_gl.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + context_tensor = shmem.get_device_context() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = (1,) + get_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + data, + results, + cur_rank, + num_ranks, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + num_warps=1, + ) + shmem.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * num_ranks + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"GET test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + print("Expected:", expected) + print("Actual:", results) + raise + finally: + shmem.barrier() + del shmem + import gc + + gc.collect() + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_gluon_put_cache_modifiers_local(load_cache_modifier, store_cache_modifier): + """Test IrisDeviceCtx.put() local (to_rank == cur_rank) with various cache modifiers.""" + shmem = iris_gl.iris(1 << 20) + cur_rank = shmem.get_rank() + context_tensor = shmem.get_device_context() + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = (1,) + put_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + data, + results, + cur_rank, + cur_rank, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + num_warps=1, + ) + shmem.barrier() + + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"LOCAL PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + raise + finally: + shmem.barrier() + del shmem + import gc + + gc.collect() + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_gluon_put_cache_modifiers_remote(load_cache_modifier, store_cache_modifier): + """Test IrisDeviceCtx.put() remote (to_rank != cur_rank) with various cache modifiers.""" + shmem = iris_gl.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() + context_tensor = shmem.get_device_context() + + if num_ranks < 2: + pytest.skip("Remote put test requires at least 2 ranks") + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros(BLOCK_SIZE, dtype=torch.float32) + + shmem.barrier() + + remote_rank = (cur_rank + 1) % num_ranks + grid = (1,) + if cur_rank == 0: + put_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + data, + results, + cur_rank, + remote_rank, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + num_warps=1, + ) + + shmem.barrier() + + if cur_rank == 1: + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"REMOTE PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + raise + + shmem.barrier() + del shmem + import gc + + gc.collect() + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_gluon_copy_local_read_remote_write(load_cache_modifier, store_cache_modifier): + """Test IrisDeviceCtx.copy() local read → remote write with various cache modifiers.""" + shmem = iris_gl.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + context_tensor = shmem.get_device_context() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + + shmem.barrier() + + grid = (1,) + copy_local_read_remote_write_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + data, + results, + cur_rank, + num_ranks, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + num_warps=1, + ) + + shmem.barrier() + + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], + torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device), + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + + shmem.barrier() + del shmem + import gc + + gc.collect() + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", + list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)), +) +def test_gluon_copy_remote_read_local_write(load_cache_modifier, store_cache_modifier): + """Test IrisDeviceCtx.copy() remote read → local write with various cache modifiers.""" + shmem = iris_gl.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + context_tensor = shmem.get_device_context() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + + shmem.barrier() + + grid = (1,) + copy_remote_read_local_write_cache_modifier_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + data, + results, + cur_rank, + num_ranks, + BLOCK_SIZE, + load_cache_modifier, + store_cache_modifier, + num_warps=1, + ) + + shmem.barrier() + + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], + torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device), + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + + shmem.barrier() + del shmem + import gc + + gc.collect() From 791e62477df0941334fbdee1bc098cc710ba94ee Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 22 Mar 2026 03:38:50 +0000 Subject: [PATCH 3/6] Rename shmem to ctx in test_device_context_cache_modifiers.py and test_gluon_cache_modifiers.py Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/ee387d87-c4e9-4207-b9d8-de35aba33605 --- .../test_device_context_cache_modifiers.py | 128 ++++++++-------- tests/unittests/test_gluon_cache_modifiers.py | 138 +++++++++--------- 2 files changed, 133 insertions(+), 133 deletions(-) diff --git a/tests/unittests/test_device_context_cache_modifiers.py b/tests/unittests/test_device_context_cache_modifiers.py index 6568d8a1..232ce9be 100644 --- a/tests/unittests/test_device_context_cache_modifiers.py +++ b/tests/unittests/test_device_context_cache_modifiers.py @@ -212,24 +212,24 @@ def device_context_copy_remote_read_local_write_kernel( @pytest.mark.parametrize("cache_modifier,volatile", list(product(LOAD_CACHE_MODIFIERS, VOLATILE_OPTIONS))) def test_device_context_load_cache_modifiers(cache_modifier, volatile): """Test DeviceContext.load() with various cache modifiers and volatile settings.""" - shmem = iris.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - cur_rank = shmem.get_rank() + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() partner = int((cur_rank + num_ranks // 2) % num_ranks) - context_tensor = shmem.get_device_context() + context_tensor = ctx.get_device_context() BLOCK_SIZE = 16 - data = shmem.full((BLOCK_SIZE,), cur_rank, dtype=torch.float32) - results = shmem.zeros_like(data) + data = ctx.full((BLOCK_SIZE,), cur_rank, dtype=torch.float32) + results = ctx.zeros_like(data) - shmem.barrier() + ctx.barrier() grid = lambda meta: (1,) device_context_load_cache_modifier_kernel[grid]( context_tensor, data, results, cur_rank, num_ranks, BLOCK_SIZE, cache_modifier, volatile ) - shmem.barrier() + ctx.barrier() expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * partner @@ -246,19 +246,19 @@ def test_device_context_load_cache_modifiers(cache_modifier, volatile): @pytest.mark.parametrize("cache_modifier", STORE_CACHE_MODIFIERS) def test_device_context_store_cache_modifiers_local(cache_modifier): """Test DeviceContext.store() local (from_rank == to_rank) with various cache modifiers.""" - shmem = iris.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - cur_rank = shmem.get_rank() + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() # For local store, we need partner == cur_rank; use a different kernel approach. # We'll test with partner = cur_rank by calling the kernel but verifying store to self. - context_tensor = shmem.get_device_context() + context_tensor = ctx.get_device_context() BLOCK_SIZE = 16 - source = shmem.ones(BLOCK_SIZE, dtype=torch.float32) - target = shmem.zeros(BLOCK_SIZE, dtype=torch.float32) + source = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + target = ctx.zeros(BLOCK_SIZE, dtype=torch.float32) - shmem.barrier() + ctx.barrier() # We override the kernel to store to itself (to_rank == cur_rank). @triton.jit @@ -280,7 +280,7 @@ def local_store_kernel( grid = lambda meta: (1,) local_store_kernel[grid](context_tensor, source, target, cur_rank, num_ranks, BLOCK_SIZE, cache_modifier) - shmem.barrier() + ctx.barrier() expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") try: @@ -294,20 +294,20 @@ def local_store_kernel( @pytest.mark.parametrize("cache_modifier", STORE_CACHE_MODIFIERS) def test_device_context_store_cache_modifiers_remote(cache_modifier): """Test DeviceContext.store() remote (from_rank != to_rank) with various cache modifiers.""" - shmem = iris.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - cur_rank = shmem.get_rank() + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() if num_ranks < 2: pytest.skip("Remote store test requires at least 2 ranks") - context_tensor = shmem.get_device_context() + context_tensor = ctx.get_device_context() BLOCK_SIZE = 16 - source = shmem.ones(BLOCK_SIZE, dtype=torch.float32) - target = shmem.zeros(BLOCK_SIZE, dtype=torch.float32) + source = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + target = ctx.zeros(BLOCK_SIZE, dtype=torch.float32) - shmem.barrier() + ctx.barrier() remote_rank = (cur_rank + 1) % num_ranks grid = lambda meta: (1,) @@ -316,7 +316,7 @@ def test_device_context_store_cache_modifiers_remote(cache_modifier): context_tensor, source, target, cur_rank, num_ranks, BLOCK_SIZE, cache_modifier ) - shmem.barrier() + ctx.barrier() if cur_rank == 1: expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") @@ -333,23 +333,23 @@ def test_device_context_store_cache_modifiers_remote(cache_modifier): ) def test_device_context_get_cache_modifiers(load_cache_modifier, store_cache_modifier): """Test DeviceContext.get() with various cache modifiers.""" - shmem = iris.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - cur_rank = shmem.get_rank() + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() - context_tensor = shmem.get_device_context() + context_tensor = ctx.get_device_context() BLOCK_SIZE = 16 - data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) - results = shmem.zeros_like(data) + data = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros_like(data) - shmem.barrier() + ctx.barrier() grid = lambda meta: (1,) device_context_get_cache_modifier_kernel[grid]( context_tensor, data, results, cur_rank, num_ranks, BLOCK_SIZE, load_cache_modifier, store_cache_modifier ) - shmem.barrier() + ctx.barrier() expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * num_ranks @@ -370,17 +370,17 @@ def test_device_context_get_cache_modifiers(load_cache_modifier, store_cache_mod ) def test_device_context_put_cache_modifiers_local(load_cache_modifier, store_cache_modifier): """Test DeviceContext.put() local (from_rank == to_rank) with various cache modifiers.""" - shmem = iris.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - cur_rank = shmem.get_rank() + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() - context_tensor = shmem.get_device_context() + context_tensor = ctx.get_device_context() BLOCK_SIZE = 16 - data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) - results = shmem.zeros_like(data) + data = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros_like(data) - shmem.barrier() + ctx.barrier() grid = lambda meta: (1,) device_context_put_cache_modifier_kernel[grid]( @@ -394,7 +394,7 @@ def test_device_context_put_cache_modifiers_local(load_cache_modifier, store_cac load_cache_modifier, store_cache_modifier, ) - shmem.barrier() + ctx.barrier() expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") try: @@ -412,20 +412,20 @@ def test_device_context_put_cache_modifiers_local(load_cache_modifier, store_cac ) def test_device_context_put_cache_modifiers_remote(load_cache_modifier, store_cache_modifier): """Test DeviceContext.put() remote (from_rank != to_rank) with various cache modifiers.""" - shmem = iris.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - cur_rank = shmem.get_rank() + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() if num_ranks < 2: pytest.skip("Remote put test requires at least 2 ranks") - context_tensor = shmem.get_device_context() + context_tensor = ctx.get_device_context() BLOCK_SIZE = 16 - data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) - results = shmem.zeros(BLOCK_SIZE, dtype=torch.float32) + data = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros(BLOCK_SIZE, dtype=torch.float32) - shmem.barrier() + ctx.barrier() remote_rank = (cur_rank + 1) % num_ranks grid = lambda meta: (1,) @@ -442,7 +442,7 @@ def test_device_context_put_cache_modifiers_remote(load_cache_modifier, store_ca store_cache_modifier, ) - shmem.barrier() + ctx.barrier() if cur_rank == 1: expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") @@ -461,28 +461,28 @@ def test_device_context_put_cache_modifiers_remote(load_cache_modifier, store_ca ) def test_device_context_copy_local_read_remote_write(load_cache_modifier, store_cache_modifier): """Test DeviceContext.copy() local read → remote write with various cache modifiers.""" - shmem = iris.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - cur_rank = shmem.get_rank() + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() - context_tensor = shmem.get_device_context() + context_tensor = ctx.get_device_context() BLOCK_SIZE = 16 - data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + data = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) base = cur_rank + num_ranks for i in range(num_ranks): data[i, :] = base * (i + 1) - results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + results = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) - shmem.barrier() + ctx.barrier() grid = lambda meta: (1,) device_context_copy_local_read_remote_write_kernel[grid]( context_tensor, data, results, cur_rank, num_ranks, BLOCK_SIZE, load_cache_modifier, store_cache_modifier ) - shmem.barrier() + ctx.barrier() for rank_id in range(num_ranks): expected_value = (rank_id + num_ranks) * (rank_id + 1) @@ -500,28 +500,28 @@ def test_device_context_copy_local_read_remote_write(load_cache_modifier, store_ ) def test_device_context_copy_remote_read_local_write(load_cache_modifier, store_cache_modifier): """Test DeviceContext.copy() remote read → local write with various cache modifiers.""" - shmem = iris.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - cur_rank = shmem.get_rank() + ctx = iris.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() - context_tensor = shmem.get_device_context() + context_tensor = ctx.get_device_context() BLOCK_SIZE = 16 - data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + data = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) base = cur_rank + num_ranks for i in range(num_ranks): data[i, :] = base * (i + 1) - results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + results = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) - shmem.barrier() + ctx.barrier() grid = lambda meta: (1,) device_context_copy_remote_read_local_write_kernel[grid]( context_tensor, data, results, cur_rank, num_ranks, BLOCK_SIZE, load_cache_modifier, store_cache_modifier ) - shmem.barrier() + ctx.barrier() for rank_id in range(num_ranks): expected_value = (rank_id + num_ranks) * (rank_id + 1) diff --git a/tests/unittests/test_gluon_cache_modifiers.py b/tests/unittests/test_gluon_cache_modifiers.py index b97db768..818b6d7f 100644 --- a/tests/unittests/test_gluon_cache_modifiers.py +++ b/tests/unittests/test_gluon_cache_modifiers.py @@ -209,17 +209,17 @@ def copy_remote_read_local_write_cache_modifier_kernel( @pytest.mark.parametrize("cache_modifier,volatile", list(product(LOAD_CACHE_MODIFIERS, VOLATILE_OPTIONS))) def test_gluon_load_cache_modifiers(cache_modifier, volatile): """Test IrisDeviceCtx.load() with various cache modifiers and volatile settings.""" - shmem = iris_gl.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - context_tensor = shmem.get_device_context() - source_rank = shmem.get_rank() + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + context_tensor = ctx.get_device_context() + source_rank = ctx.get_rank() partner = int((source_rank + num_ranks // 2) % num_ranks) BLOCK_SIZE = 16 - data = shmem.full((BLOCK_SIZE,), source_rank, dtype=torch.float32) - results = shmem.zeros_like(data) + data = ctx.full((BLOCK_SIZE,), source_rank, dtype=torch.float32) + results = ctx.zeros_like(data) - shmem.barrier() + ctx.barrier() grid = (1,) load_cache_modifier_kernel[grid]( @@ -234,7 +234,7 @@ def test_gluon_load_cache_modifiers(cache_modifier, volatile): volatile, num_warps=1, ) - shmem.barrier() + ctx.barrier() expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * partner @@ -247,8 +247,8 @@ def test_gluon_load_cache_modifiers(cache_modifier, volatile): print("Actual:", results) raise finally: - shmem.barrier() - del shmem + ctx.barrier() + del ctx import gc gc.collect() @@ -257,16 +257,16 @@ def test_gluon_load_cache_modifiers(cache_modifier, volatile): @pytest.mark.parametrize("cache_modifier", STORE_CACHE_MODIFIERS) def test_gluon_store_cache_modifiers(cache_modifier): """Test IrisDeviceCtx.store() with various cache modifiers.""" - shmem = iris_gl.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - context_tensor = shmem.get_device_context() - destination_rank = shmem.get_rank() + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + context_tensor = ctx.get_device_context() + destination_rank = ctx.get_rank() BLOCK_SIZE = 16 - src = shmem.ones(BLOCK_SIZE, dtype=torch.float32) - results = shmem.zeros_like(src) + src = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros_like(src) - shmem.barrier() + ctx.barrier() grid = (1,) store_cache_modifier_kernel[grid]( @@ -280,7 +280,7 @@ def test_gluon_store_cache_modifiers(cache_modifier): cache_modifier, num_warps=1, ) - shmem.barrier() + ctx.barrier() expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") @@ -293,8 +293,8 @@ def test_gluon_store_cache_modifiers(cache_modifier): print("Actual:", results) raise finally: - shmem.barrier() - del shmem + ctx.barrier() + del ctx import gc gc.collect() @@ -305,16 +305,16 @@ def test_gluon_store_cache_modifiers(cache_modifier): ) def test_gluon_get_cache_modifiers(load_cache_modifier, store_cache_modifier): """Test IrisDeviceCtx.get() with various cache modifiers.""" - shmem = iris_gl.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - context_tensor = shmem.get_device_context() - cur_rank = shmem.get_rank() + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + context_tensor = ctx.get_device_context() + cur_rank = ctx.get_rank() BLOCK_SIZE = 16 - data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) - results = shmem.zeros_like(data) + data = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros_like(data) - shmem.barrier() + ctx.barrier() grid = (1,) get_cache_modifier_kernel[grid]( @@ -329,7 +329,7 @@ def test_gluon_get_cache_modifiers(load_cache_modifier, store_cache_modifier): store_cache_modifier, num_warps=1, ) - shmem.barrier() + ctx.barrier() expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * num_ranks @@ -344,8 +344,8 @@ def test_gluon_get_cache_modifiers(load_cache_modifier, store_cache_modifier): print("Actual:", results) raise finally: - shmem.barrier() - del shmem + ctx.barrier() + del ctx import gc gc.collect() @@ -356,15 +356,15 @@ def test_gluon_get_cache_modifiers(load_cache_modifier, store_cache_modifier): ) def test_gluon_put_cache_modifiers_local(load_cache_modifier, store_cache_modifier): """Test IrisDeviceCtx.put() local (to_rank == cur_rank) with various cache modifiers.""" - shmem = iris_gl.iris(1 << 20) - cur_rank = shmem.get_rank() - context_tensor = shmem.get_device_context() + ctx = iris_gl.iris(1 << 20) + cur_rank = ctx.get_rank() + context_tensor = ctx.get_device_context() BLOCK_SIZE = 16 - data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) - results = shmem.zeros_like(data) + data = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros_like(data) - shmem.barrier() + ctx.barrier() grid = (1,) put_cache_modifier_kernel[grid]( @@ -379,7 +379,7 @@ def test_gluon_put_cache_modifiers_local(load_cache_modifier, store_cache_modifi store_cache_modifier, num_warps=1, ) - shmem.barrier() + ctx.barrier() expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") try: @@ -391,8 +391,8 @@ def test_gluon_put_cache_modifiers_local(load_cache_modifier, store_cache_modifi print(e) raise finally: - shmem.barrier() - del shmem + ctx.barrier() + del ctx import gc gc.collect() @@ -403,19 +403,19 @@ def test_gluon_put_cache_modifiers_local(load_cache_modifier, store_cache_modifi ) def test_gluon_put_cache_modifiers_remote(load_cache_modifier, store_cache_modifier): """Test IrisDeviceCtx.put() remote (to_rank != cur_rank) with various cache modifiers.""" - shmem = iris_gl.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - cur_rank = shmem.get_rank() - context_tensor = shmem.get_device_context() + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + cur_rank = ctx.get_rank() + context_tensor = ctx.get_device_context() if num_ranks < 2: pytest.skip("Remote put test requires at least 2 ranks") BLOCK_SIZE = 16 - data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) - results = shmem.zeros(BLOCK_SIZE, dtype=torch.float32) + data = ctx.ones(BLOCK_SIZE, dtype=torch.float32) + results = ctx.zeros(BLOCK_SIZE, dtype=torch.float32) - shmem.barrier() + ctx.barrier() remote_rank = (cur_rank + 1) % num_ranks grid = (1,) @@ -433,7 +433,7 @@ def test_gluon_put_cache_modifiers_remote(load_cache_modifier, store_cache_modif num_warps=1, ) - shmem.barrier() + ctx.barrier() if cur_rank == 1: expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") @@ -446,8 +446,8 @@ def test_gluon_put_cache_modifiers_remote(load_cache_modifier, store_cache_modif print(e) raise - shmem.barrier() - del shmem + ctx.barrier() + del ctx import gc gc.collect() @@ -458,20 +458,20 @@ def test_gluon_put_cache_modifiers_remote(load_cache_modifier, store_cache_modif ) def test_gluon_copy_local_read_remote_write(load_cache_modifier, store_cache_modifier): """Test IrisDeviceCtx.copy() local read → remote write with various cache modifiers.""" - shmem = iris_gl.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - context_tensor = shmem.get_device_context() - cur_rank = shmem.get_rank() + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + context_tensor = ctx.get_device_context() + cur_rank = ctx.get_rank() BLOCK_SIZE = 16 - data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + data = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) base = cur_rank + num_ranks for i in range(num_ranks): data[i, :] = base * (i + 1) - results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + results = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) - shmem.barrier() + ctx.barrier() grid = (1,) copy_local_read_remote_write_cache_modifier_kernel[grid]( @@ -487,7 +487,7 @@ def test_gluon_copy_local_read_remote_write(load_cache_modifier, store_cache_mod num_warps=1, ) - shmem.barrier() + ctx.barrier() for rank_id in range(num_ranks): expected_value = (rank_id + num_ranks) * (rank_id + 1) @@ -498,8 +498,8 @@ def test_gluon_copy_local_read_remote_write(load_cache_modifier, store_cache_mod f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" ) - shmem.barrier() - del shmem + ctx.barrier() + del ctx import gc gc.collect() @@ -511,20 +511,20 @@ def test_gluon_copy_local_read_remote_write(load_cache_modifier, store_cache_mod ) def test_gluon_copy_remote_read_local_write(load_cache_modifier, store_cache_modifier): """Test IrisDeviceCtx.copy() remote read → local write with various cache modifiers.""" - shmem = iris_gl.iris(1 << 20) - num_ranks = shmem.get_num_ranks() - context_tensor = shmem.get_device_context() - cur_rank = shmem.get_rank() + ctx = iris_gl.iris(1 << 20) + num_ranks = ctx.get_num_ranks() + context_tensor = ctx.get_device_context() + cur_rank = ctx.get_rank() BLOCK_SIZE = 16 - data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + data = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) base = cur_rank + num_ranks for i in range(num_ranks): data[i, :] = base * (i + 1) - results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + results = ctx.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) - shmem.barrier() + ctx.barrier() grid = (1,) copy_remote_read_local_write_cache_modifier_kernel[grid]( @@ -540,7 +540,7 @@ def test_gluon_copy_remote_read_local_write(load_cache_modifier, store_cache_mod num_warps=1, ) - shmem.barrier() + ctx.barrier() for rank_id in range(num_ranks): expected_value = (rank_id + num_ranks) * (rank_id + 1) @@ -551,8 +551,8 @@ def test_gluon_copy_remote_read_local_write(load_cache_modifier, store_cache_mod f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" ) - shmem.barrier() - del shmem + ctx.barrier() + del ctx import gc gc.collect() From b46c8a79ec75d9040b9701d8b5886af6a172f09a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 22 Mar 2026 04:03:45 +0000 Subject: [PATCH 4/6] Expand cache_modifier/volatile docstrings in DeviceContext and Gluon to match freestanding API format Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/639da603-607c-4d81-b5af-88a77a38db97 --- iris/experimental/iris_gluon.py | 63 ++++++++++++++++++++++++++++----- iris/iris.py | 63 ++++++++++++++++++++++++++++----- 2 files changed, 108 insertions(+), 18 deletions(-) diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 706f92f2..48925634 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -153,8 +153,17 @@ def load(self, pointer, from_rank, mask=None, other=None, cache_modifier="", vol from_rank: The rank ID from which to read the data mask: Optional mask for conditional loading other: Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined. - cache_modifier: Controls cache behavior of the load. Supported values: "" (default), ".ca", ".cg", ".cv". - volatile: If True, disables compiler optimizations that could reorder or eliminate the load. Defaults to False. + cache_modifier (str, optional): Controls cache behavior of the load. + + Supported values: + - "": *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + Ensures global coherence by invalidating stale GPU cache lines. + + volatile (bool, optional): If True, disables compiler optimizations that + could reorder or eliminate the load. Defaults to False. Returns: The loaded value from the target memory location @@ -177,7 +186,13 @@ def store(self, pointer, value, to_rank, mask=None, cache_modifier=""): value: The value to store to_rank: The rank ID to which the data will be written mask: Optional mask for conditional storing - cache_modifier: Controls cache behavior of the store. Supported values: "" (default), ".wb", ".cg", ".cs", ".wt". + cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + + - "": *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Example: >>> # Store from current rank to rank 1 @@ -197,8 +212,18 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None, load_cache_mod from_rank: The rank ID from which to read the data mask: Optional mask for conditional operations other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. - load_cache_modifier: Controls cache behavior of the load. Supported values: "" (default), ".ca", ".cg", ".cv". - store_cache_modifier: Controls cache behavior of the store. Supported values: "" (default), ".wb", ".cg", ".cs", ".wt". + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - "": *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - "": *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Example: >>> # Copy from rank 1 to current rank's local memory @@ -219,8 +244,18 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None, load_cache_modif to_rank: The rank ID to which the data will be written mask: Optional mask for conditional operations other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. - load_cache_modifier: Controls cache behavior of the load. Supported values: "" (default), ".ca", ".cg", ".cv". - store_cache_modifier: Controls cache behavior of the store. Supported values: "" (default), ".wb", ".cg", ".cs", ".wt". + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - "": *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - "": *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Example: >>> # Copy from current rank's local memory to rank 1 @@ -258,8 +293,18 @@ def copy( to_rank: The rank ID that will receive the data (destination rank) mask: Optional mask for conditional operations other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. - load_cache_modifier: Controls cache behavior of the load. Supported values: "" (default), ".ca", ".cg", ".cv". - store_cache_modifier: Controls cache behavior of the store. Supported values: "" (default), ".wb", ".cg", ".cs", ".wt". + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - "": *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - "": *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Example: >>> # Copy from rank 1 to rank 0 (current rank must be either 1 or 0) diff --git a/iris/iris.py b/iris/iris.py index 3c66bbda..c03edd5e 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1546,8 +1546,17 @@ def load( from_rank (int): The rank ID from which to read the data. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. other (Block, optional): Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined. Defaults to None. - cache_modifier (str, optional): Controls cache behavior of the load. Supported values: None (default, same as ".ca"), ".ca", ".cg", ".cv". Defaults to None. - volatile (bool, optional): If True, disables compiler optimizations that could reorder or eliminate the load. Defaults to False. + cache_modifier (str, optional): Controls cache behavior of the load. + + Supported values: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + Ensures global coherence by invalidating stale GPU cache lines. + + volatile (bool, optional): If True, disables compiler optimizations that + could reorder or eliminate the load. Defaults to False. hint (int or tuple, optional): Vectorization hint for the translated pointer. Defaults to None. Returns: @@ -1575,7 +1584,13 @@ def store(self, pointer, value, to_rank, mask=None, cache_modifier=None, hint: t value (Block): The tensor of elements to be stored. to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. - cache_modifier (str, optional): Controls cache behavior of the store. Supported values: None (default, same as ".wb"), ".wb", ".cg", ".cs", ".wt". Defaults to None. + cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None @@ -1612,8 +1627,18 @@ def get( from_rank (int): The rank ID from which to read the data. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. - load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values: None (default, same as ".ca"), ".ca", ".cg", ".cv". Defaults to None. - store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values: None (default, same as ".wb"), ".wb", ".cg", ".cs", ".wt". Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None @@ -1651,8 +1676,18 @@ def put( to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. - load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values: None (default, same as ".ca"), ".ca", ".cg", ".cv". Defaults to None. - store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values: None (default, same as ".wb"), ".wb", ".cg", ".cs", ".wt". Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None @@ -1694,8 +1729,18 @@ def copy( to_rank (int): The rank ID that will receive the data (destination rank). mask (Block of triton.int1, optional): If mask[idx] is false, do not load from src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. - load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values: None (default, same as ".ca"), ".ca", ".cg", ".cv". Defaults to None. - store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values: None (default, same as ".wb"), ".wb", ".cg", ".cs", ".wt". Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None From f5cfba226924fef29086682672564dc01efdec86 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 22 Mar 2026 04:49:33 +0000 Subject: [PATCH 5/6] Change Gluon cache_modifier defaults from "" to None for consistency with DeviceContext and freestanding API Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/5e77e456-b46f-4bef-b94a-8b8bfff9a473 --- iris/experimental/iris_gluon.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 48925634..24ddd014 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -144,7 +144,7 @@ def _translate(self, ptr, from_rank, to_rank): return translated_ptr @gluon.jit - def load(self, pointer, from_rank, mask=None, other=None, cache_modifier="", volatile=False): + def load(self, pointer, from_rank, mask=None, other=None, cache_modifier=None, volatile=False): """ Loads a value from the specified rank's memory location to the current rank. @@ -156,7 +156,7 @@ def load(self, pointer, from_rank, mask=None, other=None, cache_modifier="", vol cache_modifier (str, optional): Controls cache behavior of the load. Supported values: - - "": *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. @@ -177,7 +177,7 @@ def load(self, pointer, from_rank, mask=None, other=None, cache_modifier="", vol return result @gluon.jit - def store(self, pointer, value, to_rank, mask=None, cache_modifier=""): + def store(self, pointer, value, to_rank, mask=None, cache_modifier=None): """ Writes data from the current rank to the specified rank's memory location. @@ -188,7 +188,7 @@ def store(self, pointer, value, to_rank, mask=None, cache_modifier=""): mask: Optional mask for conditional storing cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: - - "": *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. @@ -202,7 +202,9 @@ def store(self, pointer, value, to_rank, mask=None, cache_modifier=""): gl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) @gluon.jit - def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None, load_cache_modifier="", store_cache_modifier=""): + def get( + self, from_ptr, to_ptr, from_rank, mask=None, other=None, load_cache_modifier=None, store_cache_modifier=None + ): """ Copies data from the specified rank's memory to the current rank's local memory. @@ -213,13 +215,13 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None, load_cache_mod mask: Optional mask for conditional operations other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: - - "": *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: - - "": *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. @@ -234,7 +236,9 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None, load_cache_mod gl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @gluon.jit - def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None, load_cache_modifier="", store_cache_modifier=""): + def put( + self, from_ptr, to_ptr, to_rank, mask=None, other=None, load_cache_modifier=None, store_cache_modifier=None + ): """ Copies data from the current rank's local memory to the specified rank's memory. @@ -245,13 +249,13 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None, load_cache_modif mask: Optional mask for conditional operations other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: - - "": *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: - - "": *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. @@ -274,8 +278,8 @@ def copy( to_rank, mask=None, other=None, - load_cache_modifier="", - store_cache_modifier="", + load_cache_modifier=None, + store_cache_modifier=None, ): """ Copies data from the specified rank's memory into the destination rank's memory. @@ -294,13 +298,13 @@ def copy( mask: Optional mask for conditional operations other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: - - "": *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: - - "": *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. From 59f07b89ba4bd8eb17d98b25a51bebf4891a82a6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Mar 2026 01:11:55 +0000 Subject: [PATCH 6/6] Fix test_device_context_store_cache_modifiers_remote: pass to_rank to kernel instead of hardcoding partner Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/3e67c439-de61-437e-97f4-318c4f05d925 --- tests/unittests/test_device_context_cache_modifiers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unittests/test_device_context_cache_modifiers.py b/tests/unittests/test_device_context_cache_modifiers.py index 232ce9be..216e7f18 100644 --- a/tests/unittests/test_device_context_cache_modifiers.py +++ b/tests/unittests/test_device_context_cache_modifiers.py @@ -51,6 +51,7 @@ def device_context_store_cache_modifier_kernel( target, cur_rank: tl.constexpr, num_ranks: tl.constexpr, + to_rank: tl.constexpr, BLOCK_SIZE: tl.constexpr, cache_modifier: tl.constexpr, ): @@ -58,14 +59,13 @@ def device_context_store_cache_modifier_kernel( ctx = DeviceContext.initialize(context_tensor, cur_rank, num_ranks) pid = tl.program_id(0) - partner = int((cur_rank + num_ranks // 2) % num_ranks) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < BLOCK_SIZE data = tl.load(source + offsets, mask=mask) - ctx.store(target + offsets, data, to_rank=partner, mask=mask, cache_modifier=cache_modifier) + ctx.store(target + offsets, data, to_rank=to_rank, mask=mask, cache_modifier=cache_modifier) @triton.jit @@ -313,7 +313,7 @@ def test_device_context_store_cache_modifiers_remote(cache_modifier): grid = lambda meta: (1,) if cur_rank == 0: device_context_store_cache_modifier_kernel[grid]( - context_tensor, source, target, cur_rank, num_ranks, BLOCK_SIZE, cache_modifier + context_tensor, source, target, cur_rank, num_ranks, remote_rank, BLOCK_SIZE, cache_modifier ) ctx.barrier()