From c6a7120835553545e39e1f48a0c65d398733f211 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 30 Mar 2026 17:59:12 +0200 Subject: [PATCH 01/10] fixing solver backward pass memory acess pattern --- jaxley/modules/base.py | 48 ++++++++----- jaxley/modules/network.py | 29 +++++--- jaxley/solver_voltage.py | 137 +++++++++++++++++++++++++++----------- 3 files changed, 148 insertions(+), 66 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index b148595d..c943df65 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -284,9 +284,9 @@ def __getitem__(self, index): supported_parents = ["network", "cell", "branch"] # cannot index into comp not_group_view = self._current_view not in self.group_names - assert ( - self._current_view in supported_parents or not_group_view - ), "Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof." + assert self._current_view in supported_parents or not_group_view, ( + "Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof." + ) index = index if isinstance(index, tuple) else (index,) child_views = self._childviews() @@ -1193,6 +1193,18 @@ def _init_solver_jaxley_dhs_solve( parents[nodes[:, 0]] = nodes[:, 1] self._dhs_solve_indexer["parent_lookup"] = parents.astype(int) + # Precompute flat child/parent index arrays for the custom JVP of the solve. + # These are used to efficiently compute dA @ x (the matrix-vector product of + # the tangent matrix with the primal solution). + node_order_grouped = self._dhs_solve_indexer["node_order_grouped"] + if len(node_order_grouped) > 0: + all_edges = np.concatenate(node_order_grouped, axis=0) + self._dhs_solve_indexer["all_children"] = all_edges[:, 0].astype(int) + self._dhs_solve_indexer["all_parents"] = all_edges[:, 1].astype(int) + else: + self._dhs_solve_indexer["all_children"] = np.asarray([], dtype=int) + self._dhs_solve_indexer["all_parents"] = np.asarray([], dtype=int) + def set(self, key: str, val: float | ArrayLike): """Set parameter of module (or its view) to a new value. @@ -1589,9 +1601,9 @@ def make_trainable( assert data is not None, f"Key '{key}' not found in nodes or edges" not_nan = ~data[key].isna() data = data.loc[not_nan].copy() - assert ( - len(data) > 0 - ), "No settable parameters found in the selected compartments." + assert len(data) > 0, ( + "No settable parameters found in the selected compartments." + ) grouped_view = data.groupby("controlled_by_param") # Because of this `x.index.values` we cannot support `make_trainable()` on @@ -2030,9 +2042,9 @@ def customize_solver_exp_euler( v = jx.integrate(cell, delta_t=delta_t, t_max=100.0) """ if exp_euler_transition is not None: - self.solver_customizers["exp_euler"][ - "exp_euler_transition" - ] = exp_euler_transition + self.solver_customizers["exp_euler"]["exp_euler_transition"] = ( + exp_euler_transition + ) @only_allow_module def _compute_transition_matrix( @@ -2416,7 +2428,7 @@ def record(self, state: str = "v", verbose=True): self.base.recordings = self.base.recordings.loc[~has_duplicates] if verbose: print( - f"Added {len(in_view)-sum(has_duplicates)} recordings. See `.recordings` for details." + f"Added {len(in_view) - sum(has_duplicates)} recordings. See `.recordings` for details." ) def _update_view(self): @@ -2703,9 +2715,9 @@ def diffuse(self, state: str) -> None: simulated_concentrations = jx.integrate(cell, t_max=5.0) """ - assert not isinstance( - self, View - ), "You can only diffuse ions in the entire module." + assert not isinstance(self, View), ( + "You can only diffuse ions in the entire module." + ) self.base.diffusion_states.append(state) self.base.nodes.loc[self._nodes_in_view, f"axial_diffusion_{state}"] = 1.0 @@ -2726,9 +2738,9 @@ def delete_diffusion(self, state: str) -> None: Args: state: Name of the state that should no longer be diffused. """ - assert ( - state in self.base.diffusion_states - ), f"State {state} is not part of `self.diffusion_states`." + assert state in self.base.diffusion_states, ( + f"State {state} is not part of `self.diffusion_states`." + ) self.base.diffusion_states.remove(state) self.base.nodes.drop(columns=[f"axial_diffusion_{state}"], inplace=True) @@ -3400,7 +3412,7 @@ def compute_xyz(self): num_children_of_parent = num_children[parents[b]] if num_children_of_parent > 1: y_offset = ( - ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5 + (index_of_child[b] / (num_children_of_parent - 1)) - 0.5 ) * y_offset_multiplier[levels[b]] else: y_offset = 0.0 @@ -3500,7 +3512,7 @@ def move_to( # can only iterate over cells for networks # lambda makes sure that generator can be created multiple times base_is_net = self.base._current_view == "network" - cells = lambda: (self.cells if base_is_net else [self]) + cells = lambda: self.cells if base_is_net else [self] root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in cells()]) root_xyz = root_xyz_cells[0] if isinstance(x, float) else root_xyz_cells diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index fc22b671..23f3609d 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -262,6 +262,17 @@ def _init_solver_jaxley_dhs_solve(self, *args, **kwargs) -> None: self._dhs_solve_indexer["node_order_grouped"] = dhs_group_comps_into_levels( self._dhs_solve_indexer["node_order"] ) + + # Precompute flat child/parent index arrays for the custom JVP of the solve. + node_order_grouped = self._dhs_solve_indexer["node_order_grouped"] + if len(node_order_grouped) > 0: + all_edges = np.concatenate(node_order_grouped, axis=0) + self._dhs_solve_indexer["all_children"] = all_edges[:, 0].astype(int) + self._dhs_solve_indexer["all_parents"] = all_edges[:, 1].astype(int) + else: + self._dhs_solve_indexer["all_children"] = np.asarray([], dtype=int) + self._dhs_solve_indexer["all_parents"] = np.asarray([], dtype=int) + self._init_view() def _step_synapse( @@ -295,9 +306,9 @@ def _step_synapse_state( synapse_names = list(grouped_syns.indices.keys()) for i, synapse_type in enumerate(syn_channels): - assert ( - synapse_names[i] == synapse_type._name - ), "Mixup in the ordering of synapses. Please create an issue on Github." + assert synapse_names[i] == synapse_type._name, ( + "Mixup in the ordering of synapses. Please create an issue on Github." + ) synapse_param_names = list(synapse_type.synapse_params.keys()) synapse_state_names = list(synapse_type.synapse_states.keys()) @@ -347,9 +358,9 @@ def _synapse_currents( # offset. diff = 1e-3 for i, synapse_type in enumerate(syn_channels): - assert ( - synapse_names[i] == synapse_type._name - ), "Mixup in the ordering of synapses. Please create an issue on Github." + assert synapse_names[i] == synapse_type._name, ( + "Mixup in the ordering of synapses. Please create an issue on Github." + ) synapse_param_names = list(synapse_type.synapse_params.keys()) synapse_state_names = list(synapse_type.synapse_states.keys()) @@ -439,9 +450,9 @@ def arrange_in_layers( plt.show() """ - assert ( - np.sum(layers) == self.shape[0] - ), "The number of cells in the layers must match the number of cells in the network." + assert np.sum(layers) == self.shape[0], ( + "The number of cells in the layers must match the number of cells in the network." + ) cells_in_layers = [ list(range(sum(layers[:i]), sum(layers[: i + 1]))) for i in range(len(layers)) diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py index f1bc38e4..fdd7b2ff 100644 --- a/jaxley/solver_voltage.py +++ b/jaxley/solver_voltage.py @@ -14,6 +14,91 @@ from jaxley.solver_gate import exponential_euler +def _make_dhs_solve(solve_indexer, optimize_for_gpu, n_nodes): + """Create a DHS solve function with custom JVP for efficient differentiation. + + The tridiagonal solve A x = b has JVP: dx = A^{-1} (db - dA x). This means the + tangent is itself a solve with the same matrix A but a different RHS. By using a + custom_jvp, we avoid JAX having to differentiate through the O(n)-step fori_loop, + which causes O(n^2) memory traffic in the backward pass. + + JAX automatically derives the transpose (VJP) from the custom JVP rule, so this + works for both forward-mode and reverse-mode differentiation. + """ + ordered_comp_edges = solve_indexer["node_order_grouped"] + flipped_comp_edges = list(reversed(ordered_comp_edges)) + all_children = solve_indexer["all_children"] + all_parents = solve_indexer["all_parents"] + + steps = len(flipped_comp_edges) + + ordered_comp_edges_np = np.asarray(ordered_comp_edges) + flipped_comp_edges_np = np.asarray(flipped_comp_edges) + + def _raw_solve(diags, lowers, uppers, solves): + """Solve the tree-structured linear system (no custom JVP).""" + if not optimize_for_gpu: + init = (diags, solves, lowers, uppers, flipped_comp_edges_np) + diags_out, solves_out, _, _, _ = fori_loop( + 0, steps, _comp_based_triang, init + ) + + lowers_norm = lowers / diags_out + solves_norm = solves_out / diags_out + diags_out = jnp.ones_like(solves_norm) + init = (solves_norm, lowers_norm, ordered_comp_edges_np) + solves_out, _, _ = fori_loop(0, steps, _comp_based_backsub, init) + + return solves_out / diags_out + else: + d, s = diags, solves + for i in range(steps): + d, s, _, _, _ = _comp_based_triang( + i, (d, s, lowers, uppers, flipped_comp_edges) + ) + + d, s = _comp_based_backsub_recursive_doubling( + d, s, lowers, steps, n_nodes, solve_indexer["parent_lookup"] + ) + return s / d + + @jax.custom_jvp + def _solve(diags, lowers, uppers, solves): + return _raw_solve(diags, lowers, uppers, solves) + + @_solve.defjvp + def _solve_jvp(primals, tangents): + diags, lowers, uppers, solves = primals + d_diags, d_lowers, d_uppers, d_solves = tangents + + # Primal output: x = A^{-1} b + x = _raw_solve(diags, lowers, uppers, solves) + + # Compute dA @ x. For each edge (child, parent), the matrix entries are: + # A[parent, child] = uppers[child] + # A[child, parent] = lowers[child] + # (this is consistent with the triangulation multiplier using `uppers`). + # + # Therefore: + # (dA @ x)[parent] += d_uppers[child] * x[child] + # (dA @ x)[child] += d_lowers[child] * x[parent] + dA_x = d_diags * x + dA_x = dA_x.at[all_parents].add(d_uppers[all_children] * x[all_children]) + dA_x = dA_x.at[all_children].add(d_lowers[all_children] * x[all_parents]) + + # JVP: dx = A^{-1} (db - dA @ x) + new_rhs = d_solves - dA_x + dx = _raw_solve(diags, lowers, uppers, new_rhs) + + return x, dx + + return _solve + + +# Cache to avoid re-creating the solve function on every call. +_dhs_solve_cache: Dict[tuple[int, bool, int], Any] = {} + + def step_voltage_implicit_with_dhs_solve( voltages, voltage_terms, @@ -79,8 +164,6 @@ def step_voltage_implicit_with_dhs_solve( # Reorder the lower and upper values. lowers = lowers_and_uppers[solve_indexer["map_to_solve_order_lower"]] uppers = lowers_and_uppers[solve_indexer["map_to_solve_order_upper"]] - ordered_comp_edges = solve_indexer["node_order_grouped"] - flipped_comp_edges = list(reversed(ordered_comp_edges)) # Add a spurious compartment that is modified by the masking. diags = jnp.concatenate([diags, jnp.asarray([1.0])]) @@ -88,46 +171,22 @@ def step_voltage_implicit_with_dhs_solve( uppers = jnp.concatenate([uppers, jnp.asarray([0.0])]) lowers = jnp.concatenate([lowers, jnp.asarray([0.0])]) - # Solve the voltage equations. - # - steps = len(flipped_comp_edges) - if not optimize_for_gpu: - # Cast from a list to a np.array. - # `ordered_comp_edges` has shape `(num_levels, num_comps_per_level, 2)`, - # and `num_comps_per_level=1` for CPU. - ordered_comp_edges = np.asarray(ordered_comp_edges) - flipped_comp_edges = np.asarray(flipped_comp_edges) - - # Triangulate. - steps = len(flipped_comp_edges) - init = (diags, solves, lowers, uppers, flipped_comp_edges) - diags, solves, _, _, _ = fori_loop(0, steps, _comp_based_triang, init) - - # Backsubstitute. - lowers /= diags - solves /= diags - diags = jnp.ones_like(solves) - init = (solves, lowers, ordered_comp_edges) - solves, _, _ = fori_loop(0, steps, _comp_based_backsub, init) - else: - # Triangulate by unrolling the loop of the levels. - for i in range(steps): - diags, solves, _, _, _ = _comp_based_triang( - i, (diags, solves, lowers, uppers, flipped_comp_edges) - ) - - # Backsubstitute with recursive doubling. - diags, solves = _comp_based_backsub_recursive_doubling( - diags, solves, lowers, steps, n_nodes, solve_indexer["parent_lookup"] + # Get or create the solve function with custom JVP. + cache_key = (id(solve_indexer), optimize_for_gpu, int(n_nodes)) + if cache_key not in _dhs_solve_cache: + _dhs_solve_cache[cache_key] = _make_dhs_solve( + solve_indexer, optimize_for_gpu, n_nodes ) + dhs_solve = _dhs_solve_cache[cache_key] + + # Solve the voltage equations with efficient custom JVP. + solution = dhs_solve(diags, lowers, uppers, solves) - # Remove the spurious compartment. This compartment got modified by masking of - # compartments in certain levels. - diags = diags[:-1] - solves = solves[:-1] + # Remove the spurious compartment. + solution = solution[:-1] + else: + solution = solves / diags - # Get inverse of the diagonalized matrix. - solution = solves / diags solution = solution[solve_indexer["inv_map_to_solve_order"]] return solution From 96ce8cd4a1ea7faf1f9c1aa5a15a52afd99d4965 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 31 Mar 2026 13:49:56 +0200 Subject: [PATCH 02/10] slighly better cache --- jaxley/solver_voltage.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py index fdd7b2ff..9c026c7b 100644 --- a/jaxley/solver_voltage.py +++ b/jaxley/solver_voltage.py @@ -95,8 +95,21 @@ def _solve_jvp(primals, tangents): return _solve -# Cache to avoid re-creating the solve function on every call. -_dhs_solve_cache: Dict[tuple[int, bool, int], Any] = {} +# Cache key for storing the compiled solve function directly on the solve_indexer +# dict. Using a tuple key avoids collisions with the string keys it already uses. +# The cached function is automatically discarded when the owning Module (and its +# _dhs_solve_indexer dict) is garbage-collected, and naturally invalidated when +# _init_solver_jaxley_dhs_solve() creates a fresh dict. +_SOLVE_FN_KEY = "_cached_solve_fn" + + +def _get_dhs_solve(solve_indexer: dict, optimize_for_gpu: bool, n_nodes: int): + cache_key = (_SOLVE_FN_KEY, optimize_for_gpu, n_nodes) + solve_fn = solve_indexer.get(cache_key) + if solve_fn is None: + solve_fn = _make_dhs_solve(solve_indexer, optimize_for_gpu, n_nodes) + solve_indexer[cache_key] = solve_fn + return solve_fn def step_voltage_implicit_with_dhs_solve( @@ -172,12 +185,7 @@ def step_voltage_implicit_with_dhs_solve( lowers = jnp.concatenate([lowers, jnp.asarray([0.0])]) # Get or create the solve function with custom JVP. - cache_key = (id(solve_indexer), optimize_for_gpu, int(n_nodes)) - if cache_key not in _dhs_solve_cache: - _dhs_solve_cache[cache_key] = _make_dhs_solve( - solve_indexer, optimize_for_gpu, n_nodes - ) - dhs_solve = _dhs_solve_cache[cache_key] + dhs_solve = _get_dhs_solve(solve_indexer, optimize_for_gpu, int(n_nodes)) # Solve the voltage equations with efficient custom JVP. solution = dhs_solve(diags, lowers, uppers, solves) From fc20b86b3d02ef2be738fd1299abc5d7fc91ed90 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 31 Mar 2026 13:50:27 +0200 Subject: [PATCH 03/10] unrelated test fixes --- tests/helpers.py | 18 +++++++++--------- tests/test_graph.py | 7 ++++--- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index 8b98cfc7..c1d2a1d8 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,12 +6,12 @@ def get_segment_xyzrL(section, comp_idx=None, loc=None, ncomp=8): - assert ( - comp_idx is not None or loc is not None - ), "Either comp_idx or loc must be provided." - assert not ( - comp_idx is not None and loc is not None - ), "Only one of comp_idx or loc can be provided." + assert comp_idx is not None or loc is not None, ( + "Either comp_idx or loc must be provided." + ) + assert not (comp_idx is not None and loc is not None), ( + "Only one of comp_idx or loc can be provided." + ) comp_len = 1 / ncomp loc = comp_len / 2 + comp_idx * comp_len if loc is None else loc @@ -154,10 +154,10 @@ def equal_both_nan_or_empty_df(a: pd.DataFrame, b: pd.DataFrame) -> bool: b = b.drop(columns="xyzr") if a.empty and b.empty: return True - a[a.isna()] = -1 - b[b.isna()] = -1 if set(a.columns) != set(b.columns): return False else: a = a[b.columns] - return (a == b).all() + + equal = a.eq(b) | (a.isna() & b.isna()) + return bool(equal.to_numpy().all()) diff --git a/tests/test_graph.py b/tests/test_graph.py index 4ddd6e55..d05e7864 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -266,9 +266,10 @@ def test_from_graph_vs_NEURON(file): errors = neuron_df["neuron_idx"].to_frame() errors["jx_idx"] = jx_df["jx_idx"] errors[["x", "y", "z"]] = neuron_df[["x", "y", "z"]] - jx_df[["x", "y", "z"]] - errors["xyz"] = np.sqrt((errors[["x", "y", "z"]] ** 2).sum(axis=1)) - errors["radius"] = neuron_df["radius"] - jx_df["radius"] - errors["length"] = neuron_df["length"] - jx_df["length"] + xyz_vals = errors[["x", "y", "z"]].to_numpy(dtype=float) + errors["xyz"] = np.sqrt((xyz_vals**2).sum(axis=1)) + errors["radius"] = (neuron_df["radius"] - jx_df["radius"]).astype(float) + errors["length"] = (neuron_df["length"] - jx_df["length"]).astype(float) assert sum(errors.groupby("jx_idx")["xyz"].max() > 1e-3) == 0 assert sum(errors.groupby("jx_idx")["radius"].max() > 1e-3) == 0 From efb0f55a7fe32a02389caa5067f9dbe84a68c7da Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 31 Mar 2026 14:36:42 +0200 Subject: [PATCH 04/10] fixes for ragged_edges --- jaxley/solver_voltage.py | 34 +++++++++++++++++++++++++++++--- tests/test_solver.py | 42 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py index 9c026c7b..96e374e0 100644 --- a/jaxley/solver_voltage.py +++ b/jaxley/solver_voltage.py @@ -14,6 +14,34 @@ from jaxley.solver_gate import exponential_euler +def _pad_comp_edges(comp_edges) -> np.ndarray: + """Convert grouped DHS edges to a dense integer array padded with `-1`. + + `node_order_grouped` is often stored as a ragged list of arrays, one per depth + level. Padding with `-1` is safe because the voltage solve appends a spurious + compartment at the end of every solve vector, and `-1` indexes that no-op slot. + """ + if isinstance(comp_edges, np.ndarray) and comp_edges.dtype != object: + return comp_edges.astype(int, copy=False) + + comp_edges = list(comp_edges) + if len(comp_edges) == 0: + return np.empty((0, 0, 2), dtype=int) + + level_arrays = [np.asarray(level, dtype=int) for level in comp_edges] + max_width = max(level.shape[0] for level in level_arrays) + padded = np.full((len(level_arrays), max_width, 2), -1, dtype=int) + + for idx, level in enumerate(level_arrays): + if level.ndim != 2 or level.shape[1] != 2: + raise ValueError( + "Expected each DHS level to have shape (num_edges_in_level, 2)." + ) + padded[idx, : level.shape[0], :] = level + + return padded + + def _make_dhs_solve(solve_indexer, optimize_for_gpu, n_nodes): """Create a DHS solve function with custom JVP for efficient differentiation. @@ -32,8 +60,8 @@ def _make_dhs_solve(solve_indexer, optimize_for_gpu, n_nodes): steps = len(flipped_comp_edges) - ordered_comp_edges_np = np.asarray(ordered_comp_edges) - flipped_comp_edges_np = np.asarray(flipped_comp_edges) + ordered_comp_edges_np = _pad_comp_edges(ordered_comp_edges) + flipped_comp_edges_np = _pad_comp_edges(flipped_comp_edges) def _raw_solve(diags, lowers, uppers, solves): """Solve the tree-structured linear system (no custom JVP).""" @@ -54,7 +82,7 @@ def _raw_solve(diags, lowers, uppers, solves): d, s = diags, solves for i in range(steps): d, s, _, _, _ = _comp_based_triang( - i, (d, s, lowers, uppers, flipped_comp_edges) + i, (d, s, lowers, uppers, flipped_comp_edges_np) ) d, s = _comp_based_backsub_recursive_doubling( diff --git a/tests/test_solver.py b/tests/test_solver.py index 464881c1..04703471 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -1,6 +1,8 @@ # This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is # licensed under the Apache License Version 2.0, see +import jax +import jax.numpy as jnp import numpy as np import pytest @@ -8,6 +10,7 @@ from jaxley.channels import HH from jaxley.connect import connect from jaxley.solver_gate import exponential_euler +from jaxley.solver_voltage import _make_dhs_solve from jaxley.synapses import IonotropicSynapse @@ -70,3 +73,42 @@ def test_exp_euler_solver_customization(SimpleCell): ) v = jx.integrate(cell, solver="exp_euler") assert np.invert(np.any(np.isnan(v))) + + +@pytest.mark.parametrize("optimize_for_gpu", [False, True]) +def test_dhs_solve_handles_ragged_grouped_edges(optimize_for_gpu): + solve_indexer = { + "node_order_grouped": [ + np.asarray([[1, 0], [2, 0]], dtype=int), + np.asarray([[3, 1]], dtype=int), + ], + "all_children": np.asarray([1, 2, 3], dtype=int), + "all_parents": np.asarray([0, 0, 1], dtype=int), + "parent_lookup": np.asarray([-1, 0, 0, 1, -1], dtype=int), + } + solve = _make_dhs_solve( + solve_indexer=solve_indexer, + optimize_for_gpu=optimize_for_gpu, + n_nodes=4, + ) + + diags = jnp.asarray([4.0, 5.0, 6.0, 7.0, 1.0]) + lowers = jnp.asarray([0.0, -0.4, -0.3, -0.2, 0.0]) + uppers = jnp.asarray([0.0, -0.5, -0.1, -0.6, 0.0]) + rhs = jnp.asarray([1.0, 2.0, 3.0, 4.0, 0.0]) + + matrix = np.diag(np.asarray(diags)) + for child, parent in zip( + solve_indexer["all_children"], solve_indexer["all_parents"] + ): + matrix[parent, child] = float(uppers[child]) + matrix[child, parent] = float(lowers[child]) + + expected = np.linalg.solve(matrix, np.asarray(rhs)) + actual = np.asarray(jax.jit(solve)(diags, lowers, uppers, rhs)) + np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) + + grad_fn = jax.jit(jax.grad(lambda b: jnp.sum(solve(diags, lowers, uppers, b)))) + actual_grad = np.asarray(grad_fn(rhs)) + expected_grad = np.linalg.solve(matrix.T, np.ones_like(expected)) + np.testing.assert_allclose(actual_grad, expected_grad, rtol=1e-6, atol=1e-6) From 9b75a6ac1acb846bd4d9dd5079b297f5818de3d3 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 31 Mar 2026 14:43:52 +0200 Subject: [PATCH 05/10] merge GPU fixes with GPU improvements --- jaxley/solver_voltage.py | 48 +++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py index 96e374e0..280b70ed 100644 --- a/jaxley/solver_voltage.py +++ b/jaxley/solver_voltage.py @@ -79,14 +79,11 @@ def _raw_solve(diags, lowers, uppers, solves): return solves_out / diags_out else: - d, s = diags, solves - for i in range(steps): - d, s, _, _, _ = _comp_based_triang( - i, (d, s, lowers, uppers, flipped_comp_edges_np) - ) + init = (diags, solves, lowers, uppers, flipped_comp_edges_np) + d, s, _, _, _ = fori_loop(0, steps, _comp_based_triang, init) d, s = _comp_based_backsub_recursive_doubling( - d, s, lowers, steps, n_nodes, solve_indexer["parent_lookup"] + d, s, lowers, steps, solve_indexer["parent_lookup"] ) return s / d @@ -169,11 +166,10 @@ def step_voltage_implicit_with_dhs_solve( map_to_solve_order_lower_and_upper: An array of indices that permutes the concatenation of lowers and uppers into the order of the solve: `lowers_and_uppers = lowers_and_uppers[map_to_solve_order_lower_and_upper]`. - optimize_for_gpu: If True, it does two things: (1) it unrolls the for-loop - for the triangularization stage. (2) It uses recursive doubling (also - unrolled) for the backsubstitution stage. Setting this to `True` will - largely speed up runs on GPU, but it will slow down compilation time and - run time on CPU. + optimize_for_gpu: If True, it uses a compilation-friendly DHS variant with + level-wise triangularization and recursive-doubling backsubstitution. + This mode prioritizes lower compilation time for large morphologies while + retaining high runtime performance on accelerators. """ axial_conductances = delta_t * axial_conductances @@ -275,7 +271,6 @@ def _comp_based_backsub_recursive_doubling( solves: ArrayLike, lowers: ArrayLike, steps: int, - n_nodes: int, parent_lookup: np.ndarray, ) -> tuple[Array, Array]: """Backsubstitute with recursive doubling. @@ -333,17 +328,24 @@ def _comp_based_backsub_recursive_doubling( lower_effect = -lowers / diags solve_effect = solves / diags - step = 1 - while step <= steps: - # For each node, get its k-step parent, where k=`step`. - k_step_parent = np.arange(n_nodes + 1) - for _ in range(step): - k_step_parent = parent_lookup[k_step_parent] - - # Update. - solve_effect = lower_effect * solve_effect[k_step_parent] + solve_effect - lower_effect *= lower_effect[k_step_parent] - step *= 2 + num_recursive_steps = int(np.ceil(np.log2(steps + 1))) if steps > 0 else 0 + parent_jump = jnp.asarray(parent_lookup) + + def _recursive_doubling_body(_, carry): + solve_effect, lower_effect, parent_jump = carry + + solve_effect = lower_effect * solve_effect[parent_jump] + solve_effect + lower_effect = lower_effect * lower_effect[parent_jump] + parent_jump = parent_jump[parent_jump] + + return solve_effect, lower_effect, parent_jump + + solve_effect, _, _ = fori_loop( + 0, + num_recursive_steps, + _recursive_doubling_body, + (solve_effect, lower_effect, parent_jump), + ) # We have to return a `diags` because the final solution is computed as # `solves/diags` (see `step_voltage_implicit_with_dhs_solve`). For recursive From a6fc1818db77ab44858432976c105426a943ec04 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 31 Mar 2026 15:14:11 +0200 Subject: [PATCH 06/10] unroll logn loop agian for testing --- jaxley/solver_voltage.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py index 280b70ed..38177fe8 100644 --- a/jaxley/solver_voltage.py +++ b/jaxley/solver_voltage.py @@ -22,15 +22,15 @@ def _pad_comp_edges(comp_edges) -> np.ndarray: compartment at the end of every solve vector, and `-1` indexes that no-op slot. """ if isinstance(comp_edges, np.ndarray) and comp_edges.dtype != object: - return comp_edges.astype(int, copy=False) + return comp_edges.astype(np.int32, copy=False) comp_edges = list(comp_edges) if len(comp_edges) == 0: - return np.empty((0, 0, 2), dtype=int) + return np.empty((0, 0, 2), dtype=np.int32) - level_arrays = [np.asarray(level, dtype=int) for level in comp_edges] + level_arrays = [np.asarray(level, dtype=np.int32) for level in comp_edges] max_width = max(level.shape[0] for level in level_arrays) - padded = np.full((len(level_arrays), max_width, 2), -1, dtype=int) + padded = np.full((len(level_arrays), max_width, 2), -1, dtype=np.int32) for idx, level in enumerate(level_arrays): if level.ndim != 2 or level.shape[1] != 2: @@ -55,8 +55,8 @@ def _make_dhs_solve(solve_indexer, optimize_for_gpu, n_nodes): """ ordered_comp_edges = solve_indexer["node_order_grouped"] flipped_comp_edges = list(reversed(ordered_comp_edges)) - all_children = solve_indexer["all_children"] - all_parents = solve_indexer["all_parents"] + all_children = np.asarray(solve_indexer["all_children"], dtype=np.int32) + all_parents = np.asarray(solve_indexer["all_parents"], dtype=np.int32) steps = len(flipped_comp_edges) @@ -329,24 +329,15 @@ def _comp_based_backsub_recursive_doubling( solve_effect = solves / diags num_recursive_steps = int(np.ceil(np.log2(steps + 1))) if steps > 0 else 0 - parent_jump = jnp.asarray(parent_lookup) - - def _recursive_doubling_body(_, carry): - solve_effect, lower_effect, parent_jump = carry + parent_jump = jnp.asarray(parent_lookup, dtype=jnp.int32) + # Only O(log2(steps)) iterations; unrolling these often recovers GPU runtime + # without significantly increasing compile time. + for _ in range(num_recursive_steps): solve_effect = lower_effect * solve_effect[parent_jump] + solve_effect lower_effect = lower_effect * lower_effect[parent_jump] parent_jump = parent_jump[parent_jump] - return solve_effect, lower_effect, parent_jump - - solve_effect, _, _ = fori_loop( - 0, - num_recursive_steps, - _recursive_doubling_body, - (solve_effect, lower_effect, parent_jump), - ) - # We have to return a `diags` because the final solution is computed as # `solves/diags` (see `step_voltage_implicit_with_dhs_solve`). For recursive # doubling, the solution should just be `solve_effect`, so we define diags as From 3eac2054c85064e46f502586b0cb31dbbedc20a3 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 31 Mar 2026 17:10:52 +0200 Subject: [PATCH 07/10] unroll for GPU for now (hurst performance) --- jaxley/solver_voltage.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py index 38177fe8..4b7eb01c 100644 --- a/jaxley/solver_voltage.py +++ b/jaxley/solver_voltage.py @@ -79,8 +79,11 @@ def _raw_solve(diags, lowers, uppers, solves): return solves_out / diags_out else: - init = (diags, solves, lowers, uppers, flipped_comp_edges_np) - d, s, _, _, _ = fori_loop(0, steps, _comp_based_triang, init) + d, s = diags, solves + for i in range(steps): + d, s, _, _, _ = _comp_based_triang( + i, (d, s, lowers, uppers, flipped_comp_edges_np) + ) d, s = _comp_based_backsub_recursive_doubling( d, s, lowers, steps, solve_indexer["parent_lookup"] @@ -166,10 +169,11 @@ def step_voltage_implicit_with_dhs_solve( map_to_solve_order_lower_and_upper: An array of indices that permutes the concatenation of lowers and uppers into the order of the solve: `lowers_and_uppers = lowers_and_uppers[map_to_solve_order_lower_and_upper]`. - optimize_for_gpu: If True, it uses a compilation-friendly DHS variant with - level-wise triangularization and recursive-doubling backsubstitution. - This mode prioritizes lower compilation time for large morphologies while - retaining high runtime performance on accelerators. + optimize_for_gpu: If True, it does two things: (1) it unrolls the for-loop + for the triangularization stage. (2) It uses recursive doubling (also + unrolled) for the backsubstitution stage. Setting this to `True` will + largely speed up runs on GPU, but it will slow down compilation time and + run time on CPU. """ axial_conductances = delta_t * axial_conductances From 5974866118dffb0f38c146a36e96c19fcbd7de2e Mon Sep 17 00:00:00 2001 From: Matthijspals Date: Fri, 10 Apr 2026 18:08:12 +0200 Subject: [PATCH 08/10] black and changelog --- CHANGELOG.md | 105 +++++++++++++++++++------------------- jaxley/modules/base.py | 30 +++++------ jaxley/modules/network.py | 18 +++---- 3 files changed, 77 insertions(+), 76 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b816363f..1a40053e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # 0.14.0 -### 🧩 New features +### ???? New features - Add AdEx simplified neuron model, similar to the implementation of Izhikevich channel. Implements Brette et al. (2005), 'Adaptive exponential integrate-and-fire model as an effective description of neuronal activity.' @@ -8,15 +8,16 @@ Implements Brette et al. (2005), 'Adaptive exponential integrate-and-fire model - Add handling of inhomogeneous branches for import and export of morphologies. (#779,@NicolasRR) - Add an logistic transformation (`jaxley.optimize.transforms.LogisticTransform`) (#788, @jnsbck) -### 🐛 Bug fixes +### ???? Bug fixes - Fix issue where `build_dynamic_state_utils` `remove_observables` performed in-place deletions on full state dict (#775, @chaseking) - Allow data_clamp to clamp multiple different states without silently only clamping the last state (#773, @kyralianaka) and fixed checkpointing for this case (#786, @kyralianaka) - Fix issue causing some `View`s to take too long to create (#791, @alexpejovic) +- jx.integrate took O(n^2) time with n compartments on the backwards pass. Instead of backpropagating through the forward solve, we use a custom_jvp (another tridiagonal solve) (#795 @manuelgloeckler) # 0.13.0 -### 🧩 New features +### ???? New features - Add utilities to process the ``all_states`` dictionary such that it only contains "true" ODE states (i.e., it removes branchpoint states, membrane states that are NaN @@ -39,7 +40,7 @@ dynamic_states = flatten(remove_observables(all_states)) recovered_all_states = add_observables(unflatten(dynamic_states), all_params, delta_t=0.025) ``` -### 📚 Documentation +### ???? Documentation - Improved documentation for ``build_init_and_step_fn`` (#719, @matthijspals, @michaeldeistler) @@ -49,7 +50,7 @@ recovered_all_states = add_observables(unflatten(dynamic_states), all_params, de # 0.12.0 -### 🧩 New features +### ???? New features - Exponential Euler solver (#743, @michaeldeistler): ```python @@ -67,7 +68,7 @@ jx.integrate(cell, solver="exp_euler") ``` - Forward Euler solver for branched morphologies (#743, @michaeldeistler). -### 🛠️ Internal updates +### ??????? Internal updates - separate getting the currents from `get_all_states()` (#727, @michaeldeistler). To restore the previous behaviour, do: @@ -76,12 +77,12 @@ states = cell.get_all_states(pstate) states = cell.append_channel_currents_to_states(states, all_params, delta_t) ``` -### 🐛 Bug fixes +### ???? Bug fixes - bugfix for `cell.vis(..., type="morph")` (#725, @michaeldeistler, thanks to Elisabeth Galyo for reporting) -### 📚 Documentation +### ???? Documentation - Added example usage to many user-facing Module functions (#716, @alexpejovic) - Update GPU installation instructions to use CUDA 13 (#732, @michaeldeistler) @@ -92,14 +93,14 @@ Elisabeth Galyo for reporting) # 0.11.5 -### 🐛 Bug fixes +### ???? Bug fixes - bugfix for `.delete()` when multiple channels have the same `current_name` or a shared parameter/state (#713, @michaeldeistler) - safe softplus, use linear function above certain threshold. This avoids an unwanted clipping operation due to the save_exp (#714 @matthijspals) -### 📚 Documentation +### ???? Documentation - typo fixes for several tutorial notebooks (#721, @michaeldeistler, thanks @martricks for reporting) @@ -109,12 +110,12 @@ recommendations (#723, @michaeldeistler) # 0.11.4 -### 🐛 Bug fixes +### ???? Bug fixes - bugfix for indexing when `init_states()` is run on a `jx.Network` (#711, @michaeldeistler) -### 📚 Documentation +### ???? Documentation - add an example on fitting a morphologically detailed cell with gradient descent (#705, @michaeldeistler) @@ -122,11 +123,11 @@ recommendations (#723, @michaeldeistler) # 0.11.3 -### 🛠️ Internal updates +### ??????? Internal updates - follow jax typing practices with Array and ArrayLike (#693, @alexpejovic) -### 🐛 Bug fixes +### ???? Bug fixes - fix for networks that mix point neurons and morphologically detailed neurons (#702, @michaeldeistler) @@ -136,43 +137,43 @@ over to `jx.Network`) (#703, @michaeldeistler) # 0.11.2 -### 🐛 Bug fixes +### ???? Bug fixes - Bugfix for `Network`s on `GPU`: since `v0.9.0`, networks had been very slow on GPU because the voltage equations of cells had been processed in sequence, not in parallel. This is now solved, giving a large speed-up for networks consisting of many cells (#691, @michaeldeistler, thanks to @VENOM314 for reporting) -### 📚 Documentation +### ???? Documentation - Remove all content from the old mkdocs documentation website (#689, @michaeldeistler) # 0.11.1 -### 🐛 Bug fixes +### ???? Bug fixes - bugfix for `set_ncomp()` when the cell consists of a single branch (#686, @michaeldeistler) -### 🛠️ Internal updates +### ??????? Internal updates - fix all typos in the codebase by using the `typos` project (#682, @alexpejovic) # 0.11.0 -### 🧩 New features +### ???? New features - simple conductance synapse added (#659, @kyralianaka) -### 📚 Documentation +### ???? Documentation - add a how-to guide on converting `NMODL` files to `Jaxley`, see [here](https://jaxley.readthedocs.io/en/latest/how_to_guide/import_channels_from_neuron.html) (#669, @michaeldeistler, special thanks to @r-makarov for building the tool) -### 🛠️ Internal updates +### ??????? Internal updates - changes to how the membrane area from SWC files is computed when the radius within a compartment is not constant. This fix can have an impact on simulation results. The @@ -186,7 +187,7 @@ updated computation of membrane area matches that of the NEURON simulator (#662, # 0.10.0 -### 🧩 New features +### ???? New features - functionality to compute the pathwise distance between compartments (#648, @michaeldeistler): @@ -196,12 +197,12 @@ path_dists = distance_pathwise(cell.soma.branch(0).comp(0), cell) cell.nodes["path_dist_from_soma"] = path_dists ``` -### 🐛 Bug fixes +### ???? Bug fixes - fixed synapse recording indices to be within type (#643, @kyralianaka) - Fix inheriting from a Module #590 (#642, @jnsbck) -### 🛠️ Internal updates +### ??????? Internal updates - `module.distance()` is now deprecated in favor of `jx.morphology.distance_direct()` (#648, @michaeldeistler) @@ -209,7 +210,7 @@ cell.nodes["path_dist_from_soma"] = path_dists # 0.9.0 -### ✨ Highlights +### ??? Highlights - This PR implements a new solver, which is now used by default (#625, @michaeldeistler). The new solver has the following advantages: @@ -228,7 +229,7 @@ from jaxley.morphology import morph_connect cell = morph_connect(cell1.branch(1).loc(0.0), cell2.branch(2).loc(1.0)) ``` -### 🧩 New features +### ???? New features - the default SWC reader has changed. To use the previous SWC reader, run `jx.read_swc(..., backend="custom")`. However, note that we will remove this reader @@ -242,13 +243,13 @@ runtime, see [here](https://github.com/jax-ml/jax/issues/26145) (#623, @michaeld [the how-to guide](https://jaxley.readthedocs.io/en/latest/how_to_guide/set_ncomp.html) (#625, @michaeldeistler) -### 📚 Documentation +### ???? Documentation - Introduce the `how-to guide` on the website (#612, @michaeldeistler) - reorganize the advanced tutorials into subgroups (#612, @michaeldeistler) - split the morphology handling tutorials into two notebooks (#612, @michaeldeistler) -### 🛠️ Internal updates +### ??????? Internal updates - improvements to graph-backend for more flexibility in modifying morphologies (#613, @michaeldeistler) @@ -264,12 +265,12 @@ with flywire, which highjacks `type_id > 4` to indicate synaptic contacts (#612, by default. To get them, do `net.copy_node_property_to_edges("global_comp_index")` (#625, @michaeldeistler) -### 🐛 Bug fixes +### ???? Bug fixes - `ChainTransform` forward now working as mirror to inverse (#628, @kyralianaka) - allow `data_set` with vectors of values (#606, @chaseking) -### 🎉 New Contributors +### ???? New Contributors - @chaseking made their first contribution in #606 @@ -288,18 +289,18 @@ by default. To get them, do `net.copy_node_property_to_edges("global_comp_index" # 0.8.0 -### 🧩 New features +### ???? New features - add leaky integrate-and-fire neurons (#564, @jnsbck), Izhikevich neurons, and rate-based neurons (#601, @michaeldeistler) -### 🛠️ Minor updates +### ??????? Minor updates - make `delta` and `v_th` in `IonotropicSynapse` trainable parameters (#599, @jnsbck) - make random postsnaptic compartment selection optional in connectivity functions (#489, @kyralianaka) -### 🐛 Bug fixes +### ???? Bug fixes - Fix bug for `groups` when `.set_ncomp` was run (#587, @michaeldeistler) - allow `.distance` to be jitted (#603, @michaeldeistler) @@ -307,7 +308,7 @@ rate-based neurons (#601, @michaeldeistler) # 0.7.0 -### 🧩 New Features +### ???? New Features - Allow ion diffusion with `cell.diffuse()` and add tutorials (#438, @michaeldeistler): ```python @@ -321,11 +322,11 @@ cell.set("axial_diffusion_CaCon_i", 1.0) ``` - Introduce ion pumps (#438, @michaeldeistler) -### 🛠️ Minor changes +### ??????? Minor changes - rename `delete_channel()` to `delete()` (#438, @michaeldeistler) -### 🐛 Bug fixes +### ???? Bug fixes - Fix for simulation of morphologies with inhomogeneous numbers of compartments (#438, @michaeldeistler) @@ -339,7 +340,7 @@ cell.set("axial_diffusion_CaCon_i", 1.0) - make random post compartment selection optional in connectivity functions (#489, @kyralianaka) -### 🎉 New Contributors +### ???? New Contributors - @Kartik-Sama made their first contribution in #582 @@ -366,7 +367,7 @@ versions of `JAX` can be made equally fast as older versions by setting `os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'` at the beginning of your jupyter notebook (#570, @michaeldeistler). -### 🧩 New Features +### ???? New Features - Add ability to record synaptic currents (#523, @ntolley). Recordings can be turned on with @@ -407,7 +408,7 @@ different numbers of post-synaptic partners. (#514, @jnsbck) [tutorial](https://jaxley.readthedocs.io/en/latest/tutorials/08_importing_morphologies.html) for more details. -### 🛠️ Code Health +### ??????? Code Health - changelog added to CI (#537, #558, @jnsbck) @@ -422,19 +423,19 @@ different numbers of post-synaptic partners. (#514, @jnsbck) - Allow inspecting the version via `import jaxley as jx; print(jx.__version__)` (#577, @michaeldeistler). -### 🐛 Bug fixes +### ???? Bug fixes - Fixed inconsistency with *type* assertions arising due to `numpy` functions returning different `dtypes` on platforms like Windows (#567, @Kartik-Sama) -### 🎉 New Contributors +### ???? New Contributors - @ntolley made their first contribution in #523 # 0.5.0 -### 🛠️ API changes +### ??????? API changes - Synapse views no longer exist (#447, #453, @jnsbck). Previous code such as ```python @@ -474,7 +475,7 @@ transforms = [ tf = jt.ParamTransform(transforms) ``` -### 🧩 New features +### ???? New features - Added a new `delete_channel()` method (#521, @jnsbck) - Allow to write trainables to the module (#470, @michaeldeistler): @@ -493,14 +494,14 @@ net[r_greater_1].nodes.vis() - check if recordings are empty (#460, @deezer257) - enable `clamp` to be jitted and vmapped with `data_clamp()` (#374, @kyralianaka) -### 🐛 Bug fixes +### ???? Bug fixes - allow for cells that were read from swc to be pickled (#525, @jnsbck) - fix units of `compute_current()` in channels (#461, @michaeldeistler) - fix issues with plotting when the morphology has a different number of compartments (#513, @jnsbck) -### 📚 Documentation +### ???? Documentation - new tutorial on synapse indexing (#464, @michaeldeistler, @zinaStef) - new tutorial on parameter sharing (#464, @michaeldeistler, @zinaStef) @@ -519,7 +520,7 @@ net[r_greater_1].nodes.vis() - automated tests to check if tutorials can be run (#480, @jnsbck) - add helpers to deprecate functions and kwargs (#516, @jnsbck) -### 🎉 New Contributors +### ???? New Contributors - @simoneeb made their first contribution in #473 - @zinaStef made their first contribution in #464 @@ -529,7 +530,7 @@ net[r_greater_1].nodes.vis() # 0.4.0 -### 🧩 New features +### ???? New features - Changing the number of compartments: `cell.branch(0).set_ncomp(4)` (#436, #440, #445, @michaeldeistler, @jnsbck) @@ -538,7 +539,7 @@ net[r_greater_1].nodes.vis() - Speed optimization for `jx.integrate(..., voltage_solver="jaxley.stone")` (#442, @michaeldeistler) -### 📚 Documentation +### ???? Documentation - new website powered by sphinx: [`jaxley.readthedocs.io`](https://jaxley.readthedocs.io/) (#434, #435, @michaeldeistler) @@ -546,7 +547,7 @@ net[r_greater_1].nodes.vis() # v0.3.0 -### 🧩 New features +### ???? New features - New solver: `jx.integrate(..., voltage_solver="jax.sparse")` which has very low compile time (#418, @michaeldeistler) @@ -554,7 +555,7 @@ compile time (#418, @michaeldeistler) the number of compartments after initialization is not yet supported, #418, #426, @michaeldeistler) -### 🐛 Bug fixes +### ???? Bug fixes - Bugfix for capacitances and their interplay with axial conductances (Thanks @Tunenip, #426, @michaeldeistler) @@ -570,13 +571,13 @@ the number of compartments after initialization is not yet supported, #418, #426 # v0.2.0 -### 🧩 New features +### ???? New features - Cranck-Nicolson solver (#413, @michaeldeistler) - Forward Euler solver for compartments and branches (#413, @michaeldeistler) - Add option to access `states` in `channel.init_state` (#416, @michaeldeistler) -### 🐛 Bug fixes +### ???? Bug fixes - Bugfix for interpolation of x, y, z values (#411, @jnsbck) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index c943df65..050b9283 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -284,9 +284,9 @@ def __getitem__(self, index): supported_parents = ["network", "cell", "branch"] # cannot index into comp not_group_view = self._current_view not in self.group_names - assert self._current_view in supported_parents or not_group_view, ( - "Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof." - ) + assert ( + self._current_view in supported_parents or not_group_view + ), "Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof." index = index if isinstance(index, tuple) else (index,) child_views = self._childviews() @@ -1601,9 +1601,9 @@ def make_trainable( assert data is not None, f"Key '{key}' not found in nodes or edges" not_nan = ~data[key].isna() data = data.loc[not_nan].copy() - assert len(data) > 0, ( - "No settable parameters found in the selected compartments." - ) + assert ( + len(data) > 0 + ), "No settable parameters found in the selected compartments." grouped_view = data.groupby("controlled_by_param") # Because of this `x.index.values` we cannot support `make_trainable()` on @@ -2042,9 +2042,9 @@ def customize_solver_exp_euler( v = jx.integrate(cell, delta_t=delta_t, t_max=100.0) """ if exp_euler_transition is not None: - self.solver_customizers["exp_euler"]["exp_euler_transition"] = ( - exp_euler_transition - ) + self.solver_customizers["exp_euler"][ + "exp_euler_transition" + ] = exp_euler_transition @only_allow_module def _compute_transition_matrix( @@ -2715,9 +2715,9 @@ def diffuse(self, state: str) -> None: simulated_concentrations = jx.integrate(cell, t_max=5.0) """ - assert not isinstance(self, View), ( - "You can only diffuse ions in the entire module." - ) + assert not isinstance( + self, View + ), "You can only diffuse ions in the entire module." self.base.diffusion_states.append(state) self.base.nodes.loc[self._nodes_in_view, f"axial_diffusion_{state}"] = 1.0 @@ -2738,9 +2738,9 @@ def delete_diffusion(self, state: str) -> None: Args: state: Name of the state that should no longer be diffused. """ - assert state in self.base.diffusion_states, ( - f"State {state} is not part of `self.diffusion_states`." - ) + assert ( + state in self.base.diffusion_states + ), f"State {state} is not part of `self.diffusion_states`." self.base.diffusion_states.remove(state) self.base.nodes.drop(columns=[f"axial_diffusion_{state}"], inplace=True) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 23f3609d..5a397ed8 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -306,9 +306,9 @@ def _step_synapse_state( synapse_names = list(grouped_syns.indices.keys()) for i, synapse_type in enumerate(syn_channels): - assert synapse_names[i] == synapse_type._name, ( - "Mixup in the ordering of synapses. Please create an issue on Github." - ) + assert ( + synapse_names[i] == synapse_type._name + ), "Mixup in the ordering of synapses. Please create an issue on Github." synapse_param_names = list(synapse_type.synapse_params.keys()) synapse_state_names = list(synapse_type.synapse_states.keys()) @@ -358,9 +358,9 @@ def _synapse_currents( # offset. diff = 1e-3 for i, synapse_type in enumerate(syn_channels): - assert synapse_names[i] == synapse_type._name, ( - "Mixup in the ordering of synapses. Please create an issue on Github." - ) + assert ( + synapse_names[i] == synapse_type._name + ), "Mixup in the ordering of synapses. Please create an issue on Github." synapse_param_names = list(synapse_type.synapse_params.keys()) synapse_state_names = list(synapse_type.synapse_states.keys()) @@ -450,9 +450,9 @@ def arrange_in_layers( plt.show() """ - assert np.sum(layers) == self.shape[0], ( - "The number of cells in the layers must match the number of cells in the network." - ) + assert ( + np.sum(layers) == self.shape[0] + ), "The number of cells in the layers must match the number of cells in the network." cells_in_layers = [ list(range(sum(layers[:i]), sum(layers[: i + 1]))) for i in range(len(layers)) From 543fd86042cc17968d8e245a90d56107571cea31 Mon Sep 17 00:00:00 2001 From: Matthijspals Date: Fri, 10 Apr 2026 18:19:08 +0200 Subject: [PATCH 09/10] undo changelog nano changes --- CHANGELOG.md | 107 ++++++++++++++++++++++++++------------------------- 1 file changed, 54 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a40053e..fe59cedb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # 0.14.0 -### ???? New features +### 🧩 New features - Add AdEx simplified neuron model, similar to the implementation of Izhikevich channel. Implements Brette et al. (2005), 'Adaptive exponential integrate-and-fire model as an effective description of neuronal activity.' @@ -8,16 +8,17 @@ Implements Brette et al. (2005), 'Adaptive exponential integrate-and-fire model - Add handling of inhomogeneous branches for import and export of morphologies. (#779,@NicolasRR) - Add an logistic transformation (`jaxley.optimize.transforms.LogisticTransform`) (#788, @jnsbck) -### ???? Bug fixes +### 🐛 Bug fixes - Fix issue where `build_dynamic_state_utils` `remove_observables` performed in-place deletions on full state dict (#775, @chaseking) - Allow data_clamp to clamp multiple different states without silently only clamping the last state (#773, @kyralianaka) and fixed checkpointing for this case (#786, @kyralianaka) - Fix issue causing some `View`s to take too long to create (#791, @alexpejovic) -- jx.integrate took O(n^2) time with n compartments on the backwards pass. Instead of backpropagating through the forward solve, we use a custom_jvp (another tridiagonal solve) (#795 @manuelgloeckler) +- Fix single point branch plotting with type="comp" (#797, @jnsbck) +- jx.integrate took O(n^2) time with n compartments on the backwards pass. Instead of backpropagating through the forward solve, we now use a custom_jvp (another tridiagonal solve, which is O(n)) (#795 @manuelgloeckler) # 0.13.0 -### ???? New features +### 🧩 New features - Add utilities to process the ``all_states`` dictionary such that it only contains "true" ODE states (i.e., it removes branchpoint states, membrane states that are NaN @@ -40,7 +41,7 @@ dynamic_states = flatten(remove_observables(all_states)) recovered_all_states = add_observables(unflatten(dynamic_states), all_params, delta_t=0.025) ``` -### ???? Documentation +### 📚 Documentation - Improved documentation for ``build_init_and_step_fn`` (#719, @matthijspals, @michaeldeistler) @@ -50,7 +51,7 @@ recovered_all_states = add_observables(unflatten(dynamic_states), all_params, de # 0.12.0 -### ???? New features +### 🧩 New features - Exponential Euler solver (#743, @michaeldeistler): ```python @@ -68,7 +69,7 @@ jx.integrate(cell, solver="exp_euler") ``` - Forward Euler solver for branched morphologies (#743, @michaeldeistler). -### ??????? Internal updates +### 🛠️ Internal updates - separate getting the currents from `get_all_states()` (#727, @michaeldeistler). To restore the previous behaviour, do: @@ -77,12 +78,12 @@ states = cell.get_all_states(pstate) states = cell.append_channel_currents_to_states(states, all_params, delta_t) ``` -### ???? Bug fixes +### 🐛 Bug fixes - bugfix for `cell.vis(..., type="morph")` (#725, @michaeldeistler, thanks to Elisabeth Galyo for reporting) -### ???? Documentation +### 📚 Documentation - Added example usage to many user-facing Module functions (#716, @alexpejovic) - Update GPU installation instructions to use CUDA 13 (#732, @michaeldeistler) @@ -93,14 +94,14 @@ Elisabeth Galyo for reporting) # 0.11.5 -### ???? Bug fixes +### 🐛 Bug fixes - bugfix for `.delete()` when multiple channels have the same `current_name` or a shared parameter/state (#713, @michaeldeistler) - safe softplus, use linear function above certain threshold. This avoids an unwanted clipping operation due to the save_exp (#714 @matthijspals) -### ???? Documentation +### 📚 Documentation - typo fixes for several tutorial notebooks (#721, @michaeldeistler, thanks @martricks for reporting) @@ -110,12 +111,12 @@ recommendations (#723, @michaeldeistler) # 0.11.4 -### ???? Bug fixes +### 🐛 Bug fixes - bugfix for indexing when `init_states()` is run on a `jx.Network` (#711, @michaeldeistler) -### ???? Documentation +### 📚 Documentation - add an example on fitting a morphologically detailed cell with gradient descent (#705, @michaeldeistler) @@ -123,11 +124,11 @@ recommendations (#723, @michaeldeistler) # 0.11.3 -### ??????? Internal updates +### 🛠️ Internal updates - follow jax typing practices with Array and ArrayLike (#693, @alexpejovic) -### ???? Bug fixes +### 🐛 Bug fixes - fix for networks that mix point neurons and morphologically detailed neurons (#702, @michaeldeistler) @@ -137,43 +138,43 @@ over to `jx.Network`) (#703, @michaeldeistler) # 0.11.2 -### ???? Bug fixes +### 🐛 Bug fixes - Bugfix for `Network`s on `GPU`: since `v0.9.0`, networks had been very slow on GPU because the voltage equations of cells had been processed in sequence, not in parallel. This is now solved, giving a large speed-up for networks consisting of many cells (#691, @michaeldeistler, thanks to @VENOM314 for reporting) -### ???? Documentation +### 📚 Documentation - Remove all content from the old mkdocs documentation website (#689, @michaeldeistler) # 0.11.1 -### ???? Bug fixes +### 🐛 Bug fixes - bugfix for `set_ncomp()` when the cell consists of a single branch (#686, @michaeldeistler) -### ??????? Internal updates +### 🛠️ Internal updates - fix all typos in the codebase by using the `typos` project (#682, @alexpejovic) # 0.11.0 -### ???? New features +### 🧩 New features - simple conductance synapse added (#659, @kyralianaka) -### ???? Documentation +### 📚 Documentation - add a how-to guide on converting `NMODL` files to `Jaxley`, see [here](https://jaxley.readthedocs.io/en/latest/how_to_guide/import_channels_from_neuron.html) (#669, @michaeldeistler, special thanks to @r-makarov for building the tool) -### ??????? Internal updates +### 🛠️ Internal updates - changes to how the membrane area from SWC files is computed when the radius within a compartment is not constant. This fix can have an impact on simulation results. The @@ -187,7 +188,7 @@ updated computation of membrane area matches that of the NEURON simulator (#662, # 0.10.0 -### ???? New features +### 🧩 New features - functionality to compute the pathwise distance between compartments (#648, @michaeldeistler): @@ -197,12 +198,12 @@ path_dists = distance_pathwise(cell.soma.branch(0).comp(0), cell) cell.nodes["path_dist_from_soma"] = path_dists ``` -### ???? Bug fixes +### 🐛 Bug fixes - fixed synapse recording indices to be within type (#643, @kyralianaka) - Fix inheriting from a Module #590 (#642, @jnsbck) -### ??????? Internal updates +### 🛠️ Internal updates - `module.distance()` is now deprecated in favor of `jx.morphology.distance_direct()` (#648, @michaeldeistler) @@ -210,7 +211,7 @@ cell.nodes["path_dist_from_soma"] = path_dists # 0.9.0 -### ??? Highlights +### ✨ Highlights - This PR implements a new solver, which is now used by default (#625, @michaeldeistler). The new solver has the following advantages: @@ -229,7 +230,7 @@ from jaxley.morphology import morph_connect cell = morph_connect(cell1.branch(1).loc(0.0), cell2.branch(2).loc(1.0)) ``` -### ???? New features +### 🧩 New features - the default SWC reader has changed. To use the previous SWC reader, run `jx.read_swc(..., backend="custom")`. However, note that we will remove this reader @@ -243,13 +244,13 @@ runtime, see [here](https://github.com/jax-ml/jax/issues/26145) (#623, @michaeld [the how-to guide](https://jaxley.readthedocs.io/en/latest/how_to_guide/set_ncomp.html) (#625, @michaeldeistler) -### ???? Documentation +### 📚 Documentation - Introduce the `how-to guide` on the website (#612, @michaeldeistler) - reorganize the advanced tutorials into subgroups (#612, @michaeldeistler) - split the morphology handling tutorials into two notebooks (#612, @michaeldeistler) -### ??????? Internal updates +### 🛠️ Internal updates - improvements to graph-backend for more flexibility in modifying morphologies (#613, @michaeldeistler) @@ -265,12 +266,12 @@ with flywire, which highjacks `type_id > 4` to indicate synaptic contacts (#612, by default. To get them, do `net.copy_node_property_to_edges("global_comp_index")` (#625, @michaeldeistler) -### ???? Bug fixes +### 🐛 Bug fixes - `ChainTransform` forward now working as mirror to inverse (#628, @kyralianaka) - allow `data_set` with vectors of values (#606, @chaseking) -### ???? New Contributors +### 🎉 New Contributors - @chaseking made their first contribution in #606 @@ -289,18 +290,18 @@ by default. To get them, do `net.copy_node_property_to_edges("global_comp_index" # 0.8.0 -### ???? New features +### 🧩 New features - add leaky integrate-and-fire neurons (#564, @jnsbck), Izhikevich neurons, and rate-based neurons (#601, @michaeldeistler) -### ??????? Minor updates +### 🛠️ Minor updates - make `delta` and `v_th` in `IonotropicSynapse` trainable parameters (#599, @jnsbck) - make random postsnaptic compartment selection optional in connectivity functions (#489, @kyralianaka) -### ???? Bug fixes +### 🐛 Bug fixes - Fix bug for `groups` when `.set_ncomp` was run (#587, @michaeldeistler) - allow `.distance` to be jitted (#603, @michaeldeistler) @@ -308,7 +309,7 @@ rate-based neurons (#601, @michaeldeistler) # 0.7.0 -### ???? New Features +### 🧩 New Features - Allow ion diffusion with `cell.diffuse()` and add tutorials (#438, @michaeldeistler): ```python @@ -322,11 +323,11 @@ cell.set("axial_diffusion_CaCon_i", 1.0) ``` - Introduce ion pumps (#438, @michaeldeistler) -### ??????? Minor changes +### 🛠️ Minor changes - rename `delete_channel()` to `delete()` (#438, @michaeldeistler) -### ???? Bug fixes +### 🐛 Bug fixes - Fix for simulation of morphologies with inhomogeneous numbers of compartments (#438, @michaeldeistler) @@ -340,7 +341,7 @@ cell.set("axial_diffusion_CaCon_i", 1.0) - make random post compartment selection optional in connectivity functions (#489, @kyralianaka) -### ???? New Contributors +### 🎉 New Contributors - @Kartik-Sama made their first contribution in #582 @@ -367,7 +368,7 @@ versions of `JAX` can be made equally fast as older versions by setting `os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'` at the beginning of your jupyter notebook (#570, @michaeldeistler). -### ???? New Features +### 🧩 New Features - Add ability to record synaptic currents (#523, @ntolley). Recordings can be turned on with @@ -408,7 +409,7 @@ different numbers of post-synaptic partners. (#514, @jnsbck) [tutorial](https://jaxley.readthedocs.io/en/latest/tutorials/08_importing_morphologies.html) for more details. -### ??????? Code Health +### 🛠️ Code Health - changelog added to CI (#537, #558, @jnsbck) @@ -423,19 +424,19 @@ different numbers of post-synaptic partners. (#514, @jnsbck) - Allow inspecting the version via `import jaxley as jx; print(jx.__version__)` (#577, @michaeldeistler). -### ???? Bug fixes +### 🐛 Bug fixes - Fixed inconsistency with *type* assertions arising due to `numpy` functions returning different `dtypes` on platforms like Windows (#567, @Kartik-Sama) -### ???? New Contributors +### 🎉 New Contributors - @ntolley made their first contribution in #523 # 0.5.0 -### ??????? API changes +### 🛠️ API changes - Synapse views no longer exist (#447, #453, @jnsbck). Previous code such as ```python @@ -475,7 +476,7 @@ transforms = [ tf = jt.ParamTransform(transforms) ``` -### ???? New features +### 🧩 New features - Added a new `delete_channel()` method (#521, @jnsbck) - Allow to write trainables to the module (#470, @michaeldeistler): @@ -494,14 +495,14 @@ net[r_greater_1].nodes.vis() - check if recordings are empty (#460, @deezer257) - enable `clamp` to be jitted and vmapped with `data_clamp()` (#374, @kyralianaka) -### ???? Bug fixes +### 🐛 Bug fixes - allow for cells that were read from swc to be pickled (#525, @jnsbck) - fix units of `compute_current()` in channels (#461, @michaeldeistler) - fix issues with plotting when the morphology has a different number of compartments (#513, @jnsbck) -### ???? Documentation +### 📚 Documentation - new tutorial on synapse indexing (#464, @michaeldeistler, @zinaStef) - new tutorial on parameter sharing (#464, @michaeldeistler, @zinaStef) @@ -520,7 +521,7 @@ net[r_greater_1].nodes.vis() - automated tests to check if tutorials can be run (#480, @jnsbck) - add helpers to deprecate functions and kwargs (#516, @jnsbck) -### ???? New Contributors +### 🎉 New Contributors - @simoneeb made their first contribution in #473 - @zinaStef made their first contribution in #464 @@ -530,7 +531,7 @@ net[r_greater_1].nodes.vis() # 0.4.0 -### ???? New features +### 🧩 New features - Changing the number of compartments: `cell.branch(0).set_ncomp(4)` (#436, #440, #445, @michaeldeistler, @jnsbck) @@ -539,7 +540,7 @@ net[r_greater_1].nodes.vis() - Speed optimization for `jx.integrate(..., voltage_solver="jaxley.stone")` (#442, @michaeldeistler) -### ???? Documentation +### 📚 Documentation - new website powered by sphinx: [`jaxley.readthedocs.io`](https://jaxley.readthedocs.io/) (#434, #435, @michaeldeistler) @@ -547,7 +548,7 @@ net[r_greater_1].nodes.vis() # v0.3.0 -### ???? New features +### 🧩 New features - New solver: `jx.integrate(..., voltage_solver="jax.sparse")` which has very low compile time (#418, @michaeldeistler) @@ -555,7 +556,7 @@ compile time (#418, @michaeldeistler) the number of compartments after initialization is not yet supported, #418, #426, @michaeldeistler) -### ???? Bug fixes +### 🐛 Bug fixes - Bugfix for capacitances and their interplay with axial conductances (Thanks @Tunenip, #426, @michaeldeistler) @@ -571,13 +572,13 @@ the number of compartments after initialization is not yet supported, #418, #426 # v0.2.0 -### ???? New features +### 🧩 New features - Cranck-Nicolson solver (#413, @michaeldeistler) - Forward Euler solver for compartments and branches (#413, @michaeldeistler) - Add option to access `states` in `channel.init_state` (#416, @michaeldeistler) -### ???? Bug fixes +### 🐛 Bug fixes - Bugfix for interpolation of x, y, z values (#411, @jnsbck) From 10481fc2d89a9c675d28082a9053e970832f9a38 Mon Sep 17 00:00:00 2001 From: Matthijspals Date: Fri, 10 Apr 2026 18:23:43 +0200 Subject: [PATCH 10/10] format helpers with black --- tests/helpers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index c1d2a1d8..efe95a65 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,12 +6,12 @@ def get_segment_xyzrL(section, comp_idx=None, loc=None, ncomp=8): - assert comp_idx is not None or loc is not None, ( - "Either comp_idx or loc must be provided." - ) - assert not (comp_idx is not None and loc is not None), ( - "Only one of comp_idx or loc can be provided." - ) + assert ( + comp_idx is not None or loc is not None + ), "Either comp_idx or loc must be provided." + assert not ( + comp_idx is not None and loc is not None + ), "Only one of comp_idx or loc can be provided." comp_len = 1 / ncomp loc = comp_len / 2 + comp_idx * comp_len if loc is None else loc