Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dynamic = ["version"]
dependencies = [
"ezmsg>=3.9.0",
"ezmsg-baseproc>=1.10.2",
"ezmsg-sigproc>=2.23.0",
"ezmsg-sigproc>=2.26.0",
"sparse>=0.17.0",
"numpy>=2.2.6",
]
Expand Down
175 changes: 73 additions & 102 deletions src/ezmsg/event/binned.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,95 @@
"""Bin an event stream into a lower-rate count (or rate) signal.

The binning is delegated to
:obj:`ezmsg.sigproc.binned_aggregate.BinnedAggregate` so this shares one
bin-boundary implementation with every other consumer of that binner. With
``fractional=True`` (the default) bins span a fractional ``bin_duration * fs``
samples with a carry accumulator and are labelled with the nominal
``bin_duration`` gain; with ``fractional=False`` they span a fixed
``int(bin_duration * fs)`` samples (sample-locked, matching
:obj:`ezmsg.sigproc.window.Window`). Because the grid comes from the shared
binner, two streams binned this way at the same ``bin_duration`` land on the
same grid for any input rate and can be aligned downstream (e.g. with
``ezmsg.sigproc.merge.Merge``).

Sparse ``sparse.COO`` inputs (e.g. the default
:obj:`ezmsg.event.peak.ThresholdCrossing` output) are densified to per-sample
contributions before binning; dense inputs are used as is. Set
``scale_by_value=True`` to weight each event by its stored value instead of
counting occurrences, and ``scale_output=True`` to divide the per-bin count by
``bin_duration`` (events/second).
"""

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
import sparse
from array_api_compat import get_namespace
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
from ezmsg.sigproc.aggregate import AggregationFunction
from ezmsg.sigproc.binned_aggregate import BinnedAggregateSettings, BinnedAggregateTransformer
from ezmsg.util.messages.axisarray import AxisArray, replace


class BinnedEventAggregatorSettings(ez.Settings):
bin_duration: float = 0.05
"""
Duration of each bin in seconds.
This is the time interval over which events will be counted.
"""
"""Duration of each output bin in seconds."""

scale_output: bool = True
"""
If True, the output will be scaled by the bin duration.
This is useful for converting counts to rates.
"""
"""If True, divide each bin's count by ``bin_duration`` (events/second)."""

axis: str = "time"
"""Name of the axis to bin along."""

fractional: bool = True
"""If True (default), bins span a fractional ``bin_duration * fs`` samples via
:obj:`BinnedAggregate` and are labelled with the nominal ``bin_duration``
gain. If False, bins span a fixed ``int(bin_duration * fs)`` samples
(sample-locked). See :obj:`BinnedAggregate`."""

@processor_state
class BinnedEventAggregatorState:
n_overflow: int = 0
counts_in_overflow: npt.NDArray[np.int64] | None = None

scale_by_value: bool = False
"""If True, weight each event by its stored value; if False (default), every
nonzero entry contributes a count of 1."""

class BinnedEventAggregator(
BaseStatefulTransformer[BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregatorState]
):
def _hash_message(self, message: AxisArray) -> int:
targ_ax_idx = message.get_axis_idx(self.settings.axis)
non_targ_dims = message.dims[:targ_ax_idx] + message.dims[targ_ax_idx + 1 :]
return hash(tuple(non_targ_dims))

def _reset_state(self, message: AxisArray) -> None:
self._state.n_overflow = 0
targ_axis_idx = message.get_axis_idx(self.settings.axis)
buff_shape = message.data.shape[:targ_axis_idx] + message.data.shape[targ_axis_idx + 1 :]
self._state.counts_in_overflow = np.zeros(buff_shape, dtype=np.int64)
class BinnedEventAggregator(BaseTransformer[BinnedEventAggregatorSettings, AxisArray, AxisArray]):
"""Count events per fixed-duration bin, delegating binning to sigproc.

def _process(self, message: AxisArray) -> AxisArray:
# Quick maths
targ_ax_idx = message.get_axis_idx(self.settings.axis)
targ_axis = message.axes[self.settings.axis]
samples_per_bin = int(self.settings.bin_duration * (1 / targ_axis.gain))

# We will be slicing the data several times, so create a variable to hold the slices
var_slice = [slice(None)] * message.data.ndim

# Store for later use
n_prev_overflow = self._state.n_overflow

if self._state.n_overflow > 0:
# Calculate how many samples from the input msg we can fit into the first bin,
# given the current overflow state
n_first = samples_per_bin - self._state.n_overflow
# Sum the number of samples in the first bin then add to self._state.counts_in_overflow
var_slice[targ_ax_idx] = slice(0, n_first)
first_bin_counts = message.data[tuple(var_slice)].sum(axis=targ_ax_idx).todense()
first_bin_counts += self._state.counts_in_overflow
else:
n_first = 0
first_bin_counts = self._state.counts_in_overflow
assert np.all(first_bin_counts == 0), "Overflow state should be zeroed out from previous iteration."

# Calculate how many samples remain after the first bin
n_remaining = message.data.shape[targ_ax_idx] - n_first
n_full_bins = int(n_remaining / samples_per_bin)

# Slice the n_first:-next_overflow samples into a segment that divides evenly into bins
split_idx = n_first + n_full_bins * samples_per_bin
var_slice[targ_ax_idx] = slice(n_first, split_idx)
full_bins_data = message.data[tuple(var_slice)]

# Reshape and sum for full bins
new_shape = (
full_bins_data.shape[:targ_ax_idx]
+ (n_full_bins, samples_per_bin)
+ full_bins_data.shape[targ_ax_idx + 1 :]
)
middle_bin_counts = full_bins_data.reshape(new_shape).sum(axis=targ_ax_idx + 1).todense()
The per-bin reduction, carry across message boundaries, and output time axis
all come from :obj:`BinnedAggregateTransformer`; this wrapper only converts
events to per-sample contributions and optionally rate-normalizes.
"""

# Prepare output
if self._state.n_overflow > 0:
first_bin_counts = first_bin_counts.reshape(
first_bin_counts.shape[:targ_ax_idx] + (1,) + first_bin_counts.shape[targ_ax_idx:]
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._binner = BinnedAggregateTransformer(
BinnedAggregateSettings(
axis=self.settings.axis,
bin_duration=self.settings.bin_duration,
operation=AggregationFunction.SUM,
fractional=self.settings.fractional,
)
output_data = np.concatenate([first_bin_counts, middle_bin_counts], axis=targ_ax_idx)
else:
output_data = middle_bin_counts

if self.settings.scale_output:
output_data = output_data / self.settings.bin_duration

# Create the new output axis
# For the target axis, backup the offset by the number of samples in the overflow
out_axis = replace(
targ_axis,
gain=targ_axis.gain * samples_per_bin,
offset=targ_axis.offset - n_prev_overflow * targ_axis.gain,
)
out_msg = replace(
message,
data=output_data,
axes={k: v if k != self.settings.axis else out_axis for k, v in message.axes.items()},
)

# Calculate and store the overflow state.
var_slice[targ_ax_idx] = slice(split_idx, None)
overflow_data = message.data[tuple(var_slice)]
self._state.n_overflow = overflow_data.shape[targ_ax_idx]
self._state.counts_in_overflow = overflow_data.sum(axis=targ_ax_idx).todense()

return out_msg
def _process(self, message: AxisArray) -> AxisArray:
data = message.data
if isinstance(data, sparse.SparseArray):
data = data.todense()
xp = get_namespace(data)
# Per-sample contribution: the event value, or 1 per nonzero entry.
# float64 where usable (exact integer counts; matches the legacy output
# dtype); float32 otherwise. MLX *exposes* ``mx.float64`` as an attribute
# but the GPU rejects it ("float64 is not supported on the GPU"), so
# ``hasattr(xp, "float64")`` is not a sufficient capability check --
# detect the MLX namespace explicitly and fall back to float32.
is_mlx = getattr(xp, "__name__", "") == "mlx.core"
float_dtype = xp.float64 if (hasattr(xp, "float64") and not is_mlx) else xp.float32
contrib = data if self.settings.scale_by_value else (data != 0)
contrib = contrib.astype(float_dtype)

binned = self._binner(replace(message, data=contrib))

if self.settings.scale_output and binned.data.size:
binned = replace(binned, data=binned.data / self.settings.bin_duration)
return binned


class BinnedEventAggregatorUnit(
Expand Down
Loading
Loading