Port cache_modifier, volatile, and other to DeviceContext and Gluon APIs#471
Port cache_modifier, volatile, and other to DeviceContext and Gluon APIs#471
Conversation
- DeviceContext.load(): add other, cache_modifier, volatile params - DeviceContext.store(): add cache_modifier param - DeviceContext.get(): add other, load_cache_modifier, store_cache_modifier params - DeviceContext.put(): add other, load_cache_modifier, store_cache_modifier params - DeviceContext.copy(): add other, load_cache_modifier, store_cache_modifier params - IrisDeviceCtx.load(): add cache_modifier, volatile params - IrisDeviceCtx.store(): add cache_modifier param - IrisDeviceCtx.get(): add load_cache_modifier, store_cache_modifier params - IrisDeviceCtx.put(): add load_cache_modifier, store_cache_modifier params - IrisDeviceCtx.copy(): add load_cache_modifier, store_cache_modifier params - Add tests: test_device_context_cache_modifiers.py - Add tests: test_gluon_cache_modifiers.py Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/8ce4f35c-fe72-4506-89d2-e79687a16a09
| ) | ||
| def test_device_context_put_cache_modifiers_local(load_cache_modifier, store_cache_modifier): | ||
| """Test DeviceContext.put() local (from_rank == to_rank) with various cache modifiers.""" | ||
| shmem = iris.iris(1 << 20) |
There was a problem hiding this comment.
@copilot use ctx instead of shmem all over the files you added or updated.
There was a problem hiding this comment.
Done in 791e624. Replaced all shmem references with ctx in both test_device_context_cache_modifiers.py and test_gluon_cache_modifiers.py.
…t_gluon_cache_modifiers.py Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/ee387d87-c4e9-4207-b9d8-de35aba33605
|
@copilot The docstrings for cache_modifier params in both DeviceContext and Gluon are too short. Expand them to match the format used in the freestanding API (see For example, you wrote: But the freestanding API has the full expanded format: And for store: Apply this expanded format to all |
…to match freestanding API format Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/639da603-607c-4d81-b5af-88a77a38db97
Done in b46c8a7. Expanded all |
|
@copilot The Gluon API uses |
…with DeviceContext and freestanding API Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/5e77e456-b46f-4bef-b94a-8b8bfff9a473
Done in f5cfba2. Changed all Gluon |
There was a problem hiding this comment.
Pull request overview
Adds missing cache_modifier, volatile, and other parameters to DeviceContext and Gluon IrisDeviceCtx methods so they match the freestanding Iris API introduced in #158, and adds unit tests to validate the new passthrough behavior.
Changes:
- Extend
DeviceContext(load/store/get/put/copy) to accept and forward cache-modifier and masking-related parameters totl.load/tl.store. - Extend Gluon
IrisDeviceCtxmethods to accept and forwardcache_modifier/volatilewhere supported bygl.load/gl.store. - Add dedicated unit tests covering cache modifiers across all 5 methods for both APIs.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 27 comments.
| File | Description |
|---|---|
| iris/iris.py | Extends DeviceContext method signatures and forwards args to Triton loads/stores. |
| iris/experimental/iris_gluon.py | Extends Gluon IrisDeviceCtx method signatures and forwards args to gl.load/gl.store. |
| tests/unittests/test_device_context_cache_modifiers.py | Adds coverage for DeviceContext cache modifiers / volatile across all 5 methods. |
| tests/unittests/test_gluon_cache_modifiers.py | Adds equivalent coverage for Gluon IrisDeviceCtx cache modifiers / volatile across all 5 methods. |
| cache_modifier=None, | ||
| volatile=False, |
There was a problem hiding this comment.
tl.load(..., cache_modifier=..., volatile=...) expects compile-time constants in Triton; but cache_modifier and volatile are not marked tl.constexpr in this @triton.jit function signature. This can cause compilation failures when callers pass string/bool constants (which Triton generally requires as constexpr). Make cache_modifier: tl.constexpr = None and volatile: tl.constexpr = False (and similarly for the load/store cache modifier args in get/put/copy, and cache_modifier in store).
| cache_modifier=None, | |
| volatile=False, | |
| cache_modifier: tl.constexpr = None, | |
| volatile: tl.constexpr = False, |
| """ | ||
| translated_ptr = self._translate(pointer, self.rank, from_rank, hint) | ||
| result = tl.load(translated_ptr, mask=mask) | ||
| result = tl.load(translated_ptr, mask=mask, other=other, cache_modifier=cache_modifier, volatile=volatile) |
There was a problem hiding this comment.
tl.load(..., cache_modifier=..., volatile=...) expects compile-time constants in Triton; but cache_modifier and volatile are not marked tl.constexpr in this @triton.jit function signature. This can cause compilation failures when callers pass string/bool constants (which Triton generally requires as constexpr). Make cache_modifier: tl.constexpr = None and volatile: tl.constexpr = False (and similarly for the load/store cache modifier args in get/put/copy, and cache_modifier in store).
| other=None, | ||
| cache_modifier=None, | ||
| volatile=False, | ||
| hint: tl.constexpr = None, |
There was a problem hiding this comment.
The new parameters are inserted before hint, which changes the positional argument order compared to the previous API. Any existing callers passing hint positionally (rather than as a keyword) will now accidentally bind it to other/cache_modifier/etc. To avoid a breaking change, keep hint in its previous position (immediately after mask) and append new parameters after it, or make the new parameters keyword-only (if compatible with @triton.jit usage patterns in this codebase).
| other=None, | |
| cache_modifier=None, | |
| volatile=False, | |
| hint: tl.constexpr = None, | |
| hint: tl.constexpr = None, | |
| other=None, | |
| cache_modifier=None, | |
| volatile=False, |
|
|
||
| @triton.jit | ||
| def store(self, pointer, value, to_rank, mask=None, hint: tl.constexpr = None): | ||
| def store(self, pointer, value, to_rank, mask=None, cache_modifier=None, hint: tl.constexpr = None): |
There was a problem hiding this comment.
The new parameters are inserted before hint, which changes the positional argument order compared to the previous API. Any existing callers passing hint positionally (rather than as a keyword) will now accidentally bind it to other/cache_modifier/etc. To avoid a breaking change, keep hint in its previous position (immediately after mask) and append new parameters after it, or make the new parameters keyword-only (if compatible with @triton.jit usage patterns in this codebase).
| def get( | ||
| self, | ||
| from_ptr, | ||
| to_ptr, | ||
| from_rank, | ||
| mask=None, | ||
| other=None, | ||
| load_cache_modifier=None, | ||
| store_cache_modifier=None, | ||
| hint: tl.constexpr = None, | ||
| ): |
There was a problem hiding this comment.
The new parameters are inserted before hint, which changes the positional argument order compared to the previous API. Any existing callers passing hint positionally (rather than as a keyword) will now accidentally bind it to other/cache_modifier/etc. To avoid a breaking change, keep hint in its previous position (immediately after mask) and append new parameters after it, or make the new parameters keyword-only (if compatible with @triton.jit usage patterns in this codebase).
| @pytest.mark.parametrize( | ||
| "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) |
There was a problem hiding this comment.
The Cartesian-product parameterization across modifiers creates a large number of JIT-compiled test cases (especially for put and both copy directions), which can significantly increase CI time. If runtime becomes an issue, consider sampling representative modifier pairs (e.g., None, one mid-level like .cg, and one extreme like .cv/.wt) or marking the full matrix as a slower test tier.
| @pytest.mark.parametrize( | |
| "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) | |
| def _sample_cache_modifiers(modifiers): | |
| """Select a small representative subset of cache modifiers to limit test explosion.""" | |
| if len(modifiers) <= 3: | |
| return modifiers | |
| mid = len(modifiers) // 2 | |
| indices = {0, mid, len(modifiers) - 1} | |
| return [modifiers[i] for i in sorted(indices)] | |
| _SAMPLED_LOAD_CACHE_MODIFIERS = _sample_cache_modifiers(LOAD_CACHE_MODIFIERS) | |
| _SAMPLED_STORE_CACHE_MODIFIERS = _sample_cache_modifiers(STORE_CACHE_MODIFIERS) | |
| _SAMPLED_GET_CACHE_MODIFIER_PAIRS = list( | |
| product(_SAMPLED_LOAD_CACHE_MODIFIERS, _SAMPLED_STORE_CACHE_MODIFIERS) | |
| ) | |
| @pytest.mark.parametrize( | |
| "load_cache_modifier,store_cache_modifier", _SAMPLED_GET_CACHE_MODIFIER_PAIRS |
| @pytest.mark.parametrize( | ||
| "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) | ||
| ) |
There was a problem hiding this comment.
The Cartesian-product parameterization across modifiers creates a large number of JIT-compiled test cases (especially for put and both copy directions), which can significantly increase CI time. If runtime becomes an issue, consider sampling representative modifier pairs (e.g., None, one mid-level like .cg, and one extreme like .cv/.wt) or marking the full matrix as a slower test tier.
| @pytest.mark.parametrize( | ||
| "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) | ||
| ) |
There was a problem hiding this comment.
The Cartesian-product parameterization across modifiers creates a large number of JIT-compiled test cases (especially for put and both copy directions), which can significantly increase CI time. If runtime becomes an issue, consider sampling representative modifier pairs (e.g., None, one mid-level like .cg, and one extreme like .cv/.wt) or marking the full matrix as a slower test tier.
| @pytest.mark.parametrize( | ||
| "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) | ||
| ) |
There was a problem hiding this comment.
The Cartesian-product parameterization across modifiers creates a large number of JIT-compiled test cases (especially for put and both copy directions), which can significantly increase CI time. If runtime becomes an issue, consider sampling representative modifier pairs (e.g., None, one mid-level like .cg, and one extreme like .cv/.wt) or marking the full matrix as a slower test tier.
| @pytest.mark.parametrize( | ||
| "load_cache_modifier,store_cache_modifier", | ||
| list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)), | ||
| ) |
There was a problem hiding this comment.
The Cartesian-product parameterization across modifiers creates a large number of JIT-compiled test cases (especially for put and both copy directions), which can significantly increase CI time. If runtime becomes an issue, consider sampling representative modifier pairs (e.g., None, one mid-level like .cg, and one extreme like .cv/.wt) or marking the full matrix as a slower test tier.
DeviceContextandIrisDeviceCtx(Gluon) were missingcache_modifier,volatile, andotherparameters that existed on the freestanding API (iris.load,iris.store, etc.) since #158.DeviceContext (
iris/iris.py)Added missing parameters to all 5 methods, passed through to
tl.load()/tl.store():loadother,cache_modifier,volatilestorecache_modifiergetother,load_cache_modifier,store_cache_modifierputother,load_cache_modifier,store_cache_modifiercopyother,load_cache_modifier,store_cache_modifierGluon (
iris/experimental/iris_gluon.py)gl.load()andgl.store()supportcache_modifierandvolatile, so all 5IrisDeviceCtxmethods are updated with the same parameters. Defaults useNone(consistent with the freestanding API andDeviceContext).Usage
Tests
tests/unittests/test_device_context_cache_modifiers.py: mirrors the existing freestanding cache modifier tests for all 5DeviceContextmethodstests/unittests/test_gluon_cache_modifiers.py: equivalent coverage forIrisDeviceCtxOriginal prompt
🔒 GitHub Advanced Security automatically protects Copilot coding agent pull requests. You can protect all pull requests by enabling Advanced Security for your repositories. Learn more about Advanced Security.