Skip to content

Make op_system callable JAX-native for diffrax/NUTS workflows#75

Open
jc-macdonald wants to merge 6 commits intomainfrom
feat/jax-native-rhs-issue-74
Open

Make op_system callable JAX-native for diffrax/NUTS workflows#75
jc-macdonald wants to merge 6 commits intomainfrom
feat/jax-native-rhs-issue-74

Conversation

@jc-macdonald
Copy link
Copy Markdown
Collaborator

Closes #74

Summary

Adds a backend-aware compile path so compile_rhs (and compile_spec) can return RHS callables that are fully traceable by JAX — no forced NumPy concretization.

Changes

Core (op_system)

  • compile_rhs(..., xp=np) / compile_spec(..., xp=None) — new optional xp parameter; defaults to numpy (no behaviour change for existing callers)
  • _is_numpy_backend(xp) guard routes array ops through xp.sum(xp.stack(...)) / xp.asarray() for JAX compatibility
  • _BackendNamespace / _Indexable internal protocols replace Any in all public callable signatures
  • eval_fn skips forced float64 cast in JAX mode (avoids x64 config requirement)

Adapter (flepimop2-op_system)

  • New backend: Literal["numpy", "jax"] field on the stepper model
  • _get_backend_namespace() lazy-imports jax.numpy via importlib (optional dep, no import-time cost)
  • Stepper coerces arrays via NumPy on numpy path; passes through on JAX path

Tests

  • test_compile_rhs_with_jax_backend_is_jittable — verifies jax.jit wrapping round-trips
  • test_compile_rhs_with_diffrax_solve — ODE solve through diffrax ODETerm
  • test_compile_rhs_traces_through_blackjax_nuts — NUTS gradient tracing smoke test
  • test_jax_backend_stepper_jittable (adapter) — stepper survives jax.jit
  • All new tests use pytest.importorskip and are skipped when JAX/diffrax/blackjax are absent

Packaging

  • jax and jax-inference optional dependency groups added to pyproject.toml
  • README updated with install instructions and usage examples

@jc-macdonald jc-macdonald marked this pull request as ready for review May 1, 2026 19:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Make op_system callable JAX-native for diffrax/NUTS workflows

1 participant