The following seems to have broken between 0.8.2 and 0.9.0, presumably due to #606.
from jaxley.synapses import IonotropicSynapse
from jaxley.connect import connect
import jaxley as jx
from jax import jit
import jax.numpy as jnp
net = jx.Network([jx.Cell()]*2)
net.record()
connect(net.cell(0), net.cell(1), IonotropicSynapse())
@jit
def simulate():
pstate = net.select(edges=0).data_set("IonotropicSynapse_gS", jnp.array([1.0]), None)
v = jx.integrate(net, param_state=pstate, t_max=1.0)
return v
v = simulate()
Now raises TracerArrayConversionError, I am guessing due to
|
synapse_inds = (synapse_inds.astype(int) - 1).to_numpy() |
EDIT:
I ran a git bisect and #606 seems to have indeed introduced the bug.
commit 08d68b04f9b64f143f517bf3960d9da1d1d98526 (refs/bisect/bad)
Author: Michael Deistler <michael.deistler@uni-tuebingen.de>
Date: Mon May 19 13:32:22 2025 +0200
Fix data set indices (#606)
* Make indices be (N, 1) instead of (1, N)
* Make default value of param_state = None
* Add test based off ntolley:data_set_vector
---------
Co-authored-by: Chase King <iamchaseking@gmail.com>
diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py
index 7f41da4..66a84a0 100644
--- a/jaxley/modules/base.py
+++ b/jaxley/modules/base.py
@@ -1095,7 +1095,7 @@ class Module(ABC):
self,
key: str,
val: Union[float, jnp.ndarray],
- param_state: Optional[List[Dict]],
+ param_state: Optional[List[Dict]] = None,
):
"""Set parameter of module (or its view) to a new value within `jit`.
@@ -1112,11 +1112,13 @@ class Module(ABC):
viewed_inds = self._nodes_in_view if is_node_param else self._edges_in_view
if key in data.columns:
not_nan = ~data[key].isna()
+ indices = jnp.asarray(viewed_inds[not_nan]).reshape(-1, 1) # shape (n_comp, 1)
+ val = jnp.broadcast_to(jnp.asarray(val), (indices.shape[0],)) # shape (n_comp,)
added_param_state = [
{
- "indices": np.atleast_2d(viewed_inds[not_nan]),
+ "indices": indices,
"key": key,
- "val": jnp.atleast_1d(jnp.asarray(val)),
+ "val": val,
}
]
if param_state is not None:
The following seems to have broken between 0.8.2 and 0.9.0, presumably due to #606.
Now raises
TracerArrayConversionError, I am guessing due tojaxley/jaxley/modules/base.py
Line 1611 in c6d11cd
EDIT:
I ran a git bisect and #606 seems to have indeed introduced the bug.