Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ def add_stimuli(
if data_stimuli is not None:
externals["i"] = jnp.concatenate([externals["i"], data_stimuli[1]])
external_inds["i"] = jnp.concatenate(
[external_inds["i"], data_stimuli[2].global_comp_index.to_numpy()]
[external_inds["i"], data_stimuli[2].index.to_numpy()]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to .index, since edges does not have global_comp_index

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...and clamping of synapses needs the dataframe index.

)
else:
externals["i"] = data_stimuli[1]
external_inds["i"] = data_stimuli[2].global_comp_index.to_numpy()
external_inds["i"] = data_stimuli[2].index.to_numpy()

return externals, external_inds

Expand All @@ -148,11 +148,11 @@ def add_clamps(
if state_name in externals.keys():
externals[state_name] = jnp.concatenate([externals[state_name], clamps])
external_inds[state_name] = jnp.concatenate(
[external_inds[state_name], inds.global_comp_index.to_numpy()]
[external_inds[state_name], inds.index.to_numpy()]
)
else:
externals[state_name] = clamps
external_inds[state_name] = inds.global_comp_index.to_numpy()
external_inds[state_name] = inds.index.to_numpy()

return externals, external_inds

Expand Down
93 changes: 60 additions & 33 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,15 @@ def add_to_group(self, group_name: str):
np.concatenate([self.base.groups[group_name], self._nodes_in_view])
)

def _get_state_names(self) -> Tuple[List, List]:
"""Collect all recordable / clampable states in the membrane and synapses.

Returns states seperated by comps and edges."""
channel_states = [name for c in self.channels for name in c.channel_states]
synapse_states = [name for s in self.synapses for name in s.synapse_states]
membrane_states = ["v", "i"] + self.membrane_current_names
return channel_states + membrane_states, synapse_states

def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:
"""Get all trainable parameters.

Expand Down Expand Up @@ -1447,10 +1456,10 @@ def _init_morph_for_debugging(self):
self.base.debug_states["par_inds"] = self.base.par_inds

def record(self, state: str = "v", verbose=True):
in_view = None
in_view = self._edges_in_view if state in self.edges.columns else in_view
in_view = self._nodes_in_view if state in self.nodes.columns else in_view
assert in_view is not None, "State not found in nodes or edges."
comp_states, edge_states = self._get_state_names()
if state not in comp_states + edge_states:
raise KeyError(f"{state} is not a recognized state in this module.")
in_view = self._nodes_in_view if state in comp_states else self._edges_in_view

new_recs = pd.DataFrame(in_view, columns=["rec_index"])
new_recs["state"] = state
Expand Down Expand Up @@ -1514,8 +1523,6 @@ def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True)

This function sets external states for the compartments.
"""
if state_name not in self.nodes.columns:
raise KeyError(f"{state_name} is not a recognized state in this module.")
self._external_input(state_name, state_array, verbose=verbose)

def _external_input(
Expand All @@ -1524,15 +1531,16 @@ def _external_input(
values: Optional[jnp.ndarray],
verbose: bool = True,
):
comp_states, edge_states = self._get_state_names()
if key not in comp_states + edge_states:
raise KeyError(f"{key} is not a recognized state in this module.")
values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0)
batch_size = values.shape[0]
num_inserted = len(self._nodes_in_view)
is_multiple = num_inserted == batch_size
values = (
values
if is_multiple
else jnp.repeat(values, len(self._nodes_in_view), axis=0)
num_inserted = (
len(self._nodes_in_view) if key in comp_states else len(self._edges_in_view)
)
is_multiple = num_inserted == batch_size
values = values if is_multiple else jnp.repeat(values, num_inserted, axis=0)
assert batch_size in [
1,
num_inserted,
Expand All @@ -1546,9 +1554,12 @@ def _external_input(
[self.base.external_inds[key], self._nodes_in_view]
)
else:
self.base.externals[key] = values
self.base.external_inds[key] = self._nodes_in_view

if key in comp_states:
self.base.externals[key] = values
self.base.external_inds[key] = self._nodes_in_view
else:
self.base.externals[key] = values
self.base.external_inds[key] = self._edges_in_view
if verbose:
print(
f"Added {num_inserted} external_states. See `.externals` for details."
Expand Down Expand Up @@ -1588,8 +1599,12 @@ def data_clamp(
verbose: Whether or not to print the number of inserted clamps. `False`
by default because this method is meant to be jitted.
"""
comp_states, edge_states = self._get_state_names()
if state_name not in comp_states + edge_states:
raise KeyError(f"{state_name} is not a recognized state in this module.")
data = self.nodes if state_name in comp_states else self.edges
return self._data_external_input(
state_name, state_array, data_clamps, self.nodes, verbose=verbose
state_name, state_array, data_clamps, data, verbose=verbose
)

def _data_external_input(
Expand All @@ -1600,16 +1615,23 @@ def _data_external_input(
view: pd.DataFrame,
verbose: bool = False,
):
comp_states, edge_states = self._get_state_names()
state_array = (
state_array
if state_array.ndim == 2
else jnp.expand_dims(state_array, axis=0)
)
batch_size = state_array.shape[0]
num_inserted = len(self._nodes_in_view)
num_inserted = (
len(self._nodes_in_view)
if state_name in comp_states
else len(self._edges_in_view)
)
is_multiple = num_inserted == batch_size
state_array = (
state_array if is_multiple else jnp.repeat(state_array, len(view), axis=0)
state_array
if is_multiple
else jnp.repeat(state_array, num_inserted, axis=0)
)
assert batch_size in [
1,
Expand Down Expand Up @@ -1638,23 +1660,28 @@ def delete_stimuli(self):
"""Removes all stimuli from the module."""
self.delete_clamps("i")

def delete_clamps(self, state_name: str):
def delete_clamps(self, state_name: Optional[str] = None):
"""Removes all clamps of the given state from the module."""
if state_name in self.externals:
keep_inds = ~np.isin(
self.base.external_inds[state_name], self._nodes_in_view
)
base_exts = self.base.externals
base_exts_inds = self.base.external_inds
if np.all(~keep_inds):
base_exts.pop(state_name, None)
base_exts_inds.pop(state_name, None)
all_externals = list(self.externals.keys())
if "i" in all_externals:
all_externals.remove("i")
state_names = all_externals if state_name is None else [state_name]
for state_name in state_names:
Comment on lines +1665 to +1669
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took the liberty of making state_name optional. If none is supplied delete_clamps() will just remove all clamps.

if state_name in self.externals:
keep_inds = ~np.isin(
self.base.external_inds[state_name], self._nodes_in_view
)
base_exts = self.base.externals
base_exts_inds = self.base.external_inds
if np.all(~keep_inds):
base_exts.pop(state_name, None)
base_exts_inds.pop(state_name, None)
else:
base_exts[state_name] = base_exts[state_name][keep_inds]
base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds]
self._update_view()
else:
base_exts[state_name] = base_exts[state_name][keep_inds]
base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds]
self._update_view()
else:
pass # does not have to be deleted if not in externals
pass # does not have to be deleted if not in externals

def insert(self, channel: Channel):
"""Insert a channel into the module.
Expand Down
51 changes: 51 additions & 0 deletions tests/test_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import jax

from jaxley.connect import connect
from jaxley.synapses.ionotropic import IonotropicSynapse

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
from typing import Optional
Expand All @@ -25,6 +28,54 @@ def test_clamp_pointneuron():
assert np.all(v[:, 1:] == -50.0)


def test_clamp_currents():
comp = jx.Compartment()
comp.insert(HH())
comp.record("i_HH")

# test clamp
comp.clamp("i_HH", 1.0 * jnp.ones((1000,)))
i1 = jx.integrate(comp, t_max=1.0)
assert np.all(i1[:, 1:] == 1.0)

# test data clamp
data_clamps = None
ipts = 1.0 * jnp.ones((1000,))
data_clamps = comp.data_clamp("i_HH", ipts, data_clamps=data_clamps)

i2 = jx.integrate(comp, data_clamps=data_clamps, t_max=1.0)
assert np.all(i2[:, 1:] == 1.0)

assert np.all(np.isclose(i1, i2))


def test_clamp_synapse():
comp = jx.Compartment()
branch = jx.Branch(comp, 1)
cell1 = jx.Cell(branch, [-1])
cell2 = jx.Cell(branch, [-1])
net = jx.Network([cell1, cell2])
connect(net[0, 0, 0], net[1, 0, 0], IonotropicSynapse())
net.record("IonotropicSynapse_s")

# test clamp
net.clamp("IonotropicSynapse_s", 1.0 * jnp.ones((1000,)))
s1 = jx.integrate(net, t_max=1.0)
assert np.all(s1[:, 1:] == 1.0)

net.delete_clamps()

# test data clamp
data_clamps = None
ipts = 1.0 * jnp.ones((1000,))
data_clamps = net.data_clamp("IonotropicSynapse_s", ipts, data_clamps=data_clamps)

s2 = jx.integrate(net, data_clamps=data_clamps, t_max=1.0)
assert np.all(s2[:, 1:] == 1.0)

assert np.all(np.isclose(s1, s2))


def test_clamp_multicompartment():
comp = jx.Compartment()
branch = jx.Branch(comp, 4)
Expand Down