From 3d11630787dac29abd91c4d45c8dd58c8254831e Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 05:57:52 -0500 Subject: [PATCH 01/51] fix: raise errors for invalid shears and PixelScale WCS inits --- jax_galsim/shear.py | 46 +++++++++++++++++++++++++++++++++++++-------- pyproject.toml | 2 +- tests/GalSim | 2 +- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 074a762e..55b639a5 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -1,3 +1,4 @@ +import equinox import galsim as _galsim import jax.numpy as jnp from galsim.errors import GalSimIncompatibleValuesError @@ -45,15 +46,24 @@ def __init__(self, *args, **kwargs): # g1,g2 elif "g1" in kwargs or "g2" in kwargs: - g1 = kwargs.pop("g1", 0.0) - g2 = kwargs.pop("g2", 0.0) + g1 = jnp.array(kwargs.pop("g1", 0.0)) + g2 = jnp.array(kwargs.pop("g2", 0.0)) self._g = g1 + 1j * g2 + self._g = equinox.error_if( + self._g, jnp.abs(self._g) > 1., + "Requested shear exceeds 1.", + ) # e1,e2 elif "e1" in kwargs or "e2" in kwargs: - e1 = kwargs.pop("e1", 0.0) - e2 = kwargs.pop("e2", 0.0) + e1 = jnp.array(kwargs.pop("e1", 0.0)) + e2 = jnp.array(kwargs.pop("e2", 0.0)) absesq = e1**2 + e2**2 + absesq = equinox.error_if( + absesq, + absesq > 1., + "Requested distortion exceeds 1.", + ) self._g = (e1 + 1j * e2) * self._e2g(absesq) # eta1,eta2 @@ -75,7 +85,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - g = kwargs.pop("g") + g = jnp.array(kwargs.pop("g")) + g = equinox.error_if( + g, + g > 1 or g < 0, + "Requested |shear| is outside [0,1].", + ) self._g = g * jnp.exp(2j * beta.rad) # e,beta @@ -89,7 +104,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - e = kwargs.pop("e") + e = jnp.array(kwargs.pop("e")) + e = equinox.error_if( + e, + (e > 1) | (e < 0), + "Requested distortion is outside [0,1].", + ) self._g = self._e2g(e**2) * e * jnp.exp(2j * beta.rad) # eta,beta @@ -103,7 +123,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - eta = kwargs.pop("eta") + eta = jnp.array(kwargs.pop("eta")) + eta = equinox.error_if( + eta, + eta < 0, + "Requested eta is below 0.", + ) self._g = self._eta2g(eta) * eta * jnp.exp(2j * beta.rad) # q,beta @@ -117,7 +142,12 @@ def __init__(self, *args, **kwargs): beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") - q = kwargs.pop("q") + q = jnp.array(kwargs.pop("q")) + q = equinox.error_if( + q, + (q <= 0) | (q > 1), + "Cannot use requested axis ratio.", + ) eta = -jnp.log(q) self._g = self._eta2g(eta) * eta * jnp.exp(2j * beta.rad) diff --git a/pyproject.toml b/pyproject.toml index 2ad2d9aa..f8ca4441 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ description = "The modular galaxy image simulation toolkit, but in JAX" dynamic = ["version"] license = { file = "LICENSE" } readme = "README.md" -dependencies = ["numpy >=1.18.0", "galsim >=2.7.0", "jax >=0.6.0", "astropy >=2.0", "quadax"] +dependencies = ["numpy >=1.18.0", "galsim >=2.7.0", "jax >=0.6.0", "astropy >=2.0", "quadax", "equinox"] [project.optional-dependencies] dev = ["pytest", "pytest-codspeed"] diff --git a/tests/GalSim b/tests/GalSim index a5afbf51..11c473b4 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit a5afbf510dc747f5667f61c742b9dd3630643988 +Subproject commit 11c473b4fde8b8b730af654a47b96e7894862d57 From bd0e282c71dd6de8bfc3e502bf0477099c83224e Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 05:58:40 -0500 Subject: [PATCH 02/51] please the dog --- jax_galsim/shear.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 55b639a5..92d57693 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -50,7 +50,8 @@ def __init__(self, *args, **kwargs): g2 = jnp.array(kwargs.pop("g2", 0.0)) self._g = g1 + 1j * g2 self._g = equinox.error_if( - self._g, jnp.abs(self._g) > 1., + self._g, + jnp.abs(self._g) > 1.0, "Requested shear exceeds 1.", ) @@ -61,7 +62,7 @@ def __init__(self, *args, **kwargs): absesq = e1**2 + e2**2 absesq = equinox.error_if( absesq, - absesq > 1., + absesq > 1.0, "Requested distortion exceeds 1.", ) self._g = (e1 + 1j * e2) * self._e2g(absesq) From a3b7ba4c68c4e806cfc95c5420e3af844de1642a Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 06:31:39 -0500 Subject: [PATCH 03/51] fix: mock up equinox --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 11c473b4..062c9ed0 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 11c473b4fde8b8b730af654a47b96e7894862d57 +Subproject commit 062c9ed06ae309b1a47885ee8abee3b7860760ac From 5a43c922a80dd1d234e44c9e55055e1a7262991b Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 06:35:42 -0500 Subject: [PATCH 04/51] test: more array equals --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 062c9ed0..e5ee4016 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 062c9ed06ae309b1a47885ee8abee3b7860760ac +Subproject commit e5ee401606efcc43b6a8f6ca5a204f5d95befc94 From ff189007dfc0e6bd0a4872eeb89f882d134a7714 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 06:49:07 -0500 Subject: [PATCH 05/51] doc: update docs for shears --- jax_galsim/shear.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 92d57693..60e5b306 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -11,9 +11,11 @@ @register_pytree_node_class @implements( _galsim.Shear, - lax_description="""\ -The jax_galsim implementation of ``Shear`` does not perform range checking of the \ -shear (e.g., ``|g| <= 1``) upon construction.""", + lax_description=( + "While the JAX-GalSim implementation of ``Shear`` will do range checking of " + "the shear upon construction, it raises ``equinox.EquinoxRuntimeError`` exceptions " + "instead of ``galsim.GalSimRangeError`` exceptions." + ), ) class Shear(object): def __init__(self, *args, **kwargs): From 85569db75d31cb70ad11b1d4290b4ddabcdc91ff Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 06:51:39 -0500 Subject: [PATCH 06/51] fix: clarify docs --- jax_galsim/shear.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 60e5b306..59ef2cca 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -12,9 +12,9 @@ @implements( _galsim.Shear, lax_description=( - "While the JAX-GalSim implementation of ``Shear`` will do range checking of " - "the shear upon construction, it raises ``equinox.EquinoxRuntimeError`` exceptions " - "instead of ``galsim.GalSimRangeError`` exceptions." + "While the JAX-GalSim implementation of ``Shear`` will raise exceptions for " + "invalid shear values (e.g., |g| > 1), it raises ``equinox.EquinoxRuntimeError`` " + "exceptions instead of ``galsim.GalSimRangeError`` exceptions." ), ) class Shear(object): From d8e24ae1f0486eaf76d243b8c9d43511aa7dbf74 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 08:49:57 -0500 Subject: [PATCH 07/51] fix: raise erorr on failed integrations --- jax_galsim/integ.py | 19 +++++++++++-------- tests/GalSim | 2 +- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index 19ad5c4b..aaef6396 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -1,5 +1,6 @@ from functools import partial +import equinox import galsim as _galsim import jax.lax import jax.numpy as jnp @@ -7,7 +8,10 @@ from jax_galsim.core.utils import implements +# @partial(jax.jit, static_argnames=("func", "_wrap_as_callback")) + +@equinox.filter_jit @implements( _galsim.integ.int1d, lax_description=( @@ -17,12 +21,11 @@ - This implementation is different than the one in GalSim and lacks some features that greatly enhance galsim's accuracy. -- The JAX-GalSim implementation returns NaN on error/non-convergence instead of - rasing an exception. +- The JAX-GalSim implementation raises a ``equinox.EquinoxRuntimeError`` on error/non-convergence + instead of rasing a ``galsim.GalSimError`` exception. """ ), ) -@partial(jax.jit, static_argnames=("func", "_wrap_as_callback")) def int1d( func, min, @@ -37,7 +40,7 @@ def int1d( # can be used with jax if _wrap_as_callback: - @jax.jit + @equinox.filter_jit def _func(x): rdt = jax.ShapeDtypeStruct(x.shape, x.dtype) return jax.pure_callback(func, rdt, x, vmap_method="sequential") @@ -72,8 +75,8 @@ def _base_integration(): _base_integration, ) - return jax.lax.cond( - status == 0, - lambda: val, - lambda: jnp.nan, + return equinox.error_if( + val, + status != 0, + "`jax_galsim.int1d` failed to converge!", ) diff --git a/tests/GalSim b/tests/GalSim index e5ee4016..22013ee3 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit e5ee401606efcc43b6a8f6ca5a204f5d95befc94 +Subproject commit 22013ee3c4fe1659814e5bfc147779fac22dd8de From 81976e63362a93816b2d1f5b7a16a54904127b60 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 08:50:45 -0500 Subject: [PATCH 08/51] style: please the dog --- jax_galsim/integ.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index aaef6396..0404271d 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -1,5 +1,3 @@ -from functools import partial - import equinox import galsim as _galsim import jax.lax From d21e3830bf00b6dfccbb7ba09e6e5611b6393b81 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 09:03:56 -0500 Subject: [PATCH 09/51] fix: try code with equinox filter_jit --- jax_galsim/integ.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index 0404271d..208e36d5 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -1,3 +1,5 @@ +from functools import partial + import equinox import galsim as _galsim import jax.lax @@ -6,8 +8,6 @@ from jax_galsim.core.utils import implements -# @partial(jax.jit, static_argnames=("func", "_wrap_as_callback")) - @equinox.filter_jit @implements( From 8805c8c9d79f901d4c1ccbf7a0aed6f2f4143ba2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 09:38:19 -0500 Subject: [PATCH 10/51] fix: use standard JIT --- jax_galsim/integ.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index 208e36d5..277e57f6 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -9,7 +9,6 @@ from jax_galsim.core.utils import implements -@equinox.filter_jit @implements( _galsim.integ.int1d, lax_description=( @@ -24,6 +23,7 @@ """ ), ) +@partial(jax.jit, static_argnames=("func", "_wrap_as_callback")) def int1d( func, min, @@ -38,7 +38,7 @@ def int1d( # can be used with jax if _wrap_as_callback: - @equinox.filter_jit + @jax.jit def _func(x): rdt = jax.ShapeDtypeStruct(x.shape, x.dtype) return jax.pure_callback(func, rdt, x, vmap_method="sequential") From 4d661a28efdff90fb5774baa9242c230e6a91ceb Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 09:38:46 -0500 Subject: [PATCH 11/51] doc: update doc strings --- jax_galsim/integ.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index 277e57f6..de1c26ce 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -18,7 +18,7 @@ - This implementation is different than the one in GalSim and lacks some features that greatly enhance galsim's accuracy. -- The JAX-GalSim implementation raises a ``equinox.EquinoxRuntimeError`` on error/non-convergence +- The JAX-GalSim implementation raises a generic ``Exception`` on error/non-convergence instead of rasing a ``galsim.GalSimError`` exception. """ ), From 1561dc06955d5601358aca2a68164e4fd1cb633d Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 10:01:45 -0500 Subject: [PATCH 12/51] fix: only use generic Exception --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 22013ee3..95ff6fd9 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 22013ee3c4fe1659814e5bfc147779fac22dd8de +Subproject commit 95ff6fd945cec5056a33276af3333fb70f5cb879 From 8c69aac025e2dc0842b17091a38c788ce6584e8f Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 10:02:31 -0500 Subject: [PATCH 13/51] doc: update docs --- jax_galsim/shear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index 59ef2cca..adfb25f4 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -13,8 +13,8 @@ _galsim.Shear, lax_description=( "While the JAX-GalSim implementation of ``Shear`` will raise exceptions for " - "invalid shear values (e.g., |g| > 1), it raises ``equinox.EquinoxRuntimeError`` " - "exceptions instead of ``galsim.GalSimRangeError`` exceptions." + "invalid shear values (e.g., |g| > 1), it raises a generic ``Exception`` " + "instead of ``galsim.GalSimRangeError`` exceptions." ), ) class Shear(object): From 6d0a66b921d924c3d6c5902867b5ff471ce25f5f Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 16:23:00 -0500 Subject: [PATCH 14/51] refactor: centralize logic for int checking --- jax_galsim/bounds.py | 48 ++++++++++++---------------------- jax_galsim/core/utils.py | 31 ++++++++++++++++++++++ jax_galsim/position.py | 12 ++++++--- jax_galsim/shear.py | 2 +- tests/jax/test_position_jax.py | 21 +++++++++++++++ 5 files changed, 78 insertions(+), 36 deletions(-) create mode 100644 tests/jax/test_position_jax.py diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 10442ad4..21b6ba8a 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,34 +1,24 @@ import galsim as _galsim import jax import jax.numpy as jnp -import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( + CONST_TYPES, cast_to_float, cast_to_int, + cast_to_python_float, + check_is_int_then_cast, ensure_hashable, has_tracers, implements, ) from jax_galsim.position import Position, PositionD, PositionI -CONST_TYPES = (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64) -CONST_TYPES_WITH_JAX = CONST_TYPES + ( - jax.Array, - jnp.array, - jnp.int32, - jnp.int64, - jnp.float32, - jnp.float64, -) - -# TODO: write extra docs for JAX changes BOUNDS_LAX_DESCR = """\ The JAX implementation - will not always test whether the bounds are valid -- will not always test whether BoundsI is initialized with integers Further, the JAX implementation adds a new method, ``isStatic`` to the ``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance @@ -525,31 +515,27 @@ def __init__(self, *args, **kwargs): f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}." ) + self.deltax = cast_to_python_float(self.deltax) + self.deltay = cast_to_python_float(self.deltay) + if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): + raise TypeError("BoundsI must be initialized with integer values") self.deltax = int(cast_to_int(self.deltax)) self.deltay = int(cast_to_int(self.deltay)) - if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): - raise TypeError("BoundsI must be initialized with integer values") + if has_tracers(self._xmin) or has_tracers(self._ymin): + self._isstatic = False + + # validate inputs are ints + self._xmin = check_is_int_then_cast( + self._xmin, "BoundsI must be initialized with integer values" + ) + self._ymin = check_is_int_then_cast( + self._ymin, "BoundsI must be initialized with integer values" + ) if self.deltax < 1 and self.deltay < 1: self._isdefined = False - # for simple inputs, we can check if the bounds are valid ints - if isinstance(self._xmin, CONST_TYPES) and self._xmin != int(self._xmin): - raise TypeError("BoundsI must be initialized with integer values") - - if isinstance(self._ymin, CONST_TYPES) and self._ymin != int(self._ymin): - raise TypeError("BoundsI must be initialized with integer values") - - if not has_tracers(self._xmin) and not has_tracers(self._ymin): - self._isstatic = True - self._xmin = int(np.trunc(self._xmin)) - self._ymin = int(np.trunc(self._ymin)) - else: - self._isstatic = False - self._xmin = cast_to_float(jnp.trunc(self._xmin)) - self._ymin = cast_to_float(jnp.trunc(self._ymin)) - if force_static and not self._isstatic: raise RuntimeError( "BoundsI initialized with non-static " diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 3fcbf46d..e4d51b18 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -3,11 +3,42 @@ from functools import partial from typing import NamedTuple +import equinox import jax import jax.numpy as jnp import numpy as np from jax.tree_util import tree_flatten +CONST_TYPES = (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64) +CONST_TYPES_WITH_JAX = CONST_TYPES + ( + jax.Array, + jnp.array, + jnp.int32, + jnp.int64, + jnp.float32, + jnp.float64, +) + + +def check_is_int_then_cast(val, msg): + """Cast to integer and raise if value is not int.""" + # for simple inputs, we can check if the bounds are valid ints + if isinstance(val, CONST_TYPES) and not has_tracers(val): + val = cast_to_python_float(val) + if val != int(val): + raise TypeError(msg) + val = int(val) + else: + # otherwise we use more opaque checking upon jit + val = equinox.error_if( + val, + val != jnp.trunc(val), + msg, + ) + val = val.astype(int) + + return val + def cast_numpy_array_to_native_byte_order(arr): """Cast an array to native byte order.""" diff --git a/jax_galsim/position.py b/jax_galsim/position.py index 822797b8..b3af5844 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -5,7 +5,7 @@ from jax_galsim.core.utils import ( cast_to_float, - cast_to_int, + check_is_int_then_cast, ensure_hashable, implements, ) @@ -214,9 +214,13 @@ class PositionI(Position): def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) - # inputs must be ints - self.x = cast_to_int(self.x) - self.y = cast_to_int(self.y) + # validate input is int + self.x = check_is_int_then_cast( + self.x, "PositionI must be initialized with integer values" + ) + self.y = check_is_int_then_cast( + self.y, "PositionI must be initialized with integer values" + ) def _check_scalar(self, other, op): try: diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index adfb25f4..dd88f424 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -14,7 +14,7 @@ lax_description=( "While the JAX-GalSim implementation of ``Shear`` will raise exceptions for " "invalid shear values (e.g., |g| > 1), it raises a generic ``Exception`` " - "instead of ``galsim.GalSimRangeError`` exceptions." + "instead of a ``galsim.GalSimRangeError`` exception." ), ) class Shear(object): diff --git a/tests/jax/test_position_jax.py b/tests/jax/test_position_jax.py new file mode 100644 index 00000000..e73937e7 --- /dev/null +++ b/tests/jax/test_position_jax.py @@ -0,0 +1,21 @@ +import jax +import pytest + +import jax_galsim + + +def test_position_jax_int_raises_in_jit(): + + @jax.jit + def _make_pos(x, y): + return jax_galsim.PositionI(x, y) + + with pytest.raises(Exception): + _make_pos(1.2, 23) + + with pytest.raises(Exception): + _make_pos(12, 2.3) + + pos = _make_pos(1, 2) + assert pos.x == 1 + assert pos.y == 2 From 260984000f5cbb96af27c6cff1f19fdbf923139f Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 14 May 2026 16:26:09 -0500 Subject: [PATCH 15/51] doc: ensure doc string is accurate --- jax_galsim/core/utils.py | 6 +++--- tests/GalSim | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index e4d51b18..18916073 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -21,15 +21,15 @@ def check_is_int_then_cast(val, msg): - """Cast to integer and raise if value is not int.""" - # for simple inputs, we can check if the bounds are valid ints + """Check if `val` is an integer, raise if not, otherwise cast to int.""" + # for simple inputs, we can check direct in python if isinstance(val, CONST_TYPES) and not has_tracers(val): val = cast_to_python_float(val) if val != int(val): raise TypeError(msg) val = int(val) else: - # otherwise we use more opaque checking upon jit + # otherwise we use more opaque checking upon jit via equinox val = equinox.error_if( val, val != jnp.trunc(val), diff --git a/tests/GalSim b/tests/GalSim index 95ff6fd9..0fe6d90b 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 95ff6fd945cec5056a33276af3333fb70f5cb879 +Subproject commit 0fe6d90bd7df4f660c923dc03da8aa44b347afda From ed7c5083e0fe994fbc13093c0fb5792a28ccc37f Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 05:44:57 -0500 Subject: [PATCH 16/51] fix: enable tests for image gain, area, exptime, and max_extra_noise --- jax_galsim/gsobject.py | 23 ++++++++++++++++++++--- tests/GalSim | 2 +- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 3d72dbab..5907e5b8 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1,6 +1,7 @@ from collections import namedtuple from functools import partial +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -601,7 +602,7 @@ def drawImage( offset=None, n_photons=None, rng=None, - max_extra_noise=0.0, + max_extra_noise=None, poisson_flux=None, sensor=None, photon_ops=(), @@ -626,6 +627,13 @@ def drawImage( if image is not None and not isinstance(image, Image): raise TypeError("image is not an Image instance", image) + # Make sure (gain, area, exptime) have valid values: + gain = equinox.error_if(jnp.array(gain), gain <= 0.0, "Invalid gain <= 0.") + area = equinox.error_if(jnp.array(area), area <= 0.0, "Invalid area <= 0.") + exptime = equinox.error_if( + jnp.array(exptime), exptime <= 0.0, "Invalid exptime <= 0." + ) + if method == "phot" and save_photons and maxN is not None: raise GalSimIncompatibleValuesError( "Setting maxN is incompatible with save_photons=True" @@ -659,6 +667,13 @@ def drawImage( sensor=sensor, n_photons=n_photons, ) + if max_extra_noise is not None: + raise GalSimIncompatibleValuesError( + "max_extra_noise is only relevant for method='phot'", + method=method, + sensor=sensor, + max_extra_noise=max_extra_noise, + ) if poisson_flux is not None: raise GalSimIncompatibleValuesError( "poisson_flux is only relevant for method='phot'", @@ -1078,6 +1093,8 @@ def _drawKImage( @implements(_galsim.GSObject._calculate_nphotons) def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): + if max_extra_noise is None: + max_extra_noise = 0.0 n_photons, g, _rng = calculate_n_photons( self.flux, self._flux_per_photon, @@ -1106,7 +1123,7 @@ def makePhot( self, n_photons=None, rng=None, - max_extra_noise=0.0, + max_extra_noise=None, poisson_flux=None, photon_ops=(), local_wcs=None, @@ -1178,7 +1195,7 @@ def drawPhot( add_to_image=False, n_photons=None, rng=None, - max_extra_noise=0.0, + max_extra_noise=None, poisson_flux=None, sensor=None, photon_ops=(), diff --git a/tests/GalSim b/tests/GalSim index 0fe6d90b..200c2cd2 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 0fe6d90bd7df4f660c923dc03da8aa44b347afda +Subproject commit 200c2cd2bad9f8f93936290cdca9d87ee10ebaa1 From 831990a64e266b8bc1ed8256996561f67431cc8d Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 05:47:02 -0500 Subject: [PATCH 17/51] doc: update doc strings --- jax_galsim/gsobject.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 5907e5b8..0ea14865 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -575,10 +575,11 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): lax_description="""\ The JAX-GalSim version of ``drawImage`` -- does not do extensive (any?) checking of the input settings. - uses a default of ``n_photons=None`` instead of ``n_photons=0`` to indicate that the number of photons should be determined from the flux and gain +- uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` + to indicate no limit on the extra noise - requires that the ``maxN`` option be a constant since PhotonArrays are allocated with ``maxN`` photons when this option is used and arrays in JAX must have static sizes. """, From 42fb804703ddaed724850f410f15d06d4a2930d3 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 05:48:00 -0500 Subject: [PATCH 18/51] doc: update doc string --- jax_galsim/gsobject.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 0ea14865..1a94cf1f 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -582,6 +582,7 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): to indicate no limit on the extra noise - requires that the ``maxN`` option be a constant since PhotonArrays are allocated with ``maxN`` photons when this option is used and arrays in JAX must have static sizes. +- raises generic ``Exception``s instead of more specific exceptions for some invalid inputs """, ) def drawImage( From 2ef7195df3502c9286890f7c872db669bb4ed289 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 05:51:32 -0500 Subject: [PATCH 19/51] doc: add doc string for position exceptions --- jax_galsim/position.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/jax_galsim/position.py b/jax_galsim/position.py index b3af5844..cf36dba8 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -208,7 +208,14 @@ def _check_scalar(self, other, op): raise TypeError("Can only %s a PositionD by float values" % op) -@implements(_galsim.PositionI) +@implements( + _galsim.PositionI, + lax_description=( + "The ``jax_galsim.PositionI`` class will raise generic " + "``Exception``s instead of a more specific exception for invalid " + "inputs." + ), +) @register_pytree_node_class class PositionI(Position): def __init__(self, *args, **kwargs): From 3ee5f33e0776230c89339f974f28bf97f92ab5a8 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 06:01:15 -0500 Subject: [PATCH 20/51] fix+doc: do more error checking and more docs --- jax_galsim/gsobject.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 1a94cf1f..ce69ec4b 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1115,10 +1115,10 @@ def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): lax_description="""\ The JAX-GalSim version of ``makePhot`` -- does little to no error checking on the inputs - uses a default of ``n_photons=None`` instead of ``n_photons=0`` - to indicate that the number of photons should be determined - from the flux and gain +- uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` + to indicate no limit on the extra noise +- raises generic ``Exception``s instead of more specific exceptions for some invalid inputs """, ) def makePhot( @@ -1187,6 +1187,9 @@ def makePhot( - uses a default of ``n_photons=None`` instead of ``n_photons=0`` to indicate that the number of photons should be determined from the flux and gain +- uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` + to indicate no limit on the extra noise +- raises generic ``Exception``s instead of more specific exceptions for some invalid inputs - requires that the ``maxN`` option must be a constant """, ) @@ -1227,6 +1230,8 @@ def drawPhot( elif not isinstance(sensor, Sensor): raise TypeError("The sensor provided is not a Sensor instance") + gain = equinox.error_if(jnp.array(gain), gain <= 0.0, "Invalid gain <= 0.") + if n_photons is not None: # n_photons is the length of an array so it is a python int and # and thus a constant wrt to JIT From 48c2569eeaed9f177f1eea5416c57e961d28cd7c Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 06:20:03 -0500 Subject: [PATCH 21/51] doc: fix doc string formatting --- jax_galsim/gsobject.py | 6 +++--- tests/GalSim | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index ce69ec4b..c6896fec 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -582,7 +582,7 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): to indicate no limit on the extra noise - requires that the ``maxN`` option be a constant since PhotonArrays are allocated with ``maxN`` photons when this option is used and arrays in JAX must have static sizes. -- raises generic ``Exception``s instead of more specific exceptions for some invalid inputs +- raises a generic ``Exception`` instead of a more specific exception for some invalid inputs """, ) def drawImage( @@ -1118,7 +1118,7 @@ def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): - uses a default of ``n_photons=None`` instead of ``n_photons=0`` - uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` to indicate no limit on the extra noise -- raises generic ``Exception``s instead of more specific exceptions for some invalid inputs +- raises a generic ``Exception`` instead of a more specific exception for some invalid inputs """, ) def makePhot( @@ -1189,7 +1189,7 @@ def makePhot( from the flux and gain - uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` to indicate no limit on the extra noise -- raises generic ``Exception``s instead of more specific exceptions for some invalid inputs +- raises a generic ``Exception`` instead of a more specific exception for some invalid inputs - requires that the ``maxN`` option must be a constant """, ) diff --git a/tests/GalSim b/tests/GalSim index 200c2cd2..0dabbf46 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 200c2cd2bad9f8f93936290cdca9d87ee10ebaa1 +Subproject commit 0dabbf463b4af7f689074c8373a936d511e4b836 From 6e9c4debe9d5ec037b97edbe3be126b9f48e3b44 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 15 May 2026 07:42:15 -0500 Subject: [PATCH 22/51] Apply suggestion from @beckermr --- jax_galsim/gsobject.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index c6896fec..b1543f54 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -579,7 +579,6 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): to indicate that the number of photons should be determined from the flux and gain - uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` - to indicate no limit on the extra noise - requires that the ``maxN`` option be a constant since PhotonArrays are allocated with ``maxN`` photons when this option is used and arrays in JAX must have static sizes. - raises a generic ``Exception`` instead of a more specific exception for some invalid inputs From 37310d93d7753f19437ef4be3d7f796d37507068 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 08:16:32 -0500 Subject: [PATCH 23/51] fix: add the rest of the types --- jax_galsim/core/utils.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 18916073..91007e39 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -9,14 +9,31 @@ import numpy as np from jax.tree_util import tree_flatten -CONST_TYPES = (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64) +CONST_TYPES = ( + float, + int, + np.ndarray, + np.int8, + np.int16, + np.int32, + np.int64, + np.float16, + np.float32, + np.float64, + np.complex64, + np.complex128, +) CONST_TYPES_WITH_JAX = CONST_TYPES + ( jax.Array, jnp.array, + jnp.int8, + jnp.int16, jnp.int32, jnp.int64, jnp.float32, jnp.float64, + jnp.complex64, + jnp.complex128, ) From 88c1c7b546eeb1ff2d6a4da450d08cd9c466df31 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 08:17:50 -0500 Subject: [PATCH 24/51] fix: use proper array ref --- jax_galsim/core/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 91007e39..71ff2735 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -25,7 +25,7 @@ ) CONST_TYPES_WITH_JAX = CONST_TYPES + ( jax.Array, - jnp.array, + jnp.ndarray, jnp.int8, jnp.int16, jnp.int32, From 70e4c3f2406876ca5a08201f3612f46082af7ad6 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 15 May 2026 08:18:35 -0500 Subject: [PATCH 25/51] Apply suggestion from @beckermr --- jax_galsim/gsobject.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index b1543f54..db59db43 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -1187,7 +1187,6 @@ def makePhot( to indicate that the number of photons should be determined from the flux and gain - uses a default of ``max_extra_noise=None`` instead of ``max_extra_noise=0`` - to indicate no limit on the extra noise - raises a generic ``Exception`` instead of a more specific exception for some invalid inputs - requires that the ``maxN`` option must be a constant """, From 164c2c626c015ddd5be7e38008dc074b2ed0abf9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 08:21:50 -0500 Subject: [PATCH 26/51] fix: docs done wrong --- docs/conf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index edb02218..f96b604b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,9 +29,11 @@ copyright = "2026, GalSim Developers" try: - from jax_galsim._version import version as release + from jax_galsim._version import version except ImportError: - release = "0.0.1.dev0" + version = "0.0.1.dev0" + +release = version # --------------------------------------------------------------------------- # General configuration From aafb848227446f3d6bea8fed601af7c9198620e0 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 08:37:31 -0500 Subject: [PATCH 27/51] test: update to latest submodule --- tests/GalSim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/GalSim b/tests/GalSim index 0dabbf46..5cd4c1ec 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 0dabbf463b4af7f689074c8373a936d511e4b836 +Subproject commit 5cd4c1ecc8b856790558e39677900cc43e0ce67f From 4ed8c8edee08ba07f9f352d85bd5272690b72448 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 09:10:23 -0500 Subject: [PATCH 28/51] fix: raise for interpolated image init problems --- jax_galsim/interpolatedimage.py | 14 ++++++++++++++ tests/GalSim | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 0424e9d8..bfb5c5d7 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -2,6 +2,7 @@ import math from functools import partial +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -118,6 +119,11 @@ def __init__( elif not isinstance(image, Image): raise TypeError("Supplied image must be an Image or file name") + if not (image.dtype == jnp.float32 or image.dtype == jnp.float64): + raise GalSimValueError( + "Interpolated images must use a float-type image.", image.dtype + ) + self._jax_children = ( image, dict( @@ -506,6 +512,14 @@ def __init__( image=self._jax_children[0], ) + if calculate_stepk or calculate_maxk or flux is not None: + image = equinox.error_if( + image, + image.array.sum() == 0.0, + "This input image has zero total flux. It does not define a " + "valid surface brightness profile.", + ) + @doc_inherit def withGSParams(self, gsparams=None, **kwargs): if gsparams == self.gsparams: diff --git a/tests/GalSim b/tests/GalSim index 5cd4c1ec..09ded8ab 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 5cd4c1ecc8b856790558e39677900cc43e0ce67f +Subproject commit 09ded8abfa570f836084ef9cf8d53c210203f825 From 7963c1f66cce1ddb3d210fda921e781032339ab7 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 09:15:49 -0500 Subject: [PATCH 29/51] doc: add docs for exceptions --- jax_galsim/interpolatedimage.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index bfb5c5d7..f2d0e11c 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -60,6 +60,8 @@ def __dir__(cls): - the pad_image options - depixelize - most of the bounds checks, type checks, and dtype casts done by galsim +- raises a generic ``Exception`` instead of a more specific one for some + initialization errors """ From 2e2aaca68758c24da9ac30106d4698e99f69d5f8 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 10:00:35 -0500 Subject: [PATCH 30/51] fix: use array in transform, not image --- jax_galsim/interpolatedimage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index f2d0e11c..e11ba44a 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -515,8 +515,8 @@ def __init__( ) if calculate_stepk or calculate_maxk or flux is not None: - image = equinox.error_if( - image, + image.array = equinox.error_if( + image.array, image.array.sum() == 0.0, "This input image has zero total flux. It does not define a " "valid surface brightness profile.", From 58196ca0f18275c3c0c0cc51d45691beff02a72b Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 10:32:56 -0500 Subject: [PATCH 31/51] fix: ensure repr of image prints even with tracers --- jax_galsim/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 3e8c0e69..e964380b 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -326,7 +326,7 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): def __repr__(self): s = "galsim.Image(bounds=%r" % self.bounds - if self.bounds.isDefined(): + if self.bounds.isDefined() and not has_tracers(self.array): s += ", array=\n%r" % (ensure_hashable(np.array(self.array)),) s += ", wcs=%r" % self.wcs if self.isconst: From 58efd99e639cf2b1a5393bc55a5f5f7249d1c068 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 10:41:39 -0500 Subject: [PATCH 32/51] fix: raise for invalid beta --- jax_galsim/moffat.py | 15 ++++++++++++++- tests/jax/test_moffat_jax.py | 18 ++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 tests/jax/test_moffat_jax.py diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 2a9b312b..5d994fe1 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -1,5 +1,6 @@ from functools import partial +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -31,7 +32,7 @@ def _Knu(nu, x): lax_description="""\ The JAX-GalSim version of the Moffat profile -- does not support truncation or beta < 1.1 +- does not support truncation or beta <= 1.1 - does not support gsparams.maxk_thresholds > 0.1 - does not support autodiff with respect to the `beta` parameter for Fourier-space evaluations @@ -67,6 +68,18 @@ def __init__( f"(got trunc={repr(trunc)}, always pass the constant 0.0)!" ) + if isinstance(beta, (float, int)): + if beta <= self._beta_thr: + raise ValueError( + f"JAX-GalSim does not support Moffat beta values <= {self._beta_thr}." + ) + else: + beta = equinox.error_if( + jnp.array(beta), + beta <= self._beta_thr, + f"JAX-GalSim does not support Moffat beta values <= {self._beta_thr}.", + ) + # Parse the radius options if half_light_radius is not None: if scale_radius is not None or fwhm is not None: diff --git a/tests/jax/test_moffat_jax.py b/tests/jax/test_moffat_jax.py new file mode 100644 index 00000000..4810a105 --- /dev/null +++ b/tests/jax/test_moffat_jax.py @@ -0,0 +1,18 @@ +import jax +import jax.numpy as jnp +import pytest + +import jax_galsim + + +def test_moffat_jax_beta_raises(): + + @jax.jit + def make_moffat(beta): + return jax_galsim.Moffat(beta, fwhm=1.0) + + with pytest.raises(Exception): + make_moffat(jnp.array(1.1)) + + with pytest.raises(Exception): + make_moffat(0.9) From 21ffe091054c3bc353b4e9be906d7b2412ade28e Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 10:49:45 -0500 Subject: [PATCH 33/51] fix: need to ensure images always hold jax arrays --- jax_galsim/image.py | 2 +- jax_galsim/interpolatedimage.py | 2 +- tests/GalSim | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index e964380b..6a1a18df 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -102,7 +102,7 @@ def __init__(self, *args, **kwargs): ) else: if "array" in kwargs: - array = kwargs.pop("array") + array = jnp.array(kwargs.pop("array")) array, xmin, ymin = self._get_xmin_ymin( array, kwargs, check_bounds=_check_bounds ) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index e11ba44a..722422d1 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -516,7 +516,7 @@ def __init__( if calculate_stepk or calculate_maxk or flux is not None: image.array = equinox.error_if( - image.array, + jnp.array(image.array), image.array.sum() == 0.0, "This input image has zero total flux. It does not define a " "valid surface brightness profile.", diff --git a/tests/GalSim b/tests/GalSim index 09ded8ab..d8ec29cb 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 09ded8abfa570f836084ef9cf8d53c210203f825 +Subproject commit d8ec29cbc70a8d4e92bd6bc2f3db2ec248ba3e06 From c892e04446ea82acf32506affa044da375eb7ed7 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 10:58:27 -0500 Subject: [PATCH 34/51] fix: raise if we do not have array as kwarg --- jax_galsim/image.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 6a1a18df..d7eac192 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -102,7 +102,14 @@ def __init__(self, *args, **kwargs): ) else: if "array" in kwargs: - array = jnp.array(kwargs.pop("array")) + array = kwargs.pop("array") + if has_tracers(array) or isinstance(array, jnp.ndarray): + pass + elif isinstance(array, np.ndarray): + array = jnp.array(cast_numpy_array_to_native_byte_order(array)) + else: + raise TypeError("Unable to parse %s as an array." % array) + array, xmin, ymin = self._get_xmin_ymin( array, kwargs, check_bounds=_check_bounds ) From 0b366def4acf31d271626305c67fa92241b6f420 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 11:45:51 -0500 Subject: [PATCH 35/51] fix: raise errors for RNG inits --- jax_galsim/random.py | 42 +++++++++++++++++++++++++++++++++++++++++- tests/GalSim | 2 +- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index db027ed7..c3c149c9 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -1,10 +1,12 @@ import secrets from functools import partial +import equinox import galsim as _galsim import jax import jax.numpy as jnp import jax.random as jrandom +import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import implements @@ -122,9 +124,16 @@ def reset(self, seed=None): self._state = _DeviateState( wrap_key_data(jnp.array(seed, dtype=jnp.uint32)) ) - else: + elif ( + isinstance(seed, (int, jnp.ndarray, jax.Array, np.ndarray)) or seed is None + ): _initial_seed = seed or secrets.randbelow(2**31) self._state = _DeviateState(jrandom.key(_initial_seed)) + else: + raise TypeError( + "Seeds for BaseDeviate must be an int-like, str, tuple, or another BaseDeviate." + f"Got seed {seed!r}." + ) @property def _key(self): @@ -295,6 +304,19 @@ def __str__(self): class GaussianDeviate(BaseDeviate): def __init__(self, seed=None, mean=0.0, sigma=1.0): super().__init__(seed=seed) + + if isinstance(sigma, (int, float)): + if sigma <= 0: + raise ValueError( + f"Gaussian deviates must have a positive sigma. Got {sigma!r}." + ) + else: + sigma = equinox.error_if( + jnp.array(sigma), + sigma <= 0, + f"Gaussian deviates must have a positive sigma. Got {sigma!r}.", + ) + self._params["mean"] = mean self._params["sigma"] = sigma @@ -435,6 +457,19 @@ def __str__(self): class PoissonDeviate(BaseDeviate): def __init__(self, seed=None, mean=1.0): super().__init__(seed=seed) + + if isinstance(mean, (int, float)): + if mean < 0: + raise ValueError( + f"Poisson deviates must have a non-negative mean. Got {mean!r}." + ) + else: + mean = equinox.error_if( + jnp.array(mean), + mean < 0, + f"Poisson deviates must have a non-negative mean. Got {mean!r}.", + ) + self._params["mean"] = mean @property @@ -484,6 +519,11 @@ def _generate_one(key, mean): @implements(_galsim.PoissonDeviate.generate_from_expectation) def generate_from_expectation(self, array): + array = equinox.error_if( + jnp.array(array), + jnp.any(jnp.array(array) < 0), + "Poission deviates must have a non-negative mean.", + ) self._key, _array = self.__class__._generate_from_exp(self._key, array) return _array diff --git a/tests/GalSim b/tests/GalSim index d8ec29cb..18639cf8 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit d8ec29cbc70a8d4e92bd6bc2f3db2ec248ba3e06 +Subproject commit 18639cf834bc3d7d60f78c626bd1af3f55232e92 From 68118356f3e10bd386d75b528a9075f7f90e8929 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 11:54:31 -0500 Subject: [PATCH 36/51] fix: ensure errors are raised for random permutations --- jax_galsim/random.py | 4 ++++ tests/GalSim | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index c3c149c9..ae09a8e2 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -746,6 +746,10 @@ def __str__(self): ) def permute(rng, *args): rng = BaseDeviate(rng) + if len(args) == 0: + raise TypeError( + f"`galsim.random.permute` must be called with at least one array. Got {args!r}" + ) arrs = [] for arr in args: arrs.append(jrandom.permutation(rng._key, arr)) diff --git a/tests/GalSim b/tests/GalSim index 18639cf8..9aca22b7 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 18639cf834bc3d7d60f78c626bd1af3f55232e92 +Subproject commit 9aca22b740d34b9e446487405018a19a50ecedf9 From 15e1398805f49e5ea6f4829097d8fcd2221f2475 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 12:16:51 -0500 Subject: [PATCH 37/51] fix: accept integer scalars too --- jax_galsim/random.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index ae09a8e2..b3ef792c 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -125,13 +125,16 @@ def reset(self, seed=None): wrap_key_data(jnp.array(seed, dtype=jnp.uint32)) ) elif ( - isinstance(seed, (int, jnp.ndarray, jax.Array, np.ndarray)) or seed is None + isinstance( + seed, (int, jnp.ndarray, jax.Array, np.ndarray, np.integer, jnp.integer) + ) + or seed is None ): _initial_seed = seed or secrets.randbelow(2**31) self._state = _DeviateState(jrandom.key(_initial_seed)) else: raise TypeError( - "Seeds for BaseDeviate must be an int-like, str, tuple, or another BaseDeviate." + "Seeds for BaseDeviate must be an int-like, str, tuple, or another BaseDeviate. " f"Got seed {seed!r}." ) From 3d645a289ff1fa2f497e55a1e026e7ae63160674 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 12:30:02 -0500 Subject: [PATCH 38/51] fix: apparently this does not work on tracers --- jax_galsim/random.py | 6 +++--- tests/jax/test_random_jax.py | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 tests/jax/test_random_jax.py diff --git a/jax_galsim/random.py b/jax_galsim/random.py index b3ef792c..4cb7be6d 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -9,7 +9,7 @@ import numpy as np from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import implements +from jax_galsim.core.utils import has_tracers, implements try: from jax.extend.random import wrap_key_data @@ -313,7 +313,7 @@ def __init__(self, seed=None, mean=0.0, sigma=1.0): raise ValueError( f"Gaussian deviates must have a positive sigma. Got {sigma!r}." ) - else: + elif not has_tracers(sigma): sigma = equinox.error_if( jnp.array(sigma), sigma <= 0, @@ -466,7 +466,7 @@ def __init__(self, seed=None, mean=1.0): raise ValueError( f"Poisson deviates must have a non-negative mean. Got {mean!r}." ) - else: + elif not has_tracers(mean): mean = equinox.error_if( jnp.array(mean), mean < 0, diff --git a/tests/jax/test_random_jax.py b/tests/jax/test_random_jax.py new file mode 100644 index 00000000..b01ba9b9 --- /dev/null +++ b/tests/jax/test_random_jax.py @@ -0,0 +1,22 @@ +import jax +import pytest + +import jax_galsim + + +def test_random_jax_gaussian_pos_sigma_jit(): + @jax.jit + def _make_gauss(sigma): + return jax_galsim.GaussianDeviate(seed=10, sigma=sigma) + + with pytest.raises(Exception): + _make_gauss(-1.0) + + @jax.jit + def _make_gauss(sigma): + return jax_galsim.GaussianDeviate(seed=10, sigma=sigma) + + _make_gauss(1.0) + + with pytest.raises(Exception): + _make_gauss(-1) From a9f5171ef015884b6e1208c8ca3b5aaae6ebba5f Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 15 May 2026 12:33:43 -0500 Subject: [PATCH 39/51] Apply suggestions from code review Co-authored-by: Matthew R. Becker --- jax_galsim/core/utils.py | 2 +- jax_galsim/moffat.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 71ff2735..abbc1aca 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -45,7 +45,7 @@ def check_is_int_then_cast(val, msg): if val != int(val): raise TypeError(msg) val = int(val) - else: + elif not has_tracers(val): # otherwise we use more opaque checking upon jit via equinox val = equinox.error_if( val, diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 5d994fe1..8235f309 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -73,7 +73,7 @@ def __init__( raise ValueError( f"JAX-GalSim does not support Moffat beta values <= {self._beta_thr}." ) - else: + elif not has_tracers(beta): beta = equinox.error_if( jnp.array(beta), beta <= self._beta_thr, From c559408a293a365e1365c6f31501fdd861e7ffc0 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 12:35:10 -0500 Subject: [PATCH 40/51] fix: really on this one --- jax_galsim/random.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 4cb7be6d..10d78b5c 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -309,15 +309,15 @@ def __init__(self, seed=None, mean=0.0, sigma=1.0): super().__init__(seed=seed) if isinstance(sigma, (int, float)): - if sigma <= 0: + if sigma < 0: raise ValueError( - f"Gaussian deviates must have a positive sigma. Got {sigma!r}." + f"Gaussian deviates must have a non-negative sigma. Got {sigma!r}." ) elif not has_tracers(sigma): sigma = equinox.error_if( jnp.array(sigma), - sigma <= 0, - f"Gaussian deviates must have a positive sigma. Got {sigma!r}.", + sigma < 0, + f"Gaussian deviates must have a non-negative sigma. Got {sigma!r}.", ) self._params["mean"] = mean From df2a3d4a31a077edda88cc652a5f9ce798ef7641 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 12:39:21 -0500 Subject: [PATCH 41/51] fix: more tests --- tests/jax/test_random_jax.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/jax/test_random_jax.py b/tests/jax/test_random_jax.py index b01ba9b9..c9309eb6 100644 --- a/tests/jax/test_random_jax.py +++ b/tests/jax/test_random_jax.py @@ -13,10 +13,18 @@ def _make_gauss(sigma): _make_gauss(-1.0) @jax.jit - def _make_gauss(sigma): + def _make_gauss_again(sigma): + return jax_galsim.GaussianDeviate(seed=10, sigma=sigma) + + _make_gauss_again(1.0) + + with pytest.raises(Exception): + _make_gauss_again(-1) + + def _make_gauss_again_again(sigma): return jax_galsim.GaussianDeviate(seed=10, sigma=sigma) - _make_gauss(1.0) + _make_gauss_again_again(1.0) with pytest.raises(Exception): - _make_gauss(-1) + jax.jit(_make_gauss_again_again)(-1) From 6d40f5a7f9b2672f2e6ab81289607d6c5026d8a5 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 12:43:53 -0500 Subject: [PATCH 42/51] fix: ensure we use any or all in calls for errors --- jax_galsim/core/utils.py | 5 +++-- jax_galsim/moffat.py | 5 +++-- jax_galsim/random.py | 16 +++++++++------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index abbc1aca..f6785ceb 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -45,11 +45,12 @@ def check_is_int_then_cast(val, msg): if val != int(val): raise TypeError(msg) val = int(val) - elif not has_tracers(val): + else: # otherwise we use more opaque checking upon jit via equinox + val = jnp.array(val) val = equinox.error_if( val, - val != jnp.trunc(val), + np.any(val != jnp.trunc(val)), msg, ) val = val.astype(int) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 8235f309..387fc4e3 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -74,9 +74,10 @@ def __init__( f"JAX-GalSim does not support Moffat beta values <= {self._beta_thr}." ) elif not has_tracers(beta): + beta = jnp.array(beta) beta = equinox.error_if( - jnp.array(beta), - beta <= self._beta_thr, + beta, + jnp.any(beta <= self._beta_thr), f"JAX-GalSim does not support Moffat beta values <= {self._beta_thr}.", ) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 10d78b5c..789ec692 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -9,7 +9,7 @@ import numpy as np from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import has_tracers, implements +from jax_galsim.core.utils import implements try: from jax.extend.random import wrap_key_data @@ -313,10 +313,11 @@ def __init__(self, seed=None, mean=0.0, sigma=1.0): raise ValueError( f"Gaussian deviates must have a non-negative sigma. Got {sigma!r}." ) - elif not has_tracers(sigma): + else: + sigma = jnp.array(sigma) sigma = equinox.error_if( - jnp.array(sigma), - sigma < 0, + sigma, + jnp.any(sigma < 0), f"Gaussian deviates must have a non-negative sigma. Got {sigma!r}.", ) @@ -466,10 +467,11 @@ def __init__(self, seed=None, mean=1.0): raise ValueError( f"Poisson deviates must have a non-negative mean. Got {mean!r}." ) - elif not has_tracers(mean): + else: + mean = jnp.array(mean) mean = equinox.error_if( - jnp.array(mean), - mean < 0, + mean, + jnp.any(mean < 0), f"Poisson deviates must have a non-negative mean. Got {mean!r}.", ) From 1f0d1da0f49a33e0f20b46f70101456d216f231a Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 12:47:57 -0500 Subject: [PATCH 43/51] fix: use any everywhere --- jax_galsim/gsobject.py | 12 ++++++++---- jax_galsim/integ.py | 2 +- jax_galsim/interpolatedimage.py | 5 +++-- jax_galsim/shear.py | 12 ++++++------ 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index db59db43..a5cf51b7 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -629,10 +629,13 @@ def drawImage( raise TypeError("image is not an Image instance", image) # Make sure (gain, area, exptime) have valid values: - gain = equinox.error_if(jnp.array(gain), gain <= 0.0, "Invalid gain <= 0.") - area = equinox.error_if(jnp.array(area), area <= 0.0, "Invalid area <= 0.") + gain = jnp.array(gain) + gain = equinox.error_if(gain, jnp.any(gain <= 0.0), "Invalid gain <= 0.") + area = jnp.array(area) + area = equinox.error_if(area, jnp.any(area <= 0.0), "Invalid area <= 0.") + exptime = jnp.array(exptime) exptime = equinox.error_if( - jnp.array(exptime), exptime <= 0.0, "Invalid exptime <= 0." + exptime, jnp.any(exptime <= 0.0), "Invalid exptime <= 0." ) if method == "phot" and save_photons and maxN is not None: @@ -1228,7 +1231,8 @@ def drawPhot( elif not isinstance(sensor, Sensor): raise TypeError("The sensor provided is not a Sensor instance") - gain = equinox.error_if(jnp.array(gain), gain <= 0.0, "Invalid gain <= 0.") + gain = jnp.array(gain) + gain = equinox.error_if(gain, jnp.any(gain <= 0.0), "Invalid gain <= 0.") if n_photons is not None: # n_photons is the length of an array so it is a python int and diff --git a/jax_galsim/integ.py b/jax_galsim/integ.py index de1c26ce..20397be2 100644 --- a/jax_galsim/integ.py +++ b/jax_galsim/integ.py @@ -75,6 +75,6 @@ def _base_integration(): return equinox.error_if( val, - status != 0, + jnp.any(status != 0), "`jax_galsim.int1d` failed to converge!", ) diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 722422d1..9564bd36 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -515,9 +515,10 @@ def __init__( ) if calculate_stepk or calculate_maxk or flux is not None: + image.array = jnp.array(image.array) image.array = equinox.error_if( - jnp.array(image.array), - image.array.sum() == 0.0, + image.array, + jnp.any(image.array.sum() == 0.0), "This input image has zero total flux. It does not define a " "valid surface brightness profile.", ) diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index dd88f424..b0663890 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -53,7 +53,7 @@ def __init__(self, *args, **kwargs): self._g = g1 + 1j * g2 self._g = equinox.error_if( self._g, - jnp.abs(self._g) > 1.0, + jnp.any(jnp.abs(self._g) > 1.0), "Requested shear exceeds 1.", ) @@ -64,7 +64,7 @@ def __init__(self, *args, **kwargs): absesq = e1**2 + e2**2 absesq = equinox.error_if( absesq, - absesq > 1.0, + jnp.any(absesq > 1.0), "Requested distortion exceeds 1.", ) self._g = (e1 + 1j * e2) * self._e2g(absesq) @@ -91,7 +91,7 @@ def __init__(self, *args, **kwargs): g = jnp.array(kwargs.pop("g")) g = equinox.error_if( g, - g > 1 or g < 0, + jnp.any((g > 1) | (g < 0)), "Requested |shear| is outside [0,1].", ) self._g = g * jnp.exp(2j * beta.rad) @@ -110,7 +110,7 @@ def __init__(self, *args, **kwargs): e = jnp.array(kwargs.pop("e")) e = equinox.error_if( e, - (e > 1) | (e < 0), + jnp.any((e > 1) | (e < 0)), "Requested distortion is outside [0,1].", ) self._g = self._e2g(e**2) * e * jnp.exp(2j * beta.rad) @@ -129,7 +129,7 @@ def __init__(self, *args, **kwargs): eta = jnp.array(kwargs.pop("eta")) eta = equinox.error_if( eta, - eta < 0, + jnp.any(eta < 0), "Requested eta is below 0.", ) self._g = self._eta2g(eta) * eta * jnp.exp(2j * beta.rad) @@ -148,7 +148,7 @@ def __init__(self, *args, **kwargs): q = jnp.array(kwargs.pop("q")) q = equinox.error_if( q, - (q <= 0) | (q > 1), + jnp.any((q <= 0) | (q > 1)), "Cannot use requested axis ratio.", ) eta = -jnp.log(q) From a7e2300411d4b2f24eb4d92ca4f126b3620a08fe Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 13:37:12 -0500 Subject: [PATCH 44/51] fix: enable index check for photon arrays --- jax_galsim/photon_array.py | 18 ++++++++++++++++++ tests/GalSim | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 336c28a7..967f55c3 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -561,6 +561,24 @@ def copyFrom( do_flux=True, do_other=True, ): + # jax naturally checks the other error cases in the test suite with the `.at` + # syntax, but it does not check out of bounds inds like ints so we do that here + if isinstance(target_indices, int) and ( + target_indices < -self._nokeep.shape[0] + or target_indices >= self._nokeep.shape[0] + ): + raise ValueError( + f"target_indices is invalid for the target PhotonArray. Got {target_indices!r}" + ) + + if isinstance(source_indices, int) and ( + source_indices < -rhs._nokeep.shape[0] + or source_indices >= rhs._nokeep.shape[0] + ): + raise ValueError( + f"source_indices is invalid for the source PhotonArray. Got {source_indices!r}" + ) + return self._copyFrom( rhs, target_indices, source_indices, do_xy, do_flux, do_other ) diff --git a/tests/GalSim b/tests/GalSim index 9aca22b7..0b4ab35c 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 9aca22b740d34b9e446487405018a19a50ecedf9 +Subproject commit 0b4ab35c1ef5f42ff48d4d082369563fdb440cf7 From 6d970a301f1977de88d08d411997796436b9117e Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Fri, 15 May 2026 13:39:24 -0500 Subject: [PATCH 45/51] Apply suggestion from @beckermr --- jax_galsim/moffat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 387fc4e3..932c48b8 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -73,7 +73,7 @@ def __init__( raise ValueError( f"JAX-GalSim does not support Moffat beta values <= {self._beta_thr}." ) - elif not has_tracers(beta): + else: beta = jnp.array(beta) beta = equinox.error_if( beta, From f1498e294203ff1b3b971eca62005bff098cd770 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 15:16:07 -0500 Subject: [PATCH 46/51] fix: enable exceptions for WCS --- jax_galsim/fitswcs.py | 11 +++++------ tests/GalSim | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index 232801d6..d64564ab 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -1,6 +1,7 @@ import copy import os +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -1094,12 +1095,10 @@ def _step(i, args): unroll=True, )[0:4] - x, y = jax.lax.cond( - jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12, - lambda x, y: (x * jnp.nan, y * jnp.nan), - lambda x, y: (x, y), - x, - y, + x, y = equinox.error_if( + (x, y), + jnp.any(jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12), + "Unable to solve for image_pos (max iter reached).", ) return x, y diff --git a/tests/GalSim b/tests/GalSim index 0b4ab35c..ccd0f55e 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 0b4ab35c1ef5f42ff48d4d082369563fdb440cf7 +Subproject commit ccd0f55e7f1952c1e36680786a36169ba26ec19e From 6c526d44c7dbf5b0856fc5372ec80116cbbc1f54 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 17:10:22 -0500 Subject: [PATCH 47/51] fix: raise errors for celestial coords --- jax_galsim/angle.py | 33 +++++++++++++++++++++++++-- jax_galsim/celestial.py | 40 ++++++++++++++++++++------------- tests/GalSim | 2 +- tests/conftest.py | 3 +++ tests/galsim_tests_config.yaml | 12 +++++++--- tests/jax/test_celestial_jax.py | 20 +++++++++++++++++ 6 files changed, 88 insertions(+), 22 deletions(-) diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index fad56976..aec36d90 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -21,9 +21,30 @@ # SOFTWARE. import galsim as _galsim import jax.numpy as jnp +import numpy as np from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import cast_to_float, ensure_hashable, implements +from jax_galsim.core.utils import ( + cast_to_float, + ensure_hashable, + has_tracers, + implements, +) + +NON_COMPLEX_TYPES = ( + float, + int, + np.int16, + np.int32, + np.int64, + np.float32, + np.float64, + jnp.int16, + jnp.int32, + jnp.int64, + jnp.float32, + jnp.float64, +) @implements(_galsim.AngleUnit) @@ -178,6 +199,10 @@ def __sub__(self, other): return _Angle(self._rad - other._rad) def __mul__(self, other): + if not (has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES)): + raise TypeError( + "Cannot multiply Angle by %s of type %s" % (other, type(other)) + ) return _Angle(self._rad * other) __rmul__ = __mul__ @@ -185,8 +210,12 @@ def __mul__(self, other): def __div__(self, other): if isinstance(other, AngleUnit): return self._rad / other.value - else: + elif has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES): return _Angle(self._rad / other) + else: + raise TypeError( + "Cannot divide Angle by %s of type %s" % (other, type(other)) + ) __truediv__ = __div__ diff --git a/jax_galsim/celestial.py b/jax_galsim/celestial.py index 1b6e992f..5645eda3 100644 --- a/jax_galsim/celestial.py +++ b/jax_galsim/celestial.py @@ -23,9 +23,11 @@ from functools import partial import coord as _coord +import equinox import galsim as _galsim import jax import jax.numpy as jnp +import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.angle import Angle, _Angle, arcsec, degrees, radians @@ -74,6 +76,16 @@ def __init__(self, ra, dec=None): elif not isinstance(dec, Angle): raise TypeError("dec must be a galsim.Angle") else: + if isinstance(dec._rad, (float, int)): + if dec._rad < -np.pi / 2 or dec._rad > np.pi / 2: + raise ValueError("dec must be between -90 deg and +90 deg.") + else: + dec._rad = equinox.error_if( + jnp.array(dec._rad), + jnp.any((dec._rad < -jnp.pi / 2) | (dec._rad > jnp.pi / 2)), + "dec must be between -90 deg and +90 deg.", + ) + # Normal case self._ra = ra self._dec = dec @@ -121,15 +133,14 @@ def get_xyz(self): @staticmethod @jax.jit - @implements( - _galsim.celestial.CelestialCoord.from_xyz, - lax_description=( - "The JAX version of this static method does not check that the norm of the input " - "vector is non-zero." - ), - ) + @implements(_galsim.celestial.CelestialCoord.from_xyz) def from_xyz(x, y, z): norm = jnp.sqrt(x * x + y * y + z * z) + norm = equinox.error_if( + norm, + jnp.any(norm == 0), + "CelestialCoord for position (0,0,0) is undefined.", + ) ret = CelestialCoord.__new__(CelestialCoord) ret._x = x / norm ret._y = y / norm @@ -236,13 +247,7 @@ def distanceTo(self, coord2): return _Angle(theta) - @implements( - _galsim.celestial.CelestialCoord.greatCirclePoint, - lax_description=( - "The JAX version of this method does not check that coord2 defines a unique great " - "circle with the current coord at angle theta." - ), - ) + @implements(_galsim.celestial.CelestialCoord.greatCirclePoint) @jax.jit def greatCirclePoint(self, coord2, theta): aux = self._get_aux() @@ -280,8 +285,11 @@ def greatCirclePoint(self, coord2, theta): # Normalize wr = (wx**2 + wy**2 + wz**2) ** 0.5 - # if wr == 0.: - # raise ValueError("coord2 does not define a unique great circle with self.") + wr = equinox.error_if( + wr, + jnp.any(wr == 0), + "coord2 does not define a unique great circle with self.", + ) wx /= wr wy /= wr wz /= wr diff --git a/tests/GalSim b/tests/GalSim index ccd0f55e..549616e8 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit ccd0f55e7f1952c1e36680786a36169ba26ec19e +Subproject commit 549616e8ca4bb84142fae6cdb0a006669f92454b diff --git a/tests/conftest.py b/tests/conftest.py index 447c1678..53d2e1c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -98,6 +98,9 @@ def pytest_collection_modifyitems(config, items): ): item.add_marker(skip) + if any([t in item.nodeid for t in test_config["skipped_tests"]["coord"]]): + item.add_marker(skip) + @lru_cache(maxsize=128) def _infile(val, fname): diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 46a64c18..6595c44f 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -28,6 +28,13 @@ enabled_tests: - test_astropy.py - test_celestial.py +# tests to explicitly skip +# applied on top of the enabled set above +skipped_tests: + coord: + - "tests/Coord/tests/test_celestial.py::test_xyz_raises" + - "tests/Coord/tests/test_celestial.py::test_greatcircle_raises" + # This documents which error messages will be allowed # without being reported as an error. These typically # correspond to features that are not implemented yet @@ -88,9 +95,8 @@ allowed_failures: - "module 'jax_galsim' has no attribute 'CorrelatedNoise'" - "CelestialCoord.precess is too slow" # cannot get jax to warmup but once it does it passes - "ValueError not raised by from_xyz" - - "ValueError not raised by greatCirclePoint" - - "TypeError not raised by __mul__" - - "ValueError not raised by CelestialCoord" + # - "ValueError not raised by greatCirclePoint" + # - "ValueError not raised by CelestialCoord" - "module 'jax_galsim' has no attribute 'BaseCorrelatedNoise'" - "module 'jax_galsim' has no attribute 'fft'" - "Transform does not support callable arguments." diff --git a/tests/jax/test_celestial_jax.py b/tests/jax/test_celestial_jax.py index d4210ac9..8a9cc63e 100644 --- a/tests/jax/test_celestial_jax.py +++ b/tests/jax/test_celestial_jax.py @@ -1,4 +1,5 @@ import galsim as _galsim +import jax.numpy as jnp import numpy as np import pytest @@ -118,3 +119,22 @@ def test_celestial_jax_ecliptic_obliquity(): ecliptic_obliquity(epoch).rad, _ecliptic_obliquity(epoch).rad, ) + + +def test_celestial_jax_xyz_raises(): + np.testing.assert_raises( + Exception, jax_galsim.CelestialCoord.from_xyz, 0.0, 0.0, 0.0 + ) + + +def test_celestial_jax_greatcircle_raises(): + theta = 50 * jax_galsim.radians + eq1 = jax_galsim.CelestialCoord( + 0 * jax_galsim.radians, 0 * jax_galsim.radians + ) # point on the equator + eq2 = jax_galsim.CelestialCoord( + jnp.array(1) * jax_galsim.radians, 0 * jax_galsim.radians + ) # 1 radian along equator + + np.testing.assert_raises(Exception, eq1.greatCirclePoint, eq1, theta) + np.testing.assert_raises(Exception, eq2.greatCirclePoint, eq2, theta) From 97ed8d803c323215fa0ebe70453ae594956bfbb9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 17:11:07 -0500 Subject: [PATCH 48/51] fix: remove skip statements from celestial coord --- tests/galsim_tests_config.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 6595c44f..428572a4 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -94,9 +94,6 @@ allowed_failures: - "module 'jax_galsim' has no attribute 'InterpolatedKImage'" - "module 'jax_galsim' has no attribute 'CorrelatedNoise'" - "CelestialCoord.precess is too slow" # cannot get jax to warmup but once it does it passes - - "ValueError not raised by from_xyz" - # - "ValueError not raised by greatCirclePoint" - # - "ValueError not raised by CelestialCoord" - "module 'jax_galsim' has no attribute 'BaseCorrelatedNoise'" - "module 'jax_galsim' has no attribute 'fft'" - "Transform does not support callable arguments." From 7486ffb66d261aa7d992eda57ffb6e7bd9736a68 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 18:32:59 -0500 Subject: [PATCH 49/51] test: use ab matrix that is more stable for benchmarks --- tests/jax/test_benchmarks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index 303631bd..5c2f7dc6 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -280,9 +280,9 @@ def _run_benchmark_invert_ab_noraise(u, v, ab): @pytest.mark.parametrize("kind", ["run"]) def test_benchmark_invert_ab_noraise(benchmark, kind): - u = jnp.arange(1000).astype(jnp.float64) - v = jnp.arange(1000).astype(jnp.float64) - ab = jnp.array([[[-0.5, 0.3], [-0.1, 2.0]], [[-1.0, 0.3], [-0.1, 4.0]]]) + u = jnp.arange(1000).astype(jnp.float64) / 1000.0 + v = jnp.arange(1000).astype(jnp.float64) / 1000.0 + ab = jnp.array([[[0.6, 0.04], [-0.03, 0.5]], [[0.4, -0.02], [0.01, 0.7]]]) dt = _run_benchmarks( benchmark, kind, From 2610daadacaec019c304c787829e2e90695fb14f Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 18 May 2026 11:52:18 -0500 Subject: [PATCH 50/51] Apply suggestions from code review Co-authored-by: Matthew R. Becker --- jax_galsim/random.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 789ec692..8ce333bb 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -105,7 +105,6 @@ def _seed(self, seed=None): @implements( _galsim.BaseDeviate.reset, - lax_description=("The JAX version of this method does no type checking."), ) def reset(self, seed=None): if isinstance(seed, _DeviateState): From 29f83f3cedfa5c4606fae1a8cfbd0ccacaf19cb6 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 18 May 2026 12:02:19 -0500 Subject: [PATCH 51/51] fix: add type checking to seed method Add type checking for seed parameter in the seed method. --- jax_galsim/random.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 8ce333bb..2cf2db27 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -91,12 +91,22 @@ def has_reliable_discard(self): def generates_in_pairs(self): return False - @implements( - _galsim.BaseDeviate.seed, - lax_description="The JAX version of this method does no type checking.", - ) + @implements(_galsim.BaseDeviate.seed) def seed(self, seed=None): - self._seed(seed=seed) + if seed is None: + self._seed(seed=seed) + elif isinstance(seed, (int, float, np.integer, np.floating)): + if seed == int(seed): + self._seed(seed=int(seed)) + else: + raise TypeError(f"BaseDeviate seed must be an integer. Got {seed!r}.") + else: + seed = equinox.error_if( + seed, + jnp.any(seed != jnp.trunc(seed)), + "BaseDeviate seed must be an integer.", + ) + self._seed(seed=seed) @implements(_galsim.BaseDeviate._seed) def _seed(self, seed=None):