Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 104 additions & 2 deletions bpd/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,28 @@ def get_timing_figure(
n_chains = int(n_gals_str) # new fmt

t_per_obj_warmup = t_warmup / n_chains
t_per_obj_per_sample_sampling = t_sampling / (n_chains * n_samples)
t_per_obj_per_sample_sampling = t_sampling / (n_chains * n_samples) / avg_ess
t_per_obj_arr = (
t_per_obj_warmup + t_per_obj_per_sample_sampling * n_samples_array
)
t_per_obj_dict[n_chains] = t_per_obj_arr / avg_ess
t_per_obj_dict[n_chains] = t_per_obj_arr

if n_gals_str == max_n_gal_str:
print(
f"Global best efficiency: {t_per_obj_per_sample_sampling / avg_ess:.2g} sec"
)
print(f"Global best warmup: {t_per_obj_warmup:.2g} sec")

the_idx = np.where(n_samples_array == 300)[0][0]
t1 = t_per_obj_arr[the_idx].item() * n_chains
t2 = t_per_obj_arr[the_idx].item()

print(
f"Total time (300 effective samples) with {n_gals_str} chains: {t1:.4g} sec"
)
print(
f"Time per galaxy (300 effective samples) with {n_gals_str} chains: {t2:.4g} sec"
)
# first option
fig1, ax = plt.subplots(1, 1, figsize=figsize)
ax.set_prop_cycle(cycles)
Expand All @@ -131,6 +141,8 @@ def get_timing_figure(
ax.set_xlabel(r"\rm \# of effective samples")

for n_chains, t_per_obj_array in t_per_obj_dict.items():
if n_chains == 5:
continue
ax.plot(n_samples_array, t_per_obj_array, label=f"${n_chains}$")

ax.legend(
Expand All @@ -144,6 +156,96 @@ def get_timing_figure(
return fig1, fig2


def get_total_timing_figure(
results: dict, *, max_n_gal_str: str, avg_ess: float, figsize=(10, 10)
) -> Figure:
all_n_gals = [n_gals for n_gals in results]

_, n_samples = results[max_n_gal_str]["samples"]["lf"].shape

total_time_warmup = []
total_time_sampling = []
n_chains_arr = np.array([int(n_gals) for n_gals in results])

for n_gals_str in all_n_gals:
t_warmup = results[n_gals_str]["t_warmup"]
t_sampling = results[n_gals_str]["t_sampling"] / n_samples * 300 / avg_ess

total_time_warmup.append(t_warmup)
total_time_sampling.append(t_sampling)

total_time_warmup = np.array(total_time_warmup)
total_time_sampling = np.array(total_time_sampling)
total_time = total_time_sampling + total_time_warmup

# first option
fig, ax = plt.subplots(1, 1, figsize=figsize)

ax.set_ylabel(r"\rm Total time (sec)")
ax.set_xlabel(r"\rm \# of Galaxies")

ax.plot(n_chains_arr, total_time_warmup, "-o", label=r"\rm Warmup")
ax.plot(n_chains_arr, total_time_sampling, "-o", label=r"\rm Inference")
ax.plot(n_chains_arr, total_time, "-o", label=r"\rm Total")

ax.plot(n_chains_arr, total_time[0] * n_chains_arr, "k--", label=r"\rm Worst")

ax.legend(loc="best", fancybox=True, shadow=False)

ax.set_xscale("log")
ax.set_yscale("log")

return fig


def get_timing_table(
results: dict, *, max_n_gal_str: str, avg_ess: float, fpath: str
) -> Figure:
all_n_gals = [n_gals for n_gals in results]
warmup_times_per_obj = {}
inference_times = {}
t_300_dict = {} # after warmup

_, n_samples = results[max_n_gal_str]["samples"]["lf"].shape

for n_gals_str in all_n_gals:
t_warmup = results[n_gals_str]["t_warmup"]
t_sampling = results[n_gals_str]["t_sampling"]

n_chains = int(n_gals_str) # new fmt

# (avg.) time to warmup 1 object
t_per_obj_warmup = t_warmup / n_chains

# (avg.) time to produce 1 effective sample for 1 object (ignoring warmup)
t_per_obj_per_sample_sampling = t_sampling / (n_chains * n_samples) / avg_ess
t_300 = t_per_obj_per_sample_sampling * 300 + t_per_obj_warmup

# save
warmup_times_per_obj[n_chains] = t_per_obj_warmup
inference_times[n_chains] = t_per_obj_per_sample_sampling
t_300_dict[n_chains] = t_300

if n_gals_str == max_n_gal_str:
print(f"Global best efficiency: {t_per_obj_per_sample_sampling:.3g} sec")
print(f"Global best warmup: {t_per_obj_warmup:.3g} sec")

# create latex table with rows for n_chains and columns for t_per_obj_warmup, t_per_obj_per_sample_sampling,
# and eff_samples_per_sec
table_str = "\\begin{tabular}{|c|c|c|c|}\n"
table_str += "\\hline\n"
table_str += "\\# of Galaxies \\newline in Parallel & Warmup time (sec) & Inference time / eff. sample (sec) & Time to produce \\newline 300 eff. samples (sec)\\\\\n"
table_str += "\\hline\n"
for n_chains in sorted(t_300_dict.keys()):
t_per_obj_warmup = warmup_times_per_obj[n_chains]
table_str += f"{n_chains} & {t_per_obj_warmup:.2g} & {inference_times[n_chains]:.2g} & {t_300_dict[n_chains]:.2g} \\\\\n"
table_str += "\\hline\n"
table_str += "\\end{tabular}"

with open(fpath, "w", encoding="utf-8") as f:
f.write(table_str)


def get_jack_bias(
g_plus_jack: np.ndarray, g_minus_jack: np.ndarray, g1_true: float
) -> tuple:
Expand Down
9 changes: 4 additions & 5 deletions bpd/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
inv_shear_transformation,
)

_grad_fnc1 = vmap(vmap(grad(inv_shear_func1), in_axes=(0, None)), in_axes=(0, None))
_grad_fnc2 = vmap(vmap(grad(inv_shear_func2), in_axes=(0, None)), in_axes=(0, None))
_inv_shear_trans = vmap(inv_shear_transformation, in_axes=(0, None))


def ellip_mag_prior(e_mag: ArrayLike, sigma: float) -> ArrayLike:
"""Prior for the magnitude of the ellipticity with domain (0, 1).
Expand Down Expand Up @@ -43,11 +47,6 @@ def ellip_prior_e1e2(e1e2: Array, sigma: float) -> ArrayLike:
return (1 - e_mag**2) ** 2 * jnp.exp(-(e_mag**2) / (2 * sigma**2)) / _norm


_grad_fnc1 = vmap(vmap(grad(inv_shear_func1), in_axes=(0, None)), in_axes=(0, None))
_grad_fnc2 = vmap(vmap(grad(inv_shear_func2), in_axes=(0, None)), in_axes=(0, None))
_inv_shear_trans = vmap(inv_shear_transformation, in_axes=(0, None))


def interim_gprops_logprior(
params: dict[str, Array],
*,
Expand Down
3 changes: 3 additions & 0 deletions experiments/exp101/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Experiment 10.1

Low noise galaxies and joint multiplicative bias joint inference.
2 changes: 2 additions & 0 deletions experiments/exp24/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Experiment 24
Investigate MCMC on noisy galaxies
1,175 changes: 1,175 additions & 0 deletions experiments/exp24/noisy.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion experiments/exp81/eta_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def shear_eta_target(g, *, data, sigma_e: float, sigma_e_int: float):
etas = data

# P(eta' | alpha, g) = P(eps | alpha) * (del eps' / del eta') * (del eps / del eps')
# jacobian on eta cancels between num and denom so we ignore it.
# first jacobian on eta cancels between num and denom so we ignore it.
eps_sheared = vmap(vmap(eta2g))(etas)
eps = _inv_shear_trans(eps_sheared, g)
num1 = jnp.log(ellip_prior_e1e2(eps, sigma_e))
Expand Down
191 changes: 191 additions & 0 deletions notebooks/feb20-26-test_moffat-new.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 15,
"id": "f56dbc17-62c9-4396-b3e7-7a98d18e1211",
"metadata": {},
"outputs": [],
"source": [
"import galsim\n",
"import jax.numpy as jnp\n",
"import jax_galsim as xgalsim\n",
"from jax import random\n",
"from jax._src.prng import PRNGKeyArray\n",
"from jax.typing import ArrayLike\n",
"from jax_galsim import GSParams\n",
"\n",
"from functools import partial\n",
"\n",
"import jax "
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "1a9cae5a-ae0f-41e8-815b-ee802f936ad8",
"metadata": {},
"outputs": [],
"source": [
"def draw_gaussian(\n",
" *,\n",
" f: float,\n",
" hlr: float,\n",
" e1: float,\n",
" e2: float,\n",
" x: float, # pixels\n",
" y: float,\n",
" slen: int,\n",
" fft_size: int, # rule of thumb: at least 4 times `slen`\n",
" psf_fwhm: float = 0.8,\n",
" pixel_scale: float = 0.2,\n",
"):\n",
" gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size)\n",
"\n",
" gal = xgalsim.Gaussian(flux=f, half_light_radius=hlr)\n",
" gal = gal.shear(g1=e1, g2=e2)\n",
"\n",
" psf = xgalsim.Gaussian(flux=1.0, fwhm=0.8)\n",
" gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(gsparams)\n",
" image = gal_conv.drawImage(nx=slen, ny=slen, scale=pixel_scale, offset=(x, y))\n",
" return image.array"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "f1e41815-b46b-441d-900e-7df655a889ee",
"metadata": {},
"outputs": [],
"source": [
"def draw_gaussian_moffat(\n",
" *,\n",
" f: float,\n",
" hlr: float,\n",
" e1: float,\n",
" e2: float,\n",
" x: float, # pixels\n",
" y: float,\n",
" slen: int,\n",
" fft_size: int, # rule of thumb: at least 4 times `slen`\n",
" psf_fwhm: float = 0.8,\n",
" pixel_scale: float = 0.2,\n",
"):\n",
" gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size)\n",
"\n",
" gal = xgalsim.Gaussian(flux=f, half_light_radius=hlr)\n",
" gal = gal.shear(g1=e1, g2=e2)\n",
"\n",
" psf = xgalsim.Moffat(flux=1.0, scale_radius=0.8, beta=2.0)\n",
" gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(gsparams)\n",
" image = gal_conv.drawImage(nx=slen, ny=slen, scale=pixel_scale, offset=(x, y))\n",
" return image.array"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "ed271d22-b47c-4b82-8921-65bd8b32d719",
"metadata": {},
"outputs": [],
"source": [
"_func1 = jax.jit(partial(draw_gaussian, slen=63, fft_size=256))\n",
"_ = _func1(f=1.0, hlr=1.0, e1=0.2, e2=0.2, x=0., y=0.0)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "62501c9e-7a53-4376-8dd3-83fc31dbc288",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"228 μs ± 4.69 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"_func1(f=1.0, hlr=1.0, e1=0.2, e2=0.2, x=0., y=0.0)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "6b297349-3516-4226-b863-3f3f846e9676",
"metadata": {},
"outputs": [],
"source": [
"_func2 = jax.jit(partial(draw_gaussian_moffat, slen=63, fft_size=256))\n",
"_ = _func2(f=1.0, hlr=1.0, e1=0.2, e2=0.2, x=0., y=0.0)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "a2b10347-57b8-4b43-92f1-ffe91bae6ba3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.16 ms ± 20.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"_func2(f=1.0, hlr=1.0, e1=0.2, e2=0.2, x=0., y=0.0) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ae7287d2-d57a-457a-a79d-4b3afd9394fb",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "11f5b52b-d105-43b9-b7e9-714b80efa748",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "4085d869-ecfa-4e77-8cd2-314fa7bc54ab",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "bpd_uv",
"language": "python",
"name": "bpd_uv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading