diff --git a/pyproject.toml b/pyproject.toml index 055af2f..7d9658f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dynamic = ["version"] dependencies = [ "ezmsg>=3.9.0", "ezmsg-baseproc>=1.10.2", - "ezmsg-sigproc>=2.22.0", + "ezmsg-sigproc>=2.23.0", "sparse>=0.17.0", "numpy>=2.2.6", ] @@ -65,3 +65,4 @@ known-third-party = ["ezmsg"] [tool.uv.sources] # Uncomment to use development version of ezmsg from git #ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "dev" } +ezmsg-sigproc = { git = "https://github.com/ezmsg-org/ezmsg-sigproc.git", branch = "generalize-binner" } diff --git a/src/ezmsg/event/binned.py b/src/ezmsg/event/binned.py index 4d49dab..553156a 100644 --- a/src/ezmsg/event/binned.py +++ b/src/ezmsg/event/binned.py @@ -1,127 +1,113 @@ +"""Bin an event stream into a lower-rate count (or rate) signal. + +The binning itself is delegated to +:obj:`ezmsg.sigproc.binned_aggregate.BinnedAggregate` so the spike-rate branch +shares a *single* bin-boundary implementation with the spike-band-power branch +(``Pow -> BinnedAggregate(MEAN)``). With ``fractional=True`` (the default) bins +span a fractional ``bin_duration * fs`` samples with a carry accumulator, track +wall-clock time, and are labelled with the nominal ``bin_duration`` gain -- +identical to how :obj:`ezmsg.event.rate.EventRate` is consumed downstream. That +makes the two branches land on the same grid at any input rate (including the +off-nominal ~30012 Hz of real recordings), so a downstream ``Merge`` aligns them +with no post-hoc reconciler. + +Sparse ``sparse.COO`` inputs (the default ``ThresholdCrossing`` output) are +densified to per-sample contributions before binning; dense inputs are used as +is. Set ``scale_by_value=True`` to weight each event by its stored value instead +of counting occurrences, and ``scale_output=True`` to divide the per-bin count +by ``bin_duration`` (events/second). +""" + import ezmsg.core as ez -import numpy as np -import numpy.typing as npt -from ezmsg.baseproc import ( - BaseStatefulTransformer, - BaseTransformerUnit, - processor_state, -) +import sparse +from array_api_compat import get_namespace +from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit +from ezmsg.sigproc.aggregate import AggregationFunction +from ezmsg.sigproc.binned_aggregate import BinnedAggregateSettings, BinnedAggregateTransformer from ezmsg.util.messages.axisarray import AxisArray, replace class BinnedEventAggregatorSettings(ez.Settings): bin_duration: float = 0.05 - """ - Duration of each bin in seconds. - This is the time interval over which events will be counted. - """ + """Duration of each output bin in seconds.""" scale_output: bool = True - """ - If True, the output will be scaled by the bin duration. - This is useful for converting counts to rates. - """ + """If True, divide each bin's count by ``bin_duration`` (events/second).""" axis: str = "time" + """Name of the axis to bin along.""" + fractional: bool = True + """If True (default), share ``EventRate``'s fractional wall-clock grid via + :obj:`BinnedAggregate` (nominal ``bin_duration`` gain). If False, use a fixed + ``round(bin_duration * fs)`` sample-locked grid. See :obj:`BinnedAggregate`.""" -@processor_state -class BinnedEventAggregatorState: - n_overflow: int = 0 - counts_in_overflow: npt.NDArray[np.int64] | None = None - + scale_by_value: bool = False + """If True, weight each event by its stored value; if False (default), every + nonzero entry contributes a count of 1.""" -class BinnedEventAggregator( - BaseStatefulTransformer[BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregatorState] -): - def _hash_message(self, message: AxisArray) -> int: - targ_ax_idx = message.get_axis_idx(self.settings.axis) - non_targ_dims = message.dims[:targ_ax_idx] + message.dims[targ_ax_idx + 1 :] - return hash(tuple(non_targ_dims)) - def _reset_state(self, message: AxisArray) -> None: - self._state.n_overflow = 0 - targ_axis_idx = message.get_axis_idx(self.settings.axis) - buff_shape = message.data.shape[:targ_axis_idx] + message.data.shape[targ_axis_idx + 1 :] - self._state.counts_in_overflow = np.zeros(buff_shape, dtype=np.int64) +class BinnedEventAggregator(BaseTransformer[BinnedEventAggregatorSettings, AxisArray, AxisArray]): + """Count events per fixed-duration bin, delegating binning to sigproc. - def _process(self, message: AxisArray) -> AxisArray: - # Quick maths - targ_ax_idx = message.get_axis_idx(self.settings.axis) - targ_axis = message.axes[self.settings.axis] - samples_per_bin = int(self.settings.bin_duration * (1 / targ_axis.gain)) - - # We will be slicing the data several times, so create a variable to hold the slices - var_slice = [slice(None)] * message.data.ndim - - # Store for later use - n_prev_overflow = self._state.n_overflow - - if self._state.n_overflow > 0: - # Calculate how many samples from the input msg we can fit into the first bin, - # given the current overflow state - n_first = samples_per_bin - self._state.n_overflow - # Sum the number of samples in the first bin then add to self._state.counts_in_overflow - var_slice[targ_ax_idx] = slice(0, n_first) - first_bin_counts = message.data[tuple(var_slice)].sum(axis=targ_ax_idx).todense() - first_bin_counts += self._state.counts_in_overflow - else: - n_first = 0 - first_bin_counts = self._state.counts_in_overflow - assert np.all(first_bin_counts == 0), "Overflow state should be zeroed out from previous iteration." - - # Calculate how many samples remain after the first bin - n_remaining = message.data.shape[targ_ax_idx] - n_first - n_full_bins = int(n_remaining / samples_per_bin) - - # Slice the n_first:-next_overflow samples into a segment that divides evenly into bins - split_idx = n_first + n_full_bins * samples_per_bin - var_slice[targ_ax_idx] = slice(n_first, split_idx) - full_bins_data = message.data[tuple(var_slice)] - - # Reshape and sum for full bins - new_shape = ( - full_bins_data.shape[:targ_ax_idx] - + (n_full_bins, samples_per_bin) - + full_bins_data.shape[targ_ax_idx + 1 :] - ) - middle_bin_counts = full_bins_data.reshape(new_shape).sum(axis=targ_ax_idx + 1).todense() + The per-bin reduction, carry across message boundaries, and output time axis + all come from :obj:`BinnedAggregateTransformer`; this wrapper only converts + events to per-sample contributions and optionally rate-normalizes. + """ - # Prepare output - if self._state.n_overflow > 0: - first_bin_counts = first_bin_counts.reshape( - first_bin_counts.shape[:targ_ax_idx] + (1,) + first_bin_counts.shape[targ_ax_idx:] + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._binner = BinnedAggregateTransformer( + BinnedAggregateSettings( + axis=self.settings.axis, + bin_duration=self.settings.bin_duration, + operation=AggregationFunction.SUM, + fractional=self.settings.fractional, ) - output_data = np.concatenate([first_bin_counts, middle_bin_counts], axis=targ_ax_idx) - else: - output_data = middle_bin_counts - - if self.settings.scale_output: - output_data = output_data / self.settings.bin_duration - - # Create the new output axis - # For the target axis, backup the offset by the number of samples in the overflow - out_axis = replace( - targ_axis, - gain=targ_axis.gain * samples_per_bin, - offset=targ_axis.offset - n_prev_overflow * targ_axis.gain, ) - out_msg = replace( - message, - data=output_data, - axes={k: v if k != self.settings.axis else out_axis for k, v in message.axes.items()}, - ) - - # Calculate and store the overflow state. - var_slice[targ_ax_idx] = slice(split_idx, None) - overflow_data = message.data[tuple(var_slice)] - self._state.n_overflow = overflow_data.shape[targ_ax_idx] - self._state.counts_in_overflow = overflow_data.sum(axis=targ_ax_idx).todense() - return out_msg + def _process(self, message: AxisArray) -> AxisArray: + data = message.data + if isinstance(data, sparse.SparseArray): + data = data.todense() + xp = get_namespace(data) + # Per-sample contribution: the event value, or 1 per nonzero entry. + # float64 where usable (exact integer counts; matches the legacy output + # dtype); float32 otherwise. MLX *exposes* ``mx.float64`` as an attribute + # but the GPU rejects it ("float64 is not supported on the GPU"), so + # ``hasattr(xp, "float64")`` is not a sufficient capability check -- + # detect the MLX namespace explicitly and fall back to float32. + is_mlx = getattr(xp, "__name__", "") == "mlx.core" + float_dtype = xp.float64 if (hasattr(xp, "float64") and not is_mlx) else xp.float32 + contrib = data if self.settings.scale_by_value else (data != 0) + contrib = contrib.astype(float_dtype) + + binned = self._binner(replace(message, data=contrib)) + + if self.settings.scale_output and binned.data.size: + binned = replace(binned, data=binned.data / self.settings.bin_duration) + return binned class BinnedEventAggregatorUnit( BaseTransformerUnit[BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregator] ): SETTINGS = BinnedEventAggregatorSettings + + +def binned_event_aggregator( + bin_duration: float = 0.05, + scale_output: bool = True, + axis: str = "time", + fractional: bool = True, + scale_by_value: bool = False, +) -> BinnedEventAggregator: + return BinnedEventAggregator( + BinnedEventAggregatorSettings( + bin_duration=bin_duration, + scale_output=scale_output, + axis=axis, + fractional=fractional, + scale_by_value=scale_by_value, + ) + ) diff --git a/src/ezmsg/event/rate.py b/src/ezmsg/event/rate.py index e2d07dc..c7b1d37 100644 --- a/src/ezmsg/event/rate.py +++ b/src/ezmsg/event/rate.py @@ -2,40 +2,43 @@ from ezmsg.baseproc import BaseTransformerUnit from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.event.kernel_activation import ( - ActivationKernelType, - BinAggregation, - BinnedKernelActivation, - BinnedKernelActivationSettings, -) +from ezmsg.event.binned import BinnedEventAggregator, BinnedEventAggregatorSettings class EventRateSettings(ez.Settings): bin_duration: float = 0.05 + fractional: bool = True + """If True (default), bins track wall-clock time on the nominal-``bin_duration`` + grid (fractional samples-per-bin with a carry accumulator). If False, bins are + a fixed ``round(bin_duration * fs)`` samples (sample-locked). See + :obj:`ezmsg.sigproc.binned_aggregate.BinnedAggregate`.""" -class Rate(BinnedKernelActivation): + +class Rate(BinnedEventAggregator): """ Event rate calculator (events per second). - Counts events per bin and divides by bin_duration to get rate in events/second. + Counts events per bin and divides by ``bin_duration`` to get rate in + events/second. Binning is delegated to + :obj:`ezmsg.sigproc.binned_aggregate.BinnedAggregate`, so the spike-rate + output shares one bin-boundary implementation -- and therefore the same + output grid -- with the spike-band-power branch. """ def __init__(self, settings: EventRateSettings) -> None: super().__init__( - BinnedKernelActivationSettings( - kernel_type=ActivationKernelType.COUNT, - tau=1.0, # Not used for COUNT + BinnedEventAggregatorSettings( bin_duration=settings.bin_duration, - aggregation=BinAggregation.SUM, - scale_by_value=False, - normalize=False, - rate_normalize=True, + scale_output=True, # counts -> events/second + axis="time", + fractional=settings.fractional, + scale_by_value=False, # count events, ignore stored peak values ) ) class EventRate(BaseTransformerUnit[EventRateSettings, AxisArray, AxisArray, Rate]): - """Unit for computing event rate from sparse events.""" + """Unit for computing event rate from sparse (or dense) events.""" SETTINGS = EventRateSettings diff --git a/tests/test_binned.py b/tests/test_binned.py index 69c2797..236c737 100644 --- a/tests/test_binned.py +++ b/tests/test_binned.py @@ -1,6 +1,7 @@ import time import numpy as np +import pytest import sparse from conftest import CHUNK_LEN, FS, N_CH, make_sparse_event_msg from ezmsg.util.messages.axisarray import AxisArray @@ -87,3 +88,124 @@ def test_binned_event_aggregator_empty_time_first(): out_normal = proc(msg_normal) assert out_normal.data.ndim == 2 assert out_normal.data.shape[1] == N_CH + + +def _sparse_chunks(spk: np.ndarray, fs: float, block: int) -> list[AxisArray]: + out = [] + for start in range(0, spk.shape[0], block): + out.append( + AxisArray( + data=sparse.COO.from_numpy(spk[start : start + block]), + dims=["time", "ch"], + axes={"time": AxisArray.TimeAxis(fs=fs, offset=start / fs)}, + ) + ) + return out + + +def _run(proc, msgs): + return [r for r in (proc(m) for m in msgs) if r.data.size] + + +@pytest.mark.parametrize("fs", [30_000.0, 30_012.0]) +def test_fractional_grid_offnominal(fs: float): + """At an off-nominal rate the fractional binner stays on the nominal-gain + wall-clock grid (gain == bin_duration, n_bins == int(n / (bin_duration*fs))), + which is what aligns it with the spike-band-power branch. The legacy + sample-locked path (Window) would instead report gain int(bin*fs)/fs.""" + bin_dur = 0.02 + n = 300_000 + spk = (np.random.default_rng(0).random((n, N_CH)) < 0.01).astype(float) + + proc = BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur)) + out = _run(proc, _sparse_chunks(spk, fs, 7000)) + + spb = bin_dur * fs + assert sum(m.data.shape[0] for m in out) == int(n / spb) + assert out[0].axes["time"].gain == pytest.approx(bin_dur) + + +@pytest.mark.parametrize("fs", [30_000.0, 30_012.0]) +def test_chunk_invariance_offnominal(fs: float): + """Output is identical regardless of how the stream is chunked.""" + bin_dur = 0.02 + spk = (np.random.default_rng(1).random((120_000, N_CH)) < 0.02).astype(float) + + whole = np.concatenate( + [m.data for m in _run(BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur)), + _sparse_chunks(spk, fs, 120_000))], + axis=0, + ) + frag = np.concatenate( + [m.data for m in _run(BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur)), + _sparse_chunks(spk, fs, 137))], + axis=0, + ) + assert whole.shape == frag.shape + np.testing.assert_array_equal(whole, frag) + + +def test_count_vs_rate_scaling(): + """scale_output divides counts by bin_duration; otherwise raw counts.""" + fs = 30_000.0 + bin_dur = 0.02 + spk = (np.random.default_rng(2).random((60_000, N_CH)) < 0.02).astype(float) + msgs = _sparse_chunks(spk, fs, 60_000) + + counts = _run( + BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, scale_output=False)), msgs + )[0] + rate = _run( + BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, scale_output=True)), msgs + )[0] + + # Counts are exact integers; rate is counts / bin_duration. + assert np.allclose(counts.data, np.round(counts.data)) + np.testing.assert_allclose(rate.data, counts.data / bin_dur, rtol=0, atol=1e-9) + + +def test_scale_by_value_weights_by_stored_value(): + """scale_by_value=True sums each event's stored value per bin; the default + counts nonzero entries (1 each).""" + fs = 30_000.0 + bin_dur = 0.02 # 600 samples/bin (exact at this fs) + spb = int(bin_dur * fs) + n = 3 * spb + + # One channel; known events with distinct values: two in bin 0, one in bin 1. + values = np.zeros((n, 1), dtype=float) + values[10, 0] = 2.0 + values[20, 0] = 3.0 # bin 0 -> count 2, value sum 5.0 + values[spb + 5, 0] = 7.0 # bin 1 -> count 1, value sum 7.0 + msg = AxisArray( + data=sparse.COO.from_numpy(values), + dims=["time", "ch"], + axes={"time": AxisArray.TimeAxis(fs=fs, offset=0.0)}, + ) + + count = BinnedEventAggregator( + settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, scale_output=False, scale_by_value=False) + )(msg) + weighted = BinnedEventAggregator( + settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, scale_output=False, scale_by_value=True) + )(msg) + + np.testing.assert_array_equal(count.data[:, 0], [2.0, 1.0, 0.0]) + np.testing.assert_allclose(weighted.data[:, 0], [5.0, 7.0, 0.0]) + + +@pytest.mark.parametrize("fs", [30_000.0, 30_012.0]) +def test_fractional_false_sample_locked(fs: float): + """fractional=False bins a fixed round(bin_duration*fs) sample count, so the + output gain is round(bin*fs)/fs (sample-locked) -- at an off-nominal rate this + differs from the nominal bin_duration gain of the fractional grid.""" + bin_dur = 0.02 + n = 300_000 + spk = (np.random.default_rng(3).random((n, N_CH)) < 0.01).astype(float) + + proc = BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, fractional=False)) + out = _run(proc, _sparse_chunks(spk, fs, 7000)) + + spb = round(bin_dur * fs) + assert sum(m.data.shape[0] for m in out) == n // spb + assert out[0].axes["time"].gain == pytest.approx(spb / fs)