diff --git a/diffrax/_misc.py b/diffrax/_misc.py index 38b90fdd..8fa07ce1 100644 --- a/diffrax/_misc.py +++ b/diffrax/_misc.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Any, cast +from typing import cast import jax import jax.core @@ -13,22 +13,17 @@ from ._custom_types import BoolScalarLike, RealScalarLike -_itemsize_kind_type: dict[tuple[int, str], Any] = { - (1, "i"): jnp.int8, - (2, "i"): jnp.int16, - (4, "i"): jnp.int32, - (8, "i"): jnp.int64, - (2, "f"): jnp.float16, - (4, "f"): jnp.float32, - (8, "f"): jnp.float64, -} - - def force_bitcast_convert_type(val, new_type): val = jnp.asarray(val) - intermediate_type = _itemsize_kind_type[new_type.dtype.itemsize, val.dtype.kind] - val = val.astype(intermediate_type) - return lax.bitcast_convert_type(val, new_type) + result = lax.bitcast_convert_type(val, new_type) + + # If downcasting (larger -> smaller type), bitcast returns multiple values. + # Combine them via XOR to ensure nearby input values map to different outputs. + if result.shape != val.shape: + result = jnp.bitwise_xor.reduce(result, axis=-1) + assert val.shape == result.shape + assert result.dtype == new_type + return result def _fill_forward( diff --git a/test/test_misc.py b/test/test_misc.py index bb61c813..cb341684 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -9,3 +9,22 @@ def test_fill_forward(): out_ = jnp.array([jnp.nan, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0]) fill_in = diffrax._misc.fill_forward(in_[:, None]) assert tree_allclose(fill_in, out_[:, None], equal_nan=True) + + +def test_force_bitcast_convert_type(): + val_1 = jnp.float64(1e6) + val_2 = jnp.float64(1e6 + 1e-4) + + # Val_1 and val_2 are different as float64, + # but would be the same if naively downcast to float32. + assert val_1 != val_2 + assert val_1.astype(jnp.int32) == val_2.astype(jnp.int32) + + val_1_cast = diffrax._misc.force_bitcast_convert_type(val_1, jnp.int32) + val_2_cast = diffrax._misc.force_bitcast_convert_type(val_2, jnp.int32) + + assert val_1_cast.dtype == jnp.int32 + assert val_2_cast.dtype == jnp.int32 + + # Bitcasted values should be different in the smaller type + assert val_1_cast != val_2_cast