Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dynamic = ["version"]
dependencies = [
"ezmsg>=3.9.0",
"ezmsg-baseproc>=1.10.2",
"ezmsg-sigproc>=2.22.0",
"ezmsg-sigproc>=2.23.0",
"sparse>=0.17.0",
"numpy>=2.2.6",
]
Expand Down Expand Up @@ -65,3 +65,4 @@ known-third-party = ["ezmsg"]
[tool.uv.sources]
# Uncomment to use development version of ezmsg from git
#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "dev" }
ezmsg-sigproc = { git = "https://github.com/ezmsg-org/ezmsg-sigproc.git", branch = "generalize-binner" }
190 changes: 88 additions & 102 deletions src/ezmsg/event/binned.py
Original file line number Diff line number Diff line change
@@ -1,127 +1,113 @@
"""Bin an event stream into a lower-rate count (or rate) signal.

The binning itself is delegated to
:obj:`ezmsg.sigproc.binned_aggregate.BinnedAggregate` so the spike-rate branch
shares a *single* bin-boundary implementation with the spike-band-power branch
(``Pow -> BinnedAggregate(MEAN)``). With ``fractional=True`` (the default) bins
span a fractional ``bin_duration * fs`` samples with a carry accumulator, track
wall-clock time, and are labelled with the nominal ``bin_duration`` gain --
identical to how :obj:`ezmsg.event.rate.EventRate` is consumed downstream. That
makes the two branches land on the same grid at any input rate (including the
off-nominal ~30012 Hz of real recordings), so a downstream ``Merge`` aligns them
with no post-hoc reconciler.

Sparse ``sparse.COO`` inputs (the default ``ThresholdCrossing`` output) are
densified to per-sample contributions before binning; dense inputs are used as
is. Set ``scale_by_value=True`` to weight each event by its stored value instead
of counting occurrences, and ``scale_output=True`` to divide the per-bin count
by ``bin_duration`` (events/second).
"""

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
import sparse
from array_api_compat import get_namespace
from ezmsg.baseproc import BaseTransformer, BaseTransformerUnit
from ezmsg.sigproc.aggregate import AggregationFunction
from ezmsg.sigproc.binned_aggregate import BinnedAggregateSettings, BinnedAggregateTransformer
from ezmsg.util.messages.axisarray import AxisArray, replace


class BinnedEventAggregatorSettings(ez.Settings):
bin_duration: float = 0.05
"""
Duration of each bin in seconds.
This is the time interval over which events will be counted.
"""
"""Duration of each output bin in seconds."""

scale_output: bool = True
"""
If True, the output will be scaled by the bin duration.
This is useful for converting counts to rates.
"""
"""If True, divide each bin's count by ``bin_duration`` (events/second)."""

axis: str = "time"
"""Name of the axis to bin along."""

fractional: bool = True
"""If True (default), share ``EventRate``'s fractional wall-clock grid via
:obj:`BinnedAggregate` (nominal ``bin_duration`` gain). If False, use a fixed
``round(bin_duration * fs)`` sample-locked grid. See :obj:`BinnedAggregate`."""

@processor_state
class BinnedEventAggregatorState:
n_overflow: int = 0
counts_in_overflow: npt.NDArray[np.int64] | None = None

scale_by_value: bool = False
"""If True, weight each event by its stored value; if False (default), every
nonzero entry contributes a count of 1."""

class BinnedEventAggregator(
BaseStatefulTransformer[BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregatorState]
):
def _hash_message(self, message: AxisArray) -> int:
targ_ax_idx = message.get_axis_idx(self.settings.axis)
non_targ_dims = message.dims[:targ_ax_idx] + message.dims[targ_ax_idx + 1 :]
return hash(tuple(non_targ_dims))

