Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flepimop2-op_engine/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
"op-engine @ git+https://github.com/ACCIDDA/op_engine.git@main",

# flepimop2 is not in a registry, so keep it as a direct reference.
"flepimop2 @ git+https://github.com/ACCIDDA/flepimop2.git@43143a652480f3db4480884a436dba9d2ffb31d3",
"flepimop2 @ git+https://github.com/ACCIDDA/flepimop2.git@main",

"pydantic>=2.0,<3",
"numpy>=1.26",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from typing import TYPE_CHECKING, Literal

import numpy as np
from flepimop2.configuration import IdentifierString, ModuleModel
from flepimop2.configuration import ModuleModel
from flepimop2.engine.abc import EngineABC
from flepimop2.exceptions import ValidationIssue
from flepimop2.typing import StateChangeEnum # noqa: TC002
from flepimop2.typing import IdentifierString, StateChangeEnum # noqa: TC002
from pydantic import Field

from op_engine.core_solver import (
Expand Down Expand Up @@ -159,7 +159,7 @@ def run(
if isinstance(system_axis, str | int):
operator_axis = system_axis

stepper: SystemProtocol = system._stepper # noqa: SLF001
stepper: SystemProtocol = system.bind()

mixing_kernels = system.option("mixing_kernels", None)
merged_params = {
Expand Down
40 changes: 37 additions & 3 deletions flepimop2-op_engine/tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING
import functools
from typing import TYPE_CHECKING, Any

import numpy as np
import pytest
Expand All @@ -11,7 +12,7 @@
from flepimop2.engine.op_engine import OpEngineFlepimop2Engine

if TYPE_CHECKING:
from flepimop2.system.abc import SystemProtocol
from flepimop2.typing import IdentifierString, SystemProtocol

# -----------------------------------------------------------------------------
# Test helpers
Expand All @@ -33,7 +34,7 @@ def __call__(


class _GoodSystem(SystemABC):
"""SystemABC implementation exposing a valid stepper via _stepper."""
"""SystemABC implementation exposing a valid stepper via bind()."""

module = "flepimop2.system.test_good"
state_change = "flow"
Expand All @@ -47,6 +48,11 @@ def __init__(self) -> None:
}
}

def _bind_impl(
self, params: dict[IdentifierString, Any] | None = None
) -> SystemProtocol:
return functools.partial(self._stepper, **(params or {}))


class _DeltaSystem(_GoodSystem):
"""SystemABC implementation with incompatible state_change."""
Expand Down Expand Up @@ -154,3 +160,31 @@ def test_validate_system_checks_state_change() -> None:
issues = engine.validate_system(bad)
assert issues is not None
assert issues[0].kind == "incompatible_system"


# -----------------------------------------------------------------------------
# Bind API integration
# -----------------------------------------------------------------------------


def test_engine_uses_bind_not_stepper() -> None:
"""Engine calls system.bind() rather than accessing system._stepper."""
engine = OpEngineFlepimop2Engine(state_change="flow")
system = _GoodSystem()
bind_called = False
original_bind = system.bind

def tracking_bind(
params: dict[IdentifierString, Any] | None = None, **kwargs: object
) -> SystemProtocol:
nonlocal bind_called
bind_called = True
return original_bind(params, **kwargs)

system.bind = tracking_bind # type: ignore[assignment]

times = np.array([0.0, 0.1], dtype=np.float64)
y0 = np.array([1.0], dtype=np.float64)
engine.run(system, times, y0, {})

assert bind_called, "Engine should call system.bind()"