Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 81 additions & 13 deletions iris/experimental/iris_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _translate(self, ptr, from_rank, to_rank):
return translated_ptr

@gluon.jit
def load(self, pointer, from_rank, mask=None, other=None):
def load(self, pointer, from_rank, mask=None, other=None, cache_modifier=None, volatile=False):
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to Triton, Gluon language load/store cache control arguments are typically compile-time constants. These new parameters (cache_modifier, volatile, load_cache_modifier, store_cache_modifier) are not marked as gl.constexpr in @gluon.jit methods, which can lead to compilation/type errors when passing strings/bools. Consider annotating them as gl.constexpr (and defaulting to None/False as you do now).

Suggested change
def load(self, pointer, from_rank, mask=None, other=None, cache_modifier=None, volatile=False):
def load(self, pointer, from_rank, mask=None, other=None, cache_modifier: gl.constexpr = None, volatile: gl.constexpr = False):

Copilot uses AI. Check for mistakes.
"""
Loads a value from the specified rank's memory location to the current rank.

Expand All @@ -153,6 +153,17 @@ def load(self, pointer, from_rank, mask=None, other=None):
from_rank: The rank ID from which to read the data
mask: Optional mask for conditional loading
other: Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined.
cache_modifier (str, optional): Controls cache behavior of the load.

Supported values:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.
Ensures global coherence by invalidating stale GPU cache lines.

volatile (bool, optional): If True, disables compiler optimizations that
could reorder or eliminate the load. Defaults to False.

Returns:
The loaded value from the target memory location
Expand All @@ -162,11 +173,11 @@ def load(self, pointer, from_rank, mask=None, other=None):
>>> data = ctx.load(buffer + offsets, 1, mask=mask)
"""
translated_ptr = self._translate(pointer, self.cur_rank, from_rank)
result = gl.load(translated_ptr, mask=mask, other=other)
result = gl.load(translated_ptr, mask=mask, other=other, cache_modifier=cache_modifier, volatile=volatile)
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to Triton, Gluon language load/store cache control arguments are typically compile-time constants. These new parameters (cache_modifier, volatile, load_cache_modifier, store_cache_modifier) are not marked as gl.constexpr in @gluon.jit methods, which can lead to compilation/type errors when passing strings/bools. Consider annotating them as gl.constexpr (and defaulting to None/False as you do now).

Copilot uses AI. Check for mistakes.
return result

@gluon.jit
def store(self, pointer, value, to_rank, mask=None):
def store(self, pointer, value, to_rank, mask=None, cache_modifier=None):
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to Triton, Gluon language load/store cache control arguments are typically compile-time constants. These new parameters (cache_modifier, volatile, load_cache_modifier, store_cache_modifier) are not marked as gl.constexpr in @gluon.jit methods, which can lead to compilation/type errors when passing strings/bools. Consider annotating them as gl.constexpr (and defaulting to None/False as you do now).

Copilot uses AI. Check for mistakes.
"""
Writes data from the current rank to the specified rank's memory location.

Expand All @@ -175,16 +186,25 @@ def store(self, pointer, value, to_rank, mask=None):
value: The value to store
to_rank: The rank ID to which the data will be written
mask: Optional mask for conditional storing
cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:

- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Example:
>>> # Store from current rank to rank 1
>>> ctx.store(buffer + offsets, values, 1, mask=mask)
"""
translated_ptr = self._translate(pointer, self.cur_rank, to_rank)
gl.store(translated_ptr, value, mask=mask)
gl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier)

@gluon.jit
def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None):
def get(
self, from_ptr, to_ptr, from_rank, mask=None, other=None, load_cache_modifier=None, store_cache_modifier=None
):
Comment on lines 204 to +207
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to Triton, Gluon language load/store cache control arguments are typically compile-time constants. These new parameters (cache_modifier, volatile, load_cache_modifier, store_cache_modifier) are not marked as gl.constexpr in @gluon.jit methods, which can lead to compilation/type errors when passing strings/bools. Consider annotating them as gl.constexpr (and defaulting to None/False as you do now).

Copilot uses AI. Check for mistakes.
"""
Copies data from the specified rank's memory to the current rank's local memory.

Expand All @@ -194,17 +214,31 @@ def get(self, from_ptr, to_ptr, from_rank, mask=None, other=None):
from_rank: The rank ID from which to read the data
mask: Optional mask for conditional operations
other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined.
load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Example:
>>> # Copy from rank 1 to current rank's local memory
>>> ctx.get(remote_ptr + offsets, local_ptr + offsets, 1, mask=mask)
"""
translated_from_ptr = self._translate(from_ptr, self.cur_rank, from_rank)
data = gl.load(translated_from_ptr, mask=mask, other=other)
gl.store(to_ptr, data, mask=mask)
data = gl.load(translated_from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier)
gl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier)
Comment on lines +235 to +236
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to Triton, Gluon language load/store cache control arguments are typically compile-time constants. These new parameters (cache_modifier, volatile, load_cache_modifier, store_cache_modifier) are not marked as gl.constexpr in @gluon.jit methods, which can lead to compilation/type errors when passing strings/bools. Consider annotating them as gl.constexpr (and defaulting to None/False as you do now).

Copilot uses AI. Check for mistakes.

@gluon.jit
def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None):
def put(
self, from_ptr, to_ptr, to_rank, mask=None, other=None, load_cache_modifier=None, store_cache_modifier=None
):
"""
Copies data from the current rank's local memory to the specified rank's memory.

Expand All @@ -214,17 +248,39 @@ def put(self, from_ptr, to_ptr, to_rank, mask=None, other=None):
to_rank: The rank ID to which the data will be written
mask: Optional mask for conditional operations
other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined.
load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Example:
>>> # Copy from current rank's local memory to rank 1
>>> ctx.put(local_ptr + offsets, remote_ptr + offsets, 1, mask=mask)
"""
translated_to_ptr = self._translate(to_ptr, self.cur_rank, to_rank)
data = gl.load(from_ptr, mask=mask, other=other)
gl.store(translated_to_ptr, data, mask=mask)
data = gl.load(from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier)
gl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier)

@gluon.jit
def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None):
def copy(
self,
src_ptr,
dst_ptr,
from_rank,
to_rank,
mask=None,
other=None,
load_cache_modifier=None,
store_cache_modifier=None,
):
"""
Copies data from the specified rank's memory into the destination rank's memory.

Expand All @@ -241,6 +297,18 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None):
to_rank: The rank ID that will receive the data (destination rank)
mask: Optional mask for conditional operations
other: Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined.
load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Example:
>>> # Copy from rank 1 to rank 0 (current rank must be either 1 or 0)
Expand All @@ -262,8 +330,8 @@ def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, other=None):
translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype)
translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype)

data = gl.load(translated_src, mask=mask, other=other)
gl.store(translated_dst, data, mask=mask)
data = gl.load(translated_src, mask=mask, other=other, cache_modifier=load_cache_modifier)
gl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier)

@gluon.jit
def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None):
Expand Down
Loading
Loading