diff --git a/pyproject.toml b/pyproject.toml index 91929ff..80a2305 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dynamic = ["version"] dependencies = [ "ezmsg>=3.9.0", "ezmsg-baseproc>=1.10.2", - "ezmsg-sigproc>=2.23.0", + "ezmsg-sigproc>=2.26.0", "sparse>=0.17.0", "numpy>=2.2.6", ] diff --git a/src/ezmsg/event/binned.py b/src/ezmsg/event/binned.py index 4d49dab..474fb3d 100644 --- a/src/ezmsg/event/binned.py +++ b/src/ezmsg/event/binned.py @@ -1,124 +1,95 @@ +"""Bin an event stream into a lower-rate count (or rate) signal. + +The binning is delegated to +:obj:`ezmsg.sigproc.binned_aggregate.BinnedAggregate` so this shares one +bin-boundary implementation with every other consumer of that binner. With +``fractional=True`` (the default) bins span a fractional ``bin_duration * fs`` +samples with a carry accumulator and are labelled with the nominal +``bin_duration`` gain; with ``fractional=False`` they span a fixed +``int(bin_duration * fs)`` samples (sample-locked, matching +:obj:`ezmsg.sigproc.window.Window`). Because the grid comes from the shared +binner, two streams binned this way at the same ``bin_duration`` land on the +same grid for any input rate and can be aligned downstream (e.g. with +``ezmsg.sigproc.merge.Merge``). + +Sparse ``sparse.COO`` inputs (e.g. the default +:obj:`ezmsg.event.peak.ThresholdCrossing` output) are densified to per-sample +contributions before binning; dense inputs are used as is. Set +``scale_by_value=True`` to weight each event by its stored value instead of +counting occurrences, and ``scale_output=True`` to divide the per-bin count by +``bin_duration`` (events/second). +""" + import ezmsg.core as ez -import numpy as np -import numpy.typing as npt -from ezmsg.baseproc import ( - BaseStatefulTransformer, - BaseTransformerUnit, - processor_state, -) +import sparse +from array_api_compat import get_namespace +from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit +from ezmsg.sigproc.aggregate import AggregationFunction +from ezmsg.sigproc.binned_aggregate import BinnedAggregateSettings, BinnedAggregateTransformer from ezmsg.util.messages.axisarray import AxisArray, replace class BinnedEventAggregatorSettings(ez.Settings): bin_duration: float = 0.05 - """ - Duration of each bin in seconds. - This is the time interval over which events will be counted. - """ + """Duration of each output bin in seconds.""" scale_output: bool = True - """ - If True, the output will be scaled by the bin duration. - This is useful for converting counts to rates. - """ + """If True, divide each bin's count by ``bin_duration`` (events/second).""" axis: str = "time" + """Name of the axis to bin along.""" + fractional: bool = True + """If True (default), bins span a fractional ``bin_duration * fs`` samples via + :obj:`BinnedAggregate` and are labelled with the nominal ``bin_duration`` + gain. If False, bins span a fixed ``int(bin_duration * fs)`` samples + (sample-locked). See :obj:`BinnedAggregate`.""" -@processor_state -class BinnedEventAggregatorState: - n_overflow: int = 0 - counts_in_overflow: npt.NDArray[np.int64] | None = None - + scale_by_value: bool = False + """If True, weight each event by its stored value; if False (default), every + nonzero entry contributes a count of 1.""" -class BinnedEventAggregator( - BaseStatefulTransformer[BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregatorState] -): - def _hash_message(self, message: AxisArray) -> int: - targ_ax_idx = message.get_axis_idx(self.settings.axis) - non_targ_dims = message.dims[:targ_ax_idx] + message.dims[targ_ax_idx + 1 :] - return hash(tuple(non_targ_dims)) - def _reset_state(self, message: AxisArray) -> None: - self._state.n_overflow = 0 - targ_axis_idx = message.get_axis_idx(self.settings.axis) - buff_shape = message.data.shape[:targ_axis_idx] + message.data.shape[targ_axis_idx + 1 :] - self._state.counts_in_overflow = np.zeros(buff_shape, dtype=np.int64) +class BinnedEventAggregator(BaseTransformer[BinnedEventAggregatorSettings, AxisArray, AxisArray]): + """Count events per fixed-duration bin, delegating binning to sigproc. - def _process(self, message: AxisArray) -> AxisArray: - # Quick maths - targ_ax_idx = message.get_axis_idx(self.settings.axis) - targ_axis = message.axes[self.settings.axis] - samples_per_bin = int(self.settings.bin_duration * (1 / targ_axis.gain)) - - # We will be slicing the data several times, so create a variable to hold the slices - var_slice = [slice(None)] * message.data.ndim - - # Store for later use - n_prev_overflow = self._state.n_overflow - - if self._state.n_overflow > 0: - # Calculate how many samples from the input msg we can fit into the first bin, - # given the current overflow state - n_first = samples_per_bin - self._state.n_overflow - # Sum the number of samples in the first bin then add to self._state.counts_in_overflow - var_slice[targ_ax_idx] = slice(0, n_first) - first_bin_counts = message.data[tuple(var_slice)].sum(axis=targ_ax_idx).todense() - first_bin_counts += self._state.counts_in_overflow - else: - n_first = 0 - first_bin_counts = self._state.counts_in_overflow - assert np.all(first_bin_counts == 0), "Overflow state should be zeroed out from previous iteration." - - # Calculate how many samples remain after the first bin - n_remaining = message.data.shape[targ_ax_idx] - n_first - n_full_bins = int(n_remaining / samples_per_bin) - - # Slice the n_first:-next_overflow samples into a segment that divides evenly into bins - split_idx = n_first + n_full_bins * samples_per_bin - var_slice[targ_ax_idx] = slice(n_first, split_idx) - full_bins_data = message.data[tuple(var_slice)] - - # Reshape and sum for full bins - new_shape = ( - full_bins_data.shape[:targ_ax_idx] - + (n_full_bins, samples_per_bin) - + full_bins_data.shape[targ_ax_idx + 1 :] - ) - middle_bin_counts = full_bins_data.reshape(new_shape).sum(axis=targ_ax_idx + 1).todense() + The per-bin reduction, carry across message boundaries, and output time axis + all come from :obj:`BinnedAggregateTransformer`; this wrapper only converts + events to per-sample contributions and optionally rate-normalizes. + """ - # Prepare output - if self._state.n_overflow > 0: - first_bin_counts = first_bin_counts.reshape( - first_bin_counts.shape[:targ_ax_idx] + (1,) + first_bin_counts.shape[targ_ax_idx:] + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._binner = BinnedAggregateTransformer( + BinnedAggregateSettings( + axis=self.settings.axis, + bin_duration=self.settings.bin_duration, + operation=AggregationFunction.SUM, + fractional=self.settings.fractional, ) - output_data = np.concatenate([first_bin_counts, middle_bin_counts], axis=targ_ax_idx) - else: - output_data = middle_bin_counts - - if self.settings.scale_output: - output_data = output_data / self.settings.bin_duration - - # Create the new output axis - # For the target axis, backup the offset by the number of samples in the overflow - out_axis = replace( - targ_axis, - gain=targ_axis.gain * samples_per_bin, - offset=targ_axis.offset - n_prev_overflow * targ_axis.gain, ) - out_msg = replace( - message, - data=output_data, - axes={k: v if k != self.settings.axis else out_axis for k, v in message.axes.items()}, - ) - - # Calculate and store the overflow state. - var_slice[targ_ax_idx] = slice(split_idx, None) - overflow_data = message.data[tuple(var_slice)] - self._state.n_overflow = overflow_data.shape[targ_ax_idx] - self._state.counts_in_overflow = overflow_data.sum(axis=targ_ax_idx).todense() - return out_msg + def _process(self, message: AxisArray) -> AxisArray: + data = message.data + if isinstance(data, sparse.SparseArray): + data = data.todense() + xp = get_namespace(data) + # Per-sample contribution: the event value, or 1 per nonzero entry. + # float64 where usable (exact integer counts; matches the legacy output + # dtype); float32 otherwise. MLX *exposes* ``mx.float64`` as an attribute + # but the GPU rejects it ("float64 is not supported on the GPU"), so + # ``hasattr(xp, "float64")`` is not a sufficient capability check -- + # detect the MLX namespace explicitly and fall back to float32. + is_mlx = getattr(xp, "__name__", "") == "mlx.core" + float_dtype = xp.float64 if (hasattr(xp, "float64") and not is_mlx) else xp.float32 + contrib = data if self.settings.scale_by_value else (data != 0) + contrib = contrib.astype(float_dtype) + + binned = self._binner(replace(message, data=contrib)) + + if self.settings.scale_output and binned.data.size: + binned = replace(binned, data=binned.data / self.settings.bin_duration) + return binned class BinnedEventAggregatorUnit( diff --git a/src/ezmsg/event/kernel_activation.py b/src/ezmsg/event/kernel_activation.py index 8f3d510..471129a 100644 --- a/src/ezmsg/event/kernel_activation.py +++ b/src/ezmsg/event/kernel_activation.py @@ -23,6 +23,7 @@ import sparse from array_api_compat import get_namespace, is_numpy_array from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state +from ezmsg.sigproc.util.binning import BinSchedule from ezmsg.util.messages.axisarray import AxisArray, replace @@ -77,7 +78,16 @@ class BinnedKernelActivationSettings(ez.Settings): """If True, normalize kernel so integral equals 1.""" rate_normalize: bool = False - """If True, divide output by bin_duration to get events/second (for COUNT kernel).""" + """If True, divide output by the bin's duration to get events/second (for COUNT + kernel). The divisor is the *actual* bin duration (``samples_per_bin / fs``), + which equals ``bin_duration`` exactly in fractional mode.""" + + fractional: bool = True + """If True (default), bins span a *fractional* ``bin_duration * fs`` samples + with a carry accumulator; bins track the nominal duration and the output gain + is exactly ``bin_duration``. If False, bins span a *fixed* ``int(bin_duration * + fs)`` samples and the output gain is ``int(bin_duration * fs) / fs`` + (sample-locked, matching :obj:`ezmsg.sigproc.window.Window`).""" @processor_state @@ -96,8 +106,10 @@ class BinnedKernelActivationState: # Input sample rate (cached from first message) fs: float | None = None - # Accumulated fractional bin samples for proper bin alignment - bin_accumulator: float = 0.0 + # Shared bin-boundary schedule. Owns samples-per-bin, the fractional carry, + # the global bin index, and the output gain/offset -- the same primitive the + # ezmsg-sigproc dense binner uses, so both land on an identical grid. + schedule: BinSchedule | None = None class BinnedKernelActivation( @@ -141,7 +153,6 @@ def _reset_state(self, message: AxisArray) -> None: self._state.activation = np.zeros(n_channels, dtype=np.float64) self._state.samples_since_update = np.zeros(n_channels, dtype=np.int64) - self._state.bin_accumulator = 0.0 # For alpha kernel, we need auxiliary state if self.settings.kernel_type == ActivationKernelType.ALPHA: @@ -152,6 +163,11 @@ def _reset_state(self, message: AxisArray) -> None: if time_axis.gain > 0: self._state.fs = 1.0 / time_axis.gain + # The schedule owns the boundary arithmetic from here on; fractional mode + # tracks the nominal bin_duration, sample-locked mode matches Window's grid. + self._state.schedule = BinSchedule(bin_duration=self.settings.bin_duration, fractional=self.settings.fractional) + self._state.schedule.reset(self._state.fs) + def _decay_to_sample(self, channel: int, target_sample: int) -> None: """ Decay activation state to target sample. @@ -238,16 +254,16 @@ def _process_events(self, message: AxisArray) -> AxisArray: n_samples = sparse_data.shape[0] n_channels = sparse_data.shape[1] if sparse_data.ndim > 1 else 1 - # Calculate bin parameters - samples_per_bin = self.settings.bin_duration * self._state.fs - total_samples = n_samples + self._state.bin_accumulator - n_bins = int(total_samples / samples_per_bin) + # Boundary arithmetic is delegated to the shared schedule (samples-per-bin, + # fractional carry, global bin index, output gain/offset). + in_offset = message.axes["time"].offset if "time" in message.axes else 0.0 + n_carry_before = self._state.schedule.carry_count + step = self._state.schedule.advance(n_new=n_samples, in_offset=in_offset, gain_in=1.0 / self._state.fs) + n_bins = step.n_bins if n_bins == 0: - # Not enough samples for a full bin yet - self._state.bin_accumulator = total_samples - - # Still need to process events to update state + # Not enough samples for a full bin yet (the schedule has folded these + # samples into its carry). Still process events to update state. if hasattr(sparse_data, "coords") and hasattr(sparse_data, "data"): coords = sparse_data.coords values = sparse_data.data @@ -263,18 +279,15 @@ def _process_events(self, message: AxisArray) -> AxisArray: data=np.zeros((0, n_channels), dtype=np.float64), axes={ **message.axes, - "time": replace(message.axes["time"], gain=self.settings.bin_duration), + "time": replace(message.axes["time"], gain=step.output_gain), }, ) - # Calculate bin boundaries (in input samples, relative to chunk start) - # Account for accumulator from previous chunk - accumulator_before = self._state.bin_accumulator # Save for offset calculation - first_bin_end = samples_per_bin - self._state.bin_accumulator - bin_ends = first_bin_end + np.arange(n_bins) * samples_per_bin - - # Update accumulator for next chunk - self._state.bin_accumulator = total_samples - n_bins * samples_per_bin + # Bin ends in input samples relative to *this chunk's* start. The schedule + # returns cut points into [carry ++ new]; subtracting the pre-advance carry + # count maps them back to chunk-local indices (identical to the legacy + # `(spb - acc) + arange*spb` truncated formula). + bin_ends = np.asarray(step.cut_points, dtype=np.int64) - n_carry_before # Collect events sorted by time events = [] @@ -362,15 +375,10 @@ def _process_events(self, message: AxisArray) -> AxisArray: # Update state sample counters relative to next chunk self._state.samples_since_update -= n_samples - # Apply rate normalization if requested (divide by bin_duration to get events/second) + # Apply rate normalization if requested (divide by the bin's actual + # duration to get events/second; == bin_duration in fractional mode). if self.settings.rate_normalize: - output = output / self.settings.bin_duration - - # Calculate output time offset - # The first bin starts at (input_offset - accumulator_time) - input_offset = message.axes["time"].offset if "time" in message.axes else 0.0 - accumulator_time = accumulator_before / self._state.fs - output_offset = input_offset - accumulator_time + output = output / step.output_gain return replace( message, @@ -378,8 +386,8 @@ def _process_events(self, message: AxisArray) -> AxisArray: axes={ **message.axes, "time": AxisArray.TimeAxis( - fs=1.0 / self.settings.bin_duration, - offset=output_offset, + fs=1.0 / step.output_gain, + offset=step.output_offset, ), }, ) @@ -397,10 +405,10 @@ def _process_dense_count_sum(self, message: AxisArray) -> AxisArray: n_samples = data.shape[0] feature_shape = tuple(data.shape[1:]) - samples_per_bin = self.settings.bin_duration * self._state.fs - accumulator_before = self._state.bin_accumulator - total_samples = n_samples + accumulator_before - n_bins = int(total_samples / samples_per_bin) + in_offset = message.axes["time"].offset if "time" in message.axes else 0.0 + n_carry_before = self._state.schedule.carry_count + step = self._state.schedule.advance(n_new=n_samples, in_offset=in_offset, gain_in=1.0 / self._state.fs) + n_bins = step.n_bins # Per-sample contribution: 1 per non-zero, or the value itself if scaling. # Use the .astype() method form so the same call works for both numpy and mlx @@ -419,20 +427,18 @@ def _process_dense_count_sum(self, message: AxisArray) -> AxisArray: # No complete bins this chunk — accumulate everything into the carry-over. new_overflow = overflow_xp + (xp.sum(contrib, axis=0) if n_samples > 0 else overflow_xp * 0) self._state.activation = np.asarray(new_overflow).reshape(self._state.activation.shape) - self._state.bin_accumulator = total_samples return replace( message, data=xp.zeros((0,) + feature_shape, dtype=xp.float32), axes={ **message.axes, - "time": replace(message.axes["time"], gain=self.settings.bin_duration), + "time": replace(message.axes["time"], gain=step.output_gain), }, ) - # Bin boundaries (in input-sample space, integer-truncated as in the event-based path). - first_bin_end = samples_per_bin - accumulator_before - bin_ends_float = first_bin_end + np.arange(n_bins) * samples_per_bin - bin_end_samples = bin_ends_float.astype(np.int64) + # Bin boundaries in this chunk's sample space, from the shared schedule + # (cut points into [carry ++ new] mapped back to chunk-local indices). + bin_end_samples = np.asarray(step.cut_points, dtype=np.int64) - n_carry_before bin_start_samples = np.concatenate(([np.int64(0)], bin_end_samples[:-1])) # Cumulative sum, prepended with zeros so cumsum_padded[k] = sum(contrib[:k]). @@ -462,14 +468,9 @@ def _process_dense_count_sum(self, message: AxisArray) -> AxisArray: else: new_overflow = xp.zeros(feature_shape, dtype=cumsum.dtype) self._state.activation = np.asarray(new_overflow).reshape(self._state.activation.shape) - self._state.bin_accumulator = total_samples - n_bins * samples_per_bin if self.settings.rate_normalize: - output = output / self.settings.bin_duration - - accumulator_time = accumulator_before / self._state.fs - input_offset = message.axes["time"].offset if "time" in message.axes else 0.0 - output_offset = input_offset - accumulator_time + output = output / step.output_gain return replace( message, @@ -477,8 +478,8 @@ def _process_dense_count_sum(self, message: AxisArray) -> AxisArray: axes={ **message.axes, "time": AxisArray.TimeAxis( - fs=1.0 / self.settings.bin_duration, - offset=output_offset, + fs=1.0 / step.output_gain, + offset=step.output_offset, ), }, ) diff --git a/src/ezmsg/event/rate.py b/src/ezmsg/event/rate.py index e2d07dc..f63de69 100644 --- a/src/ezmsg/event/rate.py +++ b/src/ezmsg/event/rate.py @@ -13,6 +13,12 @@ class EventRateSettings(ez.Settings): bin_duration: float = 0.05 + fractional: bool = True + """If True (default), bins span a fractional ``bin_duration * fs`` samples and + the output rate is exactly ``1 / bin_duration``. If False, bins are + sample-locked to ``int(bin_duration * fs)`` samples (output rate + ``fs / int(bin_duration * fs)``), matching :obj:`ezmsg.sigproc.window.Window`.""" + class Rate(BinnedKernelActivation): """ @@ -31,6 +37,7 @@ def __init__(self, settings: EventRateSettings) -> None: scale_by_value=False, normalize=False, rate_normalize=True, + fractional=settings.fractional, ) ) diff --git a/tests/test_bin_schedule_crosspackage.py b/tests/test_bin_schedule_crosspackage.py new file mode 100644 index 0000000..acb9233 --- /dev/null +++ b/tests/test_bin_schedule_crosspackage.py @@ -0,0 +1,110 @@ +"""Cross-package proof that EventRate and the ezmsg-sigproc dense binner share a grid. + +Both now route their bin boundaries through ``ezmsg.sigproc.util.binning.BinSchedule``. +This test feeds identical data to the real ``EventRate`` (count/sum path) and to +``BinnedAggregateTransformer`` (operation=SUM) and asserts they produce the same +output time axis (gain + offset) and the same bin counts -- the alignment property +the shared primitive exists to guarantee -- including at an off-nominal sample +rate and under adversarial chunking. + +EventRate has ``rate_normalize=True`` (counts / bin_duration), so its values equal +the SUM binner's values divided by bin_duration; everything else must match exactly. +""" + +import numpy as np +import pytest +from ezmsg.sigproc.aggregate import AggregationFunction +from ezmsg.sigproc.binned_aggregate import BinnedAggregateTransformer +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.event.rate import EventRateSettings, Rate + + +def _make_msg(arr: np.ndarray, fs: float, offset: float) -> AxisArray: + return AxisArray( + data=arr, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=fs, offset=offset), + "ch": AxisArray.CoordinateAxis(data=np.arange(arr.shape[1]), dims=["ch"]), + }, + ) + + +@pytest.mark.parametrize("fs", [1000.0, 30012.0, 30030.0]) +@pytest.mark.parametrize("block_size", [50000, 1, 777]) +def test_eventrate_and_dense_binner_share_grid(fs: float, block_size: int): + bin_duration = 0.02 + rng = np.random.default_rng(0) + # Dense spike train (0/1). EventRate's dense COUNT+SUM path counts non-zeros; + # BinnedAggregate(SUM) sums them -- same quantity per bin. + spikes = (rng.random((40000, 3)) < 0.01).astype(np.float32) + + rate = Rate(EventRateSettings(bin_duration=bin_duration)) + binner = BinnedAggregateTransformer( + axis="time", bin_duration=bin_duration, operation=AggregationFunction.SUM, fractional=True + ) + + rate_out, binner_out, samp_off = [], [], 0 + for start in range(0, spikes.shape[0], block_size): + chunk = spikes[start : start + block_size] + offset = samp_off / fs + rate_out.append(rate(_make_msg(chunk, fs, offset))) + binner_out.append(binner(_make_msg(chunk, fs, offset))) + samp_off += chunk.shape[0] + + # Per-message: identical gain + offset on every non-empty output, and matching + # bin counts -- enough to align two streams binned at the same bin_duration. + for r, b in zip(rate_out, binner_out): + assert r.data.shape[0] == b.data.shape[0] + if r.data.shape[0] == 0: + continue + assert r.axes["time"].gain == pytest.approx(b.axes["time"].gain) + assert r.axes["time"].offset == pytest.approx(b.axes["time"].offset) + + rate_all = np.concatenate([m.data for m in rate_out], axis=0) + binner_all = np.concatenate([m.data for m in binner_out], axis=0) + assert rate_all.shape == binner_all.shape + # EventRate divides counts by bin_duration (rate_normalize); undo to compare. + np.testing.assert_allclose(rate_all * bin_duration, binner_all, rtol=0, atol=1e-6) + + +@pytest.mark.parametrize("fs", [30012.0, 30030.0]) +@pytest.mark.parametrize("block_size", [40000, 1, 333]) +def test_eventrate_fractional_false_is_sample_locked(fs: float, block_size: int): + """fractional=False EventRate bins on Window's int(bin_duration*fs) grid and + shares it with BinnedAggregate(fractional=False). Its output rate is the + sample-locked fs/int(bin_duration*fs), not the nominal 1/bin_duration.""" + bin_duration = 0.02 + window_samples = int(bin_duration * fs) + expected_gain = window_samples / fs + rng = np.random.default_rng(1) + spikes = (rng.random((40000, 2)) < 0.01).astype(np.float32) + + rate = Rate(EventRateSettings(bin_duration=bin_duration, fractional=False)) + binner = BinnedAggregateTransformer( + axis="time", bin_duration=bin_duration, operation=AggregationFunction.SUM, fractional=False + ) + + rate_out, binner_out, samp_off = [], [], 0 + for start in range(0, spikes.shape[0], block_size): + chunk = spikes[start : start + block_size] + offset = samp_off / fs + rate_out.append(rate(_make_msg(chunk, fs, offset))) + binner_out.append(binner(_make_msg(chunk, fs, offset))) + samp_off += chunk.shape[0] + + for r, b in zip(rate_out, binner_out): + assert r.data.shape[0] == b.data.shape[0] + if r.data.shape[0] == 0: + continue + # Sample-locked gain (== Window's), not the nominal bin_duration. + assert r.axes["time"].gain == pytest.approx(expected_gain) + assert r.axes["time"].gain == pytest.approx(b.axes["time"].gain) + assert r.axes["time"].offset == pytest.approx(b.axes["time"].offset) + + rate_all = np.concatenate([m.data for m in rate_out], axis=0) + binner_all = np.concatenate([m.data for m in binner_out], axis=0) + assert rate_all.shape == binner_all.shape + # rate_normalize now divides by the actual bin duration (== expected_gain). + np.testing.assert_allclose(rate_all * expected_gain, binner_all, rtol=0, atol=1e-6) diff --git a/tests/test_binned.py b/tests/test_binned.py index 69c2797..462f5bc 100644 --- a/tests/test_binned.py +++ b/tests/test_binned.py @@ -1,6 +1,7 @@ import time import numpy as np +import pytest import sparse from conftest import CHUNK_LEN, FS, N_CH, make_sparse_event_msg from ezmsg.util.messages.axisarray import AxisArray @@ -87,3 +88,135 @@ def test_binned_event_aggregator_empty_time_first(): out_normal = proc(msg_normal) assert out_normal.data.ndim == 2 assert out_normal.data.shape[1] == N_CH + + +def _sparse_chunks(spk: np.ndarray, fs: float, block: int) -> list[AxisArray]: + out = [] + for start in range(0, spk.shape[0], block): + out.append( + AxisArray( + data=sparse.COO.from_numpy(spk[start : start + block]), + dims=["time", "ch"], + axes={"time": AxisArray.TimeAxis(fs=fs, offset=start / fs)}, + ) + ) + return out + + +def _run(proc, msgs): + return [r for r in (proc(m) for m in msgs) if r.data.size] + + +@pytest.mark.parametrize("fs", [30_000.0, 30_012.0, 30_030.0]) +def test_fractional_grid_offnominal(fs: float): + """At an off-nominal rate the fractional binner stays on the nominal-gain grid + (gain == bin_duration, n_bins == int(n / (bin_duration*fs))). The sample-locked + path (fractional=False, matching Window) would instead report gain + int(bin*fs)/fs.""" + bin_dur = 0.02 + n = 300_000 + spk = (np.random.default_rng(0).random((n, N_CH)) < 0.01).astype(float) + + proc = BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur)) + out = _run(proc, _sparse_chunks(spk, fs, 7000)) + + spb = bin_dur * fs + assert sum(m.data.shape[0] for m in out) == int(n / spb) + assert out[0].axes["time"].gain == pytest.approx(bin_dur) + + +@pytest.mark.parametrize("fs", [30_000.0, 30_012.0]) +def test_chunk_invariance_offnominal(fs: float): + """Output is identical regardless of how the stream is chunked.""" + bin_dur = 0.02 + spk = (np.random.default_rng(1).random((120_000, N_CH)) < 0.02).astype(float) + + whole = np.concatenate( + [ + m.data + for m in _run( + BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur)), + _sparse_chunks(spk, fs, 120_000), + ) + ], + axis=0, + ) + frag = np.concatenate( + [ + m.data + for m in _run( + BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur)), + _sparse_chunks(spk, fs, 137), + ) + ], + axis=0, + ) + assert whole.shape == frag.shape + np.testing.assert_array_equal(whole, frag) + + +def test_count_vs_rate_scaling(): + """scale_output divides counts by bin_duration; otherwise raw counts.""" + fs = 30_000.0 + bin_dur = 0.02 + spk = (np.random.default_rng(2).random((60_000, N_CH)) < 0.02).astype(float) + msgs = _sparse_chunks(spk, fs, 60_000) + + counts = _run( + BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, scale_output=False)), msgs + )[0] + rate = _run( + BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, scale_output=True)), msgs + )[0] + + # Counts are exact integers; rate is counts / bin_duration. + assert np.allclose(counts.data, np.round(counts.data)) + np.testing.assert_allclose(rate.data, counts.data / bin_dur, rtol=0, atol=1e-9) + + +def test_scale_by_value_weights_by_stored_value(): + """scale_by_value=True sums each event's stored value per bin; the default + counts nonzero entries (1 each).""" + fs = 30_000.0 + bin_dur = 0.02 # 600 samples/bin (exact at this fs) + spb = int(bin_dur * fs) + n = 3 * spb + + # One channel; known events with distinct values: two in bin 0, one in bin 1. + values = np.zeros((n, 1), dtype=float) + values[10, 0] = 2.0 + values[20, 0] = 3.0 # bin 0 -> count 2, value sum 5.0 + values[spb + 5, 0] = 7.0 # bin 1 -> count 1, value sum 7.0 + msg = AxisArray( + data=sparse.COO.from_numpy(values), + dims=["time", "ch"], + axes={"time": AxisArray.TimeAxis(fs=fs, offset=0.0)}, + ) + + count = BinnedEventAggregator( + settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, scale_output=False, scale_by_value=False) + )(msg) + weighted = BinnedEventAggregator( + settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, scale_output=False, scale_by_value=True) + )(msg) + + np.testing.assert_array_equal(count.data[:, 0], [2.0, 1.0, 0.0]) + np.testing.assert_allclose(weighted.data[:, 0], [5.0, 7.0, 0.0]) + + +@pytest.mark.parametrize("fs", [30_000.0, 30_012.0, 30_030.0]) +def test_fractional_false_sample_locked(fs: float): + """fractional=False bins a fixed int(bin_duration*fs) sample count, so the + output gain is int(bin*fs)/fs (sample-locked, matching Window). fs=30030 + (0.02*fs = 600.6) is the discriminating case: truncation gives 600 where + round() would give 601.""" + bin_dur = 0.02 + n = 300_000 + spk = (np.random.default_rng(3).random((n, N_CH)) < 0.01).astype(float) + + proc = BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur, fractional=False)) + out = _run(proc, _sparse_chunks(spk, fs, 7000)) + + spb = int(bin_dur * fs) + assert sum(m.data.shape[0] for m in out) == n // spb + assert out[0].axes["time"].gain == pytest.approx(spb / fs)