perf(_numba_random): defer numba import; widen jax/numpy bounds, drop absl-py#171
Merged
Conversation
Importing brainevent no longer imports numba. The LFSR primitives are now plain-Python functions, JIT-compiled lazily on the first get_numba_lfsr_* call; that call rebinds the module globals to njit dispatchers so the inline='always' cross-references (e.g. lfsr88_rand -> lfsr88_next_key) still resolve when an enclosing kernel is compiled. Behaviour with numba absent is preserved (primitives stay plain Python). Dependencies: - drop absl-py: never imported (only a C++ absl::StatusCode doc reference); not required by jax/jaxlib or brainevent's dependency closure. - numpy>=1.15 -> numpy>=2.0 - jax: drop the hard <0.11 upper bound, raise the floor to >=0.8.0 (core and cpu/cuda12/cuda13/tpu extras). Replace the removed jax upper-bound install pin with a runtime check in the numba FFI bridge: _warn_if_untested_jax() warns once, at FFI target registration, when the installed jax is newer than the validated XLA FFI ABI (XLA_FFI_API_MINOR / ffi.h layout). Wired into the CPU and CUDA register_ffi_target sites. Tests: import-does-not-import-numba, lazy-compilation lifecycle, inline-inside-njit-kernel, and the FFI version-check helper.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
import braineventno longer importsnumba. The LFSR primitives in_numba_random.pyare now plain-Python functions, JIT-compiled lazily on the firstget_numba_lfsr_*call. That call rebinds the module globals tonjitdispatchers so theinline='always'cross-references (e.g.lfsr88_rand→lfsr88_next_key) still resolve when an enclosing kernel compiles. Numba-absent behaviour is preserved.absl-py— never imported (only a C++absl::StatusCodedoc reference); not in brainevent's dependency closure (jax/jaxlib don't require it).numpy>=1.15→numpy>=2.0.jax: drop the hard<0.11upper bound, raise the floor to>=0.8.0(core + cpu/cuda12/cuda13/tpu extras)._warn_if_untested_jax()warns once, at FFI target registration, when the installed jax is newer than the validated XLA FFI ABI (XLA_FFI_API_MINOR/ffi.hlayout). Wired into the CPU and both CUDAregister_ffi_targetsites.Why
Importing
numbais expensive (~1–2s); it was pulled in eagerly on everyimport braineventvia_numba_random. The ABI version is already auto-detected from the installedjaxlibheader, so the hard<0.11pin is better expressed as a runtime warning.Tests
import brainevent/import _numba_randomdo not import numba (subprocess).@numba.njitkernel (chainednormal→randn→rand→next_key).97 passed locally (incl. real numba-backed
_jit_scalarend-to-end variants). mypy: no new errors in changed source files.