From 2e6b173f585eca33275b78b402e651850d9aa5d0 Mon Sep 17 00:00:00 2001 From: Alexander de Ranitz Date: Fri, 20 Feb 2026 15:33:33 +0100 Subject: [PATCH 1/3] Combine bits via xor when bitcasting from larger to smaller type Previously, when casting e.g. a float64 to an int32, numbers that were close in float64 could be mapped to identical int32's. Since these int's are used as keys to generate random sequences, this is problematic, as it results in identical noise being generated in subsequent timesteps. This commit fixes this by not throwing away bits when the input type is larger than the requested output type. Instead, the larger number is bitcast to multiple values in the smaller type, which are then combined using xor. --- diffrax/_misc.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/diffrax/_misc.py b/diffrax/_misc.py index 38b90fdd..4426afcc 100644 --- a/diffrax/_misc.py +++ b/diffrax/_misc.py @@ -12,23 +12,17 @@ from ._custom_types import BoolScalarLike, RealScalarLike - -_itemsize_kind_type: dict[tuple[int, str], Any] = { - (1, "i"): jnp.int8, - (2, "i"): jnp.int16, - (4, "i"): jnp.int32, - (8, "i"): jnp.int64, - (2, "f"): jnp.float16, - (4, "f"): jnp.float32, - (8, "f"): jnp.float64, -} - - def force_bitcast_convert_type(val, new_type): val = jnp.asarray(val) - intermediate_type = _itemsize_kind_type[new_type.dtype.itemsize, val.dtype.kind] - val = val.astype(intermediate_type) - return lax.bitcast_convert_type(val, new_type) + result = lax.bitcast_convert_type(val, new_type) + + # If downcasting (larger -> smaller type), bitcast returns multiple values. + # Combine them via XOR to ensure nearby input values map to different outputs. + if result.shape != val.shape: + result = jnp.bitwise_xor.reduce(result, axis=-1) + + return result + def _fill_forward( From e76d53d4a0ee0003b969fbefdecd92ee13dbabb6 Mon Sep 17 00:00:00 2001 From: Alexander de Ranitz Date: Sun, 22 Feb 2026 11:35:19 +0100 Subject: [PATCH 2/3] Add assertion to check for shape and dtype --- diffrax/_misc.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/diffrax/_misc.py b/diffrax/_misc.py index 4426afcc..8fa07ce1 100644 --- a/diffrax/_misc.py +++ b/diffrax/_misc.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Any, cast +from typing import cast import jax import jax.core @@ -12,19 +12,20 @@ from ._custom_types import BoolScalarLike, RealScalarLike + def force_bitcast_convert_type(val, new_type): val = jnp.asarray(val) result = lax.bitcast_convert_type(val, new_type) - + # If downcasting (larger -> smaller type), bitcast returns multiple values. # Combine them via XOR to ensure nearby input values map to different outputs. if result.shape != val.shape: result = jnp.bitwise_xor.reduce(result, axis=-1) - + assert val.shape == result.shape + assert result.dtype == new_type return result - def _fill_forward( last_observed_yi: Shaped[Array, " *channels"], yi: Shaped[Array, " *channels"] ) -> tuple[Shaped[Array, " *channels"], Shaped[Array, " *channels"]]: From 0a2eebfc7fd965b699f8982d434df7f3c1c15c2a Mon Sep 17 00:00:00 2001 From: Alexander de Ranitz Date: Sun, 22 Feb 2026 11:39:54 +0100 Subject: [PATCH 3/3] Add pytest for bitcasting to smaller type This checks the intended behaviour of mapping nearby numbers to distinct values when downcasting to a smaller dtype. --- test/test_misc.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/test_misc.py b/test/test_misc.py index bb61c813..cb341684 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -9,3 +9,22 @@ def test_fill_forward(): out_ = jnp.array([jnp.nan, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0]) fill_in = diffrax._misc.fill_forward(in_[:, None]) assert tree_allclose(fill_in, out_[:, None], equal_nan=True) + + +def test_force_bitcast_convert_type(): + val_1 = jnp.float64(1e6) + val_2 = jnp.float64(1e6 + 1e-4) + + # Val_1 and val_2 are different as float64, + # but would be the same if naively downcast to float32. + assert val_1 != val_2 + assert val_1.astype(jnp.int32) == val_2.astype(jnp.int32) + + val_1_cast = diffrax._misc.force_bitcast_convert_type(val_1, jnp.int32) + val_2_cast = diffrax._misc.force_bitcast_convert_type(val_2, jnp.int32) + + assert val_1_cast.dtype == jnp.int32 + assert val_2_cast.dtype == jnp.int32 + + # Bitcasted values should be different in the smaller type + assert val_1_cast != val_2_cast