diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..e33c971e --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,17 @@ +# Read the Docs configuration +# https://docs.readthedocs.io/en/stable/config-file/v2.html + +version: 2 + +build: + os: ubuntu-24.04 + tools: + python: "3.12" + commands: + - pip install uv + - uv sync --extra docs + - uv run sphinx-build -b html docs $READTHEDOCS_OUTPUT/html + +sphinx: + configuration: docs/conf.py + fail_on_warning: false diff --git a/README.md b/README.md index dd20a8a2..e277bc11 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,13 @@ **JAX port of GalSim, for parallelized, GPU accelerated, and differentiable galaxy image simulations.** -[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg)](CODE_OF_CONDUCT.md) [![Python package](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml/badge.svg)](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml) [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/GalSim-developers/JAX-GalSim/main.svg)](https://results.pre-commit.ci/latest/github/GalSim-developers/JAX-GalSim/main) [![CodSpeed Badge](https://img.shields.io/endpoint?url=https://codspeed.io/badge.json)](https://codspeed.io/GalSim-developers/JAX-GalSim) +[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg)](CODE_OF_CONDUCT.md) [![Python package](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml/badge.svg)](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml) [![Documentation Status](https://readthedocs.org/projects/jax-galsim/badge/?version=latest)](https://jax-galsim.readthedocs.io/en/latest/) [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/GalSim-developers/JAX-GalSim/main.svg)](https://results.pre-commit.ci/latest/github/GalSim-developers/JAX-GalSim/main) [![CodSpeed Badge](https://img.shields.io/endpoint?url=https://codspeed.io/badge.json)](https://codspeed.io/GalSim-developers/JAX-GalSim) **Disclaimer**: This project is still in an early development phase, **please use the [reference GalSim implementation](https://github.com/GalSim-developers/GalSim) for any scientific applications.** +**You can find the most up-to-date documentation for the project [here](https://jax-galsim.readthedocs.io/en/latest).** + + ## Objective and Design The goal of this library is to reimplement GalSim functionalities in pure JAX to allow for automatic differentiation, GPU acceleration, and batched computations. @@ -44,124 +47,4 @@ about the inner workings of GalSim and how to code in JAX. ## Current GalSim API Coverage - -JAX-GalSim has implemented 22.5% of the GalSim API. See the list below for the supported APIs. - -
- -- galsim.Add -- galsim.AffineTransform -- galsim.Angle -- galsim.AngleUnit -- galsim.BaseDeviate -- galsim.BaseNoise -- galsim.BaseWCS -- galsim.BinomialDeviate -- galsim.Bounds -- galsim.BoundsD -- galsim.BoundsI -- galsim.Box -- galsim.CCDNoise -- galsim.CelestialCoord -- galsim.Chi2Deviate -- galsim.Convolution -- galsim.Convolve -- galsim.Cubic -- galsim.Deconvolution -- galsim.Deconvolve -- galsim.Delta -- galsim.DeltaFunction -- galsim.DeviateNoise -- galsim.Exponential -- galsim.FitsHeader -- galsim.FitsWCS -- galsim.GSFitsWCS -- galsim.GSObject -- galsim.GSParams -- galsim.GalSimBoundsError -- galsim.GalSimConfigError -- galsim.GalSimConfigValueError -- galsim.GalSimDeprecationWarning -- galsim.GalSimError -- galsim.GalSimFFTSizeError -- galsim.GalSimHSMError -- galsim.GalSimImmutableError -- galsim.GalSimIncompatibleValuesError -- galsim.GalSimIndexError -- galsim.GalSimKeyError -- galsim.GalSimNotImplementedError -- galsim.GalSimRangeError -- galsim.GalSimSEDError -- galsim.GalSimUndefinedBoundsError -- galsim.GalSimValueError -- galsim.GalSimWarning -- galsim.GammaDeviate -- galsim.Gaussian -- galsim.GaussianDeviate -- galsim.GaussianNoise -- galsim.Image -- galsim.ImageCD -- galsim.ImageCF -- galsim.ImageD -- galsim.ImageF -- galsim.ImageI -- galsim.ImageS -- galsim.ImageUI -- galsim.ImageUS -- galsim.Interpolant -- galsim.InterpolatedImage -- galsim.JacobianWCS -- galsim.Lanczos -- galsim.Linear -- galsim.Moffat -- galsim.Nearest -- galsim.OffsetShearWCS -- galsim.OffsetWCS -- galsim.PhotonArray -- galsim.Pixel -- galsim.PixelScale -- galsim.PoissonDeviate -- galsim.PoissonNoise -- galsim.Position -- galsim.PositionD -- galsim.PositionI -- galsim.Quintic -- galsim.Sensor -- galsim.Shear -- galsim.ShearWCS -- galsim.SincInterpolant -- galsim.Spergel -- galsim.Sum -- galsim.TanWCS -- galsim.Transform -- galsim.Transformation -- galsim.UniformDeviate -- galsim.VariableGaussianNoise -- galsim.WeibullDeviate -- galsim.bessel.j0 -- galsim.bessel.kv -- galsim.bessel.si -- galsim.fits.closeHDUList -- galsim.fits.readCube -- galsim.fits.readFile -- galsim.fits.readMulti -- galsim.fits.write -- galsim.fits.writeFile -- galsim.fitswcs.CelestialWCS -- galsim.integ.int1d -- galsim.noise.addNoise -- galsim.noise.addNoiseSNR -- galsim.random.permute -- galsim.utilities.g1g2_to_e1e2 -- galsim.utilities.horner -- galsim.utilities.printoptions -- galsim.utilities.unweighted_moments -- galsim.utilities.unweighted_shape -- galsim.wcs.EuclideanWCS -- galsim.wcs.LocalWCS -- galsim.wcs.UniformWCS - -
- - -_**Note**: The coverage list is generated automatically by the `scripts/update_api_coverage.py` script. To update it, run `python scripts/update_api_coverage.py` from the root of the repository._ +JAX-GalSim current has implemented **22.5%** of the GalSim API. See the corresponding [documentation page](https://jax-galsim.readthedocs.io/en/latest/api-coverage) for a list of what is currently implemented. diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..f9d3fc9c --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,24 @@ +# Minimal Sphinx Makefile — uses uv run so all package deps are always available. + +SPHINXBUILD = uv run sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +.PHONY: help html clean livehtml + +help: + @echo "Usage:" + @echo " make html Build the HTML documentation" + @echo " make clean Remove the build directory" + @echo " make livehtml Auto-rebuild on file changes (requires sphinx-autobuild)" + +html: + $(SPHINXBUILD) -b html $(SOURCEDIR) $(BUILDDIR)/html + +clean: + rm -rf $(BUILDDIR) + +livehtml: + uv run sphinx-autobuild $(SOURCEDIR) $(BUILDDIR)/html \ + --ignore "**/_build/**" \ + --ignore "**/.DS_Store" diff --git a/docs/_ext/galsim_docstring.py b/docs/_ext/galsim_docstring.py new file mode 100644 index 00000000..b58b2a04 --- /dev/null +++ b/docs/_ext/galsim_docstring.py @@ -0,0 +1,246 @@ +""" +Custom Sphinx extension for JAX-GalSim documentation. + +Processes docstrings produced by the ``@implements`` decorator. Each such +docstring contains a ``*Original docstring below.*`` marker that separates +the JAX-specific summary/lax_description from the upstream GalSim text. + +This extension: + +1. Extracts the ``Parameters:`` section from the original GalSim block and + re-injects it *before* the collapsible so that Sphinx / Napoleon renders + the parameters as normal field-list entries. +2. Wraps the rest of the original GalSim narrative in a ``sphinx-design`` + ``.. dropdown::`` directive so it is collapsible in the HTML output. +""" + +import re + +# The literal text injected by the ``implements`` decorator. +_MARKER = "*Original docstring below.*" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _leading_spaces(line: str) -> int: + return len(line) - len(line.lstrip()) + + +def _parse_galsim_params(lines: list[str]) -> list[tuple[str, str]]: + """Return ``[(name, description), …]`` parsed from a GalSim Parameters block. + + GalSim parameter entries look like:: + + Parameters: + sigma: The sigma of the profile. Typically in arcsec. + [One of ``sigma``, ``fwhm``, or ``half_light_radius``.] + flux: The flux of the profile. [default: 1] + + The function tolerates multi-line descriptions and mixed whitespace. + """ + params: list[tuple[str, str]] = [] + in_params = False + param_indent: int | None = None + current_name: str | None = None + current_desc: list[str] = [] + + for line in lines: + stripped = line.rstrip() + + # Detect the "Parameters:" heading + if re.match(r"^\s*Parameters\s*:", stripped): + in_params = True + param_indent = None + continue + + if not in_params: + continue + + # A blank line inside the params block is OK (continuation) but + # a new top-level heading ends the block. + if not stripped: + continue + + indent = _leading_spaces(stripped) + + # A zero-indent non-empty line means we've left the section. + if indent == 0: + break + + # First non-empty, non-zero-indent line sets the expected indent level. + if param_indent is None: + param_indent = indent + + if indent == param_indent: + # Save the previous parameter before starting a new one. + if current_name is not None: + params.append((current_name, " ".join(current_desc))) + current_desc = [] + m = re.match(r"^(\w+)\s*:\s*(.*)", stripped.strip()) + if m: + current_name = m.group(1) + first_desc = m.group(2).strip() + if first_desc: + current_desc = [first_desc] + else: + current_name = None + elif indent > param_indent and current_name is not None: + # Continuation of the previous parameter description. + current_desc.append(stripped.strip()) + + # Flush the last parameter. + if current_name is not None: + params.append((current_name, " ".join(current_desc))) + + return params + + +def _remove_params_section(lines: list[str]) -> list[str]: + """Return *lines* with the ``Parameters:`` block removed.""" + result: list[str] = [] + in_params = False + param_indent: int | None = None + + for line in lines: + stripped = line.rstrip() + + if re.match(r"^\s*Parameters\s*:", stripped) and not in_params: + in_params = True + param_indent = _leading_spaces(stripped) + continue + + if in_params: + if not stripped: + # Skip blank lines that belong to the params section. + continue + indent = _leading_spaces(stripped) + # Any line at indent ≤ "Parameters:" indent is a new section. + if indent <= param_indent: + in_params = False + result.append(line) + # else: still inside the params block → skip it. + else: + result.append(line) + + return result + + +# --------------------------------------------------------------------------- +# Main event handler +# --------------------------------------------------------------------------- + + +def _process_galsim_docstring( + app, what: str, name: str, obj, options, lines: list[str] +) -> None: + """``autodoc-process-docstring`` handler.""" + + # Find the marker line. + marker_idx: int | None = None + for i, line in enumerate(lines): + if _MARKER in line: + marker_idx = i + break + + if marker_idx is None: + return + + # --- split --- + jax_lines = lines[:marker_idx] + original_lines = lines[marker_idx + 1 :] + + # Trim trailing blank lines from the JAX section. + while jax_lines and not jax_lines[-1].strip(): + jax_lines.pop() + + # Trim leading blank lines from the original section. + while original_lines and not original_lines[0].strip(): + original_lines.pop(0) + + # --- extract parameters --- + params = _parse_galsim_params(original_lines) + original_no_params = _remove_params_section(original_lines) + + # Trim trailing blank lines from the narrative. + while original_no_params and not original_no_params[-1].strip(): + original_no_params.pop() + + # --- split jax_lines into summary+LAX-ref and lax_description --- + # The @implements decorator always injects "LAX-backend implementation of …" + # as the second paragraph. Any content after that line is lax_description. + lax_ref_idx: int | None = None + for i, line in enumerate(jax_lines): + if "LAX-backend implementation of" in line: + lax_ref_idx = i + break + + if lax_ref_idx is not None: + header_lines = jax_lines[: lax_ref_idx + 1] + desc_lines = jax_lines[lax_ref_idx + 1 :] + # Strip surrounding blank lines from the description block. + while desc_lines and not desc_lines[0].strip(): + desc_lines.pop(0) + while desc_lines and not desc_lines[-1].strip(): + desc_lines.pop() + else: + header_lines = jax_lines + desc_lines = [] + + # --- build the replacement lines --- + new_lines: list[str] = list(header_lines) + new_lines.append("") + + # Wrap lax_description in a Sharp Bits admonition when present. + if desc_lines: + new_lines.append( + ".. admonition:: \U0001f52a JAX-GalSim - The Sharp Bits \U0001f52a" + ) + new_lines.append(" :class: warning") + new_lines.append("") + for line in desc_lines: + if line.strip(): + new_lines.append(" " + line) + else: + new_lines.append("") + new_lines.append("") + + # Inject parameters in Google style so Napoleon renders them properly. + if params: + new_lines.append("Parameters:") + for pname, pdesc in params: + # Use 4-space indent, which Napoleon / Google style expects. + new_lines.append(f" {pname}: {pdesc}") + new_lines.append("") + + # Wrap the original narrative in a collapsible dropdown. + has_content = any(line.strip() for line in original_no_params) + if has_content: + new_lines.append(".. dropdown:: Original GalSim Documentation") + new_lines.append(" :class-container: sd-shadow-sm") + new_lines.append(" :color: secondary") + new_lines.append("") + for line in original_no_params: + if line.strip(): + new_lines.append(" " + line) + else: + new_lines.append("") + new_lines.append("") + + lines[:] = new_lines + + +# --------------------------------------------------------------------------- +# Extension setup +# --------------------------------------------------------------------------- + + +def setup(app): + app.connect("autodoc-process-docstring", _process_galsim_docstring) + return { + "version": "0.1", + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/docs/_static/custom.css b/docs/_static/custom.css new file mode 100644 index 00000000..d0d10190 --- /dev/null +++ b/docs/_static/custom.css @@ -0,0 +1,42 @@ +/* + * JAX-GalSim documentation – custom CSS overrides + * + * These styles are applied on top of the Furo theme. + */ + +/* ── Collapsible "Original GalSim Documentation" dropdowns ──────────────── */ + +/* Give the dropdown a left border accent to visually separate it from the + surrounding JAX-specific content. */ +.sd-dropdown { + border-left: 3px solid var(--color-brand-primary, #4a90d9); + margin-top: 1em; + margin-bottom: 1em; +} + +/* Make the dropdown header slightly muted so the JAX content remains the + primary visual focus. */ +.sd-summary-title { + font-size: 0.9em; + font-style: italic; + color: var(--color-foreground-muted, #666); +} + +/* ── Parameter field lists ───────────────────────────────────────────────── */ + +/* Slightly tighten up the :param: / :type: field list spacing. */ +dl.field-list > dt { + font-weight: 600; +} + +dl.field-list > dd { + margin-bottom: 0.4em; +} + +/* ── General readability tweaks ─────────────────────────────────────────── */ + +/* Ensure math blocks inside dropdowns are legible. */ +.sd-card-body .math, +.sd-dropdown .math { + overflow-x: auto; +} diff --git a/docs/api-coverage.rst b/docs/api-coverage.rst new file mode 100644 index 00000000..0fe64485 --- /dev/null +++ b/docs/api-coverage.rst @@ -0,0 +1,132 @@ +API Coverage +============ + +JAX-GalSim has implemented **22.5 %** of the GalSim API. The project focuses on +the most commonly used profiles and operations, with coverage expanding over time. + +Supported APIs +-------------- + +.. dropdown:: Click to expand the full list of implemented APIs + + - ``galsim.Add`` + - ``galsim.AffineTransform`` + - ``galsim.Angle`` + - ``galsim.AngleUnit`` + - ``galsim.BaseDeviate`` + - ``galsim.BaseNoise`` + - ``galsim.BaseWCS`` + - ``galsim.BinomialDeviate`` + - ``galsim.Bounds`` + - ``galsim.BoundsD`` + - ``galsim.BoundsI`` + - ``galsim.Box`` + - ``galsim.CCDNoise`` + - ``galsim.CelestialCoord`` + - ``galsim.Chi2Deviate`` + - ``galsim.Convolution`` + - ``galsim.Convolve`` + - ``galsim.Cubic`` + - ``galsim.Deconvolution`` + - ``galsim.Deconvolve`` + - ``galsim.Delta`` + - ``galsim.DeltaFunction`` + - ``galsim.DeviateNoise`` + - ``galsim.Exponential`` + - ``galsim.FitsHeader`` + - ``galsim.FitsWCS`` + - ``galsim.GSFitsWCS`` + - ``galsim.GSObject`` + - ``galsim.GSParams`` + - ``galsim.GalSimBoundsError`` + - ``galsim.GalSimConfigError`` + - ``galsim.GalSimConfigValueError`` + - ``galsim.GalSimDeprecationWarning`` + - ``galsim.GalSimError`` + - ``galsim.GalSimFFTSizeError`` + - ``galsim.GalSimHSMError`` + - ``galsim.GalSimImmutableError`` + - ``galsim.GalSimIncompatibleValuesError`` + - ``galsim.GalSimIndexError`` + - ``galsim.GalSimKeyError`` + - ``galsim.GalSimNotImplementedError`` + - ``galsim.GalSimRangeError`` + - ``galsim.GalSimSEDError`` + - ``galsim.GalSimUndefinedBoundsError`` + - ``galsim.GalSimValueError`` + - ``galsim.GalSimWarning`` + - ``galsim.GammaDeviate`` + - ``galsim.Gaussian`` + - ``galsim.GaussianDeviate`` + - ``galsim.GaussianNoise`` + - ``galsim.Image`` + - ``galsim.ImageCD`` + - ``galsim.ImageCF`` + - ``galsim.ImageD`` + - ``galsim.ImageF`` + - ``galsim.ImageI`` + - ``galsim.ImageS`` + - ``galsim.ImageUI`` + - ``galsim.ImageUS`` + - ``galsim.Interpolant`` + - ``galsim.InterpolatedImage`` + - ``galsim.JacobianWCS`` + - ``galsim.Lanczos`` + - ``galsim.Linear`` + - ``galsim.Moffat`` + - ``galsim.Nearest`` + - ``galsim.OffsetShearWCS`` + - ``galsim.OffsetWCS`` + - ``galsim.PhotonArray`` + - ``galsim.Pixel`` + - ``galsim.PixelScale`` + - ``galsim.PoissonDeviate`` + - ``galsim.PoissonNoise`` + - ``galsim.Position`` + - ``galsim.PositionD`` + - ``galsim.PositionI`` + - ``galsim.Quintic`` + - ``galsim.Sensor`` + - ``galsim.Shear`` + - ``galsim.ShearWCS`` + - ``galsim.SincInterpolant`` + - ``galsim.Spergel`` + - ``galsim.Sum`` + - ``galsim.TanWCS`` + - ``galsim.Transform`` + - ``galsim.Transformation`` + - ``galsim.UniformDeviate`` + - ``galsim.VariableGaussianNoise`` + - ``galsim.WeibullDeviate`` + - ``galsim.bessel.j0`` + - ``galsim.bessel.kv`` + - ``galsim.bessel.si`` + - ``galsim.fits.closeHDUList`` + - ``galsim.fits.readCube`` + - ``galsim.fits.readFile`` + - ``galsim.fits.readMulti`` + - ``galsim.fits.write`` + - ``galsim.fits.writeFile`` + - ``galsim.fitswcs.CelestialWCS`` + - ``galsim.integ.int1d`` + - ``galsim.noise.addNoise`` + - ``galsim.noise.addNoiseSNR`` + - ``galsim.random.permute`` + - ``galsim.utilities.g1g2_to_e1e2`` + - ``galsim.utilities.horner`` + - ``galsim.utilities.printoptions`` + - ``galsim.utilities.unweighted_moments`` + - ``galsim.utilities.unweighted_shape`` + - ``galsim.wcs.EuclideanWCS`` + - ``galsim.wcs.LocalWCS`` + - ``galsim.wcs.UniformWCS`` + +Updating Coverage +----------------- + +.. code-block:: bash + + python scripts/update_api_coverage.py + +Compares GalSim's public API against ``jax_galsim``'s implementations and +updates the coverage percentage and list above. diff --git a/docs/api/core.rst b/docs/api/core.rst new file mode 100644 index 00000000..28c0d855 --- /dev/null +++ b/docs/api/core.rst @@ -0,0 +1,34 @@ +Core Utilities +============== + +Please note that items from the ``jax_galsim.core`` are internal APIs and should not be relied upon by external code. + + +Math +--------- + +.. automodule:: jax_galsim.core.math + :members: + :undoc-members: False + +.. automodule:: jax_galsim.core.integrate + :members: + :undoc-members: False + +.. automodule:: jax_galsim.core.interpolate + :members: + :undoc-members: False + +Drawing +--------- + +.. automodule:: jax_galsim.core.draw + :members: + :undoc-members: False + +Utilities +--------- + +.. automodule:: jax_galsim.core.utils + :members: + :undoc-members: False diff --git a/docs/api/fits.rst b/docs/api/fits.rst new file mode 100644 index 00000000..a9fe122b --- /dev/null +++ b/docs/api/fits.rst @@ -0,0 +1,9 @@ +FITS I/O +================== + +.. note:: + The FITS functionality in JAX-GalSim is just a light wrapper around the GalSim FITS functionality that attempts to convert ``Image`` objects. + +.. automodule:: jax_galsim.fits + :members: + :show-inheritance: diff --git a/docs/api/gsobjects.rst b/docs/api/gsobjects.rst new file mode 100644 index 00000000..0d7c3a13 --- /dev/null +++ b/docs/api/gsobjects.rst @@ -0,0 +1,96 @@ +Surface Brightness Profiles +================================== + +.. currentmodule:: jax_galsim + +Base class +---------- + +.. autoclass:: GSObject + :members: + :show-inheritance: + +.. autoclass:: jax_galsim.gsparams.GSParams + :members: + :show-inheritance: + +Analytic profiles +----------------- + +.. autoclass:: Gaussian + :members: + :show-inheritance: + +.. autoclass:: Moffat + :members: + :show-inheritance: + +.. autoclass:: Exponential + :members: + :show-inheritance: + +.. autoclass:: Spergel + :members: + :show-inheritance: + +Pixel / box profiles +-------------------- + +.. autoclass:: Box + :members: + :show-inheritance: + +.. autoclass:: Pixel + :members: + :show-inheritance: + +Compound profiles +----------------- + +.. autoclass:: Sum + :members: + :show-inheritance: + +.. autoclass:: Add + :members: + :show-inheritance: + +.. autoclass:: Convolution + :members: + :show-inheritance: + +.. autoclass:: Convolve + :members: + :show-inheritance: + +.. autoclass:: Deconvolution + :members: + :show-inheritance: + +.. autoclass:: Deconvolve + :members: + :show-inheritance: + +.. autoclass:: DeltaFunction + :members: + :show-inheritance: + + +Interpolated image +------------------ + +.. autoclass:: InterpolatedImage + :members: + :show-inheritance: + + +Transformations +--------------- + +.. autoclass:: Transform + :members: + :show-inheritance: + +.. autoclass:: Transformation + :members: + :show-inheritance: diff --git a/docs/api/image.rst b/docs/api/image.rst new file mode 100644 index 00000000..75bd03ab --- /dev/null +++ b/docs/api/image.rst @@ -0,0 +1,89 @@ +Images and Related Concepts +============================== + +.. currentmodule:: jax_galsim + +Image classes +------------- + +.. autoclass:: Image + :members: + :show-inheritance: + +.. autoclass:: ImageD + :members: + :show-inheritance: + +.. autoclass:: ImageF + :members: + :show-inheritance: + +.. autoclass:: ImageI + :members: + :show-inheritance: + +.. autoclass:: ImageS + :members: + :show-inheritance: + +.. autoclass:: ImageUS + :members: + :show-inheritance: + +.. autoclass:: ImageUI + :members: + :show-inheritance: + +.. autoclass:: ImageCD + :members: + :show-inheritance: + +.. autoclass:: ImageCF + :members: + :show-inheritance: + +Bounds +------ + +.. autoclass:: Bounds + :members: + :show-inheritance: + +.. autoclass:: BoundsI + :members: + :show-inheritance: + +.. autoclass:: BoundsD + :members: + :show-inheritance: + +Positions +--------- + +.. autoclass:: Position + :members: + :show-inheritance: + +.. autoclass:: PositionI + :members: + :show-inheritance: + +.. autoclass:: PositionD + :members: + :show-inheritance: + + +Coordinates +----------- + +.. autoclass:: CelestialCoord + :members: + :show-inheritance: + +.. autoclass:: Angle + :members: + :show-inheritance: + +.. autoclass:: AngleUnit + :members: + :show-inheritance: diff --git a/docs/api/index.rst b/docs/api/index.rst new file mode 100644 index 00000000..95c5ff72 --- /dev/null +++ b/docs/api/index.rst @@ -0,0 +1,21 @@ +API Reference +============= + +.. toctree:: + :maxdepth: 1 + :caption: Modules + + gsobjects + image + weak-lensing + wcs + noise + photon_shooting + interpolation + fits + core + +.. rubric:: Quick index + +* :ref:`genindex` +* :ref:`modindex` diff --git a/docs/api/interpolation.rst b/docs/api/interpolation.rst new file mode 100644 index 00000000..9b77f1ea --- /dev/null +++ b/docs/api/interpolation.rst @@ -0,0 +1,41 @@ +Interpolation +============= + +.. currentmodule:: jax_galsim + + +Interpolants +------------ + +.. autoclass:: Interpolant + :members: + :show-inheritance: + +.. autoclass:: Delta + :members: + :show-inheritance: + +.. autoclass:: Nearest + :members: + :show-inheritance: + +.. autoclass:: SincInterpolant + :members: + :show-inheritance: + +.. autoclass:: Linear + :members: + :show-inheritance: + +.. autoclass:: Cubic + :members: + :show-inheritance: + +.. autoclass:: Quintic + :members: + :show-inheritance: + +.. autoclass:: Lanczos + :members: + :show-inheritance: + diff --git a/docs/api/noise.rst b/docs/api/noise.rst new file mode 100644 index 00000000..90046663 --- /dev/null +++ b/docs/api/noise.rst @@ -0,0 +1,66 @@ +Noise & Random Deviates +======================== + +.. currentmodule:: jax_galsim + +Random deviates +--------------- + +.. autoclass:: BaseDeviate + :members: + :show-inheritance: + +.. autoclass:: UniformDeviate + :members: + :show-inheritance: + +.. autoclass:: GaussianDeviate + :members: + :show-inheritance: + +.. autoclass:: PoissonDeviate + :members: + :show-inheritance: + +.. autoclass:: Chi2Deviate + :members: + :show-inheritance: + +.. autoclass:: GammaDeviate + :members: + :show-inheritance: + +.. autoclass:: WeibullDeviate + :members: + :show-inheritance: + +.. autoclass:: BinomialDeviate + :members: + :show-inheritance: + +Noise models +------------ + +.. autoclass:: BaseNoise + :members: + :show-inheritance: + +.. autoclass:: GaussianNoise + :members: + :show-inheritance: + +.. autoclass:: PoissonNoise + :members: + :show-inheritance: + +.. autoclass:: DeviateNoise + :members: + :show-inheritance: + +.. autoclass:: VariableGaussianNoise + :members: + :show-inheritance: + +.. autoclass:: CCDNoise + :members: + :show-inheritance: diff --git a/docs/api/photon_shooting.rst b/docs/api/photon_shooting.rst new file mode 100644 index 00000000..f0796e86 --- /dev/null +++ b/docs/api/photon_shooting.rst @@ -0,0 +1,22 @@ +Photon Shooting +================== + +.. currentmodule:: jax_galsim + +Photon Arrays +----------------- + +.. autofunction:: jax_galsim.photon_array.fixed_photon_array_size + + +.. autoclass:: PhotonArray + :members: + :show-inheritance: + + +Sensors +---------- + +.. autoclass:: Sensor + :members: + :show-inheritance: diff --git a/docs/api/wcs.rst b/docs/api/wcs.rst new file mode 100644 index 00000000..e7e84ce2 --- /dev/null +++ b/docs/api/wcs.rst @@ -0,0 +1,48 @@ +World Coordinate Systems +========================== + +.. currentmodule:: jax_galsim + +WCS classes +-------------- + +.. autoclass:: BaseWCS + :members: + :show-inheritance: + +.. autoclass:: PixelScale + :members: + :show-inheritance: + +.. autoclass:: OffsetWCS + :members: + :show-inheritance: + +.. autoclass:: JacobianWCS + :members: + :show-inheritance: + +.. autoclass:: AffineTransform + :members: + :show-inheritance: + +.. autoclass:: ShearWCS + :members: + :show-inheritance: + +.. autoclass:: OffsetShearWCS + :members: + :show-inheritance: + +.. autoclass:: TanWCS + :members: + :show-inheritance: + +.. autoclass:: FitsWCS + :members: + :show-inheritance: + +.. autoclass:: GSFitsWCS + :members: + :show-inheritance: + diff --git a/docs/api/weak-lensing.rst b/docs/api/weak-lensing.rst new file mode 100644 index 00000000..6ad3d196 --- /dev/null +++ b/docs/api/weak-lensing.rst @@ -0,0 +1,11 @@ +Weak Lensing +================== + +.. currentmodule:: jax_galsim + +Shear +------- + +.. autoclass:: Shear + :members: + :show-inheritance: diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 00000000..1c4c31a6 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,106 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of configuration options, see: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +import os +import sys + +# Make the package importable without installing it. +sys.path.insert(0, os.path.abspath("..")) +# Make the custom extension importable. +sys.path.insert(0, os.path.abspath("_ext")) + +# --------------------------------------------------------------------------- +# Project information +# --------------------------------------------------------------------------- + +project = "JAX-GalSim" +author = "GalSim Developers" +copyright = "2023, GalSim Developers" + +try: + from jax_galsim._version import version as release +except ImportError: + release = "0.0.1.dev0" + +version = ".".join(release.split(".")[:2]) + +# --------------------------------------------------------------------------- +# General configuration +# --------------------------------------------------------------------------- + +# Extension load order matters: +# 1. sphinx.ext.autodoc – must be first; it defines autodoc-process-docstring +# 2. galsim_docstring – our handler runs before Napoleon sees the lines +# 3. sphinx.ext.napoleon – converts the cleaned-up Parameters: block to RST +extensions = [ + "sphinx.ext.autodoc", # API docs from docstrings (defines the event) + "galsim_docstring", # custom – splits implements() docstrings + "sphinx.ext.napoleon", # Google/NumPy-style docstring parsing + "sphinx.ext.viewcode", # "View source" links + "sphinx.ext.intersphinx", # cross-links to external docs + "sphinx_design", # dropdown / collapsible directives +] + +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +# --------------------------------------------------------------------------- +# Napoleon (docstring parsing) +# --------------------------------------------------------------------------- + +napoleon_google_docstring = True +napoleon_numpy_docstring = False +napoleon_include_init_with_doc = False +napoleon_include_private_with_doc = False +napoleon_include_special_with_doc = False +napoleon_use_admonition_for_examples = False +napoleon_use_admonition_for_notes = True +napoleon_use_admonition_for_references = False +napoleon_use_param = True +napoleon_use_rtype = True +napoleon_preprocess_types = False + +# --------------------------------------------------------------------------- +# Autodoc +# --------------------------------------------------------------------------- + +autodoc_default_options = { + "members": True, + "undoc-members": False, + "show-inheritance": True, + "member-order": "bysource", +} +autodoc_typehints = "description" +autoclass_content = "class" # use only the class docstring (not __init__) + +# Packages that are imported by jax_galsim but may not be present at +# documentation build time. +autodoc_mock_imports = [] + +# --------------------------------------------------------------------------- +# Intersphinx mappings +# --------------------------------------------------------------------------- + +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable", None), + "jax": ("https://jax.readthedocs.io/en/latest", None), + "galsim": ("https://galsim-developers.github.io/GalSim/_build/html", None), +} + +# --------------------------------------------------------------------------- +# HTML output +# --------------------------------------------------------------------------- + +html_theme = "furo" + +html_theme_options = { + "sidebar_hide_name": False, +} + +html_static_path = ["_static"] +html_css_files = ["custom.css"] + +html_title = f"{project} {version}" diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 00000000..9d130899 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,149 @@ +JAX-GalSim +========== + +.. toctree:: + :maxdepth: 1 + :hidden: + + installation + quickstart + sharp-bits + api-coverage + versioning + api/index + +|ci-badge| |ruff-badge| |precommit-badge| + +.. |ci-badge| image:: https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml/badge.svg + :target: https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml + +.. |ruff-badge| image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json + :target: https://github.com/astral-sh/ruff + +.. |precommit-badge| image:: https://results.pre-commit.ci/badge/github/GalSim-developers/JAX-GalSim/main.svg + :target: https://results.pre-commit.ci/latest/github/GalSim-developers/JAX-GalSim/main + +.. warning:: + + This project is still in an early development phase. Please use the + `reference GalSim implementation `_ + for any scientific applications. + +**JAX-GalSim** is a JAX re-implementation of the `GalSim +`_ galaxy image simulation +toolkit. It exposes (nearly) the same API as GalSim while enabling +automatic differentiation, JIT compilation, and hardware acceleration via +`JAX `_. + +Why JAX-GalSim? +--------------- + +.. grid:: 3 + :gutter: 2 + + .. grid-item-card:: ⚡ JIT Compilation + :class-card: sd-border-0 + + Compile simulation pipelines with ``jax.jit`` for significant + speedups, especially on GPU. + + .. grid-item-card:: 🔁 Automatic Differentiation + :class-card: sd-border-0 + + Compute gradients of simulation outputs with respect to galaxy + parameters using ``jax.grad``. + + .. grid-item-card:: 🔀 Vectorization + :class-card: sd-border-0 + + Batch simulations over parameter grids with ``jax.vmap`` — no + explicit loops needed. + +Quick Install +------------- + +.. code-block:: bash + + pip install jax-galsim + +See :doc:`installation` for GPU support and development setup. + +Minimal Example +--------------- + +.. code-block:: python + + import jax + import jax_galsim + + # Define a galaxy and PSF + gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) + psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + + # Convolve and draw + final = jax_galsim.Convolve([gal, psf]) + image = final.drawImage(scale=0.2) + +JAX-GalSim objects are JAX pytrees, so you can JIT-compile and differentiate +the entire pipeline: + +.. code-block:: python + + @jax.jit(static_argnames=['slen', 'fft_size']) + def simulate(flux, sigma, *, slen=21, fft_size=128): + gsparams = jax_galsim.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) + gal = jax_galsim.Gaussian(flux=flux, sigma=sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + return jax_galsim.Convolve([gal, psf]).withGSParams(gsparams) \ + .drawImage(nx=slen, ny=slen, scale=0.2).array.sum() + + # Compute gradients with respect to galaxy parameters + dflux, dsigma = jax.grad(simulate, argnums=(0, 1))(1e5, 2.0) + + +Getting Started +--------------- + +.. grid:: 2 + :gutter: 3 + + .. grid-item-card:: 📖 API Reference + :link: api/index + :link-type: doc + + Auto-generated documentation for every public class, function, and + module in ``jax_galsim``. + + .. grid-item-card:: 🔗 GalSim upstream + :link: https://galsim-developers.github.io/GalSim/_build/html + :link-type: url + + The original GalSim documentation. Many docstrings in JAX-GalSim + are derived from GalSim and expanded with JAX-specific notes. + + .. grid-item-card:: 🚀 Quick Start + :link: quickstart + :link-type: doc + + Walk through a complete simulation with JIT, grad, and vmap. + + .. grid-item-card:: 🔪 JAX-GalSim - The Sharp Bits 🔪 + :link: sharp-bits + :link-type: doc + + What changes when GalSim runs on JAX — immutability, tracing, + PyTrees, and more. + +About the Documentation +------------------------ + +Each class and function that mirrors an upstream GalSim object is annotated +with :func:`jax_galsim.core.utils.implements`. This decorator copies the +original GalSim docstring and prepends any JAX-specific caveats. In the :doc:`api/index` you will therefore find: + +* A **summary** and optional **🔪 JAX-GalSim Sharp Bits** block at the top of + each entry highlighting important caveats. +* An explicit **Parameters** table derived from the original GalSim + documentation. +* A collapsible **Original GalSim Documentation** block containing the full + upstream narrative. diff --git a/docs/installation.rst b/docs/installation.rst new file mode 100644 index 00000000..98fd95c1 --- /dev/null +++ b/docs/installation.rst @@ -0,0 +1,78 @@ +Installation +============ + +Quick Install +------------- + +.. code-block:: bash + + pip install jax-galsim + +This installs JAX-GalSim and its dependencies (JAX, NumPy, GalSim, Astropy). + +GPU Support +----------- + +JAX-GalSim inherits GPU support from JAX. To use NVIDIA GPUs, install the +appropriate JAX variant: + +.. code-block:: bash + + pip install -U "jax[cuda12]" + +See the `JAX installation guide `_ +for other accelerators and platform-specific instructions. + +Development Install +------------------- + +To contribute to JAX-GalSim or run the test suite: + +.. code-block:: bash + + # Clone with submodules (required for GalSim reference tests) + git clone --recurse-submodules https://github.com/GalSim-developers/JAX-GalSim + cd JAX-GalSim + + # Create a virtual environment + python -m venv .venv && source .venv/bin/activate + + # Install in editable mode with dev dependencies + pip install -e ".[dev]" + + # Install pre-commit hooks + pre-commit install + +Running Tests +^^^^^^^^^^^^^ + +.. code-block:: bash + + # Run all tests + pytest + + # Run a specific test file + pytest tests/jax/test_api.py + + # Run a specific test + pytest tests/jax/test_api.py::test_api_same + + # Verbose output with timing + pytest -vv --durations=100 + +Linting and Formatting +^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + # Lint + ruff check . --fix + + # Format + ruff format . + + # Or run both via pre-commit + pre-commit run --all-files + +See `CONTRIBUTING.md `_ +for full contribution guidelines. diff --git a/docs/quickstart.rst b/docs/quickstart.rst new file mode 100644 index 00000000..fe6b7e89 --- /dev/null +++ b/docs/quickstart.rst @@ -0,0 +1,131 @@ +Quick Start +=========== + +A complete galaxy image simulation, then JAX transformations (``jit``, ``grad``, ``vmap``) on top. + +A Simple Simulation +------------------- + +A Gaussian galaxy convolved with a Gaussian PSF, drawn and noised — equivalent to GalSim's +``demo1.py``. + +.. code-block:: python + + import jax_galsim + + # Galaxy parameters + gal_flux = 1e5 # total counts + gal_sigma = 2.0 # arcsec + psf_sigma = 1.0 # arcsec + pixel_scale = 0.2 # arcsec/pixel + noise_sigma = 30.0 # counts per pixel + + # Define profiles + gal = jax_galsim.Gaussian(flux=gal_flux, sigma=gal_sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=psf_sigma) + + # Convolve galaxy with PSF + final = jax_galsim.Convolve([gal, psf]) + + # Draw the image + image = final.drawImage(scale=pixel_scale) + + # Add Gaussian noise + image.addNoise(jax_galsim.GaussianNoise(sigma=noise_sigma)) + + # Write to FITS + image.write("output/demo1.fits") + +Most GalSim code translates directly by replacing ``import galsim`` with +``import jax_galsim``. + +JIT Compilation +--------------- + +Wrap your simulation in ``jax.jit`` to compile it into an optimised XLA computation: + +.. code-block:: python + + import jax + + @jax.jit(static_argnames=['slen', 'fft_size']) + def simulate(flux, sigma, *, slen, fft_size): + gsparams = jax_galsim.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) + gal = jax_galsim.Gaussian(flux=flux, sigma=sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + final = jax_galsim.Convolve([gal, psf]) + return final.drawImage(nx=slen, ny=slen, scale=0.2) + + # First call compiles; subsequent calls are fast + image = simulate(1e5, 2.0, slen=21, fft_size=128) + +.. note:: + + Any arguments that affect control flow (like image size) must be marked as + ``static_argnames`` for JIT to work. + +Here is an alternative using ``functools.partial``: + +.. code-block:: python + + from jax import jit + from functools import partial + + def simulate(flux, sigma, *, slen, fft_size): + gsparams = jax_galsim.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) + gal = jax_galsim.Gaussian(flux=flux, sigma=sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + final = jax_galsim.Convolve([gal, psf]) + return final.drawImage(nx=slen, ny=slen, scale=0.2) + + simulate_jitted = jit(partial(simulate, slen=21, fft_size=128)) + image = simulate_jitted(1e5, 2.0) + +Automatic Differentiation +-------------------------- + +Compute gradients of any scalar output with respect to parameters: + +.. code-block:: python + + def total_flux(gal_sigma, psf_sigma): + gal = jax_galsim.Gaussian(flux=1e5, sigma=gal_sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=psf_sigma) + final = jax_galsim.Convolve([gal, psf]) + image = final.drawImage(scale=0.2) + return image.array.sum() + + # Gradient of total image flux with respect to both sigmas + grad_fn = jax.grad(total_flux, argnums=(0, 1)) + d_gal, d_psf = grad_fn(2.0, 1.0) + +Useful for fitting galaxy models to data via gradient-based optimisation. + +Vectorization with vmap +----------------------- + +Batch-simulate galaxies with different parameters without explicit loops: + +.. code-block:: python + + import jax.numpy as jnp + + sigmas = jnp.linspace(1.0, 4.0, 10) + + @jax.jit + @jax.vmap + def batch_simulate(sigma): + gsparams = jax_galsim.GSParams(minimum_fft_size=128, maximum_fft_size=128) + gal = jax_galsim.Gaussian(flux=1e5, sigma=sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + final = jax_galsim.Convolve([gal, psf]).withGSParams(gsparams) + return final.drawImage(scale=0.2, nx=64, ny=64).array + + # Simulate all 10 galaxies in parallel + images = batch_simulate(sigmas) # shape: (10, 64, 64) + +Next Steps +---------- + +- :doc:`sharp-bits` — What changes when GalSim runs on JAX +- :doc:`api/index` — Full API documentation diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst new file mode 100644 index 00000000..7901176e --- /dev/null +++ b/docs/sharp-bits.rst @@ -0,0 +1,259 @@ +🔪 JAX-GalSim - The Sharp Bits 🔪 +================================== + +JAX-GalSim is designed as a drop-in replacement for GalSim — replacing +``import galsim`` with ``import jax_galsim`` works for all supported features. +However, JAX's execution model introduces several fundamental differences +that you should understand before porting code or writing new simulations. + +Immutability +------------ + +JAX arrays are **immutable**. Any GalSim operation that originally modified data +in-place now creates a new array that overwrites the original one. Take +``__iadd__`` as an example: + +.. code-block:: python + + # GalSim — mutates the image in-place + # i.e. no new numpy array is created + image += 1.0 + # under the hood: self.array[:,:] += a (no new array) + + # JAX-GalSim — creates a new array and overwrites original one + image += 1.0 + # under the hood: image._array = image._array + 1.0 (new JAX array) + +This can be a subtle source of bugs if you are used to NumPy in-place +mutability. Here is a concrete illustration: + +.. code-block:: python + + # galsim + image = galsim.ImageD(11, 11) + arr1 = image.array + image += 1.0 + arr1.sum(), image.array.sum() # -> 121.0, 121.0 + + # jax-galsim + image = jax_galsim.ImageD(11, 11) + arr1 = image.array + image += 1.0 + arr1.sum(), image.array.sum() # -> 0.0, 121.0 (original unmodified!) + +For more details, see the `JAX Sharp Bits page on in-place updates +`_. + +Array Views +----------- + +NumPy supports **array views** — slices that share memory with the original +array. JAX does not. In GalSim, you can obtain a real-valued view of a complex +image (the real part shares memory with the underlying complex buffer). +In JAX-GalSim these operations return **copies** instead. Modifying the copy +does not affect the original. + +.. code-block:: python + + # GalSim — real_part is a view, shares memory with complex_image + real_part = complex_image.real + + # JAX-GalSim — real_part is a copy + real_part = complex_image.real # independent array + +Random Number Generation +------------------------ + +JAX uses a **functional PRNG** — random state is explicit and must be passed +through computations. This has several consequences: + +**Determinism**: Given the same seed, JAX-GalSim produces identical results +across runs and platforms (CPU, GPU, TPU). GalSim's results may vary by +platform. + +**Explicit state**: Random deviates carry their state explicitly. Under the +hood, JAX-GalSim wraps JAX's key-based PRNG in GalSim's familiar noise API, +so the user-facing interface looks the same: + +.. code-block:: python + + noise = jax_galsim.GaussianNoise(sigma=30.0) + image.addNoise(noise) # state is managed internally + +**Different sequences**: Even with the same seed value, the actual random +number sequences differ from GalSim. Results will not match GalSim +number-for-number. This is expected — the underlying PRNG algorithms are +completely different. + +**No in-place fill**: GalSim deviates can "fill" existing arrays. JAX deviates +always return new arrays, consistent with JAX's immutability model. + +PyTree Registration +------------------- + +All JAX-GalSim objects are registered as JAX **PyTrees**. This is what allows +you to pass them directly to ``jax.jit``, ``jax.grad``, and ``jax.vmap``. + +A PyTree splits each object into two parts: + +.. list-table:: + :header-rows: 1 + :widths: 20 35 30 15 + + * - Part + - What it contains + - Examples + - Effect of changing + * - **Children** (traced) + - Values JAX differentiates through + - ``flux``, ``sigma``, ``half_light_radius`` + - Re-evaluation, not recompilation + * - **Auxiliary data** (static) + - Structure and configuration + - ``GSParams``, enum flags + - Full recompilation under ``jit`` + +For ``GSObject``, profile parameters live in a ``_params`` dict (children) and +numerical configuration lives in ``_gsparams`` (auxiliary): + +.. code-block:: python + + gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) + # gal._params = {"flux": 1e5, "sigma": 2.0} — traced by JAX + # gal._gsparams = GSParams(...) — static, triggers recompile + +Because ``GSParams`` is static auxiliary data, changing it between calls to a +``jit``-compiled function triggers recompilation. Keep ``GSParams`` constant +across calls when possible: + +.. code-block:: python + + import jax + + gsparams = jax_galsim.GSParams(minimum_fft_size=8192, maximum_fft_size=8192) + slen = 21 + + @jax.jit + def simulate(flux, sigma): + gal = jax_galsim.Gaussian(flux=flux, sigma=sigma, gsparams=gsparams) + return gal.drawImage(nx=slen, ny=slen, scale=0.2).array.sum() + +Control Flow and Tracing +------------------------ + +JAX's JIT compiler works by **tracing** — it records operations on abstract +values to build a computation graph. This restricts what Python code can do +inside ``jit``-compiled functions. + +No branching on traced values +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You cannot use Python ``if``/``else`` on values that JAX is tracing (e.g., +profile parameters passed into a ``jit``-compiled function): + +.. code-block:: python + + @jax.jit + def bad(sigma): + if sigma > 1.0: # ERROR: sigma is a tracer, not a concrete value + return sigma * 2 + return sigma + + @jax.jit + def good(sigma): + return jax.lax.cond(sigma > 1.0, lambda s: s * 2, lambda s: s, sigma) + +JAX-GalSim uses an internal ``has_tracers()`` utility to detect tracing and +avoid problematic control flow in its own implementations. + +Fixed output shapes +^^^^^^^^^^^^^^^^^^^ + +Under ``jit``, the **shape** of every array must be determinable at compile time. +When using ``jax.jit`` or ``jax.vmap`` you must specify fixed image dimensions: + +.. code-block:: python + + @jax.jit + @jax.vmap + def batch(sigma): + gsparams = jax_galsim.GSParams(minimum_fft_size=256, maximum_fft_size=256) + gal = jax_galsim.Gaussian(flux=1e5, sigma=sigma).withGSParams(gsparams) + # Must specify nx, ny so all images have the same shape + return gal.drawImage(scale=0.2, nx=64, ny=64).array + +The default drawing procedure uses an FFT whose k-space image size normally +depends on traced galaxy parameters (e.g. size). Fix it explicitly via +``GSParams``: + +.. code-block:: python + + gsparams = jax_galsim.GSParams(minimum_fft_size=256, maximum_fft_size=256) + +Both ``minimum_fft_size`` and ``maximum_fft_size`` must be set to the same value. + +The ``__init__`` gotcha +^^^^^^^^^^^^^^^^^^^^^^^ + +During ``jit`` tracing, JAX calls constructors with **tracer objects** rather +than concrete Python numbers. Type checks like ``isinstance(sigma, float)`` will +fail on tracers. JAX-GalSim handles this internally, but if you subclass any +JAX-GalSim object, be aware that ``__init__`` may receive tracers: + +.. code-block:: python + + from jax_galsim.core.utils import has_tracers + + class MyProfile(jax_galsim.GSObject): + def __init__(self, sigma, gsparams=None): + if not has_tracers(sigma): + # Only validate with concrete values + if sigma <= 0: + raise ValueError("sigma must be positive") + ... + +Profile Restrictions +-------------------- + +Some GalSim features are not yet implemented in JAX-GalSim: + +- **Truncated Moffat profiles**: The ``trunc`` parameter is not supported. +- **ChromaticObject**: All chromatic functionality (wavelength-dependent + profiles) is not available. +- **InterpolatedKImage**: Not implemented. +- **Airy, Kolmogorov, OpticalPSF, RealGalaxy**: See :doc:`api-coverage` for + the full list. + +The project currently implements **22.5 %** of the GalSim public API, focused +on the most commonly used profiles and operations. + +Numerical Precision +------------------- + +Simulation results may differ slightly from GalSim at the floating-point level: + +- **Operation reordering**: JAX (via XLA) may reorder floating-point operations + for performance. Floating-point addition is not associative, so different + orderings produce slightly different results. +- **Different math kernels**: XLA-compiled math kernels may differ from system + math libraries (e.g. ``libm``) used by GalSim via NumPy/C++. +- **Gradient-safe functions**: JAX-GalSim uses special implementations (e.g. + ``safe_sqrt`` to avoid ``NaN`` gradients at zero) where GalSim uses standard + library functions. These may produce slightly different values at edge cases. +- **Default precision**: JAX defaults to 32-bit floats. Enable 64-bit with + ``jax.config.update("jax_enable_x64", True)`` for higher precision matching + GalSim's default behaviour. + +These differences are typically at the level of floating-point round-off +(:math:`{\sim}10^{-7}` for float32, :math:`{\sim}10^{-15}` for float64) and +should not affect scientific conclusions. + +⚠️ Additional Sharp Bits +-------------------------- + +In the :doc:`api/index` you will find **🔪 JAX-GalSim - The Sharp Bits 🔪** blocks highlighting additional important caveats for specific classes and or methods. These could include things like: + +- Many classes do not perform some of Galsim's test for correctness during initialization (e.g., :meth:`~jax_galsim.GSObject.drawImage`). +- Certain profiles might not be auto-differentiable with respect to some of their parameters (e.g., :class:`~jax_galsim.Spergel`, :class:`~jax_galsim.Moffat`) +- Limitations regarding what types of inputes are handled (e.g., :meth:`~jax_galsim.Image.calculate_fft` does not accept complex dtypes.) + diff --git a/docs/versioning.rst b/docs/versioning.rst new file mode 100644 index 00000000..62e2e564 --- /dev/null +++ b/docs/versioning.rst @@ -0,0 +1,6 @@ +Versioning and API Policy +--------------------------- + +JAX-GalSim follows `Calver `_ with a version number ``YYYY.MM.MICRO`` with ``MICRO`` resetting to ``0`` at the start of each month. + +For APIs which are also present in GalSim (e.g., you can import the same thing by substituting galsim for jax_galsim), JAX-GalSim is a strict subset of the GalSim APIs. All other APIs may change without notice for any version part increment. We thus recommend pinning the entire JAX-GalSim version if you use this code in your work. diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 80c9d5fe..10442ad4 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -27,8 +27,8 @@ BOUNDS_LAX_DESCR = """\ The JAX implementation - - will not always test whether the bounds are valid - - will not always test whether BoundsI is initialized with integers +- will not always test whether the bounds are valid +- will not always test whether BoundsI is initialized with integers Further, the JAX implementation adds a new method, ``isStatic`` to the ``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index e956dc0c..4ca931a3 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -116,7 +116,7 @@ def calculate_mean_n_photons( max_sb: The maximum surface brightness of the object (e.g., ``obj.max_sb``). Returns: - n_photons: The number of photons. + The number of photons. """ npd = _NPhotonsData( n_photons=0.0, @@ -176,125 +176,120 @@ def calculate_n_photons( is False] Returns: - n_photons: The number of photons. - g: The flux ratio to use. Combine with a pre-existing gain via ``g /= gain`` and then multiply - the flux per photon by ``g``. - rng: The final random number generator used. - + A tuple of ``(n_photons, g, rng)`` where ``n_photons`` is the number of photons, ``g`` is the flux ratio, and ``rng`` is the final random number generator used. Notes: - - It is easiest to simply copy the original code from GSObject._calculate_nphotons - into the doc string here in order to document what this function does. - - # the old doc string: - Calculate how many photons to shoot and what flux_ratio (called g) each one should - have in order to produce an image with the right S/N and total flux. - - This routine is normally called by `drawPhot`. - - Returns: - n_photons, g - - # For profiles that are positive definite, then N = flux. Easy. - # - # However, some profiles shoot some of their photons with negative flux. This means that - # we need a few more photons to get the right S/N = sqrt(flux). Take eta to be the - # fraction of shot photons that have negative flux. - # - # S^2 = (N+ - N-)^2 = (N+ + N- - 2N-)^2 = (Ntot - 2N-)^2 = Ntot^2(1 - 2 eta)^2 - # N^2 = Var(S) = (N+ + N-) = Ntot - # - # So flux = (S/N)^2 = Ntot (1-2eta)^2 - # Ntot = flux / (1-2eta)^2 - # - # However, if each photon has a flux of 1, then S = (1-2eta) Ntot = flux / (1-2eta). - # So in fact, each photon needs to carry a flux of g = 1-2eta to get the right - # total flux. - # - # That's all the easy case. The trickier case is when we are sky-background dominated. - # Then we can usually get away with fewer shot photons than the above. In particular, - # if the noise from the photon shooting is much less than the sky noise, then we can - # use fewer shot photons and essentially have each photon have a flux > 1. This is ok - # as long as the additional noise due to this approximation is "much less than" the - # noise we'll be adding to the image for the sky noise. - # - # Let's still have Ntot photons, but now each with a flux of g. And let's look at the - # noise we get in the brightest pixel that has a nominal total flux of Imax. - # - # The number of photons hitting this pixel will be Imax/flux * Ntot. - # The variance of this number is the same thing (Poisson counting). - # So the noise in that pixel is: - # - # N^2 = Imax/flux * Ntot * g^2 - # - # And the signal in that pixel will be: - # - # S = Imax/flux * (N+ - N-) * g which has to equal Imax, so - # g = flux / Ntot(1-2eta) - # N^2 = Imax/Ntot * flux / (1-2eta)^2 - # - # As expected, we see that lowering Ntot will increase the noise in that (and every - # other) pixel. - # The input max_extra_noise parameter is the maximum value of spurious noise we want - # to allow. - # - # So setting N^2 = Imax + nu, we get - # - # Ntot = flux / (1-2eta)^2 / (1 + nu/Imax) - # g = (1 - 2eta) * (1 + nu/Imax) - # - # Returns the total flux placed inside the image bounds by photon shooting. - # - - flux = self.flux - if flux == 0.0: - return 0, 1.0 - - # The _flux_per_photon property is (1-2eta) - # This factor will already be accounted for by the shoot function, so don't include - # that as part of our scaling here. There may be other adjustments though, so g=1 here. - eta_factor = self._flux_per_photon - mod_flux = flux / (eta_factor * eta_factor) - g = 1. - - # If requested, let the target flux value vary as a Poisson deviate - if poisson_flux: - # If we have both positive and negative photons, then the mix of these - # already gives us some variation in the flux value from the variance - # of how many are positive and how many are negative. - # The number of negative photons varies as a binomial distribution. - # = eta * Ntot * g - # = (1-eta) * Ntot * g - # = (1-2eta) * Ntot * g = flux - # Var(F-) = eta * (1-eta) * Ntot * g^2 - # F+ = Ntot * g - F- is not an independent variable, so - # Var(F+ - F-) = Var(Ntot*g - 2*F-) - # = 4 * Var(F-) - # = 4 * eta * (1-eta) * Ntot * g^2 - # = 4 * eta * (1-eta) * flux - # We want the variance to be equal to flux, so we need an extra: - # delta Var = (1 - 4*eta + 4*eta^2) * flux - # = (1-2eta)^2 * flux - absflux = abs(flux) - mean = eta_factor*eta_factor * absflux - pd = PoissonDeviate(rng, mean) - pd_val = pd() - mean + absflux - ratio = pd_val / absflux - g *= ratio - mod_flux *= ratio - - if n_photons == 0.: - n_photons = abs(mod_flux) - if max_extra_noise > 0.: - gfactor = 1. + max_extra_noise / abs(self.max_sb) - n_photons /= gfactor - g *= gfactor - - # Make n_photons an integer. - iN = int(n_photons + 0.5) - - return iN, g + It is easiest to look at the original code from ``GSObject._calculate_nphotons`` + to understand what this function does: + + .. code-block:: python + + # Calculate how many photons to shoot and what flux_ratio (called g) each one should + # have in order to produce an image with the right S/N and total flux. + # + # This routine is normally called by drawPhot. + # + # Returns: + # n_photons, g + + # For profiles that are positive definite, then N = flux. Easy. + # + # However, some profiles shoot some of their photons with negative flux. This means that + # we need a few more photons to get the right S/N = sqrt(flux). Take eta to be the + # fraction of shot photons that have negative flux. + # + # S^2 = (N+ - N-)^2 = (N+ + N- - 2N-)^2 = (Ntot - 2N-)^2 = Ntot^2(1 - 2 eta)^2 + # N^2 = Var(S) = (N+ + N-) = Ntot + # + # So flux = (S/N)^2 = Ntot (1-2eta)^2 + # Ntot = flux / (1-2eta)^2 + # + # However, if each photon has a flux of 1, then S = (1-2eta) Ntot = flux / (1-2eta). + # So in fact, each photon needs to carry a flux of g = 1-2eta to get the right + # total flux. + # + # That's all the easy case. The trickier case is when we are sky-background dominated. + # Then we can usually get away with fewer shot photons than the above. In particular, + # if the noise from the photon shooting is much less than the sky noise, then we can + # use fewer shot photons and essentially have each photon have a flux > 1. This is ok + # as long as the additional noise due to this approximation is "much less than" the + # noise we'll be adding to the image for the sky noise. + # + # Let's still have Ntot photons, but now each with a flux of g. And let's look at the + # noise we get in the brightest pixel that has a nominal total flux of Imax. + # + # The number of photons hitting this pixel will be Imax/flux * Ntot. + # The variance of this number is the same thing (Poisson counting). + # So the noise in that pixel is: + # + # N^2 = Imax/flux * Ntot * g^2 + # + # And the signal in that pixel will be: + # + # S = Imax/flux * (N+ - N-) * g which has to equal Imax, so + # g = flux / Ntot(1-2eta) + # N^2 = Imax/Ntot * flux / (1-2eta)^2 + # + # As expected, we see that lowering Ntot will increase the noise in that (and every + # other) pixel. + # The input max_extra_noise parameter is the maximum value of spurious noise we want + # to allow. + # + # So setting N^2 = Imax + nu, we get + # + # Ntot = flux / (1-2eta)^2 / (1 + nu/Imax) + # g = (1 - 2eta) * (1 + nu/Imax) + # + # Returns the total flux placed inside the image bounds by photon shooting. + + flux = self.flux + if flux == 0.0: + return 0, 1.0 + + # The _flux_per_photon property is (1-2eta) + # This factor will already be accounted for by the shoot function, so don't include + # that as part of our scaling here. There may be other adjustments though, so g=1 here. + eta_factor = self._flux_per_photon + mod_flux = flux / (eta_factor * eta_factor) + g = 1.0 + + # If requested, let the target flux value vary as a Poisson deviate + if poisson_flux: + # If we have both positive and negative photons, then the mix of these + # already gives us some variation in the flux value from the variance + # of how many are positive and how many are negative. + # The number of negative photons varies as a binomial distribution. + # = eta * Ntot * g + # = (1-eta) * Ntot * g + # = (1-2eta) * Ntot * g = flux + # Var(F-) = eta * (1-eta) * Ntot * g^2 + # F+ = Ntot * g - F- is not an independent variable, so + # Var(F+ - F-) = Var(Ntot*g - 2*F-) + # = 4 * Var(F-) + # = 4 * eta * (1-eta) * Ntot * g^2 + # = 4 * eta * (1-eta) * flux + # We want the variance to be equal to flux, so we need an extra: + # delta Var = (1 - 4*eta + 4*eta^2) * flux + # = (1-2eta)^2 * flux + absflux = abs(flux) + mean = eta_factor * eta_factor * absflux + pd = PoissonDeviate(rng, mean) + pd_val = pd() - mean + absflux + ratio = pd_val / absflux + g *= ratio + mod_flux *= ratio + + if n_photons == 0.0: + n_photons = abs(mod_flux) + if max_extra_noise > 0.0: + gfactor = 1.0 + max_extra_noise / abs(self.max_sb) + n_photons /= gfactor + g *= gfactor + + # Make n_photons an integer. + iN = int(n_photons + 0.5) + + return iN, g """ n_photons_data = _NPhotonsData( diff --git a/jax_galsim/core/interpolate.py b/jax_galsim/core/interpolate.py index 7b545148..34e23398 100644 --- a/jax_galsim/core/interpolate.py +++ b/jax_galsim/core/interpolate.py @@ -14,28 +14,21 @@ def akima_interp_coeffs(x, y, use_jax=True): continuous second derivatives at the interpolation points. See https://en.wikipedia.org/wiki/Akima_spline and - Akima (1970), "A new method of interpolation and smooth curve fitting based on local procedures", - Journal of the ACM. 17: 589-602 for a description of the technique. - - Parameters - ---------- - x : array-like - The x-coordinates of the data points. These must be sorted into increasing order - and cannot contain any duplicates. - y : array-like - The y-coordinates of the data points. - use_jax : bool, optional - Whether to use JAX for computation. Default is True. If False, the - coefficients are computed using NumPy on the host device. This can be - useful when embded inside JAX code w/ JIT applied to pre-compute the - coefficients. - - Returns - ------- - tuple - A tuple of arrays (a, b, c, d) where each array has shape (N-1,) and - contains the coefficients for the cubic polynomial that interpolates - the data points between x[i] and x[i+1]. + Akima (1970), "A new method of interpolation and smooth curve fitting based on local + procedures", Journal of the ACM. 17: 589-602 for a description of the technique. + + Parameters: + x: The x-coordinates of the data points. These must be sorted into + increasing order and cannot contain any duplicates. + y: The y-coordinates of the data points. + use_jax: Whether to use JAX for computation. If False, coefficients are computed + using NumPy on the host device, which can be useful when embedded inside + JAX code with JIT applied to pre-compute the coefficients. [default: True] + + Returns: + A tuple of arrays ``(a, b, c, d)`` where each array has shape ``(N-1,)`` + and contains the coefficients for the cubic polynomial that interpolates + the data points between ``x[i]`` and ``x[i+1]``. """ if use_jax: return _akima_interp_coeffs_jax(x, y) @@ -113,29 +106,21 @@ def _akima_interp_coeffs_jax(x, y): @functools.partial(jax.jit, static_argnames=("fixed_spacing",)) def akima_interp(x, xp, yp, coeffs, fixed_spacing=False): - """Conmpute the values of an Akima cubic spline at a set of points given the + """Compute the values of an Akima cubic spline at a set of points given the interpolation coefficients. - Parameters - ---------- - x : array-like - The x-coordinates of the points where the interpolation is computed. - xp : array-like - The x-coordinates of the data points. These must be sorted into increasing order - and cannot contain any duplicates. - yp : array-like - The y-coordinates of the data points. Not used currently. - coeffs : tuple - The interpolation coefficients returned by `akima_interp_coeffs`. - fixed_spacing : bool, optional - Whether the data points are evenly spaced. Default is False. If True, the - code uses a faster technique to compute the index of the data points x into - the array xp such that xp[i] <= x < xp[i+1]. - - Returns - ------- - array-like - The values of the Akima cubic spline at the points x. + Parameters: + x: The x-coordinates of the points where the interpolation is computed. + xp: The x-coordinates of the data points. These must be sorted into + increasing order and cannot contain any duplicates. + yp: The y-coordinates of the data points. Not used currently. + coeffs: The interpolation coefficients returned by ``akima_interp_coeffs``. + fixed_spacing: Whether the data points are evenly spaced. If True, a faster + technique is used to find the index ``i`` such that + ``xp[i] <= x < xp[i+1]``. [default: False] + + Returns: + The values of the Akima cubic spline at the points ``x``. """ xp = jnp.asarray(xp) # yp = jnp.array(yp) # unused diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 915583f7..c14077fe 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -572,14 +572,14 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): @implements( _galsim.GSObject.drawImage, lax_description="""\ -The JAX-GalSim version of `drawImage` - - - does not do extensive (any?) checking of the input settings. - - uses a default of ``n_photons=None`` instead of ``n_photons=0`` - to indicate that the number of photons should be determined - from the flux and gain - - requires that the maxN option be a constant since PhotonArrays are allocated - with `maxN` photons when this option is used and arrays in JAX must have static sizes. +The JAX-GalSim version of ``drawImage`` + +- does not do extensive (any?) checking of the input settings. +- uses a default of ``n_photons=None`` instead of ``n_photons=0`` + to indicate that the number of photons should be determined + from the flux and gain +- requires that the ``maxN`` option be a constant since PhotonArrays are allocated + with ``maxN`` photons when this option is used and arrays in JAX must have static sizes. """, ) def drawImage( @@ -1078,12 +1078,12 @@ def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): @implements( _galsim.GSObject.makePhot, lax_description="""\ -The JAX-GalSim version of `makePhot` +The JAX-GalSim version of ``makePhot`` - - does little to no error checking on the inputs - - uses a default of ``n_photons=None`` instead of ``n_photons=0`` - to indicate that the number of photons should be determined - from the flux and gain +- does little to no error checking on the inputs +- uses a default of ``n_photons=None`` instead of ``n_photons=0`` + to indicate that the number of photons should be determined + from the flux and gain """, ) def makePhot( @@ -1146,13 +1146,13 @@ def makePhot( @implements( _galsim.GSObject.drawPhot, lax_description="""\ -The JAX-GalSim version of `drawPhot` +The JAX-GalSim version of ``drawPhot`` - - does little to no error checking on the inputs - - uses a default of ``n_photons=None`` instead of ``n_photons=0`` - to indicate that the number of photons should be determined - from the flux and gain - - requires that the maxN option must be a constant +- does little to no error checking on the inputs +- uses a default of ``n_photons=None`` instead of ``n_photons=0`` + to indicate that the number of photons should be determined + from the flux and gain +- requires that the ``maxN`` option must be a constant """, ) def drawPhot( diff --git a/jax_galsim/image.py b/jax_galsim/image.py index efce0e10..25fdd7ec 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -20,11 +20,12 @@ Contrary to GalSim native Image, this implementation does not support sharing of the underlying numpy array between different Images or Views. This is due to the fact that in JAX numpy arrays are immutable, so any -operation applied to this Image will create a new jnp.ndarray. +operation applied to this Image will create a new ``jnp.ndarray``. In particular the following methods will create a copy of the Image: - - Image.view() - - Image.subImage() + +- ``Image.view()`` +- ``Image.subImage()`` """ diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 8561c159..0d7c24c6 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -1,6 +1,5 @@ import copy import math -import textwrap from functools import partial import galsim as _galsim @@ -54,18 +53,17 @@ def __dir__(cls): return list(keys) -@implements( - _galsim.InterpolatedImage, - lax_description=textwrap.dedent( - """The JAX equivalent of galsim.InterpolatedImage does not support +LAX_INTERPOLATED_IMAGE = """\ +The JAX equivalent of galsim.InterpolatedImage does not support: - - noise padding - - the pad_image options - - depixelize - - most of the bounds checks, type checks, and dtype casts done by galsim - """ - ), -) +- noise padding +- the pad_image options +- depixelize +- most of the bounds checks, type checks, and dtype casts done by galsim +""" + + +@implements(_galsim.InterpolatedImage, lax_description=LAX_INTERPOLATED_IMAGE) @register_pytree_node_class class InterpolatedImage(Transformation, metaclass=DirMeta): _req_params = {"image": str} diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index f6fffe51..2070f981 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -41,35 +41,36 @@ def fixed_photon_array_size(size): lax_description="""\ JAX-GalSim PhotonArrays have significant differences from the original GalSim. - - They always copy input data and operations on them always copy. - - They (usually) do not do any type/size checking on input data. - - They do not support indexed assignement directly on the attributes. - - The additional properties `dxdz`, `dydz`, `wavelength`, `pupil_u`, `pupil_v`, - and `time` are set to arrays of NaNs by default. They are thus always allocated. - However, the methods like `hasAllocatedAngles` etc. return false if the arrays - are all NaNs. - -Further, a context manager `fixed_photon_array_size` is provided to temporarily +- They always copy input data and operations on them always copy. +- They (usually) do not do any type/size checking on input data. +- They do not support indexed assignement directly on the attributes. +- The additional properties ``dxdz``, ``dydz``, ``wavelength``, ``pupil_u``, ``pupil_v``, + and ``time`` are set to arrays of NaNs by default. They are thus always allocated. + However, the methods like `hasAllocatedAngles` etc. return false if the arrays + are all NaNs. + +Further, a context manager ``fixed_photon_array_size`` is provided to temporarily set a fixed size for photon arrays. - - This functionality is useful when apply JIT to operations that vary the - number of photons drawn using Poisson statistics. - - When using this context manager, the attribute `_nokeep` stores a boolean mask - indicating which photons are to be kept. - - The attribute `_num_keep` stores the number of photons to be kept. If you set - this attribute, the `_nokeep` mask is updated by sorting _nokeep so that things - to be kept are at the start, the first `_num_keep` photons are marked to be kept, - and finally the array is sorted back to its original order. - - You may get an error if you ask for more photons than the fixed size, but not always, - especially in JITed code. - - Operations on photon arrays with fixed sizes but different `_num_keep` values are not - defined and will not raise an error. - - The `.flux` property scales `._flux` by the ratio of the fixed size to the number of kept photons - and sets non-kept photons to zero flux. Setting `.flux` to `._flux` will break things badly. - - Profiles should always draw the full number of photons given by `.size()` or `len()` - so that they use fixed array sizes and things are JIT compatible. - -**The `_nokeep`, `_num_keep`, and associated methods are private and should not be set by hand +- This functionality is useful when applying JIT to operations that vary the + number of photons drawn using Poisson statistics. +- When using this context manager, the attribute ``_nokeep`` stores a boolean mask + indicating which photons are to be kept. +- The attribute ``_num_keep`` stores the number of photons to be kept. If you set + this attribute, the ``_nokeep`` mask is updated by sorting ``_nokeep`` so that things + to be kept are at the start, the first ``_num_keep`` photons are marked to be kept, + and finally the array is sorted back to its original order. +- You may get an error if you ask for more photons than the fixed size, but not always, + especially in JITed code. +- Operations on photon arrays with fixed sizes but different `_num_keep` values are not + defined and will not raise an error. +- The ``.flux`` property scales ``._flux`` by the ratio of the fixed size to the number + of kept photons and sets non-kept photons to zero flux. Setting ``.flux`` to ``._flux`` + will break things badly. +- Profiles should always draw the full number of photons given by ``.size()`` or ``len()`` + so that they use fixed array sizes and things are JIT compatible. + +**The ``_nokeep``, ``_num_keep``, and associated methods are private and should not be set by hand unless you know what you are doing!** """, ) @@ -260,8 +261,8 @@ def tree_unflatten(cls, aux_data, children): ret._is_corr = children[1]["is_corr"] return ret + @implements(_galsim.PhotonArray.size) def size(self): - """Return the size of the photon array. Equivalent to ``len(self)``.""" return self._Ntot def __len__(self): diff --git a/jax_galsim/spergel.py b/jax_galsim/spergel.py index fbc756db..200d197d 100644 --- a/jax_galsim/spergel.py +++ b/jax_galsim/spergel.py @@ -240,25 +240,25 @@ def _spergel_hlr_pade(x): return pm / qm -@implements( - _galsim.Spergel, - lax_description=r"""The fully normalized Spergel profile (used in both standard GalSim and JAX-GalSim) is - .. math:: +LAX_SPERGEL_DESCRIPTION = r""" +The fully normalized Spergel profile (used in both standard GalSim and JAX-GalSim) is - I(r) = flux \times \left(2\pi 2^\nu \Gamma(1+\nu) r_0^2\right)^{-1} - \times \left(\frac{r}{r_0}\right)^\nu K_\nu\left(\frac{r}{r_0}\right) +.. math:: + I(r) = flux \times \left(2\pi 2^\nu \Gamma(1+\nu) r_0^2\right)^{-1} \times \left(\frac{r}{r_0}\right)^\nu K_\nu\left(\frac{r}{r_0}\right) - with the following Fourier expression - .. math:: +with the following Fourier expression - \hat{I}(k) = flux / (1 + (k r_0)^2)^{1+\nu} +.. math:: + \hat{I}(k) = flux / (1 + (k r_0)^2)^{1+\nu} - where :math:`r_0` is the ``scale_radius``, and :math:`\nu` mandatory to be in [-0.85,4.0] +where :math:`r_0` is the ``scale_radius``, and :math:`\nu` mandatory to be in [-0.85,4.0] - The JAX-GalSim implementation does not support autodiff with respect to :math:`\nu` for - real-space evaluations. - """, -) +The JAX-GalSim implementation does not support autodiff with respect to :math:`\nu` for +real-space evaluations. +""" + + +@implements(_galsim.Spergel, lax_description=LAX_SPERGEL_DESCRIPTION) @register_pytree_node_class class Spergel(GSObject): _has_hard_edges = False diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index 958e6bfa..8dae7aca 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -12,14 +12,14 @@ @implements( - _galsim.Add, lax_description="Does not support `ChromaticObject` at this point." + _galsim.Add, lax_description="Does not support ``ChromaticObject`` at this point." ) def Add(*args, **kwargs): return Sum(*args, **kwargs) @implements( - _galsim.Sum, lax_description="Does not support `ChromaticObject` at this point." + _galsim.Sum, lax_description="Does not support ``ChromaticObject`` at this point." ) @register_pytree_node_class class Sum(GSObject): diff --git a/pyproject.toml b/pyproject.toml index 55f929f3..1a3315f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,30 +1,20 @@ [build-system] -requires = [ - "setuptools>=45", - "setuptools-scm>=8", -] +requires = ["setuptools>=45", "setuptools-scm>=8"] build-backend = "setuptools.build_meta" [project] -name = "JAX-GalSim" -authors = [ - {name = "GalSim Developers"}, -] +name = "JAX-GalSim" +authors = [{ name = "GalSim Developers" }] requires-python = ">= 3.11" -description = "The modular galaxy image simulation toolkit, but in JAX" -dynamic = ["version"] -license = {file = "LICENSE"} -readme = "README.md" -dependencies = [ - "numpy >=1.18.0", - "galsim >=2.7.0", - "jax >=0.6.0", - "astropy >=2.0", - "quadax", -] +description = "The modular galaxy image simulation toolkit, but in JAX" +dynamic = ["version"] +license = { file = "LICENSE" } +readme = "README.md" +dependencies = ["numpy >=1.18.0", "galsim >=2.7.0", "jax >=0.6.0", "astropy >=2.0", "quadax"] [project.optional-dependencies] -dev = ["pytest", "pytest-codspeed"] +dev = ["pytest", "pytest-codspeed"] +docs = ["sphinx>=7.0", "furo>=2024.1.29", "sphinx-design>=0.5"] [project.urls] home = "https://github.com/GalSim-developers/JAX-GalSim" @@ -33,25 +23,19 @@ home = "https://github.com/GalSim-developers/JAX-GalSim" include = ["jax_galsim*"] [tool.setuptools_scm] -write_to = "jax_galsim/_version.py" +write_to = "jax_galsim/_version.py" fallback_version = "0.0.1.dev0" [tool.ruff.lint] -select = ["E", "F", "I", "W"] -ignore = ["C901", "E203", "E501"] +select = ["E", "F", "I", "W"] +ignore = ["C901", "E203", "E501"] preview = true [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401", "I001"] [tool.pytest.ini_options] -minversion = "6.0" -addopts = "-ra -q" -testpaths = [ - "tests/GalSim/tests/", - "tests/jax", - "tests/Coord/tests/", -] -filterwarnings = [ - "ignore::DeprecationWarning", -] +minversion = "6.0" +addopts = "-ra -q" +testpaths = ["tests/GalSim/tests/", "tests/jax", "tests/Coord/tests/"] +filterwarnings = ["ignore::DeprecationWarning"]