Skip to content

Solver n_comp scaling fix#795

Merged
michaeldeistler merged 12 commits into
mainfrom
solver_fix
Apr 10, 2026
Merged

Solver n_comp scaling fix#795
michaeldeistler merged 12 commits into
mainfrom
solver_fix

Conversation

@manuelgloeckler
Copy link
Copy Markdown
Contributor

Problem

The backward pass through jx.integrate scales O(n^2) with the number of compartments on CPU, despite O(n) FLOPs. The root cause is that JAX's automatic VJP of fori_loop must carry and copy the full O(n) state tensor on every iteration, resulting in O(n^2) memory traffic. This makes gradient computation prohibitively slow for large morphologies (e.g. 700ms backward for 10,000 compartments vs 0.89ms forward — a 788x ratio).

Solution

Add a custom_jvp to the tridiagonal solve. The key mathematical insight is that the gradient of a linear solve A x = b is itself a linear solve with the same matrix:

dx = A^{-1} (db - dA x)

This means the backward pass is a single O(n) tridiagonal solve rather than differentiating through the O(n)-step fori_loop. JAX automatically derives the VJP (reverse-mode) from the custom JVP rule.

Changes

jaxley/solver_voltage.py

  • Add _make_dhs_solve(): factory that creates a solve function with @jax.custom_jvp. The JVP computes dA @ x using precomputed child/parent index arrays, then solves A^{-1}(db - dA x) by calling the same solver.
  • Add _pad_comp_edges(): pads ragged grouped edge arrays to dense numpy arrays (needed because levels can have different numbers of edges).
  • Add _get_dhs_solve(): caches the solve function on the solve_indexer dict to avoid recreating it on every call. (Note: caching is currently semi-optimal and may be removable.)
  • Simplify step_voltage_implicit_with_dhs_solve() to delegate to the cached solve function.
  • Simplify _comp_based_backsub_recursive_doubling(): compute num_recursive_steps via ceil(log2), use JAX array for parent_jump (enabling parent_jump[parent_jump] doubling), remove n_nodes parameter.

jaxley/modules/base.py

  • In _init_solver_jaxley_dhs_solve(): precompute flat all_children and all_parents index arrays from node_order_grouped for the custom JVP's dA @ x computation.

jaxley/modules/network.py

  • Same all_children/all_parents precomputation for the network-level solver init.

tests/test_solver.py

  • Add test_dhs_solve_handles_ragged_grouped_edges: verifies correctness of the solve and its gradient against np.linalg.solve for both CPU and GPU paths, including ragged edge groupings.

tests/helpers.py, tests/test_graph.py

  • Unrelated fixes: NaN-safe DataFrame comparison, dtype handling in graph tests, formatter-driven assert style changes.

Benchmark results

ncomp scaling (single branch)

ncomp Bwd main (ms) Bwd this branch (ms) Speedup
10 0.04 0.04 1x
1,000 9.01 0.48 19x
10,000 700.40 3.19 220x

The bwd/fwd ratio drops from 789x to 3.8x at 10,000 compartments.

Morphology (SWC) benchmarks — selected highlights

Morphology ncomp Bwd main (ms) Bwd this branch (ms) Speedup
morph_ca1_n120.swc 4 5.71 0.84 7x
morph_ca1_n120.swc 16 39.49 1.52 26x
morph_allen_485574832.swc 16 19.34 1.10 18x
morph_retina_20161028_1.swc 16 39.68 1.49 27x

Compilation times are unchanged. No impact on forward pass performance.

@manuelgloeckler
Copy link
Copy Markdown
Contributor Author

Here the full benchmarks

Runtimes on main

ncomp scaling (single branch)

Label Fwd (ms) Bwd (ms) Ratio Fwd Compile (ms) Bwd Compile (ms) Bwd FLOPs Bwd Mem
ncomp=10 0.01 0.04 4.5x 60 231 1.3e+03 0.0 MB
ncomp=1000 0.14 9.01 65.7x 69 229 5.6e+04 0.6 MB
ncomp=10000 0.89 700.40 788.5x 104 270 5.5e+05 5.6 MB

morphology (SWC) benchmarks

