diff --git a/.gitignore b/.gitignore index b1e13a0..283e4d0 100644 --- a/.gitignore +++ b/.gitignore @@ -73,7 +73,7 @@ docs/_build/ # PyBuilder .pybuilder/ -target/ +^target/ # Jupyter Notebook .ipynb_checkpoints diff --git a/src/flepimop2/abcs/__init__.py b/src/flepimop2/abcs/__init__.py index 7d6b4b2..0220c32 100644 --- a/src/flepimop2/abcs/__init__.py +++ b/src/flepimop2/abcs/__init__.py @@ -11,14 +11,20 @@ "BackendABC", "EngineABC", "EngineProtocol", +# "OptimizerABC", +# "OptimizerProtocol", "ParameterABC", "ProcessABC", "SystemABC", "SystemProtocol", + "TargetABC", + "TargetProtocol", ] from flepimop2.backend.abc import BackendABC from flepimop2.engine.abc import EngineABC, EngineProtocol +# from flepimop2.optimizer.abc import OptimizerABC, OptimizerProtocol from flepimop2.parameter.abc import ParameterABC from flepimop2.process.abc import ProcessABC from flepimop2.system.abc import SystemABC, SystemProtocol +from flepimop2.target.abc import TargetABC, TargetProtocol diff --git a/src/flepimop2/engine/abc/__init__.py b/src/flepimop2/engine/abc/__init__.py index dec9567..e4b7264 100644 --- a/src/flepimop2/engine/abc/__init__.py +++ b/src/flepimop2/engine/abc/__init__.py @@ -1,6 +1,8 @@ """Abstract class for Engines to evolve Dynamic Systems.""" -__all__ = ["EngineABC", "EngineProtocol", "build"] +__all__ = ["EngineABC", "EngineProtocol", "build", "GeneratorProtocol"] + +import functools from typing import Any, Protocol, runtime_checkable @@ -38,6 +40,17 @@ def __call__( """Protocol for engine runner functions.""" ... +class GeneratorProtocol(Protocol): + """Type-definition (Protocol) for generator functions.""" + + def __call__( + self, + times: Float64NDArray, + state: Float64NDArray, + params: dict[IdentifierString, Any], + ) -> Float64NDArray: + """Protocol for engine generator functions.""" + ... class EngineABC(ModuleABC): """Abstract class for Engines to evolve Dynamic Systems.""" @@ -57,6 +70,35 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 """ self._runner = _no_run_func + def bind( + self, + system: SystemABC, + times: Float64NDArray, + params: dict[IdentifierString, Any], + **kwargs: Any, + ) -> GeneratorProtocol: + """ + Bind a System and other Engine settings to create a GeneratorProtocol. + + This method uses the Engine to translate a SystemProtocol into a GeneratorProtocol. + + Args: + system: A system that can be bound to generate parameters or configurations. + times: Array of time points for evaluation. + params: static parameters for the stepper. + **kwargs: any additional arguments for the engine. + """ + + # bind any fixed parameters to create a targetted stepper + bound_stepper = system.bind(params) + + generator = functools.partial( + func = self._runner, + stepper = bound_stepper, + ) + + return generator + def run( self, system: SystemABC, @@ -86,6 +128,8 @@ def run( **kwargs, ) + + def validate_system( # noqa: PLR6301 self, system: SystemABC, # noqa: ARG002 diff --git a/src/flepimop2/system/abc/__init__.py b/src/flepimop2/system/abc/__init__.py index c1473db..a08250f 100644 --- a/src/flepimop2/system/abc/__init__.py +++ b/src/flepimop2/system/abc/__init__.py @@ -2,11 +2,13 @@ __all__ = ["SystemABC", "SystemProtocol", "build"] +import functools import inspect from typing import Any, Protocol, runtime_checkable import numpy as np +from flepimop2.exceptions import ValidationIssue, Flepimop2ValidationError from flepimop2._utils._module import _build from flepimop2.configuration import ModuleModel from flepimop2.module import ModuleABC @@ -87,6 +89,62 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 """ self._stepper = _no_step_function + def bind(self, params: dict[str, Any] | None = None, **kwargs: dict[str, Any]) -> SystemProtocol: + """ + Bind static parameters to the system's stepper function. + + Args: + params: A dictionary of parameters to statically define for the System. + **kwargs: Additional parameters to statically define for the System. + + Returns: + A stepper from this System with static parameters defined. + + Raises: + ValueError: If params contains "time" or "state" keys or if value types + are incompatible with stepper signature. + + """ + if params is None: + params = {} + # Combine params and kwargs, with kwargs taking precedence + combined_params = {**params, **kwargs} + + forbidden_keys = {"time", "state"} + offered_keys = set(combined_params.keys()) + validation_errors = [] + + # Validate that forbidden keys are not offered + if forbidden_keys.intersection(offered_keys): + msg = f"Cannot bind 'time' or 'state' keys; offered keys: {offered_keys}." + validation_errors.append(ValidationIssue(msg, "binding_values")) + + # Validate that offered keys are in the stepper signature + signature_keys = set(inspect.signature(self._stepper).parameters.keys()) + if invalid_keys := offered_keys - signature_keys: + msg = f"Offered keys are not in stepper signature: {invalid_keys}. Eligible system parameters are: {signature_keys - forbidden_keys}." + validation_errors.append(ValidationIssue(msg, "binding_values")) + + # Validate parameter value types against signature annotations + annotations = inspect.get_annotations(self._stepper) + for key, value in combined_params.items(): + if key in annotations: + expected_type = annotations[key] + try: + casted_value = expected_type(value) + combined_params[key] = casted_value + except (ValueError, TypeError) as e: + msg = ( + f"Parameter '{key}' (type {type(value).__name__}) could not be " + f"cast to {expected_type.__name__}. Error: {str(e)}" + ) + validation_errors.append(ValidationIssue(msg, "binding_values")) + + if validation_errors: + raise Flepimop2ValidationError(validation_errors) + + return functools.partial(self._stepper, **combined_params) + def step( self, time: np.float64, state: Float64NDArray, **params: Any ) -> Float64NDArray: diff --git a/src/flepimop2/system/wrapper/__init__.py b/src/flepimop2/system/wrapper/__init__.py index 098001f..a0b5a25 100644 --- a/src/flepimop2/system/wrapper/__init__.py +++ b/src/flepimop2/system/wrapper/__init__.py @@ -19,7 +19,6 @@ class WrapperSystem(ModuleModel, SystemABC): module: Literal["flepimop2.system.wrapper"] = "flepimop2.system.wrapper" state_change: StateChangeEnum script: Path - options: dict[str, Any] | None = None @model_validator(mode="after") def _validate_stepper(self) -> Self: diff --git a/src/flepimop2/target/abc/__init__.py b/src/flepimop2/target/abc/__init__.py new file mode 100644 index 0000000..4c3b195 --- /dev/null +++ b/src/flepimop2/target/abc/__init__.py @@ -0,0 +1,94 @@ +"""Abstract class for Optimization Targets.""" + +__all__ = ["TargetABC", "TargetProtocol", "build"] + +import inspect +from typing import Any, Protocol, runtime_checkable + +import numpy as np + +from flepimop2._utils._module import _build +from flepimop2.configuration import ModuleModel +from flepimop2.module import ModuleABC +from flepimop2.typing import Float64NDArray + + +@runtime_checkable +class TargetProtocol(Protocol): + """Type-definition (Protocol) for target functions.""" + + def __call__( + self, simulated: Float64NDArray, **kwargs: Any + ) -> Float64NDArray: + """Protocol for target functions.""" + ... + + +def _no_target_function( + simulated: Float64NDArray, + **kwargs: Any, +) -> Float64NDArray: + msg = "TargetABC::target must be provided by a concrete implementation." + raise NotImplementedError(msg) + + +class TargetABC(ModuleABC): + """ + Abstract class for Optimization Targets. + + Attributes: + module: The module name for the target. + options: Optional dictionary of additional options the target exposes for + `flepimop2` to take advantage of. + """ + + _evaluator: TargetProtocol + + def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + """ + Initialize the TargetABC. + + The default initialization sets the evaluator to a no-op function. + Concrete implementations should override this with a valid evaluator + function. + + Args: + *args: Positional arguments. + **kwargs: Keyword arguments. + """ + self._evaluator = _no_target_function + + def evaluate( + self, simulated: Float64NDArray, **kwargs: Any + ) -> Float64NDArray: + """ + Evaluate the target function. + + Args: + simulated: The simulated observations. + standard: The standard comparison. + **kwargs: Additional keyword arguments for the evaluator. + + Returns: + The next state array after one step. + """ + return self._evaluator(simulated, **kwargs) + + +def build(config: dict[str, Any] | ModuleModel) -> TargetABC: + """ + Build a `TargetABC` from a configuration dictionary. + + Args: + config: Configuration dictionary or a `ModuleModel` instance. + + Returns: + The constructed target instance. + + """ + return _build( + config, + "target", + "flepimop2.target.wrapper", + TargetABC, + ) diff --git a/src/flepimop2/target/wrapper/__init__.py b/src/flepimop2/target/wrapper/__init__.py new file mode 100644 index 0000000..c70c37c --- /dev/null +++ b/src/flepimop2/target/wrapper/__init__.py @@ -0,0 +1,37 @@ +"""A `TargetABC` which wraps a user-defined script file.""" + +__all__ = ["WrapperTarget"] + +from pathlib import Path +from typing import Any, Literal, Self + +from pydantic import model_validator + +from flepimop2._utils._module import _load_module, _validate_function +from flepimop2.configuration import ModuleModel +from flepimop2.target.abc import TargetABC + +class WrapperTarget(ModuleModel, TargetABC): + """A `TargetABC` which wraps a user-defined script file.""" + + module: Literal["flepimop2.target.wrapper"] = "flepimop2.target.wrapper" + script: Path + options: dict[str, Any] | None = None + + @model_validator(mode="after") + def _validate_evaluator(self) -> Self: + """ + Validator to load and validate the evaluator function from the script file. + + Returns: + The validated `WrapperTarget` instance. + + Raises: + AttributeError: If the module does not have a valid 'evaluator' function. + """ + mod = _load_module(self.script, "flepimop2.target.wrapped") + if not _validate_function(mod, "evaluator"): + msg = f"Module at {self.script} does not have a valid 'evaluator' function." + raise AttributeError(msg) + self._evaluator = mod.evaluator + return self diff --git a/tests/system/test_system_wrapper/dummy_system.py b/tests/system/dummy_system.py similarity index 100% rename from tests/system/test_system_wrapper/dummy_system.py rename to tests/system/dummy_system.py diff --git a/tests/system/test_system_bind.py b/tests/system/test_system_bind.py new file mode 100644 index 0000000..fc6de69 --- /dev/null +++ b/tests/system/test_system_bind.py @@ -0,0 +1,48 @@ +"""Tests `SystemABC` ability to bind static parameters.""" + +from pathlib import Path + +import pytest + +from flepimop2.system.abc import SystemABC, build +from flepimop2.typing import StateChangeEnum +from flepimop2.exceptions import Flepimop2ValidationError + +TEST_SCRIPT = Path(__file__).parent / "dummy_system.py" +system = build({"script": TEST_SCRIPT, "state_change": StateChangeEnum.DELTA}) + +@pytest.mark.parametrize("test_system", [system]) +def test_set_valid_static_parameters(test_system: SystemABC): + """Confirm no errors when setting all valid parameters.""" + test_system.bind(offset = 5.0) + +@pytest.mark.parametrize("test_system", [system]) +def test_set_valid_static_parameters_dict_version(test_system: SystemABC): + """Confirm no errors when setting all valid parameters.""" + test_system.bind(params={"offset": 5.0}) + + +@pytest.mark.parametrize("test_system", [system]) +def test_set_static_parameter_throws_error_on_fixed_time(test_system: SystemABC): + """Confirm error thrown when attempting to fix time parameter.""" + with pytest.raises(Flepimop2ValidationError): + test_system.bind(time = 100) + +@pytest.mark.parametrize("test_system", [system]) +def test_set_static_parameter_throws_error_on_fixed_state(test_system: SystemABC): + """Confirm error thrown when attempting to fix state parameter.""" + with pytest.raises(Flepimop2ValidationError): + test_system.bind(state = [1.0, 2.0, 3.0]) + +@pytest.mark.parametrize("test_system", [system]) +def test_set_nonexistent_parameter_throws_error(test_system: SystemABC): + """Confirm error thrown when setting a parameter that does not exist.""" + with pytest.raises(Flepimop2ValidationError): + test_system.bind(nonexistent_param = 5.0) + + +@pytest.mark.parametrize("test_system", [system]) +def test_set_parameter_with_invalid_type_throws_error(test_system: SystemABC): + """Confirm error thrown when setting parameter with invalid type.""" + with pytest.raises(Flepimop2ValidationError): + test_system.bind(offset = "invalid_string") \ No newline at end of file diff --git a/tests/system/test_system_wrapper.py b/tests/system/test_system_wrapper.py index a7c896b..b3d7fd8 100644 --- a/tests/system/test_system_wrapper.py +++ b/tests/system/test_system_wrapper.py @@ -6,12 +6,13 @@ import numpy as np import pytest +from flepimop2.typing import StateChangeEnum from flepimop2.system.abc import build -TEST_SCRIPT = Path(__file__).with_suffix("") / "dummy_system.py" +TEST_SCRIPT = Path(__file__).parent / "dummy_system.py" -@pytest.mark.parametrize("config", [{"script": TEST_SCRIPT}]) +@pytest.mark.parametrize("config", [{"script": TEST_SCRIPT, "state_change": StateChangeEnum.DELTA}]) def test_wrapper_system(config: dict[str, Any]) -> None: """Test `WrapperSystem` loads a script and uses its `stepper` function.""" system = build(config) diff --git a/tests/target/test_target_abc.py b/tests/target/test_target_abc.py new file mode 100644 index 0000000..5dd46ba --- /dev/null +++ b/tests/target/test_target_abc.py @@ -0,0 +1,22 @@ +"""Tests for `TargetABC` and default `WrapperTarget`.""" + +import numpy as np +import pytest + +from flepimop2.target.abc import TargetABC + + +class DummyTarget(TargetABC): + """A dummy target for testing purposes.""" + + module = "dummy" + + +@pytest.mark.parametrize("target", [DummyTarget()]) +def test_abstraction_error(target: TargetABC) -> None: + """Test default evaluator raises `NotImplementedError` when not overridden.""" + with pytest.raises(NotImplementedError): + target.evaluate( + np.array([1.0, 2.0, 3.0], dtype=np.float64), + standard=np.array([1.0, 2.0, 3.0], dtype=np.float64) + ) diff --git a/tests/target/test_target_wrapper.py b/tests/target/test_target_wrapper.py new file mode 100644 index 0000000..5739e38 --- /dev/null +++ b/tests/target/test_target_wrapper.py @@ -0,0 +1,23 @@ +"""Tests for `TargetABC` and default `WrapperTarget`.""" + +from pathlib import Path +from typing import Any + +import numpy as np +import pytest + +from flepimop2.system.abc import build + +TEST_SCRIPT = Path(__file__).with_suffix("") / "dummy_target.py" + + +@pytest.mark.parametrize("config", [{"script": TEST_SCRIPT}]) +def test_wrapper_target(config: dict[str, Any]) -> None: + """Test `WrapperTarget` loads a script and uses its `evaluator` function.""" + target = build(config) + result = target.evaluate( + np.array([1.0, 2.0, 3.0], dtype=np.float64), + standard=np.array([2.0, 3.0, 4.0], dtype=np.float64) + ) + expected = np.array([1.0], dtype=np.float64) # RMSE of off-by [1,1,1] = 1.0 + np.testing.assert_array_equal(result, expected) diff --git a/tests/target/test_target_wrapper/dummy_target.py b/tests/target/test_target_wrapper/dummy_target.py new file mode 100644 index 0000000..75edc90 --- /dev/null +++ b/tests/target/test_target_wrapper/dummy_target.py @@ -0,0 +1,21 @@ +"""A dummy stepper function for testing `WrapperSystem`.""" + +import numpy as np + +from flepimop2.typing import Float64NDArray + + +def evaluator(simulated: Float64NDArray, standard: Float64NDArray) -> Float64NDArray: + """ + A dummy evaluator function for testing purposes: RMSE(simulated, standard). + + Args: + simulated: The simulated state array. + standard: The standard state array. + state: The current state array. + offset: An offset value to be added to the state. + + Returns: + The updated state array after applying the evaluator logic. + """ + return np.sqrt(np.mean((simulated - standard) ** 2))