From ed47fe6913444a3f41171d0da71eba23f1d34ac3 Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Tue, 23 Jun 2026 20:43:00 -0600 Subject: [PATCH 1/3] add support for mlp and kalman --- .../collection/sample_adapt_regressor.py | 215 ++++++++++++++++-- tests/unit/test_sample_adapt_regressor.py | 104 +++++++++ 2 files changed, 301 insertions(+), 18 deletions(-) diff --git a/src/ezmsg/learn/collection/sample_adapt_regressor.py b/src/ezmsg/learn/collection/sample_adapt_regressor.py index 6613d53..1a19c92 100644 --- a/src/ezmsg/learn/collection/sample_adapt_regressor.py +++ b/src/ezmsg/learn/collection/sample_adapt_regressor.py @@ -1,31 +1,160 @@ from dataclasses import field import ezmsg.core as ez -from ezmsg.baseproc import SampleTriggerMessage +import numpy as np +from ezmsg.baseproc import ( + BaseStatefulTransformer, + BaseTransformerUnit, + SampleTriggerMessage, + processor_state, +) from ezmsg.sigproc.resample import ResampleSettings, ResampleUnit from ezmsg.sigproc.window import Window, WindowSettings from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace from ezmsg.learn.process.adaptive_linear_regressor import ( AdaptiveLinearRegressorSettings, AdaptiveLinearRegressorUnit, ) from ezmsg.learn.process.flatten import Flatten, FlattenSettings +from ezmsg.learn.process.refit_kalman import ( + RefitKalmanFilterSettings, + RefitKalmanFilterUnit, +) from ezmsg.learn.process.seqseqsampler import SeqSeqSamplerSettings, SeqSeqSamplerUnit +from ezmsg.learn.process.torch import TorchModelSettings, TorchModelUnit from ezmsg.learn.util import AdaptiveLinearRegressor +#: Default torch model class used when ``model_type == "mlp"``. +DEFAULT_TORCH_MODEL_CLASS = "ezmsg.learn.model.mlp.MLP" + +#: ``model_type`` tokens routed to a non-linear regressor engine. Everything +#: else (``linear``/``logistic``/``sgd``/``par``/``ridge``) is handled by +#: :class:`AdaptiveLinearRegressorUnit` as before. +_TORCH_MODEL_TYPE = "mlp" +_KALMAN_MODEL_TYPE = "kalman" + + +def _model_type_token(model_type) -> str: + if isinstance(model_type, AdaptiveLinearRegressor): + return model_type.value + return str(model_type).strip().lower() + + +def _model_backend(model_type) -> str: + """Map ``model_type`` to the regressor engine that handles it: + ``"torch"`` (MLP), ``"kalman"``, or ``"linear"`` (River/sklearn).""" + token = _model_type_token(model_type) + if token == _TORCH_MODEL_TYPE: + return "torch" + if token == _KALMAN_MODEL_TYPE: + return "kalman" + return "linear" + + +class DecodeOutputAdapterSettings(ez.Settings): + output_labels: list | None = None + """Channel labels for the decoded output. None -> generic ``ch0..chN``.""" + + +@processor_state +class DecodeOutputAdapterState: + ch_axis: AxisArray.CoordinateAxis | None = None + + +class DecodeOutputAdapterProcessor( + BaseStatefulTransformer[ + DecodeOutputAdapterSettings, + AxisArray, + AxisArray, + DecodeOutputAdapterState, + ] +): + """Normalize a decoder output into a ``(time, ch)`` AxisArray. + + The torch (``{"output": ...}``-keyed) and Kalman (``["time", "state"]``) + engines emit differently-shaped outputs than the River/sklearn regressor. + This rebuilds a uniform ``(time, ch=output_labels)`` message — keyed + ``_pred`` like :class:`AdaptiveLinearRegressorUnit` — so downstream + consumers see one contract regardless of backend. + """ + + def _reset_state(self, message: AxisArray) -> None: + if self.settings.output_labels is not None: + self.state.ch_axis = AxisArray.CoordinateAxis( + data=np.asarray(self.settings.output_labels), dims=["ch"] + ) + + def _process(self, message: AxisArray) -> AxisArray | None: + data = np.asarray(message.data, dtype=float) + if data.size == 0: + return None + + if self.settings.output_labels is not None: + n_outputs = len(self.settings.output_labels) + data = data.reshape((-1, n_outputs)) + ch_axis = self.state.ch_axis + else: + data = data.reshape((data.shape[0], -1)) if data.ndim > 1 else data.reshape((1, -1)) + ch_axis = AxisArray.CoordinateAxis( + data=np.asarray([f"ch{i}" for i in range(data.shape[-1])]), dims=["ch"] + ) + + axes = {"ch": ch_axis} + if "time" in message.axes: + axes["time"] = message.axes["time"] + return replace( + message, + data=data, + dims=["time", "ch"], + axes=axes, + key=f"{message.key}_pred", + ) + + +class DecodeOutputAdapter( + BaseTransformerUnit[ + DecodeOutputAdapterSettings, + AxisArray, + AxisArray, + DecodeOutputAdapterProcessor, + ] +): + SETTINGS = DecodeOutputAdapterSettings + class SampleAdaptRegressorSettings(ez.Settings): - # AdaptiveLinearRegressor settings - model_type: AdaptiveLinearRegressor = AdaptiveLinearRegressor.LINEAR - """Adaptive regressor backend/model.""" + # Regressor backend/model. Accepts the AdaptiveLinearRegressor enum (or its + # string value) for the River/sklearn engines, plus the strings ``"mlp"`` + # and ``"kalman"`` which route to the torch / refit-Kalman engines. + model_type: AdaptiveLinearRegressor | str = AdaptiveLinearRegressor.LINEAR + """Regressor backend/model.""" model_path: str | None = None - """Optional path to a pickled preconfigured model.""" + """Optional path to a pre-trained checkpoint. Format depends on the + backend: a pickled River/sklearn estimator, a ``torch.save`` artifact + (mlp), or a pickled state-space matrix dict (kalman).""" model_kwargs: dict = field(default_factory=dict) """Extra kwargs passed to the underlying regressor.""" + # Torch (mlp) settings + model_class: str = DEFAULT_TORCH_MODEL_CLASS + """Fully-qualified torch model class used when ``model_type == "mlp"``.""" + + device: str | None = None + """Torch device for the mlp backend. None -> auto (cuda/mps/cpu).""" + + # Kalman settings + steady_state: bool = True + """Kalman steady-state gain flag, used when ``model_type == "kalman"``.""" + + # Output adapter (mlp/kalman) + output_labels: list | None = None + """Decoded-output channel labels for the mlp/kalman adapter. None -> + generic ``ch0..chN``.""" + # Resampling settings resample_axis: str = "time" """Axis to resample along.""" @@ -57,8 +186,17 @@ class SampleAdaptRegressor(ez.Collection): WINDOW = Window() FLATTEN = Flatten() REGRESSOR = AdaptiveLinearRegressorUnit() + # Alternate engines for mlp/kalman. Declared unconditionally; only the one + # matching model_type is wired in network() — the others stay inert. + TORCH_REGRESSOR = TorchModelUnit() + KALMAN_REGRESSOR = RefitKalmanFilterUnit() + ADAPTER = DecodeOutputAdapter() + + def _backend(self) -> str: + return _model_backend(self.SETTINGS.model_type) def configure(self) -> None: + backend = self._backend() self.RESAMPLE.apply_settings( ResampleSettings( axis=self.SETTINGS.resample_axis, @@ -91,34 +229,75 @@ def configure(self) -> None: feature_axis="ch", ) ) + # The linear engine carries model_type through; for mlp/kalman it is + # declared-but-unwired, so give it a benign model_type that is a valid + # AdaptiveLinearRegressor and no checkpoint (keeps it inert). self.REGRESSOR.apply_settings( AdaptiveLinearRegressorSettings( - model_type=self.SETTINGS.model_type, - settings_path=self.SETTINGS.model_path, + model_type=self.SETTINGS.model_type + if backend == "linear" + else AdaptiveLinearRegressor.LINEAR, + settings_path=self.SETTINGS.model_path if backend == "linear" else None, model_kwargs=self.SETTINGS.model_kwargs, ) ) + self.TORCH_REGRESSOR.apply_settings( + TorchModelSettings( + model_class=self.SETTINGS.model_class, + checkpoint_path=self.SETTINGS.model_path if backend == "torch" else None, + model_kwargs=dict(self.SETTINGS.model_kwargs), + device=self.SETTINGS.device, + ) + ) + self.KALMAN_REGRESSOR.apply_settings( + RefitKalmanFilterSettings( + checkpoint_path=self.SETTINGS.model_path if backend == "kalman" else None, + steady_state=self.SETTINGS.steady_state, + ) + ) + self.ADAPTER.apply_settings( + DecodeOutputAdapterSettings(output_labels=self.SETTINGS.output_labels) + ) def network(self) -> ez.NetworkDefinition: - network = [ - (self.INPUT_LABELS, self.RESAMPLE.INPUT_SIGNAL), - (self.INPUT_SIGNAL, self.RESAMPLE.INPUT_REFERENCE), - (self.RESAMPLE.OUTPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_VALUE), - (self.INPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_SIGNAL), - (self.INPUT_TRIGGER, self.SEQSEQSAMPLER.INPUT_TRIGGER), - (self.SEQSEQSAMPLER.OUTPUT_SAMPLE, self.REGRESSOR.INPUT_SAMPLE), - ] + backend = self._backend() + + network = [] + if backend == "linear": + # Online-adaptation sample path (River/sklearn only). + network.extend( + [ + (self.INPUT_LABELS, self.RESAMPLE.INPUT_SIGNAL), + (self.INPUT_SIGNAL, self.RESAMPLE.INPUT_REFERENCE), + (self.RESAMPLE.OUTPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_VALUE), + (self.INPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_SIGNAL), + (self.INPUT_TRIGGER, self.SEQSEQSAMPLER.INPUT_TRIGGER), + (self.SEQSEQSAMPLER.OUTPUT_SAMPLE, self.REGRESSOR.INPUT_SAMPLE), + ] + ) + regressor = self.REGRESSOR + elif backend == "torch": + regressor = self.TORCH_REGRESSOR + else: + regressor = self.KALMAN_REGRESSOR if self.SETTINGS.decode_window_dur is None: - network.append((self.INPUT_SIGNAL, self.REGRESSOR.INPUT_SIGNAL)) + network.append((self.INPUT_SIGNAL, regressor.INPUT_SIGNAL)) else: network.extend( [ (self.INPUT_SIGNAL, self.WINDOW.INPUT_SIGNAL), (self.WINDOW.OUTPUT_SIGNAL, self.FLATTEN.INPUT_SIGNAL), - (self.FLATTEN.OUTPUT_SIGNAL, self.REGRESSOR.INPUT_SIGNAL), + (self.FLATTEN.OUTPUT_SIGNAL, regressor.INPUT_SIGNAL), ] ) - network.append((self.REGRESSOR.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)) + # The River/sklearn regressor already emits the canonical (time, ch) + # ``_pred`` contract; the torch/kalman engines need the adapter to match. + if backend == "linear": + network.append((regressor.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)) + else: + network.append((regressor.OUTPUT_SIGNAL, self.ADAPTER.INPUT_SIGNAL)) + network.append((self.ADAPTER.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)) + return tuple(network) diff --git a/tests/unit/test_sample_adapt_regressor.py b/tests/unit/test_sample_adapt_regressor.py index 8977b45..80d96de 100644 --- a/tests/unit/test_sample_adapt_regressor.py +++ b/tests/unit/test_sample_adapt_regressor.py @@ -1,6 +1,9 @@ +import pytest + from ezmsg.learn.collection.sample_adapt_regressor import ( SampleAdaptRegressor, SampleAdaptRegressorSettings, + _model_backend, ) @@ -19,3 +22,104 @@ def test_sample_adapt_regressor_uses_windowed_decode_branch_when_configured(): assert (collection.WINDOW.OUTPUT_SIGNAL, collection.FLATTEN.INPUT_SIGNAL) in network assert (collection.FLATTEN.OUTPUT_SIGNAL, collection.REGRESSOR.INPUT_SIGNAL) in network assert (collection.INPUT_SIGNAL, collection.REGRESSOR.INPUT_SIGNAL) not in network + + +@pytest.mark.parametrize( + "model_type, expected", + [ + ("linear", "linear"), + ("logistic", "linear"), + ("sgd", "linear"), + ("par", "linear"), + ("ridge", "linear"), + ("mlp", "torch"), + ("MLP", "torch"), + ("kalman", "kalman"), + ("Kalman", "kalman"), + ], +) +def test_model_backend_routes_model_type_to_engine(model_type, expected): + assert _model_backend(model_type) == expected + + +def test_linear_backend_wires_sample_path_and_no_adapter(): + collection = SampleAdaptRegressor( + settings=SampleAdaptRegressorSettings(model_type="linear") + ) + collection.configure() + + network = collection.network() + + # Online-adaptation sample path is present for the linear engine. + assert (collection.INPUT_TRIGGER, collection.SEQSEQSAMPLER.INPUT_TRIGGER) in network + assert ( + collection.SEQSEQSAMPLER.OUTPUT_SAMPLE, + collection.REGRESSOR.INPUT_SAMPLE, + ) in network + # Linear emits the canonical _pred contract directly; no adapter in the graph. + assert (collection.REGRESSOR.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) in network + assert (collection.ADAPTER.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) not in network + + +def test_torch_backend_wires_decode_only_through_adapter(): + collection = SampleAdaptRegressor( + settings=SampleAdaptRegressorSettings(model_type="mlp") + ) + collection.configure() + + network = collection.network() + + # Decode-only path: signal -> torch engine -> adapter -> output. + assert (collection.INPUT_SIGNAL, collection.TORCH_REGRESSOR.INPUT_SIGNAL) in network + assert ( + collection.TORCH_REGRESSOR.OUTPUT_SIGNAL, + collection.ADAPTER.INPUT_SIGNAL, + ) in network + assert (collection.ADAPTER.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) in network + # Linear engine and its sample/adapt path stay inert. + assert ( + collection.SEQSEQSAMPLER.OUTPUT_SAMPLE, + collection.REGRESSOR.INPUT_SAMPLE, + ) not in network + assert (collection.REGRESSOR.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) not in network + + +def test_kalman_backend_wires_decode_only_through_adapter(): + collection = SampleAdaptRegressor( + settings=SampleAdaptRegressorSettings(model_type="kalman") + ) + collection.configure() + + network = collection.network() + + assert (collection.INPUT_SIGNAL, collection.KALMAN_REGRESSOR.INPUT_SIGNAL) in network + assert ( + collection.KALMAN_REGRESSOR.OUTPUT_SIGNAL, + collection.ADAPTER.INPUT_SIGNAL, + ) in network + assert (collection.ADAPTER.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) in network + assert (collection.REGRESSOR.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) not in network + + +def test_non_linear_backend_uses_windowed_decode_branch(): + collection = SampleAdaptRegressor( + settings=SampleAdaptRegressorSettings( + model_type="mlp", + decode_window_dur=0.2, + decode_window_shift=0.01, + ) + ) + collection.configure() + + network = collection.network() + + # Windowing feeds the torch engine, not the (inert) linear regressor. + assert (collection.INPUT_SIGNAL, collection.WINDOW.INPUT_SIGNAL) in network + assert ( + collection.FLATTEN.OUTPUT_SIGNAL, + collection.TORCH_REGRESSOR.INPUT_SIGNAL, + ) in network + assert ( + collection.FLATTEN.OUTPUT_SIGNAL, + collection.REGRESSOR.INPUT_SIGNAL, + ) not in network From 0685f45ffeef0e729e5df6f96d57c46f80e961aa Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Fri, 26 Jun 2026 10:49:38 -0600 Subject: [PATCH 2/3] update to factory pattern --- .../collection/sample_adapt_regressor.py | 297 ++++++++++-------- tests/unit/test_sample_adapt_regressor.py | 164 ++++++---- 2 files changed, 267 insertions(+), 194 deletions(-) diff --git a/src/ezmsg/learn/collection/sample_adapt_regressor.py b/src/ezmsg/learn/collection/sample_adapt_regressor.py index 1a19c92..1ba5a0f 100644 --- a/src/ezmsg/learn/collection/sample_adapt_regressor.py +++ b/src/ezmsg/learn/collection/sample_adapt_regressor.py @@ -101,14 +101,22 @@ def _process(self, message: AxisArray) -> AxisArray | None: data=np.asarray([f"ch{i}" for i in range(data.shape[-1])]), dims=["ch"] ) - axes = {"ch": ch_axis} - if "time" in message.axes: - axes["time"] = message.axes["time"] + # The decoder engines carry a ``time`` axis through (kalman keeps the + # input's; the torch path inherits the windower's renamed ``win``->``time`` + # axis). Require it rather than silently emitting untimed samples — a + # missing time axis means the upstream layout changed and downstream + # timing/outlet behavior would be wrong. + if "time" not in message.axes: + raise ValueError( + "DecodeOutputAdapter expected a 'time' axis on the decoder output " + f"(got dims={message.dims}, axes={list(message.axes)}); the upstream " + "windowing/regressor layout changed." + ) return replace( message, data=data, dims=["time", "ch"], - axes=axes, + axes={"ch": ch_axis, "time": message.axes["time"]}, key=f"{message.key}_pred", ) @@ -173,131 +181,158 @@ class SampleAdaptRegressorSettings(ez.Settings): """Optional inference-side feature window shift in seconds.""" -class SampleAdaptRegressor(ez.Collection): - SETTINGS = SampleAdaptRegressorSettings - - INPUT_LABELS = ez.InputTopic(AxisArray) - INPUT_SIGNAL = ez.InputTopic(AxisArray) - INPUT_TRIGGER = ez.InputTopic(SampleTriggerMessage) - OUTPUT_SIGNAL = ez.OutputTopic(AxisArray) - - RESAMPLE = ResampleUnit() - SEQSEQSAMPLER = SeqSeqSamplerUnit() - WINDOW = Window() - FLATTEN = Flatten() - REGRESSOR = AdaptiveLinearRegressorUnit() - # Alternate engines for mlp/kalman. Declared unconditionally; only the one - # matching model_type is wired in network() — the others stay inert. - TORCH_REGRESSOR = TorchModelUnit() - KALMAN_REGRESSOR = RefitKalmanFilterUnit() - ADAPTER = DecodeOutputAdapter() - - def _backend(self) -> str: - return _model_backend(self.SETTINGS.model_type) - - def configure(self) -> None: - backend = self._backend() - self.RESAMPLE.apply_settings( - ResampleSettings( - axis=self.SETTINGS.resample_axis, - max_chunk_delay=float("inf"), - fill_value="extrapolate", - buffer_duration=self.SETTINGS.resample_buffer_duration, - ) - ) - self.SEQSEQSAMPLER.apply_settings( - SeqSeqSamplerSettings( - max_buffer_dur=self.SETTINGS.sampler_max_buffer_dur, - ) - ) - self.WINDOW.apply_settings( - WindowSettings( - axis="time", - newaxis="win", - window_dur=self.SETTINGS.decode_window_dur, - window_shift=self.SETTINGS.decode_window_shift, - # Window requires zero_pad_until="input" when window_shift is - # None (1:1 mode, e.g. no inference-side windowing); using - # "none" there only logs a warning and is coerced to "input". - zero_pad_until="none" if self.SETTINGS.decode_window_shift is not None else "input", - ) - ) - self.FLATTEN.apply_settings( - FlattenSettings( - preserve_axis="win", - sample_axis="time", - feature_axis="ch", - ) - ) - # The linear engine carries model_type through; for mlp/kalman it is - # declared-but-unwired, so give it a benign model_type that is a valid - # AdaptiveLinearRegressor and no checkpoint (keeps it inert). - self.REGRESSOR.apply_settings( - AdaptiveLinearRegressorSettings( - model_type=self.SETTINGS.model_type - if backend == "linear" - else AdaptiveLinearRegressor.LINEAR, - settings_path=self.SETTINGS.model_path if backend == "linear" else None, - model_kwargs=self.SETTINGS.model_kwargs, - ) - ) - self.TORCH_REGRESSOR.apply_settings( - TorchModelSettings( - model_class=self.SETTINGS.model_class, - checkpoint_path=self.SETTINGS.model_path if backend == "torch" else None, - model_kwargs=dict(self.SETTINGS.model_kwargs), - device=self.SETTINGS.device, - ) - ) - self.KALMAN_REGRESSOR.apply_settings( - RefitKalmanFilterSettings( - checkpoint_path=self.SETTINGS.model_path if backend == "kalman" else None, - steady_state=self.SETTINGS.steady_state, - ) - ) - self.ADAPTER.apply_settings( - DecodeOutputAdapterSettings(output_labels=self.SETTINGS.output_labels) - ) - - def network(self) -> ez.NetworkDefinition: - backend = self._backend() - - network = [] - if backend == "linear": - # Online-adaptation sample path (River/sklearn only). - network.extend( - [ - (self.INPUT_LABELS, self.RESAMPLE.INPUT_SIGNAL), - (self.INPUT_SIGNAL, self.RESAMPLE.INPUT_REFERENCE), - (self.RESAMPLE.OUTPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_VALUE), - (self.INPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_SIGNAL), - (self.INPUT_TRIGGER, self.SEQSEQSAMPLER.INPUT_TRIGGER), - (self.SEQSEQSAMPLER.OUTPUT_SAMPLE, self.REGRESSOR.INPUT_SAMPLE), - ] - ) - regressor = self.REGRESSOR - elif backend == "torch": - regressor = self.TORCH_REGRESSOR - else: - regressor = self.KALMAN_REGRESSOR - - if self.SETTINGS.decode_window_dur is None: - network.append((self.INPUT_SIGNAL, regressor.INPUT_SIGNAL)) - else: - network.extend( - [ - (self.INPUT_SIGNAL, self.WINDOW.INPUT_SIGNAL), - (self.WINDOW.OUTPUT_SIGNAL, self.FLATTEN.INPUT_SIGNAL), - (self.FLATTEN.OUTPUT_SIGNAL, regressor.INPUT_SIGNAL), - ] - ) - - # The River/sklearn regressor already emits the canonical (time, ch) - # ``_pred`` contract; the torch/kalman engines need the adapter to match. - if backend == "linear": - network.append((regressor.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)) - else: - network.append((regressor.OUTPUT_SIGNAL, self.ADAPTER.INPUT_SIGNAL)) - network.append((self.ADAPTER.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)) +def _build_regressor_unit(settings: SampleAdaptRegressorSettings): + """Factory: construct the single regressor unit for ``settings.model_type``. - return tuple(network) + Returns ``(unit, backend)`` where ``backend`` is ``"linear"`` (River/sklearn + via :class:`AdaptiveLinearRegressorUnit`), ``"torch"`` (mlp), or ``"kalman"``. + """ + backend = _model_backend(settings.model_type) + if backend == "torch": + return TorchModelUnit(), backend + if backend == "kalman": + return RefitKalmanFilterUnit(), backend + return AdaptiveLinearRegressorUnit(), backend + + +def build_sample_adapt_regressor( + settings: SampleAdaptRegressorSettings, +) -> ez.Collection: + """Build a decode collection wired around a single regressor engine. + + The regressor backend (River/sklearn, torch-mlp, or refit-Kalman) is selected + from ``settings.model_type`` and the collection class is defined dynamically + so the graph contains exactly the units that backend uses — no inert, + declared-but-unwired units. The signal path (and, for the linear engine, the + online-adaptation sample path) wire to that one unit, so there is no per- + backend wiring to keep in sync. + """ + regressor, backend = _build_regressor_unit(settings) + use_window = settings.decode_window_dur is not None + use_sample_path = backend == "linear" # online-adaptation path (River/sklearn) + needs_adapter = backend != "linear" # torch/kalman outputs need normalizing + + class SampleAdaptRegressor(ez.Collection): + SETTINGS = SampleAdaptRegressorSettings + + INPUT_LABELS = ez.InputTopic(AxisArray) + INPUT_SIGNAL = ez.InputTopic(AxisArray) + INPUT_TRIGGER = ez.InputTopic(SampleTriggerMessage) + OUTPUT_SIGNAL = ez.OutputTopic(AxisArray) + + REGRESSOR = regressor + if use_window: + WINDOW = Window() + FLATTEN = Flatten() + if use_sample_path: + RESAMPLE = ResampleUnit() + SEQSEQSAMPLER = SeqSeqSamplerUnit() + if needs_adapter: + ADAPTER = DecodeOutputAdapter() + + def configure(self) -> None: + if backend == "linear": + self.REGRESSOR.apply_settings( + AdaptiveLinearRegressorSettings( + model_type=self.SETTINGS.model_type, + settings_path=self.SETTINGS.model_path, + model_kwargs=self.SETTINGS.model_kwargs, + ) + ) + elif backend == "torch": + self.REGRESSOR.apply_settings( + TorchModelSettings( + model_class=self.SETTINGS.model_class, + checkpoint_path=self.SETTINGS.model_path, + model_kwargs=dict(self.SETTINGS.model_kwargs), + device=self.SETTINGS.device, + ) + ) + else: + self.REGRESSOR.apply_settings( + RefitKalmanFilterSettings( + checkpoint_path=self.SETTINGS.model_path, + steady_state=self.SETTINGS.steady_state, + ) + ) + + if use_window: + self.WINDOW.apply_settings( + WindowSettings( + axis="time", + newaxis="win", + window_dur=self.SETTINGS.decode_window_dur, + window_shift=self.SETTINGS.decode_window_shift, + # Window requires zero_pad_until="input" when + # window_shift is None (1:1 mode); "none" there only + # warns and is coerced to "input". + zero_pad_until="none" + if self.SETTINGS.decode_window_shift is not None + else "input", + ) + ) + self.FLATTEN.apply_settings( + FlattenSettings( + preserve_axis="win", + sample_axis="time", + feature_axis="ch", + ) + ) + if use_sample_path: + self.RESAMPLE.apply_settings( + ResampleSettings( + axis=self.SETTINGS.resample_axis, + max_chunk_delay=float("inf"), + fill_value="extrapolate", + buffer_duration=self.SETTINGS.resample_buffer_duration, + ) + ) + self.SEQSEQSAMPLER.apply_settings( + SeqSeqSamplerSettings( + max_buffer_dur=self.SETTINGS.sampler_max_buffer_dur, + ) + ) + if needs_adapter: + self.ADAPTER.apply_settings( + DecodeOutputAdapterSettings( + output_labels=self.SETTINGS.output_labels + ) + ) + + def network(self) -> ez.NetworkDefinition: + network = [] + if use_sample_path: + # Online-adaptation sample path (River/sklearn only). + network.extend( + [ + (self.INPUT_LABELS, self.RESAMPLE.INPUT_SIGNAL), + (self.INPUT_SIGNAL, self.RESAMPLE.INPUT_REFERENCE), + (self.RESAMPLE.OUTPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_VALUE), + (self.INPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_SIGNAL), + (self.INPUT_TRIGGER, self.SEQSEQSAMPLER.INPUT_TRIGGER), + (self.SEQSEQSAMPLER.OUTPUT_SAMPLE, self.REGRESSOR.INPUT_SAMPLE), + ] + ) + + if use_window: + network.extend( + [ + (self.INPUT_SIGNAL, self.WINDOW.INPUT_SIGNAL), + (self.WINDOW.OUTPUT_SIGNAL, self.FLATTEN.INPUT_SIGNAL), + (self.FLATTEN.OUTPUT_SIGNAL, self.REGRESSOR.INPUT_SIGNAL), + ] + ) + else: + network.append((self.INPUT_SIGNAL, self.REGRESSOR.INPUT_SIGNAL)) + + # River/sklearn already emits the canonical (time, ch) ``_pred`` + # contract; torch/kalman route through the adapter to match it. + if needs_adapter: + network.append((self.REGRESSOR.OUTPUT_SIGNAL, self.ADAPTER.INPUT_SIGNAL)) + network.append((self.ADAPTER.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)) + else: + network.append((self.REGRESSOR.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)) + + return tuple(network) + + return SampleAdaptRegressor(settings=settings) diff --git a/tests/unit/test_sample_adapt_regressor.py b/tests/unit/test_sample_adapt_regressor.py index 80d96de..4c40774 100644 --- a/tests/unit/test_sample_adapt_regressor.py +++ b/tests/unit/test_sample_adapt_regressor.py @@ -1,27 +1,27 @@ +import numpy as np import pytest +from ezmsg.util.messages.axisarray import AxisArray from ezmsg.learn.collection.sample_adapt_regressor import ( - SampleAdaptRegressor, + DecodeOutputAdapterProcessor, SampleAdaptRegressorSettings, + _build_regressor_unit, _model_backend, + build_sample_adapt_regressor, ) +from ezmsg.learn.process.adaptive_linear_regressor import AdaptiveLinearRegressorUnit +from ezmsg.learn.process.refit_kalman import RefitKalmanFilterUnit +from ezmsg.learn.process.torch import TorchModelUnit -def test_sample_adapt_regressor_uses_windowed_decode_branch_when_configured(): - collection = SampleAdaptRegressor( - settings=SampleAdaptRegressorSettings( - decode_window_dur=0.2, - decode_window_shift=0.01, - ) - ) +def _build(**kwargs): + """Build + configure a decode collection for the given settings.""" + collection = build_sample_adapt_regressor(SampleAdaptRegressorSettings(**kwargs)) collection.configure() + return collection - network = collection.network() - assert (collection.INPUT_SIGNAL, collection.WINDOW.INPUT_SIGNAL) in network - assert (collection.WINDOW.OUTPUT_SIGNAL, collection.FLATTEN.INPUT_SIGNAL) in network - assert (collection.FLATTEN.OUTPUT_SIGNAL, collection.REGRESSOR.INPUT_SIGNAL) in network - assert (collection.INPUT_SIGNAL, collection.REGRESSOR.INPUT_SIGNAL) not in network +# --- backend routing --------------------------------------------------------- @pytest.mark.parametrize( @@ -42,14 +42,32 @@ def test_model_backend_routes_model_type_to_engine(model_type, expected): assert _model_backend(model_type) == expected -def test_linear_backend_wires_sample_path_and_no_adapter(): - collection = SampleAdaptRegressor( - settings=SampleAdaptRegressorSettings(model_type="linear") - ) - collection.configure() +@pytest.mark.parametrize( + "model_type, expected_backend, expected_unit", + [ + ("linear", "linear", AdaptiveLinearRegressorUnit), + ("mlp", "torch", TorchModelUnit), + ("kalman", "kalman", RefitKalmanFilterUnit), + ], +) +def test_build_regressor_unit_selects_engine(model_type, expected_backend, expected_unit): + unit, backend = _build_regressor_unit(SampleAdaptRegressorSettings(model_type=model_type)) + assert backend == expected_backend + assert isinstance(unit, expected_unit) + + +# --- collection topology ----------------------------------------------------- + +def test_linear_backend_wires_sample_path_and_no_adapter(): + collection = _build(model_type="linear") network = collection.network() + # The factory builds only the units the linear engine uses. + assert hasattr(collection, "RESAMPLE") + assert hasattr(collection, "SEQSEQSAMPLER") + assert not hasattr(collection, "ADAPTER") + # Online-adaptation sample path is present for the linear engine. assert (collection.INPUT_TRIGGER, collection.SEQSEQSAMPLER.INPUT_TRIGGER) in network assert ( @@ -58,68 +76,88 @@ def test_linear_backend_wires_sample_path_and_no_adapter(): ) in network # Linear emits the canonical _pred contract directly; no adapter in the graph. assert (collection.REGRESSOR.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) in network - assert (collection.ADAPTER.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) not in network + # No windowing by default: signal flows straight into the regressor. + assert (collection.INPUT_SIGNAL, collection.REGRESSOR.INPUT_SIGNAL) in network -def test_torch_backend_wires_decode_only_through_adapter(): - collection = SampleAdaptRegressor( - settings=SampleAdaptRegressorSettings(model_type="mlp") - ) - collection.configure() - +@pytest.mark.parametrize( + "model_type, expected_unit", + [("mlp", TorchModelUnit), ("kalman", RefitKalmanFilterUnit)], +) +def test_non_linear_backend_wires_decode_only_through_adapter(model_type, expected_unit): + collection = _build(model_type=model_type) network = collection.network() - # Decode-only path: signal -> torch engine -> adapter -> output. - assert (collection.INPUT_SIGNAL, collection.TORCH_REGRESSOR.INPUT_SIGNAL) in network + # Only the chosen engine + adapter exist; no inert sample-path units. + assert isinstance(collection.REGRESSOR, expected_unit) + assert hasattr(collection, "ADAPTER") + assert not hasattr(collection, "RESAMPLE") + assert not hasattr(collection, "SEQSEQSAMPLER") + + # Decode-only path: signal -> engine -> adapter -> output. + assert (collection.INPUT_SIGNAL, collection.REGRESSOR.INPUT_SIGNAL) in network assert ( - collection.TORCH_REGRESSOR.OUTPUT_SIGNAL, + collection.REGRESSOR.OUTPUT_SIGNAL, collection.ADAPTER.INPUT_SIGNAL, ) in network assert (collection.ADAPTER.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) in network - # Linear engine and its sample/adapt path stay inert. - assert ( - collection.SEQSEQSAMPLER.OUTPUT_SAMPLE, - collection.REGRESSOR.INPUT_SAMPLE, - ) not in network - assert (collection.REGRESSOR.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) not in network -def test_kalman_backend_wires_decode_only_through_adapter(): - collection = SampleAdaptRegressor( - settings=SampleAdaptRegressorSettings(model_type="kalman") +@pytest.mark.parametrize("model_type", ["linear", "mlp", "kalman"]) +def test_windowed_decode_branch_when_configured(model_type): + collection = _build( + model_type=model_type, + decode_window_dur=0.2, + decode_window_shift=0.01, ) - collection.configure() - network = collection.network() - assert (collection.INPUT_SIGNAL, collection.KALMAN_REGRESSOR.INPUT_SIGNAL) in network + assert (collection.INPUT_SIGNAL, collection.WINDOW.INPUT_SIGNAL) in network + assert (collection.WINDOW.OUTPUT_SIGNAL, collection.FLATTEN.INPUT_SIGNAL) in network assert ( - collection.KALMAN_REGRESSOR.OUTPUT_SIGNAL, - collection.ADAPTER.INPUT_SIGNAL, + collection.FLATTEN.OUTPUT_SIGNAL, + collection.REGRESSOR.INPUT_SIGNAL, ) in network - assert (collection.ADAPTER.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) in network - assert (collection.REGRESSOR.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) not in network + # Windowing replaces the direct signal->regressor edge. + assert (collection.INPUT_SIGNAL, collection.REGRESSOR.INPUT_SIGNAL) not in network + + +def test_non_windowed_backend_has_no_window_units(): + collection = _build(model_type="mlp") + assert not hasattr(collection, "WINDOW") + assert not hasattr(collection, "FLATTEN") + + +# --- decode output adapter --------------------------------------------------- + + +def _adapter_message(data, *, dims, with_time=True, key="dec"): + axes = {} + if with_time: + axes["time"] = AxisArray.TimeAxis(fs=50.0) + return AxisArray(data=np.asarray(data, dtype=float), dims=dims, axes=axes, key=key) -def test_non_linear_backend_uses_windowed_decode_branch(): - collection = SampleAdaptRegressor( - settings=SampleAdaptRegressorSettings( - model_type="mlp", - decode_window_dur=0.2, - decode_window_shift=0.01, - ) +def test_adapter_normalizes_output_to_time_ch(): + # Kalman-style output: (time, state) with state_dim == len(output_labels). + proc = DecodeOutputAdapterProcessor(output_labels=["vx", "vy"]) + message = _adapter_message( + np.arange(8).reshape(4, 2), dims=["time", "state"], key="kf" ) - collection.configure() - network = collection.network() + result = proc(message) - # Windowing feeds the torch engine, not the (inert) linear regressor. - assert (collection.INPUT_SIGNAL, collection.WINDOW.INPUT_SIGNAL) in network - assert ( - collection.FLATTEN.OUTPUT_SIGNAL, - collection.TORCH_REGRESSOR.INPUT_SIGNAL, - ) in network - assert ( - collection.FLATTEN.OUTPUT_SIGNAL, - collection.REGRESSOR.INPUT_SIGNAL, - ) not in network + assert result.dims == ["time", "ch"] + assert result.data.shape == (4, 2) + assert list(result.get_axis("ch").data) == ["vx", "vy"] + assert result.key == "kf_pred" + + +def test_adapter_requires_time_axis(): + proc = DecodeOutputAdapterProcessor(output_labels=["vx", "vy"]) + message = _adapter_message( + np.arange(2).reshape(1, 2), dims=["win", "ch"], with_time=False + ) + + with pytest.raises(ValueError, match="time"): + proc(message) From 0fa6c817deb91602ad382c5e09625ed14fd0ac3f Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 26 Jun 2026 13:54:50 -0400 Subject: [PATCH 3/3] Add windowed-path integration test for decode adapter Chains the real Window -> Flatten -> DecodeOutputAdapter processors to lock the win->time axis rename that makes the adapter's time-axis guard safe on the windowed mlp/kalman path. A future change to Flatten's sample_axis semantics now fails here instead of only at runtime. Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/unit/test_sample_adapt_regressor.py | 71 +++++++++++++++++++++-- 1 file changed, 65 insertions(+), 6 deletions(-) diff --git a/tests/unit/test_sample_adapt_regressor.py b/tests/unit/test_sample_adapt_regressor.py index 4c40774..8f77099 100644 --- a/tests/unit/test_sample_adapt_regressor.py +++ b/tests/unit/test_sample_adapt_regressor.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from ezmsg.sigproc.window import WindowSettings, WindowTransformer from ezmsg.util.messages.axisarray import AxisArray from ezmsg.learn.collection.sample_adapt_regressor import ( @@ -10,6 +11,7 @@ build_sample_adapt_regressor, ) from ezmsg.learn.process.adaptive_linear_regressor import AdaptiveLinearRegressorUnit +from ezmsg.learn.process.flatten import FlattenSettings, FlattenTransformer from ezmsg.learn.process.refit_kalman import RefitKalmanFilterUnit from ezmsg.learn.process.torch import TorchModelUnit @@ -141,9 +143,7 @@ def _adapter_message(data, *, dims, with_time=True, key="dec"): def test_adapter_normalizes_output_to_time_ch(): # Kalman-style output: (time, state) with state_dim == len(output_labels). proc = DecodeOutputAdapterProcessor(output_labels=["vx", "vy"]) - message = _adapter_message( - np.arange(8).reshape(4, 2), dims=["time", "state"], key="kf" - ) + message = _adapter_message(np.arange(8).reshape(4, 2), dims=["time", "state"], key="kf") result = proc(message) @@ -155,9 +155,68 @@ def test_adapter_normalizes_output_to_time_ch(): def test_adapter_requires_time_axis(): proc = DecodeOutputAdapterProcessor(output_labels=["vx", "vy"]) - message = _adapter_message( - np.arange(2).reshape(1, 2), dims=["win", "ch"], with_time=False - ) + message = _adapter_message(np.arange(2).reshape(1, 2), dims=["win", "ch"], with_time=False) with pytest.raises(ValueError, match="time"): proc(message) + + +# --- windowed path integration ---------------------------------------------- + + +def test_windowed_path_renames_win_to_time_and_feeds_adapter(): + """End-to-end check of the windowed mlp/kalman feature path. + + The adapter's ``time``-axis guard is only safe because Window + the + learn-side Flatten rename the window axis (``win``) to ``time`` on output. + This chains the real Window -> Flatten -> adapter processors with the exact + settings ``configure()`` applies for the windowed path, so a future change + to Flatten's ``sample_axis`` semantics would fail here instead of only + surfacing at runtime. The torch/kalman engine in between preserves + ``message.axes``, so feeding the flattened output straight to the adapter + exercises the same time-axis plumbing. + """ + fs = 100.0 + window_dur, window_shift = 0.2, 0.01 + n_time, n_ch = 60, 3 + sig = AxisArray( + data=np.arange(n_time * n_ch, dtype=float).reshape(n_time, n_ch), + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=fs, offset=0.0), + "ch": AxisArray.CoordinateAxis(data=np.array(["c0", "c1", "c2"]), dims=["ch"]), + }, + key="neural", + ) + + # Settings mirror SampleAdaptRegressor.configure() for the windowed branch. + windower = WindowTransformer( + WindowSettings( + axis="time", + newaxis="win", + window_dur=window_dur, + window_shift=window_shift, + zero_pad_until="none", + ) + ) + flatten = FlattenTransformer(FlattenSettings(preserve_axis="win", sample_axis="time", feature_axis="ch")) + adapter = DecodeOutputAdapterProcessor(output_labels=None) + + windowed = windower(sig) + assert windowed.dims == ["win", "time", "ch"] + + flat = flatten(windowed) + # The window axis is preserved but renamed to "time"; the inner lag dim and + # channels fold into the feature axis. + assert flat.dims == ["time", "ch"] + assert "time" in flat.axes + # The renamed axis carries the window-rate cadence (one sample per shift), + # not the original 100 Hz sample rate. + assert flat.axes["time"].gain == pytest.approx(window_shift) + + # The adapter accepts the windowed output (no raise) and emits the contract. + result = adapter(flat) + assert result.dims == ["time", "ch"] + assert result.data.shape[0] == flat.data.shape[0] + assert result.key == "neural_pred" + assert result.axes["time"].gain == pytest.approx(window_shift)