Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,5 @@ dmypy.json

# Pyre type checker
.pyre/
.worktrees/
.worktrees/
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,27 @@ python setup.py install --user

(modifying the Makefile to your `gcc` as needed)

### Multiprocessing

Zeus21 now includes built-in process-level caching for CLASS cosmology objects. When running many simulations in parallel, each worker process will only initialize CLASS once and reuse the cached object for subsequent calls. This provides a significant speedup for large-scale inference or Monte-Carlo sampling.

Because CLASS uses C-global state, **always use `spawn` (not `fork`)** when creating a multiprocessing pool:

```python
import multiprocessing as mp
import zeus21

ctx = mp.get_context("spawn")
with ctx.Pool(processes=4) as pool:
pool.map(my_zeus21_worker, params_list)
```

See `examples/benchmark_multiprocess.py` for a runnable before/after benchmark.

### NumPy 2.0 Compatibility

Zeus21 is now compatible with NumPy 2.0+. All deprecated `np.trapz` calls have been replaced with `np.trapezoid`.

## Citation

If you find this code useful please cite:
Expand Down
80 changes: 80 additions & 0 deletions examples/benchmark_multiprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Benchmark script demonstrating Zeus21 multiprocess performance with process-level caching.

This script compares the time to run multiple simulations with and without
the built-in CLASS caching in `zeus21.cosmology.runclass()`.

Usage:
python examples/benchmark_multiprocess.py

Requirements:
- zeus21 installed (including CLASS)
- numpy
"""

import time
import multiprocessing as mp
import numpy as np
import zeus21


def _sim_worker_cached(_):
"""Worker that reuses the per-process CLASS cache."""
user_params = zeus21.User_Parameters()
cosmo_input = zeus21.Cosmo_Parameters_Input()
classy_cosmo = zeus21.runclass(cosmo_input)
cosmo_params = zeus21.Cosmo_Parameters(user_params, cosmo_input, classy_cosmo)
hmf_interp = zeus21.HMF_interpolator(user_params, cosmo_params, classy_cosmo)
astro_params = zeus21.Astro_Parameters(user_params, cosmo_params)
coeffs = zeus21.get_T21_coefficients(
user_params, cosmo_params, classy_cosmo, astro_params, hmf_interp, zmin=10.0
)
return float(coeffs.T21avg.sum())


def main():
N_SIMS = 8
N_WORKERS = 4

print("=" * 60)
print("Zeus21 Multiprocess Caching Benchmark")
print("=" * 60)

# Warmup: ensure CLASS is compiled/loaded in the main process
print("\n[Warmup] Running first simulation...")
_sim_worker_cached(None)
print("Done.")

# Single-process benchmark
print("\n[Single-process benchmark]")
t0 = time.time()
for _ in range(N_SIMS):
_sim_worker_cached(None)
t_single = time.time() - t0
print(f" {N_SIMS} simulations in {t_single:.1f}s ({t_single / N_SIMS:.2f}s/sim)")

# Multiprocess benchmark with spawn (safe for CLASS)
print("\n[Multiprocess benchmark]")
print(f" Using spawn context with {N_WORKERS} workers for {N_SIMS} simulations")
ctx = mp.get_context("spawn")
t0 = time.time()
with ctx.Pool(processes=N_WORKERS) as pool:
_ = pool.map(_sim_worker_cached, range(N_SIMS))
t_multi = time.time() - t0
print(f" {N_SIMS} simulations in {t_multi:.1f}s ({t_multi / N_SIMS:.2f}s/sim)")

print("\n" + "=" * 60)
print("Summary")
print("=" * 60)
print(f"Single-process: {t_single:.1f}s total")
print(f"Multiprocess : {t_multi:.1f}s total")
if t_multi > 0:
print(f"Speedup : {t_single / t_multi:.1f}x")
print("\nNote: The first simulation in each new process incurs the")
print("CLASS initialization cost (~few seconds). Subsequent calls in")
print("the same process reuse the cached CLASS object automatically.")
print("=" * 60)


if __name__ == "__main__":
main()
8 changes: 4 additions & 4 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,23 @@ def test_inputs():
#test Pop II Xray SED
Energylisttest = np.logspace(2,np.log10(AstroParams.Emax_xray_norm),100)
SEDXtab_test = AstroParams.SED_XRAY(Energylisttest, 2) #same in both models
normalization_XraySED = np.trapz(Energylisttest * SEDXtab_test,Energylisttest)
normalization_XraySED = np.trapezoid(Energylisttest * SEDXtab_test,Energylisttest)
assert( normalization_XraySED == pytest.approx(1.0, 0.05) ) #5% is enough here

#test Pop III Xray SED
SEDXtab_test = AstroParams.SED_XRAY(Energylisttest, 3) #same in both models
normalization_XraySED = np.trapz(Energylisttest * SEDXtab_test,Energylisttest)
normalization_XraySED = np.trapezoid(Energylisttest * SEDXtab_test,Energylisttest)
assert( normalization_XraySED == pytest.approx(1.0, 0.05) ) #5% is enough here


#test Pop II LyA SED
nulisttest = np.linspace(zeus21.constants.freqLyA, zeus21.constants.freqLyCont, 100)
SEDLtab_test = AstroParams.SED_LyA(nulisttest, 2) #same in both models
normalization_LyASED = np.trapz(SEDLtab_test,nulisttest)
normalization_LyASED = np.trapezoid(SEDLtab_test,nulisttest)
assert( normalization_LyASED == pytest.approx(1.0, 0.05) ) #5% is enough here

#test Pop III LyA SED
nulisttest = np.linspace(zeus21.constants.freqLyA, zeus21.constants.freqLyCont, 100)
SEDLtab_test = AstroParams.SED_LyA(nulisttest, 3) #same in both models
normalization_LyASED = np.trapz(SEDLtab_test,nulisttest)
normalization_LyASED = np.trapezoid(SEDLtab_test,nulisttest)
assert( normalization_LyASED == pytest.approx(1.0, 0.05) ) #5% is enough here
4 changes: 2 additions & 2 deletions zeus21/UVLFs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def UVLF_binned(Astro_Parameters,Cosmo_Parameters,HMF_interpolator, zcenter, zwi
xlo = np.subtract.outer(MUVcutlo, currMUV )/(np.sqrt(2) * sigmaUV)
weights = (erf(xhi) - erf(xlo)).T/(2.0 * MUVwidths)

UVLF_filtered = np.trapz(weights.T * HMFcurr, HMF_interpolator.Mhtab, axis=-1)
UVLF_filtered = np.trapezoid(weights.T * HMFcurr, HMF_interpolator.Mhtab, axis=-1)

if(Astro_Parameters.USE_POPIII==False):
return UVLF_filtered
Expand All @@ -98,7 +98,7 @@ def UVLF_binned(Astro_Parameters,Cosmo_Parameters,HMF_interpolator, zcenter, zwi
xlo = np.subtract.outer(MUVcutlo, MUVbarlist_III)/(np.sqrt(2) * sigmaUV)
weights = (erf(xhi) - erf(xlo)).T/(2.0 * MUVwidths)

UVLF_filtered_III = np.trapz(weights.T * HMFcurr, HMF_interpolator.Mhtab, axis=-1)
UVLF_filtered_III = np.trapezoid(weights.T * HMFcurr, HMF_interpolator.Mhtab, axis=-1)

return UVLF_filtered, UVLF_filtered_III

Expand Down
4 changes: 2 additions & 2 deletions zeus21/correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def __init__(self, User_Parameters, Cosmo_Parameters, Astro_Parameters, ClassCos
#the z>zmax part of the integral we do aside. Assume Tk=Tadiabatic from CLASS.
_zlisthighz_ = np.linspace(T21_coefficients.zintegral[-1], 99., 100) #beyond z=100 need to explictly tell CLASS to save growth
_dgrowthhighz_ = cosmology.dgrowth_dz(Cosmo_Parameters, _zlisthighz_)
_hizintegral = np.trapz(cosmology.Tadiabatic(Cosmo_Parameters,_zlisthighz_)
_hizintegral = np.trapezoid(cosmology.Tadiabatic(Cosmo_Parameters,_zlisthighz_)
/(1+_zlisthighz_)**2 * _dgrowthhighz_, _zlisthighz_)

self._betaTad_ = -2./3. * _factor_adi_/self._lingrowthd * (np.cumsum(_integrand_adi[::-1])[::-1] + _hizintegral) #units of Tk_avg. Internal sum goes from high to low z (backwards), minus sign accounts for it properly so it's positive.
Expand Down Expand Up @@ -1124,7 +1124,7 @@ def get_Pk_from_xi(self, rsinput, xiinput):
#
# Probdtab = np.exp(-dtab**2/sigmaRRsq/2.0)
#
# norm = np.trapz(NionEPS * Probdtab, dtab)
# norm = np.trapezoid(NionEPS * Probdtab, dtab)
# NionEPS/=norm
#
# bindex = min(range(len(NionEPS)), key=lambda i: abs(NionEPS[i]-_invQbar))
Expand Down
27 changes: 22 additions & 5 deletions zeus21/cosmology.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""

import numpy as np
import os
from classy import Class
from scipy.interpolate import RegularGridInterpolator
from scipy.interpolate import interp1d
Expand All @@ -21,6 +22,10 @@
from .inputs import Cosmo_Parameters, Cosmo_Parameters_Input
from .correlations import Correlations

# Process-level cache for CLASS cosmology objects. Keyed by cosmological parameters
# and os.getpid() so that each spawned worker only initializes CLASS once.
_COSMO_CACHE = {}

def cosmo_wrapper(User_Parameters, Cosmo_Parameters_Input):
"""
Wrapper function for all the cosmology. It takes Cosmo_Parameters_Input and returns:
Expand All @@ -41,6 +46,17 @@ def cosmo_wrapper(User_Parameters, Cosmo_Parameters_Input):

def runclass(CosmologyIn):
"Set up CLASS cosmology. Takes CosmologyIn class input and returns CLASS Cosmology object"
cache_key = (
CosmologyIn.omegab, CosmologyIn.omegac, CosmologyIn.h_fid,
CosmologyIn.As, CosmologyIn.ns, CosmologyIn.tau_fid,
CosmologyIn.kmax_CLASS, CosmologyIn.zmax_CLASS,
CosmologyIn.zmin_CLASS, CosmologyIn.Flag_emulate_21cmfast,
CosmologyIn.USE_RELATIVE_VELOCITIES, CosmologyIn.HMF_CHOICE,
os.getpid(),
)
if cache_key in _COSMO_CACHE:
return _COSMO_CACHE[cache_key]

ClassCosmo = Class()
ClassCosmo.set({'omega_b': CosmologyIn.omegab,'omega_cdm': CosmologyIn.omegac,
'h': CosmologyIn.h_fid,'A_s': CosmologyIn.As,'n_s': CosmologyIn.ns,'tau_reio': CosmologyIn.tau_fid})
Expand Down Expand Up @@ -75,13 +91,13 @@ def runclass(CosmologyIn):
theta_b = velTransFunc['t_b']
theta_c = velTransFunc['t_cdm']

sigma_vcb = np.sqrt(np.trapz(CosmologyIn.As * (kVel/0.05)**(CosmologyIn.ns-1) /kVel * (theta_b - theta_c)**2/kVel**2, kVel)) * constants.c_kms
sigma_vcb = np.sqrt(np.trapezoid(CosmologyIn.As * (kVel/0.05)**(CosmologyIn.ns-1) /kVel * (theta_b - theta_c)**2/kVel**2, kVel)) * constants.c_kms
ClassCosmo.pars['sigma_vcb'] = sigma_vcb

###HAC: now computing average velocity assuming a Maxwell-Boltzmann distribution of velocities
velArr = np.geomspace(0.01, constants.c_kms, 1000) #in km/s
vavgIntegrand = (3 / (2 * np.pi * sigma_vcb**2))**(3/2) * 4 * np.pi * velArr**2 * np.exp(-3 * velArr**2 / (2 * sigma_vcb**2))
ClassCosmo.pars['v_avg'] = np.trapz(vavgIntegrand * velArr, velArr)
ClassCosmo.pars['v_avg'] = np.trapezoid(vavgIntegrand * velArr, velArr)

###HAC: Computing Vcb Power Spectrum
ClassCosmo.pars['k_vcb'] = kVel
Expand All @@ -99,8 +115,8 @@ def runclass(CosmologyIn):
j0bessel = lambda x: np.sin(x)/x
j2bessel = lambda x: (3 / x**2 - 1) * np.sin(x)/x - 3*np.cos(x)/x**2

psi0 = 1 / 3 / (sigma_vcb/constants.c_kms)**2 * np.trapz(kVelIntp**2 / 2 / np.pi**2 * p_vcb_intp(np.log(kVelIntp)) * j0bessel(kVelIntp * np.transpose([rVelIntp])), kVelIntp, axis = 1)
psi2 = -2 / 3 / (sigma_vcb/constants.c_kms)**2 * np.trapz(kVelIntp**2 / 2 / np.pi**2 * p_vcb_intp(np.log(kVelIntp)) * j2bessel(kVelIntp * np.transpose([rVelIntp])), kVelIntp, axis = 1)
psi0 = 1 / 3 / (sigma_vcb/constants.c_kms)**2 * np.trapezoid(kVelIntp**2 / 2 / np.pi**2 * p_vcb_intp(np.log(kVelIntp)) * j0bessel(kVelIntp * np.transpose([rVelIntp])), kVelIntp, axis = 1)
psi2 = -2 / 3 / (sigma_vcb/constants.c_kms)**2 * np.trapezoid(kVelIntp**2 / 2 / np.pi**2 * p_vcb_intp(np.log(kVelIntp)) * j2bessel(kVelIntp * np.transpose([rVelIntp])), kVelIntp, axis = 1)

k_eta, P_eta = mcfit.xi2P(rVelIntp, l=0, lowring = True)((6 * psi0**2 + 3 * psi2**2), extrap = False)

Expand All @@ -112,7 +128,8 @@ def runclass(CosmologyIn):
else:
ClassCosmo.pars['v_avg'] = 0.0
ClassCosmo.pars['sigma_vcb'] = 1.0 #Avoids excess computation, but doesn't matter what value we set it to because the flag in inputs.py sets all feedback parameters to zero


_COSMO_CACHE[cache_key] = ClassCosmo
return ClassCosmo

def Hub(Cosmo_Parameters, z):
Expand Down
Loading