Skip to content
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Implements Brette et al. (2005), 'Adaptive exponential integrate-and-fire model
- Allow data_clamp to clamp multiple different states without silently only clamping the last state (#773, @kyralianaka) and fixed checkpointing for this case (#786, @kyralianaka)
- Fix issue causing some `View`s to take too long to create (#791, @alexpejovic)
- Fix single point branch plotting with type="comp" (#797, @jnsbck)
- jx.integrate took O(n^2) time with n compartments on the backwards pass. Instead of backpropagating through the forward solve, we now use a custom_jvp (another tridiagonal solve, which is O(n)) (#795 @manuelgloeckler)

# 0.13.0

Expand Down
18 changes: 15 additions & 3 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,18 @@ def _init_solver_jaxley_dhs_solve(
parents[nodes[:, 0]] = nodes[:, 1]
self._dhs_solve_indexer["parent_lookup"] = parents.astype(int)

# Precompute flat child/parent index arrays for the custom JVP of the solve.
# These are used to efficiently compute dA @ x (the matrix-vector product of
# the tangent matrix with the primal solution).
node_order_grouped = self._dhs_solve_indexer["node_order_grouped"]
if len(node_order_grouped) > 0:
all_edges = np.concatenate(node_order_grouped, axis=0)
self._dhs_solve_indexer["all_children"] = all_edges[:, 0].astype(int)
self._dhs_solve_indexer["all_parents"] = all_edges[:, 1].astype(int)
else:
self._dhs_solve_indexer["all_children"] = np.asarray([], dtype=int)
self._dhs_solve_indexer["all_parents"] = np.asarray([], dtype=int)

def set(self, key: str, val: float | ArrayLike):
"""Set parameter of module (or its view) to a new value.

Expand Down Expand Up @@ -2416,7 +2428,7 @@ def record(self, state: str = "v", verbose=True):
self.base.recordings = self.base.recordings.loc[~has_duplicates]
if verbose:
print(
f"Added {len(in_view)-sum(has_duplicates)} recordings. See `.recordings` for details."
f"Added {len(in_view) - sum(has_duplicates)} recordings. See `.recordings` for details."
)

def _update_view(self):
Expand Down Expand Up @@ -3400,7 +3412,7 @@ def compute_xyz(self):
num_children_of_parent = num_children[parents[b]]
if num_children_of_parent > 1:
y_offset = (
((index_of_child[b] / (num_children_of_parent - 1))) - 0.5
(index_of_child[b] / (num_children_of_parent - 1)) - 0.5
) * y_offset_multiplier[levels[b]]
else:
y_offset = 0.0
Expand Down Expand Up @@ -3500,7 +3512,7 @@ def move_to(
# can only iterate over cells for networks
# lambda makes sure that generator can be created multiple times
base_is_net = self.base._current_view == "network"
cells = lambda: (self.cells if base_is_net else [self])
cells = lambda: self.cells if base_is_net else [self]

root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in cells()])
root_xyz = root_xyz_cells[0] if isinstance(x, float) else root_xyz_cells
Expand Down
11 changes: 11 additions & 0 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,17 @@ def _init_solver_jaxley_dhs_solve(self, *args, **kwargs) -> None:
self._dhs_solve_indexer["node_order_grouped"] = dhs_group_comps_into_levels(
self._dhs_solve_indexer["node_order"]
)

# Precompute flat child/parent index arrays for the custom JVP of the solve.
node_order_grouped = self._dhs_solve_indexer["node_order_grouped"]
if len(node_order_grouped) > 0:
all_edges = np.concatenate(node_order_grouped, axis=0)
self._dhs_solve_indexer["all_children"] = all_edges[:, 0].astype(int)
self._dhs_solve_indexer["all_parents"] = all_edges[:, 1].astype(int)
else:
self._dhs_solve_indexer["all_children"] = np.asarray([], dtype=int)
self._dhs_solve_indexer["all_parents"] = np.asarray([], dtype=int)

self._init_view()

def _step_synapse(
Expand Down
194 changes: 143 additions & 51 deletions jaxley/solver_voltage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,132 @@
from jaxley.solver_gate import exponential_euler


def _pad_comp_edges(comp_edges) -> np.ndarray:
"""Convert grouped DHS edges to a dense integer array padded with `-1`.

`node_order_grouped` is often stored as a ragged list of arrays, one per depth
level. Padding with `-1` is safe because the voltage solve appends a spurious
compartment at the end of every solve vector, and `-1` indexes that no-op slot.
"""
if isinstance(comp_edges, np.ndarray) and comp_edges.dtype != object:
return comp_edges.astype(np.int32, copy=False)

comp_edges = list(comp_edges)
if len(comp_edges) == 0:
return np.empty((0, 0, 2), dtype=np.int32)

level_arrays = [np.asarray(level, dtype=np.int32) for level in comp_edges]
max_width = max(level.shape[0] for level in level_arrays)
padded = np.full((len(level_arrays), max_width, 2), -1, dtype=np.int32)

for idx, level in enumerate(level_arrays):
if level.ndim != 2 or level.shape[1] != 2:
raise ValueError(
"Expected each DHS level to have shape (num_edges_in_level, 2)."
)
padded[idx, : level.shape[0], :] = level

return padded


def _make_dhs_solve(solve_indexer, optimize_for_gpu, n_nodes):
"""Create a DHS solve function with custom JVP for efficient differentiation.

The tridiagonal solve A x = b has JVP: dx = A^{-1} (db - dA x). This means the
tangent is itself a solve with the same matrix A but a different RHS. By using a
custom_jvp, we avoid JAX having to differentiate through the O(n)-step fori_loop,
which causes O(n^2) memory traffic in the backward pass.

JAX automatically derives the transpose (VJP) from the custom JVP rule, so this
works for both forward-mode and reverse-mode differentiation.
"""
ordered_comp_edges = solve_indexer["node_order_grouped"]
flipped_comp_edges = list(reversed(ordered_comp_edges))
all_children = np.asarray(solve_indexer["all_children"], dtype=np.int32)
all_parents = np.asarray(solve_indexer["all_parents"], dtype=np.int32)

steps = len(flipped_comp_edges)

ordered_comp_edges_np = _pad_comp_edges(ordered_comp_edges)
flipped_comp_edges_np = _pad_comp_edges(flipped_comp_edges)

def _raw_solve(diags, lowers, uppers, solves):
"""Solve the tree-structured linear system (no custom JVP)."""
if not optimize_for_gpu:
init = (diags, solves, lowers, uppers, flipped_comp_edges_np)
diags_out, solves_out, _, _, _ = fori_loop(
0, steps, _comp_based_triang, init
)

lowers_norm = lowers / diags_out
solves_norm = solves_out / diags_out
diags_out = jnp.ones_like(solves_norm)
init = (solves_norm, lowers_norm, ordered_comp_edges_np)
solves_out, _, _ = fori_loop(0, steps, _comp_based_backsub, init)

return solves_out / diags_out
else:
d, s = diags, solves
for i in range(steps):
d, s, _, _, _ = _comp_based_triang(
i, (d, s, lowers, uppers, flipped_comp_edges_np)
)

d, s = _comp_based_backsub_recursive_doubling(
d, s, lowers, steps, solve_indexer["parent_lookup"]
)
return s / d

@jax.custom_jvp
def _solve(diags, lowers, uppers, solves):
return _raw_solve(diags, lowers, uppers, solves)

@_solve.defjvp
def _solve_jvp(primals, tangents):
diags, lowers, uppers, solves = primals
d_diags, d_lowers, d_uppers, d_solves = tangents

# Primal output: x = A^{-1} b
x = _raw_solve(diags, lowers, uppers, solves)

# Compute dA @ x. For each edge (child, parent), the matrix entries are:
# A[parent, child] = uppers[child]
# A[child, parent] = lowers[child]
# (this is consistent with the triangulation multiplier using `uppers`).
#
# Therefore:
# (dA @ x)[parent] += d_uppers[child] * x[child]
# (dA @ x)[child] += d_lowers[child] * x[parent]
dA_x = d_diags * x
dA_x = dA_x.at[all_parents].add(d_uppers[all_children] * x[all_children])
dA_x = dA_x.at[all_children].add(d_lowers[all_children] * x[all_parents])

# JVP: dx = A^{-1} (db - dA @ x)
new_rhs = d_solves - dA_x
dx = _raw_solve(diags, lowers, uppers, new_rhs)

return x, dx

return _solve


# Cache key for storing the compiled solve function directly on the solve_indexer
# dict. Using a tuple key avoids collisions with the string keys it already uses.
# The cached function is automatically discarded when the owning Module (and its
# _dhs_solve_indexer dict) is garbage-collected, and naturally invalidated when
# _init_solver_jaxley_dhs_solve() creates a fresh dict.
_SOLVE_FN_KEY = "_cached_solve_fn"


def _get_dhs_solve(solve_indexer: dict, optimize_for_gpu: bool, n_nodes: int):
cache_key = (_SOLVE_FN_KEY, optimize_for_gpu, n_nodes)
solve_fn = solve_indexer.get(cache_key)
if solve_fn is None:
solve_fn = _make_dhs_solve(solve_indexer, optimize_for_gpu, n_nodes)
solve_indexer[cache_key] = solve_fn
return solve_fn


def step_voltage_implicit_with_dhs_solve(
voltages,
voltage_terms,
Expand Down Expand Up @@ -79,55 +205,24 @@ def step_voltage_implicit_with_dhs_solve(
# Reorder the lower and upper values.
lowers = lowers_and_uppers[solve_indexer["map_to_solve_order_lower"]]
uppers = lowers_and_uppers[solve_indexer["map_to_solve_order_upper"]]
ordered_comp_edges = solve_indexer["node_order_grouped"]
flipped_comp_edges = list(reversed(ordered_comp_edges))

# Add a spurious compartment that is modified by the masking.
diags = jnp.concatenate([diags, jnp.asarray([1.0])])
solves = jnp.concatenate([solves, jnp.asarray([0.0])])
uppers = jnp.concatenate([uppers, jnp.asarray([0.0])])
lowers = jnp.concatenate([lowers, jnp.asarray([0.0])])

# Solve the voltage equations.
#
steps = len(flipped_comp_edges)
if not optimize_for_gpu:
# Cast from a list to a np.array.
# `ordered_comp_edges` has shape `(num_levels, num_comps_per_level, 2)`,
# and `num_comps_per_level=1` for CPU.
ordered_comp_edges = np.asarray(ordered_comp_edges)
flipped_comp_edges = np.asarray(flipped_comp_edges)

# Triangulate.
steps = len(flipped_comp_edges)
init = (diags, solves, lowers, uppers, flipped_comp_edges)
diags, solves, _, _, _ = fori_loop(0, steps, _comp_based_triang, init)

# Backsubstitute.
lowers /= diags
solves /= diags
diags = jnp.ones_like(solves)
init = (solves, lowers, ordered_comp_edges)
solves, _, _ = fori_loop(0, steps, _comp_based_backsub, init)
else:
# Triangulate by unrolling the loop of the levels.
for i in range(steps):
diags, solves, _, _, _ = _comp_based_triang(
i, (diags, solves, lowers, uppers, flipped_comp_edges)
)
# Get or create the solve function with custom JVP.
dhs_solve = _get_dhs_solve(solve_indexer, optimize_for_gpu, int(n_nodes))

# Backsubstitute with recursive doubling.
diags, solves = _comp_based_backsub_recursive_doubling(
diags, solves, lowers, steps, n_nodes, solve_indexer["parent_lookup"]
)
# Solve the voltage equations with efficient custom JVP.
solution = dhs_solve(diags, lowers, uppers, solves)

# Remove the spurious compartment. This compartment got modified by masking of
# compartments in certain levels.
diags = diags[:-1]
solves = solves[:-1]
# Remove the spurious compartment.
solution = solution[:-1]
else:
solution = solves / diags

# Get inverse of the diagonalized matrix.
solution = solves / diags
solution = solution[solve_indexer["inv_map_to_solve_order"]]

return solution
Expand Down Expand Up @@ -180,7 +275,6 @@ def _comp_based_backsub_recursive_doubling(
solves: ArrayLike,
lowers: ArrayLike,
steps: int,
n_nodes: int,
parent_lookup: np.ndarray,
) -> tuple[Array, Array]:
"""Backsubstitute with recursive doubling.
Expand Down Expand Up @@ -238,17 +332,15 @@ def _comp_based_backsub_recursive_doubling(
lower_effect = -lowers / diags
solve_effect = solves / diags

step = 1
while step <= steps:
# For each node, get its k-step parent, where k=`step`.
k_step_parent = np.arange(n_nodes + 1)
for _ in range(step):
k_step_parent = parent_lookup[k_step_parent]

# Update.
solve_effect = lower_effect * solve_effect[k_step_parent] + solve_effect
lower_effect *= lower_effect[k_step_parent]
step *= 2
num_recursive_steps = int(np.ceil(np.log2(steps + 1))) if steps > 0 else 0
parent_jump = jnp.asarray(parent_lookup, dtype=jnp.int32)

# Only O(log2(steps)) iterations; unrolling these often recovers GPU runtime
# without significantly increasing compile time.
for _ in range(num_recursive_steps):
solve_effect = lower_effect * solve_effect[parent_jump] + solve_effect
lower_effect = lower_effect * lower_effect[parent_jump]
parent_jump = parent_jump[parent_jump]

# We have to return a `diags` because the final solution is computed as
# `solves/diags` (see `step_voltage_implicit_with_dhs_solve`). For recursive
Expand Down
6 changes: 3 additions & 3 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,10 @@ def equal_both_nan_or_empty_df(a: pd.DataFrame, b: pd.DataFrame) -> bool:
b = b.drop(columns="xyzr")
if a.empty and b.empty:
return True
a[a.isna()] = -1
b[b.isna()] = -1
if set(a.columns) != set(b.columns):
return False
else:
a = a[b.columns]
return (a == b).all()

equal = a.eq(b) | (a.isna() & b.isna())
return bool(equal.to_numpy().all())
7 changes: 4 additions & 3 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,10 @@ def test_from_graph_vs_NEURON(file):
errors = neuron_df["neuron_idx"].to_frame()
errors["jx_idx"] = jx_df["jx_idx"]
errors[["x", "y", "z"]] = neuron_df[["x", "y", "z"]] - jx_df[["x", "y", "z"]]
errors["xyz"] = np.sqrt((errors[["x", "y", "z"]] ** 2).sum(axis=1))
errors["radius"] = neuron_df["radius"] - jx_df["radius"]
errors["length"] = neuron_df["length"] - jx_df["length"]
xyz_vals = errors[["x", "y", "z"]].to_numpy(dtype=float)
errors["xyz"] = np.sqrt((xyz_vals**2).sum(axis=1))
errors["radius"] = (neuron_df["radius"] - jx_df["radius"]).astype(float)
errors["length"] = (neuron_df["length"] - jx_df["length"]).astype(float)

assert sum(errors.groupby("jx_idx")["xyz"].max() > 1e-3) == 0
assert sum(errors.groupby("jx_idx")["radius"].max() > 1e-3) == 0
Expand Down
42 changes: 42 additions & 0 deletions tests/test_solver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import jax
import jax.numpy as jnp
import numpy as np
import pytest

import jaxley as jx
from jaxley.channels import HH
from jaxley.connect import connect
from jaxley.solver_gate import exponential_euler
from jaxley.solver_voltage import _make_dhs_solve
from jaxley.synapses import IonotropicSynapse


Expand Down Expand Up @@ -70,3 +73,42 @@ def test_exp_euler_solver_customization(SimpleCell):
)
v = jx.integrate(cell, solver="exp_euler")
assert np.invert(np.any(np.isnan(v)))


@pytest.mark.parametrize("optimize_for_gpu", [False, True])
def test_dhs_solve_handles_ragged_grouped_edges(optimize_for_gpu):
solve_indexer = {
"node_order_grouped": [
np.asarray([[1, 0], [2, 0]], dtype=int),
np.asarray([[3, 1]], dtype=int),
],
"all_children": np.asarray([1, 2, 3], dtype=int),
"all_parents": np.asarray([0, 0, 1], dtype=int),
"parent_lookup": np.asarray([-1, 0, 0, 1, -1], dtype=int),
}
solve = _make_dhs_solve(
solve_indexer=solve_indexer,
optimize_for_gpu=optimize_for_gpu,
n_nodes=4,
)

diags = jnp.asarray([4.0, 5.0, 6.0, 7.0, 1.0])
lowers = jnp.asarray([0.0, -0.4, -0.3, -0.2, 0.0])
uppers = jnp.asarray([0.0, -0.5, -0.1, -0.6, 0.0])
rhs = jnp.asarray([1.0, 2.0, 3.0, 4.0, 0.0])

matrix = np.diag(np.asarray(diags))
for child, parent in zip(
solve_indexer["all_children"], solve_indexer["all_parents"]
):
matrix[parent, child] = float(uppers[child])
matrix[child, parent] = float(lowers[child])

expected = np.linalg.solve(matrix, np.asarray(rhs))
actual = np.asarray(jax.jit(solve)(diags, lowers, uppers, rhs))
np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6)

grad_fn = jax.jit(jax.grad(lambda b: jnp.sum(solve(diags, lowers, uppers, b))))
actual_grad = np.asarray(grad_fn(rhs))
expected_grad = np.linalg.solve(matrix.T, np.ones_like(expected))
np.testing.assert_allclose(actual_grad, expected_grad, rtol=1e-6, atol=1e-6)
Loading