-
Notifications
You must be signed in to change notification settings - Fork 9
FEAT - Add chromatic psf and the corresponding tests (Issue #251) #253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b890474
3a2a14f
67dcbae
c0cd2de
539826b
19e960d
0d0f51f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| Chromatic Profiles | ||
| ================== | ||
|
|
||
| .. currentmodule:: jax_galsim | ||
|
|
||
| JAX-GalSim supports a JAX-native subset of GalSim chromatic rendering. The | ||
| core use case is a wavelength-dependent PSF convolved with a source whose SED | ||
| is a traced JAX array: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import jax_galsim as galsim | ||
|
|
||
| wave = jnp.linspace(400.0, 900.0, 256) | ||
| sed = galsim.SED(wave, jnp.ones_like(wave)) | ||
| bandpass = galsim.Bandpass.tophat(550.0, 750.0) | ||
|
|
||
| gal = galsim.Gaussian(half_light_radius=0.5) * sed | ||
| psf = galsim.ChromaticAtmosphere(fwhm_ref=0.7, lam_ref=700.0, alpha=-0.2) | ||
| final = galsim.ChromaticConvolution([gal, psf]) | ||
| image = final.drawImage(bandpass, nx=64, ny=64, scale=0.2, n_waves=32) | ||
|
|
||
| Separable chromatic objects, such as ``GSObject * SED``, are rendered by | ||
| integrating only the scalar spectral weight, then drawing one unit-flux spatial | ||
| profile. Non-separable chromatic objects, such as a chromatic PSF whose size | ||
| changes with wavelength, are rendered by integrating their Fourier-space values | ||
| over a fixed wavelength grid. | ||
|
|
||
| The wavelength grid size is static, while SED flux values are traced. This | ||
| keeps the rendering path compatible with ``jax.jit`` and ``jax.grad``. | ||
|
|
||
| Spectral objects | ||
| ---------------- | ||
|
|
||
| .. autoclass:: SED | ||
| :members: | ||
| :show-inheritance: | ||
|
|
||
| .. autoclass:: Bandpass | ||
| :members: | ||
| :show-inheritance: | ||
|
|
||
| Chromatic objects | ||
| ----------------- | ||
|
|
||
| .. autoclass:: ChromaticObject | ||
| :members: | ||
| :show-inheritance: | ||
|
|
||
| .. autoclass:: Chromatic | ||
| :members: | ||
| :show-inheritance: | ||
|
|
||
| .. autoclass:: ChromaticAtmosphere | ||
| :members: | ||
| :show-inheritance: | ||
|
|
||
| .. autoclass:: ChromaticConvolution | ||
| :members: | ||
| :show-inheritance: | ||
|
|
||
| .. autoclass:: ChromaticSum | ||
| :members: | ||
| :show-inheritance: | ||
|
|
||
| Compatibility notes | ||
| ------------------- | ||
|
|
||
| ``galsim.Convolve`` dispatches to ``ChromaticConvolution`` when any input is | ||
| chromatic. Multiplication follows GalSim's common pattern: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| chromatic_gal = galsim.Gaussian(half_light_radius=0.5) * sed | ||
| same_object = sed * galsim.Gaussian(half_light_radius=0.5) | ||
|
|
||
| The current implementation focuses on differentiable array-backed SEDs, | ||
| array-backed bandpasses, separable chromatic sources, and non-separable | ||
| Gaussian/Moffat atmospheric PSFs. Full GalSim spectral I/O, lookup-table | ||
| metadata, Airy/OpticalPSF chromatic optics, and photon shooting are still | ||
| outside this JAX-native subset. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ API Reference | |
| weak-lensing | ||
| wcs | ||
| noise | ||
| chromatic | ||
| photon_shooting | ||
| interpolation | ||
| fits | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,196 @@ | ||||||
| """Bandpass filter representation for chromatic simulations.""" | ||||||
|
|
||||||
| import jax.numpy as jnp | ||||||
| from jax.tree_util import register_pytree_node_class | ||||||
|
|
||||||
|
|
||||||
| @register_pytree_node_class | ||||||
| class Bandpass: | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not repeat doc strings from upstream and instead use the |
||||||
| """Wavelength-dependent throughput of an observing bandpass. | ||||||
|
|
||||||
| Represents the combined throughput of optics, filter, and detector | ||||||
| as a function of wavelength. The throughput array is a JAX-traced | ||||||
| parameter, so gradients can flow through it if needed (e.g. for | ||||||
| filter-design optimisation). | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| wave : array_like | ||||||
| Wavelength array **in nanometers**, strictly increasing. | ||||||
| Treated as static (not traced by JAX). | ||||||
| throughput : array_like | ||||||
| Dimensionless throughput ∈ [0, 1] at each wavelength. | ||||||
| Treated as a JAX-traced parameter. | ||||||
| blue_limit : float, optional | ||||||
| Override the short-wavelength cut-off in nm. Defaults to | ||||||
| ``wave[0]``. | ||||||
| red_limit : float, optional | ||||||
| Override the long-wavelength cut-off in nm. Defaults to | ||||||
| ``wave[-1]``. | ||||||
|
|
||||||
| Examples | ||||||
| -------- | ||||||
| Construct from arrays:: | ||||||
|
|
||||||
| >>> import jax.numpy as jnp | ||||||
| >>> from jax_galsim.bandpass import Bandpass | ||||||
| >>> wave = jnp.linspace(550, 700, 100) | ||||||
| >>> bp = Bandpass(wave, jnp.ones(100)) | ||||||
| >>> float(bp(625.0)) | ||||||
| 1.0 | ||||||
| >>> float(bp.effective_wavelength) | ||||||
| 625.0 | ||||||
|
|
||||||
| Top-hat filter between 600 nm and 700 nm:: | ||||||
|
|
||||||
| >>> wave = jnp.array([500., 600., 600.001, 700., 700.001, 800.]) | ||||||
| >>> thru = jnp.array([0., 0., 1., 1., 0., 0.]) | ||||||
| >>> bp = Bandpass(wave, thru) | ||||||
| """ | ||||||
|
|
||||||
| def __init__(self, wave, throughput, blue_limit=None, red_limit=None): | ||||||
| self._wave = jnp.asarray(wave, dtype=float) | ||||||
| self._throughput = jnp.asarray(throughput) | ||||||
|
|
||||||
| if self._wave.ndim != 1 or len(self._wave) < 2: | ||||||
| raise ValueError("wave must be a 1-D array with at least 2 elements.") | ||||||
| if len(self._throughput) != len(self._wave): | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
It may be more robust to check the shape directly? |
||||||
| raise ValueError("throughput must have the same length as wave.") | ||||||
|
|
||||||
| self._blue_limit = ( | ||||||
| float(blue_limit) if blue_limit is not None else float(self._wave[0]) | ||||||
| ) | ||||||
| self._red_limit = ( | ||||||
| float(red_limit) if red_limit is not None else float(self._wave[-1]) | ||||||
| ) | ||||||
|
|
||||||
| # Precompute effective wavelength at construction time so it is a | ||||||
| # concrete Python float and can be used as a static value under JIT. | ||||||
| _w = jnp.linspace(self._blue_limit, self._red_limit, 512) | ||||||
| _t = jnp.interp(_w, self._wave, self._throughput) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This interpolation is linear. There is a slightly better akima interpolant that is fast to compute on-the-fly. See |
||||||
| _norm = jnp.trapezoid(_t, _w) | ||||||
| self._effective_wavelength_val = ( | ||||||
| float(jnp.trapezoid(_w * _t, _w) / _norm) | ||||||
| if float(_norm) > 0 | ||||||
| else float(0.5 * (self._blue_limit + self._red_limit)) | ||||||
| ) | ||||||
|
Comment on lines
+67
to
+76
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here |
||||||
|
|
||||||
| # ------------------------------------------------------------------ | ||||||
| # Properties | ||||||
| # ------------------------------------------------------------------ | ||||||
|
|
||||||
| @property | ||||||
| def wave(self): | ||||||
| """Wavelength grid in nm (JAX array, static).""" | ||||||
| return self._wave | ||||||
|
|
||||||
| @property | ||||||
| def throughput(self): | ||||||
| """Throughput array (JAX array, traced).""" | ||||||
| return self._throughput | ||||||
|
|
||||||
| @property | ||||||
| def blue_limit(self): | ||||||
| """Short-wavelength cut-off in nm.""" | ||||||
| return self._blue_limit | ||||||
|
|
||||||
| @property | ||||||
| def red_limit(self): | ||||||
| """Long-wavelength cut-off in nm.""" | ||||||
| return self._red_limit | ||||||
|
|
||||||
| @property | ||||||
| def effective_wavelength(self): | ||||||
| """Flux-weighted mean wavelength (concrete Python float, safe under JIT).""" | ||||||
| return self._effective_wavelength_val | ||||||
|
|
||||||
| # ------------------------------------------------------------------ | ||||||
| # Evaluation | ||||||
| # ------------------------------------------------------------------ | ||||||
|
|
||||||
| def __call__(self, wave): | ||||||
| """Evaluate throughput at wavelength(s) in nm. | ||||||
|
|
||||||
| Returns 0 outside the defined range. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| wave : float or array_like | ||||||
| Wavelength(s) in nm. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| jnp.ndarray | ||||||
| Throughput at the requested wavelengths. | ||||||
| """ | ||||||
| wave = jnp.asarray(wave, dtype=float) | ||||||
| return jnp.interp(wave, self._wave, self._throughput, left=0.0, right=0.0) | ||||||
|
|
||||||
| # ------------------------------------------------------------------ | ||||||
| # Arithmetic | ||||||
| # ------------------------------------------------------------------ | ||||||
|
|
||||||
| def __mul__(self, other): | ||||||
| """Multiply throughput by a scalar or another Bandpass.""" | ||||||
| if isinstance(other, Bandpass): | ||||||
| # Union wavelength grid | ||||||
| wave = jnp.unique(jnp.concatenate([self._wave, other._wave])) | ||||||
| t = jnp.interp(wave, self._wave, self._throughput, left=0.0, right=0.0) | ||||||
| t2 = jnp.interp(wave, other._wave, other._throughput, left=0.0, right=0.0) | ||||||
| blue = max(self._blue_limit, other._blue_limit) | ||||||
| red = min(self._red_limit, other._red_limit) | ||||||
| return Bandpass(wave, t * t2, blue_limit=blue, red_limit=red) | ||||||
| return Bandpass( | ||||||
| self._wave, self._throughput * other, self._blue_limit, self._red_limit | ||||||
| ) | ||||||
|
|
||||||
| def __rmul__(self, other): | ||||||
| return self.__mul__(other) | ||||||
|
|
||||||
| def truncate(self, blue_limit=None, red_limit=None): | ||||||
| """Return a new Bandpass with tighter wavelength limits.""" | ||||||
| blue = blue_limit if blue_limit is not None else self._blue_limit | ||||||
| red = red_limit if red_limit is not None else self._red_limit | ||||||
| return Bandpass(self._wave, self._throughput, blue_limit=blue, red_limit=red) | ||||||
|
|
||||||
| # ------------------------------------------------------------------ | ||||||
| # Convenience constructors | ||||||
| # ------------------------------------------------------------------ | ||||||
|
|
||||||
| @classmethod | ||||||
| def tophat(cls, blue_limit, red_limit, n_wave=100): | ||||||
| """Uniform throughput = 1 between blue_limit and red_limit.""" | ||||||
| wave = jnp.linspace(blue_limit, red_limit, n_wave) | ||||||
| return cls(wave, jnp.ones(n_wave)) | ||||||
|
|
||||||
| # ------------------------------------------------------------------ | ||||||
| # JAX pytree interface | ||||||
| # ------------------------------------------------------------------ | ||||||
|
|
||||||
| def tree_flatten(self): | ||||||
| children = (self._throughput,) | ||||||
| aux_data = { | ||||||
| "wave": tuple(self._wave.tolist()), | ||||||
| "blue_limit": self._blue_limit, | ||||||
| "red_limit": self._red_limit, | ||||||
| } | ||||||
| return (children, aux_data) | ||||||
|
|
||||||
| @classmethod | ||||||
| def tree_unflatten(cls, aux_data, children): | ||||||
| return cls( | ||||||
| wave=jnp.asarray(aux_data["wave"], dtype=float), | ||||||
| throughput=children[0], | ||||||
| blue_limit=aux_data["blue_limit"], | ||||||
| red_limit=aux_data["red_limit"], | ||||||
| ) | ||||||
|
|
||||||
| # ------------------------------------------------------------------ | ||||||
| # Misc | ||||||
| # ------------------------------------------------------------------ | ||||||
|
|
||||||
| def __repr__(self): | ||||||
| return ( | ||||||
| f"Bandpass(wave=[{self._blue_limit:.1f}, {self._red_limit:.1f}] nm, " | ||||||
| f"lam_eff={self.effective_wavelength:.1f} nm)" | ||||||
| ) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use this title to match the upstream documentation. The galsim docs further use separate pages for the SEDS, Bandpasses, and Chromatic objects, but we can skip that bit.