Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 54 additions & 7 deletions flepimop2-op_engine/src/flepimop2/engine/op_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
)
from op_engine.model_core import ModelCore, ModelCoreOptions

from .config import OpEngineEngineConfig, _coerce_operator_specs, _has_operator_specs
from .config import (
OpEngineEngineConfig,
SolverMethod,
_coerce_operator_specs,
_has_operator_specs,
)

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -106,19 +111,55 @@ class OpEngineFlepimop2Engine(ModuleModel, EngineABC):
config: OpEngineEngineConfig = Field(default_factory=OpEngineEngineConfig)

def validate_system(self, system: SystemABC) -> list[ValidationIssue] | None:
"""Validate system compatibility against the engine state-change mode."""
"""Validate system compatibility with engine config."""
issues: list[ValidationIssue] = []

if system.state_change != self.state_change:
return [
issues.append(
ValidationIssue(
msg=(
f"Engine state change type, '{self.state_change}', is not "
"compatible with system state change type "
f"'{system.state_change}'."
),
kind="incompatible_system",
),
)

method = self.config.method
is_imex = method.is_imex

if is_imex and not _has_operator_specs(
_coerce_operator_specs(self.config.operators),
):
sys_ops = system.option("operators", None)
if not _has_operator_specs(_coerce_operator_specs(sys_ops)):
issues.append(
ValidationIssue(
msg=(
f"IMEX method '{method}' requires operator matrices, "
"but neither the engine config nor "
"system.option('operators') provides them."
),
kind="missing_operators",
),
)
]
return None

if method.is_implicit:
jac = system.option("jacobian", None)
if jac is None:
issues.append(
ValidationIssue(
msg=(
f"Implicit/Rosenbrock method '{method}' requires a "
"Jacobian callable, but system.option('jacobian') "
"is not provided."
),
kind="missing_jacobian",
),
)

return issues or None

