diff --git a/flepimop2-op_engine/pyproject.toml b/flepimop2-op_engine/pyproject.toml index 7a8613b..c62eac4 100644 --- a/flepimop2-op_engine/pyproject.toml +++ b/flepimop2-op_engine/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "op-engine @ git+https://github.com/ACCIDDA/op_engine.git@main", # flepimop2 is not in a registry, so keep it as a direct reference. - "flepimop2 @ git+https://github.com/ACCIDDA/flepimop2.git@43143a652480f3db4480884a436dba9d2ffb31d3", + "flepimop2 @ git+https://github.com/ACCIDDA/flepimop2.git@main", "pydantic>=2.0,<3", "numpy>=1.26", diff --git a/flepimop2-op_engine/src/flepimop2/engine/op_engine/__init__.py b/flepimop2-op_engine/src/flepimop2/engine/op_engine/__init__.py index 1e4eda2..f7817b4 100644 --- a/flepimop2-op_engine/src/flepimop2/engine/op_engine/__init__.py +++ b/flepimop2-op_engine/src/flepimop2/engine/op_engine/__init__.py @@ -6,10 +6,10 @@ from typing import TYPE_CHECKING, Literal import numpy as np -from flepimop2.configuration import IdentifierString, ModuleModel +from flepimop2.configuration import ModuleModel from flepimop2.engine.abc import EngineABC from flepimop2.exceptions import ValidationIssue -from flepimop2.typing import StateChangeEnum # noqa: TC002 +from flepimop2.typing import IdentifierString, StateChangeEnum # noqa: TC002 from pydantic import Field from op_engine.core_solver import ( @@ -159,7 +159,7 @@ def run( if isinstance(system_axis, str | int): operator_axis = system_axis - stepper: SystemProtocol = system._stepper # noqa: SLF001 + stepper: SystemProtocol = system.bind() mixing_kernels = system.option("mixing_kernels", None) merged_params = { diff --git a/flepimop2-op_engine/tests/test_engine.py b/flepimop2-op_engine/tests/test_engine.py index dc24dbd..ddb3a48 100644 --- a/flepimop2-op_engine/tests/test_engine.py +++ b/flepimop2-op_engine/tests/test_engine.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import functools +from typing import TYPE_CHECKING, Any import numpy as np import pytest @@ -11,7 +12,7 @@ from flepimop2.engine.op_engine import OpEngineFlepimop2Engine if TYPE_CHECKING: - from flepimop2.system.abc import SystemProtocol + from flepimop2.typing import IdentifierString, SystemProtocol # ----------------------------------------------------------------------------- # Test helpers @@ -33,7 +34,7 @@ def __call__( class _GoodSystem(SystemABC): - """SystemABC implementation exposing a valid stepper via _stepper.""" + """SystemABC implementation exposing a valid stepper via bind().""" module = "flepimop2.system.test_good" state_change = "flow" @@ -47,6 +48,11 @@ def __init__(self) -> None: } } + def _bind_impl( + self, params: dict[IdentifierString, Any] | None = None + ) -> SystemProtocol: + return functools.partial(self._stepper, **(params or {})) + class _DeltaSystem(_GoodSystem): """SystemABC implementation with incompatible state_change.""" @@ -154,3 +160,31 @@ def test_validate_system_checks_state_change() -> None: issues = engine.validate_system(bad) assert issues is not None assert issues[0].kind == "incompatible_system" + + +# ----------------------------------------------------------------------------- +# Bind API integration +# ----------------------------------------------------------------------------- + + +def test_engine_uses_bind_not_stepper() -> None: + """Engine calls system.bind() rather than accessing system._stepper.""" + engine = OpEngineFlepimop2Engine(state_change="flow") + system = _GoodSystem() + bind_called = False + original_bind = system.bind + + def tracking_bind( + params: dict[IdentifierString, Any] | None = None, **kwargs: object + ) -> SystemProtocol: + nonlocal bind_called + bind_called = True + return original_bind(params, **kwargs) + + system.bind = tracking_bind # type: ignore[assignment] + + times = np.array([0.0, 0.1], dtype=np.float64) + y0 = np.array([1.0], dtype=np.float64) + engine.run(system, times, y0, {}) + + assert bind_called, "Engine should call system.bind()"