Label Fwd (ms) Bwd (ms) Ratio Fwd Compile (ms) Bwd Compile (ms) Bwd FLOPs Bwd Mem
morph_single_branch.swc@ncomp=1 0.01 0.02 3.2x 31 139 5.7e+02 0.0 MB
morph_single_branch.swc@ncomp=4 0.01 0.03 3.2x 59 216 9.4e+02 0.0 MB
morph_single_branch.swc@ncomp=16 0.01 0.15 11.9x 64 223 1.6e+03 0.0 MB
morph_ca1_n120.swc@ncomp=1 0.09 1.65 17.7x 92 415 1.3e+05 0.6 MB
morph_ca1_n120.swc@ncomp=4 0.25 5.71 22.5x 230 545 2.2e+05 0.8 MB
morph_ca1_n120.swc@ncomp=16 0.44 39.49 89.9x 221 563 6.1e+05 1.8 MB
morph_allen_485574832.swc@ncomp=1 0.08 1.08 13.0x 102 382 7.3e+04 0.4 MB
morph_allen_485574832.swc@ncomp=4 0.18 3.21 18.1x 176 460 1.2e+05 0.5 MB
morph_allen_485574832.swc@ncomp=16 0.31 19.34 61.5x 172 458 3.0e+05 1.2 MB
morph_retina_20161028_1.swc@ncomp=1 0.09 1.67 18.0x 96 417 1.3e+05 0.6 MB
morph_retina_20161028_1.swc@ncomp=4 0.24 5.82 24.4x 206 562 2.2e+05 0.8 MB
morph_retina_20161028_1.swc@ncomp=16 0.43 39.68 92.5x 208 547 6.1e+05 1.8 MB

Runtimes on this branch

ncomp scaling (single branch)

Label Fwd (ms) Bwd (ms) Ratio Fwd Compile (ms) Bwd Compile (ms) Bwd FLOPs Bwd Mem
ncomp=10 0.01 0.04 4.5x 62 223 1.2e+03 0.0 MB
ncomp=1000 0.13 0.48 3.8x 74 226 5.0e+04 0.3 MB
ncomp=10000 0.85 3.19 3.8x 102 271 4.9e+05 3.2 MB

morphology (SWC) benchmarks

Label Fwd (ms) Bwd (ms) Ratio Fwd Compile (ms) Bwd Compile (ms) Bwd FLOPs Bwd Mem
morph_single_branch.swc@ncomp=1 0.01 0.02 3.1x 31 139 5.7e+02 0.0 MB
morph_single_branch.swc@ncomp=4 0.01 0.03 3.0x 59 207 8.7e+02 0.0 MB
morph_single_branch.swc@ncomp=16 0.01 0.14 11.8x 63 221 1.5e+03 0.0 MB
morph_ca1_n120.swc@ncomp=1 0.09 0.39 4.2x 90 402 1.3e+05 0.5 MB
morph_ca1_n120.swc@ncomp=4 0.25 0.84 3.4x 209 531 2.2e+05 0.6 MB
morph_ca1_n120.swc@ncomp=16 0.45 1.52 3.4x 214 561 5.9e+05 1.2 MB
morph_allen_485574832.swc@ncomp=1 0.08 0.27 3.3x 101 368 7.2e+04 0.3 MB
morph_allen_485574832.swc@ncomp=4 0.18 0.50 2.8x 165 448 1.1e+05 0.4 MB
morph_allen_485574832.swc@ncomp=16 0.32 1.10 3.5x 168 454 2.9e+05 0.8 MB
morph_retina_20161028_1.swc@ncomp=1 0.09 0.38 4.2x 116 401 1.3e+05 0.5 MB
morph_retina_20161028_1.swc@ncomp=4 0.25 0.81 3.2x 210 537 2.2e+05 0.6 MB
morph_retina_20161028_1.swc@ncomp=16 0.40 1.49 3.7x 214 547 5.9e+05 1.2 MB

@Matthijspals
Copy link
Copy Markdown
Contributor

This is great @manuelgloeckler! For our filtering setup, with 10_000 compartments the backwards pass goes from 160s to .3s.

Copy link
Copy Markdown
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, that is really awesome! Thank you so much, and these speed-ups are super impressive!

black is currently failing, and please add a line to the CHANGELOG.md.

@michaeldeistler michaeldeistler merged commit ba9772b into main Apr 10, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants