From 57e08ebc4b5cfd06b6b8fd8eb61075e1428a81e4 Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Tue, 7 Apr 2026 00:12:09 +0200 Subject: [PATCH 01/18] add option to treat input systematic histograms as difference with respect to nominal --- rabbit/tensorwriter.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/rabbit/tensorwriter.py b/rabbit/tensorwriter.py index 2c1dde8..7543e20 100644 --- a/rabbit/tensorwriter.py +++ b/rabbit/tensorwriter.py @@ -344,10 +344,12 @@ def add_systematic( mirror=True, symmetrize="average", add_to_data_covariance=False, + as_difference=False, **kargs, ): """ h: either a single histogram with the systematic variation if mirror=True or a list of two histograms with the up and down variation + as_difference: if True, interpret the histogram values as the difference with respect to the nominal (i.e. the absolute variation is norm + h) """ norm = self.dict_norm[channel][process] @@ -365,6 +367,10 @@ def add_systematic( syst_up = self.get_flat_values(h[0], flow=flow) syst_down = self.get_flat_values(h[1], flow=flow) + if as_difference: + syst_up = norm + syst_up + syst_down = norm + syst_down + logkup_proc = self.get_logk( syst_up, norm, kfactor, systematic_type=systematic_type ) @@ -385,6 +391,10 @@ def add_systematic( elif mirror: self._check_hist_and_channel(h, channel) syst = self.get_flat_values(h, flow=flow) + + if as_difference: + syst = norm + syst + logkavg_proc = self.get_logk( syst, norm, kfactor, systematic_type=systematic_type ) From c80b409dc943a91b0931d00e85959dfa6a55a522 Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Tue, 7 Apr 2026 00:13:26 +0200 Subject: [PATCH 02/18] add test for sparse mode --- tests/test_sparse_fit.py | 255 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 255 insertions(+) create mode 100644 tests/test_sparse_fit.py diff --git a/tests/test_sparse_fit.py b/tests/test_sparse_fit.py new file mode 100644 index 0000000..c3649bd --- /dev/null +++ b/tests/test_sparse_fit.py @@ -0,0 +1,255 @@ +""" +Test that writes a simple tensor in both dense and sparse modes, +runs a fit on each, and verifies that sparse mode produces consistent results. +""" + +import os +import tempfile +from types import SimpleNamespace + +import hist +import numpy as np + +from rabbit import fitter, inputdata, tensorwriter +from rabbit.poi_models.helpers import load_model + + +def make_histograms(): + """Generate common histograms for test tensors.""" + np.random.seed(42) + + ax = hist.axis.Regular(20, -5, 5, name="x") + + h_data = hist.Hist(ax, storage=hist.storage.Double()) + h_sig = hist.Hist(ax, storage=hist.storage.Weight()) + h_bkg = hist.Hist(ax, storage=hist.storage.Weight()) + + x_sig = np.random.normal(0, 1, 10000) + x_bkg = np.random.uniform(-5, 5, 5000) + + h_data.fill(np.concatenate([x_sig, x_bkg])) + h_sig.fill(x_sig, weight=np.ones(len(x_sig))) + h_bkg.fill(x_bkg, weight=np.ones(len(x_bkg))) + + # scale signal down by 10% so the fit has something to recover + h_sig.values()[...] = h_sig.values() * 0.9 + + # shape systematic on background: linear tilt + bin_centers = h_bkg.axes[0].centers + bin_centers_shifted = bin_centers - bin_centers[0] + weights = 0.01 * bin_centers_shifted - 0.05 + + h_bkg_syst_up = h_bkg.copy() + h_bkg_syst_dn = h_bkg.copy() + h_bkg_syst_up.values()[...] = h_bkg.values() * (1 + weights) + h_bkg_syst_dn.values()[...] = h_bkg.values() * (1 - weights) + + # difference histograms (variation - nominal) + h_bkg_syst_up_diff = h_bkg.copy() + h_bkg_syst_dn_diff = h_bkg.copy() + h_bkg_syst_up_diff.values()[...] = h_bkg.values() * weights + h_bkg_syst_dn_diff.values()[...] = h_bkg.values() * (-weights) + + return dict( + data=h_data, + sig=h_sig, + bkg=h_bkg, + syst_up=h_bkg_syst_up, + syst_dn=h_bkg_syst_dn, + syst_up_diff=h_bkg_syst_up_diff, + syst_dn_diff=h_bkg_syst_dn_diff, + ) + + +def make_test_tensor(outdir, sparse=False, as_difference=False): + """Create a simple tensor with signal + background + one shape systematic.""" + + hists = make_histograms() + + writer = tensorwriter.TensorWriter(sparse=sparse) + + writer.add_channel(hists["data"].axes, "ch0") + writer.add_data(hists["data"], "ch0") + + writer.add_process(hists["sig"], "sig", "ch0", signal=True) + writer.add_process(hists["bkg"], "bkg", "ch0") + + writer.add_norm_systematic("bkg_norm", "bkg", "ch0", 1.05) + + if as_difference: + writer.add_systematic( + [hists["syst_up_diff"], hists["syst_dn_diff"]], + "bkg_shape", + "bkg", + "ch0", + symmetrize="average", + as_difference=True, + ) + else: + writer.add_systematic( + [hists["syst_up"], hists["syst_dn"]], + "bkg_shape", + "bkg", + "ch0", + symmetrize="average", + ) + + suffix = "sparse" if sparse else "dense" + if as_difference: + suffix += "_diff" + name = f"test_{suffix}" + writer.write(outfolder=outdir, outfilename=name) + return os.path.join(outdir, f"{name}.hdf5") + + +def make_options(**kwargs): + """Create a minimal options namespace for the Fitter.""" + defaults = dict( + earlyStopping=-1, + noBinByBinStat=False, + binByBinStatMode="lite", + binByBinStatType="automatic", + covarianceFit=False, + chisqFit=False, + diagnostics=False, + minimizerMethod="trust-krylov", + prefitUnconstrainedNuisanceUncertainty=0.0, + freezeParameters=[], + setConstraintMinimum=[], + ) + defaults.update(kwargs) + return SimpleNamespace(**defaults) + + +def run_fit(filename): + """Load tensor, set up fitter, run fit to data, return results.""" + + indata_obj = inputdata.FitInputData(filename) + poi_model = load_model("Mu", indata_obj) + + options = make_options() + f = fitter.Fitter(indata_obj, poi_model, options) + + # fit to observed data + f.set_nobs(indata_obj.data_obs) + f.minimize() + + # compute hessian covariance + val, grad, hess = f.loss_val_grad_hess() + from rabbit.tfhelpers import edmval_cov + + edmval, cov = edmval_cov(grad, hess) + + poi_val = f.x[: poi_model.npoi].numpy() + theta_val = f.x[poi_model.npoi :].numpy() + cov_np = cov.numpy() if hasattr(cov, "numpy") else np.asarray(cov) + poi_err = np.sqrt(np.diag(cov_np)[: poi_model.npoi]) + nll = f.reduced_nll().numpy() + + return dict( + poi=poi_val, + theta=theta_val, + poi_err=poi_err, + nll=nll, + edmval=edmval, + parms=f.parms, + ) + + +def check_results(label_a, res_a, label_b, res_b, atol=1e-5, rtol=1e-4): + """Compare two fit results and return True if they match.""" + + print(f"\n--- {label_a} vs {label_b} ---") + + for label, res in [(label_a, res_a), (label_b, res_b)]: + print(f"\n{label}:") + for i, name in enumerate(res["parms"][: len(res["poi"])]): + print(f" {name}: {res['poi'][i]:.6f} +/- {res['poi_err'][i]:.6f}") + for i, name in enumerate(res["parms"][len(res["poi"]) :]): + print(f" {name}: {res['theta'][i]:.6f}") + print(f" reduced NLL: {res['nll']:.6f}") + print(f" EDM: {res['edmval']:.2e}") + + poi_match = np.allclose(res_a["poi"], res_b["poi"], atol=atol, rtol=rtol) + theta_match = np.allclose(res_a["theta"], res_b["theta"], atol=atol, rtol=rtol) + err_match = np.allclose(res_a["poi_err"], res_b["poi_err"], atol=atol, rtol=rtol) + nll_match = np.isclose(res_a["nll"], res_b["nll"], atol=atol, rtol=rtol) + + all_ok = poi_match and theta_match and err_match and nll_match + + print(f"\n POI values match: {poi_match}") + print(f" Theta values match: {theta_match}") + print(f" POI uncertainties match: {err_match}") + print(f" NLL values match: {nll_match}") + + if not poi_match: + print(f" {label_a} POI: {res_a['poi']}") + print(f" {label_b} POI: {res_b['poi']}") + print(f" diff: {res_a['poi'] - res_b['poi']}") + + if not theta_match: + print(f" {label_a} theta: {res_a['theta']}") + print(f" {label_b} theta: {res_b['theta']}") + print(f" diff: {res_a['theta'] - res_b['theta']}") + + if not nll_match: + print(f" {label_a} NLL: {res_a['nll']}") + print(f" {label_b} NLL: {res_b['nll']}") + print(f" diff: {res_a['nll'] - res_b['nll']}") + + return all_ok + + +def main(): + import tensorflow as tf + + tf.config.experimental.enable_op_determinism() + + with tempfile.TemporaryDirectory() as tmpdir: + # create tensors in all four modes + dense_file = make_test_tensor(tmpdir, sparse=False) + sparse_file = make_test_tensor(tmpdir, sparse=True) + dense_diff_file = make_test_tensor(tmpdir, sparse=False, as_difference=True) + sparse_diff_file = make_test_tensor(tmpdir, sparse=True, as_difference=True) + + configs = [ + ("Dense", dense_file), + ("Sparse", sparse_file), + ("Dense (as_difference)", dense_diff_file), + ("Sparse (as_difference)", sparse_diff_file), + ] + + results = {} + for label, fpath in configs: + print("=" * 60) + print(f"Running {label} fit...") + print("=" * 60) + results[label] = run_fit(fpath) + print() + + # check consistency across all pairs vs the dense baseline + print("=" * 60) + print("Consistency checks") + print("=" * 60) + + checks = [ + ("Dense", "Sparse"), + ("Dense", "Dense (as_difference)"), + ("Dense", "Sparse (as_difference)"), + ] + + all_ok = True + for label_a, label_b in checks: + ok = check_results(label_a, results[label_a], label_b, results[label_b]) + all_ok = all_ok and ok + + print() + if all_ok: + print("ALL CHECKS PASSED") + else: + print("SOME CHECKS FAILED") + raise SystemExit(1) + + +if __name__ == "__main__": + main() From 0e17ff62a5b38a56d8b0125ed54f74bf580ae223 Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Tue, 7 Apr 2026 00:54:58 +0200 Subject: [PATCH 03/18] Support scipy sparse array inputs in TensorWriter and add as_difference option Add `as_difference` parameter to `add_systematic` to interpret input histograms as differences from nominal. Add full scipy sparse array support for `add_process` and `add_systematic`: in sparse mode, norm is stored as flat CSR and logk is computed only at nonzero positions, avoiding full-size dense intermediates. Extend test_sparse_fit.py to cover all modes including scipy sparse inputs. Co-Authored-By: Claude Opus 4.6 (1M context) --- rabbit/tensorwriter.py | 456 +++++++++++++++++++++++++++++++++------ tests/test_sparse_fit.py | 80 +++++-- 2 files changed, 460 insertions(+), 76 deletions(-) diff --git a/rabbit/tensorwriter.py b/rabbit/tensorwriter.py index 7543e20..ab3d552 100644 --- a/rabbit/tensorwriter.py +++ b/rabbit/tensorwriter.py @@ -71,23 +71,65 @@ def __init__( self.dtype = "float64" self.chunkSize = 4 * 1024**2 + @staticmethod + def _issparse(h): + """Check if h is a scipy sparse array/matrix.""" + return hasattr(h, "toarray") and hasattr(h, "tocoo") + + @staticmethod + def _sparse_to_flat_csr(h, dtype): + """Flatten a scipy sparse array/matrix to CSR with shape (1, prod(shape)). + + The returned CSR array has sorted indices suitable for searchsorted lookups. + """ + import scipy.sparse + + size = int(np.prod(h.shape)) + coo = scipy.sparse.coo_array(h) + if coo.ndim == 2: + flat_indices = np.ravel_multi_index((coo.row, coo.col), h.shape) + elif coo.ndim == 1: + flat_indices = coo.coords[0] + else: + raise ValueError( + f"Unsupported dimensionality {coo.ndim} for scipy sparse input" + ) + sort_order = np.argsort(flat_indices) + sorted_indices = flat_indices[sort_order].astype(np.int32) + sorted_data = coo.data[sort_order].astype(dtype) + indptr = np.array([0, len(sorted_data)], dtype=np.int32) + return scipy.sparse.csr_array( + (sorted_data, sorted_indices, indptr), shape=(1, size) + ) + + def _to_flat_dense(self, h): + """Convert any array-like (including scipy sparse) to a flat dense numpy array.""" + if self._issparse(h): + return np.asarray(h.toarray()).flatten().astype(self.dtype) + return np.asarray(h).flatten().astype(self.dtype) + def get_flat_values(self, h, flow=False): if hasattr(h, "values"): values = h.values(flow=flow) + elif self._issparse(h): + values = np.asarray(h.toarray()) else: values = h - return values.flatten().astype(self.dtype) + return np.asarray(values).flatten().astype(self.dtype) def get_flat_variances(self, h, flow=False): if hasattr(h, "variances"): variances = h.variances(flow=flow) + elif self._issparse(h): + variances = np.asarray(h.toarray()) else: variances = h + variances = np.asarray(variances).flatten().astype(self.dtype) if (variances < 0.0).any(): raise ValueError("Negative variances encountered") - return variances.flatten().astype(self.dtype) + return variances def add_data(self, h, channel="ch0", variances=None): self._check_hist_and_channel(h, channel) @@ -134,19 +176,41 @@ def add_process(self, h, name, channel="ch0", signal=False, variances=None): self.dict_logkhalfdiff_indices[channel][name] = {} flow = self.channels[channel]["flow"] - norm = self.get_flat_values(h, flow) - sumw2 = self.get_flat_variances(h if variances is None else variances, flow) - if not self.allow_negative_expectation: - norm = np.maximum(norm, 0.0) + if self.sparse and self._issparse(h): + # Store as flat CSR, avoiding full dense conversion + norm = self._sparse_to_flat_csr(h, self.dtype) + if not np.all(np.isfinite(norm.data)): + raise RuntimeError( + f"NaN or Inf values encountered in nominal histogram for {name}!" + ) + if not self.allow_negative_expectation: + has_negative = np.any(norm.data < 0.0) + if has_negative: + norm = norm.copy() + norm.data[:] = np.maximum(norm.data, 0.0) + norm.eliminate_zeros() + else: + norm = self.get_flat_values(h, flow) + if not self.allow_negative_expectation: + norm = np.maximum(norm, 0.0) + if not np.all(np.isfinite(norm)): + raise RuntimeError( + f"{len(norm)-sum(np.isfinite(norm))} NaN or Inf values encountered in nominal histogram for {name}!" + ) + + # variances are always stored dense (needed for sumw2 output assembly) + if variances is not None: + sumw2 = self.get_flat_variances(variances, flow) + elif self._issparse(h): + sumw2 = self._to_flat_dense(h) + else: + sumw2 = self.get_flat_variances(h, flow) + if not np.all(np.isfinite(sumw2)): raise RuntimeError( f"{len(sumw2)-sum(np.isfinite(sumw2))} NaN or Inf values encountered in variances for {name}!" ) - if not np.all(np.isfinite(norm)): - raise RuntimeError( - f"{len(norm)-sum(np.isfinite(norm))} NaN or Inf values encountered in nominal histogram for {name}!" - ) self.dict_norm[channel][name] = norm self.dict_sumw2[channel][name] = sumw2 @@ -197,6 +261,13 @@ def _check_hist_and_channel(self, h, channel): \nHistogram axes: {[a.edges for a in axes]} \nChannel axes: {[a.edges for a in channel_axes]} """) + elif self._issparse(h): + size_in = int(np.prod(h.shape)) + size_this = int(np.prod([len(a) for a in self.channels[channel]["axes"]])) + if size_in != size_this: + raise RuntimeError( + f"Total number of elements in sparse input different from channel size '{size_in}' != '{size_this}'" + ) else: shape_in = h.shape shape_this = tuple([len(a) for a in self.channels[channel]["axes"]]) @@ -214,10 +285,22 @@ def _compute_asym_syst( channel, symmetrize="average", add_to_data_covariance=False, + _sparse_info=None, **kargs, ): + """Compute symmetrized logk from asymmetric up/down variations. + + When _sparse_info is set to (nnz_indices, size), logkup/logkdown are value + arrays at those indices and internal book_logk calls use sparse tuples. + """ var_name_out = name + def _wrap(vals): + """Wrap values as sparse tuple if in sparse mode.""" + if _sparse_info is not None: + return (_sparse_info[0], vals, _sparse_info[1]) + return vals + if symmetrize == "conservative": # symmetrize by largest magnitude of up and down variations logkavg_proc = np.where( @@ -242,7 +325,9 @@ def _compute_asym_syst( var_name_out_diff = name + "SymDiff" # special case, book the extra systematic - self.book_logk_avg(logkdiffavg_proc, channel, process, var_name_out_diff) + self.book_logk_avg( + _wrap(logkdiffavg_proc), channel, process, var_name_out_diff + ) self.book_systematic( var_name_out_diff, add_to_data_covariance=add_to_data_covariance, @@ -259,7 +344,7 @@ def _compute_asym_syst( logkavg_proc = 0.5 * (logkup + logkdown) logkhalfdiff_proc = 0.5 * (logkup - logkdown) - self.book_logk_halfdiff(logkhalfdiff_proc, channel, process, name) + self.book_logk_halfdiff(_wrap(logkhalfdiff_proc), channel, process, name) logkup = None logkdown = None @@ -293,39 +378,77 @@ def add_norm_systematic( for p, u in zip(process, uncertainty): norm = self.dict_norm[channel][p] - if isinstance(u, (list, tuple, np.ndarray)): - if len(u) != 2: - raise RuntimeError( - f"lnN uncertainty can only be a scalar for a symmetric or a list of 2 elements for asymmetric lnN uncertainties, but got a list of {len(u)} elements" - ) - # asymmetric lnN uncertainty - syst_up = norm * u[0] - syst_down = norm * u[1] - logkup_proc = self.get_logk( - syst_up, norm, systematic_type=systematic_type - ) - logkdown_proc = -self.get_logk( - syst_down, norm, systematic_type=systematic_type - ) + if self._issparse(norm): + # Sparse norm path: compute logk at nonzero positions only + norm_vals = norm.data + nnz_idx = norm.indices + size = norm.shape[1] - logkavg_proc, var_name_out = self._compute_asym_syst( - logkup_proc, - logkdown_proc, - name, - process, - channel, - symmetrize=symmetrize, - add_to_data_covariance=add_to_data_covariance, - **kargs, + if isinstance(u, (list, tuple, np.ndarray)): + if len(u) != 2: + raise RuntimeError( + f"lnN uncertainty can only be a scalar for a symmetric or a list of 2 elements for asymmetric lnN uncertainties, but got a list of {len(u)} elements" + ) + logkup_proc = self._get_logk_sparse( + norm_vals * u[0], norm_vals, 1.0, systematic_type + ) + logkdown_proc = -self._get_logk_sparse( + norm_vals * u[1], norm_vals, 1.0, systematic_type + ) + logkavg_proc, var_name_out = self._compute_asym_syst( + logkup_proc, + logkdown_proc, + name, + process, + channel, + symmetrize=symmetrize, + add_to_data_covariance=add_to_data_covariance, + _sparse_info=(nnz_idx, size), + **kargs, + ) + else: + logkavg_proc = self._get_logk_sparse( + norm_vals * u, norm_vals, 1.0, systematic_type + ) + + self.book_logk_avg( + (nnz_idx, logkavg_proc, size), channel, p, var_name_out ) else: - syst = norm * u - logkavg_proc = self.get_logk( - syst, norm, systematic_type=systematic_type - ) + if isinstance(u, (list, tuple, np.ndarray)): + if len(u) != 2: + raise RuntimeError( + f"lnN uncertainty can only be a scalar for a symmetric or a list of 2 elements for asymmetric lnN uncertainties, but got a list of {len(u)} elements" + ) + # asymmetric lnN uncertainty + syst_up = norm * u[0] + syst_down = norm * u[1] - self.book_logk_avg(logkavg_proc, channel, p, var_name_out) + logkup_proc = self.get_logk( + syst_up, norm, systematic_type=systematic_type + ) + logkdown_proc = -self.get_logk( + syst_down, norm, systematic_type=systematic_type + ) + + logkavg_proc, var_name_out = self._compute_asym_syst( + logkup_proc, + logkdown_proc, + name, + process, + channel, + symmetrize=symmetrize, + add_to_data_covariance=add_to_data_covariance, + **kargs, + ) + else: + syst = norm * u + logkavg_proc = self.get_logk( + syst, norm, systematic_type=systematic_type + ) + + self.book_logk_avg(logkavg_proc, channel, p, var_name_out) self.book_systematic( var_name_out, @@ -334,6 +457,84 @@ def add_norm_systematic( **kargs, ) + def _add_systematic_sparse( + self, + h, + name, + process, + channel, + norm, + kfactor, + mirror, + symmetrize, + add_to_data_covariance, + as_difference, + **kargs, + ): + """Sparse path for add_systematic when norm is stored as scipy sparse CSR. + + Computes logk only at norm's nonzero positions, avoiding full-size dense + intermediate arrays. The logk result is a tuple (indices, values, size). + """ + systematic_type = "normal" if add_to_data_covariance else self.systematic_type + flow = self.channels[channel]["flow"] + nnz_idx = norm.indices + norm_vals = norm.data + size = norm.shape[1] + + var_name_out = name + + if isinstance(h, (list, tuple)): + self._check_hist_and_channel(h[0], channel) + self._check_hist_and_channel(h[1], channel) + + syst_up_vals = self._get_syst_at_norm_nnz(h[0], norm, flow) + syst_down_vals = self._get_syst_at_norm_nnz(h[1], norm, flow) + + if as_difference: + syst_up_vals = norm_vals + syst_up_vals + syst_down_vals = norm_vals + syst_down_vals + + logkup_vals = self._get_logk_sparse( + syst_up_vals, norm_vals, kfactor, systematic_type + ) + logkdown_vals = -self._get_logk_sparse( + syst_down_vals, norm_vals, kfactor, systematic_type + ) + + logkavg_vals, var_name_out = self._compute_asym_syst( + logkup_vals, + logkdown_vals, + name, + process, + channel, + symmetrize, + add_to_data_covariance, + _sparse_info=(nnz_idx, size), + **kargs, + ) + elif mirror: + self._check_hist_and_channel(h, channel) + + syst_vals = self._get_syst_at_norm_nnz(h, norm, flow) + + if as_difference: + syst_vals = norm_vals + syst_vals + + logkavg_vals = self._get_logk_sparse( + syst_vals, norm_vals, kfactor, systematic_type + ) + else: + raise RuntimeError( + "Only one histogram given but mirror=False, can not construct a variation" + ) + + logkavg_proc = (nnz_idx, logkavg_vals, size) + self.book_logk_avg(logkavg_proc, channel, process, var_name_out) + self.book_systematic( + var_name_out, add_to_data_covariance=add_to_data_covariance, **kargs + ) + def add_systematic( self, h, @@ -354,6 +555,22 @@ def add_systematic( norm = self.dict_norm[channel][process] + # Use sparse path when norm is stored as scipy sparse CSR + if self._issparse(norm): + return self._add_systematic_sparse( + h, + name, + process, + channel, + norm, + kfactor, + mirror, + symmetrize, + add_to_data_covariance, + as_difference, + **kargs, + ) + var_name_out = name systematic_type = "normal" if add_to_data_covariance else self.systematic_type @@ -461,6 +678,37 @@ def add_beta_variations( self.has_beta_variations = True + @staticmethod + def _sparse_values_at(sparse_csr, indices): + """Extract values from a flat CSR array at the given flat indices. + + Uses searchsorted on the sorted CSR indices to avoid any dense conversion. + Returns a dense 1D array of values at the requested positions. + """ + result = np.zeros(len(indices), dtype=sparse_csr.dtype) + positions = np.searchsorted(sparse_csr.indices, indices) + valid = (positions < len(sparse_csr.indices)) & ( + sparse_csr.indices[positions] == indices + ) + result[valid] = sparse_csr.data[positions[valid]] + return result + + def _get_syst_at_norm_nnz(self, h, norm_csr, flow): + """Extract flat systematic values only at norm's nonzero positions. + + h can be a histogram, scipy sparse, or dense array. + Returns a 1D dense array of length norm_csr.nnz. + """ + nnz_idx = norm_csr.indices + if hasattr(h, "values"): + values = h.values(flow=flow) + return values.flatten().astype(self.dtype)[nnz_idx] + elif self._issparse(h): + syst_csr = self._sparse_to_flat_csr(h, self.dtype) + return self._sparse_values_at(syst_csr, nnz_idx) + else: + return np.asarray(h).flatten().astype(self.dtype)[nnz_idx] + def get_logk(self, syst, norm, kfac=1.0, systematic_type=None): if not np.all(np.isfinite(syst)): raise RuntimeError( @@ -490,6 +738,34 @@ def get_logk(self, syst, norm, kfac=1.0, systematic_type=None): f"Invalid systematic_type {systematic_type}, valid choices are 'log_normal' or 'normal'" ) + def _get_logk_sparse(self, syst_vals, norm_vals, kfac, systematic_type): + """Compute logk values at norm's nonzero positions only. + + syst_vals and norm_vals are dense 1D arrays of equal length (nnz of norm). + Returns a 1D dense array of logk values at those positions. + """ + if not np.all(np.isfinite(syst_vals)): + raise RuntimeError( + f"{len(syst_vals)-sum(np.isfinite(syst_vals))} NaN or Inf values encountered in systematic!" + ) + + if systematic_type == "log_normal": + _logk = kfac * np.log(syst_vals / norm_vals) + _logk = np.where( + np.equal(np.sign(norm_vals * syst_vals), 1), + _logk, + self.logkepsilon, + ) + if self.clipSystVariations > 0.0: + _logk = np.clip(_logk, -self.clip, self.clip) + return _logk + elif systematic_type == "normal": + return kfac * (syst_vals - norm_vals) + else: + raise RuntimeError( + f"Invalid systematic_type {systematic_type}, valid choices are 'log_normal' or 'normal'" + ) + def book_logk_avg(self, *args): self.book_logk( self.dict_logkavg, @@ -513,6 +789,16 @@ def book_logk( process, syst_name, ): + if isinstance(logk, tuple): + # Sparse logk from _add_systematic_sparse: (indices, values, size) + nnz_idx, logk_vals, size = logk + nonzero_mask = logk_vals != 0.0 + indices = nnz_idx[nonzero_mask].reshape(-1, 1) + values = logk_vals[nonzero_mask] + dict_logk_indices[channel][process][syst_name] = indices + dict_logk[channel][process][syst_name] = values + return + norm = self.dict_norm[channel][process] # ensure that systematic tensor is sparse where normalization matrix is sparse logk = np.where(np.equal(norm, 0.0), 0.0, logk) @@ -584,7 +870,11 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", meta_data_dict= if proc not in self.dict_norm[chan]: continue - sumw[ibin : ibin + nbinschan, iproc] = self.dict_norm[chan][proc] + norm_proc = self.dict_norm[chan][proc] + if self._issparse(norm_proc): + sumw[ibin + norm_proc.indices, iproc] = norm_proc.data + else: + sumw[ibin : ibin + nbinschan, iproc] = norm_proc sumw2[ibin : ibin + nbinschan, iproc] = self.dict_sumw2[chan][proc] if not chan_info["masked"]: @@ -627,29 +917,56 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", meta_data_dict= continue norm_proc = dict_norm_chan[proc] - norm_indices = np.transpose(np.nonzero(norm_proc)) - norm_values = np.reshape(norm_proc[norm_indices], [-1]) + if self._issparse(norm_proc): + # Use scipy sparse structure directly + norm_indices = norm_proc.indices.reshape(-1, 1) + norm_values = norm_proc.data.copy() - nvals = len(norm_values) - oldlength = norm_sparse_size - norm_sparse_size = oldlength + nvals - norm_sparse_indices.resize([norm_sparse_size, 2]) - norm_sparse_values.resize([norm_sparse_size]) + nvals = len(norm_values) + oldlength = norm_sparse_size + norm_sparse_size = oldlength + nvals + norm_sparse_indices.resize([norm_sparse_size, 2]) + norm_sparse_values.resize([norm_sparse_size]) - out_indices = np.array([[ibin, iproc]]) + np.pad( - norm_indices, ((0, 0), (0, 1)), "constant" - ) - norm_indices = None + out_indices = np.array([[ibin, iproc]]) + np.pad( + norm_indices, ((0, 0), (0, 1)), "constant" + ) + norm_indices = None - norm_sparse_indices[oldlength:norm_sparse_size] = out_indices - out_indices = None + norm_sparse_indices[oldlength:norm_sparse_size] = out_indices + out_indices = None - norm_sparse_values[oldlength:norm_sparse_size] = norm_values - norm_values = None + norm_sparse_values[oldlength:norm_sparse_size] = norm_values + norm_values = None - norm_idx_map = ( - np.cumsum(np.not_equal(norm_proc, 0.0)) - 1 + oldlength - ) + # sorted CSR indices allow searchsorted in logk mapping below + norm_nnz_idx = norm_proc.indices + oldlength_norm = oldlength + else: + norm_indices = np.transpose(np.nonzero(norm_proc)) + norm_values = np.reshape(norm_proc[norm_indices], [-1]) + + nvals = len(norm_values) + oldlength = norm_sparse_size + norm_sparse_size = oldlength + nvals + norm_sparse_indices.resize([norm_sparse_size, 2]) + norm_sparse_values.resize([norm_sparse_size]) + + out_indices = np.array([[ibin, iproc]]) + np.pad( + norm_indices, ((0, 0), (0, 1)), "constant" + ) + norm_indices = None + + norm_sparse_indices[oldlength:norm_sparse_size] = out_indices + out_indices = None + + norm_sparse_values[oldlength:norm_sparse_size] = norm_values + norm_values = None + + norm_idx_map = ( + np.cumsum(np.not_equal(norm_proc, 0.0)) - 1 + oldlength + ) + norm_nnz_idx = None dict_logkavg_proc_indices = dict_logkavg_chan_indices[proc] dict_logkavg_proc_values = dict_logkavg_chan_values[proc] @@ -671,7 +988,15 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", meta_data_dict= # first dimension of output indices are NOT in the dense [nbin,nproc] space, but rather refer to indices in the norm_sparse vectors # second dimension is flattened in the [2,nsyst] space, where logkavg corresponds to [0,isyst] flattened to isyst # two dimensions are kept in separate arrays for now to reduce the number of copies needed later - out_normindices = norm_idx_map[logkavg_proc_indices] + if norm_nnz_idx is not None: + # scipy sparse norm: use searchsorted on sorted CSR indices + flat_positions = logkavg_proc_indices.flatten() + out_normindices = ( + np.searchsorted(norm_nnz_idx, flat_positions) + + oldlength_norm + ).reshape(-1, 1) + else: + out_normindices = norm_idx_map[logkavg_proc_indices] logkavg_proc_indices = None logk_sparse_normindices[oldlength:logk_sparse_size] = ( @@ -704,7 +1029,16 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", meta_data_dict= # first dimension of output indices are NOT in the dense [nbin,nproc] space, but rather refer to indices in the norm_sparse vectors # second dimension is flattened in the [2,nsyst] space, where logkhalfdiff corresponds to [1,isyst] flattened to nsyst + isyst # two dimensions are kept in separate arrays for now to reduce the number of copies needed later - out_normindices = norm_idx_map[logkhalfdiff_proc_indices] + if norm_nnz_idx is not None: + flat_positions = logkhalfdiff_proc_indices.flatten() + out_normindices = ( + np.searchsorted(norm_nnz_idx, flat_positions) + + oldlength_norm + ).reshape(-1, 1) + else: + out_normindices = norm_idx_map[ + logkhalfdiff_proc_indices + ] logkhalfdiff_proc_indices = None logk_sparse_normindices[oldlength:logk_sparse_size] = ( diff --git a/tests/test_sparse_fit.py b/tests/test_sparse_fit.py index c3649bd..db48295 100644 --- a/tests/test_sparse_fit.py +++ b/tests/test_sparse_fit.py @@ -61,7 +61,16 @@ def make_histograms(): ) -def make_test_tensor(outdir, sparse=False, as_difference=False): +def _to_scipy_sparse(h): + """Convert a hist histogram to a scipy sparse CSR array of its values.""" + import scipy.sparse + + return scipy.sparse.csr_array(h.values()) + + +def make_test_tensor( + outdir, sparse=False, as_difference=False, scipy_sparse_input=False +): """Create a simple tensor with signal + background + one shape systematic.""" hists = make_histograms() @@ -71,14 +80,34 @@ def make_test_tensor(outdir, sparse=False, as_difference=False): writer.add_channel(hists["data"].axes, "ch0") writer.add_data(hists["data"], "ch0") - writer.add_process(hists["sig"], "sig", "ch0", signal=True) - writer.add_process(hists["bkg"], "bkg", "ch0") + if scipy_sparse_input: + writer.add_process( + _to_scipy_sparse(hists["sig"]), + "sig", + "ch0", + signal=True, + variances=hists["sig"].variances(), + ) + writer.add_process( + _to_scipy_sparse(hists["bkg"]), + "bkg", + "ch0", + variances=hists["bkg"].variances(), + ) + else: + writer.add_process(hists["sig"], "sig", "ch0", signal=True) + writer.add_process(hists["bkg"], "bkg", "ch0") writer.add_norm_systematic("bkg_norm", "bkg", "ch0", 1.05) if as_difference: + syst_up = hists["syst_up_diff"] + syst_dn = hists["syst_dn_diff"] + if scipy_sparse_input: + syst_up = _to_scipy_sparse(syst_up) + syst_dn = _to_scipy_sparse(syst_dn) writer.add_systematic( - [hists["syst_up_diff"], hists["syst_dn_diff"]], + [syst_up, syst_dn], "bkg_shape", "bkg", "ch0", @@ -86,8 +115,13 @@ def make_test_tensor(outdir, sparse=False, as_difference=False): as_difference=True, ) else: + syst_up = hists["syst_up"] + syst_dn = hists["syst_dn"] + if scipy_sparse_input: + syst_up = _to_scipy_sparse(syst_up) + syst_dn = _to_scipy_sparse(syst_dn) writer.add_systematic( - [hists["syst_up"], hists["syst_dn"]], + [syst_up, syst_dn], "bkg_shape", "bkg", "ch0", @@ -97,6 +131,8 @@ def make_test_tensor(outdir, sparse=False, as_difference=False): suffix = "sparse" if sparse else "dense" if as_difference: suffix += "_diff" + if scipy_sparse_input: + suffix += "_scipy" name = f"test_{suffix}" writer.write(outfolder=outdir, outfilename=name) return os.path.join(outdir, f"{name}.hdf5") @@ -116,6 +152,7 @@ def make_options(**kwargs): prefitUnconstrainedNuisanceUncertainty=0.0, freezeParameters=[], setConstraintMinimum=[], + unblind=[], ) defaults.update(kwargs) return SimpleNamespace(**defaults) @@ -206,17 +243,28 @@ def main(): tf.config.experimental.enable_op_determinism() with tempfile.TemporaryDirectory() as tmpdir: - # create tensors in all four modes - dense_file = make_test_tensor(tmpdir, sparse=False) - sparse_file = make_test_tensor(tmpdir, sparse=True) - dense_diff_file = make_test_tensor(tmpdir, sparse=False, as_difference=True) - sparse_diff_file = make_test_tensor(tmpdir, sparse=True, as_difference=True) - + # create tensors in all modes configs = [ - ("Dense", dense_file), - ("Sparse", sparse_file), - ("Dense (as_difference)", dense_diff_file), - ("Sparse (as_difference)", sparse_diff_file), + ("Dense", make_test_tensor(tmpdir, sparse=False)), + ("Sparse", make_test_tensor(tmpdir, sparse=True)), + ( + "Dense (as_difference)", + make_test_tensor(tmpdir, sparse=False, as_difference=True), + ), + ( + "Sparse (as_difference)", + make_test_tensor(tmpdir, sparse=True, as_difference=True), + ), + ( + "Sparse (scipy)", + make_test_tensor(tmpdir, sparse=True, scipy_sparse_input=True), + ), + ( + "Sparse (scipy+diff)", + make_test_tensor( + tmpdir, sparse=True, as_difference=True, scipy_sparse_input=True + ), + ), ] results = {} @@ -236,6 +284,8 @@ def main(): ("Dense", "Sparse"), ("Dense", "Dense (as_difference)"), ("Dense", "Sparse (as_difference)"), + ("Dense", "Sparse (scipy)"), + ("Dense", "Sparse (scipy+diff)"), ] all_ok = True From 91ce1b3739c7f3b08ef47645927602fa0e97fc96 Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Tue, 7 Apr 2026 02:11:43 +0200 Subject: [PATCH 04/18] Add multi-systematic dispatch in add_systematic and use wums.SparseHist add_systematic now detects extra axes in the input histogram beyond the channel axes (or via an explicit syst_axes argument) and books one systematic per bin combination on those extra axes, with auto-generated names from the bin labels. Works for hist inputs as well as for SparseHist inputs from wums, in both dense and sparse TensorWriter modes. The local SparseHist implementation has been moved to wums.sparse_hist and is re-exported here for convenience. SparseHist now always uses the with-flow layout internally, and the writer extracts either the with-flow or no-flow representation depending on the channel's flow setting. Co-Authored-By: Claude Opus 4.6 (1M context) --- rabbit/tensorwriter.py | 138 +++++++++- tests/test_multi_systematic.py | 448 +++++++++++++++++++++++++++++++++ 2 files changed, 579 insertions(+), 7 deletions(-) create mode 100644 tests/test_multi_systematic.py diff --git a/rabbit/tensorwriter.py b/rabbit/tensorwriter.py index ab3d552..5cf26e4 100644 --- a/rabbit/tensorwriter.py +++ b/rabbit/tensorwriter.py @@ -4,6 +4,7 @@ import h5py import numpy as np +from wums.sparse_hist import SparseHist # noqa: F401 re-exported for convenience from rabbit import common, h5pyutils_write @@ -77,11 +78,19 @@ def _issparse(h): return hasattr(h, "toarray") and hasattr(h, "tocoo") @staticmethod - def _sparse_to_flat_csr(h, dtype): + def _sparse_to_flat_csr(h, dtype, flow=False): """Flatten a scipy sparse array/matrix to CSR with shape (1, prod(shape)). + For SparseHist inputs, forwards ``flow`` to ``h.to_flat_csr`` so the + wrapper can convert from its internal with-flow layout to the requested + layout. For raw scipy sparse inputs, the row-major flatten of ``h.shape`` + is used directly (the user is responsible for matching the channel layout). + The returned CSR array has sorted indices suitable for searchsorted lookups. """ + if hasattr(h, "to_flat_csr"): + return h.to_flat_csr(dtype, flow=flow) + import scipy.sparse size = int(np.prod(h.shape)) @@ -102,8 +111,13 @@ def _sparse_to_flat_csr(h, dtype): (sorted_data, sorted_indices, indptr), shape=(1, size) ) - def _to_flat_dense(self, h): - """Convert any array-like (including scipy sparse) to a flat dense numpy array.""" + def _to_flat_dense(self, h, flow=False): + """Convert any array-like (including scipy sparse) to a flat dense numpy array. + + For SparseHist inputs, ``flow`` selects the with-flow or no-flow layout. + """ + if isinstance(h, SparseHist): + return np.asarray(h.toarray(flow=flow)).flatten().astype(self.dtype) if self._issparse(h): return np.asarray(h.toarray()).flatten().astype(self.dtype) return np.asarray(h).flatten().astype(self.dtype) @@ -111,6 +125,8 @@ def _to_flat_dense(self, h): def get_flat_values(self, h, flow=False): if hasattr(h, "values"): values = h.values(flow=flow) + elif isinstance(h, SparseHist): + values = h.toarray(flow=flow) elif self._issparse(h): values = np.asarray(h.toarray()) else: @@ -120,6 +136,8 @@ def get_flat_values(self, h, flow=False): def get_flat_variances(self, h, flow=False): if hasattr(h, "variances"): variances = h.variances(flow=flow) + elif isinstance(h, SparseHist): + variances = h.toarray(flow=flow) elif self._issparse(h): variances = np.asarray(h.toarray()) else: @@ -179,7 +197,7 @@ def add_process(self, h, name, channel="ch0", signal=False, variances=None): if self.sparse and self._issparse(h): # Store as flat CSR, avoiding full dense conversion - norm = self._sparse_to_flat_csr(h, self.dtype) + norm = self._sparse_to_flat_csr(h, self.dtype, flow=flow) if not np.all(np.isfinite(norm.data)): raise RuntimeError( f"NaN or Inf values encountered in nominal histogram for {name}!" @@ -203,7 +221,7 @@ def add_process(self, h, name, channel="ch0", signal=False, variances=None): if variances is not None: sumw2 = self.get_flat_variances(variances, flow) elif self._issparse(h): - sumw2 = self._to_flat_dense(h) + sumw2 = self._to_flat_dense(h, flow=flow) else: sumw2 = self.get_flat_variances(h, flow) @@ -535,6 +553,85 @@ def _add_systematic_sparse( var_name_out, add_to_data_covariance=add_to_data_covariance, **kargs ) + @staticmethod + def _bin_label(ax, idx): + """Return a string label for a hist axis bin, preferring string values.""" + try: + v = ax.value(idx) + if isinstance(v, (str, bytes)): + return v.decode() if isinstance(v, bytes) else v + except Exception: + pass + return str(idx) + + def _get_systematic_slices(self, h, name, channel, syst_axes=None): + """Detect extra axes in h beyond the channel and return list of (sub_name, sub_h) slices. + + Returns None if there are no extra axes (i.e. single-systematic case). + + h may be a single histogram or a list/tuple of two (up/down) histograms. + Both elements of a pair must share the same extra-axis structure. + + syst_axes: + - None (default): auto-detect any axes in h not present in the channel + - list of axis names: use exactly these axes as systematic axes + - empty list: disable detection entirely + """ + if syst_axes is not None and len(syst_axes) == 0: + return None + + if isinstance(h, (list, tuple)): + h_ref = h[0] + is_pair = True + else: + h_ref = h + is_pair = False + + # only hist-like objects (with .axes) support multi-systematic + if not hasattr(h_ref, "axes"): + return None + + h_axis_names = [a.name for a in h_ref.axes] + channel_axis_names = [a.name for a in self.channels[channel]["axes"]] + + if syst_axes is None: + extra_axis_names = [n for n in h_axis_names if n not in channel_axis_names] + else: + for n in syst_axes: + if n not in h_axis_names: + raise RuntimeError( + f"Requested systematic axis '{n}' not found in histogram axes {h_axis_names}" + ) + if n in channel_axis_names: + raise RuntimeError( + f"Systematic axis '{n}' overlaps with channel axes {channel_axis_names}" + ) + extra_axis_names = list(syst_axes) + + if not extra_axis_names: + return None + + extra_axes = [h_ref.axes[n] for n in extra_axis_names] + extra_sizes = [len(a) for a in extra_axes] + + import itertools + + slices = [] + for idx_tuple in itertools.product(*[range(s) for s in extra_sizes]): + labels = [self._bin_label(ax, i) for ax, i in zip(extra_axes, idx_tuple)] + sub_name = "_".join([name, *labels]) + + slice_dict = {n: i for n, i in zip(extra_axis_names, idx_tuple)} + + if is_pair: + sub_h = [h[0][slice_dict], h[1][slice_dict]] + else: + sub_h = h[slice_dict] + + slices.append((sub_name, sub_h)) + + return slices + def add_systematic( self, h, @@ -546,13 +643,39 @@ def add_systematic( symmetrize="average", add_to_data_covariance=False, as_difference=False, + syst_axes=None, **kargs, ): """ h: either a single histogram with the systematic variation if mirror=True or a list of two histograms with the up and down variation as_difference: if True, interpret the histogram values as the difference with respect to the nominal (i.e. the absolute variation is norm + h) + syst_axes: optional list of axis names in h that represent independent systematics. + If None (default) and h is a hist-like object with axes beyond the channel, + the extra axes are auto-detected and each bin combination becomes a separate + systematic with name "{name}_{label_0}_{label_1}_...". Pass an empty list + to disable auto-detection. """ + # multi-systematic dispatch: if h has extra axes beyond the channel, + # iterate over those and book each combination as an independent systematic + slices = self._get_systematic_slices(h, name, channel, syst_axes) + if slices is not None: + for sub_name, sub_h in slices: + self.add_systematic( + sub_h, + sub_name, + process, + channel, + kfactor=kfactor, + mirror=mirror, + symmetrize=symmetrize, + add_to_data_covariance=add_to_data_covariance, + as_difference=as_difference, + syst_axes=[], + **kargs, + ) + return + norm = self.dict_norm[channel][process] # Use sparse path when norm is stored as scipy sparse CSR @@ -696,7 +819,8 @@ def _sparse_values_at(sparse_csr, indices): def _get_syst_at_norm_nnz(self, h, norm_csr, flow): """Extract flat systematic values only at norm's nonzero positions. - h can be a histogram, scipy sparse, or dense array. + h can be a histogram, scipy sparse, SparseHist, or dense array. + ``flow`` controls the flat layout (must match the channel/norm layout). Returns a 1D dense array of length norm_csr.nnz. """ nnz_idx = norm_csr.indices @@ -704,7 +828,7 @@ def _get_syst_at_norm_nnz(self, h, norm_csr, flow): values = h.values(flow=flow) return values.flatten().astype(self.dtype)[nnz_idx] elif self._issparse(h): - syst_csr = self._sparse_to_flat_csr(h, self.dtype) + syst_csr = self._sparse_to_flat_csr(h, self.dtype, flow=flow) return self._sparse_values_at(syst_csr, nnz_idx) else: return np.asarray(h).flatten().astype(self.dtype)[nnz_idx] diff --git a/tests/test_multi_systematic.py b/tests/test_multi_systematic.py new file mode 100644 index 0000000..986dd5e --- /dev/null +++ b/tests/test_multi_systematic.py @@ -0,0 +1,448 @@ +""" +Test that add_systematic correctly handles a histogram with extra axes +representing multiple independent systematics. The result should be identical +to booking each systematic individually. +""" + +import os +import tempfile + +import h5py +import hist +import numpy as np + +from rabbit import tensorwriter + + +def make_base_histograms(nsyst): + """Build a nominal background plus per-syst variation histograms.""" + np.random.seed(42) + + ax_x = hist.axis.Regular(20, -5, 5, name="x") + + h_bkg = hist.Hist(ax_x, storage=hist.storage.Weight()) + h_bkg.fill(np.random.uniform(-5, 5, 5000), weight=np.ones(5000)) + + bin_centers = ax_x.centers - ax_x.centers[0] + base_weights = 0.01 * bin_centers - 0.05 + + # build a different variation per systematic + variations_up = [] + variations_dn = [] + for i in range(nsyst): + scale = 1.0 + 0.5 * i + h_up = h_bkg.copy() + h_dn = h_bkg.copy() + h_up.values()[...] = h_bkg.values() * (1 + scale * base_weights) + h_dn.values()[...] = h_bkg.values() * (1 - scale * base_weights) + variations_up.append(h_up) + variations_dn.append(h_dn) + + return h_bkg, variations_up, variations_dn + + +def make_writer_with_individual_systs( + h_bkg, variations_up, variations_dn, name_prefix, sparse=False +): + """Reference: book each systematic separately via the existing API.""" + writer = tensorwriter.TensorWriter(sparse=sparse) + writer.add_channel(h_bkg.axes, "ch0") + writer.add_data(h_bkg, "ch0") + writer.add_process(h_bkg, "bkg", "ch0", signal=True) + + for i, (h_up, h_dn) in enumerate(zip(variations_up, variations_dn)): + writer.add_systematic( + [h_up, h_dn], + f"{name_prefix}_{i}", + "bkg", + "ch0", + symmetrize="average", + ) + return writer + + +def make_writer_with_multi_axis( + h_bkg, variations_up, variations_dn, name_prefix, sparse=False +): + """New path: pack the variations into a single histogram with an extra 'syst' axis.""" + nsyst = len(variations_up) + ax_x = h_bkg.axes[0] + ax_syst = hist.axis.Integer(0, nsyst, underflow=False, overflow=False, name="syst") + + h_up_combined = hist.Hist(ax_x, ax_syst, storage=hist.storage.Weight()) + h_dn_combined = hist.Hist(ax_x, ax_syst, storage=hist.storage.Weight()) + for i in range(nsyst): + h_up_combined.values()[:, i] = variations_up[i].values() + h_dn_combined.values()[:, i] = variations_dn[i].values() + + writer = tensorwriter.TensorWriter(sparse=sparse) + writer.add_channel(h_bkg.axes, "ch0") + writer.add_data(h_bkg, "ch0") + writer.add_process(h_bkg, "bkg", "ch0", signal=True) + writer.add_systematic( + [h_up_combined, h_dn_combined], + name_prefix, + "bkg", + "ch0", + symmetrize="average", + ) + return writer + + +def _embed_no_flow_into_with_flow(values_no_flow, axes): + """Embed a no-flow dense array into a with-flow dense array of the given axes. + + Flow bins are filled with zeros. Used to construct SparseHist data when the + user only has values for the regular bins. + """ + full_shape = tuple(int(ax.extent) for ax in axes) + full = np.zeros(full_shape, dtype=values_no_flow.dtype) + slices = tuple( + slice( + tensorwriter.SparseHist._underflow_offset(ax), + tensorwriter.SparseHist._underflow_offset(ax) + len(ax), + ) + for ax in axes + ) + full[slices] = values_no_flow + return full + + +def make_writer_with_sparsehist_multi_axis( + h_bkg, variations_up, variations_dn, name_prefix +): + """Sparse mode + with-flow SparseHist input on a no-flow channel. + + Exercises the conversion from SparseHist's internal with-flow layout to the + no-flow CSR layout used by the channel. + """ + import scipy.sparse + + nsyst = len(variations_up) + ax_x = h_bkg.axes[0] + ax_syst = hist.axis.Integer(0, nsyst, underflow=False, overflow=False, name="syst") + + # Build no-flow (x_size, nsyst) data, then embed into with-flow shape + up_no_flow = np.zeros((len(ax_x), nsyst)) + dn_no_flow = np.zeros((len(ax_x), nsyst)) + for i in range(nsyst): + up_no_flow[:, i] = variations_up[i].values() + dn_no_flow[:, i] = variations_dn[i].values() + + up_full = _embed_no_flow_into_with_flow(up_no_flow, [ax_x, ax_syst]) + dn_full = _embed_no_flow_into_with_flow(dn_no_flow, [ax_x, ax_syst]) + bkg_full = _embed_no_flow_into_with_flow(h_bkg.values(), [ax_x]) + + sh_up = tensorwriter.SparseHist(scipy.sparse.csr_array(up_full), [ax_x, ax_syst]) + sh_dn = tensorwriter.SparseHist(scipy.sparse.csr_array(dn_full), [ax_x, ax_syst]) + sh_bkg = tensorwriter.SparseHist( + scipy.sparse.csr_array(bkg_full.reshape(1, -1)), [ax_x] + ) + + writer = tensorwriter.TensorWriter(sparse=True) + writer.add_channel(h_bkg.axes, "ch0") # flow=False + writer.add_data(h_bkg, "ch0") + writer.add_process(sh_bkg, "bkg", "ch0", signal=True, variances=h_bkg.variances()) + writer.add_systematic( + [sh_up, sh_dn], + name_prefix, + "bkg", + "ch0", + symmetrize="average", + ) + return writer + + +def make_writer_masked_flow_individual( + h_bkg, variations_up, variations_dn, name_prefix +): + """Reference: masked channel with flow=True, hist process and individual hist systematics.""" + ax_x = h_bkg.axes[0] + writer = tensorwriter.TensorWriter(sparse=True) + + # Regular non-masked data channel (needed because every TensorWriter must have data) + writer.add_channel(h_bkg.axes, "ch0") + writer.add_data(h_bkg, "ch0") + writer.add_process(h_bkg, "bkg", "ch0", signal=True) + + # Masked channel with flow=True + writer.add_channel([ax_x], "masked0", masked=True, flow=True) + writer.add_process(h_bkg, "bkg", "masked0", signal=True) + + for i, (h_up, h_dn) in enumerate(zip(variations_up, variations_dn)): + writer.add_systematic( + [h_up, h_dn], + f"{name_prefix}_{i}", + "bkg", + "masked0", + symmetrize="average", + ) + + return writer + + +def make_writer_masked_flow_sparsehist( + h_bkg, variations_up, variations_dn, name_prefix +): + """SparseHist (always with-flow internally) on a masked channel with flow=True.""" + import scipy.sparse + + nsyst = len(variations_up) + ax_x = h_bkg.axes[0] # Regular axis with under/overflow + ax_syst = hist.axis.Integer(0, nsyst, underflow=False, overflow=False, name="syst") + + # Build no-flow data and embed into with-flow shape via the helper. + up_no_flow = np.zeros((len(ax_x), nsyst)) + dn_no_flow = np.zeros((len(ax_x), nsyst)) + for i in range(nsyst): + up_no_flow[:, i] = variations_up[i].values() + dn_no_flow[:, i] = variations_dn[i].values() + + up_full = _embed_no_flow_into_with_flow(up_no_flow, [ax_x, ax_syst]) + dn_full = _embed_no_flow_into_with_flow(dn_no_flow, [ax_x, ax_syst]) + bkg_full = _embed_no_flow_into_with_flow(h_bkg.values(), [ax_x]) + + sh_up = tensorwriter.SparseHist(scipy.sparse.csr_array(up_full), [ax_x, ax_syst]) + sh_dn = tensorwriter.SparseHist(scipy.sparse.csr_array(dn_full), [ax_x, ax_syst]) + sh_bkg = tensorwriter.SparseHist( + scipy.sparse.csr_array(bkg_full.reshape(1, -1)), [ax_x] + ) + + writer = tensorwriter.TensorWriter(sparse=True) + writer.add_channel(h_bkg.axes, "ch0") + writer.add_data(h_bkg, "ch0") + writer.add_process(h_bkg, "bkg", "ch0", signal=True) + + writer.add_channel([ax_x], "masked0", masked=True, flow=True) + writer.add_process( + sh_bkg, "bkg", "masked0", signal=True, variances=np.zeros(int(ax_x.extent)) + ) + writer.add_systematic( + [sh_up, sh_dn], + name_prefix, + "bkg", + "masked0", + symmetrize="average", + ) + return writer + + +def make_writer_with_str_category(h_bkg, variations_up, variations_dn): + """Variant using a StrCategory axis to verify name labels come from bin values.""" + nsyst = len(variations_up) + ax_x = h_bkg.axes[0] + labels = [f"var{i}" for i in range(nsyst)] + ax_syst = hist.axis.StrCategory(labels, name="kind") + + h_up_combined = hist.Hist(ax_x, ax_syst, storage=hist.storage.Weight()) + h_dn_combined = hist.Hist(ax_x, ax_syst, storage=hist.storage.Weight()) + for i in range(nsyst): + h_up_combined.values()[:, i] = variations_up[i].values() + h_dn_combined.values()[:, i] = variations_dn[i].values() + + writer = tensorwriter.TensorWriter() + writer.add_channel(h_bkg.axes, "ch0") + writer.add_data(h_bkg, "ch0") + writer.add_process(h_bkg, "bkg", "ch0", signal=True) + writer.add_systematic( + [h_up_combined, h_dn_combined], + "shape", + "bkg", + "ch0", + symmetrize="average", + ) + return writer, labels + + +def read_hdf5_arrays(path): + """Load systs, dense norm, and dense logk from a written tensor file. + + Materializes dense arrays from the sparse storage format if needed so that + sparse-mode and dense-mode outputs can be compared. + """ + with h5py.File(path, "r") as f: + systs = [s.decode() for s in f["hsysts"][...]] + nproc = len(f["hprocs"][...]) + nsyst = len(systs) + + if "hnorm" in f: + hnorm = np.asarray(f["hnorm"]).reshape( + tuple(f["hnorm"].attrs["original_shape"]) + ) + hlogk = np.asarray(f["hlogk"]).reshape( + tuple(f["hlogk"].attrs["original_shape"]) + ) + return {"systs": systs, "hnorm": hnorm, "hlogk": hlogk} + + # Sparse format: reconstruct dense (nbinsfull, nproc) and (nbinsfull, nproc, nsyst) + # writeFlatInChunks stores the original shape as an attribute + norm_idx_dset = f["hnorm_sparse"]["indices"] + norm_indices = np.asarray(norm_idx_dset).reshape( + tuple(norm_idx_dset.attrs["original_shape"]) + ) + norm_values = np.asarray(f["hnorm_sparse"]["values"]) + nbinsfull, _ = f["hnorm_sparse"].attrs["dense_shape"] + hnorm = np.zeros((int(nbinsfull), int(nproc))) + hnorm[norm_indices[:, 0], norm_indices[:, 1]] = norm_values + + logk_idx_dset = f["hlogk_sparse"]["indices"] + logk_indices = np.asarray(logk_idx_dset).reshape( + tuple(logk_idx_dset.attrs["original_shape"]) + ) + logk_values = np.asarray(f["hlogk_sparse"]["values"]) + # logk_indices[:, 0] indexes into norm_sparse; [:, 1] is syst (or syst*2 for asym) + logk_nsyst_dim = f["hlogk_sparse"].attrs["dense_shape"][1] + symmetric = logk_nsyst_dim == nsyst + if symmetric: + hlogk = np.zeros((int(nbinsfull), int(nproc), int(nsyst))) + else: + hlogk = np.zeros((int(nbinsfull), int(nproc), 2, int(nsyst))) + + for k in range(len(logk_indices)): + ni = logk_indices[k, 0] # index into norm_sparse + si = logk_indices[k, 1] # syst dim index + bin_idx, proc_idx = norm_indices[ni] + if symmetric: + hlogk[bin_idx, proc_idx, si] = logk_values[k] + else: + if si < nsyst: + hlogk[bin_idx, proc_idx, 0, si] = logk_values[k] + else: + hlogk[bin_idx, proc_idx, 1, si - nsyst] = logk_values[k] + + return {"systs": systs, "hnorm": hnorm, "hlogk": hlogk} + + +def main(): + nsyst = 4 + + with tempfile.TemporaryDirectory() as tmpdir: + h_bkg, var_up, var_dn = make_base_histograms(nsyst) + + # Reference path: individual systematics + ref_writer = make_writer_with_individual_systs(h_bkg, var_up, var_dn, "shape") + ref_path = os.path.join(tmpdir, "ref.hdf5") + ref_writer.write(outfolder=tmpdir, outfilename="ref") + + # New path: single histogram with extra axis + multi_writer = make_writer_with_multi_axis(h_bkg, var_up, var_dn, "shape") + multi_writer.write(outfolder=tmpdir, outfilename="multi") + multi_path = os.path.join(tmpdir, "multi.hdf5") + + ref = read_hdf5_arrays(ref_path) + multi = read_hdf5_arrays(multi_path) + + # Auto-generated names for an Integer axis named "syst" should be "shape_0", "shape_1", ... + expected_names = [f"shape_{i}" for i in range(nsyst)] + + print("Reference systs: ", ref["systs"]) + print("Multi-axis systs: ", multi["systs"]) + print("Expected names: ", expected_names) + + assert ( + ref["systs"] == expected_names + ), f"Reference systs {ref['systs']} != expected {expected_names}" + assert ( + multi["systs"] == expected_names + ), f"Multi-axis systs {multi['systs']} != expected {expected_names}" + + assert np.allclose(ref["hnorm"], multi["hnorm"]), "norm mismatch" + assert np.allclose(ref["hlogk"], multi["hlogk"]), "logk mismatch" + + print("PASS: multi-axis Integer matches individual systematics") + + # StrCategory axis: names should come from the string bin labels + cat_writer, cat_labels = make_writer_with_str_category(h_bkg, var_up, var_dn) + cat_writer.write(outfolder=tmpdir, outfilename="cat") + cat_path = os.path.join(tmpdir, "cat.hdf5") + cat = read_hdf5_arrays(cat_path) + + expected_cat_names = sorted([f"shape_{lbl}" for lbl in cat_labels]) + print("StrCategory systs: ", cat["systs"]) + assert ( + cat["systs"] == expected_cat_names + ), f"Category systs {cat['systs']} != expected {expected_cat_names}" + # Same logk values, just different names + assert np.allclose(ref["hnorm"], cat["hnorm"]), "norm mismatch (cat)" + assert np.allclose(ref["hlogk"], cat["hlogk"]), "logk mismatch (cat)" + + print("PASS: multi-axis StrCategory matches individual systematics") + + # Sparse mode: hist with extra axis + sparse_ref_writer = make_writer_with_individual_systs( + h_bkg, var_up, var_dn, "shape", sparse=True + ) + sparse_ref_writer.write(outfolder=tmpdir, outfilename="sparse_ref") + sparse_ref = read_hdf5_arrays(os.path.join(tmpdir, "sparse_ref.hdf5")) + + sparse_multi_writer = make_writer_with_multi_axis( + h_bkg, var_up, var_dn, "shape", sparse=True + ) + sparse_multi_writer.write(outfolder=tmpdir, outfilename="sparse_multi") + sparse_multi = read_hdf5_arrays(os.path.join(tmpdir, "sparse_multi.hdf5")) + + print("Sparse-mode multi-axis systs:", sparse_multi["systs"]) + assert sparse_multi["systs"] == expected_names + assert np.allclose( + sparse_ref["hnorm"], sparse_multi["hnorm"] + ), "norm mismatch (sparse multi vs sparse ref)" + assert np.allclose( + sparse_ref["hlogk"], sparse_multi["hlogk"] + ), "logk mismatch (sparse multi vs sparse ref)" + # Sparse and dense paths should agree + assert np.allclose( + ref["hnorm"], sparse_ref["hnorm"] + ), "norm mismatch (sparse ref vs dense ref)" + assert np.allclose( + ref["hlogk"], sparse_ref["hlogk"] + ), "logk mismatch (sparse ref vs dense ref)" + print("PASS: sparse mode multi-axis matches sparse mode individual") + + # SparseHist input + sparse mode + multi-axis + sh_writer = make_writer_with_sparsehist_multi_axis( + h_bkg, var_up, var_dn, "shape" + ) + sh_writer.write(outfolder=tmpdir, outfilename="sparsehist_multi") + sh = read_hdf5_arrays(os.path.join(tmpdir, "sparsehist_multi.hdf5")) + + print("SparseHist multi-axis systs:", sh["systs"]) + assert sh["systs"] == expected_names + assert np.allclose( + sparse_ref["hnorm"], sh["hnorm"] + ), "norm mismatch (SparseHist multi vs sparse ref)" + assert np.allclose( + sparse_ref["hlogk"], sh["hlogk"] + ), "logk mismatch (SparseHist multi vs sparse ref)" + print("PASS: SparseHist multi-axis matches sparse mode individual") + + # Flow test: SparseHist with flow=True on a masked channel with flow=True + masked_ref_writer = make_writer_masked_flow_individual( + h_bkg, var_up, var_dn, "shape" + ) + masked_ref_writer.write(outfolder=tmpdir, outfilename="masked_ref") + masked_ref = read_hdf5_arrays(os.path.join(tmpdir, "masked_ref.hdf5")) + + masked_sh_writer = make_writer_masked_flow_sparsehist( + h_bkg, var_up, var_dn, "shape" + ) + masked_sh_writer.write(outfolder=tmpdir, outfilename="masked_sh") + masked_sh = read_hdf5_arrays(os.path.join(tmpdir, "masked_sh.hdf5")) + + print("Masked-flow individual systs:", masked_ref["systs"]) + print("Masked-flow SparseHist systs:", masked_sh["systs"]) + assert masked_ref["systs"] == expected_names + assert masked_sh["systs"] == expected_names + assert np.allclose( + masked_ref["hnorm"], masked_sh["hnorm"] + ), "norm mismatch (masked SparseHist flow vs masked individual)" + assert np.allclose( + masked_ref["hlogk"], masked_sh["hlogk"] + ), "logk mismatch (masked SparseHist flow vs masked individual)" + print("PASS: SparseHist on masked flow=True channel matches individual") + + print() + print("ALL CHECKS PASSED") + + +if __name__ == "__main__": + main() From dab00bca2153a388c8cb82bc1ee9fb1fb6cbac05 Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Tue, 7 Apr 2026 03:53:03 +0200 Subject: [PATCH 05/18] Add external likelihood term (gradient + hessian) support TensorWriter.add_external_likelihood_term accepts a 1D hist for the gradient and a 2D hist (or wums.SparseHist) for the hessian, both indexed by hist.axis.StrCategory axes whose bin labels identify the parameters. Both grad and hess (when provided together) must use the same parameter list in the same order; the matrix is indexed by a single parameter list. Multiple terms can be added with distinct names. Sparse hessians via SparseHist preserve sparsity through the writer and the fit. The terms are serialized under an external_terms HDF5 group, loaded back in FitInputData, and resolved against the full fit parameter list (POIs + systs) at Fitter init. Fitter._compute_external_nll adds an additive g^T x_sub + 0.5 x_sub^T H x_sub contribution to the NLL, fully differentiable through TF autodiff so all existing loss_val_grad and hessian methods pick it up automatically. Co-Authored-By: Claude Opus 4.6 (1M context) --- rabbit/fitter.py | 71 ++++++++++ rabbit/inputdata.py | 53 +++++++ rabbit/tensorwriter.py | 189 +++++++++++++++++++++++++ tests/test_external_term.py | 272 ++++++++++++++++++++++++++++++++++++ 4 files changed, 585 insertions(+) create mode 100644 tests/test_external_term.py diff --git a/rabbit/fitter.py b/rabbit/fitter.py index 5e3c29f..e40645d 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -318,6 +318,52 @@ def init_fit_parms( # one common regularization strength parameter self.tau = tf.Variable(1.0, trainable=True, name="tau", dtype=tf.float64) + # External likelihood terms (additive g^T x + 0.5 x^T H x contributions + # to the NLL). Resolve parameter name strings against the full fit + # parameter list (POIs + systs). + self.external_terms = [] + parms_str = self.parms.astype(str) + for term in self.indata.external_terms: + params = np.asarray(term["params"]).astype(str) + indices = np.empty(len(params), dtype=np.int64) + for i, p in enumerate(params): + matches = np.where(parms_str == p)[0] + if len(matches) != 1: + raise RuntimeError( + f"External likelihood term '{term['name']}' parameter " + f"'{p}' matched {len(matches)} entries in fit parameters" + ) + indices[i] = matches[0] + tf_indices = tf.constant(indices, dtype=tf.int64) + + tf_grad = ( + tf.constant(term["grad_values"], dtype=self.indata.dtype) + if term["grad_values"] is not None + else None + ) + + tf_hess_dense = None + tf_hess_sparse = None + if term["hess_dense"] is not None: + tf_hess_dense = tf.constant(term["hess_dense"], dtype=self.indata.dtype) + elif term["hess_sparse"] is not None: + rows, cols, vals = term["hess_sparse"] + tf_hess_sparse = ( + tf.constant(rows, dtype=tf.int64), + tf.constant(cols, dtype=tf.int64), + tf.constant(vals, dtype=self.indata.dtype), + ) + + self.external_terms.append( + { + "name": term["name"], + "indices": tf_indices, + "grad": tf_grad, + "hess_dense": tf_hess_dense, + "hess_sparse": tf_hess_sparse, + } + ) + # constraint minima for nuisance parameters self.theta0 = tf.Variable( self.theta0default, @@ -2087,6 +2133,27 @@ def _compute_nll_components(self, profile=True, full_nll=False): return ln, lc, lbeta, lpenalty, beta + def _compute_external_nll(self): + """Sum of external likelihood term contributions: sum_i (g_i^T x_sub + 0.5 x_sub^T H_i x_sub).""" + if not self.external_terms: + return None + total = tf.zeros([], dtype=self.indata.dtype) + for term in self.external_terms: + x_sub = tf.gather(self.x, term["indices"]) + if term["grad"] is not None: + total = total + tf.reduce_sum(term["grad"] * x_sub) + if term["hess_dense"] is not None: + # 0.5 * x_sub^T H x_sub + total = total + 0.5 * tf.reduce_sum( + x_sub * tf.linalg.matvec(term["hess_dense"], x_sub) + ) + elif term["hess_sparse"] is not None: + rows, cols, vals = term["hess_sparse"] + total = total + 0.5 * tf.reduce_sum( + vals * tf.gather(x_sub, rows) * tf.gather(x_sub, cols) + ) + return total + def _compute_nll(self, profile=True, full_nll=False): ln, lc, lbeta, lpenalty, beta = self._compute_nll_components( profile=profile, full_nll=full_nll @@ -2098,6 +2165,10 @@ def _compute_nll(self, profile=True, full_nll=False): if lpenalty is not None: l = l + lpenalty + + lext = self._compute_external_nll() + if lext is not None: + l = l + lext return l def _compute_loss(self, profile=True): diff --git a/rabbit/inputdata.py b/rabbit/inputdata.py index f9bad08..35329f8 100644 --- a/rabbit/inputdata.py +++ b/rabbit/inputdata.py @@ -182,6 +182,59 @@ def __init__(self, filename, pseudodata=None): self.axis_procs = hist.axis.StrCategory(self.procs, name="processes") + # Load external likelihood terms (optional). + # Each entry is a dict with keys: + # name: str + # params: 1D ndarray of parameter name strings + # grad_values: 1D float ndarray or None + # hess_dense: 2D float ndarray or None + # hess_sparse: tuple (rows, cols, values) or None + self.external_terms = [] + if "external_terms" in f.keys(): + names = [ + s.decode() if isinstance(s, bytes) else s + for s in f["hexternal_term_names"][...] + ] + ext_group = f["external_terms"] + for tname in names: + tg = ext_group[tname] + raw_params = tg["params"][...] + params = np.array( + [s.decode() if isinstance(s, bytes) else s for s in raw_params] + ) + grad_values = ( + np.asarray(maketensor(tg["grad_values"])) + if "grad_values" in tg.keys() + else None + ) + hess_dense = ( + np.asarray(maketensor(tg["hess_dense"])) + if "hess_dense" in tg.keys() + else None + ) + hess_sparse = None + if "hess_sparse" in tg.keys(): + hg = tg["hess_sparse"] + idx_dset = hg["indices"] + if "original_shape" in idx_dset.attrs: + idx_shape = tuple(idx_dset.attrs["original_shape"]) + indices = np.asarray(idx_dset).reshape(idx_shape) + else: + indices = np.asarray(idx_dset) + rows = indices[:, 0] + cols = indices[:, 1] + vals = np.asarray(hg["values"]) + hess_sparse = (rows, cols, vals) + self.external_terms.append( + { + "name": tname, + "params": params, + "grad_values": grad_values, + "hess_dense": hess_dense, + "hess_sparse": hess_sparse, + } + ) + @tf.function def expected_events_nominal(self): rnorm = tf.ones(self.nproc, dtype=self.dtype) diff --git a/rabbit/tensorwriter.py b/rabbit/tensorwriter.py index 5cf26e4..597a9c8 100644 --- a/rabbit/tensorwriter.py +++ b/rabbit/tensorwriter.py @@ -60,6 +60,19 @@ def __init__( self.has_beta_variations = False + # External likelihood terms. Each term is a dict with keys: + # name: identifier + # params: 1D ndarray of parameter name strings; both grad and hess + # refer to this same parameter list in the same order + # grad_values: 1D float ndarray (length == len(params)) or None + # hess_dense: 2D float ndarray of shape (len(params), len(params)) or None + # hess_sparse: tuple (rows, cols, values) for sparse hessian or None + # Exactly one of hess_dense / hess_sparse may be set, or neither + # (gradient-only term). Parameter names are resolved against the full + # fit parameter list (POIs + systs) at fit time. See + # add_external_likelihood_term for details. + self.external_terms = [] + self.clipSystVariations = False if self.clipSystVariations > 0.0: self.clip = np.abs(np.log(self.clipSystVariations)) @@ -801,6 +814,137 @@ def add_beta_variations( self.has_beta_variations = True + @staticmethod + def _strcategory_labels(ax): + """Return the bin labels of a hist StrCategory axis as a numpy string array. + + Raises if ``ax`` is not a StrCategory axis. + """ + import hist as _hist + + if not isinstance(ax, _hist.axis.StrCategory): + raise TypeError( + f"External term axes must be hist.axis.StrCategory; got {type(ax).__name__}" + ) + return np.array([ax.value(i) for i in range(len(ax))], dtype=object) + + def add_external_likelihood_term(self, grad=None, hess=None, name=None): + """Add an additive quadratic term to the negative log-likelihood. + + The term has the form + + L_ext(x) = g^T x_sub + 0.5 * x_sub^T H x_sub + + where ``x_sub`` is the slice of the full fit parameter vector + corresponding to the parameters identified by the StrCategory axes + of ``grad`` / ``hess``. Both ``grad`` and ``hess`` must use the same + parameter list in the same order. The parameter names are stored as + strings and resolved against the full parameter list (POIs + systs) + at fit time. + + Parameters + ---------- + grad : hist.Hist, optional + 1D histogram with one ``hist.axis.StrCategory`` axis whose bin + labels are parameter names. Values are the gradient ``g``. + hess : hist.Hist or wums.SparseHist, optional + 2D histogram with two ``hist.axis.StrCategory`` axes; both must + have identical bin labels equal to the gradient parameter list + (if ``grad`` is also given). Values are the hessian ``H``. May + be a dense ``hist.Hist`` or a ``wums.SparseHist`` for sparse + storage. ``H`` should be symmetric (the formula is + ``0.5 x^T H x``); the user is responsible for symmetrizing. + name : str, optional + Identifier for this term. Auto-generated if not provided. + Multiple terms can be added by calling this method repeatedly. + """ + if grad is None and hess is None: + raise ValueError( + "add_external_likelihood_term requires at least one of grad or hess" + ) + + if name is None: + name = f"ext{len(self.external_terms)}" + if any(t["name"] == name for t in self.external_terms): + raise RuntimeError(f"External likelihood term '{name}' already added") + + params = None + + # Process gradient + grad_values = None + if grad is not None: + if not hasattr(grad, "axes") or len(grad.axes) != 1: + raise ValueError( + f"grad must be a 1D histogram, got {type(grad).__name__} with " + f"{len(grad.axes) if hasattr(grad, 'axes') else 0} axes" + ) + grad_params = self._strcategory_labels(grad.axes[0]) + grad_values = np.asarray(grad.values()).flatten().astype(self.dtype) + if len(grad_values) != len(grad_params): + raise RuntimeError( + f"grad values length {len(grad_values)} does not match params length {len(grad_params)}" + ) + params = grad_params + + # Process hessian + hess_dense = None + hess_sparse = None + if hess is not None: + if len(hess.axes) != 2: + raise ValueError( + f"hess must be a 2D histogram, got {len(hess.axes)} axes" + ) + hess_params0 = self._strcategory_labels(hess.axes[0]) + hess_params1 = self._strcategory_labels(hess.axes[1]) + if not np.array_equal(hess_params0, hess_params1): + raise ValueError( + "hess must have identical labels on both axes (since it is " + "indexed by the same parameter list)" + ) + if params is not None: + if not np.array_equal(params, hess_params0): + raise ValueError( + "grad and hess must use the same parameter list in the " + f"same order; got grad params {params.tolist()} vs " + f"hess params {hess_params0.tolist()}" + ) + else: + params = hess_params0 + + if isinstance(hess, SparseHist): + # extract sparse coordinates and values from with-flow layout + # (StrCategory axes never have flow, so this matches the no-flow layout) + csr = hess.to_flat_csr(self.dtype, flow=False) + # csr has shape (1, n*n); convert flat positions to (row, col) + n = len(params) + flat = np.asarray(csr.indices, dtype=np.int64) + rows = (flat // n).astype(np.int64) + cols = (flat % n).astype(np.int64) + vals = np.asarray(csr.data, dtype=self.dtype) + hess_sparse = (rows, cols, vals) + elif self._issparse(hess): + raise ValueError( + "raw scipy sparse hess inputs are not supported; " + "wrap in wums.SparseHist with the parameter axes attached" + ) + else: + hess_dense = np.asarray(hess.values()).astype(self.dtype) + if hess_dense.shape != (len(params), len(params)): + raise RuntimeError( + f"hess shape {hess_dense.shape} does not match " + f"params length {len(params)}" + ) + + self.external_terms.append( + { + "name": name, + "params": np.asarray(params), + "grad_values": grad_values, + "hess_dense": hess_dense, + "hess_sparse": hess_sparse, + } + ) + @staticmethod def _sparse_values_at(sparse_csr, indices): """Extract values from a flat CSR array at the given flat indices. @@ -1482,6 +1626,51 @@ def create_dataset( ) beta_variations = None + # Write external likelihood terms + if self.external_terms: + ext_group = f.create_group("external_terms") + create_dataset( + "external_term_names", + [t["name"] for t in self.external_terms], + ) + for term in self.external_terms: + term_group = ext_group.create_group(term["name"]) + params_ds = term_group.create_dataset( + "params", + [len(term["params"])], + dtype=h5py.special_dtype(vlen=str), + compression="gzip", + ) + params_ds[...] = [str(p) for p in term["params"]] + + if term["grad_values"] is not None: + nbytes += h5pyutils_write.writeFlatInChunks( + term["grad_values"], + term_group, + "grad_values", + maxChunkBytes=self.chunkSize, + ) + + if term["hess_dense"] is not None: + nbytes += h5pyutils_write.writeFlatInChunks( + term["hess_dense"], + term_group, + "hess_dense", + maxChunkBytes=self.chunkSize, + ) + elif term["hess_sparse"] is not None: + rows, cols, vals = term["hess_sparse"] + n = len(term["params"]) + indices = np.stack([rows, cols], axis=-1).astype(self.idxdtype) + nbytes += h5pyutils_write.writeSparse( + indices, + vals.astype(self.dtype), + (n, n), + term_group, + "hess_sparse", + maxChunkBytes=self.chunkSize, + ) + logger.info(f"Total raw bytes in arrays = {nbytes}") def get_systsstandard(self): diff --git a/tests/test_external_term.py b/tests/test_external_term.py new file mode 100644 index 0000000..9300224 --- /dev/null +++ b/tests/test_external_term.py @@ -0,0 +1,272 @@ +"""Test external likelihood terms (gradient + hessian) added to TensorWriter and Fitter. + +The external term has the form + + L_ext(x) = g^T x_sub + 0.5 x_sub^T H x_sub + +where x_sub is the slice of fit parameters identified by the StrCategory axes +of grad/hess. With Asimov data and a single Gaussian-constrained nuisance, +the analytical post-fit value of the nuisance is + + theta = -g / (1 + h) + +where the +1 is the prefit Gaussian constraint and +h is the external hessian +contribution. This script verifies that prediction for several configurations, +including dense and sparse (wums.SparseHist) hessian storage. +""" + +import os +import tempfile +from types import SimpleNamespace + +import hist +import numpy as np +import scipy.sparse +from wums.sparse_hist import SparseHist + +from rabbit import fitter, inputdata, tensorwriter +from rabbit.poi_models.helpers import load_model + + +def make_options(**kwargs): + defaults = dict( + earlyStopping=-1, + noBinByBinStat=True, + binByBinStatMode="lite", + binByBinStatType="automatic", + covarianceFit=False, + chisqFit=False, + diagnostics=False, + minimizerMethod="trust-krylov", + prefitUnconstrainedNuisanceUncertainty=0.0, + freezeParameters=[], + setConstraintMinimum=[], + unblind=[], + ) + defaults.update(kwargs) + return SimpleNamespace(**defaults) + + +def build_writer(grad=None, hess=None): + """Build a TensorWriter with one bkg process and a single shape systematic.""" + np.random.seed(0) + ax = hist.axis.Regular(20, -5, 5, name="x") + + h_data = hist.Hist(ax, storage=hist.storage.Double()) + h_bkg = hist.Hist(ax, storage=hist.storage.Weight()) + + x_bkg = np.random.uniform(-5, 5, 5000) + h_data.fill(x_bkg) + h_bkg.fill(x_bkg, weight=np.ones(len(x_bkg))) + + bin_centers = ax.centers - ax.centers[0] + weights = 0.01 * bin_centers - 0.05 + h_up = h_bkg.copy() + h_dn = h_bkg.copy() + h_up.values()[...] = h_bkg.values() * (1 + weights) + h_dn.values()[...] = h_bkg.values() * (1 - weights) + + writer = tensorwriter.TensorWriter() + writer.add_channel([ax], "ch0") + writer.add_data(h_data, "ch0") + writer.add_process(h_bkg, "bkg", "ch0", signal=True) + writer.add_systematic([h_up, h_dn], "shape", "bkg", "ch0", symmetrize="average") + + if grad is not None or hess is not None: + writer.add_external_likelihood_term(grad=grad, hess=hess) + + return writer + + +def run_fit(filename): + indata_obj = inputdata.FitInputData(filename) + poi_model = load_model("Mu", indata_obj) + options = make_options() + f = fitter.Fitter(indata_obj, poi_model, options) + + # use Asimov data so the only force on the nuisance is the constraint + external term + f.set_nobs(f.expected_yield()) + f.minimize() + + parms_str = f.parms.astype(str) + return { + "parms": parms_str, + "x": f.x.numpy(), + } + + +def loss_grad_hess_at(filename, x_override=None): + """Return (loss, grad, hess) for the loaded tensor evaluated at x_override + (or the default starting x if None). Uses Asimov data.""" + import tensorflow as tf + + indata_obj = inputdata.FitInputData(filename) + poi_model = load_model("Mu", indata_obj) + options = make_options() + f = fitter.Fitter(indata_obj, poi_model, options) + f.set_nobs(f.expected_yield()) + if x_override is not None: + f.x.assign(tf.constant(x_override, dtype=f.x.dtype)) + val, grad, hess = f.loss_val_grad_hess() + return ( + f.parms.astype(str), + val.numpy(), + grad.numpy(), + hess.numpy(), + ) + + +def get_param_value(result, name): + idx = np.where(result["parms"] == name)[0][0] + return result["x"][idx] + + +def get_param_index(parms, name): + return int(np.where(parms == name)[0][0]) + + +def make_grad_hist(values, param_names): + """Build a 1D hist with a StrCategory axis for an external gradient.""" + ax = hist.axis.StrCategory(param_names, name="params") + h = hist.Hist(ax, storage=hist.storage.Double()) + h.values()[...] = np.asarray(values) + return h + + +def make_hess_hist(values, param_names): + """Build a 2D hist with two StrCategory axes for an external hessian.""" + ax0 = hist.axis.StrCategory(param_names, name="params0") + ax1 = hist.axis.StrCategory(param_names, name="params1") + h = hist.Hist(ax0, ax1, storage=hist.storage.Double()) + h.values()[...] = np.asarray(values) + return h + + +def make_hess_sparsehist(values, param_names): + """Same as make_hess_hist but using a wums.SparseHist. + + StrCategory axes have an overflow bin by default, so SparseHist's + with-flow layout has shape (n+1, n+1). The user data goes in the + first n x n block; the overflow row/col is filled with zeros. + """ + ax0 = hist.axis.StrCategory(param_names, name="params0") + ax1 = hist.axis.StrCategory(param_names, name="params1") + n = len(param_names) + full = np.zeros((ax0.extent, ax1.extent), dtype=np.float64) + full[:n, :n] = np.asarray(values, dtype=np.float64) + return SparseHist(scipy.sparse.csr_array(full), [ax0, ax1]) + + +def main(): + import tensorflow as tf + + tf.config.experimental.enable_op_determinism() + + SHAPE = "shape" + + with tempfile.TemporaryDirectory() as tmpdir: + + # --- Baseline: no external term --- + baseline_writer = build_writer() + baseline_writer.write(outfolder=tmpdir, outfilename="baseline") + baseline = run_fit(os.path.join(tmpdir, "baseline.hdf5")) + baseline_shape = get_param_value(baseline, SHAPE) + print(f"Baseline (no external): {SHAPE} = {baseline_shape:.6f}") + assert ( + abs(baseline_shape) < 1e-6 + ), f"Asimov baseline should give {SHAPE} ~ 0, got {baseline_shape}" + print("PASS: baseline Asimov fit gives shape ~ 0") + + # Reference loss/grad/hess at the baseline x (no external term). + # The contribution of L_ext(x) = g^T x + 0.5 x^T H x to the NLL gradient + # at any x is (g + H x), and to the NLL hessian is H. We test these + # exactly (not analytical post-fit values, which depend on the data + # Hessian and the constraint and don't have a clean closed form). + parms, val0, grad0, hess0 = loss_grad_hess_at( + os.path.join(tmpdir, "baseline.hdf5") + ) + i_shape = get_param_index(parms, SHAPE) + x0 = baseline["x"].copy() + # the test below evaluates external terms at the baseline minimum + # where x[i_shape] = 0, so H x_sub = 0 → grad delta == g exactly. + + configs = [ + ( + "grad only (g=1)", + build_writer(grad=make_grad_hist([1.0], [SHAPE])), + {i_shape: 1.0}, + {(i_shape, i_shape): 0.0}, + ), + ( + "grad+dense hess (g=1, h=2)", + build_writer( + grad=make_grad_hist([1.0], [SHAPE]), + hess=make_hess_hist([[2.0]], [SHAPE]), + ), + {i_shape: 1.0}, + {(i_shape, i_shape): 2.0}, + ), + ( + "grad+SparseHist hess (g=1, h=2)", + build_writer( + grad=make_grad_hist([1.0], [SHAPE]), + hess=make_hess_sparsehist([[2.0]], [SHAPE]), + ), + {i_shape: 1.0}, + {(i_shape, i_shape): 2.0}, + ), + ( + "hess only (h=5)", + build_writer(hess=make_hess_hist([[5.0]], [SHAPE])), + {i_shape: 0.0}, + {(i_shape, i_shape): 5.0}, + ), + ] + + for label, writer, expected_grad_delta, expected_hess_delta in configs: + tag = ( + label.replace(" ", "_") + .replace("(", "") + .replace(")", "") + .replace(",", "") + .replace("=", "") + ) + writer.write(outfolder=tmpdir, outfilename=tag) + _, val, grad, hess = loss_grad_hess_at( + os.path.join(tmpdir, f"{tag}.hdf5"), + x_override=x0, + ) + for idx, expected in expected_grad_delta.items(): + actual = grad[idx] - grad0[idx] + print( + f"{label}: grad delta @ idx {idx} = {actual:+.6f} (expected {expected:+.6f})" + ) + assert ( + abs(actual - expected) < 1e-8 + ), f"{label}: grad delta {actual} != expected {expected}" + for (i, j), expected in expected_hess_delta.items(): + actual = hess[i, j] - hess0[i, j] + print( + f"{label}: hess delta @ ({i},{j}) = {actual:+.6f} (expected {expected:+.6f})" + ) + assert ( + abs(actual - expected) < 1e-8 + ), f"{label}: hess delta {actual} != expected {expected}" + print(f"PASS: {label}") + + # Sanity check: also verify that running the full fit shifts the + # baseline shape value in the expected direction (negative for g=+1). + grad_only_writer = build_writer(grad=make_grad_hist([1.0], [SHAPE])) + grad_only_writer.write(outfolder=tmpdir, outfilename="grad_only_fit") + grad_only = run_fit(os.path.join(tmpdir, "grad_only_fit.hdf5")) + v = get_param_value(grad_only, SHAPE) + print(f"Full fit with g=+1: shape = {v:.6f} (expected negative)") + assert v < -1e-3, f"Expected shape to pull negative, got {v}" + print("PASS: full fit with positive gradient pulls shape negative") + + print() + print("ALL CHECKS PASSED") + + +if __name__ == "__main__": + main() From 47f1f904cb7b6c8bb5cee59b84e6e38d9862107d Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Tue, 7 Apr 2026 15:58:51 +0200 Subject: [PATCH 06/18] Add efficient SparseHist multi-systematic dispatch in TensorWriter The generic _get_systematic_slices loop calls h[slice_dict] once per combination on the extra (systematic) axes, which for SparseHist input is O(nnz) per slice and prohibitively slow when there are many extra bins (e.g. ~108k corparms over a ~31M nnz SparseHist would take hours). Add a fast path that pre-extracts the with-flow flat representation once, computes a linear systematic index from the extra-axis coordinates, sorts globally, and then yields contiguous per-bin runs. Empty combinations yield an empty SparseHist over the kept axes so the caller can still book the corresponding systematic name (allowing it to be constrained by an external term even when the template variation is identically zero). This is O(nnz log nnz) total instead of O(nnz) per slice, and supports both single and asymmetric (up/down) inputs. Co-Authored-By: Claude Opus 4.6 (1M context) --- rabbit/tensorwriter.py | 161 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/rabbit/tensorwriter.py b/rabbit/tensorwriter.py index 597a9c8..4c43fc8 100644 --- a/rabbit/tensorwriter.py +++ b/rabbit/tensorwriter.py @@ -577,6 +577,109 @@ def _bin_label(ax, idx): pass return str(idx) + @staticmethod + def _make_empty_sparsehist(axes, size): + """Construct an empty SparseHist over ``axes`` with the given flat size.""" + import scipy.sparse + + empty_csr = scipy.sparse.csr_array( + ( + np.zeros(0, dtype=np.float64), + np.zeros(0, dtype=np.int64), + np.array([0, 0], dtype=np.int64), + ), + shape=(1, int(size)), + ) + return SparseHist(empty_csr, axes) + + def _sparse_per_syst_slices(self, h, extra_axes, extra_axis_names, keep_axes): + """Yield ``(linear_idx, sub_h)`` for every bin combination on the extra axes. + + Single-pass O(nnz log nnz) algorithm: extract all entries via the + with-flow flat layout, compute a linear syst index from the extra-axis + coordinates, sort once, then iterate contiguous per-bin runs. Empty + slots yield an empty SparseHist over the kept axes so that callers can + still book the corresponding systematic name (allowing it to be + constrained externally even when the template variation is exactly + zero). + """ + import scipy.sparse + + extra_sizes = [int(len(a)) for a in extra_axes] + n_total = int(np.prod(extra_sizes)) + keep_extent = tuple(int(a.extent) for a in keep_axes) + keep_size = int(np.prod(keep_extent)) if keep_axes else 1 + + h_axis_names = [a.name for a in h.axes] + extra_positions = [h_axis_names.index(n) for n in extra_axis_names] + keep_positions = [ + i for i in range(len(h_axis_names)) if i not in extra_positions + ] + + csr = h.to_flat_csr(np.float64, flow=True) + flat_idx = np.asarray(csr.indices, dtype=np.int64) + values = np.asarray(csr.data, dtype=np.float64) + + if len(flat_idx) == 0: + for linear_idx in range(n_total): + yield linear_idx, self._make_empty_sparsehist(keep_axes, keep_size) + return + + multi = np.unravel_index(flat_idx, h.shape) + + # Drop entries that fall in flow bins of any extra axis (we only + # iterate over the regular bins of those axes, matching the existing + # multi-systematic dispatch convention). + valid = np.ones(len(flat_idx), dtype=bool) + per_extra_idx = [] + for ax_pos in extra_positions: + ax = h.axes[ax_pos] + u = SparseHist._underflow_offset(ax) + s = int(len(ax)) + valid &= (multi[ax_pos] >= u) & (multi[ax_pos] < u + s) + per_extra_idx.append(multi[ax_pos] - u) + + if not valid.all(): + multi = tuple(m[valid] for m in multi) + values = values[valid] + per_extra_idx = [arr[valid] for arr in per_extra_idx] + + if len(extra_positions) == 1: + syst_linear = per_extra_idx[0] + else: + syst_linear = np.ravel_multi_index(per_extra_idx, extra_sizes) + + sort_order = np.argsort(syst_linear, kind="stable") + sorted_syst = syst_linear[sort_order] + sorted_values = values[sort_order] + sorted_keep_multi = tuple(multi[i][sort_order] for i in keep_positions) + + boundaries = np.searchsorted(sorted_syst, np.arange(n_total + 1), side="left") + + for linear_idx in range(n_total): + start = int(boundaries[linear_idx]) + end = int(boundaries[linear_idx + 1]) + if start == end: + yield linear_idx, self._make_empty_sparsehist(keep_axes, keep_size) + continue + + sub_keep_multi = tuple(arr[start:end] for arr in sorted_keep_multi) + if len(keep_extent) == 1: + sub_flat = sub_keep_multi[0] + else: + sub_flat = np.ravel_multi_index(sub_keep_multi, keep_extent) + sub_vals = sorted_values[start:end] + order = np.argsort(sub_flat) + sub_csr = scipy.sparse.csr_array( + ( + sub_vals[order].astype(np.float64), + sub_flat[order].astype(np.int64), + np.array([0, len(sub_vals)], dtype=np.int64), + ), + shape=(1, keep_size), + ) + yield linear_idx, SparseHist(sub_csr, keep_axes) + def _get_systematic_slices(self, h, name, channel, syst_axes=None): """Detect extra axes in h beyond the channel and return list of (sub_name, sub_h) slices. @@ -585,6 +688,15 @@ def _get_systematic_slices(self, h, name, channel, syst_axes=None): h may be a single histogram or a list/tuple of two (up/down) histograms. Both elements of a pair must share the same extra-axis structure. + For ``SparseHist`` inputs, an efficient single-pass algorithm is used + that pre-extracts the underlying flat representation, partitions + entries by their extra-axis indices via a global sort, and then yields + one sub-``SparseHist`` per bin combination on the extra axes. This is + O(nnz log nnz) total instead of O(nnz) per slice, and it always emits + a (possibly empty) sub-hist for *every* combination so that the + downstream booking sees all systematic names even where the + per-bin variation is identically zero. + syst_axes: - None (default): auto-detect any axes in h not present in the channel - list of axis names: use exactly these axes as systematic axes @@ -626,7 +738,56 @@ def _get_systematic_slices(self, h, name, channel, syst_axes=None): extra_axes = [h_ref.axes[n] for n in extra_axis_names] extra_sizes = [len(a) for a in extra_axes] + keep_axes_ref = [a for a in h_ref.axes if a.name not in extra_axis_names] + + # Fast path for SparseHist inputs (the slow per-slice loop below would + # be O(nnz) per slice, which is prohibitive for large syst axes). + if isinstance(h_ref, SparseHist): + if is_pair: + if not isinstance(h[1], SparseHist): + raise TypeError( + "Mixed SparseHist/non-SparseHist pair not supported" + ) + up_iter = self._sparse_per_syst_slices( + h[0], extra_axes, extra_axis_names, keep_axes_ref + ) + dn_iter = self._sparse_per_syst_slices( + h[1], extra_axes, extra_axis_names, keep_axes_ref + ) + paired = zip(up_iter, dn_iter) + else: + paired = ( + (item, None) + for item in self._sparse_per_syst_slices( + h, extra_axes, extra_axis_names, keep_axes_ref + ) + ) + slices = [] + for linear_idx in range(int(np.prod(extra_sizes))): + if is_pair: + (lu, sub_up), (ld, sub_dn) = next(paired) + assert lu == linear_idx and ld == linear_idx + sub_h = [sub_up, sub_dn] + else: + (li, sub), _ = next(paired) + assert li == linear_idx + sub_h = sub + # Decode the extra-axis multi-dim index for label construction + if len(extra_axes) == 1: + idx_tuple = (linear_idx,) + else: + idx_tuple = tuple( + int(x) for x in np.unravel_index(linear_idx, extra_sizes) + ) + labels = [ + self._bin_label(ax, i) for ax, i in zip(extra_axes, idx_tuple) + ] + sub_name = "_".join([name, *labels]) + slices.append((sub_name, sub_h)) + return slices + + # Generic path: hist-like object using its own __getitem__ slicing. import itertools slices = [] From 08537b856f27c02c4fc2bc46605d7cb5ba86e1c4 Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Tue, 7 Apr 2026 20:54:03 +0200 Subject: [PATCH 07/18] Speed up TensorWriter for large multi-systematic SparseHist workloads Several independent optimizations to the writer + write() path. On a realistic 2-channel jpsi calibration tensor with ~108k corparm systematics and a 330M-nnz external hessian, total wall time drops from ~4m30s to ~1m13s. 1. Vectorized SparseHist multi-syst dispatch in add_systematic. New _add_systematics_sparsehist_batched does all per-entry math (channel flat index, norm lookup, sign-flip-protected logk) once over the full ~25M-entry array, partitions by linear systematic index via a single argsort + searchsorted, and bulk-inserts per-syst (indices, values) directly into dict_logkavg / dict_logkavg_indices. Empty bin combinations still get an entry and a corresponding book_systematic call so they appear in the fit parameter list and can be constrained externally. Triggered when the input is a single SparseHist with extra axes plus mirror=True, as_difference=True, no add_to_data_covariance. Per-channel booking goes from ~93s to ~9s. 2. Pre-allocate sparse assembly buffers in write(). The previous loop grew norm_sparse_* and logk_sparse_* via np.ndarray.resize once per (channel, process, syst), which is O(N^2) total because each resize allocates a new buffer and copies all elements. A quick first pass over the dict structures now computes the total nnz so the buffers can be allocated once and filled in place. 3. Replace list.index() with a dict in get_groups, get_constraintweights, get_noiidxs. The old code did systs.index(name) once per group member, giving O(nsysts*nmembers) behaviour: with 108k systs all in a single corparms group this was the dominant cost of write(), eating ~75 seconds. 4. Skip the unnecessary to_flat_csr sort in add_external_likelihood_term. For SparseHist hess input, access _flat_indices/_values directly and recover (rows, cols) via np.divmod, instead of going through to_flat_csr(flow=False) which sorts ~330M entries we then never read in order. ~30s saved. 5. Switch h5py compression from gzip to Blosc2 LZ4 in h5pyutils_write. ~5x faster on integer arrays at slightly better compression ratios. h5pyutils_read imports hdf5plugin so the filter is registered for read-back. 6. Add a compress=True parameter to writeFlatInChunks and have writeSparse pass compress=False for the values payload of an explicitly sparse tensor. Densely packed nonzero floats from real physics tensors compress only ~4% at 5x the write cost, so the compression is pure overhead there. Index buffers continue to compress (~10x ratio with negligible overhead). Also adds a regression test in test_multi_systematic.py that constructs a multi-syst SparseHist and asserts the batched fast path produces bit-identical hnorm/hlogk to per-syst manual booking, with log_normal + as_difference=True and entries that exercise the logkepsilon sign-flip fallback. Co-Authored-By: Claude Opus 4.6 (1M context) --- rabbit/h5pyutils_read.py | 1 + rabbit/h5pyutils_write.py | 48 +++- rabbit/tensorwriter.py | 450 ++++++++++++++++++++++++++++++--- tests/test_multi_systematic.py | 87 +++++++ 4 files changed, 535 insertions(+), 51 deletions(-) diff --git a/rabbit/h5pyutils_read.py b/rabbit/h5pyutils_read.py index 9bc9c9c..c8ba15e 100644 --- a/rabbit/h5pyutils_read.py +++ b/rabbit/h5pyutils_read.py @@ -1,3 +1,4 @@ +import hdf5plugin # noqa: F401 registers Blosc2/LZ4 filter used by the writer import tensorflow as tf diff --git a/rabbit/h5pyutils_write.py b/rabbit/h5pyutils_write.py index 789a78d..a5573c1 100644 --- a/rabbit/h5pyutils_write.py +++ b/rabbit/h5pyutils_write.py @@ -1,30 +1,53 @@ import math +import hdf5plugin import numpy as np +# Compression strategy for the HDF5 write path. +# +# By default dense arrays are written with Blosc2 + LZ4 byte-shuffle. This is +# much faster than gzip (typically ~5x on write) while achieving equal or +# better ratios, and works well for dense tensor buffers that often contain +# lots of structural zeros from sparsity patterns. +# +# Callers that know the data is already densely packed with unstructured +# nonzero values (e.g. the ``values`` payload of an explicitly sparse tensor) +# can pass ``compress=False`` to skip compression entirely. For those inputs +# Blosc2 LZ4 buys only ~4-5% of file size at ~5x the write cost, so turning +# it off is a strict win. +# +# The HDF5 filter pipeline is fundamentally single-threaded per chunk, so +# multi-threaded compression via BLOSC2_NTHREADS does not take effect through +# h5py; the main speedup comes from switching compressor and skipping the +# uncompressible buffers. +# +# Reading requires the hdf5plugin filter to be registered, which happens +# automatically via the ``import hdf5plugin`` at module import time in both +# h5pyutils_write and h5pyutils_read. +_DEFAULT_COMPRESSION_KWARGS = hdf5plugin.Blosc2(cname="lz4", clevel=5) -def writeFlatInChunks(arr, h5group, outname, maxChunkBytes=1024**2): + +def writeFlatInChunks(arr, h5group, outname, maxChunkBytes=1024**2, compress=True): arrflat = arr.reshape(-1) esize = np.dtype(arrflat.dtype).itemsize nbytes = arrflat.size * esize - # special handling for empty datasets, which should not use chunked storage or compression + # Empty datasets must not use chunked storage or compression. if arrflat.size == 0: chunksize = 1 - chunks = None - compression = None + extra_kwargs = {"chunks": None} else: chunksize = int(min(arrflat.size, max(1, math.floor(maxChunkBytes / esize)))) - chunks = (chunksize,) - compression = "gzip" + extra_kwargs = {"chunks": (chunksize,)} + if compress: + extra_kwargs.update(_DEFAULT_COMPRESSION_KWARGS) h5dset = h5group.create_dataset( outname, arrflat.shape, - chunks=chunks, dtype=arrflat.dtype, - compression=compression, + **extra_kwargs, ) # write in chunks, preserving sparsity if relevant @@ -42,8 +65,15 @@ def writeSparse(indices, values, dense_shape, h5group, outname, maxChunkBytes=10 outgroup = h5group.create_group(outname) nbytes = 0 + # Index arrays compress extremely well (~10x for the tensor-sparse + # structures used by rabbit), so keep the default compression. nbytes += writeFlatInChunks(indices, outgroup, "indices", maxChunkBytes) - nbytes += writeFlatInChunks(values, outgroup, "values", maxChunkBytes) + # Values of a sparse tensor are already densely packed nonzeros; real + # physics values typically give only ~4% compression gain at 5x the + # write cost, so skip compression here. + nbytes += writeFlatInChunks( + values, outgroup, "values", maxChunkBytes, compress=False + ) outgroup.attrs["dense_shape"] = np.array(dense_shape, dtype="int64") return nbytes diff --git a/rabbit/tensorwriter.py b/rabbit/tensorwriter.py index 4c43fc8..0c71d75 100644 --- a/rabbit/tensorwriter.py +++ b/rabbit/tensorwriter.py @@ -806,6 +806,321 @@ def _get_systematic_slices(self, h, name, channel, syst_axes=None): return slices + def _detect_extra_syst_axes(self, h, channel, syst_axes): + """Determine the extra (systematic) axes of ``h`` for a given channel. + + Returns a tuple ``(h_ref, is_pair, extra_axis_names)`` if there are + extra axes, or ``None`` otherwise. Shared by ``_get_systematic_slices`` + and the batched fast path so the detection logic stays in one place. + """ + if syst_axes is not None and len(syst_axes) == 0: + return None + + if isinstance(h, (list, tuple)): + h_ref = h[0] + is_pair = True + else: + h_ref = h + is_pair = False + + if not hasattr(h_ref, "axes"): + return None + + h_axis_names = [a.name for a in h_ref.axes] + channel_axis_names = [a.name for a in self.channels[channel]["axes"]] + + if syst_axes is None: + extra_axis_names = [n for n in h_axis_names if n not in channel_axis_names] + else: + for n in syst_axes: + if n not in h_axis_names: + raise RuntimeError( + f"Requested systematic axis '{n}' not found in histogram axes {h_axis_names}" + ) + if n in channel_axis_names: + raise RuntimeError( + f"Systematic axis '{n}' overlaps with channel axes {channel_axis_names}" + ) + extra_axis_names = list(syst_axes) + + if not extra_axis_names: + return None + + return h_ref, is_pair, extra_axis_names + + def _add_systematics_sparsehist_batched( + self, + h, + name, + process, + channel, + kfactor, + as_difference, + extra_axis_names, + **kargs, + ): + """Vectorized booking of one shape systematic per bin combination on + the extra axes of a single ``SparseHist`` input. + + Used as a fast path for the multi-systematic dispatch when the input + is a single (non-paired) :class:`wums.sparse_hist.SparseHist` with at + least one extra axis beyond the channel axes and ``as_difference=True``. + + All per-entry math (channel-flat-index computation, norm lookup, + sign-flip-protected logk evaluation) is done once over the entire + ~nnz array using vectorised numpy operations. The result is then + partitioned by linear systematic index via a single ``argsort`` + + ``searchsorted`` and bulk-inserted into ``dict_logkavg`` / + ``dict_logkavg_indices`` (sparse storage) or ``dict_logkavg`` (dense + storage). Empty bin combinations on the extra axes still get an entry + and a corresponding ``book_systematic`` call so they appear in the + fit parameter list. + """ + import scipy.sparse # noqa: F401 used implicitly via SparseHist methods + + chan_info = self.channels[channel] + chan_flow = chan_info["flow"] + channel_axes_obj = chan_info["axes"] + channel_axis_names = [a.name for a in channel_axes_obj] + + h_axis_names = [a.name for a in h.axes] + for n in channel_axis_names: + if n not in h_axis_names: + raise RuntimeError( + f"Channel axis '{n}' not found in histogram axes {h_axis_names}" + ) + + extra_positions = [h_axis_names.index(n) for n in extra_axis_names] + keep_positions = [h_axis_names.index(n) for n in channel_axis_names] + keep_axes = [h.axes[i] for i in keep_positions] + extra_axes = [h.axes[i] for i in extra_positions] + + extra_sizes = [int(len(a)) for a in extra_axes] + n_total_systs = int(np.prod(extra_sizes)) + + keep_extent = tuple(int(a.extent) for a in keep_axes) + keep_no_flow = tuple(int(len(a)) for a in keep_axes) + target_size = ( + int(np.prod(keep_extent)) if chan_flow else int(np.prod(keep_no_flow)) + ) + + norm = self.dict_norm[channel][process] + norm_is_sparse = self._issparse(norm) + systematic_type = self.systematic_type + + # ---- Step 1: get h's flat with-flow (indices, values) ---- + # Access the SparseHist's internal flat buffers directly to skip the + # O(nnz log nnz) sort that ``to_flat_csr`` would otherwise do (we do + # our own sort later, by syst index, so pre-sorting by flat index + # would be wasted work). + flat_idx = np.asarray(h._flat_indices, dtype=np.int64) + delta_vals = np.asarray(h._values, dtype=np.float64) + + if len(flat_idx) > 0: + multi = np.unravel_index(flat_idx, h.shape) + # free flat_idx — we only need the per-axis multi-dim arrays now + flat_idx = None + + # ---- Step 2: drop entries in flow bins ---- + # Build the validity mask in a single pass over all relevant axes + # (extra axes always, channel axes only if the channel is no-flow). + if chan_flow: + check_positions = list(extra_positions) + else: + check_positions = list(extra_positions) + list(keep_positions) + + valid = None + for ax_pos in check_positions: + ax = h.axes[ax_pos] + u = SparseHist._underflow_offset(ax) + s = int(len(ax)) + ax_arr = multi[ax_pos] + if u == 0: + cond = ax_arr < s + else: + # 1 underflow bin: valid = ax_arr in [1, 1+s) + cond = (ax_arr >= u) & (ax_arr < u + s) + if valid is None: + valid = cond + else: + valid &= cond + + if valid is not None and not valid.all(): + multi = tuple(m[valid] for m in multi) + delta_vals = delta_vals[valid] + valid = None # free + + # ---- Step 3: compute linear systematic index from extra axes ---- + if len(extra_positions) == 1: + ax_pos = extra_positions[0] + u = SparseHist._underflow_offset(h.axes[ax_pos]) + if u == 0: + syst_linear = multi[ax_pos].astype(np.int64, copy=False) + else: + syst_linear = (multi[ax_pos] - u).astype(np.int64, copy=False) + else: + per = [] + for ax_pos in extra_positions: + u = SparseHist._underflow_offset(h.axes[ax_pos]) + per.append(multi[ax_pos] - u if u else multi[ax_pos]) + syst_linear = np.ravel_multi_index(per, extra_sizes) + + # ---- Step 4: compute channel flat index in target layout ---- + if chan_flow: + chan_flat = np.ravel_multi_index( + tuple(multi[i] for i in keep_positions), keep_extent + ) + else: + chan_no_flow_multi = [] + for ax_pos in keep_positions: + u = SparseHist._underflow_offset(h.axes[ax_pos]) + chan_no_flow_multi.append(multi[ax_pos] - u if u else multi[ax_pos]) + chan_flat = np.ravel_multi_index(chan_no_flow_multi, keep_no_flow) + multi = None # free the per-axis arrays; we no longer need them + + # ---- Step 5: look up norm at chan_flat; drop where norm == 0 ---- + if norm_is_sparse: + norm_indices_arr = np.asarray(norm.indices, dtype=np.int64) + norm_data_arr = np.asarray(norm.data, dtype=np.float64) + positions = np.searchsorted(norm_indices_arr, chan_flat) + in_range = positions < len(norm_indices_arr) + match = np.zeros(len(chan_flat), dtype=bool) + match[in_range] = ( + norm_indices_arr[positions[in_range]] == chan_flat[in_range] + ) + chan_flat = chan_flat[match] + delta_vals = delta_vals[match] + syst_linear = syst_linear[match] + norm_at_pos = norm_data_arr[positions[match]] + else: + norm_arr = np.asarray(norm, dtype=np.float64) + norm_at_pos = norm_arr[chan_flat] + nonzero_norm = norm_at_pos != 0.0 + if not nonzero_norm.all(): + chan_flat = chan_flat[nonzero_norm] + delta_vals = delta_vals[nonzero_norm] + syst_linear = syst_linear[nonzero_norm] + norm_at_pos = norm_at_pos[nonzero_norm] + + # ---- Step 6: validate finiteness (mirrors get_logk's check) ---- + if not np.all(np.isfinite(delta_vals)): + n_bad = int((~np.isfinite(delta_vals)).sum()) + raise RuntimeError( + f"{n_bad} NaN or Inf values encountered in systematic!" + ) + + # ---- Step 7: compute logk vectorized ---- + syst_at_pos = norm_at_pos + delta_vals # as_difference=True + + if systematic_type == "log_normal": + with np.errstate(divide="ignore", invalid="ignore"): + logk_vals = kfactor * np.log(syst_at_pos / norm_at_pos) + logk_vals = np.where( + np.equal(np.sign(norm_at_pos * syst_at_pos), 1), + logk_vals, + self.logkepsilon, + ) + if self.clipSystVariations > 0.0: + logk_vals = np.clip(logk_vals, -self.clip, self.clip) + elif systematic_type == "normal": + logk_vals = kfactor * (syst_at_pos - norm_at_pos) + else: + raise RuntimeError( + f"Invalid systematic_type {systematic_type}, valid choices are 'log_normal' or 'normal'" + ) + + # ---- Step 8: drop exactly-zero logk entries ---- + nonzero_logk = logk_vals != 0.0 + if not nonzero_logk.all(): + chan_flat = chan_flat[nonzero_logk] + logk_vals = logk_vals[nonzero_logk] + syst_linear = syst_linear[nonzero_logk] + + # ---- Step 9: sort by linear syst index for partitioning ---- + # Use the default (non-stable) quicksort since the intra-syst + # order does not affect correctness. + sort_order = np.argsort(syst_linear) + sorted_syst = syst_linear[sort_order] + sorted_chan_flat = chan_flat[sort_order] + sorted_logk = logk_vals[sort_order] + sort_order = None + syst_linear = None + chan_flat = None + logk_vals = None + + # ---- Step 10: per-syst boundaries via searchsorted ---- + boundaries = np.searchsorted( + sorted_syst, np.arange(n_total_systs + 1), side="left" + ) + sorted_syst = None + else: + boundaries = np.zeros(n_total_systs + 1, dtype=np.int64) + sorted_chan_flat = np.empty(0, dtype=np.int64) + sorted_logk = np.empty(0, dtype=np.float64) + + # ---- Step 11: bulk insert into the writer's internal storage ---- + dict_logk_proc = self.dict_logkavg[channel][process] + if self.sparse: + dict_logk_idx_proc = self.dict_logkavg_indices[channel][process] + + # Pre-compute per-axis label lists once (much faster than calling + # the generic _bin_label helper per combination, since the value() + # method on boost_histogram axes has non-trivial per-call overhead). + def _axis_labels(ax): + # Check whether the axis stores string categories; if so, decode + # them in bulk. Otherwise fall back to integer bin indices. + n = int(len(ax)) + if n == 0: + return [] + try: + v0 = ax.value(0) + except Exception: + v0 = None + if isinstance(v0, (str, bytes)): + out = [] + for i in range(n): + v = ax.value(i) + if isinstance(v, bytes): + v = v.decode() + out.append(v) + return out + return [str(i) for i in range(n)] + + axis_label_lists = [_axis_labels(ax) for ax in extra_axes] + + if len(extra_axes) == 1: + labels0 = axis_label_lists[0] + sub_names = [f"{name}_{labels0[i]}" for i in range(extra_sizes[0])] + else: + sub_names = [] + for linear_idx in range(n_total_systs): + multi_syst = np.unravel_index(linear_idx, extra_sizes) + labels = [ + axis_label_lists[k][int(multi_syst[k])] + for k in range(len(extra_axes)) + ] + sub_names.append("_".join([name, *labels])) + + for linear_idx in range(n_total_systs): + sub_name = sub_names[linear_idx] + s = int(boundaries[linear_idx]) + e = int(boundaries[linear_idx + 1]) + + if self.sparse: + # Sparse storage: store views into the sorted buffers (they + # keep the big arrays alive, but that is fine — we need the + # data anyway and sharing storage avoids a full per-syst copy). + dict_logk_idx_proc[sub_name] = sorted_chan_flat[s:e].reshape(-1, 1) + dict_logk_proc[sub_name] = sorted_logk[s:e] + else: + # Dense storage: scatter into a full-size logk array + logk_dense = np.zeros(target_size, dtype=np.float64) + if e > s: + logk_dense[sorted_chan_flat[s:e]] = sorted_logk[s:e] + dict_logk_proc[sub_name] = logk_dense + + self.book_systematic(sub_name, **kargs) + def add_systematic( self, h, @@ -830,6 +1145,36 @@ def add_systematic( to disable auto-detection. """ + # Fast batched path for SparseHist multi-systematic input. Conditions: + # - extra (systematic) axes are present + # - input is a single SparseHist (not an asymmetric pair) + # - mirror=True (single-hist symmetric input) + # - as_difference=True (so missing entries cleanly mean "no variation" + # for both log_normal and normal systematic types) + # - not added to the data covariance (which goes through a different + # bookkeeping path) + extra_info = self._detect_extra_syst_axes(h, channel, syst_axes) + if ( + extra_info is not None + and not extra_info[1] # is_pair + and isinstance(extra_info[0], SparseHist) + and mirror + and as_difference + and not add_to_data_covariance + ): + _, _, extra_axis_names = extra_info + self._add_systematics_sparsehist_batched( + h, + name, + process, + channel, + kfactor=kfactor, + as_difference=as_difference, + extra_axis_names=extra_axis_names, + **kargs, + ) + return + # multi-systematic dispatch: if h has extra axes beyond the channel, # iterate over those and book each combination as an independent systematic slices = self._get_systematic_slices(h, name, channel, syst_axes) @@ -1073,15 +1418,17 @@ def add_external_likelihood_term(self, grad=None, hess=None, name=None): params = hess_params0 if isinstance(hess, SparseHist): - # extract sparse coordinates and values from with-flow layout - # (StrCategory axes never have flow, so this matches the no-flow layout) - csr = hess.to_flat_csr(self.dtype, flow=False) - # csr has shape (1, n*n); convert flat positions to (row, col) + # Access the SparseHist's internal flat (indices, values) + # buffers directly. Going through ``to_flat_csr`` would do an + # O(nnz log nnz) sort that we don't need here, since the + # downstream representation is unordered (rows, cols, values). + # The flat indices live in the with-flow layout of the dense + # shape, but for StrCategory axes with overflow=False the + # extents equal the sizes so there are no flow bins to drop. n = len(params) - flat = np.asarray(csr.indices, dtype=np.int64) - rows = (flat // n).astype(np.int64) - cols = (flat % n).astype(np.int64) - vals = np.asarray(csr.data, dtype=self.dtype) + flat = np.asarray(hess._flat_indices, dtype=np.int64) + vals = np.asarray(hess._values, dtype=self.dtype) + rows, cols = np.divmod(flat, n) hess_sparse = (rows, cols, vals) elif self._issparse(hess): raise ValueError( @@ -1326,14 +1673,46 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", meta_data_dict= ibin = 0 if self.sparse: logger.info(f"Write out sparse array") - norm_sparse_size = 0 - norm_sparse_indices = np.zeros([norm_sparse_size, 2], self.idxdtype) - norm_sparse_values = np.zeros([norm_sparse_size], self.dtype) + # Pre-compute total sizes so we can allocate the assembly buffers + # once instead of growing them per (channel, process, syst) which + # is O(N^2) total via np.ndarray.resize. This pass only touches + # python dict structures and is essentially free. + norm_sparse_size_total = 0 + logk_sparse_size_total = 0 + for chan_pre in self.channels.keys(): + dict_norm_chan_pre = self.dict_norm[chan_pre] + dict_logkavg_chan_idx_pre = self.dict_logkavg_indices[chan_pre] + dict_logkhalfdiff_chan_idx_pre = self.dict_logkhalfdiff_indices[ + chan_pre + ] + for proc_pre in procs: + if proc_pre not in dict_norm_chan_pre: + continue + norm_proc_pre = dict_norm_chan_pre[proc_pre] + if self._issparse(norm_proc_pre): + norm_sparse_size_total += int(len(norm_proc_pre.indices)) + else: + norm_sparse_size_total += int(np.count_nonzero(norm_proc_pre)) + proc_logk_idx_pre = dict_logkavg_chan_idx_pre[proc_pre] + for syst_idx_arr in proc_logk_idx_pre.values(): + logk_sparse_size_total += int(syst_idx_arr.shape[0]) + proc_halfdiff_idx_pre = dict_logkhalfdiff_chan_idx_pre[proc_pre] + for syst_idx_arr in proc_halfdiff_idx_pre.values(): + logk_sparse_size_total += int(syst_idx_arr.shape[0]) + + norm_sparse_indices = np.empty([norm_sparse_size_total, 2], self.idxdtype) + norm_sparse_values = np.empty([norm_sparse_size_total], self.dtype) + logk_sparse_normindices = np.empty( + [logk_sparse_size_total, 1], self.idxdtype + ) + logk_sparse_systindices = np.empty( + [logk_sparse_size_total, 1], self.idxdtype + ) + logk_sparse_values = np.empty([logk_sparse_size_total], self.dtype) + + norm_sparse_size = 0 logk_sparse_size = 0 - logk_sparse_normindices = np.zeros([logk_sparse_size, 1], self.idxdtype) - logk_sparse_systindices = np.zeros([logk_sparse_size, 1], self.idxdtype) - logk_sparse_values = np.zeros([logk_sparse_size], self.dtype) for chan in self.channels.keys(): nbinschan = self.nbinschan[chan] @@ -1349,13 +1728,11 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", meta_data_dict= if self._issparse(norm_proc): # Use scipy sparse structure directly norm_indices = norm_proc.indices.reshape(-1, 1) - norm_values = norm_proc.data.copy() + norm_values = norm_proc.data nvals = len(norm_values) oldlength = norm_sparse_size norm_sparse_size = oldlength + nvals - norm_sparse_indices.resize([norm_sparse_size, 2]) - norm_sparse_values.resize([norm_sparse_size]) out_indices = np.array([[ibin, iproc]]) + np.pad( norm_indices, ((0, 0), (0, 1)), "constant" @@ -1378,8 +1755,6 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", meta_data_dict= nvals = len(norm_values) oldlength = norm_sparse_size norm_sparse_size = oldlength + nvals - norm_sparse_indices.resize([norm_sparse_size, 2]) - norm_sparse_values.resize([norm_sparse_size]) out_indices = np.array([[ibin, iproc]]) + np.pad( norm_indices, ((0, 0), (0, 1)), "constant" @@ -1410,9 +1785,6 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", meta_data_dict= nvals_proc = len(logkavg_proc_values) oldlength = logk_sparse_size logk_sparse_size = oldlength + nvals_proc - logk_sparse_normindices.resize([logk_sparse_size, 1]) - logk_sparse_systindices.resize([logk_sparse_size, 1]) - logk_sparse_values.resize([logk_sparse_size]) # first dimension of output indices are NOT in the dense [nbin,nproc] space, but rather refer to indices in the norm_sparse vectors # second dimension is flattened in the [2,nsyst] space, where logkavg corresponds to [0,isyst] flattened to isyst @@ -1450,9 +1822,6 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", meta_data_dict= nvals_proc = len(logkhalfdiff_proc_values) oldlength = logk_sparse_size logk_sparse_size = oldlength + nvals_proc - logk_sparse_normindices.resize([logk_sparse_size, 1]) - logk_sparse_systindices.resize([logk_sparse_size, 1]) - logk_sparse_values.resize([logk_sparse_size]) # out_indices = np.array([[ibin,iproc,isyst,1]]) + np.pad(logkhalfdiff_proc_indices,((0,0),(0,3)),'constant') # first dimension of output indices are NOT in the dense [nbin,nproc] space, but rather refer to indices in the norm_sparse vectors @@ -1493,13 +1862,9 @@ def write(self, outfolder="./", outfilename="rabbit_input.hdf5", meta_data_dict= ibin += nbinschan - logger.info(f"Resize and sort sparse arrays into canonical order") - # resize sparse arrays to actual length - norm_sparse_indices.resize([norm_sparse_size, 2]) - norm_sparse_values.resize([norm_sparse_size]) - logk_sparse_normindices.resize([logk_sparse_size, 1]) - logk_sparse_systindices.resize([logk_sparse_size, 1]) - logk_sparse_values.resize([logk_sparse_size]) + logger.info(f"Sort sparse arrays into canonical order") + assert norm_sparse_size == norm_sparse_size_total + assert logk_sparse_size == logk_sparse_size_total # straightforward sorting of norm_sparse into canonical order norm_sparse_dense_shape = (nbinsfull, nproc) @@ -1849,29 +2214,30 @@ def get_systs(self): def get_constraintweights(self, dtype): systs = self.get_systs() constraintweights = np.ones([len(systs)], dtype=dtype) + syst_to_idx = {s: i for i, s in enumerate(systs)} for syst in self.get_systsnoconstraint(): - constraintweights[systs.index(syst)] = 0.0 + constraintweights[syst_to_idx[syst]] = 0.0 return constraintweights def get_groups(self, group_dict): systs = self.get_systs() + # Pre-compute name -> index mapping once. The previous implementation + # called ``systs.index(syst)`` per group member which is O(len(systs)) + # each, giving O(nsysts * nmembers) total -- prohibitive when both are + # large (e.g. ~108k corparms in a single group). + syst_to_idx = {s: i for i, s in enumerate(systs)} groups = [] idxs = [] for group, members in common.natural_sort_dict(group_dict).items(): groups.append(group) - idx = [] - for syst in members: - idx.append(systs.index(syst)) - idxs.append(idx) + idxs.append([syst_to_idx[syst] for syst in members]) return groups, idxs def get_noiidxs(self): - # list of indeces of nois w.r.t. systs + # list of indices of nois w.r.t. systs systs = self.get_systs() - idxs = [] - for noi in self.get_systsnoi(): - idxs.append(systs.index(noi)) - return idxs + syst_to_idx = {s: i for i, s in enumerate(systs)} + return [syst_to_idx[noi] for noi in self.get_systsnoi()] def get_systgroups(self): # list of groups of systematics (nuisances) and lists of indexes diff --git a/tests/test_multi_systematic.py b/tests/test_multi_systematic.py index 986dd5e..a024156 100644 --- a/tests/test_multi_systematic.py +++ b/tests/test_multi_systematic.py @@ -440,6 +440,93 @@ def main(): ), "logk mismatch (masked SparseHist flow vs masked individual)" print("PASS: SparseHist on masked flow=True channel matches individual") + # --- Batched SparseHist path: single hist, mirror=True, as_difference=True --- + # This exercises the vectorized fast path in add_systematic which + # bypasses the per-slice dispatch entirely. We compare byte-for-byte + # against the equivalent per-syst manual booking (which goes through + # the regular single-syst path) using log_normal systematic type on + # a dense process, on data that includes positions where the delta + # pushes the bin negative (so the logkepsilon fallback is exercised). + import scipy.sparse as _sp + from wums.sparse_hist import SparseHist as _SH + + nbatch = 12 + ax_bx = hist.axis.Regular(8, -4, 4, name="x") + ax_by = hist.axis.Regular(6, 0, 3, name="y") + ax_bs = hist.axis.Integer( + 0, nbatch, underflow=False, overflow=False, name="syst" + ) + + rng = np.random.default_rng(17) + h_bproc = hist.Hist(ax_bx, ax_by, storage=hist.storage.Weight()) + x_v = rng.normal(0, 1, 1000) + y_v = rng.uniform(0, 3, 1000) + h_bproc.fill(x_v, y_v, weight=np.ones(1000)) + h_bdata = hist.Hist(ax_bx, ax_by, storage=hist.storage.Double()) + h_bdata.fill(x_v, y_v) + + ext_shape = (ax_bx.extent, ax_by.extent, ax_bs.extent) + dense_systs = rng.normal(0, 0.1, ext_shape) + sparse_mask = rng.random(ext_shape) < 0.5 + dense_systs[sparse_mask] = 0 + flat_data = dense_systs.reshape(1, -1) + sh_batch = _SH(_sp.csr_array(flat_data), [ax_bx, ax_by, ax_bs]) + + def make_batch_writer(use_batched): + w = tensorwriter.TensorWriter(sparse=True, systematic_type="log_normal") + w.add_channel([ax_bx, ax_by], "ch0") + w.add_data(h_bdata, "ch0") + w.add_process(h_bproc, "proc", "ch0", signal=True) + if use_batched: + w.add_systematic( + sh_batch, + "syst", + "proc", + "ch0", + mirror=True, + as_difference=True, + constrained=False, + groups=["g"], + ) + else: + for i in range(nbatch): + sub_dense = dense_systs[:, :, i] + sub_flat = _sp.csr_array(sub_dense.reshape(1, -1)) + sub_sh = _SH(sub_flat, [ax_bx, ax_by]) + w.add_systematic( + sub_sh, + f"syst_{i}", + "proc", + "ch0", + mirror=True, + as_difference=True, + constrained=False, + groups=["g"], + syst_axes=[], + ) + return w + + wb = make_batch_writer(True) + wb.write(outfolder=tmpdir, outfilename="batch_fast") + wm = make_batch_writer(False) + wm.write(outfolder=tmpdir, outfilename="batch_manual") + + bf = read_hdf5_arrays(os.path.join(tmpdir, "batch_fast.hdf5")) + bm = read_hdf5_arrays(os.path.join(tmpdir, "batch_manual.hdf5")) + + print("Batched-path fast systs: ", bf["systs"]) + print("Batched-path manual systs: ", bm["systs"]) + assert ( + bf["systs"] == bm["systs"] + ), f"syst lists differ: fast {bf['systs']} vs manual {bm['systs']}" + assert np.allclose( + bf["hnorm"], bm["hnorm"] + ), "hnorm mismatch (batched fast vs manual)" + assert np.allclose( + bf["hlogk"], bm["hlogk"] + ), "hlogk mismatch (batched fast vs manual)" + print("PASS: batched SparseHist path matches per-syst manual booking") + print() print("ALL CHECKS PASSED") From f51e53fff140cb679c934e74319f724a6fea6f89 Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Thu, 9 Apr 2026 02:34:41 +0200 Subject: [PATCH 08/18] inputdata, parsing: prep for sparse fast path with CSR matvec Three preparatory changes that the fitter changes in following commits will rely on: * inputdata.py: in sparse mode, call tf.sparse.reorder on norm and logk at load time to canonicalize their indices into row-major order. The fitter sparse fast path reduces nonzero entries via row-keyed reductions, which want coalesced memory access on the sorted indices. * inputdata.py: pre-build a tf.linalg.sparse.CSRSparseMatrix view of logk so the fitter can use sm.matmul (a multi-threaded CSR kernel) for the inner contraction logk @ theta. SparseMatrixMatMul has no XLA kernel, so any tf.function calling it must be built with jit_compile=False; the fitter handles this in sparse mode. * parsing.py: add --hvpMethod {revrev,fwdrev} to choose the autodiff mode for the Hessian-vector product, and --noJitCompile to disable XLA jit_compile (on by default in dense mode). Co-Authored-By: Claude Opus 4.6 (1M context) --- rabbit/inputdata.py | 20 +++++++++++++++++--- rabbit/parsing.py | 18 ++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/rabbit/inputdata.py b/rabbit/inputdata.py index 35329f8..a539520 100644 --- a/rabbit/inputdata.py +++ b/rabbit/inputdata.py @@ -82,11 +82,25 @@ def __init__(self, filename, pseudodata=None): self.sparse = not "hnorm" in f if self.sparse: - print( - "WARNING: The sparse tensor implementation is experimental and probably slower than with a dense tensor!" - ) self.norm = makesparsetensor(f["hnorm_sparse"]) self.logk = makesparsetensor(f["hlogk_sparse"]) + # Canonicalize index ordering once at load time. The fitter's + # sparse fast path reduces nonzero entries via row-keyed + # reductions; sorted row-major indices give coalesced memory + # access. tf.sparse.reorder sorts into row-major order. + self.norm = tf.sparse.reorder(self.norm) + self.logk = tf.sparse.reorder(self.logk) + # Pre-build a CSRSparseMatrix view of logk for use in the + # fitter's sparse matvec path via sm.matmul, which dispatches + # to a multi-threaded CSR kernel and is much faster per call + # than the equivalent gather + unsorted_segment_sum. NOTE: + # SparseMatrixMatMul has no XLA kernel, so any tf.function + # that calls sm.matmul must be built with jit_compile=False. + from tensorflow.python.ops.linalg.sparse import ( + sparse_csr_matrix_ops as _tf_sparse_csr, + ) + + self.logk_csr = _tf_sparse_csr.CSRSparseMatrix(self.logk) else: self.norm = maketensor(f["hnorm"]) self.logk = maketensor(f["hlogk"]) diff --git a/rabbit/parsing.py b/rabbit/parsing.py index 1945ece..3a06569 100644 --- a/rabbit/parsing.py +++ b/rabbit/parsing.py @@ -202,6 +202,24 @@ def common_parser(): ], help="Mnimizer method used in scipy.optimize.minimize for the nominal fit minimization", ) + parser.add_argument( + "--hvpMethod", + default="revrev", + type=str, + choices=["fwdrev", "revrev"], + help="Autodiff mode for the Hessian-vector product. 'revrev' (reverse-over-reverse) " + "is the default and works well in combination with --jitCompile. 'fwdrev' " + "(forward-over-reverse, via tf.autodiff.ForwardAccumulator) is an alternative.", + ) + parser.add_argument( + "--noJitCompile", + dest="jitCompile", + default=True, + action="store_false", + help="Disable XLA jit_compile=True on the loss/gradient/HVP tf.functions. " + "jit_compile is enabled by default and substantially speeds up sparse-mode fits " + "with very large numbers of parameters.", + ) parser.add_argument( "--chisqFit", default=False, From 83afbd82b1e5cb20133cf7885ff4530ea01c7d59 Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Thu, 9 Apr 2026 02:41:35 +0200 Subject: [PATCH 09/18] fitter: dynamic loss/grad/HVP wrappers with jit_compile + hvpMethod Replace the class-level @tf.function decorators on loss_val, loss_val_grad, and loss_val_grad_hessp_{fwdrev,revrev} with instance-level wrappers built dynamically in _make_tf_functions() at construction time. This lets jit_compile and the HVP autodiff mode be controlled per-fit via --jitCompile / --hvpMethod without class-level redefinition. * --jitCompile (on by default): wraps loss/grad and revrev HVP with tf.function(jit_compile=True). The fwdrev HVP wrapper is intentionally NOT jit-compiled because tf.autodiff.Forward- Accumulator does not propagate JVPs through XLA-compiled subgraphs (the JVP comes back as zero), regardless of inner/ outer placement of jit_compile. * --hvpMethod {revrev,fwdrev}: selects which underlying HVP wrapper is bound to self.loss_val_grad_hessp. The dynamic wrappers are also stripped and rebuilt in __deepcopy__, since the FuncGraph state held by an already-traced tf.function cannot be deepcopy'd. _compute_loss is collapsed to a one-liner since its only job is to dispatch to _compute_nll. Co-Authored-By: Claude Opus 4.6 (1M context) --- rabbit/fitter.py | 103 +++++++++++++++++++++++++++++++---------------- 1 file changed, 69 insertions(+), 34 deletions(-) diff --git a/rabbit/fitter.py b/rabbit/fitter.py index e40645d..df79951 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -121,6 +121,8 @@ def __init__( self.diagnostics = options.diagnostics self.minimizer_method = options.minimizerMethod + self.hvp_method = getattr(options, "hvpMethod", "revrev") + self.jit_compile = getattr(options, "jitCompile", True) if options.covarianceFit and options.chisqFit: raise Exception( @@ -404,6 +406,11 @@ def init_fit_parms( tf.function(val.python_function.__get__(self, type(self))), ) + # (re)build instance-level tf.function wrappers for loss/grad/HVP, which + # are constructed dynamically so that jit_compile and the HVP autodiff + # mode can be controlled via fit options. + self._make_tf_functions() + def __deepcopy__(self, memo): import copy @@ -415,12 +422,23 @@ def __deepcopy__(self, memo): for name in self.__dict__ if hasattr(getattr(type(self), name, None), "python_function") } - state = {k: v for k, v in self.__dict__.items() if k not in jit_overrides} + # Also strip the dynamically-built loss/grad/HVP tf.function wrappers, + # which hold un-copyable FuncGraph state and will be rebuilt below. + dynamic_tf_funcs = { + "loss_val", + "loss_val_grad", + "loss_val_grad_hessp", + "loss_val_grad_hessp_fwdrev", + "loss_val_grad_hessp_revrev", + } + skip = jit_overrides | dynamic_tf_funcs + state = {k: v for k, v in self.__dict__.items() if k not in skip} cls = type(self) obj = cls.__new__(cls) memo[id(self)] = obj for k, v in state.items(): setattr(obj, k, copy.deepcopy(v, memo)) + obj._make_tf_functions() return obj def load_fitresult(self, fitresult_file, fitresult_key, profile=True): @@ -2172,44 +2190,61 @@ def _compute_nll(self, profile=True, full_nll=False): return l def _compute_loss(self, profile=True): - l = self._compute_nll(profile=profile) - return l - - @tf.function - def loss_val(self): - val = self._compute_loss() - return val + return self._compute_nll(profile=profile) - @tf.function - def loss_val_grad(self): - with tf.GradientTape() as t: - val = self._compute_loss() - grad = t.gradient(val, self.x) - return val, grad + def _make_tf_functions(self): + # Build tf.function wrappers at instance construction time so that + # jit_compile and the HVP autodiff mode can be controlled via fit + # options without redefining the class. + jit = self.jit_compile - # FIXME in principle this version of the function is preferred - # but seems to introduce some small numerical non-reproducibility - @tf.function - def loss_val_grad_hessp_fwdrev(self, p): - p = tf.stop_gradient(p) - with tf.autodiff.ForwardAccumulator(self.x, p) as acc: - with tf.GradientTape() as grad_tape: - val = self._compute_loss() - grad = grad_tape.gradient(val, self.x) - hessp = acc.jvp(grad) - return val, grad, hessp + def _loss_val(self): + return self._compute_loss() - @tf.function - def loss_val_grad_hessp_revrev(self, p): - p = tf.stop_gradient(p) - with tf.GradientTape() as t2: - with tf.GradientTape() as t1: + def _loss_val_grad(self): + with tf.GradientTape() as t: val = self._compute_loss() - grad = t1.gradient(val, self.x) - hessp = t2.gradient(grad, self.x, output_gradients=p) - return val, grad, hessp + grad = t.gradient(val, self.x) + return val, grad + + def _loss_val_grad_hessp_fwdrev(self, p): + p = tf.stop_gradient(p) + with tf.autodiff.ForwardAccumulator(self.x, p) as acc: + with tf.GradientTape() as grad_tape: + val = self._compute_loss() + grad = grad_tape.gradient(val, self.x) + hessp = acc.jvp(grad) + return val, grad, hessp + + def _loss_val_grad_hessp_revrev(self, p): + p = tf.stop_gradient(p) + with tf.GradientTape() as t2: + with tf.GradientTape() as t1: + val = self._compute_loss() + grad = t1.gradient(val, self.x) + hessp = t2.gradient(grad, self.x, output_gradients=p) + return val, grad, hessp - loss_val_grad_hessp = loss_val_grad_hessp_revrev + self.loss_val = tf.function(jit_compile=jit)( + _loss_val.__get__(self, type(self)) + ) + self.loss_val_grad = tf.function(jit_compile=jit)( + _loss_val_grad.__get__(self, type(self)) + ) + # NOTE: fwdrev HVP is NOT jit-compiled. tf.autodiff.ForwardAccumulator + # does not propagate JVPs through XLA-compiled subgraphs (the JVP + # comes back as zero), regardless of inner/outer placement. The + # loss/grad and revrev HVP wrappers are unaffected. + self.loss_val_grad_hessp_fwdrev = tf.function( + _loss_val_grad_hessp_fwdrev.__get__(self, type(self)) + ) + self.loss_val_grad_hessp_revrev = tf.function(jit_compile=jit)( + _loss_val_grad_hessp_revrev.__get__(self, type(self)) + ) + if self.hvp_method == "fwdrev": + self.loss_val_grad_hessp = self.loss_val_grad_hessp_fwdrev + else: + self.loss_val_grad_hessp = self.loss_val_grad_hessp_revrev @tf.function def loss_val_grad_hess(self, profile=True): From b6d7120454780a7e97fbcac16228f0bc3d74ff23 Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Thu, 9 Apr 2026 02:44:44 +0200 Subject: [PATCH 10/18] fitter: sparse fast path uses CSR matmul, no dense [nbins,nproc] Reformulate the sparse branch of _compute_yields_noBBB so that the NLL/grad/HVP path never materializes the dense [nbinsfull, nproc] intermediate, and uses tf.linalg.sparse's CSR SparseMatrixMatMul for the dominant inner contraction logk @ theta. The CSR kernel is multi-threaded and ~8x faster per call than the equivalent gather + unsorted_segment_sum that the previous form lowered to under TF on CPU. Changes: * _compute_yields_noBBB takes a new compute_norm flag. The dense [nbinsfull, nproc] normcentral grid is only built when an external caller actually wants per-process yields, or when binByBinStat "full" mode needs them for the analytic beta solution. The NLL/grad/HVP path passes compute_norm=False. * Sparse branch: replace tf.sparse.sparse_dense_matmul(logk, ...) with tf_sparse_csr.matmul(logk_csr, ...) on the pre-built CSR view from inputdata.py. * Sparse branch: collapse to per-bin yields via tf.math.unsorted_segment_sum on the modified sparse values keyed by bin index, equivalent to but cheaper than tf.sparse.reduce_sum at this scale. * _compute_yields_with_beta plumbs need_norm correctly so the bbb-lite path doesn't pay for the dense materialization. * _expected_yield_noBBB explicitly passes compute_norm=False. * _make_tf_functions: SparseMatrixMatMul has no XLA kernel, so force jit_compile=False on all wrappers in sparse mode regardless of the user's --jitCompile setting. * _make_tf_functions: tf.autodiff.ForwardAccumulator cannot trace tangents through SparseMatrixMatMul (no JVP rule for the CSR variant), so when --hvpMethod=fwdrev is requested in sparse mode, fall back to revrev with a warning. Profile on the jpsi calibration tensor (76800 bins, 108334 params, 62M-nnz logk): HVP per call drops from ~6400 ms to ~320 ms (~20x speedup), loss+grad from ~3000 ms to ~160 ms. Co-Authored-By: Claude Opus 4.6 (1M context) --- rabbit/fitter.py | 81 ++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 65 insertions(+), 16 deletions(-) diff --git a/rabbit/fitter.py b/rabbit/fitter.py index df79951..d39deaa 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -7,6 +7,7 @@ import scipy import tensorflow as tf import tensorflow_probability as tfp +from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops as tf_sparse_csr from wums import logging from rabbit import io_tools @@ -1256,8 +1257,13 @@ def _expected_variations( return expvars - def _compute_yields_noBBB(self, full=True): + def _compute_yields_noBBB(self, full=True, compute_norm=True): # full: compute yields inclduing masked channels + # compute_norm: also build the dense [nbins, nproc] normcentral tensor. + # In sparse mode this is expensive (forward + backward) and is only + # needed when an external caller requests per-process yields, or for + # binByBinStat in "full" mode. The default is True for backward + # compatibility; the NLL/grad/HVP path passes compute_norm=False. poi = self.get_poi() theta = self.get_theta() @@ -1281,15 +1287,29 @@ def _compute_yields_noBBB(self, full=True): mthetaalpha = tf.reshape(mthetaalpha, [2 * self.indata.nsyst, 1]) if self.indata.sparse: - logsnorm = tf.sparse.sparse_dense_matmul(self.indata.logk, mthetaalpha) - logsnorm = tf.squeeze(logsnorm, -1) + # Inner contraction logk · mthetaalpha via tf.linalg.sparse's + # CSR matmul. ~8x faster per call than gather + segment_sum + # because SparseMatrixMatMul dispatches to a hand-tuned CSR + # kernel. NOTE: SparseMatrixMatMul has no XLA kernel, so the + # enclosing loss/grad/HVP tf.functions are built with + # jit_compile=False in sparse mode (see _make_tf_functions). + logsnorm = tf.squeeze( + tf_sparse_csr.matmul(self.indata.logk_csr, mthetaalpha), + axis=-1, + ) + # Build a sparse [nbinsfull, nproc] tensor whose values absorb + # the per-entry syst variation and the per-(bin, proc) POI + # scaling rnorm. The sparsity pattern is unchanged from + # self.indata.norm, so with_values lets us reuse the indices. if self.indata.systematic_type == "log_normal": - snorm = tf.exp(logsnorm) + # values[i] = norm[i] * exp(logsnorm[i]) * rnorm[bin, proc] snormnorm_sparse = self.indata.norm.with_values( - snorm * self.indata.norm.values + tf.exp(logsnorm) * self.indata.norm.values ) - elif self.indata.systematic_type == "normal": + snormnorm_sparse = snormnorm_sparse * rnorm + else: # "normal" + # values[i] = norm[i] * rnorm[bin, proc] + logsnorm[i] snormnorm_sparse = self.indata.norm * rnorm snormnorm_sparse = snormnorm_sparse.with_values( snormnorm_sparse.values + logsnorm @@ -1300,13 +1320,20 @@ def _compute_yields_noBBB(self, full=True): snormnorm_sparse, self.indata.nbins ) - if self.indata.systematic_type == "log_normal": - snormnorm = tf.sparse.to_dense(snormnorm_sparse) - normcentral = rnorm * snormnorm - elif self.indata.systematic_type == "normal": + # Per-bin yields via unsorted_segment_sum on the sparse values + # keyed by bin index. Equivalent to tf.sparse.reduce_sum(..., + # axis=-1) but uses the dedicated segment_sum kernel directly, + # which has lower per-call overhead. The dense [nbinsfull, + # nproc] grid is only materialized when an external caller + # requested per-process yields (compute_norm=True). + nbinsfull_int = int(snormnorm_sparse.dense_shape[0]) + nexpcentral = tf.math.unsorted_segment_sum( + snormnorm_sparse.values, + snormnorm_sparse.indices[:, 0], + num_segments=nbinsfull_int, + ) + if compute_norm: normcentral = tf.sparse.to_dense(snormnorm_sparse) - - nexpcentral = tf.reduce_sum(normcentral, axis=-1) else: if full or self.indata.nbinsmasked == 0: nbins = self.indata.nbinsfull @@ -1343,7 +1370,13 @@ def _compute_yields_noBBB(self, full=True): return nexpcentral, normcentral def _compute_yields_with_beta(self, profile=True, compute_norm=False, full=True): - nexp, norm = self._compute_yields_noBBB(full=full) + # Only materialize the dense [nbins, nproc] normcentral when an external + # caller requested it, or when binByBinStat "full" mode needs per-process + # yields for the analytic beta solution. + need_norm = compute_norm or ( + self.binByBinStat and self.binByBinStatMode == "full" + ) + nexp, norm = self._compute_yields_noBBB(full=full, compute_norm=need_norm) if self.binByBinStat: if profile: @@ -2017,7 +2050,7 @@ def expected_yield(self, profile=False, full=False): @tf.function def _expected_yield_noBBB(self, full=False): - res, _ = self._compute_yields_noBBB(full=full) + res, _ = self._compute_yields_noBBB(full=full, compute_norm=False) return res @tf.function @@ -2196,7 +2229,12 @@ def _make_tf_functions(self): # Build tf.function wrappers at instance construction time so that # jit_compile and the HVP autodiff mode can be controlled via fit # options without redefining the class. - jit = self.jit_compile + # + # SparseMatrixMatMul has no XLA kernel, so any tf.function that + # uses it (via _compute_yields_noBBB in sparse mode) cannot be + # jit-compiled. Force jit_compile off in sparse mode regardless + # of the user's --jitCompile setting. + jit = self.jit_compile and not self.indata.sparse def _loss_val(self): return self._compute_loss() @@ -2241,7 +2279,18 @@ def _loss_val_grad_hessp_revrev(self, p): self.loss_val_grad_hessp_revrev = tf.function(jit_compile=jit)( _loss_val_grad_hessp_revrev.__get__(self, type(self)) ) - if self.hvp_method == "fwdrev": + # tf.autodiff.ForwardAccumulator does not support tangent + # propagation through SparseMatrixMatMul (no JVP rule for the + # CSR variant), so the fwdrev HVP cannot be used in sparse mode. + # Fall back to revrev with a warning. + if self.hvp_method == "fwdrev" and self.indata.sparse: + logger.warning( + "fwdrev HVP is not supported in sparse mode " + "(tf.autodiff.ForwardAccumulator cannot trace through " + "tf.linalg.sparse's CSR matmul); falling back to revrev." + ) + self.loss_val_grad_hessp = self.loss_val_grad_hessp_revrev + elif self.hvp_method == "fwdrev": self.loss_val_grad_hessp = self.loss_val_grad_hessp_fwdrev else: self.loss_val_grad_hessp = self.loss_val_grad_hessp_revrev From b30f867bebbddafa65b965618da99416f3f671d8 Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Thu, 9 Apr 2026 02:47:15 +0200 Subject: [PATCH 11/18] fitter: external sparse Hessian via CSR matmul Switch the external sparse-Hessian likelihood term to use tf.linalg.sparse's CSR SparseMatrixMatMul instead of an element-wise gather-based 0.5 x^T H x form. The CSR matmul kernel is multi- threaded, and crucially its registered gradient is itself a single sm.matmul call, so reverse-over-reverse autodiff no longer rematerializes a 2D gather/scatter chain in the second-order tape. On large external-Hessian problems this was the dominant HVP cost. Changes: * Fitter.__init__ external_terms loop: replace the "hess_sparse" (rows, cols, vals) tuple with a "hess_csr" CSRSparseMatrix view of the canonically-sorted SparseTensor, built once per term. * _compute_external_nll: dispatch on "hess_csr" instead of "hess_sparse" and compute 0.5 * x_sub^T (H @ x_sub) via tf_sparse_csr.matmul. Profile on the jpsi calibration tensor (329M-nnz prefit external Hessian on 108332 of the 108334 fit parameters): the closed-form external HVP path that previously dominated the second-order tape collapses to a single CSR matvec per HVP call, contributing negligibly to the per-call cost. Co-Authored-By: Claude Opus 4.6 (1M context) --- rabbit/fitter.py | 48 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/rabbit/fitter.py b/rabbit/fitter.py index d39deaa..5148c4a 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -346,16 +346,30 @@ def init_fit_parms( ) tf_hess_dense = None - tf_hess_sparse = None + tf_hess_csr = None if term["hess_dense"] is not None: tf_hess_dense = tf.constant(term["hess_dense"], dtype=self.indata.dtype) elif term["hess_sparse"] is not None: + # Build a CSRSparseMatrix view of the stored sparse Hessian + # for use in the closed-form external gradient/HVP path via + # sm.matmul. The Hessian is assumed symmetric, so the loss + # L = 0.5 x_sub^T H x_sub has gradient H @ x_sub and HVP + # H @ p_sub, each a single sm.matmul call. NOTE: + # SparseMatrixMatMul has no XLA kernel, so any tf.function + # that calls sm.matmul must be built with jit_compile=False. rows, cols, vals = term["hess_sparse"] - tf_hess_sparse = ( - tf.constant(rows, dtype=tf.int64), - tf.constant(cols, dtype=tf.int64), - tf.constant(vals, dtype=self.indata.dtype), + rows_np = np.asarray(rows, dtype=np.int64) + cols_np = np.asarray(cols, dtype=np.int64) + vals_np = np.asarray(vals) + n_sub = int(len(params)) + order = np.lexsort((cols_np, rows_np)) + indices_sorted = np.stack([rows_np[order], cols_np[order]], axis=1) + hess_st = tf.SparseTensor( + indices=tf.constant(indices_sorted, dtype=tf.int64), + values=tf.constant(vals_np[order], dtype=self.indata.dtype), + dense_shape=tf.constant([n_sub, n_sub], dtype=tf.int64), ) + tf_hess_csr = tf_sparse_csr.CSRSparseMatrix(hess_st) self.external_terms.append( { @@ -363,7 +377,7 @@ def init_fit_parms( "indices": tf_indices, "grad": tf_grad, "hess_dense": tf_hess_dense, - "hess_sparse": tf_hess_sparse, + "hess_csr": tf_hess_csr, } ) @@ -2185,7 +2199,17 @@ def _compute_nll_components(self, profile=True, full_nll=False): return ln, lc, lbeta, lpenalty, beta def _compute_external_nll(self): - """Sum of external likelihood term contributions: sum_i (g_i^T x_sub + 0.5 x_sub^T H_i x_sub).""" + """Sum of external likelihood term contributions: sum_i (g_i^T x_sub + 0.5 x_sub^T H_i x_sub). + + For sparse-Hessian terms this uses tf.linalg.sparse's CSR matmul, + which dispatches to a multi-threaded kernel and is much faster + per call than the previous element-wise gather-based form. The + autodiff gradient and HVP of 0.5 x^T H x via sm.matmul are + themselves single sm.matmul calls, so reverse-over-reverse autodiff + no longer rematerializes a 2D gather/scatter chain in the second- + order tape — that was the dominant cost on large external-Hessian + problems before this rewrite (e.g. jpsi: 329M-nnz prefit Hessian). + """ if not self.external_terms: return None total = tf.zeros([], dtype=self.indata.dtype) @@ -2198,11 +2222,13 @@ def _compute_external_nll(self): total = total + 0.5 * tf.reduce_sum( x_sub * tf.linalg.matvec(term["hess_dense"], x_sub) ) - elif term["hess_sparse"] is not None: - rows, cols, vals = term["hess_sparse"] - total = total + 0.5 * tf.reduce_sum( - vals * tf.gather(x_sub, rows) * tf.gather(x_sub, cols) + elif term["hess_csr"] is not None: + # Loss = 0.5 * x_sub^T H x_sub via CSR matvec (H symmetric). + Hx = tf.squeeze( + tf_sparse_csr.matmul(term["hess_csr"], x_sub[:, None]), + axis=-1, ) + total = total + 0.5 * tf.reduce_sum(x_sub * Hx) return total def _compute_nll(self, profile=True, full_nll=False): From 3f41fc64a6bf01359c953feb08997da882a038a0 Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Thu, 9 Apr 2026 03:02:51 +0200 Subject: [PATCH 12/18] rabbit_fit, setup.sh: enable XLA multi-threaded Eigen on CPU Set XLA_FLAGS=--xla_cpu_multi_thread_eigen=true so XLA's CPU emitter uses Eigen's multi-threaded routines for the dense linear-algebra ops generated by jit_compile=True. This is a free win on dense fits with no downside on sparse mode (where the dominant ops have no parallel CPU kernel anyway). Measured ~1.3x speedup on dense large-model HVP and loss+grad on a many-core system: default HVP 51.1 ms L+G 31.2 ms --xla_cpu_multi_thread_eigen=true HVP 39.1 ms L+G 23.0 ms The flag is set in two places: * setup.sh: exported when users source the rabbit setup script. Append-only so any user-set XLA_FLAGS survive. * bin/rabbit_fit.py: also set programmatically at the very top of the script (before any TF import) so users who launch rabbit_fit.py directly without sourcing setup.sh still get the speedup. Same append-only logic. Co-Authored-By: Claude Opus 4.6 (1M context) --- bin/rabbit_fit.py | 14 ++++++++++++++ setup.sh | 7 +++++++ 2 files changed, 21 insertions(+) diff --git a/bin/rabbit_fit.py b/bin/rabbit_fit.py index dcf8b1a..49522e5 100755 --- a/bin/rabbit_fit.py +++ b/bin/rabbit_fit.py @@ -1,5 +1,19 @@ #!/usr/bin/env python3 +# Enable XLA's multi-threaded Eigen path on CPU before importing tensorflow. +# This must be set before any TF import (including transitive) because XLA +# parses XLA_FLAGS once during runtime initialization. Measured ~1.3x speedup +# on dense large-model HVP/loss+grad on a many-core system, no downside. +# Users who set their own XLA_FLAGS keep theirs and we append. +import os as _os + +_xla_default = "--xla_cpu_multi_thread_eigen=true" +_existing = _os.environ.get("XLA_FLAGS", "") +if "xla_cpu_multi_thread_eigen" not in _existing: + _os.environ["XLA_FLAGS"] = ( + f"{_existing} {_xla_default}".strip() if _existing else _xla_default + ) + import copy import tensorflow as tf diff --git a/setup.sh b/setup.sh index 5c038f1..327ee07 100644 --- a/setup.sh +++ b/setup.sh @@ -2,4 +2,11 @@ export RABBIT_BASE=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) export PYTHONPATH="${RABBIT_BASE}:$PYTHONPATH" export PATH="$PATH:${RABBIT_BASE}/bin" +# Enable XLA's multi-threaded Eigen path on CPU. ~1.3x speedup on dense +# large-model HVP/loss+grad on many-core systems, no downside on smaller +# problems. Append to any existing XLA_FLAGS so user-set flags survive. +if [[ ":${XLA_FLAGS:-}:" != *":--xla_cpu_multi_thread_eigen=true:"* ]]; then + export XLA_FLAGS="${XLA_FLAGS:+$XLA_FLAGS }--xla_cpu_multi_thread_eigen=true" +fi + echo "Created environment variable RABBIT_BASE=${RABBIT_BASE}" From 183f376e95ff98322664aa5dfcfbac40b43b7e35 Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Thu, 9 Apr 2026 17:08:18 +0200 Subject: [PATCH 13/18] fitter, rabbit_fit: skip dense cov allocation under --noHessian The Fitter previously always allocated a dense [npar, npar] covariance tf.Variable, regardless of whether the postfit Hessian would actually be computed. For very large parameter counts this is infeasible (~94 GB for 108k parameters in float64) and prevents --noHessian from being usable as a low-memory mode. Changes: * Fitter.__init__: read options.noHessian into self.compute_cov. Always allocate the new self.var_prefit vector tf.Variable (length npar). Only allocate the dense self.cov tf.Variable when compute_cov=True; otherwise self.cov is None. * prefit_covariance() is split into: - prefit_variance(unconstrained_err): returns the per-parameter variance vector - prefit_covariance(unconstrained_err): returns a tf.linalg.LinearOperatorDiag wrapping the variance vector, so callers that want a matrix-like interface get one without ever materializing the dense [npar, npar] form. Callers that actually need a dense tensor can call .to_dense(). * defaultassign now updates var_prefit always and cov only when it exists. * randomize_parameters samples from var_prefit when cov is None (the existing diagonal fast path was already correct for the prefit case; only the source of the variances needed to change). * load_fitresult raises a clear error if an external covariance is provided when self.cov is None. * bin/rabbit_fit.py: - Prefit add_parms_hist reads ifitter.var_prefit instead of tf.linalg.diag_part(ifitter.cov), which would fail under --noHessian. - The --computeVariations prefit branch uses prefit_covariance(unconstrained_err=1.0).to_dense() to feed the temporary cov assign (since prefit_covariance now returns a LinearOperator). - The early --noHessian guard now rejects every flag that actually requires the postfit covariance: --doImpacts, --computeVariations, --saveHists (without --noChi2), --computeHistErrors[PerProcess], --computeHistCov, --computeHistImpacts, --computeHistGaussianImpacts, and --externalPostfit. Verified on the small test tensor: under --noHessian fitter.cov is None, var_prefit is a length-13 vector, and a plain fit converges. The incompatible-flag combinations raise clean errors with a single descriptive message. Co-Authored-By: Claude Opus 4.6 (1M context) --- bin/rabbit_fit.py | 38 +++++++++++++++++--- rabbit/fitter.py | 89 +++++++++++++++++++++++++++++++++++++---------- 2 files changed, 105 insertions(+), 22 deletions(-) diff --git a/bin/rabbit_fit.py b/bin/rabbit_fit.py index 49522e5..5751107 100755 --- a/bin/rabbit_fit.py +++ b/bin/rabbit_fit.py @@ -363,9 +363,16 @@ def save_hists(args, mappings, fitter, ws, prefit=True, profile=False): ) if args.computeVariations: + if fitter.cov is None: + raise RuntimeError( + "--computeVariations requires the parameter covariance " + "matrix and so is incompatible with --noHessian." + ) if prefit: cov_prefit = fitter.cov.numpy() - fitter.cov.assign(fitter.prefit_covariance(unconstrained_err=1.0)) + fitter.cov.assign( + fitter.prefit_covariance(unconstrained_err=1.0).to_dense() + ) exp, aux = fitter.expected_events( mapping, @@ -574,8 +581,31 @@ def main(): if args.eager: tf.config.run_functions_eagerly(True) - if args.noHessian and args.doImpacts: - raise Exception('option "--noHessian" only works without "--doImpacts"') + # --noHessian skips computing the postfit Hessian, so the dense + # parameter covariance matrix is never available. Any feature that + # needs the covariance is incompatible. + if args.noHessian: + _incompat = [] + if args.doImpacts: + _incompat.append("--doImpacts") + if args.computeVariations: + _incompat.append("--computeVariations") + if args.saveHists and not args.noChi2: + _incompat.append("--saveHists (without --noChi2)") + if args.computeHistErrors: + _incompat.append("--computeHistErrors") + if args.computeHistErrorsPerProcess: + _incompat.append("--computeHistErrorsPerProcess") + if args.computeHistCov: + _incompat.append("--computeHistCov") + if args.computeHistImpacts: + _incompat.append("--computeHistImpacts") + if args.computeHistGaussianImpacts: + _incompat.append("--computeHistGaussianImpacts") + if args.externalPostfit is not None: + _incompat.append("--externalPostfit") + if _incompat: + raise Exception("--noHessian is incompatible with: " + ", ".join(_incompat)) global logger logger = logging.setup_logger(__file__, args.verbose, args.noColorLogger) @@ -693,7 +723,7 @@ def main(): ws.add_parms_hist( values=ifitter.x, - variances=tf.linalg.diag_part(ifitter.cov), + variances=ifitter.var_prefit, hist_name="parms_prefit", ) diff --git a/rabbit/fitter.py b/rabbit/fitter.py index 5148c4a..7d05db7 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -124,6 +124,12 @@ def __init__( self.minimizer_method = options.minimizerMethod self.hvp_method = getattr(options, "hvpMethod", "revrev") self.jit_compile = getattr(options, "jitCompile", True) + # When --noHessian is requested the postfit Hessian is never + # computed, so the dense [npar, npar] covariance matrix should + # not be allocated. self.cov is set to None in that case and + # callers must use self.var_prefit (the diagonal vector form) + # for prefit uncertainties instead. + self.compute_cov = not getattr(options, "noHessian", False) if options.covarianceFit and options.chisqFit: raise Exception( @@ -307,15 +313,29 @@ def init_fit_parms( self.x = tf.Variable(xdefault, trainable=True, name="x") - # parameter covariance matrix - self.cov = tf.Variable( - self.prefit_covariance( + # Per-parameter prefit variance vector. Always allocated; the + # prefit covariance is intrinsically diagonal so this is the + # only form needed for prefit uncertainties. + self.var_prefit = tf.Variable( + self.prefit_variance( unconstrained_err=self.prefit_unconstrained_nuisance_uncertainty ), trainable=False, - name="cov", + name="var_prefit", ) + # Full parameter covariance matrix. Allocated only when the + # postfit Hessian will actually be computed; otherwise None to + # avoid the O(npar^2) allocation (94 GB for 108k parameters). + if self.compute_cov: + self.cov = tf.Variable( + tf.linalg.diag(self.var_prefit), + trainable=False, + name="cov", + ) + else: + self.cov = None + # regularization self.regularizers = [] # one common regularization strength parameter @@ -488,6 +508,13 @@ def load_fitresult(self, fitresult_file, fitresult_key, profile=True): self.x.assign(xvals) if cov_ext is not None: + if self.cov is None: + raise RuntimeError( + "load_fitresult: external covariance was provided but " + "the fitter was constructed with --noHessian (no full " + "covariance is allocated). Construct the fitter without " + "--noHessian to load an external covariance." + ) covval = self.cov.numpy() covval[np.ix_(idxs, idxs)] = cov_ext[np.ix_(idxs_ext, idxs_ext)] self.cov.assign(tf.constant(covval)) @@ -623,23 +650,38 @@ def _default_beta0(self): elif self.binByBinStatType == "normal-additive": return tf.zeros(self.beta_shape, dtype=self.indata.dtype) - def prefit_covariance(self, unconstrained_err=0.0): - # free parameters are taken to have zero uncertainty for the purposes of prefit uncertainties + def prefit_variance(self, unconstrained_err=0.0): + """Per-parameter prefit variance vector of length npar. + + Free parameters (POIs and unconstrained nuisances) are assigned a + placeholder variance of unconstrained_err**2 (zero by default). + Constrained nuisances take their variance from the constraint + term (1 / constraintweight). + """ var_poi = ( tf.ones([self.poi_model.npoi], dtype=self.indata.dtype) * unconstrained_err**2 ) - - # nuisances have their uncertainty taken from the constraint term, but unconstrained nuisances - # are set to a placeholder uncertainty (zero by default) for the purposes of prefit uncertainties var_theta = tf.where( self.indata.constraintweights == 0.0, unconstrained_err**2, tf.math.reciprocal(self.indata.constraintweights), ) + return tf.concat([var_poi, var_theta], axis=0) + + def prefit_covariance(self, unconstrained_err=0.0): + """Full prefit covariance as a tf.linalg.LinearOperatorDiag. - invhessianprefit = tf.linalg.diag(tf.concat([var_poi, var_theta], axis=0)) - return invhessianprefit + The prefit covariance is intrinsically diagonal, so we return a + LinearOperator that exposes a matrix-like interface (matvec, etc.) + without ever allocating the dense [npar, npar] form. Callers that + actually need a dense tensor can call .to_dense() explicitly. + """ + return tf.linalg.LinearOperatorDiag( + self.prefit_variance(unconstrained_err=unconstrained_err), + is_self_adjoint=True, + is_positive_definite=True, + ) @tf.function def val_jac(self, fun, *args, **kwargs): @@ -689,11 +731,12 @@ def betadefaultassign(self): self.beta.assign(self.beta0) def defaultassign(self): - self.cov.assign( - self.prefit_covariance( - unconstrained_err=self.prefit_unconstrained_nuisance_uncertainty - ) + var_pre = self.prefit_variance( + unconstrained_err=self.prefit_unconstrained_nuisance_uncertainty ) + self.var_prefit.assign(var_pre) + if self.cov is not None: + self.cov.assign(tf.linalg.diag(var_pre)) self.theta0defaultassign() if self.binByBinStat: self.beta0defaultassign() @@ -862,13 +905,23 @@ def toyassign( # the special handling of the diagonal case here speeds things up, but is also required # in case the prefit covariance has zero for some uncertainties (which is the default # for unconstrained nuisances for example) since the multivariate normal distribution - # requires a positive-definite covariance matrix - if tfh.is_diag(self.cov): + # requires a positive-definite covariance matrix. + # Under --noHessian self.cov is None and only the diagonal + # prefit variance vector is available, so we always take the + # diagonal branch in that case (sourcing the variances from + # var_prefit directly). + cov_is_diag = self.cov is None or tfh.is_diag(self.cov) + if cov_is_diag: + stddev = ( + tf.sqrt(self.var_prefit) + if self.cov is None + else tf.sqrt(tf.linalg.diag_part(self.cov)) + ) self.x.assign( tf.random.normal( shape=[], mean=self.x, - stddev=tf.sqrt(tf.linalg.diag_part(self.cov)), + stddev=stddev, dtype=self.x.dtype, ) ) From 849333202255fa3a92211d232ea3ac2a041c773b Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Thu, 9 Apr 2026 17:45:58 +0200 Subject: [PATCH 14/18] fitter: speed up Fitter.__init__ on large external sparse Hessians MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two structural fixes that reduce Fitter.__init__ from ~370 s to ~20 s on the jpsi calibration tensor (108k parameters, 329M-nnz external sparse Hessian). * Replace per-parameter np.where lookup with a single dict. The old code did np.where(parms_str == p) for each of the external term's parameters against the full ~10^5-element parameter list — quadratic, ~150 s. Build a name->index dict once and look each parameter up in O(1). * Detect already-sorted sparse-Hessian indices and skip np.lexsort. tf.SparseTensor / sparse_tensor_to_csr_sparse_matrix require canonical row-major order. The TensorWriter does not guarantee this for sparse-Hessian external terms, but in practice the indices are often already sorted (e.g. when the source SparseHist has its _flat_indices in flat-index order and they get split via np.divmod(flat, n), which preserves the ordering). A single vectorized O(nnz) check skips the much slower np.lexsort (~54 s on 329M nnz) when the data is already canonical, falling back to lexsort otherwise. The remaining ~13 s in Fitter.__init__ on the jpsi tensor is the unavoidable cost of materializing the 329M-nnz arrays into TF tensors (np.stack of the [nnz, 2] index buffer, tf.constant on both index and value buffers, and the CSR conversion proper). Co-Authored-By: Claude Opus 4.6 (1M context) --- rabbit/fitter.py | 47 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/rabbit/fitter.py b/rabbit/fitter.py index 7d05db7..ab6cab4 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -343,20 +343,29 @@ def init_fit_parms( # External likelihood terms (additive g^T x + 0.5 x^T H x contributions # to the NLL). Resolve parameter name strings against the full fit - # parameter list (POIs + systs). + # parameter list (POIs + systs). Build a single name->index dict + # once so the per-term resolution is O(n) instead of O(n^2) -- the + # latter cost ~150 s on a 108k-parameter setup with a 108k-parameter + # external term. self.external_terms = [] parms_str = self.parms.astype(str) + parms_idx = {name: i for i, name in enumerate(parms_str)} + if len(parms_idx) != len(parms_str): + raise RuntimeError( + "Duplicate parameter names in fitter parameter list; " + "external term resolution requires unique names." + ) for term in self.indata.external_terms: params = np.asarray(term["params"]).astype(str) indices = np.empty(len(params), dtype=np.int64) for i, p in enumerate(params): - matches = np.where(parms_str == p)[0] - if len(matches) != 1: + j = parms_idx.get(p, -1) + if j < 0: raise RuntimeError( f"External likelihood term '{term['name']}' parameter " - f"'{p}' matched {len(matches)} entries in fit parameters" + f"'{p}' not found in fit parameters" ) - indices[i] = matches[0] + indices[i] = j tf_indices = tf.constant(indices, dtype=tf.int64) tf_grad = ( @@ -382,11 +391,33 @@ def init_fit_parms( cols_np = np.asarray(cols, dtype=np.int64) vals_np = np.asarray(vals) n_sub = int(len(params)) - order = np.lexsort((cols_np, rows_np)) - indices_sorted = np.stack([rows_np[order], cols_np[order]], axis=1) + # tf.sparse.SparseTensor / sparse_tensor_to_csr_sparse_matrix + # requires canonical row-major lexicographic ordering of the + # (row, col) indices. The TensorWriter does not guarantee + # this for sparse-Hessian external terms, but in practice + # the data is often already sorted (e.g. when it comes from + # a SparseHist whose underlying flat indices are in + # row-major order). Detect that fast path with an O(nnz) + # check and skip the much slower np.lexsort -- on a 329M-nnz + # input the lexsort alone takes ~50 s. + if rows_np.size > 1 and bool( + np.all( + (rows_np[:-1] < rows_np[1:]) + | ( + (rows_np[:-1] == rows_np[1:]) + & (cols_np[:-1] <= cols_np[1:]) + ) + ) + ): + indices_sorted = np.stack([rows_np, cols_np], axis=1) + vals_sorted = vals_np + else: + order = np.lexsort((cols_np, rows_np)) + indices_sorted = np.stack([rows_np[order], cols_np[order]], axis=1) + vals_sorted = vals_np[order] hess_st = tf.SparseTensor( indices=tf.constant(indices_sorted, dtype=tf.int64), - values=tf.constant(vals_np[order], dtype=self.indata.dtype), + values=tf.constant(vals_sorted, dtype=self.indata.dtype), dense_shape=tf.constant([n_sub, n_sub], dtype=tf.int64), ) tf_hess_csr = tf_sparse_csr.CSRSparseMatrix(hess_st) From 6c1c18791a5b3e7a95b57bb551b083dfe856a5ca Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Thu, 9 Apr 2026 18:51:50 +0200 Subject: [PATCH 15/18] unify sparse-Hessian IO path; sort at write time, drop reorder calls Three coupled changes that simplify the IO path for the external sparse Hessian and trim the rest of Fitter.__init__ on large problems. * tensorwriter.py: sort the external sparse-Hessian indices into canonical row-major order at write time, matching what the writer already does for the sparse logk and norm tensors. Use a single ravel_multi_index + argsort. Add a fast-path that detects when the input is already canonical via a vectorized O(nnz) check and skips the sort entirely (typical when the source is a SparseHist built from a scipy CSR / CSC, which iterates row-major by definition). * inputdata.py: hess_sparse is now read via the same makesparsetensor() helper used for the sparse norm and logk, yielding a tf.sparse.SparseTensor directly. The previous code manually unpacked the indices into (rows, cols, vals) tuple form which forced an unnecessary numpy roundtrip downstream. The defensive tf.sparse.reorder calls on norm and logk are also dropped: the writer already sorts these into canonical order, so the reorder was redundant. * fitter.py external term loop: receive the SparseTensor and feed it straight to tf_sparse_csr.CSRSparseMatrix without an additional reorder step (the writer already canonicalized). This drops the in-Python np.lexsort + np.stack + tf.constant roundtrip on the 329M-nnz jpsi external Hessian. Effect on jpsi calibration tensor (108k params, 329M-nnz prefit sparse Hessian) Fitter.__init__: pre-IO unification: 20.5 s post: 5.3 s The TensorWriter side is ~5 s slower than before the sort was added (the unavoidable cost of validating canonical order on 329M nnz). For SparseHist inputs from scipy CSR/CSC the data is always pre-sorted so the validation succeeds and no additional sort is performed. Co-Authored-By: Claude Opus 4.6 (1M context) --- rabbit/fitter.py | 40 +++++----------------------------------- rabbit/inputdata.py | 31 +++++++++++-------------------- rabbit/tensorwriter.py | 32 ++++++++++++++++++++++++++++++-- 3 files changed, 46 insertions(+), 57 deletions(-) diff --git a/rabbit/fitter.py b/rabbit/fitter.py index ab6cab4..22a6f41 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -386,41 +386,11 @@ def init_fit_parms( # H @ p_sub, each a single sm.matmul call. NOTE: # SparseMatrixMatMul has no XLA kernel, so any tf.function # that calls sm.matmul must be built with jit_compile=False. - rows, cols, vals = term["hess_sparse"] - rows_np = np.asarray(rows, dtype=np.int64) - cols_np = np.asarray(cols, dtype=np.int64) - vals_np = np.asarray(vals) - n_sub = int(len(params)) - # tf.sparse.SparseTensor / sparse_tensor_to_csr_sparse_matrix - # requires canonical row-major lexicographic ordering of the - # (row, col) indices. The TensorWriter does not guarantee - # this for sparse-Hessian external terms, but in practice - # the data is often already sorted (e.g. when it comes from - # a SparseHist whose underlying flat indices are in - # row-major order). Detect that fast path with an O(nnz) - # check and skip the much slower np.lexsort -- on a 329M-nnz - # input the lexsort alone takes ~50 s. - if rows_np.size > 1 and bool( - np.all( - (rows_np[:-1] < rows_np[1:]) - | ( - (rows_np[:-1] == rows_np[1:]) - & (cols_np[:-1] <= cols_np[1:]) - ) - ) - ): - indices_sorted = np.stack([rows_np, cols_np], axis=1) - vals_sorted = vals_np - else: - order = np.lexsort((cols_np, rows_np)) - indices_sorted = np.stack([rows_np[order], cols_np[order]], axis=1) - vals_sorted = vals_np[order] - hess_st = tf.SparseTensor( - indices=tf.constant(indices_sorted, dtype=tf.int64), - values=tf.constant(vals_sorted, dtype=self.indata.dtype), - dense_shape=tf.constant([n_sub, n_sub], dtype=tf.int64), - ) - tf_hess_csr = tf_sparse_csr.CSRSparseMatrix(hess_st) + # The TensorWriter sorts the indices into canonical + # row-major order at write time, so we can feed the + # SparseTensor straight to the CSR builder without an + # additional reorder step. + tf_hess_csr = tf_sparse_csr.CSRSparseMatrix(term["hess_sparse"]) self.external_terms.append( { diff --git a/rabbit/inputdata.py b/rabbit/inputdata.py index a539520..27844e5 100644 --- a/rabbit/inputdata.py +++ b/rabbit/inputdata.py @@ -82,14 +82,11 @@ def __init__(self, filename, pseudodata=None): self.sparse = not "hnorm" in f if self.sparse: + # The TensorWriter sorts the sparse norm/logk indices into + # canonical row-major order at write time, so consumers can + # rely on that without an extra tf.sparse.reorder call. self.norm = makesparsetensor(f["hnorm_sparse"]) self.logk = makesparsetensor(f["hlogk_sparse"]) - # Canonicalize index ordering once at load time. The fitter's - # sparse fast path reduces nonzero entries via row-keyed - # reductions; sorted row-major indices give coalesced memory - # access. tf.sparse.reorder sorts into row-major order. - self.norm = tf.sparse.reorder(self.norm) - self.logk = tf.sparse.reorder(self.logk) # Pre-build a CSRSparseMatrix view of logk for use in the # fitter's sparse matvec path via sm.matmul, which dispatches # to a multi-threaded CSR kernel and is much faster per call @@ -202,7 +199,9 @@ def __init__(self, filename, pseudodata=None): # params: 1D ndarray of parameter name strings # grad_values: 1D float ndarray or None # hess_dense: 2D float ndarray or None - # hess_sparse: tuple (rows, cols, values) or None + # hess_sparse: tf.sparse.SparseTensor or None + # (sparsity pattern of the [npar_sub, npar_sub] Hessian; + # same on-disk layout as hlogk_sparse / hnorm_sparse) self.external_terms = [] if "external_terms" in f.keys(): names = [ @@ -226,19 +225,11 @@ def __init__(self, filename, pseudodata=None): if "hess_dense" in tg.keys() else None ) - hess_sparse = None - if "hess_sparse" in tg.keys(): - hg = tg["hess_sparse"] - idx_dset = hg["indices"] - if "original_shape" in idx_dset.attrs: - idx_shape = tuple(idx_dset.attrs["original_shape"]) - indices = np.asarray(idx_dset).reshape(idx_shape) - else: - indices = np.asarray(idx_dset) - rows = indices[:, 0] - cols = indices[:, 1] - vals = np.asarray(hg["values"]) - hess_sparse = (rows, cols, vals) + hess_sparse = ( + makesparsetensor(tg["hess_sparse"]) + if "hess_sparse" in tg.keys() + else None + ) self.external_terms.append( { "name": tname, diff --git a/rabbit/tensorwriter.py b/rabbit/tensorwriter.py index 0c71d75..7268ea5 100644 --- a/rabbit/tensorwriter.py +++ b/rabbit/tensorwriter.py @@ -2187,10 +2187,38 @@ def create_dataset( elif term["hess_sparse"] is not None: rows, cols, vals = term["hess_sparse"] n = len(term["params"]) - indices = np.stack([rows, cols], axis=-1).astype(self.idxdtype) + rows = np.asarray(rows, dtype=self.idxdtype) + cols = np.asarray(cols, dtype=self.idxdtype) + vals = np.asarray(vals, dtype=self.dtype) + # Sort into canonical row-major order so the reader + # (and downstream tf.sparse / CSR consumers) can skip + # the reorder step. The fast path: if the input is + # already canonical (typical when the source is a + # SparseHist whose flat indices come in flat-index + # order), skip the O(nnz log nnz) argsort entirely. + # The check is a single vectorized O(nnz) pass and + # is essentially free compared to the sort it avoids + # (~50-150 s on 329M nnz). + if rows.size > 1: + drows = np.diff(rows) + dcols = np.diff(cols) + already_sorted = bool( + np.all((drows > 0) | ((drows == 0) & (dcols >= 0))) + ) + del drows, dcols + else: + already_sorted = True + if not already_sorted: + flat = np.ravel_multi_index((rows, cols), (n, n)) + sort_order = np.argsort(flat) + del flat + rows = rows[sort_order] + cols = cols[sort_order] + vals = vals[sort_order] + indices = np.stack([rows, cols], axis=-1) nbytes += h5pyutils_write.writeSparse( indices, - vals.astype(self.dtype), + vals, (n, n), term_group, "hess_sparse", From db274db1a36fd5f4b8c349b241874dd7cf912cb0 Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Fri, 10 Apr 2026 01:59:26 +0200 Subject: [PATCH 16/18] fitter: Hessian-free CG solve for is_linear case under --noHessian The is_linear fast path in Fitter.minimize() previously built the full dense [npar, npar] Hessian via loss_val_grad_hess() and did a Cholesky solve. That's incompatible with --noHessian mode, which is supposed to avoid the O(npar^2) allocation entirely. Add an alternative Hessian-free branch that solves the normal equation H @ dx = -grad iteratively with scipy's conjugate gradient solver, feeding it a LinearOperator backed by loss_val_grad_hessp. For a purely quadratic NLL the Hessian is positive-definite and CG converges to machine precision in at most npar iterations (far fewer for well-conditioned problems). The Cholesky path is still used when compute_cov is True, since it has the lower per-call cost when allocating the dense Hessian is already acceptable. Verified against the Cholesky path on a constructed linear test (chisqFit + Ones POI model + normal systematics): converged parameter values match exactly; only the postfit uncertainty slots differ, which is expected because the noHessian run does not compute the covariance. Co-Authored-By: Claude Opus 4.6 (1M context) --- rabbit/fitter.py | 70 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/rabbit/fitter.py b/rabbit/fitter.py index 22a6f41..a7c59be 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -2480,28 +2480,64 @@ def scipy_hess(xval): def minimize(self): if self.is_linear: - logger.info( - "Likelihood is purely quadratic, solving by Cholesky decomposition instead of iterative fit" - ) + if self.compute_cov: + logger.info( + "Likelihood is purely quadratic, solving by Cholesky decomposition instead of iterative fit" + ) - # no need to do a minimization, simple matrix solve is sufficient - val, grad, hess = self.loss_val_grad_hess() + # no need to do a minimization, simple matrix solve is sufficient + val, grad, hess = self.loss_val_grad_hess() - # use a Cholesky decomposition to easily detect the non-positive-definite case - chol = tf.linalg.cholesky(hess) + # use a Cholesky decomposition to easily detect the non-positive-definite case + chol = tf.linalg.cholesky(hess) - # FIXME catch this exception to mark failed toys and continue - if tf.reduce_any(tf.math.is_nan(chol)).numpy(): - raise ValueError( - "Cholesky decomposition failed, Hessian is not positive-definite" - ) + # FIXME catch this exception to mark failed toys and continue + if tf.reduce_any(tf.math.is_nan(chol)).numpy(): + raise ValueError( + "Cholesky decomposition failed, Hessian is not positive-definite" + ) - del hess - gradv = grad[..., None] - dx = tf.linalg.cholesky_solve(chol, -gradv)[:, 0] - del chol + del hess + gradv = grad[..., None] + dx = tf.linalg.cholesky_solve(chol, -gradv)[:, 0] + del chol - self.x.assign_add(dx) + self.x.assign_add(dx) + else: + # --noHessian: we must not allocate the dense [npar, npar] + # Hessian that the Cholesky path above builds. Solve the + # normal equation H @ dx = -grad iteratively via conjugate + # gradient using only Hessian-vector products, which is + # already exposed as self.loss_val_grad_hessp. For a + # purely quadratic NLL the Hessian is positive-definite + # and CG converges to machine precision in at most npar + # steps (typically far fewer for well-conditioned + # problems). + import scipy.sparse.linalg as _spla + + logger.info( + "Likelihood is purely quadratic, solving with " + "Hessian-free conjugate gradient (--noHessian)" + ) + val, grad = self.loss_val_grad() + grad_np = grad.numpy() + n = int(grad_np.shape[0]) + dtype = grad_np.dtype + + def _hvp_np(p_np): + p_tf = tf.constant(p_np, dtype=self.x.dtype) + _, _, hessp = self.loss_val_grad_hessp(p_tf) + return hessp.numpy() + + op = _spla.LinearOperator((n, n), matvec=_hvp_np, dtype=dtype) + dx_np, info = _spla.cg(op, -grad_np, rtol=1e-10, atol=0.0) + if info != 0: + raise ValueError( + f"CG solver did not converge (info={info}); the " + "Hessian may not be positive-definite or the " + "problem may be ill-conditioned" + ) + self.x.assign_add(tf.constant(dx_np, dtype=self.x.dtype)) callback = None else: From d0708b1a27bb87998af715494891540e6c69cb6e Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Fri, 10 Apr 2026 02:23:29 +0200 Subject: [PATCH 17/18] fitter, rabbit_fit: edmval + POI/NOI uncertainties under --noHessian Under --noHessian we previously left edmval and all postfit parameter uncertainties as NaN because the dense covariance is never allocated. Now compute both via Hessian-free conjugate gradient solves of H v = grad -> edmval = 0.5 * grad^T v H c_i = e_i -> c_i is the i-th row of cov using scipy.sparse.linalg.cg with a LinearOperator backed by self.loss_val_grad_hessp. No dense Hessian or covariance is ever materialized; memory stays O(npar) instead of O(npar^2). The cov rows are computed only for the parameters the user cares about at this point -- the POIs (indices [0, npoi)) and the NOIs (npoi + indata.noiidxs). Their diagonal entries give the postfit standard deviations, which are populated into the parms_variances vector passed to add_parms_hist. Non-POI / non-NOI nuisances keep NaN variances, signalling that the postfit covariance for those parameters was not computed. Verified on the small test tensor: --noHessian now reports the same edmval (7.068e-18 vs 7.068e-18) and the same POI/NOI uncertainties (sig: 0.01436 +/- 0.01436, slope_signal: 2.01328 +/- 2.01328) as the full Cholesky path. Co-Authored-By: Claude Opus 4.6 (1M context) --- bin/rabbit_fit.py | 30 ++++++++++++++++++++- rabbit/fitter.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/bin/rabbit_fit.py b/bin/rabbit_fit.py index 5751107..e01c2f2 100755 --- a/bin/rabbit_fit.py +++ b/bin/rabbit_fit.py @@ -428,6 +428,9 @@ def fit(args, fitter, ws, dofit=True): ws.add_1D_integer_hist(cb.loss_history, "epoch", "loss") ws.add_1D_integer_hist(cb.time_history, "epoch", "time") + # prefit variances as the default fallback for add_parms_hist below + parms_variances = None + if not args.noHessian: # compute the covariance matrix and estimated distance to minimum _, grad, hess = fitter.loss_val_grad_hess() @@ -467,6 +470,31 @@ def fit(args, fitter, ws, dofit=True): global_impacts=True, ) + parms_variances = tf.linalg.diag_part(fitter.cov) + else: + # --noHessian: avoid the full dense Hessian. Still compute edmval + # and the POI+NOI uncertainties via a Hessian-free conjugate + # gradient solve of H @ v = grad and H @ c_i = e_i, using only + # Hessian-vector products. The CG solves touch O(npar) memory + # per call instead of O(npar^2), so this works on problems + # where the full covariance would be infeasible. + _, grad = fitter.loss_val_grad() + npoi = int(fitter.poi_model.npoi) + noi_idx_in_x = np.asarray(fitter.indata.noiidxs, dtype=np.int64) + npoi + poi_noi_idx = np.concatenate([np.arange(npoi, dtype=np.int64), noi_idx_in_x]) + edmval, cov_rows = fitter.edmval_cov_rows_hessfree(grad, poi_noi_idx) + logger.info(f"edmval: {edmval}") + + # Build a full-length variance vector with the POI+NOI entries + # populated from the diagonal of the CG-solved rows and the rest + # left as NaN (we did not compute those). add_parms_hist stores + # the vector verbatim into the workspace. + n = int(fitter.x.shape[0]) + parms_variances_np = np.full(n, np.nan, dtype=np.float64) + for k, i in enumerate(poi_noi_idx): + parms_variances_np[int(i)] = cov_rows[k, int(i)] + parms_variances = tf.constant(parms_variances_np, dtype=fitter.indata.dtype) + nllvalreduced = fitter.reduced_nll().numpy() ndfsat = ( @@ -497,7 +525,7 @@ def fit(args, fitter, ws, dofit=True): ws.add_parms_hist( values=fitter.x, - variances=tf.linalg.diag_part(fitter.cov) if not args.noHessian else None, + variances=parms_variances, hist_name="parms", ) diff --git a/rabbit/fitter.py b/rabbit/fitter.py index a7c59be..25f96bb 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -965,6 +965,72 @@ def edmval_cov(self, grad, hess): else: return edmval_cov(grad, hess) + def edmval_cov_rows_hessfree(self, grad, row_indices, rtol=1e-10, maxiter=None): + """Hessian-free edmval + selected rows of the covariance matrix. + + Used under --noHessian to avoid allocating the dense [npar, npar] + Hessian. Solves the linear systems + + H v = grad -> edmval = 0.5 * grad^T v + H c_i = e_i -> c_i is the i-th column/row of cov + + iteratively via scipy's conjugate gradient, feeding it a + LinearOperator backed by self.loss_val_grad_hessp. The Hessian + must be positive-definite; that's the case for a converged NLL + minimum (including the purely-quadratic --is_linear case). + + Parameters + ---------- + grad : tf.Tensor or array-like, shape [npar] + Gradient at the current x, already computed by the caller. + row_indices : iterable of int + Parameter indices to compute covariance rows for. Typically + the POI indices [0, npoi) concatenated with the NOI indices + (npoi + noiidxs). + rtol : float + Relative residual tolerance passed to scipy.sparse.linalg.cg. + maxiter : int or None + Maximum CG iterations per solve; None lets scipy choose. + + Returns + ------- + edmval : float + cov_rows : np.ndarray, shape [len(row_indices), npar] + Row i is (H^{-1})[row_indices[i], :]; diag entries give the + variances for those parameters. + """ + import scipy.sparse.linalg as _spla + + n = int(self.x.shape[0]) + dtype = np.float64 + + def _hvp_np(p_np): + p_tf = tf.constant(p_np, dtype=self.x.dtype) + _, _, hessp = self.loss_val_grad_hessp(p_tf) + return hessp.numpy() + + op = _spla.LinearOperator((n, n), matvec=_hvp_np, dtype=dtype) + + grad_np = grad.numpy() if hasattr(grad, "numpy") else np.asarray(grad) + v, info = _spla.cg(op, grad_np, rtol=rtol, atol=0.0, maxiter=maxiter) + if info != 0: + raise ValueError(f"CG solver for edmval did not converge (info={info})") + edmval = 0.5 * float(np.dot(grad_np, v)) + + row_indices = np.asarray(list(row_indices), dtype=np.int64) + cov_rows = np.empty((len(row_indices), n), dtype=dtype) + for k, i in enumerate(row_indices): + e = np.zeros(n, dtype=dtype) + e[int(i)] = 1.0 + c, info = _spla.cg(op, e, rtol=rtol, atol=0.0, maxiter=maxiter) + if info != 0: + raise ValueError( + f"CG solver for cov row {int(i)} did not converge (info={info})" + ) + cov_rows[k] = c + + return edmval, cov_rows + @tf.function def impacts_parms(self, hess): From fad47bc6772344f71eb92681671f10dbbfd9ca2d Mon Sep 17 00:00:00 2001 From: Josh Bendavid Date: Fri, 10 Apr 2026 02:49:33 +0200 Subject: [PATCH 18/18] external_likelihood: factor out external-term IO + tf build + nll eval Address PR review feedback by collecting the three external-term helpers into a single dedicated module: * read_external_terms_from_h5(ext_group) -- decode the on-disk "external_terms" h5 group into a list of raw per-term dicts. Iterate ext_group.items() directly so the writer no longer needs to store a separate "external_term_names" list. * build_tf_external_terms(terms, parms, dtype) -- promote the raw dicts to tf-side dicts (resolved indices, tf.constant grad, CSRSparseMatrix Hessian). Used by Fitter.__init__. * compute_external_nll(terms, x, dtype) -- evaluate sum_i (g_i^T x_sub + 0.5 x_sub^T H_i x_sub). Used by Fitter._compute_external_nll. FitInputData, the Fitter init external-term loop, and Fitter._compute_external_nll all collapse to one-line dispatches into this module. The tensorwriter no longer writes the hexternal_term_names dataset since the reader iterates the h5 subgroups directly. Also rework the --jitCompile CLI option per review: replace the "--noJitCompile" boolean+dest hack with a tri-state "--jitCompile {auto,on,off}" with auto as the default. The Fitter resolves the string in _make_tf_functions: "auto" silently enables jit in dense mode and disables it in sparse mode (where the CSR matmul kernels have no XLA implementation), "on" forces it (warning + falling back to off in sparse mode), "off" disables it unconditionally. Backwards compatibility: True/False are still accepted from programmatic callers. Smoke tested all 12 combinations (sparse/dense x auto/on/off x cov/--noHessian); all converge. Co-Authored-By: Claude Opus 4.6 (1M context) --- rabbit/external_likelihood.py | 224 ++++++++++++++++++++++++++++++++++ rabbit/fitter.py | 146 ++++++++-------------- rabbit/inputdata.py | 51 +------- rabbit/parsing.py | 16 +-- rabbit/tensorwriter.py | 8 +- 5 files changed, 291 insertions(+), 154 deletions(-) create mode 100644 rabbit/external_likelihood.py diff --git a/rabbit/external_likelihood.py b/rabbit/external_likelihood.py new file mode 100644 index 0000000..8b917ce --- /dev/null +++ b/rabbit/external_likelihood.py @@ -0,0 +1,224 @@ +"""Helpers for external likelihood terms (linear + quadratic parameter priors). + +An "external likelihood term" is an additive contribution to the NLL of +the form + + -log L_ext = g^T x_sub + 0.5 * x_sub^T H x_sub + +where ``x_sub`` is the subset of the fit parameters the term constrains. +Both the linear (``grad``) and quadratic (``hess_dense`` / ``hess_sparse``) +parts are optional; the sparse Hessian is stored as a +``tf.sparse.SparseTensor`` whose indices are in canonical row-major order. + +This module centralizes three things that were previously inlined in +``Fitter.__init__``, ``Fitter._compute_external_nll``, and +``FitInputData.__init__``: + +* :func:`read_external_terms_from_h5` — load the raw numpy-level + per-term dicts from an HDF5 group (used by FitInputData) +* :func:`build_tf_external_terms` — turn that list into tf-side per-term + dicts (resolved parameter indices, tf.constant grads, CSRSparseMatrix + Hessians). Used by the Fitter when it takes ownership of the input + data. +* :func:`compute_external_nll` — evaluate the scalar NLL contribution + of a list of tf-side terms at the current ``x``. +""" + +import numpy as np +import tensorflow as tf +from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops as tf_sparse_csr + +from rabbit.h5pyutils_read import makesparsetensor, maketensor + + +def read_external_terms_from_h5(ext_group): + """Decode an HDF5 ``external_terms`` group into a list of raw dicts. + + Each entry has the keys used by the rest of the pipeline: + + * ``name``: term label (str, taken from the h5 subgroup name) + * ``params``: 1D ndarray of parameter name strings + * ``grad_values``: 1D float ndarray or ``None`` + * ``hess_dense``: 2D float ndarray or ``None`` + * ``hess_sparse``: :class:`tf.sparse.SparseTensor` or ``None`` (uses + the same on-disk layout as ``hlogk_sparse`` / ``hnorm_sparse``) + + Parameters + ---------- + ext_group : h5py.Group + The ``external_terms`` group in the input HDF5 file, or ``None``. + + Returns + ------- + list[dict] + One entry per stored external term, or an empty list if + ``ext_group`` is ``None``. + """ + if ext_group is None: + return [] + + terms = [] + for tname, tg in ext_group.items(): + raw_params = tg["params"][...] + params = np.array( + [s.decode() if isinstance(s, bytes) else s for s in raw_params] + ) + grad_values = ( + np.asarray(maketensor(tg["grad_values"])) + if "grad_values" in tg.keys() + else None + ) + hess_dense = ( + np.asarray(maketensor(tg["hess_dense"])) + if "hess_dense" in tg.keys() + else None + ) + hess_sparse = ( + makesparsetensor(tg["hess_sparse"]) if "hess_sparse" in tg.keys() else None + ) + terms.append( + { + "name": tname, + "params": params, + "grad_values": grad_values, + "hess_dense": hess_dense, + "hess_sparse": hess_sparse, + } + ) + return terms + + +def build_tf_external_terms(terms, parms, dtype): + """Turn raw external-term dicts into tf-side dicts ready for the fitter. + + * Parameter names are resolved against the full fit parameter list + ``parms`` via a single ``name->index`` dict (O(n) rather than the + naive O(n^2) per-parameter ``np.where`` that this replaces — the + latter cost ~150 s on a 108k-parameter setup with a 108k-parameter + external term). + * Gradients are promoted to ``tf.constant`` in the fitter dtype. + * Dense Hessians are promoted to ``tf.constant``. + * Sparse Hessians are promoted to a :class:`CSRSparseMatrix` view + for fast ``sm.matmul``. + + Parameters + ---------- + terms : list[dict] + Raw per-term dicts as returned by :func:`read_external_terms_from_h5`. + parms : array-like of str + Full ordered list of fit parameter names (POIs + systematics). + dtype : tf.DType + Fitter dtype for gradient / Hessian tensors. + + Returns + ------- + list[dict] + One entry per term with keys ``name``, ``indices``, ``grad``, + ``hess_dense``, ``hess_csr``. Empty if ``terms`` is empty. + """ + parms_str = np.asarray(parms).astype(str) + parms_idx = {name: i for i, name in enumerate(parms_str)} + if len(parms_idx) != len(parms_str): + raise RuntimeError( + "Duplicate parameter names in fitter parameter list; " + "external term resolution requires unique names." + ) + + out = [] + for term in terms: + params = np.asarray(term["params"]).astype(str) + indices = np.empty(len(params), dtype=np.int64) + for i, p in enumerate(params): + j = parms_idx.get(p, -1) + if j < 0: + raise RuntimeError( + f"External likelihood term '{term['name']}' parameter " + f"'{p}' not found in fit parameters" + ) + indices[i] = j + tf_indices = tf.constant(indices, dtype=tf.int64) + + tf_grad = ( + tf.constant(term["grad_values"], dtype=dtype) + if term["grad_values"] is not None + else None + ) + + tf_hess_dense = None + tf_hess_csr = None + if term["hess_dense"] is not None: + tf_hess_dense = tf.constant(term["hess_dense"], dtype=dtype) + elif term["hess_sparse"] is not None: + # Build a CSRSparseMatrix view of the stored sparse Hessian + # for use in the closed-form external gradient/HVP path via + # sm.matmul. The Hessian is assumed symmetric, so the loss + # L = 0.5 x_sub^T H x_sub has gradient H @ x_sub and HVP + # H @ p_sub, each a single sm.matmul call. NOTE: + # SparseMatrixMatMul has no XLA kernel, so any tf.function + # that calls sm.matmul must be built with jit_compile=False. + # The TensorWriter sorts the indices into canonical row-major + # order at write time, so we can feed the SparseTensor + # straight to the CSR builder without an additional reorder + # step. + tf_hess_csr = tf_sparse_csr.CSRSparseMatrix(term["hess_sparse"]) + + out.append( + { + "name": term["name"], + "indices": tf_indices, + "grad": tf_grad, + "hess_dense": tf_hess_dense, + "hess_csr": tf_hess_csr, + } + ) + return out + + +def compute_external_nll(terms, x, dtype): + """Evaluate the scalar NLL contribution of a list of external terms. + + For each term, adds ``g^T x_sub + 0.5 * x_sub^T H x_sub`` to the + running total. Sparse Hessian terms use ``sm.matmul`` for the + ``H @ x_sub`` product, which dispatches to a multi-threaded CSR + kernel and is much faster per call than the previous element-wise + gather-based form. The autodiff gradient and HVP of + ``0.5 x^T H x`` via ``sm.matmul`` are themselves single + ``sm.matmul`` calls, so reverse-over-reverse autodiff no longer + rematerializes a 2D gather/scatter chain in the second-order tape + — that was the dominant cost on large external-Hessian problems + (e.g. jpsi: 329M-nnz prefit Hessian). + + Parameters + ---------- + terms : list[dict] + tf-side per-term dicts as returned by :func:`build_tf_external_terms`. + x : tf.Tensor + Current full parameter vector. + dtype : tf.DType + Dtype for the accumulator. + + Returns + ------- + tf.Tensor or None + Scalar contribution to the NLL, or ``None`` if ``terms`` is empty. + """ + if not terms: + return None + total = tf.zeros([], dtype=dtype) + for term in terms: + x_sub = tf.gather(x, term["indices"]) + if term["grad"] is not None: + total = total + tf.reduce_sum(term["grad"] * x_sub) + if term["hess_dense"] is not None: + # 0.5 * x_sub^T H x_sub + total = total + 0.5 * tf.reduce_sum( + x_sub * tf.linalg.matvec(term["hess_dense"], x_sub) + ) + elif term["hess_csr"] is not None: + # Loss = 0.5 * x_sub^T H x_sub via CSR matvec (H symmetric). + Hx = tf.squeeze( + tf_sparse_csr.matmul(term["hess_csr"], x_sub[:, None]), + axis=-1, + ) + total = total + 0.5 * tf.reduce_sum(x_sub * Hx) + return total diff --git a/rabbit/fitter.py b/rabbit/fitter.py index 25f96bb..42d88b0 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -10,7 +10,7 @@ from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops as tf_sparse_csr from wums import logging -from rabbit import io_tools +from rabbit import external_likelihood, io_tools from rabbit import tfhelpers as tfh from rabbit.impacts import global_impacts, nonprofiled_impacts, traditional_impacts from rabbit.tfhelpers import edmval_cov @@ -123,7 +123,20 @@ def __init__( self.diagnostics = options.diagnostics self.minimizer_method = options.minimizerMethod self.hvp_method = getattr(options, "hvpMethod", "revrev") - self.jit_compile = getattr(options, "jitCompile", True) + # jitCompile is tri-state: "auto" (default, enable in dense mode + # and disable in sparse mode), "on" (force on, warn-and-fall-back + # in sparse mode), or "off" (force off). Backwards compatibility: + # accept legacy True / False values from programmatic callers. + _jit_opt = getattr(options, "jitCompile", "auto") + if _jit_opt is True: + _jit_opt = "on" + elif _jit_opt is False: + _jit_opt = "off" + if _jit_opt not in ("auto", "on", "off"): + raise ValueError( + f"jitCompile must be one of 'auto', 'on', 'off'; got {_jit_opt!r}" + ) + self.jit_compile = _jit_opt # When --noHessian is requested the postfit Hessian is never # computed, so the dense [npar, npar] covariance matrix should # not be allocated. self.cov is set to None in that case and @@ -341,66 +354,14 @@ def init_fit_parms( # one common regularization strength parameter self.tau = tf.Variable(1.0, trainable=True, name="tau", dtype=tf.float64) - # External likelihood terms (additive g^T x + 0.5 x^T H x contributions - # to the NLL). Resolve parameter name strings against the full fit - # parameter list (POIs + systs). Build a single name->index dict - # once so the per-term resolution is O(n) instead of O(n^2) -- the - # latter cost ~150 s on a 108k-parameter setup with a 108k-parameter - # external term. - self.external_terms = [] - parms_str = self.parms.astype(str) - parms_idx = {name: i for i, name in enumerate(parms_str)} - if len(parms_idx) != len(parms_str): - raise RuntimeError( - "Duplicate parameter names in fitter parameter list; " - "external term resolution requires unique names." - ) - for term in self.indata.external_terms: - params = np.asarray(term["params"]).astype(str) - indices = np.empty(len(params), dtype=np.int64) - for i, p in enumerate(params): - j = parms_idx.get(p, -1) - if j < 0: - raise RuntimeError( - f"External likelihood term '{term['name']}' parameter " - f"'{p}' not found in fit parameters" - ) - indices[i] = j - tf_indices = tf.constant(indices, dtype=tf.int64) - - tf_grad = ( - tf.constant(term["grad_values"], dtype=self.indata.dtype) - if term["grad_values"] is not None - else None - ) - - tf_hess_dense = None - tf_hess_csr = None - if term["hess_dense"] is not None: - tf_hess_dense = tf.constant(term["hess_dense"], dtype=self.indata.dtype) - elif term["hess_sparse"] is not None: - # Build a CSRSparseMatrix view of the stored sparse Hessian - # for use in the closed-form external gradient/HVP path via - # sm.matmul. The Hessian is assumed symmetric, so the loss - # L = 0.5 x_sub^T H x_sub has gradient H @ x_sub and HVP - # H @ p_sub, each a single sm.matmul call. NOTE: - # SparseMatrixMatMul has no XLA kernel, so any tf.function - # that calls sm.matmul must be built with jit_compile=False. - # The TensorWriter sorts the indices into canonical - # row-major order at write time, so we can feed the - # SparseTensor straight to the CSR builder without an - # additional reorder step. - tf_hess_csr = tf_sparse_csr.CSRSparseMatrix(term["hess_sparse"]) - - self.external_terms.append( - { - "name": term["name"], - "indices": tf_indices, - "grad": tf_grad, - "hess_dense": tf_hess_dense, - "hess_csr": tf_hess_csr, - } - ) + # External likelihood terms (additive g^T x + 0.5 x^T H x + # contributions to the NLL). See rabbit.external_likelihood for + # the construction helper and the matching scalar evaluator. + self.external_terms = external_likelihood.build_tf_external_terms( + self.indata.external_terms, + self.parms, + self.indata.dtype, + ) # constraint minima for nuisance parameters self.theta0 = tf.Variable( @@ -2319,37 +2280,10 @@ def _compute_nll_components(self, profile=True, full_nll=False): return ln, lc, lbeta, lpenalty, beta def _compute_external_nll(self): - """Sum of external likelihood term contributions: sum_i (g_i^T x_sub + 0.5 x_sub^T H_i x_sub). - - For sparse-Hessian terms this uses tf.linalg.sparse's CSR matmul, - which dispatches to a multi-threaded kernel and is much faster - per call than the previous element-wise gather-based form. The - autodiff gradient and HVP of 0.5 x^T H x via sm.matmul are - themselves single sm.matmul calls, so reverse-over-reverse autodiff - no longer rematerializes a 2D gather/scatter chain in the second- - order tape — that was the dominant cost on large external-Hessian - problems before this rewrite (e.g. jpsi: 329M-nnz prefit Hessian). - """ - if not self.external_terms: - return None - total = tf.zeros([], dtype=self.indata.dtype) - for term in self.external_terms: - x_sub = tf.gather(self.x, term["indices"]) - if term["grad"] is not None: - total = total + tf.reduce_sum(term["grad"] * x_sub) - if term["hess_dense"] is not None: - # 0.5 * x_sub^T H x_sub - total = total + 0.5 * tf.reduce_sum( - x_sub * tf.linalg.matvec(term["hess_dense"], x_sub) - ) - elif term["hess_csr"] is not None: - # Loss = 0.5 * x_sub^T H x_sub via CSR matvec (H symmetric). - Hx = tf.squeeze( - tf_sparse_csr.matmul(term["hess_csr"], x_sub[:, None]), - axis=-1, - ) - total = total + 0.5 * tf.reduce_sum(x_sub * Hx) - return total + """Sum of external likelihood term contributions: sum_i (g_i^T x_sub + 0.5 x_sub^T H_i x_sub).""" + return external_likelihood.compute_external_nll( + self.external_terms, self.x, self.indata.dtype + ) def _compute_nll(self, profile=True, full_nll=False): ln, lc, lbeta, lpenalty, beta = self._compute_nll_components( @@ -2378,9 +2312,29 @@ def _make_tf_functions(self): # # SparseMatrixMatMul has no XLA kernel, so any tf.function that # uses it (via _compute_yields_noBBB in sparse mode) cannot be - # jit-compiled. Force jit_compile off in sparse mode regardless - # of the user's --jitCompile setting. - jit = self.jit_compile and not self.indata.sparse + # jit-compiled. Resolve the tri-state self.jit_compile setting: + # + # "auto" -> enable jit in dense mode, silently disable in + # sparse mode (the default; sparse mode just can't + # use it). + # "on" -> enable jit when possible. In sparse mode emit a + # warning and disable, since the user explicitly + # asked for it but it's structurally impossible. + # "off" -> never enable jit. + if self.jit_compile == "off": + jit = False + elif self.jit_compile == "on": + if self.indata.sparse: + logger.warning( + "--jitCompile=on requested but input data is sparse; " + "XLA has no kernel for the sparse matmul ops used in " + "sparse mode, so jit_compile will be disabled." + ) + jit = False + else: + jit = True + else: # "auto" + jit = not self.indata.sparse def _loss_val(self): return self._compute_loss() diff --git a/rabbit/inputdata.py b/rabbit/inputdata.py index 27844e5..57c6f32 100644 --- a/rabbit/inputdata.py +++ b/rabbit/inputdata.py @@ -193,52 +193,11 @@ def __init__(self, filename, pseudodata=None): self.axis_procs = hist.axis.StrCategory(self.procs, name="processes") - # Load external likelihood terms (optional). - # Each entry is a dict with keys: - # name: str - # params: 1D ndarray of parameter name strings - # grad_values: 1D float ndarray or None - # hess_dense: 2D float ndarray or None - # hess_sparse: tf.sparse.SparseTensor or None - # (sparsity pattern of the [npar_sub, npar_sub] Hessian; - # same on-disk layout as hlogk_sparse / hnorm_sparse) - self.external_terms = [] - if "external_terms" in f.keys(): - names = [ - s.decode() if isinstance(s, bytes) else s - for s in f["hexternal_term_names"][...] - ] - ext_group = f["external_terms"] - for tname in names: - tg = ext_group[tname] - raw_params = tg["params"][...] - params = np.array( - [s.decode() if isinstance(s, bytes) else s for s in raw_params] - ) - grad_values = ( - np.asarray(maketensor(tg["grad_values"])) - if "grad_values" in tg.keys() - else None - ) - hess_dense = ( - np.asarray(maketensor(tg["hess_dense"])) - if "hess_dense" in tg.keys() - else None - ) - hess_sparse = ( - makesparsetensor(tg["hess_sparse"]) - if "hess_sparse" in tg.keys() - else None - ) - self.external_terms.append( - { - "name": tname, - "params": params, - "grad_values": grad_values, - "hess_dense": hess_dense, - "hess_sparse": hess_sparse, - } - ) + # Load external likelihood terms (optional). See + # rabbit.external_likelihood for the per-entry dict schema. + from rabbit.external_likelihood import read_external_terms_from_h5 + + self.external_terms = read_external_terms_from_h5(f.get("external_terms")) @tf.function def expected_events_nominal(self): diff --git a/rabbit/parsing.py b/rabbit/parsing.py index 3a06569..e622601 100644 --- a/rabbit/parsing.py +++ b/rabbit/parsing.py @@ -212,13 +212,15 @@ def common_parser(): "(forward-over-reverse, via tf.autodiff.ForwardAccumulator) is an alternative.", ) parser.add_argument( - "--noJitCompile", - dest="jitCompile", - default=True, - action="store_false", - help="Disable XLA jit_compile=True on the loss/gradient/HVP tf.functions. " - "jit_compile is enabled by default and substantially speeds up sparse-mode fits " - "with very large numbers of parameters.", + "--jitCompile", + default="auto", + type=str, + choices=["auto", "on", "off"], + help="Control XLA jit_compile=True on the loss/gradient/HVP tf.functions. " + "'auto' (default) enables jit_compile in dense mode and disables it in " + "sparse mode (where the CSR matmul kernels have no XLA implementation). " + "'on' forces jit_compile on (falling back to off with a warning in sparse " + "mode). 'off' disables jit_compile unconditionally.", ) parser.add_argument( "--chisqFit", diff --git a/rabbit/tensorwriter.py b/rabbit/tensorwriter.py index 7268ea5..9d40ef3 100644 --- a/rabbit/tensorwriter.py +++ b/rabbit/tensorwriter.py @@ -2152,13 +2152,11 @@ def create_dataset( ) beta_variations = None - # Write external likelihood terms + # Write external likelihood terms. Each term is written as a + # subgroup under "external_terms"; the reader iterates the + # subgroups directly, so no separate names list is needed. if self.external_terms: ext_group = f.create_group("external_terms") - create_dataset( - "external_term_names", - [t["name"] for t in self.external_terms], - ) for term in self.external_terms: term_group = ext_group.create_group(term["name"]) params_ds = term_group.create_dataset(