Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions diffrax/_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import Any, cast
from typing import cast

import jax
import jax.core
Expand All @@ -13,22 +13,17 @@
from ._custom_types import BoolScalarLike, RealScalarLike


_itemsize_kind_type: dict[tuple[int, str], Any] = {
(1, "i"): jnp.int8,
(2, "i"): jnp.int16,
(4, "i"): jnp.int32,
(8, "i"): jnp.int64,
(2, "f"): jnp.float16,
(4, "f"): jnp.float32,
(8, "f"): jnp.float64,
}


def force_bitcast_convert_type(val, new_type):
val = jnp.asarray(val)
intermediate_type = _itemsize_kind_type[new_type.dtype.itemsize, val.dtype.kind]
val = val.astype(intermediate_type)
return lax.bitcast_convert_type(val, new_type)
result = lax.bitcast_convert_type(val, new_type)

# If downcasting (larger -> smaller type), bitcast returns multiple values.
# Combine them via XOR to ensure nearby input values map to different outputs.
if result.shape != val.shape:
result = jnp.bitwise_xor.reduce(result, axis=-1)
assert val.shape == result.shape
assert result.dtype == new_type
return result
Comment on lines +22 to +26
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe an assert result.shape == val.shape afterwards to be sure that this xor has done its job?



def _fill_forward(
Expand Down
19 changes: 19 additions & 0 deletions test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,22 @@ def test_fill_forward():
out_ = jnp.array([jnp.nan, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0])
fill_in = diffrax._misc.fill_forward(in_[:, None])
assert tree_allclose(fill_in, out_[:, None], equal_nan=True)


def test_force_bitcast_convert_type():
val_1 = jnp.float64(1e6)
val_2 = jnp.float64(1e6 + 1e-4)

# Val_1 and val_2 are different as float64,
# but would be the same if naively downcast to float32.
assert val_1 != val_2
assert val_1.astype(jnp.int32) == val_2.astype(jnp.int32)

val_1_cast = diffrax._misc.force_bitcast_convert_type(val_1, jnp.int32)
val_2_cast = diffrax._misc.force_bitcast_convert_type(val_2, jnp.int32)

assert val_1_cast.dtype == jnp.int32
assert val_2_cast.dtype == jnp.int32

# Bitcasted values should be different in the smaller type
assert val_1_cast != val_2_cast