Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def _is_none(x: Any) -> bool:
return x is None


class TermAndSolverIncompatible(ValueError):
pass


def _assert_term_compatible(
t: FloatScalarLike,
y: PyTree[ArrayLike],
Expand All @@ -137,7 +141,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
):
_assert_term_compatible(t, yi, args, term, arg, term_contr_kwarg)
else:
raise ValueError(
raise TermAndSolverIncompatible(
f"Term {term} is not a MultiTerm but is expected to be."
)
else:
Expand All @@ -147,7 +151,9 @@ def _check(term_cls, term, term_contr_kwargs, yi):
if origin_cls is None:
origin_cls = term_cls
if not isinstance(term, origin_cls):
raise ValueError(f"Term {term} is not an instance of {origin_cls}.")
raise TermAndSolverIncompatible(
f"Term {term} is not an instance of {origin_cls}."
)

# Now check the generic parametrization of `term_cls`; can be one of:
# -----------------------------------------
Expand All @@ -167,7 +173,9 @@ def _check(term_cls, term, term_contr_kwargs, yi):
better_isinstance, vf_type, vf_type_expected
)
if not vf_type_compatible:
raise ValueError(f"Vector field term {term} is incompatible.")
raise TermAndSolverIncompatible(
f"Vector field term {term} is incompatible."
)

contr = ft.partial(term.contr, **term_contr_kwargs)
# Work around https://github.com/google/jax/issues/21825
Expand All @@ -176,7 +184,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
better_isinstance, control_type, control_type_expected
)
if not control_type_compatible:
raise ValueError(
raise TermAndSolverIncompatible(
"Control term is incompatible: the returned control (e.g. "
f"Brownian motion for an SDE) was {control_type}, but this "
f"solver expected {control_type_expected}."
Expand All @@ -185,14 +193,23 @@ def _check(term_cls, term, term_contr_kwargs, yi):
assert False, "Malformed term structure"
# If we've got to this point then the term is compatible

try: # check for JAX pytree mismatches first
jtu.tree_map(lambda *a: None, term_structure, terms, contr_kwargs, y)
except ValueError as e:
pretty_term = wl.pformat(terms)
pretty_expected = wl.pformat(term_structure)
raise TermAndSolverIncompatible(
f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:"
f"\n{pretty_expected}\nNote that terms are checked recursively: if you "
"scroll up you may find a root-cause error that is more specific."
) from e
try:
with jax.numpy_dtype_promotion("standard"):
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
except ValueError as e:
# ValueError may also arise from mismatched tree structures
except TermAndSolverIncompatible as e:
pretty_term = wl.pformat(terms)
pretty_expected = wl.pformat(term_structure)
raise ValueError(
raise TermAndSolverIncompatible(
f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:"
f"\n{pretty_expected}\nNote that terms are checked recursively: if you "
"scroll up you may find a root-cause error that is more specific."
Expand Down Expand Up @@ -408,12 +425,6 @@ def body_fun_aux(state):
made_jump = static_select(keep_step, made_jump, state.made_jump)
solver_result = RESULTS.where(keep_step, solver_result, RESULTS.successful)

# TODO: if we ever support non-terminating events, then they should go in here.
# In particular the thing to be careful about is in the `if saveat.steps`
# branch below, where we want to make sure that it is the value of `y` at
# `tprev` that is actually saved. (And not just the value of `y` at the
# previous step's `tnext`, i.e. immediately before the jump.)

# Store the first unsuccessful result we get whilst iterating (if any).
result = RESULTS.where(is_okay(state.result), solver_result, state.result)
result = RESULTS.where(is_okay(result), stepsize_controller_result, result)
Expand Down Expand Up @@ -1333,11 +1344,21 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
# Note that we're threading a needle here! What if we terminate on the very
# first step? Our dense-info (and thus a subsequent root find) will be
# completely wrong!
# Fortunately, this can't quite happen:
#
# There are two things we need to consider:
# 1. doing the root find to locate the time of our event;
# 2. determining the value of `y` at that time.
#
# For 1:
# - A boolean event never uses dense-info (the interpolation is unused and we go
# to the end of the interval).
# - A floating event can't terminate on the first step (it requires a sign
# change).
#
# For 2:
# We explicitly have a `lax.cond` statement to determine whether we are on the
# first step or not.
#
# c.f. https://github.com/patrick-kidger/diffrax/issues/720
event_dense_info = jtu.tree_map(
lambda x: jnp.zeros(x.shape, x.dtype),
Expand Down
12 changes: 12 additions & 0 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,3 +864,15 @@ def grad_fn(params: jnp.ndarray) -> jnp.ndarray:

assert not jnp.isnan(grad).any(), "Gradient should not be NaN."
assert not jnp.isinf(grad).any(), "Gradient should not be infinite."


def test_nice_errors():
def vf1(t, y, args):
raise ValueError("Oh no!")

def vf2(t, y, args):
raise TypeError("Oh no!")

for vf, etype in ((vf1, ValueError), (vf2, TypeError)):
with pytest.raises(etype, match="Oh no!"):
diffrax.diffeqsolve(diffrax.ODETerm(vf), diffrax.Euler(), 0, 1, 0.1, 0)