Skip to content

perf(_numba_random): defer numba import; widen jax/numpy bounds, drop absl-py#171

Merged
chaoming0625 merged 1 commit into
mainfrom
worktree-lazy-numba-import
Jun 23, 2026
Merged

perf(_numba_random): defer numba import; widen jax/numpy bounds, drop absl-py#171
chaoming0625 merged 1 commit into
mainfrom
worktree-lazy-numba-import

Conversation

@chaoming0625

Copy link
Copy Markdown
Collaborator

Summary

  • Lazy numba import: import brainevent no longer imports numba. The LFSR primitives in _numba_random.py are now plain-Python functions, JIT-compiled lazily on the first get_numba_lfsr_* call. That call rebinds the module globals to njit dispatchers so the inline='always' cross-references (e.g. lfsr88_randlfsr88_next_key) still resolve when an enclosing kernel compiles. Numba-absent behaviour is preserved.
  • Dependencies:
    • Drop absl-py — never imported (only a C++ absl::StatusCode doc reference); not in brainevent's dependency closure (jax/jaxlib don't require it).
    • numpy>=1.15numpy>=2.0.
    • jax: drop the hard <0.11 upper bound, raise the floor to >=0.8.0 (core + cpu/cuda12/cuda13/tpu extras).
  • Runtime FFI ABI check replaces the removed jax install pin: _warn_if_untested_jax() warns once, at FFI target registration, when the installed jax is newer than the validated XLA FFI ABI (XLA_FFI_API_MINOR / ffi.h layout). Wired into the CPU and both CUDA register_ffi_target sites.

Why

Importing numba is expensive (~1–2s); it was pulled in eagerly on every import brainevent via _numba_random. The ABI version is already auto-detected from the installed jaxlib header, so the hard <0.11 pin is better expressed as a runtime warning.

Tests

  • import brainevent / import _numba_random do not import numba (subprocess).
  • Lazy-compilation lifecycle: plain before, dispatcher + populated tables after.
  • Inline inside a real @numba.njit kernel (chained normal→randn→rand→next_key).
  • FFI version-check helper: no warn on validated, warn on newer, warn-once, unparseable-version silent.

97 passed locally (incl. real numba-backed _jit_scalar end-to-end variants). mypy: no new errors in changed source files.

Importing brainevent no longer imports numba. The LFSR primitives are now
plain-Python functions, JIT-compiled lazily on the first get_numba_lfsr_*
call; that call rebinds the module globals to njit dispatchers so the
inline='always' cross-references (e.g. lfsr88_rand -> lfsr88_next_key)
still resolve when an enclosing kernel is compiled. Behaviour with numba
absent is preserved (primitives stay plain Python).

Dependencies:
- drop absl-py: never imported (only a C++ absl::StatusCode doc reference);
  not required by jax/jaxlib or brainevent's dependency closure.
- numpy>=1.15 -> numpy>=2.0
- jax: drop the hard <0.11 upper bound, raise the floor to >=0.8.0 (core
  and cpu/cuda12/cuda13/tpu extras).

Replace the removed jax upper-bound install pin with a runtime check in the
numba FFI bridge: _warn_if_untested_jax() warns once, at FFI target
registration, when the installed jax is newer than the validated XLA FFI
ABI (XLA_FFI_API_MINOR / ffi.h layout). Wired into the CPU and CUDA
register_ffi_target sites.

Tests: import-does-not-import-numba, lazy-compilation lifecycle,
inline-inside-njit-kernel, and the FFI version-check helper.
@chaoming0625 chaoming0625 merged commit 1e95441 into main Jun 23, 2026
4 checks passed
@chaoming0625 chaoming0625 deleted the worktree-lazy-numba-import branch June 23, 2026 13:04
@codecov

codecov Bot commented Jun 23, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 92.15686% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
brainevent/_numba_random.py 94.11% 2 Missing ⚠️
brainevent/_op/numba_cuda_ffi.py 0.00% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant