diff --git a/CHANGELOG.md b/CHANGELOG.md index 78315285..fe59cedb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ Implements Brette et al. (2005), 'Adaptive exponential integrate-and-fire model - Allow data_clamp to clamp multiple different states without silently only clamping the last state (#773, @kyralianaka) and fixed checkpointing for this case (#786, @kyralianaka) - Fix issue causing some `View`s to take too long to create (#791, @alexpejovic) - Fix single point branch plotting with type="comp" (#797, @jnsbck) +- jx.integrate took O(n^2) time with n compartments on the backwards pass. Instead of backpropagating through the forward solve, we now use a custom_jvp (another tridiagonal solve, which is O(n)) (#795 @manuelgloeckler) # 0.13.0 diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index b148595d..050b9283 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1193,6 +1193,18 @@ def _init_solver_jaxley_dhs_solve( parents[nodes[:, 0]] = nodes[:, 1] self._dhs_solve_indexer["parent_lookup"] = parents.astype(int) + # Precompute flat child/parent index arrays for the custom JVP of the solve. + # These are used to efficiently compute dA @ x (the matrix-vector product of + # the tangent matrix with the primal solution). + node_order_grouped = self._dhs_solve_indexer["node_order_grouped"] + if len(node_order_grouped) > 0: + all_edges = np.concatenate(node_order_grouped, axis=0) + self._dhs_solve_indexer["all_children"] = all_edges[:, 0].astype(int) + self._dhs_solve_indexer["all_parents"] = all_edges[:, 1].astype(int) + else: + self._dhs_solve_indexer["all_children"] = np.asarray([], dtype=int) + self._dhs_solve_indexer["all_parents"] = np.asarray([], dtype=int) + def set(self, key: str, val: float | ArrayLike): """Set parameter of module (or its view) to a new value. @@ -2416,7 +2428,7 @@ def record(self, state: str = "v", verbose=True): self.base.recordings = self.base.recordings.loc[~has_duplicates] if verbose: print( - f"Added {len(in_view)-sum(has_duplicates)} recordings. See `.recordings` for details." + f"Added {len(in_view) - sum(has_duplicates)} recordings. See `.recordings` for details." ) def _update_view(self): @@ -3400,7 +3412,7 @@ def compute_xyz(self): num_children_of_parent = num_children[parents[b]] if num_children_of_parent > 1: y_offset = ( - ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5 + (index_of_child[b] / (num_children_of_parent - 1)) - 0.5 ) * y_offset_multiplier[levels[b]] else: y_offset = 0.0 @@ -3500,7 +3512,7 @@ def move_to( # can only iterate over cells for networks # lambda makes sure that generator can be created multiple times base_is_net = self.base._current_view == "network" - cells = lambda: (self.cells if base_is_net else [self]) + cells = lambda: self.cells if base_is_net else [self] root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in cells()]) root_xyz = root_xyz_cells[0] if isinstance(x, float) else root_xyz_cells diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index fc22b671..5a397ed8 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -262,6 +262,17 @@ def _init_solver_jaxley_dhs_solve(self, *args, **kwargs) -> None: self._dhs_solve_indexer["node_order_grouped"] = dhs_group_comps_into_levels( self._dhs_solve_indexer["node_order"] ) + + # Precompute flat child/parent index arrays for the custom JVP of the solve. + node_order_grouped = self._dhs_solve_indexer["node_order_grouped"] + if len(node_order_grouped) > 0: + all_edges = np.concatenate(node_order_grouped, axis=0) + self._dhs_solve_indexer["all_children"] = all_edges[:, 0].astype(int) + self._dhs_solve_indexer["all_parents"] = all_edges[:, 1].astype(int) + else: + self._dhs_solve_indexer["all_children"] = np.asarray([], dtype=int) + self._dhs_solve_indexer["all_parents"] = np.asarray([], dtype=int) + self._init_view() def _step_synapse( diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py index f1bc38e4..4b7eb01c 100644 --- a/jaxley/solver_voltage.py +++ b/jaxley/solver_voltage.py @@ -14,6 +14,132 @@ from jaxley.solver_gate import exponential_euler +def _pad_comp_edges(comp_edges) -> np.ndarray: + """Convert grouped DHS edges to a dense integer array padded with `-1`. + + `node_order_grouped` is often stored as a ragged list of arrays, one per depth + level. Padding with `-1` is safe because the voltage solve appends a spurious + compartment at the end of every solve vector, and `-1` indexes that no-op slot. + """ + if isinstance(comp_edges, np.ndarray) and comp_edges.dtype != object: + return comp_edges.astype(np.int32, copy=False) + + comp_edges = list(comp_edges) + if len(comp_edges) == 0: + return np.empty((0, 0, 2), dtype=np.int32) + + level_arrays = [np.asarray(level, dtype=np.int32) for level in comp_edges] + max_width = max(level.shape[0] for level in level_arrays) + padded = np.full((len(level_arrays), max_width, 2), -1, dtype=np.int32) + + for idx, level in enumerate(level_arrays): + if level.ndim != 2 or level.shape[1] != 2: + raise ValueError( + "Expected each DHS level to have shape (num_edges_in_level, 2)." + ) + padded[idx, : level.shape[0], :] = level + + return padded + + +def _make_dhs_solve(solve_indexer, optimize_for_gpu, n_nodes): + """Create a DHS solve function with custom JVP for efficient differentiation. + + The tridiagonal solve A x = b has JVP: dx = A^{-1} (db - dA x). This means the + tangent is itself a solve with the same matrix A but a different RHS. By using a + custom_jvp, we avoid JAX having to differentiate through the O(n)-step fori_loop, + which causes O(n^2) memory traffic in the backward pass. + + JAX automatically derives the transpose (VJP) from the custom JVP rule, so this + works for both forward-mode and reverse-mode differentiation. + """ + ordered_comp_edges = solve_indexer["node_order_grouped"] + flipped_comp_edges = list(reversed(ordered_comp_edges)) + all_children = np.asarray(solve_indexer["all_children"], dtype=np.int32) + all_parents = np.asarray(solve_indexer["all_parents"], dtype=np.int32) + + steps = len(flipped_comp_edges) + + ordered_comp_edges_np = _pad_comp_edges(ordered_comp_edges) + flipped_comp_edges_np = _pad_comp_edges(flipped_comp_edges) + + def _raw_solve(diags, lowers, uppers, solves): + """Solve the tree-structured linear system (no custom JVP).""" + if not optimize_for_gpu: + init = (diags, solves, lowers, uppers, flipped_comp_edges_np) + diags_out, solves_out, _, _, _ = fori_loop( + 0, steps, _comp_based_triang, init + ) + + lowers_norm = lowers / diags_out + solves_norm = solves_out / diags_out + diags_out = jnp.ones_like(solves_norm) + init = (solves_norm, lowers_norm, ordered_comp_edges_np) + solves_out, _, _ = fori_loop(0, steps, _comp_based_backsub, init) + + return solves_out / diags_out + else: + d, s = diags, solves + for i in range(steps): + d, s, _, _, _ = _comp_based_triang( + i, (d, s, lowers, uppers, flipped_comp_edges_np) + ) + + d, s = _comp_based_backsub_recursive_doubling( + d, s, lowers, steps, solve_indexer["parent_lookup"] + ) + return s / d + + @jax.custom_jvp + def _solve(diags, lowers, uppers, solves): + return _raw_solve(diags, lowers, uppers, solves) + + @_solve.defjvp + def _solve_jvp(primals, tangents): + diags, lowers, uppers, solves = primals + d_diags, d_lowers, d_uppers, d_solves = tangents + + # Primal output: x = A^{-1} b + x = _raw_solve(diags, lowers, uppers, solves) + + # Compute dA @ x. For each edge (child, parent), the matrix entries are: + # A[parent, child] = uppers[child] + # A[child, parent] = lowers[child] + # (this is consistent with the triangulation multiplier using `uppers`). + # + # Therefore: + # (dA @ x)[parent] += d_uppers[child] * x[child] + # (dA @ x)[child] += d_lowers[child] * x[parent] + dA_x = d_diags * x + dA_x = dA_x.at[all_parents].add(d_uppers[all_children] * x[all_children]) + dA_x = dA_x.at[all_children].add(d_lowers[all_children] * x[all_parents]) + + # JVP: dx = A^{-1} (db - dA @ x) + new_rhs = d_solves - dA_x + dx = _raw_solve(diags, lowers, uppers, new_rhs) + + return x, dx + + return _solve + + +# Cache key for storing the compiled solve function directly on the solve_indexer +# dict. Using a tuple key avoids collisions with the string keys it already uses. +# The cached function is automatically discarded when the owning Module (and its +# _dhs_solve_indexer dict) is garbage-collected, and naturally invalidated when +# _init_solver_jaxley_dhs_solve() creates a fresh dict. +_SOLVE_FN_KEY = "_cached_solve_fn" + + +def _get_dhs_solve(solve_indexer: dict, optimize_for_gpu: bool, n_nodes: int): + cache_key = (_SOLVE_FN_KEY, optimize_for_gpu, n_nodes) + solve_fn = solve_indexer.get(cache_key) + if solve_fn is None: + solve_fn = _make_dhs_solve(solve_indexer, optimize_for_gpu, n_nodes) + solve_indexer[cache_key] = solve_fn + return solve_fn + + def step_voltage_implicit_with_dhs_solve( voltages, voltage_terms, @@ -79,8 +205,6 @@ def step_voltage_implicit_with_dhs_solve( # Reorder the lower and upper values. lowers = lowers_and_uppers[solve_indexer["map_to_solve_order_lower"]] uppers = lowers_and_uppers[solve_indexer["map_to_solve_order_upper"]] - ordered_comp_edges = solve_indexer["node_order_grouped"] - flipped_comp_edges = list(reversed(ordered_comp_edges)) # Add a spurious compartment that is modified by the masking. diags = jnp.concatenate([diags, jnp.asarray([1.0])]) @@ -88,46 +212,17 @@ def step_voltage_implicit_with_dhs_solve( uppers = jnp.concatenate([uppers, jnp.asarray([0.0])]) lowers = jnp.concatenate([lowers, jnp.asarray([0.0])]) - # Solve the voltage equations. - # - steps = len(flipped_comp_edges) - if not optimize_for_gpu: - # Cast from a list to a np.array. - # `ordered_comp_edges` has shape `(num_levels, num_comps_per_level, 2)`, - # and `num_comps_per_level=1` for CPU. - ordered_comp_edges = np.asarray(ordered_comp_edges) - flipped_comp_edges = np.asarray(flipped_comp_edges) - - # Triangulate. - steps = len(flipped_comp_edges) - init = (diags, solves, lowers, uppers, flipped_comp_edges) - diags, solves, _, _, _ = fori_loop(0, steps, _comp_based_triang, init) - - # Backsubstitute. - lowers /= diags - solves /= diags - diags = jnp.ones_like(solves) - init = (solves, lowers, ordered_comp_edges) - solves, _, _ = fori_loop(0, steps, _comp_based_backsub, init) - else: - # Triangulate by unrolling the loop of the levels. - for i in range(steps): - diags, solves, _, _, _ = _comp_based_triang( - i, (diags, solves, lowers, uppers, flipped_comp_edges) - ) + # Get or create the solve function with custom JVP. + dhs_solve = _get_dhs_solve(solve_indexer, optimize_for_gpu, int(n_nodes)) - # Backsubstitute with recursive doubling. - diags, solves = _comp_based_backsub_recursive_doubling( - diags, solves, lowers, steps, n_nodes, solve_indexer["parent_lookup"] - ) + # Solve the voltage equations with efficient custom JVP. + solution = dhs_solve(diags, lowers, uppers, solves) - # Remove the spurious compartment. This compartment got modified by masking of - # compartments in certain levels. - diags = diags[:-1] - solves = solves[:-1] + # Remove the spurious compartment. + solution = solution[:-1] + else: + solution = solves / diags - # Get inverse of the diagonalized matrix. - solution = solves / diags solution = solution[solve_indexer["inv_map_to_solve_order"]] return solution @@ -180,7 +275,6 @@ def _comp_based_backsub_recursive_doubling( solves: ArrayLike, lowers: ArrayLike, steps: int, - n_nodes: int, parent_lookup: np.ndarray, ) -> tuple[Array, Array]: """Backsubstitute with recursive doubling. @@ -238,17 +332,15 @@ def _comp_based_backsub_recursive_doubling( lower_effect = -lowers / diags solve_effect = solves / diags - step = 1 - while step <= steps: - # For each node, get its k-step parent, where k=`step`. - k_step_parent = np.arange(n_nodes + 1) - for _ in range(step): - k_step_parent = parent_lookup[k_step_parent] - - # Update. - solve_effect = lower_effect * solve_effect[k_step_parent] + solve_effect - lower_effect *= lower_effect[k_step_parent] - step *= 2 + num_recursive_steps = int(np.ceil(np.log2(steps + 1))) if steps > 0 else 0 + parent_jump = jnp.asarray(parent_lookup, dtype=jnp.int32) + + # Only O(log2(steps)) iterations; unrolling these often recovers GPU runtime + # without significantly increasing compile time. + for _ in range(num_recursive_steps): + solve_effect = lower_effect * solve_effect[parent_jump] + solve_effect + lower_effect = lower_effect * lower_effect[parent_jump] + parent_jump = parent_jump[parent_jump] # We have to return a `diags` because the final solution is computed as # `solves/diags` (see `step_voltage_implicit_with_dhs_solve`). For recursive diff --git a/tests/helpers.py b/tests/helpers.py index 8b98cfc7..efe95a65 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -154,10 +154,10 @@ def equal_both_nan_or_empty_df(a: pd.DataFrame, b: pd.DataFrame) -> bool: b = b.drop(columns="xyzr") if a.empty and b.empty: return True - a[a.isna()] = -1 - b[b.isna()] = -1 if set(a.columns) != set(b.columns): return False else: a = a[b.columns] - return (a == b).all() + + equal = a.eq(b) | (a.isna() & b.isna()) + return bool(equal.to_numpy().all()) diff --git a/tests/test_graph.py b/tests/test_graph.py index 4ddd6e55..d05e7864 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -266,9 +266,10 @@ def test_from_graph_vs_NEURON(file): errors = neuron_df["neuron_idx"].to_frame() errors["jx_idx"] = jx_df["jx_idx"] errors[["x", "y", "z"]] = neuron_df[["x", "y", "z"]] - jx_df[["x", "y", "z"]] - errors["xyz"] = np.sqrt((errors[["x", "y", "z"]] ** 2).sum(axis=1)) - errors["radius"] = neuron_df["radius"] - jx_df["radius"] - errors["length"] = neuron_df["length"] - jx_df["length"] + xyz_vals = errors[["x", "y", "z"]].to_numpy(dtype=float) + errors["xyz"] = np.sqrt((xyz_vals**2).sum(axis=1)) + errors["radius"] = (neuron_df["radius"] - jx_df["radius"]).astype(float) + errors["length"] = (neuron_df["length"] - jx_df["length"]).astype(float) assert sum(errors.groupby("jx_idx")["xyz"].max() > 1e-3) == 0 assert sum(errors.groupby("jx_idx")["radius"].max() > 1e-3) == 0 diff --git a/tests/test_solver.py b/tests/test_solver.py index 464881c1..04703471 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -1,6 +1,8 @@ # This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is # licensed under the Apache License Version 2.0, see +import jax +import jax.numpy as jnp import numpy as np import pytest @@ -8,6 +10,7 @@ from jaxley.channels import HH from jaxley.connect import connect from jaxley.solver_gate import exponential_euler +from jaxley.solver_voltage import _make_dhs_solve from jaxley.synapses import IonotropicSynapse @@ -70,3 +73,42 @@ def test_exp_euler_solver_customization(SimpleCell): ) v = jx.integrate(cell, solver="exp_euler") assert np.invert(np.any(np.isnan(v))) + + +@pytest.mark.parametrize("optimize_for_gpu", [False, True]) +def test_dhs_solve_handles_ragged_grouped_edges(optimize_for_gpu): + solve_indexer = { + "node_order_grouped": [ + np.asarray([[1, 0], [2, 0]], dtype=int), + np.asarray([[3, 1]], dtype=int), + ], + "all_children": np.asarray([1, 2, 3], dtype=int), + "all_parents": np.asarray([0, 0, 1], dtype=int), + "parent_lookup": np.asarray([-1, 0, 0, 1, -1], dtype=int), + } + solve = _make_dhs_solve( + solve_indexer=solve_indexer, + optimize_for_gpu=optimize_for_gpu, + n_nodes=4, + ) + + diags = jnp.asarray([4.0, 5.0, 6.0, 7.0, 1.0]) + lowers = jnp.asarray([0.0, -0.4, -0.3, -0.2, 0.0]) + uppers = jnp.asarray([0.0, -0.5, -0.1, -0.6, 0.0]) + rhs = jnp.asarray([1.0, 2.0, 3.0, 4.0, 0.0]) + + matrix = np.diag(np.asarray(diags)) + for child, parent in zip( + solve_indexer["all_children"], solve_indexer["all_parents"] + ): + matrix[parent, child] = float(uppers[child]) + matrix[child, parent] = float(lowers[child]) + + expected = np.linalg.solve(matrix, np.asarray(rhs)) + actual = np.asarray(jax.jit(solve)(diags, lowers, uppers, rhs)) + np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) + + grad_fn = jax.jit(jax.grad(lambda b: jnp.sum(solve(diags, lowers, uppers, b)))) + actual_grad = np.asarray(grad_fn(rhs)) + expected_grad = np.linalg.solve(matrix.T, np.ones_like(expected)) + np.testing.assert_allclose(actual_grad, expected_grad, rtol=1e-6, atol=1e-6)