diff --git a/src/ezmsg/sigproc/concat.py b/src/ezmsg/sigproc/concat.py index 38f4fb3..10ef674 100644 --- a/src/ezmsg/sigproc/concat.py +++ b/src/ezmsg/sigproc/concat.py @@ -393,6 +393,12 @@ class ConcatSettings(ez.Settings): new_key: str | None = None """Output AxisArray key. If None, uses the key from signal A.""" + auto_coerce_backend: bool = False + """If True, silently coerce signal B to signal A's array namespace when the + two inputs are on mismatched backends (e.g. MLX vs numpy). Defaults to False + (strict): a backend mismatch raises a clear error instead, since it is almost + always an upstream bug and silent coercion hides device<->host copies.""" + @dataclass class ConcatState: @@ -450,6 +456,18 @@ def _concat(self, a: AxisArray, b: AxisArray) -> AxisArray: new_axis = concat_dim not in a.dims xp = get_namespace(a.data) + xp_b = get_namespace(b.data) + if xp_b is not xp: + if self.settings.auto_coerce_backend: + b = replace(b, data=xp.asarray(b.data)) + else: + raise TypeError( + f"Concat received inputs on mismatched backends: " + f"a.data namespace={xp.__name__}, b.data namespace={xp_b.__name__}. " + f"Coerce both inputs to one backend upstream " + f"(e.g., via ezmsg.sigproc.asarray.AsArrayTransformer) before merging, " + f"or set ConcatSettings(auto_coerce_backend=True) to coerce B to A's backend." + ) # expand_dims for new-axis concatenation. if new_axis: diff --git a/src/ezmsg/sigproc/resample.py b/src/ezmsg/sigproc/resample.py index 5fec60d..4c98d67 100644 --- a/src/ezmsg/sigproc/resample.py +++ b/src/ezmsg/sigproc/resample.py @@ -7,6 +7,7 @@ import ezmsg.core as ez import numpy as np import scipy.interpolate +from array_api_compat import get_namespace from ezmsg.baseproc import ( BaseConsumerUnit, BaseStatefulProcessor, @@ -235,6 +236,13 @@ def __next__(self) -> AxisArray: # Calculate output resampled_data = f(xnew) + # scipy.interpolate.interp1d always returns numpy. Coerce back to the + # source's namespace so downstream merges with same-backend streams + # don't see a backend mismatch. + src_xp = get_namespace(src_axarr.data) + if get_namespace(resampled_data) is not src_xp: + resampled_data = src_xp.asarray(resampled_data) + # Create output message if hasattr(ref_ax, "data"): out_ax = replace(ref_ax, data=xnew) diff --git a/tests/unit/test_concat.py b/tests/unit/test_concat.py index d3ebda1..d4ccdc1 100644 --- a/tests/unit/test_concat.py +++ b/tests/unit/test_concat.py @@ -366,6 +366,65 @@ def _msg_mx(np_data, ch_labels): np.testing.assert_array_equal(np.asarray(result.data[:, 2:]), 3.0) +class TestBackendMismatch: + """Concat across mismatched array backends should fail loudly (default) + or coerce on opt-in, instead of raising a cryptic MLX-internals error.""" + + @staticmethod + def _msg_mx(np_data, ch_labels): + import mlx.core as mx + + return AxisArray( + data=mx.array(np_data), + dims=["time", "ch"], + axes=frozendict( + { + "time": AxisArray.TimeAxis(fs=100.0, offset=0.0), + "ch": CoordinateAxis(data=np.array(ch_labels), dims=["ch"], unit="label"), + } + ), + key="", + ) + + @requires_mlx + def test_mismatch_raises_clear_error(self): + """MLX a + numpy b -> clear, named backend-mismatch error.""" + n = 5 + msg_a = self._msg_mx(np.ones((n, 2), dtype=np.float32), ["A0", "A1"]) + msg_b = _make_msg(np.ones((n, 2), dtype=np.float32) * 2, ch_labels=["B0", "B1"]) + + proc = ConcatProcessor(ConcatSettings(axis="feature")) + with pytest.raises(TypeError, match="mismatched backends"): + proc._concat(msg_a, msg_b) + + @requires_mlx + def test_auto_coerce_backend_coerces_b_to_a(self): + """auto_coerce_backend=True coerces b onto a's namespace instead of raising.""" + import mlx.core as mx + + n = 5 + msg_a = self._msg_mx(np.ones((n, 2), dtype=np.float32), ["A0", "A1"]) + msg_b = _make_msg(np.ones((n, 2), dtype=np.float32) * 2, ch_labels=["B0", "B1"]) + + proc = ConcatProcessor(ConcatSettings(axis="feature", auto_coerce_backend=True)) + result = proc._concat(msg_a, msg_b) + + assert isinstance(result.data, mx.array) + assert result.data.shape == (n, 2, 2) + np.testing.assert_array_equal(np.asarray(result.data[:, :, 0]), 1.0) + np.testing.assert_array_equal(np.asarray(result.data[:, :, 1]), 2.0) + + def test_same_backend_numpy_unaffected(self): + """Matched numpy inputs are never coerced or rejected (default settings).""" + n = 5 + msg_a = _make_msg(np.ones((n, 2)), ch_labels=["A0", "A1"]) + msg_b = _make_msg(np.ones((n, 2)) * 2, ch_labels=["B0", "B1"]) + + result = ConcatProcessor(ConcatSettings(axis="feature"))._concat(msg_a, msg_b) + assert isinstance(result.data, np.ndarray) + assert result.data.shape == (n, 2, 2) + + class TestAssertIdenticalSharedAxes: def test_identical_passes(self): settings = ConcatSettings(axis="feature", assert_identical_shared_axes=True) diff --git a/tests/unit/test_resample.py b/tests/unit/test_resample.py index 0a9825c..0f8953d 100644 --- a/tests/unit/test_resample.py +++ b/tests/unit/test_resample.py @@ -7,6 +7,7 @@ from ezmsg.util.messages.chunker import array_chunker from ezmsg.sigproc.resample import ResampleProcessor +from tests.helpers.util import requires_mlx @pytest.fixture @@ -128,6 +129,51 @@ async def test_resample_project(irregular_messages): print(result) +@requires_mlx +def test_resample_preserves_mlx_backend(): + """ResampleProcessor must keep the source's array namespace. + + scipy.interpolate.interp1d always returns numpy, so without the + namespace coercion an MLX input silently downgrades to numpy and a + downstream same-backend merge then fails with a cryptic error. + """ + import mlx.core as mx + from array_api_compat import get_namespace + + np.random.seed(0) + fs = 128.0 + n = 200 + tvec = np.sort(np.arange(n) / fs + np.random.normal(0, 0.001, n)) + ch_ax = AxisArray.CoordinateAxis(data=np.arange(3).astype(str), dims=["ch"], unit="label") + + def _mk(a: int, b: int) -> AxisArray: + return AxisArray( + data=mx.array(np.random.randn(b - a, 3).astype(np.float32)), + dims=["time", "ch"], + axes={ + "time": AxisArray.CoordinateAxis(data=tvec[a:b], dims=["time"], unit="s"), + "ch": ch_ax, + }, + key="mlx_stream", + ) + + resample = ResampleProcessor(resample_rate=100.0, buffer_duration=4.0) + mlx_ns = get_namespace(mx.array([1.0])) + seen_output = False + for i in range(0, n, 40): + resample(_mk(i, min(i + 40, n))) + result = next(resample) + if result.data.shape[0] == 0: + continue + seen_output = True + assert isinstance(result.data, mx.array), ( + f"Expected mx.array out, got {type(result.data).__name__}" + ) + assert get_namespace(result.data) is mlx_ns + + assert seen_output, "Resampler never produced a non-empty output to check." + + @pytest.mark.asyncio async def test_resample_no_input(): """Test calling next() on ResampleProcessor before receiving any input messages."""