From 91beabd4813483119978a97c12056d403bc79c24 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 26 Jun 2026 00:21:14 -0400 Subject: [PATCH 1/6] Delegate BinnedKernelActivation boundaries to BinSchedule EventRate's bin grid (samples-per-bin, fractional carry, global bin index, output gain/offset) was computed inline and duplicated the same arithmetic ezmsg-sigproc's dense binner uses. Both now consume the shared ezmsg.sigproc.util.binning.BinSchedule, so they land on an identical grid by shared code rather than by coincidence. - BinnedKernelActivationState drops bin_accumulator in favor of a BinSchedule (fractional, since EventRate bins always track the nominal bin_duration). - _process_events and _process_dense_count_sum now get n_bins, chunk-local bin ends, output gain and output offset from BinSchedule.advance(); the per-bin reduction and activation-state logic are unchanged. - New cross-package test asserts EventRate and BinnedAggregate(SUM) produce the same output time axis and bin counts at fs in {1000, 30012, 30030} under adversarial chunking -- the alignment the shared primitive guarantees. Requires an ezmsg-sigproc release that ships BinSchedule; the dependency pin will need bumping before this can merge. --- src/ezmsg/event/kernel_activation.py | 75 +++++++++++-------------- tests/test_bin_schedule_crosspackage.py | 69 +++++++++++++++++++++++ 2 files changed, 103 insertions(+), 41 deletions(-) create mode 100644 tests/test_bin_schedule_crosspackage.py diff --git a/src/ezmsg/event/kernel_activation.py b/src/ezmsg/event/kernel_activation.py index 8f3d510..f7136d1 100644 --- a/src/ezmsg/event/kernel_activation.py +++ b/src/ezmsg/event/kernel_activation.py @@ -23,6 +23,7 @@ import sparse from array_api_compat import get_namespace, is_numpy_array from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state +from ezmsg.sigproc.util.binning import BinSchedule from ezmsg.util.messages.axisarray import AxisArray, replace @@ -96,8 +97,10 @@ class BinnedKernelActivationState: # Input sample rate (cached from first message) fs: float | None = None - # Accumulated fractional bin samples for proper bin alignment - bin_accumulator: float = 0.0 + # Shared bin-boundary schedule. Owns samples-per-bin, the fractional carry, + # the global bin index, and the output gain/offset -- the same primitive the + # ezmsg-sigproc dense binner uses, so both land on an identical grid. + schedule: BinSchedule | None = None class BinnedKernelActivation( @@ -141,7 +144,6 @@ def _reset_state(self, message: AxisArray) -> None: self._state.activation = np.zeros(n_channels, dtype=np.float64) self._state.samples_since_update = np.zeros(n_channels, dtype=np.int64) - self._state.bin_accumulator = 0.0 # For alpha kernel, we need auxiliary state if self.settings.kernel_type == ActivationKernelType.ALPHA: @@ -152,6 +154,11 @@ def _reset_state(self, message: AxisArray) -> None: if time_axis.gain > 0: self._state.fs = 1.0 / time_axis.gain + # EventRate's binning is always fractional (bins track the nominal + # bin_duration). The schedule owns the boundary arithmetic from here on. + self._state.schedule = BinSchedule(bin_duration=self.settings.bin_duration, fractional=True) + self._state.schedule.reset(self._state.fs) + def _decay_to_sample(self, channel: int, target_sample: int) -> None: """ Decay activation state to target sample. @@ -238,16 +245,16 @@ def _process_events(self, message: AxisArray) -> AxisArray: n_samples = sparse_data.shape[0] n_channels = sparse_data.shape[1] if sparse_data.ndim > 1 else 1 - # Calculate bin parameters - samples_per_bin = self.settings.bin_duration * self._state.fs - total_samples = n_samples + self._state.bin_accumulator - n_bins = int(total_samples / samples_per_bin) + # Boundary arithmetic is delegated to the shared schedule (samples-per-bin, + # fractional carry, global bin index, output gain/offset). + in_offset = message.axes["time"].offset if "time" in message.axes else 0.0 + n_carry_before = self._state.schedule.carry_count + step = self._state.schedule.advance(n_new=n_samples, in_offset=in_offset, gain_in=1.0 / self._state.fs) + n_bins = step.n_bins if n_bins == 0: - # Not enough samples for a full bin yet - self._state.bin_accumulator = total_samples - - # Still need to process events to update state + # Not enough samples for a full bin yet (the schedule has folded these + # samples into its carry). Still process events to update state. if hasattr(sparse_data, "coords") and hasattr(sparse_data, "data"): coords = sparse_data.coords values = sparse_data.data @@ -263,18 +270,15 @@ def _process_events(self, message: AxisArray) -> AxisArray: data=np.zeros((0, n_channels), dtype=np.float64), axes={ **message.axes, - "time": replace(message.axes["time"], gain=self.settings.bin_duration), + "time": replace(message.axes["time"], gain=step.output_gain), }, ) - # Calculate bin boundaries (in input samples, relative to chunk start) - # Account for accumulator from previous chunk - accumulator_before = self._state.bin_accumulator # Save for offset calculation - first_bin_end = samples_per_bin - self._state.bin_accumulator - bin_ends = first_bin_end + np.arange(n_bins) * samples_per_bin - - # Update accumulator for next chunk - self._state.bin_accumulator = total_samples - n_bins * samples_per_bin + # Bin ends in input samples relative to *this chunk's* start. The schedule + # returns cut points into [carry ++ new]; subtracting the pre-advance carry + # count maps them back to chunk-local indices (identical to the legacy + # `(spb - acc) + arange*spb` truncated formula). + bin_ends = np.asarray(step.cut_points, dtype=np.int64) - n_carry_before # Collect events sorted by time events = [] @@ -366,12 +370,6 @@ def _process_events(self, message: AxisArray) -> AxisArray: if self.settings.rate_normalize: output = output / self.settings.bin_duration - # Calculate output time offset - # The first bin starts at (input_offset - accumulator_time) - input_offset = message.axes["time"].offset if "time" in message.axes else 0.0 - accumulator_time = accumulator_before / self._state.fs - output_offset = input_offset - accumulator_time - return replace( message, data=output, @@ -379,7 +377,7 @@ def _process_events(self, message: AxisArray) -> AxisArray: **message.axes, "time": AxisArray.TimeAxis( fs=1.0 / self.settings.bin_duration, - offset=output_offset, + offset=step.output_offset, ), }, ) @@ -397,10 +395,10 @@ def _process_dense_count_sum(self, message: AxisArray) -> AxisArray: n_samples = data.shape[0] feature_shape = tuple(data.shape[1:]) - samples_per_bin = self.settings.bin_duration * self._state.fs - accumulator_before = self._state.bin_accumulator - total_samples = n_samples + accumulator_before - n_bins = int(total_samples / samples_per_bin) + in_offset = message.axes["time"].offset if "time" in message.axes else 0.0 + n_carry_before = self._state.schedule.carry_count + step = self._state.schedule.advance(n_new=n_samples, in_offset=in_offset, gain_in=1.0 / self._state.fs) + n_bins = step.n_bins # Per-sample contribution: 1 per non-zero, or the value itself if scaling. # Use the .astype() method form so the same call works for both numpy and mlx @@ -419,20 +417,18 @@ def _process_dense_count_sum(self, message: AxisArray) -> AxisArray: # No complete bins this chunk — accumulate everything into the carry-over. new_overflow = overflow_xp + (xp.sum(contrib, axis=0) if n_samples > 0 else overflow_xp * 0) self._state.activation = np.asarray(new_overflow).reshape(self._state.activation.shape) - self._state.bin_accumulator = total_samples return replace( message, data=xp.zeros((0,) + feature_shape, dtype=xp.float32), axes={ **message.axes, - "time": replace(message.axes["time"], gain=self.settings.bin_duration), + "time": replace(message.axes["time"], gain=step.output_gain), }, ) - # Bin boundaries (in input-sample space, integer-truncated as in the event-based path). - first_bin_end = samples_per_bin - accumulator_before - bin_ends_float = first_bin_end + np.arange(n_bins) * samples_per_bin - bin_end_samples = bin_ends_float.astype(np.int64) + # Bin boundaries in this chunk's sample space, from the shared schedule + # (cut points into [carry ++ new] mapped back to chunk-local indices). + bin_end_samples = np.asarray(step.cut_points, dtype=np.int64) - n_carry_before bin_start_samples = np.concatenate(([np.int64(0)], bin_end_samples[:-1])) # Cumulative sum, prepended with zeros so cumsum_padded[k] = sum(contrib[:k]). @@ -462,14 +458,11 @@ def _process_dense_count_sum(self, message: AxisArray) -> AxisArray: else: new_overflow = xp.zeros(feature_shape, dtype=cumsum.dtype) self._state.activation = np.asarray(new_overflow).reshape(self._state.activation.shape) - self._state.bin_accumulator = total_samples - n_bins * samples_per_bin if self.settings.rate_normalize: output = output / self.settings.bin_duration - accumulator_time = accumulator_before / self._state.fs - input_offset = message.axes["time"].offset if "time" in message.axes else 0.0 - output_offset = input_offset - accumulator_time + output_offset = step.output_offset return replace( message, diff --git a/tests/test_bin_schedule_crosspackage.py b/tests/test_bin_schedule_crosspackage.py new file mode 100644 index 0000000..a19d2d3 --- /dev/null +++ b/tests/test_bin_schedule_crosspackage.py @@ -0,0 +1,69 @@ +"""Cross-package proof that EventRate and the ezmsg-sigproc dense binner share a grid. + +Both now route their bin boundaries through ``ezmsg.sigproc.util.binning.BinSchedule``. +This test feeds identical data to the real ``EventRate`` (count/sum path) and to +``BinnedAggregateTransformer`` (operation=SUM) and asserts they produce the same +output time axis (gain + offset) and the same bin counts -- the alignment property +the shared primitive exists to guarantee -- including at an off-nominal sample +rate and under adversarial chunking. + +EventRate has ``rate_normalize=True`` (counts / bin_duration), so its values equal +the SUM binner's values divided by bin_duration; everything else must match exactly. +""" + +import numpy as np +import pytest +from ezmsg.sigproc.aggregate import AggregationFunction +from ezmsg.sigproc.binned_aggregate import BinnedAggregateTransformer +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.event.rate import EventRateSettings, Rate + + +def _make_msg(arr: np.ndarray, fs: float, offset: float) -> AxisArray: + return AxisArray( + data=arr, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=fs, offset=offset), + "ch": AxisArray.CoordinateAxis(data=np.arange(arr.shape[1]), dims=["ch"]), + }, + ) + + +@pytest.mark.parametrize("fs", [1000.0, 30012.0, 30030.0]) +@pytest.mark.parametrize("block_size", [50000, 1, 777]) +def test_eventrate_and_dense_binner_share_grid(fs: float, block_size: int): + bin_duration = 0.02 + rng = np.random.default_rng(0) + # Dense spike train (0/1). EventRate's dense COUNT+SUM path counts non-zeros; + # BinnedAggregate(SUM) sums them -- same quantity per bin. + spikes = (rng.random((40000, 3)) < 0.01).astype(np.float32) + + rate = Rate(EventRateSettings(bin_duration=bin_duration)) + binner = BinnedAggregateTransformer( + axis="time", bin_duration=bin_duration, operation=AggregationFunction.SUM, fractional=True + ) + + rate_out, binner_out, samp_off = [], [], 0 + for start in range(0, spikes.shape[0], block_size): + chunk = spikes[start : start + block_size] + offset = samp_off / fs + rate_out.append(rate(_make_msg(chunk, fs, offset))) + binner_out.append(binner(_make_msg(chunk, fs, offset))) + samp_off += chunk.shape[0] + + # Per-message: identical gain + offset on every non-empty output, and matching + # bin counts. This is exactly what a downstream Merge(align_axis="time") needs. + for r, b in zip(rate_out, binner_out): + assert r.data.shape[0] == b.data.shape[0] + if r.data.shape[0] == 0: + continue + assert r.axes["time"].gain == pytest.approx(b.axes["time"].gain) + assert r.axes["time"].offset == pytest.approx(b.axes["time"].offset) + + rate_all = np.concatenate([m.data for m in rate_out], axis=0) + binner_all = np.concatenate([m.data for m in binner_out], axis=0) + assert rate_all.shape == binner_all.shape + # EventRate divides counts by bin_duration (rate_normalize); undo to compare. + np.testing.assert_allclose(rate_all * bin_duration, binner_all, rtol=0, atol=1e-6) From e685c02775aa1de76d0ff4c06fd11f7671a3102f Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 26 Jun 2026 00:34:25 -0400 Subject: [PATCH 2/6] Add fractional=False (sample-locked) mode to EventRate BinSchedule already carries the fractional flag, so exposing a sample-locked mode on EventRate is just plumbing plus using the schedule's actual gain. - EventRateSettings/BinnedKernelActivationSettings gain a `fractional` field (default True; threaded through Rate). False bins on Window's int(bin_duration*fs) grid so a sample-locked SBP branch and the spike-rate branch share one grid. - Both output paths now label the axis with step.output_gain (the schedule's actual seconds-per-bin) instead of the nominal bin_duration, and rate_normalize divides by that same actual duration. In fractional mode output_gain == bin_duration, so default behavior is unchanged (full suite still green). - Cross-package test asserts EventRate(fractional=False) is sample-locked to Window's gain and shares the grid with BinnedAggregate(fractional=False). --- src/ezmsg/event/kernel_activation.py | 32 +++++++++++-------- src/ezmsg/event/rate.py | 8 +++++ tests/test_bin_schedule_crosspackage.py | 41 +++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 12 deletions(-) diff --git a/src/ezmsg/event/kernel_activation.py b/src/ezmsg/event/kernel_activation.py index f7136d1..471129a 100644 --- a/src/ezmsg/event/kernel_activation.py +++ b/src/ezmsg/event/kernel_activation.py @@ -78,7 +78,16 @@ class BinnedKernelActivationSettings(ez.Settings): """If True, normalize kernel so integral equals 1.""" rate_normalize: bool = False - """If True, divide output by bin_duration to get events/second (for COUNT kernel).""" + """If True, divide output by the bin's duration to get events/second (for COUNT + kernel). The divisor is the *actual* bin duration (``samples_per_bin / fs``), + which equals ``bin_duration`` exactly in fractional mode.""" + + fractional: bool = True + """If True (default), bins span a *fractional* ``bin_duration * fs`` samples + with a carry accumulator; bins track the nominal duration and the output gain + is exactly ``bin_duration``. If False, bins span a *fixed* ``int(bin_duration * + fs)`` samples and the output gain is ``int(bin_duration * fs) / fs`` + (sample-locked, matching :obj:`ezmsg.sigproc.window.Window`).""" @processor_state @@ -154,9 +163,9 @@ def _reset_state(self, message: AxisArray) -> None: if time_axis.gain > 0: self._state.fs = 1.0 / time_axis.gain - # EventRate's binning is always fractional (bins track the nominal - # bin_duration). The schedule owns the boundary arithmetic from here on. - self._state.schedule = BinSchedule(bin_duration=self.settings.bin_duration, fractional=True) + # The schedule owns the boundary arithmetic from here on; fractional mode + # tracks the nominal bin_duration, sample-locked mode matches Window's grid. + self._state.schedule = BinSchedule(bin_duration=self.settings.bin_duration, fractional=self.settings.fractional) self._state.schedule.reset(self._state.fs) def _decay_to_sample(self, channel: int, target_sample: int) -> None: @@ -366,9 +375,10 @@ def _process_events(self, message: AxisArray) -> AxisArray: # Update state sample counters relative to next chunk self._state.samples_since_update -= n_samples - # Apply rate normalization if requested (divide by bin_duration to get events/second) + # Apply rate normalization if requested (divide by the bin's actual + # duration to get events/second; == bin_duration in fractional mode). if self.settings.rate_normalize: - output = output / self.settings.bin_duration + output = output / step.output_gain return replace( message, @@ -376,7 +386,7 @@ def _process_events(self, message: AxisArray) -> AxisArray: axes={ **message.axes, "time": AxisArray.TimeAxis( - fs=1.0 / self.settings.bin_duration, + fs=1.0 / step.output_gain, offset=step.output_offset, ), }, @@ -460,9 +470,7 @@ def _process_dense_count_sum(self, message: AxisArray) -> AxisArray: self._state.activation = np.asarray(new_overflow).reshape(self._state.activation.shape) if self.settings.rate_normalize: - output = output / self.settings.bin_duration - - output_offset = step.output_offset + output = output / step.output_gain return replace( message, @@ -470,8 +478,8 @@ def _process_dense_count_sum(self, message: AxisArray) -> AxisArray: axes={ **message.axes, "time": AxisArray.TimeAxis( - fs=1.0 / self.settings.bin_duration, - offset=output_offset, + fs=1.0 / step.output_gain, + offset=step.output_offset, ), }, ) diff --git a/src/ezmsg/event/rate.py b/src/ezmsg/event/rate.py index e2d07dc..85e864c 100644 --- a/src/ezmsg/event/rate.py +++ b/src/ezmsg/event/rate.py @@ -13,6 +13,13 @@ class EventRateSettings(ez.Settings): bin_duration: float = 0.05 + fractional: bool = True + """If True (default), bins span a fractional ``bin_duration * fs`` samples and + the output rate is exactly ``1 / bin_duration``. If False, bins are + sample-locked to ``int(bin_duration * fs)`` samples (output rate + ``fs / int(bin_duration * fs)``), matching :obj:`ezmsg.sigproc.window.Window` + so a sample-locked SBP branch and the spike-rate branch share one grid.""" + class Rate(BinnedKernelActivation): """ @@ -31,6 +38,7 @@ def __init__(self, settings: EventRateSettings) -> None: scale_by_value=False, normalize=False, rate_normalize=True, + fractional=settings.fractional, ) ) diff --git a/tests/test_bin_schedule_crosspackage.py b/tests/test_bin_schedule_crosspackage.py index a19d2d3..e4f3701 100644 --- a/tests/test_bin_schedule_crosspackage.py +++ b/tests/test_bin_schedule_crosspackage.py @@ -67,3 +67,44 @@ def test_eventrate_and_dense_binner_share_grid(fs: float, block_size: int): assert rate_all.shape == binner_all.shape # EventRate divides counts by bin_duration (rate_normalize); undo to compare. np.testing.assert_allclose(rate_all * bin_duration, binner_all, rtol=0, atol=1e-6) + + +@pytest.mark.parametrize("fs", [30012.0, 30030.0]) +@pytest.mark.parametrize("block_size", [40000, 1, 333]) +def test_eventrate_fractional_false_is_sample_locked(fs: float, block_size: int): + """fractional=False EventRate bins on Window's int(bin_duration*fs) grid and + shares it with BinnedAggregate(fractional=False). Its output rate is the + sample-locked fs/int(bin_duration*fs), not the nominal 1/bin_duration.""" + bin_duration = 0.02 + window_samples = int(bin_duration * fs) + expected_gain = window_samples / fs + rng = np.random.default_rng(1) + spikes = (rng.random((40000, 2)) < 0.01).astype(np.float32) + + rate = Rate(EventRateSettings(bin_duration=bin_duration, fractional=False)) + binner = BinnedAggregateTransformer( + axis="time", bin_duration=bin_duration, operation=AggregationFunction.SUM, fractional=False + ) + + rate_out, binner_out, samp_off = [], [], 0 + for start in range(0, spikes.shape[0], block_size): + chunk = spikes[start : start + block_size] + offset = samp_off / fs + rate_out.append(rate(_make_msg(chunk, fs, offset))) + binner_out.append(binner(_make_msg(chunk, fs, offset))) + samp_off += chunk.shape[0] + + for r, b in zip(rate_out, binner_out): + assert r.data.shape[0] == b.data.shape[0] + if r.data.shape[0] == 0: + continue + # Sample-locked gain (== Window's), not the nominal bin_duration. + assert r.axes["time"].gain == pytest.approx(expected_gain) + assert r.axes["time"].gain == pytest.approx(b.axes["time"].gain) + assert r.axes["time"].offset == pytest.approx(b.axes["time"].offset) + + rate_all = np.concatenate([m.data for m in rate_out], axis=0) + binner_all = np.concatenate([m.data for m in binner_out], axis=0) + assert rate_all.shape == binner_all.shape + # rate_normalize now divides by the actual bin duration (== expected_gain). + np.testing.assert_allclose(rate_all * expected_gain, binner_all, rtol=0, atol=1e-6) From 4224c513ca3d0df6518ecf57f8a13681e5ea4805 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 26 Jun 2026 01:00:38 -0400 Subject: [PATCH 3/6] bump sigproc dep --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 91929ff..80a2305 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dynamic = ["version"] dependencies = [ "ezmsg>=3.9.0", "ezmsg-baseproc>=1.10.2", - "ezmsg-sigproc>=2.23.0", + "ezmsg-sigproc>=2.26.0", "sparse>=0.17.0", "numpy>=2.2.6", ] From 5b69c9289a582af5d8464592a993e1ee8ea04bd3 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 26 Jun 2026 01:21:24 -0400 Subject: [PATCH 4/6] Migrate BinnedEventAggregator onto the shared sigproc binner binned.py still carried its own integer-locked event-counting binner, the last event binner not on the shared grid. Rewrite it as a thin wrapper over ezmsg.sigproc.binned_aggregate.BinnedAggregate (operation=SUM): densify sparse events to per-sample contributions, delegate all boundary/carry/axis math to the shared binner, and optionally rate-normalize. Ported from #32 (Kyle McGraw), adjusted for the enhanced sigproc shipped in 2.26.0: - fractional=False is documented and tested as int(bin_duration*fs) truncation (matching Window and BinSchedule), not round(); the sample-locked test now includes fs=30030 where truncation (600) and rounding (601) diverge. - Adds scale_by_value (weight events by stored value), a binned_event_aggregator factory, and tests for the off-nominal fractional grid, chunk invariance, count-vs-rate scaling, scale_by_value, and the sample-locked grid. This leaves Rate/EventRate on the kernel_activation path (event-optimized, O(n_events) sparse) from the companion commits; BinnedEventAggregator is now also on the shared grid, so no event binner remains on the old divergent one. --- src/ezmsg/event/binned.py | 190 ++++++++++++++++++-------------------- tests/test_binned.py | 133 ++++++++++++++++++++++++++ 2 files changed, 221 insertions(+), 102 deletions(-) diff --git a/src/ezmsg/event/binned.py b/src/ezmsg/event/binned.py index 4d49dab..4fa37d1 100644 --- a/src/ezmsg/event/binned.py +++ b/src/ezmsg/event/binned.py @@ -1,127 +1,113 @@ +"""Bin an event stream into a lower-rate count (or rate) signal. + +The binning itself is delegated to +:obj:`ezmsg.sigproc.binned_aggregate.BinnedAggregate` so the spike-rate branch +shares a *single* bin-boundary implementation with the spike-band-power branch +(``Pow -> BinnedAggregate(MEAN)``). With ``fractional=True`` (the default) bins +span a fractional ``bin_duration * fs`` samples with a carry accumulator, track +wall-clock time, and are labelled with the nominal ``bin_duration`` gain -- +identical to how :obj:`ezmsg.event.rate.EventRate` is consumed downstream. That +makes the two branches land on the same grid at any input rate (including the +off-nominal ~30012 Hz of real recordings), so a downstream ``Merge`` aligns them +with no post-hoc reconciler. + +Sparse ``sparse.COO`` inputs (the default ``ThresholdCrossing`` output) are +densified to per-sample contributions before binning; dense inputs are used as +is. Set ``scale_by_value=True`` to weight each event by its stored value instead +of counting occurrences, and ``scale_output=True`` to divide the per-bin count +by ``bin_duration`` (events/second). +""" + import ezmsg.core as ez -import numpy as np -import numpy.typing as npt -from ezmsg.baseproc import ( - BaseStatefulTransformer, - BaseTransformerUnit, - processor_state, -) +import sparse +from array_api_compat import get_namespace +from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit +from ezmsg.sigproc.aggregate import AggregationFunction +from ezmsg.sigproc.binned_aggregate import BinnedAggregateSettings, BinnedAggregateTransformer from ezmsg.util.messages.axisarray import AxisArray, replace class BinnedEventAggregatorSettings(ez.Settings): bin_duration: float = 0.05 - """ - Duration of each bin in seconds. - This is the time interval over which events will be counted. - """ + """Duration of each output bin in seconds.""" scale_output: bool = True - """ - If True, the output will be scaled by the bin duration. - This is useful for converting counts to rates. - """ + """If True, divide each bin's count by ``bin_duration`` (events/second).""" axis: str = "time" + """Name of the axis to bin along.""" + fractional: bool = True + """If True (default), share ``EventRate``'s fractional wall-clock grid via + :obj:`BinnedAggregate` (nominal ``bin_duration`` gain). If False, use a fixed + ``int(bin_duration * fs)`` sample-locked grid. See :obj:`BinnedAggregate`.""" -@processor_state -class BinnedEventAggregatorState: - n_overflow: int = 0 - counts_in_overflow: npt.NDArray[np.int64] | None = None - + scale_by_value: bool = False + """If True, weight each event by its stored value; if False (default), every + nonzero entry contributes a count of 1.""" -class BinnedEventAggregator( - BaseStatefulTransformer[BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregatorState] -): - def _hash_message(self, message: AxisArray) -> int: - targ_ax_idx = message.get_axis_idx(self.settings.axis) - non_targ_dims = message.dims[:targ_ax_idx] + message.dims[targ_ax_idx + 1 :] - return hash(tuple(non_targ_dims)) - def _reset_state(self, message: AxisArray) -> None: - self._state.n_overflow = 0 - targ_axis_idx = message.get_axis_idx(self.settings.axis) - buff_shape = message.data.shape[:targ_axis_idx] + message.data.shape[targ_axis_idx + 1 :] - self._state.counts_in_overflow = np.zeros(buff_shape, dtype=np.int64) +class BinnedEventAggregator(BaseTransformer[BinnedEventAggregatorSettings, AxisArray, AxisArray]): + """Count events per fixed-duration bin, delegating binning to sigproc. - def _process(self, message: AxisArray) -> AxisArray: - # Quick maths - targ_ax_idx = message.get_axis_idx(self.settings.axis) - targ_axis = message.axes[self.settings.axis] - samples_per_bin = int(self.settings.bin_duration * (1 / targ_axis.gain)) - - # We will be slicing the data several times, so create a variable to hold the slices - var_slice = [slice(None)] * message.data.ndim - - # Store for later use - n_prev_overflow = self._state.n_overflow - - if self._state.n_overflow > 0: - # Calculate how many samples from the input msg we can fit into the first bin, - # given the current overflow state - n_first = samples_per_bin - self._state.n_overflow - # Sum the number of samples in the first bin then add to self._state.counts_in_overflow - var_slice[targ_ax_idx] = slice(0, n_first) - first_bin_counts = message.data[tuple(var_slice)].sum(axis=targ_ax_idx).todense() - first_bin_counts += self._state.counts_in_overflow - else: - n_first = 0 - first_bin_counts = self._state.counts_in_overflow - assert np.all(first_bin_counts == 0), "Overflow state should be zeroed out from previous iteration." - - # Calculate how many samples remain after the first bin - n_remaining = message.data.shape[targ_ax_idx] - n_first - n_full_bins = int(n_remaining / samples_per_bin) - - # Slice the n_first:-next_overflow samples into a segment that divides evenly into bins - split_idx = n_first + n_full_bins * samples_per_bin - var_slice[targ_ax_idx] = slice(n_first, split_idx) - full_bins_data = message.data[tuple(var_slice)] - - # Reshape and sum for full bins - new_shape = ( - full_bins_data.shape[:targ_ax_idx] - + (n_full_bins, samples_per_bin) - + full_bins_data.shape[targ_ax_idx + 1 :] - ) - middle_bin_counts = full_bins_data.reshape(new_shape).sum(axis=targ_ax_idx + 1).todense() + The per-bin reduction, carry across message boundaries, and output time axis + all come from :obj:`BinnedAggregateTransformer`; this wrapper only converts + events to per-sample contributions and optionally rate-normalizes. + """ - # Prepare output - if self._state.n_overflow > 0: - first_bin_counts = first_bin_counts.reshape( - first_bin_counts.shape[:targ_ax_idx] + (1,) + first_bin_counts.shape[targ_ax_idx:] + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._binner = BinnedAggregateTransformer( + BinnedAggregateSettings( + axis=self.settings.axis, + bin_duration=self.settings.bin_duration, + operation=AggregationFunction.SUM, + fractional=self.settings.fractional, ) - output_data = np.concatenate([first_bin_counts, middle_bin_counts], axis=targ_ax_idx) - else: - output_data = middle_bin_counts - - if self.settings.scale_output: - output_data = output_data / self.settings.bin_duration - - # Create the new output axis - # For the target axis, backup the offset by the number of samples in the overflow - out_axis = replace( - targ_axis, - gain=targ_axis.gain * samples_per_bin, - offset=targ_axis.offset - n_prev_overflow * targ_axis.gain, ) - out_msg = replace( - message, - data=output_data, - axes={k: v if k != self.settings.axis else out_axis for k, v in message.axes.items()}, - ) - - # Calculate and store the overflow state. - var_slice[targ_ax_idx] = slice(split_idx, None) - overflow_data = message.data[tuple(var_slice)] - self._state.n_overflow = overflow_data.shape[targ_ax_idx] - self._state.counts_in_overflow = overflow_data.sum(axis=targ_ax_idx).todense() - return out_msg + def _process(self, message: AxisArray) -> AxisArray: + data = message.data + if isinstance(data, sparse.SparseArray): + data = data.todense() + xp = get_namespace(data) + # Per-sample contribution: the event value, or 1 per nonzero entry. + # float64 where usable (exact integer counts; matches the legacy output + # dtype); float32 otherwise. MLX *exposes* ``mx.float64`` as an attribute + # but the GPU rejects it ("float64 is not supported on the GPU"), so + # ``hasattr(xp, "float64")`` is not a sufficient capability check -- + # detect the MLX namespace explicitly and fall back to float32. + is_mlx = getattr(xp, "__name__", "") == "mlx.core" + float_dtype = xp.float64 if (hasattr(xp, "float64") and not is_mlx) else xp.float32 + contrib = data if self.settings.scale_by_value else (data != 0) + contrib = contrib.astype(float_dtype) + + binned = self._binner(replace(message, data=contrib)) + + if self.settings.scale_output and binned.data.size: + binned = replace(binned, data=binned.data / self.settings.bin_duration) + return binned class BinnedEventAggregatorUnit( BaseTransformerUnit[BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregator] ): SETTINGS = BinnedEventAggregatorSettings + + +def binned_event_aggregator( + bin_duration: float = 0.05, + scale_output: bool = True, + axis: str = "time", + fractional: bool = True, + scale_by_value: bool = False, +) -> BinnedEventAggregator: + return BinnedEventAggregator( + BinnedEventAggregatorSettings( + bin_duration=bin_duration, + scale_output=scale_output, + axis=axis, + fractional=fractional, + scale_by_value=scale_by_value, + ) + ) diff --git a/tests/test_binned.py b/tests/test_binned.py index 69c2797..209201e 100644 --- a/tests/test_binned.py +++ b/tests/test_binned.py @@ -1,6 +1,7 @@ import time import numpy as np +import pytest import sparse from conftest import CHUNK_LEN, FS, N_CH, make_sparse_event_msg from ezmsg.util.messages.axisarray import AxisArray @@ -87,3 +88,135 @@ def test_binned_event_aggregator_empty_time_first(): out_normal = proc(msg_normal) assert out_normal.data.ndim == 2 assert out_normal.data.shape[1] == N_CH + + +def _sparse_chunks(spk: np.ndarray, fs: float, block: int) -> list[AxisArray]: + out = [] + for start in range(0, spk.shape[0], block): + out.append( + AxisArray( + data=sparse.COO.from_numpy(spk[start : start + block]), + dims=["time", "ch"], + axes={"time": AxisArray.TimeAxis(fs=fs, offset=start / fs)}, + ) + ) + return out + + +def _run(proc, msgs): + return [r for r in (proc(m) for m in msgs) if r.data.size] + + +@pytest.mark.parametrize("fs", [30_000.0, 30_012.0, 30_030.0]) +def test_fractional_grid_offnominal(fs: float): + """At an off-nominal rate the fractional binner stays on the nominal-gain + wall-clock grid (gain == bin_duration, n_bins == int(n / (bin_duration*fs))), + which is what aligns it with the spike-band-power branch. The legacy + sample-locked path (Window) would instead report gain int(bin*fs)/fs.""" + bin_dur = 0.02 + n = 300_000 + spk = (np.random.default_rng(0).random((n, N_CH)) < 0.01).astype(float) + + proc = BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur)) + out = _run(proc, _sparse_chunks(spk, fs, 7000)) + + spb = bin_dur * fs + assert sum(m.data.shape[0] for m in out) == int(n / spb) + assert out[0].axes["time"].gain == pytest.approx(bin_dur) + + +@pytest.mark.parametrize("fs", [30_000.0, 30_012.0]) +def test_chunk_invariance_offnominal(fs: float): + """Output is identical regardless of how the stream is chunked.""" + bin_dur = 0.02 + spk = (np.random.default_rng(1).random((120_000, N_CH)) < 0.02).astype(float) + + whole = np.concatenate( + [ + m.data + for m in _run( + BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur)), + _sparse_chunks(spk, fs, 120_000), + ) + ], + axis=0, + ) + frag = np.concatenate( + [ + m.data + for m in _run( + BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur)), + _sparse_chunks(spk, fs, 137), + ) + ], + axis=0, + ) + assert whole.shape == frag.shape + np.testing.assert_array_equal(whole, frag) + + +def test_count_vs_rate_scaling(): + """scale_output divides counts by bin_duration; otherwise raw counts.""" + fs = 30_000.0 + bin_dur = 0.02 + spk = (np.random.default_rng(2).random((60_000, N_CH)) < 0.02).astype(float) + msgs = _sparse_chunks(spk, fs, 60_000) + + counts = _run( + BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, scale_output=False)), msgs + )[0] + rate = _run( + BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, scale_output=True)), msgs + )[0] + + # Counts are exact integers; rate is counts / bin_duration. + assert np.allclose(counts.data, np.round(counts.data)) + np.testing.assert_allclose(rate.data, counts.data / bin_dur, rtol=0, atol=1e-9) + + +def test_scale_by_value_weights_by_stored_value(): + """scale_by_value=True sums each event's stored value per bin; the default + counts nonzero entries (1 each).""" + fs = 30_000.0 + bin_dur = 0.02 # 600 samples/bin (exact at this fs) + spb = int(bin_dur * fs) + n = 3 * spb + + # One channel; known events with distinct values: two in bin 0, one in bin 1. + values = np.zeros((n, 1), dtype=float) + values[10, 0] = 2.0 + values[20, 0] = 3.0 # bin 0 -> count 2, value sum 5.0 + values[spb + 5, 0] = 7.0 # bin 1 -> count 1, value sum 7.0 + msg = AxisArray( + data=sparse.COO.from_numpy(values), + dims=["time", "ch"], + axes={"time": AxisArray.TimeAxis(fs=fs, offset=0.0)}, + ) + + count = BinnedEventAggregator( + settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, scale_output=False, scale_by_value=False) + )(msg) + weighted = BinnedEventAggregator( + settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, scale_output=False, scale_by_value=True) + )(msg) + + np.testing.assert_array_equal(count.data[:, 0], [2.0, 1.0, 0.0]) + np.testing.assert_allclose(weighted.data[:, 0], [5.0, 7.0, 0.0]) + + +@pytest.mark.parametrize("fs", [30_000.0, 30_012.0, 30_030.0]) +def test_fractional_false_sample_locked(fs: float): + """fractional=False bins a fixed int(bin_duration*fs) sample count, so the + output gain is int(bin*fs)/fs (sample-locked, matching Window). fs=30030 + (0.02*fs = 600.6) is the discriminating case: truncation gives 600 where + round() would give 601.""" + bin_dur = 0.02 + n = 300_000 + spk = (np.random.default_rng(3).random((n, N_CH)) < 0.01).astype(float) + + proc = BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, fractional=False)) + out = _run(proc, _sparse_chunks(spk, fs, 7000)) + + spb = int(bin_dur * fs) + assert sum(m.data.shape[0] for m in out) == n // spb + assert out[0].axes["time"].gain == pytest.approx(spb / fs) From 3d080e98a2481cb4bb13adf27b04f9c2958f83bd Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 26 Jun 2026 01:43:05 -0400 Subject: [PATCH 5/6] Scope binning docs to ezmsg-event/ezmsg-sigproc Drop references to consumer-pipeline concepts that public users of these packages won't recognize -- named "spike-band-power"/"spike-rate" branches, the "downstream Merge with no post-hoc reconciler" narrative, "real recordings", and the confusing "wall-clock" framing. Describe the behavior in terms of this package and the sigproc binner it delegates to (BinnedAggregate, Window, BinSchedule, Merge, ThresholdCrossing) instead. Comment-only. --- src/ezmsg/event/binned.py | 41 +++++++++++++------------ src/ezmsg/event/rate.py | 3 +- tests/test_bin_schedule_crosspackage.py | 2 +- tests/test_binned.py | 8 ++--- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/src/ezmsg/event/binned.py b/src/ezmsg/event/binned.py index 4fa37d1..0e64ded 100644 --- a/src/ezmsg/event/binned.py +++ b/src/ezmsg/event/binned.py @@ -1,21 +1,23 @@ """Bin an event stream into a lower-rate count (or rate) signal. -The binning itself is delegated to -:obj:`ezmsg.sigproc.binned_aggregate.BinnedAggregate` so the spike-rate branch -shares a *single* bin-boundary implementation with the spike-band-power branch -(``Pow -> BinnedAggregate(MEAN)``). With ``fractional=True`` (the default) bins -span a fractional ``bin_duration * fs`` samples with a carry accumulator, track -wall-clock time, and are labelled with the nominal ``bin_duration`` gain -- -identical to how :obj:`ezmsg.event.rate.EventRate` is consumed downstream. That -makes the two branches land on the same grid at any input rate (including the -off-nominal ~30012 Hz of real recordings), so a downstream ``Merge`` aligns them -with no post-hoc reconciler. - -Sparse ``sparse.COO`` inputs (the default ``ThresholdCrossing`` output) are -densified to per-sample contributions before binning; dense inputs are used as -is. Set ``scale_by_value=True`` to weight each event by its stored value instead -of counting occurrences, and ``scale_output=True`` to divide the per-bin count -by ``bin_duration`` (events/second). +The binning is delegated to +:obj:`ezmsg.sigproc.binned_aggregate.BinnedAggregate` so this shares one +bin-boundary implementation with every other consumer of that binner. With +``fractional=True`` (the default) bins span a fractional ``bin_duration * fs`` +samples with a carry accumulator and are labelled with the nominal +``bin_duration`` gain; with ``fractional=False`` they span a fixed +``int(bin_duration * fs)`` samples (sample-locked, matching +:obj:`ezmsg.sigproc.window.Window`). Because the grid comes from the shared +binner, two streams binned this way at the same ``bin_duration`` land on the +same grid for any input rate and can be aligned downstream (e.g. with +``ezmsg.sigproc.merge.Merge``). + +Sparse ``sparse.COO`` inputs (e.g. the default +:obj:`ezmsg.event.peak.ThresholdCrossing` output) are densified to per-sample +contributions before binning; dense inputs are used as is. Set +``scale_by_value=True`` to weight each event by its stored value instead of +counting occurrences, and ``scale_output=True`` to divide the per-bin count by +``bin_duration`` (events/second). """ import ezmsg.core as ez @@ -38,9 +40,10 @@ class BinnedEventAggregatorSettings(ez.Settings): """Name of the axis to bin along.""" fractional: bool = True - """If True (default), share ``EventRate``'s fractional wall-clock grid via - :obj:`BinnedAggregate` (nominal ``bin_duration`` gain). If False, use a fixed - ``int(bin_duration * fs)`` sample-locked grid. See :obj:`BinnedAggregate`.""" + """If True (default), bins span a fractional ``bin_duration * fs`` samples via + :obj:`BinnedAggregate` and are labelled with the nominal ``bin_duration`` + gain. If False, bins span a fixed ``int(bin_duration * fs)`` samples + (sample-locked). See :obj:`BinnedAggregate`.""" scale_by_value: bool = False """If True, weight each event by its stored value; if False (default), every diff --git a/src/ezmsg/event/rate.py b/src/ezmsg/event/rate.py index 85e864c..f63de69 100644 --- a/src/ezmsg/event/rate.py +++ b/src/ezmsg/event/rate.py @@ -17,8 +17,7 @@ class EventRateSettings(ez.Settings): """If True (default), bins span a fractional ``bin_duration * fs`` samples and the output rate is exactly ``1 / bin_duration``. If False, bins are sample-locked to ``int(bin_duration * fs)`` samples (output rate - ``fs / int(bin_duration * fs)``), matching :obj:`ezmsg.sigproc.window.Window` - so a sample-locked SBP branch and the spike-rate branch share one grid.""" + ``fs / int(bin_duration * fs)``), matching :obj:`ezmsg.sigproc.window.Window`.""" class Rate(BinnedKernelActivation): diff --git a/tests/test_bin_schedule_crosspackage.py b/tests/test_bin_schedule_crosspackage.py index e4f3701..acb9233 100644 --- a/tests/test_bin_schedule_crosspackage.py +++ b/tests/test_bin_schedule_crosspackage.py @@ -54,7 +54,7 @@ def test_eventrate_and_dense_binner_share_grid(fs: float, block_size: int): samp_off += chunk.shape[0] # Per-message: identical gain + offset on every non-empty output, and matching - # bin counts. This is exactly what a downstream Merge(align_axis="time") needs. + # bin counts -- enough to align two streams binned at the same bin_duration. for r, b in zip(rate_out, binner_out): assert r.data.shape[0] == b.data.shape[0] if r.data.shape[0] == 0: diff --git a/tests/test_binned.py b/tests/test_binned.py index 209201e..462f5bc 100644 --- a/tests/test_binned.py +++ b/tests/test_binned.py @@ -109,10 +109,10 @@ def _run(proc, msgs): @pytest.mark.parametrize("fs", [30_000.0, 30_012.0, 30_030.0]) def test_fractional_grid_offnominal(fs: float): - """At an off-nominal rate the fractional binner stays on the nominal-gain - wall-clock grid (gain == bin_duration, n_bins == int(n / (bin_duration*fs))), - which is what aligns it with the spike-band-power branch. The legacy - sample-locked path (Window) would instead report gain int(bin*fs)/fs.""" + """At an off-nominal rate the fractional binner stays on the nominal-gain grid + (gain == bin_duration, n_bins == int(n / (bin_duration*fs))). The sample-locked + path (fractional=False, matching Window) would instead report gain + int(bin*fs)/fs.""" bin_dur = 0.02 n = 300_000 spk = (np.random.default_rng(0).random((n, N_CH)) < 0.01).astype(float) From 9400ed5d9743bf76b9b561ca1020eb16ec23e5ef Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 26 Jun 2026 01:45:49 -0400 Subject: [PATCH 6/6] Drop binned_event_aggregator factory function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The small functional-constructor helpers are a deprecated pattern; new modules shouldn't add them. BinnedEventAggregator(BinnedEventAggregatorSettings(...)) is the supported construction path. Unused — no callers. --- src/ezmsg/event/binned.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/ezmsg/event/binned.py b/src/ezmsg/event/binned.py index 0e64ded..474fb3d 100644 --- a/src/ezmsg/event/binned.py +++ b/src/ezmsg/event/binned.py @@ -96,21 +96,3 @@ class BinnedEventAggregatorUnit( BaseTransformerUnit[BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregator] ): SETTINGS = BinnedEventAggregatorSettings - - -def binned_event_aggregator( - bin_duration: float = 0.05, - scale_output: bool = True, - axis: str = "time", - fractional: bool = True, - scale_by_value: bool = False, -) -> BinnedEventAggregator: - return BinnedEventAggregator( - BinnedEventAggregatorSettings( - bin_duration=bin_duration, - scale_output=scale_output, - axis=axis, - fractional=fractional, - scale_by_value=scale_by_value, - ) - )