def run(
self,
Expand All @@ -137,7 +178,8 @@ def run(
n_state = int(y0.size)

run_cfg = self.config.to_run_config()
is_imex = run_cfg.method.startswith("imex-")
method = self.config.method
is_imex = method.is_imex
operators = run_cfg.operators

if is_imex and not _has_operator_specs(operators):
Expand All @@ -153,6 +195,11 @@ def run(
)
raise ValueError(msg)

if method.is_implicit:
jacobian = system.option("jacobian", None)
if callable(jacobian):
run_cfg = replace(run_cfg, jacobian=jacobian)

operator_axis = self.config.operator_axis
if operator_axis == "state":
system_axis = system.option("operator_axis", None)
Expand Down Expand Up @@ -180,4 +227,4 @@ def run(
return np.asarray(np.column_stack((times, states)), dtype=np.float64)


__all__ = ["OpEngineEngineConfig", "OpEngineFlepimop2Engine"]
__all__ = ["OpEngineEngineConfig", "OpEngineFlepimop2Engine", "SolverMethod"]
48 changes: 43 additions & 5 deletions flepimop2-op_engine/src/flepimop2/engine/op_engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from __future__ import annotations

from typing import Any, Literal
from enum import StrEnum
from typing import Any

from pydantic import BaseModel, ConfigDict, Field, model_validator

Expand Down Expand Up @@ -32,14 +33,46 @@ def _coerce_operator_specs(specs: object) -> OperatorSpecs | None:
return None


class SolverMethod(StrEnum):
"""Solver method identifiers for op_engine integration."""

EULER = "euler"
HEUN = "heun"
IMEX_EULER = "imex-euler"
IMEX_HEUN_TR = "imex-heun-tr"
IMEX_TRBDF2 = "imex-trbdf2"
IMPLICIT_EULER = "implicit-euler"
TRAPEZOIDAL = "trapezoidal"
BDF2 = "bdf2"
ROS2 = "ros2"

@property
def is_imex(self) -> bool:
"""Whether this method is an IMEX method."""
return self.value.startswith("imex-")

@property
def is_implicit(self) -> bool:
"""Whether this method requires a Jacobian."""
return self in _IMPLICIT_METHODS


_IMPLICIT_METHODS: frozenset[SolverMethod] = frozenset(
{
SolverMethod.IMPLICIT_EULER,
SolverMethod.TRAPEZOIDAL,
SolverMethod.BDF2,
SolverMethod.ROS2,
},
)


class OpEngineEngineConfig(BaseModel):
"""Configuration schema for op_engine when used as a flepimop2 engine."""

model_config = ConfigDict(extra="allow")

method: Literal["euler", "heun", "imex-euler", "imex-heun-tr", "imex-trbdf2"] = (
"heun"
)
method: SolverMethod = SolverMethod.HEUN
adaptive: bool = False
strict: bool = True
rtol: float = Field(default=1e-6, ge=0.0)
Expand Down Expand Up @@ -91,4 +124,9 @@ def to_run_config(self) -> RunConfig:
)


__all__ = ["OpEngineEngineConfig", "_coerce_operator_specs", "_has_operator_specs"]
__all__ = [
"OpEngineEngineConfig",
"SolverMethod",
"_coerce_operator_specs",
"_has_operator_specs",
]
Empty file.
22 changes: 11 additions & 11 deletions flepimop2-op_engine/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from op_engine.core_solver import OperatorSpecs, RunConfig # noqa: E402
from pydantic import ValidationError # noqa: E402

from flepimop2.engine.op_engine import OpEngineEngineConfig # noqa: E402
from flepimop2.engine.op_engine import OpEngineEngineConfig, SolverMethod # noqa: E402


def _has_any_operator_specs(specs: OperatorSpecs) -> bool:
Expand Down Expand Up @@ -47,7 +47,7 @@ def test_engine_config_defaults_to_run_config() -> None:
def test_engine_config_round_trips_selected_fields() -> None:
"""Engine config round-trips selected fields correctly."""
cfg = OpEngineEngineConfig(
method="euler",
method=SolverMethod.EULER,
adaptive=True,
strict=False,
rtol=1e-4,
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_engine_config_round_trips_selected_fields() -> None:
def test_engine_config_allows_unknown_fields() -> None:
"""Engine config should allow unknown fields without error."""
cfg = OpEngineEngineConfig( # type: ignore[call-arg]
method="heun",
method=SolverMethod.HEUN,
adaptive=False,
some_unknown_key=123,
nested_unknown={"a": 1},
Expand All @@ -97,7 +97,7 @@ def test_engine_config_gamma_bounds_validation() -> None:
"""Engine config validates gamma bounds for imex-trbdf2 method."""
# IMEX requires operators at parse-time.
cfg = OpEngineEngineConfig(
method="imex-trbdf2",
method=SolverMethod.IMEX_TRBDF2,
gamma=0.6,
operators={"default": "sentinel"},
)
Expand All @@ -108,36 +108,36 @@ def test_engine_config_gamma_bounds_validation() -> None:
# invalid: gamma must be in (0, 1)
with pytest.raises(ValidationError):
OpEngineEngineConfig(
method="imex-trbdf2",
method=SolverMethod.IMEX_TRBDF2,
gamma=0.0,
operators={"default": "sentinel"},
)

with pytest.raises(ValidationError):
OpEngineEngineConfig(
method="imex-trbdf2",
method=SolverMethod.IMEX_TRBDF2,
gamma=1.0,
operators={"default": "sentinel"},
)

with pytest.raises(ValidationError):
OpEngineEngineConfig(
method="imex-trbdf2",
method=SolverMethod.IMEX_TRBDF2,
gamma=-0.1,
operators={"default": "sentinel"},
)

with pytest.raises(ValidationError):
OpEngineEngineConfig(
method="imex-trbdf2",
method=SolverMethod.IMEX_TRBDF2,
gamma=1.1,
operators={"default": "sentinel"},
)


def test_engine_config_imex_allows_deferred_operators() -> None:
"""IMEX methods may omit operators to defer to system options at runtime."""
cfg = OpEngineEngineConfig(method="imex-euler")
cfg = OpEngineEngineConfig(method=SolverMethod.IMEX_EULER)
run = cfg.to_run_config()
assert run.method == "imex-euler"
assert isinstance(run.operators, OperatorSpecs)
Expand All @@ -147,13 +147,13 @@ def test_engine_config_imex_allows_deferred_operators() -> None:
def test_engine_config_imex_rejects_explicitly_empty_operator_block() -> None:
"""Providing an empty operator block should raise validation errors."""
with pytest.raises(ValidationError):
OpEngineEngineConfig(method="imex-heun-tr", operators={})
OpEngineEngineConfig(method=SolverMethod.IMEX_HEUN_TR, operators={})


def test_engine_config_imex_with_operators_still_valid() -> None:
"""Providing IMEX operators explicitly should still validate."""
cfg = OpEngineEngineConfig(
method="imex-euler",
method=SolverMethod.IMEX_EULER,
operators={"default": "sentinel"},
)
run = cfg.to_run_config()
Expand Down
Loading