Solver n_comp scaling fix#795
Merged
Merged
Conversation
Contributor
Author
|
Here the full benchmarks Runtimes on mainncomp scaling (single branch)
morphology (SWC) benchmarks
Runtimes on this branchncomp scaling (single branch)
morphology (SWC) benchmarks
|
Contributor
|
This is great @manuelgloeckler! For our filtering setup, with 10_000 compartments the backwards pass goes from 160s to .3s. |
michaeldeistler
approved these changes
Apr 6, 2026
Contributor
michaeldeistler
left a comment
There was a problem hiding this comment.
Wow, that is really awesome! Thank you so much, and these speed-ups are super impressive!
black is currently failing, and please add a line to the CHANGELOG.md.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
The backward pass through
jx.integratescales O(n^2) with the number of compartments on CPU, despite O(n) FLOPs. The root cause is that JAX's automatic VJP offori_loopmust carry and copy the full O(n) state tensor on every iteration, resulting in O(n^2) memory traffic. This makes gradient computation prohibitively slow for large morphologies (e.g. 700ms backward for 10,000 compartments vs 0.89ms forward — a 788x ratio).Solution
Add a
custom_jvpto the tridiagonal solve. The key mathematical insight is that the gradient of a linear solveA x = bis itself a linear solve with the same matrix:This means the backward pass is a single O(n) tridiagonal solve rather than differentiating through the O(n)-step
fori_loop. JAX automatically derives the VJP (reverse-mode) from the custom JVP rule.Changes
jaxley/solver_voltage.py_make_dhs_solve(): factory that creates a solve function with@jax.custom_jvp. The JVP computesdA @ xusing precomputed child/parent index arrays, then solvesA^{-1}(db - dA x)by calling the same solver._pad_comp_edges(): pads ragged grouped edge arrays to dense numpy arrays (needed because levels can have different numbers of edges)._get_dhs_solve(): caches the solve function on thesolve_indexerdict to avoid recreating it on every call. (Note: caching is currently semi-optimal and may be removable.)step_voltage_implicit_with_dhs_solve()to delegate to the cached solve function._comp_based_backsub_recursive_doubling(): computenum_recursive_stepsviaceil(log2), use JAX array forparent_jump(enablingparent_jump[parent_jump]doubling), removen_nodesparameter.jaxley/modules/base.py_init_solver_jaxley_dhs_solve(): precompute flatall_childrenandall_parentsindex arrays fromnode_order_groupedfor the custom JVP'sdA @ xcomputation.jaxley/modules/network.pyall_children/all_parentsprecomputation for the network-level solver init.tests/test_solver.pytest_dhs_solve_handles_ragged_grouped_edges: verifies correctness of the solve and its gradient againstnp.linalg.solvefor both CPU and GPU paths, including ragged edge groupings.tests/helpers.py,tests/test_graph.pyBenchmark results
ncomp scaling (single branch)
The bwd/fwd ratio drops from 789x to 3.8x at 10,000 compartments.
Morphology (SWC) benchmarks — selected highlights
Compilation times are unchanged. No impact on forward pass performance.