Anderson acceleration for fixed point iteration#217
Anderson acceleration for fixed point iteration#217aidancrilly wants to merge 23 commits intopatrick-kidger:mainfrom
Conversation
* first pass at refactoring cauchy_point function, so far untested * add a pretty drawing * minor tweaks * refactored cauchy point finding function * add clarifying comment * adding test cases * bugfix: pick correct next intercept in the presence of infinite values * limit step length to a full gradient step * add expected results to cauchy point test cases * add clarifying comment: Hessian operator is assumed to be positive definite. * add explanations for test cases --------- Co-authored-by: Johanna Haffner <johanna.haffner@bsse.ethz.ch>
* Drop support for Python 3.10 * Update readme.md and index.md * CI: tests on 3.11 and 3.13
* Implement pytest-benchmark based setup for systematic performance evaluation of Optimistix' solvers. * version bump for sif2jax requirements * add semi-recent matplotlib version to specify a minimum * no more monkeypatching * set EQX_ON_ERROR with os.environ * give a reason for skipping compilation tests * Add L-BFGS solvers to benchmark suite * clarify what --benchmark-autosave will do. * remove strict dtype promotion rules - benchmarks are not tests, so we don't need them here. We would otherwise have to use context management for any comparison to Optax minimisers. * state purpose of --scipy flag more clearly. * improve contribution guidelines, inline decorator, specify pyright errata * pyproject.toml from main * add sif2jax * move benchmark dependencies to tests group * add benchmark-skip option * add example to contributing guidelines, document OrderedDict workaround, adapt to sif2jax usage of properties. --------- Co-authored-by: Johanna Haffner <johanna.haffner@bsse.ethz.ch>
This has been a subtle long-standing bug! This mainly manifested as BestSoFar seeming to be a squiffy. Fixes patrick-kidger#33.
Remarkably, this bug seems to have existed since Optimistix day 1: if a solver returns a non-successful value halfway through the solve then it does not cause termination. Instead we keep running and then check once the solve has completed.
|
Just a note that I now use this on a problem with ~1000 parameters which required aggressive damping with fixed point iteration (damp = 0.95) and 100s iterations, now with Anderson acceleration this problem runs undamped and converges in 10s of iterations! |
patrick-kidger
left a comment
There was a problem hiding this comment.
Okay, this looks really clean to me! I have just some minor nits.
Can you also add this to the documentation?
optimistix/_solver/fixed_point.py
Outdated
| # First iterate must be treated differently | ||
| def _first_iterate(): | ||
| new_y = jtu.tree_map( | ||
| lambda y_old, y_next: self.damp * y_old + (1 - self.damp) * y_next, |
There was a problem hiding this comment.
Nit: if isinstance(self.damp, (int, float)) and self.damp == 0 – which is the fairly common default, and is something known at compile time – then we can skip this operation and probably save some effort.
This is the kind of thing that XLA will typically not compile out because in principle 0 * nan = nan, so it can't apply the 'mathematical' optimization that 0 * x = 0.
There was a problem hiding this comment.
Ah interesting! I did not know it might not be compiled out, made the change here (as well as in the damped fixed point iteration scheme).
|
@patrick-kidger thanks, I believe this is all nits resolved and docs updated. |
Continued of closed #215 - apologies if this is not the way to re-open
I have implemented a form of Anderson acceleration for fixed point problems which should increase the rate of convergence in fixed point problems over standard fixed point iteration. Mentioned in #3
Here is a simple example for cos(x) = x:

A few queries: