Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
AbstractSRK as AbstractSRK,
AbstractStratonovichSolver as AbstractStratonovichSolver,
AbstractWrappedSolver as AbstractWrappedSolver,
Alexander2 as Alexander2,
ALIGN as ALIGN,
Bosh3 as Bosh3,
ButcherTableau as ButcherTableau,
Expand Down
1 change: 1 addition & 0 deletions diffrax/_solver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .alexander2 import Alexander2 as Alexander2
from .align import ALIGN as ALIGN
from .base import (
AbstractAdaptiveSolver as AbstractAdaptiveSolver,
Expand Down
68 changes: 68 additions & 0 deletions diffrax/_solver/alexander2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from collections.abc import Callable
from typing import ClassVar

import equinox.internal as eqxi
import numpy as np
import optimistix as optx

from .._local_interpolation import ThirdOrderHermitePolynomialInterpolation
from .._root_finder import VeryChord, with_stepsize_controller_tols
from .runge_kutta import AbstractSDIRK, ButcherTableau


gamma = 1 - 0.5 * np.sqrt(2)

_alexander2_tableau = ButcherTableau(
a_lower=(np.array([1 - gamma]),),
b_sol=np.array([1 - gamma, gamma]),
b_error=np.array([1 - gamma, gamma - 1]),
c=np.array([1.0]),
a_diagonal=np.array([gamma, gamma]),
a_predictor=(np.array([1.0]),),
)


class Alexander2(AbstractSDIRK):
r"""Alexander's 2/1 method.

A-L-stable stiffly accurate 2nd order SDIRK method. Has an embedded 1st
order method for adaptive step sizing. Uses 2 stages. Uses 3rd order
Hermite interpolation for dense/ts output.

??? cite "Reference"

```bibtex
@article{alexander1977diagonally,
title={Diagonally Implicit Runge--Kutta Methods for Stiff O.D.E.'s},
author={Alexander, Roger},
year={1977},
journal={SIAM Journal on Numerical Analysis},
volume={14},
number={6},
pages = {1006--1021}
}
```
"""

tableau: ClassVar[ButcherTableau] = _alexander2_tableau
interpolation_cls: ClassVar[
Callable[..., ThirdOrderHermitePolynomialInterpolation]
] = ThirdOrderHermitePolynomialInterpolation.from_k

root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)()
root_find_max_steps: int = 10

def order(self, terms):
del terms
return 2


eqxi.doc_remove_args("scan_kind")(Alexander2.__init__)
Alexander2.__init__.__doc__ = """**Arguments:**

- `root_finder`: an [Optimistix](https://github.com/patrick-kidger/optimistix) root
finder to solve the implicit problem at each stage.
- `root_find_max_steps`: the maximum number of steps that the root finder is allowed to
make before unconditionally rejecting the step. (And trying again with whatever
smaller step that adaptive stepsize controller proposes.)
"""
5 changes: 5 additions & 0 deletions docs/api/solvers/ode_solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ Each of these takes a `root_finder` argument at initialisation, defaulting to a
members:
- __init__

::: diffrax.Alexander2
options:
members:
- __init__

::: diffrax.Kvaerno3
options:
members:
Expand Down
1 change: 1 addition & 0 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
diffrax.ReversibleHeun(),
diffrax.Tsit5(),
diffrax.ImplicitEuler(),
diffrax.Alexander2(),
diffrax.Kvaerno3(),
diffrax.Kvaerno4(),
diffrax.Kvaerno5(),
Expand Down
2 changes: 2 additions & 0 deletions test/test_global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def test_dense_interpolation(solver):
diffrax.Euler: 1e-3,
diffrax.ImplicitEuler: 1e-3,
diffrax.Ralston: 1e-3,
diffrax.Alexander2: 1e-3,
}.get(type(solver), 1e-6)
assert tree_allclose(derivs, true_derivs, atol=deriv_tol, rtol=deriv_tol)

Expand Down Expand Up @@ -386,5 +387,6 @@ def test_dense_interpolation_vmap(solver, getkey):
diffrax.Euler: 1e-3,
diffrax.ImplicitEuler: 1e-3,
diffrax.Ralston: 1e-3,
diffrax.Alexander2: 1e-3,
}.get(type(solver), 1e-6)
assert tree_allclose(derivs, true_derivs, atol=deriv_tol, rtol=deriv_tol)
1 change: 1 addition & 0 deletions test/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def f2(t, y, args):
@pytest.mark.parametrize(
"solver",
(
diffrax.Alexander2(),
diffrax.Kvaerno3(),
diffrax.Kvaerno4(),
diffrax.Kvaerno5(),
Expand Down
Loading