Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,36 @@ spec = {
compiled = compile_spec(spec)
dydt = compiled.eval_fn(0.0, [999.0, 1.0, 0.0], beta=0.3, gamma=0.1)
```

## Optional JAX Backend

Install the optional dependency group if you want to compile RHS callables
for JAX-native tracing and integration workflows:

```shell
pip install "op_system[jax]"
```

Then pass the backend namespace when compiling:

```python
import jax.numpy as jnp
from op_system import compile_spec

compiled = compile_spec(spec, xp=jnp)
dydt = compiled.eval_fn(0.0, jnp.asarray([999.0, 1.0, 0.0]), beta=0.3, gamma=0.1)
```

If you do not pass `xp`, `op_system` defaults to NumPy behavior.

For diffrax-native solves and NUTS/HMC workflows, install the inference extra:

```shell
pip install "op_system[jax-inference]"
```

This includes `diffrax` and `blackjax` for end-to-end tracing workflows.

## YAML examples (organized and API-current)

The example set below is intentionally small but complete: each core modeling pattern is shown for both `expr` and `transitions` pathways where applicable.
Expand Down
4 changes: 2 additions & 2 deletions flepimop2-op_system/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ authors = [
license = { file = "LICENSE" }
dependencies = [
"op-system>=0.1.0",
"flepimop2>=0.1.0",
"flepimop2>=0.2.0",
"pydantic>=2.0,<3",
"numpy>=1.26",
"PyYAML>=6.0",
Expand Down Expand Up @@ -78,7 +78,7 @@ convention = "google"
dev = [
"pytest>=8.4.2",
"ruff>=0.13.3",
"mypy>=1.18.2",
"mypy>=1.20.2",
]

[[tool.mypy.overrides]]
Expand Down
34 changes: 28 additions & 6 deletions flepimop2-op_system/src/flepimop2/system/op_system/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from __future__ import annotations

import functools
import importlib
import sys
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Literal, NamedTuple
Expand Down Expand Up @@ -65,6 +66,7 @@ class _AxesMeta(NamedTuple):
class OpSystemSystem(ModuleModel, SystemABC): # noqa: D101
module: Literal["flepimop2.system.op_system"] = "flepimop2.system.op_system"
state_change: StateChangeEnum = StateChangeEnum.FLOW
backend: Literal["numpy", "jax"] = "numpy"

spec: dict[str, object] = Field(
default=..., description="Inline op_system RHS specification (already loaded)"
Expand All @@ -77,7 +79,8 @@ def model_post_init(self, context: Any) -> None: # noqa: ANN401
del context

spec_obj = self.spec
compiled = compile_spec(spec_obj)
xp = self._get_backend_namespace(self.backend)
compiled = compile_spec(spec_obj, xp=xp)
n_state = len(compiled.state_names)

axes_meta = self._extract_axes_meta(compiled)
Expand Down Expand Up @@ -107,7 +110,10 @@ def _stepper(
state: Float64NDArray,
**kwargs: Any, # noqa: ANN401
) -> Float64NDArray:
state_arr = np.asarray(state, dtype=np.float64)
if self.backend == "numpy":
state_arr = np.asarray(state, dtype=np.float64)
else:
state_arr = xp.asarray(state)
if state_arr.ndim != 1 or state_arr.size != n_state:
msg = (
"state must be a 1D array matching the spec state length; "
Expand All @@ -116,10 +122,15 @@ def _stepper(
raise ValueError(msg)
params = dict(mixing_kernels)
params.update(kwargs)
return np.asarray(
compiled.eval_fn(np.float64(time), state_arr, **params),
dtype=np.float64,
)
if self.backend == "numpy":
return np.asarray(
compiled.eval_fn(np.float64(time), state_arr, **params),
dtype=np.float64,
)
result = compiled.eval_fn(time, state_arr, **params)
if TYPE_CHECKING:
return np.asarray(result, dtype=np.float64)
return result

self._stepper = _stepper
self._compiled_rhs = compiled # handy for debugging/adapters
Expand Down Expand Up @@ -158,6 +169,17 @@ def _extract_axes_meta(compiled: CompiledRhs) -> _AxesMeta:
axis_order=tuple(axis_order), axis_sizes=axis_sizes, axis_coords=axis_coords
)

@staticmethod
def _get_backend_namespace(backend: Literal["numpy", "jax"]) -> Any: # noqa: ANN401
if backend == "numpy":
return np
try:
jnp = importlib.import_module("jax.numpy")
except ImportError as exc:
msg = "backend='jax' requires jax to be installed"
raise ImportError(msg) from exc
return jnp

@staticmethod
def _extract_operators(
compiled: CompiledRhs,
Expand Down
15 changes: 15 additions & 0 deletions flepimop2-op_system/tests/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,21 @@ def test_bind_delegates_to_step(sir_spec: dict[str, object]) -> None:
np.testing.assert_allclose(via_bind, via_step, rtol=0.0, atol=0.0)


def test_jax_backend_stepper_jittable(sir_spec: dict[str, object]) -> None:
"""JAX backend returns tracer-compatible outputs under jit."""
jax = pytest.importorskip("jax")
jnp = pytest.importorskip("jax.numpy")

sys = OpSystemSystem(spec=sir_spec, backend="jax")
stepper = sys.bind(params={"beta": 0.3, "gamma": 0.1})
y0 = jnp.asarray([0.999, 0.001, 0.0])

out = jax.jit(lambda y: stepper(time=0.0, state=y))(y0)

expected = np.array([-0.0002997, 0.0001997, 0.0001], dtype=np.float64)
np.testing.assert_allclose(np.asarray(out), expected, rtol=1e-6, atol=1e-12)


# -- Integration: full SystemABC.bind() contract -----------------------------------


Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ data = [
"pandas>=2.2",
"pyarrow>=16.0",
]
jax = [
"jax>=0.4",
]
jax-inference = [
"jax>=0.4",
"diffrax",
"blackjax",
]

[build-system]
requires = ["hatchling"]
Expand Down
7 changes: 5 additions & 2 deletions src/op_system/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,23 @@
# -----------------------------------------------------------------------------


def compile_spec(spec: dict[str, object]) -> CompiledRhs: # noqa: RUF067
def compile_spec(spec: dict[str, object], *, xp: object | None = None) -> CompiledRhs: # noqa: RUF067
"""
Validate, normalize, and compile a RHS specification in one call.

This is the recommended public entrypoint for most users and adapters.

Args:
spec: Raw RHS specification mapping (YAML/JSON friendly).
xp: Optional array backend namespace; defaults to NumPy behavior.

Returns:
CompiledRhs: Runnable RHS callable container.
"""
rhs = normalize_rhs(spec)
return compile_rhs(rhs)
if xp is None:
return compile_rhs(rhs)
return compile_rhs(rhs, xp=xp)


# -----------------------------------------------------------------------------
Expand Down
Loading
Loading