diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 3017e30d..8241fa9b 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -114,6 +114,10 @@ def _is_none(x: Any) -> bool: return x is None +class TermAndSolverIncompatible(ValueError): + pass + + def _assert_term_compatible( t: FloatScalarLike, y: PyTree[ArrayLike], @@ -137,7 +141,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): ): _assert_term_compatible(t, yi, args, term, arg, term_contr_kwarg) else: - raise ValueError( + raise TermAndSolverIncompatible( f"Term {term} is not a MultiTerm but is expected to be." ) else: @@ -147,7 +151,9 @@ def _check(term_cls, term, term_contr_kwargs, yi): if origin_cls is None: origin_cls = term_cls if not isinstance(term, origin_cls): - raise ValueError(f"Term {term} is not an instance of {origin_cls}.") + raise TermAndSolverIncompatible( + f"Term {term} is not an instance of {origin_cls}." + ) # Now check the generic parametrization of `term_cls`; can be one of: # ----------------------------------------- @@ -167,7 +173,9 @@ def _check(term_cls, term, term_contr_kwargs, yi): better_isinstance, vf_type, vf_type_expected ) if not vf_type_compatible: - raise ValueError(f"Vector field term {term} is incompatible.") + raise TermAndSolverIncompatible( + f"Vector field term {term} is incompatible." + ) contr = ft.partial(term.contr, **term_contr_kwargs) # Work around https://github.com/google/jax/issues/21825 @@ -176,7 +184,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): better_isinstance, control_type, control_type_expected ) if not control_type_compatible: - raise ValueError( + raise TermAndSolverIncompatible( "Control term is incompatible: the returned control (e.g. " f"Brownian motion for an SDE) was {control_type}, but this " f"solver expected {control_type_expected}." @@ -185,14 +193,23 @@ def _check(term_cls, term, term_contr_kwargs, yi): assert False, "Malformed term structure" # If we've got to this point then the term is compatible + try: # check for JAX pytree mismatches first + jtu.tree_map(lambda *a: None, term_structure, terms, contr_kwargs, y) + except ValueError as e: + pretty_term = wl.pformat(terms) + pretty_expected = wl.pformat(term_structure) + raise TermAndSolverIncompatible( + f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:" + f"\n{pretty_expected}\nNote that terms are checked recursively: if you " + "scroll up you may find a root-cause error that is more specific." + ) from e try: with jax.numpy_dtype_promotion("standard"): jtu.tree_map(_check, term_structure, terms, contr_kwargs, y) - except ValueError as e: - # ValueError may also arise from mismatched tree structures + except TermAndSolverIncompatible as e: pretty_term = wl.pformat(terms) pretty_expected = wl.pformat(term_structure) - raise ValueError( + raise TermAndSolverIncompatible( f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:" f"\n{pretty_expected}\nNote that terms are checked recursively: if you " "scroll up you may find a root-cause error that is more specific." @@ -408,12 +425,6 @@ def body_fun_aux(state): made_jump = static_select(keep_step, made_jump, state.made_jump) solver_result = RESULTS.where(keep_step, solver_result, RESULTS.successful) - # TODO: if we ever support non-terminating events, then they should go in here. - # In particular the thing to be careful about is in the `if saveat.steps` - # branch below, where we want to make sure that it is the value of `y` at - # `tprev` that is actually saved. (And not just the value of `y` at the - # previous step's `tnext`, i.e. immediately before the jump.) - # Store the first unsuccessful result we get whilst iterating (if any). result = RESULTS.where(is_okay(state.result), solver_result, state.result) result = RESULTS.where(is_okay(result), stepsize_controller_result, result) @@ -1333,11 +1344,21 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState: # Note that we're threading a needle here! What if we terminate on the very # first step? Our dense-info (and thus a subsequent root find) will be # completely wrong! - # Fortunately, this can't quite happen: + # + # There are two things we need to consider: + # 1. doing the root find to locate the time of our event; + # 2. determining the value of `y` at that time. + # + # For 1: # - A boolean event never uses dense-info (the interpolation is unused and we go # to the end of the interval). # - A floating event can't terminate on the first step (it requires a sign # change). + # + # For 2: + # We explicitly have a `lax.cond` statement to determine whether we are on the + # first step or not. + # # c.f. https://github.com/patrick-kidger/diffrax/issues/720 event_dense_info = jtu.tree_map( lambda x: jnp.zeros(x.shape, x.dtype), diff --git a/test/test_integrate.py b/test/test_integrate.py index cfcaadfd..fcf3a836 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -864,3 +864,15 @@ def grad_fn(params: jnp.ndarray) -> jnp.ndarray: assert not jnp.isnan(grad).any(), "Gradient should not be NaN." assert not jnp.isinf(grad).any(), "Gradient should not be infinite." + + +def test_nice_errors(): + def vf1(t, y, args): + raise ValueError("Oh no!") + + def vf2(t, y, args): + raise TypeError("Oh no!") + + for vf, etype in ((vf1, ValueError), (vf2, TypeError)): + with pytest.raises(etype, match="Oh no!"): + diffrax.diffeqsolve(diffrax.ODETerm(vf), diffrax.Euler(), 0, 1, 0.1, 0)