Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ docs/_build/

# PyBuilder
.pybuilder/
target/
^target/

# Jupyter Notebook
.ipynb_checkpoints
Expand Down
6 changes: 6 additions & 0 deletions src/flepimop2/abcs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@
"BackendABC",
"EngineABC",
"EngineProtocol",
# "OptimizerABC",
# "OptimizerProtocol",
"ParameterABC",
"ProcessABC",
"SystemABC",
"SystemProtocol",
"TargetABC",
"TargetProtocol",
]

from flepimop2.backend.abc import BackendABC
from flepimop2.engine.abc import EngineABC, EngineProtocol
# from flepimop2.optimizer.abc import OptimizerABC, OptimizerProtocol
from flepimop2.parameter.abc import ParameterABC
from flepimop2.process.abc import ProcessABC
from flepimop2.system.abc import SystemABC, SystemProtocol
from flepimop2.target.abc import TargetABC, TargetProtocol
46 changes: 45 additions & 1 deletion src/flepimop2/engine/abc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Abstract class for Engines to evolve Dynamic Systems."""

__all__ = ["EngineABC", "EngineProtocol", "build"]
__all__ = ["EngineABC", "EngineProtocol", "build", "GeneratorProtocol"]

import functools

from typing import Any, Protocol, runtime_checkable

Expand Down Expand Up @@ -38,6 +40,17 @@ def __call__(
"""Protocol for engine runner functions."""
...

class GeneratorProtocol(Protocol):
"""Type-definition (Protocol) for generator functions."""

def __call__(
self,
times: Float64NDArray,
state: Float64NDArray,
params: dict[IdentifierString, Any],
) -> Float64NDArray:
"""Protocol for engine generator functions."""
...

class EngineABC(ModuleABC):
"""Abstract class for Engines to evolve Dynamic Systems."""
Expand All @@ -57,6 +70,35 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002
"""
self._runner = _no_run_func

def bind(
self,
system: SystemABC,
times: Float64NDArray,
params: dict[IdentifierString, Any],
**kwargs: Any,
) -> GeneratorProtocol:
"""
Bind a System and other Engine settings to create a GeneratorProtocol.

This method uses the Engine to translate a SystemProtocol into a GeneratorProtocol.

Args:
system: A system that can be bound to generate parameters or configurations.
times: Array of time points for evaluation.
params: static parameters for the stepper.
**kwargs: any additional arguments for the engine.
"""

# bind any fixed parameters to create a targetted stepper
bound_stepper = system.bind(params)

generator = functools.partial(
func = self._runner,
stepper = bound_stepper,
)

return generator

