diff --git a/CHANGELOG.md b/CHANGELOG.md index f5dc932b0..40e41d5fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# 0.11.0 (Prerelease + + +### 🐛 Bug fixes +- temporary fix to allow indexing modules with jax arrays until #657 is fixed (#654, @jnsbck). + # 0.10.0 ### 🧩 New features diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index e62e871ce..32c35e570 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -430,6 +430,10 @@ def _reformat_index(self, idx: Any, dtype: type = int) -> np.ndarray: if is_str_all(idx): # also asserts that the only allowed str == "all" return idx + # temporary fix for #654, until #657 is fixed + if isinstance(idx, jnp.ndarray): + idx = np.array(idx) + if isinstance(idx, np.ndarray) and np.issubdtype(idx.dtype, np.number): np_dtype = idx.dtype.type else: diff --git a/tests/test_viewing.py b/tests/test_viewing.py index e31567b75..f4e509204 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -286,7 +286,7 @@ def test_view_attrs(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): def test_view_supported_index_types(SimpleComp, SimpleBranch, SimpleCell, SimpleNet): """Check if different ways to index into Modules/Views work correctly.""" - # test int, range, slice, list, np.array, pd.Index + # test int, range, slice, list, np.array, pd.Index, jnp.array for module in [ SimpleComp(), @@ -305,6 +305,7 @@ def test_view_supported_index_types(SimpleComp, SimpleBranch, SimpleCell, Simple pd.Index([0, 1, 2]), pd.Index([0, 1, 2]).to_numpy(), np.array([True, False, True, False] * 100)[: len(module.nodes)], + jnp.array([0, 1, 2]), ] # comp.comp is not allowed