def _reset_state(self, message: AxisArray) -> None:
self._state.n_overflow = 0
targ_axis_idx = message.get_axis_idx(self.settings.axis)
buff_shape = message.data.shape[:targ_axis_idx] + message.data.shape[targ_axis_idx + 1 :]
self._state.counts_in_overflow = np.zeros(buff_shape, dtype=np.int64)
class BinnedEventAggregator(BaseTransformer[BinnedEventAggregatorSettings, AxisArray, AxisArray]):
"""Count events per fixed-duration bin, delegating binning to sigproc.

def _process(self, message: AxisArray) -> AxisArray:
# Quick maths
targ_ax_idx = message.get_axis_idx(self.settings.axis)
targ_axis = message.axes[self.settings.axis]
samples_per_bin = int(self.settings.bin_duration * (1 / targ_axis.gain))

# We will be slicing the data several times, so create a variable to hold the slices
var_slice = [slice(None)] * message.data.ndim

# Store for later use
n_prev_overflow = self._state.n_overflow

if self._state.n_overflow > 0:
# Calculate how many samples from the input msg we can fit into the first bin,
# given the current overflow state
n_first = samples_per_bin - self._state.n_overflow
# Sum the number of samples in the first bin then add to self._state.counts_in_overflow
var_slice[targ_ax_idx] = slice(0, n_first)
first_bin_counts = message.data[tuple(var_slice)].sum(axis=targ_ax_idx).todense()
first_bin_counts += self._state.counts_in_overflow
else:
n_first = 0
first_bin_counts = self._state.counts_in_overflow
assert np.all(first_bin_counts == 0), "Overflow state should be zeroed out from previous iteration."

# Calculate how many samples remain after the first bin
n_remaining = message.data.shape[targ_ax_idx] - n_first
n_full_bins = int(n_remaining / samples_per_bin)

# Slice the n_first:-next_overflow samples into a segment that divides evenly into bins
split_idx = n_first + n_full_bins * samples_per_bin
var_slice[targ_ax_idx] = slice(n_first, split_idx)
full_bins_data = message.data[tuple(var_slice)]

# Reshape and sum for full bins
new_shape = (
full_bins_data.shape[:targ_ax_idx]
+ (n_full_bins, samples_per_bin)
+ full_bins_data.shape[targ_ax_idx + 1 :]
)
middle_bin_counts = full_bins_data.reshape(new_shape).sum(axis=targ_ax_idx + 1).todense()
The per-bin reduction, carry across message boundaries, and output time axis
all come from :obj:`BinnedAggregateTransformer`; this wrapper only converts
events to per-sample contributions and optionally rate-normalizes.
"""

