diff --git a/docs/api/chromatic.rst b/docs/api/chromatic.rst new file mode 100644 index 00000000..402224e7 --- /dev/null +++ b/docs/api/chromatic.rst @@ -0,0 +1,83 @@ +Chromatic Profiles +================== + +.. currentmodule:: jax_galsim + +JAX-GalSim supports a JAX-native subset of GalSim chromatic rendering. The +core use case is a wavelength-dependent PSF convolved with a source whose SED +is a traced JAX array: + +.. code-block:: python + + import jax + import jax.numpy as jnp + import jax_galsim as galsim + + wave = jnp.linspace(400.0, 900.0, 256) + sed = galsim.SED(wave, jnp.ones_like(wave)) + bandpass = galsim.Bandpass.tophat(550.0, 750.0) + + gal = galsim.Gaussian(half_light_radius=0.5) * sed + psf = galsim.ChromaticAtmosphere(fwhm_ref=0.7, lam_ref=700.0, alpha=-0.2) + final = galsim.ChromaticConvolution([gal, psf]) + image = final.drawImage(bandpass, nx=64, ny=64, scale=0.2, n_waves=32) + +Separable chromatic objects, such as ``GSObject * SED``, are rendered by +integrating only the scalar spectral weight, then drawing one unit-flux spatial +profile. Non-separable chromatic objects, such as a chromatic PSF whose size +changes with wavelength, are rendered by integrating their Fourier-space values +over a fixed wavelength grid. + +The wavelength grid size is static, while SED flux values are traced. This +keeps the rendering path compatible with ``jax.jit`` and ``jax.grad``. + +Spectral objects +---------------- + +.. autoclass:: SED + :members: + :show-inheritance: + +.. autoclass:: Bandpass + :members: + :show-inheritance: + +Chromatic objects +----------------- + +.. autoclass:: ChromaticObject + :members: + :show-inheritance: + +.. autoclass:: Chromatic + :members: + :show-inheritance: + +.. autoclass:: ChromaticAtmosphere + :members: + :show-inheritance: + +.. autoclass:: ChromaticConvolution + :members: + :show-inheritance: + +.. autoclass:: ChromaticSum + :members: + :show-inheritance: + +Compatibility notes +------------------- + +``galsim.Convolve`` dispatches to ``ChromaticConvolution`` when any input is +chromatic. Multiplication follows GalSim's common pattern: + +.. code-block:: python + + chromatic_gal = galsim.Gaussian(half_light_radius=0.5) * sed + same_object = sed * galsim.Gaussian(half_light_radius=0.5) + +The current implementation focuses on differentiable array-backed SEDs, +array-backed bandpasses, separable chromatic sources, and non-separable +Gaussian/Moffat atmospheric PSFs. Full GalSim spectral I/O, lookup-table +metadata, Airy/OpticalPSF chromatic optics, and photon shooting are still +outside this JAX-native subset. diff --git a/docs/api/index.rst b/docs/api/index.rst index 95c5ff72..f61ff2d3 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -10,6 +10,7 @@ API Reference weak-lensing wcs noise + chromatic photon_shooting interpolation fits diff --git a/jax_galsim/__init__.py b/jax_galsim/__init__.py index 80cbe04b..1b681ae5 100644 --- a/jax_galsim/__init__.py +++ b/jax_galsim/__init__.py @@ -105,3 +105,14 @@ # this one is specific to jax_galsim from . import core + +# Chromatic profiles +from .sed import SED +from .bandpass import Bandpass +from .chromatic import ( + ChromaticObject, + Chromatic, + ChromaticAtmosphere, + ChromaticConvolution, + ChromaticSum, +) diff --git a/jax_galsim/bandpass.py b/jax_galsim/bandpass.py new file mode 100644 index 00000000..a6046e28 --- /dev/null +++ b/jax_galsim/bandpass.py @@ -0,0 +1,196 @@ +"""Bandpass filter representation for chromatic simulations.""" + +import jax.numpy as jnp +from jax.tree_util import register_pytree_node_class + + +@register_pytree_node_class +class Bandpass: + """Wavelength-dependent throughput of an observing bandpass. + + Represents the combined throughput of optics, filter, and detector + as a function of wavelength. The throughput array is a JAX-traced + parameter, so gradients can flow through it if needed (e.g. for + filter-design optimisation). + + Parameters + ---------- + wave : array_like + Wavelength array **in nanometers**, strictly increasing. + Treated as static (not traced by JAX). + throughput : array_like + Dimensionless throughput ∈ [0, 1] at each wavelength. + Treated as a JAX-traced parameter. + blue_limit : float, optional + Override the short-wavelength cut-off in nm. Defaults to + ``wave[0]``. + red_limit : float, optional + Override the long-wavelength cut-off in nm. Defaults to + ``wave[-1]``. + + Examples + -------- + Construct from arrays:: + + >>> import jax.numpy as jnp + >>> from jax_galsim.bandpass import Bandpass + >>> wave = jnp.linspace(550, 700, 100) + >>> bp = Bandpass(wave, jnp.ones(100)) + >>> float(bp(625.0)) + 1.0 + >>> float(bp.effective_wavelength) + 625.0 + + Top-hat filter between 600 nm and 700 nm:: + + >>> wave = jnp.array([500., 600., 600.001, 700., 700.001, 800.]) + >>> thru = jnp.array([0., 0., 1., 1., 0., 0.]) + >>> bp = Bandpass(wave, thru) + """ + + def __init__(self, wave, throughput, blue_limit=None, red_limit=None): + self._wave = jnp.asarray(wave, dtype=float) + self._throughput = jnp.asarray(throughput) + + if self._wave.ndim != 1 or len(self._wave) < 2: + raise ValueError("wave must be a 1-D array with at least 2 elements.") + if len(self._throughput) != len(self._wave): + raise ValueError("throughput must have the same length as wave.") + + self._blue_limit = ( + float(blue_limit) if blue_limit is not None else float(self._wave[0]) + ) + self._red_limit = ( + float(red_limit) if red_limit is not None else float(self._wave[-1]) + ) + + # Precompute effective wavelength at construction time so it is a + # concrete Python float and can be used as a static value under JIT. + _w = jnp.linspace(self._blue_limit, self._red_limit, 512) + _t = jnp.interp(_w, self._wave, self._throughput) + _norm = jnp.trapezoid(_t, _w) + self._effective_wavelength_val = ( + float(jnp.trapezoid(_w * _t, _w) / _norm) + if float(_norm) > 0 + else float(0.5 * (self._blue_limit + self._red_limit)) + ) + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def wave(self): + """Wavelength grid in nm (JAX array, static).""" + return self._wave + + @property + def throughput(self): + """Throughput array (JAX array, traced).""" + return self._throughput + + @property + def blue_limit(self): + """Short-wavelength cut-off in nm.""" + return self._blue_limit + + @property + def red_limit(self): + """Long-wavelength cut-off in nm.""" + return self._red_limit + + @property + def effective_wavelength(self): + """Flux-weighted mean wavelength (concrete Python float, safe under JIT).""" + return self._effective_wavelength_val + + # ------------------------------------------------------------------ + # Evaluation + # ------------------------------------------------------------------ + + def __call__(self, wave): + """Evaluate throughput at wavelength(s) in nm. + + Returns 0 outside the defined range. + + Parameters + ---------- + wave : float or array_like + Wavelength(s) in nm. + + Returns + ------- + jnp.ndarray + Throughput at the requested wavelengths. + """ + wave = jnp.asarray(wave, dtype=float) + return jnp.interp(wave, self._wave, self._throughput, left=0.0, right=0.0) + + # ------------------------------------------------------------------ + # Arithmetic + # ------------------------------------------------------------------ + + def __mul__(self, other): + """Multiply throughput by a scalar or another Bandpass.""" + if isinstance(other, Bandpass): + # Union wavelength grid + wave = jnp.unique(jnp.concatenate([self._wave, other._wave])) + t = jnp.interp(wave, self._wave, self._throughput, left=0.0, right=0.0) + t2 = jnp.interp(wave, other._wave, other._throughput, left=0.0, right=0.0) + blue = max(self._blue_limit, other._blue_limit) + red = min(self._red_limit, other._red_limit) + return Bandpass(wave, t * t2, blue_limit=blue, red_limit=red) + return Bandpass( + self._wave, self._throughput * other, self._blue_limit, self._red_limit + ) + + def __rmul__(self, other): + return self.__mul__(other) + + def truncate(self, blue_limit=None, red_limit=None): + """Return a new Bandpass with tighter wavelength limits.""" + blue = blue_limit if blue_limit is not None else self._blue_limit + red = red_limit if red_limit is not None else self._red_limit + return Bandpass(self._wave, self._throughput, blue_limit=blue, red_limit=red) + + # ------------------------------------------------------------------ + # Convenience constructors + # ------------------------------------------------------------------ + + @classmethod + def tophat(cls, blue_limit, red_limit, n_wave=100): + """Uniform throughput = 1 between blue_limit and red_limit.""" + wave = jnp.linspace(blue_limit, red_limit, n_wave) + return cls(wave, jnp.ones(n_wave)) + + # ------------------------------------------------------------------ + # JAX pytree interface + # ------------------------------------------------------------------ + + def tree_flatten(self): + children = (self._throughput,) + aux_data = { + "wave": tuple(self._wave.tolist()), + "blue_limit": self._blue_limit, + "red_limit": self._red_limit, + } + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls( + wave=jnp.asarray(aux_data["wave"], dtype=float), + throughput=children[0], + blue_limit=aux_data["blue_limit"], + red_limit=aux_data["red_limit"], + ) + + # ------------------------------------------------------------------ + # Misc + # ------------------------------------------------------------------ + + def __repr__(self): + return ( + f"Bandpass(wave=[{self._blue_limit:.1f}, {self._red_limit:.1f}] nm, " + f"lam_eff={self.effective_wavelength:.1f} nm)" + ) diff --git a/jax_galsim/chromatic.py b/jax_galsim/chromatic.py new file mode 100644 index 00000000..aaba0e81 --- /dev/null +++ b/jax_galsim/chromatic.py @@ -0,0 +1,992 @@ +"""Chromatic (wavelength-dependent) profiles for jax_galsim. + +Architecture overview +--------------------- +Every chromatic object exposes: + +* ``evaluateAtWavelength(wave) -> GSObject`` + Returns the monochromatic profile at wavelength *wave* (nm). The + returned GSObject may carry wavelength-dependent parameters as JAX + traced values, so the function is vmappable. + +* ``drawImage(bandpass, n_waves=64, **kwargs) -> Image`` + Integrates the profile over *bandpass* using a static wavelength grid + of *n_waves* points (trapezoid rule). ``n_waves`` is static at + JIT-compile time; everything else may be traced. + +Hierarchy +--------- +:: + + ChromaticObject base class, non-separable draw by default + ├── Chromatic GSObject × SED (separable, fast path) + ├── ChromaticAtmosphere Gaussian PSF with FWHM ∝ λ^alpha + └── ChromaticConvolution convolution of any chromatic objects + +Separable vs non-separable +-------------------------- +*Separable* means ``I(x, y, λ) = g(x, y) × h(λ)``. For a separable +object the integration reduces to a single monochromatic draw: + + flux = ∫ SED(λ) × BP(λ) dλ + image = g(x, y) drawn with total flux + +*Non-separable* objects (e.g. ChromaticAtmosphere) must evaluate the +k-space image at every wavelength sample and integrate: + + K_eff(k) = ∫ K(k, λ) × BP(λ) dλ + image = IFFT[ K_eff ] + +JAX compatibility +----------------- +* All arithmetic uses ``jax.numpy``. +* Wavelength grids are **static** (fixed-size) to allow ``jax.jit``. +* ``jax.vmap`` vectorises the per-wavelength k-value computation. +* ``jax.grad`` flows through SED flux arrays (e.g. DSPS outputs). +""" + +import jax +import jax.numpy as jnp +from jax.tree_util import register_pytree_node_class + +from jax_galsim.gsparams import GSParams +from jax_galsim.position import PositionD + + +def _pixel_scale_from_kwargs(kwargs): + """Return pixel scale in world units (arcsec/pixel) as a concrete float. + + Reads ``scale`` from kwargs; falls back to reading the WCS object if + the ``wcs`` kwarg is provided, or 1.0 otherwise. The value from + ``kwargs['scale']`` is always a Python float (user-supplied literal), + so this is safe to call inside ``jax.jit``. + """ + if "scale" in kwargs: + return float(kwargs["scale"]) + if "wcs" in kwargs: + wcs = kwargs["wcs"] + if hasattr(wcs, "_scale"): + return float(wcs._scale) + return 1.0 + + +def _make_setup_image(profile, kwargs): + """Build draw target without running full ``drawImage`` setup when size is fixed.""" + from jax_galsim.image import Image + + image = kwargs.get("image", None) + if image is not None: + return Image(image=image) + + image_kwargs = { + "dtype": kwargs.get("dtype", None), + "scale": kwargs.get("scale", None), + "wcs": kwargs.get("wcs", None), + } + image_kwargs = {k: v for k, v in image_kwargs.items() if v is not None} + + bounds = kwargs.get("bounds", None) + if bounds is not None: + return Image(bounds=bounds, **image_kwargs) + + nx = kwargs.get("nx", None) + ny = kwargs.get("ny", None) + if nx is not None and ny is not None: + return Image(nx, ny, **image_kwargs) + + return profile.drawImage(setup_only=True, **kwargs) + + +def _next_pow2(n): + n = int(n) + out = 1 + while out < n: + out *= 2 + return out + + +def _fix_fft_size_for_image(profile, image): + """Pin FFT size for JIT-safe setup when output bounds are already fixed.""" + nrow, ncol = image.array.shape + n = max(128, _next_pow2(2 * max(nrow, ncol))) + n = min(n, profile.gsparams.maximum_fft_size) + return profile.withGSParams(minimum_fft_size=n, maximum_fft_size=n) + + +def _static_kcoords(kimage, wrap_size, pixel_scale): + """Return k-space pixel centers as a JAX array with static shape.""" + nrow, ncol = kimage.array.shape + x = jnp.arange(ncol, dtype=float) + y = jnp.arange(nrow, dtype=float) - wrap_size // 2 + kx, ky = jnp.meshgrid(x, y) + dk = 2.0 * jnp.pi / (wrap_size * pixel_scale) + return jnp.stack([kx.ravel(), ky.ravel()], axis=-1) * dk + + +# --------------------------------------------------------------------------- +# Base class +# --------------------------------------------------------------------------- + + +class ChromaticObject: + """Base class for wavelength-dependent profiles. + + Subclasses must override :meth:`evaluateAtWavelength`. The default + :meth:`drawImage` uses a k-space trapezoid integration that works for + any subclass; separable subclasses override it with a faster path. + """ + + #: Set True in separable subclasses. + _separable: bool = False + + def __init__(self, obj=None): + if obj is None: + self._base_obj = None + return + + from jax_galsim.gsobject import GSObject + + if not isinstance(obj, GSObject): + raise TypeError("ChromaticObject requires a GSObject.") + self._base_obj = obj + self._separable = True + + @property + def separable(self): + """True if the profile factors as g(x,y) × h(λ).""" + return self._separable + + # ------------------------------------------------------------------ + # Interface that subclasses must implement + # ------------------------------------------------------------------ + + def evaluateAtWavelength(self, wave): + """Return the monochromatic GSObject at wavelength *wave* (nm). + + This method must be JAX-traceable: all internal computations + should use ``jax.numpy``, and the returned GSObject's parameters + may be abstract JAX tracers. + + Parameters + ---------- + wave : float or jax scalar + Wavelength in nm. + + Returns + ------- + GSObject + """ + if getattr(self, "_base_obj", None) is not None: + return self._base_obj + raise NotImplementedError( + f"{self.__class__.__name__} must implement evaluateAtWavelength." + ) + + # ------------------------------------------------------------------ + # Drawing + # ------------------------------------------------------------------ + + def drawImage(self, bandpass, n_waves=64, **kwargs): + """Draw the bandpass-integrated image. + + Parameters + ---------- + bandpass : Bandpass + Observing bandpass. + n_waves : int + Number of wavelength samples for numerical integration. + Must be a **static** integer (fixed at JIT compile time). + **kwargs + Forwarded to the underlying ``GSObject.drawImage`` calls. + Typical keys: ``nx``, ``ny``, ``scale``, ``method``, etc. + + Returns + ------- + Image + """ + if self._separable: + return self._drawImage_separable(bandpass, n_waves, **kwargs) + return self._drawImage_nonseparable(bandpass, n_waves, **kwargs) + + # ------------------------------------------------------------------ + # Separable fast path + # ------------------------------------------------------------------ + + def _drawImage_separable(self, bandpass, n_waves, **kwargs): + """FFT draw, scaling by total SED×BP flux (traced-safe under JIT). + + Design: split into two phases. + + **Phase 1 - static setup**: + Build image bounds, k-grid, and base k-values using the unit-flux + spatial profile. All shape parameters must be concrete Python + scalars at this stage (true for ``Chromatic`` where the spatial + profile has static params); the SED flux is NOT evaluated here. + + **Phase 2 — traced computation**: + Compute total_flux = ∫ SED(λ)×BP(λ) dλ (may be a JAX traced + value, e.g. DSPS output), multiply into k-values, IFFT. + """ + from jax_galsim.box import Pixel + from jax_galsim.convolve import Convolve + from jax_galsim.image import Image + + wave_eff = bandpass.effective_wavelength # concrete Python float + pixel_scale = _pixel_scale_from_kwargs(kwargs) + + # ------------------------------------------------------------------ + # Phase 1: concrete setup. Under jit this runs during tracing. + # ------------------------------------------------------------------ + with jax.disable_jit(): + # Unit-flux spatial profile — shape params are concrete Python floats + spatial_prof = self._static_spatial_profile(wave_eff) + image = _make_setup_image(spatial_prof, kwargs) + spatial_prof = _fix_fft_size_for_image(spatial_prof, image) + original_center = image.center + original_wcs = image.wcs + image.setCenter(0, 0) + + pixel = Pixel(scale=pixel_scale, gsparams=spatial_prof.gsparams) + prof_conv = Convolve([spatial_prof, pixel], gsparams=spatial_prof.gsparams) + kimage, wrap_size = prof_conv.drawFFT_makeKImage(image) + + kcoords = _static_kcoords(kimage, wrap_size, pixel_scale) + + # Static k-values (unit flux, no SED scaling). + kvals_static = jax.vmap(lambda k: prof_conv._kValue(PositionD(k[0], k[1])))( + kcoords + ) + + # Apply the same -0.5 pixel centering correction that gsobject._adjust_offset + # uses for even-sized images (avoids 0.5-pixel offset vs non-chromatic draws). + img_shape = image.array.shape # (ny, nx); unchanged by setCenter + dx_corr = -0.5 * pixel_scale * ((img_shape[1] + 1) % 2) + dy_corr = -0.5 * pixel_scale * ((img_shape[0] + 1) % 2) + phase_corr = jnp.exp( + -1j * (kcoords[:, 0] * dx_corr + kcoords[:, 1] * dy_corr) + ) + kvals_static = kvals_static * phase_corr + + kshape = kimage.array.shape + kbounds = kimage.bounds + kwcs = kimage.wcs + + # ------------------------------------------------------------------ + # Phase 2: traced computation — integrate SED × bandpass + # ------------------------------------------------------------------ + waves = jnp.linspace(bandpass.blue_limit, bandpass.red_limit, n_waves) + weights = jax.vmap(lambda w: self._sed_value(w) * bandpass(w))(waves) + total_flux = jnp.trapezoid(weights, waves) + + # Scale pre-computed k-values by traced total flux + kvals = kvals_static * total_flux + karray = kvals.reshape(kshape).astype(kimage.dtype) + eff_kimage = Image(array=karray, bounds=kbounds, wcs=kwcs, _check_bounds=False) + prof_conv.drawFFT_finish(image, eff_kimage, wrap_size, add_to_image=False) + + image.shift(original_center) + image.wcs = original_wcs + return image + + def _static_spatial_profile(self, wave_eff): + """Return unit-flux spatial profile at *wave_eff* with static params. + + *wave_eff* must be a concrete Python float (use + ``bandpass.effective_wavelength``). Subclasses override this when + they have a dedicated static spatial object (e.g. ``Chromatic`` has + ``self.obj``). The default calls ``evaluateAtWavelength`` with a + Python float — works when all shape params are Python scalars. + """ + return self.evaluateAtWavelength(float(wave_eff)).withFlux(1.0) + + def _sed_value(self, wave): + """SED flux density at *wave*. Subclasses override if needed.""" + if getattr(self, "_base_obj", None) is not None: + return jnp.ones((), dtype=float) + raise NotImplementedError + + # ------------------------------------------------------------------ + # Non-separable k-space integration + # ------------------------------------------------------------------ + + def _drawImage_nonseparable(self, bandpass, n_waves, **kwargs): + """Integrate in k-space: K_eff = ∫ K(k,λ) × BP(λ) dλ, then IFFT. + + Mirrors the drawFFT pipeline of GSObject.drawImage exactly, split + into a concrete setup phase and a traced computation phase. + """ + from jax_galsim.box import Pixel + from jax_galsim.convolve import Convolve + from jax_galsim.image import Image + + wave_eff = bandpass.effective_wavelength # static Python float + pixel_scale = _pixel_scale_from_kwargs(kwargs) + + # ------------------------------------------------------------------ + # Phase 1: concrete setup. Under jit this runs during tracing. + # ------------------------------------------------------------------ + with jax.disable_jit(): + prof0 = self._static_spatial_profile(wave_eff) + image = _make_setup_image(prof0, kwargs) + prof0 = _fix_fft_size_for_image(prof0, image) + original_center = image.center + original_wcs = image.wcs + image.setCenter(0, 0) + + pixel = Pixel(scale=pixel_scale, gsparams=prof0.gsparams) + prof0_conv = Convolve([prof0, pixel], gsparams=prof0.gsparams) + kimage, wrap_size = prof0_conv.drawFFT_makeKImage(image) + + kcoords = _static_kcoords(kimage, wrap_size, pixel_scale) + + pixel_kvals = jax.vmap(lambda k: pixel._kValue(PositionD(k[0], k[1])))( + kcoords + ) + + # Apply the -0.5 pixel centering correction for even-sized images. + img_shape = image.array.shape + dx_corr = -0.5 * pixel_scale * ((img_shape[1] + 1) % 2) + dy_corr = -0.5 * pixel_scale * ((img_shape[0] + 1) % 2) + phase_corr = jnp.exp( + -1j * (kcoords[:, 0] * dx_corr + kcoords[:, 1] * dy_corr) + ) + pixel_kvals = pixel_kvals * phase_corr + + kshape = kimage.array.shape + kbounds = kimage.bounds + kwcs = kimage.wcs + + # ------------------------------------------------------------------ + # Phase 2: traced computation + # ------------------------------------------------------------------ + waves = jnp.linspace(bandpass.blue_limit, bandpass.red_limit, n_waves) + + def kvals_at_wave(wave): + prof = self.evaluateAtWavelength(wave) + kv = jax.vmap(lambda k: prof._kValue(PositionD(k[0], k[1])))(kcoords) + return kv * bandpass(wave) + + all_kvals = jax.vmap(kvals_at_wave)(waves) # (n_waves, n_k) + eff_kvals = jnp.trapezoid(all_kvals, waves, axis=0) + eff_kvals = eff_kvals * pixel_kvals + + eff_karray = eff_kvals.reshape(kshape).astype(kimage.dtype) + eff_kimage = Image( + array=eff_karray, + bounds=kbounds, + wcs=kwcs, + _check_bounds=False, + ) + + # IFFT back to pixel space + prof0_conv.drawFFT_finish(image, eff_kimage, wrap_size, add_to_image=False) + + # Restore original center and WCS (same as drawImage does after drawFFT) + image.shift(original_center) + image.wcs = original_wcs + + return image + + # ------------------------------------------------------------------ + # Operator overloads + # ------------------------------------------------------------------ + + def __add__(self, other): + return ChromaticSum([self, other]) + + def __radd__(self, other): + return ChromaticSum([other, self]) + + def __mul__(self, other): + from jax_galsim.sed import SED + + if isinstance(other, SED): + base_obj = getattr(self, "_base_obj", None) + if base_obj is None: + raise TypeError( + "Only achromatic ChromaticObject wrappers can be multiplied by SED." + ) + return Chromatic(base_obj, other) + return _ScaledChromaticObject(self, other) + + def __rmul__(self, other): + return self.__mul__(other) + + +class _ScaledChromaticObject(ChromaticObject): + """Chromatic object with wavelength-independent flux scaling.""" + + def __init__(self, obj, scale): + self.obj = obj + self.scale = scale + self._separable = obj._separable + + def evaluateAtWavelength(self, wave): + return self.obj.evaluateAtWavelength(wave).withScaledFlux(self.scale) + + def _static_spatial_profile(self, wave_eff): + return self.obj._static_spatial_profile(wave_eff) + + def _sed_value(self, wave): + return self.obj._sed_value(wave) * self.scale + + +# --------------------------------------------------------------------------- +# ChromaticSum — sum of chromatic objects +# --------------------------------------------------------------------------- + + +class ChromaticSum(ChromaticObject): + """Sum of two or more chromatic profiles. + + The combined SED is the sum of all component SEDs. Drawing + evaluates each component at every wavelength and sums. + + Parameters + ---------- + obj_list : list of ChromaticObject + """ + + _separable = False # conservative; optimised later if all are separable + + def __init__(self, *args): + if len(args) == 0: + raise TypeError("At least one object must be provided.") + if len(args) == 1 and isinstance(args[0], (list, tuple)): + self.obj_list = list(args[0]) + else: + self.obj_list = list(args) + self._separable = all(o._separable for o in self.obj_list) + + def evaluateAtWavelength(self, wave): + from jax_galsim.sum import Sum + + return Sum([o.evaluateAtWavelength(wave) for o in self.obj_list]) + + def drawImage(self, bandpass, n_waves=64, **kwargs): + # Draw each component and sum + images = [ + obj.drawImage(bandpass, n_waves=n_waves, **kwargs) for obj in self.obj_list + ] + result = images[0] + for img in images[1:]: + result._array = result._array + img._array + return result + + +# --------------------------------------------------------------------------- +# Chromatic — separable GSObject × SED +# --------------------------------------------------------------------------- + + +@register_pytree_node_class +class Chromatic(ChromaticObject): + """Separable chromatic profile: a GSObject multiplied by an SED. + + The spatial profile is fixed; wavelength dependence enters only + through the SED flux scaling. + + ``I(x, y, λ) = g(x, y) × SED(λ)`` + + Drawing a ``Chromatic`` through a bandpass reduces to a single + monochromatic draw at the effective wavelength with total flux + ``∫ SED(λ) × BP(λ) dλ``. + + Parameters + ---------- + obj : GSObject + Normalised spatial profile (flux = 1 by convention, though + any flux is allowed and will be multiplied by the SED). + sed : SED + Spectral energy distribution. + + Examples + -------- + :: + + >>> from jax_galsim import Gaussian + >>> from jax_galsim.sed import SED + >>> from jax_galsim.bandpass import Bandpass + >>> import jax.numpy as jnp + + >>> wave = jnp.linspace(500, 900, 256) + >>> sed = SED(wave, jnp.ones(256)) + >>> bp = Bandpass.tophat(550, 750) + >>> gal = Gaussian(half_light_radius=0.5) * sed + >>> img = gal.drawImage(bp, scale=0.2, nx=64, ny=64) + """ + + _separable = True + + def __init__(self, obj, sed): + self.obj = obj + self.sed = sed + + def evaluateAtWavelength(self, wave): + """Return the GSObject scaled to SED(wave).""" + return self.obj.withScaledFlux(self.sed(wave)) + + def _sed_value(self, wave): + return self.sed(wave) + + def _static_spatial_profile(self, wave_eff): + """Return self.obj (unit flux) — shape params are always static.""" + return self.obj.withFlux(1.0) + + # JAX pytree: obj and sed are both pytrees + def tree_flatten(self): + children = (self.obj, self.sed) + aux_data = {} + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(obj=children[0], sed=children[1]) + + def __repr__(self): + return f"Chromatic({self.obj!r}, {self.sed!r})" + + +# --------------------------------------------------------------------------- +# ChromaticAtmosphere — seeing PSF with wavelength-dependent FWHM +# --------------------------------------------------------------------------- + + +@register_pytree_node_class +class ChromaticAtmosphere(ChromaticObject): + """Atmospheric PSF with a power-law wavelength-dependent FWHM. + + The PSF profile at wavelength λ is a Gaussian (or Moffat) with: + + FWHM(λ) = fwhm_ref × (λ / lam_ref)^alpha + + For Kolmogorov turbulence, the expected scaling is α ≈ −0.2. + + This profile carries a **flat (dimensionless) SED**: ``SED(λ) = 1``. + The physical SED is typically attached to the galaxy component via + :class:`Chromatic`, and passed to :class:`ChromaticConvolution`. + + Parameters + ---------- + fwhm_ref : float + FWHM in arcseconds at the reference wavelength. + lam_ref : float + Reference wavelength in nm. + alpha : float, optional + Power-law index. Default −0.2 (Kolmogorov). + profile : {'gaussian', 'moffat'} + Profile type. Default ``'gaussian'``. + moffat_beta : float, optional + Moffat β parameter (only used when ``profile='moffat'``). + Default 4.765 (typical for atmospheric seeing). + gsparams : GSParams, optional + + Examples + -------- + :: + + >>> from jax_galsim.chromatic import ChromaticAtmosphere + >>> psf = ChromaticAtmosphere(fwhm_ref=0.7, lam_ref=700.0, alpha=-0.2) + >>> prof = psf.evaluateAtWavelength(550.0) # Gaussian at 550 nm + """ + + _separable = False + + def __init__( + self, + fwhm_ref, + lam_ref, + alpha=-0.2, + profile="gaussian", + moffat_beta=4.765, + gsparams=None, + ): + # All shape params stored as Python floats (static, not JAX-traced). + # This enables jax.jit compatibility without pinning the FFT size. + # To differentiate through fwhm_ref, use gsparams with fixed fft_size: + # GSParams(minimum_fft_size=N, maximum_fft_size=N) + self._fwhm_ref = float(fwhm_ref) + self._lam_ref = float(lam_ref) + self._alpha = float(alpha) + self._profile = profile + self._moffat_beta = float(moffat_beta) + self._gsparams = GSParams.check(gsparams) + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def fwhm_ref(self): + return self._fwhm_ref + + @property + def lam_ref(self): + return self._lam_ref + + @property + def alpha(self): + return self._alpha + + # ------------------------------------------------------------------ + # Core method + # ------------------------------------------------------------------ + + def evaluateAtWavelength(self, wave): + """Return the PSF profile (unit flux) at wavelength *wave* (nm). + + FWHM is scaled as ``fwhm_ref × (wave / lam_ref)^alpha``. + When *wave* is a JAX tracer (inside vmap/jit), fwhm is also traced. + """ + fwhm = self._fwhm_ref * (wave / self._lam_ref) ** self._alpha + + if self._profile == "gaussian": + from jax_galsim.gaussian import Gaussian + + return Gaussian(fwhm=fwhm, flux=1.0, gsparams=self._gsparams) + + elif self._profile == "moffat": + from jax_galsim.moffat import Moffat + + return Moffat( + fwhm=fwhm, + beta=self._moffat_beta, + flux=1.0, + gsparams=self._gsparams, + ) + else: + raise ValueError( + f"Unknown profile type '{self._profile}'. " + "Expected 'gaussian' or 'moffat'." + ) + + def _sed_value(self, wave): + """Flat SED: unit flux at all wavelengths.""" + return jnp.ones((), dtype=float) + + # ------------------------------------------------------------------ + # JAX pytree interface + # ------------------------------------------------------------------ + + def tree_flatten(self): + # No traced children — all params are static Python scalars. + children = () + aux_data = { + "fwhm_ref": self._fwhm_ref, + "lam_ref": self._lam_ref, + "alpha": self._alpha, + "profile": self._profile, + "moffat_beta": self._moffat_beta, + "gsparams": self._gsparams, + } + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls( + fwhm_ref=aux_data["fwhm_ref"], + lam_ref=aux_data["lam_ref"], + alpha=aux_data["alpha"], + profile=aux_data["profile"], + moffat_beta=aux_data["moffat_beta"], + gsparams=aux_data["gsparams"], + ) + + def __repr__(self): + return ( + f"ChromaticAtmosphere(fwhm_ref={float(self._fwhm_ref):.3f}, " + f"lam_ref={self._lam_ref:.0f} nm, alpha={self._alpha:.2f}, " + f"profile={self._profile!r})" + ) + + +# --------------------------------------------------------------------------- +# ChromaticConvolution — convolution of chromatic objects +# --------------------------------------------------------------------------- + + +class ChromaticConvolution(ChromaticObject): + """Convolution of multiple chromatic profiles. + + Computes the bandpass-integrated image of the convolution: + + K_eff(k) = ∫ ∏_i K_i(k, λ) × BP(λ) dλ + + where each ``K_i(k, λ)`` is the k-space value of the i-th component + at wavelength λ. + + **Optimised case**: when all components except one are separable + (e.g. galaxy × SED convolved with a chromatic PSF), the separable + profiles are extracted and convolved with the wavelength-integrated + effective PSF. This avoids multiplying the same galaxy k-image at + every wavelength sample. + + Parameters + ---------- + obj_list : list of ChromaticObject or GSObject + Components to convolve. Plain ``GSObject`` instances are wrapped + automatically in a flat-SED :class:`Chromatic`. + + Examples + -------- + :: + + >>> from jax_galsim import Gaussian, Convolve + >>> from jax_galsim.chromatic import Chromatic, ChromaticAtmosphere, ChromaticConvolution + >>> from jax_galsim.sed import SED + >>> from jax_galsim.bandpass import Bandpass + >>> import jax.numpy as jnp + + >>> wave = jnp.linspace(300, 1100, 512) + >>> flux = jnp.exp(-0.5 * ((wave - 700) / 150) ** 2) # Gaussian SED + >>> sed = SED(wave, flux) + >>> bp = Bandpass.tophat(550, 800) + + >>> gal = Gaussian(half_light_radius=0.5) * sed + >>> psf = ChromaticAtmosphere(fwhm_ref=0.7, lam_ref=700.0, alpha=-0.2) + >>> final = ChromaticConvolution([gal, psf]) + >>> img = final.drawImage(bp, scale=0.2, nx=64, ny=64) + """ + + _separable = False + + def __init__(self, *args, **kwargs): + from jax_galsim.gsobject import GSObject + + if len(args) == 0: + raise TypeError("At least one object must be provided.") + if len(args) == 1 and isinstance(args[0], (list, tuple)): + obj_list = list(args[0]) + else: + obj_list = list(args) + + real_space = kwargs.pop("real_space", None) + self._gsparams = GSParams.check(kwargs.pop("gsparams", None)) + self._propagate_gsparams = kwargs.pop("propagate_gsparams", True) + if kwargs: + raise TypeError( + "ChromaticConvolution constructor got unexpected keyword argument(s): %s" + % kwargs.keys() + ) + if real_space: + raise NotImplementedError( + "Real-space chromatic convolutions are not implemented" + ) + + # Wrap plain GSObjects with a flat SED so they fit the interface + processed = [] + for obj in obj_list: + if isinstance(obj, ChromaticConvolution): + processed.extend(obj.obj_list) + continue + if isinstance(obj, GSObject): + from jax_galsim.sed import SED + + wave_stub = jnp.array([100.0, 2000.0]) + flat_sed = SED(wave_stub, jnp.ones(2)) + processed.append(Chromatic(obj, flat_sed)) + else: + processed.append(obj) + self.obj_list = processed + + @property + def gsparams(self): + return self._gsparams + + def withGSParams(self, gsparams=None, **kwargs): + ret = self.__class__.__new__(self.__class__) + ret.obj_list = self.obj_list + ret._gsparams = GSParams.check(gsparams, self._gsparams, **kwargs) + ret._propagate_gsparams = self._propagate_gsparams + return ret + + def evaluateAtWavelength(self, wave): + """Return the convolved monochromatic profile at *wave* (nm).""" + from jax_galsim.convolve import Convolve + + return Convolve([o.evaluateAtWavelength(wave) for o in self.obj_list]) + + # ------------------------------------------------------------------ + # Drawing with the optimised separable split + # ------------------------------------------------------------------ + + def drawImage(self, bandpass, n_waves=64, **kwargs): + """Draw the integrated image, exploiting separability where possible. + + Algorithm: + + 1. Separate components into *separable* (galaxy × SED) and + *non-separable* (chromatic PSF). + 2. Build the combined weight function: + ``w(λ) = ∏_sep SED_i(λ) × BP(λ)`` + 3. In k-space, integrate non-separable components: + ``K_nonsep_eff(k) = ∫ ∏_nonsep K_i(k,λ) × w(λ) dλ`` + 4. Multiply by separable component k-values (λ-independent after + extracting their SED into the weight). + 5. Include pixel convolution, IFFT to real space. + + All-separable case falls back to a single monochromatic draw. + """ + from jax_galsim.box import Pixel + from jax_galsim.convolve import Convolve + from jax_galsim.image import Image + + sep_objs = [o for o in self.obj_list if o._separable] + nonsep_objs = [o for o in self.obj_list if not o._separable] + + wave_eff = bandpass.effective_wavelength # static Python float + waves = jnp.linspace(bandpass.blue_limit, bandpass.red_limit, n_waves) + pixel_scale = _pixel_scale_from_kwargs(kwargs) + + # ------------------------------------------------------------------- + # Phase 1: concrete setup. Under jit this runs during tracing. + # + # Shape parameters (sigma, fwhm) must NOT be JIT-traced inputs here. + # If they are, pin the FFT size via gsparams(min/max_fft_size=N). + # ------------------------------------------------------------------- + with jax.disable_jit(): + if not nonsep_objs: + # All-separable: use static spatial profiles (avoids traced SED) + spatial_profs = [o._static_spatial_profile(wave_eff) for o in sep_objs] + grid_prof = Convolve(spatial_profs) + else: + # Mixed: sep objects use static profiles; nonsep objects are evaluated + # at wave_eff (their shape params must be concrete at this point). + fiducial_profs = [ + o._static_spatial_profile(wave_eff) + if o._separable + else o.evaluateAtWavelength(wave_eff) + for o in self.obj_list + ] + grid_prof = Convolve(fiducial_profs) + + image = _make_setup_image(grid_prof, kwargs) + grid_prof = _fix_fft_size_for_image(grid_prof, image) + original_center = image.center + original_wcs = image.wcs + image.setCenter(0, 0) + + pixel = Pixel(scale=pixel_scale) + grid_prof_conv = Convolve([grid_prof, pixel], gsparams=grid_prof.gsparams) + kimage, wrap_size = grid_prof_conv.drawFFT_makeKImage(image) + + # k-space coordinates with static shape. + kcoords = _static_kcoords(kimage, wrap_size, pixel_scale) + n_k = kcoords.shape[0] + + # Pixel k-values. + pixel_kvals = jax.vmap(lambda k: pixel._kValue(PositionD(k[0], k[1])))( + kcoords + ) + + # Match gsobject._adjust_offset: even-sized images need a -0.5 pixel + # true-center correction before the FFT draw. Apply it as a k-space + # phase so chromatic draws align with monochromatic drawImage. + img_shape = image.array.shape # (ny, nx); unchanged by setCenter + dx_corr = -0.5 * pixel_scale * ((img_shape[1] + 1) % 2) + dy_corr = -0.5 * pixel_scale * ((img_shape[0] + 1) % 2) + phase_corr = jnp.exp( + -1j * (kcoords[:, 0] * dx_corr + kcoords[:, 1] * dy_corr) + ) + + if not nonsep_objs: + # Pre-compute k-values of the full (spatial+pixel) convolution + base_kvals = ( + jax.vmap(lambda k: grid_prof_conv._kValue(PositionD(k[0], k[1])))( + kcoords + ) + * phase_corr + ) + else: + pixel_kvals = pixel_kvals * phase_corr + + # Pre-compute k-values of separable components (unit flux each) + sep_kvals = jnp.ones(n_k, dtype=complex) + for o in sep_objs: + prof_sep = o._static_spatial_profile(wave_eff) + sep_kvals = sep_kvals * jax.vmap( + lambda k: prof_sep._kValue(PositionD(k[0], k[1])) + )(kcoords) + + kshape = kimage.array.shape + kbounds = kimage.bounds + kwcs = kimage.wcs + + # ------------------------------------------------------------------- + # Phase 2: traced computation (JAX-traced values allowed here) + # ------------------------------------------------------------------- + if not nonsep_objs: + # All-separable: integrate SED × bandpass → total flux, then scale + def combined_weight(wave): + w = bandpass(wave) + for o in sep_objs: + w = w * o._sed_value(wave) + return w + + total_flux = jnp.trapezoid(jax.vmap(combined_weight)(waves), waves) + kvals = base_kvals * total_flux + + else: + # Mixed sep + nonsep: integrate nonsep k-values weighted by sep SED × BP + def sep_weight(wave): + w = bandpass(wave) + for o in sep_objs: + w = w * o._sed_value(wave) + return w + + def kvals_nonsep_at_wave(wave): + kv = jnp.ones(n_k, dtype=complex) + for o in nonsep_objs: + prof = o.evaluateAtWavelength(wave) + kv = kv * jax.vmap(lambda k: prof._kValue(PositionD(k[0], k[1])))( + kcoords + ) + return kv * sep_weight(wave) + + all_kvals = jax.vmap(kvals_nonsep_at_wave)(waves) # (n_waves, n_k) + eff_kvals = jnp.trapezoid(all_kvals, waves, axis=0) # (n_k,) + + # Multiply by separable spatial k-values and pixel convolution + kvals = eff_kvals * sep_kvals * pixel_kvals + + karray = kvals.reshape(kshape).astype(kimage.dtype) + eff_kimage = Image(array=karray, bounds=kbounds, wcs=kwcs, _check_bounds=False) + grid_prof_conv.drawFFT_finish(image, eff_kimage, wrap_size, add_to_image=False) + + image.shift(original_center) + image.wcs = original_wcs + return image + + def __repr__(self): + inner = ", ".join(repr(o) for o in self.obj_list) + return f"ChromaticConvolution([{inner}])" + + +# --------------------------------------------------------------------------- +# Monkey-patch GSObject.__mul__ to return Chromatic when multiplied by SED +# --------------------------------------------------------------------------- + + +def _gsobject_mul_sed(self, other): + """Allow ``gsobject * sed → Chromatic(gsobject, sed)``.""" + from jax_galsim.sed import SED + + if isinstance(other, SED): + return Chromatic(self, other) + # Fall through to original implementation (flux scaling) + return self.withScaledFlux(other) + + +def _gsobject_rmul_sed(self, other): + return _gsobject_mul_sed(self, other) + + +# Apply the patch once at import time +def _patch_gsobject(): + from jax_galsim.gsobject import GSObject + + GSObject.__mul__ = _gsobject_mul_sed + GSObject.__rmul__ = _gsobject_rmul_sed + + +_patch_gsobject() diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index c8c4c1ae..ab0258c7 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -11,7 +11,7 @@ @implements( _galsim.Convolve, - lax_description="""Does not support ChromaticConvolutions""", + lax_description="""Supports ChromaticConvolutions for FFT drawing only.""", ) def Convolve(*args, **kwargs): if len(args) == 0: @@ -28,6 +28,11 @@ def Convolve(*args, **kwargs): ) # else args is already the list of objects + from jax_galsim.chromatic import ChromaticConvolution, ChromaticObject + + if any(isinstance(obj, ChromaticObject) for obj in args): + return ChromaticConvolution(args, **kwargs) + return Convolution(*args, **kwargs) diff --git a/jax_galsim/sed.py b/jax_galsim/sed.py new file mode 100644 index 00000000..bf9b816d --- /dev/null +++ b/jax_galsim/sed.py @@ -0,0 +1,240 @@ +"""Spectral Energy Distribution (SED) for chromatic profiles. + +Designed for JAX compatibility: the flux array is a traced parameter, +so gradients flow through SED values (e.g. from DSPS outputs). +""" + +import jax.numpy as jnp +from jax.tree_util import register_pytree_node_class + + +@register_pytree_node_class +class SED: + """Spectral Energy Distribution. + + Represents flux density as a function of wavelength, designed for + full JAX compatibility. The ``flux`` array is a JAX-traced parameter, + enabling gradients through SED parameters (e.g. outputs of DSPS). + + The wavelength grid (``wave``) is treated as static auxiliary data: + it defines the interpolation structure and is not differentiated. + + Parameters + ---------- + wave : array_like + Wavelength array **in nanometers**. Must be strictly increasing. + Treated as static (not traced by JAX). + flux : array_like + Flux density at each wavelength. Treated as a JAX-traced parameter. + Units are arbitrary but must be consistent across the simulation + (typically photons / nm / cm² / s for spectral SEDs, or + dimensionless for shape-only profiles). + redshift : float, optional + Cosmological redshift applied to the SED. The observed wavelength + grid is shifted to ``wave_obs = wave_rest * (1 + redshift)``. + Default 0. + + Examples + -------- + Basic construction from arrays:: + + >>> import jax.numpy as jnp + >>> from jax_galsim.sed import SED + >>> wave = jnp.linspace(300, 1100, 512) # nm + >>> flux = jnp.ones(512) + >>> sed = SED(wave, flux) + >>> float(sed(550.0)) + 1.0 + + DSPS workflow — flux is a traced JAX array:: + + >>> flux = dsps_model(params) # JAX array + >>> sed = SED(dsps_wave_nm, flux) + >>> image = chromatic_galaxy.drawImage(bandpass) # differentiable + + Redshifted SED:: + + >>> sed_z = SED(wave, flux, redshift=0.5) + >>> sed_z(825.0) # queries rest-frame 550 nm + """ + + def __init__(self, wave, flux, redshift=0.0): + self._wave = jnp.asarray(wave, dtype=float) # static, not traced + self._flux = jnp.asarray(flux) # traced + self._redshift = jnp.asarray(redshift, dtype=float) + + if self._wave.ndim != 1 or len(self._wave) < 2: + raise ValueError("wave must be a 1-D array with at least 2 elements.") + if len(self._flux) != len(self._wave): + raise ValueError("flux must have the same length as wave.") + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def wave(self): + """Rest-frame wavelength grid in nm (JAX array, static).""" + return self._wave + + @property + def flux(self): + """Flux density array (JAX array, traced).""" + return self._flux + + @property + def redshift(self): + """Redshift (JAX scalar).""" + return self._redshift + + @property + def blue_limit(self): + """Shortest observed wavelength in nm.""" + return float(self._wave[0]) * (1.0 + float(self._redshift)) + + @property + def red_limit(self): + """Longest observed wavelength in nm.""" + return float(self._wave[-1]) * (1.0 + float(self._redshift)) + + # ------------------------------------------------------------------ + # Evaluation + # ------------------------------------------------------------------ + + def __call__(self, wave): + """Evaluate flux density at observed wavelength(s) in nm. + + Uses linear interpolation; returns 0 outside the defined range. + + Parameters + ---------- + wave : float or array_like + Observed wavelength(s) in nm. + + Returns + ------- + jnp.ndarray + Flux density at the requested wavelengths. + """ + wave = jnp.asarray(wave, dtype=float) + # Convert observed wavelength to rest-frame before interpolating + wave_rest = wave / (1.0 + self._redshift) + return jnp.interp(wave_rest, self._wave, self._flux, left=0.0, right=0.0) + + # ------------------------------------------------------------------ + # Flux through a bandpass + # ------------------------------------------------------------------ + + def calculateFlux(self, bandpass, n_waves=512): + """Integrate SED through a bandpass: ``∫ SED(λ) × BP(λ) dλ``. + + Parameters + ---------- + bandpass : Bandpass + Observing bandpass. + n_waves : int, optional + Number of quadrature points. Default 512. + + Returns + ------- + jnp.ndarray + Scalar flux value. + """ + waves = jnp.linspace(bandpass.blue_limit, bandpass.red_limit, n_waves) + return jnp.trapezoid(self(waves) * bandpass(waves), waves) + + # ------------------------------------------------------------------ + # Arithmetic + # ------------------------------------------------------------------ + + def withRedshift(self, redshift): + """Return a copy of this SED at a new redshift.""" + return SED(self._wave, self._flux, redshift) + + def __mul__(self, other): + """Multiply SED by a scalar or another SED. + + SED × scalar scales all flux values. + SED × SED multiplies flux densities (both evaluated on ``self``'s grid). + """ + if isinstance(other, SED): + # Evaluate other SED on self's rest-frame grid and multiply + other_flux = jnp.interp( + self._wave * (1.0 + self._redshift), + other._wave * (1.0 + other._redshift), + other._flux, + left=0.0, + right=0.0, + ) + return SED(self._wave, self._flux * other_flux, self._redshift) + from jax_galsim.gsobject import GSObject + + if isinstance(other, GSObject): + from jax_galsim.chromatic import Chromatic + + return Chromatic(other, self) + from jax_galsim.chromatic import ChromaticObject + + if isinstance(other, ChromaticObject): + return other * self + return SED(self._wave, self._flux * other, self._redshift) + + def __rmul__(self, other): + return self.__mul__(other) + + def __truediv__(self, other): + return SED(self._wave, self._flux / other, self._redshift) + + def __add__(self, other): + if isinstance(other, SED): + other_flux = jnp.interp( + self._wave * (1.0 + self._redshift), + other._wave * (1.0 + other._redshift), + other._flux, + left=0.0, + right=0.0, + ) + return SED(self._wave, self._flux + other_flux, self._redshift) + return SED(self._wave, self._flux + other, self._redshift) + + # ------------------------------------------------------------------ + # JAX pytree interface + # ------------------------------------------------------------------ + + def tree_flatten(self): + """Flatten for JAX tracing. + + ``flux`` and ``redshift`` are traced children. + ``wave`` is static auxiliary data (the interpolation grid). + """ + children = (self._flux, self._redshift) + # wave must be hashable for JAX cache keys; store as tuple of floats. + aux_data = {"wave": tuple(self._wave.tolist())} + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls( + wave=jnp.asarray(aux_data["wave"], dtype=float), + flux=children[0], + redshift=children[1], + ) + + # ------------------------------------------------------------------ + # Misc + # ------------------------------------------------------------------ + + def __repr__(self): + return ( + f"SED(wave=[{self._wave[0]:.1f}, ..., {self._wave[-1]:.1f}] nm, " + f"n_wave={len(self._wave)}, redshift={float(self._redshift):.4f})" + ) + + def __eq__(self, other): + if not isinstance(other, SED): + return False + return ( + jnp.array_equal(self._wave, other._wave) + and jnp.array_equal(self._flux, other._flux) + and jnp.array_equal(self._redshift, other._redshift) + ) diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 428572a4..1391e880 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -21,6 +21,9 @@ enabled_tests: - test_noise.py - test_image.py - test_photon_array.py + - test_sed.py + - test_bandpass.py + - test_chromatic.py - "*" # means all tests from galsim coord: - test_angle.py @@ -86,7 +89,6 @@ allowed_failures: - "module 'jax_galsim' has no attribute 'Dict'" - "module 'jax_galsim' has no attribute 'OutputCatalog'" - "module 'jax_galsim' has no attribute 'cdmodel'" - - "module 'jax_galsim' has no attribute 'ChromaticObject'" - "module 'jax_galsim' has no attribute 'ChromaticAiry'" - "module 'jax_galsim' has no attribute 'config'" - "module 'jax_galsim' has no attribute 'RealGalaxyCatalog'" @@ -101,13 +103,20 @@ allowed_failures: - "pad_image not implemented in jax_galsim." - "InterpolatedImages do not support noise padding in jax_galsim." - "module 'jax_galsim' has no attribute 'FittedSIPWCS'" - - "module 'jax_galsim' has no attribute 'Bandpass'" - "module 'jax_galsim' has no attribute 'Refraction'" - "module 'jax_galsim' has no attribute 'FRatioAngles'" - "module 'jax_galsim' has no attribute 'PupilAnnulusSampler'" - "module 'jax_galsim' has no attribute 'TimeSampler'" - "object has no attribute 'noise'" - - "module 'jax_galsim' has no attribute 'SED'" + # JAX-native chromatic subset: array-backed SED/Bandpass only. + # Full GalSim spectral I/O, expression parsing, LookupTable metadata, + # transforms, thinning, photon shooting, and units are not implemented yet. + - "SED.__init__() got an unexpected keyword argument 'spec'" + - "SED.__init__() got an unexpected keyword argument 'wave_type'" + - "Bandpass.__init__() got an unexpected keyword argument 'wave_type'" + - "could not convert string to float" + - "is not a valid JAX array type" + - "'ChromaticObject' object has no attribute 'dilate'" - "module 'jax_galsim' has no attribute 'getCOSMOSNoise'" - "GSParams.__init__() got an unexpected keyword argument 'allowed_flux_variation'" - "module 'jax_galsim' has no attribute 'Atmosphere'" @@ -135,7 +144,6 @@ allowed_failures: - "module 'jax_galsim' has no attribute 'Aperture'" - "module 'jax_galsim' has no attribute 'AtmosphericScreen'" - "module 'jax_galsim' has no attribute 'OpticalScreen'" - - "module 'jax_galsim' has no attribute 'ChromaticConvolution'" - "module 'jax_galsim' has no attribute 'phase_screens'" - "module 'jax_galsim' has no attribute 'DistDeviate'" - "module 'jax_galsim' has no attribute 'roman'" diff --git a/tests/jax/test_chromatic_jax.py b/tests/jax/test_chromatic_jax.py new file mode 100644 index 00000000..f8644721 --- /dev/null +++ b/tests/jax/test_chromatic_jax.py @@ -0,0 +1,411 @@ +"""Tests for jax_galsim chromatic PSF support. + +Verifies: +- SED and Bandpass construction and evaluation +- Chromatic (separable) drawImage correctness +- ChromaticAtmosphere (non-separable) drawImage correctness +- ChromaticConvolution (galaxy × SED ⊗ PSF) correctness +- jax.jit compatibility for all paths +- jax.grad compatibility for SED-flux differentiation +- Pytree round-trip (tree_flatten / tree_unflatten) +- Numerical agreement with analytic expectations +""" + +# ruff: noqa: E402,I001 + +import jax +import jax.numpy as jnp +import pytest + +# Enable float64 for accuracy +jax.config.update("jax_enable_x64", True) + +import jax_galsim as jgal +from jax_galsim.chromatic import ChromaticAtmosphere, ChromaticConvolution + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +WAVE = jnp.linspace(400.0, 900.0, 256) # nm +BP = jgal.Bandpass.tophat(550.0, 750.0) # 200 nm wide, throughput = 1 +BP_NARROW = jgal.Bandpass.tophat(600.0, 700.0) # 100 nm wide + + +def flat_sed(scale=1.0): + return jgal.SED(WAVE, jnp.ones(256) * scale) + + +# --------------------------------------------------------------------------- +# SED tests +# --------------------------------------------------------------------------- + + +def test_sed_evaluation(): + sed = flat_sed() + assert float(sed(600.0)) == pytest.approx(1.0, rel=1e-5) + assert float(sed(300.0)) == pytest.approx(0.0) # outside range + + +def test_sed_redshift(): + sed = jgal.SED(WAVE, jnp.ones(256), redshift=1.0) + # observed 800 nm → rest-frame 400 nm → flux = 1.0 + assert float(sed(800.0)) == pytest.approx(1.0, rel=1e-4) + # observed 400 nm → rest-frame 200 nm → outside range → 0 + assert float(sed(400.0)) == pytest.approx(0.0, abs=1e-6) + + +def test_sed_calculate_flux(): + sed = flat_sed() + flux = float(sed.calculateFlux(BP)) + # ∫_550^750 1 dλ = 200 nm + assert flux == pytest.approx(200.0, rel=1e-3) + + +def test_sed_pytree_roundtrip(): + sed = flat_sed(2.0) + leaves, treedef = jax.tree_util.tree_flatten(sed) + sed2 = jax.tree_util.tree_unflatten(treedef, leaves) + assert float(sed2(600.0)) == pytest.approx(2.0, rel=1e-5) + + +def test_sed_arithmetic(): + sed1 = flat_sed(2.0) + sed2 = sed1 * 3.0 + assert float(sed2(600.0)) == pytest.approx(6.0, rel=1e-5) + + sed3 = sed1 + flat_sed(1.0) + assert float(sed3(600.0)) == pytest.approx(3.0, rel=1e-5) + + +# --------------------------------------------------------------------------- +# Bandpass tests +# --------------------------------------------------------------------------- + + +def test_bandpass_evaluation(): + bp = jgal.Bandpass.tophat(550.0, 750.0) + assert float(bp(625.0)) == pytest.approx(1.0) + assert float(bp(500.0)) == pytest.approx(0.0) + assert float(bp(800.0)) == pytest.approx(0.0) + + +def test_bandpass_effective_wavelength(): + bp = jgal.Bandpass.tophat(550.0, 750.0) + lam_eff = bp.effective_wavelength + assert isinstance(lam_eff, float) + assert lam_eff == pytest.approx(650.0, rel=1e-4) + + +def test_bandpass_effective_wavelength_concrete(): + """effective_wavelength must be a concrete Python float (safe under JIT).""" + bp = jgal.Bandpass.tophat(550.0, 750.0) + lam_eff = bp.effective_wavelength + # If this were a JAX tracer, float() would raise ConcretizationTypeError + assert isinstance(lam_eff, float) + + +def test_bandpass_mul(): + bp1 = jgal.Bandpass.tophat(500.0, 700.0) + bp2 = jgal.Bandpass.tophat(600.0, 800.0) + bp = bp1 * bp2 + assert float(bp(650.0)) == pytest.approx(1.0) + assert float(bp(550.0)) == pytest.approx(0.0) + assert float(bp(750.0)) == pytest.approx(0.0) + + +def test_bandpass_pytree_roundtrip(): + bp = jgal.Bandpass.tophat(550.0, 750.0) + leaves, treedef = jax.tree_util.tree_flatten(bp) + bp2 = jax.tree_util.tree_unflatten(treedef, leaves) + assert float(bp2(625.0)) == pytest.approx(1.0) + assert bp2.effective_wavelength == pytest.approx(650.0, rel=1e-4) + + +# --------------------------------------------------------------------------- +# Chromatic (separable) tests +# --------------------------------------------------------------------------- + + +def test_chromatic_construction(): + sed = flat_sed() + gal = jgal.Gaussian(half_light_radius=0.5) * sed + assert gal._separable + + +def test_chromatic_drawImage_flux(): + """Image pixel sum should equal ∫ SED(λ) × BP(λ) dλ.""" + sed = flat_sed() + gal = jgal.Gaussian(half_light_radius=0.5) * sed + img = gal.drawImage(BP, scale=0.2, nx=32, ny=32) + # ∫_550^750 1 dλ = 200 + assert float(img.array.sum()) == pytest.approx(200.0, rel=5e-3) + + +def test_chromatic_drawImage_narrow_bandpass(): + """Narrower bandpass → smaller total flux.""" + sed = flat_sed() + gal = jgal.Gaussian(half_light_radius=0.5) * sed + img = gal.drawImage(BP_NARROW, scale=0.2, nx=32, ny=32) + assert float(img.array.sum()) == pytest.approx(100.0, rel=5e-3) + + +def test_chromatic_jit(): + @jax.jit + def render(flux): + sed = jgal.SED(WAVE, flux) + gal = jgal.Gaussian(half_light_radius=0.5) * sed + return gal.drawImage(BP, scale=0.2, nx=32, ny=32).array.sum() + + result = render(jnp.ones(256)) + assert float(result) == pytest.approx(200.0, rel=5e-3) + + +def test_chromatic_grad(): + @jax.jit + def render(flux): + sed = jgal.SED(WAVE, flux) + gal = jgal.Gaussian(half_light_radius=0.5) * sed + return gal.drawImage(BP, scale=0.2, nx=32, ny=32).array.sum() + + grad = jax.grad(render)(jnp.ones(256)) + grad_arr = jnp.asarray(grad) + + # Gradient sums to total bandpass flux ≈ 200 + assert float(grad.sum()) == pytest.approx(200.0, rel=5e-2) + # Outside bandpass → zero gradient + idx_out = int((420 - 400) / (900 - 400) * 255) + assert float(grad[idx_out]) == pytest.approx(0.0, abs=1e-8) + # Inside bandpass region has positive total contribution + mask_in = (WAVE >= 550) & (WAVE <= 750) + assert float(grad_arr[mask_in].sum()) > 0.0 + + +def test_chromatic_jit_recompile(): + """JIT should reuse compiled code when called twice with different flux.""" + + @jax.jit + def render(flux): + sed = jgal.SED(WAVE, flux) + gal = jgal.Gaussian(half_light_radius=0.5) * sed + return gal.drawImage(BP, scale=0.2, nx=32, ny=32).array.sum() + + r1 = float(render(jnp.ones(256))) + r2 = float(render(jnp.ones(256) * 2.0)) + assert r2 == pytest.approx(r1 * 2.0, rel=1e-4) + + +# --------------------------------------------------------------------------- +# ChromaticAtmosphere tests +# --------------------------------------------------------------------------- + + +def test_chromatic_atmosphere_evaluate(): + psf = ChromaticAtmosphere(fwhm_ref=0.7, lam_ref=700.0, alpha=-0.2) + prof = psf.evaluateAtWavelength(700.0) + # At reference wavelength, FWHM should be exactly fwhm_ref + # jax_galsim Gaussian exposes sigma; FWHM = sigma * fwhm_factor + assert isinstance(prof, jgal.Gaussian) + fwhm = float(prof.sigma) * jgal.Gaussian._fwhm_factor + assert fwhm == pytest.approx(0.7, rel=1e-4) + + +def test_chromatic_atmosphere_scaling(): + psf = ChromaticAtmosphere(fwhm_ref=0.7, lam_ref=700.0, alpha=-0.2) + prof_blue = psf.evaluateAtWavelength(350.0) + prof_red = psf.evaluateAtWavelength(700.0) + # FWHM ∝ λ^alpha = λ^(-0.2) → bluer is wider (alpha < 0) + assert float(prof_blue.sigma) > float(prof_red.sigma) + + +def test_chromatic_atmosphere_drawImage_flux(): + """Total flux = ∫ BP(λ) × 1 dλ = 200 (flat SED, unit PSF flux).""" + psf = ChromaticAtmosphere(fwhm_ref=0.7, lam_ref=700.0, alpha=-0.2) + img = psf.drawImage(BP, scale=0.2, nx=32, ny=32) + assert float(img.array.sum()) == pytest.approx(200.0, rel=5e-3) + + +def test_chromatic_atmosphere_moffat(): + psf = ChromaticAtmosphere( + fwhm_ref=0.7, lam_ref=700.0, alpha=-0.2, profile="moffat", moffat_beta=4.765 + ) + prof = psf.evaluateAtWavelength(700.0) + assert isinstance(prof, jgal.Moffat) + img = psf.drawImage(BP, scale=0.2, nx=32, ny=32) + assert float(img.array.sum()) == pytest.approx(200.0, rel=5e-2) + + +def test_chromatic_atmosphere_pytree(): + psf = ChromaticAtmosphere(fwhm_ref=0.7, lam_ref=700.0, alpha=-0.2) + leaves, treedef = jax.tree_util.tree_flatten(psf) + psf2 = jax.tree_util.tree_unflatten(treedef, leaves) + assert psf2._fwhm_ref == pytest.approx(0.7) + assert psf2._alpha == pytest.approx(-0.2) + + +# --------------------------------------------------------------------------- +# ChromaticConvolution tests +# --------------------------------------------------------------------------- + + +def test_chromatic_convolution_flux(): + """Galaxy × SED ⊗ ChromaticAtmosphere: flux = ∫ SED(λ) × BP(λ) dλ.""" + sed = flat_sed() + gal = jgal.Gaussian(half_light_radius=0.5) * sed + psf = ChromaticAtmosphere(fwhm_ref=0.7, lam_ref=700.0, alpha=-0.2) + final = ChromaticConvolution([gal, psf]) + img = final.drawImage(BP, scale=0.2, nx=64, ny=64, n_waves=32) + assert float(img.array.sum()) == pytest.approx(200.0, rel=5e-2) + + +def test_chromatic_convolution_jit(): + @jax.jit + def render(flux): + sed = jgal.SED(WAVE, flux) + gal = jgal.Gaussian(half_light_radius=0.5) * sed + psf = ChromaticAtmosphere(fwhm_ref=0.7, lam_ref=700.0, alpha=-0.2) + return ( + ChromaticConvolution([gal, psf]) + .drawImage(BP, scale=0.2, nx=64, ny=64, n_waves=32) + .array.sum() + ) + + result = render(jnp.ones(256)) + assert float(result) == pytest.approx(200.0, rel=5e-2) + + +def test_chromatic_convolution_grad(): + @jax.jit + def render(flux): + sed = jgal.SED(WAVE, flux) + gal = jgal.Gaussian(half_light_radius=0.5) * sed + psf = ChromaticAtmosphere(fwhm_ref=0.7, lam_ref=700.0, alpha=-0.2) + return ( + ChromaticConvolution([gal, psf]) + .drawImage(BP, scale=0.2, nx=64, ny=64, n_waves=32) + .array.sum() + ) + + grad = jax.grad(render)(jnp.ones(256)) + grad_arr = jnp.asarray(grad) + + # Gradient is nonzero only at wavelengths that fall inside the bandpass + # (the 32 quadrature points lie in [550, 750] nm). + # - sum of all gradients ≈ total bandpass flux = 200 + assert float(grad.sum()) == pytest.approx(200.0, rel=5e-2) + # - max gradient is positive + assert float(grad.max()) > 0.0 + # - indices corresponding to wavelengths outside bandpass have zero gradient + # ~420 nm → outside [550, 750] bandpass + idx_out = int((420 - 400) / (900 - 400) * 255) + assert float(grad[idx_out]) == pytest.approx(0.0, abs=1e-8) + # - indices well inside bandpass region have nonzero total contribution + mask_in = (WAVE >= 550) & (WAVE <= 750) + assert float(grad_arr[mask_in].sum()) > 0.0 + + +def test_chromatic_convolution_linearity(): + """Doubling SED flux doubles image sum.""" + + @jax.jit + def render(flux): + sed = jgal.SED(WAVE, flux) + gal = jgal.Gaussian(half_light_radius=0.5) * sed + psf = ChromaticAtmosphere(fwhm_ref=0.7, lam_ref=700.0, alpha=-0.2) + return ( + ChromaticConvolution([gal, psf]) + .drawImage(BP, scale=0.2, nx=64, ny=64, n_waves=32) + .array.sum() + ) + + s1 = float(render(jnp.ones(256))) + s2 = float(render(jnp.ones(256) * 2.0)) + assert s2 == pytest.approx(s1 * 2.0, rel=1e-4) + + +# --------------------------------------------------------------------------- +# ChromaticSum tests +# --------------------------------------------------------------------------- + + +def test_chromatic_sum_flux(): + sed1 = flat_sed(1.0) + sed2 = flat_sed(2.0) + gal1 = jgal.Gaussian(half_light_radius=0.3) * sed1 + gal2 = jgal.Gaussian(half_light_radius=0.8) * sed2 + combined = gal1 + gal2 + img = combined.drawImage(BP, scale=0.2, nx=32, ny=32) + # Total flux = (1 + 2) × 200 = 600 + assert float(img.array.sum()) == pytest.approx(600.0, rel=1e-2) + + +# --------------------------------------------------------------------------- +# Monkey-patch (GSObject * SED) tests +# --------------------------------------------------------------------------- + + +def test_gsobject_mul_sed(): + sed = flat_sed() + gal = jgal.Gaussian(half_light_radius=0.5) + from jax_galsim.chromatic import Chromatic + + result = gal * sed + assert isinstance(result, Chromatic) + + +def test_gsobject_mul_scalar(): + gal = jgal.Gaussian(half_light_radius=0.5, flux=1.0) + scaled = gal * 5.0 + assert float(scaled.flux) == pytest.approx(5.0) + + +if __name__ == "__main__": + # Run all tests inline for quick check + import sys + + tests = [ + test_sed_evaluation, + test_sed_redshift, + test_sed_calculate_flux, + test_sed_pytree_roundtrip, + test_sed_arithmetic, + test_bandpass_evaluation, + test_bandpass_effective_wavelength, + test_bandpass_effective_wavelength_concrete, + test_bandpass_mul, + test_bandpass_pytree_roundtrip, + test_chromatic_construction, + test_chromatic_drawImage_flux, + test_chromatic_drawImage_narrow_bandpass, + test_chromatic_jit, + test_chromatic_grad, + test_chromatic_jit_recompile, + test_chromatic_atmosphere_evaluate, + test_chromatic_atmosphere_scaling, + test_chromatic_atmosphere_drawImage_flux, + test_chromatic_atmosphere_moffat, + test_chromatic_atmosphere_pytree, + test_chromatic_convolution_flux, + test_chromatic_convolution_jit, + test_chromatic_convolution_grad, + test_chromatic_convolution_linearity, + test_chromatic_sum_flux, + test_gsobject_mul_sed, + test_gsobject_mul_scalar, + ] + + failed = [] + for t in tests: + try: + t() + print(f" PASS {t.__name__}") + except Exception as e: + print(f" FAIL {t.__name__}: {e}") + failed.append(t.__name__) + + print() + print(f"{len(tests) - len(failed)}/{len(tests)} passed") + if failed: + print("Failed:", failed) + sys.exit(1)