Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/ezmsg/sigproc/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,12 @@ class ConcatSettings(ez.Settings):
new_key: str | None = None
"""Output AxisArray key. If None, uses the key from signal A."""

auto_coerce_backend: bool = False
"""If True, silently coerce signal B to signal A's array namespace when the
two inputs are on mismatched backends (e.g. MLX vs numpy). Defaults to False
(strict): a backend mismatch raises a clear error instead, since it is almost
always an upstream bug and silent coercion hides device<->host copies."""


@dataclass
class ConcatState:
Expand Down Expand Up @@ -450,6 +456,18 @@ def _concat(self, a: AxisArray, b: AxisArray) -> AxisArray:
new_axis = concat_dim not in a.dims

xp = get_namespace(a.data)
xp_b = get_namespace(b.data)
if xp_b is not xp:
if self.settings.auto_coerce_backend:
b = replace(b, data=xp.asarray(b.data))
else:
raise TypeError(
f"Concat received inputs on mismatched backends: "
f"a.data namespace={xp.__name__}, b.data namespace={xp_b.__name__}. "
f"Coerce both inputs to one backend upstream "
f"(e.g., via ezmsg.sigproc.asarray.AsArrayTransformer) before merging, "
f"or set ConcatSettings(auto_coerce_backend=True) to coerce B to A's backend."
)

# expand_dims for new-axis concatenation.
if new_axis:
Expand Down
8 changes: 8 additions & 0 deletions src/ezmsg/sigproc/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ezmsg.core as ez
import numpy as np
import scipy.interpolate
from array_api_compat import get_namespace
from ezmsg.baseproc import (
BaseConsumerUnit,
BaseStatefulProcessor,
Expand Down Expand Up @@ -235,6 +236,13 @@ def __next__(self) -> AxisArray:
# Calculate output
resampled_data = f(xnew)

# scipy.interpolate.interp1d always returns numpy. Coerce back to the
# source's namespace so downstream merges with same-backend streams
# don't see a backend mismatch.
src_xp = get_namespace(src_axarr.data)
if get_namespace(resampled_data) is not src_xp:
resampled_data = src_xp.asarray(resampled_data)

# Create output message
if hasattr(ref_ax, "data"):
out_ax = replace(ref_ax, data=xnew)
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,65 @@ def _msg_mx(np_data, ch_labels):
np.testing.assert_array_equal(np.asarray(result.data[:, 2:]), 3.0)


class TestBackendMismatch:
"""Concat across mismatched array backends should fail loudly (default)
or coerce on opt-in, instead of raising a cryptic MLX-internals error."""

@staticmethod
def _msg_mx(np_data, ch_labels):
import mlx.core as mx

return AxisArray(
data=mx.array(np_data),
dims=["time", "ch"],
axes=frozendict(
{
"time": AxisArray.TimeAxis(fs=100.0, offset=0.0),
"ch": CoordinateAxis(data=np.array(ch_labels), dims=["ch"], unit="label"),
}
),
key="",
)

@requires_mlx
def test_mismatch_raises_clear_error(self):
"""MLX a + numpy b -> clear, named backend-mismatch error."""
n = 5
msg_a = self._msg_mx(np.ones((n, 2), dtype=np.float32), ["A0", "A1"])
msg_b = _make_msg(np.ones((n, 2), dtype=np.float32) * 2, ch_labels=["B0", "B1"])

proc = ConcatProcessor(ConcatSettings(axis="feature"))
with pytest.raises(TypeError, match="mismatched backends"):
proc._concat(msg_a, msg_b)

@requires_mlx
def test_auto_coerce_backend_coerces_b_to_a(self):
"""auto_coerce_backend=True coerces b onto a's namespace instead of raising."""
import mlx.core as mx

n = 5
msg_a = self._msg_mx(np.ones((n, 2), dtype=np.float32), ["A0", "A1"])
msg_b = _make_msg(np.ones((n, 2), dtype=np.float32) * 2, ch_labels=["B0", "B1"])

proc = ConcatProcessor(ConcatSettings(axis="feature", auto_coerce_backend=True))
result = proc._concat(msg_a, msg_b)

assert isinstance(result.data, mx.array)
assert result.data.shape == (n, 2, 2)
np.testing.assert_array_equal(np.asarray(result.data[:, :, 0]), 1.0)
np.testing.assert_array_equal(np.asarray(result.data[:, :, 1]), 2.0)

def test_same_backend_numpy_unaffected(self):
"""Matched numpy inputs are never coerced or rejected (default settings)."""
n = 5
msg_a = _make_msg(np.ones((n, 2)), ch_labels=["A0", "A1"])
msg_b = _make_msg(np.ones((n, 2)) * 2, ch_labels=["B0", "B1"])

result = ConcatProcessor(ConcatSettings(axis="feature"))._concat(msg_a, msg_b)
assert isinstance(result.data, np.ndarray)
assert result.data.shape == (n, 2, 2)


class TestAssertIdenticalSharedAxes:
def test_identical_passes(self):
settings = ConcatSettings(axis="feature", assert_identical_shared_axes=True)
Expand Down
46 changes: 46 additions & 0 deletions tests/unit/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ezmsg.util.messages.chunker import array_chunker

from ezmsg.sigproc.resample import ResampleProcessor
from tests.helpers.util import requires_mlx


@pytest.fixture
Expand Down Expand Up @@ -128,6 +129,51 @@ async def test_resample_project(irregular_messages):
print(result)


@requires_mlx
def test_resample_preserves_mlx_backend():
"""ResampleProcessor must keep the source's array namespace.

scipy.interpolate.interp1d always returns numpy, so without the
namespace coercion an MLX input silently downgrades to numpy and a
downstream same-backend merge then fails with a cryptic error.
"""
import mlx.core as mx
from array_api_compat import get_namespace

np.random.seed(0)
fs = 128.0
n = 200
tvec = np.sort(np.arange(n) / fs + np.random.normal(0, 0.001, n))
ch_ax = AxisArray.CoordinateAxis(data=np.arange(3).astype(str), dims=["ch"], unit="label")

def _mk(a: int, b: int) -> AxisArray:
return AxisArray(
data=mx.array(np.random.randn(b - a, 3).astype(np.float32)),
dims=["time", "ch"],
axes={
"time": AxisArray.CoordinateAxis(data=tvec[a:b], dims=["time"], unit="s"),
"ch": ch_ax,
},
key="mlx_stream",
)

resample = ResampleProcessor(resample_rate=100.0, buffer_duration=4.0)
mlx_ns = get_namespace(mx.array([1.0]))
seen_output = False
for i in range(0, n, 40):
resample(_mk(i, min(i + 40, n)))
result = next(resample)
if result.data.shape[0] == 0:
continue
seen_output = True
assert isinstance(result.data, mx.array), (
f"Expected mx.array out, got {type(result.data).__name__}"
)
assert get_namespace(result.data) is mlx_ns

assert seen_output, "Resampler never produced a non-empty output to check."


@pytest.mark.asyncio
async def test_resample_no_input():
"""Test calling next() on ResampleProcessor before receiving any input messages."""
Expand Down
Loading