# Prepare output
if self._state.n_overflow > 0:
first_bin_counts = first_bin_counts.reshape(
first_bin_counts.shape[:targ_ax_idx] + (1,) + first_bin_counts.shape[targ_ax_idx:]
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._binner = BinnedAggregateTransformer(
BinnedAggregateSettings(
axis=self.settings.axis,
bin_duration=self.settings.bin_duration,
operation=AggregationFunction.SUM,
fractional=self.settings.fractional,
)
output_data = np.concatenate([first_bin_counts, middle_bin_counts], axis=targ_ax_idx)
else:
output_data = middle_bin_counts

if self.settings.scale_output:
output_data = output_data / self.settings.bin_duration

# Create the new output axis
# For the target axis, backup the offset by the number of samples in the overflow
out_axis = replace(
targ_axis,
gain=targ_axis.gain * samples_per_bin,
offset=targ_axis.offset - n_prev_overflow * targ_axis.gain,
)
out_msg = replace(
message,
data=output_data,
axes={k: v if k != self.settings.axis else out_axis for k, v in message.axes.items()},
)

# Calculate and store the overflow state.
var_slice[targ_ax_idx] = slice(split_idx, None)
overflow_data = message.data[tuple(var_slice)]
self._state.n_overflow = overflow_data.shape[targ_ax_idx]
self._state.counts_in_overflow = overflow_data.sum(axis=targ_ax_idx).todense()

return out_msg
def _process(self, message: AxisArray) -> AxisArray:
data = message.data
if isinstance(data, sparse.SparseArray):
data = data.todense()
xp = get_namespace(data)
# Per-sample contribution: the event value, or 1 per nonzero entry.
# float64 where usable (exact integer counts; matches the legacy output
# dtype); float32 otherwise. MLX *exposes* ``mx.float64`` as an attribute
# but the GPU rejects it ("float64 is not supported on the GPU"), so
# ``hasattr(xp, "float64")`` is not a sufficient capability check --
# detect the MLX namespace explicitly and fall back to float32.
is_mlx = getattr(xp, "__name__", "") == "mlx.core"
float_dtype = xp.float64 if (hasattr(xp, "float64") and not is_mlx) else xp.float32
contrib = data if self.settings.scale_by_value else (data != 0)
contrib = contrib.astype(float_dtype)

binned = self._binner(replace(message, data=contrib))

if self.settings.scale_output and binned.data.size:
binned = replace(binned, data=binned.data / self.settings.bin_duration)
return binned


class BinnedEventAggregatorUnit(
BaseTransformerUnit[BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregator]
):
SETTINGS = BinnedEventAggregatorSettings


def binned_event_aggregator(
bin_duration: float = 0.05,
scale_output: bool = True,
axis: str = "time",
fractional: bool = True,
scale_by_value: bool = False,
) -> BinnedEventAggregator:
return BinnedEventAggregator(
BinnedEventAggregatorSettings(
bin_duration=bin_duration,
scale_output=scale_output,
axis=axis,
fractional=fractional,
scale_by_value=scale_by_value,
)
)
35 changes: 19 additions & 16 deletions src/ezmsg/event/rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,43 @@
from ezmsg.baseproc import BaseTransformerUnit
from ezmsg.util.messages.axisarray import AxisArray

from ezmsg.event.kernel_activation import (
ActivationKernelType,
BinAggregation,
BinnedKernelActivation,
BinnedKernelActivationSettings,
)
from ezmsg.event.binned import BinnedEventAggregator, BinnedEventAggregatorSettings


class EventRateSettings(ez.Settings):
bin_duration: float = 0.05

fractional: bool = True
"""If True (default), bins track wall-clock time on the nominal-``bin_duration``
grid (fractional samples-per-bin with a carry accumulator). If False, bins are
a fixed ``round(bin_duration * fs)`` samples (sample-locked). See
:obj:`ezmsg.sigproc.binned_aggregate.BinnedAggregate`."""

class Rate(BinnedKernelActivation):

class Rate(BinnedEventAggregator):
"""
Event rate calculator (events per second).

Counts events per bin and divides by bin_duration to get rate in events/second.
Counts events per bin and divides by ``bin_duration`` to get rate in
events/second. Binning is delegated to
:obj:`ezmsg.sigproc.binned_aggregate.BinnedAggregate`, so the spike-rate
output shares one bin-boundary implementation -- and therefore the same
output grid -- with the spike-band-power branch.
"""

def __init__(self, settings: EventRateSettings) -> None:
super().__init__(
BinnedKernelActivationSettings(
kernel_type=ActivationKernelType.COUNT,
tau=1.0, # Not used for COUNT
BinnedEventAggregatorSettings(
bin_duration=settings.bin_duration,
aggregation=BinAggregation.SUM,
scale_by_value=False,
normalize=False,
rate_normalize=True,
scale_output=True, # counts -> events/second
axis="time",
fractional=settings.fractional,
scale_by_value=False, # count events, ignore stored peak values
)
)


class EventRate(BaseTransformerUnit[EventRateSettings, AxisArray, AxisArray, Rate]):
"""Unit for computing event rate from sparse events."""
"""Unit for computing event rate from sparse (or dense) events."""

SETTINGS = EventRateSettings
Loading
Loading