-
Notifications
You must be signed in to change notification settings - Fork 1
Definite Targets / Optimizers #175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,7 +73,7 @@ docs/_build/ | |
|
|
||
| # PyBuilder | ||
| .pybuilder/ | ||
| target/ | ||
| ^target/ | ||
|
|
||
| # Jupyter Notebook | ||
| .ipynb_checkpoints | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,11 +2,13 @@ | |
|
|
||
| __all__ = ["SystemABC", "SystemProtocol", "build"] | ||
|
|
||
| import functools | ||
| import inspect | ||
| from typing import Any, Protocol, runtime_checkable | ||
|
|
||
| import numpy as np | ||
|
|
||
| from flepimop2.exceptions import ValidationIssue, Flepimop2ValidationError | ||
| from flepimop2._utils._module import _build | ||
| from flepimop2.configuration import ModuleModel | ||
| from flepimop2.module import ModuleABC | ||
|
|
@@ -87,6 +89,62 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 | |
| """ | ||
| self._stepper = _no_step_function | ||
|
|
||
| def bind(self, params: dict[str, Any] | None = None, **kwargs: dict[str, Any]) -> SystemProtocol: | ||
| """ | ||
| Bind static parameters to the system's stepper function. | ||
|
|
||
| Args: | ||
| params: A dictionary of parameters to statically define for the System. | ||
| **kwargs: Additional parameters to statically define for the System. | ||
|
|
||
| Returns: | ||
| A stepper from this System with static parameters defined. | ||
|
|
||
| Raises: | ||
| ValueError: If params contains "time" or "state" keys or if value types | ||
| are incompatible with stepper signature. | ||
|
|
||
| """ | ||
| if params is None: | ||
| params = {} | ||
| # Combine params and kwargs, with kwargs taking precedence | ||
| combined_params = {**params, **kwargs} | ||
|
pearsonca marked this conversation as resolved.
|
||
|
|
||
| forbidden_keys = {"time", "state"} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be good to extract out to a constant? |
||
| offered_keys = set(combined_params.keys()) | ||
| validation_errors = [] | ||
|
|
||
| # Validate that forbidden keys are not offered | ||
| if forbidden_keys.intersection(offered_keys): | ||
| msg = f"Cannot bind 'time' or 'state' keys; offered keys: {offered_keys}." | ||
| validation_errors.append(ValidationIssue(msg, "binding_values")) | ||
|
|
||
| # Validate that offered keys are in the stepper signature | ||
| signature_keys = set(inspect.signature(self._stepper).parameters.keys()) | ||
| if invalid_keys := offered_keys - signature_keys: | ||
| msg = f"Offered keys are not in stepper signature: {invalid_keys}. Eligible system parameters are: {signature_keys - forbidden_keys}." | ||
| validation_errors.append(ValidationIssue(msg, "binding_values")) | ||
|
|
||
| # Validate parameter value types against signature annotations | ||
| annotations = inspect.get_annotations(self._stepper) | ||
|
pearsonca marked this conversation as resolved.
|
||
| for key, value in combined_params.items(): | ||
| if key in annotations: | ||
| expected_type = annotations[key] | ||
| try: | ||
| casted_value = expected_type(value) | ||
| combined_params[key] = casted_value | ||
| except (ValueError, TypeError) as e: | ||
| msg = ( | ||
| f"Parameter '{key}' (type {type(value).__name__}) could not be " | ||
| f"cast to {expected_type.__name__}. Error: {str(e)}" | ||
| ) | ||
| validation_errors.append(ValidationIssue(msg, "binding_values")) | ||
|
Comment on lines
+132
to
+141
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems fine for now, but will quickly be obsoleted if we want to support parameters with shape. This gets a lot trickier to manage without a dedicated
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we can treat validating conforming shape as a next level check, that we want to do but don't yet |
||
|
|
||
| if validation_errors: | ||
| raise Flepimop2ValidationError(validation_errors) | ||
|
|
||
| return functools.partial(self._stepper, **combined_params) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a bit of a pro/con trade off returning
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i anticipate we'll need to figure out returning Self, but I suspect that would also impose that as a module developer requirement vs being able to do it generically here. |
||
|
|
||
| def step( | ||
| self, time: np.float64, state: Float64NDArray, **params: Any | ||
| ) -> Float64NDArray: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,6 @@ class WrapperSystem(ModuleModel, SystemABC): | |
| module: Literal["flepimop2.system.wrapper"] = "flepimop2.system.wrapper" | ||
| state_change: StateChangeEnum | ||
| script: Path | ||
| options: dict[str, Any] | None = None | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. goes because inherited and not actually overridden |
||
|
|
||
| @model_validator(mode="after") | ||
| def _validate_stepper(self) -> Self: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| """Abstract class for Optimization Targets.""" | ||
|
|
||
| __all__ = ["TargetABC", "TargetProtocol", "build"] | ||
|
|
||
| import inspect | ||
| from typing import Any, Protocol, runtime_checkable | ||
|
|
||
| import numpy as np | ||
|
|
||
| from flepimop2._utils._module import _build | ||
| from flepimop2.configuration import ModuleModel | ||
| from flepimop2.module import ModuleABC | ||
| from flepimop2.typing import Float64NDArray | ||
|
|
||
|
|
||
| @runtime_checkable | ||
| class TargetProtocol(Protocol): | ||
| """Type-definition (Protocol) for target functions.""" | ||
|
|
||
| def __call__( | ||
| self, simulated: Float64NDArray, **kwargs: Any | ||
| ) -> Float64NDArray: | ||
| """Protocol for target functions.""" | ||
| ... | ||
|
|
||
|
|
||
| def _no_target_function( | ||
| simulated: Float64NDArray, | ||
| **kwargs: Any, | ||
| ) -> Float64NDArray: | ||
| msg = "TargetABC::target must be provided by a concrete implementation." | ||
| raise NotImplementedError(msg) | ||
|
|
||
|
|
||
| class TargetABC(ModuleABC): | ||
| """ | ||
| Abstract class for Optimization Targets. | ||
|
|
||
| Attributes: | ||
| module: The module name for the target. | ||
| options: Optional dictionary of additional options the target exposes for | ||
| `flepimop2` to take advantage of. | ||
| """ | ||
|
|
||
| _evaluator: TargetProtocol | ||
|
|
||
| def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 | ||
| """ | ||
| Initialize the TargetABC. | ||
|
|
||
| The default initialization sets the evaluator to a no-op function. | ||
| Concrete implementations should override this with a valid evaluator | ||
| function. | ||
|
|
||
| Args: | ||
| *args: Positional arguments. | ||
| **kwargs: Keyword arguments. | ||
| """ | ||
| self._evaluator = _no_target_function | ||
|
|
||
| def evaluate( | ||
| self, simulated: Float64NDArray, **kwargs: Any | ||
| ) -> Float64NDArray: | ||
| """ | ||
| Evaluate the target function. | ||
|
|
||
| Args: | ||
| simulated: The simulated observations. | ||
| standard: The standard comparison. | ||
| **kwargs: Additional keyword arguments for the evaluator. | ||
|
|
||
| Returns: | ||
| The next state array after one step. | ||
| """ | ||
| return self._evaluator(simulated, **kwargs) | ||
|
|
||
|
|
||
| def build(config: dict[str, Any] | ModuleModel) -> TargetABC: | ||
| """ | ||
| Build a `TargetABC` from a configuration dictionary. | ||
|
|
||
| Args: | ||
| config: Configuration dictionary or a `ModuleModel` instance. | ||
|
|
||
| Returns: | ||
| The constructed target instance. | ||
|
|
||
| """ | ||
| return _build( | ||
| config, | ||
| "target", | ||
| "flepimop2.target.wrapper", | ||
| TargetABC, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| """A `TargetABC` which wraps a user-defined script file.""" | ||
|
|
||
| __all__ = ["WrapperTarget"] | ||
|
|
||
| from pathlib import Path | ||
| from typing import Any, Literal, Self | ||
|
|
||
| from pydantic import model_validator | ||
|
|
||
| from flepimop2._utils._module import _load_module, _validate_function | ||
| from flepimop2.configuration import ModuleModel | ||
| from flepimop2.target.abc import TargetABC | ||
|
|
||
| class WrapperTarget(ModuleModel, TargetABC): | ||
| """A `TargetABC` which wraps a user-defined script file.""" | ||
|
|
||
| module: Literal["flepimop2.target.wrapper"] = "flepimop2.target.wrapper" | ||
| script: Path | ||
| options: dict[str, Any] | None = None | ||
|
|
||
| @model_validator(mode="after") | ||
| def _validate_evaluator(self) -> Self: | ||
| """ | ||
| Validator to load and validate the evaluator function from the script file. | ||
|
|
||
| Returns: | ||
| The validated `WrapperTarget` instance. | ||
|
|
||
| Raises: | ||
| AttributeError: If the module does not have a valid 'evaluator' function. | ||
| """ | ||
| mod = _load_module(self.script, "flepimop2.target.wrapped") | ||
| if not _validate_function(mod, "evaluator"): | ||
| msg = f"Module at {self.script} does not have a valid 'evaluator' function." | ||
| raise AttributeError(msg) | ||
| self._evaluator = mod.evaluator | ||
| return self |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| """Tests `SystemABC` ability to bind static parameters.""" | ||
|
|
||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
|
|
||
| from flepimop2.system.abc import SystemABC, build | ||
| from flepimop2.typing import StateChangeEnum | ||
| from flepimop2.exceptions import Flepimop2ValidationError | ||
|
|
||
| TEST_SCRIPT = Path(__file__).parent / "dummy_system.py" | ||
| system = build({"script": TEST_SCRIPT, "state_change": StateChangeEnum.DELTA}) | ||
|
|
||
| @pytest.mark.parametrize("test_system", [system]) | ||
| def test_set_valid_static_parameters(test_system: SystemABC): | ||
| """Confirm no errors when setting all valid parameters.""" | ||
| test_system.bind(offset = 5.0) | ||
|
|
||
| @pytest.mark.parametrize("test_system", [system]) | ||
| def test_set_valid_static_parameters_dict_version(test_system: SystemABC): | ||
| """Confirm no errors when setting all valid parameters.""" | ||
| test_system.bind(params={"offset": 5.0}) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("test_system", [system]) | ||
| def test_set_static_parameter_throws_error_on_fixed_time(test_system: SystemABC): | ||
| """Confirm error thrown when attempting to fix time parameter.""" | ||
| with pytest.raises(Flepimop2ValidationError): | ||
| test_system.bind(time = 100) | ||
|
|
||
| @pytest.mark.parametrize("test_system", [system]) | ||
| def test_set_static_parameter_throws_error_on_fixed_state(test_system: SystemABC): | ||
| """Confirm error thrown when attempting to fix state parameter.""" | ||
| with pytest.raises(Flepimop2ValidationError): | ||
| test_system.bind(state = [1.0, 2.0, 3.0]) | ||
|
|
||
| @pytest.mark.parametrize("test_system", [system]) | ||
| def test_set_nonexistent_parameter_throws_error(test_system: SystemABC): | ||
| """Confirm error thrown when setting a parameter that does not exist.""" | ||
| with pytest.raises(Flepimop2ValidationError): | ||
| test_system.bind(nonexistent_param = 5.0) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("test_system", [system]) | ||
| def test_set_parameter_with_invalid_type_throws_error(test_system: SystemABC): | ||
| """Confirm error thrown when setting parameter with invalid type.""" | ||
| with pytest.raises(Flepimop2ValidationError): | ||
| test_system.bind(offset = "invalid_string") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flepimop2ValidationError?