Skip to content

TracerArrayConversionError in data_set #657

@jnsbck

Description

@jnsbck

The following seems to have broken between 0.8.2 and 0.9.0, presumably due to #606.

from jaxley.synapses import IonotropicSynapse
from jaxley.connect import connect
import jaxley as jx
from jax import jit
import jax.numpy as jnp

net = jx.Network([jx.Cell()]*2)
net.record()
connect(net.cell(0), net.cell(1), IonotropicSynapse())

@jit 
def simulate():
    pstate = net.select(edges=0).data_set("IonotropicSynapse_gS", jnp.array([1.0]), None)
    v = jx.integrate(net, param_state=pstate, t_max=1.0)
    return v

v = simulate()

Now raises TracerArrayConversionError, I am guessing due to

synapse_inds = (synapse_inds.astype(int) - 1).to_numpy()

EDIT:
I ran a git bisect and #606 seems to have indeed introduced the bug.

commit 08d68b04f9b64f143f517bf3960d9da1d1d98526 (refs/bisect/bad)
Author: Michael Deistler <michael.deistler@uni-tuebingen.de>
Date:   Mon May 19 13:32:22 2025 +0200

    Fix data set indices (#606)
    
    * Make indices be (N, 1) instead of (1, N)
    
    * Make default value of param_state = None
    
    * Add test based off ntolley:data_set_vector
    
    ---------
    
    Co-authored-by: Chase King <iamchaseking@gmail.com>

diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py
index 7f41da4..66a84a0 100644
--- a/jaxley/modules/base.py
+++ b/jaxley/modules/base.py
@@ -1095,7 +1095,7 @@ class Module(ABC):
         self,
         key: str,
         val: Union[float, jnp.ndarray],
-        param_state: Optional[List[Dict]],
+        param_state: Optional[List[Dict]] = None,
     ):
         """Set parameter of module (or its view) to a new value within `jit`.
 
@@ -1112,11 +1112,13 @@ class Module(ABC):
         viewed_inds = self._nodes_in_view if is_node_param else self._edges_in_view
         if key in data.columns:
             not_nan = ~data[key].isna()
+            indices = jnp.asarray(viewed_inds[not_nan]).reshape(-1, 1)  # shape (n_comp, 1)
+            val = jnp.broadcast_to(jnp.asarray(val), (indices.shape[0],))  # shape (n_comp,)
             added_param_state = [
                 {
-                    "indices": np.atleast_2d(viewed_inds[not_nan]),
+                    "indices": indices,
                     "key": key,
-                    "val": jnp.atleast_1d(jnp.asarray(val)),
+                    "val": val,
                 }
             ]
             if param_state is not None:

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions