Skip to content

Make op_system callable JAX-native for diffrax/NUTS workflows #74

@jc-macdonald

Description

@jc-macdonald

Summary

op_system currently compiles a callable RHS (eval_fn) that is NumPy-coercing at multiple points. This makes the resulting callable non-JAX-traceable, which blocks true diffrax-native + gradient-based inference workflows (e.g. NUTS/HMC through the ODE solve).

Problem

The current callable path concretizes values to NumPy/Python scalars in both compiler and system wrapper layers:

  • src/op_system/compile.py
    • env binding uses np + np.float64(...)
    • helper reductions cast via float(...) / np.float64(...)
    • equation evaluation builds outputs with np.empty(..., dtype=np.float64) + scalar assignment
    • eval_fn coercions use np.asarray(..., dtype=np.float64)
  • flepimop2-op_system/src/flepimop2/system/op_system/__init__.py
    • _stepper coercions force NumPy for time, state, and return value

These coercions break tracer propagation and prevent end-to-end JAX differentiation.

Proposed Solution

Introduce a backend-aware compile/eval path that preserves array types:

  1. Add backend namespace parameter to compile path
  • Extend compile_rhs / _make_eval_fn with a backend namespace arg (e.g., xp) defaulting to NumPy.
  • Bind expression env so user expressions using np.* still work by mapping np -> xp.
  1. Remove scalar concretization in compiler runtime
  • In alias/equation evaluation, avoid float(...) and np.float64(...) casts.
  • Keep values as backend array/scalar types.
  • Replace reductions with backend ops (xp.sum(xp.stack(...))) instead of Python sum(float(...)).
  1. Make equation output backend-preserving
  • Replace np.empty(..., dtype=np.float64) + indexed assignment with list collection + xp.asarray(...).
  1. Relax eval_fn coercions
  • Avoid unconditional np.asarray(..., dtype=np.float64) and np.float64(t) in JAX mode.
  • Keep validation but do not force conversion that concretizes tracers.
  1. Update flepimop2-op_system _stepper
  • Keep shape checks.
  • Remove forced NumPy casts around input/output in JAX mode.
  • Return compiled callable output without downcasting to NumPy.

Scope / Compatibility

  • Preserve existing behavior by default (xp=np) so current NumPy users are unaffected.
  • Add optional JAX-native path (e.g., xp=jax.numpy) for diffrax-native integration.

Acceptance Criteria

  • Existing NumPy workflows remain unchanged and pass current tests.
  • A JAX-enabled compiled callable can be evaluated under jax.jit without concretization errors.
  • diffrax integration can consume the callable without jax.pure_callback.
  • Gradient-based inference (NUTS/HMC) can trace through RHS evaluation for a representative model.

Suggested Follow-up

  • Add focused tests for both backends:
    • NumPy parity tests (old vs new behavior)
    • JAX tracer tests (jit, grad where applicable)
    • Minimal diffrax solve smoke test using compiled callable

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions