diff --git a/README.md b/README.md index 9d4c2e2..79d9b6b 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,36 @@ spec = { compiled = compile_spec(spec) dydt = compiled.eval_fn(0.0, [999.0, 1.0, 0.0], beta=0.3, gamma=0.1) ``` + +## Optional JAX Backend + +Install the optional dependency group if you want to compile RHS callables +for JAX-native tracing and integration workflows: + +```shell +pip install "op_system[jax]" +``` + +Then pass the backend namespace when compiling: + +```python +import jax.numpy as jnp +from op_system import compile_spec + +compiled = compile_spec(spec, xp=jnp) +dydt = compiled.eval_fn(0.0, jnp.asarray([999.0, 1.0, 0.0]), beta=0.3, gamma=0.1) +``` + +If you do not pass `xp`, `op_system` defaults to NumPy behavior. + +For diffrax-native solves and NUTS/HMC workflows, install the inference extra: + +```shell +pip install "op_system[jax-inference]" +``` + +This includes `diffrax` and `blackjax` for end-to-end tracing workflows. + ## YAML examples (organized and API-current) The example set below is intentionally small but complete: each core modeling pattern is shown for both `expr` and `transitions` pathways where applicable. diff --git a/flepimop2-op_system/pyproject.toml b/flepimop2-op_system/pyproject.toml index feefce1..b704e94 100644 --- a/flepimop2-op_system/pyproject.toml +++ b/flepimop2-op_system/pyproject.toml @@ -12,7 +12,7 @@ authors = [ license = { file = "LICENSE" } dependencies = [ "op-system>=0.1.0", - "flepimop2>=0.1.0", + "flepimop2>=0.2.0", "pydantic>=2.0,<3", "numpy>=1.26", "PyYAML>=6.0", @@ -78,7 +78,7 @@ convention = "google" dev = [ "pytest>=8.4.2", "ruff>=0.13.3", - "mypy>=1.18.2", + "mypy>=1.20.2", ] [[tool.mypy.overrides]] diff --git a/flepimop2-op_system/src/flepimop2/system/op_system/__init__.py b/flepimop2-op_system/src/flepimop2/system/op_system/__init__.py index bf073af..8bacf7d 100644 --- a/flepimop2-op_system/src/flepimop2/system/op_system/__init__.py +++ b/flepimop2-op_system/src/flepimop2/system/op_system/__init__.py @@ -27,6 +27,7 @@ from __future__ import annotations import functools +import importlib import sys from types import MappingProxyType from typing import TYPE_CHECKING, Any, Literal, NamedTuple @@ -65,6 +66,7 @@ class _AxesMeta(NamedTuple): class OpSystemSystem(ModuleModel, SystemABC): # noqa: D101 module: Literal["flepimop2.system.op_system"] = "flepimop2.system.op_system" state_change: StateChangeEnum = StateChangeEnum.FLOW + backend: Literal["numpy", "jax"] = "numpy" spec: dict[str, object] = Field( default=..., description="Inline op_system RHS specification (already loaded)" @@ -77,7 +79,8 @@ def model_post_init(self, context: Any) -> None: # noqa: ANN401 del context spec_obj = self.spec - compiled = compile_spec(spec_obj) + xp = self._get_backend_namespace(self.backend) + compiled = compile_spec(spec_obj, xp=xp) n_state = len(compiled.state_names) axes_meta = self._extract_axes_meta(compiled) @@ -107,7 +110,10 @@ def _stepper( state: Float64NDArray, **kwargs: Any, # noqa: ANN401 ) -> Float64NDArray: - state_arr = np.asarray(state, dtype=np.float64) + if self.backend == "numpy": + state_arr = np.asarray(state, dtype=np.float64) + else: + state_arr = xp.asarray(state) if state_arr.ndim != 1 or state_arr.size != n_state: msg = ( "state must be a 1D array matching the spec state length; " @@ -116,10 +122,15 @@ def _stepper( raise ValueError(msg) params = dict(mixing_kernels) params.update(kwargs) - return np.asarray( - compiled.eval_fn(np.float64(time), state_arr, **params), - dtype=np.float64, - ) + if self.backend == "numpy": + return np.asarray( + compiled.eval_fn(np.float64(time), state_arr, **params), + dtype=np.float64, + ) + result = compiled.eval_fn(time, state_arr, **params) + if TYPE_CHECKING: + return np.asarray(result, dtype=np.float64) + return result self._stepper = _stepper self._compiled_rhs = compiled # handy for debugging/adapters @@ -158,6 +169,17 @@ def _extract_axes_meta(compiled: CompiledRhs) -> _AxesMeta: axis_order=tuple(axis_order), axis_sizes=axis_sizes, axis_coords=axis_coords ) + @staticmethod + def _get_backend_namespace(backend: Literal["numpy", "jax"]) -> Any: # noqa: ANN401 + if backend == "numpy": + return np + try: + jnp = importlib.import_module("jax.numpy") + except ImportError as exc: + msg = "backend='jax' requires jax to be installed" + raise ImportError(msg) from exc + return jnp + @staticmethod def _extract_operators( compiled: CompiledRhs, diff --git a/flepimop2-op_system/tests/test_system.py b/flepimop2-op_system/tests/test_system.py index 831b511..e62462f 100644 --- a/flepimop2-op_system/tests/test_system.py +++ b/flepimop2-op_system/tests/test_system.py @@ -133,6 +133,21 @@ def test_bind_delegates_to_step(sir_spec: dict[str, object]) -> None: np.testing.assert_allclose(via_bind, via_step, rtol=0.0, atol=0.0) +def test_jax_backend_stepper_jittable(sir_spec: dict[str, object]) -> None: + """JAX backend returns tracer-compatible outputs under jit.""" + jax = pytest.importorskip("jax") + jnp = pytest.importorskip("jax.numpy") + + sys = OpSystemSystem(spec=sir_spec, backend="jax") + stepper = sys.bind(params={"beta": 0.3, "gamma": 0.1}) + y0 = jnp.asarray([0.999, 0.001, 0.0]) + + out = jax.jit(lambda y: stepper(time=0.0, state=y))(y0) + + expected = np.array([-0.0002997, 0.0001997, 0.0001], dtype=np.float64) + np.testing.assert_allclose(np.asarray(out), expected, rtol=1e-6, atol=1e-12) + + # -- Integration: full SystemABC.bind() contract ----------------------------------- diff --git a/pyproject.toml b/pyproject.toml index cbdd466..3488d66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,14 @@ data = [ "pandas>=2.2", "pyarrow>=16.0", ] +jax = [ + "jax>=0.4", +] +jax-inference = [ + "jax>=0.4", + "diffrax", + "blackjax", +] [build-system] requires = ["hatchling"] diff --git a/src/op_system/__init__.py b/src/op_system/__init__.py index 99b9bde..89362f1 100644 --- a/src/op_system/__init__.py +++ b/src/op_system/__init__.py @@ -50,7 +50,7 @@ # ----------------------------------------------------------------------------- -def compile_spec(spec: dict[str, object]) -> CompiledRhs: # noqa: RUF067 +def compile_spec(spec: dict[str, object], *, xp: object | None = None) -> CompiledRhs: # noqa: RUF067 """ Validate, normalize, and compile a RHS specification in one call. @@ -58,12 +58,15 @@ def compile_spec(spec: dict[str, object]) -> CompiledRhs: # noqa: RUF067 Args: spec: Raw RHS specification mapping (YAML/JSON friendly). + xp: Optional array backend namespace; defaults to NumPy behavior. Returns: CompiledRhs: Runnable RHS callable container. """ rhs = normalize_rhs(spec) - return compile_rhs(rhs) + if xp is None: + return compile_rhs(rhs) + return compile_rhs(rhs, xp=xp) # ----------------------------------------------------------------------------- diff --git a/src/op_system/compile.py b/src/op_system/compile.py index c684c7f..4a0eb4b 100644 --- a/src/op_system/compile.py +++ b/src/op_system/compile.py @@ -22,7 +22,15 @@ import ast from dataclasses import dataclass, field from types import MappingProxyType -from typing import TYPE_CHECKING, Any, NoReturn, Protocol, cast +from typing import ( + TYPE_CHECKING, + Any, + NoReturn, + Protocol, + SupportsFloat, + SupportsIndex, + cast, +) import numpy as np from numpy.typing import NDArray @@ -34,9 +42,26 @@ from .specs import NormalizedRhs Float64Array = NDArray[np.float64] +ScalarLike = SupportsFloat | SupportsIndex | str | bytes | None _SAFE_BUILTINS = {"__import__": __import__} +class _BackendNamespace(Protocol): + def asarray(self, obj: object, dtype: object | None = None) -> object: ... + + def stack(self, arrays: list[object]) -> object: ... + + def sum(self, arr: object) -> object: ... + + +class _Indexable(Protocol): + def __getitem__(self, idx: int) -> object: ... + + +def _is_numpy_backend(xp: object) -> bool: + return xp is np or getattr(xp, "__name__", "") == "numpy" + + # ----------------------------------------------------------------------------- # Error message constants # ----------------------------------------------------------------------------- @@ -128,7 +153,7 @@ class EvalFn(Protocol): """Callable RHS evaluator supporting runtime parameter kwargs.""" def __call__( # noqa: D102 - self, t: np.float64, y: Float64Array, **params: object + self, t: object, y: object, **params: object ) -> Float64Array: ... @@ -143,7 +168,7 @@ class CompiledRhs: def bind( self, params: Mapping[str, object] - ) -> Callable[[np.float64, Float64Array], Float64Array]: + ) -> Callable[[object, object], Float64Array]: """Bind parameter values and return a 2-arg RHS: rhs(t, y) -> dydt. Args: @@ -154,7 +179,7 @@ def bind( """ params_dict = dict(params) - def rhs(t: np.float64, y: Float64Array) -> Float64Array: + def rhs(t: object, y: object) -> Float64Array: return self.eval_fn(t, y, **params_dict) return rhs @@ -379,15 +404,17 @@ def _resolve_aliases( ) -def _validate_state_vector(y_arr: np.ndarray, *, n_state: int) -> np.ndarray: +def _validate_state_vector(y_arr: object, *, n_state: int) -> object: """Validate shape of the state vector. Returns: State vector coerced to shape (n_state,). """ expected_shape = (n_state,) - if tuple(y_arr.shape) != expected_shape: - _raise_state_shape_error(expected=f"(n_state={n_state},)", got=y_arr.shape) + shape = getattr(y_arr, "shape", None) + shape_tuple = tuple(shape) if shape is not None else None + if shape_tuple != expected_shape: + _raise_state_shape_error(expected=f"(n_state={n_state},)", got=shape) return y_arr @@ -395,72 +422,90 @@ def _evaluate_equations( *, eq_code: list[CodeType], env: Mapping[str, object], - n_state: int, + xp: object, ) -> Float64Array: """Evaluate equation code objects against an environment. Returns: Derivative vector aligned to the provided state ordering. """ - out = np.empty((n_state,), dtype=np.float64) - for i, codeobj in enumerate(eq_code): + out_vals: list[object] = [] + for codeobj in eq_code: try: val = eval(codeobj, {"__builtins__": _SAFE_BUILTINS}, env) # noqa: S307 except NameError as exc: _raise_parameter_error(detail=f"unknown symbol in equation: {exc!s}") except (ValueError, TypeError, ArithmeticError) as exc: _raise_invalid_expression(detail=f"equation evaluation failed: {exc!r}") - out[i] = np.float64(val) - return out + out_vals.append(val) + + if _is_numpy_backend(xp): + return np.asarray(out_vals, dtype=np.float64) + return cast("Float64Array", cast("_BackendNamespace", xp).asarray(out_vals)) -def _make_eval_fn( +def _make_eval_fn( # noqa: C901 *, state_names: tuple[str, ...], aliases: Mapping[str, str], equations: tuple[str, ...], + xp: object, ) -> EvalFn: n_state = len(state_names) name_to_idx = {s: i for i, s in enumerate(state_names)} alias_code = _collect_alias_code(aliases) eq_code = _collect_eq_code(equations) - def _sum_state(env: Mapping[str, object]) -> np.float64: - values = [ - np.float64(float(cast("Any", v))) - for k, v in env.items() - if k in name_to_idx - ] - return np.float64(sum(values)) + def _sum_state(env: Mapping[str, object]) -> object: + values = [v for k, v in env.items() if k in name_to_idx] + if not values: + return cast("_BackendNamespace", xp).asarray(0.0) + if _is_numpy_backend(xp): + return np.float64(np.sum(np.asarray(values, dtype=np.float64))) + xp_ns = cast("_BackendNamespace", xp) + return xp_ns.sum(xp_ns.stack(values)) - def _sum_prefix(prefix: str, env: Mapping[str, object]) -> np.float64: + def _sum_prefix(prefix: str, env: Mapping[str, object]) -> object: values = [ - np.float64(float(cast("Any", v))) - for k, v in env.items() - if k.startswith(prefix) and k in name_to_idx + v for k, v in env.items() if k.startswith(prefix) and k in name_to_idx ] - return np.float64(sum(values)) + if not values: + return cast("_BackendNamespace", xp).asarray(0.0) + if _is_numpy_backend(xp): + return np.float64(np.sum(np.asarray(values, dtype=np.float64))) + xp_ns = cast("_BackendNamespace", xp) + return xp_ns.sum(xp_ns.stack(values)) def _build_env( - t: np.float64, y_arr: Float64Array, params: Mapping[str, object] + t: object, y_arr: object, params: Mapping[str, object] ) -> dict[str, object]: - env: dict[str, object] = {"np": np, "t": np.float64(t)} + t_val = ( + np.float64(cast("ScalarLike", t)) + if _is_numpy_backend(xp) + else cast("_BackendNamespace", xp).asarray(t) + ) + env: dict[str, object] = {"np": xp, "t": t_val} for s, i in name_to_idx.items(): - env[s] = np.float64(y_arr[i]) + env[s] = cast("_Indexable", y_arr)[i] env.update(params) env["sum_state"] = lambda: _sum_state(env) env["sum_prefix"] = lambda prefix: _sum_prefix(str(prefix), env) return env - def eval_fn(t: np.float64, y: Float64Array, **params: object) -> Float64Array: - y_arr = _validate_state_vector(np.asarray(y, dtype=np.float64), n_state=n_state) + def eval_fn(t: object, y: object, **params: object) -> Float64Array: + y_in: object + if _is_numpy_backend(xp): + y_in = np.asarray(y, dtype=np.float64) + else: + y_in = cast("_BackendNamespace", xp).asarray(y) + y_arr = _validate_state_vector(y_in, n_state=n_state) - env = _build_env(np.float64(t), y_arr, params) + env = _build_env(t, y_arr, params) if alias_code: env.update(_resolve_aliases(alias_code, base_env=env)) - return _evaluate_equations(eq_code=eq_code, env=env, n_state=n_state) + return _evaluate_equations(eq_code=eq_code, env=env, xp=xp) return eval_fn @@ -470,11 +515,12 @@ def eval_fn(t: np.float64, y: Float64Array, **params: object) -> Float64Array: # ----------------------------------------------------------------------------- -def compile_rhs(rhs: NormalizedRhs) -> CompiledRhs: +def compile_rhs(rhs: NormalizedRhs, *, xp: object = np) -> CompiledRhs: """Compile a normalized RHS into a runnable evaluation function. Args: rhs: Normalized RHS produced by `op_system.specs.normalize_rhs`. + xp: Array backend namespace (default: NumPy). Returns: A `CompiledRhs` containing an `eval_fn(t, y, **params) -> dydt`. @@ -489,6 +535,7 @@ def compile_rhs(rhs: NormalizedRhs) -> CompiledRhs: state_names=rhs.state_names, aliases=rhs.aliases, equations=rhs.equations, + xp=xp, ) return CompiledRhs( diff --git a/tests/op_system/test_op_system_compile.py b/tests/op_system/test_op_system_compile.py index 5b53a52..5d1f4f3 100644 --- a/tests/op_system/test_op_system_compile.py +++ b/tests/op_system/test_op_system_compile.py @@ -398,3 +398,99 @@ def test_sum_over_in_filter_evaluates_correctly() -> None: state = np.array([10.0, 3.0, 7.0], dtype=np.float64) derivs = compiled.eval_fn(np.float64(0.0), state) assert np.allclose(derivs, -state) + + +def test_compile_rhs_with_jax_backend_is_jittable() -> None: + """JAX backend preserves tracers and can be jitted/differentiated.""" + jax = pytest.importorskip("jax") + jnp = pytest.importorskip("jax.numpy") + + spec = { + "kind": "expr", + "state": ["x"], + "equations": {"x": "beta * x"}, + } + rhs = normalize_rhs(spec) + compiled = compile_rhs(rhs, xp=jnp) + + y0 = jnp.asarray([2.0]) + eval_jit = jax.jit(lambda beta: compiled.eval_fn(0.0, y0, beta=beta)) + out = eval_jit(1.5) + assert np.allclose(np.asarray(out), np.asarray([3.0])) + + grad_fn = jax.grad( + lambda beta: compiled.eval_fn(0.0, y0, beta=beta)[0], + ) + assert np.isclose(float(grad_fn(1.5)), 2.0) + + +def test_compile_rhs_with_jax_backend_diffrax_smoke() -> None: + """Compiled JAX RHS can be consumed directly by diffrax.""" + jax = pytest.importorskip("jax") + jnp = pytest.importorskip("jax.numpy") + diffrax = pytest.importorskip("diffrax") + + spec = { + "kind": "expr", + "state": ["x"], + "equations": {"x": "-beta * x"}, + } + rhs = normalize_rhs(spec) + compiled = compile_rhs(rhs, xp=jnp) + + term = diffrax.ODETerm(lambda t, y, args: compiled.eval_fn(t, y, **args)) + solver = diffrax.Tsit5() + + def solve(beta: float) -> object: + return diffrax.diffeqsolve( + term, + solver, + t0=0.0, + t1=1.0, + dt0=0.1, + y0=jnp.asarray([1.0]), + args={"beta": beta}, + saveat=diffrax.SaveAt(t1=True), + ).ys[0] + + out = jax.jit(solve)(0.5) + expected = np.exp(-0.5) + assert np.isclose(float(out), expected, rtol=1e-3) + + +def test_compile_rhs_traces_through_blackjax_nuts() -> None: + """A representative NUTS step can trace gradients through RHS eval.""" + jax = pytest.importorskip("jax") + jnp = pytest.importorskip("jax.numpy") + jr = pytest.importorskip("jax.random") + blackjax = pytest.importorskip("blackjax") + + spec = { + "kind": "expr", + "state": ["x"], + "equations": {"x": "-beta * x"}, + } + rhs = normalize_rhs(spec) + compiled = compile_rhs(rhs, xp=jnp) + + y0 = jnp.asarray([1.0]) + observed_dydt = -0.7 + + def logdensity(theta: np.ndarray) -> object: + beta = jnp.exp(theta[0]) + pred = compiled.eval_fn(0.0, y0, beta=beta)[0] + residual = pred - observed_dydt + prior = -0.5 * theta[0] ** 2 + likelihood = -40.0 * residual**2 + return prior + likelihood + + nuts = blackjax.nuts( + logdensity, + step_size=0.1, + inverse_mass_matrix=jnp.ones((1,)), + ) + state = nuts.init(jnp.asarray([0.0])) + key = jr.PRNGKey(0) + + next_state, _info = jax.jit(nuts.step)(key, state) + assert np.isfinite(float(next_state.logdensity))