From 9691d470b81eba7e429e136ab114c127e6030fe3 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 12 Jun 2025 11:20:24 +0200 Subject: [PATCH 1/3] fix: hotfix for jax array indexing. closes #654 --- jaxley/modules/base.py | 4 ++++ tests/test_viewing.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index e62e871ce..2eee341a0 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -430,6 +430,10 @@ def _reformat_index(self, idx: Any, dtype: type = int) -> np.ndarray: if is_str_all(idx): # also asserts that the only allowed str == "all" return idx + # pass through for jax arrays. hotfix for #654, until #657 is fixed + if isinstance(idx, jnp.ndarray): + return idx + if isinstance(idx, np.ndarray) and np.issubdtype(idx.dtype, np.number): np_dtype = idx.dtype.type else: diff --git a/tests/test_viewing.py b/tests/test_viewing.py index e31567b75..f4e509204 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -286,7 +286,7 @@ def test_view_attrs(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): def test_view_supported_index_types(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): """Check if different ways to index into Modules/Views work correctly.""" - # test int, range, slice, list, np.array, pd.Index + # test int, range, slice, list, np.array, pd.Index, jnp.array for module in [ SimpleComp(), @@ -305,6 +305,7 @@ def test_view_supported_index_types(SimpleComp, SimpleBranch, SimpleCell, Simple pd.Index([0, 1, 2]), pd.Index([0, 1, 2]).to_numpy(), np.array([True, False, True, False] * 100)[: len(module.nodes)], + jnp.array([0, 1, 2]), ] # comp.comp is not allowed From 5000683f14f46208ceeb1641f1ed90db05eb960a Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 12 Jun 2025 11:24:11 +0200 Subject: [PATCH 2/3] fix: replace passthrough by numpy conversion --- jaxley/modules/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 2eee341a0..32c35e570 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -430,9 +430,9 @@ def _reformat_index(self, idx: Any, dtype: type = int) -> np.ndarray: if is_str_all(idx): # also asserts that the only allowed str == "all" return idx - # pass through for jax arrays. hotfix for #654, until #657 is fixed + # temporary fix for #654, until #657 is fixed if isinstance(idx, jnp.ndarray): - return idx + idx = np.array(idx) if isinstance(idx, np.ndarray) and np.issubdtype(idx.dtype, np.number): np_dtype = idx.dtype.type From 016b2e68d2aea4936b8a5903e56fe63a6a0797ea Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 12 Jun 2025 11:25:59 +0200 Subject: [PATCH 3/3] chore: update changelog --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f5dc932b0..40e41d5fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# 0.11.0 (Prerelease + + +### 🐛 Bug fixes +- temporary fix to allow indexing modules with jax arrays until #657 is fixed (#654, @jnsbck). + # 0.10.0 ### 🧩 New features