diff --git a/jaxley/integrate.py b/jaxley/integrate.py index 45645d05d..c068ec15b 100644 --- a/jaxley/integrate.py +++ b/jaxley/integrate.py @@ -118,11 +118,11 @@ def add_stimuli( if data_stimuli is not None: externals["i"] = jnp.concatenate([externals["i"], data_stimuli[1]]) external_inds["i"] = jnp.concatenate( - [external_inds["i"], data_stimuli[2].global_comp_index.to_numpy()] + [external_inds["i"], data_stimuli[2].index.to_numpy()] ) else: externals["i"] = data_stimuli[1] - external_inds["i"] = data_stimuli[2].global_comp_index.to_numpy() + external_inds["i"] = data_stimuli[2].index.to_numpy() return externals, external_inds @@ -148,11 +148,11 @@ def add_clamps( if state_name in externals.keys(): externals[state_name] = jnp.concatenate([externals[state_name], clamps]) external_inds[state_name] = jnp.concatenate( - [external_inds[state_name], inds.global_comp_index.to_numpy()] + [external_inds[state_name], inds.index.to_numpy()] ) else: externals[state_name] = clamps - external_inds[state_name] = inds.global_comp_index.to_numpy() + external_inds[state_name] = inds.index.to_numpy() return externals, external_inds diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 12bb8f0dc..92a73aed7 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1179,6 +1179,15 @@ def add_to_group(self, group_name: str): np.concatenate([self.base.groups[group_name], self._nodes_in_view]) ) + def _get_state_names(self) -> Tuple[List, List]: + """Collect all recordable / clampable states in the membrane and synapses. + + Returns states seperated by comps and edges.""" + channel_states = [name for c in self.channels for name in c.channel_states] + synapse_states = [name for s in self.synapses for name in s.synapse_states] + membrane_states = ["v", "i"] + self.membrane_current_names + return channel_states + membrane_states, synapse_states + def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: """Get all trainable parameters. @@ -1447,10 +1456,10 @@ def _init_morph_for_debugging(self): self.base.debug_states["par_inds"] = self.base.par_inds def record(self, state: str = "v", verbose=True): - in_view = None - in_view = self._edges_in_view if state in self.edges.columns else in_view - in_view = self._nodes_in_view if state in self.nodes.columns else in_view - assert in_view is not None, "State not found in nodes or edges." + comp_states, edge_states = self._get_state_names() + if state not in comp_states + edge_states: + raise KeyError(f"{state} is not a recognized state in this module.") + in_view = self._nodes_in_view if state in comp_states else self._edges_in_view new_recs = pd.DataFrame(in_view, columns=["rec_index"]) new_recs["state"] = state @@ -1514,8 +1523,6 @@ def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True) This function sets external states for the compartments. """ - if state_name not in self.nodes.columns: - raise KeyError(f"{state_name} is not a recognized state in this module.") self._external_input(state_name, state_array, verbose=verbose) def _external_input( @@ -1524,15 +1531,16 @@ def _external_input( values: Optional[jnp.ndarray], verbose: bool = True, ): + comp_states, edge_states = self._get_state_names() + if key not in comp_states + edge_states: + raise KeyError(f"{key} is not a recognized state in this module.") values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0) batch_size = values.shape[0] - num_inserted = len(self._nodes_in_view) - is_multiple = num_inserted == batch_size - values = ( - values - if is_multiple - else jnp.repeat(values, len(self._nodes_in_view), axis=0) + num_inserted = ( + len(self._nodes_in_view) if key in comp_states else len(self._edges_in_view) ) + is_multiple = num_inserted == batch_size + values = values if is_multiple else jnp.repeat(values, num_inserted, axis=0) assert batch_size in [ 1, num_inserted, @@ -1546,9 +1554,12 @@ def _external_input( [self.base.external_inds[key], self._nodes_in_view] ) else: - self.base.externals[key] = values - self.base.external_inds[key] = self._nodes_in_view - + if key in comp_states: + self.base.externals[key] = values + self.base.external_inds[key] = self._nodes_in_view + else: + self.base.externals[key] = values + self.base.external_inds[key] = self._edges_in_view if verbose: print( f"Added {num_inserted} external_states. See `.externals` for details." @@ -1588,8 +1599,12 @@ def data_clamp( verbose: Whether or not to print the number of inserted clamps. `False` by default because this method is meant to be jitted. """ + comp_states, edge_states = self._get_state_names() + if state_name not in comp_states + edge_states: + raise KeyError(f"{state_name} is not a recognized state in this module.") + data = self.nodes if state_name in comp_states else self.edges return self._data_external_input( - state_name, state_array, data_clamps, self.nodes, verbose=verbose + state_name, state_array, data_clamps, data, verbose=verbose ) def _data_external_input( @@ -1600,16 +1615,23 @@ def _data_external_input( view: pd.DataFrame, verbose: bool = False, ): + comp_states, edge_states = self._get_state_names() state_array = ( state_array if state_array.ndim == 2 else jnp.expand_dims(state_array, axis=0) ) batch_size = state_array.shape[0] - num_inserted = len(self._nodes_in_view) + num_inserted = ( + len(self._nodes_in_view) + if state_name in comp_states + else len(self._edges_in_view) + ) is_multiple = num_inserted == batch_size state_array = ( - state_array if is_multiple else jnp.repeat(state_array, len(view), axis=0) + state_array + if is_multiple + else jnp.repeat(state_array, num_inserted, axis=0) ) assert batch_size in [ 1, @@ -1638,23 +1660,28 @@ def delete_stimuli(self): """Removes all stimuli from the module.""" self.delete_clamps("i") - def delete_clamps(self, state_name: str): + def delete_clamps(self, state_name: Optional[str] = None): """Removes all clamps of the given state from the module.""" - if state_name in self.externals: - keep_inds = ~np.isin( - self.base.external_inds[state_name], self._nodes_in_view - ) - base_exts = self.base.externals - base_exts_inds = self.base.external_inds - if np.all(~keep_inds): - base_exts.pop(state_name, None) - base_exts_inds.pop(state_name, None) + all_externals = list(self.externals.keys()) + if "i" in all_externals: + all_externals.remove("i") + state_names = all_externals if state_name is None else [state_name] + for state_name in state_names: + if state_name in self.externals: + keep_inds = ~np.isin( + self.base.external_inds[state_name], self._nodes_in_view + ) + base_exts = self.base.externals + base_exts_inds = self.base.external_inds + if np.all(~keep_inds): + base_exts.pop(state_name, None) + base_exts_inds.pop(state_name, None) + else: + base_exts[state_name] = base_exts[state_name][keep_inds] + base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds] + self._update_view() else: - base_exts[state_name] = base_exts[state_name][keep_inds] - base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds] - self._update_view() - else: - pass # does not have to be deleted if not in externals + pass # does not have to be deleted if not in externals def insert(self, channel: Channel): """Insert a channel into the module. diff --git a/tests/test_clamp.py b/tests/test_clamp.py index c4ed7265e..8253cd5bb 100644 --- a/tests/test_clamp.py +++ b/tests/test_clamp.py @@ -3,6 +3,9 @@ import jax +from jaxley.connect import connect +from jaxley.synapses.ionotropic import IonotropicSynapse + jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") from typing import Optional @@ -25,6 +28,54 @@ def test_clamp_pointneuron(): assert np.all(v[:, 1:] == -50.0) +def test_clamp_currents(): + comp = jx.Compartment() + comp.insert(HH()) + comp.record("i_HH") + + # test clamp + comp.clamp("i_HH", 1.0 * jnp.ones((1000,))) + i1 = jx.integrate(comp, t_max=1.0) + assert np.all(i1[:, 1:] == 1.0) + + # test data clamp + data_clamps = None + ipts = 1.0 * jnp.ones((1000,)) + data_clamps = comp.data_clamp("i_HH", ipts, data_clamps=data_clamps) + + i2 = jx.integrate(comp, data_clamps=data_clamps, t_max=1.0) + assert np.all(i2[:, 1:] == 1.0) + + assert np.all(np.isclose(i1, i2)) + + +def test_clamp_synapse(): + comp = jx.Compartment() + branch = jx.Branch(comp, 1) + cell1 = jx.Cell(branch, [-1]) + cell2 = jx.Cell(branch, [-1]) + net = jx.Network([cell1, cell2]) + connect(net[0, 0, 0], net[1, 0, 0], IonotropicSynapse()) + net.record("IonotropicSynapse_s") + + # test clamp + net.clamp("IonotropicSynapse_s", 1.0 * jnp.ones((1000,))) + s1 = jx.integrate(net, t_max=1.0) + assert np.all(s1[:, 1:] == 1.0) + + net.delete_clamps() + + # test data clamp + data_clamps = None + ipts = 1.0 * jnp.ones((1000,)) + data_clamps = net.data_clamp("IonotropicSynapse_s", ipts, data_clamps=data_clamps) + + s2 = jx.integrate(net, data_clamps=data_clamps, t_max=1.0) + assert np.all(s2[:, 1:] == 1.0) + + assert np.all(np.isclose(s1, s2)) + + def test_clamp_multicompartment(): comp = jx.Compartment() branch = jx.Branch(comp, 4)