def run(
self,
system: SystemABC,
Expand Down Expand Up @@ -86,6 +128,8 @@ def run(
**kwargs,
)



def validate_system( # noqa: PLR6301
self,
system: SystemABC, # noqa: ARG002
Expand Down
58 changes: 58 additions & 0 deletions src/flepimop2/system/abc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

__all__ = ["SystemABC", "SystemProtocol", "build"]

import functools
import inspect
from typing import Any, Protocol, runtime_checkable

import numpy as np

from flepimop2.exceptions import ValidationIssue, Flepimop2ValidationError
from flepimop2._utils._module import _build
from flepimop2.configuration import ModuleModel
from flepimop2.module import ModuleABC
Expand Down Expand Up @@ -87,6 +89,62 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002
"""
self._stepper = _no_step_function

def bind(self, params: dict[str, Any] | None = None, **kwargs: dict[str, Any]) -> SystemProtocol:
"""
Bind static parameters to the system's stepper function.

Args:
params: A dictionary of parameters to statically define for the System.
**kwargs: Additional parameters to statically define for the System.

Returns:
A stepper from this System with static parameters defined.

Raises:
ValueError: If params contains "time" or "state" keys or if value types
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flepimop2ValidationError?

are incompatible with stepper signature.

"""
if params is None:
params = {}
# Combine params and kwargs, with kwargs taking precedence
combined_params = {**params, **kwargs}
Comment thread
pearsonca marked this conversation as resolved.

forbidden_keys = {"time", "state"}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to extract out to a constant?

offered_keys = set(combined_params.keys())
validation_errors = []

# Validate that forbidden keys are not offered
if forbidden_keys.intersection(offered_keys):
msg = f"Cannot bind 'time' or 'state' keys; offered keys: {offered_keys}."
validation_errors.append(ValidationIssue(msg, "binding_values"))

# Validate that offered keys are in the stepper signature
signature_keys = set(inspect.signature(self._stepper).parameters.keys())
if invalid_keys := offered_keys - signature_keys:
msg = f"Offered keys are not in stepper signature: {invalid_keys}. Eligible system parameters are: {signature_keys - forbidden_keys}."
validation_errors.append(ValidationIssue(msg, "binding_values"))

# Validate parameter value types against signature annotations
annotations = inspect.get_annotations(self._stepper)
Comment thread
pearsonca marked this conversation as resolved.
for key, value in combined_params.items():
if key in annotations:
expected_type = annotations[key]
try:
casted_value = expected_type(value)
combined_params[key] = casted_value
except (ValueError, TypeError) as e:
msg = (
f"Parameter '{key}' (type {type(value).__name__}) could not be "
f"cast to {expected_type.__name__}. Error: {str(e)}"
)
validation_errors.append(ValidationIssue(msg, "binding_values"))
Comment on lines +132 to +141
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine for now, but will quickly be obsoleted if we want to support parameters with shape. This gets a lot trickier to manage without a dedicated RealizedParameter class that contains information about the values and shape. Perhaps should expedite the axes development.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we can treat validating conforming shape as a next level check, that we want to do but don't yet


if validation_errors:
raise Flepimop2ValidationError(validation_errors)

return functools.partial(self._stepper, **combined_params)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bit of a pro/con trade off returning SystemProtocol vs Self. Pro is this is more simple, but the con is that you loose the extra information that the system provides, so even with binding you still need to keep the system object around on the user side.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i anticipate we'll need to figure out returning Self, but I suspect that would also impose that as a module developer requirement vs being able to do it generically here.


def step(
self, time: np.float64, state: Float64NDArray, **params: Any
) -> Float64NDArray:
Expand Down
1 change: 0 additions & 1 deletion src/flepimop2/system/wrapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class WrapperSystem(ModuleModel, SystemABC):
module: Literal["flepimop2.system.wrapper"] = "flepimop2.system.wrapper"
state_change: StateChangeEnum
script: Path
options: dict[str, Any] | None = None
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goes because inherited and not actually overridden


@model_validator(mode="after")
def _validate_stepper(self) -> Self:
Expand Down
94 changes: 94 additions & 0 deletions src/flepimop2/target/abc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Abstract class for Optimization Targets."""

__all__ = ["TargetABC", "TargetProtocol", "build"]

import inspect
from typing import Any, Protocol, runtime_checkable

import numpy as np

from flepimop2._utils._module import _build
from flepimop2.configuration import ModuleModel
from flepimop2.module import ModuleABC
from flepimop2.typing import Float64NDArray


@runtime_checkable
class TargetProtocol(Protocol):
"""Type-definition (Protocol) for target functions."""

def __call__(
self, simulated: Float64NDArray, **kwargs: Any
) -> Float64NDArray:
"""Protocol for target functions."""
...


def _no_target_function(
simulated: Float64NDArray,
**kwargs: Any,
) -> Float64NDArray:
msg = "TargetABC::target must be provided by a concrete implementation."
raise NotImplementedError(msg)


class TargetABC(ModuleABC):
"""
Abstract class for Optimization Targets.

Attributes:
module: The module name for the target.
options: Optional dictionary of additional options the target exposes for
`flepimop2` to take advantage of.
"""

_evaluator: TargetProtocol

def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002
"""
Initialize the TargetABC.

The default initialization sets the evaluator to a no-op function.
Concrete implementations should override this with a valid evaluator
function.

Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
"""
self._evaluator = _no_target_function

def evaluate(
self, simulated: Float64NDArray, **kwargs: Any
) -> Float64NDArray:
"""
Evaluate the target function.

Args:
simulated: The simulated observations.
standard: The standard comparison.
**kwargs: Additional keyword arguments for the evaluator.

Returns:
The next state array after one step.
"""
return self._evaluator(simulated, **kwargs)


def build(config: dict[str, Any] | ModuleModel) -> TargetABC:
"""
Build a `TargetABC` from a configuration dictionary.

Args:
config: Configuration dictionary or a `ModuleModel` instance.

Returns:
The constructed target instance.

"""
return _build(
config,
"target",
"flepimop2.target.wrapper",
TargetABC,
)
37 changes: 37 additions & 0 deletions src/flepimop2/target/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""A `TargetABC` which wraps a user-defined script file."""

__all__ = ["WrapperTarget"]

from pathlib import Path
from typing import Any, Literal, Self

from pydantic import model_validator

from flepimop2._utils._module import _load_module, _validate_function
from flepimop2.configuration import ModuleModel
from flepimop2.target.abc import TargetABC

class WrapperTarget(ModuleModel, TargetABC):
"""A `TargetABC` which wraps a user-defined script file."""

module: Literal["flepimop2.target.wrapper"] = "flepimop2.target.wrapper"
script: Path
options: dict[str, Any] | None = None

@model_validator(mode="after")
def _validate_evaluator(self) -> Self:
"""
Validator to load and validate the evaluator function from the script file.

Returns:
The validated `WrapperTarget` instance.

Raises:
AttributeError: If the module does not have a valid 'evaluator' function.
"""
mod = _load_module(self.script, "flepimop2.target.wrapped")
if not _validate_function(mod, "evaluator"):
msg = f"Module at {self.script} does not have a valid 'evaluator' function."
raise AttributeError(msg)
self._evaluator = mod.evaluator
return self
48 changes: 48 additions & 0 deletions tests/system/test_system_bind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Tests `SystemABC` ability to bind static parameters."""

from pathlib import Path

import pytest

from flepimop2.system.abc import SystemABC, build
from flepimop2.typing import StateChangeEnum
from flepimop2.exceptions import Flepimop2ValidationError

TEST_SCRIPT = Path(__file__).parent / "dummy_system.py"
system = build({"script": TEST_SCRIPT, "state_change": StateChangeEnum.DELTA})

@pytest.mark.parametrize("test_system", [system])
def test_set_valid_static_parameters(test_system: SystemABC):
"""Confirm no errors when setting all valid parameters."""
test_system.bind(offset = 5.0)

@pytest.mark.parametrize("test_system", [system])
def test_set_valid_static_parameters_dict_version(test_system: SystemABC):
"""Confirm no errors when setting all valid parameters."""
test_system.bind(params={"offset": 5.0})


@pytest.mark.parametrize("test_system", [system])
def test_set_static_parameter_throws_error_on_fixed_time(test_system: SystemABC):
"""Confirm error thrown when attempting to fix time parameter."""
with pytest.raises(Flepimop2ValidationError):
test_system.bind(time = 100)

@pytest.mark.parametrize("test_system", [system])
def test_set_static_parameter_throws_error_on_fixed_state(test_system: SystemABC):
"""Confirm error thrown when attempting to fix state parameter."""
with pytest.raises(Flepimop2ValidationError):
test_system.bind(state = [1.0, 2.0, 3.0])

@pytest.mark.parametrize("test_system", [system])
def test_set_nonexistent_parameter_throws_error(test_system: SystemABC):
"""Confirm error thrown when setting a parameter that does not exist."""
with pytest.raises(Flepimop2ValidationError):
test_system.bind(nonexistent_param = 5.0)


@pytest.mark.parametrize("test_system", [system])
def test_set_parameter_with_invalid_type_throws_error(test_system: SystemABC):
"""Confirm error thrown when setting parameter with invalid type."""
with pytest.raises(Flepimop2ValidationError):
test_system.bind(offset = "invalid_string")
5 changes: 3 additions & 2 deletions tests/system/test_system_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import numpy as np
import pytest

from flepimop2.typing import StateChangeEnum
from flepimop2.system.abc import build

TEST_SCRIPT = Path(__file__).with_suffix("") / "dummy_system.py"
TEST_SCRIPT = Path(__file__).parent / "dummy_system.py"


@pytest.mark.parametrize("config", [{"script": TEST_SCRIPT}])
@pytest.mark.parametrize("config", [{"script": TEST_SCRIPT, "state_change": StateChangeEnum.DELTA}])
def test_wrapper_system(config: dict[str, Any]) -> None:
"""Test `WrapperSystem` loads a script and uses its `stepper` function."""
system = build(config)
Expand Down
Loading
Loading