Summary
op_system currently compiles a callable RHS (eval_fn) that is NumPy-coercing at multiple points. This makes the resulting callable non-JAX-traceable, which blocks true diffrax-native + gradient-based inference workflows (e.g. NUTS/HMC through the ODE solve).
Problem
The current callable path concretizes values to NumPy/Python scalars in both compiler and system wrapper layers:
src/op_system/compile.py
- env binding uses
np + np.float64(...)
- helper reductions cast via
float(...) / np.float64(...)
- equation evaluation builds outputs with
np.empty(..., dtype=np.float64) + scalar assignment
eval_fn coercions use np.asarray(..., dtype=np.float64)
flepimop2-op_system/src/flepimop2/system/op_system/__init__.py
_stepper coercions force NumPy for time, state, and return value
These coercions break tracer propagation and prevent end-to-end JAX differentiation.
Proposed Solution
Introduce a backend-aware compile/eval path that preserves array types:
- Add backend namespace parameter to compile path
- Extend
compile_rhs / _make_eval_fn with a backend namespace arg (e.g., xp) defaulting to NumPy.
- Bind expression env so user expressions using
np.* still work by mapping np -> xp.
- Remove scalar concretization in compiler runtime
- In alias/equation evaluation, avoid
float(...) and np.float64(...) casts.
- Keep values as backend array/scalar types.
- Replace reductions with backend ops (
xp.sum(xp.stack(...))) instead of Python sum(float(...)).
- Make equation output backend-preserving
- Replace
np.empty(..., dtype=np.float64) + indexed assignment with list collection + xp.asarray(...).
- Relax
eval_fn coercions
- Avoid unconditional
np.asarray(..., dtype=np.float64) and np.float64(t) in JAX mode.
- Keep validation but do not force conversion that concretizes tracers.
- Update flepimop2-op_system
_stepper
- Keep shape checks.
- Remove forced NumPy casts around input/output in JAX mode.
- Return compiled callable output without downcasting to NumPy.
Scope / Compatibility
- Preserve existing behavior by default (
xp=np) so current NumPy users are unaffected.
- Add optional JAX-native path (e.g.,
xp=jax.numpy) for diffrax-native integration.
Acceptance Criteria
- Existing NumPy workflows remain unchanged and pass current tests.
- A JAX-enabled compiled callable can be evaluated under
jax.jit without concretization errors.
- diffrax integration can consume the callable without
jax.pure_callback.
- Gradient-based inference (NUTS/HMC) can trace through RHS evaluation for a representative model.
Suggested Follow-up
- Add focused tests for both backends:
- NumPy parity tests (old vs new behavior)
- JAX tracer tests (
jit, grad where applicable)
- Minimal diffrax solve smoke test using compiled callable
Summary
op_systemcurrently compiles a callable RHS (eval_fn) that is NumPy-coercing at multiple points. This makes the resulting callable non-JAX-traceable, which blocks true diffrax-native + gradient-based inference workflows (e.g. NUTS/HMC through the ODE solve).Problem
The current callable path concretizes values to NumPy/Python scalars in both compiler and system wrapper layers:
src/op_system/compile.pynp+np.float64(...)float(...)/np.float64(...)np.empty(..., dtype=np.float64)+ scalar assignmenteval_fncoercions usenp.asarray(..., dtype=np.float64)flepimop2-op_system/src/flepimop2/system/op_system/__init__.py_steppercoercions force NumPy fortime,state, and return valueThese coercions break tracer propagation and prevent end-to-end JAX differentiation.
Proposed Solution
Introduce a backend-aware compile/eval path that preserves array types:
compile_rhs/_make_eval_fnwith a backend namespace arg (e.g.,xp) defaulting to NumPy.np.*still work by mappingnp -> xp.float(...)andnp.float64(...)casts.xp.sum(xp.stack(...))) instead of Pythonsum(float(...)).np.empty(..., dtype=np.float64)+ indexed assignment with list collection +xp.asarray(...).eval_fncoercionsnp.asarray(..., dtype=np.float64)andnp.float64(t)in JAX mode._stepperScope / Compatibility
xp=np) so current NumPy users are unaffected.xp=jax.numpy) for diffrax-native integration.Acceptance Criteria
jax.jitwithout concretization errors.jax.pure_callback.Suggested Follow-up
jit,gradwhere applicable)