From be124a520f2d6199761093f4736b88cd66bf5d3c Mon Sep 17 00:00:00 2001 From: Johanna Haffner <38662446+johannahaffner@users.noreply.github.com> Date: Mon, 16 Feb 2026 21:25:09 +0100 Subject: [PATCH 1/2] Update optimistix dependency version Fixes https://github.com/patrick-kidger/diffrax/issues/722. Updated the version of the 'optimistix' dependency from 0.0.10 to 0.1.0. I'm proposing to merge this into main and do a small bug fix release? GPU preallocation seems annoying enough for that. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 880c50c9..01b2e8f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Information Analysis", "Topic :: Scientific/Engineering :: Mathematics" ] -dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.10", "wadler_lindig>=0.1.1"] +dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.1.0", "wadler_lindig>=0.1.1"] description = "GPU+autodiff-capable ODE/SDE/CDE solvers written in JAX." keywords = ["jax", "dynamical-systems", "differential-equations", "deep-learning", "equinox", "neural-differential-equations", "diffrax"] license = {file = "LICENSE"} From 2bb85e08ec87c45ce3a33691cf42fa8a280ea2fd Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 17 Feb 2026 12:31:50 +0100 Subject: [PATCH 2/2] Fix for Optimistix 0.1.0 --- diffrax/_root_finder/_verychord.py | 2 +- test/test_very_chord.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/diffrax/_root_finder/_verychord.py b/diffrax/_root_finder/_verychord.py index 921a6dd4..385a2cb9 100644 --- a/diffrax/_root_finder/_verychord.py +++ b/diffrax/_root_finder/_verychord.py @@ -162,7 +162,7 @@ def terminate( converged = _converged(factor, self.kappa) terminate = at_least_two & (small | diverged | converged) terminate_result = optx.RESULTS.where( - jnp.invert(small) & (diverged | jnp.invert(converged)), + at_least_two & jnp.invert(small) & (diverged | jnp.invert(converged)), optx.RESULTS.nonlinear_divergence, optx.RESULTS.successful, ) diff --git a/test/test_very_chord.py b/test/test_very_chord.py index c0d0287a..17424630 100644 --- a/test/test_very_chord.py +++ b/test/test_very_chord.py @@ -24,6 +24,9 @@ def _fn2(x, args): @jax.jit def _fn3(x, args): mlp = eqx.nn.MLP(4, 4, 256, 2, key=jr.PRNGKey(678)) + dynamic, static = eqx.partition(mlp, eqx.is_array) + dynamic = jtu.tree_map(lambda x: x * 0.1, dynamic) + mlp = eqx.combine(dynamic, static) return mlp(x) - x