add lagrange params return and save#2
Conversation
Summary of ChangesHello @aadityacs, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the "mmapy" optimization library by exposing and validating the Lagrange multipliers and slack variables, which are crucial for understanding the sensitivity and optimality conditions of the optimization problem. The core change involves the introduction of a dedicated "LagrangeMultipliers" dataclass and a modification to the "update_mma" function to optionally return these dual variables. This allows users to gain deeper insights into the optimization process beyond just the primal solution. Extensive unit tests have been added to ensure the correctness and adherence to KKT conditions for these newly exposed parameters. Additionally, the project's metadata and code style configurations have been updated for improved maintainability and consistency. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable feature to return Lagrange multipliers from the MMA update step, which is very useful for optimization analysis. A new LagrangeMultipliers dataclass has been added for this purpose, and the update_mma function now conditionally returns these multipliers. The accompanying tests are excellent, verifying not just the functionality but also key theoretical properties like non-negativity and complementary slackness. The codebase has also been auto-formatted for consistency, with the configuration saved in ruff.toml. My review includes a high-severity suggestion to improve the robustness of the MMAState serialization and a few medium-severity comments to remove leftover debugging statements from the tests.
| def from_array( | ||
| cls, | ||
| state_array: np.ndarray, | ||
| num_design_var: int, | ||
| ) -> 'MMAState': | ||
| cls, | ||
| state_array: np.ndarray, | ||
| num_design_var: int, | ||
| ) -> "MMAState": | ||
| """Reconstructs an `MMAState` from an array.""" | ||
| empty = MMAState.new(num_design_var) | ||
| if empty.to_array().shape != state_array.shape: | ||
| raise ValueError( | ||
| f'`state_array` shape is incompatible with `num_design_var`, got a ' | ||
| f'shape of {state_array.shape} but expected {empty.to_array().shape}' | ||
| f'when `num_design_var` is {num_design_var}.') | ||
| f"`state_array` shape is incompatible with `num_design_var`, got a " | ||
| f"shape of {state_array.shape} but expected {empty.to_array().shape}" | ||
| f"when `num_design_var` is {num_design_var}." | ||
| ) | ||
| n = num_design_var | ||
| return MMAState(x=state_array[0:n].reshape((-1, 1)), | ||
| x_old_1=state_array[n:2*n].reshape((-1, 1)), | ||
| x_old_2=state_array[2*n:3*n].reshape((-1, 1)), | ||
| low=state_array[3*n:4*n].reshape((-1, 1)), | ||
| upp=state_array[4*n:5*n].reshape((-1, 1)), | ||
| is_converged=bool(state_array[5*n]), | ||
| epoch=int(state_array[5*n+1]), | ||
| kkt_norm=state_array[5*n+2], | ||
| change_design_var=state_array[5*n+3],) | ||
| return MMAState( | ||
| x=state_array[0:n].reshape((-1, 1)), | ||
| x_old_1=state_array[n : 2 * n].reshape((-1, 1)), | ||
| x_old_2=state_array[2 * n : 3 * n].reshape((-1, 1)), | ||
| low=state_array[3 * n : 4 * n].reshape((-1, 1)), | ||
| upp=state_array[4 * n : 5 * n].reshape((-1, 1)), | ||
| is_converged=bool(state_array[5 * n]), | ||
| epoch=int(state_array[5 * n + 1]), | ||
| kkt_norm=state_array[5 * n + 2], | ||
| change_design_var=state_array[5 * n + 3], | ||
| ) | ||
|
|
||
| def to_array(self) -> np.ndarray: | ||
| """Converts the `MMAState` into a rank-1 array.""" | ||
| return np.concatenate( | ||
| [np.array(field).flatten() for field in dataclasses.astuple(self)]) | ||
| [np.array(field).flatten() for field in dataclasses.astuple(self)] | ||
| ) |
There was a problem hiding this comment.
The current implementation of to_array and from_array for serializing and deserializing MMAState is brittle. to_array uses dataclasses.astuple, which relies on the declaration order of fields. from_array then uses hardcoded slices that are implicitly dependent on this order. If a developer were to reorder or add fields to the MMAState dataclass, these methods would likely fail silently, leading to subtle bugs from incorrect state reconstruction.
To improve robustness and maintainability, I recommend making the serialization and deserialization logic explicit and independent of the dataclass field order. For example, to_array could explicitly concatenate each field, and from_array would perform the inverse operation. This would make the code safer against future refactoring.
| while not mma_state.is_converged: | ||
| objective, grad_obj = objective_fn(mma_state.x) | ||
| print(f'{mma_state.epoch} obj, {objective}') | ||
| print(f"{mma_state.epoch} obj, {objective}") |
| ) | ||
|
|
||
| print(f'epoch {mma_state.epoch} , obj {objective[0]:.2E}') | ||
| print(f"epoch {mma_state.epoch} , obj {objective[0]:.2E}") |
| ) | ||
|
|
||
| print(f'epoch {mma_state.epoch} , obj {objective[0]:.2E}') | ||
| print(f"epoch {mma_state.epoch} , obj {objective[0]:.2E}") |
No description provided.